diff --git a/.gitattributes b/.gitattributes index bcdeffc09a11..84b47a6fc56e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,3 +1,4 @@ .github/ export-ignore +datafusion/core/tests/data/newlines_in_values.csv text eol=lf datafusion/proto/src/generated/prost.rs linguist-generated datafusion/proto/src/generated/pbjson.rs linguist-generated diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 5578517ec359..0f45d51835f4 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -28,16 +28,18 @@ runs: - name: Install Build Dependencies shell: bash run: | - apt-get update - apt-get install -y protobuf-compiler + RETRY="ci/scripts/retry" + "${RETRY}" apt-get update + "${RETRY}" apt-get install -y protobuf-compiler - name: Setup Rust toolchain shell: bash # rustfmt is needed for the substrait build script run: | + RETRY="ci/scripts/retry" echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} - rustup default ${{ inputs.rust-version }} - rustup component add rustfmt + "${RETRY}" rustup toolchain install ${{ inputs.rust-version }} + "${RETRY}" rustup default ${{ inputs.rust-version }} + "${RETRY}" rustup component add rustfmt - name: Configure rust runtime env uses: ./.github/actions/setup-rust-runtime - name: Fixup git permissions diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index 34a37948785b..4e44e47f5968 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -17,11 +17,11 @@ development-process: - changed-files: - - any-glob-to-any-file: ['dev/**.*', '.github/**.*', 'ci/**.*', '.asf.yaml'] + - any-glob-to-any-file: ['dev/**/*', '.github/**/*', 'ci/**/*', '.asf.yaml'] documentation: - changed-files: - - any-glob-to-any-file: ['docs/**.*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**.*'] + - any-glob-to-any-file: ['docs/**/*', 'README.md', './**/README.md', 'DEVELOPERS.md', 'datafusion/docs/**/*'] sql: - changed-files: @@ -33,16 +33,37 @@ logical-expr: physical-expr: - changed-files: - - any-glob-to-any-file: ['datafusion/physical-expr/**/*'] + - any-glob-to-any-file: ['datafusion/physical-expr/**/*', 'datafusion/physical-expr-common/**/*', 'datafusion/physical-expr-aggregate/**/*', 'datafusion/physical-plan/**/*'] + +catalog: + - changed-files: + - any-glob-to-any-file: ['datafusion/catalog/**/*'] + +common: + - changed-files: + - any-glob-to-any-file: ['datafusion/common/**/*', 'datafusion/common-runtime/**/*'] + +execution: + - changed-files: + - any-glob-to-any-file: ['datafusion/execution/**/*'] + +functions: + - changed-files: + - any-glob-to-any-file: ['datafusion/functions/**/*', 'datafusion/functions-aggregate/**/*', 'datafusion/functions-aggregate-common', 'datafusion/functions-nested'] + optimizer: - changed-files: - - any-glob-to-any-file: ['datafusion/optimizer/**/*'] + - any-glob-to-any-file: ['datafusion/optimizer/**/*', 'datafusion/physical-optimizer/**/*'] core: - changed-files: - any-glob-to-any-file: ['datafusion/core/**/*'] +proto: + - changed-files: + - any-glob-to-any-file: ['datafusion/proto/**/*', 'datafusion/proto-common/**/*'] + substrait: - changed-files: - any-glob-to-any-file: ['datafusion/substrait/**/*'] diff --git a/.github/workflows/large_files.yml b/.github/workflows/large_files.yml new file mode 100644 index 000000000000..aa96d55a0d85 --- /dev/null +++ b/.github/workflows/large_files.yml @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +name: Large files PR check + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +on: + pull_request: + +jobs: + check-files: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Check size of new Git objects + env: + # 1 MB ought to be enough for anybody. + # TODO in case we may want to consciously commit a bigger file to the repo without using Git LFS we may disable the check e.g. with a label + MAX_FILE_SIZE_BYTES: 1048576 + shell: bash + run: | + git rev-list --objects ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} \ + > pull-request-objects.txt + exit_code=0 + while read -r id path; do + # Skip objects which are not files (commits, trees) + if [ ! -z "${path}" ]; then + size="$(git cat-file -s "${id}")" + if [ "${size}" -gt "${MAX_FILE_SIZE_BYTES}" ]; then + exit_code=1 + echo "Object ${id} [${path}] has size ${size}, exceeding ${MAX_FILE_SIZE_BYTES} limit." >&2 + echo "::error file=${path}::File ${path} has size ${size}, exceeding ${MAX_FILE_SIZE_BYTES} limit." + fi + fi + done < pull-request-objects.txt + exit "${exit_code}" diff --git a/.github/workflows/pr_benchmarks.yml b/.github/workflows/pr_benchmarks.yml deleted file mode 100644 index 5827c42e85ae..000000000000 --- a/.github/workflows/pr_benchmarks.yml +++ /dev/null @@ -1,101 +0,0 @@ -# Runs the benchmark command on the PR and -# on the branch, posting the results as a comment back the PR -name: Benchmarks - -on: - issue_comment: - -jobs: - benchmark: - name: Run Benchmarks - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/benchmark') - steps: - - name: Dump GitHub context - env: - GITHUB_CONTEXT: ${{ toJSON(github) }} - run: echo "$GITHUB_CONTEXT" - - - name: Checkout PR changes - uses: actions/checkout@v4 - with: - ref: refs/pull/${{ github.event.issue.number }}/head - - - name: Setup test data - # Workaround for `the input device is not a TTY`, appropriated from https://github.com/actions/runner/issues/241 - shell: 'script -q -e -c "bash -e {0}"' - run: | - cd benchmarks - mkdir data - - # Setup the TPC-H data sets for scale factors 1 and 10 - ./bench.sh data tpch - ./bench.sh data tpch10 - - - name: Generate unique result names - run: | - echo "HEAD_LONG_SHA=$(git log -1 --format='%H')" >> "$GITHUB_ENV" - echo "HEAD_SHORT_SHA=$(git log -1 --format='%h' --abbrev=7)" >> "$GITHUB_ENV" - echo "BASE_SHORT_SHA=$(echo "${{ github.sha }}" | cut -c1-7)" >> "$GITHUB_ENV" - - - name: Benchmark PR changes - env: - RESULTS_NAME: ${{ env.HEAD_SHORT_SHA }} - run: | - cd benchmarks - - ./bench.sh run tpch - ./bench.sh run tpch_mem - ./bench.sh run tpch10 - - # For some reason this step doesn't seem to propagate the env var down into the script - if [ -d "results/HEAD" ]; then - echo "Moving results into ${{ env.HEAD_SHORT_SHA }}" - mv results/HEAD results/${{ env.HEAD_SHORT_SHA }} - fi - - - name: Checkout base commit - uses: actions/checkout@v4 - with: - ref: ${{ github.sha }} - clean: false - - - name: Benchmark baseline and generate comparison message - env: - RESULTS_NAME: ${{ env.BASE_SHORT_SHA }} - run: | - cd benchmarks - - ./bench.sh run tpch - ./bench.sh run tpch_mem - ./bench.sh run tpch10 - - echo ${{ github.event.issue.number }} > pr - - pip3 install rich - cat > message.md < - Benchmarks comparing ${{ github.sha }} (main) and ${{ env.HEAD_LONG_SHA }} (PR) - - \`\`\` - $(./bench.sh compare ${{ env.BASE_SHORT_SHA }} ${{ env.HEAD_SHORT_SHA }}) - \`\`\` - - - EOF - - cat message.md - - - name: Upload benchmark comparison message - uses: actions/upload-artifact@v4 - with: - name: message - path: benchmarks/message.md - - - name: Upload PR number - uses: actions/upload-artifact@v4 - with: - name: pr - path: benchmarks/pr diff --git a/.github/workflows/pr_comment.yml b/.github/workflows/pr_comment.yml deleted file mode 100644 index 8b6df1c75687..000000000000 --- a/.github/workflows/pr_comment.yml +++ /dev/null @@ -1,53 +0,0 @@ -# Downloads any `message` artifacts created by other jobs -# and posts them as comments to the PR -name: PR Comment - -on: - workflow_run: - workflows: ["Benchmarks"] - types: - - completed - -jobs: - comment: - name: PR Comment - runs-on: ubuntu-latest - if: github.event.workflow_run.conclusion == 'success' - steps: - - name: Dump GitHub context - env: - GITHUB_CONTEXT: ${{ toJSON(github) }} - run: echo "$GITHUB_CONTEXT" - - - name: Download comment message - uses: actions/download-artifact@v4 - with: - name: message - run-id: ${{ github.event.workflow_run.id }} - github-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Download pr number - uses: actions/download-artifact@v4 - with: - name: pr - run-id: ${{ github.event.workflow_run.id }} - github-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Print message and pr number - run: | - cat pr - echo "PR_NUMBER=$(cat pr)" >> "$GITHUB_ENV" - cat message.md - - - name: Post the comment - uses: actions/github-script@v7 - with: - script: | - const fs = require('fs'); - const content = fs.readFileSync('message.md', 'utf8'); - github.rest.issues.createComment({ - issue_number: process.env.PR_NUMBER, - owner: context.repo.owner, - repo: context.repo.repo, - body: content, - }) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ce4b4b06cf44..39b7b2b17857 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -62,8 +62,7 @@ jobs: ~/.cargo/git/db/ ./target/ ./datafusion-cli/target/ - # this key equals the ones on `linux-build-lib` for re-use - key: cargo-cache-benchmark-${{ hashFiles('datafusion/**/Cargo.toml', 'benchmarks/Cargo.toml', 'datafusion-cli/Cargo.toml') }} + key: cargo-cache-${{ hashFiles('**/Cargo.toml', '**/Cargo.lock') }} - name: Check datafusion without default features # Some of the test binaries require the parquet feature still @@ -90,8 +89,8 @@ jobs: # Ensure that the datafusion crate can be built with only a subset of the function # packages enabled. - - name: Check datafusion (array_expressions) - run: cargo check --no-default-features --features=array_expressions -p datafusion + - name: Check datafusion (nested_expressions) + run: cargo check --no-default-features --features=nested_expressions -p datafusion - name: Check datafusion (crypto) run: cargo check --no-default-features --features=crypto_expressions -p datafusion @@ -234,11 +233,7 @@ jobs: with: rust-version: stable - name: Run cargo doc - run: | - export RUSTDOCFLAGS="-D warnings -A rustdoc::private-intra-doc-links" - cargo doc --document-private-items --no-deps --workspace - cd datafusion-cli - cargo doc --document-private-items --no-deps + run: ci/scripts/rust_docs.sh linux-wasm-pack: name: build with wasm-pack @@ -526,7 +521,7 @@ jobs: run: taplo format --check config-docs-check: - name: check configs.md is up-to-date + name: check configs.md and ***_functions.md is up-to-date needs: [ linux-build-lib ] runs-on: ubuntu-latest container: @@ -547,6 +542,11 @@ jobs: # If you encounter an error, run './dev/update_config_docs.sh' and commit ./dev/update_config_docs.sh git diff --exit-code + - name: Check if any of the ***_functions.md has been modified + run: | + # If you encounter an error, run './dev/update_function_docs.sh' and commit + ./dev/update_function_docs.sh + git diff --exit-code # Verify MSRV for the crates which are directly used by other projects: # - datafusion @@ -567,18 +567,32 @@ jobs: - name: Check datafusion working-directory: datafusion/core run: | - # If you encounter an error with any of the commands below - # it means some crate in your dependency tree has a higher - # MSRV (Min Supported Rust Version) than the one specified - # in the `rust-version` key of `Cargo.toml`. Check your - # dependencies or update the version in `Cargo.toml` - cargo msrv verify + # If you encounter an error with any of the commands below it means + # your code or some crate in the dependency tree has a higher MSRV + # (Min Supported Rust Version) than the one specified in the + # `rust-version` key of `Cargo.toml`. + # + # To reproduce: + # 1. Install the version of Rust that is failing. Example: + # rustup install 1.79.0 + # 2. Run the command that failed with that version. Example: + # cargo +1.79.0 check -p datafusion + # + # To resolve, either: + # 1. Change your code to use older Rust features, + # 2. Revert dependency update + # 3. Update the MSRV version in `Cargo.toml` + # + # Please see the DataFusion Rust Version Compatibility Policy before + # updating Cargo.toml. You may have to update the code instead. + # https://github.com/apache/datafusion/blob/main/README.md#rust-version-compatibility-policy + cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-substrait working-directory: datafusion/substrait - run: cargo msrv verify + run: cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-proto working-directory: datafusion/proto - run: cargo msrv verify + run: cargo msrv --output-format json --log-target stdout verify - name: Check datafusion-cli working-directory: datafusion-cli - run: cargo msrv verify + run: cargo msrv --output-format json --log-target stdout verify \ No newline at end of file diff --git a/.gitignore b/.gitignore index 05479fd0f07d..05570eacf630 100644 --- a/.gitignore +++ b/.gitignore @@ -16,45 +16,11 @@ # under the License. apache-rat-*.jar -arrow-src.tar -arrow-src.tar.gz - -# Compiled source -*.a -*.dll -*.o -*.py[ocd] -*.so -*.so.* -*.bundle -*.dylib -.build_cache_dir -dependency-reduced-pom.xml -MANIFEST -compile_commands.json -build.ninja - -# Generated Visual Studio files -*.vcxproj -*.vcxproj.* -*.sln -*.iml # Linux perf sample data perf.data perf.data.old -cpp/.idea/ -.clangd/ -cpp/.clangd/ -cpp/apidoc/xml/ -docs/example.gz -docs/example1.dat -docs/example3.dat -python/.eggs/ -python/doc/ -# Egg metadata -*.egg-info .vscode .idea/ @@ -66,16 +32,9 @@ docker_cache .*.swp .*.swo -site/ - -# R files -**/.Rproj.user -**/*.Rcheck/ -**/.Rhistory -.Rproj.user +venv/* # macOS -cpp/Brewfile.lock.json .DS_Store # docker volumes used for caching @@ -90,9 +49,6 @@ rusty-tags.vi .history .flatbuffers/ -.vscode -venv/* - # apache release artifacts dev/dist diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 9d2d2d81d680..000000000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,69 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# To use this, install the python package `pre-commit` and -# run once `pre-commit install`. This will setup a git pre-commit-hook -# that is executed on each commit and will report the linting problems. -# To run all hooks on all files use `pre-commit run -a` - -repos: - - repo: local - hooks: - - id: rat - name: Release Audit Tool - language: system - entry: bash -c "git archive HEAD --prefix=apache-arrow/ --output=arrow-src.tar && ./dev/release/run-rat.sh arrow-src.tar" - always_run: true - pass_filenames: false - - id: rustfmt - name: Rust Format - language: system - entry: bash -c "cd rust && cargo +stable fmt --all -- --check" - files: ^rust/.*\.rs$ - types: - - file - - rust - - id: cmake-format - name: CMake Format - language: python - entry: python run-cmake-format.py - types: [cmake] - additional_dependencies: - - cmake_format==0.5.2 - - id: hadolint - name: Docker Format - language: docker_image - types: - - dockerfile - entry: --entrypoint /bin/hadolint hadolint/hadolint:latest - - exclude: ^dev/.*$ - - repo: git://github.com/pre-commit/pre-commit-hooks - sha: v1.2.3 - hooks: - - id: flake8 - name: Python Format - files: ^(python|dev|integration)/ - types: - - file - - python - - id: flake8 - name: Cython Format - files: ^python/ - types: - - file - - cython - args: [--config=python/.flake8.cython] diff --git a/CHANGELOG.md b/CHANGELOG.md index ea0c339ac451..c481ce0b96a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ under the License. --> -* [DataFusion CHANGELOG](./datafusion/CHANGELOG.md) +Change logs for each release can be found [here](dev/changelog). + For older versions, see [apache/arrow/CHANGELOG.md](https://github.com/apache/arrow/blob/master/CHANGELOG.md). diff --git a/Cargo.toml b/Cargo.toml index 3ca3af284675..54bc68aa6329 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,42 +16,51 @@ # under the License. [workspace] +# datafusion-cli is excluded because of its Cargo.lock. See datafusion-cli/README.md. exclude = ["datafusion-cli", "dev/depcheck"] members = [ "datafusion/common", "datafusion/common-runtime", + "datafusion/catalog", "datafusion/core", "datafusion/expr", + "datafusion/expr-common", "datafusion/execution", - "datafusion/functions-aggregate", + "datafusion/ffi", "datafusion/functions", - "datafusion/functions-array", + "datafusion/functions-aggregate", + "datafusion/functions-aggregate-common", + "datafusion/functions-nested", + "datafusion/functions-window", + "datafusion/functions-window-common", "datafusion/optimizer", - "datafusion/physical-expr-common", "datafusion/physical-expr", + "datafusion/physical-expr-common", + "datafusion/physical-optimizer", "datafusion/physical-plan", "datafusion/proto", "datafusion/proto/gen", + "datafusion/proto-common", + "datafusion/proto-common/gen", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", "datafusion/wasmtest", "datafusion-examples", - "docs", "test-utils", "benchmarks", ] resolver = "2" [workspace.package] -authors = ["Apache Arrow "] +authors = ["Apache DataFusion "] edition = "2021" -homepage = "https://github.com/apache/datafusion" +homepage = "https://datafusion.apache.org" license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/datafusion" -rust-version = "1.73" -version = "37.1.0" +rust-version = "1.79" +version = "43.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -59,51 +68,80 @@ version = "37.1.0" # for the inherited dependency but cannot do the reverse (override from true to false). # # See for more detaiils: https://github.com/rust-lang/cargo/issues/11329 -arrow = { version = "51.0.0", features = ["prettyprint"] } -arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "51.0.0", default-features = false } -arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] } -arrow-ord = { version = "51.0.0", default-features = false } -arrow-schema = { version = "51.0.0", default-features = false } -arrow-string = { version = "51.0.0", default-features = false } +ahash = { version = "0.8", default-features = false, features = [ + "runtime-rng", +] } +arrow = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", features = [ + "prettyprint", +] } +arrow-array = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false, features = [ + "chrono-tz", +] } +arrow-buffer = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false } +arrow-flight = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", features = [ + "flight-sql-experimental", +] } +arrow-ipc = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false, features = [ + "lz4", +] } +arrow-ord = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false } +arrow-schema = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false } +arrow-string = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" -chrono = { version = "0.4.34", default-features = false } +chrono = { version = "0.4.38", default-features = false } ctor = "0.2.0" -dashmap = "5.4.0" -datafusion = { path = "datafusion/core", version = "37.1.0", default-features = false } -datafusion-common = { path = "datafusion/common", version = "37.1.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "37.1.0" } -datafusion-execution = { path = "datafusion/execution", version = "37.1.0" } -datafusion-expr = { path = "datafusion/expr", version = "37.1.0" } -datafusion-functions = { path = "datafusion/functions", version = "37.1.0" } -datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "37.1.0" } -datafusion-functions-array = { path = "datafusion/functions-array", version = "37.1.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "37.1.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "37.1.0", default-features = false } -datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "37.1.0", default-features = false } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "37.1.0" } -datafusion-proto = { path = "datafusion/proto", version = "37.1.0" } -datafusion-sql = { path = "datafusion/sql", version = "37.1.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "37.1.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "37.1.0" } +dashmap = "6.0.1" +datafusion = { path = "datafusion/core", version = "43.0.0", default-features = false } +datafusion-catalog = { path = "datafusion/catalog", version = "43.0.0" } +datafusion-common = { path = "datafusion/common", version = "43.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "43.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "43.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "43.0.0" } +datafusion-expr-common = { path = "datafusion/expr-common", version = "43.0.0" } +datafusion-ffi = { path = "datafusion/ffi", version = "43.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "43.0.0" } +datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", version = "43.0.0" } +datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "43.0.0" } +datafusion-functions-nested = { path = "datafusion/functions-nested", version = "43.0.0" } +datafusion-functions-window = { path = "datafusion/functions-window", version = "43.0.0" } +datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "43.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "43.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "43.0.0", default-features = false } +datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "43.0.0", default-features = false } +datafusion-physical-optimizer = { path = "datafusion/physical-optimizer", version = "43.0.0" } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "43.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "43.0.0" } +datafusion-proto-common = { path = "datafusion/proto-common", version = "43.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "43.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "43.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "43.0.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" half = { version = "2.2.1", default-features = false } +hashbrown = { version = "0.14.5", features = ["raw"] } indexmap = "2.0.0" -itertools = "0.12" +itertools = "0.13" log = "^0.4" num_cpus = "1.13.0" -object_store = { version = "0.9.1", default-features = false } +object_store = { version = "0.11.0", default-features = false } parking_lot = "0.12" -parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { git = "https://github.com/influxdata/arrow-rs", rev = "aa8c048", default-features = false, features = [ + "arrow", + "async", + "object_store", +] } +pbjson = { version = "0.7.0" } +# Should match arrow-flight's version of prost. +prost = "0.12.3" +prost-derive = "0.12.3" rand = "0.8" -rstest = "0.19.0" +regex = "1.8" +rstest = "0.23.0" serde_json = "1" -sqlparser = { version = "0.45.0", features = ["visitor"] } +sqlparser = { version = "0.51.0", features = ["visitor"] } tempfile = "3" thiserror = "1.0.44" tokio = { version = "1.36", features = ["macros", "rt", "sync"] } @@ -132,4 +170,5 @@ rpath = false large_futures = "warn" [workspace.lints.rust] -unused_imports = "deny" +unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unused_qualifications = "deny" diff --git a/README.md b/README.md index 8c2392850953..f89935d597c2 100644 --- a/README.md +++ b/README.md @@ -33,16 +33,45 @@ [discord-badge]: https://img.shields.io/discord/885562378132000778.svg?logo=discord&style=flat-square [discord-url]: https://discord.com/invite/Qw5gKqHxUM -[Website](https://github.com/apache/datafusion) | -[Guides](https://github.com/apache/datafusion/tree/main/docs) | +[Website](https://datafusion.apache.org/) | [API Docs](https://docs.rs/datafusion/latest/datafusion/) | [Chat](https://discord.com/channels/885562378132000778/885562378132000781) -logo - -Apache DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in -[Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) -in-memory format. [Python Bindings](https://github.com/apache/datafusion-python) are also available. DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and a great community. + + logo + + +DataFusion is an extensible query engine written in [Rust] that +uses [Apache Arrow] as its in-memory format. + +This crate provides libraries and binaries for developers building fast and +feature rich database and analytic systems, customized to particular workloads. +See [use cases] for examples. The following related subprojects target end users: + +- [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame + queries. +- [DataFusion Ray](https://github.com/apache/datafusion-ray/) provides a distributed version of DataFusion that scales + out on Ray clusters. +- [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on + DataFusion. + +"Out of the box," +DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], +built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and +a great community. + +DataFusion features a full query planner, a columnar, streaming, multi-threaded, +vectorized execution engine, and partitioned data sources. You can +customize DataFusion at almost all points including additional data sources, +query languages, functions, custom operators and more. +See the [Architecture] section for more details. + +[rust]: http://rustlang.org +[apache arrow]: https://arrow.apache.org +[use cases]: https://datafusion.apache.org/user-guide/introduction.html#use-cases +[python bindings]: https://github.com/apache/datafusion-python +[performance]: https://benchmark.clickhouse.com/ +[architecture]: https://datafusion.apache.org/contributor-guide/architecture.html Here are links to some important information @@ -51,7 +80,7 @@ Here are links to some important information - [Rust Getting Started](https://datafusion.apache.org/user-guide/example-usage.html) - [Rust DataFrame API](https://datafusion.apache.org/user-guide/dataframe.html) - [Rust API docs](https://docs.rs/datafusion/latest/datafusion) -- [Rust Examples](https://github.com/apache/datafusion/tree/master/datafusion-examples) +- [Rust Examples](https://github.com/apache/datafusion/tree/main/datafusion-examples) - [Python DataFrame API](https://arrow.apache.org/datafusion-python/) - [Architecture](https://docs.rs/datafusion/latest/datafusion/index.html#architecture) @@ -75,7 +104,7 @@ This crate has several [features] which can be specified in your `Cargo.toml`. Default features: -- `array_expressions`: functions for working with arrays such as `array_to_string` +- `nested_expressions`: functions for working with nested type function such as `array_to_string` - `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` - `crypto_expressions`: cryptographic functions such as `md5` and `sha256` - `datetime_expressions`: date and time functions such as `to_timestamp` @@ -97,9 +126,16 @@ Optional features: ## Rust Version Compatibility Policy -DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support -each stable Rust version for 6 months after it is -[released](https://github.com/rust-lang/rust/blob/master/RELEASES.md). This -generally translates to support for the most recent 3 to 4 stable Rust versions. +DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support stable [4 latest +Rust versions](https://releases.rs) OR the stable minor Rust version as of 4 months, whichever is lower. + +For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. + +If a hotfix is released for the minimum supported Rust version (MSRV), the MSRV will be the minor version with all hotfixes, even if it surpasses the four-month window. We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) + +## DataFusion API evolution policy + +Public methods in Apache DataFusion are subject to evolve as part of the API lifecycle. +Deprecated methods will be phased out in accordance with the [policy](https://datafusion.apache.org/library-user-guide/api-health.html), ensuring the API is stable and healthy. diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore index 2c574ff30d12..c35b1a7c1944 100644 --- a/benchmarks/.gitignore +++ b/benchmarks/.gitignore @@ -1,2 +1,3 @@ data -results \ No newline at end of file +results +venv diff --git a/benchmarks/README.md b/benchmarks/README.md index cb31c47d9b0b..a9aa1afb97a1 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -67,6 +67,13 @@ Create / download a specific dataset (TPCH) Data is placed in the `data` subdirectory. +## Select join algorithm +The benchmark runs with `prefer_hash_join == true` by default, which enforces HASH join algorithm. +To run TPCH benchmarks with join other than HASH: +```shell +PREFER_HASH_JOIN=false ./bench.sh run tpch +``` + ## Comparing performance of main and a branch ```shell @@ -150,7 +157,7 @@ Benchmark tpch_mem.json └──────────────┴──────────────┴──────────────┴───────────────┘ ``` -Note that you can also execute an automatic comparison of the changes in a given PR against the base +Note that you can also execute an automatic comparison of the changes in a given PR against the base just by including the trigger `/benchmark` in any comment. ### Running Benchmarks Manually @@ -177,7 +184,6 @@ The benchmark program also supports CSV and Parquet input file formats and a uti ```bash cargo run --release --bin tpch -- convert --input ./data --output /mnt/tpch-parquet --format parquet ``` - Or if you want to verify and run all the queries in the benchmark, you can just run `cargo test`. ### Comparing results between runs @@ -261,7 +267,7 @@ SUBCOMMANDS: # Benchmarks -The output of `dfbench` help includes a descripion of each benchmark, which is reproducedd here for convenience +The output of `dfbench` help includes a description of each benchmark, which is reproduced here for convenience ## ClickBench @@ -324,6 +330,16 @@ steps. The tests sort the entire dataset using several different sort orders. +## IMDB + +Run Join Order Benchmark (JOB) on IMDB dataset. + +The Internet Movie Database (IMDB) dataset contains real-world movie data. Unlike synthetic datasets like TPCH, which assume uniform data distribution and uncorrelated columns, the IMDB dataset includes skewed data and correlated columns (which are common for real dataset), making it more suitable for testing query optimizers, particularly for cardinality estimation. + +This benchmark is derived from [Join Order Benchmark](https://github.com/gregrahn/join-order-benchmark). + +See paper [How Good Are Query Optimizers, Really](http://www.vldb.org/pvldb/vol9/p204-leis.pdf) for more details. + ## TPCH Run the tpch benchmark. @@ -336,6 +352,34 @@ This benchmarks is derived from the [TPC-H][1] version [2]: https://github.com/databricks/tpch-dbgen.git, [2.17.1]: https://www.tpc.org/tpc_documents_current_versions/pdf/tpc-h_v2.17.1.pdf +## External Aggregation + +Run the benchmark for aggregations with limited memory. + +When the memory limit is exceeded, the aggregation intermediate results will be spilled to disk, and finally read back with sort-merge. + +External aggregation benchmarks run several aggregation queries with different memory limits, on TPCH `lineitem` table. Queries can be found in [`external_aggr.rs`](src/bin/external_aggr.rs). + +This benchmark is inspired by [DuckDB's external aggregation paper](https://hannes.muehleisen.org/publications/icde2024-out-of-core-kuiper-boncz-muehleisen.pdf), specifically Section VI. + +### External Aggregation Example Runs +1. Run all queries with predefined memory limits: +```bash +# Under 'benchmarks/' directory +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' +``` + +2. Run a query with specific memory limit: +```bash +cargo run --release --bin external_aggr -- benchmark -n 4 --iterations 3 -p '....../data/tpch_sf1' -o '/tmp/aggr.json' --query 1 --memory-limit 30M +``` + +3. Run all queries with `bench.sh` script: +```bash +./bench.sh data external_aggr +./bench.sh run external_aggr +``` + # Older Benchmarks diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 088edc56dfb0..47c5d1261605 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -34,8 +34,9 @@ COMMAND= BENCHMARK=all DATAFUSION_DIR=${DATAFUSION_DIR:-$SCRIPT_DIR/..} DATA_DIR=${DATA_DIR:-$SCRIPT_DIR/data} -#CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} -CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --profile release-nonlto"} # for faster iterations +CARGO_COMMAND=${CARGO_COMMAND:-"cargo run --release"} +PREFER_HASH_JOIN=${PREFER_HASH_JOIN:-true} +VIRTUAL_ENV=${VIRTUAL_ENV:-$SCRIPT_DIR/venv} usage() { echo " @@ -45,6 +46,7 @@ Usage: $0 data [benchmark] $0 run [benchmark] $0 compare +$0 venv ********** Examples: @@ -52,8 +54,8 @@ Examples: # Create the datasets for all benchmarks in $DATA_DIR ./bench.sh data -# Run the 'tpch' benchmark on the datafusion checkout in /source/arrow-datafusion -DATAFUSION_DIR=/source/arrow-datafusion ./bench.sh run tpch +# Run the 'tpch' benchmark on the datafusion checkout in /source/datafusion +DATAFUSION_DIR=/source/datafusion ./bench.sh run tpch ********** * Commands @@ -61,28 +63,32 @@ DATAFUSION_DIR=/source/arrow-datafusion ./bench.sh run tpch data: Generates or downloads data needed for benchmarking run: Runs the named benchmark compare: Compares results from benchmark runs +venv: Creates new venv (unless already exists) and installs compare's requirements into it ********** * Benchmarks ********** all(default): Data/Run/Compare for all benchmarks -tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table +tpch: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), single parquet file per table, hash join tpch_mem: TPCH inspired benchmark on Scale Factor (SF) 1 (~1GB), query from memory -tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table -tpch10_mem: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory +tpch10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), single parquet file per table, hash join +tpch_mem10: TPCH inspired benchmark on Scale Factor (SF) 10 (~10GB), query from memory parquet: Benchmark of parquet reader's filtering speed sort: Benchmark of sorting speed clickbench_1: ClickBench queries against a single parquet file clickbench_partitioned: ClickBench queries against a partitioned (100 files) parquet -clickbench_extended: ClickBench "inspired" queries against a single parquet (DataFusion specific) +clickbench_extended: ClickBench \"inspired\" queries against a single parquet (DataFusion specific) +external_aggr: External aggregation benchmark ********** * Supported Configuration (Environment Variables) ********** -DATA_DIR directory to store datasets -CARGO_COMMAND command that runs the benchmark binary -DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) -RESULTS_NAME folder where the benchmark files are stored +DATA_DIR directory to store datasets +CARGO_COMMAND command that runs the benchmark binary +DATAFUSION_DIR directory to use (default $DATAFUSION_DIR) +RESULTS_NAME folder where the benchmark files are stored +PREFER_HASH_JOIN Prefer hash join algorithm (default true) +VENV_PATH Python venv to use for compare and venv commands (default ./venv, override by /bin/activate) " exit 1 } @@ -101,7 +107,7 @@ while [[ $# -gt 0 ]]; do shift # past argument usage ;; - -*|--*) + -*) echo "Unknown option $1" exit 1 ;; @@ -129,6 +135,7 @@ main() { echo "BENCHMARK: ${BENCHMARK}" echo "DATA_DIR: ${DATA_DIR}" echo "CARGO_COMMAND: ${CARGO_COMMAND}" + echo "PREFER_HASH_JOIN: ${PREFER_HASH_JOIN}" echo "***************************" case "$BENCHMARK" in all) @@ -136,6 +143,7 @@ main() { data_tpch "10" data_clickbench_1 data_clickbench_partitioned + data_imdb ;; tpch) data_tpch "1" @@ -160,6 +168,13 @@ main() { clickbench_extended) data_clickbench_1 ;; + imdb) + data_imdb + ;; + external_aggr) + # same data as for tpch + data_tpch "1" + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for data generation" usage @@ -169,7 +184,7 @@ main() { run) # Parse positional parameters BENCHMARK=${ARG2:-"${BENCHMARK}"} - BRANCH_NAME=$(cd ${DATAFUSION_DIR} && git rev-parse --abbrev-ref HEAD) + BRANCH_NAME=$(cd "${DATAFUSION_DIR}" && git rev-parse --abbrev-ref HEAD) BRANCH_NAME=${BRANCH_NAME//\//_} # mind blowing syntax to replace / with _ RESULTS_NAME=${RESULTS_NAME:-"${BRANCH_NAME}"} RESULTS_DIR=${RESULTS_DIR:-"$SCRIPT_DIR/results/$RESULTS_NAME"} @@ -183,6 +198,7 @@ main() { echo "DATA_DIR: ${DATA_DIR}" echo "RESULTS_DIR: ${RESULTS_DIR}" echo "CARGO_COMMAND: ${CARGO_COMMAND}" + echo "PREFER_HASH_JOIN: ${PREFER_HASH_JOIN}" echo "***************************" # navigate to the appropriate directory @@ -200,6 +216,8 @@ main() { run_clickbench_1 run_clickbench_partitioned run_clickbench_extended + run_imdb + run_external_aggr ;; tpch) run_tpch "1" @@ -213,12 +231,6 @@ main() { tpch_mem10) run_tpch_mem "10" ;; - tpch_smj) - run_tpch_smj "1" - ;; - tpch_smj10) - run_tpch_smj "10" - ;; parquet) run_parquet ;; @@ -234,6 +246,12 @@ main() { clickbench_extended) run_clickbench_extended ;; + imdb) + run_imdb + ;; + external_aggr) + run_external_aggr + ;; *) echo "Error: unknown benchmark '$BENCHMARK' for run" usage @@ -243,9 +261,10 @@ main() { echo "Done" ;; compare) - BRANCH1=$1 - BRANCH2=$2 - compare_benchmarks + compare_benchmarks "$ARG2" "$ARG3" + ;; + venv) + setup_venv ;; "") usage @@ -286,7 +305,7 @@ data_tpch() { echo " tbl files exist ($FILE exists)." else echo " creating tbl files with tpch_dbgen..." - docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s ${SCALE_FACTOR} + docker run -v "${TPCH_DIR}":/data -it --rm ghcr.io/scalytics/tpch-docker:main -vf -s "${SCALE_FACTOR}" fi # Copy expected answers into the ./data/answers directory if it does not already exist @@ -323,22 +342,7 @@ run_tpch() { RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --format parquet -o ${RESULTS_FILE} -} - -# Runs the tpch benchmark with sort merge join -run_tpch_smj() { - SCALE_FACTOR=$1 - if [ -z "$SCALE_FACTOR" ] ; then - echo "Internal error: Scale factor not specified" - exit 1 - fi - TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - - RESULTS_FILE="${RESULTS_DIR}/tpch_smj_sf${SCALE_FACTOR}.json" - echo "RESULTS_FILE: ${RESULTS_FILE}" - echo "Running tpch SMJ benchmark..." - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join false --format parquet -o ${RESULTS_FILE} + $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" } # Runs the tpch in memory @@ -354,7 +358,7 @@ run_tpch_mem() { echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." # -m means in memory - $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" -m --format parquet -o ${RESULTS_FILE} + $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" -m --format parquet -o "${RESULTS_FILE}" } # Runs the parquet filter benchmark @@ -362,7 +366,7 @@ run_parquet() { RESULTS_FILE="${RESULTS_DIR}/parquet.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running parquet filter benchmark..." - $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o ${RESULTS_FILE} + $CARGO_COMMAND --bin parquet -- filter --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } # Runs the sort benchmark @@ -370,7 +374,7 @@ run_sort() { RESULTS_FILE="${RESULTS_DIR}/sort.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running sort benchmark..." - $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o ${RESULTS_FILE} + $CARGO_COMMAND --bin parquet -- sort --path "${DATA_DIR}" --scale-factor 1.0 --iterations 5 -o "${RESULTS_FILE}" } @@ -382,7 +386,7 @@ data_clickbench_1() { pushd "${DATA_DIR}" > /dev/null # Avoid downloading if it already exists and is the right size - OUTPUT_SIZE=`wc -c hits.parquet 2>/dev/null | awk '{print $1}' || true` + OUTPUT_SIZE=$(wc -c hits.parquet 2>/dev/null | awk '{print $1}' || true) echo -n "Checking hits.parquet..." if test "${OUTPUT_SIZE}" = "14779976446"; then echo -n "... found ${OUTPUT_SIZE} bytes ..." @@ -406,7 +410,7 @@ data_clickbench_partitioned() { pushd "${DATA_DIR}/hits_partitioned" > /dev/null echo -n "Checking hits_partitioned..." - OUTPUT_SIZE=`wc -c * 2>/dev/null | tail -n 1 | awk '{print $1}' || true` + OUTPUT_SIZE=$(wc -c -- * 2>/dev/null | tail -n 1 | awk '{print $1}' || true) if test "${OUTPUT_SIZE}" = "14737666736"; then echo -n "... found ${OUTPUT_SIZE} bytes ..." else @@ -424,7 +428,7 @@ run_clickbench_1() { RESULTS_FILE="${RESULTS_DIR}/clickbench_1.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" } # Runs the clickbench benchmark with the partitioned parquet files @@ -432,7 +436,7 @@ run_clickbench_partitioned() { RESULTS_FILE="${RESULTS_DIR}/clickbench_partitioned.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (partitioned, 100 files) benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits_partitioned" --queries-path "${SCRIPT_DIR}/queries/clickbench/queries.sql" -o "${RESULTS_FILE}" } # Runs the clickbench "extended" benchmark with a single large parquet file @@ -440,14 +444,116 @@ run_clickbench_extended() { RESULTS_FILE="${RESULTS_DIR}/clickbench_extended.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running clickbench (1 file) extended benchmark..." - $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o ${RESULTS_FILE} + $CARGO_COMMAND --bin dfbench -- clickbench --iterations 5 --path "${DATA_DIR}/hits.parquet" --queries-path "${SCRIPT_DIR}/queries/clickbench/extended.sql" -o "${RESULTS_FILE}" +} + +# Downloads the csv.gz files IMDB datasets from Peter Boncz's homepage(one of the JOB paper authors) +# http://homepages.cwi.nl/~boncz/job/imdb.tgz +data_imdb() { + local imdb_dir="${DATA_DIR}/imdb" + local imdb_temp_gz="${imdb_dir}/imdb.tgz" + local imdb_url="https://homepages.cwi.nl/~boncz/job/imdb.tgz" + + # imdb has 21 files, we just separate them into 3 groups for better readability + local first_required_files=( + "aka_name.parquet" + "aka_title.parquet" + "cast_info.parquet" + "char_name.parquet" + "comp_cast_type.parquet" + "company_name.parquet" + "company_type.parquet" + ) + + local second_required_files=( + "complete_cast.parquet" + "info_type.parquet" + "keyword.parquet" + "kind_type.parquet" + "link_type.parquet" + "movie_companies.parquet" + "movie_info.parquet" + ) + + local third_required_files=( + "movie_info_idx.parquet" + "movie_keyword.parquet" + "movie_link.parquet" + "name.parquet" + "person_info.parquet" + "role_type.parquet" + "title.parquet" + ) + + # Combine the three arrays into one + local required_files=("${first_required_files[@]}" "${second_required_files[@]}" "${third_required_files[@]}") + local convert_needed=false + + # Create directory if it doesn't exist + mkdir -p "${imdb_dir}" + + # Check if required files exist + for file in "${required_files[@]}"; do + if [ ! -f "${imdb_dir}/${file}" ]; then + convert_needed=true + break + fi + done + + if [ "$convert_needed" = true ]; then + if [ ! -f "${imdb_dir}/imdb.tgz" ]; then + echo "Downloading IMDB dataset..." + + # Download the dataset + curl -o "${imdb_temp_gz}" "${imdb_url}" + + # Extract the dataset + tar -xzvf "${imdb_temp_gz}" -C "${imdb_dir}" + $CARGO_COMMAND --bin imdb -- convert --input ${imdb_dir} --output ${imdb_dir} --format parquet + else + echo "IMDB.tgz already exists." + + # Extract the dataset + tar -xzvf "${imdb_temp_gz}" -C "${imdb_dir}" + $CARGO_COMMAND --bin imdb -- convert --input ${imdb_dir} --output ${imdb_dir} --format parquet + fi + echo "IMDB dataset downloaded and extracted." + else + echo "IMDB dataset already exists and contains required parquet files." + fi +} + +# Runs the imdb benchmark +run_imdb() { + IMDB_DIR="${DATA_DIR}/imdb" + + RESULTS_FILE="${RESULTS_DIR}/imdb.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running imdb benchmark..." + $CARGO_COMMAND --bin imdb -- benchmark datafusion --iterations 5 --path "${IMDB_DIR}" --prefer_hash_join "${PREFER_HASH_JOIN}" --format parquet -o "${RESULTS_FILE}" +} + +# Runs the external aggregation benchmark +run_external_aggr() { + # Use TPC-H SF1 dataset + TPCH_DIR="${DATA_DIR}/tpch_sf1" + RESULTS_FILE="${RESULTS_DIR}/external_aggr.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running external aggregation benchmark..." + + # Only parquet is supported. + # Since per-operator memory limit is calculated as (total-memory-limit / + # number-of-partitions), and by default `--partitions` is set to number of + # CPU cores, we set a constant number of partitions to prevent this + # benchmark to fail on some machines. + $CARGO_COMMAND --bin external_aggr -- benchmark --partitions 4 --iterations 5 --path "${TPCH_DIR}" -o "${RESULTS_FILE}" } compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" - BRANCH1="${ARG2}" - BRANCH2="${ARG3}" + BRANCH1="$1" + BRANCH2="$2" if [ -z "$BRANCH1" ] ; then echo " not specified. Available branches:" ls -1 "${BASE_RESULTS_DIR}" @@ -461,14 +567,14 @@ compare_benchmarks() { fi echo "Comparing ${BRANCH1} and ${BRANCH2}" - for bench in `ls ${BASE_RESULTS_DIR}/${BRANCH1}` ; do - RESULTS_FILE1="${BASE_RESULTS_DIR}/${BRANCH1}/${bench}" - RESULTS_FILE2="${BASE_RESULTS_DIR}/${BRANCH2}/${bench}" + for RESULTS_FILE1 in "${BASE_RESULTS_DIR}/${BRANCH1}"/*.json ; do + BENCH=$(basename "${RESULTS_FILE1}") + RESULTS_FILE2="${BASE_RESULTS_DIR}/${BRANCH2}/${BENCH}" if test -f "${RESULTS_FILE2}" ; then echo "--------------------" - echo "Benchmark ${bench}" + echo "Benchmark ${BENCH}" echo "--------------------" - python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" + PATH=$VIRTUAL_ENV/bin:$PATH python3 "${SCRIPT_DIR}"/compare.py "${RESULTS_FILE1}" "${RESULTS_FILE2}" else echo "Note: Skipping ${RESULTS_FILE1} as ${RESULTS_FILE2} does not exist" fi @@ -476,5 +582,10 @@ compare_benchmarks() { } +setup_venv() { + python3 -m venv "$VIRTUAL_ENV" + PATH=$VIRTUAL_ENV/bin:$PATH python3 -m pip install -r requirements.txt +} + # And start the process up main diff --git a/benchmarks/compare.py b/benchmarks/compare.py index ec2b28fa0556..2574c0735ca8 100755 --- a/benchmarks/compare.py +++ b/benchmarks/compare.py @@ -29,7 +29,7 @@ from rich.console import Console from rich.table import Table except ImportError: - print("Try `pip install rich` for using this script.") + print("Couldn't import modules -- run `./bench.sh venv` first") raise diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 29b1a7588f17..6797797409c1 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -14,7 +14,7 @@ ClickBench is focused on aggregation and filtering performance (though it has no The "extended" queries are not part of the official ClickBench benchmark. Instead they are used to test other DataFusion features that are not covered by -the standard benchmark Each description below is for the corresponding line in +the standard benchmark. Each description below is for the corresponding line in `extended.sql` (line 1 is `Q0`, line 2 is `Q1`, etc.) ### Q0: Data Exploration @@ -58,6 +58,77 @@ LIMIT 10; ``` +### Q3: What is the income distribution for users in specific regions + +**Question**: "What regions and social networks have the highest variance of parameter price?" + +**Important Query Properties**: STDDEV and VAR aggregation functions, GROUP BY multiple small ints + +```sql +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") +FROM 'hits.parquet' +GROUP BY "SocialSourceNetworkID", "RegionID" +HAVING s IS NOT NULL +ORDER BY s DESC +LIMIT 10; +``` + +### Q4: Response start time distribution analysis (median) + +**Question**: Find the WatchIDs with the highest median "ResponseStartTiming" without Java enabled + +**Important Query Properties**: MEDIAN, functions, high cardinality grouping that skips intermediate aggregation + +Note this query is somewhat synthetic as "WatchID" is almost unique (there are a few duplicates) + +```sql +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax +FROM 'hits.parquet' +WHERE "JavaEnable" = 0 -- filters to 32M of 100M rows +GROUP BY "ClientIP", "WatchID" +HAVING c > 1 +ORDER BY tmed DESC +LIMIT 10; +``` + +Results look like + ++-------------+---------------------+---+------+------+------+ +| ClientIP | WatchID | c | tmin | tmed | tmax | ++-------------+---------------------+---+------+------+------+ +| 1611957945 | 6655575552203051303 | 2 | 0 | 0 | 0 | +| -1402644643 | 8566928176839891583 | 2 | 0 | 0 | 0 | ++-------------+---------------------+---+------+------+------+ + + +### Q5: Response start time distribution analysis (p95) + +**Question**: Find the WatchIDs with the highest p95 "ResponseStartTiming" without Java enabled + +**Important Query Properties**: APPROX_PERCENTILE_CONT, functions, high cardinality grouping that skips intermediate aggregation + +Note this query is somewhat synthetic as "WatchID" is almost unique (there are a few duplicates) + +```sql +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax +FROM 'hits.parquet' +WHERE "JavaEnable" = 0 -- filters to 32M of 100M rows +GROUP BY "ClientIP", "WatchID" +HAVING c > 1 +ORDER BY tp95 DESC +LIMIT 10; +``` + +Results look like + ++-------------+---------------------+---+------+------+------+ +| ClientIP | WatchID | c | tmin | tp95 | tmax | ++-------------+---------------------+---+------+------+------+ +| 1611957945 | 6655575552203051303 | 2 | 0 | 0 | 0 | +| -1402644643 | 8566928176839891583 | 2 | 0 | 0 | 0 | ++-------------+---------------------+---+------+------+------+ + + ## Data Notes Here are some interesting statistics about the data used in the queries diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql index 0a2999fceb49..fbabaf2a7021 100644 --- a/benchmarks/queries/clickbench/extended.sql +++ b/benchmarks/queries/clickbench/extended.sql @@ -1,3 +1,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; -SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; \ No newline at end of file +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +SELECT "SocialSourceNetworkID", "RegionID", COUNT(*), AVG("Age"), AVG("ParamPrice"), STDDEV("ParamPrice") as s, VAR("ParamPrice") FROM hits GROUP BY "SocialSourceNetworkID", "RegionID" HAVING s IS NOT NULL ORDER BY s DESC LIMIT 10; +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, MEDIAN("ResponseStartTiming") tmed, MAX("ResponseStartTiming") tmax FROM hits WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tmed DESC LIMIT 10; +SELECT "ClientIP", "WatchID", COUNT(*) c, MIN("ResponseStartTiming") tmin, APPROX_PERCENTILE_CONT("ResponseStartTiming", 0.95) tp95, MAX("ResponseStartTiming") tmax FROM 'hits' WHERE "JavaEnable" = 0 GROUP BY "ClientIP", "WatchID" HAVING c > 1 ORDER BY tp95 DESC LIMIT 10; \ No newline at end of file diff --git a/benchmarks/queries/imdb/10a.sql b/benchmarks/queries/imdb/10a.sql new file mode 100644 index 000000000000..95b049b77479 --- /dev/null +++ b/benchmarks/queries/imdb/10a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS uncredited_voiced_character, MIN(t.title) AS russian_movie FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t WHERE ci.note like '%(voice)%' and ci.note like '%(uncredited)%' AND cn.country_code = '[ru]' AND rt.role = 'actor' AND t.production_year > 2005 AND t.id = mc.movie_id AND t.id = ci.movie_id AND ci.movie_id = mc.movie_id AND chn.id = ci.person_role_id AND rt.id = ci.role_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/10b.sql b/benchmarks/queries/imdb/10b.sql new file mode 100644 index 000000000000..c32153631412 --- /dev/null +++ b/benchmarks/queries/imdb/10b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character, MIN(t.title) AS russian_mov_with_actor_producer FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t WHERE ci.note like '%(producer)%' AND cn.country_code = '[ru]' AND rt.role = 'actor' AND t.production_year > 2010 AND t.id = mc.movie_id AND t.id = ci.movie_id AND ci.movie_id = mc.movie_id AND chn.id = ci.person_role_id AND rt.id = ci.role_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/10c.sql b/benchmarks/queries/imdb/10c.sql new file mode 100644 index 000000000000..b862cf4fa7ac --- /dev/null +++ b/benchmarks/queries/imdb/10c.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character, MIN(t.title) AS movie_with_american_producer FROM char_name AS chn, cast_info AS ci, company_name AS cn, company_type AS ct, movie_companies AS mc, role_type AS rt, title AS t WHERE ci.note like '%(producer)%' AND cn.country_code = '[us]' AND t.production_year > 1990 AND t.id = mc.movie_id AND t.id = ci.movie_id AND ci.movie_id = mc.movie_id AND chn.id = ci.person_role_id AND rt.id = ci.role_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/11a.sql b/benchmarks/queries/imdb/11a.sql new file mode 100644 index 000000000000..f835968e900b --- /dev/null +++ b/benchmarks/queries/imdb/11a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS non_polish_sequel_movie FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND t.production_year BETWEEN 1950 AND 2000 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/11b.sql b/benchmarks/queries/imdb/11b.sql new file mode 100644 index 000000000000..2411e19ea608 --- /dev/null +++ b/benchmarks/queries/imdb/11b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(lt.link) AS movie_link_type, MIN(t.title) AS sequel_movie FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follows%' AND mc.note IS NULL AND t.production_year = 1998 and t.title like '%Money%' AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/11c.sql b/benchmarks/queries/imdb/11c.sql new file mode 100644 index 000000000000..3bf794678918 --- /dev/null +++ b/benchmarks/queries/imdb/11c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' and (cn.name like '20th Century Fox%' or cn.name like 'Twentieth Century Fox%') AND ct.kind != 'production companies' and ct.kind is not NULL AND k.keyword in ('sequel', 'revenge', 'based-on-novel') AND mc.note is not NULL AND t.production_year > 1950 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/11d.sql b/benchmarks/queries/imdb/11d.sql new file mode 100644 index 000000000000..0bc33e1d6e88 --- /dev/null +++ b/benchmarks/queries/imdb/11d.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS from_company, MIN(mc.note) AS production_note, MIN(t.title) AS movie_based_on_book FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND ct.kind != 'production companies' and ct.kind is not NULL AND k.keyword in ('sequel', 'revenge', 'based-on-novel') AND mc.note is not NULL AND t.production_year > 1950 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/12a.sql b/benchmarks/queries/imdb/12a.sql new file mode 100644 index 000000000000..22add74bd55d --- /dev/null +++ b/benchmarks/queries/imdb/12a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS drama_horror_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t WHERE cn.country_code = '[us]' AND ct.kind = 'production companies' AND it1.info = 'genres' AND it2.info = 'rating' AND mi.info in ('Drama', 'Horror') AND mi_idx.info > '8.0' AND t.production_year between 2005 and 2008 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND mi.info_type_id = it1.id AND mi_idx.info_type_id = it2.id AND t.id = mc.movie_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id; diff --git a/benchmarks/queries/imdb/12b.sql b/benchmarks/queries/imdb/12b.sql new file mode 100644 index 000000000000..fc30ad550d10 --- /dev/null +++ b/benchmarks/queries/imdb/12b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS budget, MIN(t.title) AS unsuccsessful_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t WHERE cn.country_code ='[us]' AND ct.kind is not NULL and (ct.kind ='production companies' or ct.kind = 'distributors') AND it1.info ='budget' AND it2.info ='bottom 10 rank' AND t.production_year >2000 AND (t.title LIKE 'Birdemic%' OR t.title LIKE '%Movie%') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND mi.info_type_id = it1.id AND mi_idx.info_type_id = it2.id AND t.id = mc.movie_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id; diff --git a/benchmarks/queries/imdb/12c.sql b/benchmarks/queries/imdb/12c.sql new file mode 100644 index 000000000000..64a340b2381e --- /dev/null +++ b/benchmarks/queries/imdb/12c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS mainstream_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, title AS t WHERE cn.country_code = '[us]' AND ct.kind = 'production companies' AND it1.info = 'genres' AND it2.info = 'rating' AND mi.info in ('Drama', 'Horror', 'Western', 'Family') AND mi_idx.info > '7.0' AND t.production_year between 2000 and 2010 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND mi.info_type_id = it1.id AND mi_idx.info_type_id = it2.id AND t.id = mc.movie_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id; diff --git a/benchmarks/queries/imdb/13a.sql b/benchmarks/queries/imdb/13a.sql new file mode 100644 index 000000000000..95eb439d1e22 --- /dev/null +++ b/benchmarks/queries/imdb/13a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(miidx.info) AS rating, MIN(t.title) AS german_movie FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[de]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/13b.sql b/benchmarks/queries/imdb/13b.sql new file mode 100644 index 000000000000..4b6f75ab0ae6 --- /dev/null +++ b/benchmarks/queries/imdb/13b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND t.title != '' AND (t.title LIKE '%Champion%' OR t.title LIKE '%Loser%') AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/13c.sql b/benchmarks/queries/imdb/13c.sql new file mode 100644 index 000000000000..9e8c92327bd5 --- /dev/null +++ b/benchmarks/queries/imdb/13c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie_about_winning FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND t.title != '' AND (t.title LIKE 'Champion%' OR t.title LIKE 'Loser%') AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/13d.sql b/benchmarks/queries/imdb/13d.sql new file mode 100644 index 000000000000..a8bc567cabe1 --- /dev/null +++ b/benchmarks/queries/imdb/13d.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(miidx.info) AS rating, MIN(t.title) AS movie FROM company_name AS cn, company_type AS ct, info_type AS it, info_type AS it2, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS miidx, title AS t WHERE cn.country_code ='[us]' AND ct.kind ='production companies' AND it.info ='rating' AND it2.info ='release dates' AND kt.kind ='movie' AND mi.movie_id = t.id AND it2.id = mi.info_type_id AND kt.id = t.kind_id AND mc.movie_id = t.id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND miidx.movie_id = t.id AND it.id = miidx.info_type_id AND mi.movie_id = miidx.movie_id AND mi.movie_id = mc.movie_id AND miidx.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/14a.sql b/benchmarks/queries/imdb/14a.sql new file mode 100644 index 000000000000..af1a7c8983a6 --- /dev/null +++ b/benchmarks/queries/imdb/14a.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS northern_dark_movie FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind = 'movie' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2010 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/14b.sql b/benchmarks/queries/imdb/14b.sql new file mode 100644 index 000000000000..c606ebc73dd4 --- /dev/null +++ b/benchmarks/queries/imdb/14b.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS western_dark_production FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title') AND kt.kind = 'movie' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info > '6.0' AND t.production_year > 2010 and (t.title like '%murder%' or t.title like '%Murder%' or t.title like '%Mord%') AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/14c.sql b/benchmarks/queries/imdb/14c.sql new file mode 100644 index 000000000000..2a6dffde2639 --- /dev/null +++ b/benchmarks/queries/imdb/14c.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS north_european_dark_production FROM info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it1.info = 'countries' AND it2.info = 'rating' AND k.keyword is not null and k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/15a.sql b/benchmarks/queries/imdb/15a.sql new file mode 100644 index 000000000000..1d052f004426 --- /dev/null +++ b/benchmarks/queries/imdb/15a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS internet_movie FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' AND it1.info = 'release dates' AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' AND mi.note like '%internet%' AND mi.info like 'USA:% 200%' AND t.production_year > 2000 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/15b.sql b/benchmarks/queries/imdb/15b.sql new file mode 100644 index 000000000000..21c81358fa7a --- /dev/null +++ b/benchmarks/queries/imdb/15b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS youtube_movie FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' and cn.name = 'YouTube' AND it1.info = 'release dates' AND mc.note like '%(200%)%' and mc.note like '%(worldwide)%' AND mi.note like '%internet%' AND mi.info like 'USA:% 200%' AND t.production_year between 2005 and 2010 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/15c.sql b/benchmarks/queries/imdb/15c.sql new file mode 100644 index 000000000000..2d08c5203974 --- /dev/null +++ b/benchmarks/queries/imdb/15c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS release_date, MIN(t.title) AS modern_american_internet_movie FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' AND it1.info = 'release dates' AND mi.note like '%internet%' AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') AND t.production_year > 1990 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/15d.sql b/benchmarks/queries/imdb/15d.sql new file mode 100644 index 000000000000..040e9815d86c --- /dev/null +++ b/benchmarks/queries/imdb/15d.sql @@ -0,0 +1 @@ +SELECT MIN(at.title) AS aka_title, MIN(t.title) AS internet_movie_title FROM aka_title AS at, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cn.country_code = '[us]' AND it1.info = 'release dates' AND mi.note like '%internet%' AND t.production_year > 1990 AND t.id = at.movie_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = at.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = at.movie_id AND mc.movie_id = at.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id; diff --git a/benchmarks/queries/imdb/16a.sql b/benchmarks/queries/imdb/16a.sql new file mode 100644 index 000000000000..aaa0020269d2 --- /dev/null +++ b/benchmarks/queries/imdb/16a.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND t.episode_nr >= 50 AND t.episode_nr < 100 AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/16b.sql b/benchmarks/queries/imdb/16b.sql new file mode 100644 index 000000000000..c6c0bef319de --- /dev/null +++ b/benchmarks/queries/imdb/16b.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/16c.sql b/benchmarks/queries/imdb/16c.sql new file mode 100644 index 000000000000..5c3b35752195 --- /dev/null +++ b/benchmarks/queries/imdb/16c.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND t.episode_nr < 100 AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/16d.sql b/benchmarks/queries/imdb/16d.sql new file mode 100644 index 000000000000..c9e1b5f25ce5 --- /dev/null +++ b/benchmarks/queries/imdb/16d.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS cool_actor_pseudonym, MIN(t.title) AS series_named_after_char FROM aka_name AS an, cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND t.episode_nr >= 5 AND t.episode_nr < 100 AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17a.sql b/benchmarks/queries/imdb/17a.sql new file mode 100644 index 000000000000..e854a957e429 --- /dev/null +++ b/benchmarks/queries/imdb/17a.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_american_movie, MIN(n.name) AS a1 FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND n.name LIKE 'B%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17b.sql b/benchmarks/queries/imdb/17b.sql new file mode 100644 index 000000000000..903f2196b278 --- /dev/null +++ b/benchmarks/queries/imdb/17b.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE 'Z%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17c.sql b/benchmarks/queries/imdb/17c.sql new file mode 100644 index 000000000000..a96faa0b4339 --- /dev/null +++ b/benchmarks/queries/imdb/17c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie, MIN(n.name) AS a1 FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE 'X%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17d.sql b/benchmarks/queries/imdb/17d.sql new file mode 100644 index 000000000000..73e1f2c30976 --- /dev/null +++ b/benchmarks/queries/imdb/17d.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE '%Bert%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17e.sql b/benchmarks/queries/imdb/17e.sql new file mode 100644 index 000000000000..65ea73ed0510 --- /dev/null +++ b/benchmarks/queries/imdb/17e.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/17f.sql b/benchmarks/queries/imdb/17f.sql new file mode 100644 index 000000000000..542233d63e9d --- /dev/null +++ b/benchmarks/queries/imdb/17f.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS member_in_charnamed_movie FROM cast_info AS ci, company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword ='character-name-in-title' AND n.name LIKE '%B%' AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.movie_id = mc.movie_id AND ci.movie_id = mk.movie_id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/18a.sql b/benchmarks/queries/imdb/18a.sql new file mode 100644 index 000000000000..275e04bdb184 --- /dev/null +++ b/benchmarks/queries/imdb/18a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t WHERE ci.note in ('(producer)', '(executive producer)') AND it1.info = 'budget' AND it2.info = 'votes' AND n.gender = 'm' and n.name like '%Tim%' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/18b.sql b/benchmarks/queries/imdb/18b.sql new file mode 100644 index 000000000000..3ae40ed93d2f --- /dev/null +++ b/benchmarks/queries/imdb/18b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'rating' AND mi.info in ('Horror', 'Thriller') and mi.note is NULL AND mi_idx.info > '8.0' AND n.gender is not null and n.gender = 'f' AND t.production_year between 2008 and 2014 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/18c.sql b/benchmarks/queries/imdb/18c.sql new file mode 100644 index 000000000000..01f28ea527fe --- /dev/null +++ b/benchmarks/queries/imdb/18c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(t.title) AS movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, movie_info AS mi, movie_info_idx AS mi_idx, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND mi.movie_id = mi_idx.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/19a.sql b/benchmarks/queries/imdb/19a.sql new file mode 100644 index 000000000000..ceaae671fd20 --- /dev/null +++ b/benchmarks/queries/imdb/19a.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%Ang%' AND rt.role ='actress' AND t.production_year between 2005 and 2009 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/19b.sql b/benchmarks/queries/imdb/19b.sql new file mode 100644 index 000000000000..62e852ba3ec6 --- /dev/null +++ b/benchmarks/queries/imdb/19b.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS kung_fu_panda FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note = '(voice)' AND cn.country_code ='[us]' AND it.info = 'release dates' AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND mi.info is not null and (mi.info like 'Japan:%2007%' or mi.info like 'USA:%2008%') AND n.gender ='f' and n.name like '%Angel%' AND rt.role ='actress' AND t.production_year between 2007 and 2008 and t.title like '%Kung%Fu%Panda%' AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/19c.sql b/benchmarks/queries/imdb/19c.sql new file mode 100644 index 000000000000..6885af5012fc --- /dev/null +++ b/benchmarks/queries/imdb/19c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year > 2000 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/19d.sql b/benchmarks/queries/imdb/19d.sql new file mode 100644 index 000000000000..06fcc76ba7ad --- /dev/null +++ b/benchmarks/queries/imdb/19d.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS voicing_actress, MIN(t.title) AS jap_engl_voiced_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, movie_companies AS mc, movie_info AS mi, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND n.gender ='f' AND rt.role ='actress' AND t.production_year > 2000 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mi.movie_id = ci.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id; diff --git a/benchmarks/queries/imdb/1a.sql b/benchmarks/queries/imdb/1a.sql new file mode 100644 index 000000000000..07b351638857 --- /dev/null +++ b/benchmarks/queries/imdb/1a.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'top 250 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%' or mc.note like '%(presents)%') AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/1b.sql b/benchmarks/queries/imdb/1b.sql new file mode 100644 index 000000000000..f2901e8b5262 --- /dev/null +++ b/benchmarks/queries/imdb/1b.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'bottom 10 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' AND t.production_year between 2005 and 2010 AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/1c.sql b/benchmarks/queries/imdb/1c.sql new file mode 100644 index 000000000000..94e66c30aa14 --- /dev/null +++ b/benchmarks/queries/imdb/1c.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'top 250 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' and (mc.note like '%(co-production)%') AND t.production_year >2010 AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/1d.sql b/benchmarks/queries/imdb/1d.sql new file mode 100644 index 000000000000..52f58e80c811 --- /dev/null +++ b/benchmarks/queries/imdb/1d.sql @@ -0,0 +1 @@ +SELECT MIN(mc.note) AS production_note, MIN(t.title) AS movie_title, MIN(t.production_year) AS movie_year FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info_idx AS mi_idx, title AS t WHERE ct.kind = 'production companies' AND it.info = 'bottom 10 rank' AND mc.note not like '%(as Metro-Goldwyn-Mayer Pictures)%' AND t.production_year >2000 AND ct.id = mc.company_type_id AND t.id = mc.movie_id AND t.id = mi_idx.movie_id AND mc.movie_id = mi_idx.movie_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/20a.sql b/benchmarks/queries/imdb/20a.sql new file mode 100644 index 000000000000..2a1c269d6a51 --- /dev/null +++ b/benchmarks/queries/imdb/20a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS complete_downey_ironman_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND kt.kind = 'movie' AND t.production_year > 1950 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND ci.movie_id = cc.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/20b.sql b/benchmarks/queries/imdb/20b.sql new file mode 100644 index 000000000000..4c2455a52eb1 --- /dev/null +++ b/benchmarks/queries/imdb/20b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS complete_downey_ironman_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name not like '%Sherlock%' and (chn.name like '%Tony%Stark%' or chn.name like '%Iron%Man%') AND k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND kt.kind = 'movie' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND ci.movie_id = cc.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/20c.sql b/benchmarks/queries/imdb/20c.sql new file mode 100644 index 000000000000..b85b22f6b4f2 --- /dev/null +++ b/benchmarks/queries/imdb/20c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS cast_member, MIN(t.title) AS complete_dynamic_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, keyword AS k, kind_type AS kt, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') AND kt.kind = 'movie' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND ci.movie_id = cc.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/21a.sql b/benchmarks/queries/imdb/21a.sql new file mode 100644 index 000000000000..8a66a00be6cb --- /dev/null +++ b/benchmarks/queries/imdb/21a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') AND t.production_year BETWEEN 1950 AND 2000 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id; diff --git a/benchmarks/queries/imdb/21b.sql b/benchmarks/queries/imdb/21b.sql new file mode 100644 index 000000000000..90d3a5a4c078 --- /dev/null +++ b/benchmarks/queries/imdb/21b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS german_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Germany', 'German') AND t.production_year BETWEEN 2000 AND 2010 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id; diff --git a/benchmarks/queries/imdb/21c.sql b/benchmarks/queries/imdb/21c.sql new file mode 100644 index 000000000000..16a42ae6f426 --- /dev/null +++ b/benchmarks/queries/imdb/21c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS company_name, MIN(lt.link) AS link_type, MIN(t.title) AS western_follow_up FROM company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') AND t.production_year BETWEEN 1950 AND 2010 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id; diff --git a/benchmarks/queries/imdb/22a.sql b/benchmarks/queries/imdb/22a.sql new file mode 100644 index 000000000000..e513799698c5 --- /dev/null +++ b/benchmarks/queries/imdb/22a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Germany', 'German', 'USA', 'American') AND mi_idx.info < '7.0' AND t.production_year > 2008 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/22b.sql b/benchmarks/queries/imdb/22b.sql new file mode 100644 index 000000000000..f98d0ea8099d --- /dev/null +++ b/benchmarks/queries/imdb/22b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Germany', 'German', 'USA', 'American') AND mi_idx.info < '7.0' AND t.production_year > 2009 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/22c.sql b/benchmarks/queries/imdb/22c.sql new file mode 100644 index 000000000000..cf757956e0de --- /dev/null +++ b/benchmarks/queries/imdb/22c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/22d.sql b/benchmarks/queries/imdb/22d.sql new file mode 100644 index 000000000000..a47feeb05157 --- /dev/null +++ b/benchmarks/queries/imdb/22d.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS western_violent_movie FROM company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mc.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/23a.sql b/benchmarks/queries/imdb/23a.sql new file mode 100644 index 000000000000..724da913b51a --- /dev/null +++ b/benchmarks/queries/imdb/23a.sql @@ -0,0 +1 @@ +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cct1.kind = 'complete+verified' AND cn.country_code = '[us]' AND it1.info = 'release dates' AND kt.kind in ('movie') AND mi.note like '%internet%' AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND cct1.id = cc.status_id; diff --git a/benchmarks/queries/imdb/23b.sql b/benchmarks/queries/imdb/23b.sql new file mode 100644 index 000000000000..e39f0ecc28a2 --- /dev/null +++ b/benchmarks/queries/imdb/23b.sql @@ -0,0 +1 @@ +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_nerdy_internet_movie FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cct1.kind = 'complete+verified' AND cn.country_code = '[us]' AND it1.info = 'release dates' AND k.keyword in ('nerd', 'loner', 'alienation', 'dignity') AND kt.kind in ('movie') AND mi.note like '%internet%' AND mi.info like 'USA:% 200%' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND cct1.id = cc.status_id; diff --git a/benchmarks/queries/imdb/23c.sql b/benchmarks/queries/imdb/23c.sql new file mode 100644 index 000000000000..839d762d0533 --- /dev/null +++ b/benchmarks/queries/imdb/23c.sql @@ -0,0 +1 @@ +SELECT MIN(kt.kind) AS movie_kind, MIN(t.title) AS complete_us_internet_movie FROM complete_cast AS cc, comp_cast_type AS cct1, company_name AS cn, company_type AS ct, info_type AS it1, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, title AS t WHERE cct1.kind = 'complete+verified' AND cn.country_code = '[us]' AND it1.info = 'release dates' AND kt.kind in ('movie', 'tv movie', 'video movie', 'video game') AND mi.note like '%internet%' AND mi.info is not NULL and (mi.info like 'USA:% 199%' or mi.info like 'USA:% 200%') AND t.production_year > 1990 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND cn.id = mc.company_id AND ct.id = mc.company_type_id AND cct1.id = cc.status_id; diff --git a/benchmarks/queries/imdb/24a.sql b/benchmarks/queries/imdb/24a.sql new file mode 100644 index 000000000000..8f10621e0209 --- /dev/null +++ b/benchmarks/queries/imdb/24a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS voiced_action_movie_jap_eng FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat') AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year > 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND ci.movie_id = mk.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/24b.sql b/benchmarks/queries/imdb/24b.sql new file mode 100644 index 000000000000..d8a2836000b2 --- /dev/null +++ b/benchmarks/queries/imdb/24b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress_name, MIN(t.title) AS kung_fu_panda FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND cn.name = 'DreamWorks Animation' AND it.info = 'release dates' AND k.keyword in ('hero', 'martial-arts', 'hand-to-hand-combat', 'computer-animated-movie') AND mi.info is not null and (mi.info like 'Japan:%201%' or mi.info like 'USA:%201%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year > 2010 AND t.title like 'Kung Fu Panda%' AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND ci.movie_id = mk.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/25a.sql b/benchmarks/queries/imdb/25a.sql new file mode 100644 index 000000000000..bc55cc01d26b --- /dev/null +++ b/benchmarks/queries/imdb/25a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') AND mi.info = 'Horror' AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi_idx.movie_id = mk.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/25b.sql b/benchmarks/queries/imdb/25b.sql new file mode 100644 index 000000000000..3457655bb9eb --- /dev/null +++ b/benchmarks/queries/imdb/25b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'blood', 'gore', 'death', 'female-nudity') AND mi.info = 'Horror' AND n.gender = 'm' AND t.production_year > 2010 AND t.title like 'Vampire%' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi_idx.movie_id = mk.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/25c.sql b/benchmarks/queries/imdb/25c.sql new file mode 100644 index 000000000000..cf56a313d861 --- /dev/null +++ b/benchmarks/queries/imdb/25c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS male_writer, MIN(t.title) AS violent_movie_title FROM cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi_idx.movie_id = mk.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/26a.sql b/benchmarks/queries/imdb/26a.sql new file mode 100644 index 000000000000..b431f204c6dc --- /dev/null +++ b/benchmarks/queries/imdb/26a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(n.name) AS playing_actor, MIN(t.title) AS complete_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND it2.info = 'rating' AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') AND kt.kind = 'movie' AND mi_idx.info > '7.0' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND mk.movie_id = mi_idx.movie_id AND ci.movie_id = cc.movie_id AND ci.movie_id = mi_idx.movie_id AND cc.movie_id = mi_idx.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/26b.sql b/benchmarks/queries/imdb/26b.sql new file mode 100644 index 000000000000..882d234d77e0 --- /dev/null +++ b/benchmarks/queries/imdb/26b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND it2.info = 'rating' AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'fight') AND kt.kind = 'movie' AND mi_idx.info > '8.0' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND mk.movie_id = mi_idx.movie_id AND ci.movie_id = cc.movie_id AND ci.movie_id = mi_idx.movie_id AND cc.movie_id = mi_idx.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/26c.sql b/benchmarks/queries/imdb/26c.sql new file mode 100644 index 000000000000..4b9eae0b7633 --- /dev/null +++ b/benchmarks/queries/imdb/26c.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS character_name, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_hero_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, info_type AS it2, keyword AS k, kind_type AS kt, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like '%complete%' AND chn.name is not NULL and (chn.name like '%man%' or chn.name like '%Man%') AND it2.info = 'rating' AND k.keyword in ('superhero', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence', 'magnet', 'web', 'claw', 'laser') AND kt.kind = 'movie' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND t.id = cc.movie_id AND t.id = mi_idx.movie_id AND mk.movie_id = ci.movie_id AND mk.movie_id = cc.movie_id AND mk.movie_id = mi_idx.movie_id AND ci.movie_id = cc.movie_id AND ci.movie_id = mi_idx.movie_id AND cc.movie_id = mi_idx.movie_id AND chn.id = ci.person_role_id AND n.id = ci.person_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND it2.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/27a.sql b/benchmarks/queries/imdb/27a.sql new file mode 100644 index 000000000000..239673cd8147 --- /dev/null +++ b/benchmarks/queries/imdb/27a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind = 'complete' AND cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') AND t.production_year BETWEEN 1950 AND 2000 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND t.id = cc.movie_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id AND ml.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = cc.movie_id; diff --git a/benchmarks/queries/imdb/27b.sql b/benchmarks/queries/imdb/27b.sql new file mode 100644 index 000000000000..4bf85260f22d --- /dev/null +++ b/benchmarks/queries/imdb/27b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind = 'complete' AND cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Germany','Swedish', 'German') AND t.production_year = 1998 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND t.id = cc.movie_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id AND ml.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = cc.movie_id; diff --git a/benchmarks/queries/imdb/27c.sql b/benchmarks/queries/imdb/27c.sql new file mode 100644 index 000000000000..dc26ebff6851 --- /dev/null +++ b/benchmarks/queries/imdb/27c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS producing_company, MIN(lt.link) AS link_type, MIN(t.title) AS complete_western_sequel FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, keyword AS k, link_type AS lt, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, movie_link AS ml, title AS t WHERE cct1.kind = 'cast' AND cct2.kind like 'complete%' AND cn.country_code !='[pl]' AND (cn.name LIKE '%Film%' OR cn.name LIKE '%Warner%') AND ct.kind ='production companies' AND k.keyword ='sequel' AND lt.link LIKE '%follow%' AND mc.note IS NULL AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'English') AND t.production_year BETWEEN 1950 AND 2010 AND lt.id = ml.link_type_id AND ml.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND t.id = mc.movie_id AND mc.company_type_id = ct.id AND mc.company_id = cn.id AND mi.movie_id = t.id AND t.id = cc.movie_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id AND ml.movie_id = mk.movie_id AND ml.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND ml.movie_id = mi.movie_id AND mk.movie_id = mi.movie_id AND mc.movie_id = mi.movie_id AND ml.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = cc.movie_id; diff --git a/benchmarks/queries/imdb/28a.sql b/benchmarks/queries/imdb/28a.sql new file mode 100644 index 000000000000..8cb1177386da --- /dev/null +++ b/benchmarks/queries/imdb/28a.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cct1.kind = 'crew' AND cct2.kind != 'complete+verified' AND cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2000 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = mi_idx.movie_id AND mc.movie_id = cc.movie_id AND mi_idx.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/28b.sql b/benchmarks/queries/imdb/28b.sql new file mode 100644 index 000000000000..10f43c898226 --- /dev/null +++ b/benchmarks/queries/imdb/28b.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cct1.kind = 'crew' AND cct2.kind != 'complete+verified' AND cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Germany', 'Swedish', 'German') AND mi_idx.info > '6.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = mi_idx.movie_id AND mc.movie_id = cc.movie_id AND mi_idx.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/28c.sql b/benchmarks/queries/imdb/28c.sql new file mode 100644 index 000000000000..6b2e4047ae8a --- /dev/null +++ b/benchmarks/queries/imdb/28c.sql @@ -0,0 +1 @@ +SELECT MIN(cn.name) AS movie_company, MIN(mi_idx.info) AS rating, MIN(t.title) AS complete_euro_dark_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, company_name AS cn, company_type AS ct, info_type AS it1, info_type AS it2, keyword AS k, kind_type AS kt, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE cct1.kind = 'cast' AND cct2.kind = 'complete' AND cn.country_code != '[us]' AND it1.info = 'countries' AND it2.info = 'rating' AND k.keyword in ('murder', 'murder-in-title', 'blood', 'violence') AND kt.kind in ('movie', 'episode') AND mc.note not like '%(USA)%' and mc.note like '%(200%)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Danish', 'Norwegian', 'German', 'USA', 'American') AND mi_idx.info < '8.5' AND t.production_year > 2005 AND kt.id = t.kind_id AND t.id = mi.movie_id AND t.id = mk.movie_id AND t.id = mi_idx.movie_id AND t.id = mc.movie_id AND t.id = cc.movie_id AND mk.movie_id = mi.movie_id AND mk.movie_id = mi_idx.movie_id AND mk.movie_id = mc.movie_id AND mk.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mc.movie_id AND mi.movie_id = cc.movie_id AND mc.movie_id = mi_idx.movie_id AND mc.movie_id = cc.movie_id AND mi_idx.movie_id = cc.movie_id AND k.id = mk.keyword_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND ct.id = mc.company_type_id AND cn.id = mc.company_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/29a.sql b/benchmarks/queries/imdb/29a.sql new file mode 100644 index 000000000000..3033acbe6cf3 --- /dev/null +++ b/benchmarks/queries/imdb/29a.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t WHERE cct1.kind ='cast' AND cct2.kind ='complete+verified' AND chn.name = 'Queen' AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND it3.info = 'trivia' AND k.keyword = 'computer-animation' AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.title = 'Shrek 2' AND t.production_year between 2000 and 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND n.id = pi.person_id AND ci.person_id = pi.person_id AND it3.id = pi.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/29b.sql b/benchmarks/queries/imdb/29b.sql new file mode 100644 index 000000000000..88d50fc7b783 --- /dev/null +++ b/benchmarks/queries/imdb/29b.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t WHERE cct1.kind ='cast' AND cct2.kind ='complete+verified' AND chn.name = 'Queen' AND ci.note in ('(voice)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND it3.info = 'height' AND k.keyword = 'computer-animation' AND mi.info like 'USA:%200%' AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.title = 'Shrek 2' AND t.production_year between 2000 and 2005 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND n.id = pi.person_id AND ci.person_id = pi.person_id AND it3.id = pi.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/29c.sql b/benchmarks/queries/imdb/29c.sql new file mode 100644 index 000000000000..cb951781827c --- /dev/null +++ b/benchmarks/queries/imdb/29c.sql @@ -0,0 +1 @@ +SELECT MIN(chn.name) AS voiced_char, MIN(n.name) AS voicing_actress, MIN(t.title) AS voiced_animation FROM aka_name AS an, complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, char_name AS chn, cast_info AS ci, company_name AS cn, info_type AS it, info_type AS it3, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_keyword AS mk, name AS n, person_info AS pi, role_type AS rt, title AS t WHERE cct1.kind ='cast' AND cct2.kind ='complete+verified' AND ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND it.info = 'release dates' AND it3.info = 'trivia' AND k.keyword = 'computer-animation' AND mi.info is not null and (mi.info like 'Japan:%200%' or mi.info like 'USA:%200%') AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND t.production_year between 2000 and 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND mc.movie_id = ci.movie_id AND mc.movie_id = mi.movie_id AND mc.movie_id = mk.movie_id AND mc.movie_id = cc.movie_id AND mi.movie_id = ci.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND cn.id = mc.company_id AND it.id = mi.info_type_id AND n.id = ci.person_id AND rt.id = ci.role_id AND n.id = an.person_id AND ci.person_id = an.person_id AND chn.id = ci.person_role_id AND n.id = pi.person_id AND ci.person_id = pi.person_id AND it3.id = pi.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/2a.sql b/benchmarks/queries/imdb/2a.sql new file mode 100644 index 000000000000..f3ef4db75fea --- /dev/null +++ b/benchmarks/queries/imdb/2a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[de]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/2b.sql b/benchmarks/queries/imdb/2b.sql new file mode 100644 index 000000000000..82b2123fbccd --- /dev/null +++ b/benchmarks/queries/imdb/2b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[nl]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/2c.sql b/benchmarks/queries/imdb/2c.sql new file mode 100644 index 000000000000..b5f9b75dd68b --- /dev/null +++ b/benchmarks/queries/imdb/2c.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[sm]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/2d.sql b/benchmarks/queries/imdb/2d.sql new file mode 100644 index 000000000000..4a2791946548 --- /dev/null +++ b/benchmarks/queries/imdb/2d.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM company_name AS cn, keyword AS k, movie_companies AS mc, movie_keyword AS mk, title AS t WHERE cn.country_code ='[us]' AND k.keyword ='character-name-in-title' AND cn.id = mc.company_id AND mc.movie_id = t.id AND t.id = mk.movie_id AND mk.keyword_id = k.id AND mc.movie_id = mk.movie_id; diff --git a/benchmarks/queries/imdb/30a.sql b/benchmarks/queries/imdb/30a.sql new file mode 100644 index 000000000000..698872fa8337 --- /dev/null +++ b/benchmarks/queries/imdb/30a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind ='complete+verified' AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.production_year > 2000 AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/30b.sql b/benchmarks/queries/imdb/30b.sql new file mode 100644 index 000000000000..5fdb8493496c --- /dev/null +++ b/benchmarks/queries/imdb/30b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_gore_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind in ('cast', 'crew') AND cct2.kind ='complete+verified' AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/30c.sql b/benchmarks/queries/imdb/30c.sql new file mode 100644 index 000000000000..a18087e39222 --- /dev/null +++ b/benchmarks/queries/imdb/30c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS complete_violent_movie FROM complete_cast AS cc, comp_cast_type AS cct1, comp_cast_type AS cct2, cast_info AS ci, info_type AS it1, info_type AS it2, keyword AS k, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE cct1.kind = 'cast' AND cct2.kind ='complete+verified' AND ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = cc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = cc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = cc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = cc.movie_id AND mk.movie_id = cc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cct1.id = cc.subject_id AND cct2.id = cc.status_id; diff --git a/benchmarks/queries/imdb/31a.sql b/benchmarks/queries/imdb/31a.sql new file mode 100644 index 000000000000..7dd855011f2a --- /dev/null +++ b/benchmarks/queries/imdb/31a.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND cn.name like 'Lionsgate%' AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = mc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/31b.sql b/benchmarks/queries/imdb/31b.sql new file mode 100644 index 000000000000..3be5680f7d00 --- /dev/null +++ b/benchmarks/queries/imdb/31b.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND cn.name like 'Lionsgate%' AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mc.note like '%(Blu-ray)%' AND mi.info in ('Horror', 'Thriller') AND n.gender = 'm' AND t.production_year > 2000 and (t.title like '%Freddy%' or t.title like '%Jason%' or t.title like 'Saw%') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = mc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/31c.sql b/benchmarks/queries/imdb/31c.sql new file mode 100644 index 000000000000..156ea2d5eee2 --- /dev/null +++ b/benchmarks/queries/imdb/31c.sql @@ -0,0 +1 @@ +SELECT MIN(mi.info) AS movie_budget, MIN(mi_idx.info) AS movie_votes, MIN(n.name) AS writer, MIN(t.title) AS violent_liongate_movie FROM cast_info AS ci, company_name AS cn, info_type AS it1, info_type AS it2, keyword AS k, movie_companies AS mc, movie_info AS mi, movie_info_idx AS mi_idx, movie_keyword AS mk, name AS n, title AS t WHERE ci.note in ('(writer)', '(head writer)', '(written by)', '(story)', '(story editor)') AND cn.name like 'Lionsgate%' AND it1.info = 'genres' AND it2.info = 'votes' AND k.keyword in ('murder', 'violence', 'blood', 'gore', 'death', 'female-nudity', 'hospital') AND mi.info in ('Horror', 'Action', 'Sci-Fi', 'Thriller', 'Crime', 'War') AND t.id = mi.movie_id AND t.id = mi_idx.movie_id AND t.id = ci.movie_id AND t.id = mk.movie_id AND t.id = mc.movie_id AND ci.movie_id = mi.movie_id AND ci.movie_id = mi_idx.movie_id AND ci.movie_id = mk.movie_id AND ci.movie_id = mc.movie_id AND mi.movie_id = mi_idx.movie_id AND mi.movie_id = mk.movie_id AND mi.movie_id = mc.movie_id AND mi_idx.movie_id = mk.movie_id AND mi_idx.movie_id = mc.movie_id AND mk.movie_id = mc.movie_id AND n.id = ci.person_id AND it1.id = mi.info_type_id AND it2.id = mi_idx.info_type_id AND k.id = mk.keyword_id AND cn.id = mc.company_id; diff --git a/benchmarks/queries/imdb/32a.sql b/benchmarks/queries/imdb/32a.sql new file mode 100644 index 000000000000..9647fb71065d --- /dev/null +++ b/benchmarks/queries/imdb/32a.sql @@ -0,0 +1 @@ +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 WHERE k.keyword ='10,000-mile-club' AND mk.keyword_id = k.id AND t1.id = mk.movie_id AND ml.movie_id = t1.id AND ml.linked_movie_id = t2.id AND lt.id = ml.link_type_id AND mk.movie_id = t1.id; diff --git a/benchmarks/queries/imdb/32b.sql b/benchmarks/queries/imdb/32b.sql new file mode 100644 index 000000000000..6d096ab43405 --- /dev/null +++ b/benchmarks/queries/imdb/32b.sql @@ -0,0 +1 @@ +SELECT MIN(lt.link) AS link_type, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM keyword AS k, link_type AS lt, movie_keyword AS mk, movie_link AS ml, title AS t1, title AS t2 WHERE k.keyword ='character-name-in-title' AND mk.keyword_id = k.id AND t1.id = mk.movie_id AND ml.movie_id = t1.id AND ml.linked_movie_id = t2.id AND lt.id = ml.link_type_id AND mk.movie_id = t1.id; diff --git a/benchmarks/queries/imdb/33a.sql b/benchmarks/queries/imdb/33a.sql new file mode 100644 index 000000000000..24aac4e20797 --- /dev/null +++ b/benchmarks/queries/imdb/33a.sql @@ -0,0 +1 @@ +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 WHERE cn1.country_code = '[us]' AND it1.info = 'rating' AND it2.info = 'rating' AND kt1.kind in ('tv series') AND kt2.kind in ('tv series') AND lt.link in ('sequel', 'follows', 'followed by') AND mi_idx2.info < '3.0' AND t2.production_year between 2005 and 2008 AND lt.id = ml.link_type_id AND t1.id = ml.movie_id AND t2.id = ml.linked_movie_id AND it1.id = mi_idx1.info_type_id AND t1.id = mi_idx1.movie_id AND kt1.id = t1.kind_id AND cn1.id = mc1.company_id AND t1.id = mc1.movie_id AND ml.movie_id = mi_idx1.movie_id AND ml.movie_id = mc1.movie_id AND mi_idx1.movie_id = mc1.movie_id AND it2.id = mi_idx2.info_type_id AND t2.id = mi_idx2.movie_id AND kt2.id = t2.kind_id AND cn2.id = mc2.company_id AND t2.id = mc2.movie_id AND ml.linked_movie_id = mi_idx2.movie_id AND ml.linked_movie_id = mc2.movie_id AND mi_idx2.movie_id = mc2.movie_id; diff --git a/benchmarks/queries/imdb/33b.sql b/benchmarks/queries/imdb/33b.sql new file mode 100644 index 000000000000..fe6fd75a6948 --- /dev/null +++ b/benchmarks/queries/imdb/33b.sql @@ -0,0 +1 @@ +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 WHERE cn1.country_code = '[nl]' AND it1.info = 'rating' AND it2.info = 'rating' AND kt1.kind in ('tv series') AND kt2.kind in ('tv series') AND lt.link LIKE '%follow%' AND mi_idx2.info < '3.0' AND t2.production_year = 2007 AND lt.id = ml.link_type_id AND t1.id = ml.movie_id AND t2.id = ml.linked_movie_id AND it1.id = mi_idx1.info_type_id AND t1.id = mi_idx1.movie_id AND kt1.id = t1.kind_id AND cn1.id = mc1.company_id AND t1.id = mc1.movie_id AND ml.movie_id = mi_idx1.movie_id AND ml.movie_id = mc1.movie_id AND mi_idx1.movie_id = mc1.movie_id AND it2.id = mi_idx2.info_type_id AND t2.id = mi_idx2.movie_id AND kt2.id = t2.kind_id AND cn2.id = mc2.company_id AND t2.id = mc2.movie_id AND ml.linked_movie_id = mi_idx2.movie_id AND ml.linked_movie_id = mc2.movie_id AND mi_idx2.movie_id = mc2.movie_id; diff --git a/benchmarks/queries/imdb/33c.sql b/benchmarks/queries/imdb/33c.sql new file mode 100644 index 000000000000..c9f0907d3f90 --- /dev/null +++ b/benchmarks/queries/imdb/33c.sql @@ -0,0 +1 @@ +SELECT MIN(cn1.name) AS first_company, MIN(cn2.name) AS second_company, MIN(mi_idx1.info) AS first_rating, MIN(mi_idx2.info) AS second_rating, MIN(t1.title) AS first_movie, MIN(t2.title) AS second_movie FROM company_name AS cn1, company_name AS cn2, info_type AS it1, info_type AS it2, kind_type AS kt1, kind_type AS kt2, link_type AS lt, movie_companies AS mc1, movie_companies AS mc2, movie_info_idx AS mi_idx1, movie_info_idx AS mi_idx2, movie_link AS ml, title AS t1, title AS t2 WHERE cn1.country_code != '[us]' AND it1.info = 'rating' AND it2.info = 'rating' AND kt1.kind in ('tv series', 'episode') AND kt2.kind in ('tv series', 'episode') AND lt.link in ('sequel', 'follows', 'followed by') AND mi_idx2.info < '3.5' AND t2.production_year between 2000 and 2010 AND lt.id = ml.link_type_id AND t1.id = ml.movie_id AND t2.id = ml.linked_movie_id AND it1.id = mi_idx1.info_type_id AND t1.id = mi_idx1.movie_id AND kt1.id = t1.kind_id AND cn1.id = mc1.company_id AND t1.id = mc1.movie_id AND ml.movie_id = mi_idx1.movie_id AND ml.movie_id = mc1.movie_id AND mi_idx1.movie_id = mc1.movie_id AND it2.id = mi_idx2.info_type_id AND t2.id = mi_idx2.movie_id AND kt2.id = t2.kind_id AND cn2.id = mc2.company_id AND t2.id = mc2.movie_id AND ml.linked_movie_id = mi_idx2.movie_id AND ml.linked_movie_id = mc2.movie_id AND mi_idx2.movie_id = mc2.movie_id; diff --git a/benchmarks/queries/imdb/3a.sql b/benchmarks/queries/imdb/3a.sql new file mode 100644 index 000000000000..231c957be207 --- /dev/null +++ b/benchmarks/queries/imdb/3a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t WHERE k.keyword like '%sequel%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') AND t.production_year > 2005 AND t.id = mi.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi.movie_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/3b.sql b/benchmarks/queries/imdb/3b.sql new file mode 100644 index 000000000000..fd21efc81014 --- /dev/null +++ b/benchmarks/queries/imdb/3b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t WHERE k.keyword like '%sequel%' AND mi.info IN ('Bulgaria') AND t.production_year > 2010 AND t.id = mi.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi.movie_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/3c.sql b/benchmarks/queries/imdb/3c.sql new file mode 100644 index 000000000000..5f34232a2e61 --- /dev/null +++ b/benchmarks/queries/imdb/3c.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS movie_title FROM keyword AS k, movie_info AS mi, movie_keyword AS mk, title AS t WHERE k.keyword like '%sequel%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND t.production_year > 1990 AND t.id = mi.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi.movie_id AND k.id = mk.keyword_id; diff --git a/benchmarks/queries/imdb/4a.sql b/benchmarks/queries/imdb/4a.sql new file mode 100644 index 000000000000..636afab02c8a --- /dev/null +++ b/benchmarks/queries/imdb/4a.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it.info ='rating' AND k.keyword like '%sequel%' AND mi_idx.info > '5.0' AND t.production_year > 2005 AND t.id = mi_idx.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/4b.sql b/benchmarks/queries/imdb/4b.sql new file mode 100644 index 000000000000..ebd3e8992060 --- /dev/null +++ b/benchmarks/queries/imdb/4b.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it.info ='rating' AND k.keyword like '%sequel%' AND mi_idx.info > '9.0' AND t.production_year > 2010 AND t.id = mi_idx.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/4c.sql b/benchmarks/queries/imdb/4c.sql new file mode 100644 index 000000000000..309281200f98 --- /dev/null +++ b/benchmarks/queries/imdb/4c.sql @@ -0,0 +1 @@ +SELECT MIN(mi_idx.info) AS rating, MIN(t.title) AS movie_title FROM info_type AS it, keyword AS k, movie_info_idx AS mi_idx, movie_keyword AS mk, title AS t WHERE it.info ='rating' AND k.keyword like '%sequel%' AND mi_idx.info > '2.0' AND t.production_year > 1990 AND t.id = mi_idx.movie_id AND t.id = mk.movie_id AND mk.movie_id = mi_idx.movie_id AND k.id = mk.keyword_id AND it.id = mi_idx.info_type_id; diff --git a/benchmarks/queries/imdb/5a.sql b/benchmarks/queries/imdb/5a.sql new file mode 100644 index 000000000000..04aae9881f7e --- /dev/null +++ b/benchmarks/queries/imdb/5a.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS typical_european_movie FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t WHERE ct.kind = 'production companies' AND mc.note like '%(theatrical)%' and mc.note like '%(France)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German') AND t.production_year > 2005 AND t.id = mi.movie_id AND t.id = mc.movie_id AND mc.movie_id = mi.movie_id AND ct.id = mc.company_type_id AND it.id = mi.info_type_id; diff --git a/benchmarks/queries/imdb/5b.sql b/benchmarks/queries/imdb/5b.sql new file mode 100644 index 000000000000..f03a519d61b3 --- /dev/null +++ b/benchmarks/queries/imdb/5b.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS american_vhs_movie FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t WHERE ct.kind = 'production companies' AND mc.note like '%(VHS)%' and mc.note like '%(USA)%' and mc.note like '%(1994)%' AND mi.info IN ('USA', 'America') AND t.production_year > 2010 AND t.id = mi.movie_id AND t.id = mc.movie_id AND mc.movie_id = mi.movie_id AND ct.id = mc.company_type_id AND it.id = mi.info_type_id; diff --git a/benchmarks/queries/imdb/5c.sql b/benchmarks/queries/imdb/5c.sql new file mode 100644 index 000000000000..2705e7e2c7a0 --- /dev/null +++ b/benchmarks/queries/imdb/5c.sql @@ -0,0 +1 @@ +SELECT MIN(t.title) AS american_movie FROM company_type AS ct, info_type AS it, movie_companies AS mc, movie_info AS mi, title AS t WHERE ct.kind = 'production companies' AND mc.note not like '%(TV)%' and mc.note like '%(USA)%' AND mi.info IN ('Sweden', 'Norway', 'Germany', 'Denmark', 'Swedish', 'Denish', 'Norwegian', 'German', 'USA', 'American') AND t.production_year > 1990 AND t.id = mi.movie_id AND t.id = mc.movie_id AND mc.movie_id = mi.movie_id AND ct.id = mc.company_type_id AND it.id = mi.info_type_id; diff --git a/benchmarks/queries/imdb/6a.sql b/benchmarks/queries/imdb/6a.sql new file mode 100644 index 000000000000..34b3a6da5fd2 --- /dev/null +++ b/benchmarks/queries/imdb/6a.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword = 'marvel-cinematic-universe' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2010 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6b.sql b/benchmarks/queries/imdb/6b.sql new file mode 100644 index 000000000000..1233c41e66b0 --- /dev/null +++ b/benchmarks/queries/imdb/6b.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2014 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6c.sql b/benchmarks/queries/imdb/6c.sql new file mode 100644 index 000000000000..d1f97746e15e --- /dev/null +++ b/benchmarks/queries/imdb/6c.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword = 'marvel-cinematic-universe' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2014 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6d.sql b/benchmarks/queries/imdb/6d.sql new file mode 100644 index 000000000000..07729510a454 --- /dev/null +++ b/benchmarks/queries/imdb/6d.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2000 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6e.sql b/benchmarks/queries/imdb/6e.sql new file mode 100644 index 000000000000..2e77873fd81d --- /dev/null +++ b/benchmarks/queries/imdb/6e.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS marvel_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword = 'marvel-cinematic-universe' AND n.name LIKE '%Downey%Robert%' AND t.production_year > 2000 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/6f.sql b/benchmarks/queries/imdb/6f.sql new file mode 100644 index 000000000000..603901129107 --- /dev/null +++ b/benchmarks/queries/imdb/6f.sql @@ -0,0 +1 @@ +SELECT MIN(k.keyword) AS movie_keyword, MIN(n.name) AS actor_name, MIN(t.title) AS hero_movie FROM cast_info AS ci, keyword AS k, movie_keyword AS mk, name AS n, title AS t WHERE k.keyword in ('superhero', 'sequel', 'second-part', 'marvel-comics', 'based-on-comic', 'tv-special', 'fight', 'violence') AND t.production_year > 2000 AND k.id = mk.keyword_id AND t.id = mk.movie_id AND t.id = ci.movie_id AND ci.movie_id = mk.movie_id AND n.id = ci.person_id; diff --git a/benchmarks/queries/imdb/7a.sql b/benchmarks/queries/imdb/7a.sql new file mode 100644 index 000000000000..c6b26ce36f11 --- /dev/null +++ b/benchmarks/queries/imdb/7a.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t WHERE an.name LIKE '%a%' AND it.info ='mini biography' AND lt.link ='features' AND n.name_pcode_cf BETWEEN 'A' AND 'F' AND (n.gender='m' OR (n.gender = 'f' AND n.name LIKE 'B%')) AND pi.note ='Volker Boehm' AND t.production_year BETWEEN 1980 AND 1995 AND n.id = an.person_id AND n.id = pi.person_id AND ci.person_id = n.id AND t.id = ci.movie_id AND ml.linked_movie_id = t.id AND lt.id = ml.link_type_id AND it.id = pi.info_type_id AND pi.person_id = an.person_id AND pi.person_id = ci.person_id AND an.person_id = ci.person_id AND ci.movie_id = ml.linked_movie_id; diff --git a/benchmarks/queries/imdb/7b.sql b/benchmarks/queries/imdb/7b.sql new file mode 100644 index 000000000000..4e4f6e7615cb --- /dev/null +++ b/benchmarks/queries/imdb/7b.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS of_person, MIN(t.title) AS biography_movie FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t WHERE an.name LIKE '%a%' AND it.info ='mini biography' AND lt.link ='features' AND n.name_pcode_cf LIKE 'D%' AND n.gender='m' AND pi.note ='Volker Boehm' AND t.production_year BETWEEN 1980 AND 1984 AND n.id = an.person_id AND n.id = pi.person_id AND ci.person_id = n.id AND t.id = ci.movie_id AND ml.linked_movie_id = t.id AND lt.id = ml.link_type_id AND it.id = pi.info_type_id AND pi.person_id = an.person_id AND pi.person_id = ci.person_id AND an.person_id = ci.person_id AND ci.movie_id = ml.linked_movie_id; diff --git a/benchmarks/queries/imdb/7c.sql b/benchmarks/queries/imdb/7c.sql new file mode 100644 index 000000000000..a399342fae02 --- /dev/null +++ b/benchmarks/queries/imdb/7c.sql @@ -0,0 +1 @@ +SELECT MIN(n.name) AS cast_member_name, MIN(pi.info) AS cast_member_info FROM aka_name AS an, cast_info AS ci, info_type AS it, link_type AS lt, movie_link AS ml, name AS n, person_info AS pi, title AS t WHERE an.name is not NULL and (an.name LIKE '%a%' or an.name LIKE 'A%') AND it.info ='mini biography' AND lt.link in ('references', 'referenced in', 'features', 'featured in') AND n.name_pcode_cf BETWEEN 'A' AND 'F' AND (n.gender='m' OR (n.gender = 'f' AND n.name LIKE 'A%')) AND pi.note is not NULL AND t.production_year BETWEEN 1980 AND 2010 AND n.id = an.person_id AND n.id = pi.person_id AND ci.person_id = n.id AND t.id = ci.movie_id AND ml.linked_movie_id = t.id AND lt.id = ml.link_type_id AND it.id = pi.info_type_id AND pi.person_id = an.person_id AND pi.person_id = ci.person_id AND an.person_id = ci.person_id AND ci.movie_id = ml.linked_movie_id; diff --git a/benchmarks/queries/imdb/8a.sql b/benchmarks/queries/imdb/8a.sql new file mode 100644 index 000000000000..66ed05880d5f --- /dev/null +++ b/benchmarks/queries/imdb/8a.sql @@ -0,0 +1 @@ +SELECT MIN(an1.name) AS actress_pseudonym, MIN(t.title) AS japanese_movie_dubbed FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t WHERE ci.note ='(voice: English version)' AND cn.country_code ='[jp]' AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' AND n1.name like '%Yo%' and n1.name not like '%Yu%' AND rt.role ='actress' AND an1.person_id = n1.id AND n1.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND an1.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/8b.sql b/benchmarks/queries/imdb/8b.sql new file mode 100644 index 000000000000..044b5f8e8649 --- /dev/null +++ b/benchmarks/queries/imdb/8b.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS acress_pseudonym, MIN(t.title) AS japanese_anime_movie FROM aka_name AS an, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note ='(voice: English version)' AND cn.country_code ='[jp]' AND mc.note like '%(Japan)%' and mc.note not like '%(USA)%' and (mc.note like '%(2006)%' or mc.note like '%(2007)%') AND n.name like '%Yo%' and n.name not like '%Yu%' AND rt.role ='actress' AND t.production_year between 2006 and 2007 and (t.title like 'One Piece%' or t.title like 'Dragon Ball Z%') AND an.person_id = n.id AND n.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND an.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/8c.sql b/benchmarks/queries/imdb/8c.sql new file mode 100644 index 000000000000..d02b74c02c5e --- /dev/null +++ b/benchmarks/queries/imdb/8c.sql @@ -0,0 +1 @@ +SELECT MIN(a1.name) AS writer_pseudo_name, MIN(t.title) AS movie_title FROM aka_name AS a1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t WHERE cn.country_code ='[us]' AND rt.role ='writer' AND a1.person_id = n1.id AND n1.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND a1.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/8d.sql b/benchmarks/queries/imdb/8d.sql new file mode 100644 index 000000000000..0834c0ff5cb7 --- /dev/null +++ b/benchmarks/queries/imdb/8d.sql @@ -0,0 +1 @@ +SELECT MIN(an1.name) AS costume_designer_pseudo, MIN(t.title) AS movie_with_costumes FROM aka_name AS an1, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n1, role_type AS rt, title AS t WHERE cn.country_code ='[us]' AND rt.role ='costume designer' AND an1.person_id = n1.id AND n1.id = ci.person_id AND ci.movie_id = t.id AND t.id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND an1.person_id = ci.person_id AND ci.movie_id = mc.movie_id; diff --git a/benchmarks/queries/imdb/9a.sql b/benchmarks/queries/imdb/9a.sql new file mode 100644 index 000000000000..593b16213b06 --- /dev/null +++ b/benchmarks/queries/imdb/9a.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS character_name, MIN(t.title) AS movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND mc.note is not NULL and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND n.gender ='f' and n.name like '%Ang%' AND rt.role ='actress' AND t.production_year between 2005 and 2015 AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/queries/imdb/9b.sql b/benchmarks/queries/imdb/9b.sql new file mode 100644 index 000000000000..a4933fd6856e --- /dev/null +++ b/benchmarks/queries/imdb/9b.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note = '(voice)' AND cn.country_code ='[us]' AND mc.note like '%(200%)%' and (mc.note like '%(USA)%' or mc.note like '%(worldwide)%') AND n.gender ='f' and n.name like '%Angel%' AND rt.role ='actress' AND t.production_year between 2007 and 2010 AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/queries/imdb/9c.sql b/benchmarks/queries/imdb/9c.sql new file mode 100644 index 000000000000..0be511810cf6 --- /dev/null +++ b/benchmarks/queries/imdb/9c.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_character_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND n.gender ='f' and n.name like '%An%' AND rt.role ='actress' AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/benchmarks/queries/imdb/9d.sql b/benchmarks/queries/imdb/9d.sql new file mode 100644 index 000000000000..51262ca5ebae --- /dev/null +++ b/benchmarks/queries/imdb/9d.sql @@ -0,0 +1 @@ +SELECT MIN(an.name) AS alternative_name, MIN(chn.name) AS voiced_char_name, MIN(n.name) AS voicing_actress, MIN(t.title) AS american_movie FROM aka_name AS an, char_name AS chn, cast_info AS ci, company_name AS cn, movie_companies AS mc, name AS n, role_type AS rt, title AS t WHERE ci.note in ('(voice)', '(voice: Japanese version)', '(voice) (uncredited)', '(voice: English version)') AND cn.country_code ='[us]' AND n.gender ='f' AND rt.role ='actress' AND ci.movie_id = t.id AND t.id = mc.movie_id AND ci.movie_id = mc.movie_id AND mc.company_id = cn.id AND ci.role_id = rt.id AND n.id = ci.person_id AND chn.id = ci.person_role_id AND an.person_id = n.id AND an.person_id = ci.person_id; diff --git a/.github_changelog_generator b/benchmarks/requirements.txt similarity index 62% rename from .github_changelog_generator rename to benchmarks/requirements.txt index 45eef2f51836..20a5a2bddbf2 100644 --- a/.github_changelog_generator +++ b/benchmarks/requirements.txt @@ -1,5 +1,3 @@ -#!/bin/bash -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,13 +14,5 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# -# some issues are just documentation -add-sections={"documentation":{"prefix":"**Documentation updates:**","labels":["documentation"]},"performance":{"prefix":"**Performance improvements:**","labels":["performance"]}} -# uncomment to not show PRs. TBD if we shown them or not. -#pull-requests=false -# so that the component is shown associated with the issue -issue-line-labels=sql -exclude-labels=development-process,invalid -breaking-labels=api change +rich diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 441b6cdc0293..f7b84116e793 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -20,6 +20,11 @@ use datafusion::error::Result; use structopt::StructOpt; +#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] +compile_error!( + "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" +); + #[cfg(feature = "snmalloc")] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; @@ -28,7 +33,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -use datafusion_benchmarks::{clickbench, parquet_filter, sort, tpch}; +use datafusion_benchmarks::{clickbench, imdb, parquet_filter, sort, tpch}; #[derive(Debug, StructOpt)] #[structopt(about = "benchmark command")] @@ -38,6 +43,7 @@ enum Options { Clickbench(clickbench::RunOpt), ParquetFilter(parquet_filter::RunOpt), Sort(sort::RunOpt), + Imdb(imdb::RunOpt), } // Main benchmark runner entrypoint @@ -51,5 +57,6 @@ pub async fn main() -> Result<()> { Options::Clickbench(opt) => opt.run().await, Options::ParquetFilter(opt) => opt.run().await, Options::Sort(opt) => opt.run().await, + Options::Imdb(opt) => opt.run().await, } } diff --git a/benchmarks/src/bin/external_aggr.rs b/benchmarks/src/bin/external_aggr.rs new file mode 100644 index 000000000000..6438593a20a0 --- /dev/null +++ b/benchmarks/src/bin/external_aggr.rs @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! external_aggr binary entrypoint + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::OnceLock; +use structopt::StructOpt; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::execution::memory_pool::FairSpillPool; +use datafusion::execution::memory_pool::{human_readable_size, units}; +use datafusion::execution::runtime_env::RuntimeConfig; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_benchmarks::util::{BenchmarkRun, CommonOpt}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DEFAULT_PARQUET_EXTENSION}; + +#[derive(Debug, StructOpt)] +#[structopt( + name = "datafusion-external-aggregation", + about = "DataFusion external aggregation benchmark" +)] +enum ExternalAggrOpt { + Benchmark(ExternalAggrConfig), +} + +#[derive(Debug, StructOpt)] +struct ExternalAggrConfig { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Memory limit (e.g. '100M', '1.5G'). If not specified, run all pre-defined memory limits for given query. + #[structopt(long)] + memory_limit: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files (lineitem). Only parquet format is supported + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to JSON benchmark result to be compare using `compare.py` + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +/// Query Memory Limits +/// Map query id to predefined memory limits +/// +/// Q1 requires 36MiB for aggregation +/// Memory limits to run: 64MiB, 32MiB, 16MiB +/// Q2 requires 250MiB for aggregation +/// Memory limits to run: 512MiB, 256MiB, 128MiB, 64MiB, 32MiB +static QUERY_MEMORY_LIMITS: OnceLock>> = OnceLock::new(); + +impl ExternalAggrConfig { + const AGGR_TABLES: [&'static str; 1] = ["lineitem"]; + const AGGR_QUERIES: [&'static str; 2] = [ + // Q1: Output size is ~25% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey + FROM lineitem + ) + "#, + // Q2: Output size is ~99% of lineitem table + r#" + SELECT count(*) + FROM ( + SELECT DISTINCT l_orderkey, l_suppkey + FROM lineitem + ) + "#, + ]; + + fn init_query_memory_limits() -> &'static HashMap> { + use units::*; + QUERY_MEMORY_LIMITS.get_or_init(|| { + let mut map = HashMap::new(); + map.insert(1, vec![64 * MB, 32 * MB, 16 * MB]); + map.insert(2, vec![512 * MB, 256 * MB, 128 * MB, 64 * MB, 32 * MB]); + map + }) + } + + /// If `--query` and `--memory-limit` is not speicified, run all queries + /// with pre-configured memory limits + /// If only `--query` is specified, run the query with all memory limits + /// for this query + /// If both `--query` and `--memory-limit` are specified, run the query + /// with the specified memory limit + pub async fn run(&self) -> Result<()> { + let mut benchmark_run = BenchmarkRun::new(); + + let memory_limit = match &self.memory_limit { + Some(limit) => Some(Self::parse_memory_limit(limit)?), + None => None, + }; + + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => 1..=Self::AGGR_QUERIES.len(), + }; + + // Each element is (query_id, memory_limit) + // e.g. [(1, 64_000), (1, 32_000)...] means first run Q1 with 64KiB + // memory limit, next run Q1 with 32KiB memory limit, etc. + let mut query_executions = vec![]; + // Setup `query_executions` + for query_id in query_range { + if query_id > Self::AGGR_QUERIES.len() { + return exec_err!( + "Invalid '--query'(query number) {} for external aggregation benchmark.", + query_id + ); + } + + match memory_limit { + Some(limit) => { + query_executions.push((query_id, limit)); + } + None => { + let memory_limits_table = Self::init_query_memory_limits(); + let memory_limits = memory_limits_table.get(&query_id).unwrap(); + for limit in memory_limits { + query_executions.push((query_id, *limit)); + } + } + } + } + + for (query_id, mem_limit) in query_executions { + benchmark_run.start_new_case(&format!( + "{query_id}({})", + human_readable_size(mem_limit as usize) + )); + + let query_results = self.benchmark_query(query_id, mem_limit).await?; + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + + Ok(()) + } + + /// Benchmark query `query_id` in `AGGR_QUERIES` + async fn benchmark_query( + &self, + query_id: usize, + mem_limit: u64, + ) -> Result> { + let query_name = + format!("Q{query_id}({})", human_readable_size(mem_limit as usize)); + let config = self.common.config(); + let runtime_config = RuntimeConfig::new() + .with_memory_pool(Arc::new(FairSpillPool::new(mem_limit as usize))) + .build_arc()?; + let ctx = SessionContext::new_with_config_rt(config, runtime_config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_idx = query_id - 1; // 1-indexed -> 0-indexed + let sql = Self::AGGR_QUERIES[query_idx]; + + let result = self.execute_query(&ctx, sql).await?; + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "{query_name} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("{query_name} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in Self::AGGR_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(table, Arc::new(memtable))?; + } else { + ctx.register_table(table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let path = format!("{path}/{table}"); + let format = Arc::new( + ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()), + ); + let extension = DEFAULT_PARQUET_EXTENSION; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = config.infer_schema(&state).await?; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } + + /// Parse memory limit from string to number of bytes + /// e.g. '1.5G', '100M' -> 1572864 + fn parse_memory_limit(limit: &str) -> Result { + let (number, unit) = limit.split_at(limit.len() - 1); + let number: f64 = number.parse().map_err(|_| { + exec_datafusion_err!("Failed to parse number from memory limit '{}'", limit) + })?; + + match unit { + "K" => Ok((number * 1024.0) as u64), + "M" => Ok((number * 1024.0 * 1024.0) as u64), + "G" => Ok((number * 1024.0 * 1024.0 * 1024.0) as u64), + _ => exec_err!("Unsupported unit '{}' in memory limit '{}'", unit, limit), + } + } +} + +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + + match ExternalAggrOpt::from_args() { + ExternalAggrOpt::Benchmark(opt) => opt.run().await?, + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_memory_limit_all() { + // Test valid inputs + assert_eq!( + ExternalAggrConfig::parse_memory_limit("100K").unwrap(), + 102400 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("1.5M").unwrap(), + 1572864 + ); + assert_eq!( + ExternalAggrConfig::parse_memory_limit("2G").unwrap(), + 2147483648 + ); + + // Test invalid unit + assert!(ExternalAggrConfig::parse_memory_limit("500X").is_err()); + + // Test invalid number + assert!(ExternalAggrConfig::parse_memory_limit("abcM").is_err()); + } +} diff --git a/benchmarks/src/bin/h2o.rs b/benchmarks/src/bin/h2o.rs index 1bb8cb9d43e4..1ddeb786a591 100644 --- a/benchmarks/src/bin/h2o.rs +++ b/benchmarks/src/bin/h2o.rs @@ -26,7 +26,7 @@ use datafusion::datasource::listing::{ use datafusion::datasource::MemTable; use datafusion::prelude::CsvReadOptions; use datafusion::{arrow::util::pretty, error::Result, prelude::SessionContext}; -use datafusion_benchmarks::BenchmarkRun; +use datafusion_benchmarks::util::BenchmarkRun; use std::path::PathBuf; use std::sync::Arc; use structopt::StructOpt; diff --git a/benchmarks/src/bin/imdb.rs b/benchmarks/src/bin/imdb.rs new file mode 100644 index 000000000000..13421f8a89a9 --- /dev/null +++ b/benchmarks/src/bin/imdb.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! IMDB binary entrypoint + +use datafusion::error::Result; +use datafusion_benchmarks::imdb; +use structopt::StructOpt; + +#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] +compile_error!( + "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" +); + +#[cfg(feature = "snmalloc")] +#[global_allocator] +static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; + +#[cfg(feature = "mimalloc")] +#[global_allocator] +static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; + +#[derive(Debug, StructOpt)] +#[structopt(about = "benchmark command")] +enum BenchmarkSubCommandOpt { + #[structopt(name = "datafusion")] + DataFusionBenchmark(imdb::RunOpt), +} + +#[derive(Debug, StructOpt)] +#[structopt(name = "IMDB", about = "IMDB Dataset Processing.")] +enum ImdbOpt { + Benchmark(BenchmarkSubCommandOpt), + Convert(imdb::ConvertOpt), +} + +#[tokio::main] +pub async fn main() -> Result<()> { + env_logger::init(); + match ImdbOpt::from_args() { + ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => { + opt.run().await + } + ImdbOpt::Convert(opt) => opt.run().await, + } +} diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index fc0f4ca0613c..3270b082cfb4 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -21,6 +21,11 @@ use datafusion::error::Result; use datafusion_benchmarks::tpch; use structopt::StructOpt; +#[cfg(all(feature = "snmalloc", feature = "mimalloc"))] +compile_error!( + "feature \"snmalloc\" and feature \"mimalloc\" cannot be enabled at the same time" +); + #[cfg(feature = "snmalloc")] #[global_allocator] static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; diff --git a/benchmarks/src/clickbench.rs b/benchmarks/src/clickbench.rs index 41dffc55f371..46dd4b18825b 100644 --- a/benchmarks/src/clickbench.rs +++ b/benchmarks/src/clickbench.rs @@ -18,6 +18,7 @@ use std::path::Path; use std::path::PathBuf; +use crate::util::{BenchmarkRun, CommonOpt}; use datafusion::{ error::{DataFusionError, Result}, prelude::SessionContext, @@ -26,8 +27,6 @@ use datafusion_common::exec_datafusion_err; use datafusion_common::instant::Instant; use structopt::StructOpt; -use crate::{BenchmarkRun, CommonOpt}; - /// Run the clickbench benchmark /// /// The ClickBench[1] benchmarks are widely cited in the industry and @@ -116,7 +115,15 @@ impl RunOpt { None => queries.min_query_id()..=queries.max_query_id(), }; - let config = self.common.config(); + // configure parquet options + let mut config = self.common.config(); + { + let parquet_options = &mut config.options_mut().execution.parquet; + // The hits_partitioned dataset specifies string columns + // as binary due to how it was written. Force it to strings + parquet_options.binary_as_string = true; + } + let ctx = SessionContext::new_with_config(config); self.register_hits(&ctx).await?; @@ -143,7 +150,7 @@ impl RunOpt { Ok(()) } - /// Registrs the `hits.parquet` as a table named `hits` + /// Registers the `hits.parquet` as a table named `hits` async fn register_hits(&self, ctx: &SessionContext) -> Result<()> { let options = Default::default(); let path = self.path.as_os_str().to_str().unwrap(); diff --git a/benchmarks/src/imdb/convert.rs b/benchmarks/src/imdb/convert.rs new file mode 100644 index 000000000000..4e470d711da5 --- /dev/null +++ b/benchmarks/src/imdb/convert.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::dataframe::DataFrameWriteOptions; +use datafusion_common::instant::Instant; +use std::path::PathBuf; + +use datafusion::error::Result; +use datafusion::prelude::*; +use structopt::StructOpt; + +use datafusion::common::not_impl_err; + +use super::get_imdb_table_schema; +use super::IMDB_TABLES; + +#[derive(Debug, StructOpt)] +pub struct ConvertOpt { + /// Path to csv files + #[structopt(parse(from_os_str), required = true, short = "i", long = "input")] + input_path: PathBuf, + + /// Output path + #[structopt(parse(from_os_str), required = true, short = "o", long = "output")] + output_path: PathBuf, + + /// Output file format: `csv` or `parquet` + #[structopt(short = "f", long = "format")] + file_format: String, + + /// Batch size when reading CSV or Parquet files + #[structopt(short = "s", long = "batch-size", default_value = "8192")] + batch_size: usize, +} + +impl ConvertOpt { + pub async fn run(self) -> Result<()> { + let input_path = self.input_path.to_str().unwrap(); + let output_path = self.output_path.to_str().unwrap(); + let config = SessionConfig::new().with_batch_size(self.batch_size); + let ctx = SessionContext::new_with_config(config); + + for table in IMDB_TABLES { + let start = Instant::now(); + let schema = get_imdb_table_schema(table); + let input_path = format!("{input_path}/{table}.csv"); + let output_path = format!("{output_path}/{table}.parquet"); + let options = CsvReadOptions::new() + .schema(&schema) + .has_header(false) + .delimiter(b',') + .escape(b'\\') + .file_extension(".csv"); + + let mut csv = ctx.read_csv(&input_path, options).await?; + + // Select all apart from the padding column + let selection = csv + .schema() + .iter() + .take(schema.fields.len()) + .map(Expr::from) + .collect(); + + csv = csv.select(selection)?; + + println!( + "Converting '{}' to {} files in directory '{}'", + &input_path, self.file_format, &output_path + ); + match self.file_format.as_str() { + "csv" => { + csv.write_csv( + output_path.as_str(), + DataFrameWriteOptions::new(), + None, + ) + .await?; + } + "parquet" => { + csv.write_parquet( + output_path.as_str(), + DataFrameWriteOptions::new(), + None, + ) + .await?; + } + other => { + return not_impl_err!("Invalid output format: {other}"); + } + } + println!("Conversion completed in {} ms", start.elapsed().as_millis()); + } + Ok(()) + } +} diff --git a/benchmarks/src/imdb/mod.rs b/benchmarks/src/imdb/mod.rs new file mode 100644 index 000000000000..6a45242e6ff4 --- /dev/null +++ b/benchmarks/src/imdb/mod.rs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmark derived from IMDB dataset. + +use datafusion::{ + arrow::datatypes::{DataType, Field, Schema}, + common::plan_err, + error::Result, +}; +mod convert; +pub use convert::ConvertOpt; + +use std::fs; +mod run; +pub use run::RunOpt; + +// we have 21 tables in the IMDB dataset +pub const IMDB_TABLES: &[&str] = &[ + "aka_name", + "aka_title", + "cast_info", + "char_name", + "comp_cast_type", + "company_name", + "company_type", + "complete_cast", + "info_type", + "keyword", + "kind_type", + "link_type", + "movie_companies", + "movie_info_idx", + "movie_keyword", + "movie_link", + "name", + "role_type", + "title", + "movie_info", + "person_info", +]; + +/// Get the schema for the IMDB dataset tables +/// see benchmarks/data/imdb/schematext.sql +pub fn get_imdb_table_schema(table: &str) -> Schema { + match table { + "aka_name" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("person_id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("imdb_index", DataType::Utf8, true), + Field::new("name_pcode_cf", DataType::Utf8, true), + Field::new("name_pcode_nf", DataType::Utf8, true), + Field::new("surname_pcode", DataType::Utf8, true), + Field::new("md5sum", DataType::Utf8, true), + ]), + "aka_title" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("title", DataType::Utf8, true), + Field::new("imdb_index", DataType::Utf8, true), + Field::new("kind_id", DataType::Int32, false), + Field::new("production_year", DataType::Int32, true), + Field::new("phonetic_code", DataType::Utf8, true), + Field::new("episode_of_id", DataType::Int32, true), + Field::new("season_nr", DataType::Int32, true), + Field::new("episode_nr", DataType::Int32, true), + Field::new("note", DataType::Utf8, true), + Field::new("md5sum", DataType::Utf8, true), + ]), + "cast_info" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("person_id", DataType::Int32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("person_role_id", DataType::Int32, true), + Field::new("note", DataType::Utf8, true), + Field::new("nr_order", DataType::Int32, true), + Field::new("role_id", DataType::Int32, false), + ]), + "char_name" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("name", DataType::Utf8, false), + Field::new("imdb_index", DataType::Utf8, true), + Field::new("imdb_id", DataType::Int32, true), + Field::new("name_pcode_nf", DataType::Utf8, true), + Field::new("surname_pcode", DataType::Utf8, true), + Field::new("md5sum", DataType::Utf8, true), + ]), + "comp_cast_type" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("kind", DataType::Utf8, false), + ]), + "company_name" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("name", DataType::Utf8, false), + Field::new("country_code", DataType::Utf8, true), + Field::new("imdb_id", DataType::Int32, true), + Field::new("name_pcode_nf", DataType::Utf8, true), + Field::new("name_pcode_sf", DataType::Utf8, true), + Field::new("md5sum", DataType::Utf8, true), + ]), + "company_type" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("kind", DataType::Utf8, true), + ]), + "complete_cast" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, true), + Field::new("subject_id", DataType::Int32, false), + Field::new("status_id", DataType::Int32, false), + ]), + "info_type" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("info", DataType::Utf8, false), + ]), + "keyword" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("keyword", DataType::Utf8, false), + Field::new("phonetic_code", DataType::Utf8, true), + ]), + "kind_type" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("kind", DataType::Utf8, true), + ]), + "link_type" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("link", DataType::Utf8, false), + ]), + "movie_companies" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("company_id", DataType::Int32, false), + Field::new("company_type_id", DataType::Int32, false), + Field::new("note", DataType::Utf8, true), + ]), + "movie_info_idx" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("info_type_id", DataType::Int32, false), + Field::new("info", DataType::Utf8, false), + Field::new("note", DataType::Utf8, true), + ]), + "movie_keyword" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("keyword_id", DataType::Int32, false), + ]), + "movie_link" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("linked_movie_id", DataType::Int32, false), + Field::new("link_type_id", DataType::Int32, false), + ]), + "name" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("name", DataType::Utf8, false), + Field::new("imdb_index", DataType::Utf8, true), + Field::new("imdb_id", DataType::Int32, true), + Field::new("gender", DataType::Utf8, true), + Field::new("name_pcode_cf", DataType::Utf8, true), + Field::new("name_pcode_nf", DataType::Utf8, true), + Field::new("surname_pcode", DataType::Utf8, true), + Field::new("md5sum", DataType::Utf8, true), + ]), + "role_type" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("role", DataType::Utf8, false), + ]), + "title" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("title", DataType::Utf8, false), + Field::new("imdb_index", DataType::Utf8, true), + Field::new("kind_id", DataType::Int32, false), + Field::new("production_year", DataType::Int32, true), + Field::new("imdb_id", DataType::Int32, true), + Field::new("phonetic_code", DataType::Utf8, true), + Field::new("episode_of_id", DataType::Int32, true), + Field::new("season_nr", DataType::Int32, true), + Field::new("episode_nr", DataType::Int32, true), + Field::new("series_years", DataType::Utf8, true), + Field::new("md5sum", DataType::Utf8, true), + ]), + "movie_info" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("movie_id", DataType::Int32, false), + Field::new("info_type_id", DataType::Int32, false), + Field::new("info", DataType::Utf8, false), + Field::new("note", DataType::Utf8, true), + ]), + "person_info" => Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("person_id", DataType::Int32, false), + Field::new("info_type_id", DataType::Int32, false), + Field::new("info", DataType::Utf8, false), + Field::new("note", DataType::Utf8, true), + ]), + _ => unimplemented!("Schema for table {} is not implemented", table), + } +} + +/// Get the SQL statements from the specified query file +pub fn get_query_sql(query: &str) -> Result> { + let possibilities = vec![ + format!("queries/imdb/{query}.sql"), + format!("benchmarks/queries/imdb/{query}.sql"), + ]; + let mut errors = vec![]; + for filename in possibilities { + match fs::read_to_string(&filename) { + Ok(contents) => { + return Ok(contents + .split(';') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .map(|s| s.to_string()) + .collect()); + } + Err(e) => errors.push(format!("{filename}: {e}")), + }; + } + plan_err!("invalid query. Could not find query: {:?}", errors) +} diff --git a/benchmarks/src/imdb/run.rs b/benchmarks/src/imdb/run.rs new file mode 100644 index 000000000000..47c356990881 --- /dev/null +++ b/benchmarks/src/imdb/run.rs @@ -0,0 +1,822 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::PathBuf; +use std::sync::Arc; + +use super::{get_imdb_table_schema, get_query_sql, IMDB_TABLES}; +use crate::util::{BenchmarkRun, CommonOpt}; + +use arrow::record_batch::RecordBatch; +use arrow::util::pretty::{self, pretty_format_batches}; +use datafusion::datasource::file_format::csv::CsvFormat; +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; +use datafusion::datasource::listing::{ + ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, +}; +use datafusion::datasource::{MemTable, TableProvider}; +use datafusion::error::Result; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::physical_plan::{collect, displayable}; +use datafusion::prelude::*; +use datafusion_common::instant::Instant; +use datafusion_common::{DEFAULT_CSV_EXTENSION, DEFAULT_PARQUET_EXTENSION}; + +use log::info; +use structopt::StructOpt; + +// hack to avoid `default_value is meaningless for bool` errors +type BoolDefaultTrue = bool; + +/// Run the imdb benchmark (a.k.a. JOB). +/// +/// This benchmarks is derived from the [Join Order Benchmark / JOB] proposed in paper [How Good Are Query Optimizers, Really?][1]. +/// The data and answers are downloaded from +/// [2] and [3]. +/// +/// [1]: https://www.vldb.org/pvldb/vol9/p204-leis.pdf +/// [2]: http://homepages.cwi.nl/~boncz/job/imdb.tgz +/// [3]: https://db.in.tum.de/~leis/qo/job.tgz + +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number. If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// Path to data files + #[structopt(parse(from_os_str), required = true, short = "p", long = "path")] + path: PathBuf, + + /// File format: `csv` or `parquet` + #[structopt(short = "f", long = "format", default_value = "csv")] + file_format: String, + + /// Load the data into a MemTable before executing the query + #[structopt(short = "m", long = "mem-table")] + mem_table: bool, + + /// Path to machine readable output file + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, + + /// Whether to disable collection of statistics (and cost based optimizations) or not. + #[structopt(short = "S", long = "disable-statistics")] + disable_statistics: bool, + + /// If true then hash join used, if false then sort merge join + /// True by default. + #[structopt(short = "j", long = "prefer_hash_join", default_value = "true")] + prefer_hash_join: BoolDefaultTrue, +} + +const IMDB_QUERY_START_ID: usize = 1; +const IMDB_QUERY_END_ID: usize = 113; + +fn map_query_id_to_str(query_id: usize) -> &'static str { + match query_id { + // 1 + 1 => "1a", + 2 => "1b", + 3 => "1c", + 4 => "1d", + + // 2 + 5 => "2a", + 6 => "2b", + 7 => "2c", + 8 => "2d", + + // 3 + 9 => "3a", + 10 => "3b", + 11 => "3c", + + // 4 + 12 => "4a", + 13 => "4b", + 14 => "4c", + + // 5 + 15 => "5a", + 16 => "5b", + 17 => "5c", + + // 6 + 18 => "6a", + 19 => "6b", + 20 => "6c", + 21 => "6d", + 22 => "6e", + 23 => "6f", + + // 7 + 24 => "7a", + 25 => "7b", + 26 => "7c", + + // 8 + 27 => "8a", + 28 => "8b", + 29 => "8c", + 30 => "8d", + + // 9 + 31 => "9a", + 32 => "9b", + 33 => "9c", + 34 => "9d", + + // 10 + 35 => "10a", + 36 => "10b", + 37 => "10c", + + // 11 + 38 => "11a", + 39 => "11b", + 40 => "11c", + 41 => "11d", + + // 12 + 42 => "12a", + 43 => "12b", + 44 => "12c", + + // 13 + 45 => "13a", + 46 => "13b", + 47 => "13c", + 48 => "13d", + + // 14 + 49 => "14a", + 50 => "14b", + 51 => "14c", + + // 15 + 52 => "15a", + 53 => "15b", + 54 => "15c", + 55 => "15d", + + // 16 + 56 => "16a", + 57 => "16b", + 58 => "16c", + 59 => "16d", + + // 17 + 60 => "17a", + 61 => "17b", + 62 => "17c", + 63 => "17d", + 64 => "17e", + 65 => "17f", + + // 18 + 66 => "18a", + 67 => "18b", + 68 => "18c", + + // 19 + 69 => "19a", + 70 => "19b", + 71 => "19c", + 72 => "19d", + + // 20 + 73 => "20a", + 74 => "20b", + 75 => "20c", + + // 21 + 76 => "21a", + 77 => "21b", + 78 => "21c", + + // 22 + 79 => "22a", + 80 => "22b", + 81 => "22c", + 82 => "22d", + + // 23 + 83 => "23a", + 84 => "23b", + 85 => "23c", + + // 24 + 86 => "24a", + 87 => "24b", + + // 25 + 88 => "25a", + 89 => "25b", + 90 => "25c", + + // 26 + 91 => "26a", + 92 => "26b", + 93 => "26c", + + // 27 + 94 => "27a", + 95 => "27b", + 96 => "27c", + + // 28 + 97 => "28a", + 98 => "28b", + 99 => "28c", + + // 29 + 100 => "29a", + 101 => "29b", + 102 => "29c", + + // 30 + 103 => "30a", + 104 => "30b", + 105 => "30c", + + // 31 + 106 => "31a", + 107 => "31b", + 108 => "31c", + + // 32 + 109 => "32a", + 110 => "32b", + + // 33 + 111 => "33a", + 112 => "33b", + 113 => "33c", + + // Fallback for unknown query_id + _ => "unknown", + } +} + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running benchmarks with the following options: {self:?}"); + let query_range = match self.query { + Some(query_id) => query_id..=query_id, + None => IMDB_QUERY_START_ID..=IMDB_QUERY_END_ID, + }; + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(query_id).await?; + for iter in query_run { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query(&self, query_id: usize) -> Result> { + let mut config = self + .common + .config() + .with_collect_statistics(!self.disable_statistics); + config.options_mut().optimizer.prefer_hash_join = self.prefer_hash_join; + + let ctx = SessionContext::new_with_config(config); + + // register tables + self.register_tables(&ctx).await?; + + let mut millis = vec![]; + // run benchmark + let mut query_results = vec![]; + for i in 0..self.iterations() { + let start = Instant::now(); + + let query_id_str = map_query_id_to_str(query_id); + let sql = &get_query_sql(query_id_str)?; + + let mut result = vec![]; + + for query in sql { + result = self.execute_query(&ctx, query).await?; + } + + let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; + let ms = elapsed.as_secs_f64() * 1000.0; + millis.push(ms); + info!("output:\n\n{}\n\n", pretty_format_batches(&result)?); + let row_count = result.iter().map(|b| b.num_rows()).sum(); + println!( + "Query {query_id} iteration {i} took {ms:.1} ms and returned {row_count} rows" + ); + query_results.push(QueryResult { elapsed, row_count }); + } + + let avg = millis.iter().sum::() / millis.len() as f64; + println!("Query {query_id} avg time: {avg:.2} ms"); + + Ok(query_results) + } + + async fn register_tables(&self, ctx: &SessionContext) -> Result<()> { + for table in IMDB_TABLES { + let table_provider = { self.get_table(ctx, table).await? }; + + if self.mem_table { + println!("Loading table '{table}' into memory"); + let start = Instant::now(); + let memtable = + MemTable::load(table_provider, Some(self.partitions()), &ctx.state()) + .await?; + println!( + "Loaded table '{}' into memory in {} ms", + table, + start.elapsed().as_millis() + ); + ctx.register_table(*table, Arc::new(memtable))?; + } else { + ctx.register_table(*table, table_provider)?; + } + } + Ok(()) + } + + async fn execute_query( + &self, + ctx: &SessionContext, + sql: &str, + ) -> Result> { + let debug = self.common.debug; + let plan = ctx.sql(sql).await?; + let (state, plan) = plan.into_parts(); + + if debug { + println!("=== Logical plan ===\n{plan}\n"); + } + + let plan = state.optimize(&plan)?; + if debug { + println!("=== Optimized logical plan ===\n{plan}\n"); + } + let physical_plan = state.create_physical_plan(&plan).await?; + if debug { + println!( + "=== Physical plan ===\n{}\n", + displayable(physical_plan.as_ref()).indent(true) + ); + } + let result = collect(physical_plan.clone(), state.task_ctx()).await?; + if debug { + println!( + "=== Physical plan with metrics ===\n{}\n", + DisplayableExecutionPlan::with_metrics(physical_plan.as_ref()) + .indent(true) + ); + if !result.is_empty() { + // do not call print_batches if there are no batches as the result is confusing + // and makes it look like there is a batch with no columns + pretty::print_batches(&result)?; + } + } + Ok(result) + } + + async fn get_table( + &self, + ctx: &SessionContext, + table: &str, + ) -> Result> { + let path = self.path.to_str().unwrap(); + let table_format = self.file_format.as_str(); + + // Obtain a snapshot of the SessionState + let state = ctx.state(); + let (format, path, extension): (Arc, String, &'static str) = + match table_format { + // dbgen creates .tbl ('|' delimited) files without header + "tbl" => { + let path = format!("{path}/{table}.tbl"); + + let format = CsvFormat::default() + .with_delimiter(b'|') + .with_has_header(false); + + (Arc::new(format), path, ".tbl") + } + "csv" => { + let path = format!("{path}/{table}.csv"); + let format = CsvFormat::default() + .with_delimiter(b',') + .with_escape(Some(b'\\')) + .with_has_header(false); + + (Arc::new(format), path, DEFAULT_CSV_EXTENSION) + } + "parquet" => { + let path = format!("{path}/{table}.parquet"); + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); + (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) + } + other => { + unimplemented!("Invalid file format '{}'", other); + } + }; + + let options = ListingOptions::new(format) + .with_file_extension(extension) + .with_collect_stat(state.config().collect_statistics()); + + let table_path = ListingTableUrl::parse(path)?; + let config = ListingTableConfig::new(table_path).with_listing_options(options); + let config = match table_format { + "parquet" => config.with_schema(Arc::new(get_imdb_table_schema(table))), + "csv" => config.with_schema(Arc::new(get_imdb_table_schema(table))), + _ => unreachable!(), + }; + + Ok(Arc::new(ListingTable::try_new(config)?)) + } + + fn iterations(&self) -> usize { + self.common.iterations + } + + fn partitions(&self) -> usize { + self.common.partitions.unwrap_or(num_cpus::get()) + } +} + +struct QueryResult { + elapsed: std::time::Duration, + row_count: usize, +} + +#[cfg(test)] +// Only run with "ci" mode when we have the data +#[cfg(feature = "ci")] +mod tests { + use std::path::Path; + + use super::*; + + use crate::util::CommonOpt; + use datafusion::common::exec_err; + use datafusion::error::Result; + use datafusion_proto::bytes::{ + logical_plan_from_bytes, logical_plan_to_bytes, physical_plan_from_bytes, + physical_plan_to_bytes, + }; + + fn get_imdb_data_path() -> Result { + let path = + std::env::var("IMDB_DATA").unwrap_or_else(|_| "benchmarks/data".to_string()); + if !Path::new(&path).exists() { + return exec_err!( + "Benchmark data not found (set IMDB_DATA env var to override): {}", + path + ); + } + Ok(path) + } + + async fn round_trip_logical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_imdb_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: 8192, + debug: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "parquet".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + prefer_hash_join: true, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(map_query_id_to_str(query))?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let plan2 = logical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", plan.display_indent()); + let plan2_formatted = format!("{}", plan2.display_indent()); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + async fn round_trip_physical_plan(query: usize) -> Result<()> { + let ctx = SessionContext::default(); + let path = get_imdb_data_path()?; + let common = CommonOpt { + iterations: 1, + partitions: Some(2), + batch_size: 8192, + debug: false, + }; + let opt = RunOpt { + query: Some(query), + common, + path: PathBuf::from(path.to_string()), + file_format: "parquet".to_string(), + mem_table: false, + output_path: None, + disable_statistics: false, + prefer_hash_join: true, + }; + opt.register_tables(&ctx).await?; + let queries = get_query_sql(map_query_id_to_str(query))?; + for query in queries { + let plan = ctx.sql(&query).await?; + let plan = plan.create_physical_plan().await?; + let bytes = physical_plan_to_bytes(plan.clone())?; + let plan2 = physical_plan_from_bytes(&bytes, &ctx)?; + let plan_formatted = format!("{}", displayable(plan.as_ref()).indent(false)); + let plan2_formatted = + format!("{}", displayable(plan2.as_ref()).indent(false)); + assert_eq!(plan_formatted, plan2_formatted); + } + Ok(()) + } + + macro_rules! test_round_trip_logical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_logical_plan($query).await + } + }; + } + + macro_rules! test_round_trip_physical { + ($tn:ident, $query:expr) => { + #[tokio::test] + async fn $tn() -> Result<()> { + round_trip_physical_plan($query).await + } + }; + } + + // logical plan tests + test_round_trip_logical!(round_trip_logical_plan_1a, 1); + test_round_trip_logical!(round_trip_logical_plan_1b, 2); + test_round_trip_logical!(round_trip_logical_plan_1c, 3); + test_round_trip_logical!(round_trip_logical_plan_1d, 4); + test_round_trip_logical!(round_trip_logical_plan_2a, 5); + test_round_trip_logical!(round_trip_logical_plan_2b, 6); + test_round_trip_logical!(round_trip_logical_plan_2c, 7); + test_round_trip_logical!(round_trip_logical_plan_2d, 8); + test_round_trip_logical!(round_trip_logical_plan_3a, 9); + test_round_trip_logical!(round_trip_logical_plan_3b, 10); + test_round_trip_logical!(round_trip_logical_plan_3c, 11); + test_round_trip_logical!(round_trip_logical_plan_4a, 12); + test_round_trip_logical!(round_trip_logical_plan_4b, 13); + test_round_trip_logical!(round_trip_logical_plan_4c, 14); + test_round_trip_logical!(round_trip_logical_plan_5a, 15); + test_round_trip_logical!(round_trip_logical_plan_5b, 16); + test_round_trip_logical!(round_trip_logical_plan_5c, 17); + test_round_trip_logical!(round_trip_logical_plan_6a, 18); + test_round_trip_logical!(round_trip_logical_plan_6b, 19); + test_round_trip_logical!(round_trip_logical_plan_6c, 20); + test_round_trip_logical!(round_trip_logical_plan_6d, 21); + test_round_trip_logical!(round_trip_logical_plan_6e, 22); + test_round_trip_logical!(round_trip_logical_plan_6f, 23); + test_round_trip_logical!(round_trip_logical_plan_7a, 24); + test_round_trip_logical!(round_trip_logical_plan_7b, 25); + test_round_trip_logical!(round_trip_logical_plan_7c, 26); + test_round_trip_logical!(round_trip_logical_plan_8a, 27); + test_round_trip_logical!(round_trip_logical_plan_8b, 28); + test_round_trip_logical!(round_trip_logical_plan_8c, 29); + test_round_trip_logical!(round_trip_logical_plan_8d, 30); + test_round_trip_logical!(round_trip_logical_plan_9a, 31); + test_round_trip_logical!(round_trip_logical_plan_9b, 32); + test_round_trip_logical!(round_trip_logical_plan_9c, 33); + test_round_trip_logical!(round_trip_logical_plan_9d, 34); + test_round_trip_logical!(round_trip_logical_plan_10a, 35); + test_round_trip_logical!(round_trip_logical_plan_10b, 36); + test_round_trip_logical!(round_trip_logical_plan_10c, 37); + test_round_trip_logical!(round_trip_logical_plan_11a, 38); + test_round_trip_logical!(round_trip_logical_plan_11b, 39); + test_round_trip_logical!(round_trip_logical_plan_11c, 40); + test_round_trip_logical!(round_trip_logical_plan_11d, 41); + test_round_trip_logical!(round_trip_logical_plan_12a, 42); + test_round_trip_logical!(round_trip_logical_plan_12b, 43); + test_round_trip_logical!(round_trip_logical_plan_12c, 44); + test_round_trip_logical!(round_trip_logical_plan_13a, 45); + test_round_trip_logical!(round_trip_logical_plan_13b, 46); + test_round_trip_logical!(round_trip_logical_plan_13c, 47); + test_round_trip_logical!(round_trip_logical_plan_13d, 48); + test_round_trip_logical!(round_trip_logical_plan_14a, 49); + test_round_trip_logical!(round_trip_logical_plan_14b, 50); + test_round_trip_logical!(round_trip_logical_plan_14c, 51); + test_round_trip_logical!(round_trip_logical_plan_15a, 52); + test_round_trip_logical!(round_trip_logical_plan_15b, 53); + test_round_trip_logical!(round_trip_logical_plan_15c, 54); + test_round_trip_logical!(round_trip_logical_plan_15d, 55); + test_round_trip_logical!(round_trip_logical_plan_16a, 56); + test_round_trip_logical!(round_trip_logical_plan_16b, 57); + test_round_trip_logical!(round_trip_logical_plan_16c, 58); + test_round_trip_logical!(round_trip_logical_plan_16d, 59); + test_round_trip_logical!(round_trip_logical_plan_17a, 60); + test_round_trip_logical!(round_trip_logical_plan_17b, 61); + test_round_trip_logical!(round_trip_logical_plan_17c, 62); + test_round_trip_logical!(round_trip_logical_plan_17d, 63); + test_round_trip_logical!(round_trip_logical_plan_17e, 64); + test_round_trip_logical!(round_trip_logical_plan_17f, 65); + test_round_trip_logical!(round_trip_logical_plan_18a, 66); + test_round_trip_logical!(round_trip_logical_plan_18b, 67); + test_round_trip_logical!(round_trip_logical_plan_18c, 68); + test_round_trip_logical!(round_trip_logical_plan_19a, 69); + test_round_trip_logical!(round_trip_logical_plan_19b, 70); + test_round_trip_logical!(round_trip_logical_plan_19c, 71); + test_round_trip_logical!(round_trip_logical_plan_19d, 72); + test_round_trip_logical!(round_trip_logical_plan_20a, 73); + test_round_trip_logical!(round_trip_logical_plan_20b, 74); + test_round_trip_logical!(round_trip_logical_plan_20c, 75); + test_round_trip_logical!(round_trip_logical_plan_21a, 76); + test_round_trip_logical!(round_trip_logical_plan_21b, 77); + test_round_trip_logical!(round_trip_logical_plan_21c, 78); + test_round_trip_logical!(round_trip_logical_plan_22a, 79); + test_round_trip_logical!(round_trip_logical_plan_22b, 80); + test_round_trip_logical!(round_trip_logical_plan_22c, 81); + test_round_trip_logical!(round_trip_logical_plan_22d, 82); + test_round_trip_logical!(round_trip_logical_plan_23a, 83); + test_round_trip_logical!(round_trip_logical_plan_23b, 84); + test_round_trip_logical!(round_trip_logical_plan_23c, 85); + test_round_trip_logical!(round_trip_logical_plan_24a, 86); + test_round_trip_logical!(round_trip_logical_plan_24b, 87); + test_round_trip_logical!(round_trip_logical_plan_25a, 88); + test_round_trip_logical!(round_trip_logical_plan_25b, 89); + test_round_trip_logical!(round_trip_logical_plan_25c, 90); + test_round_trip_logical!(round_trip_logical_plan_26a, 91); + test_round_trip_logical!(round_trip_logical_plan_26b, 92); + test_round_trip_logical!(round_trip_logical_plan_26c, 93); + test_round_trip_logical!(round_trip_logical_plan_27a, 94); + test_round_trip_logical!(round_trip_logical_plan_27b, 95); + test_round_trip_logical!(round_trip_logical_plan_27c, 96); + test_round_trip_logical!(round_trip_logical_plan_28a, 97); + test_round_trip_logical!(round_trip_logical_plan_28b, 98); + test_round_trip_logical!(round_trip_logical_plan_28c, 99); + test_round_trip_logical!(round_trip_logical_plan_29a, 100); + test_round_trip_logical!(round_trip_logical_plan_29b, 101); + test_round_trip_logical!(round_trip_logical_plan_29c, 102); + test_round_trip_logical!(round_trip_logical_plan_30a, 103); + test_round_trip_logical!(round_trip_logical_plan_30b, 104); + test_round_trip_logical!(round_trip_logical_plan_30c, 105); + test_round_trip_logical!(round_trip_logical_plan_31a, 106); + test_round_trip_logical!(round_trip_logical_plan_31b, 107); + test_round_trip_logical!(round_trip_logical_plan_31c, 108); + test_round_trip_logical!(round_trip_logical_plan_32a, 109); + test_round_trip_logical!(round_trip_logical_plan_32b, 110); + test_round_trip_logical!(round_trip_logical_plan_33a, 111); + test_round_trip_logical!(round_trip_logical_plan_33b, 112); + test_round_trip_logical!(round_trip_logical_plan_33c, 113); + + // physical plan tests + test_round_trip_physical!(round_trip_physical_plan_1a, 1); + test_round_trip_physical!(round_trip_physical_plan_1b, 2); + test_round_trip_physical!(round_trip_physical_plan_1c, 3); + test_round_trip_physical!(round_trip_physical_plan_1d, 4); + test_round_trip_physical!(round_trip_physical_plan_2a, 5); + test_round_trip_physical!(round_trip_physical_plan_2b, 6); + test_round_trip_physical!(round_trip_physical_plan_2c, 7); + test_round_trip_physical!(round_trip_physical_plan_2d, 8); + test_round_trip_physical!(round_trip_physical_plan_3a, 9); + test_round_trip_physical!(round_trip_physical_plan_3b, 10); + test_round_trip_physical!(round_trip_physical_plan_3c, 11); + test_round_trip_physical!(round_trip_physical_plan_4a, 12); + test_round_trip_physical!(round_trip_physical_plan_4b, 13); + test_round_trip_physical!(round_trip_physical_plan_4c, 14); + test_round_trip_physical!(round_trip_physical_plan_5a, 15); + test_round_trip_physical!(round_trip_physical_plan_5b, 16); + test_round_trip_physical!(round_trip_physical_plan_5c, 17); + test_round_trip_physical!(round_trip_physical_plan_6a, 18); + test_round_trip_physical!(round_trip_physical_plan_6b, 19); + test_round_trip_physical!(round_trip_physical_plan_6c, 20); + test_round_trip_physical!(round_trip_physical_plan_6d, 21); + test_round_trip_physical!(round_trip_physical_plan_6e, 22); + test_round_trip_physical!(round_trip_physical_plan_6f, 23); + test_round_trip_physical!(round_trip_physical_plan_7a, 24); + test_round_trip_physical!(round_trip_physical_plan_7b, 25); + test_round_trip_physical!(round_trip_physical_plan_7c, 26); + test_round_trip_physical!(round_trip_physical_plan_8a, 27); + test_round_trip_physical!(round_trip_physical_plan_8b, 28); + test_round_trip_physical!(round_trip_physical_plan_8c, 29); + test_round_trip_physical!(round_trip_physical_plan_8d, 30); + test_round_trip_physical!(round_trip_physical_plan_9a, 31); + test_round_trip_physical!(round_trip_physical_plan_9b, 32); + test_round_trip_physical!(round_trip_physical_plan_9c, 33); + test_round_trip_physical!(round_trip_physical_plan_9d, 34); + test_round_trip_physical!(round_trip_physical_plan_10a, 35); + test_round_trip_physical!(round_trip_physical_plan_10b, 36); + test_round_trip_physical!(round_trip_physical_plan_10c, 37); + test_round_trip_physical!(round_trip_physical_plan_11a, 38); + test_round_trip_physical!(round_trip_physical_plan_11b, 39); + test_round_trip_physical!(round_trip_physical_plan_11c, 40); + test_round_trip_physical!(round_trip_physical_plan_11d, 41); + test_round_trip_physical!(round_trip_physical_plan_12a, 42); + test_round_trip_physical!(round_trip_physical_plan_12b, 43); + test_round_trip_physical!(round_trip_physical_plan_12c, 44); + test_round_trip_physical!(round_trip_physical_plan_13a, 45); + test_round_trip_physical!(round_trip_physical_plan_13b, 46); + test_round_trip_physical!(round_trip_physical_plan_13c, 47); + test_round_trip_physical!(round_trip_physical_plan_13d, 48); + test_round_trip_physical!(round_trip_physical_plan_14a, 49); + test_round_trip_physical!(round_trip_physical_plan_14b, 50); + test_round_trip_physical!(round_trip_physical_plan_14c, 51); + test_round_trip_physical!(round_trip_physical_plan_15a, 52); + test_round_trip_physical!(round_trip_physical_plan_15b, 53); + test_round_trip_physical!(round_trip_physical_plan_15c, 54); + test_round_trip_physical!(round_trip_physical_plan_15d, 55); + test_round_trip_physical!(round_trip_physical_plan_16a, 56); + test_round_trip_physical!(round_trip_physical_plan_16b, 57); + test_round_trip_physical!(round_trip_physical_plan_16c, 58); + test_round_trip_physical!(round_trip_physical_plan_16d, 59); + test_round_trip_physical!(round_trip_physical_plan_17a, 60); + test_round_trip_physical!(round_trip_physical_plan_17b, 61); + test_round_trip_physical!(round_trip_physical_plan_17c, 62); + test_round_trip_physical!(round_trip_physical_plan_17d, 63); + test_round_trip_physical!(round_trip_physical_plan_17e, 64); + test_round_trip_physical!(round_trip_physical_plan_17f, 65); + test_round_trip_physical!(round_trip_physical_plan_18a, 66); + test_round_trip_physical!(round_trip_physical_plan_18b, 67); + test_round_trip_physical!(round_trip_physical_plan_18c, 68); + test_round_trip_physical!(round_trip_physical_plan_19a, 69); + test_round_trip_physical!(round_trip_physical_plan_19b, 70); + test_round_trip_physical!(round_trip_physical_plan_19c, 71); + test_round_trip_physical!(round_trip_physical_plan_19d, 72); + test_round_trip_physical!(round_trip_physical_plan_20a, 73); + test_round_trip_physical!(round_trip_physical_plan_20b, 74); + test_round_trip_physical!(round_trip_physical_plan_20c, 75); + test_round_trip_physical!(round_trip_physical_plan_21a, 76); + test_round_trip_physical!(round_trip_physical_plan_21b, 77); + test_round_trip_physical!(round_trip_physical_plan_21c, 78); + test_round_trip_physical!(round_trip_physical_plan_22a, 79); + test_round_trip_physical!(round_trip_physical_plan_22b, 80); + test_round_trip_physical!(round_trip_physical_plan_22c, 81); + test_round_trip_physical!(round_trip_physical_plan_22d, 82); + test_round_trip_physical!(round_trip_physical_plan_23a, 83); + test_round_trip_physical!(round_trip_physical_plan_23b, 84); + test_round_trip_physical!(round_trip_physical_plan_23c, 85); + test_round_trip_physical!(round_trip_physical_plan_24a, 86); + test_round_trip_physical!(round_trip_physical_plan_24b, 87); + test_round_trip_physical!(round_trip_physical_plan_25a, 88); + test_round_trip_physical!(round_trip_physical_plan_25b, 89); + test_round_trip_physical!(round_trip_physical_plan_25c, 90); + test_round_trip_physical!(round_trip_physical_plan_26a, 91); + test_round_trip_physical!(round_trip_physical_plan_26b, 92); + test_round_trip_physical!(round_trip_physical_plan_26c, 93); + test_round_trip_physical!(round_trip_physical_plan_27a, 94); + test_round_trip_physical!(round_trip_physical_plan_27b, 95); + test_round_trip_physical!(round_trip_physical_plan_27c, 96); + test_round_trip_physical!(round_trip_physical_plan_28a, 97); + test_round_trip_physical!(round_trip_physical_plan_28b, 98); + test_round_trip_physical!(round_trip_physical_plan_28c, 99); + test_round_trip_physical!(round_trip_physical_plan_29a, 100); + test_round_trip_physical!(round_trip_physical_plan_29b, 101); + test_round_trip_physical!(round_trip_physical_plan_29c, 102); + test_round_trip_physical!(round_trip_physical_plan_30a, 103); + test_round_trip_physical!(round_trip_physical_plan_30b, 104); + test_round_trip_physical!(round_trip_physical_plan_30c, 105); + test_round_trip_physical!(round_trip_physical_plan_31a, 106); + test_round_trip_physical!(round_trip_physical_plan_31b, 107); + test_round_trip_physical!(round_trip_physical_plan_31c, 108); + test_round_trip_physical!(round_trip_physical_plan_32a, 109); + test_round_trip_physical!(round_trip_physical_plan_32b, 110); + test_round_trip_physical!(round_trip_physical_plan_33a, 111); + test_round_trip_physical!(round_trip_physical_plan_33b, 112); + test_round_trip_physical!(round_trip_physical_plan_33c, 113); +} diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index f81220aa2c94..02410e0cfa01 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -17,8 +17,8 @@ //! DataFusion benchmark runner pub mod clickbench; +pub mod imdb; pub mod parquet_filter; pub mod sort; pub mod tpch; -mod util; -pub use util::*; +pub mod util; diff --git a/benchmarks/src/parquet_filter.rs b/benchmarks/src/parquet_filter.rs index 5c98a2f8be3d..34103af0ffd2 100644 --- a/benchmarks/src/parquet_filter.rs +++ b/benchmarks/src/parquet_filter.rs @@ -17,7 +17,7 @@ use std::path::PathBuf; -use crate::{AccessLogOpt, BenchmarkRun, CommonOpt}; +use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; diff --git a/benchmarks/src/sort.rs b/benchmarks/src/sort.rs index 19eec2949ef6..b2038c432f77 100644 --- a/benchmarks/src/sort.rs +++ b/benchmarks/src/sort.rs @@ -18,11 +18,11 @@ use std::path::PathBuf; use std::sync::Arc; -use crate::{AccessLogOpt, BenchmarkRun, CommonOpt}; +use crate::util::{AccessLogOpt, BenchmarkRun, CommonOpt}; use arrow::util::pretty; use datafusion::common::Result; -use datafusion::physical_expr::PhysicalSortExpr; +use datafusion::physical_expr::{LexOrdering, LexOrderingRef, PhysicalSortExpr}; use datafusion::physical_plan::collect; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::prelude::{SessionConfig, SessionContext}; @@ -170,13 +170,13 @@ impl RunOpt { async fn exec_sort( ctx: &SessionContext, - expr: &[PhysicalSortExpr], + expr: LexOrderingRef<'_>, test_file: &TestParquetFile, debug: bool, ) -> Result<(usize, std::time::Duration)> { let start = Instant::now(); let scan = test_file.create_scan(ctx, None).await?; - let exec = Arc::new(SortExec::new(expr.to_owned(), scan)); + let exec = Arc::new(SortExec::new(LexOrdering::new(expr.to_owned()), scan)); let task_ctx = ctx.task_ctx(); let result = collect(exec, task_ctx).await?; let elapsed = start.elapsed(); diff --git a/benchmarks/src/tpch/run.rs b/benchmarks/src/tpch/run.rs index f2a93d2ea549..9ff1f72d8606 100644 --- a/benchmarks/src/tpch/run.rs +++ b/benchmarks/src/tpch/run.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::{ get_query_sql, get_tbl_tpch_table_schema, get_tpch_table_schema, TPCH_TABLES, }; -use crate::{BenchmarkRun, CommonOpt}; +use crate::util::{BenchmarkRun, CommonOpt}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::{self, pretty_format_batches}; @@ -200,12 +200,12 @@ impl RunOpt { let (state, plan) = plan.into_parts(); if debug { - println!("=== Logical plan ===\n{plan:?}\n"); + println!("=== Logical plan ===\n{plan}\n"); } let plan = state.optimize(&plan)?; if debug { - println!("=== Optimized logical plan ===\n{plan:?}\n"); + println!("=== Optimized logical plan ===\n{plan}\n"); } let physical_plan = state.create_physical_plan(&plan).await?; if debug { @@ -263,7 +263,8 @@ impl RunOpt { } "parquet" => { let path = format!("{path}/{table}"); - let format = ParquetFormat::default().with_enable_pruning(true); + let format = ParquetFormat::default() + .with_options(ctx.state().table_options().parquet.clone()); (Arc::new(format), path, DEFAULT_PARQUET_EXTENSION) } diff --git a/ci/scripts/retry b/ci/scripts/retry new file mode 100755 index 000000000000..0569dea58c94 --- /dev/null +++ b/ci/scripts/retry @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -euo pipefail + +x() { + echo "+ $*" >&2 + "$@" +} + +max_retry_time_seconds=$(( 3 * 60 )) +retry_delay_seconds=10 + +END=$(( $(date +%s) + ${max_retry_time_seconds} )) + +while (( $(date +%s) < $END )); do + x "$@" && exit 0 + sleep "${retry_delay_seconds}" +done + +echo "$0: retrying [$*] timed out" >&2 +exit 1 diff --git a/ci/scripts/rust_docs.sh b/ci/scripts/rust_docs.sh index cf83b80b5132..5c93711b6fb6 100755 --- a/ci/scripts/rust_docs.sh +++ b/ci/scripts/rust_docs.sh @@ -18,7 +18,7 @@ # under the License. set -ex -export RUSTDOCFLAGS="-D warnings -A rustdoc::private-intra-doc-links" +export RUSTDOCFLAGS="-D warnings" cargo doc --document-private-items --no-deps --workspace cd datafusion-cli cargo doc --document-private-items --no-deps diff --git a/ci/scripts/rust_example.sh b/ci/scripts/rust_example.sh index 675dc4e527d0..1bb97c88106f 100755 --- a/ci/scripts/rust_example.sh +++ b/ci/scripts/rust_example.sh @@ -19,7 +19,6 @@ set -ex cd datafusion-examples/examples/ -cargo fmt --all -- --check cargo check --examples files=$(ls .) @@ -29,5 +28,6 @@ do # Skip tests that rely on external storage and flight if [ ! -d $filename ]; then cargo run --example $example_name + cargo clean -p datafusion-examples fi done diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ba3e68e4011f..b37253d1a135 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -4,18 +4,18 @@ version = 3 [[package]] name = "addr2line" -version = "0.21.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] [[package]] -name = "adler" -version = "1.0.2" +name = "adler2" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "adler32" @@ -82,11 +82,54 @@ dependencies = [ "libc", ] +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + [[package]] name = "anstyle" -version = "1.0.6" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys 0.59.0", +] [[package]] name = "apache-avro" @@ -118,21 +161,21 @@ dependencies = [ [[package]] name = "arrayref" -version = "0.3.7" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b4930d2cb77ce62f89ee5d5289b4ac049559b1c45539271f5ed4fdc7db34545" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" [[package]] name = "arrayvec" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" +checksum = "4caf25cdc4a985f91df42ed9e9308e1adbcd341a31a72605c697033fcef163e3" dependencies = [ "arrow-arith", "arrow-array", @@ -151,9 +194,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" +checksum = "91f2dfd1a7ec0aca967dfaa616096aec49779adc8eccec005e2f5e4111b1192a" dependencies = [ "arrow-array", "arrow-buffer", @@ -166,9 +209,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" +checksum = "d39387ca628be747394890a6e47f138ceac1aa912eab64f02519fed24b637af8" dependencies = [ "ahash", "arrow-buffer", @@ -177,15 +220,15 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "num", ] [[package]] name = "arrow-buffer" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" +checksum = "9e51e05228852ffe3eb391ce7178a0f97d2cf80cc6ef91d3c4a6b3cb688049ec" dependencies = [ "bytes", "half", @@ -194,9 +237,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" +checksum = "d09aea56ec9fa267f3f3f6cdab67d8a9974cbba90b3aa38c8fe9d0bb071bd8c1" dependencies = [ "arrow-array", "arrow-buffer", @@ -204,7 +247,7 @@ dependencies = [ "arrow-schema", "arrow-select", "atoi", - "base64 0.22.0", + "base64 0.22.1", "chrono", "comfy-table", "half", @@ -215,9 +258,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95cbcba196b862270bf2a5edb75927380a7f3a163622c61d40cbba416a6305f2" +checksum = "c07b5232be87d115fde73e32f2ca7f1b353bff1b44ac422d3c6fc6ae38f11f0d" dependencies = [ "arrow-array", "arrow-buffer", @@ -234,9 +277,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" +checksum = "b98ae0af50890b494cebd7d6b04b35e896205c1d1df7b29a6272c5d0d0249ef5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -246,9 +289,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" +checksum = "0ed91bdeaff5a1c00d28d8f73466bcb64d32bbd7093b5a30156b4b9f4dba3eee" dependencies = [ "arrow-array", "arrow-buffer", @@ -261,9 +304,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eaafb5714d4e59feae964714d724f880511500e3569cc2a94d02456b403a2a49" +checksum = "0471f51260a5309307e5d409c9dc70aede1cd9cf1d4ff0f0a1e8e1a2dd0e0d3c" dependencies = [ "arrow-array", "arrow-buffer", @@ -272,7 +315,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.6", + "indexmap", "lexical-core", "num", "serde", @@ -281,9 +324,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" +checksum = "2883d7035e0b600fb4c30ce1e50e66e53d8656aa729f2bfa4b51d359cf3ded52" dependencies = [ "arrow-array", "arrow-buffer", @@ -296,9 +339,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" +checksum = "552907e8e587a6fde4f8843fd7a27a576a260f65dab6c065741ea79f633fc5be" dependencies = [ "ahash", "arrow-array", @@ -306,20 +349,19 @@ dependencies = [ "arrow-data", "arrow-schema", "half", - "hashbrown 0.14.3", ] [[package]] name = "arrow-schema" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" +checksum = "539ada65246b949bd99ffa0881a9a15a4a529448af1a07a9838dd78617dafab1" [[package]] name = "arrow-select" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" +checksum = "6259e566b752da6dceab91766ed8b2e67bf6270eb9ad8a6e07a33c1bede2b125" dependencies = [ "ahash", "arrow-array", @@ -331,9 +373,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" +checksum = "f3179ccbd18ebf04277a095ba7321b93fd1f774f18816bd5f6b3ce2f594edb6c" dependencies = [ "arrow-array", "arrow-buffer", @@ -348,13 +390,14 @@ dependencies = [ [[package]] name = "assert_cmd" -version = "2.0.14" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed72493ac66d5804837f480ab3766c72bdfab91a65e565fc54fa9e42db0073a8" +checksum = "dc1835b7f27878de8525dc71410b5a31cdcc5f230aed5ba5df968e09c201b23d" dependencies = [ "anstyle", "bstr", "doc-comment", + "libc", "predicates", "predicates-core", "predicates-tree", @@ -363,9 +406,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.8" +version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07dbbf24db18d609b1462965249abdf49129ccad073ec257da372adc83259c60" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" dependencies = [ "bzip2", "flate2", @@ -375,19 +418,19 @@ dependencies = [ "pin-project-lite", "tokio", "xz2", - "zstd 0.13.1", - "zstd-safe 7.1.0", + "zstd 0.13.2", + "zstd-safe 7.2.1", ] [[package]] name = "async-trait" -version = "0.1.80" +version = "0.1.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] @@ -400,178 +443,169 @@ dependencies = [ ] [[package]] -name = "atty" -version = "0.2.14" +name = "atomic-waker" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.2.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "0.55.3" +version = "1.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcdcf0d683fe9c23d32cf5b53c9918ea0a500375a9fb20109802552658e576c9" +checksum = "2d6448cfb224dd6a9b9ac734f58622dd0d4751f3589f3b777345745f46b2eb14" dependencies = [ "aws-credential-types", - "aws-http", + "aws-runtime", "aws-sdk-sso", + "aws-sdk-ssooidc", "aws-sdk-sts", "aws-smithy-async", - "aws-smithy-client", "aws-smithy-http", - "aws-smithy-http-tower", "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", - "fastrand 1.9.0", + "fastrand", "hex", - "http", - "hyper", - "ring 0.16.20", + "http 0.2.12", + "ring", "time", "tokio", - "tower", "tracing", + "url", "zeroize", ] [[package]] name = "aws-credential-types" -version = "0.55.3" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fcdb2f7acbc076ff5ad05e7864bdb191ca70a6fd07668dc3a1a8bcd051de5ae" +checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" dependencies = [ "aws-smithy-async", + "aws-smithy-runtime-api", "aws-smithy-types", - "fastrand 1.9.0", - "tokio", - "tracing", "zeroize", ] [[package]] -name = "aws-endpoint" -version = "0.55.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cce1c41a6cfaa726adee9ebb9a56fcd2bbfd8be49fd8a04c5e20fd968330b04" -dependencies = [ - "aws-smithy-http", - "aws-smithy-types", - "aws-types", - "http", - "regex", - "tracing", -] - -[[package]] -name = "aws-http" -version = "0.55.3" +name = "aws-runtime" +version = "1.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aadbc44e7a8f3e71c8b374e03ecd972869eb91dd2bc89ed018954a52ba84bc44" +checksum = "a10d5c055aa540164d9561a0e2e74ad30f0dcf7393c3a92f6733ddf9c5762468" dependencies = [ "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", - "http", - "http-body", - "lazy_static", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "once_cell", "percent-encoding", "pin-project-lite", "tracing", + "uuid", ] [[package]] name = "aws-sdk-sso" -version = "0.28.0" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8b812340d86d4a766b2ca73f740dfd47a97c2dff0c06c8517a16d88241957e4" +checksum = "ded855583fa1d22e88fe39fd6062b062376e50a8211989e07cf5e38d52eb3453" dependencies = [ "aws-credential-types", - "aws-endpoint", - "aws-http", - "aws-sig-auth", + "aws-runtime", "aws-smithy-async", - "aws-smithy-client", "aws-smithy-http", - "aws-smithy-http-tower", "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", - "http", - "regex", - "tokio-stream", - "tower", + "http 0.2.12", + "once_cell", + "regex-lite", "tracing", ] [[package]] -name = "aws-sdk-sts" -version = "0.28.0" +name = "aws-sdk-ssooidc" +version = "1.49.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "265fac131fbfc188e5c3d96652ea90ecc676a934e3174eaaee523c6cec040b3b" +checksum = "9177ea1192e6601ae16c7273385690d88a7ed386a00b74a6bc894d12103cd933" dependencies = [ "aws-credential-types", - "aws-endpoint", - "aws-http", - "aws-sig-auth", + "aws-runtime", "aws-smithy-async", - "aws-smithy-client", "aws-smithy-http", - "aws-smithy-http-tower", "aws-smithy-json", - "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", "aws-smithy-types", - "aws-smithy-xml", "aws-types", "bytes", - "http", - "regex", - "tower", + "http 0.2.12", + "once_cell", + "regex-lite", "tracing", ] [[package]] -name = "aws-sig-auth" -version = "0.55.3" +name = "aws-sdk-sts" +version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b94acb10af0c879ecd5c7bdf51cda6679a0a4f4643ce630905a77673bfa3c61" +checksum = "823ef553cf36713c97453e2ddff1eb8f62be7f4523544e2a5db64caf80100f0a" dependencies = [ "aws-credential-types", - "aws-sigv4", + "aws-runtime", + "aws-smithy-async", "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", "aws-types", - "http", + "http 0.2.12", + "once_cell", + "regex-lite", "tracing", ] [[package]] name = "aws-sigv4" -version = "0.55.3" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d2ce6f507be68e968a33485ced670111d1cbad161ddbbab1e313c03d37d8f4c" +checksum = "5619742a0d8f253be760bfbb8e8e8368c69e3587e4637af5754e488a611499b1" dependencies = [ + "aws-credential-types", "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", "form_urlencoded", "hex", "hmac", - "http", + "http 0.2.12", + "http 1.1.0", "once_cell", "percent-encoding", - "regex", "sha2", "time", "tracing", @@ -579,53 +613,28 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "0.55.3" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13bda3996044c202d75b91afeb11a9afae9db9a721c6a7a427410018e286b880" +checksum = "62220bc6e97f946ddd51b5f1361f78996e704677afc518a4ff66b7a72ea1378c" dependencies = [ "futures-util", "pin-project-lite", "tokio", - "tokio-stream", -] - -[[package]] -name = "aws-smithy-client" -version = "0.55.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a86aa6e21e86c4252ad6a0e3e74da9617295d8d6e374d552be7d3059c41cedd" -dependencies = [ - "aws-smithy-async", - "aws-smithy-http", - "aws-smithy-http-tower", - "aws-smithy-types", - "bytes", - "fastrand 1.9.0", - "http", - "http-body", - "hyper", - "hyper-rustls 0.23.2", - "lazy_static", - "pin-project-lite", - "rustls 0.20.9", - "tokio", - "tower", - "tracing", ] [[package]] name = "aws-smithy-http" -version = "0.55.3" +version = "0.60.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b3b693869133551f135e1f2c77cb0b8277d9e3e17feaf2213f735857c4f0d28" +checksum = "5c8bc3e8fdc6b8d07d976e301c02fe553f72a39b7a9fea820e023268467d7ab6" dependencies = [ + "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", - "http", - "http-body", - "hyper", + "http 0.2.12", + "http-body 0.4.6", "once_cell", "percent-encoding", "pin-project-lite", @@ -634,91 +643,130 @@ dependencies = [ ] [[package]] -name = "aws-smithy-http-tower" -version = "0.55.3" +name = "aws-smithy-json" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ae4f6c5798a247fac98a867698197d9ac22643596dc3777f0c76b91917616b9" +checksum = "4683df9469ef09468dad3473d129960119a0d3593617542b7d52086c8486f2d6" dependencies = [ - "aws-smithy-http", "aws-smithy-types", - "bytes", - "http", - "http-body", - "pin-project-lite", - "tower", - "tracing", ] [[package]] -name = "aws-smithy-json" -version = "0.55.3" +name = "aws-smithy-query" +version = "0.60.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23f9f42fbfa96d095194a632fbac19f60077748eba536eb0b9fecc28659807f8" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" dependencies = [ "aws-smithy-types", + "urlencoding", ] [[package]] -name = "aws-smithy-query" -version = "0.55.3" +name = "aws-smithy-runtime" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98819eb0b04020a1c791903533b638534ae6c12e2aceda3e6e6fba015608d51d" +checksum = "be28bd063fa91fd871d131fc8b68d7cd4c5fa0869bea68daca50dcb1cbd76be2" dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime-api", "aws-smithy-types", - "urlencoding", + "bytes", + "fastrand", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", + "http-body 1.0.1", + "httparse", + "hyper 0.14.31", + "hyper-rustls 0.24.2", + "once_cell", + "pin-project-lite", + "pin-utils", + "rustls 0.21.12", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e086682a53d3aa241192aa110fa8dfce98f2f5ac2ead0de84d41582c7e8fdb96" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.1.0", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", ] [[package]] name = "aws-smithy-types" -version = "0.55.3" +version = "1.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16a3d0bf4f324f4ef9793b86a1701d9700fbcdbd12a846da45eed104c634c6e8" +checksum = "07c9cdc179e6afbf5d391ab08c85eac817b51c87e1892a5edb5f7bbdc64314b4" dependencies = [ "base64-simd", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.1.0", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", "itoa", "num-integer", + "pin-project-lite", + "pin-utils", "ryu", + "serde", "time", + "tokio", + "tokio-util", ] [[package]] name = "aws-smithy-xml" -version = "0.55.3" +version = "0.60.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1b9d12875731bd07e767be7baad95700c3137b56730ec9ddeedb52a5e5ca63b" +checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" dependencies = [ "xmlparser", ] [[package]] name = "aws-types" -version = "0.55.3" +version = "1.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd209616cc8d7bfb82f87811a5c655dc97537f592689b18743bddf5dc5c4829" +checksum = "5221b91b3e441e6675310829fd8984801b772cb1546ef6c0e54dec9f1ac13fef" dependencies = [ "aws-credential-types", "aws-smithy-async", - "aws-smithy-client", - "aws-smithy-http", + "aws-smithy-runtime-api", "aws-smithy-types", - "http", "rustc_version", "tracing", ] [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", "miniz_oxide", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -729,9 +777,9 @@ checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" [[package]] name = "base64" -version = "0.22.0" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9475866fec1451be56a3c2400fd081ff546538961565ccb5b7142cbd22bc7a51" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64-simd" @@ -751,9 +799,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "blake2" @@ -766,9 +814,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.1" +version = "1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30cca6d3674597c30ddf2c587bf8d9d65c9a84d2326d941cc79c9842dfe0ef52" +checksum = "d82033247fd8e890df8f740e407ad4d038debb9eb1f40533fffb32e7d17dc6f7" dependencies = [ "arrayref", "arrayvec", @@ -788,9 +836,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.5.0" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" +checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -799,9 +847,9 @@ dependencies = [ [[package]] name = "brotli-decompressor" -version = "2.5.1" +version = "4.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e2e4afe60d7dd600fdd3de8d0f08c2b7ec039712e3b6137ff98b7004e82de4f" +checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -809,9 +857,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.9.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" +checksum = "40723b8fb387abc38f4f4a37c09073622e41dd12327033091ef8950659e6dc0c" dependencies = [ "memchr", "regex-automata", @@ -832,9 +880,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "bytes-utils" @@ -869,12 +917,13 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.94" +version = "1.1.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17f6e324229dc011159fcc089755d1e2e216a90d43a7dea6853ca740b84f35e7" +checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9" dependencies = [ "jobserver", "libc", + "shlex", ] [[package]] @@ -883,6 +932,18 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -893,14 +954,14 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] name = "chrono-tz" -version = "0.8.6" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +checksum = "cd6dd8046d00723a59a2f8c5f295c515b9bb9a331ee4f8f3d4dd49e428acd3b6" dependencies = [ "chrono", "chrono-tz-build", @@ -909,73 +970,77 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] [[package]] name = "clap" -version = "3.2.25" +version = "4.5.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" +checksum = "b97f376d85a664d5837dbae44bf546e6477a679ff6610010f17276f686d867e8" dependencies = [ - "atty", - "bitflags 1.3.2", + "clap_builder", "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19bc80abd44e4bed93ca373a0704ccbd1b710dc5749406201bb018272808dc54" +dependencies = [ + "anstream", + "anstyle", "clap_lex", - "indexmap 1.9.3", - "once_cell", "strsim", - "termcolor", - "textwrap", ] [[package]] name = "clap_derive" -version = "3.2.25" +version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck", - "proc-macro-error", + "heck 0.5.0", "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] name = "clap_lex" -version = "0.2.4" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" -dependencies = [ - "os_str_bytes", -] +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "clipboard-win" -version = "4.5.0" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7191c27c2357d9b7ef96baac1773290d4ca63b24205b82a3fd8a0637afcf0362" +checksum = "15efe7a882b08f34e38556b14f2fb3daa98769d06c7f0c1b076dfd0d983bc892" dependencies = [ "error-code", - "str-buf", - "winapi", ] +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "comfy-table" version = "7.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b34115915337defe99b2aff5c2ce6771e5fbc4079f4b506301f5cf394c8452f7" dependencies = [ - "strum 0.26.2", - "strum_macros 0.26.2", + "strum 0.26.3", + "strum_macros 0.26.4", "unicode-width", ] @@ -1001,9 +1066,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7144d30dcf0fafbce74250a3963025d8d52177934239851c917d29f1df280c2" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" [[package]] name = "core-foundation" @@ -1017,9 +1082,9 @@ dependencies = [ [[package]] name = "core-foundation-sys" -version = "0.8.6" +version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "core2" @@ -1032,22 +1097,28 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "608697df725056feaccfa42cffdaeeec3fccc4ffc38358ecd19b243e716a78e0" dependencies = [ "libc", ] [[package]] name = "crc32fast" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crunchy" version = "0.2.2" @@ -1092,23 +1163,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "edb49164822f3ee45b17acd4a208cfc1251410cf0cad9a833234c9890774dd9f" dependencies = [ "quote", - "syn 2.0.60", + "syn", ] [[package]] name = "dary_heap" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7762d17f1241643615821a8455a0b2c3e803784b058693d990b11f2dce25a0ca" +checksum = "04d2cd9c18b9f454ed67da600630b021a8a80bf33f8c95896ab33aaf1c26b728" [[package]] name = "dashmap" -version = "5.5.3" +version = "6.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" dependencies = [ "cfg-if", - "hashbrown 0.14.3", + "crossbeam-utils", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -1116,7 +1188,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "37.1.0" +version = "43.0.0" dependencies = [ "ahash", "apache-avro", @@ -1130,23 +1202,27 @@ dependencies = [ "bzip2", "chrono", "dashmap", + "datafusion-catalog", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", "datafusion-functions", "datafusion-functions-aggregate", - "datafusion-functions-array", + "datafusion-functions-nested", + "datafusion-functions-window", "datafusion-optimizer", "datafusion-physical-expr", + "datafusion-physical-expr-common", + "datafusion-physical-optimizer", "datafusion-physical-plan", "datafusion-sql", "flate2", "futures", "glob", "half", - "hashbrown 0.14.3", - "indexmap 2.2.6", + "hashbrown 0.14.5", + "indexmap", "itertools", "log", "num-traits", @@ -1154,6 +1230,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", + "paste", "pin-project-lite", "rand", "sqlparser", @@ -1163,18 +1240,34 @@ dependencies = [ "url", "uuid", "xz2", - "zstd 0.13.1", + "zstd 0.13.2", +] + +[[package]] +name = "datafusion-catalog" +version = "43.0.0" +dependencies = [ + "arrow-schema", + "async-trait", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "datafusion-physical-plan", + "parking_lot", ] [[package]] name = "datafusion-cli" -version = "37.1.0" +version = "43.0.0" dependencies = [ "arrow", "assert_cmd", "async-trait", "aws-config", "aws-credential-types", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", "clap", "ctor", "datafusion", @@ -1195,7 +1288,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "37.1.0" +version = "43.0.0" dependencies = [ "ahash", "apache-avro", @@ -1205,24 +1298,29 @@ dependencies = [ "arrow-schema", "chrono", "half", + "hashbrown 0.14.5", + "indexmap", "instant", "libc", "num_cpus", "object_store", "parquet", + "paste", "sqlparser", + "tokio", ] [[package]] name = "datafusion-common-runtime" -version = "37.1.0" +version = "43.0.0" dependencies = [ + "log", "tokio", ] [[package]] name = "datafusion-execution" -version = "37.1.0" +version = "43.0.0" dependencies = [ "arrow", "chrono", @@ -1230,7 +1328,7 @@ dependencies = [ "datafusion-common", "datafusion-expr", "futures", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "log", "object_store", "parking_lot", @@ -1241,34 +1339,50 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "37.1.0" +version = "43.0.0" dependencies = [ "ahash", "arrow", "arrow-array", + "arrow-buffer", "chrono", "datafusion-common", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", + "datafusion-functions-window-common", + "datafusion-physical-expr-common", + "indexmap", "paste", "serde_json", "sqlparser", - "strum 0.26.2", - "strum_macros 0.26.2", + "strum 0.26.3", + "strum_macros 0.26.4", +] + +[[package]] +name = "datafusion-expr-common" +version = "43.0.0" +dependencies = [ + "arrow", + "datafusion-common", + "itertools", + "paste", ] [[package]] name = "datafusion-functions" -version = "37.1.0" +version = "43.0.0" dependencies = [ "arrow", - "base64 0.22.0", + "arrow-buffer", + "base64 0.22.1", "blake2", "blake3", "chrono", "datafusion-common", "datafusion-execution", "datafusion-expr", - "datafusion-physical-expr", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "hex", "itertools", "log", @@ -1282,21 +1396,38 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "37.1.0" +version = "43.0.0" dependencies = [ + "ahash", "arrow", + "arrow-schema", "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions-aggregate-common", + "datafusion-physical-expr", "datafusion-physical-expr-common", + "half", + "indexmap", "log", "paste", - "sqlparser", ] [[package]] -name = "datafusion-functions-array" -version = "37.1.0" +name = "datafusion-functions-aggregate-common" +version = "43.0.0" +dependencies = [ + "ahash", + "arrow", + "datafusion-common", + "datafusion-expr-common", + "datafusion-physical-expr-common", + "rand", +] + +[[package]] +name = "datafusion-functions-nested" +version = "43.0.0" dependencies = [ "arrow", "arrow-array", @@ -1307,14 +1438,38 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-functions", + "datafusion-functions-aggregate", + "datafusion-physical-expr-common", "itertools", "log", "paste", + "rand", +] + +[[package]] +name = "datafusion-functions-window" +version = "43.0.0" +dependencies = [ + "datafusion-common", + "datafusion-expr", + "datafusion-functions-window-common", + "datafusion-physical-expr", + "datafusion-physical-expr-common", + "log", + "paste", +] + +[[package]] +name = "datafusion-functions-window-common" +version = "43.0.0" +dependencies = [ + "datafusion-common", + "datafusion-physical-expr-common", ] [[package]] name = "datafusion-optimizer" -version = "37.1.0" +version = "43.0.0" dependencies = [ "arrow", "async-trait", @@ -1322,15 +1477,17 @@ dependencies = [ "datafusion-common", "datafusion-expr", "datafusion-physical-expr", - "hashbrown 0.14.3", + "hashbrown 0.14.5", + "indexmap", "itertools", "log", + "paste", "regex-syntax", ] [[package]] name = "datafusion-physical-expr" -version = "37.1.0" +version = "43.0.0" dependencies = [ "ahash", "arrow", @@ -1339,56 +1496,71 @@ dependencies = [ "arrow-ord", "arrow-schema", "arrow-string", - "base64 0.22.0", "chrono", "datafusion-common", - "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-expr-common", + "datafusion-functions-aggregate-common", "datafusion-physical-expr-common", "half", - "hashbrown 0.14.3", - "hex", - "indexmap 2.2.6", + "hashbrown 0.14.5", + "indexmap", "itertools", "log", "paste", "petgraph", - "regex", ] [[package]] name = "datafusion-physical-expr-common" -version = "37.1.0" +version = "43.0.0" dependencies = [ + "ahash", "arrow", "datafusion-common", - "datafusion-expr", + "datafusion-expr-common", + "hashbrown 0.14.5", + "rand", ] [[package]] -name = "datafusion-physical-plan" -version = "37.1.0" +name = "datafusion-physical-optimizer" +version = "43.0.0" dependencies = [ - "ahash", "arrow", - "arrow-array", - "arrow-buffer", - "arrow-ord", "arrow-schema", - "async-trait", - "chrono", + "datafusion-common", + "datafusion-execution", + "datafusion-expr-common", + "datafusion-physical-expr", + "datafusion-physical-plan", + "itertools", +] + +[[package]] +name = "datafusion-physical-plan" +version = "43.0.0" +dependencies = [ + "ahash", + "arrow", + "arrow-array", + "arrow-buffer", + "arrow-ord", + "arrow-schema", + "async-trait", + "chrono", "datafusion-common", "datafusion-common-runtime", "datafusion-execution", "datafusion-expr", - "datafusion-functions-aggregate", + "datafusion-functions-aggregate-common", + "datafusion-functions-window-common", "datafusion-physical-expr", "datafusion-physical-expr-common", "futures", "half", - "hashbrown 0.14.3", - "indexmap 2.2.6", + "hashbrown 0.14.5", + "indexmap", "itertools", "log", "once_cell", @@ -1400,16 +1572,18 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "37.1.0" +version = "43.0.0" dependencies = [ "arrow", "arrow-array", "arrow-schema", "datafusion-common", "datafusion-expr", + "indexmap", "log", + "regex", "sqlparser", - "strum 0.26.2", + "strum 0.26.3", ] [[package]] @@ -1440,43 +1614,34 @@ dependencies = [ [[package]] name = "dirs" -version = "4.0.0" +version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ "dirs-sys", ] -[[package]] -name = "dirs-next" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" -dependencies = [ - "cfg-if", - "dirs-sys-next", -] - [[package]] name = "dirs-sys" -version = "0.3.7" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" dependencies = [ "libc", + "option-ext", "redox_users", - "winapi", + "windows-sys 0.48.0", ] [[package]] -name = "dirs-sys-next" -version = "0.1.2" +name = "displaydoc" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ - "libc", - "redox_users", - "winapi", + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -1487,36 +1652,37 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.11.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] -name = "encoding_rs" -version = "0.8.34" +name = "endian-type" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" -dependencies = [ - "cfg-if", -] +checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] -name = "endian-type" +name = "env_filter" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] [[package]] name = "env_logger" -version = "0.9.3" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" dependencies = [ - "atty", + "anstream", + "anstyle", + "env_filter", "humantime", "log", - "regex", - "termcolor", ] [[package]] @@ -1527,9 +1693,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", "windows-sys 0.52.0", @@ -1537,38 +1703,25 @@ dependencies = [ [[package]] name = "error-code" -version = "2.3.1" +version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f18991e7bf11e7ffee451b5318b5c1a73c52d0d0ada6e5a3017c8c1ced6a21" -dependencies = [ - "libc", - "str-buf", -] +checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" [[package]] name = "fastrand" -version = "1.9.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" -dependencies = [ - "instant", -] - -[[package]] -name = "fastrand" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" +checksum = "e8c02a5121d4ea3eb16a80748c74f5549a5665e4c21333c6098f283870fbdea6" [[package]] name = "fd-lock" -version = "3.0.13" +version = "4.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef033ed5e9bad94e55838ca0ca906db0e043f517adda0c8b79c7a8c66c93c1b5" +checksum = "7e5768da2206272c81ef0b5e951a41862938a6070da63bcea197899942d3b947" dependencies = [ "cfg-if", "rustix", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -1579,9 +1732,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flatbuffers" -version = "23.5.26" +version = "24.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dac53e22462d78c16d64a1cd22371b54cc3fe94aa15e7886a2fa6e5d1ab8640" +checksum = "8add37afff2d4ffa83bc748a70b4b1370984f6980768554182424ef71447c35f" dependencies = [ "bitflags 1.3.2", "rustc_version", @@ -1589,9 +1742,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.28" +version = "1.0.34" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +checksum = "a1b589b4dc103969ad3cf85c950899926ec64300a1a46d76c03a6072957036f0" dependencies = [ "crc32fast", "miniz_oxide", @@ -1623,9 +1776,9 @@ dependencies = [ [[package]] name = "futures" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" dependencies = [ "futures-channel", "futures-core", @@ -1638,9 +1791,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -1648,15 +1801,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" dependencies = [ "futures-core", "futures-task", @@ -1665,32 +1818,32 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -1700,9 +1853,9 @@ checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -1728,9 +1881,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "libc", @@ -1739,9 +1892,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "glob" @@ -1760,8 +1913,27 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", - "indexmap 2.2.6", + "http 0.2.12", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.1.0", + "indexmap", "slab", "tokio", "tokio-util", @@ -1781,28 +1953,19 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - -[[package]] -name = "hashbrown" -version = "0.13.2" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", + "allocator-api2", ] [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" -dependencies = [ - "ahash", - "allocator-api2", -] +checksum = "3a9bfc1af68b1726ea47d3d5109de126281def866b33970e10fbab11b5dafab3" [[package]] name = "heck" @@ -1811,13 +1974,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] -name = "hermit-abi" -version = "0.1.19" +name = "heck" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" @@ -1840,6 +2000,15 @@ dependencies = [ "digest", ] +[[package]] +name = "home" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "http" version = "0.2.12" @@ -1851,6 +2020,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -1858,15 +2038,38 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.12", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.1.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.8.0" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" [[package]] name = "httpdate" @@ -1882,17 +2085,17 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.28" +version = "0.14.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80" +checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" dependencies = [ "bytes", "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.26", + "http 0.2.12", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1905,18 +2108,23 @@ dependencies = [ ] [[package]] -name = "hyper-rustls" -version = "0.23.2" +name = "hyper" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1788965e61b367cd03a62950836d5cd41560c3577d90e40e0819373194d1661c" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" dependencies = [ - "http", - "hyper", - "log", - "rustls 0.20.9", - "rustls-native-certs", + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", "tokio", - "tokio-rustls 0.23.4", + "want", ] [[package]] @@ -1926,18 +2134,57 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", - "http", - "hyper", - "rustls 0.21.10", + "http 0.2.12", + "hyper 0.14.31", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", "tokio", "tokio-rustls 0.24.1", ] +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.5.0", + "hyper-util", + "rustls 0.23.16", + "rustls-native-certs 0.8.0", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "hyper 1.5.0", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "iana-time-zone" -version = "0.1.60" +version = "0.1.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1956,41 +2203,160 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locid" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_locid_transform" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_locid_transform_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_locid_transform_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" + +[[package]] +name = "icu_normalizer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "utf16_iter", + "utf8_iter", + "write16", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" + +[[package]] +name = "icu_properties" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locid_transform", + "icu_properties_data", + "icu_provider", + "tinystr", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" + +[[package]] +name = "icu_provider" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +dependencies = [ + "displaydoc", + "icu_locid", + "icu_provider_macros", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_provider_macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "idna" -version = "0.5.0" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ - "unicode-bidi", - "unicode-normalization", + "idna_adapter", + "smallvec", + "utf8_iter", ] [[package]] -name = "indexmap" -version = "1.9.3" +name = "idna_adapter" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" dependencies = [ - "autocfg", - "hashbrown 0.12.3", + "icu_normalizer", + "icu_properties", ] [[package]] name = "indexmap" -version = "2.2.6" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", - "hashbrown 0.14.3", + "hashbrown 0.15.1", ] [[package]] name = "instant" -version = "0.1.12" +version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" dependencies = [ "cfg-if", "js-sys", @@ -2006,15 +2372,21 @@ checksum = "8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02" [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" [[package]] name = "itertools" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" dependencies = [ "either", ] @@ -2027,33 +2399,33 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" -version = "0.1.30" +version = "0.1.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "685a7d121ee3f65ae4fddd72b25a04bb36b6af81bc0828f7d5434c0fe60fa3a2" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" dependencies = [ "libc", ] [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "6a88f1bda2bd75b0452a14784937d796722fdebfe50df998aeb3f0b7603019a9" dependencies = [ "wasm-bindgen", ] [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lexical-core" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +checksum = "0431c65b318a590c1de6b8fd6e72798c92291d27762d94c9e6c37ed7a73d8458" dependencies = [ "lexical-parse-float", "lexical-parse-integer", @@ -2064,9 +2436,9 @@ dependencies = [ [[package]] name = "lexical-parse-float" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +checksum = "eb17a4bdb9b418051aa59d41d65b1c9be5affab314a872e5ad7f06231fb3b4e0" dependencies = [ "lexical-parse-integer", "lexical-util", @@ -2075,9 +2447,9 @@ dependencies = [ [[package]] name = "lexical-parse-integer" -version = "0.8.6" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +checksum = "5df98f4a4ab53bf8b175b363a34c7af608fe31f93cc1fb1bf07130622ca4ef61" dependencies = [ "lexical-util", "static_assertions", @@ -2085,18 +2457,18 @@ dependencies = [ [[package]] name = "lexical-util" -version = "0.8.5" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +checksum = "85314db53332e5c192b6bca611fb10c114a80d1b831ddac0af1e9be1b9232ca0" dependencies = [ "static_assertions", ] [[package]] name = "lexical-write-float" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +checksum = "6e7c3ad4e37db81c1cbe7cf34610340adc09c322871972f74877a712abc6c809" dependencies = [ "lexical-util", "lexical-write-integer", @@ -2105,9 +2477,9 @@ dependencies = [ [[package]] name = "lexical-write-integer" -version = "0.8.5" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +checksum = "eb89e9f6958b83258afa3deed90b5de9ef68eef090ad5086c791cd2345610162" dependencies = [ "lexical-util", "static_assertions", @@ -2115,15 +2487,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.153" +version = "0.2.161" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1" [[package]] name = "libflate" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7d5654ae1795afc7ff76f4365c2c8791b0feb18e8996a96adad8ffd7c3b2bf" +checksum = "45d9dfdc14ea4ef0900c1cddbc8dcd553fbaacd8a4a282cf4018ae9dd04fb21e" dependencies = [ "adler32", "core2", @@ -2134,26 +2506,26 @@ dependencies = [ [[package]] name = "libflate_lz77" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be5f52fb8c451576ec6b79d3f4deb327398bc05bbdbd99021a6e77a4c855d524" +checksum = "e6e0d73b369f386f1c44abd9c570d5318f55ccde816ff4b562fa452e5182863d" dependencies = [ "core2", - "hashbrown 0.13.2", + "hashbrown 0.14.5", "rle-decode-fast", ] [[package]] name = "libm" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" [[package]] name = "libmimalloc-sys" -version = "0.1.35" +version = "0.1.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3979b5c37ece694f1f5e51e7ecc871fdb0f517ed04ee45f88d15d6d553cb9664" +checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" dependencies = [ "cc", "libc", @@ -2165,21 +2537,27 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "libc", ] [[package]] name = "linux-raw-sys" -version = "0.4.13" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" + +[[package]] +name = "litemap" +version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +checksum = "643cb0b8d4fcc284004d5fd0d67ccf61dfffadb7f75e1e71bc420f4688a3a704" [[package]] name = "lock_api" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" dependencies = [ "autocfg", "scopeguard", @@ -2187,9 +2565,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "lz4_flex" @@ -2223,15 +2601,15 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mimalloc" -version = "0.1.39" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa01922b5ea280a911e323e4d2fd24b7fe5cc4042e0d2cda3c40775cdc4bdc9c" +checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" dependencies = [ "libmimalloc-sys", ] @@ -2244,22 +2622,23 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.7.2" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" dependencies = [ - "adler", + "adler2", ] [[package]] name = "mio" -version = "0.8.11" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" dependencies = [ + "hermit-abi", "libc", "wasi", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2273,12 +2652,13 @@ dependencies = [ [[package]] name = "nix" -version = "0.26.4" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +checksum = "ab2156c4fce2f8df6c499cc1c763e4394b7482525bf2a9701c9d79d215f519e4" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", "cfg-if", + "cfg_aliases 0.1.1", "libc", ] @@ -2290,9 +2670,9 @@ checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" [[package]] name = "num" -version = "0.4.2" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135b08af27d103b0a51f2ae0f8632117b7b185ccf931445affa8df530576a41" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" dependencies = [ "num-bigint", "num-complex", @@ -2304,20 +2684,19 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ - "autocfg", "num-integer", "num-traits", ] [[package]] name = "num-complex" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "num-traits", ] @@ -2339,9 +2718,9 @@ dependencies = [ [[package]] name = "num-iter" -version = "0.1.44" +version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d869c01cc0c455284163fd0092f1f93835385ccab5a98a0dcc497b2f8bf055a9" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" dependencies = [ "autocfg", "num-integer", @@ -2350,11 +2729,10 @@ dependencies = [ [[package]] name = "num-rational" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" dependencies = [ - "autocfg", "num-bigint", "num-integer", "num-traits", @@ -2362,9 +2740,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", "libm", @@ -2376,32 +2754,32 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi", "libc", ] [[package]] name = "object" -version = "0.32.2" +version = "0.36.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" dependencies = [ "memchr", ] [[package]] name = "object_store" -version = "0.9.1" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8718f8b65fdf67a45108d1548347d4af7d71fb81ce727bbf9e3b2535e079db3" +checksum = "6eb4c22c6154a1e759d7099f9ffad7cc5ef8245f9efbab4a41b92623079c82f3" dependencies = [ "async-trait", - "base64 0.21.7", + "base64 0.22.1", "bytes", "chrono", "futures", "humantime", - "hyper", + "hyper 1.5.0", "itertools", "md-5", "parking_lot", @@ -2409,8 +2787,8 @@ dependencies = [ "quick-xml", "rand", "reqwest", - "ring 0.17.8", - "rustls-pemfile 2.1.2", + "ring", + "rustls-pemfile 2.2.0", "serde", "serde_json", "snafu", @@ -2422,9 +2800,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.19.0" +version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "openssl-probe" @@ -2432,6 +2810,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "ordered-float" version = "2.10.1" @@ -2441,12 +2825,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "os_str_bytes" -version = "6.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" - [[package]] name = "outref" version = "0.5.1" @@ -2455,9 +2833,9 @@ checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" [[package]] name = "parking_lot" -version = "0.12.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -2465,22 +2843,22 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.9" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] name = "parquet" -version = "51.0.0" +version = "53.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "096795d4f47f65fd3ee1ec5a98b77ab26d602f2cc785b0e4be5443add17ecc32" +checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e" dependencies = [ "ahash", "arrow-array", @@ -2490,14 +2868,14 @@ dependencies = [ "arrow-ipc", "arrow-schema", "arrow-select", - "base64 0.22.0", + "base64 0.22.1", "brotli", "bytes", "chrono", "flate2", "futures", "half", - "hashbrown 0.14.3", + "hashbrown 0.14.5", "lz4_flex", "num", "num-bigint", @@ -2508,23 +2886,24 @@ dependencies = [ "thrift", "tokio", "twox-hash", - "zstd 0.13.1", + "zstd 0.13.2", + "zstd-sys", ] [[package]] name = "parse-zoneinfo" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c705f256449c60da65e11ff6626e0c16a0a0b96aaa348de61376b249bc340f41" +checksum = "1f2a05b18d44e2957b88f96ba460715e295bc1d7510468a2f3d3b44535d26c24" dependencies = [ "regex", ] [[package]] name = "paste" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] name = "percent-encoding" @@ -2534,12 +2913,12 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" +checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.2.6", + "indexmap", ] [[package]] @@ -2580,31 +2959,11 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pin-project" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bf43b791c5b9e34c3d182969b4abb522f9343702850a2e57f460d00d09b4b3" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.60", -] - [[package]] name = "pin-project-lite" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" [[package]] name = "pin-utils" @@ -2614,9 +2973,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" [[package]] name = "powerfmt" @@ -2626,15 +2985,18 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] [[package]] name = "predicates" -version = "3.1.0" +version = "3.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" dependencies = [ "anstyle", "difflib", @@ -2646,74 +3008,108 @@ dependencies = [ [[package]] name = "predicates-core" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" +checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" [[package]] name = "predicates-tree" -version = "1.0.9" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" dependencies = [ "predicates-core", "termtree", ] [[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" +name = "proc-macro-crate" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" dependencies = [ - "proc-macro2", - "quote", - "version_check", + "toml_edit", ] [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" dependencies = [ "unicode-ident", ] [[package]] name = "quad-rand" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658fa1faf7a4cc5f057c9ee5ef560f717ad9d8dc66d975267f709624d6e1ab88" +checksum = "b76f1009795ca44bb5aaae8fd3f18953e209259c33d9b059b1f53d58ab7511db" [[package]] name = "quick-xml" -version = "0.31.0" +version = "0.36.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1004a344b30a54e2ee58d66a71b32d2db2feb0a31f9a2d302bf0536f15de2a33" +checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" dependencies = [ "memchr", "serde", ] +[[package]] +name = "quinn" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +dependencies = [ + "bytes", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls 0.23.16", + "socket2", + "thiserror", + "tokio", + "tracing", +] + +[[package]] +name = "quinn-proto" +version = "0.11.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +dependencies = [ + "bytes", + "rand", + "ring", + "rustc-hash", + "rustls 0.23.16", + "slab", + "thiserror", + "tinyvec", + "tracing", +] + +[[package]] +name = "quinn-udp" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e346e016eacfff12233c243718197ca12f148c84e1e84268a896699b41c71780" +dependencies = [ + "cfg_aliases 0.2.1", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.59.0", +] + [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -2760,18 +3156,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.4.1" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", ] [[package]] name = "redox_users" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom", "libredox", @@ -2780,9 +3176,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", @@ -2792,9 +3188,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" dependencies = [ "aho-corasick", "memchr", @@ -2803,32 +3199,39 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "relative-path" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" [[package]] name = "reqwest" -version = "0.11.27" +version = "0.12.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" +checksum = "a77c62af46e79de0a562e1a9849205ffcb7fc1238876e9bd743357570e04046f" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "bytes", - "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", - "hyper-rustls 0.24.2", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.5.0", + "hyper-rustls 0.27.3", + "hyper-util", "ipnet", "js-sys", "log", @@ -2836,39 +3239,25 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls 0.21.10", - "rustls-native-certs", - "rustls-pemfile 1.0.4", + "quinn", + "rustls 0.23.16", + "rustls-native-certs 0.8.0", + "rustls-pemfile 2.2.0", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", - "system-configuration", - "tokio", - "tokio-rustls 0.24.1", - "tokio-util", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "wasm-streams", - "web-sys", - "winreg", -] - -[[package]] -name = "ring" -version = "0.16.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" -dependencies = [ - "cc", - "libc", - "once_cell", - "spin 0.5.2", - "untrusted 0.7.1", + "tokio", + "tokio-rustls 0.26.0", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", "web-sys", - "winapi", + "windows-registry", ] [[package]] @@ -2881,8 +3270,8 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin 0.9.8", - "untrusted 0.9.0", + "spin", + "untrusted", "windows-sys 0.52.0", ] @@ -2894,9 +3283,9 @@ checksum = "3582f63211428f83597b51b2ddb88e2a91a9d52d12831f9d08f5e624e8977422" [[package]] name = "rstest" -version = "0.17.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de1bb486a691878cd320c2f0d319ba91eeaa2e894066d8b5f8f117c000e9d962" +checksum = "7b423f0e62bdd61734b67cd21ff50871dfaeb9cc74f869dcd6af974fbcb19936" dependencies = [ "futures", "futures-timer", @@ -2906,40 +3295,50 @@ dependencies = [ [[package]] name = "rstest_macros" -version = "0.17.0" +version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290ca1a1c8ca7edb7c3283bd44dc35dd54fdec6253a3912e201ba1072018fca8" +checksum = "c5e1711e7d14f74b12a58411c542185ef7fb7f2e7f8ee6e2940a883628522b42" dependencies = [ "cfg-if", + "glob", + "proc-macro-crate", "proc-macro2", "quote", + "regex", + "relative-path", "rustc_version", - "syn 1.0.109", + "syn", "unicode-ident", ] [[package]] name = "rustc-demangle" -version = "0.1.23" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc-hash" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" [[package]] name = "rustc_version" -version = "0.4.0" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ "semver", ] [[package]] name = "rustix" -version = "0.38.32" +version = "0.38.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" +checksum = "375116bee2be9ed569afe2154ea6a99dfdffd257f533f187498c2a8f5feaf4ee" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -2948,26 +3347,28 @@ dependencies = [ [[package]] name = "rustls" -version = "0.20.9" +version = "0.21.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", - "ring 0.16.20", + "ring", + "rustls-webpki 0.101.7", "sct", - "webpki", ] [[package]] name = "rustls" -version = "0.21.10" +version = "0.23.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" +checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" dependencies = [ - "log", - "ring 0.17.8", - "rustls-webpki", - "sct", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", ] [[package]] @@ -2982,6 +3383,19 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-native-certs" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcaf18a4f2be7326cd874a5fa579fae794320a0f388d365dca7e480e55f83f8a" +dependencies = [ + "openssl-probe", + "rustls-pemfile 2.2.0", + "rustls-pki-types", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2993,19 +3407,18 @@ dependencies = [ [[package]] name = "rustls-pemfile" -version = "2.1.2" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" dependencies = [ - "base64 0.22.0", "rustls-pki-types", ] [[package]] name = "rustls-pki-types" -version = "1.4.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecd36cc4259e3e4514335c4a138c6b43171a8d61d8f5c9348f9fc7529416f247" +checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" [[package]] name = "rustls-webpki" @@ -3013,44 +3426,54 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", + "ring", + "untrusted", +] + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", ] [[package]] name = "rustversion" -version = "1.0.15" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80af6f9131f277a45a3fba6ce8e2258037bb0477a67e610d3c1fe046ab31de47" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" [[package]] name = "rustyline" -version = "11.0.0" +version = "14.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfc8644681285d1fb67a467fb3021bfea306b99b4146b166a1fe3ada965eece" +checksum = "7803e8936da37efd9b6d4478277f4b2b9bb5cdb37a113e8d63222e58da647e63" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", "cfg-if", "clipboard-win", - "dirs-next", "fd-lock", + "home", "libc", "log", "memchr", "nix", "radix_trie", - "scopeguard", "unicode-segmentation", "unicode-width", "utf8parse", - "winapi", + "windows-sys 0.52.0", ] [[package]] name = "ryu" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "same-file" @@ -3063,11 +3486,11 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.23" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +checksum = "01227be5826fa0690321a2ba6c5cd57a19cf3f6a09e76973b58e61de6ab9d1c1" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3082,17 +3505,17 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", + "ring", + "untrusted", ] [[package]] name = "security-framework" -version = "2.10.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770452e37cad93e0a50d5abc3990d2bc351c36d0328f86cefec2f2fb206eaef6" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.6.0", "core-foundation", "core-foundation-sys", "libc", @@ -3101,9 +3524,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.10.0" +version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41f3cc463c0ef97e11c3461a9d3787412d30e8e7eb907c79180c4a57bf7c04ef" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" dependencies = [ "core-foundation-sys", "libc", @@ -3111,9 +3534,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.22" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" +checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" [[package]] name = "seq-macro" @@ -3123,31 +3546,32 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.198" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.214" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] name = "serde_json" -version = "1.0.116" +version = "1.0.132" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -3175,11 +3599,17 @@ dependencies = [ "digest", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" -version = "1.4.1" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" dependencies = [ "libc", ] @@ -3207,24 +3637,23 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snafu" -version = "0.7.5" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" +checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" dependencies = [ - "doc-comment", "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.7.5" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" +checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", - "syn 1.0.109", + "syn", ] [[package]] @@ -3235,20 +3664,14 @@ checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" [[package]] name = "socket2" -version = "0.5.6" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05ffd9c0a93b7543e062e759284fcf5f5e3b098501104bfbdde4d404db792871" +checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c" dependencies = [ "libc", "windows-sys 0.52.0", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" @@ -3257,9 +3680,9 @@ checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" [[package]] name = "sqlparser" -version = "0.45.0" +version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7bbffee862a796d67959a89859d6b1046bb5016d63e23835ad0da182777bbe0" +checksum = "5fe11944a61da0da3f592e19a45ebe5ab92dc14a779907ff1f08fbb797bfefc7" dependencies = [ "log", "sqlparser_derive", @@ -3273,26 +3696,26 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] -name = "static_assertions" -version = "1.1.0" +name = "stable_deref_trait" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] -name = "str-buf" -version = "1.0.6" +name = "static_assertions" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e08d8363704e6c71fc928674353e6b7c23dcea9d82d7012c8faf2a3a025f8d0" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "strsim" -version = "0.10.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "strum" @@ -3302,11 +3725,11 @@ checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" [[package]] name = "strum" -version = "0.26.2" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d8cec3501a5194c432b2b7976db6b7d10ec95c253208b45f83f7136aa985e29" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ - "strum_macros 0.26.2", + "strum_macros 0.26.4", ] [[package]] @@ -3315,48 +3738,37 @@ version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro2", "quote", "rustversion", - "syn 2.0.60", + "syn", ] [[package]] name = "strum_macros" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", - "syn 2.0.60", + "syn", ] [[package]] name = "subtle" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" - -[[package]] -name = "syn" -version = "1.0.109" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.60" +version = "2.0.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d" dependencies = [ "proc-macro2", "quote", @@ -3365,50 +3777,35 @@ dependencies = [ [[package]] name = "sync_wrapper" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" - -[[package]] -name = "system-configuration" -version = "0.5.1" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "system-configuration-sys", + "futures-core", ] [[package]] -name = "system-configuration-sys" -version = "0.5.0" +name = "synstructure" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ - "core-foundation-sys", - "libc", + "proc-macro2", + "quote", + "syn", ] [[package]] name = "tempfile" -version = "3.10.1" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" +checksum = "f0f2c9fc62d0beef6951ccffd757e241266a2c833136efbe35af6cd2567dca5b" dependencies = [ "cfg-if", - "fastrand 2.0.2", + "fastrand", + "once_cell", "rustix", - "windows-sys 0.52.0", -] - -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", + "windows-sys 0.59.0", ] [[package]] @@ -3417,30 +3814,24 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" -[[package]] -name = "textwrap" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" - [[package]] name = "thiserror" -version = "1.0.58" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" +checksum = "02dd99dc800bbb97186339685293e1cc5d9df1f8fae2d0aecd9ff1c77efea892" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.58" +version = "1.0.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" +checksum = "a7c61ec9a6f64d2793d8a45faba21efbe3ced62a886d44c36a009b2b519b4c7e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] @@ -3493,11 +3884,21 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tinystr" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -3510,43 +3911,31 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" dependencies = [ "backtrace", "bytes", "libc", "mio", - "num_cpus", "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", -] - -[[package]] -name = "tokio-rustls" -version = "0.23.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" -dependencies = [ - "rustls 0.20.9", - "tokio", - "webpki", + "syn", ] [[package]] @@ -3555,62 +3944,56 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls 0.21.10", + "rustls 0.21.12", "tokio", ] [[package]] -name = "tokio-stream" -version = "0.1.15" +name = "tokio-rustls" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "futures-core", - "pin-project-lite", + "rustls 0.23.16", + "rustls-pki-types", "tokio", ] [[package]] name = "tokio-util" -version = "0.7.10" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", "futures-sink", "pin-project-lite", "tokio", - "tracing", ] [[package]] -name = "tower" -version = "0.4.13" +name = "toml_datetime" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" -dependencies = [ - "futures-core", - "futures-util", - "pin-project", - "pin-project-lite", - "tokio", - "tower-layer", - "tower-service", - "tracing", -] +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" [[package]] -name = "tower-layer" -version = "0.3.2" +name = "toml_edit" +version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -3618,7 +4001,6 @@ version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ - "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -3632,7 +4014,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] @@ -3677,7 +4059,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", ] [[package]] @@ -3686,44 +4068,23 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "unicode-bidi" -version = "0.3.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" - [[package]] name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "unicode-normalization" -version = "0.1.23" +version = "1.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" -dependencies = [ - "tinyvec", -] +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" [[package]] name = "unicode-segmentation" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "unicode-width" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" - -[[package]] -name = "untrusted" -version = "0.7.1" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" [[package]] name = "untrusted" @@ -3733,9 +4094,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" -version = "2.5.0" +version = "2.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "8d157f1b96d14500ffdc1f10ba712e780825526c03d9a49b4d0324b0d9113ada" dependencies = [ "form_urlencoded", "idna", @@ -3748,17 +4109,29 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf16_iter" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.8.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom", "serde", @@ -3766,9 +4139,9 @@ dependencies = [ [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "vsimd" @@ -3812,34 +4185,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "128d1e363af62632b8eb57219c8fd7877144af57558fb2ef0368d0087bddeb2e" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "cb6dd4d3ca0ddffd1dd1c9c04f94b868c37ff5fac97c30b97cff2d74fce3a358" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.42" +version = "0.4.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76bc14366121efc8dbb487ab05bcc9d346b3b5ec0eaa76e46594cabbe51762c0" +checksum = "cc7ec4f8827a71586374db3e87abdb5a2bb3a15afed140221307c3ec06b1f63b" dependencies = [ "cfg-if", "js-sys", @@ -3849,9 +4223,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "e79384be7f8f5a9dd5d7167216f022090cf1f9ec128e6e6a482a2cb5c5422c56" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3859,28 +4233,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "26c6ab57572f7a24a4985830b120de1594465e5d500f24afe89e16b4e833ef68" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "65fc09f10666a9f147042251e0dda9c18f166ff7de300607007e96bdebc1068d" [[package]] name = "wasm-streams" -version = "0.4.0" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" dependencies = [ "futures-util", "js-sys", @@ -3891,62 +4265,60 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "f6488b90108c040df0fe62fa815cbdee25124641df01814dd7282749234c6112" dependencies = [ "js-sys", "wasm-bindgen", ] [[package]] -name = "webpki" -version = "0.22.4" +name = "winapi-util" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "ring 0.17.8", - "untrusted 0.9.0", + "windows-sys 0.59.0", ] [[package]] -name = "winapi" -version = "0.3.9" +name = "windows-core" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", + "windows-targets 0.52.6", ] [[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" +name = "windows-registry" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" dependencies = [ - "winapi", + "windows-result", + "windows-strings", + "windows-targets 0.52.6", ] [[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" +name = "windows-result" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] [[package]] -name = "windows-core" -version = "0.52.0" +name = "windows-strings" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" dependencies = [ - "windows-targets 0.52.5", + "windows-result", + "windows-targets 0.52.6", ] [[package]] @@ -3964,7 +4336,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -3984,18 +4365,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4006,9 +4387,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4018,9 +4399,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4030,15 +4411,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4048,9 +4429,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4060,9 +4441,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4072,9 +4453,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4084,20 +4465,31 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] -name = "winreg" -version = "0.50.0" +name = "winnow" +version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ - "cfg-if", - "windows-sys 0.48.0", + "memchr", ] +[[package]] +name = "write16" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" + +[[package]] +name = "writeable" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" + [[package]] name = "xmlparser" version = "0.13.6" @@ -4113,31 +4505,99 @@ dependencies = [ "lzma-sys", ] +[[package]] +name = "yoke" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c5b1314b079b0930c31e3af543d8ee1757b1951ae1e1565ec704403a7240ca5" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zerocopy" -version = "0.7.32" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ + "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.32" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zerofrom" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ec111ce797d0e0784a1116d0ddcdbea84322cd79e5d5ad173daeba4f93ab55" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" +checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn", + "synstructure", ] [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zerovec" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zstd" @@ -4150,11 +4610,11 @@ dependencies = [ [[package]] name = "zstd" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d789b1514203a1120ad2429eae43a7bd32b90976a7bb8a05f7ec02fa88cc23a" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" dependencies = [ - "zstd-safe 7.1.0", + "zstd-safe 7.2.1", ] [[package]] @@ -4169,18 +4629,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.1.0" +version = "7.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cd99b45c6bc03a018c8b8a86025678c87e55526064e38f9df301989dce7ec0a" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.10+zstd.1.5.6" +version = "2.0.13+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" dependencies = [ "cc", "pkg-config", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index c6019bc5970c..784d47220c7c 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,24 +18,28 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "37.1.0" -authors = ["Apache Arrow "] +version = "43.0.0" +authors = ["Apache DataFusion "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] license = "Apache-2.0" -homepage = "https://github.com/apache/datafusion" +homepage = "https://datafusion.apache.org" repository = "https://github.com/apache/datafusion" # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.79" readme = "README.md" [dependencies] -arrow = "51.0.0" -async-trait = "0.1.41" -aws-config = "0.55" -aws-credential-types = "0.55" -clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "37.1.0", features = [ +arrow = { version = "53.0.0" } +async-trait = "0.1.73" +aws-config = "1.5.5" +aws-sdk-sso = "1.43.0" +aws-sdk-ssooidc = "1.44.0" +aws-sdk-sts = "1.43.0" +# end pin aws-sdk crates +aws-credential-types = "1.2.0" +clap = { version = "4.5.16", features = ["derive", "cargo"] } +datafusion = { path = "../datafusion/core", version = "43.0.0", features = [ "avro", "crypto_expressions", "datetime_expressions", @@ -45,15 +49,15 @@ datafusion = { path = "../datafusion/core", version = "37.1.0", features = [ "unicode_expressions", "compression", ] } -dirs = "4.0.0" -env_logger = "0.9" +dirs = "5.0.1" +env_logger = "0.11" futures = "0.3" mimalloc = { version = "0.1", default-features = false } -object_store = { version = "0.9.0", features = ["aws", "gcp", "http"] } +object_store = { version = "0.11.0", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "51.0.0", default-features = false } +parquet = { version = "53.0.0", default-features = false } regex = "1.8" -rustyline = "11.0" +rustyline = "14.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } url = "2.2" @@ -61,4 +65,4 @@ url = "2.2" assert_cmd = "2.0" ctor = "0.2.0" predicates = "3.0" -rstest = "0.17" +rstest = "0.22" diff --git a/datafusion-cli/Dockerfile b/datafusion-cli/Dockerfile index 9dbab5b1ed75..79c24f6baf3e 100644 --- a/datafusion-cli/Dockerfile +++ b/datafusion-cli/Dockerfile @@ -15,22 +15,23 @@ # specific language governing permissions and limitations # under the License. -FROM rust:1.73-bullseye as builder +FROM rust:1.79-bookworm AS builder -COPY . /usr/src/arrow-datafusion -COPY ./datafusion /usr/src/arrow-datafusion/datafusion +COPY . /usr/src/datafusion +COPY ./datafusion /usr/src/datafusion/datafusion +COPY ./datafusion-cli /usr/src/datafusion/datafusion-cli -COPY ./datafusion-cli /usr/src/arrow-datafusion/datafusion-cli - -WORKDIR /usr/src/arrow-datafusion/datafusion-cli +WORKDIR /usr/src/datafusion/datafusion-cli RUN rustup component add rustfmt RUN cargo build --release -FROM debian:bullseye-slim +FROM debian:bookworm-slim + +COPY --from=builder /usr/src/datafusion/datafusion-cli/target/release/datafusion-cli /usr/local/bin -COPY --from=builder /usr/src/arrow-datafusion/datafusion-cli/target/release/datafusion-cli /usr/local/bin +RUN mkdir /data ENTRYPOINT ["datafusion-cli"] diff --git a/datafusion-cli/examples/cli-session-context.rs b/datafusion-cli/examples/cli-session-context.rs new file mode 100644 index 000000000000..1a8f15c8731b --- /dev/null +++ b/datafusion-cli/examples/cli-session-context.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shows an example of a custom session context that unions the input plan with itself. +//! To run this example, use `cargo run --example cli-session-context` from within the `datafusion-cli` directory. + +use std::sync::Arc; + +use datafusion::{ + dataframe::DataFrame, + error::DataFusionError, + execution::{context::SessionState, TaskContext}, + logical_expr::{LogicalPlan, LogicalPlanBuilder}, + prelude::SessionContext, +}; +use datafusion_cli::{ + cli_context::CliSessionContext, exec::exec_from_repl, print_options::PrintOptions, +}; +use object_store::ObjectStore; + +/// This is a toy example of a custom session context that unions the input plan with itself. +struct MyUnionerContext { + ctx: SessionContext, +} + +impl Default for MyUnionerContext { + fn default() -> Self { + Self { + ctx: SessionContext::new(), + } + } +} + +#[async_trait::async_trait] +impl CliSessionContext for MyUnionerContext { + fn task_ctx(&self) -> Arc { + self.ctx.task_ctx() + } + + fn session_state(&self) -> SessionState { + self.ctx.state() + } + + fn register_object_store( + &self, + url: &url::Url, + object_store: Arc, + ) -> Option> { + self.ctx.register_object_store(url, object_store) + } + + fn register_table_options_extension_from_scheme(&self, _scheme: &str) { + unimplemented!() + } + + async fn execute_logical_plan( + &self, + plan: LogicalPlan, + ) -> Result { + let new_plan = LogicalPlanBuilder::from(plan.clone()) + .union(plan.clone())? + .build()?; + + self.ctx.execute_logical_plan(new_plan).await + } +} + +#[tokio::main] +/// Runs the example. +pub async fn main() { + let my_ctx = MyUnionerContext::default(); + + let mut print_options = PrintOptions { + format: datafusion_cli::print_format::PrintFormat::Automatic, + quiet: false, + maxrows: datafusion_cli::print_options::MaxRows::Unlimited, + color: true, + }; + + exec_from_repl(&my_ctx, &mut print_options).await.unwrap(); +} diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index faa657da6511..ceb72dbc546b 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -20,28 +20,27 @@ use std::sync::{Arc, Weak}; use crate::object_storage::{get_object_store, AwsOptions, GcpOptions}; -use datafusion::catalog::schema::SchemaProvider; -use datafusion::catalog::{CatalogProvider, CatalogProviderList}; +use datafusion::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}; + use datafusion::common::plan_datafusion_err; -use datafusion::datasource::listing::{ - ListingTable, ListingTableConfig, ListingTableUrl, -}; +use datafusion::datasource::listing::ListingTableUrl; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::SessionState; +use datafusion::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use dirs::home_dir; use parking_lot::RwLock; -/// Wraps another catalog, automatically creating table providers -/// for local files if needed -pub struct DynamicFileCatalog { +/// Wraps another catalog, automatically register require object stores for the file locations +#[derive(Debug)] +pub struct DynamicObjectStoreCatalog { inner: Arc, state: Weak>, } -impl DynamicFileCatalog { +impl DynamicObjectStoreCatalog { pub fn new( inner: Arc, state: Weak>, @@ -50,7 +49,7 @@ impl DynamicFileCatalog { } } -impl CatalogProviderList for DynamicFileCatalog { +impl CatalogProviderList for DynamicObjectStoreCatalog { fn as_any(&self) -> &dyn Any { self } @@ -69,19 +68,20 @@ impl CatalogProviderList for DynamicFileCatalog { fn catalog(&self, name: &str) -> Option> { let state = self.state.clone(); - self.inner - .catalog(name) - .map(|catalog| Arc::new(DynamicFileCatalogProvider::new(catalog, state)) as _) + self.inner.catalog(name).map(|catalog| { + Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _ + }) } } /// Wraps another catalog provider -struct DynamicFileCatalogProvider { +#[derive(Debug)] +struct DynamicObjectStoreCatalogProvider { inner: Arc, state: Weak>, } -impl DynamicFileCatalogProvider { +impl DynamicObjectStoreCatalogProvider { pub fn new( inner: Arc, state: Weak>, @@ -90,7 +90,7 @@ impl DynamicFileCatalogProvider { } } -impl CatalogProvider for DynamicFileCatalogProvider { +impl CatalogProvider for DynamicObjectStoreCatalogProvider { fn as_any(&self) -> &dyn Any { self } @@ -101,9 +101,9 @@ impl CatalogProvider for DynamicFileCatalogProvider { fn schema(&self, name: &str) -> Option> { let state = self.state.clone(); - self.inner - .schema(name) - .map(|schema| Arc::new(DynamicFileSchemaProvider::new(schema, state)) as _) + self.inner.schema(name).map(|schema| { + Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _ + }) } fn register_schema( @@ -115,13 +115,15 @@ impl CatalogProvider for DynamicFileCatalogProvider { } } -/// Wraps another schema provider -struct DynamicFileSchemaProvider { +/// Wraps another schema provider. [DynamicObjectStoreSchemaProvider] is responsible for registering the required +/// object stores for the file locations. +#[derive(Debug)] +struct DynamicObjectStoreSchemaProvider { inner: Arc, state: Weak>, } -impl DynamicFileSchemaProvider { +impl DynamicObjectStoreSchemaProvider { pub fn new( inner: Arc, state: Weak>, @@ -131,7 +133,7 @@ impl DynamicFileSchemaProvider { } #[async_trait] -impl SchemaProvider for DynamicFileSchemaProvider { +impl SchemaProvider for DynamicObjectStoreSchemaProvider { fn as_any(&self) -> &dyn Any { self } @@ -149,9 +151,11 @@ impl SchemaProvider for DynamicFileSchemaProvider { } async fn table(&self, name: &str) -> Result>> { - let inner_table = self.inner.table(name).await?; - if inner_table.is_some() { - return Ok(inner_table); + let inner_table = self.inner.table(name).await; + if inner_table.is_ok() { + if let Some(inner_table) = inner_table? { + return Ok(Some(inner_table)); + } } // if the inner schema provider didn't have a table by @@ -162,6 +166,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { .ok_or_else(|| plan_datafusion_err!("locking error"))? .read() .clone(); + let mut builder = SessionStateBuilder::from(state.clone()); let optimized_name = substitute_tilde(name.to_owned()); let table_url = ListingTableUrl::parse(optimized_name.as_str())?; let scheme = table_url.scheme(); @@ -178,13 +183,18 @@ impl SchemaProvider for DynamicFileSchemaProvider { // to any command options so the only choice is to use an empty collection match scheme { "s3" | "oss" | "cos" => { - state = state.add_table_options_extension(AwsOptions::default()); + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(AwsOptions::default()) + } } "gs" | "gcs" => { - state = state.add_table_options_extension(GcpOptions::default()) + if let Some(table_options) = builder.table_options() { + table_options.extensions.insert(GcpOptions::default()) + } } _ => {} }; + state = builder.build(); let store = get_object_store( &state, table_url.scheme(), @@ -195,16 +205,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { state.runtime_env().register_object_store(url, store); } } - - let config = match ListingTableConfig::new(table_url).infer(&state).await { - Ok(cfg) => cfg, - Err(_) => { - // treat as non-existing - return Ok(None); - } - }; - - Ok(Some(Arc::new(ListingTable::try_new(config)?))) + self.inner.table(name).await } fn deregister_table(&self, name: &str) -> Result>> { @@ -215,7 +216,8 @@ impl SchemaProvider for DynamicFileSchemaProvider { self.inner.table_exist(name) } } -fn substitute_tilde(cur: String) -> String { + +pub fn substitute_tilde(cur: String) -> String { if let Some(usr_dir_path) = home_dir() { if let Some(usr_dir) = usr_dir_path.to_str() { if cur.starts_with('~') && !usr_dir.is_empty() { @@ -225,24 +227,25 @@ fn substitute_tilde(cur: String) -> String { } cur } - #[cfg(test)] mod tests { + use super::*; - use datafusion::catalog::schema::SchemaProvider; + use datafusion::catalog::SchemaProvider; use datafusion::prelude::SessionContext; fn setup_context() -> (SessionContext, Arc) { - let mut ctx = SessionContext::new(); - ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new( - ctx.state().catalog_list(), + let ctx = SessionContext::new(); + ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new( + ctx.state().catalog_list().clone(), ctx.state_weak_ref(), ))); - let provider = - &DynamicFileCatalog::new(ctx.state().catalog_list(), ctx.state_weak_ref()) - as &dyn CatalogProviderList; + let provider = &DynamicObjectStoreCatalog::new( + ctx.state().catalog_list().clone(), + ctx.state_weak_ref(), + ) as &dyn CatalogProviderList; let catalog = provider .catalog(provider.catalog_names().first().unwrap()) .unwrap(); @@ -262,7 +265,7 @@ mod tests { let (ctx, schema) = setup_context(); // That's a non registered table so expecting None here - let table = schema.table(&location).await.unwrap(); + let table = schema.table(&location).await?; assert!(table.is_none()); // It should still create an object store for the location in the SessionState @@ -286,7 +289,7 @@ mod tests { let (ctx, schema) = setup_context(); - let table = schema.table(&location).await.unwrap(); + let table = schema.table(&location).await?; assert!(table.is_none()); let store = ctx @@ -308,7 +311,7 @@ mod tests { let (ctx, schema) = setup_context(); - let table = schema.table(&location).await.unwrap(); + let table = schema.table(&location).await?; assert!(table.is_none()); let store = ctx @@ -330,6 +333,7 @@ mod tests { assert!(schema.table(location).await.is_err()); } + #[cfg(not(target_os = "windows"))] #[test] fn test_substitute_tilde() { diff --git a/datafusion-cli/src/cli_context.rs b/datafusion-cli/src/cli_context.rs new file mode 100644 index 000000000000..516929ebacf1 --- /dev/null +++ b/datafusion-cli/src/cli_context.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::{ + dataframe::DataFrame, + error::DataFusionError, + execution::{context::SessionState, TaskContext}, + logical_expr::LogicalPlan, + prelude::SessionContext, +}; +use object_store::ObjectStore; + +use crate::object_storage::{AwsOptions, GcpOptions}; + +#[async_trait::async_trait] +/// The CLI session context trait provides a way to have a session context that can be used with datafusion's CLI code. +pub trait CliSessionContext { + /// Get an atomic reference counted task context. + fn task_ctx(&self) -> Arc; + + /// Get the session state. + fn session_state(&self) -> SessionState; + + /// Register an object store with the session context. + fn register_object_store( + &self, + url: &url::Url, + object_store: Arc, + ) -> Option>; + + /// Register table options extension from scheme. + fn register_table_options_extension_from_scheme(&self, scheme: &str); + + /// Execute a logical plan and return a DataFrame. + async fn execute_logical_plan( + &self, + plan: LogicalPlan, + ) -> Result; +} + +#[async_trait::async_trait] +impl CliSessionContext for SessionContext { + fn task_ctx(&self) -> Arc { + self.task_ctx() + } + + fn session_state(&self) -> SessionState { + self.state() + } + + fn register_object_store( + &self, + url: &url::Url, + object_store: Arc, + ) -> Option> { + self.register_object_store(url, object_store) + } + + fn register_table_options_extension_from_scheme(&self, scheme: &str) { + match scheme { + // For Amazon S3 or Alibaba Cloud OSS + "s3" | "oss" | "cos" => { + // Register AWS specific table options in the session context: + self.register_table_options_extension(AwsOptions::default()) + } + // For Google Cloud Storage + "gs" | "gcs" => { + // Register GCP specific table options in the session context: + self.register_table_options_extension(GcpOptions::default()) + } + // For unsupported schemes, do nothing: + _ => {} + } + } + + async fn execute_logical_plan( + &self, + plan: LogicalPlan, + ) -> Result { + self.execute_logical_plan(plan).await + } +} diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs index d3d7b65f0a50..f0eb58a23391 100644 --- a/datafusion-cli/src/command.rs +++ b/datafusion-cli/src/command.rs @@ -17,18 +17,18 @@ //! Command within CLI -use crate::exec::exec_from_lines; +use crate::cli_context::CliSessionContext; +use crate::exec::{exec_and_print, exec_from_lines}; use crate::functions::{display_all_functions, Function}; use crate::print_format::PrintFormat; use crate::print_options::PrintOptions; -use clap::ArgEnum; +use clap::ValueEnum; use datafusion::arrow::array::{ArrayRef, StringArray}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::common::exec_err; use datafusion::common::instant::Instant; use datafusion::error::{DataFusionError, Result}; -use datafusion::prelude::SessionContext; use std::fs::File; use std::io::BufReader; use std::str::FromStr; @@ -55,21 +55,21 @@ pub enum OutputFormat { impl Command { pub async fn execute( &self, - ctx: &mut SessionContext, + ctx: &dyn CliSessionContext, print_options: &mut PrintOptions, ) -> Result<()> { - let now = Instant::now(); match self { - Self::Help => print_options.print_batches(&[all_commands_info()], now), + Self::Help => { + let now = Instant::now(); + let command_batch = all_commands_info(); + print_options.print_batches(command_batch.schema(), &[command_batch], now) + } Self::ListTables => { - let df = ctx.sql("SHOW TABLES").await?; - let batches = df.collect().await?; - print_options.print_batches(&batches, now) + exec_and_print(ctx, print_options, "SHOW TABLES".into()).await } Self::DescribeTableStmt(name) => { - let df = ctx.sql(&format!("SHOW COLUMNS FROM {}", name)).await?; - let batches = df.collect().await?; - print_options.print_batches(&batches, now) + exec_and_print(ctx, print_options, format!("SHOW COLUMNS FROM {}", name)) + .await } Self::Include(filename) => { if let Some(filename) = filename { diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 5fbcea0c0683..18906536691e 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -22,26 +22,26 @@ use std::fs::File; use std::io::prelude::*; use std::io::BufReader; +use crate::cli_context::CliSessionContext; use crate::helper::split_from_semicolon; use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, helper::{unescape_input, CliHelper}, - object_storage::{get_object_store, register_options}, + object_storage::get_object_store, print_options::{MaxRows, PrintOptions}, }; use datafusion::common::instant::Instant; use datafusion::common::plan_datafusion_err; +use datafusion::config::ConfigFileType; use datafusion::datasource::listing::ListingTableUrl; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{DdlStatement, LogicalPlan}; use datafusion::physical_plan::{collect, execute_stream, ExecutionPlanProperties}; -use datafusion::prelude::SessionContext; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; -use datafusion::common::FileType; use datafusion::sql::sqlparser; use rustyline::error::ReadlineError; use rustyline::Editor; @@ -49,7 +49,7 @@ use tokio::signal; /// run and execute SQL statements and commands, against a context with the given print options pub async fn exec_from_commands( - ctx: &mut SessionContext, + ctx: &dyn CliSessionContext, commands: Vec, print_options: &PrintOptions, ) -> Result<()> { @@ -62,7 +62,7 @@ pub async fn exec_from_commands( /// run and execute SQL statements and commands from a file, against a context with the given print options pub async fn exec_from_lines( - ctx: &mut SessionContext, + ctx: &dyn CliSessionContext, reader: &mut BufReader, print_options: &PrintOptions, ) -> Result<()> { @@ -70,6 +70,9 @@ pub async fn exec_from_lines( for line in reader.lines() { match line { + Ok(line) if line.starts_with("#!") => { + continue; + } Ok(line) if line.starts_with("--") => { continue; } @@ -81,7 +84,7 @@ pub async fn exec_from_lines( Ok(_) => {} Err(err) => eprintln!("{err}"), } - query = "".to_owned(); + query = "".to_string(); } else { query.push('\n'); } @@ -102,7 +105,7 @@ pub async fn exec_from_lines( } pub async fn exec_from_files( - ctx: &mut SessionContext, + ctx: &dyn CliSessionContext, files: Vec, print_options: &PrintOptions, ) -> Result<()> { @@ -121,7 +124,7 @@ pub async fn exec_from_files( /// run and execute SQL statements and commands against a context with the given print options pub async fn exec_from_repl( - ctx: &mut SessionContext, + ctx: &dyn CliSessionContext, print_options: &mut PrintOptions, ) -> rustyline::Result<()> { let mut rl = Editor::new()?; @@ -203,8 +206,8 @@ pub async fn exec_from_repl( rl.save_history(".history") } -async fn exec_and_print( - ctx: &mut SessionContext, +pub(super) async fn exec_and_print( + ctx: &dyn CliSessionContext, print_options: &PrintOptions, sql: String, ) -> Result<()> { @@ -235,8 +238,9 @@ async fn exec_and_print( let stream = execute_stream(physical_plan, task_ctx.clone())?; print_options.print_stream(stream, now).await?; } else { + let schema = physical_plan.schema(); let results = collect(physical_plan, task_ctx.clone()).await?; - adjusted.into_inner().print_batches(&results, now)?; + adjusted.into_inner().print_batches(schema, &results, now)?; } } @@ -289,33 +293,44 @@ impl AdjustedPrintOptions { } } +fn config_file_type_from_str(ext: &str) -> Option { + match ext.to_lowercase().as_str() { + "csv" => Some(ConfigFileType::CSV), + "json" => Some(ConfigFileType::JSON), + "parquet" => Some(ConfigFileType::PARQUET), + _ => None, + } +} + async fn create_plan( - ctx: &mut SessionContext, + ctx: &dyn CliSessionContext, statement: Statement, ) -> Result { - let mut plan = ctx.state().statement_to_plan(statement).await?; + let mut plan = ctx.session_state().statement_to_plan(statement).await?; // Note that cmd is a mutable reference so that create_external_table function can remove all // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + // To support custom formats, treat error as None + let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( ctx, &cmd.location, &cmd.options, - None, + format, ) .await?; } if let LogicalPlan::Copy(copy_to) = &mut plan { - let format: FileType = (©_to.format_options).into(); + let format = config_file_type_from_str(©_to.file_type.get_ext()); register_object_store_and_config_extensions( ctx, ©_to.output_url, ©_to.options, - Some(format), + format, ) .await?; } @@ -350,10 +365,10 @@ async fn create_plan( /// alteration fails, or if the object store cannot be retrieved and registered /// successfully. pub(crate) async fn register_object_store_and_config_extensions( - ctx: &SessionContext, + ctx: &dyn CliSessionContext, location: &String, options: &HashMap, - format: Option, + format: Option, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -365,20 +380,21 @@ pub(crate) async fn register_object_store_and_config_extensions( let url = table_path.as_ref(); // Register the options based on the scheme extracted from the location - register_options(ctx, scheme); + ctx.register_table_options_extension_from_scheme(scheme); // Clone and modify the default table options based on the provided options - let mut table_options = ctx.state().default_table_options().clone(); + let mut table_options = ctx.session_state().default_table_options(); if let Some(format) = format { - table_options.set_file_format(format); + table_options.set_config_format(format); } table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options - let store = get_object_store(&ctx.state(), scheme, url, &table_options).await?; + let store = + get_object_store(&ctx.session_state(), scheme, url, &table_options).await?; // Register the retrieved object store in the session context's runtime environment - ctx.runtime_env().register_object_store(url, store); + ctx.register_object_store(url, store); Ok(()) } @@ -387,9 +403,9 @@ pub(crate) async fn register_object_store_and_config_extensions( mod tests { use super::*; - use datafusion::common::config::FormatOptions; use datafusion::common::plan_err; + use datafusion::prelude::SessionContext; use url::Url; async fn create_external_table_test(location: &str, sql: &str) -> Result<()> { @@ -397,11 +413,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { + let format = config_file_type_from_str(&cmd.file_type); register_object_store_and_config_extensions( &ctx, &cmd.location, &cmd.options, - None, + format, ) .await?; } else { @@ -422,12 +439,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Copy(cmd) = &plan { - let format: FileType = (&cmd.format_options).into(); + let format = config_file_type_from_str(&cmd.file_type.get_ext()); register_object_store_and_config_extensions( &ctx, &cmd.output_url, &cmd.options, - Some(format), + format, ) .await?; } else { @@ -459,7 +476,7 @@ mod tests { "cos://bucket/path/file.parquet", "gcs://bucket/path/file.parquet", ]; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let task_ctx = ctx.task_ctx(); let dialect = &task_ctx.session_config().options().sql_parser.dialect; let dialect = dialect_from_str(dialect).ok_or_else(|| { @@ -474,10 +491,10 @@ mod tests { let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail - let mut plan = create_plan(&mut ctx, statement).await?; + let mut plan = create_plan(&ctx, statement).await?; if let LogicalPlan::Copy(copy_to) = &mut plan { assert_eq!(copy_to.output_url, location); - assert!(matches!(copy_to.format_options, FormatOptions::PARQUET(_))); + assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); ctx.runtime_env() .object_store_registry .get_store(&Url::parse(©_to.output_url).unwrap())?; @@ -600,4 +617,16 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn create_external_table_format_option() -> Result<()> { + let location = "path/to/file.cvs"; + + // Test with format options + let sql = + format!("CREATE EXTERNAL TABLE test STORED AS CSV LOCATION '{location}' OPTIONS('format.has_header' 'true')"); + create_external_table_test(location, &sql).await.unwrap(); + + Ok(()) + } } diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 806e2bb39cd4..c622463de033 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -22,16 +22,17 @@ use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use async_trait::async_trait; +use datafusion::catalog::Session; use datafusion::common::{plan_err, Column}; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; -use datafusion::execution::context::SessionState; use datafusion::logical_expr::Expr; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::scalar::ScalarValue; use parquet::basic::ConvertedType; +use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::file::reader::FileReader; use parquet::file::serialized_reader::SerializedFileReader; use parquet::file::statistics::Statistics; @@ -213,6 +214,7 @@ pub fn display_all_functions() -> Result<()> { } /// PARQUET_META table function +#[derive(Debug)] struct ParquetMetadataTable { schema: SchemaRef, batch: RecordBatch, @@ -234,7 +236,7 @@ impl TableProvider for ParquetMetadataTable { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -250,49 +252,70 @@ impl TableProvider for ParquetMetadataTable { fn convert_parquet_statistics( value: &Statistics, converted_type: ConvertedType, -) -> (String, String) { +) -> (Option, Option) { match (value, converted_type) { - (Statistics::Boolean(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::Int32(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::Int64(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::Int96(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::Float(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::Double(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::ByteArray(val), ConvertedType::UTF8) => { - let min_bytes = val.min(); - let max_bytes = val.max(); - let min = min_bytes - .as_utf8() - .map(|v| v.to_string()) - .unwrap_or_else(|_| min_bytes.to_string()); - - let max = max_bytes - .as_utf8() - .map(|v| v.to_string()) - .unwrap_or_else(|_| max_bytes.to_string()); - (min, max) - } - (Statistics::ByteArray(val), _) => (val.min().to_string(), val.max().to_string()), - (Statistics::FixedLenByteArray(val), ConvertedType::UTF8) => { - let min_bytes = val.min(); - let max_bytes = val.max(); - let min = min_bytes - .as_utf8() - .map(|v| v.to_string()) - .unwrap_or_else(|_| min_bytes.to_string()); - - let max = max_bytes - .as_utf8() - .map(|v| v.to_string()) - .unwrap_or_else(|_| max_bytes.to_string()); - (min, max) - } - (Statistics::FixedLenByteArray(val), _) => { - (val.min().to_string(), val.max().to_string()) - } + (Statistics::Boolean(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::Int32(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::Int64(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::Int96(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::Float(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::Double(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::ByteArray(val), ConvertedType::UTF8) => ( + byte_array_to_string(val.min_opt()), + byte_array_to_string(val.max_opt()), + ), + (Statistics::ByteArray(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), + (Statistics::FixedLenByteArray(val), ConvertedType::UTF8) => ( + fixed_len_byte_array_to_string(val.min_opt()), + fixed_len_byte_array_to_string(val.max_opt()), + ), + (Statistics::FixedLenByteArray(val), _) => ( + val.min_opt().map(|v| v.to_string()), + val.max_opt().map(|v| v.to_string()), + ), } } +/// Convert to a string if it has utf8 encoding, otherwise print bytes directly +fn byte_array_to_string(val: Option<&ByteArray>) -> Option { + val.map(|v| { + v.as_utf8() + .map(|s| s.to_string()) + .unwrap_or_else(|_e| v.to_string()) + }) +} + +/// Convert to a string if it has utf8 encoding, otherwise print bytes directly +fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option { + val.map(|v| { + v.as_utf8() + .map(|s| s.to_string()) + .unwrap_or_else(|_e| v.to_string()) + }) +} + +#[derive(Debug)] pub struct ParquetMetadataFunc {} impl TableFunctionImpl for ParquetMetadataFunc { @@ -376,17 +399,13 @@ impl TableFunctionImpl for ParquetMetadataFunc { let converted_type = column.column_descr().converted_type(); if let Some(s) = column.statistics() { - let (min_val, max_val) = if s.has_min_max_set() { - let (min_val, max_val) = - convert_parquet_statistics(s, converted_type); - (Some(min_val), Some(max_val)) - } else { - (None, None) - }; + let (min_val, max_val) = + convert_parquet_statistics(s, converted_type); stats_min_arr.push(min_val.clone()); stats_max_arr.push(max_val.clone()); - stats_null_count_arr.push(Some(s.null_count() as i64)); - stats_distinct_count_arr.push(s.distinct_count().map(|c| c as i64)); + stats_null_count_arr.push(s.null_count_opt().map(|c| c as i64)); + stats_distinct_count_arr + .push(s.distinct_count_opt().map(|c| c as i64)); stats_min_value_arr.push(min_val); stats_max_value_arr.push(max_val); } else { diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 85f14c1736dc..86a51f40a0a4 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -20,25 +20,20 @@ use std::borrow::Cow; +use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; + use datafusion::common::sql_datafusion_err; use datafusion::error::DataFusionError; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; use datafusion::sql::sqlparser::parser::ParserError; -use rustyline::completion::Completer; -use rustyline::completion::FilenameCompleter; -use rustyline::completion::Pair; + +use rustyline::completion::{Completer, FilenameCompleter, Pair}; use rustyline::error::ReadlineError; use rustyline::highlight::Highlighter; use rustyline::hint::Hinter; -use rustyline::validate::ValidationContext; -use rustyline::validate::ValidationResult; -use rustyline::validate::Validator; -use rustyline::Context; -use rustyline::Helper; -use rustyline::Result; - -use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; +use rustyline::validate::{ValidationContext, ValidationResult, Validator}; +use rustyline::{Context, Helper, Result}; pub struct CliHelper { completer: FilenameCompleter, @@ -123,8 +118,8 @@ impl Highlighter for CliHelper { self.highlighter.highlight(line, pos) } - fn highlight_char(&self, line: &str, pos: usize) -> bool { - self.highlighter.highlight_char(line, pos) + fn highlight_char(&self, line: &str, pos: usize, forced: bool) -> bool { + self.highlighter.highlight_char(line, pos, forced) } } @@ -257,54 +252,71 @@ mod tests { fn unescape_readline_input() -> Result<()> { let validator = CliHelper::default(); - // shoule be valid + // should be valid let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter ',';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',');" + .as_bytes(), + ), + &validator, + )?; assert!(matches!(result, ValidationResult::Valid(None))); let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter '\0';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\0');" + .as_bytes()), + &validator, + )?; assert!(matches!(result, ValidationResult::Valid(None))); let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter '\n';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\n');" + .as_bytes()), + &validator, + )?; assert!(matches!(result, ValidationResult::Valid(None))); let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter '\r';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\r');" + .as_bytes()), + &validator, + )?; assert!(matches!(result, ValidationResult::Valid(None))); let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter '\t';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\t');" + .as_bytes()), + &validator, + )?; assert!(matches!(result, ValidationResult::Valid(None))); let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter '\\';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\\');" + .as_bytes()), + &validator, + )?; assert!(matches!(result, ValidationResult::Valid(None))); - // should be invalid let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter ',,';".as_bytes()), - &validator, - )?; - assert!(matches!(result, ValidationResult::Invalid(Some(_)))); + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' ',,');" + .as_bytes()), + &validator, + )?; + assert!(matches!(result, ValidationResult::Valid(None))); + // should be invalid let result = readline_direct( - Cursor::new(r"create external table test stored as csv location 'data.csv' delimiter '\u{07}';".as_bytes()), - &validator, - )?; + Cursor::new( + r"create external table test stored as csv location 'data.csv' options ('format.delimiter' '\u{07}');" + .as_bytes()), + &validator, + )?; assert!(matches!(result, ValidationResult::Invalid(Some(_)))); Ok(()) @@ -314,7 +326,7 @@ mod tests { fn sql_dialect() -> Result<()> { let mut validator = CliHelper::default(); - // shoule be invalid in generic dialect + // should be invalid in generic dialect let result = readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; assert!( diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index 0bb75510b524..530b87af8732 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -73,7 +73,7 @@ impl Highlighter for SyntaxHighlighter { } } - fn highlight_char(&self, line: &str, _: usize) -> bool { + fn highlight_char(&self, line: &str, _pos: usize, _forced: bool) -> bool { !line.is_empty() } } diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs index 139a60b8cf16..fbfc9242a61d 100644 --- a/datafusion-cli/src/lib.rs +++ b/datafusion-cli/src/lib.rs @@ -19,11 +19,13 @@ pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION"); pub mod catalog; +pub mod cli_context; pub mod command; pub mod exec; pub mod functions; pub mod helper; pub mod highlighter; pub mod object_storage; +pub mod pool_type; pub mod print_format; pub mod print_options; diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 6f71ccafb729..4c6c352ff339 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -19,7 +19,6 @@ use std::collections::HashMap; use std::env; use std::path::Path; use std::process::ExitCode; -use std::str::FromStr; use std::sync::{Arc, OnceLock}; use datafusion::error::{DataFusionError, Result}; @@ -27,10 +26,11 @@ use datafusion::execution::context::SessionConfig; use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::SessionContext; -use datafusion_cli::catalog::DynamicFileCatalog; +use datafusion_cli::catalog::DynamicObjectStoreCatalog; use datafusion_cli::functions::ParquetMetadataFunc; use datafusion_cli::{ exec, + pool_type::PoolType, print_format::PrintFormat, print_options::{MaxRows, PrintOptions}, DATAFUSION_CLI_VERSION, @@ -42,24 +42,6 @@ use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -#[derive(PartialEq, Debug)] -enum PoolType { - Greedy, - Fair, -} - -impl FromStr for PoolType { - type Err = String; - - fn from_str(s: &str) -> Result { - match s { - "Greedy" | "greedy" => Ok(PoolType::Greedy), - "Fair" | "fair" => Ok(PoolType::Fair), - _ => Err(format!("Invalid memory pool type '{}'", s)), - } - } -} - #[derive(Debug, Parser, PartialEq)] #[clap(author, version, about, long_about= None)] struct Args { @@ -67,7 +49,7 @@ struct Args { short = 'p', long, help = "Path to your data, default to current directory", - validator(is_valid_data_dir) + value_parser(parse_valid_data_dir) )] data_path: Option, @@ -75,16 +57,16 @@ struct Args { short = 'b', long, help = "The batch size of each query, or use DataFusion default", - validator(is_valid_batch_size) + value_parser(parse_batch_size) )] batch_size: Option, #[clap( short = 'c', long, - multiple_values = true, + num_args = 0.., help = "Execute the given command string(s), then exit. Commands are expected to be non empty.", - validator(is_valid_command) + value_parser(parse_command) )] command: Vec, @@ -92,30 +74,30 @@ struct Args { short = 'm', long, help = "The memory pool limitation (e.g. '10g'), default to None (no limit)", - validator(is_valid_memory_pool_size) + value_parser(extract_memory_pool_size) )] - memory_limit: Option, + memory_limit: Option, #[clap( short, long, - multiple_values = true, + num_args = 0.., help = "Execute commands from file(s), then exit", - validator(is_valid_file) + value_parser(parse_valid_file) )] file: Vec, #[clap( short = 'r', long, - multiple_values = true, + num_args = 0.., help = "Run the provided files on startup instead of ~/.datafusionrc", - validator(is_valid_file), + value_parser(parse_valid_file), conflicts_with = "file" )] rc: Option>, - #[clap(long, arg_enum, default_value_t = PrintFormat::Automatic)] + #[clap(long, value_enum, default_value_t = PrintFormat::Automatic)] format: PrintFormat, #[clap( @@ -127,13 +109,14 @@ struct Args { #[clap( long, - help = "Specify the memory pool type 'greedy' or 'fair', default to 'greedy'" + help = "Specify the memory pool type 'greedy' or 'fair'", + default_value_t = PoolType::Greedy )] - mem_pool_type: Option, + mem_pool_type: PoolType, #[clap( long, - help = "The max number of rows to display for 'Table' format\n[default: 40] [possible values: numbers(0/10/...), inf(no limit)]", + help = "The max number of rows to display for 'Table' format\n[possible values: numbers(0/10/...), inf(no limit)]", default_value = "40" )] maxrows: MaxRows, @@ -177,18 +160,12 @@ async fn main_inner() -> Result<()> { let rt_config = // set memory pool size if let Some(memory_limit) = args.memory_limit { - let memory_limit = extract_memory_pool_size(&memory_limit).unwrap(); // set memory pool type - if let Some(mem_pool_type) = args.mem_pool_type { - match mem_pool_type { - PoolType::Greedy => rt_config - .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))), - PoolType::Fair => rt_config - .with_memory_pool(Arc::new(FairSpillPool::new(memory_limit))), - } - } else { - rt_config - .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))) + match args.mem_pool_type { + PoolType::Fair => rt_config + .with_memory_pool(Arc::new(FairSpillPool::new(memory_limit))), + PoolType::Greedy => rt_config + .with_memory_pool(Arc::new(GreedyMemoryPool::new(memory_limit))) } } else { rt_config @@ -196,12 +173,14 @@ async fn main_inner() -> Result<()> { let runtime_env = create_runtime_env(rt_config.clone())?; - let mut ctx = - SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)); + // enable dynamic file query + let ctx = + SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env)) + .enable_url_table(); ctx.refresh_catalogs().await?; - // install dynamic catalog provider that knows how to open files - ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new( - ctx.state().catalog_list(), + // install dynamic catalog provider that can register required object stores + ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new( + ctx.state().catalog_list().clone(), ctx.state_weak_ref(), ))); // register `parquet_metadata` table function to get metadata from parquet files @@ -233,62 +212,55 @@ async fn main_inner() -> Result<()> { if commands.is_empty() && files.is_empty() { if !rc.is_empty() { - exec::exec_from_files(&mut ctx, rc, &print_options).await?; + exec::exec_from_files(&ctx, rc, &print_options).await?; } // TODO maybe we can have thiserror for cli but for now let's keep it simple - return exec::exec_from_repl(&mut ctx, &mut print_options) + return exec::exec_from_repl(&ctx, &mut print_options) .await .map_err(|e| DataFusionError::External(Box::new(e))); } if !files.is_empty() { - exec::exec_from_files(&mut ctx, files, &print_options).await?; + exec::exec_from_files(&ctx, files, &print_options).await?; } if !commands.is_empty() { - exec::exec_from_commands(&mut ctx, commands, &print_options).await?; + exec::exec_from_commands(&ctx, commands, &print_options).await?; } Ok(()) } fn create_runtime_env(rn_config: RuntimeConfig) -> Result { - RuntimeEnv::new(rn_config) + RuntimeEnv::try_new(rn_config) } -fn is_valid_file(dir: &str) -> Result<(), String> { +fn parse_valid_file(dir: &str) -> Result { if Path::new(dir).is_file() { - Ok(()) + Ok(dir.to_string()) } else { Err(format!("Invalid file '{}'", dir)) } } -fn is_valid_data_dir(dir: &str) -> Result<(), String> { +fn parse_valid_data_dir(dir: &str) -> Result { if Path::new(dir).is_dir() { - Ok(()) + Ok(dir.to_string()) } else { Err(format!("Invalid data directory '{}'", dir)) } } -fn is_valid_batch_size(size: &str) -> Result<(), String> { +fn parse_batch_size(size: &str) -> Result { match size.parse::() { - Ok(size) if size > 0 => Ok(()), + Ok(size) if size > 0 => Ok(size), _ => Err(format!("Invalid batch size '{}'", size)), } } -fn is_valid_memory_pool_size(size: &str) -> Result<(), String> { - match extract_memory_pool_size(size) { - Ok(_) => Ok(()), - Err(e) => Err(e), - } -} - -fn is_valid_command(command: &str) -> Result<(), String> { +fn parse_command(command: &str) -> Result { if !command.is_empty() { - Ok(()) + Ok(command.to_string()) } else { Err("-c flag expects only non empty commands".to_string()) } @@ -304,7 +276,7 @@ enum ByteUnit { } impl ByteUnit { - fn multiplier(&self) -> usize { + fn multiplier(&self) -> u64 { match self { ByteUnit::Byte => 1, ByteUnit::KiB => 1 << 10, @@ -349,8 +321,12 @@ fn extract_memory_pool_size(size: &str) -> Result { let unit = byte_suffixes() .get(suffix) .ok_or_else(|| format!("Invalid memory pool size '{}'", size))?; + let memory_pool_size = usize::try_from(unit.multiplier()) + .ok() + .and_then(|multiplier| num.checked_mul(multiplier)) + .ok_or_else(|| format!("Memory pool size '{}' is too large", size))?; - Ok(num * unit.multiplier()) + Ok(memory_pool_size) } else { Err(format!("Invalid memory pool size '{}'", size)) } diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 85e0009bd267..3d999766e03f 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -25,9 +25,9 @@ use datafusion::common::config::{ use datafusion::common::{config_err, exec_datafusion_err, exec_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; -use datafusion::prelude::SessionContext; use async_trait::async_trait; +use aws_config::BehaviorVersion; use aws_credential_types::provider::ProvideCredentials; use object_store::aws::{AmazonS3Builder, AwsCredential}; use object_store::gcp::GoogleCloudStorageBuilder; @@ -62,7 +62,7 @@ pub async fn get_s3_object_store_builder( builder = builder.with_token(session_token); } } else { - let config = aws_config::from_env().load().await; + let config = aws_config::defaults(BehaviorVersion::latest()).load().await; if let Some(region) = config.region() { builder = builder.with_region(region.to_string()); } @@ -392,48 +392,6 @@ impl ConfigExtension for GcpOptions { const PREFIX: &'static str = "gcp"; } -/// Registers storage options for different cloud storage schemes in a given -/// session context. -/// -/// This function is responsible for extending the session context with specific -/// options based on the storage scheme being used. These options are essential -/// for handling interactions with different cloud storage services such as Amazon -/// S3, Alibaba Cloud OSS, Google Cloud Storage, etc. -/// -/// # Parameters -/// -/// * `ctx` - A mutable reference to the session context where table options are -/// to be registered. The session context holds configuration and environment -/// for the current session. -/// * `scheme` - A string slice that represents the cloud storage scheme. This -/// determines which set of options will be registered in the session context. -/// -/// # Supported Schemes -/// -/// * `s3` or `oss` - Registers `AwsOptions` which are configurations specific to -/// Amazon S3 and Alibaba Cloud OSS. -/// * `gs` or `gcs` - Registers `GcpOptions` which are configurations specific to -/// Google Cloud Storage. -/// -/// NOTE: This function will not perform any action when given an unsupported scheme. -pub(crate) fn register_options(ctx: &SessionContext, scheme: &str) { - // Match the provided scheme against supported cloud storage schemes: - match scheme { - // For Amazon S3 or Alibaba Cloud OSS - "s3" | "oss" | "cos" => { - // Register AWS specific table options in the session context: - ctx.register_table_options_extension(AwsOptions::default()) - } - // For Google Cloud Storage - "gs" | "gcs" => { - // Register GCP specific table options in the session context: - ctx.register_table_options_extension(GcpOptions::default()) - } - // For unsupported schemes, do nothing: - _ => {} - } -} - pub(crate) async fn get_object_store( state: &SessionState, scheme: &str, @@ -498,6 +456,8 @@ pub(crate) async fn get_object_store( #[cfg(test)] mod tests { + use crate::cli_context::CliSessionContext; + use super::*; use datafusion::common::plan_err; @@ -534,8 +494,8 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); - let mut table_options = ctx.state().default_table_options().clone(); + ctx.register_table_options_extension_from_scheme(scheme); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let builder = @@ -579,8 +539,8 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); - let mut table_options = ctx.state().default_table_options().clone(); + ctx.register_table_options_extension_from_scheme(scheme); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let err = get_s3_object_store_builder(table_url.as_ref(), aws_options) @@ -605,8 +565,8 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); - let mut table_options = ctx.state().default_table_options().clone(); + ctx.register_table_options_extension_from_scheme(scheme); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); // ensure this isn't an error @@ -633,8 +593,8 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); - let mut table_options = ctx.state().default_table_options().clone(); + ctx.register_table_options_extension_from_scheme(scheme); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let aws_options = table_options.extensions.get::().unwrap(); let builder = get_oss_object_store_builder(table_url.as_ref(), aws_options)?; @@ -670,8 +630,8 @@ mod tests { let mut plan = ctx.state().create_logical_plan(&sql).await?; if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &mut plan { - register_options(&ctx, scheme); - let mut table_options = ctx.state().default_table_options().clone(); + ctx.register_table_options_extension_from_scheme(scheme); + let mut table_options = ctx.state().default_table_options(); table_options.alter_with_string_hash_map(&cmd.options)?; let gcp_options = table_options.extensions.get::().unwrap(); let builder = get_gcs_object_store_builder(table_url.as_ref(), gcp_options)?; diff --git a/datafusion-cli/src/pool_type.rs b/datafusion-cli/src/pool_type.rs new file mode 100644 index 000000000000..269790b61f5a --- /dev/null +++ b/datafusion-cli/src/pool_type.rs @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + fmt::{self, Display, Formatter}, + str::FromStr, +}; + +#[derive(PartialEq, Debug, Clone)] +pub enum PoolType { + Greedy, + Fair, +} + +impl FromStr for PoolType { + type Err = String; + + fn from_str(s: &str) -> Result { + match s { + "Greedy" | "greedy" => Ok(PoolType::Greedy), + "Fair" | "fair" => Ok(PoolType::Fair), + _ => Err(format!("Invalid memory pool type '{}'", s)), + } + } +} + +impl Display for PoolType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + match self { + PoolType::Greedy => write!(f, "greedy"), + PoolType::Fair => write!(f, "fair"), + } + } +} diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs index 2de52be612bb..92cb106d622b 100644 --- a/datafusion-cli/src/print_format.rs +++ b/datafusion-cli/src/print_format.rs @@ -22,6 +22,7 @@ use std::str::FromStr; use crate::print_options::MaxRows; use arrow::csv::writer::WriterBuilder; +use arrow::datatypes::SchemaRef; use arrow::json::{ArrayWriter, LineDelimitedWriter}; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches_with_options; @@ -29,7 +30,7 @@ use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; use datafusion::error::Result; /// Allow records to be printed in different formats -#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, clap::ValueEnum, Clone, Copy)] pub enum PrintFormat { Csv, Tsv, @@ -43,7 +44,7 @@ impl FromStr for PrintFormat { type Err = String; fn from_str(s: &str) -> Result { - clap::ArgEnum::from_str(s, true) + clap::ValueEnum::from_str(s, true) } } @@ -157,6 +158,7 @@ impl PrintFormat { pub fn print_batches( &self, writer: &mut W, + schema: SchemaRef, batches: &[RecordBatch], maxrows: MaxRows, with_header: bool, @@ -168,7 +170,7 @@ impl PrintFormat { .cloned() .collect(); if batches.is_empty() { - return Ok(()); + return self.print_empty(writer, schema); } match self { @@ -186,6 +188,27 @@ impl PrintFormat { Self::NdJson => batches_to_json!(LineDelimitedWriter, writer, &batches), } } + + /// Print when the result batches contain no rows + fn print_empty( + &self, + writer: &mut W, + schema: SchemaRef, + ) -> Result<()> { + match self { + // Print column headers for Table format + Self::Table if !schema.fields().is_empty() => { + let empty_batch = RecordBatch::new_empty(schema); + let formatted = pretty_format_batches_with_options( + &[empty_batch], + &DEFAULT_FORMAT_OPTIONS, + )?; + writeln!(writer, "{}", formatted)?; + } + _ => {} + } + Ok(()) + } } #[cfg(test)] @@ -193,7 +216,7 @@ mod tests { use super::*; use std::sync::Arc; - use arrow::array::{ArrayRef, Int32Array}; + use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; #[test] @@ -201,7 +224,6 @@ mod tests { for format in [ PrintFormat::Csv, PrintFormat::Tsv, - PrintFormat::Table, PrintFormat::Json, PrintFormat::NdJson, PrintFormat::Automatic, @@ -209,10 +231,26 @@ mod tests { // no output for empty batches, even with header set PrintBatchesTest::new() .with_format(format) + .with_schema(three_column_schema()) .with_batches(vec![]) .with_expected(&[""]) .run(); } + + // output column headers for empty batches when format is Table + #[rustfmt::skip] + let expected = &[ + "+---+---+---+", + "| a | b | c |", + "+---+---+---+", + "+---+---+---+", + ]; + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_schema(three_column_schema()) + .with_batches(vec![]) + .with_expected(expected) + .run(); } #[test] @@ -385,6 +423,7 @@ mod tests { for max_rows in [MaxRows::Unlimited, MaxRows::Limited(5), MaxRows::Limited(3)] { PrintBatchesTest::new() .with_format(PrintFormat::Table) + .with_schema(one_column_schema()) .with_batches(vec![one_column_batch()]) .with_maxrows(max_rows) .with_expected(expected) @@ -450,15 +489,15 @@ mod tests { let empty_batch = RecordBatch::new_empty(batch.schema()); #[rustfmt::skip] - let expected =&[ - "+---+", - "| a |", - "+---+", - "| 1 |", - "| 2 |", - "| 3 |", - "+---+", - ]; + let expected =&[ + "+---+", + "| a |", + "+---+", + "| 1 |", + "| 2 |", + "| 3 |", + "+---+", + ]; PrintBatchesTest::new() .with_format(PrintFormat::Table) @@ -468,14 +507,32 @@ mod tests { } #[test] - fn test_print_batches_empty_batches_no_header() { + fn test_print_batches_empty_batch() { let empty_batch = RecordBatch::new_empty(one_column_batch().schema()); - // empty batches should not print a header - let expected = &[""]; + // Print column headers for empty batch when format is Table + #[rustfmt::skip] + let expected =&[ + "+---+", + "| a |", + "+---+", + "+---+", + ]; + + PrintBatchesTest::new() + .with_format(PrintFormat::Table) + .with_schema(one_column_schema()) + .with_batches(vec![empty_batch]) + .with_header(WithHeader::Yes) + .with_expected(expected) + .run(); + // No output for empty batch when schema contains no columns + let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); + let expected = &[""]; PrintBatchesTest::new() .with_format(PrintFormat::Table) + .with_schema(Arc::new(Schema::empty())) .with_batches(vec![empty_batch]) .with_header(WithHeader::Yes) .with_expected(expected) @@ -485,6 +542,7 @@ mod tests { #[derive(Debug)] struct PrintBatchesTest { format: PrintFormat, + schema: SchemaRef, batches: Vec, maxrows: MaxRows, with_header: WithHeader, @@ -504,6 +562,7 @@ mod tests { fn new() -> Self { Self { format: PrintFormat::Table, + schema: Arc::new(Schema::empty()), batches: vec![], maxrows: MaxRows::Unlimited, with_header: WithHeader::Ignored, @@ -517,6 +576,12 @@ mod tests { self } + // set the schema + fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = schema; + self + } + /// set the batches to convert fn with_batches(mut self, batches: Vec) -> Self { self.batches = batches; @@ -573,21 +638,31 @@ mod tests { fn output_with_header(&self, with_header: bool) -> String { let mut buffer: Vec = vec![]; self.format - .print_batches(&mut buffer, &self.batches, self.maxrows, with_header) + .print_batches( + &mut buffer, + self.schema.clone(), + &self.batches, + self.maxrows, + with_header, + ) .unwrap(); String::from_utf8(buffer).unwrap() } } - /// Return a batch with three columns and three rows - fn three_column_batch() -> RecordBatch { - let schema = Arc::new(Schema::new(vec![ + /// Return a schema with three columns + fn three_column_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), Field::new("c", DataType::Int32, false), - ])); + ])) + } + + /// Return a batch with three columns and three rows + fn three_column_batch() -> RecordBatch { RecordBatch::try_new( - schema, + three_column_schema(), vec![ Arc::new(Int32Array::from(vec![1, 2, 3])), Arc::new(Int32Array::from(vec![4, 5, 6])), @@ -597,12 +672,17 @@ mod tests { .unwrap() } + /// Return a schema with one column + fn one_column_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])) + } + /// return a batch with one column and three rows fn one_column_batch() -> RecordBatch { - RecordBatch::try_from_iter(vec![( - "a", - Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, - )]) + RecordBatch::try_new( + one_column_schema(), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + ) .unwrap() } diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index bede5dd15eb6..e80cc55663ae 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use datafusion::common::instant::Instant; use std::fmt::{Display, Formatter}; use std::io::Write; use std::pin::Pin; @@ -23,7 +22,9 @@ use std::str::FromStr; use crate::print_format::PrintFormat; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion::common::instant::Instant; use datafusion::common::DataFusionError; use datafusion::error::Result; use datafusion::physical_plan::RecordBatchStream; @@ -98,6 +99,7 @@ impl PrintOptions { /// Print the batches to stdout using the specified format pub fn print_batches( &self, + schema: SchemaRef, batches: &[RecordBatch], query_start_time: Instant, ) -> Result<()> { @@ -105,7 +107,7 @@ impl PrintOptions { let mut writer = stdout.lock(); self.format - .print_batches(&mut writer, batches, self.maxrows, true)?; + .print_batches(&mut writer, schema, batches, self.maxrows, true)?; let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); let formatted_exec_details = get_execution_details_formatted( @@ -148,6 +150,7 @@ impl PrintOptions { row_count += batch.num_rows(); self.format.print_batches( &mut writer, + batch.schema(), &[batch], MaxRows::Unlimited, with_header, diff --git a/datafusion-cli/tests/cli_integration.rs b/datafusion-cli/tests/cli_integration.rs index 119a0aa39d3c..27cabf15afec 100644 --- a/datafusion-cli/tests/cli_integration.rs +++ b/datafusion-cli/tests/cli_integration.rs @@ -28,6 +28,8 @@ fn init() { let _ = env_logger::try_init(); } +// Disabled due to https://github.com/apache/datafusion/issues/10793 +#[cfg(not(target_family = "windows"))] #[rstest] #[case::exec_from_commands( ["--command", "select 1", "--format", "json", "-q"], diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 0074a2b8d40c..e2432abdc138 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -62,8 +62,10 @@ dashmap = { workspace = true } datafusion = { workspace = true, default-features = true, features = ["avro"] } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions-window-common = { workspace = true } datafusion-optimizer = { workspace = true, default-features = true } datafusion-physical-expr = { workspace = true, default-features = true } +datafusion-proto = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } @@ -71,12 +73,16 @@ log = { workspace = true } mimalloc = { version = "0.1", default-features = false } num_cpus = { workspace = true } object_store = { workspace = true, features = ["aws", "http"] } -prost = { version = "0.12", default-features = false } -prost-derive = { version = "0.12", default-features = false } +prost = { workspace = true } +prost-derive = { workspace = true } serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } +test-utils = { path = "../test-utils" } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -tonic = "0.11" +tonic = "0.12.1" url = { workspace = true } uuid = "1.7" + +[target.'cfg(not(target_os = "windows"))'.dev-dependencies] +nix = { version = "0.28.0", features = ["fs"] } diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index 4b0e64ebdb7e..5f032c3e9cff 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -19,7 +19,8 @@ # DataFusion Examples -This crate includes several examples of how to use various DataFusion APIs and help you on your way. +This crate includes end to end, highly commented examples of how to use +various DataFusion APIs to help you get started. ## Prerequisites: @@ -27,7 +28,7 @@ Run `git submodule update --init` to init test files. ## Running Examples -To run the examples, use the `cargo run` command, such as: +To run an example, use the `cargo run` command, such as: ```bash git clone https://github.com/apache/datafusion @@ -35,9 +36,12 @@ cd datafusion # Download test data git submodule update --init -# Run the `csv_sql` example: +# Change to the examples directory +cd datafusion-examples/examples + +# Run the `dataframe` example: # ... use the equivalent for other examples -cargo run --example csv_sql +cargo run --example dataframe ``` ## Single Process @@ -45,31 +49,40 @@ cargo run --example csv_sql - [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) -- [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file +- [`advanced_parquet_index.rs`](examples/advanced_parquet_index.rs): Creates a detailed secondary index that covers the contents of several parquet files +- [`analyzer_rule.rs`](examples/analyzer_rule.rs): Use a custom AnalyzerRule to change a query's semantics (row level access control) - [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog -- [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file +- [`composed_extension_codec`](examples/composed_extension_codec.rs): Example of using multiple extension codecs for serialization / deserialization - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) +- [`custom_file_format.rs`](examples/custom_file_format.rs): Write data to a custom file format - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 - [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory - [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and analyze `Expr`s +- [`file_stream_provider.rs`](examples/file_stream_provider.rs): Run a query on `FileStreamProvider` which implements `StreamProvider` for reading and writing to arbitrary stream sources / sinks. - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es -- [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file +- [`optimizer_rule.rs`](examples/optimizer_rule.rs): Use a custom OptimizerRule to replace certain predicates +- [`parquet_index.rs`](examples/parquet_index.rs): Create an secondary index over several parquet files and use it to speed up queries - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files -- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics +- [`parquet_exec_visitor.rs`](examples/parquet_exec_visitor.rs): Extract statistics by visiting an ExecutionPlan after execution +- [`parse_sql_expr.rs`](examples/parse_sql_expr.rs): Parse SQL text into DataFusion `Expr`. +- [`plan_to_sql.rs`](examples/plan_to_sql.rs): Generate SQL from DataFusion `Expr` and `LogicalPlan` +- [`planner_api.rs](examples/planner_api.rs): APIs to manipulate logical and physical plans +- [`pruning.rs`](examples/pruning.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions -- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) +- [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with DataFusion structures +- [`sql_frontend.rs`](examples/sql_frontend.rs): Create LogicalPlans (only) from sql strings - [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` - [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function - [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions diff --git a/datafusion-examples/examples/advanced_parquet_index.rs b/datafusion-examples/examples/advanced_parquet_index.rs new file mode 100644 index 000000000000..f6860bb5b87a --- /dev/null +++ b/datafusion-examples/examples/advanced_parquet_index.rs @@ -0,0 +1,664 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use bytes::Bytes; +use datafusion::catalog::Session; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::{ + ParquetAccessPlan, ParquetExecBuilder, +}; +use datafusion::datasource::physical_plan::{ + parquet::ParquetFileReaderFactory, FileMeta, FileScanConfig, +}; +use datafusion::datasource::TableProvider; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::arrow_reader::{ + ArrowReaderOptions, ParquetRecordBatchReaderBuilder, RowSelection, RowSelector, +}; +use datafusion::parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; +use datafusion::parquet::arrow::ArrowWriter; +use datafusion::parquet::file::metadata::ParquetMetaData; +use datafusion::parquet::file::properties::{EnabledStatistics, WriterProperties}; +use datafusion::parquet::schema::types::ColumnPath; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_optimizer::pruning::PruningPredicate; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::*; +use datafusion_common::{ + internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::utils::conjunction; +use datafusion_expr::{TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr::utils::{Guarantee, LiteralGuarantee}; +use futures::future::BoxFuture; +use futures::FutureExt; +use object_store::ObjectStore; +use std::any::Any; +use std::collections::{HashMap, HashSet}; +use std::fs::File; +use std::ops::Range; +use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use tempfile::TempDir; +use url::Url; + +/// This example demonstrates using low level DataFusion APIs to read only +/// certain row groups and ranges from parquet files, based on external +/// information. +/// +/// Using these APIs, you can instruct DataFusion's parquet reader to skip +/// ("prune") portions of files that do not contain relevant data. These APIs +/// can be useful for doing low latency queries over a large number of Parquet +/// files on remote storage (e.g. S3) where the cost of reading the metadata for +/// each file is high (e.g. because it requires a network round trip to the +/// storage service). +/// +/// Depending on the information from the index, DataFusion can make a request +/// to the storage service (e.g. S3) to read only the necessary data. +/// +/// Note that this example uses a hard coded index implementation. For a more +/// realistic example of creating an index to prune files, see the +/// `parquet_index.rs` example. +/// +/// Specifically, this example illustrates how to: +/// 1. Use [`ParquetFileReaderFactory`] to avoid re-reading parquet metadata on each query +/// 2. Use [`PruningPredicate`] for predicate analysis +/// 3. Pass a row group selection to [`ParuetExec`] +/// 4. Pass a row selection (within a row group) to [`ParquetExec`] +/// +/// Note this is a *VERY* low level example for people who want to build their +/// own custom indexes (e.g. for low latency queries). Most users should use +/// higher level APIs for reading parquet files: +/// [`SessionContext::read_parquet`] or [`ListingTable`], which also do file +/// pruning based on parquet statistics (using the same underlying APIs) +/// +/// # Diagram +/// +/// This diagram shows how the `ParquetExec` is configured to do only a single +/// (range) read from a parquet file, for the data that is needed. It does +/// not read the file footer or any of the row groups that are not needed. +/// +/// ```text +/// ┌───────────────────────┐ The TableProvider configures the +/// │ ┌───────────────────┐ │ ParquetExec: +/// │ │ │ │ +/// │ └───────────────────┘ │ +/// │ ┌───────────────────┐ │ +/// Row │ │ │ │ 1. To read only specific Row +/// Groups │ └───────────────────┘ │ Groups (the ParquetExec tries +/// │ ┌───────────────────┐ │ to reduce this further based +/// │ │ │ │ on metadata) +/// │ └───────────────────┘ │ ┌────────────────────┐ +/// │ ┌───────────────────┐ │ │ │ +/// │ │ │◀┼ ─ ─ ┐ │ ParquetExec │ +/// │ └───────────────────┘ │ │ (Parquet Reader) │ +/// │ ... │ └ ─ ─ ─ ─│ │ +/// │ ┌───────────────────┐ │ │ ╔═══════════════╗ │ +/// │ │ │ │ │ ║ParquetMetadata║ │ +/// │ └───────────────────┘ │ │ ╚═══════════════╝ │ +/// │ ╔═══════════════════╗ │ └────────────────────┘ +/// │ ║ Thrift metadata ║ │ +/// │ ╚═══════════════════╝ │ 1. With cached ParquetMetadata, so +/// └───────────────────────┘ the ParquetExec does not re-read / +/// Parquet File decode the thrift footer +/// +/// ``` +/// +/// Within a Row Group, Column Chunks store data in DataPages. This example also +/// shows how to configure the ParquetExec to read a `RowSelection` (row ranges) +/// which will skip unneeded data pages. This requires that the Parquet file has +/// a [Page Index]. +/// +/// ```text +/// ┌───────────────────────┐ If the RowSelection does not include any +/// │ ... │ rows from a particular Data Page, that +/// │ │ Data Page is not fetched or decoded. +/// │ ┌───────────────────┐ │ Note this requires a PageIndex +/// │ │ ┌──────────┐ │ │ +/// Row │ │ │DataPage 0│ │ │ ┌────────────────────┐ +/// Groups │ │ └──────────┘ │ │ │ │ +/// │ │ ┌──────────┐ │ │ │ ParquetExec │ +/// │ │ ... │DataPage 1│ ◀┼ ┼ ─ ─ ─ │ (Parquet Reader) │ +/// │ │ └──────────┘ │ │ └ ─ ─ ─ ─ ─│ │ +/// │ │ ┌──────────┐ │ │ │ ╔═══════════════╗ │ +/// │ │ │DataPage 2│ │ │ If only rows │ ║ParquetMetadata║ │ +/// │ │ └──────────┘ │ │ from DataPage 1 │ ╚═══════════════╝ │ +/// │ └───────────────────┘ │ are selected, └────────────────────┘ +/// │ │ only DataPage 1 +/// │ ... │ is fetched and +/// │ │ decoded +/// │ ╔═══════════════════╗ │ +/// │ ║ Thrift metadata ║ │ +/// │ ╚═══════════════════╝ │ +/// └───────────────────────┘ +/// Parquet File +/// ``` +/// +/// [`ListingTable`]: datafusion::datasource::listing::ListingTable +/// [Page Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) +#[tokio::main] +async fn main() -> Result<()> { + // the object store is used to read the parquet files (in this case, it is + // a local file system, but in a real system it could be S3, GCS, etc) + let object_store: Arc = + Arc::new(object_store::local::LocalFileSystem::new()); + + // Create a custom table provider with our special index. + let provider = Arc::new(IndexTableProvider::try_new(Arc::clone(&object_store))?); + + // SessionContext for running queries that has the table provider + // registered as "index_table" + let ctx = SessionContext::new(); + ctx.register_table("index_table", Arc::clone(&provider) as _)?; + + // register object store provider for urls like `file://` work + let url = Url::try_from("file://").unwrap(); + ctx.register_object_store(&url, object_store); + + // Select data from the table without any predicates (and thus no pruning) + println!("** Select data, no predicates:"); + ctx.sql("SELECT avg(id), max(text) FROM index_table") + .await? + .show() + .await?; + // the underlying parquet reader makes 10 IO requests, one for each row group + + // Now, run a query that has a predicate that our index can handle + // + // For this query, the access plan specifies skipping 8 row groups + // and scanning 2 of them. The skipped row groups are not read at all: + // + // [Skip, Skip, Scan, Skip, Skip, Skip, Skip, Scan, Skip, Skip] + // + // Note that the parquet reader makes 2 IO requests - one for the data from + // each row group. + println!("** Select data, predicate `id IN (250, 750)`"); + ctx.sql("SELECT text FROM index_table WHERE id IN (250, 750)") + .await? + .show() + .await?; + + // Finally, demonstrate scanning sub ranges within the row groups. + // Parquet's minimum decode unit is a page, so specifying ranges + // within a row group can be used to skip pages within a row group. + // + // For this query, the access plan specifies skipping all but the last row + // group and within the last row group, reading only the row with id 950 + // + // [Skip, Skip, Skip, Skip, Skip, Skip, Skip, Skip, Skip, Selection(skip 49, select 1, skip 50)] + // + // Note that the parquet reader makes a single IO request - for the data + // pages that must be decoded + // + // Note: in order to prune pages, the Page Index must be loaded and the + // ParquetExec will load it on demand if not present. To avoid a second IO + // during query, this example loaded the Page Index pre-emptively by setting + // `ArrowReader::with_page_index` in `IndexedFile::try_new` + provider.set_use_row_selection(true); + println!("** Select data, predicate `id = 950`"); + ctx.sql("SELECT text FROM index_table WHERE id = 950") + .await? + .show() + .await?; + + Ok(()) +} + +/// DataFusion `TableProvider` that uses knowledge of how data is distributed in +/// a file to prune row groups and rows from the file. +/// +/// `file1.parquet` contains values `0..1000` +#[derive(Debug)] +pub struct IndexTableProvider { + /// Where the file is stored (cleanup on drop) + #[allow(dead_code)] + tmpdir: TempDir, + /// The file that is being read. + indexed_file: IndexedFile, + /// The underlying object store + object_store: Arc, + /// if true, use row selections in addition to row group selections + use_row_selections: AtomicBool, +} +impl IndexTableProvider { + /// Create a new IndexTableProvider + /// * `object_store` - the object store implementation to use for reading files + pub fn try_new(object_store: Arc) -> Result { + let tmpdir = TempDir::new().expect("Can't make temporary directory"); + + let indexed_file = + IndexedFile::try_new(tmpdir.path().join("indexed_file.parquet"), 0..1000)?; + + Ok(Self { + indexed_file, + tmpdir, + object_store, + use_row_selections: AtomicBool::new(false), + }) + } + + /// set the value of use row selections + pub fn set_use_row_selection(&self, use_row_selections: bool) { + self.use_row_selections + .store(use_row_selections, Ordering::SeqCst); + } + + /// return the value of use row selections + pub fn use_row_selections(&self) -> bool { + self.use_row_selections.load(Ordering::SeqCst) + } + + /// convert filters like `a = 1`, `b = 2` + /// to a single predicate like `a = 1 AND b = 2` suitable for execution + fn filters_to_predicate( + &self, + state: &dyn Session, + filters: &[Expr], + ) -> Result> { + let df_schema = DFSchema::try_from(self.schema())?; + + let predicate = conjunction(filters.to_vec()); + let predicate = predicate + .map(|predicate| state.create_physical_expr(predicate, &df_schema)) + .transpose()? + // if there are no filters, use a literal true to have a predicate + // that always evaluates to true we can pass to the index + .unwrap_or_else(|| datafusion_physical_expr::expressions::lit(true)); + + Ok(predicate) + } + + /// Returns a [`ParquetAccessPlan`] that specifies how to scan the + /// parquet file. + /// + /// A `ParquetAccessPlan` specifies which row groups and which rows within + /// those row groups to scan. + fn create_plan( + &self, + predicate: &Arc, + ) -> Result { + // In this example, we use the PruningPredicate's literal guarantees to + // analyze the predicate. In a real system, using + // `PruningPredicate::prune` would likely be easier to do. + let pruning_predicate = + PruningPredicate::try_new(Arc::clone(predicate), self.schema())?; + + // The PruningPredicate's guarantees must all be satisfied in order for + // the predicate to possibly evaluate to true. + let guarantees = pruning_predicate.literal_guarantees(); + let Some(constants) = self.value_constants(guarantees) else { + return Ok(self.indexed_file.scan_all_plan()); + }; + + // Begin with a plan that skips all row groups. + let mut plan = self.indexed_file.scan_none_plan(); + + // determine which row groups have the values in the guarantees + for value in constants { + let ScalarValue::Int32(Some(val)) = value else { + // if we have unexpected type of constant, no pruning is possible + return Ok(self.indexed_file.scan_all_plan()); + }; + + // Since we know the values in the files are between 0..1000 and + // evenly distributed between in row groups, calculate in what row + // group this value appears and tell the parquet reader to read it + let val = *val as usize; + let num_rows_in_row_group = 1000 / plan.len(); + let row_group_index = val / num_rows_in_row_group; + plan.scan(row_group_index); + + // If we want to use row selections, which the parquet reader can + // use to skip data pages when the parquet file has a "page index" + // and the reader is configured to read it, add a row selection + if self.use_row_selections() { + let offset_in_row_group = val - row_group_index * num_rows_in_row_group; + let selection = RowSelection::from(vec![ + // skip rows before the desired row + RowSelector::skip(offset_in_row_group.saturating_sub(1)), + // select the actual row + RowSelector::select(1), + // skip any remaining rows in the group + RowSelector::skip(num_rows_in_row_group - offset_in_row_group), + ]); + + plan.scan_selection(row_group_index, selection); + } + } + + Ok(plan) + } + + /// Returns the set of constants that the `"id"` column must take in order + /// for the predicate to be true. + /// + /// If `None` is returned, we can't extract the necessary information from + /// the guarantees. + fn value_constants<'a>( + &self, + guarantees: &'a [LiteralGuarantee], + ) -> Option<&'a HashSet> { + // only handle a single guarantee for column in this example + if guarantees.len() != 1 { + return None; + } + let guarantee = guarantees.first()?; + + // Only handle IN guarantees for the "in" column + if guarantee.guarantee != Guarantee::In || guarantee.column.name() != "id" { + return None; + } + Some(&guarantee.literals) + } +} + +/// Stores information needed to scan a file +#[derive(Debug)] +struct IndexedFile { + /// File name + file_name: String, + /// The path of the file + path: PathBuf, + /// The size of the file + file_size: u64, + /// The pre-parsed parquet metadata for the file + metadata: Arc, + /// The arrow schema of the file + schema: SchemaRef, +} + +impl IndexedFile { + fn try_new(path: impl AsRef, value_range: Range) -> Result { + let path = path.as_ref(); + // write the actual file + make_demo_file(path, value_range)?; + + // Now, open the file and read its size and metadata + let file_name = path + .file_name() + .ok_or_else(|| internal_datafusion_err!("Invalid path"))? + .to_str() + .ok_or_else(|| internal_datafusion_err!("Invalid filename"))? + .to_string(); + let file_size = path.metadata()?.len(); + + let file = File::open(path).map_err(|e| { + DataFusionError::from(e).context(format!("Error opening file {path:?}")) + })?; + + let options = ArrowReaderOptions::new() + // Load the page index when reading metadata to cache + // so it is available to interpret row selections + .with_page_index(true); + let reader = + ParquetRecordBatchReaderBuilder::try_new_with_options(file, options)?; + let metadata = reader.metadata().clone(); + let schema = reader.schema().clone(); + + // canonicalize after writing the file + let path = std::fs::canonicalize(path)?; + + Ok(Self { + file_name, + path, + file_size, + metadata, + schema, + }) + } + + /// Return a `PartitionedFile` to scan the underlying file + /// + /// The returned value does not have any `ParquetAccessPlan` specified in + /// its extensions. + fn partitioned_file(&self) -> PartitionedFile { + PartitionedFile::new(self.path.display().to_string(), self.file_size) + } + + /// Return a `ParquetAccessPlan` that scans all row groups in the file + fn scan_all_plan(&self) -> ParquetAccessPlan { + ParquetAccessPlan::new_all(self.metadata.num_row_groups()) + } + + /// Return a `ParquetAccessPlan` that scans no row groups in the file + fn scan_none_plan(&self) -> ParquetAccessPlan { + ParquetAccessPlan::new_none(self.metadata.num_row_groups()) + } +} + +/// Implement the TableProvider trait for IndexTableProvider +/// so that we can query it as a table. +#[async_trait] +impl TableProvider for IndexTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.indexed_file.schema) + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let indexed_file = &self.indexed_file; + let predicate = self.filters_to_predicate(state, filters)?; + + // Figure out which row groups to scan based on the predicate + let access_plan = self.create_plan(&predicate)?; + println!("{access_plan:?}"); + + let partitioned_file = indexed_file + .partitioned_file() + // provide the starting access plan to the ParquetExec by + // storing it as "extensions" on PartitionedFile + .with_extensions(Arc::new(access_plan) as _); + + // Prepare for scanning + let schema = self.schema(); + let object_store_url = ObjectStoreUrl::parse("file://")?; + let file_scan_config = FileScanConfig::new(object_store_url, schema) + .with_limit(limit) + .with_projection(projection.cloned()) + .with_file(partitioned_file); + + // Configure a factory interface to avoid re-reading the metadata for each file + let reader_factory = + CachedParquetFileReaderFactory::new(Arc::clone(&self.object_store)) + .with_file(indexed_file); + + // Finally, put it all together into a ParquetExec + Ok(ParquetExecBuilder::new(file_scan_config) + // provide the predicate so the ParquetExec can try and prune + // row groups internally + .with_predicate(predicate) + // provide the factory to create parquet reader without re-reading metadata + .with_parquet_file_reader_factory(Arc::new(reader_factory)) + .build_arc()) + } + + /// Tell DataFusion to push filters down to the scan method + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Inexact because the pruning can't handle all expressions and pruning + // is not done at the row level -- there may be rows in returned files + // that do not pass the filter + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } +} + +/// A custom [`ParquetFileReaderFactory`] that handles opening parquet files +/// from object storage, and uses pre-loaded metadata. + +#[derive(Debug)] +struct CachedParquetFileReaderFactory { + /// The underlying object store implementation for reading file data + object_store: Arc, + /// The parquet metadata for each file in the index, keyed by the file name + /// (e.g. `file1.parquet`) + metadata: HashMap>, +} + +impl CachedParquetFileReaderFactory { + fn new(object_store: Arc) -> Self { + Self { + object_store, + metadata: HashMap::new(), + } + } + /// Add the pre-parsed information about the file to the factor + fn with_file(mut self, indexed_file: &IndexedFile) -> Self { + self.metadata.insert( + indexed_file.file_name.clone(), + Arc::clone(&indexed_file.metadata), + ); + self + } +} + +impl ParquetFileReaderFactory for CachedParquetFileReaderFactory { + fn create_reader( + &self, + _partition_index: usize, + file_meta: FileMeta, + metadata_size_hint: Option, + _metrics: &ExecutionPlanMetricsSet, + ) -> Result> { + // for this example we ignore the partition index and metrics + // but in a real system you would likely use them to report details on + // the performance of the reader. + let filename = file_meta + .location() + .parts() + .last() + .expect("No path in location") + .as_ref() + .to_string(); + + let object_store = Arc::clone(&self.object_store); + let mut inner = ParquetObjectReader::new(object_store, file_meta.object_meta); + + if let Some(hint) = metadata_size_hint { + inner = inner.with_footer_size_hint(hint) + }; + + let metadata = self + .metadata + .get(&filename) + .expect("metadata for file not found: {filename}"); + Ok(Box::new(ParquetReaderWithCache { + filename, + metadata: Arc::clone(metadata), + inner, + })) + } +} + +/// wrapper around a ParquetObjectReader that caches metadata +struct ParquetReaderWithCache { + filename: String, + metadata: Arc, + inner: ParquetObjectReader, +} + +impl AsyncFileReader for ParquetReaderWithCache { + fn get_bytes( + &mut self, + range: Range, + ) -> BoxFuture<'_, datafusion::parquet::errors::Result> { + println!("get_bytes: {} Reading range {:?}", self.filename, range); + self.inner.get_bytes(range) + } + + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, datafusion::parquet::errors::Result>> { + println!( + "get_byte_ranges: {} Reading ranges {:?}", + self.filename, ranges + ); + self.inner.get_byte_ranges(ranges) + } + + fn get_metadata( + &mut self, + ) -> BoxFuture<'_, datafusion::parquet::errors::Result>> { + println!("get_metadata: {} returning cached metadata", self.filename); + + // return the cached metadata so the parquet reader does not read it + let metadata = self.metadata.clone(); + async move { Ok(metadata) }.boxed() + } +} + +/// Creates a new parquet file at the specified path. +/// +/// * id: Int32 +/// * text: Utf8 +/// +/// The `id` column increases sequentially from `min_value` to `max_value` +/// The `text` column is a repeating sequence of `TheTextValue{i}` +/// +/// Each row group has 100 rows +fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> { + let path = path.as_ref(); + let file = File::create(path)?; + + let id = Int32Array::from_iter_values(value_range.clone()); + let text = + StringArray::from_iter_values(value_range.map(|i| format!("TheTextValue{i}"))); + + let batch = RecordBatch::try_from_iter(vec![ + ("id", Arc::new(id) as ArrayRef), + ("text", Arc::new(text) as ArrayRef), + ])?; + + let schema = batch.schema(); + + // enable page statistics for the tag column, + // for everything else. + let props = WriterProperties::builder() + .set_max_row_group_size(100) + // compute column chunk (per row group) statistics by default + .set_statistics_enabled(EnabledStatistics::Chunk) + // compute column page statistics for the tag column + .set_column_statistics_enabled(ColumnPath::from("tag"), EnabledStatistics::Page) + .build(); + + // write the actual values to the file + let mut writer = ArrowWriter::try_new(file, schema, Some(props))?; + writer.write(&batch)?; + writer.close()?; + + Ok(()) +} diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 342a23b6e73d..414596bdc678 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,8 +31,8 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, AggregateUDFImpl, - GroupsAccumulator, Signature, + function::{AccumulatorArgs, StateFieldsArgs}, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; /// This example shows how to use the full AggregateUDFImpl API to implement a user @@ -92,25 +92,23 @@ impl AggregateUDFImpl for GeoMeanUdaf { } /// This is the description of the state. accumulator's state() must match the types here. - fn state_fields( - &self, - _name: &str, - value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ - Field::new("prod", value_type, true), + Field::new("prod", args.return_type.clone(), true), Field::new("n", DataType::UInt32, true), ]) } /// Tell DataFusion that this aggregate supports the more performant `GroupsAccumulator` /// which is used for cases when there are grouping columns in the query - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(GeometricMeanGroupsAccumulator::new())) } } @@ -195,7 +193,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -341,7 +339,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { Ok(()) } - /// Generate output, as specififed by `emit_to` and update the intermediate state + /// Generate output, as specified by `emit_to` and update the intermediate state fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { let counts = emit_to.take_needed(&mut self.counts); let prods = emit_to.take_needed(&mut self.prods); @@ -396,8 +394,8 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { } fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.prods.capacity() * std::mem::size_of::() + self.counts.capacity() * size_of::() + + self.prods.capacity() * size_of::() } } diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index c8063c0eb1e3..9a3ee9c8ebcd 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -15,26 +15,21 @@ // specific language governing permissions and limitations // under the License. -use datafusion::{ - arrow::{ - array::{ArrayRef, Float32Array, Float64Array}, - datatypes::DataType, - record_batch::RecordBatch, - }, - logical_expr::Volatility, -}; use std::any::Any; +use std::sync::Arc; -use arrow::array::{new_null_array, Array, AsArray}; +use arrow::array::{ + new_null_array, Array, ArrayRef, AsArray, Float32Array, Float64Array, +}; use arrow::compute; -use arrow::datatypes::Float64Type; +use arrow::datatypes::{DataType, Float64Type}; +use arrow::record_batch::RecordBatch; use datafusion::error::Result; +use datafusion::logical_expr::Volatility; use datafusion::prelude::*; use datafusion_common::{internal_err, ScalarValue}; -use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, -}; -use std::sync::Arc; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; /// This example shows how to use the full ScalarUDFImpl API to implement a user /// defined function. As in the `simple_udf.rs` example, this struct implements @@ -186,8 +181,9 @@ impl ScalarUDFImpl for PowUdf { &self.aliases } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // The POW function preserves the order of its argument. + Ok(input[0].sort_properties) } } diff --git a/datafusion-examples/examples/advanced_udwf.rs b/datafusion-examples/examples/advanced_udwf.rs index 41c6381df5d4..1c20e292f091 100644 --- a/datafusion-examples/examples/advanced_udwf.rs +++ b/datafusion-examples/examples/advanced_udwf.rs @@ -22,12 +22,15 @@ use arrow::{ array::{ArrayRef, AsArray, Float64Array}, datatypes::Float64Type, }; +use arrow_schema::Field; use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::ScalarValue; +use datafusion_expr::function::WindowUDFFieldArgs; use datafusion_expr::{ PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; /// This example shows how to use the full WindowUDFImpl API to implement a user /// defined window function. As in the `simple_udwf.rs` example, this struct implements @@ -70,16 +73,18 @@ impl WindowUDFImpl for SmoothItUdf { &self.signature } - /// What is the type of value that will be returned by this function. - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - /// Create a `PartitionEvalutor` to evaluate this function on a new + /// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true)) + } } /// This implements the lowest level evaluation for a window function @@ -216,12 +221,12 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results diff --git a/datafusion-examples/examples/analyzer_rule.rs b/datafusion-examples/examples/analyzer_rule.rs new file mode 100644 index 000000000000..bd067be97b8b --- /dev/null +++ b/datafusion-examples/examples/analyzer_rule.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use datafusion::prelude::SessionContext; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; +use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_optimizer::analyzer::AnalyzerRule; +use std::sync::{Arc, Mutex}; + +/// This example demonstrates how to add your own [`AnalyzerRule`] to +/// DataFusion. +/// +/// [`AnalyzerRule`]s transform [`LogicalPlan`]s prior to the DataFusion +/// optimization process, and can be used to change the plan's semantics (e.g. +/// output types). +/// +/// This example shows an `AnalyzerRule` which implements a simplistic of row +/// level access control scheme by introducing a filter to the query. +/// +/// See [optimizer_rule.rs] for an example of a optimizer rule +#[tokio::main] +pub async fn main() -> Result<()> { + // AnalyzerRules run before OptimizerRules. + // + // DataFusion includes several built in AnalyzerRules for tasks such as type + // coercion which change the types of expressions in the plan. Add our new + // rule to the context to run it during the analysis phase. + let rule = Arc::new(RowLevelAccessControl::new()); + let ctx = SessionContext::new(); + ctx.add_analyzer_rule(Arc::clone(&rule) as _); + + ctx.register_batch("employee", employee_batch())?; + + // Now, planning any SQL statement also invokes the AnalyzerRule + let plan = ctx + .sql("SELECT * FROM employee") + .await? + .into_optimized_plan()?; + + // Printing the query plan shows a filter has been added + // + // Filter: employee.position = Utf8("Engineer") + // TableScan: employee projection=[name, age, position] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + // Execute the query, and indeed no Manager's are returned + // + // +-----------+-----+----------+ + // | name | age | position | + // +-----------+-----+----------+ + // | Andy | 11 | Engineer | + // | Oleks | 33 | Engineer | + // | Xiangpeng | 55 | Engineer | + // +-----------+-----+----------+ + ctx.sql("SELECT * FROM employee").await?.show().await?; + + // We can now change the access level to "Manager" and see the results + // + // +----------+-----+----------+ + // | name | age | position | + // +----------+-----+----------+ + // | Andrew | 22 | Manager | + // | Chunchun | 44 | Manager | + // +----------+-----+----------+ + rule.set_show_position("Manager"); + ctx.sql("SELECT * FROM employee").await?.show().await?; + + // The filters introduced by our AnalyzerRule are treated the same as any + // other filter by the DataFusion optimizer, including predicate push down + // (including into scans), simplifications, and similar optimizations. + // + // For example adding another predicate to the query + let plan = ctx + .sql("SELECT * FROM employee WHERE age > 30") + .await? + .into_optimized_plan()?; + + // We can see the DataFusion Optimizer has combined the filters together + // when we print out the plan + // + // Filter: employee.age > Int32(30) AND employee.position = Utf8("Manager") + // TableScan: employee projection=[name, age, position] + println!("Logical Plan:\n\n{}\n", plan.display_indent()); + + Ok(()) +} + +/// Example AnalyzerRule that implements a very basic "row level access +/// control" +/// +/// In this case, it adds a filter to the plan that removes all managers from +/// the result set. +#[derive(Debug)] +struct RowLevelAccessControl { + /// Models the current access level of the session + /// + /// This is value of the position column which should be included in the + /// result set. It is wrapped in a `Mutex` so we can change it during query + show_position: Mutex, +} + +impl RowLevelAccessControl { + fn new() -> Self { + Self { + show_position: Mutex::new("Engineer".to_string()), + } + } + + /// return the current position to show, as an expression + fn show_position(&self) -> Expr { + lit(self.show_position.lock().unwrap().clone()) + } + + /// specifies a different position to show in the result set + fn set_show_position(&self, access_level: impl Into) { + *self.show_position.lock().unwrap() = access_level.into(); + } +} + +impl AnalyzerRule for RowLevelAccessControl { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + // use the TreeNode API to recursively walk the LogicalPlan tree + // and all of its children (inputs) + let transfomed_plan = plan.transform(|plan| { + // This closure is called for each LogicalPlan node + // if it is a Scan node, add a filter to remove all managers + if is_employee_table_scan(&plan) { + // Use the LogicalPlanBuilder to add a filter to the plan + let filter = LogicalPlanBuilder::from(plan) + // Filter Expression: position = + .filter(col("position").eq(self.show_position()))? + .build()?; + + // `Transformed::yes` signals the plan was changed + Ok(Transformed::yes(filter)) + } else { + // `Transformed::no` + // signals the plan was not changed + Ok(Transformed::no(plan)) + } + })?; + + // the result of calling transform is a `Transformed` structure which + // contains + // + // 1. a flag signaling if any rewrite took place + // 2. a flag if the recursion stopped early + // 3. The actual transformed data (a LogicalPlan in this case) + // + // This example does not need the value of either flag, so simply + // extract the LogicalPlan "data" + Ok(transfomed_plan.data) + } + + fn name(&self) -> &str { + "table_access" + } +} + +fn is_employee_table_scan(plan: &LogicalPlan) -> bool { + if let LogicalPlan::TableScan(scan) = plan { + scan.table_name.table() == "employee" + } else { + false + } +} + +/// Return a RecordBatch with made up data about fictional employees +fn employee_batch() -> RecordBatch { + let name: ArrayRef = Arc::new(StringArray::from_iter_values([ + "Andy", + "Andrew", + "Oleks", + "Chunchun", + "Xiangpeng", + ])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33, 44, 55])); + let position = Arc::new(StringArray::from_iter_values([ + "Engineer", "Manager", "Engineer", "Manager", "Engineer", + ])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age), ("position", position)]) + .unwrap() +} diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs deleted file mode 100644 index ac1053aa1881..000000000000 --- a/datafusion-examples/examples/avro_sql.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::util::pretty; - -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates executing a simple query against an Arrow data source (Avro) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - // create local execution context - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::arrow_test_data(); - - // register avro file with the execution context - let avro_file = &format!("{testdata}/avro/alltypes_plain.avro"); - ctx.register_avro("alltypes_plain", avro_file, AvroReadOptions::default()) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain \ - WHERE id > 1 AND tinyint_col < double_col", - ) - .await?; - let results = df.collect().await?; - - // print the results - pretty::print_batches(&results)?; - - Ok(()) -} diff --git a/datafusion-examples/examples/catalog.rs b/datafusion-examples/examples/catalog.rs index 5bc2cadac128..f40f1dfb5a15 100644 --- a/datafusion-examples/examples/catalog.rs +++ b/datafusion-examples/examples/catalog.rs @@ -19,10 +19,7 @@ use async_trait::async_trait; use datafusion::{ arrow::util::pretty, - catalog::{ - schema::SchemaProvider, - {CatalogProvider, CatalogProviderList}, - }, + catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}, datasource::{ file_format::{csv::CsvFormat, FileFormat}, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, @@ -47,13 +44,13 @@ async fn main() -> Result<()> { let dir_a = prepare_example_data()?; let dir_b = prepare_example_data()?; - let mut ctx = SessionContext::new(); + let ctx = SessionContext::new(); let state = ctx.state(); - let catlist = Arc::new(CustomCatalogProviderList::new()); + let cataloglist = Arc::new(CustomCatalogProviderList::new()); // use our custom catalog list for context. each context has a single catalog list. // context will by default have [`MemoryCatalogProviderList`] - ctx.register_catalog_list(catlist.clone()); + ctx.register_catalog_list(cataloglist.clone()); // initialize our catalog and schemas let catalog = DirCatalog::new(); @@ -83,8 +80,8 @@ async fn main() -> Result<()> { // register our catalog in the context ctx.register_catalog("dircat", Arc::new(catalog)); { - // catalog was passed down into our custom catalog list since we overide the ctx's default - let catalogs = catlist.catalogs.read().unwrap(); + // catalog was passed down into our custom catalog list since we override the ctx's default + let catalogs = cataloglist.catalogs.read().unwrap(); assert!(catalogs.contains_key("dircat")); }; @@ -138,6 +135,7 @@ struct DirSchemaOpts<'a> { format: Arc, } /// Schema where every file with extension `ext` in a given `dir` is a table. +#[derive(Debug)] struct DirSchema { ext: String, tables: RwLock>>, @@ -146,8 +144,8 @@ impl DirSchema { async fn create(state: &SessionState, opts: DirSchemaOpts<'_>) -> Result> { let DirSchemaOpts { ext, dir, format } = opts; let mut tables = HashMap::new(); - let listdir = std::fs::read_dir(dir).unwrap(); - for res in listdir { + let direntries = std::fs::read_dir(dir).unwrap(); + for res in direntries { let entry = res.unwrap(); let filename = entry.file_name().to_str().unwrap().to_string(); if !filename.ends_with(ext) { @@ -221,6 +219,7 @@ impl SchemaProvider for DirSchema { } } /// Catalog holds multiple schemas +#[derive(Debug)] struct DirCatalog { schemas: RwLock>>, } @@ -262,6 +261,7 @@ impl CatalogProvider for DirCatalog { } } /// Catalog lists holds multiple catalog providers. Each context has a single catalog list. +#[derive(Debug)] struct CustomCatalogProviderList { catalogs: RwLock>>, } diff --git a/datafusion-examples/examples/composed_extension_codec.rs b/datafusion-examples/examples/composed_extension_codec.rs new file mode 100644 index 000000000000..5c34eccf26e1 --- /dev/null +++ b/datafusion-examples/examples/composed_extension_codec.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example demonstrates how to compose multiple PhysicalExtensionCodecs +//! +//! This can be helpful when an Execution plan tree has different nodes from different crates +//! that need to be serialized. +//! +//! For example if your plan has `ShuffleWriterExec` from `datafusion-ballista` and `DeltaScan` from `deltalake` +//! both crates both provide PhysicalExtensionCodec and this example shows how to combine them together +//! +//! ```text +//! ShuffleWriterExec +//! ProjectionExec +//! ... +//! DeltaScan +//! ``` + +use std::any::Any; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +use datafusion::common::Result; +use datafusion::physical_plan::{DisplayAs, ExecutionPlan}; +use datafusion::prelude::SessionContext; +use datafusion_common::{internal_err, DataFusionError}; +use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::{AggregateUDF, ScalarUDF}; +use datafusion_proto::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; +use datafusion_proto::protobuf; + +#[tokio::main] +async fn main() { + // build execution plan that has both types of nodes + // + // Note each node requires a different `PhysicalExtensionCodec` to decode + let exec_plan = Arc::new(ParentExec { + input: Arc::new(ChildExec {}), + }); + let ctx = SessionContext::new(); + + let composed_codec = ComposedPhysicalExtensionCodec { + codecs: vec![ + Arc::new(ParentPhysicalExtensionCodec {}), + Arc::new(ChildPhysicalExtensionCodec {}), + ], + }; + + // serialize execution plan to proto + let proto: protobuf::PhysicalPlanNode = + protobuf::PhysicalPlanNode::try_from_physical_plan( + exec_plan.clone(), + &composed_codec, + ) + .expect("to proto"); + + // deserialize proto back to execution plan + let runtime = ctx.runtime_env(); + let result_exec_plan: Arc = proto + .try_into_physical_plan(&ctx, runtime.deref(), &composed_codec) + .expect("from proto"); + + // assert that the original and deserialized execution plans are equal + assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); +} + +/// This example has two types of nodes: `ParentExec` and `ChildExec` which can only +/// be serialized with different `PhysicalExtensionCodec`s +#[derive(Debug)] +struct ParentExec { + input: Arc, +} + +impl DisplayAs for ParentExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "ParentExec") + } +} + +impl ExecutionPlan for ParentExec { + fn name(&self) -> &str { + "ParentExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unreachable!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!() + } +} + +/// A PhysicalExtensionCodec that can serialize and deserialize ParentExec +#[derive(Debug)] +struct ParentPhysicalExtensionCodec; + +impl PhysicalExtensionCodec for ParentPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + if buf == "ParentExec".as_bytes() { + Ok(Arc::new(ParentExec { + input: inputs[0].clone(), + })) + } else { + internal_err!("Not supported") + } + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + if node.as_any().downcast_ref::().is_some() { + buf.extend_from_slice("ParentExec".as_bytes()); + Ok(()) + } else { + internal_err!("Not supported") + } + } +} + +#[derive(Debug)] +struct ChildExec {} + +impl DisplayAs for ChildExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "ChildExec") + } +} + +impl ExecutionPlan for ChildExec { + fn name(&self) -> &str { + "ChildExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &datafusion::physical_plan::PlanProperties { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + unreachable!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unreachable!() + } +} + +/// A PhysicalExtensionCodec that can serialize and deserialize ChildExec +#[derive(Debug)] +struct ChildPhysicalExtensionCodec; + +impl PhysicalExtensionCodec for ChildPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + if buf == "ChildExec".as_bytes() { + Ok(Arc::new(ChildExec {})) + } else { + internal_err!("Not supported") + } + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + if node.as_any().downcast_ref::().is_some() { + buf.extend_from_slice("ChildExec".as_bytes()); + Ok(()) + } else { + internal_err!("Not supported") + } + } +} + +/// A PhysicalExtensionCodec that tries one of multiple inner codecs +/// until one works +#[derive(Debug)] +struct ComposedPhysicalExtensionCodec { + codecs: Vec>, +} + +impl ComposedPhysicalExtensionCodec { + fn try_any( + &self, + mut f: impl FnMut(&dyn PhysicalExtensionCodec) -> Result, + ) -> Result { + let mut last_err = None; + for codec in &self.codecs { + match f(codec.as_ref()) { + Ok(node) => return Ok(node), + Err(err) => last_err = Some(err), + } + } + + Err(last_err.unwrap_or_else(|| { + DataFusionError::NotImplemented("Empty list of composed codecs".to_owned()) + })) + } +} + +impl PhysicalExtensionCodec for ComposedPhysicalExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[Arc], + registry: &dyn FunctionRegistry, + ) -> Result> { + self.try_any(|codec| codec.try_decode(buf, inputs, registry)) + } + + fn try_encode(&self, node: Arc, buf: &mut Vec) -> Result<()> { + self.try_any(|codec| codec.try_encode(node.clone(), buf)) + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + self.try_any(|codec| codec.try_decode_udf(name, buf)) + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + self.try_any(|codec| codec.try_encode_udf(node, buf)) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + self.try_any(|codec| codec.try_decode_udaf(name, buf)) + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + self.try_any(|codec| codec.try_encode_udaf(node, buf)) + } +} diff --git a/datafusion-examples/examples/csv_opener.rs b/datafusion-examples/examples/csv_opener.rs index 96753c8c5260..e7b7ead109bc 100644 --- a/datafusion-examples/examples/csv_opener.rs +++ b/datafusion-examples/examples/csv_opener.rs @@ -17,7 +17,6 @@ use std::{sync::Arc, vec}; -use datafusion::common::Statistics; use datafusion::{ assert_batches_eq, datasource::{ @@ -48,7 +47,9 @@ async fn main() -> Result<()> { true, b',', b'"', + None, object_store, + Some(b'#'), ); let opener = CsvOpener::new(Arc::new(config), FileCompressionType::UNCOMPRESSED); @@ -58,16 +59,11 @@ async fn main() -> Result<()> { let path = std::path::Path::new(&path).canonicalize()?; - let scan_config = FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new(path.display().to_string(), 10)]], - statistics: Statistics::new_unknown(&schema), - projection: Some(vec![12, 0]), - limit: Some(5), - table_partition_cols: vec![], - output_ordering: vec![], - }; + let scan_config = + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema.clone()) + .with_projection(Some(vec![12, 0])) + .with_limit(Some(5)) + .with_file(PartitionedFile::new(path.display().to_string(), 10)); let result = FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new()) diff --git a/datafusion-examples/examples/csv_sql.rs b/datafusion-examples/examples/csv_sql.rs deleted file mode 100644 index 851fdcb626d2..000000000000 --- a/datafusion-examples/examples/csv_sql.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::datasource::file_format::file_compression_type::FileCompressionType; -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates executing a simple query against an Arrow data source (CSV) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - // create local execution context - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::arrow_test_data(); - - // register csv file with the execution context - ctx.register_csv( - "aggregate_test_100", - &format!("{testdata}/csv/aggregate_test_100.csv"), - CsvReadOptions::new(), - ) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT c1, MIN(c12), MAX(c12) \ - FROM aggregate_test_100 \ - WHERE c11 > 0.1 AND c11 < 0.9 \ - GROUP BY c1", - ) - .await?; - - // print the results - df.show().await?; - - // query compressed CSV with specific options - let csv_options = CsvReadOptions::default() - .has_header(true) - .file_compression_type(FileCompressionType::GZIP) - .file_extension("csv.gz"); - let df = ctx - .read_csv( - &format!("{testdata}/csv/aggregate_test_100.csv.gz"), - csv_options, - ) - .await?; - let df = df - .filter(col("c1").eq(lit("a")))? - .select_columns(&["c2", "c3"])?; - - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index c2ea6f2b52a1..7440e592962b 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::{provider_as_source, TableProvider, TableType}; use datafusion::error::Result; -use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ project_schema, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, @@ -37,6 +37,7 @@ use datafusion_expr::LogicalPlanBuilder; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; +use datafusion::catalog::Session; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource @@ -109,7 +110,7 @@ struct CustomDataSourceInner { } impl Debug for CustomDataSource { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.write_str("custom_db") } } @@ -175,7 +176,7 @@ impl TableProvider for CustomDataSource { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, // filters and limit can be used here to inject some push-down operations if needed _filters: &[Expr], @@ -219,7 +220,7 @@ impl CustomExec { } impl DisplayAs for CustomExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> std::fmt::Result { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { write!(f, "CustomExec") } } @@ -237,7 +238,7 @@ impl ExecutionPlan for CustomExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion-examples/examples/custom_file_format.rs b/datafusion-examples/examples/custom_file_format.rs new file mode 100644 index 000000000000..95168597ebaa --- /dev/null +++ b/datafusion-examples/examples/custom_file_format.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{AsArray, RecordBatch, StringArray, UInt8Array}, + datatypes::UInt64Type, +}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::physical_expr::LexRequirement; +use datafusion::{ + datasource::{ + file_format::{ + csv::CsvFormatFactory, file_compression_type::FileCompressionType, + FileFormat, FileFormatFactory, + }, + physical_plan::{FileScanConfig, FileSinkConfig}, + MemTable, + }, + error::Result, + execution::context::SessionState, + physical_plan::ExecutionPlan, + prelude::SessionContext, +}; +use datafusion_common::{GetExt, Statistics}; +use datafusion_physical_expr::PhysicalExpr; +use object_store::{ObjectMeta, ObjectStore}; +use tempfile::tempdir; + +/// Example of a custom file format that reads and writes TSV files. +/// +/// TSVFileFormatFactory is responsible for creating instances of TSVFileFormat. +/// The former, once registered with the SessionState, will then be used +/// to facilitate SQL operations on TSV files, such as `COPY TO` shown here. + +#[derive(Debug)] +/// Custom file format that reads and writes TSV files +/// +/// This file format is a wrapper around the CSV file format +/// for demonstration purposes. +struct TSVFileFormat { + csv_file_format: Arc, +} + +impl TSVFileFormat { + pub fn new(csv_file_format: Arc) -> Self { + Self { csv_file_format } + } +} + +#[async_trait::async_trait] +impl FileFormat for TSVFileFormat { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_ext(&self) -> String { + "tsv".to_string() + } + + fn get_ext_with_compression(&self, c: &FileCompressionType) -> Result { + if c == &FileCompressionType::UNCOMPRESSED { + Ok("tsv".to_string()) + } else { + todo!("Compression not supported") + } + } + + async fn infer_schema( + &self, + state: &SessionState, + store: &Arc, + objects: &[ObjectMeta], + ) -> Result { + self.csv_file_format + .infer_schema(state, store, objects) + .await + } + + async fn infer_stats( + &self, + state: &SessionState, + store: &Arc, + table_schema: SchemaRef, + object: &ObjectMeta, + ) -> Result { + self.csv_file_format + .infer_stats(state, store, table_schema, object) + .await + } + + async fn create_physical_plan( + &self, + state: &SessionState, + conf: FileScanConfig, + filters: Option<&Arc>, + ) -> Result> { + self.csv_file_format + .create_physical_plan(state, conf, filters) + .await + } + + async fn create_writer_physical_plan( + &self, + input: Arc, + state: &SessionState, + conf: FileSinkConfig, + order_requirements: Option, + ) -> Result> { + self.csv_file_format + .create_writer_physical_plan(input, state, conf, order_requirements) + .await + } +} + +#[derive(Default, Debug)] +/// Factory for creating TSV file formats +/// +/// This factory is a wrapper around the CSV file format factory +/// for demonstration purposes. +pub struct TSVFileFactory { + csv_file_factory: CsvFormatFactory, +} + +impl TSVFileFactory { + pub fn new() -> Self { + Self { + csv_file_factory: CsvFormatFactory::new(), + } + } +} + +impl FileFormatFactory for TSVFileFactory { + fn create( + &self, + state: &SessionState, + format_options: &std::collections::HashMap, + ) -> Result> { + let mut new_options = format_options.clone(); + new_options.insert("format.delimiter".to_string(), "\t".to_string()); + + let csv_file_format = self.csv_file_factory.create(state, &new_options)?; + let tsv_file_format = Arc::new(TSVFileFormat::new(csv_file_format)); + + Ok(tsv_file_format) + } + + fn default(&self) -> Arc { + todo!() + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for TSVFileFactory { + fn get_ext(&self) -> String { + "tsv".to_string() + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Create a new context with the default configuration + let mut state = SessionStateBuilder::new().with_default_features().build(); + + // Register the custom file format + let file_format = Arc::new(TSVFileFactory::new()); + state.register_file_format(file_format, true).unwrap(); + + // Create a new context with the custom file format + let ctx = SessionContext::new_with_state(state); + + let mem_table = create_mem_table(); + ctx.register_table("mem_table", mem_table).unwrap(); + + let temp_dir = tempdir().unwrap(); + let table_save_path = temp_dir.path().join("mem_table.tsv"); + + let d = ctx + .sql(&format!( + "COPY mem_table TO '{}' STORED AS TSV;", + table_save_path.display(), + )) + .await?; + + let results = d.collect().await?; + println!( + "Number of inserted rows: {:?}", + (results[0] + .column_by_name("count") + .unwrap() + .as_primitive::() + .value(0)) + ); + + Ok(()) +} + +// create a simple mem table +fn create_mem_table() -> Arc { + let fields = vec![ + Field::new("id", DataType::UInt8, false), + Field::new("data", DataType::Utf8, false), + ]; + let schema = Arc::new(Schema::new(fields)); + + let partitions = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt8Array::from(vec![1, 2])), + Arc::new(StringArray::from(vec!["foo", "bar"])), + ], + ) + .unwrap(); + + Arc::new(MemTable::try_new(schema, vec![vec![partitions]]).unwrap()) +} diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index ea01c53b1c62..d7e0068ef88f 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -64,6 +64,12 @@ async fn main() -> Result<()> { .await?; parquet_df.describe().await.unwrap().show().await?; + let dyn_ctx = ctx.enable_url_table(); + let df = dyn_ctx + .sql(&format!("SELECT * FROM '{}'", file_path.to_str().unwrap())) + .await?; + df.show().await?; + Ok(()) } diff --git a/datafusion-examples/examples/dataframe_subquery.rs b/datafusion-examples/examples/dataframe_subquery.rs index 9fb61008b9f6..3e3d0c1b5a84 100644 --- a/datafusion-examples/examples/dataframe_subquery.rs +++ b/datafusion-examples/examples/dataframe_subquery.rs @@ -19,6 +19,8 @@ use arrow_schema::DataType; use std::sync::Arc; use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg; +use datafusion::functions_aggregate::min_max::max; use datafusion::prelude::*; use datafusion::test_util::arrow_test_data; use datafusion_common::ScalarValue; diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 6e9c42480c32..0eb823302acf 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -24,17 +24,16 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::DFSchema; use datafusion::error::Result; +use datafusion::functions_aggregate::first_last::first_value_udaf; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::physical_expr::{ - analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr, -}; +use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries}; use datafusion::prelude::*; use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::simplify::SimplifyContext; -use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; +use datafusion_expr::{ColumnarValue, ExprFunctionExt, ExprSchemable, Operator}; /// This example demonstrates the DataFusion [`Expr`] API. /// @@ -46,11 +45,12 @@ use datafusion_expr::{ColumnarValue, ExprSchemable, Operator}; /// also comes with APIs for evaluation, simplification, and analysis. /// /// The code in this example shows how to: -/// 1. Create [`Exprs`] using different APIs: [`main`]` -/// 2. Evaluate [`Exprs`] against data: [`evaluate_demo`] -/// 3. Simplify expressions: [`simplify_demo`] -/// 4. Analyze predicates for boundary ranges: [`range_analysis_demo`] -/// 5. Get the types of the expressions: [`expression_type_demo`] +/// 1. Create [`Expr`]s using different APIs: [`main`]` +/// 2. Use the fluent API to easily create complex [`Expr`]s: [`expr_fn_demo`] +/// 3. Evaluate [`Expr`]s against data: [`evaluate_demo`] +/// 4. Simplify expressions: [`simplify_demo`] +/// 5. Analyze predicates for boundary ranges: [`range_analysis_demo`] +/// 6. Get the types of the expressions: [`expression_type_demo`] #[tokio::main] async fn main() -> Result<()> { // The easiest way to do create expressions is to use the @@ -65,6 +65,9 @@ async fn main() -> Result<()> { )); assert_eq!(expr, expr2); + // See how to build aggregate functions with the expr_fn API + expr_fn_demo()?; + // See how to evaluate expressions evaluate_demo()?; @@ -80,6 +83,33 @@ async fn main() -> Result<()> { Ok(()) } +/// DataFusion's `expr_fn` API makes it easy to create [`Expr`]s for the +/// full range of expression types such as aggregates and window functions. +fn expr_fn_demo() -> Result<()> { + // Let's say you want to call the "first_value" aggregate function + let first_value = first_value_udaf(); + + // For example, to create the expression `FIRST_VALUE(price)` + // These expressions can be passed to `DataFrame::aggregate` and other + // APIs that take aggregate expressions. + let agg = first_value.call(vec![col("price")]); + assert_eq!(agg.to_string(), "first_value(price)"); + + // You can use the ExprFunctionExt trait to create more complex aggregates + // such as `FIRST_VALUE(price FILTER quantity > 100 ORDER BY ts ) + let agg = first_value + .call(vec![col("price")]) + .order_by(vec![col("ts").sort(false, false)]) + .filter(col("quantity").gt(lit(100))) + .build()?; // build the aggregate + assert_eq!( + agg.to_string(), + "first_value(price) FILTER (WHERE quantity > Int32(100)) ORDER BY [ts DESC NULLS LAST]" + ); + + Ok(()) +} + /// DataFusion can also evaluate arbitrary expressions on Arrow arrays. fn evaluate_demo() -> Result<()> { // For example, let's say you have some integers in an array @@ -92,7 +122,8 @@ fn evaluate_demo() -> Result<()> { let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); // First, you make a "physical expression" from the logical `Expr` - let physical_expr = physical_expr(&batch.schema(), expr)?; + let df_schema = DFSchema::try_from(batch.schema())?; + let physical_expr = SessionContext::new().create_physical_expr(expr, &df_schema)?; // Now, you can evaluate the expression against the RecordBatch let result = physical_expr.evaluate(&batch)?; @@ -146,16 +177,12 @@ fn simplify_demo() -> Result<()> { ); // here are some other examples of what DataFusion is capable of - let schema = Schema::new(vec![ - make_field("i", DataType::Int64), - make_field("b", DataType::Boolean), - ]) - .to_dfschema_ref()?; + let schema = Schema::new(vec![make_field("i", DataType::Int64)]).to_dfschema_ref()?; let context = SimplifyContext::new(&props).with_schema(schema.clone()); let simplifier = ExprSimplifier::new(context); // basic arithmetic simplification - // i + 1 + 2 => a + 3 + // i + 1 + 2 => i + 3 // (note this is not done if the expr is (col("i") + (lit(1) + lit(2)))) assert_eq!( simplifier.simplify(col("i") + (lit(1) + lit(2)))?, @@ -178,7 +205,7 @@ fn simplify_demo() -> Result<()> { ); // String --> Date simplification - // `cast('2020-09-01' as date)` --> 18500 + // `cast('2020-09-01' as date)` --> 18506 # number of days since epoch 1970-01-01 assert_eq!( simplifier.simplify(lit("2020-09-01").cast_to(&DataType::Date32, &schema)?)?, lit(ScalarValue::Date32(Some(18506))) @@ -213,7 +240,7 @@ fn range_analysis_demo() -> Result<()> { // `date < '2020-10-01' AND date > '2020-09-01'` // As always, we need to tell DataFusion the type of column "date" - let schema = Schema::new(vec![make_field("date", DataType::Date32)]); + let schema = Arc::new(Schema::new(vec![make_field("date", DataType::Date32)])); // You can provide DataFusion any known boundaries on the values of `date` // (for example, maybe you know you only have data up to `2020-09-15`), but @@ -222,9 +249,13 @@ fn range_analysis_demo() -> Result<()> { let boundaries = ExprBoundaries::try_new_unbounded(&schema)?; // Now, we invoke the analysis code to perform the range analysis - let physical_expr = physical_expr(&schema, expr)?; - let analysis_result = - analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?; + let df_schema = DFSchema::try_from(schema)?; + let physical_expr = SessionContext::new().create_physical_expr(expr, &df_schema)?; + let analysis_result = analyze( + &physical_expr, + AnalysisContext::new(boundaries), + df_schema.as_ref(), + )?; // The results of the analysis is an range, encoded as an `Interval`, for // each column in the schema, that must be true in order for the predicate @@ -248,21 +279,6 @@ fn make_ts_field(name: &str) -> Field { make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz)) } -/// Build a physical expression from a logical one, after applying simplification and type coercion -pub fn physical_expr(schema: &Schema, expr: Expr) -> Result> { - let df_schema = schema.clone().to_dfschema_ref()?; - - // Simplify - let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone())); - - // apply type coercion here to ensure types match - let expr = simplifier.coerce(expr, df_schema.clone())?; - - create_physical_expr(&expr, df_schema.as_ref(), &props) -} - /// This function shows how to use `Expr::get_type` to retrieve the DataType /// of an expression fn expression_type_demo() -> Result<()> { @@ -272,14 +288,14 @@ fn expression_type_demo() -> Result<()> { // types of the input expressions. You can provide this information using // a schema. In this case we create a schema where the column `c` is of // type Utf8 (a String / VARCHAR) - let schema = DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualified_fields( vec![Field::new("c", DataType::Utf8, true)].into(), HashMap::new(), )?; assert_eq!("Utf8", format!("{}", expr.get_type(&schema).unwrap())); // Using a schema where the column `foo` is of type Int32 - let schema = DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualified_fields( vec![Field::new("c", DataType::Int32, true)].into(), HashMap::new(), )?; @@ -288,7 +304,7 @@ fn expression_type_demo() -> Result<()> { // Get the type of an expression that adds 2 columns. Adding an Int32 // and Float32 results in Float32 type let expr = col("c1") + col("c2"); - let schema = DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualified_fields( vec![ Field::new("c1", DataType::Int32, true), Field::new("c2", DataType::Float32, true), diff --git a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs index 8d56c440da36..e75ba5dd5328 100644 --- a/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs +++ b/datafusion-examples/examples/external_dependency/dataframe-to-s3.rs @@ -20,10 +20,10 @@ use std::sync::Arc; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::FileFormat; use datafusion::datasource::listing::ListingOptions; use datafusion::error::Result; use datafusion::prelude::*; -use datafusion_common::{FileType, GetExt}; use object_store::aws::AmazonS3Builder; use url::Url; @@ -49,13 +49,12 @@ async fn main() -> Result<()> { let path = format!("s3://{bucket_name}"); let s3_url = Url::parse(&path).unwrap(); let arc_s3 = Arc::new(s3); - ctx.runtime_env() - .register_object_store(&s3_url, arc_s3.clone()); + ctx.register_object_store(&s3_url, arc_s3.clone()); let path = format!("s3://{bucket_name}/test_data/"); let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); ctx.register_listing_table("test", &path, listing_options, None, None) .await?; @@ -80,7 +79,7 @@ async fn main() -> Result<()> { let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); ctx.register_listing_table("test2", &out_path, listing_options, None, None) .await?; diff --git a/datafusion-examples/examples/external_dependency/query-aws-s3.rs b/datafusion-examples/examples/external_dependency/query-aws-s3.rs index cbb6486b4eec..da2d7e4879f9 100644 --- a/datafusion-examples/examples/external_dependency/query-aws-s3.rs +++ b/datafusion-examples/examples/external_dependency/query-aws-s3.rs @@ -48,8 +48,7 @@ async fn main() -> Result<()> { let path = format!("s3://{bucket_name}"); let s3_url = Url::parse(&path).unwrap(); - ctx.runtime_env() - .register_object_store(&s3_url, Arc::new(s3)); + ctx.register_object_store(&s3_url, Arc::new(s3)); // cannot query the parquet files from this bucket because the path contains a whitespace // and we don't support that yet @@ -64,5 +63,14 @@ async fn main() -> Result<()> { // print the results df.show().await?; + // dynamic query by the file path + let ctx = ctx.enable_url_table(); + let df = ctx + .sql(format!(r#"SELECT * FROM '{}' LIMIT 10"#, &path).as_str()) + .await?; + + // print the results + df.show().await?; + Ok(()) } diff --git a/datafusion-examples/examples/file_stream_provider.rs b/datafusion-examples/examples/file_stream_provider.rs new file mode 100644 index 000000000000..e4fd937fd373 --- /dev/null +++ b/datafusion-examples/examples/file_stream_provider.rs @@ -0,0 +1,202 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(not(target_os = "windows"))] +mod non_windows { + use datafusion::assert_batches_eq; + use datafusion_common::instant::Instant; + use std::fs::{File, OpenOptions}; + use std::io::Write; + use std::path::PathBuf; + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SchemaRef; + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; + use tempfile::TempDir; + use tokio::task::JoinSet; + + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; + use datafusion::datasource::TableProvider; + use datafusion::prelude::{SessionConfig, SessionContext}; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::SortExpr; + + // Number of lines written to FIFO + const TEST_BATCH_SIZE: usize = 5; + const TEST_DATA_SIZE: usize = 5; + + /// Makes a TableProvider for a fifo file using `StreamTable` with the `StreamProvider` trait + fn fifo_table( + schema: SchemaRef, + path: impl Into, + sort: Vec>, + ) -> Arc { + let source = FileStreamProvider::new_file(schema, path.into()) + .with_batch_size(TEST_BATCH_SIZE) + .with_header(true); + let config = StreamConfig::new(Arc::new(source)).with_order(sort); + Arc::new(StreamTable::new(Arc::new(config))) + } + + fn create_fifo_file(tmp_dir: &TempDir, file_name: &str) -> Result { + let file_path = tmp_dir.path().join(file_name); + // Simulate an infinite environment via a FIFO file + if let Err(e) = unistd::mkfifo(&file_path, stat::Mode::S_IRWXU) { + exec_err!("{}", e) + } else { + Ok(file_path) + } + } + + fn write_to_fifo( + mut file: &File, + line: &str, + ref_time: Instant, + broken_pipe_timeout: Duration, + ) -> Result<()> { + // We need to handle broken pipe error until the reader is ready. This + // is why we use a timeout to limit the wait duration for the reader. + // If the error is different than broken pipe, we fail immediately. + while let Err(e) = file.write_all(line.as_bytes()) { + if e.raw_os_error().unwrap() == 32 { + let interval = Instant::now().duration_since(ref_time); + if interval < broken_pipe_timeout { + thread::sleep(Duration::from_millis(100)); + continue; + } + } + return exec_err!("{}", e); + } + Ok(()) + } + + fn create_writing_thread( + file_path: PathBuf, + maybe_header: Option, + lines: Vec, + waiting_lock: Arc, + wait_until: usize, + tasks: &mut JoinSet<()>, + ) { + // Timeout for a long period of BrokenPipe error + let broken_pipe_timeout = Duration::from_secs(10); + let sa = file_path; + // Spawn a new thread to write to the FIFO file + #[allow(clippy::disallowed_methods)] // spawn allowed only in tests + tasks.spawn_blocking(move || { + let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Reference time to use when deciding to fail the test + let execution_start = Instant::now(); + if let Some(header) = maybe_header { + write_to_fifo(&file, &header, execution_start, broken_pipe_timeout) + .unwrap(); + } + for (cnt, line) in lines.iter().enumerate() { + while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { + thread::sleep(Duration::from_millis(50)); + } + write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + } + drop(file); + }); + } + + /// This example demonstrates a scanning against an Arrow data source (JSON) and + /// fetching results + pub async fn main() -> Result<()> { + // Create session context + let config = SessionConfig::new() + .with_batch_size(TEST_BATCH_SIZE) + .with_collect_statistics(false) + .with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + let tmp_dir = TempDir::new()?; + let fifo_path = create_fifo_file(&tmp_dir, "fifo_unbounded.csv")?; + + let mut tasks: JoinSet<()> = JoinSet::new(); + let waiting = Arc::new(AtomicBool::new(true)); + + let data_iter = 0..TEST_DATA_SIZE; + let lines = data_iter + .map(|i| format!("{},{}\n", i, i + 1)) + .collect::>(); + + create_writing_thread( + fifo_path.clone(), + Some("a1,a2\n".to_owned()), + lines.clone(), + waiting.clone(), + TEST_DATA_SIZE, + &mut tasks, + ); + + // Create schema + let schema = Arc::new(Schema::new(vec![ + Field::new("a1", DataType::UInt32, false), + Field::new("a2", DataType::UInt32, false), + ])); + + // Specify the ordering: + let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; + + let provider = fifo_table(schema.clone(), fifo_path, order.clone()); + ctx.register_table("fifo", provider)?; + + let df = ctx.sql("SELECT * FROM fifo").await.unwrap(); + let mut stream = df.execute_stream().await.unwrap(); + + let mut batches = Vec::new(); + if let Some(Ok(batch)) = stream.next().await { + batches.push(batch) + } + + let expected = vec![ + "+----+----+", + "| a1 | a2 |", + "+----+----+", + "| 0 | 1 |", + "| 1 | 2 |", + "| 2 | 3 |", + "| 3 | 4 |", + "| 4 | 5 |", + "+----+----+", + ]; + + assert_batches_eq!(&expected, &batches); + + Ok(()) + } +} + +#[tokio::main] +async fn main() -> datafusion_common::Result<()> { + #[cfg(target_os = "windows")] + { + println!("file_stream_provider example does not work on windows"); + Ok(()) + } + #[cfg(not(target_os = "windows"))] + { + non_windows::main().await + } +} diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index f9d1b8029f04..cc5f43746ddf 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -105,7 +105,7 @@ impl FlightService for FlightServiceImpl { } // add an initial FlightData message that sends schema - let options = datafusion::arrow::ipc::writer::IpcWriteOptions::default(); + let options = arrow::ipc::writer::IpcWriteOptions::default(); let schema_flight_data = SchemaAsIpc::new(&schema, &options); let mut flights = vec![FlightData::from(schema_flight_data)]; diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index f04a559d002e..2e46daf7cb4e 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{ArrayRef, StringArray}; use arrow::ipc::writer::IpcWriteOptions; use arrow::record_batch::RecordBatch; use arrow_flight::encode::FlightDataEncoderBuilder; @@ -22,24 +23,16 @@ use arrow_flight::flight_descriptor::DescriptorType; use arrow_flight::flight_service_server::{FlightService, FlightServiceServer}; use arrow_flight::sql::server::{FlightSqlService, PeekableFlightDataStream}; use arrow_flight::sql::{ - ActionBeginSavepointRequest, ActionBeginSavepointResult, - ActionBeginTransactionRequest, ActionBeginTransactionResult, - ActionCancelQueryRequest, ActionCancelQueryResult, ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, - ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, - ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, - CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, - CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, - CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery, - CommandStatementSubstraitPlan, CommandStatementUpdate, ProstMessageExt, SqlInfo, - TicketStatementQuery, + ActionCreatePreparedStatementResult, Any, CommandGetTables, + CommandPreparedStatementQuery, CommandPreparedStatementUpdate, ProstMessageExt, + SqlInfo, }; use arrow_flight::{ Action, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, }; -use arrow_schema::Schema; +use arrow_schema::{DataType, Field, Schema}; use dashmap::DashMap; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{DataFrame, ParquetReadOptions, SessionConfig, SessionContext}; @@ -165,6 +158,43 @@ impl FlightSqlServiceImpl { } } + async fn tables(&self, ctx: Arc) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("catalog_name", DataType::Utf8, true), + Field::new("db_schema_name", DataType::Utf8, true), + Field::new("table_name", DataType::Utf8, false), + Field::new("table_type", DataType::Utf8, false), + ])); + + let mut catalogs = vec![]; + let mut schemas = vec![]; + let mut names = vec![]; + let mut types = vec![]; + for catalog in ctx.catalog_names() { + let catalog_provider = ctx.catalog(&catalog).unwrap(); + for schema in catalog_provider.schema_names() { + let schema_provider = catalog_provider.schema(&schema).unwrap(); + for table in schema_provider.table_names() { + let table_provider = + schema_provider.table(&table).await.unwrap().unwrap(); + catalogs.push(catalog.clone()); + schemas.push(schema.clone()); + names.push(table.clone()); + types.push(table_provider.table_type().to_string()) + } + } + } + + RecordBatch::try_new( + schema, + [catalogs, schemas, names, types] + .into_iter() + .map(|i| Arc::new(StringArray::from(i)) as ArrayRef) + .collect::>(), + ) + .unwrap() + } + fn remove_plan(&self, handle: &str) -> Result<(), Status> { self.statements.remove(&handle.to_string()); Ok(()) @@ -246,27 +276,6 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(Response::new(Box::pin(stream))) } - async fn get_flight_info_statement( - &self, - query: CommandStatementQuery, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_statement query:\n{}", query.query); - - Err(Status::unimplemented("Implement get_flight_info_statement")) - } - - async fn get_flight_info_substrait_plan( - &self, - _query: CommandStatementSubstraitPlan, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_substrait_plan"); - Err(Status::unimplemented( - "Implement get_flight_info_substrait_plan", - )) - } - async fn get_flight_info_prepared_statement( &self, cmd: CommandPreparedStatementQuery, @@ -304,267 +313,50 @@ impl FlightSqlService for FlightSqlServiceImpl { }; let buf = fetch.as_any().encode_to_vec().into(); let ticket = Ticket { ticket: buf }; - let endpoint = FlightEndpoint { - ticket: Some(ticket), - location: vec![], - expiration_time: None, - app_metadata: Default::default(), - }; - let endpoints = vec![endpoint]; - - let message = SchemaAsIpc::new(&schema, &IpcWriteOptions::default()) - .try_into() - .map_err(|e| status!("Unable to serialize schema", e))?; - let IpcMessage(schema_bytes) = message; - let flight_desc = FlightDescriptor { - r#type: DescriptorType::Cmd.into(), - cmd: Default::default(), - path: vec![], - }; - // send -1 for total_records and total_bytes instead of iterating over all the - // batches to get num_rows() and total byte size. - let info = FlightInfo { - schema: schema_bytes, - flight_descriptor: Some(flight_desc), - endpoint: endpoints, - total_records: -1_i64, - total_bytes: -1_i64, - ordered: false, - app_metadata: Default::default(), - }; + let info = FlightInfo::new() + // Encode the Arrow schema + .try_with_schema(&schema) + .expect("encoding failed") + .with_endpoint(FlightEndpoint::new().with_ticket(ticket)) + .with_descriptor(FlightDescriptor { + r#type: DescriptorType::Cmd.into(), + cmd: Default::default(), + path: vec![], + }); let resp = Response::new(info); Ok(resp) } - async fn get_flight_info_catalogs( - &self, - _query: CommandGetCatalogs, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_catalogs"); - Err(Status::unimplemented("Implement get_flight_info_catalogs")) - } - - async fn get_flight_info_schemas( - &self, - _query: CommandGetDbSchemas, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_schemas"); - Err(Status::unimplemented("Implement get_flight_info_schemas")) - } - async fn get_flight_info_tables( &self, _query: CommandGetTables, - _request: Request, + request: Request, ) -> Result, Status> { info!("get_flight_info_tables"); - Err(Status::unimplemented("Implement get_flight_info_tables")) - } - - async fn get_flight_info_table_types( - &self, - _query: CommandGetTableTypes, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_table_types"); - Err(Status::unimplemented( - "Implement get_flight_info_table_types", - )) - } - - async fn get_flight_info_sql_info( - &self, - _query: CommandGetSqlInfo, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_sql_info"); - Err(Status::unimplemented("Implement CommandGetSqlInfo")) - } - - async fn get_flight_info_primary_keys( - &self, - _query: CommandGetPrimaryKeys, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_primary_keys"); - Err(Status::unimplemented( - "Implement get_flight_info_primary_keys", - )) - } - - async fn get_flight_info_exported_keys( - &self, - _query: CommandGetExportedKeys, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_exported_keys"); - Err(Status::unimplemented( - "Implement get_flight_info_exported_keys", - )) - } - - async fn get_flight_info_imported_keys( - &self, - _query: CommandGetImportedKeys, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_imported_keys"); - Err(Status::unimplemented( - "Implement get_flight_info_imported_keys", - )) - } - - async fn get_flight_info_cross_reference( - &self, - _query: CommandGetCrossReference, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_cross_reference"); - Err(Status::unimplemented( - "Implement get_flight_info_cross_reference", - )) - } - - async fn get_flight_info_xdbc_type_info( - &self, - _query: CommandGetXdbcTypeInfo, - _request: Request, - ) -> Result, Status> { - info!("get_flight_info_xdbc_type_info"); - Err(Status::unimplemented( - "Implement get_flight_info_xdbc_type_info", - )) - } - - async fn do_get_statement( - &self, - _ticket: TicketStatementQuery, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_statement"); - Err(Status::unimplemented("Implement do_get_statement")) - } - - async fn do_get_prepared_statement( - &self, - _query: CommandPreparedStatementQuery, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_prepared_statement"); - Err(Status::unimplemented("Implement do_get_prepared_statement")) - } - - async fn do_get_catalogs( - &self, - _query: CommandGetCatalogs, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_catalogs"); - Err(Status::unimplemented("Implement do_get_catalogs")) - } - - async fn do_get_schemas( - &self, - _query: CommandGetDbSchemas, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_schemas"); - Err(Status::unimplemented("Implement do_get_schemas")) - } - - async fn do_get_tables( - &self, - _query: CommandGetTables, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_tables"); - Err(Status::unimplemented("Implement do_get_tables")) - } - - async fn do_get_table_types( - &self, - _query: CommandGetTableTypes, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_table_types"); - Err(Status::unimplemented("Implement do_get_table_types")) - } - - async fn do_get_sql_info( - &self, - _query: CommandGetSqlInfo, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_sql_info"); - Err(Status::unimplemented("Implement do_get_sql_info")) - } - - async fn do_get_primary_keys( - &self, - _query: CommandGetPrimaryKeys, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_primary_keys"); - Err(Status::unimplemented("Implement do_get_primary_keys")) - } - - async fn do_get_exported_keys( - &self, - _query: CommandGetExportedKeys, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_exported_keys"); - Err(Status::unimplemented("Implement do_get_exported_keys")) - } - - async fn do_get_imported_keys( - &self, - _query: CommandGetImportedKeys, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_imported_keys"); - Err(Status::unimplemented("Implement do_get_imported_keys")) - } - - async fn do_get_cross_reference( - &self, - _query: CommandGetCrossReference, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_cross_reference"); - Err(Status::unimplemented("Implement do_get_cross_reference")) - } + let ctx = self.get_ctx(&request)?; + let data = self.tables(ctx).await; + let schema = data.schema(); - async fn do_get_xdbc_type_info( - &self, - _query: CommandGetXdbcTypeInfo, - _request: Request, - ) -> Result::DoGetStream>, Status> { - info!("do_get_xdbc_type_info"); - Err(Status::unimplemented("Implement do_get_xdbc_type_info")) - } + let uuid = Uuid::new_v4().hyphenated().to_string(); + self.results.insert(uuid.clone(), vec![data]); - async fn do_put_statement_update( - &self, - _ticket: CommandStatementUpdate, - _request: Request, - ) -> Result { - info!("do_put_statement_update"); - Err(Status::unimplemented("Implement do_put_statement_update")) - } + let fetch = FetchResults { handle: uuid }; + let buf = fetch.as_any().encode_to_vec().into(); + let ticket = Ticket { ticket: buf }; - async fn do_put_prepared_statement_query( - &self, - _query: CommandPreparedStatementQuery, - _request: Request, - ) -> Result::DoPutStream>, Status> { - info!("do_put_prepared_statement_query"); - Err(Status::unimplemented( - "Implement do_put_prepared_statement_query", - )) + let info = FlightInfo::new() + // Encode the Arrow schema + .try_with_schema(&schema) + .expect("encoding failed") + .with_endpoint(FlightEndpoint::new().with_ticket(ticket)) + .with_descriptor(FlightDescriptor { + r#type: DescriptorType::Cmd.into(), + cmd: Default::default(), + path: vec![], + }); + let resp = Response::new(info); + Ok(resp) } async fn do_put_prepared_statement_update( @@ -578,17 +370,6 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(-1) } - async fn do_put_substrait_plan( - &self, - _query: CommandStatementSubstraitPlan, - _request: Request, - ) -> Result { - info!("do_put_prepared_statement_update"); - Err(Status::unimplemented( - "Implement do_put_prepared_statement_update", - )) - } - async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, @@ -639,64 +420,6 @@ impl FlightSqlService for FlightSqlServiceImpl { Ok(()) } - async fn do_action_create_prepared_substrait_plan( - &self, - _query: ActionCreatePreparedSubstraitPlanRequest, - _request: Request, - ) -> Result { - info!("do_action_create_prepared_substrait_plan"); - Err(Status::unimplemented( - "Implement do_action_create_prepared_substrait_plan", - )) - } - - async fn do_action_begin_transaction( - &self, - _query: ActionBeginTransactionRequest, - _request: Request, - ) -> Result { - info!("do_action_begin_transaction"); - Err(Status::unimplemented( - "Implement do_action_begin_transaction", - )) - } - - async fn do_action_end_transaction( - &self, - _query: ActionEndTransactionRequest, - _request: Request, - ) -> Result<(), Status> { - info!("do_action_end_transaction"); - Err(Status::unimplemented("Implement do_action_end_transaction")) - } - - async fn do_action_begin_savepoint( - &self, - _query: ActionBeginSavepointRequest, - _request: Request, - ) -> Result { - info!("do_action_begin_savepoint"); - Err(Status::unimplemented("Implement do_action_begin_savepoint")) - } - - async fn do_action_end_savepoint( - &self, - _query: ActionEndSavepointRequest, - _request: Request, - ) -> Result<(), Status> { - info!("do_action_end_savepoint"); - Err(Status::unimplemented("Implement do_action_end_savepoint")) - } - - async fn do_action_cancel_query( - &self, - _query: ActionCancelQueryRequest, - _request: Request, - ) -> Result { - info!("do_action_cancel_query"); - Err(Status::unimplemented("Implement do_action_cancel_query")) - } - async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {} } diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 3973e50474ba..b42f25437d77 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::result::Result as RResult; +use std::sync::Arc; + use datafusion::error::Result; use datafusion::execution::context::{ FunctionFactory, RegisterFunction, SessionContext, SessionState, }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{exec_err, internal_err, DataFusionError}; -use datafusion_expr::simplify::ExprSimplifyResult; -use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{CreateFunction, Expr, ScalarUDF, ScalarUDFImpl, Signature}; -use std::result::Result as RResult; -use std::sync::Arc; /// This example shows how to utilize [FunctionFactory] to implement simple /// SQL-macro like functions using a `CREATE FUNCTION` statement. The same @@ -120,7 +121,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } @@ -156,8 +157,8 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &[] } - fn monotonicity(&self) -> Result> { - Ok(None) + fn output_ordering(&self, _input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) } } @@ -211,7 +212,7 @@ impl TryFrom for ScalarFunctionWrapper { name: definition.name, expr: definition .params - .return_ + .function_body .expect("Expression has to be defined!"), return_type: definition .return_type diff --git a/datafusion-examples/examples/json_opener.rs b/datafusion-examples/examples/json_opener.rs index ee33f969caa9..7bc431c5c5ee 100644 --- a/datafusion-examples/examples/json_opener.rs +++ b/datafusion-examples/examples/json_opener.rs @@ -29,7 +29,6 @@ use datafusion::{ error::Result, physical_plan::metrics::ExecutionPlanMetricsSet, }; -use datafusion_common::Statistics; use futures::StreamExt; use object_store::ObjectStore; @@ -45,7 +44,7 @@ async fn main() -> Result<()> { {"num":2,"str":"hello"} {"num":4,"str":"foo"}"#, ); - object_store.put(&path, data).await.unwrap(); + object_store.put(&path, data.into()).await.unwrap(); let schema = Arc::new(Schema::new(vec![ Field::new("num", DataType::Int64, false), @@ -61,16 +60,11 @@ async fn main() -> Result<()> { Arc::new(object_store), ); - let scan_config = FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new(path.to_string(), 10)]], - statistics: Statistics::new_unknown(&schema), - projection: Some(vec![1, 0]), - limit: Some(5), - table_partition_cols: vec![], - output_ordering: vec![], - }; + let scan_config = + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema.clone()) + .with_projection(Some(vec![1, 0])) + .with_limit(Some(5)) + .with_file(PartitionedFile::new(path.to_string(), 10)); let result = FileStream::new(&scan_config, 0, opener, &ExecutionPlanMetricsSet::new()) diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs new file mode 100644 index 000000000000..e0b552620a9a --- /dev/null +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -0,0 +1,221 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; +use arrow_schema::DataType; +use datafusion::prelude::SessionContext; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{assert_batches_eq, Result, ScalarValue}; +use datafusion_expr::{ + BinaryExpr, ColumnarValue, Expr, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, +}; +use datafusion_optimizer::optimizer::ApplyOrder; +use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; +use std::any::Any; +use std::sync::Arc; + +/// This example demonstrates how to add your own [`OptimizerRule`] +/// to DataFusion. +/// +/// [`OptimizerRule`]s transform [`LogicalPlan`]s into an equivalent (but +/// hopefully faster) form. +/// +/// See [analyzer_rule.rs] for an example of AnalyzerRules, which are for +/// changing plan semantics. +#[tokio::main] +pub async fn main() -> Result<()> { + // DataFusion includes many built in OptimizerRules for tasks such as outer + // to inner join conversion and constant folding. + // + // Note you can change the order of optimizer rules using the lower level + // `SessionState` API + let ctx = SessionContext::new(); + ctx.add_optimizer_rule(Arc::new(MyOptimizerRule {})); + + // Now, let's plan and run queries with the new rule + ctx.register_batch("person", person_batch())?; + let sql = "SELECT * FROM person WHERE age = 22"; + let plan = ctx.sql(sql).await?.into_optimized_plan()?; + + // We can see the effect of our rewrite on the output plan that the filter + // has been rewritten to `my_eq` + assert_eq!( + plan.display_indent().to_string(), + "Filter: my_eq(person.age, Int32(22))\ + \n TableScan: person projection=[name, age]" + ); + + // The query below doesn't respect a filter `where age = 22` because + // the plan has been rewritten using UDF which returns always true + // + // And the output verifies the predicates have been changed (as the my_eq + // function always returns true) + assert_batches_eq!( + [ + "+--------+-----+", + "| name | age |", + "+--------+-----+", + "| Andy | 11 |", + "| Andrew | 22 |", + "| Oleks | 33 |", + "+--------+-----+", + ], + &ctx.sql(sql).await?.collect().await? + ); + + // however we can see the rule doesn't trigger for queries with predicates + // other than `=` + assert_batches_eq!( + [ + "+-------+-----+", + "| name | age |", + "+-------+-----+", + "| Andy | 11 |", + "| Oleks | 33 |", + "+-------+-----+", + ], + &ctx.sql("SELECT * FROM person WHERE age <> 22") + .await? + .collect() + .await? + ); + + Ok(()) +} + +/// An example OptimizerRule that replaces all `col = ` predicates with a +/// user defined function +#[derive(Default, Debug)] +struct MyOptimizerRule {} + +impl OptimizerRule for MyOptimizerRule { + fn name(&self) -> &str { + "my_optimizer_rule" + } + + // New OptimizerRules should use the "rewrite" api as it is more efficient + fn supports_rewrite(&self) -> bool { + true + } + + /// Ask the optimizer to handle the plan recursion. `rewrite` will be called + /// on each plan node. + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + plan.map_expressions(|expr| { + // This closure is called for all expressions in the current plan + // + // For example, given a plan like `SELECT a + b, 5 + 10` + // + // The closure would be called twice: + // 1. once for `a + b` + // 2. once for `5 + 10` + self.rewrite_expr(expr) + }) + } +} + +impl MyOptimizerRule { + /// Rewrites an Expr replacing all ` = ` expressions with + /// a call to my_eq udf + fn rewrite_expr(&self, expr: Expr) -> Result> { + // do a bottom up rewrite of the expression tree + expr.transform_up(|expr| { + // Closure called for each sub tree + match expr { + Expr::BinaryExpr(binary_expr) if is_binary_eq(&binary_expr) => { + // destruture the expression + let BinaryExpr { left, op: _, right } = binary_expr; + // rewrite to `my_eq(left, right)` + let udf = ScalarUDF::new_from_impl(MyEq::new()); + let call = udf.call(vec![*left, *right]); + Ok(Transformed::yes(call)) + } + _ => Ok(Transformed::no(expr)), + } + }) + // Note that the TreeNode API handles propagating the transformed flag + // and errors up the call chain + } +} + +/// return true of the expression is an equality expression for a literal or +/// column reference +fn is_binary_eq(binary_expr: &BinaryExpr) -> bool { + binary_expr.op == Operator::Eq + && is_lit_or_col(binary_expr.left.as_ref()) + && is_lit_or_col(binary_expr.right.as_ref()) +} + +/// Return true if the expression is a literal or column reference +fn is_lit_or_col(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +/// A simple user defined filter function +#[derive(Debug, Clone)] +struct MyEq { + signature: Signature, +} + +impl MyEq { + fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Stable), + } + } +} + +impl ScalarUDFImpl for MyEq { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_eq" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + // this example simply returns "true" which is not what a real + // implementation would do. + Ok(ColumnarValue::Scalar(ScalarValue::from(true))) + } +} + +/// Return a RecordBatch with made up data +fn person_batch() -> RecordBatch { + let name: ArrayRef = + Arc::new(StringArray::from_iter_values(["Andy", "Andrew", "Oleks"])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22, 33])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap() +} diff --git a/datafusion-examples/examples/parquet_exec_visitor.rs b/datafusion-examples/examples/parquet_exec_visitor.rs new file mode 100644 index 000000000000..eeb288beb0df --- /dev/null +++ b/datafusion-examples/examples/parquet_exec_visitor.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::listing::{ListingOptions, PartitionedFile}; +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion::execution::context::SessionContext; +use datafusion::physical_plan::metrics::MetricValue; +use datafusion::physical_plan::{ + execute_stream, visit_execution_plan, ExecutionPlan, ExecutionPlanVisitor, +}; +use futures::StreamExt; + +/// Example of collecting metrics after execution by visiting the `ExecutionPlan` +#[tokio::main] +async fn main() { + let ctx = SessionContext::new(); + + let test_data = datafusion::test_util::parquet_test_data(); + + // Configure listing options + let file_format = ParquetFormat::default().with_enable_pruning(true); + let listing_options = ListingOptions::new(Arc::new(file_format)); + + // First example were we use an absolute path, which requires no additional setup. + let _ = ctx + .register_listing_table( + "my_table", + &format!("file://{test_data}/alltypes_plain.parquet"), + listing_options.clone(), + None, + None, + ) + .await; + + let df = ctx.sql("SELECT * FROM my_table").await.unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + + // Create empty visitor + let mut visitor = ParquetExecVisitor { + file_groups: None, + bytes_scanned: None, + }; + + // Make sure you execute the plan to collect actual execution statistics. + // For example, in this example the `file_scan_config` is known without executing + // but the `bytes_scanned` would be None if we did not execute. + let mut batch_stream = execute_stream(plan.clone(), ctx.task_ctx()).unwrap(); + while let Some(batch) = batch_stream.next().await { + println!("Batch rows: {}", batch.unwrap().num_rows()); + } + + visit_execution_plan(plan.as_ref(), &mut visitor).unwrap(); + + println!( + "ParquetExecVisitor bytes_scanned: {:?}", + visitor.bytes_scanned + ); + println!( + "ParquetExecVisitor file_groups: {:?}", + visitor.file_groups.unwrap() + ); +} + +/// Define a struct with fields to hold the execution information you want to +/// collect. In this case, I want information on how many bytes were scanned +/// and `file_groups` from the FileScanConfig. +#[derive(Debug)] +struct ParquetExecVisitor { + file_groups: Option>>, + bytes_scanned: Option, +} + +impl ExecutionPlanVisitor for ParquetExecVisitor { + type Error = datafusion_common::DataFusionError; + + /// This function is called once for every node in the tree. + /// Based on your needs implement either `pre_visit` (visit each node before its children/inputs) + /// or `post_visit` (visit each node after its children/inputs) + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + // If needed match on a specific `ExecutionPlan` node type + let maybe_parquet_exec = plan.as_any().downcast_ref::(); + if let Some(parquet_exec) = maybe_parquet_exec { + self.file_groups = Some(parquet_exec.base_config().file_groups.clone()); + + let metrics = match parquet_exec.metrics() { + None => return Ok(true), + Some(metrics) => metrics, + }; + self.bytes_scanned = metrics.sum_by_name("bytes_scanned"); + } + Ok(true) + } +} diff --git a/datafusion-examples/examples/parquet_index.rs b/datafusion-examples/examples/parquet_index.rs new file mode 100644 index 000000000000..d6e17764442d --- /dev/null +++ b/datafusion-examples/examples/parquet_index.rs @@ -0,0 +1,703 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, Int32Array, RecordBatch, StringArray, + UInt64Array, +}; +use arrow::datatypes::Int32Type; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use datafusion::catalog::Session; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::datasource::TableProvider; +use datafusion::execution::object_store::ObjectStoreUrl; +use datafusion::parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use datafusion::parquet::arrow::{ + arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter, +}; +use datafusion::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::*; +use datafusion_common::{ + internal_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{utils::conjunction, TableProviderFilterPushDown, TableType}; +use datafusion_physical_expr::PhysicalExpr; +use std::any::Any; +use std::collections::HashSet; +use std::fmt::Display; +use std::fs::{self, DirEntry, File}; +use std::ops::Range; +use std::path::{Path, PathBuf}; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tempfile::TempDir; +use url::Url; + +/// This example demonstrates building a secondary index over multiple Parquet +/// files and using that index during query to skip ("prune") files that do not +/// contain relevant data. +/// +/// This example rules out relevant data using min/max values of a column +/// extracted from the Parquet metadata. In a real system, the index could be +/// more sophisticated, e.g. using inverted indices, bloom filters or other +/// techniques. +/// +/// Note this is a low level example for people who want to build their own +/// custom indexes. To read a directory of parquet files as a table, you can use +/// a higher level API such as [`SessionContext::read_parquet`] or +/// [`ListingTable`], which also do file pruning based on parquet statistics +/// (using the same underlying APIs) +/// +/// For a more advanced example of using an index to prune row groups within a +/// file, see the (forthcoming) `advanced_parquet_index` example. +/// +/// # Diagram +/// +/// ```text +/// ┏━━━━━━━━━━━━━━━━━━━━━━━━┓ +/// ┃ Index ┃ +/// ┃ ┃ +/// step 1: predicate is ┌ ─ ─ ─ ─▶┃ (sometimes referred to ┃ +/// evaluated against ┃ as a "catalog" or ┃ +/// data in the index │ ┃ "metastore") ┃ +/// (using ┗━━━━━━━━━━━━━━━━━━━━━━━━┛ +/// PruningPredicate) │ │ +/// +/// │ │ +/// ┌──────────────┐ +/// │ value = 150 │─ ─ ─ ─ ┘ │ +/// └──────────────┘ ┌─────────────┐ +/// Predicate from query │ │ │ +/// └─────────────┘ +/// │ ┌─────────────┐ +/// step 2: Index returns only ─ ▶│ │ +/// parquet files that might have └─────────────┘ +/// matching data. ... +/// ┌─────────────┐ +/// Thus some parquet files are │ │ +/// "pruned" and thus are not └─────────────┘ +/// scanned at all Parquet Files +/// +/// ``` +/// +/// [`ListingTable`]: datafusion::datasource::listing::ListingTable +#[tokio::main] +async fn main() -> Result<()> { + // Demo data has three files, each with schema + // * file_name (string) + // * value (int32) + // + // The files are as follows: + // * file1.parquet (value: 0..100) + // * file2.parquet (value: 100..200) + // * file3.parquet (value: 200..3000) + let data = DemoData::try_new()?; + + // Create a table provider with and our special index. + let provider = Arc::new(IndexTableProvider::try_new(data.path())?); + println!("** Table Provider:"); + println!("{provider}\n"); + + // Create a SessionContext for running queries that has the table provider + // registered as "index_table" + let ctx = SessionContext::new(); + ctx.register_table("index_table", Arc::clone(&provider) as _)?; + + // register object store provider for urls like `file://` work + let url = Url::try_from("file://").unwrap(); + let object_store = object_store::local::LocalFileSystem::new(); + ctx.register_object_store(&url, Arc::new(object_store)); + + // Select data from the table without any predicates (and thus no pruning) + println!("** Select data, no predicates:"); + ctx.sql("SELECT file_name, value FROM index_table LIMIT 10") + .await? + .show() + .await?; + println!("Files pruned: {}\n", provider.index().last_num_pruned()); + + // Run a query that uses the index to prune files. + // + // Using the predicate "value = 150", the IndexTable can skip reading file 1 + // (max value 100) and file 3 (min value of 200) + println!("** Select data, predicate `value = 150`"); + ctx.sql("SELECT file_name, value FROM index_table WHERE value = 150") + .await? + .show() + .await?; + println!("Files pruned: {}\n", provider.index().last_num_pruned()); + + // likewise, we can use a more complicated predicate like + // "value < 20 OR value > 500" to read only file 1 and file 3 + println!("** Select data, predicate `value < 20 OR value > 500`"); + ctx.sql( + "SELECT file_name, count(value) FROM index_table \ + WHERE value < 20 OR value > 500 GROUP BY file_name", + ) + .await? + .show() + .await?; + println!("Files pruned: {}\n", provider.index().last_num_pruned()); + + Ok(()) +} + +/// DataFusion `TableProvider` that uses [`IndexTableProvider`], a secondary +/// index to decide which Parquet files to read. +#[derive(Debug)] +pub struct IndexTableProvider { + /// The index of the parquet files in the directory + index: ParquetMetadataIndex, + /// the directory in which the files are stored + dir: PathBuf, +} + +impl Display for IndexTableProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "IndexTableProvider")?; + writeln!(f, "---- Index ----")?; + write!(f, "{}", self.index) + } +} + +impl IndexTableProvider { + /// Create a new IndexTableProvider + pub fn try_new(dir: impl Into) -> Result { + let dir = dir.into(); + + // Create an index of the parquet files in the directory as we see them. + let mut index_builder = ParquetMetadataIndexBuilder::new(); + + let files = read_dir(&dir)?; + for file in &files { + index_builder.add_file(&file.path())?; + } + + let index = index_builder.build()?; + + Ok(Self { index, dir }) + } + + /// return a reference to the underlying index + fn index(&self) -> &ParquetMetadataIndex { + &self.index + } +} + +#[async_trait] +impl TableProvider for IndexTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.index.schema().clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let df_schema = DFSchema::try_from(self.schema())?; + // convert filters like [`a = 1`, `b = 2`] to a single filter like `a = 1 AND b = 2` + let predicate = conjunction(filters.to_vec()); + let predicate = predicate + .map(|predicate| state.create_physical_expr(predicate, &df_schema)) + .transpose()? + // if there are no filters, use a literal true to have a predicate + // that always evaluates to true we can pass to the index + .unwrap_or_else(|| datafusion_physical_expr::expressions::lit(true)); + + // Use the index to find the files that might have data that matches the + // predicate. Any file that can not have data that matches the predicate + // will not be returned. + let files = self.index.get_files(predicate.clone())?; + + let object_store_url = ObjectStoreUrl::parse("file://")?; + let mut file_scan_config = FileScanConfig::new(object_store_url, self.schema()) + .with_projection(projection.cloned()) + .with_limit(limit); + + // Transform to the format needed to pass to ParquetExec + // Create one file group per file (default to scanning them all in parallel) + for (file_name, file_size) in files { + let path = self.dir.join(file_name); + let canonical_path = fs::canonicalize(path)?; + file_scan_config = file_scan_config.with_file(PartitionedFile::new( + canonical_path.display().to_string(), + file_size, + )); + } + let exec = ParquetExec::builder(file_scan_config) + .with_predicate(predicate) + .build_arc(); + + Ok(exec) + } + + /// Tell DataFusion to push filters down to the scan method + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + // Inexact because the pruning can't handle all expressions and pruning + // is not done at the row level -- there may be rows in returned files + // that do not pass the filter + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) + } +} + +/// Simple in memory secondary index for a set of parquet files +/// +/// The index is represented as an arrow [`RecordBatch`] that can be passed +/// directly by the DataFusion [`PruningPredicate`] API +/// +/// The `RecordBatch` looks as follows. +/// +/// ```text +/// +---------------+-----------+-----------+------------------+------------------+ +/// | file_name | file_size | row_count | value_column_min | value_column_max | +/// +---------------+-----------+-----------+------------------+------------------+ +/// | file1.parquet | 6062 | 100 | 0 | 99 | +/// | file2.parquet | 6062 | 100 | 100 | 199 | +/// | file3.parquet | 163310 | 2800 | 200 | 2999 | +/// +---------------+-----------+-----------+------------------+------------------+ +/// ``` +/// +/// It must store file_name and file_size to construct `PartitionedFile`. +/// +/// Note a more advanced index might store finer grained information, such as information +/// about each row group within a file +#[derive(Debug)] +struct ParquetMetadataIndex { + file_schema: SchemaRef, + /// The index of the parquet files. See the struct level documentation for + /// the schema of this index. + index: RecordBatch, + /// The number of files that were pruned in the last query + last_num_pruned: AtomicUsize, +} + +impl Display for ParquetMetadataIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + "ParquetMetadataIndex(last_num_pruned: {})", + self.last_num_pruned() + )?; + let batches = pretty_format_batches(&[self.index.clone()]).unwrap(); + write!(f, "{batches}",) + } +} + +impl ParquetMetadataIndex { + /// the schema of the *files* in the index (not the index's schema) + fn schema(&self) -> &SchemaRef { + &self.file_schema + } + + /// number of files in the index + fn len(&self) -> usize { + self.index.num_rows() + } + + /// Return a [`PartitionedFile`] for the specified file offset + /// + /// For example, if the index batch contained data like + /// + /// ```text + /// fileA + /// fileB + /// fileC + /// ``` + /// + /// `get_file(1)` would return `(fileB, size)` + fn get_file(&self, file_offset: usize) -> (&str, u64) { + // Filenames and sizes are always non null, so we don't have to check is_valid + let file_name = self.file_names().value(file_offset); + let file_size = self.file_size().value(file_offset); + (file_name, file_size) + } + + /// Return the number of files that were pruned in the last query + pub fn last_num_pruned(&self) -> usize { + self.last_num_pruned.load(Ordering::SeqCst) + } + + /// Set the number of files that were pruned in the last query + fn set_last_num_pruned(&self, num_pruned: usize) { + self.last_num_pruned.store(num_pruned, Ordering::SeqCst); + } + + /// Return all the files matching the predicate + /// + /// Returns a tuple `(file_name, file_size)` + pub fn get_files( + &self, + predicate: Arc, + ) -> Result> { + // Use the PruningPredicate API to determine which files can not + // possibly have any relevant data. + let pruning_predicate = + PruningPredicate::try_new(predicate, self.schema().clone())?; + + // Now evaluate the pruning predicate into a boolean mask, one element per + // file in the index. If the mask is true, the file may have rows that + // match the predicate. If the mask is false, we know the file can not have *any* + // rows that match the predicate and thus can be skipped. + let file_mask = pruning_predicate.prune(self)?; + + let num_left = file_mask.iter().filter(|x| **x).count(); + self.set_last_num_pruned(self.len() - num_left); + + // Return only files that match the predicate from the index + let files_and_sizes: Vec<_> = file_mask + .into_iter() + .enumerate() + .filter_map(|(file, keep)| { + if keep { + Some(self.get_file(file)) + } else { + None + } + }) + .collect(); + Ok(files_and_sizes) + } + + /// Return the file_names column of this index + fn file_names(&self) -> &StringArray { + self.index + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + } + + /// Return the file_size column of this index + fn file_size(&self) -> &UInt64Array { + self.index + .column(1) + .as_any() + .downcast_ref::() + .unwrap() + } + + /// Reference to the row count column + fn row_counts_ref(&self) -> &ArrayRef { + self.index.column(2) + } + + /// Reference to the column minimum values + fn value_column_mins(&self) -> &ArrayRef { + self.index.column(3) + } + + /// Reference to the column maximum values + fn value_column_maxes(&self) -> &ArrayRef { + self.index.column(4) + } +} + +/// In order to use the PruningPredicate API, we need to provide DataFusion +/// the required statistics via the [`PruningStatistics`] trait +impl PruningStatistics for ParquetMetadataIndex { + /// return the minimum values for the value column + fn min_values(&self, column: &Column) -> Option { + if column.name.eq("value") { + Some(self.value_column_mins().clone()) + } else { + None + } + } + + /// return the maximum values for the value column + fn max_values(&self, column: &Column) -> Option { + if column.name.eq("value") { + Some(self.value_column_maxes().clone()) + } else { + None + } + } + + /// return the number of "containers". In this example, each "container" is + /// a file (aka a row in the index) + fn num_containers(&self) -> usize { + self.len() + } + + /// Return `None` to signal we don't have any information about null + /// counts in the index, + fn null_counts(&self, _column: &Column) -> Option { + None + } + + /// return the row counts for each file + fn row_counts(&self, _column: &Column) -> Option { + Some(self.row_counts_ref().clone()) + } + + /// The `contained` API can be used with structures such as Bloom filters, + /// but is not used in this example, so return `None` + fn contained( + &self, + _column: &Column, + _values: &HashSet, + ) -> Option { + None + } +} + +/// Builds a [`ParquetMetadataIndex`] from a set of parquet files +#[derive(Debug, Default)] +struct ParquetMetadataIndexBuilder { + file_schema: Option, + filenames: Vec, + file_sizes: Vec, + row_counts: Vec, + /// Holds the min/max value of the value column for each file + value_column_mins: Vec, + value_column_maxs: Vec, +} + +impl ParquetMetadataIndexBuilder { + fn new() -> Self { + Self::default() + } + + /// Add a new file to the index + fn add_file(&mut self, file: &Path) -> Result<()> { + let file_name = file + .file_name() + .ok_or_else(|| internal_datafusion_err!("No filename"))? + .to_str() + .ok_or_else(|| internal_datafusion_err!("Invalid filename"))?; + let file_size = file.metadata()?.len(); + + let file = File::open(file).map_err(|e| { + DataFusionError::from(e).context(format!("Error opening file {file:?}")) + })?; + + let reader = ParquetRecordBatchReaderBuilder::try_new(file)?; + + // Get the schema of the file. A real system might have to handle the + // case where the schema of the file is not the same as the schema of + // the other files e.g. using SchemaAdapter. + if self.file_schema.is_none() { + self.file_schema = Some(reader.schema().clone()); + } + + // extract the parquet statistics from the file's footer + let metadata = reader.metadata(); + let row_groups = metadata.row_groups(); + + // Extract the min/max values for each row group from the statistics + let converter = StatisticsConverter::try_new( + "value", + reader.schema(), + reader.parquet_schema(), + )?; + let row_counts = converter + .row_group_row_counts(row_groups.iter())? + .ok_or_else(|| { + internal_datafusion_err!("Row group row counts are missing") + })?; + let value_column_mins = converter.row_group_mins(row_groups.iter())?; + let value_column_maxes = converter.row_group_maxes(row_groups.iter())?; + + // In a real system you would have to handle nulls, which represent + // unknown statistics. All statistics are known in this example + assert_eq!(row_counts.null_count(), 0); + assert_eq!(value_column_mins.null_count(), 0); + assert_eq!(value_column_maxes.null_count(), 0); + + // The statistics gathered above are for each row group. We need to + // aggregate them together to compute the overall file row count, + // min and max. + let row_count = row_counts + .iter() + .flatten() // skip nulls (should be none) + .sum::(); + let value_column_min = value_column_mins + .as_primitive::() + .iter() + .flatten() // skip nulls (i.e. min is unknown) + .min() + .unwrap_or_default(); + let value_column_max = value_column_maxes + .as_primitive::() + .iter() + .flatten() // skip nulls (i.e. max is unknown) + .max() + .unwrap_or_default(); + + // sanity check the statistics + assert_eq!(row_count, metadata.file_metadata().num_rows() as u64); + + self.add_row( + file_name, + file_size, + row_count, + value_column_min, + value_column_max, + ); + Ok(()) + } + + /// Add an entry for a single new file to the in progress index + fn add_row( + &mut self, + file_name: impl Into, + file_size: u64, + row_count: u64, + value_column_min: i32, + value_column_max: i32, + ) { + self.filenames.push(file_name.into()); + self.file_sizes.push(file_size); + self.row_counts.push(row_count); + self.value_column_mins.push(value_column_min); + self.value_column_maxs.push(value_column_max); + } + + /// Build the index from the files added + fn build(self) -> Result { + let Some(file_schema) = self.file_schema else { + return Err(internal_datafusion_err!("No files added to index")); + }; + + let file_name: ArrayRef = Arc::new(StringArray::from(self.filenames)); + let file_size: ArrayRef = Arc::new(UInt64Array::from(self.file_sizes)); + let row_count: ArrayRef = Arc::new(UInt64Array::from(self.row_counts)); + let value_column_min: ArrayRef = + Arc::new(Int32Array::from(self.value_column_mins)); + let value_column_max: ArrayRef = + Arc::new(Int32Array::from(self.value_column_maxs)); + + let index = RecordBatch::try_from_iter(vec![ + ("file_name", file_name), + ("file_size", file_size), + ("row_count", row_count), + ("value_column_min", value_column_min), + ("value_column_max", value_column_max), + ])?; + + Ok(ParquetMetadataIndex { + file_schema, + index, + last_num_pruned: AtomicUsize::new(0), + }) + } +} + +/// Return a list of the directory entries in the given directory, sorted by name +fn read_dir(dir: &Path) -> Result> { + let mut files = dir + .read_dir() + .map_err(|e| { + DataFusionError::from(e).context(format!("Error reading directory {dir:?}")) + })? + .map(|entry| { + entry.map_err(|e| { + DataFusionError::from(e) + .context(format!("Error reading directory entry in {dir:?}")) + }) + }) + .collect::>>()?; + files.sort_by_key(|entry| entry.file_name()); + Ok(files) +} + +/// Demonstration Data +/// +/// Makes a directory with three parquet files +/// +/// The schema of the files is +/// * file_name (string) +/// * value (int32) +/// +/// The files are as follows: +/// * file1.parquet (values 0..100) +/// * file2.parquet (values 100..200) +/// * file3.parquet (values 200..3000) +struct DemoData { + tmpdir: TempDir, +} + +impl DemoData { + fn try_new() -> Result { + let tmpdir = TempDir::new()?; + make_demo_file(tmpdir.path().join("file1.parquet"), 0..100)?; + make_demo_file(tmpdir.path().join("file2.parquet"), 100..200)?; + make_demo_file(tmpdir.path().join("file3.parquet"), 200..3000)?; + + Ok(Self { tmpdir }) + } + + fn path(&self) -> PathBuf { + self.tmpdir.path().into() + } +} + +/// Creates a new parquet file at the specified path. +/// +/// The `value` column increases sequentially from `min_value` to `max_value` +/// with the following schema: +/// +/// * file_name: Utf8 +/// * value: Int32 +fn make_demo_file(path: impl AsRef, value_range: Range) -> Result<()> { + let path = path.as_ref(); + let file = File::create(path)?; + let filename = path + .file_name() + .ok_or_else(|| internal_datafusion_err!("No filename"))? + .to_str() + .ok_or_else(|| internal_datafusion_err!("Invalid filename"))?; + + let num_values = value_range.len(); + let file_names = + StringArray::from_iter_values(std::iter::repeat(&filename).take(num_values)); + let values = Int32Array::from_iter_values(value_range); + let batch = RecordBatch::try_from_iter(vec![ + ("file_name", Arc::new(file_names) as ArrayRef), + ("value", Arc::new(values) as ArrayRef), + ])?; + + let schema = batch.schema(); + + // write the actual values to the file + let props = None; + let mut writer = ArrowWriter::try_new(file, schema, props)?; + writer.write(&batch)?; + writer.finish()?; + + Ok(()) +} diff --git a/datafusion-examples/examples/parquet_sql.rs b/datafusion-examples/examples/parquet_sql.rs deleted file mode 100644 index fb438a7832cb..000000000000 --- a/datafusion-examples/examples/parquet_sql.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::error::Result; -use datafusion::prelude::*; - -/// This example demonstrates executing a simple query against an Arrow data source (Parquet) and -/// fetching results -#[tokio::main] -async fn main() -> Result<()> { - // create local session context - let ctx = SessionContext::new(); - - let testdata = datafusion::test_util::parquet_test_data(); - - // register parquet file with the execution context - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; - - // execute the query - let df = ctx - .sql( - "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ - FROM alltypes_plain \ - WHERE id > 1 AND tinyint_col < double_col", - ) - .await?; - - // print the results - df.show().await?; - - Ok(()) -} diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index 30ca1df73d91..b0d3922a3278 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -80,7 +80,7 @@ async fn main() -> Result<(), Box> { let local_fs = Arc::new(LocalFileSystem::default()); let u = url::Url::parse("file://./")?; - ctx.runtime_env().register_object_store(&u, local_fs); + ctx.register_object_store(&u, local_fs); // Register a listing table - this will use all files in the directory as data sources // for the query diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs new file mode 100644 index 000000000000..e23e5accae39 --- /dev/null +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::{ + assert_batches_eq, + error::Result, + prelude::{ParquetReadOptions, SessionContext}, +}; +use datafusion_common::DFSchema; +use datafusion_expr::{col, lit}; +use datafusion_sql::unparser::Unparser; + +/// This example demonstrates the programmatic parsing of SQL expressions using +/// the DataFusion [`SessionContext::parse_sql_expr`] API or the [`DataFrame::parse_sql_expr`] API. +/// +/// +/// The code in this example shows how to: +/// +/// 1. [`simple_session_context_parse_sql_expr_demo`]: Parse a simple SQL text into a logical +/// expression using a schema at [`SessionContext`]. +/// +/// 2. [`simple_dataframe_parse_sql_expr_demo`]: Parse a simple SQL text into a logical expression +/// using a schema at [`DataFrame`]. +/// +/// 3. [`query_parquet_demo`]: Query a parquet file using the parsed_sql_expr from a DataFrame. +/// +/// 4. [`round_trip_parse_sql_expr_demo`]: Parse a SQL text and convert it back to SQL using [`Unparser`]. + +#[tokio::main] +async fn main() -> Result<()> { + // See how to evaluate expressions + simple_session_context_parse_sql_expr_demo()?; + simple_dataframe_parse_sql_expr_demo().await?; + query_parquet_demo().await?; + round_trip_parse_sql_expr_demo().await?; + Ok(()) +} + +/// DataFusion can parse a SQL text to a logical expression against a schema at [`SessionContext`]. +fn simple_session_context_parse_sql_expr_demo() -> Result<()> { + let sql = "a < 5 OR a = 8"; + let expr = col("a").lt(lit(5_i64)).or(col("a").eq(lit(8_i64))); + + // provide type information that `a` is an Int32 + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema).unwrap(); + let ctx = SessionContext::new(); + + let parsed_expr = ctx.parse_sql_expr(sql, &df_schema)?; + + assert_eq!(parsed_expr, expr); + + Ok(()) +} + +/// DataFusion can parse a SQL text to an logical expression using schema at [`DataFrame`]. +async fn simple_dataframe_parse_sql_expr_demo() -> Result<()> { + let sql = "int_col < 5 OR double_col = 8.0"; + let expr = col("int_col") + .lt(lit(5_i64)) + .or(col("double_col").eq(lit(8.0_f64))); + + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let parsed_expr = df.parse_sql_expr(sql)?; + + assert_eq!(parsed_expr, expr); + + Ok(()) +} + +async fn query_parquet_demo() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let df = df + .clone() + .select(vec![ + df.parse_sql_expr("int_col")?, + df.parse_sql_expr("double_col")?, + ])? + .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .aggregate( + vec![df.parse_sql_expr("double_col")?], + vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + )? + // Directly parsing the SQL text into a sort expression is not supported yet, so + // construct it programmatically + .sort(vec![col("double_col").sort(false, false)])? + .limit(0, Some(1))?; + + let result = df.collect().await?; + + assert_batches_eq!( + &[ + "+------------+----------------------+", + "| double_col | sum(?table?.int_col) |", + "+------------+----------------------+", + "| 10.1 | 4 |", + "+------------+----------------------+", + ], + &result + ); + + Ok(()) +} + +/// DataFusion can parse a SQL text and convert it back to SQL using [`Unparser`]. +async fn round_trip_parse_sql_expr_demo() -> Result<()> { + let sql = "((int_col < 5) OR (double_col = 8))"; + + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let parsed_expr = df.parse_sql_expr(sql)?; + + let unparser = Unparser::default(); + let round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + + assert_eq!(sql, round_trip_sql); + + // enable pretty-unparsing. This make the output more human-readable + // but can be problematic when passed to other SQL engines due to + // difference in precedence rules between DataFusion and target engines. + let unparser = Unparser::default().with_pretty(true); + + let pretty = "int_col < 5 OR double_col = 8"; + let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + assert_eq!(pretty, pretty_round_trip_sql); + + Ok(()) +} diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs new file mode 100644 index 000000000000..8ea7c2951223 --- /dev/null +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::error::Result; + +use datafusion::prelude::*; +use datafusion::sql::unparser::expr_to_sql; +use datafusion_sql::unparser::dialect::CustomDialectBuilder; +use datafusion_sql::unparser::{plan_to_sql, Unparser}; + +/// This example demonstrates the programmatic construction of SQL strings using +/// the DataFusion Expr [`Expr`] and LogicalPlan [`LogicalPlan`] API. +/// +/// +/// The code in this example shows how to: +/// +/// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with +/// fluent API and convert to sql suitable for passing to another database +/// +/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans +/// +/// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple +/// expression [`Exprs`] with fluent API and convert to sql escaping column +/// names in MySQL style. +/// +/// 4. [`simple_plan_to_sql_demo`]: Create a simple logical plan using the +/// DataFrames API and convert to sql string. +/// +/// 5. [`round_trip_plan_to_sql_demo`]: Create a logical plan from a SQL string, modify it using the +/// DataFrames API and convert it back to a sql string. + +#[tokio::main] +async fn main() -> Result<()> { + // See how to evaluate expressions + simple_expr_to_sql_demo()?; + simple_expr_to_pretty_sql_demo()?; + simple_expr_to_sql_demo_escape_mysql_style()?; + simple_plan_to_sql_demo().await?; + round_trip_plan_to_sql_demo().await?; + Ok(()) +} + +/// DataFusion can convert expressions to SQL, using column name escaping +/// PostgreSQL style. +fn simple_expr_to_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let sql = expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"((a < 5) OR (a = 8))"#); + Ok(()) +} + +/// DataFusioon can remove parentheses when converting an expression to SQL. +/// Note that output is intended for humans, not for other SQL engines, +/// as difference in precedence rules can cause expressions to be parsed differently. +fn simple_expr_to_pretty_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let unparser = Unparser::default().with_pretty(true); + let sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"a < 5 OR a = 8"#); + Ok(()) +} + +/// DataFusion can convert expressions to SQL without escaping column names using +/// using a custom dialect and an explicit unparser +fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let dialect = CustomDialectBuilder::new() + .with_identifier_quote_style('`') + .build(); + let unparser = Unparser::new(&dialect); + let sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"((`a` < 5) OR (`a` = 8))"#); + Ok(()) +} + +/// DataFusion can convert a logic plan created using the DataFrames API to read from a parquet file +/// to SQL, using column name escaping PostgreSQL style. +async fn simple_plan_to_sql_demo() -> Result<()> { + let ctx = SessionContext::new(); + + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await? + .select_columns(&["id", "int_col", "double_col", "date_string_col"])?; + + // Convert the data frame to a SQL string + let sql = plan_to_sql(df.logical_plan())?.to_string(); + + assert_eq!( + sql, + r#"SELECT "?table?".id, "?table?".int_col, "?table?".double_col, "?table?".date_string_col FROM "?table?""# + ); + + Ok(()) +} + +/// DataFusion can also be used to parse SQL, programmatically modify the query +/// (in this case adding a filter) and then and converting back to SQL. +async fn round_trip_plan_to_sql_demo() -> Result<()> { + let ctx = SessionContext::new(); + + let testdata = datafusion::test_util::parquet_test_data(); + + // register parquet file with the execution context + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + // create a logical plan from a SQL string and then programmatically add new filters + let df = ctx + // Use SQL to read some data from the parquet file + .sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain", + ) + .await? + // Add id > 1 and tinyint_col < double_col filter + .filter( + col("id") + .gt(lit(1)) + .and(col("tinyint_col").lt(col("double_col"))), + )?; + + let sql = plan_to_sql(df.logical_plan())?.to_string(); + assert_eq!( + sql, + r#"SELECT alltypes_plain.int_col, alltypes_plain.double_col, CAST(alltypes_plain.date_string_col AS VARCHAR) FROM alltypes_plain WHERE ((alltypes_plain.id > 1) AND (alltypes_plain.tinyint_col < alltypes_plain.double_col))"# + ); + + Ok(()) +} diff --git a/datafusion-examples/examples/planner_api.rs b/datafusion-examples/examples/planner_api.rs new file mode 100644 index 000000000000..35cf766ba1af --- /dev/null +++ b/datafusion-examples/examples/planner_api.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::error::Result; +use datafusion::physical_plan::displayable; +use datafusion::physical_planner::DefaultPhysicalPlanner; +use datafusion::prelude::*; +use datafusion_expr::{LogicalPlan, PlanType}; + +/// This example demonstrates the process of converting logical plan +/// into physical execution plans using DataFusion. +/// +/// Planning phase in DataFusion contains several steps: +/// 1. Analyzing and optimizing logical plan +/// 2. Converting logical plan into physical plan +/// +/// The code in this example shows two ways to convert a logical plan into +/// physical plan: +/// - Via the combined `create_physical_plan` API. +/// - Utilizing the analyzer, optimizer, and query planner APIs separately. +#[tokio::main] +async fn main() -> Result<()> { + // Set up a DataFusion context and load a Parquet file + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + let df = ctx + .read_parquet( + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + // Construct the input logical plan using DataFrame API + let df = df + .clone() + .select(vec![ + df.parse_sql_expr("int_col")?, + df.parse_sql_expr("double_col")?, + ])? + .filter(df.parse_sql_expr("int_col < 5 OR double_col = 8.0")?)? + .aggregate( + vec![df.parse_sql_expr("double_col")?], + vec![df.parse_sql_expr("SUM(int_col) as sum_int_col")?], + )? + .limit(0, Some(1))?; + let logical_plan = df.logical_plan().clone(); + + to_physical_plan_in_one_api_demo(&logical_plan, &ctx).await?; + + to_physical_plan_step_by_step_demo(logical_plan, &ctx).await?; + + Ok(()) +} + +/// Converts a logical plan into a physical plan using the combined +/// `create_physical_plan` API. It will first optimize the logical +/// plan and then convert it into physical plan. +async fn to_physical_plan_in_one_api_demo( + input: &LogicalPlan, + ctx: &SessionContext, +) -> Result<()> { + let physical_plan = ctx.state().create_physical_plan(input).await?; + + println!( + "Physical plan direct from logical plan:\n\n{}\n\n", + displayable(physical_plan.as_ref()) + .to_stringified(false, PlanType::InitialPhysicalPlan) + .plan + ); + + Ok(()) +} + +/// Converts a logical plan into a physical plan by utilizing the analyzer, +/// optimizer, and query planner APIs separately. This flavor gives more +/// control over the planning process. +async fn to_physical_plan_step_by_step_demo( + input: LogicalPlan, + ctx: &SessionContext, +) -> Result<()> { + // First analyze the logical plan + let analyzed_logical_plan = ctx.state().analyzer().execute_and_check( + input, + ctx.state().config_options(), + |_, _| (), + )?; + println!("Analyzed logical plan:\n\n{:?}\n\n", analyzed_logical_plan); + + // Optimize the analyzed logical plan + let optimized_logical_plan = ctx.state().optimizer().optimize( + analyzed_logical_plan, + &ctx.state(), + |_, _| (), + )?; + println!( + "Optimized logical plan:\n\n{:?}\n\n", + optimized_logical_plan + ); + + // Create the physical plan + let physical_plan = ctx + .state() + .query_planner() + .create_physical_plan(&optimized_logical_plan, &ctx.state()) + .await?; + println!( + "Final physical plan:\n\n{}\n\n", + displayable(physical_plan.as_ref()) + .to_stringified(false, PlanType::InitialPhysicalPlan) + .plan + ); + + // Call the physical optimizer with an existing physical plan (in this + // case the plan is already optimized, but an unoptimized plan would + // typically be used in this context) + // Note that this is not part of the trait but a public method + // on DefaultPhysicalPlanner. Not all planners will provide this feature. + let planner = DefaultPhysicalPlanner::default(); + let physical_plan = + planner.optimize_physical_plan(physical_plan, &ctx.state(), |_, _| {})?; + println!( + "Optimized physical plan:\n\n{}\n\n", + displayable(physical_plan.as_ref()) + .to_stringified(false, PlanType::InitialPhysicalPlan) + .plan + ); + + Ok(()) +} diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 3fa35049a8da..c090cd2bcca9 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -33,6 +33,11 @@ use std::sync::Arc; /// quickly eliminate entire files / partitions / row groups of data from /// consideration using statistical information from a catalog or other /// metadata. +/// +/// This example uses a user defined catalog to supply pruning information, as +/// one might do as part of a higher level storage engine. See +/// `parquet_index.rs` for an example that uses pruning in the context of an +/// individual query. #[tokio::main] async fn main() { // In this example, we'll use the PruningPredicate to determine if diff --git a/datafusion-examples/examples/query-http-csv.rs b/datafusion-examples/examples/query-http-csv.rs index 928d70271159..fa3fd2ac068d 100644 --- a/datafusion-examples/examples/query-http-csv.rs +++ b/datafusion-examples/examples/query-http-csv.rs @@ -34,8 +34,7 @@ async fn main() -> Result<()> { .with_url(base_url.clone()) .build() .unwrap(); - ctx.runtime_env() - .register_object_store(&base_url, Arc::new(http_store)); + ctx.register_object_store(&base_url, Arc::new(http_store)); // register csv file with the execution context ctx.register_csv( diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs deleted file mode 100644 index 9b94a71a501c..000000000000 --- a/datafusion-examples/examples/rewrite_expr.rs +++ /dev/null @@ -1,255 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result, ScalarValue}; -use datafusion_expr::{ - AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, -}; -use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; -use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; -use datafusion_sql::planner::{ContextProvider, SqlToRel}; -use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; -use datafusion_sql::sqlparser::parser::Parser; -use datafusion_sql::TableReference; -use std::any::Any; -use std::sync::Arc; - -pub fn main() -> Result<()> { - // produce a logical plan using the datafusion-sql crate - let dialect = PostgreSqlDialect {}; - let sql = "SELECT * FROM person WHERE age BETWEEN 21 AND 32"; - let statements = Parser::parse_sql(&dialect, sql)?; - - // produce a logical plan using the datafusion-sql crate - let context_provider = MyContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context_provider); - let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; - println!( - "Unoptimized Logical Plan:\n\n{}\n", - logical_plan.display_indent() - ); - - // run the analyzer with our custom rule - let config = OptimizerContext::default().with_skip_failing_rules(false); - let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); - let analyzed_plan = - analyzer.execute_and_check(&logical_plan, config.options(), |_, _| {})?; - println!( - "Analyzed Logical Plan:\n\n{}\n", - analyzed_plan.display_indent() - ); - - // then run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; - println!( - "Optimized Logical Plan:\n\n{}\n", - optimized_plan.display_indent() - ); - - Ok(()) -} - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}\n", - rule.name(), - plan.display_indent() - ) -} - -/// An example analyzer rule that changes Int64 literals to UInt64 -struct MyAnalyzerRule {} - -impl AnalyzerRule for MyAnalyzerRule { - fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { - Self::analyze_plan(plan) - } - - fn name(&self) -> &str { - "my_analyzer_rule" - } -} - -impl MyAnalyzerRule { - fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(|plan| { - Ok(match plan { - LogicalPlan::Filter(filter) => { - let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?)) - } - _ => Transformed::no(plan), - }) - }) - .data() - } - - fn analyze_expr(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Literal(ScalarValue::Int64(i)) => { - // transform to UInt64 - Transformed::yes(Expr::Literal(ScalarValue::UInt64( - i.map(|i| i as u64), - ))) - } - _ => Transformed::no(expr), - }) - }) - .data() - } -} - -/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions -struct MyOptimizerRule {} - -impl OptimizerRule for MyOptimizerRule { - fn name(&self) -> &str { - "my_optimizer_rule" - } - - fn try_optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // recurse down and optimize children first - let optimized_plan = utils::optimize_children(self, plan, config)?; - match optimized_plan { - Some(LogicalPlan::Filter(filter)) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input, - )?))) - } - Some(optimized_plan) => Ok(Some(optimized_plan)), - None => match plan { - LogicalPlan::Filter(filter) => { - let predicate = my_rewrite(filter.predicate.clone())?; - Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate, - filter.input.clone(), - )?))) - } - _ => Ok(None), - }, - } - } -} - -/// use rewrite_expr to modify the expression tree. -fn my_rewrite(expr: Expr) -> Result { - expr.transform(|expr| { - // closure is invoked for all sub expressions - Ok(match expr { - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - // unbox - let expr: Expr = *expr; - let low: Expr = *low; - let high: Expr = *high; - if negated { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } else { - Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) - } - } - _ => Transformed::no(expr), - }) - }) - .data() -} - -#[derive(Default)] -struct MyContextProvider { - options: ConfigOptions, -} - -impl ContextProvider for MyContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - if name.table() == "person" { - Ok(Arc::new(MyTableSource { - schema: Arc::new(Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("age", DataType::UInt8, false), - ])), - })) - } else { - plan_err!("table not found") - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udfs_names(&self) -> Vec { - Vec::new() - } - - fn udafs_names(&self) -> Vec { - Vec::new() - } - - fn udwfs_names(&self) -> Vec { - Vec::new() - } -} - -struct MyTableSource { - schema: SchemaRef, -} - -impl TableSource for MyTableSource { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 140fc0d3572d..ef97bf9763b0 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -131,7 +131,7 @@ impl Accumulator for GeometricMean { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 64cf7857e223..6879a17f34be 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -109,7 +109,7 @@ async fn main() -> Result<()> { // expects two f64 vec![DataType::Float64, DataType::Float64], // returns f64 - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, pow, ); diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index c68c21fab169..6faa397ef60f 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -20,10 +20,11 @@ use arrow::csv::ReaderBuilder; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::Session; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; -use datafusion::execution::context::{ExecutionProps, SessionState}; +use datafusion::execution::context::ExecutionProps; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; @@ -35,7 +36,6 @@ use std::fs::File; use std::io::Seek; use std::path::Path; use std::sync::Arc; - // To define your own table function, you only need to do the following 3 things: // 1. Implement your own [`TableProvider`] // 2. Implement your own [`TableFunctionImpl`] and return your [`TableProvider`] @@ -73,6 +73,7 @@ async fn main() -> Result<()> { /// Usage: `read_csv(filename, [limit])` /// /// [`read_csv`]: https://duckdb.org/docs/data/csv/overview.html +#[derive(Debug)] struct LocalCsvTable { schema: SchemaRef, limit: Option, @@ -95,7 +96,7 @@ impl TableProvider for LocalCsvTable { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -127,6 +128,7 @@ impl TableProvider for LocalCsvTable { } } +#[derive(Debug)] struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { diff --git a/datafusion-examples/examples/simple_udwf.rs b/datafusion-examples/examples/simple_udwf.rs index 95339eff1cae..22dfbbbf0c3a 100644 --- a/datafusion-examples/examples/simple_udwf.rs +++ b/datafusion-examples/examples/simple_udwf.rs @@ -118,12 +118,12 @@ async fn main() -> Result<()> { df.show().await?; // Now, run the function using the DataFrame API: - let window_expr = smooth_it.call( - vec![col("speed")], // smooth_it(speed) - vec![col("car")], // PARTITION BY car - vec![col("time").sort(true, true)], // ORDER BY time ASC - WindowFrame::new(None), - ); + let window_expr = smooth_it + .call(vec![col("speed")]) // smooth_it(speed) + .partition_by(vec![col("car")]) // PARTITION BY car + .order_by(vec![col("time").sort(true, true)]) // ORDER BY time ASC + .window_frame(WindowFrame::new(None)) + .build()?; let df = ctx.table("cars").await?.window(vec![window_expr])?; // print the results @@ -132,7 +132,7 @@ async fn main() -> Result<()> { Ok(()) } -/// Create a `PartitionEvalutor` to evaluate this function on a new +/// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. fn make_partition_evaluator() -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs new file mode 100644 index 000000000000..52a27317e3c3 --- /dev/null +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow_schema::{Field, Schema}; + +use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; +use datafusion::error::Result; +use datafusion::functions_aggregate::average::avg_udaf; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion::{assert_batches_eq, prelude::*}; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::{ + expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF, + AggregateUDFImpl, GroupsAccumulator, Signature, +}; + +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. + +#[derive(Debug, Clone)] +struct BetterAvgUdaf { + signature: Signature, +} + +impl BetterAvgUdaf { + /// Create a new instance of the GeoMeanUdaf struct + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for BetterAvgUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "better_avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unimplemented!("should not get here"); + } + + // we override method, to return new expression which would substitute + // user defined function call + fn simplify(&self) -> Option { + // as an example for this functionality we replace UDF function + // with build-in aggregate function to illustrate the use + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + avg_udaf(), + // yes it is the same Avg, `BetterAvgUdaf` was just a + // marketing pitch :) + aggregate_function.args, + aggregate_function.distinct, + aggregate_function.filter, + aggregate_function.order_by, + aggregate_function.null_treatment, + ))) + }; + + Some(Box::new(simplify)) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![16.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); + ctx.register_udaf(better_avg.clone()); + + let result = ctx + .sql("SELECT better_avg(a) FROM t group by b") + .await? + .collect() + .await?; + + let expected = [ + "+-----------------+", + "| better_avg(t.a) |", + "+-----------------+", + "| 7.5 |", + "+-----------------+", + ]; + + assert_batches_eq!(expected, &result); + + let df = ctx.table("t").await?; + let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; + + let results = df.collect().await?; + let result = as_float64_array(results[0].column(0))?; + + assert!((result.value(0) - 7.5).abs() < f64::EPSILON); + println!("The average of [2,4,8,16] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs new file mode 100644 index 000000000000..117063df4e0d --- /dev/null +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow_schema::{DataType, Field}; + +use datafusion::execution::context::SessionContext; +use datafusion::functions_aggregate::average::avg_udaf; +use datafusion::{error::Result, execution::options::CsvReadOptions}; +use datafusion_expr::function::{WindowFunctionSimplification, WindowUDFFieldArgs}; +use datafusion_expr::{ + expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature, + Volatility, WindowUDF, WindowUDFImpl, +}; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; + +/// This UDWF will show how to use the WindowUDFImpl::simplify() API +#[derive(Debug, Clone)] +struct SimplifySmoothItUdf { + signature: Signature, +} + +impl SimplifySmoothItUdf { + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} +impl WindowUDFImpl for SimplifySmoothItUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "simplify_smooth_it" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + todo!() + } + + /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. + fn simplify(&self) -> Option { + let simplify = |window_function: WindowFunction, _: &dyn SimplifyInfo| { + Ok(Expr::WindowFunction(WindowFunction { + fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()), + args: window_function.args, + partition_by: window_function.partition_by, + order_by: window_function.order_by, + window_frame: window_function.window_frame, + null_treatment: window_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, true)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new()); + ctx.register_udwf(simplify_smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion-examples/examples/sql_analysis.rs b/datafusion-examples/examples/sql_analysis.rs new file mode 100644 index 000000000000..2158b8e4b016 --- /dev/null +++ b/datafusion-examples/examples/sql_analysis.rs @@ -0,0 +1,309 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This example shows how to use the structures that DataFusion provides to perform +//! Analysis on SQL queries and their plans. +//! +//! As a motivating example, we show how to count the number of JOINs in a query +//! as well as how many join tree's there are with their respective join count + +use std::sync::Arc; + +use datafusion::common::Result; +use datafusion::{ + datasource::MemTable, + execution::context::{SessionConfig, SessionContext}, +}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_expr::LogicalPlan; +use test_utils::tpcds::tpcds_schemas; + +/// Counts the total number of joins in a plan +fn total_join_count(plan: &LogicalPlan) -> usize { + let mut total = 0; + + // We can use the TreeNode API to walk over a LogicalPlan. + plan.apply(|node| { + // if we encounter a join we update the running count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + } + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + total +} + +/// Counts the total number of joins in a plan and collects every join tree in +/// the plan with their respective join count. +/// +/// Join Tree Definition: the largest subtree consisting entirely of joins +/// +/// For example, this plan: +/// +/// ```text +/// JOIN +/// / \ +/// A JOIN +/// / \ +/// B C +/// ``` +/// +/// has a single join tree `(A-B-C)` which will result in `(2, [2])` +/// +/// This plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])` +fn count_trees(plan: &LogicalPlan) -> (usize, Vec) { + // this works the same way as `total_count`, but now when we encounter a Join + // we try to collect it's entire tree + let mut to_visit = vec![plan]; + let mut total = 0; + let mut groups = vec![]; + + while let Some(node) = to_visit.pop() { + // if we encounter a join, we know were at the root of the tree + // count this tree and recurse on it's inputs + if matches!(node, LogicalPlan::Join(_)) { + let (group_count, inputs) = count_tree(node); + total += group_count; + groups.push(group_count); + to_visit.extend(inputs); + } else { + to_visit.extend(node.inputs()); + } + } + + (total, groups) +} + +/// Count the entire join tree and return its inputs using TreeNode API +/// +/// For example, if this function receives following plan: +/// +/// ```text +/// JOIN +/// / \ +/// A GROUP +/// | +/// JOIN +/// / \ +/// B C +/// ``` +/// +/// It will return `(1, [A, GROUP])` +fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) { + let mut inputs = Vec::new(); + let mut total = 0; + + join.apply(|node| { + // Some extra knowledge: + // + // optimized plans have their projections pushed down as far as + // possible, which sometimes results in a projection going in between 2 + // subsequent joins giving the illusion these joins are not "related", + // when in fact they are. + // + // This plan: + // JOIN + // / \ + // A PROJECTION + // | + // JOIN + // / \ + // B C + // + // is the same as: + // + // JOIN + // / \ + // A JOIN + // / \ + // B C + // we can continue the recursion in this case + if let LogicalPlan::Projection(_) = node { + return Ok(TreeNodeRecursion::Continue); + } + + // any join we count + if matches!(node, LogicalPlan::Join(_)) { + total += 1; + Ok(TreeNodeRecursion::Continue) + } else { + inputs.push(node); + // skip children of input node + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + (total, inputs) +} + +#[tokio::main] +async fn main() -> Result<()> { + // To show how we can count the joins in a sql query we'll be using query 88 + // from the TPC-DS benchmark. + // + // q8 has many joins, cross-joins and multiple join-trees, perfect for our + // example: + + let tpcds_query_88 = " +select * +from + (select count(*) h8_30_to_9 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 8 + and time_dim.t_minute >= 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s1, + (select count(*) h9_to_9_30 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 9 + and time_dim.t_minute < 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s2, + (select count(*) h9_30_to_10 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 9 + and time_dim.t_minute >= 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s3, + (select count(*) h10_to_10_30 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 10 + and time_dim.t_minute < 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s4, + (select count(*) h10_30_to_11 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 10 + and time_dim.t_minute >= 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s5, + (select count(*) h11_to_11_30 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 11 + and time_dim.t_minute < 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s6, + (select count(*) h11_30_to_12 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 11 + and time_dim.t_minute >= 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s7, + (select count(*) h12_to_12_30 + from store_sales, household_demographics , time_dim, store + where ss_sold_time_sk = time_dim.t_time_sk + and ss_hdemo_sk = household_demographics.hd_demo_sk + and ss_store_sk = s_store_sk + and time_dim.t_hour = 12 + and time_dim.t_minute < 30 + and ((household_demographics.hd_dep_count = 3 and household_demographics.hd_vehicle_count<=3+2) or + (household_demographics.hd_dep_count = 0 and household_demographics.hd_vehicle_count<=0+2) or + (household_demographics.hd_dep_count = 1 and household_demographics.hd_vehicle_count<=1+2)) + and store.s_store_name = 'ese') s8;"; + + // first set up the config + let config = SessionConfig::default(); + let ctx = SessionContext::new_with_config(config); + + // register the tables of the TPC-DS query + let tables = tpcds_schemas(); + for table in tables { + ctx.register_table( + table.name, + Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?), + )?; + } + // We can create a LogicalPlan from a SQL query like this + let logical_plan = ctx.sql(tpcds_query_88).await?.into_optimized_plan()?; + + println!( + "Optimized Logical Plan:\n\n{}\n", + logical_plan.display_indent() + ); + // we can get the total count (query 88 has 31 joins: 7 CROSS joins and 24 INNER joins => 40 input relations) + let total_join_count = total_join_count(&logical_plan); + assert_eq!(31, total_join_count); + + println!("The plan has {total_join_count} joins."); + + // Furthermore the 24 inner joins are 8 groups of 3 joins with the 7 + // cross-joins combining them we can get these groups using the + // `count_trees` method + let (total_join_count, trees) = count_trees(&logical_plan); + assert_eq!( + (total_join_count, &trees), + // query 88 is very straightforward, we know the cross-join group is at + // the top of the plan followed by the INNER joins + (31, &vec![7, 3, 3, 3, 3, 3, 3, 3, 3]) + ); + + println!( + "And following join-trees (number represents join amount in tree): {trees:?}" + ); + + Ok(()) +} diff --git a/datafusion-examples/examples/sql_frontend.rs b/datafusion-examples/examples/sql_frontend.rs new file mode 100644 index 000000000000..839ee95eb181 --- /dev/null +++ b/datafusion-examples/examples/sql_frontend.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{ + AggregateUDF, Expr, LogicalPlan, ScalarUDF, TableProviderFilterPushDown, TableSource, + WindowUDF, +}; +use datafusion_optimizer::{ + Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule, +}; +use datafusion_sql::planner::{ContextProvider, SqlToRel}; +use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; +use datafusion_sql::sqlparser::parser::Parser; +use datafusion_sql::TableReference; +use std::any::Any; +use std::sync::Arc; + +/// This example shows how to use DataFusion's SQL planner to parse SQL text and +/// build `LogicalPlan`s without executing them. +/// +/// For example, if you need a SQL planner and optimizer like Apache Calcite, +/// but do not want a Java runtime dependency for some reason, you could use +/// DataFusion as a SQL frontend. +/// +/// Normally, users interact with DataFusion via SessionContext. However, using +/// SessionContext requires depending on the full `datafusion` core crate. +/// +/// In this example, we demonstrate how to use the lower level APIs directly, +/// which only requires the `datafusion-sql` dependency. +pub fn main() -> Result<()> { + // First, we parse the SQL string. Note that we use the DataFusion + // Parser, which wraps the `sqlparser-rs` SQL parser and adds DataFusion + // specific syntax such as `CREATE EXTERNAL TABLE` + let dialect = PostgreSqlDialect {}; + let sql = "SELECT name FROM person WHERE age BETWEEN 21 AND 32"; + let statements = Parser::parse_sql(&dialect, sql)?; + + // Now, use DataFusion's SQL planner, called `SqlToRel` to create a + // `LogicalPlan` from the parsed statement + // + // To invoke SqlToRel we must provide it schema and function information + // via an object that implements the `ContextProvider` trait + let context_provider = MyContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context_provider); + let logical_plan = sql_to_rel.sql_statement_to_plan(statements[0].clone())?; + + // Here is the logical plan that was generated: + assert_eq!( + logical_plan.display_indent().to_string(), + "Projection: person.name\ + \n Filter: person.age BETWEEN Int64(21) AND Int64(32)\ + \n TableScan: person" + ); + + // The initial LogicalPlan is a mechanical translation from the parsed SQL + // and often can not run without the Analyzer passes. + // + // In this example, `person.age` is actually a different data type (Int8) + // than the values to which it is compared to which are Int64. Most + // execution engines, including DataFusion's, will fail if you provide such + // a plan. + // + // To prepare it to run, we must apply type coercion to align types, and + // check for other semantic errors. In DataFusion this is done by a + // component called the Analyzer. + let config = OptimizerContext::default().with_skip_failing_rules(false); + let analyzed_plan = Analyzer::new().execute_and_check( + logical_plan, + config.options(), + observe_analyzer, + )?; + // Note that the Analyzer has added a CAST to the plan to align the types + assert_eq!( + analyzed_plan.display_indent().to_string(), + "Projection: person.name\ + \n Filter: CAST(person.age AS Int64) BETWEEN Int64(21) AND Int64(32)\ + \n TableScan: person", + ); + + // As we can see, the Analyzer added a CAST so the types are the same + // (Int64). However, this plan is not as efficient as it could be, as it + // will require casting *each row* of the input to UInt64 before comparison + // to 21 and 32. To optimize this query's performance, it is better to cast + // the constants once at plan time to UInt8. + // + // Query optimization is handled in DataFusion by a component called the + // Optimizer, which we now invoke + // + let optimized_plan = + Optimizer::new().optimize(analyzed_plan, &config, observe_optimizer)?; + + // Show the fully optimized plan. Note that the optimizer did several things + // to prepare this plan for execution: + // + // 1. Removed casts from person.age as we described above + // 2. Converted BETWEEN to two single column inequalities (which are typically faster to execute) + // 3. Pushed the projection of `name` down to the scan (so the scan only returns that column) + // 4. Pushed the filter into the scan + // 5. Removed the projection as it was only serving to pass through the name column + assert_eq!( + optimized_plan.display_indent().to_string(), + "TableScan: person projection=[name], full_filters=[person.age >= UInt8(21), person.age <= UInt8(32)]" + ); + + Ok(()) +} + +// Note that both the optimizer and the analyzer take a callback, called an +// "observer" that is invoked after each pass. We do not do anything with these +// callbacks in this example + +fn observe_analyzer(_plan: &LogicalPlan, _rule: &dyn AnalyzerRule) {} +fn observe_optimizer(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + +/// Implements the `ContextProvider` trait required to plan SQL +#[derive(Default)] +struct MyContextProvider { + options: ConfigOptions, +} + +impl ContextProvider for MyContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + if name.table() == "person" { + Ok(Arc::new(MyTableSource { + schema: Arc::new(Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::UInt8, false), + ])), + })) + } else { + plan_err!("Table {} not found", name.table()) + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } + + fn get_window_meta(&self, _name: &str) -> Option> { + None + } + + fn options(&self) -> &ConfigOptions { + &self.options + } + + fn udf_names(&self) -> Vec { + Vec::new() + } + + fn udaf_names(&self) -> Vec { + Vec::new() + } + + fn udwf_names(&self) -> Vec { + Vec::new() + } +} + +/// TableSource is the part of TableProvider needed for creating a LogicalPlan. +struct MyTableSource { + schema: SchemaRef, +} + +impl TableSource for MyTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + // For this example, we report to the DataFusion optimizer that + // this provider can apply filters during the scan + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) + } +} diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index 6ab4801b8c60..71c6689c0cbd 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,40 +19,4 @@ # Changelog -- [37.1.0](../dev/changelog/37.1.0.md) -- [37.0.0](../dev/changelog/37.0.0.md) -- [36.0.0](../dev/changelog/36.0.0.md) -- [35.0.0](../dev/changelog/35.0.0.md) -- [34.0.0](../dev/changelog/34.0.0.md) -- [33.0.0](../dev/changelog/33.0.0.md) -- [32.0.0](../dev/changelog/32.0.0.md) -- [31.0.0](../dev/changelog/31.0.0.md) -- [30.0.0](../dev/changelog/30.0.0.md) -- [29.0.0](../dev/changelog/29.0.0.md) -- [28.0.0](../dev/changelog/28.0.0.md) -- [27.0.0](../dev/changelog/27.0.0.md) -- [26.0.0](../dev/changelog/26.0.0.md) -- [25.0.0](../dev/changelog/25.0.0.md) -- [24.0.0](../dev/changelog/24.0.0.md) -- [23.0.0](../dev/changelog/23.0.0.md) -- [22.0.0](../dev/changelog/22.0.0.md) -- [21.1.0](../dev/changelog/21.1.0.md) -- [21.0.0](../dev/changelog/21.0.0.md) -- [20.0.0](../dev/changelog/20.0.0.md) -- [19.0.0](../dev/changelog/19.0.0.md) -- [18.0.0](../dev/changelog/18.0.0.md) -- [17.0.0](../dev/changelog/17.0.0.md) -- [16.1.0](../dev/changelog/16.1.0.md) -- [16.0.0](../dev/changelog/16.0.0.md) -- [15.0.0](../dev/changelog/15.0.0.md) -- [14.0.0](../dev/changelog/14.0.0.md) -- [13.0.0](../dev/changelog/13.0.0.md) -- [12.0.0](../dev/changelog/12.0.0.md) -- [11.0.0](../dev/changelog/11.0.0.md) -- [10.0.0](../dev/changelog/10.0.0.md) -- [9.0.0](../dev/changelog/9.0.0.md) -- [8.0.0](../dev/changelog/8.0.0.md) -- [7.1.0](../dev/changelog/7.1.0.md) -- [7.0.0](../dev/changelog/7.0.0.md) -- [6.0.0](../dev/changelog/6.0.0.md) -- [5.0.0](../dev/changelog/5.0.0.md) +Change logs for each release can be found [here](https://github.com/apache/datafusion/tree/main/dev/changelog). diff --git a/datafusion/catalog/Cargo.toml b/datafusion/catalog/Cargo.toml new file mode 100644 index 000000000000..f9801352087d --- /dev/null +++ b/datafusion/catalog/Cargo.toml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-catalog" +description = "datafusion-catalog" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +rust-version.workspace = true +version.workspace = true + +[dependencies] +arrow-schema = { workspace = true } +async-trait = { workspace = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } +parking_lot = { workspace = true } + +[lints] +workspace = true diff --git a/datafusion/catalog/README.md b/datafusion/catalog/README.md new file mode 100644 index 000000000000..5b201e736fdc --- /dev/null +++ b/datafusion/catalog/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Catalog + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate is a submodule of DataFusion that provides catalog management functionality, including catalogs, schemas, and tables. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/catalog/src/catalog.rs similarity index 50% rename from datafusion/core/src/catalog/mod.rs rename to datafusion/catalog/src/catalog.rs index 209d9b2af297..048a7f14ed37 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/catalog/src/catalog.rs @@ -15,90 +15,13 @@ // specific language governing permissions and limitations // under the License. -//! Interfaces and default implementations of catalogs and schemas. - -pub mod information_schema; -pub mod listing_schema; -pub mod schema; - -pub use datafusion_sql::{ResolvedTableReference, TableReference}; - -use crate::catalog::schema::SchemaProvider; -use dashmap::DashMap; -use datafusion_common::{exec_err, not_impl_err, Result}; use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; -/// Represent a list of named [`CatalogProvider`]s. -/// -/// Please see the documentation on `CatalogProvider` for details of -/// implementing a custom catalog. -pub trait CatalogProviderList: Sync + Send { - /// Returns the catalog list as [`Any`] - /// so that it can be downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Adds a new catalog to this catalog list - /// If a catalog of the same name existed before, it is replaced in the list and returned. - fn register_catalog( - &self, - name: String, - catalog: Arc, - ) -> Option>; - - /// Retrieves the list of available catalog names - fn catalog_names(&self) -> Vec; - - /// Retrieves a specific catalog by name, provided it exists. - fn catalog(&self, name: &str) -> Option>; -} - -/// See [`CatalogProviderList`] -#[deprecated(since = "35.0.0", note = "use [`CatalogProviderList`] instead")] -pub trait CatalogList: CatalogProviderList {} - -/// Simple in-memory list of catalogs -pub struct MemoryCatalogProviderList { - /// Collection of catalogs containing schemas and ultimately TableProviders - pub catalogs: DashMap>, -} - -impl MemoryCatalogProviderList { - /// Instantiates a new `MemoryCatalogProviderList` with an empty collection of catalogs - pub fn new() -> Self { - Self { - catalogs: DashMap::new(), - } - } -} - -impl Default for MemoryCatalogProviderList { - fn default() -> Self { - Self::new() - } -} - -impl CatalogProviderList for MemoryCatalogProviderList { - fn as_any(&self) -> &dyn Any { - self - } - - fn register_catalog( - &self, - name: String, - catalog: Arc, - ) -> Option> { - self.catalogs.insert(name, catalog) - } - - fn catalog_names(&self) -> Vec { - self.catalogs.iter().map(|c| c.key().clone()).collect() - } - - fn catalog(&self, name: &str) -> Option> { - self.catalogs.get(name).map(|c| c.value().clone()) - } -} +pub use crate::schema::SchemaProvider; +use datafusion_common::not_impl_err; +use datafusion_common::Result; /// Represents a catalog, comprising a number of named schemas. /// @@ -112,18 +35,16 @@ impl CatalogProviderList for MemoryCatalogProviderList { /// * [`CatalogProviderList`]: a collection of `CatalogProvider`s /// * [`CatalogProvider`]: a collection of `SchemaProvider`s (sometimes called a "database" in other systems) /// * [`SchemaProvider`]: a collection of `TableProvider`s (often called a "schema" in other systems) -/// * [`TableProvider]`: individual tables +/// * [`TableProvider`]: individual tables /// /// # Implementing Catalogs /// /// To implement a catalog, you implement at least one of the [`CatalogProviderList`], /// [`CatalogProvider`] and [`SchemaProvider`] traits and register them -/// appropriately the [`SessionContext`]. -/// -/// [`SessionContext`]: crate::execution::context::SessionContext +/// appropriately in the `SessionContext`. /// /// DataFusion comes with a simple in-memory catalog implementation, -/// [`MemoryCatalogProvider`], that is used by default and has no persistence. +/// `MemoryCatalogProvider`, that is used by default and has no persistence. /// DataFusion does not include more complex Catalog implementations because /// catalog management is a key design choice for most data systems, and thus /// it is unlikely that any general-purpose catalog implementation will work @@ -157,23 +78,21 @@ impl CatalogProviderList for MemoryCatalogProviderList { /// access required to read table details (e.g. statistics). /// /// The pattern that DataFusion itself uses to plan SQL queries is to walk over -/// the query to [find all schema / table references in an `async` function], +/// the query to find all table references, /// performing required remote catalog in parallel, and then plans the query /// using that snapshot. /// -/// [find all schema / table references in an `async` function]: crate::execution::context::SessionState::resolve_table_references -/// /// # Example Catalog Implementations /// /// Here are some examples of how to implement custom catalogs: /// /// * [`datafusion-cli`]: [`DynamicFileCatalogProvider`] catalog provider -/// that treats files and directories on a filesystem as tables. +/// that treats files and directories on a filesystem as tables. /// /// * The [`catalog.rs`]: a simple directory based catalog. /// -/// * [delta-rs]: [`UnityCatalogProvider`] implementation that can -/// read from Delta Lake tables +/// * [delta-rs]: [`UnityCatalogProvider`] implementation that can +/// read from Delta Lake tables /// /// [`datafusion-cli`]: https://datafusion.apache.org/user-guide/cli/index.html /// [`DynamicFileCatalogProvider`]: https://github.com/apache/datafusion/blob/31b9b48b08592b7d293f46e75707aad7dadd7cbc/datafusion-cli/src/catalog.rs#L75 @@ -181,9 +100,9 @@ impl CatalogProviderList for MemoryCatalogProviderList { /// [delta-rs]: https://github.com/delta-io/delta-rs /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// -/// [`TableProvider]: crate::datasource::TableProvider +/// [`TableProvider`]: crate::TableProvider -pub trait CatalogProvider: Sync + Send { +pub trait CatalogProvider: Debug + Sync + Send { /// Returns the catalog provider as [`Any`] /// so that it can be downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -213,7 +132,7 @@ pub trait CatalogProvider: Sync + Send { /// Removes a schema from this catalog. Implementations of this method should return /// errors if the schema exists but cannot be dropped. For example, in DataFusion's - /// default in-memory catalog, [`MemoryCatalogProvider`], a non-empty schema + /// default in-memory catalog, `MemoryCatalogProvider`, a non-empty schema /// will only be successfully dropped when `cascade` is true. /// This is equivalent to how DROP SCHEMA works in PostgreSQL. /// @@ -230,137 +149,26 @@ pub trait CatalogProvider: Sync + Send { } } -/// Simple in-memory implementation of a catalog. -pub struct MemoryCatalogProvider { - schemas: DashMap>, -} - -impl MemoryCatalogProvider { - /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. - pub fn new() -> Self { - Self { - schemas: DashMap::new(), - } - } -} - -impl Default for MemoryCatalogProvider { - fn default() -> Self { - Self::new() - } -} - -impl CatalogProvider for MemoryCatalogProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema_names(&self) -> Vec { - self.schemas.iter().map(|s| s.key().clone()).collect() - } - - fn schema(&self, name: &str) -> Option> { - self.schemas.get(name).map(|s| s.value().clone()) - } - - fn register_schema( - &self, - name: &str, - schema: Arc, - ) -> Result>> { - Ok(self.schemas.insert(name.into(), schema)) - } +/// Represent a list of named [`CatalogProvider`]s. +/// +/// Please see the documentation on `CatalogProvider` for details of +/// implementing a custom catalog. +pub trait CatalogProviderList: Debug + Sync + Send { + /// Returns the catalog list as [`Any`] + /// so that it can be downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; - fn deregister_schema( + /// Adds a new catalog to this catalog list + /// If a catalog of the same name existed before, it is replaced in the list and returned. + fn register_catalog( &self, - name: &str, - cascade: bool, - ) -> Result>> { - if let Some(schema) = self.schema(name) { - let table_names = schema.table_names(); - match (table_names.is_empty(), cascade) { - (true, _) | (false, true) => { - let (_, removed) = self.schemas.remove(name).unwrap(); - Ok(Some(removed)) - } - (false, false) => exec_err!( - "Cannot drop schema {} because other tables depend on it: {}", - name, - itertools::join(table_names.iter(), ", ") - ), - } - } else { - Ok(None) - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::catalog::schema::MemorySchemaProvider; - use crate::datasource::empty::EmptyTable; - use crate::datasource::TableProvider; - use arrow::datatypes::Schema; - - #[test] - fn default_register_schema_not_supported() { - // mimic a new CatalogProvider and ensure it does not support registering schemas - struct TestProvider {} - impl CatalogProvider for TestProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema_names(&self) -> Vec { - unimplemented!() - } - - fn schema(&self, _name: &str) -> Option> { - unimplemented!() - } - } - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - let catalog = Arc::new(TestProvider {}); - - match catalog.register_schema("foo", schema) { - Ok(_) => panic!("unexpected OK"), - Err(e) => assert_eq!(e.strip_backtrace(), "This feature is not implemented: Registering new schemas is not supported"), - }; - } - - #[test] - fn memory_catalog_dereg_nonempty_schema() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) - as Arc; - schema.register_table("t".into(), test_table).unwrap(); - - cat.register_schema("foo", schema.clone()).unwrap(); - - assert!( - cat.deregister_schema("foo", false).is_err(), - "dropping empty schema without cascade should error" - ); - assert!(cat.deregister_schema("foo", true).unwrap().is_some()); - } - - #[test] - fn memory_catalog_dereg_empty_schema() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - - let schema = Arc::new(MemorySchemaProvider::new()) as Arc; - cat.register_schema("foo", schema).unwrap(); + name: String, + catalog: Arc, + ) -> Option>; - assert!(cat.deregister_schema("foo", false).unwrap().is_some()); - } + /// Retrieves the list of available catalog names + fn catalog_names(&self) -> Vec; - #[test] - fn memory_catalog_dereg_missing() { - let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; - assert!(cat.deregister_schema("foo", false).unwrap().is_none()); - } + /// Retrieves a specific catalog by name, provided it exists. + fn catalog(&self, name: &str) -> Option>; } diff --git a/datafusion/catalog/src/dynamic_file/catalog.rs b/datafusion/catalog/src/dynamic_file/catalog.rs new file mode 100644 index 000000000000..ccccb9762eb4 --- /dev/null +++ b/datafusion/catalog/src/dynamic_file/catalog.rs @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DynamicFileCatalog`] that creates tables from file paths + +use crate::{CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider}; +use async_trait::async_trait; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +/// Wrap another catalog provider list +#[derive(Debug)] +pub struct DynamicFileCatalog { + /// The inner catalog provider list + inner: Arc, + /// The factory that can create a table provider from the file path + factory: Arc, +} + +impl DynamicFileCatalog { + pub fn new( + inner: Arc, + factory: Arc, + ) -> Self { + Self { inner, factory } + } +} + +impl CatalogProviderList for DynamicFileCatalog { + fn as_any(&self) -> &dyn Any { + self + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + self.inner.register_catalog(name, catalog) + } + + fn catalog_names(&self) -> Vec { + self.inner.catalog_names() + } + + fn catalog(&self, name: &str) -> Option> { + self.inner.catalog(name).map(|catalog| { + Arc::new(DynamicFileCatalogProvider::new( + catalog, + Arc::clone(&self.factory), + )) as _ + }) + } +} + +/// Wraps another catalog provider +#[derive(Debug)] +struct DynamicFileCatalogProvider { + /// The inner catalog provider + inner: Arc, + /// The factory that can create a table provider from the file path + factory: Arc, +} + +impl DynamicFileCatalogProvider { + pub fn new( + inner: Arc, + factory: Arc, + ) -> Self { + Self { inner, factory } + } +} + +impl CatalogProvider for DynamicFileCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name).map(|schema| { + Arc::new(DynamicFileSchemaProvider::new( + schema, + Arc::clone(&self.factory), + )) as _ + }) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion_common::Result>> { + self.inner.register_schema(name, schema) + } +} + +/// Implements the [DynamicFileSchemaProvider] that can create tables provider from the file path. +/// +/// The provider will try to create a table provider from the file path if the table provider +/// isn't exist in the inner schema provider. +#[derive(Debug)] +pub struct DynamicFileSchemaProvider { + /// The inner schema provider + inner: Arc, + /// The factory that can create a table provider from the file path + factory: Arc, +} + +impl DynamicFileSchemaProvider { + /// Create a new [DynamicFileSchemaProvider] with the given inner schema provider. + pub fn new( + inner: Arc, + factory: Arc, + ) -> Self { + Self { inner, factory } + } +} + +#[async_trait] +impl SchemaProvider for DynamicFileSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table( + &self, + name: &str, + ) -> datafusion_common::Result>> { + if let Some(table) = self.inner.table(name).await? { + return Ok(Some(table)); + }; + + self.factory.try_new(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion_common::Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion_common::Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + +/// [UrlTableFactory] is a factory that can create a table provider from the given url. +#[async_trait] +pub trait UrlTableFactory: Debug + Sync + Send { + /// create a new table provider from the provided url + async fn try_new( + &self, + url: &str, + ) -> datafusion_common::Result>>; +} diff --git a/docs/src/lib.rs b/datafusion/catalog/src/dynamic_file/mod.rs similarity index 95% rename from docs/src/lib.rs rename to datafusion/catalog/src/dynamic_file/mod.rs index f73132468ec9..59142333dd54 100644 --- a/docs/src/lib.rs +++ b/datafusion/catalog/src/dynamic_file/mod.rs @@ -15,5 +15,4 @@ // specific language governing permissions and limitations // under the License. -#[cfg(test)] -mod library_logical_plan; +pub(crate) mod catalog; diff --git a/datafusion/catalog/src/lib.rs b/datafusion/catalog/src/lib.rs new file mode 100644 index 000000000000..21630f267d2c --- /dev/null +++ b/datafusion/catalog/src/lib.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod catalog; +mod dynamic_file; +mod schema; +mod session; +mod table; + +pub use catalog::*; +pub use dynamic_file::catalog::*; +pub use schema::*; +pub use session::*; +pub use table::*; diff --git a/datafusion/catalog/src/schema.rs b/datafusion/catalog/src/schema.rs new file mode 100644 index 000000000000..5b37348fd742 --- /dev/null +++ b/datafusion/catalog/src/schema.rs @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Describes the interface and built-in implementations of schemas, +//! representing collections of named tables. + +use async_trait::async_trait; +use datafusion_common::{exec_err, DataFusionError}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::table::TableProvider; +use datafusion_common::Result; + +/// Represents a schema, comprising a number of named tables. +/// +/// Please see [`CatalogProvider`] for details of implementing a custom catalog. +/// +/// [`CatalogProvider`]: super::CatalogProvider +#[async_trait] +pub trait SchemaProvider: Debug + Sync + Send { + /// Returns the owner of the Schema, default is None. This value is reported + /// as part of `information_tables.schemata + fn owner_name(&self) -> Option<&str> { + None + } + + /// Returns this `SchemaProvider` as [`Any`] so that it can be downcast to a + /// specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Retrieves the list of available table names in this schema. + fn table_names(&self) -> Vec; + + /// Retrieves a specific table from the schema by name, if it exists, + /// otherwise returns `None`. + async fn table( + &self, + name: &str, + ) -> Result>, DataFusionError>; + + /// If supported by the implementation, adds a new table named `name` to + /// this schema. + /// + /// If a table of the same name was already registered, returns "Table + /// already exists" error. + #[allow(unused_variables)] + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + exec_err!("schema provider does not support registering tables") + } + + /// If supported by the implementation, removes the `name` table from this + /// schema and returns the previously registered [`TableProvider`], if any. + /// + /// If no `name` table exists, returns Ok(None). + #[allow(unused_variables)] + fn deregister_table(&self, name: &str) -> Result>> { + exec_err!("schema provider does not support deregistering tables") + } + + /// Returns true if table exist in the schema provider, false otherwise. + fn table_exist(&self, name: &str) -> bool; +} diff --git a/datafusion/catalog/src/session.rs b/datafusion/catalog/src/session.rs new file mode 100644 index 000000000000..db49529ac43f --- /dev/null +++ b/datafusion/catalog/src/session.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use async_trait::async_trait; +use datafusion_common::config::ConfigOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::{AggregateUDF, Expr, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion_physical_plan::{ExecutionPlan, PhysicalExpr}; +use parking_lot::{Mutex, RwLock}; +use std::any::Any; +use std::collections::HashMap; +use std::sync::{Arc, Weak}; + +/// Interface for accessing [`SessionState`] from the catalog. +/// +/// This trait provides access to the information needed to plan and execute +/// queries, such as configuration, functions, and runtime environment. See the +/// documentation on [`SessionState`] for more information. +/// +/// Historically, the `SessionState` struct was passed directly to catalog +/// traits such as [`TableProvider`], which required a direct dependency on the +/// DataFusion core. The interface required is now defined by this trait. See +/// [#10782] for more details. +/// +/// [#10782]: https://github.com/apache/datafusion/issues/10782 +/// +/// # Migration from `SessionState` +/// +/// Using trait methods is preferred, as the implementation may change in future +/// versions. However, you can downcast a `Session` to a `SessionState` as shown +/// in the example below. If you find yourself needing to do this, please open +/// an issue on the DataFusion repository so we can extend the trait to provide +/// the required information. +/// +/// ``` +/// # use datafusion_catalog::Session; +/// # use datafusion_common::{Result, exec_datafusion_err}; +/// # struct SessionState {} +/// // Given a `Session` reference, get the concrete `SessionState` reference +/// // Note: this may stop working in future versions, +/// fn session_state_from_session(session: &dyn Session) -> Result<&SessionState> { +/// session.as_any() +/// .downcast_ref::() +/// .ok_or_else(|| exec_datafusion_err!("Failed to downcast Session to SessionState")) +/// } +/// ``` +/// +/// [`SessionState`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html +/// [`TableProvider`]: crate::TableProvider +#[async_trait] +pub trait Session: Send + Sync { + /// Return the session ID + fn session_id(&self) -> &str; + + /// Return the [`SessionConfig`] + fn config(&self) -> &SessionConfig; + + /// return the [`ConfigOptions`] + fn config_options(&self) -> &ConfigOptions { + self.config().options() + } + + /// Creates a physical [`ExecutionPlan`] plan from a [`LogicalPlan`]. + /// + /// Note: this will optimize the provided plan first. + /// + /// This function will error for [`LogicalPlan`]s such as catalog DDL like + /// `CREATE TABLE`, which do not have corresponding physical plans and must + /// be handled by another layer, typically the `SessionContext`. + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + ) -> Result>; + + /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type + /// coercion, and function rewrites. + /// + /// Note: The expression is not simplified or otherwise optimized: `a = 1 + /// + 2` will not be simplified to `a = 3` as this is a more involved process. + /// See the [expr_api] example for how to simplify expressions. + /// + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> Result>; + + /// Return reference to scalar_functions + fn scalar_functions(&self) -> &HashMap>; + + /// Return reference to aggregate_functions + fn aggregate_functions(&self) -> &HashMap>; + + /// Return reference to window functions + fn window_functions(&self) -> &HashMap>; + + /// Return the runtime env + fn runtime_env(&self) -> &Arc; + + /// Return the execution properties + fn execution_props(&self) -> &ExecutionProps; + + fn as_any(&self) -> &dyn Any; +} + +/// Create a new task context instance from Session +impl From<&dyn Session> for TaskContext { + fn from(state: &dyn Session) -> Self { + let task_id = None; + TaskContext::new( + task_id, + state.session_id().to_string(), + state.config().clone(), + state.scalar_functions().clone(), + state.aggregate_functions().clone(), + state.window_functions().clone(), + state.runtime_env().clone(), + ) + } +} +type SessionRefLock = Arc>>>>; +/// The state store that stores the reference of the runtime session state. +#[derive(Debug)] +pub struct SessionStore { + session: SessionRefLock, +} + +impl SessionStore { + /// Create a new [SessionStore] + pub fn new() -> Self { + Self { + session: Arc::new(Mutex::new(None)), + } + } + + /// Set the session state of the store + pub fn with_state(&self, state: Weak>) { + let mut lock = self.session.lock(); + *lock = Some(state); + } + + /// Get the current session of the store + pub fn get_session(&self) -> Weak> { + self.session.lock().clone().unwrap() + } +} + +impl Default for SessionStore { + fn default() -> Self { + Self::new() + } +} diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs new file mode 100644 index 000000000000..ca3a2bef882e --- /dev/null +++ b/datafusion/catalog/src/table.rs @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::borrow::Cow; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::session::Session; +use arrow_schema::SchemaRef; +use async_trait::async_trait; +use datafusion_common::Result; +use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{ + CreateExternalTable, Expr, LogicalPlan, TableProviderFilterPushDown, TableType, +}; +use datafusion_physical_plan::ExecutionPlan; + +/// Source table +#[async_trait] +pub trait TableProvider: Debug + Sync + Send { + /// Returns the table provider as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Get a reference to the schema for this table + fn schema(&self) -> SchemaRef; + + /// Get a reference to the constraints of the table. + /// Returns: + /// - `None` for tables that do not support constraints. + /// - `Some(&Constraints)` for tables supporting constraints. + /// Therefore, a `Some(&Constraints::empty())` return value indicates that + /// this table supports constraints, but there are no constraints. + fn constraints(&self) -> Option<&Constraints> { + None + } + + /// Get the type of this table for metadata/catalog purposes. + fn table_type(&self) -> TableType; + + /// Get the create statement used to create this table, if available. + fn get_table_definition(&self) -> Option<&str> { + None + } + + /// Get the [`LogicalPlan`] of this table, if available. + fn get_logical_plan(&self) -> Option> { + None + } + + /// Get the default value for a column, if available. + fn get_column_default(&self, _column: &str) -> Option<&Expr> { + None + } + + /// Create an [`ExecutionPlan`] for scanning the table with optionally + /// specified `projection`, `filter` and `limit`, described below. + /// + /// The `ExecutionPlan` is responsible scanning the datasource's + /// partitions in a streaming, parallelized fashion. + /// + /// # Projection + /// + /// If specified, only a subset of columns should be returned, in the order + /// specified. The projection is a set of indexes of the fields in + /// [`Self::schema`]. + /// + /// DataFusion provides the projection to scan only the columns actually + /// used in the query to improve performance, an optimization called + /// "Projection Pushdown". Some datasources, such as Parquet, can use this + /// information to go significantly faster when only a subset of columns is + /// required. + /// + /// # Filters + /// + /// A list of boolean filter [`Expr`]s to evaluate *during* the scan, in the + /// manner specified by [`Self::supports_filters_pushdown`]. Only rows for + /// which *all* of the `Expr`s evaluate to `true` must be returned (aka the + /// expressions are `AND`ed together). + /// + /// To enable filter pushdown you must override + /// [`Self::supports_filters_pushdown`] as the default implementation does + /// not and `filters` will be empty. + /// + /// DataFusion pushes filtering into the scans whenever possible + /// ("Filter Pushdown"), and depending on the format and the + /// implementation of the format, evaluating the predicate during the scan + /// can increase performance significantly. + /// + /// ## Note: Some columns may appear *only* in Filters + /// + /// In certain cases, a query may only use a certain column in a Filter that + /// has been completely pushed down to the scan. In this case, the + /// projection will not contain all the columns found in the filter + /// expressions. + /// + /// For example, given the query `SELECT t.a FROM t WHERE t.b > 5`, + /// + /// ```text + /// ┌────────────────────┐ + /// │ Projection(t.a) │ + /// └────────────────────┘ + /// ▲ + /// │ + /// │ + /// ┌────────────────────┐ Filter ┌────────────────────┐ Projection ┌────────────────────┐ + /// │ Filter(t.b > 5) │────Pushdown──▶ │ Projection(t.a) │ ───Pushdown───▶ │ Projection(t.a) │ + /// └────────────────────┘ └────────────────────┘ └────────────────────┘ + /// ▲ ▲ ▲ + /// │ │ │ + /// │ │ ┌────────────────────┐ + /// ┌────────────────────┐ ┌────────────────────┐ │ Scan │ + /// │ Scan │ │ Scan │ │ filter=(t.b > 5) │ + /// └────────────────────┘ │ filter=(t.b > 5) │ │ projection=(t.a) │ + /// └────────────────────┘ └────────────────────┘ + /// + /// Initial Plan If `TableProviderFilterPushDown` Projection pushdown notes that + /// returns true, filter pushdown the scan only needs t.a + /// pushes the filter into the scan + /// BUT internally evaluating the + /// predicate still requires t.b + /// ``` + /// + /// # Limit + /// + /// If `limit` is specified, must only produce *at least* this many rows, + /// (though it may return more). Like Projection Pushdown and Filter + /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as + /// possible, called "Limit Pushdown" as some sources can use this + /// information to improve their performance. Note that if there are any + /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is + /// because inexact filters do not guarantee that every filtered row is + /// removed, so applying the limit could lead to too few rows being available + /// to return as a final result. + async fn scan( + &self, + state: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result>; + + /// Specify if DataFusion should provide filter expressions to the + /// TableProvider to apply *during* the scan. + /// + /// Some TableProviders can evaluate filters more efficiently than the + /// `Filter` operator in DataFusion, for example by using an index. + /// + /// # Parameters and Return Value + /// + /// The return `Vec` must have one element for each element of the `filters` + /// argument. The value of each element indicates if the TableProvider can + /// apply the corresponding filter during the scan. The position in the return + /// value corresponds to the expression in the `filters` parameter. + /// + /// If the length of the resulting `Vec` does not match the `filters` input + /// an error will be thrown. + /// + /// Each element in the resulting `Vec` is one of the following: + /// * [`Exact`] or [`Inexact`]: The TableProvider can apply the filter + /// during scan + /// * [`Unsupported`]: The TableProvider cannot apply the filter during scan + /// + /// By default, this function returns [`Unsupported`] for all filters, + /// meaning no filters will be provided to [`Self::scan`]. + /// + /// [`Unsupported`]: TableProviderFilterPushDown::Unsupported + /// [`Exact`]: TableProviderFilterPushDown::Exact + /// [`Inexact`]: TableProviderFilterPushDown::Inexact + /// # Example + /// + /// ```rust + /// # use std::any::Any; + /// # use std::sync::Arc; + /// # use arrow_schema::SchemaRef; + /// # use async_trait::async_trait; + /// # use datafusion_catalog::{TableProvider, Session}; + /// # use datafusion_common::Result; + /// # use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; + /// # use datafusion_physical_plan::ExecutionPlan; + /// // Define a struct that implements the TableProvider trait + /// #[derive(Debug)] + /// struct TestDataSource {} + /// + /// #[async_trait] + /// impl TableProvider for TestDataSource { + /// # fn as_any(&self) -> &dyn Any { todo!() } + /// # fn schema(&self) -> SchemaRef { todo!() } + /// # fn table_type(&self) -> TableType { todo!() } + /// # async fn scan(&self, s: &dyn Session, p: Option<&Vec>, f: &[Expr], l: Option) -> Result> { + /// todo!() + /// # } + /// // Override the supports_filters_pushdown to evaluate which expressions + /// // to accept as pushdown predicates. + /// fn supports_filters_pushdown(&self, filters: &[&Expr]) -> Result> { + /// // Process each filter + /// let support: Vec<_> = filters.iter().map(|expr| { + /// match expr { + /// // This example only supports a between expr with a single column named "c1". + /// Expr::Between(between_expr) => { + /// between_expr.expr + /// .try_as_col() + /// .map(|column| { + /// if column.name == "c1" { + /// TableProviderFilterPushDown::Exact + /// } else { + /// TableProviderFilterPushDown::Unsupported + /// } + /// }) + /// // If there is no column in the expr set the filter to unsupported. + /// .unwrap_or(TableProviderFilterPushDown::Unsupported) + /// } + /// _ => { + /// // For all other cases return Unsupported. + /// TableProviderFilterPushDown::Unsupported + /// } + /// } + /// }).collect(); + /// Ok(support) + /// } + /// } + /// ``` + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } + + /// Get statistics for this table, if available + fn statistics(&self) -> Option { + None + } + + /// Return an [`ExecutionPlan`] to insert data into this table, if + /// supported. + /// + /// The returned plan should return a single row in a UInt64 + /// column called "count" such as the following + /// + /// ```text + /// +-------+, + /// | count |, + /// +-------+, + /// | 6 |, + /// +-------+, + /// ``` + /// + /// # See Also + /// + /// See [`DataSinkExec`] for the common pattern of inserting a + /// streams of `RecordBatch`es as files to an ObjectStore. + /// + /// [`DataSinkExec`]: datafusion_physical_plan::insert::DataSinkExec + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + _insert_op: InsertOp, + ) -> Result> { + not_impl_err!("Insert into not implemented for this table") + } +} + +/// A factory which creates [`TableProvider`]s at runtime given a URL. +/// +/// For example, this can be used to create a table "on the fly" +/// from a directory of files only when that name is referenced. +#[async_trait] +pub trait TableProviderFactory: Debug + Sync + Send { + /// Create a TableProvider with the given url + async fn create( + &self, + state: &dyn Session, + cmd: &CreateExternalTable, + ) -> Result>; +} diff --git a/datafusion/common-runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml index c10436087675..a21c72cd9f83 100644 --- a/datafusion/common-runtime/Cargo.toml +++ b/datafusion/common-runtime/Cargo.toml @@ -36,4 +36,8 @@ name = "datafusion_common_runtime" path = "src/lib.rs" [dependencies] +log = { workspace = true } tokio = { workspace = true } + +[dev-dependencies] +tokio = { version = "1.36", features = ["rt", "rt-multi-thread", "time"] } diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index 2f7ddb972f42..30f7526bc0b2 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -60,8 +60,8 @@ impl SpawnedTask { } /// Joins the task and unwinds the panic if it happens. - pub async fn join_unwind(self) -> R { - self.join().await.unwrap_or_else(|e| { + pub async fn join_unwind(self) -> Result { + self.join().await.map_err(|e| { // `JoinError` can be caused either by panic or cancellation. We have to handle panics: if e.is_panic() { std::panic::resume_unwind(e.into_panic()); @@ -69,9 +69,53 @@ impl SpawnedTask { // Cancellation may be caused by two reasons: // 1. Abort is called, but since we consumed `self`, it's not our case (`JoinHandle` not accessible outside). // 2. The runtime is shutting down. - // So we consider this branch as unreachable. - unreachable!("SpawnedTask was cancelled unexpectedly"); + log::warn!("SpawnedTask was polled during shutdown"); + e } }) } } + +#[cfg(test)] +mod tests { + use super::*; + + use std::future::{pending, Pending}; + + use tokio::runtime::Runtime; + + #[tokio::test] + async fn runtime_shutdown() { + let rt = Runtime::new().unwrap(); + let task = rt + .spawn(async { + SpawnedTask::spawn(async { + let fut: Pending<()> = pending(); + fut.await; + unreachable!("should never return"); + }) + }) + .await + .unwrap(); + + // caller shutdown their DF runtime (e.g. timeout, error in caller, etc) + rt.shutdown_background(); + + // race condition + // poll occurs during shutdown (buffered stream poll calls, etc) + assert!(matches!( + task.join_unwind().await, + Err(e) if e.is_cancelled() + )); + } + + #[tokio::test] + #[should_panic(expected = "foo")] + async fn panic_resume() { + // this should panic w/o an `unwrap` + SpawnedTask::spawn(async { panic!("foo") }) + .join_unwind() + .await + .ok(); + } +} diff --git a/datafusion/common-runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs index e8624163f224..8145bb110464 100644 --- a/datafusion/common-runtime/src/lib.rs +++ b/datafusion/common-runtime/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] pub mod common; diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 2391b2f83087..0747672a18f6 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -39,11 +39,10 @@ path = "src/lib.rs" avro = ["apache-avro"] backtrace = [] pyarrow = ["pyo3", "arrow/pyarrow", "parquet"] +force_hash_collisions = [] [dependencies] -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } +ahash = { workspace = true } apache-avro = { version = "0.16", default-features = false, features = [ "bzip", "snappy", @@ -56,12 +55,16 @@ arrow-buffer = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } half = { workspace = true } +hashbrown = { workspace = true } +indexmap = { workspace = true } libc = "0.2.140" num_cpus = { workspace = true } object_store = { workspace = true, optional = true } parquet = { workspace = true, optional = true, default-features = true } -pyo3 = { version = "0.20.0", optional = true } +paste = "1.0.15" +pyo3 = { version = "0.22.0", optional = true } sqlparser = { workspace = true } +tokio = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] instant = { version = "0.1", features = ["wasm-bindgen"] } diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 0dc0532bbb6f..0586fcf5e2ae 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -36,6 +36,7 @@ use arrow::{ }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; +use arrow_array::{BinaryViewArray, StringViewArray}; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -87,6 +88,11 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> { Ok(downcast_value!(array, StringArray)) } +// Downcast ArrayRef to StringViewArray +pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> { + Ok(downcast_value!(array, StringViewArray)) +} + // Downcast ArrayRef to UInt32Array pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> { Ok(downcast_value!(array, UInt32Array)) @@ -221,6 +227,11 @@ pub fn as_binary_array(array: &dyn Array) -> Result<&BinaryArray> { Ok(downcast_value!(array, BinaryArray)) } +// Downcast ArrayRef to BinaryViewArray +pub fn as_binary_view_array(array: &dyn Array) -> Result<&BinaryViewArray> { + Ok(downcast_value!(array, BinaryViewArray)) +} + // Downcast ArrayRef to FixedSizeListArray pub fn as_fixed_size_list_array(array: &dyn Array) -> Result<&FixedSizeListArray> { Ok(downcast_value!(array, FixedSizeListArray)) diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index 911ff079def1..d855198fa7c6 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -26,7 +26,6 @@ use std::collections::HashSet; use std::convert::Infallible; use std::fmt; use std::str::FromStr; -use std::sync::Arc; /// A named reference to a qualified field in a schema. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -109,24 +108,31 @@ impl Column { /// `foo.BAR` would be parsed to a reference to relation `foo`, column name `bar` (lower case) /// where `"foo.BAR"` would be parsed to a reference to column named `foo.BAR` pub fn from_qualified_name(flat_name: impl Into) -> Self { - let flat_name: &str = &flat_name.into(); - Self::from_idents(&mut parse_identifiers_normalized(flat_name, false)) + let flat_name = flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(&flat_name, false)) .unwrap_or_else(|| Self { relation: None, - name: flat_name.to_owned(), + name: flat_name, }) } /// Deserialize a fully qualified name string into a column preserving column text case pub fn from_qualified_name_ignore_case(flat_name: impl Into) -> Self { - let flat_name: &str = &flat_name.into(); - Self::from_idents(&mut parse_identifiers_normalized(flat_name, true)) + let flat_name = flat_name.into(); + Self::from_idents(&mut parse_identifiers_normalized(&flat_name, true)) .unwrap_or_else(|| Self { relation: None, - name: flat_name.to_owned(), + name: flat_name, }) } + /// return the column's name. + /// + /// Note: This ignores the relation and returns the column name only. + pub fn name(&self) -> &str { + &self.name + } + /// Serialize column into a flat name string pub fn flat_name(&self) -> String { match &self.relation { @@ -149,79 +155,6 @@ impl Column { } } - /// Qualify column if not done yet. - /// - /// If this column already has a [relation](Self::relation), it will be returned as is and the given parameters are - /// ignored. Otherwise this will search through the given schemas to find the column. This will use the first schema - /// that matches. - /// - /// A schema matches if there is a single column that -- when unqualified -- matches this column. There is an - /// exception for `USING` statements, see below. - /// - /// # Using columns - /// Take the following SQL statement: - /// - /// ```sql - /// SELECT id FROM t1 JOIN t2 USING(id) - /// ``` - /// - /// In this case, both `t1.id` and `t2.id` will match unqualified column `id`. To express this possibility, use - /// `using_columns`. Each entry in this array is a set of columns that are bound together via a `USING` clause. So - /// in this example this would be `[{t1.id, t2.id}]`. - #[deprecated( - since = "20.0.0", - note = "use normalize_with_schemas_and_ambiguity_check instead" - )] - pub fn normalize_with_schemas( - self, - schemas: &[&Arc], - using_columns: &[HashSet], - ) -> Result { - if self.relation.is_some() { - return Ok(self); - } - - for schema in schemas { - let qualified_fields = - schema.qualified_fields_with_unqualified_name(&self.name); - match qualified_fields.len() { - 0 => continue, - 1 => { - return Ok(Column::from(qualified_fields[0])); - } - _ => { - // More than 1 fields in this schema have their names set to self.name. - // - // This should only happen when a JOIN query with USING constraint references - // join columns using unqualified column name. For example: - // - // ```sql - // SELECT id FROM t1 JOIN t2 USING(id) - // ``` - // - // In this case, both `t1.id` and `t2.id` will match unqualified column `id`. - // We will use the relation from the first matched field to normalize self. - - // Compare matched fields with one USING JOIN clause at a time - let columns = schema.columns_with_unqualified_name(&self.name); - for using_col in using_columns { - let all_matched = columns.iter().all(|f| using_col.contains(f)); - // All matched fields belong to the same using column set, in orther words - // the same join clause. We simply pick the qualifer from the first match. - if all_matched { - return Ok(columns[0].clone()); - } - } - } - } - } - - _schema_err!(SchemaError::FieldNotFound { - field: Box::new(Column::new(self.relation.clone(), self.name)), - valid_fields: schemas.iter().flat_map(|s| s.columns()).collect(), - }) - } - /// Qualify column if not done yet. /// /// If this column already has a [relation](Self::relation), it will be returned as is and the given parameters are @@ -296,7 +229,7 @@ impl Column { for using_col in using_columns { let all_matched = columns.iter().all(|c| using_col.contains(c)); // All matched fields belong to the same using column set, in orther words - // the same join clause. We simply pick the qualifer from the first match. + // the same join clause. We simply pick the qualifier from the first match. if all_matched { return Ok(columns[0].clone()); } @@ -374,6 +307,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; use arrow_schema::SchemaBuilder; + use std::sync::Arc; fn create_qualified_schema(qualifier: &str, names: Vec<&str>) -> Result { let mut schema_builder = SchemaBuilder::new(); diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 30ab9a339b54..1fa32aefb33f 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -24,7 +24,7 @@ use std::str::FromStr; use crate::error::_config_err; use crate::parsers::CompressionTypeVariant; -use crate::{DataFusionError, FileType, Result}; +use crate::{DataFusionError, Result}; /// A macro that wraps a configuration struct and automatically derives /// [`Default`] and [`ConfigField`] for it, allowing it to be used @@ -130,9 +130,9 @@ macro_rules! config_namespace { $( stringify!($field_name) => self.$field_name.set(rem, value), )* - _ => return Err(DataFusionError::Configuration(format!( + _ => return _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) - ))) + ) } } @@ -181,8 +181,19 @@ config_namespace! { /// Type of `TableProvider` to use when loading `default` schema pub format: Option, default = None - /// If the file has a header - pub has_header: bool, default = false + /// Default value for `format.has_header` for `CREATE EXTERNAL TABLE` + /// if not specified explicitly in the statement. + pub has_header: bool, default = true + + /// Specifies whether newlines in (quoted) CSV values are supported. + /// + /// This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` + /// if not specified explicitly in the statement. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + pub newlines_in_values: bool, default = false } } @@ -199,10 +210,18 @@ config_namespace! { /// When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) pub enable_ident_normalization: bool, default = true + /// When set to true, SQL parser will normalize options value (convert value to lowercase) + pub enable_options_value_normalization: bool, default = true + /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. pub dialect: String, default = "generic".to_string() + /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but + /// ignore the length. If false, error if a `VARCHAR` with a length is + /// specified. The Arrow type system does not have a notion of maximum + /// string length and thus DataFusion can not enforce such limits. + pub support_varchar_with_length: bool, default = true } } @@ -242,9 +261,6 @@ config_namespace! { /// Parquet options pub parquet: ParquetOptions, default = Default::default() - /// Aggregate options - pub aggregate: AggregateOptions, default = Default::default() - /// Fan-out during initial physical planning. /// /// This is mostly use to plan `UNION` children in parallel. @@ -252,6 +268,17 @@ config_namespace! { /// Defaults to the number of CPU cores on the system pub planning_concurrency: usize, default = num_cpus::get() + /// When set to true, skips verifying that the schema produced by + /// planning the input of `LogicalPlan::Aggregate` exactly matches the + /// schema of the input plan. + /// + /// When set to false, if the schema does not match exactly + /// (including nullability and metadata), a planning error will be raised. + /// + /// This is used to workaround bugs in the planner that are now caught by + /// the new schema verification step. + pub skip_physical_aggregate_schema_check: bool, default = false + /// Specifies the reserved memory for each spillable sort operation to /// facilitate an in-memory merge. /// @@ -297,97 +324,146 @@ config_namespace! { /// Should DataFusion support recursive CTEs pub enable_recursive_ctes: bool, default = true + + /// Attempt to eliminate sorts by packing & sorting files with non-overlapping + /// statistics into the same file groups. + /// Currently experimental + pub split_file_groups_by_statistics: bool, default = false + + /// Should DataFusion keep the columns used for partition_by in the output RecordBatches + pub keep_partition_by_columns: bool, default = false + + /// Aggregation ratio (number of distinct groups / number of input rows) + /// threshold for skipping partial aggregation. If the value is greater + /// then partial aggregation will skip aggregation for further input + pub skip_partial_aggregation_probe_ratio_threshold: f64, default = 0.8 + + /// Number of input rows partial aggregation partition should process, before + /// aggregation ratio check and trying to switch to skipping aggregation mode + pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000 + + /// Should DataFusion use row number estimates at the input to decide + /// whether increasing parallelism is beneficial or not. By default, + /// only exact row numbers (not estimates) are used for this decision. + /// Setting this flag to `true` will likely produce better plans. + /// if the source of statistics is accurate. + /// We plan to make this the default in the future. + pub use_row_number_estimates_to_optimize_partitioning: bool, default = false + + /// Should DataFusion enforce batch size in joins or not. By default, + /// DataFusion will not enforce batch size in joins. Enforcing batch size + /// in joins can reduce memory usage when joining large + /// tables with a highly-selective join filter, but is also slightly slower. + pub enforce_batch_size_in_joins: bool, default = false } } config_namespace! { - /// Options related to parquet files + /// Options for reading and writing parquet files /// /// See also: [`SessionConfig`] /// /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html pub struct ParquetOptions { - /// If true, reads the Parquet data page level metadata (the + // The following options affect reading parquet files + + /// (reading) If true, reads the Parquet data page level metadata (the /// Page Index), if present, to reduce the I/O and number of /// rows decoded. pub enable_page_index: bool, default = true - /// If true, the parquet reader attempts to skip entire row groups based + /// (reading) If true, the parquet reader attempts to skip entire row groups based /// on the predicate in the query and the metadata (min/max values) stored in /// the parquet file pub pruning: bool, default = true - /// If true, the parquet reader skip the optional embedded metadata that may be in + /// (reading) If true, the parquet reader skip the optional embedded metadata that may be in /// the file Schema. This setting can help avoid schema conflicts when querying /// multiple parquet files with schemas containing compatible types but different metadata pub skip_metadata: bool, default = true - /// If specified, the parquet reader will try and fetch the last `size_hint` + /// (reading) If specified, the parquet reader will try and fetch the last `size_hint` /// bytes of the parquet file optimistically. If not specified, two reads are required: /// One read to fetch the 8-byte parquet footer and /// another to fetch the metadata length encoded in the footer pub metadata_size_hint: Option, default = None - /// If true, filter expressions are be applied during the parquet decoding operation to + /// (reading) If true, filter expressions are be applied during the parquet decoding operation to /// reduce the number of rows decoded. This optimization is sometimes called "late materialization". pub pushdown_filters: bool, default = false - /// If true, filter expressions evaluated during the parquet decoding operation + /// (reading) If true, filter expressions evaluated during the parquet decoding operation /// will be reordered heuristically to minimize the cost of evaluation. If false, /// the filters are applied in the same order as written in the query pub reorder_filters: bool, default = false - // The following map to parquet::file::properties::WriterProperties + /// (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, + /// and `Binary/BinaryLarge` with `BinaryView`. + pub schema_force_view_types: bool, default = false + + /// (reading) If true, parquet reader will read columns of + /// `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. + /// + /// Parquet files generated by some legacy writers do not correctly set + /// the UTF8 flag for strings, causing string columns to be loaded as + /// BLOB instead. + pub binary_as_string: bool, default = false + + // The following options affect writing to parquet files + // and map to parquet::file::properties::WriterProperties - /// Sets best effort maximum size of data page in bytes + /// (writing) Sets best effort maximum size of data page in bytes pub data_pagesize_limit: usize, default = 1024 * 1024 - /// Sets write_batch_size in bytes + /// (writing) Sets write_batch_size in bytes pub write_batch_size: usize, default = 1024 - /// Sets parquet writer version + /// (writing) Sets parquet writer version /// valid values are "1.0" and "2.0" - pub writer_version: String, default = "1.0".into() + pub writer_version: String, default = "1.0".to_string() - /// Sets default parquet compression codec + /// (writing) Sets default parquet compression codec. /// Valid values are: uncompressed, snappy, gzip(level), /// lzo, brotli(level), lz4, zstd(level), and lz4_raw. /// These values are not case sensitive. If NULL, uses /// default parquet writer setting + /// + /// Note that this default setting is not the same as + /// the default parquet writer setting. pub compression: Option, default = Some("zstd(3)".into()) - /// Sets if dictionary encoding is enabled. If NULL, uses + /// (writing) Sets if dictionary encoding is enabled. If NULL, uses /// default parquet writer setting - pub dictionary_enabled: Option, default = None + pub dictionary_enabled: Option, default = Some(true) - /// Sets best effort maximum dictionary page size, in bytes + /// (writing) Sets best effort maximum dictionary page size, in bytes pub dictionary_page_size_limit: usize, default = 1024 * 1024 - /// Sets if statistics are enabled for any column + /// (writing) Sets if statistics are enabled for any column /// Valid values are: "none", "chunk", and "page" /// These values are not case sensitive. If NULL, uses /// default parquet writer setting - pub statistics_enabled: Option, default = None + pub statistics_enabled: Option, default = Some("page".into()) - /// Sets max statistics size for any column. If NULL, uses + /// (writing) Sets max statistics size for any column. If NULL, uses /// default parquet writer setting - pub max_statistics_size: Option, default = None + pub max_statistics_size: Option, default = Some(4096) - /// Target maximum number of rows in each row group (defaults to 1M + /// (writing) Target maximum number of rows in each row group (defaults to 1M /// rows). Writing larger row groups requires more memory to write, but /// can get better compression and be faster to read. - pub max_row_group_size: usize, default = 1024 * 1024 + pub max_row_group_size: usize, default = 1024 * 1024 - /// Sets "created by" property + /// (writing) Sets "created by" property pub created_by: String, default = concat!("datafusion version ", env!("CARGO_PKG_VERSION")).into() - /// Sets column index truncate length - pub column_index_truncate_length: Option, default = None + /// (writing) Sets column index truncate length + pub column_index_truncate_length: Option, default = Some(64) - /// Sets best effort maximum number of rows in data page - pub data_page_row_count_limit: usize, default = usize::MAX + /// (writing) Sets best effort maximum number of rows in data page + pub data_page_row_count_limit: usize, default = 20_000 - /// Sets default encoding for any column + /// (writing) Sets default encoding for any column. /// Valid values are: plain, plain_dictionary, rle, /// bit_packed, delta_binary_packed, delta_length_byte_array, /// delta_byte_array, rle_dictionary, and byte_stream_split. @@ -395,24 +471,27 @@ config_namespace! { /// default parquet writer setting pub encoding: Option, default = None - /// Sets if bloom filter is enabled for any column - pub bloom_filter_enabled: bool, default = false + /// (writing) Use any available bloom filters when reading parquet files + pub bloom_filter_on_read: bool, default = true - /// Sets bloom filter false positive probability. If NULL, uses + /// (writing) Write bloom filters for all columns when creating parquet files + pub bloom_filter_on_write: bool, default = false + + /// (writing) Sets bloom filter false positive probability. If NULL, uses /// default parquet writer setting pub bloom_filter_fpp: Option, default = None - /// Sets bloom filter number of distinct values. If NULL, uses + /// (writing) Sets bloom filter number of distinct values. If NULL, uses /// default parquet writer setting pub bloom_filter_ndv: Option, default = None - /// Controls whether DataFusion will attempt to speed up writing + /// (writing) Controls whether DataFusion will attempt to speed up writing /// parquet files by serializing them in parallel. Each column /// in each row group in each output file are serialized in parallel /// leveraging a maximum possible core count of n_files*n_row_groups*n_columns. pub allow_single_file_parallelism: bool, default = true - /// By default parallel parquet writer is tuned for minimum + /// (writing) By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see /// a performance benefit when writing large parquet files /// by increasing maximum_parallel_row_group_writers and @@ -423,7 +502,7 @@ config_namespace! { /// data frame. pub maximum_parallel_row_group_writers: usize, default = 1 - /// By default parallel parquet writer is tuned for minimum + /// (writing) By default parallel parquet writer is tuned for minimum /// memory usage in a streaming execution plan. You may see /// a performance benefit when writing large parquet files /// by increasing maximum_parallel_row_group_writers and @@ -433,28 +512,6 @@ config_namespace! { /// writing out already in-memory data, such as from a cached /// data frame. pub maximum_buffered_record_batches_per_stream: usize, default = 2 - - } -} - -config_namespace! { - /// Options related to aggregate execution - /// - /// See also: [`SessionConfig`] - /// - /// [`SessionConfig`]: https://docs.rs/datafusion/latest/datafusion/prelude/struct.SessionConfig.html - pub struct AggregateOptions { - /// Specifies the threshold for using `ScalarValue`s to update - /// accumulators during high-cardinality aggregations for each input batch. - /// - /// The aggregation is considered high-cardinality if the number of affected groups - /// is greater than or equal to `batch_size / scalar_update_factor`. In such cases, - /// `ScalarValue`s are utilized for updating accumulators, rather than the default - /// batch-slice approach. This can lead to performance improvements. - /// - /// By adjusting the `scalar_update_factor`, you can balance the trade-off between - /// more efficient accumulator updates and the number of groups affected. - pub scalar_update_factor: usize, default = 10 } } @@ -571,6 +628,14 @@ config_namespace! { /// when an exact selectivity cannot be determined. Valid values are /// between 0 (no selectivity) and 100 (all rows are selected). pub default_filter_selectivity: u8, default = 20 + + /// When set to true, the optimizer will not attempt to convert Union to Interleave + pub prefer_existing_union: bool, default = false + + /// When set to true, if the returned type is a view type + /// then the output will be coerced to a non-view. + /// Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. + pub expand_views_at_output: bool, default = false } } @@ -593,6 +658,9 @@ config_namespace! { /// When set to true, the explain statement will print the partition sizes pub show_sizes: bool, default = true + + /// When set to true, the explain statement will print schema information + pub show_schema: bool, default = false } } @@ -664,22 +732,17 @@ impl ConfigOptions { /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let (prefix, key) = key.split_once('.').ok_or_else(|| { - DataFusionError::Configuration(format!( - "could not find config namespace for key \"{key}\"", - )) - })?; + let Some((prefix, key)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; if prefix == "datafusion" { return ConfigField::set(self, key, value); } - let e = self.extensions.0.get_mut(prefix); - let e = e.ok_or_else(|| { - DataFusionError::Configuration(format!( - "Could not find config namespace \"{prefix}\"" - )) - })?; + let Some(e) = self.extensions.0.get_mut(prefix) else { + return _config_err!("Could not find config namespace \"{prefix}\""); + }; e.0.set(key, value) } @@ -724,7 +787,7 @@ impl ConfigOptions { /// /// Only the built-in configurations will be extracted from the hash map /// and other key value pairs will be ignored. - pub fn from_string_hash_map(settings: HashMap) -> Result { + pub fn from_string_hash_map(settings: &HashMap) -> Result { struct Visitor(Vec); impl Visit for Visitor { @@ -824,7 +887,7 @@ pub trait ConfigExtension: ExtensionOptions { } /// An object-safe API for storing arbitrary configuration -pub trait ExtensionOptions: Send + Sync + std::fmt::Debug + 'static { +pub trait ExtensionOptions: Send + Sync + fmt::Debug + 'static { /// Return `self` as [`Any`] /// /// This is needed until trait upcasting is stabilised @@ -1109,6 +1172,16 @@ macro_rules! extensions_options { } } +/// These file types have special built in behavior for configuration. +/// Use TableOptions::Extensions for configuring other file types. +#[derive(Debug, Clone)] +pub enum ConfigFileType { + CSV, + #[cfg(feature = "parquet")] + PARQUET, + JSON, +} + /// Represents the configuration options available for handling different table formats within a data processing application. /// This struct encompasses options for various file formats including CSV, Parquet, and JSON, allowing for flexible configuration /// of parsing and writing behaviors specific to each format. Additionally, it supports extending functionality through custom extensions. @@ -1127,7 +1200,7 @@ pub struct TableOptions { /// The current file format that the table operations should assume. This option allows /// for dynamic switching between the supported file types (e.g., CSV, Parquet, JSON). - pub current_format: Option, + pub current_format: Option, /// Optional extensions that can be used to extend or customize the behavior of the table /// options. Extensions can be registered using `Extensions::insert` and might include @@ -1145,10 +1218,9 @@ impl ConfigField for TableOptions { if let Some(file_type) = &self.current_format { match file_type { #[cfg(feature = "parquet")] - FileType::PARQUET => self.parquet.visit(v, "format", ""), - FileType::CSV => self.csv.visit(v, "format", ""), - FileType::JSON => self.json.visit(v, "format", ""), - _ => {} + ConfigFileType::PARQUET => self.parquet.visit(v, "format", ""), + ConfigFileType::CSV => self.csv.visit(v, "format", ""), + ConfigFileType::JSON => self.json.visit(v, "format", ""), } } else { self.csv.visit(v, "csv", ""); @@ -1165,7 +1237,7 @@ impl ConfigField for TableOptions { /// # Parameters /// /// * `key`: The configuration key specifying which setting to adjust, prefixed with the format (e.g., "format.delimiter") - /// for CSV format. + /// for CSV format. /// * `value`: The value to set for the specified configuration key. /// /// # Returns @@ -1175,19 +1247,18 @@ impl ConfigField for TableOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); - let Some(format) = &self.current_format else { - return _config_err!("Specify a format for TableOptions"); - }; match key { - "format" => match format { - #[cfg(feature = "parquet")] - FileType::PARQUET => self.parquet.set(rem, value), - FileType::CSV => self.csv.set(rem, value), - FileType::JSON => self.json.set(rem, value), - _ => { - _config_err!("Config value \"{key}\" is not supported on {}", format) + "format" => { + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; + match format { + #[cfg(feature = "parquet")] + ConfigFileType::PARQUET => self.parquet.set(rem, value), + ConfigFileType::CSV => self.csv.set(rem, value), + ConfigFileType::JSON => self.json.set(rem, value), } - }, + } _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } @@ -1203,15 +1274,6 @@ impl TableOptions { Self::default() } - /// Sets the file format for the table. - /// - /// # Parameters - /// - /// * `format`: The file format to use (e.g., CSV, Parquet). - pub fn set_file_format(&mut self, format: FileType) { - self.current_format = Some(format); - } - /// Creates a new `TableOptions` instance initialized with settings from a given session config. /// /// # Parameters @@ -1242,6 +1304,15 @@ impl TableOptions { clone } + /// Sets the file format for the table. + /// + /// # Parameters + /// + /// * `format`: The file format to use (e.g., CSV, Parquet). + pub fn set_config_format(&mut self, format: ConfigFileType) { + self.current_format = Some(format); + } + /// Sets the extensions for this `TableOptions` instance. /// /// # Parameters @@ -1267,22 +1338,21 @@ impl TableOptions { /// /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let (prefix, _) = key.split_once('.').ok_or_else(|| { - DataFusionError::Configuration(format!( - "could not find config namespace for key \"{key}\"" - )) - })?; + let Some((prefix, _)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; if prefix == "format" { return ConfigField::set(self, key, value); } - let e = self.extensions.0.get_mut(prefix); - let e = e.ok_or_else(|| { - DataFusionError::Configuration(format!( - "Could not find config namespace \"{prefix}\"" - )) - })?; + if prefix == "execution" { + return Ok(()); + } + + let Some(e) = self.extensions.0.get_mut(prefix) else { + return _config_err!("Could not find config namespace \"{prefix}\""); + }; e.0.set(key, value) } @@ -1364,12 +1434,38 @@ impl TableOptions { /// Options that control how Parquet files are read, including global options /// that apply to all columns and optional column-specific overrides +/// +/// Closely tied to [`ParquetWriterOptions`](crate::file_options::parquet_writer::ParquetWriterOptions). +/// Properties not included in [`TableParquetOptions`] may not be configurable at the external API +/// (e.g. sorting_columns). #[derive(Clone, Default, Debug, PartialEq)] pub struct TableParquetOptions { /// Global Parquet options that propagates to all columns. pub global: ParquetOptions, /// Column specific options. Default usage is parquet.XX::column. - pub column_specific_options: HashMap, + pub column_specific_options: HashMap, + /// Additional file-level metadata to include. Inserted into the key_value_metadata + /// for the written [`FileMetaData`](https://docs.rs/parquet/latest/parquet/file/metadata/struct.FileMetaData.html). + /// + /// Multiple entries are permitted + /// ```sql + /// OPTIONS ( + /// 'format.metadata::key1' '', + /// 'format.metadata::key2' 'value', + /// 'format.metadata::key3' 'value has spaces', + /// 'format.metadata::key4' 'value has special chars :: :', + /// 'format.metadata::key_dupe' 'original will be overwritten', + /// 'format.metadata::key_dupe' 'final' + /// ) + /// ``` + pub key_value_metadata: HashMap>, +} + +impl TableParquetOptions { + /// Return new default TableParquetOptions + pub fn new() -> Self { + Self::default() + } } impl ConfigField for TableParquetOptions { @@ -1380,8 +1476,24 @@ impl ConfigField for TableParquetOptions { } fn set(&mut self, key: &str, value: &str) -> Result<()> { - // Determine the key if it's a global or column-specific setting - if key.contains("::") { + // Determine if the key is a global, metadata, or column-specific setting + if key.starts_with("metadata::") { + let k = match key.split("::").collect::>()[..] { + [_meta] | [_meta, ""] => { + return _config_err!( + "Invalid metadata key provided, missing key in metadata::" + ) + } + [_meta, k] => k.into(), + _ => { + return _config_err!( + "Invalid metadata key provided, found too many '::' in \"{key}\"" + ) + } + }; + self.key_value_metadata.insert(k, Some(value.into())); + Ok(()) + } else if key.contains("::") { self.column_specific_options.set(key, value) } else { self.global.set(key, value) @@ -1451,10 +1563,7 @@ macro_rules! config_namespace_with_hashmap { inner_value.set(inner_key, value) } - _ => Err(DataFusionError::Configuration(format!( - "Unrecognized key '{}'.", - key - ))), + _ => _config_err!("Unrecognized key '{key}'."), } } @@ -1472,7 +1581,10 @@ macro_rules! config_namespace_with_hashmap { } config_namespace_with_hashmap! { - pub struct ColumnOptions { + /// Options controlling parquet format for individual columns. + /// + /// See [`ParquetOptions`] for more details + pub struct ParquetColumnOptions { /// Sets if bloom filter is enabled for the column path. pub bloom_filter_enabled: Option, default = None @@ -1518,18 +1630,32 @@ config_namespace_with_hashmap! { config_namespace! { /// Options controlling CSV format pub struct CsvOptions { - pub has_header: bool, default = true + /// Specifies whether there is a CSV header (i.e. the first line + /// consists of is column names). The value `None` indicates that + /// the configuration should be consulted. + pub has_header: Option, default = None pub delimiter: u8, default = b',' pub quote: u8, default = b'"' + pub terminator: Option, default = None pub escape: Option, default = None + pub double_quote: Option, default = None + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub newlines_in_values: Option, default = None pub compression: CompressionTypeVariant, default = CompressionTypeVariant::UNCOMPRESSED pub schema_infer_max_rec: usize, default = 100 - pub date_format: Option, default = None - pub datetime_format: Option, default = None - pub timestamp_format: Option, default = None - pub timestamp_tz_format: Option, default = None - pub time_format: Option, default = None - pub null_value: Option, default = None + pub date_format: Option, default = None + pub datetime_format: Option, default = None + pub timestamp_format: Option, default = None + pub timestamp_tz_format: Option, default = None + pub time_format: Option, default = None + pub null_value: Option, default = None + pub comment: Option, default = None } } @@ -1554,12 +1680,14 @@ impl CsvOptions { /// Set true to indicate that the first line is a header. /// - default to true pub fn with_has_header(mut self, has_header: bool) -> Self { - self.has_header = has_header; + self.has_header = Some(has_header); self } - /// True if the first line is a header. - pub fn has_header(&self) -> bool { + /// Returns true if the first line is a header. If format options does not + /// specify whether there is a header, returns `None` (indicating that the + /// configuration should be consulted). + pub fn has_header(&self) -> Option { self.has_header } @@ -1577,6 +1705,13 @@ impl CsvOptions { self } + /// The character that terminates a row. + /// - default to None (CRLF) + pub fn with_terminator(mut self, terminator: Option) -> Self { + self.terminator = terminator; + self + } + /// The escape character in a row. /// - default is None pub fn with_escape(mut self, escape: Option) -> Self { @@ -1584,6 +1719,25 @@ impl CsvOptions { self } + /// Set true to indicate that the CSV quotes should be doubled. + /// - default to true + pub fn with_double_quote(mut self, double_quote: bool) -> Self { + self.double_quote = Some(double_quote); + self + } + + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = Some(newlines_in_values); + self + } + /// Set a `CompressionTypeVariant` of CSV /// - defaults to `CompressionTypeVariant::UNCOMPRESSED` pub fn with_file_compression_type( @@ -1604,6 +1758,11 @@ impl CsvOptions { self.quote } + /// The terminator character. + pub fn terminator(&self) -> Option { + self.terminator + } + /// The escape character. pub fn escape(&self) -> Option { self.escape @@ -1618,7 +1777,10 @@ config_namespace! { } } +pub trait FormatOptionsExt: Display {} + #[derive(Debug, Clone, PartialEq)] +#[allow(clippy::large_enum_variant)] pub enum FormatOptions { CSV(CsvOptions), JSON(JsonOptions), @@ -1627,6 +1789,7 @@ pub enum FormatOptions { AVRO, ARROW, } + impl Display for FormatOptions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let out = match self { @@ -1641,28 +1804,15 @@ impl Display for FormatOptions { } } -impl From for FormatOptions { - fn from(value: FileType) -> Self { - match value { - FileType::ARROW => FormatOptions::ARROW, - FileType::AVRO => FormatOptions::AVRO, - #[cfg(feature = "parquet")] - FileType::PARQUET => FormatOptions::PARQUET(TableParquetOptions::default()), - FileType::CSV => FormatOptions::CSV(CsvOptions::default()), - FileType::JSON => FormatOptions::JSON(JsonOptions::default()), - } - } -} - #[cfg(test)] mod tests { use std::any::Any; use std::collections::HashMap; use crate::config::{ - ConfigEntry, ConfigExtension, ExtensionOptions, Extensions, TableOptions, + ConfigEntry, ConfigExtension, ConfigFileType, ExtensionOptions, Extensions, + TableOptions, }; - use crate::FileType; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -1720,7 +1870,7 @@ mod tests { let mut extension = Extensions::new(); extension.insert(TestExtensionConfig::default()); let mut table_config = TableOptions::new().with_extensions(extension); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.set("format.delimiter", ";").unwrap(); assert_eq!(table_config.csv.delimiter, b';'); table_config.set("test.bootstrap.servers", "asd").unwrap(); @@ -1737,7 +1887,7 @@ mod tests { #[test] fn csv_u8_table_options() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.set("format.delimiter", ";").unwrap(); assert_eq!(table_config.csv.delimiter as char, ';'); table_config.set("format.escape", "\"").unwrap(); @@ -1750,7 +1900,7 @@ mod tests { #[test] fn parquet_table_options() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config .set("format.bloom_filter_enabled::col1", "true") .unwrap(); @@ -1764,7 +1914,7 @@ mod tests { #[test] fn parquet_table_options_config_entry() { let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config .set("format.bloom_filter_enabled::col1", "true") .unwrap(); @@ -1773,4 +1923,38 @@ mod tests { .iter() .any(|item| item.key == "format.bloom_filter_enabled::col1")) } + + #[cfg(feature = "parquet")] + #[test] + fn parquet_table_options_config_metadata_entry() { + let mut table_config = TableOptions::new(); + table_config.set_config_format(ConfigFileType::PARQUET); + table_config.set("format.metadata::key1", "").unwrap(); + table_config.set("format.metadata::key2", "value2").unwrap(); + table_config + .set("format.metadata::key3", "value with spaces ") + .unwrap(); + table_config + .set("format.metadata::key4", "value with special chars :: :") + .unwrap(); + + let parsed_metadata = table_config.parquet.key_value_metadata.clone(); + assert_eq!(parsed_metadata.get("should not exist1"), None); + assert_eq!(parsed_metadata.get("key1"), Some(&Some("".into()))); + assert_eq!(parsed_metadata.get("key2"), Some(&Some("value2".into()))); + assert_eq!( + parsed_metadata.get("key3"), + Some(&Some("value with spaces ".into())) + ); + assert_eq!( + parsed_metadata.get("key4"), + Some(&Some("value with special chars :: :".into())) + ); + + // duplicate keys are overwritten + table_config.set("format.metadata::key_dupe", "A").unwrap(); + table_config.set("format.metadata::key_dupe", "B").unwrap(); + let parsed_metadata = table_config.parquet.key_value_metadata; + assert_eq!(parsed_metadata.get("key_dupe"), Some(&Some("B".into()))); + } } diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs new file mode 100644 index 000000000000..ab02915858cd --- /dev/null +++ b/datafusion/common/src/cse.rs @@ -0,0 +1,816 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common Subexpression Elimination logic implemented in [`CSE`] can be controlled with +//! a [`CSEController`], that defines how to eliminate common subtrees from a particular +//! [`TreeNode`] tree. + +use crate::hash_utils::combine_hashes; +use crate::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, +}; +use crate::Result; +use indexmap::IndexMap; +use std::collections::HashMap; +use std::hash::{BuildHasher, Hash, Hasher, RandomState}; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Hashes the direct content of an [`TreeNode`] without recursing into its children. +/// +/// This method is useful to incrementally compute hashes, such as in [`CSE`] which builds +/// a deep hash of a node and its descendants during the bottom-up phase of the first +/// traversal and so avoid computing the hash of the node and then the hash of its +/// descendants separately. +/// +/// If a node doesn't have any children then the value returned by `hash_node()` is +/// similar to '.hash()`, but not necessarily returns the same value. +pub trait HashNode { + fn hash_node(&self, state: &mut H); +} + +impl HashNode for Arc { + fn hash_node(&self, state: &mut H) { + (**self).hash_node(state); + } +} + +/// Identifier that represents a [`TreeNode`] tree. +/// +/// This identifier is designed to be efficient and "hash", "accumulate", "equal" and +/// "have no collision (as low as possible)" +#[derive(Debug, Eq, PartialEq)] +struct Identifier<'n, N> { + // Hash of `node` built up incrementally during the first, visiting traversal. + // Its value is not necessarily equal to default hash of the node. E.g. it is not + // equal to `expr.hash()` if the node is `Expr`. + hash: u64, + node: &'n N, +} + +impl Clone for Identifier<'_, N> { + fn clone(&self) -> Self { + *self + } +} +impl Copy for Identifier<'_, N> {} + +impl Hash for Identifier<'_, N> { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} + +impl<'n, N: HashNode> Identifier<'n, N> { + fn new(node: &'n N, random_state: &RandomState) -> Self { + let mut hasher = random_state.build_hasher(); + node.hash_node(&mut hasher); + let hash = hasher.finish(); + Self { hash, node } + } + + fn combine(mut self, other: Option) -> Self { + other.map_or(self, |other_id| { + self.hash = combine_hashes(self.hash, other_id.hash); + self + }) + } +} + +/// A cache that contains the postorder index and the identifier of [`TreeNode`]s by the +/// preorder index of the nodes. +/// +/// This cache is filled by [`CSEVisitor`] during the first traversal and is +/// used by [`CSERewriter`] during the second traversal. +/// +/// The purpose of this cache is to quickly find the identifier of a node during the +/// second traversal. +/// +/// Elements in this array are added during `f_down` so the indexes represent the preorder +/// index of nodes and thus element 0 belongs to the root of the tree. +/// +/// The elements of the array are tuples that contain: +/// - Postorder index that belongs to the preorder index. Assigned during `f_up`, start +/// from 0. +/// - The optional [`Identifier`] of the node. If none the node should not be considered +/// for CSE. +/// +/// # Example +/// An expression tree like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (2, Some(Identifier(hash_of("a + b"), &"a + b"))), +/// (1, Some(Identifier(hash_of("a"), &"a"))), +/// (0, Some(Identifier(hash_of("b"), &"b"))) +/// ] +/// ``` +type IdArray<'n, N> = Vec<(usize, Option>)>; + +#[derive(PartialEq, Eq)] +/// How many times a node is evaluated. A node can be considered common if evaluated +/// surely at least 2 times or surely only once but also conditionally. +enum NodeEvaluation { + SurelyOnce, + ConditionallyAtLeastOnce, + Common, +} + +/// A map that contains the evaluation stats of [`TreeNode`]s by their identifiers. +type NodeStats<'n, N> = HashMap, NodeEvaluation>; + +/// A map that contains the common [`TreeNode`]s and their alias by their identifiers, +/// extracted during the second, rewriting traversal. +type CommonNodes<'n, N> = IndexMap, (N, String)>; + +type ChildrenList = (Vec, Vec); + +/// The [`TreeNode`] specific definition of elimination. +pub trait CSEController { + /// The type of the tree nodes. + type Node; + + /// Splits the children to normal and conditionally evaluated ones or returns `None` + /// if all are always evaluated. + fn conditional_children(node: &Self::Node) -> Option>; + + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. + // Validity is propagated up which means no subtree can be eliminated that contains + // an invalid node. + // (E.g. volatile expressions are not valid and subtrees containing such a node can't + // be extracted.) + fn is_valid(node: &Self::Node) -> bool; + + // Returns true if a node should be ignored during CSE. Contrary to validity of a node, + // it is not propagated up. + fn is_ignored(&self, node: &Self::Node) -> bool; + + // Generates a new name for the extracted subtree. + fn generate_alias(&self) -> String; + + // Replaces a node to the generated alias. + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node; + + // A helper method called on each node during top-down traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node during bottom-up traversal during the second, + // rewriting traversal of CSE. + fn rewrite_f_up(&mut self, _node: &Self::Node) {} +} + +/// The result of potentially rewriting a list of [`TreeNode`]s to eliminate common +/// subtrees. +#[derive(Debug)] +pub enum FoundCommonNodes { + /// No common [`TreeNode`]s were found + No { original_nodes_list: Vec> }, + + /// Common [`TreeNode`]s were found + Yes { + /// extracted common [`TreeNode`] + common_nodes: Vec<(N, String)>, + + /// new [`TreeNode`]s with common subtrees replaced + new_nodes_list: Vec>, + + /// original [`TreeNode`]s + original_nodes_list: Vec>, + }, +} + +/// Go through a [`TreeNode`] tree and generate identifiers for each subtrees. +/// +/// An identifier contains information of the [`TreeNode`] itself and its subtrees. +/// This visitor implementation use a stack `visit_stack` to track traversal, which +/// lets us know when a subtree's visiting is finished. When `pre_visit` is called +/// (traversing to a new node), an `EnterMark` and an `NodeItem` will be pushed into stack. +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `NodeItem` +/// before the first `EnterMark` is considered to be sub-tree of the leaving node. +/// +/// This visitor also records identifier in `id_array`. Makes the following traverse +/// pass can get the identifier of a node without recalculate it. We assign each node +/// in the tree a series number, start from 1, maintained by `series_number`. +/// Series number represents the order we left (`f_up()`) a node. Has the property +/// that child node's series number always smaller than parent's. While `id_array` is +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to +/// get the index of `id_array` for each node. +/// +/// A [`TreeNode`] without any children (column, literal etc.) will not have identifier +/// because they should not be recognized as common subtree. +struct CSEVisitor<'a, 'n, N, C: CSEController> { + /// statistics of [`TreeNode`]s + node_stats: &'a mut NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a mut IdArray<'n, N>, + + /// inner states + visit_stack: Vec>, + + /// preorder index, start from 0. + down_index: usize, + + /// postorder index, start from 0. + up_index: usize, + + /// a [`RandomState`] to generate hashes during the first traversal + random_state: &'a RandomState, + + /// a flag to indicate that common [`TreeNode`]s found + found_common: bool, + + /// if we are in a conditional branch. A conditional branch means that the [`TreeNode`] + /// might not be executed depending on the runtime values of other [`TreeNode`]s, and + /// thus can not be extracted as a common [`TreeNode`]. + conditional: bool, + + controller: &'a C, +} + +/// Record item that used when traversing a [`TreeNode`] tree. +enum VisitRecord<'n, N> { + /// Marks the beginning of [`TreeNode`]. It contains: + /// - The post-order index assigned during the first, visiting traversal. + EnterMark(usize), + + /// Marks an accumulated subtree. It contains: + /// - The accumulated identifier of a subtree. + /// - A accumulated boolean flag if the subtree is valid for CSE. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + NodeItem(Identifier<'n, N>, bool), +} + +impl<'n, N: TreeNode + HashNode, C: CSEController> CSEVisitor<'_, 'n, N, C> { + /// Find the first `EnterMark` in the stack, and accumulates every `NodeItem` before + /// it. Returns a tuple that contains: + /// - The pre-order index of the [`TreeNode`] we marked. + /// - The accumulated identifier of the children of the marked [`TreeNode`]. + /// - An accumulated boolean flag from the children of the marked [`TreeNode`] if all + /// children are valid for CSE (i.e. it is safe to extract the [`TreeNode`] as a + /// common [`TreeNode`] from its children POV). + /// (E.g. if any of the children of the marked expression is not valid (e.g. is + /// volatile) then the expression is also not valid, so we can propagate this + /// information up from children to parents via `visit_stack` during the first, + /// visiting traversal and no need to test the expression's validity beforehand with + /// an extra traversal). + fn pop_enter_mark(&mut self) -> (usize, Option>, bool) { + let mut node_id = None; + let mut is_valid = true; + + while let Some(item) = self.visit_stack.pop() { + match item { + VisitRecord::EnterMark(down_index) => { + return (down_index, node_id, is_valid); + } + VisitRecord::NodeItem(sub_node_id, sub_node_is_valid) => { + node_id = Some(sub_node_id.combine(node_id)); + is_valid &= sub_node_is_valid; + } + } + } + unreachable!("EnterMark should paired with NodeItem"); + } +} + +impl<'n, N: TreeNode + HashNode + Eq, C: CSEController> TreeNodeVisitor<'n> + for CSEVisitor<'_, 'n, N, C> +{ + type Node = N; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + self.id_array.push((0, None)); + self.visit_stack + .push(VisitRecord::EnterMark(self.down_index)); + self.down_index += 1; + + // If a node can short-circuit then some of its children might not be executed so + // count the occurrence either normal or conditional. + Ok(if self.conditional { + // If we are already in a conditionally evaluated subtree then continue + // traversal. + TreeNodeRecursion::Continue + } else { + // If we are already in a node that can short-circuit then start new + // traversals on its normal conditional children. + match C::conditional_children(node) { + Some((normal, conditional)) => { + normal + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = true; + conditional + .into_iter() + .try_for_each(|n| n.visit(self).map(|_| ()))?; + self.conditional = false; + + TreeNodeRecursion::Jump + } + + // In case of non-short-circuit node continue the traversal. + _ => TreeNodeRecursion::Continue, + } + }) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + let (down_index, sub_node_id, sub_node_is_valid) = self.pop_enter_mark(); + + let node_id = Identifier::new(node, self.random_state).combine(sub_node_id); + let is_valid = C::is_valid(node) && sub_node_is_valid; + + self.id_array[down_index].0 = self.up_index; + if is_valid && !self.controller.is_ignored(node) { + self.id_array[down_index].1 = Some(node_id); + self.node_stats + .entry(node_id) + .and_modify(|evaluation| { + if *evaluation == NodeEvaluation::SurelyOnce + || *evaluation == NodeEvaluation::ConditionallyAtLeastOnce + && !self.conditional + { + *evaluation = NodeEvaluation::Common; + self.found_common = true; + } + }) + .or_insert_with(|| { + if self.conditional { + NodeEvaluation::ConditionallyAtLeastOnce + } else { + NodeEvaluation::SurelyOnce + } + }); + } + self.visit_stack + .push(VisitRecord::NodeItem(node_id, is_valid)); + self.up_index += 1; + + Ok(TreeNodeRecursion::Continue) + } +} + +/// Rewrite a [`TreeNode`] tree by replacing detected common subtrees with the +/// corresponding temporary [`TreeNode`], that column contains the evaluate result of +/// replaced [`TreeNode`] tree. +struct CSERewriter<'a, 'n, N, C: CSEController> { + /// statistics of [`TreeNode`]s + node_stats: &'a NodeStats<'n, N>, + + /// cache to speed up second traversal + id_array: &'a IdArray<'n, N>, + + /// common [`TreeNode`]s, that are replaced during the second traversal, are collected + /// to this map + common_nodes: &'a mut CommonNodes<'n, N>, + + // preorder index, starts from 0. + down_index: usize, + + controller: &'a mut C, +} + +impl> TreeNodeRewriter + for CSERewriter<'_, '_, N, C> +{ + type Node = N; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_down(&node); + + let (up_index, node_id) = self.id_array[self.down_index]; + self.down_index += 1; + + // Handle nodes with identifiers only + if let Some(node_id) = node_id { + let evaluation = self.node_stats.get(&node_id).unwrap(); + if *evaluation == NodeEvaluation::Common { + // step index to skip all sub-node (which has smaller series number). + while self.down_index < self.id_array.len() + && self.id_array[self.down_index].0 < up_index + { + self.down_index += 1; + } + + let (node, alias) = + self.common_nodes.entry(node_id).or_insert_with(|| { + let node_alias = self.controller.generate_alias(); + (node, node_alias) + }); + + let rewritten = self.controller.rewrite(node, alias); + + return Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump)); + } + } + + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.controller.rewrite_f_up(&node); + + Ok(Transformed::no(node)) + } +} + +/// The main entry point of Common Subexpression Elimination. +/// +/// [`CSE`] requires a [`CSEController`], that defines how common subtrees of a particular +/// [`TreeNode`] tree can be eliminated. The elimination process can be started with the +/// [`CSE::extract_common_nodes()`] method. +pub struct CSE> { + random_state: RandomState, + phantom_data: PhantomData, + controller: C, +} + +impl> CSE { + pub fn new(controller: C) -> Self { + Self { + random_state: RandomState::new(), + phantom_data: PhantomData, + controller, + } + } + + /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. + fn node_to_id_array<'n>( + &self, + node: &'n N, + node_stats: &mut NodeStats<'n, N>, + id_array: &mut IdArray<'n, N>, + ) -> Result { + let mut visitor = CSEVisitor { + node_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + random_state: &self.random_state, + found_common: false, + conditional: false, + controller: &self.controller, + }; + node.visit(&mut visitor)?; + + Ok(visitor.found_common) + } + + /// Returns the identifier list for each element in `nodes` and a flag to indicate if + /// rewrite phase of CSE make sense. + /// + /// Returns and array with 1 element for each input node in `nodes` + /// + /// Each element is itself the result of [`CSE::node_to_id_array`] for that node + /// (e.g. the identifiers for each node in the tree) + fn to_arrays<'n>( + &self, + nodes: &'n [N], + node_stats: &mut NodeStats<'n, N>, + ) -> Result<(bool, Vec>)> { + let mut found_common = false; + nodes + .iter() + .map(|n| { + let mut id_array = vec![]; + self.node_to_id_array(n, node_stats, &mut id_array) + .map(|fc| { + found_common |= fc; + + id_array + }) + }) + .collect::>>() + .map(|id_arrays| (found_common, id_arrays)) + } + + /// Replace common subtrees in `node` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`] + fn replace_common_node<'n>( + &mut self, + node: N, + id_array: &IdArray<'n, N>, + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result { + if id_array.is_empty() { + Ok(Transformed::no(node)) + } else { + node.rewrite(&mut CSERewriter { + node_stats, + id_array, + common_nodes, + down_index: 0, + controller: &mut self.controller, + }) + } + .data() + } + + /// Replace common subtrees in `nodes_list` with the corresponding temporary + /// [`TreeNode`], updating `common_nodes` with any replaced [`TreeNode`]. + fn rewrite_nodes_list<'n>( + &mut self, + nodes_list: Vec>, + arrays_list: &[Vec>], + node_stats: &NodeStats<'n, N>, + common_nodes: &mut CommonNodes<'n, N>, + ) -> Result>> { + nodes_list + .into_iter() + .zip(arrays_list.iter()) + .map(|(nodes, arrays)| { + nodes + .into_iter() + .zip(arrays.iter()) + .map(|(node, id_array)| { + self.replace_common_node(node, id_array, node_stats, common_nodes) + }) + .collect::>>() + }) + .collect::>>() + } + + /// Extracts common [`TreeNode`]s and rewrites `nodes_list`. + /// + /// Returns [`FoundCommonNodes`] recording the result of the extraction. + pub fn extract_common_nodes( + &mut self, + nodes_list: Vec>, + ) -> Result> { + let mut found_common = false; + let mut node_stats = NodeStats::new(); + let id_arrays_list = nodes_list + .iter() + .map(|nodes| { + self.to_arrays(nodes, &mut node_stats) + .map(|(fc, id_arrays)| { + found_common |= fc; + + id_arrays + }) + }) + .collect::>>()?; + if found_common { + let mut common_nodes = CommonNodes::new(); + let new_nodes_list = self.rewrite_nodes_list( + // Must clone the list of nodes as Identifiers use references to original + // nodes so we have to keep them intact. + nodes_list.clone(), + &id_arrays_list, + &node_stats, + &mut common_nodes, + )?; + assert!(!common_nodes.is_empty()); + + Ok(FoundCommonNodes::Yes { + common_nodes: common_nodes.into_values().collect(), + new_nodes_list, + original_nodes_list: nodes_list, + }) + } else { + Ok(FoundCommonNodes::No { + original_nodes_list: nodes_list, + }) + } + } +} + +#[cfg(test)] +mod test { + use crate::alias::AliasGenerator; + use crate::cse::{CSEController, HashNode, IdArray, Identifier, NodeStats, CSE}; + use crate::tree_node::tests::TestTreeNode; + use crate::Result; + use std::collections::HashSet; + use std::hash::{Hash, Hasher}; + + const CSE_PREFIX: &str = "__common_node"; + + #[derive(Clone, Copy)] + pub enum TestTreeNodeMask { + Normal, + NormalAndAggregates, + } + + pub struct TestTreeNodeCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: TestTreeNodeMask, + } + + impl<'a> TestTreeNodeCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: TestTreeNodeMask) -> Self { + Self { + alias_generator, + mask, + } + } + } + + impl CSEController for TestTreeNodeCSEController<'_> { + type Node = TestTreeNode; + + fn conditional_children( + _: &Self::Node, + ) -> Option<(Vec<&Self::Node>, Vec<&Self::Node>)> { + None + } + + fn is_valid(_node: &Self::Node) -> bool { + true + } + + fn is_ignored(&self, node: &Self::Node) -> bool { + let is_leaf = node.is_leaf(); + let is_aggr = node.data == "avg" || node.data == "sum"; + + match self.mask { + TestTreeNodeMask::Normal => is_leaf || is_aggr, + TestTreeNodeMask::NormalAndAggregates => is_leaf, + } + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + TestTreeNode::new_leaf(format!("alias({}, {})", node.data, alias)) + } + } + + impl HashNode for TestTreeNode { + fn hash_node(&self, state: &mut H) { + self.data.hash(state); + } + } + + #[test] + fn id_array_visitor() -> Result<()> { + let alias_generator = AliasGenerator::new(); + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::Normal, + )); + + let a_plus_1 = TestTreeNode::new( + vec![ + TestTreeNode::new_leaf("a".to_string()), + TestTreeNode::new_leaf("1".to_string()), + ], + "+".to_string(), + ); + let avg_c = TestTreeNode::new( + vec![TestTreeNode::new_leaf("c".to_string())], + "avg".to_string(), + ); + let sum_a_plus_1 = TestTreeNode::new(vec![a_plus_1], "sum".to_string()); + let sum_a_plus_1_minus_avg_c = + TestTreeNode::new(vec![sum_a_plus_1, avg_c], "-".to_string()); + let root = TestTreeNode::new( + vec![ + sum_a_plus_1_minus_avg_c, + TestTreeNode::new_leaf("2".to_string()), + ], + "*".to_string(), + ); + + let [sum_a_plus_1_minus_avg_c, _] = root.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [sum_a_plus_1, avg_c] = sum_a_plus_1_minus_avg_c.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + let [a_plus_1] = sum_a_plus_1.children.as_slice() else { + panic!("Cannot extract subtree references") + }; + + // skip aggregates + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + // Collect distinct hashes and set them to 0 in `id_array` + fn collect_hashes( + id_array: &mut IdArray<'_, TestTreeNode>, + ) -> HashSet { + id_array + .iter_mut() + .flat_map(|(_, id_option)| { + id_option.as_mut().map(|node_id| { + let hash = node_id.hash; + node_id.hash = 0; + hash + }) + }) + .collect::>() + } + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 3); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + (3, None), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + (5, None), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + // include aggregates + let eliminator = CSE::new(TestTreeNodeCSEController::new( + &alias_generator, + TestTreeNodeMask::NormalAndAggregates, + )); + + let mut id_array = vec![]; + eliminator.node_to_id_array(&root, &mut NodeStats::new(), &mut id_array)?; + + let hashes = collect_hashes(&mut id_array); + assert_eq!(hashes.len(), 5); + + let expected = vec![ + ( + 8, + Some(Identifier { + hash: 0, + node: &root, + }), + ), + ( + 6, + Some(Identifier { + hash: 0, + node: sum_a_plus_1_minus_avg_c, + }), + ), + ( + 3, + Some(Identifier { + hash: 0, + node: sum_a_plus_1, + }), + ), + ( + 2, + Some(Identifier { + hash: 0, + node: a_plus_1, + }), + ), + (0, None), + (1, None), + ( + 5, + Some(Identifier { + hash: 0, + node: avg_c, + }), + ), + (4, None), + (7, None), + ]; + assert_eq!(expected, id_array); + + Ok(()) + } +} diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index f1909f0dc8e1..aa2d93989da1 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -98,7 +98,7 @@ pub type DFSchemaRef = Arc; /// use arrow::datatypes::Field; /// use std::collections::HashMap; /// -/// let df_schema = DFSchema::from_unqualifed_fields(vec![ +/// let df_schema = DFSchema::from_unqualified_fields(vec![ /// Field::new("c1", arrow::datatypes::DataType::Int32, false), /// ].into(),HashMap::new()).unwrap(); /// let schema = Schema::from(df_schema); @@ -125,6 +125,20 @@ impl DFSchema { } } + /// Return a reference to the inner Arrow [`Schema`] + /// + /// Note this does not have the qualifier information + pub fn as_arrow(&self) -> &Schema { + self.inner.as_ref() + } + + /// Return a reference to the inner Arrow [`SchemaRef`] + /// + /// Note this does not have the qualifier information + pub fn inner(&self) -> &SchemaRef { + &self.inner + } + /// Create a `DFSchema` from an Arrow schema where all the fields have a given qualifier pub fn new_with_metadata( qualified_fields: Vec<(Option, Arc)>, @@ -145,6 +159,19 @@ impl DFSchema { } /// Create a new `DFSchema` from a list of Arrow [Field]s + #[allow(deprecated)] + pub fn from_unqualified_fields( + fields: Fields, + metadata: HashMap, + ) -> Result { + Self::from_unqualifed_fields(fields, metadata) + } + + /// Create a new `DFSchema` from a list of Arrow [Field]s + #[deprecated( + since = "40.0.0", + note = "Please use `from_unqualified_fields` instead (this one's name is a typo). This method is subject to be removed soon" + )] pub fn from_unqualifed_fields( fields: Fields, metadata: HashMap, @@ -184,7 +211,7 @@ impl DFSchema { schema: &SchemaRef, ) -> Result { let dfschema = Self { - inner: schema.clone(), + inner: Arc::clone(schema), field_qualifiers: qualifiers, functional_dependencies: FunctionalDependencies::empty(), }; @@ -199,7 +226,12 @@ impl DFSchema { for (field, qualifier) in self.inner.fields().iter().zip(&self.field_qualifiers) { if let Some(qualifier) = qualifier { - qualified_names.insert((qualifier, field.name())); + if !qualified_names.insert((qualifier, field.name())) { + return _schema_err!(SchemaError::DuplicateQualifiedField { + qualifier: Box::new(qualifier.clone()), + name: field.name().to_string(), + }); + } } else if !unqualified_names.insert(field.name()) { return _schema_err!(SchemaError::DuplicateUnqualifiedField { name: field.name().to_string() @@ -283,8 +315,7 @@ impl DFSchema { None => self_unqualified_names.contains(field.name().as_str()), }; if !duplicated_field { - // self.inner.fields.push(field.clone()); - schema_builder.push(field.clone()); + schema_builder.push(Arc::clone(field)); qualifiers.push(qualifier.cloned()); } } @@ -328,18 +359,7 @@ impl DFSchema { // qualifier and name. (Some(q), Some(field_q)) => q.resolved_eq(field_q) && f.name() == name, // field to lookup is qualified but current field is unqualified. - (Some(qq), None) => { - // the original field may now be aliased with a name that matches the - // original qualified name - let column = Column::from_qualified_name(f.name()); - match column { - Column { - relation: Some(r), - name: column_name, - } => &r == qq && column_name == name, - _ => false, - } - } + (Some(_), None) => false, // field to lookup is unqualified, no need to compare qualifier (None, Some(_)) | (None, None) => f.name() == name, }) @@ -347,9 +367,22 @@ impl DFSchema { matches.next() } - /// Find the index of the column with the given qualifier and name - pub fn index_of_column(&self, col: &Column) -> Result { + /// Find the index of the column with the given qualifier and name, + /// returning `None` if not found + /// + /// See [Self::index_of_column] for a version that returns an error if the + /// column is not found + pub fn maybe_index_of_column(&self, col: &Column) -> Option { self.index_of_column_by_name(col.relation.as_ref(), &col.name) + } + + /// Find the index of the column with the given qualifier and name, + /// returning `Err` if not found + /// + /// See [Self::maybe_index_of_column] for a version that returns `None` if + /// the column is not found + pub fn index_of_column(&self, col: &Column) -> Result { + self.maybe_index_of_column(col) .ok_or_else(|| field_not_found(col.relation.clone(), &col.name, self)) } @@ -481,34 +514,8 @@ impl DFSchema { /// Find the field with the given name pub fn field_with_unqualified_name(&self, name: &str) -> Result<&Field> { - let matches = self.qualified_fields_with_unqualified_name(name); - match matches.len() { - 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok(matches[0].1), - _ => { - // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. - // Because name may generate from Alias/... . It means that it don't own qualifier. - // For example: - // Join on id = b.id - // Project a.id as id TableScan b id - // In this case, there isn't `ambiguous name` problem. When `matches` just contains - // one field without qualifier, we should return it. - let fields_without_qualifier = matches - .iter() - .filter(|(q, _)| q.is_none()) - .collect::>(); - if fields_without_qualifier.len() == 1 { - Ok(fields_without_qualifier[0].1) - } else { - _schema_err!(SchemaError::AmbiguousReference { - field: Column { - relation: None, - name: name.to_string(), - }, - }) - } - } - } + self.qualified_field_with_unqualified_name(name) + .map(|(_, field)| field) } /// Find the field with the given qualified name @@ -639,7 +646,7 @@ impl DFSchema { /// than datatype_is_semantically_equal in that a Dictionary type is logically /// equal to a plain V type, but not semantically equal. Dictionary is also /// logically equal to Dictionary. - fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { + pub fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool { // check nested fields match (dt1, dt2) { (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { @@ -793,13 +800,35 @@ impl From<&DFSchema> for Schema { } } +/// Allow DFSchema to be converted into an Arrow `&Schema` +impl AsRef for DFSchema { + fn as_ref(&self) -> &Schema { + self.as_arrow() + } +} + +/// Allow DFSchema to be converted into an Arrow `&SchemaRef` (to clone, for +/// example) +impl AsRef for DFSchema { + fn as_ref(&self) -> &SchemaRef { + self.inner() + } +} + /// Create a `DFSchema` from an Arrow schema impl TryFrom for DFSchema { type Error = DataFusionError; fn try_from(schema: Schema) -> Result { + Self::try_from(Arc::new(schema)) + } +} + +impl TryFrom for DFSchema { + type Error = DataFusionError; + fn try_from(schema: SchemaRef) -> Result { let field_count = schema.fields.len(); let dfschema = Self { - inner: schema.into(), + inner: schema, field_qualifiers: vec![None; field_count], functional_dependencies: FunctionalDependencies::empty(), }; @@ -843,12 +872,7 @@ impl ToDFSchema for Schema { impl ToDFSchema for SchemaRef { fn to_dfschema(self) -> Result { - // Attempt to use the Schema directly if there are no other - // references, otherwise clone - match Self::try_unwrap(self) { - Ok(schema) => DFSchema::try_from(schema), - Err(schemaref) => DFSchema::try_from(schemaref.as_ref().clone()), - } + DFSchema::try_from(self) } } @@ -1118,7 +1142,10 @@ mod tests { let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let join = left.join(&right); - assert!(join.err().is_none()); + assert_eq!( + join.unwrap_err().strip_backtrace(), + "Schema error: Schema contains duplicate qualified field name t1.c0", + ); Ok(()) } @@ -1211,15 +1238,14 @@ mod tests { #[test] fn into() { // Demonstrate how to convert back and forth between Schema, SchemaRef, DFSchema, and DFSchemaRef - let metadata = test_metadata(); let arrow_schema = Schema::new_with_metadata( vec![Field::new("c0", DataType::Int64, true)], - metadata.clone(), + test_metadata(), ); let arrow_schema_ref = Arc::new(arrow_schema.clone()); let df_schema = DFSchema { - inner: arrow_schema_ref.clone(), + inner: Arc::clone(&arrow_schema_ref), field_qualifiers: vec![None; arrow_schema_ref.fields.len()], functional_dependencies: FunctionalDependencies::empty(), }; @@ -1227,7 +1253,7 @@ mod tests { { let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); + let arrow_schema_ref = Arc::clone(&arrow_schema_ref); assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap()); assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap()); @@ -1235,7 +1261,7 @@ mod tests { { let arrow_schema = arrow_schema.clone(); - let arrow_schema_ref = arrow_schema_ref.clone(); + let arrow_schema_ref = Arc::clone(&arrow_schema_ref); assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap()); assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap()); @@ -1253,7 +1279,7 @@ mod tests { ]) } #[test] - fn test_dfschema_to_schema_convertion() { + fn test_dfschema_to_schema_conversion() { let mut a_metadata = HashMap::new(); a_metadata.insert("key".to_string(), "value".to_string()); let a_field = Field::new("a", DataType::Int64, false).with_metadata(a_metadata); @@ -1265,7 +1291,7 @@ mod tests { let schema = Arc::new(Schema::new(vec![a_field, b_field])); let df_schema = DFSchema { - inner: schema.clone(), + inner: Arc::clone(&schema), field_qualifiers: vec![None; schema.fields.len()], functional_dependencies: FunctionalDependencies::empty(), }; diff --git a/datafusion/common/src/display/mod.rs b/datafusion/common/src/display/mod.rs index 4d1d48bf9fcc..c12e7419e4b6 100644 --- a/datafusion/common/src/display/mod.rs +++ b/datafusion/common/src/display/mod.rs @@ -27,7 +27,7 @@ use std::{ /// Represents which type of plan, when storing multiple /// for use in EXPLAIN plans -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum PlanType { /// The initial LogicalPlan provided to DataFusion InitialLogicalPlan, @@ -49,15 +49,19 @@ pub enum PlanType { InitialPhysicalPlan, /// The initial physical plan with stats, prepared for execution InitialPhysicalPlanWithStats, + /// The initial physical plan with schema, prepared for execution + InitialPhysicalPlanWithSchema, /// The ExecutionPlan which results from applying an optimizer pass OptimizedPhysicalPlan { /// The name of the optimizer which produced this plan optimizer_name: String, }, - /// The final, fully optimized physical which would be executed + /// The final, fully optimized physical plan which would be executed FinalPhysicalPlan, - /// The final with stats, fully optimized physical which would be executed + /// The final with stats, fully optimized physical plan which would be executed FinalPhysicalPlanWithStats, + /// The final with schema, fully optimized physical plan which would be executed + FinalPhysicalPlanWithSchema, } impl Display for PlanType { @@ -76,17 +80,23 @@ impl Display for PlanType { PlanType::InitialPhysicalPlanWithStats => { write!(f, "initial_physical_plan_with_stats") } + PlanType::InitialPhysicalPlanWithSchema => { + write!(f, "initial_physical_plan_with_schema") + } PlanType::OptimizedPhysicalPlan { optimizer_name } => { write!(f, "physical_plan after {optimizer_name}") } PlanType::FinalPhysicalPlan => write!(f, "physical_plan"), PlanType::FinalPhysicalPlanWithStats => write!(f, "physical_plan_with_stats"), + PlanType::FinalPhysicalPlanWithSchema => { + write!(f, "physical_plan_with_schema") + } } } } /// Represents some sort of execution plan, in String form -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct StringifiedPlan { /// An identifier of what type of plan this string represents pub plan_type: PlanType, diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 4d9233d1f7c9..05988d6c6da4 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -19,6 +19,7 @@ #[cfg(feature = "backtrace")] use std::backtrace::{Backtrace, BacktraceStatus}; +use std::borrow::Cow; use std::error::Error; use std::fmt::{Display, Formatter}; use std::io; @@ -33,6 +34,7 @@ use arrow::error::ArrowError; #[cfg(feature = "parquet")] use parquet::errors::ParquetError; use sqlparser::parser::ParserError; +use tokio::task::JoinError; /// Result type for operations that could result in an [DataFusionError] pub type Result = result::Result; @@ -111,6 +113,10 @@ pub enum DataFusionError { /// SQL method, opened a CSV file that is broken, or tried to divide an /// integer by zero. Execution(String), + /// [`JoinError`] during execution of the query. + /// + /// This error can unoccur for unjoined tasks, such as execution shutdown. + ExecutionJoin(JoinError), /// Error when resources (such as memory of scratch disk space) are exhausted. /// /// This error is thrown when a consumer cannot acquire additional memory @@ -281,64 +287,9 @@ impl From for DataFusionError { impl Display for DataFusionError { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - match *self { - DataFusionError::ArrowError(ref desc, ref backtrace) => { - let backtrace = backtrace.clone().unwrap_or("".to_owned()); - write!(f, "Arrow error: {desc}{backtrace}") - } - #[cfg(feature = "parquet")] - DataFusionError::ParquetError(ref desc) => { - write!(f, "Parquet error: {desc}") - } - #[cfg(feature = "avro")] - DataFusionError::AvroError(ref desc) => { - write!(f, "Avro error: {desc}") - } - DataFusionError::IoError(ref desc) => { - write!(f, "IO error: {desc}") - } - DataFusionError::SQL(ref desc, ref backtrace) => { - let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); - write!(f, "SQL error: {desc:?}{backtrace}") - } - DataFusionError::Configuration(ref desc) => { - write!(f, "Invalid or Unsupported Configuration: {desc}") - } - DataFusionError::NotImplemented(ref desc) => { - write!(f, "This feature is not implemented: {desc}") - } - DataFusionError::Internal(ref desc) => { - write!(f, "Internal error: {desc}.\nThis was likely caused by a bug in DataFusion's \ - code and we would welcome that you file an bug report in our issue tracker") - } - DataFusionError::Plan(ref desc) => { - write!(f, "Error during planning: {desc}") - } - DataFusionError::SchemaError(ref desc, ref backtrace) => { - let backtrace: &str = - &backtrace.as_ref().clone().unwrap_or("".to_owned()); - write!(f, "Schema error: {desc}{backtrace}") - } - DataFusionError::Execution(ref desc) => { - write!(f, "Execution error: {desc}") - } - DataFusionError::ResourcesExhausted(ref desc) => { - write!(f, "Resources exhausted: {desc}") - } - DataFusionError::External(ref desc) => { - write!(f, "External error: {desc}") - } - #[cfg(feature = "object_store")] - DataFusionError::ObjectStore(ref desc) => { - write!(f, "Object Store error: {desc}") - } - DataFusionError::Context(ref desc, ref err) => { - write!(f, "{}\ncaused by\n{}", desc, *err) - } - DataFusionError::Substrait(ref desc) => { - write!(f, "Substrait error: {desc}") - } - } + let error_prefix = self.error_prefix(); + let message = self.message(); + write!(f, "{error_prefix}{message}") } } @@ -360,6 +311,7 @@ impl Error for DataFusionError { DataFusionError::Plan(_) => None, DataFusionError::SchemaError(e, _) => Some(e), DataFusionError::Execution(_) => None, + DataFusionError::ExecutionJoin(e) => Some(e), DataFusionError::ResourcesExhausted(_) => None, DataFusionError::External(e) => Some(e.as_ref()), DataFusionError::Context(_, e) => Some(e.as_ref()), @@ -375,7 +327,8 @@ impl From for io::Error { } impl DataFusionError { - const BACK_TRACE_SEP: &'static str = "\n\nbacktrace: "; + /// The separator between the error message and the backtrace + pub const BACK_TRACE_SEP: &'static str = "\n\nbacktrace: "; /// Get deepest underlying [`DataFusionError`] /// @@ -419,6 +372,9 @@ impl DataFusionError { Self::Context(description.into(), Box::new(self)) } + /// Strips backtrace out of the error message + /// If backtrace enabled then error has a format "message" [`Self::BACK_TRACE_SEP`] "backtrace" + /// The method strips the backtrace and outputs "message" pub fn strip_backtrace(&self) -> String { self.to_string() .split(Self::BACK_TRACE_SEP) @@ -450,6 +406,71 @@ impl DataFusionError { #[cfg(not(feature = "backtrace"))] "".to_owned() } + + fn error_prefix(&self) -> &'static str { + match self { + DataFusionError::ArrowError(_, _) => "Arrow error: ", + #[cfg(feature = "parquet")] + DataFusionError::ParquetError(_) => "Parquet error: ", + #[cfg(feature = "avro")] + DataFusionError::AvroError(_) => "Avro error: ", + #[cfg(feature = "object_store")] + DataFusionError::ObjectStore(_) => "Object Store error: ", + DataFusionError::IoError(_) => "IO error: ", + DataFusionError::SQL(_, _) => "SQL error: ", + DataFusionError::NotImplemented(_) => "This feature is not implemented: ", + DataFusionError::Internal(_) => "Internal error: ", + DataFusionError::Plan(_) => "Error during planning: ", + DataFusionError::Configuration(_) => "Invalid or Unsupported Configuration: ", + DataFusionError::SchemaError(_, _) => "Schema error: ", + DataFusionError::Execution(_) => "Execution error: ", + DataFusionError::ExecutionJoin(_) => "ExecutionJoin error: ", + DataFusionError::ResourcesExhausted(_) => "Resources exhausted: ", + DataFusionError::External(_) => "External error: ", + DataFusionError::Context(_, _) => "", + DataFusionError::Substrait(_) => "Substrait error: ", + } + } + + pub fn message(&self) -> Cow { + match *self { + DataFusionError::ArrowError(ref desc, ref backtrace) => { + let backtrace = backtrace.clone().unwrap_or("".to_owned()); + Cow::Owned(format!("{desc}{backtrace}")) + } + #[cfg(feature = "parquet")] + DataFusionError::ParquetError(ref desc) => Cow::Owned(desc.to_string()), + #[cfg(feature = "avro")] + DataFusionError::AvroError(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::IoError(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::SQL(ref desc, ref backtrace) => { + let backtrace: String = backtrace.clone().unwrap_or("".to_owned()); + Cow::Owned(format!("{desc:?}{backtrace}")) + } + DataFusionError::Configuration(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::NotImplemented(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::Internal(ref desc) => Cow::Owned(format!( + "{desc}.\nThis was likely caused by a bug in DataFusion's \ + code and we would welcome that you file an bug report in our issue tracker" + )), + DataFusionError::Plan(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::SchemaError(ref desc, ref backtrace) => { + let backtrace: &str = + &backtrace.as_ref().clone().unwrap_or("".to_owned()); + Cow::Owned(format!("{desc}{backtrace}")) + } + DataFusionError::Execution(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::ExecutionJoin(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::ResourcesExhausted(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::External(ref desc) => Cow::Owned(desc.to_string()), + #[cfg(feature = "object_store")] + DataFusionError::ObjectStore(ref desc) => Cow::Owned(desc.to_string()), + DataFusionError::Context(ref desc, ref err) => { + Cow::Owned(format!("{desc}\ncaused by\n{}", *err)) + } + DataFusionError::Substrait(ref desc) => Cow::Owned(desc.to_string()), + } + } } /// Unwrap an `Option` if possible. Otherwise return an `DataFusionError::Internal`. @@ -468,13 +489,6 @@ macro_rules! unwrap_or_internal_err { }; } -macro_rules! with_dollar_sign { - ($($body:tt)*) => { - macro_rules! __with_dollar_sign { $($body)* } - __with_dollar_sign!($); - } -} - /// Add a macros for concise DataFusionError::* errors declaration /// supports placeholders the same way as `format!` /// Examples: @@ -488,37 +502,41 @@ macro_rules! with_dollar_sign { /// `NAME_DF_ERR` - macro name for wrapping DataFusionError::*. Needed to keep backtrace opportunity /// in construction where DataFusionError::* used directly, like `map_err`, `ok_or_else`, etc macro_rules! make_error { - ($NAME_ERR:ident, $NAME_DF_ERR: ident, $ERR:ident) => { - with_dollar_sign! { - ($d:tt) => { - /// Macro wraps `$ERR` to add backtrace feature - #[macro_export] - macro_rules! $NAME_DF_ERR { - ($d($d args:expr),*) => { - $crate::DataFusionError::$ERR( - format!( - "{}{}", - format!($d($d args),*), - $crate::DataFusionError::get_back_trace(), - ).into() - ) - } + ($NAME_ERR:ident, $NAME_DF_ERR: ident, $ERR:ident) => { make_error!(@inner ($), $NAME_ERR, $NAME_DF_ERR, $ERR); }; + (@inner ($d:tt), $NAME_ERR:ident, $NAME_DF_ERR:ident, $ERR:ident) => { + ::paste::paste!{ + /// Macro wraps `$ERR` to add backtrace feature + #[macro_export] + macro_rules! $NAME_DF_ERR { + ($d($d args:expr),*) => { + $crate::DataFusionError::$ERR( + ::std::format!( + "{}{}", + ::std::format!($d($d args),*), + $crate::DataFusionError::get_back_trace(), + ).into() + ) } + } - /// Macro wraps Err(`$ERR`) to add backtrace feature - #[macro_export] - macro_rules! $NAME_ERR { - ($d($d args:expr),*) => { - Err($crate::DataFusionError::$ERR( - format!( - "{}{}", - format!($d($d args),*), - $crate::DataFusionError::get_back_trace(), - ).into() - )) - } + /// Macro wraps Err(`$ERR`) to add backtrace feature + #[macro_export] + macro_rules! $NAME_ERR { + ($d($d args:expr),*) => { + Err($crate::[<_ $NAME_DF_ERR>]!($d($d args),*)) } } + + + // Note: Certain macros are used in this crate, but not all. + // This macro generates a use or all of them in case they are needed + // so we allow unused code to avoid warnings when they are not used + #[doc(hidden)] + #[allow(unused)] + pub use $NAME_ERR as [<_ $NAME_ERR>]; + #[doc(hidden)] + #[allow(unused)] + pub use $NAME_DF_ERR as [<_ $NAME_DF_ERR>]; } }; } @@ -541,6 +559,9 @@ make_error!(config_err, config_datafusion_err, Configuration); // Exposes a macro to create `DataFusionError::Substrait` with optional backtrace make_error!(substrait_err, substrait_datafusion_err, Substrait); +// Exposes a macro to create `DataFusionError::ResourcesExhausted` with optional backtrace +make_error!(resources_err, resources_datafusion_err, ResourcesExhausted); + // Exposes a macro to create `DataFusionError::SQL` with optional backtrace #[macro_export] macro_rules! sql_datafusion_err { @@ -597,12 +618,6 @@ macro_rules! schema_err { // To avoid compiler error when using macro in the same crate: // macros from the current crate cannot be referred to by absolute paths -pub use config_err as _config_err; -pub use internal_datafusion_err as _internal_datafusion_err; -pub use internal_err as _internal_err; -pub use not_impl_err as _not_impl_err; -pub use plan_datafusion_err as _plan_datafusion_err; -pub use plan_err as _plan_err; pub use schema_err as _schema_err; /// Create a "field not found" DataFusion::SchemaError @@ -646,11 +661,16 @@ mod test { assert_eq!(res.strip_backtrace(), "Arrow error: Schema error: bar"); } - // RUST_BACKTRACE=1 cargo test --features backtrace --package datafusion-common --lib -- error::test::test_backtrace + // To pass the test the environment variable RUST_BACKTRACE should be set to 1 to enforce backtrace #[cfg(feature = "backtrace")] #[test] #[allow(clippy::unnecessary_literal_unwrap)] fn test_enabled_backtrace() { + match std::env::var("RUST_BACKTRACE") { + Ok(val) if val == "1" => {} + _ => panic!("Environment variable RUST_BACKTRACE must be set to 1"), + }; + let res: Result<(), DataFusionError> = plan_err!("Err"); let err = res.unwrap_err().to_string(); assert!(err.contains(DataFusionError::BACK_TRACE_SEP)); diff --git a/datafusion/common/src/file_options/csv_writer.rs b/datafusion/common/src/file_options/csv_writer.rs index 5f1a62682f8d..943288af9164 100644 --- a/datafusion/common/src/file_options/csv_writer.rs +++ b/datafusion/common/src/file_options/csv_writer.rs @@ -50,7 +50,8 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { fn try_from(value: &CsvOptions) -> Result { let mut builder = WriterBuilder::default() - .with_header(value.has_header) + .with_header(value.has_header.unwrap_or(true)) + .with_quote(value.quote) .with_delimiter(value.delimiter); if let Some(v) = &value.date_format { @@ -62,12 +63,21 @@ impl TryFrom<&CsvOptions> for CsvWriterOptions { if let Some(v) = &value.timestamp_format { builder = builder.with_timestamp_format(v.into()) } + if let Some(v) = &value.timestamp_tz_format { + builder = builder.with_timestamp_tz_format(v.into()) + } if let Some(v) = &value.time_format { builder = builder.with_time_format(v.into()) } if let Some(v) = &value.null_value { builder = builder.with_null(v.into()) } + if let Some(v) = &value.escape { + builder = builder.with_escape(*v) + } + if let Some(v) = &value.double_quote { + builder = builder.with_double_quote(*v) + } Ok(CsvWriterOptions { writer_options: builder, compression: value.compression, diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index fc0bb7445645..2648f7289798 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -17,11 +17,8 @@ //! File type abstraction -use std::fmt::{self, Display}; -use std::str::FromStr; - -use crate::config::FormatOptions; -use crate::error::{DataFusionError, Result}; +use std::any::Any; +use std::fmt::Display; /// The default file extension of arrow files pub const DEFAULT_ARROW_EXTENSION: &str = ".arrow"; @@ -40,107 +37,10 @@ pub trait GetExt { fn get_ext(&self) -> String; } -/// Readable file type -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum FileType { - /// Apache Arrow file - ARROW, - /// Apache Avro file - AVRO, - /// Apache Parquet file - #[cfg(feature = "parquet")] - PARQUET, - /// CSV file - CSV, - /// JSON file - JSON, -} - -impl From<&FormatOptions> for FileType { - fn from(value: &FormatOptions) -> Self { - match value { - FormatOptions::CSV(_) => FileType::CSV, - FormatOptions::JSON(_) => FileType::JSON, - #[cfg(feature = "parquet")] - FormatOptions::PARQUET(_) => FileType::PARQUET, - FormatOptions::AVRO => FileType::AVRO, - FormatOptions::ARROW => FileType::ARROW, - } - } -} - -impl GetExt for FileType { - fn get_ext(&self) -> String { - match self { - FileType::ARROW => DEFAULT_ARROW_EXTENSION.to_owned(), - FileType::AVRO => DEFAULT_AVRO_EXTENSION.to_owned(), - #[cfg(feature = "parquet")] - FileType::PARQUET => DEFAULT_PARQUET_EXTENSION.to_owned(), - FileType::CSV => DEFAULT_CSV_EXTENSION.to_owned(), - FileType::JSON => DEFAULT_JSON_EXTENSION.to_owned(), - } - } -} - -impl Display for FileType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let out = match self { - FileType::CSV => "csv", - FileType::JSON => "json", - #[cfg(feature = "parquet")] - FileType::PARQUET => "parquet", - FileType::AVRO => "avro", - FileType::ARROW => "arrow", - }; - write!(f, "{}", out) - } -} - -impl FromStr for FileType { - type Err = DataFusionError; - - fn from_str(s: &str) -> Result { - let s = s.to_uppercase(); - match s.as_str() { - "ARROW" => Ok(FileType::ARROW), - "AVRO" => Ok(FileType::AVRO), - #[cfg(feature = "parquet")] - "PARQUET" => Ok(FileType::PARQUET), - "CSV" => Ok(FileType::CSV), - "JSON" | "NDJSON" => Ok(FileType::JSON), - _ => Err(DataFusionError::NotImplemented(format!( - "Unknown FileType: {s}" - ))), - } - } -} - -#[cfg(test)] -#[cfg(feature = "parquet")] -mod tests { - use std::str::FromStr; - - use crate::error::DataFusionError; - use crate::FileType; - - #[test] - fn from_str() { - for (ext, file_type) in [ - ("csv", FileType::CSV), - ("CSV", FileType::CSV), - ("json", FileType::JSON), - ("JSON", FileType::JSON), - ("avro", FileType::AVRO), - ("AVRO", FileType::AVRO), - ("parquet", FileType::PARQUET), - ("PARQUET", FileType::PARQUET), - ] { - assert_eq!(FileType::from_str(ext).unwrap(), file_type); - } - - assert!(matches!( - FileType::from_str("Unknown"), - Err(DataFusionError::NotImplemented(_)) - )); - } +/// Defines the functionality needed for logical planning for +/// a type of file which will be read or written to storage. +pub trait FileType: GetExt + Display + Send + Sync { + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index eb1ce1b364fd..77781457d0d2 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -32,10 +32,10 @@ mod tests { use super::parquet_writer::ParquetWriterOptions; use crate::{ - config::TableOptions, + config::{ConfigFileType, TableOptions}, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - FileType, Result, + Result, }; use parquet::{ @@ -67,7 +67,7 @@ mod tests { "format.data_page_row_count_limit".to_owned(), "123".to_owned(), ); - option_map.insert("format.bloom_filter_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.bloom_filter_on_write".to_owned(), "true".to_owned()); option_map.insert("format.encoding".to_owned(), "plain".to_owned()); option_map.insert("format.dictionary_enabled".to_owned(), "true".to_owned()); option_map.insert("format.compression".to_owned(), "zstd(4)".to_owned()); @@ -76,7 +76,7 @@ mod tests { option_map.insert("format.bloom_filter_ndv".to_owned(), "123".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -124,6 +124,10 @@ mod tests { 123 ); + // properties which remain as default on WriterProperties + assert_eq!(properties.key_value_metadata(), None); + assert_eq!(properties.sorting_columns(), None); + Ok(()) } @@ -177,7 +181,7 @@ mod tests { ); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::PARQUET); + table_config.set_config_format(ConfigFileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -280,7 +284,7 @@ mod tests { option_map.insert("format.delimiter".to_owned(), ";".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::CSV); + table_config.set_config_format(ConfigFileType::CSV); table_config.alter_with_string_hash_map(&option_map)?; let csv_options = CsvWriterOptions::try_from(&table_config.csv)?; @@ -302,7 +306,7 @@ mod tests { option_map.insert("format.compression".to_owned(), "gzip".to_owned()); let mut table_config = TableOptions::new(); - table_config.set_file_format(FileType::JSON); + table_config.set_config_format(ConfigFileType::JSON); table_config.alter_with_string_hash_map(&option_map)?; let json_options = JsonWriterOptions::try_from(&table_config.json)?; diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index 28e73ba48f53..dd9d67d6bb47 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -17,17 +17,25 @@ //! Options related to how parquet files should be written -use crate::{config::TableParquetOptions, DataFusionError, Result}; +use crate::{ + config::{ParquetOptions, TableParquetOptions}, + DataFusionError, Result, +}; use parquet::{ basic::{BrotliLevel, GzipLevel, ZstdLevel}, - file::properties::{EnabledStatistics, WriterProperties, WriterVersion}, + file::properties::{ + EnabledStatistics, WriterProperties, WriterPropertiesBuilder, WriterVersion, + DEFAULT_MAX_STATISTICS_SIZE, DEFAULT_STATISTICS_ENABLED, + }, + format::KeyValue, schema::types::ColumnPath, }; /// Options for writing parquet files #[derive(Clone, Debug)] pub struct ParquetWriterOptions { + /// parquet-rs writer properties pub writer_options: WriterProperties, } @@ -46,57 +54,43 @@ impl ParquetWriterOptions { impl TryFrom<&TableParquetOptions> for ParquetWriterOptions { type Error = DataFusionError; - fn try_from(parquet_options: &TableParquetOptions) -> Result { - let parquet_session_options = &parquet_options.global; - let mut builder = WriterProperties::builder() - .set_data_page_size_limit(parquet_session_options.data_pagesize_limit) - .set_write_batch_size(parquet_session_options.write_batch_size) - .set_writer_version(parse_version_string( - &parquet_session_options.writer_version, - )?) - .set_dictionary_page_size_limit( - parquet_session_options.dictionary_page_size_limit, - ) - .set_max_row_group_size(parquet_session_options.max_row_group_size) - .set_created_by(parquet_session_options.created_by.clone()) - .set_column_index_truncate_length( - parquet_session_options.column_index_truncate_length, - ) - .set_data_page_row_count_limit( - parquet_session_options.data_page_row_count_limit, - ) - .set_bloom_filter_enabled(parquet_session_options.bloom_filter_enabled); - - if let Some(encoding) = &parquet_session_options.encoding { - builder = builder.set_encoding(parse_encoding_string(encoding)?); - } - - if let Some(enabled) = parquet_session_options.dictionary_enabled { - builder = builder.set_dictionary_enabled(enabled); - } - - if let Some(compression) = &parquet_session_options.compression { - builder = builder.set_compression(parse_compression_string(compression)?); - } - - if let Some(statistics) = &parquet_session_options.statistics_enabled { - builder = - builder.set_statistics_enabled(parse_statistics_string(statistics)?); - } - - if let Some(size) = parquet_session_options.max_statistics_size { - builder = builder.set_max_statistics_size(size); - } + fn try_from(parquet_table_options: &TableParquetOptions) -> Result { + // ParquetWriterOptions will have defaults for the remaining fields (e.g. sorting_columns) + Ok(ParquetWriterOptions { + writer_options: WriterPropertiesBuilder::try_from(parquet_table_options)? + .build(), + }) + } +} - if let Some(fpp) = parquet_session_options.bloom_filter_fpp { - builder = builder.set_bloom_filter_fpp(fpp); - } +impl TryFrom<&TableParquetOptions> for WriterPropertiesBuilder { + type Error = DataFusionError; - if let Some(ndv) = parquet_session_options.bloom_filter_ndv { - builder = builder.set_bloom_filter_ndv(ndv); + /// Convert the session's [`TableParquetOptions`] into a single write action's [`WriterPropertiesBuilder`]. + /// + /// The returned [`WriterPropertiesBuilder`] includes customizations applicable per column. + fn try_from(table_parquet_options: &TableParquetOptions) -> Result { + // Table options include kv_metadata and col-specific options + let TableParquetOptions { + global, + column_specific_options, + key_value_metadata, + } = table_parquet_options; + + let mut builder = global.into_writer_properties_builder()?; + + if !key_value_metadata.is_empty() { + builder = builder.set_key_value_metadata(Some( + key_value_metadata + .to_owned() + .drain() + .map(|(key, value)| KeyValue { key, value }) + .collect(), + )); } - for (column, options) in &parquet_options.column_specific_options { + // Apply column-specific options: + for (column, options) in column_specific_options { let path = ColumnPath::new(column.split('.').map(|s| s.to_owned()).collect()); if let Some(bloom_filter_enabled) = options.bloom_filter_enabled { @@ -141,9 +135,90 @@ impl TryFrom<&TableParquetOptions> for ParquetWriterOptions { builder.set_column_max_statistics_size(path, max_statistics_size); } } - Ok(ParquetWriterOptions { - writer_options: builder.build(), - }) + + Ok(builder) + } +} + +impl ParquetOptions { + /// Convert the global session options, [`ParquetOptions`], into a single write action's [`WriterPropertiesBuilder`]. + /// + /// The returned [`WriterPropertiesBuilder`] can then be further modified with additional options + /// applied per column; a customization which is not applicable for [`ParquetOptions`]. + pub fn into_writer_properties_builder(&self) -> Result { + let ParquetOptions { + data_pagesize_limit, + write_batch_size, + writer_version, + compression, + dictionary_enabled, + dictionary_page_size_limit, + statistics_enabled, + max_statistics_size, + max_row_group_size, + created_by, + column_index_truncate_length, + data_page_row_count_limit, + encoding, + bloom_filter_on_write, + bloom_filter_fpp, + bloom_filter_ndv, + + // not in WriterProperties + enable_page_index: _, + pruning: _, + skip_metadata: _, + metadata_size_hint: _, + pushdown_filters: _, + reorder_filters: _, + allow_single_file_parallelism: _, + maximum_parallel_row_group_writers: _, + maximum_buffered_record_batches_per_stream: _, + bloom_filter_on_read: _, // reads not used for writer props + schema_force_view_types: _, + binary_as_string: _, // not used for writer props + } = self; + + let mut builder = WriterProperties::builder() + .set_data_page_size_limit(*data_pagesize_limit) + .set_write_batch_size(*write_batch_size) + .set_writer_version(parse_version_string(writer_version.as_str())?) + .set_dictionary_page_size_limit(*dictionary_page_size_limit) + .set_statistics_enabled( + statistics_enabled + .as_ref() + .and_then(|s| parse_statistics_string(s).ok()) + .unwrap_or(DEFAULT_STATISTICS_ENABLED), + ) + .set_max_statistics_size( + max_statistics_size.unwrap_or(DEFAULT_MAX_STATISTICS_SIZE), + ) + .set_max_row_group_size(*max_row_group_size) + .set_created_by(created_by.clone()) + .set_column_index_truncate_length(*column_index_truncate_length) + .set_data_page_row_count_limit(*data_page_row_count_limit) + .set_bloom_filter_enabled(*bloom_filter_on_write); + + if let Some(bloom_filter_fpp) = bloom_filter_fpp { + builder = builder.set_bloom_filter_fpp(*bloom_filter_fpp); + }; + if let Some(bloom_filter_ndv) = bloom_filter_ndv { + builder = builder.set_bloom_filter_ndv(*bloom_filter_ndv); + }; + if let Some(dictionary_enabled) = dictionary_enabled { + builder = builder.set_dictionary_enabled(*dictionary_enabled); + }; + + // We do not have access to default ColumnProperties set in Arrow. + // Therefore, only overwrite if these settings exist. + if let Some(compression) = compression { + builder = builder.set_compression(parse_compression_string(compression)?); + } + if let Some(encoding) = encoding { + builder = builder.set_encoding(parse_encoding_string(encoding)?); + } + + Ok(builder) } } @@ -293,3 +368,374 @@ pub(crate) fn parse_statistics_string(str_setting: &str) -> Result ParquetColumnOptions { + ParquetColumnOptions { + compression: Some("zstd(22)".into()), + dictionary_enabled: src_col_defaults.dictionary_enabled.map(|v| !v), + statistics_enabled: Some("none".into()), + max_statistics_size: Some(72), + encoding: Some("RLE".into()), + bloom_filter_enabled: Some(true), + bloom_filter_fpp: Some(0.72), + bloom_filter_ndv: Some(72), + } + } + + fn parquet_options_with_non_defaults() -> ParquetOptions { + let defaults = ParquetOptions::default(); + let writer_version = if defaults.writer_version.eq("1.0") { + "2.0" + } else { + "1.0" + }; + + ParquetOptions { + data_pagesize_limit: 42, + write_batch_size: 42, + writer_version: writer_version.into(), + compression: Some("zstd(22)".into()), + dictionary_enabled: Some(!defaults.dictionary_enabled.unwrap_or(false)), + dictionary_page_size_limit: 42, + statistics_enabled: Some("chunk".into()), + max_statistics_size: Some(42), + max_row_group_size: 42, + created_by: "wordy".into(), + column_index_truncate_length: Some(42), + data_page_row_count_limit: 42, + encoding: Some("BYTE_STREAM_SPLIT".into()), + bloom_filter_on_write: !defaults.bloom_filter_on_write, + bloom_filter_fpp: Some(0.42), + bloom_filter_ndv: Some(42), + + // not in WriterProperties, but itemizing here to not skip newly added props + enable_page_index: defaults.enable_page_index, + pruning: defaults.pruning, + skip_metadata: defaults.skip_metadata, + metadata_size_hint: defaults.metadata_size_hint, + pushdown_filters: defaults.pushdown_filters, + reorder_filters: defaults.reorder_filters, + allow_single_file_parallelism: defaults.allow_single_file_parallelism, + maximum_parallel_row_group_writers: defaults + .maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream: defaults + .maximum_buffered_record_batches_per_stream, + bloom_filter_on_read: defaults.bloom_filter_on_read, + schema_force_view_types: defaults.schema_force_view_types, + binary_as_string: defaults.binary_as_string, + } + } + + fn extract_column_options( + props: &WriterProperties, + col: ColumnPath, + ) -> ParquetColumnOptions { + let bloom_filter_default_props = props.bloom_filter_properties(&col); + + ParquetColumnOptions { + bloom_filter_enabled: Some(bloom_filter_default_props.is_some()), + encoding: props.encoding(&col).map(|s| s.to_string()), + dictionary_enabled: Some(props.dictionary_enabled(&col)), + compression: match props.compression(&col) { + Compression::ZSTD(lvl) => { + Some(format!("zstd({})", lvl.compression_level())) + } + _ => None, + }, + statistics_enabled: Some( + match props.statistics_enabled(&col) { + EnabledStatistics::None => "none", + EnabledStatistics::Chunk => "chunk", + EnabledStatistics::Page => "page", + } + .into(), + ), + bloom_filter_fpp: bloom_filter_default_props.map(|p| p.fpp), + bloom_filter_ndv: bloom_filter_default_props.map(|p| p.ndv), + max_statistics_size: Some(props.max_statistics_size(&col)), + } + } + + /// For testing only, take a single write's props and convert back into the session config. + /// (use identity to confirm correct.) + fn session_config_from_writer_props(props: &WriterProperties) -> TableParquetOptions { + let default_col = ColumnPath::from("col doesn't have specific config"); + let default_col_props = extract_column_options(props, default_col); + + let configured_col = ColumnPath::from(COL_NAME); + let configured_col_props = extract_column_options(props, configured_col); + + let key_value_metadata = props + .key_value_metadata() + .map(|pairs| { + HashMap::from_iter( + pairs + .iter() + .cloned() + .map(|KeyValue { key, value }| (key, value)), + ) + }) + .unwrap_or_default(); + + let global_options_defaults = ParquetOptions::default(); + + let column_specific_options = if configured_col_props.eq(&default_col_props) { + HashMap::default() + } else { + HashMap::from([(COL_NAME.into(), configured_col_props)]) + }; + + TableParquetOptions { + global: ParquetOptions { + // global options + data_pagesize_limit: props.dictionary_page_size_limit(), + write_batch_size: props.write_batch_size(), + writer_version: format!("{}.0", props.writer_version().as_num()), + dictionary_page_size_limit: props.dictionary_page_size_limit(), + max_row_group_size: props.max_row_group_size(), + created_by: props.created_by().to_string(), + column_index_truncate_length: props.column_index_truncate_length(), + data_page_row_count_limit: props.data_page_row_count_limit(), + + // global options which set the default column props + encoding: default_col_props.encoding, + compression: default_col_props.compression, + dictionary_enabled: default_col_props.dictionary_enabled, + statistics_enabled: default_col_props.statistics_enabled, + max_statistics_size: default_col_props.max_statistics_size, + bloom_filter_on_write: default_col_props + .bloom_filter_enabled + .unwrap_or_default(), + bloom_filter_fpp: default_col_props.bloom_filter_fpp, + bloom_filter_ndv: default_col_props.bloom_filter_ndv, + + // not in WriterProperties + enable_page_index: global_options_defaults.enable_page_index, + pruning: global_options_defaults.pruning, + skip_metadata: global_options_defaults.skip_metadata, + metadata_size_hint: global_options_defaults.metadata_size_hint, + pushdown_filters: global_options_defaults.pushdown_filters, + reorder_filters: global_options_defaults.reorder_filters, + allow_single_file_parallelism: global_options_defaults + .allow_single_file_parallelism, + maximum_parallel_row_group_writers: global_options_defaults + .maximum_parallel_row_group_writers, + maximum_buffered_record_batches_per_stream: global_options_defaults + .maximum_buffered_record_batches_per_stream, + bloom_filter_on_read: global_options_defaults.bloom_filter_on_read, + schema_force_view_types: global_options_defaults.schema_force_view_types, + binary_as_string: global_options_defaults.binary_as_string, + }, + column_specific_options, + key_value_metadata, + } + } + + #[test] + fn table_parquet_opts_to_writer_props() { + // ParquetOptions, all props set to non-default + let parquet_options = parquet_options_with_non_defaults(); + + // TableParquetOptions, using ParquetOptions for global settings + let key = "foo".to_string(); + let value = Some("bar".into()); + let table_parquet_opts = TableParquetOptions { + global: parquet_options.clone(), + column_specific_options: [( + COL_NAME.into(), + column_options_with_non_defaults(&parquet_options), + )] + .into(), + key_value_metadata: [(key, value)].into(), + }; + + let writer_props = WriterPropertiesBuilder::try_from(&table_parquet_opts) + .unwrap() + .build(); + assert_eq!( + table_parquet_opts, + session_config_from_writer_props(&writer_props), + "the writer_props should have the same configuration as the session's TableParquetOptions", + ); + } + + /// Ensure that the configuration defaults for writing parquet files are + /// consistent with the options in arrow-rs + #[test] + fn test_defaults_match() { + // ensure the global settings are the same + let default_table_writer_opts = TableParquetOptions::default(); + let default_parquet_opts = ParquetOptions::default(); + assert_eq!( + default_table_writer_opts.global, + default_parquet_opts, + "should have matching defaults for TableParquetOptions.global and ParquetOptions", + ); + + // WriterProperties::default, a.k.a. using extern parquet's defaults + let default_writer_props = WriterProperties::new(); + + // WriterProperties::try_from(TableParquetOptions::default), a.k.a. using datafusion's defaults + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .build(); + + // Expected: how the defaults should not match + assert_ne!( + default_writer_props.created_by(), + from_datafusion_defaults.created_by(), + "should have different created_by sources", + ); + assert!( + default_writer_props.created_by().starts_with("parquet-rs version"), + "should indicate that writer_props defaults came from the extern parquet crate", + ); + assert!( + default_table_writer_opts + .global + .created_by + .starts_with("datafusion version"), + "should indicate that table_parquet_opts defaults came from datafusion", + ); + + // Expected: the datafusion default compression is different from arrow-rs's parquet + assert_eq!( + default_writer_props.compression(&"default".into()), + Compression::UNCOMPRESSED, + "extern parquet's default is None" + ); + assert!( + matches!( + from_datafusion_defaults.compression(&"default".into()), + Compression::ZSTD(_) + ), + "datafusion's default is zstd" + ); + + // Expected: the remaining should match + let same_created_by = default_table_writer_opts.global.created_by.clone(); + let mut from_extern_parquet = + session_config_from_writer_props(&default_writer_props); + from_extern_parquet.global.created_by = same_created_by; + from_extern_parquet.global.compression = Some("zstd(3)".into()); + + assert_eq!( + default_table_writer_opts, + from_extern_parquet, + "the default writer_props should have the same configuration as the session's default TableParquetOptions", + ); + } + + #[test] + fn test_bloom_filter_defaults() { + // the TableParquetOptions::default, with only the bloom filter turned on + let mut default_table_writer_opts = TableParquetOptions::default(); + default_table_writer_opts.global.bloom_filter_on_write = true; + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .build(); + + // the WriterProperties::default, with only the bloom filter turned on + let default_writer_props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .build(); + + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + "parquet and datafusion props, should have the same bloom filter props", + ); + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + Some(&BloomFilterProperties::default()), + "should use the default bloom filter props" + ); + } + + #[test] + fn test_bloom_filter_set_fpp_only() { + // the TableParquetOptions::default, with only fpp set + let mut default_table_writer_opts = TableParquetOptions::default(); + default_table_writer_opts.global.bloom_filter_on_write = true; + default_table_writer_opts.global.bloom_filter_fpp = Some(0.42); + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .build(); + + // the WriterProperties::default, with only fpp set + let default_writer_props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_bloom_filter_fpp(0.42) + .build(); + + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + "parquet and datafusion props, should have the same bloom filter props", + ); + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + Some(&BloomFilterProperties { + fpp: 0.42, + ndv: DEFAULT_BLOOM_FILTER_NDV + }), + "should have only the fpp set, and the ndv at default", + ); + } + + #[test] + fn test_bloom_filter_set_ndv_only() { + // the TableParquetOptions::default, with only ndv set + let mut default_table_writer_opts = TableParquetOptions::default(); + default_table_writer_opts.global.bloom_filter_on_write = true; + default_table_writer_opts.global.bloom_filter_ndv = Some(42); + let from_datafusion_defaults = + WriterPropertiesBuilder::try_from(&default_table_writer_opts) + .unwrap() + .build(); + + // the WriterProperties::default, with only ndv set + let default_writer_props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .set_bloom_filter_ndv(42) + .build(); + + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + from_datafusion_defaults.bloom_filter_properties(&"default".into()), + "parquet and datafusion props, should have the same bloom filter props", + ); + assert_eq!( + default_writer_props.bloom_filter_properties(&"default".into()), + Some(&BloomFilterProperties { + fpp: DEFAULT_BLOOM_FILTER_FPP, + ndv: 42 + }), + "should have only the ndv set, and the fpp at default", + ); + } +} diff --git a/datafusion/common/src/functional_dependencies.rs b/datafusion/common/src/functional_dependencies.rs index d1c3747b52b4..31eafc744390 100644 --- a/datafusion/common/src/functional_dependencies.rs +++ b/datafusion/common/src/functional_dependencies.rs @@ -23,14 +23,11 @@ use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::vec::IntoIter; -use crate::error::_plan_err; use crate::utils::{merge_and_order_indices, set_difference}; -use crate::{DFSchema, DFSchemaRef, DataFusionError, JoinType, Result}; - -use sqlparser::ast::TableConstraint; +use crate::{DFSchema, JoinType}; /// This object defines a constraint on a table. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Constraint { /// Columns with the given indices form a composite primary key (they are /// jointly unique and not nullable): @@ -40,7 +37,7 @@ pub enum Constraint { } /// This object encapsulates a list of functional constraints: -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Constraints { inner: Vec, } @@ -60,74 +57,6 @@ impl Constraints { Self { inner: constraints } } - /// Convert each `TableConstraint` to corresponding `Constraint` - pub fn new_from_table_constraints( - constraints: &[TableConstraint], - df_schema: &DFSchemaRef, - ) -> Result { - let constraints = constraints - .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let field_names = df_schema.field_names(); - // Get unique constraint indices in the schema: - let indices = columns - .iter() - .map(|u| { - let idx = field_names - .iter() - .position(|item| *item == u.value) - .ok_or_else(|| { - let name = name - .as_ref() - .map(|name| format!("with name '{name}' ")) - .unwrap_or("".to_string()); - DataFusionError::Execution( - format!("Column for unique constraint {}not found in schema: {}", name,u.value) - ) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::Unique(indices)) - } - TableConstraint::PrimaryKey { columns, .. } => { - let field_names = df_schema.field_names(); - // Get primary key indices in the schema: - let indices = columns - .iter() - .map(|pk| { - let idx = field_names - .iter() - .position(|item| *item == pk.value) - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Column for primary key not found in schema: {}", - pk.value - )) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::PrimaryKey(indices)) - } - TableConstraint::ForeignKey { .. } => { - _plan_err!("Foreign key constraints are not currently supported") - } - TableConstraint::Check { .. } => { - _plan_err!("Check constraints are not currently supported") - } - TableConstraint::Index { .. } => { - _plan_err!("Indexes are not currently supported") - } - TableConstraint::FulltextOrSpatial { .. } => { - _plan_err!("Indexes are not currently supported") - } - }) - .collect::>>()?; - Ok(Constraints::new_unverified(constraints)) - } - /// Check whether constraints is empty pub fn is_empty(&self) -> bool { self.inner.is_empty() @@ -405,7 +334,7 @@ impl FunctionalDependencies { left_func_dependencies.extend(right_func_dependencies); left_func_dependencies } - JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // These joins preserve functional dependencies of the left side: left_func_dependencies } @@ -433,7 +362,7 @@ impl FunctionalDependencies { } /// This function ensures that functional dependencies involving uniquely - /// occuring determinant keys cover their entire table in terms of + /// occurring determinant keys cover their entire table in terms of /// dependent columns. pub fn extend_target_indices(&mut self, n_out: usize) { self.deps.iter_mut().for_each( @@ -524,22 +453,31 @@ pub fn aggregate_functional_dependencies( } } - // If we have a single GROUP BY key, we can guarantee uniqueness after + // When we have a GROUP BY key, we can guarantee uniqueness after // aggregation: - if group_by_expr_names.len() == 1 { - // If `source_indices` contain 0, delete this functional dependency - // as it will be added anyway with mode `Dependency::Single`: - aggregate_func_dependencies.retain(|item| !item.source_indices.contains(&0)); - // Add a new functional dependency associated with the whole table: - aggregate_func_dependencies.push( - // Use nullable property of the group by expression - FunctionalDependence::new( - vec![0], - target_indices, - aggr_fields[0].is_nullable(), - ) - .with_mode(Dependency::Single), - ); + if !group_by_expr_names.is_empty() { + let count = group_by_expr_names.len(); + let source_indices = (0..count).collect::>(); + let nullable = source_indices + .iter() + .any(|idx| aggr_fields[*idx].is_nullable()); + // If GROUP BY expressions do not already act as a determinant: + if !aggregate_func_dependencies.iter().any(|item| { + // If `item.source_indices` is a subset of GROUP BY expressions, we shouldn't add + // them since `item.source_indices` defines this relation already. + + // The following simple comparison is working well because + // GROUP BY expressions come here as a prefix. + item.source_indices.iter().all(|idx| idx < &count) + }) { + // Add a new functional dependency associated with the whole table: + // Use nullable property of the GROUP BY expression: + aggregate_func_dependencies.push( + // Use nullable property of the GROUP BY expression: + FunctionalDependence::new(source_indices, target_indices, nullable) + .with_mode(Dependency::Single), + ); + } } FunctionalDependencies::new(aggregate_func_dependencies) } diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 9819fc7b344d..8bd646626e06 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -17,28 +17,35 @@ //! Functionality used both on logical and physical plans +#[cfg(not(feature = "force_hash_collisions"))] use std::sync::Arc; use ahash::RandomState; use arrow::array::*; use arrow::datatypes::*; -use arrow::row::Rows; +#[cfg(not(feature = "force_hash_collisions"))] use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use arrow_buffer::IntervalDayTime; +use arrow_buffer::IntervalMonthDayNano; +#[cfg(not(feature = "force_hash_collisions"))] use crate::cast::{ - as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, - as_large_list_array, as_list_array, as_primitive_array, as_string_array, - as_struct_array, + as_binary_view_array, as_boolean_array, as_fixed_size_list_array, + as_generic_binary_array, as_large_list_array, as_list_array, as_map_array, + as_primitive_array, as_string_array, as_string_view_array, as_struct_array, }; -use crate::error::{Result, _internal_err}; +use crate::error::Result; +#[cfg(not(feature = "force_hash_collisions"))] +use crate::error::_internal_err; // Combines two hashes into one hash #[inline] -fn combine_hashes(l: u64, r: u64) -> u64 { +pub fn combine_hashes(l: u64, r: u64) -> u64 { let hash = (17 * 37u64).wrapping_add(l); hash.wrapping_mul(37).wrapping_add(r) } +#[cfg(not(feature = "force_hash_collisions"))] fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) { if mul_col { hashes_buffer.iter_mut().for_each(|hash| { @@ -72,7 +79,7 @@ macro_rules! hash_value { }; } hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); -hash_value!(bool, str, [u8]); +hash_value!(bool, str, [u8], IntervalDayTime, IntervalMonthDayNano); macro_rules! hash_float_value { ($(($t:ty, $i:ty)),+) => { @@ -88,14 +95,14 @@ hash_float_value!((half::f16, u16), (f32, u32), (f64, u64)); /// Builds hash values of PrimitiveArray and writes them into `hashes_buffer` /// If `rehash==true` this combines the previous hash value in the buffer /// with the new hash using `combine_hashes` +#[cfg(not(feature = "force_hash_collisions"))] fn hash_array_primitive( array: &PrimitiveArray, random_state: &RandomState, hashes_buffer: &mut [u64], rehash: bool, ) where - T: ArrowPrimitiveType, - ::Native: HashValue, + T: ArrowPrimitiveType, { assert_eq!( hashes_buffer.len(), @@ -133,6 +140,7 @@ fn hash_array_primitive( /// Hashes one array into the `hashes_buffer` /// If `rehash==true` this combines the previous hash value in the buffer /// with the new hash using `combine_hashes` +#[cfg(not(feature = "force_hash_collisions"))] fn hash_array( array: T, random_state: &RandomState, @@ -178,6 +186,7 @@ fn hash_array( } /// Hash the values in a dictionary array +#[cfg(not(feature = "force_hash_collisions"))] fn hash_dictionary( array: &DictionaryArray, random_state: &RandomState, @@ -208,6 +217,7 @@ fn hash_dictionary( Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] fn hash_struct_array( array: &StructArray, random_state: &RandomState, @@ -234,6 +244,43 @@ fn hash_struct_array( Ok(()) } +// only adding this `cfg` b/c this function is only used with this `cfg` +#[cfg(not(feature = "force_hash_collisions"))] +fn hash_map_array( + array: &MapArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let nulls = array.nulls(); + let offsets = array.offsets(); + + // Create hashes for each entry in each row + let mut values_hashes = vec![0u64; array.entries().len()]; + create_hashes(array.entries().columns(), random_state, &mut values_hashes)?; + + // Combine the hashes for entries on each row with each other and previous hash for that row + if let Some(nulls) = nulls { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + if nulls.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for (i, (start, stop)) in offsets.iter().zip(offsets.iter().skip(1)).enumerate() { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[start.as_usize()..stop.as_usize()] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + + Ok(()) +} + +#[cfg(not(feature = "force_hash_collisions"))] fn hash_list_array( array: &GenericListArray, random_state: &RandomState, @@ -242,7 +289,7 @@ fn hash_list_array( where OffsetSize: OffsetSizeTrait, { - let values = array.values().clone(); + let values = Arc::clone(array.values()); let offsets = array.value_offsets(); let nulls = array.nulls(); let mut values_hashes = vec![0u64; values.len()]; @@ -267,12 +314,13 @@ where Ok(()) } +#[cfg(not(feature = "force_hash_collisions"))] fn hash_fixed_list_array( array: &FixedSizeListArray, random_state: &RandomState, hashes_buffer: &mut [u64], ) -> Result<()> { - let values = array.values().clone(); + let values = Arc::clone(array.values()); let value_len = array.value_length(); let offset_size = value_len as usize / array.len(); let nulls = array.nulls(); @@ -315,38 +363,6 @@ pub fn create_hashes<'a>( Ok(hashes_buffer) } -/// Test version of `create_row_hashes` that produces the same value for -/// all hashes (to test collisions) -/// -/// See comments on `hashes_buffer` for more details -#[cfg(feature = "force_hash_collisions")] -pub fn create_row_hashes<'a>( - _rows: &[Vec], - _random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for hash in hashes_buffer.iter_mut() { - *hash = 0 - } - Ok(hashes_buffer) -} - -/// Creates hash values for every row, based on their raw bytes. -#[cfg(not(feature = "force_hash_collisions"))] -pub fn create_row_hashes<'a>( - rows: &[Vec], - random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for hash in hashes_buffer.iter_mut() { - *hash = 0 - } - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = random_state.hash_one(&rows[i]); - } - Ok(hashes_buffer) -} - /// Creates hash values for every row, based on the values in the /// columns. /// @@ -367,8 +383,10 @@ pub fn create_hashes<'a>( DataType::Null => hash_null(random_state, hashes_buffer, rehash), DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash), DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash), + DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash), DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash), DataType::Binary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), + DataType::BinaryView => hash_array(as_binary_view_array(array)?, random_state, hashes_buffer, rehash), DataType::LargeBinary => hash_array(as_generic_binary_array::(array)?, random_state, hashes_buffer, rehash), DataType::FixedSizeBinary(_) => { let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap(); @@ -398,6 +416,10 @@ pub fn create_hashes<'a>( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::Map(_, _) => { + let array = as_map_array(array)?; + hash_map_array(array, random_state, hashes_buffer)?; + } DataType::FixedSizeList(_,_) => { let array = as_fixed_size_list_array(array)?; hash_fixed_list_array(array, random_state, hashes_buffer)?; @@ -414,41 +436,13 @@ pub fn create_hashes<'a>( Ok(hashes_buffer) } -/// Test version of `create_row_hashes_v2` that produces the same value for -/// all hashes (to test collisions) -/// -/// See comments on `hashes_buffer` for more details -#[cfg(feature = "force_hash_collisions")] -pub fn create_row_hashes_v2<'a>( - _rows: &Rows, - _random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for hash in hashes_buffer.iter_mut() { - *hash = 0 - } - Ok(hashes_buffer) -} - -/// Creates hash values for every row, based on their raw bytes. -#[cfg(not(feature = "force_hash_collisions"))] -pub fn create_row_hashes_v2<'a>( - rows: &Rows, - random_state: &RandomState, - hashes_buffer: &'a mut Vec, -) -> Result<&'a mut Vec> { - for hash in hashes_buffer.iter_mut() { - *hash = 0 - } - for (i, hash) in hashes_buffer.iter_mut().enumerate() { - *hash = random_state.hash_one(rows.row(i)); - } - Ok(hashes_buffer) -} - #[cfg(test)] mod tests { - use arrow::{array::*, datatypes::*}; + use std::sync::Arc; + + use arrow::array::*; + #[cfg(not(feature = "force_hash_collisions"))] + use arrow::datatypes::*; use super::*; @@ -484,22 +478,57 @@ mod tests { Ok(()) } - #[test] - fn create_hashes_binary() -> Result<()> { - let byte_array = Arc::new(BinaryArray::from_vec(vec![ - &[4, 3, 2], - &[4, 3, 2], - &[1, 2, 3], - ])); + macro_rules! create_hash_binary { + ($NAME:ident, $ARRAY:ty) => { + #[cfg(not(feature = "force_hash_collisions"))] + #[test] + fn $NAME() { + let binary = [ + Some(b"short".to_byte_slice()), + None, + Some(b"long but different 12 bytes string"), + Some(b"short2"), + Some(b"Longer than 12 bytes string"), + Some(b"short"), + Some(b"Longer than 12 bytes string"), + ]; + + let binary_array = Arc::new(binary.iter().cloned().collect::<$ARRAY>()); + let ref_array = Arc::new(binary.iter().cloned().collect::()); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + let mut binary_hashes = vec![0; binary.len()]; + create_hashes(&[binary_array], &random_state, &mut binary_hashes) + .unwrap(); + + let mut ref_hashes = vec![0; binary.len()]; + create_hashes(&[ref_array], &random_state, &mut ref_hashes).unwrap(); + + // Null values result in a zero hash, + for (val, hash) in binary.iter().zip(binary_hashes.iter()) { + match val { + Some(_) => assert_ne!(*hash, 0), + None => assert_eq!(*hash, 0), + } + } - let random_state = RandomState::with_seeds(0, 0, 0, 0); - let hashes_buff = &mut vec![0; byte_array.len()]; - let hashes = create_hashes(&[byte_array], &random_state, hashes_buff)?; - assert_eq!(hashes.len(), 3,); + // same logical values should hash to the same hash value + assert_eq!(binary_hashes, ref_hashes); - Ok(()) + // Same values should map to same hash values + assert_eq!(binary[0], binary[5]); + assert_eq!(binary[4], binary[6]); + + // different binary should map to different hash values + assert_ne!(binary[0], binary[2]); + } + }; } + create_hash_binary!(binary_array, BinaryArray); + create_hash_binary!(binary_view_array, BinaryViewArray); + #[test] fn create_hashes_fixed_size_binary() -> Result<()> { let input_arg = vec![vec![1, 2], vec![5, 6], vec![5, 6]]; @@ -515,6 +544,64 @@ mod tests { Ok(()) } + macro_rules! create_hash_string { + ($NAME:ident, $ARRAY:ty) => { + #[cfg(not(feature = "force_hash_collisions"))] + #[test] + fn $NAME() { + let strings = [ + Some("short"), + None, + Some("long but different 12 bytes string"), + Some("short2"), + Some("Longer than 12 bytes string"), + Some("short"), + Some("Longer than 12 bytes string"), + ]; + + let string_array = Arc::new(strings.iter().cloned().collect::<$ARRAY>()); + let dict_array = Arc::new( + strings + .iter() + .cloned() + .collect::>(), + ); + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + + let mut string_hashes = vec![0; strings.len()]; + create_hashes(&[string_array], &random_state, &mut string_hashes) + .unwrap(); + + let mut dict_hashes = vec![0; strings.len()]; + create_hashes(&[dict_array], &random_state, &mut dict_hashes).unwrap(); + + // Null values result in a zero hash, + for (val, hash) in strings.iter().zip(string_hashes.iter()) { + match val { + Some(_) => assert_ne!(*hash, 0), + None => assert_eq!(*hash, 0), + } + } + + // same logical values should hash to the same hash value + assert_eq!(string_hashes, dict_hashes); + + // Same values should map to same hash values + assert_eq!(strings[0], strings[5]); + assert_eq!(strings[4], strings[6]); + + // different strings should map to different hash values + assert_ne!(strings[0], strings[2]); + } + }; + } + + create_hash_string!(string_array, StringArray); + create_hash_string!(large_string_array, LargeStringArray); + create_hash_string!(string_view_array, StringArray); + create_hash_string!(dict_string_array, DictionaryArray); + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] @@ -570,6 +657,7 @@ mod tests { Some(vec![Some(3), None, Some(5)]), None, Some(vec![Some(0), Some(1), Some(2)]), + Some(vec![]), ]; let list_array = Arc::new(ListArray::from_iter_primitive::(data)) as ArrayRef; @@ -579,6 +667,7 @@ mod tests { assert_eq!(hashes[0], hashes[5]); assert_eq!(hashes[1], hashes[4]); assert_eq!(hashes[2], hashes[3]); + assert_eq!(hashes[1], hashes[6]); // null vs empty list } #[test] @@ -620,19 +709,19 @@ mod tests { vec![ ( Arc::new(Field::new("bool", DataType::Boolean, false)), - boolarr.clone() as ArrayRef, + Arc::clone(&boolarr) as ArrayRef, ), ( Arc::new(Field::new("i32", DataType::Int32, false)), - i32arr.clone() as ArrayRef, + Arc::clone(&i32arr) as ArrayRef, ), ( Arc::new(Field::new("i32", DataType::Int32, false)), - i32arr.clone() as ArrayRef, + Arc::clone(&i32arr) as ArrayRef, ), ( Arc::new(Field::new("bool", DataType::Boolean, false)), - boolarr.clone() as ArrayRef, + Arc::clone(&boolarr) as ArrayRef, ), ], Buffer::from(&[0b001011]), @@ -690,6 +779,64 @@ mod tests { assert_eq!(hashes[0], hashes[1]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_map_arrays() { + let mut builder = + MapBuilder::new(None, StringBuilder::new(), Int32Builder::new()); + // Row 0 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 1 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 2 + builder.keys().append_value("key1"); + builder.keys().append_value("key2"); + builder.values().append_value(1); + builder.values().append_value(3); + builder.append(true).unwrap(); + // Row 3 + builder.keys().append_value("key1"); + builder.keys().append_value("key3"); + builder.values().append_value(1); + builder.values().append_value(2); + builder.append(true).unwrap(); + // Row 4 + builder.keys().append_value("key1"); + builder.values().append_value(1); + builder.append(true).unwrap(); + // Row 5 + builder.keys().append_value("key1"); + builder.values().append_null(); + builder.append(true).unwrap(); + // Row 6 + builder.append(true).unwrap(); + // Row 7 + builder.keys().append_value("key1"); + builder.values().append_value(1); + builder.append(false).unwrap(); + + let array = Arc::new(builder.finish()) as ArrayRef; + + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; array.len()]; + create_hashes(&[array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[1]); // same value + assert_ne!(hashes[0], hashes[2]); // different value + assert_ne!(hashes[0], hashes[3]); // different key + assert_ne!(hashes[0], hashes[4]); // missing an entry + assert_ne!(hashes[4], hashes[5]); // filled vs null value + assert_eq!(hashes[6], hashes[7]); // empty vs null map + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] @@ -708,7 +855,12 @@ mod tests { let random_state = RandomState::with_seeds(0, 0, 0, 0); let mut one_col_hashes = vec![0; strings1.len()]; - create_hashes(&[dict_array.clone()], &random_state, &mut one_col_hashes).unwrap(); + create_hashes( + &[Arc::clone(&dict_array) as ArrayRef], + &random_state, + &mut one_col_hashes, + ) + .unwrap(); let mut two_col_hashes = vec![0; strings1.len()]; create_hashes( diff --git a/datafusion/common/src/join_type.rs b/datafusion/common/src/join_type.rs index 0a00a57ba45f..e98f34199b27 100644 --- a/datafusion/common/src/join_type.rs +++ b/datafusion/common/src/join_type.rs @@ -26,7 +26,7 @@ use crate::error::_not_impl_err; use crate::{DataFusionError, Result}; /// Join type -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] pub enum JoinType { /// Inner Join Inner, @@ -44,6 +44,20 @@ pub enum JoinType { LeftAnti, /// Right Anti Join RightAnti, + /// Left Mark join + /// + /// Returns one record for each record from the left input. The output contains an additional + /// column "mark" which is true if there is at least one match in the right input where the + /// join condition evaluates to true. Otherwise, the mark column is false. For more details see + /// [1]. This join type is used to decorrelate EXISTS subqueries used inside disjunctive + /// predicates. + /// + /// Note: This we currently do not implement the full null semantics for the mark join described + /// in [1] which will be needed if we and ANY subqueries. In our version the mark column will + /// only be true for had a match and false when no match was found, never null. + /// + /// [1]: http://btw2017.informatik.uni-stuttgart.de/slidesandpapers/F1-10-37/paper_web.pdf + LeftMark, } impl JoinType { @@ -63,6 +77,7 @@ impl Display for JoinType { JoinType::RightSemi => "RightSemi", JoinType::LeftAnti => "LeftAnti", JoinType::RightAnti => "RightAnti", + JoinType::LeftMark => "LeftMark", }; write!(f, "{join_type}") } @@ -82,13 +97,14 @@ impl FromStr for JoinType { "RIGHTSEMI" => Ok(JoinType::RightSemi), "LEFTANTI" => Ok(JoinType::LeftAnti), "RIGHTANTI" => Ok(JoinType::RightAnti), + "LEFTMARK" => Ok(JoinType::LeftMark), _ => _not_impl_err!("The join type {s} does not exist or is not implemented"), } } } /// Join constraint -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] pub enum JoinConstraint { /// Join ON On, @@ -97,10 +113,11 @@ pub enum JoinConstraint { } impl Display for JoinSide { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { JoinSide::Left => write!(f, "left"), JoinSide::Right => write!(f, "right"), + JoinSide::None => write!(f, "none"), } } } @@ -113,6 +130,9 @@ pub enum JoinSide { Left, /// Right side of the join Right, + /// Neither side of the join, used for Mark joins where the mark column does not belong to + /// either side of the join + None, } impl JoinSide { @@ -121,6 +141,7 @@ impl JoinSide { match self { JoinSide::Left => JoinSide::Right, JoinSide::Right => JoinSide::Left, + JoinSide::None => JoinSide::None, } } } diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index e64acd0bfefe..08431a36e82f 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -14,10 +14,11 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] mod column; mod dfschema; -mod error; mod functional_dependencies; mod join_type; mod param_value; @@ -30,7 +31,9 @@ mod unnest; pub mod alias; pub mod cast; pub mod config; +pub mod cse; pub mod display; +pub mod error; pub mod file_options; pub mod format; pub mod hash_utils; @@ -41,6 +44,7 @@ pub mod scalar; pub mod stats; pub mod test_util; pub mod tree_node; +pub mod types; pub mod utils; /// Reexport arrow crate @@ -54,8 +58,8 @@ pub use error::{ SharedResult, }; pub use file_options::file_type::{ - FileType, GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, - DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, + GetExt, DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, + DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, }; pub use functional_dependencies::{ aggregate_functional_dependencies, get_required_group_by_exprs_indices, @@ -68,9 +72,21 @@ pub use scalar::{ScalarType, ScalarValue}; pub use schema_reference::SchemaReference; pub use stats::{ColumnStatistics, Statistics}; pub use table_reference::{ResolvedTableReference, TableReference}; -pub use unnest::UnnestOptions; +pub use unnest::{RecursionUnnestOption, UnnestOptions}; pub use utils::project_schema; +// These are hidden from docs purely to avoid polluting the public view of what this crate exports. +// These are just re-exports of macros by the same name, which gets around the 'cannot refer to +// macro-expanded macro_export macros by their full path' error. +// The design to get around this comes from this comment: +// https://github.com/rust-lang/rust/pull/52234#issuecomment-976702997 +#[doc(hidden)] +pub use error::{ + _config_datafusion_err, _exec_datafusion_err, _internal_datafusion_err, + _not_impl_datafusion_err, _plan_datafusion_err, _resources_datafusion_err, + _substrait_datafusion_err, +}; + /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/parsers.rs b/datafusion/common/src/parsers.rs index e23edb4e2adb..c73c8a55f18c 100644 --- a/datafusion/common/src/parsers.rs +++ b/datafusion/common/src/parsers.rs @@ -18,7 +18,6 @@ //! Interval parsing logic use std::fmt::Display; -use std::result; use std::str::FromStr; use sqlparser::parser::ParserError; @@ -41,7 +40,7 @@ pub enum CompressionTypeVariant { impl FromStr for CompressionTypeVariant { type Err = ParserError; - fn from_str(s: &str) -> result::Result { + fn from_str(s: &str) -> Result { let s = s.to_uppercase(); match s.as_str() { "GZIP" | "GZ" => Ok(Self::GZIP), diff --git a/datafusion/common/src/pyarrow.rs b/datafusion/common/src/pyarrow.rs index f4356477532f..bdcf831c7884 100644 --- a/datafusion/common/src/pyarrow.rs +++ b/datafusion/common/src/pyarrow.rs @@ -22,8 +22,8 @@ use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use arrow_array::Array; use pyo3::exceptions::PyException; use pyo3::prelude::PyErr; -use pyo3::types::PyList; -use pyo3::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; +use pyo3::types::{PyAnyMethods, PyList}; +use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python}; use crate::{DataFusionError, ScalarValue}; @@ -34,18 +34,18 @@ impl From for PyErr { } impl FromPyArrow for ScalarValue { - fn from_pyarrow(value: &PyAny) -> PyResult { + fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); let typ = value.getattr("type")?; let val = value.call_method0("as_py")?; // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [val]); + let factory = py.import_bound("pyarrow")?.getattr("array")?; + let args = PyList::new_bound(py, [val]); let array = factory.call1((args, typ))?; // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow(array)?); + let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); let scalar = ScalarValue::try_from_array(&array, 0)?; Ok(scalar) @@ -64,8 +64,8 @@ impl ToPyArrow for ScalarValue { } impl<'source> FromPyObject<'source> for ScalarValue { - fn extract(value: &'source PyAny) -> PyResult { - Self::from_pyarrow(value) + fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult { + Self::from_pyarrow_bound(value) } } @@ -86,19 +86,19 @@ mod tests { fn init_python() { prepare_freethreaded_python(); Python::with_gil(|py| { - if py.run("import pyarrow", None, None).is_err() { - let locals = PyDict::new(py); - py.run( + if py.run_bound("import pyarrow", None, None).is_err() { + let locals = PyDict::new_bound(py); + py.run_bound( "import sys; executable = sys.executable; python_path = sys.path", None, - Some(locals), + Some(&locals), ) .expect("Couldn't get python info"); - let executable = locals.get_item("executable").unwrap().unwrap(); + let executable = locals.get_item("executable").unwrap(); let executable: String = executable.extract().unwrap(); - let python_path = locals.get_item("python_path").unwrap().unwrap(); - let python_path: Vec<&str> = python_path.extract().unwrap(); + let python_path = locals.get_item("python_path").unwrap(); + let python_path: Vec = python_path.extract().unwrap(); panic!("pyarrow not found\nExecutable: {executable}\nPython path: {python_path:?}\n\ HINT: try `pip install pyarrow`\n\ @@ -125,9 +125,10 @@ mod tests { Python::with_gil(|py| { for scalar in example_scalars.iter() { - let result = - ScalarValue::from_pyarrow(scalar.to_pyarrow(py).unwrap().as_ref(py)) - .unwrap(); + let result = ScalarValue::from_pyarrow_bound( + scalar.to_pyarrow(py).unwrap().bind(py), + ) + .unwrap(); assert_eq!(scalar, &result); } }); diff --git a/datafusion/common/src/scalar/consts.rs b/datafusion/common/src/scalar/consts.rs new file mode 100644 index 000000000000..efcde651841b --- /dev/null +++ b/datafusion/common/src/scalar/consts.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Constants defined for scalar construction. + +// PI ~ 3.1415927 in f32 +#[allow(clippy::approx_constant)] +pub(super) const PI_UPPER_F32: f32 = 3.141593_f32; + +// PI ~ 3.141592653589793 in f64 +pub(super) const PI_UPPER_F64: f64 = 3.141592653589794_f64; + +// -PI ~ -3.1415927 in f32 +#[allow(clippy::approx_constant)] +pub(super) const NEGATIVE_PI_LOWER_F32: f32 = -3.141593_f32; + +// -PI ~ -3.141592653589793 in f64 +pub(super) const NEGATIVE_PI_LOWER_F64: f64 = -3.141592653589794_f64; + +// PI / 2 ~ 1.5707964 in f32 +pub(super) const FRAC_PI_2_UPPER_F32: f32 = 1.5707965_f32; + +// PI / 2 ~ 1.5707963267948966 in f64 +pub(super) const FRAC_PI_2_UPPER_F64: f64 = 1.5707963267948967_f64; + +// -PI / 2 ~ -1.5707964 in f32 +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F32: f32 = -1.5707965_f32; + +// -PI / 2 ~ -1.5707963267948966 in f64 +pub(super) const NEGATIVE_FRAC_PI_2_LOWER_F64: f64 = -1.5707963267948967_f64; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index e71d82fb3beb..5595f4f9fa70 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -17,6 +17,7 @@ //! [`ScalarValue`]: stores single values +mod consts; mod struct_builder; use std::borrow::Borrow; @@ -25,15 +26,18 @@ use std::collections::{HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; use std::hash::Hash; +use std::hash::Hasher; use std::iter::repeat; +use std::mem::{size_of, size_of_val}; use std::str::FromStr; use std::sync::Arc; +use crate::arrow_datafusion_err; use crate::cast::{ as_decimal128_array, as_decimal256_array, as_dictionary_array, as_fixed_size_binary_array, as_fixed_size_list_array, }; -use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; +use crate::error::{DataFusionError, Result, _exec_err, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; use crate::utils::{ array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, @@ -45,16 +49,18 @@ use arrow::{ compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType, - Date32Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type, - IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + Date32Type, Date64Type, Field, Float32Type, Int16Type, Int32Type, Int64Type, + Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; -use arrow_buffer::Buffer; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use arrow_schema::{UnionFields, UnionMode}; +use crate::format::DEFAULT_CAST_OPTIONS; +use half::f16; pub use struct_builder::ScalarStructBuilder; /// A dynamically typed, nullable single value. @@ -100,7 +106,7 @@ pub use struct_builder::ScalarStructBuilder; /// /// `ScalarValue` represents null values in the same way as Arrow. Nulls are /// "typed" in the sense that a null value in an [`Int32Array`] is different -/// than a null value in a [`Float64Array`], and is different than the values in +/// from a null value in a [`Float64Array`], and is different from the values in /// a [`NullArray`]. /// /// ``` @@ -124,7 +130,7 @@ pub use struct_builder::ScalarStructBuilder; /// /// # Nested Types /// -/// `List` / `LargeList` / `FixedSizeList` / `Struct` are represented as a +/// `List` / `LargeList` / `FixedSizeList` / `Struct` / `Map` are represented as a /// single element array of the corresponding type. /// /// ## Example: Creating [`ScalarValue::Struct`] using [`ScalarStructBuilder`] @@ -192,6 +198,8 @@ pub enum ScalarValue { Null, /// true or false value Boolean(Option), + /// 16bit float + Float16(Option), /// 32bit float Float32(Option), /// 64bit float @@ -218,10 +226,14 @@ pub enum ScalarValue { UInt64(Option), /// utf-8 encoded string. Utf8(Option), + /// utf-8 encoded string but from view types. + Utf8View(Option), /// utf-8 encoded string representing a LargeString's arrow type. LargeUtf8(Option), /// binary Binary(Option>), + /// binary but from view types. + BinaryView(Option>), /// fixed size binary FixedSizeBinary(i32, Option>), /// large binary @@ -239,6 +251,8 @@ pub enum ScalarValue { /// Represents a single element [`StructArray`] as an [`ArrayRef`]. See /// [`ScalarValue`] for examples of how to create instances of this type. Struct(Arc), + /// Represents a single element [`MapArray`] as an [`ArrayRef`]. + Map(Arc), /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 @@ -263,11 +277,11 @@ pub enum ScalarValue { IntervalYearMonth(Option), /// Number of elapsed days and milliseconds (no leap seconds) /// stored as 2 contiguous 32-bit signed integers - IntervalDayTime(Option), + IntervalDayTime(Option), /// A triple of the number of elapsed months, days, and nanoseconds. /// Months and days are encoded as 32-bit signed integers. /// Nanoseconds is encoded as a 64-bit signed integer (no leap seconds). - IntervalMonthDayNano(Option), + IntervalMonthDayNano(Option), /// Duration in seconds DurationSecond(Option), /// Duration in milliseconds @@ -285,6 +299,12 @@ pub enum ScalarValue { Dictionary(Box, Box), } +impl Hash for Fl { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + // manual implementation of `PartialEq` impl PartialEq for ScalarValue { fn eq(&self, other: &Self) -> bool { @@ -307,7 +327,12 @@ impl PartialEq for ScalarValue { (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), _ => v1.eq(v2), }, + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), + _ => v1.eq(v2), + }, (Float32(_), _) => false, + (Float16(_), _) => false, (Float64(v1), Float64(v2)) => match (v1, v2) { (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), _ => v1.eq(v2), @@ -331,10 +356,14 @@ impl PartialEq for ScalarValue { (UInt64(_), _) => false, (Utf8(v1), Utf8(v2)) => v1.eq(v2), (Utf8(_), _) => false, + (Utf8View(v1), Utf8View(v2)) => v1.eq(v2), + (Utf8View(_), _) => false, (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), (LargeUtf8(_), _) => false, (Binary(v1), Binary(v2)) => v1.eq(v2), (Binary(_), _) => false, + (BinaryView(v1), BinaryView(v2)) => v1.eq(v2), + (BinaryView(_), _) => false, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.eq(v2), (FixedSizeBinary(_, _), _) => false, (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), @@ -347,6 +376,8 @@ impl PartialEq for ScalarValue { (LargeList(_), _) => false, (Struct(v1), Struct(v2)) => v1.eq(v2), (Struct(_), _) => false, + (Map(v1), Map(v2)) => v1.eq(v2), + (Map(_), _) => false, (Date32(v1), Date32(v2)) => v1.eq(v2), (Date32(_), _) => false, (Date64(v1), Date64(v2)) => v1.eq(v2), @@ -425,7 +456,12 @@ impl PartialOrd for ScalarValue { (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), _ => v1.partial_cmp(v2), }, + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), + _ => v1.partial_cmp(v2), + }, (Float32(_), _) => None, + (Float16(_), _) => None, (Float64(v1), Float64(v2)) => match (v1, v2) { (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), _ => v1.partial_cmp(v2), @@ -451,8 +487,12 @@ impl PartialOrd for ScalarValue { (Utf8(_), _) => None, (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), (LargeUtf8(_), _) => None, + (Utf8View(v1), Utf8View(v2)) => v1.partial_cmp(v2), + (Utf8View(_), _) => None, (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), (Binary(_), _) => None, + (BinaryView(v1), BinaryView(v2)) => v1.partial_cmp(v2), + (BinaryView(_), _) => None, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.partial_cmp(v2), (FixedSizeBinary(_, _), _) => None, (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), @@ -470,6 +510,8 @@ impl PartialOrd for ScalarValue { partial_cmp_struct(struct_arr1, struct_arr2) } (Struct(_), _) => None, + (Map(map_arr1), Map(map_arr2)) => partial_cmp_map(map_arr1, map_arr2), + (Map(_), _) => None, (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), (Date32(_), _) => None, (Date64(v1), Date64(v2)) => v1.partial_cmp(v2), @@ -599,6 +641,34 @@ fn partial_cmp_struct(s1: &Arc, s2: &Arc) -> Option, m2: &Arc) -> Option { + if m1.len() != m2.len() { + return None; + } + + if m1.data_type() != m2.data_type() { + return None; + } + + for col_index in 0..m1.len() { + let arr1 = m1.entries().column(col_index); + let arr2 = m2.entries().column(col_index); + + let lt_res = arrow::compute::kernels::cmp::lt(arr1, arr2).ok()?; + let eq_res = arrow::compute::kernels::cmp::eq(arr1, arr2).ok()?; + + for j in 0..lt_res.len() { + if lt_res.is_valid(j) && lt_res.value(j) { + return Some(Ordering::Less); + } + if eq_res.is_valid(j) && !eq_res.value(j) { + return Some(Ordering::Greater); + } + } + } + Some(Ordering::Equal) +} + impl Eq for ScalarValue {} //Float wrapper over f32/f64. Just because we cannot build std::hash::Hash for floats directly we have to do it through type wrapper @@ -622,8 +692,8 @@ hash_float_value!((f64, u64), (f32, u32)); // # Panics // // Panics if there is an error when creating hash values for rows -impl std::hash::Hash for ScalarValue { - fn hash(&self, state: &mut H) { +impl Hash for ScalarValue { + fn hash(&self, state: &mut H) { use ScalarValue::*; match self { Decimal128(v, p, s) => { @@ -637,6 +707,7 @@ impl std::hash::Hash for ScalarValue { s.hash(state) } Boolean(v) => v.hash(state), + Float16(v) => v.map(Fl).hash(state), Float32(v) => v.map(Fl).hash(state), Float64(v) => v.map(Fl).hash(state), Int8(v) => v.hash(state), @@ -647,11 +718,10 @@ impl std::hash::Hash for ScalarValue { UInt16(v) => v.hash(state), UInt32(v) => v.hash(state), UInt64(v) => v.hash(state), - Utf8(v) => v.hash(state), - LargeUtf8(v) => v.hash(state), - Binary(v) => v.hash(state), - FixedSizeBinary(_, v) => v.hash(state), - LargeBinary(v) => v.hash(state), + Utf8(v) | LargeUtf8(v) | Utf8View(v) => v.hash(state), + Binary(v) | FixedSizeBinary(_, v) | LargeBinary(v) | BinaryView(v) => { + v.hash(state) + } List(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } @@ -664,6 +734,9 @@ impl std::hash::Hash for ScalarValue { Struct(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } + Map(arr) => { + hash_nested_array(arr.to_owned() as ArrayRef, state); + } Date32(v) => v.hash(state), Date64(v) => v.hash(state), Time32Second(v) => v.hash(state), @@ -696,7 +769,7 @@ impl std::hash::Hash for ScalarValue { } } -fn hash_nested_array(arr: ArrayRef, state: &mut H) { +fn hash_nested_array(arr: ArrayRef, state: &mut H) { let arrays = vec![arr.to_owned()]; let hashes_buffer = &mut vec![0; arr.len()]; let random_state = ahash::RandomState::with_seeds(0, 0, 0, 0); @@ -730,9 +803,13 @@ fn dict_from_scalar( let values_array = value.to_array_of_size(1)?; // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = std::iter::repeat(Some(K::default_value())) - .take(size) - .collect(); + let key_array: PrimitiveArray = repeat(if value.is_null() { + None + } else { + Some(K::default_value()) + }) + .take(size) + .collect(); // create a new DictionaryArray // @@ -903,6 +980,11 @@ impl ScalarValue { ScalarValue::from(val.into()) } + /// Returns a [`ScalarValue::Utf8View`] representing `val` + pub fn new_utf8view(val: impl Into) -> Self { + ScalarValue::Utf8View(Some(val.into())) + } + /// Returns a [`ScalarValue::IntervalYearMonth`] representing /// `years` years and `months` months pub fn new_interval_ym(years: i32, months: i32) -> Self { @@ -938,6 +1020,123 @@ impl ScalarValue { } } + /// Returns a [`ScalarValue`] representing PI + pub fn new_pi(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::PI)), + DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::PI)), + _ => _internal_err!("PI is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing PI's upper bound + pub fn new_pi_upper(datatype: &DataType) -> Result { + // TODO: replace the constants with next_up/next_down when + // they are stabilized: https://doc.rust-lang.org/std/primitive.f64.html#method.next_up + match datatype { + DataType::Float32 => Ok(ScalarValue::from(consts::PI_UPPER_F32)), + DataType::Float64 => Ok(ScalarValue::from(consts::PI_UPPER_F64)), + _ => { + _internal_err!("PI_UPPER is not supported for data type: {:?}", datatype) + } + } + } + + /// Returns a [`ScalarValue`] representing -PI's lower bound + pub fn new_negative_pi_lower(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F32)), + DataType::Float64 => Ok(ScalarValue::from(consts::NEGATIVE_PI_LOWER_F64)), + _ => { + _internal_err!("-PI_LOWER is not supported for data type: {:?}", datatype) + } + } + } + + /// Returns a [`ScalarValue`] representing FRAC_PI_2's upper bound + pub fn new_frac_pi_2_upper(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F32)), + DataType::Float64 => Ok(ScalarValue::from(consts::FRAC_PI_2_UPPER_F64)), + _ => { + _internal_err!( + "PI_UPPER/2 is not supported for data type: {:?}", + datatype + ) + } + } + } + + // Returns a [`ScalarValue`] representing FRAC_PI_2's lower bound + pub fn new_neg_frac_pi_2_lower(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => { + Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F32)) + } + DataType::Float64 => { + Ok(ScalarValue::from(consts::NEGATIVE_FRAC_PI_2_LOWER_F64)) + } + _ => { + _internal_err!( + "-PI/2_LOWER is not supported for data type: {:?}", + datatype + ) + } + } + } + + /// Returns a [`ScalarValue`] representing -PI + pub fn new_negative_pi(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::PI)), + DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::PI)), + _ => _internal_err!("-PI is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing PI/2 + pub fn new_frac_pi_2(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(std::f32::consts::FRAC_PI_2)), + DataType::Float64 => Ok(ScalarValue::from(std::f64::consts::FRAC_PI_2)), + _ => _internal_err!("PI/2 is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing -PI/2 + pub fn new_neg_frac_pi_2(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(-std::f32::consts::FRAC_PI_2)), + DataType::Float64 => Ok(ScalarValue::from(-std::f64::consts::FRAC_PI_2)), + _ => _internal_err!("-PI/2 is not supported for data type: {:?}", datatype), + } + } + + /// Returns a [`ScalarValue`] representing infinity + pub fn new_infinity(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(f32::INFINITY)), + DataType::Float64 => Ok(ScalarValue::from(f64::INFINITY)), + _ => { + _internal_err!("Infinity is not supported for data type: {:?}", datatype) + } + } + } + + /// Returns a [`ScalarValue`] representing negative infinity + pub fn new_neg_infinity(datatype: &DataType) -> Result { + match datatype { + DataType::Float32 => Ok(ScalarValue::from(f32::NEG_INFINITY)), + DataType::Float64 => Ok(ScalarValue::from(f64::NEG_INFINITY)), + _ => { + _internal_err!( + "Negative Infinity is not supported for data type: {:?}", + datatype + ) + } + } + } + /// Create a zero value in the given type. pub fn new_zero(datatype: &DataType) -> Result { Ok(match datatype { @@ -950,6 +1149,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(0)), DataType::UInt32 => ScalarValue::UInt32(Some(0)), DataType::UInt64 => ScalarValue::UInt64(Some(0)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), DataType::Timestamp(TimeUnit::Second, tz) => { @@ -968,10 +1168,10 @@ impl ScalarValue { ScalarValue::IntervalYearMonth(Some(0)) } DataType::Interval(IntervalUnit::DayTime) => { - ScalarValue::IntervalDayTime(Some(0)) + ScalarValue::IntervalDayTime(Some(IntervalDayTime::ZERO)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - ScalarValue::IntervalMonthDayNano(Some(0)) + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::ZERO)) } DataType::Duration(TimeUnit::Second) => ScalarValue::DurationSecond(Some(0)), DataType::Duration(TimeUnit::Millisecond) => { @@ -993,7 +1193,6 @@ impl ScalarValue { /// Create an one value in the given type. pub fn new_one(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 => ScalarValue::Int8(Some(1)), DataType::Int16 => ScalarValue::Int16(Some(1)), @@ -1003,6 +1202,7 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(1)), DataType::UInt32 => ScalarValue::UInt32(Some(1)), DataType::UInt64 => ScalarValue::UInt64(Some(1)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(1.0))), DataType::Float32 => ScalarValue::Float32(Some(1.0)), DataType::Float64 => ScalarValue::Float64(Some(1.0)), _ => { @@ -1015,12 +1215,12 @@ impl ScalarValue { /// Create a negative one value in the given type. pub fn new_negative_one(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 | DataType::UInt8 => ScalarValue::Int8(Some(-1)), DataType::Int16 | DataType::UInt16 => ScalarValue::Int16(Some(-1)), DataType::Int32 | DataType::UInt32 => ScalarValue::Int32(Some(-1)), DataType::Int64 | DataType::UInt64 => ScalarValue::Int64(Some(-1)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(-1.0))), DataType::Float32 => ScalarValue::Float32(Some(-1.0)), DataType::Float64 => ScalarValue::Float64(Some(-1.0)), _ => { @@ -1032,7 +1232,6 @@ impl ScalarValue { } pub fn new_ten(datatype: &DataType) -> Result { - assert!(datatype.is_primitive()); Ok(match datatype { DataType::Int8 => ScalarValue::Int8(Some(10)), DataType::Int16 => ScalarValue::Int16(Some(10)), @@ -1042,11 +1241,12 @@ impl ScalarValue { DataType::UInt16 => ScalarValue::UInt16(Some(10)), DataType::UInt32 => ScalarValue::UInt32(Some(10)), DataType::UInt64 => ScalarValue::UInt64(Some(10)), + DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(10.0))), DataType::Float32 => ScalarValue::Float32(Some(10.0)), DataType::Float64 => ScalarValue::Float64(Some(10.0)), _ => { return _not_impl_err!( - "Can't create a negative one scalar from data_type \"{datatype:?}\"" + "Can't create a ten scalar from data_type \"{datatype:?}\"" ); } }) @@ -1082,17 +1282,21 @@ impl ScalarValue { ScalarValue::TimestampNanosecond(_, tz_opt) => { DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) } + ScalarValue::Float16(_) => DataType::Float16, ScalarValue::Float32(_) => DataType::Float32, ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, + ScalarValue::Utf8View(_) => DataType::Utf8View, ScalarValue::Binary(_) => DataType::Binary, + ScalarValue::BinaryView(_) => DataType::BinaryView, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), ScalarValue::LargeBinary(_) => DataType::LargeBinary, ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), ScalarValue::Struct(arr) => arr.data_type().to_owned(), + ScalarValue::Map(arr) => arr.data_type().to_owned(), ScalarValue::Date32(_) => DataType::Date32, ScalarValue::Date64(_) => DataType::Date64, ScalarValue::Time32Second(_) => DataType::Time32(TimeUnit::Second), @@ -1124,59 +1328,113 @@ impl ScalarValue { } } - /// Getter for the `DataType` of the value. - /// - /// Suggest using [`Self::data_type`] as a more standard API - #[deprecated(since = "31.0.0", note = "use data_type instead")] - pub fn get_datatype(&self) -> DataType { - self.data_type() - } - /// Calculate arithmetic negation for a scalar value pub fn arithmetic_negate(&self) -> Result { + fn neg_checked_with_ctx( + v: T, + ctx: impl Fn() -> String, + ) -> Result { + v.neg_checked() + .map_err(|e| arrow_datafusion_err!(e).context(ctx())) + } match self { ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) | ScalarValue::Int64(None) + | ScalarValue::Float16(None) | ScalarValue::Float32(None) | ScalarValue::Float64(None) => Ok(self.clone()), + ScalarValue::Float16(Some(v)) => { + Ok(ScalarValue::Float16(Some(f16::from_f32(-v.to_f32())))) + } ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))), ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))), - ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))), - ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))), - ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))), - ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))), - ScalarValue::IntervalYearMonth(Some(v)) => { - Ok(ScalarValue::IntervalYearMonth(Some(-v))) - } + ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))), + ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))), + ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))), + ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))), + ScalarValue::IntervalYearMonth(Some(v)) => Ok( + ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || { + format!("In negation of IntervalYearMonth({v})") + })?)), + ), ScalarValue::IntervalDayTime(Some(v)) => { let (days, ms) = IntervalDayTimeType::to_parts(*v); - let val = IntervalDayTimeType::make_value(-days, -ms); + let val = IntervalDayTimeType::make_value( + neg_checked_with_ctx(days, || { + format!("In negation of days {days} in IntervalDayTime") + })?, + neg_checked_with_ctx(ms, || { + format!("In negation of milliseconds {ms} in IntervalDayTime") + })?, + ); Ok(ScalarValue::IntervalDayTime(Some(val))) } ScalarValue::IntervalMonthDayNano(Some(v)) => { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos); + let val = IntervalMonthDayNanoType::make_value( + neg_checked_with_ctx(months, || { + format!("In negation of months {months} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(days, || { + format!("In negation of days {days} of IntervalMonthDayNano") + })?, + neg_checked_with_ctx(nanos, || { + format!("In negation of nanos {nanos} of IntervalMonthDayNano") + })?, + ); Ok(ScalarValue::IntervalMonthDayNano(Some(val))) } ScalarValue::Decimal128(Some(v), precision, scale) => { - Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale)) + Ok(ScalarValue::Decimal128( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal128({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) + } + ScalarValue::Decimal256(Some(v), precision, scale) => { + Ok(ScalarValue::Decimal256( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of Decimal256({v}, {precision}, {scale})") + })?), + *precision, + *scale, + )) } - ScalarValue::Decimal256(Some(v), precision, scale) => Ok( - ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale), - ), ScalarValue::TimestampSecond(Some(v), tz) => { - Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampSecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampNanosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampNanosecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampNanoSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampMicrosecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampMicrosecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMicroSecond({v})") + })?), + tz.clone(), + )) } ScalarValue::TimestampMillisecond(Some(v), tz) => { - Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone())) + Ok(ScalarValue::TimestampMillisecond( + Some(neg_checked_with_ctx(*v, || { + format!("In negation of TimestampMilliSecond({v})") + })?), + tz.clone(), + )) } value => _internal_err!( "Can not run arithmetic negative on scalar value {value:?}" @@ -1276,6 +1534,7 @@ impl ScalarValue { match self { ScalarValue::Boolean(v) => v.is_none(), ScalarValue::Null => true, + ScalarValue::Float16(v) => v.is_none(), ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), @@ -1288,17 +1547,20 @@ impl ScalarValue { ScalarValue::UInt16(v) => v.is_none(), ScalarValue::UInt32(v) => v.is_none(), ScalarValue::UInt64(v) => v.is_none(), - ScalarValue::Utf8(v) => v.is_none(), - ScalarValue::LargeUtf8(v) => v.is_none(), - ScalarValue::Binary(v) => v.is_none(), - ScalarValue::FixedSizeBinary(_, v) => v.is_none(), - ScalarValue::LargeBinary(v) => v.is_none(), + ScalarValue::Utf8(v) + | ScalarValue::Utf8View(v) + | ScalarValue::LargeUtf8(v) => v.is_none(), + ScalarValue::Binary(v) + | ScalarValue::BinaryView(v) + | ScalarValue::FixedSizeBinary(_, v) + | ScalarValue::LargeBinary(v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. ScalarValue::List(arr) => arr.len() == arr.null_count(), ScalarValue::LargeList(arr) => arr.len() == arr.null_count(), ScalarValue::FixedSizeList(arr) => arr.len() == arr.null_count(), ScalarValue::Struct(arr) => arr.len() == arr.null_count(), + ScalarValue::Map(arr) => arr.len() == arr.null_count(), ScalarValue::Date32(v) => v.is_none(), ScalarValue::Date64(v) => v.is_none(), ScalarValue::Time32Second(v) => v.is_none(), @@ -1316,7 +1578,10 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), - ScalarValue::Union(v, _, _) => v.is_none(), + ScalarValue::Union(v, _, _) => match v { + Some((_, s)) => s.is_null(), + None => true, + }, ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -1339,6 +1604,9 @@ impl ScalarValue { (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _), (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _), // TODO: we might want to look into supporting ceil/floor here for floats. + (Self::Float16(Some(l)), Self::Float16(Some(r))) => { + Some((f16::to_f32(*l) - f16::to_f32(*r)).abs().round() as _) + } (Self::Float32(Some(l)), Self::Float32(Some(r))) => { Some((l - r).abs().round() as _) } @@ -1438,9 +1706,7 @@ impl ScalarValue { // figure out the type based on the first element let data_type = match scalars.peek() { None => { - return _internal_err!( - "Empty iterator passed to ScalarValue::iter_to_array" - ); + return _exec_err!("Empty iterator passed to ScalarValue::iter_to_array"); } Some(sv) => sv.data_type(), }; @@ -1454,7 +1720,7 @@ impl ScalarValue { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) } else { - _internal_err!( + _exec_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv @@ -1474,7 +1740,7 @@ impl ScalarValue { if let ScalarValue::$SCALAR_TY(v, _) = sv { Ok(v) } else { - _internal_err!( + _exec_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv @@ -1496,7 +1762,7 @@ impl ScalarValue { if let ScalarValue::$SCALAR_TY(v) = sv { Ok(v) } else { - _internal_err!( + _exec_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {:?}, got {:?}", data_type, sv @@ -1522,6 +1788,7 @@ impl ScalarValue { } DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Float16 => build_array_primitive!(Float16Array, Float16), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), DataType::Int8 => build_array_primitive!(Int8Array, Int8), @@ -1532,8 +1799,10 @@ impl ScalarValue { DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), + DataType::Utf8View => build_array_string!(StringViewArray, Utf8View), DataType::Utf8 => build_array_string!(StringArray, Utf8), DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), + DataType::BinaryView => build_array_string!(BinaryViewArray, BinaryView), DataType::Binary => build_array_string!(BinaryArray, Binary), DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), DataType::Date32 => build_array_primitive!(Date32Array, Date32), @@ -1608,8 +1877,11 @@ impl ScalarValue { if let Some(DataType::FixedSizeList(f, l)) = first_non_null_data_type { for array in arrays.iter_mut() { if array.is_null(0) { - *array = - Arc::new(FixedSizeListArray::new_null(f.clone(), l, 1)); + *array = Arc::new(FixedSizeListArray::new_null( + Arc::clone(&f), + l, + 1, + )); } } } @@ -1618,6 +1890,7 @@ impl ScalarValue { } DataType::List(_) | DataType::LargeList(_) + | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) => { let arrays = scalars.map(|s| s.to_array()).collect::>>()?; @@ -1632,11 +1905,11 @@ impl ScalarValue { if &inner_key_type == key_type { Ok(*scalar) } else { - _internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") + _exec_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") } } _ => { - _internal_err!( + _exec_err!( "Expected scalar of type {value_type} but found: {scalar} {scalar:?}" ) } @@ -1664,7 +1937,7 @@ impl ScalarValue { if let ScalarValue::FixedSizeBinary(_, v) = sv { Ok(v) } else { - _internal_err!( + _exec_err!( "Inconsistent types in ScalarValue::iter_to_array. \ Expected {data_type:?}, got {sv:?}" ) @@ -1682,25 +1955,20 @@ impl ScalarValue { // not supported if the TimeUnit is not valid (Time32 can // only be used with Second and Millisecond, Time64 only // with Microsecond and Nanosecond) - DataType::Float16 - | DataType::Time32(TimeUnit::Microsecond) + DataType::Time32(TimeUnit::Microsecond) | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) - | DataType::Utf8View - | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { - return _internal_err!( + return _not_impl_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, scalars.peek() ); } }; - Ok(array) } @@ -1781,7 +2049,7 @@ impl ScalarValue { scale: i8, size: usize, ) -> Result { - Ok(std::iter::repeat(value) + Ok(repeat(value) .take(size) .collect::() .with_precision_and_scale(precision, scale)?) @@ -1803,7 +2071,7 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let result = ScalarValue::new_list(&scalars, &DataType::Int32); + /// let result = ScalarValue::new_list(&scalars, &DataType::Int32, true); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ @@ -1812,13 +2080,35 @@ impl ScalarValue { /// /// assert_eq!(*result, expected); /// ``` - pub fn new_list(values: &[ScalarValue], data_type: &DataType) -> Arc { + pub fn new_list( + values: &[ScalarValue], + data_type: &DataType, + nullable: bool, + ) -> Arc { let values = if values.is_empty() { new_empty_array(data_type) } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_list_array(values)) + Arc::new(array_into_list_array(values, nullable)) + } + + /// Same as [`ScalarValue::new_list`] but with nullable set to true. + pub fn new_list_nullable( + values: &[ScalarValue], + data_type: &DataType, + ) -> Arc { + Self::new_list(values, data_type, true) + } + + /// Create ListArray with Null with specific data type + /// + /// - new_null_list(i32, nullable, 1): `ListArray[NULL]` + pub fn new_null_list(data_type: DataType, nullable: bool, null_len: usize) -> Self { + let data_type = DataType::List(Field::new_list_field(data_type, nullable).into()); + Self::List(Arc::new(ListArray::from(ArrayData::new_null( + &data_type, null_len, + )))) } /// Converts `IntoIterator` where each element has type corresponding to @@ -1837,7 +2127,7 @@ impl ScalarValue { /// ScalarValue::Int32(Some(2)) /// ]; /// - /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32); + /// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32, true); /// /// let expected = ListArray::from_iter_primitive::( /// vec![ @@ -1849,13 +2139,14 @@ impl ScalarValue { pub fn new_list_from_iter( values: impl IntoIterator + ExactSizeIterator, data_type: &DataType, + nullable: bool, ) -> Arc { let values = if values.len() == 0 { new_empty_array(data_type) } else { Self::iter_to_array(values).unwrap() }; - Arc::new(array_into_list_array(values)) + Arc::new(array_into_list_array(values, nullable)) } /// Converts `Vec` where each element has type corresponding to @@ -1921,6 +2212,9 @@ impl ScalarValue { ScalarValue::Float32(e) => { build_array_from_option!(Float32, Float32Array, e, size) } + ScalarValue::Float16(e) => { + build_array_from_option!(Float16, Float16Array, e, size) + } ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), @@ -1978,6 +2272,12 @@ impl ScalarValue { } None => new_null_array(&DataType::Utf8, size), }, + ScalarValue::Utf8View(e) => match e { + Some(value) => { + Arc::new(StringViewArray::from_iter_values(repeat(value).take(size))) + } + None => new_null_array(&DataType::Utf8View, size), + }, ScalarValue::LargeUtf8(e) => match e { Some(value) => { Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) @@ -1994,6 +2294,16 @@ impl ScalarValue { Arc::new(repeat(None::<&str>).take(size).collect::()) } }, + ScalarValue::BinaryView(e) => match e { + Some(value) => Arc::new( + repeat(Some(value.as_slice())) + .take(size) + .collect::(), + ), + None => { + Arc::new(repeat(None::<&str>).take(size).collect::()) + } + }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( FixedSizeBinaryArray::try_from_sparse_iter_with_size( @@ -2034,6 +2344,9 @@ impl ScalarValue { ScalarValue::Struct(arr) => { Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } + ScalarValue::Map(arr) => { + Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? + } ScalarValue::Date32(e) => { build_array_from_option!(Date32, Date32Array, e, size) } @@ -2127,9 +2440,8 @@ impl ScalarValue { ), ScalarValue::Union(value, fields, _mode) => match value { Some((v_id, value)) => { - let mut field_type_ids = Vec::::with_capacity(fields.len()); - let mut child_arrays = - Vec::<(Field, ArrayRef)>::with_capacity(fields.len()); + let mut new_fields = Vec::with_capacity(fields.len()); + let mut child_arrays = Vec::::with_capacity(fields.len()); for (f_id, field) in fields.iter() { let ar = if f_id == *v_id { value.to_array_of_size(size)? @@ -2138,14 +2450,14 @@ impl ScalarValue { new_null_array(dt, size) }; let field = (**field).clone(); - child_arrays.push((field, ar)); - field_type_ids.push(f_id); + child_arrays.push(ar); + new_fields.push(field.clone()); } - let type_ids = repeat(*v_id).take(size).collect::>(); - let type_ids = Buffer::from_slice_ref(type_ids); - let value_offsets: Option = None; + let type_ids = repeat(*v_id).take(size); + let type_ids = ScalarBuffer::::from_iter(type_ids); + let value_offsets: Option> = None; let ar = UnionArray::try_new( - field_type_ids.as_slice(), + fields.clone(), type_ids, value_offsets, child_arrays, @@ -2206,8 +2518,12 @@ impl ScalarValue { } fn list_to_array_of_size(arr: &dyn Array, size: usize) -> Result { - let arrays = std::iter::repeat(arr).take(size).collect::>(); - Ok(arrow::compute::concat(arrays.as_slice())?) + let arrays = repeat(arr).take(size).collect::>(); + let ret = match !arrays.is_empty() { + true => arrow::compute::concat(arrays.as_slice())?, + false => arr.slice(0, 0), + }; + Ok(ret) } /// Retrieve ScalarValue for each row in `array` @@ -2247,7 +2563,7 @@ impl ScalarValue { /// use datafusion_common::ScalarValue; /// use arrow::array::ListArray; /// use arrow::datatypes::{DataType, Int32Type}; - /// use datafusion_common::utils::array_into_list_array; + /// use datafusion_common::utils::array_into_list_array_nullable; /// use std::sync::Arc; /// /// let list_arr = ListArray::from_iter_primitive::(vec![ @@ -2256,7 +2572,7 @@ impl ScalarValue { /// ]); /// /// // Wrap into another layer of list, we got nested array as [ [[1,2,3], [4,5]] ] - /// let list_arr = array_into_list_array(Arc::new(list_arr)); + /// let list_arr = array_into_list_array_nullable(Arc::new(list_arr)); /// /// // Convert the array into Scalar Values for each row, we got 1D arrays in this example /// let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&list_arr).unwrap(); @@ -2322,6 +2638,7 @@ impl ScalarValue { DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean)?, DataType::Float64 => typed_cast!(array, index, Float64Array, Float64)?, DataType::Float32 => typed_cast!(array, index, Float32Array, Float32)?, + DataType::Float16 => typed_cast!(array, index, Float16Array, Float16)?, DataType::UInt64 => typed_cast!(array, index, UInt64Array, UInt64)?, DataType::UInt32 => typed_cast!(array, index, UInt32Array, UInt32)?, DataType::UInt16 => typed_cast!(array, index, UInt16Array, UInt16)?, @@ -2334,15 +2651,20 @@ impl ScalarValue { DataType::LargeBinary => { typed_cast!(array, index, LargeBinaryArray, LargeBinary)? } + DataType::BinaryView => { + typed_cast!(array, index, BinaryViewArray, BinaryView)? + } DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, DataType::LargeUtf8 => { typed_cast!(array, index, LargeStringArray, LargeUtf8)? } - DataType::List(_) => { + DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, + DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. - let arr = Arc::new(array_into_list_array(nested_array)); + let arr = + Arc::new(array_into_list_array(nested_array, field.is_nullable())); ScalarValue::List(arr) } @@ -2473,7 +2795,17 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { typed_cast!(array, index, DurationNanosecondArray, DurationNanosecond)? } - + DataType::Map(_, _) => { + let a = array.slice(index, 1); + Self::Map(Arc::new(a.as_map().to_owned())) + } + DataType::Union(fields, mode) => { + let array = as_union_array(array); + let ti = array.type_id(index); + let index = array.value_offset(index); + let value = ScalarValue::try_from_array(array.child(ti), index)?; + ScalarValue::Union(Some((ti, Box::new(value))), fields.clone(), *mode) + } other => { return _not_impl_err!( "Can't create a scalar from array of type \"{other:?}\"" @@ -2484,22 +2816,30 @@ impl ScalarValue { /// Try to parse `value` into a ScalarValue of type `target_type` pub fn try_from_string(value: String, target_type: &DataType) -> Result { - let value = ScalarValue::from(value); - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), - }; - let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?; - ScalarValue::try_from_array(&cast_arr, 0) + ScalarValue::from(value).cast_to(target_type) } /// Try to cast this value to a ScalarValue of type `data_type` - pub fn cast_to(&self, data_type: &DataType) -> Result { - let cast_options = CastOptions { - safe: false, - format_options: Default::default(), + pub fn cast_to(&self, target_type: &DataType) -> Result { + self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS) + } + + /// Try to cast this value to a ScalarValue of type `data_type` with [`CastOptions`] + pub fn cast_to_with_options( + &self, + target_type: &DataType, + cast_options: &CastOptions<'static>, + ) -> Result { + let scalar_array = match (self, target_type) { + ( + ScalarValue::Float64(Some(float_ts)), + DataType::Timestamp(TimeUnit::Nanosecond, None), + ) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64)) + .to_array()?, + _ => self.to_array()?, }; - let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?; + + let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?; ScalarValue::try_from_array(&cast_arr, 0) } @@ -2591,6 +2931,9 @@ impl ScalarValue { ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val)? } + ScalarValue::Float16(val) => { + eq_array_primitive!(array, index, Float16Array, val)? + } ScalarValue::Float32(val) => { eq_array_primitive!(array, index, Float32Array, val)? } @@ -2622,12 +2965,18 @@ impl ScalarValue { ScalarValue::Utf8(val) => { eq_array_primitive!(array, index, StringArray, val)? } + ScalarValue::Utf8View(val) => { + eq_array_primitive!(array, index, StringViewArray, val)? + } ScalarValue::LargeUtf8(val) => { eq_array_primitive!(array, index, LargeStringArray, val)? } ScalarValue::Binary(val) => { eq_array_primitive!(array, index, BinaryArray, val)? } + ScalarValue::BinaryView(val) => { + eq_array_primitive!(array, index, BinaryViewArray, val)? + } ScalarValue::FixedSizeBinary(_, val) => { eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } @@ -2646,6 +2995,9 @@ impl ScalarValue { ScalarValue::Struct(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } + ScalarValue::Map(arr) => { + Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) + } ScalarValue::Date32(val) => { eq_array_primitive!(array, index, Date32Array, val)? } @@ -2697,8 +3049,15 @@ impl ScalarValue { ScalarValue::DurationNanosecond(val) => { eq_array_primitive!(array, index, DurationNanosecondArray, val)? } - ScalarValue::Union(_, _, _) => { - return _not_impl_err!("Union is not supported yet") + ScalarValue::Union(value, _, _) => { + let array = as_union_array(array); + let ti = array.type_id(index); + let index = array.value_offset(index); + if let Some((ti_v, value)) = value { + ti_v == &ti && value.eq_array(array.child(ti), index)? + } else { + array.child(ti).is_null(index) + } } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { @@ -2730,10 +3089,11 @@ impl ScalarValue { /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + match self { ScalarValue::Null | ScalarValue::Boolean(_) + | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) | ScalarValue::Decimal128(_, _, _) @@ -2759,7 +3119,9 @@ impl ScalarValue { | ScalarValue::DurationMillisecond(_) | ScalarValue::DurationMicrosecond(_) | ScalarValue::DurationNanosecond(_) => 0, - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) => { + ScalarValue::Utf8(s) + | ScalarValue::LargeUtf8(s) + | ScalarValue::Utf8View(s) => { s.as_ref().map(|s| s.capacity()).unwrap_or_default() } ScalarValue::TimestampSecond(_, s) @@ -2770,21 +3132,23 @@ impl ScalarValue { } ScalarValue::Binary(b) | ScalarValue::FixedSizeBinary(_, b) - | ScalarValue::LargeBinary(b) => { + | ScalarValue::LargeBinary(b) + | ScalarValue::BinaryView(b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } ScalarValue::List(arr) => arr.get_array_memory_size(), ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Map(arr) => arr.get_array_memory_size(), ScalarValue::Union(vals, fields, _mode) => { vals.as_ref() - .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .map(|(_id, sv)| sv.size() - size_of_val(sv)) .unwrap_or_default() // `fields` is boxed, so it is NOT already included in `self` - + std::mem::size_of_val(fields) - + (std::mem::size_of::() * fields.len()) - + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + + size_of_val(fields) + + (size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - size_of_val(field)).sum::() } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` @@ -2797,11 +3161,11 @@ impl ScalarValue { /// /// Includes the size of the [`Vec`] container itself. pub fn size_of_vec(vec: &Vec) -> usize { - std::mem::size_of_val(vec) - + (std::mem::size_of::() * vec.capacity()) + size_of_val(vec) + + (size_of::() * vec.capacity()) + vec .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -2809,11 +3173,11 @@ impl ScalarValue { /// /// Includes the size of the [`VecDeque`] container itself. pub fn size_of_vec_deque(vec_deque: &VecDeque) -> usize { - std::mem::size_of_val(vec_deque) - + (std::mem::size_of::() * vec_deque.capacity()) + size_of_val(vec_deque) + + (size_of::() * vec_deque.capacity()) + vec_deque .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } @@ -2821,11 +3185,11 @@ impl ScalarValue { /// /// Includes the size of the [`HashSet`] container itself. pub fn size_of_hashset(set: &HashSet) -> usize { - std::mem::size_of_val(set) - + (std::mem::size_of::() * set.capacity()) + size_of_val(set) + + (size_of::() * set.capacity()) + set .iter() - .map(|sv| sv.size() - std::mem::size_of_val(sv)) + .map(|sv| sv.size() - size_of_val(sv)) .sum::() } } @@ -3018,6 +3382,7 @@ impl TryFrom<&DataType> for ScalarValue { fn try_from(data_type: &DataType) -> Result { Ok(match data_type { DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float16 => ScalarValue::Float16(None), DataType::Float64 => ScalarValue::Float64(None), DataType::Float32 => ScalarValue::Float32(None), DataType::Int8 => ScalarValue::Int8(None), @@ -3036,7 +3401,9 @@ impl TryFrom<&DataType> for ScalarValue { } DataType::Utf8 => ScalarValue::Utf8(None), DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), + DataType::Utf8View => ScalarValue::Utf8View(None), DataType::Binary => ScalarValue::Binary(None), + DataType::BinaryView => ScalarValue::BinaryView(None), DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), DataType::LargeBinary => ScalarValue::LargeBinary(None), DataType::Date32 => ScalarValue::Date32(None), @@ -3085,52 +3452,33 @@ impl TryFrom<&DataType> for ScalarValue { Box::new(value_type.as_ref().try_into()?), ), // `ScalaValue::List` contains single element `ListArray`. - DataType::List(field) => ScalarValue::List( - new_null_array( - &DataType::List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - 1, - ) - .as_list::() - .to_owned() - .into(), - ), - // 'ScalarValue::LargeList' contains single element `LargeListArray - DataType::LargeList(field) => ScalarValue::LargeList( - new_null_array( - &DataType::LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - ))), - 1, - ) - .as_list::() - .to_owned() - .into(), - ), + DataType::List(field_ref) => ScalarValue::List(Arc::new( + GenericListArray::new_null(Arc::clone(field_ref), 1), + )), + // `ScalarValue::LargeList` contains single element `LargeListArray`. + DataType::LargeList(field_ref) => ScalarValue::LargeList(Arc::new( + GenericListArray::new_null(Arc::clone(field_ref), 1), + )), // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. - DataType::FixedSizeList(field, _) => ScalarValue::FixedSizeList( - new_null_array( - &DataType::FixedSizeList( - Arc::new(Field::new("item", field.data_type().clone(), true)), - 1, - ), + DataType::FixedSizeList(field_ref, fixed_length) => { + ScalarValue::FixedSizeList(Arc::new(FixedSizeListArray::new_null( + Arc::clone(field_ref), + *fixed_length, 1, - ) - .as_fixed_size_list() - .to_owned() - .into(), - ), + ))) + } DataType::Struct(fields) => ScalarValue::Struct( new_null_array(&DataType::Struct(fields.to_owned()), 1) .as_struct() .to_owned() .into(), ), + DataType::Map(fields, sorted) => ScalarValue::Map( + new_null_array(&DataType::Map(fields.to_owned(), sorted.to_owned()), 1) + .as_map() + .to_owned() + .into(), + ), DataType::Union(fields, mode) => { ScalarValue::Union(None, fields.clone(), *mode) } @@ -3168,6 +3516,7 @@ impl fmt::Display for ScalarValue { write!(f, "{v:?},{p:?},{s:?}")?; } ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float16(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, ScalarValue::Int8(e) => format_option!(f, e)?, @@ -3182,33 +3531,44 @@ impl fmt::Display for ScalarValue { ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, - ScalarValue::Utf8(e) => format_option!(f, e)?, - ScalarValue::LargeUtf8(e) => format_option!(f, e)?, + ScalarValue::Utf8(e) + | ScalarValue::LargeUtf8(e) + | ScalarValue::Utf8View(e) => format_option!(f, e)?, ScalarValue::Binary(e) | ScalarValue::FixedSizeBinary(_, e) - | ScalarValue::LargeBinary(e) => match e { - Some(l) => write!( - f, - "{}", - l.iter() - .map(|v| format!("{v}")) - .collect::>() - .join(",") - )?, + | ScalarValue::LargeBinary(e) + | ScalarValue::BinaryView(e) => match e { + Some(bytes) => { + // print up to first 10 bytes, with trailing ... if needed + for b in bytes.iter().take(10) { + write!(f, "{b:02X}")?; + } + if bytes.len() > 10 { + write!(f, "...")?; + } + } None => write!(f, "NULL")?, }, ScalarValue::List(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::LargeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, ScalarValue::FixedSizeList(arr) => fmt_list(arr.to_owned() as ArrayRef, f)?, - ScalarValue::Date32(e) => format_option!(f, e)?, - ScalarValue::Date64(e) => format_option!(f, e)?, + ScalarValue::Date32(e) => { + format_option!(f, e.map(|v| Date32Type::to_naive_date(v).to_string()))? + } + ScalarValue::Date64(e) => { + format_option!(f, e.map(|v| Date64Type::to_naive_date(v).to_string()))? + } ScalarValue::Time32Second(e) => format_option!(f, e)?, ScalarValue::Time32Millisecond(e) => format_option!(f, e)?, ScalarValue::Time64Microsecond(e) => format_option!(f, e)?, ScalarValue::Time64Nanosecond(e) => format_option!(f, e)?, - ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, - ScalarValue::IntervalMonthDayNano(e) => format_option!(f, e)?, + ScalarValue::IntervalMonthDayNano(e) => { + format_option!(f, e.map(|v| format!("{v:?}")))? + } + ScalarValue::IntervalDayTime(e) => { + format_option!(f, e.map(|v| format!("{v:?}")))?; + } ScalarValue::DurationSecond(e) => format_option!(f, e)?, ScalarValue::DurationMillisecond(e) => format_option!(f, e)?, ScalarValue::DurationMicrosecond(e) => format_option!(f, e)?, @@ -3232,9 +3592,8 @@ impl fmt::Display for ScalarValue { columns .iter() .zip(fields.iter()) - .enumerate() - .map(|(index, (column, field))| { - if nulls.is_some_and(|b| b.is_null(index)) { + .map(|(column, field)| { + if nulls.is_some_and(|b| b.is_null(0)) { format!("{}:NULL", field.name()) } else if let DataType::Struct(_) = field.data_type() { let sv = ScalarValue::Struct(Arc::new( @@ -3250,6 +3609,43 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Map(map_arr) => { + if map_arr.null_count() == map_arr.len() { + write!(f, "NULL")?; + return Ok(()); + } + + write!( + f, + "[{}]", + map_arr + .iter() + .map(|struct_array| { + if let Some(arr) = struct_array { + let mut buffer = VecDeque::new(); + for i in 0..arr.len() { + let key = + array_value_to_string(arr.column(0), i).unwrap(); + let value = + array_value_to_string(arr.column(1), i).unwrap(); + buffer.push_back(format!("{}:{}", key, value)); + } + format!( + "{{{}}}", + buffer + .into_iter() + .collect::>() + .join(",") + .as_str() + ) + } else { + "NULL".to_string() + } + }) + .collect::>() + .join(",") + )? + } ScalarValue::Union(val, _fields, _mode) => match val { Some((id, val)) => write!(f, "{}:{}", id, val)?, None => write!(f, "NULL")?, @@ -3271,12 +3667,25 @@ fn fmt_list(arr: ArrayRef, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "{value_formatter}") } +/// writes a byte array to formatter. `[1, 2, 3]` ==> `"1,2,3"` +fn fmt_binary(data: &[u8], f: &mut fmt::Formatter) -> fmt::Result { + let mut iter = data.iter(); + if let Some(b) = iter.next() { + write!(f, "{b}")?; + } + for b in iter { + write!(f, ",{b}")?; + } + Ok(()) +} + impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), + ScalarValue::Float16(_) => write!(f, "Float16({self})"), ScalarValue::Float32(_) => write!(f, "Float32({self})"), ScalarValue::Float64(_) => write!(f, "Float64({self})"), ScalarValue::Int8(_) => write!(f, "Int8({self})"), @@ -3301,18 +3710,36 @@ impl fmt::Debug for ScalarValue { } ScalarValue::Utf8(None) => write!(f, "Utf8({self})"), ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{self}\")"), + ScalarValue::Utf8View(None) => write!(f, "Utf8View({self})"), + ScalarValue::Utf8View(Some(_)) => write!(f, "Utf8View(\"{self}\")"), ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({self})"), ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{self}\")"), ScalarValue::Binary(None) => write!(f, "Binary({self})"), - ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{self}\")"), + ScalarValue::Binary(Some(b)) => { + write!(f, "Binary(\"")?; + fmt_binary(b.as_slice(), f)?; + write!(f, "\")") + } + ScalarValue::BinaryView(None) => write!(f, "BinaryView({self})"), + ScalarValue::BinaryView(Some(b)) => { + write!(f, "BinaryView(\"")?; + fmt_binary(b.as_slice(), f)?; + write!(f, "\")") + } ScalarValue::FixedSizeBinary(size, None) => { write!(f, "FixedSizeBinary({size}, {self})") } - ScalarValue::FixedSizeBinary(size, Some(_)) => { - write!(f, "FixedSizeBinary({size}, \"{self}\")") + ScalarValue::FixedSizeBinary(size, Some(b)) => { + write!(f, "FixedSizeBinary({size}, \"")?; + fmt_binary(b.as_slice(), f)?; + write!(f, "\")") } ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), - ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), + ScalarValue::LargeBinary(Some(b)) => { + write!(f, "LargeBinary(\"")?; + fmt_binary(b.as_slice(), f)?; + write!(f, "\")") + } ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), @@ -3338,6 +3765,33 @@ impl fmt::Debug for ScalarValue { .join(",") ) } + ScalarValue::Map(map_arr) => { + write!( + f, + "Map([{}])", + map_arr + .iter() + .map(|struct_array| { + if let Some(arr) = struct_array { + let buffer: Vec = (0..arr.len()) + .map(|i| { + let key = array_value_to_string(arr.column(0), i) + .unwrap(); + let value = + array_value_to_string(arr.column(1), i) + .unwrap(); + format!("{key:?}:{value:?}") + }) + .collect(); + format!("{{{}}}", buffer.join(",")) + } else { + "NULL".to_string() + } + }) + .collect::>() + .join(",") + ) + } ScalarValue::Date32(_) => write!(f, "Date32(\"{self}\")"), ScalarValue::Date64(_) => write!(f, "Date64(\"{self}\")"), ScalarValue::Time32Second(_) => write!(f, "Time32Second(\"{self}\")"), @@ -3426,17 +3880,45 @@ mod tests { use super::*; use crate::cast::{ - as_string_array, as_struct_array, as_uint32_array, as_uint64_array, + as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, }; use crate::assert_batches_eq; + use crate::utils::array_into_list_array_nullable; use arrow::buffer::OffsetBuffer; use arrow::compute::{is_null, kernels}; + use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_columns; + use arrow_buffer::{Buffer, NullBuffer}; use arrow_schema::Fields; use chrono::NaiveDate; use rand::Rng; + #[test] + fn test_scalar_value_from_for_map() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let expected = builder.finish(); + + let sv = ScalarValue::Map(Arc::new(expected.clone())); + let map_arr = sv.to_array().unwrap(); + let actual = as_map_array(&map_arr).unwrap(); + assert_eq!(actual, &expected); + } + #[test] fn test_scalar_value_from_for_struct() { let boolean = Arc::new(BooleanArray::from(vec![false])); @@ -3445,11 +3927,11 @@ mod tests { let expected = StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Boolean, false)), - boolean.clone() as ArrayRef, + Arc::clone(&boolean) as ArrayRef, ), ( Arc::new(Field::new("c", DataType::Int32, false)), - int.clone() as ArrayRef, + Arc::clone(&int) as ArrayRef, ), ]); @@ -3491,11 +3973,11 @@ mod tests { let struct_array = StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Boolean, false)), - boolean.clone() as ArrayRef, + Arc::clone(&boolean) as ArrayRef, ), ( Arc::new(Field::new("c", DataType::Int32, false)), - int.clone() as ArrayRef, + Arc::clone(&int) as ArrayRef, ), ]); let sv = ScalarValue::Struct(Arc::new(struct_array)); @@ -3509,11 +3991,11 @@ mod tests { let struct_array = StructArray::from(vec![ ( Arc::new(Field::new("b", DataType::Boolean, false)), - boolean.clone() as ArrayRef, + Arc::clone(&boolean) as ArrayRef, ), ( Arc::new(Field::new("c", DataType::Int32, false)), - int.clone() as ArrayRef, + Arc::clone(&int) as ArrayRef, ), ]); @@ -3545,7 +4027,7 @@ mod tests { fn test_to_array_of_size_for_fsl() { let values = Int32Array::from_iter([Some(1), None, Some(2)]); let field = Arc::new(Field::new("item", DataType::Int32, true)); - let arr = FixedSizeListArray::new(field.clone(), 3, Arc::new(values), None); + let arr = FixedSizeListArray::new(Arc::clone(&field), 3, Arc::new(values), None); let sv = ScalarValue::FixedSizeList(Arc::new(arr)); let actual_arr = sv .to_array_of_size(2) @@ -3560,6 +4042,12 @@ mod tests { &expected_arr, as_fixed_size_list_array(actual_arr.as_ref()).unwrap() ); + + let empty_array = sv + .to_array_of_size(0) + .expect("Failed to convert to empty array"); + + assert_eq!(empty_array.len(), 0); } #[test] @@ -3570,9 +4058,9 @@ mod tests { ScalarValue::from("data-fusion"), ]; - let result = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); + let result = ScalarValue::new_list_nullable(scalars.as_slice(), &DataType::Utf8); - let expected = array_into_list_array(Arc::new(StringArray::from(vec![ + let expected = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ "rust", "arrow", "data-fusion", @@ -3625,13 +4113,13 @@ mod tests { fn test_iter_to_array_fixed_size_list() { let field = Arc::new(Field::new("item", DataType::Int32, true)); let f1 = Arc::new(FixedSizeListArray::new( - field.clone(), + Arc::clone(&field), 3, Arc::new(Int32Array::from(vec![1, 2, 3])), None, )); let f2 = Arc::new(FixedSizeListArray::new( - field.clone(), + Arc::clone(&field), 3, Arc::new(Int32Array::from(vec![4, 5, 6])), None, @@ -3639,7 +4127,7 @@ mod tests { let f_nulls = Arc::new(FixedSizeListArray::new_null(field, 1, 1)); let scalars = vec![ - ScalarValue::FixedSizeList(f_nulls.clone()), + ScalarValue::FixedSizeList(Arc::clone(&f_nulls)), ScalarValue::FixedSizeList(f1), ScalarValue::FixedSizeList(f2), ScalarValue::FixedSizeList(f_nulls), @@ -3784,10 +4272,12 @@ mod tests { #[test] fn iter_to_array_string_test() { - let arr1 = - array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let arr2 = - array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"]))); + let arr1 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ + "foo", "bar", "baz", + ]))); + let arr2 = array_into_list_array_nullable(Arc::new(StringArray::from(vec![ + "rust", "world", + ]))); let scalars = vec![ ScalarValue::List(Arc::new(arr1)), @@ -3888,7 +4378,7 @@ mod tests { .strip_backtrace(); assert_eq!( err, - "Arrow error: Compute error: Overflow happened on: 2147483647 - -2147483648" + "Arrow error: Arithmetic overflow: Overflow happened on: 2147483647 - -2147483648" ) } @@ -3909,7 +4399,7 @@ mod tests { .sub_checked(&int_value_2) .unwrap_err() .strip_backtrace(); - assert_eq!(err, "Arrow error: Compute error: Overflow happened on: 9223372036854775807 - -9223372036854775808") + assert_eq!(err, "Arrow error: Arithmetic overflow: Overflow happened on: 9223372036854775807 - -9223372036854775808") } #[test] @@ -3961,14 +4451,18 @@ mod tests { let right_array = right.to_array().expect("Failed to convert to array"); let arrow_left_array = left_array.as_primitive::(); let arrow_right_array = right_array.as_primitive::(); - let arrow_result = kernels::numeric::add(arrow_left_array, arrow_right_array); + let arrow_result = add(arrow_left_array, arrow_right_array); assert_eq!(scalar_result.is_ok(), arrow_result.is_ok()); } #[test] fn test_interval_add_timestamp() -> Result<()> { - let interval = ScalarValue::IntervalMonthDayNano(Some(123)); + let interval = ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 1, + days: 2, + nanoseconds: 3, + })); let timestamp = ScalarValue::TimestampNanosecond(Some(123), None); let result = interval.add(×tamp)?; let expect = timestamp.add(&interval)?; @@ -3980,7 +4474,10 @@ mod tests { let expect = timestamp.add(&interval)?; assert_eq!(result, expect); - let interval = ScalarValue::IntervalDayTime(Some(123)); + let interval = ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 1, + milliseconds: 23, + })); let timestamp = ScalarValue::TimestampNanosecond(Some(123), None); let result = interval.add(×tamp)?; let expect = timestamp.add(&interval)?; @@ -4187,7 +4684,7 @@ mod tests { #[test] fn scalar_list_null_to_array() { - let list_array = ScalarValue::new_list(&[], &DataType::UInt64); + let list_array = ScalarValue::new_list_nullable(&[], &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 0); @@ -4208,7 +4705,7 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::UInt64(Some(101)), ]; - let list_array = ScalarValue::new_list(&values, &DataType::UInt64); + let list_array = ScalarValue::new_list_nullable(&values, &DataType::UInt64); assert_eq!(list_array.len(), 1); assert_eq!(list_array.values().len(), 3); @@ -4455,28 +4952,49 @@ mod tests { let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); - assert_eq!(non_null_list_scalar.data_type(), data_type.clone()); + assert_eq!(non_null_list_scalar.data_type(), data_type); assert_eq!(null_list_scalar.data_type(), data_type); } #[test] - fn scalar_try_from_list() { - let data_type = - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); - let data_type = &data_type; - let scalar: ScalarValue = data_type.try_into().unwrap(); + fn scalar_try_from_list_datatypes() { + let inner_field = Arc::new(Field::new("item", DataType::Int32, true)); + // Test for List + let data_type = &DataType::List(Arc::clone(&inner_field)); + let scalar: ScalarValue = data_type.try_into().unwrap(); let expected = ScalarValue::List( - new_null_array( - &DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - 1, - ) - .as_list::() - .to_owned() - .into(), + new_null_array(data_type, 1) + .as_list::() + .to_owned() + .into(), ); + assert_eq!(expected, scalar); + assert!(expected.is_null()); - assert_eq!(expected, scalar) + // Test for LargeList + let data_type = &DataType::LargeList(Arc::clone(&inner_field)); + let scalar: ScalarValue = data_type.try_into().unwrap(); + let expected = ScalarValue::LargeList( + new_null_array(data_type, 1) + .as_list::() + .to_owned() + .into(), + ); + assert_eq!(expected, scalar); + assert!(expected.is_null()); + + // Test for FixedSizeList(5) + let data_type = &DataType::FixedSizeList(Arc::clone(&inner_field), 5); + let scalar: ScalarValue = data_type.try_into().unwrap(); + let expected = ScalarValue::FixedSizeList( + new_null_array(data_type, 1) + .as_fixed_size_list() + .to_owned() + .into(), + ); + assert_eq!(expected, scalar); + assert!(expected.is_null()); } #[test] @@ -4548,13 +5066,13 @@ mod tests { // thus the size of the enum appears to as well // The value may also change depending on rust version - assert_eq!(std::mem::size_of::(), 64); + assert_eq!(size_of::(), 64); } #[test] fn memory_size() { let sv = ScalarValue::Binary(Some(Vec::with_capacity(10))); - assert_eq!(sv.size(), std::mem::size_of::() + 10,); + assert_eq!(sv.size(), size_of::() + 10,); let sv_size = sv.size(); let mut v = Vec::with_capacity(10); @@ -4563,9 +5081,7 @@ mod tests { assert_eq!(v.capacity(), 10); assert_eq!( ScalarValue::size_of_vec(&v), - std::mem::size_of::>() - + (9 * std::mem::size_of::()) - + sv_size, + size_of::>() + (9 * size_of::()) + sv_size, ); let mut s = HashSet::with_capacity(0); @@ -4575,8 +5091,8 @@ mod tests { let s_capacity = s.capacity(); assert_eq!( ScalarValue::size_of_hashset(&s), - std::mem::size_of::>() - + ((s_capacity - 1) * std::mem::size_of::()) + size_of::>() + + ((s_capacity - 1) * size_of::()) + sv_size, ); } @@ -4609,6 +5125,17 @@ mod tests { let str_vals = [Some("foo"), None, Some("bar")]; + let interval_dt_vals = [ + Some(IntervalDayTime::MINUS_ONE), + None, + Some(IntervalDayTime::ONE), + ]; + let interval_mdn_vals = [ + Some(IntervalMonthDayNano::MINUS_ONE), + None, + Some(IntervalMonthDayNano::ONE), + ]; + /// Test each value in `scalar` with the corresponding element /// at `array`. Assumes each element is unique (aka not equal /// with all other indexes) @@ -4754,7 +5281,12 @@ mod tests { Some("UTC".into()) ), make_test_case!(i32_vals, IntervalYearMonthArray, IntervalYearMonth), - make_test_case!(i64_vals, IntervalDayTimeArray, IntervalDayTime), + make_test_case!(interval_dt_vals, IntervalDayTimeArray, IntervalDayTime), + make_test_case!( + interval_mdn_vals, + IntervalMonthDayNanoArray, + IntervalMonthDayNano + ), make_str_dict_test_case!(str_vals, Int8Type), make_str_dict_test_case!(str_vals, Int16Type), make_str_dict_test_case!(str_vals, Int32Type), @@ -4859,35 +5391,35 @@ mod tests { let field_f = Arc::new(Field::new("f", DataType::Int64, false)); let field_d = Arc::new(Field::new( "D", - DataType::Struct(vec![field_e.clone(), field_f.clone()].into()), + DataType::Struct(vec![Arc::clone(&field_e), Arc::clone(&field_f)].into()), false, )); let struct_array = StructArray::from(vec![ ( - field_e.clone(), + Arc::clone(&field_e), Arc::new(Int16Array::from(vec![2])) as ArrayRef, ), ( - field_f.clone(), + Arc::clone(&field_f), Arc::new(Int64Array::from(vec![3])) as ArrayRef, ), ]); let struct_array = StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(Int32Array::from(vec![23])) as ArrayRef, ), ( - field_b.clone(), + Arc::clone(&field_b), Arc::new(BooleanArray::from(vec![false])) as ArrayRef, ), ( - field_c.clone(), + Arc::clone(&field_c), Arc::new(StringArray::from(vec!["Hello"])) as ArrayRef, ), - (field_d.clone(), Arc::new(struct_array) as ArrayRef), + (Arc::clone(&field_d), Arc::new(struct_array) as ArrayRef), ]); let scalar = ScalarValue::Struct(Arc::new(struct_array)); @@ -4897,26 +5429,26 @@ mod tests { let expected = Arc::new(StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(Int32Array::from(vec![23, 23])) as ArrayRef, ), ( - field_b.clone(), + Arc::clone(&field_b), Arc::new(BooleanArray::from(vec![false, false])) as ArrayRef, ), ( - field_c.clone(), + Arc::clone(&field_c), Arc::new(StringArray::from(vec!["Hello", "Hello"])) as ArrayRef, ), ( - field_d.clone(), + Arc::clone(&field_d), Arc::new(StructArray::from(vec![ ( - field_e.clone(), + Arc::clone(&field_e), Arc::new(Int16Array::from(vec![2, 2])) as ArrayRef, ), ( - field_f.clone(), + Arc::clone(&field_f), Arc::new(Int64Array::from(vec![3, 3])) as ArrayRef, ), ])) as ArrayRef, @@ -4995,26 +5527,26 @@ mod tests { let expected = Arc::new(StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(Int32Array::from(vec![23, 7, -1000])) as ArrayRef, ), ( - field_b.clone(), + Arc::clone(&field_b), Arc::new(BooleanArray::from(vec![false, true, true])) as ArrayRef, ), ( - field_c.clone(), + Arc::clone(&field_c), Arc::new(StringArray::from(vec!["Hello", "World", "!!!!!"])) as ArrayRef, ), ( - field_d.clone(), + Arc::clone(&field_d), Arc::new(StructArray::from(vec![ ( - field_e.clone(), + Arc::clone(&field_e), Arc::new(Int16Array::from(vec![2, 4, 6])) as ArrayRef, ), ( - field_f.clone(), + Arc::clone(&field_f), Arc::new(Int64Array::from(vec![3, 5, 7])) as ArrayRef, ), ])) as ArrayRef, @@ -5024,6 +5556,112 @@ mod tests { assert_eq!(&array, &expected); } + #[test] + fn test_scalar_union_sparse() { + let field_a = Arc::new(Field::new("A", DataType::Int32, true)); + let field_b = Arc::new(Field::new("B", DataType::Boolean, true)); + let field_c = Arc::new(Field::new("C", DataType::Utf8, true)); + let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]); + + let mut values_a = vec![None; 6]; + values_a[0] = Some(42); + let mut values_b = vec![None; 6]; + values_b[1] = Some(true); + let mut values_c = vec![None; 6]; + values_c[2] = Some("foo"); + let children: Vec = vec![ + Arc::new(Int32Array::from(values_a)), + Arc::new(BooleanArray::from(values_b)), + Arc::new(StringArray::from(values_c)), + ]; + + let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]); + let array: ArrayRef = Arc::new( + UnionArray::try_new(fields.clone(), type_ids, None, children) + .expect("UnionArray"), + ); + + let expected = [ + (0, ScalarValue::from(42)), + (1, ScalarValue::from(true)), + (2, ScalarValue::from("foo")), + (0, ScalarValue::Int32(None)), + (1, ScalarValue::Boolean(None)), + (2, ScalarValue::Utf8(None)), + ]; + + for (i, (ti, value)) in expected.into_iter().enumerate() { + let is_null = value.is_null(); + let value = Some((ti, Box::new(value))); + let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Sparse); + let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array"); + + assert_eq!( + actual, expected, + "[{i}] {actual} was not equal to {expected}" + ); + + assert!( + expected.eq_array(&array, i).expect("eq_array"), + "[{i}] {expected}.eq_array was false" + ); + + if is_null { + assert!(actual.is_null(), "[{i}] {actual} was not null") + } + } + } + + #[test] + fn test_scalar_union_dense() { + let field_a = Arc::new(Field::new("A", DataType::Int32, true)); + let field_b = Arc::new(Field::new("B", DataType::Boolean, true)); + let field_c = Arc::new(Field::new("C", DataType::Utf8, true)); + let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]); + let children: Vec = vec![ + Arc::new(Int32Array::from(vec![Some(42), None])), + Arc::new(BooleanArray::from(vec![Some(true), None])), + Arc::new(StringArray::from(vec![Some("foo"), None])), + ]; + + let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]); + let offsets = ScalarBuffer::from(vec![0, 0, 0, 1, 1, 1]); + let array: ArrayRef = Arc::new( + UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children) + .expect("UnionArray"), + ); + + let expected = [ + (0, ScalarValue::from(42)), + (1, ScalarValue::from(true)), + (2, ScalarValue::from("foo")), + (0, ScalarValue::Int32(None)), + (1, ScalarValue::Boolean(None)), + (2, ScalarValue::Utf8(None)), + ]; + + for (i, (ti, value)) in expected.into_iter().enumerate() { + let is_null = value.is_null(); + let value = Some((ti, Box::new(value))); + let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Dense); + let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array"); + + assert_eq!( + actual, expected, + "[{i}] {actual} was not equal to {expected}" + ); + + assert!( + expected.eq_array(&array, i).expect("eq_array"), + "[{i}] {expected}.eq_array was false" + ); + + if is_null { + assert!(actual.is_null(), "[{i}] {actual} was not null") + } + } + } + #[test] fn test_lists_in_struct() { let field_a = Arc::new(Field::new("A", DataType::Utf8, false)); @@ -5078,11 +5716,11 @@ mod tests { let array = as_struct_array(&array).unwrap(); let expected = StructArray::from(vec![ ( - field_a.clone(), + Arc::clone(&field_a), Arc::new(StringArray::from(vec!["First", "Second", "Third"])) as ArrayRef, ), ( - field_primitive_list.clone(), + Arc::clone(&field_primitive_list), Arc::new(ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), Some(vec![Some(4), Some(5)]), @@ -5095,14 +5733,14 @@ mod tests { // Define list-of-structs scalars - let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap(); - let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array))); + let nl0_array = ScalarValue::iter_to_array(vec![s0, s1.clone()]).unwrap(); + let nl0 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl0_array))); - let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap(); - let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array))); + let nl1_array = ScalarValue::iter_to_array(vec![s2]).unwrap(); + let nl1 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl1_array))); - let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap(); - let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array))); + let nl2_array = ScalarValue::iter_to_array(vec![s1]).unwrap(); + let nl2 = ScalarValue::List(Arc::new(array_into_list_array_nullable(nl2_array))); // iter_to_array for list-of-struct let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap(); @@ -5325,6 +5963,13 @@ mod tests { ScalarValue::Utf8(None), DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), ); + + check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); + check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); + check_scalar_cast( + ScalarValue::from("larger than 12 bytes string"), + DataType::Utf8View, + ); } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` @@ -5363,6 +6008,88 @@ mod tests { Ok(()) } + #[test] + #[allow(arithmetic_overflow)] // we want to test them + fn test_scalar_negative_overflows() -> Result<()> { + macro_rules! test_overflow_on_value { + ($($val:expr),* $(,)?) => {$( + { + let value: ScalarValue = $val; + let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}"); + let root_err = err.find_root(); + match root_err{ + DataFusionError::ArrowError( + ArrowError::ArithmeticOverflow(_), + _, + ) => {} + _ => return Err(err), + }; + } + )*}; + } + test_overflow_on_value!( + // the integers + i8::MIN.into(), + i16::MIN.into(), + i32::MIN.into(), + i64::MIN.into(), + // for decimals, only value needs to be tested + ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?, + ScalarValue::Decimal256(Some(i256::MIN), 20, 5), + // interval, check all possible values + ScalarValue::IntervalYearMonth(Some(i32::MIN)), + ScalarValue::new_interval_dt(i32::MIN, 999), + ScalarValue::new_interval_dt(1, i32::MIN), + ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456), + ScalarValue::new_interval_mdn(12, i32::MIN, 123_456), + ScalarValue::new_interval_mdn(12, 15, i64::MIN), + // tz doesn't matter when negating + ScalarValue::TimestampSecond(Some(i64::MIN), None), + ScalarValue::TimestampMillisecond(Some(i64::MIN), None), + ScalarValue::TimestampMicrosecond(Some(i64::MIN), None), + ScalarValue::TimestampNanosecond(Some(i64::MIN), None), + ); + + let float_cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + (f32::MIN.into(), f32::MAX.into()), + (f32::MAX.into(), f32::MIN.into()), + (f64::MIN.into(), f64::MAX.into()), + (f64::MAX.into(), f64::MIN.into()), + ]; + // skip float 16 because they aren't supported + for (test, expected) in float_cases.into_iter().skip(2) { + assert_eq!(test.arithmetic_negate()?, expected); + } + Ok(()) + } + + #[test] + fn f16_test_overflow() { + // TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case + let cases = [ + ( + ScalarValue::Float16(Some(f16::MIN)), + ScalarValue::Float16(Some(f16::MAX)), + ), + ( + ScalarValue::Float16(Some(f16::MAX)), + ScalarValue::Float16(Some(f16::MIN)), + ), + ]; + + for (test, expected) in cases { + assert_eq!(test.arithmetic_negate().unwrap(), expected); + } + } + macro_rules! expect_operation_error { ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { #[test] @@ -5514,6 +6241,21 @@ mod tests { ScalarValue::UInt64(Some(10)), 5, ), + ( + ScalarValue::Float16(Some(f16::from_f32(1.1))), + ScalarValue::Float16(Some(f16::from_f32(1.9))), + 1, + ), + ( + ScalarValue::Float16(Some(f16::from_f32(-5.3))), + ScalarValue::Float16(Some(f16::from_f32(-9.2))), + 4, + ), + ( + ScalarValue::Float16(Some(f16::from_f32(-5.3))), + ScalarValue::Float16(Some(f16::from_f32(-9.7))), + 4, + ), ( ScalarValue::Float32(Some(1.0)), ScalarValue::Float32(Some(2.0)), @@ -5586,6 +6328,14 @@ mod tests { // Different type (ScalarValue::Int8(Some(1)), ScalarValue::Int16(Some(1))), (ScalarValue::Int8(Some(1)), ScalarValue::Float32(Some(1.0))), + ( + ScalarValue::Float16(Some(f16::from_f32(1.0))), + ScalarValue::Float32(Some(1.0)), + ), + ( + ScalarValue::Float16(Some(f16::from_f32(1.0))), + ScalarValue::Int32(Some(1)), + ), ( ScalarValue::Float64(Some(1.1)), ScalarValue::Float32(Some(2.2)), @@ -5726,18 +6476,18 @@ mod tests { let struct_value = vec![ ( - fields[0].clone(), + Arc::clone(&fields[0]), Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, ), ( - fields[1].clone(), + Arc::clone(&fields[1]), Arc::new(StructArray::from(vec![ ( - fields_b[0].clone(), + Arc::clone(&fields_b[0]), Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, ), ( - fields_b[1].clone(), + Arc::clone(&fields_b[1]), Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, ), ])) as ArrayRef, @@ -5746,19 +6496,19 @@ mod tests { let struct_value_with_nulls = vec![ ( - fields[0].clone(), + Arc::clone(&fields[0]), Arc::new(UInt64Array::from(vec![Some(1)])) as ArrayRef, ), ( - fields[1].clone(), + Arc::clone(&fields[1]), Arc::new(StructArray::from(( vec![ ( - fields_b[0].clone(), + Arc::clone(&fields_b[0]), Arc::new(UInt64Array::from(vec![Some(2)])) as ArrayRef, ), ( - fields_b[1].clone(), + Arc::clone(&fields_b[1]), Arc::new(UInt64Array::from(vec![Some(3)])) as ArrayRef, ), ], @@ -5833,6 +6583,7 @@ mod tests { .unwrap(); assert_eq!(s.to_string(), "{a:1,b:}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:})"#); let ScalarValue::Struct(arr) = s else { panic!("Expected struct"); @@ -5850,6 +6601,43 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_null_bug() { + let field_a = Field::new("a", DataType::Int32, true); + let field_b = Field::new("b", DataType::Int32, true); + let fields = Fields::from(vec![field_a, field_b]); + + let array_a = Arc::new(Int32Array::from_iter_values([1])); + let array_b = Arc::new(Int32Array::from_iter_values([2])); + let arrays: Vec = vec![array_a, array_b]; + + let mut not_nulls = BooleanBufferBuilder::new(1); + not_nulls.append(true); + let not_nulls = not_nulls.finish(); + let not_nulls = Some(NullBuffer::new(not_nulls)); + + let ar = StructArray::new(fields, arrays, not_nulls); + let s = ScalarValue::Struct(Arc::new(ar)); + + assert_eq!(s.to_string(), "{a:1,b:2}"); + assert_eq!(format!("{s:?}"), r#"Struct({a:1,b:2})"#); + + let ScalarValue::Struct(arr) = s else { + panic!("Expected struct"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); + let expected = [ + "+--------------+", + "| s |", + "+--------------+", + "| {a: 1, b: 2} |", + "+--------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + #[test] fn test_struct_display_null() { let fields = vec![Field::new("a", DataType::Int32, false)]; @@ -5874,10 +6662,147 @@ mod tests { assert_batches_eq!(&expected, &[batch]); } + #[test] + fn test_map_display_and_debug() { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + + let map_value = ScalarValue::Map(Arc::new(builder.finish())); + + assert_eq!(map_value.to_string(), "[{joe:1},{blogs:2,foo:4},{},NULL]"); + assert_eq!( + format!("{map_value:?}"), + r#"Map([{"joe":"1"},{"blogs":"2","foo":"4"},{},NULL])"# + ); + + let ScalarValue::Map(arr) = map_value else { + panic!("Expected map"); + }; + + //verify compared to arrow display + let batch = RecordBatch::try_from_iter(vec![("m", arr as _)]).unwrap(); + let expected = [ + "+--------------------+", + "| m |", + "+--------------------+", + "| {joe: 1} |", + "| {blogs: 2, foo: 4} |", + "| {} |", + "| |", + "+--------------------+", + ]; + assert_batches_eq!(&expected, &[batch]); + } + + #[test] + fn test_binary_display() { + let no_binary_value = ScalarValue::Binary(None); + assert_eq!(format!("{no_binary_value}"), "NULL"); + let single_binary_value = ScalarValue::Binary(Some(vec![42u8])); + assert_eq!(format!("{single_binary_value}"), "2A"); + let small_binary_value = ScalarValue::Binary(Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value}"), "010203"); + let large_binary_value = + ScalarValue::Binary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); + assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); + + let no_binary_value = ScalarValue::BinaryView(None); + assert_eq!(format!("{no_binary_value}"), "NULL"); + let small_binary_value = ScalarValue::BinaryView(Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value}"), "010203"); + let large_binary_value = + ScalarValue::BinaryView(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); + assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); + + let no_binary_value = ScalarValue::LargeBinary(None); + assert_eq!(format!("{no_binary_value}"), "NULL"); + let small_binary_value = ScalarValue::LargeBinary(Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value}"), "010203"); + let large_binary_value = + ScalarValue::LargeBinary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); + assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); + + let no_binary_value = ScalarValue::FixedSizeBinary(3, None); + assert_eq!(format!("{no_binary_value}"), "NULL"); + let small_binary_value = ScalarValue::FixedSizeBinary(3, Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value}"), "010203"); + let large_binary_value = ScalarValue::FixedSizeBinary( + 11, + Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), + ); + assert_eq!(format!("{large_binary_value}"), "0102030405060708090A..."); + } + + #[test] + fn test_binary_debug() { + let no_binary_value = ScalarValue::Binary(None); + assert_eq!(format!("{no_binary_value:?}"), "Binary(NULL)"); + let single_binary_value = ScalarValue::Binary(Some(vec![42u8])); + assert_eq!(format!("{single_binary_value:?}"), "Binary(\"42\")"); + let small_binary_value = ScalarValue::Binary(Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value:?}"), "Binary(\"1,2,3\")"); + let large_binary_value = + ScalarValue::Binary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); + assert_eq!( + format!("{large_binary_value:?}"), + "Binary(\"1,2,3,4,5,6,7,8,9,10,11\")" + ); + + let no_binary_value = ScalarValue::BinaryView(None); + assert_eq!(format!("{no_binary_value:?}"), "BinaryView(NULL)"); + let small_binary_value = ScalarValue::BinaryView(Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value:?}"), "BinaryView(\"1,2,3\")"); + let large_binary_value = + ScalarValue::BinaryView(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); + assert_eq!( + format!("{large_binary_value:?}"), + "BinaryView(\"1,2,3,4,5,6,7,8,9,10,11\")" + ); + + let no_binary_value = ScalarValue::LargeBinary(None); + assert_eq!(format!("{no_binary_value:?}"), "LargeBinary(NULL)"); + let small_binary_value = ScalarValue::LargeBinary(Some(vec![1u8, 2, 3])); + assert_eq!(format!("{small_binary_value:?}"), "LargeBinary(\"1,2,3\")"); + let large_binary_value = + ScalarValue::LargeBinary(Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])); + assert_eq!( + format!("{large_binary_value:?}"), + "LargeBinary(\"1,2,3,4,5,6,7,8,9,10,11\")" + ); + + let no_binary_value = ScalarValue::FixedSizeBinary(3, None); + assert_eq!(format!("{no_binary_value:?}"), "FixedSizeBinary(3, NULL)"); + let small_binary_value = ScalarValue::FixedSizeBinary(3, Some(vec![1u8, 2, 3])); + assert_eq!( + format!("{small_binary_value:?}"), + "FixedSizeBinary(3, \"1,2,3\")" + ); + let large_binary_value = ScalarValue::FixedSizeBinary( + 11, + Some(vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]), + ); + assert_eq!( + format!("{large_binary_value:?}"), + "FixedSizeBinary(11, \"1,2,3,4,5,6,7,8,9,10,11\")" + ); + } + #[test] fn test_build_timestamp_millisecond_list() { let values = vec![ScalarValue::TimestampMillisecond(Some(1), None)]; - let arr = ScalarValue::new_list( + let arr = ScalarValue::new_list_nullable( &values, &DataType::Timestamp(TimeUnit::Millisecond, None), ); @@ -5888,7 +6813,7 @@ mod tests { fn test_newlist_timestamp_zone() { let s: &'static str = "UTC"; let values = vec![ScalarValue::TimestampMillisecond(Some(1), Some(s.into()))]; - let arr = ScalarValue::new_list( + let arr = ScalarValue::new_list_nullable( &values, &DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), ); @@ -6003,4 +6928,44 @@ mod tests { } intervals } + + fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect() + } + + #[test] + fn sparse_scalar_union_is_null() { + let sparse_scalar = ScalarValue::Union( + Some((0_i8, Box::new(ScalarValue::Int32(None)))), + union_fields(), + UnionMode::Sparse, + ); + assert!(sparse_scalar.is_null()); + } + + #[test] + fn dense_scalar_union_is_null() { + let dense_scalar = ScalarValue::Union( + Some((0_i8, Box::new(ScalarValue::Int32(None)))), + union_fields(), + UnionMode::Dense, + ); + assert!(dense_scalar.is_null()); + } + + #[test] + fn null_dictionary_scalar_produces_null_dictionary_array() { + let dictionary_scalar = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Null), + ); + assert!(dictionary_scalar.is_null()); + let dictionary_array = dictionary_scalar.to_array().unwrap(); + assert!(dictionary_array.is_null(0)); + } } diff --git a/datafusion/common/src/scalar/struct_builder.rs b/datafusion/common/src/scalar/struct_builder.rs index b1a34e4a61d0..4a6a8f0289a7 100644 --- a/datafusion/common/src/scalar/struct_builder.rs +++ b/datafusion/common/src/scalar/struct_builder.rs @@ -144,7 +144,7 @@ impl IntoFieldRef for FieldRef { impl IntoFieldRef for &FieldRef { fn into_field_ref(self) -> FieldRef { - self.clone() + Arc::clone(self) } } diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index 6cefef8d0eb5..1aa42705e7f8 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -19,13 +19,13 @@ use std::fmt::{self, Debug, Display}; -use crate::ScalarValue; +use crate::{Result, ScalarValue}; -use arrow_schema::Schema; +use arrow_schema::{Schema, SchemaRef}; /// Represents a value with a degree of certainty. `Precision` is used to /// propagate information the precision of statistical values. -#[derive(Clone, PartialEq, Eq, Default)] +#[derive(Clone, PartialEq, Eq, Default, Copy)] pub enum Precision { /// The exact value is known Exact(T), @@ -190,7 +190,7 @@ impl Precision { } } -impl Debug for Precision { +impl Debug for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -200,7 +200,7 @@ impl Debug for Precision } } -impl Display for Precision { +impl Display for Precision { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Precision::Exact(inner) => write!(f, "Exact({:?})", inner), @@ -247,26 +247,121 @@ impl Statistics { /// If the exactness of a [`Statistics`] instance is lost, this function relaxes /// the exactness of all information by converting them [`Precision::Inexact`]. - pub fn into_inexact(self) -> Self { - Statistics { - num_rows: self.num_rows.to_inexact(), - total_byte_size: self.total_byte_size.to_inexact(), - column_statistics: self - .column_statistics - .into_iter() - .map(|cs| ColumnStatistics { - null_count: cs.null_count.to_inexact(), - max_value: cs.max_value.to_inexact(), - min_value: cs.min_value.to_inexact(), - distinct_count: cs.distinct_count.to_inexact(), - }) - .collect::>(), + pub fn to_inexact(mut self) -> Self { + self.num_rows = self.num_rows.to_inexact(); + self.total_byte_size = self.total_byte_size.to_inexact(); + self.column_statistics = self + .column_statistics + .into_iter() + .map(|s| s.to_inexact()) + .collect(); + self + } + + /// Project the statistics to the given column indices. + /// + /// For example, if we had statistics for columns `{"a", "b", "c"}`, + /// projecting to `vec![2, 1]` would return statistics for columns `{"c", + /// "b"}`. + pub fn project(mut self, projection: Option<&Vec>) -> Self { + let Some(projection) = projection else { + return self; + }; + + // todo: it would be nice to avoid cloning column statistics if + // possible (e.g. if the projection did not contain duplicates) + self.column_statistics = projection + .iter() + .map(|&i| self.column_statistics[i].clone()) + .collect(); + + self + } + + /// Calculates the statistics after `fetch` and `skip` operations apply. + /// Here, `self` denotes per-partition statistics. Use the `n_partitions` + /// parameter to compute global statistics in a multi-partition setting. + pub fn with_fetch( + mut self, + schema: SchemaRef, + fetch: Option, + skip: usize, + n_partitions: usize, + ) -> Result { + let fetch_val = fetch.unwrap_or(usize::MAX); + + self.num_rows = match self { + Statistics { + num_rows: Precision::Exact(nr), + .. + } + | Statistics { + num_rows: Precision::Inexact(nr), + .. + } => { + // Here, the inexact case gives us an upper bound on the number of rows. + if nr <= skip { + // All input data will be skipped: + Precision::Exact(0) + } else if nr <= fetch_val && skip == 0 { + // If the input does not reach the `fetch` globally, and `skip` + // is zero (meaning the input and output are identical), return + // input stats as is. + // TODO: Can input stats still be used, but adjusted, when `skip` + // is non-zero? + return Ok(self); + } else if nr - skip <= fetch_val { + // After `skip` input rows are skipped, the remaining rows are + // less than or equal to the `fetch` values, so `num_rows` must + // equal the remaining rows. + check_num_rows( + (nr - skip).checked_mul(n_partitions), + // We know that we have an estimate for the number of rows: + self.num_rows.is_exact().unwrap(), + ) + } else { + // At this point we know that we were given a `fetch` value + // as the `None` case would go into the branch above. Since + // the input has more rows than `fetch + skip`, the number + // of rows will be the `fetch`, but we won't be able to + // predict the other statistics. + check_num_rows( + fetch_val.checked_mul(n_partitions), + // We know that we have an estimate for the number of rows: + self.num_rows.is_exact().unwrap(), + ) + } + } + Statistics { + num_rows: Precision::Absent, + .. + } => check_num_rows(fetch.and_then(|v| v.checked_mul(n_partitions)), false), + }; + self.column_statistics = Statistics::unknown_column(&schema); + self.total_byte_size = Precision::Absent; + Ok(self) + } +} + +/// Creates an estimate of the number of rows in the output using the given +/// optional value and exactness flag. +fn check_num_rows(value: Option, is_exact: bool) -> Precision { + if let Some(value) = value { + if is_exact { + Precision::Exact(value) + } else { + // If the input stats are inexact, so are the output stats. + Precision::Inexact(value) } + } else { + // If the estimate is not available (e.g. due to an overflow), we can + // not produce a reliable estimate. + Precision::Absent } } impl Display for Statistics { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { // string of column statistics let column_stats = self .column_statistics @@ -336,14 +431,25 @@ impl ColumnStatistics { } /// Returns a [`ColumnStatistics`] instance having all [`Precision::Absent`] parameters. - pub fn new_unknown() -> ColumnStatistics { - ColumnStatistics { + pub fn new_unknown() -> Self { + Self { null_count: Precision::Absent, max_value: Precision::Absent, min_value: Precision::Absent, distinct_count: Precision::Absent, } } + + /// If the exactness of a [`ColumnStatistics`] instance is lost, this + /// function relaxes the exactness of all information by converting them + /// [`Precision::Inexact`]. + pub fn to_inexact(mut self) -> Self { + self.null_count = self.null_count.to_inexact(); + self.max_value = self.max_value.to_inexact(); + self.min_value = self.min_value.to_inexact(); + self.distinct_count = self.distinct_count.to_inexact(); + self + } } #[cfg(test)] @@ -417,9 +523,9 @@ mod tests { let inexact_precision = Precision::Inexact(42); let absent_precision = Precision::::Absent; - assert_eq!(exact_precision.clone().to_inexact(), inexact_precision); - assert_eq!(inexact_precision.clone().to_inexact(), inexact_precision); - assert_eq!(absent_precision.clone().to_inexact(), absent_precision); + assert_eq!(exact_precision.to_inexact(), inexact_precision); + assert_eq!(inexact_precision.to_inexact(), inexact_precision); + assert_eq!(absent_precision.to_inexact(), absent_precision); } #[test] @@ -459,4 +565,20 @@ mod tests { assert_eq!(precision2.multiply(&precision3), Precision::Inexact(15)); assert_eq!(precision1.multiply(&absent_precision), Precision::Absent); } + + #[test] + fn test_precision_cloning() { + // Precision is copy + let precision: Precision = Precision::Exact(42); + let p2 = precision; + assert_eq!(precision, p2); + + // Precision is not copy (requires .clone()) + let precision: Precision = + Precision::Exact(ScalarValue::Int64(Some(42))); + // Clippy would complain about this if it were Copy + #[allow(clippy::redundant_clone)] + let p2 = precision.clone(); + assert_eq!(precision, p2); + } } diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index b6ccaa74d5fc..67f3da4f48de 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -62,7 +62,7 @@ impl std::fmt::Display for ResolvedTableReference { /// assert_eq!(table_reference, TableReference::bare("mytable")); /// /// // Get a table reference to 'MyTable' (note the capitalization) using double quotes -/// // (programatically it is better to use `TableReference::bare` for this) +/// // (programmatically it is better to use `TableReference::bare` for this) /// let table_reference = TableReference::from(r#""MyTable""#); /// assert_eq!(table_reference, TableReference::bare("MyTable")); /// diff --git a/datafusion/common/src/test_util.rs b/datafusion/common/src/test_util.rs index 9c3dfe62e119..d3b8c8451258 100644 --- a/datafusion/common/src/test_util.rs +++ b/datafusion/common/src/test_util.rs @@ -29,6 +29,27 @@ use std::{error::Error, path::PathBuf}; /// Expects to be called about like this: /// /// `assert_batch_eq!(expected_lines: &[&str], batches: &[RecordBatch])` +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow::record_batch::RecordBatch; +/// # use arrow_array::{ArrayRef, Int32Array}; +/// # use datafusion_common::assert_batches_eq; +/// let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); +/// let batch = RecordBatch::try_from_iter([("column", col)]).unwrap(); +/// // Expected output is a vec of strings +/// let expected = vec![ +/// "+--------+", +/// "| column |", +/// "+--------+", +/// "| 1 |", +/// "| 2 |", +/// "+--------+", +/// ]; +/// // compare the formatted output of the record batch with the expected output +/// assert_batches_eq!(expected, &[batch]); +/// ``` #[macro_export] macro_rules! assert_batches_eq { ($EXPECTED_LINES: expr, $CHUNKS: expr) => { @@ -56,8 +77,7 @@ macro_rules! assert_batches_eq { /// vector of strings in a way that order does not matter. /// This is a macro so errors appear on the correct line /// -/// Designed so that failure output can be directly copy/pasted -/// into the test code as expected results. +/// See [`assert_batches_eq`] for more details and example. /// /// Expects to be called about like this: /// @@ -259,8 +279,88 @@ pub fn get_data_dir( } } +#[macro_export] +macro_rules! create_array { + (Boolean, $values: expr) => { + std::sync::Arc::new(arrow::array::BooleanArray::from($values)) + }; + (Int8, $values: expr) => { + std::sync::Arc::new(arrow::array::Int8Array::from($values)) + }; + (Int16, $values: expr) => { + std::sync::Arc::new(arrow::array::Int16Array::from($values)) + }; + (Int32, $values: expr) => { + std::sync::Arc::new(arrow::array::Int32Array::from($values)) + }; + (Int64, $values: expr) => { + std::sync::Arc::new(arrow::array::Int64Array::from($values)) + }; + (UInt8, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt8Array::from($values)) + }; + (UInt16, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt16Array::from($values)) + }; + (UInt32, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt32Array::from($values)) + }; + (UInt64, $values: expr) => { + std::sync::Arc::new(arrow::array::UInt64Array::from($values)) + }; + (Float16, $values: expr) => { + std::sync::Arc::new(arrow::array::Float16Array::from($values)) + }; + (Float32, $values: expr) => { + std::sync::Arc::new(arrow::array::Float32Array::from($values)) + }; + (Float64, $values: expr) => { + std::sync::Arc::new(arrow::array::Float64Array::from($values)) + }; + (Utf8, $values: expr) => { + std::sync::Arc::new(arrow::array::StringArray::from($values)) + }; +} + +/// Creates a record batch from literal slice of values, suitable for rapid +/// testing and development. +/// +/// Example: +/// ``` +/// use datafusion_common::{record_batch, create_array}; +/// let batch = record_batch!( +/// ("a", Int32, vec![1, 2, 3]), +/// ("b", Float64, vec![Some(4.0), None, Some(5.0)]), +/// ("c", Utf8, vec!["alpha", "beta", "gamma"]) +/// ); +/// ``` +#[macro_export] +macro_rules! record_batch { + ($(($name: expr, $type: ident, $values: expr)),*) => { + { + let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![ + $( + arrow_schema::Field::new($name, arrow_schema::DataType::$type, true), + )* + ])); + + let batch = arrow_array::RecordBatch::try_new( + schema, + vec![$( + $crate::create_array!($type, $values), + )*] + ); + + batch + } + } +} + #[cfg(test)] mod tests { + use crate::cast::{as_float64_array, as_int32_array, as_string_array}; + use crate::error::Result; + use super::*; use std::env; @@ -313,4 +413,44 @@ mod tests { let res = parquet_test_data(); assert!(PathBuf::from(res).is_dir()); } + + #[test] + fn test_create_record_batch() -> Result<()> { + use arrow_array::Array; + + let batch = record_batch!( + ("a", Int32, vec![1, 2, 3, 4]), + ("b", Float64, vec![Some(4.0), None, Some(5.0), None]), + ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"]) + )?; + + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + let values: Vec<_> = as_int32_array(batch.column(0))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![1, 2, 3, 4]); + + let values: Vec<_> = as_float64_array(batch.column(1))? + .values() + .iter() + .map(|v| v.to_owned()) + .collect(); + assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]); + + let nulls: Vec<_> = as_float64_array(batch.column(1))? + .nulls() + .unwrap() + .iter() + .collect(); + assert_eq!(nulls, vec![true, false, true, false]); + + let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect(); + assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]); + + Ok(()) + } } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 43026f3a9206..a0ad1e80be9b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -43,14 +43,14 @@ macro_rules! handle_transform_recursion { /// There are three categories of TreeNode APIs: /// /// 1. "Inspecting" APIs to traverse a tree of `&TreeNodes`: -/// [`apply`], [`visit`], [`exists`]. +/// [`apply`], [`visit`], [`exists`]. /// /// 2. "Transforming" APIs that traverse and consume a tree of `TreeNode`s -/// producing possibly changed `TreeNode`s: [`transform`], [`transform_up`], -/// [`transform_down`], [`transform_down_up`], and [`rewrite`]. +/// producing possibly changed `TreeNode`s: [`transform`], [`transform_up`], +/// [`transform_down`], [`transform_down_up`], and [`rewrite`]. /// /// 3. Internal APIs used to implement the `TreeNode` API: [`apply_children`], -/// and [`map_children`]. +/// and [`map_children`]. /// /// | Traversal Order | Inspecting | Transforming | /// | --- | --- | --- | @@ -123,8 +123,8 @@ pub trait TreeNode: Sized { /// TreeNodeVisitor::f_up(ChildNode2) /// TreeNodeVisitor::f_up(ParentNode) /// ``` - fn visit>( - &self, + fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>( + &'n self, visitor: &mut V, ) -> Result { visitor @@ -190,12 +190,12 @@ pub trait TreeNode: Sized { /// # See Also /// * [`Self::transform_down`] for the equivalent transformation API. /// * [`Self::visit`] for both top-down and bottom up traversal. - fn apply Result>( - &self, + fn apply<'n, F: FnMut(&'n Self) -> Result>( + &'n self, mut f: F, ) -> Result { - fn apply_impl Result>( - node: &N, + fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result>( + node: &'n N, f: &mut F, ) -> Result { f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f))) @@ -238,15 +238,6 @@ pub trait TreeNode: Sized { transform_down_impl(self, &mut f) } - /// Same as [`Self::transform_down`] but with a mutable closure. - #[deprecated(since = "38.0.0", note = "Use `transform_down` instead")] - fn transform_down_mut Result>>( - self, - f: &mut F, - ) -> Result> { - self.transform_down(f) - } - /// Recursively rewrite the node using `f` in a bottom-up (post-order) /// fashion. /// @@ -271,15 +262,6 @@ pub trait TreeNode: Sized { transform_up_impl(self, &mut f) } - /// Same as [`Self::transform_up`] but with a mutable closure. - #[deprecated(since = "38.0.0", note = "Use `transform_up` instead")] - fn transform_up_mut Result>>( - self, - f: &mut F, - ) -> Result> { - self.transform_up(f) - } - /// Transforms the node using `f_down` while traversing the tree top-down /// (pre-order), and using `f_up` while traversing the tree bottom-up /// (post-order). @@ -427,8 +409,8 @@ pub trait TreeNode: Sized { /// /// Description: Apply `f` to inspect node's children (but not the node /// itself). - fn apply_children Result>( - &self, + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, f: F, ) -> Result; @@ -466,19 +448,19 @@ pub trait TreeNode: Sized { /// /// # See Also: /// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s -pub trait TreeNodeVisitor: Sized { +pub trait TreeNodeVisitor<'n>: Sized { /// The node type which is visitable. type Node: TreeNode; /// Invoked while traversing down the tree, before any children are visited. /// Default implementation continues the recursion. - fn f_down(&mut self, _node: &Self::Node) -> Result { + fn f_down(&mut self, _node: &'n Self::Node) -> Result { Ok(TreeNodeRecursion::Continue) } /// Invoked while traversing up the tree after children are visited. Default /// implementation continues the recursion. - fn f_up(&mut self, _node: &Self::Node) -> Result { + fn f_up(&mut self, _node: &'n Self::Node) -> Result { Ok(TreeNodeRecursion::Continue) } } @@ -486,6 +468,9 @@ pub trait TreeNodeVisitor: Sized { /// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively /// rewriting [`TreeNode`]s via [`TreeNode::rewrite`]. /// +/// For example you can implement this trait on a struct to rewrite `Expr` or +/// `LogicalPlan` that needs to track state during the rewrite. +/// /// See [`TreeNode`] for more details on available APIs /// /// When passed to [`TreeNode::rewrite`], [`TreeNodeRewriter::f_down`] and @@ -582,7 +567,11 @@ impl TreeNodeRecursion { /// Result of tree walk / transformation APIs /// -/// API users control the transformation by returning: +/// `Transformed` is a wrapper around the tree node data (e.g. `Expr` or +/// `LogicalPlan`). It is used to indicate whether the node was transformed +/// and how the recursion should proceed. +/// +/// [`TreeNode`] API users control the transformation by returning: /// - The resulting (possibly transformed) node, /// - `transformed`: flag indicating whether any change was made to the node /// - `tnr`: [`TreeNodeRecursion`] specifying how to proceed with the recursion. @@ -592,7 +581,66 @@ impl TreeNodeRecursion { /// - `transformed`: flag indicating whether any change was made to the node /// - `tnr`: [`TreeNodeRecursion`] specifying how the recursion ended. /// -/// Example APIs: +/// See also +/// * [`Transformed::update_data`] to modify the node without changing the `transformed` flag +/// * [`Transformed::map_data`] for fallable operation that return the same type +/// * [`Transformed::transform_data`] to chain fallable transformations +/// * [`TransformedResult`] for working with `Result>` +/// +/// # Examples +/// +/// Use [`Transformed::yes`] and [`Transformed::no`] to signal that a node was +/// rewritten and the recursion should continue: +/// +/// ``` +/// # use datafusion_common::tree_node::Transformed; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn orig_expr() -> i64 { 1 } +/// # fn make_new_expr(i: i64) -> i64 { 2 } +/// let expr = orig_expr(); +/// +/// // Create a new `Transformed` object signaling the node was not rewritten +/// let ret = Transformed::no(expr.clone()); +/// assert!(!ret.transformed); +/// +/// // Create a new `Transformed` object signaling the node was rewritten +/// let ret = Transformed::yes(expr); +/// assert!(ret.transformed) +/// ``` +/// +/// Access the node within the `Transformed` object: +/// ``` +/// # use datafusion_common::tree_node::Transformed; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn orig_expr() -> i64 { 1 } +/// # fn make_new_expr(i: i64) -> i64 { 2 } +/// let expr = orig_expr(); +/// +/// // `Transformed` object signaling the node was not rewritten +/// let ret = Transformed::no(expr.clone()); +/// // Access the inner object using .data +/// assert_eq!(expr, ret.data); +/// ``` +/// +/// Transform the node within the `Transformed` object. +/// +/// ``` +/// # use datafusion_common::tree_node::Transformed; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn orig_expr() -> i64 { 1 } +/// # fn make_new_expr(i: i64) -> i64 { 2 } +/// let expr = orig_expr(); +/// let ret = Transformed::no(expr.clone()) +/// .transform_data(|expr| { +/// // closure returns a result and potentially transforms the node +/// // in this example, it does transform the node +/// let new_expr = make_new_expr(expr); +/// Ok(Transformed::yes(new_expr)) +/// }).unwrap(); +/// // transformed flag is the union of the original ans closure's transformed flag +/// assert!(ret.transformed); +/// ``` +/// # Example APIs that use `TreeNode` /// - [`TreeNode`], /// - [`TreeNode::rewrite`], /// - [`TreeNode::transform_down`], @@ -615,6 +663,11 @@ impl Transformed { } } + /// Create a `Transformed` with `transformed` and [`TreeNodeRecursion::Continue`]. + pub fn new_transformed(data: T, transformed: bool) -> Self { + Self::new(data, transformed, TreeNodeRecursion::Continue) + } + /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement. pub fn yes(data: T) -> Self { Self::new(data, true, TreeNodeRecursion::Continue) @@ -625,17 +678,23 @@ impl Transformed { Self::new(data, false, TreeNodeRecursion::Continue) } - /// Applies the given `f` to the data of this [`Transformed`] object. + /// Applies an infallible `f` to the data of this [`Transformed`] object, + /// without modifying the `transformed` flag. pub fn update_data U>(self, f: F) -> Transformed { Transformed::new(f(self.data), self.transformed, self.tnr) } - /// Maps the data of [`Transformed`] object to the result of the given `f`. + /// Applies a fallible `f` (returns `Result`) to the data of this + /// [`Transformed`] object, without modifying the `transformed` flag. pub fn map_data Result>(self, f: F) -> Result> { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } - /// Maps the [`Transformed`] object to the result of the given `f`. + /// Applies a fallible transforming `f` to the data of this [`Transformed`] + /// object. + /// + /// The returned `Transformed` object has the `transformed` flag set if either + /// `self` or the return value of `f` have the `transformed` flag set. pub fn transform_data Result>>( self, f: F, @@ -822,6 +881,22 @@ macro_rules! map_until_stop_and_collect { } /// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. +/// +/// # Example +/// Access the internal data of a `Result>` +/// as a `Result` using the `data` method: +/// ``` +/// # use datafusion_common::Result; +/// # use datafusion_common::tree_node::{Transformed, TransformedResult}; +/// # // note use i64 instead of Expr as Expr is not in datafusion-common +/// # fn update_expr() -> i64 { 1 } +/// # fn main() -> Result<()> { +/// let transformed: Result> = Ok(Transformed::yes(update_expr())); +/// // access the internal data of the transformed result, or return the error +/// let transformed_expr = transformed.data()?; +/// # Ok(()) +/// # } +/// ``` pub trait TransformedResult { fn data(self) -> Result; @@ -849,7 +924,7 @@ impl TransformedResult for Result> { /// its related `Arc` will automatically implement [`TreeNode`]. pub trait DynTreeNode { /// Returns all children of the specified `TreeNode`. - fn arc_children(&self) -> Vec>; + fn arc_children(&self) -> Vec<&Arc>; /// Constructs a new node with the specified children. fn with_new_arc_children( @@ -862,11 +937,11 @@ pub trait DynTreeNode { /// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] /// (such as [`Arc`]). impl TreeNode for Arc { - fn apply_children Result>( - &self, + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, f: F, ) -> Result { - self.arc_children().iter().apply_until_stop(f) + self.arc_children().into_iter().apply_until_stop(f) } fn map_children Result>>( @@ -875,7 +950,10 @@ impl TreeNode for Arc { ) -> Result> { let children = self.arc_children(); if !children.is_empty() { - let new_children = children.into_iter().map_until_stop_and_collect(f)?; + let new_children = children + .into_iter() + .cloned() + .map_until_stop_and_collect(f)?; // Propagate up `new_children.transformed` and `new_children.tnr` // along with the node containing transformed children. if new_children.transformed { @@ -897,7 +975,7 @@ impl TreeNode for Arc { /// involving payloads, by enforcing rules for detaching and reattaching child nodes. pub trait ConcreteTreeNode: Sized { /// Provides read-only access to child nodes. - fn children(&self) -> Vec<&Self>; + fn children(&self) -> &[Self]; /// Detaches the node from its children, returning the node itself and its detached children. fn take_children(self) -> (Self, Vec); @@ -907,11 +985,11 @@ pub trait ConcreteTreeNode: Sized { } impl TreeNode for T { - fn apply_children Result>( - &self, + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, f: F, ) -> Result { - self.children().into_iter().apply_until_stop(f) + self.children().iter().apply_until_stop(f) } fn map_children Result>>( @@ -931,7 +1009,8 @@ impl TreeNode for T { } #[cfg(test)] -mod tests { +pub(crate) mod tests { + use std::collections::HashMap; use std::fmt::Display; use crate::tree_node::{ @@ -940,21 +1019,32 @@ mod tests { }; use crate::Result; - #[derive(PartialEq, Debug)] - struct TestTreeNode { - children: Vec>, - data: T, + #[derive(Debug, Eq, Hash, PartialEq, Clone)] + pub struct TestTreeNode { + pub(crate) children: Vec>, + pub(crate) data: T, } impl TestTreeNode { - fn new(children: Vec>, data: T) -> Self { + pub(crate) fn new(children: Vec>, data: T) -> Self { Self { children, data } } + + pub(crate) fn new_leaf(data: T) -> Self { + Self { + children: vec![], + data, + } + } + + pub(crate) fn is_leaf(&self) -> bool { + self.children.is_empty() + } } impl TreeNode for TestTreeNode { - fn apply_children Result>( - &self, + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, f: F, ) -> Result { self.children.iter().apply_until_stop(f) @@ -989,12 +1079,12 @@ mod tests { // | // A fn test_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1033,13 +1123,13 @@ mod tests { // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1049,12 +1139,12 @@ mod tests { // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1063,12 +1153,12 @@ mod tests { // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1105,12 +1195,12 @@ mod tests { } fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1139,12 +1229,12 @@ mod tests { } fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1153,12 +1243,12 @@ mod tests { } fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1192,12 +1282,12 @@ mod tests { } fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); @@ -1206,12 +1296,12 @@ mod tests { } fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); @@ -1275,12 +1365,12 @@ mod tests { } fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1288,12 +1378,12 @@ mod tests { } fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1309,12 +1399,12 @@ mod tests { } fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1322,12 +1412,12 @@ mod tests { } fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1354,12 +1444,12 @@ mod tests { } fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1367,12 +1457,12 @@ mod tests { } fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1402,13 +1492,13 @@ mod tests { } fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); @@ -1416,12 +1506,12 @@ mod tests { } fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); @@ -1453,15 +1543,15 @@ mod tests { } } - impl TreeNodeVisitor for TestVisitor { + impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor { type Node = TestTreeNode; - fn f_down(&mut self, node: &Self::Node) -> Result { + fn f_down(&mut self, node: &'n Self::Node) -> Result { self.visits.push(format!("f_down({})", node.data)); (*self.f_down)(node) } - fn f_up(&mut self, node: &Self::Node) -> Result { + fn f_up(&mut self, node: &'n Self::Node) -> Result { self.visits.push(format!("f_up({})", node.data)); (*self.f_up)(node) } @@ -1906,4 +1996,87 @@ mod tests { TreeNodeRecursion::Stop ) ); + + // F + // / | \ + // / | \ + // E C A + // | / \ + // C B D + // / \ | + // B D A + // | + // A + #[test] + fn test_apply_and_visit_references() -> Result<()> { + let node_a = TestTreeNode::new_leaf("a".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_a_2 = TestTreeNode::new_leaf("a".to_string()); + let node_b_2 = TestTreeNode::new_leaf("b".to_string()); + let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string()); + let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string()); + let node_a_3 = TestTreeNode::new_leaf("a".to_string()); + let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string()); + + let node_f_ref = &tree; + let node_e_ref = &node_f_ref.children[0]; + let node_c_ref = &node_e_ref.children[0]; + let node_b_ref = &node_c_ref.children[0]; + let node_d_ref = &node_c_ref.children[1]; + let node_a_ref = &node_d_ref.children[0]; + + let mut m: HashMap<&TestTreeNode, usize> = HashMap::new(); + tree.apply(|e| { + *m.entry(e).or_insert(0) += 1; + Ok(TreeNodeRecursion::Continue) + })?; + + let expected = HashMap::from([ + (node_f_ref, 1), + (node_e_ref, 1), + (node_c_ref, 2), + (node_d_ref, 2), + (node_b_ref, 2), + (node_a_ref, 3), + ]); + assert_eq!(m, expected); + + struct TestVisitor<'n> { + m: HashMap<&'n TestTreeNode, (usize, usize)>, + } + + impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> { + type Node = TestTreeNode; + + fn f_down(&mut self, node: &'n Self::Node) -> Result { + let (down_count, _) = self.m.entry(node).or_insert((0, 0)); + *down_count += 1; + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, node: &'n Self::Node) -> Result { + let (_, up_count) = self.m.entry(node).or_insert((0, 0)); + *up_count += 1; + Ok(TreeNodeRecursion::Continue) + } + } + + let mut visitor = TestVisitor { m: HashMap::new() }; + tree.visit(&mut visitor)?; + + let expected = HashMap::from([ + (node_f_ref, (1, 1)), + (node_e_ref, (1, 1)), + (node_c_ref, (2, 2)), + (node_d_ref, (2, 2)), + (node_b_ref, (2, 2)), + (node_a_ref, (3, 3)), + ]); + assert_eq!(visitor.m, expected); + + Ok(()) + } } diff --git a/datafusion/common/src/types/builtin.rs b/datafusion/common/src/types/builtin.rs new file mode 100644 index 000000000000..c6105d37c3bd --- /dev/null +++ b/datafusion/common/src/types/builtin.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::{LogicalTypeRef, NativeType}; +use std::sync::{Arc, OnceLock}; + +macro_rules! singleton { + ($name:ident, $getter:ident, $ty:ident) => { + // TODO: Use LazyLock instead of getter function when MSRV gets bumped + static $name: OnceLock = OnceLock::new(); + + #[doc = "Getter for singleton instance of a logical type representing"] + #[doc = concat!("[`NativeType::", stringify!($ty), "`].")] + pub fn $getter() -> LogicalTypeRef { + Arc::clone($name.get_or_init(|| Arc::new(NativeType::$ty))) + } + }; +} + +singleton!(LOGICAL_NULL, logical_null, Null); +singleton!(LOGICAL_BOOLEAN, logical_boolean, Boolean); +singleton!(LOGICAL_INT8, logical_int8, Int8); +singleton!(LOGICAL_INT16, logical_int16, Int16); +singleton!(LOGICAL_INT32, logical_int32, Int32); +singleton!(LOGICAL_INT64, logical_int64, Int64); +singleton!(LOGICAL_UINT8, logical_uint8, UInt8); +singleton!(LOGICAL_UINT16, logical_uint16, UInt16); +singleton!(LOGICAL_UINT32, logical_uint32, UInt32); +singleton!(LOGICAL_UINT64, logical_uint64, UInt64); +singleton!(LOGICAL_FLOAT16, logical_float16, Float16); +singleton!(LOGICAL_FLOAT32, logical_float32, Float32); +singleton!(LOGICAL_FLOAT64, logical_float64, Float64); +singleton!(LOGICAL_DATE, logical_date, Date); +singleton!(LOGICAL_BINARY, logical_binary, Binary); +singleton!(LOGICAL_STRING, logical_string, String); diff --git a/datafusion/common/src/types/field.rs b/datafusion/common/src/types/field.rs new file mode 100644 index 000000000000..85c7c157272a --- /dev/null +++ b/datafusion/common/src/types/field.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Field, Fields, UnionFields}; +use std::hash::{Hash, Hasher}; +use std::{ops::Deref, sync::Arc}; + +use super::{LogicalTypeRef, NativeType}; + +/// A record of a logical type, its name and its nullability. +#[derive(Debug, Clone, Eq, PartialOrd, Ord)] +pub struct LogicalField { + pub name: String, + pub logical_type: LogicalTypeRef, + pub nullable: bool, +} + +impl PartialEq for LogicalField { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.logical_type.eq(&other.logical_type) + && self.nullable == other.nullable + } +} + +impl Hash for LogicalField { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.logical_type.hash(state); + self.nullable.hash(state); + } +} + +impl From<&Field> for LogicalField { + fn from(value: &Field) -> Self { + Self { + name: value.name().clone(), + logical_type: Arc::new(NativeType::from(value.data_type().clone())), + nullable: value.is_nullable(), + } + } +} + +/// A reference counted [`LogicalField`]. +pub type LogicalFieldRef = Arc; + +/// A cheaply cloneable, owned collection of [`LogicalFieldRef`]. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct LogicalFields(Arc<[LogicalFieldRef]>); + +impl Deref for LogicalFields { + type Target = [LogicalFieldRef]; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl From<&Fields> for LogicalFields { + fn from(value: &Fields) -> Self { + value + .iter() + .map(|field| Arc::new(LogicalField::from(field.as_ref()))) + .collect() + } +} + +impl FromIterator for LogicalFields { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} + +/// A cheaply cloneable, owned collection of [`LogicalFieldRef`] and their +/// corresponding type ids. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct LogicalUnionFields(Arc<[(i8, LogicalFieldRef)]>); + +impl Deref for LogicalUnionFields { + type Target = [(i8, LogicalFieldRef)]; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl From<&UnionFields> for LogicalUnionFields { + fn from(value: &UnionFields) -> Self { + value + .iter() + .map(|(i, field)| (i, Arc::new(LogicalField::from(field.as_ref())))) + .collect() + } +} + +impl FromIterator<(i8, LogicalFieldRef)> for LogicalUnionFields { + fn from_iter>(iter: T) -> Self { + Self(iter.into_iter().collect()) + } +} diff --git a/datafusion/common/src/types/logical.rs b/datafusion/common/src/types/logical.rs new file mode 100644 index 000000000000..bde393992a0c --- /dev/null +++ b/datafusion/common/src/types/logical.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::NativeType; +use crate::error::Result; +use arrow_schema::DataType; +use core::fmt; +use std::{cmp::Ordering, hash::Hash, sync::Arc}; + +/// Signature that uniquely identifies a type among other types. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum TypeSignature<'a> { + /// Represents a built-in native type. + Native(&'a NativeType), + /// Represents an arrow-compatible extension type. + /// () + /// + /// The `name` should contain the same value as 'ARROW:extension:name'. + Extension { + name: &'a str, + parameters: &'a [TypeParameter<'a>], + }, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum TypeParameter<'a> { + Type(TypeSignature<'a>), + Number(i128), +} + +/// A reference counted [`LogicalType`]. +pub type LogicalTypeRef = Arc; + +/// Representation of a logical type with its signature and its native backing +/// type. +/// +/// The logical type is meant to be used during the DataFusion logical planning +/// phase in order to reason about logical types without worrying about their +/// underlying physical implementation. +/// +/// ### Extension types +/// +/// [`LogicalType`] is a trait in order to allow the possibility of declaring +/// extension types: +/// +/// ``` +/// use datafusion_common::types::{LogicalType, NativeType, TypeSignature}; +/// +/// struct JSON {} +/// +/// impl LogicalType for JSON { +/// fn native(&self) -> &NativeType { +/// &NativeType::String +/// } +/// +/// fn signature(&self) -> TypeSignature<'_> { +/// TypeSignature::Extension { +/// name: "JSON", +/// parameters: &[], +/// } +/// } +/// } +/// ``` +pub trait LogicalType: Sync + Send { + /// Get the native backing type of this logical type. + fn native(&self) -> &NativeType; + /// Get the unique type signature for this logical type. Logical types with identical + /// signatures are considered equal. + fn signature(&self) -> TypeSignature<'_>; + + /// Get the default physical type to cast `origin` to in order to obtain a physical type + /// that is logically compatible with this logical type. + fn default_cast_for(&self, origin: &DataType) -> Result { + self.native().default_cast_for(origin) + } +} + +impl fmt::Debug for dyn LogicalType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("LogicalType") + .field(&self.signature()) + .field(&self.native()) + .finish() + } +} + +impl PartialEq for dyn LogicalType { + fn eq(&self, other: &Self) -> bool { + self.signature().eq(&other.signature()) + } +} + +impl Eq for dyn LogicalType {} + +impl PartialOrd for dyn LogicalType { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for dyn LogicalType { + fn cmp(&self, other: &Self) -> Ordering { + self.signature() + .cmp(&other.signature()) + .then(self.native().cmp(other.native())) + } +} + +impl Hash for dyn LogicalType { + fn hash(&self, state: &mut H) { + self.signature().hash(state); + self.native().hash(state); + } +} diff --git a/datafusion/core/src/variable/mod.rs b/datafusion/common/src/types/mod.rs similarity index 85% rename from datafusion/core/src/variable/mod.rs rename to datafusion/common/src/types/mod.rs index 475f7570a8ee..2f9ce4ce0282 100644 --- a/datafusion/core/src/variable/mod.rs +++ b/datafusion/common/src/types/mod.rs @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! Variable provider for `@name` and `@@name` style runtime values. +mod builtin; +mod field; +mod logical; +mod native; -pub use datafusion_expr::var_provider::{VarProvider, VarType}; +pub use builtin::*; +pub use field::*; +pub use logical::*; +pub use native::*; diff --git a/datafusion/common/src/types/native.rs b/datafusion/common/src/types/native.rs new file mode 100644 index 000000000000..bfb546783ea2 --- /dev/null +++ b/datafusion/common/src/types/native.rs @@ -0,0 +1,399 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{ + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, + TypeSignature, +}; +use crate::error::{Result, _internal_err}; +use arrow::compute::can_cast_types; +use arrow_schema::{ + DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, +}; +use std::sync::Arc; + +/// Representation of a type that DataFusion can handle natively. It is a subset +/// of the physical variants in Arrow's native [`DataType`]. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum NativeType { + /// Null type + Null, + /// A boolean type representing the values `true` and `false`. + Boolean, + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, + /// A 16-bit floating point number. + Float16, + /// A 32-bit floating point number. + Float32, + /// A 64-bit floating point number. + Float64, + /// A timestamp with an optional timezone. + /// + /// Time is measured as a Unix epoch, counting the seconds from + /// 00:00:00.000 on 1 January 1970, excluding leap seconds, + /// as a signed 64-bit integer. + /// + /// The time zone is a string indicating the name of a time zone, one of: + /// + /// * As used in the Olson time zone database (the "tz database" or + /// "tzdata"), such as "America/New_York" + /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + /// + /// Timestamps with a non-empty timezone + /// ------------------------------------ + /// + /// If a Timestamp column has a non-empty timezone value, its epoch is + /// 1970-01-01 00:00:00 (January 1st 1970, midnight) in the *UTC* timezone + /// (the Unix epoch), regardless of the Timestamp's own timezone. + /// + /// Therefore, timestamp values with a non-empty timezone correspond to + /// physical points in time together with some additional information about + /// how the data was obtained and/or how to display it (the timezone). + /// + /// For example, the timestamp value 0 with the timezone string "Europe/Paris" + /// corresponds to "January 1st 1970, 00h00" in the UTC timezone, but the + /// application may prefer to display it as "January 1st 1970, 01h00" in + /// the Europe/Paris timezone (which is the same physical point in time). + /// + /// One consequence is that timestamp values with a non-empty timezone + /// can be compared and ordered directly, since they all share the same + /// well-known point of reference (the Unix epoch). + /// + /// Timestamps with an unset / empty timezone + /// ----------------------------------------- + /// + /// If a Timestamp column has no timezone value, its epoch is + /// 1970-01-01 00:00:00 (January 1st 1970, midnight) in an *unknown* timezone. + /// + /// Therefore, timestamp values without a timezone cannot be meaningfully + /// interpreted as physical points in time, but only as calendar / clock + /// indications ("wall clock time") in an unspecified timezone. + /// + /// For example, the timestamp value 0 with an empty timezone string + /// corresponds to "January 1st 1970, 00h00" in an unknown timezone: there + /// is not enough information to interpret it as a well-defined physical + /// point in time. + /// + /// One consequence is that timestamp values without a timezone cannot + /// be reliably compared or ordered, since they may have different points of + /// reference. In particular, it is *not* possible to interpret an unset + /// or empty timezone as the same as "UTC". + /// + /// Conversion between timezones + /// ---------------------------- + /// + /// If a Timestamp column has a non-empty timezone, changing the timezone + /// to a different non-empty value is a metadata-only operation: + /// the timestamp values need not change as their point of reference remains + /// the same (the Unix epoch). + /// + /// However, if a Timestamp column has no timezone value, changing it to a + /// non-empty value requires to think about the desired semantics. + /// One possibility is to assume that the original timestamp values are + /// relative to the epoch of the timezone being set; timestamp values should + /// then adjusted to the Unix epoch (for example, changing the timezone from + /// empty to "Europe/Paris" would require converting the timestamp values + /// from "Europe/Paris" to "UTC", which seems counter-intuitive but is + /// nevertheless correct). + /// + /// ``` + /// # use arrow_schema::{DataType, TimeUnit}; + /// DataType::Timestamp(TimeUnit::Second, None); + /// DataType::Timestamp(TimeUnit::Second, Some("literal".into())); + /// DataType::Timestamp(TimeUnit::Second, Some("string".to_string().into())); + /// ``` + Timestamp(TimeUnit, Option>), + /// A signed date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days. + Date, + /// A signed time representing the elapsed time since midnight in the unit of `TimeUnit`. + Time(TimeUnit), + /// Measure of elapsed time in either seconds, milliseconds, microseconds or nanoseconds. + Duration(TimeUnit), + /// A "calendar" interval which models types that don't necessarily + /// have a precise duration without the context of a base timestamp (e.g. + /// days can differ in length during day light savings time transitions). + Interval(IntervalUnit), + /// Opaque binary data of variable length. + Binary, + /// Opaque binary data of fixed size. + /// Enum parameter specifies the number of bytes per value. + FixedSizeBinary(i32), + /// A variable-length string in Unicode with UTF-8 encoding. + String, + /// A list of some logical data type with variable length. + List(LogicalFieldRef), + /// A list of some logical data type with fixed length. + FixedSizeList(LogicalFieldRef, i32), + /// A nested type that contains a number of sub-fields. + Struct(LogicalFields), + /// A nested type that can represent slots of differing types. + Union(LogicalUnionFields), + /// Decimal value with precision and scale + /// + /// * precision is the total number of digits + /// * scale is the number of digits past the decimal + /// + /// For example the number 123.45 has precision 5 and scale 2. + /// + /// In certain situations, scale could be negative number. For + /// negative scale, it is the number of padding 0 to the right + /// of the digits. + /// + /// For example the number 12300 could be treated as a decimal + /// has precision 3 and scale -2. + Decimal(u8, i8), + /// A Map is a type that an association between a key and a value. + /// + /// The key and value types are not constrained, but keys should be + /// hashable and unique. + /// + /// In a field with Map type, key type and the second the value type. The names of the + /// child fields may be respectively "entries", "key", and "value", but this is + /// not enforced. + Map(LogicalFieldRef), +} + +impl LogicalType for NativeType { + fn native(&self) -> &NativeType { + self + } + + fn signature(&self) -> TypeSignature<'_> { + TypeSignature::Native(self) + } + + fn default_cast_for(&self, origin: &DataType) -> Result { + use DataType::*; + + fn default_field_cast(to: &LogicalField, from: &Field) -> Result { + Ok(Arc::new(Field::new( + to.name.clone(), + to.logical_type.default_cast_for(from.data_type())?, + to.nullable, + ))) + } + + Ok(match (self, origin) { + (Self::Null, _) => Null, + (Self::Boolean, _) => Boolean, + (Self::Int8, _) => Int8, + (Self::Int16, _) => Int16, + (Self::Int32, _) => Int32, + (Self::Int64, _) => Int64, + (Self::UInt8, _) => UInt8, + (Self::UInt16, _) => UInt16, + (Self::UInt32, _) => UInt32, + (Self::UInt64, _) => UInt64, + (Self::Float16, _) => Float16, + (Self::Float32, _) => Float32, + (Self::Float64, _) => Float64, + (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), + (Self::Decimal(p, s), _) => Decimal256(*p, *s), + (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), + (Self::Date, _) => Date32, + (Self::Time(tu), _) => match tu { + TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), + TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu), + }, + (Self::Duration(tu), _) => Duration(*tu), + (Self::Interval(iu), _) => Interval(*iu), + (Self::Binary, LargeUtf8) => LargeBinary, + (Self::Binary, Utf8View) => BinaryView, + (Self::Binary, data_type) if can_cast_types(data_type, &BinaryView) => { + BinaryView + } + (Self::Binary, data_type) if can_cast_types(data_type, &LargeBinary) => { + LargeBinary + } + (Self::Binary, data_type) if can_cast_types(data_type, &Binary) => Binary, + (Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size), + (Self::String, LargeBinary) => LargeUtf8, + (Self::String, BinaryView) => Utf8View, + (Self::String, data_type) if can_cast_types(data_type, &Utf8View) => Utf8View, + (Self::String, data_type) if can_cast_types(data_type, &LargeUtf8) => { + LargeUtf8 + } + (Self::String, data_type) if can_cast_types(data_type, &Utf8) => Utf8, + (Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => { + List(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), LargeList(from_field)) => { + LargeList(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), ListView(from_field)) => { + ListView(default_field_cast(to_field, from_field)?) + } + (Self::List(to_field), LargeListView(from_field)) => { + LargeListView(default_field_cast(to_field, from_field)?) + } + // List array where each element is a len 1 list of the origin type + (Self::List(field), _) => List(Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(origin)?, + field.nullable, + ))), + ( + Self::FixedSizeList(to_field, to_size), + FixedSizeList(from_field, from_size), + ) if from_size == to_size => { + FixedSizeList(default_field_cast(to_field, from_field)?, *to_size) + } + ( + Self::FixedSizeList(to_field, size), + List(from_field) + | LargeList(from_field) + | ListView(from_field) + | LargeListView(from_field), + ) => FixedSizeList(default_field_cast(to_field, from_field)?, *size), + // FixedSizeList array where each element is a len 1 list of the origin type + (Self::FixedSizeList(field, size), _) => FixedSizeList( + Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(origin)?, + field.nullable, + )), + *size, + ), + // From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196 + (Self::Struct(to_fields), Struct(from_fields)) + if from_fields.len() == to_fields.len() => + { + Struct( + from_fields + .iter() + .zip(to_fields.iter()) + .map(|(from, to)| default_field_cast(to, from)) + .collect::>()?, + ) + } + (Self::Struct(to_fields), Null) => Struct( + to_fields + .iter() + .map(|field| { + Ok(Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(&Null)?, + field.nullable, + ))) + }) + .collect::>()?, + ), + (Self::Map(to_field), Map(from_field, sorted)) => { + Map(default_field_cast(to_field, from_field)?, *sorted) + } + (Self::Map(field), Null) => Map( + Arc::new(Field::new( + field.name.clone(), + field.logical_type.default_cast_for(&Null)?, + field.nullable, + )), + false, + ), + (Self::Union(to_fields), Union(from_fields, mode)) + if from_fields.len() == to_fields.len() => + { + Union( + from_fields + .iter() + .zip(to_fields.iter()) + .map(|((_, from), (i, to))| { + Ok((*i, default_field_cast(to, from)?)) + }) + .collect::>()?, + *mode, + ) + } + _ => { + return _internal_err!( + "Unavailable default cast for native type {:?} from physical type {:?}", + self, + origin + ) + } + }) + } +} + +// The following From, From, ... implementations are temporary +// mapping solutions to provide backwards compatibility while transitioning from +// the purely physical system to a logical / physical system. + +impl From for NativeType { + fn from(value: DataType) -> Self { + use NativeType::*; + match value { + DataType::Null => Null, + DataType::Boolean => Boolean, + DataType::Int8 => Int8, + DataType::Int16 => Int16, + DataType::Int32 => Int32, + DataType::Int64 => Int64, + DataType::UInt8 => UInt8, + DataType::UInt16 => UInt16, + DataType::UInt32 => UInt32, + DataType::UInt64 => UInt64, + DataType::Float16 => Float16, + DataType::Float32 => Float32, + DataType::Float64 => Float64, + DataType::Timestamp(tu, tz) => Timestamp(tu, tz), + DataType::Date32 | DataType::Date64 => Date, + DataType::Time32(tu) | DataType::Time64(tu) => Time(tu), + DataType::Duration(tu) => Duration(tu), + DataType::Interval(iu) => Interval(iu), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => Binary, + DataType::FixedSizeBinary(size) => FixedSizeBinary(size), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => String, + DataType::List(field) + | DataType::ListView(field) + | DataType::LargeList(field) + | DataType::LargeListView(field) => List(Arc::new(field.as_ref().into())), + DataType::FixedSizeList(field, size) => { + FixedSizeList(Arc::new(field.as_ref().into()), size) + } + DataType::Struct(fields) => Struct(LogicalFields::from(&fields)), + DataType::Union(union_fields, _) => { + Union(LogicalUnionFields::from(&union_fields)) + } + DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s), + DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())), + DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), + DataType::RunEndEncoded(_, field) => field.data_type().clone().into(), + } + } +} + +impl From<&DataType> for NativeType { + fn from(value: &DataType) -> Self { + value.clone().into() + } +} diff --git a/datafusion/common/src/unnest.rs b/datafusion/common/src/unnest.rs index fd92267f9b4c..db48edd06160 100644 --- a/datafusion/common/src/unnest.rs +++ b/datafusion/common/src/unnest.rs @@ -17,6 +17,8 @@ //! [`UnnestOptions`] for unnesting structured types +use crate::Column; + /// Options for unnesting a column that contains a list type, /// replicating values in the other, non nested rows. /// @@ -60,10 +62,27 @@ /// └─────────┘ └─────┘ └─────────┘ └─────┘ /// c1 c2 c1 c2 /// ``` +/// +/// `recursions` instruct how a column should be unnested (e.g unnesting a column multiple +/// time, with depth = 1 and depth = 2). Any unnested column not being mentioned inside this +/// options is inferred to be unnested with depth = 1 #[derive(Debug, Clone, PartialEq, PartialOrd, Hash, Eq)] pub struct UnnestOptions { /// Should nulls in the input be preserved? Defaults to true pub preserve_nulls: bool, + /// If specific columns need to be unnested multiple times (e.g at different depth), + /// declare them here. Any unnested columns not being mentioned inside this option + /// will be unnested with depth = 1 + pub recursions: Vec, +} + +/// Instruction on how to unnest a column (mostly with a list type) +/// such as how to name the output, and how many level it should be unnested +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct RecursionUnnestOption { + pub input_column: Column, + pub output_column: Column, + pub depth: usize, } impl Default for UnnestOptions { @@ -71,6 +90,7 @@ impl Default for UnnestOptions { Self { // default to true to maintain backwards compatible behavior preserve_nulls: true, + recursions: vec![], } } } @@ -87,4 +107,10 @@ impl UnnestOptions { self.preserve_nulls = preserve_nulls; self } + + /// Set the recursions for the unnest operation + pub fn with_recursions(mut self, recursion: RecursionUnnestOption) -> Self { + self.recursions.push(recursion); + self + } } diff --git a/datafusion/common/src/utils/expr.rs b/datafusion/common/src/utils/expr.rs new file mode 100644 index 000000000000..0fe4546b8538 --- /dev/null +++ b/datafusion/common/src/utils/expr.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression utilities + +use crate::ScalarValue; + +/// The value to which `COUNT(*)` is expanded to in +/// `COUNT()` expressions +pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1)); diff --git a/datafusion/common/src/utils/memory.rs b/datafusion/common/src/utils/memory.rs new file mode 100644 index 000000000000..d5ce59e3421b --- /dev/null +++ b/datafusion/common/src/utils/memory.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides a function to estimate the memory size of a HashTable prior to alloaction + +use crate::{DataFusionError, Result}; +use std::mem::size_of; + +/// Estimates the memory size required for a hash table prior to allocation. +/// +/// # Parameters +/// - `num_elements`: The number of elements expected in the hash table. +/// - `fixed_size`: A fixed overhead size associated with the collection +/// (e.g., HashSet or HashTable). +/// - `T`: The type of elements stored in the hash table. +/// +/// # Details +/// This function calculates the estimated memory size by considering: +/// - An overestimation of buckets to keep approximately 1/8 of them empty. +/// - The total memory size is computed as: +/// - The size of each entry (`T`) multiplied by the estimated number of +/// buckets. +/// - One byte overhead for each bucket. +/// - The fixed size overhead of the collection. +/// - If the estimation overflows, we return a [`DataFusionError`] +/// +/// # Examples +/// --- +/// +/// ## From within a struct +/// +/// ```rust +/// # use datafusion_common::utils::memory::estimate_memory_size; +/// # use datafusion_common::Result; +/// +/// struct MyStruct { +/// values: Vec, +/// other_data: usize, +/// } +/// +/// impl MyStruct { +/// fn size(&self) -> Result { +/// let num_elements = self.values.len(); +/// let fixed_size = std::mem::size_of_val(self) + +/// std::mem::size_of_val(&self.values); +/// +/// estimate_memory_size::(num_elements, fixed_size) +/// } +/// } +/// ``` +/// --- +/// ## With a simple collection +/// +/// ```rust +/// # use datafusion_common::utils::memory::estimate_memory_size; +/// # use std::collections::HashMap; +/// +/// let num_rows = 100; +/// let fixed_size = std::mem::size_of::>(); +/// let estimated_hashtable_size = +/// estimate_memory_size::<(u64, u64)>(num_rows,fixed_size) +/// .expect("Size estimation failed"); +/// ``` +pub fn estimate_memory_size(num_elements: usize, fixed_size: usize) -> Result { + // For the majority of cases hashbrown overestimates the bucket quantity + // to keep ~1/8 of them empty. We take this factor into account by + // multiplying the number of elements with a fixed ratio of 8/7 (~1.14). + // This formula leads to overallocation for small tables (< 8 elements) + // but should be fine overall. + num_elements + .checked_mul(8) + .and_then(|overestimate| { + let estimated_buckets = (overestimate / 7).next_power_of_two(); + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of collection (HashSet/HashTable) + size_of::() + .checked_mul(estimated_buckets)? + .checked_add(estimated_buckets)? + .checked_add(fixed_size) + }) + .ok_or_else(|| { + DataFusionError::Execution( + "usize overflow while estimating the number of buckets".to_string(), + ) + }) +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::estimate_memory_size; + + #[test] + fn test_estimate_memory() { + // size (bytes): 48 + let fixed_size = size_of::>(); + + // estimated buckets: 16 = (8 * 8 / 7).next_power_of_two() + let num_elements = 8; + // size (bytes): 128 = 16 * 4 + 16 + 48 + let estimated = estimate_memory_size::(num_elements, fixed_size).unwrap(); + assert_eq!(estimated, 128); + + // estimated buckets: 64 = (40 * 8 / 7).next_power_of_two() + let num_elements = 40; + // size (bytes): 368 = 64 * 4 + 64 + 48 + let estimated = estimate_memory_size::(num_elements, fixed_size).unwrap(); + assert_eq!(estimated, 368); + } + + #[test] + fn test_estimate_memory_overflow() { + let num_elements = usize::MAX; + let fixed_size = size_of::>(); + let estimated = estimate_memory_size::(num_elements, fixed_size); + + assert!(estimated.is_err()); + } +} diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils/mod.rs similarity index 86% rename from datafusion/common/src/utils.rs rename to datafusion/common/src/utils/mod.rs index 102e4d73083e..dacf90af9bbf 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils/mod.rs @@ -17,23 +17,27 @@ //! This module provides the bisect function, which implements binary search. +pub mod expr; +pub mod memory; +pub mod proxy; +pub mod string_utils; + use crate::error::{_internal_datafusion_err, _internal_err}; -use crate::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use arrow::array::{ArrayRef, PrimitiveArray}; +use crate::{DataFusionError, Result, ScalarValue}; +use arrow::array::ArrayRef; use arrow::buffer::OffsetBuffer; -use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{Field, SchemaRef, UInt32Type}; -use arrow::record_batch::RecordBatch; +use arrow::datatypes::{Field, SchemaRef}; +use arrow_array::cast::AsArray; use arrow_array::{ - Array, FixedSizeListArray, LargeListArray, ListArray, RecordBatchOptions, + Array, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, }; use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; use std::borrow::{Borrow, Cow}; -use std::cmp::Ordering; +use std::cmp::{min, Ordering}; use std::collections::HashSet; use std::ops::Range; use std::sync::Arc; @@ -86,20 +90,6 @@ pub fn get_row_at_idx(columns: &[ArrayRef], idx: usize) -> Result, -) -> Result { - let new_columns = get_arrayref_at_indices(record_batch.columns(), indices)?; - RecordBatch::try_new_with_options( - record_batch.schema(), - new_columns, - &RecordBatchOptions::new().with_row_count(Some(indices.len())), - ) - .map_err(|e| arrow_datafusion_err!(e)) -} - /// This function compares two tuples depending on the given sort options. pub fn compare_rows( x: &[ScalarValue], @@ -242,7 +232,10 @@ pub fn evaluate_partition_ranges( end: num_rows, }] } else { - let cols: Vec<_> = partition_columns.iter().map(|x| x.values.clone()).collect(); + let cols: Vec<_> = partition_columns + .iter() + .map(|x| Arc::clone(&x.values)) + .collect(); partition(&cols)?.ranges() }) } @@ -280,24 +273,6 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { Ok(idents) } -/// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. -pub fn get_arrayref_at_indices( - arrays: &[ArrayRef], - indices: &PrimitiveArray, -) -> Result> { - arrays - .iter() - .map(|array| { - compute::take( - array.as_ref(), - indices, - None, // None: no index check - ) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect() -} - pub(crate) fn parse_identifiers_normalized(s: &str, ignore_case: bool) -> Vec { parse_identifiers(s) .unwrap_or_default() @@ -329,7 +304,7 @@ pub fn get_at_indices>( /// This function finds the longest prefix of the form 0, 1, 2, ... within the /// collection `sequence`. Examples: /// - For 0, 1, 2, 4, 5; we would produce 3, meaning 0, 1, 2 is the longest satisfying -/// prefix. +/// prefix. /// - For 1, 2, 3, 4; we would produce 0, meaning there is no such prefix. pub fn longest_consecutive_prefix>( sequence: impl IntoIterator, @@ -348,10 +323,19 @@ pub fn longest_consecutive_prefix>( /// Wrap an array into a single element `ListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -pub fn array_into_list_array(arr: ArrayRef) -> ListArray { +/// The field in the list array is nullable. +pub fn array_into_list_array_nullable(arr: ArrayRef) -> ListArray { + array_into_list_array(arr, true) +} + +/// Array Utils + +/// Wrap an array into a single element `ListArray`. +/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` +pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); ListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), offsets, arr, None, @@ -363,7 +347,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray { pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), offsets, arr, None, @@ -376,7 +360,7 @@ pub fn array_into_fixed_size_list_array( ) -> FixedSizeListArray { let list_size = list_size as i32; FixedSizeListArray::new( - Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), list_size, arr, None, @@ -417,13 +401,23 @@ pub fn arrays_into_list_array( let data_type = arr[0].data_type().to_owned(); let values = arr.iter().map(|x| x.as_ref()).collect::>(); Ok(ListArray::new( - Arc::new(Field::new("item", data_type, true)), + Arc::new(Field::new_list_field(data_type, true)), OffsetBuffer::from_lengths(lens), arrow::compute::concat(values.as_slice())?, None, )) } +/// Helper function to convert a ListArray into a vector of ArrayRefs. +pub fn list_to_arrays(a: &ArrayRef) -> Vec { + a.as_list::().iter().flatten().collect::>() +} + +/// Helper function to convert a FixedSizeListArray into a vector of ArrayRefs. +pub fn fixed_size_list_to_arrays(a: &ArrayRef) -> Vec { + a.as_fixed_size_list().iter().flatten().collect::>() +} + /// Get the base type of a data type. /// /// Example @@ -432,7 +426,7 @@ pub fn arrays_into_list_array( /// use datafusion_common::utils::base_type; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// assert_eq!(base_type(&data_type), DataType::Int32); /// /// let data_type = DataType::Int32; @@ -455,10 +449,10 @@ pub fn base_type(data_type: &DataType) -> DataType { /// use datafusion_common::utils::coerced_type_with_base_type_only; /// use std::sync::Arc; /// -/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); +/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); /// let base_type = DataType::Float64; /// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type); -/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true)))); +/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true)))); pub fn coerced_type_with_base_type_only( data_type: &DataType, base_type: &DataType, @@ -525,34 +519,6 @@ pub fn list_ndims(data_type: &DataType) -> u64 { } } -/// An extension trait for smart pointers. Provides an interface to get a -/// raw pointer to the data (with metadata stripped away). -/// -/// This is useful to see if two smart pointers point to the same allocation. -pub trait DataPtr { - /// Returns a raw pointer to the data, stripping away all metadata. - fn data_ptr(this: &Self) -> *const (); - - /// Check if two pointers point to the same data. - fn data_ptr_eq(this: &Self, other: &Self) -> bool { - // Discard pointer metadata (including the v-table). - let this = Self::data_ptr(this); - let other = Self::data_ptr(other); - - std::ptr::eq(this, other) - } -} - -// Currently, it's brittle to compare `Arc`s of dyn traits with `Arc::ptr_eq` -// due to this check including v-table equality. It may be possible to use -// `Arc::ptr_eq` directly if a fix to https://github.com/rust-lang/rust/issues/103763 -// is stabilized. -impl DataPtr for Arc { - fn data_ptr(this: &Self) -> *const () { - Arc::as_ptr(this) as *const () - } -} - /// Adopted from strsim-rs for string similarity metrics pub mod datafusion_strsim { // Source: https://github.com/dguo/strsim-rs/blob/master/src/lib.rs @@ -679,6 +645,85 @@ pub fn find_indices>( .ok_or_else(|| DataFusionError::Execution("Target not found".to_string())) } +/// Transposes the given vector of vectors. +pub fn transpose(original: Vec>) -> Vec> { + match original.as_slice() { + [] => vec![], + [first, ..] => { + let mut result = (0..first.len()).map(|_| vec![]).collect::>(); + for row in original { + for (item, transposed_row) in row.into_iter().zip(&mut result) { + transposed_row.push(item); + } + } + result + } + } +} + +/// Computes the `skip` and `fetch` parameters of a single limit that would be +/// equivalent to two consecutive limits with the given `skip`/`fetch` parameters. +/// +/// There are multiple cases to consider: +/// +/// # Case 0: Parent and child are disjoint (`child_fetch <= skip`). +/// +/// ```text +/// Before merging: +/// |........skip........|---fetch-->| Parent limit +/// |...child_skip...|---child_fetch-->| Child limit +/// ``` +/// +/// After merging: +/// ```text +/// |.........(child_skip + skip).........| +/// ``` +/// +/// # Case 1: Parent is beyond child's range (`skip < child_fetch <= skip + fetch`). +/// +/// Before merging: +/// ```text +/// |...skip...|------------fetch------------>| Parent limit +/// |...child_skip...|-------------child_fetch------------>| Child limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---(child_fetch - skip)-->| +/// ``` +/// +/// # Case 2: Parent is within child's range (`skip + fetch < child_fetch`). +/// +/// Before merging: +/// ```text +/// |...skip...|---fetch-->| Parent limit +/// |...child_skip...|-------------child_fetch------------>| Child limit +/// ``` +/// +/// After merging: +/// ```text +/// |....(child_skip + skip)....|---fetch-->| +/// ``` +pub fn combine_limit( + parent_skip: usize, + parent_fetch: Option, + child_skip: usize, + child_fetch: Option, +) -> (usize, Option) { + let combined_skip = child_skip.saturating_add(parent_skip); + + let combined_fetch = match (parent_fetch, child_fetch) { + (Some(parent_fetch), Some(child_fetch)) => { + Some(min(parent_fetch, child_fetch.saturating_sub(parent_skip))) + } + (Some(parent_fetch), None) => Some(parent_fetch), + (None, Some(child_fetch)) => Some(child_fetch.saturating_sub(parent_skip)), + (None, None) => None, + }; + + (combined_skip, combined_fetch) +} + #[cfg(test)] mod tests { use crate::ScalarValue::Null; @@ -923,39 +968,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_arrayref_at_indices() -> Result<()> { - let arrays: Vec = vec![ - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 9., 10.])), - Arc::new(Float64Array::from(vec![2.0, 3.0, 3.0, 4.0, 5.0])), - Arc::new(Float64Array::from(vec![5.0, 7.0, 8.0, 10., 11.0])), - Arc::new(Float64Array::from(vec![15.0, 13.0, 8.0, 5., 0.0])), - ]; - - let row_indices_vec: Vec> = vec![ - // Get rows 0 and 1 - vec![0, 1], - // Get rows 0 and 1 - vec![0, 2], - // Get rows 1 and 3 - vec![1, 3], - // Get rows 2 and 4 - vec![2, 4], - ]; - for row_indices in row_indices_vec { - let indices = PrimitiveArray::from_iter_values(row_indices.iter().cloned()); - let chunk = get_arrayref_at_indices(&arrays, &indices)?; - for (arr_orig, arr_chunk) in arrays.iter().zip(&chunk) { - for (idx, orig_idx) in row_indices.iter().enumerate() { - let res1 = ScalarValue::try_from_array(arr_orig, *orig_idx as usize)?; - let res2 = ScalarValue::try_from_array(arr_chunk, idx)?; - assert_eq!(res1, res2); - } - } - } - Ok(()) - } - #[test] fn test_get_at_indices() -> Result<()> { let in_vec = vec![1, 2, 3, 4, 5, 6, 7]; @@ -974,26 +986,6 @@ mod tests { assert_eq!(longest_consecutive_prefix([1, 2, 3, 4]), 0); } - #[test] - fn arc_data_ptr_eq() { - let x = Arc::new(()); - let y = Arc::new(()); - let y_clone = Arc::clone(&y); - - assert!( - Arc::data_ptr_eq(&x, &x), - "same `Arc`s should point to same data" - ); - assert!( - !Arc::data_ptr_eq(&x, &y), - "different `Arc`s should point to different data" - ); - assert!( - Arc::data_ptr_eq(&y, &y_clone), - "cloned `Arc` should point to same data as the original" - ); - } - #[test] fn test_merge_and_order_indices() { assert_eq!( @@ -1038,4 +1030,13 @@ mod tests { assert!(find_indices(&[0, 3, 4], [0, 2]).is_err()); Ok(()) } + + #[test] + fn test_transpose() -> Result<()> { + let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]]; + let transposed = transpose(in_data); + let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]]; + assert_eq!(expected, transposed); + Ok(()) + } } diff --git a/datafusion/execution/src/memory_pool/proxy.rs b/datafusion/common/src/utils/proxy.rs similarity index 92% rename from datafusion/execution/src/memory_pool/proxy.rs rename to datafusion/common/src/utils/proxy.rs index 29874fdaed02..5d14a1517129 100644 --- a/datafusion/execution/src/memory_pool/proxy.rs +++ b/datafusion/common/src/utils/proxy.rs @@ -18,6 +18,7 @@ //! [`VecAllocExt`] and [`RawTableAllocExt`] to help tracking of memory allocations use hashbrown::raw::{Bucket, RawTable}; +use std::mem::size_of; /// Extension trait for [`Vec`] to account for allocations. pub trait VecAllocExt { @@ -31,7 +32,7 @@ pub trait VecAllocExt { /// /// # Example: /// ``` - /// # use datafusion_execution::memory_pool::proxy::VecAllocExt; + /// # use datafusion_common::utils::proxy::VecAllocExt; /// // use allocated to incrementally track how much memory is allocated in the vec /// let mut allocated = 0; /// let mut vec = Vec::new(); @@ -49,7 +50,7 @@ pub trait VecAllocExt { /// ``` /// # Example with other allocations: /// ``` - /// # use datafusion_execution::memory_pool::proxy::VecAllocExt; + /// # use datafusion_common::utils::proxy::VecAllocExt; /// // You can use the same allocated size to track memory allocated by /// // another source. For example /// let mut allocated = 27; @@ -68,7 +69,7 @@ pub trait VecAllocExt { /// /// # Example: /// ``` - /// # use datafusion_execution::memory_pool::proxy::VecAllocExt; + /// # use datafusion_common::utils::proxy::VecAllocExt; /// let mut vec = Vec::new(); /// // Push data into the vec and the accounting will be updated to reflect /// // memory allocation @@ -93,7 +94,7 @@ impl VecAllocExt for Vec { let new_capacity = self.capacity(); if new_capacity > prev_capacty { // capacity changed, so we allocated more - let bump_size = (new_capacity - prev_capacty) * std::mem::size_of::(); + let bump_size = (new_capacity - prev_capacty) * size_of::(); // Note multiplication should never overflow because `push` would // have panic'd first, but the checked_add could potentially // overflow since accounting could be tracking additional values, and @@ -102,7 +103,7 @@ impl VecAllocExt for Vec { } } fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } @@ -119,7 +120,7 @@ pub trait RawTableAllocExt { /// /// # Example: /// ``` - /// # use datafusion_execution::memory_pool::proxy::RawTableAllocExt; + /// # use datafusion_common::utils::proxy::RawTableAllocExt; /// # use hashbrown::raw::RawTable; /// let mut table = RawTable::new(); /// let mut allocated = 0; @@ -157,7 +158,7 @@ impl RawTableAllocExt for RawTable { // need to request more memory let bump_elements = self.capacity().max(16); - let bump_size = bump_elements * std::mem::size_of::(); + let bump_size = bump_elements * size_of::(); *accounting = (*accounting).checked_add(bump_size).expect("overflow"); self.reserve(bump_elements, hasher); diff --git a/datafusion/sql/src/expr/json_access.rs b/datafusion/common/src/utils/string_utils.rs similarity index 60% rename from datafusion/sql/src/expr/json_access.rs rename to datafusion/common/src/utils/string_utils.rs index b24482f88297..a2231e6786a7 100644 --- a/datafusion/sql/src/expr/json_access.rs +++ b/datafusion/common/src/utils/string_utils.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, SqlToRel}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::Operator; -use sqlparser::ast::JsonOperator; +//! Utilities for working with strings -impl<'a, S: ContextProvider> SqlToRel<'a, S> { - pub(crate) fn parse_sql_json_access(&self, op: JsonOperator) -> Result { - match op { - JsonOperator::AtArrow => Ok(Operator::AtArrow), - JsonOperator::ArrowAt => Ok(Operator::ArrowAt), - _ => not_impl_err!("Unsupported SQL json operator {op:?}"), - } +use arrow::{array::AsArray, datatypes::DataType}; +use arrow_array::Array; + +/// Convenient function to convert an Arrow string array to a vector of strings +pub fn string_array_to_vec(array: &dyn Array) -> Vec> { + match array.data_type() { + DataType::Utf8 => array.as_string::().iter().collect(), + DataType::LargeUtf8 => array.as_string::().iter().collect(), + DataType::Utf8View => array.as_string_view().iter().collect(), + _ => unreachable!(), } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 2bd552aacc44..8c4ad80e2924 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -20,7 +20,7 @@ name = "datafusion" description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model" keywords = ["arrow", "query", "sql"] include = ["benches/*.rs", "src/**/*.rs", "Cargo.toml"] -readme = "README.md" +readme = "../../README.md" version = { workspace = true } edition = { workspace = true } homepage = { workspace = true } @@ -30,7 +30,7 @@ authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version and fails with # "Unable to find key 'package.rust-version' (or 'package.metadata.msrv') in 'arrow-datafusion/Cargo.toml'" # https://github.com/foresterre/cargo-msrv/issues/590 -rust-version = "1.73" +rust-version = "1.79" [lints] workspace = true @@ -40,15 +40,17 @@ name = "datafusion" path = "src/lib.rs" [features] +nested_expressions = ["datafusion-functions-nested"] +# This feature is deprecated. Use the `nested_expressions` feature instead. +array_expressions = ["nested_expressions"] # Used to enable the avro format -array_expressions = ["datafusion-functions-array"] avro = ["apache-avro", "num-traits", "datafusion-common/avro"] backtrace = ["datafusion-common/backtrace"] compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression", "tokio-util"] crypto_expressions = ["datafusion-functions/crypto_expressions"] datetime_expressions = ["datafusion-functions/datetime_expressions"] default = [ - "array_expressions", + "nested_expressions", "crypto_expressions", "datetime_expressions", "encoding_expressions", @@ -60,13 +62,11 @@ default = [ ] encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) -force_hash_collisions = [] +force_hash_collisions = ["datafusion-physical-plan/force_hash_collisions", "datafusion-common/force_hash_collisions"] math_expressions = ["datafusion-functions/math_expressions"] parquet = ["datafusion-common/parquet", "dep:parquet"] pyarrow = ["datafusion-common/pyarrow", "parquet"] regex_expressions = [ - "datafusion-physical-expr/regex_expressions", - "datafusion-optimizer/regex_expressions", "datafusion-functions/regex_expressions", ] serde = ["arrow-schema/serde"] @@ -77,7 +77,7 @@ unicode_expressions = [ ] [dependencies] -ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } +ahash = { workspace = true } apache-avro = { version = "0.16", optional = true } arrow = { workspace = true } arrow-array = { workspace = true } @@ -96,22 +96,26 @@ bytes = { workspace = true } bzip2 = { version = "0.4.3", optional = true } chrono = { workspace = true } dashmap = { workspace = true } +datafusion-catalog = { workspace = true } datafusion-common = { workspace = true, features = ["object_store"] } datafusion-common-runtime = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } datafusion-functions-aggregate = { workspace = true } -datafusion-functions-array = { workspace = true, optional = true } +datafusion-functions-nested = { workspace = true, optional = true } +datafusion-functions-window = { workspace = true } datafusion-optimizer = { workspace = true } datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +datafusion-physical-optimizer = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-sql = { workspace = true } flate2 = { version = "1.0.24", optional = true } futures = { workspace = true } glob = "0.3.0" half = { workspace = true } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } @@ -120,6 +124,7 @@ num_cpus = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } parquet = { workspace = true, optional = true, default-features = true } +paste = "1.0.15" pin-project-lite = "^0.2.7" rand = { workspace = true } sqlparser = { workspace = true } @@ -132,11 +137,13 @@ xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] +arrow-buffer = { workspace = true } async-trait = { workspace = true } bigdecimal = { workspace = true } criterion = { version = "0.5", features = ["async_tokio"] } csv = "1.1.6" ctor = { workspace = true } +datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } half = { workspace = true, default-features = true } @@ -145,7 +152,7 @@ postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } rand = { workspace = true, features = ["small_rng"] } rand_distr = "0.4.3" -regex = "1.5.4" +regex = { workspace = true } rstest = { workspace = true } rust_decimal = { version = "1.27.0", features = ["tokio-pg"] } serde_json = { workspace = true } @@ -155,7 +162,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot", "fs"] tokio-postgres = "0.7.7" [target.'cfg(not(target_os = "windows"))'.dev-dependencies] -nix = { version = "0.28.0", features = ["fs"] } +nix = { version = "0.29.0", features = ["fs"] } [[bench]] harness = false @@ -209,3 +216,8 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" + +[[bench]] +harness = false +name = "map_query_sql" +required-features = ["nested_expressions"] diff --git a/datafusion/core/benches/aggregate_query_sql.rs b/datafusion/core/benches/aggregate_query_sql.rs index 3734cfbe313c..1d8d87ada784 100644 --- a/datafusion/core/benches/aggregate_query_sql.rs +++ b/datafusion/core/benches/aggregate_query_sql.rs @@ -163,6 +163,16 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("aggregate_query_distinct_median", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT MEDIAN(DISTINCT u64_wide), MEDIAN(DISTINCT u64_narrow) \ + FROM t", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/filter_query_sql.rs b/datafusion/core/benches/filter_query_sql.rs index 01adc357b39a..0e09ae09d7c2 100644 --- a/datafusion/core/benches/filter_query_sql.rs +++ b/datafusion/core/benches/filter_query_sql.rs @@ -27,7 +27,7 @@ use futures::executor::block_on; use std::sync::Arc; use tokio::runtime::Runtime; -async fn query(ctx: &mut SessionContext, sql: &str) { +async fn query(ctx: &SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -70,25 +70,25 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_size = 4096; // 2^12 c.bench_function("filter_array", |b| { - let mut ctx = create_context(array_len, batch_size).unwrap(); - b.iter(|| block_on(query(&mut ctx, "select f32, f64 from t where f32 >= f64"))) + let ctx = create_context(array_len, batch_size).unwrap(); + b.iter(|| block_on(query(&ctx, "select f32, f64 from t where f32 >= f64"))) }); c.bench_function("filter_scalar", |b| { - let mut ctx = create_context(array_len, batch_size).unwrap(); + let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| { block_on(query( - &mut ctx, + &ctx, "select f32, f64 from t where f32 >= 250 and f64 > 250", )) }) }); c.bench_function("filter_scalar in list", |b| { - let mut ctx = create_context(array_len, batch_size).unwrap(); + let ctx = create_context(array_len, batch_size).unwrap(); b.iter(|| { block_on(query( - &mut ctx, + &ctx, "select f32, f64 from t where f32 in (10, 20, 30, 40)", )) }) diff --git a/datafusion/core/benches/map_query_sql.rs b/datafusion/core/benches/map_query_sql.rs new file mode 100644 index 000000000000..e4c5f7c5deb3 --- /dev/null +++ b/datafusion/core/benches/map_query_sql.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use parking_lot::Mutex; +use rand::prelude::ThreadRng; +use rand::Rng; +use tokio::runtime::Runtime; + +use datafusion::prelude::SessionContext; +use datafusion_common::ScalarValue; +use datafusion_expr::Expr; +use datafusion_functions_nested::map::map; + +mod data_utils; + +fn build_keys(rng: &mut ThreadRng) -> Vec { + let mut keys = vec![]; + for _ in 0..1000 { + keys.push(rng.gen_range(0..9999).to_string()); + } + keys +} + +fn build_values(rng: &mut ThreadRng) -> Vec { + let mut values = vec![]; + for _ in 0..1000 { + values.push(rng.gen_range(0..9999)); + } + values +} + +fn t_batch(num: i32) -> RecordBatch { + let value: Vec = (0..num).collect(); + let c1: ArrayRef = Arc::new(Int32Array::from(value)); + RecordBatch::try_from_iter(vec![("c1", c1)]).unwrap() +} + +fn create_context(num: i32) -> datafusion_common::Result>> { + let ctx = SessionContext::new(); + ctx.register_batch("t", t_batch(num))?; + Ok(Arc::new(Mutex::new(ctx))) +} + +fn criterion_benchmark(c: &mut Criterion) { + let ctx = create_context(1).unwrap(); + let rt = Runtime::new().unwrap(); + let df = rt.block_on(ctx.lock().table("t")).unwrap(); + + let mut rng = rand::thread_rng(); + let keys = build_keys(&mut rng); + let values = build_values(&mut rng); + let mut key_buffer = Vec::new(); + let mut value_buffer = Vec::new(); + + for i in 0..1000 { + key_buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); + value_buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + } + c.bench_function("map_1000_1", |b| { + b.iter(|| { + black_box( + rt.block_on( + df.clone() + .select(vec![map(key_buffer.clone(), value_buffer.clone())]) + .unwrap() + .collect(), + ) + .unwrap(), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/core/benches/parquet_query_sql.rs b/datafusion/core/benches/parquet_query_sql.rs index bc4298786002..f82a126c5652 100644 --- a/datafusion/core/benches/parquet_query_sql.rs +++ b/datafusion/core/benches/parquet_query_sql.rs @@ -249,7 +249,7 @@ fn criterion_benchmark(c: &mut Criterion) { } // Temporary file must outlive the benchmarks, it is deleted when dropped - std::mem::drop(temp_file); + drop(temp_file); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/core/benches/physical_plan.rs b/datafusion/core/benches/physical_plan.rs index 3ad71be1f447..349c2e438195 100644 --- a/datafusion/core/benches/physical_plan.rs +++ b/datafusion/core/benches/physical_plan.rs @@ -36,6 +36,7 @@ use datafusion::physical_plan::{ memory::MemoryExec, }; use datafusion::prelude::SessionContext; +use datafusion_physical_expr_common::sort_expr::LexOrdering; // Initialise the operator using the provided record batches and the sort key // as inputs. All record batches must have the same schema. @@ -52,7 +53,7 @@ fn sort_preserving_merge_operator( expr: col(name, &schema).unwrap(), options: Default::default(), }) - .collect::>(); + .collect::(); let exec = MemoryExec::try_new( &batches.into_iter().map(|rb| vec![rb]).collect::>(), diff --git a/datafusion/core/benches/sort.rs b/datafusion/core/benches/sort.rs index 94a39bbb2af3..14e80ce364e3 100644 --- a/datafusion/core/benches/sort.rs +++ b/datafusion/core/benches/sort.rs @@ -21,7 +21,7 @@ //! 1. Creates a list of tuples (sorted if necessary) //! //! 2. Divides those tuples across some number of streams of [`RecordBatch`] -//! preserving any ordering +//! preserving any ordering //! //! 3. Times how long it takes for a given sort plan to process the input //! @@ -89,6 +89,7 @@ use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; /// Benchmarks for SortPreservingMerge stream use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::StreamExt; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -257,7 +258,7 @@ impl BenchCase { } /// Make sort exprs for each column in `schema` -fn make_sort_exprs(schema: &Schema) -> Vec { +fn make_sort_exprs(schema: &Schema) -> LexOrdering { schema .fields() .iter() diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index c0e02d388af4..140e266a0272 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -15,22 +15,31 @@ // specific language governing permissions and limitations // under the License. +extern crate arrow; #[macro_use] extern crate criterion; -extern crate arrow; extern crate datafusion; mod data_utils; + use crate::criterion::Criterion; use arrow::datatypes::{DataType, Field, Fields, Schema}; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; +use itertools::Itertools; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; use std::sync::Arc; use test_utils::tpcds::tpcds_schemas; use test_utils::tpch::tpch_schemas; use test_utils::TableDef; use tokio::runtime::Runtime; +const BENCHMARKS_PATH_1: &str = "../../benchmarks/"; +const BENCHMARKS_PATH_2: &str = "./benchmarks/"; +const CLICKBENCH_DATA_PATH: &str = "data/hits_partitioned/"; + /// Create a logical plan from the specified sql fn logical_plan(ctx: &SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); @@ -60,7 +69,9 @@ fn create_schema(column_prefix: &str, num_columns: usize) -> Schema { fn create_table_provider(column_prefix: &str, num_columns: usize) -> Arc { let schema = Arc::new(create_schema(column_prefix, num_columns)); - MemTable::try_new(schema, vec![]).map(Arc::new).unwrap() + MemTable::try_new(schema, vec![vec![]]) + .map(Arc::new) + .unwrap() } fn create_context() -> SessionContext { @@ -82,14 +93,44 @@ fn register_defs(ctx: SessionContext, defs: Vec) -> SessionContext { defs.iter().for_each(|TableDef { name, schema }| { ctx.register_table( name, - Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![]).unwrap()), + Arc::new(MemTable::try_new(Arc::new(schema.clone()), vec![vec![]]).unwrap()), ) .unwrap(); }); ctx } +fn register_clickbench_hits_table() -> SessionContext { + let ctx = SessionContext::new(); + let rt = Runtime::new().unwrap(); + + // use an external table for clickbench benchmarks + let path = + if PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() { + format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}") + } else { + format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}") + }; + + let sql = format!("CREATE EXTERNAL TABLE hits STORED AS PARQUET LOCATION '{path}'"); + + rt.block_on(ctx.sql(&sql)).unwrap(); + + let count = + rt.block_on(async { ctx.table("hits").await.unwrap().count().await.unwrap() }); + assert!(count > 0); + ctx +} + fn criterion_benchmark(c: &mut Criterion) { + // verify that we can load the clickbench data prior to running the benchmark + if !PathBuf::from(format!("{BENCHMARKS_PATH_1}{CLICKBENCH_DATA_PATH}")).exists() + && !PathBuf::from(format!("{BENCHMARKS_PATH_2}{CLICKBENCH_DATA_PATH}")).exists() + { + panic!("benchmarks/data/hits_partitioned/ could not be loaded. Please run \ + 'benchmarks/bench.sh data clickbench_partitioned' prior to running this benchmark") + } + let ctx = create_context(); // Test simplest @@ -144,6 +185,85 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); + c.bench_function("physical_select_aggregates_from_200", |b| { + let mut aggregates = String::new(); + for i in 0..200 { + if i > 0 { + aggregates.push_str(", "); + } + aggregates.push_str(format!("MAX(a{})", i).as_str()); + } + let query = format!("SELECT {} FROM t1", aggregates); + b.iter(|| { + physical_plan(&ctx, &query); + }); + }); + + // Benchmark for Physical Planning Joins + c.bench_function("physical_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 = b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_theta_join_consider_sort", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7, t2.b8 \ + FROM t1, t2 WHERE a7 < b7 \ + ORDER BY a7", + ); + }); + }); + + c.bench_function("physical_many_self_joins", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT ta.a9, tb.a10, tc.a11, td.a12, te.a13, tf.a14 \ + FROM t1 AS ta, t1 AS tb, t1 AS tc, t1 AS td, t1 AS te, t1 AS tf \ + WHERE ta.a9 = tb.a10 AND tb.a10 = tc.a11 AND tc.a11 = td.a12 AND \ + td.a12 = te.a13 AND te.a13 = tf.a14", + ); + }); + }); + + c.bench_function("physical_unnest_to_join", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 \ + FROM t1 WHERE a7 = (SELECT b8 FROM t2)", + ); + }); + }); + + c.bench_function("physical_intersection", |b| { + b.iter(|| { + physical_plan( + &ctx, + "SELECT t1.a7 FROM t1 \ + INTERSECT SELECT t2.b8 FROM t2", + ); + }); + }); + // these two queries should be equivalent + c.bench_function("physical_join_distinct", |b| { + b.iter(|| { + logical_plan( + &ctx, + "SELECT DISTINCT t1.a7 \ + FROM t1, t2 WHERE t1.a7 = t2.b8", + ); + }); + }); + // --- TPC-H --- let tpch_ctx = register_defs(SessionContext::new(), tpch_schemas()); @@ -154,9 +274,15 @@ fn criterion_benchmark(c: &mut Criterion) { "q16", "q17", "q18", "q19", "q20", "q21", "q22", ]; + let benchmarks_path = if PathBuf::from(BENCHMARKS_PATH_1).exists() { + BENCHMARKS_PATH_1 + } else { + BENCHMARKS_PATH_2 + }; + for q in tpch_queries { let sql = - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap(); + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap(); c.bench_function(&format!("physical_plan_tpch_{}", q), |b| { b.iter(|| physical_plan(&tpch_ctx, &sql)) }); @@ -165,7 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { let all_tpch_sql_queries = tpch_queries .iter() .map(|q| { - std::fs::read_to_string(format!("../../benchmarks/queries/{q}.sql")).unwrap() + std::fs::read_to_string(format!("{benchmarks_path}queries/{q}.sql")).unwrap() }) .collect::>(); @@ -177,26 +303,25 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpch_all", |b| { - b.iter(|| { - for sql in &all_tpch_sql_queries { - logical_plan(&tpch_ctx, sql) - } - }) - }); + // c.bench_function("logical_plan_tpch_all", |b| { + // b.iter(|| { + // for sql in &all_tpch_sql_queries { + // logical_plan(&tpch_ctx, sql) + // } + // }) + // }); // --- TPC-DS --- let tpcds_ctx = register_defs(SessionContext::new(), tpcds_schemas()); - - // 10, 35: Physical plan does not support logical expression Exists() - // 45: Physical plan does not support logical expression () - // 41: Optimizing disjunctions not supported - let ignored = [10, 35, 41, 45]; + let tests_path = if PathBuf::from("./tests/").exists() { + "./tests/" + } else { + "datafusion/core/tests/" + }; let raw_tpcds_sql_queries = (1..100) - .filter(|q| !ignored.contains(q)) - .map(|q| std::fs::read_to_string(format!("./tests/tpc-ds/{q}.sql")).unwrap()) + .map(|q| std::fs::read_to_string(format!("{tests_path}tpc-ds/{q}.sql")).unwrap()) .collect::>(); // some queries have multiple statements @@ -213,10 +338,53 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("logical_plan_tpcds_all", |b| { + // c.bench_function("logical_plan_tpcds_all", |b| { + // b.iter(|| { + // for sql in &all_tpcds_sql_queries { + // logical_plan(&tpcds_ctx, sql) + // } + // }) + // }); + + // -- clickbench -- + + let queries_file = + File::open(format!("{benchmarks_path}queries/clickbench/queries.sql")).unwrap(); + let extended_file = + File::open(format!("{benchmarks_path}queries/clickbench/extended.sql")).unwrap(); + + let clickbench_queries: Vec = BufReader::new(queries_file) + .lines() + .chain(BufReader::new(extended_file).lines()) + .map(|l| l.expect("Could not parse line")) + .collect_vec(); + + let clickbench_ctx = register_clickbench_hits_table(); + + // for (i, sql) in clickbench_queries.iter().enumerate() { + // c.bench_function(&format!("logical_plan_clickbench_q{}", i + 1), |b| { + // b.iter(|| logical_plan(&clickbench_ctx, sql)) + // }); + // } + + for (i, sql) in clickbench_queries.iter().enumerate() { + c.bench_function(&format!("physical_plan_clickbench_q{}", i + 1), |b| { + b.iter(|| physical_plan(&clickbench_ctx, sql)) + }); + } + + // c.bench_function("logical_plan_clickbench_all", |b| { + // b.iter(|| { + // for sql in &clickbench_queries { + // logical_plan(&clickbench_ctx, sql) + // } + // }) + // }); + + c.bench_function("physical_plan_clickbench_all", |b| { b.iter(|| { - for sql in &all_tpcds_sql_queries { - logical_plan(&tpcds_ctx, sql) + for sql in &clickbench_queries { + physical_plan(&clickbench_ctx, sql) } }) }); diff --git a/datafusion/core/benches/sql_query_with_io.rs b/datafusion/core/benches/sql_query_with_io.rs index 916f48ce40c6..aef39a04e47e 100644 --- a/datafusion/core/benches/sql_query_with_io.rs +++ b/datafusion/core/benches/sql_query_with_io.rs @@ -96,7 +96,7 @@ async fn setup_files(store: Arc) { let location = Path::from(format!( "{table_name}/partition={partition}/{file}.parquet" )); - store.put(&location, data).await.unwrap(); + store.put(&location, data.into()).await.unwrap(); } } } diff --git a/datafusion/core/src/bin/print_functions_docs.rs b/datafusion/core/src/bin/print_functions_docs.rs new file mode 100644 index 000000000000..3aedcbc2aa63 --- /dev/null +++ b/datafusion/core/src/bin/print_functions_docs.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::execution::SessionStateDefaults; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::{ + aggregate_doc_sections, scalar_doc_sections, window_doc_sections, AggregateUDF, + DocSection, Documentation, ScalarUDF, WindowUDF, +}; +use hashbrown::HashSet; +use itertools::Itertools; +use std::env::args; +use std::fmt::Write as _; + +/// Print documentation for all functions of a given type to stdout +/// +/// Usage: `cargo run --bin print_functions_docs -- ` +/// +/// Called from `dev/update_function_docs.sh` +fn main() -> Result<()> { + let args: Vec = args().collect(); + + if args.len() != 2 { + panic!( + "Usage: {} type (one of 'aggregate', 'scalar', 'window')", + args[0] + ); + } + + let function_type = args[1].trim().to_lowercase(); + let docs = match function_type.as_str() { + "aggregate" => print_aggregate_docs(), + "scalar" => print_scalar_docs(), + "window" => print_window_docs(), + _ => { + panic!("Unknown function type: {}", function_type) + } + }?; + + println!("{docs}"); + Ok(()) +} + +fn print_aggregate_docs() -> Result { + let mut providers: Vec> = vec![]; + + for f in SessionStateDefaults::default_aggregate_functions() { + providers.push(Box::new(f.as_ref().clone())); + } + + print_docs(providers, aggregate_doc_sections::doc_sections()) +} + +fn print_scalar_docs() -> Result { + let mut providers: Vec> = vec![]; + + for f in SessionStateDefaults::default_scalar_functions() { + providers.push(Box::new(f.as_ref().clone())); + } + + print_docs(providers, scalar_doc_sections::doc_sections()) +} + +fn print_window_docs() -> Result { + let mut providers: Vec> = vec![]; + + for f in SessionStateDefaults::default_window_functions() { + providers.push(Box::new(f.as_ref().clone())); + } + + print_docs(providers, window_doc_sections::doc_sections()) +} + +fn print_docs( + providers: Vec>, + doc_sections: Vec, +) -> Result { + let mut docs = "".to_string(); + + // Ensure that all providers have documentation + let mut providers_with_no_docs = HashSet::new(); + + // doc sections only includes sections that have 'include' == true + for doc_section in doc_sections { + // make sure there is at least one function that is in this doc section + if !&providers.iter().any(|f| { + if let Some(documentation) = f.get_documentation() { + documentation.doc_section == doc_section + } else { + false + } + }) { + continue; + } + + // filter out functions that are not in this doc section + let providers: Vec<&Box> = providers + .iter() + .filter(|&f| { + if let Some(documentation) = f.get_documentation() { + documentation.doc_section == doc_section + } else { + providers_with_no_docs.insert(f.get_name()); + false + } + }) + .collect::>(); + + // write out section header + let _ = writeln!(docs, "\n## {} \n", doc_section.label); + + if let Some(description) = doc_section.description { + let _ = writeln!(docs, "{description}"); + } + + // names is a sorted list of function names and aliases since we display + // both in the documentation + let names = get_names_and_aliases(&providers); + + // write out the list of function names and aliases + names.iter().for_each(|name| { + let _ = writeln!(docs, "- [{name}](#{name})"); + }); + + // write out each function and alias in the order of the sorted name list + for name in names { + let f = providers + .iter() + .find(|f| f.get_name() == name || f.get_aliases().contains(&name)) + .unwrap(); + + let aliases = f.get_aliases(); + let documentation = f.get_documentation(); + + // if this name is an alias we need to display what it's an alias of + if aliases.contains(&name) { + let fname = f.get_name(); + let _ = writeln!(docs, r#"### `{name}`"#); + let _ = writeln!(docs, "_Alias of [{fname}](#{fname})._"); + continue; + } + + // otherwise display the documentation for the function + let Some(documentation) = documentation else { + unreachable!() + }; + + // first, the name, description and syntax example + let _ = write!( + docs, + r#" +### `{}` + +{} + +``` +{} +``` +"#, + name, documentation.description, documentation.syntax_example + ); + + // next, arguments + if let Some(args) = &documentation.arguments { + let _ = writeln!(docs, "#### Arguments\n"); + for (arg_name, arg_desc) in args { + let _ = writeln!(docs, "- **{arg_name}**: {arg_desc}"); + } + } + + // next, sql example if provided + if let Some(example) = &documentation.sql_example { + let _ = writeln!( + docs, + r#" +#### Example + +{} +"#, + example + ); + } + + if let Some(alt_syntax) = &documentation.alternative_syntax { + let _ = writeln!(docs, "#### Alternative Syntax\n"); + for syntax in alt_syntax { + let _ = writeln!(docs, "```sql\n{}\n```", syntax); + } + } + + // next, aliases + if !f.get_aliases().is_empty() { + let _ = writeln!(docs, "#### Aliases"); + + for alias in f.get_aliases() { + let _ = writeln!(docs, "- {}", alias.replace("_", r#"\_"#)); + } + } + + // finally, any related udfs + if let Some(related_udfs) = &documentation.related_udfs { + let _ = writeln!(docs, "\n**Related functions**:"); + + for related in related_udfs { + let _ = writeln!(docs, "- [{related}](#{related})"); + } + } + } + } + + // If there are any functions that do not have documentation, print them out + // eventually make this an error: https://github.com/apache/datafusion/issues/12872 + if !providers_with_no_docs.is_empty() { + eprintln!("INFO: The following functions do not have documentation:"); + for f in &providers_with_no_docs { + eprintln!(" - {f}"); + } + not_impl_err!("Some functions do not have documentation. Please implement `documentation` for: {providers_with_no_docs:?}") + } else { + Ok(docs) + } +} + +/// Trait for accessing name / aliases / documentation for differnet functions +trait DocProvider { + fn get_name(&self) -> String; + fn get_aliases(&self) -> Vec; + fn get_documentation(&self) -> Option<&Documentation>; +} + +impl DocProvider for AggregateUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +impl DocProvider for ScalarUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +impl DocProvider for WindowUDF { + fn get_name(&self) -> String { + self.name().to_string() + } + fn get_aliases(&self) -> Vec { + self.aliases().iter().map(|a| a.to_string()).collect() + } + fn get_documentation(&self) -> Option<&Documentation> { + self.documentation() + } +} + +#[allow(clippy::borrowed_box)] +#[allow(clippy::ptr_arg)] +fn get_names_and_aliases(functions: &Vec<&Box>) -> Vec { + functions + .iter() + .flat_map(|f| { + if f.get_aliases().is_empty() { + vec![f.get_name().to_string()] + } else { + let mut names = vec![f.get_name().to_string()]; + names.extend(f.get_aliases().iter().cloned()); + names + } + }) + .sorted() + .collect_vec() +} diff --git a/datafusion/core/src/catalog/schema.rs b/datafusion/core/src/catalog/schema.rs deleted file mode 100644 index 8249c3a5330f..000000000000 --- a/datafusion/core/src/catalog/schema.rs +++ /dev/null @@ -1,233 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Describes the interface and built-in implementations of schemas, -//! representing collections of named tables. - -use async_trait::async_trait; -use dashmap::DashMap; -use datafusion_common::{exec_err, DataFusionError}; -use std::any::Any; -use std::sync::Arc; - -use crate::datasource::TableProvider; -use crate::error::Result; - -/// Represents a schema, comprising a number of named tables. -/// -/// Please see [`CatalogProvider`] for details of implementing a custom catalog. -/// -/// [`CatalogProvider`]: super::CatalogProvider -#[async_trait] -pub trait SchemaProvider: Sync + Send { - /// Returns the owner of the Schema, default is None. This value is reported - /// as part of `information_tables.schemata - fn owner_name(&self) -> Option<&str> { - None - } - - /// Returns this `SchemaProvider` as [`Any`] so that it can be downcast to a - /// specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Retrieves the list of available table names in this schema. - fn table_names(&self) -> Vec; - - /// Retrieves a specific table from the schema by name, if it exists, - /// otherwise returns `None`. - async fn table( - &self, - name: &str, - ) -> Result>, DataFusionError>; - - /// If supported by the implementation, adds a new table named `name` to - /// this schema. - /// - /// If a table of the same name was already registered, returns "Table - /// already exists" error. - #[allow(unused_variables)] - fn register_table( - &self, - name: String, - table: Arc, - ) -> Result>> { - exec_err!("schema provider does not support registering tables") - } - - /// If supported by the implementation, removes the `name` table from this - /// schema and returns the previously registered [`TableProvider`], if any. - /// - /// If no `name` table exists, returns Ok(None). - #[allow(unused_variables)] - fn deregister_table(&self, name: &str) -> Result>> { - exec_err!("schema provider does not support deregistering tables") - } - - /// Returns true if table exist in the schema provider, false otherwise. - fn table_exist(&self, name: &str) -> bool; -} - -/// Simple in-memory implementation of a schema. -pub struct MemorySchemaProvider { - tables: DashMap>, -} - -impl MemorySchemaProvider { - /// Instantiates a new MemorySchemaProvider with an empty collection of tables. - pub fn new() -> Self { - Self { - tables: DashMap::new(), - } - } -} - -impl Default for MemorySchemaProvider { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl SchemaProvider for MemorySchemaProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn table_names(&self) -> Vec { - self.tables - .iter() - .map(|table| table.key().clone()) - .collect() - } - - async fn table( - &self, - name: &str, - ) -> Result>, DataFusionError> { - Ok(self.tables.get(name).map(|table| table.value().clone())) - } - - fn register_table( - &self, - name: String, - table: Arc, - ) -> Result>> { - if self.table_exist(name.as_str()) { - return exec_err!("The table {name} already exists"); - } - Ok(self.tables.insert(name, table)) - } - - fn deregister_table(&self, name: &str) -> Result>> { - Ok(self.tables.remove(name).map(|(_, table)| table)) - } - - fn table_exist(&self, name: &str) -> bool { - self.tables.contains_key(name) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::datatypes::Schema; - - use crate::assert_batches_eq; - use crate::catalog::schema::{MemorySchemaProvider, SchemaProvider}; - use crate::catalog::{CatalogProvider, MemoryCatalogProvider}; - use crate::datasource::empty::EmptyTable; - use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; - use crate::prelude::SessionContext; - - #[tokio::test] - async fn test_mem_provider() { - let provider = MemorySchemaProvider::new(); - let table_name = "test_table_exist"; - assert!(!provider.table_exist(table_name)); - assert!(provider.deregister_table(table_name).unwrap().is_none()); - let test_table = EmptyTable::new(Arc::new(Schema::empty())); - // register table successfully - assert!(provider - .register_table(table_name.to_string(), Arc::new(test_table)) - .unwrap() - .is_none()); - assert!(provider.table_exist(table_name)); - let other_table = EmptyTable::new(Arc::new(Schema::empty())); - let result = - provider.register_table(table_name.to_string(), Arc::new(other_table)); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_schema_register_listing_table() { - let testdata = crate::test_util::parquet_test_data(); - let testdir = if testdata.starts_with('/') { - format!("file://{testdata}") - } else { - format!("file:///{testdata}") - }; - let filename = if testdir.ends_with('/') { - format!("{}{}", testdir, "alltypes_plain.parquet") - } else { - format!("{}/{}", testdir, "alltypes_plain.parquet") - }; - - let table_path = ListingTableUrl::parse(filename).unwrap(); - - let catalog = MemoryCatalogProvider::new(); - let schema = MemorySchemaProvider::new(); - - let ctx = SessionContext::new(); - - let config = ListingTableConfig::new(table_path) - .infer(&ctx.state()) - .await - .unwrap(); - let table = ListingTable::try_new(config).unwrap(); - - schema - .register_table("alltypes_plain".to_string(), Arc::new(table)) - .unwrap(); - - catalog.register_schema("active", Arc::new(schema)).unwrap(); - ctx.register_catalog("cat", Arc::new(catalog)); - - let df = ctx - .sql("SELECT id, bool_col FROM cat.active.alltypes_plain") - .await - .unwrap(); - - let actual = df.collect().await.unwrap(); - - let expected = [ - "+----+----------+", - "| id | bool_col |", - "+----+----------+", - "| 4 | true |", - "| 5 | false |", - "| 6 | true |", - "| 7 | false |", - "| 2 | true |", - "| 3 | false |", - "| 0 | true |", - "| 1 | false |", - "+----+----------+", - ]; - assert_batches_eq!(expected, &actual); - } -} diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog_common/information_schema.rs similarity index 93% rename from datafusion/core/src/catalog/information_schema.rs rename to datafusion/core/src/catalog_common/information_schema.rs index cd8f7649534f..180994b1cbe8 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog_common/information_schema.rs @@ -15,22 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! Implements the SQL [Information Schema] for DataFusion. +//! [`InformationSchemaProvider`] that implements the SQL [Information Schema] for DataFusion. //! //! [Information Schema]: https://en.wikipedia.org/wiki/Information_schema -use async_trait::async_trait; -use datafusion_common::DataFusionError; -use std::{any::Any, sync::Arc}; - use arrow::{ array::{StringBuilder, UInt64Builder}, datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; +use async_trait::async_trait; +use datafusion_common::DataFusionError; +use std::fmt::Debug; +use std::{any::Any, sync::Arc}; +use crate::catalog::{CatalogProviderList, SchemaProvider, TableProvider}; use crate::datasource::streaming::StreamingTable; -use crate::datasource::TableProvider; use crate::execution::context::TaskContext; use crate::logical_expr::TableType; use crate::physical_plan::stream::RecordBatchStreamAdapter; @@ -40,8 +40,6 @@ use crate::{ physical_plan::streaming::PartitionStream, }; -use super::{schema::SchemaProvider, CatalogProviderList}; - pub(crate) const INFORMATION_SCHEMA: &str = "information_schema"; pub(crate) const TABLES: &str = "tables"; pub(crate) const VIEWS: &str = "views"; @@ -59,6 +57,7 @@ pub const INFORMATION_SCHEMA_TABLES: &[&str] = /// demand. This means that if more tables are added to the underlying /// providers, they will appear the next time the `information_schema` /// table is queried. +#[derive(Debug)] pub struct InformationSchemaProvider { config: InformationSchemaConfig, } @@ -72,7 +71,7 @@ impl InformationSchemaProvider { } } -#[derive(Clone)] +#[derive(Clone, Debug)] struct InformationSchemaConfig { catalog_list: Arc, } @@ -107,26 +106,14 @@ impl InformationSchemaConfig { } // Add a final list for the information schema tables themselves - builder.add_table(&catalog_name, INFORMATION_SCHEMA, TABLES, TableType::View); - builder.add_table(&catalog_name, INFORMATION_SCHEMA, VIEWS, TableType::View); - builder.add_table( - &catalog_name, - INFORMATION_SCHEMA, - COLUMNS, - TableType::View, - ); - builder.add_table( - &catalog_name, - INFORMATION_SCHEMA, - DF_SETTINGS, - TableType::View, - ); - builder.add_table( - &catalog_name, - INFORMATION_SCHEMA, - SCHEMATA, - TableType::View, - ); + for table_name in INFORMATION_SCHEMA_TABLES { + builder.add_table( + &catalog_name, + INFORMATION_SCHEMA, + table_name, + TableType::View, + ); + } } Ok(()) @@ -225,18 +212,15 @@ impl InformationSchemaConfig { #[async_trait] impl SchemaProvider for InformationSchemaProvider { - fn as_any(&self) -> &(dyn Any + 'static) { + fn as_any(&self) -> &dyn Any { self } fn table_names(&self) -> Vec { - vec![ - TABLES.to_string(), - VIEWS.to_string(), - COLUMNS.to_string(), - DF_SETTINGS.to_string(), - SCHEMATA.to_string(), - ] + INFORMATION_SCHEMA_TABLES + .iter() + .map(|t| t.to_string()) + .collect() } async fn table( @@ -244,18 +228,13 @@ impl SchemaProvider for InformationSchemaProvider { name: &str, ) -> Result>, DataFusionError> { let config = self.config.clone(); - let table: Arc = if name.eq_ignore_ascii_case("tables") { - Arc::new(InformationSchemaTables::new(config)) - } else if name.eq_ignore_ascii_case("columns") { - Arc::new(InformationSchemaColumns::new(config)) - } else if name.eq_ignore_ascii_case("views") { - Arc::new(InformationSchemaViews::new(config)) - } else if name.eq_ignore_ascii_case("df_settings") { - Arc::new(InformationSchemaDfSettings::new(config)) - } else if name.eq_ignore_ascii_case("schemata") { - Arc::new(InformationSchemata::new(config)) - } else { - return Ok(None); + let table: Arc = match name.to_ascii_lowercase().as_str() { + TABLES => Arc::new(InformationSchemaTables::new(config)), + COLUMNS => Arc::new(InformationSchemaColumns::new(config)), + VIEWS => Arc::new(InformationSchemaViews::new(config)), + DF_SETTINGS => Arc::new(InformationSchemaDfSettings::new(config)), + SCHEMATA => Arc::new(InformationSchemata::new(config)), + _ => return Ok(None), }; Ok(Some(Arc::new( @@ -264,13 +243,11 @@ impl SchemaProvider for InformationSchemaProvider { } fn table_exist(&self, name: &str) -> bool { - matches!( - name.to_ascii_lowercase().as_str(), - TABLES | VIEWS | COLUMNS | SCHEMATA - ) + INFORMATION_SCHEMA_TABLES.contains(&name.to_ascii_lowercase().as_str()) } } +#[derive(Debug)] struct InformationSchemaTables { schema: SchemaRef, config: InformationSchemaConfig, @@ -337,7 +314,7 @@ impl InformationSchemaTablesBuilder { table_name: impl AsRef, table_type: TableType, ) { - // Note: append_value is actually infallable. + // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); self.schema_names.append_value(schema_name.as_ref()); self.table_names.append_value(table_name.as_ref()); @@ -362,6 +339,7 @@ impl InformationSchemaTablesBuilder { } } +#[derive(Debug)] struct InformationSchemaViews { schema: SchemaRef, config: InformationSchemaConfig, @@ -428,7 +406,7 @@ impl InformationSchemaViewBuilder { table_name: impl AsRef, definition: Option>, ) { - // Note: append_value is actually infallable. + // Note: append_value is actually infallible. self.catalog_names.append_value(catalog_name.as_ref()); self.schema_names.append_value(schema_name.as_ref()); self.table_names.append_value(table_name.as_ref()); @@ -449,6 +427,7 @@ impl InformationSchemaViewBuilder { } } +#[derive(Debug)] struct InformationSchemaColumns { schema: SchemaRef, config: InformationSchemaConfig, @@ -665,6 +644,7 @@ impl InformationSchemaColumnsBuilder { } } +#[derive(Debug)] struct InformationSchemata { schema: SchemaRef, config: InformationSchemaConfig, @@ -766,6 +746,7 @@ impl PartitionStream for InformationSchemata { } } +#[derive(Debug)] struct InformationSchemaDfSettings { schema: SchemaRef, config: InformationSchemaConfig, diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog_common/listing_schema.rs similarity index 93% rename from datafusion/core/src/catalog/listing_schema.rs rename to datafusion/core/src/catalog_common/listing_schema.rs index a5960b21dff5..665ea58c5f75 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog_common/listing_schema.rs @@ -15,19 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! listing_schema contains a SchemaProvider that scans ObjectStores for tables automatically +//! [`ListingSchemaProvider`]: [`SchemaProvider`] that scans ObjectStores for tables automatically use std::any::Any; use std::collections::{HashMap, HashSet}; use std::path::Path; use std::sync::{Arc, Mutex}; -use crate::catalog::schema::SchemaProvider; -use crate::datasource::provider::TableProviderFactory; -use crate::datasource::TableProvider; +use crate::catalog::{SchemaProvider, TableProvider, TableProviderFactory}; use crate::execution::context::SessionState; -use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{Constraints, DFSchema, DataFusionError, TableReference}; use datafusion_expr::CreateExternalTable; @@ -51,6 +48,7 @@ use object_store::ObjectStore; /// - `s3://host.example.com:3000/data/tpch/customer/_delta_log/` /// /// [`ObjectStore`]: object_store::ObjectStore +#[derive(Debug)] pub struct ListingSchemaProvider { authority: String, path: object_store::path::Path, @@ -58,7 +56,6 @@ pub struct ListingSchemaProvider { store: Arc, tables: Arc>>>, format: String, - has_header: bool, } impl ListingSchemaProvider { @@ -77,7 +74,6 @@ impl ListingSchemaProvider { factory: Arc, store: Arc, format: String, - has_header: bool, ) -> Self { Self { authority, @@ -86,7 +82,6 @@ impl ListingSchemaProvider { store, tables: Arc::new(Mutex::new(HashMap::new())), format, - has_header, } } @@ -139,12 +134,10 @@ impl ListingSchemaProvider { name, location: table_url, file_type: self.format.clone(), - has_header: self.has_header, - delimiter: ',', table_partition_cols: vec![], if_not_exists: false, + temporary: false, definition: None, - file_compression_type: CompressionTypeVariant::UNCOMPRESSED, order_exprs: vec![], unbounded: false, options: Default::default(), diff --git a/datafusion/core/src/catalog_common/memory.rs b/datafusion/core/src/catalog_common/memory.rs new file mode 100644 index 000000000000..f25146616891 --- /dev/null +++ b/datafusion/core/src/catalog_common/memory.rs @@ -0,0 +1,355 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`MemoryCatalogProvider`], [`MemoryCatalogProviderList`]: In-memory +//! implementations of [`CatalogProviderList`] and [`CatalogProvider`]. + +use crate::catalog::{ + CatalogProvider, CatalogProviderList, SchemaProvider, TableProvider, +}; +use async_trait::async_trait; +use dashmap::DashMap; +use datafusion_common::{exec_err, DataFusionError}; +use std::any::Any; +use std::sync::Arc; + +/// Simple in-memory list of catalogs +#[derive(Debug)] +pub struct MemoryCatalogProviderList { + /// Collection of catalogs containing schemas and ultimately TableProviders + pub catalogs: DashMap>, +} + +impl MemoryCatalogProviderList { + /// Instantiates a new `MemoryCatalogProviderList` with an empty collection of catalogs + pub fn new() -> Self { + Self { + catalogs: DashMap::new(), + } + } +} + +impl Default for MemoryCatalogProviderList { + fn default() -> Self { + Self::new() + } +} + +impl CatalogProviderList for MemoryCatalogProviderList { + fn as_any(&self) -> &dyn Any { + self + } + + fn register_catalog( + &self, + name: String, + catalog: Arc, + ) -> Option> { + self.catalogs.insert(name, catalog) + } + + fn catalog_names(&self) -> Vec { + self.catalogs.iter().map(|c| c.key().clone()).collect() + } + + fn catalog(&self, name: &str) -> Option> { + self.catalogs.get(name).map(|c| c.value().clone()) + } +} + +/// Simple in-memory implementation of a catalog. +#[derive(Debug)] +pub struct MemoryCatalogProvider { + schemas: DashMap>, +} + +impl MemoryCatalogProvider { + /// Instantiates a new MemoryCatalogProvider with an empty collection of schemas. + pub fn new() -> Self { + Self { + schemas: DashMap::new(), + } + } +} + +impl Default for MemoryCatalogProvider { + fn default() -> Self { + Self::new() + } +} + +impl CatalogProvider for MemoryCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.schemas.iter().map(|s| s.key().clone()).collect() + } + + fn schema(&self, name: &str) -> Option> { + self.schemas.get(name).map(|s| s.value().clone()) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion_common::Result>> { + Ok(self.schemas.insert(name.into(), schema)) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion_common::Result>> { + if let Some(schema) = self.schema(name) { + let table_names = schema.table_names(); + match (table_names.is_empty(), cascade) { + (true, _) | (false, true) => { + let (_, removed) = self.schemas.remove(name).unwrap(); + Ok(Some(removed)) + } + (false, false) => exec_err!( + "Cannot drop schema {} because other tables depend on it: {}", + name, + itertools::join(table_names.iter(), ", ") + ), + } + } else { + Ok(None) + } + } +} + +/// Simple in-memory implementation of a schema. +#[derive(Debug)] +pub struct MemorySchemaProvider { + tables: DashMap>, +} + +impl MemorySchemaProvider { + /// Instantiates a new MemorySchemaProvider with an empty collection of tables. + pub fn new() -> Self { + Self { + tables: DashMap::new(), + } + } +} + +impl Default for MemorySchemaProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl SchemaProvider for MemorySchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.tables + .iter() + .map(|table| table.key().clone()) + .collect() + } + + async fn table( + &self, + name: &str, + ) -> datafusion_common::Result>, DataFusionError> { + Ok(self.tables.get(name).map(|table| table.value().clone())) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion_common::Result>> { + if self.table_exist(name.as_str()) { + return exec_err!("The table {name} already exists"); + } + Ok(self.tables.insert(name, table)) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion_common::Result>> { + Ok(self.tables.remove(name).map(|(_, table)| table)) + } + + fn table_exist(&self, name: &str) -> bool { + self.tables.contains_key(name) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::catalog::CatalogProvider; + use crate::catalog_common::memory::MemorySchemaProvider; + use crate::datasource::empty::EmptyTable; + use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; + use crate::prelude::SessionContext; + use arrow_schema::Schema; + use datafusion_common::assert_batches_eq; + use std::any::Any; + use std::sync::Arc; + + #[test] + fn memory_catalog_dereg_nonempty_schema() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let test_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))) + as Arc; + schema.register_table("t".into(), test_table).unwrap(); + + cat.register_schema("foo", schema.clone()).unwrap(); + + assert!( + cat.deregister_schema("foo", false).is_err(), + "dropping empty schema without cascade should error" + ); + assert!(cat.deregister_schema("foo", true).unwrap().is_some()); + } + + #[test] + fn memory_catalog_dereg_empty_schema() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + cat.register_schema("foo", schema).unwrap(); + + assert!(cat.deregister_schema("foo", false).unwrap().is_some()); + } + + #[test] + fn memory_catalog_dereg_missing() { + let cat = Arc::new(MemoryCatalogProvider::new()) as Arc; + assert!(cat.deregister_schema("foo", false).unwrap().is_none()); + } + + #[test] + fn default_register_schema_not_supported() { + // mimic a new CatalogProvider and ensure it does not support registering schemas + #[derive(Debug)] + struct TestProvider {} + impl CatalogProvider for TestProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + unimplemented!() + } + + fn schema(&self, _name: &str) -> Option> { + unimplemented!() + } + } + + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let catalog = Arc::new(TestProvider {}); + + match catalog.register_schema("foo", schema) { + Ok(_) => panic!("unexpected OK"), + Err(e) => assert_eq!(e.strip_backtrace(), "This feature is not implemented: Registering new schemas is not supported"), + }; + } + + #[tokio::test] + async fn test_mem_provider() { + let provider = MemorySchemaProvider::new(); + let table_name = "test_table_exist"; + assert!(!provider.table_exist(table_name)); + assert!(provider.deregister_table(table_name).unwrap().is_none()); + let test_table = EmptyTable::new(Arc::new(Schema::empty())); + // register table successfully + assert!(provider + .register_table(table_name.to_string(), Arc::new(test_table)) + .unwrap() + .is_none()); + assert!(provider.table_exist(table_name)); + let other_table = EmptyTable::new(Arc::new(Schema::empty())); + let result = + provider.register_table(table_name.to_string(), Arc::new(other_table)); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_schema_register_listing_table() { + let testdata = crate::test_util::parquet_test_data(); + let testdir = if testdata.starts_with('/') { + format!("file://{testdata}") + } else { + format!("file:///{testdata}") + }; + let filename = if testdir.ends_with('/') { + format!("{}{}", testdir, "alltypes_plain.parquet") + } else { + format!("{}/{}", testdir, "alltypes_plain.parquet") + }; + + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let catalog = MemoryCatalogProvider::new(); + let schema = MemorySchemaProvider::new(); + + let ctx = SessionContext::new(); + + let config = ListingTableConfig::new(table_path) + .infer(&ctx.state()) + .await + .unwrap(); + let table = ListingTable::try_new(config).unwrap(); + + schema + .register_table("alltypes_plain".to_string(), Arc::new(table)) + .unwrap(); + + catalog.register_schema("active", Arc::new(schema)).unwrap(); + ctx.register_catalog("cat", Arc::new(catalog)); + + let df = ctx + .sql("SELECT id, bool_col FROM cat.active.alltypes_plain") + .await + .unwrap(); + + let actual = df.collect().await.unwrap(); + + let expected = [ + "+----+----------+", + "| id | bool_col |", + "+----+----------+", + "| 4 | true |", + "| 5 | false |", + "| 6 | true |", + "| 7 | false |", + "| 2 | true |", + "| 3 | false |", + "| 0 | true |", + "| 1 | false |", + "+----+----------+", + ]; + assert_batches_eq!(expected, &actual); + } +} diff --git a/datafusion/core/src/catalog_common/mod.rs b/datafusion/core/src/catalog_common/mod.rs new file mode 100644 index 000000000000..68c78dda4899 --- /dev/null +++ b/datafusion/core/src/catalog_common/mod.rs @@ -0,0 +1,273 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Interfaces and default implementations of catalogs and schemas. +//! +//! Implementations +//! * Simple memory based catalog: [`MemoryCatalogProviderList`], [`MemoryCatalogProvider`], [`MemorySchemaProvider`] +//! * Information schema: [`information_schema`] +//! * Listing schema: [`listing_schema`] + +pub mod information_schema; +pub mod listing_schema; +pub mod memory; + +pub use crate::catalog::{CatalogProvider, CatalogProviderList, SchemaProvider}; +pub use memory::{ + MemoryCatalogProvider, MemoryCatalogProviderList, MemorySchemaProvider, +}; + +pub use datafusion_sql::{ResolvedTableReference, TableReference}; + +use std::collections::BTreeSet; +use std::ops::ControlFlow; + +/// Collects all tables and views referenced in the SQL statement. CTEs are collected separately. +/// This can be used to determine which tables need to be in the catalog for a query to be planned. +/// +/// # Returns +/// +/// A `(table_refs, ctes)` tuple, the first element contains table and view references and the second +/// element contains any CTE aliases that were defined and possibly referenced. +/// +/// ## Example +/// +/// ``` +/// # use datafusion_sql::parser::DFParser; +/// # use datafusion::catalog_common::resolve_table_references; +/// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)"; +/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); +/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); +/// assert_eq!(table_refs.len(), 2); +/// assert_eq!(table_refs[0].to_string(), "bar"); +/// assert_eq!(table_refs[1].to_string(), "foo"); +/// assert_eq!(ctes.len(), 0); +/// ``` +/// +/// ## Example with CTEs +/// +/// ``` +/// # use datafusion_sql::parser::DFParser; +/// # use datafusion::catalog_common::resolve_table_references; +/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;"; +/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); +/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); +/// assert_eq!(table_refs.len(), 0); +/// assert_eq!(ctes.len(), 1); +/// assert_eq!(ctes[0].to_string(), "my_cte"); +/// ``` +pub fn resolve_table_references( + statement: &datafusion_sql::parser::Statement, + enable_ident_normalization: bool, +) -> datafusion_common::Result<(Vec, Vec)> { + use crate::sql::planner::object_name_to_table_reference; + use datafusion_sql::parser::{ + CopyToSource, CopyToStatement, Statement as DFStatement, + }; + use information_schema::INFORMATION_SCHEMA; + use information_schema::INFORMATION_SCHEMA_TABLES; + use sqlparser::ast::*; + + struct RelationVisitor { + relations: BTreeSet, + all_ctes: BTreeSet, + ctes_in_scope: Vec, + } + + impl RelationVisitor { + /// Record the reference to `relation`, if it's not a CTE reference. + fn insert_relation(&mut self, relation: &ObjectName) { + if !self.relations.contains(relation) + && !self.ctes_in_scope.contains(relation) + { + self.relations.insert(relation.clone()); + } + } + } + + impl Visitor for RelationVisitor { + type Break = (); + + fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> { + self.insert_relation(relation); + ControlFlow::Continue(()) + } + + fn pre_visit_query(&mut self, q: &Query) -> ControlFlow { + if let Some(with) = &q.with { + for cte in &with.cte_tables { + // The non-recursive CTE name is not in scope when evaluating the CTE itself, so this is valid: + // `WITH t AS (SELECT * FROM t) SELECT * FROM t` + // Where the first `t` refers to a predefined table. So we are careful here + // to visit the CTE first, before putting it in scope. + if !with.recursive { + // This is a bit hackish as the CTE will be visited again as part of visiting `q`, + // but thankfully `insert_relation` is idempotent. + cte.visit(self); + } + self.ctes_in_scope + .push(ObjectName(vec![cte.alias.name.clone()])); + } + } + ControlFlow::Continue(()) + } + + fn post_visit_query(&mut self, q: &Query) -> ControlFlow { + if let Some(with) = &q.with { + for _ in &with.cte_tables { + // Unwrap: We just pushed these in `pre_visit_query` + self.all_ctes.insert(self.ctes_in_scope.pop().unwrap()); + } + } + ControlFlow::Continue(()) + } + + fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> { + if let Statement::ShowCreate { + obj_type: ShowCreateObject::Table | ShowCreateObject::View, + obj_name, + } = statement + { + self.insert_relation(obj_name) + } + + // SHOW statements will later be rewritten into a SELECT from the information_schema + let requires_information_schema = matches!( + statement, + Statement::ShowFunctions { .. } + | Statement::ShowVariable { .. } + | Statement::ShowStatus { .. } + | Statement::ShowVariables { .. } + | Statement::ShowCreate { .. } + | Statement::ShowColumns { .. } + | Statement::ShowTables { .. } + | Statement::ShowCollation { .. } + ); + if requires_information_schema { + for s in INFORMATION_SCHEMA_TABLES { + self.relations.insert(ObjectName(vec![ + Ident::new(INFORMATION_SCHEMA), + Ident::new(*s), + ])); + } + } + ControlFlow::Continue(()) + } + } + + let mut visitor = RelationVisitor { + relations: BTreeSet::new(), + all_ctes: BTreeSet::new(), + ctes_in_scope: vec![], + }; + + fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) { + match statement { + DFStatement::Statement(s) => { + let _ = s.as_ref().visit(visitor); + } + DFStatement::CreateExternalTable(table) => { + visitor.relations.insert(table.name.clone()); + } + DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { + CopyToSource::Relation(table_name) => { + visitor.insert_relation(table_name); + } + CopyToSource::Query(query) => { + query.visit(visitor); + } + }, + DFStatement::Explain(explain) => visit_statement(&explain.statement, visitor), + } + } + + visit_statement(statement, &mut visitor); + + let table_refs = visitor + .relations + .into_iter() + .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) + .collect::>()?; + let ctes = visitor + .all_ctes + .into_iter() + .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) + .collect::>()?; + Ok((table_refs, ctes)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_table_references_shadowed_cte() { + use datafusion_sql::parser::DFParser; + + // An interesting edge case where the `t` name is used both as an ordinary table reference + // and as a CTE reference. + let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t"; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 1); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "t"); + assert_eq!(table_refs[0].to_string(), "t"); + + // UNION is a special case where the CTE is not in scope for the second branch. + let query = "(with t as (select 1) select * from t) union (select * from t)"; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 1); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "t"); + assert_eq!(table_refs[0].to_string(), "t"); + + // Nested CTEs are also handled. + // Here the first `u` is a CTE, but the second `u` is a table reference. + // While `t` is always a CTE. + let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)"; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 1); + assert_eq!(ctes.len(), 2); + assert_eq!(ctes[0].to_string(), "t"); + assert_eq!(ctes[1].to_string(), "u"); + assert_eq!(table_refs[0].to_string(), "u"); + } + + #[test] + fn resolve_table_references_recursive_cte() { + use datafusion_sql::parser::DFParser; + + let query = " + WITH RECURSIVE nodes AS ( + SELECT 1 as id + UNION ALL + SELECT id + 1 as id + FROM nodes + WHERE id < 10 + ) + SELECT * FROM nodes + "; + let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap(); + let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap(); + assert_eq!(table_refs.len(), 0); + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].to_string(), "nodes"); + } +} diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index f877b7d698b4..2c71cb80d755 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -21,11 +21,15 @@ mod parquet; use std::any::Any; +use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; use crate::arrow::record_batch::RecordBatch; use crate::arrow::util::pretty; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::format_as_file_type; +use crate::datasource::file_format::json::JsonFormatFactory; use crate::datasource::{provider_as_source, MemTable, TableProvider}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; @@ -44,22 +48,28 @@ use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; use arrow::datatypes::{DataType, Field}; use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::config::{CsvOptions, FormatOptions, JsonOptions}; +use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ - avg, count, is_null, max, median, min, stddev, utils::COUNT_STAR_EXPANSION, - TableProviderFilterPushDown, UNNAMED_TABLE, + utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, +}; +use datafusion_functions_aggregate::expr_fn::{ + avg, count, max, median, min, stddev, sum, }; use async_trait::async_trait; +use datafusion_catalog::Session; /// Contains options that control how data is /// written out from a DataFrame pub struct DataFrameWriteOptions { - /// Controls if existing data should be overwritten - overwrite: bool, + /// Controls how new data should be written to the table, determining whether + /// to append, overwrite, or replace existing data. + insert_op: InsertOp, /// Controls if all partitions should be coalesced into a single output file /// Generally will have slower performance when set to true. single_file_output: bool, @@ -72,14 +82,15 @@ impl DataFrameWriteOptions { /// Create a new DataFrameWriteOptions with default values pub fn new() -> Self { DataFrameWriteOptions { - overwrite: false, + insert_op: InsertOp::Append, single_file_output: false, partition_by: vec![], } } - /// Set the overwrite option to true or false - pub fn with_overwrite(mut self, overwrite: bool) -> Self { - self.overwrite = overwrite; + + /// Set the insert operation + pub fn with_insert_operation(mut self, insert_op: InsertOp) -> Self { + self.insert_op = insert_op; self } @@ -110,15 +121,15 @@ impl Default for DataFrameWriteOptions { /// The typical workflow using DataFrames looks like /// /// 1. Create a DataFrame via methods on [SessionContext], such as [`read_csv`] -/// and [`read_parquet`]. +/// and [`read_parquet`]. /// /// 2. Build a desired calculation by calling methods such as [`filter`], -/// [`select`], [`aggregate`], and [`limit`] +/// [`select`], [`aggregate`], and [`limit`] /// /// 3. Execute into [`RecordBatch`]es by calling [`collect`] /// /// A `DataFrame` is a wrapper around a [`LogicalPlan`] and the [`SessionState`] -/// required for execution. +/// required for execution. /// /// DataFrames are "lazy" in the sense that most methods do not actually compute /// anything, they just build up a plan. Calling [`collect`] executes the plan @@ -139,6 +150,7 @@ impl Default for DataFrameWriteOptions { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; +/// # use datafusion::functions_aggregate::expr_fn::min; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -174,6 +186,33 @@ impl DataFrame { } } + /// Creates logical expression from a SQL query text. + /// The expression is created and processed against the current schema. + /// + /// # Example: Parsing SQL queries + /// ``` + /// # use arrow::datatypes::{DataType, Field, Schema}; + /// # use datafusion::prelude::*; + /// # use datafusion_common::{DFSchema, Result}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// // datafusion will parse number as i64 first. + /// let sql = "a > 1 and b in (1, 10)"; + /// let expected = col("a").gt(lit(1 as i64)) + /// .and(col("b").in_list(vec![lit(1 as i64), lit(10 as i64)], false)); + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let expr = df.parse_sql_expr(sql)?; + /// assert_eq!(expected, expr); + /// # Ok(()) + /// # } + /// ``` + pub fn parse_sql_expr(&self, sql: &str) -> Result { + let df_schema = self.schema(); + + self.session_state.create_logical_expr(sql, df_schema) + } + /// Consume the DataFrame and produce a physical plan pub async fn create_physical_plan(self) -> Result> { self.session_state.create_physical_plan(&self.plan).await @@ -185,11 +224,20 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; /// let df = df.select_columns(&["a", "b"])?; + /// let expected = vec![ + /// "+---+---+", + /// "| a | b |", + /// "+---+---+", + /// "| 1 | 2 |", + /// "+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -208,6 +256,31 @@ impl DataFrame { .collect(); self.select(expr) } + /// Project arbitrary list of expression strings into a new `DataFrame`. + /// Method will parse string expressions into logical plan expressions. + /// + /// The output `DataFrame` has one column for each element in `exprs`. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df : DataFrame = df.select_exprs(&["a * b", "c"])?; + /// # Ok(()) + /// # } + /// ``` + pub fn select_exprs(self, exprs: &[&str]) -> Result { + let expr_list = exprs + .iter() + .map(|e| self.parse_sql_expr(e)) + .collect::>>()?; + + self.select(expr_list) + } /// Project arbitrary expressions (like SQL SELECT expressions) into a new /// `DataFrame`. @@ -218,11 +291,20 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let df = df.select(vec![col("a") * col("b"), col("c")])?; + /// let df = df.select(vec![col("a"), col("b") * col("c")])?; + /// let expected = vec![ + /// "+---+-----------------------+", + /// "| a | ?table?.b * ?table?.c |", + /// "+---+-----------------------+", + /// "| 1 | 6 |", + /// "+---+-----------------------+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -241,42 +323,84 @@ impl DataFrame { }) } - /// Expand each list element of a column to multiple rows. - #[deprecated(since = "37.0.0", note = "use unnest_columns instead")] - pub fn unnest_column(self, column: &str) -> Result { - self.unnest_columns(&[column]) - } - - /// Expand each list element of a column to multiple rows, with - /// behavior controlled by [`UnnestOptions`]. + /// Returns a new DataFrame containing all columns except the specified columns. /// - /// Please see the documentation on [`UnnestOptions`] for more - /// details about the meaning of unnest. - #[deprecated(since = "37.0.0", note = "use unnest_columns_with_options instead")] - pub fn unnest_column_with_options( - self, - column: &str, - options: UnnestOptions, - ) -> Result { - self.unnest_columns_with_options(&[column], options) + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// // +----+----+----+ + /// // | a | b | c | + /// // +----+----+----+ + /// // | 1 | 2 | 3 | + /// // +----+----+----+ + /// let df = df.drop_columns(&["a"])?; + /// let expected = vec![ + /// "+---+---+", + /// "| b | c |", + /// "+---+---+", + /// "| 2 | 3 |", + /// "+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); + /// # Ok(()) + /// # } + /// ``` + pub fn drop_columns(self, columns: &[&str]) -> Result { + let fields_to_drop = columns + .iter() + .map(|name| { + self.plan + .schema() + .qualified_field_with_unqualified_name(name) + }) + .filter(|r| r.is_ok()) + .collect::>>()?; + let expr: Vec = self + .plan + .schema() + .fields() + .into_iter() + .enumerate() + .map(|(idx, _)| self.plan.schema().qualified_field(idx)) + .filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f))) + .map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field)))) + .collect(); + self.select(expr) } - /// Expand multiple list columns into a set of rows. - /// - /// See also: + /// Expand multiple list/struct columns into a set of rows and new columns. /// - /// 1. [`UnnestOptions`] documentation for the behavior of `unnest` - /// 2. [`Self::unnest_column_with_options`] + /// See also: [`UnnestOptions`] documentation for the behavior of `unnest` /// /// # Example /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let df = df.unnest_columns(&["a", "b"])?; + /// let df = ctx.read_json("tests/data/unnest.json", NdJsonReadOptions::default()).await?; + /// // expand into multiple columns if it's json array, flatten field name if it's nested structure + /// let df = df.unnest_columns(&["b","c","d"])?; + /// let expected = vec![ + /// "+---+------+-------+-----+-----+", + /// "| a | b | c | d.e | d.f |", + /// "+---+------+-------+-----+-----+", + /// "| 1 | 2.0 | false | 1 | 2 |", + /// "| 1 | 1.3 | true | 1 | 2 |", + /// "| 1 | -6.1 | | 1 | 2 |", + /// "| 2 | 3.0 | false | | |", + /// "| 2 | 2.3 | true | | |", + /// "| 2 | -7.1 | | | |", + /// "+---+------+-------+-----+-----+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -314,11 +438,23 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))?; + /// // all rows where a <= b are returned + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -339,16 +475,35 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion::functions_aggregate::expr_fn::min; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" - /// let _ = df.clone().aggregate(vec![col("a")], vec![min(col("b"))])?; - /// + /// let df1 = df.clone().aggregate(vec![col("a")], vec![min(col("b"))])?; + /// let expected1 = vec![ + /// "+---+----------------+", + /// "| a | min(?table?.b) |", + /// "+---+----------------+", + /// "| 1 | 2 |", + /// "| 4 | 5 |", + /// "| 7 | 8 |", + /// "+---+----------------+" + /// ]; + /// assert_batches_sorted_eq!(expected1, &df1.collect().await?); /// // The following use is the equivalent of "SELECT MIN(b)" - /// let _ = df.aggregate(vec![], vec![min(col("b"))])?; + /// let df2 = df.aggregate(vec![], vec![min(col("b"))])?; + /// let expected2 = vec![ + /// "+----------------+", + /// "| min(?table?.b) |", + /// "+----------------+", + /// "| 2 |", + /// "+----------------+" + /// ]; + /// # assert_batches_sorted_eq!(expected2, &df2.collect().await?); /// # Ok(()) /// # } /// ``` @@ -357,9 +512,26 @@ impl DataFrame { group_expr: Vec, aggr_expr: Vec, ) -> Result { + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let aggr_expr_len = aggr_expr.len(); let plan = LogicalPlanBuilder::from(self.plan) .aggregate(group_expr, aggr_expr)? .build()?; + let plan = if is_grouping_set { + let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len; + // For grouping sets we do a project to not expose the internal grouping id + let exprs = plan + .schema() + .columns() + .into_iter() + .enumerate() + .filter(|(idx, _)| *idx != grouping_id_pos) + .map(|(_, column)| Expr::Column(column)) + .collect::>(); + LogicalPlanBuilder::from(plan).project(exprs)?.build()? + } else { + plan + }; Ok(DataFrame { session_state: self.session_state, plan, @@ -388,11 +560,21 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let df = df.limit(0, Some(100))?; + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let df = df.limit(1, Some(2))?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -414,12 +596,22 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? ; /// let d2 = df.clone(); /// let df = df.union(d2)?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "| 1 | 2 | 3 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -442,12 +634,22 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; /// let d2 = df.clone(); /// let df = df.union_distinct(d2)?; + /// // df2 are duplicate of df + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -467,11 +669,20 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; /// let df = df.distinct()?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -483,6 +694,47 @@ impl DataFrame { }) } + /// Return a new `DataFrame` with duplicated rows removed as per the specified expression list + /// according to the provided sorting expressions grouped by the `DISTINCT ON` clause + /// expressions. + /// + /// # Example + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new(); + /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await? + /// // Return a single row (a, b) for each distinct value of a + /// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?; + /// let expected = vec![ + /// "+---+---+", + /// "| a | b |", + /// "+---+---+", + /// "| 1 | 2 |", + /// "+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); + /// # Ok(()) + /// # } + /// ``` + pub fn distinct_on( + self, + on_expr: Vec, + select_expr: Vec, + sort_expr: Option>, + ) -> Result { + let plan = LogicalPlanBuilder::from(self.plan) + .distinct_on(on_expr, select_expr, sort_expr)? + .build()?; + Ok(DataFrame { + session_state: self.session_state, + plan, + }) + } + /// Return a new `DataFrame` that has statistics for a DataFrame. /// /// Only summarizes numeric datatypes at the moment and returns nulls for @@ -493,12 +745,26 @@ impl DataFrame { /// # use datafusion::prelude::*; /// # use datafusion::error::Result; /// # use arrow::util::pretty; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/tpch-csv/customer.csv", CsvReadOptions::new()).await?; - /// df.describe().await.unwrap(); - /// + /// let stat = df.describe().await?; + /// # // some output column are ignored + /// let expected = vec![ + /// "+------------+--------------------+--------------------+------------------------------------+--------------------+-----------------+--------------------+--------------+----------------------------------------------------------------------------------------------------------+", + /// "| describe | c_custkey | c_name | c_address | c_nationkey | c_phone | c_acctbal | c_mktsegment | c_comment |", + /// "+------------+--------------------+--------------------+------------------------------------+--------------------+-----------------+--------------------+--------------+----------------------------------------------------------------------------------------------------------+", + /// "| count | 9.0 | 9 | 9 | 9.0 | 9 | 9.0 | 9 | 9 |", + /// "| max | 10.0 | Customer#000000010 | xKiAFTjUsCuxfeleNqefumTrjS | 20.0 | 30-114-968-4951 | 9561.95 | MACHINERY | tions. even deposits boost according to the slyly bold packages. final accounts cajole requests. furious |", + /// "| mean | 6.0 | null | null | 9.88888888888889 | null | 5153.2155555555555 | null | null |", + /// "| median | 6.0 | null | null | 8.0 | null | 6819.74 | null | null |", + /// "| min | 2.0 | Customer#000000002 | 6LrEaV6KR6PLVcgl2ArL Q3rqzLzcT1 v2 | 1.0 | 11-719-748-3364 | 121.65 | AUTOMOBILE | deposits eat slyly ironic, even instructions. express foxes detect slyly. blithely even accounts abov |", + /// "| null_count | 0.0 | 0 | 0 | 0.0 | 0 | 0.0 | 0 | 0 |", + /// "| std | 2.7386127875258306 | null | null | 7.2188026092359046 | null | 3522.169804254585 | null | null |", + /// "+------------+--------------------+--------------------+------------------------------------+--------------------+-----------------+--------------------+--------------+----------------------------------------------------------------------------------------------------------+"]; + /// assert_batches_sorted_eq!(expected, &stat.collect().await?); /// # Ok(()) /// # } /// ``` @@ -534,7 +800,13 @@ impl DataFrame { vec![], original_schema_fields .clone() - .map(|f| count(is_null(col(f.name()))).alias(f.name())) + .map(|f| { + sum(case(is_null(col(f.name()))) + .when(lit(true), lit(1)) + .otherwise(lit(0)) + .unwrap()) + .alias(f.name()) + }) .collect::>(), ), // mean aggregation @@ -607,7 +879,10 @@ impl DataFrame { { let column = batchs[0].column_by_name(field.name()).unwrap(); - if field.data_type().is_numeric() { + + if column.data_type().is_null() { + Arc::new(StringArray::from(vec!["null"])) + } else if field.data_type().is_numeric() { cast(column, &DataType::Float64)? } else { cast(column, &DataType::Utf8)? @@ -662,6 +937,15 @@ impl DataFrame { }) } + /// Apply a sort by provided expressions with default direction + pub fn sort_by(self, expr: Vec) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.sort(true, false)) + .collect::>(), + ) + } + /// Sort the DataFrame by the specified sorting expressions. /// /// Note that any expression can be turned into @@ -672,18 +956,29 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; /// let df = df.sort(vec![ - /// col("a").sort(true, true), // a ASC, nulls first - /// col("b").sort(false, false), // b DESC, nulls last + /// col("a").sort(false, true), // a DESC, nulls first + /// col("b").sort(true, false), // b ASC, nulls last /// ])?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+", + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` - pub fn sort(self, expr: Vec) -> Result { + pub fn sort(self, expr: Vec) -> Result { let plan = LogicalPlanBuilder::from(self.plan).sort(expr)?.build()?; Ok(DataFrame { session_state: self.session_state, @@ -709,6 +1004,7 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -721,7 +1017,14 @@ impl DataFrame { /// // Perform the equivalent of `left INNER JOIN right ON (a = a2 AND b = b2)` /// // finding all pairs of rows from `left` and `right` where `a = a2` and `b = b2`. /// let join = left.join(right, JoinType::Inner, &["a", "b"], &["a2", "b2"], None)?; - /// let batches = join.collect().await?; + /// let expected = vec![ + /// "+---+---+---+----+----+----+", + /// "| a | b | c | a2 | b2 | c2 |", + /// "+---+---+---+----+----+----+", + /// "| 1 | 2 | 3 | 1 | 2 | 3 |", + /// "+---+---+---+----+----+----+" + /// ]; + /// assert_batches_sorted_eq!(expected, &join.collect().await?); /// # Ok(()) /// # } /// ``` @@ -758,6 +1061,7 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); @@ -781,7 +1085,13 @@ impl DataFrame { /// JoinType::Inner, /// [col("a").not_eq(col("a2")), col("b").not_eq(col("b2"))], /// )?; - /// let batches = join_on.collect().await?; + /// let expected = vec![ + /// "+---+---+---+----+----+----+", + /// "| a | b | c | a2 | b2 | c2 |", + /// "+---+---+---+----+----+----+", + /// "+---+---+---+----+----+----+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &join_on.collect().await?); /// # Ok(()) /// # } /// ``` @@ -791,9 +1101,8 @@ impl DataFrame { join_type: JoinType, on_exprs: impl IntoIterator, ) -> Result { - let expr = on_exprs.into_iter().reduce(Expr::and); let plan = LogicalPlanBuilder::from(self.plan) - .join_on(right.plan, join_type, expr)? + .join_on(right.plan, join_type, on_exprs)? .build()?; Ok(DataFrame { session_state: self.session_state, @@ -807,11 +1116,22 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; /// let df1 = df.repartition(Partitioning::RoundRobinBatch(4))?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df1.collect().await?); /// # Ok(()) /// # } /// ``` @@ -838,16 +1158,14 @@ impl DataFrame { /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let count = df.count().await?; + /// let count = df.count().await?; // 1 + /// # assert_eq!(count, 1); /// # Ok(()) /// # } /// ``` pub async fn count(self) -> Result { let rows = self - .aggregate( - vec![], - vec![datafusion_expr::count(Expr::Literal(COUNT_STAR_EXPANSION))], - )? + .aggregate(vec![], vec![count(Expr::Literal(COUNT_STAR_EXPANSION))])? .collect() .await?; let len = *rows @@ -1026,7 +1344,9 @@ impl DataFrame { } /// Return a reference to the unoptimized [`LogicalPlan`] that comprises - /// this DataFrame. See [`Self::into_unoptimized_plan`] for more details. + /// this DataFrame. + /// + /// See [`Self::into_unoptimized_plan`] for more details. pub fn logical_plan(&self) -> &LogicalPlan { &self.plan } @@ -1043,6 +1363,9 @@ impl DataFrame { /// snapshot of the [`SessionState`] attached to this [`DataFrame`] and /// consequently subsequent operations may take place against a different /// state (e.g. a different value of `now()`) + /// + /// See [`Self::into_parts`] to retrieve the owned [`LogicalPlan`] and + /// corresponding [`SessionState`]. pub fn into_unoptimized_plan(self) -> LogicalPlan { self.plan } @@ -1050,7 +1373,7 @@ impl DataFrame { /// Return the optimized [`LogicalPlan`] represented by this DataFrame. /// /// Note: This method should not be used outside testing -- see - /// [`Self::into_optimized_plan`] for more details. + /// [`Self::into_unoptimized_plan`] for more details. pub fn into_optimized_plan(self) -> Result { // Optimize the plan first for better UX self.session_state.optimize(&self.plan) @@ -1117,12 +1440,21 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let d2 = df.clone(); + /// let d2 = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; /// let df = df.intersect(d2)?; + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 1 | 2 | 3 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &df.collect().await?); /// # Ok(()) /// # } /// ``` @@ -1141,12 +1473,23 @@ impl DataFrame { /// ``` /// # use datafusion::prelude::*; /// # use datafusion::error::Result; + /// # use datafusion_common::assert_batches_sorted_eq; /// # #[tokio::main] /// # async fn main() -> Result<()> { /// let ctx = SessionContext::new(); - /// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; - /// let d2 = df.clone(); - /// let df = df.except(d2)?; + /// let df = ctx.read_csv("tests/data/example_long.csv", CsvReadOptions::new()).await?; + /// let d2 = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + /// let result = df.except(d2)?; + /// // those columns are not in example.csv, but in example_long.csv + /// let expected = vec![ + /// "+---+---+---+", + /// "| a | b | c |", + /// "+---+---+---+", + /// "| 4 | 5 | 6 |", + /// "| 7 | 8 | 9 |", + /// "+---+---+---+" + /// ]; + /// # assert_batches_sorted_eq!(expected, &result.collect().await?); /// # Ok(()) /// # } /// ``` @@ -1179,7 +1522,7 @@ impl DataFrame { self.plan, table_name.to_owned(), &arrow_schema, - write_options.overwrite, + write_options.insert_op, )? .build()?; @@ -1220,18 +1563,25 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_csv.", + options.insert_op + ))); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().csv); + + let format = if let Some(csv_opts) = writer_options { + Arc::new(CsvFormatFactory::new_with_options(csv_opts)) + } else { + Arc::new(CsvFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::CSV(props), + file_type, HashMap::new(), options.partition_by, )? @@ -1274,19 +1624,25 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_json.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_json.", + options.insert_op + ))); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().json); + let format = if let Some(json_opts) = writer_options { + Arc::new(JsonFormatFactory::new_with_options(json_opts)) + } else { + Arc::new(JsonFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::JSON(props), + file_type, Default::default(), options.partition_by, )? @@ -1316,23 +1672,32 @@ impl DataFrame { /// ``` pub fn with_column(self, name: &str, expr: Expr) -> Result { let window_func_exprs = find_window_exprs(&[expr.clone()]); - let plan = if window_func_exprs.is_empty() { - self.plan + + let (window_fn_str, plan) = if window_func_exprs.is_empty() { + (None, self.plan) } else { - LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? + ( + Some(window_func_exprs[0].to_string()), + LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)?, + ) }; - let new_column = expr.alias(name); let mut col_exists = false; + let new_column = expr.alias(name); let mut fields: Vec = plan .schema() .iter() - .map(|(qualifier, field)| { + .filter_map(|(qualifier, field)| { if field.name() == name { col_exists = true; - new_column.clone() + Some(new_column.clone()) } else { - col(Column::from((qualifier, field))) + let e = col(Column::from((qualifier, field))); + window_fn_str + .as_ref() + .filter(|s| *s == &e.to_string()) + .is_none() + .then_some(e) } }) .collect(); @@ -1354,7 +1719,7 @@ impl DataFrame { /// /// The method supports case sensitive rename with wrapping column name into one of following symbols ( " or ' or ` ) /// - /// Alternatively setting Datafusion param `datafusion.sql_parser.enable_ident_normalization` to `false` will enable + /// Alternatively setting DataFusion param `datafusion.sql_parser.enable_ident_normalization` to `false` will enable /// case sensitive rename without need to wrap column name into special symbols /// /// # Example @@ -1428,7 +1793,7 @@ impl DataFrame { /// # #[tokio::main] /// # async fn main() -> Result<()> { /// # use datafusion_common::ScalarValue; - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; /// let results = ctx /// .sql("SELECT a FROM example WHERE b = $1") @@ -1507,6 +1872,7 @@ impl DataFrame { } } +#[derive(Debug)] struct DataFrameTableProvider { plan: LogicalPlan, } @@ -1517,8 +1883,8 @@ impl TableProvider for DataFrameTableProvider { self } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { - Some(&self.plan) + fn get_logical_plan(&self) -> Option> { + Some(Cow::Borrowed(&self.plan)) } fn supports_filters_pushdown( @@ -1540,7 +1906,7 @@ impl TableProvider for DataFrameTableProvider { async fn scan( &self, - state: &SessionState, + state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], limit: Option, @@ -1575,15 +1941,20 @@ mod tests { use crate::physical_plan::{ColumnarValue, Partitioning, PhysicalExpr}; use crate::test_util::{register_aggregate_csv, test_table, test_table_with_name}; - use arrow::array::{self, Int32Array}; - use datafusion_common::{Constraint, Constraints}; + use arrow::array::Int32Array; + use datafusion_common::{assert_batches_eq, Constraint, Constraints, ScalarValue}; use datafusion_common_runtime::SpawnedTask; + use datafusion_expr::expr::WindowFunction; use datafusion_expr::{ - cast, count_distinct, create_udf, expr, lit, sum, BuiltInWindowFunction, - ScalarFunctionImplementation, Volatility, WindowFrame, WindowFunctionDefinition, + cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt, + ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct}; + use datafusion_functions_window::expr_fn::row_number; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; + use sqlparser::ast::NullTreatment; // Get string representation of the plan async fn assert_physical_plan(df: &DataFrame, expected: Vec<&str>) { @@ -1608,8 +1979,8 @@ mod tests { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), - Arc::new(array::StringArray::from(vec!["a"])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(StringArray::from(vec!["a"])), ], ) .unwrap(); @@ -1657,6 +2028,43 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_coalesce_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let query = r#"SELECT COALESCE(null, 5)"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_coalesce_from_values_schema() -> Result<()> { + let ctx = SessionContext::new(); + + let query = r#"SELECT COALESCE(column1, column2) FROM VALUES (null, 1.2)"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + + #[tokio::test] + async fn test_coalesce_from_values_schema_multiple_rows() -> Result<()> { + let ctx = SessionContext::new(); + + let query = r#"SELECT COALESCE(column1, column2) + FROM VALUES + (null, 1.2), + (1.1, null), + (2, 5);"#; + + let result = ctx.sql(query).await?; + assert_logical_expr_schema_eq_physical_expr_schema(result).await?; + Ok(()) + } + #[tokio::test] async fn test_array_agg_schema() -> Result<()> { let ctx = SessionContext::new(); @@ -1740,20 +2148,43 @@ mod tests { Ok(()) } + #[tokio::test] + async fn select_exprs() -> Result<()> { + // build plan using `select_expr`` + let t = test_table().await?; + let plan = t + .clone() + .select_exprs(&["c1", "c2", "c11", "c2 * c11"])? + .plan; + + // build plan using select + let expected_plan = t + .select(vec![ + col("c1"), + col("c2"), + col("c11"), + col("c2") * col("c11"), + ])? + .plan; + + assert_same_plan(&expected_plan, &plan); + + Ok(()) + } + #[tokio::test] async fn select_with_window_exprs() -> Result<()> { // build plan using Table API let t = test_table().await?; - let first_row = Expr::WindowFunction(expr::WindowFunction::new( + let first_row = Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::BuiltInWindowFunction( BuiltInWindowFunction::FirstValue, ), vec![col("aggregate_test_100.c1")], - vec![col("aggregate_test_100.c2")], - vec![], - WindowFrame::new(None), - None, - )); + )) + .partition_by(vec![col("aggregate_test_100.c2")]) + .build() + .unwrap(); let t2 = t.select(vec![col("c1"), first_row])?; let plan = t2.plan.clone(); @@ -1788,90 +2219,223 @@ mod tests { } #[tokio::test] - async fn aggregate() -> Result<()> { - // build plan using DataFrame API - let df = test_table().await?; - let group_expr = vec![col("c1")]; - let aggr_expr = vec![ - min(col("c12")), - max(col("c12")), - avg(col("c12")), - sum(col("c12")), - count(col("c12")), - count_distinct(col("c12")), - ]; + async fn drop_columns() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&["c2", "c11"])?; + let plan = t2.plan.clone(); - let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100", + ) + .await?; - assert_batches_sorted_eq!( - ["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", - "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | SUM(aggregate_test_100.c12) | COUNT(aggregate_test_100.c12) | COUNT(DISTINCT aggregate_test_100.c12) |", - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", - "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", - "| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |", - "| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |", - "| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |", - "| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |", - "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+"], - &df - ); + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); Ok(()) } #[tokio::test] - async fn test_aggregate_with_pk() -> Result<()> { - // create the dataframe - let config = SessionConfig::new().with_target_partitions(1); - let ctx = SessionContext::new_with_config(config); - - let df = ctx.read_table(table_with_constraints())?; - - // GROUP BY id - let group_expr = vec![col("id")]; - let aggr_expr = vec![]; - let df = df.aggregate(group_expr, aggr_expr)?; + async fn drop_columns_with_duplicates() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&["c2", "c11", "c2", "c2"])?; + let plan = t2.plan.clone(); - // Since id and name are functionally dependant, we can use name among - // expression even if it is not part of the group by expression and can - // select "name" column even though it wasn't explicitly grouped - let df = df.select(vec![col("id"), col("name")])?; - assert_physical_plan( - &df, - vec![ - "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", - " MemoryExec: partitions=1, partition_sizes=[1]", - ], + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c12,c13 FROM aggregate_test_100", ) - .await; - - let df_results = df.collect().await?; + .await?; - #[rustfmt::skip] - assert_batches_sorted_eq!([ - "+----+------+", - "| id | name |", - "+----+------+", - "| 1 | a |", - "+----+------+" - ], - &df_results - ); + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); Ok(()) } #[tokio::test] - async fn test_aggregate_with_pk2() -> Result<()> { - // create the dataframe - let config = SessionConfig::new().with_target_partitions(1); - let ctx = SessionContext::new_with_config(config); + async fn drop_columns_with_nonexistent_columns() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&["canada", "c2", "rocks"])?; + let plan = t2.plan.clone(); - let df = ctx.read_table(table_with_constraints())?; + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100", + ) + .await?; - // GROUP BY id - let group_expr = vec![col("id")]; - let aggr_expr = vec![]; + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + + #[tokio::test] + async fn drop_columns_with_empty_array() -> Result<()> { + // build plan using Table API + let t = test_table().await?; + let t2 = t.drop_columns(&[])?; + let plan = t2.plan.clone(); + + // build query using SQL + let sql_plan = create_plan( + "SELECT c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13 FROM aggregate_test_100", + ) + .await?; + + // the two plans should be identical + assert_same_plan(&plan, &sql_plan); + + Ok(()) + } + + #[tokio::test] + async fn drop_with_quotes() -> Result<()> { + // define data with a column name that has a "." in it: + let array1: Int32Array = [1, 10].into_iter().collect(); + let array2: Int32Array = [2, 11].into_iter().collect(); + let batch = RecordBatch::try_from_iter(vec![ + ("f\"c1", Arc::new(array1) as _), + ("f\"c2", Arc::new(array2) as _), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?.drop_columns(&["f\"c1"])?; + + let df_results = df.collect().await?; + + assert_batches_sorted_eq!( + [ + "+------+", + "| f\"c2 |", + "+------+", + "| 2 |", + "| 11 |", + "+------+" + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn drop_with_periods() -> Result<()> { + // define data with a column name that has a "." in it: + let array1: Int32Array = [1, 10].into_iter().collect(); + let array2: Int32Array = [2, 11].into_iter().collect(); + let batch = RecordBatch::try_from_iter(vec![ + ("f.c1", Arc::new(array1) as _), + ("f.c2", Arc::new(array2) as _), + ])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + + let df = ctx.table("t").await?.drop_columns(&["f.c1"])?; + + let df_results = df.collect().await?; + + assert_batches_sorted_eq!( + ["+------+", "| f.c2 |", "+------+", "| 2 |", "| 11 |", "+------+"], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn aggregate() -> Result<()> { + // build plan using DataFrame API + let df = test_table().await?; + let group_expr = vec![col("c1")]; + let aggr_expr = vec![ + min(col("c12")), + max(col("c12")), + avg(col("c12")), + sum(col("c12")), + count(col("c12")), + count_distinct(col("c12")), + ]; + + let df: Vec = df.aggregate(group_expr, aggr_expr)?.collect().await?; + + assert_batches_sorted_eq!( + ["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", + "| c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |", + "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+", + "| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |", + "| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |", + "| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |", + "| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |", + "| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |", + "+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+"], + &df + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; + let df = df.aggregate(group_expr, aggr_expr)?; + + // Since id and name are functionally dependant, we can use name among + // expression even if it is not part of the group by expression and can + // select "name" column even though it wasn't explicitly grouped + let df = df.select(vec![col("id"), col("name")])?; + assert_physical_plan( + &df, + vec![ + "AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]", + " MemoryExec: partitions=1, partition_sizes=[1]", + ], + ) + .await; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----+------+", + "| id | name |", + "+----+------+", + "| 1 | a |", + "+----+------+" + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_with_pk2() -> Result<()> { + // create the dataframe + let config = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(config); + + let df = ctx.read_table(table_with_constraints())?; + + // GROUP BY id + let group_expr = vec![col("id")]; + let aggr_expr = vec![]; let df = df.aggregate(group_expr, aggr_expr)?; // Predicate refers to id, and name fields: @@ -2036,6 +2600,219 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_aggregate_with_union() -> Result<()> { + let df = test_table().await?; + + let df1 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![min(col("c2"))])? + // SELECT `c1` , min(c2) as `result` + .select(vec![col("c1"), min(col("c2")).alias("result")])?; + let df2 = df + .clone() + // GROUP BY `c1` + .aggregate(vec![col("c1")], vec![max(col("c3"))])? + // SELECT `c1` , max(c3) as `result` + .select(vec![col("c1"), max(col("c3")).alias("result")])?; + + let df_union = df1.union(df2)?; + let df = df_union + // GROUP BY `c1` + .aggregate( + vec![col("c1")], + vec![sum(col("result")).alias("sum_result")], + )? + // SELECT `c1`, sum(result) as `sum_result` + .select(vec![(col("c1")), col("sum_result")])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + [ + "+----+------------+", + "| c1 | sum_result |", + "+----+------------+", + "| a | 84 |", + "| b | 69 |", + "| c | 124 |", + "| d | 126 |", + "| e | 121 |", + "+----+------------+" + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_subexpr() -> Result<()> { + let df = test_table().await?; + + let group_expr = col("c2") + lit(1); + let aggr_expr = sum(col("c3") + lit(2)); + + let df = df + // GROUP BY `c2 + 1` + .aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])? + // SELECT `c2 + 1` as c2 + 10, sum(c3 + 2) + 20 + // SELECT expressions contain aggr_expr and group_expr as subexpressions + .select(vec![ + group_expr.alias("c2") + lit(10), + (aggr_expr + lit(20)).alias("sum"), + ])?; + + let df_results = df.collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!([ + "+----------------+------+", + "| c2 + Int32(10) | sum |", + "+----------------+------+", + "| 12 | 431 |", + "| 13 | 248 |", + "| 14 | 453 |", + "| 15 | 95 |", + "| 16 | -146 |", + "+----------------+------+", + ], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_aggregate_name_collision() -> Result<()> { + let df = test_table().await?; + + let collided_alias = "aggregate_test_100.c2 + aggregate_test_100.c3"; + let group_expr = lit(1).alias(collided_alias); + + let df = df + // GROUP BY 1 + .aggregate(vec![group_expr], vec![])? + // SELECT `aggregate_test_100.c2 + aggregate_test_100.c3` + .select(vec![ + (col("aggregate_test_100.c2") + col("aggregate_test_100.c3")), + ]) + // The select expr has the same display_name as the group_expr, + // but since they are different expressions, it should fail. + .expect_err("Expected error"); + let expected = "Schema error: No field named aggregate_test_100.c2. \ + Valid fields are \"aggregate_test_100.c2 + aggregate_test_100.c3\"."; + assert_eq!(df.strip_backtrace(), expected); + + Ok(()) + } + + #[tokio::test] + async fn window_using_aggregates() -> Result<()> { + // build plan using DataFrame API + let df = test_table().await?.filter(col("c1").eq(lit("a")))?; + let mut aggr_expr = vec![ + ( + datafusion_functions_aggregate::first_last::first_value_udaf(), + "first_value", + ), + ( + datafusion_functions_aggregate::first_last::last_value_udaf(), + "last_val", + ), + ( + datafusion_functions_aggregate::approx_distinct::approx_distinct_udaf(), + "approx_distinct", + ), + ( + datafusion_functions_aggregate::approx_median::approx_median_udaf(), + "approx_median", + ), + ( + datafusion_functions_aggregate::median::median_udaf(), + "median", + ), + (datafusion_functions_aggregate::min_max::max_udaf(), "max"), + (datafusion_functions_aggregate::min_max::min_udaf(), "min"), + ] + .into_iter() + .map(|(func, name)| { + let w = WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(func), + vec![col("c3")], + ); + + Expr::WindowFunction(w) + .null_treatment(NullTreatment::IgnoreNulls) + .order_by(vec![col("c2").sort(true, true), col("c3").sort(true, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Rows, + WindowFrameBound::Preceding(ScalarValue::UInt64(None)), + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))), + )) + .build() + .unwrap() + .alias(name) + }) + .collect::>(); + aggr_expr.extend_from_slice(&[col("c2"), col("c3")]); + + let df: Vec = df.select(aggr_expr)?.collect().await?; + + assert_batches_sorted_eq!( + [ + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + "| first_value | last_val | approx_distinct | approx_median | median | max | min | c2 | c3 |", + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + "| | | | | | | | 1 | -85 |", + "| -85 | -101 | 14 | -12 | -101 | 83 | -101 | 4 | -54 |", + "| -85 | -101 | 17 | -25 | -101 | 83 | -101 | 5 | -31 |", + "| -85 | -12 | 10 | -32 | -12 | 83 | -85 | 3 | 13 |", + "| -85 | -25 | 3 | -56 | -25 | -25 | -85 | 1 | -5 |", + "| -85 | -31 | 18 | -29 | -31 | 83 | -101 | 5 | 36 |", + "| -85 | -38 | 16 | -25 | -38 | 83 | -101 | 4 | 65 |", + "| -85 | -43 | 7 | -43 | -43 | 83 | -85 | 2 | 45 |", + "| -85 | -48 | 6 | -35 | -48 | 83 | -85 | 2 | -43 |", + "| -85 | -5 | 4 | -37 | -5 | -5 | -85 | 1 | 83 |", + "| -85 | -54 | 15 | -17 | -54 | 83 | -101 | 4 | -38 |", + "| -85 | -56 | 2 | -70 | -56 | -56 | -85 | 1 | -25 |", + "| -85 | -72 | 9 | -43 | -72 | 83 | -85 | 3 | -12 |", + "| -85 | -85 | 1 | -85 | -85 | -85 | -85 | 1 | -56 |", + "| -85 | 13 | 11 | -17 | 13 | 83 | -85 | 3 | 14 |", + "| -85 | 13 | 11 | -25 | 13 | 83 | -85 | 3 | 13 |", + "| -85 | 14 | 12 | -12 | 14 | 83 | -85 | 3 | 17 |", + "| -85 | 17 | 13 | -11 | 17 | 83 | -85 | 4 | -101 |", + "| -85 | 45 | 8 | -34 | 45 | 83 | -85 | 3 | -72 |", + "| -85 | 65 | 17 | -17 | 65 | 83 | -101 | 5 | -101 |", + "| -85 | 83 | 5 | -25 | 83 | 83 | -85 | 2 | -48 |", + "+-------------+----------+-----------------+---------------+--------+-----+------+----+------+", + ], + &df + ); + + Ok(()) + } + + // Test issue: https://github.com/apache/datafusion/issues/10346 + #[tokio::test] + async fn test_select_over_aggregate_schema() -> Result<()> { + let df = test_table() + .await? + .with_column("c", col("c1"))? + .aggregate(vec![], vec![array_agg(col("c")).alias("c")])? + .select(vec![col("c")])?; + + assert_eq!(df.schema().fields().len(), 1); + let field = df.schema().field(0); + // There are two columns named 'c', one from the input of the aggregate and the other from the output. + // Select should return the column from the output of the aggregate, which is a list. + assert!(matches!(field.data_type(), DataType::List(_))); + + Ok(()) + } + #[tokio::test] async fn test_distinct() -> Result<()> { let t = test_table().await?; @@ -2099,6 +2876,91 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_distinct_on() -> Result<()> { + let t = test_table().await?; + let plan = t + .distinct_on(vec![col("c1")], vec![col("aggregate_test_100.c1")], None) + .unwrap(); + + let sql_plan = + create_plan("select distinct on (c1) c1 from aggregate_test_100").await?; + + assert_same_plan(&plan.plan.clone(), &sql_plan); + + let df_results = plan.clone().collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+", + "| c1 |", + "+----+", + "| a |", + "| b |", + "| c |", + "| d |", + "| e |", + "+----+"], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_distinct_on_sort_by() -> Result<()> { + let t = test_table().await?; + let plan = t + .select(vec![col("c1")]) + .unwrap() + .distinct_on( + vec![col("c1")], + vec![col("c1")], + Some(vec![col("c1").sort(true, true)]), + ) + .unwrap() + .sort(vec![col("c1").sort(true, true)]) + .unwrap(); + + let df_results = plan.clone().collect().await?; + + #[rustfmt::skip] + assert_batches_sorted_eq!( + ["+----+", + "| c1 |", + "+----+", + "| a |", + "| b |", + "| c |", + "| d |", + "| e |", + "+----+"], + &df_results + ); + + Ok(()) + } + + #[tokio::test] + async fn test_distinct_on_sort_by_unprojected() -> Result<()> { + let t = test_table().await?; + let err = t + .select(vec![col("c1")]) + .unwrap() + .distinct_on( + vec![col("c1")], + vec![col("c1")], + Some(vec![col("c1").sort(true, true)]), + ) + .unwrap() + // try to sort on some value not present in input to distinct + .sort(vec![col("c2").sort(true, true)]) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list"); + + Ok(()) + } + #[tokio::test] async fn join() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?; @@ -2134,11 +2996,35 @@ mod tests { \n TableScan: a\ \n Projection: b.c1, b.c2\ \n TableScan: b"; - assert_eq!(expected_plan, format!("{:?}", join.logical_plan())); + assert_eq!(expected_plan, format!("{}", join.logical_plan())); Ok(()) } + #[tokio::test] + async fn join_on_filter_datatype() -> Result<()> { + let left = test_table_with_name("a").await?.select_columns(&["c1"])?; + let right = test_table_with_name("b").await?.select_columns(&["c1"])?; + + // JOIN ON untyped NULL + let join = left.clone().join_on( + right.clone(), + JoinType::Inner, + Some(Expr::Literal(ScalarValue::Null)), + )?; + let expected_plan = "EmptyRelation"; + assert_eq!(expected_plan, format!("{}", join.into_optimized_plan()?)); + + // JOIN ON expression must be boolean type + let join = left.join_on(right, JoinType::Inner, Some(lit("TRUE")))?; + let expected = join.into_optimized_plan().unwrap_err(); + assert_eq!( + expected.strip_backtrace(), + "type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8" + ); + Ok(()) + } + #[tokio::test] async fn join_ambiguous_filter() -> Result<()> { let left = test_table_with_name("a") @@ -2204,8 +3090,8 @@ mod tests { #[tokio::test] async fn registry() -> Result<()> { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; // declare the udf let my_fn: ScalarFunctionImplementation = @@ -2215,7 +3101,7 @@ mod tests { ctx.register_udf(create_udf( "my_fn", vec![DataType::Float64], - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, my_fn, )); @@ -2301,7 +3187,7 @@ mod tests { assert_batches_sorted_eq!( [ "+----+-----------------------------+", - "| c1 | SUM(aggregate_test_100.c12) |", + "| c1 | sum(aggregate_test_100.c12) |", "+----+-----------------------------+", "| a | 10.238448667882977 |", "| b | 7.797734760124923 |", @@ -2317,7 +3203,7 @@ mod tests { assert_batches_sorted_eq!( [ "+----+---------------------+", - "| c1 | SUM(test_table.c12) |", + "| c1 | sum(test_table.c12) |", "+----+---------------------+", "| a | 10.238448667882977 |", "| b | 7.797734760124923 |", @@ -2338,8 +3224,8 @@ mod tests { /// Create a logical plan from a SQL query async fn create_plan(sql: &str) -> Result { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; Ok(ctx.sql(sql).await?.into_unoptimized_plan()) } @@ -2421,6 +3307,41 @@ mod tests { Ok(()) } + // Test issues: https://github.com/apache/datafusion/issues/11982 + // and https://github.com/apache/datafusion/issues/12425 + // Window function was creating unwanted projection when using with_column() method. + #[tokio::test] + async fn test_window_function_with_column() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; + let ctx = SessionContext::new(); + let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); + let func = row_number().alias("row_num"); + + // This first `with_column` results in a column without a `qualifier` + let df_impl = df_impl.with_column("s", col("c2") + col("c3"))?; + + // This second `with_column` should only alias `func` as `"r"` + let df = df_impl.with_column("r", func)?.limit(0, Some(2))?; + + df.clone().show().await?; + assert_eq!(5, df.schema().fields().len()); + + let df_results = df.clone().collect().await?; + assert_batches_sorted_eq!( + [ + "+----+----+-----+-----+---+", + "| c1 | c2 | c3 | s | r |", + "+----+----+-----+-----+---+", + "| c | 2 | 1 | 3 | 1 |", + "| d | 5 | -40 | -35 | 2 |", + "+----+----+-----+-----+---+", + ], + &df_results + ); + + Ok(()) + } + // Test issue: https://github.com/apache/datafusion/issues/7790 // The join operation outputs two identical column names, but they belong to different relations. #[tokio::test] @@ -2469,20 +3390,19 @@ mod tests { \n Inner Join: t1.c1 = t2.c1\ \n TableScan: t1\ \n TableScan: t2", - format!("{:?}", df_with_column.logical_plan()) + format!("{}", df_with_column.logical_plan()) ); assert_eq!( "\ Projection: t1.c1, t2.c1, Boolean(true) AS new_column\ - \n Limit: skip=0, fetch=1\ - \n Sort: t1.c1 ASC NULLS FIRST, fetch=1\ - \n Inner Join: t1.c1 = t2.c1\ - \n SubqueryAlias: t1\ - \n TableScan: aggregate_test_100 projection=[c1]\ - \n SubqueryAlias: t2\ - \n TableScan: aggregate_test_100 projection=[c1]", - format!("{:?}", df_with_column.clone().into_optimized_plan()?) + \n Sort: t1.c1 ASC NULLS FIRST, fetch=1\ + \n Inner Join: t1.c1 = t2.c1\ + \n SubqueryAlias: t1\ + \n TableScan: aggregate_test_100 projection=[c1]\ + \n SubqueryAlias: t2\ + \n TableScan: aggregate_test_100 projection=[c1]", + format!("{}", df_with_column.clone().into_optimized_plan()?) ); let df_results = df_with_column.collect().await?; @@ -2500,65 +3420,19 @@ mod tests { Ok(()) } - // Table 't1' self join - // Supplementary test of issue: https://github.com/apache/datafusion/issues/7790 - #[tokio::test] - async fn with_column_self_join() -> Result<()> { - let df = test_table().await?.select_columns(&["c1"])?; - let ctx = SessionContext::new(); - - ctx.register_table("t1", df.into_view())?; - - let df = ctx - .table("t1") - .await? - .join( - ctx.table("t1").await?, - JoinType::Inner, - &["c1"], - &["c1"], - None, - )? - .sort(vec![ - // make the test deterministic - col("t1.c1").sort(true, true), - ])? - .limit(0, Some(1))?; - - let df_results = df.clone().collect().await?; - assert_batches_sorted_eq!( - [ - "+----+----+", - "| c1 | c1 |", - "+----+----+", - "| a | a |", - "+----+----+", - ], - &df_results - ); - - let actual_err = df.clone().with_column("new_column", lit(true)).unwrap_err(); - let expected_err = "Error during planning: Projections require unique expression names \ - but the expression \"t1.c1\" at position 0 and \"t1.c1\" at position 1 have the same name. \ - Consider aliasing (\"AS\") one of them."; - assert_eq!(actual_err.strip_backtrace(), expected_err); - - Ok(()) - } - #[tokio::test] async fn with_column_renamed() -> Result<()> { let df = test_table() .await? .select_columns(&["c1", "c2", "c3"])? .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? - .limit(0, Some(1))? .sort(vec![ // make the test deterministic col("c1").sort(true, true), col("c2").sort(true, true), col("c3").sort(true, true), ])? + .limit(0, Some(1))? .with_column("sum", col("c2") + col("c3"))?; let df_sum_renamed = df @@ -2574,11 +3448,11 @@ mod tests { assert_batches_sorted_eq!( [ - "+-----+-----+----+-------+", - "| one | two | c3 | total |", - "+-----+-----+----+-------+", - "| a | 3 | 13 | 16 |", - "+-----+-----+----+-------+" + "+-----+-----+-----+-------+", + "| one | two | c3 | total |", + "+-----+-----+-----+-------+", + "| a | 3 | -72 | -69 |", + "+-----+-----+-----+-------+", ], &df_sum_renamed ); @@ -2664,19 +3538,18 @@ mod tests { \n Inner Join: t1.c1 = t2.c1\ \n TableScan: t1\ \n TableScan: t2", - format!("{:?}", df_renamed.logical_plan()) + format!("{}", df_renamed.logical_plan()) ); assert_eq!("\ Projection: t1.c1 AS AAA, t1.c2, t1.c3, t2.c1, t2.c2, t2.c3\ - \n Limit: skip=0, fetch=1\ - \n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1\ - \n Inner Join: t1.c1 = t2.c1\ - \n SubqueryAlias: t1\ - \n TableScan: aggregate_test_100 projection=[c1, c2, c3]\ - \n SubqueryAlias: t2\ - \n TableScan: aggregate_test_100 projection=[c1, c2, c3]", - format!("{:?}", df_renamed.clone().into_optimized_plan()?) + \n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1\ + \n Inner Join: t1.c1 = t2.c1\ + \n SubqueryAlias: t1\ + \n TableScan: aggregate_test_100 projection=[c1, c2, c3]\ + \n SubqueryAlias: t2\ + \n TableScan: aggregate_test_100 projection=[c1, c2, c3]", + format!("{}", df_renamed.clone().into_optimized_plan()?) ); let df_results = df_renamed.collect().await?; @@ -2697,14 +3570,13 @@ mod tests { #[tokio::test] async fn with_column_renamed_case_sensitive() -> Result<()> { - let config = - SessionConfig::from_string_hash_map(std::collections::HashMap::from([( - "datafusion.sql_parser.enable_ident_normalization".to_owned(), - "false".to_owned(), - )]))?; - let mut ctx = SessionContext::new_with_config(config); + let config = SessionConfig::from_string_hash_map(&HashMap::from([( + "datafusion.sql_parser.enable_ident_normalization".to_owned(), + "false".to_owned(), + )]))?; + let ctx = SessionContext::new_with_config(config); let name = "aggregate_test_100"; - register_aggregate_csv(&mut ctx, name).await?; + register_aggregate_csv(&ctx, name).await?; let df = ctx.table(name); let df = df @@ -2773,7 +3645,7 @@ mod tests { #[tokio::test] async fn row_writer_resize_test() -> Result<()> { - let schema = Arc::new(Schema::new(vec![arrow::datatypes::Field::new( + let schema = Arc::new(Schema::new(vec![Field::new( "column_1", DataType::Utf8, false, @@ -2782,7 +3654,7 @@ mod tests { let data = RecordBatch::try_new( schema, vec![ - Arc::new(arrow::array::StringArray::from(vec![ + Arc::new(StringArray::from(vec![ Some("2a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), Some("3a0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000800"), ])) @@ -2794,7 +3666,7 @@ mod tests { let sql = r#" SELECT - COUNT(1) + count(1) FROM test GROUP BY @@ -2861,7 +3733,7 @@ mod tests { assert_eq!( "TableScan: ?table? projection=[c2, c3, sum]", - format!("{:?}", cached_df.clone().into_optimized_plan()?) + format!("{}", cached_df.clone().into_optimized_plan()?) ); let df_results = df.collect().await?; @@ -2992,6 +3864,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftAnti, JoinType::RightAnti, + JoinType::LeftMark, ]; let default_partition_count = SessionConfig::new().target_partitions(); @@ -3009,10 +3882,10 @@ mod tests { let join_schema = physical_plan.schema(); match join_type { - JoinType::Inner - | JoinType::Left + JoinType::Left | JoinType::LeftSemi - | JoinType::LeftAnti => { + | JoinType::LeftAnti + | JoinType::LeftMark => { let left_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c1", &join_schema)?), Arc::new(Column::new_with_schema("c2", &join_schema)?), @@ -3022,7 +3895,10 @@ mod tests { &Partitioning::Hash(left_exprs, default_partition_count) ); } - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { + JoinType::Inner + | JoinType::Right + | JoinType::RightSemi + | JoinType::RightAnti => { let right_exprs: Vec> = vec![ Arc::new(Column::new_with_schema("c2_c1", &join_schema)?), Arc::new(Column::new_with_schema("c2_c2", &join_schema)?), @@ -3042,6 +3918,83 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_except_nested_struct() -> Result<()> { + use arrow::array::StructArray; + + let nested_schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("lat", DataType::Int32, true), + Field::new("long", DataType::Int32, true), + ])); + let schema = Arc::new(Schema::new(vec![ + Field::new("value", DataType::Int32, true), + Field::new( + "nested", + DataType::Struct(nested_schema.fields.clone()), + true, + ), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])), + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("lat", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("long", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ])), + ], + ) + .unwrap(); + + let updated_batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![Some(1), Some(12), Some(3)])), + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new("id", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("lat", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ( + Arc::new(Field::new("long", DataType::Int32, true)), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + ), + ])), + ], + ) + .unwrap(); + + let ctx = SessionContext::new(); + let before = ctx.read_batch(batch).expect("Failed to make DataFrame"); + let after = ctx + .read_batch(updated_batch) + .expect("Failed to make DataFrame"); + + let diff = before + .except(after) + .expect("Failed to except") + .collect() + .await?; + assert_eq!(diff.len(), 1); + Ok(()) + } + #[tokio::test] async fn nested_explain_should_fail() -> Result<()> { let ctx = SessionContext::new(); @@ -3053,4 +4006,34 @@ mod tests { assert!(result.is_err()); Ok(()) } + + // Test issue: https://github.com/apache/datafusion/issues/12065 + #[tokio::test] + async fn filtered_aggr_with_param_values() -> Result<()> { + let cfg = SessionConfig::new().set( + "datafusion.sql_parser.dialect", + &ScalarValue::from("PostgreSQL"), + ); + let ctx = SessionContext::new_with_config(cfg); + register_aggregate_csv(&ctx, "table1").await?; + + let df = ctx + .sql("select count (c2) filter (where c3 > $1) from table1") + .await? + .with_param_values(ParamValues::List(vec![ScalarValue::from(10u64)])); + + let df_results = df?.collect().await?; + assert_batches_eq!( + &[ + "+------------------------------------------------+", + "| count(table1.c2) FILTER (WHERE table1.c3 > $1) |", + "+------------------------------------------------+", + "| 54 |", + "+------------------------------------------------+", + ], + &df_results + ); + + Ok(()) + } } diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 0ec46df0ae5d..f90b35fde6ba 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -15,11 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use crate::datasource::file_format::{ + format_as_file_type, parquet::ParquetFormatFactory, +}; + use super::{ DataFrame, DataFrameWriteOptions, DataFusionError, LogicalPlanBuilder, RecordBatch, }; -use datafusion_common::config::{FormatOptions, TableParquetOptions}; +use datafusion_common::config::TableParquetOptions; +use datafusion_expr::dml::InsertOp; impl DataFrame { /// Execute the `DataFrame` and write the results to Parquet file(s). @@ -51,19 +58,25 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_parquet.", + options.insert_op + ))); } - let props = writer_options - .unwrap_or_else(|| self.session_state.default_table_options().parquet); + let format = if let Some(parquet_opts) = writer_options { + Arc::new(ParquetFormatFactory::new_with_options(parquet_opts)) + } else { + Arc::new(ParquetFormatFactory::new()) + }; + + let file_type = format_as_file_type(format); let plan = LogicalPlanBuilder::copy_to( self.plan, path.into(), - FormatOptions::PARQUET(props), + file_type, Default::default(), options.partition_by, )? @@ -178,14 +191,14 @@ mod tests { async fn write_parquet_with_small_rg_size() -> Result<()> { // This test verifies writing a parquet file with small rg size // relative to datafusion.execution.batch_size does not panic - let mut ctx = SessionContext::new_with_config( - SessionConfig::from_string_hash_map(HashMap::from_iter( + let ctx = SessionContext::new_with_config(SessionConfig::from_string_hash_map( + &HashMap::from_iter( [("datafusion.execution.batch_size", "10")] .iter() .map(|(s1, s2)| (s1.to_string(), s2.to_string())), - ))?, - ); - register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + ), + )?); + register_aggregate_csv(&ctx, "aggregate_test_100").await?; let test_df = ctx.table("aggregate_test_100").await?; let output_path = "file://local/test.parquet"; diff --git a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs index a16c1ae3333f..9f089c7c0cea 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs @@ -203,14 +203,10 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { Arc::new(builder.finish()) } - fn build_primitive_array( - &self, - rows: RecordSlice, - col_name: &str, - ) -> ArrayRef + fn build_primitive_array(&self, rows: RecordSlice, col_name: &str) -> ArrayRef where - T: ArrowNumericType, - T::Native: num_traits::cast::NumCast, + T: ArrowNumericType + Resolver, + T::Native: NumCast, { Arc::new( rows.iter() @@ -358,7 +354,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { let builder = builder .as_any_mut() .downcast_mut::>() - .ok_or_else(||ArrowError::SchemaError( + .ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -373,7 +369,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { builder.append(true); } DataType::Dictionary(_, _) => { - let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||ArrowError::SchemaError( + let builder = builder.as_any_mut().downcast_mut::>>().ok_or_else(||SchemaError( "Cast failed for ListBuilder during nested data parsing".to_string(), ))?; for val in vals { @@ -406,7 +402,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { col_name: &str, ) -> ArrowResult where - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, T: ArrowPrimitiveType + ArrowDictionaryKeyType, { let mut builder: StringDictionaryBuilder = @@ -457,12 +453,10 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt64 => { self.build_dictionary_array::(rows, col_name) } - _ => Err(ArrowError::SchemaError( - "unsupported dictionary key type".to_string(), - )), + _ => Err(SchemaError("unsupported dictionary key type".to_string())), } } else { - Err(ArrowError::SchemaError( + Err(SchemaError( "dictionary types other than UTF-8 not yet supported".to_string(), )) } @@ -536,7 +530,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { DataType::UInt32 => self.read_primitive_list_values::(rows), DataType::UInt64 => self.read_primitive_list_values::(rows), DataType::Float16 => { - return Err(ArrowError::SchemaError("Float16 not supported".to_string())) + return Err(SchemaError("Float16 not supported".to_string())) } DataType::Float32 => self.read_primitive_list_values::(rows), DataType::Float64 => self.read_primitive_list_values::(rows), @@ -545,7 +539,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { - return Err(ArrowError::SchemaError( + return Err(SchemaError( "Temporal types are not yet supported, see ARROW-4803".to_string(), )) } @@ -577,7 +571,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { // extract list values, with non-lists converted to Value::Null let array_item_count = rows .iter() - .map(|row| match row { + .map(|row| match maybe_resolve_union(row) { Value::Array(values) => values.len(), _ => 1, }) @@ -627,7 +621,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { .unwrap() } datatype => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "Nested list of {datatype:?} not supported" ))); } @@ -741,7 +735,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time64" ))) } @@ -755,7 +749,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { &field_path, ), t => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "TimeUnit {t:?} not supported with Time32" ))) } @@ -858,7 +852,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { make_array(data) } _ => { - return Err(ArrowError::SchemaError(format!( + return Err(SchemaError(format!( "type {:?} not supported", field.data_type() ))) @@ -874,7 +868,7 @@ impl<'a, R: Read> AvroArrowArrayReader<'a, R> { fn read_primitive_list_values(&self, rows: &[&Value]) -> ArrayData where T: ArrowPrimitiveType + ArrowNumericType, - T::Native: num_traits::cast::NumCast, + T::Native: NumCast, { let values = rows .iter() @@ -974,7 +968,7 @@ fn resolve_u8(v: &Value) -> AvroResult { other => Err(AvroError::GetU8(other.into())), }?; if let Value::Int(n) = int { - if n >= 0 && n <= std::convert::From::from(u8::MAX) { + if n >= 0 && n <= From::from(u8::MAX) { return Ok(n as u8); } } @@ -1052,7 +1046,7 @@ fn maybe_resolve_union(value: &Value) -> &Value { impl Resolver for N where N: ArrowNumericType, - N::Native: num_traits::cast::NumCast, + N::Native: NumCast, { fn resolve(value: &Value) -> Option { let value = maybe_resolve_union(value); @@ -1647,6 +1641,93 @@ mod test { assert_batches_eq!(expected, &[batch]); } + #[test] + fn test_avro_nullable_struct_array() { + let schema = apache_avro::Schema::parse_str( + r#" + { + "type": "record", + "name": "r1", + "fields": [ + { + "name": "col1", + "type": [ + "null", + { + "type": "array", + "items": { + "type": [ + "null", + { + "type": "record", + "name": "Item", + "fields": [ + { + "name": "id", + "type": "long" + } + ] + } + ] + } + } + ], + "default": null + } + ] + }"#, + ) + .unwrap(); + let jv1 = serde_json::json!({ + "col1": [ + { + "id": 234 + }, + { + "id": 345 + } + ] + }); + let r1 = apache_avro::to_value(jv1) + .unwrap() + .resolve(&schema) + .unwrap(); + let r2 = apache_avro::to_value(serde_json::json!({ "col1": null })) + .unwrap() + .resolve(&schema) + .unwrap(); + + let mut w = apache_avro::Writer::new(&schema, vec![]); + for _i in 0..5 { + w.append(r1.clone()).unwrap(); + } + w.append(r2).unwrap(); + let bytes = w.into_inner().unwrap(); + + let mut reader = ReaderBuilder::new() + .read_schema() + .with_batch_size(20) + .build(std::io::Cursor::new(bytes)) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 1); + + let expected = [ + "+------------------------+", + "| col1 |", + "+------------------------+", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| [{id: 234}, {id: 345}] |", + "| |", + "+------------------------+", + ]; + assert_batches_eq!(expected, &[batch]); + } + #[test] fn test_avro_iterator() { let reader = build_reader("alltypes_plain.avro", 5); diff --git a/datafusion/core/src/datasource/avro_to_arrow/mod.rs b/datafusion/core/src/datasource/avro_to_arrow/mod.rs index af0bb86a3e27..71184a78c96f 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/mod.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/mod.rs @@ -30,6 +30,8 @@ use crate::arrow::datatypes::Schema; use crate::error::Result; #[cfg(feature = "avro")] pub use reader::{Reader, ReaderBuilder}; +#[cfg(feature = "avro")] +pub use schema::to_arrow_schema; use std::io::Read; #[cfg(feature = "avro")] @@ -37,7 +39,7 @@ use std::io::Read; pub fn read_avro_schema_from_reader(reader: &mut R) -> Result { let avro_reader = apache_avro::Reader::new(reader)?; let schema = avro_reader.writer_schema(); - schema::to_arrow_schema(schema) + to_arrow_schema(schema) } #[cfg(not(feature = "avro"))] diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs index afc4536f068e..23f57b12ae08 100644 --- a/datafusion/core/src/datasource/cte_worktable.rs +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -17,11 +17,12 @@ //! CteWorkTable implementation used for recursive queries -use std::any::Any; use std::sync::Arc; +use std::{any::Any, borrow::Cow}; use arrow::datatypes::SchemaRef; use async_trait::async_trait; +use datafusion_catalog::Session; use datafusion_physical_plan::work_table::WorkTableExec; use crate::{ @@ -31,11 +32,11 @@ use crate::{ }; use crate::datasource::{TableProvider, TableType}; -use crate::execution::context::SessionState; /// The temporary working table where the previous iteration of a recursive query is stored /// Naming is based on PostgreSQL's implementation. /// See here for more details: www.postgresql.org/docs/11/queries-with.html#id-1.5.6.12.5.4 +#[derive(Debug)] pub struct CteWorkTable { /// The name of the CTE work table // WIP, see https://github.com/apache/datafusion/issues/462 @@ -63,7 +64,7 @@ impl TableProvider for CteWorkTable { self } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { + fn get_logical_plan(&self) -> Option> { None } @@ -77,7 +78,7 @@ impl TableProvider for CteWorkTable { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, _projection: Option<&Vec>, _filters: &[Expr], _limit: Option, diff --git a/datafusion/core/src/datasource/default_table_source.rs b/datafusion/core/src/datasource/default_table_source.rs index 977e681d6641..b4a5a76fc9ff 100644 --- a/datafusion/core/src/datasource/default_table_source.rs +++ b/datafusion/core/src/datasource/default_table_source.rs @@ -17,8 +17,8 @@ //! Default TableSource implementation used in DataFusion physical plans -use std::any::Any; use std::sync::Arc; +use std::{any::Any, borrow::Cow}; use crate::datasource::TableProvider; @@ -70,7 +70,7 @@ impl TableSource for DefaultTableSource { self.table_provider.supports_filters_pushdown(filter) } - fn get_logical_plan(&self) -> Option<&datafusion_expr::LogicalPlan> { + fn get_logical_plan(&self) -> Option> { self.table_provider.get_logical_plan() } diff --git a/datafusion/core/src/datasource/dynamic_file.rs b/datafusion/core/src/datasource/dynamic_file.rs new file mode 100644 index 000000000000..6654d0871c3f --- /dev/null +++ b/datafusion/core/src/datasource/dynamic_file.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! dynamic_file_schema contains an [`UrlTableFactory`] implementation that +//! can create a [`ListingTable`] from the given url. + +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion_catalog::{SessionStore, UrlTableFactory}; +use datafusion_common::plan_datafusion_err; + +use crate::datasource::listing::{ListingTable, ListingTableConfig, ListingTableUrl}; +use crate::datasource::TableProvider; +use crate::error::Result; +use crate::execution::context::SessionState; + +/// [DynamicListTableFactory] is a factory that can create a [ListingTable] from the given url. +#[derive(Default, Debug)] +pub struct DynamicListTableFactory { + /// The session store that contains the current session. + session_store: SessionStore, +} + +impl DynamicListTableFactory { + /// Create a new [DynamicListTableFactory] with the given state store. + pub fn new(session_store: SessionStore) -> Self { + Self { session_store } + } + + /// Get the session store. + pub fn session_store(&self) -> &SessionStore { + &self.session_store + } +} + +#[async_trait] +impl UrlTableFactory for DynamicListTableFactory { + async fn try_new(&self, url: &str) -> Result>> { + let Ok(table_url) = ListingTableUrl::parse(url) else { + return Ok(None); + }; + + let state = &self + .session_store() + .get_session() + .upgrade() + .and_then(|session| { + session + .read() + .as_any() + .downcast_ref::() + .cloned() + }) + .ok_or_else(|| plan_datafusion_err!("get current SessionStore error"))?; + + match ListingTableConfig::new(table_url.clone()) + .infer_options(state) + .await + { + Ok(cfg) => { + let cfg = cfg + .infer_partitions_from_path(state) + .await? + .infer_schema(state) + .await?; + ListingTable::try_new(cfg) + .map(|table| Some(Arc::new(table) as Arc)) + } + Err(_) => Ok(None), + } + } +} diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 5100987520ee..bc5b82bd8c5b 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -22,16 +22,17 @@ use std::sync::Arc; use arrow::datatypes::*; use async_trait::async_trait; +use datafusion_catalog::Session; use datafusion_common::project_schema; use crate::datasource::{TableProvider, TableType}; use crate::error::Result; -use crate::execution::context::SessionState; use crate::logical_expr::Expr; use crate::physical_plan::{empty::EmptyExec, ExecutionPlan}; /// An empty plan that is useful for testing and generating plans /// without mapping them to actual data. +#[derive(Debug)] pub struct EmptyTable { schema: SchemaRef, partitions: usize, @@ -69,7 +70,7 @@ impl TableProvider for EmptyTable { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 9d58465191e1..c10ebbd6c9ea 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -21,12 +21,14 @@ use std::any::Any; use std::borrow::Cow; +use std::collections::HashMap; use std::fmt::{self, Debug}; use std::sync::Arc; use super::file_compression_type::FileCompressionType; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; +use super::FileFormatFactory; use crate::datasource::file_format::FileFormat; use crate::datasource::physical_plan::{ ArrowExec, FileGroupDisplay, FileScanConfig, FileSinkConfig, @@ -40,14 +42,19 @@ use arrow::ipc::reader::FileReader; use arrow::ipc::writer::IpcWriteOptions; use arrow::ipc::{root_as_message, CompressionType}; use arrow_schema::{ArrowError, Schema, SchemaRef}; -use datafusion_common::{not_impl_err, DataFusionError, FileType, Statistics}; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::{ + not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, +}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_expr::dml::InsertOp; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::Bytes; +use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::stream::BoxStream; use futures::StreamExt; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; @@ -61,6 +68,42 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// If the buffered Arrow data exceeds this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +#[derive(Default, Debug)] +/// Factory struct used to create [ArrowFormat] +pub struct ArrowFormatFactory; + +impl ArrowFormatFactory { + /// Creates an instance of [ArrowFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for ArrowFormatFactory { + fn create( + &self, + _state: &SessionState, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(ArrowFormat)) + } + + fn default(&self) -> Arc { + Arc::new(ArrowFormat) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for ArrowFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_ARROW_EXTENSION[1..].to_string() + } +} + /// Arrow `FileFormat` implementation. #[derive(Default, Debug)] pub struct ArrowFormat; @@ -71,6 +114,23 @@ impl FileFormat for ArrowFormat { self } + fn get_ext(&self) -> String { + ArrowFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Arrow FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, _state: &SessionState, @@ -120,9 +180,9 @@ impl FileFormat for ArrowFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, - order_requirements: Option>, + order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Arrow format"); } @@ -136,10 +196,6 @@ impl FileFormat for ArrowFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::ARROW - } } /// Implements [`DataSink`] for writing to arrow_ipc files @@ -227,6 +283,7 @@ impl DataSink for ArrowFileSink { part_col, self.config.table_paths[0].clone(), "arrow".into(), + self.config.keep_partition_by_columns, ); let mut file_write_tasks: JoinSet> = @@ -286,7 +343,10 @@ impl DataSink for ArrowFileSink { } } - demux_task.join_unwind().await?; + demux_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; Ok(row_count as u64) } } @@ -302,7 +362,7 @@ async fn infer_schema_from_file_stream( // Expected format: // - 6 bytes // - 2 bytes - // - 4 bytes, not present below v0.15.0 + // - 4 bytes, not present below v0.15.0 // - 4 bytes // // @@ -314,7 +374,7 @@ async fn infer_schema_from_file_stream( // Files should start with these magic bytes if bytes[0..6] != ARROW_MAGIC { return Err(ArrowError::ParseError( - "Arrow file does not contian correct header".to_string(), + "Arrow file does not contain correct header".to_string(), ))?; } diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 132dae14c684..5190bdbe153a 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -18,16 +18,23 @@ //! [`AvroFormat`] Apache Avro [`FileFormat`] abstractions use std::any::Any; +use std::collections::HashMap; +use std::fmt; use std::sync::Arc; use arrow::datatypes::Schema; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::FileType; +use datafusion_common::parsers::CompressionTypeVariant; +use datafusion_common::DataFusionError; +use datafusion_common::GetExt; +use datafusion_common::DEFAULT_AVRO_EXTENSION; use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +use super::file_compression_type::FileCompressionType; use super::FileFormat; +use super::FileFormatFactory; use crate::datasource::avro_to_arrow::read_avro_schema_from_reader; use crate::datasource::physical_plan::{AvroExec, FileScanConfig}; use crate::error::Result; @@ -35,6 +42,48 @@ use crate::execution::context::SessionState; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; +#[derive(Default)] +/// Factory struct used to create [AvroFormat] +pub struct AvroFormatFactory; + +impl AvroFormatFactory { + /// Creates an instance of [AvroFormatFactory] + pub fn new() -> Self { + Self {} + } +} + +impl FileFormatFactory for AvroFormatFactory { + fn create( + &self, + _state: &SessionState, + _format_options: &HashMap, + ) -> Result> { + Ok(Arc::new(AvroFormat)) + } + + fn default(&self) -> Arc { + Arc::new(AvroFormat) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl fmt::Debug for AvroFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AvroFormatFactory").finish() + } +} + +impl GetExt for AvroFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_AVRO_EXTENSION[1..].to_string() + } +} + /// Avro `FileFormat` implementation. #[derive(Default, Debug)] pub struct AvroFormat; @@ -45,6 +94,23 @@ impl FileFormat for AvroFormat { self } + fn get_ext(&self) -> String { + AvroFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Avro FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, _state: &SessionState, @@ -89,10 +155,6 @@ impl FileFormat for AvroFormat { let exec = AvroExec::new(conf); Ok(Arc::new(exec)) } - - fn file_type(&self) -> FileType { - FileType::AVRO - } } #[cfg(test)] diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 645f98cd3fb0..0335c8aa3ff6 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -18,12 +18,12 @@ //! [`CsvFormat`], Comma Separated Value (CSV) [`FileFormat`] abstractions use std::any::Any; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug}; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::FileFormat; +use super::{FileFormat, FileFormatFactory}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::{ @@ -32,26 +32,101 @@ use crate::datasource::physical_plan::{ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::insert::{DataSink, DataSinkExec}; -use crate::physical_plan::{DisplayAs, DisplayFormatType, Statistics}; -use crate::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; +use crate::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, Statistics, +}; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::SchemaRef; use arrow::datatypes::{DataType, Field, Fields, Schema}; -use datafusion_common::config::CsvOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, CsvOptions}; use datafusion_common::file_options::csv_writer::CsvWriterOptions; -use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType}; +use datafusion_common::{ + exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, +}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_expr::dml::InsertOp; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{Buf, Bytes}; +use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::stream::BoxStream; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; +#[derive(Default)] +/// Factory struct used to create [CsvFormatFactory] +pub struct CsvFormatFactory { + /// the options for csv file read + pub options: Option, +} + +impl CsvFormatFactory { + /// Creates an instance of [CsvFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [CsvFormatFactory] with customized default options + pub fn new_with_options(options: CsvOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl Debug for CsvFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CsvFormatFactory") + .field("options", &self.options) + .finish() + } +} + +impl FileFormatFactory for CsvFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result> { + let csv_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::CSV); + table_options.alter_with_string_hash_map(format_options)?; + table_options.csv + } + Some(csv_options) => { + let mut csv_options = csv_options.clone(); + for (k, v) in format_options { + csv_options.set(k, v)?; + } + csv_options + } + }; + + Ok(Arc::new(CsvFormat::default().with_options(csv_options))) + } + + fn default(&self) -> Arc { + Arc::new(CsvFormat::default()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for CsvFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_CSV_EXTENSION[1..].to_string() + } +} + /// Character Separated Value `FileFormat` implementation. #[derive(Debug, Default)] pub struct CsvFormat { @@ -136,15 +211,22 @@ impl CsvFormat { /// Set true to indicate that the first line is a header. /// - default to true pub fn with_has_header(mut self, has_header: bool) -> Self { - self.options.has_header = has_header; + self.options.has_header = Some(has_header); self } - /// True if the first line is a header. - pub fn has_header(&self) -> bool { + /// Returns `Some(true)` if the first line is a header, `Some(false)` if + /// it is not, and `None` if it is not specified. + pub fn has_header(&self) -> Option { self.options.has_header } + /// Lines beginning with this byte are ignored. + pub fn with_comment(mut self, comment: Option) -> Self { + self.options.comment = comment; + self + } + /// The character separating values within a row. /// - default to ',' pub fn with_delimiter(mut self, delimiter: u8) -> Self { @@ -166,6 +248,25 @@ impl CsvFormat { self } + /// The character used to indicate the end of a row. + /// - default to None (CRLF) + pub fn with_terminator(mut self, terminator: Option) -> Self { + self.options.terminator = terminator; + self + } + + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.options.newlines_in_values = Some(newlines_in_values); + self + } + /// Set a `FileCompressionType` of CSV /// - defaults to `FileCompressionType::UNCOMPRESSED` pub fn with_file_compression_type( @@ -198,9 +299,21 @@ impl FileFormat for CsvFormat { self } + fn get_ext(&self) -> String { + CsvFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + Ok(format!("{}{}", ext, file_compression_type.get_ext())) + } + async fn infer_schema( &self, - _state: &SessionState, + state: &SessionState, store: &Arc, objects: &[ObjectMeta], ) -> Result { @@ -211,8 +324,14 @@ impl FileFormat for CsvFormat { for object in objects { let stream = self.read_to_delimited_chunks(store, object).await; let (schema, records_read) = self - .infer_schema_from_stream(records_to_read, stream) - .await?; + .infer_schema_from_stream(state, records_to_read, stream) + .await + .map_err(|err| { + DataFusionError::Context( + format!("Error when processing CSV file {}", &object.location), + Box::new(err), + ) + })?; records_to_read -= records_read; schemas.push(schema); if records_to_read == 0 { @@ -236,33 +355,64 @@ impl FileFormat for CsvFormat { async fn create_physical_plan( &self, - _state: &SessionState, + state: &SessionState, conf: FileScanConfig, _filters: Option<&Arc>, ) -> Result> { - let exec = CsvExec::new( - conf, - self.options.has_header, - self.options.delimiter, - self.options.quote, - self.options.escape, - self.options.compression.into(), - ); + // Consult configuration options for default values + let has_header = self + .options + .has_header + .unwrap_or(state.config_options().catalog.has_header); + let newlines_in_values = self + .options + .newlines_in_values + .unwrap_or(state.config_options().catalog.newlines_in_values); + + let exec = CsvExec::builder(conf) + .with_has_header(has_header) + .with_delimeter(self.options.delimiter) + .with_quote(self.options.quote) + .with_terminator(self.options.terminator) + .with_escape(self.options.escape) + .with_comment(self.options.comment) + .with_newlines_in_values(newlines_in_values) + .with_file_compression_type(self.options.compression.into()) + .build(); Ok(Arc::new(exec)) } async fn create_writer_physical_plan( &self, input: Arc, - _state: &SessionState, + state: &SessionState, conf: FileSinkConfig, - order_requirements: Option>, + order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for CSV"); } - let writer_options = CsvWriterOptions::try_from(&self.options)?; + // `has_header` and `newlines_in_values` fields of CsvOptions may inherit + // their values from session from configuration settings. To support + // this logic, writer options are built from the copy of `self.options` + // with updated values of these special fields. + let has_header = self + .options() + .has_header + .unwrap_or(state.config_options().catalog.has_header); + let newlines_in_values = self + .options() + .newlines_in_values + .unwrap_or(state.config_options().catalog.newlines_in_values); + + let options = self + .options() + .clone() + .with_has_header(has_header) + .with_newlines_in_values(newlines_in_values); + + let writer_options = CsvWriterOptions::try_from(&options)?; let sink_schema = conf.output_schema().clone(); let sink = Arc::new(CsvSink::new(conf, writer_options)); @@ -274,10 +424,6 @@ impl FileFormat for CsvFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::CSV - } } impl CsvFormat { @@ -286,20 +432,38 @@ impl CsvFormat { /// number of lines that were read async fn infer_schema_from_stream( &self, + state: &SessionState, mut records_to_read: usize, stream: impl Stream>, ) -> Result<(Schema, usize)> { let mut total_records_read = 0; let mut column_names = vec![]; let mut column_type_possibilities = vec![]; - let mut first_chunk = true; + let mut record_number = -1; pin_mut!(stream); while let Some(chunk) = stream.next().await.transpose()? { - let format = arrow::csv::reader::Format::default() - .with_header(self.options.has_header && first_chunk) - .with_delimiter(self.options.delimiter); + record_number += 1; + let first_chunk = record_number == 0; + let mut format = arrow::csv::reader::Format::default() + .with_header( + first_chunk + && self + .options + .has_header + .unwrap_or(state.config_options().catalog.has_header), + ) + .with_delimiter(self.options.delimiter) + .with_quote(self.options.quote); + + if let Some(escape) = self.options.escape { + format = format.with_escape(escape); + } + + if let Some(comment) = self.options.comment { + format = format.with_comment(comment); + } let (Schema { fields, .. }, records_read) = format.infer_schema(chunk.reader(), Some(records_to_read))?; @@ -320,14 +484,14 @@ impl CsvFormat { (field.name().clone(), possibilities) }) .unzip(); - first_chunk = false; } else { if fields.len() != column_type_possibilities.len() { return exec_err!( "Encountered unequal lengths between records on CSV file whilst inferring schema. \ - Expected {} records, found {} records", + Expected {} fields, found {} fields at record {}", column_type_possibilities.len(), - fields.len() + fields.len(), + record_number + 1 ); } @@ -536,10 +700,12 @@ mod tests { use arrow::compute::concat_batches; use datafusion_common::cast::as_string_array; + use datafusion_common::internal_err; use datafusion_common::stats::Precision; - use datafusion_common::{internal_err, GetExt}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::{col, lit}; + use crate::execution::session_state::SessionStateBuilder; use chrono::DateTime; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -552,9 +718,10 @@ mod tests { let session_ctx = SessionContext::new_with_config(config); let state = session_ctx.state(); let task_ctx = state.task_ctx(); - // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) + // skip column 9 that overflows the automatically discovered column type of i64 (u64 would work) let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]); - let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?; + let exec = + get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?; let stream = exec.execute(0, task_ctx)?; let tt_batches: i32 = stream @@ -582,7 +749,7 @@ mod tests { let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0, 1, 2, 3]); let exec = - get_exec(&state, "aggregate_test_100.csv", projection, Some(1)).await?; + get_exec(&state, "aggregate_test_100.csv", projection, Some(1), true).await?; let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); @@ -597,7 +764,8 @@ mod tests { let state = session_ctx.state(); let projection = None; - let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?; + let exec = + get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?; let x: Vec = exec .schema() @@ -616,7 +784,7 @@ mod tests { "c7: Int64", "c8: Int64", "c9: Int64", - "c10: Int64", + "c10: Utf8", "c11: Float64", "c12: Float64", "c13: Utf8" @@ -633,7 +801,8 @@ mod tests { let state = session_ctx.state(); let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0]); - let exec = get_exec(&state, "aggregate_test_100.csv", projection, None).await?; + let exec = + get_exec(&state, "aggregate_test_100.csv", projection, None, true).await?; let batches = collect(exec, task_ctx).await.expect("Collect batches"); @@ -703,6 +872,55 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_infer_schema_escape_chars() -> Result<()> { + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + let variable_object_store = Arc::new(VariableStream::new( + Bytes::from( + r#"c1,c2,c3,c4 +0.3,"Here, is a comma\"",third,3 +0.31,"double quotes are ok, "" quote",third again,9 +0.314,abc,xyz,27"#, + ), + 1, + )); + let object_meta = ObjectMeta { + location: Path::parse("/")?, + last_modified: DateTime::default(), + size: usize::MAX, + e_tag: None, + version: None, + }; + + let num_rows_to_read = 3; + let csv_format = CsvFormat::default() + .with_has_header(true) + .with_schema_infer_max_rec(num_rows_to_read) + .with_quote(b'"') + .with_escape(Some(b'\\')); + + let inferred_schema = csv_format + .infer_schema( + &state, + &(variable_object_store.clone() as Arc), + &[object_meta], + ) + .await?; + + let actual_fields: Vec<_> = inferred_schema + .fields() + .iter() + .map(|f| format!("{}: {:?}", f.name(), f.data_type())) + .collect(); + + assert_eq!( + vec!["c1: Float64", "c2: Utf8", "c3: Utf8", "c4: Int64",], + actual_fields + ); + Ok(()) + } + #[rstest( file_compression_type, case(FileCompressionType::UNCOMPRESSED), @@ -716,8 +934,15 @@ mod tests { async fn query_compress_data( file_compression_type: FileCompressionType, ) -> Result<()> { + let runtime = Arc::new(RuntimeEnvBuilder::new().build()?); + let mut cfg = SessionConfig::new(); + cfg.options_mut().catalog.has_header = true; + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let integration = LocalFileSystem::new_with_prefix(arrow_test_data()).unwrap(); - let path = Path::from("csv/aggregate_test_100.csv"); let csv = CsvFormat::default().with_has_header(true); let records_to_read = csv.options().schema_infer_max_rec; @@ -744,7 +969,7 @@ mod tests { Field::new("c7", DataType::Int64, true), Field::new("c8", DataType::Int64, true), Field::new("c9", DataType::Int64, true), - Field::new("c10", DataType::Int64, true), + Field::new("c10", DataType::Utf8, true), Field::new("c11", DataType::Float64, true), Field::new("c12", DataType::Float64, true), Field::new("c13", DataType::Utf8, true), @@ -757,7 +982,7 @@ mod tests { .read_to_delimited_chunks_from_stream(compressed_stream.unwrap()) .await; let (schema, records_read) = compressed_csv - .infer_schema_from_stream(records_to_read, decoded_stream) + .infer_schema_from_stream(&session_state, records_to_read, decoded_stream) .await?; assert_eq!(expected, schema); @@ -803,9 +1028,10 @@ mod tests { file_name: &str, projection: Option>, limit: Option, + has_header: bool, ) -> Result> { - let root = format!("{}/csv", crate::test_util::arrow_test_data()); - let format = CsvFormat::default(); + let root = format!("{}/csv", arrow_test_data()); + let format = CsvFormat::default().with_has_header(has_header); scan_format(state, &format, &root, file_name, projection, limit).await } @@ -901,7 +1127,7 @@ mod tests { #[rustfmt::skip] let expected = ["+--------------+", - "| SUM(aggr.c2) |", + "| sum(aggr.c2) |", "+--------------+", "| 285 |", "+--------------+"]; @@ -938,7 +1164,7 @@ mod tests { #[rustfmt::skip] let expected = ["+--------------+", - "| SUM(aggr.c3) |", + "| sum(aggr.c3) |", "+--------------+", "| 781 |", "+--------------+"]; @@ -948,6 +1174,41 @@ mod tests { Ok(()) } + #[rstest(n_partitions, case(1), case(2), case(3), case(4))] + #[tokio::test] + async fn test_csv_parallel_newlines_in_values(n_partitions: usize) -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(n_partitions); + let csv_options = CsvReadOptions::default() + .has_header(true) + .newlines_in_values(true); + let ctx = SessionContext::new_with_config(config); + let testdata = arrow_test_data(); + ctx.register_csv( + "aggr", + &format!("{testdata}/csv/aggregate_test_100.csv"), + csv_options, + ) + .await?; + + let query = "select sum(c3) from aggr;"; + let query_result = ctx.sql(query).await?.collect().await?; + let actual_partitions = count_query_csv_partitions(&ctx, query).await?; + + #[rustfmt::skip] + let expected = ["+--------------+", + "| sum(aggr.c3) |", + "+--------------+", + "| 781 |", + "+--------------+"]; + assert_batches_eq!(expected, &query_result); + assert_eq!(1, actual_partitions); // csv won't be scanned in parallel when newlines_in_values is set + + Ok(()) + } + /// Read a single empty csv file in parallel /// /// empty_0_byte.csv: @@ -1031,9 +1292,9 @@ mod tests { .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); let ctx = SessionContext::new_with_config(config); - let file_format = CsvFormat::default().with_has_header(false); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::CSV.get_ext()); + let file_format = Arc::new(CsvFormat::default().with_has_header(false)); + let listing_options = ListingOptions::new(file_format.clone()) + .with_file_extension(file_format.get_ext()); ctx.register_listing_table( "empty", "tests/data/empty_files/all_empty/", @@ -1084,9 +1345,9 @@ mod tests { .with_repartition_file_min_size(0) .with_target_partitions(n_partitions); let ctx = SessionContext::new_with_config(config); - let file_format = CsvFormat::default().with_has_header(false); - let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::CSV.get_ext()); + let file_format = Arc::new(CsvFormat::default().with_has_header(false)); + let listing_options = ListingOptions::new(file_format.clone()) + .with_file_extension(file_format.get_ext()); ctx.register_listing_table( "empty", "tests/data/empty_files/some_empty", @@ -1104,7 +1365,7 @@ mod tests { #[rustfmt::skip] let expected = ["+---------------------+", - "| SUM(empty.column_1) |", + "| sum(empty.column_1) |", "+---------------------+", "| 10 |", "+---------------------+"]; @@ -1143,15 +1404,12 @@ mod tests { #[rustfmt::skip] let expected = ["+-----------------------+", - "| SUM(one_col.column_1) |", + "| sum(one_col.column_1) |", "+-----------------------+", "| 50 |", "+-----------------------+"]; - let file_size = if cfg!(target_os = "windows") { - 30 // new line on Win is '\r\n' - } else { - 20 - }; + + let file_size = std::fs::metadata("tests/data/one_col.csv")?.len() as usize; // A 20-Byte file at most get partitioned into 20 chunks let expected_partitions = if n_partitions <= file_size { n_partitions diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index c1fbe352d37b..a054094822d0 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::error::{DataFusionError, Result}; use datafusion_common::parsers::CompressionTypeVariant::{self, *}; -use datafusion_common::{FileType, GetExt}; +use datafusion_common::GetExt; #[cfg(feature = "compression")] use async_compression::tokio::bufread::{ @@ -112,6 +112,11 @@ impl FileCompressionType { variant: UNCOMPRESSED, }; + /// Read only access to self.variant + pub fn get_variant(&self) -> &CompressionTypeVariant { + &self.variant + } + /// The file is compressed or not pub const fn is_compressed(&self) -> bool { self.variant.is_compressed() @@ -245,90 +250,16 @@ pub trait FileTypeExt { fn get_ext_with_compression(&self, c: FileCompressionType) -> Result; } -impl FileTypeExt for FileType { - fn get_ext_with_compression(&self, c: FileCompressionType) -> Result { - let ext = self.get_ext(); - - match self { - FileType::JSON | FileType::CSV => Ok(format!("{}{}", ext, c.get_ext())), - FileType::AVRO | FileType::ARROW => match c.variant { - UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "FileCompressionType can be specified for CSV/JSON FileType.".into(), - )), - }, - #[cfg(feature = "parquet")] - FileType::PARQUET => match c.variant { - UNCOMPRESSED => Ok(ext), - _ => Err(DataFusionError::Internal( - "FileCompressionType can be specified for CSV/JSON FileType.".into(), - )), - }, - } - } -} - #[cfg(test)] mod tests { use std::str::FromStr; - use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, - }; + use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::error::DataFusionError; - use datafusion_common::file_options::file_type::FileType; - use bytes::Bytes; use futures::StreamExt; - #[test] - fn get_ext_with_compression() { - for (file_type, compression, extension) in [ - (FileType::CSV, FileCompressionType::UNCOMPRESSED, ".csv"), - (FileType::CSV, FileCompressionType::GZIP, ".csv.gz"), - (FileType::CSV, FileCompressionType::XZ, ".csv.xz"), - (FileType::CSV, FileCompressionType::BZIP2, ".csv.bz2"), - (FileType::CSV, FileCompressionType::ZSTD, ".csv.zst"), - (FileType::JSON, FileCompressionType::UNCOMPRESSED, ".json"), - (FileType::JSON, FileCompressionType::GZIP, ".json.gz"), - (FileType::JSON, FileCompressionType::XZ, ".json.xz"), - (FileType::JSON, FileCompressionType::BZIP2, ".json.bz2"), - (FileType::JSON, FileCompressionType::ZSTD, ".json.zst"), - ] { - assert_eq!( - file_type.get_ext_with_compression(compression).unwrap(), - extension - ); - } - - let mut ty_ext_tuple = vec![]; - ty_ext_tuple.push((FileType::AVRO, ".avro")); - #[cfg(feature = "parquet")] - ty_ext_tuple.push((FileType::PARQUET, ".parquet")); - - // Cannot specify compression for these file types - for (file_type, extension) in ty_ext_tuple { - assert_eq!( - file_type - .get_ext_with_compression(FileCompressionType::UNCOMPRESSED) - .unwrap(), - extension - ); - for compression in [ - FileCompressionType::GZIP, - FileCompressionType::XZ, - FileCompressionType::BZIP2, - FileCompressionType::ZSTD, - ] { - assert!(matches!( - file_type.get_ext_with_compression(compression), - Err(DataFusionError::Internal(_)) - )); - } - } - } - #[test] fn from_str() { for (ext, compression_type) in [ diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index efc0aa4328d8..fd97da52165b 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -18,13 +18,14 @@ //! [`JsonFormat`]: Line delimited JSON [`FileFormat`] abstractions use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::fmt::Debug; use std::io::BufReader; use std::sync::Arc; use super::write::orchestration::stateless_multipart_put; -use super::{FileFormat, FileScanConfig}; +use super::{FileFormat, FileFormatFactory, FileScanConfig}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::physical_plan::FileGroupDisplay; @@ -41,18 +42,90 @@ use arrow::datatypes::SchemaRef; use arrow::json; use arrow::json::reader::{infer_json_schema_from_iterator, ValueIter}; use arrow_array::RecordBatch; -use datafusion_common::config::JsonOptions; +use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::{not_impl_err, FileType}; +use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_expr::dml::InsertOp; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; use async_trait::async_trait; use bytes::{Buf, Bytes}; +use datafusion_physical_expr_common::sort_expr::LexRequirement; use object_store::{GetResultPayload, ObjectMeta, ObjectStore}; +#[derive(Default)] +/// Factory struct used to create [JsonFormat] +pub struct JsonFormatFactory { + /// the options carried by format factory + pub options: Option, +} + +impl JsonFormatFactory { + /// Creates an instance of [JsonFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [JsonFormatFactory] with customized default options + pub fn new_with_options(options: JsonOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for JsonFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result> { + let json_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::JSON); + table_options.alter_with_string_hash_map(format_options)?; + table_options.json + } + Some(json_options) => { + let mut json_options = json_options.clone(); + for (k, v) in format_options { + json_options.set(k, v)?; + } + json_options + } + }; + + Ok(Arc::new(JsonFormat::default().with_options(json_options))) + } + + fn default(&self) -> Arc { + Arc::new(JsonFormat::default()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for JsonFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_JSON_EXTENSION[1..].to_string() + } +} + +impl Debug for JsonFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("JsonFormatFactory") + .field("options", &self.options) + .finish() + } +} + /// New line delimited JSON `FileFormat` implementation. #[derive(Debug, Default)] pub struct JsonFormat { @@ -95,6 +168,18 @@ impl FileFormat for JsonFormat { self } + fn get_ext(&self) -> String { + JsonFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + Ok(format!("{}{}", ext, file_compression_type.get_ext())) + } + async fn infer_schema( &self, _state: &SessionState, @@ -166,9 +251,9 @@ impl FileFormat for JsonFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, - order_requirements: Option>, + order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Json"); } @@ -184,10 +269,6 @@ impl FileFormat for JsonFormat { order_requirements, )) as _) } - - fn file_type(&self) -> FileType { - FileType::JSON - } } impl Default for JsonSerializer { @@ -219,7 +300,7 @@ impl BatchSerializer for JsonSerializer { pub struct JsonSink { /// Config options for writing data config: FileSinkConfig, - /// + /// Writer options for underlying Json writer writer_options: JsonWriterOptions, } @@ -474,7 +555,7 @@ mod tests { ctx.register_json("json_parallel", table_path, options) .await?; - let query = "SELECT SUM(a) FROM json_parallel;"; + let query = "SELECT sum(a) FROM json_parallel;"; let result = ctx.sql(query).await?.collect().await?; let actual_partitions = count_num_partitions(&ctx, query).await?; @@ -482,7 +563,7 @@ mod tests { #[rustfmt::skip] let expected = [ "+----------------------+", - "| SUM(json_parallel.a) |", + "| sum(json_parallel.a) |", "+----------------------+", "| -7 |", "+----------------------+" diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 5ee0f7186703..24f1111517d2 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -32,7 +32,8 @@ pub mod parquet; pub mod write; use std::any::Any; -use std::fmt; +use std::collections::HashMap; +use std::fmt::{self, Display}; use std::sync::Arc; use crate::arrow::datatypes::SchemaRef; @@ -41,23 +42,57 @@ use crate::error::Result; use crate::execution::context::SessionState; use crate::physical_plan::{ExecutionPlan, Statistics}; -use datafusion_common::{not_impl_err, FileType}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use arrow_schema::{DataType, Field, FieldRef, Schema}; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{internal_err, not_impl_err, GetExt}; +use datafusion_expr::Expr; +use datafusion_physical_expr::PhysicalExpr; use async_trait::async_trait; +use datafusion_physical_expr_common::sort_expr::LexRequirement; +use file_compression_type::FileCompressionType; use object_store::{ObjectMeta, ObjectStore}; +use std::fmt::Debug; + +/// Factory for creating [`FileFormat`] instances based on session and command level options +/// +/// Users can provide their own `FileFormatFactory` to support arbitrary file formats +pub trait FileFormatFactory: Sync + Send + GetExt + Debug { + /// Initialize a [FileFormat] and configure based on session and command level options + fn create( + &self, + state: &SessionState, + format_options: &HashMap, + ) -> Result>; + + /// Initialize a [FileFormat] with all options set to default values + fn default(&self) -> Arc; + + /// Returns the table source as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; +} /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across /// providers that support the same file formats. /// -/// [`TableProvider`]: crate::datasource::provider::TableProvider +/// [`TableProvider`]: crate::catalog::TableProvider #[async_trait] -pub trait FileFormat: Send + Sync + fmt::Debug { +pub trait FileFormat: Send + Sync + Debug { /// Returns the table provider as [`Any`](std::any::Any) so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; + /// Returns the extension for this FileFormat, e.g. "file.csv" -> csv + fn get_ext(&self) -> String; + + /// Returns the extension for this FileFormat when compressed, e.g. "file.csv.gz" -> csv + fn get_ext_with_compression( + &self, + _file_compression_type: &FileCompressionType, + ) -> Result; + /// Infer the common schema of the provided objects. The objects will usually /// be analysed up to a given number of records or files (as specified in the /// format config) then give the estimated common schema. This might fail if @@ -100,13 +135,249 @@ pub trait FileFormat: Send + Sync + fmt::Debug { _input: Arc, _state: &SessionState, _conf: FileSinkConfig, - _order_requirements: Option>, + _order_requirements: Option, ) -> Result> { not_impl_err!("Writer not implemented for this format") } - /// Returns the FileType corresponding to this FileFormat - fn file_type(&self) -> FileType; + /// Check if the specified file format has support for pushing down the provided filters within + /// the given schemas. Added initially to support the Parquet file format's ability to do this. + fn supports_filters_pushdown( + &self, + _file_schema: &Schema, + _table_schema: &Schema, + _filters: &[&Expr], + ) -> Result { + Ok(FilePushdownSupport::NoSupport) + } +} + +/// An enum to distinguish between different states when determining if certain filters can be +/// pushed down to file scanning +#[derive(Debug, PartialEq)] +pub enum FilePushdownSupport { + /// The file format/system being asked does not support any sort of pushdown. This should be + /// used even if the file format theoretically supports some sort of pushdown, but it's not + /// enabled or implemented yet. + NoSupport, + /// The file format/system being asked *does* support pushdown, but it can't make it work for + /// the provided filter/expression + NotSupportedForFilter, + /// The file format/system being asked *does* support pushdown and *can* make it work for the + /// provided filter/expression + Supported, +} + +/// A container of [FileFormatFactory] which also implements [FileType]. +/// This enables converting a dyn FileFormat to a dyn FileType. +/// The former trait is a superset of the latter trait, which includes execution time +/// relevant methods. [FileType] is only used in logical planning and only implements +/// the subset of methods required during logical planning. +#[derive(Debug)] +pub struct DefaultFileType { + file_format_factory: Arc, +} + +impl DefaultFileType { + /// Constructs a [DefaultFileType] wrapper from a [FileFormatFactory] + pub fn new(file_format_factory: Arc) -> Self { + Self { + file_format_factory, + } + } + + /// get a reference to the inner [FileFormatFactory] struct + pub fn as_format_factory(&self) -> &Arc { + &self.file_format_factory + } +} + +impl FileType for DefaultFileType { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Display for DefaultFileType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self.file_format_factory) + } +} + +impl GetExt for DefaultFileType { + fn get_ext(&self) -> String { + self.file_format_factory.get_ext() + } +} + +/// Converts a [FileFormatFactory] to a [FileType] +pub fn format_as_file_type( + file_format_factory: Arc, +) -> Arc { + Arc::new(DefaultFileType { + file_format_factory, + }) +} + +/// Converts a [FileType] to a [FileFormatFactory]. +/// Returns an error if the [FileType] cannot be +/// downcasted to a [DefaultFileType]. +pub fn file_type_to_format( + file_type: &Arc, +) -> Result> { + match file_type + .as_ref() + .as_any() + .downcast_ref::() + { + Some(source) => Ok(source.file_format_factory.clone()), + _ => internal_err!("FileType was not DefaultFileType"), + } +} + +/// Create a new field with the specified data type, copying the other +/// properties from the input field +fn field_with_new_type(field: &FieldRef, new_type: DataType) -> FieldRef { + Arc::new(field.as_ref().clone().with_data_type(new_type)) +} + +/// Transform a schema to use view types for Utf8 and Binary +/// +/// See [parquet::ParquetFormat::force_view_types] for details +pub fn transform_schema_to_view(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => { + field_with_new_type(field, DataType::Utf8View) + } + DataType::Binary | DataType::LargeBinary => { + field_with_new_type(field, DataType::BinaryView) + } + _ => field.clone(), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} + +/// Coerces the file schema if the table schema uses a view type. +pub(crate) fn coerce_file_schema_to_view_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| { + let dt = f.data_type(); + if dt.equals_datatype(&DataType::Utf8View) + || dt.equals_datatype(&DataType::BinaryView) + { + transform = true; + } + (f.name(), dt) + }) + .collect(); + + if !transform { + return None; + } + + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + (Some(DataType::Utf8View), DataType::Utf8 | DataType::LargeUtf8) => { + field_with_new_type(field, DataType::Utf8View) + } + ( + Some(DataType::BinaryView), + DataType::Binary | DataType::LargeBinary, + ) => field_with_new_type(field, DataType::BinaryView), + _ => field.clone(), + }, + ) + .collect(); + + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) +} + +/// Transform a schema so that any binary types are strings +pub fn transform_binary_to_string(schema: &Schema) -> Schema { + let transformed_fields: Vec> = schema + .fields + .iter() + .map(|field| match field.data_type() { + DataType::Binary => field_with_new_type(field, DataType::Utf8), + DataType::LargeBinary => field_with_new_type(field, DataType::LargeUtf8), + DataType::BinaryView => field_with_new_type(field, DataType::Utf8View), + _ => field.clone(), + }) + .collect(); + Schema::new_with_metadata(transformed_fields, schema.metadata.clone()) +} + +/// If the table schema uses a string type, coerce the file schema to use a string type. +/// +/// See [parquet::ParquetFormat::binary_as_string] for details +pub(crate) fn coerce_file_schema_to_string_type( + table_schema: &Schema, + file_schema: &Schema, +) -> Option { + let mut transform = false; + let table_fields: HashMap<_, _> = table_schema + .fields + .iter() + .map(|f| (f.name(), f.data_type())) + .collect(); + let transformed_fields: Vec> = file_schema + .fields + .iter() + .map( + |field| match (table_fields.get(field.name()), field.data_type()) { + // table schema uses string type, coerce the file schema to use string type + ( + Some(DataType::Utf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8) + } + // table schema uses large string type, coerce the file schema to use large string type + ( + Some(DataType::LargeUtf8), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::LargeUtf8) + } + // table schema uses string view type, coerce the file schema to use view type + ( + Some(DataType::Utf8View), + DataType::Binary | DataType::LargeBinary | DataType::BinaryView, + ) => { + transform = true; + field_with_new_type(field, DataType::Utf8View) + } + _ => field.clone(), + }, + ) + .collect(); + + if !transform { + None + } else { + Some(Schema::new_with_metadata( + transformed_fields, + file_schema.metadata.clone(), + )) + } } #[cfg(test)] @@ -124,10 +395,9 @@ pub(crate) mod test_util { use object_store::local::LocalFileSystem; use object_store::path::Path; use object_store::{ - GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, PutOptions, - PutResult, + Attributes, GetOptions, GetResult, GetResultPayload, ListResult, MultipartUpload, + PutMultipartOpts, PutOptions, PutPayload, PutResult, }; - use tokio::io::AsyncWrite; pub async fn scan_format( state: &SessionState, @@ -150,22 +420,18 @@ pub(crate) mod test_util { object_meta: meta, partition_values: vec![], range: None, + statistics: None, extensions: None, }]]; let exec = format .create_physical_plan( state, - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema, - file_groups, - statistics, - projection, - limit, - table_partition_cols: vec![], - output_ordering: vec![], - }, + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file_groups(file_groups) + .with_statistics(statistics) + .with_projection(projection) + .with_limit(limit), None, ) .await?; @@ -181,8 +447,8 @@ pub(crate) mod test_util { iterations_detected: Arc>, } - impl std::fmt::Display for VariableStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + impl Display for VariableStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "VariableStream") } } @@ -192,25 +458,17 @@ pub(crate) mod test_util { async fn put_opts( &self, _location: &Path, - _bytes: Bytes, + _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { unimplemented!() } - async fn put_multipart( - &self, - _location: &Path, - ) -> object_store::Result<(MultipartId, Box)> - { - unimplemented!() - } - - async fn abort_multipart( + async fn put_multipart_opts( &self, _location: &Path, - _multipart_id: &MultipartId, - ) -> object_store::Result<()> { + _opts: PutMultipartOpts, + ) -> object_store::Result> { unimplemented!() } @@ -236,6 +494,7 @@ pub(crate) mod test_util { version: None, }, range: Default::default(), + attributes: Attributes::default(), }) } diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index f5bd72495d66..1e0e28ef88cb 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -31,7 +31,6 @@ use crate::datasource::{ }; use crate::error::Result; use crate::execution::context::{SessionConfig, SessionState}; -use crate::logical_expr::Expr; use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::config::TableOptions; @@ -41,6 +40,7 @@ use datafusion_common::{ }; use async_trait::async_trait; +use datafusion_expr::SortExpr; /// Options that control the reading of CSV files. /// @@ -59,8 +59,20 @@ pub struct CsvReadOptions<'a> { pub delimiter: u8, /// An optional quote character. Defaults to `b'"'`. pub quote: u8, + /// An optional terminator character. Defaults to None (CRLF). + pub terminator: Option, /// An optional escape character. Defaults to None. pub escape: Option, + /// If enabled, lines beginning with this byte are ignored. + pub comment: Option, + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub newlines_in_values: bool, /// An optional schema representing the CSV files. If None, CSV reader will try to infer it /// based on data in file. pub schema: Option<&'a Schema>, @@ -74,7 +86,7 @@ pub struct CsvReadOptions<'a> { /// File compression type pub file_compression_type: FileCompressionType, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for CsvReadOptions<'a> { @@ -92,11 +104,14 @@ impl<'a> CsvReadOptions<'a> { schema_infer_max_records: DEFAULT_SCHEMA_INFER_MAX_RECORD, delimiter: b',', quote: b'"', + terminator: None, escape: None, + newlines_in_values: false, file_extension: DEFAULT_CSV_EXTENSION, table_partition_cols: vec![], file_compression_type: FileCompressionType::UNCOMPRESSED, file_sort_order: vec![], + comment: None, } } @@ -106,6 +121,12 @@ impl<'a> CsvReadOptions<'a> { self } + /// Specify comment char to use for CSV read + pub fn comment(mut self, comment: u8) -> Self { + self.comment = Some(comment); + self + } + /// Specify delimiter to use for CSV read pub fn delimiter(mut self, delimiter: u8) -> Self { self.delimiter = delimiter; @@ -118,12 +139,30 @@ impl<'a> CsvReadOptions<'a> { self } + /// Specify terminator to use for CSV read + pub fn terminator(mut self, terminator: Option) -> Self { + self.terminator = terminator; + self + } + /// Specify delimiter to use for CSV read pub fn escape(mut self, escape: u8) -> Self { self.escape = Some(escape); self } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = newlines_in_values; + self + } + /// Specify the file extension for CSV file selection pub fn file_extension(mut self, file_extension: &'a str) -> Self { self.file_extension = file_extension; @@ -169,7 +208,7 @@ impl<'a> CsvReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -201,7 +240,7 @@ pub struct ParquetReadOptions<'a> { /// based on data in file. pub schema: Option<&'a Schema>, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for ParquetReadOptions<'a> { @@ -218,6 +257,11 @@ impl<'a> Default for ParquetReadOptions<'a> { } impl<'a> ParquetReadOptions<'a> { + /// Create a new ParquetReadOptions with default values + pub fn new() -> Self { + Default::default() + } + /// Specify parquet_pruning pub fn parquet_pruning(mut self, parquet_pruning: bool) -> Self { self.parquet_pruning = Some(parquet_pruning); @@ -248,7 +292,7 @@ impl<'a> ParquetReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -367,7 +411,7 @@ pub struct NdJsonReadOptions<'a> { /// Flag indicating whether this file may be unbounded (as in a FIFO file). pub infinite: bool, /// Indicates how the file is sorted - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl<'a> Default for NdJsonReadOptions<'a> { @@ -422,7 +466,7 @@ impl<'a> NdJsonReadOptions<'a> { } /// Configure if file has known sort order - pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -477,9 +521,12 @@ impl ReadOptions<'_> for CsvReadOptions<'_> { let file_format = CsvFormat::default() .with_options(table_options.csv) .with_has_header(self.has_header) + .with_comment(self.comment) .with_delimiter(self.delimiter) .with_quote(self.quote) .with_escape(self.escape) + .with_terminator(self.terminator) + .with_newlines_in_values(self.newlines_in_values) .with_schema_infer_max_rec(self.schema_infer_max_records) .with_file_compression_type(self.file_compression_type.to_owned()); diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 66f506f9aa2e..b3f54e0773fd 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -20,42 +20,51 @@ use std::any::Any; use std::fmt; use std::fmt::Debug; +use std::ops::Range; use std::sync::Arc; use super::write::demux::start_demuxer_task; use super::write::{create_writer, SharedBuffer}; -use super::{FileFormat, FileScanConfig}; -use crate::arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, +use super::{ + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, + transform_binary_to_string, transform_schema_to_view, FileFormat, FileFormatFactory, + FilePushdownSupport, FileScanConfig, }; -use crate::arrow::datatypes::{DataType, Fields, Schema, SchemaRef}; +use crate::arrow::array::RecordBatch; +use crate::arrow::datatypes::{Fields, Schema, SchemaRef}; use crate::datasource::file_format::file_compression_type::FileCompressionType; -use crate::datasource::physical_plan::{ - FileGroupDisplay, FileSinkConfig, ParquetExec, SchemaAdapter, -}; +use crate::datasource::physical_plan::{FileGroupDisplay, FileSinkConfig}; use crate::datasource::statistics::{create_max_min_accs, get_col_stats}; use crate::error::Result; use crate::execution::context::SessionState; -use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::insert::{DataSink, DataSinkExec}; use crate::physical_plan::{ Accumulator, DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, Statistics, }; -use datafusion_common::config::TableParquetOptions; +use arrow::compute::sum; +use datafusion_common::config::{ConfigField, ConfigFileType, TableParquetOptions}; use datafusion_common::file_options::parquet_writer::ParquetWriterOptions; +use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_err, DataFusionError, FileType, + internal_datafusion_err, not_impl_err, DataFusionError, GetExt, + DEFAULT_PARQUET_EXTENSION, }; use datafusion_common_runtime::SpawnedTask; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::Expr; +use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; +use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; -use bytes::{BufMut, BytesMut}; +use bytes::Bytes; +use hashbrown::HashMap; +use log::debug; use object_store::buffered::BufWriter; use parquet::arrow::arrow_writer::{ compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, @@ -64,20 +73,25 @@ use parquet::arrow::arrow_writer::{ use parquet::arrow::{ arrow_to_parquet_schema, parquet_to_arrow_schema, AsyncArrowWriter, }; -use parquet::file::footer::{decode_footer, decode_metadata}; -use parquet::file::metadata::ParquetMetaData; +use parquet::file::metadata::{ParquetMetaData, ParquetMetaDataReader, RowGroupMetaData}; use parquet::file::properties::WriterProperties; -use parquet::file::statistics::Statistics as ParquetStatistics; use parquet::file::writer::SerializedFileWriter; use parquet::format::FileMetaData; use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; use tokio::task::JoinSet; -use futures::{StreamExt, TryStreamExt}; -use hashbrown::HashMap; +use crate::datasource::physical_plan::parquet::{ + can_expr_be_pushed_down_with_schemas, ParquetExecBuilder, +}; +use datafusion_physical_expr_common::sort_expr::LexRequirement; +use futures::future::BoxFuture; +use futures::{FutureExt, StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::arrow::async_reader::MetadataFetch; +use parquet::errors::ParquetError; /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. @@ -87,6 +101,77 @@ const INITIAL_BUFFER_BYTES: usize = 1048576; /// this size, it is flushed to object store const BUFFER_FLUSH_BYTES: usize = 1024000; +#[derive(Default)] +/// Factory struct used to create [ParquetFormat] +pub struct ParquetFormatFactory { + /// inner options for parquet + pub options: Option, +} + +impl ParquetFormatFactory { + /// Creates an instance of [ParquetFormatFactory] + pub fn new() -> Self { + Self { options: None } + } + + /// Creates an instance of [ParquetFormatFactory] with customized default options + pub fn new_with_options(options: TableParquetOptions) -> Self { + Self { + options: Some(options), + } + } +} + +impl FileFormatFactory for ParquetFormatFactory { + fn create( + &self, + state: &SessionState, + format_options: &std::collections::HashMap, + ) -> Result> { + let parquet_options = match &self.options { + None => { + let mut table_options = state.default_table_options(); + table_options.set_config_format(ConfigFileType::PARQUET); + table_options.alter_with_string_hash_map(format_options)?; + table_options.parquet + } + Some(parquet_options) => { + let mut parquet_options = parquet_options.clone(); + for (k, v) in format_options { + parquet_options.set(k, v)?; + } + parquet_options + } + }; + + Ok(Arc::new( + ParquetFormat::default().with_options(parquet_options), + )) + } + + fn default(&self) -> Arc { + Arc::new(ParquetFormat::default()) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl GetExt for ParquetFormatFactory { + fn get_ext(&self) -> String { + // Removes the dot, i.e. ".parquet" -> "parquet" + DEFAULT_PARQUET_EXTENSION[1..].to_string() + } +} + +impl Debug for ParquetFormatFactory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ParquetFormatFactory") + .field("ParquetFormatFactory", &self.options) + .finish() + } +} /// The Apache Parquet `FileFormat` implementation #[derive(Debug, Default)] pub struct ParquetFormat { @@ -153,6 +238,45 @@ impl ParquetFormat { pub fn options(&self) -> &TableParquetOptions { &self.options } + + /// Return `true` if should use view types. + /// + /// If this returns true, DataFusion will instruct the parquet reader + /// to read string / binary columns using view `StringView` or `BinaryView` + /// if the table schema specifies those types, regardless of any embedded metadata + /// that may specify an alternate Arrow type. The parquet reader is optimized + /// for reading `StringView` and `BinaryView` and such queries are significantly faster. + /// + /// If this returns false, the parquet reader will read the columns according to the + /// defaults or any embedded Arrow type information. This may result in reading + /// `StringArrays` and then casting to `StringViewArray` which is less efficient. + pub fn force_view_types(&self) -> bool { + self.options.global.schema_force_view_types + } + + /// If true, will use view types. See [`Self::force_view_types`] for details + pub fn with_force_view_types(mut self, use_views: bool) -> Self { + self.options.global.schema_force_view_types = use_views; + self + } + + /// Return `true` if binary types will be read as strings. + /// + /// If this returns true, DataFusion will instruct the parquet reader + /// to read binary columns such as `Binary` or `BinaryView` as the + /// corresponding string type such as `Utf8` or `LargeUtf8`. + /// The parquet reader has special optimizations for `Utf8` and `LargeUtf8` + /// validation, and such queries are significantly faster than reading + /// binary columns and then casting to string columns. + pub fn binary_as_string(&self) -> bool { + self.options.global.binary_as_string + } + + /// If true, will read binary types as strings. See [`Self::binary_as_string`] for details + pub fn with_binary_as_string(mut self, binary_as_string: bool) -> Self { + self.options.global.binary_as_string = binary_as_string; + self + } } /// Clears all metadata (Schema level and field level) on an iterator @@ -188,6 +312,23 @@ impl FileFormat for ParquetFormat { self } + fn get_ext(&self) -> String { + ParquetFormatFactory::new().get_ext() + } + + fn get_ext_with_compression( + &self, + file_compression_type: &FileCompressionType, + ) -> Result { + let ext = self.get_ext(); + match file_compression_type.get_variant() { + CompressionTypeVariant::UNCOMPRESSED => Ok(ext), + _ => Err(DataFusionError::Internal( + "Parquet FileFormat does not support compression.".into(), + )), + } + } + async fn infer_schema( &self, state: &SessionState, @@ -226,6 +367,18 @@ impl FileFormat for ParquetFormat { Schema::try_merge(schemas) }?; + let schema = if self.binary_as_string() { + transform_binary_to_string(&schema) + } else { + schema + }; + + let schema = if self.force_view_types() { + transform_schema_to_view(&schema) + } else { + schema + }; + Ok(Arc::new(schema)) } @@ -252,17 +405,22 @@ impl FileFormat for ParquetFormat { conf: FileScanConfig, filters: Option<&Arc>, ) -> Result> { + let mut builder = + ParquetExecBuilder::new_with_options(conf, self.options.clone()); + // If enable pruning then combine the filters to build the predicate. // If disable pruning then set the predicate to None, thus readers // will not prune data based on the statistics. - let predicate = self.enable_pruning().then(|| filters.cloned()).flatten(); + if self.enable_pruning() { + if let Some(predicate) = filters.cloned() { + builder = builder.with_predicate(predicate); + } + } + if let Some(metadata_size_hint) = self.metadata_size_hint() { + builder = builder.with_metadata_size_hint(metadata_size_hint); + } - Ok(Arc::new(ParquetExec::new( - conf, - predicate, - self.metadata_size_hint(), - self.options.clone(), - ))) + Ok(builder.build_arc()) } async fn create_writer_physical_plan( @@ -270,9 +428,9 @@ impl FileFormat for ParquetFormat { input: Arc, _state: &SessionState, conf: FileSinkConfig, - order_requirements: Option>, + order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } @@ -287,88 +445,52 @@ impl FileFormat for ParquetFormat { )) as _) } - fn file_type(&self) -> FileType { - FileType::PARQUET + fn supports_filters_pushdown( + &self, + file_schema: &Schema, + table_schema: &Schema, + filters: &[&Expr], + ) -> Result { + if !self.options().global.pushdown_filters { + return Ok(FilePushdownSupport::NoSupport); + } + + let all_supported = filters.iter().all(|filter| { + can_expr_be_pushed_down_with_schemas(filter, file_schema, table_schema) + }); + + Ok(if all_supported { + FilePushdownSupport::Supported + } else { + FilePushdownSupport::NotSupportedForFilter + }) } } -fn summarize_min_max( - max_values: &mut [Option], - min_values: &mut [Option], - fields: &Fields, - i: usize, - stat: &ParquetStatistics, -) { - if !stat.has_min_max_set() { - max_values[i] = None; - min_values[i] = None; - return; - } - match stat { - ParquetStatistics::Boolean(s) if DataType::Boolean == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(BooleanArray::from(vec![*s.max()]))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(BooleanArray::from(vec![*s.min()]))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Int32(s) if DataType::Int32 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Int32Array::from_value(*s.max(), 1))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Int32Array::from_value(*s.min(), 1))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Int64(s) if DataType::Int64 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Int64Array::from_value(*s.max(), 1))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Int64Array::from_value(*s.min(), 1))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Float(s) if DataType::Float32 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Float32Array::from(vec![*s.max()]))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Float32Array::from(vec![*s.min()]))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - ParquetStatistics::Double(s) if DataType::Float64 == *fields[i].data_type() => { - if let Some(max_value) = &mut max_values[i] { - max_value - .update_batch(&[Arc::new(Float64Array::from(vec![*s.max()]))]) - .unwrap_or_else(|_| max_values[i] = None); - } - if let Some(min_value) = &mut min_values[i] { - min_value - .update_batch(&[Arc::new(Float64Array::from(vec![*s.min()]))]) - .unwrap_or_else(|_| min_values[i] = None); - } - } - _ => { - max_values[i] = None; - min_values[i] = None; +/// [`MetadataFetch`] adapter for reading bytes from an [`ObjectStore`] +struct ObjectStoreFetch<'a> { + store: &'a dyn ObjectStore, + meta: &'a ObjectMeta, +} + +impl<'a> ObjectStoreFetch<'a> { + fn new(store: &'a dyn ObjectStore, meta: &'a ObjectMeta) -> Self { + Self { store, meta } + } +} + +impl<'a> MetadataFetch for ObjectStoreFetch<'a> { + fn fetch( + &mut self, + range: Range, + ) -> BoxFuture<'_, Result> { + async { + self.store + .get_range(&self.meta.location, range) + .await + .map_err(ParquetError::from) } + .boxed() } } @@ -383,57 +505,14 @@ pub async fn fetch_parquet_metadata( meta: &ObjectMeta, size_hint: Option, ) -> Result { - if meta.size < 8 { - return exec_err!("file size of {} is less than footer", meta.size); - } - - // If a size hint is provided, read more than the minimum size - // to try and avoid a second fetch. - let footer_start = if let Some(size_hint) = size_hint { - meta.size.saturating_sub(size_hint) - } else { - meta.size - 8 - }; - - let suffix = store - .get_range(&meta.location, footer_start..meta.size) - .await?; - - let suffix_len = suffix.len(); - - let mut footer = [0; 8]; - footer.copy_from_slice(&suffix[suffix_len - 8..suffix_len]); - - let length = decode_footer(&footer)?; - - if meta.size < length + 8 { - return exec_err!( - "file size of {} is less than footer + metadata {}", - meta.size, - length + 8 - ); - } - - // Did not fetch the entire file metadata in the initial read, need to make a second request - if length > suffix_len - 8 { - let metadata_start = meta.size - length - 8; - let remaining_metadata = store - .get_range(&meta.location, metadata_start..footer_start) - .await?; - - let mut metadata = BytesMut::with_capacity(length); - - metadata.put(remaining_metadata.as_ref()); - metadata.put(&suffix[..suffix_len - 8]); - - Ok(decode_metadata(metadata.as_ref())?) - } else { - let metadata_start = meta.size - length - 8; - - Ok(decode_metadata( - &suffix[metadata_start - footer_start..suffix_len - 8], - )?) - } + let file_size = meta.size; + let fetch = ObjectStoreFetch::new(store, meta); + + ParquetMetaDataReader::new() + .with_prefetch_hint(size_hint) + .load_and_finish(fetch, file_size) + .await + .map_err(DataFusionError::from) } /// Read and parse the schema of the Parquet file at location `path` @@ -452,6 +531,8 @@ async fn fetch_schema( } /// Read and parse the statistics of the Parquet file at location `path` +/// +/// See [`statistics_from_parquet_meta`] for more details async fn fetch_statistics( store: &dyn ObjectStore, table_schema: SchemaRef, @@ -459,84 +540,145 @@ async fn fetch_statistics( metadata_size_hint: Option, ) -> Result { let metadata = fetch_parquet_metadata(store, file, metadata_size_hint).await?; - let file_metadata = metadata.file_metadata(); - - let file_schema = parquet_to_arrow_schema( - file_metadata.schema_descr(), - file_metadata.key_value_metadata(), - )?; + statistics_from_parquet_meta_calc(&metadata, table_schema) +} - let num_fields = table_schema.fields().len(); - let fields = table_schema.fields(); +/// Convert statistics in [`ParquetMetaData`] into [`Statistics`] using ['StatisticsConverter`] +/// +/// The statistics are calculated for each column in the table schema +/// using the row group statistics in the parquet metadata. +pub fn statistics_from_parquet_meta_calc( + metadata: &ParquetMetaData, + table_schema: SchemaRef, +) -> Result { + let row_groups_metadata = metadata.row_groups(); - let mut num_rows = 0; - let mut total_byte_size = 0; - let mut null_counts = vec![Precision::Exact(0); num_fields]; + let mut statistics = Statistics::new_unknown(&table_schema); let mut has_statistics = false; + let mut num_rows = 0_usize; + let mut total_byte_size = 0_usize; + for row_group_meta in row_groups_metadata { + num_rows += row_group_meta.num_rows() as usize; + total_byte_size += row_group_meta.total_byte_size() as usize; + + if !has_statistics { + row_group_meta.columns().iter().for_each(|column| { + has_statistics = column.statistics().is_some(); + }); + } + } + statistics.num_rows = Precision::Exact(num_rows); + statistics.total_byte_size = Precision::Exact(total_byte_size); - let schema_adapter = SchemaAdapter::new(table_schema.clone()); - - let (mut max_values, mut min_values) = create_max_min_accs(&table_schema); - - for row_group_meta in metadata.row_groups() { - num_rows += row_group_meta.num_rows(); - total_byte_size += row_group_meta.total_byte_size(); + let file_metadata = metadata.file_metadata(); + let mut file_schema = parquet_to_arrow_schema( + file_metadata.schema_descr(), + file_metadata.key_value_metadata(), + )?; + if let Some(merged) = coerce_file_schema_to_string_type(&table_schema, &file_schema) { + file_schema = merged; + } - let mut column_stats: HashMap = HashMap::new(); + if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &file_schema) { + file_schema = merged; + } - for (i, column) in row_group_meta.columns().iter().enumerate() { - if let Some(stat) = column.statistics() { - has_statistics = true; - column_stats.insert(i, (stat.null_count(), stat)); - } - } + statistics.column_statistics = if has_statistics { + let (mut max_accs, mut min_accs) = create_max_min_accs(&table_schema); + let mut null_counts_array = + vec![Precision::Exact(0); table_schema.fields().len()]; - if has_statistics { - for (table_idx, null_cnt) in null_counts.iter_mut().enumerate() { - if let Some(file_idx) = - schema_adapter.map_column_index(table_idx, &file_schema) - { - if let Some((null_count, stats)) = column_stats.get(&file_idx) { - *null_cnt = null_cnt.add(&Precision::Exact(*null_count as usize)); - summarize_min_max( - &mut max_values, - &mut min_values, - fields, - table_idx, - stats, + table_schema + .fields() + .iter() + .enumerate() + .for_each(|(idx, field)| { + match StatisticsConverter::try_new( + field.name(), + &file_schema, + file_metadata.schema_descr(), + ) { + Ok(stats_converter) => { + summarize_min_max_null_counts( + &mut min_accs, + &mut max_accs, + &mut null_counts_array, + idx, + num_rows, + &stats_converter, + row_groups_metadata, ) - } else { - // If none statistics of current column exists, set the Max/Min Accumulator to None. - max_values[table_idx] = None; - min_values[table_idx] = None; + .ok(); + } + Err(e) => { + debug!("Failed to create statistics converter: {}", e); + null_counts_array[idx] = Precision::Exact(num_rows); } - } else { - *null_cnt = null_cnt.add(&Precision::Exact(num_rows as usize)); } - } - } - } + }); - let column_stats = if has_statistics { - get_col_stats(&table_schema, null_counts, &mut max_values, &mut min_values) + get_col_stats( + &table_schema, + null_counts_array, + &mut max_accs, + &mut min_accs, + ) } else { Statistics::unknown_column(&table_schema) }; - let statistics = Statistics { - num_rows: Precision::Exact(num_rows as usize), - total_byte_size: Precision::Exact(total_byte_size as usize), - column_statistics: column_stats, - }; - Ok(statistics) } +/// Deprecated +/// Use [`statistics_from_parquet_meta_calc`] instead. +/// This method was deprecated because it didn't need to be async so a new method was created +/// that exposes a synchronous API. +#[deprecated( + since = "40.0.0", + note = "please use `statistics_from_parquet_meta_calc` instead" +)] +pub async fn statistics_from_parquet_meta( + metadata: &ParquetMetaData, + table_schema: SchemaRef, +) -> Result { + statistics_from_parquet_meta_calc(metadata, table_schema) +} + +fn summarize_min_max_null_counts( + min_accs: &mut [Option], + max_accs: &mut [Option], + null_counts_array: &mut [Precision], + arrow_schema_index: usize, + num_rows: usize, + stats_converter: &StatisticsConverter, + row_groups_metadata: &[RowGroupMetaData], +) -> Result<()> { + let max_values = stats_converter.row_group_maxes(row_groups_metadata)?; + let min_values = stats_converter.row_group_mins(row_groups_metadata)?; + let null_counts = stats_converter.row_group_null_counts(row_groups_metadata)?; + + if let Some(max_acc) = &mut max_accs[arrow_schema_index] { + max_acc.update_batch(&[max_values])?; + } + + if let Some(min_acc) = &mut min_accs[arrow_schema_index] { + min_acc.update_batch(&[min_values])?; + } + + null_counts_array[arrow_schema_index] = Precision::Exact(match sum(&null_counts) { + Some(null_count) => null_count as usize, + None => num_rows, + }); + + Ok(()) +} + /// Implements [`DataSink`] for writing to a parquet file. pub struct ParquetSink { /// Config options for writing data config: FileSinkConfig, - /// + /// Underlying parquet options parquet_options: TableParquetOptions, /// File metadata from successfully produced parquet files. The Mutex is only used /// to allow inserting to HashMap from behind borrowed reference in DataSink::write_all. @@ -586,7 +728,9 @@ impl ParquetSink { /// of hive style partitioning where some columns are removed from the /// underlying files. fn get_writer_schema(&self) -> Arc { - if !self.config.table_partition_cols.is_empty() { + if !self.config.table_partition_cols.is_empty() + && !self.config.keep_partition_by_columns + { let schema = self.config.output_schema(); let partition_names: Vec<_> = self .config @@ -594,13 +738,14 @@ impl ParquetSink { .iter() .map(|(s, _)| s) .collect(); - Arc::new(Schema::new( + Arc::new(Schema::new_with_metadata( schema .fields() .iter() .filter(|f| !partition_names.contains(&f.name())) .map(|f| (**f).clone()) .collect::>(), + schema.metadata().clone(), )) } else { self.config.output_schema().clone() @@ -676,6 +821,7 @@ impl DataSink for ParquetSink { part_col, self.config.table_paths[0].clone(), "parquet".into(), + self.config.keep_partition_by_columns, ); let mut file_write_tasks: JoinSet< @@ -691,9 +837,13 @@ impl DataSink for ParquetSink { parquet_props.writer_options().clone(), ) .await?; + let mut reservation = + MemoryConsumer::new(format!("ParquetSink[{}]", path)) + .register(context.memory_pool()); file_write_tasks.spawn(async move { while let Some(batch) = rx.recv().await { writer.write(&batch).await?; + reservation.try_resize(writer.memory_size())?; } let file_metadata = writer .close() @@ -713,6 +863,7 @@ impl DataSink for ParquetSink { let schema = self.get_writer_schema(); let props = parquet_props.clone(); let parallel_options_clone = parallel_options.clone(); + let pool = Arc::clone(context.memory_pool()); file_write_tasks.spawn(async move { let file_metadata = output_single_parquet_file_parallelized( writer, @@ -720,6 +871,7 @@ impl DataSink for ParquetSink { schema, props.writer_options(), parallel_options_clone, + pool, ) .await?; Ok((path, file_metadata)) @@ -749,7 +901,10 @@ impl DataSink for ParquetSink { } } - demux_task.join_unwind().await?; + demux_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; Ok(row_count as u64) } @@ -760,14 +915,16 @@ impl DataSink for ParquetSink { async fn column_serializer_task( mut rx: Receiver, mut writer: ArrowColumnWriter, -) -> Result { + mut reservation: MemoryReservation, +) -> Result<(ArrowColumnWriter, MemoryReservation)> { while let Some(col) = rx.recv().await { writer.write(&col)?; + reservation.try_resize(writer.memory_size())?; } - Ok(writer) + Ok((writer, reservation)) } -type ColumnWriterTask = SpawnedTask>; +type ColumnWriterTask = SpawnedTask>; type ColSender = Sender; /// Spawns a parallel serialization task for each column @@ -777,6 +934,7 @@ fn spawn_column_parallel_row_group_writer( schema: Arc, parquet_props: Arc, max_buffer_size: usize, + pool: &Arc, ) -> Result<(Vec, Vec)> { let schema_desc = arrow_to_parquet_schema(&schema)?; let col_writers = get_column_writers(&schema_desc, &parquet_props, &schema)?; @@ -786,11 +944,17 @@ fn spawn_column_parallel_row_group_writer( let mut col_array_channels = Vec::with_capacity(num_columns); for writer in col_writers.into_iter() { // Buffer size of this channel limits the number of arrays queued up for column level serialization - let (send_array, recieve_array) = + let (send_array, receive_array) = mpsc::channel::(max_buffer_size); col_array_channels.push(send_array); - let task = SpawnedTask::spawn(column_serializer_task(recieve_array, writer)); + let reservation = + MemoryConsumer::new("ParquetSink(ArrowColumnWriter)").register(pool); + let task = SpawnedTask::spawn(column_serializer_task( + receive_array, + writer, + reservation, + )); col_writer_tasks.push(task); } @@ -806,7 +970,7 @@ struct ParallelParquetWriterOptions { /// This is the return type of calling [ArrowColumnWriter].close() on each column /// i.e. the Vec of encoded columns which can be appended to a row group -type RBStreamSerializeResult = Result<(Vec, usize)>; +type RBStreamSerializeResult = Result<(Vec, MemoryReservation, usize)>; /// Sends the ArrowArrays in passed [RecordBatch] through the channels to their respective /// parallel column serializers. @@ -819,12 +983,12 @@ async fn send_arrays_to_col_writers( let mut next_channel = 0; for (array, field) in rb.columns().iter().zip(schema.fields()) { for c in compute_leaves(field, array)? { - col_array_channels[next_channel] - .send(c) - .await - .map_err(|_| { - DataFusionError::Internal("Unable to send array to writer!".into()) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if col_array_channels[next_channel].send(c).await.is_err() { + return Ok(()); + } + next_channel += 1; } } @@ -837,16 +1001,25 @@ async fn send_arrays_to_col_writers( fn spawn_rg_join_and_finalize_task( column_writer_tasks: Vec, rg_rows: usize, + pool: &Arc, ) -> SpawnedTask { + let mut rg_reservation = + MemoryConsumer::new("ParquetSink(SerializedRowGroupWriter)").register(pool); + SpawnedTask::spawn(async move { let num_cols = column_writer_tasks.len(); let mut finalized_rg = Vec::with_capacity(num_cols); for task in column_writer_tasks.into_iter() { - let writer = task.join_unwind().await?; + let (writer, _col_reservation) = task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; + let encoded_size = writer.get_estimated_total_bytes(); + rg_reservation.grow(encoded_size); finalized_rg.push(writer.close()?); } - Ok((finalized_rg, rg_rows)) + Ok((finalized_rg, rg_reservation, rg_rows)) }) } @@ -856,7 +1029,7 @@ fn spawn_rg_join_and_finalize_task( /// row group is reached, the parallel tasks are joined on another separate task /// and sent to a concatenation task. This task immediately continues to work /// on the next row group in parallel. So, parquet serialization is parallelized -/// accross both columns and row_groups, with a theoretical max number of parallel tasks +/// across both columns and row_groups, with a theoretical max number of parallel tasks /// given by n_columns * num_row_groups. fn spawn_parquet_parallel_serialization_task( mut data: Receiver, @@ -864,6 +1037,7 @@ fn spawn_parquet_parallel_serialization_task( schema: Arc, writer_props: Arc, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> SpawnedTask> { SpawnedTask::spawn(async move { let max_buffer_rb = parallel_options.max_buffered_record_batches_per_stream; @@ -873,6 +1047,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; let mut current_rg_rows = 0; @@ -899,13 +1074,14 @@ fn spawn_parquet_parallel_serialization_task( let finalize_rg_task = spawn_rg_join_and_finalize_task( column_writer_handles, max_row_group_rows, + &pool, ); - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } current_rg_rows = 0; rb = rb.slice(rows_left, rb.num_rows() - rows_left); @@ -915,6 +1091,7 @@ fn spawn_parquet_parallel_serialization_task( schema.clone(), writer_props.clone(), max_buffer_rb, + &pool, )?; } } @@ -923,14 +1100,17 @@ fn spawn_parquet_parallel_serialization_task( drop(col_array_channels); // Handle leftover rows as final rowgroup, which may be smaller than max_row_group_rows if current_rg_rows > 0 { - let finalize_rg_task = - spawn_rg_join_and_finalize_task(column_writer_handles, current_rg_rows); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + current_rg_rows, + &pool, + ); - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + // Do not surface error from closed channel (means something + // else hit an error, and the plan is shutting down). + if serialize_tx.send(finalize_rg_task).await.is_err() { + return Ok(()); + } } Ok(()) @@ -944,9 +1124,13 @@ async fn concatenate_parallel_row_groups( schema: Arc, writer_props: Arc, mut object_store_writer: Box, + pool: Arc, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); + let mut file_reservation = + MemoryConsumer::new("ParquetSink(SerializedFileWriter)").register(&pool); + let schema_desc = arrow_to_parquet_schema(schema.as_ref())?; let mut parquet_writer = SerializedFileWriter::new( merged_buff.clone(), @@ -957,15 +1141,21 @@ async fn concatenate_parallel_row_groups( while let Some(task) = serialize_rx.recv().await { let result = task.join_unwind().await; let mut rg_out = parquet_writer.next_row_group()?; - let (serialized_columns, _cnt) = result?; + let (serialized_columns, mut rg_reservation, _cnt) = + result.map_err(DataFusionError::ExecutionJoin)??; for chunk in serialized_columns { chunk.append_to_row_group(&mut rg_out)?; + rg_reservation.free(); + let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap(); + file_reservation.try_resize(buff_to_flush.len())?; + if buff_to_flush.len() > BUFFER_FLUSH_BYTES { object_store_writer .write_all(buff_to_flush.as_slice()) .await?; buff_to_flush.clear(); + file_reservation.try_resize(buff_to_flush.len())?; // will set to zero } } rg_out.close()?; @@ -976,6 +1166,7 @@ async fn concatenate_parallel_row_groups( object_store_writer.write_all(final_buff.as_slice()).await?; object_store_writer.shutdown().await?; + file_reservation.free(); Ok(file_metadata) } @@ -990,6 +1181,7 @@ async fn output_single_parquet_file_parallelized( output_schema: Arc, parquet_props: &WriterProperties, parallel_options: ParallelParquetWriterOptions, + pool: Arc, ) -> Result { let max_rowgroups = parallel_options.max_parallel_row_groups; // Buffer size of this channel limits maximum number of RowGroups being worked on in parallel @@ -1003,16 +1195,21 @@ async fn output_single_parquet_file_parallelized( output_schema.clone(), arc_props.clone(), parallel_options, + Arc::clone(&pool), ); let file_metadata = concatenate_parallel_row_groups( serialize_rx, output_schema.clone(), arc_props.clone(), object_store_writer, + pool, ) .await?; - launch_serialization_task.join_unwind().await?; + launch_serialization_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)??; Ok(file_metadata) } @@ -1100,8 +1297,10 @@ mod tests { use super::super::test_util::scan_format; use crate::datasource::listing::{ListingTableUrl, PartitionedFile}; use crate::physical_plan::collect; + use crate::test_util::bounded_stream; use std::fmt::{Display, Formatter}; use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Duration; use super::*; @@ -1109,15 +1308,17 @@ mod tests { use crate::physical_plan::metrics::MetricValue; use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{Array, ArrayRef, StringArray}; - use arrow_schema::Field; + use arrow_array::types::Int32Type; + use arrow_array::{DictionaryArray, Int32Array, Int64Array}; + use arrow_schema::{DataType, Field}; use async_trait::async_trait; - use bytes::Bytes; use datafusion_common::cast::{ - as_binary_array, as_boolean_array, as_float32_array, as_float64_array, - as_int32_array, as_timestamp_nanosecond_array, + as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, + as_float64_array, as_int32_array, as_timestamp_nanosecond_array, }; use datafusion_common::config::ParquetOptions; use datafusion_common::ScalarValue; + use datafusion_common::ScalarValue::Utf8; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; @@ -1125,16 +1326,21 @@ mod tests { use log::error; use object_store::local::LocalFileSystem; use object_store::{ - GetOptions, GetResult, ListResult, MultipartId, PutOptions, PutResult, + GetOptions, GetResult, ListResult, MultipartUpload, PutMultipartOpts, PutOptions, + PutPayload, PutResult, }; use parquet::arrow::arrow_reader::ArrowReaderOptions; use parquet::arrow::ParquetRecordBatchStreamBuilder; - use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; + use parquet::file::metadata::{KeyValue, ParquetColumnIndex, ParquetOffsetIndex}; use parquet::file::page_index::index::Index; use tokio::fs::File; - #[tokio::test] - async fn read_merged_batches() -> Result<()> { + enum ForceViews { + Yes, + No, + } + + async fn _run_read_merged_batches(force_views: ForceViews) -> Result<()> { let c1: ArrayRef = Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); @@ -1148,7 +1354,11 @@ mod tests { let session = SessionContext::new(); let ctx = session.state(); - let format = ParquetFormat::default(); + let force_views = match force_views { + ForceViews::Yes => true, + ForceViews::No => false, + }; + let format = ParquetFormat::default().with_force_view_types(force_views); let schema = format.infer_schema(&ctx, &store, &meta).await.unwrap(); let stats = @@ -1178,6 +1388,14 @@ mod tests { Ok(()) } + #[tokio::test] + async fn read_merged_batches() -> Result<()> { + _run_read_merged_batches(ForceViews::No).await?; + _run_read_merged_batches(ForceViews::Yes).await?; + + Ok(()) + } + #[tokio::test] async fn is_schema_stable() -> Result<()> { let c1: ArrayRef = @@ -1205,7 +1423,7 @@ mod tests { .map(|i| i.to_string()) .collect(); let coll: Vec<_> = schema - .all_fields() + .flattened_fields() .into_iter() .map(|i| i.name().to_string()) .collect(); @@ -1221,7 +1439,7 @@ mod tests { } impl Display for RequestCountingObjectStore { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "RequestCounting({})", self.inner) } } @@ -1248,25 +1466,17 @@ mod tests { async fn put_opts( &self, _location: &Path, - _bytes: Bytes, + _payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { Err(object_store::Error::NotImplemented) } - async fn put_multipart( + async fn put_multipart_opts( &self, _location: &Path, - ) -> object_store::Result<(MultipartId, Box)> - { - Err(object_store::Error::NotImplemented) - } - - async fn abort_multipart( - &self, - _location: &Path, - _multipart_id: &MultipartId, - ) -> object_store::Result<()> { + _opts: PutMultipartOpts, + ) -> object_store::Result> { Err(object_store::Error::NotImplemented) } @@ -1316,8 +1526,7 @@ mod tests { } } - #[tokio::test] - async fn fetch_metadata_with_size_hint() -> Result<()> { + async fn _run_fetch_metadata_with_size_hint(force_views: ForceViews) -> Result<()> { let c1: ArrayRef = Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); @@ -1341,7 +1550,13 @@ mod tests { let session = SessionContext::new(); let ctx = session.state(); - let format = ParquetFormat::default().with_metadata_size_hint(Some(9)); + let force_views = match force_views { + ForceViews::Yes => true, + ForceViews::No => false, + }; + let format = ParquetFormat::default() + .with_metadata_size_hint(Some(9)) + .with_force_view_types(force_views); let schema = format .infer_schema(&ctx, &store.upcast(), &meta) .await @@ -1371,7 +1586,9 @@ mod tests { // ensure the requests were coalesced into a single request assert_eq!(store.request_count(), 1); - let format = ParquetFormat::default().with_metadata_size_hint(Some(size_hint)); + let format = ParquetFormat::default() + .with_metadata_size_hint(Some(size_hint)) + .with_force_view_types(force_views); let schema = format .infer_schema(&ctx, &store.upcast(), &meta) .await @@ -1406,6 +1623,146 @@ mod tests { Ok(()) } + #[tokio::test] + async fn fetch_metadata_with_size_hint() -> Result<()> { + _run_fetch_metadata_with_size_hint(ForceViews::No).await?; + _run_fetch_metadata_with_size_hint(ForceViews::Yes).await?; + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_from_parquet_metadata_dictionary() -> Result<()> { + // Data for column c_dic: ["a", "b", "c", "d"] + let values = StringArray::from_iter_values(["a", "b", "c", "d"]); + let keys = Int32Array::from_iter_values([0, 1, 2, 3]); + let dic_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + let c_dic: ArrayRef = Arc::new(dic_array); + + let batch1 = RecordBatch::try_from_iter(vec![("c_dic", c_dic)]).unwrap(); + + // Use store_parquet to write each batch to its own file + // . batch1 written into first file and includes: + // - column c_dic that has 4 rows with no null. Stats min and max of dictionary column is available. + let store = Arc::new(LocalFileSystem::new()) as _; + let (files, _file_names) = store_parquet(vec![batch1], false).await?; + + let state = SessionContext::new().state(); + let format = ParquetFormat::default(); + let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + + // Fetch statistics for first file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + assert_eq!(stats.num_rows, Precision::Exact(4)); + + // column c_dic + let c_dic_stats = &stats.column_statistics[0]; + + assert_eq!(c_dic_stats.null_count, Precision::Exact(0)); + assert_eq!( + c_dic_stats.max_value, + Precision::Exact(Utf8(Some("d".into()))) + ); + assert_eq!( + c_dic_stats.min_value, + Precision::Exact(Utf8(Some("a".into()))) + ); + + Ok(()) + } + + async fn _run_test_statistics_from_parquet_metadata( + force_views: ForceViews, + ) -> Result<()> { + // Data for column c1: ["Foo", null, "bar"] + let c1: ArrayRef = + Arc::new(StringArray::from(vec![Some("Foo"), None, Some("bar")])); + let batch1 = RecordBatch::try_from_iter(vec![("c1", c1.clone())]).unwrap(); + + // Data for column c2: [1, 2, null] + let c2: ArrayRef = Arc::new(Int64Array::from(vec![Some(1), Some(2), None])); + let batch2 = RecordBatch::try_from_iter(vec![("c2", c2)]).unwrap(); + + // Use store_parquet to write each batch to its own file + // . batch1 written into first file and includes: + // - column c1 that has 3 rows with one null. Stats min and max of string column is missing for this test even the column has values + // . batch2 written into second file and includes: + // - column c2 that has 3 rows with one null. Stats min and max of int are available and 1 and 2 respectively + let store = Arc::new(LocalFileSystem::new()) as _; + let (files, _file_names) = store_parquet(vec![batch1, batch2], false).await?; + + let force_views = match force_views { + ForceViews::Yes => true, + ForceViews::No => false, + }; + + let mut state = SessionContext::new().state(); + state = set_view_state(state, force_views); + let format = ParquetFormat::default().with_force_view_types(force_views); + let schema = format.infer_schema(&state, &store, &files).await.unwrap(); + + let null_i64 = ScalarValue::Int64(None); + let null_utf8 = if force_views { + ScalarValue::Utf8View(None) + } else { + Utf8(None) + }; + + // Fetch statistics for first file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[0], None).await?; + let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + assert_eq!(stats.num_rows, Precision::Exact(3)); + // column c1 + let c1_stats = &stats.column_statistics[0]; + assert_eq!(c1_stats.null_count, Precision::Exact(1)); + let expected_type = if force_views { + ScalarValue::Utf8View + } else { + Utf8 + }; + assert_eq!( + c1_stats.max_value, + Precision::Exact(expected_type(Some("bar".to_string()))) + ); + assert_eq!( + c1_stats.min_value, + Precision::Exact(expected_type(Some("Foo".to_string()))) + ); + // column c2: missing from the file so the table treats all 3 rows as null + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c2_stats.null_count, Precision::Exact(3)); + assert_eq!(c2_stats.max_value, Precision::Exact(null_i64.clone())); + assert_eq!(c2_stats.min_value, Precision::Exact(null_i64.clone())); + + // Fetch statistics for second file + let pq_meta = fetch_parquet_metadata(store.as_ref(), &files[1], None).await?; + let stats = statistics_from_parquet_meta_calc(&pq_meta, schema.clone())?; + assert_eq!(stats.num_rows, Precision::Exact(3)); + // column c1: missing from the file so the table treats all 3 rows as null + let c1_stats = &stats.column_statistics[0]; + assert_eq!(c1_stats.null_count, Precision::Exact(3)); + assert_eq!(c1_stats.max_value, Precision::Exact(null_utf8.clone())); + assert_eq!(c1_stats.min_value, Precision::Exact(null_utf8.clone())); + // column c2 + let c2_stats = &stats.column_statistics[1]; + assert_eq!(c2_stats.null_count, Precision::Exact(1)); + assert_eq!(c2_stats.max_value, Precision::Exact(2i64.into())); + assert_eq!(c2_stats.min_value, Precision::Exact(1i64.into())); + + Ok(()) + } + + #[tokio::test] + async fn test_statistics_from_parquet_metadata() -> Result<()> { + _run_test_statistics_from_parquet_metadata(ForceViews::No).await?; + + _run_test_statistics_from_parquet_metadata(ForceViews::Yes).await?; + + Ok(()) + } + #[tokio::test] async fn read_small_batches() -> Result<()> { let config = SessionConfig::new().with_batch_size(2); @@ -1480,10 +1837,31 @@ mod tests { Ok(()) } - #[tokio::test] - async fn read_alltypes_plain_parquet() -> Result<()> { + fn set_view_state(mut state: SessionState, use_views: bool) -> SessionState { + let mut options = TableParquetOptions::default(); + options.global.schema_force_view_types = use_views; + state + .register_file_format( + Arc::new(ParquetFormatFactory::new_with_options(options)), + true, + ) + .expect("ok"); + state + } + + async fn _run_read_alltypes_plain_parquet( + force_views: ForceViews, + expected: &str, + ) -> Result<()> { + let force_views = match force_views { + ForceViews::Yes => true, + ForceViews::No => false, + }; + let session_ctx = SessionContext::new(); - let state = session_ctx.state(); + let mut state = session_ctx.state(); + state = set_view_state(state, force_views); + let task_ctx = state.task_ctx(); let projection = None; let exec = get_exec(&state, "alltypes_plain.parquet", projection, None).await?; @@ -1495,8 +1873,20 @@ mod tests { .map(|f| format!("{}: {:?}", f.name(), f.data_type())) .collect(); let y = x.join("\n"); - assert_eq!( - "id: Int32\n\ + assert_eq!(expected, y); + + let batches = collect(exec, task_ctx).await?; + + assert_eq!(1, batches.len()); + assert_eq!(11, batches[0].num_columns()); + assert_eq!(8, batches[0].num_rows()); + + Ok(()) + } + + #[tokio::test] + async fn read_alltypes_plain_parquet() -> Result<()> { + let no_views = "id: Int32\n\ bool_col: Boolean\n\ tinyint_col: Int32\n\ smallint_col: Int32\n\ @@ -1506,15 +1896,21 @@ mod tests { double_col: Float64\n\ date_string_col: Binary\n\ string_col: Binary\n\ - timestamp_col: Timestamp(Nanosecond, None)", - y - ); + timestamp_col: Timestamp(Nanosecond, None)"; + _run_read_alltypes_plain_parquet(ForceViews::No, no_views).await?; - let batches = collect(exec, task_ctx).await?; - - assert_eq!(1, batches.len()); - assert_eq!(11, batches[0].num_columns()); - assert_eq!(8, batches[0].num_rows()); + let with_views = "id: Int32\n\ + bool_col: Boolean\n\ + tinyint_col: Int32\n\ + smallint_col: Int32\n\ + int_col: Int32\n\ + bigint_col: Int64\n\ + float_col: Float32\n\ + double_col: Float64\n\ + date_string_col: BinaryView\n\ + string_col: BinaryView\n\ + timestamp_col: Timestamp(Nanosecond, None)"; + _run_read_alltypes_plain_parquet(ForceViews::Yes, with_views).await?; Ok(()) } @@ -1651,7 +2047,9 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_parquet() -> Result<()> { let session_ctx = SessionContext::new(); - let state = session_ctx.state(); + let mut state = session_ctx.state(); + state = set_view_state(state, false); + let task_ctx = state.task_ctx(); let projection = Some(vec![9]); let exec = get_exec(&state, "alltypes_plain.parquet", projection, None).await?; @@ -1675,6 +2073,35 @@ mod tests { Ok(()) } + #[tokio::test] + async fn read_binaryview_alltypes_plain_parquet() -> Result<()> { + let session_ctx = SessionContext::new(); + let mut state = session_ctx.state(); + state = set_view_state(state, true); + + let task_ctx = state.task_ctx(); + let projection = Some(vec![9]); + let exec = get_exec(&state, "alltypes_plain.parquet", projection, None).await?; + + let batches = collect(exec, task_ctx).await?; + assert_eq!(1, batches.len()); + assert_eq!(1, batches[0].num_columns()); + assert_eq!(8, batches[0].num_rows()); + + let array = as_binary_view_array(batches[0].column(0))?; + let mut values: Vec<&str> = vec![]; + for i in 0..batches[0].num_rows() { + values.push(std::str::from_utf8(array.value(i)).unwrap()); + } + + assert_eq!( + "[\"0\", \"1\", \"0\", \"1\", \"0\", \"1\", \"0\", \"1\"]", + format!("{values:?}") + ); + + Ok(()) + } + #[tokio::test] async fn read_decimal_parquet() -> Result<()> { let session_ctx = SessionContext::new(); @@ -1772,7 +2199,7 @@ mod tests { // test result in int_col let int_col_index = page_index.get(4).unwrap(); - let int_col_offset = offset_index.get(4).unwrap(); + let int_col_offset = offset_index.get(4).unwrap().page_locations(); // 325 pages in int_col assert_eq!(int_col_offset.len(), 325); @@ -1809,8 +2236,13 @@ mod tests { limit: Option, ) -> Result> { let testdata = crate::test_util::parquet_test_data(); - let format = ParquetFormat::default(); - scan_format(state, &format, &testdata, file_name, projection, limit).await + + let format = state + .get_file_format_factory("parquet") + .map(|factory| factory.create(state, &Default::default()).unwrap()) + .unwrap_or(Arc::new(ParquetFormat::new())); + + scan_format(state, &*format, &testdata, file_name, projection, limit).await } fn build_ctx(store_url: &url::Url) -> Arc { @@ -1842,6 +2274,147 @@ mod tests { #[tokio::test] async fn parquet_sink_write() -> Result<()> { + let parquet_sink = create_written_parquet_sink("file:///").await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet files to be written, instead found {}", + written.len() + ); + + // check the file metadata + let ( + path, + FileMetaData { + num_rows, + schema, + key_value_metadata, + .. + }, + ) = written.take(1).next().unwrap(); + let path_parts = path.parts().collect::>(); + assert_eq!(path_parts.len(), 1, "should not have path prefix"); + + assert_eq!(num_rows, 2, "file metadata to have 2 rows"); + assert!( + schema.iter().any(|col_schema| col_schema.name == "a"), + "output file metadata should contain col a" + ); + assert!( + schema.iter().any(|col_schema| col_schema.name == "b"), + "output file metadata should contain col b" + ); + + let mut key_value_metadata = key_value_metadata.unwrap(); + key_value_metadata.sort_by(|a, b| a.key.cmp(&b.key)); + let expected_metadata = vec![ + KeyValue { + key: "my-data".to_string(), + value: Some("stuff".to_string()), + }, + KeyValue { + key: "my-data-bool-key".to_string(), + value: None, + }, + ]; + assert_eq!(key_value_metadata, expected_metadata); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_extension() -> Result<()> { + let filename = "test_file.custom_ext"; + let file_path = format!("file:///path/to/{}", filename); + let parquet_sink = create_written_parquet_sink(file_path.as_str()).await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet file to be written, instead found {}", + written.len() + ); + + let (path, ..) = written.take(1).next().unwrap(); + + let path_parts = path.parts().collect::>(); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); + assert_eq!(path_parts.last().unwrap().as_ref(), filename); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_directory_name() -> Result<()> { + let file_path = "file:///path/to"; + let parquet_sink = create_written_parquet_sink(file_path).await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet file to be written, instead found {}", + written.len() + ); + + let (path, ..) = written.take(1).next().unwrap(); + + let path_parts = path.parts().collect::>(); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); + assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); + + Ok(()) + } + + #[tokio::test] + async fn parquet_sink_write_with_folder_ending() -> Result<()> { + let file_path = "file:///path/to/"; + let parquet_sink = create_written_parquet_sink(file_path).await?; + + // assert written + let mut written = parquet_sink.written(); + let written = written.drain(); + assert_eq!( + written.len(), + 1, + "expected a single parquet file to be written, instead found {}", + written.len() + ); + + let (path, ..) = written.take(1).next().unwrap(); + + let path_parts = path.parts().collect::>(); + assert_eq!( + path_parts.len(), + 3, + "Expected 3 path parts, instead found {}", + path_parts.len() + ); + assert!(path_parts.last().unwrap().as_ref().ends_with(".parquet")); + + Ok(()) + } + + async fn create_written_parquet_sink(table_path: &str) -> Result> { let field_a = Field::new("a", DataType::Utf8, false); let field_b = Field::new("b", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); @@ -1850,14 +2423,21 @@ mod tests { let file_sink_config = FileSinkConfig { object_store_url: object_store_url.clone(), file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], - table_paths: vec![ListingTableUrl::parse("file:///")?], + table_paths: vec![ListingTableUrl::parse(table_path)?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, - TableParquetOptions::default(), + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + ..Default::default() + }, )); // create data @@ -1877,37 +2457,7 @@ mod tests { .await .unwrap(); - // assert written - let mut written = parquet_sink.written(); - let written = written.drain(); - assert_eq!( - written.len(), - 1, - "expected a single parquet files to be written, instead found {}", - written.len() - ); - - // check the file metadata - let ( - path, - FileMetaData { - num_rows, schema, .. - }, - ) = written.take(1).next().unwrap(); - let path_parts = path.parts().collect::>(); - assert_eq!(path_parts.len(), 1, "should not have path prefix"); - - assert_eq!(num_rows, 2, "file metdata to have 2 rows"); - assert!( - schema.iter().any(|col_schema| col_schema.name == "a"), - "output file metadata should contain col a" - ); - assert!( - schema.iter().any(|col_schema| col_schema.name == "b"), - "output file metadata should contain col b" - ); - - Ok(()) + Ok(parquet_sink) } #[tokio::test] @@ -1924,7 +2474,8 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("a".to_string(), DataType::Utf8)], // add partitioning - overwrite: true, + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( file_sink_config, @@ -1978,7 +2529,7 @@ mod tests { ); expected_partitions.remove(prefix); - assert_eq!(num_rows, 1, "file metdata to have 1 row"); + assert_eq!(num_rows, 1, "file metadata to have 1 row"); assert!( !schema.iter().any(|col_schema| col_schema.name == "a"), "output file metadata will not contain partitioned col a" @@ -1991,4 +2542,105 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn parquet_sink_write_memory_reservation() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); + + let file_sink_config = FileSinkConfig { + object_store_url: object_store_url.clone(), + file_groups: vec![PartitionedFile::new("/tmp".to_string(), 1)], + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); + + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); + + let mut write_task = parquet_sink.write_all( + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; + } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); + + Ok(()) + } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs b/datafusion/core/src/datasource/file_format/write/demux.rs index d82c2471c596..b03676d53271 100644 --- a/datafusion/core/src/datasource/file_format/write/demux.rs +++ b/datafusion/core/src/datasource/file_format/write/demux.rs @@ -18,6 +18,7 @@ //! Module containing helper methods/traits related to enabling //! dividing input stream into multiple output files at execution time +use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -31,7 +32,11 @@ use arrow_array::builder::UInt64Builder; use arrow_array::cast::AsArray; use arrow_array::{downcast_dictionary_array, RecordBatch, StringArray, StructArray}; use arrow_schema::{DataType, Schema}; -use datafusion_common::cast::as_string_array; +use chrono::NaiveDate; +use datafusion_common::cast::{ + as_boolean_array, as_date32_array, as_date64_array, as_int32_array, as_int64_array, + as_string_array, as_string_view_array, +}; use datafusion_common::{exec_datafusion_err, DataFusionError}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::TaskContext; @@ -54,8 +59,9 @@ type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>; /// which should be contained within the same output file. The outer channel /// is used to send a dynamic number of inner channels, representing a dynamic /// number of total output files. The caller is also responsible to monitor -/// the demux task for errors and abort accordingly. The single_file_ouput parameter -/// overrides all other settings to force only a single file to be written. +/// the demux task for errors and abort accordingly. A path with an extension will +/// force only a single file to be written with the extension from the path. Otherwise +/// the default extension will be used and the output will be split into multiple files. /// partition_by parameter will additionally split the input based on the unique /// values of a specific column ``` /// ┌───────────┐ ┌────────────┐ ┌─────────────┐ @@ -74,11 +80,13 @@ pub(crate) fn start_demuxer_task( context: &Arc, partition_by: Option>, base_output_path: ListingTableUrl, - file_extension: String, + default_extension: String, + keep_partition_by_columns: bool, ) -> (SpawnedTask>, DemuxedStreamReceiver) { let (tx, rx) = mpsc::unbounded_channel(); let context = context.clone(); - let single_file_output = !base_output_path.is_collection(); + let single_file_output = + !base_output_path.is_collection() && base_output_path.file_extension().is_some(); let task = match partition_by { Some(parts) => { // There could be an arbitrarily large number of parallel hive style partitions being written to, so we cannot @@ -90,7 +98,8 @@ pub(crate) fn start_demuxer_task( context, parts, base_output_path, - file_extension, + default_extension, + keep_partition_by_columns, ) .await }) @@ -101,7 +110,7 @@ pub(crate) fn start_demuxer_task( input, context, base_output_path, - file_extension, + default_extension, single_file_output, ) .await @@ -111,7 +120,7 @@ pub(crate) fn start_demuxer_task( (task, rx) } -/// Dynamically partitions input stream to acheive desired maximum rows per file +/// Dynamically partitions input stream to achieve desired maximum rows per file async fn row_count_demuxer( mut tx: UnboundedSender<(Path, Receiver)>, mut input: SendableRecordBatchStream, @@ -240,6 +249,7 @@ async fn hive_style_partitions_demuxer( partition_by: Vec<(String, DataType)>, base_output_path: ListingTableUrl, file_extension: String, + keep_partition_by_columns: bool, ) -> Result<()> { let write_id = rand::distributions::Alphanumeric.sample_string(&mut rand::thread_rng(), 16); @@ -272,9 +282,8 @@ async fn hive_style_partitions_demuxer( Some(part_tx) => part_tx, None => { // Create channel for previously unseen distinct partition key and notify consumer of new file - let (part_tx, part_rx) = tokio::sync::mpsc::channel::( - max_buffered_recordbatches, - ); + let (part_tx, part_rx) = + mpsc::channel::(max_buffered_recordbatches); let file_path = compute_hive_style_file_path( &part_key, &partition_by, @@ -298,9 +307,11 @@ async fn hive_style_partitions_demuxer( } }; - // remove partitions columns - let final_batch_to_send = - remove_partition_by_columns(&parted_batch, &partition_by)?; + let final_batch_to_send = if keep_partition_by_columns { + parted_batch + } else { + remove_partition_by_columns(&parted_batch, &partition_by)? + }; // Finally send the partial batch partitioned by distinct value! part_tx.send(final_batch_to_send).await.map_err(|_| { @@ -315,9 +326,11 @@ async fn hive_style_partitions_demuxer( fn compute_partition_keys_by_row<'a>( rb: &'a RecordBatch, partition_by: &'a [(String, DataType)], -) -> Result>> { +) -> Result>>> { let mut all_partition_values = vec![]; + const EPOCH_DAYS_FROM_CE: i32 = 719_163; + // For the purposes of writing partitioned data, we can rely on schema inference // to determine the type of the partition cols in order to provide a more ergonomic // UI which does not require specifying DataTypes manually. So, we ignore the @@ -337,7 +350,59 @@ fn compute_partition_keys_by_row<'a>( DataType::Utf8 => { let array = as_string_array(col_array)?; for i in 0..rb.num_rows() { - partition_values.push(array.value(i)); + partition_values.push(Cow::from(array.value(i))); + } + } + DataType::Utf8View => { + let array = as_string_view_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i))); + } + } + DataType::Boolean => { + let array = as_boolean_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i).to_string())); + } + } + DataType::Date32 => { + let array = as_date32_array(col_array)?; + // ISO-8601/RFC3339 format - yyyy-mm-dd + let format = "%Y-%m-%d"; + for i in 0..rb.num_rows() { + let date = NaiveDate::from_num_days_from_ce_opt( + EPOCH_DAYS_FROM_CE + array.value(i), + ) + .unwrap() + .format(format) + .to_string(); + partition_values.push(Cow::from(date)); + } + } + DataType::Date64 => { + let array = as_date64_array(col_array)?; + // ISO-8601/RFC3339 format - yyyy-mm-dd + let format = "%Y-%m-%d"; + for i in 0..rb.num_rows() { + let date = NaiveDate::from_num_days_from_ce_opt( + EPOCH_DAYS_FROM_CE + (array.value(i) / 86_400_000) as i32, + ) + .unwrap() + .format(format) + .to_string(); + partition_values.push(Cow::from(date)); + } + } + DataType::Int32 => { + let array = as_int32_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i).to_string())); + } + } + DataType::Int64 => { + let array = as_int64_array(col_array)?; + for i in 0..rb.num_rows() { + partition_values.push(Cow::from(array.value(i).to_string())); } } DataType::Dictionary(_, _) => { @@ -349,7 +414,7 @@ fn compute_partition_keys_by_row<'a>( for val in array.values() { partition_values.push( - val.ok_or(exec_datafusion_err!("Cannot partition by null value for column {}", col))? + Cow::from(val.ok_or(exec_datafusion_err!("Cannot partition by null value for column {}", col))?), ); } }, @@ -372,13 +437,13 @@ fn compute_partition_keys_by_row<'a>( fn compute_take_arrays( rb: &RecordBatch, - all_partition_values: Vec>, + all_partition_values: Vec>>, ) -> HashMap, UInt64Builder> { let mut take_map = HashMap::new(); for i in 0..rb.num_rows() { let mut part_key = vec![]; for vals in all_partition_values.iter() { - part_key.push(vals[i].to_owned()); + part_key.push(vals[i].clone().into()); } let builder = take_map.entry(part_key).or_insert(UInt64Builder::new()); builder.append_value(i as u64); diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index 3ae2122de827..6f27e6f3889f 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -42,15 +42,51 @@ use tokio::task::JoinSet; type WriterType = Box; type SerializerType = Arc; -/// Serializes a single data stream in parallel and writes to an ObjectStore -/// concurrently. Data order is preserved. In the event of an error, -/// the ObjectStore writer is returned to the caller in addition to an error, -/// so that the caller may handle aborting failed writes. +/// Result of calling [`serialize_rb_stream_to_object_store`] +pub(crate) enum SerializedRecordBatchResult { + Success { + /// the writer + writer: WriterType, + + /// the number of rows successfully written + row_count: usize, + }, + Failure { + /// As explained in [`serialize_rb_stream_to_object_store`]: + /// - If an IO error occured that involved the ObjectStore writer, then the writer will not be returned to the caller + /// - Otherwise, the writer is returned to the caller + writer: Option, + + /// the actual error that occured + err: DataFusionError, + }, +} + +impl SerializedRecordBatchResult { + /// Create the success variant + pub fn success(writer: WriterType, row_count: usize) -> Self { + Self::Success { writer, row_count } + } + + pub fn failure(writer: Option, err: DataFusionError) -> Self { + Self::Failure { writer, err } + } +} + +/// Serializes a single data stream in parallel and writes to an ObjectStore concurrently. +/// Data order is preserved. +/// +/// In the event of a non-IO error which does not involve the ObjectStore writer, +/// the writer returned to the caller in addition to the error, +/// so that failed writes may be aborted. +/// +/// In the event of an IO error involving the ObjectStore writer, +/// the writer is dropped to avoid calling further methods on it which might panic. pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, mut writer: WriterType, -) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { +) -> SerializedRecordBatchResult { let (tx, mut rx) = mpsc::channel::>>(100); let serialize_task = SpawnedTask::spawn(async move { @@ -81,43 +117,43 @@ pub(crate) async fn serialize_rb_stream_to_object_store( match writer.write_all(&bytes).await { Ok(_) => (), Err(e) => { - return Err(( - writer, + return SerializedRecordBatchResult::failure( + None, DataFusionError::Execution(format!( "Error writing to object store: {e}" )), - )) + ) } }; row_count += cnt; } Ok(Err(e)) => { // Return the writer along with the error - return Err((writer, e)); + return SerializedRecordBatchResult::failure(Some(writer), e); } Err(e) => { // Handle task panic or cancellation - return Err(( - writer, + return SerializedRecordBatchResult::failure( + Some(writer), DataFusionError::Execution(format!( "Serialization task panicked or was cancelled: {e}" )), - )); + ); } } } match serialize_task.join().await { Ok(Ok(_)) => (), - Ok(Err(e)) => return Err((writer, e)), + Ok(Err(e)) => return SerializedRecordBatchResult::failure(Some(writer), e), Err(_) => { - return Err(( - writer, + return SerializedRecordBatchResult::failure( + Some(writer), internal_datafusion_err!("Unknown error writing to object store"), - )) + ) } } - Ok((writer, row_count as u64)) + SerializedRecordBatchResult::success(writer, row_count) } type FileWriteBundle = (Receiver, SerializerType, WriterType); @@ -136,7 +172,7 @@ pub(crate) async fn stateless_serialize_and_write_files( // tracks the specific error triggering abort let mut triggering_error = None; // tracks if any errors were encountered in the process of aborting writers. - // if true, we may not have a guarentee that all written data was cleaned up. + // if true, we may not have a guarantee that all written data was cleaned up. let mut any_abort_errors = false; let mut join_set = JoinSet::new(); while let Some((data_rx, serializer, writer)) = rx.recv().await { @@ -148,14 +184,17 @@ pub(crate) async fn stateless_serialize_and_write_files( while let Some(result) = join_set.join_next().await { match result { Ok(res) => match res { - Ok((writer, cnt)) => { + SerializedRecordBatchResult::Success { + writer, + row_count: cnt, + } => { finished_writers.push(writer); row_count += cnt; } - Err((writer, e)) => { - finished_writers.push(writer); + SerializedRecordBatchResult::Failure { writer, err } => { + finished_writers.extend(writer); any_errors = true; - triggering_error = Some(e); + triggering_error = Some(err); } }, Err(e) => { @@ -183,12 +222,12 @@ pub(crate) async fn stateless_serialize_and_write_files( true => return internal_err!("Error encountered during writing to ObjectStore and failed to abort all writers. Partial result may have been written."), false => match triggering_error { Some(e) => return Err(e), - None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers succesfully aborted.") + None => return internal_err!("Unknown Error encountered during writing to ObjectStore. All writers successfully aborted.") } } } - tx.send(row_count).map_err(|_| { + tx.send(row_count as u64).map_err(|_| { internal_datafusion_err!( "Error encountered while sending row count back to file sink!" ) @@ -224,6 +263,7 @@ pub(crate) async fn stateless_multipart_put( part_cols, base_output_path.clone(), file_extension, + config.keep_partition_by_columns, ); let rb_buffer_size = &context @@ -258,11 +298,11 @@ pub(crate) async fn stateless_multipart_put( write_coordinator_task.join_unwind(), demux_task.join_unwind() ); - r1?; - r2?; + r1.map_err(DataFusionError::ExecutionJoin)??; + r2.map_err(DataFusionError::ExecutionJoin)??; let total_count = rx_row_cnt.await.map_err(|_| { - internal_datafusion_err!("Did not receieve row count from write coordinater") + internal_datafusion_err!("Did not receive row count from write coordinator") })?; Ok(total_count) diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs index 2fd352ee4eb3..37ce59f8207b 100644 --- a/datafusion/core/src/datasource/function.rs +++ b/datafusion/core/src/datasource/function.rs @@ -22,15 +22,17 @@ use super::TableProvider; use datafusion_common::Result; use datafusion_expr::Expr; +use std::fmt::Debug; use std::sync::Arc; /// A trait for table function implementations -pub trait TableFunctionImpl: Sync + Send { +pub trait TableFunctionImpl: Debug + Sync + Send { /// Create a table provider fn call(&self, args: &[Expr]) -> Result>; } /// A table that uses a function to generate data +#[derive(Debug)] pub struct TableFunction { /// Name of the table function name: String, @@ -49,6 +51,11 @@ impl TableFunction { &self.name } + /// Get the implementation of the table function + pub fn function(&self) -> &Arc { + &self.fun + } + /// Get the function implementation and generate a table pub fn create_table_provider(&self, args: &[Expr]) -> Result> { self.fun.call(args) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 9dfd18f1881e..47012f777ad1 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -17,12 +17,16 @@ //! Helper functions for the table implementation +use std::collections::HashMap; +use std::mem; use std::sync::Arc; +use super::ListingTableUrl; use super::PartitionedFile; -use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; -use crate::{error::Result, scalar::ScalarValue}; +use datafusion_common::internal_err; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{BinaryExpr, Operator}; use arrow::{ array::{Array, ArrayRef, AsArray, StringBuilder}, @@ -37,8 +41,8 @@ use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use log::{debug, trace}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{internal_err, Column, DFSchema, DataFusionError}; -use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_common::{Column, DFSchema, DataFusionError}; +use datafusion_expr::{Expr, Volatility}; use datafusion_physical_expr::create_physical_expr; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -47,78 +51,67 @@ use object_store::{ObjectMeta, ObjectStore}; /// This means that if this function returns true: /// - the table provider can filter the table partition values with this expression /// - the expression can be marked as `TableProviderFilterPushDown::Exact` once this filtering -/// was performed -pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { +/// was performed +pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(|expr| { - match expr { - Expr::Column(Column { ref name, .. }) => { - is_applicable &= col_names.contains(name); - if is_applicable { - Ok(TreeNodeRecursion::Jump) - } else { - Ok(TreeNodeRecursion::Stop) - } + expr.apply(|expr| match expr { + Expr::Column(Column { ref name, .. }) => { + is_applicable &= col_names.contains(&name.as_str()); + if is_applicable { + Ok(TreeNodeRecursion::Jump) + } else { + Ok(TreeNodeRecursion::Stop) } - Expr::Literal(_) - | Expr::Alias(_) - | Expr::OuterReferenceColumn(_, _) - | Expr::ScalarVariable(_, _) - | Expr::Not(_) - | Expr::IsNotNull(_) - | Expr::IsNull(_) - | Expr::IsTrue(_) - | Expr::IsFalse(_) - | Expr::IsUnknown(_) - | Expr::IsNotTrue(_) - | Expr::IsNotFalse(_) - | Expr::IsNotUnknown(_) - | Expr::Negative(_) - | Expr::Cast { .. } - | Expr::TryCast { .. } - | Expr::BinaryExpr { .. } - | Expr::Between { .. } - | Expr::Like { .. } - | Expr::SimilarTo { .. } - | Expr::InList { .. } - | Expr::Exists { .. } - | Expr::InSubquery(_) - | Expr::ScalarSubquery(_) - | Expr::GetIndexedField { .. } - | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), - - Expr::ScalarFunction(scalar_function) => { - match &scalar_function.func_def { - ScalarFunctionDefinition::UDF(fun) => { - match fun.signature().volatility { - Volatility::Immutable => Ok(TreeNodeRecursion::Continue), - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - Ok(TreeNodeRecursion::Stop) - } - } - } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } + } + Expr::Literal(_) + | Expr::Alias(_) + | Expr::OuterReferenceColumn(_, _) + | Expr::ScalarVariable(_, _) + | Expr::Not(_) + | Expr::IsNotNull(_) + | Expr::IsNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Negative(_) + | Expr::Cast(_) + | Expr::TryCast(_) + | Expr::BinaryExpr(_) + | Expr::Between(_) + | Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::InList(_) + | Expr::Exists(_) + | Expr::InSubquery(_) + | Expr::ScalarSubquery(_) + | Expr::GroupingSet(_) + | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + + Expr::ScalarFunction(scalar_function) => { + match scalar_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) } } + } - // TODO other expressions are not handled yet: - // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases - // - Can `Wildcard` be considered as a `Literal`? - // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateFunction { .. } - | Expr::Sort { .. } - | Expr::WindowFunction { .. } - | Expr::Wildcard { .. } - | Expr::Unnest { .. } - | Expr::Placeholder(_) => { - is_applicable = false; - Ok(TreeNodeRecursion::Stop) - } + // TODO other expressions are not handled yet: + // - AGGREGATE and WINDOW should not end up in filter conditions, except maybe in some edge cases + // - Can `Wildcard` be considered as a `Literal`? + // - ScalarVariable could be `applicable`, but that would require access to the context + Expr::AggregateFunction { .. } + | Expr::WindowFunction { .. } + | Expr::Wildcard { .. } + | Expr::Unnest { .. } + | Expr::Placeholder(_) => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) } }) .unwrap(); @@ -144,10 +137,22 @@ pub fn split_files( // effectively this is div with rounding up instead of truncating let chunk_size = (partitioned_files.len() + n - 1) / n; - partitioned_files - .chunks(chunk_size) - .map(|c| c.to_vec()) - .collect() + let mut chunks = Vec::with_capacity(n); + let mut current_chunk = Vec::with_capacity(chunk_size); + for file in partitioned_files.drain(..) { + current_chunk.push(file); + if current_chunk.len() == chunk_size { + let full_chunk = + mem::replace(&mut current_chunk, Vec::with_capacity(chunk_size)); + chunks.push(full_chunk); + } + } + + if !current_chunk.is_empty() { + chunks.push(current_chunk) + } + + chunks } struct Partition { @@ -177,9 +182,17 @@ async fn list_partitions( store: &dyn ObjectStore, table_path: &ListingTableUrl, max_depth: usize, + partition_prefix: Option, ) -> Result> { let partition = Partition { - path: table_path.prefix().clone(), + path: match partition_prefix { + Some(prefix) => Path::from_iter( + Path::from(table_path.prefix().as_ref()) + .parts() + .chain(Path::from(prefix.as_ref()).parts()), + ), + None => table_path.prefix().clone(), + }, depth: 0, files: None, }; @@ -259,7 +272,7 @@ async fn prune_partitions( .collect(); let schema = Arc::new(Schema::new(fields)); - let df_schema = DFSchema::from_unqualifed_fields( + let df_schema = DFSchema::from_unqualified_fields( partition_cols .iter() .map(|(n, d)| Field::new(n, d.clone(), true)) @@ -267,31 +280,26 @@ async fn prune_partitions( Default::default(), )?; - let batch = RecordBatch::try_new(schema.clone(), arrays)?; + let batch = RecordBatch::try_new(schema, arrays)?; // TODO: Plumb this down let props = ExecutionProps::new(); // Applies `filter` to `batch` returning `None` on error - let do_filter = |filter| -> Option { - let expr = create_physical_expr(filter, &df_schema, &props).ok()?; - expr.evaluate(&batch) - .ok()? - .into_array(partitions.len()) - .ok() + let do_filter = |filter| -> Result { + let expr = create_physical_expr(filter, &df_schema, &props)?; + expr.evaluate(&batch)?.into_array(partitions.len()) }; - //.Compute the conjunction of the filters, ignoring errors + //.Compute the conjunction of the filters let mask = filters .iter() - .fold(None, |acc, filter| match (acc, do_filter(filter)) { - (Some(a), Some(b)) => Some(and(&a, b.as_boolean()).unwrap_or(a)), - (None, Some(r)) => Some(r.as_boolean().clone()), - (r, None) => r, - }); + .map(|f| do_filter(f).map(|a| a.as_boolean().clone())) + .reduce(|a, b| Ok(and(&a?, &b?)?)); let mask = match mask { - Some(mask) => mask, + Some(Ok(mask)) => mask, + Some(Err(err)) => return Err(err), None => return Ok(partitions), }; @@ -313,10 +321,84 @@ async fn prune_partitions( Ok(filtered) } +#[derive(Debug)] +enum PartitionValue { + Single(String), + Multi, +} + +fn populate_partition_values<'a>( + partition_values: &mut HashMap<&'a str, PartitionValue>, + filter: &'a Expr, +) { + if let Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) = filter + { + match op { + Operator::Eq => match (left.as_ref(), right.as_ref()) { + (Expr::Column(Column { ref name, .. }), Expr::Literal(val)) + | (Expr::Literal(val), Expr::Column(Column { ref name, .. })) => { + if partition_values + .insert(name, PartitionValue::Single(val.to_string())) + .is_some() + { + partition_values.insert(name, PartitionValue::Multi); + } + } + _ => {} + }, + Operator::And => { + populate_partition_values(partition_values, left); + populate_partition_values(partition_values, right); + } + _ => {} + } + } +} + +fn evaluate_partition_prefix<'a>( + partition_cols: &'a [(String, DataType)], + filters: &'a [Expr], +) -> Option { + let mut partition_values = HashMap::new(); + for filter in filters { + populate_partition_values(&mut partition_values, filter); + } + + if partition_values.is_empty() { + return None; + } + + let mut parts = vec![]; + for (p, _) in partition_cols { + match partition_values.get(p.as_str()) { + Some(PartitionValue::Single(val)) => { + // if a partition only has a single literal value, then it can be added to the + // prefix + parts.push(format!("{p}={val}")); + } + _ => { + // break on the first unconstrainted partition to create a common prefix + // for all covered partitions. + break; + } + } + } + + if parts.is_empty() { + None + } else { + Some(Path::from_iter(parts)) + } +} + /// Discover the partitions on the given path and prune out files /// that belong to irrelevant partitions using `filters` expressions. -/// `filters` might contain expressions that can be resolved only at the -/// file level (e.g. Parquet row group pruning). +/// `filters` should only contain expressions that can be evaluated +/// using only the partition columns. pub async fn pruned_partition_list<'a>( ctx: &'a SessionState, store: &'a dyn ObjectStore, @@ -327,6 +409,12 @@ pub async fn pruned_partition_list<'a>( ) -> Result>> { // if no partition col => simply list all the files if partition_cols.is_empty() { + if !filters.is_empty() { + return internal_err!( + "Got partition filters for unpartitioned table {}", + table_path + ); + } return Ok(Box::pin( table_path .list_all_files(ctx, store, file_extension) @@ -335,7 +423,10 @@ pub async fn pruned_partition_list<'a>( )); } - let partitions = list_partitions(store, table_path, partition_cols.len()).await?; + let partition_prefix = evaluate_partition_prefix(partition_cols, filters); + let partitions = + list_partitions(store, table_path, partition_cols.len(), partition_prefix) + .await?; debug!("Listed {} partitions", partitions.len()); let pruned = @@ -376,6 +467,7 @@ pub async fn pruned_partition_list<'a>( object_meta, partition_values: partition_values.clone(), range: None, + statistics: None, extensions: None, }) })); @@ -423,8 +515,10 @@ where mod tests { use std::ops::Not; - use crate::logical_expr::{case, col, lit}; + use futures::StreamExt; + use crate::test::object_store::make_test_store_and_state; + use datafusion_expr::{case, col, lit, Expr}; use super::*; @@ -539,13 +633,11 @@ mod tests { ]); let filter1 = Expr::eq(col("part1"), lit("p1v2")); let filter2 = Expr::eq(col("part2"), lit("p2v1")); - // filter3 cannot be resolved at partition pruning - let filter3 = Expr::eq(col("part2"), col("other")); let pruned = pruned_partition_list( &state, store.as_ref(), &ListingTableUrl::parse("file:///tablepath/").unwrap(), - &[filter1, filter2, filter3], + &[filter1, filter2], ".parquet", &[ (String::from("part1"), DataType::Utf8), @@ -651,35 +743,145 @@ mod tests { #[test] fn test_expr_applicable_for_cols() { assert!(expr_applicable_for_cols( - &[String::from("c1")], + &["c1"], &Expr::eq(col("c1"), lit("value")) )); assert!(!expr_applicable_for_cols( - &[String::from("c1")], + &["c1"], &Expr::eq(col("c2"), lit("value")) )); assert!(!expr_applicable_for_cols( - &[String::from("c1")], + &["c1"], &Expr::eq(col("c1"), col("c2")) )); assert!(expr_applicable_for_cols( - &[String::from("c1"), String::from("c2")], + &["c1", "c2"], &Expr::eq(col("c1"), col("c2")) )); assert!(expr_applicable_for_cols( - &[String::from("c1"), String::from("c2")], + &["c1", "c2"], &(Expr::eq(col("c1"), col("c2").alias("c2_alias"))).not() )); assert!(expr_applicable_for_cols( - &[String::from("c1"), String::from("c2")], + &["c1", "c2"], &(case(col("c1")) .when(lit("v1"), lit(true)) .otherwise(lit(false)) .expect("valid case expr")) )); - // static expression not relvant in this context but we + // static expression not relevant in this context but we // test it as an edge case anyway in case we want to generalize // this helper function assert!(expr_applicable_for_cols(&[], &lit(true))); } + + #[test] + fn test_evaluate_partition_prefix() { + let partitions = &[ + ("a".to_string(), DataType::Utf8), + ("b".to_string(), DataType::Int16), + ("c".to_string(), DataType::Boolean), + ]; + + assert_eq!( + evaluate_partition_prefix(partitions, &[col("a").eq(lit("foo"))]), + Some(Path::from("a=foo")), + ); + + assert_eq!( + evaluate_partition_prefix(partitions, &[lit("foo").eq(col("a"))]), + Some(Path::from("a=foo")), + ); + + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(lit("foo")).and(col("b").eq(lit("bar")))], + ), + Some(Path::from("a=foo/b=bar")), + ); + + assert_eq!( + evaluate_partition_prefix( + partitions, + // list of filters should be evaluated as AND + &[col("a").eq(lit("foo")), col("b").eq(lit("bar")),], + ), + Some(Path::from("a=foo/b=bar")), + ); + + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a") + .eq(lit("foo")) + .and(col("b").eq(lit("1"))) + .and(col("c").eq(lit("true")))], + ), + Some(Path::from("a=foo/b=1/c=true")), + ); + + // no prefix when filter is empty + assert_eq!(evaluate_partition_prefix(partitions, &[]), None); + + // b=foo results in no prefix because a is not restricted + assert_eq!( + evaluate_partition_prefix(partitions, &[Expr::eq(col("b"), lit("foo"))]), + None, + ); + + // a=foo and c=baz only results in preifx a=foo because b is not restricted + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(lit("foo")).and(col("c").eq(lit("baz")))], + ), + Some(Path::from("a=foo")), + ); + + // partition with multiple values results in no prefix + assert_eq!( + evaluate_partition_prefix( + partitions, + &[Expr::and(col("a").eq(lit("foo")), col("a").eq(lit("bar")))], + ), + None, + ); + + // no prefix because partition a is not restricted to a single literal + assert_eq!( + evaluate_partition_prefix( + partitions, + &[Expr::or(col("a").eq(lit("foo")), col("a").eq(lit("bar")))], + ), + None, + ); + assert_eq!( + evaluate_partition_prefix(partitions, &[col("b").lt(lit(5))],), + None, + ); + } + + #[test] + fn test_evaluate_date_partition_prefix() { + let partitions = &[("a".to_string(), DataType::Date32)]; + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(Expr::Literal(ScalarValue::Date32(Some(3))))], + ), + Some(Path::from("a=1970-01-04")), + ); + + let partitions = &[("a".to_string(), DataType::Date64)]; + assert_eq!( + evaluate_partition_prefix( + partitions, + &[col("a").eq(Expr::Literal(ScalarValue::Date64(Some( + 4 * 24 * 60 * 60 * 1000 + )))),], + ), + Some(Path::from("a=1970-01-05")), + ); + } } diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index b8c279c8a7f1..c5a441aacf1d 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -22,9 +22,9 @@ mod helpers; mod table; mod url; -use crate::error::Result; use chrono::TimeZone; -use datafusion_common::ScalarValue; +use datafusion_common::Result; +use datafusion_common::{ScalarValue, Statistics}; use futures::Stream; use object_store::{path::Path, ObjectMeta}; use std::pin::Pin; @@ -48,6 +48,13 @@ pub struct FileRange { pub end: i64, } +impl FileRange { + /// returns true if this file range contains the specified offset + pub fn contains(&self, offset: i64) -> bool { + offset >= self.start && offset < self.end + } +} + #[derive(Debug, Clone)] /// A single file or part of a file that should be read, along with its schema, statistics /// and partition column values that need to be appended to each row. @@ -67,9 +74,15 @@ pub struct PartitionedFile { pub partition_values: Vec, /// An optional file range for a more fine-grained parallel execution pub range: Option, + /// Optional statistics that describe the data in this file if known. + /// + /// DataFusion relies on these statistics for planning (in particular to sort file groups), + /// so if they are incorrect, incorrect answers may result. + pub statistics: Option, /// An optional field for user defined per object metadata pub extensions: Option>, } + impl PartitionedFile { /// Create a simple file without metadata or partition pub fn new(path: impl Into, size: u64) -> Self { @@ -83,6 +96,7 @@ impl PartitionedFile { }, partition_values: vec![], range: None, + statistics: None, extensions: None, } } @@ -98,7 +112,8 @@ impl PartitionedFile { version: None, }, partition_values: vec![], - range: None, + range: Some(FileRange { start, end }), + statistics: None, extensions: None, } .with_range(start, end) @@ -120,6 +135,17 @@ impl PartitionedFile { self.range = Some(FileRange { start, end }); self } + + /// Update the user defined extensions for this file. + /// + /// This can be used to pass reader specific information. + pub fn with_extensions( + mut self, + extensions: Arc, + ) -> Self { + self.extensions = Some(extensions); + self + } } impl From for PartitionedFile { @@ -128,6 +154,7 @@ impl From for PartitionedFile { object_meta, partition_values: vec![], range: None, + statistics: None, extensions: None, } } @@ -135,7 +162,7 @@ impl From for PartitionedFile { #[cfg(test)] mod tests { - use crate::datasource::listing::ListingTableUrl; + use super::ListingTableUrl; use datafusion_execution::object_store::{ DefaultObjectStoreRegistry, ObjectStoreRegistry, }; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 6ee19828f1d4..15125fe5a090 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -18,48 +18,43 @@ //! The table implementation. use std::collections::HashMap; -use std::str::FromStr; -use std::{any::Any, sync::Arc}; +use std::{any::Any, str::FromStr, sync::Arc}; use super::helpers::{expr_applicable_for_cols, pruned_partition_list, split_files}; -use super::PartitionedFile; +use super::{ListingTableUrl, PartitionedFile}; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::{ - create_ordering, get_statistics_with_limit, TableProvider, TableType, -}; use crate::datasource::{ + create_ordering, file_format::{ - arrow::ArrowFormat, - avro::AvroFormat, - csv::CsvFormat, - file_compression_type::{FileCompressionType, FileTypeExt}, - json::JsonFormat, - FileFormat, + file_compression_type::FileCompressionType, FileFormat, FilePushdownSupport, }, - listing::ListingTableUrl, + get_statistics_with_limit, physical_plan::{FileScanConfig, FileSinkConfig}, }; -use crate::{ - error::{DataFusionError, Result}, - execution::context::SessionState, - logical_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}, - physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}, -}; +use crate::execution::context::SessionState; +use datafusion_catalog::TableProvider; +use datafusion_common::{config_err, DataFusionError, Result}; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; +use datafusion_expr::{SortExpr, TableType}; +use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; use arrow::datatypes::{DataType, Field, SchemaBuilder, SchemaRef}; use arrow_schema::Schema; use datafusion_common::{ - internal_err, plan_err, project_schema, Constraints, FileType, SchemaExt, ToDFSchema, + config_datafusion_err, internal_err, plan_err, project_schema, Constraints, + SchemaExt, ToDFSchema, +}; +use datafusion_execution::cache::{ + cache_manager::FileStatisticsCache, cache_unit::DefaultFileStatisticsCache, }; -use datafusion_execution::cache::cache_manager::FileStatisticsCache; -use datafusion_execution::cache::cache_unit::DefaultFileStatisticsCache; use datafusion_physical_expr::{ create_physical_expr, LexOrdering, PhysicalSortRequirement, }; use async_trait::async_trait; +use datafusion_catalog::Session; +use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::{future, stream, StreamExt, TryStreamExt}; use itertools::Itertools; use object_store::ObjectStore; @@ -119,9 +114,7 @@ impl ListingTableConfig { } } - fn infer_file_type(path: &str) -> Result<(FileType, String)> { - let err_msg = format!("Unable to infer file type from path: {path}"); - + fn infer_file_extension(path: &str) -> Result { let mut exts = path.rsplit('.'); let mut splitted = exts.next().unwrap_or(""); @@ -133,14 +126,7 @@ impl ListingTableConfig { splitted = exts.next().unwrap_or(""); } - let file_type = FileType::from_str(splitted) - .map_err(|_| DataFusionError::Internal(err_msg.to_owned()))?; - - let ext = file_type - .get_ext_with_compression(file_compression_type.to_owned()) - .map_err(|_| DataFusionError::Internal(err_msg))?; - - Ok((file_type, ext)) + Ok(splitted.to_string()) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -161,25 +147,15 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let (file_type, file_extension) = - ListingTableConfig::infer_file_type(file.location.as_ref())?; + let file_extension = + ListingTableConfig::infer_file_extension(file.location.as_ref())?; - let mut table_options = state.default_table_options(); - table_options.set_file_format(file_type.clone()); - let file_format: Arc = match file_type { - FileType::CSV => { - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - #[cfg(feature = "parquet")] - FileType::PARQUET => { - Arc::new(ParquetFormat::default().with_options(table_options.parquet)) - } - FileType::AVRO => Arc::new(AvroFormat), - FileType::JSON => { - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - FileType::ARROW => Arc::new(ArrowFormat), - }; + let file_format = state + .get_file_format_factory(&file_extension) + .ok_or(config_datafusion_err!( + "No file_format found with extension {file_extension}" + ))? + .create(state, &HashMap::new())?; let listing_options = ListingOptions::new(file_format) .with_file_extension(file_extension) @@ -216,6 +192,38 @@ impl ListingTableConfig { pub async fn infer(self, state: &SessionState) -> Result { self.infer_options(state).await?.infer_schema(state).await } + + /// Infer the partition columns from the path. Requires `self.options` to be set prior to using. + pub async fn infer_partitions_from_path(self, state: &SessionState) -> Result { + match self.options { + Some(options) => { + let Some(url) = self.table_paths.first() else { + return config_err!("No table path found"); + }; + let partitions = options + .infer_partitions(state, url) + .await? + .into_iter() + .map(|col_name| { + ( + col_name, + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + ) + }) + .collect::>(); + let options = options.with_table_partition_cols(partitions); + Ok(Self { + table_paths: self.table_paths, + file_schema: self.file_schema, + options: Some(options), + }) + } + None => config_err!("No `ListingOptions` set for inferring schema"), + } + } } /// Options for creating a [`ListingTable`] @@ -250,19 +258,19 @@ pub struct ListingOptions { /// ordering (encapsulated by a `Vec`). If there aren't /// multiple equivalent orderings, the outer `Vec` will have a /// single element. - pub file_sort_order: Vec>, + pub file_sort_order: Vec>, } impl ListingOptions { /// Creates an options instance with the given format /// Default values: - /// - no file extension filter + /// - use default file extension filter /// - no input partition to discover /// - one target partition /// - stat collection pub fn new(format: Arc) -> Self { Self { - file_extension: String::new(), + file_extension: format.get_ext(), format, table_partition_cols: vec![], collect_stat: true, @@ -273,6 +281,7 @@ impl ListingOptions { /// Set file extension on [`ListingOptions`] and returns self. /// + /// # Example /// ``` /// # use std::sync::Arc; /// # use datafusion::prelude::SessionContext; @@ -290,6 +299,33 @@ impl ListingOptions { self } + /// Optionally set file extension on [`ListingOptions`] and returns self. + /// + /// If `file_extension` is `None`, the file extension will not be changed + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use datafusion::prelude::SessionContext; + /// # use datafusion::datasource::{listing::ListingOptions, file_format::parquet::ParquetFormat}; + /// let extension = Some(".parquet"); + /// let listing_options = ListingOptions::new(Arc::new( + /// ParquetFormat::default() + /// )) + /// .with_file_extension_opt(extension); + /// + /// assert_eq!(listing_options.file_extension, ".parquet"); + /// ``` + pub fn with_file_extension_opt(mut self, file_extension: Option) -> Self + where + S: Into, + { + if let Some(file_extension) = file_extension { + self.file_extension = file_extension.into(); + } + self + } + /// Set `table partition columns` on [`ListingOptions`] and returns self. /// /// "partition columns," used to support [Hive Partitioning], are @@ -314,17 +350,17 @@ impl ListingOptions { ///# Notes /// /// - If only one level (e.g. `year` in the example above) is - /// specified, the other levels are ignored but the files are - /// still read. + /// specified, the other levels are ignored but the files are + /// still read. /// /// - Files that don't follow this partitioning scheme will be - /// ignored. + /// ignored. /// /// - Since the columns have the same value for all rows read from - /// each individual file (such as dates), they are typically - /// dictionary encoded for efficiency. You may use - /// [`wrap_partition_type_in_dict`] to request a - /// dictionary-encoded type. + /// each individual file (such as dates), they are typically + /// dictionary encoded for efficiency. You may use + /// [`wrap_partition_type_in_dict`] to request a + /// dictionary-encoded type. /// /// - The partition columns are solely extracted from the file path. Especially they are NOT part of the parquet files itself. /// @@ -413,7 +449,7 @@ impl ListingOptions { /// /// assert_eq!(listing_options.file_sort_order, file_sort_order); /// ``` - pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { + pub fn with_file_sort_order(mut self, file_sort_order: Vec>) -> Self { self.file_sort_order = file_sort_order; self } @@ -437,7 +473,9 @@ impl ListingOptions { .try_collect() .await?; - self.format.infer_schema(state, &store, &files).await + let schema = self.format.infer_schema(state, &store, &files).await?; + + Ok(schema) } /// Infers the partition columns stored in `LOCATION` and compares @@ -499,7 +537,7 @@ impl ListingOptions { /// Infer the partitioning at the given path on the provided object store. /// For performance reasons, it doesn't read all the files on disk /// and therefore may fail to detect invalid partitioning. - async fn infer_partitions( + pub(crate) async fn infer_partitions( &self, state: &SessionState, table_path: &ListingTableUrl, @@ -547,20 +585,49 @@ impl ListingOptions { } } -/// Reads data from one or more files via an -/// [`ObjectStore`]. For example, from -/// local files or objects from AWS S3. Implements [`TableProvider`], -/// a DataFusion data source. +/// Reads data from one or more files as a single table. +/// +/// Implements [`TableProvider`], a DataFusion data source. The files are read +/// using an [`ObjectStore`] instance, for example from local files or objects +/// from AWS S3. +/// +/// For example, given the `table1` directory (or object store prefix) +/// +/// ```text +/// table1 +/// ├── file1.parquet +/// └── file2.parquet +/// ``` +/// +/// A `ListingTable` would read the files `file1.parquet` and `file2.parquet` as +/// a single table, merging the schemas if the files have compatible but not +/// identical schemas. +/// +/// Given the `table2` directory (or object store prefix) +/// +/// ```text +/// table2 +/// ├── date=2024-06-01 +/// │ ├── file3.parquet +/// │ └── file4.parquet +/// └── date=2024-06-02 +/// └── file5.parquet +/// ``` /// -/// # Features +/// A `ListingTable` would read the files `file3.parquet`, `file4.parquet`, and +/// `file5.parquet` as a single table, again merging schemas if necessary. /// -/// 1. Merges schemas if the files have compatible but not identical schemas +/// Given the hive style partitioning structure (e.g,. directories named +/// `date=2024-06-01` and `date=2026-06-02`), `ListingTable` also adds a `date` +/// column when reading the table: +/// * The files in `table2/date=2024-06-01` will have the value `2024-06-01` +/// * The files in `table2/date=2024-06-02` will have the value `2024-06-02`. /// -/// 2. Hive-style partitioning support, where a path such as -/// `/files/date=1/1/2022/data.parquet` is injected as a `date` column. +/// If the query has a predicate like `WHERE date = '2024-06-01'` +/// only the corresponding directory will be read. /// -/// 3. Projection pushdown for formats that support it such as such as -/// Parquet +/// `ListingTable` also supports filter and projection pushdown for formats that +/// support it as such as Parquet. /// /// # Example /// @@ -612,6 +679,7 @@ impl ListingOptions { /// # Ok(()) /// # } /// ``` +#[derive(Debug)] pub struct ListingTable { table_paths: Vec, /// File fields only @@ -651,10 +719,16 @@ impl ListingTable { builder.push(Field::new(part_col_name, part_col_type.clone(), false)); } + let table_schema = Arc::new( + builder + .finish() + .with_metadata(file_schema.metadata().clone()), + ); + let table = Self { table_paths: config.table_paths, file_schema, - table_schema: Arc::new(builder.finish()), + table_schema, options, definition: None, collected_statistics: Arc::new(DefaultFileStatisticsCache::default()), @@ -693,8 +767,8 @@ impl ListingTable { } /// Specify the SQL definition for this table, if any - pub fn with_definition(mut self, defintion: Option) -> Self { - self.definition = defintion; + pub fn with_definition(mut self, definition: Option) -> Self { + self.definition = definition; self } @@ -714,6 +788,16 @@ impl ListingTable { } } +// Expressions can be used for parttion pruning if they can be evaluated using +// only the partiton columns and there are partition columns. +fn can_be_evaluted_for_partition_pruning( + partition_column_names: &[&str], + expr: &Expr, +) -> bool { + !partition_column_names.is_empty() + && expr_applicable_for_cols(partition_column_names, expr) +} + #[async_trait] impl TableProvider for ListingTable { fn as_any(&self) -> &dyn Any { @@ -734,21 +818,11 @@ impl TableProvider for ListingTable { async fn scan( &self, - state: &SessionState, + state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], limit: Option, ) -> Result> { - let (partitioned_file_lists, statistics) = - self.list_files_for_scan(state, filters, limit).await?; - - // if no files need to be read, return an `EmptyExec` - if partitioned_file_lists.is_empty() { - let schema = self.schema(); - let projected_schema = project_schema(&schema, projection)?; - return Ok(Arc::new(EmptyExec::new(projected_schema))); - } - // extract types of partition columns let table_partition_cols = self .options @@ -757,36 +831,86 @@ impl TableProvider for ListingTable { .map(|col| Ok(self.table_schema.field_with_name(&col.0)?.clone())) .collect::>>()?; - let filters = if let Some(expr) = conjunction(filters.to_vec()) { - // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. - let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; - let filters = - create_physical_expr(&expr, &table_df_schema, state.execution_props())?; - Some(filters) - } else { - None + let table_partition_col_names = table_partition_cols + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + // If the filters can be resolved using only partition cols, there is no need to + // pushdown it to TableScan, otherwise, `unhandled` pruning predicates will be generated + let (partition_filters, filters): (Vec<_>, Vec<_>) = + filters.iter().cloned().partition(|filter| { + can_be_evaluted_for_partition_pruning(&table_partition_col_names, filter) + }); + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? + let session_state = state.as_any().downcast_ref::().unwrap(); + let (mut partitioned_file_lists, statistics) = self + .list_files_for_scan(session_state, &partition_filters, limit) + .await?; + + // if no files need to be read, return an `EmptyExec` + if partitioned_file_lists.is_empty() { + let projected_schema = project_schema(&self.schema(), projection)?; + return Ok(Arc::new(EmptyExec::new(projected_schema))); + } + + let output_ordering = self.try_create_output_ordering()?; + match state + .config_options() + .execution + .split_file_groups_by_statistics + .then(|| { + output_ordering.first().map(|output_ordering| { + FileScanConfig::split_groups_by_statistics( + &self.table_schema, + &partitioned_file_lists, + output_ordering, + ) + }) + }) + .flatten() + { + Some(Err(e)) => log::debug!("failed to split file groups by statistics: {e}"), + Some(Ok(new_groups)) => { + if new_groups.len() <= self.options.target_partitions { + partitioned_file_lists = new_groups; + } else { + log::debug!("attempted to split file groups by statistics, but there were more file groups than target_partitions; falling back to unordered") + } + } + None => {} // no ordering required }; - let object_store_url = if let Some(url) = self.table_paths.first() { - url.object_store() - } else { + let filters = conjunction(filters.to_vec()) + .map(|expr| -> Result<_> { + // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. + let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; + let filters = create_physical_expr( + &expr, + &table_df_schema, + state.execution_props(), + )?; + Ok(Some(filters)) + }) + .unwrap_or(Ok(None))?; + + let Some(object_store_url) = + self.table_paths.first().map(ListingTableUrl::object_store) + else { return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))); }; + // create the execution plan self.options .format .create_physical_plan( - state, - FileScanConfig { - object_store_url, - file_schema: Arc::clone(&self.file_schema), - file_groups: partitioned_file_lists, - statistics, - projection: projection.cloned(), - limit, - output_ordering: self.try_create_output_ordering()?, - table_partition_cols, - }, + session_state, + FileScanConfig::new(object_store_url, Arc::clone(&self.file_schema)) + .with_file_groups(partitioned_file_lists) + .with_statistics(statistics) + .with_projection(projection.cloned()) + .with_limit(limit) + .with_output_ordering(output_ordering) + .with_table_partition_cols(table_partition_cols), filters.as_ref(), ) .await @@ -796,28 +920,36 @@ impl TableProvider for ListingTable { &self, filters: &[&Expr], ) -> Result> { - let support: Vec<_> = filters + let partition_column_names = self + .options + .table_partition_cols + .iter() + .map(|col| col.0.as_str()) + .collect::>(); + filters .iter() .map(|filter| { - if expr_applicable_for_cols( - &self - .options - .table_partition_cols - .iter() - .map(|x| x.0.clone()) - .collect::>(), - filter, - ) { + if can_be_evaluted_for_partition_pruning(&partition_column_names, filter) + { // if filter can be handled by partition pruning, it is exact - TableProviderFilterPushDown::Exact - } else { - // otherwise, we still might be able to handle the filter with file - // level mechanisms such as Parquet row group pruning. - TableProviderFilterPushDown::Inexact + return Ok(TableProviderFilterPushDown::Exact); + } + + // if we can't push it down completely with only the filename-based/path-based + // column names, then we should check if we can do parquet predicate pushdown + let supports_pushdown = self.options.format.supports_filters_pushdown( + &self.file_schema, + &self.table_schema, + &[filter], + )?; + + if supports_pushdown == FilePushdownSupport::Supported { + return Ok(TableProviderFilterPushDown::Exact); } + + Ok(TableProviderFilterPushDown::Inexact) }) - .collect(); - Ok(support) + .collect() } fn get_table_definition(&self) -> Option<&str> { @@ -826,18 +958,30 @@ impl TableProvider for ListingTable { async fn insert_into( &self, - state: &SessionState, + state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // Check that the schema of the plan matches the schema of this table. if !self .schema() .logically_equivalent_names_and_types(&input.schema()) { + // Return an error if schema of the input query does not match with the table schema. return plan_err!( - // Return an error if schema of the input query does not match with the table schema. - "Inserting query must have the same schema with the table." + "Inserting query must have the same schema with the table. \ + Expected: {:?}, got: {:?}", + self.schema() + .fields() + .iter() + .map(|field| field.data_type()) + .collect::>(), + input + .schema() + .fields() + .iter() + .map(|field| field.data_type()) + .collect::>() ); } @@ -852,8 +996,10 @@ impl TableProvider for ListingTable { // Get the object store for the table path. let store = state.runtime_env().object_store(table_path)?; + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here? + let session_state = state.as_any().downcast_ref::().unwrap(); let file_list_stream = pruned_partition_list( - state, + session_state, store.as_ref(), table_path, &[], @@ -863,6 +1009,8 @@ impl TableProvider for ListingTable { .await?; let file_groups = file_list_stream.try_collect::>().await?; + let keep_partition_by_columns = + state.config_options().execution.keep_partition_by_columns; // Sink related option, apart from format let config = FileSinkConfig { @@ -871,11 +1019,11 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - overwrite, + insert_op, + keep_partition_by_columns, }; - let unsorted: Vec> = vec![]; - let order_requirements = if self.options().file_sort_order != unsorted { + let order_requirements = if !self.options().file_sort_order.is_empty() { // Multiple sort orders in outer vec are equivalent, so we pass only the first one let ordering = self .try_create_output_ordering()? @@ -885,19 +1033,19 @@ impl TableProvider for ListingTable { ))? .clone(); // Converts Vec> into type required by execution plan to specify its required input ordering - Some( + Some(LexRequirement::new( ordering .into_iter() .map(PhysicalSortRequirement::from) .collect::>(), - ) + )) } else { None }; self.options() .format - .create_writer_physical_plan(input, state, config, order_requirements) + .create_writer_physical_plan(input, session_state, config, order_requirements) .await } @@ -941,10 +1089,12 @@ impl ListingTable { if self.options.collect_stat { let statistics = self.do_collect_statistics(ctx, &store, &part_file).await?; - Ok((part_file, statistics)) as Result<(PartitionedFile, Statistics)> + Ok((part_file, statistics)) } else { - Ok((part_file, Statistics::new_unknown(&self.file_schema))) - as Result<(PartitionedFile, Statistics)> + Ok(( + part_file, + Arc::new(Statistics::new_unknown(&self.file_schema)), + )) } }) .boxed() @@ -974,12 +1124,12 @@ impl ListingTable { ctx: &SessionState, store: &Arc, part_file: &PartitionedFile, - ) -> Result { - let statistics_cache = self.collected_statistics.clone(); - return match statistics_cache + ) -> Result> { + match self + .collected_statistics .get_with_extra(&part_file.object_meta.location, &part_file.object_meta) { - Some(statistics) => Ok(statistics.as_ref().clone()), + Some(statistics) => Ok(statistics), None => { let statistics = self .options @@ -991,35 +1141,39 @@ impl ListingTable { &part_file.object_meta, ) .await?; - statistics_cache.put_with_extra( + let statistics = Arc::new(statistics); + self.collected_statistics.put_with_extra( &part_file.object_meta.location, - statistics.clone().into(), + statistics.clone(), &part_file.object_meta, ); Ok(statistics) } - }; + } } } #[cfg(test)] mod tests { - use super::*; + use crate::datasource::file_format::avro::AvroFormat; + use crate::datasource::file_format::csv::CsvFormat; + use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] + use crate::datasource::file_format::parquet::ParquetFormat; use crate::datasource::{provider_as_source, MemTable}; use crate::execution::options::ArrowReadOptions; - use crate::physical_plan::collect; use crate::prelude::*; use crate::{ assert_batches_eq, test::{columns, object_store::register_test_store}, }; + use datafusion_physical_plan::collect; use arrow::record_batch::RecordBatch; use arrow_schema::SortOptions; use datafusion_common::stats::Precision; - use datafusion_common::{assert_contains, GetExt, ScalarValue}; + use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{BinaryExpr, LogicalPlanBuilder, Operator}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::ExecutionPlanProperties; @@ -1050,6 +1204,8 @@ mod tests { #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_by_default() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); @@ -1074,6 +1230,8 @@ mod tests { #[cfg(feature = "parquet")] #[tokio::test] async fn load_table_stats_when_no_stats() -> Result<()> { + use crate::datasource::file_format::parquet::ParquetFormat; + let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); let table_path = ListingTableUrl::parse(filename).unwrap(); @@ -1108,17 +1266,13 @@ mod tests { let options = ListingOptions::new(Arc::new(ParquetFormat::default())); let schema = options.infer_schema(&state, &table_path).await.unwrap(); - use crate::physical_plan::expressions::col as physical_col; + use crate::datasource::file_format::parquet::ParquetFormat; + use datafusion_physical_plan::expressions::col as physical_col; use std::ops::Add; // (file_sort_order, expected_result) let cases = vec![ (vec![], Ok(vec![])), - // not a sort expr - ( - vec![vec![col("string_col")]], - Err("Expected Expr::Sort in output_ordering, but got string_col"), - ), // sort expr, but non column ( vec![vec![ @@ -1129,13 +1283,16 @@ mod tests { // ok with one column ( vec![vec![col("string_col").sort(true, false)]], - Ok(vec![vec![PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }]]) + Ok(vec![LexOrdering { + inner: vec![PhysicalSortExpr { + expr: physical_col("string_col", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }], + } + ]) ), // ok with two columns, different options ( @@ -1143,22 +1300,17 @@ mod tests { col("string_col").sort(true, false), col("int_col").sort(false, true), ]], - Ok(vec![vec![ - PhysicalSortExpr { - expr: physical_col("string_col", &schema).unwrap(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - }, - PhysicalSortExpr { - expr: physical_col("int_col", &schema).unwrap(), - options: SortOptions { - descending: true, - nulls_first: true, - }, - }, - ]]) + Ok(vec![LexOrdering { + inner: vec![ + PhysicalSortExpr::new_default(physical_col("string_col", &schema).unwrap()) + .asc() + .nulls_last(), + PhysicalSortExpr::new_default(physical_col("int_col", &schema).unwrap()) + .desc() + .nulls_first() + ], + } + ]) ), ]; @@ -1199,7 +1351,7 @@ mod tests { register_test_store(&ctx, &[(&path, 100)]); let opt = ListingOptions::new(Arc::new(AvroFormat {})) - .with_file_extension(FileType::AVRO.get_ext()) + .with_file_extension(AvroFormat.get_ext()) .with_table_partition_cols(vec![(String::from("p1"), DataType::Utf8)]) .with_target_partitions(4); @@ -1247,6 +1399,7 @@ mod tests { "test:///bucket/key-prefix/", 12, 5, + Some(""), ) .await?; @@ -1261,6 +1414,7 @@ mod tests { "test:///bucket/key-prefix/", 4, 4, + Some(""), ) .await?; @@ -1276,12 +1430,19 @@ mod tests { "test:///bucket/key-prefix/", 2, 2, + Some(""), ) .await?; // no files => no groups - assert_list_files_for_scan_grouping(&[], "test:///bucket/key-prefix/", 2, 0) - .await?; + assert_list_files_for_scan_grouping( + &[], + "test:///bucket/key-prefix/", + 2, + 0, + Some(""), + ) + .await?; // files that don't match the prefix assert_list_files_for_scan_grouping( @@ -1293,6 +1454,21 @@ mod tests { "test:///bucket/key-prefix/", 10, 2, + Some(""), + ) + .await?; + + // files that don't match the prefix or the default file extention + assert_list_files_for_scan_grouping( + &[ + "bucket/key-prefix/file0.avro", + "bucket/key-prefix/file1.parquet", + "bucket/other-prefix/roguefile.avro", + ], + "test:///bucket/key-prefix/", + 10, + 1, + None, ) .await?; Ok(()) @@ -1313,6 +1489,7 @@ mod tests { &["test:///bucket/key1/", "test:///bucket/key2/"], 12, 5, + Some(""), ) .await?; @@ -1329,6 +1506,7 @@ mod tests { &["test:///bucket/key1/", "test:///bucket/key2/"], 5, 5, + Some(""), ) .await?; @@ -1345,11 +1523,13 @@ mod tests { &["test:///bucket/key1/"], 2, 2, + Some(""), ) .await?; // no files => no groups - assert_list_files_for_multi_paths(&[], &["test:///bucket/key1/"], 2, 0).await?; + assert_list_files_for_multi_paths(&[], &["test:///bucket/key1/"], 2, 0, Some("")) + .await?; // files that don't match the prefix assert_list_files_for_multi_paths( @@ -1364,6 +1544,24 @@ mod tests { &["test:///bucket/key3/"], 2, 1, + Some(""), + ) + .await?; + + // files that don't match the prefix or the default file ext + assert_list_files_for_multi_paths( + &[ + "bucket/key1/file0.avro", + "bucket/key1/file1.csv", + "bucket/key1/file2.avro", + "bucket/key2/file3.csv", + "bucket/key2/file4.avro", + "bucket/key3/file5.csv", + ], + &["test:///bucket/key1/", "test:///bucket/key3/"], + 2, + 2, + None, ) .await?; Ok(()) @@ -1391,6 +1589,7 @@ mod tests { table_prefix: &str, target_partitions: usize, output_partitioning: usize, + file_ext: Option<&str>, ) -> Result<()> { let ctx = SessionContext::new(); register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); @@ -1398,7 +1597,7 @@ mod tests { let format = AvroFormat {}; let opt = ListingOptions::new(Arc::new(format)) - .with_file_extension("") + .with_file_extension_opt(file_ext) .with_target_partitions(target_partitions); let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); @@ -1424,6 +1623,7 @@ mod tests { table_prefix: &[&str], target_partitions: usize, output_partitioning: usize, + file_ext: Option<&str>, ) -> Result<()> { let ctx = SessionContext::new(); register_test_store(&ctx, &files.iter().map(|f| (*f, 10)).collect::>()); @@ -1431,7 +1631,7 @@ mod tests { let format = AvroFormat {}; let opt = ListingOptions::new(Arc::new(format)) - .with_file_extension("") + .with_file_extension_opt(file_ext) .with_target_partitions(target_partitions); let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); @@ -1462,7 +1662,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::JSON, + JsonFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1480,7 +1680,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::CSV, + CsvFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1498,7 +1698,7 @@ mod tests { "10".into(), ); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1516,7 +1716,7 @@ mod tests { "20".into(), ); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 1, @@ -1537,8 +1737,8 @@ mod tests { helper_test_insert_into_sql( "csv", FileCompressionType::UNCOMPRESSED, - "WITH HEADER ROW", - None, + "", + Some(HashMap::from([("has_header".into(), "true".into())])), ) .await?; Ok(()) @@ -1579,7 +1779,7 @@ mod tests { "100".into(), ); config_map.insert( - "datafusion.execution.parquet.staistics_enabled".into(), + "datafusion.execution.parquet.statistics_enabled".into(), "none".into(), ); config_map.insert( @@ -1603,7 +1803,7 @@ mod tests { "50".into(), ); config_map.insert( - "datafusion.execution.parquet.bloom_filter_enabled".into(), + "datafusion.execution.parquet.bloom_filter_on_write".into(), "true".into(), ); config_map.insert( @@ -1653,7 +1853,7 @@ mod tests { "100".into(), ); config_map.insert( - "datafusion.execution.parquet.staistics_enabled".into(), + "datafusion.execution.parquet.statistics_enabled".into(), "none".into(), ); config_map.insert( @@ -1681,7 +1881,7 @@ mod tests { "delta_binary_packed".into(), ); config_map.insert( - "datafusion.execution.parquet.bloom_filter_enabled".into(), + "datafusion.execution.parquet.bloom_filter_on_write".into(), "true".into(), ); config_map.insert( @@ -1702,7 +1902,7 @@ mod tests { ); config_map.insert("datafusion.execution.batch_size".into(), "1".into()); helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1720,7 +1920,7 @@ mod tests { "zstd".into(), ); let e = helper_test_append_new_files_to_table( - FileType::PARQUET, + ParquetFormat::default().get_ext(), FileCompressionType::UNCOMPRESSED, Some(config_map), 2, @@ -1733,7 +1933,7 @@ mod tests { } async fn helper_test_append_new_files_to_table( - file_type: FileType, + file_type_ext: String, file_compression_type: FileCompressionType, session_config_map: Option>, expected_n_files_per_insert: usize, @@ -1741,7 +1941,7 @@ mod tests { // Create the initial context, schema, and batch. let session_ctx = match session_config_map { Some(cfg) => { - let config = SessionConfig::from_string_hash_map(cfg)?; + let config = SessionConfig::from_string_hash_map(&cfg)?; SessionContext::new_with_config(config) } None => SessionContext::new(), @@ -1770,8 +1970,8 @@ mod tests { // Register appropriate table depending on file_type we want to test let tmp_dir = TempDir::new()?; - match file_type { - FileType::CSV => { + match file_type_ext.as_str() { + "csv" => { session_ctx .register_csv( "t", @@ -1782,7 +1982,7 @@ mod tests { ) .await?; } - FileType::JSON => { + "json" => { session_ctx .register_json( "t", @@ -1793,7 +1993,7 @@ mod tests { ) .await?; } - FileType::PARQUET => { + "parquet" => { session_ctx .register_parquet( "t", @@ -1802,7 +2002,7 @@ mod tests { ) .await?; } - FileType::AVRO => { + "avro" => { session_ctx .register_avro( "t", @@ -1811,7 +2011,7 @@ mod tests { ) .await?; } - FileType::ARROW => { + "arrow" => { session_ctx .register_arrow( "t", @@ -1820,6 +2020,7 @@ mod tests { ) .await?; } + _ => panic!("Unrecognized file extension {file_type_ext}"), } // Create and register the source table with the provided schema and inserted data @@ -1838,7 +2039,8 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() @@ -1938,7 +2140,7 @@ mod tests { // Create the initial context let session_ctx = match session_config_map { Some(cfg) => { - let config = SessionConfig::from_string_hash_map(cfg)?; + let config = SessionConfig::from_string_hash_map(&cfg)?; SessionContext::new_with_config(config) } None => SessionContext::new(), diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index 82acb7a3b644..e627cacfbfc7 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; +use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_optimizer::OptimizerConfig; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; @@ -53,7 +53,7 @@ impl ListingTableUrl { /// subdirectories. /// /// Similarly `s3://BUCKET/blob.csv` refers to `blob.csv` in the S3 bucket `BUCKET`, - /// wherease `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the + /// whereas `s3://BUCKET/foo/` refers to all objects with the prefix `foo/` in the /// S3 bucket `BUCKET` /// /// # URL Encoding @@ -187,7 +187,20 @@ impl ListingTableUrl { /// Returns `true` if `path` refers to a collection of objects pub fn is_collection(&self) -> bool { - self.url.as_str().ends_with(DELIMITER) + self.url.path().ends_with(DELIMITER) + } + + /// Returns the file extension of the last path segment if it exists + pub fn file_extension(&self) -> Option<&str> { + if let Some(segments) = self.url.path_segments() { + if let Some(last_segment) = segments.last() { + if last_segment.contains(".") && !last_segment.ends_with(".") { + return last_segment.split('.').last(); + } + } + } + + None } /// Strips the prefix of this [`ListingTableUrl`] from the provided path, returning @@ -463,4 +476,84 @@ mod tests { Some(("/a/b/c//", "alltypes_plain*.parquet")), ); } + + #[test] + fn test_is_collection() { + fn test(input: &str, expected: bool, message: &str) { + let url = ListingTableUrl::parse(input).unwrap(); + assert_eq!(url.is_collection(), expected, "{message}"); + } + + test("https://a.b.c/path/", true, "path ends with / - collection"); + test( + "https://a.b.c/path/?a=b", + true, + "path ends with / - with query args - collection", + ); + test( + "https://a.b.c/path?a=b/", + false, + "path not ends with / - query ends with / - not collection", + ); + test( + "https://a.b.c/path/#a=b", + true, + "path ends with / - with fragment - collection", + ); + test( + "https://a.b.c/path#a=b/", + false, + "path not ends with / - fragment ends with / - not collection", + ); + } + + #[test] + fn test_file_extension() { + fn test(input: &str, expected: Option<&str>, message: &str) { + let url = ListingTableUrl::parse(input).unwrap(); + assert_eq!(url.file_extension(), expected, "{message}"); + } + + test("https://a.b.c/path/", None, "path ends with / - not a file"); + test( + "https://a.b.c/path/?a=b", + None, + "path ends with / - with query args - not a file", + ); + test( + "https://a.b.c/path?a=b/", + None, + "path not ends with / - query ends with / but no file extension", + ); + test( + "https://a.b.c/path/#a=b", + None, + "path ends with / - with fragment - not a file", + ); + test( + "https://a.b.c/path#a=b/", + None, + "path not ends with / - fragment ends with / but no file extension", + ); + test( + "file///some/path/", + None, + "file path ends with / - not a file", + ); + test( + "file///some/path/file", + None, + "file path does not end with - no extension", + ); + test( + "file///some/path/file.", + None, + "file path ends with . - no value after .", + ); + test( + "file///some/path/file.ext", + Some("ext"), + "file path ends with .ext - extension is ext", + ); + } } diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 1a0eb34d1234..581d88d25884 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -17,27 +17,23 @@ //! Factory for creating ListingTables with default options +use std::collections::HashSet; use std::path::Path; -use std::str::FromStr; use std::sync::Arc; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::{ - arrow::ArrowFormat, avro::AvroFormat, csv::CsvFormat, json::JsonFormat, FileFormat, -}; +use crate::catalog::{TableProvider, TableProviderFactory}; use crate::datasource::listing::{ ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl, }; -use crate::datasource::provider::TableProviderFactory; -use crate::datasource::TableProvider; use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::{arrow_datafusion_err, DataFusionError, FileType}; +use datafusion_common::{arrow_datafusion_err, plan_err, DataFusionError, ToDFSchema}; +use datafusion_common::{config_datafusion_err, Result}; use datafusion_expr::CreateExternalTable; use async_trait::async_trait; +use datafusion_catalog::Session; /// A `TableProviderFactory` capable of creating new `ListingTable`s #[derive(Debug, Default)] @@ -54,36 +50,20 @@ impl ListingTableFactory { impl TableProviderFactory for ListingTableFactory { async fn create( &self, - state: &SessionState, + state: &dyn Session, cmd: &CreateExternalTable, - ) -> datafusion_common::Result> { - let mut table_options = state.default_table_options(); - let file_type = FileType::from_str(cmd.file_type.as_str()).map_err(|_| { - DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) - })?; - table_options.set_file_format(file_type.clone()); - table_options.alter_with_string_hash_map(&cmd.options)?; + ) -> Result> { + // TODO (https://github.com/apache/datafusion/issues/11600) remove downcast_ref from here. Should file format factory be an extension to session state? + let session_state = state.as_any().downcast_ref::().unwrap(); + let file_format = session_state + .get_file_format_factory(cmd.file_type.as_str()) + .ok_or(config_datafusion_err!( + "Unable to create table with format {}! Could not find FileFormat.", + cmd.file_type + ))? + .create(session_state, &cmd.options)?; + let file_extension = get_extension(cmd.location.as_str()); - let file_format: Arc = match file_type { - FileType::CSV => { - let mut csv_options = table_options.csv; - csv_options.has_header = cmd.has_header; - csv_options.delimiter = cmd.delimiter as u8; - csv_options.compression = cmd.file_compression_type; - Arc::new(CsvFormat::default().with_options(csv_options)) - } - #[cfg(feature = "parquet")] - FileType::PARQUET => { - Arc::new(ParquetFormat::default().with_options(table_options.parquet)) - } - FileType::AVRO => Arc::new(AvroFormat), - FileType::JSON => { - let mut json_options = table_options.json; - json_options.compression = cmd.file_compression_type; - Arc::new(JsonFormat::default().with_options(json_options)) - } - FileType::ARROW => Arc::new(ArrowFormat), - }; let (provided_schema, table_partition_cols) = if cmd.schema.fields().is_empty() { ( @@ -111,7 +91,7 @@ impl TableProviderFactory for ListingTableFactory { .field_with_name(col) .map_err(|e| arrow_datafusion_err!(e)) }) - .collect::>>()? + .collect::>>()? .into_iter() .map(|f| (f.name().to_owned(), f.data_type().to_owned())) .collect(); @@ -134,17 +114,39 @@ impl TableProviderFactory for ListingTableFactory { .with_collect_stat(state.config().collect_statistics()) .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()) - .with_table_partition_cols(table_partition_cols) - .with_file_sort_order(cmd.order_exprs.clone()); + .with_table_partition_cols(table_partition_cols); - options.validate_partitions(state, &table_path).await?; + options + .validate_partitions(session_state, &table_path) + .await?; let resolved_schema = match provided_schema { - None => options.infer_schema(state, &table_path).await?, + // We will need to check the table columns against the schema + // this is done so that we can do an ORDER BY for external table creation + // specifically for parquet file format. + // See: https://github.com/apache/datafusion/issues/7317 + None => { + let schema = options.infer_schema(session_state, &table_path).await?; + let df_schema = schema.clone().to_dfschema()?; + let column_refs: HashSet<_> = cmd + .order_exprs + .iter() + .flat_map(|sort| sort.iter()) + .flat_map(|s| s.expr.column_refs()) + .collect(); + + for column in &column_refs { + if !df_schema.has_column(column) { + return plan_err!("Column {column} is not in schema"); + } + } + + schema + } Some(s) => s, }; let config = ListingTableConfig::new(table_path) - .with_listing_options(options) + .with_listing_options(options.with_file_sort_order(cmd.order_exprs.clone())) .with_schema(resolved_schema); let provider = ListingTable::try_new(config)? .with_cache(state.runtime_env().cache_manager.get_file_statistic_cache()); @@ -170,9 +172,10 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::execution::context::SessionContext; + use crate::{ + datasource::file_format::csv::CsvFormat, execution::context::SessionContext, + }; - use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{Constraints, DFSchema, TableReference}; #[tokio::test] @@ -191,16 +194,14 @@ mod tests { name, location: csv_file.path().to_str().unwrap().to_string(), file_type: "csv".to_string(), - has_header: true, - delimiter: ',', schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, - file_compression_type: CompressionTypeVariant::UNCOMPRESSED, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, - options: HashMap::new(), + options: HashMap::from([("format.has_header".into(), "true".into())]), constraints: Constraints::empty(), column_defaults: HashMap::new(), }; @@ -228,16 +229,15 @@ mod tests { let mut options = HashMap::new(); options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.has_header".into(), "true".into()); let cmd = CreateExternalTable { name, location: csv_file.path().to_str().unwrap().to_string(), file_type: "csv".to_string(), - has_header: true, - delimiter: ',', schema: Arc::new(DFSchema::empty()), table_partition_cols: vec![], if_not_exists: false, - file_compression_type: CompressionTypeVariant::UNCOMPRESSED, + temporary: false, definition: None, order_exprs: vec![], unbounded: false, diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index aab42285a0b2..3c2d1b0205d6 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -37,8 +37,11 @@ use crate::physical_planner::create_physical_sort_exprs; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion_catalog::Session; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::SortExpr; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; @@ -63,7 +66,7 @@ pub struct MemTable { column_defaults: HashMap, /// Optional pre-known sort order(s). Must be `SortExpr`s. /// inserting data into this table removes the order - pub sort_order: Arc>>>, + pub sort_order: Arc>>>, } impl MemTable { @@ -117,7 +120,7 @@ impl MemTable { /// /// Note that multiple sort orders are supported, if some are known to be /// equivalent, - pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { + pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order); self } @@ -206,7 +209,7 @@ impl TableProvider for MemTable { async fn scan( &self, - state: &SessionState, + state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -238,7 +241,7 @@ impl TableProvider for MemTable { ) }) .collect::>>()?; - exec = exec.with_sort_information(file_sort_order); + exec = exec.try_with_sort_information(file_sort_order)?; } Ok(Arc::new(exec)) @@ -258,9 +261,9 @@ impl TableProvider for MemTable { /// * A plan that returns the number of rows written. async fn insert_into( &self, - _state: &SessionState, + _state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // If we are inserting into the table, any sort order may be messed up so reset it here *self.sort_order.lock() = vec![]; @@ -272,11 +275,23 @@ impl TableProvider for MemTable { .logically_equivalent_names_and_types(&input.schema()) { return plan_err!( - "Inserting query must have the same schema with the table." + "Inserting query must have the same schema with the table. \ + Expected: {:?}, got: {:?}", + self.schema() + .fields() + .iter() + .map(|field| field.data_type()) + .collect::>(), + input + .schema() + .fields() + .iter() + .map(|field| field.data_type()) + .collect::>() ); } - if overwrite { - return not_impl_err!("Overwrite not implemented for MemoryTable yet"); + if insert_op != InsertOp::Append { + return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } let sink = Arc::new(MemSink::new(self.batches.clone())); Ok(Arc::new(DataSinkExec::new( @@ -624,7 +639,8 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() @@ -644,7 +660,7 @@ mod tests { Ok(partitions) } - /// Returns the value of results. For example, returns 6 given the follwing + /// Returns the value of results. For example, returns 6 given the following /// /// ```text /// +-------+, diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 351967d35324..ad369b75e130 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -22,6 +22,7 @@ pub mod avro_to_arrow; pub mod cte_worktable; pub mod default_table_source; +pub mod dynamic_file; pub mod empty; pub mod file_format; pub mod function; @@ -30,6 +31,7 @@ pub mod listing_table_factory; pub mod memory; pub mod physical_plan; pub mod provider; +pub mod schema_adapter; mod statistics; pub mod stream; pub mod streaming; @@ -42,45 +44,46 @@ pub use self::default_table_source::{ provider_as_source, source_as_provider, DefaultTableSource, }; pub use self::memory::MemTable; -pub use self::provider::TableProvider; pub use self::view::ViewTable; +pub use crate::catalog::TableProvider; pub use crate::logical_expr::TableType; pub use statistics::get_statistics_with_limit; use arrow_schema::{Schema, SortOptions}; use datafusion_common::{plan_err, Result}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use datafusion_physical_expr::{expressions, LexOrdering, PhysicalSortExpr}; fn create_ordering( schema: &Schema, - sort_order: &[Vec], + sort_order: &[Vec], ) -> Result> { let mut all_sort_orders = vec![]; for exprs in sort_order { // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in exprs { - match expr { - Expr::Sort(sort) => match sort.expr.as_ref() { - Expr::Column(col) => match expressions::col(&col.name, schema) { - Ok(expr) => { - sort_exprs.push(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - // Cannot find expression in the projected_schema, stop iterating - // since rest of the orderings are violated - Err(_) => break, + let mut sort_exprs = LexOrdering::default(); + for sort in exprs { + match &sort.expr { + Expr::Column(col) => match expressions::col(&col.name, schema) { + Ok(expr) => { + sort_exprs.push(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !sort.asc, + nulls_first: sort.nulls_first, + }, + }); } - expr => return plan_err!("Expected single column references in output_ordering, got {expr}"), + // Cannot find expression in the projected_schema, stop iterating + // since rest of the orderings are violated + Err(_) => break, + }, + expr => { + return plan_err!( + "Expected single column references in output_ordering, got {expr}" + ) } - expr => return plan_err!("Expected Expr::Sort in output_ordering, but got {expr}"), } } if !sort_exprs.is_empty() { diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 1e8775731015..39625a55ca15 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -31,6 +31,7 @@ use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }; +use arrow::buffer::Buffer; use arrow_ipc::reader::FileDecoder; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; @@ -134,7 +135,7 @@ impl ExecutionPlan for ArrowExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { Vec::new() } @@ -195,6 +196,23 @@ impl ExecutionPlan for ArrowExec { fn statistics(&self) -> Result { Ok(self.projected_statistics.clone()) } + + fn fetch(&self) -> Option { + self.base_config.limit + } + + fn with_fetch(&self, limit: Option) -> Option> { + let new_config = self.base_config.clone().with_limit(limit); + + Some(Arc::new(Self { + base_config: new_config, + projected_statistics: self.projected_statistics.clone(), + projected_schema: self.projected_schema.clone(), + projected_output_ordering: self.projected_output_ordering.clone(), + metrics: self.metrics.clone(), + cache: self.cache.clone(), + })) + } } pub struct ArrowOpener { @@ -283,7 +301,10 @@ impl FileOpener for ArrowOpener { for (dict_block, dict_result) in footer.dictionaries().iter().flatten().zip(dict_results) { - decoder.read_dictionary(dict_block, &dict_result.into())?; + decoder.read_dictionary( + dict_block, + &Buffer::from_bytes(dict_result.into()), + )?; } // filter recordbatches according to range @@ -318,11 +339,12 @@ impl FileOpener for ArrowOpener { .into_iter() .zip(recordbatch_results) .filter_map(move |(block, data)| { - match decoder.read_record_batch(&block, &data.into()) { - Ok(Some(record_batch)) => Some(Ok(record_batch)), - Ok(None) => None, - Err(err) => Some(Err(err)), - } + decoder + .read_record_batch( + &block, + &Buffer::from_bytes(data.into()), + ) + .transpose() }), ) .boxed()) diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 4e5140e82d3f..ce72c4087424 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -111,7 +111,7 @@ impl ExecutionPlan for AvroExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { Vec::new() } @@ -164,6 +164,23 @@ impl ExecutionPlan for AvroExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn fetch(&self) -> Option { + self.base_config.limit + } + + fn with_fetch(&self, limit: Option) -> Option> { + let new_config = self.base_config.clone().with_limit(limit); + + Some(Arc::new(Self { + base_config: new_config, + projected_statistics: self.projected_statistics.clone(), + projected_schema: self.projected_schema.clone(), + projected_output_ordering: self.projected_output_ordering.clone(), + metrics: self.metrics.clone(), + cache: self.cache.clone(), + })) + } } #[cfg(feature = "avro")] @@ -261,9 +278,7 @@ mod tests { let state = session_ctx.state(); let url = Url::parse("file://").unwrap(); - state - .runtime_env() - .register_object_store(&url, store.clone()); + session_ctx.register_object_store(&url, store.clone()); let testdata = crate::test_util::arrow_test_data(); let filename = format!("{testdata}/avro/alltypes_plain.avro"); @@ -273,16 +288,11 @@ mod tests { .infer_schema(&state, &store, &[meta.clone()]) .await?; - let avro_exec = AvroExec::new(FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: vec![vec![meta.into()]], - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: Some(vec![0, 1, 2]), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }); + let avro_exec = AvroExec::new( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file(meta.into()) + .with_projection(Some(vec![0, 1, 2])), + ); assert_eq!( avro_exec .properties() @@ -350,16 +360,11 @@ mod tests { // Include the missing column in the projection let projection = Some(vec![0, 1, 2, actual_schema.fields().len()]); - let avro_exec = AvroExec::new(FileScanConfig { - object_store_url, - file_groups: vec![vec![meta.into()]], - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }); + let avro_exec = AvroExec::new( + FileScanConfig::new(object_store_url, file_schema) + .with_file(meta.into()) + .with_projection(projection), + ); assert_eq!( avro_exec .properties() @@ -424,18 +429,19 @@ mod tests { let mut partitioned_file = PartitionedFile::from(meta); partitioned_file.partition_values = vec![ScalarValue::from("2021-10-26")]; - let avro_exec = AvroExec::new(FileScanConfig { - // select specific columns of the files as well as the partitioning - // column which is supposed to be the last column in the table schema. - projection: Some(vec![0, 1, file_schema.fields().len(), 2]), - object_store_url, - file_groups: vec![vec![partitioned_file]], - statistics: Statistics::new_unknown(&file_schema), - file_schema, - limit: None, - table_partition_cols: vec![Field::new("date", DataType::Utf8, false)], - output_ordering: vec![], - }); + let projection = Some(vec![0, 1, file_schema.fields().len(), 2]); + let avro_exec = AvroExec::new( + FileScanConfig::new(object_store_url, file_schema) + // select specific columns of the files as well as the partitioning + // column which is supposed to be the last column in the table schema. + .with_projection(projection) + .with_file(partitioned_file) + .with_table_partition_cols(vec![Field::new( + "date", + DataType::Utf8, + false, + )]), + ); assert_eq!( avro_exec .properties() diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 879461c2eb1e..5beffc3b0581 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -49,7 +49,27 @@ use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; -/// Execution plan for scanning a CSV file +/// Execution plan for scanning a CSV file. +/// +/// # Example: create a `CsvExec` +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::{ +/// # physical_plan::{CsvExec, FileScanConfig}, +/// # listing::PartitionedFile, +/// # }; +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # let object_store_url = ObjectStoreUrl::local_filesystem(); +/// # let file_schema = Arc::new(Schema::empty()); +/// // Create a CsvExec for reading the first 100MB of `file1.csv` +/// let file_scan_config = FileScanConfig::new(object_store_url, file_schema) +/// .with_file(PartitionedFile::new("file1.csv", 100*1024*1024)); +/// let exec = CsvExec::builder(file_scan_config) +/// .with_has_header(true) // The file has a header row +/// .with_newlines_in_values(true) // The file contains newlines in values +/// .build(); +/// ``` #[derive(Debug, Clone)] pub struct CsvExec { base_config: FileScanConfig, @@ -57,7 +77,10 @@ pub struct CsvExec { has_header: bool, delimiter: u8, quote: u8, + terminator: Option, escape: Option, + comment: Option, + newlines_in_values: bool, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Compression type of the file associated with CsvExec @@ -65,35 +88,184 @@ pub struct CsvExec { cache: PlanProperties, } -impl CsvExec { - /// Create a new CSV reader execution plan provided base and specific configurations - pub fn new( - base_config: FileScanConfig, - has_header: bool, - delimiter: u8, - quote: u8, - escape: Option, +/// Builder for [`CsvExec`]. +/// +/// See example on [`CsvExec`]. +#[derive(Debug, Clone)] +pub struct CsvExecBuilder { + file_scan_config: FileScanConfig, + file_compression_type: FileCompressionType, + // TODO: it seems like these format options could be reused across all the various CSV config + has_header: bool, + delimiter: u8, + quote: u8, + terminator: Option, + escape: Option, + comment: Option, + newlines_in_values: bool, +} + +impl CsvExecBuilder { + /// Create a new builder to read the provided file scan configuration. + pub fn new(file_scan_config: FileScanConfig) -> Self { + Self { + file_scan_config, + // TODO: these defaults are duplicated from `CsvOptions` - should they be computed? + has_header: false, + delimiter: b',', + quote: b'"', + terminator: None, + escape: None, + comment: None, + newlines_in_values: false, + file_compression_type: FileCompressionType::UNCOMPRESSED, + } + } + + /// Set whether the first row defines the column names. + /// + /// The default value is `false`. + pub fn with_has_header(mut self, has_header: bool) -> Self { + self.has_header = has_header; + self + } + + /// Set the column delimeter. + /// + /// The default is `,`. + pub fn with_delimeter(mut self, delimiter: u8) -> Self { + self.delimiter = delimiter; + self + } + + /// Set the quote character. + /// + /// The default is `"`. + pub fn with_quote(mut self, quote: u8) -> Self { + self.quote = quote; + self + } + + /// Set the line terminator. If not set, the default is CRLF. + /// + /// The default is None. + pub fn with_terminator(mut self, terminator: Option) -> Self { + self.terminator = terminator; + self + } + + /// Set the escape character. + /// + /// The default is `None` (i.e. quotes cannot be escaped). + pub fn with_escape(mut self, escape: Option) -> Self { + self.escape = escape; + self + } + + /// Set the comment character. + /// + /// The default is `None` (i.e. comments are not supported). + pub fn with_comment(mut self, comment: Option) -> Self { + self.comment = comment; + self + } + + /// Set whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default value is `false`. + pub fn with_newlines_in_values(mut self, newlines_in_values: bool) -> Self { + self.newlines_in_values = newlines_in_values; + self + } + + /// Set the file compression type. + /// + /// The default is [`FileCompressionType::UNCOMPRESSED`]. + pub fn with_file_compression_type( + mut self, file_compression_type: FileCompressionType, ) -> Self { + self.file_compression_type = file_compression_type; + self + } + + /// Build a [`CsvExec`]. + #[must_use] + pub fn build(self) -> CsvExec { + let Self { + file_scan_config: base_config, + file_compression_type, + has_header, + delimiter, + quote, + terminator, + escape, + comment, + newlines_in_values, + } = self; + let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); - let cache = Self::compute_properties( + let cache = CsvExec::compute_properties( projected_schema, &projected_output_ordering, &base_config, ); - Self { + + CsvExec { base_config, projected_statistics, has_header, delimiter, quote, + terminator, escape, + newlines_in_values, metrics: ExecutionPlanMetricsSet::new(), file_compression_type, cache, + comment, } } +} + +impl CsvExec { + /// Create a new CSV reader execution plan provided base and specific configurations + #[deprecated(since = "41.0.0", note = "use `CsvExec::builder` or `CsvExecBuilder`")] + #[allow(clippy::too_many_arguments)] + pub fn new( + base_config: FileScanConfig, + has_header: bool, + delimiter: u8, + quote: u8, + terminator: Option, + escape: Option, + comment: Option, + newlines_in_values: bool, + file_compression_type: FileCompressionType, + ) -> Self { + CsvExecBuilder::new(base_config) + .with_has_header(has_header) + .with_delimeter(delimiter) + .with_quote(quote) + .with_terminator(terminator) + .with_escape(escape) + .with_comment(comment) + .with_newlines_in_values(newlines_in_values) + .with_file_compression_type(file_compression_type) + .build() + } + + /// Return a [`CsvExecBuilder`]. + /// + /// See example on [`CsvExec`] and [`CsvExecBuilder`] for specifying CSV table options. + pub fn builder(file_scan_config: FileScanConfig) -> CsvExecBuilder { + CsvExecBuilder::new(file_scan_config) + } /// Ref to the base configs pub fn base_config(&self) -> &FileScanConfig { @@ -113,11 +285,32 @@ impl CsvExec { self.quote } + /// The line terminator + pub fn terminator(&self) -> Option { + self.terminator + } + + /// Lines beginning with this byte are ignored. + pub fn comment(&self) -> Option { + self.comment + } + /// The escape character pub fn escape(&self) -> Option { self.escape } + /// Specifies whether newlines in (quoted) values are supported. + /// + /// Parsing newlines in quoted values may be affected by execution behaviour such as + /// parallel file scanning. Setting this to `true` ensures that newlines in values are + /// parsed successfully, which may reduce performance. + /// + /// The default behaviour depends on the `datafusion.catalog.newlines_in_values` setting. + pub fn newlines_in_values(&self) -> bool { + self.newlines_in_values + } + fn output_partitioning_helper(file_scan_config: &FileScanConfig) -> Partitioning { Partitioning::UnknownPartitioning(file_scan_config.file_groups.len()) } @@ -173,7 +366,7 @@ impl ExecutionPlan for CsvExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { // this is a leaf node and has no children vec![] } @@ -188,15 +381,15 @@ impl ExecutionPlan for CsvExec { /// Redistribute files across partitions according to their size /// See comments on [`FileGroupPartitioner`] for more detail. /// - /// Return `None` if can't get repartitioned(empty/compressed file). + /// Return `None` if can't get repartitioned (empty, compressed file, or `newlines_in_values` set). fn repartitioned( &self, target_partitions: usize, config: &ConfigOptions, ) -> Result>> { let repartition_file_min_size = config.optimizer.repartition_file_min_size; - // Parallel execution on compressed CSV file is not supported yet. - if self.file_compression_type.is_compressed() { + // Parallel execution on compressed CSV files or files that must support newlines in values is not supported yet. + if self.file_compression_type.is_compressed() || self.newlines_in_values { return Ok(None); } @@ -233,9 +426,10 @@ impl ExecutionPlan for CsvExec { delimiter: self.delimiter, quote: self.quote, escape: self.escape, + terminator: self.terminator, object_store, + comment: self.comment, }); - let opener = CsvOpener { config, file_compression_type: self.file_compression_type.to_owned(), @@ -252,6 +446,29 @@ impl ExecutionPlan for CsvExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn fetch(&self) -> Option { + self.base_config.limit + } + + fn with_fetch(&self, limit: Option) -> Option> { + let new_config = self.base_config.clone().with_limit(limit); + + Some(Arc::new(Self { + base_config: new_config, + projected_statistics: self.projected_statistics.clone(), + has_header: self.has_header, + delimiter: self.delimiter, + quote: self.quote, + escape: self.escape, + terminator: self.terminator, + comment: self.comment, + newlines_in_values: self.newlines_in_values, + metrics: self.metrics.clone(), + file_compression_type: self.file_compression_type, + cache: self.cache.clone(), + })) + } } /// A Config for [`CsvOpener`] @@ -263,11 +480,14 @@ pub struct CsvConfig { has_header: bool, delimiter: u8, quote: u8, + terminator: Option, escape: Option, object_store: Arc, + comment: Option, } impl CsvConfig { + #[allow(clippy::too_many_arguments)] /// Returns a [`CsvConfig`] pub fn new( batch_size: usize, @@ -276,7 +496,9 @@ impl CsvConfig { has_header: bool, delimiter: u8, quote: u8, + terminator: Option, object_store: Arc, + comment: Option, ) -> Self { Self { batch_size, @@ -285,8 +507,10 @@ impl CsvConfig { has_header, delimiter, quote, + terminator, escape: None, object_store, + comment, } } } @@ -302,13 +526,18 @@ impl CsvConfig { .with_batch_size(self.batch_size) .with_header(self.has_header) .with_quote(self.quote); - + if let Some(terminator) = self.terminator { + builder = builder.with_terminator(terminator); + } if let Some(proj) = &self.file_projection { builder = builder.with_projection(proj.clone()); } if let Some(escape) = self.escape { builder = builder.with_escape(escape) } + if let Some(comment) = self.comment { + builder = builder.with_comment(comment); + } builder } @@ -518,14 +747,16 @@ mod tests { use super::*; use crate::dataframe::DataFrameWriteOptions; + use crate::datasource::file_format::csv::CsvFormat; use crate::prelude::*; use crate::test::{partitioned_csv_config, partitioned_file_groups}; use crate::{scalar::ScalarValue, test_util::aggr_test_schema}; use arrow::datatypes::*; use datafusion_common::test_util::arrow_test_data; - use datafusion_common::FileType; + use datafusion_common::config::CsvOptions; + use datafusion_execution::object_store::ObjectStoreUrl; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; @@ -545,6 +776,8 @@ mod tests { async fn csv_exec_with_projection( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); @@ -556,22 +789,24 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; - let mut config = partitioned_csv_config(file_schema, file_groups)?; + let mut config = partitioned_csv_config(file_schema, file_groups); config.projection = Some(vec![0, 2, 4]); - let csv = CsvExec::new( - config, - true, - b',', - b'"', - None, - file_compression_type.to_owned(), - ); + let csv = CsvExec::builder(config) + .with_has_header(true) + .with_delimeter(b',') + .with_quote(b'"') + .with_terminator(None) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(file_compression_type) + .build(); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); @@ -610,7 +845,10 @@ mod tests { async fn csv_exec_with_mixed_order_projection( file_compression_type: FileCompressionType, ) -> Result<()> { - let session_ctx = SessionContext::new(); + use crate::datasource::file_format::csv::CsvFormat; + + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); + let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); @@ -621,22 +859,24 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; - let mut config = partitioned_csv_config(file_schema, file_groups)?; + let mut config = partitioned_csv_config(file_schema, file_groups); config.projection = Some(vec![4, 0, 2]); - let csv = CsvExec::new( - config, - true, - b',', - b'"', - None, - file_compression_type.to_owned(), - ); + let csv = CsvExec::builder(config) + .with_has_header(true) + .with_delimeter(b',') + .with_quote(b'"') + .with_terminator(None) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); @@ -675,7 +915,10 @@ mod tests { async fn csv_exec_with_limit( file_compression_type: FileCompressionType, ) -> Result<()> { - let session_ctx = SessionContext::new(); + use crate::datasource::file_format::csv::CsvFormat; + + let cfg = SessionConfig::new().set_str("datafusion.catalog.has_header", "true"); + let session_ctx = SessionContext::new_with_config(cfg); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let path = format!("{}/csv", arrow_test_data()); @@ -686,22 +929,24 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; - let mut config = partitioned_csv_config(file_schema, file_groups)?; + let mut config = partitioned_csv_config(file_schema, file_groups); config.limit = Some(5); - let csv = CsvExec::new( - config, - true, - b',', - b'"', - None, - file_compression_type.to_owned(), - ); + let csv = CsvExec::builder(config) + .with_has_header(true) + .with_delimeter(b',') + .with_quote(b'"') + .with_terminator(None) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); @@ -738,6 +983,8 @@ mod tests { async fn csv_exec_with_missing_column( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); @@ -749,22 +996,24 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; - let mut config = partitioned_csv_config(file_schema, file_groups)?; + let mut config = partitioned_csv_config(file_schema, file_groups); config.limit = Some(5); - let csv = CsvExec::new( - config, - true, - b',', - b'"', - None, - file_compression_type.to_owned(), - ); + let csv = CsvExec::builder(config) + .with_has_header(true) + .with_delimeter(b',') + .with_quote(b'"') + .with_terminator(None) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); assert_eq!(14, csv.base_config.file_schema.fields().len()); assert_eq!(14, csv.schema().fields().len()); @@ -791,6 +1040,8 @@ mod tests { async fn csv_exec_with_partition( file_compression_type: FileCompressionType, ) -> Result<()> { + use crate::datasource::file_format::csv::CsvFormat; + let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); @@ -802,12 +1053,12 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), )?; - let mut config = partitioned_csv_config(file_schema, file_groups)?; + let mut config = partitioned_csv_config(file_schema, file_groups); // Add partition columns config.table_partition_cols = vec![Field::new("date", DataType::Utf8, false)]; @@ -819,14 +1070,16 @@ mod tests { // we don't have `/date=xx/` in the path but that is ok because // partitions are resolved during scan anyway - let csv = CsvExec::new( - config, - true, - b',', - b'"', - None, - file_compression_type.to_owned(), - ); + let csv = CsvExec::builder(config) + .with_has_header(true) + .with_delimeter(b',') + .with_quote(b'"') + .with_terminator(None) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); assert_eq!(13, csv.base_config.file_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); @@ -893,7 +1146,7 @@ mod tests { ) -> Result<()> { let ctx = SessionContext::new(); let url = Url::parse("file://").unwrap(); - ctx.runtime_env().register_object_store(&url, store.clone()); + ctx.register_object_store(&url, store.clone()); let task_ctx = ctx.task_ctx(); @@ -906,21 +1159,23 @@ mod tests { path.as_str(), filename, 1, - FileType::CSV, + Arc::new(CsvFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), ) .unwrap(); - let config = partitioned_csv_config(file_schema, file_groups).unwrap(); - let csv = CsvExec::new( - config, - true, - b',', - b'"', - None, - file_compression_type.to_owned(), - ); + let config = partitioned_csv_config(file_schema, file_groups); + let csv = CsvExec::builder(config) + .with_has_header(true) + .with_delimeter(b',') + .with_quote(b'"') + .with_terminator(None) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(file_compression_type.to_owned()) + .build(); let it = csv.execute(0, task_ctx).unwrap(); let batches: Vec<_> = it.try_collect().await.unwrap(); @@ -961,14 +1216,12 @@ mod tests { let session_ctx = SessionContext::new(); let store = object_store::memory::InMemory::new(); - let data = bytes::Bytes::from("a,b\n1,2\n3,4"); + let data = Bytes::from("a,b\n1,2\n3,4"); let path = object_store::path::Path::from("a.csv"); - store.put(&path, data).await.unwrap(); + store.put(&path, data.into()).await.unwrap(); let url = Url::parse("memory://").unwrap(); - session_ctx - .runtime_env() - .register_object_store(&url, Arc::new(store)); + session_ctx.register_object_store(&url, Arc::new(store)); let df = session_ctx .read_csv("memory:///", CsvReadOptions::new()) @@ -989,6 +1242,107 @@ mod tests { crate::assert_batches_eq!(expected, &result); } + #[tokio::test] + async fn test_terminator() { + let session_ctx = SessionContext::new(); + let store = object_store::memory::InMemory::new(); + + let data = Bytes::from("a,b\r1,2\r3,4"); + let path = object_store::path::Path::from("a.csv"); + store.put(&path, data.into()).await.unwrap(); + + let url = Url::parse("memory://").unwrap(); + session_ctx.register_object_store(&url, Arc::new(store)); + + let df = session_ctx + .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\r'))) + .await + .unwrap(); + + let result = df.collect().await.unwrap(); + + let expected = [ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 2 |", + "| 3 | 4 |", + "+---+---+", + ]; + + crate::assert_batches_eq!(expected, &result); + + let e = session_ctx + .read_csv("memory:///", CsvReadOptions::new().terminator(Some(b'\n'))) + .await + .unwrap() + .collect() + .await + .unwrap_err(); + assert_eq!(e.strip_backtrace(), "Arrow error: Csv error: incorrect number of fields for line 1, expected 2 got more than 2") + } + + #[tokio::test] + async fn test_create_external_table_with_terminator() -> Result<()> { + let ctx = SessionContext::new(); + ctx.sql( + r#" + CREATE EXTERNAL TABLE t1 ( + col1 TEXT, + col2 TEXT + ) STORED AS CSV + LOCATION 'tests/data/cr_terminator.csv' + OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true'); + "#, + ) + .await? + .collect() + .await?; + + let df = ctx.sql(r#"select * from t1"#).await?.collect().await?; + let expected = [ + "+------+--------+", + "| col1 | col2 |", + "+------+--------+", + "| id0 | value0 |", + "| id1 | value1 |", + "| id2 | value2 |", + "| id3 | value3 |", + "+------+--------+", + ]; + crate::assert_batches_eq!(expected, &df); + Ok(()) + } + + #[tokio::test] + async fn test_create_external_table_with_terminator_with_newlines_in_values( + ) -> Result<()> { + let ctx = SessionContext::new(); + ctx.sql(r#" + CREATE EXTERNAL TABLE t1 ( + col1 TEXT, + col2 TEXT + ) STORED AS CSV + LOCATION 'tests/data/newlines_in_values_cr_terminator.csv' + OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true', 'format.newlines_in_values' 'true'); + "#).await?.collect().await?; + + let df = ctx.sql(r#"select * from t1"#).await?.collect().await?; + let expected = [ + "+-------+-----------------------------+", + "| col1 | col2 |", + "+-------+-----------------------------+", + "| 1 | hello\rworld |", + "| 2 | something\relse |", + "| 3 | \rmany\rlines\rmake\rgood test\r |", + "| 4 | unquoted |", + "| value | end |", + "+-------+-----------------------------+", + ]; + crate::assert_batches_eq!(expected, &df); + Ok(()) + } + #[tokio::test] async fn write_csv_results_error_handling() -> Result<()> { let ctx = SessionContext::new(); @@ -997,7 +1351,7 @@ mod tests { let tmp_dir = TempDir::new()?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); let options = CsvReadOptions::default() .schema_infer_max_records(2) .has_header(true); @@ -1017,7 +1371,9 @@ mod tests { // create partitioned input file and context let tmp_dir = TempDir::new()?; let ctx = SessionContext::new_with_config( - SessionConfig::new().with_target_partitions(8), + SessionConfig::new() + .with_target_partitions(8) + .set_str("datafusion.catalog.has_header", "false"), ); let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; @@ -1035,7 +1391,7 @@ mod tests { let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); // execute a simple query and write the results to CSV let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/"; @@ -1045,7 +1401,9 @@ mod tests { .await?; // create a new context and verify that the results were saved to a partitioned csv file - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().set_str("datafusion.catalog.has_header", "false"), + ); let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::UInt32, false), @@ -1074,7 +1432,7 @@ mod tests { panic!("Did not find part_0 in csv output files!") } // register each partition as well as the top level dir - let csv_read_option = CsvReadOptions::new().schema(&schema); + let csv_read_option = CsvReadOptions::new().schema(&schema).has_header(false); ctx.register_csv( "part0", &format!("{out_dir}/{part_0_name}"), @@ -1126,4 +1484,36 @@ mod tests { Arc::new(schema) } + + /// Ensure that the default options are set correctly + #[test] + fn test_default_options() { + let file_scan_config = + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), aggr_test_schema()) + .with_file(PartitionedFile::new("foo", 34)); + + let CsvExecBuilder { + file_scan_config: _, + file_compression_type: _, + has_header, + delimiter, + quote, + terminator, + escape, + comment, + newlines_in_values, + } = CsvExecBuilder::new(file_scan_config); + + let default_options = CsvOptions::default(); + assert_eq!(has_header, default_options.has_header.unwrap_or(false)); + assert_eq!(delimiter, default_options.delimiter); + assert_eq!(quote, default_options.quote); + assert_eq!(terminator, default_options.terminator); + assert_eq!(escape, default_options.escape); + assert_eq!(comment, default_options.comment); + assert_eq!( + newlines_in_values, + default_options.newlines_in_values.unwrap_or(false) + ); + } } diff --git a/datafusion/core/src/datasource/physical_plan/file_groups.rs b/datafusion/core/src/datasource/physical_plan/file_groups.rs index 6456bd5c7276..28f975ae193d 100644 --- a/datafusion/core/src/datasource/physical_plan/file_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/file_groups.rs @@ -256,7 +256,7 @@ impl FileGroupPartitioner { }, ) .flatten() - .group_by(|(partition_idx, _)| *partition_idx) + .chunk_by(|(partition_idx, _)| *partition_idx) .into_iter() .map(|(_, group)| group.map(|(_, vals)| vals).collect_vec()) .collect_vec(); @@ -394,7 +394,7 @@ mod test { #[test] fn repartition_empty_file_only() { let partitioned_file_empty = pfile("empty", 0); - let file_group = vec![vec![partitioned_file_empty.clone()]]; + let file_group = vec![vec![partitioned_file_empty]]; let partitioned_files = FileGroupPartitioner::new() .with_target_partitions(4) @@ -817,10 +817,7 @@ mod test { .with_preserve_order_within_groups(true) .repartition_file_groups(&file_groups); - assert_partitioned_files( - repartitioned.clone(), - repartitioned_preserving_sort.clone(), - ); + assert_partitioned_files(repartitioned.clone(), repartitioned_preserving_sort); repartitioned } } diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 1ea411cb6f59..74ab0126a557 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -19,10 +19,11 @@ //! file sources. use std::{ - borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, mem::size_of, + sync::Arc, vec, }; -use super::{get_projected_output_ordering, FileGroupPartitioner}; +use super::{get_projected_output_ordering, statistics::MinMaxStatistics}; use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{error::Result, scalar::ScalarValue}; @@ -34,6 +35,7 @@ use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, DataFusionError, Statistics}; use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; use log::warn; @@ -54,7 +56,7 @@ pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { } /// Convert a [`ScalarValue`] of partition columns to a type, as -/// decribed in the documentation of [`wrap_partition_type_in_dict`], +/// described in the documentation of [`wrap_partition_type_in_dict`], /// which can wrap the types. pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) @@ -62,12 +64,41 @@ pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { /// The base configurations to provide when creating a physical plan for /// any given file format. +/// +/// # Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_schema::Schema; +/// use datafusion::datasource::listing::PartitionedFile; +/// # use datafusion::datasource::physical_plan::FileScanConfig; +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # let file_schema = Arc::new(Schema::empty()); +/// // create FileScan config for reading data from file:// +/// let object_store_url = ObjectStoreUrl::local_filesystem(); +/// let config = FileScanConfig::new(object_store_url, file_schema) +/// .with_limit(Some(1000)) // read only the first 1000 records +/// .with_projection(Some(vec![2, 3])) // project columns 2 and 3 +/// // Read /tmp/file1.parquet with known size of 1234 bytes in a single group +/// .with_file(PartitionedFile::new("file1.parquet", 1234)) +/// // Read /tmp/file2.parquet 56 bytes and /tmp/file3.parquet 78 bytes +/// // in a single row group +/// .with_file_group(vec![ +/// PartitionedFile::new("file2.parquet", 56), +/// PartitionedFile::new("file3.parquet", 78), +/// ]); +/// ``` #[derive(Clone)] pub struct FileScanConfig { /// Object store URL, used to get an [`ObjectStore`] instance from /// [`RuntimeEnv::object_store`] /// + /// This `ObjectStoreUrl` should be the prefix of the absolute url for files + /// as `file://` or `s3://my_bucket`. It should not include the path to the + /// file itself. The relevant URL prefix must be registered via + /// [`RuntimeEnv::register_object_store`] + /// /// [`ObjectStore`]: object_store::ObjectStore + /// [`RuntimeEnv::register_object_store`]: datafusion_execution::runtime_env::RuntimeEnv::register_object_store /// [`RuntimeEnv::object_store`]: datafusion_execution::runtime_env::RuntimeEnv::object_store pub object_store_url: ObjectStoreUrl, /// Schema before `projection` is applied. It contains the all columns that may @@ -85,6 +116,7 @@ pub struct FileScanConfig { /// sequentially, one after the next. pub file_groups: Vec>, /// Estimated overall statistics of the files, taking `filters` into account. + /// Defaults to [`Statistics::new_unknown`]. pub statistics: Statistics, /// Columns on which to project the data. Indexes that are higher than the /// number of columns of `file_schema` refer to `table_partition_cols`. @@ -99,6 +131,86 @@ pub struct FileScanConfig { } impl FileScanConfig { + /// Create a new `FileScanConfig` with default settings for scanning files. + /// + /// See example on [`FileScanConfig`] + /// + /// No file groups are added by default. See [`Self::with_file`], [`Self::with_file_group]` and + /// [`Self::with_file_groups`]. + /// + /// # Parameters: + /// * `object_store_url`: See [`Self::object_store_url`] + /// * `file_schema`: See [`Self::file_schema`] + pub fn new(object_store_url: ObjectStoreUrl, file_schema: SchemaRef) -> Self { + let statistics = Statistics::new_unknown(&file_schema); + Self { + object_store_url, + file_schema, + file_groups: vec![], + statistics, + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + } + } + + /// Set the statistics of the files + pub fn with_statistics(mut self, statistics: Statistics) -> Self { + self.statistics = statistics; + self + } + + /// Set the projection of the files + pub fn with_projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + /// Set the limit of the files + pub fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + /// Add a file as a single group + /// + /// See [Self::file_groups] for more information. + pub fn with_file(self, file: PartitionedFile) -> Self { + self.with_file_group(vec![file]) + } + + /// Add the file groups + /// + /// See [Self::file_groups] for more information. + pub fn with_file_groups( + mut self, + mut file_groups: Vec>, + ) -> Self { + self.file_groups.append(&mut file_groups); + self + } + + /// Add a new file group + /// + /// See [Self::file_groups] for more information + pub fn with_file_group(mut self, file_group: Vec) -> Self { + self.file_groups.push(file_group); + self + } + + /// Set the partitioning columns of the files + pub fn with_table_partition_cols(mut self, table_partition_cols: Vec) -> Self { + self.table_partition_cols = table_partition_cols; + self + } + + /// Set the output ordering of the files + pub fn with_output_ordering(mut self, output_ordering: Vec) -> Self { + self.output_ordering = output_ordering; + self + } + /// Project the schema and the statistics on the given column indices pub fn project(&self) -> (SchemaRef, Statistics, Vec) { if self.projection.is_none() && self.table_partition_cols.is_empty() { @@ -132,21 +244,24 @@ impl FileScanConfig { } let table_stats = Statistics { - num_rows: self.statistics.num_rows.clone(), + num_rows: self.statistics.num_rows, // TODO correct byte size? total_byte_size: Precision::Absent, column_statistics: table_cols_stats, }; - let table_schema = Arc::new( - Schema::new(table_fields).with_metadata(self.file_schema.metadata().clone()), - ); + let projected_schema = Arc::new(Schema::new_with_metadata( + table_fields, + self.file_schema.metadata().clone(), + )); + let projected_output_ordering = - get_projected_output_ordering(self, &table_schema); - (table_schema, table_stats, projected_output_ordering) + get_projected_output_ordering(self, &projected_schema); + + (projected_schema, table_stats, projected_output_ordering) } - #[allow(unused)] // Only used by avro + #[cfg_attr(not(feature = "avro"), allow(unused))] // Only used by avro pub(crate) fn projected_file_column_names(&self) -> Option> { self.projection.as_ref().map(|p| { p.iter() @@ -169,7 +284,12 @@ impl FileScanConfig { fields.map_or_else( || Arc::clone(&self.file_schema), - |f| Arc::new(Schema::new(f).with_metadata(self.file_schema.metadata.clone())), + |f| { + Arc::new(Schema::new_with_metadata( + f, + self.file_schema.metadata.clone(), + )) + }, ) } @@ -182,17 +302,69 @@ impl FileScanConfig { }) } - #[allow(missing_docs)] - #[deprecated(since = "33.0.0", note = "Use SessionContext::new_with_config")] - pub fn repartition_file_groups( - file_groups: Vec>, - target_partitions: usize, - repartition_file_min_size: usize, - ) -> Option>> { - FileGroupPartitioner::new() - .with_target_partitions(target_partitions) - .with_repartition_file_min_size(repartition_file_min_size) - .repartition_file_groups(&file_groups) + /// Attempts to do a bin-packing on files into file groups, such that any two files + /// in a file group are ordered and non-overlapping with respect to their statistics. + /// It will produce the smallest number of file groups possible. + pub fn split_groups_by_statistics( + table_schema: &SchemaRef, + file_groups: &[Vec], + sort_order: LexOrderingRef, + ) -> Result>> { + let flattened_files = file_groups.iter().flatten().collect::>(); + // First Fit: + // * Choose the first file group that a file can be placed into. + // * If it fits into no existing file groups, create a new one. + // + // By sorting files by min values and then applying first-fit bin packing, + // we can produce the smallest number of file groups such that + // files within a group are in order and non-overlapping. + // + // Source: Applied Combinatorics (Keller and Trotter), Chapter 6.8 + // https://www.appliedcombinatorics.org/book/s_posets_dilworth-intord.html + + if flattened_files.is_empty() { + return Ok(vec![]); + } + + let statistics = MinMaxStatistics::new_from_files( + sort_order, + table_schema, + None, + flattened_files.iter().copied(), + ) + .map_err(|e| { + e.context("construct min/max statistics for split_groups_by_statistics") + })?; + + let indices_sorted_by_min = statistics.min_values_sorted(); + let mut file_groups_indices: Vec> = vec![]; + + for (idx, min) in indices_sorted_by_min { + let file_group_to_insert = file_groups_indices.iter_mut().find(|group| { + // If our file is non-overlapping and comes _after_ the last file, + // it fits in this file group. + min > statistics.max( + *group + .last() + .expect("groups should be nonempty at construction"), + ) + }); + match file_group_to_insert { + Some(group) => group.push(idx), + None => file_groups_indices.push(vec![idx]), + } + } + + // Assemble indices back into groups of PartitionedFiles + Ok(file_groups_indices + .into_iter() + .map(|file_group_indices| { + file_group_indices + .into_iter() + .map(|idx| flattened_files[idx].clone()) + .collect() + }) + .collect()) } } @@ -327,7 +499,7 @@ impl ZeroBufferGenerator where T: ArrowNativeType, { - const SIZE: usize = std::mem::size_of::(); + const SIZE: usize = size_of::(); fn get_buffer(&mut self, n_vals: usize) -> Buffer { match &mut self.cache { @@ -503,7 +675,7 @@ mod tests { vec![table_partition_col.clone()], ); - // verify the proj_schema inlcudes the last column and exactly the same the field it is defined + // verify the proj_schema includes the last column and exactly the same the field it is defined let (proj_schema, _proj_statistics, _) = conf.project(); assert_eq!(proj_schema.fields().len(), file_schema.fields().len() + 1); assert_eq!( @@ -729,7 +901,7 @@ mod tests { schema.clone(), Some(vec![0, 3, 5, schema.fields().len()]), Statistics::new_unknown(&schema), - to_partition_cols(partition_cols.clone()), + to_partition_cols(partition_cols), ) .projected_file_schema(); @@ -762,7 +934,7 @@ mod tests { schema.clone(), None, Statistics::new_unknown(&schema), - to_partition_cols(partition_cols.clone()), + to_partition_cols(partition_cols), ) .projected_file_schema(); @@ -770,6 +942,277 @@ mod tests { assert_eq!(projection.fields(), schema.fields()); } + #[test] + fn test_split_groups_by_statistics() -> Result<()> { + use chrono::TimeZone; + use datafusion_common::DFSchema; + use datafusion_expr::execution_props::ExecutionProps; + use object_store::{path::Path, ObjectMeta}; + + struct File { + name: &'static str, + date: &'static str, + statistics: Vec>, + } + impl File { + fn new( + name: &'static str, + date: &'static str, + statistics: Vec>, + ) -> Self { + Self { + name, + date, + statistics, + } + } + } + + struct TestCase { + name: &'static str, + file_schema: Schema, + files: Vec, + sort: Vec, + expected_result: Result>, &'static str>, + } + + use datafusion_expr::col; + let cases = vec![ + TestCase { + name: "test sort", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), + }, + // same input but file '2' is in the middle + // test that we still order correctly + TestCase { + name: "test sort with files ordered differently", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1"], vec!["2"]]), + }, + TestCase { + name: "reverse sort", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + ], + sort: vec![col("value").sort(false, true)], + expected_result: Ok(vec![vec!["1", "0"], vec!["2"]]), + }, + // reject nullable sort columns + TestCase { + name: "no nullable sort columns", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + true, // should fail because nullable + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 1.00))]), + File::new("2", "2023-01-02", vec![Some((0.00, 1.00))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\nbuild min rows\ncaused by\ncreate sorting columns\ncaused by\nError during planning: cannot sort by nullable column") + }, + TestCase { + name: "all three non-overlapping", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.50, 0.99))]), + File::new("2", "2023-01-02", vec![Some((1.00, 1.49))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0", "1", "2"]]), + }, + TestCase { + name: "all three overlapping", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("2", "2023-01-02", vec![Some((0.00, 0.49))]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![vec!["0"], vec!["1"], vec!["2"]]), + }, + TestCase { + name: "empty input", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![], + sort: vec![col("value").sort(true, false)], + expected_result: Ok(vec![]), + }, + TestCase { + name: "one file missing statistics", + file_schema: Schema::new(vec![Field::new( + "value".to_string(), + DataType::Float64, + false, + )]), + files: vec![ + File::new("0", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("1", "2023-01-01", vec![Some((0.00, 0.49))]), + File::new("2", "2023-01-02", vec![None]), + ], + sort: vec![col("value").sort(true, false)], + expected_result: Err("construct min/max statistics for split_groups_by_statistics\ncaused by\ncollect min/max values\ncaused by\nget min/max for column: 'value'\ncaused by\nError during planning: statistics not found"), + }, + ]; + + for case in cases { + let table_schema = Arc::new(Schema::new( + case.file_schema + .fields() + .clone() + .into_iter() + .cloned() + .chain(Some(Arc::new(Field::new( + "date".to_string(), + DataType::Utf8, + false, + )))) + .collect::>(), + )); + let sort_order = case + .sort + .into_iter() + .map(|expr| { + crate::physical_planner::create_physical_sort_expr( + &expr, + &DFSchema::try_from(table_schema.as_ref().clone())?, + &ExecutionProps::default(), + ) + }) + .collect::>>()?; + + let partitioned_files = + case.files.into_iter().map(From::from).collect::>(); + let result = FileScanConfig::split_groups_by_statistics( + &table_schema, + &[partitioned_files.clone()], + &sort_order, + ); + let results_by_name = result + .as_ref() + .map(|file_groups| { + file_groups + .iter() + .map(|file_group| { + file_group + .iter() + .map(|file| { + partitioned_files + .iter() + .find_map(|f| { + if f.object_meta == file.object_meta { + Some( + f.object_meta + .location + .as_ref() + .rsplit('/') + .next() + .unwrap() + .trim_end_matches(".parquet"), + ) + } else { + None + } + }) + .unwrap() + }) + .collect::>() + }) + .collect::>() + }) + .map_err(|e| e.to_string().leak() as &'static str); + + assert_eq!(results_by_name, case.expected_result, "{}", case.name); + } + + return Ok(()); + + impl From for PartitionedFile { + fn from(file: File) -> Self { + PartitionedFile { + object_meta: ObjectMeta { + location: Path::from(format!( + "data/date={}/{}.parquet", + file.date, file.name + )), + last_modified: chrono::Utc.timestamp_nanos(0), + size: 0, + e_tag: None, + version: None, + }, + partition_values: vec![ScalarValue::from(file.date)], + range: None, + statistics: Some(Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: file + .statistics + .into_iter() + .map(|stats| { + stats + .map(|(min, max)| ColumnStatistics { + min_value: Precision::Exact(ScalarValue::from( + min, + )), + max_value: Precision::Exact(ScalarValue::from( + max, + )), + ..Default::default() + }) + .unwrap_or_default() + }) + .collect::>(), + }), + extensions: None, + } + } + } + } + // sets default for configs that play no role in projections fn config_for_projection( file_schema: SchemaRef, @@ -777,16 +1220,10 @@ mod tests { statistics: Statistics, table_partition_cols: Vec, ) -> FileScanConfig { - FileScanConfig { - file_schema, - file_groups: vec![vec![]], - limit: None, - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - projection, - statistics, - table_partition_cols, - output_ordering: vec![], - } + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), file_schema) + .with_projection(projection) + .with_statistics(statistics) + .with_table_partition_cols(table_partition_cols) } /// Convert partition columns from Vec to Vec diff --git a/datafusion/core/src/datasource/physical_plan/file_stream.rs b/datafusion/core/src/datasource/physical_plan/file_stream.rs index 619bcb29e2cc..6f354b31ae87 100644 --- a/datafusion/core/src/datasource/physical_plan/file_stream.rs +++ b/datafusion/core/src/datasource/physical_plan/file_stream.rs @@ -519,15 +519,12 @@ mod tests { use std::sync::Arc; use super::*; - use crate::datasource::file_format::write::BatchSerializer; use crate::datasource::object_store::ObjectStoreUrl; use crate::prelude::SessionContext; use crate::test::{make_partition, object_store::register_test_store}; use arrow_schema::Schema; - use datafusion_common::{internal_err, Statistics}; - - use bytes::Bytes; + use datafusion_common::internal_err; /// Test `FileOpener` which will simulate errors during file opening or scanning #[derive(Default)] @@ -646,16 +643,12 @@ mod tests { let on_error = self.on_error; - let config = FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - statistics: Statistics::new_unknown(&file_schema), + let config = FileScanConfig::new( + ObjectStoreUrl::parse("test:///").unwrap(), file_schema, - file_groups: vec![file_group], - projection: None, - limit: self.limit, - table_partition_cols: vec![], - output_ordering: vec![], - }; + ) + .with_file_group(file_group) + .with_limit(self.limit); let metrics_set = ExecutionPlanMetricsSet::new(); let file_stream = FileStream::new(&config, 0, self.opener, &metrics_set) .unwrap() @@ -974,14 +967,4 @@ mod tests { Ok(()) } - - struct TestSerializer { - bytes: Bytes, - } - - impl BatchSerializer for TestSerializer { - fn serialize(&self, _batch: RecordBatch, _initial: bool) -> Result { - Ok(self.bytes.clone()) - } - } } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 2ec1b91d08ea..cf8f129a5036 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -138,7 +138,7 @@ impl ExecutionPlan for NdJsonExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { Vec::new() } @@ -154,7 +154,7 @@ impl ExecutionPlan for NdJsonExec { target_partitions: usize, config: &datafusion_common::config::ConfigOptions, ) -> Result>> { - if self.file_compression_type == FileCompressionType::GZIP { + if self.file_compression_type.is_compressed() { return Ok(None); } let repartition_file_min_size = config.optimizer.repartition_file_min_size; @@ -206,6 +206,22 @@ impl ExecutionPlan for NdJsonExec { fn metrics(&self) -> Option { Some(self.metrics.clone_inner()) } + + fn fetch(&self) -> Option { + self.base_config.limit + } + + fn with_fetch(&self, limit: Option) -> Option> { + let new_config = self.base_config.clone().with_limit(limit); + + Some(Arc::new(Self { + base_config: new_config, + projected_statistics: self.projected_statistics.clone(), + metrics: self.metrics.clone(), + file_compression_type: self.file_compression_type, + cache: self.cache.clone(), + })) + } } /// A [`FileOpener`] that opens a JSON file and yields a [`FileOpenFuture`] @@ -383,9 +399,7 @@ mod tests { use std::path::Path; use super::*; - use crate::assert_batches_eq; use crate::dataframe::DataFrameWriteOptions; - use crate::datasource::file_format::file_compression_type::FileTypeExt; use crate::datasource::file_format::{json::JsonFormat, FileFormat}; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; @@ -393,18 +407,14 @@ mod tests { CsvReadOptions, NdJsonReadOptions, SessionConfig, SessionContext, }; use crate::test::partitioned_file_groups; + use crate::{assert_batches_eq, assert_batches_sorted_eq}; use arrow::array::Array; use arrow::datatypes::{Field, SchemaBuilder}; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; - use datafusion_common::FileType; - use flate2::write::GzEncoder; - use flate2::Compression; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; - use std::fs::File; - use std::io; use tempfile::TempDir; use url::Url; @@ -423,7 +433,7 @@ mod tests { TEST_DATA_BASE, filename, 1, - FileType::JSON, + Arc::new(JsonFormat::default()), file_compression_type.to_owned(), work_dir, ) @@ -450,14 +460,14 @@ mod tests { ) -> Result<()> { let ctx = SessionContext::new(); let url = Url::parse("file://").unwrap(); - ctx.runtime_env().register_object_store(&url, store.clone()); + ctx.register_object_store(&url, store.clone()); let filename = "1.json"; let tmp_dir = TempDir::new()?; let file_groups = partitioned_file_groups( TEST_DATA_BASE, filename, 1, - FileType::JSON, + Arc::new(JsonFormat::default()), file_compression_type.to_owned(), tmp_dir.path(), ) @@ -476,8 +486,8 @@ mod tests { let path_buf = Path::new(url.path()).join(path); let path = path_buf.to_str().unwrap(); - let ext = FileType::JSON - .get_ext_with_compression(file_compression_type.to_owned()) + let ext = JsonFormat::default() + .get_ext_with_compression(&file_compression_type) .unwrap(); let read_options = NdJsonReadOptions::default() @@ -525,16 +535,9 @@ mod tests { prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let exec = NdJsonExec::new( - FileScanConfig { - object_store_url, - file_groups, - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: None, - limit: Some(3), - table_partition_cols: vec![], - output_ordering: vec![], - }, + FileScanConfig::new(object_store_url, file_schema) + .with_file_groups(file_groups) + .with_limit(Some(3)), file_compression_type.to_owned(), ); @@ -603,16 +606,9 @@ mod tests { let missing_field_idx = file_schema.fields.len() - 1; let exec = NdJsonExec::new( - FileScanConfig { - object_store_url, - file_groups, - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: None, - limit: Some(3), - table_partition_cols: vec![], - output_ordering: vec![], - }, + FileScanConfig::new(object_store_url, file_schema) + .with_file_groups(file_groups) + .with_limit(Some(3)), file_compression_type.to_owned(), ); @@ -650,16 +646,9 @@ mod tests { prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let exec = NdJsonExec::new( - FileScanConfig { - object_store_url, - file_groups, - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: Some(vec![0, 2]), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, + FileScanConfig::new(object_store_url, file_schema) + .with_file_groups(file_groups) + .with_projection(Some(vec![0, 2])), file_compression_type.to_owned(), ); let inferred_schema = exec.schema(); @@ -702,16 +691,9 @@ mod tests { prepare_store(&state, file_compression_type.to_owned(), tmp_dir.path()).await; let exec = NdJsonExec::new( - FileScanConfig { - object_store_url, - file_groups, - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: Some(vec![3, 0, 2]), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, + FileScanConfig::new(object_store_url, file_schema) + .with_file_groups(file_groups) + .with_projection(Some(vec![3, 0, 2])), file_compression_type.to_owned(), ); let inferred_schema = exec.schema(); @@ -756,7 +738,7 @@ mod tests { let tmp_dir = TempDir::new()?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); // execute a simple query and write the results to CSV let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/"; @@ -849,7 +831,7 @@ mod tests { let tmp_dir = TempDir::new()?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); let options = CsvReadOptions::default() .schema_infer_max_records(2) .has_header(true); @@ -892,36 +874,65 @@ mod tests { Ok(()) } - fn compress_file(path: &str, output_path: &str) -> io::Result<()> { - let input_file = File::open(path)?; - let mut reader = BufReader::new(input_file); - let output_file = File::create(output_path)?; - let writer = std::io::BufWriter::new(output_file); - - let mut encoder = GzEncoder::new(writer, Compression::default()); - io::copy(&mut reader, &mut encoder)?; - - encoder.finish()?; - Ok(()) - } + #[rstest( + file_compression_type, + case::uncompressed(FileCompressionType::UNCOMPRESSED), + case::gzip(FileCompressionType::GZIP), + case::bzip2(FileCompressionType::BZIP2), + case::xz(FileCompressionType::XZ), + case::zstd(FileCompressionType::ZSTD) + )] + #[cfg(feature = "compression")] #[tokio::test] - async fn test_disable_parallel_for_json_gz() -> Result<()> { + async fn test_json_with_repartitioing( + file_compression_type: FileCompressionType, + ) -> Result<()> { let config = SessionConfig::new() .with_repartition_file_scans(true) .with_repartition_file_min_size(0) .with_target_partitions(4); let ctx = SessionContext::new_with_config(config); - let path = format!("{TEST_DATA_BASE}/1.json"); - let compressed_path = format!("{}.gz", &path); - compress_file(&path, &compressed_path)?; + + let tmp_dir = TempDir::new()?; + let (store_url, file_groups, _) = + prepare_store(&ctx.state(), file_compression_type, tmp_dir.path()).await; + + // It's important to have less than `target_partitions` amount of file groups, to + // trigger repartitioning. + assert_eq!( + file_groups.len(), + 1, + "Expected prepared store with single file group" + ); + + let path = file_groups + .first() + .unwrap() + .first() + .unwrap() + .object_meta + .location + .as_ref(); + + let url: &Url = store_url.as_ref(); + let path_buf = Path::new(url.path()).join(path); + let path = path_buf.to_str().unwrap(); + let ext = JsonFormat::default() + .get_ext_with_compression(&file_compression_type) + .unwrap(); + let read_option = NdJsonReadOptions::default() - .file_compression_type(FileCompressionType::GZIP) - .file_extension("gz"); - let df = ctx.read_json(compressed_path.clone(), read_option).await?; + .file_compression_type(file_compression_type) + .file_extension(ext.as_str()); + + let df = ctx.read_json(path, read_option).await?; let res = df.collect().await; - fs::remove_file(&compressed_path)?; - assert_batches_eq!( + + // Output sort order is nondeterministic due to multiple + // target partitions. To handle it, assert compares sorted + // result. + assert_batches_sorted_eq!( &[ "+-----+------------------+---------------+------+", "| a | b | c | d |", diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index ddb8d032f3d8..9971e87282a5 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -26,6 +26,7 @@ mod file_stream; mod json; #[cfg(feature = "parquet")] pub mod parquet; +mod statistics; pub(crate) use self::csv::plan_to_csv; pub(crate) use self::json::plan_to_json; @@ -34,7 +35,8 @@ pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactor pub use arrow_file::ArrowExec; pub use avro::AvroExec; -pub use csv::{CsvConfig, CsvExec, CsvOpener}; +pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; +use datafusion_expr::dml::InsertOp; pub use file_groups::FileGroupPartitioner; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, @@ -60,15 +62,10 @@ use crate::{ physical_plan::display::{display_orderings, ProjectSchemaDisplay}, }; -use arrow::{ - array::new_null_array, - compute::{can_cast_types, cast}, - datatypes::{DataType, Schema, SchemaRef}, - record_batch::{RecordBatch, RecordBatchOptions}, -}; -use datafusion_common::plan_err; +use arrow::datatypes::{DataType, SchemaRef}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::StreamExt; use log::debug; @@ -88,8 +85,11 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// Controls whether existing data should be overwritten by this sink - pub overwrite: bool, + /// Controls how new data should be written to the file, determining whether + /// to append to, overwrite, or replace records in existing files. + pub insert_op: InsertOp, + /// Controls whether partition columns are kept for the file + pub keep_partition_by_columns: bool, } impl FileSinkConfig { @@ -240,125 +240,6 @@ where Ok(()) } -/// A utility which can adapt file-level record batches to a table schema which may have a schema -/// obtained from merging multiple file-level schemas. -/// -/// This is useful for enabling schema evolution in partitioned datasets. -/// -/// This has to be done in two stages. -/// -/// 1. Before reading the file, we have to map projected column indexes from the table schema to -/// the file schema. -/// -/// 2. After reading a record batch we need to map the read columns back to the expected columns -/// indexes and insert null-valued columns wherever the file schema was missing a colum present -/// in the table schema. -#[derive(Clone, Debug)] -pub(crate) struct SchemaAdapter { - /// Schema for the table - table_schema: SchemaRef, -} - -impl SchemaAdapter { - pub(crate) fn new(table_schema: SchemaRef) -> SchemaAdapter { - Self { table_schema } - } - - /// Map a column index in the table schema to a column index in a particular - /// file schema - /// - /// Panics if index is not in range for the table schema - pub(crate) fn map_column_index( - &self, - index: usize, - file_schema: &Schema, - ) -> Option { - let field = self.table_schema.field(index); - Some(file_schema.fields.find(field.name())?.0) - } - - /// Creates a `SchemaMapping` that can be used to cast or map the columns from the file schema to the table schema. - /// - /// If the provided `file_schema` contains columns of a different type to the expected - /// `table_schema`, the method will attempt to cast the array data from the file schema - /// to the table schema where possible. - /// - /// Returns a [`SchemaMapping`] that can be applied to the output batch - /// along with an ordered list of columns to project from the file - pub fn map_schema( - &self, - file_schema: &Schema, - ) -> Result<(SchemaMapping, Vec)> { - let mut projection = Vec::with_capacity(file_schema.fields().len()); - let mut field_mappings = vec![None; self.table_schema.fields().len()]; - - for (file_idx, file_field) in file_schema.fields.iter().enumerate() { - if let Some((table_idx, table_field)) = - self.table_schema.fields().find(file_field.name()) - { - match can_cast_types(file_field.data_type(), table_field.data_type()) { - true => { - field_mappings[table_idx] = Some(projection.len()); - projection.push(file_idx); - } - false => { - return plan_err!( - "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", - file_field.name(), - file_field.data_type(), - table_field.data_type() - ) - } - } - } - } - - Ok(( - SchemaMapping { - table_schema: self.table_schema.clone(), - field_mappings, - }, - projection, - )) - } -} - -/// The SchemaMapping struct holds a mapping from the file schema to the table schema -/// and any necessary type conversions that need to be applied. -#[derive(Debug)] -pub struct SchemaMapping { - /// The schema of the table. This is the expected schema after conversion and it should match the schema of the query result. - table_schema: SchemaRef, - /// Mapping from field index in `table_schema` to index in projected file_schema - field_mappings: Vec>, -} - -impl SchemaMapping { - /// Adapts a `RecordBatch` to match the `table_schema` using the stored mapping and conversions. - fn map_batch(&self, batch: RecordBatch) -> Result { - let batch_rows = batch.num_rows(); - let batch_cols = batch.columns().to_vec(); - - let cols = self - .table_schema - .fields() - .iter() - .zip(&self.field_mappings) - .map(|(field, file_idx)| match file_idx { - Some(batch_idx) => cast(&batch_cols[*batch_idx], field.data_type()), - None => Ok(new_null_array(field.data_type(), batch_rows)), - }) - .collect::, _>>()?; - - // Necessary to handle empty batches - let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - - let schema = self.table_schema.clone(); - let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; - Ok(record_batch) - } -} - /// A single file or part of a file that should be read, along with its schema, statistics pub struct FileMeta { /// Path for the file (e.g. URL, filesystem path, etc) @@ -448,16 +329,11 @@ impl From for FileMeta { fn get_projected_output_ordering( base_config: &FileScanConfig, projected_schema: &SchemaRef, -) -> Vec> { +) -> Vec { let mut all_orderings = vec![]; for output_ordering in &base_config.output_ordering { - if base_config.file_groups.iter().any(|group| group.len() > 1) { - debug!("Skipping specified output ordering {:?}. Some file group had more than one file: {:?}", - base_config.output_ordering[0], base_config.file_groups); - return vec![]; - } - let mut new_ordering = vec![]; - for PhysicalSortExpr { expr, options } in output_ordering { + let mut new_ordering = LexOrdering::default(); + for PhysicalSortExpr { expr, options } in output_ordering.iter() { if let Some(col) = expr.as_any().downcast_ref::() { let name = col.name(); if let Some((idx, _)) = projected_schema.column_with_name(name) { @@ -473,11 +349,45 @@ fn get_projected_output_ordering( // since rest of the orderings are violated break; } + // do not push empty entries // otherwise we may have `Some(vec![])` at the output ordering. - if !new_ordering.is_empty() { - all_orderings.push(new_ordering); + if new_ordering.is_empty() { + continue; } + + // Check if any file groups are not sorted + if base_config.file_groups.iter().any(|group| { + if group.len() <= 1 { + // File groups with <= 1 files are always sorted + return false; + } + + let statistics = match statistics::MinMaxStatistics::new_from_files( + &new_ordering, + projected_schema, + base_config.projection.as_deref(), + group, + ) { + Ok(statistics) => statistics, + Err(e) => { + log::trace!("Error fetching statistics for file group: {e}"); + // we can't prove that it's ordered, so we have to reject it + return true; + } + }; + + !statistics.is_sorted() + }) { + debug!( + "Skipping specified output ordering {:?}. \ + Some file groups couldn't be determined to be sorted: {:?}", + base_config.output_ordering[0], base_config.file_groups + ); + continue; + } + + all_orderings.push(new_ordering); } all_orderings } @@ -591,11 +501,14 @@ mod tests { use arrow_array::cast::AsArray; use arrow_array::types::{Float32Type, Float64Type, UInt32Type}; use arrow_array::{ - BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, StringArray, - UInt64Array, + BinaryArray, BooleanArray, Float32Array, Int32Array, Int64Array, RecordBatch, + StringArray, UInt64Array, }; - use arrow_schema::Field; + use arrow_schema::{Field, Schema}; + use crate::datasource::schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapterFactory, + }; use chrono::Utc; #[test] @@ -606,7 +519,8 @@ mod tests { Field::new("c3", DataType::Float64, true), ])); - let adapter = SchemaAdapter::new(table_schema.clone()); + let adapter = DefaultSchemaAdapterFactory + .create(table_schema.clone(), table_schema.clone()); let file_schema = Schema::new(vec![ Field::new("c1", DataType::Utf8, true), @@ -663,7 +577,7 @@ mod tests { let indices = vec![1, 2, 4]; let schema = SchemaRef::from(table_schema.project(&indices).unwrap()); - let adapter = SchemaAdapter::new(schema); + let adapter = DefaultSchemaAdapterFactory.create(schema, table_schema.clone()); let (mapping, projection) = adapter.map_schema(&file_schema).unwrap(); let id = Int32Array::from(vec![Some(1), Some(2), Some(3)]); @@ -850,7 +764,7 @@ mod tests { /// create a PartitionedFile for testing fn partitioned_file(path: &str) -> PartitionedFile { let object_meta = ObjectMeta { - location: object_store::path::Path::parse(path).unwrap(), + location: Path::parse(path).unwrap(), last_modified: Utc::now(), size: 42, e_tag: None, @@ -861,6 +775,7 @@ mod tests { object_meta, partition_values: vec![], range: None, + statistics: None, extensions: None, } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs new file mode 100644 index 000000000000..ea3030664b7b --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/access_plan.rs @@ -0,0 +1,555 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{internal_err, Result}; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; +use parquet::file::metadata::RowGroupMetaData; + +/// A selection of rows and row groups within a ParquetFile to decode. +/// +/// A `ParquetAccessPlan` is used to limit the row groups and data pages a `ParquetExec` +/// will read and decode to improve performance. +/// +/// Note that page level pruning based on ArrowPredicate is applied after all of +/// these selections +/// +/// # Example +/// +/// For example, given a Parquet file with 4 row groups, a `ParquetAccessPlan` +/// can be used to specify skipping row group 0 and 2, scanning a range of rows +/// in row group 1, and scanning all rows in row group 3 as follows: +/// +/// ```rust +/// # use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; +/// # use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; +/// // Default to scan all row groups +/// let mut access_plan = ParquetAccessPlan::new_all(4); +/// access_plan.skip(0); // skip row group +/// // Use parquet reader RowSelector to specify scanning rows 100-200 and 350-400 +/// // in a row group that has 1000 rows +/// let row_selection = RowSelection::from(vec![ +/// RowSelector::skip(100), +/// RowSelector::select(100), +/// RowSelector::skip(150), +/// RowSelector::select(50), +/// RowSelector::skip(600), // skip last 600 rows +/// ]); +/// access_plan.scan_selection(1, row_selection); +/// access_plan.skip(2); // skip row group 2 +/// // row group 3 is scanned by default +/// ``` +/// +/// The resulting plan would look like: +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// +/// │ │ SKIP +/// +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Row Group 0 +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// ┌────────────────┐ SCAN ONLY ROWS +/// │└────────────────┘ │ 100-200 +/// ┌────────────────┐ 350-400 +/// │└────────────────┘ │ +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// Row Group 1 +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +/// SKIP +/// │ │ +/// +/// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Row Group 2 +/// ┌───────────────────┐ +/// │ │ SCAN ALL ROWS +/// │ │ +/// │ │ +/// └───────────────────┘ +/// Row Group 3 +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct ParquetAccessPlan { + /// How to access the i-th row group + row_groups: Vec, +} + +/// Describes how the parquet reader will access a row group +#[derive(Debug, Clone, PartialEq)] +pub enum RowGroupAccess { + /// Do not read the row group at all + Skip, + /// Read all rows from the row group + Scan, + /// Scan only the specified rows within the row group + Selection(RowSelection), +} + +impl RowGroupAccess { + /// Return true if this row group should be scanned + pub fn should_scan(&self) -> bool { + match self { + RowGroupAccess::Skip => false, + RowGroupAccess::Scan | RowGroupAccess::Selection(_) => true, + } + } +} + +impl ParquetAccessPlan { + /// Create a new `ParquetAccessPlan` that scans all row groups + pub fn new_all(row_group_count: usize) -> Self { + Self { + row_groups: vec![RowGroupAccess::Scan; row_group_count], + } + } + + /// Create a new `ParquetAccessPlan` that scans no row groups + pub fn new_none(row_group_count: usize) -> Self { + Self { + row_groups: vec![RowGroupAccess::Skip; row_group_count], + } + } + + /// Create a new `ParquetAccessPlan` from the specified [`RowGroupAccess`]es + pub fn new(row_groups: Vec) -> Self { + Self { row_groups } + } + + /// Set the i-th row group to the specified [`RowGroupAccess`] + pub fn set(&mut self, idx: usize, access: RowGroupAccess) { + self.row_groups[idx] = access; + } + + /// skips the i-th row group (should not be scanned) + pub fn skip(&mut self, idx: usize) { + self.set(idx, RowGroupAccess::Skip); + } + + /// scan the i-th row group + pub fn scan(&mut self, idx: usize) { + self.set(idx, RowGroupAccess::Scan); + } + + /// Return true if the i-th row group should be scanned + pub fn should_scan(&self, idx: usize) -> bool { + self.row_groups[idx].should_scan() + } + + /// Set to scan only the [`RowSelection`] in the specified row group. + /// + /// Behavior is different depending on the existing access + /// * [`RowGroupAccess::Skip`]: does nothing + /// * [`RowGroupAccess::Scan`]: Updates to scan only the rows in the `RowSelection` + /// * [`RowGroupAccess::Selection`]: Updates to scan only the intersection of the existing selection and the new selection + pub fn scan_selection(&mut self, idx: usize, selection: RowSelection) { + self.row_groups[idx] = match &self.row_groups[idx] { + // already skipping the entire row group + RowGroupAccess::Skip => RowGroupAccess::Skip, + RowGroupAccess::Scan => RowGroupAccess::Selection(selection), + RowGroupAccess::Selection(existing_selection) => { + RowGroupAccess::Selection(existing_selection.intersection(&selection)) + } + } + } + + /// Return an overall `RowSelection`, if needed + /// + /// This is used to compute the row selection for the parquet reader. See + /// [`ArrowReaderBuilder::with_row_selection`] for more details. + /// + /// Returns + /// * `None` if there are no [`RowGroupAccess::Selection`] + /// * `Some(selection)` if there are [`RowGroupAccess::Selection`]s + /// + /// The returned selection represents which rows to scan across any row + /// row groups which are not skipped. + /// + /// # Notes + /// + /// If there are no [`RowGroupAccess::Selection`]s, the overall row + /// selection is `None` because each row group is either entirely skipped or + /// scanned, which is covered by [`Self::row_group_indexes`]. + /// + /// If there are any [`RowGroupAccess::Selection`], an overall row selection + /// is returned for *all* the rows in the row groups that are not skipped. + /// Thus it includes a `Select` selection for any [`RowGroupAccess::Scan`]. + /// + /// # Errors + /// + /// Returns an error if any specified row selection does not specify + /// the same number of rows as in it's corresponding `row_group_metadata`. + /// + /// # Example: No Selections + /// + /// Given an access plan like this + /// + /// ```text + /// RowGroupAccess::Scan (scan all row group 0) + /// RowGroupAccess::Skip (skip row group 1) + /// RowGroupAccess::Scan (scan all row group 2) + /// RowGroupAccess::Scan (scan all row group 3) + /// ``` + /// + /// The overall row selection would be `None` because there are no + /// [`RowGroupAccess::Selection`]s. The row group indexes + /// returned by [`Self::row_group_indexes`] would be `0, 2, 3` . + /// + /// # Example: With Selections + /// + /// Given an access plan like this: + /// + /// ```text + /// RowGroupAccess::Scan (scan all row group 0) + /// RowGroupAccess::Skip (skip row group 1) + /// RowGroupAccess::Select (skip 50, scan 50, skip 900) (scan rows 50-100 in row group 2) + /// RowGroupAccess::Scan (scan all row group 3) + /// ``` + /// + /// Assuming each row group has 1000 rows, the resulting row selection would + /// be the rows to scan in row group 0, 2 and 4: + /// + /// ```text + /// RowSelection::Select(1000) (scan all rows in row group 0) + /// RowSelection::Skip(50) (skip first 50 rows in row group 2) + /// RowSelection::Select(50) (scan rows 50-100 in row group 2) + /// RowSelection::Skip(900) (skip last 900 rows in row group 2) + /// RowSelection::Select(1000) (scan all rows in row group 3) + /// ``` + /// + /// Note there is no entry for the (entirely) skipped row group 1. + /// + /// The row group indexes returned by [`Self::row_group_indexes`] would + /// still be `0, 2, 3` . + /// + /// [`ArrowReaderBuilder::with_row_selection`]: parquet::arrow::arrow_reader::ArrowReaderBuilder::with_row_selection + pub fn into_overall_row_selection( + self, + row_group_meta_data: &[RowGroupMetaData], + ) -> Result> { + assert_eq!(row_group_meta_data.len(), self.row_groups.len()); + // Intuition: entire row groups are filtered out using + // `row_group_indexes` which come from Skip and Scan. An overall + // RowSelection is only useful if there is any parts *within* a row group + // which can be filtered out, that is a `Selection`. + if !self + .row_groups + .iter() + .any(|rg| matches!(rg, RowGroupAccess::Selection(_))) + { + return Ok(None); + } + + // validate all Selections + for (idx, (rg, rg_meta)) in self + .row_groups + .iter() + .zip(row_group_meta_data.iter()) + .enumerate() + { + let RowGroupAccess::Selection(selection) = rg else { + continue; + }; + let rows_in_selection = selection + .iter() + .map(|selection| selection.row_count) + .sum::(); + + let row_group_row_count = rg_meta.num_rows(); + if rows_in_selection as i64 != row_group_row_count { + return internal_err!( + "Invalid ParquetAccessPlan Selection. Row group {idx} has {row_group_row_count} rows \ + but selection only specifies {rows_in_selection} rows. \ + Selection: {selection:?}" + ); + } + } + + let total_selection: RowSelection = self + .row_groups + .into_iter() + .zip(row_group_meta_data.iter()) + .flat_map(|(rg, rg_meta)| { + match rg { + RowGroupAccess::Skip => vec![], + RowGroupAccess::Scan => { + // need a row group access to scan the entire row group (need row group counts) + vec![RowSelector::select(rg_meta.num_rows() as usize)] + } + RowGroupAccess::Selection(selection) => { + let selection: Vec = selection.into(); + selection + } + } + }) + .collect(); + + Ok(Some(total_selection)) + } + + /// Return an iterator over the row group indexes that should be scanned + pub fn row_group_index_iter(&self) -> impl Iterator + '_ { + self.row_groups.iter().enumerate().filter_map(|(idx, b)| { + if b.should_scan() { + Some(idx) + } else { + None + } + }) + } + + /// Return a vec of all row group indexes to scan + pub fn row_group_indexes(&self) -> Vec { + self.row_group_index_iter().collect() + } + + /// Return the total number of row groups (not the total number or groups to + /// scan) + pub fn len(&self) -> usize { + self.row_groups.len() + } + + /// Return true if there are no row groups + pub fn is_empty(&self) -> bool { + self.row_groups.is_empty() + } + + /// Get a reference to the inner accesses + pub fn inner(&self) -> &[RowGroupAccess] { + &self.row_groups + } + + /// Covert into the inner row group accesses + pub fn into_inner(self) -> Vec { + self.row_groups + } +} + +#[cfg(test)] +mod test { + use super::*; + use datafusion_common::assert_contains; + use parquet::basic::LogicalType; + use parquet::file::metadata::ColumnChunkMetaData; + use parquet::schema::types::{SchemaDescPtr, SchemaDescriptor}; + use std::sync::{Arc, OnceLock}; + + #[test] + fn test_only_scans() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Scan, + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); + + // scan all row groups, no selection + assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); + assert_eq!(row_selection, None); + } + + #[test] + fn test_only_skips() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, + RowGroupAccess::Skip, + RowGroupAccess::Skip, + RowGroupAccess::Skip, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); + + // skip all row groups, no selection + assert_eq!(row_group_indexes, vec![] as Vec); + assert_eq!(row_selection, None); + } + #[test] + fn test_mixed_1() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Selection( + // specifies all 20 rows in row group 1 + vec![ + RowSelector::select(5), + RowSelector::skip(7), + RowSelector::select(8), + ] + .into(), + ), + RowGroupAccess::Skip, + RowGroupAccess::Skip, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); + + assert_eq!(row_group_indexes, vec![0, 1]); + assert_eq!( + row_selection, + Some( + vec![ + // select the entire first row group + RowSelector::select(10), + // selectors from the second row group + RowSelector::select(5), + RowSelector::skip(7), + RowSelector::select(8) + ] + .into() + ) + ); + } + + #[test] + fn test_mixed_2() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, + RowGroupAccess::Scan, + RowGroupAccess::Selection( + // specify all 30 rows in row group 1 + vec![ + RowSelector::select(5), + RowSelector::skip(7), + RowSelector::select(18), + ] + .into(), + ), + RowGroupAccess::Scan, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let row_selection = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap(); + + assert_eq!(row_group_indexes, vec![1, 2, 3]); + assert_eq!( + row_selection, + Some( + vec![ + // select the entire second row group + RowSelector::select(20), + // selectors from the third row group + RowSelector::select(5), + RowSelector::skip(7), + RowSelector::select(18), + // select the entire fourth row group + RowSelector::select(40), + ] + .into() + ) + ); + } + + #[test] + fn test_invalid_too_few() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + // specify only 12 rows in selection, but row group 1 has 20 + RowGroupAccess::Selection( + vec![RowSelector::select(5), RowSelector::skip(7)].into(), + ), + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let err = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap_err() + .to_string(); + assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); + assert_contains!(err, "Internal error: Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 12 rows"); + } + + #[test] + fn test_invalid_too_many() { + let access_plan = ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + // specify 22 rows in selection, but row group 1 has only 20 + RowGroupAccess::Selection( + vec![ + RowSelector::select(10), + RowSelector::skip(2), + RowSelector::select(10), + ] + .into(), + ), + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ]); + + let row_group_indexes = access_plan.row_group_indexes(); + let err = access_plan + .into_overall_row_selection(row_group_metadata()) + .unwrap_err() + .to_string(); + assert_eq!(row_group_indexes, vec![0, 1, 2, 3]); + assert_contains!(err, "Invalid ParquetAccessPlan Selection. Row group 1 has 20 rows but selection only specifies 22 rows"); + } + + static ROW_GROUP_METADATA: OnceLock> = OnceLock::new(); + + /// [`RowGroupMetaData`] that returns 4 row groups with 10, 20, 30, 40 rows + /// respectively + fn row_group_metadata() -> &'static [RowGroupMetaData] { + ROW_GROUP_METADATA.get_or_init(|| { + let schema_descr = get_test_schema_descr(); + let row_counts = [10, 20, 30, 40]; + + row_counts + .into_iter() + .map(|num_rows| { + let column = ColumnChunkMetaData::builder(schema_descr.column(0)) + .set_num_values(num_rows) + .build() + .unwrap(); + + RowGroupMetaData::builder(schema_descr.clone()) + .set_num_rows(num_rows) + .set_column_metadata(vec![column]) + .build() + .unwrap() + }) + .collect() + }) + } + + /// Single column schema with a single column named "a" of type `BYTE_ARRAY`/`String` + fn get_test_schema_descr() -> SchemaDescPtr { + use parquet::basic::Type as PhysicalType; + use parquet::schema::types::Type as SchemaType; + let field = SchemaType::primitive_type_builder("a", PhysicalType::BYTE_ARRAY) + .with_logical_type(Some(LogicalType::String)) + .build() + .unwrap(); + let schema = SchemaType::group_type_builder("schema") + .with_fields(vec![Arc::new(field)]) + .build() + .unwrap(); + Arc::new(SchemaDescriptor::new(Arc::new(schema))) + } +} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs index c2a7e4345a5b..f1b5f71530dc 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs @@ -29,24 +29,34 @@ use crate::physical_plan::metrics::{ pub struct ParquetFileMetrics { /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, - /// Number of row groups whose bloom filters were checked and matched + /// Number of row groups whose bloom filters were checked and matched (not pruned) pub row_groups_matched_bloom_filter: Count, /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: Count, - /// Number of row groups whose statistics were checked and matched + /// Number of row groups whose statistics were checked and matched (not pruned) pub row_groups_matched_statistics: Count, /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: Count, /// Total number of bytes scanned pub bytes_scanned: Count, /// Total rows filtered out by predicates pushed into parquet scan - pub pushdown_rows_filtered: Count, - /// Total time spent evaluating pushdown filters - pub pushdown_eval_time: Time, + pub pushdown_rows_pruned: Count, + /// Total rows passed predicates pushed into parquet scan + pub pushdown_rows_matched: Count, + /// Total time spent evaluating row-level pushdown filters + pub row_pushdown_eval_time: Time, + /// Total time spent evaluating row group-level statistics filters + pub statistics_eval_time: Time, + /// Total time spent evaluating row group Bloom Filters + pub bloom_filter_eval_time: Time, /// Total rows filtered out by parquet page index - pub page_index_rows_filtered: Count, + pub page_index_rows_pruned: Count, + /// Total rows passed through the parquet page index + pub page_index_rows_matched: Count, /// Total time spent evaluating parquet page index filters pub page_index_eval_time: Time, + /// Total time spent reading and parsing metadata from the footer + pub metadata_load_time: Time, } impl ParquetFileMetrics { @@ -80,21 +90,38 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .counter("bytes_scanned", partition); - let pushdown_rows_filtered = MetricBuilder::new(metrics) + let pushdown_rows_pruned = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("pushdown_rows_filtered", partition); + .counter("pushdown_rows_pruned", partition); + let pushdown_rows_matched = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("pushdown_rows_matched", partition); - let pushdown_eval_time = MetricBuilder::new(metrics) + let row_pushdown_eval_time = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .subset_time("row_pushdown_eval_time", partition); + let statistics_eval_time = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .subset_time("pushdown_eval_time", partition); - let page_index_rows_filtered = MetricBuilder::new(metrics) + .subset_time("statistics_eval_time", partition); + let bloom_filter_eval_time = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) - .counter("page_index_rows_filtered", partition); + .subset_time("bloom_filter_eval_time", partition); + + let page_index_rows_pruned = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("page_index_rows_pruned", partition); + let page_index_rows_matched = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("page_index_rows_matched", partition); let page_index_eval_time = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .subset_time("page_index_eval_time", partition); + let metadata_load_time = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .subset_time("metadata_load_time", partition); + Self { predicate_evaluation_errors, row_groups_matched_bloom_filter, @@ -102,10 +129,15 @@ impl ParquetFileMetrics { row_groups_matched_statistics, row_groups_pruned_statistics, bytes_scanned, - pushdown_rows_filtered, - pushdown_eval_time, - page_index_rows_filtered, + pushdown_rows_pruned, + pushdown_rows_matched, + row_pushdown_eval_time, + page_index_rows_pruned, + page_index_rows_matched, + statistics_eval_time, + bloom_filter_eval_time, page_index_eval_time, + metadata_load_time, } } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 73fb82980fc4..059f86ce110f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -15,63 +15,250 @@ // specific language governing permissions and limitations // under the License. -//! Execution plan for reading Parquet files +//! [`ParquetExec`] Execution plan for reading Parquet files use std::any::Any; use std::fmt::Debug; -use std::ops::Range; use std::sync::Arc; use crate::datasource::listing::PartitionedFile; -use crate::datasource::physical_plan::file_stream::{ - FileOpenFuture, FileOpener, FileStream, -}; +use crate::datasource::physical_plan::file_stream::FileStream; use crate::datasource::physical_plan::{ - parquet::page_filter::PagePruningPredicate, DisplayAs, FileGroupPartitioner, - FileMeta, FileScanConfig, SchemaAdapter, + parquet::page_filter::PagePruningAccessPlanFilter, DisplayAs, FileGroupPartitioner, + FileScanConfig, }; use crate::{ config::{ConfigOptions, TableParquetOptions}, - datasource::listing::ListingTableUrl, - error::{DataFusionError, Result}, + error::Result, execution::context::TaskContext, physical_optimizer::pruning::PruningPredicate, physical_plan::{ metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}, - DisplayFormatType, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, - Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, + DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, + SendableRecordBatchStream, Statistics, }, }; -use arrow::datatypes::{DataType, SchemaRef}; -use arrow::error::ArrowError; +use arrow::datatypes::SchemaRef; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering, PhysicalExpr}; -use bytes::Bytes; -use futures::future::BoxFuture; -use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use log::debug; -use object_store::buffered::BufWriter; -use object_store::path::Path; -use object_store::ObjectStore; -use parquet::arrow::arrow_reader::ArrowReaderOptions; -use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; -use parquet::arrow::{AsyncArrowWriter, ParquetRecordBatchStreamBuilder, ProjectionMask}; -use parquet::basic::{ConvertedType, LogicalType}; -use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties}; -use parquet::schema::types::ColumnDescriptor; -use tokio::task::JoinSet; +mod access_plan; mod metrics; +mod opener; mod page_filter; +mod reader; mod row_filter; -mod row_groups; -mod statistics; +mod row_group_filter; +mod writer; +use crate::datasource::schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapterFactory, +}; +pub use access_plan::{ParquetAccessPlan, RowGroupAccess}; pub use metrics::ParquetFileMetrics; +use opener::ParquetOpener; +pub use reader::{DefaultParquetFileReaderFactory, ParquetFileReaderFactory}; +pub use row_filter::can_expr_be_pushed_down_with_schemas; +pub use writer::plan_to_parquet; -/// Execution plan for scanning one or more Parquet partitions +/// Execution plan for reading one or more Parquet files. +/// +/// ```text +/// ▲ +/// │ +/// │ Produce a stream of +/// │ RecordBatches +/// │ +/// ┌───────────────────────┐ +/// │ │ +/// │ ParquetExec │ +/// │ │ +/// └───────────────────────┘ +/// ▲ +/// │ Asynchronously read from one +/// │ or more parquet files via +/// │ ObjectStore interface +/// │ +/// │ +/// .───────────────────. +/// │ ) +/// │`───────────────────'│ +/// │ ObjectStore │ +/// │.───────────────────.│ +/// │ ) +/// `───────────────────' +/// +/// ``` +/// +/// # Example: Create a `ParquetExec` +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # let file_schema = Arc::new(Schema::empty()); +/// # let object_store_url = ObjectStoreUrl::local_filesystem(); +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # use datafusion_physical_expr::expressions::lit; +/// # let predicate = lit(true); +/// // Create a ParquetExec for reading `file1.parquet` with a file size of 100MB +/// let file_scan_config = FileScanConfig::new(object_store_url, file_schema) +/// .with_file(PartitionedFile::new("file1.parquet", 100*1024*1024)); +/// let exec = ParquetExec::builder(file_scan_config) +/// // Provide a predicate for filtering row groups/pages +/// .with_predicate(predicate) +/// .build(); +/// ``` +/// +/// # Features +/// +/// Supports the following optimizations: +/// +/// * Concurrent reads: reads from one or more files in parallel as multiple +/// partitions, including concurrently reading multiple row groups from a single +/// file. +/// +/// * Predicate push down: skips row groups, pages, rows based on metadata +/// and late materialization. See "Predicate Pushdown" below. +/// +/// * Projection pushdown: reads and decodes only the columns required. +/// +/// * Limit pushdown: stop execution early after some number of rows are read. +/// +/// * Custom readers: customize reading parquet files, e.g. to cache metadata, +/// coalesce I/O operations, etc. See [`ParquetFileReaderFactory`] for more +/// details. +/// +/// * Schema evolution: read parquet files with different schemas into a unified +/// table schema. See [`SchemaAdapterFactory`] for more details. +/// +/// * metadata_size_hint: controls the number of bytes read from the end of the +/// file in the initial I/O when the default [`ParquetFileReaderFactory`]. If a +/// custom reader is used, it supplies the metadata directly and this parameter +/// is ignored. [`ParquetExecBuilder::with_metadata_size_hint`] for more details. +/// +/// * User provided [`ParquetAccessPlan`]s to skip row groups and/or pages +/// based on external information. See "Implementing External Indexes" below +/// +/// # Predicate Pushdown +/// +/// `ParquetExec` uses the provided [`PhysicalExpr`] predicate as a filter to +/// skip reading unnecessary data and improve query performance using several techniques: +/// +/// * Row group pruning: skips entire row groups based on min/max statistics +/// found in [`ParquetMetaData`] and any Bloom filters that are present. +/// +/// * Page pruning: skips individual pages within a ColumnChunk using the +/// [Parquet PageIndex], if present. +/// +/// * Row filtering: skips rows within a page using a form of late +/// materialization. When possible, predicates are applied by the parquet +/// decoder *during* decode (see [`ArrowPredicate`] and [`RowFilter`] for more +/// details). This is only enabled if `ParquetScanOptions::pushdown_filters` is set to true. +/// +/// Note: If the predicate can not be used to accelerate the scan, it is ignored +/// (no error is raised on predicate evaluation errors). +/// +/// [`ArrowPredicate`]: parquet::arrow::arrow_reader::ArrowPredicate +/// [`RowFilter`]: parquet::arrow::arrow_reader::RowFilter +/// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md +/// +/// # Example: rewriting `ParquetExec` +/// +/// You can modify a `ParquetExec` using [`ParquetExecBuilder`], for example +/// to change files or add a predicate. +/// +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow::datatypes::Schema; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # fn parquet_exec() -> ParquetExec { unimplemented!() } +/// // Split a single ParquetExec into multiple ParquetExecs, one for each file +/// let exec = parquet_exec(); +/// let existing_file_groups = &exec.base_config().file_groups; +/// let new_execs = existing_file_groups +/// .iter() +/// .map(|file_group| { +/// // create a new exec by copying the existing exec into a builder +/// let new_exec = exec.clone() +/// .into_builder() +/// .with_file_groups(vec![file_group.clone()]) +/// .build(); +/// new_exec +/// }) +/// .collect::>(); +/// ``` +/// +/// # Implementing External Indexes +/// +/// It is possible to restrict the row groups and selections within those row +/// groups that the ParquetExec will consider by providing an initial +/// [`ParquetAccessPlan`] as `extensions` on [`PartitionedFile`]. This can be +/// used to implement external indexes on top of parquet files and select only +/// portions of the files. +/// +/// The `ParquetExec` will try and reduce any provided `ParquetAccessPlan` +/// further based on the contents of `ParquetMetadata` and other settings. +/// +/// ## Example of providing a ParquetAccessPlan +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_schema::{Schema, SchemaRef}; +/// # use datafusion::datasource::listing::PartitionedFile; +/// # use datafusion::datasource::physical_plan::parquet::ParquetAccessPlan; +/// # use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +/// # use datafusion_execution::object_store::ObjectStoreUrl; +/// # fn schema() -> SchemaRef { +/// # Arc::new(Schema::empty()) +/// # } +/// // create an access plan to scan row group 0, 1 and 3 and skip row groups 2 and 4 +/// let mut access_plan = ParquetAccessPlan::new_all(5); +/// access_plan.skip(2); +/// access_plan.skip(4); +/// // provide the plan as extension to the FileScanConfig +/// let partitioned_file = PartitionedFile::new("my_file.parquet", 1234) +/// .with_extensions(Arc::new(access_plan)); +/// // create a ParquetExec to scan this file +/// let file_scan_config = FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema()) +/// .with_file(partitioned_file); +/// // this parquet exec will not even try to read row groups 2 and 4. Additional +/// // pruning based on predicates may also happen +/// let exec = ParquetExec::builder(file_scan_config).build(); +/// ``` +/// +/// For a complete example, see the [`advanced_parquet_index` example]). +/// +/// [`parquet_index_advanced` example]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_parquet_index.rs +/// +/// # Execution Overview +/// +/// * Step 1: [`ParquetExec::execute`] is called, returning a [`FileStream`] +/// configured to open parquet files with a `ParquetOpener`. +/// +/// * Step 2: When the stream is polled, the `ParquetOpener` is called to open +/// the file. +/// +/// * Step 3: The `ParquetOpener` gets the [`ParquetMetaData`] (file metadata) +/// via [`ParquetFileReaderFactory`], creating a [`ParquetAccessPlan`] by +/// applying predicates to metadata. The plan and projections are used to +/// determine what pages must be read. +/// +/// * Step 4: The stream begins reading data, fetching the required parquet +/// pages incrementally decoding them, and applying any row filters (see +/// [`Self::with_pushdown_filters`]). +/// +/// * Step 5: As each [`RecordBatch`] is read, it may be adapted by a +/// [`SchemaAdapter`] to match the table schema. By default missing columns are +/// filled with nulls, but this can be customized via [`SchemaAdapterFactory`]. +/// +/// [`RecordBatch`]: arrow::record_batch::RecordBatch +/// [`SchemaAdapter`]: crate::datasource::schema_adapter::SchemaAdapter +/// [`ParquetMetadata`]: parquet::file::metadata::ParquetMetaData #[derive(Debug, Clone)] pub struct ParquetExec { /// Base configuration for this scan @@ -81,10 +268,10 @@ pub struct ParquetExec { metrics: ExecutionPlanMetricsSet, /// Optional predicate for row filtering during parquet scan predicate: Option>, - /// Optional predicate for pruning row groups + /// Optional predicate for pruning row groups (derived from `predicate`) pruning_predicate: Option>, - /// Optional predicate for pruning pages - page_pruning_predicate: Option>, + /// Optional predicate for pruning pages (derived from `predicate`) + page_pruning_predicate: Option>, /// Optional hint for the size of the parquet metadata metadata_size_hint: Option, /// Optional user defined parquet file reader factory @@ -93,16 +280,140 @@ pub struct ParquetExec { cache: PlanProperties, /// Options for reading Parquet files table_parquet_options: TableParquetOptions, + /// Optional user defined schema adapter + schema_adapter_factory: Option>, } -impl ParquetExec { - /// Create a new Parquet reader execution plan provided file list and schema. - pub fn new( - base_config: FileScanConfig, - predicate: Option>, - metadata_size_hint: Option, +impl From for ParquetExecBuilder { + fn from(exec: ParquetExec) -> Self { + exec.into_builder() + } +} + +/// [`ParquetExecBuilder`], builder for [`ParquetExec`]. +/// +/// See example on [`ParquetExec`]. +pub struct ParquetExecBuilder { + file_scan_config: FileScanConfig, + predicate: Option>, + metadata_size_hint: Option, + table_parquet_options: TableParquetOptions, + parquet_file_reader_factory: Option>, + schema_adapter_factory: Option>, +} + +impl ParquetExecBuilder { + /// Create a new builder to read the provided file scan configuration + pub fn new(file_scan_config: FileScanConfig) -> Self { + Self::new_with_options(file_scan_config, TableParquetOptions::default()) + } + + /// Create a new builder to read the data specified in the file scan + /// configuration with the provided `TableParquetOptions`. + pub fn new_with_options( + file_scan_config: FileScanConfig, + table_parquet_options: TableParquetOptions, + ) -> Self { + Self { + file_scan_config, + predicate: None, + metadata_size_hint: None, + table_parquet_options, + parquet_file_reader_factory: None, + schema_adapter_factory: None, + } + } + + /// Update the list of files groups to read + pub fn with_file_groups(mut self, file_groups: Vec>) -> Self { + self.file_scan_config.file_groups = file_groups; + self + } + + /// Set the filter predicate when reading. + /// + /// See the "Predicate Pushdown" section of the [`ParquetExec`] documenation + /// for more details. + pub fn with_predicate(mut self, predicate: Arc) -> Self { + self.predicate = Some(predicate); + self + } + + /// Set the metadata size hint + /// + /// This value determines how many bytes at the end of the file the default + /// [`ParquetFileReaderFactory`] will request in the initial IO. If this is + /// too small, the ParquetExec will need to make additional IO requests to + /// read the footer. + pub fn with_metadata_size_hint(mut self, metadata_size_hint: usize) -> Self { + self.metadata_size_hint = Some(metadata_size_hint); + self + } + + /// Set the options for controlling how the ParquetExec reads parquet files. + /// + /// See also [`Self::new_with_options`] + pub fn with_table_parquet_options( + mut self, table_parquet_options: TableParquetOptions, ) -> Self { + self.table_parquet_options = table_parquet_options; + self + } + + /// Set optional user defined parquet file reader factory. + /// + /// You can use [`ParquetFileReaderFactory`] to more precisely control how + /// data is read from parquet files (e.g. skip re-reading metadata, coalesce + /// I/O operations, etc). + /// + /// The default reader factory reads directly from an [`ObjectStore`] + /// instance using individual I/O operations for the footer and each page. + /// + /// If a custom `ParquetFileReaderFactory` is provided, then data access + /// operations will be routed to this factory instead of [`ObjectStore`]. + /// + /// [`ObjectStore`]: object_store::ObjectStore + pub fn with_parquet_file_reader_factory( + mut self, + parquet_file_reader_factory: Arc, + ) -> Self { + self.parquet_file_reader_factory = Some(parquet_file_reader_factory); + self + } + + /// Set optional schema adapter factory. + /// + /// [`SchemaAdapterFactory`] allows user to specify how fields from the + /// parquet file get mapped to that of the table schema. The default schema + /// adapter uses arrow's cast library to map the parquet fields to the table + /// schema. + pub fn with_schema_adapter_factory( + mut self, + schema_adapter_factory: Arc, + ) -> Self { + self.schema_adapter_factory = Some(schema_adapter_factory); + self + } + + /// Convenience: build an `Arc`d `ParquetExec` from this builder + pub fn build_arc(self) -> Arc { + Arc::new(self.build()) + } + + /// Build a [`ParquetExec`] + #[must_use] + pub fn build(self) -> ParquetExec { + let Self { + file_scan_config, + predicate, + metadata_size_hint, + table_parquet_options, + parquet_file_reader_factory, + schema_adapter_factory, + } = self; + + let base_config = file_scan_config; debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", base_config.file_groups, base_config.projection, predicate, base_config.limit); @@ -125,28 +436,22 @@ impl ParquetExec { }) .filter(|p| !p.always_true()); - let page_pruning_predicate = predicate.as_ref().and_then(|predicate_expr| { - match PagePruningPredicate::try_new(predicate_expr, file_schema.clone()) { - Ok(pruning_predicate) => Some(Arc::new(pruning_predicate)), - Err(e) => { - debug!( - "Could not create page pruning predicate for '{:?}': {}", - pruning_predicate, e - ); - predicate_creation_errors.add(1); - None - } - } - }); + let page_pruning_predicate = predicate + .as_ref() + .map(|predicate_expr| { + PagePruningAccessPlanFilter::new(predicate_expr, file_schema.clone()) + }) + .map(Arc::new); let (projected_schema, projected_statistics, projected_output_ordering) = base_config.project(); - let cache = Self::compute_properties( + + let cache = ParquetExec::compute_properties( projected_schema, &projected_output_ordering, &base_config, ); - Self { + ParquetExec { base_config, projected_statistics, metrics, @@ -154,9 +459,70 @@ impl ParquetExec { pruning_predicate, page_pruning_predicate, metadata_size_hint, - parquet_file_reader_factory: None, + parquet_file_reader_factory, cache, table_parquet_options, + schema_adapter_factory, + } + } +} + +impl ParquetExec { + /// Create a new Parquet reader execution plan provided file list and schema. + #[deprecated( + since = "39.0.0", + note = "use `ParquetExec::builder` or `ParquetExecBuilder`" + )] + pub fn new( + base_config: FileScanConfig, + predicate: Option>, + metadata_size_hint: Option, + table_parquet_options: TableParquetOptions, + ) -> Self { + let mut builder = + ParquetExecBuilder::new_with_options(base_config, table_parquet_options); + if let Some(predicate) = predicate { + builder = builder.with_predicate(predicate); + } + if let Some(metadata_size_hint) = metadata_size_hint { + builder = builder.with_metadata_size_hint(metadata_size_hint); + } + builder.build() + } + + /// Return a [`ParquetExecBuilder`]. + /// + /// See example on [`ParquetExec`] and [`ParquetExecBuilder`] for specifying + /// parquet table options. + pub fn builder(file_scan_config: FileScanConfig) -> ParquetExecBuilder { + ParquetExecBuilder::new(file_scan_config) + } + + /// Convert this `ParquetExec` into a builder for modification + pub fn into_builder(self) -> ParquetExecBuilder { + // list out fields so it is clear what is being dropped + // (note the fields which are dropped are re-created as part of calling + // `build` on the builder) + let Self { + base_config, + projected_statistics: _, + metrics: _, + predicate, + pruning_predicate: _, + page_pruning_predicate: _, + metadata_size_hint, + parquet_file_reader_factory, + cache: _, + table_parquet_options, + schema_adapter_factory, + } = self; + ParquetExecBuilder { + file_scan_config: base_config, + predicate, + metadata_size_hint, + table_parquet_options, + parquet_file_reader_factory, + schema_adapter_factory, } } @@ -180,13 +546,15 @@ impl ParquetExec { self.pruning_predicate.as_ref() } + /// return the optional file reader factory + pub fn parquet_file_reader_factory( + &self, + ) -> Option<&Arc> { + self.parquet_file_reader_factory.as_ref() + } + /// Optional user defined parquet file reader factory. /// - /// `ParquetFileReaderFactory` complements `TableProvider`, It enables users to provide custom - /// implementation for data access operations. - /// - /// If custom `ParquetFileReaderFactory` is provided, then data access operations will be routed - /// to this factory instead of `ObjectStore`. pub fn with_parquet_file_reader_factory( mut self, parquet_file_reader_factory: Arc, @@ -195,11 +563,24 @@ impl ParquetExec { self } - /// If true, any filter [`Expr`]s on the scan will converted to a - /// [`RowFilter`](parquet::arrow::arrow_reader::RowFilter) in the - /// `ParquetRecordBatchStream`. These filters are applied by the - /// parquet decoder to skip unecessairly decoding other columns - /// which would not pass the predicate. Defaults to false + /// return the optional schema adapter factory + pub fn schema_adapter_factory(&self) -> Option<&Arc> { + self.schema_adapter_factory.as_ref() + } + + /// Optional schema adapter factory. + /// + /// See documentation on [`ParquetExecBuilder::with_schema_adapter_factory`] + pub fn with_schema_adapter_factory( + mut self, + schema_adapter_factory: Arc, + ) -> Self { + self.schema_adapter_factory = Some(schema_adapter_factory); + self + } + + /// If true, the predicate will be used during the parquet scan. + /// Defaults to false /// /// [`Expr`]: datafusion_expr::Expr pub fn with_pushdown_filters(mut self, pushdown_filters: bool) -> Self { @@ -243,14 +624,24 @@ impl ParquetExec { } /// If enabled, the reader will read by the bloom filter - pub fn with_enable_bloom_filter(mut self, enable_bloom_filter: bool) -> Self { - self.table_parquet_options.global.bloom_filter_enabled = enable_bloom_filter; + pub fn with_bloom_filter_on_read(mut self, bloom_filter_on_read: bool) -> Self { + self.table_parquet_options.global.bloom_filter_on_read = bloom_filter_on_read; self } - /// Return the value described in [`Self::with_enable_bloom_filter`] - fn enable_bloom_filter(&self) -> bool { - self.table_parquet_options.global.bloom_filter_enabled + /// If enabled, the writer will write by the bloom filter + pub fn with_bloom_filter_on_write( + mut self, + enable_bloom_filter_on_write: bool, + ) -> Self { + self.table_parquet_options.global.bloom_filter_on_write = + enable_bloom_filter_on_write; + self + } + + /// Return the value described in [`Self::with_bloom_filter_on_read`] + fn bloom_filter_on_read(&self) -> bool { + self.table_parquet_options.global.bloom_filter_on_read } fn output_partitioning_helper(file_config: &FileScanConfig) -> Partitioning { @@ -273,7 +664,14 @@ impl ParquetExec { ) } - fn with_file_groups(mut self, file_groups: Vec>) -> Self { + /// Updates the file groups to read and recalculates the output partitioning + /// + /// Note this function does not update statistics or other properties + /// that depend on the file groups. + fn with_file_groups_and_update_partitioning( + mut self, + file_groups: Vec>, + ) -> Self { self.base_config.file_groups = file_groups; // Changing file groups may invalidate output partitioning. Update it also let output_partitioning = Self::output_partitioning_helper(&self.base_config); @@ -300,14 +698,16 @@ impl DisplayAs for ParquetExec { .pruning_predicate .as_ref() .map(|pre| { + let mut guarantees = pre + .literal_guarantees() + .iter() + .map(|item| format!("{}", item)) + .collect_vec(); + guarantees.sort(); format!( ", pruning_predicate={}, required_guarantees=[{}]", pre.predicate_expr(), - pre.literal_guarantees() - .iter() - .map(|item| format!("{}", item)) - .collect_vec() - .join(", ") + guarantees.join(", ") ) }) .unwrap_or_default(); @@ -334,7 +734,7 @@ impl ExecutionPlan for ParquetExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { // this is a leaf node and has no children vec![] } @@ -364,7 +764,8 @@ impl ExecutionPlan for ParquetExec { let mut new_plan = self.clone(); if let Some(repartitioned_file_groups) = repartitioned_file_groups_option { - new_plan = new_plan.with_file_groups(repartitioned_file_groups); + new_plan = new_plan + .with_file_groups_and_update_partitioning(repartitioned_file_groups); } Ok(Some(Arc::new(new_plan))) } @@ -374,10 +775,12 @@ impl ExecutionPlan for ParquetExec { partition_index: usize, ctx: Arc, ) -> Result { - let projection = match self.base_config.file_column_projection_indices() { - Some(proj) => proj, - None => (0..self.base_config.file_schema.fields().len()).collect(), - }; + let projection = self + .base_config + .file_column_projection_indices() + .unwrap_or_else(|| { + (0..self.base_config.file_schema.fields().len()).collect() + }); let parquet_file_reader_factory = self .parquet_file_reader_factory @@ -387,11 +790,15 @@ impl ExecutionPlan for ParquetExec { ctx.runtime_env() .object_store(&self.base_config.object_store_url) .map(|store| { - Arc::new(DefaultParquetFileReaderFactory::new(store)) - as Arc + Arc::new(DefaultParquetFileReaderFactory::new(store)) as _ }) })?; + let schema_adapter_factory = self + .schema_adapter_factory + .clone() + .unwrap_or_else(|| Arc::new(DefaultSchemaAdapterFactory)); + let opener = ParquetOpener { partition_index, projection: Arc::from(projection), @@ -407,7 +814,8 @@ impl ExecutionPlan for ParquetExec { pushdown_filters: self.pushdown_filters(), reorder_filters: self.reorder_filters(), enable_page_index: self.enable_page_index(), - enable_bloom_filter: self.enable_bloom_filter(), + enable_bloom_filter: self.bloom_filter_on_read(), + schema_adapter_factory, }; let stream = @@ -421,178 +829,46 @@ impl ExecutionPlan for ParquetExec { } fn statistics(&self) -> Result { - Ok(self.projected_statistics.clone()) + // When filters are pushed down, we have no way of knowing the exact statistics. + // Note that pruning predicate is also a kind of filter pushdown. + // (bloom filters use `pruning_predicate` too) + let stats = if self.pruning_predicate.is_some() + || self.page_pruning_predicate.is_some() + || (self.predicate.is_some() && self.pushdown_filters()) + { + self.projected_statistics.clone().to_inexact() + } else { + self.projected_statistics.clone() + }; + Ok(stats) } -} - -/// Implements [`FileOpener`] for a parquet file -struct ParquetOpener { - partition_index: usize, - projection: Arc<[usize]>, - batch_size: usize, - limit: Option, - predicate: Option>, - pruning_predicate: Option>, - page_pruning_predicate: Option>, - table_schema: SchemaRef, - metadata_size_hint: Option, - metrics: ExecutionPlanMetricsSet, - parquet_file_reader_factory: Arc, - pushdown_filters: bool, - reorder_filters: bool, - enable_page_index: bool, - enable_bloom_filter: bool, -} - -impl FileOpener for ParquetOpener { - fn open(&self, file_meta: FileMeta) -> Result { - let file_range = file_meta.range.clone(); - - let file_metrics = ParquetFileMetrics::new( - self.partition_index, - file_meta.location().as_ref(), - &self.metrics, - ); - - let reader: Box = - self.parquet_file_reader_factory.create_reader( - self.partition_index, - file_meta, - self.metadata_size_hint, - &self.metrics, - )?; - - let batch_size = self.batch_size; - let projection = self.projection.clone(); - let projected_schema = SchemaRef::from(self.table_schema.project(&projection)?); - let schema_adapter = SchemaAdapter::new(projected_schema); - let predicate = self.predicate.clone(); - let pruning_predicate = self.pruning_predicate.clone(); - let page_pruning_predicate = self.page_pruning_predicate.clone(); - let table_schema = self.table_schema.clone(); - let reorder_predicates = self.reorder_filters; - let pushdown_filters = self.pushdown_filters; - let enable_page_index = should_enable_page_index( - self.enable_page_index, - &self.page_pruning_predicate, - ); - let enable_bloom_filter = self.enable_bloom_filter; - let limit = self.limit; - - Ok(Box::pin(async move { - let options = ArrowReaderOptions::new().with_page_index(enable_page_index); - let mut builder = - ParquetRecordBatchStreamBuilder::new_with_options(reader, options) - .await?; - - let file_schema = builder.schema().clone(); - - let (schema_mapping, adapted_projections) = - schema_adapter.map_schema(&file_schema)?; - // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; - let mask = ProjectionMask::roots( - builder.parquet_schema(), - adapted_projections.iter().cloned(), - ); - - // Filter pushdown: evaluate predicates during scan - if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { - let row_filter = row_filter::build_row_filter( - &predicate, - &file_schema, - &table_schema, - builder.metadata(), - reorder_predicates, - &file_metrics, - ); - - match row_filter { - Ok(Some(filter)) => { - builder = builder.with_row_filter(filter); - } - Ok(None) => {} - Err(e) => { - debug!( - "Ignoring error building row filter for '{:?}': {}", - predicate, e - ); - } - }; - }; - - // Row group pruning by statistics: attempt to skip entire row_groups - // using metadata on the row groups - let file_metadata = builder.metadata().clone(); - let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); - let mut row_groups = row_groups::prune_row_groups_by_statistics( - &file_schema, - builder.parquet_schema(), - file_metadata.row_groups(), - file_range, - predicate, - &file_metrics, - ); - - // Bloom filter pruning: if bloom filters are enabled and then attempt to skip entire row_groups - // using bloom filters on the row groups - if enable_bloom_filter && !row_groups.is_empty() { - if let Some(predicate) = predicate { - row_groups = row_groups::prune_row_groups_by_bloom_filters( - &file_schema, - &mut builder, - &row_groups, - file_metadata.row_groups(), - predicate, - &file_metrics, - ) - .await; - } - } - - // page index pruning: if all data on individual pages can - // be ruled using page metadata, rows from other columns - // with that range can be skipped as well - if enable_page_index && !row_groups.is_empty() { - if let Some(p) = page_pruning_predicate { - let pruned = p.prune( - &file_schema, - builder.parquet_schema(), - &row_groups, - file_metadata.as_ref(), - &file_metrics, - )?; - if let Some(row_selection) = pruned { - builder = builder.with_row_selection(row_selection); - } - } - } - - if let Some(limit) = limit { - builder = builder.with_limit(limit) - } - - let stream = builder - .with_projection(mask) - .with_batch_size(batch_size) - .with_row_groups(row_groups) - .build()?; + fn fetch(&self) -> Option { + self.base_config.limit + } - let adapted = stream - .map_err(|e| ArrowError::ExternalError(Box::new(e))) - .map(move |maybe_batch| { - maybe_batch - .and_then(|b| schema_mapping.map_batch(b).map_err(Into::into)) - }); + fn with_fetch(&self, limit: Option) -> Option> { + let new_config = self.base_config.clone().with_limit(limit); - Ok(adapted.boxed()) + Some(Arc::new(Self { + base_config: new_config, + projected_statistics: self.projected_statistics.clone(), + metrics: self.metrics.clone(), + predicate: self.predicate.clone(), + pruning_predicate: self.pruning_predicate.clone(), + page_pruning_predicate: self.page_pruning_predicate.clone(), + metadata_size_hint: self.metadata_size_hint, + parquet_file_reader_factory: self.parquet_file_reader_factory.clone(), + cache: self.cache.clone(), + table_parquet_options: self.table_parquet_options.clone(), + schema_adapter_factory: self.schema_adapter_factory.clone(), })) } } fn should_enable_page_index( enable_page_index: bool, - page_pruning_predicate: &Option>, + page_pruning_predicate: &Option>, ) -> bool { enable_page_index && page_pruning_predicate.is_some() @@ -602,166 +878,6 @@ fn should_enable_page_index( .unwrap_or(false) } -/// Factory of parquet file readers. -/// -/// Provides means to implement custom data access interface. -pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { - /// Provides `AsyncFileReader` over parquet file specified in `FileMeta` - fn create_reader( - &self, - partition_index: usize, - file_meta: FileMeta, - metadata_size_hint: Option, - metrics: &ExecutionPlanMetricsSet, - ) -> Result>; -} - -/// Default parquet reader factory. -#[derive(Debug)] -pub struct DefaultParquetFileReaderFactory { - store: Arc, -} - -impl DefaultParquetFileReaderFactory { - /// Create a factory. - pub fn new(store: Arc) -> Self { - Self { store } - } -} - -/// Implements [`AsyncFileReader`] for a parquet file in object storage -pub(crate) struct ParquetFileReader { - file_metrics: ParquetFileMetrics, - inner: ParquetObjectReader, -} - -impl AsyncFileReader for ParquetFileReader { - fn get_bytes( - &mut self, - range: Range, - ) -> BoxFuture<'_, parquet::errors::Result> { - self.file_metrics.bytes_scanned.add(range.end - range.start); - self.inner.get_bytes(range) - } - - fn get_byte_ranges( - &mut self, - ranges: Vec>, - ) -> BoxFuture<'_, parquet::errors::Result>> - where - Self: Send, - { - let total = ranges.iter().map(|r| r.end - r.start).sum(); - self.file_metrics.bytes_scanned.add(total); - self.inner.get_byte_ranges(ranges) - } - - fn get_metadata( - &mut self, - ) -> BoxFuture<'_, parquet::errors::Result>> { - self.inner.get_metadata() - } -} - -impl ParquetFileReaderFactory for DefaultParquetFileReaderFactory { - fn create_reader( - &self, - partition_index: usize, - file_meta: FileMeta, - metadata_size_hint: Option, - metrics: &ExecutionPlanMetricsSet, - ) -> Result> { - let file_metrics = ParquetFileMetrics::new( - partition_index, - file_meta.location().as_ref(), - metrics, - ); - let store = Arc::clone(&self.store); - let mut inner = ParquetObjectReader::new(store, file_meta.object_meta); - - if let Some(hint) = metadata_size_hint { - inner = inner.with_footer_size_hint(hint) - }; - - Ok(Box::new(ParquetFileReader { - inner, - file_metrics, - })) - } -} - -/// Executes a query and writes the results to a partitioned Parquet file. -pub async fn plan_to_parquet( - task_ctx: Arc, - plan: Arc, - path: impl AsRef, - writer_properties: Option, -) -> Result<()> { - let path = path.as_ref(); - let parsed = ListingTableUrl::parse(path)?; - let object_store_url = parsed.object_store(); - let store = task_ctx.runtime_env().object_store(&object_store_url)?; - let mut join_set = JoinSet::new(); - for i in 0..plan.output_partitioning().partition_count() { - let plan: Arc = plan.clone(); - let filename = format!("{}/part-{i}.parquet", parsed.prefix()); - let file = Path::parse(filename)?; - let propclone = writer_properties.clone(); - - let storeref = store.clone(); - let buf_writer = BufWriter::new(storeref, file.clone()); - let mut stream = plan.execute(i, task_ctx.clone())?; - join_set.spawn(async move { - let mut writer = - AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; - while let Some(next_batch) = stream.next().await { - let batch = next_batch?; - writer.write(&batch).await?; - } - writer - .close() - .await - .map_err(DataFusionError::from) - .map(|_| ()) - }); - } - - while let Some(result) = join_set.join_next().await { - match result { - Ok(res) => res?, - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - Ok(()) -} - -// Convert parquet column schema to arrow data type, and just consider the -// decimal data type. -pub(crate) fn parquet_to_arrow_decimal_type( - parquet_column: &ColumnDescriptor, -) -> Option { - let type_ptr = parquet_column.self_type_ptr(); - match type_ptr.get_basic_info().logical_type() { - Some(LogicalType::Decimal { scale, precision }) => { - Some(DataType::Decimal128(precision as u8, scale as i8)) - } - _ => match type_ptr.get_basic_info().converted_type() { - ConvertedType::DECIMAL => Some(DataType::Decimal128( - type_ptr.get_precision() as u8, - type_ptr.get_scale() as i8, - )), - _ => None, - }, - } -} - #[cfg(test)] mod tests { // See also `parquet_exec` integration test @@ -789,19 +905,21 @@ mod tests { ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray, StructArray, }; - use arrow::datatypes::{Field, Schema, SchemaBuilder}; use arrow::record_batch::RecordBatch; - use arrow_schema::Fields; - use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue, ToDFSchema}; - use datafusion_expr::execution_props::ExecutionProps; + use arrow_schema::{DataType, Fields}; + use datafusion_common::{assert_contains, ScalarValue}; use datafusion_expr::{col, lit, when, Expr}; - use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::planner::logical2physical; + use datafusion_physical_plan::ExecutionPlanProperties; use chrono::{TimeZone, Utc}; + use futures::StreamExt; use object_store::local::LocalFileSystem; + use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; + use parquet::file::properties::WriterProperties; use tempfile::TempDir; use url::Url; @@ -885,28 +1003,23 @@ mod tests { // files with multiple pages let multi_page = page_index_predicate; let (meta, _files) = store_parquet(batches, multi_page).await.unwrap(); - let file_groups = meta.into_iter().map(Into::into).collect(); + let file_group = meta.into_iter().map(Into::into).collect(); // set up predicate (this is normally done by a layer higher up) let predicate = predicate.map(|p| logical2physical(&p, &file_schema)); // prepare the scan - let mut parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: vec![file_groups], - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - predicate, - None, - Default::default(), + let mut builder = ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file_group(file_group) + .with_projection(projection), ); + if let Some(predicate) = predicate { + builder = builder.with_predicate(predicate); + } + let mut parquet_exec = builder.build(); + if pushdown_predicate { parquet_exec = parquet_exec .with_pushdown_filters(true) @@ -956,7 +1069,7 @@ mod tests { let tmp_dir = TempDir::new()?; let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); let options = CsvReadOptions::default() .schema_infer_max_records(2) @@ -1157,7 +1270,8 @@ mod tests { assert_batches_sorted_eq!(expected, &rt.batches.unwrap()); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) - assert_eq!(get_value(&metrics, "pushdown_rows_filtered"), 4); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 4); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 2); } #[tokio::test] @@ -1309,7 +1423,8 @@ mod tests { assert_batches_sorted_eq!(expected, &rt.batches.unwrap()); let metrics = rt.parquet_exec.metrics().unwrap(); // Note there are were 6 rows in total (across three batches) - assert_eq!(get_value(&metrics, "pushdown_rows_filtered"), 5); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 5); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 1); } #[tokio::test] @@ -1383,7 +1498,8 @@ mod tests { // There are 4 rows pruned in each of batch2, batch3, and // batch4 for a total of 12. batch1 had no pruning as c2 was // filled in as null - assert_eq!(get_value(&metrics, "page_index_rows_filtered"), 12); + assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 12); + assert_eq!(get_value(&metrics, "page_index_rows_matched"), 6); } #[tokio::test] @@ -1538,6 +1654,7 @@ mod tests { object_meta: meta.clone(), partition_values: vec![], range: Some(FileRange { start, end }), + statistics: None, extensions: None, } } @@ -1548,21 +1665,11 @@ mod tests { expected_row_num: Option, file_schema: SchemaRef, ) -> Result<()> { - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups, - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), - ); + let parquet_exec = ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file_groups(file_groups), + ) + .build(); assert_eq!( parquet_exec .properties() @@ -1639,6 +1746,7 @@ mod tests { ), ], range: None, + statistics: None, extensions: None, }; @@ -1657,16 +1765,12 @@ mod tests { ), ]); - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url, - file_groups: vec![vec![partitioned_file]], - file_schema: schema.clone(), - statistics: Statistics::new_unknown(&schema), + let parquet_exec = ParquetExec::builder( + FileScanConfig::new(object_store_url, schema.clone()) + .with_file(partitioned_file) // file has 10 cols so index 12 should be month and 13 should be day - projection: Some(vec![0, 1, 2, 12, 13]), - limit: None, - table_partition_cols: vec![ + .with_projection(Some(vec![0, 1, 2, 12, 13])) + .with_table_partition_cols(vec![ Field::new("year", DataType::Utf8, false), Field::new("month", DataType::UInt8, false), Field::new( @@ -1677,13 +1781,9 @@ mod tests { ), false, ), - ], - output_ordering: vec![], - }, - None, - None, - Default::default(), - ); + ]), + ) + .build(); assert_eq!( parquet_exec.cache.output_partitioning().partition_count(), 1 @@ -1733,24 +1833,16 @@ mod tests { }, partition_values: vec![], range: None, + statistics: None, extensions: None, }; - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: vec![vec![partitioned_file]], - file_schema: Arc::new(Schema::empty()), - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), - ); + let file_schema = Arc::new(Schema::empty()); + let parquet_exec = ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file(partitioned_file), + ) + .build(); let mut results = parquet_exec.execute(0, state.task_ctx())?; let batch = results.next().await.unwrap(); @@ -1794,7 +1886,8 @@ mod tests { "+-----+" ]; assert_batches_sorted_eq!(expected, &rt.batches.unwrap()); - assert_eq!(get_value(&metrics, "page_index_rows_filtered"), 4); + assert_eq!(get_value(&metrics, "page_index_rows_pruned"), 4); + assert_eq!(get_value(&metrics, "page_index_rows_matched"), 2); assert!( get_value(&metrics, "page_index_eval_time") > 0, "no eval time in metrics: {metrics:#?}" @@ -1818,6 +1911,26 @@ mod tests { create_batch(vec![("c1", c1.clone())]) } + /// Returns a int64 array with contents: + /// "[-1, 1, null, 2, 3, null, null]" + fn int64_batch() -> RecordBatch { + let contents: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(-1), + Some(1), + None, + Some(2), + Some(3), + None, + None, + ])); + + create_batch(vec![ + ("a", contents.clone()), + ("b", contents.clone()), + ("c", contents.clone()), + ]) + } + #[tokio::test] async fn parquet_exec_metrics() { // batch1: c1(string) @@ -1843,10 +1956,19 @@ mod tests { // pushdown predicates have eliminated all 4 bar rows and the // null row for 5 rows total - assert_eq!(get_value(&metrics, "pushdown_rows_filtered"), 5); + assert_eq!(get_value(&metrics, "pushdown_rows_pruned"), 5); + assert_eq!(get_value(&metrics, "pushdown_rows_matched"), 2); assert!( - get_value(&metrics, "pushdown_eval_time") > 0, - "no eval time in metrics: {metrics:#?}" + get_value(&metrics, "row_pushdown_eval_time") > 0, + "no pushdown eval time in metrics: {metrics:#?}" + ); + assert!( + get_value(&metrics, "statistics_eval_time") > 0, + "no statistics eval time in metrics: {metrics:#?}" + ); + assert!( + get_value(&metrics, "bloom_filter_eval_time") > 0, + "no Bloom Filter eval time in metrics: {metrics:#?}" ); } @@ -1883,6 +2005,93 @@ mod tests { assert_contains!(&display, "projection=[c1]"); } + #[tokio::test] + async fn parquet_exec_display_deterministic() { + // batches: a(int64), b(int64), c(int64) + let batches = int64_batch(); + + fn extract_required_guarantees(s: &str) -> Option<&str> { + s.split("required_guarantees=").nth(1) + } + + // Ensuring that the required_guarantees remain consistent across every display plan of the filter conditions + for _ in 0..100 { + // c = 1 AND b = 1 AND a = 1 + let filter0 = col("c") + .eq(lit(1)) + .and(col("b").eq(lit(1))) + .and(col("a").eq(lit(1))); + + let rt0 = RoundTrip::new() + .with_predicate(filter0) + .with_pushdown_predicate() + .round_trip(vec![batches.clone()]) + .await; + + let pruning_predicate = &rt0.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + + let display0 = displayable(rt0.parquet_exec.as_ref()) + .indent(true) + .to_string(); + + let guarantees0: &str = extract_required_guarantees(&display0) + .expect("Failed to extract required_guarantees"); + // Compare only the required_guarantees part (Because the file_groups part will not be the same) + assert_eq!( + guarantees0.trim(), + "[a in (1), b in (1), c in (1)]", + "required_guarantees don't match" + ); + } + + // c = 1 AND a = 1 AND b = 1 + let filter1 = col("c") + .eq(lit(1)) + .and(col("a").eq(lit(1))) + .and(col("b").eq(lit(1))); + + let rt1 = RoundTrip::new() + .with_predicate(filter1) + .with_pushdown_predicate() + .round_trip(vec![batches.clone()]) + .await; + + // b = 1 AND a = 1 AND c = 1 + let filter2 = col("b") + .eq(lit(1)) + .and(col("a").eq(lit(1))) + .and(col("c").eq(lit(1))); + + let rt2 = RoundTrip::new() + .with_predicate(filter2) + .with_pushdown_predicate() + .round_trip(vec![batches]) + .await; + + // should have a pruning predicate + let pruning_predicate = &rt1.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + let pruning_predicate = &rt2.parquet_exec.pruning_predicate; + assert!(pruning_predicate.is_some()); + + // convert to explain plan form + let display1 = displayable(rt1.parquet_exec.as_ref()) + .indent(true) + .to_string(); + let display2 = displayable(rt2.parquet_exec.as_ref()) + .indent(true) + .to_string(); + + let guarantees1 = extract_required_guarantees(&display1) + .expect("Failed to extract required_guarantees"); + let guarantees2 = extract_required_guarantees(&display2) + .expect("Failed to extract required_guarantees"); + + // Compare only the required_guarantees part (Because the predicate part will not be the same) + assert_eq!(guarantees1, guarantees2, "required_guarantees don't match"); + } + #[tokio::test] async fn parquet_exec_has_no_pruning_predicate_if_can_not_prune() { // batch1: c1(string) @@ -2009,16 +2218,16 @@ mod tests { // register a local file system object store for /tmp directory let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); // Configure listing options let file_format = ParquetFormat::default().with_enable_pruning(true); let listing_options = ListingOptions::new(Arc::new(file_format)) - .with_file_extension(FileType::PARQUET.get_ext()); + .with_file_extension(ParquetFormat::default().get_ext()); // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out"; - std::fs::create_dir(&out_dir).unwrap(); + fs::create_dir(&out_dir).unwrap(); let df = ctx.sql("SELECT c1, c2 FROM test").await?; let schema: Schema = df.schema().into(); // Register a listing table - this will use all files in the directory as data sources @@ -2079,12 +2288,6 @@ mod tests { Ok(()) } - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } - #[tokio::test] async fn test_struct_filter_parquet() -> Result<()> { let tmp_dir = TempDir::new()?; @@ -2109,6 +2312,36 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_struct_filter_parquet_with_view_types() -> Result<()> { + let tmp_dir = TempDir::new().unwrap(); + let path = tmp_dir.path().to_str().unwrap().to_string() + "/test.parquet"; + write_file(&path); + + let ctx = SessionContext::new(); + + let mut options = TableParquetOptions::default(); + options.global.schema_force_view_types = true; + let opt = + ListingOptions::new(Arc::new(ParquetFormat::default().with_options(options))); + + ctx.register_listing_table("base_table", path, opt, None, None) + .await + .unwrap(); + let sql = "select * from base_table where name='test02'"; + let batch = ctx.sql(sql).await.unwrap().collect().await.unwrap(); + assert_eq!(batch.len(), 1); + let expected = [ + "+---------------------+----+--------+", + "| struct | id | name |", + "+---------------------+----+--------+", + "| {id: 4, name: aaa2} | 2 | test02 |", + "+---------------------+----+--------+", + ]; + crate::assert_batches_eq!(expected, &batch); + Ok(()) + } + fn write_file(file: &String) { let struct_fields = Fields::from(vec![ Field::new("id", DataType::Int64, false), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/opener.rs b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs new file mode 100644 index 000000000000..4990cb4dd735 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/opener.rs @@ -0,0 +1,299 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ParquetOpener`] for opening Parquet files + +use crate::datasource::file_format::{ + coerce_file_schema_to_string_type, coerce_file_schema_to_view_type, +}; +use crate::datasource::physical_plan::parquet::page_filter::PagePruningAccessPlanFilter; +use crate::datasource::physical_plan::parquet::row_group_filter::RowGroupAccessPlanFilter; +use crate::datasource::physical_plan::parquet::{ + row_filter, should_enable_page_index, ParquetAccessPlan, +}; +use crate::datasource::physical_plan::{ + FileMeta, FileOpenFuture, FileOpener, ParquetFileMetrics, ParquetFileReaderFactory, +}; +use crate::datasource::schema_adapter::SchemaAdapterFactory; +use crate::physical_optimizer::pruning::PruningPredicate; +use arrow_schema::{ArrowError, SchemaRef}; +use datafusion_common::{exec_err, Result}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use futures::{StreamExt, TryStreamExt}; +use log::debug; +use parquet::arrow::arrow_reader::{ArrowReaderMetadata, ArrowReaderOptions}; +use parquet::arrow::async_reader::AsyncFileReader; +use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use std::sync::Arc; + +/// Implements [`FileOpener`] for a parquet file +pub(super) struct ParquetOpener { + /// Execution partition index + pub partition_index: usize, + /// Column indexes in `table_schema` needed by the query + pub projection: Arc<[usize]>, + /// Target number of rows in each output RecordBatch + pub batch_size: usize, + /// Optional limit on the number of rows to read + pub limit: Option, + /// Optional predicate to apply during the scan + pub predicate: Option>, + /// Optional pruning predicate applied to row group statistics + pub pruning_predicate: Option>, + /// Optional pruning predicate applied to data page statistics + pub page_pruning_predicate: Option>, + /// Schema of the output table + pub table_schema: SchemaRef, + /// Optional hint for how large the initial request to read parquet metadata + /// should be + pub metadata_size_hint: Option, + /// Metrics for reporting + pub metrics: ExecutionPlanMetricsSet, + /// Factory for instantiating parquet reader + pub parquet_file_reader_factory: Arc, + /// Should the filters be evaluated during the parquet scan using + /// [`DataFusionArrowPredicate`](row_filter::DatafusionArrowPredicate)? + pub pushdown_filters: bool, + /// Should the filters be reordered to optimize the scan? + pub reorder_filters: bool, + /// Should the page index be read from parquet files, if present, to skip + /// data pages + pub enable_page_index: bool, + /// Should the bloom filter be read from parquet, if present, to skip row + /// groups + pub enable_bloom_filter: bool, + /// Schema adapter factory + pub schema_adapter_factory: Arc, +} + +impl FileOpener for ParquetOpener { + fn open(&self, file_meta: FileMeta) -> Result { + let file_range = file_meta.range.clone(); + let extensions = file_meta.extensions.clone(); + let file_name = file_meta.location().to_string(); + let file_metrics = + ParquetFileMetrics::new(self.partition_index, &file_name, &self.metrics); + + let mut reader: Box = + self.parquet_file_reader_factory.create_reader( + self.partition_index, + file_meta, + self.metadata_size_hint, + &self.metrics, + )?; + + let batch_size = self.batch_size; + + let projected_schema = + SchemaRef::from(self.table_schema.project(&self.projection)?); + let schema_adapter = self + .schema_adapter_factory + .create(projected_schema, self.table_schema.clone()); + let predicate = self.predicate.clone(); + let pruning_predicate = self.pruning_predicate.clone(); + let page_pruning_predicate = self.page_pruning_predicate.clone(); + let table_schema = self.table_schema.clone(); + let reorder_predicates = self.reorder_filters; + let pushdown_filters = self.pushdown_filters; + let enable_page_index = should_enable_page_index( + self.enable_page_index, + &self.page_pruning_predicate, + ); + let enable_bloom_filter = self.enable_bloom_filter; + let limit = self.limit; + + Ok(Box::pin(async move { + let options = ArrowReaderOptions::new().with_page_index(enable_page_index); + + let mut metadata_timer = file_metrics.metadata_load_time.timer(); + let metadata = + ArrowReaderMetadata::load_async(&mut reader, options.clone()).await?; + let mut schema = Arc::clone(metadata.schema()); + + if let Some(merged) = + coerce_file_schema_to_string_type(&table_schema, &schema) + { + schema = Arc::new(merged); + } + + // read with view types + if let Some(merged) = coerce_file_schema_to_view_type(&table_schema, &schema) + { + schema = Arc::new(merged); + } + + let options = ArrowReaderOptions::new() + .with_page_index(enable_page_index) + .with_schema(Arc::clone(&schema)); + let metadata = + ArrowReaderMetadata::try_new(Arc::clone(metadata.metadata()), options)?; + + metadata_timer.stop(); + + let mut builder = + ParquetRecordBatchStreamBuilder::new_with_metadata(reader, metadata); + + let file_schema = Arc::clone(builder.schema()); + + let (schema_mapping, adapted_projections) = + schema_adapter.map_schema(&file_schema)?; + + let mask = ProjectionMask::roots( + builder.parquet_schema(), + adapted_projections.iter().cloned(), + ); + + // Filter pushdown: evaluate predicates during scan + if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { + let row_filter = row_filter::build_row_filter( + &predicate, + &file_schema, + &table_schema, + builder.metadata(), + reorder_predicates, + &file_metrics, + Arc::clone(&schema_mapping), + ); + + match row_filter { + Ok(Some(filter)) => { + builder = builder.with_row_filter(filter); + } + Ok(None) => {} + Err(e) => { + debug!( + "Ignoring error building row filter for '{:?}': {}", + predicate, e + ); + } + }; + }; + + // Determine which row groups to actually read. The idea is to skip + // as many row groups as possible based on the metadata and query + let file_metadata = Arc::clone(builder.metadata()); + let predicate = pruning_predicate.as_ref().map(|p| p.as_ref()); + let rg_metadata = file_metadata.row_groups(); + // track which row groups to actually read + let access_plan = + create_initial_plan(&file_name, extensions, rg_metadata.len())?; + let mut row_groups = RowGroupAccessPlanFilter::new(access_plan); + // if there is a range restricting what parts of the file to read + if let Some(range) = file_range.as_ref() { + row_groups.prune_by_range(rg_metadata, range); + } + // If there is a predicate that can be evaluated against the metadata + if let Some(predicate) = predicate.as_ref() { + row_groups.prune_by_statistics( + &file_schema, + builder.parquet_schema(), + rg_metadata, + predicate, + &file_metrics, + ); + + if enable_bloom_filter && !row_groups.is_empty() { + row_groups + .prune_by_bloom_filters( + &file_schema, + &mut builder, + predicate, + &file_metrics, + ) + .await; + } + } + + let mut access_plan = row_groups.build(); + + // page index pruning: if all data on individual pages can + // be ruled using page metadata, rows from other columns + // with that range can be skipped as well + if enable_page_index && !access_plan.is_empty() { + if let Some(p) = page_pruning_predicate { + access_plan = p.prune_plan_with_page_index( + access_plan, + &file_schema, + builder.parquet_schema(), + file_metadata.as_ref(), + &file_metrics, + ); + } + } + + let row_group_indexes = access_plan.row_group_indexes(); + if let Some(row_selection) = + access_plan.into_overall_row_selection(rg_metadata)? + { + builder = builder.with_row_selection(row_selection); + } + + if let Some(limit) = limit { + builder = builder.with_limit(limit) + } + + let stream = builder + .with_projection(mask) + .with_batch_size(batch_size) + .with_row_groups(row_group_indexes) + .build()?; + + let adapted = stream + .map_err(|e| ArrowError::ExternalError(Box::new(e))) + .map(move |maybe_batch| { + maybe_batch + .and_then(|b| schema_mapping.map_batch(b).map_err(Into::into)) + }); + + Ok(adapted.boxed()) + })) + } +} + +/// Return the initial [`ParquetAccessPlan`] +/// +/// If the user has supplied one as an extension, use that +/// otherwise return a plan that scans all row groups +/// +/// Returns an error if an invalid `ParquetAccessPlan` is provided +/// +/// Note: file_name is only used for error messages +fn create_initial_plan( + file_name: &str, + extensions: Option>, + row_group_count: usize, +) -> Result { + if let Some(extensions) = extensions { + if let Some(access_plan) = extensions.downcast_ref::() { + let plan_len = access_plan.len(); + if plan_len != row_group_count { + return exec_err!( + "Invalid ParquetAccessPlan for {file_name}. Specified {plan_len} row groups, but file has {row_group_count}" + ); + } + + // check row group count matches the plan + return Ok(access_plan.clone()); + } else { + debug!("ParquetExec Ignoring unknown extension specified for {file_name}"); + } + } + + // default to scanning all row groups + Ok(ParquetAccessPlan::new_all(row_group_count)) +} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 402cc106492e..ced07de974f6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -17,40 +17,32 @@ //! Contains code to filter entire pages -use arrow::array::{ - BooleanArray, Decimal128Array, Float32Array, Float64Array, Int32Array, Int64Array, - StringArray, -}; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError}; +use super::metrics::ParquetFileMetrics; +use crate::datasource::physical_plan::parquet::ParquetAccessPlan; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; +use arrow::array::BooleanArray; +use arrow::{array::ArrayRef, datatypes::SchemaRef}; use arrow_schema::Schema; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_physical_expr::expressions::Column; +use datafusion_common::ScalarValue; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; -use parquet::schema::types::{ColumnDescriptor, SchemaDescriptor}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::file::metadata::{ParquetColumnIndex, ParquetOffsetIndex}; +use parquet::format::PageLocation; +use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::arrow_reader::{RowSelection, RowSelector}, - errors::ParquetError, - file::{ - metadata::{ParquetMetaData, RowGroupMetaData}, - page_index::index::Index, - }, - format::PageLocation, + file::metadata::{ParquetMetaData, RowGroupMetaData}, }; use std::collections::HashSet; use std::sync::Arc; -use crate::datasource::physical_plan::parquet::parquet_to_arrow_decimal_type; -use crate::datasource::physical_plan::parquet::statistics::{ - from_bytes_to_i128, parquet_column, -}; -use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; - -use super::metrics::ParquetFileMetrics; - -/// A [`PagePruningPredicate`] provides the ability to construct a [`RowSelection`] -/// based on parquet page level statistics, if any +/// Filters a [`ParquetAccessPlan`] based on the [Parquet PageIndex], if present +/// +/// It does so by evaluating statistics from the [`ParquetColumnIndex`] and +/// [`ParquetOffsetIndex`] and converting them to [`RowSelection`]. +/// +/// [Parquet PageIndex]: https://github.com/apache/parquet-format/blob/master/PageIndex.md /// /// For example, given a row group with two column (chunks) for `A` /// and `B` with the following with page level statistics: @@ -99,470 +91,421 @@ use super::metrics::ParquetFileMetrics; /// /// Using `A > 35`: can rule out all of values in Page 1 (rows 0 -> 199) /// -/// Using `B = 'F'`: can rule out all vaues in Page 3 and Page 5 (rows 0 -> 99, and 250 -> 299) +/// Using `B = 'F'`: can rule out all values in Page 3 and Page 5 (rows 0 -> 99, and 250 -> 299) /// /// So we can entirely skip rows 0->199 and 250->299 as we know they /// can not contain rows that match the predicate. +/// +/// # Implementation notes +/// +/// Single column predicates are evaluated using the PageIndex information +/// for that column to determine which row ranges can be skipped based. +/// +/// The resulting [`RowSelection`]'s are combined into a final +/// row selection that is added to the [`ParquetAccessPlan`]. #[derive(Debug)] -pub struct PagePruningPredicate { +pub struct PagePruningAccessPlanFilter { + /// single column predicates (e.g. (`col = 5`) extracted from the overall + /// predicate. Must all be true for a row to be included in the result. predicates: Vec, } -impl PagePruningPredicate { - /// Create a new [`PagePruningPredicate`] - pub fn try_new(expr: &Arc, schema: SchemaRef) -> Result { +impl PagePruningAccessPlanFilter { + /// Create a new [`PagePruningAccessPlanFilter`] from a physical + /// expression. + pub fn new(expr: &Arc, schema: SchemaRef) -> Self { + // extract any single column predicates let predicates = split_conjunction(expr) .into_iter() .filter_map(|predicate| { - match PruningPredicate::try_new(predicate.clone(), schema.clone()) { - Ok(p) - if (!p.always_true()) - && (p.required_columns().n_columns() < 2) => - { - Some(Ok(p)) - } - _ => None, + let pp = + match PruningPredicate::try_new(predicate.clone(), schema.clone()) { + Ok(pp) => pp, + Err(e) => { + debug!("Ignoring error creating page pruning predicate: {e}"); + return None; + } + }; + + if pp.always_true() { + debug!("Ignoring always true page pruning predicate: {predicate}"); + return None; + } + + if pp.required_columns().single_column().is_none() { + debug!("Ignoring multi-column page pruning predicate: {predicate}"); + return None; } + + Some(pp) }) - .collect::>>()?; - Ok(Self { predicates }) + .collect::>(); + Self { predicates } } - /// Returns a [`RowSelection`] for the given file - pub fn prune( + /// Returns an updated [`ParquetAccessPlan`] by applying predicates to the + /// parquet page index, if any + pub fn prune_plan_with_page_index( &self, + mut access_plan: ParquetAccessPlan, arrow_schema: &Schema, parquet_schema: &SchemaDescriptor, - row_groups: &[usize], - file_metadata: &ParquetMetaData, + parquet_metadata: &ParquetMetaData, file_metrics: &ParquetFileMetrics, - ) -> Result> { + ) -> ParquetAccessPlan { // scoped timer updates on drop let _timer_guard = file_metrics.page_index_eval_time.timer(); if self.predicates.is_empty() { - return Ok(None); + return access_plan; } let page_index_predicates = &self.predicates; - let groups = file_metadata.row_groups(); + let groups = parquet_metadata.row_groups(); if groups.is_empty() { - return Ok(None); + return access_plan; } - let file_offset_indexes = file_metadata.offset_index(); - let file_page_indexes = file_metadata.column_index(); - let (file_offset_indexes, file_page_indexes) = match ( - file_offset_indexes, - file_page_indexes, - ) { - (Some(o), Some(i)) => (o, i), - _ => { - trace!( - "skip page pruning due to lack of indexes. Have offset: {}, column index: {}", - file_offset_indexes.is_some(), file_page_indexes.is_some() + if parquet_metadata.offset_index().is_none() + || parquet_metadata.column_index().is_none() + { + debug!( + "Can not prune pages due to lack of indexes. Have offset: {}, column index: {}", + parquet_metadata.offset_index().is_some(), parquet_metadata.column_index().is_some() ); - return Ok(None); - } + return access_plan; }; - let mut row_selections = Vec::with_capacity(page_index_predicates.len()); - for predicate in page_index_predicates { - // find column index in the parquet schema - let col_idx = find_column_index(predicate, arrow_schema, parquet_schema); - let mut selectors = Vec::with_capacity(row_groups.len()); - for r in row_groups.iter() { - let row_group_metadata = &groups[*r]; - - let rg_offset_indexes = file_offset_indexes.get(*r); - let rg_page_indexes = file_page_indexes.get(*r); - if let (Some(rg_page_indexes), Some(rg_offset_indexes), Some(col_idx)) = - (rg_page_indexes, rg_offset_indexes, col_idx) - { - selectors.extend( - prune_pages_in_one_row_group( - row_group_metadata, - predicate, - rg_offset_indexes.get(col_idx), - rg_page_indexes.get(col_idx), - groups[*r].column(col_idx).column_descr(), - file_metrics, - ) - .map_err(|e| { - ArrowError::ParquetError(format!( - "Fail in prune_pages_in_one_row_group: {e}" - )) - }), - ); + // track the total number of rows that should be skipped + let mut total_skip = 0; + // track the total number of rows that should not be skipped + let mut total_select = 0; + + // for each row group specified in the access plan + let row_group_indexes = access_plan.row_group_indexes(); + for row_group_index in row_group_indexes { + // The selection for this particular row group + let mut overall_selection = None; + for predicate in page_index_predicates { + let column = predicate + .required_columns() + .single_column() + .expect("Page pruning requires single column predicates"); + + let converter = StatisticsConverter::try_new( + column.name(), + arrow_schema, + parquet_schema, + ); + + let converter = match converter { + Ok(converter) => converter, + Err(e) => { + debug!( + "Could not create statistics converter for column {}: {e}", + column.name() + ); + continue; + } + }; + + let selection = prune_pages_in_one_row_group( + row_group_index, + predicate, + converter, + parquet_metadata, + file_metrics, + ); + + let Some(selection) = selection else { + trace!("No pages pruned in prune_pages_in_one_row_group"); + continue; + }; + + debug!("Use filter and page index to create RowSelection {:?} from predicate: {:?}", + &selection, + predicate.predicate_expr(), + ); + + overall_selection = update_selection(overall_selection, selection); + + // if the overall selection has ruled out all rows, no need to + // continue with the other predicates + let selects_any = overall_selection + .as_ref() + .map(|selection| selection.selects_any()) + .unwrap_or(true); + + if !selects_any { + break; + } + } + + if let Some(overall_selection) = overall_selection { + if overall_selection.selects_any() { + let rows_skipped = rows_skipped(&overall_selection); + let rows_selected = rows_selected(&overall_selection); + trace!("Overall selection from predicate skipped {rows_skipped}, selected {rows_selected}: {overall_selection:?}"); + total_skip += rows_skipped; + total_select += rows_selected; + access_plan.scan_selection(row_group_index, overall_selection) } else { + // Selection skips all rows, so skip the entire row group + let rows_skipped = groups[row_group_index].num_rows() as usize; + access_plan.skip(row_group_index); + total_skip += rows_skipped; trace!( - "Did not have enough metadata to prune with page indexes, \ - falling back to all rows", + "Overall selection from predicate is empty, \ + skipping all {rows_skipped} rows in row group {row_group_index}" ); - // fallback select all rows - let all_selected = - vec![RowSelector::select(groups[*r].num_rows() as usize)]; - selectors.push(all_selected); } } - debug!( - "Use filter and page index create RowSelection {:?} from predicate: {:?}", - &selectors, - predicate.predicate_expr(), - ); - row_selections.push(selectors.into_iter().flatten().collect::>()); } - let final_selection = combine_multi_col_selection(row_selections); - let total_skip = - final_selection.iter().fold( - 0, - |acc, x| { - if x.skip { - acc + x.row_count - } else { - acc - } - }, - ); - file_metrics.page_index_rows_filtered.add(total_skip); - Ok(Some(final_selection)) + file_metrics.page_index_rows_pruned.add(total_skip); + file_metrics.page_index_rows_matched.add(total_select); + access_plan } - /// Returns the number of filters in the [`PagePruningPredicate`] + /// Returns the number of filters in the [`PagePruningAccessPlanFilter`] pub fn filter_number(&self) -> usize { self.predicates.len() } } -/// Returns the column index in the row parquet schema for the single -/// column of a single column pruning predicate. -/// -/// For example, give the predicate `y > 5` -/// -/// And columns in the RowGroupMetadata like `['x', 'y', 'z']` will -/// return 1. -/// -/// Returns `None` if the column is not found, or if there are no -/// required columns, which is the case for predicate like `abs(i) = -/// 1` which are rewritten to `lit(true)` -/// -/// Panics: -/// -/// If the predicate contains more than one column reference (assumes -/// that `extract_page_index_push_down_predicates` only returns -/// predicate with one col) -fn find_column_index( - predicate: &PruningPredicate, - arrow_schema: &Schema, - parquet_schema: &SchemaDescriptor, -) -> Option { - let mut found_required_column: Option<&Column> = None; - - for required_column_details in predicate.required_columns().iter() { - let column = &required_column_details.0; - if let Some(found_required_column) = found_required_column.as_ref() { - // make sure it is the same name we have seen previously - assert_eq!( - column.name(), - found_required_column.name(), - "Unexpected multi column predicate" - ); - } else { - found_required_column = Some(column); - } - } +/// returns the number of rows skipped in the selection +/// TODO should this be upstreamed to RowSelection? +fn rows_skipped(selection: &RowSelection) -> usize { + selection + .iter() + .fold(0, |acc, x| if x.skip { acc + x.row_count } else { acc }) +} - let Some(column) = found_required_column.as_ref() else { - trace!("No column references in pruning predicate"); - return None; - }; +/// returns the number of rows not skipped in the selection +/// TODO should this be upstreamed to RowSelection? +fn rows_selected(selection: &RowSelection) -> usize { + selection + .iter() + .fold(0, |acc, x| if x.skip { acc } else { acc + x.row_count }) +} - parquet_column(parquet_schema, arrow_schema, column.name()).map(|x| x.0) +fn update_selection( + current_selection: Option, + row_selection: RowSelection, +) -> Option { + match current_selection { + None => Some(row_selection), + Some(current_selection) => Some(current_selection.intersection(&row_selection)), + } } -/// Intersects the [`RowSelector`]s +/// Returns a [`RowSelection`] for the rows in this row group to scan. /// -/// For exampe, given: -/// * `RowSelector1: [ Skip(0~199), Read(200~299)]` -/// * `RowSelector2: [ Skip(0~99), Read(100~249), Skip(250~299)]` +/// This Row Selection is formed from the page index and the predicate skips row +/// ranges that can be ruled out based on the predicate. /// -/// The final selection is the intersection of these `RowSelector`s: -/// * `final_selection:[ Skip(0~199), Read(200~249), Skip(250~299)]` -fn combine_multi_col_selection(row_selections: Vec>) -> RowSelection { - row_selections - .into_iter() - .map(RowSelection::from) - .reduce(|s1, s2| s1.intersection(&s2)) - .unwrap() -} - +/// Returns `None` if there is an error evaluating the predicate or the required +/// page information is not present. fn prune_pages_in_one_row_group( - group: &RowGroupMetaData, - predicate: &PruningPredicate, - col_offset_indexes: Option<&Vec>, - col_page_indexes: Option<&Index>, - col_desc: &ColumnDescriptor, + row_group_index: usize, + pruning_predicate: &PruningPredicate, + converter: StatisticsConverter<'_>, + parquet_metadata: &ParquetMetaData, metrics: &ParquetFileMetrics, -) -> Result> { - let num_rows = group.num_rows() as usize; - if let (Some(col_offset_indexes), Some(col_page_indexes)) = - (col_offset_indexes, col_page_indexes) - { - let target_type = parquet_to_arrow_decimal_type(col_desc); - let pruning_stats = PagesPruningStatistics { - col_page_indexes, - col_offset_indexes, - target_type: &target_type, - num_rows_in_row_group: group.num_rows(), - }; +) -> Option { + let pruning_stats = + PagesPruningStatistics::try_new(row_group_index, converter, parquet_metadata)?; + + // Each element in values is a boolean indicating whether the page may have + // values that match the predicate (true) or could not possibly have values + // that match the predicate (false). + let values = match pruning_predicate.prune(&pruning_stats) { + Ok(values) => values, + Err(e) => { + debug!("Error evaluating page index predicate values {e}"); + metrics.predicate_evaluation_errors.add(1); + return None; + } + }; - match predicate.prune(&pruning_stats) { - Ok(values) => { - let mut vec = Vec::with_capacity(values.len()); - let row_vec = create_row_count_in_each_page(col_offset_indexes, num_rows); - assert_eq!(row_vec.len(), values.len()); - let mut sum_row = *row_vec.first().unwrap(); - let mut selected = *values.first().unwrap(); - trace!("Pruned to {:?} using {:?}", values, pruning_stats); - for (i, &f) in values.iter().enumerate().skip(1) { - if f == selected { - sum_row += *row_vec.get(i).unwrap(); - } else { - let selector = if selected { - RowSelector::select(sum_row) - } else { - RowSelector::skip(sum_row) - }; - vec.push(selector); - sum_row = *row_vec.get(i).unwrap(); - selected = f; - } - } + // Convert the information of which pages to skip into a RowSelection + // that describes the ranges of rows to skip. + let Some(page_row_counts) = pruning_stats.page_row_counts() else { + debug!( + "Can not determine page row counts for row group {row_group_index}, skipping" + ); + metrics.predicate_evaluation_errors.add(1); + return None; + }; - let selector = if selected { - RowSelector::select(sum_row) - } else { - RowSelector::skip(sum_row) - }; - vec.push(selector); - return Ok(vec); - } - // stats filter array could not be built - // return a result which will not filter out any pages - Err(e) => { - debug!("Error evaluating page index predicate values {e}"); - metrics.predicate_evaluation_errors.add(1); - return Ok(vec![RowSelector::select(group.num_rows() as usize)]); - } + let mut vec = Vec::with_capacity(values.len()); + assert_eq!(page_row_counts.len(), values.len()); + let mut sum_row = *page_row_counts.first().unwrap(); + let mut selected = *values.first().unwrap(); + trace!("Pruned to {:?} using {:?}", values, pruning_stats); + for (i, &f) in values.iter().enumerate().skip(1) { + if f == selected { + sum_row += *page_row_counts.get(i).unwrap(); + } else { + let selector = if selected { + RowSelector::select(sum_row) + } else { + RowSelector::skip(sum_row) + }; + vec.push(selector); + sum_row = *page_row_counts.get(i).unwrap(); + selected = f; } } - Err(DataFusionError::ParquetError(ParquetError::General( - "Got some error in prune_pages_in_one_row_group, plz try open the debuglog mode" - .to_string(), - ))) -} -fn create_row_count_in_each_page( - location: &[PageLocation], - num_rows: usize, -) -> Vec { - let mut vec = Vec::with_capacity(location.len()); - location.windows(2).for_each(|x| { - let start = x[0].first_row_index as usize; - let end = x[1].first_row_index as usize; - vec.push(end - start); - }); - vec.push(num_rows - location.last().unwrap().first_row_index as usize); - vec + let selector = if selected { + RowSelector::select(sum_row) + } else { + RowSelector::skip(sum_row) + }; + vec.push(selector); + Some(RowSelection::from(vec)) } -/// Wraps one col page_index in one rowGroup statistics in a way -/// that implements [`PruningStatistics`] +/// Implement [`PruningStatistics`] for one column's PageIndex (column_index + offset_index) #[derive(Debug)] struct PagesPruningStatistics<'a> { - col_page_indexes: &'a Index, - col_offset_indexes: &'a Vec, - // target_type means the logical type in schema: like 'DECIMAL' is the logical type, but the - // real physical type in parquet file may be `INT32, INT64, FIXED_LEN_BYTE_ARRAY` - target_type: &'a Option, - num_rows_in_row_group: i64, + row_group_index: usize, + row_group_metadatas: &'a [RowGroupMetaData], + converter: StatisticsConverter<'a>, + column_index: &'a ParquetColumnIndex, + offset_index: &'a ParquetOffsetIndex, + page_offsets: &'a Vec, } -// Extract the min or max value calling `func` from page idex -macro_rules! get_min_max_values_for_page_index { - ($self:expr, $func:ident) => {{ - match $self.col_page_indexes { - Index::NONE => None, - Index::INT32(index) => { - match $self.target_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - let vec: Vec> = vec - .iter() - .map(|x| x.$func().and_then(|x| Some(*x as i128))) - .collect(); - Decimal128Array::from(vec) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - Some(Arc::new(Int32Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - } - } - Index::INT64(index) => { - match $self.target_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - let vec: Vec> = vec - .iter() - .map(|x| x.$func().and_then(|x| Some(*x as i128))) - .collect(); - Decimal128Array::from(vec) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - Some(Arc::new(Int64Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - } - } - Index::FLOAT(index) => { - let vec = &index.indexes; - Some(Arc::new(Float32Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::DOUBLE(index) => { - let vec = &index.indexes; - Some(Arc::new(Float64Array::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::BOOLEAN(index) => { - let vec = &index.indexes; - Some(Arc::new(BooleanArray::from_iter( - vec.iter().map(|x| x.$func().cloned()), - ))) - } - Index::BYTE_ARRAY(index) => match $self.target_type { - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - Decimal128Array::from( - vec.iter() - .map(|x| { - x.$func() - .and_then(|x| Some(from_bytes_to_i128(x.as_ref()))) - }) - .collect::>>(), - ) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => { - let vec = &index.indexes; - let array: StringArray = vec - .iter() - .map(|x| x.$func()) - .map(|x| x.and_then(|x| std::str::from_utf8(x.as_ref()).ok())) - .collect(); - Some(Arc::new(array)) - } - }, - Index::INT96(_) => { - //Todo support these type - None - } - Index::FIXED_LEN_BYTE_ARRAY(index) => match $self.target_type { - Some(DataType::Decimal128(precision, scale)) => { - let vec = &index.indexes; - Decimal128Array::from( - vec.iter() - .map(|x| { - x.$func() - .and_then(|x| Some(from_bytes_to_i128(x.as_ref()))) - }) - .collect::>>(), - ) - .with_precision_and_scale(*precision, *scale) - .ok() - .map(|arr| Arc::new(arr) as ArrayRef) - } - _ => None, - }, - } - }}; -} +impl<'a> PagesPruningStatistics<'a> { + /// Creates a new [`PagesPruningStatistics`] for a column in a row group, if + /// possible. + /// + /// Returns None if the `parquet_metadata` does not have sufficient + /// information to create the statistics. + fn try_new( + row_group_index: usize, + converter: StatisticsConverter<'a>, + parquet_metadata: &'a ParquetMetaData, + ) -> Option { + let Some(parquet_column_index) = converter.parquet_column_index() else { + trace!( + "Column {:?} not in parquet file, skipping", + converter.arrow_field() + ); + return None; + }; + + let column_index = parquet_metadata.column_index()?; + let offset_index = parquet_metadata.offset_index()?; + let row_group_metadatas = parquet_metadata.row_groups(); + let Some(row_group_page_offsets) = offset_index.get(row_group_index) else { + trace!("No page offsets for row group {row_group_index}, skipping"); + return None; + }; + let Some(offset_index_metadata) = + row_group_page_offsets.get(parquet_column_index) + else { + trace!( + "No page offsets for column {:?} in row group {row_group_index}, skipping", + converter.arrow_field() + ); + return None; + }; + let page_offsets = offset_index_metadata.page_locations(); + + Some(Self { + row_group_index, + row_group_metadatas, + converter, + column_index, + offset_index, + page_offsets, + }) + } + + /// return the row counts in each data page, if possible. + fn page_row_counts(&self) -> Option> { + let row_group_metadata = self + .row_group_metadatas + .get(self.row_group_index) + // fail fast/panic if row_group_index is out of bounds + .unwrap(); + + let num_rows_in_row_group = row_group_metadata.num_rows() as usize; + + let page_offsets = self.page_offsets; + let mut vec = Vec::with_capacity(page_offsets.len()); + page_offsets.windows(2).for_each(|x| { + let start = x[0].first_row_index as usize; + let end = x[1].first_row_index as usize; + vec.push(end - start); + }); + vec.push(num_rows_in_row_group - page_offsets.last()?.first_row_index as usize); + Some(vec) + } +} impl<'a> PruningStatistics for PagesPruningStatistics<'a> { fn min_values(&self, _column: &datafusion_common::Column) -> Option { - get_min_max_values_for_page_index!(self, min) + match self.converter.data_page_mins( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(min_values) => Some(min_values), + Err(e) => { + debug!("Error evaluating data page min values {e}"); + None + } + } } fn max_values(&self, _column: &datafusion_common::Column) -> Option { - get_min_max_values_for_page_index!(self, max) + match self.converter.data_page_maxes( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(min_values) => Some(min_values), + Err(e) => { + debug!("Error evaluating data page max values {e}"); + None + } + } } fn num_containers(&self) -> usize { - self.col_offset_indexes.len() + self.page_offsets.len() } fn null_counts(&self, _column: &datafusion_common::Column) -> Option { - match self.col_page_indexes { - Index::NONE => None, - Index::BOOLEAN(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT32(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT64(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::FLOAT(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::DOUBLE(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::INT96(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::BYTE_ARRAY(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), - Index::FIXED_LEN_BYTE_ARRAY(index) => Some(Arc::new(Int64Array::from_iter( - index.indexes.iter().map(|x| x.null_count), - ))), + match self.converter.data_page_null_counts( + self.column_index, + self.offset_index, + [&self.row_group_index], + ) { + Ok(null_counts) => Some(Arc::new(null_counts)), + Err(e) => { + debug!("Error evaluating data page null counts {e}"); + None + } } } fn row_counts(&self, _column: &datafusion_common::Column) -> Option { - // see https://github.com/apache/arrow-rs/blob/91f0b1771308609ca27db0fb1d2d49571b3980d8/parquet/src/file/metadata.rs#L979-L982 - - let row_count_per_page = self.col_offset_indexes.windows(2).map(|location| { - Some(location[1].first_row_index - location[0].first_row_index) - }); - - // append the last page row count - let row_count_per_page = row_count_per_page.chain(std::iter::once(Some( - self.num_rows_in_row_group - - self.col_offset_indexes.last().unwrap().first_row_index, - ))); - - Some(Arc::new(Int64Array::from_iter(row_count_per_page))) + match self.converter.data_page_row_counts( + self.offset_index, + self.row_group_metadatas, + [&self.row_group_index], + ) { + Ok(row_counts) => row_counts.map(|a| Arc::new(a) as ArrayRef), + Err(e) => { + debug!("Error evaluating data page row counts {e}"); + None + } + } } fn contained( diff --git a/datafusion/core/src/datasource/physical_plan/parquet/reader.rs b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs new file mode 100644 index 000000000000..8a4ba136fc96 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/reader.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ParquetFileReaderFactory`] and [`DefaultParquetFileReaderFactory`] for +//! low level control of parquet file readers + +use crate::datasource::physical_plan::{FileMeta, ParquetFileMetrics}; +use bytes::Bytes; +use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; +use futures::future::BoxFuture; +use object_store::ObjectStore; +use parquet::arrow::async_reader::{AsyncFileReader, ParquetObjectReader}; +use parquet::file::metadata::ParquetMetaData; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::Arc; + +/// Interface for reading parquet files. +/// +/// The combined implementations of [`ParquetFileReaderFactory`] and +/// [`AsyncFileReader`] can be used to provide custom data access operations +/// such as pre-cached metadata, I/O coalescing, etc. +/// +/// See [`DefaultParquetFileReaderFactory`] for a simple implementation. +pub trait ParquetFileReaderFactory: Debug + Send + Sync + 'static { + /// Provides an `AsyncFileReader` for reading data from a parquet file specified + /// + /// # Notes + /// + /// If the resulting [`AsyncFileReader`] returns `ParquetMetaData` without + /// page index information, the reader will load it on demand. Thus it is important + /// to ensure that the returned `ParquetMetaData` has the necessary information + /// if you wish to avoid a subsequent I/O + /// + /// # Arguments + /// * partition_index - Index of the partition (for reporting metrics) + /// * file_meta - The file to be read + /// * metadata_size_hint - If specified, the first IO reads this many bytes from the footer + /// * metrics - Execution metrics + fn create_reader( + &self, + partition_index: usize, + file_meta: FileMeta, + metadata_size_hint: Option, + metrics: &ExecutionPlanMetricsSet, + ) -> datafusion_common::Result>; +} + +/// Default implementation of [`ParquetFileReaderFactory`] +/// +/// This implementation: +/// 1. Reads parquet directly from an underlying [`ObjectStore`] instance. +/// 2. Reads the footer and page metadata on demand. +/// 3. Does not cache metadata or coalesce I/O operations. +#[derive(Debug)] +pub struct DefaultParquetFileReaderFactory { + store: Arc, +} + +impl DefaultParquetFileReaderFactory { + /// Create a new `DefaultParquetFileReaderFactory`. + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +/// Implements [`AsyncFileReader`] for a parquet file in object storage. +/// +/// This implementation uses the [`ParquetObjectReader`] to read data from the +/// object store on demand, as required, tracking the number of bytes read. +/// +/// This implementation does not coalesce I/O operations or cache bytes. Such +/// optimizations can be done either at the object store level or by providing a +/// custom implementation of [`ParquetFileReaderFactory`]. +pub(crate) struct ParquetFileReader { + pub file_metrics: ParquetFileMetrics, + pub inner: ParquetObjectReader, +} + +impl AsyncFileReader for ParquetFileReader { + fn get_bytes( + &mut self, + range: Range, + ) -> BoxFuture<'_, parquet::errors::Result> { + self.file_metrics.bytes_scanned.add(range.end - range.start); + self.inner.get_bytes(range) + } + + fn get_byte_ranges( + &mut self, + ranges: Vec>, + ) -> BoxFuture<'_, parquet::errors::Result>> + where + Self: Send, + { + let total = ranges.iter().map(|r| r.end - r.start).sum(); + self.file_metrics.bytes_scanned.add(total); + self.inner.get_byte_ranges(ranges) + } + + fn get_metadata( + &mut self, + ) -> BoxFuture<'_, parquet::errors::Result>> { + self.inner.get_metadata() + } +} + +impl ParquetFileReaderFactory for DefaultParquetFileReaderFactory { + fn create_reader( + &self, + partition_index: usize, + file_meta: FileMeta, + metadata_size_hint: Option, + metrics: &ExecutionPlanMetricsSet, + ) -> datafusion_common::Result> { + let file_metrics = ParquetFileMetrics::new( + partition_index, + file_meta.location().as_ref(), + metrics, + ); + let store = Arc::clone(&self.store); + let mut inner = ParquetObjectReader::new(store, file_meta.object_meta); + + if let Some(hint) = metadata_size_hint { + inner = inner.with_footer_size_hint(hint) + }; + + Ok(Box::new(ParquetFileReader { + inner, + file_metrics, + })) + } +} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 5f89ff087f70..e876f840d1eb 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -15,78 +15,115 @@ // specific language governing permissions and limitations // under the License. +//! Utilities to push down of DataFusion filter predicates (any DataFusion +//! `PhysicalExpr` that evaluates to a [`BooleanArray`]) to the parquet decoder +//! level in `arrow-rs`. +//! +//! DataFusion will use a `ParquetRecordBatchStream` to read data from parquet +//! into [`RecordBatch`]es. +//! +//! The `ParquetRecordBatchStream` takes an optional `RowFilter` which is itself +//! a Vec of `Box`. During decoding, the predicates are +//! evaluated in order, to generate a mask which is used to avoid decoding rows +//! in projected columns which do not pass the filter which can significantly +//! reduce the amount of compute required for decoding and thus improve query +//! performance. +//! +//! Since the predicates are applied serially in the order defined in the +//! `RowFilter`, the optimal ordering depends on the exact filters. The best +//! filters to execute first have two properties: +//! +//! 1. They are relatively inexpensive to evaluate (e.g. they read +//! column chunks which are relatively small) +//! +//! 2. They filter many (contiguous) rows, reducing the amount of decoding +//! required for subsequent filters and projected columns +//! +//! If requested, this code will reorder the filters based on heuristics try and +//! reduce the evaluation cost. +//! +//! The basic algorithm for constructing the `RowFilter` is as follows +//! +//! 1. Break conjunctions into separate predicates. An expression +//! like `a = 1 AND (b = 2 AND c = 3)` would be +//! separated into the expressions `a = 1`, `b = 2`, and `c = 3`. +//! 2. Determine whether each predicate can be evaluated as an `ArrowPredicate`. +//! 3. Determine, for each predicate, the total compressed size of all +//! columns required to evaluate the predicate. +//! 4. Determine, for each predicate, whether all columns required to +//! evaluate the expression are sorted. +//! 5. Re-order the predicate by total size (from step 3). +//! 6. Partition the predicates according to whether they are sorted (from step 4) +//! 7. "Compile" each predicate `Expr` to a `DatafusionArrowPredicate`. +//! 8. Build the `RowFilter` with the sorted predicates followed by +//! the unsorted predicates. Within each partition, predicates are +//! still be sorted by size. + +use std::cmp::Ordering; use std::collections::BTreeSet; use std::sync::Arc; -use super::ParquetFileMetrics; -use crate::physical_plan::metrics; - use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; +use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; +use parquet::arrow::ProjectionMask; +use parquet::file::metadata::ParquetMetaData; + +use crate::datasource::schema_adapter::SchemaMapper; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; -use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; -use parquet::arrow::ProjectionMask; -use parquet::file::metadata::ParquetMetaData; +use crate::physical_plan::metrics; + +use super::ParquetFileMetrics; -/// This module contains utilities for enabling the pushdown of DataFusion filter predicates (which -/// can be any DataFusion `Expr` that evaluates to a `BooleanArray`) to the parquet decoder level in `arrow-rs`. -/// DataFusion will use a `ParquetRecordBatchStream` to read data from parquet into arrow `RecordBatch`es. -/// When constructing the `ParquetRecordBatchStream` you can provide a `RowFilter` which is itself just a vector -/// of `Box`. During decoding, the predicates are evaluated to generate a mask which is used -/// to avoid decoding rows in projected columns which are not selected which can significantly reduce the amount -/// of compute required for decoding. +/// A "compiled" predicate passed to `ParquetRecordBatchStream` to perform +/// row-level filtering during parquet decoding. /// -/// Since the predicates are applied serially in the order defined in the `RowFilter`, the optimal ordering -/// will depend on the exact filters. The best filters to execute first have two properties: -/// 1. The are relatively inexpensive to evaluate (e.g. they read column chunks which are relatively small) -/// 2. They filter a lot of rows, reducing the amount of decoding required for subsequent filters and projected columns +/// See the module level documentation for more information. /// -/// Given the metadata exposed by parquet, the selectivity of filters is not easy to estimate so the heuristics we use here primarily -/// focus on the evaluation cost. +/// Implements the `ArrowPredicate` trait used by the parquet decoder /// -/// The basic algorithm for constructing the `RowFilter` is as follows -/// 1. Recursively break conjunctions into separate predicates. An expression like `a = 1 AND (b = 2 AND c = 3)` would be -/// separated into the expressions `a = 1`, `b = 2`, and `c = 3`. -/// 2. Determine whether each predicate is suitable as an `ArrowPredicate`. As long as the predicate does not reference any projected columns -/// or columns with non-primitive types, then it is considered suitable. -/// 3. Determine, for each predicate, the total compressed size of all columns required to evaluate the predicate. -/// 4. Determine, for each predicate, whether all columns required to evaluate the expression are sorted. -/// 5. Re-order the predicate by total size (from step 3). -/// 6. Partition the predicates according to whether they are sorted (from step 4) -/// 7. "Compile" each predicate `Expr` to a `DatafusionArrowPredicate`. -/// 8. Build the `RowFilter` with the sorted predicates followed by the unsorted predicates. Within each partition -/// the predicates will still be sorted by size. - -/// A predicate which can be passed to `ParquetRecordBatchStream` to perform row-level -/// filtering during parquet decoding. +/// An expression can be evaluated as a `DatafusionArrowPredicate` if it: +/// * Does not reference any projected columns +/// * Does not reference columns with non-primitive types (e.g. structs / lists) #[derive(Debug)] pub(crate) struct DatafusionArrowPredicate { + /// the filter expression physical_expr: Arc, + /// Path to the columns in the parquet schema required to evaluate the + /// expression projection_mask: ProjectionMask, + /// Columns required to evaluate the expression in the arrow schema projection: Vec, /// how many rows were filtered out by this predicate - rows_filtered: metrics::Count, + rows_pruned: metrics::Count, + /// how many rows passed this predicate + rows_matched: metrics::Count, /// how long was spent evaluating this predicate time: metrics::Time, + /// used to perform type coercion while filtering rows + schema_mapping: Arc, } impl DatafusionArrowPredicate { + /// Create a new `DatafusionArrowPredicate` from a `FilterCandidate` pub fn try_new( candidate: FilterCandidate, schema: &Schema, metadata: &ParquetMetaData, - rows_filtered: metrics::Count, + rows_pruned: metrics::Count, + rows_matched: metrics::Count, time: metrics::Time, + schema_mapping: Arc, ) -> Result { let schema = Arc::new(schema.project(&candidate.projection)?); let physical_expr = reassign_predicate_columns(candidate.expr, &schema, true)?; @@ -96,7 +133,7 @@ impl DatafusionArrowPredicate { // on the order they appear in the file let projection = match candidate.projection.len() { 0 | 1 => vec![], - _ => remap_projection(&candidate.projection), + 2.. => remap_projection(&candidate.projection), }; Ok(Self { @@ -106,8 +143,10 @@ impl DatafusionArrowPredicate { metadata.file_metadata().schema_descr(), candidate.projection, ), - rows_filtered, + rows_pruned, + rows_matched, time, + schema_mapping, }) } } @@ -117,36 +156,42 @@ impl ArrowPredicate for DatafusionArrowPredicate { &self.projection_mask } - fn evaluate(&mut self, batch: RecordBatch) -> ArrowResult { - let batch = match self.projection.is_empty() { - true => batch, - false => batch.project(&self.projection)?, + fn evaluate(&mut self, mut batch: RecordBatch) -> ArrowResult { + if !self.projection.is_empty() { + batch = batch.project(&self.projection)?; }; + let batch = self.schema_mapping.map_partial_batch(batch)?; + // scoped timer updates on drop let mut timer = self.time.timer(); - match self - .physical_expr + + self.physical_expr .evaluate(&batch) .and_then(|v| v.into_array(batch.num_rows())) - { - Ok(array) => { + .and_then(|array| { let bool_arr = as_boolean_array(&array)?.clone(); - let num_filtered = bool_arr.len() - bool_arr.true_count(); - self.rows_filtered.add(num_filtered); + let num_matched = bool_arr.true_count(); + let num_pruned = bool_arr.len() - num_matched; + self.rows_pruned.add(num_pruned); + self.rows_matched.add(num_matched); timer.stop(); Ok(bool_arr) - } - Err(e) => Err(ArrowError::ComputeError(format!( - "Error evaluating filter predicate: {e:?}" - ))), - } + }) + .map_err(|e| { + ArrowError::ComputeError(format!( + "Error evaluating filter predicate: {e:?}" + )) + }) } } -/// A candidate expression for creating a `RowFilter` contains the -/// expression as well as data to estimate the cost of evaluating -/// the resulting expression. +/// A candidate expression for creating a `RowFilter`. +/// +/// Each candidate contains the expression as well as data to estimate the cost +/// of evaluating the resulting expression. +/// +/// See the module level documentation for more information. pub(crate) struct FilterCandidate { expr: Arc, required_bytes: usize, @@ -154,20 +199,50 @@ pub(crate) struct FilterCandidate { projection: Vec, } -/// Helper to build a `FilterCandidate`. This will do several things +/// Helper to build a `FilterCandidate`. +/// +/// This will do several things /// 1. Determine the columns required to evaluate the expression /// 2. Calculate data required to estimate the cost of evaluating the filter -/// 3. Rewrite column expressions in the predicate which reference columns not in the particular file schema. -/// This is relevant in the case where we have determined the table schema by merging all individual file schemas -/// and any given file may or may not contain all columns in the merged schema. If a particular column is not present -/// we replace the column expression with a literal expression that produces a null value. +/// 3. Rewrite column expressions in the predicate which reference columns not +/// in the particular file schema. +/// +/// # Schema Rewrite +/// +/// When parquet files are read in the context of "schema evolution" there are +/// potentially wo schemas: +/// +/// 1. The table schema (the columns of the table that the parquet file is part of) +/// 2. The file schema (the columns actually in the parquet file) +/// +/// There are times when the table schema contains columns that are not in the +/// file schema, such as when new columns have been added in new parquet files +/// but old files do not have the columns. +/// +/// When a file is missing a column from the table schema, the value of the +/// missing column is filled in with `NULL` via a `SchemaAdapter`. +/// +/// When a predicate is pushed down to the parquet reader, the predicate is +/// evaluated in the context of the file schema. If the predicate references a +/// column that is in the table schema but not in the file schema, the column +/// reference must be rewritten to a literal expression that represents the +/// `NULL` value that would be produced by the `SchemaAdapter`. +/// +/// For example, if: +/// * The table schema is `id, name, address` +/// * The file schema is `id, name` (missing the `address` column) +/// * predicate is `address = 'foo'` +/// +/// When evaluating the predicate as a filter on the parquet file, the predicate +/// must be rewritten to `NULL = 'foo'` as the `address` column will be filled +/// in with `NULL` values during the rest of the evaluation. struct FilterCandidateBuilder<'a> { expr: Arc, + /// The schema of this parquet file file_schema: &'a Schema, + /// The schema of the table (merged schema) -- columns may be in different + /// order than in the file and have columns that are not in the file schema table_schema: &'a Schema, - required_column_indices: BTreeSet, - non_primitive_columns: bool, - projected_columns: bool, } impl<'a> FilterCandidateBuilder<'a> { @@ -180,36 +255,88 @@ impl<'a> FilterCandidateBuilder<'a> { expr, file_schema, table_schema, - required_column_indices: BTreeSet::default(), + } + } + + /// Attempt to build a `FilterCandidate` from the expression + /// + /// # Return values + /// + /// * `Ok(Some(candidate))` if the expression can be used as an ArrowFilter + /// * `Ok(None)` if the expression cannot be used as an ArrowFilter + /// * `Err(e)` if an error occurs while building the candidate + pub fn build(self, metadata: &ParquetMetaData) -> Result> { + let Some((required_indices, rewritten_expr)) = + pushdown_columns(self.expr, self.file_schema, self.table_schema)? + else { + return Ok(None); + }; + + let required_bytes = size_of_columns(&required_indices, metadata)?; + let can_use_index = columns_sorted(&required_indices, metadata)?; + + Ok(Some(FilterCandidate { + expr: rewritten_expr, + required_bytes, + can_use_index, + projection: required_indices.into_iter().collect(), + })) + } +} + +// a struct that implements TreeNodeRewriter to traverse a PhysicalExpr tree structure to determine +// if any column references in the expression would prevent it from being predicate-pushed-down. +// if non_primitive_columns || projected_columns, it can't be pushed down. +// can't be reused between calls to `rewrite`; each construction must be used only once. +struct PushdownChecker<'schema> { + /// Does the expression require any non-primitive columns (like structs)? + non_primitive_columns: bool, + /// Does the expression reference any columns that are in the table + /// schema but not in the file schema? + projected_columns: bool, + // the indices of all the columns found within the given expression which exist inside the given + // [`file_schema`] + required_column_indices: BTreeSet, + file_schema: &'schema Schema, + table_schema: &'schema Schema, +} + +impl<'schema> PushdownChecker<'schema> { + fn new(file_schema: &'schema Schema, table_schema: &'schema Schema) -> Self { + Self { non_primitive_columns: false, projected_columns: false, + required_column_indices: BTreeSet::default(), + file_schema, + table_schema, } } - pub fn build( - mut self, - metadata: &ParquetMetaData, - ) -> Result> { - let expr = self.expr.clone().rewrite(&mut self).data()?; - - if self.non_primitive_columns || self.projected_columns { - Ok(None) - } else { - let required_bytes = - size_of_columns(&self.required_column_indices, metadata)?; - let can_use_index = columns_sorted(&self.required_column_indices, metadata)?; - - Ok(Some(FilterCandidate { - expr, - required_bytes, - can_use_index, - projection: self.required_column_indices.into_iter().collect(), - })) + fn check_single_column(&mut self, column_name: &str) -> Option { + if let Ok(idx) = self.file_schema.index_of(column_name) { + self.required_column_indices.insert(idx); + + if DataType::is_nested(self.file_schema.field(idx).data_type()) { + self.non_primitive_columns = true; + return Some(TreeNodeRecursion::Jump); + } + } else if self.table_schema.index_of(column_name).is_err() { + // If the column does not exist in the (un-projected) table schema then + // it must be a projected column. + self.projected_columns = true; + return Some(TreeNodeRecursion::Jump); } + + None + } + + #[inline] + fn prevents_pushdown(&self) -> bool { + self.non_primitive_columns || self.projected_columns } } -impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { +impl<'schema> TreeNodeRewriter for PushdownChecker<'schema> { type Node = Arc; fn f_down( @@ -217,42 +344,39 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { node: Arc, ) -> Result>> { if let Some(column) = node.as_any().downcast_ref::() { - if let Ok(idx) = self.file_schema.index_of(column.name()) { - self.required_column_indices.insert(idx); - - if DataType::is_nested(self.file_schema.field(idx).data_type()) { - self.non_primitive_columns = true; - return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); - } - } else if self.table_schema.index_of(column.name()).is_err() { - // If the column does not exist in the (un-projected) table schema then - // it must be a projected column. - self.projected_columns = true; - return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); + if let Some(recursion) = self.check_single_column(column.name()) { + return Ok(Transformed::new(node, false, recursion)); } } Ok(Transformed::no(node)) } + /// After visiting all children, rewrite column references to nulls if + /// they are not in the file schema. + /// We do this because they won't be relevant if they're not in the file schema, since that's + /// the only thing we're dealing with here as this is only used for the parquet pushdown during + /// scanning fn f_up( &mut self, expr: Arc, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { + // if the expression is a column, is it in the file schema? if self.file_schema.field_with_name(column.name()).is_err() { - // the column expr must be in the table schema - return match self.table_schema.field_with_name(column.name()) { - Ok(field) => { - // return the null value corresponding to the data type + return self + .table_schema + .field_with_name(column.name()) + .and_then(|field| { + // Replace the column reference with a NULL (using the type from the table schema) + // e.g. `column = 'foo'` is rewritten be transformed to `NULL = 'foo'` + // + // See comments on `FilterCandidateBuilder` for more information let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Transformed::yes(Arc::new(Literal::new(null_value)))) - } - Err(e) => { - // If the column is not in the table schema, should throw the error - arrow_err!(e) - } - }; + Ok(Transformed::yes(Arc::new(Literal::new(null_value)) as _)) + }) + // If the column is not in the table schema, should throw the error + .map_err(|e| arrow_datafusion_err!(e)); } } @@ -260,6 +384,69 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } } +type ProjectionAndExpr = (BTreeSet, Arc); + +// Checks if a given expression can be pushed down into `ParquetExec` as opposed to being evaluated +// post-parquet-scan in a `FilterExec`. If it can be pushed down, this returns returns all the +// columns in the given expression so that they can be used in the parquet scanning, along with the +// expression rewritten as defined in [`PushdownChecker::f_up`] +fn pushdown_columns( + expr: Arc, + file_schema: &Schema, + table_schema: &Schema, +) -> Result> { + let mut checker = PushdownChecker::new(file_schema, table_schema); + + let expr = expr.rewrite(&mut checker).data()?; + + Ok((!checker.prevents_pushdown()).then_some((checker.required_column_indices, expr))) +} + +/// creates a PushdownChecker for a single use to check a given column with the given schemes. Used +/// to check preemptively if a column name would prevent pushdowning. +/// effectively does the inverse of [`pushdown_columns`] does, but with a single given column +/// (instead of traversing the entire tree to determine this) +fn would_column_prevent_pushdown( + column_name: &str, + file_schema: &Schema, + table_schema: &Schema, +) -> bool { + let mut checker = PushdownChecker::new(file_schema, table_schema); + + // the return of this is only used for [`PushdownChecker::f_down()`], so we can safely ignore + // it here. I'm just verifying we know the return type of this so nobody accidentally changes + // the return type of this fn and it gets implicitly ignored here. + let _: Option = checker.check_single_column(column_name); + + // and then return a value based on the state of the checker + checker.prevents_pushdown() +} + +/// Recurses through expr as a trea, finds all `column`s, and checks if any of them would prevent +/// this expression from being predicate pushed down. If any of them would, this returns false. +/// Otherwise, true. +pub fn can_expr_be_pushed_down_with_schemas( + expr: &datafusion_expr::Expr, + file_schema: &Schema, + table_schema: &Schema, +) -> bool { + let mut can_be_pushed = true; + expr.apply(|expr| match expr { + datafusion_expr::Expr::Column(column) => { + can_be_pushed &= + !would_column_prevent_pushdown(column.name(), file_schema, table_schema); + Ok(if can_be_pushed { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Stop + }) + } + _ => Ok(TreeNodeRecursion::Continue), + }) + .unwrap(); // we never return an Err, so we can safely unwrap this + can_be_pushed +} + /// Computes the projection required to go from the file's schema order to the projected /// order expected by this filter /// @@ -286,9 +473,11 @@ fn remap_projection(src: &[usize]) -> Vec { projection } -/// Calculate the total compressed size of all `Column's required for -/// predicate `Expr`. This should represent the total amount of file IO -/// required to evaluate the predicate. +/// Calculate the total compressed size of all `Column`'s required for +/// predicate `Expr`. +/// +/// This value represents the total amount of IO required to evaluate the +/// predicate. fn size_of_columns( columns: &BTreeSet, metadata: &ParquetMetaData, @@ -304,8 +493,10 @@ fn size_of_columns( Ok(total_size) } -/// For a given set of `Column`s required for predicate `Expr` determine whether all -/// columns are sorted. Sorted columns may be queried more efficiently in the presence of +/// For a given set of `Column`s required for predicate `Expr` determine whether +/// all columns are sorted. +/// +/// Sorted columns may be queried more efficiently in the presence of /// a PageIndex. fn columns_sorted( _columns: &BTreeSet, @@ -315,7 +506,20 @@ fn columns_sorted( Ok(false) } -/// Build a [`RowFilter`] from the given predicate `Expr` +/// Build a [`RowFilter`] from the given predicate `Expr` if possible +/// +/// # returns +/// * `Ok(Some(row_filter))` if the expression can be used as RowFilter +/// * `Ok(None)` if the expression cannot be used as an RowFilter +/// * `Err(e)` if an error occurs while building the filter +/// +/// Note that the returned `RowFilter` may not contains all conjuncts in the +/// original expression. This is because some conjuncts may not be able to be +/// evaluated as an `ArrowPredicate` and will be ignored. +/// +/// For example, if the expression is `a = 1 AND b = 2 AND c = 3` and `b = 2` +/// can not be evaluated for some reason, the returned `RowFilter` will contain +/// `a = 1` and `c = 3`. pub fn build_row_filter( expr: &Arc, file_schema: &Schema, @@ -323,87 +527,74 @@ pub fn build_row_filter( metadata: &ParquetMetaData, reorder_predicates: bool, file_metrics: &ParquetFileMetrics, + schema_mapping: Arc, ) -> Result> { - let rows_filtered = &file_metrics.pushdown_rows_filtered; - let time = &file_metrics.pushdown_eval_time; + let rows_pruned = &file_metrics.pushdown_rows_pruned; + let rows_matched = &file_metrics.pushdown_rows_matched; + let time = &file_metrics.row_pushdown_eval_time; + // Split into conjuncts: + // `a = 1 AND b = 2 AND c = 3` -> [`a = 1`, `b = 2`, `c = 3`] let predicates = split_conjunction(expr); + // Determine which conjuncts can be evaluated as ArrowPredicates, if any let mut candidates: Vec = predicates .into_iter() - .flat_map(|expr| { - if let Ok(candidate) = - FilterCandidateBuilder::new(expr.clone(), file_schema, table_schema) - .build(metadata) - { - candidate - } else { - None - } + .map(|expr| { + FilterCandidateBuilder::new(expr.clone(), file_schema, table_schema) + .build(metadata) }) + .collect::, _>>()? + .into_iter() + .flatten() .collect(); + // no candidates if candidates.is_empty() { - Ok(None) - } else if reorder_predicates { - candidates.sort_by_key(|c| c.required_bytes); - - let (indexed_candidates, other_candidates): (Vec<_>, Vec<_>) = - candidates.into_iter().partition(|c| c.can_use_index); - - let mut filters: Vec> = vec![]; - - for candidate in indexed_candidates { - let filter = DatafusionArrowPredicate::try_new( - candidate, - file_schema, - metadata, - rows_filtered.clone(), - time.clone(), - )?; - - filters.push(Box::new(filter)); - } - - for candidate in other_candidates { - let filter = DatafusionArrowPredicate::try_new( - candidate, - file_schema, - metadata, - rows_filtered.clone(), - time.clone(), - )?; + return Ok(None); + } - filters.push(Box::new(filter)); - } + if reorder_predicates { + candidates.sort_unstable_by(|c1, c2| { + match c1.can_use_index.cmp(&c2.can_use_index) { + Ordering::Equal => c1.required_bytes.cmp(&c2.required_bytes), + ord => ord, + } + }); + } - Ok(Some(RowFilter::new(filters))) - } else { - let mut filters: Vec> = vec![]; - for candidate in candidates { - let filter = DatafusionArrowPredicate::try_new( + candidates + .into_iter() + .map(|candidate| { + DatafusionArrowPredicate::try_new( candidate, file_schema, metadata, - rows_filtered.clone(), + rows_pruned.clone(), + rows_matched.clone(), time.clone(), - )?; - - filters.push(Box::new(filter)); - } - - Ok(Some(RowFilter::new(filters))) - } + Arc::clone(&schema_mapping), + ) + .map(|pred| Box::new(pred) as _) + }) + .collect::, _>>() + .map(|filters| Some(RowFilter::new(filters))) } #[cfg(test)] mod test { use super::*; + use crate::datasource::schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapterFactory, + }; + use arrow::datatypes::Field; - use datafusion_common::ToDFSchema; - use datafusion_expr::execution_props::ExecutionProps; + use arrow_schema::{Fields, TimeUnit::Nanosecond}; use datafusion_expr::{cast, col, lit, Expr}; - use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::planner::logical2physical; + use datafusion_physical_plan::metrics::{Count, Time}; + + use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; use rand::prelude::*; @@ -473,6 +664,89 @@ mod test { ); } + #[test] + fn test_filter_type_coercion() { + let testdata = crate::test_util::parquet_test_data(); + let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) + .expect("opening file"); + + let parquet_reader_builder = + ParquetRecordBatchReaderBuilder::try_new(file).expect("creating reader"); + let metadata = parquet_reader_builder.metadata().clone(); + let file_schema = parquet_reader_builder.schema().clone(); + + // This is the schema we would like to coerce to, + // which is different from the physical schema of the file. + let table_schema = Schema::new(vec![Field::new( + "timestamp_col", + DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), + false, + )]); + + let table_ref = Arc::new(table_schema.clone()); + let schema_adapter = + DefaultSchemaAdapterFactory.create(Arc::clone(&table_ref), table_ref); + let (schema_mapping, _) = schema_adapter + .map_schema(&file_schema) + .expect("creating schema mapping"); + + let mut parquet_reader = parquet_reader_builder.build().expect("building reader"); + + // Parquet file is small, we only need 1 recordbatch + let first_rb = parquet_reader + .next() + .expect("expected record batch") + .expect("expected error free record batch"); + + // Test all should fail + let expr = col("timestamp_col").lt(Expr::Literal( + ScalarValue::TimestampNanosecond(Some(1), Some(Arc::from("UTC"))), + )); + let expr = logical2physical(&expr, &table_schema); + let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema) + .build(&metadata) + .expect("building candidate") + .expect("candidate expected"); + + let mut row_filter = DatafusionArrowPredicate::try_new( + candidate, + &file_schema, + &metadata, + Count::new(), + Count::new(), + Time::new(), + Arc::clone(&schema_mapping), + ) + .expect("creating filter predicate"); + + let filtered = row_filter.evaluate(first_rb.clone()); + assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![false; 8]))); + + // Test all should pass + let expr = col("timestamp_col").gt(Expr::Literal( + ScalarValue::TimestampNanosecond(Some(0), Some(Arc::from("UTC"))), + )); + let expr = logical2physical(&expr, &table_schema); + let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema) + .build(&metadata) + .expect("building candidate") + .expect("candidate expected"); + + let mut row_filter = DatafusionArrowPredicate::try_new( + candidate, + &file_schema, + &metadata, + Count::new(), + Count::new(), + Time::new(), + schema_mapping, + ) + .expect("creating filter predicate"); + + let filtered = row_filter.evaluate(first_rb); + assert!(matches!(filtered, Ok(a) if a == BooleanArray::from(vec![true; 8]))); + } + #[test] fn test_remap_projection() { let mut rng = thread_rng(); @@ -491,9 +765,86 @@ mod test { } } - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() + #[test] + fn nested_data_structures_prevent_pushdown() { + let table_schema = get_basic_table_schema(); + + let file_schema = Schema::new(vec![Field::new( + "list_col", + DataType::Struct(Fields::empty()), + true, + )]); + + let expr = col("list_col").is_not_null(); + + assert!(!can_expr_be_pushed_down_with_schemas( + &expr, + &file_schema, + &table_schema + )); + } + + #[test] + fn projected_columns_prevent_pushdown() { + let table_schema = get_basic_table_schema(); + + let file_schema = + Schema::new(vec![Field::new("existing_col", DataType::Int64, true)]); + + let expr = col("nonexistent_column").is_null(); + + assert!(!can_expr_be_pushed_down_with_schemas( + &expr, + &file_schema, + &table_schema + )); + } + + #[test] + fn basic_expr_doesnt_prevent_pushdown() { + let table_schema = get_basic_table_schema(); + + let file_schema = Schema::new(vec![Field::new("str_col", DataType::Utf8, true)]); + + let expr = col("str_col").is_null(); + + assert!(can_expr_be_pushed_down_with_schemas( + &expr, + &file_schema, + &table_schema + )); + } + + #[test] + fn complex_expr_doesnt_prevent_pushdown() { + let table_schema = get_basic_table_schema(); + + let file_schema = Schema::new(vec![ + Field::new("str_col", DataType::Utf8, true), + Field::new("int_col", DataType::UInt64, true), + ]); + + let expr = col("str_col") + .is_not_null() + .or(col("int_col").gt(Expr::Literal(ScalarValue::UInt64(Some(5))))); + + assert!(can_expr_be_pushed_down_with_schemas( + &expr, + &file_schema, + &table_schema + )); + } + + fn get_basic_table_schema() -> Schema { + let testdata = crate::test_util::parquet_test_data(); + let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) + .expect("opening file"); + + let reader = SerializedFileReader::new(file).expect("creating reader"); + + let metadata = reader.metadata(); + + parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) + .expect("parsing schema") } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs similarity index 64% rename from datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs rename to datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index bcd9e1fa4479..7406676652f6 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. +use crate::datasource::listing::FileRange; +use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; use arrow::{array::ArrayRef, datatypes::Schema}; use arrow_array::BooleanArray; -use arrow_schema::FieldRef; -use datafusion_common::{Column, ScalarValue}; +use datafusion_common::{Column, Result, ScalarValue}; +use parquet::arrow::arrow_reader::statistics::StatisticsConverter; +use parquet::arrow::parquet_column; use parquet::basic::Type; use parquet::data_type::Decimal; -use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; use parquet::{ arrow::{async_reader::AsyncFileReader, ParquetRecordBatchStreamBuilder}, @@ -29,154 +31,198 @@ use parquet::{ file::metadata::RowGroupMetaData, }; use std::collections::{HashMap, HashSet}; +use std::sync::Arc; -use crate::datasource::listing::FileRange; -use crate::datasource::physical_plan::parquet::statistics::{ - max_statistics, min_statistics, parquet_column, -}; -use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics}; - -use super::ParquetFileMetrics; +use super::{ParquetAccessPlan, ParquetFileMetrics}; -/// Prune row groups based on statistics -/// -/// Returns a vector of indexes into `groups` which should be scanned. -/// -/// If an index is NOT present in the returned Vec it means the -/// predicate filtered all the row group. +/// Reduces the [`ParquetAccessPlan`] based on row group level metadata. /// -/// If an index IS present in the returned Vec it means the predicate -/// did not filter out that row group. -/// -/// Note: This method currently ignores ColumnOrder -/// -pub(crate) fn prune_row_groups_by_statistics( - arrow_schema: &Schema, - parquet_schema: &SchemaDescriptor, - groups: &[RowGroupMetaData], - range: Option, - predicate: Option<&PruningPredicate>, - metrics: &ParquetFileMetrics, -) -> Vec { - let mut filtered = Vec::with_capacity(groups.len()); - for (idx, metadata) in groups.iter().enumerate() { - if let Some(range) = &range { - // figure out where the first dictionary page (or first data page are) +/// This struct implements the various types of pruning that are applied to a +/// set of row groups within a parquet file, progressively narrowing down the +/// set of row groups (and ranges/selections within those row groups) that +/// should be scanned, based on the available metadata. +#[derive(Debug, Clone, PartialEq)] +pub struct RowGroupAccessPlanFilter { + /// which row groups should be accessed + access_plan: ParquetAccessPlan, +} + +impl RowGroupAccessPlanFilter { + /// Create a new `RowGroupPlanBuilder` for pruning out the groups to scan + /// based on metadata and statistics + pub fn new(access_plan: ParquetAccessPlan) -> Self { + Self { access_plan } + } + + /// Return true if there are no row groups + pub fn is_empty(&self) -> bool { + self.access_plan.is_empty() + } + + /// Returns the inner access plan + pub fn build(self) -> ParquetAccessPlan { + self.access_plan + } + + /// Prune remaining row groups to only those within the specified range. + /// + /// Updates this set to mark row groups that should not be scanned + /// + /// # Panics + /// if `groups.len() != self.len()` + pub fn prune_by_range(&mut self, groups: &[RowGroupMetaData], range: &FileRange) { + assert_eq!(groups.len(), self.access_plan.len()); + for (idx, metadata) in groups.iter().enumerate() { + if !self.access_plan.should_scan(idx) { + continue; + } + + // Skip the row group if the first dictionary/data page are not + // within the range. + // // note don't use the location of metadata // let col = metadata.column(0); let offset = col .dictionary_page_offset() .unwrap_or_else(|| col.data_page_offset()); - if offset < range.start || offset >= range.end { - continue; + if !range.contains(offset) { + self.access_plan.skip(idx); } } + } + /// Prune remaining row groups using min/max/null_count statistics and + /// the [`PruningPredicate`] to determine if the predicate can not be true. + /// + /// Updates this set to mark row groups that should not be scanned + /// + /// Note: This method currently ignores ColumnOrder + /// + /// + /// # Panics + /// if `groups.len() != self.len()` + pub fn prune_by_statistics( + &mut self, + arrow_schema: &Schema, + parquet_schema: &SchemaDescriptor, + groups: &[RowGroupMetaData], + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, + ) { + // scoped timer updates on drop + let _timer_guard = metrics.statistics_eval_time.timer(); + + assert_eq!(groups.len(), self.access_plan.len()); + // Indexes of row groups still to scan + let row_group_indexes = self.access_plan.row_group_indexes(); + let row_group_metadatas = row_group_indexes + .iter() + .map(|&i| &groups[i]) + .collect::>(); - if let Some(predicate) = predicate { - let pruning_stats = RowGroupPruningStatistics { - parquet_schema, - row_group_metadata: metadata, - arrow_schema, - }; - match predicate.prune(&pruning_stats) { - Ok(values) => { - // NB: false means don't scan row group - if !values[0] { + let pruning_stats = RowGroupPruningStatistics { + parquet_schema, + row_group_metadatas, + arrow_schema, + }; + + // try to prune the row groups in a single call + match predicate.prune(&pruning_stats) { + Ok(values) => { + // values[i] is false means the predicate could not be true for row group i + for (idx, &value) in row_group_indexes.iter().zip(values.iter()) { + if !value { + self.access_plan.skip(*idx); metrics.row_groups_pruned_statistics.add(1); - continue; + } else { + metrics.row_groups_matched_statistics.add(1); } } - // stats filter array could not be built - // return a closure which will not filter out any row groups - Err(e) => { - log::debug!("Error evaluating row group predicate values {e}"); - metrics.predicate_evaluation_errors.add(1); - } } - metrics.row_groups_matched_statistics.add(1); + // stats filter array could not be built, so we can't prune + Err(e) => { + log::debug!("Error evaluating row group predicate values {e}"); + metrics.predicate_evaluation_errors.add(1); + } } - - filtered.push(idx) } - filtered -} -/// Prune row groups by bloom filters -/// -/// Returns a vector of indexes into `groups` which should be scanned. -/// -/// If an index is NOT present in the returned Vec it means the -/// predicate filtered all the row group. -/// -/// If an index IS present in the returned Vec it means the predicate -/// did not filter out that row group. -pub(crate) async fn prune_row_groups_by_bloom_filters< - T: AsyncFileReader + Send + 'static, ->( - arrow_schema: &Schema, - builder: &mut ParquetRecordBatchStreamBuilder, - row_groups: &[usize], - groups: &[RowGroupMetaData], - predicate: &PruningPredicate, - metrics: &ParquetFileMetrics, -) -> Vec { - let mut filtered = Vec::with_capacity(groups.len()); - for idx in row_groups { - // get all columns in the predicate that we could use a bloom filter with - let literal_columns = predicate.literal_columns(); - let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); - - for column_name in literal_columns { - let Some((column_idx, _field)) = - parquet_column(builder.parquet_schema(), arrow_schema, &column_name) - else { + /// Prune remaining row groups using available bloom filters and the + /// [`PruningPredicate`]. + /// + /// Updates this set with row groups that should not be scanned + /// + /// # Panics + /// if the builder does not have the same number of row groups as this set + pub async fn prune_by_bloom_filters( + &mut self, + arrow_schema: &Schema, + builder: &mut ParquetRecordBatchStreamBuilder, + predicate: &PruningPredicate, + metrics: &ParquetFileMetrics, + ) { + // scoped timer updates on drop + let _timer_guard = metrics.bloom_filter_eval_time.timer(); + + assert_eq!(builder.metadata().num_row_groups(), self.access_plan.len()); + for idx in 0..self.access_plan.len() { + if !self.access_plan.should_scan(idx) { continue; - }; + } - let bf = match builder - .get_row_group_column_bloom_filter(*idx, column_idx) - .await - { - Ok(Some(bf)) => bf, - Ok(None) => continue, // no bloom filter for this column - Err(e) => { - log::debug!("Ignoring error reading bloom filter: {e}"); - metrics.predicate_evaluation_errors.add(1); + // Attempt to find bloom filters for filtering this row group + let literal_columns = predicate.literal_columns(); + let mut column_sbbf = HashMap::with_capacity(literal_columns.len()); + + for column_name in literal_columns { + let Some((column_idx, _field)) = + parquet_column(builder.parquet_schema(), arrow_schema, &column_name) + else { continue; - } - }; - let physical_type = - builder.parquet_schema().column(column_idx).physical_type(); + }; + + let bf = match builder + .get_row_group_column_bloom_filter(idx, column_idx) + .await + { + Ok(Some(bf)) => bf, + Ok(None) => continue, // no bloom filter for this column + Err(e) => { + log::debug!("Ignoring error reading bloom filter: {e}"); + metrics.predicate_evaluation_errors.add(1); + continue; + } + }; + let physical_type = + builder.parquet_schema().column(column_idx).physical_type(); - column_sbbf.insert(column_name.to_string(), (bf, physical_type)); - } + column_sbbf.insert(column_name.to_string(), (bf, physical_type)); + } - let stats = BloomFilterStatistics { column_sbbf }; + let stats = BloomFilterStatistics { column_sbbf }; - // Can this group be pruned? - let prune_group = match predicate.prune(&stats) { - Ok(values) => !values[0], - Err(e) => { - log::debug!("Error evaluating row group predicate on bloom filter: {e}"); - metrics.predicate_evaluation_errors.add(1); - false - } - }; + // Can this group be pruned? + let prune_group = match predicate.prune(&stats) { + Ok(values) => !values[0], + Err(e) => { + log::debug!( + "Error evaluating row group predicate on bloom filter: {e}" + ); + metrics.predicate_evaluation_errors.add(1); + false + } + }; - if prune_group { - metrics.row_groups_pruned_bloom_filter.add(1); - } else { - if !stats.column_sbbf.is_empty() { + if prune_group { + metrics.row_groups_pruned_bloom_filter.add(1); + self.access_plan.skip(idx) + } else if !stats.column_sbbf.is_empty() { metrics.row_groups_matched_bloom_filter.add(1); } - filtered.push(*idx); } } - filtered } - -/// Implements `PruningStatistics` for Parquet Split Block Bloom Filters (SBBF) +/// Implements [`PruningStatistics`] for Parquet Split Block Bloom Filters (SBBF) struct BloomFilterStatistics { /// Maps column name to the parquet bloom filter and parquet physical type column_sbbf: HashMap, @@ -224,8 +270,12 @@ impl PruningStatistics for BloomFilterStatistics { .iter() .map(|value| { match value { - ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), - ScalarValue::Binary(Some(v)) => sbbf.check(v), + ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { + sbbf.check(&v.as_str()) + } + ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { + sbbf.check(v) + } ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), ScalarValue::Boolean(Some(v)) => sbbf.check(v), ScalarValue::Float64(Some(v)) => sbbf.check(v), @@ -299,49 +349,62 @@ impl PruningStatistics for BloomFilterStatistics { } } -/// Wraps [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] -/// -/// Note: This should be implemented for an array of [`RowGroupMetaData`] instead -/// of per row-group +/// Wraps a slice of [`RowGroupMetaData`] in a way that implements [`PruningStatistics`] struct RowGroupPruningStatistics<'a> { parquet_schema: &'a SchemaDescriptor, - row_group_metadata: &'a RowGroupMetaData, + row_group_metadatas: Vec<&'a RowGroupMetaData>, arrow_schema: &'a Schema, } impl<'a> RowGroupPruningStatistics<'a> { - /// Lookups up the parquet column by name - fn column(&self, name: &str) -> Option<(&ColumnChunkMetaData, &FieldRef)> { - let (idx, field) = parquet_column(self.parquet_schema, self.arrow_schema, name)?; - Some((self.row_group_metadata.column(idx), field)) + /// Return an iterator over the row group metadata + fn metadata_iter(&'a self) -> impl Iterator + 'a { + self.row_group_metadatas.iter().copied() + } + + fn statistics_converter<'b>( + &'a self, + column: &'b Column, + ) -> Result> { + Ok(StatisticsConverter::try_new( + &column.name, + self.arrow_schema, + self.parquet_schema, + )?) } } impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { fn min_values(&self, column: &Column) -> Option { - let (column, field) = self.column(&column.name)?; - min_statistics(field.data_type(), std::iter::once(column.statistics())).ok() + self.statistics_converter(column) + .and_then(|c| Ok(c.row_group_mins(self.metadata_iter())?)) + .ok() } fn max_values(&self, column: &Column) -> Option { - let (column, field) = self.column(&column.name)?; - max_statistics(field.data_type(), std::iter::once(column.statistics())).ok() + self.statistics_converter(column) + .and_then(|c| Ok(c.row_group_maxes(self.metadata_iter())?)) + .ok() } fn num_containers(&self) -> usize { - 1 + self.row_group_metadatas.len() } fn null_counts(&self, column: &Column) -> Option { - let (c, _) = self.column(&column.name)?; - let scalar = ScalarValue::UInt64(Some(c.statistics()?.null_count())); - scalar.to_array().ok() + self.statistics_converter(column) + .and_then(|c| Ok(c.row_group_null_counts(self.metadata_iter())?)) + .ok() + .map(|counts| Arc::new(counts) as ArrayRef) } fn row_counts(&self, column: &Column) -> Option { - let (c, _) = self.column(&column.name)?; - let scalar = ScalarValue::UInt64(Some(c.num_values() as u64)); - scalar.to_array().ok() + // row counts are the same for all columns in a row group + self.statistics_converter(column) + .and_then(|c| Ok(c.row_group_row_counts(self.metadata_iter())?)) + .ok() + .flatten() + .map(|counts| Arc::new(counts) as ArrayRef) } fn contained( @@ -355,25 +418,28 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { #[cfg(test)] mod tests { + use std::ops::Rem; + use std::sync::Arc; + use super::*; - use crate::datasource::physical_plan::parquet::ParquetFileReader; + use crate::datasource::physical_plan::parquet::reader::ParquetFileReader; use crate::physical_plan::metrics::ExecutionPlanMetricsSet; + use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::{DataType, Field}; - use datafusion_common::{Result, ToDFSchema}; - use datafusion_expr::execution_props::ExecutionProps; + use datafusion_common::Result; use datafusion_expr::{cast, col, lit, Expr}; - use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; + use datafusion_physical_expr::planner::logical2physical; + use parquet::arrow::arrow_to_parquet_schema; use parquet::arrow::async_reader::ParquetObjectReader; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; + use parquet::file::metadata::ColumnChunkMetaData; use parquet::{ basic::Type as PhysicalType, file::statistics::Statistics as ParquetStatistics, schema::types::SchemaDescPtr, }; - use std::ops::Rem; - use std::sync::Arc; struct PrimitiveTypeField { name: &'static str, @@ -431,25 +497,35 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let rgm1 = get_row_group_meta_data( &schema_descr, - vec![ParquetStatistics::int32(Some(1), Some(10), None, 0, false)], + vec![ParquetStatistics::int32( + Some(1), + Some(10), + None, + Some(0), + false, + )], ); let rgm2 = get_row_group_meta_data( &schema_descr, - vec![ParquetStatistics::int32(Some(11), Some(20), None, 0, false)], + vec![ParquetStatistics::int32( + Some(11), + Some(20), + None, + Some(0), + false, + )], ); let metrics = parquet_file_metrics(); - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2], - None, - Some(&pruning_predicate), - &metrics - ), - vec![1] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2], + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![1])) } #[test] @@ -466,26 +542,30 @@ mod tests { let schema_descr = get_test_schema_descr(vec![field]); let rgm1 = get_row_group_meta_data( &schema_descr, - vec![ParquetStatistics::int32(None, None, None, 0, false)], + vec![ParquetStatistics::int32(None, None, None, Some(0), false)], ); let rgm2 = get_row_group_meta_data( &schema_descr, - vec![ParquetStatistics::int32(Some(11), Some(20), None, 0, false)], + vec![ParquetStatistics::int32( + Some(11), + Some(20), + None, + Some(0), + false, + )], ); let metrics = parquet_file_metrics(); // missing statistics for first row group mean that the result from the predicate expression // is null / undefined so the first row group can't be filtered out - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2], - None, - Some(&pruning_predicate), - &metrics - ), - vec![0, 1] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2], + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::None); } #[test] @@ -508,15 +588,15 @@ mod tests { let rgm1 = get_row_group_meta_data( &schema_descr, vec![ - ParquetStatistics::int32(Some(1), Some(10), None, 0, false), - ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ParquetStatistics::int32(Some(1), Some(10), None, Some(0), false), + ParquetStatistics::int32(Some(1), Some(10), None, Some(0), false), ], ); let rgm2 = get_row_group_meta_data( &schema_descr, vec![ - ParquetStatistics::int32(Some(11), Some(20), None, 0, false), - ParquetStatistics::int32(Some(11), Some(20), None, 0, false), + ParquetStatistics::int32(Some(11), Some(20), None, Some(0), false), + ParquetStatistics::int32(Some(11), Some(20), None, Some(0), false), ], ); @@ -524,17 +604,15 @@ mod tests { let groups = &[rgm1, rgm2]; // the first row group is still filtered out because the predicate expression can be partially evaluated // when conditions are joined using AND - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - groups, - None, - Some(&pruning_predicate), - &metrics - ), - vec![1] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + groups, + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![1])); // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out @@ -544,17 +622,15 @@ mod tests { // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - groups, - None, - Some(&pruning_predicate), - &metrics - ), - vec![0, 1] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + groups, + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::None); } #[test] @@ -585,16 +661,16 @@ mod tests { let rgm1 = get_row_group_meta_data( &schema_descr, vec![ - ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), // c2 - ParquetStatistics::int32(Some(1), Some(10), None, 0, false), + ParquetStatistics::int32(Some(-10), Some(-1), None, Some(0), false), // c2 + ParquetStatistics::int32(Some(1), Some(10), None, Some(0), false), ], ); // rg1 has c2 greater than zero, c1 less than zero let rgm2 = get_row_group_meta_data( &schema_descr, vec![ - ParquetStatistics::int32(Some(1), Some(10), None, 0, false), - ParquetStatistics::int32(Some(-10), Some(-1), None, 0, false), + ParquetStatistics::int32(Some(1), Some(10), None, Some(0), false), + ParquetStatistics::int32(Some(-10), Some(-1), None, Some(0), false), ], ); @@ -602,17 +678,15 @@ mod tests { let groups = &[rgm1, rgm2]; // the first row group should be left because c1 is greater than zero // the second should be filtered out because c1 is less than zero - assert_eq!( - prune_row_groups_by_statistics( - &file_schema, // NB must be file schema, not table_schema - &schema_descr, - groups, - None, - Some(&pruning_predicate), - &metrics - ), - vec![0] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &file_schema, + &schema_descr, + groups, + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![0])); } fn gen_row_group_meta_data_for_pruning_predicate() -> Vec { @@ -623,15 +697,15 @@ mod tests { let rgm1 = get_row_group_meta_data( &schema_descr, vec![ - ParquetStatistics::int32(Some(1), Some(10), None, 0, false), - ParquetStatistics::boolean(Some(false), Some(true), None, 0, false), + ParquetStatistics::int32(Some(1), Some(10), None, Some(0), false), + ParquetStatistics::boolean(Some(false), Some(true), None, Some(0), false), ], ); let rgm2 = get_row_group_meta_data( &schema_descr, vec![ - ParquetStatistics::int32(Some(11), Some(20), None, 0, false), - ParquetStatistics::boolean(Some(false), Some(true), None, 1, false), + ParquetStatistics::int32(Some(11), Some(20), None, Some(0), false), + ParquetStatistics::boolean(Some(false), Some(true), None, Some(1), false), ], ); vec![rgm1, rgm2] @@ -653,17 +727,15 @@ mod tests { let metrics = parquet_file_metrics(); // First row group was filtered out because it contains no null value on "c2". - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &groups, - None, - Some(&pruning_predicate), - &metrics - ), - vec![1] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(2)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &groups, + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![1])); } #[test] @@ -687,17 +759,16 @@ mod tests { let metrics = parquet_file_metrics(); // bool = NULL always evaluates to NULL (and thus will not // pass predicates. Ideally these should both be false - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &groups, - None, - Some(&pruning_predicate), - &metrics - ), - vec![1] + let mut row_groups = + RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(groups.len())); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &groups, + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![1])); } #[test] @@ -708,11 +779,8 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 2), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 2), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { scale: 2, @@ -732,7 +800,7 @@ mod tests { Some(100), Some(600), None, - 0, + Some(0), false, )], ); @@ -740,36 +808,46 @@ mod tests { &schema_descr, // [0.1, 0.2] // c1 > 5, this row group will not be included in the results. - vec![ParquetStatistics::int32(Some(10), Some(20), None, 0, false)], + vec![ParquetStatistics::int32( + Some(10), + Some(20), + None, + Some(0), + false, + )], ); let rgm3 = get_row_group_meta_data( &schema_descr, // [1, None] // c1 > 5, this row group can not be filtered out, so will be included in the results. - vec![ParquetStatistics::int32(Some(100), None, None, 0, false)], + vec![ParquetStatistics::int32( + Some(100), + None, + None, + Some(0), + false, + )], ); let metrics = parquet_file_metrics(); - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2, rgm3], - None, - Some(&pruning_predicate), - &metrics - ), - vec![0, 2] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(3)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2, rgm3], + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![0, 2])); + } + #[test] + fn row_group_pruning_predicate_decimal_type2() { // INT32: c1 > 5, but parquet decimal type has different precision or scale to arrow decimal // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let schema = Arc::new(Schema::new(vec![Field::new( - "c1", - DataType::Decimal128(9, 0), - false, - )])); + let schema = + Arc::new(Schema::new(vec![Field::new("c1", Decimal128(9, 0), false)])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT32) .with_logical_type(LogicalType::Decimal { @@ -779,7 +857,7 @@ mod tests { .with_scale(0) .with_precision(9); let schema_descr = get_test_schema_descr(vec![field]); - let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( + let expr = cast(col("c1"), Decimal128(11, 2)).gt(cast( lit(ScalarValue::Decimal128(Some(500), 5, 2)), Decimal128(11, 2), )); @@ -793,7 +871,7 @@ mod tests { Some(100), Some(600), None, - 0, + Some(0), false, )], ); @@ -801,37 +879,69 @@ mod tests { &schema_descr, // [10, 20] // c1 > 5, this row group will be included in the results. - vec![ParquetStatistics::int32(Some(10), Some(20), None, 0, false)], + vec![ParquetStatistics::int32( + Some(10), + Some(20), + None, + Some(0), + false, + )], ); let rgm3 = get_row_group_meta_data( &schema_descr, // [0, 2] // c1 > 5, this row group will not be included in the results. - vec![ParquetStatistics::int32(Some(0), Some(2), None, 0, false)], + vec![ParquetStatistics::int32( + Some(0), + Some(2), + None, + Some(0), + false, + )], ); let rgm4 = get_row_group_meta_data( &schema_descr, // [None, 2] - // c1 > 5, this row group can not be filtered out, so will be included in the results. - vec![ParquetStatistics::int32(None, Some(2), None, 0, false)], + // c1 > 5, this row group will also not be included in the results + // (the min value is unknown, but the max value is 2, so no values can be greater than 5) + vec![ParquetStatistics::int32( + None, + Some(2), + None, + Some(0), + false, + )], ); - let metrics = parquet_file_metrics(); - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2, rgm3, rgm4], + let rgm5 = get_row_group_meta_data( + &schema_descr, + // [2, None] + // c1 > 5, this row group must be included + // (the min value is 2, but the max value is unknown, so it may have values greater than 5) + vec![ParquetStatistics::int32( + Some(2), + None, None, - Some(&pruning_predicate), - &metrics - ), - vec![0, 1, 3] + Some(0), + false, + )], ); - + let metrics = parquet_file_metrics(); + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(5)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2, rgm3, rgm4, rgm5], + &pruning_predicate, + &metrics, + ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![0, 1, 4])); + } + #[test] + fn row_group_pruning_predicate_decimal_type3() { // INT64: c1 < 5, the c1 is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::INT64) @@ -852,38 +962,44 @@ mod tests { Some(600), Some(800), None, - 0, + Some(0), false, )], ); let rgm2 = get_row_group_meta_data( &schema_descr, // [0.1, 0.2] - vec![ParquetStatistics::int64(Some(10), Some(20), None, 0, false)], + vec![ParquetStatistics::int64( + Some(10), + Some(20), + None, + Some(0), + false, + )], ); let rgm3 = get_row_group_meta_data( &schema_descr, // [0.1, 0.2] - vec![ParquetStatistics::int64(None, None, None, 0, false)], + vec![ParquetStatistics::int64(None, None, None, Some(0), false)], ); let metrics = parquet_file_metrics(); - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2, rgm3], - None, - Some(&pruning_predicate), - &metrics - ), - vec![1, 2] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(3)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2, rgm3], + &pruning_predicate, + &metrics, ); - + assert_pruned(row_groups, ExpectedPruning::Some(vec![1, 2])); + } + #[test] + fn row_group_pruning_predicate_decimal_type4() { // FIXED_LENGTH_BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::FIXED_LEN_BYTE_ARRAY) @@ -896,7 +1012,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -913,7 +1029,7 @@ mod tests { 8000i128.to_be_bytes().to_vec(), ))), None, - 0, + Some(0), false, )], ); @@ -929,7 +1045,7 @@ mod tests { 20000i128.to_be_bytes().to_vec(), ))), None, - 0, + Some(0), false, )], ); @@ -937,27 +1053,31 @@ mod tests { let rgm3 = get_row_group_meta_data( &schema_descr, vec![ParquetStatistics::fixed_len_byte_array( - None, None, None, 0, false, + None, + None, + None, + Some(0), + false, )], ); let metrics = parquet_file_metrics(); - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2, rgm3], - None, - Some(&pruning_predicate), - &metrics - ), - vec![1, 2] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(3)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2, rgm3], + &pruning_predicate, + &metrics, ); - + assert_pruned(row_groups, ExpectedPruning::Some(vec![1, 2])); + } + #[test] + fn row_group_pruning_predicate_decimal_type5() { // BYTE_ARRAY: c1 = decimal128(100000, 28, 3), the c1 is decimal(18,2) // the type of parquet is decimal(18,2) let schema = Arc::new(Schema::new(vec![Field::new( "c1", - DataType::Decimal128(18, 2), + Decimal128(18, 2), false, )])); let field = PrimitiveTypeField::new("c1", PhysicalType::BYTE_ARRAY) @@ -970,7 +1090,7 @@ mod tests { .with_byte_len(16); let schema_descr = get_test_schema_descr(vec![field]); // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let left = cast(col("c1"), Decimal128(28, 3)); let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); @@ -983,7 +1103,7 @@ mod tests { // 80.00 Some(ByteArray::from(8000i128.to_be_bytes().to_vec())), None, - 0, + Some(0), false, )], ); @@ -995,26 +1115,30 @@ mod tests { // 200.00 Some(ByteArray::from(20000i128.to_be_bytes().to_vec())), None, - 0, + Some(0), false, )], ); let rgm3 = get_row_group_meta_data( &schema_descr, - vec![ParquetStatistics::byte_array(None, None, None, 0, false)], + vec![ParquetStatistics::byte_array( + None, + None, + None, + Some(0), + false, + )], ); let metrics = parquet_file_metrics(); - assert_eq!( - prune_row_groups_by_statistics( - &schema, - &schema_descr, - &[rgm1, rgm2, rgm3], - None, - Some(&pruning_predicate), - &metrics - ), - vec![1, 2] + let mut row_groups = RowGroupAccessPlanFilter::new(ParquetAccessPlan::new_all(3)); + row_groups.prune_by_statistics( + &schema, + &schema_descr, + &[rgm1, rgm2, rgm3], + &pruning_predicate, + &metrics, ); + assert_pruned(row_groups, ExpectedPruning::Some(vec![1, 2])); } fn get_row_group_meta_data( @@ -1075,12 +1199,6 @@ mod tests { ParquetFileMetrics::new(0, "file.parquet", &metrics) } - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() - } - #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_simple_expr() { BloomFilterTest::new_data_index_bloom_encoding_stats() @@ -1091,7 +1209,7 @@ mod tests { } #[tokio::test] - async fn test_row_group_bloom_filter_pruning_predicate_mutiple_expr() { + async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr() { BloomFilterTest::new_data_index_bloom_encoding_stats() .with_expect_all_pruned() // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` @@ -1105,6 +1223,25 @@ mod tests { .await } + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_multiple_expr_view() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_all_pruned() + // generate pruning predicate `(String = "Hello_Not_exists" OR String = "Hello_Not_exists2")` + .run( + lit("1").eq(lit("1")).and( + col(r#""String""#) + .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( + "Hello_Not_Exists", + ))))) + .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + Some(String::from("Hello_Not_Exists2")), + )))), + ), + ) + .await + } + #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_sql_in() { // load parquet file @@ -1126,16 +1263,14 @@ mod tests { let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); - let row_groups = vec![0]; let pruned_row_groups = test_row_group_bloom_filter_pruning_predicate( file_name, data, &pruning_predicate, - &row_groups, ) .await .unwrap(); - assert!(pruned_row_groups.is_empty()); + assert!(pruned_row_groups.access_plan.row_group_indexes().is_empty()); } #[tokio::test] @@ -1174,6 +1309,26 @@ mod tests { .await } + #[tokio::test] + async fn test_row_group_bloom_filter_pruning_predicate_with_exists_3_values_view() { + BloomFilterTest::new_data_index_bloom_encoding_stats() + .with_expect_none_pruned() + // generate pruning predicate `(String = "Hello") OR (String = "the quick") OR (String = "are you")` + .run( + col(r#""String""#) + .eq(Expr::Literal(ScalarValue::Utf8View(Some(String::from( + "Hello", + ))))) + .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + Some(String::from("the quick")), + )))) + .or(col(r#""String""#).eq(Expr::Literal(ScalarValue::Utf8View( + Some(String::from("are you")), + )))), + ) + .await + } + #[tokio::test] async fn test_row_group_bloom_filter_pruning_predicate_with_or_not_eq() { BloomFilterTest::new_data_index_bloom_encoding_stats() @@ -1196,14 +1351,60 @@ mod tests { .await } + // What row groups are expected to be left after pruning + #[derive(Debug)] + enum ExpectedPruning { + All, + /// Only the specified row groups are expected to REMAIN (not what is pruned) + Some(Vec), + None, + } + + impl ExpectedPruning { + /// asserts that the pruned row group match this expectation + fn assert(&self, row_groups: &RowGroupAccessPlanFilter) { + let num_row_groups = row_groups.access_plan.len(); + assert!(num_row_groups > 0); + let num_pruned = (0..num_row_groups) + .filter_map(|i| { + if row_groups.access_plan.should_scan(i) { + None + } else { + Some(1) + } + }) + .sum::(); + + match self { + Self::All => { + assert_eq!( + num_row_groups, num_pruned, + "Expected all row groups to be pruned, but got {row_groups:?}" + ); + } + ExpectedPruning::None => { + assert_eq!( + num_pruned, 0, + "Expected no row groups to be pruned, but got {row_groups:?}" + ); + } + ExpectedPruning::Some(expected) => { + let actual = row_groups.access_plan.row_group_indexes(); + assert_eq!(expected, &actual, "Unexpected row groups pruned. Expected {expected:?}, got {actual:?}"); + } + } + } + } + + fn assert_pruned(row_groups: RowGroupAccessPlanFilter, expected: ExpectedPruning) { + expected.assert(&row_groups); + } + struct BloomFilterTest { file_name: String, schema: Schema, - // which row groups should be attempted to prune - row_groups: Vec, - // which row groups are expected to be left after pruning. Must be set - // otherwise will panic on run() - post_pruning_row_groups: Option>, + // which row groups are expected to be left after pruning + post_pruning_row_groups: ExpectedPruning, } impl BloomFilterTest { @@ -1234,8 +1435,7 @@ mod tests { Self { file_name: String::from("data_index_bloom_encoding_stats.parquet"), schema: Schema::new(vec![Field::new("String", DataType::Utf8, false)]), - row_groups: vec![0], - post_pruning_row_groups: None, + post_pruning_row_groups: ExpectedPruning::None, } } @@ -1248,20 +1448,19 @@ mod tests { DataType::Utf8, false, )]), - row_groups: vec![0], - post_pruning_row_groups: None, + post_pruning_row_groups: ExpectedPruning::None, } } /// Expect all row groups to be pruned pub fn with_expect_all_pruned(mut self) -> Self { - self.post_pruning_row_groups = Some(vec![]); + self.post_pruning_row_groups = ExpectedPruning::All; self } /// Expect all row groups not to be pruned pub fn with_expect_none_pruned(mut self) -> Self { - self.post_pruning_row_groups = Some(self.row_groups.clone()); + self.post_pruning_row_groups = ExpectedPruning::None; self } @@ -1270,13 +1469,9 @@ mod tests { let Self { file_name, schema, - row_groups, post_pruning_row_groups, } = self; - let post_pruning_row_groups = - post_pruning_row_groups.expect("post_pruning_row_groups must be set"); - let testdata = datafusion_common::test_util::parquet_test_data(); let path = format!("{testdata}/{file_name}"); let data = bytes::Bytes::from(std::fs::read(path).unwrap()); @@ -1289,20 +1484,20 @@ mod tests { &file_name, data, &pruning_predicate, - &row_groups, ) .await .unwrap(); - assert_eq!(pruned_row_groups, post_pruning_row_groups); + + post_pruning_row_groups.assert(&pruned_row_groups); } } + /// Evaluates the pruning predicate on the specified row groups and returns the row groups that are left async fn test_row_group_bloom_filter_pruning_predicate( file_name: &str, data: bytes::Bytes, pruning_predicate: &PruningPredicate, - row_groups: &[usize], - ) -> Result> { + ) -> Result { use object_store::{ObjectMeta, ObjectStore}; let object_meta = ObjectMeta { @@ -1314,7 +1509,7 @@ mod tests { }; let in_memory = object_store::memory::InMemory::new(); in_memory - .put(&object_meta.location, data) + .put(&object_meta.location, data.into()) .await .expect("put parquet file into in memory object store"); @@ -1327,17 +1522,17 @@ mod tests { }; let mut builder = ParquetRecordBatchStreamBuilder::new(reader).await.unwrap(); - let metadata = builder.metadata().clone(); - let pruned_row_group = prune_row_groups_by_bloom_filters( - pruning_predicate.schema(), - &mut builder, - row_groups, - metadata.row_groups(), - pruning_predicate, - &file_metrics, - ) - .await; + let access_plan = ParquetAccessPlan::new_all(builder.metadata().num_row_groups()); + let mut pruned_row_groups = RowGroupAccessPlanFilter::new(access_plan); + pruned_row_groups + .prune_by_bloom_filters( + pruning_predicate.schema(), + &mut builder, + pruning_predicate, + &file_metrics, + ) + .await; - Ok(pruned_row_group) + Ok(pruned_row_groups) } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs deleted file mode 100644 index 8972c261b14a..000000000000 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ /dev/null @@ -1,922 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`min_statistics`] and [`max_statistics`] convert statistics in parquet format to arrow [`ArrayRef`]. - -// TODO: potentially move this to arrow-rs: https://github.com/apache/arrow-rs/issues/4328 - -use arrow::{array::ArrayRef, datatypes::DataType}; -use arrow_array::new_empty_array; -use arrow_schema::{FieldRef, Schema}; -use datafusion_common::{Result, ScalarValue}; -use parquet::file::statistics::Statistics as ParquetStatistics; -use parquet::schema::types::SchemaDescriptor; - -// Convert the bytes array to i128. -// The endian of the input bytes array must be big-endian. -pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { - // The bytes array are from parquet file and must be the big-endian. - // The endian is defined by parquet format, and the reference document - // https://github.com/apache/parquet-format/blob/54e53e5d7794d383529dd30746378f19a12afd58/src/main/thrift/parquet.thrift#L66 - i128::from_be_bytes(sign_extend_be(b)) -} - -// Copy from arrow-rs -// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 -// Convert the byte slice to fixed length byte array with the length of 16 -fn sign_extend_be(b: &[u8]) -> [u8; 16] { - assert!(b.len() <= 16, "Array too large, expected less than 16"); - let is_negative = (b[0] & 128u8) == 128u8; - let mut result = if is_negative { [255u8; 16] } else { [0u8; 16] }; - for (d, s) in result.iter_mut().skip(16 - b.len()).zip(b) { - *d = *s; - } - result -} - -/// Extract a single min/max statistics from a [`ParquetStatistics`] object -/// -/// * `$column_statistics` is the `ParquetStatistics` object -/// * `$func is the function` (`min`/`max`) to call to get the value -/// * `$bytes_func` is the function (`min_bytes`/`max_bytes`) to call to get the value as bytes -/// * `$target_arrow_type` is the [`DataType`] of the target statistics -macro_rules! get_statistic { - ($column_statistics:expr, $func:ident, $bytes_func:ident, $target_arrow_type:expr) => {{ - if !$column_statistics.has_min_max_set() { - return None; - } - match $column_statistics { - ParquetStatistics::Boolean(s) => Some(ScalarValue::Boolean(Some(*s.$func()))), - ParquetStatistics::Int32(s) => { - match $target_arrow_type { - // int32 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - *precision, - *scale, - )) - } - _ => Some(ScalarValue::Int32(Some(*s.$func()))), - } - } - ParquetStatistics::Int64(s) => { - match $target_arrow_type { - // int64 to decimal with the precision and scale - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(*s.$func() as i128), - *precision, - *scale, - )) - } - _ => Some(ScalarValue::Int64(Some(*s.$func()))), - } - } - // 96 bit ints not supported - ParquetStatistics::Int96(_) => None, - ParquetStatistics::Float(s) => Some(ScalarValue::Float32(Some(*s.$func()))), - ParquetStatistics::Double(s) => Some(ScalarValue::Float64(Some(*s.$func()))), - ParquetStatistics::ByteArray(s) => { - match $target_arrow_type { - // decimal data type - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - *precision, - *scale, - )) - } - _ => { - let s = std::str::from_utf8(s.$bytes_func()) - .map(|s| s.to_string()) - .ok(); - if s.is_none() { - log::debug!( - "Utf8 statistics is a non-UTF8 value, ignoring it." - ); - } - Some(ScalarValue::Utf8(s)) - } - } - } - // type not fully supported yet - ParquetStatistics::FixedLenByteArray(s) => { - match $target_arrow_type { - // just support specific logical data types, there are others each - // with their own ordering - Some(DataType::Decimal128(precision, scale)) => { - Some(ScalarValue::Decimal128( - Some(from_bytes_to_i128(s.$bytes_func())), - *precision, - *scale, - )) - } - Some(DataType::FixedSizeBinary(size)) => { - let value = s.$bytes_func().to_vec(); - let value = if value.len().try_into() == Ok(*size) { - Some(value) - } else { - log::debug!( - "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", - size, - value.len(), - ); - None - }; - Some(ScalarValue::FixedSizeBinary( - *size, - value, - )) - } - _ => None, - } - } - } - }}; -} - -/// Lookups up the parquet column by name -/// -/// Returns the parquet column index and the corresponding arrow field -pub(crate) fn parquet_column<'a>( - parquet_schema: &SchemaDescriptor, - arrow_schema: &'a Schema, - name: &str, -) -> Option<(usize, &'a FieldRef)> { - let (root_idx, field) = arrow_schema.fields.find(name)?; - if field.data_type().is_nested() { - // Nested fields are not supported and require non-trivial logic - // to correctly walk the parquet schema accounting for the - // logical type rules - - // - // For example a ListArray could correspond to anything from 1 to 3 levels - // in the parquet schema - return None; - } - - // This could be made more efficient (#TBD) - let parquet_idx = (0..parquet_schema.columns().len()) - .find(|x| parquet_schema.get_column_root_idx(*x) == root_idx)?; - Some((parquet_idx, field)) -} - -/// Extracts the min statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] -pub(crate) fn min_statistics<'a, I: Iterator>>( - data_type: &DataType, - iterator: I, -) -> Result { - let scalars = iterator - .map(|x| x.and_then(|s| get_statistic!(s, min, min_bytes, Some(data_type)))); - collect_scalars(data_type, scalars) -} - -/// Extracts the max statistics from an iterator of [`ParquetStatistics`] to an [`ArrayRef`] -pub(crate) fn max_statistics<'a, I: Iterator>>( - data_type: &DataType, - iterator: I, -) -> Result { - let scalars = iterator - .map(|x| x.and_then(|s| get_statistic!(s, max, max_bytes, Some(data_type)))); - collect_scalars(data_type, scalars) -} - -/// Builds an array from an iterator of ScalarValue -fn collect_scalars>>( - data_type: &DataType, - iterator: I, -) -> Result { - let mut scalars = iterator.peekable(); - match scalars.peek().is_none() { - true => Ok(new_empty_array(data_type)), - false => { - let null = ScalarValue::try_from(data_type)?; - ScalarValue::iter_to_array(scalars.map(|x| x.unwrap_or_else(|| null.clone()))) - } - } -} - -#[cfg(test)] -mod test { - use super::*; - use arrow_array::{ - new_null_array, Array, BinaryArray, BooleanArray, Decimal128Array, Float32Array, - Float64Array, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, - TimestampNanosecondArray, - }; - use arrow_schema::{Field, SchemaRef}; - use bytes::Bytes; - use datafusion_common::test_util::parquet_test_data; - use parquet::arrow::arrow_reader::ArrowReaderBuilder; - use parquet::arrow::arrow_writer::ArrowWriter; - use parquet::file::metadata::{ParquetMetaData, RowGroupMetaData}; - use parquet::file::properties::{EnabledStatistics, WriterProperties}; - use std::path::PathBuf; - use std::sync::Arc; - - // TODO error cases (with parquet statistics that are mismatched in expected type) - - #[test] - fn roundtrip_empty() { - let empty_bool_array = new_empty_array(&DataType::Boolean); - Test { - input: empty_bool_array.clone(), - expected_min: empty_bool_array.clone(), - expected_max: empty_bool_array.clone(), - } - .run() - } - - #[test] - fn roundtrip_bool() { - Test { - input: bool_array([ - // row group 1 - Some(true), - None, - Some(true), - // row group 2 - Some(true), - Some(false), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: bool_array([Some(true), Some(false), None]), - expected_max: bool_array([Some(true), Some(true), None]), - } - .run() - } - - #[test] - fn roundtrip_int32() { - Test { - input: i32_array([ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(0), - Some(5), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: i32_array([Some(1), Some(0), None]), - expected_max: i32_array([Some(3), Some(5), None]), - } - .run() - } - - #[test] - fn roundtrip_int64() { - Test { - input: i64_array([ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(0), - Some(5), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: i64_array([Some(1), Some(0), None]), - expected_max: i64_array(vec![Some(3), Some(5), None]), - } - .run() - } - - #[test] - fn roundtrip_f32() { - Test { - input: f32_array([ - // row group 1 - Some(1.0), - None, - Some(3.0), - // row group 2 - Some(-1.0), - Some(5.0), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: f32_array([Some(1.0), Some(-1.0), None]), - expected_max: f32_array([Some(3.0), Some(5.0), None]), - } - .run() - } - - #[test] - fn roundtrip_f64() { - Test { - input: f64_array([ - // row group 1 - Some(1.0), - None, - Some(3.0), - // row group 2 - Some(-1.0), - Some(5.0), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: f64_array([Some(1.0), Some(-1.0), None]), - expected_max: f64_array([Some(3.0), Some(5.0), None]), - } - .run() - } - - #[test] - #[should_panic( - expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Int64, got TimestampNanosecond(NULL, None)" - )] - // Due to https://github.com/apache/datafusion/issues/8295 - fn roundtrip_timestamp() { - Test { - input: timestamp_array([ - // row group 1 - Some(1), - None, - Some(3), - // row group 2 - Some(9), - Some(5), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: timestamp_array([Some(1), Some(5), None]), - expected_max: timestamp_array([Some(3), Some(9), None]), - } - .run() - } - - #[test] - fn roundtrip_decimal() { - Test { - input: Arc::new( - Decimal128Array::from(vec![ - // row group 1 - Some(100), - None, - Some(22000), - // row group 2 - Some(500000), - Some(330000), - None, - // row group 3 - None, - None, - None, - ]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_min: Arc::new( - Decimal128Array::from(vec![Some(100), Some(330000), None]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - expected_max: Arc::new( - Decimal128Array::from(vec![Some(22000), Some(500000), None]) - .with_precision_and_scale(9, 2) - .unwrap(), - ), - } - .run() - } - - #[test] - fn roundtrip_utf8() { - Test { - input: utf8_array([ - // row group 1 - Some("A"), - None, - Some("Q"), - // row group 2 - Some("ZZ"), - Some("AA"), - None, - // row group 3 - None, - None, - None, - ]), - expected_min: utf8_array([Some("A"), Some("AA"), None]), - expected_max: utf8_array([Some("Q"), Some("ZZ"), None]), - } - .run() - } - - #[test] - fn roundtrip_struct() { - let mut test = Test { - input: struct_array(vec![ - // row group 1 - (Some(true), Some(1)), - (None, None), - (Some(true), Some(3)), - // row group 2 - (Some(true), Some(0)), - (Some(false), Some(5)), - (None, None), - // row group 3 - (None, None), - (None, None), - (None, None), - ]), - expected_min: struct_array(vec![ - (Some(true), Some(1)), - (Some(true), Some(0)), - (None, None), - ]), - - expected_max: struct_array(vec![ - (Some(true), Some(3)), - (Some(true), Some(0)), - (None, None), - ]), - }; - // Due to https://github.com/apache/datafusion/issues/8334, - // statistics for struct arrays are not supported - test.expected_min = - new_null_array(test.input.data_type(), test.expected_min.len()); - test.expected_max = - new_null_array(test.input.data_type(), test.expected_min.len()); - test.run() - } - - #[test] - #[should_panic( - expected = "Inconsistent types in ScalarValue::iter_to_array. Expected Utf8, got Binary(NULL)" - )] - // Due to https://github.com/apache/datafusion/issues/8295 - fn roundtrip_binary() { - Test { - input: Arc::new(BinaryArray::from_opt_vec(vec![ - // row group 1 - Some(b"A"), - None, - Some(b"Q"), - // row group 2 - Some(b"ZZ"), - Some(b"AA"), - None, - // row group 3 - None, - None, - None, - ])), - expected_min: Arc::new(BinaryArray::from_opt_vec(vec![ - Some(b"A"), - Some(b"AA"), - None, - ])), - expected_max: Arc::new(BinaryArray::from_opt_vec(vec![ - Some(b"Q"), - Some(b"ZZ"), - None, - ])), - } - .run() - } - - #[test] - fn struct_and_non_struct() { - // Ensures that statistics for an array that appears *after* a struct - // array are not wrong - let struct_col = struct_array(vec![ - // row group 1 - (Some(true), Some(1)), - (None, None), - (Some(true), Some(3)), - ]); - let int_col = i32_array([Some(100), Some(200), Some(300)]); - let expected_min = i32_array([Some(100)]); - let expected_max = i32_array(vec![Some(300)]); - - // use a name that shadows a name in the struct column - match struct_col.data_type() { - DataType::Struct(fields) => { - assert_eq!(fields.get(1).unwrap().name(), "int_col") - } - _ => panic!("unexpected data type for struct column"), - }; - - let input_batch = RecordBatch::try_from_iter([ - ("struct_col", struct_col), - ("int_col", int_col), - ]) - .unwrap(); - - let schema = input_batch.schema(); - - let metadata = parquet_metadata(schema.clone(), input_batch); - let parquet_schema = metadata.file_metadata().schema_descr(); - - // read the int_col statistics - let (idx, _) = parquet_column(parquet_schema, &schema, "int_col").unwrap(); - assert_eq!(idx, 2); - - let row_groups = metadata.row_groups(); - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - - let min = min_statistics(&DataType::Int32, iter.clone()).unwrap(); - assert_eq!( - &min, - &expected_min, - "Min. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - - let max = max_statistics(&DataType::Int32, iter).unwrap(); - assert_eq!( - &max, - &expected_max, - "Max. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - } - - #[test] - fn nan_in_stats() { - // /parquet-testing/data/nan_in_stats.parquet - // row_groups: 1 - // "x": Double({min: Some(1.0), max: Some(NaN), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - - TestFile::new("nan_in_stats.parquet") - .with_column(ExpectedColumn { - name: "x", - expected_min: Arc::new(Float64Array::from(vec![Some(1.0)])), - expected_max: Arc::new(Float64Array::from(vec![Some(f64::NAN)])), - }) - .run(); - } - - #[test] - fn alltypes_plain() { - // /parquet-testing/data/datapage_v1-snappy-compressed-checksum.parquet - // row_groups: 1 - // (has no statistics) - TestFile::new("alltypes_plain.parquet") - // No column statistics should be read as NULL, but with the right type - .with_column(ExpectedColumn { - name: "id", - expected_min: i32_array([None]), - expected_max: i32_array([None]), - }) - .with_column(ExpectedColumn { - name: "bool_col", - expected_min: bool_array([None]), - expected_max: bool_array([None]), - }) - .run(); - } - - #[test] - fn alltypes_tiny_pages() { - // /parquet-testing/data/alltypes_tiny_pages.parquet - // row_groups: 1 - // "id": Int32({min: Some(0), max: Some(7299), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "bool_col": Boolean({min: Some(false), max: Some(true), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "tinyint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "smallint_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "int_col": Int32({min: Some(0), max: Some(9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "bigint_col": Int64({min: Some(0), max: Some(90), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "float_col": Float({min: Some(0.0), max: Some(9.9), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "double_col": Double({min: Some(0.0), max: Some(90.89999999999999), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "date_string_col": ByteArray({min: Some(ByteArray { data: "01/01/09" }), max: Some(ByteArray { data: "12/31/10" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "string_col": ByteArray({min: Some(ByteArray { data: "0" }), max: Some(ByteArray { data: "9" }), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "timestamp_col": Int96({min: None, max: None, distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) - // "year": Int32({min: Some(2009), max: Some(2010), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - // "month": Int32({min: Some(1), max: Some(12), distinct_count: None, null_count: 0, min_max_deprecated: false, min_max_backwards_compatible: false}) - TestFile::new("alltypes_tiny_pages.parquet") - .with_column(ExpectedColumn { - name: "id", - expected_min: i32_array([Some(0)]), - expected_max: i32_array([Some(7299)]), - }) - .with_column(ExpectedColumn { - name: "bool_col", - expected_min: bool_array([Some(false)]), - expected_max: bool_array([Some(true)]), - }) - .with_column(ExpectedColumn { - name: "tinyint_col", - expected_min: i32_array([Some(0)]), - expected_max: i32_array([Some(9)]), - }) - .with_column(ExpectedColumn { - name: "smallint_col", - expected_min: i32_array([Some(0)]), - expected_max: i32_array([Some(9)]), - }) - .with_column(ExpectedColumn { - name: "int_col", - expected_min: i32_array([Some(0)]), - expected_max: i32_array([Some(9)]), - }) - .with_column(ExpectedColumn { - name: "bigint_col", - expected_min: i64_array([Some(0)]), - expected_max: i64_array([Some(90)]), - }) - .with_column(ExpectedColumn { - name: "float_col", - expected_min: f32_array([Some(0.0)]), - expected_max: f32_array([Some(9.9)]), - }) - .with_column(ExpectedColumn { - name: "double_col", - expected_min: f64_array([Some(0.0)]), - expected_max: f64_array([Some(90.89999999999999)]), - }) - .with_column(ExpectedColumn { - name: "date_string_col", - expected_min: utf8_array([Some("01/01/09")]), - expected_max: utf8_array([Some("12/31/10")]), - }) - .with_column(ExpectedColumn { - name: "string_col", - expected_min: utf8_array([Some("0")]), - expected_max: utf8_array([Some("9")]), - }) - // File has no min/max for timestamp_col - .with_column(ExpectedColumn { - name: "timestamp_col", - expected_min: timestamp_array([None]), - expected_max: timestamp_array([None]), - }) - .with_column(ExpectedColumn { - name: "year", - expected_min: i32_array([Some(2009)]), - expected_max: i32_array([Some(2010)]), - }) - .with_column(ExpectedColumn { - name: "month", - expected_min: i32_array([Some(1)]), - expected_max: i32_array([Some(12)]), - }) - .run(); - } - - #[test] - fn fixed_length_decimal_legacy() { - // /parquet-testing/data/fixed_length_decimal_legacy.parquet - // row_groups: 1 - // "value": FixedLenByteArray({min: Some(FixedLenByteArray(ByteArray { data: Some(ByteBufferPtr { data: b"\0\0\0\0\0\xc8" }) })), max: Some(FixedLenByteArray(ByteArray { data: "\0\0\0\0\t`" })), distinct_count: None, null_count: 0, min_max_deprecated: true, min_max_backwards_compatible: true}) - - TestFile::new("fixed_length_decimal_legacy.parquet") - .with_column(ExpectedColumn { - name: "value", - expected_min: Arc::new( - Decimal128Array::from(vec![Some(200)]) - .with_precision_and_scale(13, 2) - .unwrap(), - ), - expected_max: Arc::new( - Decimal128Array::from(vec![Some(2400)]) - .with_precision_and_scale(13, 2) - .unwrap(), - ), - }) - .run(); - } - - const ROWS_PER_ROW_GROUP: usize = 3; - - /// Writes the input batch into a parquet file, with every every three rows as - /// their own row group, and compares the min/maxes to the expected values - struct Test { - input: ArrayRef, - expected_min: ArrayRef, - expected_max: ArrayRef, - } - - impl Test { - fn run(self) { - let Self { - input, - expected_min, - expected_max, - } = self; - - let input_batch = RecordBatch::try_from_iter([("c1", input)]).unwrap(); - - let schema = input_batch.schema(); - - let metadata = parquet_metadata(schema.clone(), input_batch); - let parquet_schema = metadata.file_metadata().schema_descr(); - - let row_groups = metadata.row_groups(); - - for field in schema.fields() { - if field.data_type().is_nested() { - let lookup = parquet_column(parquet_schema, &schema, field.name()); - assert_eq!(lookup, None); - continue; - } - - let (idx, f) = - parquet_column(parquet_schema, &schema, field.name()).unwrap(); - assert_eq!(f, field); - - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let min = min_statistics(f.data_type(), iter.clone()).unwrap(); - assert_eq!( - &min, - &expected_min, - "Min. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - - let max = max_statistics(f.data_type(), iter).unwrap(); - assert_eq!( - &max, - &expected_max, - "Max. Statistics\n\n{}\n\n", - DisplayStats(row_groups) - ); - } - } - } - - /// Write the specified batches out as parquet and return the metadata - fn parquet_metadata(schema: SchemaRef, batch: RecordBatch) -> Arc { - let props = WriterProperties::builder() - .set_statistics_enabled(EnabledStatistics::Chunk) - .set_max_row_group_size(ROWS_PER_ROW_GROUP) - .build(); - - let mut buffer = Vec::new(); - let mut writer = ArrowWriter::try_new(&mut buffer, schema, Some(props)).unwrap(); - writer.write(&batch).unwrap(); - writer.close().unwrap(); - - let reader = ArrowReaderBuilder::try_new(Bytes::from(buffer)).unwrap(); - reader.metadata().clone() - } - - /// Formats the statistics nicely for display - struct DisplayStats<'a>(&'a [RowGroupMetaData]); - impl<'a> std::fmt::Display for DisplayStats<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let row_groups = self.0; - writeln!(f, " row_groups: {}", row_groups.len())?; - for rg in row_groups { - for col in rg.columns() { - if let Some(statistics) = col.statistics() { - writeln!(f, " {}: {:?}", col.column_path(), statistics)?; - } - } - } - Ok(()) - } - } - - struct ExpectedColumn { - name: &'static str, - expected_min: ArrayRef, - expected_max: ArrayRef, - } - - /// Reads statistics out of the specified, and compares them to the expected values - struct TestFile { - file_name: &'static str, - expected_columns: Vec, - } - - impl TestFile { - fn new(file_name: &'static str) -> Self { - Self { - file_name, - expected_columns: Vec::new(), - } - } - - fn with_column(mut self, column: ExpectedColumn) -> Self { - self.expected_columns.push(column); - self - } - - /// Reads the specified parquet file and validates that the exepcted min/max - /// values for the specified columns are as expected. - fn run(self) { - let path = PathBuf::from(parquet_test_data()).join(self.file_name); - let file = std::fs::File::open(path).unwrap(); - let reader = ArrowReaderBuilder::try_new(file).unwrap(); - let arrow_schema = reader.schema(); - let metadata = reader.metadata(); - let row_groups = metadata.row_groups(); - let parquet_schema = metadata.file_metadata().schema_descr(); - - for expected_column in self.expected_columns { - let ExpectedColumn { - name, - expected_min, - expected_max, - } = expected_column; - - let (idx, field) = - parquet_column(parquet_schema, arrow_schema, name).unwrap(); - - let iter = row_groups.iter().map(|x| x.column(idx).statistics()); - let actual_min = min_statistics(field.data_type(), iter.clone()).unwrap(); - assert_eq!(&expected_min, &actual_min, "column {name}"); - - let actual_max = max_statistics(field.data_type(), iter).unwrap(); - assert_eq!(&expected_max, &actual_max, "column {name}"); - } - } - } - - fn bool_array(input: impl IntoIterator>) -> ArrayRef { - let array: BooleanArray = input.into_iter().collect(); - Arc::new(array) - } - - fn i32_array(input: impl IntoIterator>) -> ArrayRef { - let array: Int32Array = input.into_iter().collect(); - Arc::new(array) - } - - fn i64_array(input: impl IntoIterator>) -> ArrayRef { - let array: Int64Array = input.into_iter().collect(); - Arc::new(array) - } - - fn f32_array(input: impl IntoIterator>) -> ArrayRef { - let array: Float32Array = input.into_iter().collect(); - Arc::new(array) - } - - fn f64_array(input: impl IntoIterator>) -> ArrayRef { - let array: Float64Array = input.into_iter().collect(); - Arc::new(array) - } - - fn timestamp_array(input: impl IntoIterator>) -> ArrayRef { - let array: TimestampNanosecondArray = input.into_iter().collect(); - Arc::new(array) - } - - fn utf8_array<'a>(input: impl IntoIterator>) -> ArrayRef { - let array: StringArray = input - .into_iter() - .map(|s| s.map(|s| s.to_string())) - .collect(); - Arc::new(array) - } - - // returns a struct array with columns "bool_col" and "int_col" with the specified values - fn struct_array(input: Vec<(Option, Option)>) -> ArrayRef { - let boolean: BooleanArray = input.iter().map(|(b, _i)| b).collect(); - let int: Int32Array = input.iter().map(|(_b, i)| i).collect(); - - let nullable = true; - let struct_array = StructArray::from(vec![ - ( - Arc::new(Field::new("bool_col", DataType::Boolean, nullable)), - Arc::new(boolean) as ArrayRef, - ), - ( - Arc::new(Field::new("int_col", DataType::Int32, nullable)), - Arc::new(int) as ArrayRef, - ), - ]); - Arc::new(struct_array) - } -} diff --git a/datafusion/core/src/datasource/physical_plan/parquet/writer.rs b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs new file mode 100644 index 000000000000..0c0c54691068 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/parquet/writer.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::datasource::listing::ListingTableUrl; +use datafusion_common::DataFusionError; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; +use futures::StreamExt; +use object_store::buffered::BufWriter; +use object_store::path::Path; +use parquet::arrow::AsyncArrowWriter; +use parquet::file::properties::WriterProperties; +use std::sync::Arc; +use tokio::task::JoinSet; + +/// Executes a query and writes the results to a partitioned Parquet file. +pub async fn plan_to_parquet( + task_ctx: Arc, + plan: Arc, + path: impl AsRef, + writer_properties: Option, +) -> datafusion_common::Result<()> { + let path = path.as_ref(); + let parsed = ListingTableUrl::parse(path)?; + let object_store_url = parsed.object_store(); + let store = task_ctx.runtime_env().object_store(&object_store_url)?; + let mut join_set = JoinSet::new(); + for i in 0..plan.output_partitioning().partition_count() { + let plan: Arc = plan.clone(); + let filename = format!("{}/part-{i}.parquet", parsed.prefix()); + let file = Path::parse(filename)?; + let propclone = writer_properties.clone(); + + let storeref = store.clone(); + let buf_writer = BufWriter::new(storeref, file.clone()); + let mut stream = plan.execute(i, task_ctx.clone())?; + join_set.spawn(async move { + let mut writer = + AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; + while let Some(next_batch) = stream.next().await { + let batch = next_batch?; + writer.write(&batch).await?; + } + writer + .close() + .await + .map_err(DataFusionError::from) + .map(|_| ()) + }); + } + + while let Some(result) = join_set.join_next().await { + match result { + Ok(res) => res?, + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + Ok(()) +} diff --git a/datafusion/core/src/datasource/physical_plan/statistics.rs b/datafusion/core/src/datasource/physical_plan/statistics.rs new file mode 100644 index 000000000000..6af153a731b0 --- /dev/null +++ b/datafusion/core/src/datasource/physical_plan/statistics.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/*! + * + * Use statistics to optimize physical planning. + * + * Currently, this module houses code to sort file groups if they are non-overlapping with + * respect to the required sort order. See [`MinMaxStatistics`] + * +*/ + +use std::sync::Arc; + +use crate::datasource::listing::PartitionedFile; + +use arrow::{ + compute::SortColumn, + row::{Row, Rows}, +}; +use arrow_array::RecordBatch; +use arrow_schema::SchemaRef; +use datafusion_common::{DataFusionError, Result}; +use datafusion_physical_expr::{expressions::Column, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; + +/// A normalized representation of file min/max statistics that allows for efficient sorting & comparison. +/// The min/max values are ordered by [`Self::sort_order`]. +/// Furthermore, any columns that are reversed in the sort order have their min/max values swapped. +pub(crate) struct MinMaxStatistics { + min_by_sort_order: Rows, + max_by_sort_order: Rows, + sort_order: LexOrdering, +} + +impl MinMaxStatistics { + /// Sort order used to sort the statistics + #[allow(unused)] + pub fn sort_order(&self) -> LexOrderingRef { + &self.sort_order + } + + /// Min value at index + #[allow(unused)] + pub fn min(&self, idx: usize) -> Row { + self.min_by_sort_order.row(idx) + } + + /// Max value at index + pub fn max(&self, idx: usize) -> Row { + self.max_by_sort_order.row(idx) + } + + pub fn new_from_files<'a>( + projected_sort_order: LexOrderingRef, // Sort order with respect to projected schema + projected_schema: &SchemaRef, // Projected schema + projection: Option<&[usize]>, // Indices of projection in full table schema (None = all columns) + files: impl IntoIterator, + ) -> Result { + use datafusion_common::ScalarValue; + + let statistics_and_partition_values = files + .into_iter() + .map(|file| { + file.statistics + .as_ref() + .zip(Some(file.partition_values.as_slice())) + }) + .collect::>>() + .ok_or_else(|| { + DataFusionError::Plan("Parquet file missing statistics".to_string()) + })?; + + // Helper function to get min/max statistics for a given column of projected_schema + let get_min_max = |i: usize| -> Result<(Vec, Vec)> { + Ok(statistics_and_partition_values + .iter() + .map(|(s, pv)| { + if i < s.column_statistics.len() { + s.column_statistics[i] + .min_value + .get_value() + .cloned() + .zip(s.column_statistics[i].max_value.get_value().cloned()) + .ok_or_else(|| { + DataFusionError::Plan("statistics not found".to_string()) + }) + } else { + let partition_value = &pv[i - s.column_statistics.len()]; + Ok((partition_value.clone(), partition_value.clone())) + } + }) + .collect::>>()? + .into_iter() + .unzip()) + }; + + let sort_columns = sort_columns_from_physical_sort_exprs(projected_sort_order) + .ok_or(DataFusionError::Plan( + "sort expression must be on column".to_string(), + ))?; + + // Project the schema & sort order down to just the relevant columns + let min_max_schema = Arc::new( + projected_schema + .project(&(sort_columns.iter().map(|c| c.index()).collect::>()))?, + ); + let min_max_sort_order = sort_columns + .iter() + .zip(projected_sort_order.iter()) + .enumerate() + .map(|(i, (col, sort))| PhysicalSortExpr { + expr: Arc::new(Column::new(col.name(), i)), + options: sort.options, + }) + .collect::>(); + + let (min_values, max_values): (Vec<_>, Vec<_>) = sort_columns + .iter() + .map(|c| { + // Reverse the projection to get the index of the column in the full statistics + // The file statistics contains _every_ column , but the sort column's index() + // refers to the index in projected_schema + let i = projection.map(|p| p[c.index()]).unwrap_or(c.index()); + + let (min, max) = get_min_max(i).map_err(|e| { + e.context(format!("get min/max for column: '{}'", c.name())) + })?; + Ok(( + ScalarValue::iter_to_array(min)?, + ScalarValue::iter_to_array(max)?, + )) + }) + .collect::>>() + .map_err(|e| e.context("collect min/max values"))? + .into_iter() + .unzip(); + + Self::new( + &min_max_sort_order, + &min_max_schema, + RecordBatch::try_new(Arc::clone(&min_max_schema), min_values).map_err( + |e| { + DataFusionError::ArrowError(e, Some("\ncreate min batch".to_string())) + }, + )?, + RecordBatch::try_new(Arc::clone(&min_max_schema), max_values).map_err( + |e| { + DataFusionError::ArrowError(e, Some("\ncreate max batch".to_string())) + }, + )?, + ) + } + + pub fn new( + sort_order: LexOrderingRef, + schema: &SchemaRef, + min_values: RecordBatch, + max_values: RecordBatch, + ) -> Result { + use arrow::row::*; + + let sort_fields = sort_order + .iter() + .map(|expr| { + expr.expr + .data_type(schema) + .map(|data_type| SortField::new_with_options(data_type, expr.options)) + }) + .collect::>>() + .map_err(|e| e.context("create sort fields"))?; + let converter = RowConverter::new(sort_fields)?; + + let sort_columns = sort_columns_from_physical_sort_exprs(sort_order).ok_or( + DataFusionError::Plan("sort expression must be on column".to_string()), + )?; + + // swap min/max if they're reversed in the ordering + let (new_min_cols, new_max_cols): (Vec<_>, Vec<_>) = sort_order + .iter() + .zip(sort_columns.iter().copied()) + .map(|(sort_expr, column)| { + if sort_expr.options.descending { + max_values + .column_by_name(column.name()) + .zip(min_values.column_by_name(column.name())) + } else { + min_values + .column_by_name(column.name()) + .zip(max_values.column_by_name(column.name())) + } + .ok_or_else(|| { + DataFusionError::Plan(format!( + "missing column in MinMaxStatistics::new: '{}'", + column.name() + )) + }) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let [min, max] = [new_min_cols, new_max_cols].map(|cols| { + let values = RecordBatch::try_new( + min_values.schema(), + cols.into_iter().cloned().collect(), + )?; + let sorting_columns = sort_order + .iter() + .zip(sort_columns.iter().copied()) + .map(|(sort_expr, column)| { + let schema = values.schema(); + + let idx = schema.index_of(column.name())?; + let field = schema.field(idx); + + // check that sort columns are non-nullable + if field.is_nullable() { + return Err(DataFusionError::Plan( + "cannot sort by nullable column".to_string(), + )); + } + + Ok(SortColumn { + values: Arc::clone(values.column(idx)), + options: Some(sort_expr.options), + }) + }) + .collect::>>() + .map_err(|e| e.context("create sorting columns"))?; + converter + .convert_columns( + &sorting_columns + .into_iter() + .map(|c| c.values) + .collect::>(), + ) + .map_err(|e| { + DataFusionError::ArrowError(e, Some("convert columns".to_string())) + }) + }); + + Ok(Self { + min_by_sort_order: min.map_err(|e| e.context("build min rows"))?, + max_by_sort_order: max.map_err(|e| e.context("build max rows"))?, + sort_order: LexOrdering::from_ref(sort_order), + }) + } + + /// Return a sorted list of the min statistics together with the original indices + pub fn min_values_sorted(&self) -> Vec<(usize, Row<'_>)> { + let mut sort: Vec<_> = self.min_by_sort_order.iter().enumerate().collect(); + sort.sort_unstable_by(|(_, a), (_, b)| a.cmp(b)); + sort + } + + /// Check if the min/max statistics are in order and non-overlapping + pub fn is_sorted(&self) -> bool { + self.max_by_sort_order + .iter() + .zip(self.min_by_sort_order.iter().skip(1)) + .all(|(max, next_min)| max < next_min) + } +} + +fn sort_columns_from_physical_sort_exprs( + sort_order: LexOrderingRef, +) -> Option> { + sort_order + .iter() + .map(|expr| expr.expr.as_any().downcast_ref::()) + .collect::>>() +} diff --git a/datafusion/core/src/datasource/provider.rs b/datafusion/core/src/datasource/provider.rs index 7c58aded3108..9d4b67632a01 100644 --- a/datafusion/core/src/datasource/provider.rs +++ b/datafusion/core/src/datasource/provider.rs @@ -17,285 +17,17 @@ //! Data source traits -use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use datafusion_common::{not_impl_err, Constraints, Statistics}; -use datafusion_expr::{CreateExternalTable, LogicalPlan}; +use datafusion_catalog::Session; +use datafusion_expr::CreateExternalTable; pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; -use crate::arrow::datatypes::SchemaRef; +use crate::catalog::{TableProvider, TableProviderFactory}; use crate::datasource::listing_table_factory::ListingTableFactory; use crate::datasource::stream::StreamTableFactory; use crate::error::Result; -use crate::execution::context::SessionState; -use crate::logical_expr::Expr; -use crate::physical_plan::ExecutionPlan; - -/// Source table -#[async_trait] -pub trait TableProvider: Sync + Send { - /// Returns the table provider as [`Any`](std::any::Any) so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Get a reference to the schema for this table - fn schema(&self) -> SchemaRef; - - /// Get a reference to the constraints of the table. - /// Returns: - /// - `None` for tables that do not support constraints. - /// - `Some(&Constraints)` for tables supporting constraints. - /// Therefore, a `Some(&Constraints::empty())` return value indicates that - /// this table supports constraints, but there are no constraints. - fn constraints(&self) -> Option<&Constraints> { - None - } - - /// Get the type of this table for metadata/catalog purposes. - fn table_type(&self) -> TableType; - - /// Get the create statement used to create this table, if available. - fn get_table_definition(&self) -> Option<&str> { - None - } - - /// Get the [`LogicalPlan`] of this table, if available - fn get_logical_plan(&self) -> Option<&LogicalPlan> { - None - } - - /// Get the default value for a column, if available. - fn get_column_default(&self, _column: &str) -> Option<&Expr> { - None - } - - /// Create an [`ExecutionPlan`] for scanning the table with optionally - /// specified `projection`, `filter` and `limit`, described below. - /// - /// The `ExecutionPlan` is responsible scanning the datasource's - /// partitions in a streaming, parallelized fashion. - /// - /// # Projection - /// - /// If specified, only a subset of columns should be returned, in the order - /// specified. The projection is a set of indexes of the fields in - /// [`Self::schema`]. - /// - /// DataFusion provides the projection to scan only the columns actually - /// used in the query to improve performance, an optimization called - /// "Projection Pushdown". Some datasources, such as Parquet, can use this - /// information to go significantly faster when only a subset of columns is - /// required. - /// - /// # Filters - /// - /// A list of boolean filter [`Expr`]s to evaluate *during* the scan, in the - /// manner specified by [`Self::supports_filters_pushdown`]. Only rows for - /// which *all* of the `Expr`s evaluate to `true` must be returned (aka the - /// expressions are `AND`ed together). - /// - /// To enable filter pushdown you must override - /// [`Self::supports_filters_pushdown`] as the default implementation does - /// not and `filters` will be empty. - /// - /// DataFusion pushes filtering into the scans whenever possible - /// ("Filter Pushdown"), and depending on the format and the - /// implementation of the format, evaluating the predicate during the scan - /// can increase performance significantly. - /// - /// ## Note: Some columns may appear *only* in Filters - /// - /// In certain cases, a query may only use a certain column in a Filter that - /// has been completely pushed down to the scan. In this case, the - /// projection will not contain all the columns found in the filter - /// expressions. - /// - /// For example, given the query `SELECT t.a FROM t WHERE t.b > 5`, - /// - /// ```text - /// ┌────────────────────┐ - /// │ Projection(t.a) │ - /// └────────────────────┘ - /// ▲ - /// │ - /// │ - /// ┌────────────────────┐ Filter ┌────────────────────┐ Projection ┌────────────────────┐ - /// │ Filter(t.b > 5) │────Pushdown──▶ │ Projection(t.a) │ ───Pushdown───▶ │ Projection(t.a) │ - /// └────────────────────┘ └────────────────────┘ └────────────────────┘ - /// ▲ ▲ ▲ - /// │ │ │ - /// │ │ ┌────────────────────┐ - /// ┌────────────────────┐ ┌────────────────────┐ │ Scan │ - /// │ Scan │ │ Scan │ │ filter=(t.b > 5) │ - /// └────────────────────┘ │ filter=(t.b > 5) │ │ projection=(t.a) │ - /// └────────────────────┘ └────────────────────┘ - /// - /// Initial Plan If `TableProviderFilterPushDown` Projection pushdown notes that - /// returns true, filter pushdown the scan only needs t.a - /// pushes the filter into the scan - /// BUT internally evaluating the - /// predicate still requires t.b - /// ``` - /// - /// # Limit - /// - /// If `limit` is specified, must only produce *at least* this many rows, - /// (though it may return more). Like Projection Pushdown and Filter - /// Pushdown, DataFusion pushes `LIMIT`s as far down in the plan as - /// possible, called "Limit Pushdown" as some sources can use this - /// information to improve their performance. Note that if there are any - /// Inexact filters pushed down, the LIMIT cannot be pushed down. This is - /// because inexact filters do not guarantee that every filtered row is - /// removed, so applying the limit could lead to too few rows being available - /// to return as a final result. - async fn scan( - &self, - state: &SessionState, - projection: Option<&Vec>, - filters: &[Expr], - limit: Option, - ) -> Result>; - - /// Specify if DataFusion should provide filter expressions to the - /// TableProvider to apply *during* the scan. - /// - /// Some TableProviders can evaluate filters more efficiently than the - /// `Filter` operator in DataFusion, for example by using an index. - /// - /// # Parameters and Return Value - /// - /// The return `Vec` must have one element for each element of the `filters` - /// argument. The value of each element indicates if the TableProvider can - /// apply the corresponding filter during the scan. The position in the return - /// value corresponds to the expression in the `filters` parameter. - /// - /// If the length of the resulting `Vec` does not match the `filters` input - /// an error will be thrown. - /// - /// Each element in the resulting `Vec` is one of the following: - /// * [`Exact`] or [`Inexact`]: The TableProvider can apply the filter - /// during scan - /// * [`Unsupported`]: The TableProvider cannot apply the filter during scan - /// - /// By default, this function returns [`Unsupported`] for all filters, - /// meaning no filters will be provided to [`Self::scan`]. - /// - /// [`Unsupported`]: TableProviderFilterPushDown::Unsupported - /// [`Exact`]: TableProviderFilterPushDown::Exact - /// [`Inexact`]: TableProviderFilterPushDown::Inexact - /// # Example - /// - /// ```rust - /// # use std::any::Any; - /// # use std::sync::Arc; - /// # use arrow_schema::SchemaRef; - /// # use async_trait::async_trait; - /// # use datafusion::datasource::TableProvider; - /// # use datafusion::error::{Result, DataFusionError}; - /// # use datafusion::execution::context::SessionState; - /// # use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType}; - /// # use datafusion_physical_plan::ExecutionPlan; - /// // Define a struct that implements the TableProvider trait - /// struct TestDataSource {} - /// - /// #[async_trait] - /// impl TableProvider for TestDataSource { - /// # fn as_any(&self) -> &dyn Any { todo!() } - /// # fn schema(&self) -> SchemaRef { todo!() } - /// # fn table_type(&self) -> TableType { todo!() } - /// # async fn scan(&self, s: &SessionState, p: Option<&Vec>, f: &[Expr], l: Option) -> Result> { - /// todo!() - /// # } - /// // Override the supports_filters_pushdown to evaluate which expressions - /// // to accept as pushdown predicates. - /// fn supports_filters_pushdown(&self, filters: &[&Expr]) -> Result> { - /// // Process each filter - /// let support: Vec<_> = filters.iter().map(|expr| { - /// match expr { - /// // This example only supports a between expr with a single column named "c1". - /// Expr::Between(between_expr) => { - /// between_expr.expr - /// .try_into_col() - /// .map(|column| { - /// if column.name == "c1" { - /// TableProviderFilterPushDown::Exact - /// } else { - /// TableProviderFilterPushDown::Unsupported - /// } - /// }) - /// // If there is no column in the expr set the filter to unsupported. - /// .unwrap_or(TableProviderFilterPushDown::Unsupported) - /// } - /// _ => { - /// // For all other cases return Unsupported. - /// TableProviderFilterPushDown::Unsupported - /// } - /// } - /// }).collect(); - /// Ok(support) - /// } - /// } - /// ``` - fn supports_filters_pushdown( - &self, - filters: &[&Expr], - ) -> Result> { - Ok(vec![ - TableProviderFilterPushDown::Unsupported; - filters.len() - ]) - } - - /// Get statistics for this table, if available - fn statistics(&self) -> Option { - None - } - - /// Return an [`ExecutionPlan`] to insert data into this table, if - /// supported. - /// - /// The returned plan should return a single row in a UInt64 - /// column called "count" such as the following - /// - /// ```text - /// +-------+, - /// | count |, - /// +-------+, - /// | 6 |, - /// +-------+, - /// ``` - /// - /// # See Also - /// - /// See [`DataSinkExec`] for the common pattern of inserting a - /// streams of `RecordBatch`es as files to an ObjectStore. - /// - /// [`DataSinkExec`]: crate::physical_plan::insert::DataSinkExec - async fn insert_into( - &self, - _state: &SessionState, - _input: Arc, - _overwrite: bool, - ) -> Result> { - not_impl_err!("Insert into not implemented for this table") - } -} - -/// A factory which creates [`TableProvider`]s at runtime given a URL. -/// -/// For example, this can be used to create a table "on the fly" -/// from a directory of files only when that name is referenced. -#[async_trait] -pub trait TableProviderFactory: Sync + Send { - /// Create a TableProvider with the given url - async fn create( - &self, - state: &SessionState, - cmd: &CreateExternalTable, - ) -> Result>; -} /// The default [`TableProviderFactory`] /// @@ -318,7 +50,7 @@ impl DefaultTableFactory { impl TableProviderFactory for DefaultTableFactory { async fn create( &self, - state: &SessionState, + state: &dyn Session, cmd: &CreateExternalTable, ) -> Result> { let mut unbounded = cmd.unbounded; diff --git a/datafusion/core/src/datasource/schema_adapter.rs b/datafusion/core/src/datasource/schema_adapter.rs new file mode 100644 index 000000000000..5ba597e4b542 --- /dev/null +++ b/datafusion/core/src/datasource/schema_adapter.rs @@ -0,0 +1,647 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`SchemaAdapter`] and [`SchemaAdapterFactory`] to adapt file-level record batches to a table schema. +//! +//! Adapter provides a method of translating the RecordBatches that come out of the +//! physical format into how they should be used by DataFusion. For instance, a schema +//! can be stored external to a parquet file that maps parquet logical types to arrow types. + +use arrow::compute::{can_cast_types, cast}; +use arrow_array::{new_null_array, RecordBatch, RecordBatchOptions}; +use arrow_schema::{Schema, SchemaRef}; +use datafusion_common::plan_err; +use std::fmt::Debug; +use std::sync::Arc; + +/// Factory for creating [`SchemaAdapter`] +/// +/// This interface provides a way to implement custom schema adaptation logic +/// for ParquetExec (for example, to fill missing columns with default value +/// other than null). +/// +/// Most users should use [`DefaultSchemaAdapterFactory`]. See that struct for +/// more details and examples. +pub trait SchemaAdapterFactory: Debug + Send + Sync + 'static { + /// Create a [`SchemaAdapter`] + /// + /// Arguments: + /// + /// * `projected_table_schema`: The schema for the table, projected to + /// include only the fields being output (projected) by the this mapping. + /// + /// * `table_schema`: The entire table schema for the table + fn create( + &self, + projected_table_schema: SchemaRef, + table_schema: SchemaRef, + ) -> Box; +} + +/// Creates [`SchemaMapper`]s to map file-level [`RecordBatch`]es to a table +/// schema, which may have a schema obtained from merging multiple file-level +/// schemas. +/// +/// This is useful for implementing schema evolution in partitioned datasets. +/// +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. +pub trait SchemaAdapter: Send + Sync { + /// Map a column index in the table schema to a column index in a particular + /// file schema + /// + /// This is used while reading a file to push down projections by mapping + /// projected column indexes from the table schema to the file schema + /// + /// Panics if index is not in range for the table schema + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option; + + /// Creates a mapping for casting columns from the file schema to the table + /// schema. + /// + /// This is used after reading a record batch. The returned [`SchemaMapper`]: + /// + /// 1. Maps columns to the expected columns indexes + /// 2. Handles missing values (e.g. fills nulls or a default value) for + /// columns in the in the table schema not in the file schema + /// 2. Handles different types: if the column in the file schema has a + /// different type than `table_schema`, the mapper will resolve this + /// difference (e.g. by casting to the appropriate type) + /// + /// Returns: + /// * a [`SchemaMapper`] + /// * an ordered list of columns to project from the file + fn map_schema( + &self, + file_schema: &Schema, + ) -> datafusion_common::Result<(Arc, Vec)>; +} + +/// Maps, columns from a specific file schema to the table schema. +/// +/// See [`DefaultSchemaAdapterFactory`] for more details and examples. +pub trait SchemaMapper: Debug + Send + Sync { + /// Adapts a `RecordBatch` to match the `table_schema` + fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result; + + /// Adapts a [`RecordBatch`] that does not have all the columns from the + /// file schema. + /// + /// This method is used, for example, when applying a filter to a subset of + /// the columns as part of `DataFusionArrowPredicate` when `filter_pushdown` + /// is enabled. + /// + /// This method is slower than `map_batch` as it looks up columns by name. + fn map_partial_batch( + &self, + batch: RecordBatch, + ) -> datafusion_common::Result; +} + +/// Default [`SchemaAdapterFactory`] for mapping schemas. +/// +/// This can be used to adapt file-level record batches to a table schema and +/// implement schema evolution. +/// +/// Given an input file schema and a table schema, this factory returns +/// [`SchemaAdapter`] that return [`SchemaMapper`]s that: +/// +/// 1. Reorder columns +/// 2. Cast columns to the correct type +/// 3. Fill missing columns with nulls +/// +/// # Errors: +/// +/// * If a column in the table schema is non-nullable but is not present in the +/// file schema (i.e. it is missing), the returned mapper tries to fill it with +/// nulls resulting in a schema error. +/// +/// # Illustration of Schema Mapping +/// +/// ```text +/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +/// ┌───────┐ ┌───────┐ │ ┌───────┐ ┌───────┐ ┌───────┐ │ +/// ││ 1.0 │ │ "foo" │ ││ NULL │ │ "foo" │ │ "1.0" │ +/// ├───────┤ ├───────┤ │ Schema mapping ├───────┤ ├───────┤ ├───────┤ │ +/// ││ 2.0 │ │ "bar" │ ││ NULL │ │ "bar" │ │ "2.0" │ +/// └───────┘ └───────┘ │────────────────▶ └───────┘ └───────┘ └───────┘ │ +/// │ │ +/// column "c" column "b"│ column "a" column "b" column "c"│ +/// │ Float64 Utf8 │ Int32 Utf8 Utf8 +/// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┘ +/// Input Record Batch Output Record Batch +/// +/// Schema { Schema { +/// "c": Float64, "a": Int32, +/// "b": Utf8, "b": Utf8, +/// } "c": Utf8, +/// } +/// ``` +/// +/// # Example of using the `DefaultSchemaAdapterFactory` to map [`RecordBatch`]s +/// +/// Note `SchemaMapping` also supports mapping partial batches, which is used as +/// part of predicate pushdown. +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion::datasource::schema_adapter::{DefaultSchemaAdapterFactory, SchemaAdapterFactory}; +/// # use datafusion_common::record_batch; +/// // Table has fields "a", "b" and "c" +/// let table_schema = Schema::new(vec![ +/// Field::new("a", DataType::Int32, true), +/// Field::new("b", DataType::Utf8, true), +/// Field::new("c", DataType::Utf8, true), +/// ]); +/// +/// // create an adapter to map the table schema to the file schema +/// let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); +/// +/// // The file schema has fields "c" and "b" but "b" is stored as an 'Float64' +/// // instead of 'Utf8' +/// let file_schema = Schema::new(vec![ +/// Field::new("c", DataType::Utf8, true), +/// Field::new("b", DataType::Float64, true), +/// ]); +/// +/// // Get a mapping from the file schema to the table schema +/// let (mapper, _indices) = adapter.map_schema(&file_schema).unwrap(); +/// +/// let file_batch = record_batch!( +/// ("c", Utf8, vec!["foo", "bar"]), +/// ("b", Float64, vec![1.0, 2.0]) +/// ).unwrap(); +/// +/// let mapped_batch = mapper.map_batch(file_batch).unwrap(); +/// +/// // the mapped batch has the correct schema and the "b" column has been cast to Utf8 +/// let expected_batch = record_batch!( +/// ("a", Int32, vec![None, None]), // missing column filled with nulls +/// ("b", Utf8, vec!["1.0", "2.0"]), // b was cast to string and order was changed +/// ("c", Utf8, vec!["foo", "bar"]) +/// ).unwrap(); +/// assert_eq!(mapped_batch, expected_batch); +/// ``` +#[derive(Clone, Debug, Default)] +pub struct DefaultSchemaAdapterFactory; + +impl DefaultSchemaAdapterFactory { + /// Create a new factory for mapping batches from a file schema to a table + /// schema. + /// + /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with + /// the same schema for both the projected table schema and the table + /// schema. + pub fn from_schema(table_schema: SchemaRef) -> Box { + Self.create(Arc::clone(&table_schema), table_schema) + } +} + +impl SchemaAdapterFactory for DefaultSchemaAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + table_schema: SchemaRef, + ) -> Box { + Box::new(DefaultSchemaAdapter { + projected_table_schema, + table_schema, + }) + } +} + +/// This SchemaAdapter requires both the table schema and the projected table +/// schema. See [`SchemaMapping`] for more details +#[derive(Clone, Debug)] +pub(crate) struct DefaultSchemaAdapter { + /// The schema for the table, projected to include only the fields being output (projected) by the + /// associated ParquetExec + projected_table_schema: SchemaRef, + /// The entire table schema for the table we're using this to adapt. + /// + /// This is used to evaluate any filters pushed down into the scan + /// which may refer to columns that are not referred to anywhere + /// else in the plan. + table_schema: SchemaRef, +} + +impl SchemaAdapter for DefaultSchemaAdapter { + /// Map a column index in the table schema to a column index in a particular + /// file schema + /// + /// Panics if index is not in range for the table schema + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.projected_table_schema.field(index); + Some(file_schema.fields.find(field.name())?.0) + } + + /// Creates a `SchemaMapping` for casting or mapping the columns from the + /// file schema to the table schema. + /// + /// If the provided `file_schema` contains columns of a different type to + /// the expected `table_schema`, the method will attempt to cast the array + /// data from the file schema to the table schema where possible. + /// + /// Returns a [`SchemaMapping`] that can be applied to the output batch + /// along with an ordered list of columns to project from the file + fn map_schema( + &self, + file_schema: &Schema, + ) -> datafusion_common::Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + let mut field_mappings = vec![None; self.projected_table_schema.fields().len()]; + + for (file_idx, file_field) in file_schema.fields.iter().enumerate() { + if let Some((table_idx, table_field)) = + self.projected_table_schema.fields().find(file_field.name()) + { + match can_cast_types(file_field.data_type(), table_field.data_type()) { + true => { + field_mappings[table_idx] = Some(projection.len()); + projection.push(file_idx); + } + false => { + return plan_err!( + "Cannot cast file schema field {} of type {:?} to table schema field of type {:?}", + file_field.name(), + file_field.data_type(), + table_field.data_type() + ) + } + } + } + } + + Ok(( + Arc::new(SchemaMapping { + projected_table_schema: self.projected_table_schema.clone(), + field_mappings, + table_schema: self.table_schema.clone(), + }), + projection, + )) + } +} + +/// The SchemaMapping struct holds a mapping from the file schema to the table +/// schema and any necessary type conversions. +/// +/// Note, because `map_batch` and `map_partial_batch` functions have different +/// needs, this struct holds two schemas: +/// +/// 1. The projected **table** schema +/// 2. The full table schema +/// +/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which +/// has the projected schema, since that's the schema which is supposed to come +/// out of the execution of this query. Thus `map_batch` uses +/// `projected_table_schema` as it can only operate on the projected fields. +/// +/// [`map_partial_batch`] is used to create a RecordBatch with a schema that +/// can be used for Parquet predicate pushdown, meaning that it may contain +/// fields which are not in the projected schema (as the fields that parquet +/// pushdown filters operate can be completely distinct from the fields that are +/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses +/// `table_schema` to create the resulting RecordBatch (as it could be operating +/// on any fields in the schema). +/// +/// [`map_batch`]: Self::map_batch +/// [`map_partial_batch`]: Self::map_partial_batch +#[derive(Debug)] +pub struct SchemaMapping { + /// The schema of the table. This is the expected schema after conversion + /// and it should match the schema of the query result. + projected_table_schema: SchemaRef, + /// Mapping from field index in `projected_table_schema` to index in + /// projected file_schema. + /// + /// They are Options instead of just plain `usize`s because the table could + /// have fields that don't exist in the file. + field_mappings: Vec>, + /// The entire table schema, as opposed to the projected_table_schema (which + /// only contains the columns that we are projecting out of this query). + /// This contains all fields in the table, regardless of if they will be + /// projected out or not. + table_schema: SchemaRef, +} + +impl SchemaMapper for SchemaMapping { + /// Adapts a `RecordBatch` to match the `projected_table_schema` using the stored mapping and + /// conversions. The produced RecordBatch has a schema that contains only the projected + /// columns, so if one needs a RecordBatch with a schema that references columns which are not + /// in the projected, it would be better to use `map_partial_batch` + fn map_batch(&self, batch: RecordBatch) -> datafusion_common::Result { + let batch_rows = batch.num_rows(); + let batch_cols = batch.columns().to_vec(); + + let cols = self + .projected_table_schema + // go through each field in the projected schema + .fields() + .iter() + // and zip it with the index that maps fields from the projected table schema to the + // projected file schema in `batch` + .zip(&self.field_mappings) + // and for each one... + .map(|(field, file_idx)| { + file_idx.map_or_else( + // If this field only exists in the table, and not in the file, then we know + // that it's null, so just return that. + || Ok(new_null_array(field.data_type(), batch_rows)), + // However, if it does exist in both, then try to cast it to the correct output + // type + |batch_idx| cast(&batch_cols[batch_idx], field.data_type()), + ) + }) + .collect::, _>>()?; + + // Necessary to handle empty batches + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + let schema = self.projected_table_schema.clone(); + let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; + Ok(record_batch) + } + + /// Adapts a [`RecordBatch`]'s schema into one that has all the correct output types and only + /// contains the fields that exist in both the file schema and table schema. + /// + /// Unlike `map_batch` this method also preserves the columns that + /// may not appear in the final output (`projected_table_schema`) but may + /// appear in push down predicates + fn map_partial_batch( + &self, + batch: RecordBatch, + ) -> datafusion_common::Result { + let batch_cols = batch.columns().to_vec(); + let schema = batch.schema(); + + // for each field in the batch's schema (which is based on a file, not a table)... + let (cols, fields) = schema + .fields() + .iter() + .zip(batch_cols.iter()) + .flat_map(|(field, batch_col)| { + self.table_schema + // try to get the same field from the table schema that we have stored in self + .field_with_name(field.name()) + // and if we don't have it, that's fine, ignore it. This may occur when we've + // created an external table whose fields are a subset of the fields in this + // file, then tried to read data from the file into this table. If that is the + // case here, it's fine to ignore because we don't care about this field + // anyways + .ok() + // but if we do have it, + .map(|table_field| { + // try to cast it into the correct output type. we don't want to ignore this + // error, though, so it's propagated. + cast(batch_col, table_field.data_type()) + // and if that works, return the field and column. + .map(|new_col| (new_col, table_field.clone())) + }) + }) + .collect::, _>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + // Necessary to handle empty batches + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + + let schema = + Arc::new(Schema::new_with_metadata(fields, schema.metadata().clone())); + let record_batch = RecordBatch::try_new_with_options(schema, cols, &options)?; + Ok(record_batch) + } +} + +#[cfg(test)] +mod tests { + use std::fs; + use std::sync::Arc; + + use crate::assert_batches_sorted_eq; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use arrow_array::{Int32Array, StringArray}; + use arrow_schema::{DataType, SchemaRef}; + use object_store::path::Path; + use object_store::ObjectMeta; + + use crate::datasource::object_store::ObjectStoreUrl; + use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; + use crate::physical_plan::collect; + use crate::prelude::SessionContext; + + use crate::datasource::listing::PartitionedFile; + use crate::datasource::schema_adapter::{ + DefaultSchemaAdapterFactory, SchemaAdapter, SchemaAdapterFactory, SchemaMapper, + }; + use datafusion_common::record_batch; + #[cfg(feature = "parquet")] + use parquet::arrow::ArrowWriter; + use tempfile::TempDir; + + #[tokio::test] + async fn can_override_schema_adapter() { + // Test shows that SchemaAdapter can add a column that doesn't existing in the + // record batches returned from parquet. This can be useful for schema evolution + // where older files may not have all columns. + let tmp_dir = TempDir::new().unwrap(); + let table_dir = tmp_dir.path().join("parquet_test"); + fs::DirBuilder::new().create(table_dir.as_path()).unwrap(); + let f1 = Field::new("id", DataType::Int32, true); + + let file_schema = Arc::new(Schema::new(vec![f1.clone()])); + let filename = "part.parquet".to_string(); + let path = table_dir.as_path().join(filename.clone()); + let file = fs::File::create(path.clone()).unwrap(); + let mut writer = ArrowWriter::try_new(file, file_schema.clone(), None).unwrap(); + + let ids = Arc::new(Int32Array::from(vec![1i32])); + let rec_batch = RecordBatch::try_new(file_schema.clone(), vec![ids]).unwrap(); + + writer.write(&rec_batch).unwrap(); + writer.close().unwrap(); + + let location = Path::parse(path.to_str().unwrap()).unwrap(); + let metadata = fs::metadata(path.as_path()).expect("Local file metadata"); + let meta = ObjectMeta { + location, + last_modified: metadata.modified().map(chrono::DateTime::from).unwrap(), + size: metadata.len() as usize, + e_tag: None, + version: None, + }; + + let partitioned_file = PartitionedFile { + object_meta: meta, + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + }; + + let f1 = Field::new("id", DataType::Int32, true); + let f2 = Field::new("extra_column", DataType::Utf8, true); + + let schema = Arc::new(Schema::new(vec![f1.clone(), f2.clone()])); + + // prepare the scan + let parquet_exec = ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema) + .with_file(partitioned_file), + ) + .build() + .with_schema_adapter_factory(Arc::new(TestSchemaAdapterFactory {})); + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let read = collect(Arc::new(parquet_exec), task_ctx).await.unwrap(); + + let expected = [ + "+----+--------------+", + "| id | extra_column |", + "+----+--------------+", + "| 1 | foo |", + "+----+--------------+", + ]; + + assert_batches_sorted_eq!(expected, &read); + } + + #[test] + fn default_schema_adapter() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + + // file has a subset of the table schema fields and different type + let file_schema = Schema::new(vec![ + Field::new("c", DataType::Float64, true), // not in table schema + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![1]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + let mapped_batch = mapper.map_batch(file_batch).unwrap(); + + // the mapped batch has the correct schema and the "b" column has been cast to Utf8 + let expected_batch = record_batch!( + ("a", Int32, vec![None, None]), // missing column filled with nulls + ("b", Utf8, vec!["1.0", "2.0"]) // b was cast to string and order was changed + ) + .unwrap(); + assert_eq!(mapped_batch, expected_batch); + } + + #[test] + fn default_schema_adapter_non_nullable_columns() { + let table_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), // "a"" is declared non nullable + Field::new("b", DataType::Utf8, true), + ]); + let file_schema = Schema::new(vec![ + // since file doesn't have "a" it will be filled with nulls + Field::new("b", DataType::Float64, true), + ]); + + let adapter = DefaultSchemaAdapterFactory::from_schema(Arc::new(table_schema)); + let (mapper, indices) = adapter.map_schema(&file_schema).unwrap(); + assert_eq!(indices, vec![0]); + + let file_batch = record_batch!(("b", Float64, vec![1.0, 2.0])).unwrap(); + + // Mapping fails because it tries to fill in a non-nullable column with nulls + let err = mapper.map_batch(file_batch).unwrap_err().to_string(); + assert!(err.contains("Invalid argument error: Column 'a' is declared as non-nullable but contains null values"), "{err}"); + } + + #[derive(Debug)] + struct TestSchemaAdapterFactory; + + impl SchemaAdapterFactory for TestSchemaAdapterFactory { + fn create( + &self, + projected_table_schema: SchemaRef, + _table_schema: SchemaRef, + ) -> Box { + Box::new(TestSchemaAdapter { + table_schema: projected_table_schema, + }) + } + } + + struct TestSchemaAdapter { + /// Schema for the table + table_schema: SchemaRef, + } + + impl SchemaAdapter for TestSchemaAdapter { + fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option { + let field = self.table_schema.field(index); + Some(file_schema.fields.find(field.name())?.0) + } + + fn map_schema( + &self, + file_schema: &Schema, + ) -> datafusion_common::Result<(Arc, Vec)> { + let mut projection = Vec::with_capacity(file_schema.fields().len()); + + for (file_idx, file_field) in file_schema.fields.iter().enumerate() { + if self.table_schema.fields().find(file_field.name()).is_some() { + projection.push(file_idx); + } + } + + Ok((Arc::new(TestSchemaMapping {}), projection)) + } + } + + #[derive(Debug)] + struct TestSchemaMapping {} + + impl SchemaMapper for TestSchemaMapping { + fn map_batch( + &self, + batch: RecordBatch, + ) -> datafusion_common::Result { + let f1 = Field::new("id", DataType::Int32, true); + let f2 = Field::new("extra_column", DataType::Utf8, true); + + let schema = Arc::new(Schema::new(vec![f1, f2])); + + let extra_column = Arc::new(StringArray::from(vec!["foo"])); + let mut new_columns = batch.columns().to_vec(); + new_columns.push(extra_column); + + Ok(RecordBatch::try_new(schema, new_columns).unwrap()) + } + + fn map_partial_batch( + &self, + batch: RecordBatch, + ) -> datafusion_common::Result { + self.map_batch(batch) + } + } +} diff --git a/datafusion/core/src/datasource/statistics.rs b/datafusion/core/src/datasource/statistics.rs index c67227f966a2..201bbfd5c007 100644 --- a/datafusion/core/src/datasource/statistics.rs +++ b/datafusion/core/src/datasource/statistics.rs @@ -15,18 +15,26 @@ // specific language governing permissions and limitations // under the License. -use super::listing::PartitionedFile; -use crate::arrow::datatypes::{Schema, SchemaRef}; -use crate::error::Result; -use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; -use crate::physical_plan::{Accumulator, ColumnStatistics, Statistics}; +use std::mem; +use std::sync::Arc; + +use futures::{Stream, StreamExt}; use datafusion_common::stats::Precision; use datafusion_common::ScalarValue; -use futures::{Stream, StreamExt}; -use itertools::izip; -use itertools::multiunzip; +use crate::arrow::datatypes::SchemaRef; +use crate::error::Result; +use crate::physical_plan::{ColumnStatistics, Statistics}; + +#[cfg(feature = "parquet")] +use crate::{ + arrow::datatypes::Schema, + functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}, + physical_plan::Accumulator, +}; + +use super::listing::PartitionedFile; /// Get all files as well as the file level summary statistics (no statistic for partition columns). /// If the optional `limit` is provided, includes only sufficient files. Needed to read up to @@ -34,7 +42,7 @@ use itertools::multiunzip; /// `ListingTable`. If it is false we only construct bare statistics and skip a potentially expensive /// call to `multiunzip` for constructing file level summary statistics. pub async fn get_statistics_with_limit( - all_files: impl Stream>, + all_files: impl Stream)>>, file_schema: SchemaRef, limit: Option, collect_stats: bool, @@ -47,9 +55,7 @@ pub async fn get_statistics_with_limit( // - zero for summations, and // - neutral element for extreme points. let size = file_schema.fields().len(); - let mut null_counts: Vec> = vec![Precision::Absent; size]; - let mut max_values: Vec> = vec![Precision::Absent; size]; - let mut min_values: Vec> = vec![Precision::Absent; size]; + let mut col_stats_set = vec![ColumnStatistics::default(); size]; let mut num_rows = Precision::::Absent; let mut total_byte_size = Precision::::Absent; @@ -57,16 +63,19 @@ pub async fn get_statistics_with_limit( let mut all_files = Box::pin(all_files.fuse()); if let Some(first_file) = all_files.next().await { - let (file, file_stats) = first_file?; + let (mut file, file_stats) = first_file?; + file.statistics = Some(file_stats.as_ref().clone()); result_files.push(file); // First file, we set them directly from the file statistics. num_rows = file_stats.num_rows; total_byte_size = file_stats.total_byte_size; - for (index, file_column) in file_stats.column_statistics.into_iter().enumerate() { - null_counts[index] = file_column.null_count; - max_values[index] = file_column.max_value; - min_values[index] = file_column.min_value; + for (index, file_column) in + file_stats.column_statistics.clone().into_iter().enumerate() + { + col_stats_set[index].null_count = file_column.null_count; + col_stats_set[index].max_value = file_column.max_value; + col_stats_set[index].min_value = file_column.min_value; } // If the number of rows exceeds the limit, we can stop processing @@ -79,7 +88,8 @@ pub async fn get_statistics_with_limit( }; if conservative_num_rows <= limit.unwrap_or(usize::MAX) { while let Some(current) = all_files.next().await { - let (file, file_stats) = current?; + let (mut file, file_stats) = current?; + file.statistics = Some(file_stats.as_ref().clone()); result_files.push(file); if !collect_stats { continue; @@ -94,33 +104,22 @@ pub async fn get_statistics_with_limit( total_byte_size = add_row_stats(file_stats.total_byte_size, total_byte_size); - (null_counts, max_values, min_values) = multiunzip( - izip!( - file_stats.column_statistics.into_iter(), - null_counts.into_iter(), - max_values.into_iter(), - min_values.into_iter() - ) - .map( - |( - ColumnStatistics { - null_count: file_nc, - max_value: file_max, - min_value: file_min, - distinct_count: _, - }, - null_count, - max_value, - min_value, - )| { - ( - add_row_stats(file_nc, null_count), - set_max_if_greater(file_max, max_value), - set_min_if_lesser(file_min, min_value), - ) - }, - ), - ); + for (file_col_stats, col_stats) in file_stats + .column_statistics + .iter() + .zip(col_stats_set.iter_mut()) + { + let ColumnStatistics { + null_count: file_nc, + max_value: file_max, + min_value: file_min, + distinct_count: _, + } = file_col_stats; + + col_stats.null_count = add_row_stats(*file_nc, col_stats.null_count); + set_max_if_greater(file_max, &mut col_stats.max_value); + set_min_if_lesser(file_min, &mut col_stats.min_value) + } // If the number of rows exceeds the limit, we can stop processing // files. This only applies when we know the number of rows. It also @@ -138,30 +137,36 @@ pub async fn get_statistics_with_limit( let mut statistics = Statistics { num_rows, total_byte_size, - column_statistics: get_col_stats_vec(null_counts, max_values, min_values), + column_statistics: col_stats_set, }; if all_files.next().await.is_some() { // If we still have files in the stream, it means that the limit kicked // in, and the statistic could have been different had we processed the // files in a different order. - statistics = statistics.into_inexact() + statistics = statistics.to_inexact() } Ok((result_files, statistics)) } +// only adding this cfg b/c this is the only feature it's used with currently +#[cfg(feature = "parquet")] pub(crate) fn create_max_min_accs( schema: &Schema, ) -> (Vec>, Vec>) { let max_values: Vec> = schema .fields() .iter() - .map(|field| MaxAccumulator::try_new(field.data_type()).ok()) + .map(|field| { + MaxAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() + }) .collect(); let min_values: Vec> = schema .fields() .iter() - .map(|field| MinAccumulator::try_new(field.data_type()).ok()) + .map(|field| { + MinAccumulator::try_new(min_max_aggregate_data_type(field.data_type())).ok() + }) .collect(); (max_values, min_values) } @@ -177,21 +182,8 @@ fn add_row_stats( } } -pub(crate) fn get_col_stats_vec( - null_counts: Vec>, - max_values: Vec>, - min_values: Vec>, -) -> Vec { - izip!(null_counts, max_values, min_values) - .map(|(null_count, max_value, min_value)| ColumnStatistics { - null_count, - max_value, - min_value, - distinct_count: Precision::Absent, - }) - .collect() -} - +// only adding this cfg b/c this is the only feature it's used with currently +#[cfg(feature = "parquet")] pub(crate) fn get_col_stats( schema: &Schema, null_counts: Vec>, @@ -209,7 +201,7 @@ pub(crate) fn get_col_stats( None => None, }; ColumnStatistics { - null_count: null_counts[i].clone(), + null_count: null_counts[i], max_value: max_value.map(Precision::Exact).unwrap_or(Precision::Absent), min_value: min_value.map(Precision::Exact).unwrap_or(Precision::Absent), distinct_count: Precision::Absent, @@ -218,48 +210,81 @@ pub(crate) fn get_col_stats( .collect() } +// Min/max aggregation can take Dictionary encode input but always produces unpacked +// (aka non Dictionary) output. We need to adjust the output data type to reflect this. +// The reason min/max aggregate produces unpacked output because there is only one +// min/max value per group; there is no needs to keep them Dictionary encode +// +// only adding this cfg b/c this is the only feature it's used with currently +#[cfg(feature = "parquet")] +fn min_max_aggregate_data_type( + input_type: &arrow_schema::DataType, +) -> &arrow_schema::DataType { + if let arrow_schema::DataType::Dictionary(_, value_type) = input_type { + value_type.as_ref() + } else { + input_type + } +} + /// If the given value is numerically greater than the original maximum value, /// return the new maximum value with appropriate exactness information. fn set_max_if_greater( - max_nominee: Precision, - max_values: Precision, -) -> Precision { - match (&max_values, &max_nominee) { - (Precision::Exact(val1), Precision::Exact(val2)) if val1 < val2 => max_nominee, + max_nominee: &Precision, + max_value: &mut Precision, +) { + match (&max_value, max_nominee) { + (Precision::Exact(val1), Precision::Exact(val2)) if val1 < val2 => { + *max_value = max_nominee.clone(); + } (Precision::Exact(val1), Precision::Inexact(val2)) | (Precision::Inexact(val1), Precision::Inexact(val2)) | (Precision::Inexact(val1), Precision::Exact(val2)) if val1 < val2 => { - max_nominee.to_inexact() + *max_value = max_nominee.clone().to_inexact(); + } + (Precision::Exact(_), Precision::Absent) => { + let exact_max = mem::take(max_value); + *max_value = exact_max.to_inexact(); + } + (Precision::Absent, Precision::Exact(_)) => { + *max_value = max_nominee.clone().to_inexact(); + } + (Precision::Absent, Precision::Inexact(_)) => { + *max_value = max_nominee.clone(); } - (Precision::Exact(_), Precision::Absent) => max_values.to_inexact(), - (Precision::Absent, Precision::Exact(_)) => max_nominee.to_inexact(), - (Precision::Absent, Precision::Inexact(_)) => max_nominee, - (Precision::Absent, Precision::Absent) => Precision::Absent, - _ => max_values, + _ => {} } } /// If the given value is numerically lesser than the original minimum value, /// return the new minimum value with appropriate exactness information. fn set_min_if_lesser( - min_nominee: Precision, - min_values: Precision, -) -> Precision { - match (&min_values, &min_nominee) { - (Precision::Exact(val1), Precision::Exact(val2)) if val1 > val2 => min_nominee, + min_nominee: &Precision, + min_value: &mut Precision, +) { + match (&min_value, min_nominee) { + (Precision::Exact(val1), Precision::Exact(val2)) if val1 > val2 => { + *min_value = min_nominee.clone(); + } (Precision::Exact(val1), Precision::Inexact(val2)) | (Precision::Inexact(val1), Precision::Inexact(val2)) | (Precision::Inexact(val1), Precision::Exact(val2)) if val1 > val2 => { - min_nominee.to_inexact() + *min_value = min_nominee.clone().to_inexact(); + } + (Precision::Exact(_), Precision::Absent) => { + let exact_min = mem::take(min_value); + *min_value = exact_min.to_inexact(); + } + (Precision::Absent, Precision::Exact(_)) => { + *min_value = min_nominee.clone().to_inexact(); + } + (Precision::Absent, Precision::Inexact(_)) => { + *min_value = min_nominee.clone(); } - (Precision::Exact(_), Precision::Absent) => min_values.to_inexact(), - (Precision::Absent, Precision::Exact(_)) => min_nominee.to_inexact(), - (Precision::Absent, Precision::Inexact(_)) => min_nominee, - (Precision::Absent, Precision::Absent) => Precision::Absent, - _ => min_values, + _ => {} } } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index 2059a5ffcfe4..34023fbbb620 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -25,24 +25,25 @@ use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; +use crate::catalog::{TableProvider, TableProviderFactory}; +use crate::datasource::create_ordering; + use arrow_array::{RecordBatch, RecordBatchReader, RecordBatchWriter}; use arrow_schema::SchemaRef; -use async_trait::async_trait; -use futures::StreamExt; - -use datafusion_common::{plan_err, Constraints, DataFusionError, Result}; +use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; +use datafusion_expr::dml::InsertOp; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; -use crate::datasource::provider::TableProviderFactory; -use crate::datasource::{create_ordering, TableProvider}; -use crate::execution::context::SessionState; +use async_trait::async_trait; +use datafusion_catalog::Session; +use futures::StreamExt; /// A [`TableProviderFactory`] for [`StreamTable`] #[derive(Debug, Default)] @@ -52,18 +53,32 @@ pub struct StreamTableFactory {} impl TableProviderFactory for StreamTableFactory { async fn create( &self, - state: &SessionState, + state: &dyn Session, cmd: &CreateExternalTable, ) -> Result> { let schema: SchemaRef = Arc::new(cmd.schema.as_ref().into()); let location = cmd.location.clone(); let encoding = cmd.file_type.parse()?; + let header = if let Ok(opt) = cmd + .options + .get("format.has_header") + .map(|has_header| bool::from_str(has_header)) + .transpose() + { + opt.unwrap_or(false) + } else { + return config_err!( + "Valid values for format.has_header option are 'true' or 'false'" + ); + }; - let config = StreamConfig::new_file(schema, location.into()) + let source = FileStreamProvider::new_file(schema, location.into()) .with_encoding(encoding) - .with_order(cmd.order_exprs.clone()) - .with_header(cmd.has_header) .with_batch_size(state.config().batch_size()) + .with_header(header); + + let config = StreamConfig::new(Arc::new(source)) + .with_order(cmd.order_exprs.clone()) .with_constraints(cmd.constraints.clone()); Ok(Arc::new(StreamTable(Arc::new(config)))) @@ -91,19 +106,44 @@ impl FromStr for StreamEncoding { } } -/// The configuration for a [`StreamTable`] +/// The StreamProvider trait is used as a generic interface for reading and writing from streaming +/// data sources (such as FIFO, Websocket, Kafka, etc.). Implementations of the provider are +/// responsible for providing a `RecordBatchReader` and optionally a `RecordBatchWriter`. +pub trait StreamProvider: std::fmt::Debug + Send + Sync { + /// Get a reference to the schema for this stream + fn schema(&self) -> &SchemaRef; + /// Provide `RecordBatchReader` + fn reader(&self) -> Result>; + /// Provide `RecordBatchWriter` + fn writer(&self) -> Result> { + unimplemented!() + } + /// Display implementation when using as a DataSink + fn stream_write_display( + &self, + t: DisplayFormatType, + f: &mut Formatter, + ) -> std::fmt::Result; +} + +/// Stream data from the file at `location` +/// +/// * Data will be read sequentially from the provided `location` +/// * New data will be appended to the end of the file +/// +/// The encoding can be configured with [`Self::with_encoding`] and +/// defaults to [`StreamEncoding::Csv`] #[derive(Debug)] -pub struct StreamConfig { - schema: SchemaRef, +pub struct FileStreamProvider { location: PathBuf, - batch_size: usize, encoding: StreamEncoding, + /// Get a reference to the schema for this file stream + pub schema: SchemaRef, header: bool, - order: Vec>, - constraints: Constraints, + batch_size: usize, } -impl StreamConfig { +impl FileStreamProvider { /// Stream data from the file at `location` /// /// * Data will be read sequentially from the provided `location` @@ -117,19 +157,11 @@ impl StreamConfig { location, batch_size: 1024, encoding: StreamEncoding::Csv, - order: vec![], header: false, - constraints: Constraints::empty(), } } - /// Specify a sort order for the stream - pub fn with_order(mut self, order: Vec>) -> Self { - self.order = order; - self - } - - /// Specify the batch size + /// Set the batch size (the number of rows to load at one time) pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; self @@ -146,11 +178,11 @@ impl StreamConfig { self.encoding = encoding; self } +} - /// Assign constraints - pub fn with_constraints(mut self, constraints: Constraints) -> Self { - self.constraints = constraints; - self +impl StreamProvider for FileStreamProvider { + fn schema(&self) -> &SchemaRef { + &self.schema } fn reader(&self) -> Result> { @@ -198,6 +230,58 @@ impl StreamConfig { } } } + + fn stream_write_display( + &self, + _t: DisplayFormatType, + f: &mut Formatter, + ) -> std::fmt::Result { + f.debug_struct("StreamWrite") + .field("location", &self.location) + .field("batch_size", &self.batch_size) + .field("encoding", &self.encoding) + .field("header", &self.header) + .finish_non_exhaustive() + } +} + +/// The configuration for a [`StreamTable`] +#[derive(Debug)] +pub struct StreamConfig { + source: Arc, + order: Vec>, + constraints: Constraints, +} + +impl StreamConfig { + /// Create a new `StreamConfig` from a `StreamProvider` + pub fn new(source: Arc) -> Self { + Self { + source, + order: vec![], + constraints: Constraints::empty(), + } + } + + /// Specify a sort order for the stream + pub fn with_order(mut self, order: Vec>) -> Self { + self.order = order; + self + } + + /// Assign constraints + pub fn with_constraints(mut self, constraints: Constraints) -> Self { + self.constraints = constraints; + self + } + + fn reader(&self) -> Result> { + self.source.reader() + } + + fn writer(&self) -> Result> { + self.source.writer() + } } /// A [`TableProvider`] for an unbounded stream source @@ -210,6 +294,7 @@ impl StreamConfig { /// /// [Hadoop]: https://hadoop.apache.org/ /// [`ListingTable`]: crate::datasource::listing::ListingTable +#[derive(Debug)] pub struct StreamTable(Arc); impl StreamTable { @@ -226,7 +311,7 @@ impl TableProvider for StreamTable { } fn schema(&self) -> SchemaRef { - self.0.schema.clone() + self.0.source.schema().clone() } fn constraints(&self) -> Option<&Constraints> { @@ -239,37 +324,38 @@ impl TableProvider for StreamTable { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], - _limit: Option, + limit: Option, ) -> Result> { let projected_schema = match projection { Some(p) => { - let projected = self.0.schema.project(p)?; + let projected = self.0.source.schema().project(p)?; create_ordering(&projected, &self.0.order)? } - None => create_ordering(self.0.schema.as_ref(), &self.0.order)?, + None => create_ordering(self.0.source.schema(), &self.0.order)?, }; Ok(Arc::new(StreamingTableExec::try_new( - self.0.schema.clone(), + self.0.source.schema().clone(), vec![Arc::new(StreamRead(self.0.clone())) as _], projection, projected_schema, true, + limit, )?)) } async fn insert_into( &self, - _state: &SessionState, + _state: &dyn Session, input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { let ordering = match self.0.order.first() { Some(x) => { - let schema = self.0.schema.as_ref(); + let schema = self.0.source.schema(); let orders = create_ordering(schema, std::slice::from_ref(x))?; let ordering = orders.into_iter().next().unwrap(); Some(ordering.into_iter().map(Into::into).collect()) @@ -280,22 +366,23 @@ impl TableProvider for StreamTable { Ok(Arc::new(DataSinkExec::new( input, Arc::new(StreamWrite(self.0.clone())), - self.0.schema.clone(), + self.0.source.schema().clone(), ordering, ))) } } +#[derive(Debug)] struct StreamRead(Arc); impl PartitionStream for StreamRead { fn schema(&self) -> &SchemaRef { - &self.0.schema + self.0.source.schema() } fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { let config = self.0.clone(); - let schema = self.0.schema.clone(); + let schema = self.0.source.schema().clone(); let mut builder = RecordBatchReceiverStreamBuilder::new(schema, 2); let tx = builder.tx(); builder.spawn_blocking(move || { @@ -316,12 +403,7 @@ struct StreamWrite(Arc); impl DisplayAs for StreamWrite { fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("StreamWrite") - .field("location", &self.0.location) - .field("batch_size", &self.0.batch_size) - .field("encoding", &self.0.encoding) - .field("header", &self.0.header) - .finish_non_exhaustive() + self.0.source.stream_write_display(_t, f) } } @@ -359,6 +441,9 @@ impl DataSink for StreamWrite { } } drop(sender); - write_task.join_unwind().await + write_task + .join_unwind() + .await + .map_err(DataFusionError::ExecutionJoin)? } } diff --git a/datafusion/core/src/datasource/streaming.rs b/datafusion/core/src/datasource/streaming.rs index f85db2280d8e..0a14cfefcdf2 100644 --- a/datafusion/core/src/datasource/streaming.rs +++ b/datafusion/core/src/datasource/streaming.rs @@ -23,16 +23,16 @@ use std::sync::Arc; use arrow::datatypes::SchemaRef; use async_trait::async_trait; -use datafusion_common::{plan_err, Result}; -use datafusion_expr::{Expr, TableType}; -use log::debug; - use crate::datasource::TableProvider; -use crate::execution::context::SessionState; use crate::physical_plan::streaming::{PartitionStream, StreamingTableExec}; use crate::physical_plan::ExecutionPlan; +use datafusion_catalog::Session; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{Expr, TableType}; +use log::debug; /// A [`TableProvider`] that streams a set of [`PartitionStream`] +#[derive(Debug)] pub struct StreamingTable { schema: SchemaRef, partitions: Vec>, @@ -50,7 +50,7 @@ impl StreamingTable { if !schema.contains(partition_schema) { debug!( "target schema does not contain partition schema. \ - Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + Target_schema: {schema:?}. Partition Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); } @@ -85,18 +85,18 @@ impl TableProvider for StreamingTable { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], - _limit: Option, + limit: Option, ) -> Result> { - // TODO: push limit down Ok(Arc::new(StreamingTableExec::try_new( self.schema.clone(), self.partitions.clone(), projection, None, self.infinite, + limit, )?)) } } diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 3f024a6b4cb7..1ffe54e4b06c 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -17,23 +17,26 @@ //! View data source which uses a LogicalPlan as it's input. -use std::{any::Any, sync::Arc}; - -use arrow::datatypes::SchemaRef; -use async_trait::async_trait; -use datafusion_common::Column; -use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; +use std::{any::Any, borrow::Cow, sync::Arc}; use crate::{ error::Result, logical_expr::{Expr, LogicalPlan}, physical_plan::ExecutionPlan, }; +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion_catalog::Session; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Column; +use datafusion_expr::{LogicalPlanBuilder, TableProviderFilterPushDown}; +use datafusion_optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion_optimizer::Analyzer; use crate::datasource::{TableProvider, TableType}; -use crate::execution::context::SessionState; /// An implementation of `TableProvider` that uses another logical plan. +#[derive(Debug)] pub struct ViewTable { /// LogicalPlan of the view logical_plan: LogicalPlan, @@ -50,6 +53,7 @@ impl ViewTable { logical_plan: LogicalPlan, definition: Option, ) -> Result { + let logical_plan = Self::apply_required_rule(logical_plan)?; let table_schema = logical_plan.schema().as_ref().to_owned().into(); let view = Self { @@ -61,6 +65,15 @@ impl ViewTable { Ok(view) } + fn apply_required_rule(logical_plan: LogicalPlan) -> Result { + let options = ConfigOptions::default(); + Analyzer::with_rules(vec![Arc::new(ExpandWildcardRule::new())]).execute_and_check( + logical_plan, + &options, + |_, _| {}, + ) + } + /// Get definition ref pub fn definition(&self) -> Option<&String> { self.definition.as_ref() @@ -78,8 +91,8 @@ impl TableProvider for ViewTable { self } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { - Some(&self.logical_plan) + fn get_logical_plan(&self) -> Option> { + Some(Cow::Borrowed(&self.logical_plan)) } fn schema(&self) -> SchemaRef { @@ -103,7 +116,7 @@ impl TableProvider for ViewTable { async fn scan( &self, - state: &SessionState, + state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], limit: Option, @@ -232,6 +245,26 @@ mod tests { assert_batches_eq!(expected, &results); + let view_sql = + "CREATE VIEW replace_xyz AS SELECT * REPLACE (column1*2 as column1) FROM xyz"; + session_ctx.sql(view_sql).await?.collect().await?; + + let results = session_ctx + .sql("SELECT * FROM replace_xyz") + .await? + .collect() + .await?; + + let expected = [ + "+---------+---------+---------+", + "| column1 | column2 | column3 |", + "+---------+---------+---------+", + "| 2 | 2 | 3 |", + "| 8 | 5 | 6 |", + "+---------+---------+---------+", + ]; + + assert_batches_eq!(expected, &results); Ok(()) } diff --git a/datafusion/core/src/execution/context/avro.rs b/datafusion/core/src/execution/context/avro.rs index 2703529264e0..a31f2af642d0 100644 --- a/datafusion/core/src/execution/context/avro.rs +++ b/datafusion/core/src/execution/context/avro.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use super::super::options::{AvroReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, Result, SessionContext}; +use datafusion_common::TableReference; +use std::sync::Arc; impl SessionContext { /// Creates a [`DataFrame`] for reading an Avro data source. @@ -39,15 +39,15 @@ impl SessionContext { /// SQL statements executed against this context. pub async fn register_avro( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: AvroReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), @@ -57,29 +57,3 @@ impl SessionContext { Ok(()) } } - -#[cfg(test)] -mod tests { - use super::*; - - use async_trait::async_trait; - - // Test for compilation error when calling read_* functions from an #[async_trait] function. - // See https://github.com/apache/datafusion/issues/1154 - #[async_trait] - trait CallReadTrait { - async fn call_read_avro(&self) -> DataFrame; - } - - struct CallRead {} - - #[async_trait] - impl CallReadTrait for CallRead { - async fn call_read_avro(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_avro("dummy", AvroReadOptions::default()) - .await - .unwrap() - } - } -} diff --git a/datafusion/core/src/execution/context/csv.rs b/datafusion/core/src/execution/context/csv.rs index 504ebf6d77cf..e97c70ef9812 100644 --- a/datafusion/core/src/execution/context/csv.rs +++ b/datafusion/core/src/execution/context/csv.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::datasource::physical_plan::plan_to_csv; +use datafusion_common::TableReference; +use std::sync::Arc; use super::super::options::{CsvReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; @@ -55,15 +55,15 @@ impl SessionContext { /// statements executed against this context. pub async fn register_csv( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: CsvReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), @@ -90,7 +90,6 @@ mod tests { use crate::assert_batches_eq; use crate::test_util::{plan_and_collect, populate_csv_partitions}; - use async_trait::async_trait; use tempfile::TempDir; #[tokio::test] @@ -111,12 +110,12 @@ mod tests { ) .await?; let results = - plan_and_collect(&ctx, "SELECT SUM(c1), SUM(c2), COUNT(*) FROM test").await?; + plan_and_collect(&ctx, "SELECT sum(c1), sum(c2), count(*) FROM test").await?; assert_eq!(results.len(), 1); let expected = [ "+--------------+--------------+----------+", - "| SUM(test.c1) | SUM(test.c2) | COUNT(*) |", + "| sum(test.c1) | sum(test.c2) | count(*) |", "+--------------+--------------+----------+", "| 10 | 110 | 20 |", "+--------------+--------------+----------+", @@ -125,21 +124,4 @@ mod tests { Ok(()) } - - // Test for compilation error when calling read_* functions from an #[async_trait] function. - // See https://github.com/apache/datafusion/issues/1154 - #[async_trait] - trait CallReadTrait { - async fn call_read_csv(&self) -> DataFrame; - } - - struct CallRead {} - - #[async_trait] - impl CallReadTrait for CallRead { - async fn call_read_csv(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() - } - } } diff --git a/datafusion/core/src/execution/context/json.rs b/datafusion/core/src/execution/context/json.rs index c21e32cfdefb..c9a9492f9162 100644 --- a/datafusion/core/src/execution/context/json.rs +++ b/datafusion/core/src/execution/context/json.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::datasource::physical_plan::plan_to_json; +use datafusion_common::TableReference; +use std::sync::Arc; use super::super::options::{NdJsonReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; @@ -41,15 +41,15 @@ impl SessionContext { /// from SQL statements executed against this context. pub async fn register_json( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: NdJsonReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d83644597e78..333f83c673cc 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -15,86 +15,72 @@ // specific language governing permissions and limitations // under the License. -//! [`SessionContext`] contains methods for registering data sources and executing queries +//! [`SessionContext`] API for registering data sources and executing queries -use std::collections::{hash_map::Entry, HashMap, HashSet}; +use std::collections::HashSet; use std::fmt::Debug; -use std::ops::ControlFlow; use std::sync::{Arc, Weak}; use super::options::ReadOptions; use crate::{ - catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}, - catalog::listing_schema::ListingSchemaProvider, - catalog::schema::{MemorySchemaProvider, SchemaProvider}, catalog::{ - CatalogProvider, CatalogProviderList, MemoryCatalogProvider, - MemoryCatalogProviderList, + CatalogProvider, CatalogProviderList, TableProvider, TableProviderFactory, }, - config::ConfigOptions, + catalog_common::listing_schema::ListingSchemaProvider, + catalog_common::memory::MemorySchemaProvider, + catalog_common::MemoryCatalogProvider, dataframe::DataFrame, datasource::{ - cte_worktable::CteWorkTable, function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, - object_store::ObjectStoreUrl, - provider::{DefaultTableFactory, TableProviderFactory}, }, - datasource::{provider_as_source, MemTable, TableProvider, ViewTable}, + datasource::{provider_as_source, MemTable, ViewTable}, error::{DataFusionError, Result}, execution::{options::ArrowReadOptions, runtime_env::RuntimeEnv, FunctionRegistry}, logical_expr::AggregateUDF, + logical_expr::ScalarUDF, logical_expr::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, CreateMemoryTable, CreateView, DropCatalogSchema, DropFunction, DropTable, - DropView, Explain, LogicalPlan, LogicalPlanBuilder, PlanType, SetVariable, - TableSource, TableType, ToStringifiedPlan, UNNAMED_TABLE, + DropView, LogicalPlan, LogicalPlanBuilder, SetVariable, TableType, UNNAMED_TABLE, }, - optimizer::analyzer::{Analyzer, AnalyzerRule}, - optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule}, - physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule}, - physical_plan::{udf::ScalarUDF, ExecutionPlan}, - physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, + physical_expr::PhysicalExpr, + physical_plan::ExecutionPlan, variable::{VarProvider, VarType}, }; -#[cfg(feature = "array_expressions")] -use crate::functions_array; -use crate::{functions, functions_aggregate}; - -use arrow::datatypes::{DataType, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; use datafusion_common::{ - alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, tree_node::{TreeNodeRecursion, TreeNodeVisitor}, - SchemaReference, TableReference, + DFSchema, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ + expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, - var_provider::is_system_variables, - Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, -}; -use datafusion_sql::{ - parser::{CopyToSource, CopyToStatement, DFParser}, - planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, - ResolvedTableReference, + planner::ExprPlanner, + Expr, UserDefinedLogicalNode, WindowUDF, }; +// backwards compatibility +pub use crate::execution::session_state::SessionState; + +use crate::datasource::dynamic_file::DynamicListTableFactory; +use crate::execution::session_state::SessionStateBuilder; use async_trait::async_trait; use chrono::{DateTime, Utc}; -use parking_lot::RwLock; -use sqlparser::dialect::dialect_from_str; -use url::Url; -use uuid::Uuid; - +use datafusion_catalog::{DynamicFileCatalog, SessionStore, UrlTableFactory}; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_optimizer::{AnalyzerRule, OptimizerRule}; +use object_store::ObjectStore; +use parking_lot::RwLock; +use url::Url; mod avro; mod csv; @@ -143,6 +129,9 @@ where /// the state of the connection between a user and an instance of the /// DataFusion engine. /// +/// See examples below for how to use the `SessionContext` to execute queries +/// and how to configure the session. +/// /// # Overview /// /// [`SessionContext`] provides the following functionality: @@ -159,6 +148,7 @@ where /// /// ``` /// use datafusion::prelude::*; +/// # use datafusion::functions_aggregate::expr_fn::min; /// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { @@ -173,7 +163,7 @@ where /// assert_batches_eq!( /// &[ /// "+---+----------------+", -/// "| a | MIN(?table?.b) |", +/// "| a | min(?table?.b) |", /// "+---+----------------+", /// "| 1 | 2 |", /// "+---+----------------+", @@ -193,17 +183,17 @@ where /// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = SessionContext::new(); +/// let ctx = SessionContext::new(); /// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; /// let results = ctx -/// .sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100") +/// .sql("SELECT a, min(b) FROM example GROUP BY a LIMIT 100") /// .await? /// .collect() /// .await?; /// assert_batches_eq!( /// &[ /// "+---+----------------+", -/// "| a | MIN(example.b) |", +/// "| a | min(example.b) |", /// "+---+----------------+", /// "| 1 | 2 |", /// "+---+----------------+", @@ -214,19 +204,62 @@ where /// # } /// ``` /// -/// # `SessionContext`, `SessionState`, and `TaskContext` +/// # Example: Configuring `SessionContext` +/// +/// The `SessionContext` can be configured by creating a [`SessionState`] using +/// [`SessionStateBuilder`]: +/// +/// ``` +/// # use std::sync::Arc; +/// # use datafusion::prelude::*; +/// # use datafusion::execution::SessionStateBuilder; +/// # use datafusion_execution::runtime_env::RuntimeEnvBuilder; +/// // Configure a 4k batch size +/// let config = SessionConfig::new() .with_batch_size(4 * 1024); +/// +/// // configure a memory limit of 1GB with 20% slop +/// let runtime_env = RuntimeEnvBuilder::new() +/// .with_memory_limit(1024 * 1024 * 1024, 0.80) +/// .build_arc() +/// .unwrap(); /// -/// A [`SessionContext`] can be created from a [`SessionConfig`] and -/// stores the state for a particular query session. A single -/// [`SessionContext`] can run multiple queries. +/// // Create a SessionState using the config and runtime_env +/// let state = SessionStateBuilder::new() +/// .with_config(config) +/// .with_runtime_env(runtime_env) +/// // include support for built in functions and configurations +/// .with_default_features() +/// .build(); +/// +/// // Create a SessionContext +/// let ctx = SessionContext::from(state); +/// ``` /// -/// [`SessionState`] contains information available during query -/// planning (creating [`LogicalPlan`]s and [`ExecutionPlan`]s). +/// # Relationship between `SessionContext`, `SessionState`, and `TaskContext` /// -/// [`TaskContext`] contains the state available during query -/// execution [`ExecutionPlan::execute`]. It contains a subset of the -/// information in[`SessionState`] and is created from a -/// [`SessionContext`] or a [`SessionState`]. +/// The state required to optimize, and evaluate queries is +/// broken into three levels to allow tailoring +/// +/// The objects are: +/// +/// 1. [`SessionContext`]: Most users should use a `SessionContext`. It contains +/// all information required to execute queries including high level APIs such +/// as [`SessionContext::sql`]. All queries run with the same `SessionContext` +/// share the same configuration and resources (e.g. memory limits). +/// +/// 2. [`SessionState`]: contains information required to plan and execute an +/// individual query (e.g. creating a [`LogicalPlan`] or [`ExecutionPlan`]). +/// Each query is planned and executed using its own `SessionState`, which can +/// be created with [`SessionContext::state`]. `SessionState` allows finer +/// grained control over query execution, for example disallowing DDL operations +/// such as `CREATE TABLE`. +/// +/// 3. [`TaskContext`] contains the state required for query execution (e.g. +/// [`ExecutionPlan::execute`]). It contains a subset of information in +/// [`SessionState`]. `TaskContext` allows executing [`ExecutionPlan`]s +/// [`PhysicalExpr`]s without requiring a full [`SessionState`]. +/// +/// [`PhysicalExpr`]: crate::physical_expr::PhysicalExpr #[derive(Clone)] pub struct SessionContext { /// UUID for the session @@ -279,13 +312,6 @@ impl SessionContext { Self::new_with_config_rt(config, runtime) } - /// Creates a new `SessionContext` using the provided - /// [`SessionConfig`] and a new [`RuntimeEnv`]. - #[deprecated(since = "32.0.0", note = "Use SessionContext::new_with_config")] - pub fn with_config(config: SessionConfig) -> Self { - Self::new_with_config(config) - } - /// Creates a new `SessionContext` using the provided /// [`SessionConfig`] and a [`RuntimeEnv`]. /// @@ -300,31 +326,116 @@ impl SessionContext { /// all `SessionContext`'s should be configured with the /// same `RuntimeEnv`. pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); Self::new_with_state(state) } - /// Creates a new `SessionContext` using the provided - /// [`SessionConfig`] and a [`RuntimeEnv`]. - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] - pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - Self::new_with_config_rt(config, runtime) - } - /// Creates a new `SessionContext` using the provided [`SessionState`] pub fn new_with_state(state: SessionState) -> Self { Self { - session_id: state.session_id.clone(), + session_id: state.session_id().to_string(), session_start_time: Utc::now(), state: Arc::new(RwLock::new(state)), } } - /// Creates a new `SessionContext` using the provided [`SessionState`] - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_state")] - pub fn with_state(state: SessionState) -> Self { - Self::new_with_state(state) + /// Enable querying local files as tables. + /// + /// This feature is security sensitive and should only be enabled for + /// systems that wish to permit direct access to the file system from SQL. + /// + /// When enabled, this feature permits direct access to arbitrary files via + /// SQL like + /// + /// ```sql + /// SELECT * from 'my_file.parquet' + /// ``` + /// + /// See [DynamicFileCatalog] for more details + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::{error::Result, assert_batches_eq}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// let ctx = SessionContext::new() + /// .enable_url_table(); // permit local file access + /// let results = ctx + /// .sql("SELECT a, MIN(b) FROM 'tests/data/example.csv' as example GROUP BY a LIMIT 100") + /// .await? + /// .collect() + /// .await?; + /// assert_batches_eq!( + /// &[ + /// "+---+----------------+", + /// "| a | min(example.b) |", + /// "+---+----------------+", + /// "| 1 | 2 |", + /// "+---+----------------+", + /// ], + /// &results + /// ); + /// # Ok(()) + /// # } + /// ``` + pub fn enable_url_table(self) -> Self { + let current_catalog_list = Arc::clone(self.state.read().catalog_list()); + let factory = Arc::new(DynamicListTableFactory::new(SessionStore::new())); + let catalog_list = Arc::new(DynamicFileCatalog::new( + current_catalog_list, + Arc::clone(&factory) as Arc, + )); + let ctx: SessionContext = self + .into_state_builder() + .with_catalog_list(catalog_list) + .build() + .into(); + // register new state with the factory + factory.session_store().with_state(ctx.state_weak_ref()); + ctx } + + /// Convert the current `SessionContext` into a [`SessionStateBuilder`] + /// + /// This is useful to switch back to `SessionState` with custom settings such as + /// [`Self::enable_url_table`]. + /// + /// Avoids cloning the SessionState if possible. + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use datafusion::prelude::*; + /// # use datafusion::execution::SessionStateBuilder; + /// # use datafusion_optimizer::push_down_filter::PushDownFilter; + /// let my_rule = PushDownFilter{}; // pretend it is a new rule + /// // Create a new builder with a custom optimizer rule + /// let context: SessionContext = SessionStateBuilder::new() + /// .with_optimizer_rule(Arc::new(my_rule)) + /// .build() + /// .into(); + /// // Enable local file access and convert context back to a builder + /// let builder = context + /// .enable_url_table() + /// .into_state_builder(); + /// ``` + pub fn into_state_builder(self) -> SessionStateBuilder { + let SessionContext { + session_id: _, + session_start_time: _, + state, + } = self; + let state = match Arc::try_unwrap(state) { + Ok(rwlock) => rwlock.into_inner(), + Err(state) => state.read().clone(), + }; + SessionStateBuilder::from(state) + } + /// Returns the time this `SessionContext` was created pub fn session_start_time(&self) -> DateTime { self.session_start_time @@ -339,6 +450,46 @@ impl SessionContext { self } + /// Adds an optimizer rule to the end of the existing rules. + /// + /// See [`SessionState`] for more control of when the rule is applied. + pub fn add_optimizer_rule( + &self, + optimizer_rule: Arc, + ) { + self.state.write().append_optimizer_rule(optimizer_rule); + } + + /// Adds an analyzer rule to the end of the existing rules. + /// + /// See [`SessionState`] for more control of when the rule is applied. + pub fn add_analyzer_rule(&self, analyzer_rule: Arc) { + self.state.write().add_analyzer_rule(analyzer_rule); + } + + /// Registers an [`ObjectStore`] to be used with a specific URL prefix. + /// + /// See [`RuntimeEnv::register_object_store`] for more details. + /// + /// # Example: register a local object store for the "file://" URL prefix + /// ``` + /// # use std::sync::Arc; + /// # use datafusion::prelude::SessionContext; + /// # use datafusion_execution::object_store::ObjectStoreUrl; + /// let object_store_url = ObjectStoreUrl::parse("file://").unwrap(); + /// let object_store = object_store::local::LocalFileSystem::new(); + /// let ctx = SessionContext::new(); + /// // All files with the file:// url prefix will be read from the local file system + /// ctx.register_object_store(object_store_url.as_ref(), Arc::new(object_store)); + /// ``` + pub fn register_object_store( + &self, + url: &Url, + object_store: Arc, + ) -> Option> { + self.runtime_env().register_object_store(url, object_store) + } + /// Registers the [`RecordBatch`] as the specified table name pub fn register_batch( &self, @@ -356,7 +507,7 @@ impl SessionContext { /// Return the [RuntimeEnv] used to run queries with this `SessionContext` pub fn runtime_env(&self) -> Arc { - self.state.read().runtime_env.clone() + self.state.read().runtime_env().clone() } /// Returns an id that uniquely identifies this `SessionContext`. @@ -377,7 +528,7 @@ impl SessionContext { pub fn enable_ident_normalization(&self) -> bool { self.state .read() - .config + .config() .options() .sql_parser .enable_ident_normalization @@ -385,7 +536,7 @@ impl SessionContext { /// Return a copied version of config for this Session pub fn copied_config(&self) -> SessionConfig { - self.state.read().config.clone() + self.state.read().config().clone() } /// Return a copied version of table options for this Session @@ -410,7 +561,7 @@ impl SessionContext { /// # use datafusion::{error::Result, assert_batches_eq}; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// ctx /// .sql("CREATE TABLE foo (x INTEGER)") /// .await? @@ -438,7 +589,7 @@ impl SessionContext { /// # use datafusion::physical_plan::collect; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = SessionContext::new(); + /// let ctx = SessionContext::new(); /// let options = SQLOptions::new() /// .with_allow_ddl(false); /// let err = ctx.sql_with_options("CREATE TABLE foo (x INTEGER)", options) @@ -461,6 +612,32 @@ impl SessionContext { self.execute_logical_plan(plan).await } + /// Creates logical expressions from SQL query text. + /// + /// # Example: Parsing SQL queries + /// + /// ``` + /// # use arrow::datatypes::{DataType, Field, Schema}; + /// # use datafusion::prelude::*; + /// # use datafusion_common::{DFSchema, Result}; + /// # #[tokio::main] + /// # async fn main() -> Result<()> { + /// // datafusion will parse number as i64 first. + /// let sql = "a > 10"; + /// let expected = col("a").gt(lit(10 as i64)); + /// // provide type information that `a` is an Int32 + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let df_schema = DFSchema::try_from(schema).unwrap(); + /// let expr = SessionContext::new() + /// .parse_sql_expr(sql, &df_schema)?; + /// assert_eq!(expected, expr); + /// # Ok(()) + /// # } + /// ``` + pub fn parse_sql_expr(&self, sql: &str, df_schema: &DFSchema) -> Result { + self.state.read().create_logical_expr(sql, df_schema) + } + /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API /// is not featured limited (so all SQL such as `CREATE TABLE` and /// `COPY` will be run). @@ -476,30 +653,35 @@ impl SessionContext { // stack overflows. match ddl { DdlStatement::CreateExternalTable(cmd) => { - Box::pin(async move { self.create_external_table(&cmd).await }) - as std::pin::Pin + Send>> + (Box::pin(async move { self.create_external_table(&cmd).await }) + as std::pin::Pin + Send>>) + .await } DdlStatement::CreateMemoryTable(cmd) => { - Box::pin(self.create_memory_table(cmd)) + Box::pin(self.create_memory_table(cmd)).await + } + DdlStatement::CreateView(cmd) => { + Box::pin(self.create_view(cmd)).await } - DdlStatement::CreateView(cmd) => Box::pin(self.create_view(cmd)), DdlStatement::CreateCatalogSchema(cmd) => { - Box::pin(self.create_catalog_schema(cmd)) + Box::pin(self.create_catalog_schema(cmd)).await } DdlStatement::CreateCatalog(cmd) => { - Box::pin(self.create_catalog(cmd)) + Box::pin(self.create_catalog(cmd)).await } - DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)), - DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)), + DdlStatement::DropTable(cmd) => Box::pin(self.drop_table(cmd)).await, + DdlStatement::DropView(cmd) => Box::pin(self.drop_view(cmd)).await, DdlStatement::DropCatalogSchema(cmd) => { - Box::pin(self.drop_schema(cmd)) + Box::pin(self.drop_schema(cmd)).await } DdlStatement::CreateFunction(cmd) => { - Box::pin(self.create_function(cmd)) + Box::pin(self.create_function(cmd)).await } - DdlStatement::DropFunction(cmd) => Box::pin(self.drop_function(cmd)), + DdlStatement::DropFunction(cmd) => { + Box::pin(self.drop_function(cmd)).await + } + ddl => Ok(DataFrame::new(self.state(), LogicalPlan::Ddl(ddl))), } - .await } // TODO what about the other statements (like TransactionStart and TransactionEnd) LogicalPlan::Statement(Statement::SetVariable(stmt)) => { @@ -510,6 +692,41 @@ impl SessionContext { } } + /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type + /// coercion and function rewrites. + /// + /// Note: The expression is not [simplified] or otherwise optimized: + /// `a = 1 + 2` will not be simplified to `a = 3` as this is a more involved process. + /// See the [expr_api] example for how to simplify expressions. + /// + /// # Example + /// ``` + /// # use std::sync::Arc; + /// # use arrow::datatypes::{DataType, Field, Schema}; + /// # use datafusion::prelude::*; + /// # use datafusion_common::DFSchema; + /// // a = 1 (i64) + /// let expr = col("a").eq(lit(1i64)); + /// // provide type information that `a` is an Int32 + /// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + /// let df_schema = DFSchema::try_from(schema).unwrap(); + /// // Create a PhysicalExpr. Note DataFusion automatically coerces (casts) `1i64` to `1i32` + /// let physical_expr = SessionContext::new() + /// .create_physical_expr(expr, &df_schema).unwrap(); + /// ``` + /// # See Also + /// * [`SessionState::create_physical_expr`] for a lower level API + /// + /// [simplified]: datafusion_optimizer::simplify_expressions + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + pub fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> Result> { + self.state.read().create_physical_expr(expr, df_schema) + } + // return an empty dataframe fn return_empty_dataframe(&self) -> Result { let plan = LogicalPlanBuilder::empty(false).build()?; @@ -521,6 +738,11 @@ impl SessionContext { cmd: &CreateExternalTable, ) -> Result { let exist = self.table_exist(cmd.name.clone())?; + + if cmd.temporary { + return not_impl_err!("Temporary tables not supported"); + } + if exist { match cmd.if_not_exists { true => return self.return_empty_dataframe(), @@ -544,10 +766,16 @@ impl SessionContext { or_replace, constraints, column_defaults, + temporary, } = cmd; - let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); + let input = Arc::unwrap_or_clone(input); let input = self.state().optimize(&input)?; + + if temporary { + return not_impl_err!("Temporary tables not supported"); + } + let table = self.table(name.clone()).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -596,10 +824,15 @@ impl SessionContext { input, or_replace, definition, + temporary, } = cmd; let view = self.table(name.clone()).await; + if temporary { + return not_impl_err!("Temporary views not supported"); + } + match (or_replace, view) { (true, Ok(_)) => { self.deregister_table(name.clone())?; @@ -610,7 +843,6 @@ impl SessionContext { } (_, Err(_)) => { let table = Arc::new(ViewTable::try_new((*input).clone(), definition)?); - self.register_table(name, table)?; self.return_empty_dataframe() } @@ -631,8 +863,8 @@ impl SessionContext { let (catalog, schema_name) = match tokens.len() { 1 => { let state = self.state.read(); - let name = &state.config.options().catalog.default_catalog; - let catalog = state.catalog_list.catalog(name).ok_or_else(|| { + let name = &state.config().options().catalog.default_catalog; + let catalog = state.catalog_list().catalog(name).ok_or_else(|| { DataFusionError::Execution(format!( "Missing default catalog '{name}'" )) @@ -675,7 +907,7 @@ impl SessionContext { let new_catalog = Arc::new(MemoryCatalogProvider::new()); self.state .write() - .catalog_list + .catalog_list() .register_catalog(catalog_name, new_catalog); self.return_empty_dataframe() } @@ -726,7 +958,7 @@ impl SessionContext { state.config_options().catalog.default_catalog.to_string() } }; - if let Some(catalog) = state.catalog_list.catalog(&catalog_name) { + if let Some(catalog) = state.catalog_list().catalog(&catalog_name) { catalog } else if allow_missing { return self.return_empty_dataframe(); @@ -752,7 +984,7 @@ impl SessionContext { } = stmt; let mut state = self.state.write(); - state.config.options_mut().set(&variable, &value)?; + state.config_mut().options_mut().set(&variable, &value)?; drop(state); self.return_empty_dataframe() @@ -765,8 +997,8 @@ impl SessionContext { let state = self.state.read().clone(); let file_type = cmd.file_type.to_uppercase(); let factory = - &state - .table_factories + state + .table_factories() .get(file_type.as_str()) .ok_or_else(|| { DataFusionError::Execution(format!( @@ -789,7 +1021,7 @@ impl SessionContext { let state = self.state.read(); let resolved = state.resolve_table_ref(table_ref); state - .catalog_list + .catalog_list() .catalog(&resolved.catalog) .and_then(|c| c.schema(&resolved.schema)) }; @@ -809,7 +1041,7 @@ impl SessionContext { async fn create_function(&self, stmt: CreateFunction) -> Result { let function = { let state = self.state.read().clone(); - let function_factory = &state.function_factory; + let function_factory = state.function_factory(); match function_factory { Some(f) => f.create(&state, stmt).await?, @@ -842,6 +1074,7 @@ impl SessionContext { dropped |= self.state.write().deregister_udf(&stmt.name)?.is_some(); dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some(); dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some(); + dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some(); // DROP FUNCTION IF EXISTS drops the specified function only if that // function exists and in this way, it avoids error. While the DROP FUNCTION @@ -863,16 +1096,13 @@ impl SessionContext { ) { self.state .write() - .execution_props + .execution_props_mut() .add_var_provider(variable_type, provider); } /// Register a table UDF with this context pub fn register_udtf(&self, name: &str, fun: Arc) { - self.state.write().table_functions.insert( - name.to_owned(), - Arc::new(TableFunction::new(name.to_owned(), fun)), - ); + self.state.write().register_udtf(name, fun) } /// Registers a scalar UDF within this context. @@ -882,6 +1112,7 @@ impl SessionContext { /// /// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"` /// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"` + /// /// Any functions registered with the udf name or its aliases will be overwritten with this new function pub fn register_udf(&self, f: ScalarUDF) { let mut state = self.state.write(); @@ -925,6 +1156,11 @@ impl SessionContext { self.state.write().deregister_udwf(name).ok(); } + /// Deregisters a UDTF within this context. + pub fn deregister_udtf(&self, name: &str) { + self.state.write().deregister_udtf(name).ok(); + } + /// Creates a [`DataFrame`] for reading a data source. /// /// For more control such as reading multiple files, you can use @@ -1020,7 +1256,7 @@ impl SessionContext { // check schema uniqueness let mut batches = batches.into_iter().peekable(); let schema = if let Some(batch) = batches.peek() { - batch.schema().clone() + batch.schema() } else { Arc::new(Schema::empty()) }; @@ -1044,7 +1280,7 @@ impl SessionContext { /// [`ObjectStore`]: object_store::ObjectStore pub async fn register_listing_table( &self, - name: &str, + table_ref: impl Into, table_path: impl AsRef, options: ListingOptions, provided_schema: Option, @@ -1059,10 +1295,7 @@ impl SessionContext { .with_listing_options(options) .with_schema(resolved_schema); let table = ListingTable::try_new(config)?.with_definition(sql_definition); - self.register_table( - TableReference::Bare { table: name.into() }, - Arc::new(table), - )?; + self.register_table(table_ref, Arc::new(table))?; Ok(()) } @@ -1102,18 +1335,18 @@ impl SessionContext { let name = name.into(); self.state .read() - .catalog_list + .catalog_list() .register_catalog(name, catalog) } /// Retrieves the list of available catalog names. pub fn catalog_names(&self) -> Vec { - self.state.read().catalog_list.catalog_names() + self.state.read().catalog_list().catalog_names() } /// Retrieves a [`CatalogProvider`] instance by name pub fn catalog(&self, name: &str) -> Option> { - self.state.read().catalog_list.catalog(name) + self.state.read().catalog_list().catalog(name) } /// Registers a [`TableProvider`] as a table that can be @@ -1183,6 +1416,20 @@ impl SessionContext { Ok(DataFrame::new(self.state(), plan)) } + /// Retrieves a [`TableFunction`] reference by name. + /// + /// Returns an error if no table function has been registered with the provided name. + /// + /// [`register_udtf`]: SessionContext::register_udtf + pub fn table_function(&self, name: &str) -> Result> { + self.state + .read() + .table_functions() + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("Table function '{name}' not found")) + } + /// Return a [`TableProvider`] for the specified table. pub async fn table_provider<'a>( &self, @@ -1202,32 +1449,45 @@ impl SessionContext { Arc::new(TaskContext::from(self)) } - /// Snapshots the [`SessionState`] of this [`SessionContext`] setting the - /// `query_execution_start_time` to the current time + /// Return a new [`SessionState`] suitable for executing a single query. + /// + /// Notes: + /// + /// 1. `query_execution_start_time` is set to the current time for the + /// returned state. + /// + /// 2. The returned state is not shared with the current session state + /// and this changes to the returned `SessionState` such as changing + /// [`ConfigOptions`] will not be reflected in this `SessionContext`. + /// + /// [`ConfigOptions`]: crate::config::ConfigOptions pub fn state(&self) -> SessionState { let mut state = self.state.read().clone(); - state.execution_props.start_execution(); + state.execution_props_mut().start_execution(); state } + /// Get reference to [`SessionState`] + pub fn state_ref(&self) -> Arc> { + self.state.clone() + } + /// Get weak reference to [`SessionState`] pub fn state_weak_ref(&self) -> Weak> { Arc::downgrade(&self.state) } /// Register [`CatalogProviderList`] in [`SessionState`] - pub fn register_catalog_list(&mut self, catalog_list: Arc) { - self.state.write().catalog_list = catalog_list; + pub fn register_catalog_list(&self, catalog_list: Arc) { + self.state.write().register_catalog_list(catalog_list) } - /// Registers a [`ConfigExtension`] as a table option extention that can be + /// Registers a [`ConfigExtension`] as a table option extension that can be /// referenced from SQL statements executed against this context. pub fn register_table_options_extension(&self, extension: T) { self.state .write() - .table_option_namespace - .extensions - .insert(extension) + .register_table_options_extension(extension) } } @@ -1247,15 +1507,18 @@ impl FunctionRegistry for SessionContext { fn udwf(&self, name: &str) -> Result> { self.state.read().udwf(name) } + fn register_udf(&mut self, udf: Arc) -> Result>> { self.state.write().register_udf(udf) } + fn register_udaf( &mut self, udaf: Arc, ) -> Result>> { self.state.write().register_udaf(udaf) } + fn register_udwf(&mut self, udwf: Arc) -> Result>> { self.state.write().register_udwf(udwf) } @@ -1266,41 +1529,54 @@ impl FunctionRegistry for SessionContext { ) -> Result<()> { self.state.write().register_function_rewrite(rewrite) } + + fn expr_planners(&self) -> Vec> { + self.state.read().expr_planners() + } + + fn register_expr_planner( + &mut self, + expr_planner: Arc, + ) -> Result<()> { + self.state.write().register_expr_planner(expr_planner) + } } -/// A planner used to add extensions to DataFusion logical and physical plans. -#[async_trait] -pub trait QueryPlanner { - /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution - async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - session_state: &SessionState, - ) -> Result>; +/// Create a new task context instance from SessionContext +impl From<&SessionContext> for TaskContext { + fn from(session: &SessionContext) -> Self { + TaskContext::from(&*session.state.read()) + } } -/// The query planner used if no user defined planner is provided -struct DefaultQueryPlanner {} +impl From for SessionContext { + fn from(state: SessionState) -> Self { + Self::new_with_state(state) + } +} + +impl From for SessionStateBuilder { + fn from(session: SessionContext) -> Self { + session.into_state_builder() + } +} +/// A planner used to add extensions to DataFusion logical and physical plans. #[async_trait] -impl QueryPlanner for DefaultQueryPlanner { +pub trait QueryPlanner: Debug { /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution async fn create_physical_plan( &self, logical_plan: &LogicalPlan, session_state: &SessionState, - ) -> Result> { - let planner = DefaultPhysicalPlanner::default(); - planner - .create_physical_plan(logical_plan, session_state) - .await - } + ) -> Result>; } + /// A pluggable interface to handle `CREATE FUNCTION` statements /// and interact with [SessionState] to registers new udf, udaf or udwf. #[async_trait] -pub trait FunctionFactory: Sync + Send { +pub trait FunctionFactory: Debug + Sync + Send { /// Handles creation of user defined function specified in [CreateFunction] statement async fn create( &self, @@ -1320,930 +1596,10 @@ pub enum RegisterFunction { /// Table user defined function Table(String, Arc), } -/// Execution context for registering data sources and executing queries. -/// See [`SessionContext`] for a higher level API. -/// -/// Note that there is no `Default` or `new()` for SessionState, -/// to avoid accidentally running queries or other operations without passing through -/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionContext`]. -#[derive(Clone)] -pub struct SessionState { - /// A unique UUID that identifies the session - session_id: String, - /// Responsible for analyzing and rewrite a logical plan before optimization - analyzer: Analyzer, - /// Responsible for optimizing a logical plan - optimizer: Optimizer, - /// Responsible for optimizing a physical execution plan - physical_optimizers: PhysicalOptimizer, - /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` - query_planner: Arc, - /// Collection of catalogs containing schemas and ultimately TableProviders - catalog_list: Arc, - /// Table Functions - table_functions: HashMap>, - /// Scalar functions that are registered with the context - scalar_functions: HashMap>, - /// Aggregate functions registered in the context - aggregate_functions: HashMap>, - /// Window functions registered in the context - window_functions: HashMap>, - /// Deserializer registry for extensions. - serializer_registry: Arc, - /// Session configuration - config: SessionConfig, - /// Table options - table_option_namespace: TableOptions, - /// Execution properties - execution_props: ExecutionProps, - /// TableProviderFactories for different file formats. - /// - /// Maps strings like "JSON" to an instance of [`TableProviderFactory`] - /// - /// This is used to create [`TableProvider`] instances for the - /// `CREATE EXTERNAL TABLE ... STORED AS ` for custom file - /// formats other than those built into DataFusion - table_factories: HashMap>, - /// Runtime environment - runtime_env: Arc, - - /// [FunctionFactory] to support pluggable user defined function handler. - /// - /// It will be invoked on `CREATE FUNCTION` statements. - /// thus, changing dialect o PostgreSql is required - function_factory: Option>, -} - -impl Debug for SessionState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SessionState") - .field("session_id", &self.session_id) - // TODO should we print out more? - .finish() - } -} - -impl SessionState { - /// Returns new [`SessionState`] using the provided - /// [`SessionConfig`] and [`RuntimeEnv`]. - pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - let catalog_list = - Arc::new(MemoryCatalogProviderList::new()) as Arc; - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) - } - - /// Returns new [`SessionState`] using the provided - /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated(since = "32.0.0", note = "Use SessionState::new_with_config_rt")] - pub fn with_config_rt(config: SessionConfig, runtime: Arc) -> Self { - Self::new_with_config_rt(config, runtime) - } - - /// Returns new [`SessionState`] using the provided - /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogProviderList`] - pub fn new_with_config_rt_and_catalog_list( - config: SessionConfig, - runtime: Arc, - catalog_list: Arc, - ) -> Self { - let session_id = Uuid::new_v4().to_string(); - - // Create table_factories for all default formats - let mut table_factories: HashMap> = - HashMap::new(); - #[cfg(feature = "parquet")] - table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); - table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); - - if config.create_default_catalog_and_schema() { - let default_catalog = MemoryCatalogProvider::new(); - - default_catalog - .register_schema( - &config.options().catalog.default_schema, - Arc::new(MemorySchemaProvider::new()), - ) - .expect("memory catalog provider can register schema"); - - Self::register_default_schema( - &config, - &table_factories, - &runtime, - &default_catalog, - ); - - catalog_list.register_catalog( - config.options().catalog.default_catalog.clone(), - Arc::new(default_catalog), - ); - } - - let mut new_self = SessionState { - session_id, - analyzer: Analyzer::new(), - optimizer: Optimizer::new(), - physical_optimizers: PhysicalOptimizer::new(), - query_planner: Arc::new(DefaultQueryPlanner {}), - catalog_list, - table_functions: HashMap::new(), - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - window_functions: HashMap::new(), - serializer_registry: Arc::new(EmptySerializerRegistry), - table_option_namespace: TableOptions::default_from_session_config( - config.options(), - ), - config, - execution_props: ExecutionProps::new(), - runtime_env: runtime, - table_factories, - function_factory: None, - }; - - // register built in functions - functions::register_all(&mut new_self) - .expect("can not register built in functions"); - - // register crate of array expressions (if enabled) - #[cfg(feature = "array_expressions")] - functions_array::register_all(&mut new_self) - .expect("can not register array expressions"); - - functions_aggregate::register_all(&mut new_self) - .expect("can not register aggregate functions"); - - new_self - } - /// Returns new [`SessionState`] using the provided - /// [`SessionConfig`] and [`RuntimeEnv`]. - #[deprecated( - since = "32.0.0", - note = "Use SessionState::new_with_config_rt_and_catalog_list" - )] - pub fn with_config_rt_and_catalog_list( - config: SessionConfig, - runtime: Arc, - catalog_list: Arc, - ) -> Self { - Self::new_with_config_rt_and_catalog_list(config, runtime, catalog_list) - } - fn register_default_schema( - config: &SessionConfig, - table_factories: &HashMap>, - runtime: &Arc, - default_catalog: &MemoryCatalogProvider, - ) { - let url = config.options().catalog.location.as_ref(); - let format = config.options().catalog.format.as_ref(); - let (url, format) = match (url, format) { - (Some(url), Some(format)) => (url, format), - _ => return, - }; - let url = url.to_string(); - let format = format.to_string(); - - let has_header = config.options().catalog.has_header; - let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); - let authority = match url.host_str() { - Some(host) => format!("{}://{}", url.scheme(), host), - None => format!("{}://", url.scheme()), - }; - let path = &url.as_str()[authority.len()..]; - let path = object_store::path::Path::parse(path).expect("Can't parse path"); - let store = ObjectStoreUrl::parse(authority.as_str()) - .expect("Invalid default catalog url"); - let store = match runtime.object_store(store) { - Ok(store) => store, - _ => return, - }; - let factory = match table_factories.get(format.as_str()) { - Some(factory) => factory, - _ => return, - }; - let schema = ListingSchemaProvider::new( - authority, - path, - factory.clone(), - store, - format, - has_header, - ); - let _ = default_catalog - .register_schema("default", Arc::new(schema)) - .expect("Failed to register default schema"); - } - - fn resolve_table_ref( - &self, - table_ref: impl Into, - ) -> ResolvedTableReference { - let catalog = &self.config_options().catalog; - table_ref - .into() - .resolve(&catalog.default_catalog, &catalog.default_schema) - } - - pub(crate) fn schema_for_ref( - &self, - table_ref: impl Into, - ) -> Result> { - let resolved_ref = self.resolve_table_ref(table_ref); - if self.config.information_schema() && *resolved_ref.schema == *INFORMATION_SCHEMA - { - return Ok(Arc::new(InformationSchemaProvider::new( - self.catalog_list.clone(), - ))); - } - - self.catalog_list - .catalog(&resolved_ref.catalog) - .ok_or_else(|| { - plan_datafusion_err!( - "failed to resolve catalog: {}", - resolved_ref.catalog - ) - })? - .schema(&resolved_ref.schema) - .ok_or_else(|| { - plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema) - }) - } - - /// Replace the random session id. - pub fn with_session_id(mut self, session_id: String) -> Self { - self.session_id = session_id; - self - } - - /// override default query planner with `query_planner` - pub fn with_query_planner( - mut self, - query_planner: Arc, - ) -> Self { - self.query_planner = query_planner; - self - } - - /// Override the [`AnalyzerRule`]s optimizer plan rules. - pub fn with_analyzer_rules( - mut self, - rules: Vec>, - ) -> Self { - self.analyzer = Analyzer::with_rules(rules); - self - } - - /// Replace the entire list of [`OptimizerRule`]s used to optimize plans - pub fn with_optimizer_rules( - mut self, - rules: Vec>, - ) -> Self { - self.optimizer = Optimizer::with_rules(rules); - self - } - - /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans - pub fn with_physical_optimizer_rules( - mut self, - physical_optimizers: Vec>, - ) -> Self { - self.physical_optimizers = PhysicalOptimizer::with_rules(physical_optimizers); - self - } - - /// Add `analyzer_rule` to the end of the list of - /// [`AnalyzerRule`]s used to rewrite queries. - pub fn add_analyzer_rule( - mut self, - analyzer_rule: Arc, - ) -> Self { - self.analyzer.rules.push(analyzer_rule); - self - } - - /// Add `optimizer_rule` to the end of the list of - /// [`OptimizerRule`]s used to rewrite queries. - pub fn add_optimizer_rule( - mut self, - optimizer_rule: Arc, - ) -> Self { - self.optimizer.rules.push(optimizer_rule); - self - } - - /// Add `physical_optimizer_rule` to the end of the list of - /// [`PhysicalOptimizerRule`]s used to rewrite queries. - pub fn add_physical_optimizer_rule( - mut self, - physical_optimizer_rule: Arc, - ) -> Self { - self.physical_optimizers.rules.push(physical_optimizer_rule); - self - } - - /// Adds a new [`ConfigExtension`] to TableOptions - pub fn add_table_options_extension( - mut self, - extension: T, - ) -> Self { - self.table_option_namespace.extensions.insert(extension); - self - } - - /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements - pub fn with_function_factory( - mut self, - function_factory: Arc, - ) -> Self { - self.function_factory = Some(function_factory); - self - } - - /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements - pub fn set_function_factory(&mut self, function_factory: Arc) { - self.function_factory = Some(function_factory); - } - - /// Replace the extension [`SerializerRegistry`] - pub fn with_serializer_registry( - mut self, - registry: Arc, - ) -> Self { - self.serializer_registry = registry; - self - } - - /// Get the table factories - pub fn table_factories(&self) -> &HashMap> { - &self.table_factories - } - - /// Get the table factories - pub fn table_factories_mut( - &mut self, - ) -> &mut HashMap> { - &mut self.table_factories - } - - /// Parse an SQL string into an DataFusion specific AST - /// [`Statement`]. See [`SessionContext::sql`] for running queries. - pub fn sql_to_statement( - &self, - sql: &str, - dialect: &str, - ) -> Result { - let dialect = dialect_from_str(dialect).ok_or_else(|| { - plan_datafusion_err!( - "Unsupported SQL dialect: {dialect}. Available dialects: \ - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ - MsSQL, ClickHouse, BigQuery, Ansi." - ) - })?; - let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; - if statements.len() > 1 { - return not_impl_err!( - "The context currently only supports a single SQL statement" - ); - } - let statement = statements.pop_front().ok_or_else(|| { - DataFusionError::NotImplemented( - "The context requires a statement!".to_string(), - ) - })?; - Ok(statement) - } - - /// Resolve all table references in the SQL statement. - pub fn resolve_table_references( - &self, - statement: &datafusion_sql::parser::Statement, - ) -> Result> { - use crate::catalog::information_schema::INFORMATION_SCHEMA_TABLES; - use datafusion_sql::parser::Statement as DFStatement; - use sqlparser::ast::*; - - // Getting `TableProviders` is async but planing is not -- thus pre-fetch - // table providers for all relations referenced in this query - let mut relations = hashbrown::HashSet::with_capacity(10); - - struct RelationVisitor<'a>(&'a mut hashbrown::HashSet); - - impl<'a> RelationVisitor<'a> { - /// Record that `relation` was used in this statement - fn insert(&mut self, relation: &ObjectName) { - self.0.get_or_insert_with(relation, |_| relation.clone()); - } - } - - impl<'a> Visitor for RelationVisitor<'a> { - type Break = (); - - fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> { - self.insert(relation); - ControlFlow::Continue(()) - } - - fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> { - if let Statement::ShowCreate { - obj_type: ShowCreateObject::Table | ShowCreateObject::View, - obj_name, - } = statement - { - self.insert(obj_name) - } - ControlFlow::Continue(()) - } - } - - let mut visitor = RelationVisitor(&mut relations); - fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor<'_>) { - match statement { - DFStatement::Statement(s) => { - let _ = s.as_ref().visit(visitor); - } - DFStatement::CreateExternalTable(table) => { - visitor - .0 - .insert(ObjectName(vec![Ident::from(table.name.as_str())])); - } - DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { - CopyToSource::Relation(table_name) => { - visitor.insert(table_name); - } - CopyToSource::Query(query) => { - query.visit(visitor); - } - }, - DFStatement::Explain(explain) => { - visit_statement(&explain.statement, visitor) - } - } - } - - visit_statement(statement, &mut visitor); - - // Always include information_schema if available - if self.config.information_schema() { - for s in INFORMATION_SCHEMA_TABLES { - relations.insert(ObjectName(vec![ - Ident::new(INFORMATION_SCHEMA), - Ident::new(*s), - ])); - } - } - - let enable_ident_normalization = - self.config.options().sql_parser.enable_ident_normalization; - relations - .into_iter() - .map(|x| object_name_to_table_reference(x, enable_ident_normalization)) - .collect::>() - } - - /// Convert an AST Statement into a LogicalPlan - pub async fn statement_to_plan( - &self, - statement: datafusion_sql::parser::Statement, - ) -> Result { - let references = self.resolve_table_references(&statement)?; - - let mut provider = SessionContextProvider { - state: self, - tables: HashMap::with_capacity(references.len()), - }; - - let enable_ident_normalization = - self.config.options().sql_parser.enable_ident_normalization; - let parse_float_as_decimal = - self.config.options().sql_parser.parse_float_as_decimal; - for reference in references { - let resolved = &self.resolve_table_ref(reference); - if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) { - if let Ok(schema) = self.schema_for_ref(resolved.clone()) { - if let Some(table) = schema.table(&resolved.table).await? { - v.insert(provider_as_source(table)); - } - } - } - } - - let query = SqlToRel::new_with_options( - &provider, - ParserOptions { - parse_float_as_decimal, - enable_ident_normalization, - }, - ); - query.statement_to_plan(statement) - } - - /// Creates a [`LogicalPlan`] from the provided SQL string. This - /// interface will plan any SQL DataFusion supports, including DML - /// like `CREATE TABLE`, and `COPY` (which can write to local - /// files. - /// - /// See [`SessionContext::sql`] and - /// [`SessionContext::sql_with_options`] for a higher-level - /// interface that handles DDL and verification of allowed - /// statements. - pub async fn create_logical_plan(&self, sql: &str) -> Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); - let statement = self.sql_to_statement(sql, dialect)?; - let plan = self.statement_to_plan(statement).await?; - Ok(plan) - } - - /// Optimizes the logical plan by applying optimizer rules. - pub fn optimize(&self, plan: &LogicalPlan) -> Result { - if let LogicalPlan::Explain(e) = plan { - let mut stringified_plans = e.stringified_plans.clone(); - - // analyze & capture output of each rule - let analyzer_result = self.analyzer.execute_and_check( - e.plan.as_ref(), - self.options(), - |analyzed_plan, analyzer| { - let analyzer_name = analyzer.name().to_string(); - let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; - stringified_plans.push(analyzed_plan.to_stringified(plan_type)); - }, - ); - let analyzed_plan = match analyzer_result { - Ok(plan) => plan, - Err(DataFusionError::Context(analyzer_name, err)) => { - let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; - stringified_plans - .push(StringifiedPlan::new(plan_type, err.to_string())); - - return Ok(LogicalPlan::Explain(Explain { - verbose: e.verbose, - plan: e.plan.clone(), - stringified_plans, - schema: e.schema.clone(), - logical_optimization_succeeded: false, - })); - } - Err(e) => return Err(e), - }; - - // to delineate the analyzer & optimizer phases in explain output - stringified_plans - .push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan)); - - // optimize the child plan, capturing the output of each optimizer - let optimized_plan = self.optimizer.optimize( - analyzed_plan, - self, - |optimized_plan, optimizer| { - let optimizer_name = optimizer.name().to_string(); - let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; - stringified_plans.push(optimized_plan.to_stringified(plan_type)); - }, - ); - let (plan, logical_optimization_succeeded) = match optimized_plan { - Ok(plan) => (Arc::new(plan), true), - Err(DataFusionError::Context(optimizer_name, err)) => { - let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; - stringified_plans - .push(StringifiedPlan::new(plan_type, err.to_string())); - (e.plan.clone(), false) - } - Err(e) => return Err(e), - }; - - Ok(LogicalPlan::Explain(Explain { - verbose: e.verbose, - plan, - stringified_plans, - schema: e.schema.clone(), - logical_optimization_succeeded, - })) - } else { - let analyzed_plan = - self.analyzer - .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(analyzed_plan, self, |_, _| {}) - } - } - - /// Creates a physical plan from a logical plan. - /// - /// Note: this first calls [`Self::optimize`] on the provided - /// plan. - /// - /// This function will error for [`LogicalPlan`]s such as catalog - /// DDL `CREATE TABLE` must be handled by another layer. - pub async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - ) -> Result> { - let logical_plan = self.optimize(logical_plan)?; - self.query_planner - .create_physical_plan(&logical_plan, self) - .await - } - - /// Return the session ID - pub fn session_id(&self) -> &str { - &self.session_id - } - - /// Return the runtime env - pub fn runtime_env(&self) -> &Arc { - &self.runtime_env - } - - /// Return the execution properties - pub fn execution_props(&self) -> &ExecutionProps { - &self.execution_props - } - - /// Return the [`SessionConfig`] - pub fn config(&self) -> &SessionConfig { - &self.config - } - - /// Return the mutable [`SessionConfig`]. - pub fn config_mut(&mut self) -> &mut SessionConfig { - &mut self.config - } - - /// Return the physical optimizers - pub fn physical_optimizers(&self) -> &[Arc] { - &self.physical_optimizers.rules - } - - /// return the configuration options - pub fn config_options(&self) -> &ConfigOptions { - self.config.options() - } - - /// return the TableOptions options with its extensions - pub fn default_table_options(&self) -> TableOptions { - self.table_option_namespace - .combine_with_session_config(self.config_options()) - } - - /// Get a new TaskContext to run in this session - pub fn task_ctx(&self) -> Arc { - Arc::new(TaskContext::from(self)) - } - - /// Return catalog list - pub fn catalog_list(&self) -> Arc { - self.catalog_list.clone() - } - - /// Return reference to scalar_functions - pub fn scalar_functions(&self) -> &HashMap> { - &self.scalar_functions - } - - /// Return reference to aggregate_functions - pub fn aggregate_functions(&self) -> &HashMap> { - &self.aggregate_functions - } - - /// Return reference to window functions - pub fn window_functions(&self) -> &HashMap> { - &self.window_functions - } - - /// Return [SerializerRegistry] for extensions - pub fn serializer_registry(&self) -> Arc { - self.serializer_registry.clone() - } - - /// Return version of the cargo package that produced this query - pub fn version(&self) -> &str { - env!("CARGO_PKG_VERSION") - } -} - -struct SessionContextProvider<'a> { - state: &'a SessionState, - tables: HashMap>, -} - -impl<'a> ContextProvider for SessionContextProvider<'a> { - fn get_table_source(&self, name: TableReference) -> Result> { - let name = self.state.resolve_table_ref(name).to_string(); - self.tables - .get(&name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) - } - - fn get_table_function_source( - &self, - name: &str, - args: Vec, - ) -> Result> { - let tbl_func = self - .state - .table_functions - .get(name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; - let provider = tbl_func.create_table_provider(&args)?; - - Ok(provider_as_source(provider)) - } - - /// Create a new CTE work table for a recursive CTE logical plan - /// This table will be used in conjunction with a Worktable physical plan - /// to read and write each iteration of a recursive CTE - fn create_cte_work_table( - &self, - name: &str, - schema: SchemaRef, - ) -> Result> { - let table = Arc::new(CteWorkTable::new(name, schema)); - Ok(provider_as_source(table)) - } - - fn get_function_meta(&self, name: &str) -> Option> { - self.state.scalar_functions().get(name).cloned() - } - - fn get_aggregate_meta(&self, name: &str) -> Option> { - self.state.aggregate_functions().get(name).cloned() - } - - fn get_window_meta(&self, name: &str) -> Option> { - self.state.window_functions().get(name).cloned() - } - - fn get_variable_type(&self, variable_names: &[String]) -> Option { - if variable_names.is_empty() { - return None; - } - - let provider_type = if is_system_variables(variable_names) { - VarType::System - } else { - VarType::UserDefined - }; - - self.state - .execution_props - .var_providers - .as_ref() - .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names)) - } - - fn options(&self) -> &ConfigOptions { - self.state.config_options() - } - - fn udfs_names(&self) -> Vec { - self.state.scalar_functions().keys().cloned().collect() - } - - fn udafs_names(&self) -> Vec { - self.state.aggregate_functions().keys().cloned().collect() - } - - fn udwfs_names(&self) -> Vec { - self.state.window_functions().keys().cloned().collect() - } -} - -impl FunctionRegistry for SessionState { - fn udfs(&self) -> HashSet { - self.scalar_functions.keys().cloned().collect() - } - - fn udf(&self, name: &str) -> Result> { - let result = self.scalar_functions.get(name); - - result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") - }) - } - - fn udaf(&self, name: &str) -> Result> { - let result = self.aggregate_functions.get(name); - - result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") - }) - } - - fn udwf(&self, name: &str) -> Result> { - let result = self.window_functions.get(name); - - result.cloned().ok_or_else(|| { - plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") - }) - } - - fn register_udf(&mut self, udf: Arc) -> Result>> { - udf.aliases().iter().for_each(|alias| { - self.scalar_functions.insert(alias.clone(), udf.clone()); - }); - Ok(self.scalar_functions.insert(udf.name().into(), udf)) - } - - fn register_udaf( - &mut self, - udaf: Arc, - ) -> Result>> { - udaf.aliases().iter().for_each(|alias| { - self.aggregate_functions.insert(alias.clone(), udaf.clone()); - }); - Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) - } - - fn register_udwf(&mut self, udwf: Arc) -> Result>> { - udwf.aliases().iter().for_each(|alias| { - self.window_functions.insert(alias.clone(), udwf.clone()); - }); - Ok(self.window_functions.insert(udwf.name().into(), udwf)) - } - - fn deregister_udf(&mut self, name: &str) -> Result>> { - let udf = self.scalar_functions.remove(name); - if let Some(udf) = &udf { - for alias in udf.aliases() { - self.scalar_functions.remove(alias); - } - } - Ok(udf) - } - - fn deregister_udaf(&mut self, name: &str) -> Result>> { - let udaf = self.aggregate_functions.remove(name); - if let Some(udaf) = &udaf { - for alias in udaf.aliases() { - self.aggregate_functions.remove(alias); - } - } - Ok(udaf) - } - - fn deregister_udwf(&mut self, name: &str) -> Result>> { - let udwf = self.window_functions.remove(name); - if let Some(udwf) = &udwf { - for alias in udwf.aliases() { - self.window_functions.remove(alias); - } - } - Ok(udwf) - } - - fn register_function_rewrite( - &mut self, - rewrite: Arc, - ) -> Result<()> { - self.analyzer.add_function_rewrite(rewrite); - Ok(()) - } -} - -impl OptimizerConfig for SessionState { - fn query_execution_start_time(&self) -> DateTime { - self.execution_props.query_execution_start_time - } - - fn alias_generator(&self) -> Arc { - self.execution_props.alias_generator.clone() - } - - fn options(&self) -> &ConfigOptions { - self.config_options() - } -} - -/// Create a new task context instance from SessionContext -impl From<&SessionContext> for TaskContext { - fn from(session: &SessionContext) -> Self { - TaskContext::from(&*session.state.read()) - } -} - -/// Create a new task context instance from SessionState -impl From<&SessionState> for TaskContext { - fn from(state: &SessionState) -> Self { - let task_id = None; - TaskContext::new( - task_id, - state.session_id.clone(), - state.config.clone(), - state.scalar_functions.clone(), - state.aggregate_functions.clone(), - state.window_functions.clone(), - state.runtime_env.clone(), - ) - } -} /// Default implementation of [SerializerRegistry] that throws unimplemented error /// for all requests. +#[derive(Debug)] pub struct EmptySerializerRegistry; impl SerializerRegistry for EmptySerializerRegistry { @@ -2297,13 +1653,13 @@ impl SQLOptions { Default::default() } - /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true`. + /// Should DDL data definition commands (e.g. `CREATE TABLE`) be run? Defaults to `true`. pub fn with_allow_ddl(mut self, allow: bool) -> Self { self.allow_ddl = allow; self } - /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? Defaults to `true` + /// Should DML data modification commands (e.g. `INSERT` and `COPY`) be run? Defaults to `true` pub fn with_allow_dml(mut self, allow: bool) -> Self { self.allow_dml = allow; self @@ -2332,10 +1688,10 @@ impl<'a> BadPlanVisitor<'a> { } } -impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { +impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> { type Node = LogicalPlan; - fn f_down(&mut self, node: &Self::Node) -> Result { + fn f_down(&mut self, node: &'n Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2362,12 +1718,15 @@ mod tests { use super::{super::options::CsvReadOptions, *}; use crate::assert_batches_eq; use crate::execution::memory_pool::MemoryConsumer; - use crate::execution::runtime_env::RuntimeConfig; + use crate::execution::runtime_env::RuntimeEnvBuilder; use crate::test; use crate::test_util::{plan_and_collect, populate_csv_partitions}; use datafusion_common_runtime::SpawnedTask; + use crate::catalog::SchemaProvider; + use crate::execution::session_state::SessionStateBuilder; + use crate::physical_planner::PhysicalPlanner; use async_trait::async_trait; use tempfile::TempDir; @@ -2494,13 +1853,16 @@ mod tests { let path = path.join("tests/tpch-csv"); let url = format!("file://{}", path.display()); - let rt_cfg = RuntimeConfig::new(); - let runtime = Arc::new(RuntimeEnv::new(rt_cfg).unwrap()); + let runtime = RuntimeEnvBuilder::new().build_arc()?; let cfg = SessionConfig::new() .set_str("datafusion.catalog.location", url.as_str()) .set_str("datafusion.catalog.format", "CSV") .set_str("datafusion.catalog.has_header", "true"); - let session_state = SessionState::new_with_config_rt(cfg, runtime); + let session_state = SessionStateBuilder::new() + .with_config(cfg) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(session_state); ctx.refresh_catalogs().await?; @@ -2523,12 +1885,47 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_dynamic_file_query() -> Result<()> { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let path = path.join("tests/tpch-csv/customer.csv"); + let url = format!("file://{}", path.display()); + let cfg = SessionConfig::new(); + let session_state = SessionStateBuilder::new() + .with_default_features() + .with_config(cfg) + .build(); + let ctx = SessionContext::new_with_state(session_state).enable_url_table(); + let result = plan_and_collect( + &ctx, + format!("select c_name from '{}' limit 3;", &url).as_str(), + ) + .await?; + + let actual = arrow::util::pretty::pretty_format_batches(&result) + .unwrap() + .to_string(); + let expected = r#"+--------------------+ +| c_name | ++--------------------+ +| Customer#000000002 | +| Customer#000000003 | +| Customer#000000004 | ++--------------------+"#; + assert_eq!(actual, expected); + + Ok(()) + } + #[tokio::test] async fn custom_query_planner() -> Result<()> { let runtime = Arc::new(RuntimeEnv::default()); - let session_state = - SessionState::new_with_config_rt(SessionConfig::new(), runtime) - .with_query_planner(Arc::new(MyQueryPlanner {})); + let session_state = SessionStateBuilder::new() + .with_config(SessionConfig::new()) + .with_runtime_env(runtime) + .with_default_features() + .with_query_planner(Arc::new(MyQueryPlanner {})) + .build(); let ctx = SessionContext::new_with_state(session_state); let df = ctx.sql("SELECT 1").await?; @@ -2669,7 +2066,7 @@ mod tests { let catalog_list_weak = { let state = ctx.state.read(); - Arc::downgrade(&state.catalog_list) + Arc::downgrade(state.catalog_list()) }; drop(ctx); @@ -2742,13 +2139,14 @@ mod tests { fn create_physical_expr( &self, _expr: &Expr, - _input_dfschema: &crate::common::DFSchema, + _input_dfschema: &DFSchema, _session_state: &SessionState, - ) -> Result> { + ) -> Result> { unimplemented!() } } + #[derive(Debug)] struct MyQueryPlanner {} #[async_trait] diff --git a/datafusion/core/src/execution/context/parquet.rs b/datafusion/core/src/execution/context/parquet.rs index f7ab15d95baa..3f23c150be83 100644 --- a/datafusion/core/src/execution/context/parquet.rs +++ b/datafusion/core/src/execution/context/parquet.rs @@ -21,6 +21,7 @@ use super::super::options::{ParquetReadOptions, ReadOptions}; use super::{DataFilePaths, DataFrame, ExecutionPlan, Result, SessionContext}; use crate::datasource::physical_plan::parquet::plan_to_parquet; +use datafusion_common::TableReference; use parquet::file::properties::WriterProperties; impl SessionContext { @@ -42,15 +43,15 @@ impl SessionContext { /// statements executed against this context. pub async fn register_parquet( &self, - name: &str, - table_path: &str, + table_ref: impl Into, + table_path: impl AsRef, options: ParquetReadOptions<'_>, ) -> Result<()> { let listing_options = options .to_listing_options(&self.copied_config(), self.copied_table_options()); self.register_listing_table( - name, + table_ref, table_path, listing_options, options.schema.map(|s| Arc::new(s.to_owned())), @@ -84,7 +85,6 @@ mod tests { use datafusion_common::config::TableParquetOptions; use datafusion_execution::config::SessionConfig; - use async_trait::async_trait; use tempfile::tempdir; #[tokio::test] @@ -107,7 +107,7 @@ mod tests { #[tokio::test] async fn read_with_glob_path_issue_2465() -> Result<()> { let config = - SessionConfig::from_string_hash_map(std::collections::HashMap::from([( + SessionConfig::from_string_hash_map(&std::collections::HashMap::from([( "datafusion.execution.listing_table_ignore_subdirectory".to_owned(), "false".to_owned(), )]))?; @@ -331,23 +331,4 @@ mod tests { assert_eq!(total_rows, 5); Ok(()) } - - // Test for compilation error when calling read_* functions from an #[async_trait] function. - // See https://github.com/apache/datafusion/issues/1154 - #[async_trait] - trait CallReadTrait { - async fn call_read_parquet(&self) -> DataFrame; - } - - struct CallRead {} - - #[async_trait] - impl CallReadTrait for CallRead { - async fn call_read_parquet(&self) -> DataFrame { - let ctx = SessionContext::new(); - ctx.read_parquet("dummy", ParquetReadOptions::default()) - .await - .unwrap() - } - } } diff --git a/datafusion/core/src/execution/mod.rs b/datafusion/core/src/execution/mod.rs index 7e757fabac8e..10aa16ffe47a 100644 --- a/datafusion/core/src/execution/mod.rs +++ b/datafusion/core/src/execution/mod.rs @@ -18,6 +18,13 @@ //! Shared state for query planning and execution. pub mod context; +pub mod session_state; +pub use session_state::{SessionState, SessionStateBuilder}; + +mod session_state_defaults; + +pub use session_state_defaults::SessionStateDefaults; + // backwards compatibility pub use crate::datasource::file_format::options; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs new file mode 100644 index 000000000000..d50c912dd2fd --- /dev/null +++ b/datafusion/core/src/execution/session_state.rs @@ -0,0 +1,2035 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`SessionState`]: information required to run queries in a session + +use crate::catalog::{CatalogProviderList, SchemaProvider, TableProviderFactory}; +use crate::catalog_common::information_schema::{ + InformationSchemaProvider, INFORMATION_SCHEMA, +}; +use crate::catalog_common::MemoryCatalogProviderList; +use crate::datasource::cte_worktable::CteWorkTable; +use crate::datasource::file_format::{format_as_file_type, FileFormatFactory}; +use crate::datasource::function::{TableFunction, TableFunctionImpl}; +use crate::datasource::provider_as_source; +use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; +use crate::execution::SessionStateDefaults; +use crate::physical_optimizer::optimizer::PhysicalOptimizer; +use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use arrow_schema::{DataType, SchemaRef}; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use datafusion_catalog::Session; +use datafusion_common::alias::AliasGenerator; +use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; +use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::tree_node::TreeNode; +use datafusion_common::{ + config_err, not_impl_err, plan_datafusion_err, DFSchema, DataFusionError, + ResolvedTableReference, TableReference, +}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_execution::TaskContext; +use datafusion_expr::execution_props::ExecutionProps; +use datafusion_expr::expr_rewriter::FunctionRewrite; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::registry::{FunctionRegistry, SerializerRegistry}; +use datafusion_expr::simplify::SimplifyInfo; +use datafusion_expr::var_provider::{is_system_variables, VarType}; +use datafusion_expr::{ + AggregateUDF, Explain, Expr, ExprSchemable, LogicalPlan, ScalarUDF, TableSource, + WindowUDF, +}; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; +use datafusion_optimizer::{ + Analyzer, AnalyzerRule, Optimizer, OptimizerConfig, OptimizerRule, +}; +use datafusion_physical_expr::create_physical_expr; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::ExecutionPlan; +use datafusion_sql::parser::{DFParser, Statement}; +use datafusion_sql::planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}; +use itertools::Itertools; +use log::{debug, info}; +use object_store::ObjectStore; +use sqlparser::ast::Expr as SQLExpr; +use sqlparser::dialect::dialect_from_str; +use std::any::Any; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::sync::Arc; +use url::Url; +use uuid::Uuid; + +/// `SessionState` contains all the necessary state to plan and execute queries, +/// such as configuration, functions, and runtime environment. Please see the +/// documentation on [`SessionContext`] for more information. +/// +/// +/// # Example: `SessionState` from a [`SessionContext`] +/// +/// ``` +/// use datafusion::prelude::*; +/// let ctx = SessionContext::new(); +/// let state = ctx.state(); +/// ``` +/// +/// # Example: `SessionState` via [`SessionStateBuilder`] +/// +/// You can also use [`SessionStateBuilder`] to build a `SessionState` object +/// directly: +/// +/// ``` +/// use datafusion::prelude::*; +/// # use datafusion::{error::Result, assert_batches_eq}; +/// # use datafusion::execution::session_state::SessionStateBuilder; +/// # use datafusion_execution::runtime_env::RuntimeEnv; +/// # use std::sync::Arc; +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// let state = SessionStateBuilder::new() +/// .with_config(SessionConfig::new()) +/// .with_runtime_env(Arc::new(RuntimeEnv::default())) +/// .with_default_features() +/// .build(); +/// Ok(()) +/// # } +/// ``` +/// +/// Note that there is no `Default` or `new()` for SessionState, +/// to avoid accidentally running queries or other operations without passing through +/// the [`SessionConfig`] or [`RuntimeEnv`]. See [`SessionStateBuilder`] and +/// [`SessionContext`]. +/// +/// [`SessionContext`]: crate::execution::context::SessionContext +#[derive(Clone)] +pub struct SessionState { + /// A unique UUID that identifies the session + session_id: String, + /// Responsible for analyzing and rewrite a logical plan before optimization + analyzer: Analyzer, + /// Provides support for customising the SQL planner, e.g. to add support for custom operators like `->>` or `?` + expr_planners: Vec>, + /// Responsible for optimizing a logical plan + optimizer: Optimizer, + /// Responsible for optimizing a physical execution plan + physical_optimizers: PhysicalOptimizer, + /// Responsible for planning `LogicalPlan`s, and `ExecutionPlan` + query_planner: Arc, + /// Collection of catalogs containing schemas and ultimately TableProviders + catalog_list: Arc, + /// Table Functions + table_functions: HashMap>, + /// Scalar functions that are registered with the context + scalar_functions: HashMap>, + /// Aggregate functions registered in the context + aggregate_functions: HashMap>, + /// Window functions registered in the context + window_functions: HashMap>, + /// Deserializer registry for extensions. + serializer_registry: Arc, + /// Holds registered external FileFormat implementations + file_formats: HashMap>, + /// Session configuration + config: SessionConfig, + /// Table options + table_options: TableOptions, + /// Execution properties + execution_props: ExecutionProps, + /// TableProviderFactories for different file formats. + /// + /// Maps strings like "JSON" to an instance of [`TableProviderFactory`] + /// + /// This is used to create [`TableProvider`] instances for the + /// `CREATE EXTERNAL TABLE ... STORED AS ` for custom file + /// formats other than those built into DataFusion + /// + /// [`TableProvider`]: crate::catalog::TableProvider + table_factories: HashMap>, + /// Runtime environment + runtime_env: Arc, + /// [FunctionFactory] to support pluggable user defined function handler. + /// + /// It will be invoked on `CREATE FUNCTION` statements. + /// thus, changing dialect o PostgreSql is required + function_factory: Option>, +} + +impl Debug for SessionState { + /// Prefer having short fields at the top and long vector fields near the end + /// Group fields by + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionState") + .field("session_id", &self.session_id) + .field("config", &self.config) + .field("runtime_env", &self.runtime_env) + .field("catalog_list", &self.catalog_list) + .field("serializer_registry", &self.serializer_registry) + .field("file_formats", &self.file_formats) + .field("execution_props", &self.execution_props) + .field("table_options", &self.table_options) + .field("table_factories", &self.table_factories) + .field("function_factory", &self.function_factory) + .field("expr_planners", &self.expr_planners) + .field("query_planners", &self.query_planner) + .field("analyzer", &self.analyzer) + .field("optimizer", &self.optimizer) + .field("physical_optimizers", &self.physical_optimizers) + .field("table_functions", &self.table_functions) + .field("scalar_functions", &self.scalar_functions) + .field("aggregate_functions", &self.aggregate_functions) + .field("window_functions", &self.window_functions) + .finish() + } +} + +#[async_trait] +impl Session for SessionState { + fn session_id(&self) -> &str { + self.session_id() + } + + fn config(&self) -> &SessionConfig { + self.config() + } + + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + ) -> datafusion_common::Result> { + self.create_physical_plan(logical_plan).await + } + + fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> datafusion_common::Result> { + self.create_physical_expr(expr, df_schema) + } + + fn scalar_functions(&self) -> &HashMap> { + self.scalar_functions() + } + + fn aggregate_functions(&self) -> &HashMap> { + self.aggregate_functions() + } + + fn window_functions(&self) -> &HashMap> { + self.window_functions() + } + + fn runtime_env(&self) -> &Arc { + self.runtime_env() + } + + fn execution_props(&self) -> &ExecutionProps { + self.execution_props() + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl SessionState { + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`] and [`RuntimeEnv`]. + #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self { + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() + } + + /// Returns new [`SessionState`] using the provided + /// [`SessionConfig`], [`RuntimeEnv`], and [`CatalogProviderList`] + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + pub fn new_with_config_rt_and_catalog_list( + config: SessionConfig, + runtime: Arc, + catalog_list: Arc, + ) -> Self { + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_catalog_list(catalog_list) + .with_default_features() + .build() + } + + pub(crate) fn resolve_table_ref( + &self, + table_ref: impl Into, + ) -> ResolvedTableReference { + let catalog = &self.config_options().catalog; + table_ref + .into() + .resolve(&catalog.default_catalog, &catalog.default_schema) + } + + pub(crate) fn schema_for_ref( + &self, + table_ref: impl Into, + ) -> datafusion_common::Result> { + let resolved_ref = self.resolve_table_ref(table_ref); + if self.config.information_schema() && *resolved_ref.schema == *INFORMATION_SCHEMA + { + return Ok(Arc::new(InformationSchemaProvider::new( + self.catalog_list.clone(), + ))); + } + + self.catalog_list + .catalog(&resolved_ref.catalog) + .ok_or_else(|| { + plan_datafusion_err!( + "failed to resolve catalog: {}", + resolved_ref.catalog + ) + })? + .schema(&resolved_ref.schema) + .ok_or_else(|| { + plan_datafusion_err!("failed to resolve schema: {}", resolved_ref.schema) + }) + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Replace the random session id. + pub fn with_session_id(mut self, session_id: String) -> Self { + self.session_id = session_id; + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// override default query planner with `query_planner` + pub fn with_query_planner( + mut self, + query_planner: Arc, + ) -> Self { + self.query_planner = query_planner; + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Override the [`AnalyzerRule`]s optimizer plan rules. + pub fn with_analyzer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.analyzer = Analyzer::with_rules(rules); + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Replace the entire list of [`OptimizerRule`]s used to optimize plans + pub fn with_optimizer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.optimizer = Optimizer::with_rules(rules); + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Replace the entire list of [`PhysicalOptimizerRule`]s used to optimize plans + pub fn with_physical_optimizer_rules( + mut self, + physical_optimizers: Vec>, + ) -> Self { + self.physical_optimizers = PhysicalOptimizer::with_rules(physical_optimizers); + self + } + + /// Add `analyzer_rule` to the end of the list of + /// [`AnalyzerRule`]s used to rewrite queries. + pub fn add_analyzer_rule( + &mut self, + analyzer_rule: Arc, + ) -> &Self { + self.analyzer.rules.push(analyzer_rule); + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Add `optimizer_rule` to the end of the list of + /// [`OptimizerRule`]s used to rewrite queries. + pub fn add_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + self.optimizer.rules.push(optimizer_rule); + self + } + + // the add_optimizer_rule takes an owned reference + // it should probably be renamed to `with_optimizer_rule` to follow builder style + // and `add_optimizer_rule` that takes &mut self added instead of this + pub(crate) fn append_optimizer_rule( + &mut self, + optimizer_rule: Arc, + ) { + self.optimizer.rules.push(optimizer_rule); + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Add `physical_optimizer_rule` to the end of the list of + /// [`PhysicalOptimizerRule`]s used to rewrite queries. + pub fn add_physical_optimizer_rule( + mut self, + physical_optimizer_rule: Arc, + ) -> Self { + self.physical_optimizers.rules.push(physical_optimizer_rule); + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Adds a new [`ConfigExtension`] to TableOptions + pub fn add_table_options_extension( + mut self, + extension: T, + ) -> Self { + self.table_options.extensions.insert(extension); + self + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn with_function_factory( + mut self, + function_factory: Arc, + ) -> Self { + self.function_factory = Some(function_factory); + self + } + + /// Registers a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn set_function_factory(&mut self, function_factory: Arc) { + self.function_factory = Some(function_factory); + } + + #[deprecated(since = "40.0.0", note = "Use SessionStateBuilder")] + /// Replace the extension [`SerializerRegistry`] + pub fn with_serializer_registry( + mut self, + registry: Arc, + ) -> Self { + self.serializer_registry = registry; + self + } + + /// Get the function factory + pub fn function_factory(&self) -> Option<&Arc> { + self.function_factory.as_ref() + } + + /// Get the table factories + pub fn table_factories(&self) -> &HashMap> { + &self.table_factories + } + + /// Get the table factories + pub fn table_factories_mut( + &mut self, + ) -> &mut HashMap> { + &mut self.table_factories + } + + /// Parse an SQL string into an DataFusion specific AST + /// [`Statement`]. See [`SessionContext::sql`] for running queries. + /// + /// [`SessionContext::sql`]: crate::execution::context::SessionContext::sql + pub fn sql_to_statement( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { + let dialect = dialect_from_str(dialect).ok_or_else(|| { + plan_datafusion_err!( + "Unsupported SQL dialect: {dialect}. Available dialects: \ + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." + ) + })?; + let mut statements = DFParser::parse_sql_with_dialect(sql, dialect.as_ref())?; + if statements.len() > 1 { + return not_impl_err!( + "The context currently only supports a single SQL statement" + ); + } + let statement = statements.pop_front().ok_or_else(|| { + plan_datafusion_err!("No SQL statements were provided in the query string") + })?; + Ok(statement) + } + + /// parse a sql string into a sqlparser-rs AST [`SQLExpr`]. + /// + /// See [`Self::create_logical_expr`] for parsing sql to [`Expr`]. + pub fn sql_to_expr( + &self, + sql: &str, + dialect: &str, + ) -> datafusion_common::Result { + let dialect = dialect_from_str(dialect).ok_or_else(|| { + plan_datafusion_err!( + "Unsupported SQL dialect: {dialect}. Available dialects: \ + Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ + MsSQL, ClickHouse, BigQuery, Ansi." + ) + })?; + + let expr = DFParser::parse_sql_into_expr_with_dialect(sql, dialect.as_ref())?; + + Ok(expr) + } + + /// Resolve all table references in the SQL statement. Does not include CTE references. + /// + /// See [`catalog::resolve_table_references`] for more information. + /// + /// [`catalog::resolve_table_references`]: crate::catalog_common::resolve_table_references + pub fn resolve_table_references( + &self, + statement: &Statement, + ) -> datafusion_common::Result> { + let enable_ident_normalization = + self.config.options().sql_parser.enable_ident_normalization; + let (table_refs, _) = crate::catalog_common::resolve_table_references( + statement, + enable_ident_normalization, + )?; + Ok(table_refs) + } + + /// Convert an AST Statement into a LogicalPlan + pub async fn statement_to_plan( + &self, + statement: Statement, + ) -> datafusion_common::Result { + let references = self.resolve_table_references(&statement)?; + + let mut provider = SessionContextProvider { + state: self, + tables: HashMap::with_capacity(references.len()), + }; + + for reference in references { + let resolved = &self.resolve_table_ref(reference); + if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) { + if let Ok(schema) = self.schema_for_ref(resolved.clone()) { + if let Some(table) = schema.table(&resolved.table).await? { + v.insert(provider_as_source(table)); + } + } + } + } + + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); + query.statement_to_plan(statement) + } + + fn get_parser_options(&self) -> ParserOptions { + let sql_parser_options = &self.config.options().sql_parser; + + ParserOptions { + parse_float_as_decimal: sql_parser_options.parse_float_as_decimal, + enable_ident_normalization: sql_parser_options.enable_ident_normalization, + enable_options_value_normalization: sql_parser_options + .enable_options_value_normalization, + support_varchar_with_length: sql_parser_options.support_varchar_with_length, + } + } + + /// Creates a [`LogicalPlan`] from the provided SQL string. This + /// interface will plan any SQL DataFusion supports, including DML + /// like `CREATE TABLE`, and `COPY` (which can write to local + /// files. + /// + /// See [`SessionContext::sql`] and + /// [`SessionContext::sql_with_options`] for a higher-level + /// interface that handles DDL and verification of allowed + /// statements. + /// + /// [`SessionContext::sql`]: crate::execution::context::SessionContext::sql + /// [`SessionContext::sql_with_options`]: crate::execution::context::SessionContext::sql_with_options + pub async fn create_logical_plan( + &self, + sql: &str, + ) -> datafusion_common::Result { + let dialect = self.config.options().sql_parser.dialect.as_str(); + let statement = self.sql_to_statement(sql, dialect)?; + let plan = self.statement_to_plan(statement).await?; + Ok(plan) + } + + /// Creates a datafusion style AST [`Expr`] from a SQL string. + /// + /// See example on [SessionContext::parse_sql_expr](crate::execution::context::SessionContext::parse_sql_expr) + pub fn create_logical_expr( + &self, + sql: &str, + df_schema: &DFSchema, + ) -> datafusion_common::Result { + let dialect = self.config.options().sql_parser.dialect.as_str(); + + let sql_expr = self.sql_to_expr(sql, dialect)?; + + let provider = SessionContextProvider { + state: self, + tables: HashMap::new(), + }; + + let query = SqlToRel::new_with_options(&provider, self.get_parser_options()); + query.sql_to_expr(sql_expr, df_schema, &mut PlannerContext::new()) + } + + /// Returns the [`Analyzer`] for this session + pub fn analyzer(&self) -> &Analyzer { + &self.analyzer + } + + /// Returns the [`Optimizer`] for this session + pub fn optimizer(&self) -> &Optimizer { + &self.optimizer + } + + /// Returns the [`QueryPlanner`] for this session + pub fn query_planner(&self) -> &Arc { + &self.query_planner + } + + /// Optimizes the logical plan by applying optimizer rules. + pub fn optimize(&self, plan: &LogicalPlan) -> datafusion_common::Result { + if let LogicalPlan::Explain(e) = plan { + let mut stringified_plans = e.stringified_plans.clone(); + + // analyze & capture output of each rule + let analyzer_result = self.analyzer.execute_and_check( + e.plan.as_ref().clone(), + self.options(), + |analyzed_plan, analyzer| { + let analyzer_name = analyzer.name().to_string(); + let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; + stringified_plans.push(analyzed_plan.to_stringified(plan_type)); + }, + ); + let analyzed_plan = match analyzer_result { + Ok(plan) => plan, + Err(DataFusionError::Context(analyzer_name, err)) => { + let plan_type = PlanType::AnalyzedLogicalPlan { analyzer_name }; + stringified_plans + .push(StringifiedPlan::new(plan_type, err.to_string())); + + return Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan: e.plan.clone(), + stringified_plans, + schema: e.schema.clone(), + logical_optimization_succeeded: false, + })); + } + Err(e) => return Err(e), + }; + + // to delineate the analyzer & optimizer phases in explain output + stringified_plans + .push(analyzed_plan.to_stringified(PlanType::FinalAnalyzedLogicalPlan)); + + // optimize the child plan, capturing the output of each optimizer + let optimized_plan = self.optimizer.optimize( + analyzed_plan, + self, + |optimized_plan, optimizer| { + let optimizer_name = optimizer.name().to_string(); + let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; + stringified_plans.push(optimized_plan.to_stringified(plan_type)); + }, + ); + let (plan, logical_optimization_succeeded) = match optimized_plan { + Ok(plan) => (Arc::new(plan), true), + Err(DataFusionError::Context(optimizer_name, err)) => { + let plan_type = PlanType::OptimizedLogicalPlan { optimizer_name }; + stringified_plans + .push(StringifiedPlan::new(plan_type, err.to_string())); + (e.plan.clone(), false) + } + Err(e) => return Err(e), + }; + + Ok(LogicalPlan::Explain(Explain { + verbose: e.verbose, + plan, + stringified_plans, + schema: e.schema.clone(), + logical_optimization_succeeded, + })) + } else { + let analyzed_plan = self.analyzer.execute_and_check( + plan.clone(), + self.options(), + |_, _| {}, + )?; + self.optimizer.optimize(analyzed_plan, self, |_, _| {}) + } + } + + /// Creates a physical [`ExecutionPlan`] plan from a [`LogicalPlan`]. + /// + /// Note: this first calls [`Self::optimize`] on the provided + /// plan. + /// + /// This function will error for [`LogicalPlan`]s such as catalog DDL like + /// `CREATE TABLE`, which do not have corresponding physical plans and must + /// be handled by another layer, typically [`SessionContext`]. + /// + /// [`SessionContext`]: crate::execution::context::SessionContext + pub async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + ) -> datafusion_common::Result> { + let logical_plan = self.optimize(logical_plan)?; + self.query_planner + .create_physical_plan(&logical_plan, self) + .await + } + + /// Create a [`PhysicalExpr`] from an [`Expr`] after applying type + /// coercion, and function rewrites. + /// + /// Note: The expression is not [simplified] or otherwise optimized: + /// `a = 1 + 2` will not be simplified to `a = 3` as this is a more involved process. + /// See the [expr_api] example for how to simplify expressions. + /// + /// # See Also: + /// * [`SessionContext::create_physical_expr`] for a higher-level API + /// * [`create_physical_expr`] for a lower-level API + /// + /// [simplified]: datafusion_optimizer::simplify_expressions + /// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs + /// [`SessionContext::create_physical_expr`]: crate::execution::context::SessionContext::create_physical_expr + pub fn create_physical_expr( + &self, + expr: Expr, + df_schema: &DFSchema, + ) -> datafusion_common::Result> { + let simplifier = + ExprSimplifier::new(SessionSimplifyProvider::new(self, df_schema)); + // apply type coercion here to ensure types match + let mut expr = simplifier.coerce(expr, df_schema)?; + + // rewrite Exprs to functions if necessary + let config_options = self.config_options(); + for rewrite in self.analyzer.function_rewrites() { + expr = expr + .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? + .data; + } + create_physical_expr(&expr, df_schema, self.execution_props()) + } + + /// Return the session ID + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Return the runtime env + pub fn runtime_env(&self) -> &Arc { + &self.runtime_env + } + + /// Return the execution properties + pub fn execution_props(&self) -> &ExecutionProps { + &self.execution_props + } + + /// Return mutable execution properties + pub fn execution_props_mut(&mut self) -> &mut ExecutionProps { + &mut self.execution_props + } + + /// Return the [`SessionConfig`] + pub fn config(&self) -> &SessionConfig { + &self.config + } + + /// Return the mutable [`SessionConfig`]. + pub fn config_mut(&mut self) -> &mut SessionConfig { + &mut self.config + } + + /// Return the logical optimizers + pub fn optimizers(&self) -> &[Arc] { + &self.optimizer.rules + } + + /// Return the physical optimizers + pub fn physical_optimizers(&self) -> &[Arc] { + &self.physical_optimizers.rules + } + + /// return the configuration options + pub fn config_options(&self) -> &ConfigOptions { + self.config.options() + } + + /// return the TableOptions options with its extensions + pub fn default_table_options(&self) -> TableOptions { + self.table_options + .combine_with_session_config(self.config_options()) + } + + /// Return the table options + pub fn table_options(&self) -> &TableOptions { + &self.table_options + } + + /// Return mutable table options + pub fn table_options_mut(&mut self) -> &mut TableOptions { + &mut self.table_options + } + + /// Registers a [`ConfigExtension`] as a table option extension that can be + /// referenced from SQL statements executed against this context. + pub fn register_table_options_extension(&mut self, extension: T) { + self.table_options.extensions.insert(extension) + } + + /// Adds or updates a [FileFormatFactory] which can be used with COPY TO or + /// CREATE EXTERNAL TABLE statements for reading and writing files of custom + /// formats. + pub fn register_file_format( + &mut self, + file_format: Arc, + overwrite: bool, + ) -> Result<(), DataFusionError> { + let ext = file_format.get_ext().to_lowercase(); + match (self.file_formats.entry(ext.clone()), overwrite){ + (Entry::Vacant(e), _) => {e.insert(file_format);}, + (Entry::Occupied(mut e), true) => {e.insert(file_format);}, + (Entry::Occupied(_), false) => return config_err!("File type already registered for extension {ext}. Set overwrite to true to replace this extension."), + }; + Ok(()) + } + + /// Retrieves a [FileFormatFactory] based on file extension which has been registered + /// via SessionContext::register_file_format. Extensions are not case sensitive. + pub fn get_file_format_factory( + &self, + ext: &str, + ) -> Option> { + self.file_formats.get(&ext.to_lowercase()).cloned() + } + + /// Get a new TaskContext to run in this session + pub fn task_ctx(&self) -> Arc { + Arc::new(TaskContext::from(self)) + } + + /// Return catalog list + pub fn catalog_list(&self) -> &Arc { + &self.catalog_list + } + + /// set the catalog list + pub(crate) fn register_catalog_list( + &mut self, + catalog_list: Arc, + ) { + self.catalog_list = catalog_list; + } + + /// Return reference to scalar_functions + pub fn scalar_functions(&self) -> &HashMap> { + &self.scalar_functions + } + + /// Return reference to aggregate_functions + pub fn aggregate_functions(&self) -> &HashMap> { + &self.aggregate_functions + } + + /// Return reference to window functions + pub fn window_functions(&self) -> &HashMap> { + &self.window_functions + } + + /// Return reference to table_functions + pub fn table_functions(&self) -> &HashMap> { + &self.table_functions + } + + /// Return [SerializerRegistry] for extensions + pub fn serializer_registry(&self) -> &Arc { + &self.serializer_registry + } + + /// Return version of the cargo package that produced this query + pub fn version(&self) -> &str { + env!("CARGO_PKG_VERSION") + } + + /// Register a user defined table function + pub fn register_udtf(&mut self, name: &str, fun: Arc) { + self.table_functions.insert( + name.to_owned(), + Arc::new(TableFunction::new(name.to_owned(), fun)), + ); + } + + /// Deregister a user defined table function + pub fn deregister_udtf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udtf = self.table_functions.remove(name); + Ok(udtf.map(|x| x.function().clone())) + } +} + +/// A builder to be used for building [`SessionState`]'s. Defaults will +/// be used for all values unless explicitly provided. +/// +/// See example on [`SessionState`] +pub struct SessionStateBuilder { + session_id: Option, + analyzer: Option, + expr_planners: Option>>, + optimizer: Option, + physical_optimizers: Option, + query_planner: Option>, + catalog_list: Option>, + table_functions: Option>>, + scalar_functions: Option>>, + aggregate_functions: Option>>, + window_functions: Option>>, + serializer_registry: Option>, + file_formats: Option>>, + config: Option, + table_options: Option, + execution_props: Option, + table_factories: Option>>, + runtime_env: Option>, + function_factory: Option>, + // fields to support convenience functions + analyzer_rules: Option>>, + optimizer_rules: Option>>, + physical_optimizer_rules: Option>>, +} + +impl SessionStateBuilder { + /// Returns a new [`SessionStateBuilder`] with no options set. + pub fn new() -> Self { + Self { + session_id: None, + analyzer: None, + expr_planners: None, + optimizer: None, + physical_optimizers: None, + query_planner: None, + catalog_list: None, + table_functions: None, + scalar_functions: None, + aggregate_functions: None, + window_functions: None, + serializer_registry: None, + file_formats: None, + table_options: None, + config: None, + execution_props: None, + table_factories: None, + runtime_env: None, + function_factory: None, + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, + } + } + + /// Returns a new [SessionStateBuilder] based on an existing [SessionState] + /// The session id for the new builder will be unset; all other fields will + /// be cloned from what is set in the provided session state. If the default + /// catalog exists in existing session state, the new session state will not + /// create default catalog and schema. + pub fn new_from_existing(existing: SessionState) -> Self { + let default_catalog_exist = existing + .catalog_list() + .catalog(&existing.config.options().catalog.default_catalog) + .is_some(); + // The new `with_create_default_catalog_and_schema` should be false if the default catalog exists + let create_default_catalog_and_schema = existing + .config + .options() + .catalog + .create_default_catalog_and_schema + && !default_catalog_exist; + let new_config = existing + .config + .with_create_default_catalog_and_schema(create_default_catalog_and_schema); + Self { + session_id: None, + analyzer: Some(existing.analyzer), + expr_planners: Some(existing.expr_planners), + optimizer: Some(existing.optimizer), + physical_optimizers: Some(existing.physical_optimizers), + query_planner: Some(existing.query_planner), + catalog_list: Some(existing.catalog_list), + table_functions: Some(existing.table_functions), + scalar_functions: Some(existing.scalar_functions.into_values().collect_vec()), + aggregate_functions: Some( + existing.aggregate_functions.into_values().collect_vec(), + ), + window_functions: Some(existing.window_functions.into_values().collect_vec()), + serializer_registry: Some(existing.serializer_registry), + file_formats: Some(existing.file_formats.into_values().collect_vec()), + config: Some(new_config), + table_options: Some(existing.table_options), + execution_props: Some(existing.execution_props), + table_factories: Some(existing.table_factories), + runtime_env: Some(existing.runtime_env), + function_factory: existing.function_factory, + + // fields to support convenience functions + analyzer_rules: None, + optimizer_rules: None, + physical_optimizer_rules: None, + } + } + + /// Create default builder with defaults for table_factories, file formats, expr_planners and builtin + /// scalar, aggregate and windows functions. + pub fn with_default_features(self) -> Self { + self.with_table_factories(SessionStateDefaults::default_table_factories()) + .with_file_formats(SessionStateDefaults::default_file_formats()) + .with_expr_planners(SessionStateDefaults::default_expr_planners()) + .with_scalar_functions(SessionStateDefaults::default_scalar_functions()) + .with_aggregate_functions(SessionStateDefaults::default_aggregate_functions()) + .with_window_functions(SessionStateDefaults::default_window_functions()) + } + + /// Set the session id. + pub fn with_session_id(mut self, session_id: String) -> Self { + self.session_id = Some(session_id); + self + } + + /// Set the [`AnalyzerRule`]s optimizer plan rules. + pub fn with_analyzer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.analyzer = Some(Analyzer::with_rules(rules)); + self + } + + /// Add `analyzer_rule` to the end of the list of + /// [`AnalyzerRule`]s used to rewrite queries. + pub fn with_analyzer_rule( + mut self, + analyzer_rule: Arc, + ) -> Self { + let mut rules = self.analyzer_rules.unwrap_or_default(); + rules.push(analyzer_rule); + self.analyzer_rules = Some(rules); + self + } + + /// Set the [`OptimizerRule`]s used to optimize plans. + pub fn with_optimizer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.optimizer = Some(Optimizer::with_rules(rules)); + self + } + + /// Add `optimizer_rule` to the end of the list of + /// [`OptimizerRule`]s used to rewrite queries. + pub fn with_optimizer_rule( + mut self, + optimizer_rule: Arc, + ) -> Self { + let mut rules = self.optimizer_rules.unwrap_or_default(); + rules.push(optimizer_rule); + self.optimizer_rules = Some(rules); + self + } + + /// Set the [`ExprPlanner`]s used to customize the behavior of the SQL planner. + pub fn with_expr_planners( + mut self, + expr_planners: Vec>, + ) -> Self { + self.expr_planners = Some(expr_planners); + self + } + + /// Set the [`PhysicalOptimizerRule`]s used to optimize plans. + pub fn with_physical_optimizer_rules( + mut self, + physical_optimizers: Vec>, + ) -> Self { + self.physical_optimizers = + Some(PhysicalOptimizer::with_rules(physical_optimizers)); + self + } + + /// Add `physical_optimizer_rule` to the end of the list of + /// [`PhysicalOptimizerRule`]s used to rewrite queries. + pub fn with_physical_optimizer_rule( + mut self, + physical_optimizer_rule: Arc, + ) -> Self { + let mut rules = self.physical_optimizer_rules.unwrap_or_default(); + rules.push(physical_optimizer_rule); + self.physical_optimizer_rules = Some(rules); + self + } + + /// Set the [`QueryPlanner`] + pub fn with_query_planner( + mut self, + query_planner: Arc, + ) -> Self { + self.query_planner = Some(query_planner); + self + } + + /// Set the [`CatalogProviderList`] + pub fn with_catalog_list( + mut self, + catalog_list: Arc, + ) -> Self { + self.catalog_list = Some(catalog_list); + self + } + + /// Set the map of [`TableFunction`]s + pub fn with_table_functions( + mut self, + table_functions: HashMap>, + ) -> Self { + self.table_functions = Some(table_functions); + self + } + + /// Set the map of [`ScalarUDF`]s + pub fn with_scalar_functions( + mut self, + scalar_functions: Vec>, + ) -> Self { + self.scalar_functions = Some(scalar_functions); + self + } + + /// Set the map of [`AggregateUDF`]s + pub fn with_aggregate_functions( + mut self, + aggregate_functions: Vec>, + ) -> Self { + self.aggregate_functions = Some(aggregate_functions); + self + } + + /// Set the map of [`WindowUDF`]s + pub fn with_window_functions( + mut self, + window_functions: Vec>, + ) -> Self { + self.window_functions = Some(window_functions); + self + } + + /// Set the [`SerializerRegistry`] + pub fn with_serializer_registry( + mut self, + serializer_registry: Arc, + ) -> Self { + self.serializer_registry = Some(serializer_registry); + self + } + + /// Set the map of [`FileFormatFactory`]s + pub fn with_file_formats( + mut self, + file_formats: Vec>, + ) -> Self { + self.file_formats = Some(file_formats); + self + } + + /// Set the [`SessionConfig`] + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.config = Some(config); + self + } + + /// Set the [`TableOptions`] + pub fn with_table_options(mut self, table_options: TableOptions) -> Self { + self.table_options = Some(table_options); + self + } + + /// Set the [`ExecutionProps`] + pub fn with_execution_props(mut self, execution_props: ExecutionProps) -> Self { + self.execution_props = Some(execution_props); + self + } + + /// Add a [`TableProviderFactory`] to the map of factories + pub fn with_table_factory( + mut self, + key: String, + table_factory: Arc, + ) -> Self { + let mut table_factories = self.table_factories.unwrap_or_default(); + table_factories.insert(key, table_factory); + self.table_factories = Some(table_factories); + self + } + + /// Set the map of [`TableProviderFactory`]s + pub fn with_table_factories( + mut self, + table_factories: HashMap>, + ) -> Self { + self.table_factories = Some(table_factories); + self + } + + /// Set the [`RuntimeEnv`] + pub fn with_runtime_env(mut self, runtime_env: Arc) -> Self { + self.runtime_env = Some(runtime_env); + self + } + + /// Set a [`FunctionFactory`] to handle `CREATE FUNCTION` statements + pub fn with_function_factory( + mut self, + function_factory: Option>, + ) -> Self { + self.function_factory = function_factory; + self + } + + /// Register an `ObjectStore` to the [`RuntimeEnv`]. See [`RuntimeEnv::register_object_store`] + /// for more details. + /// + /// Note that this creates a default [`RuntimeEnv`] if there isn't one passed in already. + /// + /// ``` + /// # use datafusion::prelude::*; + /// # use datafusion::execution::session_state::SessionStateBuilder; + /// # use datafusion_execution::runtime_env::RuntimeEnv; + /// # use url::Url; + /// # use std::sync::Arc; + /// # let http_store = object_store::local::LocalFileSystem::new(); + /// let url = Url::try_from("file://").unwrap(); + /// let object_store = object_store::local::LocalFileSystem::new(); + /// let state = SessionStateBuilder::new() + /// .with_config(SessionConfig::new()) + /// .with_object_store(&url, Arc::new(object_store)) + /// .with_default_features() + /// .build(); + /// ``` + pub fn with_object_store( + mut self, + url: &Url, + object_store: Arc, + ) -> Self { + if self.runtime_env.is_none() { + self.runtime_env = Some(Arc::new(RuntimeEnv::default())); + } + self.runtime_env + .as_ref() + .unwrap() + .register_object_store(url, object_store); + self + } + + /// Builds a [`SessionState`] with the current configuration. + /// + /// Note that there is an explicit option for enabling catalog and schema defaults + /// in [SessionConfig::create_default_catalog_and_schema] which if enabled + /// will be built here. + pub fn build(self) -> SessionState { + let Self { + session_id, + analyzer, + expr_planners, + optimizer, + physical_optimizers, + query_planner, + catalog_list, + table_functions, + scalar_functions, + aggregate_functions, + window_functions, + serializer_registry, + file_formats, + table_options, + config, + execution_props, + table_factories, + runtime_env, + function_factory, + analyzer_rules, + optimizer_rules, + physical_optimizer_rules, + } = self; + + let config = config.unwrap_or_default(); + let runtime_env = runtime_env.unwrap_or(Arc::new(RuntimeEnv::default())); + + let mut state = SessionState { + session_id: session_id.unwrap_or(Uuid::new_v4().to_string()), + analyzer: analyzer.unwrap_or_default(), + expr_planners: expr_planners.unwrap_or_default(), + optimizer: optimizer.unwrap_or_default(), + physical_optimizers: physical_optimizers.unwrap_or_default(), + query_planner: query_planner.unwrap_or(Arc::new(DefaultQueryPlanner {})), + catalog_list: catalog_list + .unwrap_or(Arc::new(MemoryCatalogProviderList::new()) + as Arc), + table_functions: table_functions.unwrap_or_default(), + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + window_functions: HashMap::new(), + serializer_registry: serializer_registry + .unwrap_or(Arc::new(EmptySerializerRegistry)), + file_formats: HashMap::new(), + table_options: table_options + .unwrap_or(TableOptions::default_from_session_config(config.options())), + config, + execution_props: execution_props.unwrap_or_default(), + table_factories: table_factories.unwrap_or_default(), + runtime_env, + function_factory, + }; + + if let Some(file_formats) = file_formats { + for file_format in file_formats { + if let Err(e) = state.register_file_format(file_format, false) { + info!("Unable to register file format: {e}") + }; + } + } + + if let Some(scalar_functions) = scalar_functions { + scalar_functions.into_iter().for_each(|udf| { + let existing_udf = state.register_udf(udf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(aggregate_functions) = aggregate_functions { + aggregate_functions.into_iter().for_each(|udaf| { + let existing_udf = state.register_udaf(udaf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if let Some(window_functions) = window_functions { + window_functions.into_iter().for_each(|udwf| { + let existing_udf = state.register_udwf(udwf); + if let Ok(Some(existing_udf)) = existing_udf { + debug!("Overwrote an existing UDF: {}", existing_udf.name()); + } + }); + } + + if state.config.create_default_catalog_and_schema() { + let default_catalog = SessionStateDefaults::default_catalog( + &state.config, + &state.table_factories, + &state.runtime_env, + ); + + state.catalog_list.register_catalog( + state.config.options().catalog.default_catalog.clone(), + Arc::new(default_catalog), + ); + } + + if let Some(analyzer_rules) = analyzer_rules { + for analyzer_rule in analyzer_rules { + state.analyzer.rules.push(analyzer_rule); + } + } + + if let Some(optimizer_rules) = optimizer_rules { + for optimizer_rule in optimizer_rules { + state.optimizer.rules.push(optimizer_rule); + } + } + + if let Some(physical_optimizer_rules) = physical_optimizer_rules { + for physical_optimizer_rule in physical_optimizer_rules { + state + .physical_optimizers + .rules + .push(physical_optimizer_rule); + } + } + + state + } + + /// Returns the current session_id value + pub fn session_id(&self) -> &Option { + &self.session_id + } + + /// Returns the current analyzer value + pub fn analyzer(&mut self) -> &mut Option { + &mut self.analyzer + } + + /// Returns the current expr_planners value + pub fn expr_planners(&mut self) -> &mut Option>> { + &mut self.expr_planners + } + + /// Returns the current optimizer value + pub fn optimizer(&mut self) -> &mut Option { + &mut self.optimizer + } + + /// Returns the current physical_optimizers value + pub fn physical_optimizers(&mut self) -> &mut Option { + &mut self.physical_optimizers + } + + /// Returns the current query_planner value + pub fn query_planner(&mut self) -> &mut Option> { + &mut self.query_planner + } + + /// Returns the current catalog_list value + pub fn catalog_list(&mut self) -> &mut Option> { + &mut self.catalog_list + } + + /// Returns the current table_functions value + pub fn table_functions( + &mut self, + ) -> &mut Option>> { + &mut self.table_functions + } + + /// Returns the current scalar_functions value + pub fn scalar_functions(&mut self) -> &mut Option>> { + &mut self.scalar_functions + } + + /// Returns the current aggregate_functions value + pub fn aggregate_functions(&mut self) -> &mut Option>> { + &mut self.aggregate_functions + } + + /// Returns the current window_functions value + pub fn window_functions(&mut self) -> &mut Option>> { + &mut self.window_functions + } + + /// Returns the current serializer_registry value + pub fn serializer_registry(&mut self) -> &mut Option> { + &mut self.serializer_registry + } + + /// Returns the current file_formats value + pub fn file_formats(&mut self) -> &mut Option>> { + &mut self.file_formats + } + + /// Returns the current session_config value + pub fn config(&mut self) -> &mut Option { + &mut self.config + } + + /// Returns the current table_options value + pub fn table_options(&mut self) -> &mut Option { + &mut self.table_options + } + + /// Returns the current execution_props value + pub fn execution_props(&mut self) -> &mut Option { + &mut self.execution_props + } + + /// Returns the current table_factories value + pub fn table_factories( + &mut self, + ) -> &mut Option>> { + &mut self.table_factories + } + + /// Returns the current runtime_env value + pub fn runtime_env(&mut self) -> &mut Option> { + &mut self.runtime_env + } + + /// Returns the current function_factory value + pub fn function_factory(&mut self) -> &mut Option> { + &mut self.function_factory + } + + /// Returns the current analyzer_rules value + pub fn analyzer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.analyzer_rules + } + + /// Returns the current optimizer_rules value + pub fn optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.optimizer_rules + } + + /// Returns the current physical_optimizer_rules value + pub fn physical_optimizer_rules( + &mut self, + ) -> &mut Option>> { + &mut self.physical_optimizer_rules + } +} + +impl Debug for SessionStateBuilder { + /// Prefer having short fields at the top and long vector fields near the end + /// Group fields by + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionStateBuilder") + .field("session_id", &self.session_id) + .field("config", &self.config) + .field("runtime_env", &self.runtime_env) + .field("catalog_list", &self.catalog_list) + .field("serializer_registry", &self.serializer_registry) + .field("file_formats", &self.file_formats) + .field("execution_props", &self.execution_props) + .field("table_options", &self.table_options) + .field("table_factories", &self.table_factories) + .field("function_factory", &self.function_factory) + .field("expr_planners", &self.expr_planners) + .field("query_planners", &self.query_planner) + .field("analyzer_rules", &self.analyzer_rules) + .field("analyzer", &self.analyzer) + .field("optimizer_rules", &self.optimizer_rules) + .field("optimizer", &self.optimizer) + .field("physical_optimizer_rules", &self.physical_optimizer_rules) + .field("physical_optimizers", &self.physical_optimizers) + .field("table_functions", &self.table_functions) + .field("scalar_functions", &self.scalar_functions) + .field("aggregate_functions", &self.aggregate_functions) + .field("window_functions", &self.window_functions) + .finish() + } +} + +impl Default for SessionStateBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for SessionStateBuilder { + fn from(state: SessionState) -> Self { + SessionStateBuilder::new_from_existing(state) + } +} + +/// Adapter that implements the [`ContextProvider`] trait for a [`SessionState`] +/// +/// This is used so the SQL planner can access the state of the session without +/// having a direct dependency on the [`SessionState`] struct (and core crate) +struct SessionContextProvider<'a> { + state: &'a SessionState, + tables: HashMap>, +} + +impl<'a> ContextProvider for SessionContextProvider<'a> { + fn get_expr_planners(&self) -> &[Arc] { + &self.state.expr_planners + } + + fn get_table_source( + &self, + name: TableReference, + ) -> datafusion_common::Result> { + let name = self.state.resolve_table_ref(name).to_string(); + self.tables + .get(&name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) + } + + fn get_table_function_source( + &self, + name: &str, + args: Vec, + ) -> datafusion_common::Result> { + let tbl_func = self + .state + .table_functions + .get(name) + .cloned() + .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; + let provider = tbl_func.create_table_provider(&args)?; + + Ok(provider_as_source(provider)) + } + + /// Create a new CTE work table for a recursive CTE logical plan + /// This table will be used in conjunction with a Worktable physical plan + /// to read and write each iteration of a recursive CTE + fn create_cte_work_table( + &self, + name: &str, + schema: SchemaRef, + ) -> datafusion_common::Result> { + let table = Arc::new(CteWorkTable::new(name, schema)); + Ok(provider_as_source(table)) + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions().get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions().get(name).cloned() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions().get(name).cloned() + } + + fn get_variable_type(&self, variable_names: &[String]) -> Option { + if variable_names.is_empty() { + return None; + } + + let provider_type = if is_system_variables(variable_names) { + VarType::System + } else { + VarType::UserDefined + }; + + self.state + .execution_props + .var_providers + .as_ref() + .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names)) + } + + fn options(&self) -> &ConfigOptions { + self.state.config_options() + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions().keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions().keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + self.state.window_functions().keys().cloned().collect() + } + + fn get_file_type(&self, ext: &str) -> datafusion_common::Result> { + self.state + .file_formats + .get(&ext.to_lowercase()) + .ok_or(plan_datafusion_err!( + "There is no registered file format with ext {ext}" + )) + .map(|file_type| format_as_file_type(file_type.clone())) + } +} + +impl FunctionRegistry for SessionState { + fn udfs(&self) -> HashSet { + self.scalar_functions.keys().cloned().collect() + } + + fn udf(&self, name: &str) -> datafusion_common::Result> { + let result = self.scalar_functions.get(name); + + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDF named \"{name}\" in the registry") + }) + } + + fn udaf(&self, name: &str) -> datafusion_common::Result> { + let result = self.aggregate_functions.get(name); + + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDAF named \"{name}\" in the registry") + }) + } + + fn udwf(&self, name: &str) -> datafusion_common::Result> { + let result = self.window_functions.get(name); + + result.cloned().ok_or_else(|| { + plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") + }) + } + + fn register_udf( + &mut self, + udf: Arc, + ) -> datafusion_common::Result>> { + udf.aliases().iter().for_each(|alias| { + self.scalar_functions.insert(alias.clone(), udf.clone()); + }); + Ok(self.scalar_functions.insert(udf.name().into(), udf)) + } + + fn register_udaf( + &mut self, + udaf: Arc, + ) -> datafusion_common::Result>> { + udaf.aliases().iter().for_each(|alias| { + self.aggregate_functions.insert(alias.clone(), udaf.clone()); + }); + Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) + } + + fn register_udwf( + &mut self, + udwf: Arc, + ) -> datafusion_common::Result>> { + udwf.aliases().iter().for_each(|alias| { + self.window_functions.insert(alias.clone(), udwf.clone()); + }); + Ok(self.window_functions.insert(udwf.name().into(), udwf)) + } + + fn deregister_udf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udf = self.scalar_functions.remove(name); + if let Some(udf) = &udf { + for alias in udf.aliases() { + self.scalar_functions.remove(alias); + } + } + Ok(udf) + } + + fn deregister_udaf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udaf = self.aggregate_functions.remove(name); + if let Some(udaf) = &udaf { + for alias in udaf.aliases() { + self.aggregate_functions.remove(alias); + } + } + Ok(udaf) + } + + fn deregister_udwf( + &mut self, + name: &str, + ) -> datafusion_common::Result>> { + let udwf = self.window_functions.remove(name); + if let Some(udwf) = &udwf { + for alias in udwf.aliases() { + self.window_functions.remove(alias); + } + } + Ok(udwf) + } + + fn register_function_rewrite( + &mut self, + rewrite: Arc, + ) -> datafusion_common::Result<()> { + self.analyzer.add_function_rewrite(rewrite); + Ok(()) + } + + fn expr_planners(&self) -> Vec> { + self.expr_planners.clone() + } + + fn register_expr_planner( + &mut self, + expr_planner: Arc, + ) -> datafusion_common::Result<()> { + self.expr_planners.push(expr_planner); + Ok(()) + } +} + +impl OptimizerConfig for SessionState { + fn query_execution_start_time(&self) -> DateTime { + self.execution_props.query_execution_start_time + } + + fn alias_generator(&self) -> &Arc { + &self.execution_props.alias_generator + } + + fn options(&self) -> &ConfigOptions { + self.config_options() + } + + fn function_registry(&self) -> Option<&dyn FunctionRegistry> { + Some(self) + } +} + +/// Create a new task context instance from SessionState +impl From<&SessionState> for TaskContext { + fn from(state: &SessionState) -> Self { + let task_id = None; + TaskContext::new( + task_id, + state.session_id.clone(), + state.config.clone(), + state.scalar_functions.clone(), + state.aggregate_functions.clone(), + state.window_functions.clone(), + state.runtime_env.clone(), + ) + } +} + +/// The query planner used if no user defined planner is provided +#[derive(Debug)] +struct DefaultQueryPlanner {} + +#[async_trait] +impl QueryPlanner for DefaultQueryPlanner { + /// Given a `LogicalPlan`, create an [`ExecutionPlan`] suitable for execution + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> datafusion_common::Result> { + let planner = DefaultPhysicalPlanner::default(); + planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +struct SessionSimplifyProvider<'a> { + state: &'a SessionState, + df_schema: &'a DFSchema, +} + +impl<'a> SessionSimplifyProvider<'a> { + fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self { + Self { state, df_schema } + } +} + +impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> { + fn is_boolean_type(&self, expr: &Expr) -> datafusion_common::Result { + Ok(expr.get_type(self.df_schema)? == DataType::Boolean) + } + + fn nullable(&self, expr: &Expr) -> datafusion_common::Result { + expr.nullable(self.df_schema) + } + + fn execution_props(&self) -> &ExecutionProps { + self.state.execution_props() + } + + fn get_data_type(&self, expr: &Expr) -> datafusion_common::Result { + expr.get_type(self.df_schema) + } +} + +#[cfg(test)] +mod tests { + use super::{SessionContextProvider, SessionStateBuilder}; + use crate::catalog_common::MemoryCatalogProviderList; + use crate::datasource::MemTable; + use crate::execution::context::SessionState; + use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::DFSchema; + use datafusion_common::Result; + use datafusion_execution::config::SessionConfig; + use datafusion_expr::Expr; + use datafusion_optimizer::optimizer::OptimizerRule; + use datafusion_optimizer::Optimizer; + use datafusion_sql::planner::{PlannerContext, SqlToRel}; + use std::collections::HashMap; + use std::sync::Arc; + + #[test] + fn test_session_state_with_default_features() { + // test array planners with and without builtin planners + fn sql_to_expr(state: &SessionState) -> Result { + let provider = SessionContextProvider { + state, + tables: HashMap::new(), + }; + + let sql = "[1,2,3]"; + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from(schema)?; + let dialect = state.config.options().sql_parser.dialect.as_str(); + let sql_expr = state.sql_to_expr(sql, dialect)?; + + let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); + query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) + } + + let state = SessionStateBuilder::new().with_default_features().build(); + + assert!(sql_to_expr(&state).is_ok()); + + // if no builtin planners exist, you should register your own, otherwise returns error + let state = SessionStateBuilder::new().build(); + + assert!(sql_to_expr(&state).is_err()) + } + + #[test] + fn test_from_existing() -> Result<()> { + fn employee_batch() -> RecordBatch { + let name: ArrayRef = + Arc::new(StringArray::from_iter_values(["Andy", "Andrew"])); + let age: ArrayRef = Arc::new(Int32Array::from(vec![11, 22])); + RecordBatch::try_from_iter(vec![("name", name), ("age", age)]).unwrap() + } + let batch = employee_batch(); + let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; + + let session_state = SessionStateBuilder::new() + .with_catalog_list(Arc::new(MemoryCatalogProviderList::new())) + .build(); + let table_ref = session_state.resolve_table_ref("employee").to_string(); + session_state + .schema_for_ref(&table_ref)? + .register_table("employee".to_string(), Arc::new(table))?; + + let default_catalog = session_state + .config + .options() + .catalog + .default_catalog + .clone(); + let default_schema = session_state + .config + .options() + .catalog + .default_schema + .clone(); + let is_exist = session_state + .catalog_list() + .catalog(default_catalog.as_str()) + .unwrap() + .schema(default_schema.as_str()) + .unwrap() + .table_exist("employee"); + assert!(is_exist); + let new_state = SessionStateBuilder::new_from_existing(session_state).build(); + assert!(new_state + .catalog_list() + .catalog(default_catalog.as_str()) + .unwrap() + .schema(default_schema.as_str()) + .unwrap() + .table_exist("employee")); + + // if `with_create_default_catalog_and_schema` is disabled, the new one shouldn't create default catalog and schema + let disable_create_default = + SessionConfig::default().with_create_default_catalog_and_schema(false); + let without_default_state = SessionStateBuilder::new() + .with_config(disable_create_default) + .build(); + assert!(without_default_state + .catalog_list() + .catalog(&default_catalog) + .is_none()); + let new_state = + SessionStateBuilder::new_from_existing(without_default_state).build(); + assert!(new_state.catalog_list().catalog(&default_catalog).is_none()); + Ok(()) + } + + #[test] + fn test_session_state_with_optimizer_rules() { + #[derive(Default, Debug)] + struct DummyRule {} + + impl OptimizerRule for DummyRule { + fn name(&self) -> &str { + "dummy_rule" + } + } + // test building sessions with fresh set of rules + let state = SessionStateBuilder::new() + .with_optimizer_rules(vec![Arc::new(DummyRule {})]) + .build(); + + assert_eq!(state.optimizers().len(), 1); + + // test adding rules to default recommendations + let state = SessionStateBuilder::new() + .with_optimizer_rule(Arc::new(DummyRule {})) + .build(); + + assert_eq!( + state.optimizers().len(), + Optimizer::default().rules.len() + 1 + ); + } + + #[test] + fn test_with_table_factories() -> Result<()> { + use crate::test_util::TestTableFactory; + + let state = SessionStateBuilder::new().build(); + let table_factories = state.table_factories(); + assert!(table_factories.is_empty()); + + let table_factory = Arc::new(TestTableFactory {}); + let state = SessionStateBuilder::new() + .with_table_factory("employee".to_string(), table_factory) + .build(); + let table_factories = state.table_factories(); + assert_eq!(table_factories.len(), 1); + assert!(table_factories.contains_key("employee")); + Ok(()) + } +} diff --git a/datafusion/core/src/execution/session_state_defaults.rs b/datafusion/core/src/execution/session_state_defaults.rs new file mode 100644 index 000000000000..b5370efa0a97 --- /dev/null +++ b/datafusion/core/src/execution/session_state_defaults.rs @@ -0,0 +1,212 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::catalog::{CatalogProvider, TableProviderFactory}; +use crate::catalog_common::listing_schema::ListingSchemaProvider; +use crate::catalog_common::{MemoryCatalogProvider, MemorySchemaProvider}; +use crate::datasource::file_format::arrow::ArrowFormatFactory; +use crate::datasource::file_format::avro::AvroFormatFactory; +use crate::datasource::file_format::csv::CsvFormatFactory; +use crate::datasource::file_format::json::JsonFormatFactory; +#[cfg(feature = "parquet")] +use crate::datasource::file_format::parquet::ParquetFormatFactory; +use crate::datasource::file_format::FileFormatFactory; +use crate::datasource::provider::DefaultTableFactory; +use crate::execution::context::SessionState; +#[cfg(feature = "nested_expressions")] +use crate::functions_nested; +use crate::{functions, functions_aggregate, functions_window}; +use datafusion_execution::config::SessionConfig; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_execution::runtime_env::RuntimeEnv; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; + +/// Defaults that are used as part of creating a SessionState such as table providers, +/// file formats, registering of builtin functions, etc. +pub struct SessionStateDefaults {} + +impl SessionStateDefaults { + /// returns a map of the default [`TableProviderFactory`]s + pub fn default_table_factories() -> HashMap> { + let mut table_factories: HashMap> = + HashMap::new(); + #[cfg(feature = "parquet")] + table_factories.insert("PARQUET".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("CSV".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("JSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("NDJSON".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("AVRO".into(), Arc::new(DefaultTableFactory::new())); + table_factories.insert("ARROW".into(), Arc::new(DefaultTableFactory::new())); + + table_factories + } + + /// returns the default MemoryCatalogProvider + pub fn default_catalog( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + ) -> MemoryCatalogProvider { + let default_catalog = MemoryCatalogProvider::new(); + + default_catalog + .register_schema( + &config.options().catalog.default_schema, + Arc::new(MemorySchemaProvider::new()), + ) + .expect("memory catalog provider can register schema"); + + Self::register_default_schema(config, table_factories, runtime, &default_catalog); + + default_catalog + } + + /// returns the list of default [`ExprPlanner`]s + pub fn default_expr_planners() -> Vec> { + let expr_planners: Vec> = vec![ + Arc::new(functions::core::planner::CoreFunctionPlanner::default()), + // register crate of nested expressions (if enabled) + #[cfg(feature = "nested_expressions")] + Arc::new(functions_nested::planner::NestedFunctionPlanner), + #[cfg(feature = "nested_expressions")] + Arc::new(functions_nested::planner::FieldAccessPlanner), + #[cfg(any( + feature = "datetime_expressions", + feature = "unicode_expressions" + ))] + Arc::new(functions::planner::UserDefinedFunctionPlanner), + ]; + + expr_planners + } + + /// returns the list of default [`ScalarUDF']'s + pub fn default_scalar_functions() -> Vec> { + #[cfg_attr(not(feature = "nested_expressions"), allow(unused_mut))] + let mut functions: Vec> = functions::all_default_functions(); + + #[cfg(feature = "nested_expressions")] + functions.append(&mut functions_nested::all_default_nested_functions()); + + functions + } + + /// returns the list of default [`AggregateUDF']'s + pub fn default_aggregate_functions() -> Vec> { + functions_aggregate::all_default_aggregate_functions() + } + + /// returns the list of default [`WindowUDF']'s + pub fn default_window_functions() -> Vec> { + functions_window::all_default_window_functions() + } + + /// returns the list of default [`FileFormatFactory']'s + pub fn default_file_formats() -> Vec> { + let file_formats: Vec> = vec![ + #[cfg(feature = "parquet")] + Arc::new(ParquetFormatFactory::new()), + Arc::new(JsonFormatFactory::new()), + Arc::new(CsvFormatFactory::new()), + Arc::new(ArrowFormatFactory::new()), + Arc::new(AvroFormatFactory::new()), + ]; + + file_formats + } + + /// registers all builtin functions - scalar, array and aggregate + pub fn register_builtin_functions(state: &mut SessionState) { + Self::register_scalar_functions(state); + Self::register_array_functions(state); + Self::register_aggregate_functions(state); + } + + /// registers all the builtin scalar functions + pub fn register_scalar_functions(state: &mut SessionState) { + functions::register_all(state).expect("can not register built in functions"); + } + + /// registers all the builtin array functions + #[cfg_attr(not(feature = "nested_expressions"), allow(unused_variables))] + pub fn register_array_functions(state: &mut SessionState) { + // register crate of array expressions (if enabled) + #[cfg(feature = "nested_expressions")] + functions_nested::register_all(state) + .expect("can not register nested expressions"); + } + + /// registers all the builtin aggregate functions + pub fn register_aggregate_functions(state: &mut SessionState) { + functions_aggregate::register_all(state) + .expect("can not register aggregate functions"); + } + + /// registers the default schema + pub fn register_default_schema( + config: &SessionConfig, + table_factories: &HashMap>, + runtime: &Arc, + default_catalog: &MemoryCatalogProvider, + ) { + let url = config.options().catalog.location.as_ref(); + let format = config.options().catalog.format.as_ref(); + let (url, format) = match (url, format) { + (Some(url), Some(format)) => (url, format), + _ => return, + }; + let url = url.to_string(); + let format = format.to_string(); + + let url = Url::parse(url.as_str()).expect("Invalid default catalog location!"); + let authority = match url.host_str() { + Some(host) => format!("{}://{}", url.scheme(), host), + None => format!("{}://", url.scheme()), + }; + let path = &url.as_str()[authority.len()..]; + let path = object_store::path::Path::parse(path).expect("Can't parse path"); + let store = ObjectStoreUrl::parse(authority.as_str()) + .expect("Invalid default catalog url"); + let store = match runtime.object_store(store) { + Ok(store) => store, + _ => return, + }; + let factory = match table_factories.get(format.as_str()) { + Some(factory) => factory, + _ => return, + }; + let schema = + ListingSchemaProvider::new(authority, path, factory.clone(), store, format); + let _ = default_catalog + .register_schema("default", Arc::new(schema)) + .expect("Failed to register default schema"); + } + + /// registers the default [`FileFormatFactory`]s + pub fn register_default_file_formats(state: &mut SessionState) { + let formats = SessionStateDefaults::default_file_formats(); + for format in formats { + if let Err(e) = state.register_file_format(format, false) { + log::info!("Unable to register default file format: {e}") + }; + } + } +} diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 24be185fb079..63d4fbc0bba5 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -17,24 +17,28 @@ #![warn(missing_docs, clippy::needless_borrow)] //! [DataFusion] is an extensible query engine written in Rust that -//! uses [Apache Arrow] as its in-memory format. DataFusion's many [use -//! cases] help developers build very fast and feature rich database -//! and analytic systems, customized to particular workloads. +//! uses [Apache Arrow] as its in-memory format. DataFusion's target users are +//! developers building fast and feature rich database and analytic systems, +//! customized to particular workloads. See [use cases] for examples. //! -//! "Out of the box," DataFusion quickly runs complex [SQL] and -//! [`DataFrame`] queries using a sophisticated query planner, a columnar, -//! multi-threaded, vectorized execution engine, and partitioned data -//! sources (Parquet, CSV, JSON, and Avro). +//! "Out of the box," DataFusion offers [SQL] and [`Dataframe`] APIs, +//! excellent [performance], built-in support for CSV, Parquet, JSON, and Avro, +//! extensive customization, and a great community. +//! [Python Bindings] are also available. //! -//! DataFusion is designed for easy customization such as supporting -//! additional data sources, query languages, functions, custom -//! operators and more. See the [Architecture] section for more details. +//! DataFusion features a full query planner, a columnar, streaming, multi-threaded, +//! vectorized execution engine, and partitioned data sources. You can +//! customize DataFusion at almost all points including additional data sources, +//! query languages, functions, custom operators and more. +//! See the [Architecture] section below for more details. //! //! [DataFusion]: https://datafusion.apache.org/ //! [Apache Arrow]: https://arrow.apache.org //! [use cases]: https://datafusion.apache.org/user-guide/introduction.html#use-cases //! [SQL]: https://datafusion.apache.org/user-guide/sql/index.html //! [`DataFrame`]: dataframe::DataFrame +//! [performance]: https://benchmark.clickhouse.com/ +//! [Python Bindings]: https://github.com/apache/datafusion-python //! [Architecture]: #architecture //! //! # Examples @@ -52,6 +56,7 @@ //! ```rust //! # use datafusion::prelude::*; //! # use datafusion::error::Result; +//! # use datafusion::functions_aggregate::expr_fn::min; //! # use datafusion::arrow::record_batch::RecordBatch; //! //! # #[tokio::main] @@ -75,7 +80,7 @@ //! //! let expected = vec![ //! "+---+----------------+", -//! "| a | MIN(?table?.b) |", +//! "| a | min(?table?.b) |", //! "+---+----------------+", //! "| 1 | 2 |", //! "+---+----------------+" @@ -113,7 +118,7 @@ //! //! let expected = vec![ //! "+---+----------------+", -//! "| a | MIN(example.b) |", +//! "| a | min(example.b) |", //! "+---+----------------+", //! "| 1 | 2 |", //! "+---+----------------+" @@ -130,11 +135,51 @@ //! //! [datafusion-examples]: https://github.com/apache/datafusion/tree/main/datafusion-examples //! +//! # Architecture +//! +//! +//! +//! You can find a formal description of DataFusion's architecture in our +//! [SIGMOD 2024 Paper]. +//! +//! [SIGMOD 2024 Paper]: https://dl.acm.org/doi/10.1145/3626246.3653368 +//! +//! ## Design Goals +//! DataFusion's Architecture Goals are: +//! +//! 1. Work “out of the box”: Provide a very fast, world class query engine with +//! minimal setup or required configuration. +//! +//! 2. Customizable everything: All behavior should be customizable by +//! implementing traits. +//! +//! 3. Architecturally boring 🥱: Follow industrial best practice rather than +//! trying cutting edge, but unproven, techniques. +//! +//! With these principles, users start with a basic, high-performance engine +//! and specialize it over time to suit their needs and available engineering +//! capacity. +//! +//! ## Overview Presentations +//! +//! The following presentations offer high level overviews of the +//! different components and how they interact together. +//! +//! - [Apr 2023]: The Apache DataFusion Architecture talks +//! - _Query Engine_: [recording](https://youtu.be/NVKujPxwSBA) and [slides](https://docs.google.com/presentation/d/1D3GDVas-8y0sA4c8EOgdCvEjVND4s2E7I6zfs67Y4j8/edit#slide=id.p) +//! - _Logical Plan and Expressions_: [recording](https://youtu.be/EzZTLiSJnhY) and [slides](https://docs.google.com/presentation/d/1ypylM3-w60kVDW7Q6S99AHzvlBgciTdjsAfqNP85K30) +//! - _Physical Plan and Execution_: [recording](https://youtu.be/2jkWU3_w6z0) and [slides](https://docs.google.com/presentation/d/1cA2WQJ2qg6tx6y4Wf8FH2WVSm9JQ5UgmBWATHdik0hg) +//! - [July 2022]: DataFusion and Arrow: Supercharge Your Data Analytical Tool with a Rusty Query Engine: [recording](https://www.youtube.com/watch?v=Rii1VTn3seQ) and [slides](https://docs.google.com/presentation/d/1q1bPibvu64k2b7LPi7Yyb0k3gA1BiUYiUbEklqW1Ckc/view#slide=id.g11054eeab4c_0_1165) +//! - [March 2021]: The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) +//! - [February 2021]: How DataFusion is used within the Ballista Project is described in _Ballista: Distributed Compute with Rust and Apache Arrow_: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) +//! //! ## Customization and Extension //! -//! DataFusion is a "disaggregated" query engine. This -//! means developers can start with a working, full featured engine, and then -//! extend the areas they need to specialize for their usecase. For example, +//! DataFusion is designed to be highly extensible, so you can +//! start with a working, full featured engine, and then +//! specialize any behavior for your usecase. For example, //! some projects may add custom [`ExecutionPlan`] operators, or create their own //! query language that directly creates [`LogicalPlan`] rather than using the //! built in SQL planner, [`SqlToRel`]. @@ -142,7 +187,7 @@ //! In order to achieve this, DataFusion supports extension at many points: //! //! * read from any datasource ([`TableProvider`]) -//! * define your own catalogs, schemas, and table lists ([`CatalogProvider`]) +//! * define your own catalogs, schemas, and table lists ([`catalog`] and [`CatalogProvider`]) //! * build your own query language or plans ([`LogicalPlanBuilder`]) //! * declare and use user-defined functions ([`ScalarUDF`], and [`AggregateUDF`], [`WindowUDF`]) //! * add custom plan rewrite passes ([`AnalyzerRule`], [`OptimizerRule`] and [`PhysicalOptimizerRule`]) @@ -159,31 +204,7 @@ //! [`QueryPlanner`]: execution::context::QueryPlanner //! [`OptimizerRule`]: datafusion_optimizer::optimizer::OptimizerRule //! [`AnalyzerRule`]: datafusion_optimizer::analyzer::AnalyzerRule -//! [`PhysicalOptimizerRule`]: crate::physical_optimizer::optimizer::PhysicalOptimizerRule -//! -//! # Architecture -//! -//! -//! -//! You can find a formal description of DataFusion's architecture in our -//! [SIGMOD 2024 Paper]. -//! -//! [SIGMOD 2024 Paper]: https://github.com/apache/datafusion/files/14789704/DataFusion_Query_Engine___SIGMOD_2024-FINAL.pdf -//! -//! ## Overview Presentations -//! -//! The following presentations offer high level overviews of the -//! different components and how they interact together. -//! -//! - [Apr 2023]: The Apache DataFusion Architecture talks -//! - _Query Engine_: [recording](https://youtu.be/NVKujPxwSBA) and [slides](https://docs.google.com/presentation/d/1D3GDVas-8y0sA4c8EOgdCvEjVND4s2E7I6zfs67Y4j8/edit#slide=id.p) -//! - _Logical Plan and Expressions_: [recording](https://youtu.be/EzZTLiSJnhY) and [slides](https://docs.google.com/presentation/d/1ypylM3-w60kVDW7Q6S99AHzvlBgciTdjsAfqNP85K30) -//! - _Physical Plan and Execution_: [recording](https://youtu.be/2jkWU3_w6z0) and [slides](https://docs.google.com/presentation/d/1cA2WQJ2qg6tx6y4Wf8FH2WVSm9JQ5UgmBWATHdik0hg) -//! - [July 2022]: DataFusion and Arrow: Supercharge Your Data Analytical Tool with a Rusty Query Engine: [recording](https://www.youtube.com/watch?v=Rii1VTn3seQ) and [slides](https://docs.google.com/presentation/d/1q1bPibvu64k2b7LPi7Yyb0k3gA1BiUYiUbEklqW1Ckc/view#slide=id.g11054eeab4c_0_1165) -//! - [March 2021]: The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934) -//! - [February 2021]: How DataFusion is used within the Ballista Project is described in _Ballista: Distributed Compute with Rust and Apache Arrow_: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ) +//! [`PhysicalOptimizerRule`]: crate::physical_optimizer::PhysicalOptimizerRule //! //! ## Query Planning and Execution Overview //! @@ -203,11 +224,11 @@ //! ``` //! //! 1. The query string is parsed to an Abstract Syntax Tree (AST) -//! [`Statement`] using [sqlparser]. +//! [`Statement`] using [sqlparser]. //! //! 2. The AST is converted to a [`LogicalPlan`] and logical -//! expressions [`Expr`]s to compute the desired result by the -//! [`SqlToRel`] planner. +//! expressions [`Expr`]s to compute the desired result by the +//! [`SqlToRel`] planner. //! //! [`Statement`]: https://docs.rs/sqlparser/latest/sqlparser/ast/enum.Statement.html //! @@ -239,17 +260,17 @@ //! optimizing, in the following manner: //! //! 1. The [`LogicalPlan`] is checked and rewritten to enforce -//! semantic rules, such as type coercion, by [`AnalyzerRule`]s +//! semantic rules, such as type coercion, by [`AnalyzerRule`]s //! //! 2. The [`LogicalPlan`] is rewritten by [`OptimizerRule`]s, such as -//! projection and filter pushdown, to improve its efficiency. +//! projection and filter pushdown, to improve its efficiency. //! //! 3. The [`LogicalPlan`] is converted to an [`ExecutionPlan`] by a -//! [`PhysicalPlanner`] +//! [`PhysicalPlanner`] //! //! 4. The [`ExecutionPlan`] is rewritten by -//! [`PhysicalOptimizerRule`]s, such as sort and join selection, to -//! improve its efficiency. +//! [`PhysicalOptimizerRule`]s, such as sort and join selection, to +//! improve its efficiency. //! //! ## Data Sources //! @@ -275,9 +296,9 @@ //! an [`ExecutionPlan`]s for execution. //! //! 1. [`ListingTable`]: Reads data from Parquet, JSON, CSV, or AVRO -//! files. Supports single files or multiple files with HIVE style -//! partitioning, optional compression, directly reading from remote -//! object store and more. +//! files. Supports single files or multiple files with HIVE style +//! partitioning, optional compression, directly reading from remote +//! object store and more. //! //! 2. [`MemTable`]: Reads data from in memory [`RecordBatch`]es. //! @@ -409,13 +430,13 @@ //! structures: //! //! 1. [`SessionContext`]: State needed for create [`LogicalPlan`]s such -//! as the table definitions, and the function registries. +//! as the table definitions, and the function registries. //! //! 2. [`TaskContext`]: State needed for execution such as the -//! [`MemoryPool`], [`DiskManager`], and [`ObjectStoreRegistry`]. +//! [`MemoryPool`], [`DiskManager`], and [`ObjectStoreRegistry`]. //! //! 3. [`ExecutionProps`]: Per-execution properties and data (such as -//! starting timestamps, etc). +//! starting timestamps, etc). //! //! [`SessionContext`]: crate::execution::context::SessionContext //! [`TaskContext`]: crate::execution::context::TaskContext @@ -442,12 +463,26 @@ //! * [datafusion_execution]: State and structures needed for execution //! * [datafusion_expr]: [`LogicalPlan`], [`Expr`] and related logical planning structure //! * [datafusion_functions]: Scalar function packages -//! * [datafusion_functions_array]: Scalar function packages for `ARRAY`s +//! * [datafusion_functions_nested]: Scalar function packages for `ARRAY`s, `MAP`s and `STRUCT`s //! * [datafusion_optimizer]: [`OptimizerRule`]s and [`AnalyzerRule`]s //! * [datafusion_physical_expr]: [`PhysicalExpr`] and related expressions //! * [datafusion_physical_plan]: [`ExecutionPlan`] and related expressions //! * [datafusion_sql]: SQL planner ([`SqlToRel`]) //! +//! ## Citing DataFusion in Academic Papers +//! +//! You can use the following citation to reference DataFusion in academic papers: +//! +//! ```text +//! @inproceedings{lamb2024apache +//! title={Apache Arrow DataFusion: A Fast, Embeddable, Modular Analytic Query Engine}, +//! author={Lamb, Andrew and Shen, Yijie and Heres, Dani{\"e}l and Chakraborty, Jayjeet and Kabak, Mehmet Ozan and Hsieh, Liang-Chi and Sun, Chao}, +//! booktitle={Companion of the 2024 International Conference on Management of Data}, +//! pages={5--17}, +//! year={2024} +//! } +//! ``` +//! //! [sqlparser]: https://docs.rs/sqlparser/latest/sqlparser //! [`SqlToRel`]: sql::planner::SqlToRel //! [`Expr`]: datafusion_expr::Expr @@ -459,7 +494,6 @@ //! [`PhysicalOptimizerRule`]: datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule //! [`Schema`]: arrow::datatypes::Schema //! [`PhysicalExpr`]: physical_plan::PhysicalExpr -//! [`AggregateExpr`]: physical_plan::AggregateExpr //! [`RecordBatch`]: arrow::record_batch::RecordBatch //! [`RecordBatchReader`]: arrow::record_batch::RecordBatchReader //! [`Array`]: arrow::array::Array @@ -470,7 +504,7 @@ pub const DATAFUSION_VERSION: &str = env!("CARGO_PKG_VERSION"); extern crate core; extern crate sqlparser; -pub mod catalog; +pub mod catalog_common; pub mod dataframe; pub mod datasource; pub mod error; @@ -479,7 +513,6 @@ pub mod physical_optimizer; pub mod physical_planner; pub mod prelude; pub mod scalar; -pub mod variable; // re-export dependencies from arrow-rs to minimize version maintenance for crate users pub use arrow; @@ -505,6 +538,11 @@ pub use common::config; // NB datafusion execution is re-exported in the `execution` module +/// re-export of [`datafusion_catalog`] crate +pub mod catalog { + pub use datafusion_catalog::*; +} + /// re-export of [`datafusion_expr`] crate pub mod logical_expr { pub use datafusion_expr::*; @@ -515,6 +553,11 @@ pub mod optimizer { pub use datafusion_optimizer::*; } +/// re-export of [`datafusion_physical_expr`] crate +pub mod physical_expr_common { + pub use datafusion_physical_expr_common::*; +} + /// re-export of [`datafusion_physical_expr`] crate pub mod physical_expr { pub use datafusion_physical_expr::*; @@ -539,10 +582,17 @@ pub mod functions { pub use datafusion_functions::*; } -/// re-export of [`datafusion_functions_array`] crate, if "array_expressions" feature is enabled +/// re-export of [`datafusion_functions_nested`] crate, if "nested_expressions" feature is enabled +pub mod functions_nested { + #[cfg(feature = "nested_expressions")] + pub use datafusion_functions_nested::*; +} + +/// re-export of [`datafusion_functions_nested`] crate as [`functions_array`] for backward compatibility, if "nested_expressions" feature is enabled +#[deprecated(since = "41.0.0", note = "use datafusion-functions-nested instead")] pub mod functions_array { - #[cfg(feature = "array_expressions")] - pub use datafusion_functions_array::*; + #[cfg(feature = "nested_expressions")] + pub use datafusion_functions_nested::*; } /// re-export of [`datafusion_functions_aggregate`] crate @@ -550,6 +600,16 @@ pub mod functions_aggregate { pub use datafusion_functions_aggregate::*; } +/// re-export of [`datafusion_functions_window`] crate +pub mod functions_window { + pub use datafusion_functions_window::*; +} + +/// re-export of variable provider for `@name` and `@@name` style runtime values. +pub mod variable { + pub use datafusion_expr::var_provider::{VarProvider, VarType}; +} + #[cfg(test)] pub mod test; pub mod test_util; @@ -557,8 +617,77 @@ pub mod test_util; #[cfg(doctest)] doc_comment::doctest!("../../../README.md", readme_example_test); +// Instructions for Documentation Examples +// +// The following commands test the examples from the user guide as part of +// `cargo test --doc` +// +// # Adding new tests: +// +// Simply add code like this to your .md file and ensure your md file is +// included in the lists below. +// +// ```rust +// +// ``` +// +// Note that sometimes it helps to author the doctest as a standalone program +// first, and then copy it into the user guide. +// +// # Debugging Test Failures +// +// Unfortunately, the line numbers reported by doctest do not correspond to the +// line numbers of in the .md files. Thus, if a doctest fails, use the name of +// the test to find the relevant file in the list below, and then find the +// example in that file to fix. +// +// For example, if `user_guide_expressions(line 123)` fails, +// go to `docs/source/user-guide/expressions.md` to find the relevant problem. + #[cfg(doctest)] doc_comment::doctest!( "../../../docs/source/user-guide/example-usage.md", - user_guid_example_tests + user_guide_example_usage +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/crate-configuration.md", + user_guide_crate_configuration +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/configs.md", + user_guide_configs +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/dataframe.md", + user_guide_dataframe +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/user-guide/expressions.md", + user_guide_expressions +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/using-the-sql-api.md", + library_user_guide_sql_api +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/building-logical-plans.md", + library_user_guide_logical_plans +); + +#[cfg(doctest)] +doc_comment::doctest!( + "../../../docs/source/library-user-guide/using-the-dataframe-api.md", + library_user_guide_dataframe_api ); diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 42b7463600dc..2f834813ede9 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use crate::{ config::ConfigOptions, error::Result, - physical_optimizer::PhysicalOptimizerRule, physical_plan::{ coalesce_batches::CoalesceBatchesExec, filter::FilterExec, joins::HashJoinExec, repartition::RepartitionExec, Partitioning, @@ -31,10 +30,11 @@ use crate::{ }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters -#[derive(Default)] +#[derive(Default, Debug)] pub struct CoalesceBatches {} impl CoalesceBatches { diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs deleted file mode 100644 index 92787df461d3..000000000000 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ /dev/null @@ -1,475 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs -//! and try to combine them if necessary - -use std::sync::Arc; - -use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; -use crate::physical_plan::ExecutionPlan; - -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; - -/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs -/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal. -/// -/// This rule should be applied after the EnforceDistribution and EnforceSorting rules -/// -#[derive(Default)] -pub struct CombinePartialFinalAggregate {} - -impl CombinePartialFinalAggregate { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for CombinePartialFinalAggregate { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - plan.transform_down(|plan| { - let transformed = - plan.as_any() - .downcast_ref::() - .and_then(|agg_exec| { - if matches!( - agg_exec.mode(), - AggregateMode::Final | AggregateMode::FinalPartitioned - ) { - agg_exec - .input() - .as_any() - .downcast_ref::() - .and_then(|input_agg_exec| { - if matches!( - input_agg_exec.mode(), - AggregateMode::Partial - ) && can_combine( - ( - agg_exec.group_by(), - agg_exec.aggr_expr(), - agg_exec.filter_expr(), - ), - ( - input_agg_exec.group_by(), - input_agg_exec.aggr_expr(), - input_agg_exec.filter_expr(), - ), - ) { - let mode = - if agg_exec.mode() == &AggregateMode::Final { - AggregateMode::Single - } else { - AggregateMode::SinglePartitioned - }; - AggregateExec::try_new( - mode, - input_agg_exec.group_by().clone(), - input_agg_exec.aggr_expr().to_vec(), - input_agg_exec.filter_expr().to_vec(), - input_agg_exec.input().clone(), - input_agg_exec.input_schema(), - ) - .map(|combined_agg| { - combined_agg.with_limit(agg_exec.limit()) - }) - .ok() - .map(Arc::new) - } else { - None - } - }) - } else { - None - } - }); - - Ok(if let Some(transformed) = transformed { - Transformed::yes(transformed) - } else { - Transformed::no(plan) - }) - }) - .data() - } - - fn name(&self) -> &str { - "CombinePartialFinalAggregate" - } - - fn schema_check(&self) -> bool { - true - } -} - -type GroupExprsRef<'a> = ( - &'a PhysicalGroupBy, - &'a [Arc], - &'a [Option>], -); - -type GroupExprs = ( - PhysicalGroupBy, - Vec>, - Vec>>, -); - -fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { - let (final_group_by, final_aggr_expr, final_filter_expr) = - normalize_group_exprs(final_agg); - let (input_group_by, input_aggr_expr, input_filter_expr) = - normalize_group_exprs(partial_agg); - - final_group_by.eq(&input_group_by) - && final_aggr_expr.len() == input_aggr_expr.len() - && final_aggr_expr - .iter() - .zip(input_aggr_expr.iter()) - .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr)) - && final_filter_expr.len() == input_filter_expr.len() - && final_filter_expr.iter().zip(input_filter_expr.iter()).all( - |(final_expr, partial_expr)| match (final_expr, partial_expr) { - (Some(l), Some(r)) => l.eq(r), - (None, None) => true, - _ => false, - }, - ) -} - -// To compare the group expressions between the final and partial aggregations, need to discard all the column indexes and compare -fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { - let (group, agg, filter) = group_exprs; - let new_group_expr = group - .expr() - .iter() - .map(|(expr, name)| (discard_column_index(expr.clone()), name.clone())) - .collect::>(); - let new_group = PhysicalGroupBy::new( - new_group_expr, - group.null_expr().to_vec(), - group.groups().to_vec(), - ); - (new_group, agg.to_vec(), filter.to_vec()) -} - -fn discard_column_index(group_expr: Arc) -> Arc { - group_expr - .clone() - .transform(|expr| { - let normalized_form: Option> = - match expr.as_any().downcast_ref::() { - Some(column) => Some(Arc::new(Column::new(column.name(), 0))), - None => None, - }; - Ok(if let Some(normalized_form) = normalized_form { - Transformed::yes(normalized_form) - } else { - Transformed::no(expr) - }) - }) - .data() - .unwrap_or(group_expr) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::datasource::listing::PartitionedFile; - use crate::datasource::object_store::ObjectStoreUrl; - use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; - use crate::physical_plan::expressions::lit; - use crate::physical_plan::repartition::RepartitionExec; - use crate::physical_plan::{displayable, Partitioning, Statistics}; - - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_physical_expr::expressions::{col, Count, Sum}; - - /// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected - macro_rules! assert_optimized { - ($EXPECTED_LINES: expr, $PLAN: expr) => { - let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); - - // run optimizer - let optimizer = CombinePartialFinalAggregate {}; - let config = ConfigOptions::new(); - let optimized = optimizer.optimize($PLAN, &config)?; - // Now format correctly - let plan = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = trim_plan_display(&plan); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - }; - } - - fn trim_plan_display(plan: &str) -> Vec<&str> { - plan.split('\n') - .map(|s| s.trim()) - .filter(|s| !s.is_empty()) - .collect() - } - - fn schema() -> SchemaRef { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Int64, true), - Field::new("c", DataType::Int64, true), - ])) - } - - fn parquet_exec(schema: &SchemaRef) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), - )) - } - - fn partial_aggregate_exec( - input: Arc, - group_by: PhysicalGroupBy, - aggr_expr: Vec>, - ) -> Arc { - let schema = input.schema(); - let n_aggr = aggr_expr.len(); - Arc::new( - AggregateExec::try_new( - AggregateMode::Partial, - group_by, - aggr_expr, - vec![None; n_aggr], - input, - schema, - ) - .unwrap(), - ) - } - - fn final_aggregate_exec( - input: Arc, - group_by: PhysicalGroupBy, - aggr_expr: Vec>, - ) -> Arc { - let schema = input.schema(); - let n_aggr = aggr_expr.len(); - Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - group_by, - aggr_expr, - vec![None; n_aggr], - input, - schema, - ) - .unwrap(), - ) - } - - fn repartition_exec(input: Arc) -> Arc { - Arc::new( - RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap(), - ) - } - - #[test] - fn aggregations_not_combined() -> Result<()> { - let schema = schema(); - - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - let plan = final_aggregate_exec( - repartition_exec(partial_aggregate_exec( - parquet_exec(&schema), - PhysicalGroupBy::default(), - aggr_expr.clone(), - )), - PhysicalGroupBy::default(), - aggr_expr, - ); - // should not combine the Partial/Final AggregateExecs - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(1)]", - "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", - ]; - assert_optimized!(expected, plan); - - let aggr_expr1 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - let aggr_expr2 = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(2)".to_string(), - DataType::Int64, - )) as _]; - - let plan = final_aggregate_exec( - partial_aggregate_exec( - parquet_exec(&schema), - PhysicalGroupBy::default(), - aggr_expr1, - ), - PhysicalGroupBy::default(), - aggr_expr2, - ); - // should not combine the Partial/Final AggregateExecs - let expected = &[ - "AggregateExec: mode=Final, gby=[], aggr=[COUNT(2)]", - "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", - ]; - - assert_optimized!(expected, plan); - - Ok(()) - } - - #[test] - fn aggregations_combined() -> Result<()> { - let schema = schema(); - let aggr_expr = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )) as _]; - - let plan = final_aggregate_exec( - partial_aggregate_exec( - parquet_exec(&schema), - PhysicalGroupBy::default(), - aggr_expr.clone(), - ), - PhysicalGroupBy::default(), - aggr_expr, - ); - // should combine the Partial/Final AggregateExecs to tne Single AggregateExec - let expected = &[ - "AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn aggregations_with_group_combined() -> Result<()> { - let schema = schema(); - let aggr_expr = vec![Arc::new(Sum::new( - col("b", &schema)?, - "Sum(b)".to_string(), - DataType::Int64, - )) as _]; - - let groups: Vec<(Arc, String)> = - vec![(col("c", &schema)?, "c".to_string())]; - - let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); - - let groups: Vec<(Arc, String)> = - vec![(col("c", &partial_agg.schema())?, "c".to_string())]; - let final_group_by = PhysicalGroupBy::new_single(groups); - - let plan = final_aggregate_exec(partial_agg, final_group_by, aggr_expr); - // should combine the Partial/Final AggregateExecs to tne Single AggregateExec - let expected = &[ - "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } - - #[test] - fn aggregations_with_limit_combined() -> Result<()> { - let schema = schema(); - let aggr_expr = vec![]; - - let groups: Vec<(Arc, String)> = - vec![(col("c", &schema)?, "c".to_string())]; - - let partial_group_by = PhysicalGroupBy::new_single(groups); - let partial_agg = partial_aggregate_exec( - parquet_exec(&schema), - partial_group_by, - aggr_expr.clone(), - ); - - let groups: Vec<(Arc, String)> = - vec![(col("c", &partial_agg.schema())?, "c".to_string())]; - let final_group_by = PhysicalGroupBy::new_single(groups); - - let schema = partial_agg.schema(); - let final_agg = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - final_group_by, - aggr_expr, - vec![], - partial_agg, - schema, - ) - .unwrap() - .with_limit(Some(5)), - ); - let plan: Arc = final_agg; - // should combine the Partial/Final AggregateExecs to a Single AggregateExec - // with the final limit preserved - let expected = &[ - "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", - ]; - - assert_optimized!(expected, plan); - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/convert_first_last.rs b/datafusion/core/src/physical_optimizer/convert_first_last.rs deleted file mode 100644 index 14860eecf189..000000000000 --- a/datafusion/core/src/physical_optimizer/convert_first_last.rs +++ /dev/null @@ -1,260 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion_common::Result; -use datafusion_common::{ - config::ConfigOptions, - tree_node::{Transformed, TransformedResult, TreeNode}, -}; -use datafusion_physical_expr::expressions::{FirstValue, LastValue}; -use datafusion_physical_expr::{ - equivalence::ProjectionMapping, reverse_order_bys, AggregateExpr, - EquivalenceProperties, PhysicalSortRequirement, -}; -use datafusion_physical_plan::aggregates::concat_slices; -use datafusion_physical_plan::{ - aggregates::{AggregateExec, AggregateMode}, - ExecutionPlan, ExecutionPlanProperties, InputOrderMode, -}; -use std::sync::Arc; - -use datafusion_physical_plan::windows::get_ordered_partition_by_indices; - -use super::PhysicalOptimizerRule; - -/// The optimizer rule check the ordering requirements of the aggregate expressions. -/// And convert between FIRST_VALUE and LAST_VALUE if possible. -/// For example, If we have an ascending values and we want LastValue from the descending requirement, -/// it is equivalent to FirstValue with the current ascending ordering. -/// -/// The concrete example is that, says we have values c1 with [1, 2, 3], which is an ascending order. -/// If we want LastValue(c1 order by desc), which is the first value of reversed c1 [3, 2, 1], -/// so we can convert the aggregate expression to FirstValue(c1 order by asc), -/// since the current ordering is already satisfied, it saves our time! -#[derive(Default)] -pub struct OptimizeAggregateOrder {} - -impl OptimizeAggregateOrder { - pub fn new() -> Self { - Self::default() - } -} - -impl PhysicalOptimizerRule for OptimizeAggregateOrder { - fn optimize( - &self, - plan: Arc, - _config: &ConfigOptions, - ) -> Result> { - plan.transform_up(get_common_requirement_of_aggregate_input) - .data() - } - - fn name(&self) -> &str { - "OptimizeAggregateOrder" - } - - fn schema_check(&self) -> bool { - true - } -} - -fn get_common_requirement_of_aggregate_input( - plan: Arc, -) -> Result>> { - if let Some(aggr_exec) = plan.as_any().downcast_ref::() { - let input = aggr_exec.input(); - let mut aggr_expr = try_get_updated_aggr_expr_from_child(aggr_exec); - let group_by = aggr_exec.group_by(); - let mode = aggr_exec.mode(); - - let input_eq_properties = input.equivalence_properties(); - let groupby_exprs = group_by.input_exprs(); - // If existing ordering satisfies a prefix of the GROUP BY expressions, - // prefix requirements with this section. In this case, aggregation will - // work more efficiently. - let indices = get_ordered_partition_by_indices(&groupby_exprs, input); - let requirement = indices - .iter() - .map(|&idx| PhysicalSortRequirement { - expr: groupby_exprs[idx].clone(), - options: None, - }) - .collect::>(); - - try_convert_first_last_if_better( - &requirement, - &mut aggr_expr, - input_eq_properties, - )?; - - let required_input_ordering = (!requirement.is_empty()).then_some(requirement); - - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; - let projection_mapping = - ProjectionMapping::try_new(group_by.expr(), &input.schema())?; - - let cache = AggregateExec::compute_properties( - input, - plan.schema().clone(), - &projection_mapping, - mode, - &input_order_mode, - ); - - let aggr_exec = aggr_exec.new_with_aggr_expr_and_ordering_info( - required_input_ordering, - aggr_expr, - cache, - input_order_mode, - ); - - Ok(Transformed::yes( - Arc::new(aggr_exec) as Arc - )) - } else { - Ok(Transformed::no(plan)) - } -} - -/// In `create_initial_plan` for LogicalPlan::Aggregate, we have a nested AggregateExec where the first layer -/// is in Partial mode and the second layer is in Final or Finalpartitioned mode. -/// If the first layer of aggregate plan is transformed, we need to update the child of the layer with final mode. -/// Therefore, we check it and get the updated aggregate expressions. -/// -/// If AggregateExec is created from elsewhere, we skip the check and return the original aggregate expressions. -fn try_get_updated_aggr_expr_from_child( - aggr_exec: &AggregateExec, -) -> Vec> { - let input = aggr_exec.input(); - if aggr_exec.mode() == &AggregateMode::Final - || aggr_exec.mode() == &AggregateMode::FinalPartitioned - { - // Some aggregators may be modified during initialization for - // optimization purposes. For example, a FIRST_VALUE may turn - // into a LAST_VALUE with the reverse ordering requirement. - // To reflect such changes to subsequent stages, use the updated - // `AggregateExpr`/`PhysicalSortExpr` objects. - // - // The bottom up transformation is the mirror of LogicalPlan::Aggregate creation in [create_initial_plan] - if let Some(c_aggr_exec) = input.as_any().downcast_ref::() { - if c_aggr_exec.mode() == &AggregateMode::Partial { - // If the input is an AggregateExec in Partial mode, then the - // input is a CoalescePartitionsExec. In this case, the - // AggregateExec is the second stage of aggregation. The - // requirements of the second stage are the requirements of - // the first stage. - return c_aggr_exec.aggr_expr().to_vec(); - } - } - } - - aggr_exec.aggr_expr().to_vec() -} - -/// Get the common requirement that satisfies all the aggregate expressions. -/// -/// # Parameters -/// -/// - `aggr_exprs`: A slice of `Arc` containing all the -/// aggregate expressions. -/// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the -/// physical GROUP BY expression. -/// - `eq_properties`: A reference to an `EquivalenceProperties` instance -/// representing equivalence properties for ordering. -/// - `agg_mode`: A reference to an `AggregateMode` instance representing the -/// mode of aggregation. -/// -/// # Returns -/// -/// A `LexRequirement` instance, which is the requirement that satisfies all the -/// aggregate requirements. Returns an error in case of conflicting requirements. -/// -/// Similar to the one in datafusion/physical-plan/src/aggregates/mod.rs, but this -/// function care only the possible conversion between FIRST_VALUE and LAST_VALUE -fn try_convert_first_last_if_better( - prefix_requirement: &[PhysicalSortRequirement], - aggr_exprs: &mut [Arc], - eq_properties: &EquivalenceProperties, -) -> Result<()> { - for aggr_expr in aggr_exprs.iter_mut() { - let aggr_req = aggr_expr.order_bys().unwrap_or(&[]); - let reverse_aggr_req = reverse_order_bys(aggr_req); - let aggr_req = PhysicalSortRequirement::from_sort_exprs(aggr_req); - let reverse_aggr_req = - PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_req); - - if let Some(first_value) = aggr_expr.as_any().downcast_ref::() { - let mut first_value = first_value.clone(); - - if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &aggr_req, - )) { - first_value = first_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(first_value) as _; - } else if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &reverse_aggr_req, - )) { - // Converting to LAST_VALUE enables more efficient execution - // given the existing ordering: - let mut last_value = first_value.convert_to_last(); - last_value = last_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(last_value) as _; - } else { - // Requirement is not satisfied with existing ordering. - first_value = first_value.with_requirement_satisfied(false); - *aggr_expr = Arc::new(first_value) as _; - } - continue; - } - if let Some(last_value) = aggr_expr.as_any().downcast_ref::() { - let mut last_value = last_value.clone(); - if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &aggr_req, - )) { - last_value = last_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(last_value) as _; - } else if eq_properties.ordering_satisfy_requirement(&concat_slices( - prefix_requirement, - &reverse_aggr_req, - )) { - // Converting to FIRST_VALUE enables more efficient execution - // given the existing ordering: - let mut first_value = last_value.convert_to_first(); - first_value = first_value.with_requirement_satisfied(true); - *aggr_expr = Arc::new(first_value) as _; - } else { - // Requirement is not satisfied with existing ordering. - last_value = last_value.with_requirement_satisfied(false); - *aggr_expr = Arc::new(last_value) as _; - } - continue; - } - } - - Ok(()) -} diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 14232f4933f8..6cd902db7244 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -24,14 +24,12 @@ use std::fmt::Debug; use std::sync::Arc; -use super::output_requirements::OutputRequirementExec; use crate::config::ConfigOptions; use crate::error::Result; use crate::physical_optimizer::utils::{ add_sort_above_with_check, is_coalesce_partitions, is_repartition, is_sort_preserving_merge, }; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::joins::{ @@ -46,6 +44,7 @@ use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning}; use arrow::compute::SortOptions; +use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; @@ -53,6 +52,9 @@ use datafusion_physical_expr::utils::map_columns_before_projection; use datafusion_physical_expr::{ physical_exprs_equal, EquivalenceProperties, PhysicalExpr, PhysicalExprRef, }; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion_physical_plan::windows::{get_best_fitting_window, BoundedWindowAggExec}; use datafusion_physical_plan::ExecutionPlanProperties; @@ -175,7 +177,7 @@ use itertools::izip; /// /// This rule only chooses the exact match and satisfies the Distribution(a, b, c) /// by a HashPartition(a, b, c). -#[derive(Default)] +#[derive(Default, Debug)] pub struct EnforceDistribution {} impl EnforceDistribution { @@ -308,7 +310,7 @@ fn adjust_input_keys_ordering( return reorder_partitioned_join_keys( requirements, on, - vec![], + &[], &join_constructor, ) .map(Transformed::yes); @@ -327,7 +329,8 @@ fn adjust_input_keys_ordering( JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti - | JoinType::Full => vec![], + | JoinType::Full + | JoinType::LeftMark => vec![], }; } PartitionMode::Auto => { @@ -372,7 +375,7 @@ fn adjust_input_keys_ordering( return reorder_partitioned_join_keys( requirements, on, - sort_options.clone(), + sort_options, &join_constructor, ) .map(Transformed::yes); @@ -392,7 +395,7 @@ fn adjust_input_keys_ordering( let expr = proj.expr(); // For Projection, we need to transform the requirements to the columns before the Projection // And then to push down the requirements - // Construct a mapping from new name to the orginal Column + // Construct a mapping from new name to the original Column let new_required = map_columns_before_projection(&requirements.data, expr); if new_required.len() == requirements.data.len() { requirements.children[0].data = new_required; @@ -411,7 +414,7 @@ fn adjust_input_keys_ordering( } else { // By default, push down the parent requirements to children for child in requirements.children.iter_mut() { - child.data = requirements.data.clone(); + child.data.clone_from(&requirements.data); } } Ok(Transformed::yes(requirements)) @@ -420,7 +423,7 @@ fn adjust_input_keys_ordering( fn reorder_partitioned_join_keys( mut join_plan: PlanWithKeyRequirements, on: &[(PhysicalExprRef, PhysicalExprRef)], - sort_options: Vec, + sort_options: &[SortOptions], join_constructor: &F, ) -> Result where @@ -461,7 +464,7 @@ fn reorder_aggregate_keys( ) -> Result { let parent_required = &agg_node.data; let output_columns = agg_exec - .group_by() + .group_expr() .expr() .iter() .enumerate() @@ -474,7 +477,7 @@ fn reorder_aggregate_keys( .collect::>(); if parent_required.len() == output_exprs.len() - && agg_exec.group_by().null_expr().is_empty() + && agg_exec.group_expr().null_expr().is_empty() && !physical_exprs_equal(&output_exprs, parent_required) { if let Some(positions) = expected_expr_positions(&output_exprs, parent_required) { @@ -482,7 +485,7 @@ fn reorder_aggregate_keys( agg_exec.input().as_any().downcast_ref::() { if matches!(agg_exec.mode(), &AggregateMode::Partial) { - let group_exprs = agg_exec.group_by().expr(); + let group_exprs = agg_exec.group_expr().expr(); let new_group_exprs = positions .into_iter() .map(|idx| group_exprs[idx].clone()) @@ -566,7 +569,7 @@ fn shift_right_required( }) .collect::>(); - // if the parent required are all comming from the right side, the requirements can be pushdown + // if the parent required are all coming from the right side, the requirements can be pushdown (new_right_required.len() == parent_required.len()).then_some(new_right_required) } @@ -856,6 +859,7 @@ fn add_roundrobin_on_top( /// Adds a hash repartition operator: /// - to increase parallelism, and/or /// - to satisfy requirements of the subsequent operators. +/// /// Repartition(Hash) is added on top of operator `input`. /// /// # Arguments @@ -932,7 +936,7 @@ fn add_spm_on_top(input: DistributionContext) -> DistributionContext { let new_plan = if should_preserve_ordering { Arc::new(SortPreservingMergeExec::new( - input.plan.output_ordering().unwrap_or(&[]).to_vec(), + LexOrdering::from_ref(input.plan.output_ordering().unwrap_or(&[])), input.plan.clone(), )) as _ } else { @@ -1030,6 +1034,105 @@ fn replace_order_preserving_variants( context.update_plan_from_children() } +/// A struct to keep track of repartition requirements for each child node. +struct RepartitionRequirementStatus { + /// The distribution requirement for the node. + requirement: Distribution, + /// Designates whether round robin partitioning is theoretically beneficial; + /// i.e. the operator can actually utilize parallelism. + roundrobin_beneficial: bool, + /// Designates whether round robin partitioning is beneficial according to + /// the statistical information we have on the number of rows. + roundrobin_beneficial_stats: bool, + /// Designates whether hash partitioning is necessary. + hash_necessary: bool, +} + +/// Calculates the `RepartitionRequirementStatus` for each children to generate +/// consistent and sensible (in terms of performance) distribution requirements. +/// As an example, a hash join's left (build) child might produce +/// +/// ```text +/// RepartitionRequirementStatus { +/// .., +/// hash_necessary: true +/// } +/// ``` +/// +/// while its right (probe) child might have very few rows and produce: +/// +/// ```text +/// RepartitionRequirementStatus { +/// .., +/// hash_necessary: false +/// } +/// ``` +/// +/// These statuses are not consistent as all children should agree on hash +/// partitioning. This function aligns the statuses to generate consistent +/// hash partitions for each children. After alignment, the right child's +/// status would turn into: +/// +/// ```text +/// RepartitionRequirementStatus { +/// .., +/// hash_necessary: true +/// } +/// ``` +fn get_repartition_requirement_status( + plan: &Arc, + batch_size: usize, + should_use_estimates: bool, +) -> Result> { + let mut needs_alignment = false; + let children = plan.children(); + let rr_beneficial = plan.benefits_from_input_partitioning(); + let requirements = plan.required_input_distribution(); + let mut repartition_status_flags = vec![]; + for (child, requirement, roundrobin_beneficial) in + izip!(children.into_iter(), requirements, rr_beneficial) + { + // Decide whether adding a round robin is beneficial depending on + // the statistical information we have on the number of rows: + let roundrobin_beneficial_stats = match child.statistics()?.num_rows { + Precision::Exact(n_rows) => n_rows > batch_size, + Precision::Inexact(n_rows) => !should_use_estimates || (n_rows > batch_size), + Precision::Absent => true, + }; + let is_hash = matches!(requirement, Distribution::HashPartitioned(_)); + // Hash re-partitioning is necessary when the input has more than one + // partitions: + let multi_partitions = child.output_partitioning().partition_count() > 1; + let roundrobin_sensible = roundrobin_beneficial && roundrobin_beneficial_stats; + needs_alignment |= is_hash && (multi_partitions || roundrobin_sensible); + repartition_status_flags.push(( + is_hash, + RepartitionRequirementStatus { + requirement, + roundrobin_beneficial, + roundrobin_beneficial_stats, + hash_necessary: is_hash && multi_partitions, + }, + )); + } + // Align hash necessary flags for hash partitions to generate consistent + // hash partitions at each children: + if needs_alignment { + // When there is at least one hash requirement that is necessary or + // beneficial according to statistics, make all children require hash + // repartitioning: + for (is_hash, status) in &mut repartition_status_flags { + if *is_hash { + status.hash_necessary = true; + } + } + } + Ok(repartition_status_flags + .into_iter() + .map(|(_, status)| status) + .collect()) +} + /// This function checks whether we need to add additional data exchange /// operators to satisfy distribution requirements. Since this function /// takes care of such requirements, we should avoid manually adding data @@ -1049,6 +1152,9 @@ fn ensure_distribution( let enable_round_robin = config.optimizer.enable_round_robin_repartition; let repartition_file_scans = config.optimizer.repartition_file_scans; let batch_size = config.execution.batch_size; + let should_use_estimates = config + .execution + .use_row_number_estimates_to_optimize_partitioning; let is_unbounded = dist_context.plan.execution_mode().is_unbounded(); // Use order preserving variants either of the conditions true // - it is desired according to config @@ -1081,6 +1187,8 @@ fn ensure_distribution( } }; + let repartition_status_flags = + get_repartition_requirement_status(&plan, batch_size, should_use_estimates)?; // This loop iterates over all the children to: // - Increase parallelism for every child if it is beneficial. // - Satisfy the distribution requirements of every child, if it is not @@ -1088,33 +1196,32 @@ fn ensure_distribution( // We store the updated children in `new_children`. let children = izip!( children.into_iter(), - plan.required_input_distribution().iter(), plan.required_input_ordering().iter(), - plan.benefits_from_input_partitioning(), - plan.maintains_input_order() + plan.maintains_input_order(), + repartition_status_flags.into_iter() ) .map( - |(mut child, requirement, required_input_ordering, would_benefit, maintains)| { - // Don't need to apply when the returned row count is not greater than batch size - let num_rows = child.plan.statistics()?.num_rows; - let repartition_beneficial_stats = if num_rows.is_exact().unwrap_or(false) { - num_rows - .get_value() - .map(|value| value > &batch_size) - .unwrap() // safe to unwrap since is_exact() is true - } else { - true - }; - + |( + mut child, + required_input_ordering, + maintains, + RepartitionRequirementStatus { + requirement, + roundrobin_beneficial, + roundrobin_beneficial_stats, + hash_necessary, + }, + )| { let add_roundrobin = enable_round_robin // Operator benefits from partitioning (e.g. filter): - && (would_benefit && repartition_beneficial_stats) + && roundrobin_beneficial + && roundrobin_beneficial_stats // Unless partitioning increases the partition count, it is not beneficial: && child.plan.output_partitioning().partition_count() < target_partitions; // When `repartition_file_scans` is set, attempt to increase // parallelism at the source. - if repartition_file_scans && repartition_beneficial_stats { + if repartition_file_scans && roundrobin_beneficial_stats { if let Some(new_child) = child.plan.repartitioned(target_partitions, config)? { @@ -1123,7 +1230,7 @@ fn ensure_distribution( } // Satisfy the distribution requirement if it is unmet. - match requirement { + match &requirement { Distribution::SinglePartition => { child = add_spm_on_top(child); } @@ -1133,7 +1240,11 @@ fn ensure_distribution( // to increase parallelism. child = add_roundrobin_on_top(child, target_partitions)?; } - child = add_hash_on_top(child, exprs.to_vec(), target_partitions)?; + // When inserting hash is necessary to satisy hash requirement, insert hash repartition. + if hash_necessary { + child = + add_hash_on_top(child, exprs.to_vec(), target_partitions)?; + } } Distribution::UnspecifiedDistribution => { if add_roundrobin { @@ -1163,7 +1274,7 @@ fn ensure_distribution( // Make sure to satisfy ordering requirement: child = add_sort_above_with_check( child, - required_input_ordering.to_vec(), + required_input_ordering.clone(), None, ); } @@ -1192,7 +1303,11 @@ fn ensure_distribution( .collect::>>()?; let children_plans = children.iter().map(|c| c.plan.clone()).collect::>(); - plan = if plan.as_any().is::() && can_interleave(children_plans.iter()) { + + plan = if plan.as_any().is::() + && !config.optimizer.prefer_existing_union + && can_interleave(children_plans.iter()) + { // Add a special case for [`UnionExec`] since we want to "bubble up" // hash-partitioned data. So instead of // @@ -1286,7 +1401,6 @@ pub(crate) mod tests { use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig, ParquetExec}; use crate::physical_optimizer::enforce_sorting::EnforceSorting; - use crate::physical_optimizer::output_requirements::OutputRequirements; use crate::physical_optimizer::test_utils::{ check_integrity, coalesce_partitions_exec, repartition_exec, }; @@ -1297,15 +1411,17 @@ pub(crate) mod tests { use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{displayable, DisplayAs, DisplayFormatType, Statistics}; + use datafusion_physical_optimizer::output_requirements::OutputRequirements; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; use datafusion_physical_expr::{ - expressions, expressions::binary, expressions::lit, LexOrdering, - PhysicalSortExpr, PhysicalSortRequirement, + expressions::binary, expressions::lit, LexOrdering, PhysicalSortExpr, + PhysicalSortRequirement, }; + use datafusion_physical_expr_common::sort_expr::LexRequirement; use datafusion_physical_plan::PlanProperties; /// Models operators like BoundedWindowExec that require an input @@ -1320,7 +1436,7 @@ pub(crate) mod tests { impl SortRequiredExec { fn new_with_requirement( input: Arc, - requirement: Vec, + requirement: LexOrdering, ) -> Self { let cache = Self::compute_properties(&input); Self { @@ -1346,11 +1462,7 @@ pub(crate) mod tests { _t: DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - write!( - f, - "SortRequiredExec: [{}]", - PhysicalSortExpr::format_list(&self.expr) - ) + write!(f, "SortRequiredExec: [{}]", self.expr) } } @@ -1371,16 +1483,18 @@ pub(crate) mod tests { vec![false] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } // model that it requires the output ordering of its input - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { if self.expr.is_empty() { vec![None] } else { - vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))] + vec![Some(PhysicalSortRequirement::from_sort_exprs( + self.expr.iter(), + ))] } } @@ -1425,23 +1539,14 @@ pub(crate) mod tests { /// create a single parquet file that is sorted pub(crate) fn parquet_exec_with_sort( - output_ordering: Vec>, + output_ordering: Vec, ) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(&schema()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering, - }, - None, - None, - Default::default(), - )) + ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(output_ordering), + ) + .build_arc() } fn parquet_exec_multiple() -> Arc { @@ -1450,50 +1555,39 @@ pub(crate) mod tests { /// Created a sorted parquet exec with multiple files fn parquet_exec_multiple_sorted( - output_ordering: Vec>, + output_ordering: Vec, ) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![ + ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) + .with_file_groups(vec![ vec![PartitionedFile::new("x".to_string(), 100)], vec![PartitionedFile::new("y".to_string(), 100)], - ], - statistics: Statistics::new_unknown(&schema()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering, - }, - None, - None, - Default::default(), - )) + ]) + .with_output_ordering(output_ordering), + ) + .build_arc() } fn csv_exec() -> Arc { csv_exec_with_sort(vec![]) } - fn csv_exec_with_sort(output_ordering: Vec>) -> Arc { - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(&schema()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering, - }, - false, - b',', - b'"', - None, - FileCompressionType::UNCOMPRESSED, - )) + fn csv_exec_with_sort(output_ordering: Vec) -> Arc { + Arc::new( + CsvExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(output_ordering), + ) + .with_has_header(false) + .with_delimeter(b',') + .with_quote(b'"') + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } fn csv_exec_multiple() -> Arc { @@ -1501,29 +1595,25 @@ pub(crate) mod tests { } // Created a sorted parquet exec with multiple files - fn csv_exec_multiple_sorted( - output_ordering: Vec>, - ) -> Arc { - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![ - vec![PartitionedFile::new("x".to_string(), 100)], - vec![PartitionedFile::new("y".to_string(), 100)], - ], - statistics: Statistics::new_unknown(&schema()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering, - }, - false, - b',', - b'"', - None, - FileCompressionType::UNCOMPRESSED, - )) + fn csv_exec_multiple_sorted(output_ordering: Vec) -> Arc { + Arc::new( + CsvExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) + .with_file_groups(vec![ + vec![PartitionedFile::new("x".to_string(), 100)], + vec![PartitionedFile::new("y".to_string(), 100)], + ]) + .with_output_ordering(output_ordering), + ) + .with_has_header(false) + .with_delimeter(b',') + .with_quote(b'"') + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } fn projection_exec_with_alias( @@ -1554,8 +1644,7 @@ pub(crate) mod tests { .enumerate() .map(|(index, (_col, name))| { ( - Arc::new(expressions::Column::new(name, index)) - as Arc, + Arc::new(Column::new(name, index)) as Arc, name.clone(), ) }) @@ -1636,7 +1725,7 @@ pub(crate) mod tests { } fn sort_exec( - sort_exprs: Vec, + sort_exprs: LexOrdering, input: Arc, preserve_partitioning: bool, ) -> Arc { @@ -1646,7 +1735,7 @@ pub(crate) mod tests { } fn sort_preserving_merge_exec( - sort_exprs: Vec, + sort_exprs: LexOrdering, input: Arc, ) -> Arc { Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) @@ -1721,16 +1810,25 @@ pub(crate) mod tests { /// * `TARGET_PARTITIONS` (optional) - number of partitions to repartition to /// * `REPARTITION_FILE_SCANS` (optional) - if true, will repartition file scans /// * `REPARTITION_FILE_MIN_SIZE` (optional) - minimum file size to repartition + /// * `PREFER_EXISTING_UNION` (optional) - if true, will not attempt to convert Union to Interleave macro_rules! assert_optimized { ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr) => { - assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, false, 10, false, 1024); + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, false, 10, false, 1024, false); }; ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr) => { - assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, 10, false, 1024); + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, 10, false, 1024, false); + }; + + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr, $PREFER_EXISTING_UNION: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, 10, false, 1024, $PREFER_EXISTING_UNION); }; ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr) => { + assert_optimized!($EXPECTED_LINES, $PLAN, $FIRST_ENFORCE_DIST, $PREFER_EXISTING_SORT, $TARGET_PARTITIONS, $REPARTITION_FILE_SCANS, $REPARTITION_FILE_MIN_SIZE, false); + }; + + ($EXPECTED_LINES: expr, $PLAN: expr, $FIRST_ENFORCE_DIST: expr, $PREFER_EXISTING_SORT: expr, $TARGET_PARTITIONS: expr, $REPARTITION_FILE_SCANS: expr, $REPARTITION_FILE_MIN_SIZE: expr, $PREFER_EXISTING_UNION: expr) => { let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); let mut config = ConfigOptions::new(); @@ -1738,6 +1836,9 @@ pub(crate) mod tests { config.optimizer.repartition_file_scans = $REPARTITION_FILE_SCANS; config.optimizer.repartition_file_min_size = $REPARTITION_FILE_MIN_SIZE; config.optimizer.prefer_existing_sort = $PREFER_EXISTING_SORT; + config.optimizer.prefer_existing_union = $PREFER_EXISTING_UNION; + // Use a small batch size, to trigger RoundRobin in tests + config.execution.batch_size = 1; // NOTE: These tests verify the joint `EnforceDistribution` + `EnforceSorting` cascade // because they were written prior to the separation of `BasicEnforcement` into @@ -1856,6 +1957,7 @@ pub(crate) mod tests { JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, ]; @@ -1878,7 +1980,8 @@ pub(crate) mod tests { | JoinType::Right | JoinType::Full | JoinType::LeftSemi - | JoinType::LeftAnti => { + | JoinType::LeftAnti + | JoinType::LeftMark => { // Join on (a == c) let top_join_on = vec![( Arc::new(Column::new_with_schema("a", &join.schema()).unwrap()) @@ -1896,7 +1999,7 @@ pub(crate) mod tests { let expected = match join_type { // Should include 3 RepartitionExecs - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![ + JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => vec![ top_join_plan.as_str(), join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", @@ -1995,7 +2098,7 @@ pub(crate) mod tests { assert_optimized!(expected, top_join.clone(), true); assert_optimized!(expected, top_join, false); } - JoinType::LeftSemi | JoinType::LeftAnti => {} + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {} } } @@ -2701,16 +2804,16 @@ pub(crate) mod tests { vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -2727,19 +2830,19 @@ pub(crate) mod tests { _ => vec![ top_join_plan.as_str(), // Below 2 operators are differences introduced, when join mode is changed - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -2755,16 +2858,16 @@ pub(crate) mod tests { join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], // Should include 8 RepartitionExecs (4 hash, 8 round-robin), 4 SortExecs @@ -2781,21 +2884,21 @@ pub(crate) mod tests { // Below 4 operators are differences introduced, when join mode is changed "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], }; @@ -2824,16 +2927,16 @@ pub(crate) mod tests { JoinType::Inner | JoinType::Right => vec![ top_join_plan.as_str(), join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -2841,19 +2944,19 @@ pub(crate) mod tests { // Should include 7 RepartitionExecs (4 hash, 3 round-robin) and 4 SortExecs JoinType::Left | JoinType::Full => vec![ top_join_plan.as_str(), - "SortExec: expr=[b1@6 ASC]", + "SortExec: expr=[b1@6 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10", join_plan.as_str(), - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", @@ -2870,16 +2973,16 @@ pub(crate) mod tests { join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], // Should include 8 RepartitionExecs (4 of them preserves order) and 4 SortExecs @@ -2887,21 +2990,21 @@ pub(crate) mod tests { top_join_plan.as_str(), "RepartitionExec: partitioning=Hash([b1@6], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@6 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b1@6 ASC]", + "SortExec: expr=[b1@6 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", join_plan.as_str(), "RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([b1@1], 10), input_partitions=10, preserve_order=true, sort_exprs=b1@1 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b1@1 ASC]", + "SortExec: expr=[b1@1 ASC], preserve_partitioning=[false]", "ProjectionExec: expr=[a@0 as a1, b@1 as b1, c@2 as c1, d@3 as d1, e@4 as e1]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", "RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10, preserve_order=true, sort_exprs=c@2 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ], // this match arm cannot be reached @@ -2970,7 +3073,7 @@ pub(crate) mod tests { // Only two RepartitionExecs added let expected = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "SortExec: expr=[b3@1 ASC,a3@0 ASC]", + "SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[true]", "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", "AggregateExec: mode=FinalPartitioned, gby=[b1@0 as b1, a1@1 as a1], aggr=[]", @@ -2978,7 +3081,7 @@ pub(crate) mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "SortExec: expr=[b2@1 ASC,a2@0 ASC]", + "SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[true]", "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", "RepartitionExec: partitioning=Hash([b@0, a@1], 10), input_partitions=10", @@ -2990,9 +3093,9 @@ pub(crate) mod tests { let expected_first_sort_enforcement = &[ "SortMergeJoin: join_type=Inner, on=[(b3@1, b2@1), (a3@0, a2@0)]", - "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC,a3@0 ASC", + "RepartitionExec: partitioning=Hash([b3@1, a3@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b3@1 ASC, a3@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b3@1 ASC,a3@0 ASC]", + "SortExec: expr=[b3@1 ASC, a3@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "ProjectionExec: expr=[a1@0 as a3, b1@1 as b3]", "ProjectionExec: expr=[a1@1 as a1, b1@0 as b1]", @@ -3001,9 +3104,9 @@ pub(crate) mod tests { "AggregateExec: mode=Partial, gby=[b@1 as b1, a@0 as a1], aggr=[]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", - "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC,a2@0 ASC", + "RepartitionExec: partitioning=Hash([b2@1, a2@0], 10), input_partitions=10, preserve_order=true, sort_exprs=b2@1 ASC, a2@0 ASC", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[b2@1 ASC,a2@0 ASC]", + "SortExec: expr=[b2@1 ASC, a2@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "ProjectionExec: expr=[a@1 as a2, b@0 as b2]", "AggregateExec: mode=FinalPartitioned, gby=[b@0 as b, a@1 as a], aggr=[]", @@ -3021,10 +3124,10 @@ pub(crate) mod tests { fn merge_does_not_need_sort() -> Result<()> { // see https://github.com/apache/datafusion/issues/4331 let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // Scan some sorted parquet files let exec = parquet_exec_multiple_sorted(vec![sort_key.clone()]); @@ -3050,7 +3153,7 @@ pub(crate) mod tests { // hence in this case ordering lost during CoalescePartitionsExec and re-introduced with // SortExec at the top. let expected = &[ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "CoalesceBatchesExec: target_batch_size=4096", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", @@ -3097,7 +3200,67 @@ pub(crate) mod tests { "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; assert_optimized!(expected, plan.clone(), true); - assert_optimized!(expected, plan, false); + assert_optimized!(expected, plan.clone(), false); + + Ok(()) + } + + #[test] + fn union_not_to_interleave() -> Result<()> { + // group by (a as a1) + let left = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a1".to_string())], + ); + // group by (a as a2) + let right = aggregate_exec_with_alias( + parquet_exec(), + vec![("a".to_string(), "a1".to_string())], + ); + + // Union + let plan = Arc::new(UnionExec::new(vec![left, right])); + + // final agg + let plan = + aggregate_exec_with_alias(plan, vec![("a1".to_string(), "a2".to_string())]); + + // Only two RepartitionExecs added, no final RepartitionExec required + let expected = &[ + "AggregateExec: mode=FinalPartitioned, gby=[a2@0 as a2], aggr=[]", + "RepartitionExec: partitioning=Hash([a2@0], 10), input_partitions=20", + "AggregateExec: mode=Partial, gby=[a1@0 as a2], aggr=[]", + "UnionExec", + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + "AggregateExec: mode=FinalPartitioned, gby=[a1@0 as a1], aggr=[]", + "RepartitionExec: partitioning=Hash([a1@0], 10), input_partitions=10", + "AggregateExec: mode=Partial, gby=[a@0 as a1], aggr=[]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", + ]; + // no sort in the plan but since we need it as a parameter, make it default false + let prefer_existing_sort = false; + let first_enforce_distribution = true; + let prefer_existing_union = true; + + assert_optimized!( + expected, + plan.clone(), + first_enforce_distribution, + prefer_existing_sort, + prefer_existing_union + ); + assert_optimized!( + expected, + plan, + !first_enforce_distribution, + prefer_existing_sort, + prefer_existing_union + ); Ok(()) } @@ -3163,17 +3326,17 @@ pub(crate) mod tests { #[test] fn repartition_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = limit_exec(sort_exec(sort_key, parquet_exec(), false)); let expected = &[ "GlobalLimitExec: skip=0, fetch=100", "LocalLimitExec: fetch=100", // data is sorted so can't repartition here - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; assert_optimized!(expected, plan.clone(), true); @@ -3185,10 +3348,10 @@ pub(crate) mod tests { #[test] fn repartition_sorted_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_required_exec_with_req( filter_exec(sort_exec(sort_key.clone(), parquet_exec(), false)), sort_key, @@ -3200,7 +3363,7 @@ pub(crate) mod tests { // We can use repartition here, ordering requirement by SortRequiredExec // is still satisfied. "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; @@ -3264,15 +3427,15 @@ pub(crate) mod tests { fn repartition_through_sort_preserving_merge() -> Result<()> { // sort preserving merge with non-sorted input let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_preserving_merge_exec(sort_key, parquet_exec()); // need resort as the data was not sorted correctly let expected = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; assert_optimized!(expected, plan.clone(), true); @@ -3285,10 +3448,10 @@ pub(crate) mod tests { fn repartition_ignores_sort_preserving_merge() -> Result<()> { // sort preserving merge already sorted input, let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_preserving_merge_exec( sort_key.clone(), parquet_exec_multiple_sorted(vec![sort_key]), @@ -3304,7 +3467,7 @@ pub(crate) mod tests { assert_optimized!(expected, plan.clone(), true); let expected = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", ]; @@ -3317,10 +3480,10 @@ pub(crate) mod tests { fn repartition_ignores_sort_preserving_merge_with_union() -> Result<()> { // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); let plan = sort_preserving_merge_exec(sort_key, input); @@ -3335,7 +3498,7 @@ pub(crate) mod tests { assert_optimized!(expected, plan.clone(), true); let expected = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "UnionExec", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", @@ -3351,10 +3514,10 @@ pub(crate) mod tests { // SortRequired // Parquet(sorted) let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("d", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_required_exec_with_req( filter_exec(parquet_exec_with_sort(vec![sort_key.clone()])), sort_key, @@ -3386,10 +3549,10 @@ pub(crate) mod tests { // Parquet(unsorted) let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input1 = sort_required_exec_with_req( parquet_exec_with_sort(vec![sort_key.clone()]), sort_key, @@ -3428,15 +3591,15 @@ pub(crate) mod tests { )]; // non sorted input let proj = Arc::new(ProjectionExec::try_new(proj_exprs, parquet_exec())?); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("sum", &proj.schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_preserving_merge_exec(sort_key, proj); let expected = &[ "SortPreservingMergeExec: [sum@0 ASC]", - "SortExec: expr=[sum@0 ASC]", + "SortExec: expr=[sum@0 ASC], preserve_partitioning=[true]", // Since this projection is not trivial, increasing parallelism is beneficial "ProjectionExec: expr=[a@0 + b@1 as sum]", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -3446,7 +3609,7 @@ pub(crate) mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[sum@0 ASC]", + "SortExec: expr=[sum@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", // Since this projection is not trivial, increasing parallelism is beneficial "ProjectionExec: expr=[a@0 + b@1 as sum]", @@ -3461,10 +3624,10 @@ pub(crate) mod tests { #[test] fn repartition_ignores_transitively_with_projection() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![ ("a".to_string(), "a".to_string()), ("b".to_string(), "b".to_string()), @@ -3494,11 +3657,15 @@ pub(crate) mod tests { #[test] fn repartition_transitively_past_sort_with_projection() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; - let alias = vec![("a".to_string(), "a".to_string())]; + }]); + let alias = vec![ + ("a".to_string(), "a".to_string()), + ("b".to_string(), "b".to_string()), + ("c".to_string(), "c".to_string()), + ]; let plan = sort_preserving_merge_exec( sort_key.clone(), sort_exec( @@ -3509,9 +3676,9 @@ pub(crate) mod tests { ); let expected = &[ - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[true]", // Since this projection is trivial, increasing parallelism is not beneficial - "ProjectionExec: expr=[a@0 as a]", + "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; assert_optimized!(expected, plan.clone(), true); @@ -3523,15 +3690,15 @@ pub(crate) mod tests { #[test] fn repartition_transitively_past_sort_with_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_exec(sort_key, filter_exec(parquet_exec()), false); let expected = &[ "SortPreservingMergeExec: [a@0 ASC]", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -3541,7 +3708,7 @@ pub(crate) mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "FilterExec: c@2 = 0", // Expect repartition on the input of the filter (as it can benefit from additional parallelism) @@ -3557,10 +3724,10 @@ pub(crate) mod tests { #[cfg(feature = "parquet")] fn repartition_transitively_past_sort_with_projection_and_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = sort_exec( sort_key, projection_exec_with_alias( @@ -3577,7 +3744,7 @@ pub(crate) mod tests { let expected = &[ "SortPreservingMergeExec: [a@0 ASC]", // Expect repartition on the input to the sort (as it can benefit from additional parallelism) - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", // repartition is lowest down @@ -3588,7 +3755,7 @@ pub(crate) mod tests { assert_optimized!(expected, plan.clone(), true); let expected_first_sort_enforcement = &[ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c]", "FilterExec: c@2 = 0", @@ -3627,10 +3794,10 @@ pub(crate) mod tests { #[test] fn parallelization_multiple_files() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan = filter_exec(parquet_exec_multiple_sorted(vec![sort_key.clone()])); let plan = sort_required_exec_with_req(plan, sort_key); @@ -3651,7 +3818,8 @@ pub(crate) mod tests { true, target_partitions, true, - repartition_size + repartition_size, + false ); let expected = [ @@ -3668,7 +3836,8 @@ pub(crate) mod tests { true, target_partitions, true, - repartition_size + repartition_size, + false ); Ok(()) @@ -3709,29 +3878,26 @@ pub(crate) mod tests { }; let plan = aggregate_exec_with_alias( - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema(), - file_groups: vec![vec![PartitionedFile::new( - "x".to_string(), - 100, - )]], - statistics: Statistics::new_unknown(&schema()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - false, - b',', - b'"', - None, - compression_type, - )), + Arc::new( + CsvExec::builder( + FileScanConfig::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema(), + ) + .with_file(PartitionedFile::new("x".to_string(), 100)), + ) + .with_has_header(false) + .with_delimeter(b',') + .with_quote(b'"') + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(compression_type) + .build(), + ), vec![("a".to_string(), "a".to_string())], ); - assert_optimized!(expected, plan, true, false, 2, true, 10); + assert_optimized!(expected, plan, true, false, 2, true, 10, false); } Ok(()) } @@ -3741,7 +3907,7 @@ pub(crate) mod tests { let alias = vec![("a".to_string(), "a".to_string())]; let plan_parquet = aggregate_exec_with_alias(parquet_exec_multiple(), alias.clone()); - let plan_csv = aggregate_exec_with_alias(csv_exec_multiple(), alias.clone()); + let plan_csv = aggregate_exec_with_alias(csv_exec_multiple(), alias); let expected_parquet = [ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", @@ -3767,7 +3933,7 @@ pub(crate) mod tests { let alias = vec![("a".to_string(), "a".to_string())]; let plan_parquet = aggregate_exec_with_alias(parquet_exec_multiple(), alias.clone()); - let plan_csv = aggregate_exec_with_alias(csv_exec_multiple(), alias.clone()); + let plan_csv = aggregate_exec_with_alias(csv_exec_multiple(), alias); let expected_parquet = [ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", @@ -3792,18 +3958,18 @@ pub(crate) mod tests { #[test] fn parallelization_sorted_limit() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_parquet = limit_exec(sort_exec(sort_key.clone(), parquet_exec(), false)); - let plan_csv = limit_exec(sort_exec(sort_key.clone(), csv_exec(), false)); + let plan_csv = limit_exec(sort_exec(sort_key, csv_exec(), false)); let expected_parquet = &[ "GlobalLimitExec: skip=0, fetch=100", "LocalLimitExec: fetch=100", // data is sorted so can't repartition here - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", // Doesn't parallelize for SortExec without preserve_partitioning "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; @@ -3811,7 +3977,7 @@ pub(crate) mod tests { "GlobalLimitExec: skip=0, fetch=100", "LocalLimitExec: fetch=100", // data is sorted so can't repartition here - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", // Doesn't parallelize for SortExec without preserve_partitioning "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", ]; @@ -3824,17 +3990,16 @@ pub(crate) mod tests { #[test] fn parallelization_limit_with_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_parquet = limit_exec(filter_exec(sort_exec( sort_key.clone(), parquet_exec(), false, ))); - let plan_csv = - limit_exec(filter_exec(sort_exec(sort_key.clone(), csv_exec(), false))); + let plan_csv = limit_exec(filter_exec(sort_exec(sort_key, csv_exec(), false))); let expected_parquet = &[ "GlobalLimitExec: skip=0, fetch=100", @@ -3844,7 +4009,7 @@ pub(crate) mod tests { // even though data is sorted, we can use repartition here. Since // ordering is not used in subsequent stages anyway. "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", // SortExec doesn't benefit from input partitioning "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e]", ]; @@ -3856,7 +4021,7 @@ pub(crate) mod tests { // even though data is sorted, we can use repartition here. Since // ordering is not used in subsequent stages anyway. "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - "SortExec: expr=[c@2 ASC]", + "SortExec: expr=[c@2 ASC], preserve_partitioning=[false]", // SortExec doesn't benefit from input partitioning "CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", ]; @@ -3875,7 +4040,7 @@ pub(crate) mod tests { ); let plan_csv = aggregate_exec_with_alias( limit_exec(filter_exec(limit_exec(csv_exec()))), - alias.clone(), + alias, ); let expected_parquet = &[ @@ -3948,10 +4113,10 @@ pub(crate) mod tests { #[test] fn parallelization_prior_to_sort_preserving_merge() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // sort preserving merge already sorted input, let plan_parquet = sort_preserving_merge_exec( sort_key.clone(), @@ -3959,7 +4124,7 @@ pub(crate) mod tests { ); let plan_csv = sort_preserving_merge_exec( sort_key.clone(), - csv_exec_with_sort(vec![sort_key.clone()]), + csv_exec_with_sort(vec![sort_key]), ); // parallelization is not beneficial for SortPreservingMerge @@ -3978,16 +4143,16 @@ pub(crate) mod tests { #[test] fn parallelization_sort_preserving_merge_with_union() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // 2 sorted parquet files unioned (partitions are concatenated, sort is preserved) let input_parquet = union_exec(vec![parquet_exec_with_sort(vec![sort_key.clone()]); 2]); let input_csv = union_exec(vec![csv_exec_with_sort(vec![sort_key.clone()]); 2]); let plan_parquet = sort_preserving_merge_exec(sort_key.clone(), input_parquet); - let plan_csv = sort_preserving_merge_exec(sort_key.clone(), input_csv); + let plan_csv = sort_preserving_merge_exec(sort_key, input_csv); // should not repartition (union doesn't benefit from increased parallelism) // should not sort (as the data was already sorted) @@ -4012,10 +4177,10 @@ pub(crate) mod tests { #[test] fn parallelization_does_not_benefit() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); // SortRequired // Parquet(sorted) let plan_parquet = sort_required_exec_with_req( @@ -4046,10 +4211,10 @@ pub(crate) mod tests { fn parallelization_ignores_transitively_with_projection_parquet() -> Result<()> { // sorted input let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -4057,13 +4222,13 @@ pub(crate) mod tests { ("c".to_string(), "c2".to_string()), ]; let proj_parquet = projection_exec_with_alias( - parquet_exec_with_sort(vec![sort_key.clone()]), - alias_pairs.clone(), + parquet_exec_with_sort(vec![sort_key]), + alias_pairs, ); - let sort_key_after_projection = vec![PhysicalSortExpr { + let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c2", &proj_parquet.schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_parquet = sort_preserving_merge_exec(sort_key_after_projection, proj_parquet); let expected = &[ @@ -4087,10 +4252,10 @@ pub(crate) mod tests { fn parallelization_ignores_transitively_with_projection_csv() -> Result<()> { // sorted input let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); //Projection(a as a2, b as b2) let alias_pairs: Vec<(String, String)> = vec![ @@ -4100,10 +4265,10 @@ pub(crate) mod tests { let proj_csv = projection_exec_with_alias(csv_exec_with_sort(vec![sort_key]), alias_pairs); - let sort_key_after_projection = vec![PhysicalSortExpr { + let sort_key_after_projection = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c2", &proj_csv.schema()).unwrap(), options: SortOptions::default(), - }]; + }]); let plan_csv = sort_preserving_merge_exec(sort_key_after_projection, proj_csv); let expected = &[ "SortPreservingMergeExec: [c2@1 ASC]", @@ -4150,10 +4315,10 @@ pub(crate) mod tests { #[test] fn remove_unnecessary_spm_after_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4175,10 +4340,10 @@ pub(crate) mod tests { #[test] fn preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("d", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4198,16 +4363,16 @@ pub(crate) mod tests { #[test] fn do_not_preserve_ordering_through_repartition() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); let expected = &[ "SortPreservingMergeExec: [a@0 ASC]", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", @@ -4216,7 +4381,7 @@ pub(crate) mod tests { assert_optimized!(expected, physical_plan.clone(), true); let expected = &[ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", @@ -4230,10 +4395,10 @@ pub(crate) mod tests { #[test] fn no_need_for_sort_after_filter() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_preserving_merge_exec(sort_key, filter_exec(input)); @@ -4254,21 +4419,21 @@ pub(crate) mod tests { #[test] fn do_not_preserve_ordering_through_repartition2() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key]); - let sort_req = vec![PhysicalSortExpr { + let sort_req = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let physical_plan = sort_preserving_merge_exec(sort_req, filter_exec(input)); let expected = &[ "SortPreservingMergeExec: [a@0 ASC]", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", @@ -4277,9 +4442,9 @@ pub(crate) mod tests { assert_optimized!(expected, physical_plan.clone(), true); let expected = &[ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", "CoalescePartitionsExec", - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", "FilterExec: c@2 = 0", "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", "ParquetExec: file_groups={2 groups: [[x], [y]]}, projection=[a, b, c, d, e], output_ordering=[c@2 ASC]", @@ -4292,10 +4457,10 @@ pub(crate) mod tests { #[test] fn do_not_preserve_ordering_through_repartition3() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key]); let physical_plan = filter_exec(input); @@ -4313,10 +4478,10 @@ pub(crate) mod tests { #[test] fn do_not_put_sort_when_input_is_invalid() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec(); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); let expected = &[ @@ -4350,10 +4515,10 @@ pub(crate) mod tests { #[test] fn put_sort_when_input_is_valid() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let input = parquet_exec_multiple_sorted(vec![sort_key.clone()]); let physical_plan = sort_required_exec_with_req(filter_exec(input), sort_key); @@ -4387,13 +4552,13 @@ pub(crate) mod tests { #[test] fn do_not_add_unnecessary_hash() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_with_sort(vec![sort_key]); - let physical_plan = aggregate_exec_with_alias(input, alias.clone()); + let physical_plan = aggregate_exec_with_alias(input, alias); let expected = &[ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", @@ -4410,14 +4575,14 @@ pub(crate) mod tests { #[test] fn do_not_add_unnecessary_hash2() -> Result<()> { let schema = schema(); - let sort_key = vec![PhysicalSortExpr { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let alias = vec![("a".to_string(), "a".to_string())]; let input = parquet_exec_multiple_sorted(vec![sort_key]); let aggregate = aggregate_exec_with_alias(input, alias.clone()); - let physical_plan = aggregate_exec_with_alias(aggregate, alias.clone()); + let physical_plan = aggregate_exec_with_alias(aggregate, alias); let expected = &[ "AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[]", diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index a1e2f6c666dc..7b111cddc6fd 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -20,6 +20,7 @@ //! - Adds a [`SortExec`] when a requirement is not met, //! - Removes an already-existing [`SortExec`] if it is possible to prove //! that this sort is unnecessary +//! //! The rule can work on valid *and* invalid physical plans with respect to //! sorting requirements, but always produces a valid physical plan in this sense. //! @@ -49,7 +50,6 @@ use crate::physical_optimizer::utils::{ is_coalesce_partitions, is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_union, is_window, }; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -61,7 +61,10 @@ use crate::physical_plan::{Distribution, ExecutionPlan, InputOrderMode}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::{PhysicalSortExpr, PhysicalSortRequirement}; +use datafusion_physical_expr::{Partitioning, PhysicalSortRequirement}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; use datafusion_physical_plan::sorts::partial_sort::PartialSortExec; use datafusion_physical_plan::ExecutionPlanProperties; @@ -70,7 +73,7 @@ use itertools::izip; /// This rule inspects [`SortExec`]'s in the given physical plan and removes the /// ones it can prove unnecessary. -#[derive(Default)] +#[derive(Default, Debug)] pub struct EnforceSorting {} impl EnforceSorting { @@ -188,7 +191,7 @@ impl PhysicalOptimizerRule for EnforceSorting { // missed by the bottom-up traversal: let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - let adjusted = sort_pushdown.transform_down(pushdown_sorts)?.data; + let adjusted = pushdown_sorts(sort_pushdown)?; adjusted .plan @@ -229,7 +232,7 @@ fn replace_with_partial_sort( if common_prefix_length > 0 { return Ok(Arc::new( PartialSortExec::new( - sort_plan.expr().to_vec(), + LexOrdering::new(sort_plan.expr().to_vec()), sort_plan.input().clone(), common_prefix_length, ) @@ -273,14 +276,14 @@ fn parallelize_sorts( // Take the initial sort expressions and requirements let (sort_exprs, fetch) = get_sort_exprs(&requirements.plan)?; let sort_reqs = PhysicalSortRequirement::from_sort_exprs(sort_exprs); - let sort_exprs = sort_exprs.to_vec(); + let sort_exprs = LexOrdering::new(sort_exprs.to_vec()); // If there is a connection between a `CoalescePartitionsExec` and a // global sort that satisfy the requirements (i.e. intermediate // executors don't require single partition), then we can replace // the `CoalescePartitionsExec` + `SortExec` cascade with a `SortExec` // + `SortPreservingMergeExec` cascade to parallelize sorting. - requirements = remove_corresponding_coalesce_in_sub_plan(requirements)?; + requirements = remove_bottleneck_in_subplan(requirements)?; // We also need to remove the self node since `remove_corresponding_coalesce_in_sub_plan` // deals with the children and their children and so on. requirements = requirements.children.swap_remove(0); @@ -298,7 +301,7 @@ fn parallelize_sorts( } else if is_coalesce_partitions(&requirements.plan) { // There is an unnecessary `CoalescePartitionsExec` in the plan. // This will handle the recursive `CoalescePartitionsExec` plans. - requirements = remove_corresponding_coalesce_in_sub_plan(requirements)?; + requirements = remove_bottleneck_in_subplan(requirements)?; // For the removal of self node which is also a `CoalescePartitionsExec`. requirements = requirements.children.swap_remove(0); @@ -388,20 +391,33 @@ fn analyze_immediate_sort_removal( if let Some(sort_exec) = node.plan.as_any().downcast_ref::() { let sort_input = sort_exec.input(); // If this sort is unnecessary, we should remove it: - if sort_input - .equivalence_properties() - .ordering_satisfy(sort_exec.properties().output_ordering().unwrap_or(&[])) - { + if sort_input.equivalence_properties().ordering_satisfy( + sort_exec.properties().output_ordering().unwrap_or_default(), + ) { node.plan = if !sort_exec.preserve_partitioning() && sort_input.output_partitioning().partition_count() > 1 { // Replace the sort with a sort-preserving merge: - let expr = sort_exec.expr().to_vec(); + let expr = LexOrdering::new(sort_exec.expr().to_vec()); Arc::new(SortPreservingMergeExec::new(expr, sort_input.clone())) as _ } else { // Remove the sort: node.children = node.children.swap_remove(0).children; - sort_input.clone() + if let Some(fetch) = sort_exec.fetch() { + // If the sort has a fetch, we need to add a limit: + if sort_exec + .properties() + .output_partitioning() + .partition_count() + == 1 + { + Arc::new(GlobalLimitExec::new(sort_input.clone(), 0, Some(fetch))) + } else { + Arc::new(LocalLimitExec::new(sort_input.clone(), fetch)) + } + } else { + sort_input.clone() + } }; for child in node.children.iter_mut() { child.data = false; @@ -483,8 +499,11 @@ fn adjust_window_sort_removal( Ok(window_tree) } -/// Removes the [`CoalescePartitionsExec`] from the plan in `node`. -fn remove_corresponding_coalesce_in_sub_plan( +/// Removes parallelization-reducing, avoidable [`CoalescePartitionsExec`]s from +/// the plan in `node`. After the removal of such `CoalescePartitionsExec`s from +/// the plan, some of the remaining `RepartitionExec`s might become unnecessary. +/// Removes such `RepartitionExec`s from the plan as well. +fn remove_bottleneck_in_subplan( mut requirements: PlanWithCorrespondingCoalescePartitions, ) -> Result { let plan = &requirements.plan; @@ -505,15 +524,27 @@ fn remove_corresponding_coalesce_in_sub_plan( .into_iter() .map(|node| { if node.data { - remove_corresponding_coalesce_in_sub_plan(node) + remove_bottleneck_in_subplan(node) } else { Ok(node) } }) .collect::>()?; } - - requirements.update_plan_from_children() + let mut new_reqs = requirements.update_plan_from_children()?; + if let Some(repartition) = new_reqs.plan.as_any().downcast_ref::() { + let input_partitioning = repartition.input().output_partitioning(); + // We can remove this repartitioning operator if it is now a no-op: + let mut can_remove = input_partitioning.eq(repartition.partitioning()); + // We can also remove it if we ended up with an ineffective RR: + if let Partitioning::RoundRobinBatch(n_out) = repartition.partitioning() { + can_remove |= *n_out == input_partitioning.partition_count(); + } + if can_remove { + new_reqs = new_reqs.children.swap_remove(0) + } + } + Ok(new_reqs) } /// Updates child to remove the unnecessary sort below it. @@ -539,8 +570,11 @@ fn remove_corresponding_sort_from_sub_plan( requires_single_partition: bool, ) -> Result { // A `SortExec` is always at the bottom of the tree. - if is_sort(&node.plan) { - node = node.children.swap_remove(0); + if let Some(sort_exec) = node.plan.as_any().downcast_ref::() { + // Do not remove sorts with fetch: + if sort_exec.fetch().is_none() { + node = node.children.swap_remove(0); + } } else { let mut any_connection = false; let required_dist = node.plan.required_input_distribution(); @@ -567,7 +601,7 @@ fn remove_corresponding_sort_from_sub_plan( // Replace with variants that do not preserve order. if is_sort_preserving_merge(&node.plan) { node.children = node.children.swap_remove(0).children; - node.plan = node.plan.children().swap_remove(0); + node.plan = node.plan.children().swap_remove(0).clone(); } else if let Some(repartition) = node.plan.as_any().downcast_ref::() { @@ -585,7 +619,10 @@ fn remove_corresponding_sort_from_sub_plan( // `SortPreservingMergeExec` instead of a `CoalescePartitionsExec`. let plan = node.plan.clone(); let plan = if let Some(ordering) = plan.output_ordering() { - Arc::new(SortPreservingMergeExec::new(ordering.to_vec(), plan)) as _ + Arc::new(SortPreservingMergeExec::new( + LexOrdering::new(ordering.to_vec()), + plan, + )) as _ } else { Arc::new(CoalescePartitionsExec::new(plan)) as _ }; @@ -595,10 +632,10 @@ fn remove_corresponding_sort_from_sub_plan( Ok(node) } -/// Converts an [ExecutionPlan] trait object to a [PhysicalSortExpr] slice when possible. +/// Converts an [ExecutionPlan] trait object to a [LexOrderingRef] when possible. fn get_sort_exprs( sort_any: &Arc, -) -> Result<(&[PhysicalSortExpr], Option)> { +) -> Result<(LexOrderingRef, Option)> { if let Some(sort_exec) = sort_any.as_any().downcast_ref::() { Ok((sort_exec.expr(), sort_exec.fetch())) } else if let Some(spm) = sort_any.as_any().downcast_ref::() @@ -611,7 +648,6 @@ fn get_sort_exprs( #[cfg(test)] mod tests { - use super::*; use crate::physical_optimizer::enforce_distribution::EnforceDistribution; use crate::physical_optimizer::test_utils::{ @@ -620,6 +656,7 @@ mod tests { limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, spr_repartition_exec, union_exec, + RequirementsTestExec, }; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; @@ -630,6 +667,8 @@ mod tests { use datafusion_common::Result; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::{col, Column, NotExpr}; + use datafusion_physical_optimizer::PhysicalOptimizerRule; + use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use rstest::rstest; @@ -713,10 +752,7 @@ mod tests { let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - sort_pushdown - .transform_down(pushdown_sorts) - .data() - .and_then(check_integrity)?; + check_integrity(pushdown_sorts(sort_pushdown)?)?; // TODO: End state payloads will be checked here. } @@ -757,12 +793,12 @@ mod tests { let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); let expected_input = [ - "SortExec: expr=[nullable_col@0 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC]", + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC]", + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -811,19 +847,19 @@ mod tests { let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " FilterExec: NOT non_nullable_col@1", - " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[non_nullable_col@1 ASC NULLS LAST], preserve_partitioning=[false]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " CoalesceBatchesExec: target_batch_size=128", - " SortExec: expr=[non_nullable_col@1 DESC]", + " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + let expected_optimized = ["WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", " FilterExec: NOT non_nullable_col@1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " CoalesceBatchesExec: target_batch_size=128", - " SortExec: expr=[non_nullable_col@1 DESC]", + " SortExec: expr=[non_nullable_col@1 DESC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -844,7 +880,7 @@ mod tests { " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC]", + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -865,13 +901,13 @@ mod tests { let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC]", + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -901,11 +937,11 @@ mod tests { let expected_input = [ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC]", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; @@ -927,10 +963,10 @@ mod tests { let sort = sort_exec(sort_exprs.clone(), source); let spm = sort_preserving_merge_exec(sort_exprs, sort); - let sort_exprs = vec![ + let sort_exprs = LexOrdering::new(vec![ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ]); let repartition_exec = repartition_exec(spm); let sort2 = Arc::new( SortExec::new(sort_exprs.clone(), repartition_exec) @@ -945,11 +981,11 @@ mod tests { // it with a `CoalescePartitionsExec` instead of directly removing it. let expected_input = [ "AggregateExec: mode=Final, gby=[], aggr=[]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC]", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; @@ -972,7 +1008,7 @@ mod tests { let source2 = repartition_exec(memory_exec(&schema)); let union = union_exec(vec![source1, source2]); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; + let sort_exprs = LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]); // let sort = sort_exec(sort_exprs.clone(), union); let sort = Arc::new( SortExec::new(sort_exprs.clone(), union).with_preserve_partitioning(true), @@ -995,18 +1031,18 @@ mod tests { // When removing a `SortPreservingMergeExec`, make sure that partitioning // requirements are not violated. In some cases, we may need to replace // it with a `CoalescePartitionsExec` instead of directly removing it. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[non_nullable_col@1 ASC]", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[true]", " UnionExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " MemoryExec: partitions=1, partition_sizes=[0]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " UnionExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -1033,7 +1069,7 @@ mod tests { let join = hash_join_exec(left_input, right_input, on, None, &JoinType::Inner)?; let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())], join); - let expected_input = ["SortExec: expr=[a@2 ASC]", + let expected_input = ["SortExec: expr=[a@2 ASC], preserve_partitioning=[false]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col_a@0, c@2)]", " MemoryExec: partitions=1, partition_sizes=[0]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]"]; @@ -1046,6 +1082,136 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_remove_unnecessary_sort6() -> Result<()> { + let schema = create_test_schema()?; + let source = memory_exec(&schema); + let input = Arc::new( + SortExec::new( + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), + source, + ) + .with_fetch(Some(2)), + ); + let physical_plan = sort_exec( + vec![ + sort_expr("non_nullable_col", &schema), + sort_expr("nullable_col", &schema), + ], + input, + ); + + let expected_input = [ + "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", + " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort7() -> Result<()> { + let schema = create_test_schema()?; + let source = memory_exec(&schema); + let input = Arc::new(SortExec::new( + LexOrdering::new(vec![ + sort_expr("non_nullable_col", &schema), + sort_expr("nullable_col", &schema), + ]), + source, + )); + + let physical_plan = Arc::new( + SortExec::new( + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), + input, + ) + .with_fetch(Some(2)), + ) as Arc; + + let expected_input = [ + "SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", + " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "GlobalLimitExec: skip=0, fetch=2", + " SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort8() -> Result<()> { + let schema = create_test_schema()?; + let source = memory_exec(&schema); + let input = Arc::new(SortExec::new( + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), + source, + )); + let limit = Arc::new(LocalLimitExec::new(input, 2)); + let physical_plan = sort_exec( + vec![ + sort_expr("non_nullable_col", &schema), + sort_expr("nullable_col", &schema), + ], + limit, + ); + + let expected_input = [ + "SortExec: expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", + " LocalLimitExec: fetch=2", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "LocalLimitExec: fetch=2", + " SortExec: TopK(fetch=2), expr=[non_nullable_col@1 ASC, nullable_col@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + Ok(()) + } + + #[tokio::test] + async fn test_do_not_pushdown_through_limit() -> Result<()> { + let schema = create_test_schema()?; + let source = memory_exec(&schema); + // let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); + let input = Arc::new(SortExec::new( + LexOrdering::new(vec![sort_expr("non_nullable_col", &schema)]), + source, + )); + let limit = Arc::new(GlobalLimitExec::new(input, 0, Some(5))) as _; + let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], limit); + + let expected_input = [ + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " GlobalLimitExec: skip=0, fetch=5", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + let expected_optimized = [ + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " GlobalLimitExec: skip=0, fetch=5", + " SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected_optimized, physical_plan, true); + + Ok(()) + } + #[tokio::test] async fn test_remove_unnecessary_spm1() -> Result<()> { let schema = create_test_schema()?; @@ -1068,7 +1234,7 @@ mod tests { " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC]", + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1095,24 +1261,24 @@ mod tests { let repartition = repartition_exec(union); let physical_plan = sort_preserving_merge_exec(sort_exprs, repartition); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // We should keep the bottom `SortExec`. - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=2", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1130,12 +1296,12 @@ mod tests { let sort = sort_exec(vec![sort_exprs[0].clone()], source); let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); let expected_input = [ - "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + "SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1158,12 +1324,12 @@ mod tests { let expected_input = [ "SortPreservingMergeExec: [non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; let expected_optimized = [ - "SortExec: expr=[non_nullable_col@1 ASC]", + "SortExec: expr=[non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1189,7 +1355,7 @@ mod tests { "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; // should not add a sort at the output of the union, input plan should not be changed @@ -1221,7 +1387,7 @@ mod tests { "SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", ]; // should not add a sort at the output of the union, input plan should not be changed @@ -1251,17 +1417,17 @@ mod tests { // Input is an invalid plan. In this case rule should add required sorting in appropriate places. // First ParquetExec has output ordering(nullable_col@0 ASC). However, it doesn't satisfy the // required ordering of SortPreservingMergeExec. - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1292,18 +1458,18 @@ mod tests { // Third input to the union is not Sorted (SortExec is matches required ordering by the SortPreservingMergeExec above). let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // should adjust sorting in the first input of the union such that it is not unnecessarily fine let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1332,20 +1498,20 @@ mod tests { // Should modify the plan to ensure that all three inputs to the // `UnionExec` satisfy the ordering, OR add a single sort after // the `UnionExec` (both of which are equally good for this example). - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; - let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1384,15 +1550,15 @@ mod tests { // fine `SortExec`s below with required `SortExec`s that are absolutely necessary. let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1427,20 +1593,20 @@ mod tests { // shouldn't be finer than necessary. let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Should adjust the requirement in the third input of the union so // that it is not unnecessarily fine. let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1467,16 +1633,16 @@ mod tests { // Union has unnecessarily fine ordering below it. We should be able to replace them with absolutely necessary ordering. let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Union preserves the inputs ordering and we should not change any of the SortExecs under UnionExec let expected_output = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_output, physical_plan, true); @@ -1518,9 +1684,9 @@ mod tests { // The `UnionExec` doesn't preserve any of the inputs ordering in the // example below. let expected_input = ["UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST,non_nullable_col@1 DESC NULLS LAST]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; // Since `UnionExec` doesn't preserve ordering in the plan above. // We shouldn't keep SortExecs in the plan. @@ -1564,15 +1730,15 @@ mod tests { // corresponding SortExecs together. Also, the inputs of these `SortExec`s // are not necessarily the same to be able to remove them. let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 DESC NULLS LAST]", " UnionExec", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", - " SortExec: expr=[nullable_col@0 DESC NULLS LAST]", + " SortExec: expr=[nullable_col@0 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; let expected_optimized = [ - "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL), is_causal: false }]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC, non_nullable_col@1 ASC]", @@ -1586,10 +1752,10 @@ mod tests { async fn test_window_multi_path_sort2() -> Result<()> { let schema = create_test_schema()?; - let sort_exprs1 = vec![ + let sort_exprs1 = LexOrdering::new(vec![ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), - ]; + ]); let sort_exprs2 = vec![sort_expr("nullable_col", &schema)]; let source1 = parquet_exec_sorted(&schema, sort_exprs2.clone()); let source2 = parquet_exec_sorted(&schema, sort_exprs2.clone()); @@ -1602,14 +1768,14 @@ mod tests { // The `WindowAggExec` can get its required sorting from the leaf nodes directly. // The unnecessary SortExecs should be removed - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]"]; - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col], output_ordering=[nullable_col@0 ASC]", @@ -1652,19 +1818,19 @@ mod tests { // Should not change the unnecessarily fine `SortExec`s because there is `LimitExec` let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", " UnionExec", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " GlobalLimitExec: skip=0, fetch=100", " LocalLimitExec: fetch=100", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 DESC NULLS LAST]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 DESC NULLS LAST], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1709,7 +1875,7 @@ mod tests { let join_plan2 = format!( " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" ); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, non_nullable_col@1 ASC]", join_plan2.as_str(), " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; @@ -1721,20 +1887,20 @@ mod tests { // can push down the sort requirements and save 1 SortExec vec![ join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC]", + " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ] } _ => { // can not push down the sort requirements vec![ - "SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + "SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", join_plan2.as_str(), - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC]", + " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ] } @@ -1780,9 +1946,9 @@ mod tests { ); let spm_plan = match join_type { JoinType::RightAnti => { - "SortPreservingMergeExec: [col_a@0 ASC,col_b@1 ASC]" + "SortPreservingMergeExec: [col_a@0 ASC, col_b@1 ASC]" } - _ => "SortPreservingMergeExec: [col_a@2 ASC,col_b@3 ASC]", + _ => "SortPreservingMergeExec: [col_a@2 ASC, col_b@3 ASC]", }; let join_plan2 = format!( " SortMergeJoin: join_type={join_type}, on=[(nullable_col@0, col_a@0)]" @@ -1796,20 +1962,20 @@ mod tests { // can push down the sort requirements and save 1 SortExec vec![ join_plan.as_str(), - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC,col_b@1 ASC]", + " SortExec: expr=[col_a@0 ASC, col_b@1 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ] } _ => { // can not push down the sort requirements for Left and Full join. vec![ - "SortExec: expr=[col_a@2 ASC,col_b@3 ASC]", + "SortExec: expr=[col_a@2 ASC, col_b@3 ASC], preserve_partitioning=[false]", join_plan2.as_str(), - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC]", + " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]", ] } @@ -1843,17 +2009,17 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs1, join.clone()); - let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC,col_a@2 ASC]", + let expected_input = ["SortPreservingMergeExec: [col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[col_b@3 ASC,col_a@2 ASC]", + let expected_optimized = ["SortExec: expr=[col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC]", + " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1865,17 +2031,17 @@ mod tests { ]; let physical_plan = sort_preserving_merge_exec(sort_exprs2, join); - let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", + let expected_input = ["SortPreservingMergeExec: [nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; // can not push down the sort requirements, need to add SortExec - let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC,col_b@3 ASC,col_a@2 ASC]", + let expected_optimized = ["SortExec: expr=[nullable_col@0 ASC, col_b@3 ASC, col_a@2 ASC], preserve_partitioning=[false]", " SortMergeJoin: join_type=Inner, on=[(nullable_col@0, col_a@0)]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]", - " SortExec: expr=[col_a@0 ASC]", + " SortExec: expr=[col_a@0 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[col_a, col_b]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1902,16 +2068,16 @@ mod tests { let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs1, window_agg2); - let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC]", + let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]"]; - let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[nullable_col@0 ASC,non_nullable_col@1 ASC]", + let expected_optimized = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[nullable_col@0 ASC, non_nullable_col@1 ASC], preserve_partitioning=[false]", " MemoryExec: partitions=1, partition_sizes=[0]"]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -1938,13 +2104,13 @@ mod tests { // CoalescePartitionsExec and SortExec are not directly consecutive. In this case // we should be able to parallelize Sorting also (given that executors in between don't require) // single partition. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " FilterExec: NOT non_nullable_col@1", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; let expected_optimized = ["SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " FilterExec: NOT non_nullable_col@1", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " ParquetExec: file_groups={1 group: [[x]]}, projection=[nullable_col, non_nullable_col]"]; @@ -1966,7 +2132,7 @@ mod tests { let state = session_ctx.state(); let memory_exec = memory_exec(&schema); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); let window = bounded_window_exec("nullable_col", sort_exprs.clone(), memory_exec); let repartition = repartition_exec(window); @@ -1974,9 +2140,9 @@ mod tests { Arc::new(SortExec::new(sort_exprs, repartition)) as Arc; let actual = get_plan_string(&orig_plan); let expected_input = vec![ - "SortExec: expr=[nullable_col@0 ASC]", + "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " MemoryExec: partitions=1, partition_sizes=[0]", ]; assert_eq!( @@ -2016,7 +2182,7 @@ mod tests { let repartition = repartition_exec(source); let coalesce_partitions = Arc::new(CoalescePartitionsExec::new(repartition)); let repartition = repartition_exec(coalesce_partitions); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; + let sort_exprs = LexOrdering::new(vec![sort_expr("nullable_col", &schema)]); // Add local sort let sort = Arc::new( SortExec::new(sort_exprs.clone(), repartition) @@ -2028,16 +2194,16 @@ mod tests { let physical_plan = sort.clone(); // Sort Parallelize rule should end Coalesce + Sort linkage when Sort is Global Sort // Also input plan is not valid as it is. We need to add SortExec before SortPreservingMergeExec. - let expected_input = ["SortExec: expr=[nullable_col@0 ASC]", + let expected_input = ["SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", " SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " MemoryExec: partitions=1, partition_sizes=[0]"]; let expected_optimized = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", - " SortExec: expr=[nullable_col@0 ASC]", + " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " MemoryExec: partitions=1, partition_sizes=[0]", ]; @@ -2059,13 +2225,13 @@ mod tests { let coalesce_partitions = coalesce_partitions_exec(repartition_hash); let physical_plan = sort_exec(vec![sort_expr("a", &schema)], coalesce_partitions); - let expected_input = ["SortExec: expr=[a@0 ASC]", + let expected_input = ["SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=false"]; let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=false"]; @@ -2097,14 +2263,14 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = vec![ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; let expected_input_bounded = vec![ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -2121,7 +2287,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = vec![ - "SortExec: expr=[a@0 ASC]", + "SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -2129,7 +2295,7 @@ mod tests { ]; let expected_optimized_bounded_parallelize_sort = vec![ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@2], 10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC], has_header=true", @@ -2173,12 +2339,12 @@ mod tests { let spm = sort_preserving_merge_exec(sort_exprs, repartition_rr); let physical_plan = sort_exec(vec![sort_expr("b", &schema)], spm); - let expected_input = ["SortExec: expr=[b@1 ASC]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + let expected_input = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; - let expected_optimized = ["SortExec: expr=[b@1 ASC]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + let expected_optimized = ["SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); @@ -2202,12 +2368,12 @@ mod tests { spm, ); - let expected_input = ["SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", + let expected_input = ["SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; - let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", - " SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC]", + let expected_optimized = ["SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", + " SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC, b@1 ASC], has_header=false",]; assert_optimized!(expected_input, expected_optimized, physical_plan, false); @@ -2228,16 +2394,16 @@ mod tests { let physical_plan = bounded_window_exec("a", sort_exprs, spm); let expected_input = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortPreservingMergeExec: [a@0 ASC,b@1 ASC]", - " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC,b@1 ASC", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortPreservingMergeExec: [a@0 ASC, b@1 ASC]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10, preserve_order=true, sort_exprs=a@0 ASC, b@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", - " SortExec: expr=[a@0 ASC,b@1 ASC]", + " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", ]; let expected_optimized = [ - "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", - " SortExec: expr=[a@0 ASC,b@1 ASC]", + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", @@ -2260,11 +2426,11 @@ mod tests { ); let expected_input = [ - "SortExec: expr=[a@0 ASC,c@2 ASC]", + "SortExec: expr=[a@0 ASC, c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]" ]; let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC,c@2 ASC], common_prefix_length=[1]", + "PartialSortExec: expr=[a@0 ASC, c@2 ASC], common_prefix_length=[1]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -2287,12 +2453,12 @@ mod tests { ); let expected_input = [ - "SortExec: expr=[a@0 ASC,c@2 ASC,d@3 ASC]", + "SortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]" ]; // let optimized let expected_optimized = [ - "PartialSortExec: expr=[a@0 ASC,c@2 ASC,d@3 ASC], common_prefix_length=[2]", + "PartialSortExec: expr=[a@0 ASC, c@2 ASC, d@3 ASC], common_prefix_length=[2]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[a@0 ASC, c@2 ASC]", ]; assert_optimized!(expected_input, expected_optimized, physical_plan, true); @@ -2314,7 +2480,7 @@ mod tests { parquet_input, ); let expected_input = [ - "SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC]", + "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[b@1 ASC, c@2 ASC]" ]; let expected_no_change = expected_input; @@ -2337,11 +2503,75 @@ mod tests { unbounded_input, ); let expected_input = [ - "SortExec: expr=[a@0 ASC,b@1 ASC,c@2 ASC]", + "SortExec: expr=[a@0 ASC, b@1 ASC, c@2 ASC], preserve_partitioning=[false]", " StreamingTableExec: partition_sizes=1, projection=[a, b, c, d, e], infinite_source=true, output_ordering=[b@1 ASC, c@2 ASC]" ]; let expected_no_change = expected_input; assert_optimized!(expected_input, expected_no_change, physical_plan, true); Ok(()) } + + #[tokio::test] + async fn test_push_with_required_input_ordering_prohibited() -> Result<()> { + // SortExec: expr=[b] <-- can't push this down + // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order + // SortExec: expr=[a] + // MemoryExec + let schema = create_test_schema3()?; + let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); + let sort_exprs_b = LexOrdering::new(vec![sort_expr("b", &schema)]); + let plan = memory_exec(&schema); + let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = RequirementsTestExec::new(plan) + .with_required_input_ordering(sort_exprs_a) + .with_maintains_input_order(true) + .into_arc(); + let plan = sort_exec(sort_exprs_b, plan); + + let expected_input = [ + "SortExec: expr=[b@1 ASC], preserve_partitioning=[false]", + " RequiredInputOrderingExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + // should not be able to push shorts + let expected_no_change = expected_input; + assert_optimized!(expected_input, expected_no_change, plan, true); + Ok(()) + } + + // test when the required input ordering is satisfied so could push through + #[tokio::test] + async fn test_push_with_required_input_ordering_allowed() -> Result<()> { + // SortExec: expr=[a,b] <-- can push this down (as it is compatible with the required input ordering) + // RequiredInputOrder expr=[a] <-- this requires input sorted by a, and preserves the input order + // SortExec: expr=[a] + // MemoryExec + let schema = create_test_schema3()?; + let sort_exprs_a = LexOrdering::new(vec![sort_expr("a", &schema)]); + let sort_exprs_ab = + LexOrdering::new(vec![sort_expr("a", &schema), sort_expr("b", &schema)]); + let plan = memory_exec(&schema); + let plan = sort_exec(sort_exprs_a.clone(), plan); + let plan = RequirementsTestExec::new(plan) + .with_required_input_ordering(sort_exprs_a) + .with_maintains_input_order(true) + .into_arc(); + let plan = sort_exec(sort_exprs_ab, plan); + + let expected_input = [ + "SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", + " RequiredInputOrderingExec", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + // should able to push shorts + let expected = [ + "RequiredInputOrderingExec", + " SortExec: expr=[a@0 ASC, b@1 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ]; + assert_optimized!(expected_input, expected, plan, true); + Ok(()) + } } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 4fefcdf7aad6..0312e362afb1 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -27,7 +27,6 @@ use std::sync::Arc; use crate::config::ConfigOptions; use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use crate::physical_plan::joins::{ CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, @@ -39,14 +38,16 @@ use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow_schema::Schema; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; +use datafusion_expr::sort_properties::SortProperties; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::sort_properties::SortProperties; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::PhysicalOptimizerRule; /// The [`JoinSelection`] rule tries to modify a given plan so that it can /// accommodate infinite sources and optimize joins in the plan according to /// available statistical information, if there is any. -#[derive(Default)] +#[derive(Default, Debug)] pub struct JoinSelection {} impl JoinSelection { @@ -132,6 +133,9 @@ fn swap_join_type(join_type: JoinType) -> JoinType { JoinType::RightSemi => JoinType::LeftSemi, JoinType::LeftAnti => JoinType::RightAnti, JoinType::RightAnti => JoinType::LeftAnti, + JoinType::LeftMark => { + unreachable!("LeftMark join type does not support swapping") + } } } @@ -140,24 +144,38 @@ fn swap_join_projection( left_schema_len: usize, right_schema_len: usize, projection: Option<&Vec>, + join_type: &JoinType, ) -> Option> { - projection.map(|p| { - p.iter() - .map(|i| { - // If the index is less than the left schema length, it is from the left schema, so we add the right schema length to it. - // Otherwise, it is from the right schema, so we subtract the left schema length from it. - if *i < left_schema_len { - *i + right_schema_len - } else { - *i - left_schema_len - } - }) - .collect() - }) + match join_type { + // For Anti/Semi join types, projection should remain unmodified, + // since these joins output schema remains the same after swap + JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::RightAnti + | JoinType::RightSemi => projection.cloned(), + + _ => projection.map(|p| { + p.iter() + .map(|i| { + // If the index is less than the left schema length, it is from + // the left schema, so we add the right schema length to it. + // Otherwise, it is from the right schema, so we subtract the left + // schema length from it. + if *i < left_schema_len { + *i + right_schema_len + } else { + *i - left_schema_len + } + }) + .collect() + }), + } } /// This function swaps the inputs of the given join operator. -fn swap_hash_join( +/// This function is public so other downstream projects can use it +/// to construct `HashJoinExec` with right side as the build side. +pub fn swap_hash_join( hash_join: &HashJoinExec, partition_mode: PartitionMode, ) -> Result> { @@ -177,17 +195,20 @@ fn swap_hash_join( left.schema().fields().len(), right.schema().fields().len(), hash_join.projection.as_ref(), + hash_join.join_type(), ), partition_mode, hash_join.null_equals_null(), )?; + // In case of anti / semi joins or if there is embedded projection in HashJoinExec, output column order is preserved, no need to add projection again if matches!( hash_join.join_type(), JoinType::LeftSemi | JoinType::RightSemi | JoinType::LeftAnti | JoinType::RightAnti - ) { + ) || hash_join.projection.is_some() + { Ok(Arc::new(new_join)) } else { // TODO avoid adding ProjectionExec again and again, only adding Final Projection @@ -533,7 +554,7 @@ fn hash_join_convert_symmetric_subrule( // the function concludes that no specific order is required for the SymmetricHashJoinExec. This approach // ensures that the symmetric hash join operation only imposes ordering constraints when necessary, // based on the properties of the child nodes and the filter condition. - let determine_order = |side: JoinSide| -> Option> { + let determine_order = |side: JoinSide| -> Option { hash_join .filter() .map(|filter| { @@ -556,12 +577,13 @@ fn hash_join_convert_symmetric_subrule( hash_join.right().equivalence_properties(), hash_join.right().schema(), ), + JoinSide::None => return false, }; let name = schema.field(*index).name(); let col = Arc::new(Column::new(name, *index)) as _; // Check if the column is ordered. - equivalence.get_expr_ordering(col).data + equivalence.get_expr_properties(col).sort_properties != SortProperties::Unordered }, ) @@ -571,8 +593,9 @@ fn hash_join_convert_symmetric_subrule( match side { JoinSide::Left => hash_join.left().output_ordering(), JoinSide::Right => hash_join.right().output_ordering(), + JoinSide::None => unreachable!(), } - .map(|p| p.to_vec()) + .map(|p| LexOrdering::new(p.to_vec())) }) .flatten() }; @@ -702,7 +725,6 @@ fn apply_subrules( #[cfg(test)] mod tests_statistical { - use super::*; use crate::{ physical_plan::{displayable, ColumnStatistics, Statistics}, @@ -717,7 +739,7 @@ mod tests_statistical { use rstest::rstest; - /// Return statistcs for empty table + /// Return statistics for empty table fn empty_statistics() -> Statistics { Statistics { num_rows: Precision::Absent, @@ -735,7 +757,7 @@ mod tests_statistical { ) } - /// Return statistcs for small table + /// Return statistics for small table fn small_statistics() -> Statistics { let (threshold_num_rows, threshold_byte_size) = get_thresholds(); Statistics { @@ -745,7 +767,7 @@ mod tests_statistical { } } - /// Return statistcs for big table + /// Return statistics for big table fn big_statistics() -> Statistics { let (threshold_num_rows, threshold_byte_size) = get_thresholds(); Statistics { @@ -755,7 +777,7 @@ mod tests_statistical { } } - /// Return statistcs for big table + /// Return statistics for big table fn bigger_statistics() -> Statistics { let (threshold_num_rows, threshold_byte_size) = get_thresholds(); Statistics { @@ -882,28 +904,6 @@ mod tests_statistical { (big, medium, small) } - pub(crate) fn crosscheck_plans(plan: Arc) -> Result<()> { - let subrules: Vec> = vec![ - Box::new(hash_join_convert_symmetric_subrule), - Box::new(hash_join_swap_subrule), - ]; - let new_plan = plan - .transform_up(|p| apply_subrules(p, &subrules, &ConfigOptions::new())) - .data()?; - // TODO: End state payloads will be checked here. - let config = ConfigOptions::new().optimizer; - let collect_left_threshold = config.hash_join_single_partition_threshold; - let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; - let _ = new_plan.transform_up(|plan| { - statistical_join_selection_subrule( - plan, - collect_left_threshold, - collect_threshold_num_rows, - ) - })?; - Ok(()) - } - #[tokio::test] async fn test_join_with_swap() { let (big, small) = create_big_and_small(); @@ -928,7 +928,7 @@ mod tests_statistical { ); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize(join, &ConfigOptions::new()) .unwrap(); let swapping_projection = optimized_join @@ -958,7 +958,6 @@ mod tests_statistical { swapped_join.right().statistics().unwrap().total_byte_size, Precision::Inexact(2097152) ); - crosscheck_plans(join.clone()).unwrap(); } #[tokio::test] @@ -985,7 +984,7 @@ mod tests_statistical { ); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize(join, &ConfigOptions::new()) .unwrap(); let swapped_join = optimized_join @@ -1001,7 +1000,6 @@ mod tests_statistical { swapped_join.right().statistics().unwrap().total_byte_size, Precision::Inexact(2097152) ); - crosscheck_plans(join.clone()).unwrap(); } #[tokio::test] @@ -1055,7 +1053,6 @@ mod tests_statistical { Precision::Inexact(2097152) ); assert_eq!(original_schema, swapped_join.schema()); - crosscheck_plans(join).unwrap(); } } @@ -1078,7 +1075,6 @@ mod tests_statistical { "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", expected_lines, actual_lines ); - crosscheck_plans(plan).unwrap(); }; } @@ -1164,7 +1160,7 @@ mod tests_statistical { ); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize(join, &ConfigOptions::new()) .unwrap(); let swapped_join = optimized_join @@ -1180,7 +1176,6 @@ mod tests_statistical { swapped_join.right().statistics().unwrap().total_byte_size, Precision::Inexact(2097152) ); - crosscheck_plans(join).unwrap(); } #[rstest( @@ -1205,7 +1200,7 @@ mod tests_statistical { ); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize(join, &ConfigOptions::new()) .unwrap(); let swapping_projection = optimized_join @@ -1249,7 +1244,6 @@ mod tests_statistical { swapped_join.right().statistics().unwrap().total_byte_size, Precision::Inexact(2097152) ); - crosscheck_plans(join.clone()).unwrap(); } #[rstest( @@ -1311,7 +1305,65 @@ mod tests_statistical { swapped_join.right().statistics().unwrap().total_byte_size, Precision::Inexact(2097152) ); - crosscheck_plans(join.clone()).unwrap(); + } + + #[rstest( + join_type, projection, small_on_right, + case::inner(JoinType::Inner, vec![1], true), + case::left(JoinType::Left, vec![1], true), + case::right(JoinType::Right, vec![1], true), + case::full(JoinType::Full, vec![1], true), + case::left_anti(JoinType::LeftAnti, vec![0], false), + case::left_semi(JoinType::LeftSemi, vec![0], false), + case::right_anti(JoinType::RightAnti, vec![0], true), + case::right_semi(JoinType::RightSemi, vec![0], true), + )] + #[tokio::test] + async fn test_hash_join_swap_on_joins_with_projections( + join_type: JoinType, + projection: Vec, + small_on_right: bool, + ) -> Result<()> { + let (big, small) = create_big_and_small(); + + let left = if small_on_right { &big } else { &small }; + let right = if small_on_right { &small } else { &big }; + + let left_on = if small_on_right { + "big_col" + } else { + "small_col" + }; + let right_on = if small_on_right { + "small_col" + } else { + "big_col" + }; + + let join = Arc::new(HashJoinExec::try_new( + Arc::clone(left), + Arc::clone(right), + vec![( + Arc::new(Column::new_with_schema(left_on, &left.schema())?), + Arc::new(Column::new_with_schema(right_on, &right.schema())?), + )], + None, + &join_type, + Some(projection), + PartitionMode::Partitioned, + false, + )?); + + let swapped = swap_hash_join(&join.clone(), PartitionMode::Partitioned) + .expect("swap_hash_join must support joins with projections"); + let swapped_join = swapped.as_any().downcast_ref::().expect( + "ProjectionExec won't be added above if HashJoinExec contains embedded projection", + ); + + assert_eq!(swapped_join.projection, Some(vec![0_usize])); + assert_eq!(swapped.schema().fields.len(), 1); + assert_eq!(swapped.schema().fields[0].name(), "small_col"); + Ok(()) } #[tokio::test] @@ -1383,7 +1435,7 @@ mod tests_statistical { Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, )]; check_join_partition_mode( - big.clone(), + big, small.clone(), join_on, true, @@ -1407,8 +1459,8 @@ mod tests_statistical { Arc::new(Column::new_with_schema("small_col", &small.schema()).unwrap()) as _, )]; check_join_partition_mode( - empty.clone(), - small.clone(), + empty, + small, join_on, true, PartitionMode::CollectLeft, @@ -1451,7 +1503,7 @@ mod tests_statistical { Arc::new(Column::new_with_schema("big_col", &big.schema()).unwrap()) as _, )]; check_join_partition_mode( - bigger.clone(), + bigger, big.clone(), join_on, true, @@ -1499,7 +1551,7 @@ mod tests_statistical { ); let optimized_join = JoinSelection::new() - .optimize(join.clone(), &ConfigOptions::new()) + .optimize(join, &ConfigOptions::new()) .unwrap(); if !is_swapped { @@ -1523,7 +1575,6 @@ mod tests_statistical { assert_eq!(*swapped_join.partition_mode(), expected_mode); } - crosscheck_plans(join).unwrap(); } } @@ -1568,15 +1619,12 @@ mod util_tests { #[cfg(test)] mod hash_join_tests { - - use self::tests_statistical::crosscheck_plans; use super::*; use crate::physical_optimizer::test_utils::SourceType; use crate::test_util::UnboundedExec; use arrow::datatypes::{DataType, Field}; use arrow::record_batch::RecordBatch; - use datafusion_common::utils::DataPtr; struct TestCase { case: String, @@ -1944,8 +1992,7 @@ mod hash_join_tests { false, )?); - let optimized_join_plan = - hash_join_swap_subrule(join.clone(), &ConfigOptions::new())?; + let optimized_join_plan = hash_join_swap_subrule(join, &ConfigOptions::new())?; // If swap did happen let projection_added = optimized_join_plan.as_any().is::(); @@ -1969,8 +2016,8 @@ mod hash_join_tests { .. }) = plan.as_any().downcast_ref::() { - let left_changed = Arc::data_ptr_eq(left, &right_exec); - let right_changed = Arc::data_ptr_eq(right, &left_exec); + let left_changed = Arc::ptr_eq(left, &right_exec); + let right_changed = Arc::ptr_eq(right, &left_exec); // If this is not equal, we have a bigger problem. assert_eq!(left_changed, right_changed); assert_eq!( @@ -2000,7 +2047,6 @@ mod hash_join_tests { ) ); }; - crosscheck_plans(plan).unwrap(); Ok(()) } } diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs deleted file mode 100644 index d211d2c8b2d7..000000000000 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ /dev/null @@ -1,611 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! A special-case optimizer rule that pushes limit into a grouped aggregation -//! which has no aggregate expressions or sorting requirements - -use std::sync::Arc; - -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; -use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; - -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::Result; - -use itertools::Itertools; - -/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all -/// rows in the group to be processed for correctness. Example queries fitting this description are: -/// `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` -/// `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` -pub struct LimitedDistinctAggregation {} - -impl LimitedDistinctAggregation { - /// Create a new `LimitedDistinctAggregation` - pub fn new() -> Self { - Self {} - } - - fn transform_agg( - aggr: &AggregateExec, - limit: usize, - ) -> Option> { - // rules for transforming this Aggregate are held in this method - if !aggr.is_unordered_unfiltered_group_by_distinct() { - return None; - } - - // We found what we want: clone, copy the limit down, and return modified node - let new_aggr = AggregateExec::try_new( - *aggr.mode(), - aggr.group_by().clone(), - aggr.aggr_expr().to_vec(), - aggr.filter_expr().to_vec(), - aggr.input().clone(), - aggr.input_schema(), - ) - .expect("Unable to copy Aggregate!") - .with_limit(Some(limit)); - Some(Arc::new(new_aggr)) - } - - /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` - /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when - /// there is a group by, but no sorting, no aggregate expressions, and no filters in the - /// aggregation - fn transform_limit(plan: Arc) -> Option> { - let limit: usize; - let mut global_fetch: Option = None; - let mut global_skip: usize = 0; - let children: Vec>; - let mut is_global_limit = false; - if let Some(local_limit) = plan.as_any().downcast_ref::() { - limit = local_limit.fetch(); - children = local_limit.children(); - } else if let Some(global_limit) = plan.as_any().downcast_ref::() - { - global_fetch = global_limit.fetch(); - global_fetch?; - global_skip = global_limit.skip(); - // the aggregate must read at least fetch+skip number of rows - limit = global_fetch.unwrap() + global_skip; - children = global_limit.children(); - is_global_limit = true - } else { - return None; - } - let child = children.iter().exactly_one().ok()?; - // ensure there is no output ordering; can this rule be relaxed? - if plan.output_ordering().is_some() { - return None; - } - // ensure no ordering is required on the input - if plan.required_input_ordering()[0].is_some() { - return None; - } - - // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by - // must match that of a child aggregation in order to rewrite the child aggregation - let mut match_aggr: Arc = plan; - let mut found_match_aggr = false; - - let mut rewrite_applicable = true; - let closure = |plan: Arc| { - if !rewrite_applicable { - return Ok(Transformed::no(plan)); - } - if let Some(aggr) = plan.as_any().downcast_ref::() { - if found_match_aggr { - if let Some(parent_aggr) = - match_aggr.as_any().downcast_ref::() - { - if !parent_aggr.group_by().eq(aggr.group_by()) { - // a partial and final aggregation with different groupings disqualifies - // rewriting the child aggregation - rewrite_applicable = false; - return Ok(Transformed::no(plan)); - } - } - } - // either we run into an Aggregate and transform it, or disable the rewrite - // for subsequent children - match Self::transform_agg(aggr, limit) { - None => {} - Some(new_aggr) => { - match_aggr = plan; - found_match_aggr = true; - return Ok(Transformed::yes(new_aggr)); - } - } - } - rewrite_applicable = false; - Ok(Transformed::no(plan)) - }; - let child = child.clone().transform_down(closure).data().ok()?; - if is_global_limit { - return Some(Arc::new(GlobalLimitExec::new( - child, - global_skip, - global_fetch, - ))); - } - Some(Arc::new(LocalLimitExec::new(child, limit))) - } -} - -impl Default for LimitedDistinctAggregation { - fn default() -> Self { - Self::new() - } -} - -impl PhysicalOptimizerRule for LimitedDistinctAggregation { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - if config.optimizer.enable_distinct_aggregation_soft_limit { - plan.transform_down(|plan| { - Ok( - if let Some(plan) = - LimitedDistinctAggregation::transform_limit(plan.clone()) - { - Transformed::yes(plan) - } else { - Transformed::no(plan) - }, - ) - }) - .data() - } else { - Ok(plan) - } - } - - fn name(&self) -> &str { - "LimitedDistinctAggregation" - } - - fn schema_check(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::physical_optimizer::aggregate_statistics::tests::TestAggregate; - use crate::physical_optimizer::enforce_distribution::tests::{ - parquet_exec_with_sort, schema, trim_plan_display, - }; - use crate::physical_plan::aggregates::PhysicalGroupBy; - use crate::physical_plan::collect; - use crate::physical_plan::memory::MemoryExec; - use crate::prelude::SessionContext; - - use arrow::array::Int32Array; - use arrow::compute::SortOptions; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use arrow::util::pretty::pretty_format_batches; - use arrow_schema::SchemaRef; - use datafusion_execution::config::SessionConfig; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::{cast, col}; - use datafusion_physical_expr::{expressions, PhysicalExpr, PhysicalSortExpr}; - use datafusion_physical_plan::aggregates::AggregateMode; - use datafusion_physical_plan::displayable; - - fn mock_data() -> Result> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ])); - - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - None, - Some(1), - Some(4), - Some(5), - ])), - Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(6), - Some(2), - Some(8), - Some(9), - ])), - ], - )?; - - Ok(Arc::new(MemoryExec::try_new( - &[vec![batch]], - Arc::clone(&schema), - None, - )?)) - } - - fn assert_plan_matches_expected( - plan: &Arc, - expected: &[&str], - ) -> Result<()> { - let expected_lines: Vec<&str> = expected.to_vec(); - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); - - let optimized = LimitedDistinctAggregation::new() - .optimize(Arc::clone(plan), state.config_options())?; - - let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); - let actual_lines = trim_plan_display(&optimized_result); - - assert_eq!( - &expected_lines, &actual_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines - ); - - Ok(()) - } - - async fn assert_results_match_expected( - plan: Arc, - expected: &str, - ) -> Result<()> { - let cfg = SessionConfig::new().with_target_partitions(1); - let ctx = SessionContext::new_with_config(cfg); - let batches = collect(plan, ctx.task_ctx()).await?; - let actual = format!("{}", pretty_format_batches(&batches)?); - assert_eq!(actual, expected); - Ok(()) - } - - pub fn build_group_by( - input_schema: &SchemaRef, - columns: Vec, - ) -> PhysicalGroupBy { - let mut group_by_expr: Vec<(Arc, String)> = vec![]; - for column in columns.iter() { - group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); - } - PhysicalGroupBy::new_single(group_by_expr.clone()) - } - - #[tokio::test] - async fn test_partial_final() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - - // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Partial/Final AggregateExec - let partial_agg = AggregateExec::try_new( - AggregateMode::Partial, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let final_agg = AggregateExec::try_new( - AggregateMode::Final, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - Arc::new(partial_agg), /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(final_agg), - 4, // fetch - ); - // expected to push the limit to the Partial and Final AggregateExecs - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", - "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; - Ok(()) - } - - #[tokio::test] - async fn test_single_local() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - - // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec - let single_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(single_agg), - 4, // fetch - ); - // expected to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; - Ok(()) - } - - #[tokio::test] - async fn test_single_global() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - - // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec - let single_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = GlobalLimitExec::new( - Arc::new(single_agg), - 1, // skip - Some(3), // fetch - ); - // expected to push the skip+fetch limit to the AggregateExec - let expected = [ - "GlobalLimitExec: skip=1, fetch=3", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; - Ok(()) - } - - #[tokio::test] - async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - - // `SELECT distinct a FROM MemoryExec GROUP BY a, b LIMIT 4;`, Single/Single AggregateExec - let group_by_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let distinct_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - Arc::new(group_by_agg), /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(distinct_agg), - 4, // fetch - ); - // expected to push the limit to the outer AggregateExec only - let expected = [ - "LocalLimitExec: fetch=4", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", - "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - let expected = r#" -+---+ -| a | -+---+ -| 1 | -| 2 | -| | -| 4 | -+---+ -"# - .trim(); - assert_results_match_expected(plan, expected).await?; - Ok(()) - } - - #[test] - fn test_no_group_by() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - - // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec - let single_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec![]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(single_agg), - 10, // fetch - ); - // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[], aggr=[]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - Ok(()) - } - - #[test] - fn test_has_aggregate_expression() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - let agg = TestAggregate::new_count_star(); - - // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec - let single_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![None], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(single_agg), - 10, // fetch - ); - // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - Ok(()) - } - - #[test] - fn test_has_filter() -> Result<()> { - let source = mock_data()?; - let schema = source.schema(); - - // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec - // the `a > 1` filter is applied in the AggregateExec - let filter_expr = Some(expressions::binary( - expressions::col("a", &schema)?, - Operator::Gt, - cast(expressions::lit(1u32), &schema, DataType::Int32)?, - &schema, - )?); - let agg = TestAggregate::new_count_star(); - let single_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![agg.count_expr()], /* aggr_expr */ - vec![filter_expr], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(single_agg), - 10, // fetch - ); - // expected not to push the limit to the AggregateExec - // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", - "MemoryExec: partitions=1, partition_sizes=[1]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - Ok(()) - } - - #[test] - fn test_has_order_by() -> Result<()> { - let sort_key = vec![PhysicalSortExpr { - expr: expressions::col("a", &schema()).unwrap(), - options: SortOptions::default(), - }]; - let source = parquet_exec_with_sort(vec![sort_key]); - let schema = source.schema(); - - // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec - // the `a > 1` filter is applied in the AggregateExec - let single_agg = AggregateExec::try_new( - AggregateMode::Single, - build_group_by(&schema.clone(), vec!["a".to_string()]), - vec![], /* aggr_expr */ - vec![], /* filter_expr */ - source, /* input */ - schema.clone(), /* input_schema */ - )?; - let limit_exec = LocalLimitExec::new( - Arc::new(single_agg), - 10, // fetch - ); - // expected not to push the limit to the AggregateExec - let expected = [ - "LocalLimitExec: fetch=10", - "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", - "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", - ]; - let plan: Arc = Arc::new(limit_exec); - assert_plan_matches_expected(&plan, &expected)?; - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index c80668c6da74..efdd3148d03f 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -21,25 +21,20 @@ //! "Repartition" or "Sortedness" //! //! [`ExecutionPlan`]: crate::physical_plan::ExecutionPlan -pub mod aggregate_statistics; pub mod coalesce_batches; -pub mod combine_partial_final_agg; -mod convert_first_last; pub mod enforce_distribution; pub mod enforce_sorting; pub mod join_selection; -pub mod limited_distinct_aggregation; pub mod optimizer; -pub mod output_requirements; -pub mod pipeline_checker; -mod projection_pushdown; +pub mod projection_pushdown; pub mod pruning; pub mod replace_with_order_preserving_variants; -mod sort_pushdown; -pub mod topk_aggregation; -mod utils; - +pub mod sanity_checker; #[cfg(test)] pub mod test_utils; +pub mod update_aggr_exprs; + +mod sort_pushdown; +mod utils; -pub use optimizer::PhysicalOptimizerRule; +pub use datafusion_physical_optimizer::*; diff --git a/datafusion/core/src/physical_optimizer/optimizer.rs b/datafusion/core/src/physical_optimizer/optimizer.rs index 08cbf68fa617..7a6f991121ef 100644 --- a/datafusion/core/src/physical_optimizer/optimizer.rs +++ b/datafusion/core/src/physical_optimizer/optimizer.rs @@ -17,50 +17,25 @@ //! Physical optimizer traits +use datafusion_physical_optimizer::PhysicalOptimizerRule; use std::sync::Arc; -use super::convert_first_last::OptimizeAggregateOrder; use super::projection_pushdown::ProjectionPushdown; -use crate::config::ConfigOptions; +use super::update_aggr_exprs::OptimizeAggregateOrder; use crate::physical_optimizer::aggregate_statistics::AggregateStatistics; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; use crate::physical_optimizer::enforce_distribution::EnforceDistribution; use crate::physical_optimizer::enforce_sorting::EnforceSorting; use crate::physical_optimizer::join_selection::JoinSelection; +use crate::physical_optimizer::limit_pushdown::LimitPushdown; use crate::physical_optimizer::limited_distinct_aggregation::LimitedDistinctAggregation; use crate::physical_optimizer::output_requirements::OutputRequirements; -use crate::physical_optimizer::pipeline_checker::PipelineChecker; +use crate::physical_optimizer::sanity_checker::SanityCheckPlan; use crate::physical_optimizer::topk_aggregation::TopKAggregation; -use crate::{error::Result, physical_plan::ExecutionPlan}; - -/// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which -/// computes the same results, but in a potentially more efficient way. -/// -/// Use [`SessionState::add_physical_optimizer_rule`] to register additional -/// `PhysicalOptimizerRule`s. -/// -/// [`SessionState::add_physical_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionState.html#method.add_physical_optimizer_rule -pub trait PhysicalOptimizerRule { - /// Rewrite `plan` to an optimized form - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result>; - - /// A human readable name for this optimizer rule - fn name(&self) -> &str; - - /// A flag to indicate whether the physical planner should valid the rule will not - /// change the schema of the plan after the rewriting. - /// Some of the optimization rules might change the nullable properties of the schema - /// and should disable the schema check. - fn schema_check(&self) -> bool; -} /// A rule-based physical optimizer. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PhysicalOptimizer { /// All rules to apply pub rules: Vec>, @@ -112,11 +87,6 @@ impl PhysicalOptimizer { // Remove the ancillary output requirement operator since we are done with the planning // phase. Arc::new(OutputRequirements::new_remove_mode()), - // The PipelineChecker rule will reject non-runnable query plans that use - // pipeline-breaking operators on infinite input(s). The rule generates a - // diagnostic error message when this happens. It makes no changes to the - // given query plan; i.e. it only acts as a final gatekeeping rule. - Arc::new(PipelineChecker::new()), // The aggregation limiter will try to find situations where the accumulator count // is not tied to the cardinality, i.e. when the output of the aggregation is passed // into an `order by max(x) limit y`. In this case it will copy the limit value down @@ -129,6 +99,19 @@ impl PhysicalOptimizer { // are not present, the load of executors such as join or union will be // reduced by narrowing their input tables. Arc::new(ProjectionPushdown::new()), + // The LimitPushdown rule tries to push limits down as far as possible, + // replacing operators with fetching variants, or adding limits + // past operators that support limit pushdown. + Arc::new(LimitPushdown::new()), + // The SanityCheckPlan rule checks whether the order and + // distribution requirements of each node in the plan + // is satisfied. It will also reject non-runnable query + // plans that use pipeline-breaking operators on infinite + // input(s). The rule generates a diagnostic error + // message for invalid plans. It makes no changes to the + // given query plan; i.e. it only acts as a final + // gatekeeping rule. + Arc::new(SanityCheckPlan::new()), ]; Self::with_rules(rules) diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs deleted file mode 100644 index 5c6a0ab8ea7f..000000000000 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ /dev/null @@ -1,334 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The [PipelineChecker] rule ensures that a given plan can accommodate its -//! infinite sources, if there are any. It will reject non-runnable query plans -//! that use pipeline-breaking operators on infinite input(s). - -use std::sync::Arc; - -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; - -use datafusion_common::config::OptimizerOptions; -use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; -use datafusion_physical_plan::joins::SymmetricHashJoinExec; - -/// The PipelineChecker rule rejects non-runnable query plans that use -/// pipeline-breaking operators on infinite input(s). -#[derive(Default)] -pub struct PipelineChecker {} - -impl PipelineChecker { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl PhysicalOptimizerRule for PipelineChecker { - fn optimize( - &self, - plan: Arc, - config: &ConfigOptions, - ) -> Result> { - plan.transform_up(|p| check_finiteness_requirements(p, &config.optimizer)) - .data() - } - - fn name(&self) -> &str { - "PipelineChecker" - } - - fn schema_check(&self) -> bool { - true - } -} - -/// This function propagates finiteness information and rejects any plan with -/// pipeline-breaking operators acting on infinite inputs. -pub fn check_finiteness_requirements( - input: Arc, - optimizer_options: &OptimizerOptions, -) -> Result>> { - if let Some(exec) = input.as_any().downcast_ref::() { - if !(optimizer_options.allow_symmetric_joins_without_pruning - || (exec.check_if_order_information_available()? && is_prunable(exec))) - { - return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ - the 'allow_symmetric_joins_without_pruning' configuration flag"); - } - } - if !input.execution_mode().pipeline_friendly() { - plan_err!( - "Cannot execute pipeline breaking queries, operator: {:?}", - input - ) - } else { - Ok(Transformed::no(input)) - } -} - -/// This function returns whether a given symmetric hash join is amenable to -/// data pruning. For this to be possible, it needs to have a filter where -/// all involved [`PhysicalExpr`]s, [`Operator`]s and data types support -/// interval calculations. -/// -/// [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr -/// [`Operator`]: datafusion_expr::Operator -fn is_prunable(join: &SymmetricHashJoinExec) -> bool { - join.filter().map_or(false, |filter| { - check_support(filter.expression(), &join.schema()) - && filter - .schema() - .fields() - .iter() - .all(|f| is_datatype_supported(f.data_type())) - }) -} - -#[cfg(test)] -mod sql_tests { - use super::*; - use crate::physical_optimizer::test_utils::{ - BinaryTestCase, QueryCase, SourceType, UnaryTestCase, - }; - - #[tokio::test] - async fn test_hash_left_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: false, - }; - - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_right_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: false, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_inner_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: false, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: false, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "Join Error".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_full_outer_join_swap() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1" - .to_string(), - cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], - error_operator: "operator: HashJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_aggregate() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: AggregateExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_window_agg_hash_partition() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT - c9, - SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 - FROM test - LIMIT 5".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: SortExec".to_string() - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_window_agg_single_partition() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: true, - }; - let case = QueryCase { - sql: "SELECT - c9, - SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 - FROM test".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "operator: SortExec".to_string() - }; - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_hash_cross_join() -> Result<()> { - let test1 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Bounded), - expect_fail: true, - }; - let test2 = BinaryTestCase { - source_types: (SourceType::Unbounded, SourceType::Unbounded), - expect_fail: true, - }; - let test3 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Unbounded), - expect_fail: true, - }; - let test4 = BinaryTestCase { - source_types: (SourceType::Bounded, SourceType::Bounded), - expect_fail: false, - }; - let case = QueryCase { - sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(), - cases: vec![ - Arc::new(test1), - Arc::new(test2), - Arc::new(test3), - Arc::new(test4), - ], - error_operator: "operator: CrossJoinExec".to_string(), - }; - - case.run().await?; - Ok(()) - } - - #[tokio::test] - async fn test_analyzer() -> Result<()> { - let test1 = UnaryTestCase { - source_type: SourceType::Bounded, - expect_fail: false, - }; - let test2 = UnaryTestCase { - source_type: SourceType::Unbounded, - expect_fail: false, - }; - let case = QueryCase { - sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(), - cases: vec![Arc::new(test1), Arc::new(test2)], - error_operator: "Analyze Error".to_string(), - }; - - case.run().await?; - Ok(()) - } -} diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 359916de0f1e..5aecf036ce18 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -18,13 +18,12 @@ //! This file implements the `ProjectionPushdown` physical optimization rule. //! The function [`remove_unnecessary_projections`] tries to push down all //! projections one by one if the operator below is amenable to this. If a -//! projection reaches a source, it can even dissappear from the plan entirely. +//! projection reaches a source, it can even disappear from the plan entirely. use std::collections::HashMap; use std::sync::Arc; use super::output_requirements::OutputRequirementExec; -use super::PhysicalOptimizerRule; use crate::datasource::physical_plan::CsvExec; use crate::error::Result; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -46,7 +45,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{DataFusionError, JoinSide}; +use datafusion_common::{internal_err, JoinSide}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ utils::collect_columns, Partitioning, PhysicalExpr, PhysicalExprRef, @@ -55,11 +54,13 @@ use datafusion_physical_expr::{ use datafusion_physical_plan::streaming::StreamingTableExec; use datafusion_physical_plan::union::UnionExec; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use itertools::Itertools; /// This rule inspects [`ProjectionExec`]'s in the given physical plan and tries to /// remove or swap with its child. -#[derive(Default)] +#[derive(Default, Debug)] pub struct ProjectionPushdown {} impl ProjectionPushdown { @@ -123,7 +124,10 @@ pub fn remove_unnecessary_projections( } else if input.is::() { try_swapping_with_coalesce_partitions(projection)? } else if let Some(filter) = input.downcast_ref::() { - try_swapping_with_filter(projection, filter)? + try_swapping_with_filter(projection, filter)?.map_or_else( + || try_embed_projection(projection, filter), + |e| Ok(Some(e)), + )? } else if let Some(repartition) = input.downcast_ref::() { try_swapping_with_repartition(projection, repartition)? } else if let Some(sort) = input.downcast_ref::() { @@ -134,7 +138,7 @@ pub fn remove_unnecessary_projections( try_pushdown_through_union(projection, union)? } else if let Some(hash_join) = input.downcast_ref::() { try_pushdown_through_hash_join(projection, hash_join)?.map_or_else( - || try_embed_to_hash_join(projection, hash_join), + || try_embed_projection(projection, hash_join), |e| Ok(Some(e)), )? } else if let Some(cross_join) = input.downcast_ref::() { @@ -179,14 +183,17 @@ fn try_swapping_with_csv( ); file_scan.projection = Some(new_projections); - Arc::new(CsvExec::new( - file_scan, - csv.has_header(), - csv.delimiter(), - csv.quote(), - csv.escape(), - csv.file_compression_type, - )) as _ + Arc::new( + CsvExec::builder(file_scan) + .with_has_header(csv.has_header()) + .with_delimeter(csv.delimiter()) + .with_quote(csv.quote()) + .with_escape(csv.escape()) + .with_comment(csv.comment()) + .with_newlines_in_values(csv.newlines_in_values()) + .with_file_compression_type(csv.file_compression_type) + .build(), + ) as _ }) } @@ -239,7 +246,7 @@ fn try_swapping_with_streaming_table( let mut lex_orderings = vec![]; for lex_ordering in streaming_table.projected_output_ordering().into_iter() { - let mut orderings = vec![]; + let mut orderings = LexOrdering::default(); for order in lex_ordering { let Some(new_ordering) = update_expr(&order.expr, projection.expr(), false)? else { @@ -259,6 +266,7 @@ fn try_swapping_with_streaming_table( Some(new_projections.as_ref()), lex_orderings, streaming_table.is_infinite(), + streaming_table.limit(), ) .map(|e| Some(Arc::new(e) as _)) } @@ -327,10 +335,10 @@ fn try_swapping_with_output_req( return Ok(None); } - let mut updated_sort_reqs = vec![]; + let mut updated_sort_reqs = LexRequirement::new(vec![]); // None or empty_vec can be treated in the same way. if let Some(reqs) = &output_req.required_input_ordering()[0] { - for req in reqs { + for req in &reqs.inner { let Some(new_expr) = update_expr(&req.expr, projection.expr(), false)? else { return Ok(None); }; @@ -377,7 +385,7 @@ fn try_swapping_with_coalesce_partitions( return Ok(None); } // CoalescePartitionsExec always has a single child, so zero indexing is safe. - make_with_child(projection, &projection.input().children()[0]) + make_with_child(projection, projection.input().children()[0]) .map(|e| Some(Arc::new(CoalescePartitionsExec::new(e)) as _)) } @@ -459,7 +467,7 @@ fn try_swapping_with_sort( return Ok(None); } - let mut updated_exprs = vec![]; + let mut updated_exprs = LexOrdering::default(); for sort in sort.expr() { let Some(new_expr) = update_expr(&sort.expr, projection.expr(), false)? else { return Ok(None); @@ -489,7 +497,7 @@ fn try_swapping_with_sort_preserving_merge( return Ok(None); } - let mut updated_exprs = vec![]; + let mut updated_exprs = LexOrdering::default(); for sort in spm.expr() { let Some(updated_expr) = update_expr(&sort.expr, projection.expr(), false)? else { @@ -525,17 +533,33 @@ fn try_pushdown_through_union( let new_children = union .children() .into_iter() - .map(|child| make_with_child(projection, &child)) + .map(|child| make_with_child(projection, child)) .collect::>>()?; Ok(Some(Arc::new(UnionExec::new(new_children)))) } +trait EmbeddedProjection: ExecutionPlan + Sized { + fn with_projection(&self, projection: Option>) -> Result; +} + +impl EmbeddedProjection for HashJoinExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + +impl EmbeddedProjection for FilterExec { + fn with_projection(&self, projection: Option>) -> Result { + self.with_projection(projection) + } +} + /// Some projection can't be pushed down left input or right input of hash join because filter or on need may need some columns that won't be used in later. /// By embed those projection to hash join, we can reduce the cost of build_batch_from_indices in hash join (build_batch_from_indices need to can compute::take() for each column) and avoid unnecessary output creation. -fn try_embed_to_hash_join( +fn try_embed_projection( projection: &ProjectionExec, - hash_join: &HashJoinExec, + execution_plan: &Exec, ) -> Result>> { // Collect all column indices from the given projection expressions. let projection_index = collect_column_indices(projection.expr()); @@ -545,20 +569,20 @@ fn try_embed_to_hash_join( }; // If the projection indices is the same as the input columns, we don't need to embed the projection to hash join. - // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of hash_join schema fields. + // Check the projection_index is 0..n-1 and the length of projection_index is the same as the length of execution_plan schema fields. if projection_index.len() == projection_index.last().unwrap() + 1 - && projection_index.len() == hash_join.schema().fields().len() + && projection_index.len() == execution_plan.schema().fields().len() { return Ok(None); } - let new_hash_join = - Arc::new(hash_join.with_projection(Some(projection_index.to_vec()))?); + let new_execution_plan = + Arc::new(execution_plan.with_projection(Some(projection_index.to_vec()))?); - // Build projection expressions for update_expr. Zip the projection_index with the new_hash_join output schema fields. + // Build projection expressions for update_expr. Zip the projection_index with the new_execution_plan output schema fields. let embed_project_exprs = projection_index .iter() - .zip(new_hash_join.schema().fields()) + .zip(new_execution_plan.schema().fields()) .map(|(index, field)| { ( Arc::new(Column::new(field.name(), *index)) as Arc, @@ -579,10 +603,10 @@ fn try_embed_to_hash_join( // Old projection may contain some alias or expression such as `a + 1` and `CAST('true' AS BOOLEAN)`, but our projection_exprs in hash join just contain column, so we need to create the new projection to keep the original projection. let new_projection = Arc::new(ProjectionExec::try_new( new_projection_exprs, - new_hash_join.clone(), + new_execution_plan.clone(), )?); if is_projection_removable(&new_projection) { - Ok(Some(new_hash_join)) + Ok(Some(new_execution_plan)) } else { Ok(Some(new_projection)) } @@ -627,7 +651,7 @@ fn try_pushdown_through_hash_join( if !join_allows_pushdown( &projection_as_columns, - hash_join.schema(), + &hash_join.schema(), far_right_left_col_ind, far_left_right_col_ind, ) { @@ -638,6 +662,7 @@ fn try_pushdown_through_hash_join( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], hash_join.on(), + hash_join.left().schema().fields().len(), ) else { return Ok(None); }; @@ -647,8 +672,7 @@ fn try_pushdown_through_hash_join( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], filter, - hash_join.left(), - hash_join.right(), + hash_join.left().schema().fields().len(), ) { Some(updated_filter) => Some(updated_filter), None => return Ok(None), @@ -658,7 +682,7 @@ fn try_pushdown_through_hash_join( }; let (new_left, new_right) = new_join_children( - projection_as_columns, + &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, hash_join.left(), @@ -696,7 +720,7 @@ fn try_swapping_with_cross_join( if !join_allows_pushdown( &projection_as_columns, - cross_join.schema(), + &cross_join.schema(), far_right_left_col_ind, far_left_right_col_ind, ) { @@ -704,7 +728,7 @@ fn try_swapping_with_cross_join( } let (new_left, new_right) = new_join_children( - projection_as_columns, + &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, cross_join.left(), @@ -736,7 +760,7 @@ fn try_swapping_with_nested_loop_join( if !join_allows_pushdown( &projection_as_columns, - nl_join.schema(), + &nl_join.schema(), far_right_left_col_ind, far_left_right_col_ind, ) { @@ -748,8 +772,7 @@ fn try_swapping_with_nested_loop_join( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], filter, - nl_join.left(), - nl_join.right(), + nl_join.left().schema().fields().len(), ) { Some(updated_filter) => Some(updated_filter), None => return Ok(None), @@ -759,7 +782,7 @@ fn try_swapping_with_nested_loop_join( }; let (new_left, new_right) = new_join_children( - projection_as_columns, + &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, nl_join.left(), @@ -793,7 +816,7 @@ fn try_swapping_with_sort_merge_join( if !join_allows_pushdown( &projection_as_columns, - sm_join.schema(), + &sm_join.schema(), far_right_left_col_ind, far_left_right_col_ind, ) { @@ -804,16 +827,17 @@ fn try_swapping_with_sort_merge_join( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], sm_join.on(), + sm_join.left().schema().fields().len(), ) else { return Ok(None); }; let (new_left, new_right) = new_join_children( - projection_as_columns, + &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, - &sm_join.children()[0], - &sm_join.children()[1], + sm_join.children()[0], + sm_join.children()[1], )?; Ok(Some(Arc::new(SortMergeJoinExec::try_new( @@ -846,7 +870,7 @@ fn try_swapping_with_sym_hash_join( if !join_allows_pushdown( &projection_as_columns, - sym_join.schema(), + &sym_join.schema(), far_right_left_col_ind, far_left_right_col_ind, ) { @@ -857,6 +881,7 @@ fn try_swapping_with_sym_hash_join( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], sym_join.on(), + sym_join.left().schema().fields().len(), ) else { return Ok(None); }; @@ -866,8 +891,7 @@ fn try_swapping_with_sym_hash_join( &projection_as_columns[0..=far_right_left_col_ind as _], &projection_as_columns[far_left_right_col_ind as _..], filter, - sym_join.left(), - sym_join.right(), + sym_join.left().schema().fields().len(), ) { Some(updated_filter) => Some(updated_filter), None => return Ok(None), @@ -877,7 +901,7 @@ fn try_swapping_with_sym_hash_join( }; let (new_left, new_right) = new_join_children( - projection_as_columns, + &projection_as_columns, far_right_left_col_ind, far_left_right_col_ind, sym_join.left(), @@ -891,8 +915,14 @@ fn try_swapping_with_sym_hash_join( new_filter, sym_join.join_type(), sym_join.null_equals_null(), - sym_join.right().output_ordering().map(|p| p.to_vec()), - sym_join.left().output_ordering().map(|p| p.to_vec()), + sym_join + .right() + .output_ordering() + .map(|p| LexOrdering::new(p.to_vec())), + sym_join + .left() + .output_ordering() + .map(|p| LexOrdering::new(p.to_vec())), sym_join.partition_mode(), )?))) } @@ -1088,6 +1118,7 @@ fn update_join_on( proj_left_exprs: &[(Column, String)], proj_right_exprs: &[(Column, String)], hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)], + left_field_size: usize, ) -> Option> { // TODO: Clippy wants the "map" call removed, but doing so generates // a compilation error. Remove the clippy directive once this @@ -1098,8 +1129,9 @@ fn update_join_on( .map(|(left, right)| (left, right)) .unzip(); - let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs); - let new_right_columns = new_columns_for_join_on(&right_idx, proj_right_exprs); + let new_left_columns = new_columns_for_join_on(&left_idx, proj_left_exprs, 0); + let new_right_columns = + new_columns_for_join_on(&right_idx, proj_right_exprs, left_field_size); match (new_left_columns, new_right_columns) { (Some(left), Some(right)) => Some(left.into_iter().zip(right).collect()), @@ -1110,9 +1142,14 @@ fn update_join_on( /// This function generates a new set of columns to be used in a hash join /// operation based on a set of equi-join conditions (`hash_join_on`) and a /// list of projection expressions (`projection_exprs`). +/// +/// Notes: Column indices in the projection expressions are based on the join schema, +/// whereas the join on expressions are based on the join child schema. `column_index_offset` +/// represents the offset between them. fn new_columns_for_join_on( hash_join_on: &[&PhysicalExprRef], projection_exprs: &[(Column, String)], + column_index_offset: usize, ) -> Option> { let new_columns = hash_join_on .iter() @@ -1128,6 +1165,8 @@ fn new_columns_for_join_on( .enumerate() .find(|(_, (proj_column, _))| { column.name() == proj_column.name() + && column.index() + column_index_offset + == proj_column.index() }) .map(|(index, (_, alias))| Column::new(alias, index)); if let Some(new_column) = new_column { @@ -1136,10 +1175,10 @@ fn new_columns_for_join_on( // If the column is not found in the projection expressions, // it means that the column is not projected. In this case, // we cannot push the projection down. - Err(DataFusionError::Internal(format!( + internal_err!( "Column {:?} not found in projection expressions", column - ))) + ) } } else { Ok(Transformed::no(expr)) @@ -1158,21 +1197,20 @@ fn update_join_filter( projection_left_exprs: &[(Column, String)], projection_right_exprs: &[(Column, String)], join_filter: &JoinFilter, - join_left: &Arc, - join_right: &Arc, + left_field_size: usize, ) -> Option { let mut new_left_indices = new_indices_for_join_filter( join_filter, JoinSide::Left, projection_left_exprs, - join_left.schema(), + 0, ) .into_iter(); let mut new_right_indices = new_indices_for_join_filter( join_filter, JoinSide::Right, projection_right_exprs, - join_right.schema(), + left_field_size, ) .into_iter(); @@ -1202,20 +1240,24 @@ fn update_join_filter( /// This function determines and returns a vector of indices representing the /// positions of columns in `projection_exprs` that are involved in `join_filter`, /// and correspond to a particular side (`join_side`) of the join operation. +/// +/// Notes: Column indices in the projection expressions are based on the join schema, +/// whereas the join filter is based on the join child schema. `column_index_offset` +/// represents the offset between them. fn new_indices_for_join_filter( join_filter: &JoinFilter, join_side: JoinSide, projection_exprs: &[(Column, String)], - join_child_schema: SchemaRef, + column_index_offset: usize, ) -> Vec { join_filter .column_indices() .iter() .filter(|col_idx| col_idx.side == join_side) .filter_map(|col_idx| { - projection_exprs.iter().position(|(col, _)| { - col.name() == join_child_schema.fields()[col_idx.index].name() - }) + projection_exprs + .iter() + .position(|(col, _)| col_idx.index + column_index_offset == col.index()) }) .collect() } @@ -1227,7 +1269,7 @@ fn new_indices_for_join_filter( /// - Left or right table is not lost after the projection. fn join_allows_pushdown( projection_as_columns: &[(Column, String)], - join_schema: SchemaRef, + join_schema: &SchemaRef, far_right_left_col_ind: i32, far_left_right_col_ind: i32, ) -> bool { @@ -1244,7 +1286,7 @@ fn join_allows_pushdown( /// this function constructs the new [`ProjectionExec`]s that will come on top /// of the original children of the join. fn new_join_children( - projection_as_columns: Vec<(Column, String)>, + projection_as_columns: &[(Column, String)], far_right_left_col_ind: i32, far_left_right_col_ind: i32, left_child: &Arc, @@ -1296,12 +1338,11 @@ mod tests { use crate::physical_plan::joins::StreamJoinPartitionMode; use arrow_schema::{DataType, Field, Schema, SortOptions}; - use datafusion_common::{JoinType, ScalarValue, Statistics}; + use datafusion_common::{JoinType, ScalarValue}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ - ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, NegativeExpr, @@ -1362,9 +1403,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1378,8 +1417,6 @@ mod tests { )), ], DataType::Int32, - None, - false, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -1431,9 +1468,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1447,8 +1482,6 @@ mod tests { )), ], DataType::Int32, - None, - false, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -1503,9 +1536,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1519,8 +1550,6 @@ mod tests { )), ], DataType::Int32, - None, - false, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -1572,9 +1601,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), @@ -1588,8 +1615,6 @@ mod tests { )), ], DataType::Int32, - None, - false, )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), @@ -1691,23 +1716,21 @@ mod tests { Field::new("d", DataType::Int32, true), Field::new("e", DataType::Int32, true), ])); - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(&schema), - projection: Some(vec![0, 1, 2, 3, 4]), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![vec![]], - }, - false, - 0, - 0, - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + CsvExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_projection(Some(vec![0, 1, 2, 3, 4])), + ) + .with_has_header(false) + .with_delimeter(0) + .with_quote(0) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } fn create_projecting_csv_exec() -> Arc { @@ -1717,23 +1740,21 @@ mod tests { Field::new("c", DataType::Int32, true), Field::new("d", DataType::Int32, true), ])); - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(&schema), - projection: Some(vec![3, 2, 1]), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![vec![]], - }, - false, - 0, - 0, - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + CsvExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_projection(Some(vec![3, 2, 1])), + ) + .with_has_header(false) + .with_delimeter(0) + .with_quote(0) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } fn create_projecting_memory_exec() -> Arc { @@ -1816,6 +1837,7 @@ mod tests { #[test] fn test_streaming_table_after_projection() -> Result<()> { + #[derive(Debug)] struct DummyStreamPartition { schema: SchemaRef, } @@ -1847,7 +1869,7 @@ mod tests { }) as _], Some(&vec![0_usize, 2, 4, 3]), vec![ - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("e", 2)), options: SortOptions::default(), @@ -1856,14 +1878,15 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), }, - ], - vec![PhysicalSortExpr { + ]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("d", 3)), options: SortOptions::default(), - }], + }]), ] .into_iter(), true, + None, )?; let projection = Arc::new(ProjectionExec::try_new( vec![ @@ -1906,7 +1929,7 @@ mod tests { assert_eq!( result.projected_output_ordering().into_iter().collect_vec(), vec![ - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("e", 1)), options: SortOptions::default(), @@ -1915,11 +1938,11 @@ mod tests { expr: Arc::new(Column::new("a", 2)), options: SortOptions::default(), }, - ], - vec![PhysicalSortExpr { + ]), + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("d", 0)), options: SortOptions::default(), - }], + }]), ] ); assert!(result.is_infinite()); @@ -1980,7 +2003,7 @@ mod tests { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(OutputRequirementExec::new( csv.clone(), - Some(vec![ + Some(LexRequirement::new(vec![ PhysicalSortRequirement { expr: Arc::new(Column::new("b", 1)), options: Some(SortOptions::default()), @@ -1993,7 +2016,7 @@ mod tests { )), options: Some(SortOptions::default()), }, - ]), + ])), Distribution::HashPartitioned(vec![ Arc::new(Column::new("a", 0)), Arc::new(Column::new("b", 1)), @@ -2026,7 +2049,7 @@ mod tests { ]; assert_eq!(get_plan_string(&after_optimize), expected); - let expected_reqs = vec![ + let expected_reqs = LexRequirement::new(vec![ PhysicalSortRequirement { expr: Arc::new(Column::new("b", 2)), options: Some(SortOptions::default()), @@ -2039,7 +2062,7 @@ mod tests { )), options: Some(SortOptions::default()), }, - ]; + ]); assert_eq!( after_optimize .as_any() @@ -2220,9 +2243,9 @@ mod tests { )?); let initial = get_plan_string(&projection); let expected_initial = [ - "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left, a@5 as a_from_right, c@7 as c_from_right]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2231,10 +2254,10 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", - " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", + "SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b_from_left@1, c_from_right@1)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " ProjectionExec: expr=[c@2 as c_from_left, b@1 as b_from_left, a@0 as a_from_left]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[a@0 as a_from_right, c@2 as c_from_right]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(get_plan_string(&after_optimize), expected); @@ -2339,9 +2362,9 @@ mod tests { )?); let initial = get_plan_string(&projection); let expected_initial = [ - "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2350,9 +2373,9 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", - " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "ProjectionExec: expr=[a@5 as a, b@6 as b, c@7 as c, d@8 as d, e@9 as e, a@0 as a, b@1 as b, c@2 as c, d@3 as d, e@4 as e]", + " SymmetricHashJoinExec: mode=SinglePartition, join_type=Inner, on=[(b@1, c@2)], filter=b_left_inter@0 - 1 + a_right_inter@1 <= a_right_inter@1 + c_left_inter@2", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(get_plan_string(&after_optimize), expected); @@ -2506,8 +2529,8 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", - " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", + "RepartitionExec: partitioning=Hash([a@1, b_new@0, d_new@2], 6), input_partitions=1", + " ProjectionExec: expr=[b@1 as b_new, a@0 as a, d@3 as d_new]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", ]; assert_eq!(get_plan_string(&after_optimize), expected); @@ -2536,7 +2559,7 @@ mod tests { fn test_sort_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(SortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), @@ -2549,7 +2572,7 @@ mod tests { )), options: SortOptions::default(), }, - ], + ]), csv.clone(), )); let projection: Arc = Arc::new(ProjectionExec::try_new( @@ -2564,7 +2587,7 @@ mod tests { let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortExec: expr=[b@1 ASC,c@2 + a@0 ASC]", + " SortExec: expr=[b@1 ASC, c@2 + a@0 ASC], preserve_partitioning=[false]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2573,7 +2596,7 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "SortExec: expr=[b@2 ASC,c@0 + new_a@1 ASC]", + "SortExec: expr=[b@2 ASC, c@0 + new_a@1 ASC], preserve_partitioning=[false]", " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; @@ -2586,7 +2609,7 @@ mod tests { fn test_sort_preserving_after_projection() -> Result<()> { let csv = create_simple_csv_exec(); let sort_req: Arc = Arc::new(SortPreservingMergeExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), @@ -2599,7 +2622,7 @@ mod tests { )), options: SortOptions::default(), }, - ], + ]), csv.clone(), )); let projection: Arc = Arc::new(ProjectionExec::try_new( @@ -2614,7 +2637,7 @@ mod tests { let initial = get_plan_string(&projection); let expected_initial = [ "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " SortPreservingMergeExec: [b@1 ASC,c@2 + a@0 ASC]", + " SortPreservingMergeExec: [b@1 ASC, c@2 + a@0 ASC]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2623,7 +2646,7 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "SortPreservingMergeExec: [b@2 ASC,c@0 + new_a@1 ASC]", + "SortPreservingMergeExec: [b@2 ASC, c@0 + new_a@1 ASC]", " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; @@ -2648,10 +2671,10 @@ mod tests { let initial = get_plan_string(&projection); let expected_initial = [ - "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " UnionExec", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + "ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " UnionExec", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(initial, expected_initial); @@ -2660,12 +2683,12 @@ mod tests { ProjectionPushdown::new().optimize(projection, &ConfigOptions::new())?; let expected = [ - "UnionExec", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", - " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", - " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + "UnionExec", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", + " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false", + " ProjectionExec: expr=[c@2 as c, a@0 as new_a, b@1 as b]", " CsvExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], has_header=false" ]; assert_eq!(get_plan_string(&after_optimize), expected); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 605ef9f9023f..eb03b337779c 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -173,10 +173,10 @@ pub trait PruningStatistics { /// 1. Arbitrary expressions (including user defined functions) /// /// 2. Vectorized evaluation (provide more than one set of statistics at a time) -/// so it is suitable for pruning 1000s of containers. +/// so it is suitable for pruning 1000s of containers. /// /// 3. Any source of information that implements the [`PruningStatistics`] trait -/// (not just Parquet metadata). +/// (not just Parquet metadata). /// /// # Example /// @@ -278,17 +278,17 @@ pub trait PruningStatistics { /// 2. A predicate (expression that evaluates to a boolean) /// /// 3. [`PruningStatistics`] that provides information about columns in that -/// schema, for multiple “containers”. For each column in each container, it -/// provides optional information on contained values, min_values, max_values, -/// null_counts counts, and row_counts counts. +/// schema, for multiple “containers”. For each column in each container, it +/// provides optional information on contained values, min_values, max_values, +/// null_counts counts, and row_counts counts. /// /// **Outputs**: /// A (non null) boolean value for each container: /// * `true`: There MAY be rows that match the predicate /// /// * `false`: There are no rows that could possibly match the predicate (the -/// predicate can never possibly be true). The container can be pruned (skipped) -/// entirely. +/// predicate can never possibly be true). The container can be pruned (skipped) +/// entirely. /// /// Note that in order to be correct, `PruningPredicate` must return false /// **only** if it can determine that for all rows in the container, the @@ -458,7 +458,7 @@ pub trait PruningStatistics { /// [`Snowflake SIGMOD Paper`]: https://dl.acm.org/doi/10.1145/2882903.2903741 /// [small materialized aggregates]: https://www.vldb.org/conf/1998/p476.pdf /// [zone maps]: https://dl.acm.org/doi/10.1007/978-3-642-03730-6_10 -///[data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 +/// [data skipping]: https://dl.acm.org/doi/10.1145/2588555.2610515 #[derive(Debug, Clone)] pub struct PruningPredicate { /// The input schema against which the predicate will be evaluated @@ -471,11 +471,43 @@ pub struct PruningPredicate { /// Original physical predicate from which this predicate expr is derived /// (required for serialization) orig_expr: Arc, - /// [`LiteralGuarantee`]s that are used to try and prove a predicate can not - /// possibly evaluate to `true`. + /// [`LiteralGuarantee`]s used to try and prove a predicate can not possibly + /// evaluate to `true`. + /// + /// See [`PruningPredicate::literal_guarantees`] for more details. literal_guarantees: Vec, } +/// Rewrites predicates that [`PredicateRewriter`] can not handle, e.g. certain +/// complex expressions or predicates that reference columns that are not in the +/// schema. +pub trait UnhandledPredicateHook { + /// Called when a predicate can not be rewritten in terms of statistics or + /// references a column that is not in the schema. + fn handle(&self, expr: &Arc) -> Arc; +} + +/// The default handling for unhandled predicates is to return a constant `true` +/// (meaning don't prune the container) +#[derive(Debug, Clone)] +struct ConstantUnhandledPredicateHook { + default: Arc, +} + +impl Default for ConstantUnhandledPredicateHook { + fn default() -> Self { + Self { + default: Arc::new(phys_expr::Literal::new(ScalarValue::from(true))), + } + } +} + +impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { + fn handle(&self, _expr: &Arc) -> Arc { + self.default.clone() + } +} + impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -500,10 +532,16 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + // build predicate expression once let mut required_columns = RequiredColumns::new(); - let predicate_expr = - build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + let predicate_expr = build_predicate_expression( + &expr, + schema.as_ref(), + &mut required_columns, + &unhandled_hook, + ); let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -595,6 +633,10 @@ impl PruningPredicate { } /// Returns a reference to the literal guarantees + /// + /// Note that **All** `LiteralGuarantee`s must be satisfied for the + /// expression to possibly be `true`. If any is not satisfied, the + /// expression is guaranteed to be `null` or `false`. pub fn literal_guarantees(&self) -> &[LiteralGuarantee] { &self.literal_guarantees } @@ -603,10 +645,14 @@ impl PruningPredicate { /// /// This happens if the predicate is a literal `true` and /// literal_guarantees is empty. + /// + /// This can happen when a predicate is simplified to a constant `true` pub fn always_true(&self) -> bool { is_always_true(&self.predicate_expr) && self.literal_guarantees.is_empty() } + // this is only used by `parquet` feature right now + #[allow(dead_code)] pub(crate) fn required_columns(&self) -> &RequiredColumns { &self.required_columns } @@ -617,7 +663,7 @@ impl PruningPredicate { /// /// This is useful to avoid fetching statistics for columns that will not be /// used in the predicate. For example, it can be used to avoid reading - /// uneeded bloom filters (a non trivial operation). + /// unneeded bloom filters (a non trivial operation). pub fn literal_columns(&self) -> Vec { let mut seen = HashSet::new(); self.literal_guarantees @@ -730,12 +776,27 @@ impl RequiredColumns { Self::default() } - /// Returns number of unique columns - pub(crate) fn n_columns(&self) -> usize { - self.iter() - .map(|(c, _s, _f)| c) - .collect::>() - .len() + /// Returns Some(column) if this is a single column predicate. + /// + /// Returns None if this is a multi-column predicate. + /// + /// Examples: + /// * `a > 5 OR a < 10` returns `Some(a)` + /// * `a > 5 OR b < 10` returns `None` + /// * `true` returns None + #[allow(dead_code)] + // this fn is only used by `parquet` feature right now, thus the `allow(dead_code)` + pub(crate) fn single_column(&self) -> Option<&phys_expr::Column> { + if self.columns.windows(2).all(|w| { + // check if all columns are the same (ignoring statistics and field) + let c1 = &w[0].0; + let c2 = &w[1].0; + c1 == c2 + }) { + self.columns.first().map(|r| &r.0) + } else { + None + } } /// Returns an iterator over items in columns (see doc on @@ -772,13 +833,19 @@ impl RequiredColumns { column_expr: &Arc, field: &Field, stat_type: StatisticsType, - suffix: &str, ) -> Result> { let (idx, need_to_insert) = match self.find_stat_column(column, stat_type) { Some(idx) => (idx, false), None => (self.columns.len(), true), }; + let suffix = match stat_type { + StatisticsType::Min => "min", + StatisticsType::Max => "max", + StatisticsType::NullCount => "null_count", + StatisticsType::RowCount => "row_count", + }; + let stat_column = phys_expr::Column::new(&format!("{}_{}", column.name(), suffix), idx); @@ -800,7 +867,7 @@ impl RequiredColumns { column_expr: &Arc, field: &Field, ) -> Result> { - self.stat_column_expr(column, column_expr, field, StatisticsType::Min, "min") + self.stat_column_expr(column, column_expr, field, StatisticsType::Min) } /// rewrite col --> col_max @@ -810,7 +877,7 @@ impl RequiredColumns { column_expr: &Arc, field: &Field, ) -> Result> { - self.stat_column_expr(column, column_expr, field, StatisticsType::Max, "max") + self.stat_column_expr(column, column_expr, field, StatisticsType::Max) } /// rewrite col --> col_null_count @@ -820,13 +887,7 @@ impl RequiredColumns { column_expr: &Arc, field: &Field, ) -> Result> { - self.stat_column_expr( - column, - column_expr, - field, - StatisticsType::NullCount, - "null_count", - ) + self.stat_column_expr(column, column_expr, field, StatisticsType::NullCount) } /// rewrite col --> col_row_count @@ -836,13 +897,7 @@ impl RequiredColumns { column_expr: &Arc, field: &Field, ) -> Result> { - self.stat_column_expr( - column, - column_expr, - field, - StatisticsType::RowCount, - "row_count", - ) + self.stat_column_expr(column, column_expr, field, StatisticsType::RowCount) } } @@ -1293,27 +1348,78 @@ fn build_is_null_column_expr( /// an OR chain const MAX_LIST_VALUE_SIZE_REWRITE: usize = 20; +/// Rewrite a predicate expression in terms of statistics (min/max/null_counts) +/// for use as a [`PruningPredicate`]. +pub struct PredicateRewriter { + unhandled_hook: Arc, +} + +impl Default for PredicateRewriter { + fn default() -> Self { + Self { + unhandled_hook: Arc::new(ConstantUnhandledPredicateHook::default()), + } + } +} + +impl PredicateRewriter { + /// Create a new `PredicateRewriter` + pub fn new() -> Self { + Self::default() + } + + /// Set the unhandled hook to be used when a predicate can not be rewritten + pub fn with_unhandled_hook( + self, + unhandled_hook: Arc, + ) -> Self { + Self { unhandled_hook } + } + + /// Translate logical filter expression into pruning predicate + /// expression that will evaluate to FALSE if it can be determined no + /// rows between the min/max values could pass the predicates. + /// + /// Any predicates that can not be translated will be passed to `unhandled_hook`. + /// + /// Returns the pruning predicate as an [`PhysicalExpr`] + /// + /// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` + pub fn rewrite_predicate_to_statistics_predicate( + &self, + expr: &Arc, + schema: &Schema, + ) -> Arc { + let mut required_columns = RequiredColumns::new(); + build_predicate_expression( + expr, + schema, + &mut required_columns, + &self.unhandled_hook, + ) + } +} + /// Translate logical filter expression into pruning predicate /// expression that will evaluate to FALSE if it can be determined no /// rows between the min/max values could pass the predicates. /// +/// Any predicates that can not be translated will be passed to `unhandled_hook`. +/// /// Returns the pruning predicate as an [`PhysicalExpr`] /// -/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will be rewritten to TRUE +/// Notice: Does not handle [`phys_expr::InListExpr`] greater than 20, which will fall back to calling `unhandled_hook` fn build_predicate_expression( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + unhandled_hook: &Arc, ) -> Arc { - // Returned for unsupported expressions. Such expressions are - // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); - // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(is_not_null) = expr_any.downcast_ref::() { return build_is_null_column_expr( @@ -1322,19 +1428,19 @@ fn build_predicate_expression( required_columns, true, ) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { - return unhandled; + return unhandled_hook.handle(expr); } } if let Some(in_list) = expr_any.downcast_ref::() { @@ -1354,7 +1460,6 @@ fn build_predicate_expression( let change_expr = in_list .list() .iter() - .cloned() .map(|e| { Arc::new(phys_expr::BinaryExpr::new( in_list.expr().clone(), @@ -1364,9 +1469,14 @@ fn build_predicate_expression( }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); - return build_predicate_expression(&change_expr, schema, required_columns); + return build_predicate_expression( + &change_expr, + schema, + required_columns, + unhandled_hook, + ); } else { - return unhandled; + return unhandled_hook.handle(expr); } } @@ -1378,13 +1488,15 @@ fn build_predicate_expression( bin_expr.right().clone(), ) } else { - return unhandled; + return unhandled_hook.handle(expr); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(&left, schema, required_columns); - let right_expr = build_predicate_expression(&right, schema, required_columns); + let left_expr = + build_predicate_expression(&left, schema, required_columns, unhandled_hook); + let right_expr = + build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, @@ -1392,7 +1504,7 @@ fn build_predicate_expression( (left, Operator::Or, right) if is_always_true(left) || is_always_true(right) => { - unhandled + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; @@ -1405,12 +1517,11 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => { - return unhandled; - } + Err(_) => return unhandled_hook.handle(expr), }; - build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) + build_statistics_expr(&mut expr_builder) + .unwrap_or_else(|_| unhandled_hook.handle(expr)) } fn build_statistics_expr( @@ -1549,22 +1660,24 @@ pub(crate) enum StatisticsType { #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::ops::{Not, Rem}; + use super::*; use crate::assert_batches_eq; use crate::logical_expr::{col, lit}; + use arrow::array::Decimal128Array; use arrow::{ array::{BinaryArray, Int32Array, Int64Array, StringArray}, datatypes::TimeUnit, }; use arrow_array::UInt64Array; - use datafusion_common::ToDFSchema; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; - use datafusion_physical_expr::create_physical_expr; - use std::collections::HashMap; - use std::ops::{Not, Rem}; + use datafusion_functions_nested::expr_fn::{array_has, make_array}; + use datafusion_physical_expr::expressions as phys_expr; + use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] /// Mock statistic provider for tests @@ -3379,6 +3492,74 @@ mod tests { // TODO: add test for other case and op } + #[test] + fn test_rewrite_expr_to_prunable_custom_unhandled_hook() { + struct CustomUnhandledHook; + + impl UnhandledPredicateHook for CustomUnhandledHook { + /// This handles an arbitrary case of a column that doesn't exist in the schema + /// by renaming it to yet another column that doesn't exist in the schema + /// (the transformation is arbitrary, the point is that it can do whatever it wants) + fn handle(&self, _expr: &Arc) -> Arc { + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42)))) + } + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let schema_with_b = Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ]); + + let rewriter = PredicateRewriter::new() + .with_unhandled_hook(Arc::new(CustomUnhandledHook {})); + + let transform_expr = |expr| { + let expr = logical2physical(&expr, &schema_with_b); + rewriter.rewrite_predicate_to_statistics_predicate(&expr, &schema) + }; + + // transform an arbitrary valid expression that we know is handled + let known_expression = col("a").eq(lit(12)); + let known_expression_transformed = PredicateRewriter::new() + .rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), + &schema, + ); + + // an expression referencing an unknown column (that is not in the schema) gets passed to the hook + let input = col("b").eq(lit(12)); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown column + let input = known_expression.clone().and(input.clone()); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // an unknown expression gets passed to the hook + let input = array_has(make_array(vec![lit(1)]), col("a")); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown expression + let input = known_expression.and(input); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + } + #[test] fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value @@ -3868,12 +4049,7 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - build_predicate_expression(&expr, schema, required_columns) - } - - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::default()) as _; + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } } diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 9b6e2076912c..930ce52e6fa2 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -33,6 +33,7 @@ use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::tree_node::PlanContext; use datafusion_physical_plan::ExecutionPlanProperties; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; /// For a given `plan`, this object carries the information one needs from its @@ -131,7 +132,8 @@ fn plan_with_order_preserving_variants( if let Some(ordering) = child.output_ordering().map(Vec::from) { // When the input of a `CoalescePartitionsExec` has an ordering, // replace it with a `SortPreservingMergeExec` if appropriate: - let spm = SortPreservingMergeExec::new(ordering, child.clone()); + let spm = + SortPreservingMergeExec::new(LexOrdering::new(ordering), child.clone()); sort_input.plan = Arc::new(spm) as _; sort_input.children[0].data = true; return Ok(sort_input); @@ -255,7 +257,7 @@ pub(crate) fn replace_with_order_preserving_variants( if alternate_plan .plan .equivalence_properties() - .ordering_satisfy(requirements.plan.output_ordering().unwrap_or(&[])) + .ordering_satisfy(requirements.plan.output_ordering().unwrap_or_default()) { for child in alternate_plan.children.iter_mut() { child.data = false; @@ -291,7 +293,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{TransformedResult, TreeNode}; - use datafusion_common::{Result, Statistics}; + use datafusion_common::Result; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::{self, col, Column}; @@ -426,14 +428,14 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -450,7 +452,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -507,11 +509,11 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -519,11 +521,11 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -545,11 +547,11 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -600,7 +602,7 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -608,7 +610,7 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -627,7 +629,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -676,7 +678,7 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -685,7 +687,7 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -706,7 +708,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -759,7 +761,7 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -769,7 +771,7 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -792,7 +794,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -917,7 +919,7 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -927,7 +929,7 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -950,7 +952,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC NULLS LAST]", - " SortExec: expr=[a@0 ASC NULLS LAST]", + " SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " CoalesceBatchesExec: target_batch_size=8192", " FilterExec: c@1 > 3", @@ -1007,14 +1009,14 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_input_bounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -1023,7 +1025,7 @@ mod tests { // Expected unbounded result (same for with and without flag) let expected_optimized_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", @@ -1032,7 +1034,7 @@ mod tests { // Expected bounded results same with and without flag, because ordering requirement of the executor is different than the existing ordering. let expected_optimized_bounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " CsvExec: file_groups={1 group: [[file_path]]}, projection=[a, c, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true", @@ -1071,14 +1073,14 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST]", + "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", " StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST]", ]; let expected_input_bounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST]", + "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1095,7 +1097,7 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ - "SortExec: expr=[a@0 ASC NULLS LAST]", + "SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1153,11 +1155,11 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1165,11 +1167,11 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1182,7 +1184,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1192,11 +1194,11 @@ mod tests { // Expected bounded results with and without flag let expected_optimized_bounded = [ "SortPreservingMergeExec: [c@1 ASC]", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[true]", " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1207,7 +1209,7 @@ mod tests { " FilterExec: c@1 > 3", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8, preserve_order=true, sort_exprs=c@1 ASC", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", - " SortExec: expr=[c@1 ASC]", + " SortExec: expr=[c@1 ASC], preserve_partitioning=[false]", " CoalescePartitionsExec", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", @@ -1270,7 +1272,7 @@ mod tests { // Expected inputs unbounded and bounded let expected_input_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -1283,7 +1285,7 @@ mod tests { ]; let expected_input_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -1298,7 +1300,7 @@ mod tests { // Expected unbounded result (same for with and without flag) let expected_optimized_unbounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -1314,7 +1316,7 @@ mod tests { // existing ordering. let expected_optimized_bounded = [ "SortPreservingMergeExec: [a@0 ASC]", - " SortExec: expr=[a@0 ASC]", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[true]", " HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c@1, c@1)]", " CoalesceBatchesExec: target_batch_size=4096", " RepartitionExec: partitioning=Hash([c@1], 8), input_partitions=8", @@ -1475,6 +1477,7 @@ mod tests { Some(&projection), vec![sort_exprs], true, + None, ) .unwrap(), ) @@ -1489,25 +1492,24 @@ mod tests { let sort_exprs = sort_exprs.into_iter().collect(); let projection: Vec = vec![0, 2, 3]; - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new( - "file_path".to_string(), - 100, - )]], - statistics: Statistics::new_unknown(schema), - projection: Some(projection), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - }, - true, - 0, - b'"', - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + CsvExec::builder( + FileScanConfig::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema.clone(), + ) + .with_file(PartitionedFile::new("file_path".to_string(), 100)) + .with_projection(Some(projection)) + .with_output_ordering(vec![sort_exprs]), + ) + .with_has_header(true) + .with_delimeter(0) + .with_quote(b'"') + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } } diff --git a/datafusion/core/src/physical_optimizer/sanity_checker.rs b/datafusion/core/src/physical_optimizer/sanity_checker.rs new file mode 100644 index 000000000000..4d2baf1fe1ab --- /dev/null +++ b/datafusion/core/src/physical_optimizer/sanity_checker.rs @@ -0,0 +1,671 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! The [SanityCheckPlan] rule ensures that a given plan can +//! accommodate its infinite sources, if there are any. It will reject +//! non-runnable query plans that use pipeline-breaking operators on +//! infinite input(s). In addition, it will check if all order and +//! distribution requirements of a plan are satisfied by its children. + +use std::sync::Arc; + +use crate::error::Result; +use crate::physical_plan::ExecutionPlan; + +use datafusion_common::config::{ConfigOptions, OptimizerOptions}; +use datafusion_common::plan_err; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; +use datafusion_physical_plan::joins::SymmetricHashJoinExec; +use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties}; + +use datafusion_physical_expr_common::sort_expr::format_physical_sort_requirement_list; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use itertools::izip; + +/// The SanityCheckPlan rule rejects the following query plans: +/// 1. Invalid plans containing nodes whose order and/or distribution requirements +/// are not satisfied by their children. +/// 2. Plans that use pipeline-breaking operators on infinite input(s), +/// it is impossible to execute such queries (they will never generate output nor finish) +#[derive(Default, Debug)] +pub struct SanityCheckPlan {} + +impl SanityCheckPlan { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for SanityCheckPlan { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|p| check_plan_sanity(p, &config.optimizer)) + .data() + } + + fn name(&self) -> &str { + "SanityCheckPlan" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This function propagates finiteness information and rejects any plan with +/// pipeline-breaking operators acting on infinite inputs. +pub fn check_finiteness_requirements( + input: Arc, + optimizer_options: &OptimizerOptions, +) -> Result>> { + if let Some(exec) = input.as_any().downcast_ref::() { + if !(optimizer_options.allow_symmetric_joins_without_pruning + || (exec.check_if_order_information_available()? && is_prunable(exec))) + { + return plan_err!("Join operation cannot operate on a non-prunable stream without enabling \ + the 'allow_symmetric_joins_without_pruning' configuration flag"); + } + } + if !input.execution_mode().pipeline_friendly() { + plan_err!( + "Cannot execute pipeline breaking queries, operator: {:?}", + input + ) + } else { + Ok(Transformed::no(input)) + } +} + +/// This function returns whether a given symmetric hash join is amenable to +/// data pruning. For this to be possible, it needs to have a filter where +/// all involved [`PhysicalExpr`]s, [`Operator`]s and data types support +/// interval calculations. +/// +/// [`PhysicalExpr`]: crate::physical_plan::PhysicalExpr +/// [`Operator`]: datafusion_expr::Operator +fn is_prunable(join: &SymmetricHashJoinExec) -> bool { + join.filter().map_or(false, |filter| { + check_support(filter.expression(), &join.schema()) + && filter + .schema() + .fields() + .iter() + .all(|f| is_datatype_supported(f.data_type())) + }) +} + +/// Ensures that the plan is pipeline friendly and the order and +/// distribution requirements from its children are satisfied. +pub fn check_plan_sanity( + plan: Arc, + optimizer_options: &OptimizerOptions, +) -> Result>> { + check_finiteness_requirements(plan.clone(), optimizer_options)?; + + for ((idx, child), sort_req, dist_req) in izip!( + plan.children().iter().enumerate(), + plan.required_input_ordering().iter(), + plan.required_input_distribution().iter() + ) { + let child_eq_props = child.equivalence_properties(); + if let Some(sort_req) = sort_req { + if !child_eq_props.ordering_satisfy_requirement(sort_req) { + let plan_str = get_plan_string(&plan); + return plan_err!( + "Plan: {:?} does not satisfy order requirements: {}. Child-{} order: {}", + plan_str, + format_physical_sort_requirement_list(sort_req), + idx, + child_eq_props.oeq_class + ); + } + } + + if !child + .output_partitioning() + .satisfy(dist_req, child_eq_props) + { + let plan_str = get_plan_string(&plan); + return plan_err!( + "Plan: {:?} does not satisfy distribution requirements: {}. Child-{} output partitioning: {}", + plan_str, + dist_req, + idx, + child.output_partitioning() + ); + } + } + + Ok(Transformed::no(plan)) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::physical_optimizer::test_utils::{ + bounded_window_exec, global_limit_exec, local_limit_exec, memory_exec, + repartition_exec, sort_exec, sort_expr_options, sort_merge_join_exec, + BinaryTestCase, QueryCase, SourceType, UnaryTestCase, + }; + + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::Result; + use datafusion_expr::JoinType; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr::Partitioning; + use datafusion_physical_plan::displayable; + use datafusion_physical_plan::repartition::RepartitionExec; + + fn create_test_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new("c9", DataType::Int32, true)])) + } + + fn create_test_schema2() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])) + } + + /// Check if sanity checker should accept or reject plans. + fn assert_sanity_check(plan: &Arc, is_sane: bool) { + let sanity_checker = SanityCheckPlan::new(); + let opts = ConfigOptions::default(); + assert_eq!( + sanity_checker.optimize(plan.clone(), &opts).is_ok(), + is_sane + ); + } + + /// Check if the plan we created is as expected by comparing the plan + /// formatted as a string. + fn assert_plan(plan: &dyn ExecutionPlan, expected_lines: Vec<&str>) { + let plan_str = displayable(plan).indent(true).to_string(); + let actual_lines: Vec<&str> = plan_str.trim().lines().collect(); + assert_eq!(actual_lines, expected_lines); + } + + #[tokio::test] + async fn test_hash_left_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: false, + }; + + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 LEFT JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_right_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 RIGHT JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_inner_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: false, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: false, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "Join Error".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_full_outer_join_swap() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 FULL JOIN right as t2 ON t1.c1 = t2.c1" + .to_string(), + cases: vec![Arc::new(test1), Arc::new(test2), Arc::new(test3)], + error_operator: "operator: HashJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_aggregate() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT c1, MIN(c4) FROM test GROUP BY c1".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: AggregateExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_window_agg_hash_partition() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT + c9, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 + FROM test + LIMIT 5".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: SortExec".to_string() + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_window_agg_single_partition() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: true, + }; + let case = QueryCase { + sql: "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) as sum1 + FROM test".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "operator: SortExec".to_string() + }; + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_hash_cross_join() -> Result<()> { + let test1 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Bounded), + expect_fail: true, + }; + let test2 = BinaryTestCase { + source_types: (SourceType::Unbounded, SourceType::Unbounded), + expect_fail: true, + }; + let test3 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Unbounded), + expect_fail: true, + }; + let test4 = BinaryTestCase { + source_types: (SourceType::Bounded, SourceType::Bounded), + expect_fail: false, + }; + let case = QueryCase { + sql: "SELECT t2.c1 FROM left as t1 CROSS JOIN right as t2".to_string(), + cases: vec![ + Arc::new(test1), + Arc::new(test2), + Arc::new(test3), + Arc::new(test4), + ], + error_operator: "operator: CrossJoinExec".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + async fn test_analyzer() -> Result<()> { + let test1 = UnaryTestCase { + source_type: SourceType::Bounded, + expect_fail: false, + }; + let test2 = UnaryTestCase { + source_type: SourceType::Unbounded, + expect_fail: false, + }; + let case = QueryCase { + sql: "EXPLAIN ANALYZE SELECT * FROM test".to_string(), + cases: vec![Arc::new(test1), Arc::new(test2)], + error_operator: "Analyze Error".to_string(), + }; + + case.run().await?; + Ok(()) + } + + #[tokio::test] + /// Tests that plan is valid when the sort requirements are satisfied. + async fn test_bounded_window_agg_sort_requirement() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr_options( + "c9", + &source.schema(), + SortOptions { + descending: false, + nulls_first: false, + }, + )]; + let sort = sort_exec(sort_exprs.clone(), source); + let bw = bounded_window_exec("c9", sort_exprs, sort); + assert_plan(bw.as_ref(), vec![ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]" + ]); + assert_sanity_check(&bw, true); + Ok(()) + } + + #[tokio::test] + /// Tests that plan is invalid when the sort requirements are not satisfied. + async fn test_bounded_window_agg_no_sort_requirement() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let sort_exprs = vec![sort_expr_options( + "c9", + &source.schema(), + SortOptions { + descending: false, + nulls_first: false, + }, + )]; + let bw = bounded_window_exec("c9", sort_exprs, source); + assert_plan(bw.as_ref(), vec![ + "BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", + " MemoryExec: partitions=1, partition_sizes=[0]" + ]); + // Order requirement of the `BoundedWindowAggExec` is not satisfied. We expect to receive error during sanity check. + assert_sanity_check(&bw, false); + Ok(()) + } + + #[tokio::test] + /// A valid when a single partition requirement + /// is satisfied. + async fn test_global_limit_single_partition() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = global_limit_exec(source); + + assert_plan( + limit.as_ref(), + vec![ + "GlobalLimitExec: skip=0, fetch=100", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&limit, true); + Ok(()) + } + + #[tokio::test] + /// An invalid plan when a single partition requirement + /// is not satisfied. + async fn test_global_limit_multi_partition() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = global_limit_exec(repartition_exec(source)); + + assert_plan( + limit.as_ref(), + vec![ + "GlobalLimitExec: skip=0, fetch=100", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Distribution requirement of the `GlobalLimitExec` is not satisfied. We expect to receive error during sanity check. + assert_sanity_check(&limit, false); + Ok(()) + } + + #[tokio::test] + /// A plan with no requirements should satisfy. + async fn test_local_limit() -> Result<()> { + let schema = create_test_schema(); + let source = memory_exec(&schema); + let limit = local_limit_exec(source); + + assert_plan( + limit.as_ref(), + vec![ + "LocalLimitExec: fetch=100", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&limit, true); + Ok(()) + } + + #[tokio::test] + /// Valid plan with multiple children satisfy both order and distribution. + async fn test_sort_merge_join_satisfied() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let source2 = memory_exec(&schema2); + let sort_opts = SortOptions::default(); + let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; + let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; + let left = sort_exec(sort_exprs1, source1); + let right = sort_exec(sort_exprs2, source2); + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(vec![right_jcol.clone()], 10), + )?); + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + assert_sanity_check(&smj, true); + Ok(()) + } + + #[tokio::test] + /// Invalid case when the order is not satisfied by the 2nd + /// child. + async fn test_sort_merge_join_order_missing() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let right = memory_exec(&schema2); + let sort_exprs1 = vec![sort_expr_options( + "c9", + &source1.schema(), + SortOptions::default(), + )]; + let left = sort_exec(sort_exprs1, source1); + // Missing sort of the right child here.. + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::Hash(vec![right_jcol.clone()], 10), + )?); + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=Hash([a@0], 10), input_partitions=1", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Order requirement for the `SortMergeJoin` is not satisfied for right child. We expect to receive error during sanity check. + assert_sanity_check(&smj, false); + Ok(()) + } + + #[tokio::test] + /// Invalid case when the distribution is not satisfied by the 2nd + /// child. + async fn test_sort_merge_join_dist_missing() -> Result<()> { + let schema1 = create_test_schema(); + let schema2 = create_test_schema2(); + let source1 = memory_exec(&schema1); + let source2 = memory_exec(&schema2); + let sort_opts = SortOptions::default(); + let sort_exprs1 = vec![sort_expr_options("c9", &source1.schema(), sort_opts)]; + let sort_exprs2 = vec![sort_expr_options("a", &source2.schema(), sort_opts)]; + let left = sort_exec(sort_exprs1, source1); + let right = sort_exec(sort_exprs2, source2); + let right = Arc::new(RepartitionExec::try_new( + right, + Partitioning::RoundRobinBatch(10), + )?); + let left_jcol = col("c9", &left.schema()).unwrap(); + let right_jcol = col("a", &right.schema()).unwrap(); + let left = Arc::new(RepartitionExec::try_new( + left, + Partitioning::Hash(vec![left_jcol.clone()], 10), + )?); + + // Missing hash partitioning on right child. + + let join_on = vec![(left_jcol as _, right_jcol as _)]; + let join_ty = JoinType::Inner; + let smj = sort_merge_join_exec(left, right, &join_on, &join_ty); + + assert_plan( + smj.as_ref(), + vec![ + "SortMergeJoin: join_type=Inner, on=[(c9@0, a@0)]", + " RepartitionExec: partitioning=Hash([c9@0], 10), input_partitions=1", + " SortExec: expr=[c9@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + " RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + " SortExec: expr=[a@0 ASC], preserve_partitioning=[false]", + " MemoryExec: partitions=1, partition_sizes=[0]", + ], + ); + // Distribution requirement for the `SortMergeJoin` is not satisfied for right child (has round-robin partitioning). We expect to receive error during sanity check. + assert_sanity_check(&smj, false); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index c527819e7746..9eb200f534db 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -15,64 +15,94 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::Debug; use std::sync::Arc; -use super::utils::add_sort_above; -use crate::physical_optimizer::utils::{ - is_limit, is_sort_preserving_merge, is_union, is_window, -}; +use super::utils::{add_sort_above, is_sort}; +use crate::physical_optimizer::utils::{is_sort_preserving_merge, is_union, is_window}; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::utils::calculate_join_output_ordering; -use crate::physical_plan::joins::{HashJoinExec, SortMergeJoinExec}; +use crate::physical_plan::joins::SortMergeJoinExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::tree_node::PlanContext; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{ + ConcreteTreeNode, Transformed, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{plan_err, JoinSide, Result}; use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{ - LexRequirementRef, PhysicalSortExpr, PhysicalSortRequirement, +use datafusion_physical_expr::utils::collect_columns; +use datafusion_physical_expr::{LexRequirementRef, PhysicalSortRequirement}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexOrderingRef, LexRequirement, }; +use hashbrown::HashSet; + /// This is a "data class" we use within the [`EnforceSorting`] rule to push /// down [`SortExec`] in the plan. In some cases, we can reduce the total /// computational cost by pushing down `SortExec`s through some executors. The -/// object carries the parent required ordering as its data. +/// object carries the parent required ordering and the (optional) `fetch` value +/// of the parent node as its data. /// /// [`EnforceSorting`]: crate::physical_optimizer::enforce_sorting::EnforceSorting -pub type SortPushDown = PlanContext>>; +#[derive(Default, Clone)] +pub struct ParentRequirements { + ordering_requirement: Option, + fetch: Option, +} + +pub type SortPushDown = PlanContext; /// Assigns the ordering requirement of the root node to the its children. pub fn assign_initial_requirements(node: &mut SortPushDown) { let reqs = node.plan.required_input_ordering(); for (child, requirement) in node.children.iter_mut().zip(reqs) { - child.data = requirement; + child.data = ParentRequirements { + ordering_requirement: requirement, + fetch: None, + }; } } -pub(crate) fn pushdown_sorts( +pub(crate) fn pushdown_sorts(sort_pushdown: SortPushDown) -> Result { + let mut new_node = pushdown_sorts_helper(sort_pushdown)?; + while new_node.tnr == TreeNodeRecursion::Stop { + new_node = pushdown_sorts_helper(new_node.data)?; + } + let (new_node, children) = new_node.data.take_children(); + let new_children = children + .into_iter() + .map(pushdown_sorts) + .collect::>()?; + new_node.with_new_children(new_children) +} + +fn pushdown_sorts_helper( mut requirements: SortPushDown, ) -> Result> { let plan = &requirements.plan; - let parent_reqs = requirements.data.as_deref().unwrap_or(&[]); + let parent_reqs = requirements + .data + .ordering_requirement + .as_deref() + .unwrap_or(&[]); let satisfy_parent = plan .equivalence_properties() .ordering_satisfy_requirement(parent_reqs); - - if let Some(sort_exec) = plan.as_any().downcast_ref::() { + if is_sort(plan) { let required_ordering = plan .output_ordering() .map(PhysicalSortRequirement::from_sort_exprs) .unwrap_or_default(); - if !satisfy_parent { // Make sure this `SortExec` satisfies parent requirements: - let fetch = sort_exec.fetch(); - let sort_reqs = requirements.data.unwrap_or_default(); + let sort_reqs = requirements.data.ordering_requirement.unwrap_or_default(); + let fetch = requirements.data.fetch; requirements = requirements.children.swap_remove(0); requirements = add_sort_above(requirements, sort_reqs, fetch); }; @@ -82,12 +112,24 @@ pub(crate) fn pushdown_sorts( if let Some(adjusted) = pushdown_requirement_to_children(&child.plan, &required_ordering)? { + let fetch = child.plan.fetch(); for (grand_child, order) in child.children.iter_mut().zip(adjusted) { - grand_child.data = order; + grand_child.data = ParentRequirements { + ordering_requirement: order, + fetch, + }; } // Can push down requirements - child.data = None; - return Ok(Transformed::yes(child)); + child.data = ParentRequirements { + ordering_requirement: Some(required_ordering), + fetch, + }; + + return Ok(Transformed { + data: child, + transformed: true, + tnr: TreeNodeRecursion::Stop, + }); } else { // Can not push down requirements requirements.children = vec![child]; @@ -97,19 +139,24 @@ pub(crate) fn pushdown_sorts( // For non-sort operators, immediately return if parent requirements are met: let reqs = plan.required_input_ordering(); for (child, order) in requirements.children.iter_mut().zip(reqs) { - child.data = order; + child.data.ordering_requirement = order; } } else if let Some(adjusted) = pushdown_requirement_to_children(plan, parent_reqs)? { // Can not satisfy the parent requirements, check whether we can push // requirements down: for (child, order) in requirements.children.iter_mut().zip(adjusted) { - child.data = order; + child.data.ordering_requirement = order; } - requirements.data = None; + requirements.data.ordering_requirement = None; } else { // Can not push down requirements, add new `SortExec`: - let sort_reqs = requirements.data.clone().unwrap_or_default(); - requirements = add_sort_above(requirements, sort_reqs, None); + let sort_reqs = requirements + .data + .ordering_requirement + .clone() + .unwrap_or_default(); + let fetch = requirements.data.fetch; + requirements = add_sort_above(requirements, sort_reqs, fetch); assign_initial_requirements(&mut requirements); } Ok(Transformed::yes(requirements)) @@ -118,7 +165,7 @@ pub(crate) fn pushdown_sorts( fn pushdown_requirement_to_children( plan: &Arc, parent_required: LexRequirementRef, -) -> Result>>>> { +) -> Result>>> { let maintains_input_order = plan.maintains_input_order(); if is_window(plan) { let required_input_ordering = plan.required_input_ordering(); @@ -126,16 +173,58 @@ fn pushdown_requirement_to_children( let child_plan = plan.children().swap_remove(0); match determine_children_requirement(parent_required, request_child, child_plan) { RequirementsCompatibility::Satisfy => { - let req = (!request_child.is_empty()).then(|| request_child.to_vec()); + let req = (!request_child.is_empty()) + .then(|| LexRequirement::new(request_child.to_vec())); Ok(Some(vec![req])) } RequirementsCompatibility::Compatible(adjusted) => Ok(Some(vec![adjusted])), RequirementsCompatibility::NonCompatible => Ok(None), } + } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { + let sort_req = PhysicalSortRequirement::from_sort_exprs( + sort_exec.properties().output_ordering().unwrap_or(&[]), + ); + if sort_exec + .properties() + .eq_properties + .requirements_compatible(parent_required, &sort_req) + { + debug_assert!(!parent_required.is_empty()); + Ok(Some(vec![Some(LexRequirement::new( + parent_required.to_vec(), + ))])) + } else { + Ok(None) + } + } else if plan.fetch().is_some() + && plan.supports_limit_pushdown() + && plan + .maintains_input_order() + .iter() + .all(|maintain| *maintain) + { + let output_req = PhysicalSortRequirement::from_sort_exprs( + plan.properties().output_ordering().unwrap_or(&[]), + ); + // Push down through operator with fetch when: + // - requirement is aligned with output ordering + // - it preserves ordering during execution + if plan + .properties() + .eq_properties + .requirements_compatible(parent_required, &output_req) + { + let req = (!parent_required.is_empty()) + .then(|| LexRequirement::new(parent_required.to_vec())); + Ok(Some(vec![req])) + } else { + Ok(None) + } } else if is_union(plan) { // UnionExec does not have real sort requirements for its input. Here we change the adjusted_request_ordering to UnionExec's output ordering and // propagate the sort requirements down to correct the unnecessary descendant SortExec under the UnionExec - let req = (!parent_required.is_empty()).then(|| parent_required.to_vec()); + let req = (!parent_required.is_empty()) + .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req; plan.children().len()])) } else if let Some(smj) = plan.as_any().downcast_ref::() { // If the current plan is SortMergeJoinExec @@ -146,7 +235,7 @@ fn pushdown_requirement_to_children( Some(JoinSide::Left) => try_pushdown_requirements_to_join( smj, parent_required, - parent_required_expr, + parent_required_expr.as_ref(), JoinSide::Left, ), Some(JoinSide::Right) => { @@ -159,7 +248,7 @@ fn pushdown_requirement_to_children( try_pushdown_requirements_to_join( smj, parent_required, - new_right_required_expr, + new_right_required_expr.as_ref(), JoinSide::Right, ) } @@ -174,8 +263,7 @@ fn pushdown_requirement_to_children( || plan.as_any().is::() // TODO: Add support for Projection push down || plan.as_any().is::() - || is_limit(plan) - || plan.as_any().is::() + || pushdown_would_violate_requirements(parent_required, plan.as_ref()) { // If the current plan is a leaf node or can not maintain any of the input ordering, can not pushed down requirements. // For RepartitionExec, we always choose to not push down the sort requirements even the RepartitionExec(input_partition=1) could maintain input ordering. @@ -189,28 +277,44 @@ fn pushdown_requirement_to_children( spm_eqs = spm_eqs.with_reorder(new_ordering); // Do not push-down through SortPreservingMergeExec when // ordering requirement invalidates requirement of sort preserving merge exec. - if !spm_eqs.ordering_satisfy(plan.output_ordering().unwrap_or(&[])) { + if !spm_eqs.ordering_satisfy(plan.output_ordering().unwrap_or_default()) { Ok(None) } else { // Can push-down through SortPreservingMergeExec, because parent requirement is finer // than SortPreservingMergeExec output ordering. - let req = (!parent_required.is_empty()).then(|| parent_required.to_vec()); + let req = (!parent_required.is_empty()) + .then(|| LexRequirement::new(parent_required.to_vec())); Ok(Some(vec![req])) } } else { - Ok(Some( - maintains_input_order - .into_iter() - .map(|flag| { - (flag && !parent_required.is_empty()) - .then(|| parent_required.to_vec()) - }) - .collect(), - )) + handle_custom_pushdown(plan, parent_required, maintains_input_order) } // TODO: Add support for Projection push down } +/// Return true if pushing the sort requirements through a node would violate +/// the input sorting requirements for the plan +fn pushdown_would_violate_requirements( + parent_required: LexRequirementRef, + child: &dyn ExecutionPlan, +) -> bool { + child + .required_input_ordering() + .iter() + .any(|child_required| { + let Some(child_required) = child_required.as_ref() else { + // no requirements, so pushing down would not violate anything + return false; + }; + // check if the plan's requirements would still e satisfied if we pushed + // down the parent requirements + child_required + .iter() + .zip(parent_required.iter()) + .all(|(c, p)| !c.compatible(p)) + }) +} + /// Determine children requirements: /// - If children requirements are more specific, do not push down parent /// requirements. @@ -219,7 +323,7 @@ fn pushdown_requirement_to_children( fn determine_children_requirement( parent_required: LexRequirementRef, request_child: LexRequirementRef, - child_plan: Arc, + child_plan: &Arc, ) -> RequirementsCompatibility { if child_plan .equivalence_properties() @@ -233,23 +337,55 @@ fn determine_children_requirement( { // Parent requirements are more specific, adjust child's requirements // and push down the new requirements: - let adjusted = (!parent_required.is_empty()).then(|| parent_required.to_vec()); + let adjusted = (!parent_required.is_empty()) + .then(|| LexRequirement::new(parent_required.to_vec())); RequirementsCompatibility::Compatible(adjusted) } else { RequirementsCompatibility::NonCompatible } } + fn try_pushdown_requirements_to_join( smj: &SortMergeJoinExec, parent_required: LexRequirementRef, - sort_expr: Vec, + sort_expr: LexOrderingRef, push_side: JoinSide, -) -> Result>>>> { - let left_ordering = smj.left().output_ordering().unwrap_or(&[]); - let right_ordering = smj.right().output_ordering().unwrap_or(&[]); +) -> Result>>> { + let left_eq_properties = smj.left().equivalence_properties(); + let right_eq_properties = smj.right().equivalence_properties(); + let mut smj_required_orderings = smj.required_input_ordering(); + let right_requirement = smj_required_orderings.swap_remove(1); + let left_requirement = smj_required_orderings.swap_remove(0); + let left_ordering = smj.left().output_ordering().unwrap_or_default(); + let right_ordering = smj.right().output_ordering().unwrap_or_default(); let (new_left_ordering, new_right_ordering) = match push_side { - JoinSide::Left => (sort_expr.as_slice(), right_ordering), - JoinSide::Right => (left_ordering, sort_expr.as_slice()), + JoinSide::Left => { + let left_eq_properties = left_eq_properties + .clone() + .with_reorder(LexOrdering::from_ref(sort_expr)); + if left_eq_properties + .ordering_satisfy_requirement(&left_requirement.unwrap_or_default()) + { + // After re-ordering requirement is still satisfied + (sort_expr, right_ordering) + } else { + return Ok(None); + } + } + JoinSide::Right => { + let right_eq_properties = right_eq_properties + .clone() + .with_reorder(LexOrdering::from_ref(sort_expr)); + if right_eq_properties + .ordering_satisfy_requirement(&right_requirement.unwrap_or_default()) + { + // After re-ordering requirement is still satisfied + (left_ordering, sort_expr) + } else { + return Ok(None); + } + } + JoinSide::None => return Ok(None), }; let join_type = smj.join_type(); let probe_side = SortMergeJoinExec::probe_side(&join_type); @@ -268,7 +404,7 @@ fn try_pushdown_requirements_to_join( let should_pushdown = smj_eqs.ordering_satisfy_requirement(parent_required); Ok(should_pushdown.then(|| { let mut required_input_ordering = smj.required_input_ordering(); - let new_req = Some(PhysicalSortRequirement::from_sort_exprs(&sort_expr)); + let new_req = Some(PhysicalSortRequirement::from_sort_exprs(sort_expr)); match push_side { JoinSide::Left => { required_input_ordering[0] = new_req; @@ -276,18 +412,23 @@ fn try_pushdown_requirements_to_join( JoinSide::Right => { required_input_ordering[1] = new_req; } + JoinSide::None => unreachable!(), } required_input_ordering })) } fn expr_source_side( - required_exprs: &[PhysicalSortExpr], + required_exprs: LexOrderingRef, join_type: JoinType, left_columns_len: usize, ) -> Option { match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { let all_column_sides = required_exprs .iter() .filter_map(|r| { @@ -332,7 +473,7 @@ fn expr_source_side( fn shift_right_required( parent_required: LexRequirementRef, left_columns_len: usize, -) -> Result> { +) -> Result { let new_right_required = parent_required .iter() .filter_map(|r| { @@ -344,7 +485,7 @@ fn shift_right_required( }) .collect::>(); if new_right_required.len() == parent_required.len() { - Ok(new_right_required) + Ok(LexRequirement::new(new_right_required)) } else { plan_err!( "Expect to shift all the parent required column indexes for SortMergeJoin" @@ -352,13 +493,121 @@ fn shift_right_required( } } +/// Handles the custom pushdown of parent-required sorting requirements down to +/// the child execution plans, considering whether the input order is maintained. +/// +/// # Arguments +/// +/// * `plan` - A reference to an `ExecutionPlan` for which the pushdown will be applied. +/// * `parent_required` - The sorting requirements expected by the parent node. +/// * `maintains_input_order` - A vector of booleans indicating whether each child +/// maintains the input order. +/// +/// # Returns +/// +/// Returns `Ok(Some(Vec>))` if the sorting requirements can be +/// pushed down, `Ok(None)` if not. On error, returns a `Result::Err`. +fn handle_custom_pushdown( + plan: &Arc, + parent_required: LexRequirementRef, + maintains_input_order: Vec, +) -> Result>>> { + // If there's no requirement from the parent or the plan has no children, return early + if parent_required.is_empty() || plan.children().is_empty() { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .iter() + .map(|col| col.index()) + .collect::>() + }) + .collect(); + + // Get the number of fields in each child's schema + let len_of_child_schemas: Vec = plan + .children() + .iter() + .map(|c| c.schema().fields().len()) + .collect(); + + // Find the index of the child that maintains input order + let Some(maintained_child_idx) = maintains_input_order + .iter() + .enumerate() + .find(|(_, m)| **m) + .map(|pair| pair.0) + else { + return Ok(None); + }; + + // Check if all required columns come from the child that maintains input order + let start_idx = len_of_child_schemas[..maintained_child_idx] + .iter() + .sum::(); + let end_idx = start_idx + len_of_child_schemas[maintained_child_idx]; + let all_from_maintained_child = + all_indices.iter().all(|i| i >= &start_idx && i < &end_idx); + + // If all columns are from the maintained child, update the parent requirements + if all_from_maintained_child { + let sub_offset = len_of_child_schemas + .iter() + .take(maintained_child_idx) + .sum::(); + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[maintained_child_idx].schema(); + let updated_columns = req + .expr + .clone() + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::>>()?; + + // Prepare the result, populating with the updated requirements for children that maintain order + let result = maintains_input_order + .iter() + .map(|&maintains_order| { + if maintains_order { + Some(LexRequirement::new(updated_parent_req.clone())) + } else { + None + } + }) + .collect(); + + Ok(Some(result)) + } else { + Ok(None) + } +} + /// Define the Requirements Compatibility #[derive(Debug)] enum RequirementsCompatibility { /// Requirements satisfy Satisfy, /// Requirements compatible - Compatible(Option>), + Compatible(Option), /// Requirements not compatible NonCompatible, } diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 7bc1eeb7c4a5..bdf16300ea87 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -17,11 +17,13 @@ //! Collection of testing utility functions that are leveraged by the query optimizer rules +use std::any::Any; +use std::fmt::Formatter; use std::sync::Arc; use crate::datasource::listing::PartitionedFile; use crate::datasource::physical_plan::{FileScanConfig, ParquetExec}; -use crate::datasource::stream::{StreamConfig, StreamTable}; +use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use crate::error::Result; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -41,15 +43,22 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinType, Statistics}; +use datafusion_common::JoinType; use datafusion_execution::object_store::ObjectStoreUrl; -use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; +use datafusion_expr::{WindowFrame, WindowFunctionDefinition}; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use datafusion_physical_plan::displayable; use datafusion_physical_plan::tree_node::PlanContext; +use datafusion_physical_plan::{ + displayable, DisplayAs, DisplayFormatType, PlanProperties, +}; use async_trait::async_trait; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexRequirement, PhysicalSortRequirement, +}; async fn register_current_csv( ctx: &SessionContext, @@ -62,7 +71,8 @@ async fn register_current_csv( match infinite { true => { - let config = StreamConfig::new_file(schema, path.into()); + let source = FileStreamProvider::new_file(schema, path.into()); + let config = StreamConfig::new(Arc::new(source)); ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; } false => { @@ -233,17 +243,17 @@ pub fn bounded_window_exec( sort_exprs: impl IntoIterator, input: Arc, ) -> Arc { - let sort_exprs: Vec<_> = sort_exprs.into_iter().collect(); + let sort_exprs: LexOrdering = sort_exprs.into_iter().collect(); let schema = input.schema(); Arc::new( crate::physical_plan::windows::BoundedWindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col(col_name, &schema).unwrap()], &[], - &sort_exprs, + sort_exprs.as_ref(), Arc::new(WindowFrame::new(Some(false))), schema.as_ref(), false, @@ -274,21 +284,11 @@ pub fn sort_preserving_merge_exec( /// Create a non sorted parquet exec pub fn parquet_exec(schema: &SchemaRef) -> Arc { - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), - )) + ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema.clone()) + .with_file(PartitionedFile::new("x".to_string(), 100)), + ) + .build_arc() } // Created a sorted parquet exec @@ -298,21 +298,12 @@ pub fn parquet_exec_sorted( ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - }, - None, - None, - Default::default(), - )) + ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema.clone()) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(vec![sort_exprs]), + ) + .build_arc() } pub fn union_exec(input: Vec>) -> Arc { @@ -370,6 +361,98 @@ pub fn sort_exec( Arc::new(SortExec::new(sort_exprs, input)) } +/// A test [`ExecutionPlan`] whose requirements can be configured. +#[derive(Debug)] +pub struct RequirementsTestExec { + required_input_ordering: LexOrdering, + maintains_input_order: bool, + input: Arc, +} + +impl RequirementsTestExec { + pub fn new(input: Arc) -> Self { + Self { + required_input_ordering: LexOrdering::default(), + maintains_input_order: true, + input, + } + } + + /// sets the required input ordering + pub fn with_required_input_ordering( + mut self, + required_input_ordering: LexOrdering, + ) -> Self { + self.required_input_ordering = required_input_ordering; + self + } + + /// set the maintains_input_order flag + pub fn with_maintains_input_order(mut self, maintains_input_order: bool) -> Self { + self.maintains_input_order = maintains_input_order; + self + } + + /// returns this ExecutionPlan as an Arc + pub fn into_arc(self) -> Arc { + Arc::new(self) + } +} + +impl DisplayAs for RequirementsTestExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "RequiredInputOrderingExec") + } +} + +impl ExecutionPlan for RequirementsTestExec { + fn name(&self) -> &str { + "RequiredInputOrderingExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + self.input.properties() + } + + fn required_input_ordering(&self) -> Vec> { + let requirement = PhysicalSortRequirement::from_sort_exprs( + self.required_input_ordering.as_ref().iter(), + ); + vec![Some(requirement)] + } + + fn maintains_input_order(&self) -> Vec { + vec![self.maintains_input_order] + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 1); + Ok(RequirementsTestExec::new(children[0].clone()) + .with_required_input_ordering(self.required_input_ordering.clone()) + .with_maintains_input_order(self.maintains_input_order) + .into_arc()) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("Test exec does not support execution") + } +} + /// A [`PlanContext`] object is susceptible to being left in an inconsistent state after /// untested mutable operations. It is crucial that there be no discrepancies between a plan /// associated with the root node and the plan generated after traversing all nodes diff --git a/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs new file mode 100644 index 000000000000..d85278556cc4 --- /dev/null +++ b/datafusion/core/src/physical_optimizer/update_aggr_exprs.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! An optimizer rule that checks ordering requirements of aggregate expressions +//! and modifies the expressions to work more efficiently if possible. + +use std::sync::Arc; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{plan_datafusion_err, Result}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::{ + reverse_order_bys, EquivalenceProperties, PhysicalSortRequirement, +}; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::aggregates::concat_slices; +use datafusion_physical_plan::windows::get_ordered_partition_by_indices; +use datafusion_physical_plan::{ + aggregates::AggregateExec, ExecutionPlan, ExecutionPlanProperties, +}; + +/// This optimizer rule checks ordering requirements of aggregate expressions. +/// +/// There are 3 kinds of aggregators in terms of ordering requirements: +/// - `AggregateOrderSensitivity::Insensitive`, meaning that ordering is not +/// important. +/// - `AggregateOrderSensitivity::HardRequirement`, meaning that the aggregator +/// requires a specific ordering. +/// - `AggregateOrderSensitivity::Beneficial`, meaning that the aggregator can +/// handle unordered input, but can run more efficiently if its input conforms +/// to a specific ordering. +/// +/// This rule analyzes aggregate expressions of type `Beneficial` to see whether +/// their input ordering requirements are satisfied. If this is the case, the +/// aggregators are modified to run in a more efficient mode. +#[derive(Default, Debug)] +pub struct OptimizeAggregateOrder {} + +impl OptimizeAggregateOrder { + #[allow(missing_docs)] + pub fn new() -> Self { + Self::default() + } +} + +impl PhysicalOptimizerRule for OptimizeAggregateOrder { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_up(|plan| { + if let Some(aggr_exec) = plan.as_any().downcast_ref::() { + // Final stage implementations do not rely on ordering -- those + // ordering fields may be pruned out by first stage aggregates. + // Hence, necessary information for proper merge is added during + // the first stage to the state field, which the final stage uses. + if !aggr_exec.mode().is_first_stage() { + return Ok(Transformed::no(plan)); + } + let input = aggr_exec.input(); + let mut aggr_expr = aggr_exec.aggr_expr().to_vec(); + + let groupby_exprs = aggr_exec.group_expr().input_exprs(); + // If the existing ordering satisfies a prefix of the GROUP BY + // expressions, prefix requirements with this section. In this + // case, aggregation will work more efficiently. + let indices = get_ordered_partition_by_indices(&groupby_exprs, input); + let requirement = indices + .iter() + .map(|&idx| { + PhysicalSortRequirement::new(groupby_exprs[idx].clone(), None) + }) + .collect::>(); + + aggr_expr = try_convert_aggregate_if_better( + aggr_expr, + &requirement, + input.equivalence_properties(), + )?; + + let aggr_exec = aggr_exec.with_new_aggr_exprs(aggr_expr); + + Ok(Transformed::yes(Arc::new(aggr_exec) as _)) + } else { + Ok(Transformed::no(plan)) + } + }) + .data() + } + + fn name(&self) -> &str { + "OptimizeAggregateOrder" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Tries to convert each aggregate expression to a potentially more efficient +/// version. +/// +/// # Parameters +/// +/// * `aggr_exprs` - A vector of `AggregateFunctionExpr` representing the +/// aggregate expressions to be optimized. +/// * `prefix_requirement` - An array slice representing the ordering +/// requirements preceding the aggregate expressions. +/// * `eq_properties` - A reference to the `EquivalenceProperties` object +/// containing ordering information. +/// +/// # Returns +/// +/// Returns `Ok(converted_aggr_exprs)` if the conversion process completes +/// successfully. Any errors occurring during the conversion process are +/// passed through. +fn try_convert_aggregate_if_better( + aggr_exprs: Vec>, + prefix_requirement: &[PhysicalSortRequirement], + eq_properties: &EquivalenceProperties, +) -> Result>> { + aggr_exprs + .into_iter() + .map(|aggr_expr| { + let aggr_sort_exprs = &aggr_expr.order_bys().unwrap_or_default(); + let reverse_aggr_sort_exprs = reverse_order_bys(aggr_sort_exprs); + let aggr_sort_reqs = + PhysicalSortRequirement::from_sort_exprs(aggr_sort_exprs.iter()); + let reverse_aggr_req = + PhysicalSortRequirement::from_sort_exprs(&reverse_aggr_sort_exprs.inner); + + // If the aggregate expression benefits from input ordering, and + // there is an actual ordering enabling this, try to update the + // aggregate expression to benefit from the existing ordering. + // Otherwise, leave it as is. + if aggr_expr.order_sensitivity().is_beneficial() && !aggr_sort_reqs.is_empty() + { + let reqs = concat_slices(prefix_requirement, &aggr_sort_reqs); + if eq_properties.ordering_satisfy_requirement(&reqs) { + // Existing ordering satisfies the aggregator requirements: + aggr_expr.with_beneficial_ordering(true)?.map(Arc::new) + } else if eq_properties.ordering_satisfy_requirement(&concat_slices( + prefix_requirement, + &reverse_aggr_req, + )) { + // Converting to reverse enables more efficient execution + // given the existing ordering (if possible): + aggr_expr + .reverse_expr() + .map(Arc::new) + .unwrap_or(aggr_expr) + .with_beneficial_ordering(true)? + .map(Arc::new) + } else { + // There is no beneficial ordering present -- aggregation + // will still work albeit in a less efficient mode. + aggr_expr.with_beneficial_ordering(false)?.map(Arc::new) + } + .ok_or_else(|| { + plan_datafusion_err!( + "Expects an aggregate expression that can benefit from input ordering" + ) + }) + } else { + Ok(aggr_expr) + } + }) + .collect() +} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 2c0d042281e6..8007d8cc7f00 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -39,7 +39,7 @@ pub fn add_sort_above( fetch: Option, ) -> PlanContext { let mut sort_expr = PhysicalSortRequirement::to_sort_exprs(sort_requirements); - sort_expr.retain(|sort_expr| { + sort_expr.inner.retain(|sort_expr| { !node .plan .equivalence_properties() diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b7b6c20b19bb..2a96a2ad111f 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -19,16 +19,9 @@ use std::borrow::Cow; use std::collections::HashMap; -use std::fmt::Write; use std::sync::Arc; -use crate::datasource::file_format::arrow::ArrowFormat; -use crate::datasource::file_format::avro::AvroFormat; -use crate::datasource::file_format::csv::CsvFormat; -use crate::datasource::file_format::json::JsonFormat; -#[cfg(feature = "parquet")] -use crate::datasource::file_format::parquet::ParquetFormat; -use crate::datasource::file_format::FileFormat; +use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; use crate::datasource::source_as_provider; @@ -36,20 +29,18 @@ use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; use crate::logical_expr::{ - Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Window, + Aggregate, EmptyRelation, Join, Projection, Sort, TableScan, Unnest, Values, Window, }; use crate::logical_expr::{ Expr, LogicalPlan, Partitioning as LogicalPartitioning, PlanType, Repartition, UserDefinedLogicalNode, }; -use crate::logical_expr::{Limit, Values}; use crate::physical_expr::{create_physical_expr, create_physical_exprs}; -use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::explain::ExplainExec; -use crate::physical_plan::expressions::{Column, PhysicalSortExpr}; +use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::filter::FilterExec; use crate::physical_plan::joins::utils as join_utils; use crate::physical_plan::joins::{ @@ -66,351 +57,44 @@ use crate::physical_plan::unnest::UnnestExec; use crate::physical_plan::values::ValuesExec; use crate::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use crate::physical_plan::{ - aggregates, displayable, udaf, windows, AggregateExpr, ExecutionPlan, - ExecutionPlanProperties, InputOrderMode, Partitioning, PhysicalExpr, WindowExpr, + displayable, windows, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, + Partitioning, PhysicalExpr, WindowExpr, }; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; use arrow_array::builder::StringBuilder; use arrow_array::RecordBatch; -use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, - FileType, ScalarValue, + ScalarValue, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, - Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, - WindowFunction, + physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; -use datafusion_expr::expr_vec_fmt; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, Extension, Filter, RecursiveQuery, - ScalarFunctionDefinition, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, Extension, FetchType, Filter, JoinType, RecursiveQuery, + SkipType, SortExpr, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_expr::LexOrdering; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::unnest::ListUnnest; use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; +use datafusion_physical_optimizer::PhysicalOptimizerRule; use futures::{StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; use log::{debug, trace}; use sqlparser::ast::NullTreatment; use tokio::sync::Mutex; -fn create_function_physical_name( - fun: &str, - distinct: bool, - args: &[Expr], - order_by: Option<&Vec>, -) -> Result { - let names: Vec = args - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>()?; - - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - - let phys_name = format!("{}({}{})", fun, distinct_str, names.join(",")); - - Ok(order_by - .map(|order_by| format!("{} ORDER BY [{}]", phys_name, expr_vec_fmt!(order_by))) - .unwrap_or(phys_name)) -} - -fn physical_name(e: &Expr) -> Result { - create_physical_name(e, true) -} - -fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { - match e { - Expr::Unnest(_) => { - internal_err!( - "Expr::Unnest should have been converted to LogicalPlan::Unnest" - ) - } - Expr::Column(c) => { - if is_first_expr { - Ok(c.name.clone()) - } else { - Ok(c.flat_name()) - } - } - Expr::Alias(Alias { name, .. }) => Ok(name.clone()), - Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{value:?}")), - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = create_physical_name(left, false)?; - let right = create_physical_name(right, false)?; - Ok(format!("{left} {op} {right}")) - } - Expr::Case(case) => { - let mut name = "CASE ".to_string(); - if let Some(e) = &case.expr { - let _ = write!(name, "{e} "); - } - for (w, t) in &case.when_then_expr { - let _ = write!(name, "WHEN {w} THEN {t} "); - } - if let Some(e) = &case.else_expr { - let _ = write!(name, "ELSE {e} "); - } - name += "END"; - Ok(name) - } - Expr::Cast(Cast { expr, .. }) => { - // CAST does not change the expression name - create_physical_name(expr, false) - } - Expr::TryCast(TryCast { expr, .. }) => { - // CAST does not change the expression name - create_physical_name(expr, false) - } - Expr::Not(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("NOT {expr}")) - } - Expr::Negative(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("(- {expr})")) - } - Expr::IsNull(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NULL")) - } - Expr::IsNotNull(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT NULL")) - } - Expr::IsTrue(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS TRUE")) - } - Expr::IsFalse(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS FALSE")) - } - Expr::IsUnknown(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS UNKNOWN")) - } - Expr::IsNotTrue(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT TRUE")) - } - Expr::IsNotFalse(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT FALSE")) - } - Expr::IsNotUnknown(expr) => { - let expr = create_physical_name(expr, false)?; - Ok(format!("{expr} IS NOT UNKNOWN")) - } - Expr::GetIndexedField(GetIndexedField { expr: _, field }) => { - match field { - GetFieldAccess::NamedStructField { name: _ } => { - unreachable!( - "NamedStructField should have been rewritten in OperatorToFunction" - ) - } - GetFieldAccess::ListIndex { key: _ } => { - unreachable!( - "ListIndex should have been rewritten in OperatorToFunction" - ) - } - GetFieldAccess::ListRange { - start: _, - stop: _, - stride: _, - } => { - unreachable!( - "ListRange should have been rewritten in OperatorToFunction" - ) - } - }; - } - Expr::ScalarFunction(fun) => { - // function should be resolved during `AnalyzerRule`s - if let ScalarFunctionDefinition::Name(_) = fun.func_def { - return internal_err!("Function `Expr` with name should be resolved."); - } - - create_function_physical_name(fun.name(), false, &fun.args, None) - } - Expr::WindowFunction(WindowFunction { - fun, - args, - order_by, - .. - }) => { - create_function_physical_name(&fun.to_string(), false, args, Some(order_by)) - } - Expr::AggregateFunction(AggregateFunction { - func_def, - distinct, - args, - filter, - order_by, - null_treatment: _, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(..) => create_function_physical_name( - func_def.name(), - *distinct, - args, - order_by.as_ref(), - ), - AggregateFunctionDefinition::UDF(fun) => { - // TODO: Add support for filter by in AggregateUDF - if filter.is_some() { - return exec_err!( - "aggregate expression with filter is not supported" - ); - } - - let names = args - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()?; - Ok(format!("{}({})", fun.name(), names.join(","))) - } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Aggregate function `Expr` with name should be resolved.") - } - }, - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Ok(format!( - "ROLLUP ({})", - exprs - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()? - .join(", ") - )), - GroupingSet::Cube(exprs) => Ok(format!( - "CUBE ({})", - exprs - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()? - .join(", ") - )), - GroupingSet::GroupingSets(lists_of_exprs) => { - let mut strings = vec![]; - for exprs in lists_of_exprs { - let exprs_str = exprs - .iter() - .map(|e| create_physical_name(e, false)) - .collect::>>()? - .join(", "); - strings.push(format!("({exprs_str})")); - } - Ok(format!("GROUPING SETS ({})", strings.join(", "))) - } - }, - - Expr::InList(InList { - expr, - list, - negated, - }) => { - let expr = create_physical_name(expr, false)?; - let list = list.iter().map(|expr| create_physical_name(expr, false)); - if *negated { - Ok(format!("{expr} NOT IN ({list:?})")) - } else { - Ok(format!("{expr} IN ({list:?})")) - } - } - Expr::Exists { .. } => { - not_impl_err!("EXISTS is not yet supported in the physical plan") - } - Expr::InSubquery(_) => { - not_impl_err!("IN subquery is not yet supported in the physical plan") - } - Expr::ScalarSubquery(_) => { - not_impl_err!("Scalar subqueries are not yet supported in the physical plan") - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = create_physical_name(expr, false)?; - let low = create_physical_name(low, false)?; - let high = create_physical_name(high, false)?; - if *negated { - Ok(format!("{expr} NOT BETWEEN {low} AND {high}")) - } else { - Ok(format!("{expr} BETWEEN {low} AND {high}")) - } - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { - let expr = create_physical_name(expr, false)?; - let pattern = create_physical_name(pattern, false)?; - let op_name = if *case_insensitive { "ILIKE" } else { "LIKE" }; - let escape = if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - }; - if *negated { - Ok(format!("{expr} NOT {op_name} {pattern}{escape}")) - } else { - Ok(format!("{expr} {op_name} {pattern}{escape}")) - } - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { - let expr = create_physical_name(expr, false)?; - let pattern = create_physical_name(pattern, false)?; - let escape = if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - }; - if *negated { - Ok(format!("{expr} NOT SIMILAR TO {pattern}{escape}")) - } else { - Ok(format!("{expr} SIMILAR TO {pattern}{escape}")) - } - } - Expr::Sort { .. } => { - internal_err!("Create physical name does not support sort expression") - } - Expr::Wildcard { .. } => { - internal_err!("Create physical name does not support wildcard") - } - Expr::Placeholder(_) => { - internal_err!("Create physical name does not support placeholder") - } - Expr::OuterReferenceColumn(_, _) => { - internal_err!("Create physical name does not support OuterReferenceColumn") - } - } -} - /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. #[async_trait] @@ -496,7 +180,8 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { let plan = self .create_initial_plan(logical_plan, session_state) .await?; - self.optimize_internal(plan, session_state, |_, _| {}) + + self.optimize_physical_plan(plan, session_state, |_, _| {}) } } } @@ -744,7 +429,7 @@ impl DefaultPhysicalPlanner { Ok(Some(plan)) } - /// Given a single LogicalPlan node, map it to it's physical ExecutionPlan counterpart. + /// Given a single LogicalPlan node, map it to its physical ExecutionPlan counterpart. async fn map_logical_node_to_physical( &self, node: &LogicalPlan, @@ -808,7 +493,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Copy(CopyTo { input, output_url, - format_options, + file_type, partition_by, options: source_option_tuples, }) => { @@ -826,6 +511,16 @@ impl DefaultPhysicalPlanner { .map(|s| (s.to_string(), arrow_schema::DataType::Null)) .collect::>(); + let keep_partition_by_columns = match source_option_tuples + .get("execution.keep_partition_by_columns") + .map(|v| v.trim()) { + None => session_state.config().options().execution.keep_partition_by_columns, + Some("true") => true, + Some("false") => false, + Some(value) => + return Err(DataFusionError::Configuration(format!("provided value for 'execution.keep_partition_by_columns' was not recognized: \"{}\"", value))), + }; + // Set file sink related options let config = FileSinkConfig { object_store_url, @@ -833,42 +528,20 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols, - overwrite: false, - }; - let mut table_options = session_state.default_table_options(); - let sink_format: Arc = match format_options { - FormatOptions::CSV(options) => { - table_options.csv = options.clone(); - table_options.set_file_format(FileType::CSV); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new(CsvFormat::default().with_options(table_options.csv)) - } - FormatOptions::JSON(options) => { - table_options.json = options.clone(); - table_options.set_file_format(FileType::JSON); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new(JsonFormat::default().with_options(table_options.json)) - } - #[cfg(feature = "parquet")] - FormatOptions::PARQUET(options) => { - table_options.parquet = options.clone(); - table_options.set_file_format(FileType::PARQUET); - table_options.alter_with_string_hash_map(source_option_tuples)?; - Arc::new( - ParquetFormat::default().with_options(table_options.parquet), - ) - } - FormatOptions::AVRO => Arc::new(AvroFormat {}), - FormatOptions::ARROW => Arc::new(ArrowFormat {}), + insert_op: InsertOp::Append, + keep_partition_by_columns, }; + let sink_format = file_type_to_format(file_type)? + .create(session_state, source_option_tuples)?; + sink_format .create_writer_physical_plan(input_exec, session_state, config, None) .await? } LogicalPlan::Dml(DmlStatement { table_name, - op: WriteOp::InsertInto, + op: WriteOp::Insert(insert_op), .. }) => { let name = table_name.table(); @@ -876,23 +549,7 @@ impl DefaultPhysicalPlanner { if let Some(provider) = schema.table(name).await? { let input_exec = children.one()?; provider - .insert_into(session_state, input_exec, false) - .await? - } else { - return exec_err!("Table '{table_name}' does not exist"); - } - } - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::InsertOverwrite, - .. - }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { - let input_exec = children.one()?; - provider - .insert_into(session_state, input_exec, true) + .insert_into(session_state, input_exec, *insert_op) .await? } else { return exec_err!("Table '{table_name}' does not exist"); @@ -992,10 +649,18 @@ impl DefaultPhysicalPlanner { aggr_expr, .. }) => { + let options = session_state.config().options(); // Initially need to perform the aggregate and then merge the partitions let input_exec = children.one()?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); + let physical_input_schema_from_logical = logical_input_schema.inner(); + + if &physical_input_schema != physical_input_schema_from_logical + && !options.execution.skip_physical_aggregate_schema_check + { + return internal_err!("Physical input schema should be the same as the one converted from logical input schema."); + } let groups = self.create_grouping_physical_expr( group_expr, @@ -1022,16 +687,12 @@ impl DefaultPhysicalPlanner { let initial_aggr = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups.clone(), - aggregates.clone(), + aggregates, filters.clone(), input_exec, physical_input_schema.clone(), )?); - // update group column indices based on partial aggregate plan evaluation - let final_group: Vec> = - initial_aggr.output_group_expr(); - let can_repartition = !groups.is_empty() && session_state.config().target_partitions() > 1 && session_state.config().repartition_aggregations(); @@ -1040,7 +701,7 @@ impl DefaultPhysicalPlanner { // optimization purposes. For example, a FIRST_VALUE may turn // into a LAST_VALUE with the reverse ordering requirement. // To reflect such changes to subsequent stages, use the updated - // `AggregateExpr`/`PhysicalSortExpr` objects. + // `AggregateFunctionExpr`/`PhysicalSortExpr` objects. let updated_aggregates = initial_aggr.aggr_expr().to_vec(); let next_partition_mode = if can_repartition { @@ -1052,13 +713,7 @@ impl DefaultPhysicalPlanner { AggregateMode::Final }; - let final_grouping_set = PhysicalGroupBy::new_single( - final_group - .iter() - .enumerate() - .map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone())) - .collect(), - ); + let final_grouping_set = initial_aggr.group_expr().as_final(); Arc::new(AggregateExec::try_new( next_partition_mode, @@ -1142,8 +797,20 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Subquery(_) => todo!(), LogicalPlan::SubqueryAlias(_) => children.one()?, - LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + LogicalPlan::Limit(limit) => { let input = children.one()?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!( + "Unsupported OFFSET expression: {:?}", + limit.skip + ); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!( + "Unsupported LIMIT expression: {:?}", + limit.fetch + ); + }; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -1152,33 +819,34 @@ impl DefaultPhysicalPlanner { // Apply a LocalLimitExec to each partition. The optimizer will also insert // a CoalescePartitionsExec between the GlobalLimitExec and LocalLimitExec if let Some(fetch) = fetch { - Arc::new(LocalLimitExec::new(input, *fetch + skip)) + Arc::new(LocalLimitExec::new(input, fetch + skip)) } else { input } }; - Arc::new(GlobalLimitExec::new(input, *skip, *fetch)) + Arc::new(GlobalLimitExec::new(input, skip, fetch)) } LogicalPlan::Unnest(Unnest { - columns, + list_type_columns, + struct_type_columns, schema, options, .. }) => { let input = children.one()?; - let column_execs = columns + let schema = SchemaRef::new(schema.as_ref().to_owned().into()); + let list_column_indices = list_type_columns .iter() - .map(|column| { - schema - .index_of_column(column) - .map(|idx| Column::new(&column.name, idx)) + .map(|(index, unnesting)| ListUnnest { + index_in_input_schema: *index, + depth: unnesting.depth, }) - .collect::>()?; - let schema = SchemaRef::new(schema.as_ref().to_owned().into()); + .collect(); Arc::new(UnnestExec::new( input, - column_execs, + list_column_indices, + struct_type_columns.clone(), schema, options.clone(), )) @@ -1314,7 +982,7 @@ impl DefaultPhysicalPlanner { let join_filter = match filter { Some(expr) => { // Extract columns from filter expression and saved in a HashSet - let cols = expr.to_columns()?; + let cols = expr.column_refs(); // Collect left & right field indices, the field indices are sorted in ascending order let left_field_indices = cols @@ -1359,14 +1027,21 @@ impl DefaultPhysicalPlanner { }) .collect(); + let metadata: HashMap<_, _> = left_df_schema + .metadata() + .clone() + .into_iter() + .chain(right_df_schema.metadata().clone()) + .collect(); + // Construct intermediate schemas used for filtering data and // convert logical expression to physical according to filter schema let filter_df_schema = DFSchema::new_with_metadata( filter_df_fields, - HashMap::new(), + metadata.clone(), )?; let filter_schema = - Schema::new_with_metadata(filter_fields, HashMap::new()); + Schema::new_with_metadata(filter_fields, metadata); let filter_expr = create_physical_expr( expr, &filter_df_schema, @@ -1390,14 +1065,18 @@ impl DefaultPhysicalPlanner { session_state.config_options().optimizer.prefer_hash_join; let join: Arc = if join_on.is_empty() { - // there is no equal join condition, use the nested loop join - // TODO optimize the plan, and use the config of `target_partitions` and `repartition_joins` - Arc::new(NestedLoopJoinExec::try_new( - physical_left, - physical_right, - join_filter, - join_type, - )?) + if join_filter.is_none() && matches!(join_type, JoinType::Inner) { + // cross join if there is no join conditions and no join filter set + Arc::new(CrossJoinExec::new(physical_left, physical_right)) + } else { + // there is no equal join condition, use the nested loop join + Arc::new(NestedLoopJoinExec::try_new( + physical_left, + physical_right, + join_filter, + join_type, + )?) + } } else if session_state.config().target_partitions() > 1 && session_state.config().repartition_joins() && !prefer_hash_join @@ -1457,10 +1136,6 @@ impl DefaultPhysicalPlanner { join } } - LogicalPlan::CrossJoin(_) => { - let [left, right] = children.two()?; - Arc::new(CrossJoinExec::new(left, right)) - } LogicalPlan::RecursiveQuery(RecursiveQuery { name, is_distinct, .. }) => { @@ -1526,6 +1201,9 @@ impl DefaultPhysicalPlanner { // statement can be prepared) return not_impl_err!("Unsupported logical plan: Prepare"); } + LogicalPlan::Execute(_) => { + return not_impl_err!("Unsupported logical plan: Execute"); + } LogicalPlan::Dml(dml) => { // DataFusion is a read-only query engine, but also a library, so consumers may implement this return not_impl_err!("Unsupported logical plan: Dml({0})", dml.op); @@ -1821,7 +1499,8 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = create_physical_exprs(args, logical_schema, execution_props)?; + let physical_args = + create_physical_exprs(args, logical_schema, execution_props)?; let partition_by = create_physical_exprs(partition_by, logical_schema, execution_props)?; let order_by = @@ -1835,15 +1514,14 @@ pub fn create_window_expr_with_name( } let window_frame = Arc::new(window_frame.clone()); - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; windows::create_window_expr( fun, name, - &args, + &physical_args, &partition_by, - &order_by, + order_by.as_ref(), window_frame, physical_schema, ignore_nulls, @@ -1862,37 +1540,43 @@ pub fn create_window_expr( // unpack aliased logical expressions, e.g. "sum(col) over () as total" let (name, e) = match e { Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), - _ => (e.display_name()?, e), + _ => (e.schema_name().to_string(), e), }; create_window_expr_with_name(e, name, logical_schema, execution_props) } type AggregateExprWithOptionalArgs = ( - Arc, + Arc, // The filter clause, if any Option>, // Ordering requirements, if any - Option>, + Option, ); /// Create an aggregate expression with a name from a logical expression pub fn create_aggregate_expr_with_name_and_maybe_filter( e: &Expr, - name: impl Into, + name: Option, logical_input_schema: &DFSchema, physical_input_schema: &Schema, execution_props: &ExecutionProps, ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - func_def, + func, distinct, args, filter, order_by, null_treatment, }) => { - let args = + let name = if let Some(name) = name { + name + } else { + physical_name(e)? + }; + + let physical_args = create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( @@ -1903,61 +1587,35 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; - let ignore_nulls = null_treatment - .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) + let ignore_nulls = null_treatment.unwrap_or(NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; - let (agg_expr, filter, order_by) = match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, - physical_input_schema, - name, - ignore_nulls, - )?; - (agg_expr, filter, physical_sort_exprs) - } - AggregateFunctionDefinition::UDF(fun) => { - let sort_exprs = order_by.clone().unwrap_or(vec![]); - let physical_sort_exprs = match order_by { - Some(exprs) => Some(create_physical_sort_exprs( - exprs, - logical_input_schema, - execution_props, - )?), - None => None, - }; - let ordering_reqs: Vec = - physical_sort_exprs.clone().unwrap_or(vec![]); - let agg_expr = udaf::create_aggregate_expr( - fun, - &args, - &sort_exprs, - &ordering_reqs, - physical_input_schema, - name, - ignore_nulls, - )?; - (agg_expr, filter, physical_sort_exprs) - } - AggregateFunctionDefinition::Name(_) => { - return internal_err!( - "Aggregate function name should have been resolved" - ) - } + + let (agg_expr, filter, order_by) = { + let physical_sort_exprs = match order_by { + Some(exprs) => Some(create_physical_sort_exprs( + exprs, + logical_input_schema, + execution_props, + )?), + None => None, + }; + + let ordering_reqs: LexOrdering = + physical_sort_exprs.clone().unwrap_or_default(); + + let agg_expr = + AggregateExprBuilder::new(func.to_owned(), physical_args.to_vec()) + .order_by(ordering_reqs) + .schema(Arc::new(physical_input_schema.to_owned())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .with_distinct(*distinct) + .build() + .map(Arc::new)?; + + (agg_expr, filter, physical_sort_exprs) }; + Ok((agg_expr, filter, order_by)) } other => internal_err!("Invalid aggregate expression '{other:?}'"), @@ -1973,8 +1631,9 @@ pub fn create_aggregate_expr_and_maybe_filter( ) -> Result { // unpack (nested) aliased logical expressions, e.g. "sum(col) as total" let (name, e) = match e { - Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), - _ => (physical_name(e)?, e), + Expr::Alias(Alias { expr, name, .. }) => (Some(name.clone()), expr.as_ref()), + Expr::AggregateFunction(_) => (Some(e.schema_name().to_string()), e), + _ => (None, e), }; create_aggregate_expr_with_name_and_maybe_filter( @@ -1988,38 +1647,34 @@ pub fn create_aggregate_expr_and_maybe_filter( /// Create a physical sort expression from a logical expression pub fn create_physical_sort_expr( - e: &Expr, + e: &SortExpr, input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { - if let Expr::Sort(expr::Sort { + let SortExpr { expr, asc, nulls_first, - }) = e - { - Ok(PhysicalSortExpr { - expr: create_physical_expr(expr, input_dfschema, execution_props)?, - options: SortOptions { - descending: !asc, - nulls_first: *nulls_first, - }, - }) - } else { - internal_err!("Expects a sort expression") - } + } = e; + Ok(PhysicalSortExpr { + expr: create_physical_expr(expr, input_dfschema, execution_props)?, + options: SortOptions { + descending: !asc, + nulls_first: *nulls_first, + }, + }) } /// Create vector of physical sort expression from a vector of logical expression pub fn create_physical_sort_exprs( - exprs: &[Expr], + exprs: &[SortExpr], input_dfschema: &DFSchema, execution_props: &ExecutionProps, ) -> Result { exprs .iter() .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) - .collect::>>() + .collect::>() } impl DefaultPhysicalPlanner { @@ -2040,7 +1695,7 @@ impl DefaultPhysicalPlanner { let config = &session_state.config_options().explain; if !config.physical_plan_only { - stringified_plans = e.stringified_plans.clone(); + stringified_plans.clone_from(&e.stringified_plans); if e.logical_optimization_succeeded { stringified_plans.push(e.plan.to_stringified(FinalLogicalPlan)); } @@ -2052,26 +1707,40 @@ impl DefaultPhysicalPlanner { .await { Ok(input) => { - // This plan will includes statistics if show_statistics is on + // Include statistics / schema if enabled stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) + .set_show_schema(config.show_schema) .to_stringified(e.verbose, InitialPhysicalPlan), ); - // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose - if e.verbose && !config.show_statistics { - stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(true) - .to_stringified( - e.verbose, - InitialPhysicalPlanWithStats, - ), - ); + // Show statistics + schema in verbose output even if not + // explicitly requested + if e.verbose { + if !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithStats, + ), + ); + } + if !config.show_schema { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_schema(true) + .to_stringified( + e.verbose, + InitialPhysicalPlanWithSchema, + ), + ); + } } - let optimized_plan = self.optimize_internal( + let optimized_plan = self.optimize_physical_plan( input, session_state, |plan, optimizer| { @@ -2080,6 +1749,7 @@ impl DefaultPhysicalPlanner { stringified_plans.push( displayable(plan) .set_show_statistics(config.show_statistics) + .set_show_schema(config.show_schema) .to_stringified(e.verbose, plan_type), ); }, @@ -2090,19 +1760,33 @@ impl DefaultPhysicalPlanner { stringified_plans.push( displayable(input.as_ref()) .set_show_statistics(config.show_statistics) + .set_show_schema(config.show_schema) .to_stringified(e.verbose, FinalPhysicalPlan), ); - // If the show_statisitcs is off, add another line to show statsitics in the case of explain verbose - if e.verbose && !config.show_statistics { - stringified_plans.push( - displayable(input.as_ref()) - .set_show_statistics(true) - .to_stringified( - e.verbose, - FinalPhysicalPlanWithStats, - ), - ); + // Show statistics + schema in verbose output even if not + // explicitly requested + if e.verbose { + if !config.show_statistics { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_statistics(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithStats, + ), + ); + } + if !config.show_schema { + stringified_plans.push( + displayable(input.as_ref()) + .set_show_schema(true) + .to_stringified( + e.verbose, + FinalPhysicalPlanWithSchema, + ), + ); + } } } Err(DataFusionError::Context(optimizer_name, e)) => { @@ -2140,7 +1824,7 @@ impl DefaultPhysicalPlanner { /// Optimize a physical plan by applying each physical optimizer, /// calling observer(plan, optimizer after each one) - fn optimize_internal( + pub fn optimize_physical_plan( &self, plan: Arc, session_state: &SessionState, @@ -2239,7 +1923,6 @@ impl DefaultPhysicalPlanner { expr: &[Expr], ) -> Result> { let input_schema = input.as_ref().schema(); - let physical_exprs = expr .iter() .map(|e| { @@ -2297,6 +1980,7 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { #[cfg(test)] mod tests { use std::any::Any; + use std::cmp::Ordering; use std::fmt::{self, Debug}; use std::ops::{BitAnd, Not}; @@ -2310,21 +1994,25 @@ mod tests { use crate::prelude::{SessionConfig, SessionContext}; use crate::test_util::{scan_empty, scan_empty_with_partitions}; + use crate::execution::session_state::SessionStateBuilder; use arrow::array::{ArrayRef, DictionaryArray, Int32Array}; use arrow::datatypes::{DataType, Field, Int32Type}; use datafusion_common::{assert_contains, DFSchemaRef, TableReference}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; - use datafusion_expr::{ - col, lit, sum, LogicalPlanBuilder, UserDefinedLogicalNodeCore, - }; + use datafusion_expr::{col, lit, LogicalPlanBuilder, UserDefinedLogicalNodeCore}; + use datafusion_functions_aggregate::expr_fn::sum; use datafusion_physical_expr::EquivalenceProperties; fn make_session_state() -> SessionState { let runtime = Arc::new(RuntimeEnv::default()); let config = SessionConfig::new().with_target_partitions(4); let config = config.set_bool("datafusion.optimizer.skip_failed_rules", false); - SessionState::new_with_config_rt(config, runtime) + SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build() } async fn plan(logical_plan: &LogicalPlan) -> Result> { @@ -2354,7 +2042,7 @@ mod tests { // verify that the plan correctly casts u8 to i64 // the cast from u8 to i64 for literal will be simplified, and get lit(int64(5)) // the cast here is implicit so has CastOptions with safe=true - let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) } }"; + let expected = "BinaryExpr { left: Column { name: \"c7\", index: 2 }, op: Lt, right: Literal { value: Int64(5) }, fail_on_overflow: false }"; assert!(format!("{exec_plan:?}").contains(expected)); Ok(()) } @@ -2490,9 +2178,6 @@ mod tests { assert!(format!("{plan:?}").contains("GlobalLimitExec")); assert!(format!("{plan:?}").contains("skip: 3, fetch: Some(5)")); - // LocalLimitExec adjusts the `fetch` - assert!(format!("{plan:?}").contains("LocalLimitExec")); - assert!(format!("{plan:?}").contains("fetch: 8")); Ok(()) } @@ -2593,7 +2278,7 @@ mod tests { let execution_plan = plan(&logical_plan).await?; // verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated. - let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") } }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") } } }"; + let expected = "expr: [(BinaryExpr { left: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"a\") }, fail_on_overflow: false }, op: Or, right: BinaryExpr { left: Column { name: \"c1\", index: 0 }, op: Eq, right: Literal { value: Utf8(\"1\") }, fail_on_overflow: false }, fail_on_overflow: false }"; let actual = format!("{execution_plan:?}"); assert!(actual.contains(expected), "{}", actual); @@ -2645,7 +2330,7 @@ mod tests { .downcast_ref::() .expect("hash aggregate"); assert_eq!( - "SUM(aggregate_test_100.c2)", + "sum(aggregate_test_100.c2)", final_hash_agg.schema().field(1).name() ); // we need access to the input to the partial aggregate so that other projects can @@ -2673,8 +2358,8 @@ mod tests { .downcast_ref::() .expect("hash aggregate"); assert_eq!( - "SUM(aggregate_test_100.c3)", - final_hash_agg.schema().field(2).name() + "sum(aggregate_test_100.c3)", + final_hash_agg.schema().field(3).name() ); // we need access to the input to the partial aggregate so that other projects can // implement serde @@ -2834,7 +2519,7 @@ mod tests { fn default() -> Self { Self { schema: DFSchemaRef::new( - DFSchema::from_unqualifed_fields( + DFSchema::from_unqualified_fields( vec![Field::new("a", DataType::Int32, false)].into(), HashMap::new(), ) @@ -2850,6 +2535,14 @@ mod tests { } } + // Implementation needed for `UserDefinedLogicalNodeCore`, since the only field is + // a schema, we can't derive `PartialOrd`, and we can't compare these. + impl PartialOrd for NoOpExtensionNode { + fn partial_cmp(&self, _other: &Self) -> Option { + None + } + } + impl UserDefinedLogicalNodeCore for NoOpExtensionNode { fn name(&self) -> &str { "NoOp" @@ -2871,9 +2564,17 @@ mod tests { write!(f, "NoOp") } - fn from_template(&self, _exprs: &[Expr], _inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + _inputs: Vec, + ) -> Result { unimplemented!("NoOp"); } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug)] @@ -2883,7 +2584,7 @@ mod tests { impl NoOpExecutionPlan { fn new(schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(schema); Self { cache } } @@ -2924,7 +2625,7 @@ mod tests { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index d82a5a2cc1a1..9c9fcd04bf09 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -39,8 +39,8 @@ pub use datafusion_expr::{ Expr, }; pub use datafusion_functions::expr_fn::*; -#[cfg(feature = "array_expressions")] -pub use datafusion_functions_array::expr_fn::*; +#[cfg(feature = "nested_expressions")] +pub use datafusion_functions_nested::expr_fn::*; pub use std::ops::Not; pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; diff --git a/datafusion/core/src/test/mod.rs b/datafusion/core/src/test/mod.rs index 0042554f6c73..9ac75c8f3efb 100644 --- a/datafusion/core/src/test/mod.rs +++ b/datafusion/core/src/test/mod.rs @@ -24,9 +24,9 @@ use std::io::{BufReader, BufWriter}; use std::path::Path; use std::sync::Arc; -use crate::datasource::file_format::file_compression_type::{ - FileCompressionType, FileTypeExt, -}; +use crate::datasource::file_format::csv::CsvFormat; +use crate::datasource::file_format::file_compression_type::FileCompressionType; +use crate::datasource::file_format::FileFormat; use crate::datasource::listing::PartitionedFile; use crate::datasource::object_store::ObjectStoreUrl; use crate::datasource::physical_plan::{CsvExec, FileScanConfig}; @@ -40,7 +40,7 @@ use crate::test_util::{aggr_test_schema, arrow_test_data}; use arrow::array::{self, Array, ArrayRef, Decimal128Builder, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion_common::{DataFusionError, FileType, Statistics}; +use datafusion_common::{DataFusionError, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning, PhysicalSortExpr}; use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; @@ -69,7 +69,7 @@ pub fn create_table_dual() -> Arc { let batch = RecordBatch::try_new( dual_schema.clone(), vec![ - Arc::new(array::Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![1])), Arc::new(array::StringArray::from(vec!["a"])), ], ) @@ -87,19 +87,22 @@ pub fn scan_partitioned_csv(partitions: usize, work_dir: &Path) -> Result>`] for scanning `partitions` of `filename` @@ -107,7 +110,7 @@ pub fn partitioned_file_groups( path: &str, filename: &str, partitions: usize, - file_type: FileType, + file_format: Arc, file_compression_type: FileCompressionType, work_dir: &Path, ) -> Result>> { @@ -119,9 +122,8 @@ pub fn partitioned_file_groups( let filename = format!( "partition-{}{}", i, - file_type - .to_owned() - .get_ext_with_compression(file_compression_type.to_owned()) + file_format + .get_ext_with_compression(&file_compression_type) .unwrap() ); let filename = work_dir.join(filename); @@ -166,7 +168,7 @@ pub fn partitioned_file_groups( for (i, line) in f.lines().enumerate() { let line = line.unwrap(); - if i == 0 && file_type == FileType::CSV { + if i == 0 && file_format.get_ext() == CsvFormat::default().get_ext() { // write header to all partitions for w in writers.iter_mut() { w.write_all(line.as_bytes()).unwrap(); @@ -196,17 +198,9 @@ pub fn partitioned_file_groups( pub fn partitioned_csv_config( schema: SchemaRef, file_groups: Vec>, -) -> Result { - Ok(FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: schema.clone(), - file_groups, - statistics: Statistics::new_unknown(&schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }) +) -> FileScanConfig { + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), schema) + .with_file_groups(file_groups) } pub fn assert_fields_eq(plan: &LogicalPlan, expected: Vec<&str>) { @@ -282,26 +276,28 @@ pub fn csv_exec_sorted( ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("x".to_string(), 100)]], - statistics: Statistics::new_unknown(schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - }, - false, - 0, - 0, - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + CsvExec::builder( + FileScanConfig::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema.clone(), + ) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(vec![sort_exprs]), + ) + .with_has_header(false) + .with_delimeter(0) + .with_quote(0) + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } // construct a stream partition for test purposes +#[derive(Debug)] pub(crate) struct TestStreamPartition { pub schema: SchemaRef, } @@ -331,6 +327,7 @@ pub fn stream_exec_ordered( None, vec![sort_exprs], true, + None, ) .unwrap(), ) @@ -343,23 +340,24 @@ pub fn csv_exec_ordered( ) -> Arc { let sort_exprs = sort_exprs.into_iter().collect(); - Arc::new(CsvExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::parse("test:///").unwrap(), - file_schema: schema.clone(), - file_groups: vec![vec![PartitionedFile::new("file_path".to_string(), 100)]], - statistics: Statistics::new_unknown(schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![sort_exprs], - }, - true, - 0, - b'"', - None, - FileCompressionType::UNCOMPRESSED, - )) + Arc::new( + CsvExec::builder( + FileScanConfig::new( + ObjectStoreUrl::parse("test:///").unwrap(), + schema.clone(), + ) + .with_file(PartitionedFile::new("file_path".to_string(), 100)) + .with_output_ordering(vec![sort_exprs]), + ) + .with_has_header(true) + .with_delimeter(0) + .with_quote(b'"') + .with_escape(None) + .with_comment(None) + .with_newlines_in_values(false) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + ) } /// A mock execution plan that simply returns the provided statistics @@ -417,6 +415,10 @@ impl DisplayAs for StatisticsExec { } impl ExecutionPlan for StatisticsExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -425,7 +427,7 @@ impl ExecutionPlan for StatisticsExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/core/src/test/object_store.rs b/datafusion/core/src/test/object_store.rs index d6f324a7f1f9..6c0a2fc7bec4 100644 --- a/datafusion/core/src/test/object_store.rs +++ b/datafusion/core/src/test/object_store.rs @@ -16,9 +16,8 @@ // under the License. //! Object store implementation used for testing use crate::execution::context::SessionState; +use crate::execution::session_state::SessionStateBuilder; use crate::prelude::SessionContext; -use datafusion_execution::config::SessionConfig; -use datafusion_execution::runtime_env::RuntimeEnv; use futures::FutureExt; use object_store::{memory::InMemory, path::Path, ObjectMeta, ObjectStore}; use std::sync::Arc; @@ -27,8 +26,7 @@ use url::Url; /// Returns a test object store with the provided `ctx` pub fn register_test_store(ctx: &SessionContext, files: &[(&str, u64)]) { let url = Url::parse("test://").unwrap(); - ctx.runtime_env() - .register_object_store(&url, make_test_store_and_state(files).0); + ctx.register_object_store(&url, make_test_store_and_state(files).0); } /// Create a test object store with the provided files @@ -45,10 +43,7 @@ pub fn make_test_store_and_state(files: &[(&str, u64)]) -> (Arc, Sessi ( Arc::new(memory), - SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + SessionStateBuilder::new().with_default_features().build(), ) } diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index 75ef364d01fd..e03c18fec7c4 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -29,12 +29,12 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; -use crate::datasource::provider::TableProviderFactory; -use crate::datasource::stream::{StreamConfig, StreamTable}; -use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; +use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; +use crate::datasource::{empty::EmptyTable, provider_as_source}; use crate::error::Result; -use crate::execution::context::{SessionState, TaskContext}; +use crate::execution::context::TaskContext; use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::{ DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, @@ -45,13 +45,16 @@ use crate::prelude::{CsvReadOptions, SessionContext}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::TableReference; -use datafusion_expr::{CreateExternalTable, Expr, TableType}; -use datafusion_physical_expr::EquivalenceProperties; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_physical_expr::{expressions, EquivalenceProperties, PhysicalExpr}; use async_trait::async_trait; +use datafusion_catalog::Session; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use futures::Stream; use tempfile::TempDir; - // backwards compatibility #[cfg(feature = "parquet")] pub use datafusion_common::test_util::parquet_test_data; @@ -107,7 +110,7 @@ pub fn aggr_test_schema() -> SchemaRef { /// Register session context for the aggregate_test_100.csv file pub async fn register_aggregate_csv( - ctx: &mut SessionContext, + ctx: &SessionContext, table_name: &str, ) -> Result<()> { let schema = aggr_test_schema(); @@ -123,8 +126,8 @@ pub async fn register_aggregate_csv( /// Create a table from the aggregate_test_100.csv file with the specified name pub async fn test_table_with_name(name: &str) -> Result { - let mut ctx = SessionContext::new(); - register_aggregate_csv(&mut ctx, name).await?; + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx, name).await?; ctx.table(name).await } @@ -171,13 +174,14 @@ pub fn populate_csv_partitions( } /// TableFactory for tests +#[derive(Default, Debug)] pub struct TestTableFactory {} #[async_trait] impl TableProviderFactory for TestTableFactory { async fn create( &self, - _: &SessionState, + _: &dyn Session, cmd: &CreateExternalTable, ) -> Result> { Ok(Arc::new(TestTableProvider { @@ -188,6 +192,7 @@ impl TableProviderFactory for TestTableFactory { } /// TableProvider for testing purposes +#[derive(Debug)] pub struct TestTableProvider { /// URL of table files or folder pub url: String, @@ -213,7 +218,7 @@ impl TableProvider for TestTableProvider { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, _projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -285,6 +290,10 @@ impl DisplayAs for UnboundedExec { } impl ExecutionPlan for UnboundedExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -293,7 +302,7 @@ impl ExecutionPlan for UnboundedExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -353,12 +362,102 @@ pub fn register_unbounded_file_with_ordering( schema: SchemaRef, file_path: &Path, table_name: &str, - file_sort_order: Vec>, + file_sort_order: Vec>, ) -> Result<()> { - let config = - StreamConfig::new_file(schema, file_path.into()).with_order(file_sort_order); + let source = FileStreamProvider::new_file(schema, file_path.into()); + let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order); // Register table: ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +struct BoundedStream { + limit: usize, + count: usize, + batch: RecordBatch, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + return Poll::Ready(None); + } + self.count += 1; + Poll::Ready(Some(Ok(self.batch.clone()))) + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.batch.schema() + } +} + +/// Creates an bounded stream for testing purposes. +pub fn bounded_stream(batch: RecordBatch, limit: usize) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + count: 0, + limit, + batch, + }) +} + +/// Describe the type of aggregate being tested +pub enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), +} + +impl TestAggregate { + /// Create a new COUNT(*) aggregate + pub fn new_count_star() -> Self { + Self::CountStar + } + + /// Create a new COUNT(column) aggregate + pub fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(schema.clone()) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr { + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .alias(self.column_name()) + .build() + .unwrap() + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), + } + } + + /// What name would this aggregate produce in a plan? + pub fn column_name(&self) -> &'static str { + match self { + Self::CountStar => "COUNT(*)", + Self::ColumnA(_) => "COUNT(a)", + } + } + + /// What is the expected count? + pub fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, + } + } +} diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 8113d799a184..9f06ad9308ab 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -37,8 +37,7 @@ use crate::physical_plan::metrics::MetricsSet; use crate::physical_plan::ExecutionPlan; use crate::prelude::{Expr, SessionConfig, SessionContext}; -use datafusion_common::Statistics; - +use crate::datasource::physical_plan::parquet::ParquetExecBuilder; use object_store::path::Path; use object_store::ObjectMeta; use parquet::arrow::ArrowWriter; @@ -144,21 +143,15 @@ impl TestParquetFile { ctx: &SessionContext, maybe_filter: Option, ) -> Result> { - let scan_config = FileScanConfig { - object_store_url: self.object_store_url.clone(), - file_schema: self.schema.clone(), - file_groups: vec![vec![PartitionedFile { - object_meta: self.object_meta.clone(), - partition_values: vec![], - range: None, - extensions: None, - }]], - statistics: Statistics::new_unknown(&self.schema), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }; + let scan_config = + FileScanConfig::new(self.object_store_url.clone(), self.schema.clone()) + .with_file(PartitionedFile { + object_meta: self.object_meta.clone(), + partition_values: vec![], + range: None, + statistics: None, + extensions: None, + }); let df_schema = self.schema.clone().to_dfschema_ref()?; @@ -168,25 +161,22 @@ impl TestParquetFile { let parquet_options = ctx.copied_table_options().parquet; if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); - let filter = simplifier.coerce(filter, df_schema.clone()).unwrap(); + let filter = simplifier.coerce(filter, &df_schema).unwrap(); let physical_filter_expr = create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?; - let parquet_exec = Arc::new(ParquetExec::new( - scan_config, - Some(physical_filter_expr.clone()), - None, - parquet_options, - )); + + let parquet_exec = + ParquetExecBuilder::new_with_options(scan_config, parquet_options) + .with_predicate(physical_filter_expr.clone()) + .build_arc(); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?); Ok(exec) } else { - Ok(Arc::new(ParquetExec::new( - scan_config, - None, - None, - parquet_options, - ))) + Ok( + ParquetExecBuilder::new_with_options(scan_config, parquet_options) + .build_arc(), + ) } } @@ -194,7 +184,7 @@ impl TestParquetFile { /// /// Recursively searches for ParquetExec and returns the metrics /// on the first one it finds - pub fn parquet_metrics(plan: Arc) -> Option { + pub fn parquet_metrics(plan: &Arc) -> Option { if let Some(parquet) = plan.as_any().downcast_ref::() { return parquet.metrics(); } diff --git a/datafusion/core/tests/core_integration.rs b/datafusion/core/tests/core_integration.rs index befefb1d7ec5..e0917e6cca19 100644 --- a/datafusion/core/tests/core_integration.rs +++ b/datafusion/core/tests/core_integration.rs @@ -24,6 +24,27 @@ mod dataframe; /// Run all tests that are found in the `macro_hygiene` directory mod macro_hygiene; +/// Run all tests that are found in the `execution` directory +mod execution; + +/// Run all tests that are found in the `expr_api` directory +mod expr_api; + +/// Run all tests that are found in the `fifo` directory +mod fifo; + +/// Run all tests that are found in the `memory_limit` directory +mod memory_limit; + +/// Run all tests that are found in the `custom_sources_cases` directory +mod custom_sources_cases; + +/// Run all tests that are found in the `optimizer` directory +mod optimizer; + +/// Run all tests that are found in the `physical_optimizer` directory +mod physical_optimizer; + #[cfg(test)] #[ctor::ctor] fn init() { diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs deleted file mode 100644 index aa3f35e29541..000000000000 --- a/datafusion/core/tests/custom_sources.rs +++ /dev/null @@ -1,308 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::any::Any; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{Context, Poll}; - -use arrow::array::{Int32Array, Int64Array}; -use arrow::compute::kernels::aggregate; -use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; -use arrow::record_batch::RecordBatch; -use datafusion::datasource::{TableProvider, TableType}; -use datafusion::error::Result; -use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; -use datafusion::logical_expr::{ - col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, -}; -use datafusion::physical_plan::{ - collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; -use datafusion::scalar::ScalarValue; -use datafusion_common::cast::as_primitive_array; -use datafusion_common::project_schema; -use datafusion_common::stats::Precision; -use datafusion_physical_expr::EquivalenceProperties; -use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::{ExecutionMode, PlanProperties}; - -use async_trait::async_trait; -use futures::stream::Stream; - -/// Also run all tests that are found in the `custom_sources_cases` directory -mod custom_sources_cases; - -macro_rules! TEST_CUSTOM_SCHEMA_REF { - () => { - Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int32, false), - Field::new("c2", DataType::Int32, false), - ])) - }; -} -macro_rules! TEST_CUSTOM_RECORD_BATCH { - () => { - RecordBatch::try_new( - TEST_CUSTOM_SCHEMA_REF!(), - vec![ - Arc::new(Int32Array::from(vec![1, 10, 10, 100])), - Arc::new(Int32Array::from(vec![2, 12, 12, 120])), - ], - ) - }; -} - -//--- Custom source dataframe tests ---// - -struct CustomTableProvider; -#[derive(Debug, Clone)] -struct CustomExecutionPlan { - projection: Option>, - cache: PlanProperties, -} - -impl CustomExecutionPlan { - fn new(projection: Option>) -> Self { - let schema = TEST_CUSTOM_SCHEMA_REF!(); - let schema = - project_schema(&schema, projection.as_ref()).expect("projected schema"); - let cache = Self::compute_properties(schema); - Self { projection, cache } - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(schema: SchemaRef) -> PlanProperties { - let eq_properties = EquivalenceProperties::new(schema); - PlanProperties::new( - eq_properties, - // Output Partitioning - Partitioning::UnknownPartitioning(1), - ExecutionMode::Bounded, - ) - } -} - -struct TestCustomRecordBatchStream { - /// the nb of batches of TEST_CUSTOM_RECORD_BATCH generated - nb_batch: i32, -} - -impl RecordBatchStream for TestCustomRecordBatchStream { - fn schema(&self) -> SchemaRef { - TEST_CUSTOM_SCHEMA_REF!() - } -} - -impl Stream for TestCustomRecordBatchStream { - type Item = Result; - - fn poll_next( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.nb_batch > 0 { - self.get_mut().nb_batch -= 1; - Poll::Ready(Some(TEST_CUSTOM_RECORD_BATCH!().map_err(Into::into))) - } else { - Poll::Ready(None) - } - } -} - -impl DisplayAs for CustomExecutionPlan { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "CustomExecutionPlan: projection={:#?}", self.projection) - } - } - } -} - -impl ExecutionPlan for CustomExecutionPlan { - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn children(&self) -> Vec> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - Ok(self) - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) - } - - fn statistics(&self) -> Result { - let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); - Ok(Statistics { - num_rows: Precision::Exact(batch.num_rows()), - total_byte_size: Precision::Absent, - column_statistics: self - .projection - .clone() - .unwrap_or_else(|| (0..batch.columns().len()).collect()) - .iter() - .map(|i| ColumnStatistics { - null_count: Precision::Exact(batch.column(*i).null_count()), - min_value: Precision::Exact(ScalarValue::Int32(aggregate::min( - as_primitive_array::(batch.column(*i)).unwrap(), - ))), - max_value: Precision::Exact(ScalarValue::Int32(aggregate::max( - as_primitive_array::(batch.column(*i)).unwrap(), - ))), - ..Default::default() - }) - .collect(), - }) - } -} - -#[async_trait] -impl TableProvider for CustomTableProvider { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - TEST_CUSTOM_SCHEMA_REF!() - } - - fn table_type(&self) -> TableType { - TableType::Base - } - - async fn scan( - &self, - _state: &SessionState, - projection: Option<&Vec>, - _filters: &[Expr], - _limit: Option, - ) -> Result> { - Ok(Arc::new(CustomExecutionPlan::new(projection.cloned()))) - } -} - -#[tokio::test] -async fn custom_source_dataframe() -> Result<()> { - let ctx = SessionContext::new(); - - let table = ctx.read_table(Arc::new(CustomTableProvider))?; - let (state, plan) = table.into_parts(); - let logical_plan = LogicalPlanBuilder::from(plan) - .project(vec![col("c2")])? - .build()?; - - let optimized_plan = state.optimize(&logical_plan)?; - match &optimized_plan { - LogicalPlan::TableScan(TableScan { - source, - projected_schema, - .. - }) => { - assert_eq!(source.schema().fields().len(), 2); - assert_eq!(projected_schema.fields().len(), 1); - } - _ => panic!("input to projection should be TableScan"), - } - - let expected = format!("TableScan: {UNNAMED_TABLE} projection=[c2]"); - assert_eq!(format!("{optimized_plan:?}"), expected); - - let physical_plan = state.create_physical_plan(&optimized_plan).await?; - - assert_eq!(1, physical_plan.schema().fields().len()); - assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, state.task_ctx()).await?; - let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; - assert_eq!(1, batches.len()); - assert_eq!(2, batches[0].num_columns()); - assert_eq!(origin_rec_batch.num_rows(), batches[0].num_rows()); - - Ok(()) -} - -#[tokio::test] -async fn optimizers_catch_all_statistics() { - let ctx = SessionContext::new(); - ctx.register_table("test", Arc::new(CustomTableProvider)) - .unwrap(); - - let df = ctx - .sql("SELECT count(*), min(c1), max(c1) from test") - .await - .unwrap(); - - let physical_plan = df.create_physical_plan().await.unwrap(); - - // when the optimization kicks in, the source is replaced by an PlaceholderRowExec - assert!( - contains_place_holder_exec(Arc::clone(&physical_plan)), - "Expected aggregate_statistics optimizations missing: {physical_plan:?}" - ); - - let expected = RecordBatch::try_new( - Arc::new(Schema::new(vec![ - Field::new("COUNT(*)", DataType::Int64, false), - Field::new("MIN(test.c1)", DataType::Int32, false), - Field::new("MAX(test.c1)", DataType::Int32, false), - ])), - vec![ - Arc::new(Int64Array::from(vec![4])), - Arc::new(Int32Array::from(vec![1])), - Arc::new(Int32Array::from(vec![100])), - ], - ) - .unwrap(); - - let task_ctx = ctx.task_ctx(); - let actual = collect(physical_plan, task_ctx).await.unwrap(); - - assert_eq!(actual.len(), 1); - assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); -} - -fn contains_place_holder_exec(plan: Arc) -> bool { - if plan.as_any().is::() { - true - } else if plan.children().len() != 1 { - false - } else { - contains_place_holder_exec(Arc::clone(&plan.children()[0])) - } -} diff --git a/datafusion/core/tests/custom_sources_cases/mod.rs b/datafusion/core/tests/custom_sources_cases/mod.rs index d5367c77d2b9..e1bd14105e23 100644 --- a/datafusion/core/tests/custom_sources_cases/mod.rs +++ b/datafusion/core/tests/custom_sources_cases/mod.rs @@ -15,5 +15,301 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::{Int32Array, Int64Array}; +use arrow::compute::kernels::aggregate; +use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; +use arrow::record_batch::RecordBatch; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::error::Result; +use datafusion::execution::context::{SessionContext, TaskContext}; +use datafusion::logical_expr::{ + col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, +}; +use datafusion::physical_plan::{ + collect, ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; +use datafusion::scalar::ScalarValue; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::project_schema; +use datafusion_common::stats::Precision; +use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::{ExecutionMode, PlanProperties}; + +use async_trait::async_trait; +use datafusion_catalog::Session; +use futures::stream::Stream; + mod provider_filter_pushdown; mod statistics; + +macro_rules! TEST_CUSTOM_SCHEMA_REF { + () => { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])) + }; +} +macro_rules! TEST_CUSTOM_RECORD_BATCH { + () => { + RecordBatch::try_new( + TEST_CUSTOM_SCHEMA_REF!(), + vec![ + Arc::new(Int32Array::from(vec![1, 10, 10, 100])), + Arc::new(Int32Array::from(vec![2, 12, 12, 120])), + ], + ) + }; +} + +//--- Custom source dataframe tests ---// + +#[derive(Debug)] +struct CustomTableProvider; + +#[derive(Debug, Clone)] +struct CustomExecutionPlan { + projection: Option>, + cache: PlanProperties, +} + +impl CustomExecutionPlan { + fn new(projection: Option>) -> Self { + let schema = TEST_CUSTOM_SCHEMA_REF!(); + let schema = + project_schema(&schema, projection.as_ref()).expect("projected schema"); + let cache = Self::compute_properties(schema); + Self { projection, cache } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + let eq_properties = EquivalenceProperties::new(schema); + PlanProperties::new( + eq_properties, + // Output Partitioning + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ) + } +} + +struct TestCustomRecordBatchStream { + /// the nb of batches of TEST_CUSTOM_RECORD_BATCH generated + nb_batch: i32, +} + +impl RecordBatchStream for TestCustomRecordBatchStream { + fn schema(&self) -> SchemaRef { + TEST_CUSTOM_SCHEMA_REF!() + } +} + +impl Stream for TestCustomRecordBatchStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.nb_batch > 0 { + self.get_mut().nb_batch -= 1; + Poll::Ready(Some(TEST_CUSTOM_RECORD_BATCH!().map_err(Into::into))) + } else { + Poll::Ready(None) + } + } +} + +impl DisplayAs for CustomExecutionPlan { + fn fmt_as( + &self, + t: DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CustomExecutionPlan: projection={:#?}", self.projection) + } + } + } +} + +impl ExecutionPlan for CustomExecutionPlan { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) + } + + fn statistics(&self) -> Result { + let batch = TEST_CUSTOM_RECORD_BATCH!().unwrap(); + Ok(Statistics { + num_rows: Precision::Exact(batch.num_rows()), + total_byte_size: Precision::Absent, + column_statistics: self + .projection + .clone() + .unwrap_or_else(|| (0..batch.columns().len()).collect()) + .iter() + .map(|i| ColumnStatistics { + null_count: Precision::Exact(batch.column(*i).null_count()), + min_value: Precision::Exact(ScalarValue::Int32(aggregate::min( + as_primitive_array::(batch.column(*i)).unwrap(), + ))), + max_value: Precision::Exact(ScalarValue::Int32(aggregate::max( + as_primitive_array::(batch.column(*i)).unwrap(), + ))), + ..Default::default() + }) + .collect(), + }) + } +} + +#[async_trait] +impl TableProvider for CustomTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + TEST_CUSTOM_SCHEMA_REF!() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + Ok(Arc::new(CustomExecutionPlan::new(projection.cloned()))) + } +} + +#[tokio::test] +async fn custom_source_dataframe() -> Result<()> { + let ctx = SessionContext::new(); + + let table = ctx.read_table(Arc::new(CustomTableProvider))?; + let (state, plan) = table.into_parts(); + let logical_plan = LogicalPlanBuilder::from(plan) + .project(vec![col("c2")])? + .build()?; + + let optimized_plan = state.optimize(&logical_plan)?; + match &optimized_plan { + LogicalPlan::TableScan(TableScan { + source, + projected_schema, + .. + }) => { + assert_eq!(source.schema().fields().len(), 2); + assert_eq!(projected_schema.fields().len(), 1); + } + _ => panic!("input to projection should be TableScan"), + } + + let expected = format!("TableScan: {UNNAMED_TABLE} projection=[c2]"); + assert_eq!(format!("{optimized_plan}"), expected); + + let physical_plan = state.create_physical_plan(&optimized_plan).await?; + + assert_eq!(1, physical_plan.schema().fields().len()); + assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); + + let batches = collect(physical_plan, state.task_ctx()).await?; + let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; + assert_eq!(1, batches.len()); + assert_eq!(2, batches[0].num_columns()); + assert_eq!(origin_rec_batch.num_rows(), batches[0].num_rows()); + + Ok(()) +} + +#[tokio::test] +async fn optimizers_catch_all_statistics() { + let ctx = SessionContext::new(); + ctx.register_table("test", Arc::new(CustomTableProvider)) + .unwrap(); + + let df = ctx + .sql("SELECT count(*), min(c1), max(c1) from test") + .await + .unwrap(); + + let physical_plan = df.create_physical_plan().await.unwrap(); + + // when the optimization kicks in, the source is replaced by an PlaceholderRowExec + assert!( + contains_place_holder_exec(Arc::clone(&physical_plan)), + "Expected aggregate_statistics optimizations missing: {physical_plan:?}" + ); + + let expected = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new("count(*)", DataType::Int64, false), + Field::new("min(test.c1)", DataType::Int32, false), + Field::new("max(test.c1)", DataType::Int32, false), + ])), + vec![ + Arc::new(Int64Array::from(vec![4])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![100])), + ], + ) + .unwrap(); + + let task_ctx = ctx.task_ctx(); + let actual = collect(physical_plan, task_ctx).await.unwrap(); + + assert_eq!(actual.len(), 1); + assert_eq!(format!("{:?}", actual[0]), format!("{expected:?}")); +} + +fn contains_place_holder_exec(plan: Arc) -> bool { + if plan.as_any().is::() { + true + } else if plan.children().len() != 1 { + false + } else { + contains_place_holder_exec(Arc::clone(plan.children()[0])) + } +} diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index 4579fe806d6f..09f7265d639a 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -21,9 +21,10 @@ use std::sync::Arc; use arrow::array::{Int32Builder, Int64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use datafusion::datasource::provider::{TableProvider, TableType}; +use datafusion::catalog::TableProvider; +use datafusion::datasource::provider::TableType; use datafusion::error::Result; -use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion::execution::context::TaskContext; use datafusion::logical_expr::TableProviderFilterPushDown; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ @@ -35,9 +36,11 @@ use datafusion::scalar::ScalarValue; use datafusion_common::cast::as_primitive_array; use datafusion_common::{internal_err, not_impl_err}; use datafusion_expr::expr::{BinaryExpr, Cast}; +use datafusion_functions_aggregate::expr_fn::count; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; +use datafusion_catalog::Session; fn create_batch(value: i32, num_rows: usize) -> Result { let mut builder = Int32Builder::with_capacity(num_rows); @@ -93,6 +96,10 @@ impl DisplayAs for CustomPlan { } impl ExecutionPlan for CustomPlan { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn std::any::Any { self } @@ -101,7 +108,7 @@ impl ExecutionPlan for CustomPlan { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -135,7 +142,7 @@ impl ExecutionPlan for CustomPlan { } } -#[derive(Clone)] +#[derive(Clone, Debug)] struct CustomProvider { zero_batch: RecordBatch, one_batch: RecordBatch, @@ -157,7 +164,7 @@ impl TableProvider for CustomProvider { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], _: Option, diff --git a/datafusion/core/tests/custom_sources_cases/statistics.rs b/datafusion/core/tests/custom_sources_cases/statistics.rs index 85ac47dc97fc..41d182a3767b 100644 --- a/datafusion/core/tests/custom_sources_cases/statistics.rs +++ b/datafusion/core/tests/custom_sources_cases/statistics.rs @@ -20,7 +20,7 @@ use std::{any::Any, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion::execution::context::TaskContext; use datafusion::{ datasource::{TableProvider, TableType}, error::Result, @@ -36,6 +36,7 @@ use datafusion_common::{project_schema, stats::Precision}; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; +use datafusion_catalog::Session; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -89,7 +90,7 @@ impl TableProvider for StatisticsValidation { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, filters: &[Expr], // limit is ignored because it is not mandatory for a `TableProvider` to honor it @@ -145,6 +146,10 @@ impl DisplayAs for StatisticsValidation { } impl ExecutionPlan for StatisticsValidation { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -153,7 +158,7 @@ impl ExecutionPlan for StatisticsValidation { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/core/tests/data/cr_terminator.csv b/datafusion/core/tests/data/cr_terminator.csv new file mode 100644 index 000000000000..f2a5d09a4c19 --- /dev/null +++ b/datafusion/core/tests/data/cr_terminator.csv @@ -0,0 +1 @@ +c1,c2 id0,value0 id1,value1 id2,value2 id3,value3 \ No newline at end of file diff --git a/datafusion/core/tests/data/double_quote.csv b/datafusion/core/tests/data/double_quote.csv new file mode 100644 index 000000000000..95a6f0c4077a --- /dev/null +++ b/datafusion/core/tests/data/double_quote.csv @@ -0,0 +1,5 @@ +c1,c2 +id0,"""value0""" +id1,"""value1""" +id2,"""value2""" +id3,"""value3""" diff --git a/datafusion/core/tests/data/example_long.csv b/datafusion/core/tests/data/example_long.csv new file mode 100644 index 000000000000..83d4cdde1ce1 --- /dev/null +++ b/datafusion/core/tests/data/example_long.csv @@ -0,0 +1,4 @@ +a,b,c +1,2,3 +4,5,6 +7,8,9 \ No newline at end of file diff --git a/datafusion/core/tests/data/newlines_in_values.csv b/datafusion/core/tests/data/newlines_in_values.csv new file mode 100644 index 000000000000..de0cdb94a5d4 --- /dev/null +++ b/datafusion/core/tests/data/newlines_in_values.csv @@ -0,0 +1,13 @@ +id,message +1,"hello +world" +2,"something +else" +3," +many +lines +make +good test +" +4,unquoted +value,end diff --git a/datafusion/core/tests/data/newlines_in_values_cr_terminator.csv b/datafusion/core/tests/data/newlines_in_values_cr_terminator.csv new file mode 100644 index 000000000000..2f6557d60ec5 --- /dev/null +++ b/datafusion/core/tests/data/newlines_in_values_cr_terminator.csv @@ -0,0 +1 @@ +id,message 1,"hello world" 2,"something else" 3," many lines make good test " 4,unquoted value,end \ No newline at end of file diff --git a/datafusion/core/tests/data/unnest.json b/datafusion/core/tests/data/unnest.json new file mode 100644 index 000000000000..5999171c2886 --- /dev/null +++ b/datafusion/core/tests/data/unnest.json @@ -0,0 +1,2 @@ +{"a":1, "b":[2.0, 1.3, -6.1], "c":[false, true],"d":{"e":1,"f":2}} +{"a":2, "b":[3.0, 2.3, -7.1], "c":[false, true]} diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 7806461bb1ac..1bd90fce839d 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -30,9 +30,11 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion::assert_batches_eq; -use datafusion_common::DFSchema; +use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; +use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont}; +use datafusion_functions_nested::map::map; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -161,13 +163,188 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_nullif() -> Result<()> { + let expr = nullif(col("a"), lit("abcDEF")); + + let expected = [ + "+-------------------------------+", + "| nullif(test.a,Utf8(\"abcDEF\")) |", + "+-------------------------------+", + "| |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+-------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_arrow_cast() -> Result<()> { + let expr = arrow_typeof(arrow_cast(col("b"), lit("Float64"))); + + let expected = [ + "+--------------------------------------------------+", + "| arrow_typeof(arrow_cast(test.b,Utf8(\"Float64\"))) |", + "+--------------------------------------------------+", + "| Float64 |", + "| Float64 |", + "| Float64 |", + "| Float64 |", + "+--------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_nvl() -> Result<()> { + let lit_null = lit(ScalarValue::Utf8(None)); + // nvl(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'TURNED_NULL') + let expr = nvl( + when(col("a").eq(lit("abcDEF")), lit_null) + .otherwise(col("a")) + .unwrap(), + lit("TURNED_NULL"), + ) + .alias("nvl_expr"); + + let expected = [ + "+-------------+", + "| nvl_expr |", + "+-------------+", + "| TURNED_NULL |", + "| abc123 |", + "| CBAdef |", + "| 123AbcDef |", + "+-------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} +#[tokio::test] +async fn test_nvl2() -> Result<()> { + let lit_null = lit(ScalarValue::Utf8(None)); + // nvl2(CASE WHEN a = 'abcDEF' THEN NULL ELSE a END, 'NON_NUll', 'TURNED_NULL') + let expr = nvl2( + when(col("a").eq(lit("abcDEF")), lit_null) + .otherwise(col("a")) + .unwrap(), + lit("NON_NULL"), + lit("TURNED_NULL"), + ) + .alias("nvl2_expr"); + + let expected = [ + "+-------------+", + "| nvl2_expr |", + "+-------------+", + "| TURNED_NULL |", + "| NON_NULL |", + "| NON_NULL |", + "| NON_NULL |", + "+-------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} +#[tokio::test] +async fn test_fn_arrow_typeof() -> Result<()> { + let expr = arrow_typeof(col("l")); + + let expected = [ + "+------------------------------------------------------------------------------------------------------------------+", + "| arrow_typeof(test.l) |", + "+------------------------------------------------------------------------------------------------------------------+", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "| List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) |", + "+------------------------------------------------------------------------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_struct() -> Result<()> { + let expr = r#struct(vec![col("a"), col("b")]); + + let expected = [ + "+--------------------------+", + "| struct(test.a,test.b) |", + "+--------------------------+", + "| {c0: abcDEF, c1: 1} |", + "| {c0: abc123, c1: 10} |", + "| {c0: CBAdef, c1: 10} |", + "| {c0: 123AbcDef, c1: 100} |", + "+--------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_named_struct() -> Result<()> { + let expr = named_struct(vec![lit("column_a"), col("a"), lit("column_b"), col("b")]); + + let expected = [ + "+---------------------------------------------------------------+", + "| named_struct(Utf8(\"column_a\"),test.a,Utf8(\"column_b\"),test.b) |", + "+---------------------------------------------------------------+", + "| {column_a: abcDEF, column_b: 1} |", + "| {column_a: abc123, column_b: 10} |", + "| {column_a: CBAdef, column_b: 10} |", + "| {column_a: 123AbcDef, column_b: 100} |", + "+---------------------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + +#[tokio::test] +async fn test_fn_coalesce() -> Result<()> { + let expr = coalesce(vec![lit(ScalarValue::Utf8(None)), lit("ab")]); + + let expected = [ + "+---------------------------------+", + "| coalesce(Utf8(NULL),Utf8(\"ab\")) |", + "+---------------------------------+", + "| ab |", + "| ab |", + "| ab |", + "| ab |", + "+---------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + + Ok(()) +} + #[tokio::test] async fn test_fn_approx_median() -> Result<()> { let expr = approx_median(col("b")); let expected = [ "+-----------------------+", - "| APPROX_MEDIAN(test.b) |", + "| approx_median(test.b) |", "+-----------------------+", "| 10 |", "+-----------------------+", @@ -183,11 +360,11 @@ async fn test_fn_approx_median() -> Result<()> { #[tokio::test] async fn test_fn_approx_percentile_cont() -> Result<()> { - let expr = approx_percentile_cont(col("b"), lit(0.5)); + let expr = approx_percentile_cont(col("b"), lit(0.5), None); let expected = [ "+---------------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,Float64(0.5)) |", + "| approx_percentile_cont(test.b,Float64(0.5)) |", "+---------------------------------------------+", "| 10 |", "+---------------------------------------------+", @@ -204,11 +381,11 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { None::<&str>, "arg_2".to_string(), )); - let expr = approx_percentile_cont(col("b"), alias_expr); + let expr = approx_percentile_cont(col("b"), alias_expr, None); let df = create_test_table().await?; let expected = [ "+--------------------------------------+", - "| APPROX_PERCENTILE_CONT(test.b,arg_2) |", + "| approx_percentile_cont(test.b,arg_2) |", "+--------------------------------------+", "| 10 |", "+--------------------------------------+", @@ -217,6 +394,21 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { assert_batches_eq!(expected, &batches); + // with number of centroids set + let expr = approx_percentile_cont(col("b"), lit(0.5), Some(lit(2))); + let expected = [ + "+------------------------------------------------------+", + "| approx_percentile_cont(test.b,Float64(0.5),Int32(2)) |", + "+------------------------------------------------------+", + "| 30 |", + "+------------------------------------------------------+", + ]; + + let df = create_test_table().await?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + Ok(()) } @@ -422,7 +614,7 @@ async fn test_fn_md5() -> Result<()> { #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_like() -> Result<()> { - let expr = regexp_like(col("a"), lit("[a-z]")); + let expr = regexp_like(col("a"), lit("[a-z]"), None); let expected = [ "+-----------------------------------+", @@ -437,13 +629,28 @@ async fn test_fn_regexp_like() -> Result<()> { assert_fn_batches!(expr, expected); + let expr = regexp_like(col("a"), lit("abc"), Some(lit("i"))); + + let expected = [ + "+-------------------------------------------+", + "| regexp_like(test.a,Utf8(\"abc\"),Utf8(\"i\")) |", + "+-------------------------------------------+", + "| true |", + "| true |", + "| false |", + "| true |", + "+-------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + Ok(()) } #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_match() -> Result<()> { - let expr = regexp_match(col("a"), lit("[a-z]")); + let expr = regexp_match(col("a"), lit("[a-z]"), None); let expected = [ "+------------------------------------+", @@ -458,13 +665,28 @@ async fn test_fn_regexp_match() -> Result<()> { assert_fn_batches!(expr, expected); + let expr = regexp_match(col("a"), lit("[A-Z]"), Some(lit("i"))); + + let expected = [ + "+----------------------------------------------+", + "| regexp_match(test.a,Utf8(\"[A-Z]\"),Utf8(\"i\")) |", + "+----------------------------------------------+", + "| [a] |", + "| [a] |", + "| [C] |", + "| [A] |", + "+----------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + Ok(()) } #[tokio::test] #[cfg(feature = "unicode_expressions")] async fn test_fn_regexp_replace() -> Result<()> { - let expr = regexp_replace(col("a"), lit("[a-z]"), lit("x"), lit("g")); + let expr = regexp_replace(col("a"), lit("[a-z]"), lit("x"), Some(lit("g"))); let expected = [ "+----------------------------------------------------------+", @@ -479,6 +701,21 @@ async fn test_fn_regexp_replace() -> Result<()> { assert_fn_batches!(expr, expected); + let expr = regexp_replace(col("a"), lit("[a-z]"), lit("x"), None); + + let expected = [ + "+------------------------------------------------+", + "| regexp_replace(test.a,Utf8(\"[a-z]\"),Utf8(\"x\")) |", + "+------------------------------------------------+", + "| xbcDEF |", + "| xbc123 |", + "| CBAxef |", + "| 123AxcDef |", + "+------------------------------------------------+", + ]; + + assert_fn_batches!(expr, expected); + Ok(()) } @@ -866,3 +1103,24 @@ async fn test_fn_array_to_string() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_fn_map() -> Result<()> { + let expr = map( + vec![lit("a"), lit("b"), lit("c")], + vec![lit(1), lit(2), lit(3)], + ); + let expected = [ + "+---------------------------------------------------------------------------------------+", + "| map(make_array(Utf8(\"a\"),Utf8(\"b\"),Utf8(\"c\")),make_array(Int32(1),Int32(2),Int32(3))) |", + "+---------------------------------------------------------------------------------------+", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "| {a: 1, b: 2, c: 3} |", + "+---------------------------------------------------------------------------------------+", + ]; + assert_fn_batches!(expr, expected); + + Ok(()) +} diff --git a/datafusion/core/tests/dataframe/describe.rs b/datafusion/core/tests/dataframe/describe.rs index e82c06efd644..9321481efbd2 100644 --- a/datafusion/core/tests/dataframe/describe.rs +++ b/datafusion/core/tests/dataframe/describe.rs @@ -39,7 +39,7 @@ async fn describe() -> Result<()> { "| describe | id | bool_col | tinyint_col | smallint_col | int_col | bigint_col | float_col | double_col | date_string_col | string_col | timestamp_col | year | month |", "+------------+-------------------+----------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-----------------+------------+-------------------------+--------------------+-------------------+", "| count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", - "| null_count | 7300.0 | 7300 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300.0 | 7300 | 7300 | 7300 | 7300.0 | 7300.0 |", + "| null_count | 0.0 | 0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0 | 0 | 0 | 0.0 | 0.0 |", "| mean | 3649.5 | null | 4.5 | 4.5 | 4.5 | 45.0 | 4.949999964237213 | 45.45 | null | null | null | 2009.5 | 6.526027397260274 |", "| std | 2107.472815166704 | null | 2.8724780750809518 | 2.8724780750809518 | 2.8724780750809518 | 28.724780750809533 | 3.1597258182544645 | 29.012028558317645 | null | null | null | 0.5000342500942125 | 3.44808750051728 |", "| min | 0.0 | null | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 01/01/09 | 0 | 2008-12-31T23:00:00 | 2009.0 | 1.0 |", @@ -69,7 +69,37 @@ async fn describe_boolean_binary() -> Result<()> { "| describe | a | b |", "+------------+------+------+", "| count | 1 | 1 |", - "| null_count | 1 | 1 |", + "| null_count | 0 | 0 |", + "| mean | null | null |", + "| std | null | null |", + "| min | a | null |", + "| max | a | null |", + "| median | null | null |", + "+------------+------+------+" + ]; + assert_batches_eq!(expected, &result); + Ok(()) +} + +#[tokio::test] +async fn describe_null() -> Result<()> { + let ctx = parquet_context().await; + + //add test case for only boolean boolean/binary column + let result = ctx + .sql("select 'a' as a, null as b") + .await? + .describe() + .await? + .collect() + .await?; + #[rustfmt::skip] + let expected = [ + "+------------+------+------+", + "| describe | a | b |", + "+------------+------+------+", + "| count | 1 | 0 |", + "| null_count | 0 | 1 |", "| mean | null | null |", "| std | null | null |", "| min | a | null |", diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index f565fba1db5b..439aa6147e9b 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -29,8 +29,13 @@ use arrow::{ }, record_batch::RecordBatch, }; -use arrow_array::Float32Array; -use arrow_schema::ArrowError; +use arrow_array::{ + Array, BooleanArray, DictionaryArray, Float32Array, Float64Array, Int8Array, + UnionArray, +}; +use arrow_buffer::ScalarBuffer; +use arrow_schema::{ArrowError, UnionFields, UnionMode}; +use datafusion_functions_aggregate::count::count_udaf; use object_store::local::LocalFileSystem; use std::fs; use std::sync::Arc; @@ -40,7 +45,8 @@ use url::Url; use datafusion::dataframe::{DataFrame, DataFrameWriteOptions}; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::execution::context::{SessionContext, SessionState}; +use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::{parquet_test_data, populate_csv_partitions}; @@ -51,11 +57,11 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::var_provider::{VarProvider, VarType}; use datafusion_expr::{ - array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col, - placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr, - ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + cast, col, exists, expr, in_subquery, lit, out_ref_col, placeholder, scalar_subquery, + when, wildcard, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; +use datafusion_functions_aggregate::expr_fn::{array_agg, avg, count, max, sum}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -96,7 +102,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 - // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here + // for compare difference between sql and df logical plan, we need to create a new SessionContext here let ctx = create_join_context()?; let df_results = ctx .table("t1") @@ -108,10 +114,7 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .await? .aggregate(vec![], vec![count(wildcard())])? .select(vec![count(wildcard())])? - .into_unoptimized_plan(), - // Usually, into_optimized_plan() should be used here, but due to - // https://github.com/apache/datafusion/issues/5771, - // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here. + .into_optimized_plan()?, ), ))? .select(vec![col("a"), col("b")])? @@ -169,7 +172,7 @@ async fn test_count_wildcard_on_window() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") + .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") .await? .explain(false, false)? .collect() @@ -178,17 +181,17 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )) + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; @@ -210,7 +213,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { let sql_results = ctx .sql("select count(*) from t1") .await? - .select(vec![count(wildcard())])? + .select(vec![col("count(*)")])? .explain(false, false)? .collect() .await?; @@ -233,6 +236,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { Ok(()) } + #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { let ctx = create_join_context()?; @@ -348,7 +352,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { .unwrap() .select(vec![col("a")]) .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .sort(vec![Sort::new(col("b"), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -392,7 +396,7 @@ async fn sort_on_distinct_columns() -> Result<()> { .unwrap() .distinct() .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))]) + .sort(vec![Sort::new(col("a"), false, true)]) .unwrap(); let results = df.collect().await.unwrap(); @@ -431,7 +435,7 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { .await? .select(vec![col("a")])? .distinct()? - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .sort(vec![Sort::new(col("b"), false, true)]) .unwrap_err(); assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions b must appear in select list"); Ok(()) @@ -595,15 +599,15 @@ async fn test_grouping_sets() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(col("a"))])? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("a")), false, true)), - Expr::Sort(Sort::new(Box::new(col("b")), false, true)), + Sort::new(col("a"), false, true), + Sort::new(col("b"), false, true), ])?; let results = df.collect().await?; let expected = vec![ "+-----------+-----+---------------+", - "| a | b | COUNT(test.a) |", + "| a | b | count(test.a) |", "+-----------+-----+---------------+", "| | 100 | 1 |", "| | 10 | 2 |", @@ -636,15 +640,15 @@ async fn test_grouping_sets_count() -> Result<()> { .await? .aggregate(vec![grouping_set_expr], vec![count(lit(1))])? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("c1")), false, true)), - Expr::Sort(Sort::new(Box::new(col("c2")), false, true)), + Sort::new(col("c1"), false, true), + Sort::new(col("c2"), false, true), ])?; let results = df.collect().await?; let expected = vec![ "+----+----+-----------------+", - "| c1 | c2 | COUNT(Int32(1)) |", + "| c1 | c2 | count(Int32(1)) |", "+----+----+-----------------+", "| | 5 | 14 |", "| | 4 | 23 |", @@ -683,8 +687,8 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { ], )? .sort(vec![ - Expr::Sort(Sort::new(Box::new(col("c1")), false, true)), - Expr::Sort(Sort::new(Box::new(col("c2")), false, true)), + Sort::new(col("c1"), false, true), + Sort::new(col("c2"), false, true), ])?; let results = df.collect().await?; @@ -1231,11 +1235,11 @@ async fn unnest_aggregate_columns() -> Result<()> { .collect() .await?; let expected = [ - r#"+--------------------+"#, - r#"| COUNT(shapes.tags) |"#, - r#"+--------------------+"#, - r#"| 9 |"#, - r#"+--------------------+"#, + r#"+-------------+"#, + r#"| count(tags) |"#, + r#"+-------------+"#, + r#"| 9 |"#, + r#"+-------------+"#, ]; assert_batches_sorted_eq!(expected, &results); @@ -1384,8 +1388,8 @@ async fn unnest_with_redundant_columns() -> Result<()> { let optimized_plan = df.clone().into_optimized_plan()?; let expected = vec![ "Projection: shapes.shape_id [shape_id:UInt32]", - " Unnest: shape_id2 [shape_id:UInt32, shape_id2:UInt32;N]", - " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[ARRAY_AGG(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", + " Unnest: lists[shape_id2|depth=1] structs[] [shape_id:UInt32, shape_id2:UInt32;N]", + " Aggregate: groupBy=[[shapes.shape_id]], aggr=[[array_agg(shapes.shape_id) AS shape_id2]] [shape_id:UInt32, shape_id2:List(Field { name: \"item\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} });N]", " TableScan: shapes projection=[shape_id] [shape_id:UInt32]", ]; @@ -1427,9 +1431,7 @@ async fn unnest_analyze_metrics() -> Result<()> { .explain(false, true)? .collect() .await?; - let formatted = arrow::util::pretty::pretty_format_batches(&results) - .unwrap() - .to_string(); + let formatted = pretty_format_batches(&results).unwrap().to_string(); assert_contains!(&formatted, "elapsed_compute="); assert_contains!(&formatted, "input_batches=1"); assert_contains!(&formatted, "input_rows=5"); @@ -1541,7 +1543,11 @@ async fn unnest_non_nullable_list() -> Result<()> { async fn test_read_batches() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![ @@ -1591,7 +1597,11 @@ async fn test_read_batches() -> Result<()> { async fn test_read_batches_empty() -> Result<()> { let config = SessionConfig::new(); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() + .build(); let ctx = SessionContext::new_with_state(state); let batches = vec![]; @@ -1605,9 +1615,7 @@ async fn test_read_batches_empty() -> Result<()> { #[tokio::test] async fn consecutive_projection_same_schema() -> Result<()> { - let config = SessionConfig::new(); - let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime); + let state = SessionStateBuilder::new().with_default_features().build(); let ctx = SessionContext::new_with_state(state); let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); @@ -1963,7 +1971,7 @@ async fn test_array_agg() -> Result<()> { let expected = [ "+-------------------------------------+", - "| ARRAY_AGG(test.a) |", + "| array_agg(test.a) |", "+-------------------------------------+", "| [abcDEF, abc123, CBAdef, 123AbcDef] |", "+-------------------------------------+", @@ -2070,7 +2078,7 @@ async fn write_partitioned_parquet_results() -> Result<()> { let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/"; @@ -2140,7 +2148,7 @@ async fn write_parquet_results() -> Result<()> { // register a local file system object store for /tmp directory let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); let local_url = Url::parse("file://local").unwrap(); - ctx.runtime_env().register_object_store(&local_url, local); + ctx.register_object_store(&local_url, local); // execute a simple query and write the results to parquet let out_dir = tmp_dir.as_ref().to_str().unwrap().to_string() + "/out/"; @@ -2193,3 +2201,265 @@ async fn write_parquet_results() -> Result<()> { Ok(()) } + +fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + (2, Arc::new(Field::new("C", DataType::Utf8, true))), + ] + .into_iter() + .collect() +} + +#[tokio::test] +async fn sparse_union_is_null() { + // union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}] + let int_array = Int32Array::from(vec![Some(1), None, None, None, None, None]); + let float_array = Float64Array::from(vec![None, None, Some(3.2), None, None, None]); + let str_array = StringArray::from(vec![None, None, None, None, Some("a"), None]); + let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); + + let field = Field::new( + "my_union", + DataType::Union(union_fields(), UnionMode::Sparse), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("union_batch", batch).unwrap(); + + let df = ctx.table("union_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {A=} |", + "| {B=3.2} |", + "| {B=} |", + "| {C=a} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap()); + + // filter where is null + let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=} |", + "| {B=} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); + + // filter where is not null + let result_df = df.filter(col("my_union").is_not_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {B=3.2} |", + "| {C=a} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); +} + +#[tokio::test] +async fn dense_union_is_null() { + // union of [{A=1}, null, {B=3.2}, {A=34}] + let int_array = Int32Array::from(vec![Some(1), None]); + let float_array = Float64Array::from(vec![Some(3.2), None]); + let str_array = StringArray::from(vec![Some("a"), None]); + let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::>(); + let offsets = [0, 1, 0, 1, 0, 1] + .into_iter() + .collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = + UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap(); + + let field = Field::new( + "my_union", + DataType::Union(union_fields(), UnionMode::Dense), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("union_batch", batch).unwrap(); + + let df = ctx.table("union_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {A=} |", + "| {B=3.2} |", + "| {B=} |", + "| {C=a} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &df.clone().collect().await.unwrap()); + + // filter where is null + let result_df = df.clone().filter(col("my_union").is_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=} |", + "| {B=} |", + "| {C=} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); + + // filter where is not null + let result_df = df.filter(col("my_union").is_not_null()).unwrap(); + let expected = [ + "+----------+", + "| my_union |", + "+----------+", + "| {A=1} |", + "| {B=3.2} |", + "| {C=a} |", + "+----------+", + ]; + assert_batches_sorted_eq!(expected, &result_df.collect().await.unwrap()); +} + +#[tokio::test] +async fn boolean_dictionary_as_filter() { + let values = vec![Some(true), Some(false), None, Some(true)]; + let keys = vec![0, 0, 1, 2, 1, 3, 1]; + let values_array = BooleanArray::from(values); + let keys_array = Int8Array::from(keys); + let array = + DictionaryArray::new(keys_array, Arc::new(values_array) as Arc); + let array = Arc::new(array); + + let field = Field::new( + "my_dict", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Boolean)), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![array.clone()]).unwrap(); + + let ctx = SessionContext::new(); + + ctx.register_batch("dict_batch", batch).unwrap(); + + let df = ctx.table("dict_batch").await.unwrap(); + + // view_all + let expected = [ + "+---------+", + "| my_dict |", + "+---------+", + "| true |", + "| true |", + "| false |", + "| |", + "| false |", + "| true |", + "| false |", + "+---------+", + ]; + assert_batches_eq!(expected, &df.clone().collect().await.unwrap()); + + let result_df = df.clone().filter(col("my_dict")).unwrap(); + let expected = [ + "+---------+", + "| my_dict |", + "+---------+", + "| true |", + "| true |", + "| true |", + "+---------+", + ]; + assert_batches_eq!(expected, &result_df.collect().await.unwrap()); + + // test nested dictionary + let keys = vec![0, 2]; // 0 -> true, 2 -> false + let keys_array = Int8Array::from(keys); + let nested_array = DictionaryArray::new(keys_array, array); + + let field = Field::new( + "my_nested_dict", + DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Dictionary( + Box::new(DataType::Int8), + Box::new(DataType::Boolean), + )), + ), + true, + ); + + let schema = Arc::new(Schema::new(vec![field])); + + let batch = RecordBatch::try_new(schema, vec![Arc::new(nested_array)]).unwrap(); + + ctx.register_batch("nested_dict_batch", batch).unwrap(); + + let df = ctx.table("nested_dict_batch").await.unwrap(); + + // view_all + let expected = [ + "+----------------+", + "| my_nested_dict |", + "+----------------+", + "| true |", + "| false |", + "+----------------+", + ]; + + assert_batches_eq!(expected, &df.clone().collect().await.unwrap()); + + let result_df = df.clone().filter(col("my_nested_dict")).unwrap(); + let expected = [ + "+----------------+", + "| my_nested_dict |", + "+----------------+", + "| true |", + "+----------------+", + ]; + + assert_batches_eq!(expected, &result_df.collect().await.unwrap()); +} diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs new file mode 100644 index 000000000000..168bf484e541 --- /dev/null +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::Int64Array; +use arrow_schema::{DataType, Field}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::logical_plan::{LogicalPlan, Values}; +use datafusion_expr::{Aggregate, AggregateUDF, Expr}; +use datafusion_functions_aggregate::count::Count; +use datafusion_physical_plan::collect; +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::Deref; +use std::sync::Arc; + +///! Logical plans need to provide stable semantics, as downstream projects +///! create them and depend on them. Test executable semantics of logical plans. + +#[tokio::test] +async fn count_only_nulls() -> Result<()> { + // Input: VALUES (NULL), (NULL), (NULL) AS _(col) + let input_schema = Arc::new(DFSchema::from_unqualified_fields( + vec![Field::new("col", DataType::Null, true)].into(), + HashMap::new(), + )?); + let input = Arc::new(LogicalPlan::Values(Values { + schema: input_schema, + values: vec![ + vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null)], + vec![Expr::Literal(ScalarValue::Null)], + ], + })); + let input_col_ref = Expr::Column(Column { + relation: None, + name: "col".to_string(), + }); + + // Aggregation: count(col) AS count + let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + input, + vec![], + vec![Expr::AggregateFunction(AggregateFunction { + func: Arc::new(AggregateUDF::new_from_impl(Count::new())), + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + })], + )?); + + // Execute and verify results + let session_state = SessionStateBuilder::new().build(); + let physical_plan = session_state.create_physical_plan(&aggregate).await?; + let result = + collect(physical_plan, Arc::new(TaskContext::from(&session_state))).await?; + + let result = only(result.as_slice()); + let result_schema = result.schema(); + let field = only(result_schema.fields().deref()); + let column = only(result.columns()); + + assert_eq!(field.data_type(), &DataType::Int64); // TODO should be UInt64 + assert_eq!(column.deref(), &Int64Array::from(vec![0])); + + Ok(()) +} + +fn only(elements: &[T]) -> &T +where + T: Debug, +{ + let [element] = elements else { + panic!("Expected exactly one element, got {:?}", elements); + }; + element +} diff --git a/datafusion/core/tests/execution/mod.rs b/datafusion/core/tests/execution/mod.rs new file mode 100644 index 000000000000..8169db1a4611 --- /dev/null +++ b/datafusion/core/tests/execution/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod logical_plan; diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs new file mode 100644 index 000000000000..81a33361008f --- /dev/null +++ b/datafusion/core/tests/expr_api/mod.rs @@ -0,0 +1,388 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::{pretty_format_batches, pretty_format_columns}; +use arrow_array::builder::{ListBuilder, StringBuilder}; +use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, StructArray}; +use arrow_schema::{DataType, Field}; +use datafusion::prelude::*; +use datafusion_common::{DFSchema, ScalarValue}; +use datafusion_expr::ExprFunctionExt; +use datafusion_functions::core::expr_ext::FieldAccessor; +use datafusion_functions_aggregate::first_last::first_value_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_functions_nested::expr_ext::{IndexAccessor, SliceAccessor}; +use sqlparser::ast::NullTreatment; +/// Tests of using and evaluating `Expr`s outside the context of a LogicalPlan +use std::sync::{Arc, OnceLock}; + +mod parse_sql_expr; +mod simplification; + +#[test] +fn test_octet_length() { + #[rustfmt::skip] + evaluate_expr_test( + octet_length(col("id")), + vec![ + "+------+", + "| expr |", + "+------+", + "| 1 |", + "| 1 |", + "| 1 |", + "+------+", + ], + ); +} + +#[test] +fn test_eq() { + // id = '2' + evaluate_expr_test( + col("id").eq(lit("2")), + vec![ + "+-------+", + "| expr |", + "+-------+", + "| false |", + "| true |", + "| false |", + "+-------+", + ], + ); +} + +#[test] +fn test_eq_with_coercion() { + // id = 2 (need to coerce the 2 to '2' to evaluate) + evaluate_expr_test( + col("id").eq(lit(2i32)), + vec![ + "+-------+", + "| expr |", + "+-------+", + "| false |", + "| true |", + "| false |", + "+-------+", + ], + ); +} + +#[test] +fn test_get_field() { + evaluate_expr_test( + col("props").field("a"), + vec![ + "+------------+", + "| expr |", + "+------------+", + "| 2021-02-01 |", + "| 2021-02-02 |", + "| 2021-02-03 |", + "+------------+", + ], + ); +} + +#[test] +fn test_get_field_null() { + #[rustfmt::skip] + evaluate_expr_test( + lit(ScalarValue::Null).field("a"), + vec![ + "+------+", + "| expr |", + "+------+", + "| |", + "+------+", + ], + ); +} + +#[test] +fn test_nested_get_field() { + evaluate_expr_test( + col("props") + .field("a") + .eq(lit("2021-02-02")) + .or(col("id").eq(lit(1))), + vec![ + "+-------+", + "| expr |", + "+-------+", + "| true |", + "| true |", + "| false |", + "+-------+", + ], + ); +} + +#[test] +fn test_list_index() { + #[rustfmt::skip] + evaluate_expr_test( + col("list").index(lit(1i64)), + vec![ + "+------+", + "| expr |", + "+------+", + "| one |", + "| two |", + "| five |", + "+------+", + ], + ); +} + +#[test] +fn test_list_range() { + evaluate_expr_test( + col("list").range(lit(1i64), lit(2i64)), + vec![ + "+--------------+", + "| expr |", + "+--------------+", + "| [one] |", + "| [two, three] |", + "| [five] |", + "+--------------+", + ], + ); +} + +#[tokio::test] +async fn test_aggregate_ext_order_by() { + let agg = first_value_udaf().call(vec![col("props")]); + + // ORDER BY id ASC + let agg_asc = agg + .clone() + .order_by(vec![col("id").sort(true, true)]) + .build() + .unwrap() + .alias("asc"); + + // ORDER BY id DESC + let agg_desc = agg + .order_by(vec![col("id").sort(false, true)]) + .build() + .unwrap() + .alias("desc"); + + evaluate_agg_test( + agg_asc, + vec![ + "+-----------------+", + "| asc |", + "+-----------------+", + "| {a: 2021-02-01} |", + "+-----------------+", + ], + ) + .await; + + evaluate_agg_test( + agg_desc, + vec![ + "+-----------------+", + "| desc |", + "+-----------------+", + "| {a: 2021-02-03} |", + "+-----------------+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_filter() { + let agg = first_value_udaf() + .call(vec![col("i")]) + .order_by(vec![col("i").sort(true, true)]) + .filter(col("i").is_not_null()) + .build() + .unwrap() + .alias("val"); + + #[rustfmt::skip] + evaluate_agg_test( + agg, + vec![ + "+-----+", + "| val |", + "+-----+", + "| 5 |", + "+-----+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_distinct() { + let agg = sum_udaf() + .call(vec![lit(5)]) + // distinct sum should be 5, not 15 + .distinct() + .build() + .unwrap() + .alias("distinct"); + + evaluate_agg_test( + agg, + vec![ + "+----------+", + "| distinct |", + "+----------+", + "| 5 |", + "+----------+", + ], + ) + .await; +} + +#[tokio::test] +async fn test_aggregate_ext_null_treatment() { + let agg = first_value_udaf() + .call(vec![col("i")]) + .order_by(vec![col("i").sort(true, true)]); + + let agg_respect = agg + .clone() + .null_treatment(NullTreatment::RespectNulls) + .build() + .unwrap() + .alias("respect"); + + let agg_ignore = agg + .null_treatment(NullTreatment::IgnoreNulls) + .build() + .unwrap() + .alias("ignore"); + + evaluate_agg_test( + agg_respect, + vec![ + "+---------+", + "| respect |", + "+---------+", + "| |", + "+---------+", + ], + ) + .await; + + evaluate_agg_test( + agg_ignore, + vec![ + "+--------+", + "| ignore |", + "+--------+", + "| 5 |", + "+--------+", + ], + ) + .await; +} + +/// Evaluates the specified expr as an aggregate and compares the result to the +/// expected result. +async fn evaluate_agg_test(expr: Expr, expected_lines: Vec<&str>) { + let batch = test_batch(); + + let ctx = SessionContext::new(); + let group_expr = vec![]; + let agg_expr = vec![expr]; + let result = ctx + .read_batch(batch) + .unwrap() + .aggregate(group_expr, agg_expr) + .unwrap() + .collect() + .await + .unwrap(); + + let result = pretty_format_batches(&result).unwrap().to_string(); + let actual_lines = result.lines().collect::>(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); +} + +/// Converts the `Expr` to a `PhysicalExpr`, evaluates it against the provided +/// `RecordBatch` and compares the result to the expected result. +fn evaluate_expr_test(expr: Expr, expected_lines: Vec<&str>) { + let batch = test_batch(); + let df_schema = DFSchema::try_from(batch.schema()).unwrap(); + let physical_expr = SessionContext::new() + .create_physical_expr(expr, &df_schema) + .unwrap(); + + let result = physical_expr.evaluate(&batch).unwrap(); + let array = result.into_array(1).unwrap(); + let result = pretty_format_columns("expr", &[array]).unwrap().to_string(); + let actual_lines = result.lines().collect::>(); + + assert_eq!( + expected_lines, actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); +} + +static TEST_BATCH: OnceLock = OnceLock::new(); + +fn test_batch() -> RecordBatch { + TEST_BATCH + .get_or_init(|| { + let string_array: ArrayRef = Arc::new(StringArray::from(vec!["1", "2", "3"])); + let int_array: ArrayRef = + Arc::new(Int64Array::from_iter(vec![Some(10), None, Some(5)])); + + // { a: "2021-02-01" } { a: "2021-02-02" } { a: "2021-02-03" } + let struct_array: ArrayRef = Arc::from(StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Utf8, false)), + Arc::new(StringArray::from(vec![ + "2021-02-01", + "2021-02-02", + "2021-02-03", + ])) as _, + )])); + + // ["one"] ["two", "three", "four"] ["five"] + let mut builder = ListBuilder::new(StringBuilder::new()); + builder.append_value([Some("one")]); + builder.append_value([Some("two"), Some("three"), Some("four")]); + builder.append_value([Some("five")]); + let list_array: ArrayRef = Arc::new(builder.finish()); + + RecordBatch::try_from_iter(vec![ + ("id", string_array), + ("i", int_array), + ("props", struct_array), + ("list", list_array), + ]) + .unwrap() + }) + .clone() +} diff --git a/datafusion/core/tests/expr_api/parse_sql_expr.rs b/datafusion/core/tests/expr_api/parse_sql_expr.rs new file mode 100644 index 000000000000..cc049f0004d9 --- /dev/null +++ b/datafusion/core/tests/expr_api/parse_sql_expr.rs @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{DataType, Field, Schema}; +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use datafusion_common::DFSchema; +use datafusion_common::{DFSchemaRef, Result, ToDFSchema}; +use datafusion_expr::col; +use datafusion_expr::lit; +use datafusion_expr::Expr; +use datafusion_sql::unparser::Unparser; +/// A schema like: +/// +/// a: Int32 (possibly with nulls) +/// b: Int32 +/// s: Float32 +fn schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Float32, false), + ]) + .to_dfschema_ref() + .unwrap() +} + +#[tokio::test] +async fn round_trip_parse_sql_expr() -> Result<()> { + let tests = vec![ + "(a = 10)", + "((a = 10) AND (b <> 20))", + "((a = 10) OR (b <> 20))", + "(((a = 10) AND (b <> 20)) OR (c = a))", + "((a = 10) AND b IN (20, 30))", + "((a = 10) AND b NOT IN (20, 30))", + "sum(a)", + "(sum(a) + 1)", + "(min(a) + max(b))", + "(min(a) + (max(b) * sum(c)))", + "(min(a) + ((max(b) * sum(c)) / 10))", + ]; + + for test in tests { + round_trip_session_context(test)?; + round_trip_dataframe(test).await?; + } + + Ok(()) +} + +fn round_trip_session_context(sql: &str) -> Result<()> { + let ctx = SessionContext::new(); + let df_schema = schema(); + let expr = ctx.parse_sql_expr(sql, &df_schema)?; + let sql2 = unparse_sql_expr(&expr)?; + assert_eq!(sql, sql2); + + Ok(()) +} + +async fn round_trip_dataframe(sql: &str) -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx + .read_csv( + &"tests/data/example.csv".to_string(), + CsvReadOptions::default(), + ) + .await?; + let expr = df.parse_sql_expr(sql)?; + let sql2 = unparse_sql_expr(&expr)?; + assert_eq!(sql, sql2); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_qualified_schema() -> Result<()> { + let sql = "a < 5 OR a = 8"; + let expr = col("t.a").lt(lit(5_i64)).or(col("t.a").eq(lit(8_i64))); + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let df_schema = DFSchema::try_from_qualified_schema("t", &schema).unwrap(); + let ctx = SessionContext::new(); + let parsed_expr = ctx.parse_sql_expr(sql, &df_schema)?; + assert_eq!(parsed_expr, expr); + Ok(()) +} + +fn unparse_sql_expr(expr: &Expr) -> Result { + let unparser = Unparser::default(); + + let round_trip_sql = unparser.expr_to_sql(expr)?.to_string(); + Ok(round_trip_sql) +} diff --git a/datafusion/core/tests/simplification.rs b/datafusion/core/tests/expr_api/simplification.rs similarity index 89% rename from datafusion/core/tests/simplification.rs rename to datafusion/core/tests/expr_api/simplification.rs index 880c294bb7aa..68785b7a5a45 100644 --- a/datafusion/core/tests/simplification.rs +++ b/datafusion/core/tests/expr_api/simplification.rs @@ -19,6 +19,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Int32Array}; +use arrow_buffer::IntervalDayTime; use chrono::{DateTime, TimeZone, Utc}; use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*}; use datafusion_common::cast::as_int32_array; @@ -28,10 +29,10 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::builder::table_scan_with_filters; use datafusion_expr::simplify::SimplifyInfo; use datafusion_expr::{ - expr, table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, - LogicalPlanBuilder, ScalarUDF, Volatility, + table_scan, Cast, ColumnarValue, ExprSchemable, LogicalPlan, LogicalPlanBuilder, + ScalarUDF, Volatility, }; -use datafusion_functions::{math, string}; +use datafusion_functions::math; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions}; use datafusion_optimizer::{OptimizerContext, OptimizerRule}; @@ -118,7 +119,7 @@ fn get_optimized_plan_formatted(plan: LogicalPlan, date_time: &DateTime) -> let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); let optimized_plan = optimizer.optimize(plan, &config, observe).unwrap(); - format!("{optimized_plan:?}") + format!("{optimized_plan}") } // ------------------------------ @@ -154,7 +155,7 @@ fn test_evaluate(input_expr: Expr, expected_expr: Expr) { // Make a UDF that adds its two values together, with the specified volatility fn make_udf_add(volatility: Volatility) -> Arc { let input_types = vec![DataType::Int32, DataType::Int32]; - let return_type = Arc::new(DataType::Int32); + let return_type = DataType::Int32; let fun = Arc::new(|args: &[ColumnarValue]| { let args = ColumnarValue::values_to_arrays(args)?; @@ -281,7 +282,10 @@ fn select_date_plus_interval() -> Result<()> { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema)? - + Expr::Literal(ScalarValue::IntervalDayTime(Some(123i64 << 32))); + + Expr::Literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 123, + milliseconds: 0, + }))); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr])? @@ -289,7 +293,7 @@ fn select_date_plus_interval() -> Result<()> { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = r#"Projection: Date32("18636") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("528280977408") + let expected = r#"Projection: Date32("2021-01-09") AS to_timestamp(Utf8("2020-09-08T12:05:00+00:00")) + IntervalDayTime("IntervalDayTime { days: 123, milliseconds: 0 }") TableScan: test"#; let actual = get_optimized_plan_formatted(plan, &time); @@ -329,8 +333,8 @@ fn simplify_scan_predicate() -> Result<()> { .build()?; // before simplify: t.g = power(t.f, 1.0) - // after simplify: (t.g = t.f) as "t.g = power(t.f, 1.0)" - let expected = "TableScan: test, full_filters=[g = f AS g = power(f,Float64(1))]"; + // after simplify: t.g = t.f" + let expected = "TableScan: test, full_filters=[g = f]"; let actual = get_optimized_plan_formatted(plan, &Utc::now()); assert_eq!(expected, actual); Ok(()) @@ -364,13 +368,13 @@ fn test_const_evaluator() { #[test] fn test_const_evaluator_scalar_functions() { // concat("foo", "bar") --> "foobar" - let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]); + let expr = concat(vec![lit("foo"), lit("bar")]); test_evaluate(expr, lit("foobar")); // ensure arguments are also constant folded // concat("foo", concat("bar", "baz")) --> "foobarbaz" - let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]); - let expr = string::expr_fn::concat(vec![lit("foo"), concat1]); + let concat1 = concat(vec![lit("bar"), lit("baz")]); + let expr = concat(vec![lit("foo"), concat1]); test_evaluate(expr, lit("foobarbaz")); // Check non string arguments @@ -403,7 +407,7 @@ fn test_const_evaluator_scalar_functions() { #[test] fn test_const_evaluator_now() { let ts_nanos = 1599566400000000000i64; - let time = chrono::Utc.timestamp_nanos(ts_nanos); + let time = Utc.timestamp_nanos(ts_nanos); let ts_string = "2020-09-08T12:05:00+00:00"; // now() --> ts test_evaluate_with_start_time(now(), lit_timestamp_nano(ts_nanos), &time); @@ -425,7 +429,7 @@ fn test_evaluator_udfs() { // immutable UDF should get folded // udf_add(1+2, 30+40) --> 73 - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( + let expr = Expr::ScalarFunction(ScalarFunction::new_udf( make_udf_add(Volatility::Immutable), args.clone(), )); @@ -434,21 +438,16 @@ fn test_evaluator_udfs() { // stable UDF should be entirely folded // udf_add(1+2, 30+40) --> 73 let fun = make_udf_add(Volatility::Stable); - let expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - args.clone(), - )); + let expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args.clone())); test_evaluate(expr, lit(73)); // volatile UDF should have args folded // udf_add(1+2, 30+40) --> udf_add(3, 70) let fun = make_udf_add(Volatility::Volatile); - let expr = - Expr::ScalarFunction(expr::ScalarFunction::new_udf(Arc::clone(&fun), args)); - let expected_expr = Expr::ScalarFunction(expr::ScalarFunction::new_udf( - Arc::clone(&fun), - folded_args, - )); + let expr = Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), args)); + let expected_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(Arc::clone(&fun), folded_args)); test_evaluate(expr, expected_expr); } @@ -508,6 +507,29 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) { "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" ); } +fn test_simplify_with_cycle_count( + input_expr: Expr, + expected_expr: Expr, + expected_count: u32, +) { + let info: MyInfo = MyInfo { + schema: expr_test_schema(), + execution_props: ExecutionProps::new(), + }; + let simplifier = ExprSimplifier::new(info); + let (simplified_expr, count) = simplifier + .simplify_with_cycle_count(input_expr.clone()) + .expect("successfully evaluated"); + + assert_eq!( + simplified_expr, expected_expr, + "Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}" + ); + assert_eq!( + count, expected_count, + "Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}" + ); +} #[test] fn test_simplify_log() { @@ -658,3 +680,11 @@ fn test_simplify_concat() { let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]); test_simplify(expr, expected) } +#[test] +fn test_simplify_cycles() { + // cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX + let expr = cast(now(), DataType::Int64) + .lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX)); + let expected = lit(true); + test_simplify_with_cycle_count(expr, expected, 3); +} diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo/mod.rs similarity index 73% rename from datafusion/core/tests/fifo.rs rename to datafusion/core/tests/fifo/mod.rs index 9b132f18c7a5..cb587e3510c2 100644 --- a/datafusion/core/tests/fifo.rs +++ b/datafusion/core/tests/fifo/mod.rs @@ -6,7 +6,7 @@ // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // -//http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an @@ -16,48 +16,47 @@ // under the License. //! This test demonstrates the DataFusion FIFO capabilities. -//! + #[cfg(target_family = "unix")] #[cfg(test)] mod unix_test { - use datafusion_common::instant::Instant; - use std::fs::{File, OpenOptions}; - use std::io::Write; + use std::fs::File; use std::path::PathBuf; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; - use std::thread; use std::time::Duration; use arrow::array::Array; use arrow::csv::ReaderBuilder; use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::SchemaRef; - use futures::StreamExt; - use nix::sys::stat; - use nix::unistd; - use tempfile::TempDir; - use tokio::task::{spawn_blocking, JoinHandle}; - - use datafusion::datasource::stream::{StreamConfig, StreamTable}; + use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::datasource::TableProvider; use datafusion::{ prelude::{CsvReadOptions, SessionConfig, SessionContext}, test_util::{aggr_test_schema, arrow_test_data}, }; - use datafusion_common::{exec_err, DataFusionError, Result}; - use datafusion_expr::Expr; + use datafusion_common::instant::Instant; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::SortExpr; + + use futures::StreamExt; + use nix::sys::stat; + use nix::unistd; + use tempfile::TempDir; + use tokio::io::AsyncWriteExt; + use tokio::task::{spawn_blocking, JoinHandle}; /// Makes a TableProvider for a fifo file fn fifo_table( schema: SchemaRef, path: impl Into, - sort: Vec>, + sort: Vec>, ) -> Arc { - let config = StreamConfig::new_file(schema, path.into()) - .with_order(sort) + let source = FileStreamProvider::new_file(schema, path.into()) .with_batch_size(TEST_BATCH_SIZE) .with_header(true); + let config = StreamConfig::new(Arc::new(source)).with_order(sort); Arc::new(StreamTable::new(Arc::new(config))) } @@ -71,8 +70,8 @@ mod unix_test { } } - fn write_to_fifo( - mut file: &File, + async fn write_to_fifo( + file: &mut tokio::fs::File, line: &str, ref_time: Instant, broken_pipe_timeout: Duration, @@ -80,11 +79,11 @@ mod unix_test { // We need to handle broken pipe error until the reader is ready. This // is why we use a timeout to limit the wait duration for the reader. // If the error is different than broken pipe, we fail immediately. - while let Err(e) = file.write_all(line.as_bytes()) { + while let Err(e) = file.write_all(line.as_bytes()).await { if e.raw_os_error().unwrap() == 32 { let interval = Instant::now().duration_since(ref_time); if interval < broken_pipe_timeout { - thread::sleep(Duration::from_millis(100)); + tokio::time::sleep(Duration::from_millis(50)).await; continue; } } @@ -93,28 +92,38 @@ mod unix_test { Ok(()) } - fn create_writing_thread( + /// This function creates a writing task for the FIFO file. To verify + /// incremental processing, it waits for a signal to continue writing after + /// a certain number of lines are written. + #[allow(clippy::disallowed_methods)] + fn create_writing_task( file_path: PathBuf, header: String, lines: Vec, - waiting_lock: Arc, - wait_until: usize, + waiting_signal: Arc, + send_before_waiting: usize, ) -> JoinHandle<()> { // Timeout for a long period of BrokenPipe error let broken_pipe_timeout = Duration::from_secs(10); - let sa = file_path.clone(); - // Spawn a new thread to write to the FIFO file - #[allow(clippy::disallowed_methods)] // spawn allowed only in tests - spawn_blocking(move || { - let file = OpenOptions::new().write(true).open(sa).unwrap(); + // Spawn a new task to write to the FIFO file + tokio::spawn(async move { + let mut file = tokio::fs::OpenOptions::new() + .write(true) + .open(file_path) + .await + .unwrap(); // Reference time to use when deciding to fail the test let execution_start = Instant::now(); - write_to_fifo(&file, &header, execution_start, broken_pipe_timeout).unwrap(); + write_to_fifo(&mut file, &header, execution_start, broken_pipe_timeout) + .await + .unwrap(); for (cnt, line) in lines.iter().enumerate() { - while waiting_lock.load(Ordering::SeqCst) && cnt > wait_until { - thread::sleep(Duration::from_millis(50)); + while waiting_signal.load(Ordering::SeqCst) && cnt > send_before_waiting { + tokio::time::sleep(Duration::from_millis(50)).await; } - write_to_fifo(&file, line, execution_start, broken_pipe_timeout).unwrap(); + write_to_fifo(&mut file, line, execution_start, broken_pipe_timeout) + .await + .unwrap(); } drop(file); }) @@ -125,6 +134,8 @@ mod unix_test { const TEST_BATCH_SIZE: usize = 20; // Number of lines written to FIFO const TEST_DATA_SIZE: usize = 20_000; + // Number of lines to write before waiting to verify incremental processing + const SEND_BEFORE_WAITING: usize = 2 * TEST_BATCH_SIZE; // Number of lines what can be joined. Each joinable key produced 20 lines with // aggregate_test_100 dataset. We will use these joinable keys for understanding // incremental execution. @@ -132,7 +143,7 @@ mod unix_test { // This test provides a relatively realistic end-to-end scenario where // we swap join sides to accommodate a FIFO source. - #[tokio::test(flavor = "multi_thread", worker_threads = 8)] + #[tokio::test] async fn unbounded_file_with_swapped_join() -> Result<()> { // Create session context let config = SessionConfig::new() @@ -162,8 +173,8 @@ mod unix_test { .zip(0..TEST_DATA_SIZE) .map(|(a1, a2)| format!("{a1},{a2}\n")) .collect::>(); - // Create writing threads for the left and right FIFO files - let task = create_writing_thread( + // Create writing tasks for the left and right FIFO files + let task = create_writing_task( fifo_path.clone(), "a1,a2\n".to_owned(), lines, @@ -190,7 +201,16 @@ mod unix_test { ) .await?; // Execute the query - let df = ctx.sql("SELECT t1.a2, t2.c1, t2.c4, t2.c5 FROM left as t1 JOIN right as t2 ON t1.a1 = t2.c1").await?; + let df = ctx + .sql( + "SELECT + t1.a2, t2.c1, t2.c4, t2.c5 + FROM + left as t1, right as t2 + WHERE + t1.a1 = t2.c1", + ) + .await?; let mut stream = df.execute_stream().await?; while (stream.next().await).is_some() { waiting.store(false, Ordering::SeqCst); @@ -199,16 +219,9 @@ mod unix_test { Ok(()) } - #[derive(Debug, PartialEq)] - enum JoinOperation { - LeftUnmatched, - RightUnmatched, - Equal, - } - - // This test provides a relatively realistic end-to-end scenario where - // we change the join into a [SymmetricHashJoin] to accommodate two - // unbounded (FIFO) sources. + /// This test provides a relatively realistic end-to-end scenario where + /// we change the join into a `SymmetricHashJoinExec` to accommodate two + /// unbounded (FIFO) sources. #[tokio::test] async fn unbounded_file_with_symmetric_join() -> Result<()> { // Create session context @@ -217,17 +230,6 @@ mod unix_test { .set_bool("datafusion.execution.coalesce_batches", false) .with_target_partitions(1); let ctx = SessionContext::new_with_config(config); - // Tasks - let mut tasks: Vec> = vec![]; - - // Join filter - let a1_iter = 0..TEST_DATA_SIZE; - // Join key - let a2_iter = (0..TEST_DATA_SIZE).map(|x| x % 10); - let lines = a1_iter - .zip(a2_iter) - .map(|(a1, a2)| format!("{a1},{a2}\n")) - .collect::>(); // Create a new temporary FIFO file let tmp_dir = TempDir::new()?; @@ -238,22 +240,6 @@ mod unix_test { // Create a mutex for tracking if the right input source is waiting for data. let waiting = Arc::new(AtomicBool::new(true)); - // Create writing threads for the left and right FIFO files - tasks.push(create_writing_thread( - left_fifo.clone(), - "a1,a2\n".to_owned(), - lines.clone(), - waiting.clone(), - TEST_BATCH_SIZE, - )); - tasks.push(create_writing_thread( - right_fifo.clone(), - "a1,a2\n".to_owned(), - lines.clone(), - waiting.clone(), - TEST_BATCH_SIZE, - )); - // Create schema let schema = Arc::new(Schema::new(vec![ Field::new("a1", DataType::UInt32, false), @@ -264,62 +250,78 @@ mod unix_test { let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]]; // Set unbounded sorted files read configuration - let provider = fifo_table(schema.clone(), left_fifo, order.clone()); + let provider = fifo_table(schema.clone(), left_fifo.clone(), order.clone()); ctx.register_table("left", provider)?; - let provider = fifo_table(schema.clone(), right_fifo, order); + let provider = fifo_table(schema.clone(), right_fifo.clone(), order); ctx.register_table("right", provider)?; // Execute the query, with no matching rows. (since key is modulus 10) let df = ctx .sql( "SELECT - t1.a1, - t1.a2, - t2.a1, - t2.a2 + t1.a1, t1.a2, t2.a1, t2.a2 FROM - left as t1 FULL - JOIN right as t2 ON t1.a2 = t2.a2 - AND t1.a1 > t2.a1 + 4 - AND t1.a1 < t2.a1 + 9", + left as t1 + FULL JOIN + right as t2 + ON + t1.a2 = t2.a2 AND + t1.a1 > t2.a1 + 4 AND + t1.a1 < t2.a1 + 9", ) .await?; let mut stream = df.execute_stream().await?; - let mut operations = vec![]; - // Partial. + + // Tasks + let mut tasks: Vec> = vec![]; + + // Join filter + let a1_iter = 0..TEST_DATA_SIZE; + // Join key + let a2_iter = (0..TEST_DATA_SIZE).map(|x| x % 10); + let lines = a1_iter + .zip(a2_iter) + .map(|(a1, a2)| format!("{a1},{a2}\n")) + .collect::>(); + + // Create writing tasks for the left and right FIFO files + tasks.push(create_writing_task( + left_fifo, + "a1,a2\n".to_owned(), + lines.clone(), + waiting.clone(), + SEND_BEFORE_WAITING, + )); + tasks.push(create_writing_task( + right_fifo, + "a1,a2\n".to_owned(), + lines, + waiting.clone(), + SEND_BEFORE_WAITING, + )); + // Collect output data: + let (mut equal, mut left, mut right) = (0, 0, 0); while let Some(Ok(batch)) = stream.next().await { waiting.store(false, Ordering::SeqCst); let left_unmatched = batch.column(2).null_count(); let right_unmatched = batch.column(0).null_count(); - let op = if left_unmatched == 0 && right_unmatched == 0 { - JoinOperation::Equal - } else if right_unmatched > left_unmatched { - JoinOperation::RightUnmatched + if left_unmatched == 0 && right_unmatched == 0 { + equal += 1; + } else if right_unmatched <= left_unmatched { + left += 1; } else { - JoinOperation::LeftUnmatched + right += 1; }; - operations.push(op); } futures::future::try_join_all(tasks).await.unwrap(); - // The SymmetricHashJoin executor produces FULL join results at every - // pruning, which happens before it reaches the end of input and more - // than once. In this test, we feed partially joinable data to both - // sides in order to ensure that left or right unmatched results are - // generated more than once during the test. - assert!( - operations - .iter() - .filter(|&n| JoinOperation::RightUnmatched.eq(n)) - .count() - > 1 - && operations - .iter() - .filter(|&n| JoinOperation::LeftUnmatched.eq(n)) - .count() - > 1 - ); + // The symmetric hash join algorithm produces FULL join results at + // every pruning, which happens before it reaches the end of input and + // more than once. In this test, we feed partially joinable data to + // both sides in order to ensure that left or right unmatched results + // are generated as expected. + assert!(equal >= 0 && left > 1 && right > 1); Ok(()) } @@ -340,17 +342,14 @@ mod unix_test { (source_fifo_path.clone(), source_fifo_path.display()); // Tasks let mut tasks: Vec> = vec![]; - // TEST_BATCH_SIZE + 1 rows will be provided. However, after processing precisely - // TEST_BATCH_SIZE rows, the program will pause and wait for a batch to be read in another - // thread. This approach ensures that the pipeline remains unbroken. - tasks.push(create_writing_thread( + tasks.push(create_writing_task( source_fifo_path_thread, "a1,a2\n".to_owned(), (0..TEST_DATA_SIZE) .map(|_| "a,1\n".to_string()) .collect::>(), waiting, - TEST_BATCH_SIZE, + SEND_BEFORE_WAITING, )); // Create a new temporary FIFO file let sink_fifo_path = create_fifo_file(&tmp_dir, "sink.csv")?; @@ -369,8 +368,8 @@ mod unix_test { let mut reader = ReaderBuilder::new(schema) .with_batch_size(TEST_BATCH_SIZE) + .with_header(true) .build(file) - .map_err(|e| DataFusionError::Internal(e.to_string())) .unwrap(); while let Some(Ok(_)) = reader.next() { @@ -384,8 +383,8 @@ mod unix_test { a2 INT NOT NULL ) STORED AS CSV - WITH HEADER ROW - LOCATION '{source_display_fifo_path}'" + LOCATION '{source_display_fifo_path}' + OPTIONS ('format.has_header' 'true')" )) .await?; @@ -396,8 +395,8 @@ mod unix_test { a2 INT NOT NULL ) STORED AS CSV - WITH HEADER ROW - LOCATION '{sink_display_fifo_path}'" + LOCATION '{sink_display_fifo_path}' + OPTIONS ('format.has_header' 'true')" )) .await?; diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 8df16e7944d2..21f604e6c60f 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -25,6 +25,7 @@ use arrow::util::pretty::pretty_format_batches; use arrow_array::types::Int64Type; use datafusion::common::Result; use datafusion::datasource::MemTable; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; @@ -32,16 +33,167 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_physical_expr::expressions::{col, Sum}; -use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; +use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; +use crate::fuzz_cases::aggregation_fuzzer::{ + AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, QueryBuilder, +}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::task::JoinSet; +// ======================================================================== +// The new aggregation fuzz tests based on [`AggregationFuzzer`] +// ======================================================================== +// +// Notes on tests: +// +// Since the supported types differ for each aggregation function, the tests +// below are structured so they enumerate each different aggregate function. +// +// The test framework handles varying combinations of arguments (data types), +// sortedness, and grouping parameters +// +// TODO: Test floating point values (where output needs to be compared with some +// acceptable range due to floating point rounding) +// +// TODO: test other aggregate functions +// - AVG (unstable given the wide range of inputs) +#[tokio::test(flavor = "multi_thread")] +async fn test_min() { + let data_gen_config = baseline_config(); + + // Queries like SELECT min(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("min") + // min works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_max() { + let data_gen_config = baseline_config(); + + // Queries like SELECT max(a) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("max") + // max works on all column types + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_sum() { + let data_gen_config = baseline_config(); + + // Queries like SELECT sum(a), sum(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("sum") + .with_distinct_aggregate_function("sum") + // sum only works on numeric columns + .with_aggregate_arguments(data_gen_config.numeric_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_count() { + let data_gen_config = baseline_config(); + + // Queries like SELECT count(a), count(distinct) FROM fuzz_table GROUP BY b + let query_builder = QueryBuilder::new() + .with_table_name("fuzz_table") + .with_aggregate_function("count") + .with_distinct_aggregate_function("count") + // count work for all arguments + .with_aggregate_arguments(data_gen_config.all_columns()) + .set_group_by_columns(data_gen_config.all_columns()); + + AggregationFuzzerBuilder::from(data_gen_config) + .add_query_builder(query_builder) + .build() + .run() + .await; +} + +/// Return a standard set of columns for testing data generation +/// +/// Includes numeric and string types +/// +/// Does not include: +/// 1. Floating point numbers +/// 1. structured types +fn baseline_config() -> DatasetGeneratorConfig { + let columns = vec![ + ColumnDescr::new("i8", DataType::Int8), + ColumnDescr::new("i16", DataType::Int16), + ColumnDescr::new("i32", DataType::Int32), + ColumnDescr::new("i64", DataType::Int64), + ColumnDescr::new("u8", DataType::UInt8), + ColumnDescr::new("u16", DataType::UInt16), + ColumnDescr::new("u32", DataType::UInt32), + ColumnDescr::new("u64", DataType::UInt64), + ColumnDescr::new("date32", DataType::Date32), + ColumnDescr::new("date64", DataType::Date64), + // TODO: date/time columns + // todo decimal columns + // begin string columns + ColumnDescr::new("utf8", DataType::Utf8), + ColumnDescr::new("largeutf8", DataType::LargeUtf8), + // TODO add support for utf8view in data generator + // ColumnDescr::new("utf8view", DataType::Utf8View), + // todo binary + // low cardinality columns + ColumnDescr::new("u8_low", DataType::UInt8).with_max_num_distinct(10), + ColumnDescr::new("utf8_low", DataType::Utf8).with_max_num_distinct(10), + ]; + + let min_num_rows = 512; + let max_num_rows = 1024; + + DatasetGeneratorConfig { + columns, + rows_num_range: (min_num_rows, max_num_rows), + sort_keys_set: vec![ + // low cardinality to try and get many repeated runs + vec![String::from("u8_low")], + vec![String::from("utf8_low"), String::from("u8_low")], + ], + } +} + +// ======================================================================== +// The old aggregation fuzz tests +// ======================================================================== + +/// Tracks if this stream is generating input or output /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -56,7 +208,7 @@ async fn streaming_aggregate_test() { vec!["d", "c", "a"], vec!["d", "c", "b", "a"], ]; - let n = 300; + let n = 10; let distincts = vec![10, 20]; for distinct in distincts { let mut join_set = JoinSet::new(); @@ -82,7 +234,7 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = vec![]; + let mut sort_keys = LexOrdering::default(); for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { expr: col(ordering_col, &schema).unwrap(), @@ -98,14 +250,19 @@ async fn run_aggregate_test(input1: Vec, group_by_columns: Vec<&str let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys]), + .try_with_sort_information(vec![sort_keys]) + .unwrap(), ); - let aggregate_expr = vec![Arc::new(Sum::new( - col("d", &schema).unwrap(), - "sum1", - DataType::Int64, - )) as Arc]; + let aggregate_expr = + vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("d", &schema).unwrap()]) + .schema(Arc::clone(&schema)) + .alias("sum1") + .build() + .map(Arc::new) + .unwrap(), + ]; let expr = group_by_columns .iter() .map(|elem| (col(elem, &schema).unwrap(), elem.to_string())) @@ -306,16 +463,17 @@ async fn group_by_string_test( let actual = extract_result_counts(results); assert_eq!(expected, actual); } + async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { struct Visitor { expected_sort: bool, } let mut visitor = Visitor { expected_sort }; - impl TreeNodeVisitor for Visitor { + impl<'n> TreeNodeVisitor<'n> for Visitor { type Node = Arc; - fn f_down(&mut self, node: &Self::Node) -> Result { + fn f_down(&mut self, node: &'n Self::Node) -> Result { if let Some(exec) = node.as_any().downcast_ref::() { if self.expected_sort { assert!(matches!( diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs new file mode 100644 index 000000000000..af454bee7ce8 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp, sync::Arc}; + +use datafusion::{ + datasource::MemTable, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::TableProvider; +use datafusion_common::error::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::col; +use rand::{thread_rng, Rng}; + +use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; + +/// SessionContext generator +/// +/// During testing, `generate_baseline` will be called firstly to generate a standard [`SessionContext`], +/// and we will run `sql` on it to get the `expected result`. Then `generate` will be called some times to +/// generate some random [`SessionContext`]s, and we will run the same `sql` on them to get `actual results`. +/// Finally, we compare the `actual results` with `expected result`, the test only success while all they are +/// same with the expected. +/// +/// Following parameters of [`SessionContext`] used in query running will be generated randomly: +/// - `batch_size` +/// - `target_partitions` +/// - `skip_partial parameters` +/// - hint `sorted` or not +/// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed +/// to support this) +/// +pub struct SessionContextGenerator { + /// Current testing dataset + dataset: Arc, + + /// Table name of the test table + table_name: String, + + /// Used in generate the random `batch_size` + /// + /// The generated `batch_size` is between (0, total_rows_num] + max_batch_size: usize, + + /// Candidate `SkipPartialParams` which will be picked randomly + candidate_skip_partial_params: Vec, + + /// The upper bound of the randomly generated target partitions, + /// and the lower bound will be 1 + max_target_partitions: usize, +} + +impl SessionContextGenerator { + pub fn new(dataset_ref: Arc, table_name: &str) -> Self { + let candidate_skip_partial_params = vec![ + SkipPartialParams::ensure_trigger(), + SkipPartialParams::ensure_not_trigger(), + ]; + + let max_batch_size = cmp::max(1, dataset_ref.total_rows_num); + let max_target_partitions = num_cpus::get(); + + Self { + dataset: dataset_ref, + table_name: table_name.to_string(), + max_batch_size, + candidate_skip_partial_params, + max_target_partitions, + } + } +} + +impl SessionContextGenerator { + /// Generate the `SessionContext` for the baseline run + pub fn generate_baseline(&self) -> Result { + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // The baseline context should try best to disable all optimizations, + // and pursuing the rightness. + let batch_size = self.max_batch_size; + let target_partitions = 1; + let skip_partial_params = SkipPartialParams::ensure_not_trigger(); + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + skip_partial_params, + sort_hint: false, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } + + /// Randomly generate session context + pub fn generate(&self) -> Result { + let mut rng = thread_rng(); + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // We will randomly generate following options: + // - `batch_size`, from range: [1, `total_rows_num`] + // - `target_partitions`, from range: [1, cpu_num] + // - `skip_partial`, trigger or not trigger currently for simplicity + // - `sorted`, if found a sorted dataset, will or will not push down this information + // - `spilling`(TODO) + let batch_size = rng.gen_range(1..=self.max_batch_size); + + let target_partitions = rng.gen_range(1..=self.max_target_partitions); + + let skip_partial_params_idx = + rng.gen_range(0..self.candidate_skip_partial_params.len()); + let skip_partial_params = + self.candidate_skip_partial_params[skip_partial_params_idx]; + + let (provider, sort_hint) = + if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + // Sort keys exist and random to push down + let sort_exprs = self + .dataset + .sort_keys + .iter() + .map(|key| col(key).sort(true, true)) + .collect::>(); + (provider.with_sort_order(vec![sort_exprs]), true) + } else { + (provider, false) + }; + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + sort_hint, + skip_partial_params, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } +} + +/// The generated [`SessionContext`] with its params +/// +/// Storing the generated `params` is necessary for +/// reporting the broken test case. +pub struct SessionContextWithParams { + pub ctx: SessionContext, + pub params: SessionContextParams, +} + +/// Collect the generated params, and build the [`SessionContext`] +struct GeneratedSessionContextBuilder { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, + table_name: String, + table_provider: Arc, +} + +impl GeneratedSessionContextBuilder { + fn build(self) -> Result { + // Build session context + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.batch_size", + &ScalarValue::UInt64(Some(self.batch_size as u64)), + ); + session_config = session_config.set( + "datafusion.execution.target_partitions", + &ScalarValue::UInt64(Some(self.target_partitions as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::UInt64(Some(self.skip_partial_params.rows_threshold as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(self.skip_partial_params.ratio_threshold)), + ); + + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table(self.table_name, self.table_provider)?; + + let params = SessionContextParams { + batch_size: self.batch_size, + target_partitions: self.target_partitions, + sort_hint: self.sort_hint, + skip_partial_params: self.skip_partial_params, + }; + + Ok(SessionContextWithParams { ctx, params }) + } +} + +/// The generated params for [`SessionContext`] +#[derive(Debug)] +#[allow(dead_code)] +pub struct SessionContextParams { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, +} + +/// Partial skipping parameters +#[derive(Debug, Clone, Copy)] +pub struct SkipPartialParams { + /// Related to `skip_partial_aggregation_probe_ratio_threshold` in `ExecutionOptions` + pub ratio_threshold: f64, + + /// Related to `skip_partial_aggregation_probe_rows_threshold` in `ExecutionOptions` + pub rows_threshold: usize, +} + +impl SkipPartialParams { + /// Generate `SkipPartialParams` ensuring to trigger partial skipping + pub fn ensure_trigger() -> Self { + Self { + ratio_threshold: 0.0, + rows_threshold: 0, + } + } + + /// Generate `SkipPartialParams` ensuring not to trigger partial skipping + pub fn ensure_not_trigger() -> Self { + Self { + ratio_threshold: 1.0, + rows_threshold: usize::MAX, + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::{RecordBatch, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[tokio::test] + async fn test_generated_context() { + // 1. Define a test dataset firstly + let a_col: StringArray = [ + Some("rust"), + Some("java"), + Some("cpp"), + Some("go"), + Some("go1"), + Some("python"), + Some("python1"), + Some("python2"), + ] + .into_iter() + .collect(); + // Sort by "b" + let b_col: UInt32Array = [ + Some(1), + Some(2), + Some(4), + Some(8), + Some(8), + Some(16), + Some(16), + Some(16), + ] + .into_iter() + .collect(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::UInt32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a_col), Arc::new(b_col)], + ) + .unwrap(); + + // One row a group to create batches + let mut batches = Vec::with_capacity(batch.num_rows()); + for start in 0..batch.num_rows() { + let sub_batch = batch.slice(start, 1); + batches.push(sub_batch); + } + + let dataset = Dataset::new(batches, vec!["b".to_string()]); + + // 2. Generate baseline context, and some randomly session contexts. + // Run the same query on them, and all randoms' results should equal to baseline's + let ctx_generator = SessionContextGenerator::new(Arc::new(dataset), "fuzz_table"); + + let query = "select b, count(a) from fuzz_table group by b"; + let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap(); + let mut random_wrapped_ctxs = Vec::with_capacity(8); + for _ in 0..8 { + let ctx = ctx_generator.generate().unwrap(); + random_wrapped_ctxs.push(ctx); + } + + let base_result = baseline_wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + + for wrapped_ctx in random_wrapped_ctxs { + let random_result = wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + check_equality_of_batches(&base_result, &random_result).unwrap(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs new file mode 100644 index 000000000000..aafa5ed7f66b --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -0,0 +1,545 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{ + Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::sorts::sort::sort_batch; +use rand::{ + rngs::{StdRng, ThreadRng}, + thread_rng, Rng, SeedableRng, +}; +use test_utils::{ + array_gen::{PrimitiveArrayGenerator, StringArrayGenerator}, + stagger_batch, +}; + +/// Config for Data sets generator +/// +/// # Parameters +/// - `columns`, you just need to define `column name`s and `column data type`s +/// fot the test datasets, and then they will be randomly generated from generator +/// when you can `generate` function +/// +/// - `rows_num_range`, the rows num of the datasets will be randomly generated +/// among this range +/// +/// - `sort_keys`, if `sort_keys` are defined, when you can `generate`, the generator +/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted +/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets +/// will be returned +/// +#[derive(Debug, Clone)] +pub struct DatasetGeneratorConfig { + /// Descriptions of columns in datasets, it's `required` + pub columns: Vec, + + /// Rows num range of the generated datasets, it's `required` + pub rows_num_range: (usize, usize), + + /// Additional optional sort keys + /// + /// The generated datasets always include a non-sorted copy. For each + /// element in `sort_keys_set`, an additional datasets is created that + /// is sorted by these values as well. + pub sort_keys_set: Vec>, +} + +impl DatasetGeneratorConfig { + /// return a list of all column names + pub fn all_columns(&self) -> Vec<&str> { + self.columns.iter().map(|d| d.name.as_str()).collect() + } + + /// return a list of column names that are "numeric" + pub fn numeric_columns(&self) -> Vec<&str> { + self.columns + .iter() + .filter_map(|d| { + if d.column_type.is_numeric() { + Some(d.name.as_str()) + } else { + None + } + }) + .collect() + } +} + +/// Dataset generator +/// +/// It will generate one random [`Dataset`]s when `generate` function is called. +/// +/// The generation logic in `generate`: +/// +/// - Randomly generate a base record from `batch_generator` firstly. +/// And `columns`, `rows_num_range` in `config`(detail can see `DataSetsGeneratorConfig`), +/// will be used in generation. +/// +/// - Sort the batch according to `sort_keys` in `config` to generator another +/// `len(sort_keys)` sorted batches. +/// +/// - Split each batch to multiple batches which each sub-batch in has the randomly `rows num`, +/// and this multiple batches will be used to create the `Dataset`. +/// +pub struct DatasetGenerator { + batch_generator: RecordBatchGenerator, + sort_keys_set: Vec>, +} + +impl DatasetGenerator { + pub fn new(config: DatasetGeneratorConfig) -> Self { + let batch_generator = RecordBatchGenerator::new( + config.rows_num_range.0, + config.rows_num_range.1, + config.columns, + ); + + Self { + batch_generator, + sort_keys_set: config.sort_keys_set, + } + } + + pub fn generate(&self) -> Result> { + let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); + + // Generate the base batch (unsorted) + let base_batch = self.batch_generator.generate()?; + let batches = stagger_batch(base_batch.clone()); + let dataset = Dataset::new(batches, Vec::new()); + datasets.push(dataset); + + // Generate the related sorted batches + let schema = base_batch.schema_ref(); + for sort_keys in self.sort_keys_set.clone() { + let sort_exprs = sort_keys + .iter() + .map(|key| { + let col_expr = col(key, schema)?; + Ok(PhysicalSortExpr::new_default(col_expr)) + }) + .collect::>()?; + let sorted_batch = sort_batch(&base_batch, sort_exprs.as_ref(), None)?; + + let batches = stagger_batch(sorted_batch); + let dataset = Dataset::new(batches, sort_keys); + datasets.push(dataset); + } + + Ok(datasets) + } +} + +/// Single test data set +#[derive(Debug)] +pub struct Dataset { + pub batches: Vec, + pub total_rows_num: usize, + pub sort_keys: Vec, +} + +impl Dataset { + pub fn new(batches: Vec, sort_keys: Vec) -> Self { + let total_rows_num = batches.iter().map(|batch| batch.num_rows()).sum::(); + + Self { + batches, + total_rows_num, + sort_keys, + } + } +} + +#[derive(Debug, Clone)] +pub struct ColumnDescr { + /// Column name + name: String, + + /// Data type of this column + column_type: DataType, + + /// The maximum number of distinct values in this column. + /// + /// See [`ColumnDescr::with_max_num_distinct`] for more information + max_num_distinct: Option, +} + +impl ColumnDescr { + #[inline] + pub fn new(name: &str, column_type: DataType) -> Self { + Self { + name: name.to_string(), + column_type, + max_num_distinct: None, + } + } + + /// set the maximum number of distinct values in this column + /// + /// If `None`, the number of distinct values is randomly selected between 1 + /// and the number of rows. + pub fn with_max_num_distinct(mut self, num_distinct: usize) -> Self { + self.max_num_distinct = Some(num_distinct); + self + } +} + +/// Record batch generator +struct RecordBatchGenerator { + min_rows_nun: usize, + + max_rows_num: usize, + + columns: Vec, + + candidate_null_pcts: Vec, +} + +macro_rules! generate_string_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $OFFSET_TYPE:ty) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let max_len = $BATCH_GEN_RNG.gen_range(1..50); + + let mut generator = StringArrayGenerator { + max_len, + num_strings: $NUM_ROWS, + num_distinct_strings: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$OFFSET_TYPE>() + }}; +} + +macro_rules! generate_primitive_array { + ($SELF:ident, $NUM_ROWS:ident, $MAX_NUM_DISTINCT:expr, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $ARROW_TYPE:ident) => { + paste::paste! {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives: $MAX_NUM_DISTINCT, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$ARROW_TYPE>() + }}} +} + +impl RecordBatchGenerator { + fn new(min_rows_nun: usize, max_rows_num: usize, columns: Vec) -> Self { + let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; + + Self { + min_rows_nun, + max_rows_num, + columns, + candidate_null_pcts, + } + } + + fn generate(&self) -> Result { + let mut rng = thread_rng(); + let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(rng.gen()); + + // Build arrays + let mut arrays = Vec::with_capacity(self.columns.len()); + for col in self.columns.iter() { + let array = self.generate_array_of_type( + col, + num_rows, + &mut rng, + array_gen_rng.clone(), + ); + arrays.push(array); + } + + // Build schema + let fields = self + .columns + .iter() + .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) + } + + fn generate_array_of_type( + &self, + col: &ColumnDescr, + num_rows: usize, + batch_gen_rng: &mut ThreadRng, + array_gen_rng: StdRng, + ) -> ArrayRef { + let num_distinct = if num_rows > 1 { + batch_gen_rng.gen_range(1..num_rows) + } else { + num_rows + }; + // cap to at most the num_distinct values + let max_num_distinct = col + .max_num_distinct + .map(|max| num_distinct.min(max)) + .unwrap_or(num_distinct); + + match col.column_type { + DataType::Int8 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int8Type + ) + } + DataType::Int16 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int16Type + ) + } + DataType::Int32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int32Type + ) + } + DataType::Int64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Int64Type + ) + } + DataType::UInt8 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt8Type + ) + } + DataType::UInt16 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt16Type + ) + } + DataType::UInt32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt32Type + ) + } + DataType::UInt64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + UInt64Type + ) + } + DataType::Float32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Float32Type + ) + } + DataType::Float64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Float64Type + ) + } + DataType::Date32 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Date32Type + ) + } + DataType::Date64 => { + generate_primitive_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + Date64Type + ) + } + DataType::Utf8 => { + generate_string_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + i32 + ) + } + DataType::LargeUtf8 => { + generate_string_array!( + self, + num_rows, + max_num_distinct, + batch_gen_rng, + array_gen_rng, + i64 + ) + } + _ => { + panic!("Unsupported data generator type: {}", col.column_type) + } + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::UInt32Array; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[test] + fn test_generated_datasets() { + // The test datasets generation config + // We expect that after calling `generate` + // - Generate 2 datasets + // - They have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + // - One of them is unsorted, another is sorted by column "b" + // - Their rows num should be same and between [16, 32] + let config = DatasetGeneratorConfig { + columns: vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::UInt32), + ], + rows_num_range: (16, 32), + sort_keys_set: vec![vec!["b".to_string()]], + }; + + let gen = DatasetGenerator::new(config); + let datasets = gen.generate().unwrap(); + + // Should Generate 2 datasets + assert_eq!(datasets.len(), 2); + + // Should have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + let check_fields = |batch: &RecordBatch| { + assert_eq!(batch.num_columns(), 2); + let fields = batch.schema().fields().clone(); + assert_eq!(fields[0].name(), "a"); + assert_eq!(*fields[0].data_type(), DataType::Utf8); + assert_eq!(fields[1].name(), "b"); + assert_eq!(*fields[1].data_type(), DataType::UInt32); + }; + + let batch = &datasets[0].batches[0]; + check_fields(batch); + let batch = &datasets[1].batches[0]; + check_fields(batch); + + // One batches should be sort by "b" + let sorted_batches = &datasets[1].batches; + let b_vals = sorted_batches.iter().flat_map(|batch| { + let uint_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + uint_array.iter() + }); + let mut prev_b_val = u32::MIN; + for b_val in b_vals { + let b_val = b_val.unwrap_or(u32::MIN); + assert!(b_val >= prev_b_val); + prev_b_val = b_val; + } + + // Two batches should be same after sorting + check_equality_of_batches(&datasets[0].batches, &datasets[1].batches).unwrap(); + + // Rows num should between [16, 32] + let rows_num0 = datasets[0] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + let rows_num1 = datasets[1] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + assert_eq!(rows_num0, rows_num1); + assert!(rows_num0 >= 16); + assert!(rows_num0 <= 32); + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs new file mode 100644 index 000000000000..d021e73f35b2 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -0,0 +1,527 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::sync::Arc; + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion_common::{DataFusionError, Result}; +use rand::{thread_rng, Rng}; +use tokio::task::JoinSet; + +use crate::fuzz_cases::aggregation_fuzzer::{ + check_equality_of_batches, + context_generator::{SessionContextGenerator, SessionContextWithParams}, + data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig}, + run_sql, +}; + +/// Rounds to call `generate` of [`SessionContextGenerator`] +/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`] +/// will generated for each dataset for testing. +const CTX_GEN_ROUNDS: usize = 16; + +/// Aggregation fuzzer's builder +pub struct AggregationFuzzerBuilder { + /// See `candidate_sqls` in [`AggregationFuzzer`], no default, and required to set + candidate_sqls: Vec>, + + /// See `table_name` in [`AggregationFuzzer`], no default, and required to set + table_name: Option>, + + /// Used to generate `dataset_generator` in [`AggregationFuzzer`], + /// no default, and required to set + data_gen_config: Option, + + /// See `data_gen_rounds` in [`AggregationFuzzer`], default 16 + data_gen_rounds: usize, +} + +impl AggregationFuzzerBuilder { + fn new() -> Self { + Self { + candidate_sqls: Vec::new(), + table_name: None, + data_gen_config: None, + data_gen_rounds: 16, + } + } + + /// Adds random SQL queries to the fuzzer along with the table name + /// + /// Adds + /// - 3 random queries + /// - 3 random queries for each group by selected from the sort keys + /// - 1 random query with no grouping + pub fn add_query_builder(mut self, mut query_builder: QueryBuilder) -> Self { + const NUM_QUERIES: usize = 3; + for _ in 0..NUM_QUERIES { + let sql = query_builder.generate_query(); + self.candidate_sqls.push(Arc::from(sql)); + } + // also add several queries limited to grouping on the group by columns only, if any + // So if the data is sorted on `a,b` only group by `a,b` or`a` or `b` + if let Some(data_gen_config) = &self.data_gen_config { + for sort_keys in &data_gen_config.sort_keys_set { + let group_by_columns = sort_keys.iter().map(|s| s.as_str()); + query_builder = query_builder.set_group_by_columns(group_by_columns); + for _ in 0..NUM_QUERIES { + let sql = query_builder.generate_query(); + self.candidate_sqls.push(Arc::from(sql)); + } + } + } + // also add a query with no grouping + query_builder = query_builder.set_group_by_columns(vec![]); + let sql = query_builder.generate_query(); + self.candidate_sqls.push(Arc::from(sql)); + + self.table_name(query_builder.table_name()) + } + + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = Some(Arc::from(table_name)); + self + } + + pub fn data_gen_config(mut self, data_gen_config: DatasetGeneratorConfig) -> Self { + self.data_gen_config = Some(data_gen_config); + self + } + + pub fn build(self) -> AggregationFuzzer { + assert!(!self.candidate_sqls.is_empty()); + let candidate_sqls = self.candidate_sqls; + let table_name = self.table_name.expect("table_name is required"); + let data_gen_config = self.data_gen_config.expect("data_gen_config is required"); + let data_gen_rounds = self.data_gen_rounds; + + let dataset_generator = DatasetGenerator::new(data_gen_config); + + AggregationFuzzer { + candidate_sqls, + table_name, + dataset_generator, + data_gen_rounds, + } + } +} + +impl Default for AggregationFuzzerBuilder { + fn default() -> Self { + Self::new() + } +} + +impl From for AggregationFuzzerBuilder { + fn from(value: DatasetGeneratorConfig) -> Self { + Self::default().data_gen_config(value) + } +} + +/// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`], +/// and running them to check the correctness of the optimizations +/// (e.g. sorted, partial skipping, spilling...) +pub struct AggregationFuzzer { + /// Candidate test queries represented by sqls + candidate_sqls: Vec>, + + /// The queried table name + table_name: Arc, + + /// Dataset generator used to randomly generate datasets + dataset_generator: DatasetGenerator, + + /// Rounds to call `generate` of [`DatasetGenerator`], + /// len(sort_keys_set) + 1` datasets will be generated for testing. + /// + /// It is suggested to set value 2x or more bigger than num of + /// `candidate_sqls` for better test coverage. + data_gen_rounds: usize, +} + +/// Query group including the tested dataset and its sql query +struct QueryGroup { + dataset: Dataset, + sql: Arc, +} + +impl AggregationFuzzer { + /// Run the fuzzer, printing an error and panicking if any of the tasks fail + pub async fn run(&self) { + let res = self.run_inner().await; + + if let Err(e) = res { + // Print the error via `Display` so that it displays nicely (the default `unwrap()` + // prints using `Debug` which escapes newlines, and makes multi-line messages + // hard to read + println!("{e}"); + panic!("Error!"); + } + } + + async fn run_inner(&self) -> Result<()> { + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + + // Loop to generate datasets and its query + for _ in 0..self.data_gen_rounds { + // Generate datasets first + let datasets = self + .dataset_generator + .generate() + .expect("should success to generate dataset"); + + // Then for each of them, we random select a test sql for it + let query_groups = datasets + .into_iter() + .map(|dataset| { + let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql = self.candidate_sqls[sql_idx].clone(); + + QueryGroup { dataset, sql } + }) + .collect::>(); + + for q in &query_groups { + println!(" Testing with query {}", q.sql); + } + + let tasks = self.generate_fuzz_tasks(query_groups).await; + for task in tasks { + join_set.spawn(async move { task.run().await }); + } + } + + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.map_err(|e| { + DataFusionError::Internal(format!( + "AggregationFuzzer task error: {:?}", + e + )) + })??; + } + Ok(()) + } + + async fn generate_fuzz_tasks( + &self, + query_groups: Vec, + ) -> Vec { + let mut tasks = Vec::with_capacity(query_groups.len() * CTX_GEN_ROUNDS); + for QueryGroup { dataset, sql } in query_groups { + let dataset_ref = Arc::new(dataset); + let ctx_generator = + SessionContextGenerator::new(dataset_ref.clone(), &self.table_name); + + // Generate the baseline context, and get the baseline result firstly + let baseline_ctx_with_params = ctx_generator + .generate_baseline() + .expect("should success to generate baseline session context"); + let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) + .await + .expect("should success to run baseline sql"); + let baseline_result = Arc::new(baseline_result); + // Generate test tasks + for _ in 0..CTX_GEN_ROUNDS { + let ctx_with_params = ctx_generator + .generate() + .expect("should success to generate session context"); + let task = AggregationFuzzTestTask { + dataset_ref: dataset_ref.clone(), + expected_result: baseline_result.clone(), + sql: sql.clone(), + ctx_with_params, + }; + + tasks.push(task); + } + } + tasks + } +} + +/// One test task generated by [`AggregationFuzzer`] +/// +/// It includes: +/// - `expected_result`, the expected result generated by baseline [`SessionContext`] +/// (disable all possible optimizations for ensuring correctness). +/// +/// - `ctx`, a randomly generated [`SessionContext`], `sql` will be run +/// on it after, and check if the result is equal to expected. +/// +/// - `sql`, the selected test sql +/// +/// - `dataset_ref`, the input dataset, store it for error reported when found +/// the inconsistency between the one for `ctx` and `expected results`. +/// +struct AggregationFuzzTestTask { + /// Generated session context in current test case + ctx_with_params: SessionContextWithParams, + + /// Expected result in current test case + /// It is generate from `query` + `baseline session context` + expected_result: Arc>, + + /// The test query + /// Use sql to represent it currently. + sql: Arc, + + /// The test dataset for error reporting + dataset_ref: Arc, +} + +impl AggregationFuzzTestTask { + async fn run(&self) -> Result<()> { + let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx) + .await + .map_err(|e| e.context(self.context_error_report()))?; + self.check_result(&task_result, &self.expected_result) + } + + fn check_result( + &self, + task_result: &[RecordBatch], + expected_result: &[RecordBatch], + ) -> Result<()> { + check_equality_of_batches(task_result, expected_result).map_err(|e| { + // If we found inconsistent result, we print the test details for reproducing at first + let message = format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + e.row_idx, + e.lhs_row, + e.rhs_row, + format_batches_with_limit(task_result), + format_batches_with_limit(expected_result), + format_batches_with_limit(&self.dataset_ref.batches), + ); + DataFusionError::Internal(message) + }) + } + + /// Returns a formatted error message + fn context_error_report(&self) -> String { + format!( + "##### AggregationFuzzer error report #####\n\ + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + pretty_format_batches(&self.dataset_ref.batches).unwrap(), + ) + } +} + +/// Pretty prints the `RecordBatch`es, limited to the first 100 rows +fn format_batches_with_limit(batches: &[RecordBatch]) -> impl std::fmt::Display { + const MAX_ROWS: usize = 100; + let mut row_count = 0; + let to_print = batches + .iter() + .filter_map(|b| { + if row_count >= MAX_ROWS { + None + } else if row_count + b.num_rows() > MAX_ROWS { + // output last rows before limit + let slice_len = MAX_ROWS - row_count; + let b = b.slice(0, slice_len); + row_count += slice_len; + Some(b) + } else { + row_count += b.num_rows(); + Some(b.clone()) + } + }) + .collect::>(); + + pretty_format_batches(&to_print).unwrap() +} + +/// Random aggregate query builder +/// +/// Creates queries like +/// ```sql +/// SELECT AGG(..) FROM table_name GROUP BY +///``` +#[derive(Debug, Default, Clone)] +pub struct QueryBuilder { + /// The name of the table to query + table_name: String, + /// Aggregate functions to be used in the query + /// (function_name, is_distinct) + aggregate_functions: Vec<(String, bool)>, + /// Columns to be used in group by + group_by_columns: Vec, + /// Possible columns for arguments in the aggregate functions + /// + /// Assumes each + arguments: Vec, +} +impl QueryBuilder { + pub fn new() -> Self { + Default::default() + } + + /// return the table name if any + pub fn table_name(&self) -> &str { + &self.table_name + } + + /// Set the table name for the query builder + pub fn with_table_name(mut self, table_name: impl Into) -> Self { + self.table_name = table_name.into(); + self + } + + /// Add a new possible aggregate function to the query builder + pub fn with_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), false)); + self + } + + /// Add a new possible `DISTINCT` aggregate function to the query + /// + /// This is different than `with_aggregate_function` because only certain + /// aggregates support `DISTINCT` + pub fn with_distinct_aggregate_function( + mut self, + aggregate_function: impl Into, + ) -> Self { + self.aggregate_functions + .push((aggregate_function.into(), true)); + self + } + + /// Set the columns to be used in the group bys clauses + pub fn set_group_by_columns<'a>( + mut self, + group_by: impl IntoIterator, + ) -> Self { + self.group_by_columns = group_by.into_iter().map(String::from).collect(); + self + } + + /// Add one or more columns to be used as an argument in the aggregate functions + pub fn with_aggregate_arguments<'a>( + mut self, + arguments: impl IntoIterator, + ) -> Self { + let arguments = arguments.into_iter().map(String::from); + self.arguments.extend(arguments); + self + } + + pub fn generate_query(&self) -> String { + let group_by = self.random_group_by(); + let mut query = String::from("SELECT "); + query.push_str(&self.random_aggregate_functions().join(", ")); + query.push_str(" FROM "); + query.push_str(&self.table_name); + if !group_by.is_empty() { + query.push_str(" GROUP BY "); + query.push_str(&group_by.join(", ")); + } + query + } + + /// Generate a some random aggregate function invocations (potentially repeating). + /// + /// Each aggregate function invocation is of the form + /// + /// ```sql + /// function_name( argument) as alias + /// ``` + /// + /// where + /// * `function_names` are randomly selected from [`Self::aggregate_functions`] + /// * ` argument` is randomly selected from [`Self::arguments`] + /// * `alias` is a unique alias `colN` for the column (to avoid duplicate column names) + fn random_aggregate_functions(&self) -> Vec { + const MAX_NUM_FUNCTIONS: usize = 5; + let mut rng = thread_rng(); + let num_aggregate_functions = rng.gen_range(1..MAX_NUM_FUNCTIONS); + + let mut alias_gen = 1; + + let mut aggregate_functions = vec![]; + while aggregate_functions.len() < num_aggregate_functions { + let idx = rng.gen_range(0..self.aggregate_functions.len()); + let (function_name, is_distinct) = &self.aggregate_functions[idx]; + let argument = self.random_argument(); + let alias = format!("col{}", alias_gen); + let distinct = if *is_distinct { "DISTINCT " } else { "" }; + alias_gen += 1; + let function = format!("{function_name}({distinct}{argument}) as {alias}"); + aggregate_functions.push(function); + } + aggregate_functions + } + + /// Pick a random aggregate function argument + fn random_argument(&self) -> String { + let mut rng = thread_rng(); + let idx = rng.gen_range(0..self.arguments.len()); + self.arguments[idx].clone() + } + + /// Pick a random number of fields to group by (non-repeating) + /// + /// Limited to 3 group by columns to ensure coverage for large groups. With + /// larger numbers of columns, each group has many fewer values. + fn random_group_by(&self) -> Vec { + let mut rng = thread_rng(); + const MAX_GROUPS: usize = 3; + let max_groups = self.group_by_columns.len().max(MAX_GROUPS); + let num_group_by = rng.gen_range(1..max_groups); + + let mut already_used = HashSet::new(); + let mut group_by = vec![]; + while group_by.len() < num_group_by + && already_used.len() != self.group_by_columns.len() + { + let idx = rng.gen_range(0..self.group_by_columns.len()); + if already_used.insert(idx) { + group_by.push(self.group_by_columns[idx].clone()); + } + } + group_by + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs new file mode 100644 index 000000000000..d93a5b7b9360 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::error::Result; + +mod context_generator; +mod data_generator; +mod fuzzer; + +pub use data_generator::{ColumnDescr, DatasetGeneratorConfig}; +pub use fuzzer::*; + +#[derive(Debug)] +pub(crate) struct InconsistentResult { + pub row_idx: usize, + pub lhs_row: String, + pub rhs_row: String, +} + +pub(crate) fn check_equality_of_batches( + lhs: &[RecordBatch], + rhs: &[RecordBatch], +) -> std::result::Result<(), InconsistentResult> { + let lhs_formatted_batches = pretty_format_batches(lhs).unwrap().to_string(); + let mut lhs_formatted_batches_sorted: Vec<&str> = + lhs_formatted_batches.trim().lines().collect(); + lhs_formatted_batches_sorted.sort_unstable(); + let rhs_formatted_batches = pretty_format_batches(rhs).unwrap().to_string(); + let mut rhs_formatted_batches_sorted: Vec<&str> = + rhs_formatted_batches.trim().lines().collect(); + rhs_formatted_batches_sorted.sort_unstable(); + + for (row_idx, (lhs_row, rhs_row)) in lhs_formatted_batches_sorted + .iter() + .zip(&rhs_formatted_batches_sorted) + .enumerate() + { + if lhs_row != rhs_row { + return Err(InconsistentResult { + row_idx, + lhs_row: lhs_row.to_string(), + rhs_row: rhs_row.to_string(), + }); + } + } + + Ok(()) +} + +pub(crate) async fn run_sql(sql: &str, ctx: &SessionContext) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs similarity index 83% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs rename to datafusion/core/tests/fuzz_cases/equivalence/mod.rs index de090badd349..2f8a38200bf1 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/core/tests/fuzz_cases/equivalence/mod.rs @@ -15,10 +15,9 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod accumulate; -mod adapter; -pub use accumulate::NullState; -pub use adapter::GroupsAccumulatorAdapter; +//! `EquivalenceProperties` fuzz testing -pub(crate) mod bool_op; -pub(crate) mod prim_op; +mod ordering; +mod projection; +mod properties; +mod utils; diff --git a/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs new file mode 100644 index 000000000000..525baadd14a5 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/ordering.rs @@ -0,0 +1,395 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + convert_to_orderings, create_random_schema, create_test_params, create_test_schema_2, + generate_table_for_eq_properties, generate_table_for_orderings, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 5; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + let col_exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + ]; + + for n_req in 0..=col_exprs.len() { + for exprs in col_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + eq_properties.ordering_satisfy(requirement.as_ref()), + expected, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::(); + let expected = is_table_same_after_sort( + requirement.clone(), + table_data_with_properties.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + + assert_eq!( + eq_properties.ordering_satisfy(requirement.as_ref()), + (expected | false), + "{}", + err_msg + ); + } + } + } + + Ok(()) +} + +#[test] +fn test_ordering_satisfy_with_equivalence() -> Result<()> { + // Schema satisfies following orderings: + // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] + // and + // Column [a=c] (e.g they are aliases). + let (test_schema, eq_properties) = create_test_params()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, 625, 5)?; + + // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function + let requirements = vec![ + // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it + (vec![(col_a, option_asc)], true), + (vec![(col_a, option_desc)], false), + // Test whether equivalence works as expected + (vec![(col_c, option_asc)], true), + (vec![(col_c, option_desc)], false), + // Test whether ordering equivalence works as expected + (vec![(col_d, option_asc)], true), + (vec![(col_d, option_asc), (col_b, option_asc)], true), + (vec![(col_d, option_desc), (col_b, option_asc)], false), + ( + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + true, + ), + (vec![(col_e, option_desc), (col_f, option_asc)], true), + (vec![(col_e, option_asc), (col_f, option_asc)], false), + (vec![(col_e, option_desc), (col_b, option_asc)], false), + (vec![(col_e, option_asc), (col_b, option_asc)], false), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_f, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_d, option_desc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_f, option_asc), + ], + false, + ), + ( + vec![ + (col_d, option_asc), + (col_b, option_asc), + (col_e, option_asc), + (col_b, option_asc), + ], + false, + ), + (vec![(col_d, option_asc), (col_e, option_desc)], true), + ( + vec![ + (col_d, option_asc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_f, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_c, option_asc), + (col_b, option_asc), + ], + true, + ), + ( + vec![ + (col_d, option_asc), + (col_e, option_desc), + (col_b, option_asc), + (col_f, option_asc), + ], + true, + ), + ]; + + for (cols, expected) in requirements { + let err_msg = format!("Error in test case:{cols:?}"); + let required = cols + .into_iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(expr), + options, + }) + .collect::(); + + // Check expected result with experimental result. + assert_eq!( + is_table_same_after_sort( + required.clone(), + table_data_with_properties.clone() + )?, + expected + ); + assert_eq!( + eq_properties.ordering_satisfy(required.as_ref()), + expected, + "{err_msg}" + ); + } + + Ok(()) +} + +// This test checks given a table is ordered with `[a ASC, b ASC, c ASC, d ASC]` and `[a ASC, c ASC, b ASC, d ASC]` +// whether the table is also ordered with `[a ASC, b ASC, d ASC]` and `[a ASC, c ASC, d ASC]` +// Since these orderings cannot be deduced, these orderings shouldn't be satisfied by the table generated. +// For background see discussion: https://github.com/apache/datafusion/issues/12700#issuecomment-2411134296 +#[test] +fn test_ordering_satisfy_on_data() -> Result<()> { + let schema = create_test_schema_2()?; + let col_a = &col("a", &schema)?; + let col_b = &col("b", &schema)?; + let col_c = &col("c", &schema)?; + let col_d = &col("d", &schema)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + let orderings = vec![ + // [a ASC, b ASC, c ASC, d ASC] + vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ], + // [a ASC, c ASC, b ASC, d ASC] + vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + + let batch = generate_table_for_orderings(orderings, schema, 1000, 10)?; + + // [a ASC, c ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_c, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC, d ASC] cannot be deduced + let ordering = vec![ + (col_a, option_asc), + (col_b, option_asc), + (col_d, option_asc), + ]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(!is_table_same_after_sort(ordering, batch.clone())?); + + // [a ASC, b ASC] can be deduced + let ordering = vec![(col_a, option_asc), (col_b, option_asc)]; + let ordering = convert_to_orderings(&[ordering])[0].clone(); + assert!(is_table_same_after_sort(ordering, batch.clone())?); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/projection.rs b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs new file mode 100644 index 000000000000..3df3e0348e42 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/projection.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + apply_projection, create_random_schema, generate_table_for_eq_properties, + is_table_same_after_sort, TestScalarUDF, +}; +use arrow_schema::SortOptions; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn project_orderings_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + // Make sure each ordering after projection is valid. + for ordering in projected_eq.oeq_class().iter() { + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs + ); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + projected_batch.clone(), + )?, + "{}", + err_msg + ); + } + } + } + } + + Ok(()) +} + +#[test] +fn ordering_satisfy_after_projection_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 20; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + const SORT_OPTIONS: SortOptions = SortOptions { + descending: false, + nulls_first: false, + }; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + // Floor(a) + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + // a + b + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let proj_exprs = vec![ + (col("a", &test_schema)?, "a_new"), + (col("b", &test_schema)?, "b_new"), + (col("c", &test_schema)?, "c_new"), + (col("d", &test_schema)?, "d_new"), + (col("e", &test_schema)?, "e_new"), + (col("f", &test_schema)?, "f_new"), + (floor_a, "floor(a)"), + (a_plus_b, "a+b"), + ]; + + for n_req in 0..=proj_exprs.len() { + for proj_exprs in proj_exprs.iter().combinations(n_req) { + let proj_exprs = proj_exprs + .into_iter() + .map(|(expr, name)| (Arc::clone(expr), name.to_string())) + .collect::>(); + let (projected_batch, projected_eq) = apply_projection( + proj_exprs.clone(), + &table_data_with_properties, + &eq_properties, + )?; + + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &test_schema)?; + + let projected_exprs = projection_mapping + .iter() + .map(|(_source, target)| Arc::clone(target)) + .collect::>(); + + for n_req in 0..=projected_exprs.len() { + for exprs in projected_exprs.iter().combinations(n_req) { + let requirement = exprs + .into_iter() + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: SORT_OPTIONS, + }) + .collect::(); + let expected = is_table_same_after_sort( + requirement.clone(), + projected_batch.clone(), + )?; + let err_msg = format!( + "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", + requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping + ); + // Check whether ordering_satisfy API result and + // experimental result matches. + assert_eq!( + projected_eq.ordering_satisfy(requirement.as_ref()), + expected, + "{}", + err_msg + ); + } + } + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/properties.rs b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs new file mode 100644 index 000000000000..82586bd79eda --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/properties.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::fuzz_cases::equivalence::utils::{ + create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, + TestScalarUDF, +}; +use datafusion_common::{DFSchema, Result}; +use datafusion_expr::{Operator, ScalarUDF}; +use datafusion_physical_expr::expressions::{col, BinaryExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use itertools::Itertools; +use std::sync::Arc; + +#[test] +fn test_find_longest_permutation_random() -> Result<()> { + const N_RANDOM_SCHEMA: usize = 100; + const N_ELEMENTS: usize = 125; + const N_DISTINCT: usize = 5; + + for seed in 0..N_RANDOM_SCHEMA { + // Create a random schema with random properties + let (test_schema, eq_properties) = create_random_schema(seed as u64)?; + // Generate a data that satisfies properties given + let table_data_with_properties = + generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; + + let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); + let floor_a = datafusion_physical_expr::udf::create_physical_expr( + &test_fun, + &[col("a", &test_schema)?], + &test_schema, + &[], + &DFSchema::empty(), + )?; + let a_plus_b = Arc::new(BinaryExpr::new( + col("a", &test_schema)?, + Operator::Plus, + col("b", &test_schema)?, + )) as Arc; + let exprs = [ + col("a", &test_schema)?, + col("b", &test_schema)?, + col("c", &test_schema)?, + col("d", &test_schema)?, + col("e", &test_schema)?, + col("f", &test_schema)?, + floor_a, + a_plus_b, + ]; + + for n_req in 0..=exprs.len() { + for exprs in exprs.iter().combinations(n_req) { + let exprs = exprs.into_iter().cloned().collect::>(); + let (ordering, indices) = eq_properties.find_longest_permutation(&exprs); + // Make sure that find_longest_permutation return values are consistent + let ordering2 = indices + .iter() + .zip(ordering.iter()) + .map(|(&idx, sort_expr)| PhysicalSortExpr { + expr: Arc::clone(&exprs[idx]), + options: sort_expr.options, + }) + .collect::(); + assert_eq!( + ordering, ordering2, + "indices and lexicographical ordering do not match" + ); + + let err_msg = format!( + "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", + ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants + ); + assert_eq!(ordering.len(), indices.len(), "{}", err_msg); + // Since ordered section satisfies schema, we expect + // that result will be same after sort (e.g sort was unnecessary). + assert!( + is_table_same_after_sort( + ordering.clone(), + table_data_with_properties.clone(), + )?, + "{}", + err_msg + ); + } + } + } + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/equivalence/utils.rs b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs new file mode 100644 index 000000000000..35da8b596380 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/equivalence/utils.rs @@ -0,0 +1,627 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +use datafusion::physical_plan::expressions::col; +use datafusion::physical_plan::expressions::Column; +use datafusion_physical_expr::{ConstExpr, EquivalenceProperties, PhysicalSortExpr}; +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::compute::{lexsort_to_indices, take_record_batch, SortColumn}; +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::{ArrayRef, Float32Array, Float64Array, RecordBatch, UInt32Array}; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; +use datafusion_common::{exec_err, plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::equivalence::{EquivalenceClass, ProjectionMapping}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; + +use itertools::izip; +use rand::prelude::*; + +pub fn output_schema( + mapping: &ProjectionMapping, + input_schema: &Arc, +) -> Result { + // Calculate output schema + let fields: Result> = mapping + .iter() + .map(|(source, target)| { + let name = target + .as_any() + .downcast_ref::() + .ok_or_else(|| plan_datafusion_err!("Expects to have column"))? + .name(); + let field = Field::new( + name, + source.data_type(input_schema)?, + source.nullable(input_schema)?, + ); + + Ok(field) + }) + .collect(); + + let output_schema = Arc::new(Schema::new_with_metadata( + fields?, + input_schema.metadata().clone(), + )); + + Ok(output_schema) +} + +// Generate a schema which consists of 6 columns (a, b, c, d, e, f) +pub fn create_test_schema_2() -> Result { + let a = Field::new("a", DataType::Float64, true); + let b = Field::new("b", DataType::Float64, true); + let c = Field::new("c", DataType::Float64, true); + let d = Field::new("d", DataType::Float64, true); + let e = Field::new("e", DataType::Float64, true); + let f = Field::new("f", DataType::Float64, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); + + Ok(schema) +} + +/// Construct a schema with random ordering +/// among column a, b, c, d +/// where +/// Column [a=f] (e.g they are aliases). +/// Column e is constant. +pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema_2()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; + + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + // Define a and f are aliases + eq_properties.add_equal_conditions(col_a, col_f)?; + // Column e has constant value. + eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); + + // Randomly order columns for sorting + let mut rng = StdRng::seed_from_u64(seed); + let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted + + let options_asc = SortOptions { + descending: false, + nulls_first: false, + }; + + while !remaining_exprs.is_empty() { + let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); + remaining_exprs.shuffle(&mut rng); + + let ordering = remaining_exprs + .drain(0..n_sort_expr) + .map(|expr| PhysicalSortExpr { + expr: Arc::clone(expr), + options: options_asc, + }) + .collect(); + + eq_properties.add_new_orderings([ordering]); + } + + Ok((test_schema, eq_properties)) +} + +// Apply projection to the input_data, return projected equivalence properties and record batch +pub fn apply_projection( + proj_exprs: Vec<(Arc, String)>, + input_data: &RecordBatch, + input_eq_properties: &EquivalenceProperties, +) -> Result<(RecordBatch, EquivalenceProperties)> { + let input_schema = input_data.schema(); + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + + let output_schema = output_schema(&projection_mapping, &input_schema)?; + let num_rows = input_data.num_rows(); + // Apply projection to the input record batch. + let projected_values = projection_mapping + .iter() + .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) + .collect::>>()?; + let projected_batch = if projected_values.is_empty() { + RecordBatch::new_empty(Arc::clone(&output_schema)) + } else { + RecordBatch::try_new(Arc::clone(&output_schema), projected_values)? + }; + + let projected_eq = input_eq_properties.project(&projection_mapping, output_schema); + Ok((projected_batch, projected_eq)) +} + +#[test] +fn add_equal_conditions_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("x", DataType::Int64, true), + Field::new("y", DataType::Int64, true), + ])); + + let mut eq_properties = EquivalenceProperties::new(schema); + let col_a_expr = Arc::new(Column::new("a", 0)) as Arc; + let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; + let col_c_expr = Arc::new(Column::new("c", 2)) as Arc; + let col_x_expr = Arc::new(Column::new("x", 3)) as Arc; + let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; + + // a and b are aliases + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + + // This new entry is redundant, size shouldn't increase + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 2); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + + // b and c are aliases. Exising equivalence class should expand, + // however there shouldn't be any new equivalence class + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 3); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + + // This is a new set of equality. Hence equivalent class count should be 2. + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; + assert_eq!(eq_properties.eq_group().len(), 2); + + // This equality bridges distinct equality sets. + // Hence equivalent class count should decrease from 2 to 1. + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; + assert_eq!(eq_properties.eq_group().len(), 1); + let eq_groups = &eq_properties.eq_group().classes[0]; + assert_eq!(eq_groups.len(), 5); + assert!(eq_groups.contains(&col_a_expr)); + assert!(eq_groups.contains(&col_b_expr)); + assert!(eq_groups.contains(&col_c_expr)); + assert!(eq_groups.contains(&col_x_expr)); + assert!(eq_groups.contains(&col_y_expr)); + + Ok(()) +} + +/// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. +/// +/// The function works by adding a unique column of ascending integers to the original table. This column ensures +/// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can +/// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce +/// deterministic sorting results. +/// +/// If the table remains the same after sorting with the added unique column, it indicates that the table was +/// already sorted according to `required_ordering` to begin with. +pub fn is_table_same_after_sort( + mut required_ordering: LexOrdering, + batch: RecordBatch, +) -> Result { + // Clone the original schema and columns + let original_schema = batch.schema(); + let mut columns = batch.columns().to_vec(); + + // Create a new unique column + let n_row = batch.num_rows(); + let vals: Vec = (0..n_row).collect::>(); + let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); + let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; + columns.push(Arc::clone(&unique_col)); + + // Create a new schema with the added unique column + let unique_col_name = "unique"; + let unique_field = Arc::new(Field::new(unique_col_name, DataType::Float64, false)); + let fields: Vec<_> = original_schema + .fields() + .iter() + .cloned() + .chain(std::iter::once(unique_field)) + .collect(); + let schema = Arc::new(Schema::new(fields)); + + // Create a new batch with the added column + let new_batch = RecordBatch::try_new(Arc::clone(&schema), columns)?; + + // Add the unique column to the required ordering to ensure deterministic results + required_ordering.push(PhysicalSortExpr { + expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), + options: Default::default(), + }); + + // Convert the required ordering to a list of SortColumn + let sort_columns = required_ordering + .iter() + .map(|order_expr| { + let expr_result = order_expr.expr.evaluate(&new_batch)?; + let values = expr_result.into_array(new_batch.num_rows())?; + Ok(SortColumn { + values, + options: Some(order_expr.options), + }) + }) + .collect::>>()?; + + // Check if the indices after sorting match the initial ordering + let sorted_indices = lexsort_to_indices(&sort_columns, None)?; + let original_indices = UInt32Array::from_iter_values(0..n_row as u32); + + Ok(sorted_indices == original_indices) +} + +// If we already generated a random result for one of the +// expressions in the equivalence classes. For other expressions in the same +// equivalence class use same result. This util gets already calculated result, when available. +fn get_representative_arr( + eq_group: &EquivalenceClass, + existing_vec: &[Option], + schema: SchemaRef, +) -> Option { + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + if let Some(res) = &existing_vec[idx] { + return Some(Arc::clone(res)); + } + } + None +} + +// Generate a schema which consists of 8 columns (a, b, c, d, e, f, g, h) +pub fn create_test_schema() -> Result { + let a = Field::new("a", DataType::Int32, true); + let b = Field::new("b", DataType::Int32, true); + let c = Field::new("c", DataType::Int32, true); + let d = Field::new("d", DataType::Int32, true); + let e = Field::new("e", DataType::Int32, true); + let f = Field::new("f", DataType::Int32, true); + let g = Field::new("g", DataType::Int32, true); + let h = Field::new("h", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f, g, h])); + + Ok(schema) +} + +/// Construct a schema with following properties +/// Schema satisfies following orderings: +/// [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] +/// and +/// Column [a=c] (e.g they are aliases). +pub fn create_test_params() -> Result<(SchemaRef, EquivalenceProperties)> { + let test_schema = create_test_schema()?; + let col_a = &col("a", &test_schema)?; + let col_b = &col("b", &test_schema)?; + let col_c = &col("c", &test_schema)?; + let col_d = &col("d", &test_schema)?; + let col_e = &col("e", &test_schema)?; + let col_f = &col("f", &test_schema)?; + let col_g = &col("g", &test_schema)?; + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + eq_properties.add_equal_conditions(col_a, col_c)?; + + let option_asc = SortOptions { + descending: false, + nulls_first: false, + }; + let option_desc = SortOptions { + descending: true, + nulls_first: true, + }; + let orderings = vec![ + // [a ASC] + vec![(col_a, option_asc)], + // [d ASC, b ASC] + vec![(col_d, option_asc), (col_b, option_asc)], + // [e DESC, f ASC, g ASC] + vec![ + (col_e, option_desc), + (col_f, option_asc), + (col_g, option_asc), + ], + ]; + let orderings = convert_to_orderings(&orderings); + eq_properties.add_new_orderings(orderings); + Ok((test_schema, eq_properties)) +} + +// Generate a table that satisfies the given equivalence properties; i.e. +// equivalences, ordering equivalences, and constants. +pub fn generate_table_for_eq_properties( + eq_properties: &EquivalenceProperties, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + let schema = eq_properties.schema(); + let mut schema_vec = vec![None; schema.fields.len()]; + + // Utility closure to generate random array + let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { + let values: Vec = (0..num_elems) + .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) + }; + + // Fill constant columns + for constant in &eq_properties.constants { + let col = constant.expr().as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = + Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) as ArrayRef; + schema_vec[idx] = Some(arr); + } + + // Fill columns based on ordering equivalences + for ordering in eq_properties.oeq_class.iter() { + let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering + .iter() + .map(|PhysicalSortExpr { expr, options }| { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + let arr = generate_random_array(n_elem, n_distinct); + ( + SortColumn { + values: arr, + options: Some(*options), + }, + idx, + ) + }) + .unzip(); + + let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + for (idx, arr) in izip!(indices, sort_arrs) { + schema_vec[idx] = Some(arr); + } + } + + // Fill columns based on equivalence groups + for eq_group in eq_properties.eq_group.iter() { + let representative_array = + get_representative_arr(eq_group, &schema_vec, Arc::clone(schema)) + .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); + + for expr in eq_group.iter() { + let col = expr.as_any().downcast_ref::().unwrap(); + let (idx, _field) = schema.column_with_name(col.name()).unwrap(); + schema_vec[idx] = Some(Arc::clone(&representative_array)); + } + } + + let res: Vec<_> = schema_vec + .into_iter() + .zip(schema.fields.iter()) + .map(|(elem, field)| { + ( + field.name(), + // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) + elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), + ) + }) + .collect(); + + Ok(RecordBatch::try_from_iter(res)?) +} + +// Generate a table that satisfies the given orderings; +pub fn generate_table_for_orderings( + mut orderings: Vec, + schema: SchemaRef, + n_elem: usize, + n_distinct: usize, +) -> Result { + let mut rng = StdRng::seed_from_u64(23); + + assert!(!orderings.is_empty()); + // Sort the inner vectors by their lengths (longest first) + orderings.sort_by_key(|v| std::cmp::Reverse(v.inner.len())); + + let arrays = schema + .fields + .iter() + .map(|field| { + ( + field.name(), + generate_random_f64_array(n_elem, n_distinct, &mut rng), + ) + }) + .collect::>(); + let batch = RecordBatch::try_from_iter(arrays)?; + + // Sort batch according to first ordering expression + let sort_columns = get_sort_columns(&batch, orderings[0].as_ref())?; + let sort_indices = lexsort_to_indices(&sort_columns, None)?; + let mut batch = take_record_batch(&batch, &sort_indices)?; + + // prune out rows that is invalid according to remaining orderings. + for ordering in orderings.iter().skip(1) { + let sort_columns = get_sort_columns(&batch, ordering.as_ref())?; + + // Collect sort options and values into separate vectors. + let (sort_options, sort_col_values): (Vec<_>, Vec<_>) = sort_columns + .into_iter() + .map(|sort_col| (sort_col.options.unwrap(), sort_col.values)) + .unzip(); + + let mut cur_idx = 0; + let mut keep_indices = vec![cur_idx as u32]; + for next_idx in 1..batch.num_rows() { + let cur_row = get_row_at_idx(&sort_col_values, cur_idx)?; + let next_row = get_row_at_idx(&sort_col_values, next_idx)?; + + if compare_rows(&cur_row, &next_row, &sort_options)? != Ordering::Greater { + // next row satisfies ordering relation given, compared to the current row. + keep_indices.push(next_idx as u32); + cur_idx = next_idx; + } + } + // Only keep valid rows, that satisfies given ordering relation. + batch = take_record_batch(&batch, &UInt32Array::from_iter_values(keep_indices))?; + } + + Ok(batch) +} + +// Convert each tuple to PhysicalSortExpr +pub fn convert_to_sort_exprs( + in_data: &[(&Arc, SortOptions)], +) -> LexOrdering { + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(*expr), + options: *options, + }) + .collect() +} + +// Convert each inner tuple to PhysicalSortExpr +pub fn convert_to_orderings( + orderings: &[Vec<(&Arc, SortOptions)>], +) -> Vec { + orderings + .iter() + .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) + .collect() +} + +// Utility function to generate random f64 array +fn generate_random_f64_array( + n_elems: usize, + n_distinct: usize, + rng: &mut StdRng, +) -> ArrayRef { + let values: Vec = (0..n_elems) + .map(|_| rng.gen_range(0..n_distinct) as f64 / 2.0) + .collect(); + Arc::new(Float64Array::from_iter_values(values)) +} + +// Helper function to get sort columns from a batch +fn get_sort_columns( + batch: &RecordBatch, + ordering: LexOrderingRef, +) -> Result> { + ordering + .iter() + .map(|expr| expr.evaluate_to_sort_column(batch)) + .collect::>>() +} + +#[derive(Debug, Clone)] +pub struct TestScalarUDF { + pub(crate) signature: Signature, +} + +impl TestScalarUDF { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TestScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "test-scalar-udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + + match arg_type { + DataType::Float32 => Ok(DataType::Float32), + _ => Ok(DataType::Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + + let arr: ArrayRef = match args[0].data_type() { + DataType::Float64 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f64::floor)) + .collect::() + }), + DataType::Float32 => Arc::new({ + let arg = &args[0].as_any().downcast_ref::().ok_or_else( + || { + DataFusionError::Internal(format!( + "could not cast {} to {}", + self.name(), + std::any::type_name::() + )) + }, + )?; + + arg.iter() + .map(|a| a.map(f32::floor)) + .collect::() + }), + other => { + return exec_err!( + "Unsupported data type {other:?} for function {}", + self.name() + ); + } + }; + Ok(ColumnarValue::Array(arr)) + } +} diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index fbfa0ffc19b4..d7a3460e4987 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -15,13 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; use arrow_schema::Schema; +use std::sync::Arc; +use std::time::SystemTime; + +use datafusion_common::ScalarValue; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::PhysicalExprRef; + +use itertools::Itertools; use rand::Rng; use datafusion::common::JoinSide; @@ -38,82 +44,250 @@ use datafusion::physical_plan::memory::MemoryExec; use datafusion::prelude::{SessionConfig, SessionContext}; use test_utils::stagger_batch_with_seed; +// Determines what Fuzz tests needs to run +// Ideally all tests should match, but in reality some tests +// passes only partial cases +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum JoinTestType { + // compare NestedLoopJoin and HashJoin + NljHj, + // compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin + // because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes + HjSmj, +} + +fn col_lt_col_filter(schema1: Arc, schema2: Arc) -> JoinFilter { + let less_filter = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 1)), + Operator::Lt, + Arc::new(Column::new("x", 0)), + )) as _; + let column_indices = vec![ + ColumnIndex { + index: 2, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + schema1 + .field_with_name("x") + .unwrap() + .clone() + .with_nullable(true), + schema2 + .field_with_name("x") + .unwrap() + .clone() + .with_nullable(true), + ]); + + JoinFilter::new(less_filter, column_indices, intermediate_schema) +} + +#[tokio::test] +async fn test_inner_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Inner, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + #[tokio::test] async fn test_inner_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Inner, + None, ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] async fn test_left_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Left, + None, ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Left, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] async fn test_right_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Right, + None, ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_right_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Right, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] async fn test_full_join_1k() { - run_join_test( + JoinFuzzTestCase::new( make_staggered_batches(1000), make_staggered_batches(1000), JoinType::Full, + None, ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] -async fn test_semi_join_10k() { - run_join_test( - make_staggered_batches(10000), - make_staggered_batches(10000), +// flaky for HjSmj case +// https://github.com/apache/datafusion/issues/12359 +async fn test_full_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::Full, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_semi_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), JoinType::LeftSemi, + None, + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_semi_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_anti_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftAnti, + None, ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } #[tokio::test] -async fn test_anti_join_10k() { - run_join_test( - make_staggered_batches(10000), - make_staggered_batches(10000), +async fn test_anti_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), JoinType::LeftAnti, + Some(Box::new(col_lt_col_filter)), ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) .await } -/// Perform sort-merge join and hash join on same input -/// and verify two outputs are equal -async fn run_join_test( +#[tokio::test] +async fn test_left_mark_join_1k() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftMark, + None, + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +#[tokio::test] +async fn test_left_mark_join_1k_filtered() { + JoinFuzzTestCase::new( + make_staggered_batches(1000), + make_staggered_batches(1000), + JoinType::LeftMark, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false) + .await +} + +type JoinFilterBuilder = Box, Arc) -> JoinFilter>; + +struct JoinFuzzTestCase { + batch_sizes: &'static [usize], input1: Vec, input2: Vec, join_type: JoinType, -) { - let batch_sizes = [1, 2, 7, 49, 50, 51, 100]; - for batch_size in batch_sizes { - let session_config = SessionConfig::new().with_batch_size(batch_size); - let ctx = SessionContext::new_with_config(session_config); - let task_ctx = ctx.task_ctx(); - - let schema1 = input1[0].schema(); - let schema2 = input2[0].schema(); - let on_columns = vec![ + join_filter_builder: Option, +} + +impl JoinFuzzTestCase { + fn new( + input1: Vec, + input2: Vec, + join_type: JoinType, + join_filter_builder: Option, + ) -> Self { + Self { + batch_sizes: &[1, 2, 7, 49, 50, 51, 100], + input1, + input2, + join_type, + join_filter_builder, + } + } + + fn on_columns(&self) -> Vec<(PhysicalExprRef, PhysicalExprRef)> { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + vec![ ( Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _, Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _, @@ -122,10 +296,90 @@ async fn run_join_test( Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _, Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _, ), - ]; + ] + } + + /// Helper function for building NLJoin filter, returning intermediate + /// schema as a union of origin filter intermediate schema and + /// on-condition schema + fn intermediate_schema(&self) -> Schema { + let filter_schema = if let Some(filter) = self.join_filter() { + filter.schema().to_owned() + } else { + Schema::empty() + }; + + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + + let on_schema = Schema::new(vec![ + schema1 + .field_with_name("a") + .unwrap() + .to_owned() + .with_nullable(true), + schema1 + .field_with_name("b") + .unwrap() + .to_owned() + .with_nullable(true), + schema2.field_with_name("a").unwrap().to_owned(), + schema2.field_with_name("b").unwrap().to_owned(), + ]); + + Schema::new( + filter_schema + .fields + .into_iter() + .cloned() + .chain(on_schema.fields.into_iter().cloned()) + .collect_vec(), + ) + } - // Nested loop join uses filter for joining records - let column_indices = vec![ + /// Helper function for building NLJoin filter, returns the union + /// of original filter expression and on-condition expression + fn composite_filter_expression(&self) -> PhysicalExprRef { + let (filter_expression, column_idx_offset) = + if let Some(filter) = self.join_filter() { + ( + filter.expression().to_owned(), + filter.schema().fields().len(), + ) + } else { + (Arc::new(Literal::new(ScalarValue::from(true))) as _, 0) + }; + + let equal_a = Arc::new(BinaryExpr::new( + Arc::new(Column::new("a", column_idx_offset)), + Operator::Eq, + Arc::new(Column::new("a", column_idx_offset + 2)), + )); + let equal_b = Arc::new(BinaryExpr::new( + Arc::new(Column::new("b", column_idx_offset + 1)), + Operator::Eq, + Arc::new(Column::new("b", column_idx_offset + 3)), + )); + let on_expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)); + + Arc::new(BinaryExpr::new( + filter_expression, + Operator::And, + on_expression, + )) + } + + /// Helper function for building NLJoin filter, returning the union + /// of original filter column indices and on-condition column indices. + /// Result must match intermediate schema. + fn column_indices(&self) -> Vec { + let mut column_indices = if let Some(filter) = self.join_filter() { + filter.column_indices().to_vec() + } else { + vec![] + }; + + let on_column_indices = vec![ ColumnIndex { index: 0, side: JoinSide::Left, @@ -143,120 +397,315 @@ async fn run_join_test( side: JoinSide::Right, }, ]; - let intermediate_schema = Schema::new(vec![ - schema1.field_with_name("a").unwrap().to_owned(), - schema1.field_with_name("b").unwrap().to_owned(), - schema2.field_with_name("a").unwrap().to_owned(), - schema2.field_with_name("b").unwrap().to_owned(), - ]); - let equal_a = Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Eq, - Arc::new(Column::new("a", 2)), - )) as _; - let equal_b = Arc::new(BinaryExpr::new( - Arc::new(Column::new("b", 1)), - Operator::Eq, - Arc::new(Column::new("b", 3)), - )) as _; - let expression = Arc::new(BinaryExpr::new(equal_a, Operator::And, equal_b)) as _; - - let on_filter = JoinFilter::new(expression, column_indices, intermediate_schema); - - // sort-merge join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let smj = Arc::new( + column_indices.extend(on_column_indices); + column_indices + } + + fn left_right(&self) -> (Arc, Arc) { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + let left = + Arc::new(MemoryExec::try_new(&[self.input1.clone()], schema1, None).unwrap()); + let right = + Arc::new(MemoryExec::try_new(&[self.input2.clone()], schema2, None).unwrap()); + (left, right) + } + + fn join_filter(&self) -> Option { + let schema1 = self.input1[0].schema(); + let schema2 = self.input2[0].schema(); + self.join_filter_builder + .as_ref() + .map(|builder| builder(schema1, schema2)) + } + + fn sort_merge_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( SortMergeJoinExec::try_new( left, right, - on_columns.clone(), - None, - join_type, - vec![SortOptions::default(), SortOptions::default()], + self.on_columns().clone(), + self.join_filter(), + self.join_type, + vec![SortOptions::default(); self.on_columns().len()], false, ) .unwrap(), - ); - let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); - - // hash join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let hj = Arc::new( + ) + } + + fn hash_join(&self) -> Arc { + let (left, right) = self.left_right(); + Arc::new( HashJoinExec::try_new( left, right, - on_columns.clone(), - None, - &join_type, + self.on_columns().clone(), + self.join_filter(), + &self.join_type, None, PartitionMode::Partitioned, false, ) .unwrap(), - ); - let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); - - // nested loop join - let left = Arc::new( - MemoryExec::try_new(&[input1.clone()], schema1.clone(), None).unwrap(), - ); - let right = Arc::new( - MemoryExec::try_new(&[input2.clone()], schema2.clone(), None).unwrap(), - ); - let nlj = Arc::new( - NestedLoopJoinExec::try_new(left, right, Some(on_filter), &join_type) + ) + } + + fn nested_loop_join(&self) -> Arc { + let (left, right) = self.left_right(); + + let column_indices = self.column_indices(); + let intermediate_schema = self.intermediate_schema(); + let expression = self.composite_filter_expression(); + + let filter = JoinFilter::new(expression, column_indices, intermediate_schema); + + Arc::new( + NestedLoopJoinExec::try_new(left, right, Some(filter), &self.join_type) .unwrap(), - ); - let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); - - // compare - let smj_formatted = pretty_format_batches(&smj_collected).unwrap().to_string(); - let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); - let nlj_formatted = pretty_format_batches(&nlj_collected).unwrap().to_string(); - - let mut smj_formatted_sorted: Vec<&str> = smj_formatted.trim().lines().collect(); - smj_formatted_sorted.sort_unstable(); - - let mut hj_formatted_sorted: Vec<&str> = hj_formatted.trim().lines().collect(); - hj_formatted_sorted.sort_unstable(); - - let mut nlj_formatted_sorted: Vec<&str> = nlj_formatted.trim().lines().collect(); - nlj_formatted_sorted.sort_unstable(); - - for (i, (smj_line, hj_line)) in smj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, smj_line), - (i, hj_line), - "SortMergeJoinExec and HashJoinExec produced different results" - ); + ) + } + + /// Perform joins tests on same inputs and verify outputs are equal + /// `join_tests` - identifies what join types to test + /// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders, + /// so it is easy to debug a test on top of the failed data + async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) { + for batch_size in self.batch_sizes { + let session_config = SessionConfig::new().with_batch_size(*batch_size); + let ctx = SessionContext::new_with_config(session_config); + let task_ctx = ctx.task_ctx(); + + let hj = self.hash_join(); + let hj_collected = collect(hj, task_ctx.clone()).await.unwrap(); + + let smj = self.sort_merge_join(); + let smj_collected = collect(smj, task_ctx.clone()).await.unwrap(); + + let nlj = self.nested_loop_join(); + let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap(); + + // Get actual row counts(without formatting overhead) for HJ and SMJ + let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows()); + + // compare + let smj_formatted = + pretty_format_batches(&smj_collected).unwrap().to_string(); + let hj_formatted = pretty_format_batches(&hj_collected).unwrap().to_string(); + let nlj_formatted = + pretty_format_batches(&nlj_collected).unwrap().to_string(); + + let mut smj_formatted_sorted: Vec<&str> = + smj_formatted.trim().lines().collect(); + smj_formatted_sorted.sort_unstable(); + + let mut hj_formatted_sorted: Vec<&str> = + hj_formatted.trim().lines().collect(); + hj_formatted_sorted.sort_unstable(); + + let mut nlj_formatted_sorted: Vec<&str> = + nlj_formatted.trim().lines().collect(); + nlj_formatted_sorted.sort_unstable(); + + if debug + && ((join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows) + || (join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows)) + { + let fuzz_debug = "fuzz_test_debug"; + std::fs::remove_dir_all(fuzz_debug).unwrap_or(()); + std::fs::create_dir_all(fuzz_debug).unwrap(); + let out_dir_name = &format!("{fuzz_debug}/batch_size_{batch_size}"); + println!("Test result data mismatch found. HJ rows {}, SMJ rows {}, NLJ rows {}", hj_rows, smj_rows, nlj_rows); + println!("The debug is ON. Input data will be saved to {out_dir_name}"); + + Self::save_partitioned_batches_as_parquet( + &self.input1, + out_dir_name, + "input1", + ); + Self::save_partitioned_batches_as_parquet( + &self.input2, + out_dir_name, + "input2", + ); + + if join_tests.contains(&JoinTestType::NljHj) && nlj_rows != hj_rows { + println!("=============== HashJoinExec =================="); + hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + println!("=============== NestedLoopJoinExec =================="); + nlj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + + Self::save_partitioned_batches_as_parquet( + &nlj_collected, + out_dir_name, + "nlj", + ); + Self::save_partitioned_batches_as_parquet( + &hj_collected, + out_dir_name, + "hj", + ); + } + + if join_tests.contains(&JoinTestType::HjSmj) && smj_rows != hj_rows { + println!("=============== HashJoinExec =================="); + hj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + println!("=============== SortMergeJoinExec =================="); + smj_formatted_sorted.iter().for_each(|s| println!("{}", s)); + + Self::save_partitioned_batches_as_parquet( + &hj_collected, + out_dir_name, + "hj", + ); + Self::save_partitioned_batches_as_parquet( + &smj_collected, + out_dir_name, + "smj", + ); + } + } + + if join_tests.contains(&JoinTestType::NljHj) { + let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size); + assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str()); + + let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size); + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + for (i, (nlj_line, hj_line)) in nlj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, nlj_line), + (i, hj_line), + "{}", + err_msg_contents.as_str() + ); + } + } + + if join_tests.contains(&JoinTestType::HjSmj) { + let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size); + assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str()); + + let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size); + // row level compare if any of joins returns the result + // the reason is different formatting when there is no rows + if smj_rows > 0 || hj_rows > 0 { + for (i, (smj_line, hj_line)) in smj_formatted_sorted + .iter() + .zip(&hj_formatted_sorted) + .enumerate() + { + assert_eq!( + (i, smj_line), + (i, hj_line), + "{}", + err_msg_contents.as_str() + ); + } + } + } } + } - for (i, (nlj_line, hj_line)) in nlj_formatted_sorted - .iter() - .zip(&hj_formatted_sorted) - .enumerate() - { - assert_eq!( - (i, nlj_line), - (i, hj_line), - "NestedLoopJoinExec and HashJoinExec produced different results" + /// This method useful for debugging fuzz tests + /// It helps to save randomly generated input test data for both join inputs into the user folder + /// as a parquet files preserving partitioning. + /// Once the data is saved it is possible to run a custom test on top of the saved data and debug + /// + /// #[tokio::test] + /// async fn test1() { + /// let left: Vec = JoinFuzzTestCase::load_partitioned_batches_from_parquet("fuzz_test_debug/batch_size_2/input1").await.unwrap(); + /// let right: Vec = JoinFuzzTestCase::load_partitioned_batches_from_parquet("fuzz_test_debug/batch_size_2/input2").await.unwrap(); + /// + /// JoinFuzzTestCase::new( + /// left, + /// right, + /// JoinType::LeftSemi, + /// Some(Box::new(col_lt_col_filter)), + /// ) + /// .run_test(&[JoinTestType::HjSmj], false) + /// .await; + /// } + fn save_partitioned_batches_as_parquet( + input: &[RecordBatch], + output_dir: &str, + out_name: &str, + ) { + let out_path = &format!("{output_dir}/{out_name}"); + std::fs::remove_dir_all(out_path).unwrap_or(()); + std::fs::create_dir_all(out_path).unwrap(); + + input.iter().enumerate().for_each(|(idx, batch)| { + let file_path = format!("{out_path}/file_{}.parquet", idx); + let mut file = std::fs::File::create(&file_path).unwrap(); + println!( + "{}: Saving batch idx {} rows {} to parquet {}", + &out_name, + idx, + batch.num_rows(), + &file_path ); + let mut writer = parquet::arrow::ArrowWriter::try_new( + &mut file, + input.first().unwrap().schema(), + None, + ) + .expect("creating writer"); + writer.write(batch).unwrap(); + writer.close().unwrap(); + }); + } + + /// Read parquet files preserving partitions, i.e. 1 file -> 1 partition + /// Files can be of different sizes + /// The method can be useful to read partitions have been saved by `save_partitioned_batches_as_parquet` + /// for test debugging purposes + #[allow(dead_code)] + async fn load_partitioned_batches_from_parquet( + dir: &str, + ) -> std::io::Result> { + let ctx: SessionContext = SessionContext::new(); + let mut batches: Vec = vec![]; + let mut entries = std::fs::read_dir(dir)? + .map(|res| res.map(|e| e.path())) + .collect::, std::io::Error>>()?; + + // important to read files using the same order as they have been written + // sort by modification time + entries.sort_by_key(|path| { + std::fs::metadata(path) + .and_then(|metadata| metadata.modified()) + .unwrap_or(SystemTime::UNIX_EPOCH) + }); + + for entry in entries { + let path = entry.as_path(); + + if path.is_file() { + let mut batch = ctx + .read_parquet( + path.to_str().unwrap(), + datafusion::prelude::ParquetReadOptions::default(), + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + batches.append(&mut batch); + } } + Ok(batches) } } diff --git a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs index 9889ce2ae562..c52acdd82764 100644 --- a/datafusion/core/tests/fuzz_cases/limit_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/limit_fuzz.rs @@ -226,7 +226,7 @@ impl SortedData { } /// Return the sort expression to use for this data, depending on the type - fn sort_expr(&self) -> Vec { + fn sort_expr(&self) -> Vec { match self { Self::I32 { .. } | Self::F64 { .. } | Self::Str { .. } => { vec![datafusion_expr::col("x").sort(true, true)] @@ -341,7 +341,7 @@ async fn run_limit_test(fetch: usize, data: &SortedData) { /// Return random ASCII String with len fn get_random_string(len: usize) -> String { - rand::thread_rng() + thread_rng() .sample_iter(rand::distributions::Alphanumeric) .take(len) .map(char::from) diff --git a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs index 95cd75f50a00..4e895920dd3d 100644 --- a/datafusion/core/tests/fuzz_cases/merge_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/merge_fuzz.rs @@ -16,6 +16,7 @@ // under the License. //! Fuzz Test for various corner cases merging streams of RecordBatches + use std::sync::Arc; use arrow::{ @@ -30,6 +31,7 @@ use datafusion::physical_plan::{ sorts::sort_preserving_merge::SortPreservingMergeExec, }; use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use test_utils::{batches_to_vec, partitions_to_sorted_vec, stagger_batch_with_seed}; #[tokio::test] @@ -106,13 +108,13 @@ async fn run_merge_test(input: Vec>) { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let exec = MemoryExec::try_new(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 69241571b4af..49db0d31a8e9 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,6 +21,9 @@ mod join_fuzz; mod merge_fuzz; mod sort_fuzz; +mod aggregation_fuzzer; +mod equivalence; + mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index f4b4f16aa160..e4acb96f4930 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -22,7 +22,7 @@ use arrow::{ compute::SortOptions, record_batch::RecordBatch, }; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; @@ -30,6 +30,7 @@ use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_execution::memory_pool::GreedyMemoryPool; use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use rand::Rng; use std::sync::Arc; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -37,8 +38,8 @@ use test_utils::{batches_to_vec, partitions_to_sorted_vec}; const KB: usize = 1 << 10; #[tokio::test] #[cfg_attr(tarpaulin, ignore)] -async fn test_sort_1k_mem() { - for (batch_size, should_spill) in [(5, false), (20000, true), (1000000, true)] { +async fn test_sort_10k_mem() { + for (batch_size, should_spill) in [(5, false), (20000, true), (500000, true)] { SortTest::new() .with_int32_batches(batch_size) .with_pool_size(10 * KB) @@ -114,13 +115,13 @@ impl SortTest { .expect("at least one batch"); let schema = first_batch.schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("x", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let exec = MemoryExec::try_new(&input, schema, None).unwrap(); let sort = Arc::new(SortExec::new(sort, Arc::new(exec))); @@ -136,9 +137,10 @@ impl SortTest { .sort_spill_reservation_bytes, ); - let runtime_config = RuntimeConfig::new() - .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))); - let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))) + .build_arc() + .unwrap(); SessionContext::new_with_config_rt(session_config, runtime) } else { SessionContext::new_with_config(session_config) diff --git a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs index 6c9c3359ebf4..73f4a569954e 100644 --- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs @@ -29,7 +29,7 @@ mod sp_repartition_fuzz_tests { metrics::{BaselineMetrics, ExecutionPlanMetricsSet}, repartition::RepartitionExec, sorts::sort_preserving_merge::SortPreservingMergeExec, - sorts::streaming_merge::streaming_merge, + sorts::streaming_merge::StreamingMergeBuilder, stream::RecordBatchStreamAdapter, ExecutionPlan, Partitioning, }; @@ -39,12 +39,13 @@ mod sp_repartition_fuzz_tests { config::SessionConfig, memory_pool::MemoryConsumer, SendableRecordBatchStream, }; use datafusion_physical_expr::{ + equivalence::{EquivalenceClass, EquivalenceProperties}, expressions::{col, Column}, - EquivalenceProperties, PhysicalExpr, PhysicalSortExpr, + ConstExpr, PhysicalExpr, PhysicalSortExpr, }; use test_utils::add_empty_batches; - use datafusion_physical_expr::equivalence::EquivalenceClass; + use datafusion_physical_expr_common::sort_expr::LexOrdering; use itertools::izip; use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; @@ -78,9 +79,9 @@ mod sp_repartition_fuzz_tests { let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); + eq_properties.add_equal_conditions(col_a, col_f)?; // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); + eq_properties = eq_properties.with_constants([ConstExpr::from(col_e)]); // Randomly order columns for sorting let mut rng = StdRng::seed_from_u64(seed); @@ -149,7 +150,7 @@ mod sp_repartition_fuzz_tests { // Fill constant columns for constant in eq_properties.constants() { - let col = constant.as_any().downcast_ref::().unwrap(); + let col = constant.expr().as_any().downcast_ref::().unwrap(); let (idx, _field) = schema.column_with_name(col.name()).unwrap(); let arr = Arc::new(UInt64Array::from_iter_values(vec![0; n_elem])) as ArrayRef; @@ -174,7 +175,7 @@ mod sp_repartition_fuzz_tests { }) .unzip(); - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; + let sort_arrs = lexsort(&sort_columns, None)?; for (idx, arr) in izip!(indices, sort_arrs) { schema_vec[idx] = Some(arr); } @@ -246,15 +247,14 @@ mod sp_repartition_fuzz_tests { MemoryConsumer::new("test".to_string()).register(context.memory_pool()); // Internally SortPreservingMergeExec uses this function for merging. - let res = streaming_merge( - streams, - schema, - &exprs, - BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0), - 1, - None, - mem_reservation, - )?; + let res = StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(schema) + .with_expressions(&exprs) + .with_metrics(BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0)) + .with_batch_size(1) + .with_reservation(mem_reservation) + .build()?; let res = collect(res).await?; // Contains the merged result. let res = concat_batches(&res[0].schema(), &res)?; @@ -346,7 +346,7 @@ mod sp_repartition_fuzz_tests { let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let mut sort_keys = vec![]; + let mut sort_keys = LexOrdering::default(); for ordering_col in ["a", "b", "c"] { sort_keys.push(PhysicalSortExpr { expr: col(ordering_col, &schema).unwrap(), @@ -359,7 +359,8 @@ mod sp_repartition_fuzz_tests { let running_source = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None) .unwrap() - .with_sort_information(vec![sort_keys.clone()]), + .try_with_sort_information(vec![sort_keys.clone()]) + .unwrap(), ); let hash_exprs = vec![col("c", &schema).unwrap()]; diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 2514324a9541..5bfb4d97ed70 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -17,33 +17,39 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int32Array}; +use arrow::array::{ArrayRef, Int32Array, StringArray}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ - create_window_expr, BoundedWindowAggExec, WindowAggExec, + create_window_expr, schema_add_window_field, BoundedWindowAggExec, WindowAggExec, }; use datafusion::physical_plan::InputOrderMode::{Linear, PartiallySorted, Sorted}; use datafusion::physical_plan::{collect, InputOrderMode}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_common_runtime::SpawnedTask; -use datafusion_expr::type_coercion::aggregates::coerce_types; +use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf; use datafusion_expr::{ - AggregateFunction, BuiltInWindowFunction, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunctionDefinition, + BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, + WindowFunctionDefinition, }; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf}; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use itertools::Itertools; use test_utils::add_empty_batches; +use datafusion::functions_window::row_number::row_number_udwf; +use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf}; +use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashMap; +use rand::distributions::Alphanumeric; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -165,7 +171,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), // its name "COUNT", // window function argument @@ -178,12 +184,10 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ROWS BETWEEN UNBOUNDED PRECEDING AND PRECEDING/FOLLOWING // ) ( - // Window function - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::RowNumber, - ), + // user-defined window function + WindowFunctionDefinition::WindowUDF(row_number_udwf()), // its name - "ROW_NUMBER", + "row_number", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -195,7 +199,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag), + WindowFunctionDefinition::WindowUDF(lag_udwf()), // its name "LAG", // no argument @@ -209,7 +213,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead), + WindowFunctionDefinition::WindowUDF(lead_udwf()), // its name "LEAD", // no argument @@ -223,9 +227,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Rank), + WindowFunctionDefinition::WindowUDF(rank_udwf()), // its name - "RANK", + "rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -237,11 +241,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { // ) ( // Window function - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), // its name - "DENSE_RANK", + "dense_rank", // no argument vec![], // Expected causality, for None cases causality will be determined from window frame boundaries @@ -250,7 +252,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { ]; let partitionby_exprs = vec![]; - let orderby_exprs = vec![]; + let orderby_exprs = LexOrdering::default(); // Window frame starts with "UNBOUNDED PRECEDING": let start_bound = WindowFrameBound::Preceding(ScalarValue::UInt64(None)); @@ -276,14 +278,14 @@ async fn bounded_window_causal_non_causal() -> Result<()> { }; let extended_schema = - schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + schema_add_window_field(&args, &schema, &window_fn, fn_name)?; let window_expr = create_window_expr( &window_fn, fn_name.to_string(), &args, &partitionby_exprs, - &orderby_exprs, + orderby_exprs.as_ref(), Arc::new(window_frame), &extended_schema, false, @@ -292,7 +294,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { vec![window_expr], memory_exec.clone(), vec![], - InputOrderMode::Linear, + Linear, )?); let task_ctx = ctx.task_ctx(); let mut collected_results = @@ -343,28 +345,28 @@ fn get_random_function( window_fn_map.insert( "sum", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), + WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "count", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "min", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![arg.clone()], ), ); window_fn_map.insert( "max", ( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![arg.clone()], ), ); @@ -375,36 +377,25 @@ fn get_random_function( window_fn_map.insert( "row_number", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::RowNumber, - ), + WindowFunctionDefinition::WindowUDF(row_number_udwf()), vec![], ), ); window_fn_map.insert( "rank", - ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Rank, - ), - vec![], - ), + (WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![]), ); window_fn_map.insert( "dense_rank", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::DenseRank, - ), + WindowFunctionDefinition::WindowUDF(dense_rank_udwf()), vec![], ), ); window_fn_map.insert( "lead", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lead, - ), + WindowFunctionDefinition::WindowUDF(lead_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -415,9 +406,7 @@ fn get_random_function( window_fn_map.insert( "lag", ( - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::Lag, - ), + WindowFunctionDefinition::WindowUDF(lag_udwf()), vec![ arg.clone(), lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))), @@ -461,13 +450,12 @@ fn get_random_function( let fn_name = window_fn_map.keys().collect::>()[rand_fn_idx]; let (window_fn, args) = window_fn_map.values().collect::>()[rand_fn_idx]; let mut args = args.clone(); - if let WindowFunctionDefinition::AggregateFunction(f) = window_fn { + if let WindowFunctionDefinition::AggregateUDF(udf) = window_fn { if !args.is_empty() { // Do type coercion first argument let a = args[0].clone(); let dt = a.data_type(schema.as_ref()).unwrap(); - let sig = f.signature(); - let coerced = coerce_types(f, &[dt], &sig).unwrap(); + let coerced = data_types_with_aggregate_udf(&[dt], udf).unwrap(); args[0] = cast(a, schema, coerced[0].clone()).unwrap(); } } @@ -596,25 +584,6 @@ fn convert_bound_to_current_row_if_applicable( } } -/// This utility determines whether a given window frame can be executed with -/// multiple ORDER BY expressions. As an example, range frames with offset (such -/// as `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING`) cannot have ORDER BY clauses -/// of the form `\[ORDER BY a ASC, b ASC, ...]` -fn can_accept_multi_orderby(window_frame: &WindowFrame) -> bool { - match window_frame.units { - WindowFrameUnits::Rows => true, - WindowFrameUnits::Range => { - // Range can only accept multi ORDER BY clauses when bounds are - // CURRENT ROW or UNBOUNDED PRECEDING/FOLLOWING: - (window_frame.start_bound.is_unbounded() - || window_frame.start_bound == WindowFrameBound::CurrentRow) - && (window_frame.end_bound.is_unbounded() - || window_frame.end_bound == WindowFrameBound::CurrentRow) - } - WindowFrameUnits::Groups => true, - } -} - /// Perform batch and running window same input /// and verify outputs of `WindowAggExec` and `BoundedWindowAggExec` are equal async fn run_window_test( @@ -624,42 +593,42 @@ async fn run_window_test( orderby_columns: Vec<&str>, search_mode: InputOrderMode, ) -> Result<()> { - let is_linear = !matches!(search_mode, InputOrderMode::Sorted); + let is_linear = !matches!(search_mode, Sorted); let mut rng = StdRng::seed_from_u64(random_seed); let schema = input1[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); let (window_fn, args, fn_name) = get_random_function(&schema, &mut rng, is_linear); let window_frame = get_random_window_frame(&mut rng, is_linear); - let mut orderby_exprs = vec![]; + let mut orderby_exprs = LexOrdering::default(); for column in &orderby_columns { orderby_exprs.push(PhysicalSortExpr { expr: col(column, &schema)?, options: SortOptions::default(), }) } - if orderby_exprs.len() > 1 && !can_accept_multi_orderby(&window_frame) { - orderby_exprs = orderby_exprs[0..1].to_vec(); + if orderby_exprs.len() > 1 && !window_frame.can_accept_multi_orderby() { + orderby_exprs = LexOrdering::new(orderby_exprs[0..1].to_vec()); } let mut partitionby_exprs = vec![]; for column in &partition_by_columns { partitionby_exprs.push(col(column, &schema)?); } - let mut sort_keys = vec![]; + let mut sort_keys = LexOrdering::default(); for partition_by_expr in &partitionby_exprs { sort_keys.push(PhysicalSortExpr { expr: partition_by_expr.clone(), options: SortOptions::default(), }) } - for order_by_expr in &orderby_exprs { + for order_by_expr in &orderby_exprs.inner { if !sort_keys.contains(order_by_expr) { sort_keys.push(order_by_expr.clone()) } } let concat_input_record = concat_batches(&schema, &input1)?; - let source_sort_keys = vec![ + let source_sort_keys = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: Default::default(), @@ -672,10 +641,10 @@ async fn run_window_test( expr: col("c", &schema)?, options: Default::default(), }, - ]; + ]); let mut exec1 = Arc::new( MemoryExec::try_new(&[vec![concat_input_record]], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ) as _; // Table is ordered according to ORDER BY a, b, c In linear test we use PARTITION BY b, ORDER BY a // For WindowAggExec to produce correct result it need table to be ordered by b,a. Hence add a sort. @@ -683,7 +652,7 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } - let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let extended_schema = schema_add_window_field(&args, &schema, &window_fn, &fn_name)?; let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( @@ -691,7 +660,7 @@ async fn run_window_test( fn_name.clone(), &args, &partitionby_exprs, - &orderby_exprs, + orderby_exprs.as_ref(), Arc::new(window_frame.clone()), &extended_schema, false, @@ -701,7 +670,7 @@ async fn run_window_test( )?) as _; let exec2 = Arc::new( MemoryExec::try_new(&[input1.clone()], schema.clone(), None)? - .with_sort_information(vec![source_sort_keys.clone()]), + .try_with_sort_information(vec![source_sort_keys.clone()])?, ); let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( vec![create_window_expr( @@ -709,7 +678,7 @@ async fn run_window_test( fn_name, &args, &partitionby_exprs, - &orderby_exprs, + orderby_exprs.as_ref(), Arc::new(window_frame.clone()), &extended_schema, false, @@ -720,11 +689,30 @@ async fn run_window_test( )?) as _; let task_ctx = ctx.task_ctx(); let collected_usual = collect(usual_window_exec, task_ctx.clone()).await?; - let collected_running = collect(running_window_exec, task_ctx).await?; + let collected_running = collect(running_window_exec, task_ctx) + .await? + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect::>(); // BoundedWindowAggExec should produce more chunk than the usual WindowAggExec. // Otherwise it means that we cannot generate result in running mode. - assert!(collected_running.len() > collected_usual.len()); + let err_msg = format!("Inconsistent result for window_frame: {window_frame:?}, window_fn: {window_fn:?}, args:{args:?}, random_seed: {random_seed:?}, search_mode: {search_mode:?}, partition_by_columns:{partition_by_columns:?}, orderby_columns: {orderby_columns:?}"); + // Below check makes sure that, streaming execution generates more chunks than the bulk execution. + // Since algorithms and operators works on sliding windows in the streaming execution. + // However, in the current test setup for some random generated window frame clauses: It is not guaranteed + // for streaming execution to generate more chunk than its non-streaming counter part in the Linear mode. + // As an example window frame `OVER(PARTITION BY d ORDER BY a RANGE BETWEEN CURRENT ROW AND 9 FOLLOWING)` + // needs to receive a=10 to generate result for the rows where a=0. If the input data generated is between the range [0, 9]. + // even in streaming mode, generated result will be single bulk as in the non-streaming version. + if search_mode != Linear { + assert!( + collected_running.len() > collected_usual.len(), + "{}", + err_msg + ); + } + // compare let usual_formatted = pretty_format_batches(&collected_usual)?.to_string(); let running_formatted = pretty_format_batches(&collected_running)?.to_string(); @@ -754,36 +742,17 @@ async fn run_window_test( Ok(()) } -// The planner has fully updated schema before calling the `create_window_expr` -// Replicate the same for this test -fn schema_add_window_fields( - args: &[Arc], - schema: &Arc, - window_fn: &WindowFunctionDefinition, - fn_name: &str, -) -> Result> { - let data_types = args - .iter() - .map(|e| e.clone().as_ref().data_type(schema)) - .collect::>>()?; - let window_expr_return_type = window_fn.return_type(&data_types)?; - let mut window_fields = schema - .fields() - .iter() - .map(|f| f.as_ref().clone()) - .collect_vec(); - window_fields.extend_from_slice(&[Field::new( - fn_name, - window_expr_return_type, - true, - )]); - Ok(Arc::new(Schema::new(window_fields))) +fn generate_random_string(rng: &mut StdRng, length: usize) -> String { + rng.sample_iter(&Alphanumeric) + .take(length) + .map(char::from) + .collect() } /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x -fn make_staggered_batches( +pub(crate) fn make_staggered_batches( len: usize, n_distinct: usize, random_seed: u64, @@ -792,6 +761,7 @@ fn make_staggered_batches( let mut rng = StdRng::seed_from_u64(random_seed); let mut input123: Vec<(i32, i32, i32)> = vec![(0, 0, 0); len]; let mut input4: Vec = vec![0; len]; + let mut input5: Vec = vec!["".to_string(); len]; input123.iter_mut().for_each(|v| { *v = ( rng.gen_range(0..n_distinct) as i32, @@ -801,10 +771,15 @@ fn make_staggered_batches( }); input123.sort(); rng.fill(&mut input4[..]); + input5.iter_mut().for_each(|v| { + *v = generate_random_string(&mut rng, 1); + }); + input5.sort(); let input1 = Int32Array::from_iter_values(input123.iter().map(|k| k.0)); let input2 = Int32Array::from_iter_values(input123.iter().map(|k| k.1)); let input3 = Int32Array::from_iter_values(input123.iter().map(|k| k.2)); let input4 = Int32Array::from_iter_values(input4); + let input5 = StringArray::from_iter_values(input5); // split into several record batches let mut remainder = RecordBatch::try_from_iter(vec![ @@ -812,6 +787,7 @@ fn make_staggered_batches( ("b", Arc::new(input2) as ArrayRef), ("c", Arc::new(input3) as ArrayRef), ("x", Arc::new(input4) as ArrayRef), + ("string_field", Arc::new(input5) as ArrayRef), ]) .unwrap(); @@ -820,6 +796,7 @@ fn make_staggered_batches( while remainder.num_rows() > 0 { let batch_size = rng.gen_range(0..50); if remainder.num_rows() < batch_size { + batches.push(remainder); break; } batches.push(remainder.slice(0, batch_size)); diff --git a/datafusion/core/tests/macro_hygiene/mod.rs b/datafusion/core/tests/macro_hygiene/mod.rs index 72ac6e64fb0c..c35e46c0c558 100644 --- a/datafusion/core/tests/macro_hygiene/mod.rs +++ b/datafusion/core/tests/macro_hygiene/mod.rs @@ -37,3 +37,13 @@ mod plan_datafusion_err { plan_datafusion_err!("foo"); } } + +mod record_batch { + // NO other imports! + use datafusion_common::record_batch; + + #[test] + fn test_macro() { + record_batch!(("column_name", Int32, vec![1, 2, 3])).unwrap(); + } +} diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit/mod.rs similarity index 77% rename from datafusion/core/tests/memory_limit.rs rename to datafusion/core/tests/memory_limit/mod.rs index ebc2456224da..6817969580da 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -26,23 +26,29 @@ use datafusion::assert_batches_eq; use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::streaming::PartitionStream; +use datafusion_execution::memory_pool::{ + GreedyMemoryPool, MemoryPool, TrackConsumersPool, +}; use datafusion_expr::{Expr, TableType}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; use futures::StreamExt; use std::any::Any; +use std::num::NonZeroUsize; use std::sync::{Arc, OnceLock}; +use tokio::fs::File; use datafusion::datasource::streaming::StreamingTable; use datafusion::datasource::{MemTable, TableProvider}; -use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::runtime_env::RuntimeEnvBuilder; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::physical_optimizer::join_selection::JoinSelection; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::{ExecutionPlan, SendableRecordBatchStream}; use datafusion_common::{assert_contains, Result}; use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_catalog::Session; use datafusion_execution::TaskContext; use test_utils::AccessLogGenerator; @@ -70,8 +76,7 @@ async fn group_by_none() { TestCase::new() .with_query("select median(request_bytes) from t") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "AggregateStream", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: AggregateStream" ]) .with_memory_limit(2_000) .run() @@ -83,8 +88,7 @@ async fn group_by_row_hash() { TestCase::new() .with_query("select count(*) from t GROUP BY response_bytes") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "GroupedHashAggregateStream", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" ]) .with_memory_limit(2_000) .run() @@ -97,8 +101,7 @@ async fn group_by_hash() { // group by dict column .with_query("select count(*) from t GROUP BY service, host, pod, container") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "GroupedHashAggregateStream", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: GroupedHashAggregateStream" ]) .with_memory_limit(1_000) .run() @@ -111,8 +114,7 @@ async fn join_by_key_multiple_partitions() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "HashJoinInput[0]", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[0]", ]) .with_memory_limit(1_000) .with_config(config) @@ -126,8 +128,7 @@ async fn join_by_key_single_partition() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service = t2.service") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "HashJoinInput", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput", ]) .with_memory_limit(1_000) .with_config(config) @@ -140,8 +141,7 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]", ]) .with_memory_limit(1_000) .run() @@ -153,8 +153,7 @@ async fn cross_join() { TestCase::new() .with_query("select t1.* from t t1 CROSS JOIN t t2") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "CrossJoinExec", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec", ]) .with_memory_limit(1_000) .run() @@ -162,7 +161,7 @@ async fn cross_join() { } #[tokio::test] -async fn merge_join() { +async fn sort_merge_join_no_spill() { // Planner chooses MergeJoin only if number of partitions > 1 let config = SessionConfig::new() .with_target_partitions(2) @@ -173,11 +172,32 @@ async fn merge_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", + "Failed to allocate additional", "SMJStream", + "Disk spilling disabled", ]) .with_memory_limit(1_000) .with_config(config) + .with_scenario(Scenario::AccessLogStreaming) + .run() + .await +} + +#[tokio::test] +async fn sort_merge_join_spill() { + // Planner chooses MergeJoin only if number of partitions > 1 + let config = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false); + + TestCase::new() + .with_query( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + ) + .with_memory_limit(1_000) + .with_config(config) + .with_disk_manager_config(DiskManagerConfig::NewOs) + .with_scenario(Scenario::AccessLogStreaming) .run() .await } @@ -189,8 +209,7 @@ async fn symmetric_hash_join() { "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", ) .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "SymmetricHashJoinStream", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SymmetricHashJoinStream", ]) .with_memory_limit(1_000) .with_scenario(Scenario::AccessLogStreaming) @@ -208,8 +227,7 @@ async fn sort_preserving_merge() { // so only a merge is needed .with_query("select * from t ORDER BY a ASC NULLS LAST, b ASC NULLS LAST LIMIT 10") .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "SortPreservingMergeExec", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: SortPreservingMergeExec", ]) // provide insufficient memory to merge .with_memory_limit(partition_size / 2) @@ -220,17 +238,15 @@ async fn sort_preserving_merge() { // SortPreservingMergeExec (not a Sort which would compete // with the SortPreservingMergeExec for memory) &[ - "+---------------+-------------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+-------------------------------------------------------------------------------------------------------------+", - "| logical_plan | Limit: skip=0, fetch=10 |", - "| | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", - "| | TableScan: t projection=[a, b] |", - "| physical_plan | GlobalLimitExec: skip=0, fetch=10 |", - "| | SortPreservingMergeExec: [a@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 |", - "| | MemoryExec: partitions=2, partition_sizes=[5, 5], output_ordering=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST |", - "| | |", - "+---------------+-------------------------------------------------------------------------------------------------------------+", + "+---------------+------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Sort: t.a ASC NULLS LAST, t.b ASC NULLS LAST, fetch=10 |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | SortPreservingMergeExec: [a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 |", + "| | MemoryExec: partitions=2, partition_sizes=[5, 5], output_ordering=a@0 ASC NULLS LAST, b@1 ASC NULLS LAST |", + "| | |", + "+---------------+------------------------------------------------------------------------------------------------------------+", ] ) .run() @@ -257,7 +273,7 @@ async fn sort_spill_reservation() { .with_query("select * from t ORDER BY a , b DESC") // enough memory to sort if we don't try to merge it all at once .with_memory_limit(partition_size) - // use a single partiton so only a sort is needed + // use a single partition so only a sort is needed .with_scenario(scenario) .with_disk_manager_config(DiskManagerConfig::NewOs) .with_expected_plan( @@ -265,15 +281,15 @@ async fn sort_spill_reservation() { // also merge, so we can ensure the sort could finish // given enough merging memory &[ - "+---------------+--------------------------------------------------------------------------------------------------------+", - "| plan_type | plan |", - "+---------------+--------------------------------------------------------------------------------------------------------+", - "| logical_plan | Sort: t.a ASC NULLS LAST, t.b DESC NULLS FIRST |", - "| | TableScan: t projection=[a, b] |", - "| physical_plan | SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] |", - "| | MemoryExec: partitions=1, partition_sizes=[5], output_ordering=a@0 ASC NULLS LAST,b@1 ASC NULLS LAST |", - "| | |", - "+---------------+--------------------------------------------------------------------------------------------------------+", + "+---------------+---------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+---------------------------------------------------------------------------------------------------------+", + "| logical_plan | Sort: t.a ASC NULLS LAST, t.b DESC NULLS FIRST |", + "| | TableScan: t projection=[a, b] |", + "| physical_plan | SortExec: expr=[a@0 ASC NULLS LAST, b@1 DESC], preserve_partitioning=[false] |", + "| | MemoryExec: partitions=1, partition_sizes=[5], output_ordering=a@0 ASC NULLS LAST, b@1 ASC NULLS LAST |", + "| | |", + "+---------------+---------------------------------------------------------------------------------------------------------+", ] ); @@ -285,8 +301,7 @@ async fn sort_spill_reservation() { test.clone() .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "ExternalSorterMerge", // merging in sort fails + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: ExternalSorterMerge", ]) .with_config(config) .run() @@ -315,14 +330,70 @@ async fn oom_recursive_cte() { SELECT * FROM nodes;", ) .with_expected_errors(vec![ - "Resources exhausted: Failed to allocate additional", - "RecursiveQuery", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: RecursiveQuery", ]) .with_memory_limit(2_000) .run() .await } +#[tokio::test] +async fn oom_parquet_sink() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.into_path().join("test.parquet"); + let _ = File::create(path.clone()).await.unwrap(); + + TestCase::new() + .with_query(format!( + " + COPY (select * from t) + TO '{}' + STORED AS PARQUET OPTIONS (compression 'uncompressed'); + ", + path.to_string_lossy() + )) + .with_expected_errors(vec![ + "Failed to allocate additional", + "for ParquetSink(ArrowColumnWriter)", + ]) + .with_memory_limit(200_000) + .run() + .await +} + +#[tokio::test] +async fn oom_with_tracked_consumer_pool() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.into_path().join("test.parquet"); + let _ = File::create(path.clone()).await.unwrap(); + + TestCase::new() + .with_config( + SessionConfig::new() + ) + .with_query(format!( + " + COPY (select * from t) + TO '{}' + STORED AS PARQUET OPTIONS (compression 'uncompressed'); + ", + path.to_string_lossy() + )) + .with_expected_errors(vec![ + "Failed to allocate additional", + "for ParquetSink(ArrowColumnWriter)", + "Additional allocation failed with top memory consumers (across reservations) as: ParquetSink(ArrowColumnWriter)" + ]) + .with_memory_pool(Arc::new( + TrackConsumersPool::new( + GreedyMemoryPool::new(200_000), + NonZeroUsize::new(1).unwrap() + ) + )) + .run() + .await +} + /// Run the query with the specified memory limit, /// and verifies the expected errors are returned #[derive(Clone, Debug)] @@ -330,12 +401,13 @@ struct TestCase { query: Option, expected_errors: Vec, memory_limit: usize, + memory_pool: Option>, config: SessionConfig, scenario: Scenario, /// How should the disk manager (that allows spilling) be /// configured? Defaults to `Disabled` disk_manager_config: DiskManagerConfig, - /// Expected explain plan, if non emptry + /// Expected explain plan, if non-empty expected_plan: Vec, /// Is the plan expected to pass? Defaults to false expected_success: bool, @@ -348,6 +420,7 @@ impl TestCase { expected_errors: vec![], memory_limit: 0, config: SessionConfig::new(), + memory_pool: None, scenario: Scenario::AccessLog, disk_manager_config: DiskManagerConfig::Disabled, expected_plan: vec![], @@ -377,6 +450,15 @@ impl TestCase { self } + /// Set the memory pool to be used + /// + /// This will override the memory_limit requested, + /// as the memory pool includes the limit. + fn with_memory_pool(mut self, memory_pool: Arc) -> Self { + self.memory_pool = Some(memory_pool); + self + } + /// Specify the configuration to use pub fn with_config(mut self, config: SessionConfig) -> Self { self.config = config; @@ -417,6 +499,7 @@ impl TestCase { query, expected_errors, memory_limit, + memory_pool, config, scenario, disk_manager_config, @@ -426,21 +509,27 @@ impl TestCase { let table = scenario.table(); - let rt_config = RuntimeConfig::new() - // do not allow spilling + let mut builder = RuntimeEnvBuilder::new() + // disk manager setting controls the spilling .with_disk_manager(disk_manager_config) .with_memory_limit(memory_limit, MEMORY_FRACTION); - let runtime = RuntimeEnv::new(rt_config).unwrap(); + if let Some(pool) = memory_pool { + builder = builder.with_memory_pool(pool); + }; + let runtime = builder.build_arc().unwrap(); // Configure execution - let state = SessionState::new_with_config_rt(config, Arc::new(runtime)); - let state = match scenario.rules() { - Some(rules) => state.with_physical_optimizer_rules(rules), - None => state, + let builder = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features(); + let builder = match scenario.rules() { + Some(rules) => builder.with_physical_optimizer_rules(rules), + None => builder, }; - let ctx = SessionContext::new_with_state(state); + let ctx = SessionContext::new_with_state(builder.build()); ctx.register_table("t", table).expect("registering table"); let query = query.expect("Test error: query not specified"); @@ -565,7 +654,7 @@ impl Scenario { descending: false, nulls_first: false, }; - let sort_information = vec![vec![ + let sort_information = vec![LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema).unwrap(), options, @@ -574,7 +663,7 @@ impl Scenario { expr: col("b", &schema).unwrap(), options, }, - ]]; + ])]; let table = SortedTableProvider::new(batches, sort_information); Arc::new(table) @@ -688,6 +777,7 @@ fn batches_byte_size(batches: &[RecordBatch]) -> usize { batches.iter().map(|b| b.get_array_memory_size()).sum() } +#[derive(Debug)] struct DummyStreamPartition { schema: SchemaRef, batches: Vec, @@ -709,6 +799,7 @@ impl PartitionStream for DummyStreamPartition { } /// Wrapper over a TableProvider that can provide ordering information +#[derive(Debug)] struct SortedTableProvider { schema: SchemaRef, batches: Vec>, @@ -742,14 +833,14 @@ impl TableProvider for SortedTableProvider { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, ) -> Result> { let mem_exec = MemoryExec::try_new(&self.batches, self.schema(), projection.cloned())? - .with_sort_information(self.sort_information.clone()); + .try_with_sort_information(self.sort_information.clone())?; Ok(Arc::new(mem_exec)) } diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer/mod.rs similarity index 64% rename from datafusion/core/tests/optimizer_integration.rs rename to datafusion/core/tests/optimizer/mod.rs index 5a7870b7a01c..f17d13a42060 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -23,11 +23,19 @@ use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; +use arrow_schema::{Fields, SchemaBuilder}; use datafusion_common::config::ConfigOptions; -use datafusion_common::{plan_err, Result}; -use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{plan_err, DFSchema, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; +use datafusion_expr::{ + col, lit, AggregateUDF, BinaryExpr, Expr, ExprSchemable, LogicalPlan, Operator, + ScalarUDF, TableSource, WindowUDF, +}; +use datafusion_functions::core::expr_ext::FieldAccessor; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; +use datafusion_optimizer::simplify_expressions::GuaranteeRewriter; use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -73,14 +81,13 @@ fn timestamp_nano_ts_utc_predicates() { let sql = "SELECT col_int32 FROM test WHERE col_ts_nano_utc < (now() - interval '1 hour')"; - let plan = test_sql(sql).unwrap(); // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned let expected = "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"+00:00\"))\ \n TableScan: test projection=[col_int32, col_ts_nano_utc]"; - assert_eq!(expected, format!("{plan:?}")); + quick_test(sql, expected); } #[test] @@ -109,7 +116,7 @@ fn concat_ws_literals() -> Result<()> { fn quick_test(sql: &str, expected_plan: &str) { let plan = test_sql(sql).unwrap(); - assert_eq!(expected_plan, format!("{:?}", plan)); + assert_eq!(expected_plan, format!("{}", plan)); } fn test_sql(sql: &str) -> Result { @@ -135,7 +142,7 @@ fn test_sql(sql: &str) -> Result { let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; + let plan = analyzer.execute_and_check(plan, config.options(), |_, _| {})?; optimizer.optimize(plan, &config, |_, _| {}) } @@ -207,15 +214,15 @@ impl ContextProvider for MyContextProvider { &self.options } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } @@ -233,3 +240,120 @@ impl TableSource for MyTableSource { self.schema.clone() } } + +#[test] +fn test_nested_schema_nullability() { + let mut builder = SchemaBuilder::new(); + builder.push(Field::new("foo", DataType::Int32, true)); + builder.push(Field::new( + "parent", + DataType::Struct(Fields::from(vec![Field::new( + "child", + DataType::Int64, + false, + )])), + true, + )); + let schema = builder.finish(); + + let dfschema = DFSchema::from_field_specific_qualified_schema( + vec![Some("table_name".into()), None], + &Arc::new(schema), + ) + .unwrap(); + + let expr = col("parent").field("child"); + assert!(expr.nullable(&dfschema).unwrap()); +} + +#[test] +fn test_inequalities_non_null_bounded() { + let guarantees = vec![ + // x ∈ [1, 3] (not null) + ( + col("x"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), + }, + ), + // s.y ∈ [1, 3] (not null) + ( + col("s").field("y"), + NullableInterval::NotNull { + values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), + }, + ), + ]; + + let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); + + // (original_expr, expected_simplification) + let simplified_cases = &[ + (col("x").lt(lit(0)), false), + (col("s").field("y").lt(lit(0)), false), + (col("x").lt_eq(lit(3)), true), + (col("x").gt(lit(3)), false), + (col("x").gt(lit(0)), true), + (col("x").eq(lit(0)), false), + (col("x").not_eq(lit(0)), true), + (col("x").between(lit(0), lit(5)), true), + (col("x").between(lit(5), lit(10)), false), + (col("x").not_between(lit(0), lit(5)), false), + (col("x").not_between(lit(5), lit(10)), true), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(ScalarValue::Null)), + }), + true, + ), + ( + Expr::BinaryExpr(BinaryExpr { + left: Box::new(col("x")), + op: Operator::IsDistinctFrom, + right: Box::new(lit(5)), + }), + true, + ), + ]; + + validate_simplified_cases(&mut rewriter, simplified_cases); + + let unchanged_cases = &[ + col("x").gt(lit(2)), + col("x").lt_eq(lit(2)), + col("x").eq(lit(2)), + col("x").not_eq(lit(2)), + col("x").between(lit(3), lit(5)), + col("x").not_between(lit(3), lit(10)), + ]; + + validate_unchanged_cases(&mut rewriter, unchanged_cases); +} + +fn validate_simplified_cases(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)]) +where + ScalarValue: From, + T: Clone, +{ + for (expr, expected_value) in cases { + let output = expr.clone().rewrite(rewriter).data().unwrap(); + let expected = lit(ScalarValue::from(expected_value.clone())); + assert_eq!( + output, expected, + "{} simplified to {}, but expected {}", + expr, output, expected + ); + } +} +fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { + for expr in cases { + let output = expr.clone().rewrite(rewriter).data().unwrap(); + assert_eq!( + &output, expr, + "{} was simplified to {}, but expected it to be unchanged", + expr, output + ); + } +} diff --git a/datafusion/core/tests/parquet/custom_reader.rs b/datafusion/core/tests/parquet/custom_reader.rs index 4bacc80579ed..7c1e199ceb95 100644 --- a/datafusion/core/tests/parquet/custom_reader.rs +++ b/datafusion/core/tests/parquet/custom_reader.rs @@ -30,8 +30,8 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ FileMeta, FileScanConfig, ParquetExec, ParquetFileMetrics, ParquetFileReaderFactory, }; +use datafusion::physical_plan::collect; use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion::physical_plan::{collect, Statistics}; use datafusion::prelude::SessionContext; use datafusion_common::Result; @@ -63,33 +63,27 @@ async fn route_data_access_ops_to_parquet_file_reader_factory() { let file_schema = batch.schema().clone(); let (in_memory_object_store, parquet_files_meta) = store_parquet_in_memory(vec![batch]).await; - let file_groups = parquet_files_meta + let file_group = parquet_files_meta .into_iter() .map(|meta| PartitionedFile { object_meta: meta, partition_values: vec![], range: None, + statistics: None, extensions: Some(Arc::new(String::from(EXPECTED_USER_DEFINED_METADATA))), }) .collect(); // prepare the scan - let parquet_exec = ParquetExec::new( - FileScanConfig { + let parquet_exec = ParquetExec::builder( + FileScanConfig::new( // just any url that doesn't point to in memory object store - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: vec![file_groups], - statistics: Statistics::new_unknown(&file_schema), + ObjectStoreUrl::local_filesystem(), file_schema, - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), + ) + .with_file_group(file_group), ) + .build() .with_parquet_file_reader_factory(Arc::new(InMemoryParquetFileReaderFactory( Arc::clone(&in_memory_object_store), ))); @@ -198,7 +192,7 @@ async fn store_parquet_in_memory( let mut objects = Vec::with_capacity(parquet_batches.len()); for (meta, bytes) in parquet_batches { in_memory - .put(&meta.location, bytes) + .put(&meta.location, bytes.into()) .await .expect("put parquet file into in memory object store"); objects.push(meta); diff --git a/datafusion/core/tests/parquet/external_access_plan.rs b/datafusion/core/tests/parquet/external_access_plan.rs new file mode 100644 index 000000000000..03afc858dfca --- /dev/null +++ b/datafusion/core/tests/parquet/external_access_plan.rs @@ -0,0 +1,418 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for passing user provided [`ParquetAccessPlan`]` to `ParquetExec`]` +use crate::parquet::utils::MetricsFinder; +use crate::parquet::{create_data_batch, Scenario}; +use arrow::util::pretty::pretty_format_batches; +use arrow_schema::SchemaRef; +use datafusion::common::Result; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::parquet::{ParquetAccessPlan, RowGroupAccess}; +use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::prelude::SessionContext; +use datafusion_common::{assert_contains, DFSchema}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_expr::{col, lit, Expr}; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::ExecutionPlan; +use parquet::arrow::arrow_reader::{RowSelection, RowSelector}; +use parquet::arrow::ArrowWriter; +use parquet::file::properties::WriterProperties; +use std::sync::{Arc, OnceLock}; +use tempfile::NamedTempFile; + +#[tokio::test] +async fn none() { + // no user defined plan + Test { + access_plan: None, + expected_rows: 10, + } + .run_success() + .await; +} + +#[tokio::test] +async fn scan_all() { + let parquet_metrics = Test { + access_plan: Some(ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Scan, + ])), + expected_rows: 10, + } + .run_success() + .await; + + // Verify that some bytes were read + let bytes_scanned = metric_value(&parquet_metrics, "bytes_scanned").unwrap(); + assert_ne!(bytes_scanned, 0, "metrics : {parquet_metrics:#?}",); +} + +#[tokio::test] +async fn skip_all() { + let parquet_metrics = Test { + access_plan: Some(ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, + RowGroupAccess::Skip, + ])), + expected_rows: 0, + } + .run_success() + .await; + + // Verify that skipping all row groups skips reading any data at all + let bytes_scanned = metric_value(&parquet_metrics, "bytes_scanned").unwrap(); + assert_eq!(bytes_scanned, 0, "metrics : {parquet_metrics:#?}",); +} + +#[tokio::test] +async fn skip_one_row_group() { + let plans = vec![ + ParquetAccessPlan::new(vec![RowGroupAccess::Scan, RowGroupAccess::Skip]), + ParquetAccessPlan::new(vec![RowGroupAccess::Skip, RowGroupAccess::Scan]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 5, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn selection_scan() { + let plans = vec![ + ParquetAccessPlan::new(vec![ + RowGroupAccess::Scan, + RowGroupAccess::Selection(select_one_row()), + ]), + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_one_row()), + RowGroupAccess::Scan, + ]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 6, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn skip_scan() { + let plans = vec![ + // skip one row group, scan the toehr + ParquetAccessPlan::new(vec![ + RowGroupAccess::Skip, + RowGroupAccess::Selection(select_one_row()), + ]), + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_one_row()), + RowGroupAccess::Skip, + ]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 1, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn plan_and_filter() { + // show that row group pruning is applied even when an initial plan is supplied + + // No rows match this predicate + let predicate = col("utf8").eq(lit("z")); + + // user supplied access plan specifies to still read a row group + let access_plan = Some(ParquetAccessPlan::new(vec![ + // Row group 0 has values a-d + RowGroupAccess::Skip, + // Row group 1 has values e-i + RowGroupAccess::Scan, + ])); + + // initia + let parquet_metrics = TestFull { + access_plan, + expected_rows: 0, + predicate: Some(predicate), + } + .run() + .await + .unwrap(); + + // Verify that row group pruning still happens for just that group + let row_groups_pruned_statistics = + metric_value(&parquet_metrics, "row_groups_pruned_statistics").unwrap(); + assert_eq!( + row_groups_pruned_statistics, 1, + "metrics : {parquet_metrics:#?}", + ); +} + +#[tokio::test] +async fn two_selections() { + let plans = vec![ + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_one_row()), + RowGroupAccess::Selection(select_two_rows()), + ]), + ParquetAccessPlan::new(vec![ + RowGroupAccess::Selection(select_two_rows()), + RowGroupAccess::Selection(select_one_row()), + ]), + ]; + + for access_plan in plans { + Test { + access_plan: Some(access_plan), + expected_rows: 3, + } + .run_success() + .await; + } +} + +#[tokio::test] +async fn bad_row_groups() { + let err = TestFull { + access_plan: Some(ParquetAccessPlan::new(vec![ + // file has only 2 row groups, but specify 3 + RowGroupAccess::Scan, + RowGroupAccess::Skip, + RowGroupAccess::Scan, + ])), + expected_rows: 0, + predicate: None, + } + .run() + .await + .unwrap_err(); + let err_string = err.to_string(); + assert_contains!(&err_string, "Invalid ParquetAccessPlan"); + assert_contains!(&err_string, "Specified 3 row groups, but file has 2"); +} + +#[tokio::test] +async fn bad_selection() { + let err = TestFull { + access_plan: Some(ParquetAccessPlan::new(vec![ + // specify fewer rows than are actually in the row group + RowGroupAccess::Selection(RowSelection::from(vec![ + RowSelector::skip(1), + RowSelector::select(3), + ])), + RowGroupAccess::Skip, + ])), + // expects that we hit an error, this should not be run + expected_rows: 10000, + predicate: None, + } + .run() + .await + .unwrap_err(); + let err_string = err.to_string(); + assert_contains!(&err_string, "Internal error: Invalid ParquetAccessPlan Selection. Row group 0 has 5 rows but selection only specifies 4 rows"); +} + +/// Return a RowSelection of 1 rows from a row group of 5 rows +fn select_one_row() -> RowSelection { + RowSelection::from(vec![ + RowSelector::skip(2), + RowSelector::select(1), + RowSelector::skip(2), + ]) +} +/// Return a RowSelection of 2 rows from a row group of 5 rows +fn select_two_rows() -> RowSelection { + RowSelection::from(vec![ + RowSelector::skip(1), + RowSelector::select(1), + RowSelector::skip(1), + RowSelector::select(1), + RowSelector::skip(1), + ]) +} + +/// Test for passing user defined ParquetAccessPlans. See [`TestFull`] for details. +#[derive(Debug)] +struct Test { + access_plan: Option, + expected_rows: usize, +} + +impl Test { + /// Runs the test case, panic'ing on error. + /// + /// Returns the `MetricsSet` from the ParqeutExec + async fn run_success(self) -> MetricsSet { + let Self { + access_plan, + expected_rows, + } = self; + TestFull { + access_plan, + expected_rows, + predicate: None, + } + .run() + .await + .unwrap() + } +} + +/// Test for passing user defined ParquetAccessPlans: +/// +/// 1. Creates a parquet file with 2 row groups, each with 5 rows +/// 2. Reads the parquet file with an optional user provided access plan +/// 3. Verifies that the expected number of rows is read +/// 4. Returns the statistics from running the plan +struct TestFull { + access_plan: Option, + expected_rows: usize, + predicate: Option, +} + +impl TestFull { + async fn run(self) -> Result { + let ctx = SessionContext::new(); + + let Self { + access_plan, + expected_rows, + predicate, + } = self; + + let TestData { + temp_file: _, + schema, + file_name, + file_size, + } = get_test_data(); + + let mut partitioned_file = PartitionedFile::new(file_name, *file_size); + + // add the access plan, if any, as an extension + if let Some(access_plan) = access_plan { + partitioned_file = partitioned_file.with_extensions(Arc::new(access_plan)); + } + + // Create a ParquetExec to read the file + let object_store_url = ObjectStoreUrl::local_filesystem(); + let config = FileScanConfig::new(object_store_url, schema.clone()) + .with_file(partitioned_file); + + let mut builder = ParquetExec::builder(config); + + // add the predicate, if requested + if let Some(predicate) = predicate { + let df_schema = DFSchema::try_from(schema.clone())?; + let predicate = ctx.create_physical_expr(predicate, &df_schema)?; + builder = builder.with_predicate(predicate); + } + + let plan: Arc = builder.build_arc(); + + // run the ParquetExec and collect the results + let results = + datafusion::physical_plan::collect(Arc::clone(&plan), ctx.task_ctx()).await?; + + // calculate the total number of rows that came out + let total_rows = results.iter().map(|b| b.num_rows()).sum::(); + assert_eq!( + total_rows, + expected_rows, + "results: \n{}", + pretty_format_batches(&results).unwrap() + ); + + Ok(MetricsFinder::find_metrics(plan.as_ref()).unwrap()) + } +} + +// Holds necessary data for these tests to reuse the same parquet file +struct TestData { + // field is present as on drop the file is deleted + #[allow(dead_code)] + temp_file: NamedTempFile, + schema: SchemaRef, + file_name: String, + file_size: u64, +} + +static TEST_DATA: OnceLock = OnceLock::new(); + +/// Return a parquet file with 2 row groups each with 5 rows +fn get_test_data() -> &'static TestData { + TEST_DATA.get_or_init(|| { + let scenario = Scenario::UTF8; + let row_per_group = 5; + + let mut temp_file = tempfile::Builder::new() + .prefix("user_access_plan") + .suffix(".parquet") + .tempfile() + .expect("tempfile creation"); + + let props = WriterProperties::builder() + .set_max_row_group_size(row_per_group) + .build(); + + let batches = create_data_batch(scenario); + let schema = batches[0].schema(); + + let mut writer = + ArrowWriter::try_new(&mut temp_file, schema.clone(), Some(props)).unwrap(); + + for batch in batches { + writer.write(&batch).expect("writing batch"); + } + writer.close().unwrap(); + + let file_name = temp_file.path().to_string_lossy().to_string(); + let file_size = temp_file.path().metadata().unwrap().len(); + + TestData { + temp_file, + schema, + file_name, + file_size, + } + }) +} + +/// Return the total value of the specified metric name +fn metric_value(parquet_metrics: &MetricsSet, metric_name: &str) -> Option { + parquet_metrics + .sum(|metric| metric.value().name() == metric_name) + .map(|v| v.as_usize()) +} diff --git a/datafusion/core/tests/parquet/file_statistics.rs b/datafusion/core/tests/parquet/file_statistics.rs index 9f94a59a3e59..4b5d22bfa71f 100644 --- a/datafusion/core/tests/parquet/file_statistics.rs +++ b/datafusion/core/tests/parquet/file_statistics.rs @@ -28,15 +28,36 @@ use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; use datafusion_common::stats::Precision; use datafusion_execution::cache::cache_manager::CacheManagerConfig; -use datafusion_execution::cache::cache_unit; use datafusion_execution::cache::cache_unit::{ DefaultFileStatisticsCache, DefaultListFilesCache, }; use datafusion_execution::config::SessionConfig; -use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion_execution::runtime_env::RuntimeEnvBuilder; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion_expr::{col, lit, Expr}; use tempfile::tempdir; +#[tokio::test] +async fn check_stats_precision_with_filter_pushdown() { + let testdata = datafusion::test_util::parquet_test_data(); + let filename = format!("{}/{}", testdata, "alltypes_plain.parquet"); + let table_path = ListingTableUrl::parse(filename).unwrap(); + + let opt = ListingOptions::new(Arc::new(ParquetFormat::default())); + let table = get_listing_table(&table_path, None, &opt).await; + let (_, _, state) = get_cache_runtime_state(); + // Scan without filter, stats are exact + let exec = table.scan(&state, None, &[], None).await.unwrap(); + assert_eq!(exec.statistics().unwrap().num_rows, Precision::Exact(8)); + + // Scan with filter pushdown, stats are inexact + let filter = Expr::gt(col("id"), lit(1)); + + let exec = table.scan(&state, None, &[filter], None).await.unwrap(); + assert_eq!(exec.statistics().unwrap().num_rows, Precision::Inexact(8)); +} + #[tokio::test] async fn load_table_stats_with_session_level_cache() { let testdata = datafusion::test_util::parquet_test_data(); @@ -167,10 +188,7 @@ async fn get_listing_table( ) -> ListingTable { let schema = opt .infer_schema( - &SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ), + &SessionStateBuilder::new().with_default_features().build(), table_path, ) .await @@ -192,16 +210,18 @@ fn get_cache_runtime_state() -> ( SessionState, ) { let cache_config = CacheManagerConfig::default(); - let file_static_cache = Arc::new(cache_unit::DefaultFileStatisticsCache::default()); - let list_file_cache = Arc::new(cache_unit::DefaultListFilesCache::default()); + let file_static_cache = Arc::new(DefaultFileStatisticsCache::default()); + let list_file_cache = Arc::new(DefaultListFilesCache::default()); let cache_config = cache_config .with_files_statistics_cache(Some(file_static_cache.clone())) .with_list_files_cache(Some(list_file_cache.clone())); - let rt = Arc::new( - RuntimeEnv::new(RuntimeConfig::new().with_cache_manager(cache_config)).unwrap(), - ); + let rt = RuntimeEnvBuilder::new() + .with_cache_manager(cache_config) + .build_arc() + .expect("could not build runtime environment"); + let state = SessionContext::new_with_config_rt(SessionConfig::default(), rt).state(); (file_static_cache, list_file_cache, state) diff --git a/datafusion/core/tests/parquet/filter_pushdown.rs b/datafusion/core/tests/parquet/filter_pushdown.rs index feb928a3a474..8def192f9331 100644 --- a/datafusion/core/tests/parquet/filter_pushdown.rs +++ b/datafusion/core/tests/parquet/filter_pushdown.rs @@ -529,7 +529,7 @@ impl<'a> TestCase<'a> { // verify expected pushdown let metrics = - TestParquetFile::parquet_metrics(exec).expect("found parquet metrics"); + TestParquetFile::parquet_metrics(&exec).expect("found parquet metrics"); let pushdown_expected = if scan_options.pushdown_filters { self.pushdown_expected @@ -538,24 +538,28 @@ impl<'a> TestCase<'a> { PushdownExpected::None }; - let pushdown_rows_filtered = get_value(&metrics, "pushdown_rows_filtered"); - println!(" pushdown_rows_filtered: {pushdown_rows_filtered}"); + let pushdown_rows_pruned = get_value(&metrics, "pushdown_rows_pruned"); + println!(" pushdown_rows_pruned: {pushdown_rows_pruned}"); + let pushdown_rows_matched = get_value(&metrics, "pushdown_rows_matched"); + println!(" pushdown_rows_matched: {pushdown_rows_matched}"); match pushdown_expected { PushdownExpected::None => { - assert_eq!(pushdown_rows_filtered, 0, "{}", self.name); + assert_eq!(pushdown_rows_pruned, 0, "{}", self.name); } PushdownExpected::Some => { assert!( - pushdown_rows_filtered > 0, + pushdown_rows_pruned > 0, "{}: Expected to filter rows via pushdown, but none were", self.name ); } }; - let page_index_rows_filtered = get_value(&metrics, "page_index_rows_filtered"); - println!(" page_index_rows_filtered: {page_index_rows_filtered}"); + let page_index_rows_pruned = get_value(&metrics, "page_index_rows_pruned"); + println!(" page_index_rows_pruned: {page_index_rows_pruned}"); + let page_index_rows_matched = get_value(&metrics, "page_index_rows_matched"); + println!(" page_index_rows_matched: {page_index_rows_matched}"); let page_index_filtering_expected = if scan_options.enable_page_index { self.page_index_filtering_expected @@ -567,11 +571,11 @@ impl<'a> TestCase<'a> { match page_index_filtering_expected { PageIndexFilteringExpected::None => { - assert_eq!(page_index_rows_filtered, 0); + assert_eq!(page_index_rows_pruned, 0); } PageIndexFilteringExpected::Some => { assert!( - page_index_rows_filtered > 0, + page_index_rows_pruned > 0, "Expected to filter rows via page index but none were", ); } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index bb938e3af493..cfa2a3df3ba2 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -16,11 +16,13 @@ // under the License. //! Parquet integration tests +use crate::parquet::utils::MetricsFinder; use arrow::array::Decimal128Array; use arrow::{ array::{ - Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, + FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, LargeBinaryArray, LargeStringArray, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, @@ -28,20 +30,22 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::make_array; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ - datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, - physical_plan::{accept, metrics::MetricsSet, ExecutionPlan, ExecutionPlanVisitor}, + datasource::{provider_as_source, TableProvider}, + physical_plan::metrics::MetricsSet, prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use parquet::arrow::ArrowWriter; -use parquet::file::properties::WriterProperties; +use parquet::file::properties::{EnabledStatistics, WriterProperties}; use std::sync::Arc; use tempfile::NamedTempFile; mod custom_reader; +// Don't run on windows as tempfiles don't seem to work the same +#[cfg(not(target_os = "windows"))] +mod external_access_plan; mod file_statistics; #[cfg(not(target_family = "windows"))] mod filter_pushdown; @@ -49,6 +53,7 @@ mod page_pruning; mod row_group_pruning; mod schema; mod schema_coercion; +mod utils; #[cfg(test)] #[ctor::ctor] @@ -62,6 +67,7 @@ fn init() { // ---------------------- /// What data to use +#[derive(Debug, Clone, Copy)] enum Scenario { Timestamps, Dates, @@ -75,10 +81,12 @@ enum Scenario { DecimalBloomFilterInt64, DecimalLargePrecision, DecimalLargePrecisionBloomFilter, + /// StringArray, BinaryArray, FixedSizeBinaryArray ByteArray, PeriodsInColumnNames, WithNullValues, WithNullValuesPageLevel, + UTF8, } enum Unit { @@ -163,7 +171,7 @@ impl TestOutput { /// The number of row pages pruned fn row_pages_pruned(&self) -> Option { - self.metric_value("page_index_rows_filtered") + self.metric_value("page_index_rows_pruned") } fn description(&self) -> String { @@ -275,25 +283,8 @@ impl ContextWithParquet { .expect("Running"); // find the parquet metrics - struct MetricsFinder { - metrics: Option, - } - impl ExecutionPlanVisitor for MetricsFinder { - type Error = std::convert::Infallible; - fn pre_visit( - &mut self, - plan: &dyn ExecutionPlan, - ) -> Result { - if plan.as_any().downcast_ref::().is_some() { - self.metrics = plan.metrics(); - } - // stop searching once we have found the metrics - Ok(self.metrics.is_none()) - } - } - let mut finder = MetricsFinder { metrics: None }; - accept(physical_plan.as_ref(), &mut finder).unwrap(); - let parquet_metrics = finder.metrics.unwrap(); + let parquet_metrics = + MetricsFinder::find_metrics(physical_plan.as_ref()).unwrap(); let result_rows = results.iter().map(|b| b.num_rows()).sum(); @@ -315,9 +306,13 @@ impl ContextWithParquet { /// /// Columns are named: /// "nanos" --> TimestampNanosecondArray +/// "nanos_timezoned" --> TimestampNanosecondArray with timezone /// "micros" --> TimestampMicrosecondArray +/// "micros_timezoned" --> TimestampMicrosecondArray with timezone /// "millis" --> TimestampMillisecondArray +/// "millis_timezoned" --> TimestampMillisecondArray with timezone /// "seconds" --> TimestampSecondArray +/// "seconds_timezoned" --> TimestampSecondArray with timezone /// "names" --> StringArray fn make_timestamp_batch(offset: Duration) -> RecordBatch { let ts_strings = vec![ @@ -328,6 +323,8 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { Some("2020-01-02T01:01:01.0000000000001"), ]; + let tz_string = "Pacific/Efate"; + let offset_nanos = offset.num_nanoseconds().expect("non overflow nanos"); let ts_nanos = ts_strings @@ -365,19 +362,47 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { .map(|(i, _)| format!("Row {i} + {offset}")) .collect::>(); - let arr_nanos = TimestampNanosecondArray::from(ts_nanos); - let arr_micros = TimestampMicrosecondArray::from(ts_micros); - let arr_millis = TimestampMillisecondArray::from(ts_millis); - let arr_seconds = TimestampSecondArray::from(ts_seconds); + let arr_nanos = TimestampNanosecondArray::from(ts_nanos.clone()); + let arr_nanos_timezoned = + TimestampNanosecondArray::from(ts_nanos).with_timezone(tz_string); + let arr_micros = TimestampMicrosecondArray::from(ts_micros.clone()); + let arr_micros_timezoned = + TimestampMicrosecondArray::from(ts_micros).with_timezone(tz_string); + let arr_millis = TimestampMillisecondArray::from(ts_millis.clone()); + let arr_millis_timezoned = + TimestampMillisecondArray::from(ts_millis).with_timezone(tz_string); + let arr_seconds = TimestampSecondArray::from(ts_seconds.clone()); + let arr_seconds_timezoned = + TimestampSecondArray::from(ts_seconds).with_timezone(tz_string); let names = names.iter().map(|s| s.as_str()).collect::>(); let arr_names = StringArray::from(names); let schema = Schema::new(vec![ Field::new("nanos", arr_nanos.data_type().clone(), true), + Field::new( + "nanos_timezoned", + arr_nanos_timezoned.data_type().clone(), + true, + ), Field::new("micros", arr_micros.data_type().clone(), true), + Field::new( + "micros_timezoned", + arr_micros_timezoned.data_type().clone(), + true, + ), Field::new("millis", arr_millis.data_type().clone(), true), + Field::new( + "millis_timezoned", + arr_millis_timezoned.data_type().clone(), + true, + ), Field::new("seconds", arr_seconds.data_type().clone(), true), + Field::new( + "seconds_timezoned", + arr_seconds_timezoned.data_type().clone(), + true, + ), Field::new("name", arr_names.data_type().clone(), true), ]); let schema = Arc::new(schema); @@ -386,9 +411,13 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { schema, vec![ Arc::new(arr_nanos), + Arc::new(arr_nanos_timezoned), Arc::new(arr_micros), + Arc::new(arr_micros_timezoned), Arc::new(arr_millis), + Arc::new(arr_millis_timezoned), Arc::new(arr_seconds), + Arc::new(arr_seconds_timezoned), Arc::new(arr_names), ], ) @@ -425,7 +454,7 @@ fn make_int_batches(start: i8, end: i8) -> RecordBatch { .unwrap() } -/// Return record batch with i8, i16, i32, and i64 sequences +/// Return record batch with u8, u16, u32, and u64 sequences /// /// Columns are named /// "u8" -> UInt8Array @@ -581,6 +610,8 @@ fn make_bytearray_batch( string_values: Vec<&str>, binary_values: Vec<&[u8]>, fixedsize_values: Vec<&[u8; 3]>, + // i64 offset. + large_binary_values: Vec<&[u8]>, ) -> RecordBatch { let num_rows = string_values.len(); let name: StringArray = std::iter::repeat(Some(name)).take(num_rows).collect(); @@ -591,6 +622,8 @@ fn make_bytearray_batch( .map(|value| Some(value.as_slice())) .collect::>() .into(); + let service_large_binary: LargeBinaryArray = + large_binary_values.iter().map(Some).collect(); let schema = Schema::new(vec![ Field::new("name", name.data_type().clone(), true), @@ -602,6 +635,11 @@ fn make_bytearray_batch( service_fixedsize.data_type().clone(), true, ), + Field::new( + "service_large_binary", + service_large_binary.data_type().clone(), + true, + ), ]); let schema = Arc::new(schema); @@ -612,6 +650,7 @@ fn make_bytearray_batch( Arc::new(service_string), Arc::new(service_binary), Arc::new(service_fixedsize), + Arc::new(service_large_binary), ], ) .unwrap() @@ -695,6 +734,16 @@ fn make_int_batches_with_null( .unwrap() } +fn make_utf8_batch(value: Vec>) -> RecordBatch { + let utf8 = StringArray::from(value.clone()); + let large_utf8 = LargeStringArray::from(value); + RecordBatch::try_from_iter(vec![ + ("utf8", Arc::new(utf8) as _), + ("large_utf8", Arc::new(large_utf8) as _), + ]) + .unwrap() +} + fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -735,6 +784,7 @@ fn create_data_batch(scenario: Scenario) -> Vec { Scenario::UInt32Range => { vec![make_uint32_range(0, 10), make_uint32_range(200000, 300000)] } + Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), @@ -751,6 +801,7 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![2000, 3000, 3000, 4000, 6000], 9, 2), ] } + Scenario::DecimalBloomFilterInt32 => { // decimal record batch vec![ @@ -806,6 +857,13 @@ fn create_data_batch(scenario: Scenario) -> Vec { b"frontend five", ], vec![b"fe1", b"fe2", b"fe3", b"fe7", b"fe5"], + vec![ + b"frontend one", + b"frontend two", + b"frontend three", + b"frontend seven", + b"frontend five", + ], ), make_bytearray_batch( "mixed", @@ -824,6 +882,13 @@ fn create_data_batch(scenario: Scenario) -> Vec { b"backend three", ], vec![b"fe6", b"fe4", b"be1", b"be2", b"be3"], + vec![ + b"frontend six", + b"frontend four", + b"backend one", + b"backend two", + b"backend three", + ], ), make_bytearray_batch( "all backends", @@ -842,9 +907,17 @@ fn create_data_batch(scenario: Scenario) -> Vec { b"backend eight", ], vec![b"be4", b"be5", b"be6", b"be7", b"be8"], + vec![ + b"backend four", + b"backend five", + b"backend six", + b"backend seven", + b"backend eight", + ], ), ] } + Scenario::PeriodsInColumnNames => { vec![ // all frontend @@ -879,6 +952,19 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_int_batches_with_null(5, 1, 6), ] } + + Scenario::UTF8 => { + vec![ + make_utf8_batch(vec![Some("a"), Some("b"), Some("c"), Some("d"), None]), + make_utf8_batch(vec![ + Some("e"), + Some("f"), + Some("g"), + Some("h"), + Some("i"), + ]), + ] + } } } @@ -893,10 +979,10 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem let props = WriterProperties::builder() .set_max_row_group_size(row_per_group) .set_bloom_filter_enabled(true) + .set_statistics_enabled(EnabledStatistics::Page) .build(); let batches = create_data_batch(scenario); - let schema = batches[0].schema(); let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index ccaa65b7ee5f..d201ed3a841f 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -27,7 +27,7 @@ use datafusion::execution::context::SessionState; use datafusion::physical_plan::metrics::MetricValue; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use datafusion_common::{ScalarValue, Statistics, ToDFSchema}; +use datafusion_common::{ScalarValue, ToDFSchema}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::{col, lit, Expr}; use datafusion_physical_expr::create_physical_expr; @@ -62,6 +62,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { object_meta: meta, partition_values: vec![], range: None, + statistics: None, extensions: None, }; @@ -69,23 +70,12 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { let execution_props = ExecutionProps::new(); let predicate = create_physical_expr(&filter, &df_schema, &execution_props).unwrap(); - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url, - file_groups: vec![vec![partitioned_file]], - file_schema: schema.clone(), - statistics: Statistics::new_unknown(&schema), - // file has 10 cols so index 12 should be month - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - Some(predicate), - None, - Default::default(), - ); - parquet_exec.with_enable_page_index(true) + ParquetExec::builder( + FileScanConfig::new(object_store_url, schema).with_file(partitioned_file), + ) + .with_predicate(predicate) + .build() + .with_enable_page_index(true) } #[tokio::test] @@ -159,8 +149,9 @@ async fn page_index_filter_one_col() { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - // 5.create filter date_string_col == 1; - let filter = col("date_string_col").eq(lit("01/01/09")); + // 5.create filter date_string_col == "01/01/09"`; + // Note this test doesn't apply type coercion so the literal must match the actual view type + let filter = col("date_string_col").eq(lit(ScalarValue::new_utf8view("01/01/09"))); let parquet_exec = get_parquet_exec(&state, filter).await; let mut results = parquet_exec.execute(0, task_ctx.clone()).unwrap(); let batch = results.next().await.unwrap().unwrap(); diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 1a174a325bd5..536ac5414a9a 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -109,7 +109,7 @@ impl RowGroupPruningTest { assert_eq!( output.predicate_evaluation_errors(), self.expected_errors, - "mismatched predicate_evaluation" + "mismatched predicate_evaluation error" ); assert_eq!( output.row_groups_matched_statistics(), diff --git a/datafusion/core/tests/parquet/schema.rs b/datafusion/core/tests/parquet/schema.rs index 1b572914d7bd..e13fbad24426 100644 --- a/datafusion/core/tests/parquet/schema.rs +++ b/datafusion/core/tests/parquet/schema.rs @@ -25,7 +25,7 @@ use datafusion_common::assert_batches_sorted_eq; #[tokio::test] async fn schema_merge_ignores_metadata_by_default() { - // Create several parquet files in same directoty / table with + // Create several parquet files in same directory / table with // same schema but different metadata let tmp_dir = TempDir::new().unwrap(); let table_dir = tmp_dir.path().join("parquet_test"); @@ -103,7 +103,7 @@ async fn schema_merge_ignores_metadata_by_default() { #[tokio::test] async fn schema_merge_can_preserve_metadata() { - // Create several parquet files in same directoty / table with + // Create several parquet files in same directory / table with // same schema but different metadata let tmp_dir = TempDir::new().unwrap(); let table_dir = tmp_dir.path().join("parquet_test"); diff --git a/datafusion/core/tests/parquet/schema_coercion.rs b/datafusion/core/tests/parquet/schema_coercion.rs index 88f795d2a4fe..af9411f40ecb 100644 --- a/datafusion/core/tests/parquet/schema_coercion.rs +++ b/datafusion/core/tests/parquet/schema_coercion.rs @@ -26,7 +26,7 @@ use datafusion::assert_batches_sorted_eq; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::physical_plan::collect; use datafusion::prelude::SessionContext; -use datafusion_common::{Result, Statistics}; +use datafusion_common::Result; use datafusion_execution::object_store::ObjectStoreUrl; use object_store::path::Path; @@ -51,7 +51,7 @@ async fn multi_parquet_coercion() { let batch2 = RecordBatch::try_from_iter(vec![("c2", c2), ("c3", c3)]).unwrap(); let (meta, _files) = store_parquet(vec![batch1, batch2]).await.unwrap(); - let file_groups = meta.into_iter().map(Into::into).collect(); + let file_group = meta.into_iter().map(Into::into).collect(); // cast c1 to utf8, c2 to int32, c3 to float64 let file_schema = Arc::new(Schema::new(vec![ @@ -59,21 +59,11 @@ async fn multi_parquet_coercion() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: vec![file_groups], - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), - ); + let parquet_exec = ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file_group(file_group), + ) + .build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); @@ -115,7 +105,7 @@ async fn multi_parquet_coercion_projection() { RecordBatch::try_from_iter(vec![("c2", c2), ("c1", c1s), ("c3", c3)]).unwrap(); let (meta, _files) = store_parquet(vec![batch1, batch2]).await.unwrap(); - let file_groups = meta.into_iter().map(Into::into).collect(); + let file_group = meta.into_iter().map(Into::into).collect(); // cast c1 to utf8, c2 to int32, c3 to float64 let file_schema = Arc::new(Schema::new(vec![ @@ -123,21 +113,12 @@ async fn multi_parquet_coercion_projection() { Field::new("c2", DataType::Int32, true), Field::new("c3", DataType::Float64, true), ])); - let parquet_exec = ParquetExec::new( - FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_groups: vec![file_groups], - statistics: Statistics::new_unknown(&file_schema), - file_schema, - projection: Some(vec![1, 0, 2]), - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }, - None, - None, - Default::default(), - ); + let parquet_exec = ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::local_filesystem(), file_schema) + .with_file_group(file_group) + .with_projection(Some(vec![1, 0, 2])), + ) + .build(); let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); diff --git a/datafusion/core/tests/parquet/utils.rs b/datafusion/core/tests/parquet/utils.rs new file mode 100644 index 000000000000..d8d2b2fbb8a5 --- /dev/null +++ b/datafusion/core/tests/parquet/utils.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for parquet tests + +use datafusion::datasource::physical_plan::ParquetExec; +use datafusion_physical_plan::metrics::MetricsSet; +use datafusion_physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor}; + +/// Find the metrics from the first ParquetExec encountered in the plan +#[derive(Debug)] +pub struct MetricsFinder { + metrics: Option, +} +impl MetricsFinder { + pub fn new() -> Self { + Self { metrics: None } + } + + /// Return the metrics if found + pub fn into_metrics(self) -> Option { + self.metrics + } + + pub fn find_metrics(plan: &dyn ExecutionPlan) -> Option { + let mut finder = Self::new(); + accept(plan, &mut finder).unwrap(); + finder.into_metrics() + } +} + +impl ExecutionPlanVisitor for MetricsFinder { + type Error = std::convert::Infallible; + fn pre_visit(&mut self, plan: &dyn ExecutionPlan) -> Result { + if plan.as_any().downcast_ref::().is_some() { + self.metrics = plan.metrics(); + } + // stop searching once we have found the metrics + Ok(self.metrics.is_none()) + } +} diff --git a/datafusion/core/tests/parquet_exec.rs b/datafusion/core/tests/parquet_exec.rs index 43ceb615a062..f41f82a76c67 100644 --- a/datafusion/core/tests/parquet_exec.rs +++ b/datafusion/core/tests/parquet_exec.rs @@ -15,5 +15,7 @@ // specific language governing permissions and limitations // under the License. +//! End to end test for `ParquetExec` and related components + /// Run all tests that are found in the `parquet` directory mod parquet; diff --git a/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs new file mode 100644 index 000000000000..85076abdaf29 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/combine_partial_final_agg.rs @@ -0,0 +1,292 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::datasource::listing::PartitionedFile; +use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; +use datafusion::physical_optimizer::combine_partial_final_agg::CombinePartialFinalAggregate; +use datafusion_common::config::ConfigOptions; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; +use datafusion_physical_expr::expressions::{col, lit}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::displayable; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::ExecutionPlan; + +/// Runs the CombinePartialFinalAggregate optimizer and asserts the plan against the expected +macro_rules! assert_optimized { + ($EXPECTED_LINES: expr, $PLAN: expr) => { + let expected_lines: Vec<&str> = $EXPECTED_LINES.iter().map(|s| *s).collect(); + + // run optimizer + let optimizer = CombinePartialFinalAggregate {}; + let config = ConfigOptions::new(); + let optimized = optimizer.optimize($PLAN, &config)?; + // Now format correctly + let plan = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&plan); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + }; +} + +fn trim_plan_display(plan: &str) -> Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() +} + +fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + ])) +} + +fn parquet_exec(schema: &SchemaRef) -> Arc { + ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema.clone()) + .with_file(PartitionedFile::new("x".to_string(), 100)), + ) + .build_arc() +} + +fn partial_aggregate_exec( + input: Arc, + group_by: PhysicalGroupBy, + aggr_expr: Vec>, +) -> Arc { + let schema = input.schema(); + let n_aggr = aggr_expr.len(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![None; n_aggr], + input, + schema, + ) + .unwrap(), + ) +} + +fn final_aggregate_exec( + input: Arc, + group_by: PhysicalGroupBy, + aggr_expr: Vec>, +) -> Arc { + let schema = input.schema(); + let n_aggr = aggr_expr.len(); + Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggr_expr, + vec![None; n_aggr], + input, + schema, + ) + .unwrap(), + ) +} + +fn repartition_exec(input: Arc) -> Arc { + Arc::new(RepartitionExec::try_new(input, Partitioning::RoundRobinBatch(10)).unwrap()) +} + +// Return appropriate expr depending if COUNT is for col or table (*) +fn count_expr( + expr: Arc, + name: &str, + schema: &Schema, +) -> Arc { + AggregateExprBuilder::new(count_udaf(), vec![expr]) + .schema(Arc::new(schema.clone())) + .alias(name) + .build() + .map(Arc::new) + .unwrap() +} + +#[test] +fn aggregations_not_combined() -> datafusion_common::Result<()> { + let schema = schema(); + + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + + let plan = final_aggregate_exec( + repartition_exec(partial_aggregate_exec( + parquet_exec(&schema), + PhysicalGroupBy::default(), + aggr_expr.clone(), + )), + PhysicalGroupBy::default(), + aggr_expr, + ); + // should not combine the Partial/Final AggregateExecs + let expected = &[ + "AggregateExec: mode=Final, gby=[], aggr=[COUNT(1)]", + "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=1", + "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + assert_optimized!(expected, plan); + + let aggr_expr1 = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + let aggr_expr2 = vec![count_expr(lit(1i8), "COUNT(2)", &schema)]; + + let plan = final_aggregate_exec( + partial_aggregate_exec( + parquet_exec(&schema), + PhysicalGroupBy::default(), + aggr_expr1, + ), + PhysicalGroupBy::default(), + aggr_expr2, + ); + // should not combine the Partial/Final AggregateExecs + let expected = &[ + "AggregateExec: mode=Final, gby=[], aggr=[COUNT(2)]", + "AggregateExec: mode=Partial, gby=[], aggr=[COUNT(1)]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + + Ok(()) +} + +#[test] +fn aggregations_combined() -> datafusion_common::Result<()> { + let schema = schema(); + let aggr_expr = vec![count_expr(lit(1i8), "COUNT(1)", &schema)]; + + let plan = final_aggregate_exec( + partial_aggregate_exec( + parquet_exec(&schema), + PhysicalGroupBy::default(), + aggr_expr.clone(), + ), + PhysicalGroupBy::default(), + aggr_expr, + ); + // should combine the Partial/Final AggregateExecs to the Single AggregateExec + let expected = &[ + "AggregateExec: mode=Single, gby=[], aggr=[COUNT(1)]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) +} + +#[test] +fn aggregations_with_group_combined() -> datafusion_common::Result<()> { + let schema = schema(); + let aggr_expr = vec![ + AggregateExprBuilder::new(sum_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("Sum(b)") + .build() + .map(Arc::new) + .unwrap(), + ]; + let groups: Vec<(Arc, String)> = + vec![(col("c", &schema)?, "c".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + let partial_agg = partial_aggregate_exec( + parquet_exec(&schema), + partial_group_by, + aggr_expr.clone(), + ); + + let groups: Vec<(Arc, String)> = + vec![(col("c", &partial_agg.schema())?, "c".to_string())]; + let final_group_by = PhysicalGroupBy::new_single(groups); + + let plan = final_aggregate_exec(partial_agg, final_group_by, aggr_expr); + // should combine the Partial/Final AggregateExecs to the Single AggregateExec + let expected = &[ + "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[Sum(b)]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) +} + +#[test] +fn aggregations_with_limit_combined() -> datafusion_common::Result<()> { + let schema = schema(); + let aggr_expr = vec![]; + + let groups: Vec<(Arc, String)> = + vec![(col("c", &schema)?, "c".to_string())]; + + let partial_group_by = PhysicalGroupBy::new_single(groups); + let partial_agg = partial_aggregate_exec( + parquet_exec(&schema), + partial_group_by, + aggr_expr.clone(), + ); + + let groups: Vec<(Arc, String)> = + vec![(col("c", &partial_agg.schema())?, "c".to_string())]; + let final_group_by = PhysicalGroupBy::new_single(groups); + + let schema = partial_agg.schema(); + let final_agg = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + final_group_by, + aggr_expr, + vec![], + partial_agg, + schema, + ) + .unwrap() + .with_limit(Some(5)), + ); + let plan: Arc = final_agg; + // should combine the Partial/Final AggregateExecs to a Single AggregateExec + // with the final limit preserved + let expected = &[ + "AggregateExec: mode=Single, gby=[c@2 as c], aggr=[], lim=[5]", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c]", + ]; + + assert_optimized!(expected, plan); + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs new file mode 100644 index 000000000000..1b4c28d41d19 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -0,0 +1,490 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions}; +use datafusion_common::config::ConfigOptions; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::BinaryExpr; +use datafusion_physical_expr::expressions::{col, lit}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +use datafusion_physical_optimizer::limit_pushdown::LimitPushdown; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::empty::EmptyExec; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::streaming::{PartitionStream, StreamingTableExec}; +use datafusion_physical_plan::{get_plan_string, ExecutionPlan, ExecutionPlanProperties}; +use std::sync::Arc; + +#[derive(Debug)] +struct DummyStreamPartition { + schema: SchemaRef, +} +impl PartitionStream for DummyStreamPartition { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + unreachable!() + } +} + +#[test] +fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero( +) -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema)?; + let global_limit = global_limit_exec(streaming_table, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=5" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn transforms_streaming_table_exec_into_fetching_version_and_keeps_the_global_limit_when_skip_is_nonzero( +) -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema)?; + let global_limit = global_limit_exec(streaming_table, 2, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=2, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "GlobalLimitExec: skip=2, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true, fetch=7" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn transforms_coalesce_batches_exec_into_fetching_version_and_removes_local_limit( +) -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema.clone())?; + let repartition = repartition_exec(streaming_table)?; + let filter = filter_exec(schema, repartition)?; + let coalesce_batches = coalesce_batches_exec(filter); + let local_limit = local_limit_exec(coalesce_batches, 5); + let coalesce_partitions = coalesce_partitions_exec(local_limit); + let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " LocalLimitExec: fetch=5", + " CoalesceBatchesExec: target_batch_size=8192", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " CoalesceBatchesExec: target_batch_size=8192, fetch=5", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn pushes_global_limit_exec_through_projection_exec() -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema.clone())?; + let filter = filter_exec(schema.clone(), streaming_table)?; + let projection = projection_exec(schema, filter)?; + let global_limit = global_limit_exec(projection, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " FilterExec: c3@2 > 0", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " GlobalLimitExec: skip=0, fetch=5", + " FilterExec: c3@2 > 0", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn pushes_global_limit_exec_through_projection_exec_and_transforms_coalesce_batches_exec_into_fetching_version( +) -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema.clone()).unwrap(); + let coalesce_batches = coalesce_batches_exec(streaming_table); + let projection = projection_exec(schema, coalesce_batches)?; + let global_limit = global_limit_exec(projection, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192, fetch=5", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn pushes_global_limit_into_multiple_fetch_plans() -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema.clone()).unwrap(); + let coalesce_batches = coalesce_batches_exec(streaming_table); + let projection = projection_exec(schema.clone(), coalesce_batches)?; + let repartition = repartition_exec(projection)?; + let sort = sort_exec( + vec![PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }], + repartition, + ); + let spm = sort_preserving_merge_exec(sort.output_ordering().unwrap().to_vec(), sort); + let global_limit = global_limit_exec(spm, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " SortPreservingMergeExec: [c1@0 ASC]", + " SortExec: expr=[c1@0 ASC], preserve_partitioning=[false]", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "SortPreservingMergeExec: [c1@0 ASC], fetch=5", + " SortExec: TopK(fetch=5), expr=[c1@0 ASC], preserve_partitioning=[false]", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3]", + " CoalesceBatchesExec: target_batch_size=8192", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn keeps_pushed_local_limit_exec_when_there_are_multiple_input_partitions( +) -> datafusion_common::Result<()> { + let schema = create_schema(); + let streaming_table = streaming_table_exec(schema.clone())?; + let repartition = repartition_exec(streaming_table)?; + let filter = filter_exec(schema, repartition)?; + let coalesce_partitions = coalesce_partitions_exec(filter); + let global_limit = global_limit_exec(coalesce_partitions, 0, Some(5)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = [ + "GlobalLimitExec: skip=0, fetch=5", + " CoalescePartitionsExec", + " FilterExec: c3@2 > 0", + " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", + " StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true" + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn merges_local_limit_with_local_limit() -> datafusion_common::Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let child_local_limit = local_limit_exec(empty_exec, 10); + let parent_local_limit = local_limit_exec(child_local_limit, 20); + + let initial = get_plan_string(&parent_local_limit); + let expected_initial = [ + "LocalLimitExec: fetch=20", + " LocalLimitExec: fetch=10", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(parent_local_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=0, fetch=10", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn merges_global_limit_with_global_limit() -> datafusion_common::Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let child_global_limit = global_limit_exec(empty_exec, 10, Some(30)); + let parent_global_limit = global_limit_exec(child_global_limit, 10, Some(20)); + + let initial = get_plan_string(&parent_global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=10, fetch=20", + " GlobalLimitExec: skip=10, fetch=30", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(parent_global_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn merges_global_limit_with_local_limit() -> datafusion_common::Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let local_limit = local_limit_exec(empty_exec, 40); + let global_limit = global_limit_exec(local_limit, 20, Some(30)); + + let initial = get_plan_string(&global_limit); + let expected_initial = [ + "GlobalLimitExec: skip=20, fetch=30", + " LocalLimitExec: fetch=40", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(global_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn merges_local_limit_with_global_limit() -> datafusion_common::Result<()> { + let schema = create_schema(); + let empty_exec = empty_exec(schema); + let global_limit = global_limit_exec(empty_exec, 20, Some(30)); + let local_limit = local_limit_exec(global_limit, 20); + + let initial = get_plan_string(&local_limit); + let expected_initial = [ + "LocalLimitExec: fetch=20", + " GlobalLimitExec: skip=20, fetch=30", + " EmptyExec", + ]; + + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(local_limit, &ConfigOptions::new())?; + + let expected = ["GlobalLimitExec: skip=20, fetch=20", " EmptyExec"]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +fn create_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + Field::new("c3", DataType::Int32, true), + ])) +} + +fn streaming_table_exec( + schema: SchemaRef, +) -> datafusion_common::Result> { + Ok(Arc::new(StreamingTableExec::try_new( + schema.clone(), + vec![Arc::new(DummyStreamPartition { schema }) as _], + None, + None, + true, + None, + )?)) +} + +fn global_limit_exec( + input: Arc, + skip: usize, + fetch: Option, +) -> Arc { + Arc::new(GlobalLimitExec::new(input, skip, fetch)) +} + +fn local_limit_exec( + input: Arc, + fetch: usize, +) -> Arc { + Arc::new(LocalLimitExec::new(input, fetch)) +} + +fn sort_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortExec::new(sort_exprs, input)) +} + +fn sort_preserving_merge_exec( + sort_exprs: impl IntoIterator, + input: Arc, +) -> Arc { + let sort_exprs = sort_exprs.into_iter().collect(); + Arc::new(SortPreservingMergeExec::new(sort_exprs, input)) +} + +fn projection_exec( + schema: SchemaRef, + input: Arc, +) -> datafusion_common::Result> { + Ok(Arc::new(ProjectionExec::try_new( + vec![ + (col("c1", schema.as_ref()).unwrap(), "c1".to_string()), + (col("c2", schema.as_ref()).unwrap(), "c2".to_string()), + (col("c3", schema.as_ref()).unwrap(), "c3".to_string()), + ], + input, + )?)) +} + +fn filter_exec( + schema: SchemaRef, + input: Arc, +) -> datafusion_common::Result> { + Ok(Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("c3", schema.as_ref()).unwrap(), + Operator::Gt, + lit(0), + )), + input, + )?)) +} + +fn coalesce_batches_exec(input: Arc) -> Arc { + Arc::new(CoalesceBatchesExec::new(input, 8192)) +} + +fn coalesce_partitions_exec( + local_limit: Arc, +) -> Arc { + Arc::new(CoalescePartitionsExec::new(local_limit)) +} + +fn repartition_exec( + streaming_table: Arc, +) -> datafusion_common::Result> { + Ok(Arc::new(RepartitionExec::try_new( + streaming_table, + Partitioning::RoundRobinBatch(8), + )?)) +} + +fn empty_exec(schema: SchemaRef) -> Arc { + Arc::new(EmptyExec::new(schema)) +} diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs new file mode 100644 index 000000000000..6910db6285a3 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -0,0 +1,441 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the limited distinct aggregation optimizer rule + +use super::test_util::{parquet_exec_with_sort, schema, trim_plan_display}; + +use std::sync::Arc; + +use arrow::{ + array::Int32Array, + compute::SortOptions, + datatypes::{DataType, Field, Schema}, + record_batch::RecordBatch, + util::pretty::pretty_format_batches, +}; +use arrow_schema::SchemaRef; +use datafusion::{prelude::SessionContext, test_util::TestAggregate}; +use datafusion_common::Result; +use datafusion_execution::config::SessionConfig; +use datafusion_expr::Operator; +use datafusion_physical_expr::{ + expressions::{cast, col}, + PhysicalExpr, PhysicalSortExpr, +}; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_optimizer::{ + limited_distinct_aggregation::LimitedDistinctAggregation, PhysicalOptimizerRule, +}; +use datafusion_physical_plan::{ + aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}, + collect, displayable, expressions, + limit::{GlobalLimitExec, LocalLimitExec}, + memory::MemoryExec, + ExecutionPlan, +}; + +fn mock_data() -> Result> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(1), + Some(4), + Some(5), + ])), + Arc::new(Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(2), + Some(8), + Some(9), + ])), + ], + )?; + + Ok(Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?)) +} + +fn assert_plan_matches_expected( + plan: &Arc, + expected: &[&str], +) -> Result<()> { + let expected_lines: Vec<&str> = expected.to_vec(); + let session_ctx = SessionContext::new(); + let state = session_ctx.state(); + + let optimized = LimitedDistinctAggregation::new() + .optimize(Arc::clone(plan), state.config_options())?; + + let optimized_result = displayable(optimized.as_ref()).indent(true).to_string(); + let actual_lines = trim_plan_display(&optimized_result); + + assert_eq!( + &expected_lines, &actual_lines, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected_lines, actual_lines + ); + + Ok(()) +} + +async fn assert_results_match_expected( + plan: Arc, + expected: &str, +) -> Result<()> { + let cfg = SessionConfig::new().with_target_partitions(1); + let ctx = SessionContext::new_with_config(cfg); + let batches = collect(plan, ctx.task_ctx()).await?; + let actual = format!("{}", pretty_format_batches(&batches)?); + assert_eq!(actual, expected); + Ok(()) +} + +pub fn build_group_by(input_schema: &SchemaRef, columns: Vec) -> PhysicalGroupBy { + let mut group_by_expr: Vec<(Arc, String)> = vec![]; + for column in columns.iter() { + group_by_expr.push((col(column, input_schema).unwrap(), column.to_string())); + } + PhysicalGroupBy::new_single(group_by_expr.clone()) +} + +#[tokio::test] +async fn test_partial_final() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Partial/Final AggregateExec + let partial_agg = AggregateExec::try_new( + AggregateMode::Partial, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let final_agg = AggregateExec::try_new( + AggregateMode::Final, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + Arc::new(partial_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(final_agg), + 4, // fetch + ); + // expected to push the limit to the Partial and Final AggregateExecs + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) +} + +#[tokio::test] +async fn test_single_local() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 4, // fetch + ); + // expected to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) +} + +#[tokio::test] +async fn test_single_global() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec GROUP BY a LIMIT 4;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = GlobalLimitExec::new( + Arc::new(single_agg), + 1, // skip + Some(3), // fetch + ); + // expected to push the skip+fetch limit to the AggregateExec + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) +} + +#[tokio::test] +async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT distinct a FROM MemoryExec GROUP BY a, b LIMIT 4;`, Single/Single AggregateExec + let group_by_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string(), "b".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let distinct_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + Arc::new(group_by_agg), /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(distinct_agg), + 4, // fetch + ); + // expected to push the limit to the outer AggregateExec only + let expected = [ + "LocalLimitExec: fetch=4", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], lim=[4]", + "AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + let expected = r#" ++---+ +| a | ++---+ +| 1 | +| 2 | +| | +| 4 | ++---+ +"# + .trim(); + assert_results_match_expected(plan, expected).await?; + Ok(()) +} + +#[test] +fn test_no_group_by() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema, vec![]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + source, /* input */ + schema, /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[], aggr=[]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) +} + +#[test] +fn test_has_aggregate_expression() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + let agg = TestAggregate::new_count_star(); + + // `SELECT FROM MemoryExec LIMIT 10;`, Single AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema, vec!["a".to_string()]), + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![None], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) +} + +#[test] +fn test_has_filter() -> Result<()> { + let source = mock_data()?; + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let filter_expr = Some(expressions::binary( + col("a", &schema)?, + Operator::Gt, + cast(expressions::lit(1u32), &schema, DataType::Int32)?, + &schema, + )?); + let agg = TestAggregate::new_count_star(); + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema.clone(), vec!["a".to_string()]), + vec![Arc::new(agg.count_expr(&schema))], /* aggr_expr */ + vec![filter_expr], /* filter_expr */ + source, /* input */ + schema.clone(), /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + // TODO(msirek): open an issue for `filter_expr` of `AggregateExec` not printing out + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[COUNT(*)]", + "MemoryExec: partitions=1, partition_sizes=[1]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) +} + +#[test] +fn test_has_order_by() -> Result<()> { + let sort_key = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema()).unwrap(), + options: SortOptions::default(), + }]); + let source = parquet_exec_with_sort(vec![sort_key]); + let schema = source.schema(); + + // `SELECT a FROM MemoryExec WHERE a > 1 GROUP BY a LIMIT 10;`, Single AggregateExec + // the `a > 1` filter is applied in the AggregateExec + let single_agg = AggregateExec::try_new( + AggregateMode::Single, + build_group_by(&schema, vec!["a".to_string()]), + vec![], /* aggr_expr */ + vec![], /* filter_expr */ + source, /* input */ + schema, /* input_schema */ + )?; + let limit_exec = LocalLimitExec::new( + Arc::new(single_agg), + 10, // fetch + ); + // expected not to push the limit to the AggregateExec + let expected = [ + "LocalLimitExec: fetch=10", + "AggregateExec: mode=Single, gby=[a@0 as a], aggr=[], ordering_mode=Sorted", + "ParquetExec: file_groups={1 group: [[x]]}, projection=[a, b, c, d, e], output_ordering=[a@0 ASC]", + ]; + let plan: Arc = Arc::new(limit_exec); + assert_plan_matches_expected(&plan, &expected)?; + Ok(()) +} diff --git a/datafusion/core/tests/physical_optimizer/mod.rs b/datafusion/core/tests/physical_optimizer/mod.rs new file mode 100644 index 000000000000..c06783aa0277 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/mod.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod combine_partial_final_agg; +mod limit_pushdown; +mod limited_distinct_aggregation; +mod test_util; diff --git a/datafusion/core/tests/physical_optimizer/test_util.rs b/datafusion/core/tests/physical_optimizer/test_util.rs new file mode 100644 index 000000000000..12cd08fb3db3 --- /dev/null +++ b/datafusion/core/tests/physical_optimizer/test_util.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Test utilities for physical optimizer tests + +use std::sync::Arc; + +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion::datasource::{ + listing::PartitionedFile, + physical_plan::{FileScanConfig, ParquetExec}, +}; +use datafusion_execution::object_store::ObjectStoreUrl; +use datafusion_physical_expr_common::sort_expr::LexOrdering; + +/// create a single parquet file that is sorted +pub(crate) fn parquet_exec_with_sort( + output_ordering: Vec, +) -> Arc { + ParquetExec::builder( + FileScanConfig::new(ObjectStoreUrl::parse("test:///").unwrap(), schema()) + .with_file(PartitionedFile::new("x".to_string(), 100)) + .with_output_ordering(output_ordering), + ) + .build_arc() +} + +pub(crate) fn schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int32, true), + Field::new("e", DataType::Boolean, true), + ])) +} + +pub(crate) fn trim_plan_display(plan: &str) -> Vec<&str> { + plan.split('\n') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect() +} diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 84b791a3de05..1f10cb244e83 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -35,9 +35,9 @@ async fn csv_query_array_agg_distinct() -> Result<()> { assert_eq!( *actual[0].schema(), Schema::new(vec![Field::new_list( - "ARRAY_AGG(DISTINCT aggregate_test_100.c2)", + "array_agg(DISTINCT aggregate_test_100.c2)", Field::new("item", DataType::UInt32, true), - false + true ),]) ); @@ -69,12 +69,12 @@ async fn csv_query_array_agg_distinct() -> Result<()> { #[tokio::test] async fn count_partitioned() -> Result<()> { let results = - execute_with_partition("SELECT COUNT(c1), COUNT(c2) FROM test", 4).await?; + execute_with_partition("SELECT count(c1), count(c2) FROM test", 4).await?; assert_eq!(results.len(), 1); let expected = [ "+----------------+----------------+", - "| COUNT(test.c1) | COUNT(test.c2) |", + "| count(test.c1) | count(test.c2) |", "+----------------+----------------+", "| 40 | 40 |", "+----------------+----------------+", @@ -86,11 +86,11 @@ async fn count_partitioned() -> Result<()> { #[tokio::test] async fn count_aggregated() -> Result<()> { let results = - execute_with_partition("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4).await?; + execute_with_partition("SELECT c1, count(c2) FROM test GROUP BY c1", 4).await?; let expected = [ "+----+----------------+", - "| c1 | COUNT(test.c2) |", + "| c1 | count(test.c2) |", "+----+----------------+", "| 0 | 10 |", "| 1 | 10 |", @@ -105,14 +105,14 @@ async fn count_aggregated() -> Result<()> { #[tokio::test] async fn count_aggregated_cube() -> Result<()> { let results = execute_with_partition( - "SELECT c1, c2, COUNT(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2", + "SELECT c1, c2, count(c3) FROM test GROUP BY CUBE (c1, c2) ORDER BY c1, c2", 4, ) .await?; let expected = vec![ "+----+----+----------------+", - "| c1 | c2 | COUNT(test.c3) |", + "| c1 | c2 | count(test.c3) |", "+----+----+----------------+", "| | | 40 |", "| | 1 | 4 |", @@ -222,15 +222,15 @@ async fn run_count_distinct_integers_aggregated_scenario( " SELECT c_group, - COUNT(c_uint64), - COUNT(DISTINCT c_int8), - COUNT(DISTINCT c_int16), - COUNT(DISTINCT c_int32), - COUNT(DISTINCT c_int64), - COUNT(DISTINCT c_uint8), - COUNT(DISTINCT c_uint16), - COUNT(DISTINCT c_uint32), - COUNT(DISTINCT c_uint64) + count(c_uint64), + count(DISTINCT c_int8), + count(DISTINCT c_int16), + count(DISTINCT c_int32), + count(DISTINCT c_int64), + count(DISTINCT c_uint8), + count(DISTINCT c_uint16), + count(DISTINCT c_uint32), + count(DISTINCT c_uint64) FROM test GROUP BY c_group ", @@ -260,7 +260,7 @@ async fn count_distinct_integers_aggregated_single_partition() -> Result<()> { let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; let expected = ["+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", + "| c_group | count(test.c_uint64) | count(DISTINCT test.c_int8) | count(DISTINCT test.c_int16) | count(DISTINCT test.c_int32) | count(DISTINCT test.c_int64) | count(DISTINCT test.c_uint8) | count(DISTINCT test.c_uint16) | count(DISTINCT test.c_uint32) | count(DISTINCT test.c_uint64) |", "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| a | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 2 | 2 |", "| b | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |", @@ -284,7 +284,7 @@ async fn count_distinct_integers_aggregated_multiple_partitions() -> Result<()> let results = run_count_distinct_integers_aggregated_scenario(partitions).await?; let expected = ["+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", - "| c_group | COUNT(test.c_uint64) | COUNT(DISTINCT test.c_int8) | COUNT(DISTINCT test.c_int16) | COUNT(DISTINCT test.c_int32) | COUNT(DISTINCT test.c_int64) | COUNT(DISTINCT test.c_uint8) | COUNT(DISTINCT test.c_uint16) | COUNT(DISTINCT test.c_uint32) | COUNT(DISTINCT test.c_uint64) |", + "| c_group | count(test.c_uint64) | count(DISTINCT test.c_int8) | count(DISTINCT test.c_int16) | count(DISTINCT test.c_int32) | count(DISTINCT test.c_int64) | count(DISTINCT test.c_uint8) | count(DISTINCT test.c_uint16) | count(DISTINCT test.c_uint32) | count(DISTINCT test.c_uint64) |", "+---------+----------------------+-----------------------------+------------------------------+------------------------------+------------------------------+------------------------------+-------------------------------+-------------------------------+-------------------------------+", "| a | 5 | 3 | 3 | 3 | 3 | 3 | 3 | 3 | 3 |", "| b | 5 | 4 | 4 | 4 | 4 | 4 | 4 | 4 | 4 |", @@ -301,7 +301,7 @@ async fn test_accumulator_row_accumulator() -> Result<()> { let ctx = SessionContext::new_with_config(config); register_aggregate_csv(&ctx).await?; - let sql = "SELECT c1, c2, MIN(c13) as min1, MIN(c9) as min2, MAX(c13) as max1, MAX(c9) as max2, AVG(c9) as avg1, MIN(c13) as min3, COUNT(C9) as cnt1, 0.5*SUM(c9-c8) as sum1 + let sql = "SELECT c1, c2, MIN(c13) as min1, MIN(c9) as min2, MAX(c13) as max1, MAX(c9) as max2, AVG(c9) as avg1, MIN(c13) as min3, count(C9) as cnt1, 0.5*SUM(c9-c8) as sum1 FROM aggregate_test_100 GROUP BY c1, c2 ORDER BY c1, c2 diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 2174009b8557..83712053b954 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::test_util::TestTableFactory; use super::*; #[tokio::test] async fn create_custom_table() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {})); @@ -45,10 +41,7 @@ async fn create_custom_table() -> Result<()> { #[tokio::test] async fn create_external_table_with_ddl() -> Result<()> { - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); state .table_factories_mut() .insert("MOCKTABLE".to_string(), Arc::new(TestTableFactory {})); diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index b5819dc18832..39fd492786bc 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -72,16 +72,11 @@ async fn explain_analyze_baseline_metrics() { assert_metrics!( &formatted, "GlobalLimitExec: skip=0, fetch=3, ", - "metrics=[output_rows=1, elapsed_compute=" - ); - assert_metrics!( - &formatted, - "LocalLimitExec: fetch=3", "metrics=[output_rows=3, elapsed_compute=" ); assert_metrics!( &formatted, - "ProjectionExec: expr=[COUNT(*)", + "ProjectionExec: expr=[count(*)", "metrics=[output_rows=1, elapsed_compute=" ); assert_metrics!( @@ -253,7 +248,7 @@ async fn csv_explain_plans() { // Optimized logical plan let state = ctx.state(); - let msg = format!("Optimizing logical plan for '{sql}': {plan:?}"); + let msg = format!("Optimizing logical plan for '{sql}': {plan}"); let plan = state.optimize(plan).expect(&msg); let optimized_logical_schema = plan.schema(); // Both schema has to be the same @@ -327,7 +322,7 @@ async fn csv_explain_plans() { // Physical plan // Create plan - let msg = format!("Creating physical plan for '{sql}': {plan:?}"); + let msg = format!("Creating physical plan for '{sql}': {plan}"); let plan = state.create_physical_plan(&plan).await.expect(&msg); // // Execute plan @@ -352,7 +347,7 @@ async fn csv_explain_verbose() { // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); - // Don't actually test the contents of the debuging output (as + // Don't actually test the contents of the debugging output (as // that may change and keeping this test updated will be a // pain). Instead just check for a few key pieces. assert_contains!(&actual, "logical_plan"); @@ -548,7 +543,7 @@ async fn csv_explain_verbose_plans() { // Physical plan // Create plan - let msg = format!("Creating physical plan for '{sql}': {plan:?}"); + let msg = format!("Creating physical plan for '{sql}': {plan}"); let plan = state.create_physical_plan(&plan).await.expect(&msg); // // Execute plan @@ -612,18 +607,17 @@ async fn test_physical_plan_display_indent() { let dataframe = ctx.sql(sql).await.unwrap(); let physical_plan = dataframe.create_physical_plan().await.unwrap(); let expected = vec![ - "GlobalLimitExec: skip=0, fetch=10", - " SortPreservingMergeExec: [the_min@2 DESC], fetch=10", - " SortExec: TopK(fetch=10), expr=[the_min@2 DESC]", - " ProjectionExec: expr=[c1@0 as c1, MAX(aggregate_test_100.c12)@1 as MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)@2 as the_min]", - " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", - " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[MAX(aggregate_test_100.c12), MIN(aggregate_test_100.c12)]", - " CoalesceBatchesExec: target_batch_size=4096", - " FilterExec: c12@1 < 10", - " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", - " CsvExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], has_header=true", + "SortPreservingMergeExec: [the_min@2 DESC], fetch=10", + " SortExec: TopK(fetch=10), expr=[the_min@2 DESC], preserve_partitioning=[true]", + " ProjectionExec: expr=[c1@0 as c1, max(aggregate_test_100.c12)@1 as max(aggregate_test_100.c12), min(aggregate_test_100.c12)@2 as the_min]", + " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([c1@0], 9000), input_partitions=9000", + " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[max(aggregate_test_100.c12), min(aggregate_test_100.c12)]", + " CoalesceBatchesExec: target_batch_size=4096", + " FilterExec: c12@1 < 10", + " RepartitionExec: partitioning=RoundRobinBatch(9000), input_partitions=1", + " CsvExec: file_groups={1 group: [[ARROW_TEST_DATA/csv/aggregate_test_100.csv]]}, projection=[c1, c12], has_header=true", ]; let normalizer = ExplainNormalizer::new(); @@ -700,7 +694,7 @@ async fn csv_explain_analyze() { // Only test basic plumbing and try to avoid having to change too // many things. explain_analyze_baseline_metrics covers the values // in greater depth - let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(*)], metrics=[output_rows=5"; + let needle = "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[count(*)], metrics=[output_rows=5"; assert_contains!(&formatted, needle); let verbose_needle = "Output Rows"; @@ -721,7 +715,7 @@ async fn csv_explain_analyze_order_by() { // Ensure that the ordering is not optimized away from the plan // https://github.com/apache/datafusion/issues/6379 let needle = - "SortExec: expr=[c1@0 ASC NULLS LAST], metrics=[output_rows=100, elapsed_compute"; + "SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false], metrics=[output_rows=100, elapsed_compute"; assert_contains!(&formatted, needle); } @@ -793,7 +787,7 @@ async fn explain_logical_plan_only() { let expected = vec![ vec![ "logical_plan", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ \n SubqueryAlias: t\ \n Projection: \ \n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))" @@ -812,7 +806,7 @@ async fn explain_physical_plan_only() { let expected = vec![vec![ "physical_plan", - "ProjectionExec: expr=[2 as COUNT(*)]\ + "ProjectionExec: expr=[2 as count(*)]\ \n PlaceholderRowExec\ \n", ]]; diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index f7d5205db0d3..fab92c0f9c2b 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion::datasource::stream::{StreamConfig, StreamTable}; +use datafusion::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable}; use datafusion::test_util::register_unbounded_file_with_ordering; use super::*; @@ -33,7 +33,7 @@ async fn join_change_in_planner() -> Result<()> { Field::new("a2", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; @@ -101,7 +101,7 @@ async fn join_no_order_on_filter() -> Result<()> { Field::new("a3", DataType::UInt32, false), ])); // Specify the ordering: - let file_sort_order = vec![[datafusion_expr::col("a1")] + let file_sort_order = vec![[col("a1")] .into_iter() .map(|e| { let ascending = true; @@ -166,12 +166,14 @@ async fn join_change_in_planner_without_sort() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - let left = StreamConfig::new_file(schema.clone(), left_file_path); + let left_source = FileStreamProvider::new_file(schema.clone(), left_file_path); + let left = StreamConfig::new(Arc::new(left_source)); ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - let right = StreamConfig::new_file(schema, right_file_path); + let right_source = FileStreamProvider::new_file(schema, right_file_path); + let right = StreamConfig::new(Arc::new(right_source)); ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let sql = "SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10"; let dataframe = ctx.sql(sql).await?; @@ -216,17 +218,19 @@ async fn join_change_in_planner_without_sort_not_allowed() -> Result<()> { Field::new("a1", DataType::UInt32, false), Field::new("a2", DataType::UInt32, false), ])); - let left = StreamConfig::new_file(schema.clone(), left_file_path); + let left_source = FileStreamProvider::new_file(schema.clone(), left_file_path); + let left = StreamConfig::new(Arc::new(left_source)); ctx.register_table("left", Arc::new(StreamTable::new(Arc::new(left))))?; let right_file_path = tmp_dir.path().join("right.csv"); File::create(right_file_path.clone())?; - let right = StreamConfig::new_file(schema.clone(), right_file_path); + let right_source = FileStreamProvider::new_file(schema.clone(), right_file_path); + let right = StreamConfig::new(Arc::new(right_source)); ctx.register_table("right", Arc::new(StreamTable::new(Arc::new(right))))?; let df = ctx.sql("SELECT t1.a1, t1.a2, t2.a1, t2.a2 FROM left as t1 FULL JOIN right as t2 ON t1.a2 = t2.a2 AND t1.a1 > t2.a1 + 3 AND t1.a1 < t2.a1 + 10").await?; match df.create_physical_plan().await { Ok(_) => panic!("Expecting error."), Err(e) => { - assert_eq!(e.strip_backtrace(), "PipelineChecker\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") + assert_eq!(e.strip_backtrace(), "SanityCheckPlan\ncaused by\nError during planning: Join operation cannot operate on a non-prunable stream without enabling the 'allow_symmetric_joins_without_pruning' configuration flag") } } Ok(()) diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 9b7828a777c8..177427b47d21 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -60,11 +60,12 @@ pub mod aggregates; pub mod create_drop; pub mod explain_analyze; pub mod joins; +mod path_partition; pub mod select; mod sql_api; async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let df = ctx .sql(&format!( @@ -85,8 +86,8 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { c13 VARCHAR NOT NULL ) STORED AS CSV - WITH HEADER ROW LOCATION '{testdata}/csv/aggregate_test_100.csv' + OPTIONS ('format.has_header' 'true') " )) .await @@ -102,7 +103,7 @@ async fn register_aggregate_csv_by_sql(ctx: &SessionContext) { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -226,7 +227,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { } async fn register_alltypes_parquet(ctx: &SessionContext) { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs similarity index 86% rename from datafusion/core/tests/path_partition.rs rename to datafusion/core/tests/sql/path_partition.rs index dd8eb52f67c7..919054e8330f 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -25,6 +25,7 @@ use std::sync::Arc; use arrow::datatypes::DataType; use datafusion::datasource::listing::ListingTableUrl; +use datafusion::datasource::physical_plan::ParquetExec; use datafusion::{ assert_batches_sorted_eq, datasource::{ @@ -36,21 +37,71 @@ use datafusion::{ prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; +use datafusion_catalog::TableProvider; use datafusion_common::stats::Precision; use datafusion_common::ScalarValue; +use datafusion_execution::config::SessionConfig; use async_trait::async_trait; use bytes::Bytes; use chrono::{TimeZone, Utc}; -use futures::stream; -use futures::stream::BoxStream; +use datafusion_expr::{col, lit, Expr, Operator}; +use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal}; +use datafusion_physical_expr::PhysicalExpr; +use futures::stream::{self, BoxStream}; use object_store::{ - path::Path, GetOptions, GetResult, GetResultPayload, ListResult, MultipartId, - ObjectMeta, ObjectStore, PutOptions, PutResult, + path::Path, GetOptions, GetResult, GetResultPayload, ListResult, ObjectMeta, + ObjectStore, PutOptions, PutResult, }; -use tokio::io::AsyncWrite; +use object_store::{Attributes, MultipartUpload, PutMultipartOpts, PutPayload}; use url::Url; +#[tokio::test] +async fn parquet_partition_pruning_filter() -> Result<()> { + let ctx = SessionContext::new(); + + let table = create_partitioned_alltypes_parquet_table( + &ctx, + &[ + "year=2021/month=09/day=09/file.parquet", + "year=2021/month=10/day=09/file.parquet", + "year=2021/month=10/day=28/file.parquet", + ], + &[ + ("year", DataType::Int32), + ("month", DataType::Int32), + ("day", DataType::Int32), + ], + "mirror:///", + "alltypes_plain.parquet", + ) + .await; + + // The first three filters can be resolved using only the partition columns. + let filters = [ + Expr::eq(col("year"), lit(2021)), + Expr::eq(col("month"), lit(10)), + Expr::eq(col("day"), lit(28)), + Expr::gt(col("id"), lit(1)), + ]; + let exec = table.scan(&ctx.state(), None, &filters, None).await?; + let parquet_exec = exec.as_any().downcast_ref::().unwrap(); + let pred = parquet_exec.predicate().unwrap(); + // Only the last filter should be pushdown to TableScan + let expected = Arc::new(BinaryExpr::new( + Arc::new(Column::new_with_schema("id", &exec.schema()).unwrap()), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))), + )); + + assert!(pred.as_any().is::()); + let pred = pred.as_any().downcast_ref::().unwrap(); + + assert_eq!(pred, expected.as_any()); + + Ok(()) +} + #[tokio::test] async fn parquet_distinct_partition_col() -> Result<()> { let ctx = SessionContext::new(); @@ -120,7 +171,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { //3. limit is not contained within a single partition //The id column is included to ensure that the parquet file is actually scanned. let results = ctx - .sql("SELECT COUNT(*) as num_rows_per_month, month, MAX(id) from t group by month order by num_rows_per_month desc") + .sql("SELECT count(*) as num_rows_per_month, month, MAX(id) from t group by month order by num_rows_per_month desc") .await? .collect() .await?; @@ -202,7 +253,9 @@ fn extract_as_utf(v: &ScalarValue) -> Option { #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().set_str("datafusion.catalog.has_header", "true"), + ); register_partitioned_aggregate_csv( &ctx, @@ -238,7 +291,9 @@ async fn csv_filter_with_file_col() -> Result<()> { #[tokio::test] async fn csv_filter_with_file_nonstring_col() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().set_str("datafusion.catalog.has_header", "true"), + ); register_partitioned_aggregate_csv( &ctx, @@ -274,7 +329,9 @@ async fn csv_filter_with_file_nonstring_col() -> Result<()> { #[tokio::test] async fn csv_projection_on_partition() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().set_str("datafusion.catalog.has_header", "true"), + ); register_partitioned_aggregate_csv( &ctx, @@ -310,7 +367,9 @@ async fn csv_projection_on_partition() -> Result<()> { #[tokio::test] async fn csv_grouping_by_partition() -> Result<()> { - let ctx = SessionContext::new(); + let ctx = SessionContext::new_with_config( + SessionConfig::new().set_str("datafusion.catalog.has_header", "true"), + ); register_partitioned_aggregate_csv( &ctx, @@ -331,7 +390,7 @@ async fn csv_grouping_by_partition() -> Result<()> { let expected = [ "+------------+----------+----------------------+", - "| date | COUNT(*) | COUNT(DISTINCT t.c1) |", + "| date | count(*) | count(DISTINCT t.c1) |", "+------------+----------+----------------------+", "| 2021-10-26 | 100 | 5 |", "| 2021-10-27 | 100 | 5 |", @@ -483,7 +542,7 @@ async fn parquet_statistics() -> Result<()> { // stats for the first col are read from the parquet file assert_eq!(stat_cols[0].null_count, Precision::Exact(1)); // TODO assert partition column stats once implemented (#1186) - assert_eq!(stat_cols[1], ColumnStatistics::new_unknown(),); + assert_eq!(stat_cols[1], ColumnStatistics::new_unknown()); Ok(()) } @@ -525,7 +584,7 @@ fn register_partitioned_aggregate_csv( let csv_file_path = format!("{testdata}/csv/aggregate_test_100.csv"); let file_schema = test_util::aggr_test_schema(); let url = Url::parse("mirror://").unwrap(); - ctx.runtime_env().register_object_store( + ctx.register_object_store( &url, MirroringObjectStore::new_arc(csv_file_path, store_paths), ); @@ -555,10 +614,29 @@ async fn register_partitioned_alltypes_parquet( table_path: &str, source_file: &str, ) { + let table = create_partitioned_alltypes_parquet_table( + ctx, + store_paths, + partition_cols, + table_path, + source_file, + ) + .await; + ctx.register_table("t", table) + .expect("registering listing table failed"); +} + +async fn create_partitioned_alltypes_parquet_table( + ctx: &SessionContext, + store_paths: &[&str], + partition_cols: &[(&str, DataType)], + table_path: &str, + source_file: &str, +) -> Arc { let testdata = parquet_test_data(); let parquet_file_path = format!("{testdata}/{source_file}"); let url = Url::parse("mirror://").unwrap(); - ctx.runtime_env().register_object_store( + ctx.register_object_store( &url, MirroringObjectStore::new_arc(parquet_file_path.clone(), store_paths), ); @@ -583,11 +661,7 @@ async fn register_partitioned_alltypes_parquet( let config = ListingTableConfig::new(table_path) .with_listing_options(options) .with_schema(file_schema); - - let table = ListingTable::try_new(config).unwrap(); - - ctx.register_table("t", Arc::new(table)) - .expect("registering listing table failed"); + Arc::new(ListingTable::try_new(config).unwrap()) } #[derive(Debug)] @@ -623,24 +697,17 @@ impl ObjectStore for MirroringObjectStore { async fn put_opts( &self, _location: &Path, - _bytes: Bytes, + _put_payload: PutPayload, _opts: PutOptions, ) -> object_store::Result { unimplemented!() } - async fn put_multipart( + async fn put_multipart_opts( &self, _location: &Path, - ) -> object_store::Result<(MultipartId, Box)> { - unimplemented!() - } - - async fn abort_multipart( - &self, - _location: &Path, - _multipart_id: &MultipartId, - ) -> object_store::Result<()> { + _opts: PutMultipartOpts, + ) -> object_store::Result> { unimplemented!() } @@ -665,6 +732,7 @@ impl ObjectStore for MirroringObjectStore { range: 0..meta.size, payload: GetResultPayload::File(file, path), meta, + attributes: Attributes::default(), }) } diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index f2710e659240..dd660512f346 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -246,7 +246,37 @@ async fn test_parameter_invalid_types() -> Result<()> { .await; assert_eq!( results.unwrap_err().strip_backtrace(), - "Arrow error: Invalid argument error: Invalid comparison operation: List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) == List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} })" + "type_coercion\ncaused by\nError during planning: Cannot infer common argument type for comparison operation List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) = Int32" ); Ok(()) } + +#[tokio::test] +async fn test_version_function() { + let expected_version = format!( + "Apache DataFusion {}, {} on {}", + env!("CARGO_PKG_VERSION"), + std::env::consts::ARCH, + std::env::consts::OS, + ); + + let ctx = SessionContext::new(); + let results = ctx + .sql("select version()") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // since width of columns varies between platforms, we can't compare directly + // so we just check that the version string is present + + // expect a single string column with a single row + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_columns(), 1); + let version = results[0].column(0).as_string::(); + assert_eq!(version.len(), 1); + + assert_eq!(version.value(0), expected_version); +} diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index b3a819fbc331..48f4a66b65dc 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -58,6 +58,19 @@ async fn unsupported_dml_returns_error() { ctx.sql_with_options(sql, options).await.unwrap(); } +#[tokio::test] +async fn dml_output_schema() { + use arrow::datatypes::Schema; + use arrow::datatypes::{DataType, Field}; + + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + let sql = "INSERT INTO test VALUES (1)"; + let df = ctx.sql(sql).await.unwrap(); + let count_schema = Schema::new(vec![Field::new("count", DataType::UInt64, false)]); + assert_eq!(Schema::from(df.schema()), count_schema); +} + #[tokio::test] async fn unsupported_copy_returns_error() { let tmpdir = TempDir::new().unwrap(); @@ -100,6 +113,40 @@ async fn unsupported_statement_returns_error() { ctx.sql_with_options(sql, options).await.unwrap(); } +#[tokio::test] +async fn empty_statement_returns_error() { + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let state = ctx.state(); + + // Give it an empty string which contains no statements + let plan_res = state.create_logical_plan("").await; + assert_eq!( + plan_res.unwrap_err().strip_backtrace(), + "Error during planning: No SQL statements were provided in the query string" + ); +} + +#[tokio::test] +async fn multiple_statements_returns_error() { + let ctx = SessionContext::new(); + ctx.sql("CREATE TABLE test (x int)").await.unwrap(); + + let state = ctx.state(); + + // Give it a string that contains multiple statements + let plan_res = state + .create_logical_plan( + "INSERT INTO test (x) VALUES (1); INSERT INTO test (x) VALUES (2)", + ) + .await; + assert_eq!( + plan_res.unwrap_err().strip_backtrace(), + "This feature is not implemented: The context currently only supports a single SQL statement" + ); +} + #[tokio::test] async fn ddl_can_not_be_planned_by_session_state() { let ctx = SessionContext::new(); diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 44fb0afff319..252d76d0f9d9 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -229,9 +229,6 @@ async fn tpcds_logical_q40() -> Result<()> { } #[tokio::test] -#[ignore] -// Optimizer rule 'scalar_subquery_to_join' failed: Optimizing disjunctions not supported! -// issue: https://github.com/apache/datafusion/issues/5368 async fn tpcds_logical_q41() -> Result<()> { create_logical_plan(41).await } @@ -571,7 +568,6 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -697,7 +693,6 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -728,8 +723,6 @@ async fn tpcds_physical_q40() -> Result<()> { create_physical_plan(40).await } -#[ignore] -// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..) #[tokio::test] async fn tpcds_physical_q41() -> Result<()> { create_physical_plan(41).await @@ -750,7 +743,6 @@ async fn tpcds_physical_q44() -> Result<()> { create_physical_plan(44).await } -#[ignore] // Physical plan does not support logical expression () #[tokio::test] async fn tpcds_physical_q45() -> Result<()> { create_physical_plan(45).await @@ -1044,7 +1036,10 @@ async fn regression_test(query_no: u8, create_physical: bool) -> Result<()> { for table in &tables { ctx.register_table( table.name.as_str(), - Arc::new(MemTable::try_new(Arc::new(table.schema.clone()), vec![])?), + Arc::new(MemTable::try_new( + Arc::new(table.schema.clone()), + vec![vec![]], + )?), )?; } diff --git a/datafusion/core/tests/user_defined/expr_planner.rs b/datafusion/core/tests/user_defined/expr_planner.rs new file mode 100644 index 000000000000..ad9c1280d6b1 --- /dev/null +++ b/datafusion/core/tests/user_defined/expr_planner.rs @@ -0,0 +1,124 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use std::sync::Arc; + +use datafusion::common::{assert_batches_eq, DFSchema}; +use datafusion::error::Result; +use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::Operator; +use datafusion::prelude::*; +use datafusion::sql::sqlparser::ast::BinaryOperator; +use datafusion_common::ScalarValue; +use datafusion_expr::expr::Alias; +use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawBinaryExpr}; +use datafusion_expr::BinaryExpr; + +#[derive(Debug)] +struct MyCustomPlanner; + +impl ExprPlanner for MyCustomPlanner { + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + _schema: &DFSchema, + ) -> Result> { + match &expr.op { + BinaryOperator::Arrow => { + Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr.left.clone()), + right: Box::new(expr.right.clone()), + op: Operator::StringConcat, + }))) + } + BinaryOperator::LongArrow => { + Ok(PlannerResult::Planned(Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr.left.clone()), + right: Box::new(expr.right.clone()), + op: Operator::Plus, + }))) + } + BinaryOperator::Question => { + Ok(PlannerResult::Planned(Expr::Alias(Alias::new( + Expr::Literal(ScalarValue::Boolean(Some(true))), + None::<&str>, + format!("{} ? {}", expr.left, expr.right), + )))) + } + _ => Ok(PlannerResult::Original(expr)), + } + } +} + +async fn plan_and_collect(sql: &str) -> Result> { + let config = + SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres"); + let mut ctx = SessionContext::new_with_config(config); + ctx.register_expr_planner(Arc::new(MyCustomPlanner))?; + ctx.sql(sql).await?.collect().await +} + +#[tokio::test] +async fn test_custom_operators_arrow() { + let actual = plan_and_collect("select 'foo'->'bar';").await.unwrap(); + let expected = [ + "+----------------------------+", + "| Utf8(\"foo\") || Utf8(\"bar\") |", + "+----------------------------+", + "| foobar |", + "+----------------------------+", + ]; + assert_batches_eq!(&expected, &actual); +} + +#[tokio::test] +async fn test_custom_operators_long_arrow() { + let actual = plan_and_collect("select 1->>2;").await.unwrap(); + let expected = [ + "+---------------------+", + "| Int64(1) + Int64(2) |", + "+---------------------+", + "| 3 |", + "+---------------------+", + ]; + assert_batches_eq!(&expected, &actual); +} + +#[tokio::test] +async fn test_question_select() { + let actual = plan_and_collect("select a ? 2 from (select 1 as a);") + .await + .unwrap(); + let expected = [ + "+--------------+", + "| a ? Int64(2) |", + "+--------------+", + "| true |", + "+--------------+", + ]; + assert_batches_eq!(&expected, &actual); +} + +#[tokio::test] +async fn test_question_filter() { + let actual = plan_and_collect("select a from (select 1 as a) where a ? 2;") + .await + .unwrap(); + let expected = ["+---+", "| a |", "+---+", "| 1 |", "+---+"]; + assert_batches_eq!(&expected, &actual); +} diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs new file mode 100644 index 000000000000..ff14fa0be3fb --- /dev/null +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::{ + error::Result, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::{DisplayAs, ExecutionMode, ExecutionPlan, PlanProperties}; + +#[tokio::test] +async fn insert_operation_is_passed_correctly_to_table_provider() { + // Use the SQLite syntax so we can test the "INSERT OR REPLACE INTO" syntax + let ctx = session_ctx_with_dialect("SQLite"); + let table_provider = Arc::new(TestInsertTableProvider::new()); + ctx.register_table("testing", table_provider.clone()) + .unwrap(); + + let sql = "INSERT INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Append).await; + + let sql = "INSERT OVERWRITE testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Overwrite).await; + + let sql = "REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; + + let sql = "INSERT OR REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; +} + +async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) { + let df = ctx.sql(sql).await.unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + let exec = plan.as_any().downcast_ref::().unwrap(); + assert_eq!(exec.op, insert_op); +} + +fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { + let mut config = SessionConfig::new(); + let options = config.options_mut(); + options.sql_parser.dialect = dialect.into(); + SessionContext::new_with_config(config) +} + +#[derive(Debug)] +struct TestInsertTableProvider { + schema: SchemaRef, +} + +impl TestInsertTableProvider { + fn new() -> Self { + Self { + schema: SchemaRef::new(Schema::new(vec![Field::new( + "column", + DataType::Int64, + false, + )])), + } + } +} + +#[async_trait] +impl TableProvider for TestInsertTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!("TestInsertTableProvider is a stub for testing.") + } + + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + insert_op: InsertOp, + ) -> Result> { + Ok(Arc::new(TestInsertExec::new(insert_op))) + } +} + +#[derive(Debug)] +struct TestInsertExec { + op: InsertOp, + plan_properties: PlanProperties, +} + +impl TestInsertExec { + fn new(op: InsertOp) -> Self { + let eq_properties = EquivalenceProperties::new(make_count_schema()); + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { + op, + plan_properties, + } + } +} + +impl DisplayAs for TestInsertExec { + fn fmt_as( + &self, + _t: datafusion_physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestInsertExec") + } +} + +impl ExecutionPlan for TestInsertExec { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.is_empty()); + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } +} + +fn make_count_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])) +} diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 6c6d966cc3aa..5d84cdb69283 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -29,3 +29,9 @@ mod user_defined_window_functions; /// Tests for User Defined Table Functions mod user_defined_table_functions; + +/// Tests for Expression Planner +mod expr_planner; + +/// Tests for insert operations +mod insert_operation; diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 8f02fb30b013..497addd23094 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -18,14 +18,19 @@ //! This module contains end to end demonstrations of creating //! user defined aggregate functions -use arrow::{array::AsArray, datatypes::Fields}; -use arrow_array::{types::UInt64Type, Int32Array, PrimitiveArray, StructArray}; -use arrow_schema::Schema; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::{ atomic::{AtomicBool, Ordering}, Arc, }; +use arrow::{array::AsArray, datatypes::Fields}; +use arrow_array::{ + types::UInt64Type, Int32Array, PrimitiveArray, StringArray, StructArray, +}; +use arrow_schema::Schema; + +use datafusion::dataframe::DataFrame; use datafusion::datasource::MemTable; use datafusion::test_util::plan_and_collect; use datafusion::{ @@ -45,10 +50,10 @@ use datafusion::{ }; use datafusion_common::{assert_contains, cast::as_primitive_array, exec_err}; use datafusion_expr::{ - create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, - SimpleAggregateUDF, + col, create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator, + LogicalPlanBuilder, SimpleAggregateUDF, }; -use datafusion_physical_expr::expressions::AvgAccumulator; +use datafusion_functions_aggregate::average::AvgAccumulator; /// Test to show the contents of the setup #[tokio::test] @@ -142,7 +147,7 @@ async fn test_udaf_as_window_with_frame_without_retract_batch() { let sql = "SELECT time_sum(time) OVER(ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as time_sum from t"; // Note if this query ever does start working let err = execute(&ctx, sql).await.unwrap_err(); - assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: AggregateUDF { inner: AggregateUDF { name: \"time_sum\", signature: Signature { type_signature: Exact([Timestamp(Nanosecond, None)]), volatility: Immutable }, fun: \"\" } }(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); + assert_contains!(err.to_string(), "This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented: time_sum(t.time) ORDER BY [t.time ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING"); } /// Basic query for with a udaf returning a structure @@ -186,7 +191,7 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let expected = [ "+---------------------------------------+", - "| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |", + "| sum(arrow_cast(t.time,Utf8(\"Int64\"))) |", "+---------------------------------------+", "| 19000 |", "+---------------------------------------+", @@ -267,7 +272,7 @@ async fn deregister_udaf() -> Result<()> { Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - ctx.register_udaf(my_avg.clone()); + ctx.register_udaf(my_avg); assert!(ctx.state().aggregate_functions().contains_key("my_avg")); @@ -377,6 +382,55 @@ async fn test_groups_accumulator() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_parameterized_aggregate_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + let udf1 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 1, + }); + let udf2 = AggregateUDF::from(TestGroupsAccumulator { + signature: signature.clone(), + result: 2, + }); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .aggregate( + [col("text")], + [ + udf1.call(vec![col("text")]).alias("a"), + udf2.call(vec![col("text")]).alias("b"), + ], + )? + .build()?; + + assert_eq!( + format!("{plan}"), + "Aggregate: groupBy=[[t.text]], aggr=[[geo_mean(t.text) AS a, geo_mean(t.text) AS b]]\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+------+---+---+", + "| text | a | b |", + "+------+---+---+", + "| foo | 1 | 2 |", + "+------+---+---+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + /// Returns an context with a table "t" and the "first" and "time_sum" /// aggregate functions registered. /// @@ -693,7 +747,7 @@ impl Accumulator for FirstSelector { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -725,13 +779,31 @@ impl AggregateUDFImpl for TestGroupsAccumulator { panic!("accumulator shouldn't invoke"); } - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { true } - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(self.clone())) } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.result == other.result && self.signature == other.signature + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.signature.hash(hasher); + self.result.hash(hasher); + hasher.finish() + } } impl Accumulator for TestGroupsAccumulator { @@ -744,7 +816,7 @@ impl Accumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } fn state(&mut self) -> Result> { @@ -792,6 +864,6 @@ impl GroupsAccumulator for TestGroupsAccumulator { } fn size(&self) -> usize { - std::mem::size_of::() + size_of::() } } diff --git a/datafusion/core/tests/user_defined/user_defined_plan.rs b/datafusion/core/tests/user_defined/user_defined_plan.rs index 2c12e108bb47..c96256784402 100644 --- a/datafusion/core/tests/user_defined/user_defined_plan.rs +++ b/datafusion/core/tests/user_defined/user_defined_plan.rs @@ -68,6 +68,10 @@ use arrow::{ record_batch::RecordBatch, util::pretty::pretty_format_batches, }; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; + +use datafusion::execution::session_state::SessionStateBuilder; use datafusion::{ common::cast::{as_int64_array, as_string_array}, common::{arrow_datafusion_err, internal_err, DFSchemaRef}, @@ -77,10 +81,10 @@ use datafusion::{ runtime_env::RuntimeEnv, }, logical_expr::{ - Expr, Extension, Limit, LogicalPlan, Sort, UserDefinedLogicalNode, + Expr, Extension, LogicalPlan, Sort, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, }, - optimizer::{optimize_children, OptimizerConfig, OptimizerRule}, + optimizer::{OptimizerConfig, OptimizerRule}, physical_expr::EquivalenceProperties, physical_plan::{ DisplayAs, DisplayFormatType, Distribution, ExecutionMode, ExecutionPlan, @@ -90,13 +94,17 @@ use datafusion::{ physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}, prelude::{SessionConfig, SessionContext}, }; - -use async_trait::async_trait; -use futures::{Stream, StreamExt}; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::ScalarValue; +use datafusion_expr::tree_node::replace_sort_expression; +use datafusion_expr::{FetchType, Projection, SortExpr}; +use datafusion_optimizer::optimizer::ApplyOrder; +use datafusion_optimizer::AnalyzerRule; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. -async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { +async fn exec_sql(ctx: &SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) @@ -105,39 +113,48 @@ async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { } /// Create a test table. -async fn setup_table(mut ctx: SessionContext) -> Result { - let sql = "CREATE EXTERNAL TABLE sales(customer_id VARCHAR, revenue BIGINT) STORED AS CSV location 'tests/data/customer.csv'"; +async fn setup_table(ctx: SessionContext) -> Result { + let sql = " + CREATE EXTERNAL TABLE sales(customer_id VARCHAR, revenue BIGINT) + STORED AS CSV location 'tests/data/customer.csv' + OPTIONS('format.has_header' 'false') + "; let expected = vec!["++", "++"]; - let s = exec_sql(&mut ctx, sql).await?; + let s = exec_sql(&ctx, sql).await?; let actual = s.lines().collect::>(); assert_eq!(expected, actual, "Creating table"); Ok(ctx) } -async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result { - let sql = - "CREATE EXTERNAL TABLE sales STORED AS CSV location 'tests/data/customer.csv'"; +async fn setup_table_without_schemas(ctx: SessionContext) -> Result { + let sql = " + CREATE EXTERNAL TABLE sales + STORED AS CSV location 'tests/data/customer.csv' + OPTIONS('format.has_header' 'false') + "; let expected = vec!["++", "++"]; - let s = exec_sql(&mut ctx, sql).await?; + let s = exec_sql(&ctx, sql).await?; let actual = s.lines().collect::>(); assert_eq!(expected, actual, "Creating table"); Ok(ctx) } -const QUERY1: &str = "SELECT * FROM sales limit 3"; - const QUERY: &str = "SELECT customer_id, revenue FROM sales ORDER BY revenue DESC limit 3"; +const QUERY1: &str = "SELECT * FROM sales limit 3"; + +const QUERY2: &str = "SELECT 42, arrow_typeof(42)"; + // Run the query using the specified execution context and compare it // to the known result -async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Result<()> { +async fn run_and_compare_query(ctx: SessionContext, description: &str) -> Result<()> { let expected = vec![ "+-------------+---------+", "| customer_id | revenue |", @@ -148,7 +165,35 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re "+-------------+---------+", ]; - let s = exec_sql(&mut ctx, QUERY).await?; + let s = exec_sql(&ctx, QUERY).await?; + let actual = s.lines().collect::>(); + + assert_eq!( + expected, + actual, + "output mismatch for {}. Expectedn\n{}Actual:\n{}", + description, + expected.join("\n"), + s + ); + Ok(()) +} + +// Run the query using the specified execution context and compare it +// to the known result +async fn run_and_compare_query_with_analyzer_rule( + ctx: SessionContext, + description: &str, +) -> Result<()> { + let expected = vec![ + "+------------+--------------------------+", + "| UInt64(42) | arrow_typeof(UInt64(42)) |", + "+------------+--------------------------+", + "| 42 | UInt64 |", + "+------------+--------------------------+", + ]; + + let s = exec_sql(&ctx, QUERY2).await?; let actual = s.lines().collect::>(); assert_eq!( @@ -165,7 +210,7 @@ async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Re // Run the query using the specified execution context and compare it // to the known result async fn run_and_compare_query_with_auto_schemas( - mut ctx: SessionContext, + ctx: SessionContext, description: &str, ) -> Result<()> { let expected = vec![ @@ -178,7 +223,7 @@ async fn run_and_compare_query_with_auto_schemas( "+----------+----------+", ]; - let s = exec_sql(&mut ctx, QUERY1).await?; + let s = exec_sql(&ctx, QUERY1).await?; let actual = s.lines().collect::>(); assert_eq!( @@ -206,6 +251,14 @@ async fn normal_query() -> Result<()> { run_and_compare_query(ctx, "Default context").await } +#[tokio::test] +// Run the query using default planners, optimizer and custom analyzer rule +async fn normal_query_with_analyzer() -> Result<()> { + let ctx = SessionContext::new(); + ctx.add_analyzer_rule(Arc::new(MyAnalyzerRule {})); + run_and_compare_query_with_analyzer_rule(ctx, "MyAnalyzerRule").await +} + #[tokio::test] // Run the query using topk optimization async fn topk_query() -> Result<()> { @@ -217,13 +270,13 @@ async fn topk_query() -> Result<()> { #[tokio::test] // Run EXPLAIN PLAN and show the plan was in fact rewritten async fn topk_plan() -> Result<()> { - let mut ctx = setup_table(make_topk_context()).await?; + let ctx = setup_table(make_topk_context()).await?; let mut expected = ["| logical_plan after topk | TopK: k=3 |", "| | TableScan: sales projection=[customer_id,revenue] |"].join("\n"); let explain_query = format!("EXPLAIN VERBOSE {QUERY}"); - let actual_output = exec_sql(&mut ctx, &explain_query).await?; + let actual_output = exec_sql(&ctx, &explain_query).await?; // normalize newlines (output on windows uses \r\n) let mut actual_output = actual_output.replace("\r\n", "\n"); @@ -246,14 +299,20 @@ async fn topk_plan() -> Result<()> { fn make_topk_context() -> SessionContext { let config = SessionConfig::new().with_target_partitions(48); let runtime = Arc::new(RuntimeEnv::default()); - let state = SessionState::new_with_config_rt(config, runtime) + let state = SessionStateBuilder::new() + .with_config(config) + .with_runtime_env(runtime) + .with_default_features() .with_query_planner(Arc::new(TopKQueryPlanner {})) - .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); + .with_optimizer_rule(Arc::new(TopKOptimizerRule {})) + .with_analyzer_rule(Arc::new(MyAnalyzerRule {})) + .build(); SessionContext::new_with_state(state) } // ------ The implementation of the TopK code follows ----- +#[derive(Debug)] struct TopKQueryPlanner {} #[async_trait] @@ -277,61 +336,67 @@ impl QueryPlanner for TopKQueryPlanner { } } +#[derive(Default, Debug)] struct TopKOptimizerRule {} + impl OptimizerRule for TopKOptimizerRule { + fn name(&self) -> &str { + "topk" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + // Example rewrite pass to insert a user defined LogicalPlanNode - fn try_optimize( + fn rewrite( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { // Note: this code simply looks for the pattern of a Limit followed by a // Sort and replaces it by a TopK node. It does not handle many // edge cases (e.g multiple sort columns, sort ASC / DESC), etc. - if let LogicalPlan::Limit(Limit { - fetch: Some(fetch), - input, + let LogicalPlan::Limit(ref limit) = plan else { + return Ok(Transformed::no(plan)); + }; + let FetchType::Literal(Some(fetch)) = limit.get_fetch_type()? else { + return Ok(Transformed::no(plan)); + }; + + if let LogicalPlan::Sort(Sort { + ref expr, + ref input, .. - }) = plan + }) = limit.input.as_ref() { - if let LogicalPlan::Sort(Sort { - ref expr, - ref input, - .. - }) = **input - { - if expr.len() == 1 { - // we found a sort with a single sort expr, replace with a a TopK - return Ok(Some(LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode { - k: *fetch, - input: self - .try_optimize(input.as_ref(), config)? - .unwrap_or_else(|| input.as_ref().clone()), - expr: expr[0].clone(), - }), - }))); - } + if expr.len() == 1 { + // we found a sort with a single sort expr, replace with a a TopK + return Ok(Transformed::yes(LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode { + k: fetch, + input: input.as_ref().clone(), + expr: expr[0].clone(), + }), + }))); } } - // If we didn't find the Limit/Sort combination, recurse as - // normal and build the result. - optimize_children(self, plan, config) - } - - fn name(&self) -> &str { - "topk" + Ok(Transformed::no(plan)) } } -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, PartialOrd, Hash)] struct TopKPlanNode { k: usize, input: LogicalPlan, /// The sort expression (this example only supports a single sort /// expr) - expr: Expr, + expr: SortExpr, } impl Debug for TopKPlanNode { @@ -357,7 +422,7 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { } fn expressions(&self) -> Vec { - vec![self.expr.clone()] + vec![self.expr.expr.clone()] } /// For example: `TopK: k=10` @@ -365,14 +430,22 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { write!(f, "TopK: k={}", self.k) } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs( + &self, + mut exprs: Vec, + mut inputs: Vec, + ) -> Result { assert_eq!(inputs.len(), 1, "input size inconsistent"); assert_eq!(exprs.len(), 1, "expression size inconsistent"); - Self { + Ok(Self { k: self.k, - input: inputs[0].clone(), - expr: exprs[0].clone(), - } + input: inputs.swap_remove(0), + expr: replace_sort_expression(self.expr.clone(), exprs.swap_remove(0)), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default } } @@ -440,11 +513,7 @@ impl Debug for TopKExec { } impl DisplayAs for TopKExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "TopKExec: k={}", self.k) @@ -455,6 +524,10 @@ impl DisplayAs for TopKExec { #[async_trait] impl ExecutionPlan for TopKExec { + fn name(&self) -> &'static str { + Self::static_name() + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -468,8 +541,8 @@ impl ExecutionPlan for TopKExec { vec![Distribution::SinglePartition] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -615,3 +688,53 @@ impl RecordBatchStream for TopKReader { self.input.schema() } } + +#[derive(Default, Debug)] +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(|plan| { + Ok(match plan { + LogicalPlan::Projection(projection) => { + let expr = Self::analyze_expr(projection.expr.clone())?; + Transformed::yes(LogicalPlan::Projection(Projection::try_new( + expr, + projection.input, + )?)) + } + _ => Transformed::no(plan), + }) + }) + .data() + } + + fn analyze_expr(expr: Vec) -> Result> { + expr.into_iter() + .map(|e| { + e.transform(|e| { + Ok(match e { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::no(e), + }) + }) + .data() + }) + .collect() + } +} diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index c40573a8df80..f1b172862399 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,31 +15,39 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::as_string_array; use arrow::compute::kernels::numeric::add; +use arrow_array::builder::BooleanBuilder; +use arrow_array::cast::AsArray; use arrow_array::{ - Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, + Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; -use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; +use parking_lot::Mutex; +use regex::Regex; +use sqlparser::ast::Ident; + use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; +use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{ - assert_batches_eq, assert_batches_sorted_eq, cast::as_float64_array, - cast::as_int32_array, not_impl_err, plan_err, ExprSchema, Result, ScalarValue, + assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, internal_err, + not_impl_err, plan_err, DFSchema, DataFusionError, ExprSchema, Result, ScalarValue, }; -use datafusion_common::{exec_err, internal_err, DataFusionError}; -use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - Accumulator, ColumnarValue, CreateFunction, ExprSchemable, LogicalPlanBuilder, - ScalarUDF, ScalarUDFImpl, Signature, Volatility, + Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, ExprSchemable, + LogicalPlanBuilder, OperateFunctionArg, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, }; -use rand::{thread_rng, Rng}; -use std::any::Any; -use std::iter; -use std::sync::Arc; +use datafusion_functions_nested::range::range_udf; /// test that casting happens on udfs. /// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and @@ -52,7 +60,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let actual = plan_and_collect(&ctx, sql).await.unwrap(); let expected = [ "+------------------------------------------+", - "| AVG(custom_sqrt(aggregate_test_100.c11)) |", + "| avg(custom_sqrt(aggregate_test_100.c11)) |", "+------------------------------------------+", "| 0.6584408483418835 |", "+------------------------------------------+", @@ -70,7 +78,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { let actual = plan_and_collect(&ctx, sql).await.unwrap(); let expected = [ "+------------------------------------------+", - "| AVG(custom_sqrt(aggregate_test_100.c12)) |", + "| avg(custom_sqrt(aggregate_test_100.c12)) |", "+------------------------------------------+", "| 0.6706002946036459 |", "+------------------------------------------+", @@ -114,7 +122,7 @@ async fn scalar_udf() -> Result<()> { ctx.register_udf(create_udf( "my_add", vec![DataType::Int32, DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, myfunc, )); @@ -133,7 +141,7 @@ async fn scalar_udf() -> Result<()> { .build()?; assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Projection: t.a, t.b, my_add(t.a, t.b)\n TableScan: t projection=[a, b]" ); @@ -168,31 +176,105 @@ async fn scalar_udf() -> Result<()> { Ok(()) } +struct Simple0ArgsScalarUDF { + name: String, + signature: Signature, + return_type: DataType, +} + +impl std::fmt::Debug for Simple0ArgsScalarUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl ScalarUDFImpl for Simple0ArgsScalarUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("{} function does not accept arguments", self.name()) + } + + fn invoke_no_args(&self, _number_rows: usize) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) + } +} + #[tokio::test] -async fn scalar_udf_zero_params() -> Result<()> { +async fn test_row_mismatch_error_in_scalar_udf() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let batch = RecordBatch::try_new( Arc::new(schema.clone()), - vec![Arc::new(Int32Array::from(vec![1, 10, 10, 100]))], + vec![Arc::new(Int32Array::from(vec![1, 2]))], )?; + let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - // create function just returns 100 regardless of inp - let myfunc = Arc::new(|_args: &[ColumnarValue]| { - Ok(ColumnarValue::Array( - Arc::new((0..1).map(|_| 100).collect::()) as ArrayRef, - )) + + // udf that always return 1 row + let buggy_udf = Arc::new(|_: &[ColumnarValue]| { + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(vec![0])))) }); ctx.register_udf(create_udf( - "get_100", - vec![], - Arc::new(DataType::Int32), + "buggy_func", + vec![DataType::Int32], + DataType::Int32, Volatility::Immutable, - myfunc, + buggy_udf, )); + assert_contains!( + ctx.sql("select buggy_func(a) from t") + .await? + .show() + .await + .err() + .unwrap() + .to_string(), + "UDF returned a different number of rows than expected" + ); + Ok(()) +} + +#[tokio::test] +async fn scalar_udf_zero_params() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 10, 10, 100]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_100_udf = Simple0ArgsScalarUDF { + name: "get_100".to_string(), + signature: Signature::exact(vec![], Volatility::Immutable), + return_type: DataType::Int32, + }; + + ctx.register_udf(ScalarUDF::from(get_100_udf)); let result = plan_and_collect(&ctx, "select get_100() a from t").await?; let expected = [ @@ -241,7 +323,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { ctx.register_udf(create_udf( "abs", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))), )); @@ -308,12 +390,12 @@ async fn udaf_as_window_func() -> Result<()> { context.register_udaf(my_acc); let sql = "SELECT a, MY_ACC(b) OVER(PARTITION BY a) FROM my_table"; - let expected = r#"Projection: my_table.a, AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - WindowAggr: windowExpr=[[AggregateUDF { inner: AggregateUDF { name: "my_acc", signature: Signature { type_signature: Exact([Int32]), volatility: Immutable }, fun: "" } }(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + let expected = r#"Projection: my_table.a, my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; let dataframe = context.sql(sql).await.unwrap(); - assert_eq!(format!("{:?}", dataframe.logical_plan()), expected); + assert_eq!(format!("{}", dataframe.logical_plan()), expected); Ok(()) } @@ -334,7 +416,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { ctx.register_udf(create_udf( "MY_FUNC", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, myfunc, )); @@ -379,7 +461,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let udf = create_udf( "dummy", vec![DataType::Int32], - Arc::new(DataType::Int32), + DataType::Int32, Volatility::Immutable, myfunc, ) @@ -403,26 +485,31 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { Ok(()) } +/// Volatile UDF that should append a different value to each row #[derive(Debug)] -pub struct RandomUDF { +struct AddIndexToStringVolatileScalarUDF { + name: String, signature: Signature, + return_type: DataType, } -impl RandomUDF { - pub fn new() -> Self { +impl AddIndexToStringVolatileScalarUDF { + fn new() -> Self { Self { - signature: Signature::any(0, Volatility::Volatile), + name: "add_index_to_string".to_string(), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Volatile), + return_type: DataType::Utf8, } } } -impl ScalarUDFImpl for RandomUDF { - fn as_any(&self) -> &dyn std::any::Any { +impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF { + fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "random_udf" + &self.name } fn signature(&self) -> &Signature { @@ -430,93 +517,150 @@ impl ScalarUDFImpl for RandomUDF { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Float64) + Ok(self.return_type.clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let len: usize = match &args[0] { - // This udf is always invoked with zero argument so its argument - // is a null array indicating the batch size. - ColumnarValue::Array(array) if array.data_type().is_null() => array.len(), - _ => { - return Err(datafusion::error::DataFusionError::Internal( - "Invalid argument type".to_string(), - )) + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("index_with_offset function does not accept arguments") + } + + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + let answer = match &args[0] { + // When called with static arguments, the result is returned as an array. + ColumnarValue::Scalar(ScalarValue::Utf8(Some(value))) => { + let mut answer = vec![]; + for index in 1..=number_rows { + // When calling a function with immutable arguments, the result is returned with ")". + // Example: SELECT add_index_to_string('const_value') FROM table; + answer.push(index.to_string() + ") " + value); + } + answer + } + // The result is returned as an array when called with dynamic arguments. + ColumnarValue::Array(array) => { + let string_array = as_string_array(array); + let mut counter = HashMap::<&str, u64>::new(); + string_array + .iter() + .map(|value| { + let value = value.expect("Unexpected null"); + let index = counter.get(value).unwrap_or(&0) + 1; + counter.insert(value, index); + + // When calling a function with mutable arguments, the result is returned with ".". + // Example: SELECT add_index_to_string(table.value) FROM table; + index.to_string() + ". " + value + }) + .collect() } + _ => unimplemented!(), }; - let mut rng = thread_rng(); - let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len); - let array = Float64Array::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) + Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer)))) } } -/// Ensure that a user defined function with zero argument will be invoked -/// with a null array indicating the batch size. #[tokio::test] -async fn test_user_defined_functions_zero_argument() -> Result<()> { - let ctx = SessionContext::new(); - - let schema = Arc::new(Schema::new(vec![Field::new( - "index", - DataType::UInt8, - false, - )])); - - let batch = RecordBatch::try_new( - schema, - vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))], - )?; - - ctx.register_batch("data_table", batch)?; - - let random_normal_udf = ScalarUDF::from(RandomUDF::new()); - ctx.register_udf(random_normal_udf); - - let result = plan_and_collect( - &ctx, - "SELECT random_udf() AS random_udf, random() AS native_random FROM data_table", - ) - .await?; - - assert_eq!(result.len(), 1); - let batch = &result[0]; - let random_udf = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); - let native_random = batch - .column(1) - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!(random_udf.len(), native_random.len()); - - let mut previous = -1.0; - for i in 0..random_udf.len() { - assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0); - assert!(random_udf.value(i) != previous); - previous = random_udf.value(i); +async fn volatile_scalar_udf_with_params() -> Result<()> { + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", "test_2", "test_2", "test_1", "test_2", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") // with dynamic function parameters + .await?; + let expected = [ + "+-----------+", + "| str |", + "+-----------+", + "| 1. test_1 |", + "| 2. test_1 |", + "| 3. test_1 |", + "| 1. test_2 |", + "| 2. test_2 |", + "| 4. test_1 |", + "| 3. test_2 |", + "+-----------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test') AS str from t") // with fixed function parameters + .await?; + let expected = [ + "+---------+", + "| str |", + "+---------+", + "| 1) test |", + "| 2) test |", + "| 3) test |", + "| 4) test |", + "| 5) test |", + "| 6) test |", + "| 7) test |", + "+---------+", + ]; + assert_batches_eq!(expected, &result); + + let result = + plan_and_collect(&ctx, "select add_index_to_string('test_value') as str") // with fixed function parameters + .await?; + let expected = [ + "+---------------+", + "| str |", + "+---------------+", + "| 1) test_value |", + "+---------------+", + ]; + assert_batches_eq!(expected, &result); + } + { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from(vec![ + "test_1", "test_1", "test_1", + ]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + + let get_new_str_udf = AddIndexToStringVolatileScalarUDF::new(); + + ctx.register_udf(ScalarUDF::from(get_new_str_udf)); + + let result = + plan_and_collect(&ctx, "select add_index_to_string(t.a) AS str from t") + .await?; + let expected = [ + "+-----------+", // + "| str |", // + "+-----------+", // + "| 1. test_1 |", // + "| 2. test_1 |", // + "| 3. test_1 |", // + "+-----------+", + ]; + assert_batches_eq!(expected, &result); } - - Ok(()) -} - -#[tokio::test] -async fn deregister_udf() -> Result<()> { - let random_normal_udf = ScalarUDF::from(RandomUDF::new()); - let ctx = SessionContext::new(); - - ctx.register_udf(random_normal_udf.clone()); - - assert!(ctx.udfs().contains("random_udf")); - - ctx.deregister_udf("random_udf"); - - assert!(!ctx.udfs().contains("random_udf")); - Ok(()) } @@ -615,6 +759,33 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_user_defined_sql_functions() -> Result<()> { + let ctx = SessionContext::new(); + + let expr_planners = ctx.expr_planners(); + + assert!(!expr_planners.is_empty()); + + Ok(()) +} + +#[tokio::test] +async fn deregister_udf() -> Result<()> { + let cast2i64 = ScalarUDF::from(CastToI64UDF::new()); + let ctx = SessionContext::new(); + + ctx.register_udf(cast2i64); + + assert!(ctx.udfs().contains("cast_to_i64")); + + ctx.deregister_udf("cast_to_i64"); + + assert!(!ctx.udfs().contains("cast_to_i64")); + + Ok(()) +} + #[derive(Debug)] struct TakeUDF { signature: Signature, @@ -765,11 +936,11 @@ struct ScalarFunctionWrapper { name: String, expr: Expr, signature: Signature, - return_type: arrow_schema::DataType, + return_type: DataType, } impl ScalarUDFImpl for ScalarFunctionWrapper { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } @@ -777,21 +948,15 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { &self.name } - fn signature(&self) -> &datafusion_expr::Signature { + fn signature(&self) -> &Signature { &self.signature } - fn return_type( - &self, - _arg_types: &[arrow_schema::DataType], - ) -> Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(self.return_type.clone()) } - fn invoke( - &self, - _args: &[datafusion_expr::ColumnarValue], - ) -> Result { + fn invoke(&self, _args: &[ColumnarValue]) -> Result { internal_err!("This function should not get invoked!") } @@ -808,10 +973,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { fn aliases(&self) -> &[String] { &[] } - - fn monotonicity(&self) -> Result> { - Ok(None) - } } impl ScalarFunctionWrapper { @@ -863,7 +1024,7 @@ impl TryFrom for ScalarFunctionWrapper { name: definition.name, expr: definition .params - .return_ + .function_body .expect("Expression has to be defined!"), return_type: definition .return_type @@ -875,10 +1036,7 @@ impl TryFrom for ScalarFunctionWrapper { .into_iter() .map(|a| a.data_type) .collect(), - definition - .params - .behavior - .unwrap_or(datafusion_expr::Volatility::Volatile), + definition.params.behavior.unwrap_or(Volatility::Volatile), ), }) } @@ -887,15 +1045,7 @@ impl TryFrom for ScalarFunctionWrapper { #[tokio::test] async fn create_scalar_function_from_sql_statement() -> Result<()> { let function_factory = Arc::new(CustomFunctionFactory::default()); - let runtime_config = RuntimeConfig::new(); - let runtime_environment = RuntimeEnv::new(runtime_config)?; - - let session_config = SessionConfig::new(); - let state = - SessionState::new_with_config_rt(session_config, Arc::new(runtime_environment)) - .with_function_factory(function_factory.clone()); - - let ctx = SessionContext::new_with_state(state); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); let options = SQLOptions::new().with_allow_ddl(false); let sql = r#" @@ -961,13 +1111,217 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { Ok(()) } +/// Saves whatever is passed to it as a scalar function +#[derive(Debug, Default)] +struct RecordingFunctionFactory { + calls: Mutex>, +} + +impl RecordingFunctionFactory { + fn new() -> Self { + Self::default() + } + + /// return all the calls made to the factory + fn calls(&self) -> Vec { + self.calls.lock().clone() + } +} + +#[async_trait::async_trait] +impl FunctionFactory for RecordingFunctionFactory { + async fn create( + &self, + _state: &SessionState, + statement: CreateFunction, + ) -> Result { + self.calls.lock().push(statement); + + let udf = range_udf(); + Ok(RegisterFunction::Scalar(udf)) + } +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_postgres_syntax() -> Result<()> { + let function_factory = Arc::new(RecordingFunctionFactory::new()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION strlen(name TEXT) + RETURNS int LANGUAGE plrust AS + $$ + Ok(Some(name.unwrap().len() as i32)) + $$; + "#; + + let body = " + Ok(Some(name.unwrap().len() as i32)) + "; + + match ctx.sql(sql).await { + Ok(_) => {} + Err(e) => { + panic!("Error creating function: {}", e); + } + } + + // verify that the call was passed through + let calls = function_factory.calls(); + let schema = DFSchema::try_from(Schema::empty())?; + assert_eq!(calls.len(), 1); + let call = &calls[0]; + let expected = CreateFunction { + or_replace: false, + temporary: false, + name: "strlen".into(), + args: Some(vec![OperateFunctionArg { + name: Some(Ident { + value: "name".into(), + quote_style: None, + }), + data_type: DataType::Utf8, + default_expr: None, + }]), + return_type: Some(DataType::Int32), + params: CreateFunctionBody { + language: Some(Ident { + value: "plrust".into(), + quote_style: None, + }), + behavior: None, + function_body: Some(lit(body)), + }, + schema: Arc::new(schema), + }; + + assert_eq!(call, &expected); + + Ok(()) +} + +#[derive(Debug)] +struct MyRegexUdf { + signature: Signature, + regex: Regex, +} + +impl MyRegexUdf { + fn new(pattern: &str) -> Self { + Self { + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), + regex: Regex::new(pattern).expect("regex"), + } + } + + fn matches(&self, value: Option<&str>) -> Option { + Some(self.regex.is_match(value?)) + } +} + +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "regex_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Boolean) + } else { + plan_err!("regex_udf only accepts a Utf8 argument") + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + self.matches(value.as_deref()), + ))) + } + [ColumnarValue::Array(values)] => { + let mut builder = BooleanBuilder::with_capacity(values.len()); + for value in values.as_string::() { + builder.append_option(self.matches(value)) + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.regex.as_str() == other.regex.as_str() + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.regex.as_str().hash(hasher); + hasher.finish() + } +} + +#[tokio::test] +async fn test_parameterized_scalar_udf() -> Result<()> { + let batch = RecordBatch::try_from_iter([( + "text", + Arc::new(StringArray::from(vec!["foo", "bar", "foobar", "barfoo"])) as ArrayRef, + )])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let foo_udf = ScalarUDF::from(MyRegexUdf::new("fo{2}")); + let bar_udf = ScalarUDF::from(MyRegexUdf::new("[Bb]ar")); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .filter( + foo_udf + .call(vec![col("text")]) + .and(bar_udf.call(vec![col("text")])), + )? + .filter(col("text").is_not_null())? + .build()?; + + assert_eq!( + format!("{plan}"), + "Filter: t.text IS NOT NULL\n Filter: regex_udf(t.text) AND regex_udf(t.text)\n TableScan: t projection=[text]" + ); + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + let expected = [ + "+--------+", + "| text |", + "+--------+", + "| foobar |", + "| barfoo |", + "+--------+", + ]; + assert_batches_eq!(expected, &actual); + + ctx.deregister_table("t")?; + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF ctx.register_udf(create_udf( "custom_sqrt", vec![DataType::Float64], - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, Arc::new(custom_sqrt), )); @@ -987,7 +1341,7 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::arrow_test_data(); + let testdata = test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( "aggregate_test_100", @@ -999,7 +1353,7 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { } async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", &format!("{testdata}/alltypes_plain.parquet"), diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index b5d10b1c5b9b..0cc156866d4d 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -24,11 +24,11 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; -use datafusion::execution::context::SessionState; use datafusion::execution::TaskContext; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::SessionContext; +use datafusion_catalog::Session; use datafusion_common::{assert_batches_eq, DFSchema, ScalarValue}; use datafusion_expr::{EmptyRelation, Expr, LogicalPlan, Projection, TableType}; use std::fs::File; @@ -90,6 +90,22 @@ async fn test_simple_read_csv_udtf() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_deregister_udtf() -> Result<()> { + let ctx = SessionContext::new(); + + ctx.register_udtf("read_csv", Arc::new(SimpleCsvTableFunc {})); + + assert!(ctx.state().table_functions().contains_key("read_csv")); + + ctx.deregister_udtf("read_csv"); + + assert!(!ctx.state().table_functions().contains_key("read_csv")); + + Ok(()) +} + +#[derive(Debug)] struct SimpleCsvTable { schema: SchemaRef, exprs: Vec, @@ -112,7 +128,7 @@ impl TableProvider for SimpleCsvTable { async fn scan( &self, - state: &SessionState, + state: &dyn Session, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -146,7 +162,7 @@ impl TableProvider for SimpleCsvTable { } impl SimpleCsvTable { - async fn interpreter_expr(&self, state: &SessionState) -> Result { + async fn interpreter_expr(&self, state: &dyn Session) -> Result { use datafusion::logical_expr::expr_rewriter::normalize_col; use datafusion::logical_expr::utils::columnize_expr; let plan = LogicalPlan::EmptyRelation(EmptyRelation { @@ -156,8 +172,8 @@ impl SimpleCsvTable { let logical_plan = Projection::try_new( vec![columnize_expr( normalize_col(self.exprs[0].clone(), &plan)?, - plan.schema(), - )], + &plan, + )?], Arc::new(plan), ) .map(LogicalPlan::Projection)?; @@ -176,6 +192,7 @@ impl SimpleCsvTable { } } +#[derive(Debug)] struct SimpleCsvTableFunc {} impl TableFunctionImpl for SimpleCsvTableFunc { @@ -185,7 +202,7 @@ impl TableFunctionImpl for SimpleCsvTableFunc { for expr in exprs { match expr { Expr::Literal(ScalarValue::Utf8(Some(ref path))) => { - filepath = path.clone() + filepath.clone_from(path); } expr => new_exprs.push(expr.clone()), } diff --git a/datafusion/core/tests/user_defined/user_defined_window_functions.rs b/datafusion/core/tests/user_defined/user_defined_window_functions.rs index 3c607301fc98..10ee0c5cd2dc 100644 --- a/datafusion/core/tests/user_defined/user_defined_window_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_window_functions.rs @@ -29,11 +29,19 @@ use std::{ use arrow::array::AsArray; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; -use arrow_schema::DataType; +use arrow_schema::{DataType, Field, Schema}; use datafusion::{assert_batches_eq, prelude::SessionContext}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDF, WindowUDFImpl, +}; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_functions_window_common::{ + expr::ExpressionArgs, field::WindowUDFFieldArgs, +}; +use datafusion_physical_expr::{ + expressions::{col, lit}, + PhysicalExpr, }; /// A query with a window function evaluated over the entire partition @@ -522,7 +530,6 @@ impl OddCounter { #[derive(Debug, Clone)] struct SimpleWindowUDF { signature: Signature, - return_type: DataType, test_state: Arc, aliases: Vec, } @@ -531,10 +538,8 @@ impl OddCounter { fn new(test_state: Arc) -> Self { let signature = Signature::exact(vec![DataType::Float64], Volatility::Immutable); - let return_type = DataType::Int64; Self { signature, - return_type, test_state, aliases: vec!["odd_counter_alias".to_string()], } @@ -554,17 +559,20 @@ impl OddCounter { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) - } - - fn partition_evaluator(&self) -> Result> { + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state)))) } fn aliases(&self) -> &[String] { &self.aliases } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Int64, true)) + } } ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state))) @@ -591,11 +599,7 @@ impl PartitionEvaluator for OddCounter { Ok(scalar) } - fn evaluate_all( - &mut self, - values: &[arrow_array::ArrayRef], - num_rows: usize, - ) -> Result { + fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { println!("evaluate_all, values: {values:#?}, num_rows: {num_rows}"); self.test_state.inc_evaluate_all_called(); @@ -639,7 +643,124 @@ fn odd_count(arr: &Int64Array) -> i64 { } /// returns an array of num_rows that has the number of odd values in `arr` -fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> arrow_array::ArrayRef { +fn odd_count_arr(arr: &Int64Array, num_rows: usize) -> ArrayRef { let array: Int64Array = std::iter::repeat(odd_count(arr)).take(num_rows).collect(); Arc::new(array) } + +#[derive(Debug)] +struct VariadicWindowUDF { + signature: Signature, +} + +impl VariadicWindowUDF { + fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(0), + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + } + } +} + +impl WindowUDFImpl for VariadicWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "variadic_window_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _: PartitionEvaluatorArgs, + ) -> Result> { + unimplemented!("unnecessary for testing"); + } + + fn field(&self, _: WindowUDFFieldArgs) -> Result { + unimplemented!("unnecessary for testing"); + } +} + +#[test] +// Fixes: default implementation of `WindowUDFImpl::expressions` +// returns all input expressions to the user-defined window +// function unmodified. +// +// See: https://github.com/apache/datafusion/pull/13169 +fn test_default_expressions() -> Result<()> { + let udwf = WindowUDF::from(VariadicWindowUDF::new()); + + let field_a = Field::new("a", DataType::Int32, false); + let field_b = Field::new("b", DataType::Float32, false); + let field_c = Field::new("c", DataType::Boolean, false); + let schema = Schema::new(vec![field_a, field_b, field_c]); + + let test_cases = vec![ + // + // Zero arguments + // + vec![], + // + // Single argument + // + vec![col("a", &schema)?], + vec![lit(1)], + // + // Two arguments + // + vec![col("a", &schema)?, col("b", &schema)?], + vec![col("a", &schema)?, lit(2)], + vec![lit(false), col("a", &schema)?], + // + // Three arguments + // + vec![col("a", &schema)?, col("b", &schema)?, col("c", &schema)?], + vec![col("a", &schema)?, col("b", &schema)?, lit(false)], + vec![col("a", &schema)?, lit(0.5), col("c", &schema)?], + vec![lit(3), col("b", &schema)?, col("c", &schema)?], + ]; + + for input_exprs in &test_cases { + let input_types = input_exprs + .iter() + .map(|expr: &Arc| expr.data_type(&schema).unwrap()) + .collect::>(); + let expr_args = ExpressionArgs::new(input_exprs, &input_types); + + let ret_exprs = udwf.expressions(expr_args); + + // Verify same number of input expressions are returned + assert_eq!( + input_exprs.len(), + ret_exprs.len(), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + + // Compares each returned expression with original input expressions + for (expected, actual) in input_exprs.iter().zip(&ret_exprs) { + assert_eq!( + format!("{expected:?}"), + format!("{actual:?}"), + "\nInput expressions: {:?}\nReturned expressions: {:?}", + input_exprs, + ret_exprs + ); + } + } + Ok(()) +} diff --git a/datafusion/execution/Cargo.toml b/datafusion/execution/Cargo.toml index a00b3354eb73..fb2e7e914fe5 100644 --- a/datafusion/execution/Cargo.toml +++ b/datafusion/execution/Cargo.toml @@ -42,7 +42,7 @@ dashmap = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } futures = { workspace = true } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { workspace = true } log = { workspace = true } object_store = { workspace = true } parking_lot = { workspace = true } diff --git a/datafusion/execution/src/cache/cache_manager.rs b/datafusion/execution/src/cache/cache_manager.rs index 97529263688b..c2403e34c665 100644 --- a/datafusion/execution/src/cache/cache_manager.rs +++ b/datafusion/execution/src/cache/cache_manager.rs @@ -54,10 +54,10 @@ impl CacheManager { pub fn try_new(config: &CacheManagerConfig) -> Result> { let mut manager = CacheManager::default(); if let Some(cc) = &config.table_files_statistics_cache { - manager.file_statistic_cache = Some(cc.clone()) + manager.file_statistic_cache = Some(Arc::clone(cc)) } if let Some(lc) = &config.list_files_cache { - manager.list_files_cache = Some(lc.clone()) + manager.list_files_cache = Some(Arc::clone(lc)) } Ok(Arc::new(manager)) } diff --git a/datafusion/execution/src/cache/cache_unit.rs b/datafusion/execution/src/cache/cache_unit.rs index 25f9b9fa4d68..a9291659a3ef 100644 --- a/datafusion/execution/src/cache/cache_unit.rs +++ b/datafusion/execution/src/cache/cache_unit.rs @@ -39,7 +39,7 @@ impl CacheAccessor> for DefaultFileStatisticsCache { fn get(&self, k: &Path) -> Option> { self.statistics .get(k) - .map(|s| Some(s.value().1.clone())) + .map(|s| Some(Arc::clone(&s.value().1))) .unwrap_or(None) } @@ -55,7 +55,7 @@ impl CacheAccessor> for DefaultFileStatisticsCache { // file has changed None } else { - Some(statistics.clone()) + Some(Arc::clone(statistics)) } }) .unwrap_or(None) @@ -108,7 +108,7 @@ impl CacheAccessor>> for DefaultListFilesCache { type Extra = ObjectMeta; fn get(&self, k: &Path) -> Option>> { - self.statistics.get(k).map(|x| x.value().clone()) + self.statistics.get(k).map(|x| Arc::clone(x.value())) } fn get_with_extra( diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 28275d484e29..53646dc5b468 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -48,7 +48,7 @@ use datafusion_common::{ /// use datafusion_common::ScalarValue; /// /// let config = SessionConfig::new() -/// .set("datafusion.execution.batch_size", ScalarValue::UInt64(Some(1234))) +/// .set("datafusion.execution.batch_size", &ScalarValue::UInt64(Some(1234))) /// .set_bool("datafusion.execution.parquet.pushdown_filters", true); /// /// assert_eq!(config.batch_size(), 1234); @@ -78,7 +78,6 @@ use datafusion_common::{ /// | --------- | ------------- | /// | `datafusion.catalog` | [CatalogOptions][datafusion_common::config::CatalogOptions] | /// | `datafusion.execution` | [ExecutionOptions][datafusion_common::config::ExecutionOptions] | -/// | `datafusion.execution.aggregate` | [AggregateOptions][datafusion_common::config::AggregateOptions] | /// | `datafusion.execution.parquet` | [ParquetOptions][datafusion_common::config::ParquetOptions] | /// | `datafusion.optimizer` | [OptimizerOptions][datafusion_common::config::OptimizerOptions] | /// | `datafusion.sql_parser` | [SqlParserOptions][datafusion_common::config::SqlParserOptions] | @@ -123,7 +122,7 @@ impl SessionConfig { } /// Create new ConfigOptions struct, taking values from a string hash map. - pub fn from_string_hash_map(settings: HashMap) -> Result { + pub fn from_string_hash_map(settings: &HashMap) -> Result { Ok(ConfigOptions::from_string_hash_map(settings)?.into()) } @@ -157,7 +156,7 @@ impl SessionConfig { } /// Set a configuration option - pub fn set(self, key: &str, value: ScalarValue) -> Self { + pub fn set(self, key: &str, value: &ScalarValue) -> Self { self.set_str(key, &value.to_string()) } @@ -331,6 +330,14 @@ impl SessionConfig { self } + /// Prefer existing union (true). See [prefer_existing_union] for more details + /// + /// [prefer_existing_union]: datafusion_common::config::OptimizerOptions::prefer_existing_union + pub fn with_prefer_existing_union(mut self, enabled: bool) -> Self { + self.options.optimizer.prefer_existing_union = enabled; + self + } + /// Enables or disables the use of pruning predicate for parquet readers to skip row groups pub fn with_parquet_pruning(mut self, enabled: bool) -> Self { self.options.execution.parquet.pruning = enabled; @@ -344,12 +351,12 @@ impl SessionConfig { /// Returns true if bloom filter should be used to skip parquet row groups pub fn parquet_bloom_filter_pruning(&self) -> bool { - self.options.execution.parquet.bloom_filter_enabled + self.options.execution.parquet.bloom_filter_on_read } /// Enables or disables the use of bloom filter for parquet readers to skip row groups pub fn with_parquet_bloom_filter_pruning(mut self, enabled: bool) -> Self { - self.options.execution.parquet.bloom_filter_enabled = enabled; + self.options.execution.parquet.bloom_filter_on_read = enabled; self } @@ -375,19 +382,6 @@ impl SessionConfig { self.options.execution.batch_size } - /// Get the currently configured scalar_update_factor for aggregate - pub fn agg_scalar_update_factor(&self) -> usize { - self.options.execution.aggregate.scalar_update_factor - } - - /// Customize scalar_update_factor for aggregate - pub fn with_agg_scalar_update_factor(mut self, n: usize) -> Self { - // scalar update factor must be greater than zero - assert!(n > 0); - self.options.execution.aggregate.scalar_update_factor = n; - self - } - /// Enables or disables the coalescence of small batches into larger batches pub fn with_coalesce_batches(mut self, enabled: bool) -> Self { self.options.execution.coalesce_batches = enabled; @@ -438,6 +432,20 @@ impl SessionConfig { self } + /// Enables or disables the enforcement of batch size in joins + pub fn with_enforce_batch_size_in_joins( + mut self, + enforce_batch_size_in_joins: bool, + ) -> Self { + self.options.execution.enforce_batch_size_in_joins = enforce_batch_size_in_joins; + self + } + + /// Returns true if the joins will be enforced to output batches of the configured size + pub fn enforce_batch_size_in_joins(&self) -> bool { + self.options.execution.enforce_batch_size_in_joins + } + /// Convert configuration options to name-value pairs with values /// converted to strings. /// diff --git a/datafusion/execution/src/disk_manager.rs b/datafusion/execution/src/disk_manager.rs index 85cc6f8499f0..38c259fcbdc8 100644 --- a/datafusion/execution/src/disk_manager.rs +++ b/datafusion/execution/src/disk_manager.rs @@ -18,7 +18,7 @@ //! Manages files generated during query execution, files are //! hashed among the directories listed in RuntimeConfig::local_dirs. -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; use log::debug; use parking_lot::Mutex; use rand::{thread_rng, Rng}; @@ -119,9 +119,9 @@ impl DiskManager { ) -> Result { let mut guard = self.local_dirs.lock(); let local_dirs = guard.as_mut().ok_or_else(|| { - DataFusionError::ResourcesExhausted(format!( + resources_datafusion_err!( "Memory Exhausted while {request_description} (DiskManager is disabled)" - )) + ) })?; // Create a temporary directory if needed @@ -139,7 +139,7 @@ impl DiskManager { let dir_index = thread_rng().gen_range(0..local_dirs.len()); Ok(RefCountedTempFile { - parent_temp_dir: local_dirs[dir_index].clone(), + parent_temp_dir: Arc::clone(&local_dirs[dir_index]), tempfile: Builder::new() .tempfile_in(local_dirs[dir_index].as_ref()) .map_err(DataFusionError::IoError)?, @@ -173,7 +173,7 @@ fn create_local_dirs(local_dirs: Vec) -> Result>> { local_dirs .iter() .map(|root| { - if !std::path::Path::new(root).exists() { + if !Path::new(root).exists() { std::fs::create_dir(root)?; } Builder::new() diff --git a/datafusion/execution/src/lib.rs b/datafusion/execution/src/lib.rs index a1a1551c2ca6..909364fa805d 100644 --- a/datafusion/execution/src/lib.rs +++ b/datafusion/execution/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! DataFusion execution configuration and runtime structures @@ -22,11 +24,16 @@ pub mod config; pub mod disk_manager; pub mod memory_pool; pub mod object_store; -pub mod registry; pub mod runtime_env; mod stream; mod task; +pub mod registry { + pub use datafusion_expr::registry::{ + FunctionRegistry, MemoryFunctionRegistry, SerializerRegistry, + }; +} + pub use disk_manager::DiskManager; pub use registry::FunctionRegistry; pub use stream::{RecordBatchStream, SendableRecordBatchStream}; diff --git a/datafusion/execution/src/memory_pool/mod.rs b/datafusion/execution/src/memory_pool/mod.rs index 7816f15bc259..5bf30b724d0b 100644 --- a/datafusion/execution/src/memory_pool/mod.rs +++ b/datafusion/execution/src/memory_pool/mod.rs @@ -18,11 +18,13 @@ //! [`MemoryPool`] for memory management during query execution, [`proxy]` for //! help with allocation accounting. -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use std::{cmp::Ordering, sync::Arc}; mod pool; -pub mod proxy; +pub mod proxy { + pub use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +} pub use pool::*; @@ -66,20 +68,44 @@ pub use pool::*; /// Note that a `MemoryPool` can be shared by concurrently executing plans, /// which can be used to control memory usage in a multi-tenant system. /// +/// # How MemoryPool works by example +/// +/// Scenario 1: +/// For `Filter` operator, `RecordBatch`es will stream through it, so it +/// don't have to keep track of memory usage through [`MemoryPool`]. +/// +/// Scenario 2: +/// For `CrossJoin` operator, if the input size gets larger, the intermediate +/// state will also grow. So `CrossJoin` operator will use [`MemoryPool`] to +/// limit the memory usage. +/// 2.1 `CrossJoin` operator has read a new batch, asked memory pool for +/// additional memory. Memory pool updates the usage and returns success. +/// 2.2 `CrossJoin` has read another batch, and tries to reserve more memory +/// again, memory pool does not have enough memory. Since `CrossJoin` operator +/// has not implemented spilling, it will stop execution and return an error. +/// +/// Scenario 3: +/// For `Aggregate` operator, its intermediate states will also accumulate as +/// the input size gets larger, but with spilling capability. When it tries to +/// reserve more memory from the memory pool, and the memory pool has already +/// reached the memory limit, it will return an error. Then, `Aggregate` +/// operator will spill the intermediate buffers to disk, and release memory +/// from the memory pool, and continue to retry memory reservation. +/// /// # Implementing `MemoryPool` /// /// You can implement a custom allocation policy by implementing the /// [`MemoryPool`] trait and configuring a `SessionContext` appropriately. -/// However, mDataFusion comes with the following simple memory pool implementations that +/// However, DataFusion comes with the following simple memory pool implementations that /// handle many common cases: /// /// * [`UnboundedMemoryPool`]: no memory limits (the default) /// /// * [`GreedyMemoryPool`]: Limits memory usage to a fixed size using a "first -/// come first served" policy +/// come first served" policy /// /// * [`FairSpillPool`]: Limits memory usage to a fixed size, allocating memory -/// to all spilling operators fairly +/// to all spilling operators fairly pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// Registers a new [`MemoryConsumer`] /// @@ -115,7 +141,7 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug { /// For help with allocation accounting, see the [proxy] module. /// /// [proxy]: crate::memory_pool::proxy -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct MemoryConsumer { name: String, can_spill: bool, @@ -218,6 +244,23 @@ impl MemoryReservation { self.size = new_size } + /// Tries to free `capacity` bytes from this reservation + /// if `capacity` does not exceed [`Self::size`] + /// Returns new reservation size + /// or error if shrinking capacity is more than allocated size + pub fn try_shrink(&mut self, capacity: usize) -> Result { + if let Some(new_size) = self.size.checked_sub(capacity) { + self.registration.pool.shrink(self, capacity); + self.size = new_size; + Ok(new_size) + } else { + internal_err!( + "Cannot free the capacity {capacity} out of allocated size {}", + self.size + ) + } + } + /// Sets the size of this reservation to `capacity` pub fn resize(&mut self, capacity: usize) { match capacity.cmp(&self.size) { @@ -266,7 +309,7 @@ impl MemoryReservation { self.size = self.size.checked_sub(capacity).unwrap(); Self { size: capacity, - registration: self.registration.clone(), + registration: Arc::clone(&self.registration), } } @@ -274,7 +317,7 @@ impl MemoryReservation { pub fn new_empty(&self) -> Self { Self { size: 0, - registration: self.registration.clone(), + registration: Arc::clone(&self.registration), } } @@ -291,13 +334,17 @@ impl Drop for MemoryReservation { } } -const TB: u64 = 1 << 40; -const GB: u64 = 1 << 30; -const MB: u64 = 1 << 20; -const KB: u64 = 1 << 10; +pub mod units { + pub const TB: u64 = 1 << 40; + pub const GB: u64 = 1 << 30; + pub const MB: u64 = 1 << 20; + pub const KB: u64 = 1 << 10; +} /// Present size in human readable form pub fn human_readable_size(size: usize) -> String { + use units::*; + let size = size as u64; let (value, unit) = { if size >= 2 * TB { diff --git a/datafusion/execution/src/memory_pool/pool.rs b/datafusion/execution/src/memory_pool/pool.rs index 4a491630fe20..e169c1f319cc 100644 --- a/datafusion/execution/src/memory_pool/pool.rs +++ b/datafusion/execution/src/memory_pool/pool.rs @@ -16,10 +16,14 @@ // under the License. use crate::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{resources_datafusion_err, DataFusionError, Result}; +use hashbrown::HashMap; use log::debug; use parking_lot::Mutex; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + num::NonZeroUsize, + sync::atomic::{AtomicU64, AtomicUsize, Ordering}, +}; /// A [`MemoryPool`] that enforces no limit #[derive(Debug, Default)] @@ -231,12 +235,164 @@ impl MemoryPool for FairSpillPool { } } +/// Constructs a resources error based upon the individual [`MemoryReservation`]. +/// +/// The error references the `bytes already allocated` for the reservation, +/// and not the total within the collective [`MemoryPool`], +/// nor the total across multiple reservations with the same [`MemoryConsumer`]. +#[inline(always)] fn insufficient_capacity_err( reservation: &MemoryReservation, additional: usize, available: usize, ) -> DataFusionError { - DataFusionError::ResourcesExhausted(format!("Failed to allocate additional {} bytes for {} with {} bytes already allocated - maximum available is {}", additional, reservation.registration.consumer.name, reservation.size, available)) + resources_datafusion_err!("Failed to allocate additional {} bytes for {} with {} bytes already allocated for this reservation - {} bytes remain available for the total pool", additional, reservation.registration.consumer.name, reservation.size, available) +} + +/// A [`MemoryPool`] that tracks the consumers that have +/// reserved memory within the inner memory pool. +/// +/// By tracking memory reservations more carefully this pool +/// can provide better error messages on the largest memory users +/// +/// Tracking is per hashed [`MemoryConsumer`], not per [`MemoryReservation`]. +/// The same consumer can have multiple reservations. +#[derive(Debug)] +pub struct TrackConsumersPool { + inner: I, + top: NonZeroUsize, + tracked_consumers: Mutex>, +} + +impl TrackConsumersPool { + /// Creates a new [`TrackConsumersPool`]. + /// + /// The `top` determines how many Top K [`MemoryConsumer`]s to include + /// in the reported [`DataFusionError::ResourcesExhausted`]. + pub fn new(inner: I, top: NonZeroUsize) -> Self { + Self { + inner, + top, + tracked_consumers: Default::default(), + } + } + + /// Determine if there are multiple [`MemoryConsumer`]s registered + /// which have the same name. + /// + /// This is very tied to the implementation of the memory consumer. + fn has_multiple_consumers(&self, name: &String) -> bool { + let consumer = MemoryConsumer::new(name); + let consumer_with_spill = consumer.clone().with_can_spill(true); + let guard = self.tracked_consumers.lock(); + guard.contains_key(&consumer) && guard.contains_key(&consumer_with_spill) + } + + /// The top consumers in a report string. + pub fn report_top(&self, top: usize) -> String { + let mut consumers = self + .tracked_consumers + .lock() + .iter() + .map(|(consumer, reserved)| { + ( + (consumer.name().to_owned(), consumer.can_spill()), + reserved.load(Ordering::Acquire), + ) + }) + .collect::>(); + consumers.sort_by(|a, b| b.1.cmp(&a.1)); // inverse ordering + + consumers[0..std::cmp::min(top, consumers.len())] + .iter() + .map(|((name, can_spill), size)| { + if self.has_multiple_consumers(name) { + format!("{name}(can_spill={}) consumed {:?} bytes", can_spill, size) + } else { + format!("{name} consumed {:?} bytes", size) + } + }) + .collect::>() + .join(", ") + } +} + +impl MemoryPool for TrackConsumersPool { + fn register(&self, consumer: &MemoryConsumer) { + self.inner.register(consumer); + + let mut guard = self.tracked_consumers.lock(); + if let Some(already_reserved) = guard.insert(consumer.clone(), Default::default()) + { + guard.entry_ref(consumer).and_modify(|bytes| { + bytes.fetch_add( + already_reserved.load(Ordering::Acquire), + Ordering::AcqRel, + ); + }); + } + } + + fn unregister(&self, consumer: &MemoryConsumer) { + self.inner.unregister(consumer); + self.tracked_consumers.lock().remove(consumer); + } + + fn grow(&self, reservation: &MemoryReservation, additional: usize) { + self.inner.grow(reservation, additional); + self.tracked_consumers + .lock() + .entry_ref(reservation.consumer()) + .and_modify(|bytes| { + bytes.fetch_add(additional as u64, Ordering::AcqRel); + }); + } + + fn shrink(&self, reservation: &MemoryReservation, shrink: usize) { + self.inner.shrink(reservation, shrink); + self.tracked_consumers + .lock() + .entry_ref(reservation.consumer()) + .and_modify(|bytes| { + bytes.fetch_sub(shrink as u64, Ordering::AcqRel); + }); + } + + fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> { + self.inner + .try_grow(reservation, additional) + .map_err(|e| match e { + DataFusionError::ResourcesExhausted(e) => { + // wrap OOM message in top consumers + DataFusionError::ResourcesExhausted( + provide_top_memory_consumers_to_error_msg( + e, + self.report_top(self.top.into()), + ), + ) + } + _ => e, + })?; + + self.tracked_consumers + .lock() + .entry_ref(reservation.consumer()) + .and_modify(|bytes| { + bytes.fetch_add(additional as u64, Ordering::AcqRel); + }); + Ok(()) + } + + fn reserved(&self) -> usize { + self.inner.reserved() + } +} + +fn provide_top_memory_consumers_to_error_msg( + error_msg: String, + top_consumers: String, +) -> String { + format!("Additional allocation failed with top memory consumers (across reservations) as: {}. Error: {}", top_consumers, error_msg) } #[cfg(test)] @@ -262,10 +418,10 @@ mod tests { assert_eq!(pool.reserved(), 4000); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); let err = r2.try_grow(1).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated - maximum available is 0"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 1 bytes for r2 with 2000 bytes already allocated for this reservation - 0 bytes remain available for the total pool"); r1.shrink(1990); r2.shrink(2000); @@ -290,12 +446,12 @@ mod tests { .register(&pool); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); //Shrinking r2 to zero doesn't allow a3 to allocate more than 45 r2.free(); let err = r3.try_grow(70).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated - maximum available is 40"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 70 bytes for r3 with 0 bytes already allocated for this reservation - 40 bytes remain available for the total pool"); // But dropping r2 does drop(r2); @@ -308,6 +464,226 @@ mod tests { let mut r4 = MemoryConsumer::new("s4").register(&pool); let err = r4.try_grow(30).unwrap_err().strip_backtrace(); - assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated - maximum available is 20"); + assert_eq!(err, "Resources exhausted: Failed to allocate additional 30 bytes for s4 with 0 bytes already allocated for this reservation - 20 bytes remain available for the total pool"); + } + + #[test] + fn test_tracked_consumers_pool() { + let pool: Arc = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + + // Test: use all the different interfaces to change reservation size + + // set r1=50, using grow and shrink + let mut r1 = MemoryConsumer::new("r1").register(&pool); + r1.grow(70); + r1.shrink(20); + + // set r2=15 using try_grow + let mut r2 = MemoryConsumer::new("r2").register(&pool); + r2.try_grow(15) + .expect("should succeed in memory allotment for r2"); + + // set r3=20 using try_resize + let mut r3 = MemoryConsumer::new("r3").register(&pool); + r3.try_resize(25) + .expect("should succeed in memory allotment for r3"); + r3.try_resize(20) + .expect("should succeed in memory allotment for r3"); + + // set r4=10 + // this should not be reported in top 3 + let mut r4 = MemoryConsumer::new("r4").register(&pool); + r4.grow(10); + + // Test: reports if new reservation causes error + // using the previously set sizes for other consumers + let mut r5 = MemoryConsumer::new("r5").register(&pool); + let expected = "Additional allocation failed with top memory consumers (across reservations) as: r1 consumed 50 bytes, r3 consumed 20 bytes, r2 consumed 15 bytes. Error: Failed to allocate additional 150 bytes for r5 with 0 bytes already allocated for this reservation - 5 bytes remain available for the total pool"; + let res = r5.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide list of top memory consumers, instead found {:?}", + res + ); + } + + #[test] + fn test_tracked_consumers_pool_register() { + let pool: Arc = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + + let same_name = "foo"; + + // Test: see error message when no consumers recorded yet + let mut r0 = MemoryConsumer::new(same_name).register(&pool); + let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 100 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error when no reservations have been made yet, instead found {:?}", res + ); + + // API: multiple registrations using the same hashed consumer, + // will be recognized as the same in the TrackConsumersPool. + + // Test: will be the same per Top Consumers reported. + r0.grow(10); // make r0=10, pool available=90 + let new_consumer_same_name = MemoryConsumer::new(same_name); + let mut r1 = new_consumer_same_name.register(&pool); + // TODO: the insufficient_capacity_err() message is per reservation, not per consumer. + // a followup PR will clarify this message "0 bytes already allocated for this reservation" + let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo consumed 10 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; + let res = r1.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with same hashed consumer (a single foo=10 bytes, available=90), instead found {:?}", res + ); + + // Test: will accumulate size changes per consumer, not per reservation + r1.grow(20); + let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo consumed 30 bytes. Error: Failed to allocate additional 150 bytes for foo with 20 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r1.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with same hashed consumer (a single foo=30 bytes, available=70), instead found {:?}", res + ); + + // Test: different hashed consumer, (even with the same name), + // will be recognized as different in the TrackConsumersPool + let consumer_with_same_name_but_different_hash = + MemoryConsumer::new(same_name).with_can_spill(true); + let mut r2 = consumer_with_same_name_but_different_hash.register(&pool); + let expected = "Additional allocation failed with top memory consumers (across reservations) as: foo(can_spill=false) consumed 30 bytes, foo(can_spill=true) consumed 0 bytes. Error: Failed to allocate additional 150 bytes for foo with 0 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r2.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with different hashed consumer (foo(can_spill=false)=30 bytes and foo(can_spill=true)=0 bytes, available=70), instead found {:?}", res + ); + } + + #[test] + fn test_tracked_consumers_pool_deregister() { + fn test_per_pool_type(pool: Arc) { + // Baseline: see the 2 memory consumers + let mut r0 = MemoryConsumer::new("r0").register(&pool); + r0.grow(10); + let r1_consumer = MemoryConsumer::new("r1"); + let mut r1 = r1_consumer.clone().register(&pool); + r1.grow(20); + let expected = "Additional allocation failed with top memory consumers (across reservations) as: r1 consumed 20 bytes, r0 consumed 10 bytes. Error: Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected) + ), + "should provide proper error with both consumers, instead found {:?}", + res + ); + + // Test: unregister one + // only the remaining one should be listed + pool.unregister(&r1_consumer); + let expected_consumers = "Additional allocation failed with top memory consumers (across reservations) as: r0 consumed 10 bytes"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_consumers) + ), + "should provide proper error with only 1 consumer left registered, instead found {:?}", res + ); + + // Test: actual message we see is the `available is 70`. When it should be `available is 90`. + // This is because the pool.shrink() does not automatically occur within the inner_pool.deregister(). + let expected_70_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 70 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_70_available) + ), + "should find that the inner pool will still count all bytes for the deregistered consumer until the reservation is dropped, instead found {:?}", res + ); + + // Test: the registration needs to free itself (or be dropped), + // for the proper error message + r1.free(); + let expected_90_available = "Failed to allocate additional 150 bytes for r0 with 10 bytes already allocated for this reservation - 90 bytes remain available for the total pool"; + let res = r0.try_grow(150); + assert!( + matches!( + &res, + Err(DataFusionError::ResourcesExhausted(ref e)) if e.to_string().contains(expected_90_available) + ), + "should correctly account the total bytes after reservation is free, instead found {:?}", res + ); + } + + let tracked_spill_pool: Arc = Arc::new(TrackConsumersPool::new( + FairSpillPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + test_per_pool_type(tracked_spill_pool); + + let tracked_greedy_pool: Arc = Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + test_per_pool_type(tracked_greedy_pool); + } + + #[test] + fn test_tracked_consumers_pool_use_beyond_errors() { + let upcasted: Arc = + Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(100), + NonZeroUsize::new(3).unwrap(), + )); + let pool: Arc = Arc::clone(&upcasted) + .downcast::>() + .unwrap(); + // set r1=20 + let mut r1 = MemoryConsumer::new("r1").register(&pool); + r1.grow(20); + // set r2=15 + let mut r2 = MemoryConsumer::new("r2").register(&pool); + r2.grow(15); + // set r3=45 + let mut r3 = MemoryConsumer::new("r3").register(&pool); + r3.grow(45); + + let downcasted = upcasted + .downcast::>() + .unwrap(); + + // Test: can get runtime metrics, even without an error thrown + let expected = "r3 consumed 45 bytes, r1 consumed 20 bytes"; + let res = downcasted.report_top(2); + assert_eq!( + res, expected, + "should provide list of top memory consumers, instead found {:?}", + res + ); } } diff --git a/datafusion/execution/src/object_store.rs b/datafusion/execution/src/object_store.rs index c0c58a87dcc6..cd75c9f3c49e 100644 --- a/datafusion/execution/src/object_store.rs +++ b/datafusion/execution/src/object_store.rs @@ -27,7 +27,12 @@ use object_store::ObjectStore; use std::sync::Arc; use url::Url; -/// A parsed URL identifying a particular [`ObjectStore`] +/// A parsed URL identifying a particular [`ObjectStore`] instance +/// +/// For example: +/// * `file://` for local file system +/// * `s3://bucket` for AWS S3 bucket +/// * `oss://bucket` for Aliyun OSS bucket #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct ObjectStoreUrl { url: Url, @@ -35,6 +40,19 @@ pub struct ObjectStoreUrl { impl ObjectStoreUrl { /// Parse an [`ObjectStoreUrl`] from a string + /// + /// # Example + /// ``` + /// # use url::Url; + /// # use datafusion_execution::object_store::ObjectStoreUrl; + /// let object_store_url = ObjectStoreUrl::parse("s3://bucket").unwrap(); + /// assert_eq!(object_store_url.as_str(), "s3://bucket/"); + /// // can also access the underlying `Url` + /// let url: &Url = object_store_url.as_ref(); + /// assert_eq!(url.scheme(), "s3"); + /// assert_eq!(url.host_str(), Some("bucket")); + /// assert_eq!(url.path(), "/"); + /// ``` pub fn parse(s: impl AsRef) -> Result { let mut parsed = Url::parse(s.as_ref()).map_err(|e| DataFusionError::External(Box::new(e)))?; @@ -51,7 +69,14 @@ impl ObjectStoreUrl { Ok(Self { url: parsed }) } - /// An [`ObjectStoreUrl`] for the local filesystem + /// An [`ObjectStoreUrl`] for the local filesystem (`file://`) + /// + /// # Example + /// ``` + /// # use datafusion_execution::object_store::ObjectStoreUrl; + /// let local_fs = ObjectStoreUrl::parse("file://").unwrap(); + /// assert_eq!(local_fs, ObjectStoreUrl::local_filesystem()) + /// ``` pub fn local_filesystem() -> Self { Self::parse("file://").unwrap() } @@ -85,11 +110,11 @@ impl std::fmt::Display for ObjectStoreUrl { /// instances. For example DataFusion might be configured so that /// /// 1. `s3://my_bucket/lineitem/` mapped to the `/lineitem` path on an -/// AWS S3 object store bound to `my_bucket` +/// AWS S3 object store bound to `my_bucket` /// /// 2. `s3://my_other_bucket/lineitem/` mapped to the (same) -/// `/lineitem` path on a *different* AWS S3 object store bound to -/// `my_other_bucket` +/// `/lineitem` path on a *different* AWS S3 object store bound to +/// `my_other_bucket` /// /// When given a [`ListingTableUrl`], DataFusion tries to find an /// appropriate [`ObjectStore`]. For example @@ -102,21 +127,21 @@ impl std::fmt::Display for ObjectStoreUrl { /// [`ObjectStoreRegistry::get_store`] and one of three things will happen: /// /// - If an [`ObjectStore`] has been registered with [`ObjectStoreRegistry::register_store`] with -/// `s3://my_bucket`, that [`ObjectStore`] will be returned +/// `s3://my_bucket`, that [`ObjectStore`] will be returned /// /// - If an AWS S3 object store can be ad-hoc discovered by the url `s3://my_bucket/lineitem/`, this -/// object store will be registered with key `s3://my_bucket` and returned. +/// object store will be registered with key `s3://my_bucket` and returned. /// /// - Otherwise an error will be returned, indicating that no suitable [`ObjectStore`] could -/// be found +/// be found /// /// This allows for two different use-cases: /// /// 1. Systems where object store buckets are explicitly created using DDL, can register these -/// buckets using [`ObjectStoreRegistry::register_store`] +/// buckets using [`ObjectStoreRegistry::register_store`] /// /// 2. Systems relying on ad-hoc discovery, without corresponding DDL, can create [`ObjectStore`] -/// lazily by providing a custom implementation of [`ObjectStoreRegistry`] +/// lazily by providing a custom implementation of [`ObjectStoreRegistry`] /// /// /// [`ListingTableUrl`]: https://docs.rs/datafusion/latest/datafusion/datasource/listing/struct.ListingTableUrl.html @@ -209,10 +234,10 @@ impl ObjectStoreRegistry for DefaultObjectStoreRegistry { let s = get_url_key(url); self.object_stores .get(&s) - .map(|o| o.value().clone()) + .map(|o| Arc::clone(o.value())) .ok_or_else(|| { DataFusionError::Internal(format!( - "No suitable object store found for {url}" + "No suitable object store found for {url}. See `RuntimeEnv::register_object_store`" )) }) } diff --git a/datafusion/execution/src/runtime_env.rs b/datafusion/execution/src/runtime_env.rs index e78a9e0de9f0..4022eb07de0c 100644 --- a/datafusion/execution/src/runtime_env.rs +++ b/datafusion/execution/src/runtime_env.rs @@ -20,23 +20,28 @@ use crate::{ disk_manager::{DiskManager, DiskManagerConfig}, - memory_pool::{GreedyMemoryPool, MemoryPool, UnboundedMemoryPool}, + memory_pool::{ + GreedyMemoryPool, MemoryPool, TrackConsumersPool, UnboundedMemoryPool, + }, object_store::{DefaultObjectStoreRegistry, ObjectStoreRegistry}, }; use crate::cache::cache_manager::{CacheManager, CacheManagerConfig}; use datafusion_common::{DataFusionError, Result}; use object_store::ObjectStore; -use std::fmt::{Debug, Formatter}; use std::path::PathBuf; use std::sync::Arc; +use std::{ + fmt::{Debug, Formatter}, + num::NonZeroUsize, +}; use url::Url; #[derive(Clone)] /// Execution runtime environment that manages system resources such /// as memory, disk, cache and storage. /// -/// A [`RuntimeEnv`] is created from a [`RuntimeConfig`] and has the +/// A [`RuntimeEnv`] is created from a [`RuntimeEnvBuilder`] and has the /// following resource management functionality: /// /// * [`MemoryPool`]: Manage memory @@ -61,8 +66,12 @@ impl Debug for RuntimeEnv { } impl RuntimeEnv { - /// Create env based on configuration + #[deprecated(note = "please use `try_new` instead")] pub fn new(config: RuntimeConfig) -> Result { + Self::try_new(config) + } + /// Create env based on configuration + pub fn try_new(config: RuntimeConfig) -> Result { let RuntimeConfig { memory_pool, disk_manager, @@ -89,6 +98,39 @@ impl RuntimeEnv { /// scheme, if any. /// /// See [`ObjectStoreRegistry`] for more details + /// + /// # Example: Register local file system object store + /// ``` + /// # use std::sync::Arc; + /// # use url::Url; + /// # use datafusion_execution::runtime_env::RuntimeEnv; + /// # let runtime_env = RuntimeEnv::try_new(Default::default()).unwrap(); + /// let url = Url::try_from("file://").unwrap(); + /// let object_store = object_store::local::LocalFileSystem::new(); + /// // register the object store with the runtime environment + /// runtime_env.register_object_store(&url, Arc::new(object_store)); + /// ``` + /// + /// # Example: Register local file system object store + /// + /// To register reading from urls such as ` + /// + /// ``` + /// # use std::sync::Arc; + /// # use url::Url; + /// # use datafusion_execution::runtime_env::RuntimeEnv; + /// # let runtime_env = RuntimeEnv::try_new(Default::default()).unwrap(); + /// # // use local store for example as http feature is not enabled + /// # let http_store = object_store::local::LocalFileSystem::new(); + /// // create a new object store via object_store::http::HttpBuilder; + /// let base_url = Url::parse("https://github.com").unwrap(); + /// // let http_store = HttpBuilder::new() + /// // .with_url(base_url.clone()) + /// // .build() + /// // .unwrap(); + /// // register the object store with the runtime environment + /// runtime_env.register_object_store(&base_url, Arc::new(http_store)); + /// ``` pub fn register_object_store( &self, url: &Url, @@ -109,13 +151,17 @@ impl RuntimeEnv { impl Default for RuntimeEnv { fn default() -> Self { - RuntimeEnv::new(RuntimeConfig::new()).unwrap() + RuntimeEnvBuilder::new().build().unwrap() } } +/// Please see: +/// This a type alias for backwards compatibility. +pub type RuntimeConfig = RuntimeEnvBuilder; + #[derive(Clone)] /// Execution runtime configuration -pub struct RuntimeConfig { +pub struct RuntimeEnvBuilder { /// DiskManager to manage temporary disk file usage pub disk_manager: DiskManagerConfig, /// [`MemoryPool`] from which to allocate memory @@ -128,13 +174,13 @@ pub struct RuntimeConfig { pub object_store_registry: Arc, } -impl Default for RuntimeConfig { +impl Default for RuntimeEnvBuilder { fn default() -> Self { Self::new() } } -impl RuntimeConfig { +impl RuntimeEnvBuilder { /// New with default values pub fn new() -> Self { Self { @@ -180,11 +226,33 @@ impl RuntimeConfig { /// Note DataFusion does not yet respect this limit in all cases. pub fn with_memory_limit(self, max_memory: usize, memory_fraction: f64) -> Self { let pool_size = (max_memory as f64 * memory_fraction) as usize; - self.with_memory_pool(Arc::new(GreedyMemoryPool::new(pool_size))) + self.with_memory_pool(Arc::new(TrackConsumersPool::new( + GreedyMemoryPool::new(pool_size), + NonZeroUsize::new(5).unwrap(), + ))) } /// Use the specified path to create any needed temporary files pub fn with_temp_file_path(self, path: impl Into) -> Self { self.with_disk_manager(DiskManagerConfig::new_specified(vec![path.into()])) } + + /// Build a RuntimeEnv + pub fn build(self) -> Result { + let memory_pool = self + .memory_pool + .unwrap_or_else(|| Arc::new(UnboundedMemoryPool::default())); + + Ok(RuntimeEnv { + memory_pool, + disk_manager: DiskManager::try_new(self.disk_manager)?, + cache_manager: CacheManager::try_new(&self.cache_manager)?, + object_store_registry: self.object_store_registry, + }) + } + + /// Convenience method to create a new `Arc` + pub fn build_arc(self) -> Result> { + self.build().map(Arc::new) + } } diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs index 7fc5e458b86b..f3eb7b77e03c 100644 --- a/datafusion/execution/src/stream.rs +++ b/datafusion/execution/src/stream.rs @@ -20,7 +20,9 @@ use datafusion_common::Result; use futures::Stream; use std::pin::Pin; -/// Trait for types that stream [arrow::record_batch::RecordBatch] +/// Trait for types that stream [RecordBatch] +/// +/// See [`SendableRecordBatchStream`] for more details. pub trait RecordBatchStream: Stream> { /// Returns the schema of this `RecordBatchStream`. /// @@ -29,5 +31,23 @@ pub trait RecordBatchStream: Stream> { fn schema(&self) -> SchemaRef; } -/// Trait for a [`Stream`] of [`RecordBatch`]es +/// Trait for a [`Stream`] of [`RecordBatch`]es that can be passed between threads +/// +/// This trait is used to retrieve the results of DataFusion execution plan nodes. +/// +/// The trait is a specialized Rust Async [`Stream`] that also knows the schema +/// of the data it will return (even if the stream has no data). Every +/// `RecordBatch` returned by the stream should have the same schema as returned +/// by [`schema`](`RecordBatchStream::schema`). +/// +/// # Error Handling +/// +/// Once a stream returns an error, it should not be polled again (the caller +/// should stop calling `next`) and handle the error. +/// +/// However, returning `Ready(None)` (end of stream) is likely the safest +/// behavior after an error. Like [`Stream`]s, `RecordBatchStream`s should not +/// be polled after end of stream or returning an error. However, also like +/// [`Stream`]s there is no mechanism to prevent callers polling so returning +/// `Ready(None)` is recommended. pub type SendableRecordBatchStream = Pin>; diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index 4216ce95f35e..57fcac0ee5ab 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -20,21 +20,21 @@ use std::{ sync::Arc, }; -use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; -use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; - use crate::{ config::SessionConfig, memory_pool::MemoryPool, registry::FunctionRegistry, - runtime_env::{RuntimeConfig, RuntimeEnv}, + runtime_env::{RuntimeEnv, RuntimeEnvBuilder}, }; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// Task Execution Context /// -/// A [`TaskContext`] contains the state available during a single -/// query's execution. Please see [`SessionContext`] for a user level -/// multi-query API. +/// A [`TaskContext`] contains the state required during a single query's +/// execution. Please see the documentation on [`SessionContext`] for more +/// information. /// /// [`SessionContext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html #[derive(Debug)] @@ -57,8 +57,9 @@ pub struct TaskContext { impl Default for TaskContext { fn default() -> Self { - let runtime = RuntimeEnv::new(RuntimeConfig::new()) - .expect("defauly runtime created successfully"); + let runtime = RuntimeEnvBuilder::new() + .build_arc() + .expect("default runtime created successfully"); // Create a default task context, mostly useful for testing Self { @@ -68,7 +69,7 @@ impl Default for TaskContext { scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), window_functions: HashMap::new(), - runtime: Arc::new(runtime), + runtime, } } } @@ -121,7 +122,7 @@ impl TaskContext { /// Return the [RuntimeEnv] associated with this [TaskContext] pub fn runtime_env(&self) -> Arc { - self.runtime.clone() + Arc::clone(&self.runtime) } /// Update the [`SessionConfig`] @@ -172,22 +173,29 @@ impl FunctionRegistry for TaskContext { udaf: Arc, ) -> Result>> { udaf.aliases().iter().for_each(|alias| { - self.aggregate_functions.insert(alias.clone(), udaf.clone()); + self.aggregate_functions + .insert(alias.clone(), Arc::clone(&udaf)); }); Ok(self.aggregate_functions.insert(udaf.name().into(), udaf)) } fn register_udwf(&mut self, udwf: Arc) -> Result>> { udwf.aliases().iter().for_each(|alias| { - self.window_functions.insert(alias.clone(), udwf.clone()); + self.window_functions + .insert(alias.clone(), Arc::clone(&udwf)); }); Ok(self.window_functions.insert(udwf.name().into(), udwf)) } fn register_udf(&mut self, udf: Arc) -> Result>> { udf.aliases().iter().for_each(|alias| { - self.scalar_functions.insert(alias.clone(), udf.clone()); + self.scalar_functions + .insert(alias.clone(), Arc::clone(&udf)); }); Ok(self.scalar_functions.insert(udf.name().into(), udf)) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } #[cfg(test)] diff --git a/datafusion/expr-common/Cargo.toml b/datafusion/expr-common/Cargo.toml new file mode 100644 index 000000000000..de11b19c3b06 --- /dev/null +++ b/datafusion/expr-common/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-expr-common" +description = "Logical plan and expression representation for DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_expr_common" +path = "src/lib.rs" + +[features] + +[dependencies] +arrow = { workspace = true } +datafusion-common = { workspace = true } +itertools = { workspace = true } +paste = "^1.0" diff --git a/datafusion/expr/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs similarity index 53% rename from datafusion/expr/src/accumulator.rs rename to datafusion/expr-common/src/accumulator.rs index 031348269a38..7155c7993f8c 100644 --- a/datafusion/expr/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -35,12 +35,12 @@ use std::fmt::Debug; /// * compute the final value from its internal state via [`evaluate`] /// /// * retract an update to its state from given inputs via -/// [`retract_batch`] (when used as a window aggregate [window -/// function]) +/// [`retract_batch`] (when used as a window aggregate [window +/// function]) /// /// * convert its internal state to a vector of aggregate values via -/// [`state`] and combine the state from multiple accumulators' -/// via [`merge_batch`], as part of efficient multi-phase grouping. +/// [`state`] and combine the state from multiple accumulators +/// via [`merge_batch`], as part of efficient multi-phase grouping. /// /// [`GroupsAccumulator`]: crate::GroupsAccumulator /// [`update_batch`]: Self::update_batch @@ -64,11 +64,11 @@ pub trait Accumulator: Send + Sync + Debug { /// For example, the `SUM` accumulator maintains a running sum, /// and `evaluate` will produce that running sum as its output. /// - /// After this call, the accumulator's internal state should be - /// equivalent to when it was first created. + /// This function should not be called twice, otherwise it will + /// result in potentially non-deterministic behavior. /// /// This function gets `&mut self` to allow for the accumulator to build - /// arrow compatible internal state that can be returned without copying + /// arrow-compatible internal state that can be returned without copying /// when possible (for example distinct strings) fn evaluate(&mut self) -> Result; @@ -85,18 +85,18 @@ pub trait Accumulator: Send + Sync + Debug { /// Returns the intermediate state of the accumulator, consuming the /// intermediate state. /// - /// After this call, the accumulator's internal state should be - /// equivalent to when it was first created. + /// This function should not be called twice, otherwise it will + /// result in potentially non-deterministic behavior. /// /// This function gets `&mut self` to allow for the accumulator to build - /// arrow compatible internal state that can be returned without copying + /// arrow-compatible internal state that can be returned without copying /// when possible (for example distinct strings). /// /// Intermediate state is used for "multi-phase" grouping in /// DataFusion, where an aggregate is computed in parallel with - /// multiple `Accumulator` instances, as illustrated below: + /// multiple `Accumulator` instances, as described below: /// - /// # MultiPhase Grouping + /// # Multi-Phase Grouping /// /// ```text /// ▲ @@ -117,8 +117,8 @@ pub trait Accumulator: Send + Sync + Debug { /// ┌─────────────────────────┐ ┌─────────────────────────┐ /// │ GroubyBy │ │ GroubyBy │ /// │(AggregateMode::Partial) │ │(AggregateMode::Partial) │ - /// └─────────────────────────┘ └────────────▲────────────┘ - /// ▲ │ + /// └─────────────────────────┘ └─────────────────────────┘ + /// ▲ ▲ /// │ │ update_batch() is called for /// │ │ each input RecordBatch /// .─────────. .─────────. @@ -130,7 +130,7 @@ pub trait Accumulator: Send + Sync + Debug { /// `───────' `───────' /// ``` /// - /// The partial state is serialied as `Arrays` and then combined + /// The partial state is serialized as `Arrays` and then combined /// with other partial states from different instances of this /// Accumulator (that ran on different partitions, for example). /// @@ -140,13 +140,114 @@ pub trait Accumulator: Send + Sync + Debug { /// to be summed together) /// /// Some accumulators can return multiple values for their - /// intermediate states. For example average, tracks `sum` and - /// `n`, and this function should return - /// a vector of two values, sum and n. + /// intermediate states. For example, the average accumulator + /// tracks `sum` and `n`, and this function should return a vector + /// of two values, sum and n. /// /// Note that [`ScalarValue::List`] can be used to pass multiple /// values if the number of intermediate values is not known at /// planning time (e.g. for `MEDIAN`) + /// + /// # Multi-phase repartitioned Grouping + /// + /// Many multi-phase grouping plans contain a Repartition operation + /// as well as shown below: + /// + /// ```text + /// ▲ ▲ + /// │ │ + /// │ │ + /// │ │ + /// │ │ + /// │ │ + /// ┌───────────────────────┐ ┌───────────────────────┐ 4. Each AggregateMode::Final + /// │GroupBy │ │GroupBy │ GroupBy has an entry for its + /// │(AggregateMode::Final) │ │(AggregateMode::Final) │ subset of groups (in this case + /// │ │ │ │ that means half the entries) + /// └───────────────────────┘ └───────────────────────┘ + /// ▲ ▲ + /// │ │ + /// └─────────────┬────────────┘ + /// │ + /// │ + /// │ + /// ┌─────────────────────────┐ 3. Repartitioning by hash(group + /// │ Repartition │ keys) ensures that each distinct + /// │ HASH(x) │ group key now appears in exactly + /// └─────────────────────────┘ one partition + /// ▲ + /// │ + /// ┌───────────────┴─────────────┐ + /// │ │ + /// │ │ + /// ┌─────────────────────────┐ ┌──────────────────────────┐ 2. Each AggregateMode::Partial + /// │ GroubyBy │ │ GroubyBy │ GroupBy has an entry for *all* + /// │(AggregateMode::Partial) │ │ (AggregateMode::Partial) │ the groups + /// └─────────────────────────┘ └──────────────────────────┘ + /// ▲ ▲ + /// │ │ + /// │ │ + /// .─────────. .─────────. + /// ,─' '─. ,─' '─. + /// ; Input : ; Input : 1. Since input data is + /// : Partition 0 ; : Partition 1 ; arbitrarily or RoundRobin + /// ╲ ╱ ╲ ╱ distributed, each partition + /// '─. ,─' '─. ,─' likely has all distinct + /// `───────' `───────' + /// ``` + /// + /// This structure is used so that the `AggregateMode::Partial` accumulators + /// reduces the cardinality of the input as soon as possible. Typically, + /// each partial accumulator sees all groups in the input as the group keys + /// are evenly distributed across the input. + /// + /// The final output is computed by repartitioning the result of + /// [`Self::state`] from each Partial aggregate and `hash(group keys)` so + /// that each distinct group key appears in exactly one of the + /// `AggregateMode::Final` GroupBy nodes. The outputs of the final nodes are + /// then unioned together to produce the overall final output. + /// + /// Here is an example that shows the distribution of groups in the + /// different phases + /// + /// ```text + /// ┌─────┐ ┌─────┐ + /// │ 1 │ │ 3 │ + /// ├─────┤ ├─────┤ + /// │ 2 │ │ 4 │ After repartitioning by + /// └─────┘ └─────┘ hash(group keys), each distinct + /// ┌─────┐ ┌─────┐ group key now appears in exactly + /// │ 1 │ │ 3 │ one partition + /// ├─────┤ ├─────┤ + /// │ 2 │ │ 4 │ + /// └─────┘ └─────┘ + /// + /// + /// ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ + /// + /// ┌─────┐ ┌─────┐ + /// │ 2 │ │ 2 │ + /// ├─────┤ ├─────┤ + /// │ 1 │ │ 2 │ + /// ├─────┤ ├─────┤ + /// │ 3 │ │ 3 │ + /// ├─────┤ ├─────┤ + /// │ 4 │ │ 1 │ + /// └─────┘ └─────┘ Input data is arbitrarily or + /// ... ... RoundRobin distributed, each + /// ┌─────┐ ┌─────┐ partition likely has all + /// │ 1 │ │ 4 │ distinct group keys + /// ├─────┤ ├─────┤ + /// │ 4 │ │ 3 │ + /// ├─────┤ ├─────┤ + /// │ 1 │ │ 1 │ + /// ├─────┤ ├─────┤ + /// │ 4 │ │ 3 │ + /// └─────┘ └─────┘ + /// + /// group values group values + /// in partition 0 in partition 1 + /// ``` fn state(&mut self) -> Result>; /// Updates the accumulator's state from an `Array` containing one diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs similarity index 75% rename from datafusion/expr/src/columnar_value.rs rename to datafusion/expr-common/src/columnar_value.rs index 87c3c063b91a..4b9454ed739d 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -15,25 +15,74 @@ // specific language governing permissions and limitations // under the License. -//! Columnar value module contains a set of types that represent a columnar value. +//! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::ArrayRef; -use arrow::array::NullArray; +use arrow::array::{Array, ArrayRef, NullArray}; use arrow::compute::{kernels, CastOptions}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::DataType; use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; -/// Represents the result of evaluating an expression: either a single -/// [`ScalarValue`] or an [`ArrayRef`]. +/// The result of evaluating an expression. /// -/// While a [`ColumnarValue`] can always be converted into an array -/// for convenience, it is often much more performant to provide an -/// optimized path for scalar values. +/// [`ColumnarValue::Scalar`] represents a single value repeated any number of +/// times. This is an important performance optimization for handling values +/// that do not change across rows. /// -/// See [`ColumnarValue::values_to_arrays`] for a function that converts -/// multiple columnar values into arrays of the same length. +/// [`ColumnarValue::Array`] represents a column of data, stored as an Arrow +/// [`ArrayRef`] +/// +/// A slice of `ColumnarValue`s logically represents a table, with each column +/// having the same number of rows. This means that all `Array`s are the same +/// length. +/// +/// # Example +/// +/// A `ColumnarValue::Array` with an array of 5 elements and a +/// `ColumnarValue::Scalar` with the value 100 +/// +/// ```text +/// ┌──────────────┐ +/// │ ┌──────────┐ │ +/// │ │ "A" │ │ +/// │ ├──────────┤ │ +/// │ │ "B" │ │ +/// │ ├──────────┤ │ +/// │ │ "C" │ │ +/// │ ├──────────┤ │ +/// │ │ "D" │ │ ┌──────────────┐ +/// │ ├──────────┤ │ │ ┌──────────┐ │ +/// │ │ "E" │ │ │ │ 100 │ │ +/// │ └──────────┘ │ │ └──────────┘ │ +/// └──────────────┘ └──────────────┘ +/// +/// ColumnarValue:: ColumnarValue:: +/// Array Scalar +/// ``` +/// +/// Logically represents the following table: +/// +/// | Column 1| Column 2 | +/// | ------- | -------- | +/// | A | 100 | +/// | B | 100 | +/// | C | 100 | +/// | D | 100 | +/// | E | 100 | +/// +/// # Performance Notes +/// +/// When implementing functions or operators, it is important to consider the +/// performance implications of handling scalar values. +/// +/// Because all functions must handle [`ArrayRef`], it is +/// convenient to convert [`ColumnarValue::Scalar`]s using +/// [`Self::into_array`]. For example, [`ColumnarValue::values_to_arrays`] +/// converts multiple columnar values into arrays of the same length. +/// +/// However, it is often much more performant to provide a different, +/// implementation that handles scalar values differently #[derive(Clone, Debug)] pub enum ColumnarValue { /// Array of values @@ -80,7 +129,7 @@ impl ColumnarValue { }) } - /// null columnar values are implemented as a null array in order to pass batch + /// Null columnar values are implemented as a null array in order to pass batch /// num_rows pub fn create_null_array(num_rows: usize) -> Self { ColumnarValue::Array(Arc::new(NullArray::new(num_rows))) @@ -144,28 +193,9 @@ impl ColumnarValue { ColumnarValue::Array(array) => Ok(ColumnarValue::Array( kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), - ColumnarValue::Scalar(scalar) => { - let scalar_array = - if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { - if let ScalarValue::Float64(Some(float_ts)) = scalar { - ScalarValue::Int64(Some( - (float_ts * 1_000_000_000_f64).trunc() as i64, - )) - .to_array()? - } else { - scalar.to_array()? - } - } else { - scalar.to_array()? - }; - let cast_array = kernels::cast::cast_with_options( - &scalar_array, - cast_type, - &cast_options, - )?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( + scalar.cast_to_with_options(cast_type, &cast_options)?, + )), } } } diff --git a/datafusion/expr/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs similarity index 56% rename from datafusion/expr/src/groups_accumulator.rs rename to datafusion/expr-common/src/groups_accumulator.rs index 2ffbfb266e9c..2c8b126cb52c 100644 --- a/datafusion/expr/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -17,8 +17,8 @@ //! Vectorized [`GroupsAccumulator`] -use arrow_array::{ArrayRef, BooleanArray}; -use datafusion_common::Result; +use arrow::array::{ArrayRef, BooleanArray}; +use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. #[derive(Debug, Clone, Copy)] @@ -29,7 +29,7 @@ pub enum EmitTo { /// indexes down by `n`. /// /// For example, if `n=10`, group_index `0, 1, ... 9` are emitted - /// and group indexes '`10, 11, 12, ...` become `0, 1, 2, ...`. + /// and group indexes `10, 11, 12, ...` become `0, 1, 2, ...`. First(usize), } @@ -56,9 +56,32 @@ impl EmitTo { } } -/// `GroupAccumulator` implements a single aggregate (e.g. AVG) and +/// `GroupsAccumulator` implements a single aggregate (e.g. AVG) and /// stores the state for *all* groups internally. /// +/// Logically, a [`GroupsAccumulator`] stores a mapping from each group index to +/// the state of the aggregate for that group. For example an implementation for +/// `min` might look like +/// +/// ```text +/// ┌─────┐ +/// │ 0 │───────────▶ 100 +/// ├─────┤ +/// │ 1 │───────────▶ 200 +/// └─────┘ +/// ... ... +/// ┌─────┐ +/// │ N-2 │───────────▶ 50 +/// ├─────┤ +/// │ N-1 │───────────▶ 200 +/// └─────┘ +/// +/// +/// Logical group Current Min +/// number value for that +/// group +/// ``` +/// /// # Notes on Implementing `GroupAccumulator` /// /// All aggregates must first implement the simpler [`Accumulator`] trait, which @@ -67,6 +90,11 @@ impl EmitTo { /// faster for queries with many group values. See the [Aggregating Millions of /// Groups Fast blog] for more background. /// +/// [`NullState`] can help keep the state for groups that have not seen any +/// values and produce the correct output for those groups. +/// +/// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html +/// /// # Details /// Each group is assigned a `group_index` by the hash table and each /// accumulator manages the specific state, one per `group_index`. @@ -75,7 +103,7 @@ impl EmitTo { /// expected that each `GroupAccumulator` will use something like `Vec<..>` /// to store the group states. /// -/// [`Accumulator`]: crate::Accumulator +/// [`Accumulator`]: crate::accumulator::Accumulator /// [Aggregating Millions of Groups Fast blog]: https://arrow.apache.org/blog/2023/08/05/datafusion_fast_grouping/ pub trait GroupsAccumulator: Send { /// Updates the accumulator's state from its arguments, encoded as @@ -83,17 +111,21 @@ pub trait GroupsAccumulator: Send { /// /// * `values`: the input arguments to the accumulator /// - /// * `group_indices`: To which groups do the rows in `values` - /// belong, group id) + /// * `group_indices`: The group indices to which each row in `values` belongs. /// /// * `opt_filter`: if present, only update aggregate state using - /// `values[i]` if `opt_filter[i]` is true + /// `values[i]` if `opt_filter[i]` is true /// /// * `total_num_groups`: the number of groups (the largest - /// group_index is thus `total_num_groups - 1`). + /// group_index is thus `total_num_groups - 1`). /// /// Note that subsequent calls to update_batch may have larger /// total_num_groups as new groups are seen. + /// + /// See [`NullState`] to help keep the state for groups that have not seen any + /// values and produce the correct output for those groups. + /// + /// [`NullState`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/struct.NullState.html fn update_batch( &mut self, values: &[ArrayRef], @@ -113,7 +145,7 @@ pub trait GroupsAccumulator: Send { /// each group, and `evaluate` will produce that running sum as /// its output for all groups, in group_index order /// - /// If `emit_to`` is [`EmitTo::All`], the accumulator should + /// If `emit_to` is [`EmitTo::All`], the accumulator should /// return all groups and release / reset its internal state /// equivalent to when it was first created. /// @@ -128,6 +160,9 @@ pub trait GroupsAccumulator: Send { /// Returns the intermediate aggregate state for this accumulator, /// used for multi-phase grouping, resetting its internal state. /// + /// See [`Accumulator::state`] for more information on multi-phase + /// aggregation. + /// /// For example, `AVG` might return two arrays: `SUM` and `COUNT` /// but the `MIN` aggregate would just return a single array. /// @@ -135,11 +170,13 @@ pub trait GroupsAccumulator: Send { /// single `StructArray` rather than multiple arrays. /// /// See [`Self::evaluate`] for details on the required output - /// order and `emit_to`. + /// order and `emit_to`. + /// + /// [`Accumulator::state`]: crate::accumulator::Accumulator::state fn state(&mut self, emit_to: EmitTo) -> Result>; /// Merges intermediate state (the output from [`Self::state`]) - /// into this accumulator's values. + /// into this accumulator's current state. /// /// For some aggregates (such as `SUM`), `merge_batch` is the same /// as `update_batch`, but for some aggregates (such as `COUNT`, @@ -147,9 +184,9 @@ pub trait GroupsAccumulator: Send { /// differ. See [`Self::state`] for more details on how state is /// used and merged. /// - /// * `values`: arrays produced from calling `state` previously to the accumulator + /// * `values`: arrays produced from previously calling `state` on other accumulators. /// - /// Other arguments are the same as for [`Self::update_batch`]; + /// Other arguments are the same as for [`Self::update_batch`]. fn merge_batch( &mut self, values: &[ArrayRef], @@ -158,8 +195,59 @@ pub trait GroupsAccumulator: Send { total_num_groups: usize, ) -> Result<()>; + /// Converts an input batch directly to the intermediate aggregate state. + /// + /// This is the equivalent of treating each input row as its own group. It + /// is invoked when the Partial phase of a multi-phase aggregation is not + /// reducing the cardinality enough to warrant spending more effort on + /// pre-aggregation (see `Background` section below), and switches to + /// passing intermediate state directly on to the next aggregation phase. + /// + /// Examples: + /// * `COUNT`: an array of 1s for each row in the input batch. + /// * `SUM/MIN/MAX`: the input values themselves. + /// + /// # Arguments + /// * `values`: the input arguments to the accumulator + /// * `opt_filter`: if present, any row where `opt_filter[i]` is false should be ignored + /// + /// # Background + /// + /// In a multi-phase aggregation (see [`Accumulator::state`]), the initial + /// Partial phase reduces the cardinality of the input data as soon as + /// possible in the plan. + /// + /// This strategy is very effective for queries with a small number of + /// groups, as most of the data is aggregated immediately and only a small + /// amount of data must be repartitioned (see [`Accumulator::state`] for + /// background) + /// + /// However, for queries with a large number of groups, the Partial phase + /// often does not reduce the cardinality enough to warrant the memory and + /// CPU cost of actually performing the aggregation. For such cases, the + /// HashAggregate operator will dynamically switch to passing intermediate + /// state directly to the next aggregation phase with minimal processing + /// using this method. + /// + /// [`Accumulator::state`]: crate::accumulator::Accumulator::state + fn convert_to_state( + &self, + _values: &[ArrayRef], + _opt_filter: Option<&BooleanArray>, + ) -> Result> { + not_impl_err!("Input batch conversion to state not implemented") + } + + /// Returns `true` if [`Self::convert_to_state`] is implemented to support + /// intermediate aggregate state conversion. + fn supports_convert_to_state(&self) -> bool { + false + } + /// Amount of memory used to store the state of this accumulator, - /// in bytes. This function is called once per batch, so it should - /// be `O(n)` to compute, not `O(num_groups)` + /// in bytes. + /// + /// This function is called once per batch, so it should be `O(n)` to + /// compute, not `O(num_groups)` fn size(&self) -> usize; } diff --git a/datafusion/expr/src/interval_arithmetic.rs b/datafusion/expr-common/src/interval_arithmetic.rs similarity index 94% rename from datafusion/expr/src/interval_arithmetic.rs rename to datafusion/expr-common/src/interval_arithmetic.rs index ca91a8c9da00..ffaa32f08075 100644 --- a/datafusion/expr/src/interval_arithmetic.rs +++ b/datafusion/expr-common/src/interval_arithmetic.rs @@ -17,16 +17,16 @@ //! Interval arithmetic library +use crate::operator::Operator; +use crate::type_coercion::binary::get_result_type; use std::borrow::Borrow; use std::fmt::{self, Display, Formatter}; use std::ops::{AddAssign, SubAssign}; -use crate::type_coercion::binary::get_result_type; -use crate::Operator; - use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::DataType; -use arrow::datatypes::{IntervalUnit, TimeUnit}; +use arrow::datatypes::{ + DataType, IntervalDayTime, IntervalMonthDayNano, IntervalUnit, TimeUnit, +}; use datafusion_common::rounding::{alter_fp_rounding_mode, next_down, next_up}; use datafusion_common::{internal_err, Result, ScalarValue}; @@ -71,10 +71,10 @@ macro_rules! get_extreme_value { ScalarValue::IntervalYearMonth(Some(i32::$extreme)) } DataType::Interval(IntervalUnit::DayTime) => { - ScalarValue::IntervalDayTime(Some(i64::$extreme)) + ScalarValue::IntervalDayTime(Some(IntervalDayTime::$extreme)) } DataType::Interval(IntervalUnit::MonthDayNano) => { - ScalarValue::IntervalMonthDayNano(Some(i128::$extreme)) + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano::$extreme)) } _ => unreachable!(), } @@ -119,8 +119,14 @@ macro_rules! value_transition { IntervalYearMonth(Some(value)) if value == i32::$bound => { IntervalYearMonth(None) } - IntervalDayTime(Some(value)) if value == i64::$bound => IntervalDayTime(None), - IntervalMonthDayNano(Some(value)) if value == i128::$bound => { + IntervalDayTime(Some(value)) + if value == arrow::datatypes::IntervalDayTime::$bound => + { + IntervalDayTime(None) + } + IntervalMonthDayNano(Some(value)) + if value == arrow::datatypes::IntervalMonthDayNano::$bound => + { IntervalMonthDayNano(None) } _ => next_value_helper::<$direction>($value), @@ -273,19 +279,34 @@ impl Interval { unreachable!(); }; // Standardize boolean interval endpoints: - Self { + return Self { lower: ScalarValue::Boolean(Some(lower_bool.unwrap_or(false))), upper: ScalarValue::Boolean(Some(upper_bool.unwrap_or(true))), - } + }; } - // Standardize floating-point endpoints: - else if lower.data_type() == DataType::Float32 { - handle_float_intervals!(Float32, f32, lower, upper) - } else if lower.data_type() == DataType::Float64 { - handle_float_intervals!(Float64, f64, lower, upper) - } else { + match lower.data_type() { + // Standardize floating-point endpoints: + DataType::Float32 => handle_float_intervals!(Float32, f32, lower, upper), + DataType::Float64 => handle_float_intervals!(Float64, f64, lower, upper), + // Unsigned null values for lower bounds are set to zero: + DataType::UInt8 if lower.is_null() => Self { + lower: ScalarValue::UInt8(Some(0)), + upper, + }, + DataType::UInt16 if lower.is_null() => Self { + lower: ScalarValue::UInt16(Some(0)), + upper, + }, + DataType::UInt32 if lower.is_null() => Self { + lower: ScalarValue::UInt32(Some(0)), + upper, + }, + DataType::UInt64 if lower.is_null() => Self { + lower: ScalarValue::UInt64(Some(0)), + upper, + }, // Other data types do not require standardization: - Self { lower, upper } + _ => Self { lower, upper }, } } @@ -299,12 +320,50 @@ impl Interval { Self::try_new(ScalarValue::from(lower), ScalarValue::from(upper)) } + /// Creates a singleton zero interval if the datatype supported. + pub fn make_zero(data_type: &DataType) -> Result { + let zero_endpoint = ScalarValue::new_zero(data_type)?; + Ok(Self::new(zero_endpoint.clone(), zero_endpoint)) + } + /// Creates an unbounded interval from both sides if the datatype supported. pub fn make_unbounded(data_type: &DataType) -> Result { let unbounded_endpoint = ScalarValue::try_from(data_type)?; Ok(Self::new(unbounded_endpoint.clone(), unbounded_endpoint)) } + /// Creates an interval between -1 to 1. + pub fn make_symmetric_unit_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_negative_one(data_type)?, + ScalarValue::new_one(data_type)?, + ) + } + + /// Create an interval from -π to π. + pub fn make_symmetric_pi_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_negative_pi_lower(data_type)?, + ScalarValue::new_pi_upper(data_type)?, + ) + } + + /// Create an interval from -π/2 to π/2. + pub fn make_symmetric_half_pi_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_neg_frac_pi_2_lower(data_type)?, + ScalarValue::new_frac_pi_2_upper(data_type)?, + ) + } + + /// Create an interval from 0 to infinity. + pub fn make_non_negative_infinity_interval(data_type: &DataType) -> Result { + Self::try_new( + ScalarValue::new_zero(data_type)?, + ScalarValue::try_from(data_type)?, + ) + } + /// Returns a reference to the lower bound. pub fn lower(&self) -> &ScalarValue { &self.lower @@ -369,7 +428,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn gt>(&self, other: T) -> Result { + pub fn gt>(&self, other: T) -> Result { let rhs = other.borrow(); if self.data_type().ne(&rhs.data_type()) { internal_err!( @@ -402,7 +461,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn gt_eq>(&self, other: T) -> Result { + pub fn gt_eq>(&self, other: T) -> Result { let rhs = other.borrow(); if self.data_type().ne(&rhs.data_type()) { internal_err!( @@ -435,7 +494,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn lt>(&self, other: T) -> Result { + pub fn lt>(&self, other: T) -> Result { other.borrow().gt(self) } @@ -446,7 +505,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn lt_eq>(&self, other: T) -> Result { + pub fn lt_eq>(&self, other: T) -> Result { other.borrow().gt_eq(self) } @@ -457,7 +516,7 @@ impl Interval { /// NOTE: This function only works with intervals of the same data type. /// Attempting to compare intervals of different data types will lead /// to an error. - pub(crate) fn equal>(&self, other: T) -> Result { + pub fn equal>(&self, other: T) -> Result { let rhs = other.borrow(); if get_result_type(&self.data_type(), &Operator::Eq, &rhs.data_type()).is_err() { internal_err!( @@ -480,7 +539,7 @@ impl Interval { /// Compute the logical conjunction of this (boolean) interval with the /// given boolean interval. - pub(crate) fn and>(&self, other: T) -> Result { + pub fn and>(&self, other: T) -> Result { let rhs = other.borrow(); match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { ( @@ -501,8 +560,31 @@ impl Interval { } } + /// Compute the logical disjunction of this boolean interval with the + /// given boolean interval. + pub fn or>(&self, other: T) -> Result { + let rhs = other.borrow(); + match (&self.lower, &self.upper, &rhs.lower, &rhs.upper) { + ( + &ScalarValue::Boolean(Some(self_lower)), + &ScalarValue::Boolean(Some(self_upper)), + &ScalarValue::Boolean(Some(other_lower)), + &ScalarValue::Boolean(Some(other_upper)), + ) => { + let lower = self_lower || other_lower; + let upper = self_upper || other_upper; + + Ok(Self { + lower: ScalarValue::Boolean(Some(lower)), + upper: ScalarValue::Boolean(Some(upper)), + }) + } + _ => internal_err!("Incompatible data types for logical conjunction"), + } + } + /// Compute the logical negation of this (boolean) interval. - pub(crate) fn not(&self) -> Result { + pub fn not(&self) -> Result { if self.data_type().ne(&DataType::Boolean) { internal_err!("Cannot apply logical negation to a non-boolean interval") } else if self == &Self::CERTAINLY_TRUE { @@ -761,6 +843,18 @@ impl Interval { } .map(|result| result + 1) } + + /// Reflects an [`Interval`] around the point zero. + /// + /// This method computes the arithmetic negation of the interval, reflecting + /// it about the origin of the number line. This operation swaps and negates + /// the lower and upper bounds of the interval. + pub fn arithmetic_negate(self) -> Result { + Ok(Self { + lower: self.upper().clone().arithmetic_negate()?, + upper: self.lower().clone().arithmetic_negate()?, + }) + } } impl Display for Interval { @@ -895,7 +989,8 @@ fn div_bounds( /// results are converted to an *unbounded endpoint* if: /// - We are calculating an upper bound and we have a positive overflow. /// - We are calculating a lower bound and we have a negative overflow. -/// Otherwise; the function sets the endpoint as: +/// +/// Otherwise, the function sets the endpoint as: /// - The minimum representable number with the given datatype (`dt`) if /// we are calculating an upper bound and we have a negative overflow. /// - The maximum representable number with the given datatype (`dt`) if @@ -957,6 +1052,25 @@ macro_rules! impl_OneTrait{ } impl_OneTrait! {u8, u16, u32, u64, i8, i16, i32, i64, i128} +impl OneTrait for IntervalDayTime { + fn one() -> Self { + IntervalDayTime { + days: 0, + milliseconds: 1, + } + } +} + +impl OneTrait for IntervalMonthDayNano { + fn one() -> Self { + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + } + } +} + /// This function either increments or decrements its argument, depending on /// the `INC` value (where a `true` value corresponds to the increment). fn increment_decrement( @@ -1019,11 +1133,15 @@ fn next_value_helper(value: ScalarValue) -> ScalarValue { IntervalYearMonth(Some(val)) => { IntervalYearMonth(Some(increment_decrement::(val))) } - IntervalDayTime(Some(val)) => { - IntervalDayTime(Some(increment_decrement::(val))) - } + IntervalDayTime(Some(val)) => IntervalDayTime(Some(increment_decrement::< + INC, + arrow::datatypes::IntervalDayTime, + >(val))), IntervalMonthDayNano(Some(val)) => { - IntervalMonthDayNano(Some(increment_decrement::(val))) + IntervalMonthDayNano(Some(increment_decrement::< + INC, + arrow::datatypes::IntervalMonthDayNano, + >(val))) } _ => value, // Unbounded values return without change. } @@ -1059,7 +1177,7 @@ fn min_of_bounds(first: &ScalarValue, second: &ScalarValue) -> ScalarValue { /// Example usage: /// ``` /// use datafusion_common::DataFusionError; -/// use datafusion_expr::interval_arithmetic::{satisfy_greater, Interval}; +/// use datafusion_expr_common::interval_arithmetic::{satisfy_greater, Interval}; /// /// let left = Interval::make(Some(-1000.0_f32), Some(1000.0_f32))?; /// let right = Interval::make(Some(500.0_f32), Some(2000.0_f32))?; @@ -1105,8 +1223,8 @@ pub fn satisfy_greater( } } - // Only the lower bound of left hand side and the upper bound of the right - // hand side can change after propagating the greater-than operation. + // Only the lower bound of left-hand side and the upper bound of the right-hand + // side can change after propagating the greater-than operation. let new_left_lower = if left.lower.is_null() || left.lower <= right.lower { if strict { next_value(right.lower.clone()) @@ -1434,8 +1552,8 @@ fn cast_scalar_value( /// ``` /// use arrow::datatypes::DataType; /// use datafusion_common::ScalarValue; -/// use datafusion_expr::interval_arithmetic::Interval; -/// use datafusion_expr::interval_arithmetic::NullableInterval; +/// use datafusion_expr_common::interval_arithmetic::Interval; +/// use datafusion_expr_common::interval_arithmetic::NullableInterval; /// /// // [1, 2) U {NULL} /// let maybe_null = NullableInterval::MaybeNull { @@ -1556,9 +1674,9 @@ impl NullableInterval { /// /// ``` /// use datafusion_common::ScalarValue; - /// use datafusion_expr::Operator; - /// use datafusion_expr::interval_arithmetic::Interval; - /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// use datafusion_expr_common::operator::Operator; + /// use datafusion_expr_common::interval_arithmetic::Interval; + /// use datafusion_expr_common::interval_arithmetic::NullableInterval; /// /// // 4 > 3 -> true /// let lhs = NullableInterval::from(ScalarValue::Int32(Some(4))); @@ -1635,7 +1753,7 @@ impl NullableInterval { } _ => Ok(Self::MaybeNull { values }), } - } else if op.is_comparison_operator() { + } else if op.supports_propagation() { Ok(Self::Null { datatype: DataType::Boolean, }) @@ -1680,8 +1798,8 @@ impl NullableInterval { /// /// ``` /// use datafusion_common::ScalarValue; - /// use datafusion_expr::interval_arithmetic::Interval; - /// use datafusion_expr::interval_arithmetic::NullableInterval; + /// use datafusion_expr_common::interval_arithmetic::Interval; + /// use datafusion_expr_common::interval_arithmetic::NullableInterval; /// /// let interval = NullableInterval::from(ScalarValue::Int32(Some(4))); /// assert_eq!(interval.single_value(), Some(ScalarValue::Int32(Some(4)))); @@ -1759,11 +1877,7 @@ mod tests { .sub(value.clone()) .unwrap() .lt(&eps)); - assert!(value - .clone() - .sub(prev_value(value.clone())) - .unwrap() - .lt(&eps)); + assert!(value.sub(prev_value(value.clone())).unwrap().lt(&eps)); assert_ne!(next_value(value.clone()), value); assert_ne!(prev_value(value.clone()), value); }); @@ -1795,11 +1909,11 @@ mod tests { min_max.into_iter().zip(inf).for_each(|((min, max), inf)| { assert_eq!(next_value(max.clone()), inf); assert_ne!(prev_value(max.clone()), max); - assert_ne!(prev_value(max.clone()), inf); + assert_ne!(prev_value(max), inf); assert_eq!(prev_value(min.clone()), inf); assert_ne!(next_value(min.clone()), min); - assert_ne!(next_value(min.clone()), inf); + assert_ne!(next_value(min), inf); assert_eq!(next_value(inf.clone()), inf); assert_eq!(prev_value(inf.clone()), inf); @@ -1885,10 +1999,10 @@ mod tests { let unbounded_cases = vec![ (DataType::Boolean, Boolean(Some(false)), Boolean(Some(true))), - (DataType::UInt8, UInt8(None), UInt8(None)), - (DataType::UInt16, UInt16(None), UInt16(None)), - (DataType::UInt32, UInt32(None), UInt32(None)), - (DataType::UInt64, UInt64(None), UInt64(None)), + (DataType::UInt8, UInt8(Some(0)), UInt8(None)), + (DataType::UInt16, UInt16(Some(0)), UInt16(None)), + (DataType::UInt32, UInt32(Some(0)), UInt32(None)), + (DataType::UInt64, UInt64(Some(0)), UInt64(None)), (DataType::Int8, Int8(None), Int8(None)), (DataType::Int16, Int16(None), Int16(None)), (DataType::Int32, Int32(None), Int32(None)), @@ -1994,6 +2108,10 @@ mod tests { Interval::make(None, Some(1000_i64))?, Interval::make(Some(1000_i64), Some(1500_i64))?, ), + ( + Interval::make(Some(0_u8), Some(0_u8))?, + Interval::make::(None, None)?, + ), ( Interval::try_new( prev_value(ScalarValue::Float32(Some(0.0_f32))), @@ -2036,6 +2154,10 @@ mod tests { Interval::make(Some(-1000_i64), Some(1000_i64))?, Interval::make(None, Some(-1500_i64))?, ), + ( + Interval::make::(None, None)?, + Interval::make(Some(0_u64), Some(0_u64))?, + ), ( Interval::make(Some(0.0_f32), Some(0.0_f32))?, Interval::make(Some(0.0_f32), Some(0.0_f32))?, diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs new file mode 100644 index 000000000000..179dd75ace85 --- /dev/null +++ b/datafusion/expr-common/src/lib.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Logical Expr types and traits for [DataFusion] +//! +//! This crate contains types and traits that are used by both Logical and Physical expressions. +//! They are kept in their own crate to avoid physical expressions depending on logical expressions. +//! +//! +//! [DataFusion]: + +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +pub mod accumulator; +pub mod columnar_value; +pub mod groups_accumulator; +pub mod interval_arithmetic; +pub mod operator; +pub mod signature; +pub mod sort_properties; +pub mod type_coercion; diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr-common/src/operator.rs similarity index 63% rename from datafusion/expr/src/operator.rs rename to datafusion/expr-common/src/operator.rs index a10312e23446..6ca0f04897ac 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr-common/src/operator.rs @@ -15,14 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Operator module contains foundational types that are used to represent operators in DataFusion. - -use crate::expr_fn::binary_expr; -use crate::Expr; -use crate::Like; use std::fmt; -use std::ops; -use std::ops::Not; /// Operators applied to expressions #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -149,10 +142,11 @@ impl Operator { ) } - /// Return true if the operator is a comparison operator. + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation /// - /// For example, 'Binary(a, >, b)' would be a comparison expression. - pub fn is_comparison_operator(&self) -> bool { + /// For example, 'Binary(a, >, b)' expression supports propagation. + pub fn supports_propagation(&self) -> bool { matches!( self, Operator::Eq @@ -170,6 +164,15 @@ impl Operator { ) } + /// Return true if the comparison operator can be used in interval arithmetic and constraint + /// propagation + /// + /// For example, 'Binary(a, >, b)' expression supports propagation. + #[deprecated(since = "43.0.0", note = "please use `supports_propagation` instead")] + pub fn is_comparison_operator(&self) -> bool { + self.supports_propagation() + } + /// Return true if the operator is a logic operator. /// /// For example, 'Binary(Binary(a, >, b), AND, Binary(a, <, b + 3))' would @@ -218,29 +221,23 @@ impl Operator { } /// Get the operator precedence - /// use as a reference + /// use as a reference pub fn precedence(&self) -> u8 { match self { Operator::Or => 5, Operator::And => 10, - Operator::NotEq - | Operator::Eq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => 20, - Operator::Plus | Operator::Minus => 30, - Operator::Multiply | Operator::Divide | Operator::Modulo => 40, + Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq => 15, + Operator::Lt | Operator::Gt => 20, + Operator::LikeMatch + | Operator::NotLikeMatch + | Operator::ILikeMatch + | Operator::NotILikeMatch => 25, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexNotMatch | Operator::RegexIMatch | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch | Operator::BitwiseAnd | Operator::BitwiseOr | Operator::BitwiseShiftLeft @@ -248,7 +245,9 @@ impl Operator { | Operator::BitwiseXor | Operator::StringConcat | Operator::AtArrow - | Operator::ArrowAt => 0, + | Operator::ArrowAt => 30, + Operator::Plus | Operator::Minus => 40, + Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } } } @@ -291,202 +290,3 @@ impl fmt::Display for Operator { write!(f, "{display}") } } - -/// Support ` + ` fluent style -impl ops::Add for Expr { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - binary_expr(self, Operator::Plus, rhs) - } -} - -/// Support ` - ` fluent style -impl ops::Sub for Expr { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - binary_expr(self, Operator::Minus, rhs) - } -} - -/// Support ` * ` fluent style -impl ops::Mul for Expr { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - binary_expr(self, Operator::Multiply, rhs) - } -} - -/// Support ` / ` fluent style -impl ops::Div for Expr { - type Output = Self; - - fn div(self, rhs: Self) -> Self { - binary_expr(self, Operator::Divide, rhs) - } -} - -/// Support ` % ` fluent style -impl ops::Rem for Expr { - type Output = Self; - - fn rem(self, rhs: Self) -> Self { - binary_expr(self, Operator::Modulo, rhs) - } -} - -/// Support ` & ` fluent style -impl ops::BitAnd for Expr { - type Output = Self; - - fn bitand(self, rhs: Self) -> Self { - binary_expr(self, Operator::BitwiseAnd, rhs) - } -} - -/// Support ` | ` fluent style -impl ops::BitOr for Expr { - type Output = Self; - - fn bitor(self, rhs: Self) -> Self { - binary_expr(self, Operator::BitwiseOr, rhs) - } -} - -/// Support ` ^ ` fluent style -impl ops::BitXor for Expr { - type Output = Self; - - fn bitxor(self, rhs: Self) -> Self { - binary_expr(self, Operator::BitwiseXor, rhs) - } -} - -/// Support ` << ` fluent style -impl ops::Shl for Expr { - type Output = Self; - - fn shl(self, rhs: Self) -> Self::Output { - binary_expr(self, Operator::BitwiseShiftLeft, rhs) - } -} - -/// Support ` >> ` fluent style -impl ops::Shr for Expr { - type Output = Self; - - fn shr(self, rhs: Self) -> Self::Output { - binary_expr(self, Operator::BitwiseShiftRight, rhs) - } -} - -/// Support `- ` fluent style -impl ops::Neg for Expr { - type Output = Self; - - fn neg(self) -> Self::Output { - Expr::Negative(Box::new(self)) - } -} - -/// Support `NOT ` fluent style -impl Not for Expr { - type Output = Self; - - fn not(self) -> Self::Output { - match self { - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::Like(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => Expr::SimilarTo(Like::new( - !negated, - expr, - pattern, - escape_char, - case_insensitive, - )), - _ => Expr::Not(Box::new(self)), - } - } -} - -#[cfg(test)] -mod tests { - use crate::lit; - - #[test] - fn test_operators() { - // Add - assert_eq!( - format!("{}", lit(1u32) + lit(2u32)), - "UInt32(1) + UInt32(2)" - ); - // Sub - assert_eq!( - format!("{}", lit(1u32) - lit(2u32)), - "UInt32(1) - UInt32(2)" - ); - // Mul - assert_eq!( - format!("{}", lit(1u32) * lit(2u32)), - "UInt32(1) * UInt32(2)" - ); - // Div - assert_eq!( - format!("{}", lit(1u32) / lit(2u32)), - "UInt32(1) / UInt32(2)" - ); - // Rem - assert_eq!( - format!("{}", lit(1u32) % lit(2u32)), - "UInt32(1) % UInt32(2)" - ); - // BitAnd - assert_eq!( - format!("{}", lit(1u32) & lit(2u32)), - "UInt32(1) & UInt32(2)" - ); - // BitOr - assert_eq!( - format!("{}", lit(1u32) | lit(2u32)), - "UInt32(1) | UInt32(2)" - ); - // BitXor - assert_eq!( - format!("{}", lit(1u32) ^ lit(2u32)), - "UInt32(1) BIT_XOR UInt32(2)" - ); - // Shl - assert_eq!( - format!("{}", lit(1u32) << lit(2u32)), - "UInt32(1) << UInt32(2)" - ); - // Shr - assert_eq!( - format!("{}", lit(1u32) >> lit(2u32)), - "UInt32(1) >> UInt32(2)" - ); - // Neg - assert_eq!(format!("{}", -lit(1u32)), "(- UInt32(1))"); - // Not - assert_eq!(format!("{}", !lit(1u32)), "NOT UInt32(1)"); - } -} diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr-common/src/signature.rs similarity index 78% rename from datafusion/expr/src/signature.rs rename to datafusion/expr-common/src/signature.rs index e2505d6fd65f..24cb54f634b1 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -35,7 +35,7 @@ pub const TIMEZONE_WILDCARD: &str = "+TZ"; /// valid length. It exists to avoid the need to enumerate all possible fixed size list lengths. pub const FIXED_SIZE_LIST_WILDCARD: i32 = i32::MIN; -///A function's volatility, which defines the functions eligibility for certain optimizations +/// A function's volatility, which defines the functions eligibility for certain optimizations #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] pub enum Volatility { /// An immutable function will always return the same output when given the same @@ -43,8 +43,8 @@ pub enum Volatility { Immutable, /// A stable function may return different values given the same input across different /// queries but must return the same value for a given input within a query. An example of - /// this is the `Now` function. DataFusion - /// will attempt to inline `Stable` functions during planning, when possible. + /// this is the `Now` function. DataFusion will attempt to inline `Stable` functions + /// during planning, when possible. /// For query `select col1, now() from t1`, it might take a while to execute but /// `now()` column will be the same for each output row, which is evaluated /// during planning. @@ -65,7 +65,7 @@ pub enum Volatility { /// automatically coerces (add casts to) function arguments so they match the type signature. /// /// For example, a function like `cos` may only be implemented for `Float64` arguments. To support a query -/// that calles `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically +/// that calls `cos` with a different argument type, such as `cos(int_column)`, type coercion automatically /// adds a cast such as `cos(CAST int_column AS DOUBLE)` during planning. /// /// # Data Types @@ -75,7 +75,7 @@ pub enum Volatility { /// /// ``` /// # use arrow::datatypes::{DataType, TimeUnit}; -/// # use datafusion_expr::{TIMEZONE_WILDCARD, TypeSignature}; +/// # use datafusion_expr_common::signature::{TIMEZONE_WILDCARD, TypeSignature}; /// let type_signature = TypeSignature::Exact(vec![ /// // A nanosecond precision timestamp with ANY timezone /// // matches Timestamp(Nanosecond, Some("+0:00")) @@ -84,22 +84,17 @@ pub enum Volatility { /// DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())), /// ]); /// ``` -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum TypeSignature { - /// One or more arguments of an common type out of a list of valid types. + /// One or more arguments of a common type out of a list of valid types. /// /// # Examples /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` Variadic(Vec), - /// One or more arguments of an arbitrary but equal type. - /// DataFusion attempts to coerce all argument types to match the first argument's type - /// - /// # Examples - /// Given types in signature should be coercible to the same final type. - /// A function such as `make_array` is `VariadicEqual`. - /// - /// `make_array(i32, i64) -> make_array(i64, i64)` - VariadicEqual, + /// The acceptable signature and coercions rules to coerce arguments to this + /// signature are special for this function. If this signature is specified, + /// DataFusion will call `ScalarUDFImpl::coerce_types` to prepare argument types. + UserDefined, /// One or more arguments with arbitrary types VariadicAny, /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. @@ -110,6 +105,11 @@ pub enum TypeSignature { Uniform(usize, Vec), /// Exact number of arguments of an exact type Exact(Vec), + /// The number of arguments that can be coerced to in order + /// For example, `Coercible(vec![DataType::Float64])` accepts + /// arguments like `vec![DataType::Int32]` or `vec![DataType::Float32]` + /// since i32 and f32 can be casted to f64 + Coercible(Vec), /// Fixed number of arguments of arbitrary types /// If a function takes 0 argument, its `TypeSignature` should be `Any(0)` Any(usize), @@ -122,9 +122,17 @@ pub enum TypeSignature { OneOf(Vec), /// Specifies Signatures for array functions ArraySignature(ArrayFunctionSignature), + /// Fixed number of arguments of numeric types. + /// See to know which type is considered numeric + Numeric(usize), + /// Fixed number of arguments of all the same string types. + /// The precedence of type from high to low is Utf8View, LargeUtf8 and Utf8. + /// Null is considerd as `Utf8` by default + /// Dictionary with string value type is also handled. + String(usize), } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum ArrayFunctionSignature { /// Specialized Signature for ArrayAppend and similar functions /// The first argument should be List/LargeList/FixedSizedList, and the second argument should be non-list or list. @@ -145,6 +153,9 @@ pub enum ArrayFunctionSignature { /// The function takes a single argument that must be a List/LargeList/FixedSizeList /// or something that can be coerced to one of those types. Array, + /// Specialized Signature for MapArray + /// The function takes a single argument that must be a MapArray + MapArray, } impl std::fmt::Display for ArrayFunctionSignature { @@ -165,12 +176,15 @@ impl std::fmt::Display for ArrayFunctionSignature { ArrayFunctionSignature::Array => { write!(f, "array") } + ArrayFunctionSignature::MapArray => { + write!(f, "map_array") + } } } } impl TypeSignature { - pub(crate) fn to_string_repr(&self) -> Vec { + pub fn to_string_repr(&self) -> Vec { match self { TypeSignature::Variadic(types) => { vec![format!("{}, ..", Self::join_types(types, "/"))] @@ -181,7 +195,13 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::Exact(types) => { + TypeSignature::String(num) => { + vec![format!("String({num})")] + } + TypeSignature::Numeric(num) => { + vec![format!("Numeric({num})")] + } + TypeSignature::Exact(types) | TypeSignature::Coercible(types) => { vec![Self::join_types(types, ", ")] } TypeSignature::Any(arg_count) => { @@ -190,8 +210,8 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => { - vec!["CoercibleT, .., CoercibleT".to_string()] + TypeSignature::UserDefined => { + vec!["UserDefined".to_string()] } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { @@ -204,10 +224,7 @@ impl TypeSignature { } /// Helper function to join types with specified delimiter. - pub(crate) fn join_types( - types: &[T], - delimiter: &str, - ) -> String { + pub fn join_types(types: &[T], delimiter: &str) -> String { types .iter() .map(|t| t.to_string()) @@ -232,7 +249,7 @@ impl TypeSignature { /// /// DataFusion will automatically coerce (cast) argument types to one of the supported /// function signatures, if possible. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Signature { /// The data types that the function accepts. See [TypeSignature] for more information. pub type_signature: TypeSignature, @@ -255,13 +272,30 @@ impl Signature { volatility, } } - /// An arbitrary number of arguments of the same type. - pub fn variadic_equal(volatility: Volatility) -> Self { + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::UserDefined, + volatility, + } + } + + /// A specified number of numeric arguments + pub fn numeric(arg_count: usize, volatility: Volatility) -> Self { Self { - type_signature: TypeSignature::VariadicEqual, + type_signature: TypeSignature::Numeric(arg_count), volatility, } } + + /// A specified number of numeric arguments + pub fn string(arg_count: usize, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::String(arg_count), + volatility, + } + } + /// An arbitrary number of arguments of any type. pub fn variadic_any(volatility: Volatility) -> Self { Self { @@ -287,6 +321,14 @@ impl Signature { volatility, } } + /// Target coerce types in order + pub fn coercible(target_types: Vec, volatility: Volatility) -> Self { + Self { + type_signature: TypeSignature::Coercible(target_types), + volatility, + } + } + /// A specified number of arguments of any type pub fn any(arg_count: usize, volatility: Volatility) -> Self { Signature { @@ -346,14 +388,6 @@ impl Signature { } } -/// Monotonicity of the `ScalarFunctionExpr` with respect to its arguments. -/// Each element of this vector corresponds to an argument and indicates whether -/// the function's behavior is monotonic, or non-monotonic/unknown for that argument, namely: -/// - `None` signifies unknown monotonicity or non-monotonicity. -/// - `Some(true)` indicates that the function is monotonically increasing w.r.t. the argument in question. -/// - Some(false) indicates that the function is monotonically decreasing w.r.t. the argument in question. -pub type FuncMonotonicity = Vec>; - #[cfg(test)] mod tests { use super::*; @@ -400,4 +434,24 @@ mod tests { ); } } + + #[test] + fn type_signature_partial_ord() { + // Test validates that partial ord is defined for TypeSignature and Signature. + assert!(TypeSignature::UserDefined < TypeSignature::VariadicAny); + assert!(TypeSignature::UserDefined < TypeSignature::Any(1)); + + assert!( + TypeSignature::Uniform(1, vec![DataType::Null]) + < TypeSignature::Uniform(1, vec![DataType::Boolean]) + ); + assert!( + TypeSignature::Uniform(1, vec![DataType::Null]) + < TypeSignature::Uniform(2, vec![DataType::Null]) + ); + assert!( + TypeSignature::Uniform(usize::MAX, vec![DataType::Null]) + < TypeSignature::Exact(vec![DataType::Null]) + ); + } } diff --git a/datafusion/physical-expr-common/src/sort_properties.rs b/datafusion/expr-common/src/sort_properties.rs similarity index 77% rename from datafusion/physical-expr-common/src/sort_properties.rs rename to datafusion/expr-common/src/sort_properties.rs index 47a5d5ba5e3b..7778be2ecf0d 100644 --- a/datafusion/physical-expr-common/src/sort_properties.rs +++ b/datafusion/expr-common/src/sort_properties.rs @@ -17,9 +17,10 @@ use std::ops::Neg; -use arrow::compute::SortOptions; +use crate::interval_arithmetic::Interval; -use crate::tree_node::ExprContext; +use arrow::compute::SortOptions; +use arrow::datatypes::DataType; /// To propagate [`SortOptions`] across the `PhysicalExpr`, it is insufficient /// to simply use `Option`: There must be a differentiation between @@ -120,29 +121,39 @@ impl SortProperties { impl Neg for SortProperties { type Output = Self; - fn neg(self) -> Self::Output { - match self { - SortProperties::Ordered(SortOptions { - descending, - nulls_first, - }) => SortProperties::Ordered(SortOptions { - descending: !descending, - nulls_first, - }), - SortProperties::Singleton => SortProperties::Singleton, - SortProperties::Unordered => SortProperties::Unordered, + fn neg(mut self) -> Self::Output { + if let SortProperties::Ordered(SortOptions { descending, .. }) = &mut self { + *descending = !*descending; } + self } } -/// The `ExprOrdering` struct is designed to aid in the determination of ordering (represented -/// by [`SortProperties`]) for a given `PhysicalExpr`. When analyzing the orderings -/// of a `PhysicalExpr`, the process begins by assigning the ordering of its leaf nodes. -/// By propagating these leaf node orderings upwards in the expression tree, the overall -/// ordering of the entire `PhysicalExpr` can be derived. -/// -/// This struct holds the necessary state information for each expression in the `PhysicalExpr`. -/// It encapsulates the orderings (`data`) associated with the expression (`expr`), and -/// orderings of the children expressions (`children`). The [`ExprOrdering`] of a parent -/// expression is determined based on the [`ExprOrdering`] states of its children expressions. -pub type ExprOrdering = ExprContext; +/// Represents the properties of a `PhysicalExpr`, including its sorting and range attributes. +#[derive(Debug, Clone)] +pub struct ExprProperties { + pub sort_properties: SortProperties, + pub range: Interval, +} + +impl ExprProperties { + /// Creates a new `ExprProperties` instance with unknown sort properties and unknown range. + pub fn new_unknown() -> Self { + Self { + sort_properties: SortProperties::default(), + range: Interval::make_unbounded(&DataType::Null).unwrap(), + } + } + + /// Sets the sorting properties of the expression and returns the modified instance. + pub fn with_order(mut self, order: SortProperties) -> Self { + self.sort_properties = order; + self + } + + /// Sets the range of the expression and returns the modified instance. + pub fn with_range(mut self, range: Interval) -> Self { + self.range = range; + self + } +} diff --git a/datafusion/expr-common/src/type_coercion.rs b/datafusion/expr-common/src/type_coercion.rs new file mode 100644 index 000000000000..e934c6eaf35b --- /dev/null +++ b/datafusion/expr-common/src/type_coercion.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod aggregates; +pub mod binary; diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs new file mode 100644 index 000000000000..fee75f9e4595 --- /dev/null +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -0,0 +1,360 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::signature::TypeSignature; +use arrow::datatypes::{ + DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, + DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, +}; + +use datafusion_common::{internal_err, plan_err, Result}; + +pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; + +pub static SIGNED_INTEGERS: &[DataType] = &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, +]; + +pub static UNSIGNED_INTEGERS: &[DataType] = &[ + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, +]; + +pub static INTEGERS: &[DataType] = &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, +]; + +pub static NUMERICS: &[DataType] = &[ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, +]; + +pub static TIMESTAMPS: &[DataType] = &[ + DataType::Timestamp(TimeUnit::Second, None), + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Timestamp(TimeUnit::Microsecond, None), + DataType::Timestamp(TimeUnit::Nanosecond, None), +]; + +pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; + +pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; + +pub static TIMES: &[DataType] = &[ + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), +]; + +/// Validate the length of `input_types` matches the `signature` for `agg_fun`. +/// +/// This method DOES NOT validate the argument types - only that (at least one, +/// in the case of [`TypeSignature::OneOf`]) signature matches the desired +/// number of input types. +pub fn check_arg_count( + func_name: &str, + input_types: &[DataType], + signature: &TypeSignature, +) -> Result<()> { + match signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != *agg_count { + return plan_err!( + "The function {func_name} expects {:?} arguments, but {:?} were provided", + agg_count, + input_types.len() + ); + } + } + TypeSignature::Exact(types) => { + if types.len() != input_types.len() { + return plan_err!( + "The function {func_name} expects {:?} arguments, but {:?} were provided", + types.len(), + input_types.len() + ); + } + } + TypeSignature::OneOf(variants) => { + let ok = variants + .iter() + .any(|v| check_arg_count(func_name, input_types, v).is_ok()); + if !ok { + return plan_err!( + "The function {func_name} does not accept {:?} function arguments.", + input_types.len() + ); + } + } + TypeSignature::VariadicAny => { + if input_types.is_empty() { + return plan_err!( + "The function {func_name} expects at least one argument" + ); + } + } + TypeSignature::UserDefined + | TypeSignature::Numeric(_) + | TypeSignature::Coercible(_) => { + // User-defined signature is validated in `coerce_types` + // Numeric and Coercible signature is validated in `get_valid_types` + } + _ => { + return internal_err!( + "Aggregate functions do not support this {signature:?}" + ); + } + } + Ok(()) +} + +/// Function return type of a sum +pub fn sum_return_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Int64 => Ok(DataType::Int64), + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal128(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+10), s) + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+10), s) + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } + other => plan_err!("SUM does not support type \"{other:?}\""), + } +} + +/// Function return type of variance +pub fn variance_return_type(arg_type: &DataType) -> Result { + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + plan_err!("VAR does not support {arg_type:?}") + } +} + +/// Function return type of covariance +pub fn covariance_return_type(arg_type: &DataType) -> Result { + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + plan_err!("COVAR does not support {arg_type:?}") + } +} + +/// Function return type of correlation +pub fn correlation_return_type(arg_type: &DataType) -> Result { + if NUMERICS.contains(arg_type) { + Ok(DataType::Float64) + } else { + plan_err!("CORR does not support {arg_type:?}") + } +} + +/// Function return type of an average +pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result { + match arg_type { + DataType::Decimal128(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal128(new_precision, new_scale)) + } + DataType::Decimal256(precision, scale) => { + // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). + // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); + let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); + Ok(DataType::Decimal256(new_precision, new_scale)) + } + arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), + DataType::Dictionary(_, dict_value_type) => { + avg_return_type(func_name, dict_value_type.as_ref()) + } + other => plan_err!("{func_name} does not support {other:?}"), + } +} + +/// Internal sum type of an average +pub fn avg_sum_type(arg_type: &DataType) -> Result { + match arg_type { + DataType::Decimal128(precision, scale) => { + // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s) + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } + arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), + DataType::Dictionary(_, dict_value_type) => { + avg_sum_type(dict_value_type.as_ref()) + } + other => plan_err!("AVG does not support {other:?}"), + } +} + +pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { + match arg_type { + DataType::Dictionary(_, dict_value_type) => { + is_sum_support_arg_type(dict_value_type.as_ref()) + } + _ => matches!( + arg_type, + arg_type if NUMERICS.contains(arg_type) + || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) + ), + } +} + +pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { + match arg_type { + DataType::Dictionary(_, dict_value_type) => { + is_avg_support_arg_type(dict_value_type.as_ref()) + } + _ => matches!( + arg_type, + arg_type if NUMERICS.contains(arg_type) + || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _)) + ), + } +} + +pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + arg_type if NUMERICS.contains(arg_type) + ) +} + +pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + arg_type if NUMERICS.contains(arg_type) + ) +} + +pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + arg_type if NUMERICS.contains(arg_type) + ) +} + +pub fn is_integer_arg_type(arg_type: &DataType) -> bool { + arg_type.is_integer() +} + +pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result> { + // Supported types smallint, int, bigint, real, double precision, decimal, or interval + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + fn coerced_type(func_name: &str, data_type: &DataType) -> Result { + return match &data_type { + DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), + DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), + d if d.is_numeric() => Ok(DataType::Float64), + DataType::Dictionary(_, v) => return coerced_type(func_name, v.as_ref()), + _ => { + return plan_err!( + "The function {:?} does not support inputs of type {:?}.", + func_name, + data_type + ) + } + }; + } + Ok(vec![coerced_type(func_name, &arg_types[0])?]) +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_variance_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = variance_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal128(36, 10); + assert!(variance_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn test_sum_return_data_type() -> Result<()> { + let data_type = DataType::Decimal128(10, 5); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal128(20, 5), result_type); + + let data_type = DataType::Decimal128(36, 10); + let result_type = sum_return_type(&data_type)?; + assert_eq!(DataType::Decimal128(38, 10), result_type); + Ok(()) + } + + #[test] + fn test_covariance_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = covariance_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal128(36, 10); + assert!(covariance_return_type(&data_type).is_err()); + Ok(()) + } + + #[test] + fn test_correlation_return_data_type() -> Result<()> { + let data_type = DataType::Float64; + let result_type = correlation_return_type(&data_type)?; + assert_eq!(DataType::Float64, result_type); + + let data_type = DataType::Decimal128(36, 10); + assert!(correlation_return_type(&data_type).is_err()); + Ok(()) + } +} diff --git a/datafusion/expr/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs similarity index 62% rename from datafusion/expr/src/type_coercion/binary.rs rename to datafusion/expr-common/src/type_coercion/binary.rs index 7eec606658f4..31fe6a59baee 100644 --- a/datafusion/expr/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -17,18 +17,21 @@ //! Coercion rules for matching argument types for binary operators +use std::collections::HashSet; use std::sync::Arc; -use crate::Operator; +use crate::operator::Operator; use arrow::array::{new_empty_array, Array}; use arrow::compute::can_cast_types; use arrow::datatypes::{ - DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, Field, FieldRef, Fields, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; - -use datafusion_common::{exec_datafusion_err, plan_datafusion_err, plan_err, Result}; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, plan_datafusion_err, plan_err, Result, +}; +use itertools::Itertools; /// The type signature of an instantiation of binary operator expression such as /// `lhs + rhs` @@ -86,7 +89,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result And | Or => if matches!((lhs, rhs), (Boolean | Null, Boolean | Null)) { // Logical binary boolean operators can only be evaluated for // boolean or null arguments. - Ok(Signature::uniform(DataType::Boolean)) + Ok(Signature::uniform(Boolean)) } else { plan_err!( "Cannot infer common argument type for logical boolean operation {lhs} {op} {rhs}" @@ -154,7 +157,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result rhs: rhs.clone(), ret, }) - } else if let Some(coerced) = temporal_coercion(lhs, rhs) { + } else if let Some(coerced) = temporal_coercion_strict_timezone(lhs, rhs) { // Temporal arithmetic by first coercing to a common time representation // e.g. Date32 - Timestamp let ret = get_result(&coerced, &coerced).map_err(|e| { @@ -191,7 +194,7 @@ fn signature(lhs: &DataType, op: &Operator, rhs: &DataType) -> Result } } -/// returns the resulting type of a binary expression evaluating the `op` with the left and right hand types +/// Returns the resulting type of a binary expression evaluating the `op` with the left and right hand types pub fn get_result_type( lhs: &DataType, op: &Operator, @@ -289,20 +292,355 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option for TypeCategory { + fn from(data_type: &DataType) -> Self { + match data_type { + // Dict is a special type in arrow, we check the value type + DataType::Dictionary(_, v) => { + let v = v.as_ref(); + TypeCategory::from(v) + } + _ => { + if data_type.is_numeric() { + return TypeCategory::Numeric; + } + + if matches!(data_type, DataType::Boolean) { + return TypeCategory::Boolean; + } + + if matches!( + data_type, + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + ) { + return TypeCategory::Array; + } + + // String literal is possible to cast to many other types like numeric or datetime, + // therefore, it is categorized as a unknown type + if matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Null + ) { + return TypeCategory::Unknown; + } + + if matches!( + data_type, + DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Timestamp(_, _) + | DataType::Interval(_) + | DataType::Duration(_) + ) { + return TypeCategory::DateTime; + } + + if matches!( + data_type, + DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) + ) { + return TypeCategory::Composite; + } + + TypeCategory::NotSupported + } + } + } +} + +/// Coerce dissimilar data types to a single data type. +/// UNION, INTERSECT, EXCEPT, CASE, ARRAY, VALUES, and the GREATEST and LEAST functions are +/// examples that has the similar resolution rules. +/// See for more information. +/// The rules in the document provide a clue, but adhering strictly to them doesn't precisely +/// align with the behavior of Postgres. Therefore, we've made slight adjustments to the rules +/// to better match the behavior of both Postgres and DuckDB. For example, we expect adjusted +/// decimal precision and scale when coercing decimal types. +/// +/// This function doesn't preserve correct field name and nullability for the struct type, we only care about data type. +/// +/// Returns Option because we might want to continue on the code even if the data types are not coercible to the common type +pub fn type_union_resolution(data_types: &[DataType]) -> Option { + if data_types.is_empty() { + return None; + } + + // If all the data_types is the same return first one + if data_types.iter().all(|t| t == &data_types[0]) { + return Some(data_types[0].clone()); + } + + // If all the data_types are null, return string + if data_types.iter().all(|t| t == &DataType::Null) { + return Some(DataType::Utf8); + } + + // Ignore Nulls, if any data_type category is not the same, return None + let data_types_category: Vec = data_types + .iter() + .filter(|&t| t != &DataType::Null) + .map(|t| t.into()) + .collect(); + + if data_types_category + .iter() + .any(|t| t == &TypeCategory::NotSupported) + { + return None; + } + + // Check if there is only one category excluding Unknown + let categories: HashSet = HashSet::from_iter( + data_types_category + .iter() + .filter(|&c| c != &TypeCategory::Unknown) + .cloned(), + ); + if categories.len() > 1 { + return None; + } + + // Ignore Nulls + let mut candidate_type: Option = None; + for data_type in data_types.iter() { + if data_type == &DataType::Null { + continue; + } + if let Some(ref candidate_t) = candidate_type { + // Find candidate type that all the data types can be coerced to + // Follows the behavior of Postgres and DuckDB + // Coerced type may be different from the candidate and current data type + // For example, + // i64 and decimal(7, 2) are expect to get coerced type decimal(22, 2) + // numeric string ('1') and numeric (2) are expect to get coerced type numeric (1, 2) + if let Some(t) = type_union_resolution_coercion(data_type, candidate_t) { + candidate_type = Some(t); + } else { + return None; + } + } else { + candidate_type = Some(data_type.clone()); + } + } + + candidate_type +} + +/// Coerce `lhs_type` and `rhs_type` to a common type for [type_union_resolution] +/// See [type_union_resolution] for more information. +fn type_union_resolution_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + match (lhs_type, rhs_type) { + ( + DataType::Dictionary(lhs_index_type, lhs_value_type), + DataType::Dictionary(rhs_index_type, rhs_value_type), + ) => { + let new_index_type = + type_union_resolution_coercion(lhs_index_type, rhs_index_type); + let new_value_type = + type_union_resolution_coercion(lhs_value_type, rhs_value_type); + if let (Some(new_index_type), Some(new_value_type)) = + (new_index_type, new_value_type) + { + Some(DataType::Dictionary( + Box::new(new_index_type), + Box::new(new_value_type), + )) + } else { + None + } + } + (DataType::Dictionary(index_type, value_type), other_type) + | (other_type, DataType::Dictionary(index_type, value_type)) => { + let new_value_type = type_union_resolution_coercion(value_type, other_type); + new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) + } + (DataType::List(lhs), DataType::List(rhs)) => { + let new_item_type = + type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); + new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) + } + (DataType::Struct(lhs), DataType::Struct(rhs)) => { + if lhs.len() != rhs.len() { + return None; + } + + // Search the field in the right hand side with the SAME field name + fn search_corresponding_coerced_type( + lhs_field: &FieldRef, + rhs: &Fields, + ) -> Option { + for rhs_field in rhs.iter() { + if lhs_field.name() == rhs_field.name() { + if let Some(t) = type_union_resolution_coercion( + lhs_field.data_type(), + rhs_field.data_type(), + ) { + return Some(t); + } else { + return None; + } + } + } + + None + } + + let types = lhs + .iter() + .map(|lhs_field| search_corresponding_coerced_type(lhs_field, rhs)) + .collect::>>()?; + + let fields = types + .into_iter() + .enumerate() + .map(|(i, datatype)| { + Arc::new(Field::new(format!("c{i}"), datatype, true)) + }) + .collect::>(); + Some(DataType::Struct(fields.into())) + } + _ => { + // Numeric coercion is the same as comparison coercion, both find the narrowest type + // that can accommodate both types + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) + } + } +} + +/// Handle type union resolution including struct type and others. +pub fn try_type_union_resolution(data_types: &[DataType]) -> Result> { + let err = match try_type_union_resolution_with_struct(data_types) { + Ok(struct_types) => return Ok(struct_types), + Err(e) => Some(e), + }; + + if let Some(new_type) = type_union_resolution(data_types) { + Ok(vec![new_type; data_types.len()]) + } else { + exec_err!("Fail to find the coerced type, errors: {:?}", err) + } +} + +// Handle struct where we only change the data type but preserve the field name and nullability. +// Since field name is the key of the struct, so it shouldn't be updated to the common column name like "c0" or "c1" +pub fn try_type_union_resolution_with_struct( + data_types: &[DataType], +) -> Result> { + let mut keys_string: Option = None; + for data_type in data_types { + if let DataType::Struct(fields) = data_type { + let keys = fields.iter().map(|f| f.name().to_owned()).join(","); + if let Some(ref k) = keys_string { + if *k != keys { + return exec_err!("Expect same keys for struct type but got mismatched pair {} and {}", *k, keys); + } + } else { + keys_string = Some(keys); + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut struct_types: Vec = if let DataType::Struct(fields) = &data_types[0] + { + fields.iter().map(|f| f.data_type().to_owned()).collect() + } else { + return internal_err!("Struct type is checked is the previous function, so this should be unreachable"); + }; + + for data_type in data_types.iter().skip(1) { + if let DataType::Struct(fields) = data_type { + let incoming_struct_types: Vec = + fields.iter().map(|f| f.data_type().to_owned()).collect(); + // The order of field is verified above + for (lhs_type, rhs_type) in + struct_types.iter_mut().zip(incoming_struct_types.iter()) + { + if let Some(coerced_type) = + type_union_resolution_coercion(lhs_type, rhs_type) + { + *lhs_type = coerced_type; + } else { + return exec_err!( + "Fail to find the coerced type for {} and {}", + lhs_type, + rhs_type + ); + } + } + } else { + return exec_err!("Expect to get struct but got {}", data_type); + } + } + + let mut final_struct_types = vec![]; + for s in data_types { + let mut new_fields = vec![]; + if let DataType::Struct(fields) = s { + for (i, f) in fields.iter().enumerate() { + let field = Arc::unwrap_or_clone(Arc::clone(f)) + .with_data_type(struct_types[i].to_owned()); + new_fields.push(Arc::new(field)); + } + } + final_struct_types.push(DataType::Struct(new_fields.into())) + } + + Ok(final_struct_types) +} + +/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a +/// comparison operation +/// +/// Example comparison operations are `lhs = rhs` and `lhs > rhs` +/// +/// Binary comparison kernels require the two arguments to be the (exact) same +/// data type. However, users can write queries where the two arguments are +/// different data types. In such cases, the data types are automatically cast +/// (coerced) to a single data type to pass to the kernels. pub fn comparison_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { if lhs_type == rhs_type { // same type => equality is possible return Some(lhs_type.clone()); } - comparison_binary_numeric_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type, true)) - .or_else(|| temporal_coercion(lhs_type, rhs_type)) + binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, true)) + .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) .or_else(|| string_numeric_coercion(lhs_type, rhs_type)) .or_else(|| string_temporal_coercion(lhs_type, rhs_type)) .or_else(|| binary_coercion(lhs_type, rhs_type)) + .or_else(|| struct_coercion(lhs_type, rhs_type)) } /// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation @@ -319,7 +657,7 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { match (l, r) { - // Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp - (Utf8, temporal) | (LargeUtf8, temporal) => match temporal { - Date32 | Date64 => Some(temporal.clone()), - Time32(_) | Time64(_) => { - if is_time_with_valid_unit(temporal.to_owned()) { - Some(temporal.to_owned()) - } else { - None + // Coerce Utf8View/Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp + (Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => { + match temporal { + Date32 | Date64 => Some(temporal.clone()), + Time32(_) | Time64(_) => { + if is_time_with_valid_unit(temporal.to_owned()) { + Some(temporal.to_owned()) + } else { + None + } } + Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), + _ => None, } - Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())), - _ => None, - }, + } _ => None, } } @@ -359,9 +699,8 @@ fn string_temporal_coercion( match_rule(lhs_type, rhs_type).or_else(|| match_rule(rhs_type, lhs_type)) } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation -/// where one both are numeric -pub(crate) fn comparison_binary_numeric_coercion( +/// Coerce `lhs_type` and `rhs_type` to a common type where both are numeric +pub fn binary_numeric_coercion( lhs_type: &DataType, rhs_type: &DataType, ) -> Option { @@ -375,20 +714,13 @@ pub(crate) fn comparison_binary_numeric_coercion( return Some(lhs_type.clone()); } - // these are ordered from most informative to least informative so + if let Some(t) = decimal_coercion(lhs_type, rhs_type) { + return Some(t); + } + + // These are ordered from most informative to least informative so // that the coercion does not lose information via truncation match (lhs_type, rhs_type) { - // Prefer decimal data type over floating point for comparison operation - (Decimal128(_, _), Decimal128(_, _)) => { - get_wider_decimal_type(lhs_type, rhs_type) - } - (Decimal128(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), - (_, Decimal128(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), - (Decimal256(_, _), Decimal256(_, _)) => { - get_wider_decimal_type(lhs_type, rhs_type) - } - (Decimal256(_, _), _) => get_comparison_common_decimal_type(lhs_type, rhs_type), - (_, Decimal256(_, _)) => get_comparison_common_decimal_type(rhs_type, lhs_type), (Float64, _) | (_, Float64) => Some(Float64), (_, Float32) | (Float32, _) => Some(Float32), // The following match arms encode the following logic: Given the two @@ -396,6 +728,11 @@ pub(crate) fn comparison_binary_numeric_coercion( // accommodates all values of both types. Note that some information // loss is inevitable when we have a signed type and a `UInt64`, in // which case we use `Int64`;i.e. the widest signed integral type. + + // TODO: For i64 and u64, we can use decimal or float64 + // Postgres has no unsigned type :( + // DuckDB v.0.10.0 has double (double precision floating-point number (8 bytes)) + // for largest signed (signed sixteen-byte integer) and unsigned integer (unsigned sixteen-byte integer) (Int64, _) | (_, Int64) | (UInt64, Int8) @@ -426,9 +763,28 @@ pub(crate) fn comparison_binary_numeric_coercion( } } -/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of -/// a comparison operation where one is a decimal -fn get_comparison_common_decimal_type( +/// Decimal coercion rules. +pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { + // Prefer decimal data type over floating point for comparison operation + (Decimal128(_, _), Decimal128(_, _)) => { + get_wider_decimal_type(lhs_type, rhs_type) + } + (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), + (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), + (Decimal256(_, _), Decimal256(_, _)) => { + get_wider_decimal_type(lhs_type, rhs_type) + } + (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), + (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), + (_, _) => None, + } +} + +/// Coerce `lhs_type` and `rhs_type` to a common type. +fn get_common_decimal_type( decimal_type: &DataType, other_type: &DataType, ) -> Option { @@ -494,7 +850,7 @@ pub fn get_wider_type(lhs: &DataType, rhs: &DataType) -> Result { (Int16 | Int32 | Int64, Int8) | (Int32 | Int64, Int16) | (Int64, Int32) | // Left Float is larger than right Float. (Float32 | Float64, Float16) | (Float64, Float32) | - // Left String is larget than right String. + // Left String is larger than right String. (LargeUtf8, Utf8) | // Any left type is wider than a right hand side Null. (_, Null) => lhs.clone(), @@ -555,6 +911,31 @@ fn coerce_numeric_type_to_decimal256(numeric_type: &DataType) -> Option Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Struct(lhs_fields), Struct(rhs_fields)) => { + if lhs_fields.len() != rhs_fields.len() { + return None; + } + + let types = std::iter::zip(lhs_fields.iter(), rhs_fields.iter()) + .map(|(lhs, rhs)| comparison_coercion(lhs.data_type(), rhs.data_type())) + .collect::>>()?; + + let fields = types + .into_iter() + .enumerate() + .map(|(i, datatype)| { + Arc::new(Field::new(format!("c{i}"), datatype, true)) + }) + .collect::>(); + Some(Struct(fields.into())) + } + _ => None, + } +} + /// Returns the output type of applying mathematics operations such as /// `+` to arguments of `lhs_type` and `rhs_type`. fn mathematics_numerical_coercion( @@ -563,12 +944,12 @@ fn mathematics_numerical_coercion( ) -> Option { use arrow::datatypes::DataType::*; - // error on any non-numeric type + // Error on any non-numeric type if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) { return None; }; - // these are ordered from most informative to least informative so + // These are ordered from most informative to least informative so // that the coercion removes the least amount of information match (lhs_type, rhs_type) { (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { @@ -632,7 +1013,7 @@ fn both_numeric_or_null_and_numeric(lhs_type: &DataType, rhs_type: &DataType) -> /// /// Not all operators support dictionaries, if `preserve_dictionaries` is true /// dictionaries will be preserved if possible -fn dictionary_coercion( +fn dictionary_comparison_coercion( lhs_type: &DataType, rhs_type: &DataType, preserve_dictionaries: bool, @@ -661,23 +1042,28 @@ fn dictionary_coercion( /// Coercion rules for string concat. /// This is a union of string coercion rules and specified rules: -/// 1. At lease one side of lhs and rhs should be string type (Utf8 / LargeUtf8) +/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8) /// 2. Data type of the other side should be able to cast to string type fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) { + (Utf8View, from_type) | (from_type, Utf8View) => { + string_concat_internal_coercion(from_type, &Utf8View) + } (Utf8, from_type) | (from_type, Utf8) => { string_concat_internal_coercion(from_type, &Utf8) } (LargeUtf8, from_type) | (from_type, LargeUtf8) => { string_concat_internal_coercion(from_type, &LargeUtf8) } + (Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => { + string_coercion(lhs_value_type, rhs_value_type).or(None) + } _ => None, }) } fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - // TODO: cast between array elements (#6558) if lhs_type.equals_datatype(rhs_type) { Some(lhs_type.to_owned()) } else { @@ -685,6 +1071,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } +/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise +/// return `None`. fn string_concat_internal_coercion( from_type: &DataType, to_type: &DataType, @@ -696,20 +1084,79 @@ fn string_concat_internal_coercion( } } -/// Coercion rules for string types (Utf8/LargeUtf8): If at least one argument is -/// a string type and both arguments can be coerced into a string type, coerce -/// to string type. -fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { +/// Coercion rules for string view types (Utf8/LargeUtf8/Utf8View): +/// If at least one argument is a string view, we coerce to string view +/// based on the observation that StringArray to StringViewArray is cheap but not vice versa. +/// +/// Between Utf8 and LargeUtf8, we coerce to LargeUtf8. +pub fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { + // If Utf8View is in any side, we coerce to Utf8View. + (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { + Some(Utf8View) + } + // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. + (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), + // Utf8 coerces to Utf8 (Utf8, Utf8) => Some(Utf8), - (LargeUtf8, Utf8) => Some(LargeUtf8), - (Utf8, LargeUtf8) => Some(LargeUtf8), - (LargeUtf8, LargeUtf8) => Some(LargeUtf8), - // TODO: cast between array elements (#6558) + _ => None, + } +} + +/// This will be deprecated when binary operators native support +/// for Utf8View (use `string_coercion` instead). +fn regex_comparison_string_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + // If Utf8View is in any side, we coerce to Utf8. + (Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => { + Some(Utf8) + } + // Then, if LargeUtf8 is in any side, we coerce to LargeUtf8. + (LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8), + // Utf8 coerces to Utf8 + (Utf8, Utf8) => Some(Utf8), + _ => None, + } +} + +fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Utf8 | LargeUtf8, other_type) | (other_type, Utf8 | LargeUtf8) + if other_type.is_numeric() => + { + Some(other_type.clone()) + } + _ => None, + } +} + +/// Coercion rules for list types. +fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { (List(_), List(_)) => Some(lhs_type.clone()), - (List(_), _) => Some(lhs_type.clone()), - (_, List(_)) => Some(rhs_type.clone()), + (LargeList(_), List(_)) => Some(lhs_type.clone()), + (List(_), LargeList(_)) => Some(rhs_type.clone()), + (LargeList(_), LargeList(_)) => Some(lhs_type.clone()), + (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), + (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()), + // Coerce to the left side FixedSizeList type if the list lengths are the same, + // otherwise coerce to list with the left type for dynamic length + (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => { + if ls == rs { + Some(lhs_type.clone()) + } else { + Some(List(Arc::clone(lf))) + } + } + (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), + (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()), _ => None, } } @@ -725,43 +1172,72 @@ fn binary_to_string_coercion( match (lhs_type, rhs_type) { (Binary, Utf8) => Some(Utf8), (Binary, LargeUtf8) => Some(LargeUtf8), + (BinaryView, Utf8) => Some(Utf8View), + (BinaryView, LargeUtf8) => Some(LargeUtf8), (LargeBinary, Utf8) => Some(LargeUtf8), (LargeBinary, LargeUtf8) => Some(LargeUtf8), (Utf8, Binary) => Some(Utf8), (Utf8, LargeBinary) => Some(LargeUtf8), + (Utf8, BinaryView) => Some(Utf8View), (LargeUtf8, Binary) => Some(LargeUtf8), (LargeUtf8, LargeBinary) => Some(LargeUtf8), + (LargeUtf8, BinaryView) => Some(LargeUtf8), _ => None, } } -/// Coercion rules for binary types (Binary/LargeBinary): If at least one argument is +/// Coercion rules for binary types (Binary/LargeBinary/BinaryView): If at least one argument is /// a binary type and both arguments can be coerced into a binary type, coerce /// to binary type. fn binary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (Binary | Utf8, Binary) | (Binary, Utf8) => Some(Binary), - (LargeBinary | Binary | Utf8 | LargeUtf8, LargeBinary) - | (LargeBinary, Binary | Utf8 | LargeUtf8) => Some(LargeBinary), + // If BinaryView is in any side, we coerce to BinaryView. + (BinaryView, BinaryView | Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) + | (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, BinaryView) => { + Some(BinaryView) + } + // Prefer LargeBinary over Binary + (LargeBinary | Binary | Utf8 | LargeUtf8 | Utf8View, LargeBinary) + | (LargeBinary, Binary | Utf8 | LargeUtf8 | Utf8View) => Some(LargeBinary), + + // If Utf8View/LargeUtf8 presents need to be large Binary + (Utf8View | LargeUtf8, Binary) | (Binary, Utf8View | LargeUtf8) => { + Some(LargeBinary) + } + (Binary, Utf8) | (Utf8, Binary) => Some(Binary), _ => None, } } -/// coercion rules for like operations. +/// Coercion rules for like operations. /// This is a union of string coercion rules and dictionary coercion rules pub fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { string_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| binary_to_string_coercion(lhs_type, rhs_type)) - .or_else(|| dictionary_coercion(lhs_type, rhs_type, false)) + .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) + .or_else(|| regex_null_coercion(lhs_type, rhs_type)) .or_else(|| null_coercion(lhs_type, rhs_type)) } -/// coercion rules for regular expression comparison operations. +/// Coercion rules for regular expression comparison operations with NULL input. +fn regex_null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + use arrow::datatypes::DataType::*; + match (lhs_type, rhs_type) { + (Null, Utf8View | Utf8 | LargeUtf8) => Some(rhs_type.clone()), + (Utf8View | Utf8 | LargeUtf8, Null) => Some(lhs_type.clone()), + (Null, Null) => Some(Utf8), + _ => None, + } +} + +/// Coercion rules for regular expression comparison operations. /// This is a union of string coercion rules and dictionary coercion rules pub fn regex_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { - string_coercion(lhs_type, rhs_type) - .or_else(|| dictionary_coercion(lhs_type, rhs_type, false)) + regex_comparison_string_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_comparison_coercion(lhs_type, rhs_type, false)) + .or_else(|| regex_null_coercion(lhs_type, rhs_type)) } /// Checks if the TimeUnit associated with a Time32 or Time64 type is consistent, @@ -777,16 +1253,97 @@ fn is_time_with_valid_unit(datatype: DataType) -> bool { ) } +/// Non-strict Timezone Coercion is useful in scenarios where we can guarantee +/// a stable relationship between two timestamps of different timezones. +/// +/// An example of this is binary comparisons (<, >, ==, etc). Arrow stores timestamps +/// as relative to UTC epoch, and then adds the timezone as an offset. As a result, we can always +/// do a binary comparison between the two times. +/// +/// Timezone coercion is handled by the following rules: +/// - If only one has a timezone, coerce the other to match +/// - If both have a timezone, coerce to the left type +/// - "UTC" and "+00:00" are considered equivalent +fn temporal_coercion_nonstrict_timezone( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { + (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { + let tz = match (lhs_tz, rhs_tz) { + // If both have a timezone, use the left timezone. + (Some(lhs_tz), Some(_rhs_tz)) => Some(Arc::clone(lhs_tz)), + (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)), + (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)), + (None, None) => None, + }; + + let unit = timeunit_coercion(lhs_unit, rhs_unit); + + Some(Timestamp(unit, tz)) + } + _ => temporal_coercion(lhs_type, rhs_type), + } +} + +/// Strict Timezone coercion is useful in scenarios where we cannot guarantee a stable relationship +/// between two timestamps with different timezones or do not want implicit coercion between them. +/// +/// An example of this when attempting to coerce function arguments. Functions already have a mechanism +/// for defining which timestamp types they want to support, so we do not want to do any further coercion. +/// /// Coercion rules for Temporal columns: the type that both lhs and rhs can be /// casted to for the purpose of a date computation /// For interval arithmetic, it doesn't handle datetime type +/- interval +/// Timezone coercion is handled by the following rules: +/// - If only one has a timezone, coerce the other to match +/// - If both have a timezone, throw an error +/// - "UTC" and "+00:00" are considered equivalent +fn temporal_coercion_strict_timezone( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + match (lhs_type, rhs_type) { + (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { + let tz = match (lhs_tz, rhs_tz) { + (Some(lhs_tz), Some(rhs_tz)) => { + match (lhs_tz.as_ref(), rhs_tz.as_ref()) { + // UTC and "+00:00" are the same by definition. Most other timezones + // do not have a 1-1 mapping between timezone and an offset from UTC + ("UTC", "+00:00") | ("+00:00", "UTC") => Some(Arc::clone(lhs_tz)), + (lhs, rhs) if lhs == rhs => Some(Arc::clone(lhs_tz)), + // can't cast across timezones + _ => { + return None; + } + } + } + (Some(lhs_tz), None) => Some(Arc::clone(lhs_tz)), + (None, Some(rhs_tz)) => Some(Arc::clone(rhs_tz)), + (None, None) => None, + }; + + let unit = timeunit_coercion(lhs_unit, rhs_unit); + + Some(Timestamp(unit, tz)) + } + _ => temporal_coercion(lhs_type, rhs_type), + } +} + fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; use arrow::datatypes::IntervalUnit::*; use arrow::datatypes::TimeUnit::*; match (lhs_type, rhs_type) { - (Interval(_), Interval(_)) => Some(Interval(MonthDayNano)), + (Interval(_) | Duration(_), Interval(_) | Duration(_)) => { + Some(Interval(MonthDayNano)) + } (Date64, Date32) | (Date32, Date64) => Some(Date64), (Timestamp(_, None), Date64) | (Date64, Timestamp(_, None)) => { Some(Timestamp(Nanosecond, None)) @@ -800,47 +1357,33 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { Some(Timestamp(Nanosecond, None)) } - (Timestamp(lhs_unit, lhs_tz), Timestamp(rhs_unit, rhs_tz)) => { - let tz = match (lhs_tz, rhs_tz) { - // can't cast across timezones - (Some(lhs_tz), Some(rhs_tz)) => { - if lhs_tz != rhs_tz { - return None; - } else { - Some(lhs_tz.clone()) - } - } - (Some(lhs_tz), None) => Some(lhs_tz.clone()), - (None, Some(rhs_tz)) => Some(rhs_tz.clone()), - (None, None) => None, - }; - - let unit = match (lhs_unit, rhs_unit) { - (Second, Millisecond) => Second, - (Second, Microsecond) => Second, - (Second, Nanosecond) => Second, - (Millisecond, Second) => Second, - (Millisecond, Microsecond) => Millisecond, - (Millisecond, Nanosecond) => Millisecond, - (Microsecond, Second) => Second, - (Microsecond, Millisecond) => Millisecond, - (Microsecond, Nanosecond) => Microsecond, - (Nanosecond, Second) => Second, - (Nanosecond, Millisecond) => Millisecond, - (Nanosecond, Microsecond) => Microsecond, - (l, r) => { - assert_eq!(l, r); - l.clone() - } - }; + _ => None, + } +} - Some(Timestamp(unit, tz)) +fn timeunit_coercion(lhs_unit: &TimeUnit, rhs_unit: &TimeUnit) -> TimeUnit { + use arrow::datatypes::TimeUnit::*; + match (lhs_unit, rhs_unit) { + (Second, Millisecond) => Second, + (Second, Microsecond) => Second, + (Second, Nanosecond) => Second, + (Millisecond, Second) => Second, + (Millisecond, Microsecond) => Millisecond, + (Millisecond, Nanosecond) => Millisecond, + (Microsecond, Second) => Second, + (Microsecond, Millisecond) => Millisecond, + (Microsecond, Nanosecond) => Microsecond, + (Nanosecond, Second) => Second, + (Nanosecond, Millisecond) => Millisecond, + (Nanosecond, Microsecond) => Microsecond, + (l, r) => { + assert_eq!(l, r); + *l } - _ => None, } } -/// coercion rules from NULL type. Since NULL can be casted to any other type in arrow, +/// Coercion rules from NULL type. Since NULL can be casted to any other type in arrow, /// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coercion is valid. fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { match (lhs_type, rhs_type) { @@ -953,38 +1496,50 @@ mod tests { let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, false), + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Int32) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), Some(Int32) ); // Since we can coerce values of Int16 to Utf8 can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Utf8)); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), + Some(Utf8) + ); // Since we can coerce values of Utf8 to Binary can support this let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Dictionary(Box::new(Int8), Box::new(Binary)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, true), + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), Some(Binary) ); let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); let rhs_type = Utf8; - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, true), + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Utf8) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), Some(lhs_type.clone()) ); let lhs_type = Utf8; let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); - assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8)); assert_eq!( - dictionary_coercion(&lhs_type, &rhs_type, true), + dictionary_comparison_coercion(&lhs_type, &rhs_type, false), + Some(Utf8) + ); + assert_eq!( + dictionary_comparison_coercion(&lhs_type, &rhs_type, true), Some(rhs_type.clone()) ); } @@ -1464,6 +2019,90 @@ mod tests { DataType::LargeBinary ); + // Timestamps + let utc: Option> = Some("UTC".into()); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, utc.clone()), + DataType::Timestamp(TimeUnit::Second, utc.clone()), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, utc.clone()) + ); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, utc.clone()), + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, utc.clone()) + ); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())), + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, Some("America/New_York".into())) + ); + test_coercion_binary_rule!( + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())), + DataType::Timestamp(TimeUnit::Second, utc), + Operator::Eq, + DataType::Timestamp(TimeUnit::Second, Some("Europe/Brussels".into())) + ); + + // list + let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); + test_coercion_binary_rule!( + DataType::List(Arc::clone(&inner_field)), + DataType::List(Arc::clone(&inner_field)), + Operator::Eq, + DataType::List(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::List(Arc::clone(&inner_field)), + DataType::LargeList(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::LargeList(Arc::clone(&inner_field)), + DataType::List(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::LargeList(Arc::clone(&inner_field)), + DataType::LargeList(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + Operator::Eq, + DataType::FixedSizeList(Arc::clone(&inner_field), 10) + ); + test_coercion_binary_rule!( + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + DataType::LargeList(Arc::clone(&inner_field)), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::LargeList(Arc::clone(&inner_field)), + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + Operator::Eq, + DataType::LargeList(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::List(Arc::clone(&inner_field)), + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + Operator::Eq, + DataType::List(Arc::clone(&inner_field)) + ); + test_coercion_binary_rule!( + DataType::FixedSizeList(Arc::clone(&inner_field), 10), + DataType::List(Arc::clone(&inner_field)), + Operator::Eq, + DataType::List(Arc::clone(&inner_field)) + ); + // TODO add other data type Ok(()) } diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 2759572581ea..d7dc1afe4d50 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,13 +38,17 @@ path = "src/lib.rs" [features] [dependencies] -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } +ahash = { workspace = true } arrow = { workspace = true } arrow-array = { workspace = true } +arrow-buffer = { workspace = true } chrono = { workspace = true } -datafusion-common = { workspace = true, default-features = true } +datafusion-common = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-functions-window-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +indexmap = { workspace = true } paste = "^1.0" serde_json = { workspace = true } sqlparser = { workspace = true } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs deleted file mode 100644 index 3dc9c3a01c15..000000000000 --- a/datafusion/expr/src/aggregate_function.rs +++ /dev/null @@ -1,434 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Aggregate function module contains all built-in aggregate functions definitions - -use std::sync::Arc; -use std::{fmt, str::FromStr}; - -use crate::utils; -use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility}; - -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; - -use strum_macros::EnumIter; - -/// Enum of all built-in aggregate functions -// Contributor's guide for adding new aggregate functions -// https://datafusion.apache.org/contributor-guide/index.html#how-to-add-a-new-aggregate-function -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] -pub enum AggregateFunction { - /// Count - Count, - /// Sum - Sum, - /// Minimum - Min, - /// Maximum - Max, - /// Average - Avg, - /// Median - Median, - /// Approximate distinct function - ApproxDistinct, - /// Aggregation into an array - ArrayAgg, - /// First value in a group according to some ordering - FirstValue, - /// Last value in a group according to some ordering - LastValue, - /// N'th value in a group according to some ordering - NthValue, - /// Variance (Sample) - Variance, - /// Variance (Population) - VariancePop, - /// Standard Deviation (Sample) - Stddev, - /// Standard Deviation (Population) - StddevPop, - /// Covariance (Sample) - Covariance, - /// Covariance (Population) - CovariancePop, - /// Correlation - Correlation, - /// Slope from linear regression - RegrSlope, - /// Intercept from linear regression - RegrIntercept, - /// Number of input rows in which both expressions are not null - RegrCount, - /// R-squared value from linear regression - RegrR2, - /// Average of the independent variable - RegrAvgx, - /// Average of the dependent variable - RegrAvgy, - /// Sum of squares of the independent variable - RegrSXX, - /// Sum of squares of the dependent variable - RegrSYY, - /// Sum of products of pairs of numbers - RegrSXY, - /// Approximate continuous percentile function - ApproxPercentileCont, - /// Approximate continuous percentile function with weight - ApproxPercentileContWithWeight, - /// ApproxMedian - ApproxMedian, - /// Grouping - Grouping, - /// Bit And - BitAnd, - /// Bit Or - BitOr, - /// Bit Xor - BitXor, - /// Bool And - BoolAnd, - /// Bool Or - BoolOr, - /// String aggregation - StringAgg, -} - -impl AggregateFunction { - pub fn name(&self) -> &str { - use AggregateFunction::*; - match self { - Count => "COUNT", - Sum => "SUM", - Min => "MIN", - Max => "MAX", - Avg => "AVG", - Median => "MEDIAN", - ApproxDistinct => "APPROX_DISTINCT", - ArrayAgg => "ARRAY_AGG", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", - NthValue => "NTH_VALUE", - Variance => "VAR", - VariancePop => "VAR_POP", - Stddev => "STDDEV", - StddevPop => "STDDEV_POP", - Covariance => "COVAR", - CovariancePop => "COVAR_POP", - Correlation => "CORR", - RegrSlope => "REGR_SLOPE", - RegrIntercept => "REGR_INTERCEPT", - RegrCount => "REGR_COUNT", - RegrR2 => "REGR_R2", - RegrAvgx => "REGR_AVGX", - RegrAvgy => "REGR_AVGY", - RegrSXX => "REGR_SXX", - RegrSYY => "REGR_SYY", - RegrSXY => "REGR_SXY", - ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", - ApproxMedian => "APPROX_MEDIAN", - Grouping => "GROUPING", - BitAnd => "BIT_AND", - BitOr => "BIT_OR", - BitXor => "BIT_XOR", - BoolAnd => "BOOL_AND", - BoolOr => "BOOL_OR", - StringAgg => "STRING_AGG", - } - } -} - -impl fmt::Display for AggregateFunction { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}", self.name()) - } -} - -impl FromStr for AggregateFunction { - type Err = DataFusionError; - fn from_str(name: &str) -> Result { - Ok(match name { - // general - "avg" => AggregateFunction::Avg, - "bit_and" => AggregateFunction::BitAnd, - "bit_or" => AggregateFunction::BitOr, - "bit_xor" => AggregateFunction::BitXor, - "bool_and" => AggregateFunction::BoolAnd, - "bool_or" => AggregateFunction::BoolOr, - "count" => AggregateFunction::Count, - "max" => AggregateFunction::Max, - "mean" => AggregateFunction::Avg, - "median" => AggregateFunction::Median, - "min" => AggregateFunction::Min, - "sum" => AggregateFunction::Sum, - "array_agg" => AggregateFunction::ArrayAgg, - "first_value" => AggregateFunction::FirstValue, - "last_value" => AggregateFunction::LastValue, - "nth_value" => AggregateFunction::NthValue, - "string_agg" => AggregateFunction::StringAgg, - // statistical - "corr" => AggregateFunction::Correlation, - "covar" => AggregateFunction::Covariance, - "covar_pop" => AggregateFunction::CovariancePop, - "covar_samp" => AggregateFunction::Covariance, - "stddev" => AggregateFunction::Stddev, - "stddev_pop" => AggregateFunction::StddevPop, - "stddev_samp" => AggregateFunction::Stddev, - "var" => AggregateFunction::Variance, - "var_pop" => AggregateFunction::VariancePop, - "var_samp" => AggregateFunction::Variance, - "regr_slope" => AggregateFunction::RegrSlope, - "regr_intercept" => AggregateFunction::RegrIntercept, - "regr_count" => AggregateFunction::RegrCount, - "regr_r2" => AggregateFunction::RegrR2, - "regr_avgx" => AggregateFunction::RegrAvgx, - "regr_avgy" => AggregateFunction::RegrAvgy, - "regr_sxx" => AggregateFunction::RegrSXX, - "regr_syy" => AggregateFunction::RegrSYY, - "regr_sxy" => AggregateFunction::RegrSXY, - // approximate - "approx_distinct" => AggregateFunction::ApproxDistinct, - "approx_median" => AggregateFunction::ApproxMedian, - "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, - "approx_percentile_cont_with_weight" => { - AggregateFunction::ApproxPercentileContWithWeight - } - // other - "grouping" => AggregateFunction::Grouping, - _ => { - return plan_err!("There is no built-in function named {name}"); - } - }) - } -} - -impl AggregateFunction { - /// Returns the datatype of the aggregate function given its argument types - /// - /// This is used to get the returned data type for aggregate expr. - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - - let coerced_data_types = coerce_types(self, input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message - .map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - &format!("{self}"), - self.signature(), - input_expr_types, - ) - ) - })?; - - match self { - AggregateFunction::Count | AggregateFunction::ApproxDistinct => { - Ok(DataType::Int64) - } - AggregateFunction::Max | AggregateFunction::Min => { - // For min and max agg function, the returned type is same as input type. - // The coerced_data_types is same with input_types. - Ok(coerced_data_types[0].clone()) - } - AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => Ok(coerced_data_types[0].clone()), - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Ok(DataType::Boolean) - } - AggregateFunction::Variance => variance_return_type(&coerced_data_types[0]), - AggregateFunction::VariancePop => { - variance_return_type(&coerced_data_types[0]) - } - AggregateFunction::Covariance => { - covariance_return_type(&coerced_data_types[0]) - } - AggregateFunction::CovariancePop => { - covariance_return_type(&coerced_data_types[0]) - } - AggregateFunction::Correlation => { - correlation_return_type(&coerced_data_types[0]) - } - AggregateFunction::Stddev => stddev_return_type(&coerced_data_types[0]), - AggregateFunction::StddevPop => stddev_return_type(&coerced_data_types[0]), - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => Ok(DataType::Float64), - AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), - AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( - "item", - coerced_data_types[0].clone(), - true, - )))), - AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), - AggregateFunction::ApproxPercentileContWithWeight => { - Ok(coerced_data_types[0].clone()) - } - AggregateFunction::ApproxMedian | AggregateFunction::Median => { - Ok(coerced_data_types[0].clone()) - } - AggregateFunction::Grouping => Ok(DataType::Int32), - AggregateFunction::FirstValue - | AggregateFunction::LastValue - | AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), - AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), - } - } -} - -/// Returns the internal sum datatype of the avg aggregate function. -pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { - // Note that this function *must* return the same type that the respective physical expression returns - // or the execution panics. - let fun = AggregateFunction::Avg; - let coerced_data_types = crate::type_coercion::aggregates::coerce_types( - &fun, - input_expr_types, - &fun.signature(), - )?; - avg_sum_type(&coerced_data_types[0]) -} - -impl AggregateFunction { - /// the signatures supported by the function `fun`. - pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. - match self { - AggregateFunction::Count => Signature::variadic_any(Volatility::Immutable), - AggregateFunction::ApproxDistinct - | AggregateFunction::Grouping - | AggregateFunction::ArrayAgg => Signature::any(1, Volatility::Immutable), - AggregateFunction::Min | AggregateFunction::Max => { - let valid = STRINGS - .iter() - .chain(NUMERICS.iter()) - .chain(TIMESTAMPS.iter()) - .chain(DATES.iter()) - .chain(TIMES.iter()) - .chain(BINARYS.iter()) - .cloned() - .collect::>(); - Signature::uniform(1, valid, Volatility::Immutable) - } - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => { - Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable) - } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) - } - AggregateFunction::Avg - | AggregateFunction::Sum - | AggregateFunction::Variance - | AggregateFunction::VariancePop - | AggregateFunction::Stddev - | AggregateFunction::StddevPop - | AggregateFunction::Median - | AggregateFunction::ApproxMedian - | AggregateFunction::FirstValue - | AggregateFunction::LastValue => { - Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Covariance - | AggregateFunction::CovariancePop - | AggregateFunction::Correlation - | AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { - Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) - } - AggregateFunction::ApproxPercentileCont => { - let mut variants = - Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); - // Accept any numeric value paired with a float64 percentile - for num in NUMERICS { - variants - .push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); - // Additionally accept an integer number of centroids for T-Digest - for int in INTEGERS { - variants.push(TypeSignature::Exact(vec![ - num.clone(), - DataType::Float64, - int.clone(), - ])) - } - } - - Signature::one_of(variants, Volatility::Immutable) - } - AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( - // Accept any numeric value paired with a float64 percentile - NUMERICS - .iter() - .map(|t| { - TypeSignature::Exact(vec![ - t.clone(), - t.clone(), - DataType::Float64, - ]) - }) - .collect(), - Volatility::Immutable, - ), - AggregateFunction::StringAgg => { - Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use strum::IntoEnumIterator; - - #[test] - // Test for AggregateFuncion's Display and from_str() implementations. - // For each variant in AggregateFuncion, it converts the variant to a string - // and then back to a variant. The test asserts that the original variant and - // the reconstructed variant are the same. This assertion is also necessary for - // function suggestion. See https://github.com/apache/datafusion/issues/8082 - fn test_display_and_from_str() { - for func_original in AggregateFunction::iter() { - let func_name = func_original.to_string(); - let func_from_str = - AggregateFunction::from_str(func_name.to_lowercase().as_str()).unwrap(); - assert_eq!(func_from_str, func_original); - } - } -} diff --git a/datafusion/expr/src/built_in_window_function.rs b/datafusion/expr/src/built_in_window_function.rs index 1001bbb015ed..ab41395ad371 100644 --- a/datafusion/expr/src/built_in_window_function.rs +++ b/datafusion/expr/src/built_in_window_function.rs @@ -22,7 +22,7 @@ use std::str::FromStr; use crate::type_coercion::functions::data_types; use crate::utils; -use crate::{Signature, TypeSignature, Volatility}; +use crate::{Signature, Volatility}; use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result}; use arrow::datatypes::DataType; @@ -37,53 +37,23 @@ impl fmt::Display for BuiltInWindowFunction { /// A [window function] built in to DataFusion /// -/// [window function]: https://en.wikipedia.org/wiki/Window_function_(SQL) -#[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter)] +/// [Window Function]: https://en.wikipedia.org/wiki/Window_function_(SQL) +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)] pub enum BuiltInWindowFunction { - /// number of the current row within its partition, counting from 1 - RowNumber, - /// rank of the current row with gaps; same as row_number of its first peer - Rank, - /// rank of the current row without gaps; this function counts peer groups - DenseRank, - /// relative rank of the current row: (rank - 1) / (total rows - 1) - PercentRank, - /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows) - CumeDist, - /// integer ranging from 1 to the argument value, dividing the partition as equally as possible - Ntile, - /// returns value evaluated at the row that is offset rows before the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lag, - /// returns value evaluated at the row that is offset rows after the current row within the partition; - /// if there is no such row, instead return default (which must be of the same type as value). - /// Both offset and default are evaluated with respect to the current row. - /// If omitted, offset defaults to 1 and default to null - Lead, /// returns value evaluated at the row that is the first row of the window frame FirstValue, - /// returns value evaluated at the row that is the last row of the window frame + /// Returns value evaluated at the row that is the last row of the window frame LastValue, - /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row + /// Returns value evaluated at the row that is the nth row of the window frame (counting from 1); returns null if no such row NthValue, } impl BuiltInWindowFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use BuiltInWindowFunction::*; match self { - RowNumber => "ROW_NUMBER", - Rank => "RANK", - DenseRank => "DENSE_RANK", - PercentRank => "PERCENT_RANK", - CumeDist => "CUME_DIST", - Ntile => "NTILE", - Lag => "LAG", - Lead => "LEAD", - FirstValue => "FIRST_VALUE", - LastValue => "LAST_VALUE", + FirstValue => "first_value", + LastValue => "last_value", NthValue => "NTH_VALUE", } } @@ -93,14 +63,6 @@ impl FromStr for BuiltInWindowFunction { type Err = DataFusionError; fn from_str(name: &str) -> Result { Ok(match name.to_uppercase().as_str() { - "ROW_NUMBER" => BuiltInWindowFunction::RowNumber, - "RANK" => BuiltInWindowFunction::Rank, - "DENSE_RANK" => BuiltInWindowFunction::DenseRank, - "PERCENT_RANK" => BuiltInWindowFunction::PercentRank, - "CUME_DIST" => BuiltInWindowFunction::CumeDist, - "NTILE" => BuiltInWindowFunction::Ntile, - "LAG" => BuiltInWindowFunction::Lag, - "LEAD" => BuiltInWindowFunction::Lead, "FIRST_VALUE" => BuiltInWindowFunction::FirstValue, "LAST_VALUE" => BuiltInWindowFunction::LastValue, "NTH_VALUE" => BuiltInWindowFunction::NthValue, @@ -115,10 +77,10 @@ impl BuiltInWindowFunction { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function + // Verify that this is a valid set of data types for this function data_types(input_expr_types, &self.signature()) - // original errors are all related to wrong function signature - // aggregate them for better error message + // Original errors are all related to wrong function signature + // Aggregate them for better error message .map_err(|_| { plan_datafusion_err!( "{}", @@ -131,57 +93,19 @@ impl BuiltInWindowFunction { })?; match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::Ntile => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue + BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue | BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()), } } - /// the signatures supported by the built-in window function `fun`. + /// The signatures supported by the built-in window function `fun`. pub fn signature(&self) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. + // Note: The physical expression must accept the type returned by this function or the execution panics. match self { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable), - BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => { - Signature::one_of( - vec![ - TypeSignature::Any(1), - TypeSignature::Any(2), - TypeSignature::Any(3), - ], - Volatility::Immutable, - ) - } BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => { Signature::any(1, Volatility::Immutable) } - BuiltInWindowFunction::Ntile => Signature::uniform( - 1, - vec![ - DataType::UInt64, - DataType::UInt32, - DataType::UInt16, - DataType::UInt8, - DataType::Int64, - DataType::Int32, - DataType::Int16, - DataType::Int8, - ], - Volatility::Immutable, - ), BuiltInWindowFunction::NthValue => Signature::any(2, Volatility::Immutable), } } diff --git a/datafusion/expr/src/conditional_expressions.rs b/datafusion/expr/src/conditional_expressions.rs index 7a2bf4b6c44a..23cc88f1c0ff 100644 --- a/datafusion/expr/src/conditional_expressions.rs +++ b/datafusion/expr/src/conditional_expressions.rs @@ -64,7 +64,7 @@ impl CaseBuilder { } fn build(&self) -> Result { - // collect all "then" expressions + // Collect all "then" expressions let mut then_expr = self.then_expr.clone(); if let Some(e) = &self.else_expr { then_expr.push(e.as_ref().to_owned()); @@ -79,7 +79,7 @@ impl CaseBuilder { .collect::>>()?; if then_types.contains(&DataType::Null) { - // cannot verify types until execution type + // Cannot verify types until execution type } else { let unique_types: HashSet<&DataType> = then_types.iter().collect(); if unique_types.len() != 1 { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0d8e8d816b33..a9c183952fc7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,39 +17,80 @@ //! Logical Expressions: [`Expr`] -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fmt::{self, Display, Formatter, Write}; -use std::hash::Hash; +use std::hash::{Hash, Hasher}; +use std::mem; use std::str::FromStr; use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; -use crate::window_frame; +use crate::Volatility; use crate::{ - aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator, - Signature, + udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame, + WindowUDF, }; use arrow::datatypes::{DataType, FieldRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::cse::HashNode; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ - internal_err, plan_err, Column, DFSchema, Result, ScalarValue, TableReference, + plan_err, Column, DFSchema, Result, ScalarValue, TableReference, +}; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use sqlparser::ast::{ + display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem, + NullTreatment, RenameSelectItem, ReplaceSelectElement, }; -use sqlparser::ast::NullTreatment; -/// `Expr` is a central struct of DataFusion's query API, and -/// represent logical expressions such as `A + 1`, or `CAST(c1 AS -/// int)`. +/// Represents logical expressions such as `A + 1`, or `CAST(c1 AS int)`. +/// +/// For example the expression `A + 1` will be represented as +/// +///```text +/// BinaryExpr { +/// left: Expr::Column("A"), +/// op: Operator::Plus, +/// right: Expr::Literal(ScalarValue::Int32(Some(1))) +/// } +/// ``` +/// +/// # Creating Expressions +/// +/// `Expr`s can be created directly, but it is often easier and less verbose to +/// use the fluent APIs in [`crate::expr_fn`] such as [`col`] and [`lit`], or +/// methods such as [`Expr::alias`], [`Expr::cast_to`], and [`Expr::Like`]). /// -/// An `Expr` can compute its [DataType] -/// and nullability, and has functions for building up complex -/// expressions. +/// See also [`ExprFunctionExt`] for creating aggregate and window functions. +/// +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt +/// +/// # Schema Access +/// +/// See [`ExprSchemable::get_type`] to access the [`DataType`] and nullability +/// of an `Expr`. +/// +/// # Visiting and Rewriting `Expr`s +/// +/// The `Expr` struct implements the [`TreeNode`] trait for walking and +/// rewriting expressions. For example [`TreeNode::apply`] recursively visits an +/// `Expr` and [`TreeNode::transform`] can be used to rewrite an expression. See +/// the examples below and [`TreeNode`] for more information. /// /// # Examples /// -/// ## Create an expression `c1` referring to column named "c1" +/// ## Column references and literals +/// +/// [`Expr::Column`] refer to the values of columns and are often created with +/// the [`col`] function. For example to create an expression `c1` referring to +/// column named "c1": +/// +/// [`col`]: crate::expr_fn::col +/// /// ``` /// # use datafusion_common::Column; /// # use datafusion_expr::{lit, col, Expr}; @@ -57,11 +98,33 @@ use sqlparser::ast::NullTreatment; /// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); /// ``` /// -/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together +/// [`Expr::Literal`] refer to literal, or constant, values. These are created +/// with the [`lit`] function. For example to create an expression `42`: +/// +/// [`lit`]: crate::lit +/// +/// ``` +/// # use datafusion_common::{Column, ScalarValue}; +/// # use datafusion_expr::{lit, col, Expr}; +/// // All literals are strongly typed in DataFusion. To make an `i64` 42: +/// let expr = lit(42i64); +/// assert_eq!(expr, Expr::Literal(ScalarValue::Int64(Some(42)))); +/// // To make a (typed) NULL: +/// let expr = Expr::Literal(ScalarValue::Int64(None)); +/// // to make an (untyped) NULL (the optimizer will coerce this to the correct type): +/// let expr = lit(ScalarValue::Null); +/// ``` +/// +/// ## Binary Expressions +/// +/// Exprs implement traits that allow easy to understand construction of more +/// complex expressions. For example, to create `c1 + c2` to add columns "c1" and +/// "c2" together +/// /// ``` /// # use datafusion_expr::{lit, col, Operator, Expr}; +/// // Use the `+` operator to add two columns together /// let expr = col("c1") + col("c2"); -/// /// assert!(matches!(expr, Expr::BinaryExpr { ..} )); /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); @@ -70,12 +133,13 @@ use sqlparser::ast::NullTreatment; /// } /// ``` /// -/// ## Create expression `c1 = 42` to compare the value in column "c1" to the literal value `42` +/// The expression `c1 = 42` to compares the value in column "c1" to the +/// literal value `42`: +/// /// ``` /// # use datafusion_common::ScalarValue; /// # use datafusion_expr::{lit, col, Operator, Expr}; /// let expr = col("c1").eq(lit(42_i32)); -/// /// assert!(matches!(expr, Expr::BinaryExpr { .. } )); /// if let Expr::BinaryExpr(binary_expr) = expr { /// assert_eq!(*binary_expr.left, col("c1")); @@ -85,19 +149,23 @@ use sqlparser::ast::NullTreatment; /// } /// ``` /// -/// ## Return a list of [`Expr::Column`] from a schema's columns +/// Here is how to implement the equivalent of `SELECT *` to select all +/// [`Expr::Column`] from a [`DFSchema`]'s columns: +/// /// ``` /// # use arrow::datatypes::{DataType, Field, Schema}; /// # use datafusion_common::{DFSchema, Column}; /// # use datafusion_expr::Expr; -/// +/// // Create a schema c1(int, c2 float) /// let arrow_schema = Schema::new(vec![ /// Field::new("c1", DataType::Int32, false), /// Field::new("c2", DataType::Float64, false), /// ]); -/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema).unwrap(); +/// // DFSchema is a an Arrow schema with optional relation name +/// let df_schema = DFSchema::try_from_qualified_schema("t1", &arrow_schema) +/// .unwrap(); /// -/// // Form a list of expressions for each item in the schema +/// // Form Vec with an expression for each column in the schema /// let exprs: Vec<_> = df_schema.iter() /// .map(Expr::from) /// .collect(); @@ -107,11 +175,61 @@ use sqlparser::ast::NullTreatment; /// Expr::from(Column::from_qualified_name("t1.c2")), /// ]); /// ``` -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +/// +/// # Visiting and Rewriting `Expr`s +/// +/// Here is an example that finds all literals in an `Expr` tree: +/// ``` +/// # use std::collections::{HashSet}; +/// use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, Expr, lit}; +/// use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +/// // Expression a = 5 AND b = 6 +/// let expr = col("a").eq(lit(5)) & col("b").eq(lit(6)); +/// // find all literals in a HashMap +/// let mut scalars = HashSet::new(); +/// // apply recursively visits all nodes in the expression tree +/// expr.apply(|e| { +/// if let Expr::Literal(scalar) = e { +/// scalars.insert(scalar); +/// } +/// // The return value controls whether to continue visiting the tree +/// Ok(TreeNodeRecursion::Continue) +/// }).unwrap(); +/// // All subtrees have been visited and literals found +/// assert_eq!(scalars.len(), 2); +/// assert!(scalars.contains(&ScalarValue::Int32(Some(5)))); +/// assert!(scalars.contains(&ScalarValue::Int32(Some(6)))); +/// ``` +/// +/// Rewrite an expression, replacing references to column "a" in an +/// to the literal `42`: +/// +/// ``` +/// # use datafusion_common::tree_node::{Transformed, TreeNode}; +/// # use datafusion_expr::{col, Expr, lit}; +/// // expression a = 5 AND b = 6 +/// let expr = col("a").eq(lit(5)).and(col("b").eq(lit(6))); +/// // rewrite all references to column "a" to the literal 42 +/// let rewritten = expr.transform(|e| { +/// if let Expr::Column(c) = &e { +/// if &c.name == "a" { +/// // return Transformed::yes to indicate the node was changed +/// return Ok(Transformed::yes(lit(42))) +/// } +/// } +/// // return Transformed::no to indicate the node was not changed +/// Ok(Transformed::no(e)) +/// }).unwrap(); +/// // The expression has been rewritten +/// assert!(rewritten.transformed); +/// // to 42 = 5 AND b = 6 +/// assert_eq!(rewritten.data, lit(42).eq(lit(5)).and(col("b").eq(lit(6)))); +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum Expr { /// An expression with a specific name. Alias(Alias), - /// A named reference to a qualified filed in a schema. + /// A named reference to a qualified field in a schema. Column(Column), /// A named reference to a variable in a registry. ScalarVariable(DataType, Vec), @@ -143,28 +261,29 @@ pub enum Expr { IsNotUnknown(Box), /// arithmetic negation of an expression, the operand must be of a signed numeric data type Negative(Box), - /// Returns the field of a [`arrow::array::ListArray`] or - /// [`arrow::array::StructArray`] by index or range - GetIndexedField(GetIndexedField), /// Whether an expression is between a given range. Between(Between), /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. /// + /// ```text /// CASE WHEN condition THEN result /// [WHEN ...] /// [ELSE result] /// END + /// ``` /// /// The second form uses a base expression and then a series of "when" clauses that match on a /// literal value. /// + /// ```text /// CASE expression /// WHEN value THEN result /// [WHEN ...] /// [ELSE result] /// END + /// ``` Case(Case), /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. @@ -172,11 +291,14 @@ pub enum Expr { /// Casts the expression to a given type and will return a null value if the expression cannot be cast. /// This expression is guaranteed to have a fixed type. TryCast(TryCast), - /// A sort expression, that can be used to sort values. - Sort(Sort), /// Represents the call of a scalar function with a set of arguments. ScalarFunction(ScalarFunction), - /// Represents the call of an aggregate built-in function with arguments. + /// Calls an aggregate function with arguments, and optional + /// `ORDER BY`, `FILTER`, `DISTINCT` and `NULL TREATMENT`. + /// + /// See also [`ExprFunctionExt`] to set these fields. + /// + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), @@ -193,7 +315,10 @@ pub enum Expr { /// /// This expr has to be resolved to a list of columns before translating logical /// plan into physical plan. - Wildcard { qualifier: Option }, + Wildcard { + qualifier: Option, + options: WildcardOptions, + }, /// List of grouping set expressions. Only valid in the context of an aggregate /// GROUP BY expression list GroupingSet(GroupingSet), @@ -230,7 +355,8 @@ impl<'a> From<(Option<&'a TableReference>, &'a FieldRef)> for Expr { } } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +/// UNNEST expression. +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Unnest { pub expr: Box, } @@ -242,10 +368,15 @@ impl Unnest { expr: Box::new(expr), } } + + /// Create a new Unnest expression. + pub fn new_boxed(boxed: Box) -> Self { + Self { expr: boxed } + } } /// Alias expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Alias { pub expr: Box, pub relation: Option, @@ -268,7 +399,7 @@ impl Alias { } /// Binary expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct BinaryExpr { /// Left-hand side of the expression pub left: Box, @@ -319,7 +450,7 @@ impl Display for BinaryExpr { } /// CASE expression -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Hash)] pub struct Case { /// Optional base expression that can be compared to literal values in the "when" expressions pub expr: Option>, @@ -345,7 +476,7 @@ impl Case { } /// LIKE expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Like { pub negated: bool, pub expr: Box, @@ -375,7 +506,7 @@ impl Like { } /// BETWEEN expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Between { /// The value to compare pub expr: Box, @@ -399,21 +530,11 @@ impl Between { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -/// Defines which implementation of a function for DataFusion to call. -pub enum ScalarFunctionDefinition { - /// Resolved to a user defined function - UDF(Arc), - /// A scalar function constructed with name. This variant can not be executed directly - /// and instead must be resolved to one of the other variants prior to physical planning. - Name(Arc), -} - /// ScalarFunction expression invokes a built-in scalar function -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct ScalarFunction { /// The function - pub func_def: ScalarFunctionDefinition, + pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, } @@ -421,47 +542,14 @@ pub struct ScalarFunction { impl ScalarFunction { // return the Function's name pub fn name(&self) -> &str { - self.func_def.name() - } -} - -impl ScalarFunctionDefinition { - /// Function's name for display - pub fn name(&self) -> &str { - match self { - ScalarFunctionDefinition::UDF(udf) => udf.name(), - ScalarFunctionDefinition::Name(func_name) => func_name.as_ref(), - } - } - - /// Whether this function is volatile, i.e. whether it can return different results - /// when evaluated multiple times with the same input. - pub fn is_volatile(&self) -> Result { - match self { - ScalarFunctionDefinition::UDF(udf) => { - Ok(udf.signature().volatility == crate::Volatility::Volatile) - } - ScalarFunctionDefinition::Name(func) => { - internal_err!( - "Cannot determine volatility of unresolved function: {func}" - ) - } - } + self.func.name() } } impl ScalarFunction { /// Create a new ScalarFunction expression with a user-defined function (UDF) pub fn new_udf(udf: Arc, args: Vec) -> Self { - Self { - func_def: ScalarFunctionDefinition::UDF(udf), - args, - } - } - - /// Create a new ScalarFunction expression with a user-defined function (UDF) - pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec) -> Self { - Self { func_def, args } + Self { func: udf, args } } } @@ -480,26 +568,8 @@ pub enum GetFieldAccess { }, } -/// Returns the field of a [`arrow::array::ListArray`] or -/// [`arrow::array::StructArray`] by `key`. See [`GetFieldAccess`] for -/// details. -#[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct GetIndexedField { - /// The expression to take the field from - pub expr: Box, - /// The name of the field to take - pub field: GetFieldAccess, -} - -impl GetIndexedField { - /// Create a new GetIndexedField expression - pub fn new(expr: Box, field: GetFieldAccess) -> Self { - Self { expr, field } - } -} - /// Cast expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Cast { /// The expression being cast pub expr: Box, @@ -515,7 +585,7 @@ impl Cast { } /// TryCast Expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct TryCast { /// The expression being cast pub expr: Box, @@ -531,10 +601,10 @@ impl TryCast { } /// SORT expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Sort { /// The expression to sort on - pub expr: Box, + pub expr: Expr, /// The direction of the sort pub asc: bool, /// Whether to put Nulls before all other data values @@ -543,42 +613,50 @@ pub struct Sort { impl Sort { /// Create a new Sort expression - pub fn new(expr: Box, asc: bool, nulls_first: bool) -> Self { + pub fn new(expr: Expr, asc: bool, nulls_first: bool) -> Self { Self { expr, asc, nulls_first, } } -} -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -/// Defines which implementation of an aggregate function DataFusion should call. -pub enum AggregateFunctionDefinition { - BuiltIn(aggregate_function::AggregateFunction), - /// Resolved to a user defined aggregate function - UDF(Arc), - /// A aggregation function constructed with name. This variant can not be executed directly - /// and instead must be resolved to one of the other variants prior to physical planning. - Name(Arc), + /// Create a new Sort expression with the opposite sort direction + pub fn reverse(&self) -> Self { + Self { + expr: self.expr.clone(), + asc: !self.asc, + nulls_first: !self.nulls_first, + } + } } -impl AggregateFunctionDefinition { - /// Function's name for display - pub fn name(&self) -> &str { - match self { - AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), - AggregateFunctionDefinition::UDF(udf) => udf.name(), - AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), +impl Display for Sort { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.expr)?; + if self.asc { + write!(f, " ASC")?; + } else { + write!(f, " DESC")?; + } + if self.nulls_first { + write!(f, " NULLS FIRST")?; + } else { + write!(f, " NULLS LAST")?; } + Ok(()) } } /// Aggregate function -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +/// +/// See also [`ExprFunctionExt`] to set these fields on `Expr` +/// +/// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub func_def: AggregateFunctionDefinition, + pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -586,40 +664,22 @@ pub struct AggregateFunction { /// Optional filter pub filter: Option>, /// Optional ordering - pub order_by: Option>, + pub order_by: Option>, pub null_treatment: Option, } impl AggregateFunction { - pub fn new( - fun: aggregate_function::AggregateFunction, - args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option, - ) -> Self { - Self { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - args, - distinct, - filter, - order_by, - null_treatment, - } - } - /// Create a new AggregateFunction expression with a user-defined function (UDF) pub fn new_udf( - udf: Arc, + func: Arc, args: Vec, distinct: bool, filter: Option>, - order_by: Option>, + order_by: Option>, null_treatment: Option, ) -> Self { Self { - func_def: AggregateFunctionDefinition::UDF(udf), + func, args, distinct, filter, @@ -630,60 +690,109 @@ impl AggregateFunction { } /// WindowFunction -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] /// Defines which implementation of an aggregate function DataFusion should call. pub enum WindowFunctionDefinition { /// A built in aggregate function that leverages an aggregate function - AggregateFunction(aggregate_function::AggregateFunction), /// A a built-in window function - BuiltInWindowFunction(built_in_window_function::BuiltInWindowFunction), + BuiltInWindowFunction(BuiltInWindowFunction), /// A user defined aggregate function AggregateUDF(Arc), /// A user defined aggregate function - WindowUDF(Arc), + WindowUDF(Arc), } impl WindowFunctionDefinition { /// Returns the datatype of the window function - pub fn return_type(&self, input_expr_types: &[DataType]) -> Result { + pub fn return_type( + &self, + input_expr_types: &[DataType], + _input_expr_nullable: &[bool], + display_name: &str, + ) -> Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => { - fun.return_type(input_expr_types) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { fun.return_type(input_expr_types) } WindowFunctionDefinition::AggregateUDF(fun) => { fun.return_type(input_expr_types) } - WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types), + WindowFunctionDefinition::WindowUDF(fun) => fun + .field(WindowUDFFieldArgs::new(input_expr_types, display_name)) + .map(|field| field.data_type().clone()), } } - /// the signatures supported by the function `fun`. + /// The signatures supported by the function `fun`. pub fn signature(&self) -> Signature { match self { - WindowFunctionDefinition::AggregateFunction(fun) => fun.signature(), WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(), WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(), WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(), } } + + /// Function's name for display + pub fn name(&self) -> &str { + match self { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(), + WindowFunctionDefinition::WindowUDF(fun) => fun.name(), + WindowFunctionDefinition::AggregateUDF(fun) => fun.name(), + } + } } -impl fmt::Display for WindowFunctionDefinition { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for WindowFunctionDefinition { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { - WindowFunctionDefinition::AggregateFunction(fun) => fun.fmt(f), - WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.fmt(f), - WindowFunctionDefinition::AggregateUDF(fun) => std::fmt::Debug::fmt(fun, f), - WindowFunctionDefinition::WindowUDF(fun) => fun.fmt(f), + WindowFunctionDefinition::BuiltInWindowFunction(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f), + WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f), } } } +impl From for WindowFunctionDefinition { + fn from(value: BuiltInWindowFunction) -> Self { + Self::BuiltInWindowFunction(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::AggregateUDF(value) + } +} + +impl From> for WindowFunctionDefinition { + fn from(value: Arc) -> Self { + Self::WindowUDF(value) + } +} + /// Window function -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +/// +/// Holds the actual actual function to call [`WindowFunction`] as well as its +/// arguments (`args`) and the contents of the `OVER` clause: +/// +/// 1. `PARTITION BY` +/// 2. `ORDER BY` +/// 3. Window frame (e.g. `ROWS 1 PRECEDING AND 1 FOLLOWING`) +/// +/// # Example +/// ``` +/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt}; +/// # use datafusion_expr::expr::WindowFunction; +/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c) +/// let expr = Expr::WindowFunction( +/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")]) +/// ) +/// .partition_by(vec![col("b")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .build() +/// .unwrap(); +/// ``` +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct WindowFunction { /// Name of the function pub fun: WindowFunctionDefinition, @@ -692,30 +801,24 @@ pub struct WindowFunction { /// List of partition by expressions pub partition_by: Vec, /// List of order by expressions - pub order_by: Vec, + pub order_by: Vec, /// Window frame - pub window_frame: window_frame::WindowFrame, + pub window_frame: WindowFrame, /// Specifies how NULL value is treated: ignore or respect pub null_treatment: Option, } impl WindowFunction { - /// Create a new Window expression - pub fn new( - fun: WindowFunctionDefinition, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: window_frame::WindowFrame, - null_treatment: Option, - ) -> Self { + /// Create a new Window expression with the specified argument an + /// empty `OVER` clause + pub fn new(fun: impl Into, args: Vec) -> Self { Self { - fun, + fun: fun.into(), args, - partition_by, - order_by, - window_frame, - null_treatment, + partition_by: Vec::default(), + order_by: Vec::default(), + window_frame: WindowFrame::new(None), + null_treatment: None, } } } @@ -728,25 +831,19 @@ pub fn find_df_window_func(name: &str) -> Option { // may have different implementations for these cases. If the sought // function is not found among built-in window functions, we search for // it among aggregate functions. - if let Ok(built_in_function) = - built_in_window_function::BuiltInWindowFunction::from_str(name.as_str()) - { + if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) { Some(WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, )) - } else if let Ok(aggregate) = - aggregate_function::AggregateFunction::from_str(name.as_str()) - { - Some(WindowFunctionDefinition::AggregateFunction(aggregate)) } else { None } } -// Exists expression. -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +/// EXISTS expression +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Exists { - /// subquery that will produce a single column of data + /// Subquery that will produce a single column of data pub subquery: Subquery, /// Whether the expression is negated pub negated: bool, @@ -759,6 +856,9 @@ impl Exists { } } +/// User Defined Aggregate Function +/// +/// See [`udaf::AggregateUDF`] for more information. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateUDF { /// The function @@ -789,7 +889,7 @@ impl AggregateUDF { } /// InList expression -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InList { /// The expression to compare pub expr: Box, @@ -811,7 +911,7 @@ impl InList { } /// IN subquery -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct InSubquery { /// The expression to compare pub expr: Box, @@ -836,7 +936,7 @@ impl InSubquery { /// /// The type of these parameters is inferred using [`Expr::infer_placeholder_types`] /// or can be specified directly using `PREPARE` statements. -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct Placeholder { /// The identifier of the parameter, including the leading `$` (e.g, `"$1"` or `"$foo"`) pub id: String, @@ -852,11 +952,12 @@ impl Placeholder { } /// Grouping sets +/// /// See /// for Postgres definition. /// See /// for Apache Spark definition. -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum GroupingSet { /// Rollup grouping sets Rollup(Vec), @@ -870,15 +971,16 @@ impl GroupingSet { /// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this /// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate /// the exprs in the underlying sets. - pub fn distinct_expr(&self) -> Vec { + pub fn distinct_expr(&self) -> Vec<&Expr> { match self { - GroupingSet::Rollup(exprs) => exprs.clone(), - GroupingSet::Cube(exprs) => exprs.clone(), + GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => { + exprs.iter().collect() + } GroupingSet::GroupingSets(groups) => { - let mut exprs: Vec = vec![]; + let mut exprs: Vec<&Expr> = vec![]; for exp in groups.iter().flatten() { - if !exprs.contains(exp) { - exprs.push(exp.clone()); + if !exprs.contains(&exp) { + exprs.push(exp); } } exprs @@ -887,26 +989,136 @@ impl GroupingSet { } } -/// Fixed seed for the hashing so that Ords are consistent across runs -const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); +/// Additional options for wildcards, e.g. Snowflake `EXCLUDE`/`RENAME` and Bigquery `EXCEPT`. +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug, Default)] +pub struct WildcardOptions { + /// `[ILIKE...]`. + /// Snowflake syntax: + pub ilike: Option, + /// `[EXCLUDE...]`. + /// Snowflake syntax: + pub exclude: Option, + /// `[EXCEPT...]`. + /// BigQuery syntax: + /// Clickhouse syntax: + pub except: Option, + /// `[REPLACE]` + /// BigQuery syntax: + /// Clickhouse syntax: + /// Snowflake syntax: + pub replace: Option, + /// `[RENAME ...]`. + /// Snowflake syntax: + pub rename: Option, +} + +impl WildcardOptions { + pub fn with_replace(self, replace: PlannedReplaceSelectItem) -> Self { + WildcardOptions { + ilike: self.ilike, + exclude: self.exclude, + except: self.except, + replace: Some(replace), + rename: self.rename, + } + } +} + +impl Display for WildcardOptions { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + if let Some(ilike) = &self.ilike { + write!(f, " {ilike}")?; + } + if let Some(exclude) = &self.exclude { + write!(f, " {exclude}")?; + } + if let Some(except) = &self.except { + write!(f, " {except}")?; + } + if let Some(replace) = &self.replace { + write!(f, " {replace}")?; + } + if let Some(rename) = &self.rename { + write!(f, " {rename}")?; + } + Ok(()) + } +} + +/// The planned expressions for `REPLACE` +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug, Default)] +pub struct PlannedReplaceSelectItem { + /// The original ast nodes + pub items: Vec, + /// The expression planned from the ast nodes. They will be used when expanding the wildcard. + pub planned_expressions: Vec, +} + +impl Display for PlannedReplaceSelectItem { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "REPLACE")?; + write!(f, " ({})", display_comma_separated(&self.items))?; + Ok(()) + } +} -impl PartialOrd for Expr { - fn partial_cmp(&self, other: &Self) -> Option { - let s = SEED.hash_one(self); - let o = SEED.hash_one(other); +impl PlannedReplaceSelectItem { + pub fn items(&self) -> &[ReplaceSelectElement] { + &self.items + } - Some(s.cmp(&o)) + pub fn expressions(&self) -> &[Expr] { + &self.planned_expressions } } impl Expr { - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. + #[deprecated(since = "40.0.0", note = "use schema_name instead")] pub fn display_name(&self) -> Result { - create_name(self) + Ok(self.schema_name().to_string()) + } + + /// The name of the column (field) that this `Expr` will produce. + /// + /// For example, for a projection (e.g. `SELECT `) the resulting arrow + /// [`Schema`] will have a field with this name. + /// + /// Note that the resulting string is subtlety different from the `Display` + /// representation for certain `Expr`. Some differences: + /// + /// 1. [`Expr::Alias`], which shows only the alias itself + /// 2. [`Expr::Cast`] / [`Expr::TryCast`], which only displays the expression + /// + /// # Example + /// ``` + /// # use datafusion_expr::{col, lit}; + /// let expr = col("foo").eq(lit(42)); + /// assert_eq!("foo = Int32(42)", expr.schema_name().to_string()); + /// + /// let expr = col("foo").alias("bar").eq(lit(11)); + /// assert_eq!("bar = Int32(11)", expr.schema_name().to_string()); + /// ``` + /// + /// [`Schema`]: arrow::datatypes::Schema + pub fn schema_name(&self) -> impl Display + '_ { + SchemaDisplay(self) + } + + /// Returns the qualifier and the schema name of this expression. + /// + /// Used when the expression forms the output field of a certain plan. + /// The result is the field's qualifier and field name in the plan's + /// output schema. We can use this qualified name to reference the field. + pub fn qualified_name(&self) -> (Option, String) { + match self { + Expr::Column(Column { relation, name }) => (relation.clone(), name.clone()), + Expr::Alias(Alias { relation, name, .. }) => (relation.clone(), name.clone()), + _ => (None, self.schema_name().to_string()), + } } /// Returns a full and complete string representation of this expression. + #[deprecated(note = "use format! instead")] pub fn canonical_name(&self) -> String { format!("{self}") } @@ -924,7 +1136,6 @@ impl Expr { Expr::Column(..) => "Column", Expr::OuterReferenceColumn(_, _) => "Outer", Expr::Exists { .. } => "Exists", - Expr::GetIndexedField { .. } => "GetIndexedField", Expr::GroupingSet(..) => "GroupingSet", Expr::InList { .. } => "InList", Expr::InSubquery(..) => "InSubquery", @@ -945,7 +1156,6 @@ impl Expr { Expr::ScalarFunction(..) => "ScalarFunction", Expr::ScalarSubquery { .. } => "ScalarSubquery", Expr::ScalarVariable(..) => "ScalarVariable", - Expr::Sort { .. } => "Sort", Expr::TryCast { .. } => "TryCast", Expr::WindowFunction { .. } => "WindowFunction", Expr::Wildcard { .. } => "Wildcard", @@ -1031,21 +1241,15 @@ impl Expr { Expr::Like(Like::new(true, Box::new(self), Box::new(other), None, true)) } - /// Return the name to use for the specific Expr, recursing into - /// `Expr::Sort` as appropriate + /// Return the name to use for the specific Expr pub fn name_for_alias(&self) -> Result { - match self { - // call Expr::display_name() on a Expr::Sort will throw an error - Expr::Sort(Sort { expr, .. }) => expr.name_for_alias(), - expr => expr.display_name(), - } + Ok(self.schema_name().to_string()) } /// Ensure `expr` has the name as `original_name` by adding an /// alias if necessary. pub fn alias_if_changed(self, original_name: String) -> Result { let new_name = self.name_for_alias()?; - if new_name == original_name { return Ok(self); } @@ -1055,14 +1259,7 @@ impl Expr { /// Return `self AS name` alias expression pub fn alias(self, name: impl Into) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new(Box::new(expr.alias(name)), asc, nulls_first)), - _ => Expr::Alias(Alias::new(self, None::<&str>, name.into())), - } + Expr::Alias(Alias::new(self, None::<&str>, name.into())) } /// Return `self AS name` alias expression with a specific qualifier @@ -1071,21 +1268,29 @@ impl Expr { relation: Option>, name: impl Into, ) -> Expr { - match self { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => Expr::Sort(Sort::new( - Box::new(expr.alias_qualified(relation, name)), - asc, - nulls_first, - )), - _ => Expr::Alias(Alias::new(self, relation, name.into())), - } + Expr::Alias(Alias::new(self, relation, name.into())) } /// Remove an alias from an expression if one exists. + /// + /// If the expression is not an alias, the expression is returned unchanged. + /// This method does not remove aliases from nested expressions. + /// + /// # Example + /// ``` + /// # use datafusion_expr::col; + /// // `foo as "bar"` is unaliased to `foo` + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.unalias(), col("foo")); + /// + /// // `foo as "bar" + baz` is not unaliased + /// let expr = col("foo").alias("bar") + col("baz"); + /// assert_eq!(expr.clone().unalias(), expr); + /// + /// // `foo as "bar" as "baz" is unalaised to foo as "bar" + /// let expr = col("foo").alias("bar").alias("baz"); + /// assert_eq!(expr.unalias(), col("foo").alias("bar")); + /// ``` pub fn unalias(self) -> Expr { match self { Expr::Alias(alias) => *alias.expr, @@ -1093,6 +1298,55 @@ impl Expr { } } + /// Recursively removed potentially multiple aliases from an expression. + /// + /// This method removes nested aliases and returns [`Transformed`] + /// to signal if the expression was changed. + /// + /// # Example + /// ``` + /// # use datafusion_expr::col; + /// // `foo as "bar"` is unaliased to `foo` + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.unalias_nested().data, col("foo")); + /// + /// // `foo as "bar" + baz` is unaliased + /// let expr = col("foo").alias("bar") + col("baz"); + /// assert_eq!(expr.clone().unalias_nested().data, col("foo") + col("baz")); + /// + /// // `foo as "bar" as "baz" is unalaised to foo + /// let expr = col("foo").alias("bar").alias("baz"); + /// assert_eq!(expr.unalias_nested().data, col("foo")); + /// ``` + pub fn unalias_nested(self) -> Transformed { + self.transform_down_up( + |expr| { + // f_down: skip subqueries. Check in f_down to avoid recursing into them + let recursion = if matches!( + expr, + Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) + ) { + // Subqueries could contain aliases so don't recurse into those + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }; + Ok(Transformed::new(expr, false, recursion)) + }, + |expr| { + // f_up: unalias on up so we can remove nested aliases like + // `(x as foo) as bar` + if let Expr::Alias(Alias { expr, .. }) = expr { + Ok(Transformed::yes(*expr)) + } else { + Ok(Transformed::no(expr)) + } + }, + ) + // Unreachable code: internal closure doesn't return err + .unwrap() + } + /// Return `self IN ` if `negated` is false, otherwise /// return `self NOT IN `.a pub fn in_list(self, list: Vec, negated: bool) -> Expr { @@ -1109,14 +1363,14 @@ impl Expr { Expr::IsNotNull(Box::new(self)) } - /// Create a sort expression from an existing expression. + /// Create a sort configuration from an existing expression. /// /// ``` /// # use datafusion_expr::col; /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort(Sort::new(Box::new(self), asc, nulls_first)) + pub fn sort(self, asc: bool, nulls_first: bool) -> Sort { + Sort::new(self, asc, nulls_first) } /// Return `IsTrue(Box(self))` @@ -1159,7 +1413,7 @@ impl Expr { )) } - /// return `self NOT BETWEEN low AND high` + /// Return `self NOT BETWEEN low AND high` pub fn not_between(self, low: Expr, high: Expr) -> Expr { Expr::Between(Between::new( Box::new(self), @@ -1169,99 +1423,58 @@ impl Expr { )) } - /// Return access to the named field. Example `expr["name"]` - /// - /// ## Access field "my_field" from column "c1" - /// - /// For example if column "c1" holds documents like this - /// - /// ```json - /// { - /// "my_field": 123.34, - /// "other_field": "Boston", - /// } - /// ``` - /// - /// You can access column "my_field" with - /// - /// ``` - /// # use datafusion_expr::{col}; - /// let expr = col("c1") - /// .field("my_field"); - /// assert_eq!(expr.display_name().unwrap(), "c1[my_field]"); - /// ``` - pub fn field(self, name: impl Into) -> Self { - Expr::GetIndexedField(GetIndexedField { - expr: Box::new(self), - field: GetFieldAccess::NamedStructField { - name: ScalarValue::from(name.into()), - }, - }) + #[deprecated(since = "39.0.0", note = "use try_as_col instead")] + pub fn try_into_col(&self) -> Result { + match self { + Expr::Column(it) => Ok(it.clone()), + _ => plan_err!("Could not coerce '{self}' into Column!"), + } } - /// Return access to the element field. Example `expr["name"]` + /// Return a reference to the inner `Column` if any /// - /// ## Example Access element 2 from column "c1" + /// returns `None` if the expression is not a `Column` /// - /// For example if column "c1" holds documents like this + /// Note: None may be returned for expressions that are not `Column` but + /// are convertible to `Column` such as `Cast` expressions. /// - /// ```json - /// [10, 20, 30, 40] + /// Example /// ``` + /// # use datafusion_common::Column; + /// use datafusion_expr::{col, Expr}; + /// let expr = col("foo"); + /// assert_eq!(expr.try_as_col(), Some(&Column::from("foo"))); /// - /// You can access the value "30" with - /// + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.try_as_col(), None); /// ``` - /// # use datafusion_expr::{lit, col, Expr}; - /// let expr = col("c1") - /// .index(lit(3)); - /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(3)]"); - /// ``` - pub fn index(self, key: Expr) -> Self { - Expr::GetIndexedField(GetIndexedField { - expr: Box::new(self), - field: GetFieldAccess::ListIndex { key: Box::new(key) }, - }) + pub fn try_as_col(&self) -> Option<&Column> { + if let Expr::Column(it) = self { + Some(it) + } else { + None + } } - /// Return elements between `1` based `start` and `stop`, for - /// example `expr[1:3]` - /// - /// ## Example: Access element 2, 3, 4 from column "c1" - /// - /// For example if column "c1" holds documents like this - /// - /// ```json - /// [10, 20, 30, 40] - /// ``` - /// - /// You can access the value `[20, 30, 40]` with + /// Returns the inner `Column` if any. This is a specialized version of + /// [`Self::try_as_col`] that take Cast expressions into account when the + /// expression is as on condition for joins. /// - /// ``` - /// # use datafusion_expr::{lit, col}; - /// let expr = col("c1") - /// .range(lit(2), lit(4)); - /// assert_eq!(expr.display_name().unwrap(), "c1[Int32(2):Int32(4):Int64(1)]"); - /// ``` - pub fn range(self, start: Expr, stop: Expr) -> Self { - Expr::GetIndexedField(GetIndexedField { - expr: Box::new(self), - field: GetFieldAccess::ListRange { - start: Box::new(start), - stop: Box::new(stop), - stride: Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))), - }, - }) - } - - pub fn try_into_col(&self) -> Result { + /// Called this method when you are sure that the expression is a `Column` + /// or a `Cast` expression that wraps a `Column`. + pub fn get_as_join_column(&self) -> Option<&Column> { match self { - Expr::Column(it) => Ok(it.clone()), - _ => plan_err!("Could not coerce '{self}' into Column!"), + Expr::Column(c) => Some(c), + Expr::Cast(Cast { expr, .. }) => match &**expr { + Expr::Column(c) => Some(c), + _ => None, + }, + _ => None, } } /// Return all referenced columns of this expression. + #[deprecated(since = "40.0.0", note = "use Expr::column_refs instead")] pub fn to_columns(&self) -> Result> { let mut using_columns = HashSet::new(); expr_to_columns(self, &mut using_columns)?; @@ -1269,18 +1482,105 @@ impl Expr { Ok(using_columns) } - /// Return true when the expression contains out reference(correlated) expressions. - pub fn contains_outer(&self) -> bool { - self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. }))) - .unwrap() + /// Return all references to columns in this expression. + /// + /// # Example + /// ``` + /// # use std::collections::HashSet; + /// # use datafusion_common::Column; + /// # use datafusion_expr::col; + /// // For an expression `a + (b * a)` + /// let expr = col("a") + (col("b") * col("a")); + /// let refs = expr.column_refs(); + /// // refs contains "a" and "b" + /// assert_eq!(refs.len(), 2); + /// assert!(refs.contains(&Column::new_unqualified("a"))); + /// assert!(refs.contains(&Column::new_unqualified("b"))); + /// ``` + pub fn column_refs(&self) -> HashSet<&Column> { + let mut using_columns = HashSet::new(); + self.add_column_refs(&mut using_columns); + using_columns + } + + /// Adds references to all columns in this expression to the set + /// + /// See [`Self::column_refs`] for details + pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { + self.apply(|expr| { + if let Expr::Column(col) = expr { + set.insert(col); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); + } + + /// Return all references to columns and their occurrence counts in the expression. + /// + /// # Example + /// ``` + /// # use std::collections::HashMap; + /// # use datafusion_common::Column; + /// # use datafusion_expr::col; + /// // For an expression `a + (b * a)` + /// let expr = col("a") + (col("b") * col("a")); + /// let mut refs = expr.column_refs_counts(); + /// // refs contains "a" and "b" + /// assert_eq!(refs.len(), 2); + /// assert_eq!(*refs.get(&Column::new_unqualified("a")).unwrap(), 2); + /// assert_eq!(*refs.get(&Column::new_unqualified("b")).unwrap(), 1); + /// ``` + pub fn column_refs_counts(&self) -> HashMap<&Column, usize> { + let mut map = HashMap::new(); + self.add_column_ref_counts(&mut map); + map + } + + /// Adds references to all columns and their occurrence counts in the expression to + /// the map. + /// + /// See [`Self::column_refs_counts`] for details + pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { + self.apply(|expr| { + if let Expr::Column(col) = expr { + *map.entry(col).or_default() += 1; + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); + } + + /// Returns true if there are any column references in this Expr + pub fn any_column_refs(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) + .unwrap() + } + + /// Return true when the expression contains out reference(correlated) expressions. + pub fn contains_outer(&self) -> bool { + self.exists(|expr| Ok(matches!(expr, Expr::OuterReferenceColumn { .. }))) + .unwrap() + } + + /// Returns true if the expression node is volatile, i.e. whether it can return + /// different results when evaluated multiple times with the same input. + /// Note: unlike [`Self::is_volatile`], this function does not consider inputs: + /// - `rand()` returns `true`, + /// - `a + rand()` returns `false` + pub fn is_volatile_node(&self) -> bool { + matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile) } /// Returns true if the expression is volatile, i.e. whether it can return different /// results when evaluated multiple times with the same input. - pub fn is_volatile(&self) -> Result { - self.exists(|expr| { - Ok(matches!(expr, Expr::ScalarFunction(func) if func.func_def.is_volatile()?)) - }) + /// + /// For example the function call `RANDOM()` is volatile as each call will + /// return a different value. + /// + /// See [`Volatility`] for more information. + pub fn is_volatile(&self) -> bool { + self.exists(|expr| Ok(expr.is_volatile_node())).unwrap() } /// Recursively find all [`Expr::Placeholder`] expressions, and @@ -1314,9 +1614,7 @@ impl Expr { /// and thus any side effects (like divide by zero) may not be encountered pub fn short_circuits(&self) -> bool { match self { - Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { - matches!(func_def, ScalarFunctionDefinition::UDF(fun) if fun.short_circuits()) - } + Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -1330,7 +1628,6 @@ impl Expr { | Expr::Cast(..) | Expr::Column(..) | Expr::Exists(..) - | Expr::GetIndexedField(..) | Expr::GroupingSet(..) | Expr::InList(..) | Expr::InSubquery(..) @@ -1354,13 +1651,167 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Sort(..) | Expr::Placeholder(..) => false, } } } -// modifies expr if it is a placeholder with datatype of right +impl HashNode for Expr { + /// As it is pretty easy to forget changing this method when `Expr` changes the + /// implementation doesn't use wildcard patterns (`..`, `_`) to catch changes + /// compile time. + fn hash_node(&self, state: &mut H) { + mem::discriminant(self).hash(state); + match self { + Expr::Alias(Alias { + expr: _expr, + relation, + name, + }) => { + relation.hash(state); + name.hash(state); + } + Expr::Column(column) => { + column.hash(state); + } + Expr::ScalarVariable(data_type, name) => { + data_type.hash(state); + name.hash(state); + } + Expr::Literal(scalar_value) => { + scalar_value.hash(state); + } + Expr::BinaryExpr(BinaryExpr { + left: _left, + op, + right: _right, + }) => { + op.hash(state); + } + Expr::Like(Like { + negated, + expr: _expr, + pattern: _pattern, + escape_char, + case_insensitive, + }) + | Expr::SimilarTo(Like { + negated, + expr: _expr, + pattern: _pattern, + escape_char, + case_insensitive, + }) => { + negated.hash(state); + escape_char.hash(state); + case_insensitive.hash(state); + } + Expr::Not(_expr) + | Expr::IsNotNull(_expr) + | Expr::IsNull(_expr) + | Expr::IsTrue(_expr) + | Expr::IsFalse(_expr) + | Expr::IsUnknown(_expr) + | Expr::IsNotTrue(_expr) + | Expr::IsNotFalse(_expr) + | Expr::IsNotUnknown(_expr) + | Expr::Negative(_expr) => {} + Expr::Between(Between { + expr: _expr, + negated, + low: _low, + high: _high, + }) => { + negated.hash(state); + } + Expr::Case(Case { + expr: _expr, + when_then_expr: _when_then_expr, + else_expr: _else_expr, + }) => {} + Expr::Cast(Cast { + expr: _expr, + data_type, + }) + | Expr::TryCast(TryCast { + expr: _expr, + data_type, + }) => { + data_type.hash(state); + } + Expr::ScalarFunction(ScalarFunction { func, args: _args }) => { + func.hash(state); + } + Expr::AggregateFunction(AggregateFunction { + func, + args: _args, + distinct, + filter: _filter, + order_by: _order_by, + null_treatment, + }) => { + func.hash(state); + distinct.hash(state); + null_treatment.hash(state); + } + Expr::WindowFunction(WindowFunction { + fun, + args: _args, + partition_by: _partition_by, + order_by: _order_by, + window_frame, + null_treatment, + }) => { + fun.hash(state); + window_frame.hash(state); + null_treatment.hash(state); + } + Expr::InList(InList { + expr: _expr, + list: _list, + negated, + }) => { + negated.hash(state); + } + Expr::Exists(Exists { subquery, negated }) => { + subquery.hash(state); + negated.hash(state); + } + Expr::InSubquery(InSubquery { + expr: _expr, + subquery, + negated, + }) => { + subquery.hash(state); + negated.hash(state); + } + Expr::ScalarSubquery(subquery) => { + subquery.hash(state); + } + Expr::Wildcard { qualifier, options } => { + qualifier.hash(state); + options.hash(state); + } + Expr::GroupingSet(grouping_set) => { + mem::discriminant(grouping_set).hash(state); + match grouping_set { + GroupingSet::Rollup(_exprs) | GroupingSet::Cube(_exprs) => {} + GroupingSet::GroupingSets(_exprs) => {} + } + } + Expr::Placeholder(place_holder) => { + place_holder.hash(state); + } + Expr::OuterReferenceColumn(data_type, column) => { + data_type.hash(state); + column.hash(state); + } + Expr::Unnest(Unnest { expr: _expr }) => {} + }; + } +} + +// Modifies expr if it is a placeholder with datatype of right fn rewrite_placeholder(expr: &mut Expr, other: &Expr, schema: &DFSchema) -> Result<()> { if let Expr::Placeholder(Placeholder { id: _, data_type }) = expr { if data_type.is_none() { @@ -1391,10 +1842,308 @@ macro_rules! expr_vec_fmt { }}; } +struct SchemaDisplay<'a>(&'a Expr); +impl<'a> Display for SchemaDisplay<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.0 { + // The same as Display + Expr::Column(_) + | Expr::Literal(_) + | Expr::ScalarVariable(..) + | Expr::OuterReferenceColumn(..) + | Expr::Placeholder(_) + | Expr::Wildcard { .. } => write!(f, "{}", self.0), + + Expr::AggregateFunction(AggregateFunction { + func, + args, + distinct, + filter, + order_by, + null_treatment, + }) => { + write!( + f, + "{}({}{})", + func.name(), + if *distinct { "DISTINCT " } else { "" }, + schema_name_from_exprs_comma_seperated_without_space(args)? + )?; + + if let Some(null_treatment) = null_treatment { + write!(f, " {}", null_treatment)?; + } + + if let Some(filter) = filter { + write!(f, " FILTER (WHERE {filter})")?; + }; + + if let Some(order_by) = order_by { + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; + }; + + Ok(()) + } + // Expr is not shown since it is aliased + Expr::Alias(Alias { name, .. }) => write!(f, "{name}"), + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + if *negated { + write!( + f, + "{} NOT BETWEEN {} AND {}", + SchemaDisplay(expr), + SchemaDisplay(low), + SchemaDisplay(high), + ) + } else { + write!( + f, + "{} BETWEEN {} AND {}", + SchemaDisplay(expr), + SchemaDisplay(low), + SchemaDisplay(high), + ) + } + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + write!(f, "{} {op} {}", SchemaDisplay(left), SchemaDisplay(right),) + } + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => { + write!(f, "CASE ")?; + + if let Some(e) = expr { + write!(f, "{} ", SchemaDisplay(e))?; + } + + for (when, then) in when_then_expr { + write!( + f, + "WHEN {} THEN {} ", + SchemaDisplay(when), + SchemaDisplay(then), + )?; + } + + if let Some(e) = else_expr { + write!(f, "ELSE {} ", SchemaDisplay(e))?; + } + + write!(f, "END") + } + // Cast expr is not shown to be consistant with Postgres and Spark + Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) => { + write!(f, "{}", SchemaDisplay(expr)) + } + Expr::InList(InList { + expr, + list, + negated, + }) => { + let inlist_name = schema_name_from_exprs(list)?; + + if *negated { + write!(f, "{} NOT IN {}", SchemaDisplay(expr), inlist_name) + } else { + write!(f, "{} IN {}", SchemaDisplay(expr), inlist_name) + } + } + Expr::Exists(Exists { negated: true, .. }) => write!(f, "NOT EXISTS"), + Expr::Exists(Exists { negated: false, .. }) => write!(f, "EXISTS"), + Expr::GroupingSet(GroupingSet::Cube(exprs)) => { + write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?) + } + Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => { + write!(f, "GROUPING SETS (")?; + for exprs in lists_of_exprs.iter() { + write!(f, "({})", schema_name_from_exprs(exprs)?)?; + } + write!(f, ")") + } + Expr::GroupingSet(GroupingSet::Rollup(exprs)) => { + write!(f, "ROLLUP ({})", schema_name_from_exprs(exprs)?) + } + Expr::IsNull(expr) => write!(f, "{} IS NULL", SchemaDisplay(expr)), + Expr::IsNotNull(expr) => { + write!(f, "{} IS NOT NULL", SchemaDisplay(expr)) + } + Expr::IsUnknown(expr) => { + write!(f, "{} IS UNKNOWN", SchemaDisplay(expr)) + } + Expr::IsNotUnknown(expr) => { + write!(f, "{} IS NOT UNKNOWN", SchemaDisplay(expr)) + } + Expr::InSubquery(InSubquery { negated: true, .. }) => { + write!(f, "NOT IN") + } + Expr::InSubquery(InSubquery { negated: false, .. }) => write!(f, "IN"), + Expr::IsTrue(expr) => write!(f, "{} IS TRUE", SchemaDisplay(expr)), + Expr::IsFalse(expr) => write!(f, "{} IS FALSE", SchemaDisplay(expr)), + Expr::IsNotTrue(expr) => { + write!(f, "{} IS NOT TRUE", SchemaDisplay(expr)) + } + Expr::IsNotFalse(expr) => { + write!(f, "{} IS NOT FALSE", SchemaDisplay(expr)) + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + write!( + f, + "{} {}{} {}", + SchemaDisplay(expr), + if *negated { "NOT " } else { "" }, + if *case_insensitive { "ILIKE" } else { "LIKE" }, + SchemaDisplay(pattern), + )?; + + if let Some(char) = escape_char { + write!(f, " CHAR '{char}'")?; + } + + Ok(()) + } + Expr::Negative(expr) => write!(f, "(- {})", SchemaDisplay(expr)), + Expr::Not(expr) => write!(f, "NOT {}", SchemaDisplay(expr)), + Expr::Unnest(Unnest { expr }) => { + write!(f, "UNNEST({})", SchemaDisplay(expr)) + } + Expr::ScalarFunction(ScalarFunction { func, args }) => { + match func.schema_name(args) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {}", e) + } + } + } + Expr::ScalarSubquery(Subquery { subquery, .. }) => { + write!(f, "{}", subquery.schema().field(0).name()) + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + .. + }) => { + write!( + f, + "{} {} {}", + SchemaDisplay(expr), + if *negated { + "NOT SIMILAR TO" + } else { + "SIMILAR TO" + }, + SchemaDisplay(pattern), + )?; + if let Some(char) = escape_char { + write!(f, " CHAR '{char}'")?; + } + + Ok(()) + } + Expr::WindowFunction(WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + null_treatment, + }) => { + write!( + f, + "{}({})", + fun, + schema_name_from_exprs_comma_seperated_without_space(args)? + )?; + + if let Some(null_treatment) = null_treatment { + write!(f, " {}", null_treatment)?; + } + + if !partition_by.is_empty() { + write!( + f, + " PARTITION BY [{}]", + schema_name_from_exprs(partition_by)? + )?; + } + + if !order_by.is_empty() { + write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; + }; + + write!(f, " {window_frame}") + } + } + } +} + +/// Get schema_name for Vector of expressions +/// +/// Internal usage. Please call `schema_name_from_exprs` instead +// TODO: Use ", " to standardize the formatting of Vec, +// +pub(crate) fn schema_name_from_exprs_comma_seperated_without_space( + exprs: &[Expr], +) -> Result { + schema_name_from_exprs_inner(exprs, ",") +} + +/// Get schema_name for Vector of expressions +pub fn schema_name_from_exprs(exprs: &[Expr]) -> Result { + schema_name_from_exprs_inner(exprs, ", ") +} + +fn schema_name_from_exprs_inner(exprs: &[Expr], sep: &str) -> Result { + let mut s = String::new(); + for (i, e) in exprs.iter().enumerate() { + if i > 0 { + write!(&mut s, "{sep}")?; + } + write!(&mut s, "{}", SchemaDisplay(e))?; + } + + Ok(s) +} + +pub fn schema_name_from_sorts(sorts: &[Sort]) -> Result { + let mut s = String::new(); + for (i, e) in sorts.iter().enumerate() { + if i > 0 { + write!(&mut s, ", ")?; + } + let ordering = if e.asc { "ASC" } else { "DESC" }; + let nulls_ordering = if e.nulls_first { + "NULLS FIRST" + } else { + "NULLS LAST" + }; + write!(&mut s, "{} {} {}", e.expr, ordering, nulls_ordering)?; + } + + Ok(s) +} + /// Format expressions for display as part of a logical plan. In many cases, this will produce /// similar output to `Expr.name()` except that column names will be prefixed with '#'. -impl fmt::Display for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { +impl Display for Expr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { Expr::Alias(Alias { expr, name, .. }) => write!(f, "{expr} AS {name}"), Expr::Column(c) => write!(f, "{c}"), @@ -1450,25 +2199,13 @@ impl fmt::Display for Expr { }) => write!(f, "{expr} IN ({subquery:?})"), Expr::ScalarSubquery(subquery) => write!(f, "({subquery:?})"), Expr::BinaryExpr(expr) => write!(f, "{expr}"), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - if *asc { - write!(f, "{expr} ASC")?; - } else { - write!(f, "{expr} DESC")?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } Expr::ScalarFunction(fun) => { fmt_function(f, fun.name(), false, &fun.args, true) } + // TODO: use udf's display_name, need to fix the seperator issue, + // Expr::ScalarFunction(ScalarFunction { func, args }) => { + // write!(f, "{}", func.display_name(args).unwrap()) + // } Expr::WindowFunction(WindowFunction { fun, args, @@ -1497,7 +2234,7 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - func_def, + func, distinct, ref args, filter, @@ -1505,7 +2242,7 @@ impl fmt::Display for Expr { null_treatment, .. }) => { - fmt_function(f, func_def.name(), *distinct, args, true)?; + fmt_function(f, func.name(), *distinct, args, true)?; if let Some(nt) = null_treatment { write!(f, " {}", nt)?; } @@ -1575,22 +2312,9 @@ impl fmt::Display for Expr { write!(f, "{expr} IN ([{}])", expr_vec_fmt!(list)) } } - Expr::Wildcard { qualifier } => match qualifier { - Some(qualifier) => write!(f, "{qualifier}.*"), - None => write!(f, "*"), - }, - Expr::GetIndexedField(GetIndexedField { field, expr }) => match field { - GetFieldAccess::NamedStructField { name } => { - write!(f, "({expr})[{name}]") - } - GetFieldAccess::ListIndex { key } => write!(f, "({expr})[{key}]"), - GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - write!(f, "({expr})[{start}:{stop}:{stride}]") - } + Expr::Wildcard { qualifier, options } => match qualifier { + Some(qualifier) => write!(f, "{qualifier}.*{options}"), + None => write!(f, "*{options}"), }, Expr::GroupingSet(grouping_sets) => match grouping_sets { GroupingSet::Rollup(exprs) => { @@ -1616,14 +2340,14 @@ impl fmt::Display for Expr { }, Expr::Placeholder(Placeholder { id, .. }) => write!(f, "{id}"), Expr::Unnest(Unnest { expr }) => { - write!(f, "UNNEST({expr:?})") + write!(f, "UNNEST({expr})") } } } } fn fmt_function( - f: &mut fmt::Formatter, + f: &mut Formatter, fun: &str, distinct: bool, args: &[Expr], @@ -1634,7 +2358,6 @@ fn fmt_function( false => args.iter().map(|arg| format!("{arg:?}")).collect(), }; - // let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); let distinct_str = match distinct { true => "DISTINCT ", false => "", @@ -1642,311 +2365,29 @@ fn fmt_function( write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) } -fn create_function_name(fun: &str, distinct: bool, args: &[Expr]) -> Result { - let names: Vec = args.iter().map(create_name).collect::>()?; - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) -} - -/// Returns a readable name of an expression based on the input schema. -/// This function recursively transverses the expression for names such as "CAST(a > 2)". -fn create_name(e: &Expr) -> Result { - match e { - Expr::Alias(Alias { name, .. }) => Ok(name.clone()), - Expr::Column(c) => Ok(c.flat_name()), - Expr::OuterReferenceColumn(_, c) => Ok(format!("outer_ref({})", c.flat_name())), - Expr::ScalarVariable(_, variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{value:?}")), - Expr::BinaryExpr(binary_expr) => { - let left = create_name(binary_expr.left.as_ref())?; - let right = create_name(binary_expr.right.as_ref())?; - Ok(format!("{} {} {}", left, binary_expr.op, right)) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive, - }) => { - let s = format!( - "{} {}{} {} {}", - expr, - if *negated { "NOT " } else { "" }, - if *case_insensitive { "ILIKE" } else { "LIKE" }, - pattern, - if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - } - ); - Ok(s) - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - case_insensitive: _, - }) => { - let s = format!( - "{} {} {} {}", - expr, - if *negated { - "NOT SIMILAR TO" - } else { - "SIMILAR TO" - }, - pattern, - if let Some(char) = escape_char { - format!("CHAR '{char}'") - } else { - "".to_string() - } - ); - Ok(s) - } - Expr::Case(case) => { - let mut name = "CASE ".to_string(); - if let Some(e) = &case.expr { - let e = create_name(e)?; - let _ = write!(name, "{e} "); - } - for (w, t) in &case.when_then_expr { - let when = create_name(w)?; - let then = create_name(t)?; - let _ = write!(name, "WHEN {when} THEN {then} "); - } - if let Some(e) = &case.else_expr { - let e = create_name(e)?; - let _ = write!(name, "ELSE {e} "); - } - name += "END"; - Ok(name) - } - Expr::Cast(Cast { expr, .. }) => { - // CAST does not change the expression name - create_name(expr) - } - Expr::TryCast(TryCast { expr, .. }) => { - // CAST does not change the expression name - create_name(expr) - } - Expr::Not(expr) => { - let expr = create_name(expr)?; - Ok(format!("NOT {expr}")) - } - Expr::Negative(expr) => { - let expr = create_name(expr)?; - Ok(format!("(- {expr})")) - } - Expr::IsNull(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS NULL")) - } - Expr::IsNotNull(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS NOT NULL")) - } - Expr::IsTrue(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS TRUE")) - } - Expr::IsFalse(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS FALSE")) - } - Expr::IsUnknown(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS UNKNOWN")) - } - Expr::IsNotTrue(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS NOT TRUE")) - } - Expr::IsNotFalse(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS NOT FALSE")) - } - Expr::IsNotUnknown(expr) => { - let expr = create_name(expr)?; - Ok(format!("{expr} IS NOT UNKNOWN")) - } - Expr::Exists(Exists { negated: true, .. }) => Ok("NOT EXISTS".to_string()), - Expr::Exists(Exists { negated: false, .. }) => Ok("EXISTS".to_string()), - Expr::InSubquery(InSubquery { negated: true, .. }) => Ok("NOT IN".to_string()), - Expr::InSubquery(InSubquery { negated: false, .. }) => Ok("IN".to_string()), - Expr::ScalarSubquery(subquery) => { - Ok(subquery.subquery.schema().field(0).name().clone()) - } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = create_name(expr)?; - match field { - GetFieldAccess::NamedStructField { name } => { - Ok(format!("{expr}[{name}]")) - } - GetFieldAccess::ListIndex { key } => { - let key = create_name(key)?; - Ok(format!("{expr}[{key}]")) - } - GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - let start = create_name(start)?; - let stop = create_name(stop)?; - let stride = create_name(stride)?; - Ok(format!("{expr}[{start}:{stop}:{stride}]")) - } - } - } - Expr::Unnest(Unnest { expr }) => { - let expr_name = create_name(expr)?; - Ok(format!("unnest({expr_name})")) - } - Expr::ScalarFunction(fun) => create_function_name(fun.name(), false, &fun.args), - Expr::WindowFunction(WindowFunction { - fun, - args, - window_frame, - partition_by, - order_by, - null_treatment, - }) => { - let mut parts: Vec = - vec![create_function_name(&fun.to_string(), false, args)?]; - - if let Some(nt) = null_treatment { - parts.push(format!("{}", nt)); - } - - if !partition_by.is_empty() { - parts.push(format!("PARTITION BY [{}]", expr_vec_fmt!(partition_by))); - } - - if !order_by.is_empty() { - parts.push(format!("ORDER BY [{}]", expr_vec_fmt!(order_by))); - } - - parts.push(format!("{window_frame}")); - Ok(parts.join(" ")) - } - Expr::AggregateFunction(AggregateFunction { - func_def, - distinct, - args, - filter, - order_by, - null_treatment, - }) => { - let name = match func_def { - AggregateFunctionDefinition::BuiltIn(..) - | AggregateFunctionDefinition::Name(..) => { - create_function_name(func_def.name(), *distinct, args)? - } - AggregateFunctionDefinition::UDF(..) => { - let names: Vec = - args.iter().map(create_name).collect::>()?; - names.join(",") - } - }; - let mut info = String::new(); - if let Some(fe) = filter { - info += &format!(" FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); - }; - if let Some(nt) = null_treatment { - info += &format!(" {}", nt); - } - match func_def { - AggregateFunctionDefinition::BuiltIn(..) - | AggregateFunctionDefinition::Name(..) => { - Ok(format!("{}{}", name, info)) - } - AggregateFunctionDefinition::UDF(fun) => { - Ok(format!("{}({}){}", fun.name(), name, info)) - } - } - } - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => { - Ok(format!("ROLLUP ({})", create_names(exprs.as_slice())?)) - } - GroupingSet::Cube(exprs) => { - Ok(format!("CUBE ({})", create_names(exprs.as_slice())?)) - } - GroupingSet::GroupingSets(lists_of_exprs) => { - let mut list_of_names = vec![]; - for exprs in lists_of_exprs { - list_of_names.push(format!("({})", create_names(exprs.as_slice())?)); - } - Ok(format!("GROUPING SETS ({})", list_of_names.join(", "))) - } - }, - Expr::InList(InList { - expr, - list, - negated, - }) => { - let expr = create_name(expr)?; - let list = list.iter().map(create_name); - if *negated { - Ok(format!("{expr} NOT IN ({list:?})")) - } else { - Ok(format!("{expr} IN ({list:?})")) - } - } - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = create_name(expr)?; - let low = create_name(low)?; - let high = create_name(high)?; - if *negated { - Ok(format!("{expr} NOT BETWEEN {low} AND {high}")) - } else { - Ok(format!("{expr} BETWEEN {low} AND {high}")) - } - } - Expr::Sort { .. } => { - internal_err!("Create name does not support sort expression") - } - Expr::Wildcard { qualifier } => match qualifier { - Some(qualifier) => internal_err!( - "Create name does not support qualified wildcard, got {qualifier}" - ), - None => Ok("*".to_string()), - }, - Expr::Placeholder(Placeholder { id, .. }) => Ok((*id).to_string()), +/// The name of the column (field) that this `Expr` will produce in the physical plan. +/// The difference from [Expr::schema_name] is that top-level columns are unqualified. +pub fn physical_name(expr: &Expr) -> Result { + if let Expr::Column(col) = expr { + Ok(col.name.clone()) + } else { + Ok(expr.schema_name().to_string()) } } -/// Create a comma separated list of names from a list of expressions -fn create_names(exprs: &[Expr]) -> Result { - Ok(exprs - .iter() - .map(create_name) - .collect::>>()? - .join(", ")) -} - #[cfg(test)] mod test { use crate::expr_fn::col; - use crate::{case, lit, ColumnarValue, ScalarUDF, ScalarUDFImpl, Volatility}; + use crate::{ + case, lit, qualified_wildcard, wildcard, wildcard_with_options, ColumnarValue, + ScalarUDF, ScalarUDFImpl, Volatility, + }; + use sqlparser::ast; + use sqlparser::ast::{Ident, IdentWithAlias}; use std::any::Any; #[test] + #[allow(deprecated)] fn format_case_when() -> Result<()> { let expr = case(col("a")) .when(lit(1), lit(true)) @@ -1955,11 +2396,11 @@ mod test { let expected = "CASE a WHEN Int32(1) THEN Boolean(true) WHEN Int32(0) THEN Boolean(false) ELSE NULL END"; assert_eq!(expected, expr.canonical_name()); assert_eq!(expected, format!("{expr}")); - assert_eq!(expected, expr.display_name()?); Ok(()) } #[test] + #[allow(deprecated)] fn format_cast() -> Result<()> { let expr = Expr::Cast(Cast { expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)))), @@ -1968,28 +2409,23 @@ mod test { let expected_canonical = "CAST(Float32(1.23) AS Utf8)"; assert_eq!(expected_canonical, expr.canonical_name()); assert_eq!(expected_canonical, format!("{expr}")); - // note that CAST intentionally has a name that is different from its `Display` + // Note that CAST intentionally has a name that is different from its `Display` // representation. CAST does not change the name of expressions. - assert_eq!("Float32(1.23)", expr.display_name()?); + assert_eq!("Float32(1.23)", expr.schema_name().to_string()); Ok(()) } #[test] fn test_partial_ord() { - // Test validates that partial ord is defined for Expr using hashes, not + // Test validates that partial ord is defined for Expr, not // intended to exhaustively test all possibilities let exp1 = col("a") + lit(1); let exp2 = col("a") + lit(2); let exp3 = !(col("a") + lit(2)); - // Since comparisons are done using hash value of the expression - // expr < expr2 may return false, or true. There is no guaranteed result. - // The only guarantee is "<" operator should have the opposite result of ">=" operator - let greater_or_equal = exp1 >= exp2; - assert_eq!(exp1 < exp2, !greater_or_equal); - - let greater_or_equal = exp3 >= exp2; - assert_eq!(exp3 < exp2, !greater_or_equal); + assert!(exp1 < exp2); + assert!(exp3 > exp2); + assert!(exp1 < exp3) } #[test] @@ -1997,7 +2433,7 @@ mod test { // single column { let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)); - let columns = expr.to_columns()?; + let columns = expr.column_refs(); assert_eq!(1, columns.len()); assert!(columns.contains(&Column::from_name("a"))); } @@ -2005,7 +2441,7 @@ mod test { // multiple columns { let expr = col("a") + col("b") + lit(1); - let columns = expr.to_columns()?; + let columns = expr.column_refs(); assert_eq!(2, columns.len()); assert!(columns.contains(&Column::from_name("a"))); assert!(columns.contains(&Column::from_name("b"))); @@ -2051,7 +2487,7 @@ mod test { } #[test] - fn test_is_volatile_scalar_func_definition() { + fn test_is_volatile_scalar_func() { // UDF #[derive(Debug)] struct TestScalarUDF { @@ -2080,7 +2516,7 @@ mod test { let udf = Arc::new(ScalarUDF::from(TestScalarUDF { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), })); - assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + assert_ne!(udf.signature().volatility, Volatility::Volatile); let udf = Arc::new(ScalarUDF::from(TestScalarUDF { signature: Signature::uniform( @@ -2089,35 +2525,18 @@ mod test { Volatility::Volatile, ), })); - assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); - - // Unresolved function - ScalarFunctionDefinition::Name(Arc::from("UnresolvedFunc")) - .is_volatile() - .expect_err("Shouldn't determine volatility of unresolved function"); + assert_eq!(udf.signature().volatility, Volatility::Volatile); } use super::*; - #[test] - fn test_count_return_type() -> Result<()> { - let fun = find_df_window_func("count").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = fun.return_type(&[DataType::UInt64])?; - assert_eq!(DataType::Int64, observed); - - Ok(()) - } - #[test] fn test_first_value_return_type() -> Result<()> { let fun = find_df_window_func("first_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::UInt64])?; + let observed = fun.return_type(&[DataType::UInt64], &[true], "")?; assert_eq!(DataType::UInt64, observed); Ok(()) @@ -2126,34 +2545,10 @@ mod test { #[test] fn test_last_value_return_type() -> Result<()> { let fun = find_df_window_func("last_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lead_return_type() -> Result<()> { - let fun = find_df_window_func("lead").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = fun.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_lag_return_type() -> Result<()> { - let fun = find_df_window_func("lag").unwrap(); - let observed = fun.return_type(&[DataType::Utf8])?; + let observed = fun.return_type(&[DataType::Utf8], &[true], "")?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64])?; + let observed = fun.return_type(&[DataType::Float64], &[true], "")?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -2162,67 +2557,29 @@ mod test { #[test] fn test_nth_value_return_type() -> Result<()> { let fun = find_df_window_func("nth_value").unwrap(); - let observed = fun.return_type(&[DataType::Utf8, DataType::UInt64])?; + let observed = + fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?; assert_eq!(DataType::Utf8, observed); - let observed = fun.return_type(&[DataType::Float64, DataType::UInt64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_percent_rank_return_type() -> Result<()> { - let fun = find_df_window_func("percent_rank").unwrap(); - let observed = fun.return_type(&[])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_cume_dist_return_type() -> Result<()> { - let fun = find_df_window_func("cume_dist").unwrap(); - let observed = fun.return_type(&[])?; + let observed = + fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?; assert_eq!(DataType::Float64, observed); Ok(()) } - #[test] - fn test_ntile_return_type() -> Result<()> { - let fun = find_df_window_func("ntile").unwrap(); - let observed = fun.return_type(&[DataType::Int16])?; - assert_eq!(DataType::UInt64, observed); - - Ok(()) - } - #[test] fn test_window_function_case_insensitive() -> Result<()> { - let names = vec![ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "cume_dist", - "ntile", - "lag", - "lead", - "first_value", - "last_value", - "nth_value", - "min", - "max", - "count", - "avg", - "sum", - ]; + let names = vec!["first_value", "last_value", "nth_value"]; for name in names { let fun = find_df_window_func(name).unwrap(); let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap(); assert_eq!(fun, fun2); - assert_eq!(fun.to_string(), name.to_uppercase()); + if fun.to_string() == "first_value" || fun.to_string() == "last_value" { + assert_eq!(fun.to_string(), name); + } else { + assert_eq!(fun.to_string(), name.to_uppercase()); + } } Ok(()) } @@ -2230,53 +2587,122 @@ mod test { #[test] fn test_find_df_window_function() { assert_eq!( - find_df_window_func("max"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Max - )) - ); - assert_eq!( - find_df_window_func("min"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Min - )) - ); - assert_eq!( - find_df_window_func("avg"), - Some(WindowFunctionDefinition::AggregateFunction( - aggregate_function::AggregateFunction::Avg + find_df_window_func("first_value"), + Some(WindowFunctionDefinition::BuiltInWindowFunction( + BuiltInWindowFunction::FirstValue )) ); assert_eq!( - find_df_window_func("cume_dist"), + find_df_window_func("LAST_value"), Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::CumeDist + BuiltInWindowFunction::LastValue )) ); + assert_eq!(find_df_window_func("not_exist"), None) + } + + #[test] + fn test_display_wildcard() { + assert_eq!(format!("{}", wildcard()), "*"); + assert_eq!(format!("{}", qualified_wildcard("t1")), "t1.*"); assert_eq!( - find_df_window_func("first_value"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::FirstValue - )) + format!( + "{}", + wildcard_with_options(wildcard_options( + Some(IlikeSelectItem { + pattern: "c1".to_string() + }), + None, + None, + None, + None + )) + ), + "* ILIKE 'c1'" ); assert_eq!( - find_df_window_func("LAST_value"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::LastValue - )) + format!( + "{}", + wildcard_with_options(wildcard_options( + None, + Some(ExcludeSelectItem::Multiple(vec![ + Ident::from("c1"), + Ident::from("c2") + ])), + None, + None, + None + )) + ), + "* EXCLUDE (c1, c2)" ); assert_eq!( - find_df_window_func("LAG"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lag - )) + format!( + "{}", + wildcard_with_options(wildcard_options( + None, + None, + Some(ExceptSelectItem { + first_element: Ident::from("c1"), + additional_elements: vec![Ident::from("c2")] + }), + None, + None + )) + ), + "* EXCEPT (c1, c2)" ); assert_eq!( - find_df_window_func("LEAD"), - Some(WindowFunctionDefinition::BuiltInWindowFunction( - built_in_window_function::BuiltInWindowFunction::Lead - )) + format!( + "{}", + wildcard_with_options(wildcard_options( + None, + None, + None, + Some(PlannedReplaceSelectItem { + items: vec![ReplaceSelectElement { + expr: ast::Expr::Identifier(Ident::from("c1")), + column_name: Ident::from("a1"), + as_keyword: false + }], + planned_expressions: vec![] + }), + None + )) + ), + "* REPLACE (c1 a1)" ); - assert_eq!(find_df_window_func("not_exist"), None) + assert_eq!( + format!( + "{}", + wildcard_with_options(wildcard_options( + None, + None, + None, + None, + Some(RenameSelectItem::Multiple(vec![IdentWithAlias { + ident: Ident::from("c1"), + alias: Ident::from("a1") + }])) + )) + ), + "* RENAME (c1 AS a1)" + ) + } + + fn wildcard_options( + opt_ilike: Option, + opt_exclude: Option, + opt_except: Option, + opt_replace: Option, + opt_rename: Option, + ) -> WildcardOptions { + WildcardOptions { + ilike: opt_ilike, + exclude: opt_exclude, + except: opt_except, + replace: opt_replace, + rename: opt_rename, + } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 1d976a12cc4f..7fd4e64e0e62 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -19,19 +19,28 @@ use crate::expr::{ AggregateFunction, BinaryExpr, Cast, Exists, GroupingSet, InList, InSubquery, - Placeholder, TryCast, + Placeholder, TryCast, Unnest, WildcardOptions, WindowFunction, }; use crate::function::{ AccumulatorArgs, AccumulatorFactoryFunction, PartitionEvaluatorFactory, + StateFieldsArgs, }; use crate::{ - aggregate_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, - AggregateUDF, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery, + AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, + ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, +}; +use crate::{ + AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl, +}; +use arrow::compute::kernels::cast_utils::{ + parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, }; -use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{Column, Result}; +use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference}; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use sqlparser::ast::NullTreatment; use std::any::Any; use std::fmt::Debug; use std::ops::Not; @@ -112,7 +121,46 @@ pub fn placeholder(id: impl Into) -> Expr { /// assert_eq!(p.to_string(), "*") /// ``` pub fn wildcard() -> Expr { - Expr::Wildcard { qualifier: None } + Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + } +} + +/// Create an '*' [`Expr::Wildcard`] expression with the wildcard options +pub fn wildcard_with_options(options: WildcardOptions) -> Expr { + Expr::Wildcard { + qualifier: None, + options, + } +} + +/// Create an 't.*' [`Expr::Wildcard`] expression that matches all columns from a specific table +/// +/// # Example +/// +/// ```rust +/// # use datafusion_common::TableReference; +/// # use datafusion_expr::{qualified_wildcard}; +/// let p = qualified_wildcard(TableReference::bare("t")); +/// assert_eq!(p.to_string(), "t.*") +/// ``` +pub fn qualified_wildcard(qualifier: impl Into) -> Expr { + Expr::Wildcard { + qualifier: Some(qualifier.into()), + options: WildcardOptions::default(), + } +} + +/// Create an 't.*' [`Expr::Wildcard`] expression with the wildcard options +pub fn qualified_wildcard_with_options( + qualifier: impl Into, + options: WildcardOptions, +) -> Expr { + Expr::Wildcard { + qualifier: Some(qualifier.into()), + options, + } } /// Return a new expression `left right` @@ -143,78 +191,6 @@ pub fn not(expr: Expr) -> Expr { expr.not() } -/// Create an expression to represent the min() aggregate function -pub fn min(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Min, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the max() aggregate function -pub fn max(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Max, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the sum() aggregate function -pub fn sum(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Sum, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the array_agg() aggregate function -pub fn array_agg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ArrayAgg, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the avg() aggregate function -pub fn avg(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Avg, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Create an expression to represent the count() aggregate function -pub fn count(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - false, - None, - None, - None, - )) -} - /// Return a new expression with bitwise AND pub fn bitwise_and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr(BinaryExpr::new( @@ -260,93 +236,11 @@ pub fn bitwise_shift_left(left: Expr, right: Expr) -> Expr { )) } -/// Create an expression to represent the count(distinct) aggregate function -pub fn count_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Count, - vec![expr], - true, - None, - None, - None, - )) -} - /// Create an in_list expression pub fn in_list(expr: Expr, list: Vec, negated: bool) -> Expr { Expr::InList(InList::new(Box::new(expr), list, negated)) } -/// Returns the approximate number of distinct input values. -/// This function provides an approximation of count(DISTINCT x). -/// Zero is returned if all input values are null. -/// This function should produce a standard error of 0.81%, -/// which is the standard deviation of the (approximately normal) -/// error distribution over all possible sets. -/// It does not guarantee an upper bound on the error for any specific input set. -pub fn approx_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxDistinct, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Calculate the median for `expr`. -pub fn median(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Median, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Calculate an approximation of the median for `expr`. -pub fn approx_median(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxMedian, - vec![expr], - false, - None, - None, - None, - )) -} - -/// Calculate an approximation of the specified `percentile` for `expr`. -pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileCont, - vec![expr, percentile], - false, - None, - None, - None, - )) -} - -/// Calculate an approximation of the specified `percentile` for `expr` and `weight_expr`. -pub fn approx_percentile_cont_with_weight( - expr: Expr, - weight_expr: Expr, - percentile: Expr, -) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxPercentileContWithWeight, - vec![expr, weight_expr, percentile], - false, - None, - None, - None, - )) -} - /// Create an EXISTS subquery expression pub fn exists(subquery: Arc) -> Expr { let outer_ref_columns = subquery.all_out_ref_exprs(); @@ -406,18 +300,6 @@ pub fn scalar_subquery(subquery: Arc) -> Expr { }) } -/// Create an expression to represent the stddev() aggregate function -pub fn stddev(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::Stddev, - vec![expr], - false, - None, - None, - None, - )) -} - /// Create a grouping set pub fn grouping_set(exprs: Vec>) -> Expr { Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) @@ -488,6 +370,13 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } +/// Create a Unnest expression +pub fn unnest(expr: Expr) -> Expr { + Expr::Unnest(Unnest { + expr: Box::new(expr), + }) +} + /// Convenience method to create a new user defined scalar function (UDF) with a /// specific signature and specific return type. /// @@ -503,11 +392,10 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { pub fn create_udf( name: &str, input_types: Vec, - return_type: Arc, + return_type: DataType, volatility: Volatility, fun: ScalarFunctionImplementation, ) -> ScalarUDF { - let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); ScalarUDF::from(SimpleScalarUDF::new( name, input_types, @@ -589,8 +477,8 @@ pub fn create_udaf( accumulator: AccumulatorFactoryFunction, state_type: Arc>, ) -> AggregateUDF { - let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); - let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + let return_type = Arc::unwrap_or_clone(return_type); + let state_type = Arc::unwrap_or_clone(state_type); let state_fields = state_type .into_iter() .enumerate() @@ -690,12 +578,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF { (self.accumulator)(acc_args) } - fn state_fields( - &self, - _name: &str, - _value_type: DataType, - _ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { Ok(self.state_fields.clone()) } } @@ -712,7 +595,7 @@ pub fn create_udwf( volatility: Volatility, partition_evaluator_factory: PartitionEvaluatorFactory, ) -> WindowUDF { - let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let return_type = Arc::unwrap_or_clone(return_type); WindowUDF::from(SimpleWindowUDF::new( name, input_type, @@ -776,12 +659,292 @@ impl WindowUDFImpl for SimpleWindowUDF { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + (self.partition_evaluator_factory)() } - fn partition_evaluator(&self) -> Result> { - (self.partition_evaluator_factory)() + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new( + field_args.name(), + self.return_type.clone(), + true, + )) + } +} + +pub fn interval_year_month_lit(value: &str) -> Expr { + let interval = parse_interval_year_month(value).ok(); + Expr::Literal(ScalarValue::IntervalYearMonth(interval)) +} + +pub fn interval_datetime_lit(value: &str) -> Expr { + let interval = parse_interval_day_time(value).ok(); + Expr::Literal(ScalarValue::IntervalDayTime(interval)) +} + +pub fn interval_month_day_nano_lit(value: &str) -> Expr { + let interval = parse_interval_month_day_nano(value).ok(); + Expr::Literal(ScalarValue::IntervalMonthDayNano(interval)) +} + +/// Extensions for configuring [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] +/// +/// Adds methods to [`Expr`] that make it easy to set optional options +/// such as `ORDER BY`, `FILTER` and `DISTINCT` +/// +/// # Example +/// ```no_run +/// # use datafusion_common::Result; +/// # use datafusion_expr::test::function_stub::count; +/// # use sqlparser::ast::NullTreatment; +/// # use datafusion_expr::{ExprFunctionExt, lit, Expr, col}; +/// # // first_value is an aggregate function in another crate +/// # fn first_value(_arg: Expr) -> Expr { +/// unimplemented!() } +/// # fn main() -> Result<()> { +/// // Create an aggregate count, filtering on column y > 5 +/// let agg = count(col("x")).filter(col("y").gt(lit(5))).build()?; +/// +/// // Find the first value in an aggregate sorted by column y +/// // equivalent to: +/// // `FIRST_VALUE(x ORDER BY y ASC IGNORE NULLS)` +/// let sort_expr = col("y").sort(true, true); +/// let agg = first_value(col("x")) +/// .order_by(vec![sort_expr]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// +/// // Create a window expression for percent rank partitioned on column a +/// // equivalent to: +/// // `PERCENT_RANK() OVER (PARTITION BY a ORDER BY b ASC NULLS LAST IGNORE NULLS)` +/// // percent_rank is an udwf function in another crate +/// # fn percent_rank() -> Expr { +/// unimplemented!() } +/// let window = percent_rank() +/// .partition_by(vec![col("a")]) +/// .order_by(vec![col("b").sort(true, true)]) +/// .null_treatment(NullTreatment::IgnoreNulls) +/// .build()?; +/// # Ok(()) +/// # } +/// ``` +pub trait ExprFunctionExt { + /// Add `ORDER BY ` + fn order_by(self, order_by: Vec) -> ExprFuncBuilder; + /// Add `FILTER ` + fn filter(self, filter: Expr) -> ExprFuncBuilder; + /// Add `DISTINCT` + fn distinct(self) -> ExprFuncBuilder; + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder; + /// Add `PARTITION BY` + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder; + /// Add appropriate window frame conditions + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder; +} + +#[derive(Debug, Clone)] +pub enum ExprFuncKind { + Aggregate(AggregateFunction), + Window(WindowFunction), +} + +/// Implementation of [`ExprFunctionExt`]. +/// +/// See [`ExprFunctionExt`] for usage and examples +#[derive(Debug, Clone)] +pub struct ExprFuncBuilder { + fun: Option, + order_by: Option>, + filter: Option, + distinct: bool, + null_treatment: Option, + partition_by: Option>, + window_frame: Option, +} + +impl ExprFuncBuilder { + /// Create a new `ExprFuncBuilder`, see [`ExprFunctionExt`] + fn new(fun: Option) -> Self { + Self { + fun, + order_by: None, + filter: None, + distinct: false, + null_treatment: None, + partition_by: None, + window_frame: None, + } + } + + /// Updates and returns the in progress [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] + /// + /// # Errors: + /// + /// Returns an error if this builder [`ExprFunctionExt`] was used with an + /// `Expr` variant other than [`Expr::AggregateFunction`] or [`Expr::WindowFunction`] + pub fn build(self) -> Result { + let Self { + fun, + order_by, + filter, + distinct, + null_treatment, + partition_by, + window_frame, + } = self; + + let Some(fun) = fun else { + return plan_err!( + "ExprFunctionExt can only be used with Expr::AggregateFunction or Expr::WindowFunction" + ); + }; + + let fun_expr = match fun { + ExprFuncKind::Aggregate(mut udaf) => { + udaf.order_by = order_by; + udaf.filter = filter.map(Box::new); + udaf.distinct = distinct; + udaf.null_treatment = null_treatment; + Expr::AggregateFunction(udaf) + } + ExprFuncKind::Window(mut udwf) => { + let has_order_by = order_by.as_ref().map(|o| !o.is_empty()); + udwf.order_by = order_by.unwrap_or_default(); + udwf.partition_by = partition_by.unwrap_or_default(); + udwf.window_frame = + window_frame.unwrap_or(WindowFrame::new(has_order_by)); + udwf.null_treatment = null_treatment; + Expr::WindowFunction(udwf) + } + }; + + Ok(fun_expr) + } +} + +impl ExprFunctionExt for ExprFuncBuilder { + /// Add `ORDER BY ` + fn order_by(mut self, order_by: Vec) -> ExprFuncBuilder { + self.order_by = Some(order_by); + self + } + + /// Add `FILTER ` + fn filter(mut self, filter: Expr) -> ExprFuncBuilder { + self.filter = Some(filter); + self + } + + /// Add `DISTINCT` + fn distinct(mut self) -> ExprFuncBuilder { + self.distinct = true; + self + } + + /// Add `RESPECT NULLS` or `IGNORE NULLS` + fn null_treatment( + mut self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { + self.null_treatment = null_treatment.into(); + self + } + + fn partition_by(mut self, partition_by: Vec) -> ExprFuncBuilder { + self.partition_by = Some(partition_by); + self + } + + fn window_frame(mut self, window_frame: WindowFrame) -> ExprFuncBuilder { + self.window_frame = Some(window_frame); + self + } +} + +impl ExprFunctionExt for Expr { + fn order_by(self, order_by: Vec) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.order_by = Some(order_by); + } + builder + } + fn filter(self, filter: Expr) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.filter = Some(filter); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn distinct(self) -> ExprFuncBuilder { + match self { + Expr::AggregateFunction(udaf) => { + let mut builder = + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))); + builder.distinct = true; + builder + } + _ => ExprFuncBuilder::new(None), + } + } + fn null_treatment( + self, + null_treatment: impl Into>, + ) -> ExprFuncBuilder { + let mut builder = match self { + Expr::AggregateFunction(udaf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Aggregate(udaf))) + } + Expr::WindowFunction(udwf) => { + ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))) + } + _ => ExprFuncBuilder::new(None), + }; + if builder.fun.is_some() { + builder.null_treatment = null_treatment.into(); + } + builder + } + + fn partition_by(self, partition_by: Vec) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.partition_by = Some(partition_by); + builder + } + _ => ExprFuncBuilder::new(None), + } + } + + fn window_frame(self, window_frame: WindowFrame) -> ExprFuncBuilder { + match self { + Expr::WindowFunction(udwf) => { + let mut builder = ExprFuncBuilder::new(Some(ExprFuncKind::Window(udwf))); + builder.window_frame = Some(window_frame); + builder + } + _ => ExprFuncBuilder::new(None), + } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 14154189a126..c86696854ca3 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -19,16 +19,15 @@ use std::collections::HashMap; use std::collections::HashSet; +use std::fmt::Debug; use std::sync::Arc; -use crate::expr::{Alias, Unnest}; +use crate::expr::{Alias, Sort, Unnest}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRewriter, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::TableReference; use datafusion_common::{Column, DFSchema, Result}; @@ -42,8 +41,9 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// /// For example, concatenating arrays `a || b` is represented as /// `Operator::ArrowAt`, but can be implemented by calling a function -/// `array_concat` from the `functions-array` crate. -pub trait FunctionRewrite { +/// `array_concat` from the `functions-nested` crate. +// This is not used in datafusion internally, but it is still helpful for downstream project so don't remove it. +pub trait FunctionRewrite: Debug { /// Return a human readable name for this rewrite fn name(&self) -> &str; @@ -59,7 +59,7 @@ pub trait FunctionRewrite { ) -> Result>; } -/// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions +/// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { expr.transform(|expr| { @@ -116,6 +116,20 @@ pub fn normalize_cols( .collect() } +pub fn normalize_sorts( + sorts: impl IntoIterator>, + plan: &LogicalPlan, +) -> Result> { + sorts + .into_iter() + .map(|e| { + let sort = e.into(); + normalize_col(sort.expr, plan) + .map(|expr| Sort::new(expr, sort.asc, sort.nulls_first)) + }) + .collect() +} + /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { @@ -154,7 +168,7 @@ pub fn unnormalize_col(expr: Expr) -> Expr { }) }) .data() - .expect("Unnormalize is infallable") + .expect("Unnormalize is infallible") } /// Create a Column from the Scalar Expr @@ -172,7 +186,7 @@ pub fn create_col_from_scalar_expr( name, )), _ => { - let scalar_column = scalar_expr.display_name()?; + let scalar_column = scalar_expr.schema_name().to_string(); Ok(Column::new( Some::(subqry_alias.into()), scalar_column, @@ -200,33 +214,31 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { }) }) .data() - .expect("strip_outer_reference is infallable") + .expect("strip_outer_reference is infallible") } /// Returns plan with expressions coerced to types compatible with /// schema types pub fn coerce_plan_expr_for_schema( - plan: &LogicalPlan, + plan: LogicalPlan, schema: &DFSchema, ) -> Result { match plan { // special case Projection to avoid adding multiple projections LogicalPlan::Projection(Projection { expr, input, .. }) => { - let new_exprs = - coerce_exprs_for_schema(expr.clone(), input.schema(), schema)?; - let projection = Projection::try_new(new_exprs, input.clone())?; + let new_exprs = coerce_exprs_for_schema(expr, input.schema(), schema)?; + let projection = Projection::try_new(new_exprs, input)?; Ok(LogicalPlan::Projection(projection)) } _ => { let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); - let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; - let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err()); + let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); if add_project { - let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?; + let projection = Projection::try_new(new_exprs, Arc::new(plan))?; Ok(LogicalPlan::Projection(projection)) } else { - Ok(plan.clone()) + Ok(plan) } } } @@ -247,10 +259,11 @@ fn coerce_exprs_for_schema( Expr::Alias(Alias { expr, name, .. }) => { Ok(expr.cast_to(new_type, src_schema)?.alias(name)) } + Expr::Wildcard { .. } => Ok(expr), _ => expr.cast_to(new_type, src_schema), } } else { - Ok(expr.clone()) + Ok(expr) } }) .collect::>() @@ -265,18 +278,79 @@ pub fn unalias(expr: Expr) -> Expr { } } -/// Rewrites `expr` using `rewriter`, ensuring that the output has the -/// same name as `expr` prior to rewrite, adding an alias if necessary. +/// Handles ensuring the name of rewritten expressions is not changed. /// /// This is important when optimizing plans to ensure the output -/// schema of plan nodes don't change after optimization -pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result -where - R: TreeNodeRewriter, -{ - let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?.data; - expr.alias_if_changed(original_name) +/// schema of plan nodes don't change after optimization. +/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the +/// expression should be preserved: `3 as "1 + 2"` +/// +/// See for details +pub struct NamePreserver { + use_alias: bool, +} + +/// If the qualified name of an expression is remembered, it will be preserved +/// when rewriting the expression +pub enum SavedName { + /// Saved qualified name to be preserved + Saved { + relation: Option, + name: String, + }, + /// Name is not preserved + None, +} + +impl NamePreserver { + /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan + pub fn new(plan: &LogicalPlan) -> Self { + Self { + // The expressions of these plans do not contribute to their output schema, + // so there is no need to preserve expression names to prevent a schema change. + use_alias: !matches!( + plan, + LogicalPlan::Filter(_) + | LogicalPlan::Join(_) + | LogicalPlan::TableScan(_) + | LogicalPlan::Limit(_) + | LogicalPlan::Execute(_) + ), + } + } + + /// Create a new NamePreserver for rewriting the `expr`s in `Projection` + /// + /// This will use aliases + pub fn new_for_projection() -> Self { + Self { use_alias: true } + } + + pub fn save(&self, expr: &Expr) -> SavedName { + if self.use_alias { + let (relation, name) = expr.qualified_name(); + SavedName::Saved { relation, name } + } else { + SavedName::None + } + } +} + +impl SavedName { + /// Ensures the qualified name of the rewritten expression is preserved + pub fn restore(self, expr: Expr) -> Expr { + match self { + SavedName::Saved { relation, name } => { + let (new_relation, new_name) = expr.qualified_name(); + if new_relation != relation || new_name != name { + expr.alias_qualified(relation, name) + } else { + expr + } + } + SavedName::None => expr, + } + } } #[cfg(test)] @@ -284,9 +358,9 @@ mod test { use std::ops::Add; use super::*; - use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::ScalarValue; #[derive(Default)] @@ -446,15 +520,19 @@ mod test { // change literal type from i32 to i64 test_rewrite(col("a").add(lit(1i32)), col("a").add(lit(1i64))); - // SortExpr a+1 ==> b + 2 + // test preserve qualifier + test_rewrite( + Expr::Column(Column::new(Some("test"), "a")), + Expr::Column(Column::new_unqualified("test.a")), + ); test_rewrite( - Expr::Sort(Sort::new(Box::new(col("a").add(lit(1i32))), true, false)), - Expr::Sort(Sort::new(Box::new(col("b").add(lit(2i64))), true, false)), + Expr::Column(Column::new_unqualified("test.a")), + Expr::Column(Column::new(Some("test"), "a")), ); } - /// rewrites `expr_from` to `rewrite_to` using - /// `rewrite_preserving_name` verifying the result is `expected_expr` + /// rewrites `expr_from` to `rewrite_to` while preserving the original qualified name + /// by using the `NamePreserver` fn test_rewrite(expr_from: Expr, rewrite_to: Expr) { struct TestRewriter { rewrite_to: Expr, @@ -471,20 +549,12 @@ mod test { let mut rewriter = TestRewriter { rewrite_to: rewrite_to.clone(), }; - let expr = rewrite_preserving_name(expr_from.clone(), &mut rewriter).unwrap(); - - let original_name = match &expr_from { - Expr::Sort(Sort { expr, .. }) => expr.display_name(), - expr => expr.display_name(), - } - .unwrap(); - - let new_name = match &expr { - Expr::Sort(Sort { expr, .. }) => expr.display_name(), - expr => expr.display_name(), - } - .unwrap(); + let saved_name = NamePreserver { use_alias: true }.save(&expr_from); + let new_expr = expr_from.clone().rewrite(&mut rewriter).unwrap().data; + let new_expr = saved_name.restore(new_expr); + let original_name = expr_from.qualified_name(); + let new_name = new_expr.qualified_name(); assert_eq!( original_name, new_name, "mismatch rewriting expr_from: {expr_from} to {rewrite_to}" diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index eb38fee7cad0..f0d3d8fcd0c1 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -17,38 +17,28 @@ //! Rewrite for order by expressions -use crate::expr::{Alias, Sort}; +use crate::expr::Alias; use crate::expr_rewriter::normalize_col; -use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; +use crate::{expr::Sort, Cast, Expr, LogicalPlan, TryCast}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output -/// For example, `max(x)` is written to `col("MAX(x)")` +/// For example, `max(x)` is written to `col("max(x)")` pub fn rewrite_sort_cols_by_aggs( - exprs: impl IntoIterator>, + sorts: impl IntoIterator>, plan: &LogicalPlan, -) -> Result> { - exprs +) -> Result> { + sorts .into_iter() .map(|e| { - let expr = e.into(); - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let sort = Expr::Sort(Sort::new( - Box::new(rewrite_sort_col_by_aggs(*expr, plan)?), - asc, - nulls_first, - )); - Ok(sort) - } - expr => Ok(expr), - } + let sort = e.into(); + Ok(Sort::new( + rewrite_sort_col_by_aggs(sort.expr, plan)?, + sort.asc, + sort.nulls_first, + )) }) .collect() } @@ -87,11 +77,8 @@ fn rewrite_in_terms_of_projection( expr.transform(|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { - let col = Expr::Column( - found - .to_field(input.schema()) - .map(|(qualifier, field)| Column::new(qualifier, field.name()))?, - ); + let (qualifier, field_name) = found.qualified_name(); + let col = Expr::Column(Column::new(qualifier, field_name)); return Ok(Transformed::yes(col)); } @@ -109,7 +96,7 @@ fn rewrite_in_terms_of_projection( // expr is an actual expr like min(t.c2), but we are looking // for a column with the same "MIN(C2)", so translate there - let name = normalized_expr.display_name()?; + let name = normalized_expr.schema_name().to_string(); let search_col = Expr::Column(Column { relation: None, @@ -156,11 +143,13 @@ mod test { use arrow::datatypes::{DataType, Field, Schema}; use crate::{ - avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast, + cast, col, lit, logical_plan::builder::LogicalTableSource, try_cast, LogicalPlanBuilder, }; use super::*; + use crate::test::function_stub::avg; + use crate::test::function_stub::min; #[test] fn rewrite_sort_cols_by_agg() { @@ -235,20 +224,20 @@ mod test { expected: sort(col("c1")), }, TestCase { - desc: r#"min(c2) --> "MIN(c2)" -- (column *named* "min(t.c2)"!)"#, + desc: r#"min(c2) --> "min(c2)" -- (column *named* "min(t.c2)"!)"#, input: sort(min(col("c2"))), - expected: sort(col("MIN(t.c2)")), + expected: sort(col("min(t.c2)")), }, TestCase { - desc: r#"c1 + min(c2) --> "c1 + MIN(c2)" -- (column *named* "min(t.c2)"!)"#, + desc: r#"c1 + min(c2) --> "c1 + min(c2)" -- (column *named* "min(t.c2)"!)"#, input: sort(col("c1") + min(col("c2"))), // should be "c1" not t.c1 - expected: sort(col("c1") + col("MIN(t.c2)")), + expected: sort(col("c1") + col("min(t.c2)")), }, TestCase { - desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#, + desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#, input: sort(avg(col("c3"))), - expected: sort(col("AVG(t.c3)").alias("average")), + expected: sort(col("avg(t.c3)").alias("average")), }, ]; @@ -287,8 +276,8 @@ mod test { struct TestCase { desc: &'static str, - input: Expr, - expected: Expr, + input: Sort, + expected: Sort, } impl TestCase { @@ -330,7 +319,7 @@ mod test { .unwrap() } - fn sort(expr: Expr) -> Expr { + fn sort(expr: Expr) -> Sort { let asc = true; let nulls_first = true; expr.sort(asc, nulls_first) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index c5ae0f1b831a..07a36672f272 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,44 +17,45 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction, + AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, + ScalarFunction, TryCast, Unnest, WindowFunction, }; -use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::type_coercion::functions::data_types; -use crate::{utils, LogicalPlan, Projection, Subquery}; +use crate::type_coercion::functions::{ + data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, +}; +use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, ExprSchema, - Result, TableReference, + not_impl_err, plan_datafusion_err, plan_err, Column, ExprSchema, Result, + TableReference, }; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; use std::collections::HashMap; use std::sync::Arc; -/// trait to allow expr to typable with respect to a schema +/// Trait to allow expr to typable with respect to a schema pub trait ExprSchemable { - /// given a schema, return the type of the expr + /// Given a schema, return the type of the expr fn get_type(&self, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the nullability of the expr + /// Given a schema, return the nullability of the expr fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; - /// given a schema, return the expr's optional metadata + /// Given a schema, return the expr's optional metadata fn metadata(&self, schema: &dyn ExprSchema) -> Result>; - /// convert to a field with respect to a schema + /// Convert to a field with respect to a schema fn to_field( &self, input_schema: &dyn ExprSchema, ) -> Result<(Option, Arc)>; - /// cast to a type with respect to a schema + /// Cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; - /// given a schema, return the type and nullability of the expr + /// Given a schema, return the type and nullability of the expr fn data_type_and_nullable(&self, schema: &dyn ExprSchema) -> Result<(DataType, bool)>; } @@ -69,8 +70,8 @@ impl ExprSchemable for Expr { /// /// # Examples /// - /// ## Get the type of an expression that adds 2 columns. Adding an Int32 - /// ## and Float32 results in Float32 type + /// Get the type of an expression that adds 2 columns. Adding an Int32 + /// and Float32 results in Float32 type /// /// ``` /// # use arrow::datatypes::{DataType, Field}; @@ -80,7 +81,7 @@ impl ExprSchemable for Expr { /// /// fn main() { /// let expr = col("c1") + col("c2"); - /// let schema = DFSchema::from_unqualifed_fields( + /// let schema = DFSchema::from_unqualified_fields( /// vec![ /// Field::new("c1", DataType::Int32, true), /// Field::new("c2", DataType::Float32, true), @@ -107,83 +108,87 @@ impl ExprSchemable for Expr { }, _ => expr.get_type(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => expr.get_type(schema), + Expr::Negative(expr) => expr.get_type(schema), Expr::Column(c) => Ok(schema.data_type(c)?.clone()), Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()), Expr::ScalarVariable(ty, _) => Ok(ty.clone()), Expr::Literal(l) => Ok(l.data_type()), - Expr::Case(case) => case.when_then_expr[0].1.get_type(schema), + Expr::Case(case) => { + for (_, then_expr) in &case.when_then_expr { + let then_type = then_expr.get_type(schema)?; + if !then_type.is_null() { + return Ok(then_type); + } + } + case.else_expr + .as_ref() + .map_or(Ok(DataType::Null), |e| e.get_type(schema)) + } Expr::Cast(Cast { data_type, .. }) | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()), Expr::Unnest(Unnest { expr }) => { let arg_data_type = expr.get_type(schema)?; // Unnest's output type is the inner type of the list - match arg_data_type{ - DataType::List(field) | DataType::LargeList(field) | DataType::FixedSizeList(field, _) =>{ - Ok(field.data_type().clone()) - } - DataType::Struct(_) => { - not_impl_err!("unnest() does not support struct yet") - } + match arg_data_type { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()), + DataType::Struct(_) => Ok(arg_data_type), DataType::Null => { not_impl_err!("unnest() does not support null yet") } _ => { - plan_err!("unnest() can only be applied to array, struct and null") + plan_err!( + "unnest() can only be applied to array, struct and null" + ) } } } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }) => { let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - match func_def { - ScalarFunctionDefinition::UDF(fun) => { - // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - data_types(&arg_data_types, fun.signature()).map_err(|_| { - plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - fun.name(), - fun.signature().clone(), - &arg_data_types, - ) + + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` + let new_data_types = data_types_with_scalar_udf(&arg_data_types, func) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &arg_data_types, ) - })?; + ) + })?; - // perform additional function arguments validation (due to limited - // expressiveness of `TypeSignature`), then infer return type - Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?) - } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - } + // Perform additional function arguments validation (due to limited + // expressiveness of `TypeSignature`), then infer return type + Ok(func.return_type_from_exprs(args, schema, &new_data_types)?) } - Expr::WindowFunction(WindowFunction { fun, args, .. }) => { + Expr::WindowFunction(window_function) => self + .data_type_and_nullable_with_window_function(schema, window_function) + .map(|(return_type, _)| return_type), + Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - fun.return_type(&data_types) - } - Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - fun.return_type(&data_types) - } - AggregateFunctionDefinition::UDF(fun) => { - Ok(fun.return_type(&data_types)?) - } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - } + let new_types = data_types_with_aggregate_udf(&data_types, func) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + func.name(), + func.signature().clone(), + &data_types + ) + ) + })?; + Ok(func.return_type(&new_types)?) } Expr::Not(_) | Expr::IsNull(_) @@ -209,23 +214,18 @@ impl ExprSchemable for Expr { Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean), Expr::Placeholder(Placeholder { data_type, .. }) => { data_type.clone().ok_or_else(|| { - plan_datafusion_err!("Placeholder type could not be resolved. Make sure that the placeholder is bound to a concrete type, e.g. by providing parameter values.") + plan_datafusion_err!( + "Placeholder type could not be resolved. Make sure that the \ + placeholder is bound to a concrete type, e.g. by providing \ + parameter values." + ) }) } - Expr::Wildcard { qualifier } => { - // Wildcard do not really have a type and do not appear in projections - match qualifier { - Some(_) => internal_err!("QualifiedWildcard expressions are not valid in a logical query plan"), - None => Ok(DataType::Null) - } - } + Expr::Wildcard { .. } => Ok(DataType::Null), Expr::GroupingSet(_) => { - // grouping sets do not really have a type and do not appear in projections + // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - field_for_index(expr, field, schema).map(|x| x.data_type().clone()) - } } } @@ -242,10 +242,9 @@ impl ExprSchemable for Expr { /// column that does not exist in the schema. fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Alias(Alias { expr, .. }) - | Expr::Not(expr) - | Expr::Negative(expr) - | Expr::Sort(Sort { expr, .. }) => expr.nullable(input_schema), + Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => { + expr.nullable(input_schema) + } Expr::InList(InList { expr, list, .. }) => { // Avoid inspecting too many expressions. @@ -280,7 +279,7 @@ impl ExprSchemable for Expr { Expr::OuterReferenceColumn(_, _) => Ok(true), Expr::Literal(value) => Ok(value.is_null()), Expr::Case(case) => { - // this expression is nullable if any of the input expressions are nullable + // This expression is nullable if any of the input expressions are nullable let then_nullable = case .when_then_expr .iter() @@ -297,11 +296,20 @@ impl ExprSchemable for Expr { } } Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema), + Expr::ScalarFunction(ScalarFunction { func, args }) => { + Ok(func.is_nullable(args, input_schema)) + } + Expr::AggregateFunction(AggregateFunction { func, .. }) => { + Ok(func.is_nullable()) + } + Expr::WindowFunction(window_function) => self + .data_type_and_nullable_with_window_function( + input_schema, + window_function, + ) + .map(|(_, nullable)| nullable), Expr::ScalarVariable(_, _) | Expr::TryCast { .. } - | Expr::ScalarFunction(..) - | Expr::WindowFunction { .. } - | Expr::AggregateFunction { .. } | Expr::Unnest(_) | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) @@ -326,21 +334,9 @@ impl ExprSchemable for Expr { | Expr::SimilarTo(Like { expr, pattern, .. }) => { Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?) } - Expr::Wildcard { .. } => internal_err!( - "Wildcard expressions are not valid in a logical query plan" - ), - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - // If schema is nested, check if parent is nullable - // if it is, return early - if let Expr::Column(col) = expr.as_ref() { - if input_schema.nullable(col)? { - return Ok(true); - } - } - field_for_index(expr, field, input_schema).map(|x| x.is_nullable()) - } + Expr::Wildcard { .. } => Ok(false), Expr::GroupingSet(_) => { - // grouping sets do not really have the concept of nullable and do not appear + // Grouping sets do not really have the concept of nullable and do not appear // in projections Ok(true) } @@ -379,9 +375,7 @@ impl ExprSchemable for Expr { }, _ => expr.data_type_and_nullable(schema), }, - Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => { - expr.data_type_and_nullable(schema) - } + Expr::Negative(expr) => expr.data_type_and_nullable(schema), Expr::Column(c) => schema .data_type_and_nullable(c) .map(|(d, n)| (d.clone(), n)), @@ -410,6 +404,9 @@ impl ExprSchemable for Expr { let right = right.data_type_and_nullable(schema)?; Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) } + Expr::WindowFunction(window_function) => { + self.data_type_and_nullable_with_window_function(schema, window_function) + } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), } } @@ -422,35 +419,12 @@ impl ExprSchemable for Expr { &self, input_schema: &dyn ExprSchema, ) -> Result<(Option, Arc)> { - match self { - Expr::Column(c) => { - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - Ok(( - c.relation.clone(), - Field::new(&c.name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(), - )) - } - Expr::Alias(Alias { relation, name, .. }) => { - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - Ok(( - relation.clone(), - Field::new(name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(), - )) - } - _ => { - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - Ok(( - None, - Field::new(self.display_name()?, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(), - )) - } - } + let (relation, schema_name) = self.qualified_name(); + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + let field = Field::new(schema_name, data_type, nullable) + .with_metadata(self.metadata(input_schema)?) + .into(); + Ok((relation, field)) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -465,7 +439,7 @@ impl ExprSchemable for Expr { return Ok(self); } - // TODO(kszucs): most of the operations do not validate the type correctness + // TODO(kszucs): Most of the operations do not validate the type correctness // like all of the binary expressions below. Perhaps Expr should track the // type of the expression? @@ -482,34 +456,84 @@ impl ExprSchemable for Expr { } } -/// return the schema [`Field`] for the type referenced by `get_indexed_field` -fn field_for_index( - expr: &Expr, - field: &GetFieldAccess, - schema: &dyn ExprSchema, -) -> Result { - let expr_dt = expr.get_type(schema)?; - match field { - GetFieldAccess::NamedStructField { name } => { - GetFieldAccessSchema::NamedStructField { name: name.clone() } +impl Expr { + /// Common method for window functions that applies type coercion + /// to all arguments of the window function to check if it matches + /// its signature. + /// + /// If successful, this method returns the data type and + /// nullability of the window function's result. + /// + /// Otherwise, returns an error if there's a type mismatch between + /// the window function's signature and the provided arguments. + fn data_type_and_nullable_with_window_function( + &self, + schema: &dyn ExprSchema, + window_function: &WindowFunction, + ) -> Result<(DataType, bool)> { + let WindowFunction { fun, args, .. } = window_function; + + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + match fun { + WindowFunctionDefinition::BuiltInWindowFunction(window_fun) => { + let return_type = window_fun.return_type(&data_types)?; + let nullable = + !["RANK", "NTILE", "CUME_DIST"].contains(&window_fun.name()); + Ok((return_type, nullable)) + } + WindowFunctionDefinition::AggregateUDF(udaf) => { + let new_types = data_types_with_aggregate_udf(&data_types, udaf) + .map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + fun.name(), + fun.signature(), + &data_types + ) + ) + })?; + + let return_type = udaf.return_type(&new_types)?; + let nullable = udaf.is_nullable(); + + Ok((return_type, nullable)) + } + WindowFunctionDefinition::WindowUDF(udwf) => { + let new_types = + data_types_with_window_udf(&data_types, udwf).map_err(|err| { + plan_datafusion_err!( + "{} {}", + err, + utils::generate_signature_error_msg( + fun.name(), + fun.signature(), + &data_types + ) + ) + })?; + let (_, function_name) = self.qualified_name(); + let field_args = WindowUDFFieldArgs::new(&new_types, &function_name); + + udwf.field(field_args) + .map(|field| (field.data_type().clone(), field.is_nullable())) + } } - GetFieldAccess::ListIndex { key } => GetFieldAccessSchema::ListIndex { - key_dt: key.get_type(schema)?, - }, - GetFieldAccess::ListRange { - start, - stop, - stride, - } => GetFieldAccessSchema::ListRange { - start_dt: start.get_type(schema)?, - stop_dt: stop.get_type(schema)?, - stride_dt: stride.get_type(schema)?, - }, } - .get_accessed_field(&expr_dt) } -/// cast subquery in InSubquery/ScalarSubquery to a given type. +/// Cast subquery in InSubquery/ScalarSubquery to a given type. +/// +/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific +/// columns), it casts the first expression in the projection to the target type and creates a +/// new projection with the casted expression. +/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan +/// with the casted first column. +/// pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { if subquery.subquery.schema().field(0).data_type() == cast_to_type { return Ok(subquery); @@ -523,7 +547,7 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result { @@ -545,8 +569,8 @@ pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result {{ @@ -671,7 +695,7 @@ mod tests { .unwrap() ); - let schema = DFSchema::from_unqualifed_fields( + let schema = DFSchema::from_unqualified_fields( vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())] .into(), HashMap::new(), @@ -682,31 +706,6 @@ mod tests { assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata()); } - #[test] - fn test_nested_schema_nullability() { - let mut builder = SchemaBuilder::new(); - builder.push(Field::new("foo", DataType::Int32, true)); - builder.push(Field::new( - "parent", - DataType::Struct(Fields::from(vec![Field::new( - "child", - DataType::Int64, - false, - )])), - true, - )); - let schema = builder.finish(); - - let dfschema = DFSchema::from_field_specific_qualified_schema( - vec![Some("table_name".into()), None], - &Arc::new(schema), - ) - .unwrap(); - - let expr = col("parent").field("child"); - assert!(expr.nullable(&dfschema).unwrap()); - } - #[derive(Debug)] struct MockExprSchema { nullable: bool, diff --git a/datafusion/expr/src/field_util.rs b/datafusion/expr/src/field_util.rs deleted file mode 100644 index f0ce61ee9bbb..000000000000 --- a/datafusion/expr/src/field_util.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Utility functions for complex field access - -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{plan_datafusion_err, plan_err, Result, ScalarValue}; - -/// Types of the field access expression of a nested type, such as `Field` or `List` -pub enum GetFieldAccessSchema { - /// Named field, For example `struct["name"]` - NamedStructField { name: ScalarValue }, - /// Single list index, for example: `list[i]` - ListIndex { key_dt: DataType }, - /// List stride, for example `list[i:j:k]` - ListRange { - start_dt: DataType, - stop_dt: DataType, - stride_dt: DataType, - }, -} - -impl GetFieldAccessSchema { - /// Returns the schema [`Field`] from a [`DataType::List`] or - /// [`DataType::Struct`] indexed by this structure - /// - /// # Error - /// Errors if - /// * the `data_type` is not a Struct or a List, - /// * the `data_type` of the name/index/start-stop do not match a supported index type - pub fn get_accessed_field(&self, data_type: &DataType) -> Result { - match self { - Self::NamedStructField{ name } => { - match (data_type, name) { - (DataType::Map(fields, _), _) => { - match fields.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - // Arrow's MapArray is essentially a ListArray of structs with two columns. They are - // often named "key", and "value", but we don't require any specific naming here; - // instead, we assume that the second columnis the "value" column both here and in - // execution. - let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(Field::new("map", value_field.data_type().clone(), true)) - }, - _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), - } - } - (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { - if s.is_empty() { - plan_err!( - "Struct based indexed access requires a non empty string" - ) - } else { - let field = fields.iter().find(|f| f.name() == s); - field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.as_ref().clone()) - } - } - (DataType::Struct(_), _) => plan_err!( - "Only utf8 strings are valid as an indexed field in a struct" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `Struct`, or `Map` types, got {other}"), - } - } - Self::ListIndex{ key_dt } => { - match (data_type, key_dt) { - (DataType::List(lt), DataType::Int64) => Ok(Field::new("list", lt.data_type().clone(), true)), - (DataType::LargeList(lt), DataType::Int64) => Ok(Field::new("large_list", lt.data_type().clone(), true)), - (DataType::List(_), _) | (DataType::LargeList(_), _) => plan_err!( - "Only ints are valid as an indexed field in a List/LargeList" - ), - (other, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), - } - } - Self::ListRange { start_dt, stop_dt, stride_dt } => { - match (data_type, start_dt, stop_dt, stride_dt) { - (DataType::List(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("list", data_type.clone(), true)), - (DataType::LargeList(_), DataType::Int64, DataType::Int64, DataType::Int64) => Ok(Field::new("large_list", data_type.clone(), true)), - (DataType::List(_), _, _, _) | (DataType::LargeList(_), _, _, _)=> plan_err!( - "Only ints are valid as an indexed field in a List/LargeList" - ), - (other, _, _, _) => plan_err!("The expression to get an indexed field is only valid for `List`, `LargeList` or `Struct` types, got {other}"), - } - } - } - } -} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7a92a50ae15d..23ffc83e3549 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -18,11 +18,27 @@ //! Function module contains typing and signature for built-in and user defined functions. use crate::ColumnarValue; -use crate::{Accumulator, Expr, PartitionEvaluator}; -use arrow::datatypes::{DataType, Schema}; +use crate::{Expr, PartitionEvaluator}; +use arrow::datatypes::DataType; use datafusion_common::Result; use std::sync::Arc; +pub use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs, +}; + +pub use datafusion_functions_window_common::expr::ExpressionArgs; +pub use datafusion_functions_window_common::field::WindowUDFFieldArgs; +pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; + +#[derive(Debug, Clone, Copy)] +pub enum Hint { + /// Indicates the argument needs to be padded if it is scalar + Pad, + /// Indicates the argument can be converted to an array of length 1 + AcceptsSingular, +} + /// Scalar function /// /// The Fn param is the wrapped function but be aware that the function will @@ -38,56 +54,6 @@ pub type ScalarFunctionImplementation = pub type ReturnTypeFunction = Arc Result> + Send + Sync>; -/// [`AccumulatorArgs`] contains information about how an aggregate -/// function was called, including the types of its arguments and any optional -/// ordering expressions. -pub struct AccumulatorArgs<'a> { - /// The return type of the aggregate function. - pub data_type: &'a DataType, - /// The schema of the input arguments - pub schema: &'a Schema, - /// Whether to ignore nulls. - /// - /// SQL allows the user to specify `IGNORE NULLS`, for example: - /// - /// ```sql - /// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; - /// ``` - pub ignore_nulls: bool, - - /// The expressions in the `ORDER BY` clause passed to this aggregator. - /// - /// SQL allows the user to specify the ordering of arguments to the - /// aggregate using an `ORDER BY`. For example: - /// - /// ```sql - /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; - /// ``` - /// - /// If no `ORDER BY` is specified, `sort_exprs`` will be empty. - pub sort_exprs: &'a [Expr], -} - -impl<'a> AccumulatorArgs<'a> { - pub fn new( - data_type: &'a DataType, - schema: &'a Schema, - ignore_nulls: bool, - sort_exprs: &'a [Expr], - ) -> Self { - Self { - data_type, - schema, - ignore_nulls, - sort_exprs, - } - } -} - -/// Factory that returns an accumulator for the given aggregate function. -pub type AccumulatorFactoryFunction = - Arc Result> + Send + Sync>; - /// Factory that creates a PartitionEvaluator for the given window /// function pub type PartitionEvaluatorFactory = @@ -97,3 +63,29 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; + +/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +///Cclosure returns simplified [Expr] or an error. +pub type AggregateFunctionSimplification = Box< + dyn Fn( + crate::expr::AggregateFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; + +/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +/// Closure returns simplified [Expr] or an error. +pub type WindowFunctionSimplification = Box< + dyn Fn( + crate::expr::WindowFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index de4f31029293..849d9604808c 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! [DataFusion](https://github.com/apache/datafusion) //! is an extensible query execution framework that uses @@ -25,45 +27,58 @@ //! //! The [expr_fn] module contains functions for creating expressions. -mod accumulator; mod built_in_window_function; -mod columnar_value; mod literal; -mod operator; +mod operation; mod partition_evaluator; -mod signature; mod table_source; mod udaf; mod udf; +mod udf_docs; mod udwf; -pub mod aggregate_function; pub mod conditional_expressions; pub mod execution_props; pub mod expr; pub mod expr_fn; pub mod expr_rewriter; pub mod expr_schema; -pub mod field_util; pub mod function; -pub mod groups_accumulator; -pub mod interval_arithmetic; +pub mod groups_accumulator { + pub use datafusion_expr_common::groups_accumulator::*; +} + +pub mod interval_arithmetic { + pub use datafusion_expr_common::interval_arithmetic::*; +} pub mod logical_plan; +pub mod planner; +pub mod registry; pub mod simplify; +pub mod sort_properties { + pub use datafusion_expr_common::sort_properties::*; +} +pub mod test; pub mod tree_node; pub mod type_coercion; pub mod utils; pub mod var_provider; pub mod window_frame; +pub mod window_function; pub mod window_state; -pub use accumulator::Accumulator; -pub use aggregate_function::AggregateFunction; pub use built_in_window_function::BuiltInWindowFunction; -pub use columnar_value::ColumnarValue; +pub use datafusion_expr_common::accumulator::Accumulator; +pub use datafusion_expr_common::columnar_value::ColumnarValue; +pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; +pub use datafusion_expr_common::operator::Operator; +pub use datafusion_expr_common::signature::{ + ArrayFunctionSignature, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, +}; +pub use datafusion_expr_common::type_coercion::binary; pub use expr::{ - Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, + Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GroupingSet, Like, + Sort as SortExpr, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; @@ -71,19 +86,17 @@ pub use function::{ AccumulatorFactoryFunction, PartitionEvaluatorFactory, ReturnTypeFunction, ScalarFunctionImplementation, StateTypeFunction, }; -pub use groups_accumulator::{EmitTo, GroupsAccumulator}; pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral}; pub use logical_plan::*; -pub use operator::Operator; pub use partition_evaluator::PartitionEvaluator; -pub use signature::{ - ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility, - TIMEZONE_WILDCARD, -}; +pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl}; -pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::{WindowUDF, WindowUDFImpl}; +pub use udaf::{ + aggregate_doc_sections, AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs, +}; +pub use udf::{scalar_doc_sections, ScalarUDF, ScalarUDFImpl}; +pub use udf_docs::{DocSection, Documentation, DocumentationBuilder}; +pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 43873cb90cda..b7839c4873af 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -20,46 +20,52 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; -use std::iter::zip; +use std::iter::once; use std::sync::Arc; use crate::dml::CopyTo; -use crate::expr::Alias; +use crate::expr::{Alias, Sort as SortExpr}; use crate::expr_rewriter::{ coerce_plan_expr_for_schema, normalize_col, - normalize_col_with_schemas_and_ambiguity_check, normalize_cols, + normalize_col_with_schemas_and_ambiguity_check, normalize_cols, normalize_sorts, rewrite_sort_cols_by_aggs, }; use crate::logical_plan::{ - Aggregate, Analyze, CrossJoin, Distinct, DistinctOn, EmptyRelation, Explain, Filter, - Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values, Window, }; -use crate::type_coercion::binary::comparison_coercion; use crate::utils::{ - can_hash, columnize_expr, compare_sort_expr, expand_qualified_wildcard, - expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, + can_hash, columnize_expr, compare_sort_expr, expr_to_columns, + find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, + and, binary_expr, lit, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, TableProviderFilterPushDown, TableSource, WriteOp, }; +use super::dml::InsertOp; +use super::plan::ColumnUnnestList; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; -use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - get_target_functional_dependencies, not_impl_err, plan_datafusion_err, plan_err, - Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, - ToDFSchema, UnnestOptions, + exec_err, get_target_functional_dependencies, internal_err, not_impl_err, + plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, + FunctionalDependencies, Result, ScalarValue, TableReference, ToDFSchema, + UnnestOptions, }; +use datafusion_expr_common::type_coercion::binary::type_union_resolution; +use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; /// Builder for logical plans /// +/// # Example building a simple plan /// ``` /// # use datafusion_expr::{lit, col, LogicalPlanBuilder, logical_plan::table_scan}; /// # use datafusion_common::Result; @@ -81,28 +87,34 @@ pub const UNNAMED_TABLE: &str = "?table?"; /// // SELECT last_name /// // FROM employees /// // WHERE salary < 1000 -/// let plan = table_scan( -/// Some("employee"), -/// &employee_schema(), -/// None, -/// )? -/// // Keep only rows where salary < 1000 -/// .filter(col("salary").lt_eq(lit(1000)))? -/// // only show "last_name" in the final results -/// .project(vec![col("last_name")])? -/// .build()?; +/// let plan = table_scan(Some("employee"), &employee_schema(), None)? +/// // Keep only rows where salary < 1000 +/// .filter(col("salary").lt(lit(1000)))? +/// // only show "last_name" in the final results +/// .project(vec![col("last_name")])? +/// .build()?; +/// +/// // Convert from plan back to builder +/// let builder = LogicalPlanBuilder::from(plan); /// /// # Ok(()) /// # } /// ``` #[derive(Debug, Clone)] pub struct LogicalPlanBuilder { - plan: LogicalPlan, + plan: Arc, } impl LogicalPlanBuilder { /// Create a builder from an existing plan - pub fn from(plan: LogicalPlan) -> Self { + pub fn new(plan: LogicalPlan) -> Self { + Self { + plan: Arc::new(plan), + } + } + + /// Create a builder from an existing plan + pub fn new_from_arc(plan: Arc) -> Self { Self { plan } } @@ -111,11 +123,16 @@ impl LogicalPlanBuilder { self.plan.schema() } + /// Return the LogicalPlan of the plan build so far + pub fn plan(&self) -> &LogicalPlan { + &self.plan + } + /// Create an empty relation. /// /// `produce_one_row` set to true means this empty node needs to produce a placeholder row. pub fn empty(produce_one_row: bool) -> Self { - Self::from(LogicalPlan::EmptyRelation(EmptyRelation { + Self::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, schema: DFSchemaRef::new(DFSchema::empty()), })) @@ -124,7 +141,7 @@ impl LogicalPlanBuilder { /// Convert a regular plan into a recursive query. /// `is_distinct` indicates whether the recursive term should be de-duplicated (`UNION`) after each iteration or not (`UNION ALL`). pub fn to_recursive_query( - &self, + self, name: String, recursive_term: LogicalPlan, is_distinct: bool, @@ -146,10 +163,10 @@ impl LogicalPlanBuilder { } // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = - coerce_plan_expr_for_schema(&recursive_term, self.plan.schema())?; + coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, - static_term: Arc::new(self.plan.clone()), + static_term: self.plan, recursive_term: Arc::new(coerced_recursive_term), is_distinct, }))) @@ -159,12 +176,45 @@ impl LogicalPlanBuilder { /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. /// + /// so it's usually better to override the default names with a table alias list. + /// + /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. + pub fn values(values: Vec>) -> Result { + if values.is_empty() { + return plan_err!("Values list cannot be empty"); + } + let n_cols = values[0].len(); + if n_cols == 0 { + return plan_err!("Values list cannot be zero length"); + } + for (i, row) in values.iter().enumerate() { + if row.len() != n_cols { + return plan_err!( + "Inconsistent data length across values list: got {} values in row {} but expected {}", + row.len(), + i, + n_cols + ); + } + } + + // Infer from data itself + Self::infer_data(values) + } + + /// Create a values list based relation, and the schema is inferred from data itself or table schema if provided, consuming + /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) + /// documentation for more details. + /// /// By default, it assigns the names column1, column2, etc. to the columns of a VALUES table. /// The column names are not specified by the SQL standard and different database systems do it differently, /// so it's usually better to override the default names with a table alias list. /// /// If the values include params/binders such as $1, $2, $3, etc, then the `param_data_types` should be provided. - pub fn values(mut values: Vec>) -> Result { + pub fn values_with_schema( + values: Vec>, + schema: &DFSchemaRef, + ) -> Result { if values.is_empty() { return plan_err!("Values list cannot be empty"); } @@ -172,13 +222,6 @@ impl LogicalPlanBuilder { if n_cols == 0 { return plan_err!("Values list cannot be zero length"); } - let empty_schema = DFSchema::empty(); - let mut field_types: Vec> = Vec::with_capacity(n_cols); - for _ in 0..n_cols { - field_types.push(None); - } - // hold all the null holes so that we can correct their data types later - let mut nulls: Vec<(usize, usize)> = Vec::new(); for (i, row) in values.iter().enumerate() { if row.len() != n_cols { return plan_err!( @@ -188,24 +231,88 @@ impl LogicalPlanBuilder { n_cols ); } - field_types = row - .iter() - .enumerate() - .map(|(j, expr)| { - if let Expr::Literal(ScalarValue::Null) = expr { - nulls.push((i, j)); - Ok(field_types[j].clone()) + } + + // Check the type of value against the schema + Self::infer_values_from_schema(values, schema) + } + + fn infer_values_from_schema( + values: Vec>, + schema: &DFSchema, + ) -> Result { + let n_cols = values[0].len(); + let mut field_types: Vec = Vec::with_capacity(n_cols); + for j in 0..n_cols { + let field_type = schema.field(j).data_type(); + for row in values.iter() { + let value = &row[j]; + let data_type = value.get_type(schema)?; + + if !data_type.equals_datatype(field_type) { + if can_cast_types(&data_type, field_type) { } else { - let data_type = expr.get_type(&empty_schema)?; - if let Some(prev_data_type) = &field_types[j] { - if prev_data_type != &data_type { - return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_data_type} but found {data_type}") - } - } - Ok(Some(data_type)) + return exec_err!( + "type mistmatch and can't cast to got {} and {}", + data_type, + field_type + ); } - }) - .collect::>>>()?; + } + } + field_types.push(field_type.to_owned()); + } + + Self::infer_inner(values, &field_types, schema) + } + + fn infer_data(values: Vec>) -> Result { + let n_cols = values[0].len(); + let schema = DFSchema::empty(); + + let mut field_types: Vec = Vec::with_capacity(n_cols); + for j in 0..n_cols { + let mut common_type: Option = None; + for (i, row) in values.iter().enumerate() { + let value = &row[j]; + let data_type = value.get_type(&schema)?; + if data_type == DataType::Null { + continue; + } + + if let Some(prev_type) = common_type { + // get common type of each column values. + let data_types = vec![prev_type.clone(), data_type.clone()]; + let Some(new_type) = type_union_resolution(&data_types) else { + return plan_err!("Inconsistent data type across values list at row {i} column {j}. Was {prev_type} but found {data_type}"); + }; + common_type = Some(new_type); + } else { + common_type = Some(data_type); + } + } + // assuming common_type was not set, and no error, therefore the type should be NULL + // since the code loop skips NULL + field_types.push(common_type.unwrap_or(DataType::Null)); + } + + Self::infer_inner(values, &field_types, &schema) + } + + fn infer_inner( + mut values: Vec>, + field_types: &[DataType], + schema: &DFSchema, + ) -> Result { + // wrap cast if data type is not same as common type. + for row in &mut values { + for (j, field_type) in field_types.iter().enumerate() { + if let Expr::Literal(ScalarValue::Null) = row[j] { + row[j] = Expr::Literal(ScalarValue::try_from(field_type)?); + } else { + row[j] = std::mem::take(&mut row[j]).cast_to(field_type, schema)?; + } + } } let fields = field_types .iter() @@ -213,15 +320,13 @@ impl LogicalPlanBuilder { .map(|(j, data_type)| { // naming is following convention https://www.postgresql.org/docs/current/queries-values.html let name = &format!("column{}", j + 1); - Field::new(name, data_type.clone().unwrap_or(DataType::Utf8), true) + Field::new(name, data_type.clone(), true) }) .collect::>(); - for (i, j) in nulls { - values[i][j] = Expr::Literal(ScalarValue::try_from(fields[j].data_type())?); - } - let dfschema = DFSchema::from_unqualifed_fields(fields.into(), HashMap::new())?; + let dfschema = DFSchema::from_unqualified_fields(fields.into(), HashMap::new())?; let schema = DFSchemaRef::new(dfschema); - Ok(Self::from(LogicalPlan::Values(Values { schema, values }))) + + Ok(Self::new(LogicalPlan::Values(Values { schema, values }))) } /// Convert a table provider into a builder with a TableScan @@ -268,16 +373,16 @@ impl LogicalPlanBuilder { pub fn copy_to( input: LogicalPlan, output_url: String, - format_options: FormatOptions, + file_type: Arc, options: HashMap, partition_by: Vec, ) -> Result { - Ok(Self::from(LogicalPlan::Copy(CopyTo { + Ok(Self::new(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url, - format_options, - options, partition_by, + file_type, + options, }))) } @@ -286,22 +391,16 @@ impl LogicalPlanBuilder { input: LogicalPlan, table_name: impl Into, table_schema: &Schema, - overwrite: bool, + insert_op: InsertOp, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto - }; - - Ok(Self::from(LogicalPlan::Dml(DmlStatement { - table_name: table_name.into(), + Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( + table_name.into(), table_schema, - op, - input: Arc::new(input), - }))) + WriteOp::Insert(insert_op), + Arc::new(input), + )))) } /// Convert a table provider into a builder with a TableScan @@ -313,7 +412,20 @@ impl LogicalPlanBuilder { ) -> Result { TableScan::try_new(table_name, table_source, projection, filters, None) .map(LogicalPlan::TableScan) - .map(Self::from) + .map(Self::new) + } + + /// Convert a table provider into a builder with a TableScan with filter and fetch + pub fn scan_with_filters_fetch( + table_name: impl Into, + table_source: Arc, + projection: Option>, + filters: Vec, + fetch: Option, + ) -> Result { + TableScan::try_new(table_name, table_source, projection, filters, fetch) + .map(LogicalPlan::TableScan) + .map(Self::new) } /// Wrap a plan in a window @@ -358,7 +470,7 @@ impl LogicalPlanBuilder { self, expr: impl IntoIterator>, ) -> Result { - project(self.plan, expr).map(Self::from) + project(Arc::unwrap_or_clone(self.plan), expr).map(Self::new) } /// Select the given column indices @@ -373,17 +485,25 @@ impl LogicalPlanBuilder { /// Apply a filter pub fn filter(self, expr: impl Into) -> Result { let expr = normalize_col(expr.into(), &self.plan)?; - Filter::try_new(expr, Arc::new(self.plan)) + Filter::try_new(expr, self.plan) + .map(LogicalPlan::Filter) + .map(Self::new) + } + + /// Apply a filter which is used for a having clause + pub fn having(self, expr: impl Into) -> Result { + let expr = normalize_col(expr.into(), &self.plan)?; + Filter::try_new_with_having(expr, self.plan) .map(LogicalPlan::Filter) .map(Self::from) } /// Make a builder for a prepare logical plan from the builder's plan pub fn prepare(self, name: String, data_types: Vec) -> Result { - Ok(Self::from(LogicalPlan::Prepare(Prepare { + Ok(Self::new(LogicalPlan::Prepare(Prepare { name, data_types, - input: Arc::new(self.plan), + input: self.plan, }))) } @@ -394,28 +514,41 @@ impl LogicalPlanBuilder { /// `fetch` - Maximum number of rows to fetch, after skipping `skip` rows, /// if specified. pub fn limit(self, skip: usize, fetch: Option) -> Result { - Ok(Self::from(LogicalPlan::Limit(Limit { - skip, - fetch, - input: Arc::new(self.plan), + let skip_expr = if skip == 0 { + None + } else { + Some(lit(skip as i64)) + }; + let fetch_expr = fetch.map(|f| lit(f as i64)); + self.limit_by_expr(skip_expr, fetch_expr) + } + + /// Limit the number of rows returned + /// + /// Similar to `limit` but uses expressions for `skip` and `fetch` + pub fn limit_by_expr(self, skip: Option, fetch: Option) -> Result { + Ok(Self::new(LogicalPlan::Limit(Limit { + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), + input: self.plan, }))) } /// Apply an alias pub fn alias(self, alias: impl Into) -> Result { - subquery_alias(self.plan, alias).map(Self::from) + subquery_alias(Arc::unwrap_or_clone(self.plan), alias).map(Self::new) } /// Add missing sort columns to all downstream projection /// - /// Thus, if you have a LogialPlan that selects A and B and have + /// Thus, if you have a LogicalPlan that selects A and B and have /// not requested a sort by C, this code will add C recursively to /// all input projections. /// /// Adding a new column is not correct if there is a `Distinct` /// node, which produces only distinct values of its /// inputs. Adding a new column to its input will result in - /// potententially different results than with the original column. + /// potentially different results than with the original column. /// /// For example, if the input is like: /// @@ -436,7 +569,7 @@ impl LogicalPlanBuilder { /// See for more details fn add_missing_columns( curr_plan: LogicalPlan, - missing_cols: &[Column], + missing_cols: &IndexSet, is_distinct: bool, ) -> Result { match curr_plan { @@ -458,7 +591,7 @@ impl LogicalPlanBuilder { Self::ambiguous_distinct_check(&missing_exprs, missing_cols, &expr)?; } expr.extend(missing_exprs); - project((*input).clone(), expr) + project(Arc::unwrap_or_clone(input), expr) } _ => { let is_distinct = @@ -481,7 +614,7 @@ impl LogicalPlanBuilder { fn ambiguous_distinct_check( missing_exprs: &[Expr], - missing_cols: &[Column], + missing_cols: &IndexSet, projection_exprs: &[Expr], ) -> Result<()> { if missing_exprs.is_empty() { @@ -516,37 +649,55 @@ impl LogicalPlanBuilder { plan_err!("For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list") } - /// Apply a sort + /// Apply a sort by provided expressions with default direction + pub fn sort_by( + self, + expr: impl IntoIterator> + Clone, + ) -> Result { + self.sort( + expr.into_iter() + .map(|e| e.into().sort(true, false)) + .collect::>(), + ) + } + pub fn sort( self, - exprs: impl IntoIterator> + Clone, + sorts: impl IntoIterator> + Clone, ) -> Result { - let exprs = rewrite_sort_cols_by_aggs(exprs, &self.plan)?; + self.sort_with_limit(sorts, None) + } + + /// Apply a sort + pub fn sort_with_limit( + self, + sorts: impl IntoIterator> + Clone, + fetch: Option, + ) -> Result { + let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?; let schema = self.plan.schema(); // Collect sort columns that are missing in the input plan's schema - let mut missing_cols: Vec = vec![]; - exprs - .clone() - .into_iter() - .try_for_each::<_, Result<()>>(|expr| { - let columns = expr.to_columns()?; + let mut missing_cols: IndexSet = IndexSet::new(); + sorts.iter().try_for_each::<_, Result<()>>(|sort| { + let columns = sort.expr.column_refs(); - columns.into_iter().for_each(|c| { - if schema.field_from_column(&c).is_err() { - missing_cols.push(c); - } - }); + missing_cols.extend( + columns + .into_iter() + .filter(|c| !schema.has_column(c)) + .cloned(), + ); - Ok(()) - })?; + Ok(()) + })?; if missing_cols.is_empty() { - return Ok(Self::from(LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &self.plan)?, - input: Arc::new(self.plan), - fetch: None, + return Ok(Self::new(LogicalPlan::Sort(Sort { + expr: normalize_sorts(sorts, &self.plan)?, + input: self.plan, + fetch, }))); } @@ -554,38 +705,40 @@ impl LogicalPlanBuilder { let new_expr = schema.columns().into_iter().map(Expr::Column).collect(); let is_distinct = false; - let plan = Self::add_missing_columns(self.plan, &missing_cols, is_distinct)?; + let plan = Self::add_missing_columns( + Arc::unwrap_or_clone(self.plan), + &missing_cols, + is_distinct, + )?; let sort_plan = LogicalPlan::Sort(Sort { - expr: normalize_cols(exprs, &plan)?, + expr: normalize_sorts(sorts, &plan)?, input: Arc::new(plan), - fetch: None, + fetch, }); Projection::try_new(new_expr, Arc::new(sort_plan)) .map(LogicalPlan::Projection) - .map(Self::from) + .map(Self::new) } /// Apply a union, preserving duplicate rows pub fn union(self, plan: LogicalPlan) -> Result { - union(self.plan, plan).map(Self::from) + union(Arc::unwrap_or_clone(self.plan), plan).map(Self::new) } /// Apply a union, removing duplicate rows pub fn union_distinct(self, plan: LogicalPlan) -> Result { - let left_plan: LogicalPlan = self.plan; + let left_plan: LogicalPlan = Arc::unwrap_or_clone(self.plan); let right_plan: LogicalPlan = plan; - Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( + Ok(Self::new(LogicalPlan::Distinct(Distinct::All(Arc::new( union(left_plan, right_plan)?, ))))) } /// Apply deduplication: Only distinct (different) values are returned) pub fn distinct(self) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct::All(Arc::new( - self.plan, - ))))) + Ok(Self::new(LogicalPlan::Distinct(Distinct::All(self.plan)))) } /// Project first values of the specified expression list according to the provided @@ -594,10 +747,10 @@ impl LogicalPlanBuilder { self, on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, ) -> Result { - Ok(Self::from(LogicalPlan::Distinct(Distinct::On( - DistinctOn::try_new(on_expr, select_expr, sort_expr, Arc::new(self.plan))?, + Ok(Self::new(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new(on_expr, select_expr, sort_expr, self.plan)?, )))) } @@ -683,7 +836,7 @@ impl LogicalPlanBuilder { pub(crate) fn normalize( plan: &LogicalPlan, - column: impl Into + Clone, + column: impl Into, ) -> Result { let schema = plan.schema(); let fallback_schemas = plan.fallback_normalize_schemas(); @@ -812,8 +965,8 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), on, filter, @@ -876,8 +1029,8 @@ impl LogicalPlanBuilder { DataFusionError::Internal("filters should not be None here".to_string()) })?) } else { - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), on: join_on, filter: filters, @@ -893,17 +1046,22 @@ impl LogicalPlanBuilder { pub fn cross_join(self, right: LogicalPlan) -> Result { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &JoinType::Inner)?; - Ok(Self::from(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, schema: DFSchemaRef::new(join_schema), }))) } /// Repartition pub fn repartition(self, partitioning_scheme: Partitioning) -> Result { - Ok(Self::from(LogicalPlan::Repartition(Repartition { - input: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Repartition(Repartition { + input: self.plan, partitioning_scheme, }))) } @@ -915,9 +1073,9 @@ impl LogicalPlanBuilder { ) -> Result { let window_expr = normalize_cols(window_expr, &self.plan)?; validate_unique_names("Windows", &window_expr)?; - Ok(Self::from(LogicalPlan::Window(Window::try_new( + Ok(Self::new(LogicalPlan::Window(Window::try_new( window_expr, - Arc::new(self.plan), + self.plan, )?))) } @@ -934,9 +1092,9 @@ impl LogicalPlanBuilder { let group_expr = add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?; - Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr) + Aggregate::try_new(self.plan, group_expr, aggr_expr) .map(LogicalPlan::Aggregate) - .map(Self::from) + .map(Self::new) } /// Create an expression to represent the explanation of the plan @@ -950,18 +1108,18 @@ impl LogicalPlanBuilder { let schema = schema.to_dfschema_ref()?; if analyze { - Ok(Self::from(LogicalPlan::Analyze(Analyze { + Ok(Self::new(LogicalPlan::Analyze(Analyze { verbose, - input: Arc::new(self.plan), + input: self.plan, schema, }))) } else { let stringified_plans = vec![self.plan.to_stringified(PlanType::InitialLogicalPlan)]; - Ok(Self::from(LogicalPlan::Explain(Explain { + Ok(Self::new(LogicalPlan::Explain(Explain { verbose, - plan: Arc::new(self.plan), + plan: self.plan, stringified_plans, schema, logical_optimization_succeeded: false, @@ -1039,7 +1197,7 @@ impl LogicalPlanBuilder { /// Build the plan pub fn build(self) -> Result { - Ok(self.plan) + Ok(Arc::unwrap_or_clone(self.plan)) } /// Apply a join with the expression on constraint. @@ -1067,14 +1225,16 @@ impl LogicalPlanBuilder { let left_key = l.into(); let right_key = r.into(); - let left_using_columns = left_key.to_columns()?; + let mut left_using_columns = HashSet::new(); + expr_to_columns(&left_key, &mut left_using_columns)?; let normalized_left_key = normalize_col_with_schemas_and_ambiguity_check( left_key, &[&[self.plan.schema(), right.schema()]], &[left_using_columns], )?; - let right_using_columns = right_key.to_columns()?; + let mut right_using_columns = HashSet::new(); + expr_to_columns(&right_key, &mut right_using_columns)?; let normalized_right_key = normalize_col_with_schemas_and_ambiguity_check( right_key, &[&[self.plan.schema(), right.schema()]], @@ -1085,8 +1245,8 @@ impl LogicalPlanBuilder { find_valid_equijoin_key_pair( &normalized_left_key, &normalized_right_key, - self.plan.schema().clone(), - right.schema().clone(), + self.plan.schema(), + right.schema(), )?.ok_or_else(|| plan_datafusion_err!( "can't create join plan, join key should belong to one input, error key: ({normalized_left_key},{normalized_right_key})" @@ -1097,8 +1257,8 @@ impl LogicalPlanBuilder { let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; - Ok(Self::from(LogicalPlan::Join(Join { - left: Arc::new(self.plan), + Ok(Self::new(LogicalPlan::Join(Join { + left: self.plan, right: Arc::new(right), on: join_key_pairs, filter, @@ -1111,7 +1271,7 @@ impl LogicalPlanBuilder { /// Unnest the given column. pub fn unnest_column(self, column: impl Into) -> Result { - Ok(Self::from(unnest(self.plan, vec![column.into()])?)) + unnest(Arc::unwrap_or_clone(self.plan), vec![column.into()]).map(Self::new) } /// Unnest the given column given [`UnnestOptions`] @@ -1120,11 +1280,12 @@ impl LogicalPlanBuilder { column: impl Into, options: UnnestOptions, ) -> Result { - Ok(Self::from(unnest_with_options( - self.plan, + unnest_with_options( + Arc::unwrap_or_clone(self.plan), vec![column.into()], options, - )?)) + ) + .map(Self::new) } /// Unnest the given columns with the given [`UnnestOptions`] @@ -1133,11 +1294,23 @@ impl LogicalPlanBuilder { columns: Vec, options: UnnestOptions, ) -> Result { - Ok(Self::from(unnest_with_options( - self.plan, columns, options, - )?)) + unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) + .map(Self::new) + } +} + +impl From for LogicalPlanBuilder { + fn from(plan: LogicalPlan) -> Self { + LogicalPlanBuilder::new(plan) + } +} + +impl From> for LogicalPlanBuilder { + fn from(plan: Arc) -> Self { + LogicalPlanBuilder::new_from_arc(plan) } } + pub fn change_redundant_column(fields: &Fields) -> Vec { let mut name_map = HashMap::new(); fields @@ -1154,6 +1327,25 @@ pub fn change_redundant_column(fields: &Fields) -> Vec { }) .collect() } + +fn mark_field(schema: &DFSchema) -> (Option, Arc) { + let mut table_references = schema + .iter() + .filter_map(|(qualifier, _)| qualifier) + .collect::>(); + table_references.dedup(); + let table_reference = if table_references.len() == 1 { + table_references.pop().cloned() + } else { + None + }; + + ( + table_reference, + Arc::new(Field::new("mark", DataType::Boolean, false)), + ) +} + /// Creates a schema for a join operation. /// The fields from the left side are first pub fn build_join_schema( @@ -1180,17 +1372,17 @@ pub fn build_join_schema( JoinType::Inner => { // left then right let left_fields = left_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); let right_fields = right_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); left_fields.into_iter().chain(right_fields).collect() } JoinType::Left => { // left then right, right set to nullable in case of not matched scenario let left_fields = left_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); left_fields .into_iter() @@ -1200,7 +1392,7 @@ pub fn build_join_schema( JoinType::Right => { // left then right, left set to nullable in case of not matched scenario let right_fields = right_fields - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect::>(); nullify_fields(left_fields) .into_iter() @@ -1216,11 +1408,19 @@ pub fn build_join_schema( } JoinType::LeftSemi | JoinType::LeftAnti => { // Only use the left side for the schema - left_fields.map(|(q, f)| (q.cloned(), f.clone())).collect() + left_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect() } + JoinType::LeftMark => left_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .chain(once(mark_field(right))) + .collect(), JoinType::RightSemi | JoinType::RightAnti => { // Only use the right side for the schema - right_fields.map(|(q, f)| (q.cloned(), f.clone())).collect() + right_fields + .map(|(q, f)| (q.cloned(), Arc::clone(f))) + .collect() } }; let func_dependencies = left.functional_dependencies().join( @@ -1228,8 +1428,12 @@ pub fn build_join_schema( join_type, left.fields().len(), ); - let mut metadata = left.metadata().clone(); - metadata.extend(right.metadata().clone()); + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); let dfschema = DFSchema::new_with_metadata(qualified_fields, metadata)?; dfschema.with_functional_dependencies(func_dependencies) } @@ -1243,7 +1447,7 @@ pub fn build_join_schema( /// /// This allows MySQL style selects like /// `SELECT col FROM t WHERE pk = 5` if col is unique -fn add_group_by_exprs_from_dependencies( +pub fn add_group_by_exprs_from_dependencies( mut group_expr: Vec, schema: &DFSchemaRef, ) -> Result> { @@ -1251,15 +1455,15 @@ fn add_group_by_exprs_from_dependencies( // c1 + 1` produces an output field named `"c1 + 1"` let mut group_by_field_names = group_expr .iter() - .map(|e| e.display_name()) - .collect::>>()?; + .map(|e| e.schema_name().to_string()) + .collect::>(); if let Some(target_indices) = get_target_functional_dependencies(schema, &group_by_field_names) { for idx in target_indices { let expr = Expr::Column(Column::from(schema.qualified_field(idx))); - let expr_name = expr.display_name()?; + let expr_name = expr.schema_name().to_string(); if !group_by_field_names.contains(&expr_name) { group_by_field_names.push(expr_name); group_expr.push(expr); @@ -1269,14 +1473,14 @@ fn add_group_by_exprs_from_dependencies( Ok(group_expr) } /// Errors if one or more expressions have equal names. -pub(crate) fn validate_unique_names<'a>( +pub fn validate_unique_names<'a>( node_name: &str, expressions: impl IntoIterator, ) -> Result<()> { let mut unique_names = HashMap::new(); expressions.into_iter().enumerate().try_for_each(|(position, expr)| { - let name = expr.display_name()?; + let name = expr.schema_name().to_string(); match unique_names.get(&name) { None => { unique_names.insert(name, (position, expr)); @@ -1292,95 +1496,38 @@ pub(crate) fn validate_unique_names<'a>( }) } -pub fn project_with_column_index( - expr: Vec, - input: Arc, - schema: DFSchemaRef, -) -> Result { - let alias_expr = expr - .into_iter() - .enumerate() - .map(|(i, e)| match e { - Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { - e.unalias().alias(schema.field(i).name()) - } - Expr::Column(Column { - relation: _, - ref name, - }) if name != schema.field(i).name() => e.alias(schema.field(i).name()), - Expr::Alias { .. } | Expr::Column { .. } => e, - _ => e.alias(schema.field(i).name()), - }) - .collect::>(); - - Projection::try_new_with_schema(alias_expr, input, schema) - .map(LogicalPlan::Projection) -} - -/// Union two logical plans. +/// Union two [`LogicalPlan`]s. +/// +/// Constructs the UNION plan, but does not perform type-coercion. Therefore the +/// subtree expressions will not be properly typed until the optimizer pass. +/// +/// If a properly typed UNION plan is needed, refer to [`TypeCoercionRewriter::coerce_union`] +/// or alternatively, merge the union input schema using [`coerce_union_schema`] and +/// apply the expression rewrite with [`coerce_plan_expr_for_schema`]. +/// +/// [`TypeCoercionRewriter::coerce_union`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/struct.TypeCoercionRewriter.html#method.coerce_union +/// [`coerce_union_schema`]: https://docs.rs/datafusion-optimizer/latest/datafusion_optimizer/analyzer/type_coercion/fn.coerce_union_schema.html pub fn union(left_plan: LogicalPlan, right_plan: LogicalPlan) -> Result { - let left_col_num = left_plan.schema().fields().len(); - - // check union plan length same. - let right_col_num = right_plan.schema().fields().len(); - if right_col_num != left_col_num { + if left_plan.schema().fields().len() != right_plan.schema().fields().len() { return plan_err!( - "Union queries must have the same number of columns, (left is {left_col_num}, right is {right_col_num})"); - } - - // create union schema - let union_qualified_fields = - zip(left_plan.schema().iter(), right_plan.schema().iter()) - .map( - |((left_qualifier, left_field), (_right_qualifier, right_field))| { - let nullable = left_field.is_nullable() || right_field.is_nullable(); - let data_type = comparison_coercion( - left_field.data_type(), - right_field.data_type(), - ) - .ok_or_else(|| { - plan_datafusion_err!( - "UNION Column {} (type: {}) is not compatible with column {} (type: {})", - right_field.name(), - right_field.data_type(), - left_field.name(), - left_field.data_type() - ) - })?; - Ok(( - left_qualifier.cloned(), - Arc::new(Field::new(left_field.name(), data_type, nullable)), - )) - }, - ) - .collect::>>()?; - let union_schema = - DFSchema::new_with_metadata(union_qualified_fields, HashMap::new())?; + "UNION queries have different number of columns: \ + left has {} columns whereas right has {} columns", + left_plan.schema().fields().len(), + right_plan.schema().fields().len() + ); + } - let inputs = vec![left_plan, right_plan] - .into_iter() - .map(|p| { - let plan = coerce_plan_expr_for_schema(&p, &union_schema)?; - match plan { - LogicalPlan::Projection(Projection { expr, input, .. }) => { - Ok(Arc::new(project_with_column_index( - expr, - input, - Arc::new(union_schema.clone()), - )?)) - } - other_plan => Ok(Arc::new(other_plan)), - } - }) - .collect::>>()?; + // Temporarily use the schema from the left input and later rely on the analyzer to + // coerce the two schemas into a common one. - if inputs.is_empty() { - return plan_err!("Empty UNION"); - } + // Functional Dependencies doesn't preserve after UNION operation + let schema = (**left_plan.schema()).clone(); + let schema = + Arc::new(schema.with_functional_dependencies(FunctionalDependencies::empty())?); Ok(LogicalPlan::Union(Union { - inputs, - schema: Arc::new(union_schema), + inputs: vec![Arc::new(left_plan), Arc::new(right_plan)], + schema, })) } @@ -1393,27 +1540,14 @@ pub fn project( plan: LogicalPlan, expr: impl IntoIterator>, ) -> Result { - // TODO: move it into analyzer - let input_schema = plan.schema(); let mut projected_expr = vec![]; for e in expr { let e = e.into(); match e { - Expr::Wildcard { qualifier: None } => { - projected_expr.extend(expand_wildcard(input_schema, &plan, None)?) - } - Expr::Wildcard { - qualifier: Some(qualifier), - } => projected_expr.extend(expand_qualified_wildcard( - &qualifier, - input_schema, - None, - )?), - _ => projected_expr - .push(columnize_expr(normalize_col(e, &plan)?, input_schema)), + Expr::Wildcard { .. } => projected_expr.push(e), + _ => projected_expr.push(columnize_expr(normalize_col(e, &plan)?, &plan)?), } } - validate_unique_names("Projections", projected_expr.iter())?; Projection::try_new(projected_expr, Arc::new(plan)).map(LogicalPlan::Projection) @@ -1453,6 +1587,29 @@ pub fn table_scan_with_filters( LogicalPlanBuilder::scan_with_filters(name, table_source, projection, filters) } +/// Create a LogicalPlanBuilder representing a scan of a table with the provided name and schema, +/// filters, and inlined fetch. +/// This is mostly used for testing and documentation. +pub fn table_scan_with_filter_and_fetch( + name: Option>, + table_schema: &Schema, + projection: Option>, + filters: Vec, + fetch: Option, +) -> Result { + let table_source = table_source(table_schema); + let name = name + .map(|n| n.into()) + .unwrap_or_else(|| TableReference::bare(UNNAMED_TABLE)); + LogicalPlanBuilder::scan_with_filters_fetch( + name, + table_source, + projection, + filters, + fetch, + ) +} + fn table_source(table_schema: &Schema) -> Arc { let table_schema = Arc::new(table_schema.clone()); Arc::new(LogicalTableSource { table_schema }) @@ -1486,10 +1643,15 @@ pub fn wrap_projection_for_join_if_necessary( let need_project = join_keys.iter().any(|key| !matches!(key, Expr::Column(_))); let plan = if need_project { - let mut projection = expand_wildcard(input_schema, &input, None)?; + // Include all columns from the input and extend them with the join keys + let mut projection = input_schema + .columns() + .into_iter() + .map(Expr::Column) + .collect::>(); let join_key_items = alias_join_keys .iter() - .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .flat_map(|expr| expr.try_as_col().is_none().then_some(expr)) .cloned() .collect::>(); projection.extend(join_key_items); @@ -1504,8 +1666,12 @@ pub fn wrap_projection_for_join_if_necessary( let join_on = alias_join_keys .into_iter() .map(|key| { - key.try_into_col() - .or_else(|_| Ok(Column::from_name(key.display_name()?))) + if let Some(col) = key.try_as_col() { + Ok(col.clone()) + } else { + let name = key.schema_name().to_string(); + Ok(Column::from_name(name)) + } }) .collect::>>()?; @@ -1532,63 +1698,239 @@ impl TableSource for LogicalTableSource { } fn schema(&self) -> SchemaRef { - self.table_schema.clone() + Arc::clone(&self.table_schema) } fn supports_filters_pushdown( &self, filters: &[&Expr], - ) -> Result> { + ) -> Result> { Ok(vec![TableProviderFilterPushDown::Exact; filters.len()]) } } /// Create a [`LogicalPlan::Unnest`] plan pub fn unnest(input: LogicalPlan, columns: Vec) -> Result { - unnest_with_options(input, columns, UnnestOptions::new()) + unnest_with_options(input, columns, UnnestOptions::default()) +} + +// Get the data type of a multi-dimensional type after unnesting it +// with a given depth +fn get_unnested_list_datatype_recursive( + data_type: &DataType, + depth: usize, +) -> Result { + match data_type { + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) => { + if depth == 1 { + return Ok(field.data_type().clone()); + } + return get_unnested_list_datatype_recursive(field.data_type(), depth - 1); + } + _ => {} + }; + + internal_err!("trying to unnest on invalid data type {:?}", data_type) +} + +pub fn get_struct_unnested_columns( + col_name: &String, + inner_fields: &Fields, +) -> Vec { + inner_fields + .iter() + .map(|f| Column::from_name(format!("{}.{}", col_name, f.name()))) + .collect() +} + +// Based on data type, either struct or a variant of list +// return a set of columns as the result of unnesting +// the input columns. +// For example, given a column with name "a", +// - List(Element) returns ["a"] with data type Element +// - Struct(field1, field2) returns ["a.field1","a.field2"] +// For list data type, an argument depth is used to specify +// the recursion level +pub fn get_unnested_columns( + col_name: &String, + data_type: &DataType, + depth: usize, +) -> Result)>> { + let mut qualified_columns = Vec::with_capacity(1); + + match data_type { + DataType::List(_) | DataType::FixedSizeList(_, _) | DataType::LargeList(_) => { + let data_type = get_unnested_list_datatype_recursive(data_type, depth)?; + let new_field = Arc::new(Field::new( + col_name, data_type, + // Unnesting may produce NULLs even if the list is not null. + // For example: unnset([1], []) -> 1, null + true, + )); + let column = Column::from_name(col_name); + // let column = Column::from((None, &new_field)); + qualified_columns.push((column, new_field)); + } + DataType::Struct(fields) => { + qualified_columns.extend(fields.iter().map(|f| { + let new_name = format!("{}.{}", col_name, f.name()); + let column = Column::from_name(&new_name); + let new_field = f.as_ref().clone().with_name(new_name); + // let column = Column::from((None, &f)); + (column, Arc::new(new_field)) + })) + } + _ => { + return internal_err!( + "trying to unnest on invalid data type {:?}", + data_type + ); + } + }; + Ok(qualified_columns) } /// Create a [`LogicalPlan::Unnest`] plan with options +/// This function receive a list of columns to be unnested +/// because multiple unnest can be performed on the same column (e.g unnest with different depth) +/// The new schema will contains post-unnest fields replacing the original field +/// +/// For example: +/// Input schema as +/// ```text +/// +---------------------+-------------------+ +/// | col1 | col2 | +/// +---------------------+-------------------+ +/// | Struct(INT64,INT32) | List(List(Int64)) | +/// +---------------------+-------------------+ +/// ``` +/// +/// +/// +/// Then unnesting columns with: +/// - (col1,Struct) +/// - (col2,List(\[depth=1,depth=2\])) +/// +/// will generate a new schema as +/// ```text +/// +---------+---------+---------------------+---------------------+ +/// | col1.c0 | col1.c1 | unnest_col2_depth_1 | unnest_col2_depth_2 | +/// +---------+---------+---------------------+---------------------+ +/// | Int64 | Int32 | List(Int64) | Int64 | +/// +---------+---------+---------------------+---------------------+ +/// ``` pub fn unnest_with_options( input: LogicalPlan, - columns: Vec, + columns_to_unnest: Vec, options: UnnestOptions, ) -> Result { - // Extract the type of the nested field in the list. - let mut unnested_fields: HashMap = HashMap::with_capacity(columns.len()); - // Add qualifiers to the columns. - let mut qualified_columns = Vec::with_capacity(columns.len()); - for c in &columns { - let index = input.schema().index_of_column(c)?; - let (unnest_qualifier, unnest_field) = input.schema().qualified_field(index); - let unnested_field = match unnest_field.data_type() { - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => Arc::new(Field::new( - unnest_field.name(), - field.data_type().clone(), - // Unnesting may produce NULLs even if the list is not null. - // For example: unnset([1], []) -> 1, null - true, - )), - _ => { - // If the unnest field is not a list type return the input plan. - return Ok(input); - } - }; - qualified_columns.push(Column::from((unnest_qualifier, &unnested_field))); - unnested_fields.insert(index, unnested_field); - } + let mut list_columns: Vec<(usize, ColumnUnnestList)> = vec![]; + let mut struct_columns = vec![]; + let indices_to_unnest = columns_to_unnest + .iter() + .map(|c| Ok((input.schema().index_of_column(c)?, c))) + .collect::>>()?; - // Update the schema with the unnest column types changed to contain the nested types. let input_schema = input.schema(); + + let mut dependency_indices = vec![]; + // Transform input schema into new schema + // Given this comprehensive example + // + // input schema: + // 1.col1_unnest_placeholder: list[list[int]], + // 2.col1: list[list[int]] + // 3.col2: list[int] + // with unnest on unnest(col1,depth=2), unnest(col1,depth=1) and unnest(col2,depth=1) + // output schema: + // 1.unnest_col1_depth_2: int + // 2.unnest_col1_depth_1: list[int] + // 3.col1: list[list[int]] + // 4.unnest_col2_depth_1: int + // Meaning the placeholder column will be replaced by its unnested variation(s), note + // the plural. let fields = input_schema .iter() .enumerate() - .map(|(index, (q, f))| match unnested_fields.get(&index) { - Some(unnested_field) => (q.cloned(), unnested_field.clone()), - None => (q.cloned(), f.clone()), + .map(|(index, (original_qualifier, original_field))| { + match indices_to_unnest.get(&index) { + Some(column_to_unnest) => { + let recursions_on_column = options + .recursions + .iter() + .filter(|p| -> bool { &p.input_column == *column_to_unnest }) + .collect::>(); + let mut transformed_columns = recursions_on_column + .iter() + .map(|r| { + list_columns.push(( + index, + ColumnUnnestList { + output_column: r.output_column.clone(), + depth: r.depth, + }, + )); + Ok(get_unnested_columns( + &r.output_column.name, + original_field.data_type(), + r.depth, + )? + .into_iter() + .next() + .unwrap()) // because unnesting a list column always result into one result + }) + .collect::)>>>()?; + if transformed_columns.is_empty() { + transformed_columns = get_unnested_columns( + &column_to_unnest.name, + original_field.data_type(), + 1, + )?; + match original_field.data_type() { + DataType::Struct(_) => { + struct_columns.push(index); + } + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + list_columns.push(( + index, + ColumnUnnestList { + output_column: Column::from_name( + &column_to_unnest.name, + ), + depth: 1, + }, + )); + } + _ => {} + }; + } + + // new columns dependent on the same original index + dependency_indices + .extend(std::iter::repeat(index).take(transformed_columns.len())); + Ok(transformed_columns + .iter() + .map(|(col, data_type)| { + (col.relation.to_owned(), data_type.to_owned()) + }) + .collect()) + } + None => { + dependency_indices.push(index); + Ok(vec![( + original_qualifier.cloned(), + Arc::clone(original_field), + )]) + } + } }) + .collect::>>()? + .into_iter() + .flatten() .collect::>(); let metadata = input_schema.metadata().clone(); @@ -1596,9 +1938,13 @@ pub fn unnest_with_options( // We can use the existing functional dependencies: let deps = input_schema.functional_dependencies().clone(); let schema = Arc::new(df_schema.with_functional_dependencies(deps)?); + Ok(LogicalPlan::Unnest(Unnest { input: Arc::new(input), - columns: qualified_columns, + exec_columns: columns_to_unnest, + list_type_columns: list_columns, + struct_type_columns: struct_columns, + dependency_indices, schema, options, })) @@ -1606,11 +1952,12 @@ pub fn unnest_with_options( #[cfg(test)] mod tests { + use super::*; use crate::logical_plan::StringifiedPlan; - use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery, sum}; + use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; - use datafusion_common::SchemaError; + use datafusion_common::{RecursionUnnestOption, SchemaError}; #[test] fn plan_builder_simple() -> Result<()> { @@ -1624,7 +1971,7 @@ mod tests { \n Filter: employee_csv.state = Utf8(\"CO\")\ \n TableScan: employee_csv projection=[id, state]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -1643,7 +1990,7 @@ mod tests { .unwrap(); assert_eq!(&expected, plan.schema().as_ref()); - // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifer + // Note scan of "EMPLOYEE_CSV" is treated as a SQL identifier // (and thus normalized to "employee"csv") as well let projection = None; let plan = @@ -1664,62 +2011,20 @@ mod tests { ); } - #[test] - fn plan_builder_aggregate() -> Result<()> { - let plan = - table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? - .aggregate( - vec![col("state")], - vec![sum(col("salary")).alias("total_salary")], - )? - .project(vec![col("state"), col("total_salary")])? - .limit(2, Some(10))? - .build()?; - - let expected = "Limit: skip=2, fetch=10\ - \n Projection: employee_csv.state, total_salary\ - \n Aggregate: groupBy=[[employee_csv.state]], aggr=[[SUM(employee_csv.salary) AS total_salary]]\ - \n TableScan: employee_csv projection=[state, salary]"; - - assert_eq!(expected, format!("{plan:?}")); - - Ok(()) - } - #[test] fn plan_builder_sort() -> Result<()> { let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? .sort(vec![ - Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + expr::Sort::new(col("state"), true, true), + expr::Sort::new(col("salary"), false, false), ])? .build()?; let expected = "Sort: employee_csv.state ASC NULLS FIRST, employee_csv.salary DESC NULLS LAST\ \n TableScan: employee_csv projection=[state, salary]"; - assert_eq!(expected, format!("{plan:?}")); - - Ok(()) - } - - #[test] - fn plan_using_join_wildcard_projection() -> Result<()> { - let t2 = table_scan(Some("t2"), &employee_schema(), None)?.build()?; - - let plan = table_scan(Some("t1"), &employee_schema(), None)? - .join_using(t2, JoinType::Inner, vec!["id"])? - .project(vec![Expr::Wildcard { qualifier: None }])? - .build()?; - - // id column should only show up once in projection - let expected = "Projection: t1.id, t1.first_name, t1.last_name, t1.state, t1.salary, t2.first_name, t2.last_name, t2.state, t2.salary\ - \n Inner Join: Using t1.id = t2.id\ - \n TableScan: t1\ - \n TableScan: t2"; - - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -1744,7 +2049,7 @@ mod tests { \n TableScan: employee_csv projection=[state, salary]\ \n TableScan: employee_csv projection=[state, salary]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -1773,24 +2078,7 @@ mod tests { \n TableScan: employee_csv projection=[state, salary]\ \n TableScan: employee_csv projection=[state, salary]"; - assert_eq!(expected, format!("{plan:?}")); - - Ok(()) - } - - #[test] - fn plan_builder_union_different_num_columns_error() -> Result<()> { - let plan1 = - table_scan(TableReference::none(), &employee_schema(), Some(vec![3]))?; - let plan2 = - table_scan(TableReference::none(), &employee_schema(), Some(vec![3, 4]))?; - - let expected = "Error during planning: Union queries must have the same number of columns, (left is 1, right is 2)"; - let err_msg1 = plan1.clone().union(plan2.clone().build()?).unwrap_err(); - let err_msg2 = plan1.union_distinct(plan2.build()?).unwrap_err(); - - assert_eq!(err_msg1.strip_backtrace(), expected); - assert_eq!(err_msg2.strip_backtrace(), expected); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -1810,7 +2098,7 @@ mod tests { \n Filter: employee_csv.state = Utf8(\"CO\")\ \n TableScan: employee_csv projection=[id, state]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -1837,7 +2125,7 @@ mod tests { \n TableScan: foo\ \n Projection: bar.a\ \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query:?}")); + assert_eq!(expected, format!("{outer_query}")); Ok(()) } @@ -1865,7 +2153,7 @@ mod tests { \n TableScan: foo\ \n Projection: bar.a\ \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query:?}")); + assert_eq!(expected, format!("{outer_query}")); Ok(()) } @@ -1891,7 +2179,7 @@ mod tests { \n Projection: foo.b\ \n TableScan: foo\ \n TableScan: bar"; - assert_eq!(expected, format!("{outer_query:?}")); + assert_eq!(expected, format!("{outer_query}")); Ok(()) } @@ -1926,36 +2214,6 @@ mod tests { } } - #[test] - fn aggregate_non_unique_names() -> Result<()> { - let plan = table_scan( - Some("employee_csv"), - &employee_schema(), - // project state and salary by column index - Some(vec![3, 4]), - )? - // two columns with the same name => error - .aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]); - - match plan { - Err(DataFusionError::SchemaError( - SchemaError::AmbiguousReference { - field: - Column { - relation: Some(TableReference::Bare { table }), - name, - }, - }, - _, - )) => { - assert_eq!(*"employee_csv", *table); - assert_eq!("state", &name); - Ok(()) - } - _ => plan_err!("Plan should have returned an DataFusionError::SchemaError"), - } - } - fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -2027,13 +2285,13 @@ mod tests { #[test] fn plan_builder_unnest() -> Result<()> { - // Unnesting a simple column should return the child plan. - let plan = nested_table_scan("test_table")? - .unnest_column("scalar")? - .build()?; - - let expected = "TableScan: test_table"; - assert_eq!(expected, format!("{plan:?}")); + // Cannot unnest on a scalar column + let err = nested_table_scan("test_table")? + .unnest_column("scalar") + .unwrap_err(); + assert!(err + .to_string() + .starts_with("Internal error: trying to unnest on invalid data type UInt32")); // Unnesting the strings list. let plan = nested_table_scan("test_table")? @@ -2041,46 +2299,123 @@ mod tests { .build()?; let expected = "\ - Unnest: test_table.strings\ + Unnest: lists[test_table.strings|depth=1] structs[]\ \n TableScan: test_table"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); // Check unnested field is a scalar - let field = plan - .schema() - .field_with_name(Some(&TableReference::bare("test_table")), "strings") - .unwrap(); + let field = plan.schema().field_with_name(None, "strings").unwrap(); assert_eq!(&DataType::Utf8, field.data_type()); - // Unnesting multiple fields. + // Unnesting the singular struct column result into 2 new columns for each subfield + let plan = nested_table_scan("test_table")? + .unnest_column("struct_singular")? + .build()?; + + let expected = "\ + Unnest: lists[] structs[test_table.struct_singular]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); + + for field_name in &["a", "b"] { + // Check unnested struct field is a scalar + let field = plan + .schema() + .field_with_name(None, &format!("struct_singular.{}", field_name)) + .unwrap(); + assert_eq!(&DataType::UInt32, field.data_type()); + } + + // Unnesting multiple fields in separate plans let plan = nested_table_scan("test_table")? .unnest_column("strings")? .unnest_column("structs")? + .unnest_column("struct_singular")? .build()?; let expected = "\ - Unnest: test_table.structs\ - \n Unnest: test_table.strings\ - \n TableScan: test_table"; - assert_eq!(expected, format!("{plan:?}")); + Unnest: lists[] structs[test_table.struct_singular]\ + \n Unnest: lists[test_table.structs|depth=1] structs[]\ + \n Unnest: lists[test_table.strings|depth=1] structs[]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); // Check unnested struct list field should be a struct. - let field = plan - .schema() - .field_with_name(Some(&TableReference::bare("test_table")), "structs") - .unwrap(); + let field = plan.schema().field_with_name(None, "structs").unwrap(); assert!(matches!(field.data_type(), DataType::Struct(_))); + // Unnesting multiple fields at the same time, using infer syntax + let cols = vec!["strings", "structs", "struct_singular"] + .into_iter() + .map(|c| c.into()) + .collect(); + + let plan = nested_table_scan("test_table")? + .unnest_columns_with_options(cols, UnnestOptions::default())? + .build()?; + + let expected = "\ + Unnest: lists[test_table.strings|depth=1, test_table.structs|depth=1] structs[test_table.struct_singular]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); + // Unnesting missing column should fail. let plan = nested_table_scan("test_table")?.unnest_column("missing"); assert!(plan.is_err()); + // Simultaneously unnesting a list (with different depth) and a struct column + let plan = nested_table_scan("test_table")? + .unnest_columns_with_options( + vec!["stringss".into(), "struct_singular".into()], + UnnestOptions::default() + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_1".into(), + depth: 1, + }) + .with_recursions(RecursionUnnestOption { + input_column: "stringss".into(), + output_column: "stringss_depth_2".into(), + depth: 2, + }), + )? + .build()?; + + let expected = "\ + Unnest: lists[test_table.stringss|depth=1, test_table.stringss|depth=2] structs[test_table.struct_singular]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); + + // Check output columns has correct type + let field = plan + .schema() + .field_with_name(None, "stringss_depth_1") + .unwrap(); + assert_eq!( + &DataType::new_list(DataType::Utf8, false), + field.data_type() + ); + let field = plan + .schema() + .field_with_name(None, "stringss_depth_2") + .unwrap(); + assert_eq!(&DataType::Utf8, field.data_type()); + // unnesting struct is still correct + for field_name in &["a", "b"] { + let field = plan + .schema() + .field_with_name(None, &format!("struct_singular.{}", field_name)) + .unwrap(); + assert_eq!(&DataType::UInt32, field.data_type()); + } + Ok(()) } fn nested_table_scan(table_name: &str) -> Result { - // Create a schema with a scalar field, a list of strings, and a list of structs. - let struct_field = Field::new_struct( + // Create a schema with a scalar field, a list of strings, a list of structs + // and a singular struct + let struct_field_in_list = Field::new_struct( "item", vec![ Field::new("a", DataType::UInt32, false), @@ -2089,10 +2424,20 @@ mod tests { false, ); let string_field = Field::new("item", DataType::Utf8, false); + let strings_field = Field::new_list("item", string_field.clone(), false); let schema = Schema::new(vec![ Field::new("scalar", DataType::UInt32, false), Field::new_list("strings", string_field, false), - Field::new_list("structs", struct_field, false), + Field::new_list("structs", struct_field_in_list, false), + Field::new( + "struct_singular", + DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("b", DataType::UInt32, false), + ])), + false, + ), + Field::new_list("stringss", strings_field, false), ]); table_scan(Some(table_name), &schema, None) @@ -2117,6 +2462,7 @@ mod tests { Ok(()) } + #[test] fn test_change_redundant_column() -> Result<()> { let t1_field_1 = Field::new("a", DataType::Int32, false); @@ -2140,4 +2486,21 @@ mod tests { ); Ok(()) } + + #[test] + fn plan_builder_from_logical_plan() -> Result<()> { + let plan = + table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? + .sort(vec![ + expr::Sort::new(col("state"), true, true), + expr::Sort::new(col("salary"), false, false), + ])? + .build()?; + + let plan_expected = format!("{plan}"); + let plan_builder: LogicalPlanBuilder = Arc::new(plan).into(); + assert_eq!(plan_expected, format!("{}", plan_builder.plan)); + + Ok(()) + } } diff --git a/datafusion/expr/src/logical_plan/ddl.rs b/datafusion/expr/src/logical_plan/ddl.rs index 8d72c9a8b036..93e8b5fd045e 100644 --- a/datafusion/expr/src/logical_plan/ddl.rs +++ b/datafusion/expr/src/logical_plan/ddl.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::{Expr, LogicalPlan, SortExpr, Volatility}; +use std::cmp::Ordering; use std::collections::HashMap; use std::sync::Arc; use std::{ @@ -22,15 +24,13 @@ use std::{ hash::{Hash, Hasher}, }; -use crate::{Expr, LogicalPlan, Volatility}; - +use crate::expr::Sort; use arrow::datatypes::DataType; -use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{Constraints, DFSchemaRef, SchemaReference, TableReference}; use sqlparser::ast::Ident; /// Various types of DDL (CREATE / DROP) catalog manipulation -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum DdlStatement { /// Creates an external table. CreateExternalTable(CreateExternalTable), @@ -42,6 +42,8 @@ pub enum DdlStatement { CreateCatalogSchema(CreateCatalogSchema), /// Creates a new catalog (aka "Database"). CreateCatalog(CreateCatalog), + /// Creates a new index. + CreateIndex(CreateIndex), /// Drops a table. DropTable(DropTable), /// Drops a view. @@ -67,6 +69,7 @@ impl DdlStatement { schema } DdlStatement::CreateCatalog(CreateCatalog { schema, .. }) => schema, + DdlStatement::CreateIndex(CreateIndex { schema, .. }) => schema, DdlStatement::DropTable(DropTable { schema, .. }) => schema, DdlStatement::DropView(DropView { schema, .. }) => schema, DdlStatement::DropCatalogSchema(DropCatalogSchema { schema, .. }) => schema, @@ -84,6 +87,7 @@ impl DdlStatement { DdlStatement::CreateView(_) => "CreateView", DdlStatement::CreateCatalogSchema(_) => "CreateCatalogSchema", DdlStatement::CreateCatalog(_) => "CreateCatalog", + DdlStatement::CreateIndex(_) => "CreateIndex", DdlStatement::DropTable(_) => "DropTable", DdlStatement::DropView(_) => "DropView", DdlStatement::DropCatalogSchema(_) => "DropCatalogSchema", @@ -102,6 +106,7 @@ impl DdlStatement { vec![input] } DdlStatement::CreateView(CreateView { input, .. }) => vec![input], + DdlStatement::CreateIndex(_) => vec![], DdlStatement::DropTable(_) => vec![], DdlStatement::DropView(_) => vec![], DdlStatement::DropCatalogSchema(_) => vec![], @@ -115,7 +120,7 @@ impl DdlStatement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a DdlStatement); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -148,6 +153,9 @@ impl DdlStatement { }) => { write!(f, "CreateCatalog: {catalog_name:?}") } + DdlStatement::CreateIndex(CreateIndex { name, .. }) => { + write!(f, "CreateIndex: {name:?}") + } DdlStatement::DropTable(DropTable { name, if_exists, .. }) => { @@ -180,7 +188,7 @@ impl DdlStatement { } /// Creates an external table. -#[derive(Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateExternalTable { /// The table schema pub schema: DFSchemaRef, @@ -190,20 +198,16 @@ pub struct CreateExternalTable { pub location: String, /// The file type of physical file pub file_type: String, - /// Whether the CSV file contains a header - pub has_header: bool, - /// Delimiter for CSV - pub delimiter: char, /// Partition Columns pub table_partition_cols: Vec, /// Option to not error if table already exists pub if_not_exists: bool, + /// Whether the table is a temporary table + pub temporary: bool, /// SQL used to create the table, if available pub definition: Option, /// Order expressions supplied by user - pub order_exprs: Vec>, - /// File compression type (GZIP, BZIP2, XZ, ZSTD) - pub file_compression_type: CompressionTypeVariant, + pub order_exprs: Vec>, /// Whether the table is an infinite streams pub unbounded: bool, /// Table(provider) specific options @@ -221,20 +225,68 @@ impl Hash for CreateExternalTable { self.name.hash(state); self.location.hash(state); self.file_type.hash(state); - self.has_header.hash(state); - self.delimiter.hash(state); self.table_partition_cols.hash(state); self.if_not_exists.hash(state); self.definition.hash(state); - self.file_compression_type.hash(state); self.order_exprs.hash(state); self.unbounded.hash(state); self.options.len().hash(state); // HashMap is not hashable } } +// Manual implementation needed because of `schema`, `options`, and `column_defaults` fields. +// Comparison excludes these fields. +impl PartialOrd for CreateExternalTable { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableCreateExternalTable<'a> { + /// The table name + pub name: &'a TableReference, + /// The physical location + pub location: &'a String, + /// The file type of physical file + pub file_type: &'a String, + /// Partition Columns + pub table_partition_cols: &'a Vec, + /// Option to not error if table already exists + pub if_not_exists: &'a bool, + /// SQL used to create the table, if available + pub definition: &'a Option, + /// Order expressions supplied by user + pub order_exprs: &'a Vec>, + /// Whether the table is an infinite streams + pub unbounded: &'a bool, + /// The list of constraints in the schema, such as primary key, unique, etc. + pub constraints: &'a Constraints, + } + let comparable_self = ComparableCreateExternalTable { + name: &self.name, + location: &self.location, + file_type: &self.file_type, + table_partition_cols: &self.table_partition_cols, + if_not_exists: &self.if_not_exists, + definition: &self.definition, + order_exprs: &self.order_exprs, + unbounded: &self.unbounded, + constraints: &self.constraints, + }; + let comparable_other = ComparableCreateExternalTable { + name: &other.name, + location: &other.location, + file_type: &other.file_type, + table_partition_cols: &other.table_partition_cols, + if_not_exists: &other.if_not_exists, + definition: &other.definition, + order_exprs: &other.order_exprs, + unbounded: &other.unbounded, + constraints: &other.constraints, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + /// Creates an in memory table. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct CreateMemoryTable { /// The table name pub name: TableReference, @@ -248,10 +300,12 @@ pub struct CreateMemoryTable { pub or_replace: bool, /// Default values for columns pub column_defaults: Vec<(String, Expr)>, + /// Wheter the table is `TableType::Temporary` + pub temporary: bool, } /// Creates a view. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Hash)] pub struct CreateView { /// The table name pub name: TableReference, @@ -261,10 +315,12 @@ pub struct CreateView { pub or_replace: bool, /// SQL used to create the view, if available pub definition: Option, + /// Wheter the view is ephemeral + pub temporary: bool, } /// Creates a catalog (aka "Database"). -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CreateCatalog { /// The catalog name pub catalog_name: String, @@ -274,8 +330,18 @@ pub struct CreateCatalog { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for CreateCatalog { + fn partial_cmp(&self, other: &Self) -> Option { + match self.catalog_name.partial_cmp(&other.catalog_name) { + Some(Ordering::Equal) => self.if_not_exists.partial_cmp(&other.if_not_exists), + cmp => cmp, + } + } +} + /// Creates a schema. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct CreateCatalogSchema { /// The table schema pub schema_name: String, @@ -285,8 +351,18 @@ pub struct CreateCatalogSchema { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for CreateCatalogSchema { + fn partial_cmp(&self, other: &Self) -> Option { + match self.schema_name.partial_cmp(&other.schema_name) { + Some(Ordering::Equal) => self.if_not_exists.partial_cmp(&other.if_not_exists), + cmp => cmp, + } + } +} + /// Drops a table. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DropTable { /// The table name pub name: TableReference, @@ -296,8 +372,18 @@ pub struct DropTable { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for DropTable { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name.partial_cmp(&other.name) { + Some(Ordering::Equal) => self.if_exists.partial_cmp(&other.if_exists), + cmp => cmp, + } + } +} + /// Drops a view. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DropView { /// The view name pub name: TableReference, @@ -307,8 +393,18 @@ pub struct DropView { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for DropView { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name.partial_cmp(&other.name) { + Some(Ordering::Equal) => self.if_exists.partial_cmp(&other.if_exists), + cmp => cmp, + } + } +} + /// Drops a schema -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DropCatalogSchema { /// The schema name pub name: SchemaReference, @@ -320,6 +416,19 @@ pub struct DropCatalogSchema { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for DropCatalogSchema { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name.partial_cmp(&other.name) { + Some(Ordering::Equal) => match self.if_exists.partial_cmp(&other.if_exists) { + Some(Ordering::Equal) => self.cascade.partial_cmp(&other.cascade), + cmp => cmp, + }, + cmp => cmp, + } + } +} + /// Arguments passed to `CREATE FUNCTION` /// /// Note this meant to be the same as from sqlparser's [`sqlparser::ast::Statement::CreateFunction`] @@ -337,7 +446,40 @@ pub struct CreateFunction { /// Dummy schema pub schema: DFSchemaRef, } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for CreateFunction { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableCreateFunction<'a> { + pub or_replace: &'a bool, + pub temporary: &'a bool, + pub name: &'a String, + pub args: &'a Option>, + pub return_type: &'a Option, + pub params: &'a CreateFunctionBody, + } + let comparable_self = ComparableCreateFunction { + or_replace: &self.or_replace, + temporary: &self.temporary, + name: &self.name, + args: &self.args, + return_type: &self.return_type, + params: &self.params, + }; + let comparable_other = ComparableCreateFunction { + or_replace: &other.or_replace, + temporary: &other.temporary, + name: &other.name, + args: &other.args, + return_type: &other.return_type, + params: &other.params, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct OperateFunctionArg { // TODO: figure out how to support mode // pub mode: Option, @@ -345,40 +487,102 @@ pub struct OperateFunctionArg { pub data_type: DataType, pub default_expr: Option, } -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub struct CreateFunctionBody { /// LANGUAGE lang_name pub language: Option, /// IMMUTABLE | STABLE | VOLATILE pub behavior: Option, - /// AS 'definition' - pub as_: Option, - /// RETURN expression - pub return_: Option, + /// RETURN or AS function body + pub function_body: Option, } #[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub enum DefinitionStatement { - SingleQuotedDef(String), - DoubleDollarDef(String), +pub struct DropFunction { + pub name: String, + pub if_exists: bool, + pub schema: DFSchemaRef, } -impl From for DefinitionStatement { - fn from(value: sqlparser::ast::FunctionDefinition) -> Self { - match value { - sqlparser::ast::FunctionDefinition::SingleQuotedDef(s) => { - Self::SingleQuotedDef(s) - } - sqlparser::ast::FunctionDefinition::DoubleDollarDef(s) => { - Self::DoubleDollarDef(s) - } +impl PartialOrd for DropFunction { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name.partial_cmp(&other.name) { + Some(Ordering::Equal) => self.if_exists.partial_cmp(&other.if_exists), + cmp => cmp, } } } #[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct DropFunction { - pub name: String, - pub if_exists: bool, +pub struct CreateIndex { + pub name: Option, + pub table: TableReference, + pub using: Option, + pub columns: Vec, + pub unique: bool, + pub if_not_exists: bool, pub schema: DFSchemaRef, } + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for CreateIndex { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableCreateIndex<'a> { + pub name: &'a Option, + pub table: &'a TableReference, + pub using: &'a Option, + pub columns: &'a Vec, + pub unique: &'a bool, + pub if_not_exists: &'a bool, + } + let comparable_self = ComparableCreateIndex { + name: &self.name, + table: &self.table, + using: &self.using, + columns: &self.columns, + unique: &self.unique, + if_not_exists: &self.if_not_exists, + }; + let comparable_other = ComparableCreateIndex { + name: &other.name, + table: &other.table, + using: &other.using, + columns: &other.columns, + unique: &other.unique, + if_not_exists: &other.if_not_exists, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + +#[cfg(test)] +mod test { + use crate::{CreateCatalog, DdlStatement, DropView}; + use datafusion_common::{DFSchema, DFSchemaRef, TableReference}; + use std::cmp::Ordering; + + #[test] + fn test_partial_ord() { + let catalog = DdlStatement::CreateCatalog(CreateCatalog { + catalog_name: "name".to_string(), + if_not_exists: false, + schema: DFSchemaRef::new(DFSchema::empty()), + }); + let catalog_2 = DdlStatement::CreateCatalog(CreateCatalog { + catalog_name: "name".to_string(), + if_not_exists: true, + schema: DFSchemaRef::new(DFSchema::empty()), + }); + + assert_eq!(catalog.partial_cmp(&catalog_2), Some(Ordering::Less)); + + let drop_view = DdlStatement::DropView(DropView { + name: TableReference::from("table"), + if_exists: false, + schema: DFSchemaRef::new(DFSchema::empty()), + }); + + assert_eq!(drop_view.partial_cmp(&catalog), Some(Ordering::Greater)); + } +} diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 3a2ed9ffc2d8..9aea7747c414 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -20,17 +20,17 @@ use std::collections::HashMap; use std::fmt; use crate::{ - expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, - Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, - Unnest, Values, Window, + expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Execute, + Expr, Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, + RecursiveQuery, Repartition, Sort, Subquery, SubqueryAlias, + TableProviderFilterPushDown, TableScan, Unnest, Values, Window, }; use crate::dml::CopyTo; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; -use datafusion_common::DataFusionError; +use datafusion_common::{Column, DataFusionError}; use serde_json::json; /// Formats plans with a single line per node. For example: @@ -58,12 +58,12 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } } -impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { +impl<'n, 'a, 'b> TreeNodeVisitor<'n> for IndentVisitor<'a, 'b> { type Node = LogicalPlan; fn f_down( &mut self, - plan: &LogicalPlan, + plan: &'n LogicalPlan, ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; @@ -84,7 +84,7 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { fn f_up( &mut self, - _plan: &LogicalPlan, + _plan: &'n LogicalPlan, ) -> datafusion_common::Result { self.indent -= 1; Ok(TreeNodeRecursion::Continue) @@ -180,12 +180,12 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } } -impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { +impl<'n, 'a, 'b> TreeNodeVisitor<'n> for GraphvizVisitor<'a, 'b> { type Node = LogicalPlan; fn f_down( &mut self, - plan: &LogicalPlan, + plan: &'n LogicalPlan, ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); @@ -338,9 +338,9 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { .collect::>() .join(", "); - let elipse = if values.len() > 5 { "..." } else { "" }; + let eclipse = if values.len() > 5 { "..." } else { "" }; - let values_str = format!("{}{}", str_values, elipse); + let values_str = format!("{}{}", str_values, eclipse); json!({ "Node Type": "Values", "Values": values_str @@ -387,19 +387,16 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { } if !full_filter.is_empty() { - object["Full Filters"] = serde_json::Value::String( - expr_vec_fmt!(full_filter).to_string(), - ); + object["Full Filters"] = + serde_json::Value::String(expr_vec_fmt!(full_filter)); }; if !partial_filter.is_empty() { - object["Partial Filters"] = serde_json::Value::String( - expr_vec_fmt!(partial_filter).to_string(), - ); + object["Partial Filters"] = + serde_json::Value::String(expr_vec_fmt!(partial_filter)); } if !unsupported_filters.is_empty() { - object["Unsupported Filters"] = serde_json::Value::String( - expr_vec_fmt!(unsupported_filters).to_string(), - ); + object["Unsupported Filters"] = + serde_json::Value::String(expr_vec_fmt!(unsupported_filters)); } } @@ -425,7 +422,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, partition_by: _, options, }) => { @@ -437,7 +434,7 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { json!({ "Node Type": "CopyTo", "Output URL": output_url, - "Format Options": format!("{}", format_options), + "File Type": format!("{}", file_type.get_ext()), "Options": op_str }) } @@ -507,11 +504,6 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Filter": format!("{}", filter_expr) }) } - LogicalPlan::CrossJoin(_) => { - json!({ - "Node Type": "Cross Join" - }) - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -552,11 +544,13 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { let mut object = serde_json::json!( { "Node Type": "Limit", - "Skip": skip, } ); + if let Some(s) = skip { + object["Skip"] = s.to_string().into() + }; if let Some(f) = fetch { - object["Fetch"] = serde_json::Value::Number((*f).into()); + object["Fetch"] = f.to_string().into() }; object } @@ -595,9 +589,8 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Select": expr_vec_fmt!(select_expr), }); if let Some(sort_expr) = sort_expr { - object["Sort"] = serde_json::Value::String( - expr_vec_fmt!(sort_expr).to_string(), - ); + object["Sort"] = + serde_json::Value::String(expr_vec_fmt!(sort_expr)); } object @@ -633,27 +626,57 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> { "Data Types": format!("{:?}", data_types) }) } + LogicalPlan::Execute(Execute { + name, parameters, .. + }) => { + json!({ + "Node Type": "Execute", + "Name": name, + "Parameters": expr_vec_fmt!(parameters), + }) + } LogicalPlan::DescribeTable(DescribeTable { .. }) => { json!({ "Node Type": "DescribeTable" }) } - LogicalPlan::Unnest(Unnest { columns, .. }) => { + LogicalPlan::Unnest(Unnest { + input: plan, + list_type_columns: list_col_indices, + struct_type_columns: struct_col_indices, + .. + }) => { + let input_columns = plan.schema().columns(); + let list_type_columns = list_col_indices + .iter() + .map(|(i, unnest_info)| { + format!( + "{}|depth={:?}", + &input_columns[*i].to_string(), + unnest_info.depth + ) + }) + .collect::>(); + let struct_type_columns = struct_col_indices + .iter() + .map(|i| &input_columns[*i]) + .collect::>(); json!({ "Node Type": "Unnest", - "Column": expr_vec_fmt!(columns), + "ListColumn": expr_vec_fmt!(list_type_columns), + "StructColumn": expr_vec_fmt!(struct_type_columns), }) } } } } -impl<'a, 'b> TreeNodeVisitor for PgJsonVisitor<'a, 'b> { +impl<'n, 'a, 'b> TreeNodeVisitor<'n> for PgJsonVisitor<'a, 'b> { type Node = LogicalPlan; fn f_down( &mut self, - node: &LogicalPlan, + node: &'n LogicalPlan, ) -> datafusion_common::Result { let id = self.next_id; self.next_id += 1; diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 9c0fe0f30486..669bc8e8a7d3 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::cmp::Ordering; use std::collections::HashMap; -use std::fmt::{self, Display}; +use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use datafusion_common::config::FormatOptions; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; use crate::LogicalPlan; @@ -34,12 +36,24 @@ pub struct CopyTo { pub output_url: String, /// Determines which, if any, columns should be used for hive-style partitioned writes pub partition_by: Vec, - /// File format options. - pub format_options: FormatOptions, + /// File type trait + pub file_type: Arc, /// SQL Options that can affect the formats pub options: HashMap, } +impl Debug for CopyTo { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("CopyTo") + .field("input", &self.input) + .field("output_url", &self.output_url) + .field("partition_by", &self.partition_by) + .field("file_type", &"...") + .field("options", &self.options) + .finish_non_exhaustive() + } +} + // Implement PartialEq manually impl PartialEq for CopyTo { fn eq(&self, other: &Self) -> bool { @@ -50,6 +64,23 @@ impl PartialEq for CopyTo { // Implement Eq (no need for additional logic over PartialEq) impl Eq for CopyTo {} +// Manual implementation needed because of `file_type` and `options` fields. +// Comparison excludes these field. +impl PartialOrd for CopyTo { + fn partial_cmp(&self, other: &Self) -> Option { + match self.input.partial_cmp(&other.input) { + Some(Ordering::Equal) => match self.output_url.partial_cmp(&other.output_url) + { + Some(Ordering::Equal) => { + self.partition_by.partial_cmp(&other.partition_by) + } + cmp => cmp, + }, + cmp => cmp, + } + } +} + // Implement Hash manually impl Hash for CopyTo { fn hash(&self, state: &mut H) { @@ -60,7 +91,7 @@ impl Hash for CopyTo { /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DmlStatement { /// The table name pub table_name: TableReference, @@ -70,19 +101,52 @@ pub struct DmlStatement { pub op: WriteOp, /// The relation that determines the tuples to add/remove/modify the schema must match with table_schema pub input: Arc, + /// The schema of the output relation + pub output_schema: DFSchemaRef, } impl DmlStatement { + /// Creates a new DML statement with the output schema set to a single `count` column. + pub fn new( + table_name: TableReference, + table_schema: DFSchemaRef, + op: WriteOp, + input: Arc, + ) -> Self { + Self { + table_name, + table_schema, + op, + input, + + // The output schema is always a single column with the number of rows affected + output_schema: make_count_schema(), + } + } + /// Return a descriptive name of this [`DmlStatement`] pub fn name(&self) -> &str { self.op.name() } } -#[derive(Clone, PartialEq, Eq, Hash)] +// Manual implementation needed because of `table_schema` and `output_schema` fields. +// Comparison excludes these fields. +impl PartialOrd for DmlStatement { + fn partial_cmp(&self, other: &Self) -> Option { + match self.table_name.partial_cmp(&other.table_name) { + Some(Ordering::Equal) => match self.op.partial_cmp(&other.op) { + Some(Ordering::Equal) => self.input.partial_cmp(&other.input), + cmp => cmp, + }, + cmp => cmp, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { - InsertOverwrite, - InsertInto, + Insert(InsertOp), Delete, Update, Ctas, @@ -92,8 +156,7 @@ impl WriteOp { /// Return a descriptive name of this [`WriteOp`] pub fn name(&self) -> &str { match self { - WriteOp::InsertOverwrite => "Insert Overwrite", - WriteOp::InsertInto => "Insert Into", + WriteOp::Insert(insert) => insert.name(), WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", @@ -102,7 +165,46 @@ impl WriteOp { } impl Display for WriteOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "{}", self.name()) } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum InsertOp { + /// Appends new rows to the existing table without modifying any + /// existing rows. This corresponds to the SQL `INSERT INTO` query. + Append, + /// Overwrites all existing rows in the table with the new rows. + /// This corresponds to the SQL `INSERT OVERWRITE` query. + Overwrite, + /// If any existing rows collides with the inserted rows (typically based + /// on a unique key or primary key), those existing rows are replaced. + /// This corresponds to the SQL `REPLACE INTO` query and its equivalents. + Replace, +} + +impl InsertOp { + /// Return a descriptive name of this [`InsertOp`] + pub fn name(&self) -> &str { + match self { + InsertOp::Append => "Insert Into", + InsertOp::Overwrite => "Insert Overwrite", + InsertOp::Replace => "Replace Into", + } + } +} + +impl Display for InsertOp { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + +fn make_count_schema() -> DFSchemaRef { + Arc::new( + Schema::new(vec![Field::new("count", DataType::UInt64, false)]) + .try_into() + .unwrap(), + ) +} diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index 7e6f07e0c509..19d4cb3db9ce 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -17,7 +17,8 @@ //! This module defines the interface for logical nodes use crate::{Expr, LogicalPlan}; -use datafusion_common::{DFSchema, DFSchemaRef}; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; +use std::cmp::Ordering; use std::hash::{Hash, Hasher}; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; @@ -76,27 +77,31 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// For example: `TopK: k=10` fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result; - /// Create a new `ExtensionPlanNode` with the specified children + #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")] + #[allow(clippy::wrong_self_convention)] + fn from_template( + &self, + exprs: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + + /// Create a new `UserDefinedLogicalNode` with the specified children /// and expressions. This function is used during optimization /// when the plan is being rewritten and a new instance of the - /// `ExtensionPlanNode` must be created. + /// `UserDefinedLogicalNode` must be created. /// /// Note that exprs and inputs are in the same order as the result /// of self.inputs and self.exprs. /// - /// So, `self.from_template(exprs, ..).expressions() == exprs - // - // TODO(clippy): This should probably be renamed to use a `with_*` prefix. Something - // like `with_template`, or `with_exprs_and_inputs`. - // - // Also, I think `ExtensionPlanNode` has been renamed to `UserDefinedLogicalNode` - // but the doc comments have not been updated. - #[allow(clippy::wrong_self_convention)] - fn from_template( + /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs + fn with_exprs_and_inputs( &self, - exprs: &[Expr], - inputs: &[LogicalPlan], - ) -> Arc; + exprs: Vec, + inputs: Vec, + ) -> Result>; /// Returns the necessary input columns for this node required to compute /// the columns in the output schema @@ -189,6 +194,17 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`] /// directly because it must remain object safe. fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool; + fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option; + + /// Returns `true` if a limit can be safely pushed down through this + /// `UserDefinedLogicalNode` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false + } } impl Hash for dyn UserDefinedLogicalNode { @@ -197,21 +213,27 @@ impl Hash for dyn UserDefinedLogicalNode { } } -impl std::cmp::PartialEq for dyn UserDefinedLogicalNode { +impl PartialEq for dyn UserDefinedLogicalNode { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other) } } +impl PartialOrd for dyn UserDefinedLogicalNode { + fn partial_cmp(&self, other: &Self) -> Option { + self.dyn_ord(other) + } +} + impl Eq for dyn UserDefinedLogicalNode {} /// This trait facilitates implementation of the [`UserDefinedLogicalNode`]. /// /// See the example in -/// [user_defined_plan.rs](../../tests/user_defined_plan.rs) for an -/// example of how to use this extension API. +/// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs) +/// file for an example of how to use this extension API. pub trait UserDefinedLogicalNodeCore: - fmt::Debug + Eq + Hash + Send + Sync + 'static + fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static { /// Return the plan's name. fn name(&self) -> &str; @@ -244,23 +266,27 @@ pub trait UserDefinedLogicalNodeCore: /// For example: `TopK: k=10` fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result; - /// Create a new `ExtensionPlanNode` with the specified children + #[deprecated(since = "39.0.0", note = "use with_exprs_and_inputs instead")] + #[allow(clippy::wrong_self_convention)] + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + self.with_exprs_and_inputs(exprs.to_vec(), inputs.to_vec()) + .unwrap() + } + + /// Create a new `UserDefinedLogicalNode` with the specified children /// and expressions. This function is used during optimization /// when the plan is being rewritten and a new instance of the - /// `ExtensionPlanNode` must be created. + /// `UserDefinedLogicalNode` must be created. /// /// Note that exprs and inputs are in the same order as the result /// of self.inputs and self.exprs. /// - /// So, `self.from_template(exprs, ..).expressions() == exprs - // - // TODO(clippy): This should probably be renamed to use a `with_*` prefix. Something - // like `with_template`, or `with_exprs_and_inputs`. - // - // Also, I think `ExtensionPlanNode` has been renamed to `UserDefinedLogicalNode` - // but the doc comments have not been updated. - #[allow(clippy::wrong_self_convention)] - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self; + /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs + fn with_exprs_and_inputs( + &self, + exprs: Vec, + inputs: Vec, + ) -> Result; /// Returns the necessary input columns for this node required to compute /// the columns in the output schema @@ -279,6 +305,16 @@ pub trait UserDefinedLogicalNodeCore: ) -> Option>> { None } + + /// Returns `true` if a limit can be safely pushed down through this + /// `UserDefinedLogicalNode` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` @@ -312,12 +348,12 @@ impl UserDefinedLogicalNode for T { self.fmt_for_explain(f) } - fn from_template( + fn with_exprs_and_inputs( &self, - exprs: &[Expr], - inputs: &[LogicalPlan], - ) -> Arc { - Arc::new(self.from_template(exprs, inputs)) + exprs: Vec, + inputs: Vec, + ) -> Result> { + Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?)) } fn necessary_children_exprs( @@ -338,6 +374,17 @@ impl UserDefinedLogicalNode for T { None => false, } } + + fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option { + other + .as_any() + .downcast_ref::() + .and_then(|other| self.partial_cmp(other)) + } + + fn supports_limit_pushdown(&self) -> bool { + self.supports_limit_pushdown() + } } fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet { diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index 034440643e51..59654a227829 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -26,20 +26,20 @@ pub mod tree_node; pub use builder::{ build_join_schema, table_scan, union, wrap_projection_for_join_if_necessary, - LogicalPlanBuilder, UNNAMED_TABLE, + LogicalPlanBuilder, LogicalTableSource, UNNAMED_TABLE, }; pub use ddl::{ CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateFunction, - CreateFunctionBody, CreateMemoryTable, CreateView, DdlStatement, DefinitionStatement, + CreateFunctionBody, CreateIndex, CreateMemoryTable, CreateView, DdlStatement, DropCatalogSchema, DropFunction, DropTable, DropView, OperateFunctionArg, }; pub use dml::{DmlStatement, WriteOp}; pub use plan::{ - projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, - DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, - JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, - RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, - TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, + projection_schema, Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, + DistinctOn, EmptyRelation, Execute, Explain, Extension, FetchType, Filter, Join, + JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, + Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, + SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 64c5b56a4080..191a42e38e3a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -17,58 +17,183 @@ //! Logical plan types +use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use super::dml::CopyTo; use super::DdlStatement; use crate::builder::{change_redundant_column, unnest_with_options}; -use crate::expr::{Alias, Placeholder, Sort as SortExpr, WindowFunction}; -use crate::expr_rewriter::{create_col_from_scalar_expr, normalize_cols}; +use crate::expr::{Placeholder, Sort as SortExpr, WindowFunction}; +use crate::expr_rewriter::{ + create_col_from_scalar_expr, normalize_cols, normalize_sorts, NamePreserver, +}; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ - enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, + enumerate_grouping_sets, exprlist_len, exprlist_to_fields, find_base_plan, + find_out_reference_exprs, grouping_set_expr_count, grouping_set_to_exprlist, + split_conjunction, }; use crate::{ - build_join_schema, expr_vec_fmt, BinaryExpr, BuiltInWindowFunction, - CreateMemoryTable, CreateView, Expr, ExprSchemable, LogicalPlanBuilder, Operator, - TableProviderFilterPushDown, TableSource, WindowFunctionDefinition, + build_join_schema, expr_vec_fmt, BinaryExpr, CreateMemoryTable, CreateView, Expr, + ExprSchemable, LogicalPlanBuilder, Operator, TableProviderFilterPushDown, + TableSource, WindowFunctionDefinition, }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, DFSchema, DFSchemaRef, DataFusionError, Dependency, FunctionalDependence, - FunctionalDependencies, ParamValues, Result, TableReference, UnnestOptions, + FunctionalDependencies, ParamValues, Result, ScalarValue, TableReference, + UnnestOptions, }; +use indexmap::IndexSet; // backwards compatibility use crate::display::PgJsonVisitor; +use crate::tree_node::replace_sort_expressions; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; -/// A LogicalPlan represents the different types of relational -/// operators (such as Projection, Filter, etc) and can be created by -/// the SQL query planner and the DataFrame API. +/// A `LogicalPlan` is a node in a tree of relational operators (such as +/// Projection or Filter). /// -/// A LogicalPlan represents transforming an input relation (table) to -/// an output relation (table) with a (potentially) different -/// schema. A plan represents a dataflow tree where data flows -/// from leaves up to the root to produce the query result. +/// Represents transforming an input relation (table) to an output relation +/// (table) with a potentially different schema. Plans form a dataflow tree +/// where data flows from leaves up to the root to produce the query result. +/// +/// `LogicalPlan`s can be created by the SQL query planner, the DataFrame API, +/// or programmatically (for example custom query languages). /// /// # See also: -/// * [`tree_node`]: To inspect and rewrite `LogicalPlan` trees +/// * [`Expr`]: For the expressions that are evaluated by the plan +/// * [`LogicalPlanBuilder`]: For building `LogicalPlan`s +/// * [`tree_node`]: To inspect and rewrite `LogicalPlan`s /// /// [`tree_node`]: crate::logical_plan::tree_node -#[derive(Clone, PartialEq, Eq, Hash)] +/// +/// # Examples +/// +/// ## Creating a LogicalPlan from SQL: +/// +/// See [`SessionContext::sql`](https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.sql) +/// +/// ## Creating a LogicalPlan from the DataFrame API: +/// +/// See [`DataFrame::logical_plan`](https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.logical_plan) +/// +/// ## Creating a LogicalPlan programmatically: +/// +/// See [`LogicalPlanBuilder`] +/// +/// # Visiting and Rewriting `LogicalPlan`s +/// +/// Using the [`tree_node`] API, you can recursively walk all nodes in a +/// `LogicalPlan`. For example, to find all column references in a plan: +/// +/// ``` +/// # use std::collections::HashSet; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion_expr::{Expr, col, lit, LogicalPlan, LogicalPlanBuilder, table_scan}; +/// # use datafusion_common::tree_node::{TreeNodeRecursion, TreeNode}; +/// # use datafusion_common::{Column, Result}; +/// # fn employee_schema() -> Schema { +/// # Schema::new(vec![ +/// # Field::new("name", DataType::Utf8, false), +/// # Field::new("salary", DataType::Int32, false), +/// # ]) +/// # } +/// // Projection(name, salary) +/// // Filter(salary > 1000) +/// // TableScan(employee) +/// # fn main() -> Result<()> { +/// let plan = table_scan(Some("employee"), &employee_schema(), None)? +/// .filter(col("salary").gt(lit(1000)))? +/// .project(vec![col("name")])? +/// .build()?; +/// +/// // use apply to walk the plan and collect all expressions +/// let mut expressions = HashSet::new(); +/// plan.apply(|node| { +/// // collect all expressions in the plan +/// node.apply_expressions(|expr| { +/// expressions.insert(expr.clone()); +/// Ok(TreeNodeRecursion::Continue) // control walk of expressions +/// })?; +/// Ok(TreeNodeRecursion::Continue) // control walk of plan nodes +/// }).unwrap(); +/// +/// // we found the expression in projection and filter +/// assert_eq!(expressions.len(), 2); +/// println!("Found expressions: {:?}", expressions); +/// // found predicate in the Filter: employee.salary > 1000 +/// let salary = Expr::Column(Column::new(Some("employee"), "salary")); +/// assert!(expressions.contains(&salary.gt(lit(1000)))); +/// // found projection in the Projection: employee.name +/// let name = Expr::Column(Column::new(Some("employee"), "name")); +/// assert!(expressions.contains(&name)); +/// # Ok(()) +/// # } +/// ``` +/// +/// You can also rewrite plans using the [`tree_node`] API. For example, to +/// replace the filter predicate in a plan: +/// +/// ``` +/// # use std::collections::HashSet; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// # use datafusion_expr::{Expr, col, lit, LogicalPlan, LogicalPlanBuilder, table_scan}; +/// # use datafusion_common::tree_node::{TreeNodeRecursion, TreeNode}; +/// # use datafusion_common::{Column, Result}; +/// # fn employee_schema() -> Schema { +/// # Schema::new(vec![ +/// # Field::new("name", DataType::Utf8, false), +/// # Field::new("salary", DataType::Int32, false), +/// # ]) +/// # } +/// // Projection(name, salary) +/// // Filter(salary > 1000) +/// // TableScan(employee) +/// # fn main() -> Result<()> { +/// use datafusion_common::tree_node::Transformed; +/// let plan = table_scan(Some("employee"), &employee_schema(), None)? +/// .filter(col("salary").gt(lit(1000)))? +/// .project(vec![col("name")])? +/// .build()?; +/// +/// // use transform to rewrite the plan +/// let transformed_result = plan.transform(|node| { +/// // when we see the filter node +/// if let LogicalPlan::Filter(mut filter) = node { +/// // replace predicate with salary < 2000 +/// filter.predicate = Expr::Column(Column::new(Some("employee"), "salary")).lt(lit(2000)); +/// let new_plan = LogicalPlan::Filter(filter); +/// return Ok(Transformed::yes(new_plan)); // communicate the node was changed +/// } +/// // return the node unchanged +/// Ok(Transformed::no(node)) +/// }).unwrap(); +/// +/// // Transformed result contains rewritten plan and information about +/// // whether the plan was changed +/// assert!(transformed_result.transformed); +/// let rewritten_plan = transformed_result.data; +/// +/// // we found the filter +/// assert_eq!(rewritten_plan.display_indent().to_string(), +/// "Projection: employee.name\ +/// \n Filter: employee.salary < Int32(2000)\ +/// \n TableScan: employee"); +/// # Ok(()) +/// # } +/// ``` +/// +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum LogicalPlan { /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. @@ -96,9 +221,6 @@ pub enum LogicalPlan { /// Join two logical plans on one or more join columns. /// This is used to implement SQL `JOIN` Join(Join), - /// Apply Cross Join to two logical plans. - /// This is used to implement SQL `CROSS JOIN` - CrossJoin(CrossJoin), /// Repartitions the input based on a partitioning scheme. This is /// used to add parallelism and is sometimes referred to as an /// "exchange" operator in other systems @@ -144,7 +266,9 @@ pub enum LogicalPlan { /// Prepare a statement and find any bind parameters /// (e.g. `?`). This is used to implement SQL-prepared statements. Prepare(Prepare), - /// Data Manipulaton Language (DML): Insert / Update / Delete + /// Execute a prepared statement. This is used to implement SQL 'EXECUTE'. + Execute(Execute), + /// Data Manipulation Language (DML): Insert / Update / Delete Dml(DmlStatement), /// Data Definition Language (DDL): CREATE / DROP TABLES / VIEWS / SCHEMAS Ddl(DdlStatement), @@ -160,6 +284,15 @@ pub enum LogicalPlan { RecursiveQuery(RecursiveQuery), } +impl Default for LogicalPlan { + fn default() -> Self { + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }) + } +} + impl LogicalPlan { /// Get a reference to the logical plan's schema pub fn schema(&self) -> &DFSchemaRef { @@ -177,13 +310,13 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { schema, .. }) => schema, LogicalPlan::Sort(Sort { input, .. }) => input.schema(), LogicalPlan::Join(Join { schema, .. }) => schema, - LogicalPlan::CrossJoin(CrossJoin { schema, .. }) => schema, LogicalPlan::Repartition(Repartition { input, .. }) => input.schema(), LogicalPlan::Limit(Limit { input, .. }) => input.schema(), LogicalPlan::Statement(statement) => statement.schema(), LogicalPlan::Subquery(Subquery { subquery, .. }) => subquery.schema(), LogicalPlan::SubqueryAlias(SubqueryAlias { schema, .. }) => schema, LogicalPlan::Prepare(Prepare { input, .. }) => input.schema(), + LogicalPlan::Execute(Execute { schema, .. }) => schema, LogicalPlan::Explain(explain) => &explain.schema, LogicalPlan::Analyze(analyze) => &analyze.schema, LogicalPlan::Extension(extension) => extension.node.schema(), @@ -191,7 +324,7 @@ impl LogicalPlan { LogicalPlan::DescribeTable(DescribeTable { output_schema, .. }) => { output_schema } - LogicalPlan::Dml(DmlStatement { table_schema, .. }) => table_schema, + LogicalPlan::Dml(DmlStatement { output_schema, .. }) => output_schema, LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, @@ -210,8 +343,7 @@ impl LogicalPlan { | LogicalPlan::Projection(_) | LogicalPlan::Aggregate(_) | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => self + | LogicalPlan::Join(_) => self .inputs() .iter() .map(|input| input.schema().as_ref()) @@ -289,27 +421,6 @@ impl LogicalPlan { exprs } - #[deprecated(since = "37.0.0", note = "Use `apply_expressions` instead")] - pub fn inspect_expressions(self: &LogicalPlan, mut f: F) -> Result<(), E> - where - F: FnMut(&Expr) -> Result<(), E>, - { - let mut err = Ok(()); - self.apply_expressions(|e| { - if let Err(e) = f(e) { - // save the error for later (it may not be a DataFusionError - err = Err(e); - Ok(TreeNodeRecursion::Stop) - } else { - Ok(TreeNodeRecursion::Continue) - } - }) - // The closure always returns OK, so this will always too - .expect("no way to return error during recursion"); - - err - } - /// Returns all inputs / children of this `LogicalPlan` node. /// /// Note does not include inputs to inputs, or subqueries. @@ -322,7 +433,6 @@ impl LogicalPlan { LogicalPlan::Aggregate(Aggregate { input, .. }) => vec![input], LogicalPlan::Sort(Sort { input, .. }) => vec![input], LogicalPlan::Join(Join { left, right, .. }) => vec![left, right], - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => vec![left, right], LogicalPlan::Limit(Limit { input, .. }) => vec![input], LogicalPlan::Subquery(Subquery { subquery, .. }) => vec![subquery], LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => vec![input], @@ -350,6 +460,7 @@ impl LogicalPlan { | LogicalPlan::Statement { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } + | LogicalPlan::Execute { .. } | LogicalPlan::DescribeTable(_) => vec![], } } @@ -368,8 +479,18 @@ impl LogicalPlan { // The join keys in using-join must be columns. let columns = on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| { - accumu.insert(l.try_into_col()?); - accumu.insert(r.try_into_col()?); + let Some(l) = l.get_as_join_column() else { + return internal_err!( + "Invalid join key. Expected column, found {l:?}" + ); + }; + let Some(r) = r.get_as_join_column() else { + return internal_err!( + "Invalid join key. Expected column, found {r:?}" + ); + }; + accumu.insert(l.to_owned()); + accumu.insert(r.to_owned()); Result::<_, DataFusionError>::Ok(accumu) })?; using_columns.push(columns); @@ -415,16 +536,11 @@ impl LogicalPlan { left.head_output_expr() } } - JoinType::LeftSemi | JoinType::LeftAnti => left.head_output_expr(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left.head_output_expr() + } JoinType::RightSemi | JoinType::RightAnti => right.head_output_expr(), }, - LogicalPlan::CrossJoin(cross) => { - if cross.left.schema().fields().is_empty() { - cross.right.head_output_expr() - } else { - cross.left.head_output_expr() - } - } LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { static_term.head_output_expr() } @@ -448,6 +564,7 @@ impl LogicalPlan { LogicalPlan::Subquery(_) => Ok(None), LogicalPlan::EmptyRelation(_) | LogicalPlan::Prepare(_) + | LogicalPlan::Execute(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Explain(_) @@ -461,10 +578,160 @@ impl LogicalPlan { } } - /// Returns a copy of this `LogicalPlan` with the new inputs - #[deprecated(since = "35.0.0", note = "please use `with_new_exprs` instead")] - pub fn with_new_inputs(&self, inputs: &[LogicalPlan]) -> Result { - self.with_new_exprs(self.expressions(), inputs.to_vec()) + /// Recomputes schema and type information for this LogicalPlan if needed. + /// + /// Some `LogicalPlan`s may need to recompute their schema if the number or + /// type of expressions have been changed (for example due to type + /// coercion). For example [`LogicalPlan::Projection`]s schema depends on + /// its expressions. + /// + /// Some `LogicalPlan`s schema is unaffected by any changes to their + /// expressions. For example [`LogicalPlan::Filter`] schema is always the + /// same as its input schema. + /// + /// This is useful after modifying a plans `Expr`s (or input plans) via + /// methods such as [Self::map_children] and [Self::map_expressions]. Unlike + /// [Self::with_new_exprs], this method does not require a new set of + /// expressions or inputs plans. + /// + /// # Return value + /// Returns an error if there is some issue recomputing the schema. + /// + /// # Notes + /// + /// * Does not recursively recompute schema for input (child) plans. + pub fn recompute_schema(self) -> Result { + match self { + // Since expr may be different than the previous expr, schema of the projection + // may change. We need to use try_new method instead of try_new_with_schema method. + LogicalPlan::Projection(Projection { + expr, + input, + schema: _, + }) => Projection::try_new(expr, input).map(LogicalPlan::Projection), + LogicalPlan::Dml(_) => Ok(self), + LogicalPlan::Copy(_) => Ok(self), + LogicalPlan::Values(Values { schema, values }) => { + // todo it isn't clear why the schema is not recomputed here + Ok(LogicalPlan::Values(Values { schema, values })) + } + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => Filter::try_new_internal(predicate, input, having) + .map(LogicalPlan::Filter), + LogicalPlan::Repartition(_) => Ok(self), + LogicalPlan::Window(Window { + input, + window_expr, + schema: _, + }) => Window::try_new(window_expr, input).map(LogicalPlan::Window), + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema: _, + }) => Aggregate::try_new(input, group_expr, aggr_expr) + .map(LogicalPlan::Aggregate), + LogicalPlan::Sort(_) => Ok(self), + LogicalPlan::Join(Join { + left, + right, + filter, + join_type, + join_constraint, + on, + schema: _, + null_equals_null, + }) => { + let schema = + build_join_schema(left.schema(), right.schema(), &join_type)?; + + let new_on: Vec<_> = on + .into_iter() + .map(|equi_expr| { + // SimplifyExpression rule may add alias to the equi_expr. + (equi_expr.0.unalias(), equi_expr.1.unalias()) + }) + .collect(); + + Ok(LogicalPlan::Join(Join { + left, + right, + join_type, + join_constraint, + on: new_on, + filter, + schema: DFSchemaRef::new(schema), + null_equals_null, + })) + } + LogicalPlan::Subquery(_) => Ok(self), + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema: _, + }) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias), + LogicalPlan::Limit(_) => Ok(self), + LogicalPlan::Ddl(_) => Ok(self), + LogicalPlan::Extension(Extension { node }) => { + // todo make an API that does not require cloning + // This requires a copy of the extension nodes expressions and inputs + let expr = node.expressions(); + let inputs: Vec<_> = node.inputs().into_iter().cloned().collect(); + Ok(LogicalPlan::Extension(Extension { + node: node.with_exprs_and_inputs(expr, inputs)?, + })) + } + LogicalPlan::Union(Union { inputs, schema }) => { + let input_schema = inputs[0].schema(); + // If inputs are not pruned do not change schema + // TODO this seems wrong (shouldn't we always use the schema of the input?) + let schema = if schema.fields().len() == input_schema.fields().len() { + Arc::clone(&schema) + } else { + Arc::clone(input_schema) + }; + Ok(LogicalPlan::Union(Union { inputs, schema })) + } + LogicalPlan::Distinct(distinct) => { + let distinct = match distinct { + Distinct::All(input) => Distinct::All(input), + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + input, + schema: _, + }) => Distinct::On(DistinctOn::try_new( + on_expr, + select_expr, + sort_expr, + input, + )?), + }; + Ok(LogicalPlan::Distinct(distinct)) + } + LogicalPlan::RecursiveQuery(_) => Ok(self), + LogicalPlan::Analyze(_) => Ok(self), + LogicalPlan::Explain(_) => Ok(self), + LogicalPlan::Prepare(_) => Ok(self), + LogicalPlan::Execute(_) => Ok(self), + LogicalPlan::TableScan(_) => Ok(self), + LogicalPlan::EmptyRelation(_) => Ok(self), + LogicalPlan::Statement(_) => Ok(self), + LogicalPlan::DescribeTable(_) => Ok(self), + LogicalPlan::Unnest(Unnest { + input, + exec_columns, + options, + .. + }) => { + // Update schema with unnested column type. + unnest_with_options(Arc::unwrap_or_clone(input), exec_columns, options) + } + } } /// Returns a new `LogicalPlan` based on `self` with inputs and @@ -495,42 +762,51 @@ impl LogicalPlan { pub fn with_new_exprs( &self, mut expr: Vec, - mut inputs: Vec, + inputs: Vec, ) -> Result { match self { // Since expr may be different than the previous expr, schema of the projection // may change. We need to use try_new method instead of try_new_with_schema method. LogicalPlan::Projection(Projection { .. }) => { - Projection::try_new(expr, Arc::new(inputs.swap_remove(0))) - .map(LogicalPlan::Projection) + let input = self.only_input(inputs)?; + Projection::try_new(expr, Arc::new(input)).map(LogicalPlan::Projection) } LogicalPlan::Dml(DmlStatement { table_name, table_schema, op, .. - }) => Ok(LogicalPlan::Dml(DmlStatement { - table_name: table_name.clone(), - table_schema: table_schema.clone(), - op: op.clone(), - input: Arc::new(inputs.swap_remove(0)), - })), + }) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Dml(DmlStatement::new( + table_name.clone(), + Arc::clone(table_schema), + op.clone(), + Arc::new(input), + ))) + } LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, options, partition_by, - }) => Ok(LogicalPlan::Copy(CopyTo { - input: Arc::new(inputs.swap_remove(0)), - output_url: output_url.clone(), - format_options: format_options.clone(), - options: options.clone(), - partition_by: partition_by.clone(), - })), + }) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: output_url.clone(), + file_type: Arc::clone(file_type), + options: options.clone(), + partition_by: partition_by.clone(), + })) + } LogicalPlan::Values(Values { schema, .. }) => { + self.assert_no_inputs(inputs)?; Ok(LogicalPlan::Values(Values { - schema: schema.clone(), + schema: Arc::clone(schema), values: expr .chunks_exact(schema.fields().len()) .map(|s| s.to_vec()) @@ -538,81 +814,63 @@ impl LogicalPlan { })) } LogicalPlan::Filter { .. } => { - assert_eq!(1, expr.len()); - let predicate = expr.pop().unwrap(); - - // filter predicates should not contain aliased expressions so we remove any aliases - // before this logic was added we would have aliases within filters such as for - // benchmark q6: - // - // lineitem.l_shipdate >= Date32(\"8766\") - // AND lineitem.l_shipdate < Date32(\"9131\") - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >= - // Decimal128(Some(49999999999999),30,15) - // AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <= - // Decimal128(Some(69999999999999),30,15) - // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - - let predicate = predicate - .transform_down(|expr| { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) - } - Expr::Alias(_) => Ok(Transformed::new( - expr.unalias(), - true, - TreeNodeRecursion::Jump, - )), - _ => Ok(Transformed::no(expr)), - } - }) - .data()?; + let predicate = self.only_expr(expr)?; + let input = self.only_input(inputs)?; - Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) - .map(LogicalPlan::Filter) + Filter::try_new(predicate, Arc::new(input)).map(LogicalPlan::Filter) } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. }) => match partitioning_scheme { Partitioning::RoundRobinBatch(n) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; Ok(LogicalPlan::Repartition(Repartition { partitioning_scheme: Partitioning::RoundRobinBatch(*n), - input: Arc::new(inputs.swap_remove(0)), + input: Arc::new(input), + })) + } + Partitioning::Hash(_, n) => { + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Repartition(Repartition { + partitioning_scheme: Partitioning::Hash(expr, *n), + input: Arc::new(input), })) } - Partitioning::Hash(_, n) => Ok(LogicalPlan::Repartition(Repartition { - partitioning_scheme: Partitioning::Hash(expr, *n), - input: Arc::new(inputs.swap_remove(0)), - })), Partitioning::DistributeBy(_) => { + let input = self.only_input(inputs)?; Ok(LogicalPlan::Repartition(Repartition { partitioning_scheme: Partitioning::DistributeBy(expr), - input: Arc::new(inputs.swap_remove(0)), + input: Arc::new(input), })) } }, LogicalPlan::Window(Window { window_expr, .. }) => { assert_eq!(window_expr.len(), expr.len()); - Window::try_new(expr, Arc::new(inputs.swap_remove(0))) - .map(LogicalPlan::Window) + let input = self.only_input(inputs)?; + Window::try_new(expr, Arc::new(input)).map(LogicalPlan::Window) } LogicalPlan::Aggregate(Aggregate { group_expr, .. }) => { + let input = self.only_input(inputs)?; // group exprs are the first expressions let agg_expr = expr.split_off(group_expr.len()); - Aggregate::try_new(Arc::new(inputs.swap_remove(0)), expr, agg_expr) + Aggregate::try_new(Arc::new(input), expr, agg_expr) .map(LogicalPlan::Aggregate) } - LogicalPlan::Sort(Sort { fetch, .. }) => Ok(LogicalPlan::Sort(Sort { - expr, - input: Arc::new(inputs.swap_remove(0)), - fetch: *fetch, - })), + LogicalPlan::Sort(Sort { + expr: sort_expr, + fetch, + .. + }) => { + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Sort(Sort { + expr: replace_sort_expressions(sort_expr.clone(), expr), + input: Arc::new(input), + fetch: *fetch, + })) + } LogicalPlan::Join(Join { join_type, join_constraint, @@ -620,8 +878,8 @@ impl LogicalPlan { null_equals_null, .. }) => { - let schema = - build_join_schema(inputs[0].schema(), inputs[1].schema(), join_type)?; + let (left, right) = self.only_two_inputs(inputs)?; + let schema = build_join_schema(left.schema(), right.schema(), join_type)?; let equi_expr_count = on.len(); assert!(expr.len() >= equi_expr_count); @@ -650,8 +908,8 @@ impl LogicalPlan { }).collect::>>()?; Ok(LogicalPlan::Join(Join { - left: Arc::new(inputs.swap_remove(0)), - right: Arc::new(inputs.swap_remove(0)), + left: Arc::new(left), + right: Arc::new(right), join_type: *join_type, join_constraint: *join_constraint, on: new_on, @@ -660,29 +918,39 @@ impl LogicalPlan { null_equals_null: *null_equals_null, })) } - LogicalPlan::CrossJoin(_) => { - let left = inputs.swap_remove(0); - let right = inputs.swap_remove(0); - LogicalPlanBuilder::from(left).cross_join(right)?.build() - } LogicalPlan::Subquery(Subquery { outer_ref_columns, .. }) => { - let subquery = LogicalPlanBuilder::from(inputs.swap_remove(0)).build()?; + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + let subquery = LogicalPlanBuilder::from(input).build()?; Ok(LogicalPlan::Subquery(Subquery { subquery: Arc::new(subquery), outer_ref_columns: outer_ref_columns.clone(), })) } LogicalPlan::SubqueryAlias(SubqueryAlias { alias, .. }) => { - SubqueryAlias::try_new(Arc::new(inputs.swap_remove(0)), alias.clone()) + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + SubqueryAlias::try_new(Arc::new(input), alias.clone()) .map(LogicalPlan::SubqueryAlias) } LogicalPlan::Limit(Limit { skip, fetch, .. }) => { + let old_expr_len = skip.iter().chain(fetch.iter()).count(); + if old_expr_len != expr.len() { + return internal_err!( + "Invalid number of new Limit expressions: expected {}, got {}", + old_expr_len, + expr.len() + ); + } + let new_skip = skip.as_ref().and_then(|_| expr.pop()); + let new_fetch = fetch.as_ref().and_then(|_| expr.pop()); + let input = self.only_input(inputs)?; Ok(LogicalPlan::Limit(Limit { - skip: *skip, - fetch: *fetch, - input: Arc::new(inputs.swap_remove(0)), + skip: new_skip.map(Box::new), + fetch: new_fetch.map(Box::new), + input: Arc::new(input), })) } LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { @@ -690,38 +958,51 @@ impl LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, .. - })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( - CreateMemoryTable { - input: Arc::new(inputs.swap_remove(0)), - constraints: Constraints::empty(), - name: name.clone(), - if_not_exists: *if_not_exists, - or_replace: *or_replace, - column_defaults: column_defaults.clone(), - }, - ))), + })) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + CreateMemoryTable { + input: Arc::new(input), + constraints: Constraints::empty(), + name: name.clone(), + if_not_exists: *if_not_exists, + or_replace: *or_replace, + column_defaults: column_defaults.clone(), + temporary: *temporary, + }, + ))) + } LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { name, or_replace, definition, + temporary, .. - })) => Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { - input: Arc::new(inputs.swap_remove(0)), - name: name.clone(), - or_replace: *or_replace, - definition: definition.clone(), - }))), + })) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { + input: Arc::new(input), + name: name.clone(), + or_replace: *or_replace, + temporary: *temporary, + definition: definition.clone(), + }))) + } LogicalPlan::Extension(e) => Ok(LogicalPlan::Extension(Extension { - node: e.node.from_template(&expr, &inputs), + node: e.node.with_exprs_and_inputs(expr, inputs)?, })), LogicalPlan::Union(Union { schema, .. }) => { + self.assert_no_expressions(expr)?; let input_schema = inputs[0].schema(); // If inputs are not pruned do not change schema. let schema = if schema.fields().len() == input_schema.fields().len() { - schema.clone() + Arc::clone(schema) } else { - input_schema.clone() + Arc::clone(input_schema) }; Ok(LogicalPlan::Union(Union { inputs: inputs.into_iter().map(Arc::new).collect(), @@ -730,23 +1011,25 @@ impl LogicalPlan { } LogicalPlan::Distinct(distinct) => { let distinct = match distinct { - Distinct::All(_) => Distinct::All(Arc::new(inputs.swap_remove(0))), + Distinct::All(_) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + Distinct::All(Arc::new(input)) + } Distinct::On(DistinctOn { on_expr, select_expr, .. }) => { + let input = self.only_input(inputs)?; let sort_expr = expr.split_off(on_expr.len() + select_expr.len()); let select_expr = expr.split_off(on_expr.len()); + assert!(sort_expr.is_empty(), "with_new_exprs for Distinct does not support sort expressions"); Distinct::On(DistinctOn::try_new( expr, select_expr, - if !sort_expr.is_empty() { - Some(sort_expr) - } else { - None - }, - Arc::new(inputs.swap_remove(0)), + None, // no sort expressions accepted + Arc::new(input), )?) } }; @@ -754,44 +1037,57 @@ impl LogicalPlan { } LogicalPlan::RecursiveQuery(RecursiveQuery { name, is_distinct, .. - }) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: name.clone(), - static_term: Arc::new(inputs.swap_remove(0)), - recursive_term: Arc::new(inputs.swap_remove(0)), - is_distinct: *is_distinct, - })), + }) => { + self.assert_no_expressions(expr)?; + let (static_term, recursive_term) = self.only_two_inputs(inputs)?; + Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: name.clone(), + static_term: Arc::new(static_term), + recursive_term: Arc::new(recursive_term), + is_distinct: *is_distinct, + })) + } LogicalPlan::Analyze(a) => { - assert!(expr.is_empty()); - assert_eq!(inputs.len(), 1); + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; Ok(LogicalPlan::Analyze(Analyze { verbose: a.verbose, - schema: a.schema.clone(), - input: Arc::new(inputs.swap_remove(0)), + schema: Arc::clone(&a.schema), + input: Arc::new(input), })) } LogicalPlan::Explain(e) => { - assert!( - expr.is_empty(), - "Invalid EXPLAIN command. Expression should empty" - ); - assert_eq!(inputs.len(), 1, "Invalid EXPLAIN command. Inputs are empty"); + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; Ok(LogicalPlan::Explain(Explain { verbose: e.verbose, - plan: Arc::new(inputs.swap_remove(0)), + plan: Arc::new(input), stringified_plans: e.stringified_plans.clone(), - schema: e.schema.clone(), + schema: Arc::clone(&e.schema), logical_optimization_succeeded: e.logical_optimization_succeeded, })) } LogicalPlan::Prepare(Prepare { name, data_types, .. - }) => Ok(LogicalPlan::Prepare(Prepare { - name: name.clone(), - data_types: data_types.clone(), - input: Arc::new(inputs.swap_remove(0)), - })), + }) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; + Ok(LogicalPlan::Prepare(Prepare { + name: name.clone(), + data_types: data_types.clone(), + input: Arc::new(input), + })) + } + LogicalPlan::Execute(Execute { name, schema, .. }) => { + self.assert_no_inputs(inputs)?; + Ok(LogicalPlan::Execute(Execute { + name: name.clone(), + schema: Arc::clone(schema), + parameters: expr, + })) + } LogicalPlan::TableScan(ts) => { - assert!(inputs.is_empty(), "{self:?} should have no inputs"); + self.assert_no_inputs(inputs)?; Ok(LogicalPlan::TableScan(TableScan { filters: expr, ..ts.clone() @@ -799,22 +1095,89 @@ impl LogicalPlan { } LogicalPlan::EmptyRelation(_) | LogicalPlan::Ddl(_) - | LogicalPlan::Statement(_) => { + | LogicalPlan::Statement(_) + | LogicalPlan::DescribeTable(_) => { // All of these plan types have no inputs / exprs so should not be called - assert!(expr.is_empty(), "{self:?} should have no exprs"); - assert!(inputs.is_empty(), "{self:?} should have no inputs"); + self.assert_no_expressions(expr)?; + self.assert_no_inputs(inputs)?; Ok(self.clone()) } - LogicalPlan::DescribeTable(_) => Ok(self.clone()), LogicalPlan::Unnest(Unnest { - columns, options, .. + exec_columns: columns, + options, + .. }) => { + self.assert_no_expressions(expr)?; + let input = self.only_input(inputs)?; // Update schema with unnested column type. - let input = inputs.swap_remove(0); - unnest_with_options(input, columns.clone(), options.clone()) + let new_plan = + unnest_with_options(input, columns.clone(), options.clone())?; + Ok(new_plan) } } } + + /// Helper for [Self::with_new_exprs] to use when no expressions are expected. + #[inline] + #[allow(clippy::needless_pass_by_value)] // expr is moved intentionally to ensure it's not used again + fn assert_no_expressions(&self, expr: Vec) -> Result<()> { + if !expr.is_empty() { + return internal_err!("{self:?} should have no exprs, got {:?}", expr); + } + Ok(()) + } + + /// Helper for [Self::with_new_exprs] to use when no inputs are expected. + #[inline] + #[allow(clippy::needless_pass_by_value)] // inputs is moved intentionally to ensure it's not used again + fn assert_no_inputs(&self, inputs: Vec) -> Result<()> { + if !inputs.is_empty() { + return internal_err!("{self:?} should have no inputs, got: {:?}", inputs); + } + Ok(()) + } + + /// Helper for [Self::with_new_exprs] to use when exactly one expression is expected. + #[inline] + fn only_expr(&self, mut expr: Vec) -> Result { + if expr.len() != 1 { + return internal_err!( + "{self:?} should have exactly one expr, got {:?}", + expr + ); + } + Ok(expr.remove(0)) + } + + /// Helper for [Self::with_new_exprs] to use when exactly one input is expected. + #[inline] + fn only_input(&self, mut inputs: Vec) -> Result { + if inputs.len() != 1 { + return internal_err!( + "{self:?} should have exactly one input, got {:?}", + inputs + ); + } + Ok(inputs.remove(0)) + } + + /// Helper for [Self::with_new_exprs] to use when exactly two inputs are expected. + #[inline] + fn only_two_inputs( + &self, + mut inputs: Vec, + ) -> Result<(LogicalPlan, LogicalPlan)> { + if inputs.len() != 2 { + return internal_err!( + "{self:?} should have exactly two inputs, got {:?}", + inputs + ); + } + let right = inputs.remove(1); + let left = inputs.remove(0); + Ok((left, right)) + } + /// Replaces placeholder param values (like `$1`, `$2`) in [`LogicalPlan`] /// with the specified `param_values`. /// @@ -879,10 +1242,7 @@ impl LogicalPlan { Ok(if let LogicalPlan::Prepare(prepare_lp) = plan_with_values { param_values.verify(&prepare_lp.data_types)?; // try and take ownership of the input if is not shared, clone otherwise - match Arc::try_unwrap(prepare_lp.input) { - Ok(input) => input, - Err(arc_input) => arc_input.as_ref().clone(), - } + Arc::unwrap_or_clone(prepare_lp.input) } else { plan_with_values }) @@ -946,15 +1306,11 @@ impl LogicalPlan { _ => None, } } - JoinType::LeftSemi | JoinType::LeftAnti => left.max_rows(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left.max_rows() + } JoinType::RightSemi | JoinType::RightAnti => right.max_rows(), }, - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - match (left.max_rows(), right.max_rows()) { - (Some(left_max), Some(right_max)) => Some(left_max * right_max), - _ => None, - } - } LogicalPlan::Repartition(Repartition { input, .. }) => input.max_rows(), LogicalPlan::Union(Union { inputs, .. }) => inputs .iter() @@ -972,7 +1328,10 @@ impl LogicalPlan { LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), - LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, + LogicalPlan::Limit(limit) => match limit.get_fetch_type() { + Ok(FetchType::Literal(s)) => s, + _ => None, + }, LogicalPlan::Distinct( Distinct::All(input) | Distinct::On(DistinctOn { input, .. }), ) => input.max_rows(), @@ -985,6 +1344,7 @@ impl LogicalPlan { | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) | LogicalPlan::Prepare(_) + | LogicalPlan::Execute(_) | LogicalPlan::Statement(_) | LogicalPlan::Extension(_) => None, } @@ -1004,6 +1364,45 @@ impl LogicalPlan { .unwrap(); contains } + + /// Get the output expressions and their corresponding columns. + /// + /// The parent node may reference the output columns of the plan by expressions, such as + /// projection over aggregate or window functions. This method helps to convert the + /// referenced expressions into columns. + /// + /// See also: [`crate::utils::columnize_expr`] + pub fn columnized_output_exprs(&self) -> Result> { + match self { + LogicalPlan::Aggregate(aggregate) => Ok(aggregate + .output_expressions()? + .into_iter() + .zip(self.schema().columns()) + .collect()), + LogicalPlan::Window(Window { + window_expr, + input, + schema, + }) => { + // The input could be another Window, so the result should also include the input's. For Example: + // `EXPLAIN SELECT RANK() OVER (PARTITION BY a ORDER BY b), SUM(b) OVER (PARTITION BY a) FROM t` + // Its plan is: + // Projection: RANK() PARTITION BY [t.a] ORDER BY [t.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(t.b) PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + // WindowAggr: windowExpr=[[SUM(CAST(t.b AS Int64)) PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] + // WindowAggr: windowExpr=[[RANK() PARTITION BY [t.a] ORDER BY [t.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]/ + // TableScan: t projection=[a, b] + let mut output_exprs = input.columnized_output_exprs()?; + let input_len = input.schema().fields().len(); + output_exprs.extend( + window_expr + .iter() + .zip(schema.columns().into_iter().skip(input_len)), + ); + Ok(output_exprs) + } + _ => Ok(vec![]), + } + } } impl LogicalPlan { @@ -1018,16 +1417,21 @@ impl LogicalPlan { param_values: &ParamValues, ) -> Result { self.transform_up_with_subqueries(|plan| { - let schema = plan.schema().clone(); + let schema = Arc::clone(plan.schema()); + let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|e| { - e.infer_placeholder_types(&schema)?.transform_up(|e| { - if let Expr::Placeholder(Placeholder { id, .. }) = e { - let value = param_values.get_placeholders_with_values(&id)?; - Ok(Transformed::yes(Expr::Literal(value))) - } else { - Ok(Transformed::no(e)) - } - }) + let original_name = name_preserver.save(&e); + let transformed_expr = + e.infer_placeholder_types(&schema)?.transform_up(|e| { + if let Expr::Placeholder(Placeholder { id, .. }) = e { + let value = param_values.get_placeholders_with_values(&id)?; + Ok(Transformed::yes(Expr::Literal(value))) + } else { + Ok(Transformed::no(e)) + } + })?; + // Preserve name to avoid breaking column references to this expression + Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) }) .map(|res| res.data) @@ -1282,8 +1686,8 @@ impl LogicalPlan { }) .collect(); - let elipse = if values.len() > 5 { "..." } else { "" }; - write!(f, "Values: {}{}", str_values.join(", "), elipse) + let eclipse = if values.len() > 5 { "..." } else { "" }; + write!(f, "Values: {}{}", str_values.join(", "), eclipse) } LogicalPlan::TableScan(TableScan { @@ -1377,7 +1781,7 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input: _, output_url, - format_options, + file_type, options, .. }) => { @@ -1387,7 +1791,7 @@ impl LogicalPlan { .collect::>() .join(", "); - write!(f, "CopyTo: format={format_options} output_url={output_url} options: ({op_str})") + write!(f, "CopyTo: format={} output_url={output_url} options: ({op_str})", file_type.get_ext()) } LogicalPlan::Ddl(ddl) => { write!(f, "{}", ddl.display()) @@ -1442,6 +1846,11 @@ impl LogicalPlan { .as_ref() .map(|expr| format!(" Filter: {expr}")) .unwrap_or_else(|| "".to_string()); + let join_type = if filter.is_none() && keys.is_empty() && matches!(join_type, JoinType::Inner) { + "Cross".to_string() + } else { + join_type.to_string() + }; match join_constraint { JoinConstraint::On => { write!( @@ -1463,9 +1872,6 @@ impl LogicalPlan { } } } - LogicalPlan::CrossJoin(_) => { - write!(f, "CrossJoin:") - } LogicalPlan::Repartition(Repartition { partitioning_scheme, .. @@ -1493,16 +1899,20 @@ impl LogicalPlan { ) } }, - LogicalPlan::Limit(Limit { - ref skip, - ref fetch, - .. - }) => { + LogicalPlan::Limit(limit) => { + // Attempt to display `skip` and `fetch` as literals if possible, otherwise as expressions. + let skip_str = match limit.get_skip_type() { + Ok(SkipType::Literal(n)) => n.to_string(), + _ => limit.skip.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()), + }; + let fetch_str = match limit.get_fetch_type() { + Ok(FetchType::Literal(Some(n))) => n.to_string(), + Ok(FetchType::Literal(None)) => "None".to_string(), + _ => limit.fetch.as_ref().map_or_else(|| "None".to_string(), |x| x.to_string()) + }; write!( f, - "Limit: skip={}, fetch={}", - skip, - fetch.map_or_else(|| "None".to_string(), |x| x.to_string()) + "Limit: skip={}, fetch={}", skip_str,fetch_str, ) } LogicalPlan::Subquery(Subquery { .. }) => { @@ -1538,11 +1948,31 @@ impl LogicalPlan { }) => { write!(f, "Prepare: {name:?} {data_types:?} ") } + LogicalPlan::Execute(Execute { name, parameters, .. }) => { + write!(f, "Execute: {} params=[{}]", name, expr_vec_fmt!(parameters)) + } LogicalPlan::DescribeTable(DescribeTable { .. }) => { write!(f, "DescribeTable") } - LogicalPlan::Unnest(Unnest { columns, .. }) => { - write!(f, "Unnest: {}", expr_vec_fmt!(columns)) + LogicalPlan::Unnest(Unnest { + input: plan, + list_type_columns: list_col_indices, + struct_type_columns: struct_col_indices, .. }) => { + let input_columns = plan.schema().columns(); + let list_type_columns = list_col_indices + .iter() + .map(|(i,unnest_info)| + format!("{}|depth={}", &input_columns[*i].to_string(), + unnest_info.depth)) + .collect::>(); + let struct_type_columns = struct_col_indices + .iter() + .map(|i| &input_columns[*i]) + .collect::>(); + // get items from input_columns indexed by list_col_indices + write!(f, "Unnest: lists[{}] structs[{}]", + expr_vec_fmt!(list_type_columns), + expr_vec_fmt!(struct_type_columns)) } } } @@ -1551,7 +1981,7 @@ impl LogicalPlan { } } -impl Debug for LogicalPlan { +impl Display for LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { self.display_indent().fmt(f) } @@ -1564,7 +1994,7 @@ impl ToStringifiedPlan for LogicalPlan { } /// Produces no rows: An empty relation with an empty schema -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct EmptyRelation { /// Whether to produce a placeholder row pub produce_one_row: bool, @@ -1572,6 +2002,13 @@ pub struct EmptyRelation { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for EmptyRelation { + fn partial_cmp(&self, other: &Self) -> Option { + self.produce_one_row.partial_cmp(&other.produce_one_row) + } +} + /// A variadic query operation, Recursive CTE. /// /// # Recursive Query Evaluation @@ -1579,22 +2016,22 @@ pub struct EmptyRelation { /// From the [Postgres Docs]: /// /// 1. Evaluate the non-recursive term. For `UNION` (but not `UNION ALL`), -/// discard duplicate rows. Include all remaining rows in the result of the -/// recursive query, and also place them in a temporary working table. -// +/// discard duplicate rows. Include all remaining rows in the result of the +/// recursive query, and also place them in a temporary working table. +/// /// 2. So long as the working table is not empty, repeat these steps: /// /// * Evaluate the recursive term, substituting the current contents of the -/// working table for the recursive self-reference. For `UNION` (but not `UNION -/// ALL`), discard duplicate rows and rows that duplicate any previous result -/// row. Include all remaining rows in the result of the recursive query, and -/// also place them in a temporary intermediate table. +/// working table for the recursive self-reference. For `UNION` (but not `UNION +/// ALL`), discard duplicate rows and rows that duplicate any previous result +/// row. Include all remaining rows in the result of the recursive query, and +/// also place them in a temporary intermediate table. /// /// * Replace the contents of the working table with the contents of the -/// intermediate table, then empty the intermediate table. +/// intermediate table, then empty the intermediate table. /// /// [Postgres Docs]: https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct RecursiveQuery { /// Name of the query pub name: String, @@ -1611,7 +2048,7 @@ pub struct RecursiveQuery { /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Values { /// The table schema pub schema: DFSchemaRef, @@ -1619,6 +2056,13 @@ pub struct Values { pub values: Vec>, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Values { + fn partial_cmp(&self, other: &Self) -> Option { + self.values.partial_cmp(&other.values) + } +} + /// Evaluates an arbitrary list of expressions (essentially a /// SELECT with an expression list) on its input. #[derive(Clone, PartialEq, Eq, Hash, Debug)] @@ -1633,6 +2077,16 @@ pub struct Projection { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Projection { + fn partial_cmp(&self, other: &Self) -> Option { + match self.expr.partial_cmp(&other.expr) { + Some(Ordering::Equal) => self.input.partial_cmp(&other.input), + cmp => cmp, + } + } +} + impl Projection { /// Create a new Projection pub fn try_new(expr: Vec, input: Arc) -> Result { @@ -1646,7 +2100,9 @@ impl Projection { input: Arc, schema: DFSchemaRef, ) -> Result { - if expr.len() != schema.fields().len() { + if !expr.iter().any(|e| matches!(e, Expr::Wildcard { .. })) + && expr.len() != schema.fields().len() + { return plan_err!("Projection has mismatch between number of expressions ({}) and number of fields in schema ({})", expr.len(), schema.fields().len()); } Ok(Self { @@ -1672,7 +2128,7 @@ impl Projection { /// # Arguments /// /// * `input`: A reference to the input `LogicalPlan` for which the projection schema -/// will be computed. +/// will be computed. /// * `exprs`: A slice of `Expr` expressions representing the projection operation to apply. /// /// # Returns @@ -1681,18 +2137,19 @@ impl Projection { /// produced by the projection operation. If the schema computation is successful, /// the `Result` will contain the schema; otherwise, it will contain an error. pub fn projection_schema(input: &LogicalPlan, exprs: &[Expr]) -> Result> { - let mut schema = DFSchema::new_with_metadata( - exprlist_to_fields(exprs, input)?, - input.schema().metadata().clone(), - )?; - schema = schema.with_functional_dependencies(calc_func_dependencies_for_project( - exprs, input, - )?)?; + let metadata = input.schema().metadata().clone(); + + let schema = + DFSchema::new_with_metadata(exprlist_to_fields(exprs, input)?, metadata)? + .with_functional_dependencies(calc_func_dependencies_for_project( + exprs, input, + )?)?; + Ok(Arc::new(schema)) } /// Aliased subquery -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct SubqueryAlias { @@ -1713,7 +2170,7 @@ impl SubqueryAlias { let fields = change_redundant_column(plan.schema().fields()); let meta_data = plan.schema().as_ref().metadata().clone(); let schema: Schema = - DFSchema::from_unqualifed_fields(fields.into(), meta_data)?.into(); + DFSchema::from_unqualified_fields(fields.into(), meta_data)?.into(); // Since schema is the same, other than qualifier, we can use existing // functional dependencies: let func_dependencies = plan.schema().functional_dependencies().clone(); @@ -1729,6 +2186,16 @@ impl SubqueryAlias { } } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for SubqueryAlias { + fn partial_cmp(&self, other: &Self) -> Option { + match self.input.partial_cmp(&other.input) { + Some(Ordering::Equal) => self.alias.partial_cmp(&other.alias), + cmp => cmp, + } + } +} + /// Filters rows from its input that do not match an /// expression (essentially a WHERE clause with a predicate /// expression). @@ -1740,40 +2207,65 @@ impl SubqueryAlias { /// /// Filter should not be created directly but instead use `try_new()` /// and that these fields are only pub to support pattern matching -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] #[non_exhaustive] pub struct Filter { /// The predicate expression, which must have Boolean type. pub predicate: Expr, /// The incoming logical plan pub input: Arc, + /// The flag to indicate if the filter is a having clause + pub having: bool, } impl Filter { /// Create a new filter operator. + /// + /// Notes: as Aliases have no effect on the output of a filter operator, + /// they are removed from the predicate expression. pub fn try_new(predicate: Expr, input: Arc) -> Result { + Self::try_new_internal(predicate, input, false) + } + + /// Create a new filter operator for a having clause. + /// This is similar to a filter, but its having flag is set to true. + pub fn try_new_with_having(predicate: Expr, input: Arc) -> Result { + Self::try_new_internal(predicate, input, true) + } + + fn is_allowed_filter_type(data_type: &DataType) -> bool { + match data_type { + // Interpret NULL as a missing boolean value. + DataType::Boolean | DataType::Null => true, + DataType::Dictionary(_, value_type) => { + Filter::is_allowed_filter_type(value_type.as_ref()) + } + _ => false, + } + } + + fn try_new_internal( + predicate: Expr, + input: Arc, + having: bool, + ) -> Result { // Filter predicates must return a boolean value so we try and validate that here. // Note that it is not always possible to resolve the predicate expression during plan // construction (such as with correlated subqueries) so we make a best effort here and // ignore errors resolving the expression against the schema. if let Ok(predicate_type) = predicate.get_type(input.schema()) { - if predicate_type != DataType::Boolean { + if !Filter::is_allowed_filter_type(&predicate_type) { return plan_err!( "Cannot create filter with non-boolean predicate '{predicate}' returning {predicate_type}" ); } } - // filter predicates should not be aliased - if let Expr::Alias(Alias { expr, name, .. }) = predicate { - return plan_err!( - "Attempted to create Filter predicate with \ - expression `{expr}` aliased as '{name}'. Filter predicates should not be \ - aliased." - ); - } - - Ok(Self { predicate, input }) + Ok(Self { + predicate: predicate.unalias_nested().data, + input, + having, + }) } /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? @@ -1845,7 +2337,7 @@ impl Filter { } /// Window its input based on a set of window spec and window function (e.g. SUM or RANK) -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Window { /// The incoming logical plan pub input: Arc, @@ -1861,7 +2353,7 @@ impl Window { let fields: Vec<(Option, Arc)> = input .schema() .iter() - .map(|(q, f)| (q.cloned(), f.clone())) + .map(|(q, f)| (q.cloned(), Arc::clone(f))) .collect(); let input_len = fields.len(); let mut window_fields = fields; @@ -1882,18 +2374,14 @@ impl Window { .enumerate() .filter_map(|(idx, expr)| { if let Expr::WindowFunction(WindowFunction { - // Function is ROW_NUMBER - fun: - WindowFunctionDefinition::BuiltInWindowFunction( - BuiltInWindowFunction::RowNumber, - ), + fun: WindowFunctionDefinition::WindowUDF(udwf), partition_by, .. }) = expr { // When there is no PARTITION BY, row number will be unique // across the entire table. - if partition_by.is_empty() { + if udwf.name() == "row_number" && partition_by.is_empty() { return Some(idx + input_len); } } @@ -1914,17 +2402,47 @@ impl Window { window_func_dependencies.extend(new_deps); } - Ok(Window { - input, + Self::try_new_with_schema( window_expr, - schema: Arc::new( + input, + Arc::new( DFSchema::new_with_metadata(window_fields, metadata)? .with_functional_dependencies(window_func_dependencies)?, ), + ) + } + + pub fn try_new_with_schema( + window_expr: Vec, + input: Arc, + schema: DFSchemaRef, + ) -> Result { + if window_expr.len() != schema.fields().len() - input.schema().fields().len() { + return plan_err!( + "Window has mismatch between number of expressions ({}) and number of fields in schema ({})", + window_expr.len(), + schema.fields().len() - input.schema().fields().len() + ); + } + + Ok(Window { + input, + window_expr, + schema, }) } } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Window { + fn partial_cmp(&self, other: &Self) -> Option { + match self.input.partial_cmp(&other.input) { + Some(Ordering::Equal) => self.window_expr.partial_cmp(&other.window_expr), + cmp => cmp, + } + } +} + /// Produces rows from a table provider by reference or from the context #[derive(Clone)] pub struct TableScan { @@ -1942,6 +2460,19 @@ pub struct TableScan { pub fetch: Option, } +impl Debug for TableScan { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("TableScan") + .field("table_name", &self.table_name) + .field("source", &"...") + .field("projection", &self.projection) + .field("projected_schema", &self.projected_schema) + .field("filters", &self.filters) + .field("fetch", &self.fetch) + .finish_non_exhaustive() + } +} + impl PartialEq for TableScan { fn eq(&self, other: &Self) -> bool { self.table_name == other.table_name @@ -1954,6 +2485,37 @@ impl PartialEq for TableScan { impl Eq for TableScan {} +// Manual implementation needed because of `source` and `projected_schema` fields. +// Comparison excludes these field. +impl PartialOrd for TableScan { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableTableScan<'a> { + /// The name of the table + pub table_name: &'a TableReference, + /// Optional column indices to use as a projection + pub projection: &'a Option>, + /// Optional expressions to be used as filters by the table provider + pub filters: &'a Vec, + /// Optional number of rows to read + pub fetch: &'a Option, + } + let comparable_self = ComparableTableScan { + table_name: &self.table_name, + projection: &self.projection, + filters: &self.filters, + fetch: &self.fetch, + }; + let comparable_other = ComparableTableScan { + table_name: &other.table_name, + projection: &other.projection, + filters: &other.filters, + fetch: &other.fetch, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + impl Hash for TableScan { fn hash(&self, state: &mut H) { self.table_name.hash(state); @@ -2018,19 +2580,8 @@ impl TableScan { } } -/// Apply Cross Join to two logical plans -#[derive(Clone, PartialEq, Eq, Hash)] -pub struct CrossJoin { - /// Left input - pub left: Arc, - /// Right input - pub right: Arc, - /// The output schema, containing fields from the left and right inputs - pub schema: DFSchemaRef, -} - -/// Repartition the plan based on a partitioning scheme. -#[derive(Clone, PartialEq, Eq, Hash)] +// Repartition the plan based on a partitioning scheme. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Repartition { /// The incoming logical plan pub input: Arc, @@ -2039,7 +2590,7 @@ pub struct Repartition { } /// Union multiple inputs -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Union { /// Inputs to merge pub inputs: Vec>, @@ -2047,9 +2598,16 @@ pub struct Union { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Union { + fn partial_cmp(&self, other: &Self) -> Option { + self.inputs.partial_cmp(&other.inputs) + } +} + /// Prepare a statement but do not execute it. Prepare statements can have 0 or more /// `Expr::Placeholder` expressions that are filled in during execution -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Prepare { /// The name of the statement pub name: String, @@ -2059,6 +2617,27 @@ pub struct Prepare { pub input: Arc, } +/// Execute a prepared statement. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Execute { + /// The name of the prepared statement to execute + pub name: String, + /// The execute parameters + pub parameters: Vec, + /// Dummy schema + pub schema: DFSchemaRef, +} + +// Comparison excludes the `schema` field. +impl PartialOrd for Execute { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name.partial_cmp(&other.name) { + Some(Ordering::Equal) => self.parameters.partial_cmp(&other.parameters), + cmp => cmp, + } + } +} + /// Describe the schema of table /// /// # Example output: @@ -2081,7 +2660,7 @@ pub struct Prepare { /// | parent_span_id | Utf8 | YES | /// +--------------------+-----------------------------+-------------+ /// ``` -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DescribeTable { /// Table schema pub schema: Arc, @@ -2089,9 +2668,18 @@ pub struct DescribeTable { pub output_schema: DFSchemaRef, } +// Manual implementation of `PartialOrd`, returning none since there are no comparable types in +// `DescribeTable`. This allows `LogicalPlan` to derive `PartialOrd`. +impl PartialOrd for DescribeTable { + fn partial_cmp(&self, _other: &Self) -> Option { + // There is no relevant comparison for schemas + None + } +} + /// Produces a relation with string representations of /// various parts of the plan -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Explain { /// Should extra (detailed, intermediate plans) be included? pub verbose: bool, @@ -2105,9 +2693,39 @@ pub struct Explain { pub logical_optimization_succeeded: bool, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Explain { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableExplain<'a> { + /// Should extra (detailed, intermediate plans) be included? + pub verbose: &'a bool, + /// The logical plan that is being EXPLAIN'd + pub plan: &'a Arc, + /// Represent the various stages plans have gone through + pub stringified_plans: &'a Vec, + /// Used by physical planner to check if should proceed with planning + pub logical_optimization_succeeded: &'a bool, + } + let comparable_self = ComparableExplain { + verbose: &self.verbose, + plan: &self.plan, + stringified_plans: &self.stringified_plans, + logical_optimization_succeeded: &self.logical_optimization_succeeded, + }; + let comparable_other = ComparableExplain { + verbose: &other.verbose, + plan: &other.plan, + stringified_plans: &other.stringified_plans, + logical_optimization_succeeded: &other.logical_optimization_succeeded, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + /// Runs the actual plan, and then prints the physical plan with /// with execution metrics. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Analyze { /// Should extra detail be included? pub verbose: bool, @@ -2117,12 +2735,22 @@ pub struct Analyze { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Analyze { + fn partial_cmp(&self, other: &Self) -> Option { + match self.verbose.partial_cmp(&other.verbose) { + Some(Ordering::Equal) => self.input.partial_cmp(&other.input), + cmp => cmp, + } + } +} + /// Extension operator defined outside of DataFusion // TODO(clippy): This clippy `allow` should be removed if // the manual `PartialEq` is removed in favor of a derive. // (see `PartialEq` the impl for details.) #[allow(clippy::derived_hash_with_manual_eq)] -#[derive(Clone, Eq, Hash)] +#[derive(Debug, Clone, Eq, Hash)] pub struct Extension { /// The runtime extension operator pub node: Arc, @@ -2137,20 +2765,83 @@ impl PartialEq for Extension { } } +impl PartialOrd for Extension { + fn partial_cmp(&self, other: &Self) -> Option { + self.node.partial_cmp(&other.node) + } +} + /// Produces the first `n` tuples from its input and discards the rest. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Limit { /// Number of rows to skip before fetch - pub skip: usize, + pub skip: Option>, /// Maximum number of rows to fetch, /// None means fetching all rows - pub fetch: Option, + pub fetch: Option>, /// The logical plan pub input: Arc, } +/// Different types of skip expression in Limit plan. +pub enum SkipType { + /// The skip expression is a literal value. + Literal(usize), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +/// Different types of fetch expression in Limit plan. +pub enum FetchType { + /// The fetch expression is a literal value. + /// `Literal(None)` means the fetch expression is not provided. + Literal(Option), + /// Currently only supports expressions that can be folded into constants. + UnsupportedExpr, +} + +impl Limit { + /// Get the skip type from the limit plan. + pub fn get_skip_type(&self) -> Result { + match self.skip.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(s)) => { + // `skip = NULL` is equivalent to `skip = 0` + let s = s.unwrap_or(0); + if s >= 0 { + Ok(SkipType::Literal(s as usize)) + } else { + plan_err!("OFFSET must be >=0, '{}' was provided", s) + } + } + _ => Ok(SkipType::UnsupportedExpr), + }, + // `skip = None` is equivalent to `skip = 0` + None => Ok(SkipType::Literal(0)), + } + } + + /// Get the fetch type from the limit plan. + pub fn get_fetch_type(&self) -> Result { + match self.fetch.as_deref() { + Some(expr) => match *expr { + Expr::Literal(ScalarValue::Int64(Some(s))) => { + if s >= 0 { + Ok(FetchType::Literal(Some(s as usize))) + } else { + plan_err!("LIMIT must be >= 0, '{}' was provided", s) + } + } + Expr::Literal(ScalarValue::Int64(None)) => Ok(FetchType::Literal(None)), + _ => Ok(FetchType::UnsupportedExpr), + }, + None => Ok(FetchType::Literal(None)), + } + } +} + /// Removes duplicate rows from the input -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Distinct { /// Plain `DISTINCT` referencing all selection expressions All(Arc), @@ -2158,8 +2849,18 @@ pub enum Distinct { On(DistinctOn), } +impl Distinct { + /// return a reference to the nodes input + pub fn input(&self) -> &Arc { + match self { + Distinct::All(input) => input, + Distinct::On(DistinctOn { input, .. }) => input, + } + } +} + /// Removes duplicate rows from the input -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct DistinctOn { /// The `DISTINCT ON` clause expression list pub on_expr: Vec, @@ -2168,7 +2869,7 @@ pub struct DistinctOn { /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when /// present. Note that those matching expressions actually wrap the `ON` expressions with /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). - pub sort_expr: Option>, + pub sort_expr: Option>, /// The logical plan that is being DISTINCT'd pub input: Arc, /// The schema description of the DISTINCT ON output @@ -2180,7 +2881,7 @@ impl DistinctOn { pub fn try_new( on_expr: Vec, select_expr: Vec, - sort_expr: Option>, + sort_expr: Option>, input: Arc, ) -> Result { if on_expr.is_empty() { @@ -2215,20 +2916,15 @@ impl DistinctOn { /// Try to update `self` with a new sort expressions. /// /// Validates that the sort expressions are a super-set of the `ON` expressions. - pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { - let sort_expr = normalize_cols(sort_expr, self.input.as_ref())?; + pub fn with_sort_expr(mut self, sort_expr: Vec) -> Result { + let sort_expr = normalize_sorts(sort_expr, self.input.as_ref())?; // Check that the left-most sort expressions are the same as the `ON` expressions. let mut matched = true; for (on, sort) in self.on_expr.iter().zip(sort_expr.iter()) { - match sort { - Expr::Sort(SortExpr { expr, .. }) => { - if on != &**expr { - matched = false; - break; - } - } - _ => return plan_err!("Not a sort expression: {sort}"), + if on != &sort.expr { + matched = false; + break; } } @@ -2243,6 +2939,38 @@ impl DistinctOn { } } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for DistinctOn { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableDistinctOn<'a> { + /// The `DISTINCT ON` clause expression list + pub on_expr: &'a Vec, + /// The selected projection expression list + pub select_expr: &'a Vec, + /// The `ORDER BY` clause, whose initial expressions must match those of the `ON` clause when + /// present. Note that those matching expressions actually wrap the `ON` expressions with + /// additional info pertaining to the sorting procedure (i.e. ASC/DESC, and NULLS FIRST/LAST). + pub sort_expr: &'a Option>, + /// The logical plan that is being DISTINCT'd + pub input: &'a Arc, + } + let comparable_self = ComparableDistinctOn { + on_expr: &self.on_expr, + select_expr: &self.select_expr, + sort_expr: &self.sort_expr, + input: &self.input, + }; + let comparable_other = ComparableDistinctOn { + on_expr: &other.on_expr, + select_expr: &other.select_expr, + sort_expr: &other.sort_expr, + input: &other.input, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -2270,9 +2998,9 @@ impl Aggregate { let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); - let grouping_expr: Vec = grouping_set_to_exprlist(group_expr.as_slice())?; + let grouping_expr: Vec<&Expr> = grouping_set_to_exprlist(group_expr.as_slice())?; - let mut qualified_fields = exprlist_to_fields(grouping_expr.as_slice(), &input)?; + let mut qualified_fields = exprlist_to_fields(grouping_expr, &input)?; // Even columns that cannot be null will become nullable when used in a grouping set. if is_grouping_set { @@ -2280,11 +3008,23 @@ impl Aggregate { .into_iter() .map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into())) .collect::>(); + qualified_fields.push(( + None, + Field::new( + Self::INTERNAL_GROUPING_ID, + Self::grouping_id_type(qualified_fields.len()), + false, + ) + .into(), + )); } qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?); - let schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; + let schema = DFSchema::new_with_metadata( + qualified_fields, + input.schema().metadata().clone(), + )?; Self::try_new_with_schema(input, group_expr, aggr_expr, Arc::new(schema)) } @@ -2328,12 +3068,80 @@ impl Aggregate { }) } + fn is_grouping_set(&self) -> bool { + matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)]) + } + + /// Get the output expressions. + fn output_expressions(&self) -> Result> { + static INTERNAL_ID_EXPR: OnceLock = OnceLock::new(); + let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?; + if self.is_grouping_set() { + exprs.push(INTERNAL_ID_EXPR.get_or_init(|| { + Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID)) + })); + } + exprs.extend(self.aggr_expr.iter()); + debug_assert!(exprs.len() == self.schema.fields().len()); + Ok(exprs) + } + /// Get the length of the group by expression in the output schema /// This is not simply group by expression length. Expression may be /// GroupingSet, etc. In these case we need to get inner expression lengths. pub fn group_expr_len(&self) -> Result { grouping_set_expr_count(&self.group_expr) } + + /// Returns the data type of the grouping id. + /// The grouping ID value is a bitmask where each set bit + /// indicates that the corresponding grouping expression is + /// null + pub fn grouping_id_type(group_exprs: usize) -> DataType { + if group_exprs <= 8 { + DataType::UInt8 + } else if group_exprs <= 16 { + DataType::UInt16 + } else if group_exprs <= 32 { + DataType::UInt32 + } else { + DataType::UInt64 + } + } + + /// Internal column used when the aggregation is a grouping set. + /// + /// This column contains a bitmask where each bit represents a grouping + /// expression. The least significant bit corresponds to the rightmost + /// grouping expression. A bit value of 0 indicates that the corresponding + /// column is included in the grouping set, while a value of 1 means it is excluded. + /// + /// For example, for the grouping expressions CUBE(a, b), the grouping ID + /// column will have the following values: + /// 0b00: Both `a` and `b` are included + /// 0b01: `b` is excluded + /// 0b10: `a` is excluded + /// 0b11: Both `a` and `b` are excluded + /// + /// This internal column is necessary because excluded columns are replaced + /// with `NULL` values. To handle these cases correctly, we must distinguish + /// between an actual `NULL` value in a column and a column being excluded from the set. + pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id"; +} + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Aggregate { + fn partial_cmp(&self, other: &Self) -> Option { + match self.input.partial_cmp(&other.input) { + Some(Ordering::Equal) => { + match self.group_expr.partial_cmp(&other.group_expr) { + Some(Ordering::Equal) => self.aggr_expr.partial_cmp(&other.aggr_expr), + cmp => cmp, + } + } + cmp => cmp, + } + } } /// Checks whether any expression in `group_expr` contains `Expr::GroupingSet`. @@ -2360,8 +3168,10 @@ fn calc_func_dependencies_for_aggregate( if !contains_grouping_set(group_expr) { let group_by_expr_names = group_expr .iter() - .map(|item| item.display_name()) - .collect::>>()?; + .map(|item| item.schema_name().to_string()) + .collect::>() + .into_iter() + .collect::>(); let aggregate_func_dependencies = aggregate_functional_dependencies( input.schema(), &group_by_expr_names, @@ -2383,27 +3193,61 @@ fn calc_func_dependencies_for_project( // Calculate expression indices (if present) in the input schema. let proj_indices = exprs .iter() - .filter_map(|expr| { - let expr_name = match expr { - Expr::Alias(alias) => { - format!("{}", alias.expr) - } - _ => format!("{}", expr), - }; - input_fields.iter().position(|item| *item == expr_name) + .map(|expr| match expr { + Expr::Wildcard { qualifier, options } => { + let wildcard_fields = exprlist_to_fields( + vec![&Expr::Wildcard { + qualifier: qualifier.clone(), + options: options.clone(), + }], + input, + )?; + Ok::<_, DataFusionError>( + wildcard_fields + .into_iter() + .filter_map(|(qualifier, f)| { + let flat_name = qualifier + .map(|t| format!("{}.{}", t, f.name())) + .unwrap_or_else(|| f.name().clone()); + input_fields.iter().position(|item| *item == flat_name) + }) + .collect::>(), + ) + } + Expr::Alias(alias) => { + let name = format!("{}", alias.expr); + Ok(input_fields + .iter() + .position(|item| *item == name) + .map(|i| vec![i]) + .unwrap_or(vec![])) + } + _ => { + let name = format!("{}", expr); + Ok(input_fields + .iter() + .position(|item| *item == name) + .map(|i| vec![i]) + .unwrap_or(vec![])) + } }) + .collect::>>()? + .into_iter() + .flatten() .collect::>(); + + let len = exprlist_len(exprs, input.schema(), Some(find_base_plan(input).schema()))?; Ok(input .schema() .functional_dependencies() - .project_functional_dependencies(&proj_indices, exprs.len())) + .project_functional_dependencies(&proj_indices, len)) } /// Sorts its input according to a list of sort expressions. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Sort { /// The sort expressions - pub expr: Vec, + pub expr: Vec, /// The incoming logical plan pub input: Arc, /// Optional fetch limit @@ -2411,7 +3255,7 @@ pub struct Sort { } /// Join two logical plans on one or more join columns -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Join { /// Left input pub left: Arc, @@ -2466,8 +3310,50 @@ impl Join { } } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Join { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableJoin<'a> { + /// Left input + pub left: &'a Arc, + /// Right input + pub right: &'a Arc, + /// Equijoin clause expressed as pairs of (left, right) join expressions + pub on: &'a Vec<(Expr, Expr)>, + /// Filters applied during join (non-equi conditions) + pub filter: &'a Option, + /// Join type + pub join_type: &'a JoinType, + /// Join constraint + pub join_constraint: &'a JoinConstraint, + /// If null_equals_null is true, null == null else null != null + pub null_equals_null: &'a bool, + } + let comparable_self = ComparableJoin { + left: &self.left, + right: &self.right, + on: &self.on, + filter: &self.filter, + join_type: &self.join_type, + join_constraint: &self.join_constraint, + null_equals_null: &self.null_equals_null, + }; + let comparable_other = ComparableJoin { + left: &other.left, + right: &other.right, + on: &other.on, + filter: &other.filter, + join_type: &other.join_type, + join_constraint: &other.join_constraint, + null_equals_null: &other.null_equals_null, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + /// Subquery -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct Subquery { /// The subquery pub subquery: Arc, @@ -2498,8 +3384,12 @@ impl Debug for Subquery { } } -/// Logical partitioning schemes supported by the repartition operator. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Logical partitioning schemes supported by [`LogicalPlan::Repartition`] +/// +/// See [`Partitioning`] for more details on partitioning +/// +/// [`Partitioning`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/enum.Partitioning.html# +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Partitioning { /// Allocate batches using a round-robin algorithm and the specified number of partitions RoundRobinBatch(usize), @@ -2510,31 +3400,114 @@ pub enum Partitioning { DistributeBy(Vec), } +/// Represent the unnesting operation on a list column, such as the recursion depth and +/// the output column name after unnesting +/// +/// Example: given `ColumnUnnestList { output_column: "output_name", depth: 2 }` +/// +/// ```text +/// input output_name +/// ┌─────────┐ ┌─────────┐ +/// │{{1,2}} │ │ 1 │ +/// ├─────────┼─────►├─────────┤ +/// │{{3}} │ │ 2 │ +/// ├─────────┤ ├─────────┤ +/// │{{4},{5}}│ │ 3 │ +/// └─────────┘ ├─────────┤ +/// │ 4 │ +/// ├─────────┤ +/// │ 5 │ +/// └─────────┘ +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] +pub struct ColumnUnnestList { + pub output_column: Column, + pub depth: usize, +} + +impl Display for ColumnUnnestList { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}|depth={}", self.output_column, self.depth) + } +} + /// Unnest a column that contains a nested list type. See /// [`UnnestOptions`] for more details. #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Unnest { /// The incoming logical plan pub input: Arc, - /// The columns to unnest - pub columns: Vec, + /// Columns to run unnest on, can be a list of (List/Struct) columns + pub exec_columns: Vec, + /// refer to the indices(in the input schema) of columns + /// that have type list to run unnest on + pub list_type_columns: Vec<(usize, ColumnUnnestList)>, + /// refer to the indices (in the input schema) of columns + /// that have type struct to run unnest on + pub struct_type_columns: Vec, + /// Having items aligned with the output columns + /// representing which column in the input schema each output column depends on + pub dependency_indices: Vec, /// The output schema, containing the unnested field column. pub schema: DFSchemaRef, /// Options pub options: UnnestOptions, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for Unnest { + fn partial_cmp(&self, other: &Self) -> Option { + #[derive(PartialEq, PartialOrd)] + struct ComparableUnnest<'a> { + /// The incoming logical plan + pub input: &'a Arc, + /// Columns to run unnest on, can be a list of (List/Struct) columns + pub exec_columns: &'a Vec, + /// refer to the indices(in the input schema) of columns + /// that have type list to run unnest on + pub list_type_columns: &'a Vec<(usize, ColumnUnnestList)>, + /// refer to the indices (in the input schema) of columns + /// that have type struct to run unnest on + pub struct_type_columns: &'a Vec, + /// Having items aligned with the output columns + /// representing which column in the input schema each output column depends on + pub dependency_indices: &'a Vec, + /// Options + pub options: &'a UnnestOptions, + } + let comparable_self = ComparableUnnest { + input: &self.input, + exec_columns: &self.exec_columns, + list_type_columns: &self.list_type_columns, + struct_type_columns: &self.struct_type_columns, + dependency_indices: &self.dependency_indices, + options: &self.options, + }; + let comparable_other = ComparableUnnest { + input: &other.input, + exec_columns: &other.exec_columns, + list_type_columns: &other.list_type_columns, + struct_type_columns: &other.struct_type_columns, + dependency_indices: &other.dependency_indices, + options: &other.options, + }; + comparable_self.partial_cmp(&comparable_other) + } +} + #[cfg(test)] mod tests { use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; - use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet}; + use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet}; - use datafusion_common::tree_node::TreeNodeVisitor; + use datafusion_common::tree_node::{TransformedResult, TreeNodeVisitor}; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; + use crate::test::function_stub::count; + fn employee_schema() -> Schema { Schema::new(vec![ Field::new("id", DataType::Int32, false), @@ -2717,10 +3690,10 @@ digraph { strings: Vec, } - impl TreeNodeVisitor for OkVisitor { + impl<'n> TreeNodeVisitor<'n> for OkVisitor { type Node = LogicalPlan; - fn f_down(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &'n LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2734,7 +3707,7 @@ digraph { Ok(TreeNodeRecursion::Continue) } - fn f_up(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &'n LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2800,10 +3773,10 @@ digraph { return_false_from_post_in: OptionalCounter, } - impl TreeNodeVisitor for StoppingVisitor { + impl<'n> TreeNodeVisitor<'n> for StoppingVisitor { type Node = LogicalPlan; - fn f_down(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &'n LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { return Ok(TreeNodeRecursion::Stop); } @@ -2812,7 +3785,7 @@ digraph { Ok(TreeNodeRecursion::Continue) } - fn f_up(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &'n LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { return Ok(TreeNodeRecursion::Stop); } @@ -2869,10 +3842,10 @@ digraph { return_error_from_post_in: OptionalCounter, } - impl TreeNodeVisitor for ErrorVisitor { + impl<'n> TreeNodeVisitor<'n> for ErrorVisitor { type Node = LogicalPlan; - fn f_down(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &'n LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } @@ -2880,7 +3853,7 @@ digraph { self.inner.f_down(plan) } - fn f_up(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &'n LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } @@ -2937,7 +3910,7 @@ digraph { vec![col("a")], Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: empty_schema.clone(), + schema: Arc::clone(&empty_schema), })), empty_schema, ); @@ -2961,54 +3934,6 @@ digraph { .unwrap() } - /// Extension plan that panic when trying to access its input plan - #[derive(Debug)] - struct NoChildExtension { - empty_schema: DFSchemaRef, - } - - impl UserDefinedLogicalNode for NoChildExtension { - fn as_any(&self) -> &dyn std::any::Any { - unimplemented!() - } - - fn name(&self) -> &str { - unimplemented!() - } - - fn inputs(&self) -> Vec<&LogicalPlan> { - panic!("Should not be called") - } - - fn schema(&self) -> &DFSchemaRef { - &self.empty_schema - } - - fn expressions(&self) -> Vec { - unimplemented!() - } - - fn fmt_for_explain(&self, _: &mut fmt::Formatter) -> fmt::Result { - unimplemented!() - } - - fn from_template( - &self, - _: &[Expr], - _: &[LogicalPlan], - ) -> Arc { - unimplemented!() - } - - fn dyn_hash(&self, _: &mut dyn Hasher) { - unimplemented!() - } - - fn dyn_eq(&self, _: &dyn UserDefinedLogicalNode) -> bool { - unimplemented!() - } - } - #[test] fn test_replace_invalid_placeholder() { // test empty placeholder @@ -3100,9 +4025,9 @@ digraph { ); let scan = Arc::new(LogicalPlan::TableScan(TableScan { table_name: TableReference::bare("tab"), - source: source.clone(), + source: Arc::clone(&source) as Arc, projection: None, - projected_schema: schema.clone(), + projected_schema: Arc::clone(&schema), filters: vec![], fetch: None, })); @@ -3132,17 +4057,14 @@ digraph { table_name: TableReference::bare("tab"), source, projection: None, - projected_schema: unique_schema.clone(), + projected_schema: Arc::clone(&unique_schema), filters: vec![], fetch: None, })); let col = schema.field_names()[0].clone(); - let filter = Filter::try_new( - Expr::Column(col.into()).eq(Expr::Literal(ScalarValue::Int32(Some(1)))), - scan, - ) - .unwrap(); + let filter = + Filter::try_new(Expr::Column(col.into()).eq(lit(1i32)), scan).unwrap(); assert!(filter.is_scalar()); } @@ -3160,8 +4082,7 @@ digraph { .build() .unwrap(); - let external_filter = - col("foo").eq(Expr::Literal(ScalarValue::Boolean(Some(true)))); + let external_filter = col("foo").eq(lit(true)); // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs @@ -3186,4 +4107,61 @@ digraph { let actual = format!("{}", plan.display_indent()); assert_eq!(expected.to_string(), actual) } + + #[test] + fn test_plan_partial_ord() { + let empty_relation = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }); + + let describe_table = LogicalPlan::DescribeTable(DescribeTable { + schema: Arc::new(Schema::new(vec![Field::new( + "foo", + DataType::Int32, + false, + )])), + output_schema: DFSchemaRef::new(DFSchema::empty()), + }); + + let describe_table_clone = LogicalPlan::DescribeTable(DescribeTable { + schema: Arc::new(Schema::new(vec![Field::new( + "foo", + DataType::Int32, + false, + )])), + output_schema: DFSchemaRef::new(DFSchema::empty()), + }); + + assert_eq!( + empty_relation.partial_cmp(&describe_table), + Some(Ordering::Less) + ); + assert_eq!( + describe_table.partial_cmp(&empty_relation), + Some(Ordering::Greater) + ); + assert_eq!(describe_table.partial_cmp(&describe_table_clone), None); + } + + #[test] + fn test_limit_with_new_children() { + let limit = LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(Expr::Literal( + ScalarValue::new_ten(&DataType::UInt32).unwrap(), + ))), + input: Arc::new(LogicalPlan::Values(Values { + schema: Arc::new(DFSchema::empty()), + values: vec![vec![]], + })), + }); + let new_limit = limit + .with_new_exprs( + limit.expressions(), + limit.inputs().into_iter().cloned().collect(), + ) + .unwrap(); + assert_eq!(limit, new_limit); + } } diff --git a/datafusion/expr/src/logical_plan/statement.rs b/datafusion/expr/src/logical_plan/statement.rs index f294e7d3ea4c..7ad18ce7bbf7 100644 --- a/datafusion/expr/src/logical_plan/statement.rs +++ b/datafusion/expr/src/logical_plan/statement.rs @@ -16,6 +16,7 @@ // under the License. use datafusion_common::DFSchemaRef; +use std::cmp::Ordering; use std::fmt::{self, Display}; /// Various types of Statements. @@ -25,7 +26,7 @@ use std::fmt::{self, Display}; /// While DataFusion does not offer support transactions, it provides /// [`LogicalPlan`](crate::LogicalPlan) support to assist building /// database systems using DataFusion -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum Statement { // Begin a transaction TransactionStart(TransactionStart), @@ -60,7 +61,7 @@ impl Statement { /// children. /// /// See [crate::LogicalPlan::display] for an example - pub fn display(&self) -> impl fmt::Display + '_ { + pub fn display(&self) -> impl Display + '_ { struct Wrapper<'a>(&'a Statement); impl<'a> Display for Wrapper<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -92,21 +93,21 @@ impl Statement { } /// Indicates if a transaction was committed or aborted -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum TransactionConclusion { Commit, Rollback, } /// Indicates if this transaction is allowed to write -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum TransactionAccessMode { ReadOnly, ReadWrite, } /// Indicates ANSI transaction isolation level -#[derive(Clone, PartialEq, Eq, Hash, Debug)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] pub enum TransactionIsolationLevel { ReadUncommitted, ReadCommitted, @@ -115,7 +116,7 @@ pub enum TransactionIsolationLevel { } /// Indicator that the following statements should be committed or rolled back atomically -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct TransactionStart { /// indicates if transaction is allowed to write pub access_mode: TransactionAccessMode, @@ -125,8 +126,20 @@ pub struct TransactionStart { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for TransactionStart { + fn partial_cmp(&self, other: &Self) -> Option { + match self.access_mode.partial_cmp(&other.access_mode) { + Some(Ordering::Equal) => { + self.isolation_level.partial_cmp(&other.isolation_level) + } + cmp => cmp, + } + } +} + /// Indicator that any current transaction should be terminated -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct TransactionEnd { /// whether the transaction committed or aborted pub conclusion: TransactionConclusion, @@ -136,9 +149,19 @@ pub struct TransactionEnd { pub schema: DFSchemaRef, } +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for TransactionEnd { + fn partial_cmp(&self, other: &Self) -> Option { + match self.conclusion.partial_cmp(&other.conclusion) { + Some(Ordering::Equal) => self.chain.partial_cmp(&other.chain), + cmp => cmp, + } + } +} + /// Set a Variable's value -- value in /// [`ConfigOptions`](datafusion_common::config::ConfigOptions) -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct SetVariable { /// The variable name pub variable: String, @@ -147,3 +170,13 @@ pub struct SetVariable { /// Dummy schema pub schema: DFSchemaRef, } + +// Manual implementation needed because of `schema` field. Comparison excludes this field. +impl PartialOrd for SetVariable { + fn partial_cmp(&self, other: &Self) -> Option { + match self.variable.partial_cmp(&other.value) { + Some(Ordering::Equal) => self.value.partial_cmp(&other.value), + cmp => cmp, + } + } +} diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 37a36c36ca53..ff2c1ec1d58f 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -37,16 +37,17 @@ //! * [`LogicalPlan::with_new_exprs`]: Create a new plan with different expressions //! * [`LogicalPlan::expressions`]: Return a copy of the plan's expressions use crate::{ - dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, CrossJoin, - DdlStatement, Distinct, DistinctOn, DmlStatement, Explain, Expr, Extension, Filter, - Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, - Repartition, Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, - UserDefinedLogicalNode, Values, Window, + dml::CopyTo, Aggregate, Analyze, CreateMemoryTable, CreateView, DdlStatement, + Distinct, DistinctOn, DmlStatement, Execute, Explain, Expr, Extension, Filter, Join, + Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, Repartition, + Sort, Subquery, SubqueryAlias, TableScan, Union, Unnest, UserDefinedLogicalNode, + Values, Window, }; +use std::ops::Deref; use std::sync::Arc; use crate::expr::{Exists, InSubquery}; -use crate::tree_node::transform_option_vec; +use crate::tree_node::{transform_sort_option_vec, transform_sort_vec}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, @@ -56,8 +57,8 @@ use datafusion_common::{ }; impl TreeNode for LogicalPlan { - fn apply_children Result>( - &self, + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, f: F, ) -> Result { self.inputs().into_iter().apply_until_stop(f) @@ -71,10 +72,10 @@ impl TreeNode for LogicalPlan { /// subqueries, for example such as are in [`Expr::Exists`]. /// /// [`Expr::Exists`]: crate::Expr::Exists - fn map_children(self, mut f: F) -> Result> - where - F: FnMut(Self) -> Result>, - { + fn map_children Result>>( + self, + mut f: F, + ) -> Result> { Ok(match self { LogicalPlan::Projection(Projection { expr, @@ -87,8 +88,17 @@ impl TreeNode for LogicalPlan { schema, }) }), - LogicalPlan::Filter(Filter { predicate, input }) => rewrite_arc(input, f)? - .update_data(|input| LogicalPlan::Filter(Filter { predicate, input })), + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => rewrite_arc(input, f)?.update_data(|input| { + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -150,22 +160,6 @@ impl TreeNode for LogicalPlan { null_equals_null, }) }), - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) => map_until_stop_and_collect!( - rewrite_arc(left, &mut f), - right, - rewrite_arc(right, &mut f) - )? - .update_data(|(left, right)| { - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) - }), LogicalPlan::Limit(Limit { skip, fetch, input }) => rewrite_arc(input, f)? .update_data(|input| LogicalPlan::Limit(Limit { skip, fetch, input })), LogicalPlan::Subquery(Subquery { @@ -242,26 +236,28 @@ impl TreeNode for LogicalPlan { table_schema, op, input, + output_schema, }) => rewrite_arc(input, f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, table_schema, op, input, + output_schema, }) }), LogicalPlan::Copy(CopyTo { input, output_url, partition_by, - format_options, + file_type, options, }) => rewrite_arc(input, f)?.update_data(|input| { LogicalPlan::Copy(CopyTo { input, output_url, partition_by, - format_options, + file_type, options, }) }), @@ -274,6 +270,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) => rewrite_arc(input, f)?.update_data(|input| { DdlStatement::CreateMemoryTable(CreateMemoryTable { name, @@ -282,6 +279,7 @@ impl TreeNode for LogicalPlan { if_not_exists, or_replace, column_defaults, + temporary, }) }), DdlStatement::CreateView(CreateView { @@ -289,18 +287,21 @@ impl TreeNode for LogicalPlan { input, or_replace, definition, + temporary, }) => rewrite_arc(input, f)?.update_data(|input| { DdlStatement::CreateView(CreateView { name, input, or_replace, definition, + temporary, }) }), // no inputs in these statements DdlStatement::CreateExternalTable(_) | DdlStatement::CreateCatalogSchema(_) | DdlStatement::CreateCatalog(_) + | DdlStatement::CreateIndex(_) | DdlStatement::DropTable(_) | DdlStatement::DropView(_) | DdlStatement::DropCatalogSchema(_) @@ -311,13 +312,19 @@ impl TreeNode for LogicalPlan { } LogicalPlan::Unnest(Unnest { input, - columns, + exec_columns: input_columns, + list_type_columns, + struct_type_columns, + dependency_indices, schema, options, }) => rewrite_arc(input, f)?.update_data(|input| { LogicalPlan::Unnest(Unnest { input, - columns, + exec_columns: input_columns, + dependency_indices, + list_type_columns, + struct_type_columns, schema, options, }) @@ -356,39 +363,25 @@ impl TreeNode for LogicalPlan { | LogicalPlan::Statement { .. } | LogicalPlan::EmptyRelation { .. } | LogicalPlan::Values { .. } + | LogicalPlan::Execute { .. } | LogicalPlan::DescribeTable(_) => Transformed::no(self), }) } } -/// Converts a `Arc` without copying, if possible. Copies the plan -/// if there is a shared reference -pub fn unwrap_arc(plan: Arc) -> LogicalPlan { - Arc::try_unwrap(plan) - // if None is returned, there is another reference to this - // LogicalPlan, so we can not own it, and must clone instead - .unwrap_or_else(|node| node.as_ref().clone()) -} - /// Applies `f` to rewrite a `Arc` without copying, if possible -fn rewrite_arc( +fn rewrite_arc Result>>( plan: Arc, mut f: F, -) -> Result>> -where - F: FnMut(LogicalPlan) -> Result>, -{ - f(unwrap_arc(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) +) -> Result>> { + f(Arc::unwrap_or_clone(plan))?.map_data(|new_plan| Ok(Arc::new(new_plan))) } /// rewrite a `Vec` of `Arc` without copying, if possible -fn rewrite_arcs( +fn rewrite_arcs Result>>( input_plans: Vec>, mut f: F, -) -> Result>>> -where - F: FnMut(LogicalPlan) -> Result>, -{ +) -> Result>>> { input_plans .into_iter() .map_until_stop_and_collect(|plan| rewrite_arc(plan, &mut f)) @@ -399,13 +392,10 @@ where /// /// Should be removed when we have an API for in place modifications of the /// extension to avoid these copies -fn rewrite_extension_inputs( +fn rewrite_extension_inputs Result>>( extension: Extension, f: F, -) -> Result> -where - F: FnMut(LogicalPlan) -> Result>, -{ +) -> Result> { let Extension { node } = extension; node.inputs() @@ -415,7 +405,7 @@ where .map_data(|new_inputs| { let exprs = node.expressions(); Ok(Extension { - node: node.from_template(&exprs, &new_inputs), + node: node.with_exprs_and_inputs(exprs, new_inputs)?, }) }) } @@ -481,7 +471,9 @@ impl LogicalPlan { .apply_until_stop(|e| f(&e))? .visit_sibling(|| filter.iter().apply_until_stop(f)) } - LogicalPlan::Sort(Sort { expr, .. }) => expr.iter().apply_until_stop(f), + LogicalPlan::Sort(Sort { expr, .. }) => { + expr.iter().apply_until_stop(|sort| f(&sort.expr)) + } LogicalPlan::Extension(extension) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs @@ -490,7 +482,9 @@ impl LogicalPlan { LogicalPlan::TableScan(TableScan { filters, .. }) => { filters.iter().apply_until_stop(f) } - LogicalPlan::Unnest(Unnest { columns, .. }) => { + LogicalPlan::Unnest(unnest) => { + let columns = unnest.exec_columns.clone(); + let exprs = columns .iter() .map(|c| Expr::Column(c.clone())) @@ -505,16 +499,22 @@ impl LogicalPlan { })) => on_expr .iter() .chain(select_expr.iter()) - .chain(sort_expr.iter().flatten()) + .chain(sort_expr.iter().flatten().map(|sort| &sort.expr)) .apply_until_stop(f), + LogicalPlan::Limit(Limit { skip, fetch, .. }) => skip + .iter() + .chain(fetch.iter()) + .map(|e| e.deref()) + .apply_until_stop(f), + LogicalPlan::Execute(Execute { parameters, .. }) => { + parameters.iter().apply_until_stop(f) + } // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) @@ -559,10 +559,17 @@ impl LogicalPlan { value.into_iter().map_until_stop_and_collect(&mut f) })? .update_data(|values| LogicalPlan::Values(Values { schema, values })), - LogicalPlan::Filter(Filter { predicate, input }) => f(predicate)? - .update_data(|predicate| { - LogicalPlan::Filter(Filter { predicate, input }) - }), + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) => f(predicate)?.update_data(|predicate| { + LogicalPlan::Filter(Filter { + predicate, + input, + having, + }) + }), LogicalPlan::Repartition(Repartition { input, partitioning_scheme, @@ -649,29 +656,25 @@ impl LogicalPlan { null_equals_null, }) }), - LogicalPlan::Sort(Sort { expr, input, fetch }) => expr - .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })), + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + transform_sort_vec(expr, &mut f)? + .update_data(|expr| LogicalPlan::Sort(Sort { expr, input, fetch })) + } LogicalPlan::Extension(Extension { node }) => { // would be nice to avoid this copy -- maybe can // update extension to just observer Exprs - node.expressions() + let exprs = node + .expressions() .into_iter() - .map_until_stop_and_collect(f)? - .update_data(|exprs| { - LogicalPlan::Extension(Extension { - node: UserDefinedLogicalNode::from_template( - node.as_ref(), - exprs.as_slice(), - node.inputs() - .into_iter() - .cloned() - .collect::>() - .as_slice(), - ), - }) - }) + .map_until_stop_and_collect(f)?; + let plan = LogicalPlan::Extension(Extension { + node: UserDefinedLogicalNode::with_exprs_and_inputs( + node.as_ref(), + exprs.data, + node.inputs().into_iter().cloned().collect::>(), + )?, + }); + Transformed::new(plan, exprs.transformed, exprs.tnr) } LogicalPlan::TableScan(TableScan { table_name, @@ -704,7 +707,7 @@ impl LogicalPlan { select_expr, select_expr.into_iter().map_until_stop_and_collect(&mut f), sort_expr, - transform_option_vec(sort_expr, &mut f) + transform_sort_option_vec(sort_expr, &mut f) )? .update_data(|(on_expr, select_expr, sort_expr)| { LogicalPlan::Distinct(Distinct::On(DistinctOn { @@ -715,15 +718,47 @@ impl LogicalPlan { schema, })) }), + LogicalPlan::Limit(Limit { skip, fetch, input }) => { + let skip = skip.map(|e| *e); + let fetch = fetch.map(|e| *e); + map_until_stop_and_collect!( + skip.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }), + fetch, + fetch.map_or(Ok::<_, DataFusionError>(Transformed::no(None)), |e| { + Ok(f(e)?.update_data(Some)) + }) + )? + .update_data(|(skip, fetch)| { + LogicalPlan::Limit(Limit { + skip: skip.map(Box::new), + fetch: fetch.map(Box::new), + input, + }) + }) + } + LogicalPlan::Execute(Execute { + parameters, + name, + schema, + }) => parameters + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|parameters| { + LogicalPlan::Execute(Execute { + parameters, + name, + schema, + }) + }), // plans without expressions LogicalPlan::EmptyRelation(_) | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Analyze(_) | LogicalPlan::Explain(_) | LogicalPlan::Union(_) @@ -738,7 +773,7 @@ impl LogicalPlan { /// Visits a plan similarly to [`Self::visit`], including subqueries that /// may appear in expressions such as `IN (SELECT ...)`. - pub fn visit_with_subqueries>( + pub fn visit_with_subqueries TreeNodeVisitor<'n, Node = Self>>( &self, visitor: &mut V, ) -> Result { diff --git a/datafusion/expr/src/operation.rs b/datafusion/expr/src/operation.rs new file mode 100644 index 000000000000..6b79a8248b29 --- /dev/null +++ b/datafusion/expr/src/operation.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains implementations of operations (unary, binary etc.) for DataFusion expressions. + +use crate::expr_fn::binary_expr; +use crate::{Expr, Like}; +use datafusion_expr_common::operator::Operator; +use std::ops::{self, Not}; + +/// Support ` + ` fluent style +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +/// Support ` - ` fluent style +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +/// Support ` * ` fluent style +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiply, rhs) + } +} + +/// Support ` / ` fluent style +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Divide, rhs) + } +} + +/// Support ` % ` fluent style +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} + +/// Support ` & ` fluent style +impl ops::BitAnd for Expr { + type Output = Self; + + fn bitand(self, rhs: Self) -> Self { + binary_expr(self, Operator::BitwiseAnd, rhs) + } +} + +/// Support ` | ` fluent style +impl ops::BitOr for Expr { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self { + binary_expr(self, Operator::BitwiseOr, rhs) + } +} + +/// Support ` ^ ` fluent style +impl ops::BitXor for Expr { + type Output = Self; + + fn bitxor(self, rhs: Self) -> Self { + binary_expr(self, Operator::BitwiseXor, rhs) + } +} + +/// Support ` << ` fluent style +impl ops::Shl for Expr { + type Output = Self; + + fn shl(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::BitwiseShiftLeft, rhs) + } +} + +/// Support ` >> ` fluent style +impl ops::Shr for Expr { + type Output = Self; + + fn shr(self, rhs: Self) -> Self::Output { + binary_expr(self, Operator::BitwiseShiftRight, rhs) + } +} + +/// Support `- ` fluent style +impl ops::Neg for Expr { + type Output = Self; + + fn neg(self) -> Self::Output { + Expr::Negative(Box::new(self)) + } +} + +/// Support `NOT ` fluent style +impl Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + match self { + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => Expr::Like(Like::new( + !negated, + expr, + pattern, + escape_char, + case_insensitive, + )), + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => Expr::SimilarTo(Like::new( + !negated, + expr, + pattern, + escape_char, + case_insensitive, + )), + _ => Expr::Not(Box::new(self)), + } + } +} + +#[cfg(test)] +mod tests { + use crate::lit; + + #[test] + fn test_operators() { + // Add + assert_eq!( + format!("{}", lit(1u32) + lit(2u32)), + "UInt32(1) + UInt32(2)" + ); + // Sub + assert_eq!( + format!("{}", lit(1u32) - lit(2u32)), + "UInt32(1) - UInt32(2)" + ); + // Mul + assert_eq!( + format!("{}", lit(1u32) * lit(2u32)), + "UInt32(1) * UInt32(2)" + ); + // Div + assert_eq!( + format!("{}", lit(1u32) / lit(2u32)), + "UInt32(1) / UInt32(2)" + ); + // Rem + assert_eq!( + format!("{}", lit(1u32) % lit(2u32)), + "UInt32(1) % UInt32(2)" + ); + // BitAnd + assert_eq!( + format!("{}", lit(1u32) & lit(2u32)), + "UInt32(1) & UInt32(2)" + ); + // BitOr + assert_eq!( + format!("{}", lit(1u32) | lit(2u32)), + "UInt32(1) | UInt32(2)" + ); + // BitXor + assert_eq!( + format!("{}", lit(1u32) ^ lit(2u32)), + "UInt32(1) BIT_XOR UInt32(2)" + ); + // Shl + assert_eq!( + format!("{}", lit(1u32) << lit(2u32)), + "UInt32(1) << UInt32(2)" + ); + // Shr + assert_eq!( + format!("{}", lit(1u32) >> lit(2u32)), + "UInt32(1) >> UInt32(2)" + ); + // Neg + assert_eq!(format!("{}", -lit(1u32)), "(- UInt32(1))"); + // Not + assert_eq!(format!("{}", !lit(1u32)), "NOT UInt32(1)"); + } +} diff --git a/datafusion/expr/src/partition_evaluator.rs b/datafusion/expr/src/partition_evaluator.rs index 04b6faf55ae1..a0f0988b4f4e 100644 --- a/datafusion/expr/src/partition_evaluator.rs +++ b/datafusion/expr/src/partition_evaluator.rs @@ -135,7 +135,7 @@ pub trait PartitionEvaluator: Debug + Send { /// must produce an output column with one output row for every /// input row. /// - /// `num_rows` is requied to correctly compute the output in case + /// `num_rows` is required to correctly compute the output in case /// `values.len() == 0` /// /// Implementing this function is an optimization: certain window diff --git a/datafusion/expr/src/planner.rs b/datafusion/expr/src/planner.rs new file mode 100644 index 000000000000..7dd7360e478f --- /dev/null +++ b/datafusion/expr/src/planner.rs @@ -0,0 +1,251 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ContextProvider`] and [`ExprPlanner`] APIs to customize SQL query planning + +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::datatypes::{DataType, Field, SchemaRef}; +use datafusion_common::{ + config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema, + Result, TableReference, +}; + +use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF}; + +/// Provides the `SQL` query planner meta-data about tables and +/// functions referenced in SQL statements, without a direct dependency on other +/// DataFusion structures +pub trait ContextProvider { + /// Getter for a datasource + fn get_table_source(&self, name: TableReference) -> Result>; + + fn get_file_type(&self, _ext: &str) -> Result> { + not_impl_err!("Registered file types are not supported") + } + + /// Getter for a table function + fn get_table_function_source( + &self, + _name: &str, + _args: Vec, + ) -> Result> { + not_impl_err!("Table Functions are not supported") + } + + /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) + /// We don't directly implement this in the logical plan's ['SqlToRel`] + /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency + /// of the sql crate (namely, the `CteWorktable`). + /// The [`ContextProvider`] provides a way to "hide" this dependency. + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not implemented") + } + + /// Getter for expr planners + fn get_expr_planners(&self) -> &[Arc] { + &[] + } + + /// Getter for a UDF description + fn get_function_meta(&self, name: &str) -> Option>; + /// Getter for a UDAF description + fn get_aggregate_meta(&self, name: &str) -> Option>; + /// Getter for a UDWF + fn get_window_meta(&self, name: &str) -> Option>; + /// Getter for system/user-defined variable type + fn get_variable_type(&self, variable_names: &[String]) -> Option; + + /// Get configuration options + fn options(&self) -> &ConfigOptions; + + /// Get all user defined scalar function names + fn udf_names(&self) -> Vec; + + /// Get all user defined aggregate function names + fn udaf_names(&self) -> Vec; + + /// Get all user defined window function names + fn udwf_names(&self) -> Vec; +} + +/// This trait allows users to customize the behavior of the SQL planner +pub trait ExprPlanner: Debug + Send + Sync { + /// Plan the binary operation between two expressions, returns original + /// BinaryExpr if not possible + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + /// Plan the field access expression + /// + /// returns original FieldAccessExpr if not possible + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + /// Plan the array literal, returns OriginalArray if not possible + /// + /// Returns origin expression arguments if not possible + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Original(exprs)) + } + + // Plan the POSITION expression, e.g., POSITION( in ) + // returns origin expression arguments if not possible + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plan the dictionary literal `{ key: value, ...}` + /// + /// Returns origin expression arguments if not possible + fn plan_dictionary_literal( + &self, + expr: RawDictionaryExpr, + _schema: &DFSchema, + ) -> Result> { + Ok(PlannerResult::Original(expr)) + } + + /// Plan an extract expression, e.g., `EXTRACT(month FROM foo)` + /// + /// Returns origin expression arguments if not possible + fn plan_extract(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plan an substring expression, e.g., `SUBSTRING( [FROM ] [FOR ])` + /// + /// Returns origin expression arguments if not possible + fn plan_substring(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plans a struct `struct(expression1[, ..., expression_n])` + /// literal based on the given input expressions. + /// This function takes a vector of expressions and a boolean flag indicating whether + /// the struct uses the optional name + /// + /// Returns a `PlannerResult` containing either the planned struct expressions or the original + /// input expressions if planning is not possible. + fn plan_struct_literal( + &self, + args: Vec, + _is_named_struct: bool, + ) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plans an overlay expression eg `overlay(str PLACING substr FROM pos [FOR count])` + /// + /// Returns origin expression arguments if not possible + fn plan_overlay(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plan a make_map expression, e.g., `make_map(key1, value1, key2, value2, ...)` + /// + /// Returns origin expression arguments if not possible + fn plan_make_map(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Original(args)) + } + + /// Plans compound identifier eg `db.schema.table` for non-empty nested names + /// + /// Note: + /// Currently compound identifier for outer query schema is not supported. + /// + /// Returns planned expression + fn plan_compound_identifier( + &self, + _field: &Field, + _qualifier: Option<&TableReference>, + _nested_names: &[String], + ) -> Result>> { + not_impl_err!( + "Default planner compound identifier hasn't been implemented for ExprPlanner" + ) + } + + /// Plans `ANY` expression, e.g., `expr = ANY(array_expr)` + /// + /// Returns origin binary expression if not possible + fn plan_any(&self, expr: RawBinaryExpr) -> Result> { + Ok(PlannerResult::Original(expr)) + } +} + +/// An operator with two arguments to plan +/// +/// Note `left` and `right` are DataFusion [`Expr`]s but the `op` is the SQL AST +/// operator. +/// +/// This structure is used by [`ExprPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawBinaryExpr { + pub op: sqlparser::ast::BinaryOperator, + pub left: Expr, + pub right: Expr, +} + +/// An expression with GetFieldAccess to plan +/// +/// This structure is used by [`ExprPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawFieldAccessExpr { + pub field_access: GetFieldAccess, + pub expr: Expr, +} + +/// A Dictionary literal expression `{ key: value, ...}` +/// +/// This structure is used by [`ExprPlanner`] to plan operators with +/// custom expressions. +#[derive(Debug, Clone)] +pub struct RawDictionaryExpr { + pub keys: Vec, + pub values: Vec, +} + +/// Result of planning a raw expr with [`ExprPlanner`] +#[derive(Debug, Clone)] +pub enum PlannerResult { + /// The raw expression was successfully planned as a new [`Expr`] + Planned(Expr), + /// The raw expression could not be planned, and is returned unmodified + Original(T), +} diff --git a/datafusion/execution/src/registry.rs b/datafusion/expr/src/registry.rs similarity index 90% rename from datafusion/execution/src/registry.rs rename to datafusion/expr/src/registry.rs index f3714a11c239..6d3457f70d4c 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -17,11 +17,13 @@ //! FunctionRegistry trait +use crate::expr_rewriter::FunctionRewrite; +use crate::planner::ExprPlanner; +use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, Result}; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; -use std::collections::HashMap; -use std::{collections::HashSet, sync::Arc}; +use std::collections::{HashMap, HashSet}; +use std::fmt::Debug; +use std::sync::Arc; /// A registry knows how to build logical expressions out of user-defined function' names pub trait FunctionRegistry { @@ -108,10 +110,21 @@ pub trait FunctionRegistry { ) -> Result<()> { not_impl_err!("Registering FunctionRewrite") } + + /// Set of all registered [`ExprPlanner`]s + fn expr_planners(&self) -> Vec>; + + /// Registers a new [`ExprPlanner`] with the registry. + fn register_expr_planner( + &mut self, + _expr_planner: Arc, + ) -> Result<()> { + not_impl_err!("Registering ExprPlanner") + } } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. -pub trait SerializerRegistry: Send + Sync { +pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( @@ -183,4 +196,8 @@ impl FunctionRegistry for MemoryFunctionRegistry { fn register_udwf(&mut self, udaf: Arc) -> Result>> { Ok(self.udwfs.insert(udaf.name().into(), udaf)) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } diff --git a/datafusion/expr/src/simplify.rs b/datafusion/expr/src/simplify.rs index 6fae31b4a698..e636fabf10fb 100644 --- a/datafusion/expr/src/simplify.rs +++ b/datafusion/expr/src/simplify.rs @@ -29,10 +29,10 @@ use crate::{execution_props::ExecutionProps, Expr, ExprSchemable}; /// information in without having to create `DFSchema` objects. If you /// have a [`DFSchemaRef`] you can use [`SimplifyContext`] pub trait SimplifyInfo { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result; - /// returns true of this expr is nullable (could possibly be NULL) + /// Returns true of this expr is nullable (could possibly be NULL) fn nullable(&self, expr: &Expr) -> Result; /// Returns details needed for partial expression evaluation @@ -72,9 +72,9 @@ impl<'a> SimplifyContext<'a> { } impl<'a> SimplifyInfo for SimplifyContext<'a> { - /// returns true if this Expr has boolean type + /// Returns true if this Expr has boolean type fn is_boolean_type(&self, expr: &Expr) -> Result { - for schema in &self.schema { + if let Some(schema) = &self.schema { if let Ok(DataType::Boolean) = expr.get_type(schema) { return Ok(true); } @@ -109,10 +109,11 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> { } /// Was the expression simplified? +#[derive(Debug)] pub enum ExprSimplifyResult { /// The function call was simplified to an entirely new Expr Simplified(Expr), - /// the function call could not be simplified, and the arguments + /// The function call could not be simplified, and the arguments /// are return unmodified. Original(Vec), } diff --git a/datafusion/expr/src/table_source.rs b/datafusion/expr/src/table_source.rs index f662f4d9f77d..bdb602d48dee 100644 --- a/datafusion/expr/src/table_source.rs +++ b/datafusion/expr/src/table_source.rs @@ -22,7 +22,7 @@ use crate::{Expr, LogicalPlan}; use arrow::datatypes::SchemaRef; use datafusion_common::{Constraints, Result}; -use std::any::Any; +use std::{any::Any, borrow::Cow}; /// Indicates how a filter expression is handled by /// [`TableProvider::scan`]. @@ -61,14 +61,27 @@ pub enum TableType { Temporary, } -/// The TableSource trait is used during logical query planning and optimizations and -/// provides access to schema information and filter push-down capabilities. This trait -/// provides a subset of the functionality of the TableProvider trait in the core -/// datafusion crate. The TableProvider trait provides additional capabilities needed for -/// physical query execution (such as the ability to perform a scan). The reason for -/// having two separate traits is to avoid having the logical plan code be dependent -/// on the DataFusion execution engine. Other projects may want to use DataFusion's -/// logical plans and have their own execution engine. +impl std::fmt::Display for TableType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TableType::Base => write!(f, "Base"), + TableType::View => write!(f, "View"), + TableType::Temporary => write!(f, "Temporary"), + } + } +} + +/// Access schema information and filter push-down capabilities. +/// +/// The TableSource trait is used during logical query planning and +/// optimizations and provides a subset of the functionality of the +/// `TableProvider` trait in the (core) `datafusion` crate. The `TableProvider` +/// trait provides additional capabilities needed for physical query execution +/// (such as the ability to perform a scan). +/// +/// The reason for having two separate traits is to avoid having the logical +/// plan code be dependent on the DataFusion execution engine. Some projects use +/// DataFusion's logical plans and have their own execution engine. pub trait TableSource: Sync + Send { fn as_any(&self) -> &dyn Any; @@ -85,31 +98,19 @@ pub trait TableSource: Sync + Send { TableType::Base } - /// Tests whether the table provider can make use of a filter expression - /// to optimise data retrieval. - #[deprecated(since = "20.0.0", note = "use supports_filters_pushdown instead")] - fn supports_filter_pushdown( - &self, - _filter: &Expr, - ) -> Result { - Ok(TableProviderFilterPushDown::Unsupported) - } - /// Tests whether the table provider can make use of any or all filter expressions /// to optimise data retrieval. - #[allow(deprecated)] fn supports_filters_pushdown( &self, filters: &[&Expr], ) -> Result> { - filters - .iter() - .map(|f| self.supports_filter_pushdown(f)) - .collect() + Ok((0..filters.len()) + .map(|_| TableProviderFilterPushDown::Unsupported) + .collect()) } /// Get the Logical plan of this table provider, if available. - fn get_logical_plan(&self) -> Option<&LogicalPlan> { + fn get_logical_plan(&self) -> Option> { None } diff --git a/datafusion/expr/src/test/function_stub.rs b/datafusion/expr/src/test/function_stub.rs new file mode 100644 index 000000000000..262aa99e5007 --- /dev/null +++ b/datafusion/expr/src/test/function_stub.rs @@ -0,0 +1,507 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Aggregate function stubs for test in expr / optimizer. +//! +//! These are used to avoid a dependence on `datafusion-functions-aggregate` which live in a different crate + +use std::any::Any; + +use arrow::datatypes::{ + DataType, Field, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; + +use datafusion_common::{exec_err, not_impl_err, Result}; + +use crate::type_coercion::aggregates::{avg_return_type, coerce_avg_type, NUMERICS}; +use crate::Volatility::Immutable; +use crate::{ + expr::AggregateFunction, + function::{AccumulatorArgs, StateFieldsArgs}, + utils::AggregateOrderSensitivity, + Accumulator, AggregateUDFImpl, Expr, GroupsAccumulator, ReversedUDAF, Signature, +}; + +macro_rules! create_func { + ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + paste::paste! { + /// Singleton instance of [$UDAF], ensures the UDAF is only created once + /// named STATIC_$(UDAF). For example `STATIC_FirstValue` + #[allow(non_upper_case_globals)] + static [< STATIC_ $UDAF >]: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + #[doc = concat!("AggregateFunction that returns a [AggregateUDF](crate::AggregateUDF) for [`", stringify!($UDAF), "`]")] + pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { + [< STATIC_ $UDAF >] + .get_or_init(|| { + std::sync::Arc::new(crate::AggregateUDF::from(<$UDAF>::default())) + }) + .clone() + } + } + } +} + +create_func!(Sum, sum_udaf); + +pub fn sum(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + sum_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +create_func!(Count, count_udaf); + +pub fn count(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + count_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +create_func!(Avg, avg_udaf); + +pub fn avg(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + avg_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Stub `sum` used for optimizer testing +#[derive(Debug)] +pub struct Sum { + signature: Signature, +} + +impl Sum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Immutable), + } + } +} + +impl Default for Sum { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Sum { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("SUM expects exactly one argument"); + } + + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + + fn coerced_type(data_type: &DataType) -> Result { + match data_type { + DataType::Dictionary(_, v) => coerced_type(v), + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { + Ok(data_type.clone()) + } + dt if dt.is_signed_integer() => Ok(DataType::Int64), + dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), + dt if dt.is_floating() => Ok(DataType::Float64), + _ => exec_err!("Sum not supported for {}", data_type), + } + } + + Ok(vec![coerced_type(&arg_types[0])?]) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Int64 => Ok(DataType::Int64), + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal128(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } + other => { + exec_err!("[return_type] SUM not supported for {}", other) + } + } + } + + fn accumulator(&self, _args: AccumulatorArgs) -> Result> { + unreachable!("stub should not have accumulate()") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unreachable!("stub should not have state_fields()") + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + false + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unreachable!("stub should not have accumulate()") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } +} + +/// Testing stub implementation of COUNT aggregate +pub struct Count { + signature: Signature, + aliases: Vec, +} + +impl std::fmt::Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + aliases: vec!["count".to_string()], + signature: Signature::variadic_any(Immutable), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "COUNT" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } +} + +create_func!(Min, min_udaf); + +pub fn min(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + min_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of Min aggregate +pub struct Min { + signature: Signature, +} + +impl std::fmt::Debug for Min { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Min") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl Min { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn is_descending(&self) -> Option { + Some(false) + } +} + +create_func!(Max, max_udaf); + +pub fn max(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + max_udaf(), + vec![expr], + false, + None, + None, + None, + )) +} + +/// Testing stub implementation of MAX aggregate +pub struct Max { + signature: Signature, +} + +impl std::fmt::Debug for Max { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Max") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} + +impl Max { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Immutable), + } + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "max" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + not_impl_err!("no impl for stub") + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + fn is_descending(&self) -> Option { + Some(true) + } +} + +/// Testing stub implementation of avg aggregate +#[derive(Debug)] +pub struct Avg { + signature: Signature, + aliases: Vec, +} + +impl Avg { + pub fn new() -> Self { + Self { + aliases: vec![String::from("mean")], + signature: Signature::uniform(1, NUMERICS.to_vec(), Immutable), + } + } +} + +impl Default for Avg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Avg { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!("no impl for stub") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + not_impl_err!("no impl for stub") + } + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + coerce_avg_type(self.name(), arg_types) + } +} diff --git a/datafusion/expr/src/test/mod.rs b/datafusion/expr/src/test/mod.rs new file mode 100644 index 000000000000..04e1ccc47465 --- /dev/null +++ b/datafusion/expr/src/test/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod function_stub; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 471ed0b975b0..90afe5722abb 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -15,23 +15,31 @@ // specific language governing permissions and limitations // under the License. -//! Tree node implementation for logical expr +//! Tree node implementation for Logical Expressions use crate::expr::{ - AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, - Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, - ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction, + AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, Like, Placeholder, ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; -use crate::{Expr, GetFieldAccess}; +use crate::{Expr, ExprFunctionExt}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; -use datafusion_common::{internal_err, map_until_stop_and_collect, Result}; +use datafusion_common::{map_until_stop_and_collect, Result}; +/// Implementation of the [`TreeNode`] trait +/// +/// This allows logical expressions (`Expr`) to be traversed and transformed +/// Facilitates tasks such as optimization and rewriting during query +/// planning. impl TreeNode for Expr { - fn apply_children Result>( - &self, + /// Applies a function `f` to each child expression of `self`. + /// + /// The function `f` determines whether to continue traversing the tree or to stop. + /// This method collects all child expressions and applies `f` to each. + fn apply_children<'n, F: FnMut(&'n Self) -> Result>( + &'n self, f: F, ) -> Result { let children = match self { @@ -49,18 +57,7 @@ impl TreeNode for Expr { | Expr::Negative(expr) | Expr::Cast(Cast { expr, .. }) | Expr::TryCast(TryCast { expr, .. }) - | Expr::Sort(Sort { expr, .. }) | Expr::InSubquery(InSubquery{ expr, .. }) => vec![expr.as_ref()], - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = expr.as_ref(); - match field { - GetFieldAccess::ListIndex {key} => vec![key.as_ref(), expr], - GetFieldAccess::ListRange {start, stop, stride} => { - vec![start.as_ref(), stop.as_ref(),stride.as_ref(), expr] - } - GetFieldAccess::NamedStructField { .. } => vec![expr], - } - } Expr::GroupingSet(GroupingSet::Rollup(exprs)) | Expr::GroupingSet(GroupingSet::Cube(exprs)) => exprs.iter().collect(), Expr::ScalarFunction (ScalarFunction{ args, .. } ) => { @@ -109,7 +106,7 @@ impl TreeNode for Expr { expr_vec.push(f.as_ref()); } if let Some(order_by) = order_by { - expr_vec.extend(order_by); + expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); } expr_vec } @@ -121,7 +118,7 @@ impl TreeNode for Expr { }) => { let mut expr_vec = args.iter().collect::>(); expr_vec.extend(partition_by); - expr_vec.extend(order_by); + expr_vec.extend(order_by.iter().map(|sort| &sort.expr)); expr_vec } Expr::InList(InList { expr, list, .. }) => { @@ -134,6 +131,10 @@ impl TreeNode for Expr { children.into_iter().apply_until_stop(f) } + /// Maps each child of `self` using the provided closure `f`. + /// + /// The closure `f` takes ownership of an expression and returns a `Transformed` result, + /// indicating whether the expression was transformed or left unchanged. fn map_children Result>>( self, mut f: F, @@ -146,8 +147,9 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Unnest(_) | Expr::Literal(_) => Transformed::no(self), + Expr::Unnest(Unnest { expr, .. }) => transform_box(expr, &mut f)? + .update_data(|be| Expr::Unnest(Unnest::new_boxed(be))), Expr::Alias(Alias { expr, relation, @@ -275,20 +277,11 @@ impl TreeNode for Expr { .update_data(|be| Expr::Cast(Cast::new(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => transform_box(expr, &mut f)? - .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - transform_vec(args, &mut f)?.map_data(|new_args| match func_def { - ScalarFunctionDefinition::UDF(fun) => { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) - } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } + Expr::ScalarFunction(ScalarFunction { func, args }) => { + transform_vec(args, &mut f)?.map_data(|new_args| { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + func, new_args, + ))) })? } Expr::WindowFunction(WindowFunction { @@ -303,21 +296,20 @@ impl TreeNode for Expr { partition_by, transform_vec(partition_by, &mut f), order_by, - transform_vec(order_by, &mut f) + transform_sort_vec(order_by, &mut f) )? .update_data(|(new_args, new_partition_by, new_order_by)| { - Expr::WindowFunction(WindowFunction::new( - fun, - new_args, - new_partition_by, - new_order_by, - window_frame, - null_treatment, - )) + Expr::WindowFunction(WindowFunction::new(fun, new_args)) + .partition_by(new_partition_by) + .order_by(new_order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build() + .unwrap() }), Expr::AggregateFunction(AggregateFunction { args, - func_def, + func, distinct, filter, order_by, @@ -327,35 +319,18 @@ impl TreeNode for Expr { filter, transform_option_box(filter, &mut f), order_by, - transform_option_vec(order_by, &mut f) + transform_sort_option_vec(order_by, &mut f) )? - .map_data( - |(new_args, new_filter, new_order_by)| match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun, - new_args, - distinct, - new_filter, - new_order_by, - null_treatment, - ))) - } - AggregateFunctionDefinition::UDF(fun) => { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - fun, - new_args, - false, - new_filter, - new_order_by, - null_treatment, - ))) - } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }, - )?, + .map_data(|(new_args, new_filter, new_order_by)| { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + new_args, + distinct, + new_filter, + new_order_by, + null_treatment, + ))) + })?, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), @@ -380,51 +355,82 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - transform_box(expr, &mut f)?.update_data(|be| { - Expr::GetIndexedField(GetIndexedField::new(be, field)) - }) - } }) } } -fn transform_box(be: Box, f: &mut F) -> Result>> -where - F: FnMut(Expr) -> Result>, -{ +/// Transforms a boxed expression by applying the provided closure `f`. +fn transform_box Result>>( + be: Box, + f: &mut F, +) -> Result>> { Ok(f(*be)?.update_data(Box::new)) } -fn transform_option_box( +/// Transforms an optional boxed expression by applying the provided closure `f`. +fn transform_option_box Result>>( obe: Option>, f: &mut F, -) -> Result>>> -where - F: FnMut(Expr) -> Result>, -{ +) -> Result>>> { obe.map_or(Ok(Transformed::no(None)), |be| { Ok(transform_box(be, f)?.update_data(Some)) }) } /// &mut transform a Option<`Vec` of `Expr`s> -pub fn transform_option_vec( +pub fn transform_option_vec Result>>( ove: Option>, f: &mut F, -) -> Result>>> -where - F: FnMut(Expr) -> Result>, -{ +) -> Result>>> { ove.map_or(Ok(Transformed::no(None)), |ve| { Ok(transform_vec(ve, f)?.update_data(Some)) }) } /// &mut transform a `Vec` of `Expr`s -fn transform_vec(ve: Vec, f: &mut F) -> Result>> -where - F: FnMut(Expr) -> Result>, -{ +fn transform_vec Result>>( + ve: Vec, + f: &mut F, +) -> Result>> { ve.into_iter().map_until_stop_and_collect(f) } + +/// Transforms an optional vector of sort expressions by applying the provided closure `f`. +pub fn transform_sort_option_vec Result>>( + sorts_option: Option>, + f: &mut F, +) -> Result>>> { + sorts_option.map_or(Ok(Transformed::no(None)), |sorts| { + Ok(transform_sort_vec(sorts, f)?.update_data(Some)) + }) +} + +/// Transforms an vector of sort expressions by applying the provided closure `f`. +pub fn transform_sort_vec Result>>( + sorts: Vec, + mut f: &mut F, +) -> Result>> { + Ok(sorts + .iter() + .map(|sort| sort.expr.clone()) + .map_until_stop_and_collect(&mut f)? + .update_data(|transformed_exprs| { + replace_sort_expressions(sorts, transformed_exprs) + })) +} + +pub fn replace_sort_expressions(sorts: Vec, new_expr: Vec) -> Vec { + assert_eq!(sorts.len(), new_expr.len()); + sorts + .into_iter() + .zip(new_expr) + .map(|(sort, expr)| replace_sort_expression(sort, expr)) + .collect() +} + +pub fn replace_sort_expression(sort: Sort, new_expr: Expr) -> Sort { + Sort { + expr: new_expr, + ..sort + } +} diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs deleted file mode 100644 index 5ffdc8f94753..000000000000 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ /dev/null @@ -1,748 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::ops::Deref; - -use super::functions::can_coerce_from; -use crate::{AggregateFunction, Signature, TypeSignature}; - -use arrow::datatypes::{ - DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, -}; -use datafusion_common::{internal_err, plan_err, Result}; - -pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; - -pub static SIGNED_INTEGERS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, -]; - -pub static UNSIGNED_INTEGERS: &[DataType] = &[ - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, -]; - -pub static INTEGERS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, -]; - -pub static NUMERICS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float32, - DataType::Float64, -]; - -pub static TIMESTAMPS: &[DataType] = &[ - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), -]; - -pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; - -pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; - -pub static TIMES: &[DataType] = &[ - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), -]; - -/// Returns the coerced data type for each `input_types`. -/// Different aggregate function with different input data type will get corresponding coerced data type. -pub fn coerce_types( - agg_fun: &AggregateFunction, - input_types: &[DataType], - signature: &Signature, -) -> Result> { - use DataType::*; - // Validate input_types matches (at least one of) the func signature. - check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?; - - match agg_fun { - AggregateFunction::Count | AggregateFunction::ApproxDistinct => { - Ok(input_types.to_vec()) - } - AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), - AggregateFunction::Min | AggregateFunction::Max => { - // min and max support the dictionary data type - // unpack the dictionary to get the value - get_min_max_result_type(input_types) - } - AggregateFunction::Sum => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - let v = match &input_types[0] { - Decimal128(p, s) => Decimal128(*p, *s), - Decimal256(p, s) => Decimal256(*p, *s), - d if d.is_signed_integer() => Int64, - d if d.is_unsigned_integer() => UInt64, - d if d.is_floating() => Float64, - Dictionary(_, v) => { - return coerce_types(agg_fun, &[v.as_ref().clone()], signature) - } - _ => { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ) - } - }; - Ok(vec![v]) - } - AggregateFunction::Avg => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval - let v = match &input_types[0] { - Decimal128(p, s) => Decimal128(*p, *s), - Decimal256(p, s) => Decimal256(*p, *s), - d if d.is_numeric() => Float64, - Dictionary(_, v) => { - return coerce_types(agg_fun, &[v.as_ref().clone()], signature) - } - _ => { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ) - } - }; - Ok(vec![v]) - } - AggregateFunction::BitAnd - | AggregateFunction::BitOr - | AggregateFunction::BitXor => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_bit_and_or_xor_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } - AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { - // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc - // smallint, int, bigint, real, double precision, decimal, or interval. - if !is_bool_and_or_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Variance | AggregateFunction::VariancePop => { - if !is_variance_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } - AggregateFunction::Covariance | AggregateFunction::CovariancePop => { - if !is_covariance_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } - AggregateFunction::Stddev | AggregateFunction::StddevPop => { - if !is_stddev_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64]) - } - AggregateFunction::Correlation => { - if !is_correlation_support_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { - let valid_types = [NUMERICS.to_vec(), vec![Null]].concat(); - let input_types_valid = // number of input already checked before - valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); - if !input_types_valid { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } - AggregateFunction::ApproxPercentileCont => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if input_types.len() == 3 && !input_types[2].is_integer() { - return plan_err!( - "The percentile sample points count for {:?} must be integer, not {:?}.", - agg_fun, input_types[2] - ); - } - let mut result = input_types.to_vec(); - if can_coerce_from(&Float64, &input_types[1]) { - result[1] = Float64; - } else { - return plan_err!( - "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", - agg_fun, input_types[1] - ); - } - Ok(result) - } - AggregateFunction::ApproxPercentileContWithWeight => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { - return plan_err!( - "The weight argument for {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[1] - ); - } - if !matches!(input_types[2], Float64) { - return plan_err!( - "The percentile argument for {:?} must be Float64, not {:?}.", - agg_fun, - input_types[2] - ); - } - Ok(input_types.to_vec()) - } - AggregateFunction::ApproxMedian => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } - AggregateFunction::Median - | AggregateFunction::FirstValue - | AggregateFunction::LastValue => Ok(input_types.to_vec()), - AggregateFunction::NthValue => Ok(input_types.to_vec()), - AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), - AggregateFunction::StringAgg => { - if !is_string_agg_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}", - agg_fun, - input_types[0] - ); - } - if !is_string_agg_supported_arg_type(&input_types[1]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}", - agg_fun, - input_types[1] - ); - } - Ok(vec![LargeUtf8, input_types[1].clone()]) - } - } -} - -/// Validate the length of `input_types` matches the `signature` for `agg_fun`. -/// -/// This method DOES NOT validate the argument types - only that (at least one, -/// in the case of [`TypeSignature::OneOf`]) signature matches the desired -/// number of input types. -pub fn check_arg_count( - func_name: &str, - input_types: &[DataType], - signature: &TypeSignature, -) -> Result<()> { - match signature { - TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { - if input_types.len() != *agg_count { - return plan_err!( - "The function {func_name} expects {:?} arguments, but {:?} were provided", - agg_count, - input_types.len() - ); - } - } - TypeSignature::Exact(types) => { - if types.len() != input_types.len() { - return plan_err!( - "The function {func_name} expects {:?} arguments, but {:?} were provided", - types.len(), - input_types.len() - ); - } - } - TypeSignature::OneOf(variants) => { - let ok = variants - .iter() - .any(|v| check_arg_count(func_name, input_types, v).is_ok()); - if !ok { - return plan_err!( - "The function {func_name} does not accept {:?} function arguments.", - input_types.len() - ); - } - } - TypeSignature::VariadicAny => { - if input_types.is_empty() { - return plan_err!( - "The function {func_name} expects at least one argument" - ); - } - } - _ => { - return internal_err!( - "Aggregate functions do not support this {signature:?}" - ); - } - } - Ok(()) -} - -fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - -/// function return type of a sum -pub fn sum_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Int64 => Ok(DataType::Int64), - DataType::UInt64 => Ok(DataType::UInt64), - DataType::Float64 => Ok(DataType::Float64), - DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+10), s) - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } - other => plan_err!("SUM does not support type \"{other:?}\""), - } -} - -/// function return type of variance -pub fn variance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("VAR does not support {arg_type:?}") - } -} - -/// function return type of covariance -pub fn covariance_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("COVAR does not support {arg_type:?}") - } -} - -/// function return type of correlation -pub fn correlation_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("CORR does not support {arg_type:?}") - } -} - -/// function return type of standard deviation -pub fn stddev_return_type(arg_type: &DataType) -> Result { - if NUMERICS.contains(arg_type) { - Ok(DataType::Float64) - } else { - plan_err!("STDDEV does not support {arg_type:?}") - } -} - -/// function return type of an average -pub fn avg_return_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Decimal128(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal128(new_precision, new_scale)) - } - DataType::Decimal256(precision, scale) => { - // in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). - // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); - let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); - Ok(DataType::Decimal256(new_precision, new_scale)) - } - arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), - DataType::Dictionary(_, dict_value_type) => { - avg_return_type(dict_value_type.as_ref()) - } - other => plan_err!("AVG does not support {other:?}"), - } -} - -/// internal sum type of an average -pub fn avg_sum_type(arg_type: &DataType) -> Result { - match arg_type { - DataType::Decimal128(precision, scale) => { - // in the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal128(new_precision, *scale)) - } - DataType::Decimal256(precision, scale) => { - // in Spark the sum type of avg is DECIMAL(min(38,precision+10), s) - let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); - Ok(DataType::Decimal256(new_precision, *scale)) - } - arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), - DataType::Dictionary(_, dict_value_type) => { - avg_sum_type(dict_value_type.as_ref()) - } - other => plan_err!("AVG does not support {other:?}"), - } -} - -pub fn is_bit_and_or_xor_support_arg_type(arg_type: &DataType) -> bool { - NUMERICS.contains(arg_type) -} - -pub fn is_bool_and_or_support_arg_type(arg_type: &DataType) -> bool { - matches!(arg_type, DataType::Boolean) -} - -pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_sum_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) - ), - } -} - -pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - DataType::Dictionary(_, dict_value_type) => { - is_avg_support_arg_type(dict_value_type.as_ref()) - } - _ => matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - || matches!(arg_type, DataType::Decimal128(_, _)| DataType::Decimal256(_, _)) - ), - } -} - -pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_stddev_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -pub fn is_integer_arg_type(arg_type: &DataType) -> bool { - arg_type.is_integer() -} - -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on. -pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - arg_type if NUMERICS.contains(arg_type) - ) -} - -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::StringAgg`] aggregation can operate on. -pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Null - ) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_aggregate_coerce_types() { - // test input args with error number input types - let fun = AggregateFunction::Min; - let input_types = vec![DataType::Int64, DataType::Int32]; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace()); - - // test input args is invalid data type for sum or avg - let fun = AggregateFunction::Sum; - let input_types = vec![DataType::Utf8]; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!( - "Error during planning: The function Sum does not support inputs of type Utf8.", - result.unwrap_err().strip_backtrace() - ); - let fun = AggregateFunction::Avg; - let signature = fun.signature(); - let result = coerce_types(&fun, &input_types, &signature); - assert_eq!( - "Error during planning: The function Avg does not support inputs of type Utf8.", - result.unwrap_err().strip_backtrace() - ); - - // test count, array_agg, approx_distinct, min, max. - // the coerced types is same with input types - let funs = vec![ - AggregateFunction::Count, - AggregateFunction::ArrayAgg, - AggregateFunction::ApproxDistinct, - AggregateFunction::Min, - AggregateFunction::Max, - ]; - let input_types = vec![ - vec![DataType::Int32], - vec![DataType::Decimal128(10, 2)], - vec![DataType::Decimal256(1, 1)], - vec![DataType::Utf8], - ]; - for fun in funs { - for input_type in &input_types { - let signature = fun.signature(); - let result = coerce_types(&fun, input_type, &signature); - assert_eq!(*input_type, result.unwrap()); - } - } - // test sum - let fun = AggregateFunction::Sum; - let signature = fun.signature(); - let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap(); - assert_eq!(r[0], DataType::Int64); - let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap(); - assert_eq!(r[0], DataType::Float64); - let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap(); - assert_eq!(r[0], DataType::Decimal128(20, 3)); - let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); - assert_eq!(r[0], DataType::Decimal256(20, 3)); - - // test avg - let fun = AggregateFunction::Avg; - let signature = fun.signature(); - let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap(); - assert_eq!(r[0], DataType::Float64); - let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap(); - assert_eq!(r[0], DataType::Float64); - let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap(); - assert_eq!(r[0], DataType::Decimal128(20, 3)); - let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap(); - assert_eq!(r[0], DataType::Decimal256(20, 3)); - - // ApproxPercentileCont input types - let input_types = vec![ - vec![DataType::Int8, DataType::Float64], - vec![DataType::Int16, DataType::Float64], - vec![DataType::Int32, DataType::Float64], - vec![DataType::Int64, DataType::Float64], - vec![DataType::UInt8, DataType::Float64], - vec![DataType::UInt16, DataType::Float64], - vec![DataType::UInt32, DataType::Float64], - vec![DataType::UInt64, DataType::Float64], - vec![DataType::Float32, DataType::Float64], - vec![DataType::Float64, DataType::Float64], - ]; - for input_type in &input_types { - let signature = AggregateFunction::ApproxPercentileCont.signature(); - let result = coerce_types( - &AggregateFunction::ApproxPercentileCont, - input_type, - &signature, - ); - assert_eq!(*input_type, result.unwrap()); - } - } - - #[test] - fn test_avg_return_data_type() -> Result<()> { - let data_type = DataType::Decimal128(10, 5); - let result_type = avg_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(14, 9), result_type); - - let data_type = DataType::Decimal128(36, 10); - let result_type = avg_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(38, 14), result_type); - Ok(()) - } - - #[test] - fn test_variance_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = variance_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(variance_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_sum_return_data_type() -> Result<()> { - let data_type = DataType::Decimal128(10, 5); - let result_type = sum_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(20, 5), result_type); - - let data_type = DataType::Decimal128(36, 10); - let result_type = sum_return_type(&data_type)?; - assert_eq!(DataType::Decimal128(38, 10), result_type); - Ok(()) - } - - #[test] - fn test_stddev_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = stddev_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(stddev_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_covariance_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = covariance_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(covariance_return_type(&data_type).is_err()); - Ok(()) - } - - #[test] - fn test_correlation_return_data_type() -> Result<()> { - let data_type = DataType::Float64; - let result_type = correlation_return_type(&data_type)?; - assert_eq!(DataType::Float64, result_type); - - let data_type = DataType::Decimal128(36, 10); - assert!(correlation_return_type(&data_type).is_err()); - Ok(()) - } -} diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 07516c1f6f53..85f8e20ba4a5 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -15,20 +15,125 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - -use crate::signature::{ - ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, -}; -use crate::{Signature, TypeSignature}; +use super::binary::{binary_numeric_coercion, comparison_coercion}; +use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; -use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; +use datafusion_common::{ + exec_err, internal_datafusion_err, internal_err, plan_err, + utils::{coerced_fixed_size_list_to_list, list_ndims}, + Result, +}; +use datafusion_expr_common::{ + signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD}, + type_coercion::binary::string_coercion, +}; +use std::sync::Arc; + +/// Performs type coercion for scalar function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn data_types_with_scalar_udf( + current_types: &[DataType], + func: &ScalarUDF, +) -> Result> { + let signature = func.signature(); + + if current_types.is_empty() { + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!("{} does not support zero arguments.", func.name()); + } + } + + let valid_types = + get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?; + + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + try_coerce_types(valid_types, current_types, &signature.type_signature) +} + +/// Performs type coercion for aggregate function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn data_types_with_aggregate_udf( + current_types: &[DataType], + func: &AggregateUDF, +) -> Result> { + let signature = func.signature(); + + if current_types.is_empty() { + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!("{} does not support zero arguments.", func.name()); + } + } + + let valid_types = get_valid_types_with_aggregate_udf( + &signature.type_signature, + current_types, + func, + )?; + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + try_coerce_types(valid_types, current_types, &signature.type_signature) +} + +/// Performs type coercion for window function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn data_types_with_window_udf( + current_types: &[DataType], + func: &WindowUDF, +) -> Result> { + let signature = func.signature(); -use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; + if current_types.is_empty() { + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!("{} does not support zero arguments.", func.name()); + } + } + + let valid_types = + get_valid_types_with_window_udf(&signature.type_signature, current_types, func)?; + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + try_coerce_types(valid_types, current_types, &signature.type_signature) +} /// Performs type coercion for function arguments. /// @@ -46,14 +151,13 @@ pub fn data_types( return Ok(vec![]); } else { return plan_err!( - "Coercion from {:?} to the signature {:?} failed.", - current_types, + "signature {:?} does not support zero arguments.", &signature.type_signature ); } } - let valid_types = get_valid_types(&signature.type_signature, current_types)?; + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types .iter() .any(|data_type| data_type == current_types) @@ -61,11 +165,46 @@ pub fn data_types( return Ok(current_types.to_vec()); } - // Try and coerce the argument types to match the signature, returning the - // coerced types from the first matching signature. - for valid_types in valid_types { - if let Some(types) = maybe_data_types(&valid_types, current_types) { - return Ok(types); + try_coerce_types(valid_types, current_types, &signature.type_signature) +} + +fn is_well_supported_signature(type_signature: &TypeSignature) -> bool { + if let TypeSignature::OneOf(signatures) = type_signature { + return signatures.iter().all(is_well_supported_signature); + } + + matches!( + type_signature, + TypeSignature::UserDefined + | TypeSignature::Numeric(_) + | TypeSignature::String(_) + | TypeSignature::Coercible(_) + | TypeSignature::Any(_) + ) +} + +fn try_coerce_types( + valid_types: Vec>, + current_types: &[DataType], + type_signature: &TypeSignature, +) -> Result> { + let mut valid_types = valid_types; + + // Well-supported signature that returns exact valid types. + if !valid_types.is_empty() && is_well_supported_signature(type_signature) { + // exact valid types + assert_eq!(valid_types.len(), 1); + let valid_types = valid_types.swap_remove(0); + if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) { + return Ok(t); + } + } else { + // Try and coerce the argument types to match the signature, returning the + // coerced types from the first matching signature. + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, current_types) { + return Ok(types); + } } } @@ -73,10 +212,92 @@ pub fn data_types( plan_err!( "Coercion from {:?} to the signature {:?} failed.", current_types, - &signature.type_signature + type_signature ) } +fn get_valid_types_with_scalar_udf( + signature: &TypeSignature, + current_types: &[DataType], + func: &ScalarUDF, +) -> Result>> { + match signature { + TypeSignature::UserDefined => match func.coerce_types(current_types) { + Ok(coerced_types) => Ok(vec![coerced_types]), + Err(e) => exec_err!("User-defined coercion failed with {:?}", e), + }, + TypeSignature::OneOf(signatures) => { + let mut res = vec![]; + let mut errors = vec![]; + for sig in signatures { + match get_valid_types_with_scalar_udf(sig, current_types, func) { + Ok(valid_types) => { + res.extend(valid_types); + } + Err(e) => { + errors.push(e.to_string()); + } + } + } + + // Every signature failed, return the joined error + if res.is_empty() { + internal_err!( + "Failed to match any signature, errors: {}", + errors.join(",") + ) + } else { + Ok(res) + } + } + _ => get_valid_types(signature, current_types), + } +} + +fn get_valid_types_with_aggregate_udf( + signature: &TypeSignature, + current_types: &[DataType], + func: &AggregateUDF, +) -> Result>> { + let valid_types = match signature { + TypeSignature::UserDefined => match func.coerce_types(current_types) { + Ok(coerced_types) => vec![coerced_types], + Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + }, + TypeSignature::OneOf(signatures) => signatures + .iter() + .filter_map(|t| { + get_valid_types_with_aggregate_udf(t, current_types, func).ok() + }) + .flatten() + .collect::>(), + _ => get_valid_types(signature, current_types)?, + }; + + Ok(valid_types) +} + +fn get_valid_types_with_window_udf( + signature: &TypeSignature, + current_types: &[DataType], + func: &WindowUDF, +) -> Result>> { + let valid_types = match signature { + TypeSignature::UserDefined => match func.coerce_types(current_types) { + Ok(coerced_types) => vec![coerced_types], + Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + }, + TypeSignature::OneOf(signatures) => signatures + .iter() + .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok()) + .flatten() + .collect::>(), + _ => get_valid_types(signature, current_types)?, + }; + + Ok(valid_types) +} + /// Returns a Vec of all possible valid argument types for the given signature. fn get_valid_types( signature: &TypeSignature, @@ -179,36 +400,133 @@ fn get_valid_types( .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), + TypeSignature::String(number) => { + if *number < 1 { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if *number != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + number, + current_types.len() + ); + } + + fn coercion_rule( + lhs_type: &DataType, + rhs_type: &DataType, + ) -> Result { + match (lhs_type, rhs_type) { + (DataType::Null, DataType::Null) => Ok(DataType::Utf8), + (DataType::Null, data_type) | (data_type, DataType::Null) => { + coercion_rule(data_type, &DataType::Utf8) + } + (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => { + coercion_rule(lhs, rhs) + } + (DataType::Dictionary(_, v), other) + | (other, DataType::Dictionary(_, v)) => coercion_rule(v, other), + _ => { + if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) { + Ok(coerced_type) + } else { + plan_err!( + "{} and {} are not coercible to a common string type", + lhs_type, + rhs_type + ) + } + } + } + } + + // Length checked above, safe to unwrap + let mut coerced_type = current_types.first().unwrap().to_owned(); + for t in current_types.iter().skip(1) { + coerced_type = coercion_rule(&coerced_type, t)?; + } + + fn base_type_or_default_type(data_type: &DataType) -> DataType { + if data_type.is_null() { + DataType::Utf8 + } else if let DataType::Dictionary(_, v) = data_type { + base_type_or_default_type(v) + } else { + data_type.to_owned() + } + } + + vec![vec![base_type_or_default_type(&coerced_type); *number]] + } + TypeSignature::Numeric(number) => { + if *number < 1 { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if *number != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + number, + current_types.len() + ); + } + + let mut valid_type = current_types.first().unwrap().clone(); + for t in current_types.iter().skip(1) { + if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) { + valid_type = coerced_type; + } else { + return plan_err!( + "{} and {} are not coercible to a common numeric type", + valid_type, + t + ); + } + } + + vec![vec![valid_type; *number]] + } + TypeSignature::Coercible(target_types) => { + if target_types.is_empty() { + return plan_err!( + "The signature expected at least one argument but received {}", + current_types.len() + ); + } + if target_types.len() != current_types.len() { + return plan_err!( + "The signature expected {} arguments but received {}", + target_types.len(), + current_types.len() + ); + } + + for (data_type, target_type) in current_types.iter().zip(target_types.iter()) + { + if !can_cast_types(data_type, target_type) { + return plan_err!("{data_type} is not coercible to {target_type}"); + } + } + + vec![target_types.to_owned()] + } TypeSignature::Uniform(number, valid_types) => valid_types .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), - TypeSignature::VariadicEqual => { - let new_type = current_types.iter().skip(1).try_fold( - current_types.first().unwrap().clone(), - |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) - } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") - } - }, - ); - - match new_type { - Ok(new_type) => vec![vec![new_type; current_types.len()]], - Err(e) => return Err(e), - } + TypeSignature::UserDefined => { + return internal_err!( + "User-defined signature should be handled by function-specific coerce_types." + ) } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } - TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::ArraySignature(ref function_signature) => match function_signature { @@ -238,6 +556,16 @@ fn get_valid_types( array(¤t_types[0]) .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]]) } + ArrayFunctionSignature::MapArray => { + if current_types.len() != 1 { + return Ok(vec![vec![]]); + } + + match ¤t_types[0] { + DataType::Map(_, _) => vec![vec![current_types[0].clone()]], + _ => vec![vec![]], + } + } }, TypeSignature::Any(number) => { if current_types.len() != *number { @@ -281,6 +609,8 @@ fn maybe_data_types( new_type.push(current_type.clone()) } else { // attempt to coerce. + // TODO: Replace with `can_cast_types` after failing cases are resolved + // (they need new signature that returns exactly valid types instead of list of possible valid types). if let Some(coerced_type) = coerced_from(valid_type, current_type) { new_type.push(coerced_type) } else { @@ -292,6 +622,33 @@ fn maybe_data_types( Some(new_type) } +/// Check if the current argument types can be coerced to match the given `valid_types` +/// unlike `maybe_data_types`, this function does not coerce the types. +/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported. +fn maybe_data_types_without_coercion( + valid_types: &[DataType], + current_types: &[DataType], +) -> Option> { + if valid_types.len() != current_types.len() { + return None; + } + + let mut new_type = Vec::with_capacity(valid_types.len()); + for (i, valid_type) in valid_types.iter().enumerate() { + let current_type = ¤t_types[i]; + + if current_type == valid_type { + new_type.push(current_type.clone()) + } else if can_cast_types(current_type, valid_type) { + // validate the valid type is castable from the current type + new_type.push(valid_type.clone()) + } else { + return None; + } + } + Some(new_type) +} + /// Return true if a value of type `type_from` can be coerced /// (losslessly converted) into a value of `type_to` /// @@ -306,11 +663,18 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { false } +/// Find the coerced type for the given `type_into` and `type_from`. +/// Returns `None` if coercion is not possible. +/// +/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32. +/// +/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion. fn coerced_from<'a>( type_into: &'a DataType, type_from: &'a DataType, ) -> Option { use self::DataType::*; + // match Dictionary first match (type_into, type_from) { // coerced dictionary first @@ -325,85 +689,48 @@ fn coerced_from<'a>( Some(type_into.clone()) } // coerced into type_into - (Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()), - (Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => { - Some(type_into.clone()) - } - (Int32, _) - if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => - { - Some(type_into.clone()) - } - (Int64, _) - if matches!( - type_from, - Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 - ) => - { - Some(type_into.clone()) - } - (UInt8, _) if matches!(type_from, Null | UInt8) => Some(type_into.clone()), - (UInt16, _) if matches!(type_from, Null | UInt8 | UInt16) => { - Some(type_into.clone()) - } - (UInt32, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => { - Some(type_into.clone()) - } - (UInt64, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => { - Some(type_into.clone()) - } - (Float32, _) - if matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - ) => - { - Some(type_into.clone()) - } - (Float64, _) - if matches!( - type_from, - Null | Int8 - | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - ) => - { - Some(type_into.clone()) - } - (Timestamp(TimeUnit::Nanosecond, None), _) - if matches!( - type_from, - Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8 - ) => - { - Some(type_into.clone()) - } - (Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => { + (Int8, Null | Int8) => Some(type_into.clone()), + (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()), + (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()), + (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => { Some(type_into.clone()) } + (UInt8, Null | UInt8) => Some(type_into.clone()), + (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()), + (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()), + (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()), + ( + Float32, + Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 + | Float32, + ) => Some(type_into.clone()), + ( + Float64, + Null + | Int8 + | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Decimal128(_, _), + ) => Some(type_into.clone()), + ( + Timestamp(TimeUnit::Nanosecond, None), + Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8, + ) => Some(type_into.clone()), + (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()), + // We can go into a Utf8View from a Utf8 or LargeUtf8 + (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()), // Any type can be coerced into strings (Utf8 | LargeUtf8, _) => Some(type_into.clone()), (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()), - (List(_), _) if matches!(type_from, FixedSizeList(_, _)) => { - Some(type_into.clone()) - } + (List(_), FixedSizeList(_, _)) => Some(type_into.clone()), // Only accept list and largelist with the same number of dimensions unless the type is Null. // List or LargeList with different dimensions should be handled in TypeSignature or other places before this @@ -414,54 +741,34 @@ fn coerced_from<'a>( Some(type_into.clone()) } // should be able to coerce wildcard fixed size list to non wildcard fixed size list - (FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), _) => match type_from { - FixedSizeList(f_from, size_from) => { - match coerced_from(f_into.data_type(), f_from.data_type()) { - Some(data_type) if &data_type != f_into.data_type() => { - let new_field = - Arc::new(f_into.as_ref().clone().with_data_type(data_type)); - Some(FixedSizeList(new_field, *size_from)) - } - Some(_) => Some(FixedSizeList(f_into.clone(), *size_from)), - _ => None, - } + ( + FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), + FixedSizeList(f_from, size_from), + ) => match coerced_from(f_into.data_type(), f_from.data_type()) { + Some(data_type) if &data_type != f_into.data_type() => { + let new_field = + Arc::new(f_into.as_ref().clone().with_data_type(data_type)); + Some(FixedSizeList(new_field, *size_from)) } + Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)), _ => None, }, - (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => { match type_from { Timestamp(_, Some(from_tz)) => { - Some(Timestamp(unit.clone(), Some(from_tz.clone()))) + Some(Timestamp(*unit, Some(Arc::clone(from_tz)))) } Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => { // In the absence of any other information assume the time zone is "+00" (UTC). - Some(Timestamp(unit.clone(), Some("+00".into()))) + Some(Timestamp(*unit, Some("+00".into()))) } _ => None, } } - (Timestamp(_, Some(_)), _) - if matches!( - type_from, - Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8 - ) => - { + (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => { Some(type_into.clone()) } - // More coerce rules. - // Note that not all rules in `comparison_coercion` can be reused here. - // For example, all numeric types can be coerced into Utf8 for comparison, - // but not for function arguments. - _ => comparison_binary_numeric_coercion(type_into, type_from).and_then( - |coerced_type| { - if *type_into == coerced_type { - Some(coerced_type) - } else { - None - } - }, - ), + _ => None, } } @@ -473,6 +780,18 @@ mod tests { use super::*; use arrow::datatypes::Field; + #[test] + fn test_string_conversion() { + let cases = vec![ + (DataType::Utf8View, DataType::Utf8, true), + (DataType::Utf8View, DataType::LargeUtf8, true), + ]; + + for case in cases { + assert_eq!(can_coerce_from(&case.0, &case.1), case.2); + } + } + #[test] fn test_maybe_data_types() { // this vec contains: arg1, arg2, expected result @@ -483,7 +802,7 @@ mod tests { vec![DataType::UInt8, DataType::UInt16], Some(vec![DataType::UInt8, DataType::UInt16]), ), - // 2 entries, can coerse values + // 2 entries, can coerce values ( vec![DataType::UInt16, DataType::UInt16], vec![DataType::UInt8, DataType::UInt16], @@ -552,12 +871,12 @@ mod tests { fn test_fixed_list_wildcard_coerce() -> Result<()> { let inner = Arc::new(Field::new("item", DataType::Int32, false)); let current_types = vec![ - DataType::FixedSizeList(inner.clone(), 2), // able to coerce for any size + DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size ]; let signature = Signature::exact( vec![DataType::FixedSizeList( - inner.clone(), + Arc::clone(&inner), FIXED_SIZE_LIST_WILDCARD, )], Volatility::Stable, @@ -568,7 +887,7 @@ mod tests { // make sure it can't coerce to a different size let signature = Signature::exact( - vec![DataType::FixedSizeList(inner.clone(), 3)], + vec![DataType::FixedSizeList(Arc::clone(&inner), 3)], Volatility::Stable, ); let coerced_data_types = data_types(¤t_types, &signature); @@ -576,7 +895,7 @@ mod tests { // make sure it works with the same type. let signature = Signature::exact( - vec![DataType::FixedSizeList(inner.clone(), 2)], + vec![DataType::FixedSizeList(Arc::clone(&inner), 2)], Volatility::Stable, ); let coerced_data_types = data_types(¤t_types, &signature).unwrap(); diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 86005da3dafa..3a5c65fb46ee 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -19,7 +19,7 @@ //! //! Coercion is performed automatically by DataFusion when the types //! of arguments passed to a function or needed by operators do not -//! exacty match the types required by that function / operator. In +//! exactly match the types required by that function / operator. In //! this case, DataFusion will attempt to *coerce* the arguments to //! types accepted by the function by inserting CAST operations. //! @@ -31,11 +31,14 @@ //! i64. However, i64 -> i32 is never performed as there are i64 //! values which can not be represented by i32 values. -pub mod aggregates; -pub mod binary; +pub mod aggregates { + pub use datafusion_expr_common::type_coercion::aggregates::*; +} pub mod functions; pub mod other; +pub use datafusion_expr_common::type_coercion::binary; + use arrow::datatypes::DataType; /// Determine whether the given data type `dt` represents signed numeric values. pub fn is_signed_numeric(dt: &DataType) -> bool { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 67c3b51ca373..dbbf88447ba3 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,18 +17,28 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::AccumulatorArgs; -use crate::groups_accumulator::GroupsAccumulator; -use crate::utils::format_state_name; -use crate::{Accumulator, Expr}; -use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::{not_impl_err, Result}; use std::any::Any; +use std::cmp::Ordering; use std::fmt::{self, Debug, Formatter}; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; +use arrow::datatypes::{DataType, Field}; + +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +use crate::expr::AggregateFunction; +use crate::function::{ + AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, +}; +use crate::groups_accumulator::GroupsAccumulator; +use crate::utils::format_state_name; +use crate::utils::AggregateOrderSensitivity; +use crate::{Accumulator, Expr}; +use crate::{Documentation, Signature}; + /// Logical representation of a user-defined [aggregate function] (UDAF). /// /// An aggregate function combines the values from multiple input rows @@ -48,7 +58,7 @@ use std::vec; /// 1. For simple use cases, use [`create_udaf`] (examples in [`simple_udaf.rs`]). /// /// 2. For advanced use cases, use [`AggregateUDFImpl`] which provides full API -/// access (examples in [`advanced_udaf.rs`]). +/// access (examples in [`advanced_udaf.rs`]). /// /// # API Note /// This is a separate struct from `AggregateUDFImpl` to maintain backwards @@ -60,46 +70,48 @@ use std::vec; /// [`create_udaf`]: crate::expr_fn::create_udaf /// [`simple_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs /// [`advanced_udaf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialOrd)] pub struct AggregateUDF { inner: Arc, } impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for AggregateUDF {} -impl std::hash::Hash for AggregateUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for AggregateUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } -impl AggregateUDF { - /// Create a new AggregateUDF - /// - /// See [`AggregateUDFImpl`] for a more convenient way to create a - /// `AggregateUDF` using trait objects - #[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")] - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - accumulator: &AccumulatorFactoryFunction, - ) -> Self { - Self::new_from_impl(AggregateUDFLegacyWrapper { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), - }) +impl fmt::Display for AggregateUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "{}", self.name()) } +} +/// Arguments passed to [`AggregateUDFImpl::value_from_stats`] +pub struct StatisticsArgs<'a> { + /// The statistics of the aggregate input + pub statistics: &'a Statistics, + /// The resolved return type of the aggregate function + pub return_type: &'a DataType, + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + /// The physical expression of arguments the aggregate function takes. + pub exprs: &'a [Arc], +} + +impl AggregateUDF { /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object /// /// Note this is the same as using the `From` impl (`AggregateUDF::from`) @@ -113,8 +125,8 @@ impl AggregateUDF { } /// Return the underlying [`AggregateUDFImpl`] trait object for this function - pub fn inner(&self) -> Arc { - self.inner.clone() + pub fn inner(&self) -> &Arc { + &self.inner } /// Adds additional names that can be used to invoke this function, in @@ -122,16 +134,18 @@ impl AggregateUDF { /// /// If you implement [`AggregateUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedAggregateUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedAggregateUDFImpl::new( + Arc::clone(&self.inner), + aliases, + )) } - /// creates an [`Expr`] that calls the aggregate function. + /// Creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - // TODO: Support dictinct, filter, order by and null_treatment - Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(self.clone()), args, false, @@ -148,6 +162,10 @@ impl AggregateUDF { self.inner.name() } + pub fn is_nullable(&self) -> bool { + self.inner.is_nullable() + } + /// Returns the aliases for this function. pub fn aliases(&self) -> &[String] { self.inner.aliases() @@ -177,23 +195,93 @@ impl AggregateUDF { /// for more details. /// /// This is used to support multi-phase aggregations - pub fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - self.inner.state_fields(name, value_type, ordering_fields) + pub fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) } /// See [`AggregateUDFImpl::groups_accumulator_supported`] for more details. - pub fn groups_accumulator_supported(&self) -> bool { - self.inner.groups_accumulator_supported() + pub fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner.groups_accumulator_supported(args) } /// See [`AggregateUDFImpl::create_groups_accumulator`] for more details. - pub fn create_groups_accumulator(&self) -> Result> { - self.inner.create_groups_accumulator() + pub fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.create_groups_accumulator(args) + } + + pub fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.create_sliding_accumulator(args) + } + + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } + + /// See [`AggregateUDFImpl::with_beneficial_ordering`] for more details. + pub fn with_beneficial_ordering( + self, + beneficial_ordering: bool, + ) -> Result> { + self.inner + .with_beneficial_ordering(beneficial_ordering) + .map(|updated_udf| updated_udf.map(|udf| Self { inner: udf })) + } + + /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`] + /// for possible options. + pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { + self.inner.order_sensitivity() + } + + /// Reserves the `AggregateUDF` (e.g. returns the `AggregateUDF` that will + /// generate same result with this `AggregateUDF` when iterated in reverse + /// order, and `None` if there is no such `AggregateUDF`). + pub fn reverse_udf(&self) -> ReversedUDAF { + self.inner.reverse_expr() + } + + /// Do the function rewrite + /// + /// See [`AggregateUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() + } + + /// Returns true if the function is max, false if the function is min + /// None in all other cases, used in certain optimizations for + /// or aggregate + pub fn is_descending(&self) -> Option { + self.inner.is_descending() + } + + /// Return the value of this aggregate function if it can be determined + /// entirely from statistics and arguments. + /// + /// See [`AggregateUDFImpl::value_from_stats`] for more details. + pub fn value_from_stats( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + self.inner.value_from_stats(statistics_args) + } + + /// See [`AggregateUDFImpl::default_value`] for more details. + pub fn default_value(&self, data_type: &DataType) -> Result { + self.inner.default_value(data_type) + } + + /// Returns the documentation for this Aggregate UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() } } @@ -219,25 +307,42 @@ where /// # Basic Example /// ``` /// # use std::any::Any; +/// # use std::sync::OnceLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr}; -/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::AccumulatorArgs}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility, Expr, Documentation}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator, function::{AccumulatorArgs, StateFieldsArgs}}; +/// # use datafusion_expr::window_doc_sections::DOC_SECTION_AGGREGATE; /// # use arrow::datatypes::Schema; /// # use arrow::datatypes::Field; +/// /// #[derive(Debug, Clone)] /// struct GeoMeanUdf { -/// signature: Signature -/// }; +/// signature: Signature, +/// } /// /// impl GeoMeanUdf { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable) +/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), /// } /// } /// } /// +/// static DOCUMENTATION: OnceLock = OnceLock::new(); +/// +/// fn get_doc() -> &'static Documentation { +/// DOCUMENTATION.get_or_init(|| { +/// Documentation::builder() +/// .with_doc_section(DOC_SECTION_AGGREGATE) +/// .with_description("calculates a geometric mean") +/// .with_syntax_example("geo_mean(2.0)") +/// .with_argument("arg1", "The Float64 number for the geometric mean") +/// .build() +/// .unwrap() +/// }) +/// } +/// /// /// Implement the AggregateUDFImpl trait for GeoMeanUdf /// impl AggregateUDFImpl for GeoMeanUdf { /// fn as_any(&self) -> &dyn Any { self } @@ -245,18 +350,21 @@ where /// fn signature(&self) -> &Signature { &self.signature } /// fn return_type(&self, args: &[DataType]) -> Result { /// if !matches!(args.get(0), Some(&DataType::Float64)) { -/// return plan_err!("add_one only accepts Float64 arguments"); +/// return plan_err!("geo_mean only accepts Float64 arguments"); /// } /// Ok(DataType::Float64) /// } /// // This is the accumulator factory; DataFusion uses it to create new accumulators. /// fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { unimplemented!() } -/// fn state_fields(&self, _name: &str, value_type: DataType, _ordering_fields: Vec) -> Result> { +/// fn state_fields(&self, args: StateFieldsArgs) -> Result> { /// Ok(vec![ -/// Field::new("value", value_type, true), +/// Field::new("value", args.return_type.clone(), true), /// Field::new("ordering", DataType::UInt32, true) /// ]) /// } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } /// } /// /// // Create a new AggregateUDF from the implementation @@ -266,6 +374,9 @@ where /// let expr = geometric_mean.call(vec![col("a")]); /// ``` pub trait AggregateUDFImpl: Debug + Send + Sync { + // Note: When adding any methods (with default implementations), remember to add them also + // into the AliasedAggregateUDFImpl below! + /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -280,6 +391,16 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// the arguments fn return_type(&self, arg_types: &[DataType]) -> Result; + /// Whether the aggregate function is nullable. + /// + /// Nullable means that that the function could return `null` for any inputs. + /// For example, aggregate functions like `COUNT` always return a non null value + /// but others like `MIN` will return `NULL` if there is nullable input. + /// Note that if the function is declared as *not* nullable, make sure the [`AggregateUDFImpl::default_value`] is `non-null` + fn is_nullable(&self) -> bool { + true + } + /// Return a new [`Accumulator`] that aggregates values for a specific /// group during query execution. /// @@ -289,11 +410,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Return the fields used to store the intermediate state of this accumulator. /// - /// # Arguments: - /// 1. `name`: the name of the expression (e.g. AVG, SUM, etc) - /// 2. `value_type`: Aggregate's aggregate's output (returned by [`Self::return_type`]) - /// 3. `ordering_fields`: the fields used to order the input arguments, if any. - /// Empty if no ordering expression is provided. + /// See [`Accumulator::state`] for background information. + /// + /// args: [`StateFieldsArgs`] contains arguments passed to the + /// aggregate function's accumulator. /// /// # Notes: /// @@ -309,19 +429,17 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// The name of the fields must be unique within the query and thus should /// be derived from `name`. See [`format_state_name`] for a utility function /// to generate a unique name. - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { - let value_fields = vec![Field::new( - format_state_name(name, "value"), - value_type, + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let fields = vec![Field::new( + format_state_name(args.name, "value"), + args.return_type.clone(), true, )]; - Ok(value_fields.into_iter().chain(ordering_fields).collect()) + Ok(fields + .into_iter() + .chain(args.ordering_fields.to_vec()) + .collect()) } /// If the aggregate expression has a specialized @@ -331,10 +449,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// # Notes /// /// Even if this function returns true, DataFusion will still use - /// `Self::accumulator` for certain queries, such as when this aggregate is + /// [`Self::accumulator`] for certain queries, such as when this aggregate is /// used as a window function or when there no GROUP BY columns in the /// query. - fn groups_accumulator_supported(&self) -> bool { + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { false } @@ -343,7 +461,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// For maximum performance, a [`GroupsAccumulator`] should be /// implemented in addition to [`Accumulator`]. - fn create_groups_accumulator(&self) -> Result> { + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") } @@ -354,6 +475,197 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Sliding accumulator is an alternative accumulator that can be used for + /// window functions. It has retract method to revert the previous update. + /// + /// See [retract_batch] for more details. + /// + /// [retract_batch]: datafusion_expr_common::accumulator::Accumulator::retract_batch + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.accumulator(args) + } + + /// Sets the indicator whether ordering requirements of the AggregateUDFImpl is + /// satisfied by its input. If this is not the case, UDFs with order + /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce + /// the correct result with possibly more work internally. + /// + /// # Returns + /// + /// Returns `Ok(Some(updated_udf))` if the process completes successfully. + /// If the expression can benefit from existing input ordering, but does + /// not implement the method, returns an error. Order insensitive and hard + /// requirement aggregators return `Ok(None)`. + fn with_beneficial_ordering( + self: Arc, + _beneficial_ordering: bool, + ) -> Result>> { + if self.order_sensitivity().is_beneficial() { + return exec_err!( + "Should implement with satisfied for aggregator :{:?}", + self.name() + ); + } + Ok(None) + } + + /// Gets the order sensitivity of the UDF. See [`AggregateOrderSensitivity`] + /// for possible options. + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + // We have hard ordering requirements by default, meaning that order + // sensitive UDFs need their input orderings to satisfy their ordering + // requirements to generate correct results. + AggregateOrderSensitivity::HardRequirement + } + + /// Optionally apply per-UDaF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Returns + /// + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + /// + /// closure returns simplified [Expr] or an error. + /// + fn simplify(&self) -> Option { + None + } + + /// Returns the reverse expression of the aggregate function. + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::NotSupported + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`AggregateUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDAFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` + /// to ensure the argument was cast to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } + + /// Return true if this aggregate UDF is equal to the other. + /// + /// Allows customizing the equality of aggregate UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this aggregate UDF. + /// + /// Allows customizing the hash code of aggregate UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } + + /// If this function is max, return true + /// If the function is min, return false + /// Otherwise return None (the default) + /// + /// + /// Note: this is used to use special aggregate implementations in certain conditions + fn is_descending(&self) -> Option { + None + } + + /// Return the value of this aggregate function if it can be determined + /// entirely from statistics and arguments. + /// + /// Using a [`ScalarValue`] rather than a runtime computation can significantly + /// improving query performance. + /// + /// For example, if the minimum value of column `x` is known to be `42` from + /// statistics, then the aggregate `MIN(x)` should return `Some(ScalarValue(42))` + fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + None + } + + /// Returns default value of the function given the input is all `null`. + /// + /// Most of the aggregate function return Null if input is Null, + /// while `count` returns 0 if input is Null + fn default_value(&self, data_type: &DataType) -> Result { + ScalarValue::try_from(data_type) + } + + /// Returns the documentation for this Aggregate UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } +} + +impl PartialEq for dyn AggregateUDFImpl { + fn eq(&self, other: &Self) -> bool { + self.equals(other) + } +} + +// Manual implementation of `PartialOrd` +// There might be some wackiness with it, but this is based on the impl of eq for AggregateUDFImpl +// https://users.rust-lang.org/t/how-to-compare-two-trait-objects-for-equality/88063/5 +impl PartialOrd for dyn AggregateUDFImpl { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name().partial_cmp(other.name()) { + Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), + cmp => cmp, + } + } +} + +pub enum ReversedUDAF { + /// The expression is the same as the original expression, like SUM, COUNT + Identical, + /// The expression does not support reverse calculation + NotSupported, + /// The expression is different from the original expression + Reversed(Arc), } /// AggregateUDF that adds an alias to the underlying function. It is better to @@ -400,51 +712,220 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl { fn aliases(&self) -> &[String] { &self.aliases } -} -/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers -/// of the older API -pub struct AggregateUDFLegacyWrapper { - /// name - name: String, - /// Signature (input arguments) - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// actual implementation - accumulator: AccumulatorFactoryFunction, -} + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.inner.state_fields(args) + } -impl Debug for AggregateUDFLegacyWrapper { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + self.inner.groups_accumulator_supported(args) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.create_groups_accumulator(args) + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + self.inner.accumulator(args) + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Arc::clone(&self.inner) + .with_beneficial_ordering(beneficial_ordering) + .map(|udf| { + udf.map(|udf| { + Arc::new(AliasedAggregateUDFImpl { + inner: udf, + aliases: self.aliases.clone(), + }) as Arc + }) + }) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + self.inner.order_sensitivity() + } + + fn simplify(&self) -> Option { + self.inner.simplify() + } + + fn reverse_expr(&self) -> ReversedUDAF { + self.inner.reverse_expr() + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } + + fn is_descending(&self) -> Option { + self.inner.is_descending() + } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() } } -impl AggregateUDFImpl for AggregateUDFLegacyWrapper { - fn as_any(&self) -> &dyn Any { - self +// Aggregate UDF doc sections for use in public documentation +pub mod aggregate_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_GENERAL, + DOC_SECTION_STATISTICAL, + DOC_SECTION_APPROXIMATE, + ] + } + + pub const DOC_SECTION_GENERAL: DocSection = DocSection { + include: true, + label: "General Functions", + description: None, + }; + + pub const DOC_SECTION_STATISTICAL: DocSection = DocSection { + include: true, + label: "Statistical Functions", + description: None, + }; + + pub const DOC_SECTION_APPROXIMATE: DocSection = DocSection { + include: true, + label: "Approximate Functions", + description: None, + }; +} + +#[cfg(test)] +mod test { + use crate::{AggregateUDF, AggregateUDFImpl}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::Result; + use datafusion_expr_common::accumulator::Accumulator; + use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_functions_aggregate_common::accumulator::{ + AccumulatorArgs, StateFieldsArgs, + }; + use std::any::Any; + use std::cmp::Ordering; + + #[derive(Debug, Clone)] + struct AMeanUdf { + signature: Signature, + } + + impl AMeanUdf { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } } - fn name(&self) -> &str { - &self.name + impl AggregateUDFImpl for AMeanUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "a" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!() + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { + unimplemented!() + } + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!() + } } - fn signature(&self) -> &Signature { - &self.signature + #[derive(Debug, Clone)] + struct BMeanUdf { + signature: Signature, + } + impl BMeanUdf { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } } - fn return_type(&self, arg_types: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(arg_types)?; - Ok(res.as_ref().clone()) + impl AggregateUDFImpl for BMeanUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "b" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!() + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { + unimplemented!() + } + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!() + } } - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - (self.accumulator)(acc_args) + #[test] + fn test_partial_ord() { + // Test validates that partial ord is defined for AggregateUDF using the name and signature, + // not intended to exhaustively test all possibilities + let a1 = AggregateUDF::from(AMeanUdf::new()); + let a2 = AggregateUDF::from(AMeanUdf::new()); + assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal)); + + let b1 = AggregateUDF::from(BMeanUdf::new()); + assert!(a1 < b1); + assert!(!(a1 == b1)); } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 4557fe60a447..003a3ed36a60 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,17 +17,19 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::expr::schema_name_from_exprs_comma_seperated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; +use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ - ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, - ScalarFunctionImplementation, Signature, + ColumnarValue, Documentation, Expr, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::{ExprSchema, Result}; +use datafusion_common::{not_impl_err, ExprSchema, Result}; +use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; -use std::fmt; +use std::cmp::Ordering; use std::fmt::Debug; -use std::fmt::Formatter; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; /// Logical representation of a Scalar User Defined Function. @@ -40,7 +42,9 @@ use std::sync::Arc; /// 1. For simple use cases, use [`create_udf`] (examples in [`simple_udf.rs`]). /// /// 2. For advanced use cases, use [`ScalarUDFImpl`] which provides full API -/// access (examples in [`advanced_udf.rs`]). +/// access (examples in [`advanced_udf.rs`]). +/// +/// See [`Self::call`] to invoke a `ScalarUDF` with arguments. /// /// # API Note /// @@ -57,39 +61,29 @@ pub struct ScalarUDF { impl PartialEq for ScalarUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) + } +} + +// Manual implementation based on `ScalarUDFImpl::equals` +impl PartialOrd for ScalarUDF { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name().partial_cmp(other.name()) { + Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), + cmp => cmp, + } } } impl Eq for ScalarUDF {} -impl std::hash::Hash for ScalarUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for ScalarUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } impl ScalarUDF { - /// Create a new ScalarUDF from low level details. - /// - /// See [`ScalarUDFImpl`] for a more convenient way to create a - /// `ScalarUDF` using trait objects - #[deprecated(since = "34.0.0", note = "please implement ScalarUDFImpl instead")] - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - fun: &ScalarFunctionImplementation, - ) -> Self { - Self::new_from_impl(ScalarUdfLegacyWrapper { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), - }) - } - /// Create a new `ScalarUDF` from a `[ScalarUDFImpl]` trait object /// /// Note this is the same as using the `From` impl (`ScalarUDF::from`) @@ -103,8 +97,8 @@ impl ScalarUDF { } /// Return the underlying [`ScalarUDFImpl`] trait object for this function - pub fn inner(&self) -> Arc { - self.inner.clone() + pub fn inner(&self) -> &Arc { + &self.inner } /// Adds additional names that can be used to invoke this function, in @@ -112,13 +106,22 @@ impl ScalarUDF { /// /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedScalarUDFImpl::new(Arc::clone(&self.inner), aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified /// arguments. /// - /// This utility allows using the UDF without requiring access to the registry. + /// This utility allows easily calling UDFs + /// + /// # Example + /// ```no_run + /// use datafusion_expr::{col, lit, ScalarUDF}; + /// # fn my_udf() -> ScalarUDF { unimplemented!() } + /// let my_func: ScalarUDF = my_udf(); + /// // Create an expr for `my_func(a, 12.3)` + /// let expr = my_func.call(vec![col("a"), lit(12.3)]); + /// ``` pub fn call(&self, args: Vec) -> Expr { Expr::ScalarFunction(crate::expr::ScalarFunction::new_udf( Arc::new(self.clone()), @@ -133,6 +136,20 @@ impl ScalarUDF { self.inner.name() } + /// Returns this function's display_name. + /// + /// See [`ScalarUDFImpl::display_name`] for more details + pub fn display_name(&self, args: &[Expr]) -> Result { + self.inner.display_name(args) + } + + /// Returns this function's schema_name. + /// + /// See [`ScalarUDFImpl::schema_name`] for more details + pub fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + /// Returns the aliases for this function. /// /// See [`ScalarUDF::with_aliases`] for more details @@ -176,28 +193,113 @@ impl ScalarUDF { /// Invoke the function on `args`, returning the appropriate result. /// /// See [`ScalarUDFImpl::invoke`] for more details. + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] pub fn invoke(&self, args: &[ColumnarValue]) -> Result { + #[allow(deprecated)] self.inner.invoke(args) } + pub fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + self.inner.is_nullable(args, schema) + } + + /// Invoke the function with `args` and number of rows, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_batch`] for more details. + pub fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + self.inner.invoke_batch(args, number_rows) + } + + /// Invoke the function without `args` but number of rows, returning the appropriate result. + /// + /// See [`ScalarUDFImpl::invoke_no_args`] for more details. + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] + pub fn invoke_no_args(&self, number_rows: usize) -> Result { + #[allow(deprecated)] + self.inner.invoke_no_args(number_rows) + } + /// Returns a `ScalarFunctionImplementation` that can invoke the function /// during execution + #[deprecated(since = "42.0.0", note = "Use `invoke_batch` instead")] pub fn fun(&self) -> ScalarFunctionImplementation { - let captured = self.inner.clone(); + let captured = Arc::clone(&self.inner); + #[allow(deprecated)] Arc::new(move |args| captured.invoke(args)) } - /// This function specifies monotonicity behaviors for User defined scalar functions. - /// - /// See [`ScalarUDFImpl::monotonicity`] for more details. - pub fn monotonicity(&self) -> Result> { - self.inner.monotonicity() - } - /// Get the circuits of inner implementation pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } + + /// Computes the output interval for a [`ScalarUDF`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `inputs` are the intervals for the inputs (children) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + pub fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + self.inner.evaluate_bounds(inputs) + } + + /// Updates bounds for child expressions, given a known interval for this + /// function. This is used to propagate constraints down through an expression + /// tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`. + pub fn propagate_constraints( + &self, + interval: &Interval, + inputs: &[&Interval], + ) -> Result>> { + self.inner.propagate_constraints(interval, inputs) + } + + /// Calculates the [`SortProperties`] of this function based on its + /// children's properties. + pub fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + self.inner.output_ordering(inputs) + } + + /// See [`ScalarUDFImpl::coerce_types`] for more details. + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } + + /// Returns the documentation for this Scalar UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } } impl From for ScalarUDF @@ -222,22 +324,39 @@ where /// # Basic Example /// ``` /// # use std::any::Any; +/// # use std::sync::OnceLock; /// # use arrow::datatypes::DataType; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{col, ColumnarValue, Documentation, Signature, Volatility}; /// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF}; +/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +/// /// #[derive(Debug)] /// struct AddOne { -/// signature: Signature -/// }; +/// signature: Signature, +/// } /// /// impl AddOne { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), /// } /// } /// } +/// +/// static DOCUMENTATION: OnceLock = OnceLock::new(); +/// +/// fn get_doc() -> &'static Documentation { +/// DOCUMENTATION.get_or_init(|| { +/// Documentation::builder() +/// .with_doc_section(DOC_SECTION_MATH) +/// .with_description("Add one to an int32") +/// .with_syntax_example("add_one(2)") +/// .with_argument("arg1", "The int32 number to add one to") +/// .build() +/// .unwrap() +/// }) +/// } /// /// /// Implement the ScalarUDFImpl trait for AddOne /// impl ScalarUDFImpl for AddOne { @@ -252,6 +371,9 @@ where /// } /// // The actual implementation would add one to the argument /// fn invoke(&self, args: &[ColumnarValue]) -> Result { unimplemented!() } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } /// } /// /// // Create a new ScalarUDF from the implementation @@ -261,12 +383,33 @@ where /// let expr = add_one.call(vec![col("a")]); /// ``` pub trait ScalarUDFImpl: Debug + Send + Sync { + // Note: When adding any methods (with default implementations), remember to add them also + // into the AliasedScalarUDFImpl below! + /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; /// Returns this function's name fn name(&self) -> &str; + /// Returns the user-defined display name of the UDF given the arguments + fn display_name(&self, args: &[Expr]) -> Result { + let names: Vec = args.iter().map(ToString::to_string).collect(); + // TODO: join with ", " to standardize the formatting of Vec, + Ok(format!("{}({})", self.name(), names.join(","))) + } + + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, args: &[Expr]) -> Result { + Ok(format!( + "{}({})", + self.name(), + schema_name_from_exprs_comma_seperated_without_space(args)? + )) + } + /// Returns the function's [`Signature`] for information about what input /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; @@ -317,15 +460,18 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { self.return_type(arg_types) } + fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { + true + } + /// Invoke the function on `args`, returning the appropriate result /// /// The function will be invoked passed with the slice of [`ColumnarValue`] /// (either scalar or array). /// - /// # Zero Argument Functions - /// If the function has zero parameters (e.g. `now()`) it will be passed a - /// single element slice which is a a null array to indicate the batch's row - /// count (so the function can know the resulting array size). + /// If the function does not take any arguments, please use [invoke_no_args] + /// instead and return [not_impl_err] for this function. + /// /// /// # Performance /// @@ -335,7 +481,58 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments /// to arrays, which will likely be simpler code, but be slower. - fn invoke(&self, args: &[ColumnarValue]) -> Result; + /// + /// [invoke_no_args]: ScalarUDFImpl::invoke_no_args + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!( + "Function {} does not implement invoke but called", + self.name() + ) + } + + /// Invoke the function with `args` and the number of rows, + /// returning the appropriate result. + /// + /// The function will be invoked with the slice of [`ColumnarValue`] + /// (either scalar or array). + /// + /// # Performance + /// + /// For the best performance, the implementations should handle the common case + /// when one or more of their arguments are constant values (aka + /// [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. + fn invoke_batch( + &self, + args: &[ColumnarValue], + number_rows: usize, + ) -> Result { + match args.is_empty() { + true => + { + #[allow(deprecated)] + self.invoke_no_args(number_rows) + } + false => + { + #[allow(deprecated)] + self.invoke(args) + } + } + } + + /// Invoke the function without `args`, instead the number of rows are provided, + /// returning the appropriate result. + #[deprecated(since = "42.1.0", note = "Use `invoke_batch` instead")] + fn invoke_no_args(&self, _number_rows: usize) -> Result { + not_impl_err!( + "Function {} does not implement invoke_no_args but called", + self.name() + ) + } /// Returns any aliases (alternate names) for this function. /// @@ -350,11 +547,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { &[] } - /// This function specifies monotonicity behaviors for User defined scalar functions. - fn monotonicity(&self) -> Result> { - Ok(None) - } - /// Optionally apply per-UDF simplification / rewrite rules. /// /// This can be used to apply function specific simplification rules during @@ -367,8 +559,8 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// optimizations manually for specific UDFs. /// /// # Arguments - /// * 'args': The arguments of the function - /// * 'schema': The schema of the function + /// * `args`: The arguments of the function + /// * `info`: The necessary information for simplification /// /// # Returns /// [`ExprSimplifyResult`] indicating the result of the simplification NOTE @@ -388,6 +580,116 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn short_circuits(&self) -> bool { false } + + /// Computes the output interval for a [`ScalarUDFImpl`], given the input + /// intervals. + /// + /// # Parameters + /// + /// * `children` are the intervals for the children (inputs) of this function. + /// + /// # Example + /// + /// If the function is `ABS(a)`, and the input interval is `a: [-3, 2]`, + /// then the output interval would be `[0, 3]`. + fn evaluate_bounds(&self, _input: &[&Interval]) -> Result { + // We cannot assume the input datatype is the same of output type. + Interval::make_unbounded(&DataType::Null) + } + + /// Updates bounds for child expressions, given a known interval for this + /// function. This is used to propagate constraints down through an expression + /// tree. + /// + /// # Parameters + /// + /// * `interval` is the currently known interval for this function. + /// * `inputs` are the current intervals for the inputs (children) of this function. + /// + /// # Returns + /// + /// A `Vec` of new intervals for the children, in order. + /// + /// If constraint propagation reveals an infeasibility for any child, returns + /// [`None`]. If none of the children intervals change as a result of + /// propagation, may return an empty vector instead of cloning `children`. + /// This is the default (and conservative) return value. + /// + /// # Example + /// + /// If the function is `ABS(a)`, the current `interval` is `[4, 5]` and the + /// input `a` is given as `[-7, 3]`, then propagation would return `[-5, 3]`. + fn propagate_constraints( + &self, + _interval: &Interval, + _inputs: &[&Interval], + ) -> Result>> { + Ok(Some(vec![])) + } + + /// Calculates the [`SortProperties`] of this function based on its + /// children's properties. + fn output_ordering(&self, _inputs: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (i.e. with `1` as an integer), coerce_types can return `[DataType::Float64]` + /// to ensure the argument is converted to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } + + /// Return true if this scalar UDF is equal to the other. + /// + /// Allows customizing the equality of scalar UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this scalar UDF. + /// + /// Allows customizing the hash code of scalar UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } + + /// Returns the documentation for this Scalar UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -405,7 +707,6 @@ impl AliasedScalarUDFImpl { ) -> Self { let mut aliases = inner.aliases().to_vec(); aliases.extend(new_aliases.into_iter().map(|s| s.to_string())); - Self { inner, aliases } } } @@ -414,10 +715,19 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { self.inner.name() } + fn display_name(&self, args: &[Expr]) -> Result { + self.inner.display_name(args) + } + + fn schema_name(&self, args: &[Expr]) -> Result { + self.inner.schema_name(args) + } + fn signature(&self) -> &Signature { self.inner.signature() } @@ -426,69 +736,169 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + schema: &dyn ExprSchema, + arg_types: &[DataType], + ) -> Result { + self.inner.return_type_from_exprs(args, schema, arg_types) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { + #[allow(deprecated)] self.inner.invoke(args) } - fn aliases(&self) -> &[String] { - &self.aliases + fn invoke_no_args(&self, number_rows: usize) -> Result { + #[allow(deprecated)] + self.inner.invoke_no_args(number_rows) } -} -/// Implementation of [`ScalarUDFImpl`] that wraps the function style pointers -/// of the older API (see -/// for more details) -struct ScalarUdfLegacyWrapper { - /// The name of the function - name: String, - /// The signature (the types of arguments that are supported) - signature: Signature, - /// Function that returns the return type given the argument types - return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - fun: ScalarFunctionImplementation, -} + fn simplify( + &self, + args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + self.inner.simplify(args, info) + } -impl Debug for ScalarUdfLegacyWrapper { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() + fn short_circuits(&self) -> bool { + self.inner.short_circuits() } -} -impl ScalarUDFImpl for ScalarUdfLegacyWrapper { - fn as_any(&self) -> &dyn Any { - self + fn evaluate_bounds(&self, input: &[&Interval]) -> Result { + self.inner.evaluate_bounds(input) } - fn name(&self) -> &str { - &self.name + + fn propagate_constraints( + &self, + interval: &Interval, + inputs: &[&Interval], + ) -> Result>> { + self.inner.propagate_constraints(interval, inputs) } - fn signature(&self) -> &Signature { - &self.signature + fn output_ordering(&self, inputs: &[ExprProperties]) -> Result { + self.inner.output_ordering(inputs) } - fn return_type(&self, arg_types: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(arg_types)?; - Ok(res.as_ref().clone()) + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - (self.fun)(args) + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } } - fn aliases(&self) -> &[String] { - &[] + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() } } + +// Scalar UDF doc sections for use in public documentation +pub mod scalar_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_MATH, + DOC_SECTION_CONDITIONAL, + DOC_SECTION_STRING, + DOC_SECTION_BINARY_STRING, + DOC_SECTION_REGEX, + DOC_SECTION_DATETIME, + DOC_SECTION_ARRAY, + DOC_SECTION_STRUCT, + DOC_SECTION_MAP, + DOC_SECTION_HASHING, + DOC_SECTION_OTHER, + ] + } + + pub const DOC_SECTION_MATH: DocSection = DocSection { + include: true, + label: "Math Functions", + description: None, + }; + + pub const DOC_SECTION_CONDITIONAL: DocSection = DocSection { + include: true, + label: "Conditional Functions", + description: None, + }; + + pub const DOC_SECTION_STRING: DocSection = DocSection { + include: true, + label: "String Functions", + description: None, + }; + + pub const DOC_SECTION_BINARY_STRING: DocSection = DocSection { + include: true, + label: "Binary String Functions", + description: None, + }; + + pub const DOC_SECTION_REGEX: DocSection = DocSection { + include: true, + label: "Regular Expression Functions", + description: Some( + r#"Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) +regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) +(minus support for several features including look-around and backreferences). +The following regular expression functions are supported:"#, + ), + }; + + pub const DOC_SECTION_DATETIME: DocSection = DocSection { + include: true, + label: "Time and Date Functions", + description: None, + }; + + pub const DOC_SECTION_ARRAY: DocSection = DocSection { + include: true, + label: "Array Functions", + description: None, + }; + + pub const DOC_SECTION_STRUCT: DocSection = DocSection { + include: true, + label: "Struct Functions", + description: None, + }; + + pub const DOC_SECTION_MAP: DocSection = DocSection { + include: true, + label: "Map Functions", + description: None, + }; + + pub const DOC_SECTION_HASHING: DocSection = DocSection { + include: true, + label: "Hashing Functions", + description: None, + }; + + pub const DOC_SECTION_OTHER: DocSection = DocSection { + include: true, + label: "Other Functions", + description: None, + }; +} diff --git a/datafusion/expr/src/udf_docs.rs b/datafusion/expr/src/udf_docs.rs new file mode 100644 index 000000000000..a124361e42a3 --- /dev/null +++ b/datafusion/expr/src/udf_docs.rs @@ -0,0 +1,230 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::exec_err; +use datafusion_common::Result; + +/// Documentation for use by [`ScalarUDFImpl`](crate::ScalarUDFImpl), +/// [`AggregateUDFImpl`](crate::AggregateUDFImpl) and [`WindowUDFImpl`](crate::WindowUDFImpl) functions +/// that will be used to generate public documentation. +/// +/// The name of the udf will be pulled from the [`ScalarUDFImpl::name`](crate::ScalarUDFImpl::name), +/// [`AggregateUDFImpl::name`](crate::AggregateUDFImpl::name) or [`WindowUDFImpl::name`](crate::WindowUDFImpl::name) +/// function as appropriate. +/// +/// All strings in the documentation are required to be +/// in [markdown format](https://www.markdownguide.org/basic-syntax/). +/// +/// Currently, documentation only supports a single language +/// thus all text should be in English. +#[derive(Debug, Clone)] +pub struct Documentation { + /// The section in the documentation where the UDF will be documented + pub doc_section: DocSection, + /// The description for the UDF + pub description: String, + /// A brief example of the syntax. For example "ascii(str)" + pub syntax_example: String, + /// A sql example for the UDF, usually in the form of a sql prompt + /// query and output. It is strongly recommended to provide an + /// example for anything but the most basic UDF's + pub sql_example: Option, + /// Arguments for the UDF which will be displayed in array order. + /// Left member of a pair is the argument name, right is a + /// description for the argument + pub arguments: Option>, + /// A list of alternative syntax examples for a function + pub alternative_syntax: Option>, + /// Related functions if any. Values should match the related + /// udf's name exactly. Related udf's must be of the same + /// UDF type (scalar, aggregate or window) for proper linking to + /// occur + pub related_udfs: Option>, +} + +impl Documentation { + /// Returns a new [`DocumentationBuilder`] with no options set. + pub fn builder() -> DocumentationBuilder { + DocumentationBuilder::new() + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct DocSection { + /// True to include this doc section in the public + /// documentation, false otherwise + pub include: bool, + /// A display label for the doc section. For example: "Math Expressions" + pub label: &'static str, + /// An optional description for the doc section + pub description: Option<&'static str>, +} + +/// A builder to be used for building [`Documentation`]'s. +/// +/// Example: +/// +/// ```rust +/// # use datafusion_expr::Documentation; +/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +/// # use datafusion_common::Result; +/// # +/// # fn main() -> Result<()> { +/// let documentation = Documentation::builder() +/// .with_doc_section(DOC_SECTION_MATH) +/// .with_description("Add one to an int32") +/// .with_syntax_example("add_one(2)") +/// .with_argument("arg_1", "The int32 number to add one to") +/// .build()?; +/// Ok(()) +/// # } +pub struct DocumentationBuilder { + pub doc_section: Option, + pub description: Option, + pub syntax_example: Option, + pub sql_example: Option, + pub arguments: Option>, + pub alternative_syntax: Option>, + pub related_udfs: Option>, +} + +impl DocumentationBuilder { + pub fn new() -> Self { + Self { + doc_section: None, + description: None, + syntax_example: None, + sql_example: None, + arguments: None, + alternative_syntax: None, + related_udfs: None, + } + } + + pub fn with_doc_section(mut self, doc_section: DocSection) -> Self { + self.doc_section = Some(doc_section); + self + } + + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + pub fn with_syntax_example(mut self, syntax_example: impl Into) -> Self { + self.syntax_example = Some(syntax_example.into()); + self + } + + pub fn with_sql_example(mut self, sql_example: impl Into) -> Self { + self.sql_example = Some(sql_example.into()); + self + } + + /// Adds documentation for a specific argument to the documentation. + /// + /// Arguments are displayed in the order they are added. + pub fn with_argument( + mut self, + arg_name: impl Into, + arg_description: impl Into, + ) -> Self { + let mut args = self.arguments.unwrap_or_default(); + args.push((arg_name.into(), arg_description.into())); + self.arguments = Some(args); + self + } + + /// Add a standard "expression" argument to the documentation + /// + /// The argument is rendered like below if Some() is passed through: + /// + /// ```text + /// : + /// expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` + /// + /// The argument is rendered like below if None is passed through: + /// + /// ```text + /// : + /// The expression to operate on. Can be a constant, column, or function, and any combination of operators. + /// ``` + pub fn with_standard_argument( + self, + arg_name: impl Into, + expression_type: Option<&str>, + ) -> Self { + let description = format!( + "{} expression to operate on. Can be a constant, column, or function, and any combination of operators.", + expression_type.unwrap_or("The") + ); + self.with_argument(arg_name, description) + } + + pub fn with_alternative_syntax(mut self, syntax_name: impl Into) -> Self { + let mut alternative_syntax_array = self.alternative_syntax.unwrap_or_default(); + alternative_syntax_array.push(syntax_name.into()); + self.alternative_syntax = Some(alternative_syntax_array); + self + } + + pub fn with_related_udf(mut self, related_udf: impl Into) -> Self { + let mut related = self.related_udfs.unwrap_or_default(); + related.push(related_udf.into()); + self.related_udfs = Some(related); + self + } + + pub fn build(self) -> Result { + let Self { + doc_section, + description, + syntax_example, + sql_example, + arguments, + alternative_syntax, + related_udfs, + } = self; + + if doc_section.is_none() { + return exec_err!("Documentation must have a doc section"); + } + if description.is_none() { + return exec_err!("Documentation must have a description"); + } + if syntax_example.is_none() { + return exec_err!("Documentation must have a syntax_example"); + } + + Ok(Documentation { + doc_section: doc_section.unwrap(), + description: description.unwrap(), + syntax_example: syntax_example.unwrap(), + sql_example, + arguments, + alternative_syntax, + related_udfs, + }) + } +} + +impl Default for DocumentationBuilder { + fn default() -> Self { + Self::new() + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 5a8373509a40..124625280670 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -17,28 +17,38 @@ //! [`WindowUDF`]: User Defined Window Functions -use crate::{ - Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, - WindowFrame, -}; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use arrow::compute::SortOptions; +use std::cmp::Ordering; +use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ any::Any, fmt::{self, Debug, Display, Formatter}, sync::Arc, }; +use arrow::datatypes::{DataType, Field}; + +use crate::expr::WindowFunction; +use crate::{ + function::WindowFunctionSimplification, Documentation, Expr, PartitionEvaluator, + Signature, +}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + /// Logical representation of a user-defined window function (UDWF) /// A UDWF is different from a UDF in that it is stateful across batches. /// /// See the documentation on [`PartitionEvaluator`] for more details /// /// 1. For simple use cases, use [`create_udwf`] (examples in -/// [`simple_udwf.rs`]). +/// [`simple_udwf.rs`]). /// /// 2. For advanced use cases, use [`WindowUDFImpl`] which provides full API -/// access (examples in [`advanced_udwf.rs`]). +/// access (examples in [`advanced_udwf.rs`]). /// /// # API Note /// This is a separate struct from `WindowUDFImpl` to maintain backwards @@ -48,7 +58,7 @@ use std::{ /// [`create_udwf`]: crate::expr_fn::create_udwf /// [`simple_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/simple_udwf.rs /// [`advanced_udwf.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udwf.rs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialOrd)] pub struct WindowUDF { inner: Arc, } @@ -62,39 +72,19 @@ impl Display for WindowUDF { impl PartialEq for WindowUDF { fn eq(&self, other: &Self) -> bool { - self.name() == other.name() && self.signature() == other.signature() + self.inner.equals(other.inner.as_ref()) } } impl Eq for WindowUDF {} -impl std::hash::Hash for WindowUDF { - fn hash(&self, state: &mut H) { - self.name().hash(state); - self.signature().hash(state); +impl Hash for WindowUDF { + fn hash(&self, state: &mut H) { + self.inner.hash_value().hash(state) } } impl WindowUDF { - /// Create a new WindowUDF from low level details. - /// - /// See [`WindowUDFImpl`] for a more convenient way to create a - /// `WindowUDF` using trait objects - #[deprecated(since = "34.0.0", note = "please implement WindowUDFImpl instead")] - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - partition_evaluator_factory: &PartitionEvaluatorFactory, - ) -> Self { - Self::new_from_impl(WindowUDFLegacyWrapper { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - partition_evaluator_factory: partition_evaluator_factory.clone(), - }) - } - /// Create a new `WindowUDF` from a `[WindowUDFImpl]` trait object /// /// Note this is the same as using the `From` impl (`WindowUDF::from`) @@ -108,8 +98,8 @@ impl WindowUDF { } /// Return the underlying [`WindowUDFImpl`] trait object for this function - pub fn inner(&self) -> Arc { - self.inner.clone() + pub fn inner(&self) -> &Arc { + &self.inner } /// Adds additional names that can be used to invoke this function, in @@ -117,31 +107,22 @@ impl WindowUDF { /// /// If you implement [`WindowUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedWindowUDFImpl::new(self.inner.clone(), aliases)) + Self::new_from_impl(AliasedWindowUDFImpl::new(Arc::clone(&self.inner), aliases)) } - /// creates a [`Expr`] that calls the window function given - /// the `partition_by`, `order_by`, and `window_frame` definition + /// creates a [`Expr`] that calls the window function with default + /// values for `order_by`, `partition_by`, `window_frame`. /// - /// This utility allows using the UDWF without requiring access to - /// the registry, such as with the DataFrame API. - pub fn call( - &self, - args: Vec, - partition_by: Vec, - order_by: Vec, - window_frame: WindowFrame, - ) -> Expr { + /// See [`ExprFunctionExt`] for details on setting these values. + /// + /// This utility allows using a user defined window function without + /// requiring access to the registry, such as with the DataFrame API. + /// + /// [`ExprFunctionExt`]: crate::expr_fn::ExprFunctionExt + pub fn call(&self, args: Vec) -> Expr { let fun = crate::WindowFunctionDefinition::WindowUDF(Arc::new(self.clone())); - Expr::WindowFunction(crate::expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment: None, - }) + Expr::WindowFunction(WindowFunction::new(fun, args)) } /// Returns this function's name @@ -163,16 +144,61 @@ impl WindowUDF { self.inner.signature() } - /// Return the type of the function given its input types + /// Do the function rewrite /// - /// See [`WindowUDFImpl::return_type`] for more details. - pub fn return_type(&self, args: &[DataType]) -> Result { - self.inner.return_type(args) + /// See [`WindowUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() } + /// Expressions that are passed to the [`PartitionEvaluator`]. + /// + /// See [`WindowUDFImpl::expressions`] for more details. + pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + self.inner.expressions(expr_args) + } /// Return a `PartitionEvaluator` for evaluating this window function - pub fn partition_evaluator_factory(&self) -> Result> { - self.inner.partition_evaluator() + pub fn partition_evaluator_factory( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) + } + + /// Returns the field of the final result of evaluating this window function. + /// + /// See [`WindowUDFImpl::field`] for more details. + pub fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + self.inner.field(field_args) + } + + /// Returns custom result ordering introduced by this window function + /// which is used to update ordering equivalences. + /// + /// See [`WindowUDFImpl::sort_options`] for more details. + pub fn sort_options(&self) -> Option { + self.inner.sort_options() + } + + /// See [`WindowUDFImpl::coerce_types`] for more details. + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } + + /// Returns the reversed user-defined window function when the + /// order of evaluation is reversed. + /// + /// See [`WindowUDFImpl::reverse_expr`] for more details. + pub fn reverse_expr(&self) -> ReversedUDWF { + self.inner.reverse_expr() + } + + /// Returns the documentation for this Window UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + pub fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() } } @@ -198,50 +224,82 @@ where /// # Basic Example /// ``` /// # use std::any::Any; -/// # use arrow::datatypes::DataType; +/// # use std::sync::OnceLock; +/// # use arrow::datatypes::{DataType, Field}; /// # use datafusion_common::{DataFusionError, plan_err, Result}; -/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame}; +/// # use datafusion_expr::{col, Signature, Volatility, PartitionEvaluator, WindowFrame, ExprFunctionExt, Documentation}; /// # use datafusion_expr::{WindowUDFImpl, WindowUDF}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +/// /// #[derive(Debug, Clone)] /// struct SmoothIt { -/// signature: Signature -/// }; +/// signature: Signature, +/// } /// /// impl SmoothIt { /// fn new() -> Self { /// Self { -/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable) +/// signature: Signature::uniform(1, vec![DataType::Int32], Volatility::Immutable), /// } /// } /// } /// -/// /// Implement the WindowUDFImpl trait for AddOne +/// static DOCUMENTATION: OnceLock = OnceLock::new(); +/// +/// fn get_doc() -> &'static Documentation { +/// DOCUMENTATION.get_or_init(|| { +/// Documentation::builder() +/// .with_doc_section(DOC_SECTION_ANALYTICAL) +/// .with_description("smooths the windows") +/// .with_syntax_example("smooth_it(2)") +/// .with_argument("arg1", "The int32 number to smooth by") +/// .build() +/// .unwrap() +/// }) +/// } +/// +/// /// Implement the WindowUDFImpl trait for SmoothIt /// impl WindowUDFImpl for SmoothIt { /// fn as_any(&self) -> &dyn Any { self } /// fn name(&self) -> &str { "smooth_it" } /// fn signature(&self) -> &Signature { &self.signature } -/// fn return_type(&self, args: &[DataType]) -> Result { -/// if !matches!(args.get(0), Some(&DataType::Int32)) { -/// return plan_err!("smooth_it only accepts Int32 arguments"); +/// // The actual implementation would smooth the window +/// fn partition_evaluator( +/// &self, +/// _partition_evaluator_args: PartitionEvaluatorArgs, +/// ) -> Result> { +/// unimplemented!() +/// } +/// fn field(&self, field_args: WindowUDFFieldArgs) -> Result { +/// if let Some(DataType::Int32) = field_args.get_input_type(0) { +/// Ok(Field::new(field_args.name(), DataType::Int32, false)) +/// } else { +/// plan_err!("smooth_it only accepts Int32 arguments") /// } -/// Ok(DataType::Int32) /// } -/// // The actual implementation would add one to the argument -/// fn partition_evaluator(&self) -> Result> { unimplemented!() } +/// fn documentation(&self) -> Option<&Documentation> { +/// Some(get_doc()) +/// } /// } /// /// // Create a new WindowUDF from the implementation /// let smooth_it = WindowUDF::from(SmoothIt::new()); /// /// // Call the function `add_one(col)` -/// let expr = smooth_it.call( -/// vec![col("speed")], // smooth_it(speed) -/// vec![col("car")], // PARTITION BY car -/// vec![col("time").sort(true, true)], // ORDER BY time ASC -/// WindowFrame::new(None), -/// ); +/// // smooth_it(speed) OVER (PARTITION BY car ORDER BY time ASC) +/// let expr = smooth_it.call(vec![col("speed")]) +/// .partition_by(vec![col("car")]) +/// .order_by(vec![col("time").sort(true, true)]) +/// .window_frame(WindowFrame::new(None)) +/// .build() +/// .unwrap(); /// ``` pub trait WindowUDFImpl: Debug + Send + Sync { + // Note: When adding any methods (with default implementations), remember to add them also + // into the AliasedWindowUDFImpl below! + /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -252,12 +310,16 @@ pub trait WindowUDFImpl: Debug + Send + Sync { /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; - /// What [`DataType`] will be returned by this function, given the types of - /// the arguments - fn return_type(&self, arg_types: &[DataType]) -> Result; + /// Returns the expressions that are passed to the [`PartitionEvaluator`]. + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args.input_exprs().into() + } /// Invoke the function, returning the [`PartitionEvaluator`] instance - fn partition_evaluator(&self) -> Result>; + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result>; /// Returns any aliases (alternate names) for this function. /// @@ -266,6 +328,136 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDWF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization. The default implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// Example: + /// [`simplify_udwf_expression.rs`]: + /// + /// # Returns + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + fn simplify(&self) -> Option { + None + } + + /// Return true if this window UDF is equal to the other. + /// + /// Allows customizing the equality of window UDFs. + /// Must be consistent with [`Self::hash_value`] and follow the same rules as [`Eq`]: + /// + /// - reflexive: `a.equals(a)`; + /// - symmetric: `a.equals(b)` implies `b.equals(a)`; + /// - transitive: `a.equals(b)` and `b.equals(c)` implies `a.equals(c)`. + /// + /// By default, compares [`Self::name`] and [`Self::signature`]. + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + self.name() == other.name() && self.signature() == other.signature() + } + + /// Returns a hash value for this window UDF. + /// + /// Allows customizing the hash code of window UDFs. Similarly to [`Hash`] and [`Eq`], + /// if [`Self::equals`] returns true for two UDFs, their `hash_value`s must be the same. + /// + /// By default, hashes [`Self::name`] and [`Self::signature`]. + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.name().hash(hasher); + self.signature().hash(hasher); + hasher.finish() + } + + /// The [`Field`] of the final result of evaluating this window function. + /// + /// Call `field_args.name()` to get the fully qualified name for defining + /// the [`Field`]. For a complete example see the implementation in the + /// [Basic Example](WindowUDFImpl#basic-example) section. + fn field(&self, field_args: WindowUDFFieldArgs) -> Result; + + /// Allows the window UDF to define a custom result ordering. + /// + /// By default, a window UDF doesn't introduce an ordering. + /// But when specified by a window UDF this is used to update + /// ordering equivalences. + fn sort_options(&self) -> Option { + None + } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`WindowUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDWFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` + /// to ensure the argument was cast to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } + + /// Allows customizing the behavior of the user-defined window + /// function when it is evaluated in reverse order. + fn reverse_expr(&self) -> ReversedUDWF { + ReversedUDWF::NotSupported + } + + /// Returns the documentation for this Window UDF. + /// + /// Documentation can be accessed programmatically as well as + /// generating publicly facing documentation. + fn documentation(&self) -> Option<&Documentation> { + None + } +} + +pub enum ReversedUDWF { + /// The result of evaluating the user-defined window function + /// remains identical when reversed. + Identical, + /// A window function which does not support evaluating the result + /// in reverse order. + NotSupported, + /// Customize the user-defined window function for evaluating the + /// result in reverse order. + Reversed(Arc), +} + +impl PartialEq for dyn WindowUDFImpl { + fn eq(&self, other: &Self) -> bool { + self.equals(other) + } +} + +impl PartialOrd for dyn WindowUDFImpl { + fn partial_cmp(&self, other: &Self) -> Option { + match self.name().partial_cmp(other.name()) { + Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), + cmp => cmp, + } + } } /// WindowUDF that adds an alias to the underlying function. It is better to @@ -301,64 +493,188 @@ impl WindowUDFImpl for AliasedWindowUDFImpl { self.inner.signature() } - fn return_type(&self, arg_types: &[DataType]) -> Result { - self.inner.return_type(arg_types) + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + expr_args + .input_exprs() + .first() + .map_or(vec![], |expr| vec![Arc::clone(expr)]) } - fn partition_evaluator(&self) -> Result> { - self.inner.partition_evaluator() + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + self.inner.partition_evaluator(partition_evaluator_args) } fn aliases(&self) -> &[String] { &self.aliases } -} -/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers -/// of the older API (see -/// for more details) -pub struct WindowUDFLegacyWrapper { - /// name - name: String, - /// signature - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// Return the partition evaluator - partition_evaluator_factory: PartitionEvaluatorFactory, + fn simplify(&self) -> Option { + self.inner.simplify() + } + + fn equals(&self, other: &dyn WindowUDFImpl) -> bool { + if let Some(other) = other.as_any().downcast_ref::() { + self.inner.equals(other.inner.as_ref()) && self.aliases == other.aliases + } else { + false + } + } + + fn hash_value(&self) -> u64 { + let hasher = &mut DefaultHasher::new(); + self.inner.hash_value().hash(hasher); + self.aliases.hash(hasher); + hasher.finish() + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + self.inner.field(field_args) + } + + fn sort_options(&self) -> Option { + self.inner.sort_options() + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + self.inner.documentation() + } } -impl Debug for WindowUDFLegacyWrapper { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish_non_exhaustive() +// Window UDF doc sections for use in public documentation +pub mod window_doc_sections { + use crate::DocSection; + + pub fn doc_sections() -> Vec { + vec![ + DOC_SECTION_AGGREGATE, + DOC_SECTION_RANKING, + DOC_SECTION_ANALYTICAL, + ] } + + pub const DOC_SECTION_AGGREGATE: DocSection = DocSection { + include: true, + label: "Aggregate Functions", + description: Some("All aggregate functions can be used as window functions."), + }; + + pub const DOC_SECTION_RANKING: DocSection = DocSection { + include: true, + label: "Ranking Functions", + description: None, + }; + + pub const DOC_SECTION_ANALYTICAL: DocSection = DocSection { + include: true, + label: "Analytical Functions", + description: None, + }; } -impl WindowUDFImpl for WindowUDFLegacyWrapper { - fn as_any(&self) -> &dyn Any { - self +#[cfg(test)] +mod test { + use crate::{PartitionEvaluator, WindowUDF, WindowUDFImpl}; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::Result; + use datafusion_expr_common::signature::{Signature, Volatility}; + use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; + use std::any::Any; + use std::cmp::Ordering; + + #[derive(Debug, Clone)] + struct AWindowUDF { + signature: Signature, } - fn name(&self) -> &str { - &self.name + impl AWindowUDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + } + } } - fn signature(&self) -> &Signature { - &self.signature + /// Implement the WindowUDFImpl trait for AddOne + impl WindowUDFImpl for AWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "a" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + unimplemented!() + } + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + unimplemented!() + } } - fn return_type(&self, arg_types: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(arg_types)?; - Ok(res.as_ref().clone()) + #[derive(Debug, Clone)] + struct BWindowUDF { + signature: Signature, } - fn partition_evaluator(&self) -> Result> { - (self.partition_evaluator_factory)() + impl BWindowUDF { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + } + } + } + + /// Implement the WindowUDFImpl trait for AddOne + impl WindowUDFImpl for BWindowUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "b" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + unimplemented!() + } + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + unimplemented!() + } + } + + #[test] + fn test_partial_ord() { + let a1 = WindowUDF::from(AWindowUDF::new()); + let a2 = WindowUDF::from(AWindowUDF::new()); + assert_eq!(a1.partial_cmp(&a2), Some(Ordering::Equal)); + + let b1 = WindowUDF::from(BWindowUDF::new()); + assert!(a1 < b1); + assert!(!(a1 == b1)); } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 64fe98c23b08..29c62440abb1 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -18,34 +18,39 @@ //! Expression utilities use std::cmp::Ordering; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; -use crate::expr::{Alias, Sort, WindowFunction}; +use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction}; use crate::expr_rewriter::strip_outer_reference; -use crate::logical_plan::Aggregate; -use crate::signature::{Signature, TypeSignature}; use crate::{ - and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, - Operator, TryCast, + and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, }; +use datafusion_expr_common::signature::{Signature, TypeSignature}; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::utils::get_at_indices; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, Result, - ScalarValue, TableReference, + internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, + DataFusionError, Result, TableReference, }; -use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, WildcardAdditionalOptions}; +use indexmap::IndexSet; +use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem}; + +pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; /// The value to which `COUNT(*)` is expanded to in /// `COUNT()` expressions -pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::Int64(Some(1)); +pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; /// Recursively walk a list of expression trees, collecting the unique set of columns /// referenced in the expression +#[deprecated(since = "40.0.0", note = "Expr::add_column_refs instead")] pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result<()> { for e in expr { expr_to_columns(e, accum)?; @@ -62,9 +67,10 @@ pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result { "Invalid group by expressions, GroupingSet must be the only expression" ); } - Ok(grouping_set.distinct_expr().len()) + // Groupings sets have an additional interal column for the grouping id + Ok(grouping_set.distinct_expr().len() + 1) } else { - Ok(group_expr.len()) + grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len()) } } @@ -199,7 +205,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { if !has_grouping_set || group_expr.len() == 1 { return Ok(group_expr); } - // only process mix grouping sets + // Only process mix grouping sets let partial_sets = group_expr .iter() .map(|expr| { @@ -228,7 +234,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { }) .collect::>>()?; - // cross join + // Cross Join let grouping_sets = partial_sets .into_iter() .map(Ok) @@ -248,7 +254,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec) -> Result> { /// Find all distinct exprs in a list of group by expressions. If the /// first element is a `GroupingSet` expression then it must be the only expr. -pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { +pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() { if group_expr.len() > 1 { return plan_err!( @@ -257,7 +263,11 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { } Ok(grouping_set.distinct_expr()) } else { - Ok(group_expr.to_vec()) + Ok(group_expr + .iter() + .collect::>() + .into_iter() + .collect()) } } @@ -293,7 +303,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::Case { .. } | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::Sort { .. } | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } @@ -303,7 +312,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } - | Expr::GetIndexedField { .. } | Expr::Placeholder(_) | Expr::OuterReferenceColumn { .. } => {} } @@ -318,7 +326,7 @@ fn get_excluded_columns( opt_exclude: Option<&ExcludeSelectItem>, opt_except: Option<&ExceptSelectItem>, schema: &DFSchema, - qualifier: &Option, + qualifier: Option<&TableReference>, ) -> Result> { let mut idents = vec![]; if let Some(excepts) = opt_except { @@ -334,7 +342,7 @@ fn get_excluded_columns( // Excluded columns should be unique let n_elem = idents.len(); let unique_idents = idents.into_iter().collect::>(); - // if HashSet size, and vector length are different, this means that some of the excluded columns + // If HashSet size, and vector length are different, this means that some of the excluded columns // are not unique. In this case return error. if n_elem != unique_idents.len() { return plan_err!("EXCLUDE or EXCEPT contains duplicate column names"); @@ -343,8 +351,7 @@ fn get_excluded_columns( let mut result = vec![]; for ident in unique_idents.into_iter() { let col_name = ident.value.as_str(); - let (qualifier, field) = - schema.qualified_field_with_name(qualifier.as_ref(), col_name)?; + let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?; result.push(Column::from((qualifier, field))); } Ok(result) @@ -376,7 +383,7 @@ fn get_exprs_except_skipped( pub fn expand_wildcard( schema: &DFSchema, plan: &LogicalPlan, - wildcard_options: Option<&WildcardAdditionalOptions>, + wildcard_options: Option<&WildcardOptions>, ) -> Result> { let using_columns = plan.using_columns()?; let mut columns_to_skip = using_columns @@ -400,13 +407,13 @@ pub fn expand_wildcard( .collect::>() }) .collect::>(); - let excluded_columns = if let Some(WildcardAdditionalOptions { - opt_exclude, - opt_except, + let excluded_columns = if let Some(WildcardOptions { + exclude: opt_exclude, + except: opt_except, .. }) = wildcard_options { - get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, &None)? + get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)? } else { vec![] }; @@ -417,12 +424,11 @@ pub fn expand_wildcard( /// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s. pub fn expand_qualified_wildcard( - qualifier: &str, + qualifier: &TableReference, schema: &DFSchema, - wildcard_options: Option<&WildcardAdditionalOptions>, + wildcard_options: Option<&WildcardOptions>, ) -> Result> { - let qualifier = TableReference::from(qualifier); - let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + let qualified_indices = schema.fields_indices_with_qualified(qualifier); let projected_func_dependencies = schema .functional_dependencies() .project_functional_dependencies(&qualified_indices, qualified_indices.len()); @@ -431,13 +437,16 @@ pub fn expand_qualified_wildcard( return plan_err!("Invalid qualifier {qualifier}"); } - let qualified_schema = Arc::new(Schema::new(fields_with_qualified)); + let qualified_schema = Arc::new(Schema::new_with_metadata( + fields_with_qualified, + schema.metadata().clone(), + )); let qualified_dfschema = DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)? .with_functional_dependencies(projected_func_dependencies)?; - let excluded_columns = if let Some(WildcardAdditionalOptions { - opt_exclude, - opt_except, + let excluded_columns = if let Some(WildcardOptions { + exclude: opt_exclude, + except: opt_except, .. }) = wildcard_options { @@ -445,7 +454,7 @@ pub fn expand_qualified_wildcard( opt_exclude.as_ref(), opt_except.as_ref(), schema, - &Some(qualifier), + Some(qualifier), )? } else { vec![] @@ -460,23 +469,21 @@ pub fn expand_qualified_wildcard( } /// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") -/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column -type WindowSortKey = Vec<(Expr, bool)>; +/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column +type WindowSortKey = Vec<(Sort, bool)>; -/// Generate a sort key for a given window expr's partition_by and order_bu expr +/// Generate a sort key for a given window expr's partition_by and order_by expr pub fn generate_sort_key( partition_by: &[Expr], - order_by: &[Expr], + order_by: &[Sort], ) -> Result { let normalized_order_by_keys = order_by .iter() - .map(|e| match e { - Expr::Sort(Sort { expr, .. }) => { - Ok(Expr::Sort(Sort::new(expr.clone(), true, false))) - } - _ => plan_err!("Order by only accepts sort expressions"), + .map(|e| { + let Sort { expr, .. } = e; + Sort::new(expr.clone(), true, false) }) - .collect::>>()?; + .collect::>(); let mut final_sort_keys = vec![]; let mut is_partition_flag = vec![]; @@ -512,68 +519,64 @@ pub fn generate_sort_key( /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): /// pub fn compare_sort_expr( - sort_expr_a: &Expr, - sort_expr_b: &Expr, + sort_expr_a: &Sort, + sort_expr_b: &Sort, schema: &DFSchemaRef, ) -> Ordering { - match (sort_expr_a, sort_expr_b) { - ( - Expr::Sort(Sort { - expr: expr_a, - asc: asc_a, - nulls_first: nulls_first_a, - }), - Expr::Sort(Sort { - expr: expr_b, - asc: asc_b, - nulls_first: nulls_first_b, - }), - ) => { - let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); - let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); - for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { - match idx_a.cmp(idx_b) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - } - } - match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { - Ordering::Less => return Ordering::Greater, - Ordering::Greater => { - return Ordering::Less; - } - Ordering::Equal => {} + let Sort { + expr: expr_a, + asc: asc_a, + nulls_first: nulls_first_a, + } = sort_expr_a; + + let Sort { + expr: expr_b, + asc: asc_b, + nulls_first: nulls_first_b, + } = sort_expr_b; + + let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema); + let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema); + for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) { + match idx_a.cmp(idx_b) { + Ordering::Less => { + return Ordering::Less; } - match (asc_a, asc_b) { - (true, false) => { - return Ordering::Greater; - } - (false, true) => { - return Ordering::Less; - } - _ => {} + Ordering::Greater => { + return Ordering::Greater; } - match (nulls_first_a, nulls_first_b) { - (true, false) => { - return Ordering::Less; - } - (false, true) => { - return Ordering::Greater; - } - _ => {} - } - Ordering::Equal + Ordering::Equal => {} + } + } + match ref_indexes_a.len().cmp(&ref_indexes_b.len()) { + Ordering::Less => return Ordering::Greater, + Ordering::Greater => { + return Ordering::Less; + } + Ordering::Equal => {} + } + match (asc_a, asc_b) { + (true, false) => { + return Ordering::Greater; + } + (false, true) => { + return Ordering::Less; + } + _ => {} + } + match (nulls_first_a, nulls_first_b) { + (true, false) => { + return Ordering::Less; } - _ => Ordering::Equal, + (false, true) => { + return Ordering::Greater; + } + _ => {} } + Ordering::Equal } -/// group a slice of window expression expr by their order by expressions +/// Group a slice of window expression expr by their order by expressions pub fn group_window_expr_by_sort_keys( window_expr: Vec, ) -> Result)>> { @@ -600,20 +603,12 @@ pub fn group_window_expr_by_sort_keys( /// Collect all deeply nested `Expr::AggregateFunction`. /// They are returned in order of occurrence (depth /// first), with duplicates omitted. -pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { +pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { matches!(nested_expr, Expr::AggregateFunction { .. }) }) } -/// Collect all deeply nested `Expr::Sort`. They are returned in order of occurrence -/// (depth first), with duplicates omitted. -pub fn find_sort_exprs(exprs: &[Expr]) -> Vec { - find_exprs_in_exprs(exprs, &|nested_expr| { - matches!(nested_expr, Expr::Sort { .. }) - }) -} - /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. pub fn find_window_exprs(exprs: &[Expr]) -> Vec { @@ -633,12 +628,15 @@ pub fn find_out_reference_exprs(expr: &Expr) -> Vec { /// Search the provided `Expr`'s, and all of their nested `Expr`, for any that /// pass the provided test. The returned `Expr`'s are deduplicated and returned /// in order of appearance (depth first). -fn find_exprs_in_exprs(exprs: &[Expr], test_fn: &F) -> Vec +fn find_exprs_in_exprs<'a, F>( + exprs: impl IntoIterator, + test_fn: &F, +) -> Vec where F: Fn(&Expr) -> bool, { exprs - .iter() + .into_iter() .flat_map(|expr| find_exprs_in_expr(expr, test_fn)) .fold(vec![], |mut acc, expr| { if !acc.contains(&expr) { @@ -661,7 +659,7 @@ where if !(exprs.contains(expr)) { exprs.push(expr.clone()) } - // stop recursing down this expr once we find a match + // Stop recursing down this expr once we find a match return Ok(TreeNodeRecursion::Jump); } @@ -680,7 +678,7 @@ where let mut err = Ok(()); expr.apply(|expr| { if let Err(e) = f(expr) { - // save the error for later (it may not be a DataFusionError + // Save the error for later (it may not be a DataFusionError) err = Err(e); Ok(TreeNodeRecursion::Stop) } else { @@ -694,84 +692,175 @@ where err } -/// Returns a new logical plan based on the original one with inputs -/// and expressions replaced. -/// -/// The exprs correspond to the same order of expressions returned by -/// `LogicalPlan::expressions`. This function is used in optimizers in -/// the following way: -/// -/// ```text -/// let new_inputs = optimize_children(..., plan, props); -/// -/// // get the plans expressions to optimize -/// let exprs = plan.expressions(); -/// -/// // potentially rewrite plan expressions -/// let rewritten_exprs = rewrite_exprs(exprs); -/// -/// // create new plan using rewritten_exprs in same position -/// let new_plan = from_plan(&plan, rewritten_exprs, new_inputs); -/// ``` -/// -/// Notice: sometimes [from_plan] will use schema of original plan, it don't change schema! -/// Such as `Projection/Aggregate/Window` -#[deprecated(since = "31.0.0", note = "use LogicalPlan::with_new_exprs instead")] -pub fn from_plan( +/// Create field meta-data from an expression, for use in a result set schema +pub fn exprlist_to_fields<'a>( + exprs: impl IntoIterator, plan: &LogicalPlan, - expr: &[Expr], - inputs: &[LogicalPlan], -) -> Result { - plan.with_new_exprs(expr.to_vec(), inputs.to_vec()) -} - -/// Find all columns referenced from an aggregate query -fn agg_cols(agg: &Aggregate) -> Vec { - agg.aggr_expr - .iter() - .chain(&agg.group_expr) - .flat_map(find_columns_referenced_by_expr) - .collect() +) -> Result, Arc)>> { + // Look for exact match in plan's output schema + let wildcard_schema = find_base_plan(plan).schema(); + let input_schema = plan.schema(); + let result = exprs + .into_iter() + .map(|e| match e { + Expr::Wildcard { qualifier, options } => match qualifier { + None => { + let excluded: Vec = get_excluded_columns( + options.exclude.as_ref(), + options.except.as_ref(), + wildcard_schema, + None, + )? + .into_iter() + .map(|c| c.flat_name()) + .collect(); + Ok::<_, DataFusionError>( + wildcard_schema + .field_names() + .iter() + .enumerate() + .filter(|(_, s)| !excluded.contains(s)) + .map(|(i, _)| wildcard_schema.qualified_field(i)) + .map(|(qualifier, f)| { + (qualifier.cloned(), Arc::new(f.to_owned())) + }) + .collect::>(), + ) + } + Some(qualifier) => { + let excluded: Vec = get_excluded_columns( + options.exclude.as_ref(), + options.except.as_ref(), + wildcard_schema, + Some(qualifier), + )? + .into_iter() + .map(|c| c.flat_name()) + .collect(); + Ok(wildcard_schema + .fields_with_qualified(qualifier) + .into_iter() + .filter_map(|field| { + let flat_name = format!("{}.{}", qualifier, field.name()); + if excluded.contains(&flat_name) { + None + } else { + Some(( + Some(qualifier.clone()), + Arc::new(field.to_owned()), + )) + } + }) + .collect::>()) + } + }, + _ => Ok(vec![e.to_field(input_schema)?]), + }) + .collect::>>()? + .into_iter() + .flatten() + .collect(); + Ok(result) } -fn exprlist_to_fields_aggregate( - exprs: &[Expr], - agg: &Aggregate, -) -> Result, Arc)>> { - let agg_cols = agg_cols(agg); - let mut fields = vec![]; - for expr in exprs { - match expr { - Expr::Column(c) if agg_cols.iter().any(|x| x == c) => { - // resolve against schema of input to aggregate - fields.push(expr.to_field(agg.input.schema())?); +/// Find the suitable base plan to expand the wildcard expression recursively. +/// When planning [LogicalPlan::Window] and [LogicalPlan::Aggregate], we will generate +/// an intermediate plan based on the relation plan (e.g. [LogicalPlan::TableScan], [LogicalPlan::Subquery], ...). +/// If we expand a wildcard expression basing the intermediate plan, we could get some duplicate fields. +pub fn find_base_plan(input: &LogicalPlan) -> &LogicalPlan { + match input { + LogicalPlan::Window(window) => find_base_plan(&window.input), + LogicalPlan::Aggregate(agg) => find_base_plan(&agg.input), + // [SqlToRel::try_process_unnest] will convert Expr(Unnest(Expr)) to Projection/Unnest/Projection + // We should expand the wildcard expression based on the input plan of the inner Projection. + LogicalPlan::Unnest(unnest) => { + if let LogicalPlan::Projection(projection) = unnest.input.deref() { + find_base_plan(&projection.input) + } else { + input + } + } + LogicalPlan::Filter(filter) => { + if filter.having { + // If a filter is used for a having clause, its input plan is an aggregation. + // We should expand the wildcard expression based on the aggregation's input plan. + find_base_plan(&filter.input) + } else { + input } - _ => fields.push(expr.to_field(&agg.schema)?), } + _ => input, } - Ok(fields) } -/// Create field meta-data from an expression, for use in a result set schema -pub fn exprlist_to_fields( +/// Count the number of real fields. We should expand the wildcard expression to get the actual number. +pub fn exprlist_len( exprs: &[Expr], - plan: &LogicalPlan, -) -> Result, Arc)>> { - // when dealing with aggregate plans we cannot simply look in the aggregate output schema - // because it will contain columns representing complex expressions (such a column named - // `GROUPING(person.state)` so in order to resolve `person.state` in this case we need to - // look at the input to the aggregate instead. - let fields = match plan { - LogicalPlan::Aggregate(agg) => Some(exprlist_to_fields_aggregate(exprs, agg)), - _ => None, - }; - if let Some(fields) = fields { - fields - } else { - // look for exact match in plan's output schema - let input_schema = &plan.schema(); - exprs.iter().map(|e| e.to_field(input_schema)).collect() - } + schema: &DFSchemaRef, + wildcard_schema: Option<&DFSchemaRef>, +) -> Result { + exprs + .iter() + .map(|e| match e { + Expr::Wildcard { + qualifier: None, + options, + } => { + let excluded = get_excluded_columns( + options.exclude.as_ref(), + options.except.as_ref(), + wildcard_schema.unwrap_or(schema), + None, + )? + .into_iter() + .collect::>(); + Ok( + get_exprs_except_skipped(wildcard_schema.unwrap_or(schema), excluded) + .len(), + ) + } + Expr::Wildcard { + qualifier: Some(qualifier), + options, + } => { + let related_wildcard_schema = wildcard_schema.as_ref().map_or_else( + || Ok(Arc::clone(schema)), + |schema| { + // Eliminate the fields coming from other tables. + let qualified_fields = schema + .fields() + .iter() + .enumerate() + .filter_map(|(idx, field)| { + let (maybe_table_ref, _) = schema.qualified_field(idx); + if maybe_table_ref.map_or(true, |q| q == qualifier) { + Some((maybe_table_ref.cloned(), Arc::clone(field))) + } else { + None + } + }) + .collect::>(); + let metadata = schema.metadata().clone(); + DFSchema::new_with_metadata(qualified_fields, metadata) + .map(Arc::new) + }, + )?; + let excluded = get_excluded_columns( + options.exclude.as_ref(), + options.except.as_ref(), + related_wildcard_schema.as_ref(), + Some(qualifier), + )? + .into_iter() + .collect::>(); + Ok( + get_exprs_except_skipped(related_wildcard_schema.as_ref(), excluded) + .len(), + ) + } + _ => Ok(1), + }) + .sum() } /// Convert an expression into Column expression if it's already provided as input plan. @@ -789,37 +878,21 @@ pub fn exprlist_to_fields( /// .aggregate(vec![col("c1")], vec![sum(col("c2"))])? /// .project(vec![col("c1"), col("SUM(c2)")? /// ``` -pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr { - match e { - Expr::Column(_) => e, - Expr::OuterReferenceColumn(_, _) => e, - Expr::Alias(Alias { - expr, - relation, - name, - }) => columnize_expr(*expr, input_schema).alias_qualified(relation, name), - Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast { - expr: Box::new(columnize_expr(*expr, input_schema)), - data_type, - }), - Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new( - Box::new(columnize_expr(*expr, input_schema)), - data_type, +pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { + let output_exprs = match input.columnized_output_exprs() { + Ok(exprs) if !exprs.is_empty() => exprs, + _ => return Ok(e), + }; + let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); + e.transform_down(|node: Expr| match exprs_map.get(&node) { + Some(column) => Ok(Transformed::new( + Expr::Column(column.clone()), + true, + TreeNodeRecursion::Jump, )), - Expr::ScalarSubquery(_) => e.clone(), - _ => match e.display_name() { - Ok(name) => { - match input_schema.qualified_field_with_unqualified_name(&name) { - Ok((qualifier, field)) => { - Expr::Column(Column::from((qualifier, field))) - } - // expression not provided as input, do not convert to a column reference - Err(_) => e, - } - } - Err(_) => e, - }, - } + None => Ok(Transformed::no(node)), + }) + .data() } /// Collect all deeply nested `Expr::Column`'s. They are returned in order of @@ -852,7 +925,9 @@ pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result { let (qualifier, field) = plan.schema().qualified_field_from_column(col)?; Ok(Expr::from(Column::from((qualifier, field)))) } - _ => Ok(Expr::Column(Column::from_name(expr.display_name()?))), + _ => Ok(Expr::Column(Column::from_name( + expr.schema_name().to_string(), + ))), } } @@ -871,7 +946,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr( } } Expr::Literal(_) => { - indexes.push(std::usize::MAX); + indexes.push(usize::MAX); } _ => {} } @@ -881,8 +956,8 @@ pub(crate) fn find_column_indexes_referenced_by_expr( indexes } -/// can this data type be used in hash join equal conditions?? -/// data types here come from function 'equal_rows', if more data types are supported +/// Can this data type be used in hash join equal conditions?? +/// Data types here come from function 'equal_rows', if more data types are supported /// in equal_rows(hash join), add those data types here to generate join logical plan. pub fn can_hash(data_type: &DataType) -> bool { match data_type { @@ -906,6 +981,7 @@ pub fn can_hash(data_type: &DataType) -> bool { }, DataType::Utf8 => true, DataType::LargeUtf8 => true, + DataType::Utf8View => true, DataType::Decimal128(_, _) => true, DataType::Date32 => true, DataType::Date64 => true, @@ -918,14 +994,15 @@ pub fn can_hash(data_type: &DataType) -> bool { DataType::List(_) => true, DataType::LargeList(_) => true, DataType::FixedSizeList(_, _) => true, + DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())), _ => false, } } /// Check whether all columns are from the schema. pub fn check_all_columns_from_schema( - columns: &HashSet, - schema: DFSchemaRef, + columns: &HashSet<&Column>, + schema: &DFSchema, ) -> Result { for col in columns.iter() { let exist = schema.is_column_from_schema(col); @@ -949,19 +1026,19 @@ pub fn check_all_columns_from_schema( pub fn find_valid_equijoin_key_pair( left_key: &Expr, right_key: &Expr, - left_schema: DFSchemaRef, - right_schema: DFSchemaRef, + left_schema: &DFSchema, + right_schema: &DFSchema, ) -> Result> { - let left_using_columns = left_key.to_columns()?; - let right_using_columns = right_key.to_columns()?; + let left_using_columns = left_key.column_refs(); + let right_using_columns = right_key.column_refs(); // Conditions like a = 10, will be added to non-equijoin. if left_using_columns.is_empty() || right_using_columns.is_empty() { return Ok(None); } - if check_all_columns_from_schema(&left_using_columns, left_schema.clone())? - && check_all_columns_from_schema(&right_using_columns, right_schema.clone())? + if check_all_columns_from_schema(&left_using_columns, left_schema)? + && check_all_columns_from_schema(&right_using_columns, right_schema)? { return Ok(Some((left_key.clone(), right_key.clone()))); } else if check_all_columns_from_schema(&right_using_columns, left_schema)? @@ -1028,6 +1105,54 @@ fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<& } } +/// Iteratate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction(expr: &Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(right); + stack.push(left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(expr), + other => return Some(other), + } + } + None + }) +} + +/// Iteratate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` +/// +/// See [`split_conjunction_owned`] for more details and an example. +pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator { + let mut stack = vec![expr]; + std::iter::from_fn(move || { + while let Some(expr) = stack.pop() { + match expr { + Expr::BinaryExpr(BinaryExpr { + right, + op: Operator::And, + left, + }) => { + stack.push(*right); + stack.push(*left); + } + Expr::Alias(Alias { expr, .. }) => stack.push(*expr), + other => return Some(other), + } + } + None + }) +} + /// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// This is often used to "split" filter expressions such as `col1 = 5 @@ -1147,7 +1272,7 @@ fn split_binary_impl<'a>( /// assert_eq!(conjunction(split), Some(expr)); /// ``` pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + filters.into_iter().reduce(Expr::and) } /// Combines an array of filter expressions into a single filter @@ -1155,12 +1280,41 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::disjunction; +/// // a=1 OR b=2 +/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use disjuncton to join them together with `OR` +/// assert_eq!(disjunction(split), Some(expr)); +/// ``` pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) + filters.into_iter().reduce(Expr::or) } -/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with -/// its predicate be all `predicates` ANDed. +/// Returns a new [LogicalPlan] that filters the output of `plan` with a +/// [LogicalPlan::Filter] with all `predicates` ANDed. +/// +/// # Example +/// Before: +/// ```text +/// plan +/// ``` +/// +/// After: +/// ```text +/// Filter(predicate) +/// plan +/// ``` pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { // reduce filters to a single filter with an AND let predicate = predicates @@ -1222,9 +1376,9 @@ pub fn only_or_err(slice: &[T]) -> Result<&T> { } /// merge inputs schema into a single schema. -pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { +pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema { if inputs.len() == 1 { - inputs[0].schema().clone().as_ref().clone() + inputs[0].schema().as_ref().clone() } else { inputs.iter().map(|input| input.schema()).fold( DFSchema::empty(), @@ -1236,7 +1390,7 @@ pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { } } -/// Build state name. State is the intermidiate state of the aggregate function. +/// Build state name. State is the intermediate state of the aggregate function. pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } @@ -1245,8 +1399,9 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { mod tests { use super::*; use crate::{ - col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction, - WindowFrame, WindowFunctionDefinition, + col, cube, expr_vec_fmt, grouping_set, lit, rollup, + test::function_stub::max_udaf, test::function_stub::min_udaf, + test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition, }; #[test] @@ -1259,37 +1414,21 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { - let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + let max1 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + let max2 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + let min3 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), + let sum4 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![], - WindowFrame::new(None), - None, )); let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1302,42 +1441,38 @@ mod tests { #[test] fn test_group_window_expr_by_sort_keys() -> Result<()> { - let age_asc = Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)); - let name_desc = Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)); - let created_at_desc = - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)); - let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + let age_asc = Sort::new(col("age"), true, true); + let name_desc = Sort::new(col("name"), false, true); + let created_at_desc = Sort::new(col("created_at"), false, true); + let max1 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); - let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); + let max2 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("name")], - vec![], - vec![], - WindowFrame::new(None), - None, )); - let min3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Min), + let min3 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(min_udaf()), vec![col("name")], - vec![], - vec![age_asc.clone(), name_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); - let sum4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), + )) + .order_by(vec![age_asc.clone(), name_desc.clone()]) + .build() + .unwrap(); + let sum4 = Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(sum_udaf()), vec![col("age")], - vec![], - vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], - WindowFrame::new(Some(false)), - None, - )); + )) + .order_by(vec![ + name_desc.clone(), + age_asc.clone(), + created_at_desc.clone(), + ]) + .build() + .unwrap(); // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs.to_vec())?; @@ -1359,43 +1494,6 @@ mod tests { Ok(()) } - #[test] - fn test_find_sort_exprs() -> Result<()> { - let exprs = &[ - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), - vec![col("name")], - vec![], - vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Sum), - vec![col("age")], - vec![], - vec![ - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ], - WindowFrame::new(Some(false)), - None, - )), - ]; - let expected = vec![ - Expr::Sort(expr::Sort::new(Box::new(col("age")), true, true)), - Expr::Sort(expr::Sort::new(Box::new(col("name")), false, true)), - Expr::Sort(expr::Sort::new(Box::new(col("created_at")), false, true)), - ]; - let result = find_sort_exprs(exprs); - assert_eq!(expected, result); - Ok(()) - } - #[test] fn avoid_generate_duplicate_sort_keys() -> Result<()> { let asc_or_desc = [true, false]; @@ -1404,41 +1502,41 @@ mod tests { for asc_ in asc_or_desc { for nulls_first_ in nulls_first_or_last { let order_by = &[ - Expr::Sort(Sort { - expr: Box::new(col("age")), + Sort { + expr: col("age"), asc: asc_, nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("name")), + }, + Sort { + expr: col("name"), asc: asc_, nulls_first: nulls_first_, - }), + }, ]; let expected = vec![ ( - Expr::Sort(Sort { - expr: Box::new(col("age")), + Sort { + expr: col("age"), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { - expr: Box::new(col("name")), + Sort { + expr: col("name"), asc: asc_, nulls_first: nulls_first_, - }), + }, true, ), ( - Expr::Sort(Sort { - expr: Box::new(col("created_at")), + Sort { + expr: col("created_at"), asc: true, nulls_first: false, - }), + }, true, ), ]; diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index c0617eaf4ed4..222914315d70 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -23,12 +23,11 @@ //! - An ending frame boundary, //! - An EXCLUDE clause. +use crate::{expr::Sort, lit}; +use arrow::datatypes::DataType; use std::fmt::{self, Formatter}; use std::hash::Hash; -use crate::expr::Sort; -use crate::Expr; - use datafusion_common::{plan_err, sql_err, DataFusionError, Result, ScalarValue}; use sqlparser::ast; use sqlparser::parser::ParserError::ParserError; @@ -37,7 +36,7 @@ use sqlparser::parser::ParserError::ParserError; /// window function. The ending frame boundary can be omitted if the `BETWEEN` /// and `AND` keywords that surround the starting frame boundary are also omitted, /// in which case the ending frame boundary defaults to `CURRENT ROW`. -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash)] pub struct WindowFrame { /// Frame type - either `ROWS`, `RANGE` or `GROUPS` pub units: WindowFrameUnits, @@ -95,7 +94,7 @@ pub struct WindowFrame { } impl fmt::Display for WindowFrame { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!( f, "{} BETWEEN {} AND {}", @@ -120,9 +119,9 @@ impl TryFrom for WindowFrame { type Error = DataFusionError; fn try_from(value: ast::WindowFrame) -> Result { - let start_bound = value.start_bound.try_into()?; + let start_bound = WindowFrameBound::try_parse(value.start_bound, &value.units)?; let end_bound = match value.end_bound { - Some(value) => value.try_into()?, + Some(bound) => WindowFrameBound::try_parse(bound, &value.units)?, None => WindowFrameBound::CurrentRow, }; @@ -139,6 +138,7 @@ impl TryFrom for WindowFrame { )? } }; + let units = value.units.into(); Ok(Self::new_bounds(units, start_bound, end_bound)) } @@ -246,59 +246,51 @@ impl WindowFrame { causal, } } -} -/// Regularizes ORDER BY clause for window definition for implicit corner cases. -pub fn regularize_window_order_by( - frame: &WindowFrame, - order_by: &mut Vec, -) -> Result<()> { - if frame.units == WindowFrameUnits::Range && order_by.len() != 1 { - // Normally, RANGE frames require an ORDER BY clause with exactly one - // column. However, an ORDER BY clause may be absent or present but with - // more than one column in two edge cases: - // 1. start bound is UNBOUNDED or CURRENT ROW - // 2. end bound is CURRENT ROW or UNBOUNDED. - // In these cases, we regularize the ORDER BY clause if the ORDER BY clause - // is absent. If an ORDER BY clause is present but has more than one column, - // the ORDER BY clause is unchanged. Note that this follows Postgres behavior. - if (frame.start_bound.is_unbounded() - || frame.start_bound == WindowFrameBound::CurrentRow) - && (frame.end_bound == WindowFrameBound::CurrentRow - || frame.end_bound.is_unbounded()) - { - // If an ORDER BY clause is absent, it is equivalent to a ORDER BY clause - // with constant value as sort key. - // If an ORDER BY clause is present but has more than one column, it is - // unchanged. - if order_by.is_empty() { - order_by.push(Expr::Sort(Sort::new( - Box::new(Expr::Literal(ScalarValue::UInt64(Some(1)))), - true, - false, - ))); + /// Regularizes the ORDER BY clause of the window frame. + pub fn regularize_order_bys(&self, order_by: &mut Vec) -> Result<()> { + match self.units { + // Normally, RANGE frames require an ORDER BY clause with exactly + // one column. However, an ORDER BY clause may be absent or have + // more than one column when the start/end bounds are UNBOUNDED or + // CURRENT ROW. + WindowFrameUnits::Range if self.free_range() => { + // If an ORDER BY clause is absent, it is equivalent to an + // ORDER BY clause with constant value as sort key. If an + // ORDER BY clause is present but has more than one column, + // it is unchanged. Note that this follows PostgreSQL behavior. + if order_by.is_empty() { + order_by.push(lit(1u64).sort(true, false)); + } + } + WindowFrameUnits::Range if order_by.len() != 1 => { + return plan_err!("RANGE requires exactly one ORDER BY column"); + } + WindowFrameUnits::Groups if order_by.is_empty() => { + return plan_err!("GROUPS requires an ORDER BY clause"); } + _ => {} } + Ok(()) } - Ok(()) -} -/// Checks if given window frame is valid. In particular, if the frame is RANGE -/// with offset PRECEDING/FOLLOWING, it must have exactly one ORDER BY column. -pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { - if frame.units == WindowFrameUnits::Range && order_bys != 1 { - // See `regularize_window_order_by`. - if !(frame.start_bound.is_unbounded() - || frame.start_bound == WindowFrameBound::CurrentRow) - || !(frame.end_bound == WindowFrameBound::CurrentRow - || frame.end_bound.is_unbounded()) - { - plan_err!("RANGE requires exactly one ORDER BY column")? + /// Returns whether the window frame can accept multiple ORDER BY expressons. + pub fn can_accept_multi_orderby(&self) -> bool { + match self.units { + WindowFrameUnits::Rows => true, + WindowFrameUnits::Range => self.free_range(), + WindowFrameUnits::Groups => true, } - } else if frame.units == WindowFrameUnits::Groups && order_bys == 0 { - plan_err!("GROUPS requires an ORDER BY clause")? - }; - Ok(()) + } + + /// Returns whether the window frame is "free range"; i.e. its start/end + /// bounds are UNBOUNDED or CURRENT ROW. + fn free_range(&self) -> bool { + (self.start_bound.is_unbounded() + || self.start_bound == WindowFrameBound::CurrentRow) + && (self.end_bound.is_unbounded() + || self.end_bound == WindowFrameBound::CurrentRow) + } } /// There are five ways to describe starting and ending frame boundaries: @@ -309,14 +301,14 @@ pub fn check_window_frame(frame: &WindowFrame, order_bys: usize) -> Result<()> { /// 4. `` FOLLOWING /// 5. UNBOUNDED FOLLOWING /// -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WindowFrameBound { /// 1. UNBOUNDED PRECEDING - /// The frame boundary is the first row in the partition. + /// The frame boundary is the first row in the partition. /// /// 2. `` PRECEDING - /// `` must be a non-negative constant numeric expression. The boundary is a row that - /// is `` "units" prior to the current row. + /// `` must be a non-negative constant numeric expression. The boundary is a row that + /// is `` "units" prior to the current row. Preceding(ScalarValue), /// 3. The current row. /// @@ -326,10 +318,10 @@ pub enum WindowFrameBound { /// boundary. CurrentRow, /// 4. This is the same as "`` PRECEDING" except that the boundary is `` units after the - /// current rather than before the current row. + /// current rather than before the current row. /// /// 5. UNBOUNDED FOLLOWING - /// The frame boundary is the last row in the partition. + /// The frame boundary is the last row in the partition. Following(ScalarValue), } @@ -343,17 +335,18 @@ impl WindowFrameBound { } } -impl TryFrom for WindowFrameBound { - type Error = DataFusionError; - - fn try_from(value: ast::WindowFrameBound) -> Result { +impl WindowFrameBound { + fn try_parse( + value: ast::WindowFrameBound, + units: &ast::WindowFrameUnits, + ) -> Result { Ok(match value { ast::WindowFrameBound::Preceding(Some(v)) => { - Self::Preceding(convert_frame_bound_to_scalar_value(*v)?) + Self::Preceding(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Preceding(None) => Self::Preceding(ScalarValue::Null), ast::WindowFrameBound::Following(Some(v)) => { - Self::Following(convert_frame_bound_to_scalar_value(*v)?) + Self::Following(convert_frame_bound_to_scalar_value(*v, units)?) } ast::WindowFrameBound::Following(None) => Self::Following(ScalarValue::Null), ast::WindowFrameBound::CurrentRow => Self::CurrentRow, @@ -361,37 +354,69 @@ impl TryFrom for WindowFrameBound { } } -pub fn convert_frame_bound_to_scalar_value(v: ast::Expr) -> Result { - Ok(ScalarValue::Utf8(Some(match v { - ast::Expr::Value(ast::Value::Number(value, false)) - | ast::Expr::Value(ast::Value::SingleQuotedString(value)) => value, - ast::Expr::Interval(ast::Interval { - value, - leading_field, - .. - }) => { - let result = match *value { - ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, - e => { - return sql_err!(ParserError(format!( - "INTERVAL expression cannot be {e:?}" - ))); +fn convert_frame_bound_to_scalar_value( + v: ast::Expr, + units: &ast::WindowFrameUnits, +) -> Result { + match units { + // For ROWS and GROUPS we are sure that the ScalarValue must be a non-negative integer ... + ast::WindowFrameUnits::Rows | ast::WindowFrameUnits::Groups => match v { + ast::Expr::Value(ast::Value::Number(value, false)) => { + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + }, + ast::Expr::Interval(ast::Interval { + value, + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }) => { + let value = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + Ok(ScalarValue::try_from_string(value, &DataType::UInt64)?) + } + _ => plan_err!( + "Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers" + ), + }, + // ... instead for RANGE it could be anything depending on the type of the ORDER BY clause, + // so we use a ScalarValue::Utf8. + ast::WindowFrameUnits::Range => Ok(ScalarValue::Utf8(Some(match v { + ast::Expr::Value(ast::Value::Number(value, false)) => value, + ast::Expr::Interval(ast::Interval { + value, + leading_field, + .. + }) => { + let result = match *value { + ast::Expr::Value(ast::Value::SingleQuotedString(item)) => item, + e => { + return sql_err!(ParserError(format!( + "INTERVAL expression cannot be {e:?}" + ))); + } + }; + if let Some(leading_field) = leading_field { + format!("{result} {leading_field}") + } else { + result } - }; - if let Some(leading_field) = leading_field { - format!("{result} {leading_field}") - } else { - result } - } - _ => plan_err!( - "Invalid window frame: frame offsets must be non negative integers" - )?, - }))) + _ => plan_err!( + "Invalid window frame: frame offsets for RANGE must be either a numeric value, a string value or an interval" + )?, + }))), + } } impl fmt::Display for WindowFrameBound { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self { WindowFrameBound::Preceding(n) => { if n.is_null() { @@ -432,7 +457,7 @@ pub enum WindowFrameUnits { } impl fmt::Display for WindowFrameUnits { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.write_str(match self { WindowFrameUnits::Rows => "ROWS", WindowFrameUnits::Range => "RANGE", @@ -488,8 +513,91 @@ mod tests { ast::Expr::Value(ast::Value::Number("1".to_string(), false)), )))), }; - let result = WindowFrame::try_from(window_frame); - assert!(result.is_ok()); + + let window_frame = WindowFrame::try_from(window_frame)?; + assert_eq!(window_frame.units, WindowFrameUnits::Rows); + assert_eq!( + window_frame.start_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))) + ); + assert_eq!( + window_frame.end_bound, + WindowFrameBound::Preceding(ScalarValue::UInt64(Some(1))) + ); + + Ok(()) + } + + macro_rules! test_bound { + ($unit:ident, $value:expr, $expected:expr) => { + let preceding = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(preceding, WindowFrameBound::Preceding($expected)); + let following = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + )?; + assert_eq!(following, WindowFrameBound::Following($expected)); + }; + } + + macro_rules! test_bound_err { + ($unit:ident, $value:expr, $expected:expr) => { + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Preceding($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + let err = WindowFrameBound::try_parse( + ast::WindowFrameBound::Following($value), + &ast::WindowFrameUnits::$unit, + ) + .unwrap_err(); + assert_eq!(err.strip_backtrace(), $expected); + }; + } + + #[test] + fn test_window_frame_bound_creation() -> Result<()> { + // Unbounded + test_bound!(Rows, None, ScalarValue::Null); + test_bound!(Groups, None, ScalarValue::Null); + test_bound!(Range, None, ScalarValue::Null); + + // Number + let number = Some(Box::new(ast::Expr::Value(ast::Value::Number( + "42".to_string(), + false, + )))); + test_bound!(Rows, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!(Groups, number.clone(), ScalarValue::UInt64(Some(42))); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("42".to_string())) + ); + + // Interval + let number = Some(Box::new(ast::Expr::Interval(ast::Interval { + value: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + "1".to_string(), + ))), + leading_field: Some(ast::DateTimeField::Day), + fractional_seconds_precision: None, + last_field: None, + leading_precision: None, + }))); + test_bound_err!(Rows, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound_err!(Groups, number.clone(), "Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers"); + test_bound!( + Range, + number.clone(), + ScalarValue::Utf8(Some("1 DAY".to_string())) + ); + Ok(()) } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs new file mode 100644 index 000000000000..be2b6575e2e9 --- /dev/null +++ b/datafusion/expr/src/window_function.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal}; + +/// Create an expression to represent the `nth_value` window function +pub fn nth_value(arg: Expr, n: i64) -> Expr { + Expr::WindowFunction(WindowFunction::new( + BuiltInWindowFunction::NthValue, + vec![arg, n.lit()], + )) +} diff --git a/datafusion/expr/src/window_state.rs b/datafusion/expr/src/window_state.rs index e7f31bbfbf2b..f1d0ead23ab1 100644 --- a/datafusion/expr/src/window_state.rs +++ b/datafusion/expr/src/window_state.rs @@ -48,7 +48,7 @@ pub struct WindowAggState { /// Keeps track of how many rows should be generated to be in sync with input record_batch. // (For each row in the input record batch we need to generate a window result). pub n_row_result_missing: usize, - /// flag indicating whether we have received all data for this partition + /// Flag indicating whether we have received all data for this partition pub is_end: bool, } diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml new file mode 100644 index 000000000000..119747342515 --- /dev/null +++ b/datafusion/ffi/Cargo.toml @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-ffi" +description = "Foreign Function Interface implementation for DataFusion" +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +# Specify MSRV here as `cargo msrv` doesn't support workspace version +rust-version = "1.76" + +[lints] +workspace = true + +[lib] +name = "datafusion_ffi" +path = "src/lib.rs" + +[dependencies] +abi_stable = "0.11.3" +arrow = { workspace = true, features = ["ffi"] } +async-ffi = { version = "0.5.0", features = ["abi_stable"] } +async-trait = { workspace = true } +datafusion = { workspace = true, default-features = false } +datafusion-proto = { workspace = true } +doc-comment = { workspace = true } +futures = { workspace = true } +log = { workspace = true } +prost = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true } diff --git a/datafusion/ffi/README.md b/datafusion/ffi/README.md new file mode 100644 index 000000000000..ba4bb8b961a1 --- /dev/null +++ b/datafusion/ffi/README.md @@ -0,0 +1,81 @@ + + +# `datafusion-ffi`: Apache DataFusion Foreign Function Interface + +This crate contains code to allow interoperability of Apache [DataFusion] +with functions from other languages using a stable interface. + +See [API Docs] for details and examples. + +We expect this crate may be used by both sides of the FFI. This allows users +to create modules that can interoperate with the necessity of using the same +version of DataFusion. The driving use case has been the `datafusion-python` +repository, but many other use cases may exist. We envision at least two +use cases. + +1. `datafusion-python` which will use the FFI to provide external services such + as a `TableProvider` without needing to re-export the entire `datafusion-python` + code base. With `datafusion-ffi` these packages do not need `datafusion-python` + as a dependency at all. +2. Users may want to create a modular interface that allows runtime loading of + libraries. + +## Struct Layout + +In this crate we have a variety of structs which closely mimic the behavior of +their internal counterparts. In the following example, we will refer to the +`TableProvider`, but the same pattern exists for other structs. + +Each of the exposted structs in this crate is provided with a variant prefixed +with `Foreign`. This variant is designed to be used by the consumer of the +foreign code. The `Foreign` structs should _never_ access the `private_data` +fields. Instead they should only access the data returned through the function +calls defined on the `FFI_` structs. The second purpose of the `Foreign` +structs is to contain additional data that may be needed by the traits that +are implemented on them. Some of these traits require borrowing data which +can be far more convienent to be locally stored. + +For example, we have a struct `FFI_TableProvider` to give access to the +`TableProvider` functions like `table_type()` and `scan()`. If we write a +library that wishes to expose it's `TableProvider`, then we can access the +private data that contains the Arc reference to the `TableProvider` via +`FFI_TableProvider`. This data is local to the library. + +If we have a program that accesses a `TableProvider` via FFI, then it +will use `ForeignTableProvider`. When using `ForeignTableProvider` we **must** +not attempt to access the `private_data` field in `FFI_TableProvider`. If a +user is testing locally, you may be able to successfully access this field, but +it will only work if you are building against the exact same version of +`DataFusion` for both libraries **and** the same compiler. It will not work +in general. + +It is worth noting that which library is the `local` and which is `foreign` +depends on which interface we are considering. For example, suppose we have a +Python library called `my_provider` that exposes a `TableProvider` called +`MyProvider` via `FFI_TableProvider`. Within the library `my_provider` we can +access the `private_data` via `FFI_TableProvider`. We connect this to +`datafusion-python`, where we access it as a `ForeignTableProvider`. Now when +we call `scan()` on this interface, we have to pass it a `FFI_SessionConfig`. +The `SessionConfig` is local to `datafusion-python` and **not** `my_provider`. +It is important to be careful when expanding these functions to be certain which +side of the interface each object refers to. + +[datafusion]: https://datafusion.apache.org +[api docs]: http://docs.rs/datafusion-ffi/latest diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs new file mode 100644 index 000000000000..c5add8782c51 --- /dev/null +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use abi_stable::StableAbi; +use arrow::{ + datatypes::{Schema, SchemaRef}, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, +}; +use log::error; + +/// This is a wrapper struct around FFI_ArrowSchema simply to indicate +/// to the StableAbi macros that the underlying struct is FFI safe. +#[repr(C)] +#[derive(Debug, StableAbi)] +pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); + +impl From for WrappedSchema { + fn from(value: SchemaRef) -> Self { + let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { + Ok(s) => s, + Err(e) => { + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + FFI_ArrowSchema::empty() + } + }; + + WrappedSchema(ffi_schema) + } +} + +impl From for SchemaRef { + fn from(value: WrappedSchema) -> Self { + let schema = match Schema::try_from(&value.0) { + Ok(s) => s, + Err(e) => { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); + Schema::empty() + } + }; + Arc::new(schema) + } +} + +/// This is a wrapper struct for FFI_ArrowArray to indicate to StableAbi +/// that the struct is FFI Safe. For convenience, we also include the +/// schema needed to create a record batch from the array. +#[repr(C)] +#[derive(Debug, StableAbi)] +pub struct WrappedArray { + #[sabi(unsafe_opaque_field)] + pub array: FFI_ArrowArray, + + pub schema: WrappedSchema, +} diff --git a/datafusion/ffi/src/execution_plan.rs b/datafusion/ffi/src/execution_plan.rs new file mode 100644 index 000000000000..d10eda8990b8 --- /dev/null +++ b/datafusion/ffi/src/execution_plan.rs @@ -0,0 +1,361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, pin::Pin, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use datafusion::error::Result; +use datafusion::{ + error::DataFusionError, + execution::{SendableRecordBatchStream, TaskContext}, + physical_plan::{DisplayAs, ExecutionPlan, PlanProperties}, +}; + +use crate::{ + plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream, +}; + +/// A stable struct for sharing a [`ExecutionPlan`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_ExecutionPlan { + /// Return the plan properties + pub properties: unsafe extern "C" fn(plan: &Self) -> FFI_PlanProperties, + + /// Return a vector of children plans + pub children: unsafe extern "C" fn(plan: &Self) -> RVec, + + /// Return the plan name. + pub name: unsafe extern "C" fn(plan: &Self) -> RString, + + /// Execute the plan and return a record batch stream. Errors + /// will be returned as a string. + pub execute: unsafe extern "C" fn( + plan: &Self, + partition: usize, + ) -> RResult, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignExecutionPlan`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_ExecutionPlan {} +unsafe impl Sync for FFI_ExecutionPlan {} + +pub struct ExecutionPlanPrivateData { + pub plan: Arc, + pub context: Arc, +} + +unsafe extern "C" fn properties_fn_wrapper( + plan: &FFI_ExecutionPlan, +) -> FFI_PlanProperties { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + + plan.properties().into() +} + +unsafe extern "C" fn children_fn_wrapper( + plan: &FFI_ExecutionPlan, +) -> RVec { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + let ctx = &(*private_data).context; + + let children: Vec<_> = plan + .children() + .into_iter() + .map(|child| FFI_ExecutionPlan::new(Arc::clone(child), Arc::clone(ctx))) + .collect(); + + children.into() +} + +unsafe extern "C" fn execute_fn_wrapper( + plan: &FFI_ExecutionPlan, + partition: usize, +) -> RResult { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + let ctx = &(*private_data).context; + + match plan.execute(partition, Arc::clone(ctx)) { + Ok(rbs) => RResult::ROk(rbs.into()), + Err(e) => RResult::RErr( + format!("Error occurred during FFI_ExecutionPlan execute: {}", e).into(), + ), + } +} +unsafe extern "C" fn name_fn_wrapper(plan: &FFI_ExecutionPlan) -> RString { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan = &(*private_data).plan; + + plan.name().into() +} + +unsafe extern "C" fn release_fn_wrapper(plan: &mut FFI_ExecutionPlan) { + let private_data = Box::from_raw(plan.private_data as *mut ExecutionPlanPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(plan: &FFI_ExecutionPlan) -> FFI_ExecutionPlan { + let private_data = plan.private_data as *const ExecutionPlanPrivateData; + let plan_data = &(*private_data); + + FFI_ExecutionPlan::new(Arc::clone(&plan_data.plan), Arc::clone(&plan_data.context)) +} + +impl Clone for FFI_ExecutionPlan { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl FFI_ExecutionPlan { + /// This function is called on the provider's side. + pub fn new(plan: Arc, context: Arc) -> Self { + let private_data = Box::new(ExecutionPlanPrivateData { plan, context }); + + Self { + properties: properties_fn_wrapper, + children: children_fn_wrapper, + name: name_fn_wrapper, + execute: execute_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_ExecutionPlan { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an execution plan provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignExecutionPlan is to be used by the caller of the plan, so it has +/// no knowledge or access to the private data. All interaction with the plan +/// must occur through the functions defined in FFI_ExecutionPlan. +#[derive(Debug)] +pub struct ForeignExecutionPlan { + name: String, + plan: FFI_ExecutionPlan, + properties: PlanProperties, + children: Vec>, +} + +unsafe impl Send for ForeignExecutionPlan {} +unsafe impl Sync for ForeignExecutionPlan {} + +impl DisplayAs for ForeignExecutionPlan { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!( + f, + "FFI_ExecutionPlan(number_of_children={})", + self.children.len(), + ) + } +} + +impl TryFrom<&FFI_ExecutionPlan> for ForeignExecutionPlan { + type Error = DataFusionError; + + fn try_from(plan: &FFI_ExecutionPlan) -> Result { + unsafe { + let name = (plan.name)(plan).into(); + + let properties: PlanProperties = (plan.properties)(plan).try_into()?; + + let children_rvec = (plan.children)(plan); + let children: Result> = children_rvec + .iter() + .map(ForeignExecutionPlan::try_from) + .map(|child| child.map(|c| Arc::new(c) as Arc)) + .collect(); + + Ok(Self { + name, + plan: plan.clone(), + properties, + children: children?, + }) + } + } +} + +impl ExecutionPlan for ForeignExecutionPlan { + fn name(&self) -> &str { + &self.name + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn children(&self) -> Vec<&Arc> { + self.children + .iter() + .map(|p| p as &Arc) + .collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(ForeignExecutionPlan { + plan: self.plan.clone(), + name: self.name.clone(), + children, + properties: self.properties.clone(), + })) + } + + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + unsafe { + match (self.plan.execute)(&self.plan, partition) { + RResult::ROk(stream) => { + let stream = Pin::new(Box::new(stream)) as SendableRecordBatchStream; + Ok(stream) + } + RResult::RErr(e) => Err(DataFusionError::Execution(format!( + "Error occurred during FFI call to FFI_ExecutionPlan execute. {}", + e + ))), + } + } + } +} + +#[cfg(test)] +mod tests { + use datafusion::{physical_plan::Partitioning, prelude::SessionContext}; + + use super::*; + + #[derive(Debug)] + pub struct EmptyExec { + props: PlanProperties, + } + + impl EmptyExec { + pub fn new(schema: arrow::datatypes::SchemaRef) -> Self { + Self { + props: PlanProperties::new( + datafusion::physical_expr::EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(3), + datafusion::physical_plan::ExecutionMode::Unbounded, + ), + } + } + } + + impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + _t: datafusion::physical_plan::DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + "empty-exec" + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.props + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[test] + fn test_round_trip_ffi_execution_plan() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + let ctx = SessionContext::new(); + + let original_plan = Arc::new(EmptyExec::new(schema)); + let original_name = original_plan.name().to_string(); + + let local_plan = FFI_ExecutionPlan::new(original_plan, ctx.task_ctx()); + + let foreign_plan: ForeignExecutionPlan = (&local_plan).try_into()?; + + assert!(original_name == foreign_plan.name()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs new file mode 100644 index 000000000000..4a74e65dc671 --- /dev/null +++ b/datafusion/ffi/src/lib.rs @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +pub mod arrow_wrappers; +pub mod execution_plan; +pub mod plan_properties; +pub mod record_batch_stream; +pub mod session_config; +pub mod table_provider; +pub mod table_source; + +#[cfg(doctest)] +doc_comment::doctest!("../README.md", readme_example_test); diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs new file mode 100644 index 000000000000..722681ae4a1d --- /dev/null +++ b/datafusion/ffi/src/plan_properties.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ + RResult::{self, RErr, ROk}, + RStr, RVec, + }, + StableAbi, +}; +use arrow::datatypes::SchemaRef; +use datafusion::{ + error::{DataFusionError, Result}, + physical_expr::EquivalenceProperties, + physical_plan::{ExecutionMode, PlanProperties}, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_sort_exprs, parse_protobuf_partitioning}, + to_proto::{serialize_partitioning, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::{Partitioning, PhysicalSortExprNodeCollection}, +}; +use prost::Message; + +use crate::arrow_wrappers::WrappedSchema; + +/// A stable struct for sharing [`PlanProperties`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_PlanProperties { + /// The output partitioning is a [`Partitioning`] protobuf message serialized + /// into bytes to pass across the FFI boundary. + pub output_partitioning: + unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + + /// Return the execution mode of the plan. + pub execution_mode: unsafe extern "C" fn(plan: &Self) -> FFI_ExecutionMode, + + /// The output ordering is a [`PhysicalSortExprNodeCollection`] protobuf message + /// serialized into bytes to pass across the FFI boundary. + pub output_ordering: + unsafe extern "C" fn(plan: &Self) -> RResult, RStr<'static>>, + + /// Return the schema of the plan. + pub schema: unsafe extern "C" fn(plan: &Self) -> WrappedSchema, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, +} + +struct PlanPropertiesPrivateData { + props: PlanProperties, +} + +unsafe extern "C" fn output_partitioning_fn_wrapper( + properties: &FFI_PlanProperties, +) -> RResult, RStr<'static>> { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + + let codec = DefaultPhysicalExtensionCodec {}; + let partitioning_data = + match serialize_partitioning(props.output_partitioning(), &codec) { + Ok(p) => p, + Err(_) => { + return RErr( + "unable to serialize output_partitioning in FFI_PlanProperties" + .into(), + ) + } + }; + let output_partitioning = partitioning_data.encode_to_vec(); + + ROk(output_partitioning.into()) +} + +unsafe extern "C" fn execution_mode_fn_wrapper( + properties: &FFI_PlanProperties, +) -> FFI_ExecutionMode { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + props.execution_mode().into() +} + +unsafe extern "C" fn output_ordering_fn_wrapper( + properties: &FFI_PlanProperties, +) -> RResult, RStr<'static>> { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + + let codec = DefaultPhysicalExtensionCodec {}; + let output_ordering = + match props.output_ordering() { + Some(ordering) => { + let physical_sort_expr_nodes = + match serialize_physical_sort_exprs(ordering.to_owned(), &codec) { + Ok(v) => v, + Err(_) => return RErr( + "unable to serialize output_ordering in FFI_PlanProperties" + .into(), + ), + }; + + let ordering_data = PhysicalSortExprNodeCollection { + physical_sort_expr_nodes, + }; + + ordering_data.encode_to_vec() + } + None => Vec::default(), + }; + ROk(output_ordering.into()) +} + +unsafe extern "C" fn schema_fn_wrapper(properties: &FFI_PlanProperties) -> WrappedSchema { + let private_data = properties.private_data as *const PlanPropertiesPrivateData; + let props = &(*private_data).props; + + let schema: SchemaRef = Arc::clone(props.eq_properties.schema()); + schema.into() +} + +unsafe extern "C" fn release_fn_wrapper(props: &mut FFI_PlanProperties) { + let private_data = + Box::from_raw(props.private_data as *mut PlanPropertiesPrivateData); + drop(private_data); +} + +impl Drop for FFI_PlanProperties { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl From<&PlanProperties> for FFI_PlanProperties { + fn from(props: &PlanProperties) -> Self { + let private_data = Box::new(PlanPropertiesPrivateData { + props: props.clone(), + }); + + FFI_PlanProperties { + output_partitioning: output_partitioning_fn_wrapper, + execution_mode: execution_mode_fn_wrapper, + output_ordering: output_ordering_fn_wrapper, + schema: schema_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl TryFrom for PlanProperties { + type Error = DataFusionError; + + fn try_from(ffi_props: FFI_PlanProperties) -> Result { + let ffi_schema = unsafe { (ffi_props.schema)(&ffi_props) }; + let schema = (&ffi_schema.0).try_into()?; + + // TODO Extend FFI to get the registry and codex + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + let ffi_orderings = unsafe { (ffi_props.output_ordering)(&ffi_props) }; + let orderings = match ffi_orderings { + ROk(ordering_vec) => { + let proto_output_ordering = + PhysicalSortExprNodeCollection::decode(ordering_vec.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + Some(parse_physical_sort_exprs( + &proto_output_ordering.physical_sort_expr_nodes, + &default_ctx, + &schema, + &codex, + )?) + } + RErr(e) => return Err(DataFusionError::Plan(e.to_string())), + }; + + let ffi_partitioning = unsafe { (ffi_props.output_partitioning)(&ffi_props) }; + let partitioning = match ffi_partitioning { + ROk(partitioning_vec) => { + let proto_output_partitioning = + Partitioning::decode(partitioning_vec.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + parse_protobuf_partitioning( + Some(&proto_output_partitioning), + &default_ctx, + &schema, + &codex, + )? + .ok_or(DataFusionError::Plan( + "Unable to deserialize partitioning protobuf in FFI_PlanProperties" + .to_string(), + )) + } + RErr(e) => Err(DataFusionError::Plan(e.to_string())), + }?; + + let execution_mode: ExecutionMode = + unsafe { (ffi_props.execution_mode)(&ffi_props).into() }; + + let eq_properties = match orderings { + Some(ordering) => { + EquivalenceProperties::new_with_orderings(Arc::new(schema), &[ordering]) + } + None => EquivalenceProperties::new(Arc::new(schema)), + }; + + Ok(PlanProperties::new( + eq_properties, + partitioning, + execution_mode, + )) + } +} + +/// FFI safe version of [`ExecutionMode`]. +#[repr(C)] +#[allow(non_camel_case_types)] +#[derive(Clone, StableAbi)] +pub enum FFI_ExecutionMode { + Bounded, + Unbounded, + PipelineBreaking, +} + +impl From for FFI_ExecutionMode { + fn from(value: ExecutionMode) -> Self { + match value { + ExecutionMode::Bounded => FFI_ExecutionMode::Bounded, + ExecutionMode::Unbounded => FFI_ExecutionMode::Unbounded, + ExecutionMode::PipelineBreaking => FFI_ExecutionMode::PipelineBreaking, + } + } +} + +impl From for ExecutionMode { + fn from(value: FFI_ExecutionMode) -> Self { + match value { + FFI_ExecutionMode::Bounded => ExecutionMode::Bounded, + FFI_ExecutionMode::Unbounded => ExecutionMode::Unbounded, + FFI_ExecutionMode::PipelineBreaking => ExecutionMode::PipelineBreaking, + } + } +} + +#[cfg(test)] +mod tests { + use datafusion::physical_plan::Partitioning; + + use super::*; + + #[test] + fn test_round_trip_ffi_plan_properties() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + let original_props = PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(3), + ExecutionMode::Unbounded, + ); + + let local_props_ptr = FFI_PlanProperties::from(&original_props); + + let foreign_props: PlanProperties = local_props_ptr.try_into()?; + + assert!(format!("{:?}", foreign_props) == format!("{:?}", original_props)); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs new file mode 100644 index 000000000000..c944e56c5cde --- /dev/null +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, task::Poll}; + +use abi_stable::{ + std_types::{ROption, RResult, RString}, + StableAbi, +}; +use arrow::array::{Array, RecordBatch}; +use arrow::{ + array::{make_array, StructArray}, + ffi::{from_ffi, to_ffi}, +}; +use async_ffi::{ContextExt, FfiContext, FfiPoll}; +use datafusion::error::Result; +use datafusion::{ + error::DataFusionError, + execution::{RecordBatchStream, SendableRecordBatchStream}, +}; +use futures::{Stream, TryStreamExt}; + +use crate::arrow_wrappers::{WrappedArray, WrappedSchema}; + +/// A stable struct for sharing [`RecordBatchStream`] across FFI boundaries. +/// We use the async-ffi crate for handling async calls across libraries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_RecordBatchStream { + /// This mirrors the `poll_next` of [`RecordBatchStream`] but does so + /// in a FFI safe manner. + pub poll_next: + unsafe extern "C" fn( + stream: &Self, + cx: &mut FfiContext, + ) -> FfiPoll>>, + + /// Return the schema of the record batch + pub schema: unsafe extern "C" fn(stream: &Self) -> WrappedSchema, + + /// Internal data. This is only to be accessed by the provider of the plan. + /// The foreign library should never attempt to access this data. + pub private_data: *mut c_void, +} + +impl From for FFI_RecordBatchStream { + fn from(stream: SendableRecordBatchStream) -> Self { + FFI_RecordBatchStream { + poll_next: poll_next_fn_wrapper, + schema: schema_fn_wrapper, + private_data: Box::into_raw(Box::new(stream)) as *mut c_void, + } + } +} + +unsafe impl Send for FFI_RecordBatchStream {} + +unsafe extern "C" fn schema_fn_wrapper(stream: &FFI_RecordBatchStream) -> WrappedSchema { + let stream = stream.private_data as *const SendableRecordBatchStream; + + (*stream).schema().into() +} + +fn record_batch_to_wrapped_array( + record_batch: RecordBatch, +) -> RResult { + let struct_array = StructArray::from(record_batch); + match to_ffi(&struct_array.to_data()) { + Ok((array, schema)) => RResult::ROk(WrappedArray { + array, + schema: WrappedSchema(schema), + }), + Err(e) => RResult::RErr(e.to_string().into()), + } +} + +// probably want to use pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { +fn maybe_record_batch_to_wrapped_stream( + record_batch: Option>, +) -> ROption> { + match record_batch { + Some(Ok(record_batch)) => { + ROption::RSome(record_batch_to_wrapped_array(record_batch)) + } + Some(Err(e)) => ROption::RSome(RResult::RErr(e.to_string().into())), + None => ROption::RNone, + } +} + +unsafe extern "C" fn poll_next_fn_wrapper( + stream: &FFI_RecordBatchStream, + cx: &mut FfiContext, +) -> FfiPoll>> { + let stream = stream.private_data as *mut SendableRecordBatchStream; + + let poll_result = cx.with_context(|std_cx| { + (*stream) + .try_poll_next_unpin(std_cx) + .map(maybe_record_batch_to_wrapped_stream) + }); + + poll_result.into() +} + +impl RecordBatchStream for FFI_RecordBatchStream { + fn schema(&self) -> arrow::datatypes::SchemaRef { + let wrapped_schema = unsafe { (self.schema)(self) }; + wrapped_schema.into() + } +} + +fn wrapped_array_to_record_batch(array: WrappedArray) -> Result { + let array_data = + unsafe { from_ffi(array.array, &array.schema.0).map_err(DataFusionError::from)? }; + let array = make_array(array_data); + let struct_array = array + .as_any() + .downcast_ref::() + .ok_or(DataFusionError::Execution( + "Unexpected array type during record batch collection in FFI_RecordBatchStream" + .to_string(), + ))?; + + Ok(struct_array.into()) +} + +fn maybe_wrapped_array_to_record_batch( + array: ROption>, +) -> Option> { + match array { + ROption::RSome(RResult::ROk(wrapped_array)) => { + Some(wrapped_array_to_record_batch(wrapped_array)) + } + ROption::RSome(RResult::RErr(e)) => { + Some(Err(DataFusionError::Execution(e.to_string()))) + } + ROption::RNone => None, + } +} + +impl Stream for FFI_RecordBatchStream { + type Item = Result; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let poll_result = + unsafe { cx.with_ffi_context(|ffi_cx| (self.poll_next)(&self, ffi_cx)) }; + + match poll_result { + FfiPoll::Ready(array) => { + Poll::Ready(maybe_wrapped_array_to_record_batch(array)) + } + FfiPoll::Pending => Poll::Pending, + FfiPoll::Panicked => Poll::Ready(Some(Err(DataFusionError::Execution( + "Error occurred during poll_next on FFI_RecordBatchStream".to_string(), + )))), + } + } +} diff --git a/datafusion/ffi/src/session_config.rs b/datafusion/ffi/src/session_config.rs new file mode 100644 index 000000000000..aea03cf94e0a --- /dev/null +++ b/datafusion/ffi/src/session_config.rs @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::HashMap, + ffi::{c_char, c_void, CString}, +}; + +use abi_stable::{ + std_types::{RHashMap, RString}, + StableAbi, +}; +use datafusion::{config::ConfigOptions, error::Result}; +use datafusion::{error::DataFusionError, prelude::SessionConfig}; + +/// A stable struct for sharing [`SessionConfig`] across FFI boundaries. +/// Instead of attempting to expose the entire SessionConfig interface, we +/// convert the config options into a map from a string to string and pass +/// those values across the FFI boundary. On the receiver side, we +/// reconstruct a SessionConfig from those values. +/// +/// It is possible that using different versions of DataFusion across the +/// FFI boundary could have differing expectations of the config options. +/// This is a limitation of this approach, but exposing the entire +/// SessionConfig via a FFI interface would be extensive and provide limited +/// value over this version. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_SessionConfig { + /// Return a hash map from key to value of the config options represented + /// by string values. + pub config_options: unsafe extern "C" fn(config: &Self) -> RHashMap, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignSessionConfig`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_SessionConfig {} +unsafe impl Sync for FFI_SessionConfig {} + +unsafe extern "C" fn config_options_fn_wrapper( + config: &FFI_SessionConfig, +) -> RHashMap { + let private_data = config.private_data as *mut SessionConfigPrivateData; + let config_options = &(*private_data).config; + + let mut options = RHashMap::default(); + for config_entry in config_options.entries() { + if let Some(value) = config_entry.value { + options.insert(config_entry.key.into(), value.into()); + } + } + + options +} + +unsafe extern "C" fn release_fn_wrapper(config: &mut FFI_SessionConfig) { + let private_data = + Box::from_raw(config.private_data as *mut SessionConfigPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(config: &FFI_SessionConfig) -> FFI_SessionConfig { + let old_private_data = config.private_data as *mut SessionConfigPrivateData; + let old_config = &(*old_private_data).config; + + let private_data = Box::new(SessionConfigPrivateData { + config: old_config.clone(), + }); + + FFI_SessionConfig { + config_options: config_options_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + } +} + +struct SessionConfigPrivateData { + pub config: ConfigOptions, +} + +impl From<&SessionConfig> for FFI_SessionConfig { + fn from(session: &SessionConfig) -> Self { + let mut config_keys = Vec::new(); + let mut config_values = Vec::new(); + for config_entry in session.options().entries() { + if let Some(value) = config_entry.value { + let key_cstr = CString::new(config_entry.key).unwrap_or_default(); + let key_ptr = key_cstr.into_raw() as *const c_char; + config_keys.push(key_ptr); + + config_values + .push(CString::new(value).unwrap_or_default().into_raw() + as *const c_char); + } + } + + let private_data = Box::new(SessionConfigPrivateData { + config: session.options().clone(), + }); + + Self { + config_options: config_options_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + } + } +} + +impl Clone for FFI_SessionConfig { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl Drop for FFI_SessionConfig { + fn drop(&mut self) { + unsafe { (self.release)(self) }; + } +} + +/// A wrapper struct for accessing [`SessionConfig`] across a FFI boundary. +/// The [`SessionConfig`] will be generated from a hash map of the config +/// options in the provider and will be reconstructed on this side of the +/// interface.s +pub struct ForeignSessionConfig(pub SessionConfig); + +impl TryFrom<&FFI_SessionConfig> for ForeignSessionConfig { + type Error = DataFusionError; + + fn try_from(config: &FFI_SessionConfig) -> Result { + let config_options = unsafe { (config.config_options)(config) }; + + let mut options_map = HashMap::new(); + config_options.iter().for_each(|kv_pair| { + options_map.insert(kv_pair.0.to_string(), kv_pair.1.to_string()); + }); + + Ok(Self(SessionConfig::from_string_hash_map(&options_map)?)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_ffi_session_config() -> Result<()> { + let session_config = SessionConfig::new(); + let original_options = session_config.options().entries(); + + let ffi_config: FFI_SessionConfig = (&session_config).into(); + + let foreign_config: ForeignSessionConfig = (&ffi_config).try_into()?; + + let returned_options = foreign_config.0.options().entries(); + + assert!(original_options.len() == returned_options.len()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/table_provider.rs b/datafusion/ffi/src/table_provider.rs new file mode 100644 index 000000000000..011ad96e423d --- /dev/null +++ b/datafusion/ffi/src/table_provider.rs @@ -0,0 +1,443 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::SchemaRef; +use async_ffi::{FfiFuture, FutureExt}; +use async_trait::async_trait; +use datafusion::{ + catalog::{Session, TableProvider}, + datasource::TableType, + error::DataFusionError, + execution::session_state::SessionStateBuilder, + logical_expr::TableProviderFilterPushDown, + physical_plan::ExecutionPlan, + prelude::{Expr, SessionContext}, +}; +use datafusion_proto::{ + logical_plan::{ + from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec, + }, + protobuf::LogicalExprList, +}; +use prost::Message; + +use crate::{ + arrow_wrappers::WrappedSchema, + session_config::ForeignSessionConfig, + table_source::{FFI_TableProviderFilterPushDown, FFI_TableType}, +}; + +use super::{ + execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan}, + session_config::FFI_SessionConfig, +}; +use datafusion::error::Result; + +/// A stable struct for sharing [`TableProvider`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TableProvider { + /// Return the table schema + pub schema: unsafe extern "C" fn(provider: &Self) -> WrappedSchema, + + /// Perform a scan on the table. See [`TableProvider`] for detailed usage information. + /// + /// # Arguments + /// + /// * `provider` - the table provider + /// * `session_config` - session configuration + /// * `projections` - if specified, only a subset of the columns are returned + /// * `filters_serialized` - filters to apply to the scan, which are a + /// [`LogicalExprList`] protobuf message serialized into bytes to pass + /// across the FFI boundary. + /// * `limit` - if specified, limit the number of rows returned + pub scan: unsafe extern "C" fn( + provider: &Self, + session_config: &FFI_SessionConfig, + projections: RVec, + filters_serialized: RVec, + limit: ROption, + ) -> FfiFuture>, + + /// Return the type of table. See [`TableType`] for options. + pub table_type: unsafe extern "C" fn(provider: &Self) -> FFI_TableType, + + /// Based upon the input filters, identify which are supported. The filters + /// are a [`LogicalExprList`] protobuf message serialized into bytes to pass + /// across the FFI boundary. + pub supports_filters_pushdown: Option< + unsafe extern "C" fn( + provider: &FFI_TableProvider, + filters_serialized: RVec, + ) + -> RResult, RString>, + >, + + /// Used to create a clone on the provider of the execution plan. This should + /// only need to be called by the receiver of the plan. + pub clone: unsafe extern "C" fn(plan: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(arg: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the plan. + /// A [`ForeignExecutionPlan`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_TableProvider {} +unsafe impl Sync for FFI_TableProvider {} + +struct ProviderPrivateData { + provider: Arc, +} + +unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema { + let private_data = provider.private_data as *const ProviderPrivateData; + let provider = &(*private_data).provider; + + provider.schema().into() +} + +unsafe extern "C" fn table_type_fn_wrapper( + provider: &FFI_TableProvider, +) -> FFI_TableType { + let private_data = provider.private_data as *const ProviderPrivateData; + let provider = &(*private_data).provider; + + provider.table_type().into() +} + +fn supports_filters_pushdown_internal( + provider: &Arc, + filters_serialized: &[u8], +) -> Result> { + let default_ctx = SessionContext::new(); + let codec = DefaultLogicalExtensionCodec {}; + + let filters = match filters_serialized.is_empty() { + true => vec![], + false => { + let proto_filters = LogicalExprList::decode(filters_serialized) + .map_err(|e| DataFusionError::Plan(e.to_string()))?; + + parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)? + } + }; + let filters_borrowed: Vec<&Expr> = filters.iter().collect(); + + let results: RVec<_> = provider + .supports_filters_pushdown(&filters_borrowed)? + .iter() + .map(|v| v.into()) + .collect(); + + Ok(results) +} + +unsafe extern "C" fn supports_filters_pushdown_fn_wrapper( + provider: &FFI_TableProvider, + filters_serialized: RVec, +) -> RResult, RString> { + let private_data = provider.private_data as *const ProviderPrivateData; + let provider = &(*private_data).provider; + + supports_filters_pushdown_internal(provider, &filters_serialized) + .map_err(|e| e.to_string().into()) + .into() +} + +unsafe extern "C" fn scan_fn_wrapper( + provider: &FFI_TableProvider, + session_config: &FFI_SessionConfig, + projections: RVec, + filters_serialized: RVec, + limit: ROption, +) -> FfiFuture> { + let private_data = provider.private_data as *mut ProviderPrivateData; + let internal_provider = &(*private_data).provider; + let session_config = session_config.clone(); + + async move { + let config = match ForeignSessionConfig::try_from(&session_config) { + Ok(c) => c, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + let session = SessionStateBuilder::new() + .with_default_features() + .with_config(config.0) + .build(); + let ctx = SessionContext::new_with_state(session); + + let filters = match filters_serialized.is_empty() { + true => vec![], + false => { + let default_ctx = SessionContext::new(); + let codec = DefaultLogicalExtensionCodec {}; + + let proto_filters = + match LogicalExprList::decode(filters_serialized.as_ref()) { + Ok(f) => f, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + + match parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec) { + Ok(f) => f, + Err(e) => return RResult::RErr(e.to_string().into()), + } + } + }; + + let projections: Vec<_> = projections.into_iter().collect(); + let maybe_projections = match projections.is_empty() { + true => None, + false => Some(&projections), + }; + + let plan = match internal_provider + .scan(&ctx.state(), maybe_projections, &filters, limit.into()) + .await + { + Ok(p) => p, + Err(e) => return RResult::RErr(e.to_string().into()), + }; + + RResult::ROk(FFI_ExecutionPlan::new(plan, ctx.task_ctx())) + } + .into_ffi() +} + +unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) { + let private_data = Box::from_raw(provider.private_data as *mut ProviderPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_TableProvider { + let old_private_data = provider.private_data as *const ProviderPrivateData; + + let private_data = Box::into_raw(Box::new(ProviderPrivateData { + provider: Arc::clone(&(*old_private_data).provider), + })) as *mut c_void; + + FFI_TableProvider { + schema: schema_fn_wrapper, + scan: scan_fn_wrapper, + table_type: table_type_fn_wrapper, + supports_filters_pushdown: provider.supports_filters_pushdown, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data, + } +} + +impl Drop for FFI_TableProvider { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +impl FFI_TableProvider { + /// Creates a new [`FFI_TableProvider`]. + pub fn new( + provider: Arc, + can_support_pushdown_filters: bool, + ) -> Self { + let private_data = Box::new(ProviderPrivateData { provider }); + + Self { + schema: schema_fn_wrapper, + scan: scan_fn_wrapper, + table_type: table_type_fn_wrapper, + supports_filters_pushdown: match can_support_pushdown_filters { + true => Some(supports_filters_pushdown_fn_wrapper), + false => None, + }, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +/// This wrapper struct exists on the reciever side of the FFI interface, so it has +/// no guarantees about being able to access the data in `private_data`. Any functions +/// defined on this struct must only use the stable functions provided in +/// FFI_TableProvider to interact with the foreign table provider. +#[derive(Debug)] +pub struct ForeignTableProvider(FFI_TableProvider); + +unsafe impl Send for ForeignTableProvider {} +unsafe impl Sync for ForeignTableProvider {} + +impl From<&FFI_TableProvider> for ForeignTableProvider { + fn from(provider: &FFI_TableProvider) -> Self { + Self(provider.clone()) + } +} + +impl Clone for FFI_TableProvider { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +#[async_trait] +impl TableProvider for ForeignTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + let wrapped_schema = unsafe { (self.0.schema)(&self.0) }; + wrapped_schema.into() + } + + fn table_type(&self) -> TableType { + unsafe { (self.0.table_type)(&self.0).into() } + } + + async fn scan( + &self, + session: &dyn Session, + projection: Option<&Vec>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let session_config: FFI_SessionConfig = session.config().into(); + + let projections: Option> = + projection.map(|p| p.iter().map(|v| v.to_owned()).collect()); + + let codec = DefaultLogicalExtensionCodec {}; + let filter_list = LogicalExprList { + expr: serialize_exprs(filters, &codec)?, + }; + let filters_serialized = filter_list.encode_to_vec().into(); + + let plan = unsafe { + let maybe_plan = (self.0.scan)( + &self.0, + &session_config, + projections.unwrap_or_default(), + filters_serialized, + limit.into(), + ) + .await; + + match maybe_plan { + RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?, + RResult::RErr(_) => { + return Err(DataFusionError::Internal( + "Unable to perform scan via FFI".to_string(), + )) + } + } + }; + + Ok(Arc::new(plan)) + } + + /// Tests whether the table provider can make use of a filter expression + /// to optimise data retrieval. + fn supports_filters_pushdown( + &self, + filters: &[&Expr], + ) -> Result> { + unsafe { + let pushdown_fn = match self.0.supports_filters_pushdown { + Some(func) => func, + None => { + return Ok(vec![ + TableProviderFilterPushDown::Unsupported; + filters.len() + ]) + } + }; + + let codec = DefaultLogicalExtensionCodec {}; + + let expr_list = LogicalExprList { + expr: serialize_exprs(filters.iter().map(|f| f.to_owned()), &codec)?, + }; + let serialized_filters = expr_list.encode_to_vec(); + + let pushdowns = pushdown_fn(&self.0, serialized_filters.into()); + + match pushdowns { + RResult::ROk(p) => Ok(p.iter().map(|v| v.into()).collect()), + RResult::RErr(e) => Err(DataFusionError::Plan(e.to_string())), + } + } + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::prelude::{col, lit}; + + use super::*; + + #[tokio::test] + async fn test_round_trip_ffi_table_provider() -> Result<()> { + use arrow::datatypes::Field; + use datafusion::arrow::{ + array::Float32Array, datatypes::DataType, record_batch::RecordBatch, + }; + use datafusion::datasource::MemTable; + + let schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + let ctx = SessionContext::new(); + + let provider = + Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?); + + let ffi_provider = FFI_TableProvider::new(provider, true); + + let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into(); + + ctx.register_table("t", Arc::new(foreign_table_provider))?; + + let df = ctx.table("t").await?; + + df.select(vec![col("a")])? + .filter(col("a").gt(lit(3.0)))? + .show() + .await?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/table_source.rs b/datafusion/ffi/src/table_source.rs new file mode 100644 index 000000000000..a59836622ee6 --- /dev/null +++ b/datafusion/ffi/src/table_source.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use abi_stable::StableAbi; +use datafusion::{datasource::TableType, logical_expr::TableProviderFilterPushDown}; + +/// FFI safe version of [`TableProviderFilterPushDown`]. +#[repr(C)] +#[derive(StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_TableProviderFilterPushDown { + Unsupported, + Inexact, + Exact, +} + +impl From<&FFI_TableProviderFilterPushDown> for TableProviderFilterPushDown { + fn from(value: &FFI_TableProviderFilterPushDown) -> Self { + match value { + FFI_TableProviderFilterPushDown::Unsupported => { + TableProviderFilterPushDown::Unsupported + } + FFI_TableProviderFilterPushDown::Inexact => { + TableProviderFilterPushDown::Inexact + } + FFI_TableProviderFilterPushDown::Exact => TableProviderFilterPushDown::Exact, + } + } +} + +impl From<&TableProviderFilterPushDown> for FFI_TableProviderFilterPushDown { + fn from(value: &TableProviderFilterPushDown) -> Self { + match value { + TableProviderFilterPushDown::Unsupported => { + FFI_TableProviderFilterPushDown::Unsupported + } + TableProviderFilterPushDown::Inexact => { + FFI_TableProviderFilterPushDown::Inexact + } + TableProviderFilterPushDown::Exact => FFI_TableProviderFilterPushDown::Exact, + } + } +} + +/// FFI safe version of [`TableType`]. +#[repr(C)] +#[allow(non_camel_case_types)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, StableAbi)] +pub enum FFI_TableType { + Base, + View, + Temporary, +} + +impl From for TableType { + fn from(value: FFI_TableType) -> Self { + match value { + FFI_TableType::Base => TableType::Base, + FFI_TableType::View => TableType::View, + FFI_TableType::Temporary => TableType::Temporary, + } + } +} + +impl From for FFI_TableType { + fn from(value: TableType) -> Self { + match value { + TableType::Base => FFI_TableType::Base, + TableType::View => FFI_TableType::View, + TableType::Temporary => FFI_TableType::Temporary, + } + } +} diff --git a/datafusion/functions-aggregate-common/Cargo.toml b/datafusion/functions-aggregate-common/Cargo.toml new file mode 100644 index 000000000000..a8296ce11f30 --- /dev/null +++ b/datafusion/functions-aggregate-common/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-aggregate-common" +description = "Utility functions for implementing aggregate functions for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_functions_aggregate_common" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ahash = { workspace = true } +arrow = { workspace = true } +datafusion-common = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +rand = { workspace = true } diff --git a/datafusion/functions-aggregate-common/src/accumulator.rs b/datafusion/functions-aggregate-common/src/accumulator.rs new file mode 100644 index 000000000000..67ada562800b --- /dev/null +++ b/datafusion/functions-aggregate-common/src/accumulator.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::Result; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; +use std::sync::Arc; + +/// [`AccumulatorArgs`] contains information about how an aggregate +/// function was called, including the types of its arguments and any optional +/// ordering expressions. +#[derive(Debug)] +pub struct AccumulatorArgs<'a> { + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The schema of the input arguments + pub schema: &'a Schema, + + /// Whether to ignore nulls. + /// + /// SQL allows the user to specify `IGNORE NULLS`, for example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1) IGNORE NULLS FROM t; + /// ``` + pub ignore_nulls: bool, + + /// The expressions in the `ORDER BY` clause passed to this aggregator. + /// + /// SQL allows the user to specify the ordering of arguments to the + /// aggregate using an `ORDER BY`. For example: + /// + /// ```sql + /// SELECT FIRST_VALUE(column1 ORDER BY column2) FROM t; + /// ``` + /// + /// If no `ORDER BY` is specified, `ordering_req` will be empty. + pub ordering_req: LexOrderingRef<'a>, + + /// Whether the aggregation is running in reverse order + pub is_reversed: bool, + + /// The name of the aggregate expression + pub name: &'a str, + + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + + /// The physical expression of arguments the aggregate function takes. + pub exprs: &'a [Arc], +} + +/// Factory that returns an accumulator for the given aggregate function. +pub type AccumulatorFactoryFunction = + Arc Result> + Send + Sync>; + +/// [`StateFieldsArgs`] contains information about the fields that an +/// aggregate function's accumulator should have. Used for `AggregateUDFImpl::state_fields`. +pub struct StateFieldsArgs<'a> { + /// The name of the aggregate function. + pub name: &'a str, + + /// The input types of the aggregate function. + pub input_types: &'a [DataType], + + /// The return type of the aggregate function. + pub return_type: &'a DataType, + + /// The ordering fields of the aggregate function. + pub ordering_fields: &'a [Field], + + /// Whether the aggregate function is distinct. + pub is_distinct: bool, +} diff --git a/datafusion/functions-aggregate-common/src/aggregate.rs b/datafusion/functions-aggregate-common/src/aggregate.rs new file mode 100644 index 000000000000..c9cbaa8396fc --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod count_distinct; +pub mod groups_accumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs new file mode 100644 index 000000000000..7d772f7c649d --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bytes; +mod native; + +pub use bytes::BytesDistinctCountAccumulator; +pub use bytes::BytesViewDistinctCountAccumulator; +pub use native::FloatDistinctCountAccumulator; +pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs similarity index 53% rename from datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs rename to datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs index 2ed9b002c841..07fa4efc990e 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/bytes.rs @@ -17,13 +17,15 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values -use crate::binary_map::{ArrowBytesSet, OutputType}; -use arrow_array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewSet; use std::fmt::Debug; +use std::mem::size_of_val; use std::sync::Arc; /// Specialized implementation of @@ -35,10 +37,10 @@ use std::sync::Arc; /// [`BinaryArray`]: arrow::array::BinaryArray /// [`LargeBinaryArray`]: arrow::array::LargeBinaryArray #[derive(Debug)] -pub(super) struct BytesDistinctCountAccumulator(ArrowBytesSet); +pub struct BytesDistinctCountAccumulator(ArrowBytesSet); impl BytesDistinctCountAccumulator { - pub(super) fn new(output_type: OutputType) -> Self { + pub fn new(output_type: OutputType) -> Self { Self(ArrowBytesSet::new(output_type)) } } @@ -47,7 +49,7 @@ impl Accumulator for BytesDistinctCountAccumulator { fn state(&mut self) -> datafusion_common::Result> { let set = self.0.take(); let arr = set.into_state(); - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -85,6 +87,66 @@ impl Accumulator for BytesDistinctCountAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.0.size() + size_of_val(self) + self.0.size() + } +} + +/// Specialized implementation of +/// `COUNT DISTINCT` for [`StringViewArray`] and [`BinaryViewArray`]. +/// +/// [`StringViewArray`]: arrow::array::StringViewArray +/// [`BinaryViewArray`]: arrow::array::BinaryViewArray +#[derive(Debug)] +pub struct BytesViewDistinctCountAccumulator(ArrowBytesViewSet); + +impl BytesViewDistinctCountAccumulator { + pub fn new(output_type: OutputType) -> Self { + Self(ArrowBytesViewSet::new(output_type)) + } +} + +impl Accumulator for BytesViewDistinctCountAccumulator { + fn state(&mut self) -> datafusion_common::Result> { + let set = self.0.take(); + let arr = set.into_state(); + let list = Arc::new(array_into_list_array_nullable(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() { + return Ok(()); + } + + self.0.insert(&values[0]); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + self.0.insert(&list); + }; + Ok(()) + }) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64))) + } + + fn size(&self) -> usize { + size_of_val(self) + self.0.size() } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs similarity index 77% rename from datafusion/physical-expr/src/aggregate/count_distinct/native.rs rename to datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs index 95d8662e0f6e..405b2c2db7bd 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/native.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/count_distinct/native.rs @@ -23,23 +23,25 @@ use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; +use std::mem::size_of_val; use std::sync::Arc; use ahash::RandomState; +use arrow::array::types::ArrowPrimitiveType; use arrow::array::ArrayRef; -use arrow_array::types::ArrowPrimitiveType; -use arrow_array::PrimitiveArray; -use arrow_schema::DataType; +use arrow::array::PrimitiveArray; +use arrow::datatypes::DataType; use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::utils::array_into_list_array; +use datafusion_common::utils::array_into_list_array_nullable; +use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; +use datafusion_expr_common::accumulator::Accumulator; -use crate::aggregate::utils::Hashable; +use crate::utils::Hashable; #[derive(Debug)] -pub(super) struct PrimitiveDistinctCountAccumulator +pub struct PrimitiveDistinctCountAccumulator where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, @@ -53,7 +55,7 @@ where T: ArrowPrimitiveType + Send, T::Native: Eq + Hash, { - pub(super) fn new(data_type: &DataType) -> Self { + pub fn new(data_type: &DataType) -> Self { Self { values: HashSet::default(), data_type: data_type.clone(), @@ -71,7 +73,7 @@ where PrimitiveArray::::from_iter_values(self.values.iter().cloned()) .with_data_type(self.data_type.clone()), ); - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -115,23 +117,15 @@ where } fn size(&self) -> usize { - let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) - / 7) - .next_power_of_two(); - - // Size of accumulator - // + size of entry * number of buckets - // + 1 byte for each bucket - // + fixed size of HashSet - std::mem::size_of_val(self) - + std::mem::size_of::() * estimated_buckets - + estimated_buckets - + std::mem::size_of_val(&self.values) + let num_elements = self.values.len(); + let fixed_size = size_of_val(self) + size_of_val(&self.values); + + estimate_memory_size::(num_elements, fixed_size).unwrap() } } #[derive(Debug)] -pub(super) struct FloatDistinctCountAccumulator +pub struct FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { @@ -142,13 +136,22 @@ impl FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send, { - pub(super) fn new() -> Self { + pub fn new() -> Self { Self { values: HashSet::default(), } } } +impl Default for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn default() -> Self { + Self::new() + } +} + impl Accumulator for FloatDistinctCountAccumulator where T: ArrowPrimitiveType + Send + Debug, @@ -157,7 +160,7 @@ where let arr = Arc::new(PrimitiveArray::::from_iter_values( self.values.iter().map(|v| v.0), )) as ArrayRef; - let list = Arc::new(array_into_list_array(arr)); + let list = Arc::new(array_into_list_array_nullable(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -202,17 +205,9 @@ where } fn size(&self) -> usize { - let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) - / 7) - .next_power_of_two(); - - // Size of accumulator - // + size of entry * number of buckets - // + 1 byte for each bucket - // + fixed size of HashSet - std::mem::size_of_val(self) - + std::mem::size_of::() * estimated_buckets - + estimated_buckets - + std::mem::size_of_val(&self.values) + let num_elements = self.values.len(); + let fixed_size = size_of_val(self) + size_of_val(&self.values); + + estimate_memory_size::(num_elements, fixed_size).unwrap() } } diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs similarity index 67% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 9856e1c989b3..03e4ef557269 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/adapter.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -15,19 +15,26 @@ // specific language governing permissions and limitations // under the License. +//! Utilities for implementing GroupsAccumulator //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] +pub mod accumulate; +pub mod bool_op; +pub mod nulls; +pub mod prim_op; + +use std::mem::{size_of, size_of_val}; + +use arrow::array::new_empty_array; use arrow::{ - array::{AsArray, UInt32Builder}, + array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}, compute, + compute::take_arrays, datatypes::UInt32Type, }; -use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; -use datafusion_common::{ - arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, - ScalarValue, -}; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// @@ -36,6 +43,52 @@ use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; /// they are not as fast as a specialized `GroupsAccumulator`. This /// interface bridges the gap so the group by operator only operates /// in terms of [`Accumulator`]. +/// +/// Internally, this adapter creates a new [`Accumulator`] for each group which +/// stores the state for that group. This both requires an allocation for each +/// Accumulator, internal indices, as well as whatever internal allocations the +/// Accumulator itself requires. +/// +/// For example, a `MinAccumulator` that computes the minimum string value with +/// a [`ScalarValue::Utf8`]. That will require at least two allocations per group +/// (one for the `MinAccumulator` and one for the `ScalarValue::Utf8`). +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// │MinAccumulator { │ +/// ┌─────▶│ min: ScalarValue::Utf8("A") │───────┐ +/// │ │} │ │ +/// │ └─────────────────────────────────┘ └───────▶ "A" +/// ┌─────┐ │ ┌─────────────────────────────────┐ +/// │ 0 │─────┘ │MinAccumulator { │ +/// ├─────┤ ┌─────▶│ min: ScalarValue::Utf8("Z") │───────────────▶ "Z" +/// │ 1 │─────┘ │} │ +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │ │MinAccumulator { │ +/// ├─────┤ │ min: ScalarValue::Utf8("A") │────────────────▶ "A" +/// │ N-1 │─────┐ │} │ +/// └─────┘ │ └────────────────────────────────┘ +/// │ ┌────────────────────────────────┐ ┌───────▶ "Q" +/// │ │MinAccumulator { │ │ +/// └─────▶│ min: ScalarValue::Utf8("Q") │────────┘ +/// │} │ +/// └────────────────────────────────┘ +/// +/// +/// Logical group Current Min/Max value for that group stored +/// number as a ScalarValue which points to an +/// indivdually allocated String +/// +///``` +/// +/// # Optimizations +/// +/// The adapter minimizes the number of calls to [`Accumulator::update_batch`] +/// by first collecting the input rows for each group into a contiguous array +/// using [`compute::take`] +/// pub struct GroupsAccumulatorAdapter { factory: Box Result> + Send>, @@ -55,9 +108,9 @@ struct AccumulatorState { /// [`Accumulator`] that stores the per-group state accumulator: Box, - // scratch space: indexes in the input array that will be fed to - // this accumulator. Stores indexes as `u32` to match the arrow - // `take` kernel input. + /// scratch space: indexes in the input array that will be fed to + /// this accumulator. Stores indexes as `u32` to match the arrow + /// `take` kernel input. indices: Vec, } @@ -69,11 +122,9 @@ impl AccumulatorState { } } - /// Returns the amount of memory taken by this structre and its accumulator + /// Returns the amount of memory taken by this structure and its accumulator fn size(&self) -> usize { - self.accumulator.size() - + std::mem::size_of_val(self) - + self.indices.allocated_size() + self.accumulator.size() + size_of_val(self) + self.indices.allocated_size() } } @@ -164,7 +215,7 @@ impl GroupsAccumulatorAdapter { let mut groups_with_rows = vec![]; // batch_indices holds indices into values, each group is contiguous - let mut batch_indices = UInt32Builder::with_capacity(0); + let mut batch_indices = vec![]; // offsets[i] is index into batch_indices where the rows for // group_index i starts @@ -178,16 +229,16 @@ impl GroupsAccumulatorAdapter { } groups_with_rows.push(group_index); - batch_indices.append_slice(indices); + batch_indices.extend_from_slice(indices); offset_so_far += indices.len(); offsets.push(offset_so_far); } - let batch_indices = batch_indices.finish(); + let batch_indices = batch_indices.into(); // reorder the values and opt_filter by batch_indices so that // all values for each group are contiguous, then invoke the // accumulator once per group with values - let values = get_arrayref_at_indices(values, &batch_indices)?; + let values = take_arrays(values, &batch_indices, None)?; let opt_filter = get_filter_at_indices(opt_filter, &batch_indices)?; // invoke each accumulator with the appropriate rows, first @@ -201,9 +252,12 @@ impl GroupsAccumulatorAdapter { let state = &mut self.states[group_idx]; sizes_pre += state.size(); - let values_to_accumulate = - slice_and_maybe_filter(&values, opt_filter.as_ref(), offsets)?; - (f)(state.accumulator.as_mut(), &values_to_accumulate)?; + let values_to_accumulate = slice_and_maybe_filter( + &values, + opt_filter.as_ref().map(|f| f.as_boolean()), + offsets, + )?; + f(state.accumulator.as_mut(), &values_to_accumulate)?; // clear out the state so they are empty for next // iteration @@ -284,6 +338,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { result } + // filtered_null_mask(opt_filter, &values); fn state(&mut self, emit_to: EmitTo) -> Result> { let vec_size_pre = self.states.allocated_size(); let states = emit_to.take_needed(&mut self.states); @@ -342,6 +397,58 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { fn size(&self) -> usize { self.allocation_bytes } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let num_rows = values[0].len(); + + // If there are no rows, return empty arrays + if num_rows == 0 { + // create empty accumulator to get the state types + let empty_state = (self.factory)()?.state()?; + let empty_arrays = empty_state + .into_iter() + .map(|state_val| new_empty_array(&state_val.data_type())) + .collect::>(); + + return Ok(empty_arrays); + } + + // Each row has its respective group + let mut results = vec![]; + for row_idx in 0..num_rows { + // Create the empty accumulator for converting + let mut converted_accumulator = (self.factory)()?; + + // Convert row to states + let values_to_accumulate = + slice_and_maybe_filter(values, opt_filter, &[row_idx, row_idx + 1])?; + converted_accumulator.update_batch(&values_to_accumulate)?; + let states = converted_accumulator.state()?; + + // Resize results to have enough columns according to the converted states + results.resize_with(states.len(), || Vec::with_capacity(num_rows)); + + // Add the states to results + for (idx, state_val) in states.into_iter().enumerate() { + results[idx].push(state_val); + } + } + + let arrays = results + .into_iter() + .map(ScalarValue::iter_to_array) + .collect::>>()?; + + Ok(arrays) + } + + fn supports_convert_to_state(&self) -> bool { + true + } } /// Extension trait for [`Vec`] to account for allocations. @@ -357,7 +464,7 @@ pub trait VecAllocExt { impl VecAllocExt for Vec { type T = T; fn allocated_size(&self) -> usize { - std::mem::size_of::() * self.capacity() + size_of::() * self.capacity() } } @@ -378,7 +485,7 @@ fn get_filter_at_indices( // Copied from physical-plan pub(crate) fn slice_and_maybe_filter( aggr_array: &[ArrayRef], - filter_opt: Option<&ArrayRef>, + filter_opt: Option<&BooleanArray>, offsets: &[usize], ) -> Result> { let (offset, length) = (offsets[0], offsets[1] - offsets[0]); @@ -388,13 +495,12 @@ pub(crate) fn slice_and_maybe_filter( .collect(); if let Some(f) = filter_opt { - let filter_array = f.slice(offset, length); - let filter_array = filter_array.as_boolean(); + let filter = f.slice(offset, length); sliced_arrays .iter() .map(|array| { - compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e)) + compute::filter(&array, &filter).map_err(|e| arrow_datafusion_err!(e)) }) .collect() } else { diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs similarity index 80% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 9850b002e40e..3efd348937ed 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -17,13 +17,13 @@ //! [`GroupsAccumulator`] helpers: [`NullState`] and [`accumulate_indices`] //! -//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator +//! [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator +use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::datatypes::ArrowPrimitiveType; -use arrow_array::{Array, BooleanArray, PrimitiveArray}; -use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer}; -use datafusion_expr::EmitTo; +use datafusion_expr_common::groups_accumulator::EmitTo; /// Track the accumulator null state per row: if any values for that /// group were null and if any values have been seen at all for that group. /// @@ -48,7 +48,7 @@ use datafusion_expr::EmitTo; /// had at least one value to accumulate so they do not need to track /// if they have seen values for a particular group. /// -/// [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator +/// [`GroupsAccumulator`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator #[derive(Debug)] pub struct NullState { /// Have we seen any non-filtered input values for `group_index`? @@ -91,38 +91,11 @@ impl NullState { /// * `opt_filter`: if present, only rows for which is Some(true) are included /// * `value_fn`: function invoked for (group_index, value) where value is non null /// - /// # Example + /// See [`accumulate`], for more details on how value_fn is called /// - /// ```text - /// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ - /// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ - /// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ - /// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ - /// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ - /// │ └─────┘ │ │ └─────┘ │ └─────┘ - /// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ + /// When value_fn is called it also sets /// - /// group_indices values opt_filter - /// ``` - /// - /// In the example above, `value_fn` is invoked for each (group_index, - /// value) pair where `opt_filter[i]` is true and values is non null - /// - /// ```text - /// value_fn(2, 200) - /// value_fn(0, 200) - /// value_fn(0, 300) - /// ``` - /// - /// It also sets - /// - /// 1. `self.seen_values[group_index]` to true for all rows that had a non null vale + /// 1. `self.seen_values[group_index]` to true for all rows that had a non null value pub fn accumulate( &mut self, group_indices: &[usize], @@ -134,105 +107,14 @@ impl NullState { T: ArrowPrimitiveType + Send, F: FnMut(usize, T::Native) + Send, { - let data: &[T::Native] = values.values(); - assert_eq!(data.len(), group_indices.len()); - // ensure the seen_values is big enough (start everything at // "not seen" valid) let seen_values = initialize_builder(&mut self.seen_values, total_num_groups, false); - - match (values.null_count() > 0, opt_filter) { - // no nulls, no filter, - (false, None) => { - let iter = group_indices.iter().zip(data.iter()); - for (&group_index, &new_value) in iter { - seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); - } - } - // nulls, no filter - (true, None) => { - let nulls = values.nulls().unwrap(); - // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum - // iterate over in chunks of 64 bits for more efficient null checking - let group_indices_chunks = group_indices.chunks_exact(64); - let data_chunks = data.chunks_exact(64); - let bit_chunks = nulls.inner().bit_chunks(); - - let group_indices_remainder = group_indices_chunks.remainder(); - let data_remainder = data_chunks.remainder(); - - group_indices_chunks - .zip(data_chunks) - .zip(bit_chunks.iter()) - .for_each(|((group_index_chunk, data_chunk), mask)| { - // index_mask has value 1 << i in the loop - let mut index_mask = 1; - group_index_chunk.iter().zip(data_chunk.iter()).for_each( - |(&group_index, &new_value)| { - // valid bit was set, real value - let is_valid = (mask & index_mask) != 0; - if is_valid { - seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); - } - index_mask <<= 1; - }, - ) - }); - - // handle any remaining bits (after the initial 64) - let remainder_bits = bit_chunks.remainder_bits(); - group_indices_remainder - .iter() - .zip(data_remainder.iter()) - .enumerate() - .for_each(|(i, (&group_index, &new_value))| { - let is_valid = remainder_bits & (1 << i) != 0; - if is_valid { - seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); - } - }); - } - // no nulls, but a filter - (false, Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - // The performance with a filter could be improved by - // iterating over the filter in chunks, rather than a single - // iterator. TODO file a ticket - group_indices - .iter() - .zip(data.iter()) - .zip(filter.iter()) - .for_each(|((&group_index, &new_value), filter_value)| { - if let Some(true) = filter_value { - seen_values.set_bit(group_index, true); - value_fn(group_index, new_value); - } - }) - } - // both null values and filters - (true, Some(filter)) => { - assert_eq!(filter.len(), group_indices.len()); - // The performance with a filter could be improved by - // iterating over the filter in chunks, rather than using - // iterators. TODO file a ticket - filter - .iter() - .zip(group_indices.iter()) - .zip(values.iter()) - .for_each(|((filter_value, &group_index), new_value)| { - if let Some(true) = filter_value { - if let Some(new_value) = new_value { - seen_values.set_bit(group_index, true); - value_fn(group_index, new_value) - } - } - }) - } - } + accumulate(group_indices, values, opt_filter, |group_index, value| { + seen_values.set_bit(group_index, true); + value_fn(group_index, value); + }); } /// Invokes `value_fn(group_index, value)` for each non null, non @@ -351,6 +233,144 @@ impl NullState { } } +/// Invokes `value_fn(group_index, value)` for each non null, non +/// filtered value of `value`, +/// +/// # Arguments: +/// +/// * `group_indices`: To which groups do the rows in `values` belong, (aka group_index) +/// * `values`: the input arguments to the accumulator +/// * `opt_filter`: if present, only rows for which is Some(true) are included +/// * `value_fn`: function invoked for (group_index, value) where value is non null +/// +/// # Example +/// +/// ```text +/// ┌─────────┐ ┌─────────┐ ┌ ─ ─ ─ ─ ┐ +/// │ ┌─────┐ │ │ ┌─────┐ │ ┌─────┐ +/// │ │ 2 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 2 │ │ │ │ 100 │ │ │ │ f │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 200 │ │ │ │ t │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 1 │ │ │ │ 200 │ │ │ │NULL │ │ +/// │ ├─────┤ │ │ ├─────┤ │ ├─────┤ +/// │ │ 0 │ │ │ │ 300 │ │ │ │ t │ │ +/// │ └─────┘ │ │ └─────┘ │ └─────┘ +/// └─────────┘ └─────────┘ └ ─ ─ ─ ─ ┘ +/// +/// group_indices values opt_filter +/// ``` +/// +/// In the example above, `value_fn` is invoked for each (group_index, +/// value) pair where `opt_filter[i]` is true and values is non null +/// +/// ```text +/// value_fn(2, 200) +/// value_fn(0, 200) +/// value_fn(0, 300) +/// ``` +pub fn accumulate( + group_indices: &[usize], + values: &PrimitiveArray, + opt_filter: Option<&BooleanArray>, + mut value_fn: F, +) where + T: ArrowPrimitiveType + Send, + F: FnMut(usize, T::Native) + Send, +{ + let data: &[T::Native] = values.values(); + assert_eq!(data.len(), group_indices.len()); + + match (values.null_count() > 0, opt_filter) { + // no nulls, no filter, + (false, None) => { + let iter = group_indices.iter().zip(data.iter()); + for (&group_index, &new_value) in iter { + value_fn(group_index, new_value); + } + } + // nulls, no filter + (true, None) => { + let nulls = values.nulls().unwrap(); + // This is based on (ahem, COPY/PASTE) arrow::compute::aggregate::sum + // iterate over in chunks of 64 bits for more efficient null checking + let group_indices_chunks = group_indices.chunks_exact(64); + let data_chunks = data.chunks_exact(64); + let bit_chunks = nulls.inner().bit_chunks(); + + let group_indices_remainder = group_indices_chunks.remainder(); + let data_remainder = data_chunks.remainder(); + + group_indices_chunks + .zip(data_chunks) + .zip(bit_chunks.iter()) + .for_each(|((group_index_chunk, data_chunk), mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + group_index_chunk.iter().zip(data_chunk.iter()).for_each( + |(&group_index, &new_value)| { + // valid bit was set, real value + let is_valid = (mask & index_mask) != 0; + if is_valid { + value_fn(group_index, new_value); + } + index_mask <<= 1; + }, + ) + }); + + // handle any remaining bits (after the initial 64) + let remainder_bits = bit_chunks.remainder_bits(); + group_indices_remainder + .iter() + .zip(data_remainder.iter()) + .enumerate() + .for_each(|(i, (&group_index, &new_value))| { + let is_valid = remainder_bits & (1 << i) != 0; + if is_valid { + value_fn(group_index, new_value); + } + }); + } + // no nulls, but a filter + (false, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than a single + // iterator. TODO file a ticket + group_indices + .iter() + .zip(data.iter()) + .zip(filter.iter()) + .for_each(|((&group_index, &new_value), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, new_value); + } + }) + } + // both null values and filters + (true, Some(filter)) => { + assert_eq!(filter.len(), group_indices.len()); + // The performance with a filter could be improved by + // iterating over the filter in chunks, rather than using + // iterators. TODO file a ticket + filter + .iter() + .zip(group_indices.iter()) + .zip(values.iter()) + .for_each(|((filter_value, &group_index), new_value)| { + if let Some(true) = filter_value { + if let Some(new_value) = new_value { + value_fn(group_index, new_value) + } + } + }) + } + } +} + /// This function is called to update the accumulator state per row /// when the value is not needed (e.g. COUNT) /// @@ -410,7 +430,7 @@ pub fn accumulate_indices( }, ); - // handle any remaining bits (after the intial 64) + // handle any remaining bits (after the initial 64) let remainder_bits = bit_chunks.remainder_bits(); group_indices_remainder .iter() @@ -462,9 +482,9 @@ fn initialize_builder( mod test { use super::*; - use arrow_array::UInt32Array; - use hashbrown::HashSet; + use arrow::array::UInt32Array; use rand::{rngs::ThreadRng, Rng}; + use std::collections::HashSet; #[test] fn accumulate() { @@ -835,7 +855,7 @@ mod test { } } - /// Parallel implementaiton of NullState to check expected values + /// Parallel implementation of NullState to check expected values #[derive(Debug, Default)] struct MockNullState { /// group indices that had values that passed the filter diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs similarity index 75% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs index f40c661a7a2f..149312e5a9c0 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/bool_op.rs @@ -17,11 +17,11 @@ use std::sync::Arc; -use arrow::array::AsArray; -use arrow_array::{ArrayRef, BooleanArray}; -use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder}; +use crate::aggregate::groups_accumulator::nulls::filtered_null_mask; +use arrow::array::{ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder}; +use arrow::buffer::BooleanBuffer; use datafusion_common::Result; -use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use super::accumulate::NullState; @@ -47,17 +47,22 @@ where /// Function that computes the output bool_fn: F, + + /// The identity element for the boolean operation. + /// Any value combined with this returns the original value. + identity: bool, } impl BooleanGroupsAccumulator where F: Fn(bool, bool) -> bool + Send + Sync, { - pub fn new(bitop_fn: F) -> Self { + pub fn new(bool_fn: F, identity: bool) -> Self { Self { values: BooleanBufferBuilder::new(0), null_state: NullState::new(), - bool_fn: bitop_fn, + bool_fn, + identity, } } } @@ -78,7 +83,9 @@ where if self.values.len() < total_num_groups { let new_groups = total_num_groups - self.values.len(); - self.values.append_n(new_groups, Default::default()); + // Fill with the identity element, so that when the first non-null value is encountered, + // it will combine with the identity and the result will be the first non-null value itself. + self.values.append_n(new_groups, self.identity); } // NullState dispatches / handles tracking nulls and groups that saw no values @@ -136,4 +143,22 @@ where // capacity is in bits, so convert to bytes self.values.capacity() / 8 + self.null_state.size() } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let values = values[0].as_boolean().clone(); + + let values_null_buffer_filtered = filtered_null_mask(opt_filter, &values); + let (values_buf, _) = values.into_parts(); + let values_filtered = BooleanArray::new(values_buf, values_null_buffer_filtered); + + Ok(vec![Arc::new(values_filtered)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs new file mode 100644 index 000000000000..6a8946034cbc --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -0,0 +1,204 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`set_nulls`], other utilities for working with nulls + +use arrow::array::{ + Array, ArrayRef, ArrowNumericType, AsArray, BinaryArray, BinaryViewArray, + BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, + StringViewArray, +}; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, Result}; +use std::sync::Arc; + +/// Sets the validity mask for a `PrimitiveArray` to `nulls` +/// replacing any existing null mask +/// +/// See [`set_nulls_dyn`] for a version that works with `Array` +pub fn set_nulls( + array: PrimitiveArray, + nulls: Option, +) -> PrimitiveArray { + let (dt, values, _old_nulls) = array.into_parts(); + PrimitiveArray::::new(values, nulls).with_data_type(dt) +} + +/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer. +/// +/// The `NullBuffer` is +/// * `true` (representing valid) for values that were `true` in filter +/// * `false` (representing null) for values that were `false` or `null` in filter +fn filter_to_nulls(filter: &BooleanArray) -> Option { + let (filter_bools, filter_nulls) = filter.clone().into_parts(); + let filter_bools = NullBuffer::from(filter_bools); + NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +} + +/// Compute an output validity mask for an array that has been filtered +/// +/// This can be used to compute nulls for the output of +/// [`GroupsAccumulator::convert_to_state`], which quickly applies an optional +/// filter to the input rows by setting any filtered rows to NULL in the output. +/// Subsequent applications of aggregate functions that ignore NULLs (most of +/// them) will thus ignore the filtered rows as well. +/// +/// # Output element is `true` (and thus output is non-null) +/// +/// A `true` in the output represents non null output for all values that were *both*: +/// +/// * `true` in any `opt_filter` (aka values that passed the filter) +/// +/// * `non null` in `input` +/// +/// # Output element is `false` (and thus output is null) +/// +/// A `false` in the output represents an input that was *either*: +/// +/// * `null` +/// +/// * filtered (aka the value was `false` or `null` in the filter) +/// +/// # Example +/// +/// ```text +/// ┌─────┐ ┌─────┐ ┌─────┐ +/// │true │ │NULL │ │false│ +/// │true │ │ │true │ │true │ +/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls +/// │false│ │ │NULL │ │false│ +/// │false│ │true │ │false│ +/// └─────┘ └─────┘ └─────┘ +/// array opt_filter output +/// .nulls() +/// +/// false = NULL true = pass false = NULL Meanings +/// true = valid false = filter true = valid +/// NULL = filter +/// ``` +/// +/// [`GroupsAccumulator::convert_to_state`]: datafusion_expr_common::groups_accumulator::GroupsAccumulator +pub fn filtered_null_mask( + opt_filter: Option<&BooleanArray>, + input: &dyn Array, +) -> Option { + let opt_filter = opt_filter.and_then(filter_to_nulls); + NullBuffer::union(opt_filter.as_ref(), input.nulls()) +} + +/// Applies optional filter to input, returning a new array of the same type +/// with the same data, but with any values that were filtered out set to null +pub fn apply_filter_as_nulls( + input: &dyn Array, + opt_filter: Option<&BooleanArray>, +) -> Result { + let nulls = filtered_null_mask(opt_filter, input); + set_nulls_dyn(input, nulls) +} + +/// Replaces the nulls in the input array with the given `NullBuffer` +/// +/// TODO: replace when upstreamed in arrow-rs: +pub fn set_nulls_dyn(input: &dyn Array, nulls: Option) -> Result { + if let Some(nulls) = nulls.as_ref() { + assert_eq!(nulls.len(), input.len()); + } + + let output: ArrayRef = match input.data_type() { + DataType::Utf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeUtf8 => { + let input = input.as_string::(); + // safety: values / offsets came from a valid string array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeStringArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::Utf8View => { + let input = input.as_string_view(); + // safety: values / views came from a valid string view array, so are valid utf8 + // and we checked nulls has the same length as values + unsafe { + Arc::new(StringViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + + DataType::Binary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::LargeBinary => { + let input = input.as_binary::(); + // safety: values / offsets came from a valid large binary array + // and we checked nulls has the same length as values + unsafe { + Arc::new(LargeBinaryArray::new_unchecked( + input.offsets().clone(), + input.values().clone(), + nulls, + )) + } + } + DataType::BinaryView => { + let input = input.as_binary_view(); + // safety: values / views came from a valid binary view array + // and we checked nulls has the same length as values + unsafe { + Arc::new(BinaryViewArray::new_unchecked( + input.views().clone(), + input.data_buffers().to_vec(), + nulls, + )) + } + } + _ => { + return not_impl_err!("Applying nulls {:?}", input.data_type()); + } + }; + assert_eq!(input.len(), output.len()); + assert_eq!(input.data_type(), output.data_type()); + + Ok(output) +} diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs similarity index 61% rename from datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs rename to datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs index 994f5447d7c0..078982c983fc 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/prim_op.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use std::mem::size_of; use std::sync::Arc; -use arrow::{array::AsArray, datatypes::ArrowPrimitiveType}; -use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray}; -use arrow_schema::DataType; -use datafusion_common::Result; -use datafusion_expr::{EmitTo, GroupsAccumulator}; +use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray}; +use arrow::buffer::NullBuffer; +use arrow::compute; +use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::DataType; +use datafusion_common::{internal_datafusion_err, DataFusionError, Result}; +use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; use super::accumulate::NullState; @@ -134,7 +137,65 @@ where self.update_batch(values, group_indices, opt_filter, total_num_groups) } + /// Converts an input batch directly to a state batch + /// + /// The state is: + /// - self.prim_fn for all non null, non filtered values + /// - null otherwise + /// + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let values = values[0].as_primitive::().clone(); + + // Initializing state with starting values + let initial_state = + PrimitiveArray::::from_value(self.starting_value, values.len()); + + // Recalculating values in case there is filter + let values = match opt_filter { + None => values, + Some(filter) => { + let (filter_values, filter_nulls) = filter.clone().into_parts(); + // Calculating filter mask as a result of bitand of filter, and converting it to null buffer + let filter_bool = match filter_nulls { + Some(filter_nulls) => filter_nulls.inner() & &filter_values, + None => filter_values, + }; + let filter_nulls = NullBuffer::from(filter_bool); + + // Rebuilding input values with a new nulls mask, which is equal to + // the union of original nulls and filter mask + let (dt, values_buf, original_nulls) = values.into_parts(); + let nulls_buf = + NullBuffer::union(original_nulls.as_ref(), Some(&filter_nulls)); + PrimitiveArray::::new(values_buf, nulls_buf).with_data_type(dt) + } + }; + + let state_values = compute::binary_mut(initial_state, &values, |mut x, y| { + (self.prim_fn)(&mut x, y); + x + }); + let state_values = state_values + .map_err(|_| { + internal_datafusion_err!( + "initial_values underlying buffer must not be shared" + ) + })? + .map_err(DataFusionError::from)? + .with_data_type(self.data_type.clone()); + + Ok(vec![Arc::new(state_values)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + fn size(&self) -> usize { - self.values.capacity() * std::mem::size_of::() + self.null_state.size() + self.values.capacity() * size_of::() + self.null_state.size() } } diff --git a/datafusion/functions-aggregate-common/src/lib.rs b/datafusion/functions-aggregate-common/src/lib.rs new file mode 100644 index 000000000000..cc50ff70913b --- /dev/null +++ b/datafusion/functions-aggregate-common/src/lib.rs @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common Aggregate functionality for [DataFusion] +//! +//! This crate contains traits and utilities commonly used to implement aggregate functions +//! They are kept in their own crate to avoid physical expressions depending on logical expressions. +//! +//! [DataFusion]: + +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +pub mod accumulator; +pub mod aggregate; +pub mod merge_arrays; +pub mod order; +pub mod stats; +pub mod tdigest; +pub mod utils; diff --git a/datafusion/functions-aggregate-common/src/merge_arrays.rs b/datafusion/functions-aggregate-common/src/merge_arrays.rs new file mode 100644 index 000000000000..544bdc182829 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/merge_arrays.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::SortOptions; +use datafusion_common::utils::compare_rows; +use datafusion_common::{exec_err, ScalarValue}; +use std::cmp::Ordering; +use std::collections::{BinaryHeap, VecDeque}; + +/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data from +/// multiple partitions using `BinaryHeap`. When used inside `BinaryHeap`, this +/// struct returns smallest `CustomElement`, where smallest is determined by +/// `ordering` values (`Vec`) according to `sort_options`. +#[derive(Debug, PartialEq, Eq)] +struct CustomElement<'a> { + /// Stores the partition this entry came from + branch_idx: usize, + /// Values to merge + value: ScalarValue, + // Comparison "key" + ordering: Vec, + /// Options defining the ordering semantics + sort_options: &'a [SortOptions], +} + +impl<'a> CustomElement<'a> { + fn new( + branch_idx: usize, + value: ScalarValue, + ordering: Vec, + sort_options: &'a [SortOptions], + ) -> Self { + Self { + branch_idx, + value, + ordering, + sort_options, + } + } + + fn ordering( + &self, + current: &[ScalarValue], + target: &[ScalarValue], + ) -> datafusion_common::Result { + // Calculate ordering according to `sort_options` + compare_rows(current, target, self.sort_options) + } +} + +// Overwrite ordering implementation such that +// - `self.ordering` values are used for comparison, +// - When used inside `BinaryHeap` it is a min-heap. +impl<'a> Ord for CustomElement<'a> { + fn cmp(&self, other: &Self) -> Ordering { + // Compares according to custom ordering + self.ordering(&self.ordering, &other.ordering) + // Convert max heap to min heap + .map(|ordering| ordering.reverse()) + // This function return error, when `self.ordering` and `other.ordering` + // have different types (such as one is `ScalarValue::Int64`, other is `ScalarValue::Float32`) + // Here this case won't happen, because data from each partition will have same type + .unwrap() + } +} + +impl<'a> PartialOrd for CustomElement<'a> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +/// This functions merges `values` array (`&[Vec]`) into single array `Vec` +/// Merging done according to ordering values stored inside `ordering_values` (`&[Vec>]`) +/// Inner `Vec` in the `ordering_values` can be thought as ordering information for the +/// each `ScalarValue` in the `values` array. +/// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec` +/// of the `ordering_values` array). +/// +/// As an example +/// values can be \[ +/// \[1, 2, 3, 4, 5\], +/// \[1, 2, 3, 4\], +/// \[1, 2, 3, 4, 5, 6\], +/// \] +/// In this case we will be merging three arrays (doesn't have to be same size) +/// and produce a merged array with size 15 (sum of 5+4+6) +/// Merging will be done according to ordering at `ordering_values` vector. +/// As an example `ordering_values` can be [ +/// \[(1, a), (2, b), (3, b), (4, a), (5, b) \], +/// \[(1, a), (2, b), (3, b), (4, a) \], +/// \[(1, b), (2, c), (3, d), (4, e), (5, a), (6, b) \], +/// ] +/// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) +/// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. +/// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) +pub fn merge_ordered_arrays( + // We will merge values into single `Vec`. + values: &mut [VecDeque], + // `values` will be merged according to `ordering_values`. + // Inner `Vec` can be thought as ordering information for the + // each `ScalarValue` in the values`. + ordering_values: &mut [VecDeque>], + // Defines according to which ordering comparisons should be done. + sort_options: &[SortOptions], +) -> datafusion_common::Result<(Vec, Vec>)> { + // Keep track the most recent data of each branch, in binary heap data structure. + let mut heap = BinaryHeap::::new(); + + if values.len() != ordering_values.len() + || values + .iter() + .zip(ordering_values.iter()) + .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) + { + return exec_err!( + "Expects values arguments and/or ordering_values arguments to have same size" + ); + } + let n_branch = values.len(); + let mut merged_values = vec![]; + let mut merged_orderings = vec![]; + // Continue iterating the loop until consuming data of all branches. + loop { + let minimum = if let Some(minimum) = heap.pop() { + minimum + } else { + // Heap is empty, fill it with the next entries from each branch. + for branch_idx in 0..n_branch { + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); + } + // If None, we consumed this branch, skip it. + } + + // Now we have filled the heap, get the largest entry (this will be + // the next element in merge). + if let Some(minimum) = heap.pop() { + minimum + } else { + // Heap is empty, this means that all indices are same with + // `end_indices`. We have consumed all of the branches, merge + // is completed, exit from the loop: + break; + } + }; + let CustomElement { + branch_idx, + value, + ordering, + .. + } = minimum; + // Add minimum value in the heap to the result + merged_values.push(value); + merged_orderings.push(ordering); + + // If there is an available entry, push next entry in the most + // recently consumed branch to the heap. + if let Some(orderings) = ordering_values[branch_idx].pop_front() { + // Their size should be same, we can safely .unwrap here. + let value = values[branch_idx].pop_front().unwrap(); + // Push the next element to the heap: + heap.push(CustomElement::new( + branch_idx, + value, + orderings, + sort_options, + )); + } + } + + Ok((merged_values, merged_orderings)) +} diff --git a/datafusion/functions-aggregate-common/src/order.rs b/datafusion/functions-aggregate-common/src/order.rs new file mode 100644 index 000000000000..bfa6e39138f9 --- /dev/null +++ b/datafusion/functions-aggregate-common/src/order.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Represents the sensitivity of an aggregate expression to ordering. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum AggregateOrderSensitivity { + /// Indicates that the aggregate expression is insensitive to ordering. + /// Ordering at the input is not important for the result of the aggregator. + Insensitive, + /// Indicates that the aggregate expression has a hard requirement on ordering. + /// The aggregator can not produce a correct result unless its ordering + /// requirement is satisfied. + HardRequirement, + /// Indicates that ordering is beneficial for the aggregate expression in terms + /// of evaluation efficiency. The aggregator can produce its result efficiently + /// when its required ordering is satisfied; however, it can still produce the + /// correct result (albeit less efficiently) when its required ordering is not met. + Beneficial, +} + +impl AggregateOrderSensitivity { + pub fn is_insensitive(&self) -> bool { + self.eq(&AggregateOrderSensitivity::Insensitive) + } + + pub fn is_beneficial(&self) -> bool { + self.eq(&AggregateOrderSensitivity::Beneficial) + } + + pub fn hard_requires(&self) -> bool { + self.eq(&AggregateOrderSensitivity::HardRequirement) + } +} diff --git a/datafusion/physical-expr/src/aggregate/stats.rs b/datafusion/functions-aggregate-common/src/stats.rs similarity index 91% rename from datafusion/physical-expr/src/aggregate/stats.rs rename to datafusion/functions-aggregate-common/src/stats.rs index 98baaccffe81..bcd004db7831 100644 --- a/datafusion/physical-expr/src/aggregate/stats.rs +++ b/datafusion/functions-aggregate-common/src/stats.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. +/// TODO: Move this to functions-aggregate module /// Enum used for differentiating population and sample for statistical functions -#[derive(Debug, Clone, Copy)] +#[derive(PartialEq, Eq, Debug, Clone, Copy)] pub enum StatsType { /// Population Population, diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/functions-aggregate-common/src/tdigest.rs similarity index 88% rename from datafusion/physical-expr/src/aggregate/tdigest.rs rename to datafusion/functions-aggregate-common/src/tdigest.rs index e3b23b91d0ff..786d7ea3e361 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/functions-aggregate-common/src/tdigest.rs @@ -28,11 +28,12 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; -use arrow_array::types::Float64Type; +use arrow::datatypes::Float64Type; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; use std::cmp::Ordering; +use std::mem::{size_of, size_of_val}; pub const DEFAULT_MAX_SIZE: usize = 100; @@ -47,10 +48,21 @@ macro_rules! cast_scalar_f64 { }; } +// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or +// panic. +macro_rules! cast_scalar_u64 { + ($value:expr ) => { + match &$value { + ScalarValue::UInt64(Some(v)) => *v, + v => panic!("invalid type {:?}", v), + } + }; +} + /// This trait is implemented for each type a [`TDigest`] can operate on, /// allowing it to support both numerical rust types (obtained from /// `PrimitiveArray` instances), and [`ScalarValue`] instances. -pub(crate) trait TryIntoF64 { +pub trait TryIntoF64 { /// A fallible conversion of a possibly null `self` into a [`f64`]. /// /// If `self` is null, this method must return `Ok(None)`. @@ -84,7 +96,7 @@ impl_try_ordered_f64!(u8); /// Centroid implementation to the cluster mentioned in the paper. #[derive(Debug, PartialEq, Clone)] -pub(crate) struct Centroid { +pub struct Centroid { mean: f64, weight: f64, } @@ -104,21 +116,21 @@ impl Ord for Centroid { } impl Centroid { - pub(crate) fn new(mean: f64, weight: f64) -> Self { + pub fn new(mean: f64, weight: f64) -> Self { Centroid { mean, weight } } #[inline] - pub(crate) fn mean(&self) -> f64 { + pub fn mean(&self) -> f64 { self.mean } #[inline] - pub(crate) fn weight(&self) -> f64 { + pub fn weight(&self) -> f64 { self.weight } - pub(crate) fn add(&mut self, sum: f64, weight: f64) -> f64 { + pub fn add(&mut self, sum: f64, weight: f64) -> f64 { let new_sum = sum + self.weight * self.mean; let new_weight = self.weight + weight; self.weight = new_weight; @@ -138,62 +150,61 @@ impl Default for Centroid { /// T-Digest to be operated on. #[derive(Debug, PartialEq, Clone)] -pub(crate) struct TDigest { +pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, - count: f64, + count: u64, max: f64, min: f64, } impl TDigest { - pub(crate) fn new(max_size: usize) -> Self { + pub fn new(max_size: usize) -> Self { TDigest { centroids: Vec::new(), max_size, sum: 0_f64, - count: 0_f64, + count: 0, max: f64::NAN, min: f64::NAN, } } - pub(crate) fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self { + pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self { TDigest { centroids: vec![centroid.clone()], max_size, sum: centroid.mean * centroid.weight, - count: 1_f64, + count: 1, max: centroid.mean, min: centroid.mean, } } #[inline] - pub(crate) fn count(&self) -> f64 { + pub fn count(&self) -> u64 { self.count } #[inline] - pub(crate) fn max(&self) -> f64 { + pub fn max(&self) -> f64 { self.max } #[inline] - pub(crate) fn min(&self) -> f64 { + pub fn min(&self) -> f64 { self.min } #[inline] - pub(crate) fn max_size(&self) -> usize { + pub fn max_size(&self) -> usize { self.max_size } /// Size in bytes including `Self`. - pub(crate) fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.centroids.capacity()) + pub fn size(&self) -> usize { + size_of_val(self) + (size_of::() * self.centroids.capacity()) } } @@ -203,7 +214,7 @@ impl Default for TDigest { centroids: Vec::new(), max_size: 100, sum: 0_f64, - count: 0_f64, + count: 0, max: f64::NAN, min: f64::NAN, } @@ -211,8 +222,8 @@ impl Default for TDigest { } impl TDigest { - fn k_to_q(k: f64, d: f64) -> f64 { - let k_div_d = k / d; + fn k_to_q(k: u64, d: usize) -> f64 { + let k_div_d = k as f64 / d as f64; if k_div_d >= 0.5 { let base = 1.0 - k_div_d; 1.0 - 2.0 * base * base @@ -222,20 +233,20 @@ impl TDigest { } fn clamp(v: f64, lo: f64, hi: f64) -> f64 { - if lo.is_nan() && hi.is_nan() { + if lo.is_nan() || hi.is_nan() { return v; } v.clamp(lo, hi) } - #[cfg(test)] - pub(crate) fn merge_unsorted_f64(&self, unsorted_values: Vec) -> TDigest { + // public for testing in other modules + pub fn merge_unsorted_f64(&self, unsorted_values: Vec) -> TDigest { let mut values = unsorted_values; values.sort_by(|a, b| a.total_cmp(b)); self.merge_sorted_f64(&values) } - pub(crate) fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest { + pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest { #[cfg(debug_assertions)] debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); @@ -244,12 +255,12 @@ impl TDigest { } let mut result = TDigest::new(self.max_size()); - result.count = self.count() + (sorted_values.len() as f64); + result.count = self.count() + sorted_values.len() as u64; let maybe_min = *sorted_values.first().unwrap(); let maybe_max = *sorted_values.last().unwrap(); - if self.count() > 0.0 { + if self.count() > 0 { result.min = self.min.min(maybe_min); result.max = self.max.max(maybe_max); } else { @@ -259,10 +270,10 @@ impl TDigest { let mut compressed: Vec = Vec::with_capacity(self.max_size); - let mut k_limit: f64 = 1.0; + let mut k_limit: u64 = 1; let mut q_limit_times_count = - Self::k_to_q(k_limit, self.max_size as f64) * result.count(); - k_limit += 1.0; + Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + k_limit += 1; let mut iter_centroids = self.centroids.iter().peekable(); let mut iter_sorted_values = sorted_values.iter().peekable(); @@ -309,8 +320,8 @@ impl TDigest { compressed.push(curr.clone()); q_limit_times_count = - Self::k_to_q(k_limit, self.max_size as f64) * result.count(); - k_limit += 1.0; + Self::k_to_q(k_limit, self.max_size) * result.count() as f64; + k_limit += 1; curr = next; } } @@ -370,9 +381,7 @@ impl TDigest { } // Merge multiple T-Digests - pub(crate) fn merge_digests<'a>( - digests: impl IntoIterator, - ) -> TDigest { + pub fn merge_digests<'a>(digests: impl IntoIterator) -> TDigest { let digests = digests.into_iter().collect::>(); let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); if n_centroids == 0 { @@ -383,7 +392,7 @@ impl TDigest { let mut centroids: Vec = Vec::with_capacity(n_centroids); let mut starts: Vec = Vec::with_capacity(digests.len()); - let mut count: f64 = 0.0; + let mut count = 0; let mut min = f64::INFINITY; let mut max = f64::NEG_INFINITY; @@ -391,8 +400,8 @@ impl TDigest { for digest in digests.iter() { starts.push(start); - let curr_count: f64 = digest.count(); - if curr_count > 0.0 { + let curr_count = digest.count(); + if curr_count > 0 { min = min.min(digest.min); max = max.max(digest.max); count += curr_count; @@ -426,8 +435,8 @@ impl TDigest { let mut result = TDigest::new(max_size); let mut compressed: Vec = Vec::with_capacity(max_size); - let mut k_limit: f64 = 1.0; - let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count); + let mut k_limit = 1; + let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; let mut iter_centroids = centroids.iter_mut(); let mut curr = iter_centroids.next().unwrap(); @@ -446,8 +455,8 @@ impl TDigest { sums_to_merge = 0_f64; weights_to_merge = 0_f64; compressed.push(curr.clone()); - q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * (count); - k_limit += 1.0; + q_limit_times_count = Self::k_to_q(k_limit, max_size) * count as f64; + k_limit += 1; curr = centroid; } } @@ -465,13 +474,12 @@ impl TDigest { } /// To estimate the value located at `q` quantile - pub(crate) fn estimate_quantile(&self, q: f64) -> f64 { + pub fn estimate_quantile(&self, q: f64) -> f64 { if self.centroids.is_empty() { return 0.0; } - let count_ = self.count; - let rank = q * count_; + let rank = q * self.count as f64; let mut pos: usize; let mut t; @@ -481,7 +489,7 @@ impl TDigest { } pos = 0; - t = count_; + t = self.count as f64; for (k, centroid) in self.centroids.iter().enumerate().rev() { t -= centroid.weight(); @@ -531,6 +539,18 @@ impl TDigest { let value = self.centroids[pos].mean() + ((rank - t) / self.centroids[pos].weight() - 0.5) * delta; + // In `merge_digests()`: `min` is initialized to Inf, `max` is initialized to -Inf + // and gets updated according to different `TDigest`s + // However, `min`/`max` won't get updated if there is only one `NaN` within `TDigest` + // The following two checks is for such edge case + if !min.is_finite() && min.is_sign_positive() { + min = f64::NEG_INFINITY; + } + + if !max.is_finite() && max.is_sign_negative() { + max = f64::INFINITY; + } + Self::clamp(value, min, max) } @@ -569,7 +589,7 @@ impl TDigest { /// The [`TDigest::from_scalar_state()`] method reverses this processes, /// consuming the output of this method and returning an unpacked /// [`TDigest`]. - pub(crate) fn to_scalar_state(&self) -> Vec { + pub fn to_scalar_state(&self) -> Vec { // Gather up all the centroids let centroids: Vec = self .centroids @@ -578,12 +598,12 @@ impl TDigest { .map(|v| ScalarValue::Float64(Some(v))) .collect(); - let arr = ScalarValue::new_list(¢roids, &DataType::Float64); + let arr = ScalarValue::new_list_nullable(¢roids, &DataType::Float64); vec![ ScalarValue::UInt64(Some(self.max_size as u64)), ScalarValue::Float64(Some(self.sum)), - ScalarValue::Float64(Some(self.count)), + ScalarValue::UInt64(Some(self.count)), ScalarValue::Float64(Some(self.max)), ScalarValue::Float64(Some(self.min)), ScalarValue::List(arr), @@ -598,7 +618,7 @@ impl TDigest { /// Providing input to this method that was not obtained from /// [`Self::to_scalar_state()`] results in undefined behaviour and may /// panic. - pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self { + pub fn from_scalar_state(state: &[ScalarValue]) -> Self { assert_eq!(state.len(), 6, "invalid TDigest state"); let max_size = match &state[0] { @@ -624,12 +644,14 @@ impl TDigest { let max = cast_scalar_f64!(&state[3]); let min = cast_scalar_f64!(&state[4]); - assert!(max.total_cmp(&min).is_ge()); + if min.is_finite() && max.is_finite() { + assert!(max.total_cmp(&min).is_ge()); + } Self { max_size, sum: cast_scalar_f64!(state[1]), - count: cast_scalar_f64!(&state[2]), + count: cast_scalar_u64!(&state[2]), max, min, centroids, diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/functions-aggregate-common/src/utils.rs similarity index 76% rename from datafusion/physical-expr/src/aggregate/utils.rs rename to datafusion/functions-aggregate-common/src/utils.rs index 6d97ad3da6de..f55e5ec9a41d 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/functions-aggregate-common/src/utils.rs @@ -15,25 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! Utilities used in aggregates - use std::sync::Arc; -// For backwards compatibility -pub use datafusion_physical_expr_common::aggregate::utils::{ - down_cast_any_ref, get_sort_options, ordering_fields, -}; - -use arrow::array::{ArrayRef, ArrowNativeTypeOp}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::ArrowNativeType; +use arrow::{ + array::ArrowNativeTypeOp, + compute::SortOptions, + datatypes::{ + DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + ToByteSlice, + }, }; -use arrow_buffer::{ArrowNativeType, ToByteSlice}; -use arrow_schema::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; /// Convert scalar values from an accumulator into arrays. pub fn get_accum_scalar_values_as_arrays( @@ -46,6 +43,92 @@ pub fn get_accum_scalar_values_as_arrays( .collect() } +/// Adjust array type metadata if needed +/// +/// Since `Decimal128Arrays` created from `Vec` have +/// default precision and scale, this function adjusts the output to +/// match `data_type`, if necessary +pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result { + let array = match data_type { + DataType::Decimal128(p, s) => Arc::new( + array + .as_primitive::() + .clone() + .with_precision_and_scale(*p, *s)?, + ) as ArrayRef, + DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + DataType::Timestamp(TimeUnit::Second, tz) => Arc::new( + array + .as_primitive::() + .clone() + .with_timezone_opt(tz.clone()), + ), + // no adjustment needed for other arrays + _ => array, + }; + Ok(array) +} + +/// Construct corresponding fields for lexicographical ordering requirement expression +pub fn ordering_fields( + ordering_req: LexOrderingRef, + // Data type of each expression in the ordering requirement + data_types: &[DataType], +) -> Vec { + ordering_req + .iter() + .zip(data_types.iter()) + .map(|(sort_expr, dtype)| { + Field::new( + sort_expr.expr.to_string().as_str(), + dtype.clone(), + // Multi partitions may be empty hence field should be nullable. + true, + ) + }) + .collect() +} + +/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. +pub fn get_sort_options(ordering_req: LexOrderingRef) -> Vec { + ordering_req.iter().map(|item| item.options).collect() +} + +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone, Debug)] +pub struct Hashable(pub T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {} + /// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow /// /// This is needed because different precisions for Decimal128/Decimal256 can @@ -54,7 +137,7 @@ pub fn get_accum_scalar_values_as_arrays( /// /// For example, the precision is 3, the max of value is `999` and the min /// value is `-999` -pub(crate) struct DecimalAverager { +pub struct DecimalAverager { /// scale factor for sum values (10^sum_scale) sum_mul: T::Native, /// scale factor for target (10^target_scale) @@ -104,7 +187,7 @@ impl DecimalAverager { /// target_scale and target_precision and reporting overflow. /// /// * sum: The total sum value stored as Decimal128 with sum_scale - /// (passed to `Self::try_new`) + /// (passed to `Self::try_new`) /// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value) #[inline(always)] pub fn avg(&self, sum: T::Native, count: T::Native) -> Result { @@ -125,67 +208,3 @@ impl DecimalAverager { } } } - -/// Adjust array type metadata if needed -/// -/// Since `Decimal128Arrays` created from `Vec` have -/// default precision and scale, this function adjusts the output to -/// match `data_type`, if necessary -pub fn adjust_output_array( - data_type: &DataType, - array: ArrayRef, -) -> Result { - let array = match data_type { - DataType::Decimal128(p, s) => Arc::new( - array - .as_primitive::() - .clone() - .with_precision_and_scale(*p, *s)?, - ) as ArrayRef, - DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - DataType::Timestamp(arrow_schema::TimeUnit::Second, tz) => Arc::new( - array - .as_primitive::() - .clone() - .with_timezone_opt(tz.clone()), - ), - // no adjustment needed for other arrays - _ => array, - }; - Ok(array) -} - -/// A wrapper around a type to provide hash for floats -#[derive(Copy, Clone, Debug)] -pub(crate) struct Hashable(pub T); - -impl std::hash::Hash for Hashable { - fn hash(&self, state: &mut H) { - self.0.to_byte_slice().hash(state) - } -} - -impl PartialEq for Hashable { - fn eq(&self, other: &Self) -> bool { - self.0.is_eq(other.0) - } -} - -impl Eq for Hashable {} diff --git a/datafusion/functions-aggregate/COMMENTS.md b/datafusion/functions-aggregate/COMMENTS.md new file mode 100644 index 000000000000..e669e1355711 --- /dev/null +++ b/datafusion/functions-aggregate/COMMENTS.md @@ -0,0 +1,77 @@ + + +# Why Is List Item Always Nullable? + +## Motivation + +There were independent proposals to make the `nullable` setting of list +items in accumulator state be configurable. This meant adding additional +fields which captured the `nullable` setting from schema in planning for +the first argument to the aggregation function, and the returned value. + +These fields were to be added to `StateFieldArgs`. But then we found out +that aggregate computation does not depend on it, and it can be avoided. + +This document exists to make that reasoning explicit. + +## Background + +The list data type is used in the accumulator state for a few aggregate +functions like: + +- `sum` +- `count` +- `array_agg` +- `bit_and`, `bit_or` and `bit_xor` +- `nth_value` + +In all of the above cases the data type of the list item is equivalent +to either the first argument of the aggregate function or the returned +value. + +For example, in `array_agg` the data type of item is equivalent to the +first argument and the definition looks like this: + +```rust +// `args` : `StateFieldArgs` +// `input_type` : data type of the first argument +let mut fields = vec![Field::new_list( + format_state_name(self.name(), "nth_value"), + Field::new("item", args.input_types[0].clone(), true /* nullable of list item */ ), + false, // nullable of list itself +)]; +``` + +For all the aggregates listed above, the list item is always defined as +nullable. + +## Computing Intermediate State + +By setting `nullable` (of list item) to be always `true` like this we +ensure that the aggregate computation works even when nulls are +present. The advantage of doing it this way is that it eliminates the +need for additional code and special treatment of nulls in the +accumulator state. + +## Nullable Of List Itself + +The `nullable` of list itself depends on the aggregate. In the case of +`array_agg` the list is nullable(`true`), meanwhile for `sum` the list +is not nullable(`false`). diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index f97647565364..37e4c7f4a5ad 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "datafusion-functions-aggregate" -description = "Aggregate function packages for the DataFusion query engine" +description = "Traits and types for logical plans and expressions for DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] readme = "README.md" version = { workspace = true } @@ -38,11 +38,29 @@ path = "src/lib.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +ahash = { workspace = true } arrow = { workspace = true } +arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +half = { workspace = true } +indexmap = { workspace = true } log = { workspace = true } paste = "1.0.14" -sqlparser = { workspace = true } + +[dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } +criterion = "0.5" +rand = { workspace = true } + +[[bench]] +name = "count" +harness = false + +[[bench]] +name = "sum" +harness = false diff --git a/datafusion/functions-aggregate/README.md b/datafusion/functions-aggregate/README.md new file mode 100644 index 000000000000..29b313d2a903 --- /dev/null +++ b/datafusion/functions-aggregate/README.md @@ -0,0 +1,27 @@ + + +# DataFusion Aggregate Function Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains packages of function that can be used to customize the +functionality of DataFusion. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-aggregate/benches/count.rs b/datafusion/functions-aggregate/benches/count.rs new file mode 100644 index 000000000000..1c8266ed5b89 --- /dev/null +++ b/datafusion/functions-aggregate/benches/count.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::Int32Type; +use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use datafusion_functions_aggregate::count::Count; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; +use std::sync::Arc; + +fn prepare_accumulator() -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Int32, true)])); + let accumulator_args = AccumulatorArgs { + return_type: &DataType::Int64, + schema: &schema, + ignore_nulls: false, + ordering_req: LexOrderingRef::default(), + is_reversed: false, + name: "COUNT(f)", + is_distinct: false, + exprs: &[col("f", &schema).unwrap()], + }; + let count_fn = Count::new(); + + count_fn + .create_groups_accumulator(accumulator_args) + .unwrap() +} + +fn convert_to_state_bench( + c: &mut Criterion, + name: &str, + values: ArrayRef, + opt_filter: Option<&BooleanArray>, +) { + let accumulator = prepare_accumulator(); + c.bench_function(name, |b| { + b.iter(|| { + black_box( + accumulator + .convert_to_state(&[values.clone()], opt_filter) + .unwrap(), + ) + }) + }); +} + +fn count_benchmark(c: &mut Criterion) { + let values = Arc::new(create_primitive_array::(8192, 0.0)) as ArrayRef; + convert_to_state_bench(c, "count convert state no nulls, no filter", values, None); + + let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; + convert_to_state_bench(c, "count convert state 30% nulls, no filter", values, None); + + let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; + convert_to_state_bench(c, "count convert state 70% nulls, no filter", values, None); + + let values = Arc::new(create_primitive_array::(8192, 0.0)) as ArrayRef; + let filter = create_boolean_array(8192, 0.0, 0.5); + convert_to_state_bench( + c, + "count convert state no nulls, filter", + values, + Some(&filter), + ); + + let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; + let filter = create_boolean_array(8192, 0.0, 0.5); + convert_to_state_bench( + c, + "count convert state nulls, filter", + values, + Some(&filter), + ); +} + +criterion_group!(benches, count_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-aggregate/benches/sum.rs b/datafusion/functions-aggregate/benches/sum.rs new file mode 100644 index 000000000000..1e9493280ed2 --- /dev/null +++ b/datafusion/functions-aggregate/benches/sum.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::datatypes::Int64Type; +use arrow::util::bench_util::{create_boolean_array, create_primitive_array}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::{function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator}; +use datafusion_functions_aggregate::sum::Sum; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; +use std::sync::Arc; + +fn prepare_accumulator(data_type: &DataType) -> Box { + let schema = Arc::new(Schema::new(vec![Field::new("f", data_type.clone(), true)])); + let accumulator_args = AccumulatorArgs { + return_type: data_type, + schema: &schema, + ignore_nulls: false, + ordering_req: LexOrderingRef::default(), + is_reversed: false, + name: "SUM(f)", + is_distinct: false, + exprs: &[col("f", &schema).unwrap()], + }; + let sum_fn = Sum::new(); + + sum_fn.create_groups_accumulator(accumulator_args).unwrap() +} + +fn convert_to_state_bench( + c: &mut Criterion, + name: &str, + values: ArrayRef, + opt_filter: Option<&BooleanArray>, +) { + let accumulator = prepare_accumulator(values.data_type()); + c.bench_function(name, |b| { + b.iter(|| { + black_box( + accumulator + .convert_to_state(&[values.clone()], opt_filter) + .unwrap(), + ) + }) + }); +} + +fn count_benchmark(c: &mut Criterion) { + let values = Arc::new(create_primitive_array::(8192, 0.0)) as ArrayRef; + convert_to_state_bench(c, "sum i64 convert state no nulls, no filter", values, None); + + let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; + convert_to_state_bench( + c, + "sum i64 convert state 30% nulls, no filter", + values, + None, + ); + + let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; + convert_to_state_bench( + c, + "sum i64 convert state 70% nulls, no filter", + values, + None, + ); + + let values = Arc::new(create_primitive_array::(8192, 0.0)) as ArrayRef; + let filter = create_boolean_array(8192, 0.0, 0.5); + convert_to_state_bench( + c, + "sum i64 convert state no nulls, filter", + values, + Some(&filter), + ); + + let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; + let filter = create_boolean_array(8192, 0.0, 0.5); + convert_to_state_bench( + c, + "sum i64 convert state nulls, filter", + values, + Some(&filter), + ); +} + +criterion_group!(benches, count_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/aggregate/approx_distinct.rs b/datafusion/functions-aggregate/src/approx_distinct.rs similarity index 71% rename from datafusion/physical-expr/src/aggregate/approx_distinct.rs rename to datafusion/functions-aggregate/src/approx_distinct.rs index c0bce3ac2774..1df106feb4d3 100644 --- a/datafusion/physical-expr/src/aggregate/approx_distinct.rs +++ b/datafusion/functions-aggregate/src/approx_distinct.rs @@ -17,134 +17,89 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use super::hyperloglog::HyperLogLog; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; +use crate::hyperloglog::HyperLogLog; +use arrow::array::BinaryArray; use arrow::array::{ - ArrayRef, BinaryArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, - PrimitiveArray, + GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ - ArrowPrimitiveType, DataType, Field, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, + ArrowPrimitiveType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, }; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; +use datafusion_common::ScalarValue; use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, + downcast_value, internal_err, not_impl_err, DataFusionError, Result, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, }; -use datafusion_expr::Accumulator; use std::any::Any; +use std::fmt::{Debug, Formatter}; use std::hash::Hash; use std::marker::PhantomData; -use std::sync::Arc; +use std::sync::OnceLock; +make_udaf_expr_and_func!( + ApproxDistinct, + approx_distinct, + expression, + "approximate number of distinct input values", + approx_distinct_udaf +); -/// APPROX_DISTINCT aggregate expression -#[derive(Debug)] -pub struct ApproxDistinct { - name: String, - input_data_type: DataType, - expr: Arc, -} - -impl ApproxDistinct { - /// Create a new ApproxDistinct aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ) -> Self { - Self { - name: name.into(), - input_data_type, - expr, - } +impl From<&HyperLogLog> for ScalarValue { + fn from(v: &HyperLogLog) -> ScalarValue { + let values = v.as_ref().to_vec(); + ScalarValue::Binary(Some(values)) } } -impl AggregateExpr for ApproxDistinct { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::UInt64, false)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "hll_registers"), - DataType::Binary, - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - let accumulator: Box = match &self.input_data_type { - // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL - // TODO support for boolean (trivial case) - // https://github.com/apache/datafusion/issues/1109 - DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), - DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), - DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), - DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), - DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), - DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), - DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), - other => { - return not_impl_err!( - "Support for 'approx_distinct' for data type {other} is not implemented" +impl TryFrom<&[u8]> for HyperLogLog { + type Error = DataFusionError; + fn try_from(v: &[u8]) -> Result> { + let arr: [u8; 16384] = v.try_into().map_err(|_| { + DataFusionError::Internal( + "Impossibly got invalid binary array from states".into(), ) - } - }; - Ok(accumulator) - } - - fn name(&self) -> &str { - &self.name + })?; + Ok(HyperLogLog::::new_with_registers(arr)) } } -impl PartialEq for ApproxDistinct { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) +impl TryFrom<&ScalarValue> for HyperLogLog { + type Error = DataFusionError; + fn try_from(v: &ScalarValue) -> Result> { + if let ScalarValue::Binary(Some(slice)) = v { + slice.as_slice().try_into() + } else { + internal_err!( + "Impossibly got invalid scalar value while converting to HyperLogLog" + ) + } } } #[derive(Debug)] -struct BinaryHLLAccumulator +struct NumericHLLAccumulator where - T: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: Hash, { - hll: HyperLogLog>, - phantom_data: PhantomData, + hll: HyperLogLog, } -impl BinaryHLLAccumulator +impl NumericHLLAccumulator where - T: OffsetSizeTrait, + T: ArrowPrimitiveType, + T::Native: Hash, { /// new approx_distinct accumulator pub fn new() -> Self { Self { hll: HyperLogLog::new(), - phantom_data: PhantomData, } } } @@ -172,55 +127,23 @@ where } #[derive(Debug)] -struct NumericHLLAccumulator +struct BinaryHLLAccumulator where - T: ArrowPrimitiveType, - T::Native: Hash, + T: OffsetSizeTrait, { - hll: HyperLogLog, + hll: HyperLogLog>, + phantom_data: PhantomData, } -impl NumericHLLAccumulator +impl BinaryHLLAccumulator where - T: ArrowPrimitiveType, - T::Native: Hash, + T: OffsetSizeTrait, { /// new approx_distinct accumulator pub fn new() -> Self { Self { hll: HyperLogLog::new(), - } - } -} - -impl From<&HyperLogLog> for ScalarValue { - fn from(v: &HyperLogLog) -> ScalarValue { - let values = v.as_ref().to_vec(); - ScalarValue::Binary(Some(values)) - } -} - -impl TryFrom<&[u8]> for HyperLogLog { - type Error = DataFusionError; - fn try_from(v: &[u8]) -> Result> { - let arr: [u8; 16384] = v.try_into().map_err(|_| { - DataFusionError::Internal( - "Impossibly got invalid binary array from states".into(), - ) - })?; - Ok(HyperLogLog::::new_with_registers(arr)) - } -} - -impl TryFrom<&ScalarValue> for HyperLogLog { - type Error = DataFusionError; - fn try_from(v: &ScalarValue) -> Result> { - if let ScalarValue::Binary(Some(slice)) = v { - slice.as_slice().try_into() - } else { - internal_err!( - "Impossibly got invalid scalar value while converting to HyperLogLog" - ) + phantom_data: PhantomData, } } } @@ -292,7 +215,7 @@ where impl Accumulator for NumericHLLAccumulator where - T: ArrowPrimitiveType + std::fmt::Debug, + T: ArrowPrimitiveType + Debug, T::Native: Hash, { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { @@ -304,3 +227,113 @@ where default_accumulator_impl!(); } + +impl Debug for ApproxDistinct { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApproxDistinct") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxDistinct { + fn default() -> Self { + Self::new() + } +} + +pub struct ApproxDistinct { + signature: Signature, +} + +impl ApproxDistinct { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ApproxDistinct { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "approx_distinct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::UInt64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, "hll_registers"), + DataType::Binary, + false, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + + let accumulator: Box = match data_type { + // TODO u8, i8, u16, i16 shall really be done using bitmap, not HLL + // TODO support for boolean (trivial case) + // https://github.com/apache/datafusion/issues/1109 + DataType::UInt8 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt16 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt32 => Box::new(NumericHLLAccumulator::::new()), + DataType::UInt64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int8 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int16 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int32 => Box::new(NumericHLLAccumulator::::new()), + DataType::Int64 => Box::new(NumericHLLAccumulator::::new()), + DataType::Utf8 => Box::new(StringHLLAccumulator::::new()), + DataType::LargeUtf8 => Box::new(StringHLLAccumulator::::new()), + DataType::Binary => Box::new(BinaryHLLAccumulator::::new()), + DataType::LargeBinary => Box::new(BinaryHLLAccumulator::::new()), + other => { + return not_impl_err!( + "Support for 'approx_distinct' for data type {other} is not implemented" + ) + } + }; + Ok(accumulator) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_distinct_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_distinct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm.", + ) + .with_syntax_example("approx_distinct(expression)") + .with_sql_example(r#"```sql +> SELECT approx_distinct(column_name) FROM table_name; ++-----------------------------------+ +| approx_distinct(column_name) | ++-----------------------------------+ +| 42 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs new file mode 100644 index 000000000000..96609622a51e --- /dev/null +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution + +use std::any::Any; +use std::fmt::Debug; +use std::sync::OnceLock; + +use arrow::{datatypes::DataType, datatypes::Field}; +use arrow_schema::DataType::{Float64, UInt64}; + +use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; + +use crate::approx_percentile_cont::ApproxPercentileAccumulator; + +make_udaf_expr_and_func!( + ApproxMedian, + approx_median, + expression, + "Computes the approximate median of a set of numbers", + approx_median_udaf +); + +/// APPROX_MEDIAN aggregate expression +pub struct ApproxMedian { + signature: Signature, +} + +impl Debug for ApproxMedian { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ApproxMedian") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxMedian { + fn default() -> Self { + Self::new() + } +} + +impl ApproxMedian { + /// Create a new APPROX_MEDIAN aggregate function + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ApproxMedian { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new(format_state_name(args.name, "max_size"), UInt64, false), + Field::new(format_state_name(args.name, "sum"), Float64, false), + Field::new(format_state_name(args.name, "count"), UInt64, false), + Field::new(format_state_name(args.name, "max"), Float64, false), + Field::new(format_state_name(args.name, "min"), Float64, false), + Field::new_list( + format_state_name(args.name, "centroids"), + Field::new("item", Float64, true), + false, + ), + ]) + } + + fn name(&self) -> &str { + "approx_median" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("ApproxMedian requires numeric input types"); + } + Ok(arg_types[0].clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!( + "APPROX_MEDIAN(DISTINCT) aggregations are not available" + ); + } + + Ok(Box::new(ApproxPercentileAccumulator::new( + 0.5_f64, + acc_args.exprs[0].data_type(acc_args.schema)?, + ))) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_median_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_median_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`.", + ) + .with_syntax_example("approx_median(expression)") + .with_sql_example(r#"```sql +> SELECT approx_median(column_name) FROM table_name; ++-----------------------------------+ +| approx_median(column_name) | ++-----------------------------------+ +| 23.5 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs similarity index 54% rename from datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs rename to datafusion/functions-aggregate/src/approx_percentile_cont.rs index 63a4c85f9e80..53fcfd641ddf 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -15,81 +15,108 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregate::tdigest::TryIntoF64; -use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; + +use arrow::array::{Array, RecordBatch}; +use arrow::compute::{filter, is_not_null}; use arrow::{ array::{ ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, - datatypes::{DataType, Field}, + datatypes::DataType, }; -use arrow_array::RecordBatch; -use arrow_schema::Schema; +use arrow_schema::{Field, Schema}; + use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result, - ScalarValue, + downcast_value, internal_err, not_impl_datafusion_err, not_impl_err, plan_err, + DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::{INTEGERS, NUMERICS}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, Documentation, Expr, Signature, + TypeSignature, Volatility, +}; +use datafusion_functions_aggregate_common::tdigest::{ + TDigest, TryIntoF64, DEFAULT_MAX_SIZE, }; -use datafusion_expr::{Accumulator, ColumnarValue}; -use std::{any::Any, sync::Arc}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + +create_func!(ApproxPercentileCont, approx_percentile_cont_udaf); + +/// Computes the approximate percentile continuous of a set of numbers +pub fn approx_percentile_cont( + expression: Expr, + percentile: Expr, + centroids: Option, +) -> Expr { + let args = if let Some(centroids) = centroids { + vec![expression, percentile, centroids] + } else { + vec![expression, percentile] + }; + approx_percentile_cont_udaf().call(args) +} -/// APPROX_PERCENTILE_CONT aggregate expression -#[derive(Debug)] pub struct ApproxPercentileCont { - name: String, - input_data_type: DataType, - expr: Vec>, - percentile: f64, - tdigest_max_size: Option, + signature: Signature, } -impl ApproxPercentileCont { - /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 2); - - let percentile = validate_input_percentile_expr(&expr[1])?; - - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: None, - }) +impl Debug for ApproxPercentileCont { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("ApproxPercentileCont") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileCont { + fn default() -> Self { + Self::new() } +} +impl ApproxPercentileCont { /// Create a new [`ApproxPercentileCont`] aggregate function. - pub fn new_with_max_size( - expr: Vec>, - name: impl Into, - input_data_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, DesiredPercentileLiteral, TDigestMaxSize] - debug_assert_eq!(expr.len(), 3); - let percentile = validate_input_percentile_expr(&expr[1])?; - let max_size = validate_input_max_size_expr(&expr[2])?; - Ok(Self { - name: name.into(), - input_data_type, - // The physical expr to evaluate during accumulation - expr, - percentile, - tdigest_max_size: Some(max_size), - }) + pub fn new() -> Self { + let mut variants = Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); + // Accept any numeric value paired with a float64 percentile + for num in NUMERICS { + variants.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + DataType::Float64, + int.clone(), + ])) + } + } + Self { + signature: Signature::one_of(variants, Volatility::Immutable), + } } - pub(crate) fn create_plain_accumulator(&self) -> Result { - let accumulator: ApproxPercentileAccumulator = match &self.input_data_type { + pub(crate) fn create_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result { + let percentile = validate_input_percentile_expr(&args.exprs[1])?; + let tdigest_max_size = if args.exprs.len() == 3 { + Some(validate_input_max_size_expr(&args.exprs[2])?) + } else { + None + }; + + let data_type = args.exprs[0].data_type(args.schema)?; + let accumulator: ApproxPercentileAccumulator = match data_type { t @ (DataType::UInt8 | DataType::UInt16 | DataType::UInt32 @@ -100,11 +127,10 @@ impl ApproxPercentileCont { | DataType::Int64 | DataType::Float32 | DataType::Float64) => { - if let Some(max_size) = self.tdigest_max_size { - ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(), max_size) - + if let Some(max_size) = tdigest_max_size { + ApproxPercentileAccumulator::new_with_max_size(percentile, t, max_size) }else{ - ApproxPercentileAccumulator::new(self.percentile, t.clone()) + ApproxPercentileAccumulator::new(percentile, t) } } @@ -114,47 +140,36 @@ impl ApproxPercentileCont { ) } }; - Ok(accumulator) - } -} -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &ApproxPercentileCont) -> bool { - self.name == other.name - && self.input_data_type == other.input_data_type - && self.percentile == other.percentile - && self.tdigest_max_size == other.tdigest_max_size - && self.expr.len() == other.expr.len() - && self - .expr - .iter() - .zip(other.expr.iter()) - .all(|(this, other)| this.eq(other)) + Ok(accumulator) } } -fn get_lit_value(expr: &Arc) -> Result { - let empty_schema = Schema::empty(); - let empty_batch = RecordBatch::new_empty(Arc::new(empty_schema)); - let result = expr.evaluate(&empty_batch)?; - match result { - ColumnarValue::Array(_) => Err(DataFusionError::Internal(format!( - "The expr {:?} can't be evaluated to scalar value", - expr - ))), - ColumnarValue::Scalar(scalar_value) => Ok(scalar_value), +fn get_scalar_value(expr: &Arc) -> Result { + let empty_schema = Arc::new(Schema::empty()); + let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); + if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { + Ok(s) + } else { + internal_err!("Didn't expect ColumnarValue::Array") } } fn validate_input_percentile_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let percentile = match &lit { - ScalarValue::Float32(Some(q)) => *q as f64, - ScalarValue::Float64(Some(q)) => *q, - got => return not_impl_err!( - "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", - got.data_type() - ) + let percentile = match get_scalar_value(expr) + .map_err(|_| not_impl_datafusion_err!("Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { + ScalarValue::Float32(Some(value)) => { + value as f64 + } + ScalarValue::Float64(Some(value)) => { + value + } + sv => { + return not_impl_err!( + "Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + sv.data_type() + ) + } }; // Ensure the percentile is between 0 and 1. @@ -167,94 +182,126 @@ fn validate_input_percentile_expr(expr: &Arc) -> Result { } fn validate_input_max_size_expr(expr: &Arc) -> Result { - let lit = get_lit_value(expr)?; - let max_size = match &lit { - ScalarValue::UInt8(Some(q)) => *q as usize, - ScalarValue::UInt16(Some(q)) => *q as usize, - ScalarValue::UInt32(Some(q)) => *q as usize, - ScalarValue::UInt64(Some(q)) => *q as usize, - ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize, - ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize, - got => return not_impl_err!( - "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", - got.data_type() - ) + let max_size = match get_scalar_value(expr) + .map_err(|_| not_impl_datafusion_err!("Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal, got: {expr}"))? { + ScalarValue::UInt8(Some(q)) => q as usize, + ScalarValue::UInt16(Some(q)) => q as usize, + ScalarValue::UInt32(Some(q)) => q as usize, + ScalarValue::UInt64(Some(q)) => q as usize, + ScalarValue::Int32(Some(q)) if q > 0 => q as usize, + ScalarValue::Int64(Some(q)) if q > 0 => q as usize, + ScalarValue::Int16(Some(q)) if q > 0 => q as usize, + ScalarValue::Int8(Some(q)) if q > 0 => q as usize, + sv => { + return not_impl_err!( + "Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).", + sv.data_type() + ) + }, }; + Ok(max_size) } -impl AggregateExpr for ApproxPercentileCont { +impl AggregateUDFImpl for ApproxPercentileCont { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), false)) - } - #[allow(rustdoc::private_intra_doc_links)] /// See [`TDigest::to_scalar_state()`] for a description of the serialised /// state. - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "max_size"), + format_state_name(args.name, "max_size"), DataType::UInt64, false, ), Field::new( - format_state_name(&self.name, "sum"), + format_state_name(args.name, "sum"), DataType::Float64, false, ), Field::new( - format_state_name(&self.name, "count"), - DataType::Float64, + format_state_name(args.name, "count"), + DataType::UInt64, false, ), Field::new( - format_state_name(&self.name, "max"), + format_state_name(args.name, "max"), DataType::Float64, false, ), Field::new( - format_state_name(&self.name, "min"), + format_state_name(args.name, "min"), DataType::Float64, false, ), Field::new_list( - format_state_name(&self.name, "centroids"), + format_state_name(args.name, "centroids"), Field::new("item", DataType::Float64, true), false, ), ]) } - fn expressions(&self) -> Vec> { - self.expr.clone() + fn name(&self) -> &str { + "approx_percentile_cont" } - fn create_accumulator(&self) -> Result> { - let accumulator = self.create_plain_accumulator()?; - Ok(Box::new(accumulator)) + fn signature(&self) -> &Signature { + &self.signature } - fn name(&self) -> &str { - &self.name + #[inline] + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(self.create_accumulator(acc_args)?)) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("approx_percentile_cont requires numeric input types"); + } + if arg_types.len() == 3 && !arg_types[2].is_integer() { + return plan_err!( + "approx_percentile_cont requires integer max_size input types" + ); + } + Ok(arg_types[0].clone()) } -} -impl PartialEq for ApproxPercentileCont { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.eq(x)) - .unwrap_or(false) + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_percentile_cont_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_percentile_cont_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the approximate percentile of input values using the t-digest algorithm.", + ) + .with_syntax_example("approx_percentile_cont(expression, percentile, centroids)") + .with_sql_example(r#"```sql +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++-------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++-------------------------------------------------+ +| 65.0 | ++-------------------------------------------------+ +```"#) + .with_standard_argument("expression", None) + .with_argument("percentile", "Percentile to compute. Must be a float value between 0 and 1 (inclusive).") + .with_argument("centroids", "Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory.") + .build() + .unwrap() + }) +} + #[derive(Debug)] pub struct ApproxPercentileAccumulator { digest: TDigest, @@ -283,12 +330,14 @@ impl ApproxPercentileAccumulator { } } - pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) { + // public for approx_percentile_cont_with_weight + pub fn merge_digests(&mut self, digests: &[TDigest]) { let digests = digests.iter().chain(std::iter::once(&self.digest)); self.digest = TDigest::merge_digests(digests) } - pub(crate) fn convert_to_float(values: &ArrayRef) -> Result> { + // public for approx_percentile_cont_with_weight + pub fn convert_to_float(values: &ArrayRef) -> Result> { match values.data_type() { DataType::Float64 => { let array = downcast_value!(values, Float64Array); @@ -383,15 +432,19 @@ impl Accumulator for ApproxPercentileAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let sorted_values = &arrow::compute::sort(values, None)?; + // Remove any nulls before computing the percentile + let mut values = Arc::clone(&values[0]); + if values.nulls().is_some() { + values = filter(&values, &is_not_null(&values)?)?; + } + let sorted_values = &arrow::compute::sort(&values, None)?; let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?; self.digest = self.digest.merge_sorted_f64(&sorted_values); Ok(()) } fn evaluate(&mut self) -> Result { - if self.digest.count() == 0.0 { + if self.digest.count() == 0 { return ScalarValue::try_from(self.return_type.clone()); } let q = self.digest.estimate_quantile(self.percentile); @@ -434,19 +487,20 @@ impl Accumulator for ApproxPercentileAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + self.digest.size() - - std::mem::size_of_val(&self.digest) + size_of_val(self) + self.digest.size() - size_of_val(&self.digest) + self.return_type.size() - - std::mem::size_of_val(&self.return_type) + - size_of_val(&self.return_type) } } #[cfg(test)] mod tests { - use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator; - use crate::aggregate::tdigest::TDigest; use arrow_schema::DataType; + use datafusion_functions_aggregate_common::tdigest::TDigest; + + use crate::approx_percentile_cont::ApproxPercentileAccumulator; + #[test] fn test_combine_approx_percentile_accumulator() { let mut digests: Vec = Vec::new(); @@ -466,8 +520,8 @@ mod tests { ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000.0); + assert_eq!(accumulator.digest.count(), 50_000); accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000.0); + assert_eq!(accumulator.digest.count(), 100_000); } } diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs new file mode 100644 index 000000000000..5458d0f792b9 --- /dev/null +++ b/datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; + +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; + +use datafusion_common::ScalarValue; +use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_APPROXIMATE; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, +}; +use datafusion_functions_aggregate_common::tdigest::{ + Centroid, TDigest, DEFAULT_MAX_SIZE, +}; + +use crate::approx_percentile_cont::{ApproxPercentileAccumulator, ApproxPercentileCont}; + +make_udaf_expr_and_func!( + ApproxPercentileContWithWeight, + approx_percentile_cont_with_weight, + expression weight percentile, + "Computes the approximate percentile continuous with weight of a set of numbers", + approx_percentile_cont_with_weight_udaf +); + +/// APPROX_PERCENTILE_CONT_WITH_WEIGHT aggregate expression +pub struct ApproxPercentileContWithWeight { + signature: Signature, + approx_percentile_cont: ApproxPercentileCont, +} + +impl Debug for ApproxPercentileContWithWeight { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ApproxPercentileContWithWeight") + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxPercentileContWithWeight { + fn default() -> Self { + Self::new() + } +} + +impl ApproxPercentileContWithWeight { + /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. + pub fn new() -> Self { + Self { + signature: Signature::one_of( + // Accept any numeric value paired with a float64 percentile + NUMERICS + .iter() + .map(|t| { + TypeSignature::Exact(vec![ + t.clone(), + t.clone(), + DataType::Float64, + ]) + }) + .collect(), + Immutable, + ), + approx_percentile_cont: ApproxPercentileCont::new(), + } + } +} + +impl AggregateUDFImpl for ApproxPercentileContWithWeight { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "approx_percentile_cont_with_weight" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric input types" + ); + } + if !arg_types[1].is_numeric() { + return plan_err!( + "approx_percentile_cont_with_weight requires numeric weight input types" + ); + } + if arg_types[2] != DataType::Float64 { + return plan_err!("approx_percentile_cont_with_weight requires float64 percentile input types"); + } + Ok(arg_types[0].clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!( + "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" + ); + } + + if acc_args.exprs.len() != 3 { + return plan_err!( + "approx_percentile_cont_with_weight requires three arguments: value, weight, percentile" + ); + } + + let sub_args = AccumulatorArgs { + exprs: &[ + Arc::clone(&acc_args.exprs[0]), + Arc::clone(&acc_args.exprs[2]), + ], + ..acc_args + }; + let approx_percentile_cont_accumulator = + self.approx_percentile_cont.create_accumulator(sub_args)?; + let accumulator = ApproxPercentileWithWeightAccumulator::new( + approx_percentile_cont_accumulator, + ); + Ok(Box::new(accumulator)) + } + + #[allow(rustdoc::private_intra_doc_links)] + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + self.approx_percentile_cont.state_fields(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_approx_percentile_cont_with_weight_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_approx_percentile_cont_with_weight_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_APPROXIMATE) + .with_description( + "Returns the weighted approximate percentile of input values using the t-digest algorithm.", + ) + .with_syntax_example("approx_percentile_cont_with_weight(expression, weight, percentile)") + .with_sql_example(r#"```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++----------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++----------------------------------------------------------------------+ +| 78.5 | ++----------------------------------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .with_argument("weight", "Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators.") + .with_argument("percentile", "Percentile to compute. Must be a float value between 0 and 1 (inclusive).") + .build() + .unwrap() + }) +} + +#[derive(Debug)] +pub struct ApproxPercentileWithWeightAccumulator { + approx_percentile_cont_accumulator: ApproxPercentileAccumulator, +} + +impl ApproxPercentileWithWeightAccumulator { + pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self { + Self { + approx_percentile_cont_accumulator, + } + } +} + +impl Accumulator for ApproxPercentileWithWeightAccumulator { + fn state(&mut self) -> Result> { + self.approx_percentile_cont_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let means = &values[0]; + let weights = &values[1]; + debug_assert_eq!( + means.len(), + weights.len(), + "invalid number of values in means and weights" + ); + let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?; + let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?; + let mut digests: Vec = vec![]; + for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { + digests.push(TDigest::new_with_centroid( + DEFAULT_MAX_SIZE, + Centroid::new(*mean, *weight), + )) + } + self.approx_percentile_cont_accumulator + .merge_digests(&digests); + Ok(()) + } + + fn evaluate(&mut self) -> Result { + self.approx_percentile_cont_accumulator.evaluate() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.approx_percentile_cont_accumulator + .merge_batch(states)?; + + Ok(()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.approx_percentile_cont_accumulator) + + self.approx_percentile_cont_accumulator.size() + } +} diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/functions-aggregate/src/array_agg.rs similarity index 53% rename from datafusion/physical-expr/src/aggregate/array_agg_ordered.rs rename to datafusion/functions-aggregate/src/array_agg.rs index 7e2c7bb27144..7c22c21e38c9 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -15,154 +15,320 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions which specify ordering requirement -//! that can evaluated at runtime during query execution - -use std::any::Any; -use std::cmp::Ordering; -use std::collections::{BinaryHeap, VecDeque}; -use std::fmt::Debug; -use std::sync::Arc; - -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::format_state_name; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, -}; - -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{new_empty_array, StructArray}; -use arrow_schema::{Fields, SortOptions}; - -use datafusion_common::utils::array_into_list_array; -use datafusion_common::utils::{compare_rows, get_row_at_idx}; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -/// Expression for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi -/// partition setting, partial aggregations are computed for every partition, -/// and then their results are merged. +//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`] + +use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, StructArray}; +use arrow::datatypes::DataType; + +use arrow_schema::{Field, Fields}; +use datafusion_common::cast::as_list_array; +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, Signature, Volatility}; +use datafusion_expr::{AggregateUDFImpl, Documentation}; +use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use std::collections::{HashSet, VecDeque}; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; + +make_udaf_expr_and_func!( + ArrayAgg, + array_agg, + expression, + "input values, including nulls, concatenated into an array", + array_agg_udaf +); + #[derive(Debug)] -pub struct OrderSensitiveArrayAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have `NULL`s - nullable: bool, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, - /// Whether the aggregation is running in reverse - reverse: bool, +/// ARRAY_AGG aggregate expression +pub struct ArrayAgg { + signature: Signature, } -impl OrderSensitiveArrayAgg { - /// Create a new `OrderSensitiveArrayAgg` aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - nullable: bool, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { +impl Default for ArrayAgg { + fn default() -> Self { Self { - name: name.into(), - input_data_type, - expr, - nullable, - order_by_data_types, - ordering_req, - reverse: false, + signature: Signature::any(1, Volatility::Immutable), } } } -impl AggregateExpr for OrderSensitiveArrayAgg { - fn as_any(&self) -> &dyn Any { +impl AggregateUDFImpl for ArrayAgg { + fn as_any(&self) -> &dyn std::any::Any { self } - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) + fn name(&self) -> &str { + "array_agg" } - fn create_accumulator(&self) -> Result> { - OrderSensitiveArrayAggAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.reverse, - ) - .map(|acc| Box::new(acc) as _) + fn aliases(&self) -> &[String] { + &[] + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + return Ok(vec![Field::new_list( + format_state_name(args.name, "distinct_array_agg"), + // See COMMENTS.md to understand why nullable is set to true + Field::new("item", args.input_types[0].clone(), true), + true, + )]); + } + let mut fields = vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() + format_state_name(args.name, "array_agg"), + // See COMMENTS.md to understand why nullable is set to true + Field::new("item", args.input_types[0].clone(), true), + true, )]; - let orderings = ordering_fields(&self.ordering_req, &self.order_by_data_types); + + if args.ordering_fields.is_empty() { + return Ok(fields); + } + + let orderings = args.ordering_fields.to_vec(); fields.push(Field::new_list( - format_state_name(&self.name, "array_agg_orderings"), + format_state_name(args.name, "array_agg_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, + false, )); + Ok(fields) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + + if acc_args.is_distinct { + return Ok(Box::new(DistinctArrayAggAccumulator::try_new(&data_type)?)); + } + + if acc_args.ordering_req.is_empty() { + return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?)); + } + + let ordering_dtypes = acc_args + .ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + OrderSensitiveArrayAggAccumulator::try_new( + &data_type, + &ordering_dtypes, + LexOrdering::from_ref(acc_args.ordering_req), + acc_args.is_reversed, + ) + .map(|acc| Box::new(acc) as _) } - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(array_agg_udaf()) } - fn name(&self) -> &str { - &self.name + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_agg_doc()) } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_agg_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order.", + ) + .with_syntax_example("array_agg(expression [ORDER BY expression])") + .with_sql_example(r#"```sql +> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| array_agg(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| [element1, element2, element3] | ++-----------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), - nullable: self.nullable, - order_by_data_types: self.order_by_data_types.clone(), - // Reverse requirement: - ordering_req: reverse_order_bys(&self.ordering_req), - reverse: !self.reverse, - })) +#[derive(Debug)] +pub struct ArrayAggAccumulator { + values: Vec, + datatype: DataType, +} + +impl ArrayAggAccumulator { + /// new array_agg accumulator based on given item data type + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: vec![], + datatype: datatype.clone(), + }) } } -impl PartialEq for OrderSensitiveArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) +impl Accumulator for ArrayAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // Append value like Int64Array(1,2,3) + if values.is_empty() { + return Ok(()); + } + + if values.len() != 1 { + return internal_err!("expects single batch"); + } + + let val = Arc::clone(&values[0]); + if val.len() > 0 { + self.values.push(val); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) + if states.is_empty() { + return Ok(()); + } + + if states.len() != 1 { + return internal_err!("expects single state"); + } + + let list_arr = as_list_array(&states[0])?; + for arr in list_arr.iter().flatten() { + self.values.push(arr); + } + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + // Transform Vec to ListArr + let element_arrays: Vec<&dyn Array> = + self.values.iter().map(|a| a.as_ref()).collect(); + + if element_arrays.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + + let concated_array = arrow::compute::concat(&element_arrays)?; + let list_array = array_into_list_array_nullable(concated_array); + + Ok(ScalarValue::List(Arc::new(list_array))) + } + + fn size(&self) -> usize { + size_of_val(self) + + (size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + self.datatype.size() + - size_of_val(&self.datatype) + } +} + +#[derive(Debug)] +struct DistinctArrayAggAccumulator { + values: HashSet, + datatype: DataType, +} + +impl DistinctArrayAggAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + values: HashSet::new(), + datatype: datatype.clone(), + }) + } +} + +impl Accumulator for DistinctArrayAggAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.len() != 1 { + return internal_err!("expects single batch"); + } + + let array = &values[0]; + + for i in 0..array.len() { + let scalar = ScalarValue::try_from_array(&array, i)?; + self.values.insert(scalar); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + + if states.len() != 1 { + return internal_err!("expects single state"); + } + + states[0] + .as_list::() + .iter() + .flatten() + .try_for_each(|val| self.update_batch(&[val])) + } + + fn evaluate(&mut self) -> Result { + let values: Vec = self.values.iter().cloned().collect(); + if values.is_empty() { + return Ok(ScalarValue::new_null_list(self.datatype.clone(), true, 1)); + } + let arr = ScalarValue::new_list(&values, &self.datatype, true); + Ok(ScalarValue::List(arr)) + } + + fn size(&self) -> usize { + size_of_val(self) + ScalarValue::size_of_hashset(&self.values) + - size_of_val(&self.values) + + self.datatype.size() + - size_of_val(&self.datatype) } } +/// Accumulator for a `ARRAY_AGG(... ORDER BY ..., ...)` aggregation. In a multi +/// partition setting, partial aggregations are computed for every partition, +/// and then their results are merged. #[derive(Debug)] pub(crate) struct OrderSensitiveArrayAggAccumulator { /// Stores entries in the `ARRAY_AGG` result. @@ -294,39 +460,50 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { fn state(&mut self) -> Result> { let mut result = vec![self.evaluate()?]; result.push(self.evaluate_orderings()?); + Ok(result) } fn evaluate(&mut self) -> Result { + if self.values.is_empty() { + return Ok(ScalarValue::new_null_list( + self.datatypes[0].clone(), + true, + 1, + )); + } + let values = self.values.clone(); let array = if self.reverse { - ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0]) + ScalarValue::new_list_from_iter( + values.into_iter().rev(), + &self.datatypes[0], + true, + ) } else { - ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0]) + ScalarValue::new_list_from_iter(values.into_iter(), &self.datatypes[0], true) }; Ok(ScalarValue::List(array)) } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } @@ -334,7 +511,7 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { impl OrderSensitiveArrayAggAccumulator { fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); let num_columns = fields.len(); let struct_field = Fields::from(fields.clone()); @@ -353,199 +530,24 @@ impl OrderSensitiveArrayAggAccumulator { column_wise_ordering_values.push(array); } - let ordering_array = StructArray::try_new( - struct_field.clone(), - column_wise_ordering_values, - None, - )?; - Ok(ScalarValue::List(Arc::new(array_into_list_array( + let ordering_array = + StructArray::try_new(struct_field, column_wise_ordering_values, None)?; + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( Arc::new(ordering_array), )))) } } -/// This is a wrapper struct to be able to correctly merge `ARRAY_AGG` data from -/// multiple partitions using `BinaryHeap`. When used inside `BinaryHeap`, this -/// struct returns smallest `CustomElement`, where smallest is determined by -/// `ordering` values (`Vec`) according to `sort_options`. -#[derive(Debug, PartialEq, Eq)] -struct CustomElement<'a> { - /// Stores the partition this entry came from - branch_idx: usize, - /// Values to merge - value: ScalarValue, - // Comparison "key" - ordering: Vec, - /// Options defining the ordering semantics - sort_options: &'a [SortOptions], -} - -impl<'a> CustomElement<'a> { - fn new( - branch_idx: usize, - value: ScalarValue, - ordering: Vec, - sort_options: &'a [SortOptions], - ) -> Self { - Self { - branch_idx, - value, - ordering, - sort_options, - } - } - - fn ordering( - &self, - current: &[ScalarValue], - target: &[ScalarValue], - ) -> Result { - // Calculate ordering according to `sort_options` - compare_rows(current, target, self.sort_options) - } -} - -// Overwrite ordering implementation such that -// - `self.ordering` values are used for comparison, -// - When used inside `BinaryHeap` it is a min-heap. -impl<'a> Ord for CustomElement<'a> { - fn cmp(&self, other: &Self) -> Ordering { - // Compares according to custom ordering - self.ordering(&self.ordering, &other.ordering) - // Convert max heap to min heap - .map(|ordering| ordering.reverse()) - // This function return error, when `self.ordering` and `other.ordering` - // have different types (such as one is `ScalarValue::Int64`, other is `ScalarValue::Float32`) - // Here this case won't happen, because data from each partition will have same type - .unwrap() - } -} - -impl<'a> PartialOrd for CustomElement<'a> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -/// This functions merges `values` array (`&[Vec]`) into single array `Vec` -/// Merging done according to ordering values stored inside `ordering_values` (`&[Vec>]`) -/// Inner `Vec` in the `ordering_values` can be thought as ordering information for the -/// each `ScalarValue` in the `values` array. -/// Desired ordering specified by `sort_options` argument (Should have same size with inner `Vec` -/// of the `ordering_values` array). -/// -/// As an example -/// values can be \[ -/// \[1, 2, 3, 4, 5\], -/// \[1, 2, 3, 4\], -/// \[1, 2, 3, 4, 5, 6\], -/// \] -/// In this case we will be merging three arrays (doesn't have to be same size) -/// and produce a merged array with size 15 (sum of 5+4+6) -/// Merging will be done according to ordering at `ordering_values` vector. -/// As an example `ordering_values` can be [ -/// \[(1, a), (2, b), (3, b), (4, a), (5, b) \], -/// \[(1, a), (2, b), (3, b), (4, a) \], -/// \[(1, b), (2, c), (3, d), (4, e), (5, a), (6, b) \], -/// ] -/// For each ScalarValue in the `values` we have a corresponding `Vec` (like timestamp of it) -/// for the example above `sort_options` will have size two, that defines ordering requirement of the merge. -/// Inner `Vec`s of the `ordering_values` will be compared according `sort_options` (Their sizes should match) -pub(crate) fn merge_ordered_arrays( - // We will merge values into single `Vec`. - values: &mut [VecDeque], - // `values` will be merged according to `ordering_values`. - // Inner `Vec` can be thought as ordering information for the - // each `ScalarValue` in the values`. - ordering_values: &mut [VecDeque>], - // Defines according to which ordering comparisons should be done. - sort_options: &[SortOptions], -) -> Result<(Vec, Vec>)> { - // Keep track the most recent data of each branch, in binary heap data structure. - let mut heap = BinaryHeap::::new(); - - if values.len() != ordering_values.len() - || values - .iter() - .zip(ordering_values.iter()) - .any(|(vals, ordering_vals)| vals.len() != ordering_vals.len()) - { - return exec_err!( - "Expects values arguments and/or ordering_values arguments to have same size" - ); - } - let n_branch = values.len(); - let mut merged_values = vec![]; - let mut merged_orderings = vec![]; - // Continue iterating the loop until consuming data of all branches. - loop { - let minimum = if let Some(minimum) = heap.pop() { - minimum - } else { - // Heap is empty, fill it with the next entries from each branch. - for branch_idx in 0..n_branch { - if let Some(orderings) = ordering_values[branch_idx].pop_front() { - // Their size should be same, we can safely .unwrap here. - let value = values[branch_idx].pop_front().unwrap(); - // Push the next element to the heap: - heap.push(CustomElement::new( - branch_idx, - value, - orderings, - sort_options, - )); - } - // If None, we consumed this branch, skip it. - } - - // Now we have filled the heap, get the largest entry (this will be - // the next element in merge). - if let Some(minimum) = heap.pop() { - minimum - } else { - // Heap is empty, this means that all indices are same with - // `end_indices`. We have consumed all of the branches, merge - // is completed, exit from the loop: - break; - } - }; - let CustomElement { - branch_idx, - value, - ordering, - .. - } = minimum; - // Add minimum value in the heap to the result - merged_values.push(value); - merged_orderings.push(ordering); - - // If there is an available entry, push next entry in the most - // recently consumed branch to the heap. - if let Some(orderings) = ordering_values[branch_idx].pop_front() { - // Their size should be same, we can safely .unwrap here. - let value = values[branch_idx].pop_front().unwrap(); - // Push the next element to the heap: - heap.push(CustomElement::new( - branch_idx, - value, - orderings, - sort_options, - )); - } - } - - Ok((merged_values, merged_orderings)) -} - #[cfg(test)] mod tests { + use super::*; + use std::collections::VecDeque; use std::sync::Arc; - use crate::aggregate::array_agg_ordered::merge_ordered_arrays; - - use arrow_array::{Array, ArrayRef, Int64Array}; + use arrow::array::Int64Array; use arrow_schema::SortOptions; + use datafusion_common::utils::get_row_at_idx; use datafusion_common::{Result, ScalarValue}; diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/functions-aggregate/src/average.rs similarity index 65% rename from datafusion/physical-expr/src/aggregate/average.rs rename to datafusion/functions-aggregate/src/average.rs index 065c2179f4c5..710b7e69ac5c 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -15,77 +15,96 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expressions that can evaluated at runtime during query execution +//! Defines `Avg` & `Mean` aggregate & accumulators -use arrow::array::{AsArray, PrimitiveBuilder}; -use log::debug; - -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; +use arrow::array::{ + Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, AsArray, + BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, +}; -use crate::aggregate::groups_accumulator::accumulate::NullState; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; use arrow::compute::sum; -use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; -use arrow::{ - array::{ArrayRef, UInt64Array}, - datatypes::Field, +use arrow::datatypes::{ + i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, + Float64Type, UInt64Type, +}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::{avg_return_type, coerce_avg_type}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, + ReversedUDAF, Signature, }; -use arrow_array::types::{Decimal256Type, DecimalType}; -use arrow_array::{ - Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, + +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, }; -use arrow_buffer::{i256, ArrowNativeType}; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::type_coercion::aggregates::avg_return_type; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; -use super::utils::DecimalAverager; +use datafusion_functions_aggregate_common::utils::DecimalAverager; +use log::debug; +use std::any::Any; +use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; + +make_udaf_expr_and_func!( + Avg, + avg, + expression, + "Returns the avg of a group of values.", + avg_udaf +); -/// AVG aggregate expression -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct Avg { - name: String, - expr: Arc, - input_data_type: DataType, - result_data_type: DataType, + signature: Signature, + aliases: Vec, } impl Avg { - /// Create a new AVG aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - let result_data_type = avg_return_type(&data_type).unwrap(); - + pub fn new() -> Self { Self { - name: name.into(), - expr, - input_data_type: data_type, - result_data_type, + signature: Signature::user_defined(Immutable), + aliases: vec![String::from("mean")], } } } -impl AggregateExpr for Avg { - /// Return a reference to Any that can be used for downcasting +impl Default for Avg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Avg { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.result_data_type.clone(), true)) + fn name(&self) -> &str { + "avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return exec_err!("avg(DISTINCT) aggregations are not available"); + } use DataType::*; + + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { + match (&data_type, acc_args.return_type) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -110,59 +129,49 @@ impl AggregateExpr for Avg { target_precision: *target_precision, target_scale: *target_scale, })), - _ => not_impl_err!( + _ => exec_err!( "AvgAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type + &data_type, + acc_args.return_type ), } } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "sum"), - self.input_data_type.clone(), + format_state_name(args.name, "sum"), + args.input_types[0].clone(), true, ), ]) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - self.create_accumulator() + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + matches!( + args.return_type, + DataType::Float64 | DataType::Decimal128(_, _) + ) } - fn groups_accumulator_supported(&self) -> bool { + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { use DataType::*; - matches!(&self.result_data_type, Float64 | Decimal128(_, _)) - } - - fn create_groups_accumulator(&self) -> Result> { - use DataType::*; + let data_type = args.exprs[0].data_type(args.schema)?; // instantiate specialized accumulator based for the type - match (&self.input_data_type, &self.result_data_type) { + match (&data_type, args.return_type) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - &self.result_data_type, + &data_type, + args.return_type, |sum: f64, count: u64| Ok(sum / count as f64), ))) } @@ -180,8 +189,8 @@ impl AggregateExpr for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - &self.result_data_type, + &data_type, + args.return_type, avg_fn, ))) } @@ -201,35 +210,66 @@ impl AggregateExpr for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - &self.input_data_type, - &self.result_data_type, + &data_type, + args.return_type, avg_fn, ))) } _ => not_impl_err!( "AvgGroupsAccumulator for ({} --> {})", - self.input_data_type, - self.result_data_type + &data_type, + args.return_type ), } } -} -impl PartialEq for Avg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.result_data_type == x.result_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("{} expects exactly one argument.", self.name()); + } + coerce_avg_type(self.name(), arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_avg_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_avg_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the average of numeric values in the specified column.", + ) + .with_syntax_example("avg(expression)") + .with_sql_example( + r#"```sql +> SELECT avg(column_name) FROM table_name; ++---------------------------+ +| avg(column_name) | ++---------------------------+ +| 42.75 | ++---------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + /// An accumulator to compute the average #[derive(Debug, Default)] pub struct AvgAccumulator { @@ -238,13 +278,6 @@ pub struct AvgAccumulator { } impl Accumulator for AvgAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::Float64(self.sum), - ]) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; @@ -255,13 +288,21 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.count -= (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { - self.sum = Some(self.sum.unwrap() - x); - } - Ok(()) + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Float64(self.sum), + ]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -275,23 +316,23 @@ impl Accumulator for AvgAccumulator { } Ok(()) } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Float64( - self.sum.map(|f| f / self.count as f64), - )) + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); + } + Ok(()) } + fn supports_retract_batch(&self) -> bool { true } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } } /// An accumulator to compute the average for decimals -struct DecimalAvgAccumulator { +#[derive(Debug)] +struct DecimalAvgAccumulator { sum: Option, count: u64, sum_scale: i8, @@ -300,56 +341,12 @@ struct DecimalAvgAccumulator { target_scale: i8, } -impl Debug for DecimalAvgAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DecimalAvgAccumulator") - .field("sum", &self.sum) - .field("count", &self.count) - .field("sum_scale", &self.sum_scale) - .field("sum_precision", &self.sum_precision) - .field("target_precision", &self.target_precision) - .field("target_scale", &self.target_scale) - .finish() - } -} - -impl Accumulator for DecimalAvgAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::new_primitive::( - self.sum, - &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), - )?, - ]) - } - +impl Accumulator for DecimalAvgAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = values[0].as_primitive::(); - self.count += (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { - let v = self.sum.get_or_insert(T::Native::default()); - self.sum = Some(v.add_wrapping(x)); - } - Ok(()) - } - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.count -= (values.len() - values.null_count()) as u64; if let Some(x) = sum(values) { - self.sum = Some(self.sum.unwrap().sub_wrapping(x)); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - // counts are summed - self.count += sum(states[0].as_primitive::()).unwrap_or_default(); - - // sums are summed - if let Some(x) = sum(states[1].as_primitive::()) { let v = self.sum.get_or_insert(T::Native::default()); self.sum = Some(v.add_wrapping(x)); } @@ -374,12 +371,43 @@ impl Accumulator for DecimalAvgAccumulator &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), ) } - fn supports_retract_batch(&self) -> bool { - true - } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::new_primitive::( + self.sum, + &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), + )?, + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(T::Native::default()); + self.sum = Some(v.add_wrapping(x)); + } + Ok(()) + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap().sub_wrapping(x)); + } + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true } } @@ -444,7 +472,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); @@ -469,45 +497,6 @@ where Ok(()) } - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 2, "two arguments to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - let partial_sums = values[1].as_primitive::(); - // update counts with partial counts - self.counts.resize(total_num_groups, 0); - self.null_state.accumulate( - group_indices, - partial_counts, - opt_filter, - total_num_groups, - |group_index, partial_count| { - self.counts[group_index] += partial_count; - }, - ); - - // update sums - self.sums.resize(total_num_groups, T::default_value()); - self.null_state.accumulate( - group_indices, - partial_sums, - opt_filter, - total_num_groups, - |group_index, new_value: ::Native| { - let sum = &mut self.sums[group_index]; - *sum = sum.add_wrapping(new_value); - }, - ); - - Ok(()) - } - fn evaluate(&mut self, emit_to: EmitTo) -> Result { let counts = emit_to.take_needed(&mut self.counts); let sums = emit_to.take_needed(&mut self.sums); @@ -562,115 +551,70 @@ where ]) } - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - + self.sums.capacity() * std::mem::size_of::() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::assert_aggregate; - use arrow::array::*; - use datafusion_expr::AggregateFunction; - - #[test] - fn avg_decimal() { - // test agg - let array: ArrayRef = Arc::new( - (1..7) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - - assert_aggregate( - array, - AggregateFunction::Avg, - false, - ScalarValue::Decimal128(Some(35000), 14, 4), + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + let partial_sums = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + self.null_state.accumulate( + group_indices, + partial_counts, + opt_filter, + total_num_groups, + |group_index, partial_count| { + self.counts[group_index] += partial_count; + }, ); - } - #[test] - fn avg_decimal_with_nulls() { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - assert_aggregate( - array, - AggregateFunction::Avg, - false, - ScalarValue::Decimal128(Some(32500), 14, 4), + // update sums + self.sums.resize(total_num_groups, T::default_value()); + self.null_state.accumulate( + group_indices, + partial_sums, + opt_filter, + total_num_groups, + |group_index, new_value: ::Native| { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + }, ); - } - #[test] - fn avg_decimal_all_nulls() { - // test agg - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - assert_aggregate( - array, - AggregateFunction::Avg, - false, - ScalarValue::Decimal128(None, 14, 4), - ); + Ok(()) } - #[test] - fn avg_i32() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); - } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let sums = values[0] + .as_primitive::() + .clone() + .with_data_type(self.sum_data_type.clone()); + let counts = UInt64Array::from_value(1, sums.len()); - #[test] - fn avg_i32_with_nulls() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.25f64)); - } + let nulls = filtered_null_mask(opt_filter, &sums); - #[test] - fn avg_i32_all_nulls() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::Float64(None)); - } + // set nulls on the arrays + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); - #[test] - fn avg_u32() { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3.0f64)); + Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) } - #[test] - fn avg_f32() { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); + fn supports_convert_to_state(&self) -> bool { + true } - #[test] - fn avg_f64() { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - assert_aggregate(a, AggregateFunction::Avg, false, ScalarValue::from(3_f64)); + fn size(&self) -> usize { + self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } diff --git a/datafusion/functions-aggregate/src/bit_and_or_xor.rs b/datafusion/functions-aggregate/src/bit_and_or_xor.rs new file mode 100644 index 000000000000..249ff02e7222 --- /dev/null +++ b/datafusion/functions-aggregate/src/bit_and_or_xor.rs @@ -0,0 +1,579 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `BitAnd`, `BitOr`, `BitXor` and `BitXor DISTINCT` aggregate accumulators + +use std::any::Any; +use std::collections::HashSet; +use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; + +use ahash::RandomState; +use arrow::array::{downcast_integer, Array, ArrayRef, AsArray}; +use arrow::datatypes::{ + ArrowNativeType, ArrowNumericType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_schema::Field; + +use datafusion_common::cast::as_list_array; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::INTEGERS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, +}; + +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign}; +use std::sync::OnceLock; + +/// This macro helps create group accumulators based on bitwise operations typically used internally +/// and might not be necessary for users to call directly. +macro_rules! group_accumulator_helper { + ($t:ty, $dt:expr, $opr:expr) => { + match $opr { + BitwiseOperationType::And => Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitand_assign(y)) + .with_starting_value(!0), + )), + BitwiseOperationType::Or => Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitor_assign(y)), + )), + BitwiseOperationType::Xor => Ok(Box::new( + PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| x.bitxor_assign(y)), + )), + } + }; +} + +/// `accumulator_helper` is a macro accepting (ArrowPrimitiveType, BitwiseOperationType, bool) +macro_rules! accumulator_helper { + ($t:ty, $opr:expr, $is_distinct: expr) => { + match $opr { + BitwiseOperationType::And => Ok(Box::>::default()), + BitwiseOperationType::Or => Ok(Box::>::default()), + BitwiseOperationType::Xor => { + if $is_distinct { + Ok(Box::>::default()) + } else { + Ok(Box::>::default()) + } + } + } + }; +} + +/// AND, OR and XOR only supports a subset of numeric types +/// +/// `args` is [AccumulatorArgs] +/// `opr` is [BitwiseOperationType] +/// `is_distinct` is boolean value indicating whether the operation is distinct or not. +macro_rules! downcast_bitwise_accumulator { + ($args:ident, $opr:expr, $is_distinct: expr) => { + match $args.return_type { + DataType::Int8 => accumulator_helper!(Int8Type, $opr, $is_distinct), + DataType::Int16 => accumulator_helper!(Int16Type, $opr, $is_distinct), + DataType::Int32 => accumulator_helper!(Int32Type, $opr, $is_distinct), + DataType::Int64 => accumulator_helper!(Int64Type, $opr, $is_distinct), + DataType::UInt8 => accumulator_helper!(UInt8Type, $opr, $is_distinct), + DataType::UInt16 => accumulator_helper!(UInt16Type, $opr, $is_distinct), + DataType::UInt32 => accumulator_helper!(UInt32Type, $opr, $is_distinct), + DataType::UInt64 => accumulator_helper!(UInt64Type, $opr, $is_distinct), + _ => { + not_impl_err!( + "{} not supported for {}: {}", + stringify!($opr), + $args.name, + $args.return_type + ) + } + } + }; +} + +/// Simplifies the creation of User-Defined Aggregate Functions (UDAFs) for performing bitwise operations in a declarative manner. +/// +/// `EXPR_FN` identifier used to name the generated expression function. +/// `AGGREGATE_UDF_FN` is an identifier used to name the underlying UDAF function. +/// `OPR_TYPE` is an expression that evaluates to the type of bitwise operation to be performed. +/// `DOCUMENTATION` documentation for the UDAF +macro_rules! make_bitwise_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $OPR_TYPE:expr, $DOCUMENTATION:expr) => { + make_udaf_expr!( + $EXPR_FN, + expr_x, + concat!( + "Returns the bitwise", + stringify!($OPR_TYPE), + "of a group of values" + ), + $AGGREGATE_UDF_FN + ); + create_func!( + $EXPR_FN, + $AGGREGATE_UDF_FN, + BitwiseOperation::new($OPR_TYPE, stringify!($EXPR_FN), $DOCUMENTATION) + ); + }; +} + +static BIT_AND_DOC: OnceLock = OnceLock::new(); + +fn get_bit_and_doc() -> &'static Documentation { + BIT_AND_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Computes the bitwise AND of all non-null input values.") + .with_syntax_example("bit_and(expression)") + .with_standard_argument("expression", Some("Integer")) + .build() + .unwrap() + }) +} + +static BIT_OR_DOC: OnceLock = OnceLock::new(); + +fn get_bit_or_doc() -> &'static Documentation { + BIT_OR_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Computes the bitwise OR of all non-null input values.") + .with_syntax_example("bit_or(expression)") + .with_standard_argument("expression", Some("Integer")) + .build() + .unwrap() + }) +} + +static BIT_XOR_DOC: OnceLock = OnceLock::new(); + +fn get_bit_xor_doc() -> &'static Documentation { + BIT_XOR_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Computes the bitwise exclusive OR of all non-null input values.", + ) + .with_syntax_example("bit_xor(expression)") + .with_standard_argument("expression", Some("Integer")) + .build() + .unwrap() + }) +} + +make_bitwise_udaf_expr_and_func!( + bit_and, + bit_and_udaf, + BitwiseOperationType::And, + get_bit_and_doc() +); +make_bitwise_udaf_expr_and_func!( + bit_or, + bit_or_udaf, + BitwiseOperationType::Or, + get_bit_or_doc() +); +make_bitwise_udaf_expr_and_func!( + bit_xor, + bit_xor_udaf, + BitwiseOperationType::Xor, + get_bit_xor_doc() +); + +/// The different types of bitwise operations that can be performed. +#[derive(Debug, Clone, Eq, PartialEq)] +enum BitwiseOperationType { + And, + Or, + Xor, +} + +impl Display for BitwiseOperationType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +/// [BitwiseOperation] struct encapsulates information about a bitwise operation. +#[derive(Debug)] +struct BitwiseOperation { + signature: Signature, + /// `operation` indicates the type of bitwise operation to be performed. + operation: BitwiseOperationType, + func_name: &'static str, + documentation: &'static Documentation, +} + +impl BitwiseOperation { + pub fn new( + operator: BitwiseOperationType, + func_name: &'static str, + documentation: &'static Documentation, + ) -> Self { + Self { + operation: operator, + signature: Signature::uniform(1, INTEGERS.to_vec(), Volatility::Immutable), + func_name, + documentation, + } + } +} + +impl AggregateUDFImpl for BitwiseOperation { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_type = &arg_types[0]; + if !arg_type.is_integer() { + return exec_err!( + "[return_type] {} not supported for {}", + self.name(), + arg_type + ); + } + Ok(arg_type.clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + downcast_bitwise_accumulator!(acc_args, self.operation, acc_args.is_distinct) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if self.operation == BitwiseOperationType::Xor && args.is_distinct { + Ok(vec![Field::new_list( + format_state_name( + args.name, + format!("{} distinct", self.name()).as_str(), + ), + // See COMMENTS.md to understand why nullable is set to true + Field::new("item", args.return_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, self.name()), + args.return_type.clone(), + true, + )]) + } + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + let data_type = args.return_type; + let operation = &self.operation; + downcast_integer! { + data_type => (group_accumulator_helper, data_type, operation), + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + self.name(), + data_type + ), + } + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn documentation(&self) -> Option<&Documentation> { + Some(self.documentation) + } +} + +struct BitAndAccumulator { + value: Option, +} + +impl std::fmt::Debug for BitAndAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "BitAndAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for BitAndAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitAndAccumulator +where + T::Native: std::ops::BitAnd, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if let Some(x) = arrow::compute::bit_and(values[0].as_primitive::()) { + let v = self.value.get_or_insert(x); + *v = *v & x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +struct BitOrAccumulator { + value: Option, +} + +impl std::fmt::Debug for BitOrAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "BitOrAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for BitOrAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitOrAccumulator +where + T::Native: std::ops::BitOr, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if let Some(x) = arrow::compute::bit_or(values[0].as_primitive::()) { + let v = self.value.get_or_insert(T::Native::usize_as(0)); + *v = *v | x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +struct BitXorAccumulator { + value: Option, +} + +impl std::fmt::Debug for BitXorAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "BitXorAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for BitXorAccumulator { + fn default() -> Self { + Self { value: None } + } +} + +impl Accumulator for BitXorAccumulator +where + T::Native: std::ops::BitXor, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if let Some(x) = arrow::compute::bit_xor(values[0].as_primitive::()) { + let v = self.value.get_or_insert(T::Native::usize_as(0)); + *v = *v ^ x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // XOR is it's own inverse + self.update_batch(values) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +struct DistinctBitXorAccumulator { + values: HashSet, +} + +impl std::fmt::Debug for DistinctBitXorAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) + } +} + +impl Default for DistinctBitXorAccumulator { + fn default() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for DistinctBitXorAccumulator +where + T::Native: std::ops::BitXor + std::hash::Hash + Eq, +{ + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(array.value(idx)); + } + } + None => array.values().iter().for_each(|x| { + self.values.insert(*x); + }), + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc ^ *distinct_value; + } + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &T::DATA_TYPE) + } + + fn size(&self) -> usize { + size_of_val(self) + self.values.capacity() * size_of::() + } + + fn state(&mut self) -> Result> { + // 1. Stores aggregate state in `ScalarValue::List` + // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set + let state_out = { + let values = self + .values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) + .collect::>>()?; + + let arr = ScalarValue::new_list_nullable(&values, &T::DATA_TYPE); + vec![ScalarValue::List(arr)] + }; + Ok(state_out) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if let Some(state) = states.first() { + let list_arr = as_list_array(state)?; + for arr in list_arr.iter().flatten() { + self.update_batch(&[arr])?; + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{ArrayRef, UInt64Array}; + use arrow::datatypes::UInt64Type; + use datafusion_common::ScalarValue; + + use crate::bit_and_or_xor::BitXorAccumulator; + use datafusion_expr::Accumulator; + + #[test] + fn test_bit_xor_accumulator() { + let mut accumulator = BitXorAccumulator:: { value: None }; + let batches: Vec<_> = vec![vec![1, 2], vec![1]] + .into_iter() + .map(|b| Arc::new(b.into_iter().collect::()) as ArrayRef) + .collect(); + + let added = &[Arc::clone(&batches[0])]; + let retracted = &[Arc::clone(&batches[1])]; + + // XOR of 1..3 is 3 + accumulator.update_batch(added).unwrap(); + assert_eq!( + accumulator.evaluate().unwrap(), + ScalarValue::UInt64(Some(3)) + ); + + // Removing [1] ^ 3 = 2 + accumulator.retract_batch(retracted).unwrap(); + assert_eq!( + accumulator.evaluate().unwrap(), + ScalarValue::UInt64(Some(2)) + ); + } +} diff --git a/datafusion/functions-aggregate/src/bool_and_or.rs b/datafusion/functions-aggregate/src/bool_and_or.rs new file mode 100644 index 000000000000..87293ccfa21f --- /dev/null +++ b/datafusion/functions-aggregate/src/bool_and_or.rs @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::mem::size_of_val; +use std::sync::OnceLock; + +use arrow::array::ArrayRef; +use arrow::array::BooleanArray; +use arrow::compute::bool_and as compute_bool_and; +use arrow::compute::bool_or as compute_bool_or; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; + +use datafusion_common::internal_err; +use datafusion_common::{downcast_value, not_impl_err}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, +}; + +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; + +// returns the new value after bool_and/bool_or with the new values, taking nullability into account +macro_rules! typed_bool_and_or_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let delta = $OP(array); + Ok(ScalarValue::$SCALAR(delta)) + }}; +} + +// bool_and/bool_or the array and returns a ScalarValue of its corresponding type. +macro_rules! bool_and_or_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Boolean => { + typed_bool_and_or_batch!($VALUES, BooleanArray, Boolean, $OP) + } + e => { + return internal_err!( + "Bool and/Bool or is not expected to receive the type {e:?}" + ); + } + } + }}; +} + +/// dynamically-typed bool_and(array) -> ScalarValue +fn bool_and_batch(values: &ArrayRef) -> Result { + bool_and_or_batch!(values, compute_bool_and) +} + +/// dynamically-typed bool_or(array) -> ScalarValue +fn bool_or_batch(values: &ArrayRef) -> Result { + bool_and_or_batch!(values, compute_bool_or) +} + +make_udaf_expr_and_func!( + BoolAnd, + bool_and, + expression, + "The values to combine with `AND`", + bool_and_udaf +); + +make_udaf_expr_and_func!( + BoolOr, + bool_or, + expression, + "The values to combine with `OR`", + bool_or_udaf +); + +/// BOOL_AND aggregate expression +#[derive(Debug)] +pub struct BoolAnd { + signature: Signature, +} + +impl BoolAnd { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Boolean], + Volatility::Immutable, + ), + } + } +} + +impl Default for BoolAnd { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for BoolAnd { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bool_and" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::::default()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, self.name()), + DataType::Boolean, + true, + )]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + match args.return_type { + DataType::Boolean => { + Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y, true))) + } + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + args.name, + args.return_type + ), + } + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bool_and_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_bool_and_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns true if all non-null input values are true, otherwise false.", + ) + .with_syntax_example("bool_and(expression)") + .with_sql_example( + r#"```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +#[derive(Debug, Default)] +struct BoolAndAccumulator { + acc: Option, +} + +impl Accumulator for BoolAndAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.acc = match (self.acc, bool_and_batch(values)?) { + (None, ScalarValue::Boolean(v)) => v, + (Some(v), ScalarValue::Boolean(None)) => Some(v), + (Some(a), ScalarValue::Boolean(Some(b))) => Some(a && b), + _ => unreachable!(), + }; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Boolean(self.acc)) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Boolean(self.acc)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} + +/// BOOL_OR aggregate expression +#[derive(Debug, Clone)] +pub struct BoolOr { + signature: Signature, +} + +impl BoolOr { + fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Boolean], + Volatility::Immutable, + ), + } + } +} + +impl Default for BoolOr { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for BoolOr { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bool_or" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn accumulator(&self, _: AccumulatorArgs) -> Result> { + Ok(Box::::default()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, self.name()), + DataType::Boolean, + true, + )]) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + match args.return_type { + DataType::Boolean => Ok(Box::new(BooleanGroupsAccumulator::new( + |x, y| x || y, + false, + ))), + _ => not_impl_err!( + "GroupsAccumulator not supported for {} with {}", + args.name, + args.return_type + ), + } + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bool_or_doc()) + } +} + +fn get_bool_or_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns true if any non-null input value is true, otherwise false.", + ) + .with_syntax_example("bool_or(expression)") + .with_sql_example( + r#"```sql +> SELECT bool_or(column_name) FROM table_name; ++----------------------------+ +| bool_or(column_name) | ++----------------------------+ +| true | ++----------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +#[derive(Debug, Default)] +struct BoolOrAccumulator { + acc: Option, +} + +impl Accumulator for BoolOrAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + self.acc = match (self.acc, bool_or_batch(values)?) { + (None, ScalarValue::Boolean(v)) => v, + (Some(v), ScalarValue::Boolean(None)) => Some(v), + (Some(a), ScalarValue::Boolean(Some(b))) => Some(a || b), + _ => unreachable!(), + }; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Boolean(self.acc)) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Boolean(self.acc)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } +} diff --git a/datafusion/functions-aggregate/src/correlation.rs b/datafusion/functions-aggregate/src/correlation.rs new file mode 100644 index 000000000000..187a43ecbea3 --- /dev/null +++ b/datafusion/functions-aggregate/src/correlation.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Correlation`]: correlation sample aggregations. + +use std::any::Any; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; + +use arrow::compute::{and, filter, is_not_null}; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; + +use crate::covariance::CovarianceAccumulator; +use crate::stddev::StddevAccumulator; +use datafusion_common::{plan_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::stats::StatsType; + +make_udaf_expr_and_func!( + Correlation, + corr, + y x, + "Correlation between two numeric values.", + corr_udaf +); + +#[derive(Debug)] +pub struct Correlation { + signature: Signature, +} + +impl Default for Correlation { + fn default() -> Self { + Self::new() + } +} + +impl Correlation { + /// Create a new COVAR_POP aggregate function + pub fn new() -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Correlation { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "corr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Correlation requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CorrelationAccumulator::try_new()?)) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean1"), DataType::Float64, true), + Field::new(format_state_name(name, "m2_1"), DataType::Float64, true), + Field::new(format_state_name(name, "mean2"), DataType::Float64, true), + Field::new(format_state_name(name, "m2_2"), DataType::Float64, true), + Field::new( + format_state_name(name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_corr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_corr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the coefficient of correlation between two numeric values.", + ) + .with_syntax_example("corr(expression1, expression2)") + .with_sql_example( + r#"```sql +> SELECT corr(column1, column2) FROM table_name; ++--------------------------------+ +| corr(column1, column2) | ++--------------------------------+ +| 0.85 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) + .build() + .unwrap() + }) +} + +/// An accumulator to compute correlation +#[derive(Debug)] +pub struct CorrelationAccumulator { + covar: CovarianceAccumulator, + stddev1: StddevAccumulator, + stddev2: StddevAccumulator, +} + +impl CorrelationAccumulator { + /// Creates a new `CorrelationAccumulator` + pub fn try_new() -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new(StatsType::Population)?, + stddev1: StddevAccumulator::try_new(StatsType::Population)?, + stddev2: StddevAccumulator::try_new(StatsType::Population)?, + }) + } +} + +impl Accumulator for CorrelationAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + // TODO: null input skipping logic duplicated across Correlation + // and its children accumulators. + // This could be simplified by splitting up input filtering and + // calculation logic in children accumulators, and calling only + // calculation part from Correlation + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.update_batch(&values)?; + self.stddev1.update_batch(&values[0..1])?; + self.stddev2.update_batch(&values[1..2])?; + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let covar = self.covar.evaluate()?; + let stddev1 = self.stddev1.evaluate()?; + let stddev2 = self.stddev2.evaluate()?; + + if let ScalarValue::Float64(Some(c)) = covar { + if let ScalarValue::Float64(Some(s1)) = stddev1 { + if let ScalarValue::Float64(Some(s2)) = stddev2 { + if s1 == 0_f64 || s2 == 0_f64 { + return Ok(ScalarValue::Float64(Some(0_f64))); + } else { + return Ok(ScalarValue::Float64(Some(c / s1 / s2))); + } + } + } + } + + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.covar) + self.covar.size() + - size_of_val(&self.stddev1) + + self.stddev1.size() + - size_of_val(&self.stddev2) + + self.stddev2.size() + } + + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.covar.get_count()), + ScalarValue::from(self.covar.get_mean1()), + ScalarValue::from(self.stddev1.get_m2()), + ScalarValue::from(self.covar.get_mean2()), + ScalarValue::from(self.stddev2.get_m2()), + ScalarValue::from(self.covar.get_algo_const()), + ]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states_c = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[3]), + Arc::clone(&states[5]), + ]; + let states_s1 = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[2]), + ]; + let states_s2 = [ + Arc::clone(&states[0]), + Arc::clone(&states[3]), + Arc::clone(&states[4]), + ]; + + self.covar.merge_batch(&states_c)?; + self.stddev1.merge_batch(&states_s1)?; + self.stddev2.merge_batch(&states_s2)?; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.retract_batch(&values)?; + self.stddev1.retract_batch(&values[0..1])?; + self.stddev2.retract_batch(&values[1..2])?; + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs new file mode 100644 index 000000000000..bade589a908a --- /dev/null +++ b/datafusion/functions-aggregate/src/count.rs @@ -0,0 +1,732 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ahash::RandomState; +use datafusion_common::stats::Precision; +use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_physical_expr::expressions; +use std::collections::HashSet; +use std::fmt::Debug; +use std::mem::{size_of, size_of_val}; +use std::ops::BitAnd; +use std::sync::{Arc, OnceLock}; + +use arrow::{ + array::{ArrayRef, AsArray}, + compute, + datatypes::{ + DataType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Field, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, + }, +}; + +use arrow::{ + array::{Array, BooleanArray, Int64Array, PrimitiveArray}, + buffer::BooleanBuffer, +}; +use datafusion_common::{ + downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + Documentation, EmitTo, GroupsAccumulator, Signature, Volatility, +}; +use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; +use datafusion_functions_aggregate_common::aggregate::count_distinct::{ + BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, + PrimitiveDistinctCountAccumulator, +}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; +use datafusion_physical_expr_common::binary_map::OutputType; + +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; +make_udaf_expr_and_func!( + Count, + count, + expr, + "Count the number of non-null values in the column", + count_udaf +); + +pub fn count_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + count_udaf(), + vec![expr], + true, + None, + None, + None, + )) +} + +pub struct Count { + signature: Signature, +} + +impl Debug for Count { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("Count") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Count { + fn default() -> Self { + Self::new() + } +} + +impl Count { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + // TypeSignature::Any(0) is required to handle `Count()` with no args + vec![TypeSignature::VariadicAny, TypeSignature::Any(0)], + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for Count { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int64) + } + + fn is_nullable(&self) -> bool { + false + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "count distinct"), + // See COMMENTS.md to understand why nullable is set to true + Field::new("item", args.input_types[0].clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "count"), + DataType::Int64, + false, + )]) + } + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if !acc_args.is_distinct { + return Ok(Box::new(CountAccumulator::new())); + } + + if acc_args.exprs.len() > 1 { + return not_impl_err!("COUNT DISTINCT with multiple arguments"); + } + + let data_type = &acc_args.exprs[0].data_type(acc_args.schema)?; + Ok(match data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator + DataType::Int8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Int64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt8 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt16 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::UInt64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal128Type, + >::new(data_type)), + DataType::Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< + Decimal256Type, + >::new(data_type)), + + DataType::Date32 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Date64 => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time32(TimeUnit::Millisecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time32(TimeUnit::Second) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Time64(TimeUnit::Microsecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Time64(TimeUnit::Nanosecond) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + DataType::Timestamp(TimeUnit::Microsecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Nanosecond, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new( + data_type, + ), + ), + DataType::Timestamp(TimeUnit::Second, _) => Box::new( + PrimitiveDistinctCountAccumulator::::new(data_type), + ), + + DataType::Float16 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float32 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + DataType::Float64 => { + Box::new(FloatDistinctCountAccumulator::::new()) + } + + DataType::Utf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Utf8View => { + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8View)) + } + DataType::LargeUtf8 => { + Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) + } + DataType::Binary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + DataType::BinaryView => Box::new(BytesViewDistinctCountAccumulator::new( + OutputType::BinaryView, + )), + DataType::LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( + OutputType::Binary, + )), + + // Use the generic accumulator based on `ScalarValue` for all other types + _ => Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: data_type.clone(), + }), + }) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + // groups accumulator only supports `COUNT(c1)`, not + // `COUNT(c1, c2)`, etc + if args.is_distinct { + return false; + } + args.exprs.len() == 1 + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator + Ok(Box::new(CountGroupsAccumulator::new())) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Int64(Some(0))) + } + + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + if statistics_args.is_distinct { + return None; + } + if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + let current_val = &statistics_args.statistics.column_statistics + [col_expr.index()] + .null_count; + if let &Precision::Exact(val) = current_val { + return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); + } + } else if let Some(lit_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(ScalarValue::Int64(Some(num_rows as i64))); + } + } + } + } + None + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_count_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_count_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.", + ) + .with_syntax_example("count(expression)") + .with_sql_example(r#"```sql +> SELECT count(column_name) FROM table_name; ++-----------------------+ +| count(column_name) | ++-----------------------+ +| 100 | ++-----------------------+ + +> SELECT count(*) FROM table_name; ++------------------+ +| count(*) | ++------------------+ +| 120 | ++------------------+ +```"#) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +#[derive(Debug)] +struct CountAccumulator { + count: i64, +} + +impl CountAccumulator { + /// new count accumulator + pub fn new() -> Self { + Self { count: 0 } + } +} + +impl Accumulator for CountAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Int64(Some(self.count))]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let array = &values[0]; + self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Int64Array); + let delta = &compute::sum(counts); + if let Some(d) = delta { + self.count += *d; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.count))) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +/// An accumulator to compute the counts of [`PrimitiveArray`]. +/// Stores values as native types, and does overflow checking +/// +/// Unlike most other accumulators, COUNT never produces NULLs. If no +/// non-null values are seen in any group the output is 0. Thus, this +/// accumulator has no additional null or seen filter tracking. +#[derive(Debug)] +struct CountGroupsAccumulator { + /// Count per group. + /// + /// Note this is an i64 and not a u64 (or usize) because the + /// output type of count is `DataType::Int64`. Thus by using `i64` + /// for the counts, the output [`Int64Array`] can be created + /// without copy. + counts: Vec, +} + +impl CountGroupsAccumulator { + pub fn new() -> Self { + Self { counts: vec![] } + } +} + +impl GroupsAccumulator for CountGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &values[0]; + + // Add one to each group's counter for each non null, non + // filtered value + self.counts.resize(total_num_groups, 0); + accumulate_indices( + group_indices, + values.logical_nulls().as_ref(), + opt_filter, + |group_index| { + self.counts[group_index] += 1; + }, + ); + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "one argument to merge_batch"); + // first batch is counts, second is partial sums + let partial_counts = values[0].as_primitive::(); + + // intermediate counts are always created as non null + assert_eq!(partial_counts.null_count(), 0); + let partial_counts = partial_counts.values(); + + // Adds the counts with the partial counts + self.counts.resize(total_num_groups, 0); + match opt_filter { + Some(filter) => filter + .iter() + .zip(group_indices.iter()) + .zip(partial_counts.iter()) + .for_each(|((filter_value, &group_index), partial_count)| { + if let Some(true) = filter_value { + self.counts[group_index] += partial_count; + } + }), + None => group_indices.iter().zip(partial_counts.iter()).for_each( + |(&group_index, partial_count)| { + self.counts[group_index] += partial_count; + }, + ), + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + + // Count is always non null (null inputs just don't contribute to the overall values) + let nulls = None; + let array = PrimitiveArray::::new(counts.into(), nulls); + + Ok(Arc::new(array)) + } + + // return arrays for counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls + Ok(vec![Arc::new(counts) as ArrayRef]) + } + + /// Converts an input batch directly to a state batch + /// + /// The state of `COUNT` is always a single Int64Array: + /// * `1` (for non-null, non filtered values) + /// * `0` (for null values) + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let values = &values[0]; + + let state_array = match (values.logical_nulls(), opt_filter) { + (None, None) => { + // In case there is no nulls in input and no filter, returning array of 1 + Arc::new(Int64Array::from_value(1, values.len())) + } + (Some(nulls), None) => { + // If there are any nulls in input values -- casting `nulls` (true for values, false for nulls) + // of input array to Int64 + let nulls = BooleanArray::new(nulls.into_inner(), None); + compute::cast(&nulls, &DataType::Int64)? + } + (None, Some(filter)) => { + // If there is only filter + // - applying filter null mask to filter values by bitand filter values and nulls buffers + // (using buffers guarantees absence of nulls in result) + // - casting result of bitand to Int64 array + let (filter_values, filter_nulls) = filter.clone().into_parts(); + + let state_buf = match filter_nulls { + Some(filter_nulls) => &filter_values & filter_nulls.inner(), + None => filter_values, + }; + + let boolean_state = BooleanArray::new(state_buf, None); + compute::cast(&boolean_state, &DataType::Int64)? + } + (Some(nulls), Some(filter)) => { + // For both input nulls and filter + // - applying filter null mask to filter values by bitand filter values and nulls buffers + // (using buffers guarantees absence of nulls in result) + // - applying values null mask to filter buffer by another bitand on filter result and + // nulls from input values + // - casting result to Int64 array + let (filter_values, filter_nulls) = filter.clone().into_parts(); + + let filter_buf = match filter_nulls { + Some(filter_nulls) => &filter_values & filter_nulls.inner(), + None => filter_values, + }; + let state_buf = &filter_buf & nulls.inner(); + + let boolean_state = BooleanArray::new(state_buf, None); + compute::cast(&boolean_state, &DataType::Int64)? + } + }; + + Ok(vec![state_array]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.counts.capacity() * size_of::() + } +} + +/// count null values for multiple columns +/// for each row if one column value is null, then null_count + 1 +fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { + if values.len() > 1 { + let result_bool_buf: Option = values + .iter() + .map(|a| a.logical_nulls()) + .fold(None, |acc, b| match (acc, b) { + (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), + (Some(acc), None) => Some(acc), + (None, Some(b)) => Some(b.into_inner()), + _ => None, + }); + result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) + } else { + values[0] + .logical_nulls() + .map_or(0, |nulls| nulls.null_count()) + } +} + +/// General purpose distinct accumulator that works for any DataType by using +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) +/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and +/// [`BytesDistinctCountAccumulator`] +#[derive(Debug)] +struct DistinctCountAccumulator { + values: HashSet, + state_data_type: DataType, +} + +impl DistinctCountAccumulator { + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types + fn fixed_size(&self) -> usize { + size_of_val(self) + + (size_of::() * self.values.capacity()) + + self + .values + .iter() + .next() + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) + .unwrap_or(0) + + size_of::() + } + + // calculates the size as accurately as possible. Note that calling this + // method is expensive + fn full_size(&self) -> usize { + size_of_val(self) + + (size_of::() * self.values.capacity()) + + self + .values + .iter() + .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) + .sum::() + + size_of::() + } +} + +impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. + fn state(&mut self) -> Result> { + let scalars = self.values.iter().cloned().collect::>(); + let arr = + ScalarValue::new_list_nullable(scalars.as_slice(), &self.state_data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = &values[0]; + if arr.data_type() == &DataType::Null { + return Ok(()); + } + + (0..arr.len()).try_for_each(|index| { + if !arr.is_null(index) { + let scalar = ScalarValue::try_from_array(arr, index)?; + self.values.insert(scalar); + } + Ok(()) + }) + } + + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!(states.len(), 1, "array_agg states must be singleton!"); + let array = &states[0]; + let list_array = array.as_list::(); + for inner_array in list_array.iter() { + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; + self.update_batch(&[inner_array])?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + match &self.state_data_type { + DataType::Boolean | DataType::Null => self.fixed_size(), + d if d.is_primitive() => self.fixed_size(), + _ => self.full_size(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::NullArray; + + #[test] + fn count_accumulator_nulls() -> Result<()> { + let mut accumulator = CountAccumulator::new(); + accumulator.update_batch(&[Arc::new(NullArray::new(10))])?; + assert_eq!(accumulator.evaluate()?, ScalarValue::Int64(Some(0))); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs new file mode 100644 index 000000000000..063aaa92059d --- /dev/null +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -0,0 +1,454 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CovarianceSample`]: covariance sample aggregations. + +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::OnceLock; + +use arrow::{ + array::{ArrayRef, Float64Array, UInt64Array}, + compute::kernels::cast, + datatypes::{DataType, Field}, +}; + +use datafusion_common::{ + downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, + ScalarValue, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + type_coercion::aggregates::NUMERICS, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::stats::StatsType; + +make_udaf_expr_and_func!( + CovarianceSample, + covar_samp, + y x, + "Computes the sample covariance.", + covar_samp_udaf +); + +make_udaf_expr_and_func!( + CovariancePopulation, + covar_pop, + y x, + "Computes the population covariance.", + covar_pop_udaf +); + +pub struct CovarianceSample { + signature: Signature, + aliases: Vec, +} + +impl Debug for CovarianceSample { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("CovarianceSample") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for CovarianceSample { + fn default() -> Self { + Self::new() + } +} + +impl CovarianceSample { + pub fn new() -> Self { + Self { + aliases: vec![String::from("covar")], + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CovarianceSample { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "covar_samp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean1"), DataType::Float64, true), + Field::new(format_state_name(name, "mean2"), DataType::Float64, true), + Field::new( + format_state_name(name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_covar_samp_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_covar_samp_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description("Returns the sample covariance of a set of number pairs.") + .with_syntax_example("covar_samp(expression1, expression2)") + .with_sql_example( + r#"```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) + .build() + .unwrap() + }) +} + +pub struct CovariancePopulation { + signature: Signature, +} + +impl Debug for CovariancePopulation { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("CovariancePopulation") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for CovariancePopulation { + fn default() -> Self { + Self::new() + } +} + +impl CovariancePopulation { + pub fn new() -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CovariancePopulation { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "covar_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean1"), DataType::Float64, true), + Field::new(format_state_name(name, "mean2"), DataType::Float64, true), + Field::new( + format_state_name(name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_covar_pop_doc()) + } +} + +fn get_covar_pop_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the population covariance of a set of number pairs.", + ) + .with_syntax_example("covar_pop(expression1, expression2)") + .with_sql_example( + r#"```sql +> SELECT covar_pop(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_pop(column1, column2) | ++-----------------------------------+ +| 7.63 | ++-----------------------------------+ +```"#, + ) + .with_standard_argument("expression1", Some("First")) + .with_standard_argument("expression2", Some("Second")) + .build() + .unwrap() + }) +} + +/// An accumulator to compute covariance +/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper +/// for calculating variance: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. +/// +/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, +/// parallelizable and numerically stable. + +#[derive(Debug)] +pub struct CovarianceAccumulator { + algo_const: f64, + mean1: f64, + mean2: f64, + count: u64, + stats_type: StatsType, +} + +impl CovarianceAccumulator { + /// Creates a new `CovarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + algo_const: 0_f64, + mean1: 0_f64, + mean2: 0_f64, + count: 0_u64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean1(&self) -> f64 { + self.mean1 + } + + pub fn get_mean2(&self) -> f64 { + self.mean2 + } + + pub fn get_algo_const(&self) -> f64 { + self.algo_const + } +} + +impl Accumulator for CovarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean1), + ScalarValue::from(self.mean2), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + let new_count = self.count + 1; + let delta1 = value1 - self.mean1; + let new_mean1 = delta1 / new_count as f64 + self.mean1; + let delta2 = value2 - self.mean2; + let new_mean2 = delta2 / new_count as f64 + self.mean2; + let new_c = delta1 * (value2 - new_mean2) + self.algo_const; + + self.count += 1; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + + let new_count = self.count - 1; + let delta1 = self.mean1 - value1; + let new_mean1 = delta1 / new_count as f64 + self.mean1; + let delta2 = self.mean2 - value2; + let new_mean2 = delta2 / new_count as f64 + self.mean2; + let new_c = self.algo_const - delta1 * (new_mean2 - value2); + + self.count -= 1; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let means1 = downcast_value!(states[1], Float64Array); + let means2 = downcast_value!(states[2], Float64Array); + let cs = downcast_value!(states[3], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + let new_count = self.count + c; + let new_mean1 = self.mean1 * self.count as f64 / new_count as f64 + + means1.value(i) * c as f64 / new_count as f64; + let new_mean2 = self.mean2 * self.count as f64 / new_count as f64 + + means2.value(i) * c as f64 / new_count as f64; + let delta1 = self.mean1 - means1.value(i); + let delta2 = self.mean2 - means2.value(i); + let new_c = self.algo_const + + cs.value(i) + + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + if count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) + } + } + + fn size(&self) -> usize { + size_of_val(self) + } +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8dc4cee87a3b..0b05713499a9 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -17,43 +17,47 @@ //! Defines the FIRST_VALUE/LAST_VALUE aggregations. +use std::any::Any; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::{Arc, OnceLock}; + use arrow::array::{ArrayRef, AsArray, BooleanArray}; -use arrow::compute::{self, lexsort_to_indices, SortColumn, SortOptions}; +use arrow::compute::{self, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::utils::{compare_rows, get_arrayref_at_indices, get_row_at_idx}; +use datafusion_common::utils::{compare_rows, get_row_at_idx}; use datafusion_common::{ arrow_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::function::AccumulatorArgs; -use datafusion_expr::type_coercion::aggregates::NUMERICS; -use datafusion_expr::utils::format_state_name; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::{format_state_name, AggregateOrderSensitivity}; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature, - TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Documentation, Expr, + ExprFunctionExt, Signature, SortExpr, TypeSignature, Volatility, }; -use datafusion_physical_expr_common::aggregate::utils::{ - down_cast_any_ref, get_sort_options, ordering_fields, -}; -use datafusion_physical_expr_common::aggregate::AggregateExpr; -use datafusion_physical_expr_common::expressions; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use datafusion_physical_expr_common::utils::reverse_order_bys; -use sqlparser::ast::NullTreatment; -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; - -make_udaf_function!( - FirstValue, - first_value, - "Returns the first value in a group of values.", - first_value_udaf -); +use datafusion_functions_aggregate_common::utils::get_sort_options; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; + +create_func!(FirstValue, first_value_udaf); + +/// Returns the first value in a group of values. +pub fn first_value(expression: Expr, order_by: Option>) -> Expr { + if let Some(order_by) = order_by { + first_value_udaf() + .call(vec![expression]) + .order_by(order_by) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() + } else { + first_value_udaf().call(vec![expression]) + } +} pub struct FirstValue { signature: Signature, - aliases: Vec, + requirement_satisfied: bool, } impl Debug for FirstValue { @@ -75,17 +79,23 @@ impl Default for FirstValue { impl FirstValue { pub fn new() -> Self { Self { - aliases: vec![String::from("FIRST_VALUE")], signature: Signature::one_of( vec![ // TODO: we can introduce more strict signature that only numeric of array types are allowed TypeSignature::ArraySignature(ArrayFunctionSignature::Array), - TypeSignature::Uniform(1, NUMERICS.to_vec()), + TypeSignature::Numeric(1), + TypeSignature::Uniform(1, vec![DataType::Utf8]), ], Volatility::Immutable, ), + requirement_satisfied: false, } } + + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + self.requirement_satisfied = requirement_satisfied; + self + } } impl AggregateUDFImpl for FirstValue { @@ -94,7 +104,7 @@ impl AggregateUDFImpl for FirstValue { } fn name(&self) -> &str { - "FIRST_VALUE" + "first_value" } fn signature(&self) -> &Signature { @@ -106,68 +116,88 @@ impl AggregateUDFImpl for FirstValue { } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let mut all_sort_orders = vec![]; - - // Construct PhysicalSortExpr objects from Expr objects: - let mut sort_exprs = vec![]; - for expr in acc_args.sort_exprs { - if let Expr::Sort(sort) = expr { - if let Expr::Column(col) = sort.expr.as_ref() { - let name = &col.name; - let e = expressions::column::col(name, acc_args.schema)?; - sort_exprs.push(PhysicalSortExpr { - expr: e, - options: SortOptions { - descending: !sort.asc, - nulls_first: sort.nulls_first, - }, - }); - } - } - } - if !sort_exprs.is_empty() { - all_sort_orders.extend(sort_exprs); - } - - let ordering_req = all_sort_orders; - - let ordering_dtypes = ordering_req + let ordering_dtypes = acc_args + .ordering_req .iter() .map(|e| e.expr.data_type(acc_args.schema)) .collect::>>()?; - let requirement_satisfied = ordering_req.is_empty(); + // When requirement is empty, or it is signalled by outside caller that + // the ordering requirement is/will be satisfied. + let requirement_satisfied = + acc_args.ordering_req.is_empty() || self.requirement_satisfied; FirstValueAccumulator::try_new( - acc_args.data_type, + acc_args.return_type, &ordering_dtypes, - ordering_req, + LexOrdering::from_ref(acc_args.ordering_req), acc_args.ignore_nulls, ) .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields( - &self, - name: &str, - value_type: DataType, - ordering_fields: Vec, - ) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new( - format_state_name(name, "first_value"), - value_type, + format_state_name(args.name, "first_value"), + args.return_type.clone(), true, )]; - fields.extend(ordering_fields); + fields.extend(args.ordering_fields.to_vec()); fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } fn aliases(&self) -> &[String] { - &self.aliases + &[] + } + + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new( + FirstValue::new().with_requirement_satisfied(beneficial_ordering), + ))) + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Beneficial + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(last_value_udaf()) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_first_value_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_first_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", + ) + .with_syntax_example("first_value(expression [ORDER BY expression])") + .with_sql_example(r#"```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + #[derive(Debug)] pub struct FirstValueAccumulator { first: ScalarValue, @@ -243,7 +273,7 @@ impl FirstValueAccumulator { .iter() .zip(self.ordering_req.iter()) .map(|(values, req)| SortColumn { - values: values.clone(), + values: Arc::clone(values), options: Some(req.options), }) .collect::>(); @@ -285,7 +315,7 @@ impl Accumulator for FirstValueAccumulator { if compare_rows( &self.orderings, orderings, - &get_sort_options(&self.ordering_req), + &get_sort_options(self.ordering_req.as_ref()), )? .is_gt() { @@ -303,21 +333,23 @@ impl Accumulator for FirstValueAccumulator { let flags = states[is_set_idx].as_boolean(); let filtered_states = filter_states_according_to_is_set(states, flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_cols = - convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + let sort_cols = convert_to_sort_cols( + &filtered_states[1..is_set_idx], + self.ordering_req.as_ref(), + ); let ordered_states = if sort_cols.is_empty() { // When no ordering is given, use the existing state as is: filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - get_arrayref_at_indices(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { let first_row = get_row_at_idx(&ordered_states, 0)?; // When collecting orderings, we exclude the is_set flag from the state. let first_ordering = &first_row[1..is_set_idx]; - let sort_options = get_sort_options(&self.ordering_req); + let sort_options = get_sort_options(self.ordering_req.as_ref()); // Either there is no existing value, or there is an earlier version in new data. if !self.is_set || compare_rows(&self.orderings, first_ordering, &sort_options)?.is_gt() @@ -336,363 +368,164 @@ impl Accumulator for FirstValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.first) + size_of_val(self) - size_of_val(&self.first) + self.first.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } -/// TO BE DEPRECATED: Builtin FIRST_VALUE physical aggregate expression will be replaced by udf in the future -#[derive(Debug, Clone)] -pub struct FirstValuePhysicalExpr { - name: String, - input_data_type: DataType, - order_by_data_types: Vec, - expr: Arc, - ordering_req: LexOrdering, +make_udaf_expr_and_func!( + LastValue, + last_value, + "Returns the last value in a group of values.", + last_value_udaf +); + +pub struct LastValue { + signature: Signature, requirement_satisfied: bool, - ignore_nulls: bool, - state_fields: Vec, } -impl FirstValuePhysicalExpr { - /// Creates a new FIRST_VALUE aggregation function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ordering_req: LexOrdering, - order_by_data_types: Vec, - state_fields: Vec, - ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); - Self { - name: name.into(), - input_data_type, - order_by_data_types, - expr, - ordering_req, - requirement_satisfied, - ignore_nulls: false, - state_fields, - } - } - - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; - self - } - - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name - } - - /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type - } - - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types +impl Debug for LastValue { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("LastValue") + .field("name", &self.name()) + .field("signature", &self.signature) + .field("accumulator", &"") + .finish() } +} - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr +impl Default for LastValue { + fn default() -> Self { + Self::new() } +} - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req +impl LastValue { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + // TODO: we can introduce more strict signature that only numeric of array types are allowed + TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::Numeric(1), + TypeSignature::Uniform(1, vec![DataType::Utf8]), + ], + Volatility::Immutable, + ), + requirement_satisfied: false, + } } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { + fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { self.requirement_satisfied = requirement_satisfied; self } - - pub fn convert_to_last(self) -> LastValuePhysicalExpr { - let mut name = format!("LAST{}", &self.name[5..]); - replace_order_by_clause(&mut name); - - let FirstValuePhysicalExpr { - expr, - input_data_type, - ordering_req, - order_by_data_types, - .. - } = self; - LastValuePhysicalExpr::new( - expr, - name, - input_data_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - ) - } } -impl AggregateExpr for FirstValuePhysicalExpr { - /// Return a reference to Any that can be used for downcasting +impl AggregateUDFImpl for LastValue { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } - - fn state_fields(&self) -> Result> { - if !self.state_fields.is_empty() { - return Ok(self.state_fields.clone()); - } - - let mut fields = vec![Field::new( - format_state_name(&self.name, "first_value"), - self.input_data_type.clone(), - true, - )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); - Ok(fields) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } - fn name(&self) -> &str { - &self.name - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_last())) - } - - fn create_sliding_accumulator(&self) -> Result> { - FirstValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) - } -} - -impl PartialEq for FirstValuePhysicalExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// TO BE DEPRECATED: Builtin LAST_VALUE physical aggregate expression will be replaced by udf in the future -#[derive(Debug, Clone)] -pub struct LastValuePhysicalExpr { - name: String, - input_data_type: DataType, - order_by_data_types: Vec, - expr: Arc, - ordering_req: LexOrdering, - requirement_satisfied: bool, - ignore_nulls: bool, -} - -impl LastValuePhysicalExpr { - /// Creates a new LAST_VALUE aggregation function. - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - ordering_req: LexOrdering, - order_by_data_types: Vec, - ) -> Self { - let requirement_satisfied = ordering_req.is_empty(); - Self { - name: name.into(), - input_data_type, - order_by_data_types, - expr, - ordering_req, - requirement_satisfied, - ignore_nulls: false, - } - } - - pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { - self.ignore_nulls = ignore_nulls; - self - } - - /// Returns the name of the aggregate expression. - pub fn name(&self) -> &str { - &self.name - } - - /// Returns the input data type of the aggregate expression. - pub fn input_data_type(&self) -> &DataType { - &self.input_data_type - } - - /// Returns the data types of the order-by columns. - pub fn order_by_data_types(&self) -> &Vec { - &self.order_by_data_types - } - - /// Returns the expression associated with the aggregate function. - pub fn expr(&self) -> &Arc { - &self.expr - } - - /// Returns the lexical ordering requirements of the aggregate expression. - pub fn ordering_req(&self) -> &LexOrdering { - &self.ordering_req + "last_value" } - pub fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { - self.requirement_satisfied = requirement_satisfied; - self + fn signature(&self) -> &Signature { + &self.signature } - pub fn convert_to_first(self) -> FirstValuePhysicalExpr { - let mut name = format!("FIRST{}", &self.name[4..]); - replace_order_by_clause(&mut name); - - let LastValuePhysicalExpr { - expr, - input_data_type, - ordering_req, - order_by_data_types, - .. - } = self; - FirstValuePhysicalExpr::new( - expr, - name, - input_data_type, - reverse_order_bys(&ordering_req), - order_by_data_types, - vec![], - ) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) } -} -impl AggregateExpr for LastValuePhysicalExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let ordering_dtypes = acc_args + .ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) - } + let requirement_satisfied = + acc_args.ordering_req.is_empty() || self.requirement_satisfied; - fn create_accumulator(&self) -> Result> { LastValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, + acc_args.return_type, + &ordering_dtypes, + LexOrdering::from_ref(acc_args.ordering_req), + acc_args.ignore_nulls, ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + .map(|acc| Box::new(acc.with_requirement_satisfied(requirement_satisfied)) as _) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let StateFieldsArgs { + name, + input_types, + return_type: _, + ordering_fields, + is_distinct: _, + } = args; let mut fields = vec![Field::new( - format_state_name(&self.name, "last_value"), - self.input_data_type.clone(), + format_state_name(name, "last_value"), + input_types[0].clone(), true, )]; - fields.extend(ordering_fields( - &self.ordering_req, - &self.order_by_data_types, - )); - fields.push(Field::new( - format_state_name(&self.name, "is_set"), - DataType::Boolean, - true, - )); + fields.extend(ordering_fields.to_vec()); + fields.push(Field::new("is_set", DataType::Boolean, true)); Ok(fields) } - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + fn aliases(&self) -> &[String] { + &[] } - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result>> { + Ok(Some(Arc::new( + LastValue::new().with_requirement_satisfied(beneficial_ordering), + ))) } - fn name(&self) -> &str { - &self.name + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Beneficial } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone().convert_to_first())) + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Reversed(first_value_udaf()) } - fn create_sliding_accumulator(&self) -> Result> { - LastValueAccumulator::try_new( - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - self.ignore_nulls, - ) - .map(|acc| { - Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ - }) + fn documentation(&self) -> Option<&Documentation> { + Some(get_last_value_doc()) } } -impl PartialEq for LastValuePhysicalExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } +fn get_last_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group.", + ) + .with_syntax_example("last_value(expression [ORDER BY expression])") + .with_sql_example(r#"```sql +> SELECT last_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| last_value(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| last_element | ++-----------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -766,7 +599,7 @@ impl LastValueAccumulator { // Take the reverse ordering requirement. This enables us to // use "fetch = 1" to get the last value. SortColumn { - values: values.clone(), + values: Arc::clone(values), options: Some(!req.options), } }) @@ -814,7 +647,7 @@ impl Accumulator for LastValueAccumulator { if compare_rows( &self.orderings, orderings, - &get_sort_options(&self.ordering_req), + &get_sort_options(self.ordering_req.as_ref()), )? .is_lt() { @@ -832,15 +665,17 @@ impl Accumulator for LastValueAccumulator { let flags = states[is_set_idx].as_boolean(); let filtered_states = filter_states_according_to_is_set(states, flags)?; // 1..is_set_idx range corresponds to ordering section - let sort_cols = - convert_to_sort_cols(&filtered_states[1..is_set_idx], &self.ordering_req); + let sort_cols = convert_to_sort_cols( + &filtered_states[1..is_set_idx], + self.ordering_req.as_ref(), + ); let ordered_states = if sort_cols.is_empty() { // When no ordering is given, use existing state as is: filtered_states } else { let indices = lexsort_to_indices(&sort_cols, None)?; - get_arrayref_at_indices(&filtered_states, &indices)? + take_arrays(&filtered_states, &indices, None)? }; if !ordered_states[0].is_empty() { @@ -848,7 +683,7 @@ impl Accumulator for LastValueAccumulator { let last_row = get_row_at_idx(&ordered_states, last_idx)?; // When collecting orderings, we exclude the is_set flag from the state. let last_ordering = &last_row[1..is_set_idx]; - let sort_options = get_sort_options(&self.ordering_req); + let sort_options = get_sort_options(self.ordering_req.as_ref()); // Either there is no existing value, or there is a newer (latest) // version in the new data: if !self.is_set @@ -868,10 +703,10 @@ impl Accumulator for LastValueAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.last) + size_of_val(self) - size_of_val(&self.last) + self.last.size() + ScalarValue::size_of_vec(&self.orderings) - - std::mem::size_of_val(&self.orderings) + - size_of_val(&self.orderings) } } @@ -890,42 +725,17 @@ fn filter_states_according_to_is_set( /// Combines array refs and their corresponding orderings to construct `SortColumn`s. fn convert_to_sort_cols( arrs: &[ArrayRef], - sort_exprs: &[PhysicalSortExpr], + sort_exprs: LexOrderingRef, ) -> Vec { arrs.iter() .zip(sort_exprs.iter()) .map(|(item, sort_expr)| SortColumn { - values: item.clone(), + values: Arc::clone(item), options: Some(sort_expr.options), }) .collect::>() } -fn replace_order_by_clause(order_by: &mut String) { - let suffixes = [ - (" DESC NULLS FIRST]", " ASC NULLS LAST]"), - (" ASC NULLS FIRST]", " DESC NULLS LAST]"), - (" DESC NULLS LAST]", " ASC NULLS FIRST]"), - (" ASC NULLS LAST]", " DESC NULLS FIRST]"), - ]; - - if let Some(start) = order_by.find("ORDER BY [") { - if let Some(end) = order_by[start..].find(']') { - let order_by_start = start + 9; - let order_by_end = start + end; - - let column_order = &order_by[order_by_start..=order_by_end]; - for &(suffix, replacement) in &suffixes { - if column_order.ends_with(suffix) { - let new_order = column_order.replace(suffix, replacement); - order_by.replace_range(order_by_start..=order_by_end, &new_order); - break; - } - } - } - } -} - #[cfg(test)] mod tests { use arrow::array::Int64Array; @@ -934,10 +744,18 @@ mod tests { #[test] fn test_first_last_value_value() -> Result<()> { - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -951,7 +769,7 @@ mod tests { for arr in arrs { // Once first_value is set, accumulator should remember it. // It shouldn't update first_value for each new batch - first_accumulator.update_batch(&[arr.clone()])?; + first_accumulator.update_batch(&[Arc::clone(&arr)])?; // last_value should be updated for each new batch. last_accumulator.update_batch(&[arr])?; } @@ -974,15 +792,23 @@ mod tests { .collect::>(); // FirstValueAccumulator - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - - first_accumulator.update_batch(&[arrs[0].clone()])?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; + + first_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = first_accumulator.state()?; - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - first_accumulator.update_batch(&[arrs[1].clone()])?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; + first_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = first_accumulator.state()?; assert_eq!(state1.len(), state2.len()); @@ -990,29 +816,41 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); } - let mut first_accumulator = - FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut first_accumulator = FirstValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; first_accumulator.merge_batch(&states)?; let merged_state = first_accumulator.state()?; assert_eq!(merged_state.len(), state1.len()); // LastValueAccumulator - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - - last_accumulator.update_batch(&[arrs[0].clone()])?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; + + last_accumulator.update_batch(&[Arc::clone(&arrs[0])])?; let state1 = last_accumulator.state()?; - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; - last_accumulator.update_batch(&[arrs[1].clone()])?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; + last_accumulator.update_batch(&[Arc::clone(&arrs[1])])?; let state2 = last_accumulator.state()?; assert_eq!(state1.len(), state2.len()); @@ -1020,14 +858,18 @@ mod tests { let mut states = vec![]; for idx in 0..state1.len() { - states.push(arrow::compute::concat(&[ + states.push(compute::concat(&[ &state1[idx].to_array()?, &state2[idx].to_array()?, ])?); } - let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; + let mut last_accumulator = LastValueAccumulator::try_new( + &DataType::Int64, + &[], + LexOrdering::default(), + false, + )?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; diff --git a/datafusion/functions-aggregate/src/grouping.rs b/datafusion/functions-aggregate/src/grouping.rs new file mode 100644 index 000000000000..27949aa3df27 --- /dev/null +++ b/datafusion/functions-aggregate/src/grouping.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::fmt; +use std::sync::OnceLock; + +use arrow::datatypes::DataType; +use arrow::datatypes::Field; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; + +make_udaf_expr_and_func!( + Grouping, + grouping, + expression, + "Returns 1 if the data is aggregated across the specified column or 0 for not aggregated in the result set.", + grouping_udaf +); + +pub struct Grouping { + signature: Signature, +} + +impl fmt::Debug for Grouping { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Grouping") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Grouping { + fn default() -> Self { + Self::new() + } +} + +impl Grouping { + /// Create a new GROUPING aggregate function. + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Grouping { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "grouping" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![Field::new( + format_state_name(args.name, "grouping"), + DataType::Int32, + true, + )]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + not_impl_err!( + "physical plan is not yet implemented for GROUPING aggregate function" + ) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_grouping_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_grouping_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set.", + ) + .with_syntax_example("grouping(expression)") + .with_sql_example(r#"```sql +> SELECT column_name, GROUPING(column_name) AS group_column + FROM table_name + GROUP BY GROUPING SETS ((column_name), ()); ++-------------+-------------+ +| column_name | group_column | ++-------------+-------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+-------------+ +```"#, + ) + .with_argument("expression", "Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function.") + .build() + .unwrap() + }) +} diff --git a/datafusion/physical-expr/src/aggregate/hyperloglog.rs b/datafusion/functions-aggregate/src/hyperloglog.rs similarity index 99% rename from datafusion/physical-expr/src/aggregate/hyperloglog.rs rename to datafusion/functions-aggregate/src/hyperloglog.rs index 657a7b9f7f21..3074889eab23 100644 --- a/datafusion/physical-expr/src/aggregate/hyperloglog.rs +++ b/datafusion/functions-aggregate/src/hyperloglog.rs @@ -20,7 +20,7 @@ //! `hyperloglog` is a module that contains a modified version //! of [redis's implementation](https://github.com/redis/redis/blob/4930d19e70c391750479951022e207e19111eb55/src/hyperloglog.c) //! with some modification based on strong assumption of usage -//! within datafusion, so that [`datafusion_expr::approx_distinct`] function can +//! within datafusion, so that function can //! be efficiently implemented. //! //! Specifically, like Redis's version, this HLL structure uses diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 8016b76889f7..ca0276d326a4 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Aggregate Function packages for [DataFusion]. //! @@ -50,13 +52,37 @@ //! 3. Add a new feature to `Cargo.toml`, with any optional dependencies //! //! 4. Use the `make_package!` macro to expose the module when the -//! feature is enabled. +//! feature is enabled. #[macro_use] pub mod macros; +pub mod approx_distinct; +pub mod array_agg; +pub mod correlation; +pub mod count; +pub mod covariance; pub mod first_last; +pub mod hyperloglog; +pub mod median; +pub mod min_max; +pub mod regr; +pub mod stddev; +pub mod sum; +pub mod variance; +pub mod approx_median; +pub mod approx_percentile_cont; +pub mod approx_percentile_cont_with_weight; +pub mod average; +pub mod bit_and_or_xor; +pub mod bool_and_or; +pub mod grouping; +pub mod nth_value; +pub mod string_agg; + +use crate::approx_percentile_cont::approx_percentile_cont_udaf; +use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::AggregateUDF; @@ -65,12 +91,91 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::approx_distinct::approx_distinct; + pub use super::approx_median::approx_median; + pub use super::approx_percentile_cont::approx_percentile_cont; + pub use super::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight; + pub use super::array_agg::array_agg; + pub use super::average::avg; + pub use super::bit_and_or_xor::bit_and; + pub use super::bit_and_or_xor::bit_or; + pub use super::bit_and_or_xor::bit_xor; + pub use super::bool_and_or::bool_and; + pub use super::bool_and_or::bool_or; + pub use super::correlation::corr; + pub use super::count::count; + pub use super::count::count_distinct; + pub use super::covariance::covar_pop; + pub use super::covariance::covar_samp; pub use super::first_last::first_value; + pub use super::first_last::last_value; + pub use super::grouping::grouping; + pub use super::median::median; + pub use super::min_max::max; + pub use super::min_max::min; + pub use super::nth_value::nth_value; + pub use super::regr::regr_avgx; + pub use super::regr::regr_avgy; + pub use super::regr::regr_count; + pub use super::regr::regr_intercept; + pub use super::regr::regr_r2; + pub use super::regr::regr_slope; + pub use super::regr::regr_sxx; + pub use super::regr::regr_sxy; + pub use super::regr::regr_syy; + pub use super::stddev::stddev; + pub use super::stddev::stddev_pop; + pub use super::sum::sum; + pub use super::variance::var_pop; + pub use super::variance::var_sample; +} + +/// Returns all default aggregate functions +pub fn all_default_aggregate_functions() -> Vec> { + vec![ + array_agg::array_agg_udaf(), + first_last::first_value_udaf(), + first_last::last_value_udaf(), + covariance::covar_samp_udaf(), + covariance::covar_pop_udaf(), + correlation::corr_udaf(), + sum::sum_udaf(), + min_max::max_udaf(), + min_max::min_udaf(), + median::median_udaf(), + count::count_udaf(), + regr::regr_slope_udaf(), + regr::regr_intercept_udaf(), + regr::regr_count_udaf(), + regr::regr_r2_udaf(), + regr::regr_avgx_udaf(), + regr::regr_avgy_udaf(), + regr::regr_sxx_udaf(), + regr::regr_syy_udaf(), + regr::regr_sxy_udaf(), + variance::var_samp_udaf(), + variance::var_pop_udaf(), + stddev::stddev_udaf(), + stddev::stddev_pop_udaf(), + approx_median::approx_median_udaf(), + approx_distinct::approx_distinct_udaf(), + approx_percentile_cont_udaf(), + approx_percentile_cont_with_weight_udaf(), + string_agg::string_agg_udaf(), + bit_and_or_xor::bit_and_udaf(), + bit_and_or_xor::bit_or_udaf(), + bit_and_or_xor::bit_xor_udaf(), + bool_and_or::bool_and_udaf(), + bool_and_or::bool_or_udaf(), + average::avg_udaf(), + grouping::grouping_udaf(), + nth_value::nth_value_udaf(), + ] } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = vec![first_last::first_value_udaf()]; + let functions: Vec> = all_default_aggregate_functions(); functions.into_iter().try_for_each(|udf| { let existing_udaf = registry.register_udaf(udf)?; @@ -82,3 +187,36 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use crate::all_default_aggregate_functions; + use datafusion_common::Result; + use std::collections::HashSet; + + #[test] + fn test_no_duplicate_name() -> Result<()> { + let mut names = HashSet::new(); + let migrated_functions = ["array_agg", "count", "max", "min"]; + for func in all_default_aggregate_functions() { + // TODO: remove this + // These functions are in intermediate migration state, skip them + if migrated_functions.contains(&func.name().to_lowercase().as_str()) { + continue; + } + assert!( + names.insert(func.name().to_string().to_lowercase()), + "duplicate function name: {}", + func.name() + ); + for alias in func.aliases() { + assert!( + names.insert(alias.to_string().to_lowercase()), + "duplicate function name: {}", + alias + ); + } + } + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 04f9fecb8b19..ffb5183278e6 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,41 +15,67 @@ // specific language governing permissions and limitations // under the License. -macro_rules! make_udaf_function { +macro_rules! make_udaf_expr { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN( + $($arg: datafusion_expr::Expr,)* + ) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + false, + None, + None, + None, + )) + } + }; +} + +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { - paste::paste! { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN( - args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option - ) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - $AGGREGATE_UDF_FN(), - args, - distinct, - filter, - order_by, - null_treatment, - )) - } + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN( + args: Vec, + ) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + args, + false, + None, + None, + None, + )) + } + + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; +} +macro_rules! create_func { + ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); + }; + ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { + paste::paste! { /// Singleton instance of [$UDAF], ensures the UDAF is only created once /// named STATIC_$(UDAF). For example `STATIC_FirstValue` #[allow(non_upper_case_globals)] static [< STATIC_ $UDAF >]: std::sync::OnceLock> = std::sync::OnceLock::new(); - /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] - /// - /// [AggregateUDF]: datafusion_expr::AggregateUDF + #[doc = concat!("AggregateFunction that returns a [`AggregateUDF`](datafusion_expr::AggregateUDF) for [`", stringify!($UDAF), "`]")] pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) }) .clone() } diff --git a/datafusion/functions-aggregate/src/median.rs b/datafusion/functions-aggregate/src/median.rs new file mode 100644 index 000000000000..ff0a930d490b --- /dev/null +++ b/datafusion/functions-aggregate/src/median.rs @@ -0,0 +1,331 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; +use std::fmt::{Debug, Formatter}; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; + +use arrow::array::{downcast_integer, ArrowNumericType}; +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::{ + DataType, Decimal128Type, Decimal256Type, Field, Float16Type, Float32Type, + Float64Type, + }, +}; + +use arrow::array::Array; +use arrow::array::ArrowNativeTypeOp; +use arrow::datatypes::ArrowNativeType; + +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, + Documentation, Signature, Volatility, +}; +use datafusion_functions_aggregate_common::utils::Hashable; + +make_udaf_expr_and_func!( + Median, + median, + expression, + "Computes the median of a set of numbers", + median_udaf +); + +/// MEDIAN aggregate expression. If using the non-distinct variation, then this uses a +/// lot of memory because all values need to be stored in memory before a result can be +/// computed. If an approximation is sufficient then APPROX_MEDIAN provides a much more +/// efficient solution. +/// +/// If using the distinct variation, the memory usage will be similarly high if the +/// cardinality is high as it stores all distinct values in memory before computing the +/// result, but if cardinality is low then memory usage will also be lower. +pub struct Median { + signature: Signature, +} + +impl Debug for Median { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + f.debug_struct("Median") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Median { + fn default() -> Self { + Self::new() + } +} + +impl Median { + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for Median { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "median" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + //Intermediate state is a list of the elements we have collected so far + let field = Field::new("item", args.input_types[0].clone(), true); + let state_name = if args.is_distinct { + "distinct_median" + } else { + "median" + }; + + Ok(vec![Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + )]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + if acc_args.is_distinct { + Ok(Box::new(DistinctMedianAccumulator::<$t> { + data_type: $dt.clone(), + distinct_values: HashSet::new(), + })) + } else { + Ok(Box::new(MedianAccumulator::<$t> { + data_type: $dt.clone(), + all_values: vec![], + })) + } + }; + } + + let dt = acc_args.exprs[0].data_type(acc_args.schema)?; + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "MedianAccumulator not supported for {} with {}", + acc_args.name, + dt, + ))), + } + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_median_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_median_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the median value in the specified column.") + .with_syntax_example("median(expression)") + .with_sql_example( + r#"```sql +> SELECT median(column_name) FROM table_name; ++----------------------+ +| median(column_name) | ++----------------------+ +| 45.5 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +/// The median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct MedianAccumulator { + data_type: DataType, + all_values: Vec, +} + +impl Debug for MedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MedianAccumulator({})", self.data_type) + } +} + +impl Accumulator for MedianAccumulator { + fn state(&mut self) -> Result> { + let all_values = self + .all_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.all_values.reserve(values.len() - values.null_count()); + self.all_values.extend(values.iter().flatten()); + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.all_values); + let median = calculate_median::(d); + ScalarValue::new_primitive::(median, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.all_values.capacity() * size_of::() + } +} + +/// The distinct median accumulator accumulates the raw input values +/// as `ScalarValue`s +/// +/// The intermediate state is represented as a List of scalar values updated by +/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values +/// in the final evaluation step so that we avoid expensive conversions and +/// allocations during `update_batch`. +struct DistinctMedianAccumulator { + data_type: DataType, + distinct_values: HashSet>, +} + +impl Debug for DistinctMedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctMedianAccumulator({})", self.data_type) + } +} + +impl Accumulator for DistinctMedianAccumulator { + fn state(&mut self) -> Result> { + let all_values = self + .distinct_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(x.0), &self.data_type)) + .collect::>>()?; + + let arr = ScalarValue::new_list_nullable(&all_values, &self.data_type); + Ok(vec![ScalarValue::List(arr)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.distinct_values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.distinct_values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let d = std::mem::take(&mut self.distinct_values) + .into_iter() + .map(|v| v.0) + .collect::>(); + let median = calculate_median::(d); + ScalarValue::new_primitive::(median, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.distinct_values.capacity() * size_of::() + } +} + +fn calculate_median( + mut values: Vec, +) -> Option { + let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); + + let len = values.len(); + if len == 0 { + None + } else if len % 2 == 0 { + let (low, high, _) = values.select_nth_unstable_by(len / 2, cmp); + let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); + let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + Some(median) + } else { + let (_, median, _) = values.select_nth_unstable_by(len / 2, cmp); + Some(*median) + } +} diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs new file mode 100644 index 000000000000..b4256508e351 --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -0,0 +1,1805 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function +//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function + +mod min_max_bytes; + +use arrow::array::{ + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, +}; +use arrow::compute; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; +use arrow_schema::IntervalUnit; +use datafusion_common::stats::Precision; +use datafusion_common::{ + downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_physical_expr::expressions; +use std::fmt::Debug; + +use arrow::datatypes::i256; +use arrow::datatypes::{ + Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, + Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, +}; + +use crate::min_max::min_max_bytes::MinMaxBytesAccumulator; +use datafusion_common::ScalarValue; +use datafusion_expr::{ + function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Documentation, Signature, + Volatility, +}; +use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; +use half::f16; +use std::mem::size_of_val; +use std::ops::Deref; +use std::sync::OnceLock; + +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + if input_types.len() != 1 { + return exec_err!( + "min/max was called with {} arguments. It requires only 1.", + input_types.len() + ); + } + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), + } +} + +// MAX aggregate UDF +#[derive(Debug)] +pub struct Max { + signature: Signature, +} + +impl Max { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl Default for Max { + fn default() -> Self { + Self::new() + } +} +/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX` +/// the specified [`ArrowPrimitiveType`]. +/// +/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType +macro_rules! primitive_max_accumulator { + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new($DATA_TYPE, |cur, new| { + if *cur < new { + *cur = new + } + }) + // Initialize each accumulator to $NATIVE::MIN + .with_starting_value($NATIVE::MIN), + )) + }}; +} + +/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN` +/// the specified [`ArrowPrimitiveType`]. +/// +/// +/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType +macro_rules! primitive_min_accumulator { + ($DATA_TYPE:ident, $NATIVE:ident, $PRIMTYPE:ident) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$DATA_TYPE, |cur, new| { + if *cur > new { + *cur = new + } + }) + // Initialize each accumulator to $NATIVE::MAX + .with_starting_value($NATIVE::MAX), + )) + }}; +} + +trait FromColumnStatistics { + fn value_from_column_statistics( + &self, + stats: &ColumnStatistics, + ) -> Option; + + fn value_from_statistics( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { + match *num_rows { + 0 => return ScalarValue::try_from(statistics_args.return_type).ok(), + value if value > 0 => { + let col_stats = &statistics_args.statistics.column_statistics; + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + return self.value_from_column_statistics( + &col_stats[col_expr.index()], + ); + } + } + } + _ => {} + } + } + None + } +} + +impl FromColumnStatistics for Max { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.max_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + +impl AggregateUDFImpl for Max { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "max" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MaxAccumulator::try_new(acc_args.return_type)?)) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + use DataType::*; + matches!( + args.return_type, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + use DataType::*; + use TimeUnit::*; + let data_type = args.return_type; + match data_type { + Int8 => primitive_max_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_max_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_max_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_max_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_max_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_max_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_max_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_max_accumulator!(data_type, u64, UInt64Type), + Float16 => { + primitive_max_accumulator!(data_type, f16, Float16Type) + } + Float32 => { + primitive_max_accumulator!(data_type, f32, Float32Type) + } + Float64 => { + primitive_max_accumulator!(data_type, f64, Float64Type) + } + Date32 => primitive_max_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_max_accumulator!(data_type, i64, Date64Type), + Time32(Second) => { + primitive_max_accumulator!(data_type, i32, Time32SecondType) + } + Time32(Millisecond) => { + primitive_max_accumulator!(data_type, i32, Time32MillisecondType) + } + Time64(Microsecond) => { + primitive_max_accumulator!(data_type, i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + primitive_max_accumulator!(data_type, i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + primitive_max_accumulator!(data_type, i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + primitive_max_accumulator!(data_type, i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + primitive_max_accumulator!(data_type, i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + primitive_max_accumulator!(data_type, i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + primitive_max_accumulator!(data_type, i128, Decimal128Type) + } + Decimal256(_, _) => { + primitive_max_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_max(data_type.clone()))) + } + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for max({})", data_type), + } + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMaxAccumulator::try_new(args.return_type)?)) + } + + fn is_descending(&self) -> Option { + Some(true) + } + + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::Insensitive + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_result_type(arg_types) + } + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical + } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_max_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_max_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the maximum value in the specified column.") + .with_syntax_example("max(expression)") + .with_sql_example( + r#"```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +// Statically-typed version of min/max(array) -> ScalarValue for string types +macro_rules! typed_min_max_batch_string { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_string())); + ScalarValue::$SCALAR(value) + }}; +} +// Statically-typed version of min/max(array) -> ScalarValue for binay types. +macro_rules! typed_min_max_batch_binary { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + let value = value.and_then(|e| Some(e.to_vec())); + ScalarValue::$SCALAR(value) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +macro_rules! typed_min_max_batch { + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + let array = downcast_value!($VALUES, $ARRAYTYPE); + let value = compute::$OP(array); + ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) + }}; +} + +// Statically-typed version of min/max(array) -> ScalarValue for non-string types. +// this is a macro to support both operations (min and max). +macro_rules! min_max_batch { + ($VALUES:expr, $OP:ident) => {{ + match $VALUES.data_type() { + DataType::Null => ScalarValue::Null, + DataType::Decimal128(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal128Array, + Decimal128, + $OP, + precision, + scale + ) + } + DataType::Decimal256(precision, scale) => { + typed_min_max_batch!( + $VALUES, + Decimal256Array, + Decimal256, + $OP, + precision, + scale + ) + } + // all types that have a natural order + DataType::Float64 => { + typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) + } + DataType::Float32 => { + typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) + } + DataType::Float16 => { + typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) + } + DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), + DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), + DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), + DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), + DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), + DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), + DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), + DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), + DataType::Timestamp(TimeUnit::Second, tz_opt) => { + typed_min_max_batch!( + $VALUES, + TimestampSecondArray, + TimestampSecond, + $OP, + tz_opt + ) + } + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMillisecondArray, + TimestampMillisecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampMicrosecondArray, + TimestampMicrosecond, + $OP, + tz_opt + ), + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( + $VALUES, + TimestampNanosecondArray, + TimestampNanosecond, + $OP, + tz_opt + ), + DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), + DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), + DataType::Time32(TimeUnit::Second) => { + typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) + } + DataType::Time32(TimeUnit::Millisecond) => { + typed_min_max_batch!( + $VALUES, + Time32MillisecondArray, + Time32Millisecond, + $OP + ) + } + DataType::Time64(TimeUnit::Microsecond) => { + typed_min_max_batch!( + $VALUES, + Time64MicrosecondArray, + Time64Microsecond, + $OP + ) + } + DataType::Time64(TimeUnit::Nanosecond) => { + typed_min_max_batch!( + $VALUES, + Time64NanosecondArray, + Time64Nanosecond, + $OP + ) + } + DataType::Interval(IntervalUnit::YearMonth) => { + typed_min_max_batch!( + $VALUES, + IntervalYearMonthArray, + IntervalYearMonth, + $OP + ) + } + DataType::Interval(IntervalUnit::DayTime) => { + typed_min_max_batch!($VALUES, IntervalDayTimeArray, IntervalDayTime, $OP) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + typed_min_max_batch!( + $VALUES, + IntervalMonthDayNanoArray, + IntervalMonthDayNano, + $OP + ) + } + other => { + // This should have been handled before + return internal_err!( + "Min/Max accumulator not implemented for type {:?}", + other + ); + } + } + }}; +} + +/// dynamically-typed min(array) -> ScalarValue +fn min_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, min_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + min_string_view + ) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + min_binary + ) + } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + min_binary_view + ) + } + _ => min_max_batch!(values, min), + }) +} + +/// dynamically-typed max(array) -> ScalarValue +fn max_batch(values: &ArrayRef) -> Result { + Ok(match values.data_type() { + DataType::Utf8 => { + typed_min_max_batch_string!(values, StringArray, Utf8, max_string) + } + DataType::LargeUtf8 => { + typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + max_string_view + ) + } + DataType::Boolean => { + typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) + } + DataType::Binary => { + typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) + } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + max_binary_view + ) + } + DataType::LargeBinary => { + typed_min_max_batch_binary!( + &values, + LargeBinaryArray, + LargeBinary, + max_binary + ) + } + _ => min_max_batch!(values, max), + }) +} + +// min/max of two non-string scalar values. +macro_rules! typed_min_max { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ + ScalarValue::$SCALAR( + match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => Some((*a).$OP(*b)), + }, + $($EXTRA_ARGS.clone()),* + ) + }}; +} +macro_rules! typed_min_max_float { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) => match a.total_cmp(b) { + choose_min_max!($OP) => Some(*b), + _ => Some(*a), + }, + }) + }}; +} + +// min/max of two scalar string values. +macro_rules! typed_min_max_string { + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ + ScalarValue::$SCALAR(match ($VALUE, $DELTA) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) => Some((a).$OP(b).clone()), + }) + }}; +} + +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; +} + +macro_rules! interval_min_max { + ($OP:tt, $LHS:expr, $RHS:expr) => {{ + match $LHS.partial_cmp(&$RHS) { + Some(choose_min_max!($OP)) => $RHS.clone(), + Some(_) => $LHS.clone(), + None => { + return internal_err!("Comparison error while computing interval min/max") + } + } + }}; +} + +// min/max of two scalar values of the same type +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ + Ok(match ($VALUE, $DELTA) { + (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + ( + lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) + ) => { + if lhsp.eq(rhsp) && lhss.eq(rhss) { + typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) + } else { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ); + } + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + typed_min_max!(lhs, rhs, Boolean, $OP) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + typed_min_max_float!(lhs, rhs, Float64, $OP) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + typed_min_max_float!(lhs, rhs, Float32, $OP) + } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + typed_min_max_float!(lhs, rhs, Float16, $OP) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + typed_min_max!(lhs, rhs, UInt64, $OP) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + typed_min_max!(lhs, rhs, UInt32, $OP) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + typed_min_max!(lhs, rhs, UInt16, $OP) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + typed_min_max!(lhs, rhs, UInt8, $OP) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + typed_min_max!(lhs, rhs, Int64, $OP) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + typed_min_max!(lhs, rhs, Int32, $OP) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + typed_min_max!(lhs, rhs, Int16, $OP) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + typed_min_max!(lhs, rhs, Int8, $OP) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8, $OP) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) + } + (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8View, $OP) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + typed_min_max_string!(lhs, rhs, Binary, $OP) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + typed_min_max_string!(lhs, rhs, LargeBinary, $OP) + } + (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { + typed_min_max_string!(lhs, rhs, BinaryView, $OP) + } + (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { + typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) + } + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) + } + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => { + typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) + } + ( + ScalarValue::Date32(lhs), + ScalarValue::Date32(rhs), + ) => { + typed_min_max!(lhs, rhs, Date32, $OP) + } + ( + ScalarValue::Date64(lhs), + ScalarValue::Date64(rhs), + ) => { + typed_min_max!(lhs, rhs, Date64, $OP) + } + ( + ScalarValue::Time32Second(lhs), + ScalarValue::Time32Second(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Second, $OP) + } + ( + ScalarValue::Time32Millisecond(lhs), + ScalarValue::Time32Millisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time32Millisecond, $OP) + } + ( + ScalarValue::Time64Microsecond(lhs), + ScalarValue::Time64Microsecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Microsecond, $OP) + } + ( + ScalarValue::Time64Nanosecond(lhs), + ScalarValue::Time64Nanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) + } + ( + ScalarValue::IntervalYearMonth(lhs), + ScalarValue::IntervalYearMonth(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) + } + ( + ScalarValue::IntervalDayTime(lhs), + ScalarValue::IntervalDayTime(rhs), + ) => { + typed_min_max!(lhs, rhs, IntervalDayTime, $OP) + } + ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalMonthDayNano(_), + ) | ( + ScalarValue::IntervalYearMonth(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalDayTime(_), + ) | ( + ScalarValue::IntervalMonthDayNano(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalYearMonth(_), + ) | ( + ScalarValue::IntervalDayTime(_), + ScalarValue::IntervalMonthDayNano(_), + ) => { + interval_min_max!($OP, $VALUE, $DELTA) + } + ( + ScalarValue::DurationSecond(lhs), + ScalarValue::DurationSecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationSecond, $OP) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMillisecond, $OP) + } + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) + } + ( + ScalarValue::DurationNanosecond(lhs), + ScalarValue::DurationNanosecond(rhs), + ) => { + typed_min_max!(lhs, rhs, DurationNanosecond, $OP) + } + e => { + return internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + e + ) + } + }) + }}; +} + +/// An accumulator to compute the maximum value +#[derive(Debug)] +pub struct MaxAccumulator { + max: ScalarValue, +} + +impl MaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &max_batch(values)?; + let new_max: Result = + min_max!(&self.max, delta, max); + self.max = new_max?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.max) + self.max.size() + } +} + +#[derive(Debug)] +pub struct SlidingMaxAccumulator { + max: ScalarValue, + moving_max: MovingMax, +} + +impl SlidingMaxAccumulator { + /// new max accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + max: ScalarValue::try_from(datatype)?, + moving_max: MovingMax::::new(), + }) + } +} + +impl Accumulator for SlidingMaxAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + self.moving_max.push(val); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for _idx in 0..values[0].len() { + (self.moving_max).pop(); + } + if let Some(res) = self.moving_max.max() { + self.max = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.max.clone()]) + } + + fn evaluate(&mut self) -> Result { + Ok(self.max.clone()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.max) + self.max.size() + } +} + +#[derive(Debug)] +pub struct Min { + signature: Signature, +} + +impl Min { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl Default for Min { + fn default() -> Self { + Self::new() + } +} + +impl FromColumnStatistics for Min { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.min_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + +impl AggregateUDFImpl for Min { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "min" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].to_owned()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(MinAccumulator::try_new(acc_args.return_type)?)) + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + use DataType::*; + matches!( + args.return_type, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float16 + | Float32 + | Float64 + | Decimal128(_, _) + | Decimal256(_, _) + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Utf8 + | LargeUtf8 + | Utf8View + | Binary + | LargeBinary + | BinaryView + ) + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + use DataType::*; + use TimeUnit::*; + let data_type = args.return_type; + match data_type { + Int8 => primitive_min_accumulator!(data_type, i8, Int8Type), + Int16 => primitive_min_accumulator!(data_type, i16, Int16Type), + Int32 => primitive_min_accumulator!(data_type, i32, Int32Type), + Int64 => primitive_min_accumulator!(data_type, i64, Int64Type), + UInt8 => primitive_min_accumulator!(data_type, u8, UInt8Type), + UInt16 => primitive_min_accumulator!(data_type, u16, UInt16Type), + UInt32 => primitive_min_accumulator!(data_type, u32, UInt32Type), + UInt64 => primitive_min_accumulator!(data_type, u64, UInt64Type), + Float16 => { + primitive_min_accumulator!(data_type, f16, Float16Type) + } + Float32 => { + primitive_min_accumulator!(data_type, f32, Float32Type) + } + Float64 => { + primitive_min_accumulator!(data_type, f64, Float64Type) + } + Date32 => primitive_min_accumulator!(data_type, i32, Date32Type), + Date64 => primitive_min_accumulator!(data_type, i64, Date64Type), + Time32(Second) => { + primitive_min_accumulator!(data_type, i32, Time32SecondType) + } + Time32(Millisecond) => { + primitive_min_accumulator!(data_type, i32, Time32MillisecondType) + } + Time64(Microsecond) => { + primitive_min_accumulator!(data_type, i64, Time64MicrosecondType) + } + Time64(Nanosecond) => { + primitive_min_accumulator!(data_type, i64, Time64NanosecondType) + } + Timestamp(Second, _) => { + primitive_min_accumulator!(data_type, i64, TimestampSecondType) + } + Timestamp(Millisecond, _) => { + primitive_min_accumulator!(data_type, i64, TimestampMillisecondType) + } + Timestamp(Microsecond, _) => { + primitive_min_accumulator!(data_type, i64, TimestampMicrosecondType) + } + Timestamp(Nanosecond, _) => { + primitive_min_accumulator!(data_type, i64, TimestampNanosecondType) + } + Decimal128(_, _) => { + primitive_min_accumulator!(data_type, i128, Decimal128Type) + } + Decimal256(_, _) => { + primitive_min_accumulator!(data_type, i256, Decimal256Type) + } + Utf8 | LargeUtf8 | Utf8View | Binary | LargeBinary | BinaryView => { + Ok(Box::new(MinMaxBytesAccumulator::new_min(data_type.clone()))) + } + + // This is only reached if groups_accumulator_supported is out of sync + _ => internal_err!("GroupsAccumulator not supported for min({})", data_type), + } + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(SlidingMinAccumulator::try_new(args.return_type)?)) + } + + fn is_descending(&self) -> Option { + Some(false) + } + + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { + datafusion_expr::utils::AggregateOrderSensitivity::Insensitive + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + get_min_max_result_type(arg_types) + } + + fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { + datafusion_expr::ReversedUDAF::Identical + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_min_doc()) + } +} + +fn get_min_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the minimum value in the specified column.") + .with_syntax_example("min(expression)") + .with_sql_example( + r#"```sql +> SELECT min(column_name) FROM table_name; ++----------------------+ +| min(column_name) | ++----------------------+ +| 12 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +/// An accumulator to compute the minimum value +#[derive(Debug)] +pub struct MinAccumulator { + min: ScalarValue, +} + +impl MinAccumulator { + /// new min accumulator + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + }) + } +} + +impl Accumulator for MinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &values[0]; + let delta = &min_batch(values)?; + let new_min: Result = + min_max!(&self.min, delta, min); + self.min = new_min?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.min) + self.min.size() + } +} + +#[derive(Debug)] +pub struct SlidingMinAccumulator { + min: ScalarValue, + moving_min: MovingMin, +} + +impl SlidingMinAccumulator { + pub fn try_new(datatype: &DataType) -> Result { + Ok(Self { + min: ScalarValue::try_from(datatype)?, + moving_min: MovingMin::::new(), + }) + } +} + +impl Accumulator for SlidingMinAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.min.clone()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + self.moving_min.push(val); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + for idx in 0..values[0].len() { + let val = ScalarValue::try_from_array(&values[0], idx)?; + if !val.is_null() { + (self.moving_min).pop(); + } + } + if let Some(res) = self.moving_min.min() { + self.min = res.clone(); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + Ok(self.min.clone()) + } + + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + size_of_val(self) - size_of_val(&self.min) + self.min.size() + } +} + +/// Keep track of the minimum value in a sliding window. +/// +/// The implementation is taken from +/// +/// `moving min max` provides one data structure for keeping track of the +/// minimum value and one for keeping track of the maximum value in a sliding +/// window. +/// +/// Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, +/// push to this stack all elements popped from first stack while updating their current min/max. Now pop from +/// the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, +/// look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. +/// +/// The complexity of the operations are +/// - O(1) for getting the minimum/maximum +/// - O(1) for push +/// - amortized O(1) for pop +/// +/// ``` +/// # use datafusion_functions_aggregate::min_max::MovingMin; +/// let mut moving_min = MovingMin::::new(); +/// moving_min.push(2); +/// moving_min.push(1); +/// moving_min.push(3); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(2)); +/// +/// assert_eq!(moving_min.min(), Some(&1)); +/// assert_eq!(moving_min.pop(), Some(1)); +/// +/// assert_eq!(moving_min.min(), Some(&3)); +/// assert_eq!(moving_min.pop(), Some(3)); +/// +/// assert_eq!(moving_min.min(), None); +/// assert_eq!(moving_min.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMin { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMin { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMin { + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMin` to keep track of the minimum in a sliding + /// window with `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the minimum of the sliding window or `None` if the window is + /// empty. + #[inline] + pub fn min(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, min)), None) => Some(min), + (None, Some((_, min))) => Some(min), + (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, min)) => { + if val > *min { + (val, min.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let min = if last.1 < val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), min); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// Keep track of the maximum value in a sliding window. +/// +/// See [`MovingMin`] for more details. +/// +/// ``` +/// # use datafusion_functions_aggregate::min_max::MovingMax; +/// let mut moving_max = MovingMax::::new(); +/// moving_max.push(2); +/// moving_max.push(3); +/// moving_max.push(1); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(2)); +/// +/// assert_eq!(moving_max.max(), Some(&3)); +/// assert_eq!(moving_max.pop(), Some(3)); +/// +/// assert_eq!(moving_max.max(), Some(&1)); +/// assert_eq!(moving_max.pop(), Some(1)); +/// +/// assert_eq!(moving_max.max(), None); +/// assert_eq!(moving_max.pop(), None); +/// ``` +#[derive(Debug)] +pub struct MovingMax { + push_stack: Vec<(T, T)>, + pop_stack: Vec<(T, T)>, +} + +impl Default for MovingMax { + fn default() -> Self { + Self { + push_stack: Vec::new(), + pop_stack: Vec::new(), + } + } +} + +impl MovingMax { + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. + #[inline] + pub fn new() -> Self { + Self::default() + } + + /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with + /// `capacity` allocated slots. + #[inline] + pub fn with_capacity(capacity: usize) -> Self { + Self { + push_stack: Vec::with_capacity(capacity), + pop_stack: Vec::with_capacity(capacity), + } + } + + /// Returns the maximum of the sliding window or `None` if the window is empty. + #[inline] + pub fn max(&self) -> Option<&T> { + match (self.push_stack.last(), self.pop_stack.last()) { + (None, None) => None, + (Some((_, max)), None) => Some(max), + (None, Some((_, max))) => Some(max), + (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), + } + } + + /// Pushes a new element into the sliding window. + #[inline] + pub fn push(&mut self, val: T) { + self.push_stack.push(match self.push_stack.last() { + Some((_, max)) => { + if val < *max { + (val, max.clone()) + } else { + (val.clone(), val) + } + } + None => (val.clone(), val), + }); + } + + /// Removes and returns the last value of the sliding window. + #[inline] + pub fn pop(&mut self) -> Option { + if self.pop_stack.is_empty() { + match self.push_stack.pop() { + Some((val, _)) => { + let mut last = (val.clone(), val); + self.pop_stack.push(last.clone()); + while let Some((val, _)) = self.push_stack.pop() { + let max = if last.1 > val { + last.1.clone() + } else { + val.clone() + }; + last = (val.clone(), max); + self.pop_stack.push(last.clone()); + } + } + None => return None, + } + } + self.pop_stack.pop().map(|(val, _)| val) + } + + /// Returns the number of elements stored in the sliding window. + #[inline] + pub fn len(&self) -> usize { + self.push_stack.len() + self.pop_stack.len() + } + + /// Returns `true` if the moving window contains no elements. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +make_udaf_expr_and_func!( + Max, + max, + expression, + "Returns the maximum of a group of values.", + max_udaf +); + +make_udaf_expr_and_func!( + Min, + min, + expression, + "Returns the minimum of a group of values.", + min_udaf +); + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{ + IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + }; + use std::sync::Arc; + + #[test] + fn interval_min_max() { + // IntervalYearMonth + let b = IntervalYearMonthArray::from(vec![ + IntervalYearMonthType::make_value(0, 1), + IntervalYearMonthType::make_value(5, 34), + IntervalYearMonthType::make_value(-2, 4), + IntervalYearMonthType::make_value(7, -4), + IntervalYearMonthType::make_value(0, 1), + ]); + let b: ArrayRef = Arc::new(b); + + let mut min = + MinAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); + let min_res = min.evaluate().unwrap(); + assert_eq!( + min_res, + ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value( + -2, 4 + ))) + ); + + let mut max = + MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::YearMonth)) + .unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); + let max_res = max.evaluate().unwrap(); + assert_eq!( + max_res, + ScalarValue::IntervalYearMonth(Some(IntervalYearMonthType::make_value( + 5, 34 + ))) + ); + + // IntervalDayTime + let b = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(0, 0), + IntervalDayTimeType::make_value(5, 454000), + IntervalDayTimeType::make_value(-34, 0), + IntervalDayTimeType::make_value(7, -4000), + IntervalDayTimeType::make_value(1, 0), + ]); + let b: ArrayRef = Arc::new(b); + + let mut min = + MinAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); + let min_res = min.evaluate().unwrap(); + assert_eq!( + min_res, + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(-34, 0))) + ); + + let mut max = + MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::DayTime)).unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); + let max_res = max.evaluate().unwrap(); + assert_eq!( + max_res, + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(7, -4000))) + ); + + // IntervalMonthDayNano + let b = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(1, 0, 0), + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000), + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000), + IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000), + IntervalMonthDayNanoType::make_value(1, 0, 0), + ]); + let b: ArrayRef = Arc::new(b); + + let mut min = + MinAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(); + min.update_batch(&[Arc::clone(&b)]).unwrap(); + let min_res = min.evaluate().unwrap(); + assert_eq!( + min_res, + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000) + )) + ); + + let mut max = + MaxAccumulator::try_new(&DataType::Interval(IntervalUnit::MonthDayNano)) + .unwrap(); + max.update_batch(&[Arc::clone(&b)]).unwrap(); + let max_res = max.evaluate().unwrap(); + assert_eq!( + max_res, + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000) + )) + ); + } + + #[test] + fn float_min_max_with_nans() { + let pos_nan = f32::NAN; + let zero = 0_f32; + let neg_inf = f32::NEG_INFINITY; + + let check = |acc: &mut dyn Accumulator, values: &[&[f32]], expected: f32| { + for batch in values.iter() { + let batch = + Arc::new(Float32Array::from_iter_values(batch.iter().copied())); + acc.update_batch(&[batch]).unwrap(); + } + let result = acc.evaluate().unwrap(); + assert_eq!(result, ScalarValue::Float32(Some(expected))); + }; + + // This test checks both comparison between batches (which uses the min_max macro + // defined above) and within a batch (which uses the arrow min/max compute function + // and verifies both respect the total order comparison for floats) + + let min = || MinAccumulator::try_new(&DataType::Float32).unwrap(); + let max = || MaxAccumulator::try_new(&DataType::Float32).unwrap(); + + check(&mut min(), &[&[zero], &[pos_nan]], zero); + check(&mut min(), &[&[zero, pos_nan]], zero); + check(&mut min(), &[&[zero], &[neg_inf]], neg_inf); + check(&mut min(), &[&[zero, neg_inf]], neg_inf); + check(&mut max(), &[&[zero], &[pos_nan]], pos_nan); + check(&mut max(), &[&[zero, pos_nan]], pos_nan); + check(&mut max(), &[&[zero], &[neg_inf]], zero); + check(&mut max(), &[&[zero, neg_inf]], zero); + } + + use datafusion_common::Result; + use rand::Rng; + + fn get_random_vec_i32(len: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut input = Vec::with_capacity(len); + for _i in 0..len { + input.push(rng.gen_range(0..100)); + } + input + } + + fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_min = MovingMin::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().min().unwrap()); + + moving_min.push(data[i]); + if i > n_sliding_window { + moving_min.pop(); + } + res.push(*moving_min.min().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { + let data = get_random_vec_i32(len); + let mut expected = Vec::with_capacity(len); + let mut moving_max = MovingMax::::new(); + let mut res = Vec::with_capacity(len); + for i in 0..len { + let start = i.saturating_sub(n_sliding_window); + expected.push(*data[start..i + 1].iter().max().unwrap()); + + moving_max.push(data[i]); + if i > n_sliding_window { + moving_max.pop(); + } + res.push(*moving_max.max().unwrap()); + } + assert_eq!(res, expected); + Ok(()) + } + + #[test] + fn moving_min_tests() -> Result<()> { + moving_min_i32(100, 10)?; + moving_min_i32(100, 20)?; + moving_min_i32(100, 50)?; + moving_min_i32(100, 100)?; + Ok(()) + } + + #[test] + fn moving_max_tests() -> Result<()> { + moving_max_i32(100, 10)?; + moving_max_i32(100, 20)?; + moving_max_i32(100, 50)?; + moving_max_i32(100, 100)?; + Ok(()) + } + + #[test] + fn test_min_max_coerce_types() { + // the coerced types is same with input types + let funs: Vec> = + vec![Box::new(Min::new()), Box::new(Max::new())]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Decimal128(10, 2)], + vec![DataType::Decimal256(1, 1)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let result = fun.coerce_types(input_type); + assert_eq!(*input_type, result.unwrap()); + } + } + } + + #[test] + fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { + let data_type = + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + let result = get_min_max_result_type(&[data_type])?; + assert_eq!(result, vec![DataType::Int32]); + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs new file mode 100644 index 000000000000..501454edf77c --- /dev/null +++ b/datafusion/functions-aggregate/src/min_max/min_max_bytes.rs @@ -0,0 +1,515 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + Array, ArrayRef, AsArray, BinaryBuilder, BinaryViewBuilder, BooleanArray, + LargeBinaryBuilder, LargeStringBuilder, StringBuilder, StringViewBuilder, +}; +use arrow_schema::DataType; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::{EmitTo, GroupsAccumulator}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls; +use std::mem::size_of; +use std::sync::Arc; + +/// Implements fast Min/Max [`GroupsAccumulator`] for "bytes" types ([`StringArray`], +/// [`BinaryArray`], [`StringViewArray`], etc) +/// +/// This implementation dispatches to the appropriate specialized code in +/// [`MinMaxBytesState`] based on data type and comparison function +/// +/// [`StringArray`]: arrow::array::StringArray +/// [`BinaryArray`]: arrow::array::BinaryArray +/// [`StringViewArray`]: arrow::array::StringViewArray +#[derive(Debug)] +pub(crate) struct MinMaxBytesAccumulator { + /// Inner data storage. + inner: MinMaxBytesState, + /// if true, is `MIN` otherwise is `MAX` + is_min: bool, +} + +impl MinMaxBytesAccumulator { + /// Create a new accumulator for computing `min(val)` + pub fn new_min(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: true, + } + } + + /// Create a new accumulator fo computing `max(val)` + pub fn new_max(data_type: DataType) -> Self { + Self { + inner: MinMaxBytesState::new(data_type), + is_min: false, + } + } +} + +impl GroupsAccumulator for MinMaxBytesAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + let array = &values[0]; + assert_eq!(array.len(), group_indices.len()); + assert_eq!(array.data_type(), &self.inner.data_type); + + // apply filter if needed + let array = apply_filter_as_nulls(array, opt_filter)?; + + // dispatch to appropriate kernel / specialized implementation + fn string_min(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a < b + } + } + fn string_max(a: &[u8], b: &[u8]) -> bool { + // safety: only called from this function, which ensures a and b come + // from an array with valid utf8 data + unsafe { + let a = std::str::from_utf8_unchecked(a); + let b = std::str::from_utf8_unchecked(b); + a > b + } + } + fn binary_min(a: &[u8], b: &[u8]) -> bool { + a < b + } + + fn binary_max(a: &[u8], b: &[u8]) -> bool { + a > b + } + + fn str_to_bytes<'a>( + it: impl Iterator>, + ) -> impl Iterator> { + it.map(|s| s.map(|s| s.as_bytes())) + } + + match (self.is_min, &self.inner.data_type) { + // Utf8/LargeUtf8/Utf8View Min + (true, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_min, + ), + (true, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_min, + ), + + // Utf8/LargeUtf8/Utf8View Max + (false, &DataType::Utf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::LargeUtf8) => self.inner.update_batch( + str_to_bytes(array.as_string::().iter()), + group_indices, + total_num_groups, + string_max, + ), + (false, &DataType::Utf8View) => self.inner.update_batch( + str_to_bytes(array.as_string_view().iter()), + group_indices, + total_num_groups, + string_max, + ), + + // Binary/LargeBinary/BinaryView Min + (true, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_min, + ), + (true, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_min, + ), + + // Binary/LargeBinary/BinaryView Max + (false, &DataType::Binary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::LargeBinary) => self.inner.update_batch( + array.as_binary::().iter(), + group_indices, + total_num_groups, + binary_max, + ), + (false, &DataType::BinaryView) => self.inner.update_batch( + array.as_binary_view().iter(), + group_indices, + total_num_groups, + binary_max, + ), + + _ => internal_err!( + "Unexpected combination for MinMaxBytesAccumulator: ({:?}, {:?})", + self.is_min, + self.inner.data_type + ), + } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let (data_capacity, min_maxes) = self.inner.emit_to(emit_to); + + // Convert the Vec of bytes to a vec of Strings (at no cost) + fn bytes_to_str( + min_maxes: Vec>>, + ) -> impl Iterator> { + min_maxes.into_iter().map(|opt| { + opt.map(|bytes| { + // Safety: only called on data added from update_batch which ensures + // the input type matched the output type + unsafe { String::from_utf8_unchecked(bytes) } + }) + }) + } + + let result: ArrayRef = match self.inner.data_type { + DataType::Utf8 => { + let mut builder = + StringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::LargeUtf8 => { + let mut builder = + LargeStringBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Utf8View => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = StringViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in bytes_to_str(min_maxes) { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_str()), + } + } + Arc::new(builder.finish()) + } + DataType::Binary => { + let mut builder = + BinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::LargeBinary => { + let mut builder = + LargeBinaryBuilder::with_capacity(min_maxes.len(), data_capacity); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + DataType::BinaryView => { + let block_size = capacity_to_view_block_size(data_capacity); + + let mut builder = BinaryViewBuilder::with_capacity(min_maxes.len()) + .with_fixed_block_size(block_size); + for opt in min_maxes { + match opt { + None => builder.append_null(), + Some(s) => builder.append_value(s.as_ref() as &[u8]), + } + } + Arc::new(builder.finish()) + } + _ => { + return internal_err!( + "Unexpected data type for MinMaxBytesAccumulator: {:?}", + self.inner.data_type + ); + } + }; + + assert_eq!(&self.inner.data_type, result.data_type()); + Ok(result) + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + // min/max are their own states (no transition needed) + self.evaluate(emit_to).map(|arr| vec![arr]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + // min/max are their own states (no transition needed) + self.update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + // Min/max do not change the values as they are their own states + // apply the filter by combining with the null mask, if any + let output = apply_filter_as_nulls(&values[0], opt_filter)?; + Ok(vec![output]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +/// Returns the block size in (contiguous buffer size) to use +/// for a given data capacity (total string length) +/// +/// This is a heuristic to avoid allocating too many small buffers +fn capacity_to_view_block_size(data_capacity: usize) -> u32 { + let max_block_size = 2 * 1024 * 1024; + if let Ok(block_size) = u32::try_from(data_capacity) { + block_size.min(max_block_size) + } else { + max_block_size + } +} + +/// Stores internal Min/Max state for "bytes" types. +/// +/// This implementation is general and stores the minimum/maximum for each +/// groups in an individual byte array, which balances allocations and memory +/// fragmentation (aka garbage). +/// +/// ```text +/// ┌─────────────────────────────────┐ +/// ┌─────┐ ┌────▶│Option> (["A"]) │───────────▶ "A" +/// │ 0 │────┘ └─────────────────────────────────┘ +/// ├─────┤ ┌─────────────────────────────────┐ +/// │ 1 │─────────▶│Option> (["Z"]) │───────────▶ "Z" +/// └─────┘ └─────────────────────────────────┘ ... +/// ... ... +/// ┌─────┐ ┌────────────────────────────────┐ +/// │ N-2 │─────────▶│Option> (["A"]) │────────────▶ "A" +/// ├─────┤ └────────────────────────────────┘ +/// │ N-1 │────┐ ┌────────────────────────────────┐ +/// └─────┘ └────▶│Option> (["Q"]) │────────────▶ "Q" +/// └────────────────────────────────┘ +/// +/// min_max: Vec> +/// ``` +/// +/// Note that for `StringViewArray` and `BinaryViewArray`, there are potentially +/// more efficient implementations (e.g. by managing a string data buffer +/// directly), but then garbage collection, memory management, and final array +/// construction becomes more complex. +/// +/// See discussion on +#[derive(Debug)] +struct MinMaxBytesState { + /// The minimum/maximum value for each group + min_max: Vec>>, + /// The data type of the array + data_type: DataType, + /// The total bytes of the string data (for pre-allocating the final array, + /// and tracking memory usage) + total_data_bytes: usize, +} + +#[derive(Debug, Clone, Copy)] +enum MinMaxLocation<'a> { + /// the min/max value is stored in the existing `min_max` array + ExistingMinMax, + /// the min/max value is stored in the input array at the given index + Input(&'a [u8]), +} + +/// Implement the MinMaxBytesAccumulator with a comparison function +/// for comparing strings +impl MinMaxBytesState { + /// Create a new MinMaxBytesAccumulator + /// + /// # Arguments: + /// * `data_type`: The data type of the arrays that will be passed to this accumulator + fn new(data_type: DataType) -> Self { + Self { + min_max: vec![], + data_type, + total_data_bytes: 0, + } + } + + /// Set the specified group to the given value, updating memory usage appropriately + fn set_value(&mut self, group_index: usize, new_val: &[u8]) { + match self.min_max[group_index].as_mut() { + None => { + self.min_max[group_index] = Some(new_val.to_vec()); + self.total_data_bytes += new_val.len(); + } + Some(existing_val) => { + // Copy data over to avoid re-allocating + self.total_data_bytes -= existing_val.len(); + self.total_data_bytes += new_val.len(); + existing_val.clear(); + existing_val.extend_from_slice(new_val); + } + } + } + + /// Updates the min/max values for the given string values + /// + /// `cmp` is the comparison function to use, called like `cmp(new_val, existing_val)` + /// returns true if the `new_val` should replace `existing_val` + fn update_batch<'a, F, I>( + &mut self, + iter: I, + group_indices: &[usize], + total_num_groups: usize, + mut cmp: F, + ) -> Result<()> + where + F: FnMut(&[u8], &[u8]) -> bool + Send + Sync, + I: IntoIterator>, + { + self.min_max.resize(total_num_groups, None); + // Minimize value copies by calculating the new min/maxes for each group + // in this batch (either the existing min/max or the new input value) + // and updating the owne values in `self.min_maxes` at most once + let mut locations = vec![MinMaxLocation::ExistingMinMax; total_num_groups]; + + // Figure out the new min value for each group + for (new_val, group_index) in iter.into_iter().zip(group_indices.iter()) { + let group_index = *group_index; + let Some(new_val) = new_val else { + continue; // skip nulls + }; + + let existing_val = match locations[group_index] { + // previous input value was the min/max, so compare it + MinMaxLocation::Input(existing_val) => existing_val, + MinMaxLocation::ExistingMinMax => { + let Some(exising_val) = self.min_max[group_index].as_ref() else { + // no existing min/max, so this is the new min/max + locations[group_index] = MinMaxLocation::Input(new_val); + continue; + }; + exising_val.as_ref() + } + }; + + // Compare the new value to the existing value, replacing if necessary + if cmp(new_val, existing_val) { + locations[group_index] = MinMaxLocation::Input(new_val); + } + } + + // Update self.min_max with any new min/max values we found in the input + for (group_index, location) in locations.iter().enumerate() { + match location { + MinMaxLocation::ExistingMinMax => {} + MinMaxLocation::Input(new_val) => self.set_value(group_index, new_val), + } + } + Ok(()) + } + + /// Emits the specified min_max values + /// + /// Returns (data_capacity, min_maxes), updating the current value of total_data_bytes + /// + /// - `data_capacity`: the total length of all strings and their contents, + /// - `min_maxes`: the actual min/max values for each group + fn emit_to(&mut self, emit_to: EmitTo) -> (usize, Vec>>) { + match emit_to { + EmitTo::All => { + ( + std::mem::take(&mut self.total_data_bytes), // reset total bytes and min_max + std::mem::take(&mut self.min_max), + ) + } + EmitTo::First(n) => { + let first_min_maxes: Vec<_> = self.min_max.drain(..n).collect(); + let first_data_capacity: usize = first_min_maxes + .iter() + .map(|opt| opt.as_ref().map(|s| s.len()).unwrap_or(0)) + .sum(); + self.total_data_bytes -= first_data_capacity; + (first_data_capacity, first_min_maxes) + } + } + } + + fn size(&self) -> usize { + self.total_data_bytes + self.min_max.len() * size_of::>>() + } +} diff --git a/datafusion/physical-expr/src/aggregate/nth_value.rs b/datafusion/functions-aggregate/src/nth_value.rs similarity index 65% rename from datafusion/physical-expr/src/aggregate/nth_value.rs rename to datafusion/functions-aggregate/src/nth_value.rs index dba259a507fd..5f3a8cf2f161 100644 --- a/datafusion/physical-expr/src/aggregate/nth_value.rs +++ b/datafusion/functions-aggregate/src/nth_value.rs @@ -20,146 +20,188 @@ use std::any::Any; use std::collections::VecDeque; -use std::sync::Arc; +use std::mem::{size_of, size_of_val}; +use std::sync::{Arc, OnceLock}; -use crate::aggregate::array_agg_ordered::merge_ordered_arrays; -use crate::aggregate::utils::{down_cast_any_ref, ordering_fields}; -use crate::expressions::{format_state_name, Literal}; -use crate::{ - reverse_order_bys, AggregateExpr, LexOrdering, PhysicalExpr, PhysicalSortExpr, +use arrow::array::{new_empty_array, ArrayRef, AsArray, StructArray}; +use arrow_schema::{DataType, Field, Fields}; + +use datafusion_common::utils::{array_into_list_array_nullable, get_row_at_idx}; +use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + lit, Accumulator, AggregateUDFImpl, Documentation, ExprFunctionExt, ReversedUDAF, + Signature, SortExpr, Volatility, }; +use datafusion_functions_aggregate_common::merge_arrays::merge_ordered_arrays; +use datafusion_functions_aggregate_common::utils::ordering_fields; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; -use arrow_array::cast::AsArray; -use arrow_array::{new_empty_array, ArrayRef, StructArray}; -use arrow_schema::{DataType, Field, Fields}; -use datafusion_common::utils::{array_into_list_array, get_row_at_idx}; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; +create_func!(NthValueAgg, nth_value_udaf); + +/// Returns the nth value in a group of values. +pub fn nth_value( + expr: datafusion_expr::Expr, + n: i64, + order_by: Vec, +) -> datafusion_expr::Expr { + let args = vec![expr, lit(n)]; + if !order_by.is_empty() { + nth_value_udaf() + .call(args) + .order_by(order_by) + .build() + .unwrap() + } else { + nth_value_udaf().call(args) + } +} /// Expression for a `NTH_VALUE(... ORDER BY ..., ...)` aggregation. In a multi /// partition setting, partial aggregations are computed for every partition, /// and then their results are merged. #[derive(Debug)] pub struct NthValueAgg { - /// Column name - name: String, - /// The `DataType` for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// The `N` value. - n: i64, - /// If the input expression can have `NULL`s - nullable: bool, - /// Ordering data types - order_by_data_types: Vec, - /// Ordering requirement - ordering_req: LexOrdering, + signature: Signature, } impl NthValueAgg { /// Create a new `NthValueAgg` aggregate function - pub fn new( - expr: Arc, - n: i64, - name: impl Into, - input_data_type: DataType, - nullable: bool, - order_by_data_types: Vec, - ordering_req: LexOrdering, - ) -> Self { + pub fn new() -> Self { Self { - name: name.into(), - input_data_type, - expr, - n, - nullable, - order_by_data_types, - ordering_req, + signature: Signature::any(2, Volatility::Immutable), } } } -impl AggregateExpr for NthValueAgg { +impl Default for NthValueAgg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for NthValueAgg { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.input_data_type.clone(), true)) + fn name(&self) -> &str { + "nth_value" } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(NthValueAccumulator::try_new( - self.n, - &self.input_data_type, - &self.order_by_data_types, - self.ordering_req.clone(), - )?)) + fn signature(&self) -> &Signature { + &self.signature } - fn state_fields(&self) -> Result> { + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let n = match acc_args.exprs[1] + .as_any() + .downcast_ref::() + .map(|lit| lit.value()) + { + Some(ScalarValue::Int64(Some(value))) => { + if acc_args.is_reversed { + -*value + } else { + *value + } + } + _ => { + return not_impl_err!( + "{} not supported for n: {}", + self.name(), + &acc_args.exprs[1] + ) + } + }; + + let ordering_dtypes = acc_args + .ordering_req + .iter() + .map(|e| e.expr.data_type(acc_args.schema)) + .collect::>>()?; + + let data_type = acc_args.exprs[0].data_type(acc_args.schema)?; + NthValueAccumulator::try_new( + n, + &data_type, + &ordering_dtypes, + LexOrdering::from_ref(acc_args.ordering_req), + ) + .map(|acc| Box::new(acc) as _) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { let mut fields = vec![Field::new_list( - format_state_name(&self.name, "nth_value"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, // This should be the same as field() + format_state_name(self.name(), "nth_value"), + // See COMMENTS.md to understand why nullable is set to true + Field::new("item", args.input_types[0].clone(), true), + false, )]; - if !self.ordering_req.is_empty() { - let orderings = - ordering_fields(&self.ordering_req, &self.order_by_data_types); + let orderings = args.ordering_fields.to_vec(); + if !orderings.is_empty() { fields.push(Field::new_list( - format_state_name(&self.name, "nth_value_orderings"), + format_state_name(self.name(), "nth_value_orderings"), Field::new("item", DataType::Struct(Fields::from(orderings)), true), - self.nullable, + false, )); } Ok(fields) } - fn expressions(&self) -> Vec> { - let n = Arc::new(Literal::new(ScalarValue::Int64(Some(self.n)))) as _; - vec![self.expr.clone(), n] - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) + fn aliases(&self) -> &[String] { + &[] } - fn name(&self) -> &str { - &self.name + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Reversed(nth_value_udaf()) } - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.to_string(), - input_data_type: self.input_data_type.clone(), - expr: self.expr.clone(), - // index should be from the opposite side - n: -self.n, - nullable: self.nullable, - order_by_data_types: self.order_by_data_types.clone(), - // reverse requirement - ordering_req: reverse_order_bys(&self.ordering_req), - }) as _) + fn documentation(&self) -> Option<&Documentation> { + Some(get_nth_value_doc()) } } -impl PartialEq for NthValueAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.order_by_data_types == x.order_by_data_types - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nth_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the nth value in a group of values.", + ) + .with_syntax_example("nth_value(expression, n ORDER BY expression)") + .with_sql_example(r#"```sql +> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept + FROM employee; ++---------+--------+-------------------------+ +| dept_id | salary | second_salary_by_dept | ++---------+--------+-------------------------+ +| 1 | 30000 | NULL | +| 1 | 40000 | 40000 | +| 1 | 50000 | 40000 | +| 2 | 35000 | NULL | +| 2 | 45000 | 45000 | ++---------+--------+-------------------------+ +```"#) + .with_argument("expression", "The column or expression to retrieve the nth value from.") + .with_argument("n", "The position (nth) of the value to retrieve, based on the ordering.") + .build() + .unwrap() + }) } #[derive(Debug)] -pub(crate) struct NthValueAccumulator { +pub struct NthValueAccumulator { + /// The `N` value. n: i64, /// Stores entries in the `NTH_VALUE` result. values: VecDeque, @@ -337,25 +379,23 @@ impl Accumulator for NthValueAccumulator { } fn size(&self) -> usize { - let mut total = std::mem::size_of_val(self) - + ScalarValue::size_of_vec_deque(&self.values) - - std::mem::size_of_val(&self.values); + let mut total = size_of_val(self) + ScalarValue::size_of_vec_deque(&self.values) + - size_of_val(&self.values); // Add size of the `self.ordering_values` - total += - std::mem::size_of::>() * self.ordering_values.capacity(); + total += size_of::>() * self.ordering_values.capacity(); for row in &self.ordering_values { - total += ScalarValue::size_of_vec(row) - std::mem::size_of_val(row); + total += ScalarValue::size_of_vec(row) - size_of_val(row); } // Add size of the `self.datatypes` - total += std::mem::size_of::() * self.datatypes.capacity(); + total += size_of::() * self.datatypes.capacity(); for dtype in &self.datatypes { - total += dtype.size() - std::mem::size_of_val(dtype); + total += dtype.size() - size_of_val(dtype); } // Add size of the `self.ordering_req` - total += std::mem::size_of::() * self.ordering_req.capacity(); + total += size_of::() * self.ordering_req.capacity(); // TODO: Calculate size of each `PhysicalSortExpr` more accurately. total } @@ -363,7 +403,7 @@ impl Accumulator for NthValueAccumulator { impl NthValueAccumulator { fn evaluate_orderings(&self) -> Result { - let fields = ordering_fields(&self.ordering_req, &self.datatypes[1..]); + let fields = ordering_fields(self.ordering_req.as_ref(), &self.datatypes[1..]); let struct_field = Fields::from(fields.clone()); let mut column_wise_ordering_values = vec![]; @@ -382,13 +422,10 @@ impl NthValueAccumulator { column_wise_ordering_values.push(array); } - let ordering_array = StructArray::try_new( - struct_field.clone(), - column_wise_ordering_values, - None, - )?; + let ordering_array = + StructArray::try_new(struct_field, column_wise_ordering_values, None)?; - Ok(ScalarValue::List(Arc::new(array_into_list_array( + Ok(ScalarValue::List(Arc::new(array_into_list_array_nullable( Arc::new(ordering_array), )))) } @@ -396,7 +433,10 @@ impl NthValueAccumulator { fn evaluate_values(&self) -> ScalarValue { let mut values_cloned = self.values.clone(); let values_slice = values_cloned.make_contiguous(); - ScalarValue::List(ScalarValue::new_list(values_slice, &self.datatypes[0])) + ScalarValue::List(ScalarValue::new_list_nullable( + values_slice, + &self.datatypes[0], + )) } /// Updates state, with the `values`. Fetch contains missing number of entries for state to be complete diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/functions-aggregate/src/regr.rs similarity index 61% rename from datafusion/physical-expr/src/aggregate/regr.rs rename to datafusion/functions-aggregate/src/regr.rs index 36e7b7c9b3e4..bf1e81949d23 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -17,10 +17,6 @@ //! Defines physical expressions that can evaluated at runtime during query execution -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -28,28 +24,64 @@ use arrow::{ datatypes::DataType, datatypes::Field, }; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, Volatility, +}; +use std::any::Any; +use std::collections::HashMap; +use std::fmt::Debug; +use std::mem::size_of_val; +use std::sync::OnceLock; + +macro_rules! make_regr_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { + make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN); + create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN))); + } +} -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; +make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope); +make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept); +make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count); +make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2); +make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX); +make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY); +make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); +make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); +make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); -#[derive(Debug)] pub struct Regr { - name: String, + signature: Signature, regr_type: RegrType, - expr_y: Arc, - expr_x: Arc, + func_name: &'static str, +} + +impl Debug for Regr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("regr") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } } impl Regr { - pub fn get_regr_type(&self) -> RegrType { - self.regr_type.clone() + pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + regr_type, + func_name, + } } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Hash, Eq)] #[allow(clippy::upper_case_acronyms)] pub enum RegrType { /// Variant for `regr_slope` aggregate expression @@ -92,96 +124,214 @@ pub enum RegrType { SXY, } -impl Regr { - pub fn new( - expr_y: Arc, - expr_x: Arc, - name: impl Into, - regr_type: RegrType, - return_type: DataType, - ) -> Self { - // the result of regr_slope only support FLOAT64 data type. - assert!(matches!(return_type, DataType::Float64)); - Self { - name: name.into(), - regr_type, - expr_y, - expr_x, - } +impl RegrType { + /// return the documentation for the `RegrType` + fn documentation(&self) -> Option<&Documentation> { + get_regr_docs().get(self) } } -impl AggregateExpr for Regr { +static DOCUMENTATION: OnceLock> = OnceLock::new(); +fn get_regr_docs() -> &'static HashMap { + DOCUMENTATION.get_or_init(|| { + let mut hash_map = HashMap::new(); + hash_map.insert( + RegrType::Slope, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the slope of the linear regression line for non-null pairs in aggregate columns. \ + Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k*X + b) using minimal RSS fitting.", + ) + .with_syntax_example("regr_slope(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::Intercept, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the y-intercept of the linear regression line. For the equation (y = kx + b), \ + this function returns b.", + ) + .with_syntax_example("regr_intercept(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::Count, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Counts the number of non-null paired data points.", + ) + .with_syntax_example("regr_count(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::R2, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the square of the correlation coefficient between the independent and dependent variables.", + ) + .with_syntax_example("regr_r2(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::AvgX, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the average of the independent variable (input) expression_x for the non-null paired data points.", + ) + .with_syntax_example("regr_avgx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::AvgY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the average of the dependent variable (output) expression_y for the non-null paired data points.", + ) + .with_syntax_example("regr_avgy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SXX, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of squares of the independent variable.", + ) + .with_syntax_example("regr_sxx(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SYY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of squares of the dependent variable.", + ) + .with_syntax_example("regr_syy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + + hash_map.insert( + RegrType::SXY, + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Computes the sum of products of paired data points.", + ) + .with_syntax_example("regr_sxy(expression_y, expression_x)") + .with_standard_argument("expression_y", Some("Dependent variable")) + .with_standard_argument("expression_x", Some("Independent variable")) + .build() + .unwrap() + ); + hash_map + }) +} + +impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + self.func_name } - fn create_accumulator(&self) -> Result> { - Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) + fn signature(&self) -> &Signature { + &self.signature } - fn create_sliding_accumulator(&self) -> Result> { + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + if matches!(self.regr_type, RegrType::Count) { + Ok(DataType::UInt64) + } else { + Ok(DataType::Float64) + } + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "mean_x"), + format_state_name(args.name, "mean_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "mean_y"), + format_state_name(args.name, "mean_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_x"), + format_state_name(args.name, "m2_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_y"), + format_state_name(args.name, "m2_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "algo_const"), + format_state_name(args.name, "algo_const"), DataType::Float64, true, ), ]) } - fn expressions(&self) -> Vec> { - vec![self.expr_y.clone(), self.expr_x.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Regr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.expr_y.eq(&x.expr_y) - && self.expr_x.eq(&x.expr_x) - }) - .unwrap_or(false) + fn documentation(&self) -> Option<&Documentation> { + self.regr_type.documentation() } } @@ -305,6 +455,10 @@ impl Accumulator for RegrAccumulator { Ok(()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values_y = &cast(&values[0], &DataType::Float64)?; let values_x = &cast(&values[1], &DataType::Float64)?; @@ -443,7 +597,7 @@ impl Accumulator for RegrAccumulator { let nullif_cond = self.count <= 1 || var_pop_x == 0.0; nullif_or_stat(nullif_cond, self.mean_y - slope * self.mean_x) } - RegrType::Count => Ok(ScalarValue::Float64(Some(self.count as f64))), + RegrType::Count => Ok(ScalarValue::UInt64(Some(self.count))), RegrType::R2 => { // Only 0/1 point or all x(or y) is the same let nullif_cond = self.count <= 1 || var_pop_x == 0.0 || var_pop_y == 0.0; @@ -461,6 +615,6 @@ impl Accumulator for RegrAccumulator { } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } diff --git a/datafusion/functions-aggregate/src/stddev.rs b/datafusion/functions-aggregate/src/stddev.rs new file mode 100644 index 000000000000..95269ed8217c --- /dev/null +++ b/datafusion/functions-aggregate/src/stddev.rs @@ -0,0 +1,500 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions that can evaluated at runtime during query execution + +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::mem::align_of_val; +use std::sync::{Arc, OnceLock}; + +use arrow::array::Float64Array; +use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; + +use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_common::{plan_err, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_STATISTICAL; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, +}; +use datafusion_functions_aggregate_common::stats::StatsType; + +use crate::variance::{VarianceAccumulator, VarianceGroupsAccumulator}; + +make_udaf_expr_and_func!( + Stddev, + stddev, + expression, + "Compute the standard deviation of a set of numbers", + stddev_udaf +); + +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +pub struct Stddev { + signature: Signature, + alias: Vec, +} + +impl Debug for Stddev { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Stddev") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for Stddev { + fn default() -> Self { + Self::new() + } +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new() -> Self { + Self { + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), + alias: vec!["stddev_samp".to_string()], + } + } +} + +impl AggregateUDFImpl for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "stddev" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(args.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); + } + Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) + } + + fn aliases(&self) -> &[String] { + &self.alias + } + + fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { + !acc_args.is_distinct + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(StddevGroupsAccumulator::new(StatsType::Sample))) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_stddev_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_stddev_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description("Returns the standard deviation of a set of numbers.") + .with_syntax_example("stddev(expression)") + .with_sql_example( + r#"```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +make_udaf_expr_and_func!( + StddevPop, + stddev_pop, + expression, + "Compute the population standard deviation of a set of numbers", + stddev_pop_udaf +); + +/// STDDEV_POP population aggregate expression +pub struct StddevPop { + signature: Signature, +} + +impl Debug for StddevPop { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StddevPop") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for StddevPop { + fn default() -> Self { + Self::new() + } +} + +impl StddevPop { + /// Create a new STDDEV_POP aggregate function + pub fn new() -> Self { + Self { + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for StddevPop { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "stddev_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(args.name, "count"), + DataType::UInt64, + true, + ), + Field::new( + format_state_name(args.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(args.name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); + } + Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("StddevPop requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { + !acc_args.is_distinct + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(StddevGroupsAccumulator::new( + StatsType::Population, + ))) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_stddev_pop_doc()) + } +} + +fn get_stddev_pop_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STATISTICAL) + .with_description( + "Returns the population standard deviation of a set of numbers.", + ) + .with_syntax_example("stddev_pop(expression)") + .with_sql_example( + r#"```sql +> SELECT stddev_pop(column_name) FROM table_name; ++--------------------------+ +| stddev_pop(column_name) | ++--------------------------+ +| 10.56 | ++--------------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +/// An accumulator to compute the average +#[derive(Debug)] +pub struct StddevAccumulator { + variance: VarianceAccumulator, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + variance: VarianceAccumulator::try_new(s_type)?, + }) + } + + pub fn get_m2(&self) -> f64 { + self.variance.get_m2() + } +} + +impl Accumulator for StddevAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + ScalarValue::from(self.variance.get_mean()), + ScalarValue::from(self.variance.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.update_batch(values) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.retract_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.variance.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + let variance = self.variance.evaluate()?; + match variance { + ScalarValue::Float64(e) => { + if e.is_none() { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) + } + } + _ => internal_err!("Variance should be f64"), + } + } + + fn size(&self) -> usize { + align_of_val(self) - align_of_val(&self.variance) + self.variance.size() + } + + fn supports_retract_batch(&self) -> bool { + self.variance.supports_retract_batch() + } +} + +#[derive(Debug)] +pub struct StddevGroupsAccumulator { + variance: VarianceGroupsAccumulator, +} + +impl StddevGroupsAccumulator { + pub fn new(s_type: StatsType) -> Self { + Self { + variance: VarianceGroupsAccumulator::new(s_type), + } + } +} + +impl GroupsAccumulator for StddevGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.variance + .update_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&arrow::array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + self.variance + .merge_batch(values, group_indices, opt_filter, total_num_groups) + } + + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { + let (mut variances, nulls) = self.variance.variance(emit_to); + variances.iter_mut().for_each(|v| *v = v.sqrt()); + Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls)))) + } + + fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result> { + self.variance.state(emit_to) + } + + fn size(&self) -> usize { + self.variance.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::{array::*, datatypes::*}; + use datafusion_expr::AggregateUDF; + use datafusion_functions_aggregate_common::utils::get_accum_scalar_values_as_arrays; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr_common::sort_expr::LexOrderingRef; + use std::sync::Arc; + + #[test] + fn stddev_f64_merge_1() -> Result<()> { + let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); + let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = stddev_pop_udaf(); + let agg2 = stddev_pop_udaf(); + + let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?; + assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2)); + + Ok(()) + } + + #[test] + fn stddev_f64_merge_2() -> Result<()> { + let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); + let b = Arc::new(Float64Array::from(vec![None])); + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + + let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; + let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; + + let agg1 = stddev_pop_udaf(); + let agg2 = stddev_pop_udaf(); + + let actual = merge(&batch1, &batch2, agg1, agg2, &schema)?; + assert_eq!(actual, ScalarValue::from(std::f64::consts::SQRT_2)); + + Ok(()) + } + + fn merge( + batch1: &RecordBatch, + batch2: &RecordBatch, + agg1: Arc, + agg2: Arc, + schema: &Schema, + ) -> Result { + let args1 = AccumulatorArgs { + return_type: &DataType::Float64, + schema, + ignore_nulls: false, + ordering_req: LexOrderingRef::default(), + name: "a", + is_distinct: false, + is_reversed: false, + exprs: &[col("a", schema)?], + }; + + let args2 = AccumulatorArgs { + return_type: &DataType::Float64, + schema, + ignore_nulls: false, + ordering_req: LexOrderingRef::default(), + name: "a", + is_distinct: false, + is_reversed: false, + exprs: &[col("a", schema)?], + }; + + let mut accum1 = agg1.accumulator(args1)?; + let mut accum2 = agg2.accumulator(args2)?; + + let value1 = vec![col("a", schema)? + .evaluate(batch1) + .and_then(|v| v.into_array(batch1.num_rows()))?]; + let value2 = vec![col("a", schema)? + .evaluate(batch2) + .and_then(|v| v.into_array(batch2.num_rows()))?]; + + accum1.update_batch(&value1)?; + accum2.update_batch(&value2)?; + let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; + accum1.merge_batch(&state2)?; + let result = accum1.evaluate()?; + Ok(result) + } +} diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs new file mode 100644 index 000000000000..68267b9f72c7 --- /dev/null +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -0,0 +1,187 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] accumulator for the `string_agg` function + +use arrow::array::ArrayRef; +use arrow_schema::DataType; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::{not_impl_err, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility, +}; +use datafusion_physical_expr::expressions::Literal; +use std::any::Any; +use std::mem::size_of_val; +use std::sync::OnceLock; + +make_udaf_expr_and_func!( + StringAgg, + string_agg, + expr delimiter, + "Concatenates the values of string expressions and places separator values between them", + string_agg_udaf +); + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + signature: Signature, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for StringAgg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "string_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::LargeUtf8) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { + return match lit.value() { + ScalarValue::Utf8(Some(delimiter)) + | ScalarValue::LargeUtf8(Some(delimiter)) => { + Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) + } + ScalarValue::Utf8(None) + | ScalarValue::LargeUtf8(None) + | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), + e => not_impl_err!("StringAgg not supported for delimiter {}", e), + }; + } + + not_impl_err!("expect literal") + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_string_agg_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_string_agg_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Concatenates the values of string expressions and places separator values between them." + ) + .with_syntax_example("string_agg(expression, delimiter)") + .with_sql_example(r#"```sql +> SELECT string_agg(name, ', ') AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Alice, Bob, Charlie | ++--------------------------+ +```"#, + ) + .with_argument("expression", "The string expression to concatenate. Can be a column or any valid string expression.") + .with_argument("delimiter", "A literal string used as a separator between the concatenated values.") + .build() + .unwrap() + }) +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs new file mode 100644 index 000000000000..6ad376db4fb9 --- /dev/null +++ b/datafusion/functions-aggregate/src/sum.rs @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators + +use ahash::RandomState; +use datafusion_expr::utils::AggregateOrderSensitivity; +use std::any::Any; +use std::collections::HashSet; +use std::mem::{size_of, size_of_val}; +use std::sync::OnceLock; + +use arrow::array::Array; +use arrow::array::ArrowNativeTypeOp; +use arrow::array::{ArrowNumericType, AsArray}; +use arrow::datatypes::ArrowNativeType; +use arrow::datatypes::ArrowPrimitiveType; +use arrow::datatypes::{ + DataType, Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, +}; +use arrow::{array::ArrayRef, datatypes::Field}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, ReversedUDAF, + Signature, Volatility, +}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_functions_aggregate_common::utils::Hashable; + +make_udaf_expr_and_func!( + Sum, + sum, + expression, + "Returns the sum of a group of values.", + sum_udaf +); + +/// Sum only supports a subset of numeric types, instead relying on type coercion +/// +/// This macro is similar to [downcast_primitive](arrow::array::downcast_primitive) +/// +/// `args` is [AccumulatorArgs] +/// `helper` is a macro accepting (ArrowPrimitiveType, DataType) +macro_rules! downcast_sum { + ($args:ident, $helper:ident) => { + match $args.return_type { + DataType::UInt64 => $helper!(UInt64Type, $args.return_type), + DataType::Int64 => $helper!(Int64Type, $args.return_type), + DataType::Float64 => $helper!(Float64Type, $args.return_type), + DataType::Decimal128(_, _) => $helper!(Decimal128Type, $args.return_type), + DataType::Decimal256(_, _) => $helper!(Decimal256Type, $args.return_type), + _ => { + not_impl_err!( + "Sum not supported for {}: {}", + $args.name, + $args.return_type + ) + } + } + }; +} + +#[derive(Debug)] +pub struct Sum { + signature: Signature, +} + +impl Sum { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl Default for Sum { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for Sum { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return exec_err!("SUM expects exactly one argument"); + } + + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + + fn coerced_type(data_type: &DataType) -> Result { + match data_type { + DataType::Dictionary(_, v) => coerced_type(v), + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { + Ok(data_type.clone()) + } + dt if dt.is_signed_integer() => Ok(DataType::Int64), + dt if dt.is_unsigned_integer() => Ok(DataType::UInt64), + dt if dt.is_floating() => Ok(DataType::Float64), + _ => exec_err!("Sum not supported for {}", data_type), + } + } + + Ok(vec![coerced_type(&arg_types[0])?]) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + DataType::Int64 => Ok(DataType::Int64), + DataType::UInt64 => Ok(DataType::UInt64), + DataType::Float64 => Ok(DataType::Float64), + DataType::Decimal128(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal128(new_precision, *scale)) + } + DataType::Decimal256(precision, scale) => { + // in the spark, the result type is DECIMAL(min(38,precision+10), s) + // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 + let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); + Ok(DataType::Decimal256(new_precision, *scale)) + } + other => { + exec_err!("[return_type] SUM not supported for {}", other) + } + } + } + + fn accumulator(&self, args: AccumulatorArgs) -> Result> { + if args.is_distinct { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) + }; + } + downcast_sum!(args, helper) + } else { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) + } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + if args.is_distinct { + Ok(vec![Field::new_list( + format_state_name(args.name, "sum distinct"), + // See COMMENTS.md to understand why nullable is set to true + Field::new("item", args.return_type.clone(), true), + false, + )]) + } else { + Ok(vec![Field::new( + format_state_name(args.name, "sum"), + args.return_type.clone(), + true, + )]) + } + } + + fn aliases(&self) -> &[String] { + &[] + } + + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + !args.is_distinct + } + + fn create_groups_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( + &$dt, + |x, y| *x = x.add_wrapping(y), + ))) + }; + } + downcast_sum!(args, helper) + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) + }; + } + downcast_sum!(args, helper) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + AggregateOrderSensitivity::Insensitive + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sum_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sum_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description("Returns the sum of all values in the specified column.") + .with_syntax_example("sum(expression)") + .with_sql_example( + r#"```sql +> SELECT sum(column_name) FROM table_name; ++-----------------------+ +| sum(column_name) | ++-----------------------+ +| 12345 | ++-----------------------+ +```"#, + ) + .with_standard_argument("expression", None) + .build() + .unwrap() + }) +} + +/// This accumulator computes SUM incrementally +struct SumAccumulator { + sum: Option, + data_type: DataType, +} + +impl std::fmt::Debug for SumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SumAccumulator({})", self.data_type) + } +} + +impl SumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: None, + data_type, + } + } +} + +impl Accumulator for SumAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + if let Some(x) = arrow::compute::sum(values) { + let v = self.sum.get_or_insert(T::Native::usize_as(0)); + *v = v.add_wrapping(x); + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.update_batch(states) + } + + fn evaluate(&mut self) -> Result { + ScalarValue::new_primitive::(self.sum, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + } +} + +/// This accumulator incrementally computes sums over a sliding window +/// +/// This is separate from [`SumAccumulator`] as requires additional state +struct SlidingSumAccumulator { + sum: T::Native, + count: u64, + data_type: DataType, +} + +impl std::fmt::Debug for SlidingSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "SlidingSumAccumulator({})", self.data_type) + } +} + +impl SlidingSumAccumulator { + fn new(data_type: DataType) -> Self { + Self { + sum: T::Native::usize_as(0), + count: 0, + data_type, + } + } +} + +impl Accumulator for SlidingSumAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?, self.count.into()]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as u64; + if let Some(x) = arrow::compute::sum(values) { + self.sum = self.sum.add_wrapping(x) + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let values = states[0].as_primitive::(); + if let Some(x) = arrow::compute::sum(values) { + self.sum = self.sum.add_wrapping(x) + } + if let Some(x) = arrow::compute::sum(states[1].as_primitive::()) { + self.count += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let v = (self.count != 0).then_some(self.sum); + ScalarValue::new_primitive::(v, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + if let Some(x) = arrow::compute::sum(values) { + self.sum = self.sum.sub_wrapping(x) + } + self.count -= (values.len() - values.null_count()) as u64; + Ok(()) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + +struct DistinctSumAccumulator { + values: HashSet, RandomState>, + data_type: DataType, +} + +impl std::fmt::Debug for DistinctSumAccumulator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DistinctSumAccumulator({})", self.data_type) + } +} + +impl DistinctSumAccumulator { + pub fn try_new(data_type: &DataType) -> Result { + Ok(Self { + values: HashSet::default(), + data_type: data_type.clone(), + }) + } +} + +impl Accumulator for DistinctSumAccumulator { + fn state(&mut self) -> Result> { + // 1. Stores aggregate state in `ScalarValue::List` + // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set + let state_out = { + let distinct_values = self + .values + .iter() + .map(|value| { + ScalarValue::new_primitive::(Some(value.0), &self.data_type) + }) + .collect::>>()?; + + vec![ScalarValue::List(ScalarValue::new_list_nullable( + &distinct_values, + &self.data_type, + ))] + }; + Ok(state_out) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = values[0].as_primitive::(); + match array.nulls().filter(|x| x.null_count() > 0) { + Some(n) => { + for idx in n.valid_indices() { + self.values.insert(Hashable(array.value(idx))); + } + } + None => array.values().iter().for_each(|x| { + self.values.insert(Hashable(*x)); + }), + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + for x in states[0].as_list::().iter().flatten() { + self.update_batch(&[x])? + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let mut acc = T::Native::usize_as(0); + for distinct_value in self.values.iter() { + acc = acc.add_wrapping(distinct_value.0) + } + let v = (!self.values.is_empty()).then_some(acc); + ScalarValue::new_primitive::(v, &self.data_type) + } + + fn size(&self) -> usize { + size_of_val(self) + self.values.capacity() * size_of::() + } +} diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs new file mode 100644 index 000000000000..810247a2884a --- /dev/null +++ b/datafusion/functions-aggregate/src/variance.rs @@ -0,0 +1,614 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`VarianceSample`]: variance sample aggregations. +//! [`VariancePopulation`]: variance population aggregations. + +use arrow::{ + array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, + buffer::NullBuffer, + compute::kernels::cast, + datatypes::{DataType, Field}, +}; +use std::mem::{size_of, size_of_val}; +use std::sync::OnceLock; +use std::{fmt::Debug, sync::Arc}; + +use datafusion_common::{ + downcast_value, not_impl_err, plan_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::aggregate_doc_sections::DOC_SECTION_GENERAL; +use datafusion_expr::{ + function::{AccumulatorArgs, StateFieldsArgs}, + utils::format_state_name, + Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, + Volatility, +}; +use datafusion_functions_aggregate_common::{ + aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, +}; + +make_udaf_expr_and_func!( + VarianceSample, + var_sample, + expression, + "Computes the sample variance.", + var_samp_udaf +); + +make_udaf_expr_and_func!( + VariancePopulation, + var_pop, + expression, + "Computes the population variance.", + var_pop_udaf +); + +pub struct VarianceSample { + signature: Signature, + aliases: Vec, +} + +impl Debug for VarianceSample { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("VarianceSample") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for VarianceSample { + fn default() -> Self { + Self::new() + } +} + +impl VarianceSample { + pub fn new() -> Self { + Self { + aliases: vec![String::from("var_sample"), String::from("var_samp")], + signature: Signature::coercible( + vec![DataType::Float64], + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for VarianceSample { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "var" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("VAR(DISTINCT) aggregations are not available"); + } + + Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { + !acc_args.is_distinct + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(VarianceGroupsAccumulator::new(StatsType::Sample))) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_variance_sample_doc()) + } +} + +static VARIANCE_SAMPLE_DOC: OnceLock = OnceLock::new(); + +fn get_variance_sample_doc() -> &'static Documentation { + VARIANCE_SAMPLE_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the statistical sample variance of a set of numbers.", + ) + .with_syntax_example("var(expression)") + .with_standard_argument("expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +pub struct VariancePopulation { + signature: Signature, + aliases: Vec, +} + +impl Debug for VariancePopulation { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("VariancePopulation") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for VariancePopulation { + fn default() -> Self { + Self::new() + } +} + +impl VariancePopulation { + pub fn new() -> Self { + Self { + aliases: vec![String::from("var_population")], + signature: Signature::numeric(1, Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for VariancePopulation { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "var_pop" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Variance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ]) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); + } + + Ok(Box::new(VarianceAccumulator::try_new( + StatsType::Population, + )?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn groups_accumulator_supported(&self, acc_args: AccumulatorArgs) -> bool { + !acc_args.is_distinct + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(VarianceGroupsAccumulator::new( + StatsType::Population, + ))) + } + fn documentation(&self) -> Option<&Documentation> { + Some(get_variance_population_doc()) + } +} + +static VARIANCE_POPULATION_DOC: OnceLock = OnceLock::new(); + +fn get_variance_population_doc() -> &'static Documentation { + VARIANCE_POPULATION_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_GENERAL) + .with_description( + "Returns the statistical population variance of a set of numbers.", + ) + .with_syntax_example("var_pop(expression)") + .with_standard_argument("expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// An accumulator to compute variance +/// The algorithm used is an online implementation and numerically stable. It is based on this paper: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. + +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: f64, + mean: f64, + count: u64, + stats_type: StatsType, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + m2: 0_f64, + mean: 0_f64, + count: 0_u64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean(&self) -> f64 { + self.mean + } + + pub fn get_m2(&self) -> f64 { + self.m2 + } +} + +#[inline] +fn merge( + count: u64, + mean: f64, + m2: f64, + count2: u64, + mean2: f64, + m22: f64, +) -> (u64, f64, f64) { + let new_count = count + count2; + let new_mean = + mean * count as f64 / new_count as f64 + mean2 * count2 as f64 / new_count as f64; + let delta = mean - mean2; + let new_m2 = + m2 + m22 + delta * delta * count as f64 * count2 as f64 / new_count as f64; + + (new_count, new_mean, new_m2) +} + +#[inline] +fn update(count: u64, mean: f64, m2: f64, value: f64) -> (u64, f64, f64) { + let new_count = count + 1; + let delta1 = value - mean; + let new_mean = delta1 / new_count as f64 + mean; + let delta2 = value - new_mean; + let new_m2 = m2 + delta1 * delta2; + + (new_count, new_mean, new_m2) +} + +impl Accumulator for VarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + (self.count, self.mean, self.m2) = + update(self.count, self.mean, self.m2, value) + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = &cast(&values[0], &DataType::Float64)?; + let arr = downcast_value!(values, Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count - 1; + let delta1 = self.mean - value; + let new_mean = delta1 / new_count as f64 + self.mean; + let delta2 = new_mean - value; + let new_m2 = self.m2 - delta1 * delta2; + + self.count -= 1; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let means = downcast_value!(states[1], Float64Array); + let m2s = downcast_value!(states[2], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + (self.count, self.mean, self.m2) = merge( + self.count, + self.mean, + self.m2, + c, + means.value(i), + m2s.value(i), + ) + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + Ok(ScalarValue::Float64(match self.count { + 0 => None, + 1 => { + if let StatsType::Population = self.stats_type { + Some(0.0) + } else { + None + } + } + _ => Some(self.m2 / count as f64), + })) + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + +#[derive(Debug)] +pub struct VarianceGroupsAccumulator { + m2s: Vec, + means: Vec, + counts: Vec, + stats_type: StatsType, +} + +impl VarianceGroupsAccumulator { + pub fn new(s_type: StatsType) -> Self { + Self { + m2s: Vec::new(), + means: Vec::new(), + counts: Vec::new(), + stats_type: s_type, + } + } + + fn resize(&mut self, total_num_groups: usize) { + self.m2s.resize(total_num_groups, 0.0); + self.means.resize(total_num_groups, 0.0); + self.counts.resize(total_num_groups, 0); + } + + fn merge( + group_indices: &[usize], + counts: &UInt64Array, + means: &Float64Array, + m2s: &Float64Array, + opt_filter: Option<&BooleanArray>, + mut value_fn: F, + ) where + F: FnMut(usize, u64, f64, f64) + Send, + { + assert_eq!(counts.null_count(), 0); + assert_eq!(means.null_count(), 0); + assert_eq!(m2s.null_count(), 0); + + match opt_filter { + None => { + group_indices + .iter() + .zip(counts.values().iter()) + .zip(means.values().iter()) + .zip(m2s.values().iter()) + .for_each(|(((&group_index, &count), &mean), &m2)| { + value_fn(group_index, count, mean, m2); + }); + } + Some(filter) => { + group_indices + .iter() + .zip(counts.values().iter()) + .zip(means.values().iter()) + .zip(m2s.values().iter()) + .zip(filter.iter()) + .for_each( + |((((&group_index, &count), &mean), &m2), filter_value)| { + if let Some(true) = filter_value { + value_fn(group_index, count, mean, m2); + } + }, + ); + } + } + } + + pub fn variance( + &mut self, + emit_to: datafusion_expr::EmitTo, + ) -> (Vec, NullBuffer) { + let mut counts = emit_to.take_needed(&mut self.counts); + // means are only needed for updating m2s and are not needed for the final result. + // But we still need to take them to ensure the internal state is consistent. + let _ = emit_to.take_needed(&mut self.means); + let m2s = emit_to.take_needed(&mut self.m2s); + + if let StatsType::Sample = self.stats_type { + counts.iter_mut().for_each(|count| { + *count = count.saturating_sub(1); + }); + } + let nulls = NullBuffer::from_iter(counts.iter().map(|&count| count != 0)); + let variance = m2s + .iter() + .zip(counts) + .map(|(m2, count)| m2 / count as f64) + .collect(); + (variance, nulls) + } +} + +impl GroupsAccumulator for VarianceGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = &cast(&values[0], &DataType::Float64)?; + let values = downcast_value!(values, Float64Array); + + self.resize(total_num_groups); + accumulate(group_indices, values, opt_filter, |group_index, value| { + let (new_count, new_mean, new_m2) = update( + self.counts[group_index], + self.means[group_index], + self.m2s[group_index], + value, + ); + self.counts[group_index] = new_count; + self.means[group_index] = new_mean; + self.m2s[group_index] = new_m2; + }); + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 3, "two arguments to merge_batch"); + // first batch is counts, second is partial means, third is partial m2s + let partial_counts = downcast_value!(values[0], UInt64Array); + let partial_means = downcast_value!(values[1], Float64Array); + let partial_m2s = downcast_value!(values[2], Float64Array); + + self.resize(total_num_groups); + Self::merge( + group_indices, + partial_counts, + partial_means, + partial_m2s, + opt_filter, + |group_index, partial_count, partial_mean, partial_m2| { + let (new_count, new_mean, new_m2) = merge( + self.counts[group_index], + self.means[group_index], + self.m2s[group_index], + partial_count, + partial_mean, + partial_m2, + ); + self.counts[group_index] = new_count; + self.means[group_index] = new_mean; + self.m2s[group_index] = new_m2; + }, + ); + Ok(()) + } + + fn evaluate(&mut self, emit_to: datafusion_expr::EmitTo) -> Result { + let (variances, nulls) = self.variance(emit_to); + Ok(Arc::new(Float64Array::new(variances.into(), Some(nulls)))) + } + + fn state(&mut self, emit_to: datafusion_expr::EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let means = emit_to.take_needed(&mut self.means); + let m2s = emit_to.take_needed(&mut self.m2s); + + Ok(vec![ + Arc::new(UInt64Array::new(counts.into(), None)), + Arc::new(Float64Array::new(means.into(), None)), + Arc::new(Float64Array::new(m2s.into(), None)), + ]) + } + + fn size(&self) -> usize { + self.m2s.capacity() * size_of::() + + self.means.capacity() * size_of::() + + self.counts.capacity() * size_of::() + } +} diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs deleted file mode 100644 index ee064335c1cc..000000000000 --- a/datafusion/functions-array/src/array_has.rs +++ /dev/null @@ -1,324 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. - -use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use arrow::row::{RowConverter, SortField}; -use datafusion_common::cast::as_generic_list_array; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use itertools::Itertools; - -use crate::utils::check_datatypes; - -use std::any::Any; -use std::sync::Arc; - -// Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayHas, - array_has, - first_array second_array, // arg name - "returns true, if the element appears in the first array, otherwise false.", // doc - array_has_udf // internal function name -); -make_udf_function!(ArrayHasAll, - array_has_all, - first_array second_array, // arg name - "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc - array_has_all_udf // internal function name -); -make_udf_function!(ArrayHasAny, - array_has_any, - first_array second_array, // arg name - "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc - array_has_any_udf // internal function name -); - -#[derive(Debug)] -pub struct ArrayHas { - signature: Signature, - aliases: Vec, -} - -impl Default for ArrayHas { - fn default() -> Self { - Self::new() - } -} - -impl ArrayHas { - pub fn new() -> Self { - Self { - signature: Signature::array_and_element(Volatility::Immutable), - aliases: vec![ - String::from("array_has"), - String::from("list_has"), - String::from("array_contains"), - String::from("list_contains"), - ], - } - } -} - -impl ScalarUDFImpl for ArrayHas { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_has" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - - if args.len() != 2 { - return exec_err!("array_has needs two arguments"); - } - - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => general_array_has_dispatch::( - &args[0], - &args[1], - ComparisonType::Single, - ) - .map(ColumnarValue::Array), - DataType::LargeList(_) => general_array_has_dispatch::( - &args[0], - &args[1], - ComparisonType::Single, - ) - .map(ColumnarValue::Array), - _ => exec_err!("array_has does not support type '{array_type:?}'."), - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -#[derive(Debug)] -pub struct ArrayHasAll { - signature: Signature, - aliases: Vec, -} - -impl Default for ArrayHasAll { - fn default() -> Self { - Self::new() - } -} - -impl ArrayHasAll { - pub fn new() -> Self { - Self { - signature: Signature::any(2, Volatility::Immutable), - aliases: vec![String::from("array_has_all"), String::from("list_has_all")], - } - } -} - -impl ScalarUDFImpl for ArrayHasAll { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_has_all" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - if args.len() != 2 { - return exec_err!("array_has_all needs two arguments"); - } - - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) - .map(ColumnarValue::Array) - } - DataType::LargeList(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::All) - .map(ColumnarValue::Array) - } - _ => exec_err!("array_has_all does not support type '{array_type:?}'."), - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -#[derive(Debug)] -pub struct ArrayHasAny { - signature: Signature, - aliases: Vec, -} - -impl Default for ArrayHasAny { - fn default() -> Self { - Self::new() - } -} - -impl ArrayHasAny { - pub fn new() -> Self { - Self { - signature: Signature::any(2, Volatility::Immutable), - aliases: vec![String::from("array_has_any"), String::from("list_has_any")], - } - } -} - -impl ScalarUDFImpl for ArrayHasAny { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_has_any" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _: &[DataType]) -> Result { - Ok(DataType::Boolean) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - - if args.len() != 2 { - return exec_err!("array_has_any needs two arguments"); - } - - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) - .map(ColumnarValue::Array) - } - DataType::LargeList(_) => { - general_array_has_dispatch::(&args[0], &args[1], ComparisonType::Any) - .map(ColumnarValue::Array) - } - _ => exec_err!("array_has_any does not support type '{array_type:?}'."), - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -/// Represents the type of comparison for array_has. -#[derive(Debug, PartialEq)] -enum ComparisonType { - // array_has_all - All, - // array_has_any - Any, - // array_has - Single, -} - -fn general_array_has_dispatch( - array: &ArrayRef, - sub_array: &ArrayRef, - comparison_type: ComparisonType, -) -> Result { - let array = if comparison_type == ComparisonType::Single { - let arr = as_generic_list_array::(array)?; - check_datatypes("array_has", &[arr.values(), sub_array])?; - arr - } else { - check_datatypes("array_has", &[array, sub_array])?; - as_generic_list_array::(array)? - }; - - let mut boolean_builder = BooleanArray::builder(array.len()); - - let converter = RowConverter::new(vec![SortField::new(array.value_type())])?; - - let element = sub_array.clone(); - let sub_array = if comparison_type != ComparisonType::Single { - as_generic_list_array::(sub_array)? - } else { - array - }; - - for (row_idx, (arr, sub_arr)) in array.iter().zip(sub_array.iter()).enumerate() { - if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { - let arr_values = converter.convert_columns(&[arr])?; - let sub_arr_values = if comparison_type != ComparisonType::Single { - converter.convert_columns(&[sub_arr])? - } else { - converter.convert_columns(&[element.clone()])? - }; - - let mut res = match comparison_type { - ComparisonType::All => sub_arr_values - .iter() - .dedup() - .all(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Any => sub_arr_values - .iter() - .dedup() - .any(|elem| arr_values.iter().dedup().any(|x| x == elem)), - ComparisonType::Single => arr_values - .iter() - .dedup() - .any(|x| x == sub_arr_values.row(row_idx)), - }; - - if comparison_type == ComparisonType::Any { - res |= res; - } - - boolean_builder.append_value(res); - } - } - Ok(Arc::new(boolean_builder.finish())) -} diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs deleted file mode 100644 index 1c9e0c878e6e..000000000000 --- a/datafusion/functions-array/src/range.rs +++ /dev/null @@ -1,328 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for range and gen_series functions. - -use arrow::array::{Array, ArrayRef, Int64Array, ListArray}; -use arrow::datatypes::{DataType, Field}; -use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; -use std::any::Any; - -use crate::utils::make_scalar_function; -use arrow_array::types::{Date32Type, IntervalMonthDayNanoType}; -use arrow_array::Date32Array; -use arrow_schema::DataType::{Date32, Int64, Interval, List}; -use arrow_schema::IntervalUnit::MonthDayNano; -use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array}; -use datafusion_common::{exec_err, not_impl_datafusion_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, -}; -use std::sync::Arc; - -make_udf_function!( - Range, - range, - start stop step, - "create a list of values in the range between start and stop", - range_udf -); -#[derive(Debug)] -pub(super) struct Range { - signature: Signature, - aliases: Vec, -} -impl Range { - pub fn new() -> Self { - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("range")], - } - } -} -impl ScalarUDFImpl for Range { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "range" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - Int64 => make_scalar_function(|args| gen_range_inner(args, false))(args), - Date32 => make_scalar_function(|args| gen_range_date(args, false))(args), - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - GenSeries, - gen_series, - start stop step, - "create a list of values in the range between start and stop, include upper bound", - gen_series_udf -); -#[derive(Debug)] -pub(super) struct GenSeries { - signature: Signature, - aliases: Vec, -} -impl GenSeries { - pub fn new() -> Self { - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("generate_series")], - } - } -} -impl ScalarUDFImpl for GenSeries { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "generate_series" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - Int64 => make_scalar_function(|args| gen_range_inner(args, true))(args), - Date32 => make_scalar_function(|args| gen_range_date(args, true))(args), - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -/// Generates an array of integers from start to stop with a given step. -/// -/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. -/// It returns a `Result` representing the resulting ListArray after the operation. -/// -/// # Arguments -/// -/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. -/// -/// # Examples -/// -/// gen_range(3) => [0, 1, 2] -/// gen_range(1, 4) => [1, 2, 3] -/// gen_range(1, 7, 2) => [1, 3, 5] -pub(super) fn gen_range_inner( - args: &[ArrayRef], - include_upper: bool, -) -> Result { - let (start_array, stop_array, step_array) = match args.len() { - 1 => (None, as_int64_array(&args[0])?, None), - 2 => ( - Some(as_int64_array(&args[0])?), - as_int64_array(&args[1])?, - None, - ), - 3 => ( - Some(as_int64_array(&args[0])?), - as_int64_array(&args[1])?, - Some(as_int64_array(&args[2])?), - ), - _ => return exec_err!("gen_range expects 1 to 3 arguments"), - }; - - let mut values = vec![]; - let mut offsets = vec![0]; - let mut valid = BooleanBufferBuilder::new(stop_array.len()); - for (idx, stop) in stop_array.iter().enumerate() { - match retrieve_range_args(start_array, stop, step_array, idx) { - Some((_, _, 0)) => { - return exec_err!( - "step can't be 0 for function {}(start [, stop, step])", - if include_upper { - "generate_series" - } else { - "range" - } - ); - } - Some((start, stop, step)) => { - // Below, we utilize `usize` to represent steps. - // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. - let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| { - not_impl_datafusion_err!("step {} can't fit into usize", step) - })?; - values.extend( - gen_range_iter(start, stop, step < 0, include_upper) - .step_by(step_abs), - ); - offsets.push(values.len() as i32); - valid.append(true); - } - // If any of the arguments is NULL, append a NULL value to the result. - None => { - offsets.push(values.len() as i32); - valid.append(false); - } - }; - } - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", Int64, true)), - OffsetBuffer::new(offsets.into()), - Arc::new(Int64Array::from(values)), - Some(NullBuffer::new(valid.finish())), - )?); - Ok(arr) -} - -/// Get the (start, stop, step) args for the range and generate_series function. -/// If any of the arguments is NULL, returns None. -fn retrieve_range_args( - start_array: Option<&Int64Array>, - stop: Option, - step_array: Option<&Int64Array>, - idx: usize, -) -> Option<(i64, i64, i64)> { - // Default start value is 0 if not provided - let start = - start_array.map_or(Some(0), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; - let stop = stop?; - // Default step value is 1 if not provided - let step = - step_array.map_or(Some(1), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; - Some((start, stop, step)) -} - -/// Returns an iterator of i64 values from start to stop -fn gen_range_iter( - start: i64, - stop: i64, - decreasing: bool, - include_upper: bool, -) -> Box> { - match (decreasing, include_upper) { - // Decreasing range, stop is inclusive - (true, true) => Box::new((stop..=start).rev()), - // Decreasing range, stop is exclusive - (true, false) => { - if stop == i64::MAX { - // start is never greater than stop, and stop is exclusive, - // so the decreasing range must be empty. - Box::new(std::iter::empty()) - } else { - // Increase the stop value by one to exclude it. - // Since stop is not i64::MAX, `stop + 1` will not overflow. - Box::new((stop + 1..=start).rev()) - } - } - // Increasing range, stop is inclusive - (false, true) => Box::new(start..=stop), - // Increasing range, stop is exclusive - (false, false) => Box::new(start..stop), - } -} - -fn gen_range_date(args: &[ArrayRef], include_upper: bool) -> Result { - if args.len() != 3 { - return exec_err!("arguments length does not match"); - } - let (start_array, stop_array, step_array) = ( - Some(as_date32_array(&args[0])?), - as_date32_array(&args[1])?, - Some(as_interval_mdn_array(&args[2])?), - ); - - let mut values = vec![]; - let mut offsets = vec![0]; - for (idx, stop) in stop_array.iter().enumerate() { - let mut stop = stop.unwrap_or(0); - let start = start_array.as_ref().map(|x| x.value(idx)).unwrap_or(0); - let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); - let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); - let neg = months < 0 || days < 0; - if !include_upper { - stop = Date32Type::subtract_month_day_nano(stop, step); - } - let mut new_date = start; - loop { - if neg && new_date < stop || !neg && new_date > stop { - break; - } - values.push(new_date); - new_date = Date32Type::add_month_day_nano(new_date, step); - } - offsets.push(values.len() as i32); - } - - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", Date32, true)), - OffsetBuffer::new(offsets.into()), - Arc::new(Date32Array::from(values)), - None, - )?); - Ok(arr) -} diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs deleted file mode 100644 index 32d15b5563a5..000000000000 --- a/datafusion/functions-array/src/rewrite.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Rewrites for using Array Functions - -use crate::array_has::array_has_all; -use crate::concat::{array_append, array_concat, array_prepend}; -use crate::extract::{array_element, array_slice}; -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; -use datafusion_common::utils::list_ndims; -use datafusion_common::Result; -use datafusion_common::{Column, DFSchema}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Operator}; -use datafusion_functions::expr_fn::get_field; - -/// Rewrites expressions into function calls to array functions -pub(crate) struct ArrayFunctionRewriter {} - -impl FunctionRewrite for ArrayFunctionRewriter { - fn name(&self) -> &str { - "FunctionRewrite" - } - - fn rewrite( - &self, - expr: Expr, - schema: &DFSchema, - _config: &ConfigOptions, - ) -> Result> { - let transformed = match expr { - // array1 @> array2 -> array_has_all(array1, array2) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::AtArrow - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_has_all(*left, *right)) - } - - // array1 <@ array2 -> array_has_all(array2, array1) - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::ArrowAt - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_has_all(*right, *left)) - } - - // Column cases: - // 1) array_prepend/append/concat || column - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) - && as_col(&right).is_some() => - { - let c = as_col(&right).unwrap(); - let d = schema.field_from_column(c)?.data_type(); - let ndim = list_ndims(d); - match ndim { - 0 => Transformed::yes(array_append(*left, *right)), - _ => Transformed::yes(array_concat(vec![*left, *right])), - } - } - // 2) select column1 || column2 - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && as_col(&left).is_some() - && as_col(&right).is_some() => - { - let c1 = as_col(&left).unwrap(); - let c2 = as_col(&right).unwrap(); - let d1 = schema.field_from_column(c1)?.data_type(); - let d2 = schema.field_from_column(c2)?.data_type(); - let ndim1 = list_ndims(d1); - let ndim2 = list_ndims(d2); - match (ndim1, ndim2) { - (0, _) => Transformed::yes(array_prepend(*left, *right)), - (_, 0) => Transformed::yes(array_append(*left, *right)), - _ => Transformed::yes(array_concat(vec![*left, *right])), - } - } - - // Chain concat operator (a || b) || array, - // (array_concat, array_append, array_prepend) || array -> array concat - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) - && is_func(&right, "make_array") => - { - Transformed::yes(array_concat(vec![*left, *right])) - } - - // Chain concat operator (a || b) || scalar, - // (array_concat, array_append, array_prepend) || scalar -> array append - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_one_of_func( - &left, - &["array_append", "array_prepend", "array_concat"], - ) => - { - Transformed::yes(array_append(*left, *right)) - } - - // array || array -> array concat - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat - && is_func(&left, "make_array") - && is_func(&right, "make_array") => - { - Transformed::yes(array_concat(vec![*left, *right])) - } - - // array || scalar -> array append - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat && is_func(&left, "make_array") => - { - Transformed::yes(array_append(*left, *right)) - } - - // scalar || array -> array prepend - Expr::BinaryExpr(BinaryExpr { left, op, right }) - if op == Operator::StringConcat && is_func(&right, "make_array") => - { - Transformed::yes(array_prepend(*left, *right)) - } - - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::NamedStructField { name }, - }) => { - let name = Expr::Literal(name); - Transformed::yes(get_field(*expr, name)) - } - - // expr[idx] ==> array_element(expr, idx) - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::ListIndex { key }, - }) => Transformed::yes(array_element(*expr, *key)), - - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - Expr::GetIndexedField(GetIndexedField { - expr, - field: - GetFieldAccess::ListRange { - start, - stop, - stride, - }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, *stride)), - - _ => Transformed::no(expr), - }; - Ok(transformed) - } -} - -/// Returns true if expr is a function call to the specified named function. -/// Returns false otherwise. -fn is_func(expr: &Expr, func_name: &str) -> bool { - let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else { - return false; - }; - - func_def.name() == func_name -} - -/// Returns true if expr is a function call with one of the specified names -fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool { - let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else { - return false; - }; - - func_names.contains(&func_def.name()) -} - -/// returns Some(col) if this is Expr::Column -fn as_col(expr: &Expr) -> Option<&Column> { - if let Expr::Column(c) = expr { - Some(c) - } else { - None - } -} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs deleted file mode 100644 index 1462b3efad33..000000000000 --- a/datafusion/functions-array/src/udf.rs +++ /dev/null @@ -1,828 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ScalarUDFImpl`] definitions for array functions. - -use arrow::array::{NullArray, StringArray}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use arrow::datatypes::IntervalUnit::MonthDayNano; -use arrow_schema::DataType::{LargeUtf8, List, Utf8}; -use datafusion_common::exec_err; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::TypeSignature; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -// Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayToString, - array_to_string, - array delimiter, // arg name - "converts each element to its text representation.", // doc - array_to_string_udf // internal function name -); -#[derive(Debug)] -pub struct ArrayToString { - signature: Signature, - aliases: Vec, -} - -impl ArrayToString { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("array_to_string"), - String::from("list_to_string"), - String::from("array_join"), - String::from("list_join"), - ], - } - } -} - -impl ScalarUDFImpl for ArrayToString { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_to_string" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, - _ => { - return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_to_string(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!(StringToArray, - string_to_array, - string delimiter null_string, // arg name - "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc - string_to_array_udf // internal function name -); -#[derive(Debug)] -pub struct StringToArray { - signature: Signature, - aliases: Vec, -} - -impl StringToArray { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("string_to_array"), - String::from("string_to_list"), - ], - } - } -} - -impl ScalarUDFImpl for StringToArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "string_to_array" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { - List(Arc::new(Field::new("item", arg_types[0].clone(), true))) - } - _ => { - return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." - ); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let mut args = ColumnarValue::values_to_arrays(args)?; - // Case: delimiter is NULL, needs to be handled as well. - if args[1].as_any().is::() { - args[1] = Arc::new(StringArray::new_null(args[1].len())); - }; - - match args[0].data_type() { - Utf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - LargeUtf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - other => { - exec_err!("unsupported type for string_to_array function as {other}") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Range, - range, - start stop step, - "create a list of values in the range between start and stop", - range_udf -); -#[derive(Debug)] -pub struct Range { - signature: Signature, - aliases: Vec, -} -impl Range { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("range")], - } - } -} -impl ScalarUDFImpl for Range { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "range" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - match args[0].data_type() { - arrow::datatypes::DataType::Int64 => { - crate::kernels::gen_range(&args, false).map(ColumnarValue::Array) - } - arrow::datatypes::DataType::Date32 => { - crate::kernels::gen_range_date(&args, false).map(ColumnarValue::Array) - } - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - GenSeries, - gen_series, - start stop step, - "create a list of values in the range between start and stop, include upper bound", - gen_series_udf -); -#[derive(Debug)] -pub struct GenSeries { - signature: Signature, - aliases: Vec, -} -impl GenSeries { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("generate_series")], - } - } -} -impl ScalarUDFImpl for GenSeries { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "generate_series" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - match args[0].data_type() { - arrow::datatypes::DataType::Int64 => { - crate::kernels::gen_range(&args, true).map(ColumnarValue::Array) - } - arrow::datatypes::DataType::Date32 => { - crate::kernels::gen_range_date(&args, true).map(ColumnarValue::Array) - } - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayDims, - array_dims, - array, - "returns an array of the array's dimensions.", - array_dims_udf -); - -#[derive(Debug)] -pub struct ArrayDims { - signature: Signature, - aliases: Vec, -} - -impl ArrayDims { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_dims".to_string(), "list_dims".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayDims { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_dims" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => { - List(Arc::new(Field::new("item", UInt64, true))) - } - _ => { - return plan_err!("The array_dims function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_dims(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArraySort, - array_sort, - array desc null_first, - "returns sorted array.", - array_sort_udf -); - -#[derive(Debug)] -pub struct ArraySort { - signature: Signature, - aliases: Vec, -} - -impl ArraySort { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_sort".to_string(), "list_sort".to_string()], - } - } -} - -impl ScalarUDFImpl for ArraySort { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_sort" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_sort(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Cardinality, - cardinality, - array, - "returns the total number of elements in the array.", - cardinality_udf -); - -impl Cardinality { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("cardinality")], - } - } -} - -#[derive(Debug)] -pub struct Cardinality { - signature: Signature, - aliases: Vec, -} -impl ScalarUDFImpl for Cardinality { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "cardinality" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::cardinality(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayNdims, - array_ndims, - array, - "returns the number of dimensions of the array.", - array_ndims_udf -); - -#[derive(Debug)] -pub struct ArrayNdims { - signature: Signature, - aliases: Vec, -} -impl ArrayNdims { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("array_ndims"), String::from("list_ndims")], - } - } -} - -impl ScalarUDFImpl for ArrayNdims { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_ndims" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_ndims function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_ndims(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayEmpty, - array_empty, - array, - "returns true for an empty array or false for a non-empty array.", - array_empty_udf -); - -#[derive(Debug)] -pub struct ArrayEmpty { - signature: Signature, - aliases: Vec, -} -impl ArrayEmpty { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("empty")], - } - } -} - -impl ScalarUDFImpl for ArrayEmpty { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "empty" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_empty(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayRepeat, - array_repeat, - element count, // arg name - "returns an array containing element `count` times.", // doc - array_repeat_udf // internal function name -); -#[derive(Debug)] -pub struct ArrayRepeat { - signature: Signature, - aliases: Vec, -} - -impl ArrayRepeat { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_repeat"), String::from("list_repeat")], - } - } -} - -impl ScalarUDFImpl for ArrayRepeat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_repeat" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_repeat(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayLength, - array_length, - array, - "returns the length of the array dimension.", - array_length_udf -); - -#[derive(Debug)] -pub struct ArrayLength { - signature: Signature, - aliases: Vec, -} -impl ArrayLength { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_length"), String::from("list_length")], - } - } -} - -impl ScalarUDFImpl for ArrayLength { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_length" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_length(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Flatten, - flatten, - array, - "flattens an array of arrays into a single array.", - flatten_udf -); - -#[derive(Debug)] -pub struct Flatten { - signature: Signature, - aliases: Vec, -} -impl Flatten { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("flatten")], - } - } -} - -impl ScalarUDFImpl for Flatten { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "flatten" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) - } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(field.clone())), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - let data_type = get_base_type(&arg_types[0])?; - Ok(data_type) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::flatten(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayDistinct, - array_distinct, - array, - "return distinct values from the array after removing duplicates.", - array_distinct_udf -); - -#[derive(Debug)] -pub struct ArrayDistinct { - signature: Signature, - aliases: Vec, -} - -impl crate::udf::ArrayDistinct { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_distinct".to_string(), "list_distinct".to_string()], - } - } -} - -impl ScalarUDFImpl for crate::udf::ArrayDistinct { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_distinct" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_distinct(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-nested/Cargo.toml similarity index 82% rename from datafusion/functions-array/Cargo.toml rename to datafusion/functions-nested/Cargo.toml index eb1ef9e03f31..bdfb07031b8c 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-nested/Cargo.toml @@ -16,8 +16,8 @@ # under the License. [package] -name = "datafusion-functions-array" -description = "Array Function packages for the DataFusion query engine" +name = "datafusion-functions-nested" +description = "Nested Type Function packages for the DataFusion query engine" keywords = ["datafusion", "logical", "plan", "expressions"] readme = "README.md" version = { workspace = true } @@ -34,7 +34,7 @@ workspace = true [features] [lib] -name = "datafusion_functions_array" +name = "datafusion_functions_nested" path = "src/lib.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -49,9 +49,12 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true } -itertools = { version = "0.12", features = ["use_std"] } +datafusion-functions-aggregate = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "1.0.14" +rand = "0.8.5" [dev-dependencies] criterion = { version = "0.5", features = ["async_tokio"] } @@ -59,3 +62,7 @@ criterion = { version = "0.5", features = ["async_tokio"] } [[bench]] harness = false name = "array_expression" + +[[bench]] +harness = false +name = "map" diff --git a/datafusion/functions-array/README.md b/datafusion/functions-nested/README.md similarity index 87% rename from datafusion/functions-array/README.md rename to datafusion/functions-nested/README.md index 25deca8e1c77..8a5047c838ab 100644 --- a/datafusion/functions-array/README.md +++ b/datafusion/functions-nested/README.md @@ -17,11 +17,11 @@ under the License. --> -# DataFusion Array Function Library +# DataFusion Nested Type Function Library [DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. -This crate contains functions for working with arrays, such as `array_append` that work with +This crate contains functions for working with arrays, maps and structs, such as `array_append` that work with `ListArray`, `LargeListArray` and `FixedListArray` types from the `arrow` crate. [df]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-array/benches/array_expression.rs b/datafusion/functions-nested/benches/array_expression.rs similarity index 95% rename from datafusion/functions-array/benches/array_expression.rs rename to datafusion/functions-nested/benches/array_expression.rs index 48b829793cef..0e3ecbc72641 100644 --- a/datafusion/functions-array/benches/array_expression.rs +++ b/datafusion/functions-nested/benches/array_expression.rs @@ -21,7 +21,7 @@ extern crate arrow; use crate::criterion::Criterion; use datafusion_expr::lit; -use datafusion_functions_array::expr_fn::{array_replace_all, make_array}; +use datafusion_functions_nested::expr_fn::{array_replace_all, make_array}; fn criterion_benchmark(c: &mut Criterion) { // Construct large arrays for benchmarking diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs new file mode 100644 index 000000000000..3c4a09c65992 --- /dev/null +++ b/datafusion/functions-nested/benches/map.rs @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow_array::{Int32Array, ListArray, StringArray}; +use arrow_buffer::{OffsetBuffer, ScalarBuffer}; +use arrow_schema::{DataType, Field}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use rand::prelude::ThreadRng; +use rand::Rng; +use std::collections::HashSet; +use std::sync::Arc; + +use datafusion_common::ScalarValue; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_functions_nested::map::map_udf; +use datafusion_functions_nested::planner::NestedFunctionPlanner; + +fn keys(rng: &mut ThreadRng) -> Vec { + let mut keys = HashSet::with_capacity(1000); + + while keys.len() < 1000 { + keys.insert(rng.gen_range(0..10000).to_string()); + } + + keys.into_iter().collect() +} + +fn values(rng: &mut ThreadRng) -> Vec { + let mut values = HashSet::with_capacity(1000); + + while values.len() < 1000 { + values.insert(rng.gen_range(0..10000)); + } + values.into_iter().collect() +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("make_map_1000", |b| { + let mut rng = rand::thread_rng(); + let keys = keys(&mut rng); + let values = values(&mut rng); + let mut buffer = Vec::new(); + for i in 0..1000 { + buffer.push(Expr::Literal(ScalarValue::Utf8(Some(keys[i].clone())))); + buffer.push(Expr::Literal(ScalarValue::Int32(Some(values[i])))); + } + + let planner = NestedFunctionPlanner {}; + + b.iter(|| { + black_box( + planner + .plan_make_map(buffer.clone()) + .expect("map should work on valid values"), + ); + }); + }); + + c.bench_function("map_1000", |b| { + let mut rng = rand::thread_rng(); + let field = Arc::new(Field::new("item", DataType::Utf8, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let key_list = ListArray::new( + field, + offsets, + Arc::new(StringArray::from(keys(&mut rng))), + None, + ); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, 1000])); + let value_list = ListArray::new( + field, + offsets, + Arc::new(Int32Array::from(values(&mut rng))), + None, + ); + let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); + let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); + + b.iter(|| { + black_box( + #[allow(deprecated)] // TODO use invoke_batch + map_udf() + .invoke(&[keys.clone(), values.clone()]) + .expect("map should work on valid values"), + ); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs new file mode 100644 index 000000000000..fe1d05199e80 --- /dev/null +++ b/datafusion/functions-nested/src/array_has.rs @@ -0,0 +1,561 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. + +use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use arrow::row::{RowConverter, Rows, SortField}; +use arrow_array::{Datum, GenericListArray, Scalar}; +use arrow_buffer::BooleanBuffer; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::string_utils::string_array_to_vec; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_physical_expr_common::datum::compare_with_eq; +use itertools::Itertools; + +use crate::utils::make_scalar_function; + +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +// Create static instances of ScalarUDFs for each function +make_udf_expr_and_func!(ArrayHas, + array_has, + haystack_array element, // arg names + "returns true, if the element appears in the first array, otherwise false.", // doc + array_has_udf // internal function name +); +make_udf_expr_and_func!(ArrayHasAll, + array_has_all, + haystack_array needle_array, // arg names + "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc + array_has_all_udf // internal function name +); +make_udf_expr_and_func!(ArrayHasAny, + array_has_any, + haystack_array needle_array, // arg names + "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc + array_has_any_udf // internal function name +); + +#[derive(Debug)] +pub struct ArrayHas { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayHas { + fn default() -> Self { + Self::new() + } +} + +impl ArrayHas { + pub fn new() -> Self { + Self { + signature: Signature::array_and_element(Volatility::Immutable), + aliases: vec![ + String::from("list_has"), + String::from("array_contains"), + String::from("list_contains"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayHas { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_has" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match &args[1] { + ColumnarValue::Array(array_needle) => { + // the needle is already an array, convert the haystack to an array of the same length + let haystack = args[0].to_owned().into_array(array_needle.len())?; + let array = array_has_inner_for_array(&haystack, array_needle)?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(scalar_needle) => { + // Always return null if the second argument is null + // i.e. array_has(array, null) -> null + if scalar_needle.is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + } + + // since the needle is a scalar, convert it to an array of size 1 + let haystack = args[0].to_owned().into_array(1)?; + let needle = scalar_needle.to_array_of_size(1)?; + let needle = Scalar::new(needle); + let array = array_has_inner_for_scalar(&haystack, &needle)?; + if let ColumnarValue::Scalar(_) = &args[0] { + // If both inputs are scalar, keeps output as scalar + let scalar_value = ScalarValue::try_from_array(&array, 0)?; + Ok(ColumnarValue::Scalar(scalar_value)) + } else { + Ok(ColumnarValue::Array(array)) + } + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_has_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_has_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns true if the array contains the element.", + ) + .with_syntax_example("array_has(array, element)") + .with_sql_example( + r#"```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) +} + +fn array_has_inner_for_scalar( + haystack: &ArrayRef, + needle: &dyn Datum, +) -> Result { + match haystack.data_type() { + DataType::List(_) => array_has_dispatch_for_scalar::(haystack, needle), + DataType::LargeList(_) => array_has_dispatch_for_scalar::(haystack, needle), + _ => exec_err!( + "array_has does not support type '{:?}'.", + haystack.data_type() + ), + } +} + +fn array_has_inner_for_array(haystack: &ArrayRef, needle: &ArrayRef) -> Result { + match haystack.data_type() { + DataType::List(_) => array_has_dispatch_for_array::(haystack, needle), + DataType::LargeList(_) => array_has_dispatch_for_array::(haystack, needle), + _ => exec_err!( + "array_has does not support type '{:?}'.", + haystack.data_type() + ), + } +} + +fn array_has_dispatch_for_array( + haystack: &ArrayRef, + needle: &ArrayRef, +) -> Result { + let haystack = as_generic_list_array::(haystack)?; + let mut boolean_builder = BooleanArray::builder(haystack.len()); + + for (i, arr) in haystack.iter().enumerate() { + if arr.is_none() || needle.is_null(i) { + boolean_builder.append_null(); + continue; + } + let arr = arr.unwrap(); + let is_nested = arr.data_type().is_nested(); + let needle_row = Scalar::new(needle.slice(i, 1)); + let eq_array = compare_with_eq(&arr, &needle_row, is_nested)?; + let is_contained = eq_array.true_count() > 0; + boolean_builder.append_value(is_contained) + } + + Ok(Arc::new(boolean_builder.finish())) +} + +fn array_has_dispatch_for_scalar( + haystack: &ArrayRef, + needle: &dyn Datum, +) -> Result { + let haystack = as_generic_list_array::(haystack)?; + let values = haystack.values(); + let is_nested = values.data_type().is_nested(); + let offsets = haystack.value_offsets(); + // If first argument is empty list (second argument is non-null), return false + // i.e. array_has([], non-null element) -> false + if values.len() == 0 { + return Ok(Arc::new(BooleanArray::new( + BooleanBuffer::new_unset(haystack.len()), + None, + ))); + } + let eq_array = compare_with_eq(values, needle, is_nested)?; + let mut final_contained = vec![None; haystack.len()]; + for (i, offset) in offsets.windows(2).enumerate() { + let start = offset[0].to_usize().unwrap(); + let end = offset[1].to_usize().unwrap(); + let length = end - start; + // For non-nested list, length is 0 for null + if length == 0 { + continue; + } + let sliced_array = eq_array.slice(start, length); + // For nested list, check number of nulls + if sliced_array.null_count() != length { + final_contained[i] = Some(sliced_array.true_count() > 0); + } + } + + Ok(Arc::new(BooleanArray::from(final_contained))) +} + +fn array_has_all_inner(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::List(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) + } + DataType::LargeList(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::All) + } + _ => exec_err!( + "array_has does not support type '{:?}'.", + args[0].data_type() + ), + } +} + +fn array_has_any_inner(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::List(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + DataType::LargeList(_) => { + array_has_all_and_any_dispatch::(&args[0], &args[1], ComparisonType::Any) + } + _ => exec_err!( + "array_has does not support type '{:?}'.", + args[0].data_type() + ), + } +} + +#[derive(Debug)] +pub struct ArrayHasAll { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayHasAll { + fn default() -> Self { + Self::new() + } +} + +impl ArrayHasAll { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_has_all")], + } + } +} + +impl ScalarUDFImpl for ArrayHasAll { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_has_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_has_all_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_has_all_doc()) + } +} + +fn get_array_has_all_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns true if all elements of sub-array exist in array.", + ) + .with_syntax_example("array_has_all(array, sub-array)") + .with_sql_example( + r#"```sql +> select array_has_all([1, 2, 3, 4], [2, 3]); ++--------------------------------------------+ +| array_has_all(List([1,2,3,4]), List([2,3])) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "sub-array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) +} + +#[derive(Debug)] +pub struct ArrayHasAny { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayHasAny { + fn default() -> Self { + Self::new() + } +} + +impl ArrayHasAny { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_has_any")], + } + } +} + +impl ScalarUDFImpl for ArrayHasAny { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_has_any" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_has_any_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_has_any_doc()) + } +} + +fn get_array_has_any_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns true if any elements exist in both arrays.", + ) + .with_syntax_example("array_has_any(array, sub-array)") + .with_sql_example( + r#"```sql +> select array_has_any([1, 2, 3], [3, 4]); ++------------------------------------------+ +| array_has_any(List([1,2,3]), List([3,4])) | ++------------------------------------------+ +| true | ++------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "sub-array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) +} + +/// Represents the type of comparison for array_has. +#[derive(Debug, PartialEq, Clone, Copy)] +enum ComparisonType { + // array_has_all + All, + // array_has_any + Any, +} + +fn array_has_all_and_any_dispatch( + haystack: &ArrayRef, + needle: &ArrayRef, + comparison_type: ComparisonType, +) -> Result { + let haystack = as_generic_list_array::(haystack)?; + let needle = as_generic_list_array::(needle)?; + match needle.data_type() { + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + array_has_all_and_any_string_internal::(haystack, needle, comparison_type) + } + _ => general_array_has_for_all_and_any::(haystack, needle, comparison_type), + } +} + +// String comparison for array_has_all and array_has_any +fn array_has_all_and_any_string_internal( + array: &GenericListArray, + needle: &GenericListArray, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(array.len()); + for (arr, sub_arr) in array.iter().zip(needle.iter()) { + match (arr, sub_arr) { + (Some(arr), Some(sub_arr)) => { + let haystack_array = string_array_to_vec(&arr); + let needle_array = string_array_to_vec(&sub_arr); + boolean_builder.append_value(array_has_string_kernel( + haystack_array, + needle_array, + comparison_type, + )); + } + (_, _) => { + boolean_builder.append_null(); + } + } + } + + Ok(Arc::new(boolean_builder.finish())) +} + +fn array_has_string_kernel( + haystack: Vec>, + needle: Vec>, + comparison_type: ComparisonType, +) -> bool { + match comparison_type { + ComparisonType::All => needle + .iter() + .dedup() + .all(|x| haystack.iter().dedup().any(|y| y == x)), + ComparisonType::Any => needle + .iter() + .dedup() + .any(|x| haystack.iter().dedup().any(|y| y == x)), + } +} + +// General row comparison for array_has_all and array_has_any +fn general_array_has_for_all_and_any( + haystack: &GenericListArray, + needle: &GenericListArray, + comparison_type: ComparisonType, +) -> Result { + let mut boolean_builder = BooleanArray::builder(haystack.len()); + let converter = RowConverter::new(vec![SortField::new(haystack.value_type())])?; + + for (arr, sub_arr) in haystack.iter().zip(needle.iter()) { + if let (Some(arr), Some(sub_arr)) = (arr, sub_arr) { + let arr_values = converter.convert_columns(&[arr])?; + let sub_arr_values = converter.convert_columns(&[sub_arr])?; + boolean_builder.append_value(general_array_has_all_and_any_kernel( + arr_values, + sub_arr_values, + comparison_type, + )); + } else { + boolean_builder.append_null(); + } + } + + Ok(Arc::new(boolean_builder.finish())) +} + +fn general_array_has_all_and_any_kernel( + haystack_rows: Rows, + needle_rows: Rows, + comparison_type: ComparisonType, +) -> bool { + match comparison_type { + ComparisonType::All => needle_rows.iter().all(|needle_row| { + haystack_rows + .iter() + .any(|haystack_row| haystack_row == needle_row) + }), + ComparisonType::Any => needle_rows.iter().any(|needle_row| { + haystack_rows + .iter() + .any(|haystack_row| haystack_row == needle_row) + }), + } +} diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-nested/src/cardinality.rs similarity index 55% rename from datafusion/functions-array/src/cardinality.rs rename to datafusion/functions-nested/src/cardinality.rs index ed9f8d01f973..b6661e0807f4 100644 --- a/datafusion/functions-array/src/cardinality.rs +++ b/datafusion/functions-nested/src/cardinality.rs @@ -18,30 +18,41 @@ //! [`ScalarUDFImpl`] definitions for cardinality function. use crate::utils::make_scalar_function; -use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array}; +use arrow_array::{ + Array, ArrayRef, GenericListArray, MapArray, OffsetSizeTrait, UInt64Array, +}; use arrow_schema::DataType; -use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; -use datafusion_common::cast::{as_large_list_array, as_list_array}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, Map, UInt64}; +use datafusion_common::cast::{as_large_list_array, as_list_array, as_map_array}; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( Cardinality, cardinality, array, - "returns the total number of elements in the array.", + "returns the total number of elements in the array or map.", cardinality_udf ); impl Cardinality { pub fn new() -> Self { Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("cardinality")], + signature: Signature::one_of( + vec![ + TypeSignature::ArraySignature(ArrayFunctionSignature::Array), + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + ], + Volatility::Immutable, + ), + aliases: vec![], } } } @@ -65,9 +76,9 @@ impl ScalarUDFImpl for Cardinality { fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + List(_) | LargeList(_) | FixedSizeList(_, _) | Map(_, _) => UInt64, _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); + return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList/Map."); } }) } @@ -79,6 +90,39 @@ impl ScalarUDFImpl for Cardinality { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_cardinality_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cardinality_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the total number of elements in the array.", + ) + .with_syntax_example("cardinality(array)") + .with_sql_example( + r#"```sql +> select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); ++--------------------------------------+ +| cardinality(List([1,2,3,4,5,6,7,8])) | ++--------------------------------------+ +| 8 | ++--------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Cardinality SQL function @@ -96,12 +140,24 @@ pub fn cardinality_inner(args: &[ArrayRef]) -> Result { let list_array = as_large_list_array(&args[0])?; generic_list_cardinality::(list_array) } + Map(_, _) => { + let map_array = as_map_array(&args[0])?; + generic_map_cardinality(map_array) + } other => { exec_err!("cardinality does not support type '{:?}'", other) } } } +fn generic_map_cardinality(array: &MapArray) -> Result { + let result: UInt64Array = array + .iter() + .map(|opt_arr| opt_arr.map(|arr| arr.len() as u64)) + .collect(); + Ok(Arc::new(result)) +} + fn generic_list_cardinality( array: &GenericListArray, ) -> Result { diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-nested/src/concat.rs similarity index 76% rename from datafusion/functions-array/src/concat.rs rename to datafusion/functions-nested/src/concat.rs index f9d9bf4356ff..1bdcf74aee2a 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-nested/src/concat.rs @@ -17,7 +17,8 @@ //! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. -use std::{any::Any, cmp::Ordering, sync::Arc}; +use std::sync::{Arc, OnceLock}; +use std::{any::Any, cmp::Ordering}; use arrow::array::{Capacities, MutableArrayData}; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; @@ -27,16 +28,15 @@ use datafusion_common::Result; use datafusion_common::{ cast::as_generic_list_array, exec_err, not_impl_err, plan_err, utils::list_ndims, }; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::{ - type_coercion::binary::get_wider_type, ColumnarValue, ScalarUDFImpl, Signature, - Volatility, + type_coercion::binary::get_wider_type, ColumnarValue, Documentation, ScalarUDFImpl, + Signature, Volatility, }; use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayAppend, array_append, array element, // arg name @@ -61,7 +61,6 @@ impl ArrayAppend { Self { signature: Signature::array_and_element(Volatility::Immutable), aliases: vec![ - String::from("array_append"), String::from("list_append"), String::from("array_push_back"), String::from("list_push_back"), @@ -94,9 +93,46 @@ impl ScalarUDFImpl for ArrayAppend { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_append_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_append_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Appends an element to the end of an array.", + ) + .with_syntax_example("array_append(array, element)") + .with_sql_example( + r#"```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to append to the array.", + ) + .build() + .unwrap() + }) } -make_udf_function!( +make_udf_expr_and_func!( ArrayPrepend, array_prepend, element array, @@ -121,7 +157,6 @@ impl ArrayPrepend { Self { signature: Signature::element_and_array(Volatility::Immutable), aliases: vec![ - String::from("array_prepend"), String::from("list_prepend"), String::from("array_push_front"), String::from("list_push_front"), @@ -154,9 +189,44 @@ impl ScalarUDFImpl for ArrayPrepend { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_prepend_doc()) + } } -make_udf_function!( +fn get_array_prepend_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Prepends an element to the beginning of an array.", + ) + .with_syntax_example("array_prepend(element, array)") + .with_sql_example( + r#"```sql +> select array_prepend(1, [2, 3, 4]); ++---------------------------------------+ +| array_prepend(Int64(1),List([2,3,4])) | ++---------------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------------+ +```"#, + ) + .with_argument( + "element", + "Element to prepend to the array.", + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) +} + +make_udf_expr_and_func!( ArrayConcat, array_concat, "Concatenates arrays.", @@ -180,7 +250,6 @@ impl ArrayConcat { Self { signature: Signature::variadic_any(Volatility::Immutable), aliases: vec![ - String::from("array_concat"), String::from("array_cat"), String::from("list_concat"), String::from("list_cat"), @@ -238,6 +307,41 @@ impl ScalarUDFImpl for ArrayConcat { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_concat_doc()) + } +} + +fn get_array_concat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Concatenates arrays.", + ) + .with_syntax_example("array_concat(array[, ..., array_n])") + .with_sql_example( + r#"```sql +> select array_concat([1, 2], [3, 4], [5, 6]); ++---------------------------------------------------+ +| array_concat(List([1,2]),List([3,4]),List([5,6])) | ++---------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++---------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression to concatenate. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array_n", + "Subsequent array column or literal array to concatenate.", + ) + .build() + .unwrap() + }) } /// Array_concat/Array_cat SQL function @@ -254,7 +358,7 @@ pub(crate) fn array_concat_inner(args: &[ArrayRef]) -> Result { return not_impl_err!("Array is not type '{base_type:?}'."); } if !base_type.eq(&DataType::Null) { - new_args.push(arg.clone()); + new_args.push(Arc::clone(arg)); } } diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-nested/src/dimension.rs similarity index 69% rename from datafusion/functions-array/src/dimension.rs rename to datafusion/functions-nested/src/dimension.rs index 569eff66f7f4..7df0ed2b40bd 100644 --- a/datafusion/functions-array/src/dimension.rs +++ b/datafusion/functions-nested/src/dimension.rs @@ -29,11 +29,13 @@ use datafusion_common::{exec_err, plan_err, Result}; use crate::utils::{compute_array_dims, make_scalar_function}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use arrow_schema::Field; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -use std::sync::Arc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayDims, array_dims, array, @@ -51,7 +53,7 @@ impl ArrayDims { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_dims".to_string(), "list_dims".to_string()], + aliases: vec!["list_dims".to_string()], } } } @@ -86,9 +88,42 @@ impl ScalarUDFImpl for ArrayDims { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_dims_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_dims_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of the array's dimensions.", + ) + .with_syntax_example("array_dims(array)") + .with_sql_example( + r#"```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } -make_udf_function!( +make_udf_expr_and_func!( ArrayNdims, array_ndims, array, @@ -105,7 +140,7 @@ impl ArrayNdims { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("array_ndims"), String::from("list_ndims")], + aliases: vec![String::from("list_ndims")], } } } @@ -138,6 +173,41 @@ impl ScalarUDFImpl for ArrayNdims { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_ndims_doc()) + } +} + +fn get_array_ndims_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the number of dimensions of the array.", + ) + .with_syntax_example("array_ndims(array, element)") + .with_sql_example( + r#"```sql +> select array_ndims([[1, 2, 3], [4, 5, 6]]); ++----------------------------------+ +| array_ndims(List([1,2,3,4,5,6])) | ++----------------------------------+ +| 2 | ++----------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Array element.", + ) + .build() + .unwrap() + }) } /// Array_dims SQL function diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs new file mode 100644 index 000000000000..4f890e4166e9 --- /dev/null +++ b/datafusion/functions-nested/src/distance.rs @@ -0,0 +1,271 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [ScalarUDFImpl] definitions for array_distance function. + +use crate::utils::{downcast_arg, make_scalar_function}; +use arrow_array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, Float64, LargeList, List}; +use core::any::type_name; +use datafusion_common::cast::{ + as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, + as_int64_array, +}; +use datafusion_common::utils::coerced_fixed_size_list_to_list; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +make_udf_expr_and_func!( + ArrayDistance, + array_distance, + array, + "returns the Euclidean distance between two numeric arrays.", + array_distance_udf +); + +#[derive(Debug)] +pub(super) struct ArrayDistance { + signature: Signature, + aliases: Vec, +} + +impl ArrayDistance { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_distance".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayDistance { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_distance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Ok(Float64), + _ => exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!("array_distance expects exactly two arguments"); + } + let mut result = Vec::new(); + for arg_type in arg_types { + match arg_type { + List(_) | LargeList(_) | FixedSizeList(_, _) => result.push(coerced_fixed_size_list_to_list(arg_type)), + _ => return exec_err!("The array_distance function can only accept List/LargeList/FixedSizeList."), + } + } + + Ok(result) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_distance_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_distance_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_distance_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the Euclidean distance between two input arrays of equal length.", + ) + .with_syntax_example("array_distance(array1, array2)") + .with_sql_example( + r#"```sql +> select array_distance([1, 2], [1, 4]); ++------------------------------------+ +| array_distance(List([1,2], [1,4])) | ++------------------------------------+ +| 2.0 | ++------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) +} + +pub fn array_distance_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("array_distance expects exactly two arguments"); + } + + match (&args[0].data_type(), &args[1].data_type()) { + (List(_), List(_)) => general_array_distance::(args), + (LargeList(_), LargeList(_)) => general_array_distance::(args), + (array_type1, array_type2) => { + exec_err!("array_distance does not support types '{array_type1:?}' and '{array_type2:?}'") + } + } +} + +fn general_array_distance(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let result = list_array1 + .iter() + .zip(list_array2.iter()) + .map(|(arr1, arr2)| compute_array_distance(arr1, arr2)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Computes the Euclidean distance between two arrays +fn compute_array_distance( + arr1: Option, + arr2: Option, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + // Check for NULL values inside the arrays + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + let sum_squares: f64 = values1 + .iter() + .zip(values2.iter()) + .map(|(v1, v2)| { + let diff = v1.unwrap_or(0.0) - v2.unwrap_or(0.0); + diff * diff + }) + .sum(); + + Ok(Some(sum_squares.sqrt())) +} + +/// Converts an array of any numeric type to a Float64Array. +fn convert_to_f64_array(array: &ArrayRef) -> Result { + match array.data_type() { + Float64 => Ok(as_float64_array(array)?.clone()), + DataType::Float32 => { + let array = as_float32_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + DataType::Int64 => { + let array = as_int64_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + DataType::Int32 => { + let array = as_int32_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + _ => exec_err!("Unsupported array type for conversion to Float64Array"), + } +} diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-nested/src/empty.rs similarity index 69% rename from datafusion/functions-array/src/empty.rs rename to datafusion/functions-nested/src/empty.rs index d5fa174eee5f..5d310eb23952 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-nested/src/empty.rs @@ -21,14 +21,16 @@ use crate::utils::make_scalar_function; use arrow_array::{ArrayRef, BooleanArray, OffsetSizeTrait}; use arrow_schema::DataType; use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; -use datafusion_common::cast::{as_generic_list_array, as_null_array}; +use datafusion_common::cast::as_generic_list_array; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayEmpty, array_empty, array, @@ -45,11 +47,7 @@ impl ArrayEmpty { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec![ - "empty".to_string(), - "array_empty".to_string(), - "list_empty".to_string(), - ], + aliases: vec!["array_empty".to_string(), "list_empty".to_string()], } } } @@ -82,6 +80,39 @@ impl ScalarUDFImpl for ArrayEmpty { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_empty_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_empty_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns 1 for an empty array or 0 for a non-empty array.", + ) + .with_syntax_example("empty(array)") + .with_sql_example( + r#"```sql +> select empty([1]); ++------------------+ +| empty(List([1])) | ++------------------+ +| 0 | ++------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Array_empty SQL function @@ -90,12 +121,7 @@ pub fn array_empty_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_empty expects one argument"); } - if as_null_array(&args[0]).is_ok() { - // Make sure to return Boolean type. - return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); - } let array_type = args[0].data_type(); - match array_type { List(_) => general_array_empty::(&args[0]), LargeList(_) => general_array_empty::(&args[0]), @@ -105,9 +131,10 @@ pub fn array_empty_inner(args: &[ArrayRef]) -> Result { fn general_array_empty(array: &ArrayRef) -> Result { let array = as_generic_list_array::(array)?; + let builder = array .iter() - .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) + .map(|arr| arr.map(|arr| arr.is_empty())) .collect::(); Ok(Arc::new(builder)) } diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-nested/src/except.rs similarity index 72% rename from datafusion/functions-array/src/except.rs rename to datafusion/functions-nested/src/except.rs index 444c7c758771..947d3c018221 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-nested/src/except.rs @@ -24,14 +24,15 @@ use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, FieldRef}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; use std::collections::HashSet; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayExcept, array_except, first_array second_array, @@ -49,7 +50,7 @@ impl ArrayExcept { pub fn new() -> Self { Self { signature: Signature::any(2, Volatility::Immutable), - aliases: vec!["array_except".to_string(), "list_except".to_string()], + aliases: vec!["list_except".to_string()], } } } @@ -80,6 +81,49 @@ impl ScalarUDFImpl for ArrayExcept { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_except_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_except_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of the elements that appear in the first array but not in the second.", + ) + .with_syntax_example("array_except(array1, array2)") + .with_sql_example( + r#"```sql +> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Array_except SQL function diff --git a/datafusion/functions-nested/src/expr_ext.rs b/datafusion/functions-nested/src/expr_ext.rs new file mode 100644 index 000000000000..4da4a3f583b7 --- /dev/null +++ b/datafusion/functions-nested/src/expr_ext.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Extension methods for Expr. + +use datafusion_expr::Expr; + +use crate::extract::{array_element, array_slice}; + +/// Return access to the element field. Example `expr["name"]` +/// +/// ## Example Access element 2 from column "c1" +/// +/// For example if column "c1" holds documents like this +/// +/// ```json +/// [10, 20, 30, 40] +/// ``` +/// +/// You can access the value "30" with +/// +/// ``` +/// # use datafusion_expr::{lit, col, Expr}; +/// # use datafusion_functions_nested::expr_ext::IndexAccessor; +/// let expr = col("c1") +/// .index(lit(3)); +/// assert_eq!(expr.schema_name().to_string(), "c1[Int32(3)]"); +/// ``` +pub trait IndexAccessor { + fn index(self, key: Expr) -> Expr; +} + +impl IndexAccessor for Expr { + fn index(self, key: Expr) -> Expr { + array_element(self, key) + } +} + +/// Return elements between `1` based `start` and `stop`, for +/// example `expr[1:3]` +/// +/// ## Example: Access element 2, 3, 4 from column "c1" +/// +/// For example if column "c1" holds documents like this +/// +/// ```json +/// [10, 20, 30, 40] +/// ``` +/// +/// You can access the value `[20, 30, 40]` with +/// +/// ``` +/// # use datafusion_expr::{lit, col}; +/// # use datafusion_functions_nested::expr_ext::SliceAccessor; +/// let expr = col("c1") +/// .range(lit(2), lit(4)); +/// assert_eq!(expr.schema_name().to_string(), "c1[Int32(2):Int32(4)]"); +/// ``` +pub trait SliceAccessor { + fn range(self, start: Expr, stop: Expr) -> Expr; +} + +impl SliceAccessor for Expr { + fn range(self, start: Expr, stop: Expr) -> Expr { + array_slice(self, start, stop, None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use datafusion_expr::{col, lit}; + + #[test] + fn test_index() { + let expr1 = col("a").index(lit(1)); + let expr2 = array_element(col("a"), lit(1)); + assert_eq!(expr1, expr2); + } + + #[test] + fn test_range() { + let expr1 = col("a").range(lit(1), lit(2)); + let expr2 = array_slice(col("a"), lit(1), lit(2), None); + assert_eq!(expr1, expr2); + } +} diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-nested/src/extract.rs similarity index 61% rename from datafusion/functions-array/src/extract.rs rename to datafusion/functions-nested/src/extract.rs index 0dbd106b6f18..275095832edb 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-nested/src/extract.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front and array_pop_back functions. +//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front, array_pop_back, and array_any_value functions. use arrow::array::Array; use arrow::array::ArrayRef; @@ -35,16 +35,18 @@ use datafusion_common::cast::as_list_array; use datafusion_common::{ exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, }; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayElement, array_element, array element, @@ -52,15 +54,9 @@ make_udf_function!( array_element_udf ); -make_udf_function!( - ArraySlice, - array_slice, - array begin end stride, - "returns a slice of the array.", - array_slice_udf -); +create_func!(ArraySlice, array_slice_udf); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopFront, array_pop_front, array, @@ -68,7 +64,7 @@ make_udf_function!( array_pop_front_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopBack, array_pop_back, array, @@ -76,6 +72,14 @@ make_udf_function!( array_pop_back_udf ); +make_udf_expr_and_func!( + ArrayAnyValue, + array_any_value, + array, + "returns the first non-null element in the array.", + array_any_value_udf +); + #[derive(Debug)] pub(super) struct ArrayElement { signature: Signature, @@ -87,7 +91,6 @@ impl ArrayElement { Self { signature: Signature::array_and_index(Volatility::Immutable), aliases: vec![ - String::from("array_element"), String::from("array_extract"), String::from("list_element"), String::from("list_extract"), @@ -104,6 +107,27 @@ impl ScalarUDFImpl for ArrayElement { "array_element" } + fn display_name(&self, args: &[Expr]) -> Result { + let args_name = args.iter().map(ToString::to_string).collect::>(); + if args_name.len() != 2 { + return exec_err!("expect 2 args, got {}", args_name.len()); + } + + Ok(format!("{}[{}]", args_name[0], args_name[1])) + } + + fn schema_name(&self, args: &[Expr]) -> Result { + let args_name = args + .iter() + .map(|e| e.schema_name().to_string()) + .collect::>(); + if args_name.len() != 2 { + return exec_err!("expect 2 args, got {}", args_name.len()); + } + + Ok(format!("{}[{}]", args_name[0], args_name[1])) + } + fn signature(&self) -> &Signature { &self.signature } @@ -126,6 +150,43 @@ impl ScalarUDFImpl for ArrayElement { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_element_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_element_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Extracts the element with the index n from the array.", + ) + .with_syntax_example("array_element(array, index)") + .with_sql_example( + r#"```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "index", + "Index to extract the element from the array.", + ) + .build() + .unwrap() + }) } /// array_element SQL function @@ -224,6 +285,15 @@ where Ok(arrow::array::make_array(data)) } +#[doc = "returns a slice of the array."] +pub fn array_slice(array: Expr, begin: Expr, end: Expr, stride: Option) -> Expr { + let args = match stride { + Some(stride) => vec![array, begin, end, stride], + None => vec![array, begin, end], + }; + array_slice_udf().call(args) +} + #[derive(Debug)] pub(super) struct ArraySlice { signature: Signature, @@ -234,7 +304,7 @@ impl ArraySlice { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_slice"), String::from("list_slice")], + aliases: vec![String::from("list_slice")], } } } @@ -243,6 +313,28 @@ impl ScalarUDFImpl for ArraySlice { fn as_any(&self) -> &dyn Any { self } + + fn display_name(&self, args: &[Expr]) -> Result { + let args_name = args.iter().map(ToString::to_string).collect::>(); + if let Some((arr, indexes)) = args_name.split_first() { + Ok(format!("{arr}[{}]", indexes.join(":"))) + } else { + exec_err!("no argument") + } + } + + fn schema_name(&self, args: &[Expr]) -> Result { + let args_name = args + .iter() + .map(|e| e.schema_name().to_string()) + .collect::>(); + if let Some((arr, indexes)) = args_name.split_first() { + Ok(format!("{arr}[{}]", indexes.join(":"))) + } else { + exec_err!("no argument") + } + } + fn name(&self) -> &str { "array_slice" } @@ -262,6 +354,49 @@ impl ScalarUDFImpl for ArraySlice { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_slice_doc()) + } +} + +fn get_array_slice_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns a slice of the array based on 1-indexed start and end positions.", + ) + .with_syntax_example("array_slice(array, begin, end)") + .with_sql_example( + r#"```sql +> select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); ++--------------------------------------------------------+ +| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | ++--------------------------------------------------------+ +| [3, 4, 5, 6] | ++--------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "begin", + "Index of the first element. If negative, it counts backward from the end of the array.", + ) + .with_argument( + "end", + "Index of the last element. If negative, it counts backward from the end of the array.", + ) + .with_argument( + "stride", + "Stride of the array slice. The default is 1.", + ) + .build() + .unwrap() + }) } /// array_slice SQL function @@ -416,19 +551,16 @@ where if let (Some(from), Some(to)) = (from_index, to_index) { let stride = stride.map(|s| s.value(row_index)); - // array_slice with stride in duckdb, return empty array if stride is not supported and from > to. - if stride.is_none() && from > to { - // return empty array - offsets.push(offsets[row_index]); - continue; - } + // Default stride is 1 if not provided let stride = stride.unwrap_or(1); if stride.is_zero() { return exec_err!( "array_slice got invalid stride: {:?}, it cannot be 0", stride ); - } else if from <= to && stride.is_negative() { + } else if (from <= to && stride.is_negative()) + || (from > to && stride.is_positive()) + { // return empty array offsets.push(offsets[row_index]); continue; @@ -503,10 +635,7 @@ impl ArrayPopFront { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec![ - String::from("array_pop_front"), - String::from("list_pop_front"), - ], + aliases: vec![String::from("list_pop_front")], } } } @@ -534,6 +663,37 @@ impl ScalarUDFImpl for ArrayPopFront { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_pop_front_doc()) + } +} + +fn get_array_pop_front_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the array without the first element.", + ) + .with_syntax_example("array_pop_front(array)") + .with_sql_example( + r#"```sql +> select array_pop_front([1, 2, 3]); ++-------------------------------+ +| array_pop_front(List([1,2,3])) | ++-------------------------------+ +| [2, 3] | ++-------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_pop_front SQL function @@ -581,10 +741,7 @@ impl ArrayPopBack { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec![ - String::from("array_pop_back"), - String::from("list_pop_back"), - ], + aliases: vec![String::from("list_pop_back")], } } } @@ -612,6 +769,37 @@ impl ScalarUDFImpl for ArrayPopBack { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_pop_back_doc()) + } +} + +fn get_array_pop_back_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the array without the last element.", + ) + .with_syntax_example("array_pop_back(array)") + .with_sql_example( + r#"```sql +> select array_pop_back([1, 2, 3]); ++-------------------------------+ +| array_pop_back(List([1,2,3])) | ++-------------------------------+ +| [1, 2] | ++-------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_pop_back SQL function @@ -652,3 +840,146 @@ where ); general_array_slice::(array, &from_array, &to_array, None) } + +#[derive(Debug)] +pub(super) struct ArrayAnyValue { + signature: Signature, + aliases: Vec, +} + +impl ArrayAnyValue { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("list_any_value")], + } + } +} + +impl ScalarUDFImpl for ArrayAnyValue { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_any_value" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(field) + | LargeList(field) + | FixedSizeList(field, _) => Ok(field.data_type().clone()), + _ => plan_err!( + "array_any_value can only accept List, LargeList or FixedSizeList as the argument" + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_any_value_inner)(args) + } + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_any_value_doc()) + } +} + +fn get_array_any_value_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the first non-null element in the array.", + ) + .with_syntax_example("array_any_value(array)") + .with_sql_example( + r#"```sql +> select array_any_value([NULL, 1, 2, 3]); ++-------------------------------+ +| array_any_value(List([NULL,1,2,3])) | ++-------------------------------------+ +| 1 | ++-------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) +} + +fn array_any_value_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_any_value expects one argument"); + } + + match &args[0].data_type() { + List(_) => { + let array = as_list_array(&args[0])?; + general_array_any_value::(array) + } + LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_array_any_value::(array) + } + data_type => exec_err!("array_any_value does not support type: {:?}", data_type), + } +} + +fn general_array_any_value( + array: &GenericListArray, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(array.len()); + + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; + let len = end - start; + + // array is null + if len == O::usize_as(0) { + mutable.extend_nulls(1); + continue; + } + + let row_value = array.value(row_index); + match row_value.nulls() { + Some(row_nulls_buffer) => { + // nulls are present in the array so try to take the first valid element + if let Some(first_non_null_index) = + row_nulls_buffer.valid_indices().next() + { + let index = start.as_usize() + first_non_null_index; + mutable.extend(0, index, index + 1) + } else { + // all the elements in the array are null + mutable.extend_nulls(1); + } + } + None => { + // no nulls are present in the array so take the first element + let index = start.as_usize(); + mutable.extend(0, index, index + 1); + } + } + } + + let data = mutable.freeze(); + Ok(arrow::array::make_array(data)) +} diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-nested/src/flatten.rs similarity index 76% rename from datafusion/functions-array/src/flatten.rs rename to datafusion/functions-nested/src/flatten.rs index e2b50c6c02cc..4fe631517b09 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-nested/src/flatten.rs @@ -26,12 +26,14 @@ use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( Flatten, flatten, array, @@ -48,7 +50,7 @@ impl Flatten { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("flatten")], + aliases: vec![], } } } @@ -78,7 +80,7 @@ impl ScalarUDFImpl for Flatten { get_base_type(field.data_type()) } Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(field.clone())), + FixedSizeList(field, _) => Ok(List(Arc::clone(field))), _ => exec_err!( "Not reachable, data_type should be List, LargeList or FixedSizeList" ), @@ -96,6 +98,38 @@ impl ScalarUDFImpl for Flatten { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_flatten_doc()) + } +} +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_flatten_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Converts an array of arrays to a flat array.\n\n- Applies to any depth of nested arrays\n- Does not change arrays that are already flat\n\nThe flattened array contains all the elements from all source arrays.", + ) + .with_syntax_example("flatten(array)") + .with_sql_example( + r#"```sql +> select flatten([[1, 2], [3, 4]]); ++------------------------------+ +| flatten(List([1,2], [3,4])) | ++------------------------------+ +| [1, 2, 3, 4] | ++------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// Flatten SQL function @@ -116,7 +150,7 @@ pub fn flatten_inner(args: &[ArrayRef]) -> Result { let flattened_array = flatten_internal::(list_arr.clone(), None)?; Ok(Arc::new(flattened_array) as ArrayRef) } - Null => Ok(args[0].clone()), + Null => Ok(Arc::clone(&args[0])), _ => { exec_err!("flatten does not support type '{array_type:?}'") } @@ -148,7 +182,7 @@ fn flatten_internal( let list_arr = GenericListArray::::new(field, offsets, values, None); Ok(list_arr) } else { - Ok(list_arr.clone()) + Ok(list_arr) } } } diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-nested/src/length.rs similarity index 75% rename from datafusion/functions-array/src/length.rs rename to datafusion/functions-nested/src/length.rs index 9bbd11950d21..3e039f286421 100644 --- a/datafusion/functions-array/src/length.rs +++ b/datafusion/functions-nested/src/length.rs @@ -27,12 +27,14 @@ use core::any::type_name; use datafusion_common::cast::{as_generic_list_array, as_int64_array}; use datafusion_common::DataFusionError; use datafusion_common::{exec_err, plan_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayLength, array_length, array, @@ -49,7 +51,7 @@ impl ArrayLength { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_length"), String::from("list_length")], + aliases: vec![String::from("list_length")], } } } @@ -82,6 +84,43 @@ impl ScalarUDFImpl for ArrayLength { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the length of the array dimension.", + ) + .with_syntax_example("array_length(array, dimension)") + .with_sql_example( + r#"```sql +> select array_length([1, 2, 3, 4, 5], 1); ++-------------------------------------------+ +| array_length(List([1,2,3,4,5]), 1) | ++-------------------------------------------+ +| 5 | ++-------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "dimension", + "Array dimension.", + ) + .build() + .unwrap() + }) } /// Array_length SQL function diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-nested/src/lib.rs similarity index 73% rename from datafusion/functions-array/src/lib.rs rename to datafusion/functions-nested/src/lib.rs index 5914736773b7..301ddb36fc56 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -14,10 +14,12 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] -//! Array Functions for [DataFusion]. +//! Nested type Functions for [DataFusion]. //! -//! This crate contains a collection of array functions implemented using the +//! This crate contains a collection of nested type functions implemented using the //! extension API. //! //! [DataFusion]: https://crates.io/crates/datafusion @@ -32,12 +34,19 @@ pub mod array_has; pub mod cardinality; pub mod concat; pub mod dimension; +pub mod distance; pub mod empty; pub mod except; +pub mod expr_ext; pub mod extract; pub mod flatten; pub mod length; pub mod make_array; +pub mod map; +pub mod map_extract; +pub mod map_keys; +pub mod map_values; +pub mod planner; pub mod position; pub mod range; pub mod remove; @@ -45,7 +54,6 @@ pub mod repeat; pub mod replace; pub mod resize; pub mod reverse; -pub mod rewrite; pub mod set_ops; pub mod sort; pub mod string; @@ -68,8 +76,10 @@ pub mod expr_fn { pub use super::concat::array_prepend; pub use super::dimension::array_dims; pub use super::dimension::array_ndims; + pub use super::distance::array_distance; pub use super::empty::array_empty; pub use super::except::array_except; + pub use super::extract::array_any_value; pub use super::extract::array_element; pub use super::extract::array_pop_back; pub use super::extract::array_pop_front; @@ -77,6 +87,9 @@ pub mod expr_fn { pub use super::flatten::flatten; pub use super::length::array_length; pub use super::make_array::make_array; + pub use super::map_extract::map_extract; + pub use super::map_keys::map_keys; + pub use super::map_values::map_values; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -98,9 +111,9 @@ pub mod expr_fn { pub use super::string::string_to_array; } -/// Registers all enabled packages with a [`FunctionRegistry`] -pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = vec![ +/// Return all default nested type functions +pub fn all_default_nested_functions() -> Vec> { + vec![ string::array_to_string_udf(), string::string_to_array_udf(), range::range_udf(), @@ -116,12 +129,14 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { extract::array_pop_back_udf(), extract::array_pop_front_udf(), extract::array_slice_udf(), + extract::array_any_value_udf(), make_array::make_array_udf(), array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), empty::array_empty_udf(), length::array_length_udf(), + distance::array_distance_udf(), flatten::flatten_udf(), sort::array_sort_udf(), repeat::array_repeat_udf(), @@ -138,7 +153,16 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { replace::array_replace_n_udf(), replace::array_replace_all_udf(), replace::array_replace_udf(), - ]; + map::map_udf(), + map_extract::map_extract_udf(), + map_keys::map_keys_udf(), + map_values::map_values_udf(), + ] +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let functions: Vec> = all_default_nested_functions(); functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; if let Some(existing_udf) = existing_udf { @@ -146,7 +170,33 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { } Ok(()) as Result<()> })?; - registry.register_function_rewrite(Arc::new(rewrite::ArrayFunctionRewriter {}))?; Ok(()) } + +#[cfg(test)] +mod tests { + use crate::all_default_nested_functions; + use datafusion_common::Result; + use std::collections::HashSet; + + #[test] + fn test_no_duplicate_name() -> Result<()> { + let mut names = HashSet::new(); + for func in all_default_nested_functions() { + assert!( + names.insert(func.name().to_string().to_lowercase()), + "duplicate function name: {}", + func.name() + ); + for alias in func.aliases() { + assert!( + names.insert(alias.to_string().to_lowercase()), + "duplicate function name: {}", + alias + ); + } + } + Ok(()) + } +} diff --git a/datafusion/functions-array/src/macros.rs b/datafusion/functions-nested/src/macros.rs similarity index 67% rename from datafusion/functions-array/src/macros.rs rename to datafusion/functions-nested/src/macros.rs index c49f5830b8d5..00247f39ac10 100644 --- a/datafusion/functions-array/src/macros.rs +++ b/datafusion/functions-nested/src/macros.rs @@ -19,8 +19,8 @@ /// /// 1. Single `ScalarUDF` instance /// -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that function named $NAME. +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. /// @@ -41,60 +41,58 @@ /// * `arg`: 0 or more named arguments for the function /// * `DOC`: documentation string for the function /// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` -/// * `GNAME`: name for the single static instance of the `ScalarUDF` /// /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl -macro_rules! make_udf_function { +macro_rules! make_udf_expr_and_func { ($UDF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr , $SCALAR_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] - pub fn $EXPR_FN($($arg: Expr),*) -> Expr { - Expr::ScalarFunction(ScalarFunction::new_udf( + pub fn $EXPR_FN($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), vec![$($arg),*], )) } - - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() - } + create_func!($UDF, $SCALAR_UDF_FN); } }; ($UDF:ty, $EXPR_FN:ident, $DOC:expr , $SCALAR_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function #[doc = $DOC] - pub fn $EXPR_FN(arg: Vec) -> Expr { - Expr::ScalarFunction(ScalarFunction::new_udf( + pub fn $EXPR_FN(arg: Vec) -> datafusion_expr::Expr { + datafusion_expr::Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction::new_udf( $SCALAR_UDF_FN(), arg, )) } + create_func!($UDF, $SCALAR_UDF_FN); + } + }; +} +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`ScalarUDFImpl`] +/// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` +/// +/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl +macro_rules! create_func { + ($UDF:ty, $SCALAR_UDF_FN:ident) => { + paste::paste! { /// Singleton instance of [`$UDF`], ensures the UDF is only created once /// named STATIC_$(UDF). For example `STATIC_ArrayToString` #[allow(non_upper_case_globals)] static [< STATIC_ $UDF >]: std::sync::OnceLock> = std::sync::OnceLock::new(); - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + + #[doc = concat!("ScalarFunction that returns a [`ScalarUDF`](datafusion_expr::ScalarUDF) for ")] + #[doc = stringify!($UDF)] pub fn $SCALAR_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDF >] .get_or_init(|| { diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs similarity index 68% rename from datafusion/functions-array/src/make_array.rs rename to datafusion/functions-nested/src/make_array.rs index 0439a736ee42..c2c6f24948b8 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -17,7 +17,9 @@ //! [`ScalarUDFImpl`] definitions for `make_array` function. -use std::{any::Any, sync::Arc}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; +use std::vec; use arrow::array::{ArrayData, Capacities, MutableArrayData}; use arrow_array::{ @@ -26,16 +28,19 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::{plan_err, utils::array_into_list_array, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; +use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; +use datafusion_expr::binary::{ + try_type_union_resolution_with_struct, type_union_resolution, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::TypeSignature; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, }; use crate::utils::make_scalar_function; -make_udf_function!( +make_udf_expr_and_func!( MakeArray, make_array, "Returns an Arrow array using the specified input expressions.", @@ -58,10 +63,10 @@ impl MakeArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![TypeSignature::VariadicEqual, TypeSignature::Any(0)], + vec![TypeSignature::UserDefined, TypeSignature::Any(0)], Volatility::Immutable, ), - aliases: vec![String::from("make_array"), String::from("make_list")], + aliases: vec![String::from("make_list")], } } } @@ -81,21 +86,14 @@ impl ScalarUDFImpl for MakeArray { fn return_type(&self, arg_types: &[DataType]) -> Result { match arg_types.len() { - 0 => Ok(DataType::List(Arc::new(Field::new( - "item", - DataType::Null, - true, - )))), + 0 => Ok(empty_array_type()), _ => { - let mut expr_type = DataType::Null; - for arg_type in arg_types { - if !arg_type.equals_datatype(&DataType::Null) { - expr_type = arg_type.clone(); - break; - } - } - - Ok(List(Arc::new(Field::new("item", expr_type, true)))) + // At this point, all the type in array should be coerced to the same one + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].to_owned(), + true, + )))) } } } @@ -104,9 +102,79 @@ impl ScalarUDFImpl for MakeArray { make_scalar_function(make_array_inner)(args) } + fn invoke_no_args(&self, _number_rows: usize) -> Result { + make_scalar_function(make_array_inner)(&[]) + } + fn aliases(&self) -> &[String] { &self.aliases } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let mut errors = vec![]; + match try_type_union_resolution_with_struct(arg_types) { + Ok(r) => return Ok(r), + Err(e) => { + errors.push(e); + } + } + + if let Some(new_type) = type_union_resolution(arg_types) { + // TODO: Move FixedSizeList to List in type_union_resolution + if let DataType::FixedSizeList(field, _) = new_type { + Ok(vec![List(field); arg_types.len()]) + } else if new_type.is_null() { + Ok(vec![DataType::Int64; arg_types.len()]) + } else { + Ok(vec![new_type; arg_types.len()]) + } + } else { + plan_err!( + "Fail to find the valid type between {:?} for {}, errors are {:?}", + arg_types, + self.name(), + errors + ) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_make_array_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_make_array_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array using the specified input expressions.", + ) + .with_syntax_example("make_array(expression1[, ..., expression_n])") + .with_sql_example( + r#"```sql +> select make_array(1, 2, 3, 4, 5); ++----------------------------------------------------------+ +| make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | ++----------------------------------------------------------+ +| [1, 2, 3, 4, 5] | ++----------------------------------------------------------+ +```"#, + ) + .with_argument( + "expression_n", + "Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators.", + ) + .build() + .unwrap() + }) +} + +// Empty array is a special case that is useful for many other array functions +pub(super) fn empty_array_type() -> DataType { + List(Arc::new(Field::new("item", DataType::Int64, true))) } /// `make_array_inner` is the implementation of the `make_array` function. @@ -125,8 +193,10 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: Null => { - let array = new_null_array(&Null, arrays.iter().map(|a| a.len()).sum()); - Ok(Arc::new(array_into_list_array(array))) + let length = arrays.iter().map(|a| a.len()).sum(); + // By default Int64 + let array = new_null_array(&DataType::Int64, length); + Ok(Arc::new(array_into_list_array_nullable(array))) } LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs new file mode 100644 index 000000000000..d7dce3bacbe1 --- /dev/null +++ b/datafusion/functions-nested/src/map.rs @@ -0,0 +1,429 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::collections::{HashSet, VecDeque}; +use std::sync::{Arc, OnceLock}; + +use arrow::array::ArrayData; +use arrow_array::{Array, ArrayRef, MapArray, OffsetSizeTrait, StructArray}; +use arrow_buffer::{Buffer, ToByteSlice}; +use arrow_schema::{DataType, Field, SchemaBuilder}; + +use datafusion_common::utils::{fixed_size_list_to_arrays, list_to_arrays}; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; + +use crate::make_array::make_array; + +/// Returns a map created from a key list and a value list +pub fn map(keys: Vec, values: Vec) -> Expr { + let keys = make_array(keys); + let values = make_array(values); + Expr::ScalarFunction(ScalarFunction::new_udf(map_udf(), vec![keys, values])) +} + +create_func!(MapFunc, map_udf); + +/// Check if we can evaluate the expr to constant directly. +/// +/// # Example +/// ```sql +/// SELECT make_map('type', 'test') from test +/// ``` +/// We can evaluate the result of `make_map` directly. +fn can_evaluate_to_const(args: &[ColumnarValue]) -> bool { + args.iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) +} + +fn make_map_batch(args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!( + "make_map requires exactly 2 arguments, got {} instead", + args.len() + ); + } + + let can_evaluate_to_const = can_evaluate_to_const(args); + + // check the keys array is unique + let keys = get_first_array_ref(&args[0])?; + if keys.null_count() > 0 { + return exec_err!("map key cannot be null"); + } + let key_array = keys.as_ref(); + + match &args[0] { + ColumnarValue::Array(_) => { + let row_keys = match key_array.data_type() { + DataType::List(_) => list_to_arrays::(&keys), + DataType::LargeList(_) => list_to_arrays::(&keys), + DataType::FixedSizeList(_, _) => fixed_size_list_to_arrays(&keys), + data_type => { + return exec_err!( + "Expected list, large_list or fixed_size_list, got {:?}", + data_type + ); + } + }; + + row_keys + .iter() + .try_for_each(|key| check_unique_keys(key.as_ref()))?; + } + ColumnarValue::Scalar(_) => { + check_unique_keys(key_array)?; + } + } + + let values = get_first_array_ref(&args[1])?; + make_map_batch_internal(keys, values, can_evaluate_to_const, args[0].data_type()) +} + +fn check_unique_keys(array: &dyn Array) -> Result<()> { + let mut seen_keys = HashSet::with_capacity(array.len()); + + for i in 0..array.len() { + let key = ScalarValue::try_from_array(array, i)?; + if seen_keys.contains(&key) { + return exec_err!("map key must be unique, duplicate key found: {}", key); + } + seen_keys.insert(key); + } + Ok(()) +} + +fn get_first_array_ref(columnar_value: &ColumnarValue) -> Result { + match columnar_value { + ColumnarValue::Scalar(value) => match value { + ScalarValue::List(array) => Ok(array.value(0)), + ScalarValue::LargeList(array) => Ok(array.value(0)), + ScalarValue::FixedSizeList(array) => Ok(array.value(0)), + _ => exec_err!("Expected array, got {:?}", value), + }, + ColumnarValue::Array(array) => Ok(array.to_owned()), + } +} + +fn make_map_batch_internal( + keys: ArrayRef, + values: ArrayRef, + can_evaluate_to_const: bool, + data_type: DataType, +) -> Result { + if keys.len() != values.len() { + return exec_err!("map requires key and value lists to have the same length"); + } + + if !can_evaluate_to_const { + return if let DataType::LargeList(..) = data_type { + make_map_array_internal::(keys, values) + } else { + make_map_array_internal::(keys, values) + }; + } + + let key_field = Arc::new(Field::new("key", keys.data_type().clone(), false)); + let value_field = Arc::new(Field::new("value", values.data_type().clone(), true)); + let mut entry_struct_buffer: VecDeque<(Arc, ArrayRef)> = VecDeque::new(); + let mut entry_offsets_buffer = VecDeque::new(); + entry_offsets_buffer.push_back(0); + + entry_struct_buffer.push_back((Arc::clone(&key_field), Arc::clone(&keys))); + entry_struct_buffer.push_back((Arc::clone(&value_field), Arc::clone(&values))); + entry_offsets_buffer.push_back(keys.len() as u32); + + let entry_struct: Vec<(Arc, ArrayRef)> = entry_struct_buffer.into(); + let entry_struct = StructArray::from(entry_struct); + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + + let entry_offsets: Vec = entry_offsets_buffer.into(); + let entry_offsets_buffer = Buffer::from(entry_offsets.to_byte_slice()); + + let map_data = ArrayData::builder(map_data_type) + .len(entry_offsets.len() - 1) + .add_buffer(entry_offsets_buffer) + .add_child_data(entry_struct.to_data()) + .build()?; + let map_array = Arc::new(MapArray::from(map_data)); + + Ok(if can_evaluate_to_const { + ColumnarValue::Scalar(ScalarValue::try_from_array(map_array.as_ref(), 0)?) + } else { + ColumnarValue::Array(map_array) + }) +} + +#[derive(Debug)] +pub struct MapFunc { + signature: Signature, +} + +impl Default for MapFunc { + fn default() -> Self { + Self::new() + } +} + +impl MapFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for MapFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() % 2 != 0 { + return exec_err!( + "map requires an even number of arguments, got {} instead", + arg_types.len() + ); + } + let mut builder = SchemaBuilder::new(); + builder.push(Field::new( + "key", + get_element_type(&arg_types[0])?.clone(), + false, + )); + builder.push(Field::new( + "value", + get_element_type(&arg_types[1])?.clone(), + true, + )); + let fields = builder.finish().fields; + Ok(DataType::Map( + Arc::new(Field::new("entries", DataType::Struct(fields), false)), + false, + )) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_map_batch(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns an Arrow map with the specified key-value pairs.\n\n\ + The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null." + ) + .with_syntax_example( + "map(key, value)\nmap(key: value)\nmake_map(['key1', 'key2'], ['value1', 'value2'])" + ) + .with_sql_example( + r#"```sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ```"# + ) + .with_argument( + "key", + "For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null." + ) + .with_argument( + "value", + "For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators.\n\ + For `make_map`: The list of values to be mapped to the corresponding keys." + ) + .build() + .unwrap() + }) +} + +fn get_element_type(data_type: &DataType) -> Result<&DataType> { + match data_type { + DataType::List(element) => Ok(element.data_type()), + DataType::LargeList(element) => Ok(element.data_type()), + DataType::FixedSizeList(element, _) => Ok(element.data_type()), + _ => exec_err!( + "Expected list, large_list or fixed_size_list, got {:?}", + data_type + ), + } +} + +/// Helper function to create MapArray from array of values to support arrays for Map scalar function +/// +/// ``` text +/// Format of input KEYS and VALUES column +/// keys values +/// +---------------------+ +---------------------+ +/// | +-----------------+ | | +-----------------+ | +/// | | [k11, k12, k13] | | | | [v11, v12, v13] | | +/// | +-----------------+ | | +-----------------+ | +/// | | | | +/// | +-----------------+ | | +-----------------+ | +/// | | [k21, k22, k23] | | | | [v21, v22, v23] | | +/// | +-----------------+ | | +-----------------+ | +/// | | | | +/// | +-----------------+ | | +-----------------+ | +/// | |[k31, k32, k33] | | | |[v31, v32, v33] | | +/// | +-----------------+ | | +-----------------+ | +/// +---------------------+ +---------------------+ +/// ``` +/// Flattened keys and values array to user create `StructArray`, +/// which serves as inner child for `MapArray` +/// +/// ``` text +/// Flattened Flattened +/// Keys Values +/// +-----------+ +-----------+ +/// | +-------+ | | +-------+ | +/// | | k11 | | | | v11 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k12 | | | | v12 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k13 | | | | v13 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k21 | | | | v21 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k22 | | | | v22 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k23 | | | | v23 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k31 | | | | v31 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k32 | | | | v32 | | +/// | +-------+ | | +-------+ | +/// | +-------+ | | +-------+ | +/// | | k33 | | | | v33 | | +/// | +-------+ | | +-------+ | +/// +-----------+ +-----------+ +/// ```text + +fn make_map_array_internal( + keys: ArrayRef, + values: ArrayRef, +) -> Result { + let mut offset_buffer = vec![O::zero()]; + let mut running_offset = O::zero(); + + let keys = list_to_arrays::(&keys); + let values = list_to_arrays::(&values); + + let mut key_array_vec = vec![]; + let mut value_array_vec = vec![]; + for (k, v) in keys.iter().zip(values.iter()) { + running_offset = running_offset.add(O::usize_as(k.len())); + offset_buffer.push(running_offset); + key_array_vec.push(k.as_ref()); + value_array_vec.push(v.as_ref()); + } + + // concatenate all the arrays + let flattened_keys = arrow::compute::concat(key_array_vec.as_ref())?; + if flattened_keys.null_count() > 0 { + return exec_err!("keys cannot be null"); + } + let flattened_values = arrow::compute::concat(value_array_vec.as_ref())?; + + let fields = vec![ + Arc::new(Field::new("key", flattened_keys.data_type().clone(), false)), + Arc::new(Field::new( + "value", + flattened_values.data_type().clone(), + true, + )), + ]; + + let struct_data = ArrayData::builder(DataType::Struct(fields.into())) + .len(flattened_keys.len()) + .add_child_data(flattened_keys.to_data()) + .add_child_data(flattened_values.to_data()) + .build()?; + + let map_data = ArrayData::builder(DataType::Map( + Arc::new(Field::new( + "entries", + struct_data.data_type().clone(), + false, + )), + false, + )) + .len(keys.len()) + .add_child_data(struct_data) + .add_buffer(Buffer::from_slice_ref(offset_buffer.as_slice())) + .build()?; + Ok(ColumnarValue::Array(Arc::new(MapArray::from(map_data)))) +} diff --git a/datafusion/functions-nested/src/map_extract.rs b/datafusion/functions-nested/src/map_extract.rs new file mode 100644 index 000000000000..d2bb6595fe76 --- /dev/null +++ b/datafusion/functions-nested/src/map_extract.rs @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_extract functions. + +use arrow::array::{ArrayRef, Capacities, MutableArrayData}; +use arrow_array::{make_array, ListArray}; + +use arrow::datatypes::DataType; +use arrow_array::{Array, MapArray}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::Field; + +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; +use std::vec; + +use crate::utils::{get_map_entry_field, make_scalar_function}; + +// Create static instances of ScalarUDFs for each function +make_udf_expr_and_func!( + MapExtract, + map_extract, + map key, + "Return a list containing the value for a given key or an empty list if the key is not contained in the map.", + map_extract_udf +); + +#[derive(Debug)] +pub(super) struct MapExtract { + signature: Signature, + aliases: Vec, +} + +impl MapExtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![String::from("element_at")], + } + } +} + +impl ScalarUDFImpl for MapExtract { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "map_extract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 2 { + return exec_err!("map_extract expects two arguments"); + } + let map_type = &arg_types[0]; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new( + "item", + map_fields.last().unwrap().data_type().clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(map_extract_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return exec_err!("map_extract expects two arguments"); + } + + let field = get_map_entry_field(&arg_types[0])?; + Ok(vec![ + arg_types[0].clone(), + field.first().unwrap().data_type().clone(), + ]) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_extract_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_extract_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list containing the value for the given key or an empty list if the key is not present in the map.", + ) + .with_syntax_example("map_extract(map, key)") + .with_sql_example( + r#"```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators.", + ) + .with_argument( + "key", + "Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed.", + ) + .build() + .unwrap() + }) +} + +fn general_map_extract_inner( + map_array: &MapArray, + query_keys_array: &dyn Array, +) -> Result { + let keys = map_array.keys(); + let mut offsets = vec![0_i32]; + + let values = map_array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, capacity); + + for (row_index, offset_window) in map_array.value_offsets().windows(2).enumerate() { + let start = offset_window[0] as usize; + let end = offset_window[1] as usize; + let len = end - start; + + let query_key = query_keys_array.slice(row_index, 1); + + let value_index = + (0..len).find(|&i| keys.slice(start + i, 1).as_ref() == query_key.as_ref()); + + match value_index { + Some(index) => { + mutable.extend(0, start + index, start + index + 1); + } + None => { + mutable.extend_nulls(1); + } + } + offsets.push(offsets[row_index] + 1); + } + + let data = mutable.freeze(); + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new("item", map_array.value_type().clone(), true)), + OffsetBuffer::::new(offsets.into()), + Arc::new(make_array(data)), + None, + ))) +} + +fn map_extract_inner(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!("map_extract expects two arguments"); + } + + let map_array = match args[0].data_type() { + DataType::Map(_, _) => as_map_array(&args[0])?, + _ => return exec_err!("The first argument in map_extract must be a map"), + }; + + let key_type = map_array.key_type(); + + if key_type != args[1].data_type() { + return exec_err!( + "The key type {} does not match the map key type {}", + args[1].data_type(), + key_type + ); + } + + general_map_extract_inner(map_array, &args[1]) +} diff --git a/datafusion/functions-nested/src/map_keys.rs b/datafusion/functions-nested/src/map_keys.rs new file mode 100644 index 000000000000..03e381e372f6 --- /dev/null +++ b/datafusion/functions-nested/src/map_keys.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_keys function. + +use crate::utils::{get_map_entry_field, make_scalar_function}; +use arrow_array::{Array, ArrayRef, ListArray}; +use arrow_schema::{DataType, Field}; +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +make_udf_expr_and_func!( + MapKeysFunc, + map_keys, + map, + "Return a list of all keys in the map.", + map_keys_udf +); + +#[derive(Debug)] +pub(crate) struct MapKeysFunc { + signature: Signature, +} + +impl MapKeysFunc { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for MapKeysFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_keys" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return exec_err!("map_keys expects single argument"); + } + let map_type = &arg_types[0]; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new( + "item", + map_fields.first().unwrap().data_type().clone(), + false, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(map_keys_inner)(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_keys_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_keys_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list of all keys in the map." + ) + .with_syntax_example("map_keys(map)") + .with_sql_example( + r#"```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) + .build() + .unwrap() + }) +} + +fn map_keys_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("map_keys expects single argument"); + } + + let map_array = match args[0].data_type() { + DataType::Map(_, _) => as_map_array(&args[0])?, + _ => return exec_err!("Argument for map_keys should be a map"), + }; + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new("item", map_array.key_type().clone(), false)), + map_array.offsets().clone(), + Arc::clone(map_array.keys()), + None, + ))) +} diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs new file mode 100644 index 000000000000..dc7d9c9db8ee --- /dev/null +++ b/datafusion/functions-nested/src/map_values.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for map_values function. + +use crate::utils::{get_map_entry_field, make_scalar_function}; +use arrow_array::{Array, ArrayRef, ListArray}; +use arrow_schema::{DataType, Field}; +use datafusion_common::{cast::as_map_array, exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MAP; +use datafusion_expr::{ + ArrayFunctionSignature, ColumnarValue, Documentation, ScalarUDFImpl, Signature, + TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +make_udf_expr_and_func!( + MapValuesFunc, + map_values, + map, + "Return a list of all values in the map.", + map_values_udf +); + +#[derive(Debug)] +pub(crate) struct MapValuesFunc { + signature: Signature, +} + +impl MapValuesFunc { + pub fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::ArraySignature(ArrayFunctionSignature::MapArray), + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for MapValuesFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "map_values" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return exec_err!("map_values expects single argument"); + } + let map_type = &arg_types[0]; + let map_fields = get_map_entry_field(map_type)?; + Ok(DataType::List(Arc::new(Field::new( + "item", + map_fields.last().unwrap().data_type().clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(map_values_inner)(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_map_values_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_map_values_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MAP) + .with_description( + "Returns a list of all values in the map." + ) + .with_syntax_example("map_values(map)") + .with_sql_example( + r#"```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +```"#, + ) + .with_argument( + "map", + "Map expression. Can be a constant, column, or function, and any combination of map operators." + ) + .build() + .unwrap() + }) +} + +fn map_values_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("map_values expects single argument"); + } + + let map_array = match args[0].data_type() { + DataType::Map(_, _) => as_map_array(&args[0])?, + _ => return exec_err!("Argument for map_values should be a map"), + }; + + Ok(Arc::new(ListArray::new( + Arc::new(Field::new("item", map_array.value_type().clone(), true)), + map_array.offsets().clone(), + Arc::clone(map_array.values()), + None, + ))) +} diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs new file mode 100644 index 000000000000..9ae2fa781d87 --- /dev/null +++ b/datafusion/functions-nested/src/planner.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] + +use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ + planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, + sqlparser, Expr, ExprSchemable, GetFieldAccess, +}; +use datafusion_functions::expr_fn::get_field; +use datafusion_functions_aggregate::nth_value::nth_value_udaf; + +use crate::map::map_udf; +use crate::{ + array_has::{array_has_all, array_has_udf}, + expr_fn::{array_append, array_concat, array_prepend}, + extract::{array_element, array_slice}, + make_array::make_array, +}; + +#[derive(Debug)] +pub struct NestedFunctionPlanner; + +impl ExprPlanner for NestedFunctionPlanner { + fn plan_binary_op( + &self, + expr: RawBinaryExpr, + schema: &DFSchema, + ) -> Result> { + let RawBinaryExpr { op, left, right } = expr; + + if op == sqlparser::ast::BinaryOperator::StringConcat { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + + // Rewrite string concat operator to function based on types + // if we get list || list then we rewrite it to array_concat() + // if we get list || non-list then we rewrite it to array_append() + // if we get non-list || list then we rewrite it to array_prepend() + // if we get string || string then we rewrite it to concat() + + // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient. + // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite. + if left_list_ndims + right_list_ndims == 0 { + // TODO: concat function ignore null, but string concat takes null into consideration + // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator` + } else if left_list_ndims == right_list_ndims { + return Ok(PlannerResult::Planned(array_concat(vec![left, right]))); + } else if left_list_ndims > right_list_ndims { + return Ok(PlannerResult::Planned(array_append(left, right))); + } else if left_list_ndims < right_list_ndims { + return Ok(PlannerResult::Planned(array_prepend(left, right))); + } + } else if matches!( + op, + sqlparser::ast::BinaryOperator::AtArrow + | sqlparser::ast::BinaryOperator::ArrowAt + ) { + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + let left_list_ndims = list_ndims(&left_type); + let right_list_ndims = list_ndims(&right_type); + // if both are list + if left_list_ndims > 0 && right_list_ndims > 0 { + if op == sqlparser::ast::BinaryOperator::AtArrow { + // array1 @> array2 -> array_has_all(array1, array2) + return Ok(PlannerResult::Planned(array_has_all(left, right))); + } else { + // array1 <@ array2 -> array_has_all(array2, array1) + return Ok(PlannerResult::Planned(array_has_all(right, left))); + } + } + } + + Ok(PlannerResult::Original(RawBinaryExpr { op, left, right })) + } + + fn plan_array_literal( + &self, + exprs: Vec, + _schema: &DFSchema, + ) -> Result>> { + Ok(PlannerResult::Planned(make_array(exprs))) + } + + fn plan_make_map(&self, args: Vec) -> Result>> { + if args.len() % 2 != 0 { + return plan_err!("make_map requires an even number of arguments"); + } + + let (keys, values): (Vec<_>, Vec<_>) = + args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0); + let keys = make_array(keys.into_iter().map(|(_, e)| e).collect()); + let values = make_array(values.into_iter().map(|(_, e)| e).collect()); + + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(map_udf(), vec![keys, values]), + ))) + } + + fn plan_any(&self, expr: RawBinaryExpr) -> Result> { + if expr.op == sqlparser::ast::BinaryOperator::Eq { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + array_has_udf(), + // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)` + vec![expr.right, expr.left], + ), + ))) + } else { + plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op) + } + } +} + +#[derive(Debug)] +pub struct FieldAccessPlanner; + +impl ExprPlanner for FieldAccessPlanner { + fn plan_field_access( + &self, + expr: RawFieldAccessExpr, + _schema: &DFSchema, + ) -> Result> { + let RawFieldAccessExpr { expr, field_access } = expr; + + match field_access { + // expr["field"] => get_field(expr, "field") + GetFieldAccess::NamedStructField { name } => { + Ok(PlannerResult::Planned(get_field(expr, name))) + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key: index } => { + match expr { + // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) + Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { + Ok(PlannerResult::Planned(Expr::AggregateFunction( + datafusion_expr::expr::AggregateFunction::new_udf( + nth_value_udaf(), + agg_func + .args + .into_iter() + .chain(std::iter::once(*index)) + .collect(), + agg_func.distinct, + agg_func.filter, + agg_func.order_by, + agg_func.null_treatment, + ), + ))) + } + _ => Ok(PlannerResult::Planned(array_element(expr, *index))), + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => Ok(PlannerResult::Planned(array_slice( + expr, + *start, + *stop, + Some(*stride), + ))), + } + } +} + +fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { + return agg_func.func.name() == "array_agg"; +} diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-nested/src/position.rs similarity index 70% rename from datafusion/functions-array/src/position.rs rename to datafusion/functions-nested/src/position.rs index a5a7a7405aa9..adb45141601d 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-nested/src/position.rs @@ -19,11 +19,12 @@ use arrow_schema::DataType::{LargeList, List, UInt64}; use arrow_schema::{DataType, Field}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow_array::types::UInt64Type; use arrow_array::{ @@ -37,7 +38,7 @@ use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayPosition, array_position, array element index, @@ -57,7 +58,6 @@ impl ArrayPosition { Volatility::Immutable, ), aliases: vec![ - String::from("array_position"), String::from("list_position"), String::from("array_indexof"), String::from("list_indexof"), @@ -89,6 +89,53 @@ impl ScalarUDFImpl for ArrayPosition { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_position_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_position_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the position of the first occurrence of the specified element in the array.", + ) + .with_syntax_example("array_position(array, element)\narray_position(array, element, index)") + .with_sql_example( + r#"```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to search for position in the array.", + ) + .with_argument( + "index", + "Index at which to start searching.", + ) + .build() + .unwrap() + }) } /// Array_position SQL function @@ -168,7 +215,7 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -make_udf_function!( +make_udf_expr_and_func!( ArrayPositions, array_positions, array element, // arg name @@ -185,10 +232,7 @@ impl ArrayPositions { pub fn new() -> Self { Self { signature: Signature::array_and_element(Volatility::Immutable), - aliases: vec![ - String::from("array_positions"), - String::from("list_positions"), - ], + aliases: vec![String::from("list_positions")], } } } @@ -216,6 +260,41 @@ impl ScalarUDFImpl for ArrayPositions { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_positions_doc()) + } +} + +fn get_array_positions_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Searches for an element in the array, returns all occurrences.", + ) + .with_syntax_example("array_positions(array, element)") + .with_sql_example( + r#"```sql +> select array_positions([1, 2, 2, 3, 1, 4], 2); ++-----------------------------------------------+ +| array_positions(List([1,2,2,3,1,4]),Int64(2)) | ++-----------------------------------------------+ +| [2, 3] | ++-----------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to search for positions in the array.", + ) + .build() + .unwrap() + }) } /// Array_positions SQL function diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs new file mode 100644 index 000000000000..ddc56b1e4ee8 --- /dev/null +++ b/datafusion/functions-nested/src/range.rs @@ -0,0 +1,639 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for range and gen_series functions. + +use crate::utils::make_scalar_function; +use arrow::array::{Array, ArrayRef, Int64Array, ListArray, ListBuilder}; +use arrow::datatypes::{DataType, Field}; +use arrow_array::builder::{Date32Builder, TimestampNanosecondBuilder}; +use arrow_array::temporal_conversions::as_datetime_with_timezone; +use arrow_array::timezone::Tz; +use arrow_array::types::{ + Date32Type, IntervalMonthDayNanoType, TimestampNanosecondType as TSNT, +}; +use arrow_array::{NullArray, TimestampNanosecondArray}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_schema::DataType::*; +use arrow_schema::IntervalUnit::MonthDayNano; +use arrow_schema::TimeUnit::Nanosecond; +use datafusion_common::cast::{ + as_date32_array, as_int64_array, as_interval_mdn_array, as_timestamp_nanosecond_array, +}; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, not_impl_datafusion_err, Result, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use itertools::Itertools; +use std::any::Any; +use std::cmp::Ordering; +use std::iter::from_fn; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; + +make_udf_expr_and_func!( + Range, + range, + start stop step, + "create a list of values in the range between start and stop", + range_udf +); +#[derive(Debug)] +pub(super) struct Range { + signature: Signature, + aliases: Vec, +} +impl Range { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} +impl ScalarUDFImpl for Range { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "range" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + arg_types + .iter() + .map(|arg_type| match arg_type { + Null => Ok(Null), + Int8 => Ok(Int64), + Int16 => Ok(Int64), + Int32 => Ok(Int64), + Int64 => Ok(Int64), + UInt8 => Ok(Int64), + UInt16 => Ok(Int64), + UInt32 => Ok(Int64), + UInt64 => Ok(Int64), + Timestamp(_, tz) => Ok(Timestamp(Nanosecond, tz.clone())), + Date32 => Ok(Date32), + Date64 => Ok(Date32), + Utf8 => Ok(Date32), + LargeUtf8 => Ok(Date32), + Utf8View => Ok(Date32), + Interval(_) => Ok(Interval(MonthDayNano)), + _ => exec_err!("Unsupported DataType"), + }) + .try_collect() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.iter().any(|t| t.is_null()) { + Ok(Null) + } else { + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.iter().any(|arg| arg.data_type().is_null()) { + return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); + } + match args[0].data_type() { + Int64 => make_scalar_function(|args| gen_range_inner(args, false))(args), + Date32 => make_scalar_function(|args| gen_range_date(args, false))(args), + Timestamp(_, _) => { + make_scalar_function(|args| gen_range_timestamp(args, false))(args) + } + dt => { + exec_err!("unsupported type for RANGE. Expected Int64, Date32 or Timestamp, got: {dt}") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_range_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_range_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0.", + ) + .with_syntax_example("range(start, stop, step)") + .with_sql_example( + r#"```sql +> select range(2, 10, 3); ++-----------------------------------+ +| range(Int64(2),Int64(10),Int64(3))| ++-----------------------------------+ +| [2, 5, 8] | ++-----------------------------------+ + +> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); ++--------------------------------------------------------------+ +| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | ++--------------------------------------------------------------+ +| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | ++--------------------------------------------------------------+ +```"#, + ) + .with_argument( + "start", + "Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported.", + ) + .with_argument( + "end", + "End of the range (not included). Type must be the same as start.", + ) + .with_argument( + "step", + "Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges.", + ) + .build() + .unwrap() + }) +} + +make_udf_expr_and_func!( + GenSeries, + gen_series, + start stop step, + "create a list of values in the range between start and stop, include upper bound", + gen_series_udf +); +#[derive(Debug)] +pub(super) struct GenSeries { + signature: Signature, + aliases: Vec, +} +impl GenSeries { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![], + } + } +} +impl ScalarUDFImpl for GenSeries { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "generate_series" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + arg_types + .iter() + .map(|arg_type| match arg_type { + Null => Ok(Null), + Int8 => Ok(Int64), + Int16 => Ok(Int64), + Int32 => Ok(Int64), + Int64 => Ok(Int64), + UInt8 => Ok(Int64), + UInt16 => Ok(Int64), + UInt32 => Ok(Int64), + UInt64 => Ok(Int64), + Timestamp(_, tz) => Ok(Timestamp(Nanosecond, tz.clone())), + Date32 => Ok(Date32), + Date64 => Ok(Date32), + Utf8 => Ok(Date32), + LargeUtf8 => Ok(Date32), + Utf8View => Ok(Date32), + Interval(_) => Ok(Interval(MonthDayNano)), + _ => exec_err!("Unsupported DataType"), + }) + .try_collect() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.iter().any(|t| t.is_null()) { + Ok(Null) + } else { + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.iter().any(|arg| arg.data_type().is_null()) { + return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); + } + match args[0].data_type() { + Int64 => make_scalar_function(|args| gen_range_inner(args, true))(args), + Date32 => make_scalar_function(|args| gen_range_date(args, true))(args), + Timestamp(_, _) => { + make_scalar_function(|args| gen_range_timestamp(args, true))(args) + } + dt => { + exec_err!( + "unsupported type for GENERATE_SERIES. Expected Int64, Date32 or Timestamp, got: {}", + dt + ) + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_generate_series_doc()) + } +} + +static GENERATE_SERIES_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_generate_series_doc() -> &'static Documentation { + GENERATE_SERIES_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Similar to the range function, but it includes the upper bound.", + ) + .with_syntax_example("generate_series(start, stop, step)") + .with_sql_example( + r#"```sql +> select generate_series(1,3); ++------------------------------------+ +| generate_series(Int64(1),Int64(3)) | ++------------------------------------+ +| [1, 2, 3] | ++------------------------------------+ +```"#, + ) + .with_argument( + "start", + "start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported.", + ) + .with_argument( + "end", + "end of the series (included). Type must be the same as start.", + ) + .with_argument( + "step", + "increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges.", + ) + .build() + .unwrap() + }) +} + +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub(super) fn gen_range_inner( + args: &[ArrayRef], + include_upper: bool, +) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), + }; + + let mut values = vec![]; + let mut offsets = vec![0]; + let mut valid = BooleanBufferBuilder::new(stop_array.len()); + for (idx, stop) in stop_array.iter().enumerate() { + match retrieve_range_args(start_array, stop, step_array, idx) { + Some((_, _, 0)) => { + return exec_err!( + "step can't be 0 for function {}(start [, stop, step])", + if include_upper { + "generate_series" + } else { + "range" + } + ); + } + Some((start, stop, step)) => { + // Below, we utilize `usize` to represent steps. + // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. + let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| { + not_impl_datafusion_err!("step {} can't fit into usize", step) + })?; + values.extend( + gen_range_iter(start, stop, step < 0, include_upper) + .step_by(step_abs), + ); + offsets.push(values.len() as i32); + valid.append(true); + } + // If any of the arguments is NULL, append a NULL value to the result. + None => { + offsets.push(values.len() as i32); + valid.append(false); + } + }; + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + Some(NullBuffer::new(valid.finish())), + )?); + Ok(arr) +} + +/// Get the (start, stop, step) args for the range and generate_series function. +/// If any of the arguments is NULL, returns None. +fn retrieve_range_args( + start_array: Option<&Int64Array>, + stop: Option, + step_array: Option<&Int64Array>, + idx: usize, +) -> Option<(i64, i64, i64)> { + // Default start value is 0 if not provided + let start = + start_array.map_or(Some(0), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; + let stop = stop?; + // Default step value is 1 if not provided + let step = + step_array.map_or(Some(1), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; + Some((start, stop, step)) +} + +/// Returns an iterator of i64 values from start to stop +fn gen_range_iter( + start: i64, + stop: i64, + decreasing: bool, + include_upper: bool, +) -> Box> { + match (decreasing, include_upper) { + // Decreasing range, stop is inclusive + (true, true) => Box::new((stop..=start).rev()), + // Decreasing range, stop is exclusive + (true, false) => { + if stop == i64::MAX { + // start is never greater than stop, and stop is exclusive, + // so the decreasing range must be empty. + Box::new(std::iter::empty()) + } else { + // Increase the stop value by one to exclude it. + // Since stop is not i64::MAX, `stop + 1` will not overflow. + Box::new((stop + 1..=start).rev()) + } + } + // Increasing range, stop is inclusive + (false, true) => Box::new(start..=stop), + // Increasing range, stop is exclusive + (false, false) => Box::new(start..stop), + } +} + +fn gen_range_date(args: &[ArrayRef], include_upper_bound: bool) -> Result { + if args.len() != 3 { + return exec_err!("arguments length does not match"); + } + let (start_array, stop_array, step_array) = ( + Some(as_date32_array(&args[0])?), + as_date32_array(&args[1])?, + Some(as_interval_mdn_array(&args[2])?), + ); + + // values are date32s + let values_builder = Date32Builder::new(); + let mut list_builder = ListBuilder::new(values_builder); + + for (idx, stop) in stop_array.iter().enumerate() { + let mut stop = stop.unwrap_or(0); + + let start = if let Some(start_array_values) = start_array { + start_array_values.value(idx) + } else { + list_builder.append_null(); + continue; + }; + + let step = if let Some(step) = step_array { + step.value(idx) + } else { + list_builder.append_null(); + continue; + }; + + let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); + + if months == 0 && days == 0 { + return exec_err!("Cannot generate date range less than 1 day."); + } + + let neg = months < 0 || days < 0; + if !include_upper_bound { + stop = Date32Type::subtract_month_day_nano(stop, step); + } + let mut new_date = start; + + let values = from_fn(|| { + if (neg && new_date < stop) || (!neg && new_date > stop) { + None + } else { + let current_date = new_date; + new_date = Date32Type::add_month_day_nano(new_date, step); + Some(Some(current_date)) + } + }); + + list_builder.append_value(values); + } + + let arr = Arc::new(list_builder.finish()); + + Ok(arr) +} + +fn gen_range_timestamp(args: &[ArrayRef], include_upper_bound: bool) -> Result { + if args.len() != 3 { + return exec_err!( + "Arguments length must be 3 for {}", + if include_upper_bound { + "GENERATE_SERIES" + } else { + "RANGE" + } + ); + } + + // coerce_types fn should coerce all types to Timestamp(Nanosecond, tz) + let (start_arr, start_tz_opt) = cast_timestamp_arg(&args[0], include_upper_bound)?; + let (stop_arr, stop_tz_opt) = cast_timestamp_arg(&args[1], include_upper_bound)?; + let step_arr = as_interval_mdn_array(&args[2])?; + let start_tz = parse_tz(start_tz_opt)?; + let stop_tz = parse_tz(stop_tz_opt)?; + + // values are timestamps + let values_builder = start_tz_opt + .clone() + .map_or_else(TimestampNanosecondBuilder::new, |start_tz_str| { + TimestampNanosecondBuilder::new().with_timezone(start_tz_str) + }); + let mut list_builder = ListBuilder::new(values_builder); + + for idx in 0..start_arr.len() { + if start_arr.is_null(idx) || stop_arr.is_null(idx) || step_arr.is_null(idx) { + list_builder.append_null(); + continue; + } + + let start = start_arr.value(idx); + let stop = stop_arr.value(idx); + let step = step_arr.value(idx); + + let (months, days, ns) = IntervalMonthDayNanoType::to_parts(step); + if months == 0 && days == 0 && ns == 0 { + return exec_err!( + "Interval argument to {} must not be 0", + if include_upper_bound { + "GENERATE_SERIES" + } else { + "RANGE" + } + ); + } + + let neg = TSNT::add_month_day_nano(start, step, start_tz) + .ok_or(exec_datafusion_err!( + "Cannot generate timestamp range where start + step overflows" + ))? + .cmp(&start) + == Ordering::Less; + + let stop_dt = as_datetime_with_timezone::(stop, stop_tz).ok_or( + exec_datafusion_err!( + "Cannot generate timestamp for stop: {}: {:?}", + stop, + stop_tz + ), + )?; + + let mut current = start; + let mut current_dt = as_datetime_with_timezone::(current, start_tz).ok_or( + exec_datafusion_err!( + "Cannot generate timestamp for start: {}: {:?}", + current, + start_tz + ), + )?; + + let values = from_fn(|| { + if (include_upper_bound + && ((neg && current_dt < stop_dt) || (!neg && current_dt > stop_dt))) + || (!include_upper_bound + && ((neg && current_dt <= stop_dt) + || (!neg && current_dt >= stop_dt))) + { + return None; + } + + let prev_current = current; + + if let Some(ts) = TSNT::add_month_day_nano(current, step, start_tz) { + current = ts; + current_dt = as_datetime_with_timezone::(current, start_tz)?; + + Some(Some(prev_current)) + } else { + // we failed to parse the timestamp here so terminate the series + None + } + }); + + list_builder.append_value(values); + } + + let arr = Arc::new(list_builder.finish()); + + Ok(arr) +} + +fn cast_timestamp_arg( + arg: &ArrayRef, + include_upper: bool, +) -> Result<(&TimestampNanosecondArray, &Option>)> { + match arg.data_type() { + Timestamp(Nanosecond, tz_opt) => { + Ok((as_timestamp_nanosecond_array(arg)?, tz_opt)) + } + _ => { + internal_err!( + "Unexpected argument type for {} : {}", + if include_upper { + "GENERATE_SERIES" + } else { + "RANGE" + }, + arg.data_type() + ) + } + } +} + +fn parse_tz(tz: &Option>) -> Result { + let tz = tz.as_ref().map_or_else(|| "+00", |s| s); + + Tz::from_str(tz) + .map_err(|op| exec_datafusion_err!("failed to parse timezone {tz}: {:?}", op)) +} diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-nested/src/remove.rs similarity index 67% rename from datafusion/functions-array/src/remove.rs rename to datafusion/functions-nested/src/remove.rs index 21e373081054..dc1ed4833c67 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-nested/src/remove.rs @@ -27,12 +27,14 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayRemove, array_remove, array element, @@ -50,7 +52,7 @@ impl ArrayRemove { pub fn new() -> Self { Self { signature: Signature::array_and_element(Volatility::Immutable), - aliases: vec!["array_remove".to_string(), "list_remove".to_string()], + aliases: vec!["list_remove".to_string()], } } } @@ -79,9 +81,46 @@ impl ScalarUDFImpl for ArrayRemove { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_remove_doc()) + } } -make_udf_function!( +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_remove_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Removes the first element from the array equal to the given value.", + ) + .with_syntax_example("array_remove(array, element)") + .with_sql_example( + r#"```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to be removed from the array.", + ) + .build() + .unwrap() + }) +} + +make_udf_expr_and_func!( ArrayRemoveN, array_remove_n, array element max, @@ -99,7 +138,7 @@ impl ArrayRemoveN { pub fn new() -> Self { Self { signature: Signature::any(3, Volatility::Immutable), - aliases: vec!["array_remove_n".to_string(), "list_remove_n".to_string()], + aliases: vec!["list_remove_n".to_string()], } } } @@ -128,9 +167,48 @@ impl ScalarUDFImpl for ArrayRemoveN { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_remove_n_doc()) + } } -make_udf_function!( +fn get_array_remove_n_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Removes the first `max` elements from the array equal to the given value.", + ) + .with_syntax_example("array_remove_n(array, element, max)") + .with_sql_example( + r#"```sql +> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); ++---------------------------------------------------------+ +| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) | ++---------------------------------------------------------+ +| [1, 3, 2, 1, 4] | ++---------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to be removed from the array.", + ) + .with_argument( + "max", + "Number of first occurrences to remove.", + ) + .build() + .unwrap() + }) +} + +make_udf_expr_and_func!( ArrayRemoveAll, array_remove_all, array element, @@ -148,10 +226,7 @@ impl ArrayRemoveAll { pub fn new() -> Self { Self { signature: Signature::array_and_element(Volatility::Immutable), - aliases: vec![ - "array_remove_all".to_string(), - "list_remove_all".to_string(), - ], + aliases: vec!["list_remove_all".to_string()], } } } @@ -180,6 +255,41 @@ impl ScalarUDFImpl for ArrayRemoveAll { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_remove_all_doc()) + } +} + +fn get_array_remove_all_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Removes all elements from the array equal to the given value.", + ) + .with_syntax_example("array_remove_all(array, element)") + .with_sql_example( + r#"```sql +> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); ++--------------------------------------------------+ +| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) | ++--------------------------------------------------+ +| [1, 3, 1, 4] | ++--------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "element", + "Element to be removed from the array.", + ) + .build() + .unwrap() + }) } /// Array_remove SQL function @@ -232,7 +342,7 @@ fn array_remove_internal( } } -/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences +/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences /// of `element_array[i]`. /// /// The type of each **element** in `list_array` must be the same as the type of diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-nested/src/repeat.rs similarity index 80% rename from datafusion/functions-array/src/repeat.rs rename to datafusion/functions-nested/src/repeat.rs index 89b766bdcdfc..55584c143a54 100644 --- a/datafusion/functions-array/src/repeat.rs +++ b/datafusion/functions-nested/src/repeat.rs @@ -29,12 +29,14 @@ use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayRepeat, array_repeat, element count, // arg name @@ -51,7 +53,7 @@ impl ArrayRepeat { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_repeat"), String::from("list_repeat")], + aliases: vec![String::from("list_repeat")], } } } @@ -84,6 +86,49 @@ impl ScalarUDFImpl for ArrayRepeat { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_repeat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_repeat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array containing element `count` times.", + ) + .with_syntax_example("array_repeat(element, count)") + .with_sql_example( + r#"```sql +> select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +> select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ +```"#, + ) + .with_argument( + "element", + "Element expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "count", + "Value of how many times to repeat the element.", + ) + .build() + .unwrap() + }) } /// Array_repeat SQL function diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-nested/src/replace.rs similarity index 68% rename from datafusion/functions-array/src/replace.rs rename to datafusion/functions-nested/src/replace.rs index c32305bb454b..1d0a1d1f2815 100644 --- a/datafusion/functions-array/src/replace.rs +++ b/datafusion/functions-nested/src/replace.rs @@ -27,30 +27,31 @@ use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use crate::utils::compare_element_to_list; use crate::utils::make_scalar_function; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayReplace, +make_udf_expr_and_func!(ArrayReplace, array_replace, array from to, "replaces the first occurrence of the specified element with another specified element.", array_replace_udf ); -make_udf_function!(ArrayReplaceN, +make_udf_expr_and_func!(ArrayReplaceN, array_replace_n, array from to max, "replaces the first `max` occurrences of the specified element with another specified element.", array_replace_n_udf ); -make_udf_function!(ArrayReplaceAll, +make_udf_expr_and_func!(ArrayReplaceAll, array_replace_all, array from to, "replaces all occurrences of the specified element with another specified element.", @@ -67,7 +68,7 @@ impl ArrayReplace { pub fn new() -> Self { Self { signature: Signature::any(3, Volatility::Immutable), - aliases: vec![String::from("array_replace"), String::from("list_replace")], + aliases: vec![String::from("list_replace")], } } } @@ -96,6 +97,47 @@ impl ScalarUDFImpl for ArrayReplace { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Replaces the first occurrence of the specified element with another specified element.", + ) + .with_syntax_example("array_replace(array, from, to)") + .with_sql_example( + r#"```sql +> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5); ++--------------------------------------------------------+ +| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++--------------------------------------------------------+ +| [1, 5, 2, 3, 2, 1, 4] | ++--------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "from", + "Initial element.", + ) + .with_argument( + "to", + "Final element.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -108,10 +150,7 @@ impl ArrayReplaceN { pub fn new() -> Self { Self { signature: Signature::any(4, Volatility::Immutable), - aliases: vec![ - String::from("array_replace_n"), - String::from("list_replace_n"), - ], + aliases: vec![String::from("list_replace_n")], } } } @@ -140,6 +179,49 @@ impl ScalarUDFImpl for ArrayReplaceN { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_replace_n_doc()) + } +} + +fn get_array_replace_n_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Replaces the first `max` occurrences of the specified element with another specified element.", + ) + .with_syntax_example("array_replace_n(array, from, to, max)") + .with_sql_example( + r#"```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "from", + "Initial element.", + ) + .with_argument( + "to", + "Final element.", + ) + .with_argument( + "max", + "Number of first occurrences to replace.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -152,10 +234,7 @@ impl ArrayReplaceAll { pub fn new() -> Self { Self { signature: Signature::any(3, Volatility::Immutable), - aliases: vec![ - String::from("array_replace_all"), - String::from("list_replace_all"), - ], + aliases: vec![String::from("list_replace_all")], } } } @@ -184,6 +263,45 @@ impl ScalarUDFImpl for ArrayReplaceAll { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_replace_all_doc()) + } +} + +fn get_array_replace_all_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Replaces all occurrences of the specified element with another specified element.", + ) + .with_syntax_example("array_replace_all(array, from, to)") + .with_sql_example( + r#"```sql +> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5); ++------------------------------------------------------------+ +| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | ++------------------------------------------------------------+ +| [1, 5, 5, 3, 5, 1, 4] | ++------------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "from", + "Initial element.", + ) + .with_argument( + "to", + "Final element.", + ) + .build() + .unwrap() + }) } /// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-nested/src/resize.rs similarity index 64% rename from datafusion/functions-array/src/resize.rs rename to datafusion/functions-nested/src/resize.rs index c5855d054494..b0255e7be2a3 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-nested/src/resize.rs @@ -19,18 +19,22 @@ use crate::utils::make_scalar_function; use arrow::array::{Capacities, MutableArrayData}; -use arrow_array::{ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; -use arrow_buffer::{ArrowNativeType, OffsetBuffer}; +use arrow_array::{ + new_null_array, Array, ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait, +}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayResize, array_resize, array size value, @@ -48,7 +52,7 @@ impl ArrayResize { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_resize".to_string(), "list_resize".to_string()], + aliases: vec!["list_resize".to_string()], } } } @@ -68,8 +72,8 @@ impl ScalarUDFImpl for ArrayResize { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), - LargeList(field) => Ok(LargeList(field.clone())), + List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))), + LargeList(field) => Ok(LargeList(Arc::clone(field))), _ => exec_err!( "Not reachable, data_type should be List, LargeList or FixedSizeList" ), @@ -83,6 +87,47 @@ impl ScalarUDFImpl for ArrayResize { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_resize_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_resize_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.", + ) + .with_syntax_example("array_resize(array, size, value)") + .with_sql_example( + r#"```sql +> select array_resize([1, 2, 3], 5, 0); ++-------------------------------------+ +| array_resize(List([1,2,3],5,0)) | ++-------------------------------------+ +| [1, 2, 3, 0, 0] | ++-------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "size", + "New size of given array.", + ) + .with_argument( + "value", + "Defines new elements' value or empty if value is not set.", + ) + .build() + .unwrap() + }) } /// array_resize SQL function @@ -91,9 +136,26 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { return exec_err!("array_resize needs two or three arguments"); } + let array = &arg[0]; + + // Checks if entire array is null + if array.null_count() == array.len() { + let return_type = match array.data_type() { + List(field) => List(Arc::clone(field)), + LargeList(field) => LargeList(Arc::clone(field)), + _ => { + return exec_err!( + "array_resize does not support type '{:?}'.", + array.data_type() + ) + } + }; + return Ok(new_null_array(&return_type, array.len())); + } + let new_len = as_int64_array(&arg[1])?; let new_element = if arg.len() == 3 { - Some(arg[2].clone()) + Some(Arc::clone(&arg[2])) } else { None }; @@ -112,15 +174,12 @@ pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { } /// array_resize keep the original array and append the default element to the end -fn general_list_resize( +fn general_list_resize>( array: &GenericListArray, count_array: &Int64Array, field: &FieldRef, default_element: Option, -) -> Result -where - O: TryInto, -{ +) -> Result { let data_type = array.value_type(); let values = array.values(); @@ -144,7 +203,16 @@ where capacity, ); + let mut null_builder = BooleanBufferBuilder::new(array.len()); + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + if array.is_null(row_index) { + null_builder.append(false); + offsets.push(offsets[row_index]); + continue; + } + null_builder.append(true); + let count = count_array.value(row_index).to_usize().ok_or_else(|| { internal_datafusion_err!("array_resize: failed to convert size to usize") })?; @@ -171,10 +239,12 @@ where } let data = mutable.freeze(); + let null_bit_buffer: NullBuffer = null_builder.finish().into(); + Ok(Arc::new(GenericListArray::::try_new( - field.clone(), + Arc::clone(field), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), - None, + Some(null_bit_buffer), )?)) } diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-nested/src/reverse.rs similarity index 71% rename from datafusion/functions-array/src/reverse.rs rename to datafusion/functions-nested/src/reverse.rs index 8324c407bd86..1ecf7f848468 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-nested/src/reverse.rs @@ -25,12 +25,14 @@ use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArrayReverse, array_reverse, array, @@ -48,7 +50,7 @@ impl ArrayReverse { pub fn new() -> Self { Self { signature: Signature::any(1, Volatility::Immutable), - aliases: vec!["array_reverse".to_string(), "list_reverse".to_string()], + aliases: vec!["list_reverse".to_string()], } } } @@ -77,6 +79,39 @@ impl ScalarUDFImpl for ArrayReverse { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_reverse_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_reverse_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns the array with the order of the elements reversed.", + ) + .with_syntax_example("array_reverse(array)") + .with_sql_example( + r#"```sql +> select array_reverse([1, 2, 3, 4]); ++------------------------------------------------------------+ +| array_reverse(List([1, 2, 3, 4])) | ++------------------------------------------------------------+ +| [4, 3, 2, 1] | ++------------------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_reverse SQL function @@ -94,18 +129,15 @@ pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { let array = as_large_list_array(&arg[0])?; general_array_reverse::(array, field) } - Null => Ok(arg[0].clone()), + Null => Ok(Arc::clone(&arg[0])), array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), } } -fn general_array_reverse( +fn general_array_reverse>( array: &GenericListArray, field: &FieldRef, -) -> Result -where - O: TryFrom, -{ +) -> Result { let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -141,7 +173,7 @@ where let data = mutable.freeze(); Ok(Arc::new(GenericListArray::::try_new( - field.clone(), + Arc::clone(field), OffsetBuffer::::new(offsets.into()), arrow_array::make_array(data), Some(nulls.into()), diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs similarity index 69% rename from datafusion/functions-array/src/set_ops.rs rename to datafusion/functions-nested/src/set_ops.rs index 5f3087fafd6f..ce8d248319fe 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -17,7 +17,7 @@ //! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions. -use crate::make_array::make_array_inner; +use crate::make_array::{empty_array_type, make_array_inner}; use crate::utils::make_scalar_function; use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; @@ -27,17 +27,18 @@ use arrow::row::{RowConverter, SortField}; use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use itertools::Itertools; use std::any::Any; use std::collections::HashSet; use std::fmt::{Display, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayUnion, array_union, array1 array2, @@ -45,7 +46,7 @@ make_udf_function!( array_union_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayIntersect, array_intersect, first_array second_array, @@ -53,7 +54,7 @@ make_udf_function!( array_intersect_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayDistinct, array_distinct, array, @@ -71,7 +72,7 @@ impl ArrayUnion { pub fn new() -> Self { Self { signature: Signature::any(2, Volatility::Immutable), - aliases: vec![String::from("array_union"), String::from("list_union")], + aliases: vec![String::from("list_union")], } } } @@ -104,6 +105,49 @@ impl ScalarUDFImpl for ArrayUnion { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_union_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_union_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates.", + ) + .with_syntax_example("array_union(array1, array2)") + .with_sql_example( + r#"```sql +> select array_union([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6] | ++----------------------------------------------------+ +> select array_union([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [1, 2, 3, 4, 5, 6, 7, 8] | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -116,10 +160,7 @@ impl ArrayIntersect { pub fn new() -> Self { Self { signature: Signature::any(2, Volatility::Immutable), - aliases: vec![ - String::from("array_intersect"), - String::from("list_intersect"), - ], + aliases: vec![String::from("list_intersect")], } } } @@ -140,7 +181,7 @@ impl ScalarUDFImpl for ArrayIntersect { fn return_type(&self, arg_types: &[DataType]) -> Result { match (arg_types[0].clone(), arg_types[1].clone()) { (Null, Null) | (Null, _) => Ok(Null), - (_, Null) => Ok(List(Arc::new(Field::new("item", Null, true)))), + (_, Null) => Ok(empty_array_type()), (dt, _) => Ok(dt), } } @@ -152,6 +193,47 @@ impl ScalarUDFImpl for ArrayIntersect { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_intersect_doc()) + } +} + +fn get_array_intersect_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns an array of elements in the intersection of array1 and array2.", + ) + .with_syntax_example("array_intersect(array1, array2)") + .with_sql_example( + r#"```sql +> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [3, 4] | ++----------------------------------------------------+ +> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); ++----------------------------------------------------+ +| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); | ++----------------------------------------------------+ +| [] | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array1", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "array2", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } #[derive(Debug)] @@ -164,7 +246,7 @@ impl ArrayDistinct { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec!["array_distinct".to_string(), "list_distinct".to_string()], + aliases: vec!["list_distinct".to_string()], } } } @@ -207,6 +289,37 @@ impl ScalarUDFImpl for ArrayDistinct { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_distinct_doc()) + } +} + +fn get_array_distinct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Returns distinct values from the array after removing duplicates.", + ) + .with_syntax_example("array_distinct(array)") + .with_sql_example( + r#"```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .build() + .unwrap() + }) } /// array_distinct SQL function @@ -218,7 +331,7 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result { // handle null if args[0].data_type() == &Null { - return Ok(args[0].clone()); + return Ok(Arc::clone(&args[0])); } // handle for list & largelist @@ -264,6 +377,17 @@ fn generic_set_lists( return general_array_distinct::(l, &field); } + // Handle empty array at rhs case + // array_union(arr, []) -> arr; + // array_intersect(arr, []) -> []; + if r.value_length(0).is_zero() { + if set_op == SetOp::Union { + return Ok(Arc::new(l.clone()) as ArrayRef); + } else { + return Ok(Arc::new(r.clone()) as ArrayRef); + } + } + if l.value_type() != r.value_type() { return internal_err!("{set_op:?} is not implemented for '{l:?}' and '{r:?}'"); } @@ -308,7 +432,7 @@ fn generic_set_lists( offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; let array = match arrays.first() { - Some(array) => array.clone(), + Some(array) => Arc::clone(array), None => { return internal_err!("{set_op}: failed to get array from rows"); } @@ -364,12 +488,12 @@ fn general_set_op( (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; - generic_set_lists::(array1, array2, field.clone(), set_op) + generic_set_lists::(array1, array2, Arc::clone(field), set_op) } (LargeList(field), LargeList(_)) => { let array1 = as_large_list_array(&array1)?; let array2 = as_large_list_array(&array2)?; - generic_set_lists::(array1, array2, field.clone(), set_op) + generic_set_lists::(array1, array2, Arc::clone(field), set_op) } (data_type1, data_type2) => { internal_err!( @@ -420,7 +544,7 @@ fn general_array_distinct( offsets.push(last_offset + OffsetSize::usize_as(rows.len())); let arrays = converter.convert_rows(rows)?; let array = match arrays.first() { - Some(array) => array.clone(), + Some(array) => Arc::clone(array), None => { return internal_err!("array_distinct: failed to get array from rows") } @@ -431,7 +555,7 @@ fn general_array_distinct( let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); let values = compute::concat(&new_arrays_ref)?; Ok(Arc::new(GenericListArray::::try_new( - field.clone(), + Arc::clone(field), offsets, values, None, diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-nested/src/sort.rs similarity index 76% rename from datafusion/functions-array/src/sort.rs rename to datafusion/functions-nested/src/sort.rs index af78712065fc..b29c187f0679 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-nested/src/sort.rs @@ -25,12 +25,14 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -make_udf_function!( +make_udf_expr_and_func!( ArraySort, array_sort, array desc null_first, @@ -48,7 +50,7 @@ impl ArraySort { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_sort".to_string(), "list_sort".to_string()], + aliases: vec!["list_sort".to_string()], } } } @@ -91,6 +93,47 @@ impl ScalarUDFImpl for ArraySort { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_sort_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_sort_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Sort array.", + ) + .with_syntax_example("array_sort(array, desc, nulls_first)") + .with_sql_example( + r#"```sql +> select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "desc", + "Whether to sort in descending order(`ASC` or `DESC`).", + ) + .with_argument( + "nulls_first", + "Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`).", + ) + .build() + .unwrap() + }) } /// Array_sort SQL function @@ -121,6 +164,9 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; let row_count = list_array.len(); + if row_count == 0 { + return Ok(Arc::clone(&args[0])); + } let mut array_lengths = vec![]; let mut arrays = vec![]; diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-nested/src/string.rs similarity index 78% rename from datafusion/functions-array/src/string.rs rename to datafusion/functions-nested/src/string.rs index 38059035005b..30f3845215fc 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-nested/src/string.rs @@ -24,21 +24,26 @@ use arrow::array::{ UInt8Array, }; use arrow::datatypes::{DataType, Field}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{Expr, TypeSignature}; +use datafusion_expr::TypeSignature; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use std::any::{type_name, Any}; use crate::utils::{downcast_arg, make_scalar_function}; -use arrow_schema::DataType::{FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8}; +use arrow::compute::cast; +use arrow_schema::DataType::{ + Dictionary, FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8, +}; use datafusion_common::cast::{ as_generic_string_array, as_large_list_array, as_list_array, as_string_array, }; use datafusion_common::exec_err; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use std::sync::Arc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_ARRAY; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; macro_rules! to_string { ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ @@ -77,7 +82,7 @@ macro_rules! call_array_function { DataType::UInt16 => array_function!(UInt16Array), DataType::UInt32 => array_function!(UInt32Array), DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), + dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), } }; ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ @@ -96,13 +101,13 @@ macro_rules! call_array_function { DataType::UInt16 => array_function!(UInt16Array), DataType::UInt32 => array_function!(UInt32Array), DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), + dt => not_impl_err!("Unsupported data type in array_to_string: {dt}"), } }}; } // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayToString, array_to_string, array delimiter, // arg name @@ -120,7 +125,6 @@ impl ArrayToString { Self { signature: Signature::variadic_any(Volatility::Immutable), aliases: vec![ - String::from("array_to_string"), String::from("list_to_string"), String::from("array_join"), String::from("list_join"), @@ -158,9 +162,46 @@ impl ScalarUDFImpl for ArrayToString { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_array_to_string_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_array_to_string_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Converts each element to its text representation.", + ) + .with_syntax_example("array_to_string(array, delimiter)") + .with_sql_example( + r#"```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +```"#, + ) + .with_argument( + "array", + "Array expression. Can be a constant, column, or function, and any combination of array operators.", + ) + .with_argument( + "delimiter", + "Array element separator.", + ) + .build() + .unwrap() + }) } -make_udf_function!( +make_udf_expr_and_func!( StringToArray, string_to_array, string delimiter null_string, // arg name @@ -183,10 +224,7 @@ impl StringToArray { ], Volatility::Immutable, ), - aliases: vec![ - String::from("string_to_array"), - String::from("string_to_list"), - ], + aliases: vec![String::from("string_to_list")], } } } @@ -230,6 +268,51 @@ impl ScalarUDFImpl for StringToArray { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_string_to_array_doc()) + } +} + +fn get_string_to_array_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ARRAY) + .with_description( + "Splits a string into an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL.", + ) + .with_syntax_example("string_to_array(str, delimiter[, null_str])") + .with_sql_example( + r#"```sql +> select string_to_array('abc##def', '##'); ++-----------------------------------+ +| string_to_array(Utf8('abc##def')) | ++-----------------------------------+ +| ['abc', 'def'] | ++-----------------------------------+ +> select string_to_array('abc def', ' ', 'def'); ++---------------------------------------------+ +| string_to_array(Utf8('abc def'), Utf8(' '), Utf8('def')) | ++---------------------------------------------+ +| ['abc', NULL] | ++---------------------------------------------+ +```"#, + ) + .with_argument( + "str", + "String expression to split.", + ) + .with_argument( + "delimiter", + "Delimiter string to split on.", + ) + .with_argument( + "null_str", + "Substring values to be replaced with `NULL`.", + ) + .build() + .unwrap() + }) } /// Array_to_string SQL function @@ -250,6 +333,8 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { with_null_string = true; } + /// Creates a single string from single element of a ListArray (which is + /// itself another Array) fn compute_array_to_string( arg: &mut String, arr: ArrayRef, @@ -286,6 +371,22 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { Ok(arg) } + Dictionary(_key_type, value_type) => { + // Call cast to unwrap the dictionary. This could be optimized if we wanted + // to accept the overhead of extra code + let values = cast(&arr, value_type.as_ref()).map_err(|e| { + DataFusionError::from(e).context( + "Casting dictionary to values in compute_array_to_string", + ) + })?; + compute_array_to_string( + arg, + values, + delimiter, + null_string, + with_null_string, + ) + } Null => Ok(arg), data_type => { macro_rules! array_function { @@ -365,7 +466,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { let delimiter = delimiters[0].unwrap(); let s = compute_array_to_string( &mut arg, - arr.clone(), + Arc::clone(arr), delimiter.to_string(), null_string, with_null_string, diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-nested/src/utils.rs similarity index 87% rename from datafusion/functions-array/src/utils.rs rename to datafusion/functions-nested/src/utils.rs index 86fd281b5845..b9a75724bcde 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -26,13 +26,13 @@ use arrow_array::{ UInt32Array, }; use arrow_buffer::OffsetBuffer; -use arrow_schema::Field; +use arrow_schema::{Field, Fields}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use core::any::type_name; use datafusion_common::DataFusionError; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_expr::ColumnarValue; macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ @@ -60,11 +60,13 @@ pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { } /// array function wrapper that differentiates between scalar (length 1) and array. -pub(crate) fn make_scalar_function(inner: F) -> ScalarFunctionImplementation +pub(crate) fn make_scalar_function( + inner: F, +) -> impl Fn(&[ColumnarValue]) -> Result where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, + F: Fn(&[ArrayRef]) -> Result, { - Arc::new(move |args: &[ColumnarValue]| { + move |args: &[ColumnarValue]| { // first, identify if any of the arguments is an Array. If yes, store its `len`, // as any scalar will need to be converted to an array of len `len`. let len = args @@ -87,7 +89,7 @@ where } else { result.map(ColumnarValue::Array) } - }) + } } pub(crate) fn align_array_dimensions( @@ -105,7 +107,7 @@ pub(crate) fn align_array_dimensions( .zip(args_ndim.iter()) .map(|(array, ndim)| { if ndim < max_ndim { - let mut aligned_array = array.clone(); + let mut aligned_array = Arc::clone(&array); for _ in 0..(max_ndim - ndim) { let data_type = aligned_array.data_type().to_owned(); let array_lengths = vec![1; aligned_array.len()]; @@ -120,7 +122,7 @@ pub(crate) fn align_array_dimensions( } Ok(aligned_array) } else { - Ok(array.clone()) + Ok(Arc::clone(&array)) } }) .collect(); @@ -253,11 +255,26 @@ pub(crate) fn compute_array_dims( } } +pub(crate) fn get_map_entry_field(data_type: &DataType) -> Result<&Fields> { + match data_type { + DataType::Map(field, _) => { + let field_data_type = field.data_type(); + match field_data_type { + DataType::Struct(fields) => Ok(fields), + _ => { + internal_err!("Expected a Struct type, got {:?}", field_data_type) + } + } + } + _ => internal_err!("Expected a Map type, got {:?}", data_type), + } +} + #[cfg(test)] mod tests { use super::*; use arrow::datatypes::Int64Type; - use datafusion_common::utils::array_into_list_array; + use datafusion_common::utils::array_into_list_array_nullable; /// Only test internal functions, array-related sql functions will be tested in sqllogictest `array.slt` #[test] @@ -272,8 +289,12 @@ mod tests { Some(vec![Some(6), Some(7), Some(8)]), ])); - let array2d_1 = Arc::new(array_into_list_array(array1d_1.clone())) as ArrayRef; - let array2d_2 = Arc::new(array_into_list_array(array1d_2.clone())) as ArrayRef; + let array2d_1 = Arc::new(array_into_list_array_nullable( + Arc::clone(&array1d_1) as ArrayRef + )) as ArrayRef; + let array2d_2 = Arc::new(array_into_list_array_nullable( + Arc::clone(&array1d_2) as ArrayRef + )) as ArrayRef; let res = align_array_dimensions::(vec![ array1d_1.to_owned(), @@ -289,11 +310,10 @@ mod tests { expected_dim ); - let array3d_1 = Arc::new(array_into_list_array(array2d_1)) as ArrayRef; - let array3d_2 = array_into_list_array(array2d_2.to_owned()); + let array3d_1 = Arc::new(array_into_list_array_nullable(array2d_1)) as ArrayRef; + let array3d_2 = array_into_list_array_nullable(array2d_2.to_owned()); let res = - align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2.clone())]) - .unwrap(); + align_array_dimensions::(vec![array1d_1, Arc::new(array3d_2)]).unwrap(); let expected = as_list_array(&array3d_1).unwrap(); let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type()); diff --git a/datafusion/functions-window-common/Cargo.toml b/datafusion/functions-window-common/Cargo.toml new file mode 100644 index 000000000000..b5df212b7d2a --- /dev/null +++ b/datafusion/functions-window-common/Cargo.toml @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-window-common" +description = "Common functions for implementing user-defined window functions for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +authors = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_functions_window_common" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +datafusion-common = { workspace = true } +datafusion-physical-expr-common = { workspace = true } diff --git a/datafusion/functions-window-common/README.md b/datafusion/functions-window-common/README.md new file mode 100644 index 000000000000..de12d25f9731 --- /dev/null +++ b/datafusion/functions-window-common/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Window Function Common Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains common functions for implementing user-defined window functions. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-window-common/src/expr.rs b/datafusion/functions-window-common/src/expr.rs new file mode 100644 index 000000000000..1d99fe7acf15 --- /dev/null +++ b/datafusion/functions-window-common/src/expr.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to user-defined window function +#[derive(Debug, Default)] +pub struct ExpressionArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], +} + +impl<'a> ExpressionArgs<'a> { + /// Create an instance of [`ExpressionArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + ) -> Self { + Self { + input_exprs, + input_types, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } +} diff --git a/datafusion/functions-window-common/src/field.rs b/datafusion/functions-window-common/src/field.rs new file mode 100644 index 000000000000..8011b7b0f05f --- /dev/null +++ b/datafusion/functions-window-common/src/field.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; + +/// Metadata for defining the result field from evaluating a +/// user-defined window function. +pub struct WindowUDFFieldArgs<'a> { + /// The data types corresponding to the arguments to the + /// user-defined window function. + input_types: &'a [DataType], + /// The display name of the user-defined window function. + display_name: &'a str, +} + +impl<'a> WindowUDFFieldArgs<'a> { + /// Create an instance of [`WindowUDFFieldArgs`]. + /// + /// # Arguments + /// + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// * `function_name` - The qualified schema name of the + /// user-defined window function expression. + /// + pub fn new(input_types: &'a [DataType], display_name: &'a str) -> Self { + WindowUDFFieldArgs { + input_types, + display_name, + } + } + + /// Returns the data type of input expressions passed as arguments + /// to the user-defined window function. + pub fn input_types(&self) -> &[DataType] { + self.input_types + } + + /// Returns the name for the field of the final result of evaluating + /// the user-defined window function. + pub fn name(&self) -> &str { + self.display_name + } + + /// Returns `Some(DataType)` of input expression at index, otherwise + /// returns `None` if the index is out of bounds. + pub fn get_input_type(&self, index: usize) -> Option { + self.input_types.get(index).cloned() + } +} diff --git a/datafusion/functions-window-common/src/lib.rs b/datafusion/functions-window-common/src/lib.rs new file mode 100644 index 000000000000..da8d096da562 --- /dev/null +++ b/datafusion/functions-window-common/src/lib.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Common user-defined window functionality for [DataFusion] +//! +//! [DataFusion]: +pub mod expr; +pub mod field; +pub mod partition; diff --git a/datafusion/functions-window-common/src/partition.rs b/datafusion/functions-window-common/src/partition.rs new file mode 100644 index 000000000000..64786d2fe7c7 --- /dev/null +++ b/datafusion/functions-window-common/src/partition.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +/// Arguments passed to created user-defined window function state +/// during physical execution. +#[derive(Debug, Default)] +pub struct PartitionEvaluatorArgs<'a> { + /// The expressions passed as arguments to the user-defined window + /// function. + input_exprs: &'a [Arc], + /// The corresponding data types of expressions passed as arguments + /// to the user-defined window function. + input_types: &'a [DataType], + /// Set to `true` if the user-defined window function is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is specified. + ignore_nulls: bool, +} + +impl<'a> PartitionEvaluatorArgs<'a> { + /// Create an instance of [`PartitionEvaluatorArgs`]. + /// + /// # Arguments + /// + /// * `input_exprs` - The expressions passed as arguments + /// to the user-defined window function. + /// * `input_types` - The data types corresponding to the + /// arguments to the user-defined window function. + /// * `is_reversed` - Set to `true` if and only if the user-defined + /// window function is reversible and is reversed. + /// * `ignore_nulls` - Set to `true` when `IGNORE NULLS` is + /// specified. + /// + pub fn new( + input_exprs: &'a [Arc], + input_types: &'a [DataType], + is_reversed: bool, + ignore_nulls: bool, + ) -> Self { + Self { + input_exprs, + input_types, + is_reversed, + ignore_nulls, + } + } + + /// Returns the expressions passed as arguments to the user-defined + /// window function. + pub fn input_exprs(&self) -> &'a [Arc] { + self.input_exprs + } + + /// Returns the [`DataType`]s corresponding to the input expressions + /// to the user-defined window function. + pub fn input_types(&self) -> &'a [DataType] { + self.input_types + } + + /// Returns `true` when the user-defined window function is + /// reversed, otherwise returns `false`. + pub fn is_reversed(&self) -> bool { + self.is_reversed + } + + /// Returns `true` when `IGNORE NULLS` is specified, otherwise + /// returns `false`. + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } +} diff --git a/datafusion/functions-window/Cargo.toml b/datafusion/functions-window/Cargo.toml new file mode 100644 index 000000000000..262c21fcec65 --- /dev/null +++ b/datafusion/functions-window/Cargo.toml @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions-window" +description = "Window function packages for the DataFusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[lib] +name = "datafusion_functions_window" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-functions-window-common = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } +log = { workspace = true } +paste = "1.0.15" + +[dev-dependencies] +arrow = { workspace = true } diff --git a/datafusion/functions-window/README.md b/datafusion/functions-window/README.md new file mode 100644 index 000000000000..18590983ca47 --- /dev/null +++ b/datafusion/functions-window/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Window Function Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains user-defined window functions. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/functions-window/src/cume_dist.rs b/datafusion/functions-window/src/cume_dist.rs new file mode 100644 index 000000000000..9e30c672fee5 --- /dev/null +++ b/datafusion/functions-window/src/cume_dist.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `cume_dist` window function implementation + +use datafusion_common::arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::Result; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::iter; +use std::ops::Range; +use std::sync::{Arc, OnceLock}; + +define_udwf_and_expr!( + CumeDist, + cume_dist, + "Calculates the cumulative distribution of a value in a group of values." +); + +/// CumeDist calculates the cume_dist in the window function with order by +#[derive(Debug)] +pub struct CumeDist { + signature: Signature, +} + +impl CumeDist { + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for CumeDist { + fn default() -> Self { + Self::new() + } +} + +impl WindowUDFImpl for CumeDist { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "cume_dist" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::Float64, false)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_cume_dist_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cume_dist_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows).", + ) + .with_syntax_example("cume_dist()") + .build() + .unwrap() + }) +} + +#[derive(Debug, Default)] +pub(crate) struct CumeDistEvaluator; + +impl PartitionEvaluator for CumeDistEvaluator { + /// Computes the cumulative distribution for all rows in the partition + fn evaluate_all_with_rank( + &self, + num_rows: usize, + ranks_in_partition: &[Range], + ) -> Result { + let scalar = num_rows as f64; + let result = Float64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(0_u64, |acc, range| { + let len = range.end - range.start; + *acc += len as u64; + let value: f64 = (*acc as f64) / scalar; + let result = iter::repeat(value).take(len); + Some(result) + }) + .flatten(), + ); + Ok(Arc::new(result)) + } + + fn include_rank(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_common::cast::as_float64_array; + + fn test_f64_result( + num_rows: usize, + ranks: Vec>, + expected: Vec, + ) -> Result<()> { + let evaluator = CumeDistEvaluator; + let result = evaluator.evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; + let result = result.values().to_vec(); + assert_eq!(expected, result); + Ok(()) + } + + #[test] + #[allow(clippy::single_range_in_vec_init)] + fn test_cume_dist() -> Result<()> { + test_f64_result(0, vec![], vec![])?; + + test_f64_result(1, vec![0..1], vec![1.0])?; + + test_f64_result(2, vec![0..2], vec![1.0, 1.0])?; + + test_f64_result(4, vec![0..2, 2..4], vec![0.5, 0.5, 1.0, 1.0])?; + + Ok(()) + } +} diff --git a/datafusion/functions-window/src/lead_lag.rs b/datafusion/functions-window/src/lead_lag.rs new file mode 100644 index 000000000000..bbe50cbbdc8a --- /dev/null +++ b/datafusion/functions-window/src/lead_lag.rs @@ -0,0 +1,746 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `lead` and `lag` window function implementations + +use crate::utils::{get_scalar_value_from_args, get_signed_integer}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_ANALYTICAL; +use datafusion_expr::{ + Documentation, Literal, PartitionEvaluator, ReversedUDWF, Signature, TypeSignature, + Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::any::Any; +use std::cmp::min; +use std::collections::VecDeque; +use std::ops::{Neg, Range}; +use std::sync::{Arc, OnceLock}; + +get_or_init_udwf!( + Lag, + lag, + "Returns the row value that precedes the current row by a specified \ + offset within partition. If no such row exists, then returns the \ + default value.", + WindowShift::lag +); +get_or_init_udwf!( + Lead, + lead, + "Returns the value from a row that follows the current row by a \ + specified offset within the partition. If no such row exists, then \ + returns the default value.", + WindowShift::lead +); + +/// Create an expression to represent the `lag` window function +/// +/// returns value evaluated at the row that is offset rows before the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lag( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lag_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +/// Create an expression to represent the `lead` window function +/// +/// returns value evaluated at the row that is offset rows after the current row within the partition; +/// if there is no such row, instead return default (which must be of the same type as value). +/// Both offset and default are evaluated with respect to the current row. +/// If omitted, offset defaults to 1 and default to null +pub fn lead( + arg: datafusion_expr::Expr, + shift_offset: Option, + default_value: Option, +) -> datafusion_expr::Expr { + let shift_offset_lit = shift_offset + .map(|v| v.lit()) + .unwrap_or(ScalarValue::Null.lit()); + let default_lit = default_value.unwrap_or(ScalarValue::Null).lit(); + + lead_udwf().call(vec![arg, shift_offset_lit, default_lit]) +} + +#[derive(Debug)] +enum WindowShiftKind { + Lag, + Lead, +} + +impl WindowShiftKind { + fn name(&self) -> &'static str { + match self { + WindowShiftKind::Lag => "lag", + WindowShiftKind::Lead => "lead", + } + } + + /// In [`WindowShiftEvaluator`] a positive offset is used to signal + /// computation of `lag()`. So here we negate the input offset + /// value when computing `lead()`. + fn shift_offset(&self, value: Option) -> i64 { + match self { + WindowShiftKind::Lag => value.unwrap_or(1), + WindowShiftKind::Lead => value.map(|v| v.neg()).unwrap_or(-1), + } + } +} + +/// window shift expression +#[derive(Debug)] +pub struct WindowShift { + signature: Signature, + kind: WindowShiftKind, +} + +impl WindowShift { + fn new(kind: WindowShiftKind) -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Any(1), + TypeSignature::Any(2), + TypeSignature::Any(3), + ], + Volatility::Immutable, + ), + kind, + } + } + + pub fn lag() -> Self { + Self::new(WindowShiftKind::Lag) + } + + pub fn lead() -> Self { + Self::new(WindowShiftKind::Lead) + } +} + +static LAG_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lag_doc() -> &'static Documentation { + LAG_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ANALYTICAL) + .with_description( + "Returns value evaluated at the row that is offset rows before the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + ) + .with_syntax_example("lag(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows back \ + the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + .unwrap() + }) +} + +static LEAD_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lead_doc() -> &'static Documentation { + LEAD_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_ANALYTICAL) + .with_description( + "Returns value evaluated at the row that is offset rows after the \ + current row within the partition; if there is no such row, instead return default \ + (which must be of the same type as value).", + ) + .with_syntax_example("lead(expression, offset, default)") + .with_argument("expression", "Expression to operate on") + .with_argument("offset", "Integer. Specifies how many rows \ + forward the value of expression should be retrieved. Defaults to 1.") + .with_argument("default", "The default value if the offset is \ + not within the partition. Must be of the same type as expression.") + .build() + .unwrap() + }) +} + +impl WindowUDFImpl for WindowShift { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.kind.name() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + /// Handles the case where `NULL` expression is passed as an + /// argument to `lead`/`lag`. The type is refined depending + /// on the default value argument. + /// + /// For more details see: + fn expressions(&self, expr_args: ExpressionArgs) -> Vec> { + parse_expr(expr_args.input_exprs(), expr_args.input_types()) + .into_iter() + .collect::>() + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let shift_offset = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 1)? + .map(get_signed_integer) + .map_or(Ok(None), |v| v.map(Some)) + .map(|n| self.kind.shift_offset(n)) + .map(|offset| { + if partition_evaluator_args.is_reversed() { + -offset + } else { + offset + } + })?; + let default_value = parse_default_value( + partition_evaluator_args.input_exprs(), + partition_evaluator_args.input_types(), + )?; + + Ok(Box::new(WindowShiftEvaluator { + shift_offset, + default_value, + ignore_nulls: partition_evaluator_args.ignore_nulls(), + non_null_offsets: VecDeque::new(), + })) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = parse_expr_type(field_args.input_types())?; + + Ok(Field::new(field_args.name(), return_type, true)) + } + + fn reverse_expr(&self) -> ReversedUDWF { + match self.kind { + WindowShiftKind::Lag => ReversedUDWF::Reversed(lag_udwf()), + WindowShiftKind::Lead => ReversedUDWF::Reversed(lead_udwf()), + } + } + + fn documentation(&self) -> Option<&Documentation> { + match self.kind { + WindowShiftKind::Lag => Some(get_lag_doc()), + WindowShiftKind::Lead => Some(get_lead_doc()), + } + } +} + +/// When `lead`/`lag` is evaluated on a `NULL` expression we attempt to +/// refine it by matching it with the type of the default value. +/// +/// For e.g. in `lead(NULL, 1, false)` the generic `ScalarValue::Null` +/// is refined into `ScalarValue::Boolean(None)`. Only the type is +/// refined, the expression value remains `NULL`. +/// +/// When the window function is evaluated with `NULL` expression +/// this guarantees that the type matches with that of the default +/// value. +/// +/// For more details see: +fn parse_expr( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result> { + assert!(!input_exprs.is_empty()); + assert!(!input_types.is_empty()); + + let expr = Arc::clone(input_exprs.first().unwrap()); + let expr_type = input_types.first().unwrap(); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr); + } + + let default_value = get_scalar_value_from_args(input_exprs, 2)?; + default_value.map_or(Ok(expr), |value| { + ScalarValue::try_from(&value.data_type()).map(|v| { + Arc::new(datafusion_physical_expr::expressions::Literal::new(v)) + as Arc + }) + }) +} + +/// Returns the data type of the default value(if provided) when the +/// expression is `NULL`. +/// +/// Otherwise, returns the expression type unchanged. +fn parse_expr_type(input_types: &[DataType]) -> Result { + assert!(!input_types.is_empty()); + let expr_type = input_types.first().unwrap_or(&DataType::Null); + + // Handles the most common case where NULL is unexpected + if !expr_type.is_null() { + return Ok(expr_type.clone()); + } + + let default_value_type = input_types.get(2).unwrap_or(&DataType::Null); + Ok(default_value_type.clone()) +} + +/// Handles type coercion and null value refinement for default value +/// argument depending on the data type of the input expression. +fn parse_default_value( + input_exprs: &[Arc], + input_types: &[DataType], +) -> Result { + let expr_type = parse_expr_type(input_types)?; + let unparsed = get_scalar_value_from_args(input_exprs, 2)?; + + unparsed + .filter(|v| !v.data_type().is_null()) + .map(|v| v.cast_to(&expr_type)) + .unwrap_or(ScalarValue::try_from(expr_type)) +} + +#[derive(Debug)] +struct WindowShiftEvaluator { + shift_offset: i64, + default_value: ScalarValue, + ignore_nulls: bool, + // VecDeque contains offset values that between non-null entries + non_null_offsets: VecDeque, +} + +impl WindowShiftEvaluator { + fn is_lag(&self) -> bool { + // Mode is LAG, when shift_offset is positive + self.shift_offset > 0 + } +} + +// implement ignore null for evaluate_all +fn evaluate_all_with_ignore_null( + array: &ArrayRef, + offset: i64, + default_value: &ScalarValue, + is_lag: bool, +) -> Result { + let valid_indices: Vec = + array.nulls().unwrap().valid_indices().collect::>(); + let direction = !is_lag; + let new_array_results: Result, DataFusionError> = (0..array.len()) + .map(|id| { + let result_index = match valid_indices.binary_search(&id) { + Ok(pos) => if direction { + pos.checked_add(offset as usize) + } else { + pos.checked_sub(offset.unsigned_abs() as usize) + } + .and_then(|new_pos| { + if new_pos < valid_indices.len() { + Some(valid_indices[new_pos]) + } else { + None + } + }), + Err(pos) => if direction { + pos.checked_add(offset as usize) + } else if pos > 0 { + pos.checked_sub(offset.unsigned_abs() as usize) + } else { + None + } + .and_then(|new_pos| { + if new_pos < valid_indices.len() { + Some(valid_indices[new_pos]) + } else { + None + } + }), + }; + + match result_index { + Some(index) => ScalarValue::try_from_array(array, index), + None => Ok(default_value.clone()), + } + }) + .collect(); + + let new_array = new_array_results?; + ScalarValue::iter_to_array(new_array) +} +// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value +fn shift_with_default_value( + array: &ArrayRef, + offset: i64, + default_value: &ScalarValue, +) -> Result { + use datafusion_common::arrow::compute::concat; + + let value_len = array.len() as i64; + if offset == 0 { + Ok(Arc::clone(array)) + } else if offset == i64::MIN || offset.abs() >= value_len { + default_value.to_array_of_size(value_len as usize) + } else { + let slice_offset = (-offset).clamp(0, value_len) as usize; + let length = array.len() - offset.unsigned_abs() as usize; + let slice = array.slice(slice_offset, length); + + // Generate array with remaining `null` items + let nulls = offset.unsigned_abs() as usize; + let default_values = default_value.to_array_of_size(nulls)?; + + // Concatenate both arrays, add nulls after if shift > 0 else before + if offset > 0 { + concat(&[default_values.as_ref(), slice.as_ref()]) + .map_err(|e| arrow_datafusion_err!(e)) + } else { + concat(&[slice.as_ref(), default_values.as_ref()]) + .map_err(|e| arrow_datafusion_err!(e)) + } + } +} + +impl PartitionEvaluator for WindowShiftEvaluator { + fn get_range(&self, idx: usize, n_rows: usize) -> Result> { + if self.is_lag() { + let start = if self.non_null_offsets.len() == self.shift_offset as usize { + // How many rows needed previous than the current row to get necessary lag result + let offset: usize = self.non_null_offsets.iter().sum(); + idx.saturating_sub(offset) + } else if !self.ignore_nulls { + let offset = self.shift_offset as usize; + idx.saturating_sub(offset) + } else { + 0 + }; + let end = idx + 1; + Ok(Range { start, end }) + } else { + let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize { + // How many rows needed further than the current row to get necessary lead result + let offset: usize = self.non_null_offsets.iter().sum(); + min(idx + offset + 1, n_rows) + } else if !self.ignore_nulls { + let offset = (-self.shift_offset) as usize; + min(idx + offset, n_rows) + } else { + n_rows + }; + Ok(Range { start: idx, end }) + } + } + + fn is_causal(&self) -> bool { + // Lagging windows are causal by definition: + self.is_lag() + } + + fn evaluate( + &mut self, + values: &[ArrayRef], + range: &Range, + ) -> Result { + let array = &values[0]; + let len = array.len(); + + // LAG mode + let i = if self.is_lag() { + (range.end as i64 - self.shift_offset - 1) as usize + } else { + // LEAD mode + (range.start as i64 - self.shift_offset) as usize + }; + + let mut idx: Option = if i < len { Some(i) } else { None }; + + // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows + // If current row index points to NULL value the row is NOT counted + if self.ignore_nulls && self.is_lag() { + // LAG when NULLS are ignored. + // Find the nonNULL row index that shifted by offset comparing to current row index + idx = if self.non_null_offsets.len() == self.shift_offset as usize { + let total_offset: usize = self.non_null_offsets.iter().sum(); + Some(range.end - 1 - total_offset) + } else { + None + }; + + // Keep track of offset values between non-null entries + if array.is_valid(range.end - 1) { + // Non-null add new offset + self.non_null_offsets.push_back(1); + if self.non_null_offsets.len() > self.shift_offset as usize { + // WE do not need to keep track of more than `lag number of offset` values. + self.non_null_offsets.pop_front(); + } + } else if !self.non_null_offsets.is_empty() { + // Entry is null, increment offset value of the last entry. + let end_idx = self.non_null_offsets.len() - 1; + self.non_null_offsets[end_idx] += 1; + } + } else if self.ignore_nulls && !self.is_lag() { + // LEAD when NULLS are ignored. + // Stores the necessary non-null entry number further than the current row. + let non_null_row_count = (-self.shift_offset) as usize; + + if self.non_null_offsets.is_empty() { + // When empty, fill non_null offsets with the data further than the current row. + let mut offset_val = 1; + for idx in range.start + 1..range.end { + if array.is_valid(idx) { + self.non_null_offsets.push_back(offset_val); + offset_val = 1; + } else { + offset_val += 1; + } + // It is enough to keep track of `non_null_row_count + 1` non-null offset. + // further data is unnecessary for the result. + if self.non_null_offsets.len() == non_null_row_count + 1 { + break; + } + } + } else if range.end < len && array.is_valid(range.end) { + // Update `non_null_offsets` with the new end data. + if array.is_valid(range.end) { + // When non-null, append a new offset. + self.non_null_offsets.push_back(1); + } else { + // When null, increment offset count of the last entry + let last_idx = self.non_null_offsets.len() - 1; + self.non_null_offsets[last_idx] += 1; + } + } + + // Find the nonNULL row index that shifted by offset comparing to current row index + idx = if self.non_null_offsets.len() >= non_null_row_count { + let total_offset: usize = + self.non_null_offsets.iter().take(non_null_row_count).sum(); + Some(range.start + total_offset) + } else { + None + }; + // Prune `self.non_null_offsets` from the start. so that at next iteration + // start of the `self.non_null_offsets` matches with current row. + if !self.non_null_offsets.is_empty() { + self.non_null_offsets[0] -= 1; + if self.non_null_offsets[0] == 0 { + // When offset is 0. Remove it. + self.non_null_offsets.pop_front(); + } + } + } + + // Set the default value if + // - index is out of window bounds + // OR + // - ignore nulls mode and current value is null and is within window bounds + // .unwrap() is safe here as there is a none check in front + #[allow(clippy::unnecessary_unwrap)] + if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) { + ScalarValue::try_from_array(array, idx.unwrap()) + } else { + Ok(self.default_value.clone()) + } + } + + fn evaluate_all( + &mut self, + values: &[ArrayRef], + _num_rows: usize, + ) -> Result { + // LEAD, LAG window functions take single column, values will have size 1 + let value = &values[0]; + if !self.ignore_nulls { + shift_with_default_value(value, self.shift_offset, &self.default_value) + } else { + evaluate_all_with_ignore_null( + value, + self.shift_offset, + &self.default_value, + self.is_lag(), + ) + } + } + + fn supports_bounded_execution(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::*; + use datafusion_common::cast::as_int32_array; + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + fn test_i32_result( + expr: WindowShift, + partition_evaluator_args: PartitionEvaluatorArgs, + expected: Int32Array, + ) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); + let values = vec![arr]; + let num_rows = values.len(); + let result = expr + .partition_evaluator(partition_evaluator_args)? + .evaluate_all(&values, num_rows)?; + let result = as_int32_array(&result)?; + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn lead_lag_get_range() -> Result<()> { + // LAG(2) + let lag_fn = WindowShiftEvaluator { + shift_offset: 2, + default_value: ScalarValue::Null, + ignore_nulls: false, + non_null_offsets: Default::default(), + }; + assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 4, end: 7 }); + assert_eq!(lag_fn.get_range(0, 10)?, Range { start: 0, end: 1 }); + + // LAG(2 ignore nulls) + let lag_fn = WindowShiftEvaluator { + shift_offset: 2, + default_value: ScalarValue::Null, + ignore_nulls: true, + // models data received [, , , NULL, , NULL, , ...] + non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used + }; + assert_eq!(lag_fn.get_range(6, 10)?, Range { start: 2, end: 7 }); + + // LEAD(2) + let lead_fn = WindowShiftEvaluator { + shift_offset: -2, + default_value: ScalarValue::Null, + ignore_nulls: false, + non_null_offsets: Default::default(), + }; + assert_eq!(lead_fn.get_range(6, 10)?, Range { start: 6, end: 8 }); + assert_eq!(lead_fn.get_range(9, 10)?, Range { start: 9, end: 10 }); + + // LEAD(2 ignore nulls) + let lead_fn = WindowShiftEvaluator { + shift_offset: -2, + default_value: ScalarValue::Null, + ignore_nulls: true, + // models data received [..., , NULL, , NULL, , ..] + non_null_offsets: vec![2, 2].into(), + }; + assert_eq!(lead_fn.get_range(4, 10)?, Range { start: 4, end: 9 }); + + Ok(()) + } + + #[test] + fn test_lead_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + + test_i32_result( + WindowShift::lead(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + [ + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + Some(8), + None, + ] + .iter() + .collect::(), + ) + } + + #[test] + fn test_lag_window_shift() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + + test_i32_result( + WindowShift::lag(), + PartitionEvaluatorArgs::new(&[expr], &[DataType::Int32], false, false), + [ + None, + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + ] + .iter() + .collect::(), + ) + } + + #[test] + fn test_lag_with_default() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let shift_offset = + Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; + let default_value = Arc::new(Literal::new(ScalarValue::Int32(Some(100)))) + as Arc; + + let input_exprs = &[expr, shift_offset, default_value]; + let input_types: &[DataType] = + &[DataType::Int32, DataType::Int32, DataType::Int32]; + + test_i32_result( + WindowShift::lag(), + PartitionEvaluatorArgs::new(input_exprs, input_types, false, false), + [ + Some(100), + Some(1), + Some(-2), + Some(3), + Some(-4), + Some(5), + Some(-6), + Some(7), + ] + .iter() + .collect::(), + ) + } +} diff --git a/datafusion/functions-window/src/lib.rs b/datafusion/functions-window/src/lib.rs new file mode 100644 index 000000000000..ff8542838df9 --- /dev/null +++ b/datafusion/functions-window/src/lib.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window Function packages for [DataFusion]. +//! +//! This crate contains a collection of various window function packages for DataFusion, +//! implemented using the extension API. +//! +//! [DataFusion]: https://crates.io/crates/datafusion +//! +use std::sync::Arc; + +use log::debug; + +use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::WindowUDF; + +#[macro_use] +pub mod macros; + +pub mod cume_dist; +pub mod lead_lag; +pub mod ntile; +pub mod rank; +pub mod row_number; +mod utils; + +/// Fluent-style API for creating `Expr`s +pub mod expr_fn { + pub use super::cume_dist::cume_dist; + pub use super::lead_lag::lag; + pub use super::lead_lag::lead; + pub use super::ntile::ntile; + pub use super::rank::{dense_rank, percent_rank, rank}; + pub use super::row_number::row_number; +} + +/// Returns all default window functions +pub fn all_default_window_functions() -> Vec> { + vec![ + cume_dist::cume_dist_udwf(), + row_number::row_number_udwf(), + lead_lag::lead_udwf(), + lead_lag::lag_udwf(), + rank::rank_udwf(), + rank::dense_rank_udwf(), + rank::percent_rank_udwf(), + ntile::ntile_udwf(), + ] +} +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all( + registry: &mut dyn FunctionRegistry, +) -> datafusion_common::Result<()> { + let functions: Vec> = all_default_window_functions(); + + functions.into_iter().try_for_each(|fun| { + let existing_udwf = registry.register_udwf(fun)?; + if let Some(existing_udwf) = existing_udwf { + debug!("Overwrite existing UDWF: {}", existing_udwf.name()); + } + Ok(()) as datafusion_common::Result<()> + })?; + + Ok(()) +} diff --git a/datafusion/functions-window/src/macros.rs b/datafusion/functions-window/src/macros.rs new file mode 100644 index 000000000000..2905ccf4c204 --- /dev/null +++ b/datafusion/functions-window/src/macros.rs @@ -0,0 +1,689 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Convenience macros for defining a user-defined window function +//! and associated expression API (fluent style). +//! +//! See [`define_udwf_and_expr!`] for usage examples. +//! +//! [`define_udwf_and_expr!`]: crate::define_udwf_and_expr! + +/// Lazily initializes a user-defined window function exactly once +/// when called concurrently. Repeated calls return a reference to the +/// same instance. +/// +/// # Parameters +/// +/// * `$UDWF`: The struct which defines the [`Signature`](datafusion_expr::Signature) +/// of the user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDWF::default()`. +/// +/// # Example +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window::get_or_init_udwf; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// Defines the `simple_udwf()` user-defined window function. +/// get_or_init_udwf!( +/// SimpleUDWF, +/// simple, +/// "Simple user-defined window function doc comment." +/// ); +/// # +/// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); +/// # +/// # #[derive(Debug)] +/// # struct SimpleUDWF { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for SimpleUDWF { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for SimpleUDWF { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "simple_user_defined_window_function" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # } +/// # } +/// # +/// ``` +#[macro_export] +macro_rules! get_or_init_udwf { + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $UDWF::default); + }; + + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { + paste::paste! { + #[doc = concat!(" Singleton instance of [`", stringify!($OUT_FN_NAME), "`], ensures the user-defined")] + #[doc = concat!(" window function is only created once.")] + #[allow(non_upper_case_globals)] + static []: std::sync::OnceLock> = + std::sync::OnceLock::new(); + + #[doc = concat!(" Returns a [`WindowUDF`](datafusion_expr::WindowUDF) for [`", stringify!($OUT_FN_NAME), "`].")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn [<$OUT_FN_NAME _udwf>]() -> std::sync::Arc { + [] + .get_or_init(|| { + std::sync::Arc::new(datafusion_expr::WindowUDF::from($CTOR())) + }) + .clone() + } + } + }; +} + +/// Create a [`WindowFunction`] expression that exposes a fluent API +/// which you can use to build more complex expressions. +/// +/// [`WindowFunction`]: datafusion_expr::Expr::WindowFunction +/// +/// # Parameters +/// +/// * `$UDWF`: The struct which defines the [`Signature`] of the +/// user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. +/// +/// [`Signature`]: datafusion_expr::Signature +/// [`Expr`]: datafusion_expr::Expr +/// +/// # Example +/// +/// 1. With Zero Parameters +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// +/// # get_or_init_udwf!( +/// # RowNumber, +/// # row_number, +/// # "Returns a unique row number for each row in window partition beginning at 1." +/// # ); +/// /// Creates `row_number()` API which has zero parameters: +/// /// +/// /// ``` +/// /// /// Returns a unique row number for each row in window partition +/// /// /// beginning at 1. +/// /// pub fn row_number() -> datafusion_expr::Expr { +/// /// row_number_udwf().call(vec![]) +/// /// } +/// /// ``` +/// create_udwf_expr!( +/// RowNumber, +/// row_number, +/// "Returns a unique row number for each row in window partition beginning at 1." +/// ); +/// # +/// # assert_eq!( +/// # row_number().name_for_alias().unwrap(), +/// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct RowNumber { +/// # signature: Signature, +/// # } +/// # impl Default for RowNumber { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # impl WindowUDFImpl for RowNumber { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "row_number" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # } +/// # } +/// ``` +/// +/// 2. With Multiple Parameters +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// # get_or_init_udwf!(Lead, lead, "user-defined window function"); +/// # +/// /// Creates `lead(expr, offset, default)` with 3 parameters: +/// /// +/// /// ``` +/// /// /// Returns a value evaluated at the row that is offset rows +/// /// /// after the current row within the partition. +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// create_udwf_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], +/// "Returns a value evaluated at the row that is offset rows after the current row within the partition." +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for Lead { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +#[macro_export] +macro_rules! create_udwf_expr { + // zero arguments + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + paste::paste! { + #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] + #[doc = concat!(" `", stringify!($UDWF), "` user-defined window function.")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn $OUT_FN_NAME() -> datafusion_expr::Expr { + [<$OUT_FN_NAME _udwf>]().call(vec![]) + } + } + }; + + // 1 or more arguments + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { + paste::paste! { + #[doc = " Create a [`WindowFunction`](datafusion_expr::Expr::WindowFunction) expression for"] + #[doc = concat!(" `", stringify!($UDWF), "` user-defined window function.")] + #[doc = ""] + #[doc = concat!(" ", $DOC)] + pub fn $OUT_FN_NAME( + $($PARAM: datafusion_expr::Expr),+ + ) -> datafusion_expr::Expr { + [<$OUT_FN_NAME _udwf>]() + .call(vec![$($PARAM),+]) + } + } + }; +} + +/// Defines a user-defined window function. +/// +/// Combines [`get_or_init_udwf!`] and [`create_udwf_expr!`] into a +/// single macro for convenience. +/// +/// # Arguments +/// +/// * `$UDWF`: The struct which defines the [`Signature`] of the +/// user-defined window function. +/// * `$OUT_FN_NAME`: The basename to generate a unique function name like +/// `$OUT_FN_NAME_udwf`. +/// * (optional) `[$($PARAM:ident),+]`: An array of 1 or more parameters +/// for the generated function. The type of parameters is [`Expr`]. +/// When omitted this creates a function with zero parameters. +/// * `$DOC`: Doc comments for UDWF. +/// * (optional) `$CTOR`: Pass a custom constructor. When omitted it +/// automatically resolves to `$UDWF::default()`. +/// +/// [`Signature`]: datafusion_expr::Signature +/// [`Expr`]: datafusion_expr::Expr +/// +/// # Usage +/// +/// ## Expression API With Zero parameters +/// 1. Uses default constructor for UDWF. +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window::{define_udwf_and_expr, get_or_init_udwf, create_udwf_expr}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `simple_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn simple() -> datafusion_expr::Expr { +/// /// simple_udwf().call(vec![]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// SimpleUDWF, +/// simple, +/// "a simple user-defined window function" +/// ); +/// # +/// # assert_eq!(simple_udwf().name(), "simple_user_defined_window_function"); +/// # +/// # #[derive(Debug)] +/// # struct SimpleUDWF { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for SimpleUDWF { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for SimpleUDWF { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "simple_user_defined_window_function" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::Int64, false)) +/// # } +/// # } +/// # +/// ``` +/// +/// 2. Uses a custom constructor for UDWF. +/// +/// ``` +/// # use std::any::Any; +/// # use datafusion_common::arrow::datatypes::{DataType, Field}; +/// # use datafusion_expr::{PartitionEvaluator, Signature, Volatility, WindowUDFImpl}; +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `row_number_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn row_number() -> datafusion_expr::Expr { +/// /// row_number_udwf().call(vec![]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// RowNumber, +/// row_number, +/// "Returns a unique row number for each row in window partition beginning at 1.", +/// RowNumber::new // <-- custom constructor +/// ); +/// # +/// # assert_eq!( +/// # row_number().name_for_alias().unwrap(), +/// # "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct RowNumber { +/// # signature: Signature, +/// # } +/// # impl RowNumber { +/// # fn new() -> Self { +/// # Self { +/// # signature: Signature::any(0, Volatility::Immutable), +/// # } +/// # } +/// # } +/// # impl WindowUDFImpl for RowNumber { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "row_number" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new(field_args.name(), DataType::UInt64, false)) +/// # } +/// # } +/// ``` +/// +/// ## Expression API With Multiple Parameters +/// 3. Uses default constructor for UDWF +/// +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `lead_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], // <- 3 parameters +/// "user-defined window function" +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Default for Lead { +/// # fn default() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +/// 4. Uses custom constructor for UDWF +/// +/// ``` +/// # use std::any::Any; +/// # +/// # use datafusion_expr::{ +/// # PartitionEvaluator, Signature, TypeSignature, Volatility, WindowUDFImpl, +/// # }; +/// # +/// # use datafusion_functions_window::{create_udwf_expr, define_udwf_and_expr, get_or_init_udwf}; +/// # use datafusion_functions_window_common::field::WindowUDFFieldArgs; +/// # +/// # use datafusion_common::arrow::datatypes::Field; +/// # use datafusion_common::ScalarValue; +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +/// # +/// /// 1. Defines the `lead_udwf()` user-defined window function. +/// /// +/// /// 2. Defines the expression API: +/// /// ``` +/// /// pub fn lead( +/// /// expr: datafusion_expr::Expr, +/// /// offset: datafusion_expr::Expr, +/// /// default: datafusion_expr::Expr, +/// /// ) -> datafusion_expr::Expr { +/// /// lead_udwf().call(vec![expr, offset, default]) +/// /// } +/// /// ``` +/// define_udwf_and_expr!( +/// Lead, +/// lead, +/// [expr, offset, default], // <- 3 parameters +/// "user-defined window function", +/// Lead::new // <- Custom constructor +/// ); +/// # +/// # assert_eq!( +/// # lead(col("a"), lit(1i64), lit(ScalarValue::Null)) +/// # .name_for_alias() +/// # .unwrap(), +/// # "lead(a,Int64(1),NULL) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// # ); +/// # +/// # #[derive(Debug)] +/// # struct Lead { +/// # signature: Signature, +/// # } +/// # +/// # impl Lead { +/// # fn new() -> Self { +/// # Self { +/// # signature: Signature::one_of( +/// # vec![ +/// # TypeSignature::Any(1), +/// # TypeSignature::Any(2), +/// # TypeSignature::Any(3), +/// # ], +/// # Volatility::Immutable, +/// # ), +/// # } +/// # } +/// # } +/// # +/// # impl WindowUDFImpl for Lead { +/// # fn as_any(&self) -> &dyn Any { +/// # self +/// # } +/// # fn name(&self) -> &str { +/// # "lead" +/// # } +/// # fn signature(&self) -> &Signature { +/// # &self.signature +/// # } +/// # fn partition_evaluator( +/// # &self, +/// # _partition_evaluator_args: PartitionEvaluatorArgs, +/// # ) -> datafusion_common::Result> { +/// # unimplemented!() +/// # } +/// # fn field(&self, field_args: WindowUDFFieldArgs) -> datafusion_common::Result { +/// # Ok(Field::new( +/// # field_args.name(), +/// # field_args.get_input_type(0).unwrap(), +/// # false, +/// # )) +/// # } +/// # } +/// ``` +#[macro_export] +macro_rules! define_udwf_and_expr { + // Defines UDWF with default constructor + // Defines expression API with zero parameters + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC); + create_udwf_expr!($UDWF, $OUT_FN_NAME, $DOC); + }; + + // Defines UDWF by passing a custom constructor + // Defines expression API with zero parameters + ($UDWF:ident, $OUT_FN_NAME:ident, $DOC:expr, $CTOR:path) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $CTOR); + create_udwf_expr!($UDWF, $OUT_FN_NAME, $DOC); + }; + + // Defines UDWF with default constructor + // Defines expression API with multiple parameters + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC); + create_udwf_expr!($UDWF, $OUT_FN_NAME, [$($PARAM),+], $DOC); + }; + + // Defines UDWF by passing a custom constructor + // Defines expression API with multiple parameters + ($UDWF:ident, $OUT_FN_NAME:ident, [$($PARAM:ident),+], $DOC:expr, $CTOR:path) => { + get_or_init_udwf!($UDWF, $OUT_FN_NAME, $DOC, $CTOR); + create_udwf_expr!($UDWF, $OUT_FN_NAME, [$($PARAM),+], $DOC); + }; +} diff --git a/datafusion/functions-window/src/ntile.rs b/datafusion/functions-window/src/ntile.rs new file mode 100644 index 000000000000..b0a7241f24cd --- /dev/null +++ b/datafusion/functions-window/src/ntile.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `ntile` window function implementation + +use std::any::Any; +use std::fmt::Debug; +use std::sync::{Arc, OnceLock}; + +use crate::utils::{ + get_scalar_value_from_args, get_signed_integer, get_unsigned_integer, +}; +use datafusion_common::arrow::array::{ArrayRef, UInt64Array}; +use datafusion_common::arrow::datatypes::{DataType, Field}; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, Expr, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +get_or_init_udwf!( + Ntile, + ntile, + "integer ranging from 1 to the argument value, dividing the partition as equally as possible" +); + +pub fn ntile(arg: Expr) -> Expr { + ntile_udwf().call(vec![arg]) +} + +#[derive(Debug)] +pub struct Ntile { + signature: Signature, +} + +impl Ntile { + /// Create a new `ntile` function + pub fn new() -> Self { + Self { + signature: Signature::uniform( + 1, + vec![ + DataType::UInt64, + DataType::UInt32, + DataType::UInt16, + DataType::UInt8, + DataType::Int64, + DataType::Int32, + DataType::Int16, + DataType::Int8, + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for Ntile { + fn default() -> Self { + Self::new() + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ntile_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Integer ranging from 1 to the argument value, dividing the partition as equally as possible", + ) + .with_syntax_example("ntile(expression)") + .with_argument("expression","An integer describing the number groups the partition should be split into") + .build() + .unwrap() + }) +} + +impl WindowUDFImpl for Ntile { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ntile" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + let scalar_n = + get_scalar_value_from_args(partition_evaluator_args.input_exprs(), 0)? + .ok_or_else(|| { + DataFusionError::Execution( + "NTILE requires a positive integer".to_string(), + ) + })?; + + if scalar_n.is_null() { + return exec_err!("NTILE requires a positive integer, but finds NULL"); + } + + if scalar_n.is_unsigned() { + let n = get_unsigned_integer(scalar_n)?; + Ok(Box::new(NtileEvaluator { n })) + } else { + let n: i64 = get_signed_integer(scalar_n)?; + if n <= 0 { + return exec_err!("NTILE requires a positive integer"); + } + Ok(Box::new(NtileEvaluator { n: n as u64 })) + } + } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let nullable = false; + + Ok(Field::new(field_args.name(), DataType::UInt64, nullable)) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ntile_doc()) + } +} + +#[derive(Debug)] +struct NtileEvaluator { + n: u64, +} + +impl PartitionEvaluator for NtileEvaluator { + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { + let num_rows = num_rows as u64; + let mut vec: Vec = Vec::new(); + let n = u64::min(self.n, num_rows); + for i in 0..num_rows { + let res = i * n / num_rows; + vec.push(res + 1) + } + Ok(Arc::new(UInt64Array::from(vec))) + } +} diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/functions-window/src/rank.rs similarity index 55% rename from datafusion/physical-expr/src/window/rank.rs rename to datafusion/functions-window/src/rank.rs index fa3d4e487f14..06c3f49055a5 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/functions-window/src/rank.rs @@ -15,40 +15,83 @@ // specific language governing permissions and limitations // under the License. -//! Defines physical expression for `rank`, `dense_rank`, and `percent_rank` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::window_expr::RankState; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::ArrayRef; -use arrow::array::{Float64Array, UInt64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::utils::get_row_at_idx; -use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; +//! Implementation of `rank`, `dense_rank`, and `percent_rank` window functions, +//! which can be evaluated at runtime during query execution. use std::any::Any; +use std::fmt::Debug; use std::iter; use std::ops::Range; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; + +use crate::define_udwf_and_expr; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::{Float64Array, UInt64Array}; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::utils::get_row_at_idx; +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; + +define_udwf_and_expr!( + Rank, + rank, + "Returns rank of the current row with gaps. Same as `row_number` of its first peer", + Rank::basic +); + +define_udwf_and_expr!( + DenseRank, + dense_rank, + "Returns rank of the current row without gaps. This function counts peer groups", + Rank::dense_rank +); + +define_udwf_and_expr!( + PercentRank, + percent_rank, + "Returns the relative rank of the current row: (rank - 1) / (total rows - 1)", + Rank::percent_rank +); /// Rank calculates the rank in the window function with order by #[derive(Debug)] pub struct Rank { name: String, + signature: Signature, rank_type: RankType, - /// Output data type - data_type: DataType, } impl Rank { - /// Get rank_type of the rank in window function with order by - pub fn get_type(&self) -> RankType { - self.rank_type + /// Create a new `rank` function with the specified name and rank type + pub fn new(name: String, rank_type: RankType) -> Self { + Self { + name, + signature: Signature::any(0, Volatility::Immutable), + rank_type, + } + } + + /// Create a `rank` window function + pub fn basic() -> Self { + Rank::new("rank".to_string(), RankType::Basic) + } + + /// Create a `dense_rank` window function + pub fn dense_rank() -> Self { + Rank::new("dense_rank".to_string(), RankType::Dense) + } + + /// Create a `percent_rank` window function + pub fn percent_rank() -> Self { + Rank::new("percent_rank".to_string(), RankType::Percent) } } @@ -59,74 +102,121 @@ pub enum RankType { Percent, } -/// Create a rank window function -pub fn rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Basic, - data_type: data_type.clone(), - } +static RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rank_doc() -> &'static Documentation { + RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Returns the rank of the current row within its partition, allowing \ + gaps between ranks. This function provides a ranking similar to `row_number`, but \ + skips ranks for identical values.", + ) + .with_syntax_example("rank()") + .build() + .unwrap() + }) } -/// Create a dense rank window function -pub fn dense_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Dense, - data_type: data_type.clone(), - } +static DENSE_RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_dense_rank_doc() -> &'static Documentation { + DENSE_RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Returns the rank of the current row without gaps. This function ranks \ + rows in a dense manner, meaning consecutive ranks are assigned even for identical \ + values.", + ) + .with_syntax_example("dense_rank()") + .build() + .unwrap() + }) } -/// Create a percent rank window function -pub fn percent_rank(name: String, data_type: &DataType) -> Rank { - Rank { - name, - rank_type: RankType::Percent, - data_type: data_type.clone(), - } +static PERCENT_RANK_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_percent_rank_doc() -> &'static Documentation { + PERCENT_RANK_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Returns the percentage rank of the current row within its partition. \ + The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`.", + ) + .with_syntax_example("percent_rank()") + .build() + .unwrap() + }) } -impl BuiltInWindowFunctionExpr for Rank { - /// Return a reference to Any that can be used for downcasting +impl WindowUDFImpl for Rank { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - fn name(&self) -> &str { &self.name } - fn create_evaluator(&self) -> Result> { + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { Ok(Box::new(RankEvaluator { state: RankState::default(), rank_type: self.rank_type, })) } - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in RANK window function (in all modes) introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + let return_type = match self.rank_type { + RankType::Basic | RankType::Dense => DataType::UInt64, + RankType::Percent => DataType::Float64, + }; + + let nullable = false; + Ok(Field::new(field_args.name(), return_type, nullable)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, }) } + + fn documentation(&self) -> Option<&Documentation> { + match self.rank_type { + RankType::Basic => Some(get_rank_doc()), + RankType::Dense => Some(get_dense_rank_doc()), + RankType::Percent => Some(get_percent_rank_doc()), + } + } +} + +/// State for the RANK(rank) built-in window function. +#[derive(Debug, Clone, Default)] +pub struct RankState { + /// The last values for rank as these values change, we increase n_rank + pub last_rank_data: Option>, + /// The index where last_rank_boundary is started + pub last_rank_boundary: usize, + /// Keep the number of entries in current rank + pub current_group_count: usize, + /// Rank number kept from the start + pub n_rank: usize, } +/// State for the `rank` built-in window function. #[derive(Debug)] -pub(crate) struct RankEvaluator { +struct RankEvaluator { state: RankState, rank_type: RankType, } @@ -136,7 +226,6 @@ impl PartitionEvaluator for RankEvaluator { matches!(self.rank_type, RankType::Basic | RankType::Dense) } - /// Evaluates the window function inside the given range. fn evaluate( &mut self, values: &[ArrayRef], @@ -163,6 +252,7 @@ impl PartitionEvaluator for RankEvaluator { // data is still in the same rank self.state.current_group_count += 1; } + match self.rank_type { RankType::Basic => Ok(ScalarValue::UInt64(Some( self.state.last_rank_boundary as u64 + 1, @@ -179,8 +269,19 @@ impl PartitionEvaluator for RankEvaluator { num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - // see https://www.postgresql.org/docs/current/functions-window.html let result: ArrayRef = match self.rank_type { + RankType::Basic => Arc::new(UInt64Array::from_iter_values( + ranks_in_partition + .iter() + .scan(1_u64, |acc, range| { + let len = range.end - range.start; + let result = iter::repeat(*acc).take(len); + *acc += len as u64; + Some(result) + }) + .flatten(), + )), + RankType::Dense => Arc::new(UInt64Array::from_iter_values( ranks_in_partition .iter() @@ -190,9 +291,10 @@ impl PartitionEvaluator for RankEvaluator { iter::repeat(rank).take(len) }), )), + RankType::Percent => { - // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. let denominator = num_rows as f64; + Arc::new(Float64Array::from_iter_values( ranks_in_partition .iter() @@ -206,18 +308,8 @@ impl PartitionEvaluator for RankEvaluator { .flatten(), )) } - RankType::Basic => Arc::new(UInt64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(1_u64, |acc, range| { - let len = range.end - range.start; - let result = iter::repeat(*acc).take(len); - *acc += len as u64; - Some(result) - }) - .flatten(), - )), }; + Ok(result) } @@ -244,53 +336,57 @@ mod tests { test_i32_result(expr, vec![0..8], expected) } - fn test_f64_result( + fn test_i32_result( expr: &Rank, - num_rows: usize, ranks: Vec>, - expected: Vec, + expected: Vec, ) -> Result<()> { + let args = PartitionEvaluatorArgs::default(); let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; + .partition_evaluator(args)? + .evaluate_all_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); Ok(()) } - fn test_i32_result( + fn test_f64_result( expr: &Rank, + num_rows: usize, ranks: Vec>, - expected: Vec, + expected: Vec, ) -> Result<()> { - let result = expr.create_evaluator()?.evaluate_all_with_rank(8, &ranks)?; - let result = as_uint64_array(&result)?; + let args = PartitionEvaluatorArgs::default(); + let result = expr + .partition_evaluator(args)? + .evaluate_all_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, *result); Ok(()) } #[test] - fn test_dense_rank() -> Result<()> { - let r = dense_rank("arr".into(), &DataType::UInt64); + fn test_rank() -> Result<()> { + let r = Rank::basic(); test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; + test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; Ok(()) } #[test] - fn test_rank() -> Result<()> { - let r = rank("arr".into(), &DataType::UInt64); + fn test_dense_rank() -> Result<()> { + let r = Rank::dense_rank(); test_without_rank(&r, vec![1; 8])?; - test_with_rank(&r, vec![1, 1, 3, 4, 4, 4, 7, 8])?; + test_with_rank(&r, vec![1, 1, 2, 3, 3, 3, 4, 5])?; Ok(()) } #[test] #[allow(clippy::single_range_in_vec_init)] fn test_percent_rank() -> Result<()> { - let r = percent_rank("arr".into(), &DataType::Float64); + let r = Rank::percent_rank(); // empty case let expected = vec![0.0; 0]; diff --git a/datafusion/functions-window/src/row_number.rs b/datafusion/functions-window/src/row_number.rs new file mode 100644 index 000000000000..56af14fb84ae --- /dev/null +++ b/datafusion/functions-window/src/row_number.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `row_number` window function implementation + +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::arrow::array::UInt64Array; +use datafusion_common::arrow::compute::SortOptions; +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::arrow::datatypes::Field; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::window_doc_sections::DOC_SECTION_RANKING; +use datafusion_expr::{ + Documentation, PartitionEvaluator, Signature, Volatility, WindowUDFImpl, +}; +use datafusion_functions_window_common::field; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use field::WindowUDFFieldArgs; +use std::any::Any; +use std::fmt::Debug; +use std::ops::Range; +use std::sync::OnceLock; + +define_udwf_and_expr!( + RowNumber, + row_number, + "Returns a unique row number for each row in window partition beginning at 1." +); + +/// row_number expression +#[derive(Debug)] +pub struct RowNumber { + signature: Signature, +} + +impl RowNumber { + /// Create a new `row_number` function + pub fn new() -> Self { + Self { + signature: Signature::any(0, Volatility::Immutable), + } + } +} + +impl Default for RowNumber { + fn default() -> Self { + Self::new() + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_row_number_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_RANKING) + .with_description( + "Number of the current row within its partition, counting from 1.", + ) + .with_syntax_example("row_number()") + .build() + .unwrap() + }) +} + +impl WindowUDFImpl for RowNumber { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "row_number" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + Ok(Box::::default()) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + Ok(Field::new(field_args.name(), DataType::UInt64, false)) + } + + fn sort_options(&self) -> Option { + Some(SortOptions { + descending: false, + nulls_first: false, + }) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_row_number_doc()) + } +} + +/// State for the `row_number` built-in window function. +#[derive(Debug, Default)] +struct NumRowsEvaluator { + n_rows: usize, +} + +impl PartitionEvaluator for NumRowsEvaluator { + fn is_causal(&self) -> bool { + // The row_number function doesn't need "future" values to emit results: + true + } + + fn evaluate_all( + &mut self, + _values: &[ArrayRef], + num_rows: usize, + ) -> Result { + Ok(std::sync::Arc::new(UInt64Array::from_iter_values( + 1..(num_rows as u64) + 1, + ))) + } + + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &Range, + ) -> Result { + self.n_rows += 1; + Ok(ScalarValue::UInt64(Some(self.n_rows as u64))) + } + + fn supports_bounded_execution(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion_common::arrow::array::{Array, BooleanArray}; + use datafusion_common::cast::as_uint64_array; + + use super::*; + + #[test] + fn row_number_all_null() -> Result<()> { + let values: ArrayRef = Arc::new(BooleanArray::from(vec![ + None, None, None, None, None, None, None, None, + ])); + let num_rows = values.len(); + + let actual = RowNumber::default() + .partition_evaluator(PartitionEvaluatorArgs::default())? + .evaluate_all(&[values], num_rows)?; + let actual = as_uint64_array(&actual)?; + + assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *actual.values()); + Ok(()) + } + + #[test] + fn row_number_all_values() -> Result<()> { + let values: ArrayRef = Arc::new(BooleanArray::from(vec![ + true, false, true, false, false, true, false, true, + ])); + let num_rows = values.len(); + + let actual = RowNumber::default() + .partition_evaluator(PartitionEvaluatorArgs::default())? + .evaluate_all(&[values], num_rows)?; + let actual = as_uint64_array(&actual)?; + + assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *actual.values()); + Ok(()) + } +} diff --git a/datafusion/functions-window/src/utils.rs b/datafusion/functions-window/src/utils.rs new file mode 100644 index 000000000000..3f8061dbea3e --- /dev/null +++ b/datafusion/functions-window/src/utils.rs @@ -0,0 +1,65 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::arrow::datatypes::DataType; +use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +pub(crate) fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::Int64)?.try_into() +} + +pub(crate) fn get_scalar_value_from_args( + args: &[Arc], + index: usize, +) -> Result> { + Ok(if let Some(field) = args.get(index) { + let tmp = field + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::NotImplemented( + format!("There is only support Literal types for field at idx: {index} in Window Function"), + ))? + .value() + .clone(); + Some(tmp) + } else { + None + }) +} + +pub(crate) fn get_unsigned_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } + + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); + } + + value.cast_to(&DataType::UInt64)?.try_into() +} diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 0886dee03479..70a988dbfefb 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -66,6 +66,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +arrow-buffer = { workspace = true } base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } @@ -73,14 +74,13 @@ chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-physical-expr = { workspace = true, default-features = true } -hashbrown = { version = "0.14", features = ["raw"], optional = true } +hashbrown = { workspace = true, optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } rand = { workspace = true } -regex = { version = "1.8", optional = true } +regex = { workspace = true, optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.7", features = ["v4"], optional = true } @@ -102,6 +102,11 @@ harness = false name = "to_timestamp" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "encoding" +required-features = ["encoding_expressions"] + [[bench]] harness = false name = "regx" @@ -112,6 +117,16 @@ harness = false name = "make_date" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "iszero" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "nullif" +required-features = ["core_expressions"] + [[bench]] harness = false name = "date_bin" @@ -122,6 +137,16 @@ harness = false name = "to_char" required-features = ["datetime_expressions"] +[[bench]] +harness = false +name = "isnan" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "signum" +required-features = ["math_expressions"] + [[bench]] harness = false name = "substr_index" @@ -141,3 +166,43 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "pad" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "repeat" +required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "random" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "substr" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "character_length" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "cot" +required-features = ["math_expressions"] + +[[bench]] +harness = false +name = "strpos" +required-features = ["unicode_expressions"] + +[[bench]] +harness = false +name = "trunc" +required-features = ["math_expressions"] diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs new file mode 100644 index 000000000000..9ba16807de01 --- /dev/null +++ b/datafusion/functions/benches/character_length.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use rand::distributions::Alphanumeric; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use std::sync::Arc; + +/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with +/// 4096 rows, each row containing a string with 128 random characters. +/// around 10% of the rows are null, around 10% of the rows are non-ASCII. +fn gen_string_array( + n_rows: usize, + str_len_chars: usize, + null_density: f32, + utf8_density: f32, + is_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + let rng_ref = &mut rng; + + let corpus = "DataFusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes + let corpus_char_count = corpus.chars().count(); + + let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); + for _ in 0..n_rows { + let rand_num = rng_ref.gen::(); // [0.0, 1.0) + if rand_num < null_density { + output_string_vec.push(None); + } else if rand_num < null_density + utf8_density { + // Generate random UTF8 string + let mut generated_string = String::with_capacity(str_len_chars); + for _ in 0..str_len_chars { + let idx = rng_ref.gen_range(0..corpus_char_count); + let char = corpus.chars().nth(idx).unwrap(); + generated_string.push(char); + } + output_string_vec.push(Some(generated_string)); + } else { + // Generate random ASCII-only string + let value = rng_ref + .sample_iter(&Alphanumeric) + .take(str_len_chars) + .collect(); + let value = String::from_utf8(value).unwrap(); + output_string_vec.push(Some(value)); + } + } + + if is_string_view { + let string_view_array: StringViewArray = output_string_vec.into_iter().collect(); + vec![ColumnarValue::Array(Arc::new(string_view_array))] + } else { + let string_array: StringArray = output_string_vec.clone().into_iter().collect(); + vec![ColumnarValue::Array(Arc::new(string_array))] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + // All benches are single batch run with 8192 rows + let character_length = datafusion_functions::unicode::character_length(); + + let n_rows = 8192; + for str_len in [8, 32, 128, 4096] { + // StringArray ASCII only + let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + c.bench_function( + &format!("character_length_StringArray_ascii_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(character_length.invoke(&args_string_ascii)) + }) + }, + ); + + // StringArray UTF8 + let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + c.bench_function( + &format!("character_length_StringArray_utf8_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(character_length.invoke(&args_string_utf8)) + }) + }, + ); + + // StringViewArray ASCII only + let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + c.bench_function( + &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(character_length.invoke(&args_string_view_ascii)) + }) + }, + ); + + // StringViewArray UTF8 + let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + c.bench_function( + &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(character_length.invoke(&args_string_view_utf8)) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index e7b00a6d540a..280819778f93 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::ArrayRef; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -26,7 +27,7 @@ fn create_args(size: usize, str_len: usize) -> Vec { let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ - ColumnarValue::Array(array.clone()), + ColumnarValue::Array(Arc::clone(&array) as ArrayRef), ColumnarValue::Scalar(scalar), ColumnarValue::Array(array), ] @@ -37,7 +38,10 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args(size, 32); let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { - b.iter(|| criterion::black_box(concat().invoke(&args).unwrap())) + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(concat().invoke(&args).unwrap()) + }) }); group.finish(); } diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs new file mode 100644 index 000000000000..a33f00b4b73e --- /dev/null +++ b/datafusion/functions/benches/cot.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::cot; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let cot_fn = cot(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("cot f32 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(cot_fn.invoke(&f32_args).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("cot f64 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(cot_fn.invoke(&f64_args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index c881947354fd..4a8682c42f94 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -45,6 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { let udf = date_bin(); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( udf.invoke(&[interval.clone(), timestamps.clone()]) .expect("date_bin should work on valid values"), diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs new file mode 100644 index 000000000000..0615091e90d4 --- /dev/null +++ b/datafusion/functions/benches/encoding.rs @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::encoding; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let decode = encoding::decode(); + for size in [1024, 4096, 8192] { + let str_array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + c.bench_function(&format!("base64_decode/{size}"), |b| { + let method = ColumnarValue::Scalar("base64".into()); + #[allow(deprecated)] // TODO use invoke_batch + let encoded = encoding::encode() + .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .unwrap(); + + let args = vec![encoded, method]; + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(decode.invoke(&args).unwrap()) + }) + }); + + c.bench_function(&format!("hex_decode/{size}"), |b| { + let method = ColumnarValue::Scalar("hex".into()); + #[allow(deprecated)] // TODO use invoke_batch + let encoded = encoding::encode() + .invoke(&[ColumnarValue::Array(str_array.clone()), method.clone()]) + .unwrap(); + + let args = vec![encoded, method]; + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(decode.invoke(&args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs new file mode 100644 index 000000000000..3e50de658b36 --- /dev/null +++ b/datafusion/functions/benches/isnan.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::isnan; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let isnan = isnan(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("isnan f32 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(isnan.invoke(&f32_args).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("isnan f64 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(isnan.invoke(&f64_args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs new file mode 100644 index 000000000000..3e6ac97063ca --- /dev/null +++ b/datafusion/functions/benches/iszero.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::iszero; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let iszero = iszero(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("iszero f32 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(iszero.invoke(&f32_args).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("iszero f64 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(iszero.invoke(&f64_args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index fa963f174e46..6cc67791464f 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -17,8 +17,10 @@ extern crate criterion; -use arrow::array::{ArrayRef, StringArray}; -use arrow::util::bench_util::create_string_array_with_len; +use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ColumnarValue; use datafusion_functions::string; @@ -65,26 +67,134 @@ fn create_args3(size: usize) -> Vec { vec![ColumnarValue::Array(array)] } +/// Create an array of args containing StringViews, where all the values in the +/// StringViews are ASCII. +/// * `size` - the length of the StringViews, and +/// * `str_len` - the length of the strings within the array. +/// * `null_density` - the density of null values in the array. +/// * `mixed` - whether the array is mixed between inlined and referenced strings. +fn create_args4( + size: usize, + str_len: usize, + null_density: f32, + mixed: bool, +) -> Vec { + let array = Arc::new(create_string_view_array_with_len( + size, + null_density, + str_len, + mixed, + )); + + vec![ColumnarValue::Array(array)] +} + +/// Create an array of args containing a StringViewArray, where some of the values in the +/// array are non-ASCII. +/// * `size` - the length of the StringArray, and +/// * `non_ascii_density` - the density of non-ASCII values in the array. +/// * `null_density` - the density of null values in the array. +fn create_args5( + size: usize, + non_ascii_density: f32, + null_density: f32, +) -> Vec { + let mut string_view_builder = StringViewBuilder::with_capacity(size); + for _ in 0..size { + // sample null_density to determine if the value should be null + if rand::random::() < null_density { + string_view_builder.append_null(); + continue; + } + + // sample non_ascii_density to determine if the value should be non-ASCII + if rand::random::() < non_ascii_density { + string_view_builder.append_value("农历新年农历新年农历新年农历新年农历新年"); + } else { + string_view_builder.append_value("DATAFUSIONDATAFUSIONDATAFUSION"); + } + } + + let array = Arc::new(string_view_builder.finish()) as ArrayRef; + vec![ColumnarValue::Array(array)] +} + fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { - b.iter(|| black_box(lower.invoke(&args))) + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(lower.invoke(&args)) + }) }); let args = create_args2(size); c.bench_function( &format!("lower_the_first_value_is_nonascii: {}", size), - |b| b.iter(|| black_box(lower.invoke(&args))), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(lower.invoke(&args)) + }) + }, ); let args = create_args3(size); c.bench_function( &format!("lower_the_middle_value_is_nonascii: {}", size), - |b| b.iter(|| black_box(lower.invoke(&args))), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(lower.invoke(&args)) + }) + }, ); } + + let sizes = [4096, 8192]; + let str_lens = [10, 64, 128]; + let mixes = [true, false]; + let null_densities = [0.0f32, 0.1f32]; + + for null_density in &null_densities { + for &mixed in &mixes { + for &str_len in &str_lens { + for &size in &sizes { + let args = create_args4(size, str_len, *null_density, mixed); + c.bench_function( + &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", + size, str_len, null_density, mixed), + |b| b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(lower.invoke(&args)) + }), + ); + + let args = create_args4(size, str_len, *null_density, mixed); + c.bench_function( + &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", + size, str_len, null_density, mixed), + |b| b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(lower.invoke(&args)) + }), + ); + + let args = create_args5(size, 0.1, *null_density); + c.bench_function( + &format!("lower_some_values_are_nonascii_string_views: size: {}, str_len: {}, non_ascii_density: {}, null_density: {}, mixed: {}", + size, str_len, 0.1, null_density, mixed), + |b| b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(lower.invoke(&args)) + }), + ); + } + } + } + } } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 01acb9de3381..4f94729b6fef 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -17,32 +17,226 @@ extern crate criterion; -use arrow::array::{ArrayRef, StringArray}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; +use criterion::{ + black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, + Criterion, SamplingMode, +}; use datafusion_common::ScalarValue; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ScalarUDF}; use datafusion_functions::string; -use std::sync::Arc; +use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng}; +use std::{fmt, sync::Arc}; -fn create_args(size: usize, characters: &str) -> Vec { - let iter = - std::iter::repeat(format!("{}datafusion{}", characters, characters)).take(size); - let array = Arc::new(StringArray::from_iter_values(iter)) as ArrayRef; +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + +#[derive(Clone, Copy)] +pub enum StringArrayType { + Utf8View, + Utf8, + LargeUtf8, +} + +impl fmt::Display for StringArrayType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StringArrayType::Utf8View => f.write_str("string_view"), + StringArrayType::Utf8 => f.write_str("string"), + StringArrayType::LargeUtf8 => f.write_str("large_string"), + } + } +} + +/// returns an array of strings, and `characters` as a ScalarValue +pub fn create_string_array_and_characters( + size: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, + string_array_type: StringArrayType, +) -> (ArrayRef, ScalarValue) { + let rng = &mut seedable_rng(); + + // Create `size` rows: + // - 10% rows will be `None` + // - Other 90% will be strings with same `remaining_len` lengths + // We will build the string array on it later. + let string_iter = (0..size).map(|_| { + if rng.gen::() < 0.1 { + None + } else { + let mut value = trimmed.as_bytes().to_vec(); + let generated = rng.sample_iter(&Alphanumeric).take(remaining_len); + value.extend(generated); + Some(String::from_utf8(value).unwrap()) + } + }); + + // Build the target `string array` and `characters` according to `string_array_type` + match string_array_type { + StringArrayType::Utf8View => ( + Arc::new(string_iter.collect::()), + ScalarValue::Utf8View(Some(characters.to_string())), + ), + StringArrayType::Utf8 => ( + Arc::new(string_iter.collect::()), + ScalarValue::Utf8(Some(characters.to_string())), + ), + StringArrayType::LargeUtf8 => ( + Arc::new(string_iter.collect::()), + ScalarValue::LargeUtf8(Some(characters.to_string())), + ), + } +} + +/// Create args for the ltrim benchmark +/// Inputs: +/// - size: rows num of the test array +/// - characters: the characters we need to trim +/// - trimmed: the part in the testing string that will be trimmed +/// - remaining_len: the len of the remaining part of testing string after trimming +/// - string_array_type: the method used to store the testing strings +/// +/// Outputs: +/// - testing string array +/// - trimmed characters +/// +fn create_args( + size: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, + string_array_type: StringArrayType, +) -> Vec { + let (string_array, pattern) = create_string_array_and_characters( + size, + characters, + trimmed, + remaining_len, + string_array_type, + ); vec![ - ColumnarValue::Array(array), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(characters.to_string()))), + ColumnarValue::Array(string_array), + ColumnarValue::Scalar(pattern), ] } +#[allow(clippy::too_many_arguments)] +fn run_with_string_type( + group: &mut BenchmarkGroup<'_, M>, + ltrim: &ScalarUDF, + size: usize, + len: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, + string_type: StringArrayType, +) { + let args = create_args(size, characters, trimmed, remaining_len, string_type); + group.bench_function( + format!( + "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(ltrim.invoke(&args)) + }) + }, + ); +} + +#[allow(clippy::too_many_arguments)] +fn run_one_group( + c: &mut Criterion, + group_name: &str, + ltrim: &ScalarUDF, + string_types: &[StringArrayType], + size: usize, + len: usize, + characters: &str, + trimmed: &str, + remaining_len: usize, +) { + let mut group = c.benchmark_group(group_name); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + for string_type in string_types { + run_with_string_type( + &mut group, + ltrim, + size, + len, + characters, + trimmed, + remaining_len, + *string_type, + ); + } + + group.finish(); +} + fn criterion_benchmark(c: &mut Criterion) { let ltrim = string::ltrim(); - for char in ["\"", "Header:"] { - for size in [1024, 4096, 8192] { - let args = create_args(size, char); - c.bench_function(&format!("ltrim {}: {}", char, size), |b| { - b.iter(|| black_box(ltrim.invoke(&args))) - }); - } + let characters = ",!()"; + + let string_types = [ + StringArrayType::Utf8View, + StringArrayType::Utf8, + StringArrayType::LargeUtf8, + ]; + for size in [1024, 4096, 8192] { + // len=12, trimmed_len=4, len_after_ltrim=8 + let len = 12; + let trimmed = characters; + let remaining_len = len - trimmed.len(); + run_one_group( + c, + "INPUT LEN <= 12", + <rim, + &string_types, + size, + len, + characters, + trimmed, + remaining_len, + ); + + // len=64, trimmed_len=4, len_after_ltrim=60 + let len = 64; + let trimmed = characters; + let remaining_len = len - trimmed.len(); + run_one_group( + c, + "INPUT LEN > 12, OUTPUT LEN > 12", + <rim, + &string_types, + size, + len, + characters, + trimmed, + remaining_len, + ); + + // len=64, trimmed_len=56, len_after_ltrim=8 + let len = 64; + let trimmed = characters.repeat(15); + let remaining_len = len - trimmed.len(); + run_one_group( + c, + "INPUT LEN > 12, OUTPUT LEN <= 12", + <rim, + &string_types, + size, + len, + characters, + &trimmed, + remaining_len, + ); } } diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index cb8f1abe6d5d..a9844e4b2541 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -62,6 +62,7 @@ fn criterion_benchmark(c: &mut Criterion) { let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( make_date() .invoke(&[years.clone(), months.clone(), days.clone()]) @@ -77,6 +78,7 @@ fn criterion_benchmark(c: &mut Criterion) { let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( make_date() .invoke(&[year.clone(), months.clone(), days.clone()]) @@ -92,6 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) { let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( make_date() .invoke(&[year.clone(), month.clone(), days.clone()]) @@ -106,6 +109,7 @@ fn criterion_benchmark(c: &mut Criterion) { let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( make_date() .invoke(&[year.clone(), month.clone(), day.clone()]) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs new file mode 100644 index 000000000000..6e1154cf182a --- /dev/null +++ b/datafusion/functions/benches/nullif.rs @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_functions::core::nullif; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let nullif = nullif(); + for size in [1024, 4096, 8192] { + let array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), + ColumnarValue::Array(array), + ]; + c.bench_function(&format!("nullif scalar array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(nullif.invoke(&args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs new file mode 100644 index 000000000000..4b21ca373047 --- /dev/null +++ b/datafusion/functions/benches/pad.rs @@ -0,0 +1,160 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; +use arrow::datatypes::Int64Type; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode::{lpad, rpad}; +use rand::distributions::{Distribution, Uniform}; +use rand::Rng; +use std::sync::Arc; + +struct Filter { + dist: Dist, +} + +impl Distribution for Filter +where + Dist: Distribution, +{ + fn sample(&self, rng: &mut R) -> T { + self.dist.sample(rng) + } +} + +pub fn create_primitive_array( + size: usize, + null_density: f32, + len: usize, +) -> PrimitiveArray +where + T: ArrowPrimitiveType, +{ + let dist = Filter { + dist: Uniform::new_inclusive::(0, len as i64), + }; + + let mut rng = rand::thread_rng(); + (0..size) + .map(|_| { + if rng.gen::() < null_density { + None + } else { + Some(rng.sample(&dist)) + } + }) + .collect() +} + +fn create_args( + size: usize, + str_len: usize, + force_view_types: bool, +) -> Vec { + let length_array = Arc::new(create_primitive_array::(size, 0.0, str_len)); + + if !force_view_types { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + let fill_array = Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), + ColumnarValue::Array(fill_array), + ] + } else { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + let fill_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&length_array) as ArrayRef), + ColumnarValue::Array(fill_array), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 2048] { + let mut group = c.benchmark_group("lpad function"); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(lpad().invoke(&args).unwrap()) + }) + }); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(lpad().invoke(&args).unwrap()) + }) + }); + + let args = create_args::(size, 32, true); + group.bench_function(BenchmarkId::new("stringview type", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(lpad().invoke(&args).unwrap()) + }) + }); + + group.finish(); + + let mut group = c.benchmark_group("rpad function"); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(rpad().invoke(&args).unwrap()) + }) + }); + + let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(rpad().invoke(&args).unwrap()) + }) + }); + + // rpad for stringview type + let args = create_args::(size, 32, true); + group.bench_function(BenchmarkId::new("stringview type", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + criterion::black_box(rpad().invoke(&args).unwrap()) + }) + }); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs new file mode 100644 index 000000000000..5df5d9c7dee2 --- /dev/null +++ b/datafusion/functions/benches/random.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ScalarUDFImpl; +use datafusion_functions::math::random::RandomFunc; + +fn criterion_benchmark(c: &mut Criterion) { + let random_func = RandomFunc::new(); + + // Benchmark to evaluate 1M rows in batch size 8192 + let iterations = 1_000_000 / 8192; // Calculate how many iterations are needed to reach approximately 1M rows + c.bench_function("random_1M_rows_batch_8192", |b| { + b.iter(|| { + for _ in 0..iterations { + black_box(random_func.invoke_batch(&[], 8192).unwrap()); + } + }) + }); + + // Benchmark to evaluate 1M rows in batch size 128 + let iterations_128 = 1_000_000 / 128; // Calculate how many iterations are needed to reach approximately 1M rows with batch size 128 + c.bench_function("random_1M_rows_batch_128", |b| { + b.iter(|| { + for _ in 0..iterations_128 { + black_box(random_func.invoke_batch(&[], 128).unwrap()); + } + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index da4882381e76..468d3d548bcf 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -18,8 +18,11 @@ extern crate criterion; use arrow::array::builder::StringBuilder; -use arrow::array::{ArrayRef, StringArray}; +use arrow::array::{ArrayRef, AsArray, Int64Array, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_functions::regex::regexpcount::regexp_count_func; use datafusion_functions::regex::regexplike::regexp_like; use datafusion_functions::regex::regexpmatch::regexp_match; use datafusion_functions::regex::regexpreplace::regexp_replace; @@ -59,6 +62,15 @@ fn regex(rng: &mut ThreadRng) -> StringArray { StringArray::from(data) } +fn start(rng: &mut ThreadRng) -> Int64Array { + let mut data: Vec = vec![]; + for _ in 0..1000 { + data.push(rng.gen_range(1..5)); + } + + Int64Array::from(data) +} + fn flags(rng: &mut ThreadRng) -> StringArray { let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); @@ -75,6 +87,46 @@ fn flags(rng: &mut ThreadRng) -> StringArray { } fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("regexp_count_1000 string", |b| { + let mut rng = rand::thread_rng(); + let data = Arc::new(data(&mut rng)) as ArrayRef; + let regex = Arc::new(regex(&mut rng)) as ArrayRef; + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = Arc::new(flags(&mut rng)) as ArrayRef; + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8"), + ) + }) + }); + + c.bench_function("regexp_count_1000 utf8view", |b| { + let mut rng = rand::thread_rng(); + let data = cast(&data(&mut rng), &DataType::Utf8View).unwrap(); + let regex = cast(®ex(&mut rng), &DataType::Utf8View).unwrap(); + let start = Arc::new(start(&mut rng)) as ArrayRef; + let flags = cast(&flags(&mut rng), &DataType::Utf8View).unwrap(); + + b.iter(|| { + black_box( + regexp_count_func(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&start), + Arc::clone(&flags), + ]) + .expect("regexp_count should work on utf8view"), + ) + }) + }); + c.bench_function("regexp_like_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; @@ -83,7 +135,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_like::(&[data.clone(), regex.clone(), flags.clone()]) + regexp_like(&[Arc::clone(&data), Arc::clone(®ex), Arc::clone(&flags)]) .expect("regexp_like should work on valid values"), ) }) @@ -97,8 +149,12 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_match::(&[data.clone(), regex.clone(), flags.clone()]) - .expect("regexp_match should work on valid values"), + regexp_match::(&[ + Arc::clone(&data), + Arc::clone(®ex), + Arc::clone(&flags), + ]) + .expect("regexp_match should work on valid values"), ) }) }); @@ -114,12 +170,12 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box( - regexp_replace::(&[ - data.clone(), - regex.clone(), - replacement.clone(), - flags.clone(), - ]) + regexp_replace::( + data.as_string::(), + regex.as_string::(), + replacement.as_string::(), + Some(&flags), + ) .expect("regexp_replace should work on valid values"), ) }) diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs new file mode 100644 index 000000000000..6e54c92b9b26 --- /dev/null +++ b/datafusion/functions/benches/repeat.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string; +use std::sync::Arc; +use std::time::Duration; + +fn create_args( + size: usize, + str_len: usize, + repeat_times: i64, + force_view_types: bool, +) -> Vec { + let number_array = Arc::new(Int64Array::from( + (0..size).map(|_| repeat_times).collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(number_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&number_array) as ArrayRef), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let repeat = string::repeat(); + for size in [1024, 4096] { + // REPEAT 3 TIMES + let repeat_times = 3; + let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + let args = create_args::(size, 32, repeat_times, true); + group.bench_function( + format!( + "repeat_string_view [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(repeat.invoke(&args)) + }) + }, + ); + + let args = create_args::(size, 32, repeat_times, false); + group.bench_function( + format!( + "repeat_string [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(repeat.invoke(&args)) + }) + }, + ); + + let args = create_args::(size, 32, repeat_times, false); + group.bench_function( + format!( + "repeat_large_string [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(repeat.invoke(&args)) + }) + }, + ); + + group.finish(); + + // REPEAT 30 TIMES + let repeat_times = 30; + let mut group = c.benchmark_group(format!("repeat {} times", repeat_times)); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + group.measurement_time(Duration::from_secs(10)); + + let args = create_args::(size, 32, repeat_times, true); + group.bench_function( + format!( + "repeat_string_view [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(repeat.invoke(&args)) + }) + }, + ); + + let args = create_args::(size, 32, repeat_times, false); + group.bench_function( + format!( + "repeat_string [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(repeat.invoke(&args)) + }) + }, + ); + + let args = create_args::(size, 32, repeat_times, false); + group.bench_function( + format!( + "repeat_large_string [size={}, repeat_times={}]", + size, repeat_times + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(repeat.invoke(&args)) + }) + }, + ); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs new file mode 100644 index 000000000000..ea1f5433df4e --- /dev/null +++ b/datafusion/functions/benches/signum.rs @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::signum; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let signum = signum(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("signum f32 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(signum.invoke(&f32_args).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("signum f64 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(signum.invoke(&f64_args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs new file mode 100644 index 000000000000..31ca61e34c3a --- /dev/null +++ b/datafusion/functions/benches/strpos.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{StringArray, StringViewArray}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use rand::distributions::Alphanumeric; +use rand::prelude::StdRng; +use rand::{Rng, SeedableRng}; +use std::str::Chars; +use std::sync::Arc; + +/// gen_arr(4096, 128, 0.1, 0.1, true) will generate a StringViewArray with +/// 4096 rows, each row containing a string with 128 random characters. +/// around 10% of the rows are null, around 10% of the rows are non-ASCII. +fn gen_string_array( + n_rows: usize, + str_len_chars: usize, + null_density: f32, + utf8_density: f32, + is_string_view: bool, // false -> StringArray, true -> StringViewArray +) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + let rng_ref = &mut rng; + + let utf8 = "DatafusionДатаФусион数据融合📊🔥"; // includes utf8 encoding with 1~4 bytes + let corpus_char_count = utf8.chars().count(); + + let mut output_string_vec: Vec> = Vec::with_capacity(n_rows); + let mut output_sub_string_vec: Vec> = Vec::with_capacity(n_rows); + for _ in 0..n_rows { + let rand_num = rng_ref.gen::(); // [0.0, 1.0) + if rand_num < null_density { + output_sub_string_vec.push(None); + output_string_vec.push(None); + } else if rand_num < null_density + utf8_density { + // Generate random UTF8 string + let mut generated_string = String::with_capacity(str_len_chars); + for _ in 0..str_len_chars { + let idx = rng_ref.gen_range(0..corpus_char_count); + let char = utf8.chars().nth(idx).unwrap(); + generated_string.push(char); + } + output_sub_string_vec.push(Some(random_substring(generated_string.chars()))); + output_string_vec.push(Some(generated_string)); + } else { + // Generate random ASCII-only string + let value = rng_ref + .sample_iter(&Alphanumeric) + .take(str_len_chars) + .collect(); + let value = String::from_utf8(value).unwrap(); + output_sub_string_vec.push(Some(random_substring(value.chars()))); + output_string_vec.push(Some(value)); + } + } + + if is_string_view { + let string_view_array: StringViewArray = output_string_vec.into_iter().collect(); + let sub_string_view_array: StringViewArray = + output_sub_string_vec.into_iter().collect(); + vec![ + ColumnarValue::Array(Arc::new(string_view_array)), + ColumnarValue::Array(Arc::new(sub_string_view_array)), + ] + } else { + let string_array: StringArray = output_string_vec.clone().into_iter().collect(); + let sub_string_array: StringArray = output_sub_string_vec.into_iter().collect(); + vec![ + ColumnarValue::Array(Arc::new(string_array)), + ColumnarValue::Array(Arc::new(sub_string_array)), + ] + } +} + +fn random_substring(chars: Chars) -> String { + // get the substring of a random length from the input string by byte unit + let mut rng = StdRng::seed_from_u64(44); + let count = chars.clone().count(); + let start = rng.gen_range(0..count - 1); + let end = rng.gen_range(start + 1..count); + chars + .enumerate() + .filter(|(i, _)| *i >= start && *i < end) + .map(|(_, c)| c) + .collect() +} + +fn criterion_benchmark(c: &mut Criterion) { + // All benches are single batch run with 8192 rows + let strpos = datafusion_functions::unicode::strpos(); + + let n_rows = 8192; + for str_len in [8, 32, 128, 4096] { + // StringArray ASCII only + let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + c.bench_function( + &format!("strpos_StringArray_ascii_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(strpos.invoke(&args_string_ascii)) + }) + }, + ); + + // StringArray UTF8 + let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + c.bench_function( + &format!("strpos_StringArray_utf8_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(strpos.invoke(&args_string_utf8)) + }) + }, + ); + + // StringViewArray ASCII only + let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + c.bench_function( + &format!("strpos_StringViewArray_ascii_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(strpos.invoke(&args_string_view_ascii)) + }) + }, + ); + + // StringViewArray UTF8 + let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + c.bench_function( + &format!("strpos_StringViewArray_utf8_str_len_{}", str_len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(strpos.invoke(&args_string_view_utf8)) + }) + }, + ); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs new file mode 100644 index 000000000000..21020dad31a4 --- /dev/null +++ b/datafusion/functions/benches/substr.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; +use arrow::util::bench_util::{ + create_string_array_with_len, create_string_view_array_with_len, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::unicode; +use std::sync::Arc; + +fn create_args_without_count( + size: usize, + str_len: usize, + start_half_way: bool, + force_view_types: bool, +) -> Vec { + let start_array = Arc::new(Int64Array::from( + (0..size) + .map(|_| { + if start_half_way { + (str_len / 2) as i64 + } else { + 1i64 + } + }) + .collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ] + } +} + +fn create_args_with_count( + size: usize, + str_len: usize, + count_max: usize, + force_view_types: bool, +) -> Vec { + let start_array = + Arc::new(Int64Array::from((0..size).map(|_| 1).collect::>())); + let count = count_max.min(str_len) as i64; + let count_array = Arc::new(Int64Array::from( + (0..size).map(|_| count).collect::>(), + )); + + if force_view_types { + let string_array = + Arc::new(create_string_view_array_with_len(size, 0.1, str_len, false)); + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(start_array), + ColumnarValue::Array(count_array), + ] + } else { + let string_array = + Arc::new(create_string_array_with_len::(size, 0.1, str_len)); + + vec![ + ColumnarValue::Array(string_array), + ColumnarValue::Array(Arc::clone(&start_array) as ArrayRef), + ColumnarValue::Array(Arc::clone(&count_array) as ArrayRef), + ] + } +} + +fn criterion_benchmark(c: &mut Criterion) { + let substr = unicode::substr(); + for size in [1024, 4096] { + // string_len = 12, substring_len=6 (see `create_args_without_count`) + let len = 12; + let mut group = c.benchmark_group("SHORTER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_without_count::(size, len, true, true); + group.bench_function( + format!("substr_string_view [size={}, strlen={}]", size, len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + let args = create_args_without_count::(size, len, false, false); + group.bench_function( + format!("substr_string [size={}, strlen={}]", size, len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + let args = create_args_without_count::(size, len, true, false); + group.bench_function( + format!("substr_large_string [size={}, strlen={}]", size, len), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + group.finish(); + + // string_len = 128, start=1, count=64, substring_len=64 + let len = 128; + let count = 64; + let mut group = c.benchmark_group("LONGER THAN 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!( + "substr_string_view [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!( + "substr_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!( + "substr_large_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + group.finish(); + + // string_len = 128, start=1, count=6, substring_len=6 + let len = 128; + let count = 6; + let mut group = c.benchmark_group("SRC_LEN > 12, SUB_LEN < 12"); + group.sampling_mode(SamplingMode::Flat); + group.sample_size(10); + + let args = create_args_with_count::(size, len, count, true); + group.bench_function( + format!( + "substr_string_view [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!( + "substr_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + let args = create_args_with_count::(size, len, count, false); + group.bench_function( + format!( + "substr_large_string [size={}, count={}, strlen={}]", + size, count, len, + ), + |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(substr.invoke(&args)) + }) + }, + ); + + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index bb9a5b809eee..1e793cf4db8c 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -90,6 +90,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = [strings, delimiters, counts]; b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( substr_index() .invoke(&args) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index d9a153e64abc..09032fdf2de1 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -86,6 +86,7 @@ fn criterion_benchmark(c: &mut Criterion) { let patterns = ColumnarValue::Array(Arc::new(patterns(&mut rng)) as ArrayRef); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( to_char() .invoke(&[data.clone(), patterns.clone()]) @@ -101,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( to_char() .invoke(&[data.clone(), patterns.clone()]) @@ -124,6 +126,7 @@ fn criterion_benchmark(c: &mut Criterion) { ))); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( to_char() .invoke(&[data.clone(), pattern.clone()]) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index e734b6832f29..11816fe9c64f 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -20,29 +20,128 @@ extern crate criterion; use std::sync::Arc; use arrow::array::builder::StringBuilder; -use arrow::array::ArrayRef; +use arrow::array::{ArrayRef, StringArray}; +use arrow::compute::cast; +use arrow::datatypes::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ColumnarValue; use datafusion_functions::datetime::to_timestamp; +fn data() -> StringArray { + let data: Vec<&str> = vec![ + "1997-01-31T09:26:56.123Z", + "1997-01-31T09:26:56.123-05:00", + "1997-01-31 09:26:56.123-05:00", + "2023-01-01 04:05:06.789 -08", + "1997-01-31T09:26:56.123", + "1997-01-31 09:26:56.123", + "1997-01-31 09:26:56", + "1997-01-31 13:26:56", + "1997-01-31 13:26:56+04:00", + "1997-01-31", + ]; + + StringArray::from(data) +} + +fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { + let mut inputs = StringBuilder::new(); + let mut format1_builder = StringBuilder::with_capacity(2, 10); + let mut format2_builder = StringBuilder::with_capacity(2, 10); + let mut format3_builder = StringBuilder::with_capacity(2, 10); + + inputs.append_value("1997-01-31T09:26:56.123Z"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); + + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); + + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); + + inputs.append_value("2023-01-01 04:05:06.789 -08"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); + + inputs.append_value("1997-01-31T09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S"); + + inputs.append_value("1997-01-31 092656"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S"); + + inputs.append_value("1997-01-31 092656+04:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); + + inputs.append_value("Sun Jul 8 00:34:60 2001"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d 00:00:00"); + + ( + inputs.finish(), + format1_builder.finish(), + format2_builder.finish(), + format3_builder.finish(), + ) +} fn criterion_benchmark(c: &mut Criterion) { - c.bench_function("to_timestamp_no_formats", |b| { - let mut inputs = StringBuilder::new(); - inputs.append_value("1997-01-31T09:26:56.123Z"); - inputs.append_value("1997-01-31T09:26:56.123-05:00"); - inputs.append_value("1997-01-31 09:26:56.123-05:00"); - inputs.append_value("2023-01-01 04:05:06.789 -08"); - inputs.append_value("1997-01-31T09:26:56.123"); - inputs.append_value("1997-01-31 09:26:56.123"); - inputs.append_value("1997-01-31 09:26:56"); - inputs.append_value("1997-01-31 13:26:56"); - inputs.append_value("1997-01-31 13:26:56+04:00"); - inputs.append_value("1997-01-31"); - - let string_array = ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef); + c.bench_function("to_timestamp_no_formats_utf8", |b| { + let string_array = ColumnarValue::Array(Arc::new(data()) as ArrayRef); + + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box( + to_timestamp() + .invoke(&[string_array.clone()]) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_no_formats_largeutf8", |b| { + let data = cast(&data(), &DataType::LargeUtf8).unwrap(); + let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); + + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box( + to_timestamp() + .invoke(&[string_array.clone()]) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_no_formats_utf8view", |b| { + let data = cast(&data(), &DataType::Utf8View).unwrap(); + let string_array = ColumnarValue::Array(Arc::new(data) as ArrayRef); b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( to_timestamp() .invoke(&[string_array.clone()]) @@ -51,69 +150,71 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("to_timestamp_with_formats", |b| { - let mut inputs = StringBuilder::new(); - let mut format1_builder = StringBuilder::with_capacity(2, 10); - let mut format2_builder = StringBuilder::with_capacity(2, 10); - let mut format3_builder = StringBuilder::with_capacity(2, 10); - - inputs.append_value("1997-01-31T09:26:56.123Z"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); - - inputs.append_value("1997-01-31T09:26:56.123-05:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); - - inputs.append_value("1997-01-31 09:26:56.123-05:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); - - inputs.append_value("2023-01-01 04:05:06.789 -08"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); - - inputs.append_value("1997-01-31T09:26:56.123"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); - - inputs.append_value("1997-01-31 09:26:56.123"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); - - inputs.append_value("1997-01-31 09:26:56"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H:%M:%S"); - - inputs.append_value("1997-01-31 092656"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H%M%S"); - - inputs.append_value("1997-01-31 092656+04:00"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); - - inputs.append_value("Sun Jul 8 00:34:60 2001"); - format1_builder.append_value("%+"); - format2_builder.append_value("%c"); - format3_builder.append_value("%Y-%m-%d 00:00:00"); + c.bench_function("to_timestamp_with_formats_utf8", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + + let args = [ + ColumnarValue::Array(Arc::new(inputs) as ArrayRef), + ColumnarValue::Array(Arc::new(format1) as ArrayRef), + ColumnarValue::Array(Arc::new(format2) as ArrayRef), + ColumnarValue::Array(Arc::new(format3) as ArrayRef), + ]; + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box( + to_timestamp() + .invoke(&args.clone()) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_with_formats_largeutf8", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); + + let args = [ + ColumnarValue::Array( + Arc::new(cast(&inputs, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format1, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format2, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef + ), + ]; + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box( + to_timestamp() + .invoke(&args.clone()) + .expect("to_timestamp should work on valid values"), + ) + }) + }); + + c.bench_function("to_timestamp_with_formats_utf8view", |b| { + let (inputs, format1, format2, format3) = data_with_formats(); let args = [ - ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), - ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ColumnarValue::Array( + Arc::new(cast(&inputs, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format1, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format2, &DataType::Utf8View).unwrap()) as ArrayRef + ), + ColumnarValue::Array( + Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef + ), ]; b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch black_box( to_timestamp() .invoke(&args.clone()) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs new file mode 100644 index 000000000000..07ce522eb913 --- /dev/null +++ b/datafusion/functions/benches/trunc.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +extern crate criterion; + +use arrow::{ + datatypes::{Float32Type, Float64Type}, + util::bench_util::create_primitive_array, +}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr::ColumnarValue; +use datafusion_functions::math::trunc; + +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let trunc = trunc(); + for size in [1024, 4096, 8192] { + let f32_array = Arc::new(create_primitive_array::(size, 0.2)); + let f32_args = vec![ColumnarValue::Array(f32_array)]; + c.bench_function(&format!("trunc f32 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(trunc.invoke(&f32_args).unwrap()) + }) + }); + let f64_array = Arc::new(create_primitive_array::(size, 0.2)); + let f64_args = vec![ColumnarValue::Array(f64_array)]; + c.bench_function(&format!("trunc f64 array: {}", size), |b| { + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(trunc.invoke(&f64_args).unwrap()) + }) + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index a3e5fbd7a433..ac4ecacff941 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -37,7 +37,10 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); c.bench_function("upper_all_values_are_ascii", |b| { - b.iter(|| black_box(upper.invoke(&args))) + b.iter(|| { + #[allow(deprecated)] // TODO use invoke_batch + black_box(upper.invoke(&args)) + }) }); } } diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index d641389e0ae3..a3e3feaa17e3 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,17 +17,20 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use std::any::Any; -use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; - -use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::DataType; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, - ScalarValue, + arrow_datafusion_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, + ExprSchema, Result, ScalarValue, }; +use std::any::Any; +use std::sync::OnceLock; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ExprSchemable, ScalarUDFImpl, Signature, + Volatility, +}; /// Implements casting to arbitrary arrow types (rather than SQL types) /// @@ -44,7 +47,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} /// select cast(column_x as int) ... /// ``` /// -/// You can use the `arrow_cast` functiont to cast to a specific arrow type +/// Use the `arrow_cast` function to cast to a specific arrow type /// /// For example /// ```sql @@ -88,6 +91,10 @@ impl ScalarUDFImpl for ArrowCastFunc { internal_err!("arrow_cast should return type from exprs") } + fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + args.iter().any(|e| e.nullable(schema).ok().unwrap_or(true)) + } + fn return_type_from_exprs( &self, args: &[Expr], @@ -126,6 +133,39 @@ impl ScalarUDFImpl for ArrowCastFunc { // return the newly written argument to DataFusion Ok(ExprSimplifyResult::Simplified(new_expr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_arrow_cast_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_arrow_cast_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description("Casts a value to a specific Arrow data type.") + .with_syntax_example("arrow_cast(expression, datatype)") + .with_sql_example( + r#"```sql +> select arrow_cast(-5, 'Int8') as a, + arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, + arrow_cast('bar', 'LargeUtf8') as c, + arrow_cast('2023-01-02T12:53:02', 'Timestamp(Microsecond, Some("+08:00"))') as d + ; ++----+-----+-----+---------------------------+ +| a | b | c | d | ++----+-----+-----+---------------------------+ +| -5 | foo | bar | 2023-01-02T12:53:02+08:00 | ++----+-----+-----+---------------------------+ +```"#, + ) + .with_argument("expression", "Expression to cast. The expression can be a constant, column, or function, and any combination of operators.") + .with_argument("datatype", "[Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`]") + .build() + .unwrap() + }) } /// Returns the requested type from the arguments @@ -139,763 +179,11 @@ fn data_type_from_args(args: &[Expr]) -> Result { &args[1] ); }; - parse_data_type(val) -} - -/// Parses `str` into a `DataType`. -/// -/// `parse_data_type` is the reverse of [`DataType`]'s `Display` -/// impl, and maintains the invariant that -/// `parse_data_type(data_type.to_string()) == data_type` -/// -/// Remove if added to arrow: -fn parse_data_type(val: &str) -> Result { - Parser::new(val).parse() -} - -fn make_error(val: &str, msg: &str) -> DataFusionError { - plan_datafusion_err!("Unsupported type '{val}'. Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'. Error {msg}" ) -} - -fn make_error_expected(val: &str, expected: &Token, actual: &Token) -> DataFusionError { - make_error(val, &format!("Expected '{expected}', got '{actual}'")) -} - -#[derive(Debug)] -/// Implementation of `parse_data_type`, modeled after -struct Parser<'a> { - val: &'a str, - tokenizer: Tokenizer<'a>, -} - -impl<'a> Parser<'a> { - fn new(val: &'a str) -> Self { - Self { - val, - tokenizer: Tokenizer::new(val), - } - } - - fn parse(mut self) -> Result { - let data_type = self.parse_next_type()?; - // ensure that there is no trailing content - if self.tokenizer.next().is_some() { - Err(make_error( - self.val, - &format!("checking trailing content after parsing '{data_type}'"), - )) - } else { - Ok(data_type) - } - } - - /// parses the next full DataType - fn parse_next_type(&mut self) -> Result { - match self.next_token()? { - Token::SimpleType(data_type) => Ok(data_type), - Token::Timestamp => self.parse_timestamp(), - Token::Time32 => self.parse_time32(), - Token::Time64 => self.parse_time64(), - Token::Duration => self.parse_duration(), - Token::Interval => self.parse_interval(), - Token::FixedSizeBinary => self.parse_fixed_size_binary(), - Token::Decimal128 => self.parse_decimal_128(), - Token::Decimal256 => self.parse_decimal_256(), - Token::Dictionary => self.parse_dictionary(), - Token::List => self.parse_list(), - Token::LargeList => self.parse_large_list(), - Token::FixedSizeList => self.parse_fixed_size_list(), - tok => Err(make_error( - self.val, - &format!("finding next type, got unexpected '{tok}'"), - )), - } - } - - /// Parses the List type - fn parse_list(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let data_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::List(Arc::new(Field::new( - "item", data_type, true, - )))) - } - - /// Parses the LargeList type - fn parse_large_list(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let data_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::LargeList(Arc::new(Field::new( - "item", data_type, true, - )))) - } - - /// Parses the FixedSizeList type - fn parse_fixed_size_list(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let length = self.parse_i32("FixedSizeList")?; - self.expect_token(Token::Comma)?; - let data_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::FixedSizeList( - Arc::new(Field::new("item", data_type, true)), - length, - )) - } - - /// Parses the next timeunit - fn parse_time_unit(&mut self, context: &str) -> Result { - match self.next_token()? { - Token::TimeUnit(time_unit) => Ok(time_unit), - tok => Err(make_error( - self.val, - &format!("finding TimeUnit for {context}, got {tok}"), - )), - } - } - - /// Parses the next timezone - fn parse_timezone(&mut self, context: &str) -> Result> { - match self.next_token()? { - Token::None => Ok(None), - Token::Some => { - self.expect_token(Token::LParen)?; - let timezone = self.parse_double_quoted_string("Timezone")?; - self.expect_token(Token::RParen)?; - Ok(Some(timezone)) - } - tok => Err(make_error( - self.val, - &format!("finding Timezone for {context}, got {tok}"), - )), - } - } - - /// Parses the next double quoted string - fn parse_double_quoted_string(&mut self, context: &str) -> Result { - match self.next_token()? { - Token::DoubleQuotedString(s) => Ok(s), - tok => Err(make_error( - self.val, - &format!("finding double quoted string for {context}, got '{tok}'"), - )), - } - } - - /// Parses the next integer value - fn parse_i64(&mut self, context: &str) -> Result { - match self.next_token()? { - Token::Integer(v) => Ok(v), - tok => Err(make_error( - self.val, - &format!("finding i64 for {context}, got '{tok}'"), - )), - } - } - - /// Parses the next i32 integer value - fn parse_i32(&mut self, context: &str) -> Result { - let length = self.parse_i64(context)?; - length.try_into().map_err(|e| { - make_error( - self.val, - &format!("converting {length} into i32 for {context}: {e}"), - ) - }) - } - - /// Parses the next i8 integer value - fn parse_i8(&mut self, context: &str) -> Result { - let length = self.parse_i64(context)?; - length.try_into().map_err(|e| { - make_error( - self.val, - &format!("converting {length} into i8 for {context}: {e}"), - ) - }) - } - - /// Parses the next u8 integer value - fn parse_u8(&mut self, context: &str) -> Result { - let length = self.parse_i64(context)?; - length.try_into().map_err(|e| { - make_error( - self.val, - &format!("converting {length} into u8 for {context}: {e}"), - ) - }) - } - - /// Parses the next timestamp (called after `Timestamp` has been consumed) - fn parse_timestamp(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Timestamp")?; - self.expect_token(Token::Comma)?; - let timezone = self.parse_timezone("Timestamp")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Timestamp(time_unit, timezone.map(Into::into))) - } - - /// Parses the next Time32 (called after `Time32` has been consumed) - fn parse_time32(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Time32")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Time32(time_unit)) - } - - /// Parses the next Time64 (called after `Time64` has been consumed) - fn parse_time64(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Time64")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Time64(time_unit)) - } - - /// Parses the next Duration (called after `Duration` has been consumed) - fn parse_duration(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let time_unit = self.parse_time_unit("Duration")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Duration(time_unit)) - } - - /// Parses the next Interval (called after `Interval` has been consumed) - fn parse_interval(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let interval_unit = match self.next_token()? { - Token::IntervalUnit(interval_unit) => interval_unit, - tok => { - return Err(make_error( - self.val, - &format!("finding IntervalUnit for Interval, got {tok}"), - )) - } - }; - self.expect_token(Token::RParen)?; - Ok(DataType::Interval(interval_unit)) - } - - /// Parses the next FixedSizeBinary (called after `FixedSizeBinary` has been consumed) - fn parse_fixed_size_binary(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let length = self.parse_i32("FixedSizeBinary")?; - self.expect_token(Token::RParen)?; - Ok(DataType::FixedSizeBinary(length)) - } - - /// Parses the next Decimal128 (called after `Decimal128` has been consumed) - fn parse_decimal_128(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let precision = self.parse_u8("Decimal128")?; - self.expect_token(Token::Comma)?; - let scale = self.parse_i8("Decimal128")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Decimal128(precision, scale)) - } - - /// Parses the next Decimal256 (called after `Decimal256` has been consumed) - fn parse_decimal_256(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let precision = self.parse_u8("Decimal256")?; - self.expect_token(Token::Comma)?; - let scale = self.parse_i8("Decimal256")?; - self.expect_token(Token::RParen)?; - Ok(DataType::Decimal256(precision, scale)) - } - - /// Parses the next Dictionary (called after `Dictionary` has been consumed) - fn parse_dictionary(&mut self) -> Result { - self.expect_token(Token::LParen)?; - let key_type = self.parse_next_type()?; - self.expect_token(Token::Comma)?; - let value_type = self.parse_next_type()?; - self.expect_token(Token::RParen)?; - Ok(DataType::Dictionary( - Box::new(key_type), - Box::new(value_type), - )) - } - /// return the next token, or an error if there are none left - fn next_token(&mut self) -> Result { - match self.tokenizer.next() { - None => Err(make_error(self.val, "finding next token")), - Some(token) => token, - } - } - - /// consume the next token, returning OK(()) if it matches tok, and Err if not - fn expect_token(&mut self, tok: Token) -> Result<()> { - let next_token = self.next_token()?; - if next_token == tok { - Ok(()) - } else { - Err(make_error_expected(self.val, &tok, &next_token)) - } - } -} - -/// returns true if this character is a separator -fn is_separator(c: char) -> bool { - c == '(' || c == ')' || c == ',' || c == ' ' -} - -#[derive(Debug)] -/// Splits a strings like Dictionary(Int32, Int64) into tokens sutable for parsing -/// -/// For example the string "Timestamp(Nanosecond, None)" would be parsed into: -/// -/// * Token::Timestamp -/// * Token::Lparen -/// * Token::IntervalUnit(IntervalUnit::Nanosecond) -/// * Token::Comma, -/// * Token::None, -/// * Token::Rparen, -struct Tokenizer<'a> { - val: &'a str, - chars: Peekable>, - // temporary buffer for parsing words - word: String, -} - -impl<'a> Tokenizer<'a> { - fn new(val: &'a str) -> Self { - Self { - val, - chars: val.chars().peekable(), - word: String::new(), - } - } - - /// returns the next char, without consuming it - fn peek_next_char(&mut self) -> Option { - self.chars.peek().copied() - } - - /// returns the next char, and consuming it - fn next_char(&mut self) -> Option { - self.chars.next() - } - - /// parse the characters in val starting at pos, until the next - /// `,`, `(`, or `)` or end of line - fn parse_word(&mut self) -> Result { - // reset temp space - self.word.clear(); - loop { - match self.peek_next_char() { - None => break, - Some(c) if is_separator(c) => break, - Some(c) => { - self.next_char(); - self.word.push(c); - } - } - } - - if let Some(c) = self.word.chars().next() { - // if it started with a number, try parsing it as an integer - if c == '-' || c.is_numeric() { - let val: i64 = self.word.parse().map_err(|e| { - make_error( - self.val, - &format!("parsing {} as integer: {e}", self.word), - ) - })?; - return Ok(Token::Integer(val)); - } - // if it started with a double quote `"`, try parsing it as a double quoted string - else if c == '"' { - let len = self.word.chars().count(); - - // to verify it's double quoted - if let Some(last_c) = self.word.chars().last() { - if last_c != '"' || len < 2 { - return Err(make_error( - self.val, - &format!("parsing {} as double quoted string: last char must be \"", self.word), - )); - } - } - - if len == 2 { - return Err(make_error( - self.val, - &format!("parsing {} as double quoted string: empty string isn't supported", self.word), - )); - } - - let val: String = self.word.parse().map_err(|e| { - make_error( - self.val, - &format!("parsing {} as double quoted string: {e}", self.word), - ) - })?; - - let s = val[1..len - 1].to_string(); - if s.contains('"') { - return Err(make_error( - self.val, - &format!("parsing {} as double quoted string: escaped double quote isn't supported", self.word), - )); - } - - return Ok(Token::DoubleQuotedString(s)); - } - } - - // figure out what the word was - let token = match self.word.as_str() { - "Null" => Token::SimpleType(DataType::Null), - "Boolean" => Token::SimpleType(DataType::Boolean), - - "Int8" => Token::SimpleType(DataType::Int8), - "Int16" => Token::SimpleType(DataType::Int16), - "Int32" => Token::SimpleType(DataType::Int32), - "Int64" => Token::SimpleType(DataType::Int64), - - "UInt8" => Token::SimpleType(DataType::UInt8), - "UInt16" => Token::SimpleType(DataType::UInt16), - "UInt32" => Token::SimpleType(DataType::UInt32), - "UInt64" => Token::SimpleType(DataType::UInt64), - - "Utf8" => Token::SimpleType(DataType::Utf8), - "LargeUtf8" => Token::SimpleType(DataType::LargeUtf8), - "Binary" => Token::SimpleType(DataType::Binary), - "LargeBinary" => Token::SimpleType(DataType::LargeBinary), - - "Float16" => Token::SimpleType(DataType::Float16), - "Float32" => Token::SimpleType(DataType::Float32), - "Float64" => Token::SimpleType(DataType::Float64), - - "Date32" => Token::SimpleType(DataType::Date32), - "Date64" => Token::SimpleType(DataType::Date64), - - "List" => Token::List, - "LargeList" => Token::LargeList, - "FixedSizeList" => Token::FixedSizeList, - - "Second" => Token::TimeUnit(TimeUnit::Second), - "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), - "Microsecond" => Token::TimeUnit(TimeUnit::Microsecond), - "Nanosecond" => Token::TimeUnit(TimeUnit::Nanosecond), - - "Timestamp" => Token::Timestamp, - "Time32" => Token::Time32, - "Time64" => Token::Time64, - "Duration" => Token::Duration, - "Interval" => Token::Interval, - "Dictionary" => Token::Dictionary, - - "FixedSizeBinary" => Token::FixedSizeBinary, - "Decimal128" => Token::Decimal128, - "Decimal256" => Token::Decimal256, - - "YearMonth" => Token::IntervalUnit(IntervalUnit::YearMonth), - "DayTime" => Token::IntervalUnit(IntervalUnit::DayTime), - "MonthDayNano" => Token::IntervalUnit(IntervalUnit::MonthDayNano), - - "Some" => Token::Some, - "None" => Token::None, - - _ => { - return Err(make_error( - self.val, - &format!("unrecognized word: {}", self.word), - )) - } - }; - Ok(token) - } -} - -impl<'a> Iterator for Tokenizer<'a> { - type Item = Result; - - fn next(&mut self) -> Option { - loop { - match self.peek_next_char()? { - ' ' => { - // skip whitespace - self.next_char(); - continue; - } - '(' => { - self.next_char(); - return Some(Ok(Token::LParen)); - } - ')' => { - self.next_char(); - return Some(Ok(Token::RParen)); - } - ',' => { - self.next_char(); - return Some(Ok(Token::Comma)); - } - _ => return Some(self.parse_word()), - } - } - } -} - -/// Grammar is -/// -#[derive(Debug, PartialEq)] -enum Token { - // Null, or Int32 - SimpleType(DataType), - Timestamp, - Time32, - Time64, - Duration, - Interval, - FixedSizeBinary, - Decimal128, - Decimal256, - Dictionary, - TimeUnit(TimeUnit), - IntervalUnit(IntervalUnit), - LParen, - RParen, - Comma, - Some, - None, - Integer(i64), - DoubleQuotedString(String), - List, - LargeList, - FixedSizeList, -} - -impl Display for Token { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Token::SimpleType(t) => write!(f, "{t}"), - Token::List => write!(f, "List"), - Token::LargeList => write!(f, "LargeList"), - Token::FixedSizeList => write!(f, "FixedSizeList"), - Token::Timestamp => write!(f, "Timestamp"), - Token::Time32 => write!(f, "Time32"), - Token::Time64 => write!(f, "Time64"), - Token::Duration => write!(f, "Duration"), - Token::Interval => write!(f, "Interval"), - Token::TimeUnit(u) => write!(f, "TimeUnit({u:?})"), - Token::IntervalUnit(u) => write!(f, "IntervalUnit({u:?})"), - Token::LParen => write!(f, "("), - Token::RParen => write!(f, ")"), - Token::Comma => write!(f, ","), - Token::Some => write!(f, "Some"), - Token::None => write!(f, "None"), - Token::FixedSizeBinary => write!(f, "FixedSizeBinary"), - Token::Decimal128 => write!(f, "Decimal128"), - Token::Decimal256 => write!(f, "Decimal256"), - Token::Dictionary => write!(f, "Dictionary"), - Token::Integer(v) => write!(f, "Integer({v})"), - Token::DoubleQuotedString(s) => write!(f, "DoubleQuotedString({s})"), - } - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_parse_data_type() { - // this ensures types can be parsed correctly from their string representations - for dt in list_datatypes() { - round_trip(dt) - } - } - - /// convert data_type to a string, and then parse it as a type - /// verifying it is the same - fn round_trip(data_type: DataType) { - let data_type_string = data_type.to_string(); - println!("Input '{data_type_string}' ({data_type:?})"); - let parsed_type = parse_data_type(&data_type_string).unwrap(); - assert_eq!( - data_type, parsed_type, - "Mismatch parsing {data_type_string}" - ); - } - - fn list_datatypes() -> Vec { - vec![ - // --------- - // Non Nested types - // --------- - DataType::Null, - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float16, - DataType::Float32, - DataType::Float64, - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), - // we can't cover all possible timezones, here we only test utc and +08:00 - DataType::Timestamp(TimeUnit::Nanosecond, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Millisecond, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Second, Some("+00:00".into())), - DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())), - DataType::Timestamp(TimeUnit::Microsecond, Some("+08:00".into())), - DataType::Timestamp(TimeUnit::Millisecond, Some("+08:00".into())), - DataType::Timestamp(TimeUnit::Second, Some("+08:00".into())), - DataType::Date32, - DataType::Date64, - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Microsecond), - DataType::Time32(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Second), - DataType::Time64(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Nanosecond), - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Binary, - DataType::FixedSizeBinary(0), - DataType::FixedSizeBinary(1234), - DataType::FixedSizeBinary(-432), - DataType::LargeBinary, - DataType::Utf8, - DataType::LargeUtf8, - DataType::Decimal128(7, 12), - DataType::Decimal256(6, 13), - // --------- - // Nested types - // --------- - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Timestamp(TimeUnit::Nanosecond, None)), - ), - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::FixedSizeBinary(23)), - ), - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new( - // nested dictionaries are probably a bad idea but they are possible - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Utf8), - ), - ), - ), - // TODO support more structured types (List, LargeList, Struct, Union, Map, RunEndEncoded, etc) - ] - } - - #[test] - fn test_parse_data_type_whitespace_tolerance() { - // (string to parse, expected DataType) - let cases = [ - ("Int8", DataType::Int8), - ( - "Timestamp (Nanosecond, None)", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ( - "Timestamp (Nanosecond, None) ", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ( - " Timestamp (Nanosecond, None )", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ( - "Timestamp (Nanosecond, None ) ", - DataType::Timestamp(TimeUnit::Nanosecond, None), - ), - ]; - - for (data_type_string, expected_data_type) in cases { - println!("Parsing '{data_type_string}', expecting '{expected_data_type:?}'"); - let parsed_data_type = parse_data_type(data_type_string).unwrap(); - assert_eq!(parsed_data_type, expected_data_type); - } - } - - #[test] - fn parse_data_type_errors() { - // (string to parse, expected error message) - let cases = [ - ("", "Unsupported type ''"), - ("", "Error finding next token"), - ("null", "Unsupported type 'null'"), - ("Nu", "Unsupported type 'Nu'"), - ( - r#"Timestamp(Nanosecond, Some(+00:00))"#, - "Error unrecognized word: +00:00", - ), - ( - r#"Timestamp(Nanosecond, Some("+00:00))"#, - r#"parsing "+00:00 as double quoted string: last char must be ""#, - ), - ( - r#"Timestamp(Nanosecond, Some(""))"#, - r#"parsing "" as double quoted string: empty string isn't supported"#, - ), - ( - r#"Timestamp(Nanosecond, Some("+00:00""))"#, - r#"parsing "+00:00"" as double quoted string: escaped double quote isn't supported"#, - ), - ("Timestamp(Nanosecond, ", "Error finding next token"), - ( - "Float32 Float32", - "trailing content after parsing 'Float32'", - ), - ("Int32, ", "trailing content after parsing 'Int32'"), - ("Int32(3), ", "trailing content after parsing 'Int32'"), - ("FixedSizeBinary(Int32), ", "Error finding i64 for FixedSizeBinary, got 'Int32'"), - ("FixedSizeBinary(3.0), ", "Error parsing 3.0 as integer: invalid digit found in string"), - // too large for i32 - ("FixedSizeBinary(4000000000), ", "Error converting 4000000000 into i32 for FixedSizeBinary: out of range integral type conversion attempted"), - // can't have negative precision - ("Decimal128(-3, 5)", "Error converting -3 into u8 for Decimal128: out of range integral type conversion attempted"), - ("Decimal256(-3, 5)", "Error converting -3 into u8 for Decimal256: out of range integral type conversion attempted"), - ("Decimal128(3, 500)", "Error converting 500 into i8 for Decimal128: out of range integral type conversion attempted"), - ("Decimal256(3, 500)", "Error converting 500 into i8 for Decimal256: out of range integral type conversion attempted"), - - ]; - - for (data_type_string, expected_message) in cases { - print!("Parsing '{data_type_string}', expecting '{expected_message}'"); - match parse_data_type(data_type_string) { - Ok(d) => panic!( - "Expected error while parsing '{data_type_string}', but got '{d}'" - ), - Err(e) => { - let message = e.to_string(); - assert!( - message.contains(expected_message), - "\n\ndid not find expected in actual.\n\nexpected: {expected_message}\nactual:{message}\n" - ); - // errors should also contain a help message - assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); - } - } - } - } + val.parse().map_err(|e| match e { + // If the data type cannot be parsed, return a Plan error to signal an + // error in the input rather than a more general ArrowError + arrow::error::ArrowError::ParseError(e) => plan_datafusion_err!("{e}"), + e => arrow_datafusion_err!(e), + }) } diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index cc5e7e619bd8..a425aff6caad 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -17,9 +17,11 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct ArrowTypeOfFunc { @@ -69,4 +71,35 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { "{input_data_type}" )))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_arrowtypeof_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_arrowtypeof_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description( + "Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression.", + ) + .with_syntax_example("arrow_typeof(expression)") + .with_sql_example( + r#"```sql +> select arrow_typeof('foo'), arrow_typeof(1); ++---------------------------+------------------------+ +| arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | ++---------------------------+------------------------+ +| Utf8 | Int64 | ++---------------------------+------------------------+ +``` +"#, + ) + .with_argument("expression", "Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators.") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 76f2a3ed741b..a05f3f08232c 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -15,17 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; - -use datafusion_common::{exec_err, Result}; -use datafusion_expr::type_coercion::functions::data_types; -use datafusion_expr::ColumnarValue; +use datafusion_common::{exec_err, ExprSchema, Result}; +use datafusion_expr::binary::try_type_union_resolution; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use itertools::Itertools; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct CoalesceFunc { @@ -41,7 +42,7 @@ impl Default for CoalesceFunc { impl CoalesceFunc { pub fn new() -> Self { Self { - signature: Signature::variadic_equal(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -60,9 +61,16 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - // COALESCE has multiple args and they might get coerced, get a preview of this - let coerced_types = data_types(arg_types, self.signature()); - coerced_types.map(|types| types[0].clone()) + Ok(arg_types + .iter() + .find_or_first(|d| !d.is_null()) + .unwrap() + .clone()) + } + + // If any the arguments in coalesce is non-null, the result is non-null + fn is_nullable(&self, args: &[Expr], schema: &dyn ExprSchema) -> bool { + args.iter().all(|e| e.nullable(schema).ok().unwrap_or(true)) } /// coalesce evaluates to the first value which is not NULL @@ -124,6 +132,44 @@ impl ScalarUDFImpl for CoalesceFunc { fn short_circuits(&self) -> bool { true } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.is_empty() { + return exec_err!("coalesce must have at least one argument"); + } + + try_type_union_resolution(arg_types) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_coalesce_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_coalesce_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values.") + .with_syntax_example("coalesce(expression1[, ..., expression_n])") + .with_sql_example(r#"```sql +> select coalesce(null, null, 'datafusion'); ++----------------------------------------+ +| coalesce(NULL,NULL,Utf8("datafusion")) | ++----------------------------------------+ +| datafusion | ++----------------------------------------+ +```"#, + ) + .with_argument( + "expression1, expression_n", + "Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary." + ) + .build() + .unwrap() + }) } #[cfg(test)] @@ -142,4 +188,22 @@ mod test { .unwrap(); assert_eq!(return_type, DataType::Date32); } + + #[test] + fn test_coalesce_return_types_with_nulls_first() { + let coalesce = core::coalesce::CoalesceFunc::new(); + let return_type = coalesce + .return_type(&[DataType::Null, DataType::Date32]) + .unwrap(); + assert_eq!(return_type, DataType::Date32); + } + + #[test] + fn test_coalesce_return_types_with_nulls_last() { + let coalesce = core::coalesce::CoalesceFunc::new(); + let return_type = coalesce + .return_type(&[DataType::Int64, DataType::Null]) + .unwrap(); + assert_eq!(return_type, DataType::Int64); + } } diff --git a/datafusion/functions/src/core/expr_ext.rs b/datafusion/functions/src/core/expr_ext.rs new file mode 100644 index 000000000000..af05f447f1c1 --- /dev/null +++ b/datafusion/functions/src/core/expr_ext.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Extension methods for Expr. + +use datafusion_expr::{Expr, Literal}; + +use super::expr_fn::get_field; + +/// Return access to the named field. Example `expr["name"]` +/// +/// ## Access field "my_field" from column "c1" +/// +/// For example if column "c1" holds documents like this +/// +/// ```json +/// { +/// "my_field": 123.34, +/// "other_field": "Boston", +/// } +/// ``` +/// +/// You can access column "my_field" with +/// +/// ``` +/// # use datafusion_expr::{col}; +/// # use datafusion_functions::core::expr_ext::FieldAccessor; +/// let expr = col("c1") +/// .field("my_field"); +/// assert_eq!(expr.schema_name().to_string(), "c1[my_field]"); +/// ``` +pub trait FieldAccessor { + fn field(self, name: impl Literal) -> Expr; +} + +impl FieldAccessor for Expr { + fn field(self, name: impl Literal) -> Expr { + get_field(self, name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use datafusion_expr::col; + + #[test] + fn test_field() { + let expr1 = col("a").field("b"); + let expr2 = get_field(col("a"), "b"); + assert_eq!(expr1, expr2); + } +} diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b00b8ea553f2..c0af4d35966b 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -15,14 +15,19 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Scalar, StringArray}; +use arrow::array::{ + make_array, Array, Capacities, MutableArrayData, Scalar, StringArray, +}; use arrow::datatypes::DataType; use datafusion_common::cast::{as_map_array, as_struct_array}; -use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; -use datafusion_expr::field_util::GetFieldAccessSchema; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_common::{ + exec_err, plan_datafusion_err, plan_err, ExprSchema, Result, ScalarValue, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct GetFieldFunc { @@ -48,10 +53,51 @@ impl ScalarUDFImpl for GetFieldFunc { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "get_field" } + fn display_name(&self, args: &[Expr]) -> Result { + if args.len() != 2 { + return exec_err!( + "get_field function requires 2 arguments, got {}", + args.len() + ); + } + + let name = match &args[1] { + Expr::Literal(name) => name, + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + + Ok(format!("{}[{}]", args[0], name)) + } + + fn schema_name(&self, args: &[Expr]) -> Result { + if args.len() != 2 { + return exec_err!( + "get_field function requires 2 arguments, got {}", + args.len() + ); + } + + let name = match &args[1] { + Expr::Literal(name) => name, + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + + Ok(format!("{}[{}]", args[0].schema_name(), name)) + } + fn signature(&self) -> &Signature { &self.signature } @@ -81,11 +127,37 @@ impl ScalarUDFImpl for GetFieldFunc { ); } }; - let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() }; - let arg_dt = args[0].get_type(schema)?; - access_schema - .get_accessed_field(&arg_dt) - .map(|f| f.data_type().clone()) + let data_type = args[0].get_type(schema)?; + match (data_type, name) { + (DataType::Map(fields, _), _) => { + match fields.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + // Arrow's MapArray is essentially a ListArray of structs with two columns. They are + // often named "key", and "value", but we don't require any specific naming here; + // instead, we assume that the second column is the "value" column both here and in + // execution. + let value_field = fields.get(1).expect("fields should have exactly two members"); + Ok(value_field.data_type().clone()) + }, + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + } + } + (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => { + if s.is_empty() { + plan_err!( + "Struct based indexed access requires a non empty string" + ) + } else { + let field = fields.iter().find(|f| f.name() == s); + field.ok_or(plan_datafusion_err!("Field {s} not found in struct")).map(|f| f.data_type().clone()) + } + } + (DataType::Struct(_), _) => plan_err!( + "Only UTF8 strings are valid as an indexed field in a struct" + ), + (DataType::Null, _) => Ok(DataType::Null), + (other, _) => plan_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -96,8 +168,12 @@ impl ScalarUDFImpl for GetFieldFunc { ); } + if args[0].data_type().is_null() { + return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + } + let arrays = ColumnarValue::values_to_arrays(args)?; - let array = arrays[0].clone(); + let array = Arc::clone(&arrays[0]); let name = match &args[1] { ColumnarValue::Scalar(name) => name, @@ -107,29 +183,109 @@ impl ScalarUDFImpl for GetFieldFunc { ); } }; + match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - let entries = arrow::compute::filter(map_array.entries(), &keys)?; - let entries_struct_array = as_struct_array(entries.as_ref())?; - Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) - } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => exec_err!( - "get indexed field {k} not found in struct"), - Some(col) => Ok(ColumnarValue::Array(col.clone())) + (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { + let map_array = as_map_array(array.as_ref())?; + let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + + // note that this array has more entries than the expected output/input size + // because map_array is flattened + let original_data = map_array.entries().column(1).to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], true, + capacity); + + for entry in 0..map_array.len(){ + let start = map_array.value_offsets()[entry] as usize; + let end = map_array.value_offsets()[entry + 1] as usize; + + let maybe_matched = + keys.slice(start, end-start). + iter().enumerate(). + find(|(_, t)| t.unwrap()); + if maybe_matched.is_none() { + mutable.extend_nulls(1); + continue } + let (match_offset,_) = maybe_matched.unwrap(); + mutable.extend(0, start + match_offset, start + match_offset + 1); + } + let data = mutable.freeze(); + let data = make_array(data); + Ok(ColumnarValue::Array(data)) + } + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(k) { + None => exec_err!("get indexed field {k} not found in struct"), + Some(col) => Ok(ColumnarValue::Array(Arc::clone(col))), } - (DataType::Struct(_), name) => exec_err!( - "get indexed field is only possible on struct with utf8 indexes. \ - Tried with {name:?} index"), - (dt, name) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {name:?} index"), } + (DataType::Struct(_), name) => exec_err!( + "get_field is only possible on struct with utf8 indexes. \ + Received with {name:?} index" + ), + (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + (dt, name) => exec_err!( + "get_field is only possible on maps with utf8 indexes or struct \ + with utf8 indexes. Received {dt:?} with {name:?} index" + ), + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_getfield_doc()) } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_getfield_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description(r#"Returns a field within a map or a struct with the given key. +Note: most users invoke `get_field` indirectly via field access +syntax such as `my_struct_col['field_name']` which results in a call to +`get_field(my_struct_col, 'field_name')`."#) + .with_syntax_example("get_field(expression1, expression2)") + .with_sql_example(r#"```sql +> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); +> select struct(idx, v) from t as c; ++-------------------------+ +| struct(c.idx,c.v) | ++-------------------------+ +| {c0: data, c1: fusion} | +| {c0: apache, c1: arrow} | ++-------------------------+ +> select get_field((select struct(idx, v) from t), 'c0'); ++-----------------------+ +| struct(t.idx,t.v)[c0] | ++-----------------------+ +| data | +| apache | ++-----------------------+ +> select get_field((select struct(idx, v) from t), 'c1'); ++-----------------------+ +| struct(t.idx,t.v)[c1] | ++-----------------------+ +| fusion | +| arrow | ++-----------------------+ +``` + "#) + .with_argument( + "expression1", + "The map or struct to retrieve a field for." + ) + .with_argument( + "expression2", + "The field name in the map or struct to retrieve data for. Must evaluate to a string." + ) + .build() + .unwrap() + }) +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 753134bdfdc2..cf64c03766cb 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -17,15 +17,21 @@ //! "core" DataFusion functions +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod arrow_cast; pub mod arrowtypeof; pub mod coalesce; +pub mod expr_ext; pub mod getfield; pub mod named_struct; pub mod nullif; pub mod nvl; pub mod nvl2; +pub mod planner; pub mod r#struct; +pub mod version; // create UDFs make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); @@ -37,16 +43,70 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce); +make_udf_function!(version::VersionFunc, VERSION, version); + +pub mod expr_fn { + use datafusion_expr::{Expr, Literal}; + + export_functions!(( + nullif, + "Returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression", + arg1 arg2 + ),( + arrow_cast, + "Returns value2 if value1 is NULL; otherwise it returns value1", + arg1 arg2 + ),( + nvl, + "Returns value2 if value1 is NULL; otherwise it returns value1", + arg1 arg2 + ),( + nvl2, + "Returns value2 if value1 is not NULL; otherwise, it returns value3.", + arg1 arg2 arg3 + ),( + arrow_typeof, + "Returns the Arrow type of the input expression.", + arg1 + ),( + r#struct, + "Returns a struct with the given arguments", + args, + ),( + named_struct, + "Returns a struct with the given names and arguments pairs", + args, + ),( + coalesce, + "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL", + args, + )); + + #[doc = "Returns the value of the field with the given name from the struct"] + pub fn get_field(arg1: Expr, arg2: impl Literal) -> Expr { + super::get_field().call(vec![arg1, arg2.lit()]) + } +} -// Export the functions out of this package, both as expr_fn as well as a list of functions -export_functions!( - (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), - (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."), - (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), - (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), - (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), - (r#struct, args, "Returns a struct with the given arguments"), - (named_struct, args, "Returns a struct with the given names and arguments pairs"), - (get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct"), - (coalesce, args, "Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL") -); +/// Returns all DataFusion functions defined in this package +pub fn functions() -> Vec> { + vec![ + nullif(), + arrow_cast(), + nvl(), + nvl2(), + arrow_typeof(), + named_struct(), + // Note: most users invoke `get_field` indirectly via field access + // syntax like `my_struct_col['field_name']`, which results in a call to + // `get_field(my_struct_col, "field_name")`. + // + // However, it is also exposed directly for use cases such as + // serializing / deserializing plans with the field access desugared to + // calls to `get_field` + get_field(), + coalesce(), + version(), + r#struct(), + ] +} diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 8ccda977f3a4..b2c7f06d5868 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -18,10 +18,12 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use hashbrown::HashSet; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; /// put values in a struct array. fn named_struct_expr(args: &[ColumnarValue]) -> Result { @@ -45,7 +47,6 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { .map(|(i, chunk)| { let name_column = &chunk[0]; - let name = match name_column { ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) @@ -57,20 +58,30 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { .into_iter() .unzip(); - let arrays = ColumnarValue::values_to_arrays(&values)?; + { + // Check to enforce the uniqueness of struct field name + let mut unique_field_names = HashSet::new(); + for name in names.iter() { + if unique_field_names.contains(name) { + return exec_err!( + "named_struct requires unique field names. Field {name} is used more than once." + ); + } + unique_field_names.insert(name); + } + } - let fields = names + let fields: Fields = names .into_iter() - .zip(arrays) - .map(|(name, value)| { - ( - Arc::new(Field::new(name, value.data_type().clone(), true)), - value, - ) - }) - .collect::>(); + .zip(&values) + .map(|(name, value)| Arc::new(Field::new(name, value.data_type().clone(), true))) + .collect::>() + .into(); - Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields)))) + let arrays = ColumnarValue::values_to_arrays(&values)?; + + let struct_array = StructArray::new(fields, arrays, None); + Ok(ColumnarValue::Array(Arc::new(struct_array))) } #[derive(Debug)] @@ -113,7 +124,7 @@ impl ScalarUDFImpl for NamedStructFunc { fn return_type_from_exprs( &self, - args: &[datafusion_expr::Expr], + args: &[Expr], schema: &dyn datafusion_common::ExprSchema, _arg_types: &[DataType], ) -> Result { @@ -151,4 +162,46 @@ impl ScalarUDFImpl for NamedStructFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { named_struct_expr(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_named_struct_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_named_struct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRUCT) + .with_description("Returns an Arrow struct using the specified name and input expressions pairs.") + .with_syntax_example("named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input])") + .with_sql_example(r#" +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ +> select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +``` +"#) + .with_argument( + "expression_n_name", + "Name of the column field. Must be a constant string." + ) + .with_argument("expression_n_input", "Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators.") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index e8bf2db514c3..f96ee1ea7a12 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -17,14 +17,15 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Documentation}; -use arrow::array::Array; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; use datafusion_common::ScalarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct NullIfFunc { @@ -94,6 +95,47 @@ impl ScalarUDFImpl for NullIfFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { nullif_func(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nullif_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nullif_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns _null_ if _expression1_ equals _expression2_; otherwise it returns _expression1_. +This can be used to perform the inverse operation of [`coalesce`](#coalesce).") + .with_syntax_example("nullif(expression1, expression2)") + .with_sql_example(r#"```sql +> select nullif('datafusion', 'data'); ++-----------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("data")) | ++-----------------------------------------+ +| datafusion | ++-----------------------------------------+ +> select nullif('datafusion', 'datafusion'); ++-----------------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("datafusion")) | ++-----------------------------------------------+ +| | ++-----------------------------------------------+ +``` +"#) + .with_argument( + "expression1", + "Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression2", + "Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators." + ) + .build() + .unwrap() + }) } /// Implements NULLIF(expr1, expr2) @@ -122,8 +164,13 @@ fn nullif_func(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - let lhs = lhs.to_array_of_size(rhs.len())?; - let array = nullif(&lhs, &eq(&lhs, &rhs)?)?; + let lhs_s = lhs.to_scalar()?; + let lhs_a = lhs.to_array_of_size(rhs.len())?; + let array = nullif( + // nullif in arrow-select does not support Datum, so we need to convert to array + lhs_a.as_ref(), + &eq(&lhs_s, &rhs)?, + )?; Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index 05515c6e925c..16438e1b6254 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -20,7 +20,11 @@ use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct NVLFunc { @@ -90,6 +94,46 @@ impl ScalarUDFImpl for NVLFunc { fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nvl_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nvl_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.") + .with_syntax_example("nvl(expression1, expression2)") + .with_sql_example(r#"```sql +> select nvl(null, 'a'); ++---------------------+ +| nvl(NULL,Utf8("a")) | ++---------------------+ +| a | ++---------------------+\ +> select nvl('b', 'a'); ++--------------------------+ +| nvl(Utf8("b"),Utf8("a")) | ++--------------------------+ +| b | ++--------------------------+ +``` +"#) + .with_argument( + "expression1", + "Expression to return if not null. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression2", + "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." + ) + .build() + .unwrap() + }) } fn nvl_func(args: &[ColumnarValue]) -> Result { @@ -101,13 +145,13 @@ fn nvl_func(args: &[ColumnarValue]) -> Result { } let (lhs_array, rhs_array) = match (&args[0], &args[1]) { (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => { - (lhs.clone(), rhs.to_array_of_size(lhs.len())?) + (Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?) } (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { - (lhs.clone(), rhs.clone()) + (Arc::clone(lhs), Arc::clone(rhs)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => { - (lhs.to_array_of_size(rhs.len())?, rhs.clone()) + (lhs.to_array_of_size(rhs.len())?, Arc::clone(rhs)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { let mut current_value = lhs; diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 66b9ef566a78..cfcdb4480787 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -19,8 +19,13 @@ use arrow::array::Array; use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::{internal_err, plan_datafusion_err, Result}; -use datafusion_expr::{utils, ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL; +use datafusion_expr::{ + type_coercion::binary::comparison_coercion, ColumnarValue, Documentation, + ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct NVL2Func { @@ -36,7 +41,7 @@ impl Default for NVL2Func { impl NVL2Func { pub fn new() -> Self { Self { - signature: Signature::variadic_equal(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -55,22 +60,81 @@ impl ScalarUDFImpl for NVL2Func { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 3 { - return Err(plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } Ok(arg_types[1].clone()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { nvl2_func(args) } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 3 { + return exec_err!( + "NVL2 takes exactly three arguments, but got {}", + arg_types.len() + ); + } + let new_type = arg_types.iter().skip(1).try_fold( + arg_types.first().unwrap().clone(), + |acc, x| { + // The coerced types found by `comparison_coercion` are not guaranteed to be + // coercible for the arguments. `comparison_coercion` returns more loose + // types that can be coerced to both `acc` and `x` for comparison purpose. + // See `maybe_data_types` for the actual coercion. + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + )?; + Ok(vec![new_type; arg_types.len()]) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nvl2_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nvl2_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_CONDITIONAL) + .with_description("Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.") + .with_syntax_example("nvl2(expression1, expression2, expression3)") + .with_sql_example(r#"```sql +> select nvl2(null, 'a', 'b'); ++--------------------------------+ +| nvl2(NULL,Utf8("a"),Utf8("b")) | ++--------------------------------+ +| b | ++--------------------------------+ +> select nvl2('data', 'a', 'b'); ++----------------------------------------+ +| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | ++----------------------------------------+ +| a | ++----------------------------------------+ +``` +"#) + .with_argument( + "expression1", + "Expression to test for null. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression2", + "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators." + ) + .with_argument( + "expression3", + "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators." + ) + .build() + .unwrap() + }) } fn nvl2_func(args: &[ColumnarValue]) -> Result { @@ -94,7 +158,7 @@ fn nvl2_func(args: &[ColumnarValue]) -> Result { .iter() .map(|arg| match arg { ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len), - ColumnarValue::Array(array) => Ok(array.clone()), + ColumnarValue::Array(array) => Ok(Arc::clone(array)), }) .collect::>>()?; let to_apply = is_not_null(&args[0])?; diff --git a/datafusion/functions/src/core/planner.rs b/datafusion/functions/src/core/planner.rs new file mode 100644 index 000000000000..717a74797c0b --- /dev/null +++ b/datafusion/functions/src/core/planner.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::Field; +use datafusion_common::Result; +use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::{ExprPlanner, PlannerResult, RawDictionaryExpr}; +use datafusion_expr::{lit, Expr}; + +use super::named_struct; + +#[derive(Default, Debug)] +pub struct CoreFunctionPlanner {} + +impl ExprPlanner for CoreFunctionPlanner { + fn plan_dictionary_literal( + &self, + expr: RawDictionaryExpr, + _schema: &DFSchema, + ) -> Result> { + let mut args = vec![]; + for (k, v) in expr.keys.into_iter().zip(expr.values.into_iter()) { + args.push(k); + args.push(v); + } + Ok(PlannerResult::Planned(named_struct().call(args))) + } + + fn plan_struct_literal( + &self, + args: Vec, + is_named_struct: bool, + ) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf( + if is_named_struct { + named_struct() + } else { + crate::core::r#struct() + }, + args, + ), + ))) + } + + fn plan_overlay(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::string::overlay(), args), + ))) + } + + fn plan_compound_identifier( + &self, + field: &Field, + qualifier: Option<&TableReference>, + nested_names: &[String], + ) -> Result>> { + let col = Expr::Column(Column::from((qualifier, field))); + + // Start with the base column expression + let mut expr = col; + + // Iterate over nested_names and create nested get_field expressions + for nested_name in nested_names { + let get_field_args = vec![expr, lit(ScalarValue::from(nested_name.clone()))]; + expr = Expr::ScalarFunction(ScalarFunction::new_udf( + crate::core::get_field(), + get_field_args, + )); + } + + Ok(PlannerResult::Planned(expr)) + } +} diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 9d4b2e4a0b8b..75d1d4eca698 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -18,10 +18,11 @@ use arrow::array::{ArrayRef, StructArray}; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRUCT; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; fn array_struct(args: &[ArrayRef]) -> Result { // do not accept 0 arguments. @@ -29,23 +30,23 @@ fn array_struct(args: &[ArrayRef]) -> Result { return exec_err!("struct requires at least one argument"); } - let vec: Vec<_> = args + let fields = args .iter() .enumerate() .map(|(i, arg)| { let field_name = format!("c{i}"); - Ok(( - Arc::new(Field::new( - field_name.as_str(), - arg.data_type().clone(), - true, - )), - arg.clone(), - )) + Ok(Arc::new(Field::new( + field_name.as_str(), + arg.data_type().clone(), + true, + ))) }) - .collect::>>()?; + .collect::>>()? + .into(); - Ok(Arc::new(StructArray::from(vec))) + let arrays = args.to_vec(); + + Ok(Arc::new(StructArray::new(fields, arrays, None))) } /// put values in a struct array. @@ -53,9 +54,11 @@ fn struct_expr(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } + #[derive(Debug)] pub struct StructFunc { signature: Signature, + aliases: Vec, } impl Default for StructFunc { @@ -68,6 +71,7 @@ impl StructFunc { pub fn new() -> Self { Self { signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("row")], } } } @@ -80,6 +84,10 @@ impl ScalarUDFImpl for StructFunc { "struct" } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn signature(&self) -> &Signature { &self.signature } @@ -96,58 +104,56 @@ impl ScalarUDFImpl for StructFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { struct_expr(args) } -} -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::Int64Array; - use datafusion_common::cast::as_struct_array; - use datafusion_common::ScalarValue; - - #[test] - fn test_struct() { - // struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3} - let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), - ]; - let struc = struct_expr(&args) - .expect("failed to initialize function struct") - .into_array(1) - .expect("Failed to convert to array"); - let result = - as_struct_array(&struc).expect("failed to initialize function struct"); - assert_eq!( - &Int64Array::from(vec![1]), - result - .column_by_name("c0") - .unwrap() - .clone() - .as_any() - .downcast_ref::() - .unwrap() - ); - assert_eq!( - &Int64Array::from(vec![2]), - result - .column_by_name("c1") - .unwrap() - .clone() - .as_any() - .downcast_ref::() - .unwrap() - ); - assert_eq!( - &Int64Array::from(vec![3]), - result - .column_by_name("c2") - .unwrap() - .clone() - .as_any() - .downcast_ref::() - .unwrap() - ); + fn documentation(&self) -> Option<&Documentation> { + Some(get_struct_doc()) } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_struct_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRUCT) + .with_description("Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. +For example: `c0`, `c1`, `c2`, etc.") + .with_syntax_example("struct(expression1[, ..., expression_n])") + .with_sql_example(r#"For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `c1`: +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +-- use default names `c0`, `c1` +> select struct(a, b) from t; ++-----------------+ +| struct(t.a,t.b) | ++-----------------+ +| {c0: 1, c1: 2} | +| {c0: 3, c1: 4} | ++-----------------+ + +-- name the first field `field_a` +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ +``` +"#) + .with_argument( + "expression1, expression_n", + "Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators.") + .build() + .unwrap() + }) +} diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs new file mode 100644 index 000000000000..e7ac749ddddc --- /dev/null +++ b/datafusion/functions/src/core/version.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`VersionFunc`]: Implementation of the `version` function. + +use arrow::datatypes::DataType; +use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_OTHER; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; + +#[derive(Debug)] +pub struct VersionFunc { + signature: Signature, +} + +impl Default for VersionFunc { + fn default() -> Self { + Self::new() + } +} + +impl VersionFunc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for VersionFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "version" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> Result { + if args.is_empty() { + Ok(DataType::Utf8) + } else { + plan_err!("version expects no arguments") + } + } + + fn invoke(&self, _: &[ColumnarValue]) -> Result { + not_impl_err!("version does not take any arguments") + } + + fn invoke_no_args(&self, _: usize) -> Result { + // TODO it would be great to add rust version and arrow version, + // but that requires a `build.rs` script and/or adding a version const to arrow-rs + let version = format!( + "Apache DataFusion {}, {} on {}", + env!("CARGO_PKG_VERSION"), + std::env::consts::ARCH, + std::env::consts::OS, + ); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(version)))) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_version_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_version_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_OTHER) + .with_description("Returns the version of DataFusion.") + .with_syntax_example("version()") + .with_sql_example( + r#"```sql +> select version(); ++--------------------------------------------+ +| version() | ++--------------------------------------------+ +| Apache DataFusion 42.0.0, aarch64 on macos | ++--------------------------------------------+ +```"#, + ) + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod test { + use super::*; + use datafusion_expr::ScalarUDF; + + #[tokio::test] + async fn test_version_udf() { + let version_udf = ScalarUDF::from(VersionFunc::new()); + let version = version_udf.invoke_batch(&[], 1).unwrap(); + + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(version))) = version { + assert!(version.starts_with("Apache DataFusion")); + } else { + panic!("Expected version string"); + } + } +} diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index c9dd3c1f56a2..0e43fb7785df 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -19,10 +19,12 @@ use super::basic::{digest, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::*, Volatility, }; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct DigestFunc { @@ -69,4 +71,48 @@ impl ScalarUDFImpl for DigestFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { digest(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_digest_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_digest_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description( + "Computes the binary hash of an expression using the specified algorithm.", + ) + .with_syntax_example("digest(expression, algorithm)") + .with_sql_example( + r#"```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + ) + .with_standard_argument( + "expression", Some("String")) + .with_argument( + "algorithm", + "String expression specifying algorithm to use. Must be one of: + +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3", + ) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index ccb6fbba80aa..062d63bcc018 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -19,8 +19,12 @@ use crate::crypto::basic::md5; use arrow::datatypes::DataType; use datafusion_common::{plan_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct Md5Func { @@ -84,4 +88,32 @@ impl ScalarUDFImpl for Md5Func { fn invoke(&self, args: &[ColumnarValue]) -> Result { md5(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_md5_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_md5_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes an MD5 128-bit checksum for a string expression.") + .with_syntax_example("md5(expression)") + .with_sql_example( + r#"```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/mod.rs b/datafusion/functions/src/crypto/mod.rs index a879fdb45b35..46177fc22b60 100644 --- a/datafusion/functions/src/crypto/mod.rs +++ b/datafusion/functions/src/crypto/mod.rs @@ -17,6 +17,9 @@ //! "crypto" DataFusion functions +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod basic; pub mod digest; pub mod md5; @@ -30,28 +33,36 @@ make_udf_function!(sha224::SHA224Func, SHA224, sha224); make_udf_function!(sha256::SHA256Func, SHA256, sha256); make_udf_function!(sha384::SHA384Func, SHA384, sha384); make_udf_function!(sha512::SHA512Func, SHA512, sha512); -export_functions!(( - digest, - input_arg1 input_arg2, - "Computes the binary hash of an expression using the specified algorithm." -),( - md5, - input_arg, - "Computes an MD5 128-bit checksum for a string expression." -),( - sha224, - input_arg1, - "Computes the SHA-224 hash of a binary string." -),( - sha256, - input_arg1, - "Computes the SHA-256 hash of a binary string." -),( - sha384, - input_arg1, - "Computes the SHA-384 hash of a binary string." -),( - sha512, - input_arg1, - "Computes the SHA-512 hash of a binary string." -)); + +pub mod expr_fn { + export_functions!(( + digest, + "Computes the binary hash of an expression using the specified algorithm.", + input_arg1 input_arg2 + ),( + md5, + "Computes an MD5 128-bit checksum for a string expression.", + input_arg + ),( + sha224, + "Computes the SHA-224 hash of a binary string.", + input_arg1 + ),( + sha256, + "Computes the SHA-256 hash of a binary string.", + input_arg1 + ),( + sha384, + "Computes the SHA-384 hash of a binary string.", + input_arg1 + ),( + sha512, + "Computes the SHA-512 hash of a binary string.", + input_arg1 + )); +} + +/// Returns all DataFusion functions defined in this package +pub fn functions() -> Vec> { + vec![digest(), md5(), sha224(), sha256(), sha384(), sha512()] +} diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index 2795c4a25004..39202d5bf691 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -19,13 +19,18 @@ use super::basic::{sha224, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA224Func { signature: Signature, } + impl Default for SHA224Func { fn default() -> Self { Self::new() @@ -44,6 +49,31 @@ impl SHA224Func { } } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha224_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-224 hash of a binary string.") + .with_syntax_example("sha224(expression)") + .with_sql_example( + r#"```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for SHA224Func { fn as_any(&self) -> &dyn Any { self @@ -60,7 +90,12 @@ impl ScalarUDFImpl for SHA224Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha224(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha224_doc()) + } } diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index 0a3f3b26e431..74deb3fc6caa 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -19,8 +19,12 @@ use super::basic::{sha256, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA256Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA256Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha256(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha256_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha256_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-256 hash of a binary string.") + .with_syntax_example("sha256(expression)") + .with_sql_example( + r#"```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index c3f7845ce7bd..9b1e1ba9ec3c 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -19,8 +19,12 @@ use super::basic::{sha384, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA384Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA384Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha384(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha384_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha384_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-384 hash of a binary string.") + .with_syntax_example("sha384(expression)") + .with_sql_example( + r#"```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index dc3bfac9d8bd..c88579fd08ee 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -19,8 +19,12 @@ use super::basic::{sha512, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_HASHING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct SHA512Func { @@ -60,7 +64,36 @@ impl ScalarUDFImpl for SHA512Func { fn return_type(&self, arg_types: &[DataType]) -> Result { utf8_or_binary_to_binary_type(&arg_types[0], self.name()) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { sha512(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_sha512_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_sha512_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_HASHING) + .with_description("Computes the SHA-512 hash of a binary string.") + .with_syntax_example("sha512(expression)") + .with_sql_example( + r#"```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +```"#, + ) + .with_argument("expression", "String") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index f0689ffd64e9..6e3106a5bce6 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -18,17 +18,20 @@ use std::sync::Arc; use arrow::array::{ - Array, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, + StringViewArray, }; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::LocalResult::Single; use chrono::{DateTime, TimeZone, Utc}; -use itertools::Either; +use crate::strings::StringArrayType; use datafusion_common::cast::as_generic_string_array; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarType, ScalarValue}; +use datafusion_common::{ + exec_err, unwrap_or_internal_err, DataFusionError, Result, ScalarType, ScalarValue, +}; use datafusion_expr::ColumnarValue; /// Error message if nanosecond conversion request beyond supported interval @@ -39,14 +42,15 @@ pub(crate) fn string_to_timestamp_nanos_shim(s: &str) -> Result { string_to_timestamp_nanos(s).map_err(|e| e.into()) } -/// Checks that all the arguments from the second are of type [Utf8] or [LargeUtf8] +/// Checks that all the arguments from the second are of type [Utf8], [LargeUtf8] or [Utf8View] /// /// [Utf8]: DataType::Utf8 /// [LargeUtf8]: DataType::LargeUtf8 +/// [Utf8View]: DataType::Utf8View pub(crate) fn validate_data_types(args: &[ColumnarValue], name: &str) -> Result<()> { for (idx, a) in args.iter().skip(1).enumerate() { match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } _ => { @@ -93,7 +97,9 @@ pub(crate) fn string_to_datetime_formatted( if let Err(e) = &dt { // no timezone or other failure, try without a timezone - let ndt = parsed.to_naive_datetime_with_offset(0); + let ndt = parsed + .to_naive_datetime_with_offset(0) + .or_else(|_| parsed.to_naive_date().map(|nd| nd.into())); if let Err(e) = &ndt { return Err(err(&e.to_string())); } @@ -149,26 +155,68 @@ pub(crate) fn string_to_timestamp_nanos_formatted( }) } -pub(crate) fn handle<'a, O, F, S>( - args: &'a [ColumnarValue], +/// Accepts a string with a `chrono` format and converts it to a +/// millisecond precision timestamp. +/// +/// See [`chrono::format::strftime`] for the full set of supported formats. +/// +/// Internally, this function uses the `chrono` library for the +/// datetime parsing +/// +/// ## Timezone / Offset Handling +/// +/// Numerical values of timestamps are stored compared to offset UTC. +/// +/// Any timestamp in the formatting string is handled according to the rules +/// defined by `chrono`. +/// +/// [`chrono::format::strftime`]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html +/// +#[inline] +pub(crate) fn string_to_timestamp_millis_formatted(s: &str, format: &str) -> Result { + Ok(string_to_datetime_formatted(&Utc, s, format)? + .naive_utc() + .and_utc() + .timestamp_millis()) +} + +pub(crate) fn handle( + args: &[ColumnarValue], op: F, name: &str, ) -> Result where O: ArrowPrimitiveType, S: ScalarType, - F: Fn(&'a str) -> Result, + F: Fn(&str) -> Result, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( - unary_string_to_primitive_function::(&[a.as_ref()], op, name)?, + DataType::Utf8View => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&StringViewArray, O, _>( + a.as_ref().as_string_view(), + op, + )?, + ))), + DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&GenericStringArray, O, _>( + a.as_ref().as_string::(), + op, + )?, + ))), + DataType::Utf8 => Ok(ColumnarValue::Array(Arc::new( + unary_string_to_primitive_function::<&GenericStringArray, O, _>( + a.as_ref().as_string::(), + op, + )?, ))), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x)).transpose()?; + ScalarValue::Utf8View(a) + | ScalarValue::LargeUtf8(a) + | ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| op(x)).transpose()?; Ok(ColumnarValue::Scalar(S::scalar(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), @@ -176,11 +224,11 @@ where } } -// given an function that maps a `&str`, `&str` to an arrow native type, +// Given a function that maps a `&str`, `&str` to an arrow native type, // returns a `ColumnarValue` where the function is applied to either a `ArrayRef` or `ScalarValue` // depending on the `args`'s variant. -pub(crate) fn handle_multiple<'a, O, F, S, M>( - args: &'a [ColumnarValue], +pub(crate) fn handle_multiple( + args: &[ColumnarValue], op: F, op2: M, name: &str, @@ -188,24 +236,24 @@ pub(crate) fn handle_multiple<'a, O, F, S, M>( where O: ArrowPrimitiveType, S: ScalarType, - F: Fn(&'a str, &'a str) -> Result, + F: Fn(&str, &str) -> Result, M: Fn(O::Native) -> O::Native, { match &args[0] { ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // validate the column types for (pos, arg) in args.iter().enumerate() { match arg { ColumnarValue::Array(arg) => match arg.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View | DataType::LargeUtf8 | DataType::Utf8 => { // all good } other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), }, ColumnarValue::Scalar(arg) => { match arg.data_type() { - DataType::Utf8 | DataType::LargeUtf8 => { + DataType::Utf8View| DataType::LargeUtf8 | DataType::Utf8 => { // all good } other => return exec_err!("Unsupported data type {other:?} for function {name}, arg # {pos}"), @@ -215,7 +263,7 @@ where } Ok(ColumnarValue::Array(Arc::new( - strings_to_primitive_function::(args, op, op2, name)?, + strings_to_primitive_function::(args, op, op2, name)?, ))) } other => { @@ -224,47 +272,39 @@ where }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - let mut val: Option> = None; - let mut err: Option = None; + ScalarValue::Utf8View(a) + | ScalarValue::LargeUtf8(a) + | ScalarValue::Utf8(a) => { + let a = a.as_ref(); + // ASK: Why do we trust `a` to be non-null at this point? + let a = unwrap_or_internal_err!(a); - match a { - Some(a) => { - // enumerate all the values finding the first one that returns an Ok result - for (pos, v) in args.iter().enumerate().skip(1) { - if let ColumnarValue::Scalar(s) = v { - if let ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x) = - s - { - if let Some(s) = x { - match op(a.as_str(), s.as_str()) { - Ok(r) => { - val = Some(Ok(ColumnarValue::Scalar( - S::scalar(Some(op2(r))), - ))); - break; - } - Err(e) => { - err = Some(e); - } - } - } - } else { - return exec_err!("Unsupported data type {s:?} for function {name}, arg # {pos}"); - } - } else { - return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); + let mut ret = None; + + for (pos, v) in args.iter().enumerate().skip(1) { + let ColumnarValue::Scalar( + ScalarValue::Utf8View(x) + | ScalarValue::LargeUtf8(x) + | ScalarValue::Utf8(x), + ) = v + else { + return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); + }; + + if let Some(s) = x { + match op(a.as_str(), s.as_str()) { + Ok(r) => { + ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some( + op2(r), + ))))); + break; } + Err(e) => ret = Some(Err(e)), } } - None => (), } - if let Some(v) = val { - v - } else { - Err(err.unwrap()) - } + unwrap_or_internal_err!(ret) } other => { exec_err!("Unsupported data type {other:?} for function {name}") @@ -282,18 +322,16 @@ where /// # Errors /// This function errors iff: /// * the number of arguments is not > 1 or -/// * the array arguments are not castable to a `GenericStringArray` or /// * the function `op` errors for all input -pub(crate) fn strings_to_primitive_function<'a, T, O, F, F2>( - args: &'a [ColumnarValue], +pub(crate) fn strings_to_primitive_function( + args: &[ColumnarValue], op: F, op2: F2, name: &str, ) -> Result> where O: ArrowPrimitiveType, - T: OffsetSizeTrait, - F: Fn(&'a str, &'a str) -> Result, + F: Fn(&str, &str) -> Result, F2: Fn(O::Native) -> O::Native, { if args.len() < 2 { @@ -304,50 +342,90 @@ where ); } - // this will throw the error if any of the array args are not castable to GenericStringArray - let data = args - .iter() - .map(|a| match a { - ColumnarValue::Array(a) => { - Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => { + let string_array = a.as_string_view(); + handle_array_op::( + &string_array, + &args[1..], + op, + op2, + ) } - ColumnarValue::Scalar(s) => match s { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(Either::Right(a)), - other => exec_err!( - "Unexpected scalar type encountered '{other}' for function '{name}'" - ), - }, - }) - .collect::, &Option>>>>()?; - - let first_arg = &data.first().unwrap().left().unwrap(); + DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&a)?; + handle_array_op::, F, F2>( + &string_array, + &args[1..], + op, + op2, + ) + } + DataType::Utf8 => { + let string_array = as_generic_string_array::(&a)?; + handle_array_op::, F, F2>( + &string_array, + &args[1..], + op, + op2, + ) + } + other => exec_err!( + "Unsupported data type {other:?} for function substr,\ + expected Utf8View, Utf8 or LargeUtf8." + ), + }, + other => exec_err!( + "Received {} data type, expected only array", + other.data_type() + ), + } +} - first_arg +fn handle_array_op<'a, O, V, F, F2>( + first: &V, + args: &[ColumnarValue], + op: F, + op2: F2, +) -> Result> +where + V: StringArrayType<'a>, + O: ArrowPrimitiveType, + F: Fn(&str, &str) -> Result, + F2: Fn(O::Native) -> O::Native, +{ + first .iter() .enumerate() .map(|(pos, x)| { let mut val = None; - if let Some(x) = x { - let param_args = data.iter().skip(1); - - // go through the args and find the first successful result. Only the last - // failure will be returned if no successful result was received. - for param_arg in param_args { - // param_arg is an array, use the corresponding index into the array as the arg - // we're currently parsing - let p = *param_arg; - let r = if p.is_left() { - let p = p.left().unwrap(); - op(x, p.value(pos)) - } - // args is a scalar, use it directly - else if let Some(p) = p.right().unwrap() { - op(x, p.as_str()) - } else { - continue; - }; + for arg in args { + let v = match arg { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8View => Ok(a.as_string_view().value(pos)), + DataType::LargeUtf8 => Ok(a.as_string::().value(pos)), + DataType::Utf8 => Ok(a.as_string::().value(pos)), + other => exec_err!("Unexpected type encountered '{other}'"), + }, + ColumnarValue::Scalar(s) => match s { + ScalarValue::Utf8View(a) + | ScalarValue::LargeUtf8(a) + | ScalarValue::Utf8(a) => { + if let Some(v) = a { + Ok(v.as_str()) + } else { + continue; + } + } + other => { + exec_err!("Unexpected scalar type encountered '{other}'") + } + }, + }?; + let r = op(x, v); if r.is_ok() { val = Some(Ok(op2(r.unwrap()))); break; @@ -368,28 +446,16 @@ where /// # Errors /// This function errors iff: /// * the number of arguments is not 1 or -/// * the first argument is not castable to a `GenericStringArray` or /// * the function `op` errors -fn unary_string_to_primitive_function<'a, T, O, F>( - args: &[&'a dyn Array], +fn unary_string_to_primitive_function<'a, StringArrType, O, F>( + array: StringArrType, op: F, - name: &str, ) -> Result> where + StringArrType: StringArrayType<'a>, O: ArrowPrimitiveType, - T: OffsetSizeTrait, F: Fn(&'a str) -> Result, { - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } - - let array = as_generic_string_array::(args[0])?; - // first map is the iterator, second is for the `Option<_>` array.iter().map(|x| x.map(&op).transpose()).collect() } diff --git a/datafusion/functions/src/datetime/current_date.rs b/datafusion/functions/src/datetime/current_date.rs index 8b180ff41b91..24046611a71f 100644 --- a/datafusion/functions/src/datetime/current_date.rs +++ b/datafusion/functions/src/datetime/current_date.rs @@ -22,8 +22,12 @@ use arrow::datatypes::DataType::Date32; use chrono::{Datelike, NaiveDate}; use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::OnceLock; #[derive(Debug)] pub struct CurrentDateFunc { @@ -95,4 +99,25 @@ impl ScalarUDFImpl for CurrentDateFunc { ScalarValue::Date32(days), ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_current_date_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_current_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC date. + +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. +"#) + .with_syntax_example("current_date()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/current_time.rs b/datafusion/functions/src/datetime/current_time.rs index 803759d4e904..4122b54b07e8 100644 --- a/datafusion/functions/src/datetime/current_time.rs +++ b/datafusion/functions/src/datetime/current_time.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Time64; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; +use std::sync::OnceLock; use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct CurrentTimeFunc { @@ -84,4 +87,25 @@ impl ScalarUDFImpl for CurrentTimeFunc { ScalarValue::Time64Nanosecond(nano), ))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_current_time_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_current_time_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC time. + +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. +"#) + .with_syntax_example("current_time()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index da1797cdae81..065201e1caa3 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::temporal_conversions::NANOSECONDS; use arrow::array::types::{ @@ -29,16 +29,18 @@ use arrow::datatypes::DataType::{Null, Timestamp, Utf8}; use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; -use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, not_impl_err, plan_err, Result, ScalarValue}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; +use chrono::{DateTime, Datelike, Duration, Months, TimeDelta, Utc}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; + #[derive(Debug)] pub struct DateBinFunc { signature: Signature, @@ -56,35 +58,35 @@ impl DateBinFunc { vec![ Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), Timestamp(Nanosecond, None), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), Timestamp(Nanosecond, None), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), ]), Exact(vec![ DataType::Interval(MonthDayNano), - Timestamp(array_type.clone(), Some(TIMEZONE_WILDCARD.into())), + Timestamp(array_type, Some(TIMEZONE_WILDCARD.into())), ]), Exact(vec![ DataType::Interval(DayTime), - Timestamp(array_type.clone(), None), + Timestamp(array_type, None), ]), Exact(vec![ DataType::Interval(DayTime), @@ -146,9 +148,60 @@ impl ScalarUDFImpl for DateBinFunc { } } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![None, Some(true)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // The DATE_BIN function preserves the order of its second argument. + let step = &input[0]; + let date_value = &input[1]; + let reference = input.get(2); + + if step.sort_properties.eq(&SortProperties::Singleton) + && reference + .map(|r| r.sort_properties.eq(&SortProperties::Singleton)) + .unwrap_or(true) + { + Ok(date_value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_bin_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_bin_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. + +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. +"#) + .with_syntax_example("date_bin(interval, expression, origin-timestamp)") + .with_argument("interval", "Bin interval.") + .with_argument("expression", "Time expression to operate on. Can be a constant, column, or function.") + .with_argument("origin-timestamp", "Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). + +The following intervals are supported: + +- nanoseconds +- microseconds +- milliseconds +- seconds +- minutes +- hours +- days +- weeks +- months +- years +- century +") + .build() + .unwrap() + }) } enum Interval { @@ -187,7 +240,7 @@ fn date_bin_nanos_interval(stride_nanos: i64, source: i64, origin: i64) -> i64 { fn compute_distance(time_diff: i64, stride: i64) -> i64 { let time_delta = time_diff - (time_diff % stride); - if time_diff < 0 && stride > 1 { + if time_diff < 0 && stride > 1 && time_delta != time_diff { // The origin is later than the source timestamp, round down to the previous bin time_delta - stride } else { @@ -425,21 +478,26 @@ fn date_bin_impl( mod tests { use std::sync::Arc; + use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc}; use arrow::array::types::TimestampNanosecondType; use arrow::array::{IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, TimeUnit}; - use chrono::TimeDelta; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::datetime::date_bin::{date_bin_nanos_interval, DateBinFunc}; + use chrono::TimeDelta; #[test] + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch fn test_date_bin() { let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); @@ -447,21 +505,33 @@ mod tests { let timestamps = Arc::new((1..6).map(Some).collect::()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), ColumnarValue::Array(timestamps), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); // stride supports month-day-nano let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); @@ -472,8 +542,12 @@ mod tests { // // invalid number of arguments - let res = DateBinFunc::new() - .invoke(&[ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1)))]); + let res = DateBinFunc::new().invoke(&[ColumnarValue::Scalar( + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + })), + )]); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" @@ -492,7 +566,10 @@ mod tests { // stride: invalid value let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(0))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 0, + }))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); @@ -503,7 +580,9 @@ mod tests { // stride: overflow of day-time interval let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(i64::MAX))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); @@ -536,7 +615,10 @@ mod tests { // origin: invalid type let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ]); @@ -546,14 +628,26 @@ mod tests { ); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); // unsupported array type for stride - let intervals = Arc::new((1..6).map(Some).collect::()); + let intervals = Arc::new( + (1..6) + .map(|x| { + Some(IntervalDayTime { + days: 0, + milliseconds: x, + }) + }) + .collect::(), + ); let res = DateBinFunc::new().invoke(&[ ColumnarValue::Array(intervals), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), @@ -567,7 +661,10 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(1))), + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Array(timestamps), ]); @@ -685,6 +782,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateBinFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), @@ -768,4 +866,32 @@ mod tests { assert_eq!(result, expected1, "{source} = {expected}"); }) } + + #[test] + fn test_date_bin_before_epoch() { + let cases = [ + ( + (TimeDelta::try_minutes(15), "1969-12-31T23:44:59.999999999"), + "1969-12-31T23:30:00", + ), + ( + (TimeDelta::try_minutes(15), "1969-12-31T23:45:00"), + "1969-12-31T23:45:00", + ), + ( + (TimeDelta::try_minutes(15), "1969-12-31T23:45:00.000000001"), + "1969-12-31T23:45:00", + ), + ]; + + cases.iter().for_each(|((stride, source), expected)| { + let stride = stride.unwrap(); + let stride1 = stride.num_nanoseconds().unwrap(); + let source1 = string_to_timestamp_nanos(source).unwrap(); + + let expected1 = string_to_timestamp_nanos(expected).unwrap(); + let result = date_bin_nanos_interval(stride1, source1, 0); + assert_eq!(result, expected1, "{source} = {expected}"); + }) + } } diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 111cdabe2bfb..01e094bc4e0b 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -16,13 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::str::FromStr; +use std::sync::{Arc, OnceLock}; use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::compute::kernels::cast_utils::IntervalUnit; use arrow::compute::{binary, cast, date_part, DatePart}; use arrow::datatypes::DataType::{ - Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, + Date32, Date64, Duration, Float64, Interval, Time32, Time64, Timestamp, Utf8, + Utf8View, }; +use arrow::datatypes::IntervalUnit::{DayTime, MonthDayNano, YearMonth}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; @@ -33,9 +37,10 @@ use datafusion_common::cast::{ as_timestamp_nanosecond_array, as_timestamp_second_array, }; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; #[derive(Debug)] @@ -56,31 +61,71 @@ impl DatePartFunc { signature: Signature::one_of( vec![ Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![Utf8View, Timestamp(Nanosecond, None)]), Exact(vec![ Utf8, Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![Utf8View, Timestamp(Millisecond, None)]), Exact(vec![ Utf8, Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![Utf8View, Timestamp(Microsecond, None)]), Exact(vec![ Utf8, Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![Utf8View, Timestamp(Second, None)]), Exact(vec![ Utf8, Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Date64]), + Exact(vec![Utf8View, Date64]), Exact(vec![Utf8, Date32]), + Exact(vec![Utf8View, Date32]), Exact(vec![Utf8, Time32(Second)]), + Exact(vec![Utf8View, Time32(Second)]), Exact(vec![Utf8, Time32(Millisecond)]), + Exact(vec![Utf8View, Time32(Millisecond)]), Exact(vec![Utf8, Time64(Microsecond)]), + Exact(vec![Utf8View, Time64(Microsecond)]), Exact(vec![Utf8, Time64(Nanosecond)]), + Exact(vec![Utf8View, Time64(Nanosecond)]), + Exact(vec![Utf8, Interval(YearMonth)]), + Exact(vec![Utf8View, Interval(YearMonth)]), + Exact(vec![Utf8, Interval(DayTime)]), + Exact(vec![Utf8View, Interval(DayTime)]), + Exact(vec![Utf8, Interval(MonthDayNano)]), + Exact(vec![Utf8View, Interval(MonthDayNano)]), + Exact(vec![Utf8, Duration(Second)]), + Exact(vec![Utf8View, Duration(Second)]), + Exact(vec![Utf8, Duration(Millisecond)]), + Exact(vec![Utf8View, Duration(Millisecond)]), + Exact(vec![Utf8, Duration(Microsecond)]), + Exact(vec![Utf8View, Duration(Microsecond)]), + Exact(vec![Utf8, Duration(Nanosecond)]), + Exact(vec![Utf8View, Duration(Nanosecond)]), ], Volatility::Immutable, ), @@ -114,6 +159,8 @@ impl ScalarUDFImpl for DatePartFunc { let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { v + } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = part { + v } else { return exec_err!( "First argument of `DATE_PART` must be non-null scalar Utf8" @@ -123,26 +170,42 @@ impl ScalarUDFImpl for DatePartFunc { let is_scalar = matches!(array, ColumnarValue::Scalar(_)); let array = match array { - ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Array(array) => Arc::clone(array), ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; - let arr = match part.to_lowercase().as_str() { - "year" => date_part_f64(array.as_ref(), DatePart::Year)?, - "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, - "month" => date_part_f64(array.as_ref(), DatePart::Month)?, - "week" => date_part_f64(array.as_ref(), DatePart::Week)?, - "day" => date_part_f64(array.as_ref(), DatePart::Day)?, - "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, - "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, - "hour" => date_part_f64(array.as_ref(), DatePart::Hour)?, - "minute" => date_part_f64(array.as_ref(), DatePart::Minute)?, - "second" => seconds(array.as_ref(), Second)?, - "millisecond" => seconds(array.as_ref(), Millisecond)?, - "microsecond" => seconds(array.as_ref(), Microsecond)?, - "nanosecond" => seconds(array.as_ref(), Nanosecond)?, - "epoch" => epoch(array.as_ref())?, - _ => return exec_err!("Date part '{part}' not supported"), + // to remove quotes at most 2 characters + let part_trim = part.trim_matches(|c| c == '\'' || c == '\"'); + if ![2, 0].contains(&(part.len() - part_trim.len())) { + return exec_err!("Date part '{part}' not supported"); + } + + // using IntervalUnit here means we hand off all the work of supporting plurals (like "seconds") + // and synonyms ( like "ms,msec,msecond,millisecond") to Arrow + let arr = if let Ok(interval_unit) = IntervalUnit::from_str(part_trim) { + match interval_unit { + IntervalUnit::Year => date_part_f64(array.as_ref(), DatePart::Year)?, + IntervalUnit::Month => date_part_f64(array.as_ref(), DatePart::Month)?, + IntervalUnit::Week => date_part_f64(array.as_ref(), DatePart::Week)?, + IntervalUnit::Day => date_part_f64(array.as_ref(), DatePart::Day)?, + IntervalUnit::Hour => date_part_f64(array.as_ref(), DatePart::Hour)?, + IntervalUnit::Minute => date_part_f64(array.as_ref(), DatePart::Minute)?, + IntervalUnit::Second => seconds(array.as_ref(), Second)?, + IntervalUnit::Millisecond => seconds(array.as_ref(), Millisecond)?, + IntervalUnit::Microsecond => seconds(array.as_ref(), Microsecond)?, + IntervalUnit::Nanosecond => seconds(array.as_ref(), Nanosecond)?, + // century and decade are not supported by `DatePart`, although they are supported in postgres + _ => return exec_err!("Date part '{part}' not supported"), + } + } else { + // special cases that can be extracted (in postgres) but are not interval units + match part_trim.to_lowercase().as_str() { + "qtr" | "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, + "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, + "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "epoch" => epoch(array.as_ref())?, + _ => return exec_err!("Date part '{part}' not supported"), + } }; Ok(if is_scalar { @@ -155,6 +218,47 @@ impl ScalarUDFImpl for DatePartFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_part_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_part_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Returns the specified part of the date as an integer.") + .with_syntax_example("date_part(part, expression)") + .with_argument( + "part", + r#"Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) +"#, + ) + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function.", + ) + .with_alternative_syntax("extract(field FROM source)") + .build() + .unwrap() + }) } /// Invoke [`date_part`] and cast the result to Float64 @@ -178,10 +282,28 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { let subsecs = date_part(array, DatePart::Nanosecond)?; let subsecs = as_int32_array(subsecs.as_ref())?; - let r: Float64Array = binary(secs, subsecs, |secs, subsecs| { - (secs as f64 + (subsecs as f64 / 1_000_000_000_f64)) * sf - })?; - Ok(Arc::new(r)) + // Special case where there are no nulls. + if subsecs.null_count() == 0 { + let r: Float64Array = binary(secs, subsecs, |secs, subsecs| { + (secs as f64 + ((subsecs % 1_000_000_000) as f64 / 1_000_000_000_f64)) * sf + })?; + Ok(Arc::new(r)) + } else { + // Nulls in secs are preserved, nulls in subsecs are treated as zero to account for the case + // where the number of nanoseconds overflows. + let r: Float64Array = secs + .iter() + .zip(subsecs) + .map(|(secs, subsecs)| { + secs.map(|secs| { + let subsecs = subsecs.unwrap_or(0); + (secs as f64 + ((subsecs % 1_000_000_000) as f64 / 1_000_000_000_f64)) + * sf + }) + }) + .collect(); + Ok(Arc::new(r)) + } } fn epoch(array: &dyn Array) -> Result { @@ -210,7 +332,8 @@ fn epoch(array: &dyn Array) -> Result { Time64(Nanosecond) => { as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - d => return exec_err!("Can not convert {d:?} to epoch"), + Interval(_) | Duration(_) => return seconds(array, Second), + d => return exec_err!("Cannot convert {d:?} to epoch"), }; Ok(Arc::new(f)) } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 0414bf9c2a26..f8abef601f70 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::ops::{Add, Sub}; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::temporal_conversions::{ as_datetime_with_timezone, timestamp_ns_to_datetime, @@ -29,21 +29,21 @@ use arrow::array::types::{ TimestampNanosecondType, TimestampSecondType, }; use arrow::array::{Array, PrimitiveArray}; -use arrow::datatypes::DataType::{Null, Timestamp, Utf8}; -use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; -use chrono::{ - DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, -}; - +use arrow::datatypes::DataType::{self, Null, Timestamp, Utf8, Utf8View}; +use arrow::datatypes::TimeUnit::{self, Microsecond, Millisecond, Nanosecond, Second}; use datafusion_common::cast::as_primitive_array; use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; +use chrono::{ + DateTime, Datelike, Duration, LocalResult, NaiveDateTime, Offset, TimeDelta, Timelike, +}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; + #[derive(Debug)] pub struct DateTruncFunc { signature: Signature, @@ -62,25 +62,45 @@ impl DateTruncFunc { signature: Signature::one_of( vec![ Exact(vec![Utf8, Timestamp(Nanosecond, None)]), + Exact(vec![Utf8View, Timestamp(Nanosecond, None)]), Exact(vec![ Utf8, Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Microsecond, None)]), + Exact(vec![Utf8View, Timestamp(Microsecond, None)]), Exact(vec![ Utf8, Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Millisecond, None)]), + Exact(vec![Utf8View, Timestamp(Millisecond, None)]), Exact(vec![ Utf8, Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + ]), Exact(vec![Utf8, Timestamp(Second, None)]), + Exact(vec![Utf8View, Timestamp(Second, None)]), Exact(vec![ Utf8, Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), ]), + Exact(vec![ + Utf8View, + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + ]), ], Volatility::Immutable, ), @@ -104,7 +124,9 @@ impl ScalarUDFImpl for DateTruncFunc { fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[1] { - Timestamp(Nanosecond, None) | Utf8 | Null => Ok(Timestamp(Nanosecond, None)), + Timestamp(Nanosecond, None) | Utf8 | DataType::Date32 | Null => { + Ok(Timestamp(Nanosecond, None)) + } Timestamp(Nanosecond, tz_opt) => Ok(Timestamp(Nanosecond, tz_opt.clone())), Timestamp(Microsecond, tz_opt) => Ok(Timestamp(Microsecond, tz_opt.clone())), Timestamp(Millisecond, tz_opt) => Ok(Timestamp(Millisecond, tz_opt.clone())), @@ -120,6 +142,9 @@ impl ScalarUDFImpl for DateTruncFunc { let granularity = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity + { + v.to_lowercase() + } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = granularity { v.to_lowercase() } else { @@ -167,36 +192,37 @@ impl ScalarUDFImpl for DateTruncFunc { } ColumnarValue::Array(array) => { let array_type = array.data_type(); - match array_type { - Timestamp(Second, tz_opt) => { - process_array::(array, granularity, tz_opt)? + if let Timestamp(unit, tz_opt) = array_type { + match unit { + Second => process_array::( + array, + granularity, + tz_opt, + )?, + Millisecond => process_array::( + array, + granularity, + tz_opt, + )?, + Microsecond => process_array::( + array, + granularity, + tz_opt, + )?, + Nanosecond => process_array::( + array, + granularity, + tz_opt, + )?, } - Timestamp(Millisecond, tz_opt) => process_array::< - TimestampMillisecondType, - >( - array, granularity, tz_opt - )?, - Timestamp(Microsecond, tz_opt) => process_array::< - TimestampMicrosecondType, - >( - array, granularity, tz_opt - )?, - Timestamp(Nanosecond, tz_opt) => process_array::< - TimestampNanosecondType, - >( - array, granularity, tz_opt - )?, - _ => process_array::( - array, - granularity, - &None, - )?, + } else { + return exec_err!("second argument of `date_trunc` is an unsupported array type: {array_type}"); } } _ => { return exec_err!( - "second argument of `date_trunc` must be nanosecond timestamp scalar or array" - ); + "second argument of `date_trunc` must be timestamp scalar or array" + ); } }) } @@ -205,11 +231,53 @@ impl ScalarUDFImpl for DateTruncFunc { &self.aliases } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![None, Some(true)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // The DATE_TRUNC function preserves the order of its second argument. + let precision = &input[0]; + let date_value = &input[1]; + + if precision.sort_properties.eq(&SortProperties::Singleton) { + Ok(date_value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } + } + fn documentation(&self) -> Option<&Documentation> { + Some(get_date_trunc_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_date_trunc_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Truncates a timestamp value to a specified precision.") + .with_syntax_example("date_trunc(precision, expression)") + .with_argument( + "precision", + r#"Time precision to truncate to. The following precisions are supported: + + - year / YEAR + - quarter / QUARTER + - month / MONTH + - week / WEEK + - day / DAY + - hour / HOUR + - minute / MINUTE + - second / SECOND +"#, + ) + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function.", + ) + .build() + .unwrap() + }) +} + fn _date_trunc_coarse(granularity: &str, value: Option) -> Result> where T: Datelike + Timelike + Sub + Copy, @@ -410,7 +478,10 @@ fn parse_tz(tz: &Option>) -> Result> { #[cfg(test)] mod tests { + use std::sync::Arc; + use crate::datetime::date_trunc::{date_trunc_coarse, DateTruncFunc}; + use arrow::array::cast::as_primitive_array; use arrow::array::types::TimestampNanosecondType; use arrow::array::TimestampNanosecondArray; @@ -418,7 +489,6 @@ mod tests { use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use std::sync::Arc; #[test] fn date_trunc_test() { @@ -654,6 +724,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateTruncFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::from("day")), @@ -812,6 +883,7 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::() .with_timezone_opt(tz_opt.clone()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = DateTruncFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::from("hour")), diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index d36ebe735ee7..29b2f29b14c2 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -16,13 +16,17 @@ // under the License. use std::any::Any; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Int64, Timestamp}; +use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; - -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FromUnixtimeFunc { @@ -38,7 +42,10 @@ impl Default for FromUnixtimeFunc { impl FromUnixtimeFunc { pub fn new() -> Self { Self { - signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + signature: Signature::one_of( + vec![Exact(vec![Int64, Utf8]), Exact(vec![Int64])], + Volatility::Immutable, + ), } } } @@ -56,26 +63,134 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + arg_types: &[DataType], + ) -> Result { + match arg_types.len() { + 1 => Ok(Timestamp(Second, None)), + 2 => match &args[1] { + Expr::Literal(ScalarValue::Utf8(Some(tz))) => Ok(Timestamp(Second, Some(Arc::from(tz.to_string())))), + _ => exec_err!( + "Second argument for `from_unixtime` must be non-null utf8, received {:?}", + arg_types[1]), + }, + _ => exec_err!( + "from_unixtime function requires 1 or 2 arguments, got {}", + arg_types.len() + ), + } + } + fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Second, None)) + internal_err!("call return_type_from_exprs instead") } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.len() != 1 { + let len = args.len(); + if len != 1 && len != 2 { return exec_err!( - "from_unixtime function requires 1 argument, got {}", + "from_unixtime function requires 1 or 2 argument, got {}", args.len() ); } - match args[0].data_type() { - Int64 => args[0].cast_to(&Timestamp(Second, None), None), - other => { - exec_err!( - "Unsupported data type {:?} for function from_unixtime", - other - ) + if args[0].data_type() != Int64 { + return exec_err!( + "Unsupported data type {:?} for function from_unixtime", + args[0].data_type() + ); + } + + match len { + 1 => args[0].cast_to(&Timestamp(Second, None), None), + 2 => match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(tz))) => args[0] + .cast_to(&Timestamp(Second, Some(Arc::from(tz.to_string()))), None), + _ => { + exec_err!( + "Unsupported data type {:?} for function from_unixtime", + args[1].data_type() + ) + } + }, + _ => unreachable!(), + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_from_unixtime_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_from_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp.") + .with_syntax_example("from_unixtime(expression[, timezone])") + .with_standard_argument("expression", None) + .with_argument( + "timezone", + "Optional timezone to use when converting the integer to a timestamp. If not provided, the default timezone is UTC.", + ) + .with_sql_example(r#"```sql +> select from_unixtime(1599572549, 'America/New_York'); ++-----------------------------------------------------------+ +| from_unixtime(Int64(1599572549),Utf8("America/New_York")) | ++-----------------------------------------------------------+ +| 2020-09-08T09:42:29-04:00 | ++-----------------------------------------------------------+ +```"#) + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod test { + use crate::datetime::from_unixtime::FromUnixtimeFunc; + use datafusion_common::ScalarValue; + use datafusion_common::ScalarValue::Int64; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_without_timezone() { + let args = [ColumnarValue::Scalar(Int64(Some(1729900800)))]; + + #[allow(deprecated)] // TODO use invoke_batch + let result = FromUnixtimeFunc::new().invoke(&args).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(sec), None)) => { + assert_eq!(sec, 1729900800); + } + _ => panic!("Expected scalar value"), + } + } + + #[test] + fn test_with_timezone() { + let args = [ + ColumnarValue::Scalar(Int64(Some(1729900800))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "America/New_York".to_string(), + ))), + ]; + + #[allow(deprecated)] // TODO use invoke_batch + let result = FromUnixtimeFunc::new().invoke(&args).unwrap(); + + match result { + ColumnarValue::Scalar(ScalarValue::TimestampSecond(Some(sec), Some(tz))) => { + assert_eq!(sec, 1729900800); + assert_eq!(tz.to_string(), "America/New_York"); } + _ => panic!("Expected scalar value"), } } } diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 6aa72572bc4d..6b246cb088a2 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -16,18 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::builder::PrimitiveBuilder; use arrow::array::cast::AsArray; use arrow::array::types::{Date32Type, Int32Type}; use arrow::array::PrimitiveArray; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8}; +use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf8View}; use chrono::prelude::*; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct MakeDateFunc { @@ -45,7 +48,7 @@ impl MakeDateFunc { Self { signature: Signature::uniform( 3, - vec![Int32, Int64, UInt32, UInt64, Utf8], + vec![Int32, Int64, UInt32, UInt64, Utf8, Utf8View], Volatility::Immutable, ), } @@ -86,9 +89,9 @@ impl ScalarUDFImpl for MakeDateFunc { ColumnarValue::Array(a) => Some(a.len()), }); - let years = args[0].cast_to(&DataType::Int32, None)?; - let months = args[1].cast_to(&DataType::Int32, None)?; - let days = args[2].cast_to(&DataType::Int32, None)?; + let years = args[0].cast_to(&Int32, None)?; + let months = args[1].cast_to(&Int32, None)?; + let days = args[2].cast_to(&Int32, None)?; let scalar_value_fn = |col: &ColumnarValue| -> Result { let ColumnarValue::Scalar(s) = col else { @@ -148,6 +151,47 @@ impl ScalarUDFImpl for MakeDateFunc { Ok(value) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_make_date_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_make_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Make a date from year/month/day component parts.") + .with_syntax_example("make_date(year, month, day)") + .with_argument( + "year", + " Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.", ) + .with_argument( + "month", + "Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.", + ) + .with_argument("day", "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.") + .with_sql_example(r#"```sql +> select make_date(2023, 1, 31); ++-------------------------------------------+ +| make_date(Int64(2023),Int64(1),Int64(31)) | ++-------------------------------------------+ +| 2023-01-31 | ++-------------------------------------------+ +> select make_date('2023', '01', '31'); ++-----------------------------------------------+ +| make_date(Utf8("2023"),Utf8("01"),Utf8("31")) | ++-----------------------------------------------+ +| 2023-01-31 | ++-----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +"#) + .build() + .unwrap() + }) } /// Converts the year/month/day fields to an `i32` representing the days from @@ -190,6 +234,7 @@ mod tests { #[test] fn test_make_date() { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), @@ -204,6 +249,7 @@ mod tests { panic!("Expected a scalar value") } + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), @@ -218,6 +264,7 @@ mod tests { panic!("Expected a scalar value") } + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), @@ -235,6 +282,7 @@ mod tests { let years = Arc::new((2021..2025).map(Some).collect::()); let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Array(years), @@ -260,6 +308,7 @@ mod tests { // // invalid number of arguments + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new() .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); assert_eq!( @@ -268,6 +317,7 @@ mod tests { ); // invalid type + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), @@ -279,6 +329,7 @@ mod tests { ); // overflow of month + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), @@ -290,6 +341,7 @@ mod tests { ); // overflow of day + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let res = MakeDateFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), diff --git a/datafusion/functions/src/datetime/mod.rs b/datafusion/functions/src/datetime/mod.rs index c6939976eb02..db4e365267dd 100644 --- a/datafusion/functions/src/datetime/mod.rs +++ b/datafusion/functions/src/datetime/mod.rs @@ -32,6 +32,7 @@ pub mod make_date; pub mod now; pub mod to_char; pub mod to_date; +pub mod to_local_time; pub mod to_timestamp; pub mod to_unixtime; @@ -50,6 +51,7 @@ make_udf_function!( make_udf_function!(now::NowFunc, NOW, now); make_udf_function!(to_char::ToCharFunc, TO_CHAR, to_char); make_udf_function!(to_date::ToDateFunc, TO_DATE, to_date); +make_udf_function!(to_local_time::ToLocalTimeFunc, TO_LOCAL_TIME, to_local_time); make_udf_function!(to_unixtime::ToUnixtimeFunc, TO_UNIXTIME, to_unixtime); make_udf_function!(to_timestamp::ToTimestampFunc, TO_TIMESTAMP, to_timestamp); make_udf_function!( @@ -79,45 +81,66 @@ make_udf_function!( pub mod expr_fn { use datafusion_expr::Expr; - #[doc = "returns current UTC date as a Date32 value"] - pub fn current_date() -> Expr { - super::current_date().call(vec![]) - } - - #[doc = "returns current UTC time as a Time64 value"] - pub fn current_time() -> Expr { - super::current_time().call(vec![]) - } - - #[doc = "coerces an arbitrary timestamp to the start of the nearest specified interval"] - pub fn date_bin(stride: Expr, source: Expr, origin: Expr) -> Expr { - super::date_bin().call(vec![stride, source, origin]) - } - - #[doc = "extracts a subfield from the date"] - pub fn date_part(part: Expr, date: Expr) -> Expr { - super::date_part().call(vec![part, date]) - } - - #[doc = "truncates the date to a specified level of precision"] - pub fn date_trunc(part: Expr, date: Expr) -> Expr { - super::date_trunc().call(vec![part, date]) - } - - #[doc = "converts an integer to RFC3339 timestamp format string"] - pub fn from_unixtime(unixtime: Expr) -> Expr { - super::from_unixtime().call(vec![unixtime]) - } - - #[doc = "make a date from year, month and day component parts"] - pub fn make_date(year: Expr, month: Expr, day: Expr) -> Expr { - super::make_date().call(vec![year, month, day]) - } - - #[doc = "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement"] - pub fn now() -> Expr { - super::now().call(vec![]) - } + export_functions!(( + current_date, + "returns current UTC date as a Date32 value", + ),( + current_time, + "returns current UTC time as a Time64 value", + ),( + from_unixtime, + "converts an integer to RFC3339 timestamp format string", + unixtime + ),( + date_bin, + "coerces an arbitrary timestamp to the start of the nearest specified interval", + stride source origin + ),( + date_part, + "extracts a subfield from the date", + part date + ),( + date_trunc, + "truncates the date to a specified level of precision", + part date + ),( + make_date, + "make a date from year, month and day component parts", + year month day + ),( + now, + "returns the current timestamp in nanoseconds, using the same value for all instances of now() in same statement", + ), + ( + to_local_time, + "converts a timezone-aware timestamp to local time (with no offset or timezone information), i.e. strips off the timezone from the timestamp", + args, + ), + ( + to_unixtime, + "converts a string and optional formats to a Unixtime", + args, + ),( + to_timestamp, + "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`", + args, + ),( + to_timestamp_seconds, + "converts a string and optional formats to a `Timestamp(Seconds, None)`", + args, + ),( + to_timestamp_millis, + "converts a string and optional formats to a `Timestamp(Milliseconds, None)`", + args, + ),( + to_timestamp_micros, + "converts a string and optional formats to a `Timestamp(Microseconds, None)`", + args, + ),( + to_timestamp_nanos, + "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`", + args, + )); /// Returns a string representation of a date, time, timestamp or duration based /// on a Chrono pattern. @@ -247,39 +270,9 @@ pub mod expr_fn { pub fn to_date(args: Vec) -> Expr { super::to_date().call(args) } - - #[doc = "converts a string and optional formats to a Unixtime"] - pub fn to_unixtime(args: Vec) -> Expr { - super::to_unixtime().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`"] - pub fn to_timestamp(args: Vec) -> Expr { - super::to_timestamp().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Seconds, None)`"] - pub fn to_timestamp_seconds(args: Vec) -> Expr { - super::to_timestamp_seconds().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Milliseconds, None)`"] - pub fn to_timestamp_millis(args: Vec) -> Expr { - super::to_timestamp_millis().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Microseconds, None)`"] - pub fn to_timestamp_micros(args: Vec) -> Expr { - super::to_timestamp_micros().call(args) - } - - #[doc = "converts a string and optional formats to a `Timestamp(Nanoseconds, None)`"] - pub fn to_timestamp_nanos(args: Vec) -> Expr { - super::to_timestamp_nanos().call(args) - } } -/// Return a list of all functions in this package +/// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { vec![ current_date(), @@ -292,6 +285,7 @@ pub fn functions() -> Vec> { now(), to_char(), to_date(), + to_local_time(), to_unixtime(), to_timestamp(), to_timestamp_seconds(), diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b2221215b94b..c13bbfb18105 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,19 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use std::any::Any; +use std::sync::OnceLock; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct NowFunc { signature: Signature, + aliases: Vec, } impl Default for NowFunc { @@ -40,6 +44,7 @@ impl NowFunc { pub fn new() -> Self { Self { signature: Signature::uniform(0, vec![], Volatility::Stable), + aliases: vec!["current_timestamp".to_string()], } } } @@ -84,4 +89,32 @@ impl ScalarUDFImpl for NowFunc { ScalarValue::TimestampNanosecond(now_ts, Some("+00:00".into())), ))) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_unixtime_doc()) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { + false + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. +"#) + .with_syntax_example("now()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f2e5af978ca0..ef5d6a4f6990 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::cast::AsArray; use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; @@ -29,9 +29,10 @@ use arrow::error::ArrowError; use arrow::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, TIMEZONE_WILDCARD, }; #[derive(Debug)] @@ -53,34 +54,34 @@ impl ToCharFunc { vec![ Exact(vec![Date32, Utf8]), Exact(vec![Date64, Utf8]), + Exact(vec![Time64(Nanosecond), Utf8]), + Exact(vec![Time64(Microsecond), Utf8]), Exact(vec![Time32(Millisecond), Utf8]), Exact(vec![Time32(Second), Utf8]), - Exact(vec![Time64(Microsecond), Utf8]), - Exact(vec![Time64(Nanosecond), Utf8]), - Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![ - Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Millisecond, None), Utf8]), + Exact(vec![Timestamp(Nanosecond, None), Utf8]), Exact(vec![ - Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), Exact(vec![Timestamp(Microsecond, None), Utf8]), Exact(vec![ - Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Timestamp(Nanosecond, None), Utf8]), + Exact(vec![Timestamp(Millisecond, None), Utf8]), Exact(vec![ - Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())), + Timestamp(Second, Some(TIMEZONE_WILDCARD.into())), Utf8, ]), - Exact(vec![Duration(Second), Utf8]), - Exact(vec![Duration(Millisecond), Utf8]), - Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Timestamp(Second, None), Utf8]), Exact(vec![Duration(Nanosecond), Utf8]), + Exact(vec![Duration(Microsecond), Utf8]), + Exact(vec![Duration(Millisecond), Utf8]), + Exact(vec![Duration(Second), Utf8]), ], Volatility::Immutable, ), @@ -137,6 +138,42 @@ impl ScalarUDFImpl for ToCharFunc { fn aliases(&self) -> &[String] { &self.aliases } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_char_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_char_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported.") + .with_syntax_example("to_char(expression, format)") + .with_argument( + "expression", + " Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration." + ) + .with_argument( + "format", + "A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression.", + ) + .with_argument("day", "Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators.") + .with_sql_example(r#"```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); ++----------------------------------------------+ +| to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | ++----------------------------------------------+ +| 01-03-2023 | ++----------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) +"#) + .build() + .unwrap() + }) } fn _build_format_options<'a>( @@ -185,10 +222,7 @@ fn _to_char_scalar( if is_scalar_expression { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } else { - return Ok(ColumnarValue::Array(new_null_array( - &DataType::Utf8, - array.len(), - ))); + return Ok(ColumnarValue::Array(new_null_array(&Utf8, array.len()))); } } @@ -350,6 +384,7 @@ mod tests { ]; for (value, format, expected) in scalar_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)]) .expect("that to_char parsed values without error"); @@ -424,6 +459,7 @@ mod tests { ]; for (value, format, expected) in scalar_array_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Scalar(value), @@ -549,6 +585,7 @@ mod tests { ]; for (value, format, expected) in array_scalar_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Array(value as ArrayRef), @@ -565,6 +602,7 @@ mod tests { } for (value, format, expected) in array_array_data { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Array(value), @@ -585,6 +623,7 @@ mod tests { // // invalid number of arguments + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new() .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); assert_eq!( @@ -593,6 +632,7 @@ mod tests { ); // invalid type + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ToCharFunc::new().invoke(&[ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index e491c0b55508..8f72100416e8 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -15,15 +15,19 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - -use arrow::array::types::Date32Type; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Date32; - use crate::datetime::common::*; -use datafusion_common::{exec_err, internal_datafusion_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::*; +use arrow::error::ArrowError::ParseError; +use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser}; +use datafusion_common::error::DataFusionError; +use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct ToDateFunc { @@ -47,22 +51,20 @@ impl ToDateFunc { match args.len() { 1 => handle::( args, - |s| { - string_to_timestamp_nanos_shim(s) - .map(|n| n / (1_000_000 * 24 * 60 * 60 * 1_000)) - .and_then(|v| { - v.try_into().map_err(|_| { - internal_datafusion_err!("Unable to cast to Date32 for converting from i64 to i32 failed") - }) - }) + |s| match Date32Type::parse(s) { + Some(v) => Ok(v), + None => arrow_err!(ParseError( + "Unable to cast to Date32 for converting from i64 to i32 failed" + .to_string() + )), }, "to_date", ), - n if n >= 2 => handle_multiple::( + 2.. => handle_multiple::( args, |s, format| { - string_to_timestamp_nanos_formatted(s, format) - .map(|n| n / (1_000_000 * 24 * 60 * 60 * 1_000)) + string_to_timestamp_millis_formatted(s, format) + .map(|n| n / (24 * 60 * 60 * 1_000)) .and_then(|v| { v.try_into().map_err(|_| { internal_datafusion_err!("Unable to cast to Date32 for converting from i64 to i32 failed") @@ -72,11 +74,55 @@ impl ToDateFunc { |n| n, "to_date", ), - _ => exec_err!("Unsupported 0 argument count for function to_date"), + 0 => exec_err!("Unsupported 0 argument count for function to_date"), } } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_date_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#"Converts a value to a date (`YYYY-MM-DD`). +Supports strings, integer and double types as input. +Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding date. + +Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. +"#) + .with_syntax_example("to_date('2017-05-31', '%Y-%m-%d')") + .with_sql_example(r#"```sql +> select to_date('2023-01-31'); ++-----------------------------+ +| to_date(Utf8("2023-01-31")) | ++-----------------------------+ +| 2023-01-31 | ++-----------------------------+ +> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); ++---------------------------------------------------------------+ +| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | ++---------------------------------------------------------------+ +| 2023-01-31 | ++---------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) +"#) + .with_standard_argument("expression", Some("String")) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order + they appear with the first successful one being returned. If none of the formats successfully parse the expression + an error will be returned.", + ) + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for ToDateFunc { fn as_any(&self) -> &dyn Any { self @@ -105,16 +151,320 @@ impl ScalarUDFImpl for ToDateFunc { } match args[0].data_type() { - DataType::Int32 - | DataType::Int64 - | DataType::Null - | DataType::Float64 - | DataType::Date32 - | DataType::Date64 => args[0].cast_to(&DataType::Date32, None), - DataType::Utf8 => self.to_date(args), + Int32 | Int64 | Null | Float64 | Date32 | Date64 => { + args[0].cast_to(&Date32, None) + } + Utf8View | LargeUtf8 | Utf8 => self.to_date(args), other => { exec_err!("Unsupported data type {:?} for function to_date", other) } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_date_doc()) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; + use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + use super::ToDateFunc; + + #[test] + fn test_to_date_without_format() { + struct TestCase { + name: &'static str, + date_str: &'static str, + } + + let test_cases = vec![ + TestCase { + name: "Largest four-digit year (9999)", + date_str: "9999-12-31", + }, + TestCase { + name: "Year 1 (0001)", + date_str: "0001-12-31", + }, + TestCase { + name: "Year before epoch (1969)", + date_str: "1969-01-01", + }, + TestCase { + name: "Switch Julian/Gregorian calendar (1582-10-10)", + date_str: "1582-10-10", + }, + ]; + + for tc in &test_cases { + test_scalar(ScalarValue::Utf8(Some(tc.date_str.to_string())), tc); + test_scalar(ScalarValue::LargeUtf8(Some(tc.date_str.to_string())), tc); + test_scalar(ScalarValue::Utf8View(Some(tc.date_str.to_string())), tc); + + test_array::>(tc); + test_array::>(tc); + test_array::(tc); + } + + fn test_scalar(sv: ScalarValue, tc: &TestCase) { + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = ToDateFunc::new().invoke(&[ColumnarValue::Scalar(sv)]); + + match to_date_result { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + assert_eq!( + date_val, expected, + "{}: to_date created wrong value", + tc.name + ); + } + _ => panic!("Could not convert '{}' to Date", tc.date_str), + } + } + + fn test_array(tc: &TestCase) + where + A: From> + Array + 'static, + { + let date_array = A::from(vec![tc.date_str]); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = + ToDateFunc::new().invoke(&[ColumnarValue::Array(Arc::new(date_array))]); + + match to_date_result { + Ok(ColumnarValue::Array(a)) => { + assert_eq!(a.len(), 1); + + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + let mut builder = Date32Array::builder(4); + builder.append_value(expected.unwrap()); + + assert_eq!( + &builder.finish() as &dyn Array, + a.as_ref(), + "{}: to_date created wrong value", + tc.name + ); + } + _ => panic!("Could not convert '{}' to Date", tc.date_str), + } + } + } + + #[test] + fn test_to_date_with_format() { + struct TestCase { + name: &'static str, + date_str: &'static str, + format_str: &'static str, + formatted_date: &'static str, + } + + let test_cases = vec![ + TestCase { + name: "Largest four-digit year (9999)", + date_str: "9999-12-31", + format_str: "%Y%m%d", + formatted_date: "99991231", + }, + TestCase { + name: "Smallest four-digit year (-9999)", + date_str: "-9999-12-31", + format_str: "%Y/%m/%d", + formatted_date: "-9999/12/31", + }, + TestCase { + name: "Year 1 (0001)", + date_str: "0001-12-31", + format_str: "%Y%m%d", + formatted_date: "00011231", + }, + TestCase { + name: "Year before epoch (1969)", + date_str: "1969-01-01", + format_str: "%Y%m%d", + formatted_date: "19690101", + }, + TestCase { + name: "Switch Julian/Gregorian calendar (1582-10-10)", + date_str: "1582-10-10", + format_str: "%Y%m%d", + formatted_date: "15821010", + }, + TestCase { + name: "Negative Year, BC (-42-01-01)", + date_str: "-42-01-01", + format_str: "%Y/%m/%d", + formatted_date: "-42/01/01", + }, + ]; + + for tc in &test_cases { + test_scalar(ScalarValue::Utf8(Some(tc.formatted_date.to_string())), tc); + test_scalar( + ScalarValue::LargeUtf8(Some(tc.formatted_date.to_string())), + tc, + ); + test_scalar( + ScalarValue::Utf8View(Some(tc.formatted_date.to_string())), + tc, + ); + + test_array::>(tc); + test_array::>(tc); + test_array::(tc); + } + + fn test_scalar(sv: ScalarValue, tc: &TestCase) { + let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); + + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = ToDateFunc::new().invoke(&[ + ColumnarValue::Scalar(sv), + ColumnarValue::Scalar(format_scalar), + ]); + + match to_date_result { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + assert_eq!(date_val, expected, "{}: to_date created wrong value for date '{}' with format string '{}'", tc.name, tc.formatted_date, tc.format_str); + } + _ => panic!( + "Could not convert '{}' with format string '{}'to Date", + tc.date_str, tc.format_str + ), + } + } + + fn test_array(tc: &TestCase) + where + A: From> + Array + 'static, + { + let date_array = A::from(vec![tc.formatted_date]); + let format_array = A::from(vec![tc.format_str]); + + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = ToDateFunc::new().invoke(&[ + ColumnarValue::Array(Arc::new(date_array)), + ColumnarValue::Array(Arc::new(format_array)), + ]); + + match to_date_result { + Ok(ColumnarValue::Array(a)) => { + assert_eq!(a.len(), 1); + + let expected = Date32Type::parse_formatted(tc.date_str, "%Y-%m-%d"); + let mut builder = Date32Array::builder(4); + builder.append_value(expected.unwrap()); + + assert_eq!( + &builder.finish() as &dyn Array, a.as_ref(), + "{}: to_date created wrong value for date '{}' with format string '{}'", + tc.name, + tc.formatted_date, + tc.format_str + ); + } + _ => panic!( + "Could not convert '{}' with format string '{}'to Date: {:?}", + tc.formatted_date, tc.format_str, to_date_result + ), + } + } + } + + #[test] + fn test_to_date_multiple_format_strings() { + let formatted_date_scalar = ScalarValue::Utf8(Some("2023/01/31".into())); + let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); + let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); + + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = ToDateFunc::new().invoke(&[ + ColumnarValue::Scalar(formatted_date_scalar), + ColumnarValue::Scalar(format1_scalar), + ColumnarValue::Scalar(format2_scalar), + ]); + + match to_date_result { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { + let expected = Date32Type::parse_formatted("2023-01-31", "%Y-%m-%d"); + assert_eq!( + date_val, expected, + "to_date created wrong value for date with 2 format strings" + ); + } + _ => panic!("Conversion failed",), + } + } + + #[test] + fn test_to_date_from_timestamp() { + let test_cases = vec![ + "2020-09-08T13:42:29Z", + "2020-09-08T13:42:29.190855-05:00", + "2020-09-08 12:13:29", + ]; + for date_str in test_cases { + let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); + + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = + ToDateFunc::new().invoke(&[ColumnarValue::Scalar(formatted_date_scalar)]); + + match to_date_result { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { + let expected = Date32Type::parse_formatted("2020-09-08", "%Y-%m-%d"); + assert_eq!(date_val, expected, "to_date created wrong value"); + } + _ => panic!("Conversion of {} failed", date_str), + } + } + } + + #[test] + fn test_to_date_string_with_valid_number() { + let date_str = "20241231"; + let date_scalar = ScalarValue::Utf8(Some(date_str.into())); + + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = + ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); + + match to_date_result { + Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { + let expected = Date32Type::parse_formatted("2024-12-31", "%Y-%m-%d"); + assert_eq!( + date_val, expected, + "to_date created wrong value for {}", + date_str + ); + } + _ => panic!("Conversion of {} failed", date_str), + } + } + + #[test] + fn test_to_date_string_with_invalid_number() { + let date_str = "202412311"; + let date_scalar = ScalarValue::Utf8(Some(date_str.into())); + + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let to_date_result = + ToDateFunc::new().invoke(&[ColumnarValue::Scalar(date_scalar)]); + + if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { + panic!( + "Conversion of {} succeded, but should have failed, ", + date_str + ); + } + } } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs new file mode 100644 index 000000000000..fef1eb9a60c8 --- /dev/null +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -0,0 +1,637 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::ops::Add; +use std::sync::{Arc, OnceLock}; + +use arrow::array::timezone::Tz; +use arrow::array::{Array, ArrayRef, PrimitiveBuilder}; +use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; +use arrow::datatypes::{ + ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; + +use chrono::{DateTime, MappedLocalTime, Offset, TimeDelta, TimeZone, Utc}; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::{exec_err, plan_err, DataFusionError, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; + +/// A UDF function that converts a timezone-aware timestamp to local time (with no offset or +/// timezone information). In other words, this function strips off the timezone from the timestamp, +/// while keep the display value of the timestamp the same. +#[derive(Debug)] +pub struct ToLocalTimeFunc { + signature: Signature, +} + +impl Default for ToLocalTimeFunc { + fn default() -> Self { + Self::new() + } +} + +impl ToLocalTimeFunc { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } + + fn to_local_time(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {}", + args.len() + ); + } + + let time_value = &args[0]; + let arg_type = time_value.data_type(); + match arg_type { + Timestamp(_, None) => { + // if no timezone specified, just return the input + Ok(time_value.clone()) + } + // If has timezone, adjust the underlying time value. The current time value + // is stored as i64 in UTC, even though the timezone may not be in UTC. Therefore, + // we need to adjust the time value to the local time. See [`adjust_to_local_time`] + // for more details. + // + // Then remove the timezone in return type, i.e. return None + Timestamp(_, Some(timezone)) => { + let tz: Tz = timezone.parse()?; + + match time_value { + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(ts), + Some(_), + )) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + ColumnarValue::Array(array) => { + fn transform_array( + array: &ArrayRef, + tz: Tz, + ) -> Result { + let mut builder = PrimitiveBuilder::::new(); + + let primitive_array = as_primitive_array::(array)?; + for ts_opt in primitive_array.iter() { + match ts_opt { + None => builder.append_null(), + Some(ts) => { + let adjusted_ts: i64 = + adjust_to_local_time::(ts, tz)?; + builder.append_value(adjusted_ts) + } + } + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + + match array.data_type() { + Timestamp(_, None) => { + // if no timezone specified, just return the input + Ok(time_value.clone()) + } + Timestamp(Nanosecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Microsecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Millisecond, Some(_)) => { + transform_array::(array, tz) + } + Timestamp(Second, Some(_)) => { + transform_array::(array, tz) + } + _ => { + exec_err!("to_local_time function requires timestamp argument in array, got {:?}", array.data_type()) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + time_value.data_type() + ) + } + } + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + arg_type + ) + } + } + } +} + +/// This function converts a timestamp with a timezone to a timestamp without a timezone. +/// The display value of the adjusted timestamp remain the same, but the underlying timestamp +/// representation is adjusted according to the relative timezone offset to UTC. +/// +/// This function uses chrono to handle daylight saving time changes. +/// +/// For example, +/// +/// ```text +/// '2019-03-31T01:00:00Z'::timestamp at time zone 'Europe/Brussels' +/// ``` +/// +/// is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00+01:00 +/// ``` +/// +/// and is represented in DataFusion as: +/// +/// ```text +/// TimestampNanosecond(Some(1_553_990_400_000_000_000), Some("Europe/Brussels")) +/// ``` +/// +/// To strip off the timezone while keeping the display value the same, we need to +/// adjust the underlying timestamp with the timezone offset value using `adjust_to_local_time()` +/// +/// ```text +/// adjust_to_local_time(1_553_990_400_000_000_000, "Europe/Brussels") --> 1_553_994_000_000_000_000 +/// ``` +/// +/// The difference between `1_553_990_400_000_000_000` and `1_553_994_000_000_000_000` is +/// `3600_000_000_000` ns, which corresponds to 1 hour. This matches with the timezone +/// offset for "Europe/Brussels" for this date. +/// +/// Note that the offset varies with daylight savings time (DST), which makes this tricky! For +/// example, timezone "Europe/Brussels" has a 2-hour offset during DST and a 1-hour offset +/// when DST ends. +/// +/// Consequently, DataFusion can represent the timestamp in local time (with no offset or +/// timezone information) as +/// +/// ```text +/// TimestampNanosecond(Some(1_553_994_000_000_000_000), None) +/// ``` +/// +/// which is displayed as follows in datafusion-cli: +/// +/// ```text +/// 2019-03-31T01:00:00 +/// ``` +/// +/// See `test_adjust_to_local_time()` for example +fn adjust_to_local_time(ts: i64, tz: Tz) -> Result { + fn convert_timestamp(ts: i64, converter: F) -> Result> + where + F: Fn(i64) -> MappedLocalTime>, + { + match converter(ts) { + MappedLocalTime::Ambiguous(earliest, latest) => exec_err!( + "Ambiguous timestamp. Do you mean {:?} or {:?}", + earliest, + latest + ), + MappedLocalTime::None => exec_err!( + "The local time does not exist because there is a gap in the local time." + ), + MappedLocalTime::Single(date_time) => Ok(date_time), + } + } + + let date_time = match T::UNIT { + Nanosecond => Utc.timestamp_nanos(ts), + Microsecond => convert_timestamp(ts, |ts| Utc.timestamp_micros(ts))?, + Millisecond => convert_timestamp(ts, |ts| Utc.timestamp_millis_opt(ts))?, + Second => convert_timestamp(ts, |ts| Utc.timestamp_opt(ts, 0))?, + }; + + let offset_seconds: i64 = tz + .offset_from_utc_datetime(&date_time.naive_utc()) + .fix() + .local_minus_utc() as i64; + + let adjusted_date_time = date_time.add( + // This should not fail under normal circumstances as the + // maximum possible offset is 26 hours (93,600 seconds) + TimeDelta::try_seconds(offset_seconds) + .ok_or(DataFusionError::Internal("Offset seconds should be less than i64::MAX / 1_000 or greater than -i64::MAX / 1_000".to_string()))?, + ); + + // convert the naive datetime back to i64 + match T::UNIT { + Nanosecond => adjusted_date_time.timestamp_nanos_opt().ok_or( + DataFusionError::Internal( + "Failed to convert DateTime to timestamp in nanosecond. This error may occur if the date is out of range. The supported date ranges are between 1677-09-21T00:12:43.145224192 and 2262-04-11T23:47:16.854775807".to_string(), + ), + ), + Microsecond => Ok(adjusted_date_time.timestamp_micros()), + Millisecond => Ok(adjusted_date_time.timestamp_millis()), + Second => Ok(adjusted_date_time.timestamp()), + } +} + +impl ScalarUDFImpl for ToLocalTimeFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_local_time" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if arg_types.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + arg_types.len() + ); + } + + match &arg_types[0] { + Timestamp(timeunit, _) => Ok(Timestamp(*timeunit, None)), + _ => exec_err!( + "The to_local_time function can only accept timestamp as the arg, got {:?}", arg_types[0] + ) + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "to_local_time function requires 1 argument, got {:?}", + args.len() + ); + } + + self.to_local_time(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 1 { + return plan_err!( + "to_local_time function requires 1 argument, got {:?}", + arg_types.len() + ); + } + + let first_arg = arg_types[0].clone(); + match &first_arg { + Timestamp(Nanosecond, timezone) => { + Ok(vec![Timestamp(Nanosecond, timezone.clone())]) + } + Timestamp(Microsecond, timezone) => { + Ok(vec![Timestamp(Microsecond, timezone.clone())]) + } + Timestamp(Millisecond, timezone) => { + Ok(vec![Timestamp(Millisecond, timezone.clone())]) + } + Timestamp(Second, timezone) => Ok(vec![Timestamp(Second, timezone.clone())]), + _ => plan_err!("The to_local_time function can only accept Timestamp as the arg got {first_arg}"), + } + } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_local_time_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_local_time_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes.") + .with_syntax_example("to_local_time(expression)") + .with_argument( + "expression", + "Time expression to operate on. Can be a constant, column, or function." + ) + .with_sql_example(r#"```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +```"#) + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; + use arrow::datatypes::{DataType, TimeUnit}; + use chrono::NaiveDateTime; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::{adjust_to_local_time, ToLocalTimeFunc}; + + #[test] + fn test_adjust_to_local_time() { + let timestamp_str = "2020-03-31T13:40:00"; + let tz: arrow::array::timezone::Tz = + "America/New_York".parse().expect("Invalid timezone"); + + let timestamp = timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + + let expected_timestamp = timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + + let res = adjust_to_local_time::(timestamp, tz).unwrap(); + assert_eq!(res, expected_timestamp); + } + + #[test] + fn test_to_local_time_scalar() { + let timezone = Some("Europe/Brussels".into()); + let timestamps_with_timezone = vec![ + ( + ScalarValue::TimestampNanosecond( + Some(1_123_123_000_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampNanosecond(Some(1_123_130_200_000_000_000), None), + ), + ( + ScalarValue::TimestampMicrosecond( + Some(1_123_123_000_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMicrosecond(Some(1_123_130_200_000_000), None), + ), + ( + ScalarValue::TimestampMillisecond( + Some(1_123_123_000_000), + timezone.clone(), + ), + ScalarValue::TimestampMillisecond(Some(1_123_130_200_000), None), + ), + ( + ScalarValue::TimestampSecond(Some(1_123_123_000), timezone), + ScalarValue::TimestampSecond(Some(1_123_130_200), None), + ), + ]; + + for (input, expected) in timestamps_with_timezone { + test_to_local_time_helper(input, expected); + } + } + + #[test] + fn test_timezone_with_daylight_savings() { + let timezone_str = "America/New_York"; + let tz: arrow::array::timezone::Tz = + timezone_str.parse().expect("Invalid timezone"); + + // Test data: + // ( + // the string display of the input timestamp, + // the i64 representation of the timestamp before adjustment in nanosecond, + // the i64 representation of the timestamp after adjustment in nanosecond, + // ) + let test_cases = vec![ + ( + // DST time + "2020-03-31T13:40:00", + 1_585_676_400_000_000_000, + 1_585_662_000_000_000_000, + ), + ( + // End of DST + "2020-11-04T14:06:40", + 1_604_516_800_000_000_000, + 1_604_498_800_000_000_000, + ), + ]; + + for ( + input_timestamp_str, + expected_input_timestamp, + expected_adjusted_timestamp, + ) in test_cases + { + let input_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_local_timezone(tz) // this is in a local timezone + .unwrap() + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(input_timestamp, expected_input_timestamp); + + let expected_timestamp = input_timestamp_str + .parse::() + .unwrap() + .and_utc() // this is in UTC + .timestamp_nanos_opt() + .unwrap(); + assert_eq!(expected_timestamp, expected_adjusted_timestamp); + + let input = ScalarValue::TimestampNanosecond( + Some(input_timestamp), + Some(timezone_str.into()), + ); + let expected = + ScalarValue::TimestampNanosecond(Some(expected_timestamp), None); + test_to_local_time_helper(input, expected) + } + } + + fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let res = ToLocalTimeFunc::new() + .invoke_batch(&[ColumnarValue::Scalar(input)], 1) + .unwrap(); + match res { + ColumnarValue::Scalar(res) => { + assert_eq!(res, expected); + } + _ => panic!("unexpected return type"), + } + } + + #[test] + fn test_to_local_time_timezones_array() { + let cases = [ + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + None::>, + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ( + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + Some("+01:00".into()), + vec![ + "2020-09-08T00:00:00", + "2020-09-08T01:00:00", + "2020-09-08T02:00:00", + "2020-09-08T03:00:00", + "2020-09-08T04:00:00", + ], + ), + ]; + + cases.iter().for_each(|(source, _tz_opt, expected)| { + let input = source + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let right = expected + .iter() + .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) + .collect::(); + let batch_size = input.len(); + let result = ToLocalTimeFunc::new() + .invoke_batch(&[ColumnarValue::Array(Arc::new(input))], batch_size) + .unwrap(); + if let ColumnarValue::Array(result) = result { + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Nanosecond, None) + ); + let left = arrow::array::cast::as_primitive_array::< + TimestampNanosecondType, + >(&result); + assert_eq!(left, &right); + } else { + panic!("unexpected column type"); + } + }); + } +} diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index a7bcca62944c..f15fad701c55 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -16,18 +16,21 @@ // under the License. use std::any::Any; +use std::sync::{Arc, OnceLock}; -use arrow::datatypes::DataType::Timestamp; +use arrow::datatypes::DataType::*; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{ - ArrowTimestampType, DataType, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, + ArrowTimestampType, DataType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use datafusion_common::{exec_err, Result, ScalarType}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - use crate::datetime::common::*; +use datafusion_common::{exec_err, Result, ScalarType}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct ToTimestampFunc { @@ -143,8 +146,8 @@ impl ScalarUDFImpl for ToTimestampFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Nanosecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Nanosecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -161,13 +164,16 @@ impl ScalarUDFImpl for ToTimestampFunc { } match args[0].data_type() { - DataType::Int32 | DataType::Int64 => args[0] + Int32 | Int64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, None), None), - DataType::Null | DataType::Float64 | Timestamp(_, None) => { + Null | Float64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } - DataType::Utf8 => { + Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) + } + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp") } other => { @@ -178,6 +184,50 @@ impl ScalarUDFImpl for ToTimestampFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_timestamp_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description(r#" +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. +"#) + .with_syntax_example("to_timestamp(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------+ +| to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------+ +> select to_timestamp('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------+ +| to_timestamp(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++--------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampSecondsFunc { @@ -193,8 +243,8 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Second, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Second)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -211,10 +261,11 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Second, None), None) } - DataType::Utf8 => { + Timestamp(_, Some(tz)) => args[0].cast_to(&Timestamp(Second, Some(tz)), None), + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_seconds") } other => { @@ -225,6 +276,46 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_seconds_doc()) + } +} + +static TO_TIMESTAMP_SECONDS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_seconds_doc() -> &'static Documentation { + TO_TIMESTAMP_SECONDS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_seconds(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); ++-------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-------------------------------------------------------------------+ +| 2023-01-31T14:26:56 | ++-------------------------------------------------------------------+ +> select to_timestamp_seconds('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++----------------------------------------------------------------------------------------------------------------+ +| to_timestamp_seconds(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++----------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00 | ++----------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampMillisFunc { @@ -240,8 +331,8 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Millisecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Millisecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -258,10 +349,13 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Millisecond, None), None) } - DataType::Utf8 => { + Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Millisecond, Some(tz)), None) + } + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_millis") } other => { @@ -272,6 +366,46 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_millis_doc()) + } +} + +static TO_TIMESTAMP_MILLIS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_millis_doc() -> &'static Documentation { + TO_TIMESTAMP_MILLIS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_millis(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123 | ++------------------------------------------------------------------+ +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampMicrosFunc { @@ -287,8 +421,8 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Microsecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Microsecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -305,10 +439,13 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Microsecond, None), None) } - DataType::Utf8 => { + Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Microsecond, Some(tz)), None) + } + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_micros") } other => { @@ -319,6 +456,46 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_micros_doc()) + } +} + +static TO_TIMESTAMP_MICROS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_micros_doc() -> &'static Documentation { + TO_TIMESTAMP_MICROS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_micros(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); ++------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++------------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456 | ++------------------------------------------------------------------+ +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++---------------------------------------------------------------------------------------------------------------+ +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++---------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) } impl ScalarUDFImpl for ToTimestampNanosFunc { @@ -334,8 +511,8 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Timestamp(Nanosecond, None)) + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(return_type_for(&arg_types[0], Nanosecond)) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -352,10 +529,13 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } match args[0].data_type() { - DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { + Null | Int32 | Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } - DataType::Utf8 => { + Timestamp(_, Some(tz)) => { + args[0].cast_to(&Timestamp(Nanosecond, Some(tz)), None) + } + Utf8View | LargeUtf8 | Utf8 => { to_timestamp_impl::(args, "to_timestamp_nanos") } other => { @@ -366,6 +546,55 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_timestamp_nanos_doc()) + } +} + +static TO_TIMESTAMP_NANOS_DOC: OnceLock = OnceLock::new(); + +fn get_to_timestamp_nanos_doc() -> &'static Documentation { + TO_TIMESTAMP_NANOS_DOC.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp.") + .with_syntax_example("to_timestamp_nanos(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ) + .with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.", + ) + .with_sql_example(r#"```sql +> select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); ++-----------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | ++-----------------------------------------------------------------+ +| 2023-01-31T14:26:56.123456789 | ++-----------------------------------------------------------------+ +> select to_timestamp_nanos('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); ++--------------------------------------------------------------------------------------------------------------+ +| to_timestamp_nanos(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | ++--------------------------------------------------------------------------------------------------------------+ +| 2023-05-17T03:59:00.123456789 | ++---------------------------------------------------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) +"#) + .build() + .unwrap() + }) +} + +/// Returns the return type for the to_timestamp_* function, preserving +/// the timezone if it exists. +fn return_type_for(arg: &DataType, unit: TimeUnit) -> DataType { + match arg { + Timestamp(_, Some(tz)) => Timestamp(unit, Some(Arc::clone(tz))), + _ => Timestamp(unit, None), + } } fn to_timestamp_impl>( @@ -407,7 +636,6 @@ mod tests { use arrow::array::{ArrayRef, Int64Array, StringBuilder}; use arrow::datatypes::TimeUnit; use chrono::Utc; - use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -591,7 +819,7 @@ mod tests { ColumnarValue::Array(Arc::new(date_string_builder.finish()) as ArrayRef); let expected_err = - "Arrow error: Parser error: Invalid timezone \"ZZ\": 'ZZ' is not a valid timezone"; + "Arrow error: Parser error: Invalid timezone \"ZZ\": failed to parse timezone"; match to_timestamp(&[string_array]) { Ok(_) => panic!("Expected error but got success"), Err(e) => { @@ -670,6 +898,10 @@ mod tests { parse_timestamp_formatted("09-08-2020 13/42/29", "%m-%d-%Y %H/%M/%S") .unwrap() ); + assert_eq!( + 1642896000000000000, + parse_timestamp_formatted("2022-01-23", "%Y-%m-%d").unwrap() + ); } fn parse_timestamp_formatted(s: &str, format: &str) -> Result { @@ -736,6 +968,103 @@ mod tests { } } + #[test] + fn test_tz() { + let udfs: Vec> = vec![ + Box::new(ToTimestampFunc::new()), + Box::new(ToTimestampSecondsFunc::new()), + Box::new(ToTimestampMillisFunc::new()), + Box::new(ToTimestampNanosFunc::new()), + Box::new(ToTimestampSecondsFunc::new()), + ]; + + let mut nanos_builder = TimestampNanosecondArray::builder(2); + let mut millis_builder = TimestampMillisecondArray::builder(2); + let mut micros_builder = TimestampMicrosecondArray::builder(2); + let mut sec_builder = TimestampSecondArray::builder(2); + + nanos_builder.append_value(1599572549190850000); + millis_builder.append_value(1599572549190); + micros_builder.append_value(1599572549190850); + sec_builder.append_value(1599572549); + + let nanos_timestamps = + Arc::new(nanos_builder.finish().with_timezone("UTC")) as ArrayRef; + let millis_timestamps = + Arc::new(millis_builder.finish().with_timezone("UTC")) as ArrayRef; + let micros_timestamps = + Arc::new(micros_builder.finish().with_timezone("UTC")) as ArrayRef; + let sec_timestamps = + Arc::new(sec_builder.finish().with_timezone("UTC")) as ArrayRef; + + let arrays = &[ + ColumnarValue::Array(Arc::clone(&nanos_timestamps)), + ColumnarValue::Array(Arc::clone(&millis_timestamps)), + ColumnarValue::Array(Arc::clone(µs_timestamps)), + ColumnarValue::Array(Arc::clone(&sec_timestamps)), + ]; + + for udf in &udfs { + for array in arrays { + let rt = udf.return_type(&[array.data_type()]).unwrap(); + assert!(matches!(rt, Timestamp(_, Some(_)))); + + let res = udf + .invoke_batch(&[array.clone()], 1) + .expect("that to_timestamp parsed values without error"); + let array = match res { + ColumnarValue::Array(res) => res, + _ => panic!("Expected a columnar array"), + }; + let ty = array.data_type(); + assert!(matches!(ty, Timestamp(_, Some(_)))); + } + } + + let mut nanos_builder = TimestampNanosecondArray::builder(2); + let mut millis_builder = TimestampMillisecondArray::builder(2); + let mut micros_builder = TimestampMicrosecondArray::builder(2); + let mut sec_builder = TimestampSecondArray::builder(2); + let mut i64_builder = Int64Array::builder(2); + + nanos_builder.append_value(1599572549190850000); + millis_builder.append_value(1599572549190); + micros_builder.append_value(1599572549190850); + sec_builder.append_value(1599572549); + i64_builder.append_value(1599572549); + + let nanos_timestamps = Arc::new(nanos_builder.finish()) as ArrayRef; + let millis_timestamps = Arc::new(millis_builder.finish()) as ArrayRef; + let micros_timestamps = Arc::new(micros_builder.finish()) as ArrayRef; + let sec_timestamps = Arc::new(sec_builder.finish()) as ArrayRef; + let i64_timestamps = Arc::new(i64_builder.finish()) as ArrayRef; + + let arrays = &[ + ColumnarValue::Array(Arc::clone(&nanos_timestamps)), + ColumnarValue::Array(Arc::clone(&millis_timestamps)), + ColumnarValue::Array(Arc::clone(µs_timestamps)), + ColumnarValue::Array(Arc::clone(&sec_timestamps)), + ColumnarValue::Array(Arc::clone(&i64_timestamps)), + ]; + + for udf in &udfs { + for array in arrays { + let rt = udf.return_type(&[array.data_type()]).unwrap(); + assert!(matches!(rt, Timestamp(_, None))); + + let res = udf + .invoke_batch(&[array.clone()], 1) + .expect("that to_timestamp parsed values without error"); + let array = match res { + ColumnarValue::Array(res) => res, + _ => panic!("Expected a columnar array"), + }; + let ty = array.data_type(); + assert!(matches!(ty, Timestamp(_, None))); + } + } + } + #[test] fn test_to_timestamp_arg_validation() { let mut date_string_builder = StringBuilder::with_capacity(2, 1024); @@ -807,6 +1136,8 @@ mod tests { .expect("that to_timestamp with format args parsed values without error"); if let ColumnarValue::Array(parsed_array) = parsed_timestamps { assert_eq!(parsed_array.len(), 1); + assert!(matches!(parsed_array.data_type(), Timestamp(_, None))); + match time_unit { Nanosecond => { assert_eq!(nanos_expected_timestamps, parsed_array.as_ref()) diff --git a/datafusion/functions/src/datetime/to_unixtime.rs b/datafusion/functions/src/datetime/to_unixtime.rs index 396dadccb4b3..dd90ce6a6c96 100644 --- a/datafusion/functions/src/datetime/to_unixtime.rs +++ b/datafusion/functions/src/datetime/to_unixtime.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - -use arrow::datatypes::{DataType, TimeUnit}; - +use super::to_timestamp::ToTimestampSecondsFunc; use crate::datetime::common::*; +use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use super::to_timestamp::ToTimestampSecondsFunc; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_DATETIME; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct ToUnixtimeFunc { @@ -61,7 +62,11 @@ impl ScalarUDFImpl for ToUnixtimeFunc { Ok(DataType::Int64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { + fn invoke_batch( + &self, + args: &[ColumnarValue], + batch_size: usize, + ) -> Result { if args.is_empty() { return exec_err!("to_unixtime function requires 1 or more arguments, got 0"); } @@ -79,11 +84,49 @@ impl ScalarUDFImpl for ToUnixtimeFunc { .cast_to(&DataType::Timestamp(TimeUnit::Second, None), None)? .cast_to(&DataType::Int64, None), DataType::Utf8 => ToTimestampSecondsFunc::new() - .invoke(args)? + .invoke_batch(args, batch_size)? .cast_to(&DataType::Int64, None), other => { exec_err!("Unsupported data type {:?} for function to_unixtime", other) } } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_unixtime_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_unixtime_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_DATETIME) + .with_description("Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided.") + .with_syntax_example("to_unixtime(expression[, ..., format_n])") + .with_argument( + "expression", + "Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators." + ).with_argument( + "format_n", + "Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned.") + .with_sql_example(r#" +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` +"#) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index d9ce299a2602..4f91879f94db 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -18,9 +18,12 @@ //! Encoding expressions use arrow::{ - array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait, StringArray}, - datatypes::DataType, + array::{ + Array, ArrayRef, BinaryArray, GenericByteArray, OffsetSizeTrait, StringArray, + }, + datatypes::{ByteArrayType, DataType}, }; +use arrow_buffer::{Buffer, OffsetBufferBuilder}; use base64::{engine::general_purpose, Engine as _}; use datafusion_common::{ cast::{as_generic_binary_array, as_generic_string_array}, @@ -28,11 +31,11 @@ use datafusion_common::{ }; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; +use datafusion_expr::{ColumnarValue, Documentation}; +use std::sync::{Arc, OnceLock}; use std::{fmt, str::FromStr}; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_BINARY_STRING; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -49,21 +52,28 @@ impl Default for EncodeFunc { impl EncodeFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } +static ENCODE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_encode_doc() -> &'static Documentation { + ENCODE_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_BINARY_STRING) + .with_description("Encode binary data into a textual representation.") + .with_syntax_example("encode(expression, format)") + .with_argument("expression", "Expression containing string or binary data") + .with_argument("format", "Supported formats are: `base64`, `hex`") + .with_related_udf("decode") + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for EncodeFunc { fn as_any(&self) -> &dyn Any { self @@ -77,23 +87,43 @@ impl ScalarUDFImpl for EncodeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(match arg_types[0] { - Utf8 => Utf8, - LargeUtf8 => LargeUtf8, - Binary => Utf8, - LargeBinary => LargeUtf8, - Null => Null, - _ => { - return plan_err!("The encode function can only accept utf8 or binary."); - } - }) + Ok(arg_types[0].to_owned()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { encode(args) } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return plan_err!( + "{} expects to get 2 arguments, but got {}", + self.name(), + arg_types.len() + ); + } + + if arg_types[1] != DataType::Utf8 { + return Err(DataFusionError::Plan("2nd argument should be Utf8".into())); + } + + match arg_types[0] { + DataType::Utf8 | DataType::Binary | DataType::Null => { + Ok(vec![DataType::Utf8; 2]) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + Ok(vec![DataType::LargeUtf8, DataType::Utf8]) + } + _ => plan_err!( + "1st argument should be Utf8 or Binary or Null, got {:?}", + arg_types[0] + ), + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_encode_doc()) + } } #[derive(Debug)] @@ -109,21 +139,28 @@ impl Default for DecodeFunc { impl DecodeFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), } } } +static DECODE_DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_decode_doc() -> &'static Documentation { + DECODE_DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_BINARY_STRING) + .with_description("Decode binary data from textual representation in string.") + .with_syntax_example("decode(expression, format)") + .with_argument("expression", "Expression containing encoded string data") + .with_argument("format", "Same arguments as [encode](#encode)") + .with_related_udf("encode") + .build() + .unwrap() + }) +} + impl ScalarUDFImpl for DecodeFunc { fn as_any(&self) -> &dyn Any { self @@ -137,23 +174,43 @@ impl ScalarUDFImpl for DecodeFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(match arg_types[0] { - Utf8 => Binary, - LargeUtf8 => LargeBinary, - Binary => Binary, - LargeBinary => LargeBinary, - Null => Null, - _ => { - return plan_err!("The decode function can only accept utf8 or binary."); - } - }) + Ok(arg_types[0].to_owned()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { decode(args) } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 2 { + return plan_err!( + "{} expects to get 2 arguments, but got {}", + self.name(), + arg_types.len() + ); + } + + if arg_types[1] != DataType::Utf8 { + return plan_err!("2nd argument should be Utf8"); + } + + match arg_types[0] { + DataType::Utf8 | DataType::Binary | DataType::Null => { + Ok(vec![DataType::Binary, DataType::Utf8]) + } + DataType::LargeUtf8 | DataType::LargeBinary => { + Ok(vec![DataType::LargeBinary, DataType::Utf8]) + } + _ => plan_err!( + "1st argument should be Utf8 or Binary or Null, got {:?}", + arg_types[0] + ), + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_decode_doc()) + } } #[derive(Debug, Copy, Clone)] @@ -232,16 +289,22 @@ fn base64_encode(input: &[u8]) -> String { general_purpose::STANDARD_NO_PAD.encode(input) } -fn hex_decode(input: &[u8]) -> Result> { - hex::decode(input).map_err(|e| { +fn hex_decode(input: &[u8], buf: &mut [u8]) -> Result { + // only write input / 2 bytes to buf + let out_len = input.len() / 2; + let buf = &mut buf[..out_len]; + hex::decode_to_slice(input, buf).map_err(|e| { DataFusionError::Internal(format!("Failed to decode from hex: {}", e)) - }) + })?; + Ok(out_len) } -fn base64_decode(input: &[u8]) -> Result> { - general_purpose::STANDARD_NO_PAD.decode(input).map_err(|e| { - DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) - }) +fn base64_decode(input: &[u8], buf: &mut [u8]) -> Result { + general_purpose::STANDARD_NO_PAD + .decode_slice(input, buf) + .map_err(|e| { + DataFusionError::Internal(format!("Failed to decode from base64: {}", e)) + }) } macro_rules! encode_to_array { @@ -254,14 +317,35 @@ macro_rules! encode_to_array { }}; } -macro_rules! decode_to_array { - ($METHOD: ident, $INPUT:expr) => {{ - let binary_array: BinaryArray = $INPUT - .iter() - .map(|x| x.map(|x| $METHOD(x.as_ref())).transpose()) - .collect::>()?; - Arc::new(binary_array) - }}; +fn decode_to_array( + method: F, + input: &GenericByteArray, + conservative_upper_bound_size: usize, +) -> Result +where + F: Fn(&[u8], &mut [u8]) -> Result, +{ + let mut values = vec![0; conservative_upper_bound_size]; + let mut offsets = OffsetBufferBuilder::new(input.len()); + let mut total_bytes_decoded = 0; + for v in input { + if let Some(v) = v { + let cursor = &mut values[total_bytes_decoded..]; + let decoded = method(v.as_ref(), cursor)?; + total_bytes_decoded += decoded; + offsets.push_length(decoded); + } else { + offsets.push_length(0); + } + } + // We reserved an upper bound size for the values buffer, but we only use the actual size + values.truncate(total_bytes_decoded); + let binary_array = BinaryArray::try_new( + offsets.finish(), + Buffer::from_vec(values), + input.nulls().cloned(), + )?; + Ok(Arc::new(binary_array)) } impl Encoding { @@ -368,10 +452,7 @@ impl Encoding { T: OffsetSizeTrait, { let input_value = as_generic_binary_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => decode_to_array!(base64_decode, input_value), - Self::Hex => decode_to_array!(hex_decode, input_value), - }; + let array = self.decode_byte_array(input_value)?; Ok(ColumnarValue::Array(array)) } @@ -380,12 +461,29 @@ impl Encoding { T: OffsetSizeTrait, { let input_value = as_generic_string_array::(value)?; - let array: ArrayRef = match self { - Self::Base64 => decode_to_array!(base64_decode, input_value), - Self::Hex => decode_to_array!(hex_decode, input_value), - }; + let array = self.decode_byte_array(input_value)?; Ok(ColumnarValue::Array(array)) } + + fn decode_byte_array( + &self, + input_value: &GenericByteArray, + ) -> Result { + match self { + Self::Base64 => { + let upper_bound = + base64::decoded_len_estimate(input_value.values().len()); + decode_to_array(base64_decode, input_value, upper_bound) + } + Self::Hex => { + // Calculate the upper bound for decoded byte size + // For hex encoding, each pair of hex characters (2 bytes) represents 1 byte when decoded + // So the upper bound is half the length of the input values. + let upper_bound = input_value.values().len() / 2; + decode_to_array(hex_decode, input_value, upper_bound) + } + } + } } impl fmt::Display for Encoding { diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs index 49f914a68774..48171370ad58 100644 --- a/datafusion/functions/src/encoding/mod.rs +++ b/datafusion/functions/src/encoding/mod.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +use datafusion_expr::ScalarUDF; +use std::sync::Arc; + pub mod inner; // create `encode` and `decode` UDFs @@ -22,7 +25,19 @@ make_udf_function!(inner::EncodeFunc, ENCODE, encode); make_udf_function!(inner::DecodeFunc, DECODE, decode); // Export the functions out of this package, both as expr_fn as well as a list of functions -export_functions!( - (encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"), - (decode, input encoding, "decode the `input`, using the `encoding`. encoding can be base64 or hex") -); +pub mod expr_fn { + export_functions!( ( + encode, + "encode the `input`, using the `encoding`. encoding can be base64 or hex", + input encoding + ),( + decode, + "decode the `input`, using the `encoding`. encoding can be base64 or hex", + input encoding + )); +} + +/// Returns all DataFusion functions defined in this package +pub fn functions() -> Vec> { + vec![encode(), decode()] +} diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 2a00839dc532..91f9449953e9 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Function packages for [DataFusion]. //! @@ -74,12 +76,14 @@ //! 3. Add a new feature to `Cargo.toml`, with any optional dependencies //! //! 4. Use the `make_package!` macro to expose the module when the -//! feature is enabled. +//! feature is enabled. //! //! [`ScalarUDF`]: datafusion_expr::ScalarUDF use datafusion_common::Result; use datafusion_execution::FunctionRegistry; +use datafusion_expr::ScalarUDF; use log::debug; +use std::sync::Arc; #[macro_use] pub mod macros; @@ -128,6 +132,11 @@ make_stub_package!(crypto, "crypto_expressions"); pub mod unicode; make_stub_package!(unicode, "unicode_expressions"); +#[cfg(any(feature = "datetime_expressions", feature = "unicode_expressions"))] +pub mod planner; + +pub mod strings; + mod utils; /// Fluent-style API for creating `Expr`s @@ -150,9 +159,9 @@ pub mod expr_fn { pub use super::unicode::expr_fn::*; } -/// Registers all enabled packages with a [`FunctionRegistry`] -pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let mut all_functions = core::functions() +/// Return all default functions +pub fn all_default_functions() -> Vec> { + core::functions() .into_iter() .chain(datetime::functions()) .chain(encoding::functions()) @@ -160,9 +169,15 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(regex::functions()) .chain(crypto::functions()) .chain(unicode::functions()) - .chain(string::functions()); + .chain(string::functions()) + .collect::>() +} + +/// Registers all enabled packages with a [`FunctionRegistry`] +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + let all_functions = all_default_functions(); - all_functions.try_for_each(|udf| { + all_functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; if let Some(existing_udf) = existing_udf { debug!("Overwrite existing UDF: {}", existing_udf.name()); @@ -171,3 +186,30 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { })?; Ok(()) } + +#[cfg(test)] +mod tests { + use crate::all_default_functions; + use datafusion_common::Result; + use std::collections::HashSet; + + #[test] + fn test_no_duplicate_name() -> Result<()> { + let mut names = HashSet::new(); + for func in all_default_functions() { + assert!( + names.insert(func.name().to_string().to_lowercase()), + "duplicate function name: {}", + func.name() + ); + for alias in func.aliases() { + assert!( + names.insert(alias.to_string().to_lowercase()), + "duplicate function name: {}", + alias + ); + } + } + Ok(()) + } +} diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 5ee47bd3e8eb..9bc038e71edc 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -36,31 +36,37 @@ /// ] /// } /// ``` +/// +/// Exported functions accept: +/// - `Vec` argument (single argument followed by a comma) +/// - Variable number of `Expr` arguments (zero or more arguments, must be without commas) macro_rules! export_functions { - ($(($FUNC:ident, $($arg:ident)*, $DOC:expr)),*) => { - pub mod expr_fn { - $( - #[doc = $DOC] - /// Return $name(arg) - pub fn $FUNC($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { - super::$FUNC().call(vec![$($arg),*],) - } - )* + ($(($FUNC:ident, $DOC:expr, $($arg:tt)*)),*) => { + $( + // switch to single-function cases below + export_functions!(single $FUNC, $DOC, $($arg)*); + )* + }; + + // single vector argument (a single argument followed by a comma) + (single $FUNC:ident, $DOC:expr, $arg:ident,) => { + #[doc = $DOC] + pub fn $FUNC($arg: Vec) -> datafusion_expr::Expr { + super::$FUNC().call($arg) } + }; - /// Return a list of all functions in this package - pub fn functions() -> Vec> { - vec![ - $( - $FUNC(), - )* - ] + // variadic arguments (zero or more arguments, without commas) + (single $FUNC:ident, $DOC:expr, $($arg:ident)*) => { + #[doc = $DOC] + pub fn $FUNC($($arg: datafusion_expr::Expr),*) -> datafusion_expr::Expr { + super::$FUNC().call(vec![$($arg),*]) } }; } /// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that function named $NAME. +/// function named `$NAME` which returns that singleton. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. macro_rules! make_udf_function { @@ -69,9 +75,8 @@ macro_rules! make_udf_function { static $GNAME: std::sync::OnceLock> = std::sync::OnceLock::new(); - /// Return a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF + #[doc = "Return a [`ScalarUDF`](datafusion_expr::ScalarUDF) implementation "] + #[doc = stringify!($UDF)] pub fn $NAME() -> std::sync::Arc { $GNAME .get_or_init(|| { @@ -89,7 +94,6 @@ macro_rules! make_udf_function { /// The rationale for providing stub functions is to help users to configure datafusion /// properly (so they get an error telling them why a function is not available) /// instead of getting a cryptic "no function found" message at runtime. - macro_rules! make_stub_package { ($name:ident, $feature:literal) => { #[cfg(not(feature = $feature))] @@ -108,27 +112,6 @@ macro_rules! make_stub_package { }; } -/// Invokes a function on each element of an array and returns the result as a new array -/// -/// $ARG: ArrayRef -/// $NAME: name of the function (for error messages) -/// $ARGS_TYPE: the type of array to cast the argument to -/// $RETURN_TYPE: the type of array to return -/// $FUNC: the function to apply to each element of $ARG -/// -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARG_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARG_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - /// Downcast an argument to a specific array type, returning an internal error /// if the cast fails /// @@ -156,21 +139,24 @@ macro_rules! downcast_arg { /// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument -/// $MONOTONIC_FUNC: the monotonicity of the function +/// $OUTPUT_ORDERING: the output ordering calculation method of the function macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $OUTPUT_ORDERING:expr, $EVALUATE_BOUNDS:expr, $GET_DOC:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - }; use std::any::Any; use std::sync::Arc; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::interval_arithmetic::Interval; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + }; + #[derive(Debug)] pub struct $UDF { signature: Signature, @@ -211,32 +197,30 @@ macro_rules! make_math_unary_udf { } } - fn monotonicity(&self) -> Result> { - Ok($MONOTONICITY) + fn output_ordering( + &self, + input: &[ExprProperties], + ) -> Result { + $OUTPUT_ORDERING(input) + } + + fn evaluate_bounds(&self, inputs: &[&Interval]) -> Result { + $EVALUATE_BOUNDS(inputs) } fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - Float64Array, - { f64::$UNARY_FUNC } - )) - } - DataType::Float32 => { - Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - Float32Array, - { f32::$UNARY_FUNC } - )) - } + DataType::Float64 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| f64::$UNARY_FUNC(x)), + ) as ArrayRef, + DataType::Float32 => Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| f32::$UNARY_FUNC(x)), + ) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -244,8 +228,13 @@ macro_rules! make_math_unary_udf { ) } }; + Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } } } }; @@ -260,22 +249,24 @@ macro_rules! make_math_unary_udf { /// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $BINARY_FUNC: the binary function to apply to the argument -/// $MONOTONIC_FUNC: the monotonicity of the function +/// $OUTPUT_ORDERING: the output ordering calculation method of the function macro_rules! make_math_binary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $MONOTONICITY:expr) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $BINARY_FUNC:ident, $OUTPUT_ORDERING:expr, $GET_DOC:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { - use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, DataFusionError, Result}; - use datafusion_expr::TypeSignature::*; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - }; use std::any::Any; use std::sync::Arc; + use arrow::array::{ArrayRef, AsArray}; + use arrow::datatypes::{DataType, Float32Type, Float64Type}; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + use datafusion_expr::TypeSignature; + use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + }; + #[derive(Debug)] pub struct $UDF { signature: Signature, @@ -287,8 +278,8 @@ macro_rules! make_math_binary_udf { Self { signature: Signature::one_of( vec![ - Exact(vec![Float32, Float32]), - Exact(vec![Float64, Float64]), + TypeSignature::Exact(vec![Float32, Float32]), + TypeSignature::Exact(vec![Float64, Float64]), ], Volatility::Immutable, ), @@ -318,31 +309,36 @@ macro_rules! make_math_binary_udf { } } - fn monotonicity(&self) -> Result> { - Ok($MONOTONICITY) + fn output_ordering( + &self, + input: &[ExprProperties], + ) -> Result { + $OUTPUT_ORDERING(input) } fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::$BINARY_FUNC } - )), - - DataType::Float32 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::$BINARY_FUNC } - )), + DataType::Float64 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + y, + x, + |y, x| f64::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } + DataType::Float32 => { + let y = args[0].as_primitive::(); + let x = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + y, + x, + |y, x| f32::$BINARY_FUNC(y, x), + )?; + Arc::new(result) as _ + } other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -350,49 +346,14 @@ macro_rules! make_math_binary_udf { ) } }; + Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some($GET_DOC()) + } } } }; } - -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; -} - -macro_rules! make_function_inputs2 { - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; - ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ - let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); - let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); - - arg1.iter() - .zip(arg2.iter()) - .map(|(a1, a2)| match (a1, a2) { - (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), - _ => None, - }) - .collect::<$ARRAY_TYPE1>() - }}; -} diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index e05dc8665285..5511a57d8566 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -17,22 +17,22 @@ //! math expressions -use arrow::array::Decimal128Array; -use arrow::array::Decimal256Array; -use arrow::array::Int16Array; -use arrow::array::Int32Array; -use arrow::array::Int64Array; -use arrow::array::Int8Array; -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, not_impl_err}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::array::{ + ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, +}; +use arrow::datatypes::DataType; use arrow::error::ArrowError; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; +use datafusion_common::{exec_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; type MathArrayFunction = fn(&Vec) -> Result; @@ -94,7 +94,7 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 - | DataType::UInt64 => Ok(|args: &Vec| Ok(args[0].clone())), + | DataType::UInt64 => Ok(|args: &Vec| Ok(Arc::clone(&args[0]))), // Decimal types DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), @@ -170,7 +170,39 @@ impl ScalarUDFImpl for AbsFunc { let input_data_type = args[0].data_type(); let abs_fun = create_abs_function(input_data_type)?; - let arr = abs_fun(&args)?; - Ok(ColumnarValue::Array(arr)) + abs_fun(&args).map(ColumnarValue::Array) } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0. + let arg = &input[0]; + let range = &arg.range; + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else if range.lt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(-arg.sort_properties) + } else { + Ok(SortProperties::Unordered) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_abs_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_abs_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the absolute value of a number.") + .with_syntax_example("abs(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/bounds.rs b/datafusion/functions/src/math/bounds.rs new file mode 100644 index 000000000000..894d2bded5eb --- /dev/null +++ b/datafusion/functions/src/math/bounds.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::ScalarValue; +use datafusion_expr::interval_arithmetic::Interval; + +pub(super) fn unbounded_bounds(input: &[&Interval]) -> crate::Result { + let data_type = input[0].data_type(); + + Interval::make_unbounded(&data_type) +} + +pub(super) fn sin_bounds(input: &[&Interval]) -> crate::Result { + // sin(x) is bounded by [-1, 1] + let data_type = input[0].data_type(); + + Interval::make_symmetric_unit_interval(&data_type) +} + +pub(super) fn asin_bounds(input: &[&Interval]) -> crate::Result { + // asin(x) is bounded by [-π/2, π/2] + let data_type = input[0].data_type(); + + Interval::make_symmetric_half_pi_interval(&data_type) +} + +pub(super) fn atan_bounds(input: &[&Interval]) -> crate::Result { + // atan(x) is bounded by [-π/2, π/2] + let data_type = input[0].data_type(); + + Interval::make_symmetric_half_pi_interval(&data_type) +} + +pub(super) fn acos_bounds(input: &[&Interval]) -> crate::Result { + // acos(x) is bounded by [0, π] + let data_type = input[0].data_type(); + + Interval::try_new( + ScalarValue::new_zero(&data_type)?, + ScalarValue::new_pi_upper(&data_type)?, + ) +} + +pub(super) fn acosh_bounds(input: &[&Interval]) -> crate::Result { + // acosh(x) is bounded by [0, ∞) + let data_type = input[0].data_type(); + + Interval::make_non_negative_infinity_interval(&data_type) +} + +pub(super) fn cos_bounds(input: &[&Interval]) -> crate::Result { + // cos(x) is bounded by [-1, 1] + let data_type = input[0].data_type(); + + Interval::make_symmetric_unit_interval(&data_type) +} + +pub(super) fn cosh_bounds(input: &[&Interval]) -> crate::Result { + // cosh(x) is bounded by [1, ∞) + let data_type = input[0].data_type(); + + Interval::try_new( + ScalarValue::new_one(&data_type)?, + ScalarValue::try_from(&data_type)?, + ) +} + +pub(super) fn exp_bounds(input: &[&Interval]) -> crate::Result { + // exp(x) is bounded by [0, ∞) + let data_type = input[0].data_type(); + + Interval::make_non_negative_infinity_interval(&data_type) +} + +pub(super) fn radians_bounds(input: &[&Interval]) -> crate::Result { + // radians(x) is bounded by (-π, π) + let data_type = input[0].data_type(); + + Interval::make_symmetric_pi_interval(&data_type) +} + +pub(super) fn sqrt_bounds(input: &[&Interval]) -> crate::Result { + // sqrt(x) is bounded by [0, ∞) + let data_type = input[0].data_type(); + + Interval::make_non_negative_infinity_interval(&data_type) +} + +pub(super) fn tanh_bounds(input: &[&Interval]) -> crate::Result { + // tanh(x) is bounded by (-1, 1) + let data_type = input[0].data_type(); + + Interval::make_symmetric_unit_interval(&data_type) +} diff --git a/datafusion/functions/src/math/cot.rs b/datafusion/functions/src/math/cot.rs index 66219960d9a2..eded50a20d8d 100644 --- a/datafusion/functions/src/math/cot.rs +++ b/datafusion/functions/src/math/cot.rs @@ -16,17 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray}; use arrow::datatypes::DataType::{Float32, Float64}; - -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use crate::utils::make_scalar_function; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct CotFunc { @@ -39,6 +39,20 @@ impl Default for CotFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_cot_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cotangent of a number.") + .with_syntax_example(r#"cot(numeric_expression)"#) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + impl CotFunc { pub fn new() -> Self { use DataType::*; @@ -77,6 +91,10 @@ impl ScalarUDFImpl for CotFunc { } } + fn documentation(&self) -> Option<&Documentation> { + Some(get_cot_doc()) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(cot, vec![])(args) } @@ -85,18 +103,16 @@ impl ScalarUDFImpl for CotFunc { ///cot SQL function fn cot(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float64Array, - { compute_cot64 } - )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "x", - Float32Array, - { compute_cot32 } - )) as ArrayRef), + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| compute_cot64(x)), + ) as ArrayRef), + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| compute_cot32(x)), + ) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function cot"), } } diff --git a/datafusion/functions/src/math/factorial.rs b/datafusion/functions/src/math/factorial.rs index dc481da79069..bacdf47524f4 100644 --- a/datafusion/functions/src/math/factorial.rs +++ b/datafusion/functions/src/math/factorial.rs @@ -15,16 +15,22 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, Int64Array}; +use arrow::{ + array::{ArrayRef, Int64Array}, + error::ArrowError, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FactorialFunc { @@ -65,30 +71,47 @@ impl ScalarUDFImpl for FactorialFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(factorial, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_factorial_doc()) + } } -macro_rules! make_function_scalar_inputs { - ($ARG: expr, $NAME:expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARRAY_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$ARRAY_TYPE>() - }}; +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_factorial_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Factorial. Returns 1 if value is less than 2.") + .with_syntax_example("factorial(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } /// Factorial SQL function fn factorial(args: &[ArrayRef]) -> Result { match args[0].data_type() { - DataType::Int64 => Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Int64Array, - { |value: i64| { (1..=value).product() } } - )) as ArrayRef), + Int64 => { + let arg = downcast_arg!((&args[0]), "value", Int64Array); + Ok(arg + .iter() + .map(|a| match a { + Some(a) => (2..=a) + .try_fold(1i64, i64::checked_mul) + .ok_or_else(|| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Overflow happened on FACTORIAL({a})" + ))) + }) + .map(Some), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function factorial."), } } diff --git a/datafusion/functions/src/math/gcd.rs b/datafusion/functions/src/math/gcd.rs index 41c9e4e23314..f4edef3acca3 100644 --- a/datafusion/functions/src/math/gcd.rs +++ b/datafusion/functions/src/math/gcd.rs @@ -16,17 +16,20 @@ // under the License. use arrow::array::{ArrayRef, Int64Array}; +use arrow::error::ArrowError; use std::any::Any; use std::mem::swap; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct GcdFunc { @@ -68,29 +71,52 @@ impl ScalarUDFImpl for GcdFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(gcd, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_gcd_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_gcd_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.", + ) + .with_syntax_example("gcd(expression_x, expression_y)") + .with_standard_argument("expression_x", Some("First numeric")) + .with_standard_argument("expression_y", Some("Second numeric")) + .build() + .unwrap() + }) } /// Gcd SQL function fn gcd(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_gcd } - )) as ArrayRef), + Int64 => { + let arg1 = downcast_arg!(&args[0], "x", Int64Array); + let arg2 = downcast_arg!(&args[1], "y", Int64Array); + + Ok(arg1 + .iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Ok(Some(compute_gcd(a1, a2)?)), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function gcd"), } } -/// Computes greatest common divisor using Binary GCD algorithm. -pub fn compute_gcd(x: i64, y: i64) -> i64 { - let mut a = x.wrapping_abs(); - let mut b = y.wrapping_abs(); - +/// Computes gcd of two unsigned integers using Binary GCD algorithm. +pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 { if a == 0 { return b; } @@ -99,32 +125,43 @@ pub fn compute_gcd(x: i64, y: i64) -> i64 { } let shift = (a | b).trailing_zeros(); - a >>= shift; - b >>= shift; a >>= a.trailing_zeros(); - loop { b >>= b.trailing_zeros(); if a > b { swap(&mut a, &mut b); } - b -= a; - if b == 0 { return a << shift; } } } +/// Computes greatest common divisor using Binary GCD algorithm. +pub fn compute_gcd(x: i64, y: i64) -> Result { + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + let r = unsigned_gcd(a, b); + // gcd(i64::MIN, i64::MIN) = i64::MIN.unsigned_abs() cannot fit into i64 + r.try_into().map_err(|_| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Signed integer overflow in GCD({x}, {y})" + ))) + }) +} + #[cfg(test)] mod test { use std::sync::Arc; - use arrow::array::{ArrayRef, Int64Array}; + use arrow::{ + array::{ArrayRef, Int64Array}, + error::ArrowError, + }; use crate::math::gcd::gcd; - use datafusion_common::cast::as_int64_array; + use datafusion_common::{cast::as_int64_array, DataFusionError}; #[test] fn test_gcd_i64() { @@ -142,4 +179,21 @@ mod test { assert_eq!(ints.value(2), 5); assert_eq!(ints.value(3), 8); } + + #[test] + fn overflow_on_both_param_i64_min() { + let args: Vec = vec![ + Arc::new(Int64Array::from(vec![i64::MIN])), // x + Arc::new(Int64Array::from(vec![i64::MIN])), // y + ]; + + match gcd(&args) { + // we expect a overflow + Err(DataFusionError::ArrowError(ArrowError::ComputeError(_), _)) => {} + Err(_) => { + panic!("failed to initialize function gcd") + } + Ok(_) => panic!("GCD({0}, {0}) should have overflown", i64::MIN), + }; + } } diff --git a/datafusion/functions/src/math/iszero.rs b/datafusion/functions/src/math/iszero.rs index e6a728053359..7e5d4fe77ffa 100644 --- a/datafusion/functions/src/math/iszero.rs +++ b/datafusion/functions/src/math/iszero.rs @@ -16,16 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; use arrow::datatypes::DataType::{Boolean, Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use crate::utils::make_scalar_function; @@ -72,25 +74,39 @@ impl ScalarUDFImpl for IsZeroFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(iszero, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_iszero_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_iszero_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns true if a given number is +0.0 or -0.0 otherwise returns false.", + ) + .with_syntax_example("iszero(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } /// Iszero SQL function pub fn iszero(args: &[ArrayRef]) -> Result { match args[0].data_type() { - Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { |x: f64| { x == 0_f64 } } + Float64 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), - Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { |x: f32| { x == 0_f32 } } + Float32 => Ok(Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + |x| x == 0.0, )) as ArrayRef), other => exec_err!("Unsupported data type {other:?} for function iszero"), diff --git a/datafusion/functions/src/math/lcm.rs b/datafusion/functions/src/math/lcm.rs index 3674f7371de2..64b07ce606f2 100644 --- a/datafusion/functions/src/math/lcm.rs +++ b/datafusion/functions/src/math/lcm.rs @@ -16,17 +16,20 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, Int64Array}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::error::ArrowError; +use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; -use crate::math::gcd::compute_gcd; +use super::gcd::unsigned_gcd; use crate::utils::make_scalar_function; #[derive(Debug)] @@ -69,30 +72,66 @@ impl ScalarUDFImpl for LcmFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(lcm, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lcm_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lcm_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the least common multiple of `expression_x` and `expression_y`. Returns 0 if either input is zero.", + ) + .with_syntax_example("lcm(expression_x, expression_y)") + .with_standard_argument("expression_x", Some("First numeric")) + .with_standard_argument("expression_y", Some("Second numeric")) + .build() + .unwrap() + }) } /// Lcm SQL function fn lcm(args: &[ArrayRef]) -> Result { let compute_lcm = |x: i64, y: i64| { - let a = x.wrapping_abs(); - let b = y.wrapping_abs(); - - if a == 0 || b == 0 { - return 0; + if x == 0 || y == 0 { + return Ok(0); } - a / compute_gcd(a, b) * b + + // lcm(x, y) = |x| * |y| / gcd(|x|, |y|) + let a = x.unsigned_abs(); + let b = y.unsigned_abs(); + let gcd = unsigned_gcd(a, b); + // gcd is not zero since both a and b are not zero, so the division is safe. + (a / gcd) + .checked_mul(b) + .and_then(|v| i64::try_from(v).ok()) + .ok_or_else(|| { + arrow_datafusion_err!(ArrowError::ComputeError(format!( + "Signed integer overflow in LCM({x}, {y})" + ))) + }) }; match args[0].data_type() { - Int64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Int64Array, - Int64Array, - { compute_lcm } - )) as ArrayRef), + Int64 => { + let arg1 = downcast_arg!(&args[0], "x", Int64Array); + let arg2 = downcast_arg!(&args[1], "y", Int64Array); + + Ok(arg1 + .iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Ok(Some(compute_lcm(a1, a2)?)), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as ArrayRef) + } other => exec_err!("Unsupported data type {other:?} for function lcm"), } } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 0a29e1ecfc12..9110f9f532d8 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -17,24 +17,24 @@ //! Math function: `log()`. -use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +use super::power::PowerFunc; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_err, plan_datafusion_err, plan_err, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition, + lit, ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature::*, }; - -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use super::power::PowerFunc; #[derive(Debug)] pub struct LogFunc { @@ -47,6 +47,22 @@ impl Default for LogFunc { } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_log_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number.") + .with_syntax_example(r#"log(base, numeric_expression) +log(numeric_expression)"#) + .with_standard_argument("base", Some("Base numeric")) + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + impl LogFunc { pub fn new() -> Self { use DataType::*; @@ -83,8 +99,29 @@ impl ScalarUDFImpl for LogFunc { } } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true), Some(false)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + let (base_sort_properties, num_sort_properties) = if input.len() == 1 { + // log(x) defaults to log(10, x) + (SortProperties::Singleton, input[0].sort_properties) + } else { + (input[0].sort_properties, input[1].sort_properties) + }; + match (num_sort_properties, base_sort_properties) { + (first @ SortProperties::Ordered(num), SortProperties::Ordered(base)) + if num.descending != base.descending + && num.nulls_first == base.nulls_first => + { + Ok(first) + } + ( + first @ (SortProperties::Ordered(_) | SortProperties::Singleton), + SortProperties::Singleton, + ) => Ok(first), + (SortProperties::Singleton, second @ SortProperties::Ordered(_)) => { + Ok(-second) + } + _ => Ok(SortProperties::Unordered), + } } // Support overloaded log(base, x) and log(x) which defaults to log(10, x) @@ -96,43 +133,46 @@ impl ScalarUDFImpl for LogFunc { let mut x = &args[0]; if args.len() == 2 { x = &args[1]; - base = ColumnarValue::Array(args[0].clone()); + base = ColumnarValue::Array(Arc::clone(&args[0])); } // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) + Arc::new(x.as_primitive::().unary::<_, Float64Type>( + |value: f64| f64::log(value, base as f64), + )) + } + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + x, + base, + f64::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float64Array, - { f64::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) + ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => Arc::new( + x.as_primitive::() + .unary::<_, Float32Type>(|value: f32| f32::log(value, base)), + ), + ColumnarValue::Array(base) => { + let x = x.as_primitive::(); + let base = base.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float32Type>( + x, + base, + f32::log, + )?; + Arc::new(result) as _ } - ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( - x, - base, - "x", - "base", - Float32Array, - { f32::log } - )), _ => { return exec_err!("log function requires a scalar or array for base") } @@ -145,6 +185,10 @@ impl ScalarUDFImpl for LogFunc { Ok(ColumnarValue::Array(arr)) } + fn documentation(&self) -> Option<&Documentation> { + Some(get_log_doc()) + } + /// Simplify the `log` function by the relevant rules: /// 1. Log(a, 1) ===> 0 /// 2. Log(a, Power(a, b)) ===> b @@ -178,8 +222,8 @@ impl ScalarUDFImpl for LogFunc { &info.get_data_type(&base)?, )?))) } - Expr::ScalarFunction(ScalarFunction { func_def, mut args }) - if is_pow(&func_def) && args.len() == 2 && base == args[0] => + Expr::ScalarFunction(ScalarFunction { func, mut args }) + if is_pow(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above Ok(ExprSimplifyResult::Simplified(b)) @@ -192,7 +236,7 @@ impl ScalarUDFImpl for LogFunc { } else { let args = match num_args { 1 => vec![number], - 2 => vec![number, base], + 2 => vec![base, number], _ => { return internal_err!( "Unexpected number of arguments in log::simplify" @@ -207,24 +251,202 @@ impl ScalarUDFImpl for LogFunc { } /// Returns true if the function is `PowerFunc` -fn is_pow(func_def: &ScalarFunctionDefinition) -> bool { - if let ScalarFunctionDefinition::UDF(fun) = func_def { - fun.as_ref() - .inner() - .as_any() - .downcast_ref::() - .is_some() - } else { - false - } +fn is_pow(func: &ScalarUDF) -> bool { + func.inner().as_any().downcast_ref::().is_some() } #[cfg(test)] mod tests { - use datafusion_common::cast::{as_float32_array, as_float64_array}; + use std::collections::HashMap; use super::*; + use arrow::array::{Float32Array, Float64Array, Int64Array}; + use arrow::compute::SortOptions; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_common::DFSchema; + use datafusion_expr::execution_props::ExecutionProps; + use datafusion_expr::simplify::SimplifyContext; + + #[test] + #[should_panic] + fn test_log_invalid_base_type() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), + ]; + + let _ = LogFunc::new().invoke_batch(&args, 4); + } + + #[test] + fn test_log_invalid_value() { + let args = [ + ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num + ]; + + let result = LogFunc::new().invoke_batch(&args, 1); + result.expect_err("expected error"); + } + + #[test] + fn test_log_scalar_f32_unary() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num + ]; + + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64_unary() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num + ]; + + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f32() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num + ]; + + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 5.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_scalar_f64() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num + ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num + ]; + + let result = LogFunc::new() + .invoke_batch(&args, 1) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 1); + assert!((floats.value(0) - 6.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f64_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float64Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke_batch(&args, 4) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_log_f32_unary() { + let args = [ + ColumnarValue::Array(Arc::new(Float32Array::from(vec![ + 10.0, 100.0, 1000.0, 10000.0, + ]))), // num + ]; + + let result = LogFunc::new() + .invoke_batch(&args, 4) + .expect("failed to initialize function log"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float64Array"); + + assert_eq!(floats.len(), 4); + assert!((floats.value(0) - 1.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 3.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + #[test] fn test_log_f64() { let args = [ @@ -235,7 +457,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { @@ -244,10 +466,10 @@ mod tests { .expect("failed to convert result to a Float64Array"); assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 3.0); - assert_eq!(floats.value(1), 2.0); - assert_eq!(floats.value(2), 4.0); - assert_eq!(floats.value(3), 4.0); + assert!((floats.value(0) - 3.0).abs() < 1e-10); + assert!((floats.value(1) - 2.0).abs() < 1e-10); + assert!((floats.value(2) - 4.0).abs() < 1e-10); + assert!((floats.value(3) - 4.0).abs() < 1e-10); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") @@ -265,7 +487,7 @@ mod tests { ]; let result = LogFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function log"); match result { @@ -274,14 +496,162 @@ mod tests { .expect("failed to convert result to a Float32Array"); assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), 3.0); - assert_eq!(floats.value(1), 2.0); - assert_eq!(floats.value(2), 4.0); - assert_eq!(floats.value(3), 4.0); + assert!((floats.value(0) - 3.0).abs() < f32::EPSILON); + assert!((floats.value(1) - 2.0).abs() < f32::EPSILON); + assert!((floats.value(2) - 4.0).abs() < f32::EPSILON); + assert!((floats.value(3) - 4.0).abs() < f32::EPSILON); } ColumnarValue::Scalar(_) => { panic!("Expected an array value") } } } + #[test] + // Test log() simplification errors + fn test_log_simplify_errors() { + let props = ExecutionProps::new(); + let schema = + Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); + let context = SimplifyContext::new(&props).with_schema(schema); + // Expect 0 args to error + let _ = LogFunc::new().simplify(vec![], &context).unwrap_err(); + // Expect 3 args to error + let _ = LogFunc::new() + .simplify(vec![lit(1), lit(2), lit(3)], &context) + .unwrap_err(); + } + + #[test] + // Test that non-simplifiable log() expressions are unchanged after simplification + fn test_log_simplify_original() { + let props = ExecutionProps::new(); + let schema = + Arc::new(DFSchema::new_with_metadata(vec![], HashMap::new()).unwrap()); + let context = SimplifyContext::new(&props).with_schema(schema); + // One argument with no simplifications + let result = LogFunc::new().simplify(vec![lit(2)], &context).unwrap(); + let ExprSimplifyResult::Original(args) = result else { + panic!("Expected ExprSimplifyResult::Original") + }; + assert_eq!(args.len(), 1); + assert_eq!(args[0], lit(2)); + // Two arguments with no simplifications + let result = LogFunc::new() + .simplify(vec![lit(2), lit(3)], &context) + .unwrap(); + let ExprSimplifyResult::Original(args) = result else { + panic!("Expected ExprSimplifyResult::Original") + }; + assert_eq!(args.len(), 2); + assert_eq!(args[0], lit(2)); + assert_eq!(args[1], lit(3)); + } + + #[test] + fn test_log_output_ordering() { + // [Unordered, Ascending, Descending, Literal] + let orders = vec![ + ExprProperties::new_unknown(), + ExprProperties::new_unknown().with_order(SortProperties::Ordered( + SortOptions { + descending: false, + nulls_first: true, + }, + )), + ExprProperties::new_unknown().with_order(SortProperties::Ordered( + SortOptions { + descending: true, + nulls_first: true, + }, + )), + ExprProperties::new_unknown().with_order(SortProperties::Singleton), + ]; + + let log = LogFunc::new(); + + // Test log(num) + for order in orders.iter().cloned() { + let result = log.output_ordering(&[order.clone()]).unwrap(); + assert_eq!(result, order.sort_properties); + } + + // Test log(base, num), where `nulls_first` is the same + let mut results = Vec::with_capacity(orders.len() * orders.len()); + for base_order in orders.iter() { + for num_order in orders.iter().cloned() { + let result = log + .output_ordering(&[base_order.clone(), num_order]) + .unwrap(); + results.push(result); + } + } + let expected = vec![ + // base: Unordered + SortProperties::Unordered, + SortProperties::Unordered, + SortProperties::Unordered, + SortProperties::Unordered, + // base: Ascending, num: Unordered + SortProperties::Unordered, + // base: Ascending, num: Ascending + SortProperties::Unordered, + // base: Ascending, num: Descending + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Ascending, num: Literal + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Descending, num: Unordered + SortProperties::Unordered, + // base: Descending, num: Ascending + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Descending, num: Descending + SortProperties::Unordered, + // base: Descending, num: Literal + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Literal, num: Unordered + SortProperties::Unordered, + // base: Literal, num: Ascending + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: true, + }), + // base: Literal, num: Descending + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + // base: Literal, num: Literal + SortProperties::Singleton, + ]; + assert_eq!(results, expected); + + // Test with different `nulls_first` + let base_order = ExprProperties::new_unknown().with_order( + SortProperties::Ordered(SortOptions { + descending: true, + nulls_first: true, + }), + ); + let num_order = ExprProperties::new_unknown().with_order( + SortProperties::Ordered(SortOptions { + descending: false, + nulls_first: false, + }), + ); + assert_eq!( + log.output_ordering(&[base_order, num_order]).unwrap(), + SortProperties::Unordered + ); + } } diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index b6e8d26b6460..1452bfdee5a0 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -17,259 +17,295 @@ //! "math" DataFusion functions +use crate::math::monotonicity::*; use datafusion_expr::ScalarUDF; use std::sync::Arc; pub mod abs; +pub mod bounds; pub mod cot; pub mod factorial; pub mod gcd; pub mod iszero; pub mod lcm; pub mod log; +pub mod monotonicity; pub mod nans; pub mod nanvl; pub mod pi; pub mod power; pub mod random; pub mod round; +pub mod signum; pub mod trunc; // Create UDFs make_udf_function!(abs::AbsFunc, ABS, abs); -make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); -make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); -make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); -make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); -make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); -make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); -make_math_binary_udf!(Atan2, ATAN2, atan2, atan2, Some(vec![Some(true)])); -make_math_unary_udf!(CbrtFunc, CBRT, cbrt, cbrt, None); -make_math_unary_udf!(CeilFunc, CEIL, ceil, ceil, Some(vec![Some(true)])); -make_math_unary_udf!(CosFunc, COS, cos, cos, None); -make_math_unary_udf!(CoshFunc, COSH, cosh, cosh, None); +make_math_unary_udf!( + AcosFunc, + ACOS, + acos, + acos, + super::acos_order, + super::bounds::acos_bounds, + super::get_acos_doc +); +make_math_unary_udf!( + AcoshFunc, + ACOSH, + acosh, + acosh, + super::acosh_order, + super::bounds::acosh_bounds, + super::get_acosh_doc +); +make_math_unary_udf!( + AsinFunc, + ASIN, + asin, + asin, + super::asin_order, + super::bounds::asin_bounds, + super::get_asin_doc +); +make_math_unary_udf!( + AsinhFunc, + ASINH, + asinh, + asinh, + super::asinh_order, + super::bounds::unbounded_bounds, + super::get_asinh_doc +); +make_math_unary_udf!( + AtanFunc, + ATAN, + atan, + atan, + super::atan_order, + super::bounds::atan_bounds, + super::get_atan_doc +); +make_math_unary_udf!( + AtanhFunc, + ATANH, + atanh, + atanh, + super::atanh_order, + super::bounds::unbounded_bounds, + super::get_atanh_doc +); +make_math_binary_udf!( + Atan2, + ATAN2, + atan2, + atan2, + super::atan2_order, + super::get_atan2_doc +); +make_math_unary_udf!( + CbrtFunc, + CBRT, + cbrt, + cbrt, + super::cbrt_order, + super::bounds::unbounded_bounds, + super::get_cbrt_doc +); +make_math_unary_udf!( + CeilFunc, + CEIL, + ceil, + ceil, + super::ceil_order, + super::bounds::unbounded_bounds, + super::get_ceil_doc +); +make_math_unary_udf!( + CosFunc, + COS, + cos, + cos, + super::cos_order, + super::bounds::cos_bounds, + super::get_cos_doc +); +make_math_unary_udf!( + CoshFunc, + COSH, + cosh, + cosh, + super::cosh_order, + super::bounds::cosh_bounds, + super::get_cosh_doc +); make_udf_function!(cot::CotFunc, COT, cot); -make_math_unary_udf!(DegreesFunc, DEGREES, degrees, to_degrees, None); -make_math_unary_udf!(ExpFunc, EXP, exp, exp, Some(vec![Some(true)])); +make_math_unary_udf!( + DegreesFunc, + DEGREES, + degrees, + to_degrees, + super::degrees_order, + super::bounds::unbounded_bounds, + super::get_degrees_doc +); +make_math_unary_udf!( + ExpFunc, + EXP, + exp, + exp, + super::exp_order, + super::bounds::exp_bounds, + super::get_exp_doc +); make_udf_function!(factorial::FactorialFunc, FACTORIAL, factorial); -make_math_unary_udf!(FloorFunc, FLOOR, floor, floor, Some(vec![Some(true)])); +make_math_unary_udf!( + FloorFunc, + FLOOR, + floor, + floor, + super::floor_order, + super::bounds::unbounded_bounds, + super::get_floor_doc +); make_udf_function!(log::LogFunc, LOG, log); make_udf_function!(gcd::GcdFunc, GCD, gcd); make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(iszero::IsZeroFunc, ISZERO, iszero); make_udf_function!(lcm::LcmFunc, LCM, lcm); -make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); -make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); -make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_math_unary_udf!( + LnFunc, + LN, + ln, + ln, + super::ln_order, + super::bounds::unbounded_bounds, + super::get_ln_doc +); +make_math_unary_udf!( + Log2Func, + LOG2, + log2, + log2, + super::log2_order, + super::bounds::unbounded_bounds, + super::get_log2_doc +); +make_math_unary_udf!( + Log10Func, + LOG10, + log10, + log10, + super::log10_order, + super::bounds::unbounded_bounds, + super::get_log10_doc +); make_udf_function!(nanvl::NanvlFunc, NANVL, nanvl); make_udf_function!(pi::PiFunc, PI, pi); make_udf_function!(power::PowerFunc, POWER, power); -make_math_unary_udf!(RadiansFunc, RADIANS, radians, to_radians, None); +make_math_unary_udf!( + RadiansFunc, + RADIANS, + radians, + to_radians, + super::radians_order, + super::bounds::radians_bounds, + super::get_radians_doc +); make_udf_function!(random::RandomFunc, RANDOM, random); make_udf_function!(round::RoundFunc, ROUND, round); -make_math_unary_udf!(SignumFunc, SIGNUM, signum, signum, None); -make_math_unary_udf!(SinFunc, SIN, sin, sin, None); -make_math_unary_udf!(SinhFunc, SINH, sinh, sinh, None); -make_math_unary_udf!(SqrtFunc, SQRT, sqrt, sqrt, None); -make_math_unary_udf!(TanFunc, TAN, tan, tan, None); -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); +make_udf_function!(signum::SignumFunc, SIGNUM, signum); +make_math_unary_udf!( + SinFunc, + SIN, + sin, + sin, + super::sin_order, + super::bounds::sin_bounds, + super::get_sin_doc +); +make_math_unary_udf!( + SinhFunc, + SINH, + sinh, + sinh, + super::sinh_order, + super::bounds::unbounded_bounds, + super::get_sinh_doc +); +make_math_unary_udf!( + SqrtFunc, + SQRT, + sqrt, + sqrt, + super::sqrt_order, + super::bounds::sqrt_bounds, + super::get_sqrt_doc +); +make_math_unary_udf!( + TanFunc, + TAN, + tan, + tan, + super::tan_order, + super::bounds::unbounded_bounds, + super::get_tan_doc +); +make_math_unary_udf!( + TanhFunc, + TANH, + tanh, + tanh, + super::tanh_order, + super::bounds::tanh_bounds, + super::get_tanh_doc +); make_udf_function!(trunc::TruncFunc, TRUNC, trunc); pub mod expr_fn { - use datafusion_expr::Expr; - - #[doc = "returns the absolute value of a given number"] - pub fn abs(num: Expr) -> Expr { - super::abs().call(vec![num]) - } - - #[doc = "returns the arc cosine or inverse cosine of a number"] - pub fn acos(num: Expr) -> Expr { - super::acos().call(vec![num]) - } - - #[doc = "returns inverse hyperbolic cosine"] - pub fn acosh(num: Expr) -> Expr { - super::acosh().call(vec![num]) - } - - #[doc = "returns the arc sine or inverse sine of a number"] - pub fn asin(num: Expr) -> Expr { - super::asin().call(vec![num]) - } - - #[doc = "returns inverse hyperbolic sine"] - pub fn asinh(num: Expr) -> Expr { - super::asinh().call(vec![num]) - } - - #[doc = "returns inverse tangent"] - pub fn atan(num: Expr) -> Expr { - super::atan().call(vec![num]) - } - - #[doc = "returns inverse tangent of a division given in the argument"] - pub fn atan2(y: Expr, x: Expr) -> Expr { - super::atan2().call(vec![y, x]) - } - - #[doc = "returns inverse hyperbolic tangent"] - pub fn atanh(num: Expr) -> Expr { - super::atanh().call(vec![num]) - } - - #[doc = "cube root of a number"] - pub fn cbrt(num: Expr) -> Expr { - super::cbrt().call(vec![num]) - } - - #[doc = "nearest integer greater than or equal to argument"] - pub fn ceil(num: Expr) -> Expr { - super::ceil().call(vec![num]) - } - - #[doc = "cosine"] - pub fn cos(num: Expr) -> Expr { - super::cos().call(vec![num]) - } - - #[doc = "hyperbolic cosine"] - pub fn cosh(num: Expr) -> Expr { - super::cosh().call(vec![num]) - } - - #[doc = "cotangent of a number"] - pub fn cot(num: Expr) -> Expr { - super::cot().call(vec![num]) - } - - #[doc = "converts radians to degrees"] - pub fn degrees(num: Expr) -> Expr { - super::degrees().call(vec![num]) - } - - #[doc = "exponential"] - pub fn exp(num: Expr) -> Expr { - super::exp().call(vec![num]) - } - - #[doc = "factorial"] - pub fn factorial(num: Expr) -> Expr { - super::factorial().call(vec![num]) - } - - #[doc = "nearest integer less than or equal to argument"] - pub fn floor(num: Expr) -> Expr { - super::floor().call(vec![num]) - } - - #[doc = "greatest common divisor"] - pub fn gcd(x: Expr, y: Expr) -> Expr { - super::gcd().call(vec![x, y]) - } - - #[doc = "returns true if a given number is +NaN or -NaN otherwise returns false"] - pub fn isnan(num: Expr) -> Expr { - super::isnan().call(vec![num]) - } - - #[doc = "returns true if a given number is +0.0 or -0.0 otherwise returns false"] - pub fn iszero(num: Expr) -> Expr { - super::iszero().call(vec![num]) - } - - #[doc = "least common multiple"] - pub fn lcm(x: Expr, y: Expr) -> Expr { - super::lcm().call(vec![x, y]) - } - - #[doc = "natural logarithm (base e) of a number"] - pub fn ln(num: Expr) -> Expr { - super::ln().call(vec![num]) - } - - #[doc = "logarithm of a number for a particular `base`"] - pub fn log(base: Expr, num: Expr) -> Expr { - super::log().call(vec![base, num]) - } - - #[doc = "base 2 logarithm of a number"] - pub fn log2(num: Expr) -> Expr { - super::log2().call(vec![num]) - } - - #[doc = "base 10 logarithm of a number"] - pub fn log10(num: Expr) -> Expr { - super::log10().call(vec![num]) - } - - #[doc = "returns x if x is not NaN otherwise returns y"] - pub fn nanvl(x: Expr, y: Expr) -> Expr { - super::nanvl().call(vec![x, y]) - } - - #[doc = "Returns an approximate value of π"] - pub fn pi() -> Expr { - super::pi().call(vec![]) - } - - #[doc = "`base` raised to the power of `exponent`"] - pub fn power(base: Expr, exponent: Expr) -> Expr { - super::power().call(vec![base, exponent]) - } - - #[doc = "converts degrees to radians"] - pub fn radians(num: Expr) -> Expr { - super::radians().call(vec![num]) - } - - #[doc = "Returns a random value in the range 0.0 <= x < 1.0"] - pub fn random() -> Expr { - super::random().call(vec![]) - } - - #[doc = "round to nearest integer"] - pub fn round(args: Vec) -> Expr { - super::round().call(args) - } - - #[doc = "sign of the argument (-1, 0, +1)"] - pub fn signum(num: Expr) -> Expr { - super::signum().call(vec![num]) - } - - #[doc = "sine"] - pub fn sin(num: Expr) -> Expr { - super::sin().call(vec![num]) - } - - #[doc = "hyperbolic sine"] - pub fn sinh(num: Expr) -> Expr { - super::sinh().call(vec![num]) - } - - #[doc = "square root of a number"] - pub fn sqrt(num: Expr) -> Expr { - super::sqrt().call(vec![num]) - } - - #[doc = "returns the tangent of a number"] - pub fn tan(num: Expr) -> Expr { - super::tan().call(vec![num]) - } - - #[doc = "returns the hyperbolic tangent of a number"] - pub fn tanh(num: Expr) -> Expr { - super::tanh().call(vec![num]) - } - - #[doc = "truncate toward zero, with optional precision"] - pub fn trunc(args: Vec) -> Expr { - super::trunc().call(args) - } + export_functions!( + (abs, "returns the absolute value of a given number", num), + (acos, "returns the arc cosine or inverse cosine of a number", num), + (acosh, "returns inverse hyperbolic cosine", num), + (asin, "returns the arc sine or inverse sine of a number", num), + (asinh, "returns inverse hyperbolic sine", num), + (atan, "returns inverse tangent", num), + (atan2, "returns inverse tangent of a division given in the argument", y x), + (atanh, "returns inverse hyperbolic tangent", num), + (cbrt, "cube root of a number", num), + (ceil, "nearest integer greater than or equal to argument", num), + (cos, "cosine", num), + (cosh, "hyperbolic cosine", num), + (cot, "cotangent of a number", num), + (degrees, "converts radians to degrees", num), + (exp, "exponential", num), + (factorial, "factorial", num), + (floor, "nearest integer less than or equal to argument", num), + (gcd, "greatest common divisor", x y), + (isnan, "returns true if a given number is +NaN or -NaN otherwise returns false", num), + (iszero, "returns true if a given number is +0.0 or -0.0 otherwise returns false", num), + (lcm, "least common multiple", x y), + (ln, "natural logarithm (base e) of a number", num), + (log, "logarithm of a number for a particular `base`", base num), + (log2, "base 2 logarithm of a number", num), + (log10, "base 10 logarithm of a number", num), + (nanvl, "returns x if x is not NaN otherwise returns y", x y), + (pi, "Returns an approximate value of π",), + (power, "`base` raised to the power of `exponent`", base exponent), + (radians, "converts degrees to radians", num), + (random, "Returns a random value in the range 0.0 <= x < 1.0",), + (signum, "sign of the argument (-1, 0, +1)", num), + (sin, "sine", num), + (sinh, "hyperbolic sine", num), + (sqrt, "square root of a number", num), + (tan, "returns the tangent of a number", num), + (tanh, "returns the hyperbolic tangent of a number", num), + (round, "round to nearest integer", args,), + (trunc, "truncate toward zero, with optional precision", args,) + ); } -/// Return a list of all functions in this package +/// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { vec![ abs(), @@ -302,13 +338,128 @@ pub fn functions() -> Vec> { power(), radians(), random(), - round(), signum(), sin(), sinh(), sqrt(), tan(), tanh(), + round(), trunc(), ] } + +#[cfg(test)] +mod tests { + use arrow::datatypes::DataType; + use datafusion_common::ScalarValue; + use datafusion_expr::interval_arithmetic::Interval; + + fn unbounded_interval(data_type: &DataType) -> Interval { + Interval::make_unbounded(data_type).unwrap() + } + + fn one_to_inf_interval(data_type: &DataType) -> Interval { + Interval::try_new( + ScalarValue::new_one(data_type).unwrap(), + ScalarValue::try_from(data_type).unwrap(), + ) + .unwrap() + } + + fn zero_to_pi_interval(data_type: &DataType) -> Interval { + Interval::try_new( + ScalarValue::new_zero(data_type).unwrap(), + ScalarValue::new_pi_upper(data_type).unwrap(), + ) + .unwrap() + } + + fn assert_udf_evaluates_to_bounds( + udf: &datafusion_expr::ScalarUDF, + interval: Interval, + expected: Interval, + ) { + let input = vec![&interval]; + let result = udf.evaluate_bounds(&input).unwrap(); + assert_eq!( + result, + expected, + "Bounds check failed on UDF: {:?}", + udf.name() + ); + } + + #[test] + fn test_cases() -> crate::Result<()> { + let datatypes = [DataType::Float32, DataType::Float64]; + let cases = datatypes + .iter() + .flat_map(|data_type| { + vec![ + ( + super::acos(), + unbounded_interval(data_type), + zero_to_pi_interval(data_type), + ), + ( + super::acosh(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ( + super::asin(), + unbounded_interval(data_type), + Interval::make_symmetric_half_pi_interval(data_type).unwrap(), + ), + ( + super::atan(), + unbounded_interval(data_type), + Interval::make_symmetric_half_pi_interval(data_type).unwrap(), + ), + ( + super::cos(), + unbounded_interval(data_type), + Interval::make_symmetric_unit_interval(data_type).unwrap(), + ), + ( + super::cosh(), + unbounded_interval(data_type), + one_to_inf_interval(data_type), + ), + ( + super::sin(), + unbounded_interval(data_type), + Interval::make_symmetric_unit_interval(data_type).unwrap(), + ), + ( + super::exp(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ( + super::sqrt(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ( + super::radians(), + unbounded_interval(data_type), + Interval::make_symmetric_pi_interval(data_type).unwrap(), + ), + ( + super::sqrt(), + unbounded_interval(data_type), + Interval::make_non_negative_infinity_interval(data_type).unwrap(), + ), + ] + }) + .collect::>(); + + for (udf, interval, expected) in cases { + assert_udf_evaluates_to_bounds(&udf, interval, expected); + } + + Ok(()) + } +} diff --git a/datafusion/functions/src/math/monotonicity.rs b/datafusion/functions/src/math/monotonicity.rs new file mode 100644 index 000000000000..19c85f4b6e3c --- /dev/null +++ b/datafusion/functions/src/math/monotonicity.rs @@ -0,0 +1,572 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::OnceLock; + +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::Documentation; + +/// Non-increasing on the interval \[−1, 1\], undefined otherwise. +pub fn acos_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = + Interval::make_symmetric_unit_interval(&range.lower().data_type())?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(-arg.sort_properties) + } else { + exec_err!("Input range of ACOS contains out-of-domain values") + } +} + +static DOCUMENTATION_ACOS: OnceLock = OnceLock::new(); + +pub fn get_acos_doc() -> &'static Documentation { + DOCUMENTATION_ACOS.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the arc cosine or inverse cosine of a number.") + .with_syntax_example("acos(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for x ≥ 1, undefined otherwise. +pub fn acosh_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = Interval::try_new( + ScalarValue::new_one(&range.lower().data_type())?, + ScalarValue::try_from(&range.upper().data_type())?, + )?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of ACOSH contains out-of-domain values") + } +} + +static DOCUMENTATION_ACOSH: OnceLock = OnceLock::new(); + +pub fn get_acosh_doc() -> &'static Documentation { + DOCUMENTATION_ACOSH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the area hyperbolic cosine or inverse hyperbolic cosine of a number.", + ) + .with_syntax_example("acosh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing on the interval \[−1, 1\], undefined otherwise. +pub fn asin_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = + Interval::make_symmetric_unit_interval(&range.lower().data_type())?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of ASIN contains out-of-domain values") + } +} + +static DOCUMENTATION_ASIN: OnceLock = OnceLock::new(); + +pub fn get_asin_doc() -> &'static Documentation { + DOCUMENTATION_ASIN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the arc sine or inverse sine of a number.") + .with_syntax_example("asin(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn asinh_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_ASINH: OnceLock = OnceLock::new(); + +pub fn get_asinh_doc() -> &'static Documentation { + DOCUMENTATION_ASINH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the area hyperbolic sine or inverse hyperbolic sine of a number.", + ) + .with_syntax_example("asinh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn atan_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_ATAN: OnceLock = OnceLock::new(); + +pub fn get_atan_doc() -> &'static Documentation { + DOCUMENTATION_ATAN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the arc tangent or inverse tangent of a number.") + .with_syntax_example("atan(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing on the interval \[−1, 1\], undefined otherwise. +pub fn atanh_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let valid_domain = + Interval::make_symmetric_unit_interval(&range.lower().data_type())?; + + if valid_domain.contains(range)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of ATANH contains out-of-domain values") + } +} + +static DOCUMENTATION_ATANH: OnceLock = OnceLock::new(); + +pub fn get_atanh_doc() -> &'static Documentation { + DOCUMENTATION_ATANH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number.", + ) + .with_syntax_example("atanh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Order depends on the quadrant. +// TODO: Implement ordering rule of the ATAN2 function. +pub fn atan2_order(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +static DOCUMENTATION_ATANH2: OnceLock = OnceLock::new(); + +pub fn get_atan2_doc() -> &'static Documentation { + DOCUMENTATION_ATANH2.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the arc tangent or inverse tangent of `expression_y / expression_x`.", + ) + .with_syntax_example("atan2(expression_y, expression_x)") + .with_argument("expression_y", r#"First numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators."#) + .with_argument("expression_x", r#"Second numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators."#) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn cbrt_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_CBRT: OnceLock = OnceLock::new(); + +pub fn get_cbrt_doc() -> &'static Documentation { + DOCUMENTATION_CBRT.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cube root of a number.") + .with_syntax_example("cbrt(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn ceil_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_CEIL: OnceLock = OnceLock::new(); + +pub fn get_ceil_doc() -> &'static Documentation { + DOCUMENTATION_CEIL.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the nearest integer greater than or equal to a number.", + ) + .with_syntax_example("ceil(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-increasing on \[0, π\] and then non-decreasing on \[π, 2π\]. +/// This pattern repeats periodically with a period of 2π. +// TODO: Implement ordering rule of the ATAN2 function. +pub fn cos_order(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +static DOCUMENTATION_COS: OnceLock = OnceLock::new(); + +pub fn get_cos_doc() -> &'static Documentation { + DOCUMENTATION_COS.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the cosine of a number.") + .with_syntax_example("cos(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0. +pub fn cosh_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else if range.lt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(-arg.sort_properties) + } else { + Ok(SortProperties::Unordered) + } +} + +static DOCUMENTATION_COSH: OnceLock = OnceLock::new(); + +pub fn get_cosh_doc() -> &'static Documentation { + DOCUMENTATION_COSH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the hyperbolic cosine of a number.") + .with_syntax_example("cosh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing function that converts radians to degrees. +pub fn degrees_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_DEGREES: OnceLock = OnceLock::new(); + +pub fn get_degrees_doc() -> &'static Documentation { + DOCUMENTATION_DEGREES.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Converts radians to degrees.") + .with_syntax_example("degrees(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn exp_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_EXP: OnceLock = OnceLock::new(); + +pub fn get_exp_doc() -> &'static Documentation { + DOCUMENTATION_EXP.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-e exponential of a number.") + .with_syntax_example("exp(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn floor_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_FLOOR: OnceLock = OnceLock::new(); + +pub fn get_floor_doc() -> &'static Documentation { + DOCUMENTATION_FLOOR.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns the nearest integer less than or equal to a number.", + ) + .with_syntax_example("floor(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn ln_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of LN contains out-of-domain values") + } +} + +static DOCUMENTATION_LN: OnceLock = OnceLock::new(); + +pub fn get_ln_doc() -> &'static Documentation { + DOCUMENTATION_LN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the natural logarithm of a number.") + .with_syntax_example("ln(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn log2_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of LOG2 contains out-of-domain values") + } +} + +static DOCUMENTATION_LOG2: OnceLock = OnceLock::new(); + +pub fn get_log2_doc() -> &'static Documentation { + DOCUMENTATION_LOG2.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-2 logarithm of a number.") + .with_syntax_example("log2(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn log10_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of LOG10 contains out-of-domain values") + } +} + +static DOCUMENTATION_LOG10: OnceLock = OnceLock::new(); + +pub fn get_log10_doc() -> &'static Documentation { + DOCUMENTATION_LOG10.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the base-10 logarithm of a number.") + .with_syntax_example("log10(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers x. +pub fn radians_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_RADIONS: OnceLock = OnceLock::new(); + +pub fn get_radians_doc() -> &'static Documentation { + DOCUMENTATION_RADIONS.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Converts degrees to radians.") + .with_syntax_example("radians(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing on \[0, π\] and then non-increasing on \[π, 2π\]. +/// This pattern repeats periodically with a period of 2π. +// TODO: Implement ordering rule of the SIN function. +pub fn sin_order(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +static DOCUMENTATION_SIN: OnceLock = OnceLock::new(); + +pub fn get_sin_doc() -> &'static Documentation { + DOCUMENTATION_SIN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the sine of a number.") + .with_syntax_example("sin(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn sinh_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_SINH: OnceLock = OnceLock::new(); + +pub fn get_sinh_doc() -> &'static Documentation { + DOCUMENTATION_SINH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the hyperbolic sine of a number.") + .with_syntax_example("sinh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for x ≥ 0, undefined otherwise. +pub fn sqrt_order(input: &[ExprProperties]) -> Result { + let arg = &input[0]; + let range = &arg.range; + + let zero_point = Interval::make_zero(&range.lower().data_type())?; + + if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE { + Ok(arg.sort_properties) + } else { + exec_err!("Input range of SQRT contains out-of-domain values") + } +} + +static DOCUMENTATION_SQRT: OnceLock = OnceLock::new(); + +pub fn get_sqrt_doc() -> &'static Documentation { + DOCUMENTATION_SQRT.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the square root of a number.") + .with_syntax_example("sqrt(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing between vertical asymptotes at x = k * π ± π / 2 for any +/// integer k. +// TODO: Implement ordering rule of the TAN function. +pub fn tan_order(_input: &[ExprProperties]) -> Result { + Ok(SortProperties::Unordered) +} + +static DOCUMENTATION_TAN: OnceLock = OnceLock::new(); + +pub fn get_tan_doc() -> &'static Documentation { + DOCUMENTATION_TAN.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the tangent of a number.") + .with_syntax_example("tan(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// Non-decreasing for all real numbers. +pub fn tanh_order(input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) +} + +static DOCUMENTATION_TANH: OnceLock = OnceLock::new(); + +pub fn get_tanh_doc() -> &'static Documentation { + DOCUMENTATION_TANH.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns the hyperbolic tangent of a number.") + .with_syntax_example("tanh(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} diff --git a/datafusion/functions/src/math/nans.rs b/datafusion/functions/src/math/nans.rs index 2bd704a7de2e..c1dd1aacc35a 100644 --- a/datafusion/functions/src/math/nans.rs +++ b/datafusion/functions/src/math/nans.rs @@ -17,15 +17,15 @@ //! Math function: `isnan()`. -use arrow::datatypes::DataType; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, TypeSignature}; -use arrow::array::{ArrayRef, BooleanArray, Float32Array, Float64Array}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use arrow::array::{ArrayRef, AsArray, BooleanArray}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct IsNanFunc { @@ -43,7 +43,10 @@ impl IsNanFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Float32]), Exact(vec![Float64])], + vec![ + TypeSignature::Exact(vec![Float32]), + TypeSignature::Exact(vec![Float64]), + ], Volatility::Immutable, ), } @@ -70,20 +73,15 @@ impl ScalarUDFImpl for IsNanFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float64Array, - BooleanArray, - { f64::is_nan } - )), - DataType::Float32 => Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - self.name(), - Float32Array, - BooleanArray, - { f32::is_nan } - )), + DataType::Float64 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + f64::is_nan, + )) as ArrayRef, + + DataType::Float32 => Arc::new(BooleanArray::from_unary( + args[0].as_primitive::(), + f32::is_nan, + )) as ArrayRef, other => { return exec_err!( "Unsupported data type {other:?} for function {}", @@ -93,4 +91,24 @@ impl ScalarUDFImpl for IsNanFunc { }; Ok(ColumnarValue::Array(arr)) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_isnan_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_isnan_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns true if a given number is +NaN or -NaN otherwise returns false.", + ) + .with_syntax_example("isnan(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/nanvl.rs b/datafusion/functions/src/math/nanvl.rs index d81a690843b6..cfd21256dd96 100644 --- a/datafusion/functions/src/math/nanvl.rs +++ b/datafusion/functions/src/math/nanvl.rs @@ -16,18 +16,19 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use crate::utils::make_scalar_function; +use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::make_scalar_function; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct NanvlFunc { @@ -75,6 +76,28 @@ impl ScalarUDFImpl for NanvlFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(nanvl, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_nanvl_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_nanvl_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + r#"Returns the first argument if it's not _NaN_. +Returns the second argument otherwise."#, + ) + .with_syntax_example("nanvl(expression_x, expression_y)") + .with_argument("expression_x", "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators.") + .with_argument("expression_y", "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators.") + .build() + .unwrap() + }) } /// Nanvl SQL function @@ -89,14 +112,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float64Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float64Array; + let y = args[1].as_primitive() as &Float64Array; + arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } Float32 => { let compute_nanvl = |x: f32, y: f32| { @@ -107,14 +127,11 @@ fn nanvl(args: &[ArrayRef]) -> Result { } }; - Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "x", - "y", - Float32Array, - { compute_nanvl } - )) as ArrayRef) + let x = args[0].as_primitive() as &Float32Array; + let y = args[1].as_primitive() as &Float32Array; + arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl) + .map(|res| Arc::new(res) as _) + .map_err(DataFusionError::from) } other => exec_err!("Unsupported data type {other:?} for function nanvl"), } @@ -122,10 +139,12 @@ fn nanvl(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::nanvl::nanvl; + use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_nanvl_f64() { diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index 0801e797511b..ea0f33161772 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -16,15 +16,16 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::OnceLock; -use arrow::array::Float64Array; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; - -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct PiFunc { @@ -62,15 +63,35 @@ impl ScalarUDFImpl for PiFunc { Ok(Float64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - if !matches!(&args[0], ColumnarValue::Array(_)) { - return exec_err!("Expect pi function to take no param"); - } - let array = Float64Array::from_value(std::f64::consts::PI, 1); - Ok(ColumnarValue::Array(Arc::new(array))) + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("{} function does not accept arguments", self.name()) } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn invoke_no_args(&self, _number_rows: usize) -> Result { + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + std::f64::consts::PI, + )))) } + + fn output_ordering(&self, _input: &[ExprProperties]) -> Result { + // This function returns a constant value. + Ok(SortProperties::Singleton) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_pi_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_pi_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Returns an approximate value of π.") + .with_syntax_example("pi()") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 29caa63a9422..a24c613f5259 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -16,22 +16,22 @@ // under the License. //! Math function: `power()`. +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +use super::log::LogFunc; -use arrow::datatypes::DataType; +use arrow::array::{ArrayRef, AsArray, Int64Array}; +use arrow::datatypes::{ArrowNativeTypeOp, DataType, Float64Type}; use datafusion_common::{ - exec_err, plan_datafusion_err, DataFusionError, Result, ScalarValue, + arrow_datafusion_err, exec_datafusion_err, exec_err, plan_datafusion_err, + DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition}; - -use arrow::array::{ArrayRef, Float64Array, Int64Array}; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ScalarUDF, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; - -use super::log::LogFunc; #[derive(Debug)] pub struct PowerFunc { @@ -50,7 +50,10 @@ impl PowerFunc { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], + vec![ + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Float64, Float64]), + ], Volatility::Immutable, ), aliases: vec![String::from("pow")], @@ -85,23 +88,35 @@ impl ScalarUDFImpl for PowerFunc { let args = ColumnarValue::values_to_arrays(args)?; let arr: ArrayRef = match args[0].data_type() { - DataType::Float64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Float64Array, - { f64::powf } - )), - - DataType::Int64 => Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "base", - "exponent", - Int64Array, - { i64::pow } - )), + DataType::Float64 => { + let bases = args[0].as_primitive::(); + let exponents = args[1].as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + bases, + exponents, + f64::powf, + )?; + Arc::new(result) as _ + } + DataType::Int64 => { + let bases = downcast_arg!(&args[0], "base", Int64Array); + let exponents = downcast_arg!(&args[1], "exponent", Int64Array); + bases + .iter() + .zip(exponents.iter()) + .map(|(base, exp)| match (base, exp) { + (Some(base), Some(exp)) => Ok(Some(base.pow_checked( + exp.try_into().map_err(|_| { + exec_datafusion_err!( + "Can't use negative exponents: {exp} in integer computation, please use Float." + ) + })?, + ).map_err(|e| arrow_datafusion_err!(e))?)), + _ => Ok(None), + }) + .collect::>() + .map(Arc::new)? as _ + } other => { return exec_err!( @@ -140,8 +155,8 @@ impl ScalarUDFImpl for PowerFunc { Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } - Expr::ScalarFunction(ScalarFunction { func_def, mut args }) - if is_log(&func_def) && args.len() == 2 && base == args[0] => + Expr::ScalarFunction(ScalarFunction { func, mut args }) + if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above Ok(ExprSimplifyResult::Simplified(b)) @@ -149,23 +164,37 @@ impl ScalarUDFImpl for PowerFunc { _ => Ok(ExprSimplifyResult::Original(vec![base, exponent])), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_power_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_power_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Returns a base expression raised to the power of an exponent.", + ) + .with_syntax_example("power(base, exponent)") + .with_standard_argument("base", Some("Numeric")) + .with_standard_argument("exponent", Some("Exponent numeric")) + .build() + .unwrap() + }) } /// Return true if this function call is a call to `Log` -fn is_log(func_def: &ScalarFunctionDefinition) -> bool { - if let ScalarFunctionDefinition::UDF(fun) = func_def { - fun.as_ref() - .inner() - .as_any() - .downcast_ref::() - .is_some() - } else { - false - } +fn is_log(func: &ScalarUDF) -> bool { + func.inner().as_any().downcast_ref::().is_some() } #[cfg(test)] mod tests { + use arrow::array::Float64Array; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; @@ -178,7 +207,7 @@ mod tests { ]; let result = PowerFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function power"); match result { @@ -205,7 +234,7 @@ mod tests { ]; let result = PowerFunc::new() - .invoke(&args) + .invoke_batch(&args, 4) .expect("failed to initialize function power"); match result { diff --git a/datafusion/functions/src/math/random.rs b/datafusion/functions/src/math/random.rs index 2c1ad4136702..cf564e5328a5 100644 --- a/datafusion/functions/src/math/random.rs +++ b/datafusion/functions/src/math/random.rs @@ -16,17 +16,17 @@ // under the License. use std::any::Any; -use std::iter; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::Float64Array; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Float64; use rand::{thread_rng, Rng}; -use datafusion_common::{exec_err, Result}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct RandomFunc { @@ -64,45 +64,37 @@ impl ScalarUDFImpl for RandomFunc { Ok(Float64) } - fn invoke(&self, args: &[ColumnarValue]) -> Result { - random(args) + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("{} function does not accept arguments", self.name()) } -} - -/// Random SQL function -fn random(args: &[ColumnarValue]) -> Result { - let len: usize = match &args[0] { - ColumnarValue::Array(array) => array.len(), - _ => return exec_err!("Expect random function to take no param"), - }; - let mut rng = thread_rng(); - let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); - let array = Float64Array::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) -} - -#[cfg(test)] -mod test { - use std::sync::Arc; - - use arrow::array::NullArray; - use datafusion_common::cast::as_float64_array; - use datafusion_expr::ColumnarValue; + fn invoke_no_args(&self, num_rows: usize) -> Result { + let mut rng = thread_rng(); + let mut values = vec![0.0; num_rows]; + // Equivalent to set each element with rng.gen_range(0.0..1.0), but more efficient + rng.fill(&mut values[..]); + let array = Float64Array::from(values); - use crate::math::random::random; - - #[test] - fn test_random_expression() { - let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; - let array = random(&args) - .expect("failed to initialize function random") - .into_array(1) - .expect("Failed to convert to array"); - let floats = - as_float64_array(&array).expect("failed to initialize function random"); + Ok(ColumnarValue::Array(Arc::new(array))) + } - assert_eq!(floats.len(), 1); - assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); + fn documentation(&self) -> Option<&Documentation> { + Some(get_random_doc()) } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_random_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + r#"Returns a random float value in the range [0, 1). +The random seed is unique to each row."#, + ) + .with_syntax_example("random()") + .build() + .unwrap() + }) +} diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index f4a163137a35..6000e5d765de 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -16,17 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; + +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::DataType::{Float32, Float64, Int32}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, FuncMonotonicity}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct RoundFunc { @@ -80,11 +84,44 @@ impl ScalarUDFImpl for RoundFunc { make_scalar_function(round, vec![])(args) } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // round preserves the order of the first argument + let value = &input[0]; + let precision = input.get(1); + + if precision + .map(|r| r.sort_properties.eq(&SortProperties::Singleton)) + .unwrap_or(true) + { + Ok(value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_round_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_round_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description("Rounds a number to the nearest integer.") + .with_syntax_example("round(numeric_expression[, decimal_places])") + .with_standard_argument("numeric_expression", Some("Numeric")) + .with_argument( + "decimal_places", + "Optional. The number of decimal places to round to. Defaults to 0.", + ) + .build() + .unwrap() + }) +} + /// Round SQL function pub fn round(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { @@ -97,77 +134,89 @@ pub fn round(args: &[ArrayRef]) -> Result { let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); if args.len() == 2 { - decimal_places = ColumnarValue::Array(args[1].clone()); + decimal_places = ColumnarValue::Array(Arc::clone(&args[1])); } match args[0].data_type() { - DataType::Float64 => match decimal_places { + Float64 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) - } - } - )) as ArrayRef) + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; + + let result = args[0] + .as_primitive::() + .unary::<_, Float64Type>(|value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }); + Ok(Arc::new(result) as _) + } + ColumnarValue::Array(decimal_places) => { + let options = CastOptions { + safe: false, // raise error if the cast is not possible + ..Default::default() + }; + let decimal_places = cast_with_options(&decimal_places, &Int32, &options) + .map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })?; + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result = arrow::compute::binary::<_, _, _, Float64Type>( + values, + decimal_places, + |value, decimal_places| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float64Array, - Int64Array, - { - |value: f64, decimal_places: i64| { - (value * 10.0_f64.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f64.powi(decimal_places.try_into().unwrap()) - } - } - )) as ArrayRef), _ => { exec_err!("round function requires a scalar or array for decimal_places") } }, - DataType::Float32 => match decimal_places { + Float32 => match decimal_places { ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places = decimal_places.try_into().unwrap(); - - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) - } - } - )) as ArrayRef) + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; + let result = args[0] + .as_primitive::() + .unary::<_, Float32Type>(|value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }); + Ok(Arc::new(result) as _) + } + ColumnarValue::Array(_) => { + let ColumnarValue::Array(decimal_places) = + decimal_places.cast_to(&Int32, None).map_err(|e| { + exec_datafusion_err!("Invalid values for decimal places: {e}") + })? + else { + panic!("Unexpected result of ColumnarValue::Array.cast") + }; + + let values = args[0].as_primitive::(); + let decimal_places = decimal_places.as_primitive::(); + let result: PrimitiveArray = arrow::compute::binary( + values, + decimal_places, + |value, decimal_places| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + }, + )?; + Ok(Arc::new(result) as _) } - ColumnarValue::Array(decimal_places) => Ok(Arc::new(make_function_inputs2!( - &args[0], - decimal_places, - "value", - "decimal_places", - Float32Array, - Int64Array, - { - |value: f32, decimal_places: i64| { - (value * 10.0_f32.powi(decimal_places.try_into().unwrap())) - .round() - / 10.0_f32.powi(decimal_places.try_into().unwrap()) - } - } - )) as ArrayRef), _ => { exec_err!("round function requires a scalar or array for decimal_places") } @@ -179,10 +228,13 @@ pub fn round(args: &[ArrayRef]) -> Result { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::round::round; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; + use datafusion_common::DataFusionError; #[test] fn test_round_f32() { @@ -249,4 +301,17 @@ mod test { assert_eq!(floats, &expected); } + + #[test] + fn test_round_f32_cast_fail() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![125.2345])), // input + Arc::new(Int64Array::from(vec![2147483648])), // decimal_places + ]; + + let result = round(&args); + + assert!(result.is_err()); + assert!(matches!(result, Err(DataFusionError::Execution { .. }))); + } } diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs new file mode 100644 index 000000000000..7f21297712c7 --- /dev/null +++ b/datafusion/functions/src/math/signum.rs @@ -0,0 +1,235 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::{Arc, OnceLock}; + +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type}; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; + +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub struct SignumFunc { + signature: Signature, +} + +impl Default for SignumFunc { + fn default() -> Self { + SignumFunc::new() + } +} + +impl SignumFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Float64, Float32], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SignumFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "signum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // Non-decreasing for all real numbers x. + Ok(input[0].sort_properties) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(signum, vec![])(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_signum_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_signum_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + r#"Returns the sign of a number. +Negative numbers return `-1`. +Zero and positive numbers return `1`."#, + ) + .with_syntax_example("signum(numeric_expression)") + .with_standard_argument("numeric_expression", Some("Numeric")) + .build() + .unwrap() + }) +} + +/// signum SQL function +pub fn signum(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + Float64 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>( + |x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), + + Float32 => Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>( + |x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.signum() + } + }, + ), + ) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function signum"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{Float32Array, Float64Array}; + + use datafusion_common::cast::{as_float32_array, as_float64_array}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::math::signum::SignumFunc; + + #[test] + fn test_signum_f32() { + let array = Arc::new(Float32Array::from(vec![ + -1.0, + -0.0, + 0.0, + 1.0, + -0.01, + 0.01, + f32::NAN, + f32::INFINITY, + f32::NEG_INFINITY, + ])); + let batch_size = array.len(); + let result = SignumFunc::new() + .invoke_batch(&[ColumnarValue::Array(array)], batch_size) + .expect("failed to initialize function signum"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float32_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 9); + assert_eq!(floats.value(0), -1.0); + assert_eq!(floats.value(1), 0.0); + assert_eq!(floats.value(2), 0.0); + assert_eq!(floats.value(3), 1.0); + assert_eq!(floats.value(4), -1.0); + assert_eq!(floats.value(5), 1.0); + assert!(floats.value(6).is_nan()); + assert_eq!(floats.value(7), 1.0); + assert_eq!(floats.value(8), -1.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } + + #[test] + fn test_signum_f64() { + let array = Arc::new(Float64Array::from(vec![ + -1.0, + -0.0, + 0.0, + 1.0, + -0.01, + 0.01, + f64::NAN, + f64::INFINITY, + f64::NEG_INFINITY, + ])); + let batch_size = array.len(); + let result = SignumFunc::new() + .invoke_batch(&[ColumnarValue::Array(array)], batch_size) + .expect("failed to initialize function signum"); + + match result { + ColumnarValue::Array(arr) => { + let floats = as_float64_array(&arr) + .expect("failed to convert result to a Float32Array"); + + assert_eq!(floats.len(), 9); + assert_eq!(floats.value(0), -1.0); + assert_eq!(floats.value(1), 0.0); + assert_eq!(floats.value(2), 0.0); + assert_eq!(floats.value(3), 1.0); + assert_eq!(floats.value(4), -1.0); + assert_eq!(floats.value(5), 1.0); + assert!(floats.value(6).is_nan()); + assert_eq!(floats.value(7), 1.0); + assert_eq!(floats.value(8), -1.0); + } + ColumnarValue::Scalar(_) => { + panic!("Expected an array value") + } + } + } +} diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 6f88099889cc..9a05684d238e 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -16,18 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; -use arrow::datatypes::DataType; -use arrow::datatypes::DataType::{Float32, Float64}; +use std::sync::{Arc, OnceLock}; use crate::utils::make_scalar_function; + +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::DataType::{Float32, Float64}; +use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type}; use datafusion_common::ScalarValue::Int64; -use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, FuncMonotonicity}; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct TruncFunc { @@ -86,11 +89,47 @@ impl ScalarUDFImpl for TruncFunc { make_scalar_function(trunc, vec![])(args) } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + // trunc preserves the order of the first argument + let value = &input[0]; + let precision = input.get(1); + + if precision + .map(|r| r.sort_properties.eq(&SortProperties::Singleton)) + .unwrap_or(true) + { + Ok(value.sort_properties) + } else { + Ok(SortProperties::Unordered) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_trunc_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_trunc_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_MATH) + .with_description( + "Truncates a number to a whole number or truncated to the specified decimal places.", + ) + .with_syntax_example("trunc(numeric_expression[, decimal_places])") + .with_standard_argument("numeric_expression", Some("Numeric")) + .with_argument("decimal_places", r#"Optional. The number of decimal places to + truncate to. Defaults to 0 (truncate to a whole number). If + `decimal_places` is a positive integer, truncates digits to the + right of the decimal point. If `decimal_places` is a negative + integer, replaces digits to the left of the decimal point with `0`."#) + .build() + .unwrap() + }) +} + /// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function fn trunc(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { @@ -100,44 +139,66 @@ fn trunc(args: &[ArrayRef]) -> Result { ); } - //if only one arg then invoke toolchain trunc(num) and precision = 0 by default - //or then invoke the compute_truncate method to process precision + // If only one arg then invoke toolchain trunc(num) and precision = 0 by default + // or then invoke the compute_truncate method to process precision let num = &args[0]; let precision = if args.len() == 1 { ColumnarValue::Scalar(Int64(Some(0))) } else { - ColumnarValue::Array(args[1].clone()) + ColumnarValue::Array(Arc::clone(&args[1])) }; - match args[0].data_type() { + match num.data_type() { Float64 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float64Array, - Int64Array, - { compute_truncate64 } - )) as ArrayRef), + ColumnarValue::Scalar(Int64(Some(0))) => { + Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float64Type>(|x: f64| { + if x == 0_f64 { + 0_f64 + } else { + x.trunc() + } + }), + ) as ArrayRef) + } + ColumnarValue::Array(precision) => { + let num_array = num.as_primitive::(); + let precision_array = precision.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(num_array, precision_array, |x, y| { + compute_truncate64(x, y) + })?; + + Ok(Arc::new(result) as ArrayRef) + } _ => exec_err!("trunc function requires a scalar or array for precision"), }, Float32 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), - ) as ArrayRef), - ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( - num, - precision, - "x", - "y", - Float32Array, - Int64Array, - { compute_truncate32 } - )) as ArrayRef), + ColumnarValue::Scalar(Int64(Some(0))) => { + Ok(Arc::new( + args[0] + .as_primitive::() + .unary::<_, Float32Type>(|x: f32| { + if x == 0_f32 { + 0_f32 + } else { + x.trunc() + } + }), + ) as ArrayRef) + } + ColumnarValue::Array(precision) => { + let num_array = num.as_primitive::(); + let precision_array = precision.as_primitive::(); + let result: PrimitiveArray = + arrow::compute::binary(num_array, precision_array, |x, y| { + compute_truncate32(x, y) + })?; + + Ok(Arc::new(result) as ArrayRef) + } _ => exec_err!("trunc function requires a scalar or array for precision"), }, other => exec_err!("Unsupported data type {other:?} for function trunc"), @@ -156,10 +217,12 @@ fn compute_truncate64(x: f64, y: i64) -> f64 { #[cfg(test)] mod test { + use std::sync::Arc; + use crate::math::trunc::trunc; + use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array}; use datafusion_common::cast::{as_float32_array, as_float64_array}; - use std::sync::Arc; #[test] fn test_truncate_32() { diff --git a/datafusion/functions/src/planner.rs b/datafusion/functions/src/planner.rs new file mode 100644 index 000000000000..93edec7ece30 --- /dev/null +++ b/datafusion/functions/src/planner.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! SQL planning extensions like [`UserDefinedFunctionPlanner`] + +use datafusion_common::Result; +use datafusion_expr::{ + expr::ScalarFunction, + planner::{ExprPlanner, PlannerResult}, + Expr, +}; + +#[derive(Default, Debug)] +pub struct UserDefinedFunctionPlanner; + +impl ExprPlanner for UserDefinedFunctionPlanner { + #[cfg(feature = "datetime_expressions")] + fn plan_extract(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::datetime::date_part(), args), + ))) + } + + #[cfg(feature = "unicode_expressions")] + fn plan_position(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::strpos(), args), + ))) + } + + #[cfg(feature = "unicode_expressions")] + fn plan_substring(&self, args: Vec) -> Result>> { + Ok(PlannerResult::Planned(Expr::ScalarFunction( + ScalarFunction::new_udf(crate::unicode::substr(), args), + ))) + } +} diff --git a/datafusion/functions/src/regex/mod.rs b/datafusion/functions/src/regex/mod.rs index 5c12d4559e74..803f51e915a9 100644 --- a/datafusion/functions/src/regex/mod.rs +++ b/datafusion/functions/src/regex/mod.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. -//! "regx" DataFusion functions +//! "regex" DataFusion functions +use std::sync::Arc; + +pub mod regexpcount; pub mod regexplike; pub mod regexpmatch; pub mod regexpreplace; + // create UDFs +make_udf_function!(regexpcount::RegexpCountFunc, REGEXP_COUNT, regexp_count); make_udf_function!(regexpmatch::RegexpMatchFunc, REGEXP_MATCH, regexp_match); make_udf_function!(regexplike::RegexpLikeFunc, REGEXP_LIKE, regexp_like); make_udf_function!( @@ -28,12 +33,67 @@ make_udf_function!( REGEXP_REPLACE, regexp_replace ); -export_functions!(( - regexp_match, - input_arg1 input_arg2, - "returns a list of regular expression matches in a string. " -),( - regexp_like, - input_arg1 input_arg2, - "Returns true if a has at least one match in a string,false otherwise." -),(regexp_replace, arg1 arg2 arg3 arg4, "Replaces substrings in a string that match")); + +pub mod expr_fn { + use datafusion_expr::Expr; + + /// Returns the number of consecutive occurrences of a regular expression in a string. + pub fn regexp_count( + values: Expr, + regex: Expr, + start: Option, + flags: Option, + ) -> Expr { + let mut args = vec![values, regex]; + if let Some(start) = start { + args.push(start); + }; + + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_count().call(args) + } + + /// Returns a list of regular expression matches in a string. + pub fn regexp_match(values: Expr, regex: Expr, flags: Option) -> Expr { + let mut args = vec![values, regex]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_match().call(args) + } + + /// Returns true if a has at least one match in a string, false otherwise. + pub fn regexp_like(values: Expr, regex: Expr, flags: Option) -> Expr { + let mut args = vec![values, regex]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_like().call(args) + } + + /// Replaces substrings in a string that match. + pub fn regexp_replace( + string: Expr, + pattern: Expr, + replacement: Expr, + flags: Option, + ) -> Expr { + let mut args = vec![string, pattern, replacement]; + if let Some(flags) = flags { + args.push(flags); + }; + super::regexp_replace().call(args) + } +} + +/// Returns all DataFusion functions defined in this package +pub fn functions() -> Vec> { + vec![ + regexp_count(), + regexp_match(), + regexp_like(), + regexp_replace(), + ] +} diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs new file mode 100644 index 000000000000..7c4313effffb --- /dev/null +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -0,0 +1,984 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::strings::StringArrayType; +use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array}; +use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{ + DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View, +}; +use arrow::error::ArrowError; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact, + TypeSignature::Uniform, Volatility, +}; +use itertools::izip; +use regex::Regex; +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::sync::{Arc, OnceLock}; + +#[derive(Debug)] +pub struct RegexpCountFunc { + signature: Signature, +} + +impl Default for RegexpCountFunc { + fn default() -> Self { + Self::new() + } +} + +impl RegexpCountFunc { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + Uniform(2, vec![Utf8View, LargeUtf8, Utf8]), + Exact(vec![Utf8View, Utf8View, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![Utf8View, Utf8View, Int64, Utf8View]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]), + Exact(vec![Utf8, Utf8, Int64, Utf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RegexpCountFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "regexp_count" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Int64) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .map(|arg| arg.clone().into_array(inferred_length)) + .collect::>>()?; + + let result = regexp_count_func(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_count_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_count_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.") + .with_syntax_example("regexp_count(str, regexp[, start, flags])") + .with_sql_example(r#"```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_standard_argument("regexp",Some("Regular")) + .with_argument("start", "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + +pub fn regexp_count_func(args: &[ArrayRef]) -> Result { + let args_len = args.len(); + if !(2..=4).contains(&args_len) { + return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4."); + } + + let values = &args[0]; + match values.data_type() { + Utf8 | LargeUtf8 | Utf8View => (), + other => { + return internal_err!( + "Unsupported data type {other:?} for function regexp_count" + ); + } + } + + regexp_count( + values, + &args[1], + if args_len > 2 { Some(&args[2]) } else { None }, + if args_len > 3 { Some(&args[3]) } else { None }, + ) + .map_err(|e| e.into()) +} + +/// `arrow-rs` style implementation of `regexp_count` function. +/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern +/// within a string array. It supports optional start positions and flags for case insensitivity. +/// +/// The function accepts a variable number of arguments: +/// - `values`: The array of strings to search within. +/// - `regex_array`: The array of regular expression patterns to search for. +/// - `start_array` (optional): The array of start positions for the search. +/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity). +/// +/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions, +/// and flags. It uses a cache to store compiled regular expressions for efficiency. +/// +/// # Errors +/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile. +pub fn regexp_count( + values: &dyn Array, + regex_array: &dyn Datum, + start_array: Option<&dyn Datum>, + flags_array: Option<&dyn Datum>, +) -> Result { + let (regex_array, is_regex_scalar) = regex_array.get(); + let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| { + let (start, is_start_scalar) = start.get(); + (Some(start), is_start_scalar) + }); + let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| { + let (flags, is_flags_scalar) = flags.get(); + (Some(flags), is_flags_scalar) + }); + + match (values.data_type(), regex_array.data_type(), flags_array) { + (Utf8, Utf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, None) => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner( + values.as_string::(), + regex_array.as_string::(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string::()), + is_flags_scalar, + ), + (Utf8View, Utf8View, None) => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + None, + is_flags_scalar, + ), + (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner( + values.as_string_view(), + regex_array.as_string_view(), + is_regex_scalar, + start_array.map(|start| start.as_primitive::()), + is_start_scalar, + Some(flags_array.as_string_view()), + is_flags_scalar, + ), + _ => Err(ArrowError::ComputeError( + "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(), + )), + } +} + +pub fn regexp_count_inner<'a, S>( + values: S, + regex_array: S, + is_regex_scalar: bool, + start_array: Option<&Int64Array>, + is_start_scalar: bool, + flags_array: Option, + is_flags_scalar: bool, +) -> Result +where + S: StringArrayType<'a>, +{ + let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 { + (Some(regex_array.value(0)), true) + } else { + (None, false) + }; + + let (start_array, start_scalar, is_start_scalar) = + if let Some(start_array) = start_array { + if is_start_scalar || start_array.len() == 1 { + (None, Some(start_array.value(0)), true) + } else { + (Some(start_array), None, false) + } + } else { + (None, Some(1), true) + }; + + let (flags_array, flags_scalar, is_flags_scalar) = + if let Some(flags_array) = flags_array { + if is_flags_scalar || flags_array.len() == 1 { + (None, Some(flags_array.value(0)), true) + } else { + (Some(flags_array), None, false) + } + } else { + (None, None, true) + }; + + let mut regex_cache = HashMap::new(); + + match (is_regex_scalar, is_start_scalar, is_flags_scalar) { + (true, true, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .map(|value| count_matches(value, &pattern, start_scalar)) + .collect::, ArrowError>>()?, + ))) + } + (true, true, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(flags_array.iter()) + .map(|(value, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (true, false, true) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let pattern = compile_regex(regex, flags_scalar)?; + + let start_array = start_array.unwrap(); + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(start_array.iter()) + .map(|(value, start)| count_matches(value, &pattern, start)) + .collect::, ArrowError>>()?, + ))) + } + (true, false, false) => { + let regex = match regex_scalar { + None | Some("") => { + return Ok(Arc::new(Int64Array::from(vec![0; values.len()]))) + } + Some(regex) => regex, + }; + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + start_array.unwrap().iter(), + flags_array.iter() + ) + .map(|(value, start, flags)| { + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + values + .iter() + .zip(regex_array.iter()) + .map(|(value, regex)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, true, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), flags_array.iter()) + .map(|(value, regex, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + + count_matches(value, &pattern, start_scalar) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, true) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!(values.iter(), regex_array.iter(), start_array.iter()) + .map(|(value, regex, start)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = compile_and_cache_regex( + regex, + flags_scalar, + &mut regex_cache, + )?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + (false, false, false) => { + if values.len() != regex_array.len() { + return Err(ArrowError::ComputeError(format!( + "regex_array must be the same length as values array; got {} and {}", + regex_array.len(), + values.len(), + ))); + } + + let start_array = start_array.unwrap(); + if values.len() != start_array.len() { + return Err(ArrowError::ComputeError(format!( + "start_array must be the same length as values array; got {} and {}", + start_array.len(), + values.len(), + ))); + } + + let flags_array = flags_array.unwrap(); + if values.len() != flags_array.len() { + return Err(ArrowError::ComputeError(format!( + "flags_array must be the same length as values array; got {} and {}", + flags_array.len(), + values.len(), + ))); + } + + Ok(Arc::new(Int64Array::from_iter_values( + izip!( + values.iter(), + regex_array.iter(), + start_array.iter(), + flags_array.iter() + ) + .map(|(value, regex, start, flags)| { + let regex = match regex { + None | Some("") => return Ok(0), + Some(regex) => regex, + }; + + let pattern = + compile_and_cache_regex(regex, flags, &mut regex_cache)?; + count_matches(value, &pattern, start) + }) + .collect::, ArrowError>>()?, + ))) + } + } +} + +fn compile_and_cache_regex( + regex: &str, + flags: Option<&str>, + regex_cache: &mut HashMap, +) -> Result { + match regex_cache.entry(regex.to_string()) { + Entry::Vacant(entry) => { + let compiled = compile_regex(regex, flags)?; + entry.insert(compiled.clone()); + Ok(compiled) + } + Entry::Occupied(entry) => Ok(entry.get().to_owned()), + } +} + +fn compile_regex(regex: &str, flags: Option<&str>) -> Result { + let pattern = match flags { + None | Some("") => regex.to_string(), + Some(flags) => { + if flags.contains("g") { + return Err(ArrowError::ComputeError( + "regexp_count() does not support global flag".to_string(), + )); + } + format!("(?{}){}", flags, regex) + } + }; + + Regex::new(&pattern).map_err(|_| { + ArrowError::ComputeError(format!( + "Regular expression did not compile: {}", + pattern + )) + }) +} + +fn count_matches( + value: Option<&str>, + pattern: &Regex, + start: Option, +) -> Result { + let value = match value { + None | Some("") => return Ok(0), + Some(value) => value, + }; + + if let Some(start) = start { + if start < 1 { + return Err(ArrowError::ComputeError( + "regexp_count() requires start to be 1 based".to_string(), + )); + } + + let find_slice = value.chars().skip(start as usize - 1).collect::(); + let count = pattern.find_iter(find_slice.as_str()).count(); + Ok(count as i64) + } else { + let count = pattern.find_iter(value).count(); + Ok(count as i64) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{GenericStringArray, StringViewArray}; + + #[test] + fn test_regexp_count() { + test_case_sensitive_regexp_count_scalar(); + test_case_sensitive_regexp_count_scalar_start(); + test_case_insensitive_regexp_count_scalar_flags(); + test_case_sensitive_regexp_count_start_scalar_complex(); + + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::>(); + test_case_sensitive_regexp_count_array::(); + + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::>(); + test_case_sensitive_regexp_count_array_start::(); + + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::>(); + test_case_insensitive_regexp_count_array_flags::(); + + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::>(); + test_case_sensitive_regexp_count_array_complex::(); + } + + fn test_case_sensitive_regexp_count_scalar() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let expected: Vec = vec![0, 1, 2, 1, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_scalar_start() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 2; + let expected: Vec = vec![0, 1, 1, 0, 2]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_insensitive_regexp_count_scalar_flags() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = "abc"; + let start = 1; + let flags = "i"; + let expected: Vec = vec![0, 1, 2, 2, 3]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); + let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); + let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 2]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_array_start() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + + let expected = Int64Array::from(vec![0, 0, 1, 1, 0]); + + let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_insensitive_regexp_count_array_flags() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 2, 2, 3]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } + + fn test_case_sensitive_regexp_count_start_scalar_complex() { + let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; + let regex = ["", "abc", "a", "bc", "ab"]; + let start = 5; + let flags = ["", "i", "", "", "i"]; + let expected: Vec = vec![0, 0, 0, 1, 1]; + + values.iter().enumerate().for_each(|(pos, &v)| { + // utf8 + let v_sv = ScalarValue::Utf8(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string())); + let start_sv = ScalarValue::Int64(Some(start)); + let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); + let expected = expected.get(pos).cloned(); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // largeutf8 + let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); + let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv.clone()), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + + // utf8view + let v_sv = ScalarValue::Utf8View(Some(v.to_string())); + let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); + let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); + + let re = RegexpCountFunc::new().invoke_batch( + &[ + ColumnarValue::Scalar(v_sv), + ColumnarValue::Scalar(regex_sv), + ColumnarValue::Scalar(start_sv), + ColumnarValue::Scalar(flags_sv.clone()), + ], + 1, + ); + match re { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { + assert_eq!(v, expected, "regexp_count scalar test failed"); + } + _ => panic!("Unexpected result"), + } + }); + } + + fn test_case_sensitive_regexp_count_array_complex() + where + A: From> + Array + 'static, + { + let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]); + let regex = A::from(vec!["", "abc", "a", "bc", "ab"]); + let start = Int64Array::from(vec![1, 2, 3, 4, 5]); + let flags = A::from(vec!["", "i", "", "", "i"]); + + let expected = Int64Array::from(vec![0, 1, 1, 1, 1]); + + let re = regexp_count_func(&[ + Arc::new(values), + Arc::new(regex), + Arc::new(start), + Arc::new(flags), + ]) + .unwrap(); + assert_eq!(re.as_ref(), &expected); + } +} diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 09b96a28c107..13de7888aa5f 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -15,42 +15,95 @@ // specific language governing permissions and limitations // under the License. -//! Regx expressions -use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +//! Regex expressions + +use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; +use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; -use datafusion_common::{ - cast::as_generic_string_array, internal_err, DataFusionError, Result, -}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct RegexpLikeFunc { signature: Signature, } + impl Default for RegexpLikeFunc { fn default() -> Self { Self::new() } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_like_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise.") + .with_syntax_example("regexp_like(str, regexp[, flags])") + .with_sql_example(r#"```sql +select regexp_like('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); ++--------------------------------------------------------+ +| regexp_like(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | ++--------------------------------------------------------+ +| true | ++--------------------------------------------------------+ +SELECT regexp_like('aBc', '(b|d)', 'i'); ++--------------------------------------------------+ +| regexp_like(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | ++--------------------------------------------------+ +| true | ++--------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", Some("String")) + .with_standard_argument("regexp", Some("Regular")) + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + impl RegexpLikeFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8View]), + TypeSignature::Exact(vec![Utf8, LargeUtf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8View, Utf8]), + TypeSignature::Exact(vec![Utf8, LargeUtf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Utf8]), ], Volatility::Immutable, ), @@ -75,15 +128,13 @@ impl ScalarUDFImpl for RegexpLikeFunc { use DataType::*; Ok(match &arg_types[0] { - LargeUtf8 | Utf8 => Boolean, Null => Null, - other => { - return plan_err!( - "The regexp_like function can only accept strings. Got {other}" - ); - } + // Type coercion is done by DataFusion based on signature, so if we + // get here, the first argument is always a string + _ => Boolean, }) } + fn invoke(&self, args: &[ColumnarValue]) -> Result { let len = args .iter() @@ -99,7 +150,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { .map(|arg| arg.clone().into_array(inferred_length)) .collect::>>()?; - let result = regexp_like_func(&args); + let result = regexp_like(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); @@ -108,16 +159,12 @@ impl ScalarUDFImpl for RegexpLikeFunc { result.map(ColumnarValue::Array) } } -} -fn regexp_like_func(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Utf8 => regexp_like::(args), - DataType::LargeUtf8 => regexp_like::(args), - other => { - internal_err!("Unsupported data type {other:?} for function regexp_like") - } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_like_doc()) } } + /// Tests a string using a regular expression returning true if at /// least one match, false otherwise. /// @@ -160,46 +207,114 @@ fn regexp_like_func(args: &[ArrayRef]) -> Result { /// # Ok(()) /// # } /// ``` -pub fn regexp_like(args: &[ArrayRef]) -> Result { +pub fn regexp_like(args: &[ArrayRef]) -> Result { match args.len() { - 2 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let array = regexp::regexp_is_match_utf8(values, regex, None) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + 2 => handle_regexp_like(&args[0], &args[1], None), 3 => { - let values = as_generic_string_array::(&args[0])?; - let regex = as_generic_string_array::(&args[1])?; - let flags = as_generic_string_array::(&args[2])?; + let flags = args[2].as_string::(); if flags.iter().any(|s| s == Some("g")) { return plan_err!("regexp_like() does not support the \"global\" option"); } - let array = regexp::regexp_is_match_utf8(values, regex, Some(flags)) - .map_err(|e| arrow_datafusion_err!(e))?; - - Ok(Arc::new(array) as ArrayRef) - } + handle_regexp_like(&args[0], &args[1], Some(flags)) + }, other => exec_err!( - "regexp_like was called with {other} arguments. It requires at least 2 and at most 3." + "`regexp_like` was called with {other} arguments. It requires at least 2 and at most 3." ), } } + +fn handle_regexp_like( + values: &ArrayRef, + patterns: &ArrayRef, + flags: Option<&GenericStringArray>, +) -> Result { + let array = match (values.data_type(), patterns.data_type()) { + (Utf8View, Utf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, Utf8View) => { + let value = values.as_string_view(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8View, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (Utf8, LargeUtf8) => { + let value = values.as_string_view(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, Utf8View) => { + let value = values.as_string::(); + let pattern = patterns.as_string_view(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + (LargeUtf8, LargeUtf8) => { + let value = values.as_string::(); + let pattern = patterns.as_string::(); + + regexp::regexp_is_match(value, pattern, flags) + .map_err(|e| arrow_datafusion_err!(e))? + } + other => { + return internal_err!( + "Unsupported data type {other:?} for function `regexp_like`" + ) + } + }; + + Ok(Arc::new(array) as ArrayRef) +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::BooleanBuilder; use arrow::array::StringArray; + use arrow::array::{BooleanBuilder, StringViewArray}; use crate::regex::regexplike::regexp_like; #[test] - fn test_case_sensitive_regexp_like() { + fn test_case_sensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = @@ -213,13 +328,33 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = regexp_like::(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); assert_eq!(re.as_ref(), &expected); } #[test] - fn test_case_insensitive_regexp_like() { + fn test_case_sensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + + let patterns = + StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(true); + expected_builder.append_value(false); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns)]).unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like_utf8() { let values = StringArray::from(vec!["abc"; 5]); let patterns = StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); @@ -233,9 +368,29 @@ mod tests { expected_builder.append_value(false); let expected = expected_builder.finish(); - let re = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) - .unwrap(); + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + + #[test] + fn test_case_insensitive_regexp_like_utf8view() { + let values = StringViewArray::from(vec!["abc"; 5]); + let patterns = + StringViewArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]); + let flags = StringArray::from(vec!["i"; 5]); + + let mut expected_builder: BooleanBuilder = BooleanBuilder::new(); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(true); + expected_builder.append_value(false); + let expected = expected_builder.finish(); + + let re = regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + .unwrap(); assert_eq!(re.as_ref(), &expected); } @@ -247,7 +402,7 @@ mod tests { let flags = StringArray::from(vec!["g"]); let re_err = - regexp_like::(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) + regexp_like(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)]) .expect_err("unsupported flag should have failed"); assert_eq!( diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 73228e608143..019666bd7b2d 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Regx expressions +//! Regex expressions use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; @@ -26,16 +26,17 @@ use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct RegexpMatchFunc { signature: Signature, } + impl Default for RegexpMatchFunc { fn default() -> Self { Self::new() @@ -48,10 +49,14 @@ impl RegexpMatchFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8, Utf8]), + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Utf8)`, it first tries coercing to `(Utf8, Utf8)`. + // If that fails, it proceeds to `(LargeUtf8, Utf8)`. + // TODO: Native support Utf8View for regexp_match. + TypeSignature::Exact(vec![Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, LargeUtf8]), ], Volatility::Immutable, ), @@ -73,17 +78,9 @@ impl ScalarUDFImpl for RegexpMatchFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match &arg_types[0] { - LargeUtf8 => List(Arc::new(Field::new("item", LargeUtf8, true))), - Utf8 => List(Arc::new(Field::new("item", Utf8, true))), - Null => Null, - other => { - return plan_err!( - "The regexp_match function can only accept strings. Got {other}" - ); - } + DataType::Null => DataType::Null, + other => DataType::List(Arc::new(Field::new("item", other.clone(), true))), }) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -110,7 +107,51 @@ impl ScalarUDFImpl for RegexpMatchFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_match_doc()) + } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_match_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string.") + .with_syntax_example("regexp_match(str, regexp[, flags])") + .with_sql_example(r#"```sql + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", Some("String")) + .with_argument("regexp","Regular expression to match against. + Can be a constant, column, or function.") + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() + }) +} + fn regexp_match_func(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Utf8 => regexp_match::(args), @@ -134,7 +175,7 @@ pub fn regexp_match(args: &[ArrayRef]) -> Result { let flags = as_generic_string_array::(&args[2])?; if flags.iter().any(|s| s == Some("g")) { - return plan_err!("regexp_match() does not support the \"global\" option") + return plan_err!("regexp_match() does not support the \"global\" option"); } regexp::regexp_match(values, regex, Some(flags)) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index 4e21883c9752..4d8e5e5fe3e3 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -16,27 +16,31 @@ // under the License. //! Regx expressions -use arrow::array::new_null_array; use arrow::array::ArrayDataBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericStringArray; +use arrow::array::StringViewBuilder; +use arrow::array::{new_null_array, ArrayIter, AsArray}; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayAccessor, StringViewArray}; use arrow::datatypes::DataType; +use datafusion_common::cast::as_string_view_array; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::ScalarValue; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; +use datafusion_expr::function::Hint; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_REGEX; use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use datafusion_physical_expr::functions::Hint; +use datafusion_expr::TypeSignature; +use datafusion_expr::{Documentation, ScalarUDFImpl, Signature, Volatility}; use regex::Regex; use std::any::Any; use std::collections::HashMap; -use std::sync::Arc; -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; + #[derive(Debug)] pub struct RegexpReplaceFunc { signature: Signature, @@ -53,8 +57,10 @@ impl RegexpReplaceFunc { Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8, Utf8]), - Exact(vec![Utf8, Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8, Utf8, Utf8, Utf8]), + TypeSignature::Exact(vec![Utf8View, Utf8, Utf8, Utf8]), ], Volatility::Immutable, ), @@ -80,6 +86,7 @@ impl ScalarUDFImpl for RegexpReplaceFunc { Ok(match &arg_types[0] { LargeUtf8 | LargeBinary => LargeUtf8, Utf8 | Binary => Utf8, + Utf8View | BinaryView => Utf8View, Null => Null, Dictionary(_, t) => match **t { LargeUtf8 | LargeBinary => LargeUtf8, @@ -117,16 +124,64 @@ impl ScalarUDFImpl for RegexpReplaceFunc { result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_regexp_replace_doc()) + } } + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_regexp_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_REGEX) + .with_description("Replaces substrings in a string that match a [regular expression](https://docs.rs/regex/latest/regex/#syntax).") + .with_syntax_example("regexp_replace(str, regexp, replacement[, flags])") + .with_sql_example(r#"```sql +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); ++------------------------------------------------------------------------+ +| regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | ++------------------------------------------------------------------------+ +| fooXarYXazY | ++------------------------------------------------------------------------+ +SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); ++-------------------------------------------------------------------+ +| regexp_replace(Utf8("aBc"),Utf8("(b|d)"),Utf8("Ab\1a"),Utf8("i")) | ++-------------------------------------------------------------------+ +| aAbBac | ++-------------------------------------------------------------------+ +``` +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) +"#) + .with_standard_argument("str", Some("String")) + .with_argument("regexp","Regular expression to match against. + Can be a constant, column, or function.") + .with_standard_argument("replacement", Some("Replacement string")) + .with_argument("flags", + r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*?"#) + .build() + .unwrap() +}) +} + fn regexp_replace_func(args: &[ColumnarValue]) -> Result { match args[0].data_type() { DataType::Utf8 => specialize_regexp_replace::(args), DataType::LargeUtf8 => specialize_regexp_replace::(args), + DataType::Utf8View => specialize_regexp_replace::(args), other => { internal_err!("Unsupported data type {other:?} for function regexp_replace") } } } + /// replace POSIX capture groups (like \1) with Rust Regex group (like ${1}) /// used by regexp_replace fn regex_replace_posix_groups(replacement: &str) -> String { @@ -179,125 +234,170 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// # Ok(()) /// # } /// ``` -pub fn regexp_replace(args: &[ArrayRef]) -> Result { +pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>( + string_array: V, + pattern_array: B, + replacement_array: B, + flags: Option<&ArrayRef>, +) -> Result +where + V: ArrayAccessor, + B: ArrayAccessor, +{ // Default implementation for regexp_replace, assumes all args are arrays // and args is a sequence of 3 or 4 elements. // creating Regex is expensive so create hashmap for memoization let mut patterns: HashMap = HashMap::new(); - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let pattern_array = as_generic_string_array::(&args[1])?; - let replacement_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(pattern_array.iter()) - .zip(replacement_array.iter()) - .map(|((string, pattern), replacement)| match (string, pattern, replacement) { - (Some(string), Some(pattern), Some(replacement)) => { - let replacement = regex_replace_posix_groups(replacement); - - // if patterns hashmap already has regexp then use else create and return - let re = match patterns.get(pattern) { - Some(re) => Ok(re), - None => { - match Regex::new(pattern) { - Ok(re) => { - patterns.insert(pattern.to_string(), re); - Ok(patterns.get(pattern).unwrap()) + let datatype = string_array.data_type().to_owned(); + + let string_array_iter = ArrayIter::new(string_array); + let pattern_array_iter = ArrayIter::new(pattern_array); + let replacement_array_iter = ArrayIter::new(replacement_array); + + match flags { + None => { + let result_iter = string_array_iter + .zip(pattern_array_iter) + .zip(replacement_array_iter) + .map(|((string, pattern), replacement)| { + match (string, pattern, replacement) { + (Some(string), Some(pattern), Some(replacement)) => { + let replacement = regex_replace_posix_groups(replacement); + // if patterns hashmap already has regexp then use else create and return + let re = match patterns.get(pattern) { + Some(re) => Ok(re), + None => match Regex::new(pattern) { + Ok(re) => { + patterns.insert(pattern.to_string(), re); + Ok(patterns.get(pattern).unwrap()) + } + Err(err) => { + Err(DataFusionError::External(Box::new(err))) + } }, - Err(err) => Err(DataFusionError::External(Box::new(err))), - } - } - }; + }; - Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose() + Some(re.map(|re| re.replace(string, replacement.as_str()))) + .transpose() + } + _ => Ok(None), + } + }); + + match datatype { + DataType::Utf8 | DataType::LargeUtf8 => { + let result = + result_iter.collect::>>()?; + Ok(Arc::new(result) as ArrayRef) } - _ => Ok(None) - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + DataType::Utf8View => { + let result = result_iter.collect::>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let pattern_array = as_generic_string_array::(&args[1])?; - let replacement_array = as_generic_string_array::(&args[2])?; - let flags_array = as_generic_string_array::(&args[3])?; - - let result = string_array - .iter() - .zip(pattern_array.iter()) - .zip(replacement_array.iter()) - .zip(flags_array.iter()) - .map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) { - (Some(string), Some(pattern), Some(replacement), Some(flags)) => { - let replacement = regex_replace_posix_groups(replacement); - - // format flags into rust pattern - let (pattern, replace_all) = if flags == "g" { - (pattern.to_string(), true) - } else if flags.contains('g') { - (format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true) - } else { - (format!("(?{flags}){pattern}"), false) - }; - - // if patterns hashmap already has regexp then use else create and return - let re = match patterns.get(&pattern) { - Some(re) => Ok(re), - None => { - match Regex::new(pattern.as_str()) { - Ok(re) => { - patterns.insert(pattern.clone(), re); - Ok(patterns.get(&pattern).unwrap()) + Some(flags) => { + let flags_array = as_generic_string_array::(flags)?; + + let result_iter = string_array_iter + .zip(pattern_array_iter) + .zip(replacement_array_iter) + .zip(flags_array.iter()) + .map(|(((string, pattern), replacement), flags)| { + match (string, pattern, replacement, flags) { + (Some(string), Some(pattern), Some(replacement), Some(flags)) => { + let replacement = regex_replace_posix_groups(replacement); + + // format flags into rust pattern + let (pattern, replace_all) = if flags == "g" { + (pattern.to_string(), true) + } else if flags.contains('g') { + ( + format!( + "(?{}){}", + flags.to_string().replace('g', ""), + pattern + ), + true, + ) + } else { + (format!("(?{flags}){pattern}"), false) + }; + + // if patterns hashmap already has regexp then use else create and return + let re = match patterns.get(&pattern) { + Some(re) => Ok(re), + None => match Regex::new(pattern.as_str()) { + Ok(re) => { + patterns.insert(pattern.clone(), re); + Ok(patterns.get(&pattern).unwrap()) + } + Err(err) => { + Err(DataFusionError::External(Box::new(err))) + } }, - Err(err) => Err(DataFusionError::External(Box::new(err))), - } - } - }; - - Some(re.map(|re| { - if replace_all { - re.replace_all(string, replacement.as_str()) - } else { - re.replace(string, replacement.as_str()) + }; + + Some(re.map(|re| { + if replace_all { + re.replace_all(string, replacement.as_str()) + } else { + re.replace(string, replacement.as_str()) + } + })) + .transpose() } - })).transpose() + _ => Ok(None), + } + }); + + match datatype { + DataType::Utf8 | DataType::LargeUtf8 => { + let result = + result_iter.collect::>>()?; + Ok(Arc::new(result) as ArrayRef) } - _ => Ok(None) - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + DataType::Utf8View => { + let result = result_iter.collect::>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } - other => exec_err!( - "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." - ), } } -fn _regexp_replace_early_abort( - input_array: &GenericStringArray, +fn _regexp_replace_early_abort( + input_array: T, + sz: usize, ) -> Result { // Mimicking the existing behavior of regexp_replace, if any of the scalar arguments - // are actually null, then the result will be an array of the same size but with nulls. + // are actually null, then the result will be an array of the same size as the first argument with all nulls. // // Also acts like an early abort mechanism when the input array is empty. - Ok(new_null_array(input_array.data_type(), input_array.len())) + Ok(new_null_array(input_array.data_type(), sz)) } + /// Get the first argument from the given string array. /// /// Note: If the array is empty or the first argument is null, /// then calls the given early abort function. macro_rules! fetch_string_arg { - ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{ - let array = as_generic_string_array::($ARG)?; + ($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident, $ARRAY_SIZE:expr) => {{ + let array = as_generic_string_array::<$T>($ARG)?; if array.len() == 0 || array.is_null(0) { - return $EARLY_ABORT(array); + return $EARLY_ABORT(array, $ARRAY_SIZE); } else { array.value(0) } @@ -312,13 +412,24 @@ macro_rules! fetch_string_arg { fn _regexp_replace_static_pattern_replace( args: &[ArrayRef], ) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let pattern = fetch_string_arg!(&args[1], "pattern", T, _regexp_replace_early_abort); - let replacement = - fetch_string_arg!(&args[2], "replacement", T, _regexp_replace_early_abort); + let array_size = args[0].len(); + let pattern = fetch_string_arg!( + &args[1], + "pattern", + i32, + _regexp_replace_early_abort, + array_size + ); + let replacement = fetch_string_arg!( + &args[2], + "replacement", + i32, + _regexp_replace_early_abort, + array_size + ); let flags = match args.len() { 3 => None, - 4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort)), + 4 => Some(fetch_string_arg!(&args[3], "flags", i32, _regexp_replace_early_abort, array_size)), other => { return exec_err!( "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." @@ -345,32 +456,60 @@ fn _regexp_replace_static_pattern_replace( // with rust ones. let replacement = regex_replace_posix_groups(replacement); - // We are going to create the underlying string buffer from its parts - // to be able to re-use the existing null buffer for sparse arrays. - let mut vals = BufferBuilder::::new({ - let offsets = string_array.value_offsets(); - (offsets[string_array.len()] - offsets[0]) - .to_usize() - .unwrap() - }); - let mut new_offsets = BufferBuilder::::new(string_array.len() + 1); - new_offsets.append(T::zero()); - - string_array.iter().for_each(|val| { - if let Some(val) = val { - let result = re.replacen(val, limit, replacement.as_str()); - vals.append_slice(result.as_bytes()); + let string_array_type = args[0].data_type(); + match string_array_type { + DataType::Utf8 | DataType::LargeUtf8 => { + let string_array = as_generic_string_array::(&args[0])?; + + // We are going to create the underlying string buffer from its parts + // to be able to re-use the existing null buffer for sparse arrays. + let mut vals = BufferBuilder::::new({ + let offsets = string_array.value_offsets(); + (offsets[string_array.len()] - offsets[0]) + .to_usize() + .unwrap() + }); + let mut new_offsets = BufferBuilder::::new(string_array.len() + 1); + new_offsets.append(T::zero()); + + string_array.iter().for_each(|val| { + if let Some(val) = val { + let result = re.replacen(val, limit, replacement.as_str()); + vals.append_slice(result.as_bytes()); + } + new_offsets.append(T::from_usize(vals.len()).unwrap()); + }); + + let data = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) + .len(string_array.len()) + .nulls(string_array.nulls().cloned()) + .buffers(vec![new_offsets.finish(), vals.finish()]) + .build()?; + let result_array = GenericStringArray::::from(data); + Ok(Arc::new(result_array) as ArrayRef) + } + DataType::Utf8View => { + let string_view_array = as_string_view_array(&args[0])?; + + let mut builder = StringViewBuilder::with_capacity(string_view_array.len()); + + for val in string_view_array.iter() { + if let Some(val) = val { + let result = re.replacen(val, limit, replacement.as_str()); + builder.append_value(result); + } else { + builder.append_null(); + } + } + + let result = builder.finish(); + Ok(Arc::new(result) as ArrayRef) } - new_offsets.append(T::from_usize(vals.len()).unwrap()); - }); - - let data = ArrayDataBuilder::new(GenericStringArray::::DATA_TYPE) - .len(string_array.len()) - .nulls(string_array.nulls().cloned()) - .buffers(vec![new_offsets.finish(), vals.finish()]) - .build()?; - let result_array = GenericStringArray::::from(data); - Ok(Arc::new(result_array) as ArrayRef) + _ => unreachable!( + "Invalid data type for regexp_replace: {}", + string_array_type + ), + } } /// Determine which implementation of the regexp_replace to use based @@ -446,7 +585,47 @@ pub fn specialize_regexp_replace( .iter() .map(|arg| arg.clone().into_array(inferred_length)) .collect::>>()?; - regexp_replace::(&args) + + match args[0].data_type() { + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } } } @@ -456,43 +635,91 @@ mod tests { use super::*; - #[test] - fn test_static_pattern_regexp_replace() { - let values = StringArray::from(vec!["abc"; 5]); - let patterns = StringArray::from(vec!["b"; 5]); - let replacements = StringArray::from(vec!["foo"; 5]); - let expected = StringArray::from(vec!["afooc"; 5]); - - let re = _regexp_replace_static_pattern_replace::(&[ - Arc::new(values), - Arc::new(patterns), - Arc::new(replacements), - ]) - .unwrap(); - - assert_eq!(re.as_ref(), &expected); + macro_rules! static_pattern_regexp_replace { + ($name:ident, $T:ty, $O:ty) => { + #[test] + fn $name() { + let values = vec!["abc", "acd", "abcd1234567890123", "123456789012abc"]; + let patterns = vec!["b"; 4]; + let replacement = vec!["foo"; 4]; + let expected = + vec!["afooc", "acd", "afoocd1234567890123", "123456789012afooc"]; + + let values = <$T>::from(values); + let patterns = StringArray::from(patterns); + let replacements = StringArray::from(replacement); + let expected = <$T>::from(expected); + + let re = _regexp_replace_static_pattern_replace::<$O>(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + }; } - #[test] - fn test_static_pattern_regexp_replace_with_flags() { - let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]); - let patterns = StringArray::from(vec!["b"; 5]); - let replacements = StringArray::from(vec!["foo"; 5]); - let flags = StringArray::from(vec!["i"; 5]); - let expected = - StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]); - - let re = _regexp_replace_static_pattern_replace::(&[ - Arc::new(values), - Arc::new(patterns), - Arc::new(replacements), - Arc::new(flags), - ]) - .unwrap(); - - assert_eq!(re.as_ref(), &expected); + static_pattern_regexp_replace!(string_array, StringArray, i32); + static_pattern_regexp_replace!(string_view_array, StringViewArray, i32); + static_pattern_regexp_replace!(large_string_array, LargeStringArray, i64); + + macro_rules! static_pattern_regexp_replace_with_flags { + ($name:ident, $T:ty, $O: ty) => { + #[test] + fn $name() { + let values = vec![ + "abc", + "aBc", + "acd", + "abcd1234567890123", + "aBcd1234567890123", + "123456789012abc", + "123456789012aBc", + ]; + let expected = vec![ + "afooc", + "afooc", + "acd", + "afoocd1234567890123", + "afoocd1234567890123", + "123456789012afooc", + "123456789012afooc", + ]; + + let values = <$T>::from(values); + let patterns = StringArray::from(vec!["b"; 7]); + let replacements = StringArray::from(vec!["foo"; 7]); + let flags = StringArray::from(vec!["i"; 5]); + let expected = <$T>::from(expected); + + let re = _regexp_replace_static_pattern_replace::<$O>(&[ + Arc::new(values), + Arc::new(patterns), + Arc::new(replacements), + Arc::new(flags), + ]) + .unwrap(); + + assert_eq!(re.as_ref(), &expected); + } + }; } + static_pattern_regexp_replace_with_flags!(string_array_with_flags, StringArray, i32); + static_pattern_regexp_replace_with_flags!( + string_view_array_with_flags, + StringViewArray, + i32 + ); + static_pattern_regexp_replace_with_flags!( + large_string_array_with_flags, + LargeStringArray, + i64 + ); + #[test] fn test_static_pattern_regexp_replace_early_abort() { let values = StringArray::from(vec!["abc"; 5]); @@ -549,7 +776,7 @@ mod tests { #[test] fn test_static_pattern_regexp_replace_pattern_error() { let values = StringArray::from(vec!["abc"; 5]); - // Delibaretely using an invalid pattern to see how the single pattern + // Deliberately using an invalid pattern to see how the single pattern // error is propagated on regexp_replace. let patterns = StringArray::from(vec!["["; 5]); let replacements = StringArray::from(vec!["foo"; 5]); diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 15a3c2391ac6..b76d70d7e9d2 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -16,46 +16,31 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::Int32Array; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef, AsArray, Int32Array}; use arrow::datatypes::DataType; -use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; -use datafusion_expr::ColumnarValue; +use arrow::error::ArrowError; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use std::sync::Arc; - -/// Returns the numeric code of the first character of the argument. -/// ascii('x') = 120 -pub fn ascii(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - let mut chars = string.chars(); - chars.next().map_or(0, |v| v as i32) - }) - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct AsciiFunc { signature: Signature, } + +impl Default for AsciiFunc { + fn default() -> Self { + Self::new() + } +} + impl AsciiFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -80,12 +65,129 @@ impl ScalarUDFImpl for AsciiFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(ascii::, vec![])(args), - DataType::LargeUtf8 => { - return make_scalar_function(ascii::, vec![])(args); - } - _ => internal_err!("Unsupported data type"), + make_scalar_function(ascii, vec![])(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ascii_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ascii_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns the Unicode character code of the first character in a string.", + ) + .with_syntax_example("ascii(str)") + .with_sql_example( + r#"```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("chr") + .build() + .unwrap() + }) +} + +fn calculate_ascii<'a, V>(array: V) -> Result +where + V: ArrayAccessor, +{ + let iter = ArrayIter::new(array); + let result = iter + .map(|string| { + string.map(|s| { + let mut chars = s.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the numeric code of the first character of the argument. +pub fn ascii(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + Ok(calculate_ascii(string_array)?) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + Ok(calculate_ascii(string_array)?) } + _ => internal_err!("Unsupported data type"), + } +} + +#[cfg(test)] +mod tests { + use crate::string::ascii::AsciiFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + macro_rules! test_ascii { + ($INPUT:expr, $EXPECTED:expr) => { + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + AsciiFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + + #[test] + fn test_functions() -> Result<()> { + test_ascii!(Some(String::from("x")), Ok(Some(120))); + test_ascii!(Some(String::from("a")), Ok(Some(97))); + test_ascii!(Some(String::from("")), Ok(Some(0))); + test_ascii!(None, Ok(None)); + Ok(()) } } diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 17c49216553b..25b56341fcaa 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -15,31 +15,32 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::utf8_to_int_type; - #[derive(Debug)] pub struct BitLengthFunc { signature: Signature, } +impl Default for BitLengthFunc { + fn default() -> Self { + Self::new() + } +} + impl BitLengthFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -82,4 +83,34 @@ impl ScalarUDFImpl for BitLengthFunc { }, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_bit_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_bit_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the bit length of a string.") + .with_syntax_example("bit_length(str)") + .with_sql_example( + r#"```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("length") + .with_related_udf("octet_length") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index 971f7bbd4d92..e215b18d9c3c 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -15,24 +15,24 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use arrow::array::{ArrayRef, OffsetSizeTrait}; -use std::any::Any; - use arrow::datatypes::DataType; - use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_physical_expr::functions::Hint; - -use crate::string::common::*; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::function::Hint; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::any::Any; +use std::sync::OnceLock; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' fn btrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Both) + let use_string_view = args[0].data_type() == &DataType::Utf8View; + general_trim::(args, TrimType::Both, use_string_view) } #[derive(Debug)] @@ -41,12 +41,17 @@ pub struct BTrimFunc { aliases: Vec, } +impl Default for BTrimFunc { + fn default() -> Self { + Self::new() + } +} + impl BTrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), aliases: vec![String::from("trim")], @@ -68,12 +73,16 @@ impl ScalarUDFImpl for BTrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "btrim") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "btrim") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function( + DataType::Utf8 | DataType::Utf8View => make_scalar_function( btrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), @@ -81,11 +90,191 @@ impl ScalarUDFImpl for BTrimFunc { btrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), - other => exec_err!("Unsupported data type {other:?} for function btrim"), + other => exec_err!( + "Unsupported data type {other:?} for function btrim,\ + expected Utf8, LargeUtf8 or Utf8View." + ), } } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_btrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_btrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.") + .with_syntax_example("btrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("trim_str", "String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(BOTH trim_str FROM str)") + .with_alternative_syntax("trim(trim_str FROM str)") + .with_related_udf("ltrim") + .with_related_udf("rtrim") + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::btrim::BTrimFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() { + // String view cases for checking normal logic + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from(" alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t")))), + ], + Ok(Some("alphabe")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabe" + )))), + ], + Ok(Some("t")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + // Special string view case for checking unlined output(len > 12) + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "xxxalphabetalphabetxxx" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("x")))), + ], + Ok(Some("alphabetalphabet")), + &str, + Utf8View, + StringViewArray + ); + // String cases + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), + ], + Ok(Some("alphabe")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), + ], + Ok(Some("t")), + &str, + Utf8, + StringArray + ); + test_function!( + BTrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + } } diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index 21d79cf6b0f1..0d94cab08d91 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -16,7 +16,7 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::ArrayRef; use arrow::array::StringArray; @@ -24,13 +24,13 @@ use arrow::datatypes::DataType; use arrow::datatypes::DataType::Int64; use arrow::datatypes::DataType::Utf8; +use crate::utils::make_scalar_function; use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::make_scalar_function; - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -65,6 +65,12 @@ pub struct ChrFunc { signature: Signature, } +impl Default for ChrFunc { + fn default() -> Self { + Self::new() + } +} + impl ChrFunc { pub fn new() -> Self { Self { @@ -93,4 +99,35 @@ impl ScalarUDFImpl for ChrFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(chr, vec![])(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_chr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_chr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns the character with the specified ASCII or Unicode code value.", + ) + .with_syntax_example("chr(expression)") + .with_sql_example( + r#"```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ +```"#, + ) + .with_standard_argument("expression", Some("String")) + .with_related_udf("ascii") + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index d36bd5cecc47..0d1f90eb22b9 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -15,17 +15,20 @@ // specific language governing permissions and limitations // under the License. +//! Common utilities for implementing string functions + use std::fmt::{Display, Formatter}; use std::sync::Arc; +use crate::strings::make_and_append_view; use arrow::array::{ - new_null_array, Array, ArrayDataBuilder, ArrayRef, GenericStringArray, - GenericStringBuilder, OffsetSizeTrait, StringArray, + new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder, + OffsetSizeTrait, StringBuilder, StringViewArray, }; -use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; +use arrow::buffer::Buffer; use arrow::datatypes::DataType; - -use datafusion_common::cast::as_generic_string_array; +use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::ColumnarValue; @@ -49,32 +52,213 @@ impl Display for TrimType { pub(crate) fn general_trim( args: &[ArrayRef], trim_type: TrimType, + use_string_view: bool, ) -> Result { let func = match trim_type { TrimType::Left => |input, pattern: &str| { let pattern = pattern.chars().collect::>(); - str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + let ltrimmed_str = + str::trim_start_matches::<&[char]>(input, pattern.as_ref()); + // `ltrimmed_str` is actually `input`[start_offset..], + // so `start_offset` = len(`input`) - len(`ltrimmed_str`) + let start_offset = input.as_bytes().len() - ltrimmed_str.as_bytes().len(); + + (ltrimmed_str, start_offset as u32) }, TrimType::Right => |input, pattern: &str| { let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + let rtrimmed_str = str::trim_end_matches::<&[char]>(input, pattern.as_ref()); + + // `ltrimmed_str` is actually `input`[0..new_len], so `start_offset` is 0 + (rtrimmed_str, 0) }, TrimType::Both => |input, pattern: &str| { let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>( - str::trim_start_matches::<&[char]>(input, pattern.as_ref()), - pattern.as_ref(), - ) + let ltrimmed_str = + str::trim_start_matches::<&[char]>(input, pattern.as_ref()); + // `btrimmed_str` can be got by rtrim(ltrim(`input`)), + // so its `start_offset` should be same as ltrim situation above + let start_offset = input.as_bytes().len() - ltrimmed_str.as_bytes().len(); + let btrimmed_str = + str::trim_end_matches::<&[char]>(ltrimmed_str, pattern.as_ref()); + + (btrimmed_str, start_offset as u32) }, }; + if use_string_view { + string_view_trim(func, args) + } else { + string_trim::(func, args) + } +} + +/// Applies the trim function to the given string view array(s) +/// and returns a new string view array with the trimmed values. +/// +/// # `trim_func`: The function to apply to each string view. +/// +/// ## Arguments +/// - The original string +/// - the pattern to trim +/// +/// ## Returns +/// - trimmed str (must be a substring of the first argument) +/// - start offset, needed in `string_view_trim` +/// +/// ## Examples +/// +/// For `ltrim`: +/// - `fn(" abc", " ") -> ("abc", 2)` +/// - `fn("abd", " ") -> ("abd", 0)` +/// +/// For `btrim`: +/// - `fn(" abc ", " ") -> ("abc", 2)` +/// - `fn("abd", " ") -> ("abd", 0)` +// removing 'a will cause compiler complaining lifetime of `func` +fn string_view_trim<'a>( + trim_func: fn(&'a str, &'a str) -> (&'a str, u32), + args: &'a [ArrayRef], +) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + let mut views_buf = Vec::with_capacity(string_view_array.len()); + let mut null_builder = NullBufferBuilder::new(string_view_array.len()); + + match args.len() { + 1 => { + let array_iter = string_view_array.iter(); + let views_iter = string_view_array.views().iter(); + for (src_str_opt, raw_view) in array_iter.zip(views_iter) { + trim_and_append_str( + src_str_opt, + Some(" "), + trim_func, + &mut views_buf, + &mut null_builder, + raw_view, + ); + } + } + 2 => { + let characters_array = as_string_view_array(&args[1])?; + + if characters_array.len() == 1 { + // Only one `trim characters` exist + if characters_array.is_null(0) { + return Ok(new_null_array( + // The schema is expecting utf8 as null + &DataType::Utf8View, + string_view_array.len(), + )); + } + + let characters = characters_array.value(0); + let array_iter = string_view_array.iter(); + let views_iter = string_view_array.views().iter(); + for (src_str_opt, raw_view) in array_iter.zip(views_iter) { + trim_and_append_str( + src_str_opt, + Some(characters), + trim_func, + &mut views_buf, + &mut null_builder, + raw_view, + ); + } + } else { + // A specific `trim characters` for a row in the string view array + let characters_iter = characters_array.iter(); + let array_iter = string_view_array.iter(); + let views_iter = string_view_array.views().iter(); + for ((src_str_opt, raw_view), characters_opt) in + array_iter.zip(views_iter).zip(characters_iter) + { + trim_and_append_str( + src_str_opt, + characters_opt, + trim_func, + &mut views_buf, + &mut null_builder, + raw_view, + ); + } + } + } + other => { + return exec_err!( + "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." + ); + } + } + + let views_buf = ScalarBuffer::from(views_buf); + let nulls_buf = null_builder.finish(); + + // Safety: + // (1) The blocks of the given views are all provided + // (2) Each of the range `view.offset+start..end` of view in views_buf is within + // the bounds of each of the blocks + unsafe { + let array = StringViewArray::new_unchecked( + views_buf, + string_view_array.data_buffers().to_vec(), + nulls_buf, + ); + Ok(Arc::new(array) as ArrayRef) + } +} + +/// Trims the given string and appends the trimmed string to the views buffer +/// and the null buffer. +/// +/// Calls `trim_func` on the string value in `original_view`, for non_null +/// values and appends the updated view to the views buffer / null_builder. +/// +/// Arguments +/// - `src_str_opt`: The original string value (represented by the view) +/// - `trim_characters_opt`: The characters to trim from the string +/// - `trim_func`: The function to apply to the string (see [`string_view_trim`] for details) +/// - `views_buf`: The buffer to append the updated views to +/// - `null_builder`: The buffer to append the null values to +/// - `original_view`: The original view value (that contains src_str_opt) +fn trim_and_append_str<'a>( + src_str_opt: Option<&'a str>, + trim_characters_opt: Option<&'a str>, + trim_func: fn(&'a str, &'a str) -> (&'a str, u32), + views_buf: &mut Vec, + null_builder: &mut NullBufferBuilder, + original_view: &u128, +) { + if let (Some(src_str), Some(characters)) = (src_str_opt, trim_characters_opt) { + let (trim_str, start_offset) = trim_func(src_str, characters); + make_and_append_view( + views_buf, + null_builder, + original_view, + trim_str, + start_offset, + ); + } else { + null_builder.append_null(); + views_buf.push(0); + } +} + +/// Applies the trim function to the given string array(s) +/// and returns a new string array with the trimmed values. +/// +/// See [`string_view_trim`] for details on `func` +fn string_trim<'a, T: OffsetSizeTrait>( + func: fn(&'a str, &'a str) -> (&'a str, u32), + args: &'a [ArrayRef], +) -> Result { let string_array = as_generic_string_array::(&args[0])?; match args.len() { 1 => { let result = string_array .iter() - .map(|string| string.map(|string: &str| func(string, " "))) + .map(|string| string.map(|string: &str| func(string, " ").0)) .collect::>(); Ok(Arc::new(result) as ArrayRef) @@ -84,13 +268,16 @@ pub(crate) fn general_trim( if characters_array.len() == 1 { if characters_array.is_null(0) { - return Ok(new_null_array(args[0].data_type(), args[0].len())); + return Ok(new_null_array( + string_array.data_type(), + string_array.len(), + )); } let characters = characters_array.value(0); let result = string_array .iter() - .map(|item| item.map(|string| func(string, characters))) + .map(|item| item.map(|string| func(string, characters).0)) .collect::>(); return Ok(Arc::new(result) as ArrayRef); } @@ -99,7 +286,7 @@ pub(crate) fn general_trim( .iter() .zip(characters_array.iter()) .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), + (Some(string), Some(characters)) => Some(func(string, characters).0), _ => None, }) .collect::>(); @@ -108,8 +295,8 @@ pub(crate) fn general_trim( } other => { exec_err!( - "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) + "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." + ) } } } @@ -139,6 +326,23 @@ where i64, _, >(array, op)?)), + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + let mut string_builder = StringBuilder::with_capacity( + string_array.len(), + string_array.get_array_memory_size(), + ); + + for str in string_array.iter() { + if let Some(str) = str { + string_builder.append_value(op(str)); + } else { + string_builder.append_null(); + } + } + + Ok(ColumnarValue::Array(Arc::new(string_builder.finish()))) + } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { @@ -150,100 +354,15 @@ where let result = a.as_ref().map(|x| op(x)); Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) } + ScalarValue::Utf8View(a) => { + let result = a.as_ref().map(|x| op(x)); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, } } -pub(crate) enum ColumnarValueRef<'a> { - Scalar(&'a [u8]), - NullableArray(&'a StringArray), - NonNullableArray(&'a StringArray), -} - -impl<'a> ColumnarValueRef<'a> { - #[inline] - pub fn is_valid(&self, i: usize) -> bool { - match &self { - Self::Scalar(_) | Self::NonNullableArray(_) => true, - Self::NullableArray(array) => array.is_valid(i), - } - } - - #[inline] - pub fn nulls(&self) -> Option { - match &self { - Self::Scalar(_) | Self::NonNullableArray(_) => None, - Self::NullableArray(array) => array.nulls().cloned(), - } - } -} - -/// Optimized version of the StringBuilder in Arrow that: -/// 1. Precalculating the expected length of the result, avoiding reallocations. -/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` -pub(crate) struct StringArrayBuilder { - offsets_buffer: MutableBuffer, - value_buffer: MutableBuffer, -} - -impl StringArrayBuilder { - pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_buffer = MutableBuffer::with_capacity( - (item_capacity + 1) * std::mem::size_of::(), - ); - // SAFETY: the first offset value is definitely not going to exceed the bounds. - unsafe { offsets_buffer.push_unchecked(0_i32) }; - Self { - offsets_buffer, - value_buffer: MutableBuffer::with_capacity(data_capacity), - } - } - - pub fn write( - &mut self, - column: &ColumnarValueRef, - i: usize, - ) { - match column { - ColumnarValueRef::Scalar(s) => { - self.value_buffer.extend_from_slice(s); - } - ColumnarValueRef::NullableArray(array) => { - if !CHECK_VALID || array.is_valid(i) { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - ColumnarValueRef::NonNullableArray(array) => { - self.value_buffer - .extend_from_slice(array.value(i).as_bytes()); - } - } - } - - pub fn append_offset(&mut self) { - let next_offset: i32 = self - .value_buffer - .len() - .try_into() - .expect("byte array offset overflow"); - unsafe { self.offsets_buffer.push_unchecked(next_offset) }; - } - - pub fn finish(self, null_buffer: Option) -> StringArray { - let array_builder = ArrayDataBuilder::new(DataType::Utf8) - .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) - .add_buffer(self.offsets_buffer.into()) - .add_buffer(self.value_buffer.into()) - .nulls(null_buffer); - // SAFETY: all data that was appended was valid UTF8 and the values - // and offsets were created correctly - let array_data = unsafe { array_builder.build_unchecked() }; - StringArray::from(array_data) - } -} - fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result where O: OffsetSizeTrait, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 55b7c2f22249..e3834b291896 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -15,22 +15,23 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - +use arrow::array::{as_largestring_array, Array}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Utf8; +use std::any::Any; +use std::sync::{Arc, OnceLock}; -use datafusion_common::cast::as_string_array; -use datafusion_common::{internal_err, Result, ScalarValue}; +use crate::string::concat; +use crate::strings::{ + ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder, +}; +use datafusion_common::cast::{as_string_array, as_string_view_array}; +use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::string::concat; - #[derive(Debug)] pub struct ConcatFunc { signature: Signature, @@ -46,7 +47,10 @@ impl ConcatFunc { pub fn new() -> Self { use DataType::*; Self { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + signature: Signature::variadic( + vec![Utf8, Utf8View, LargeUtf8], + Volatility::Immutable, + ), } } } @@ -64,13 +68,36 @@ impl ScalarUDFImpl for ConcatFunc { &self.signature } - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Utf8) + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + let mut dt = &Utf8; + arg_types.iter().for_each(|data_type| { + if data_type == &Utf8View { + dt = data_type; + } + if data_type == &LargeUtf8 && dt != &Utf8View { + dt = data_type; + } + }); + + Ok(dt.to_owned()) } /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' fn invoke(&self, args: &[ColumnarValue]) -> Result { + let mut return_datatype = DataType::Utf8; + args.iter().for_each(|col| { + if col.data_type() == DataType::Utf8View { + return_datatype = col.data_type(); + } + if col.data_type() == DataType::LargeUtf8 + && return_datatype != DataType::Utf8View + { + return_datatype = col.data_type(); + } + }); + let array_len = args .iter() .filter_map(|x| match x { @@ -87,7 +114,21 @@ impl ScalarUDFImpl for ConcatFunc { result.push_str(v); } } - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + + return match return_datatype { + DataType::Utf8View => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result)))) + } + DataType::Utf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result)))) + } + other => { + plan_err!("Concat function does not support datatype of {other}") + } + }; } // Array @@ -97,34 +138,97 @@ impl ScalarUDFImpl for ConcatFunc { for arg in args { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { if let Some(s) = maybe_value { data_size += s.len() * len; columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } } ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + }, + DataType::LargeUtf8 => { + let string_array = as_largestring_array(array); + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + }; + columns.push(column); + }, + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + + data_size += string_array.len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + }; + columns.push(column); + }, + other => { + return plan_err!("Input was {other} which is not a supported datatype for concat function") + } }; - columns.push(column); } _ => unreachable!(), } } - let mut builder = StringArrayBuilder::with_capacity(len, data_size); - for i in 0..len { - columns - .iter() - .for_each(|column| builder.write::(column, i)); - builder.append_offset(); + match return_datatype { + DataType::Utf8 => { + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + + let string_array = builder.finish(None); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::Utf8View => { + let mut builder = StringViewArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + + let string_array = builder.finish(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + DataType::LargeUtf8 => { + let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); + } + + let string_array = builder.finish(None); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + _ => unreachable!(), } - Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) } /// Simplify the `concat` function by @@ -142,6 +246,36 @@ impl ScalarUDFImpl for ConcatFunc { ) -> Result { simplify_concat(args) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_concat_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_concat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Concatenates multiple strings together.") + .with_syntax_example("concat(str[, ..., str_n])") + .with_sql_example( + r#"```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("str_n", "Subsequent string expressions to concatenate.") + .with_related_udf("concat_ws") + .build() + .unwrap() + }) } pub fn simplify_concat(args: Vec) -> Result { @@ -151,11 +285,11 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), + ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v)), ) => contiguous_scalar += &v, Expr::Literal(x) => { return internal_err!( @@ -182,7 +316,7 @@ pub fn simplify_concat(args: Vec) -> Result { if !args.eq(&new_args) { Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(concat()), + func: concat(), args: new_args, }, ))) @@ -195,8 +329,9 @@ pub fn simplify_concat(args: Vec) -> Result { mod tests { use super::*; use crate::utils::test::test_function; - use arrow::array::Array; + use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use DataType::*; #[test] fn test_functions() -> Result<()> { @@ -232,6 +367,31 @@ mod tests { Utf8, StringArray ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aacc")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + ConcatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("aa")), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)), + ColumnarValue::Scalar(ScalarValue::from("cc")), + ], + Ok(Some("aacc")), + &str, + LargeUtf8, + LargeStringArray + ); Ok(()) } @@ -248,6 +408,7 @@ mod tests { ]))); let args = &[c0, c1, c2]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = ConcatFunc::new().invoke(args)?; let expected = Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 1d27712b2c93..811939c1699b 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -15,24 +15,23 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::StringArray; +use arrow::array::{as_largestring_array, Array, StringArray}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Utf8; -use datafusion_common::cast::as_string_array; -use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use crate::string::concat::simplify_concat; +use crate::string::concat_ws; +use crate::strings::{ColumnarValueRef, StringArrayBuilder}; +use datafusion_common::cast::{as_string_array, as_string_view_array}; +use datafusion_common::{exec_err, internal_err, plan_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; -use crate::string::concat::simplify_concat; -use crate::string::concat_ws; - #[derive(Debug)] pub struct ConcatWsFunc { signature: Signature, @@ -48,7 +47,10 @@ impl ConcatWsFunc { pub fn new() -> Self { use DataType::*; Self { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), + signature: Signature::variadic( + vec![Utf8View, Utf8, LargeUtf8], + Volatility::Immutable, + ), } } } @@ -67,13 +69,14 @@ impl ScalarUDFImpl for ConcatWsFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; Ok(Utf8) } /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' fn invoke(&self, args: &[ColumnarValue]) -> Result { - // do not accept 0 or 1 arguments. + // do not accept 0 arguments. if args.len() < 2 { return exec_err!( "concat_ws was called with {} arguments. It requires at least 2.", @@ -92,8 +95,12 @@ impl ScalarUDFImpl for ConcatWsFunc { // Scalar if array_len.is_none() { let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } _ => unreachable!(), @@ -104,22 +111,30 @@ impl ScalarUDFImpl for ConcatWsFunc { for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { result.push_str(s); break; } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} _ => unreachable!(), } } for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) + | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { result.push_str(sep); result.push_str(s); } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} _ => unreachable!(), } } @@ -155,21 +170,53 @@ impl ScalarUDFImpl for ConcatWsFunc { let mut columns = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value)) + | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => { if let Some(s) = maybe_value { data_size += s.len() * len; columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } } ColumnarValue::Array(array) => { - let string_array = as_string_array(array)?; - data_size += string_array.values().len(); - let column = if array.is_nullable() { - ColumnarValueRef::NullableArray(string_array) - } else { - ColumnarValueRef::NonNullableArray(string_array) + match array.data_type() { + DataType::Utf8 => { + let string_array = as_string_array(array)?; + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + }, + DataType::LargeUtf8 => { + let string_array = as_largestring_array(array); + + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableLargeStringArray(string_array) + } else { + ColumnarValueRef::NonNullableLargeStringArray(string_array) + }; + columns.push(column); + }, + DataType::Utf8View => { + let string_array = as_string_view_array(array)?; + + data_size += string_array.data_buffers().iter().map(|buf| buf.len()).sum::(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableStringViewArray(string_array) + } else { + ColumnarValueRef::NonNullableStringViewArray(string_array) + }; + columns.push(column); + }, + other => { + return plan_err!("Input was {other} which is not a supported datatype for concat_ws function.") + } }; - columns.push(column); } _ => unreachable!(), } @@ -218,12 +265,50 @@ impl ScalarUDFImpl for ConcatWsFunc { _ => Ok(ExprSimplifyResult::Original(args)), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_concat_ws_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_concat_ws_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Concatenates multiple strings together with a specified separator.", + ) + .with_syntax_example("concat_ws(separator, str[, ..., str_n])") + .with_sql_example( + r#"```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +```"#, + ) + .with_argument( + "separator", + "Separator to insert between concatenated strings.", + ) + .with_standard_argument("str", Some("String")) + .with_argument("str_n", "Subsequent string expressions to concatenate.") + .with_related_udf("concat") + .build() + .unwrap() + }) } fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { Expr::Literal( - ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), + ScalarValue::Utf8(delimiter) + | ScalarValue::LargeUtf8(delimiter) + | ScalarValue::Utf8View(delimiter), ) => { match delimiter { // when the delimiter is an empty string, @@ -236,8 +321,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | ScalarValue::Utf8View(Some(v))) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { @@ -266,7 +351,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result Self { + ContainsFunc::new() + } +} + +impl ContainsFunc { + pub fn new() -> Self { + Self { + signature: Signature::string(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ContainsFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(contains, vec![])(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_contains_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_contains_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Return true if search_str is found within string (case-sensitive).", + ) + .with_syntax_example("contains(str, search_str)") + .with_sql_example( + r#"```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("search_str", "The string to search for in str.") + .build() + .unwrap() + }) +} + +/// use `arrow::compute::contains` to do the calculation for contains +pub fn contains(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (Utf8View, Utf8View) => { + let mod_str = args[0].as_string_view(); + let match_str = args[1].as_string_view(); + let res = arrow_contains(mod_str, match_str)?; + Ok(Arc::new(res) as ArrayRef) + } + (Utf8, Utf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = arrow_contains(mod_str, match_str)?; + Ok(Arc::new(res) as ArrayRef) + } + (LargeUtf8, LargeUtf8) => { + let mod_str = args[0].as_string::(); + let match_str = args[1].as_string::(); + let res = arrow_contains(mod_str, match_str)?; + Ok(Arc::new(res) as ArrayRef) + } + other => { + exec_err!("Unsupported data type {other:?} for function `contains`.") + } + } +} + +#[cfg(test)] +mod test { + use super::ContainsFunc; + use arrow::array::{BooleanArray, StringArray}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_contains_udf() { + let udf = ContainsFunc::new(); + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("xxx?()"), + Some("yyy?()"), + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch + let actual = udf.invoke(&[array, scalar]).unwrap(); + let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + ]))); + assert_eq!( + *actual.into_array(2).unwrap(), + *expect.into_array(2).unwrap() + ); + } +} diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index b72cf0f66fa6..88978a35c0b7 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -16,19 +16,16 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use arrow::datatypes::DataType::Boolean; - -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::utils::make_scalar_function; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] pub struct EndsWithFunc { @@ -43,17 +40,8 @@ impl Default for EndsWithFunc { impl EndsWithFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -72,27 +60,60 @@ impl ScalarUDFImpl for EndsWithFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(Boolean) + Ok(DataType::Boolean) } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(ends_with::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(ends_with::, vec![])(args), + DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { + make_scalar_function(ends_with, vec![])(args) + } other => { - exec_err!("Unsupported data type {other:?} for function ends_with") + internal_err!("Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View")? } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ends_with_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ends_with_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Tests if a string ends with a substring.") + .with_syntax_example("ends_with(str, substr)") + .with_sql_example( + r#"```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring to test for.") + .build() + .unwrap() + }) } /// Returns true if string ends with suffix. /// ends_with('alphabet', 'abet') = 't' -pub fn ends_with(args: &[ArrayRef]) -> Result { - let left = as_generic_string_array::(&args[0])?; - let right = as_generic_string_array::(&args[1])?; - - let result = arrow::compute::kernels::comparison::ends_with(left, right)?; +pub fn ends_with(args: &[ArrayRef]) -> Result { + let result = arrow::compute::kernels::comparison::ends_with(&args[0], &args[1])?; Ok(Arc::new(result) as ArrayRef) } diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 864179d130fd..5fd1e7929881 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -16,18 +16,18 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_generic_string_array; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct InitcapFunc { signature: Signature, @@ -41,13 +41,8 @@ impl Default for InitcapFunc { impl InitcapFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -73,11 +68,40 @@ impl ScalarUDFImpl for InitcapFunc { match args[0].data_type() { DataType::Utf8 => make_scalar_function(initcap::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(initcap::, vec![])(args), + DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function initcap") } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_initcap_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_initcap_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters.") + .with_syntax_example("initcap(str)") + .with_sql_example(r#"```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_related_udf("lower") + .with_related_udf("upper") + .build() + .unwrap() + }) } /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. @@ -88,28 +112,41 @@ fn initcap(args: &[ArrayRef]) -> Result { // first map is the iterator, second is for the `Option<_>` let result = string_array .iter() - .map(|string| { - string.map(|string: &str| { - let mut char_vector = Vec::::new(); - let mut previous_character_letter_or_number = false; - for c in string.chars() { - if previous_character_letter_or_number { - char_vector.push(c.to_ascii_lowercase()); - } else { - char_vector.push(c.to_ascii_uppercase()); - } - previous_character_letter_or_number = c.is_ascii_uppercase() - || c.is_ascii_lowercase() - || c.is_ascii_digit(); - } - char_vector.iter().collect::() - }) - }) + .map(initcap_string) .collect::>(); Ok(Arc::new(result) as ArrayRef) } +fn initcap_utf8view(args: &[ArrayRef]) -> Result { + let string_view_array = as_string_view_array(&args[0])?; + + let result = string_view_array + .iter() + .map(initcap_string) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +fn initcap_string(string: Option<&str>) -> Option { + let mut char_vector = Vec::::new(); + string.map(|string: &str| { + char_vector.clear(); + let mut previous_character_letter_or_number = false; + for c in string.chars() { + if previous_character_letter_or_number { + char_vector.push(c.to_ascii_lowercase()); + } else { + char_vector.push(c.to_ascii_uppercase()); + } + previous_character_letter_or_number = + c.is_ascii_uppercase() || c.is_ascii_lowercase() || c.is_ascii_digit(); + } + char_vector.iter().collect::() + }) +} + #[cfg(test)] mod tests { use crate::string::initcap::InitcapFunc; @@ -153,6 +190,44 @@ mod tests { Utf8, StringArray ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hi THOMAS".to_string() + )))], + Ok(Some("Hi Thomas")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() + )))], + Ok(Some("Hi Thomas With M0re Than 12 Chars")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + "".to_string() + )))], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + InitcapFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], + Ok(None), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index ec22b0a4a480..558e71239f84 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -16,17 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; use crate::utils::{make_scalar_function, utf8_to_int_type}; -use datafusion_common::cast::as_generic_string_array; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] @@ -34,14 +34,16 @@ pub struct LevenshteinFunc { signature: Signature, } +impl Default for LevenshteinFunc { + fn default() -> Self { + Self::new() + } +} + impl LevenshteinFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -65,13 +67,42 @@ impl ScalarUDFImpl for LevenshteinFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(levenshtein::, vec![])(args), + DataType::Utf8View | DataType::Utf8 => { + make_scalar_function(levenshtein::, vec![])(args) + } DataType::LargeUtf8 => make_scalar_function(levenshtein::, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function levenshtein") } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_levenshtein_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_levenshtein_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings.") + .with_syntax_example("levenshtein(str1, str2)") + .with_sql_example(r#"```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +```"#) + .with_argument("str1", "String expression to compute Levenshtein distance with str2.") + .with_argument("str2", "String expression to compute Levenshtein distance with str1.") + .build() + .unwrap() + }) } ///Returns the Levenshtein distance between the two given strings. @@ -83,10 +114,26 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { args.len() ); } - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8View => { + let str1_array = as_string_view_array(&args[0])?; + let str2_array = as_string_view_array(&args[1])?; + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } DataType::Utf8 => { + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -100,6 +147,8 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } DataType::LargeUtf8 => { + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; let result = str1_array .iter() .zip(str2_array.iter()) @@ -114,7 +163,7 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { } other => { exec_err!( - "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + "levenshtein was called with {other} datatype arguments. It requires Utf8View, Utf8 or LargeUtf8." ) } } diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index b9b3840252c5..ef56120c582a 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -15,31 +15,32 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::datatypes::DataType; - -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::OnceLock; use crate::string::common::to_lower; use crate::utils::utf8_to_str_type; +use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; #[derive(Debug)] pub struct LowerFunc { signature: Signature, } +impl Default for LowerFunc { + fn default() -> Self { + Self::new() + } +} + impl LowerFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -64,8 +65,37 @@ impl ScalarUDFImpl for LowerFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { to_lower(args, "lower") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lower_doc()) + } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lower_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts a string to lower-case.") + .with_syntax_example("lower(str)") + .with_sql_example( + r#"```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("initcap") + .with_related_udf("upper") + .build() + .unwrap() + }) +} #[cfg(test)] mod tests { use super::*; @@ -75,6 +105,7 @@ mod tests { fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); let args = vec![ColumnarValue::Array(input)]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = match func.invoke(&args)? { ColumnarValue::Array(result) => result, _ => unreachable!(), diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 1a6a9d497f66..0b4c197646b6 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -15,24 +15,24 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; - -use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_physical_expr::functions::Hint; +use std::any::Any; +use std::sync::OnceLock; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::function::Hint; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' fn ltrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Left) + let use_string_view = args[0].data_type() == &DataType::Utf8View; + general_trim::(args, TrimType::Left, use_string_view) } #[derive(Debug)] @@ -40,12 +40,17 @@ pub struct LtrimFunc { signature: Signature, } +impl Default for LtrimFunc { + fn default() -> Self { + Self::new() + } +} + impl LtrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), } @@ -66,12 +71,16 @@ impl ScalarUDFImpl for LtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "ltrim") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "ltrim") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function( + DataType::Utf8 | DataType::Utf8View => make_scalar_function( ltrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), @@ -79,7 +88,192 @@ impl ScalarUDFImpl for LtrimFunc { ltrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), - other => exec_err!("Unsupported data type {other:?} for function ltrim"), + other => exec_err!( + "Unsupported data type {other:?} for function ltrim,\ + expected Utf8, LargeUtf8 or Utf8View." + ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_ltrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_ltrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string.") + .with_syntax_example("ltrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("trim_str", "String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(LEADING trim_str FROM str)") + .with_related_udf("btrim") + .with_related_udf("rtrim") + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::ltrim::LtrimFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() { + // String view cases for checking normal logic + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from(" alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t")))), + ], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabe" + )))), + ], + Ok(Some("t")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + // Special string view case for checking unlined output(len > 12) + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "xxxalphabetalphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("x")))), + ], + Ok(Some("alphabetalphabet")), + &str, + Utf8View, + StringViewArray + ); + // String cases + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet ")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), + ], + Ok(Some("t")), + &str, + Utf8, + StringArray + ); + test_function!( + LtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + } } diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 9eb2a7426fba..622802f0142b 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -21,29 +21,29 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; -mod ascii; -mod bit_length; -mod btrim; -mod chr; -mod common; -mod concat; -mod concat_ws; -mod ends_with; -mod initcap; -mod levenshtein; -mod lower; -mod ltrim; -mod octet_length; -mod overlay; -mod repeat; -mod replace; -mod rtrim; -mod split_part; -mod starts_with; -mod to_hex; -mod upper; -mod uuid; - +pub mod ascii; +pub mod bit_length; +pub mod btrim; +pub mod chr; +pub mod common; +pub mod concat; +pub mod concat_ws; +pub mod contains; +pub mod ends_with; +pub mod initcap; +pub mod levenshtein; +pub mod lower; +pub mod ltrim; +pub mod octet_length; +pub mod overlay; +pub mod repeat; +pub mod replace; +pub mod rtrim; +pub mod split_part; +pub mod starts_with; +pub mod to_hex; +pub mod upper; +pub mod uuid; // create UDFs make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); @@ -66,124 +66,108 @@ make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); make_udf_function!(upper::UpperFunc, UPPER, upper); make_udf_function!(uuid::UuidFunc, UUID, uuid); - +make_udf_function!(contains::ContainsFunc, CONTAINS, contains); pub mod expr_fn { use datafusion_expr::Expr; - #[doc = "Returns the numeric code of the first character of the argument."] - pub fn ascii(arg1: Expr) -> Expr { - super::ascii().call(vec![arg1]) - } - - #[doc = "Returns the number of bits in the `string`"] - pub fn bit_length(arg: Expr) -> Expr { - super::bit_length().call(vec![arg]) - } + export_functions!(( + ascii, + "Returns the numeric code of the first character of the argument.", + arg1 + ),( + bit_length, + "Returns the number of bits in the `string`", + arg1 + ),( + btrim, + "Removes all characters, spaces by default, from both sides of a string", + args, + ),( + chr, + "Converts the Unicode code point to a UTF8 character", + arg1 + ),( + concat, + "Concatenates the text representations of all the arguments. NULL arguments are ignored", + args, + ),( + ends_with, + "Returns true if the `string` ends with the `suffix`, false otherwise.", + string suffix + ),( + initcap, + "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase", + string + ),( + levenshtein, + "Returns the Levenshtein distance between the two given strings", + arg1 arg2 + ),( + lower, + "Converts a string to lowercase.", + arg1 + ),( + ltrim, + "Removes all characters, spaces by default, from the beginning of a string", + args, + ),( + octet_length, + "returns the number of bytes of a string", + args + ),( + overlay, + "replace the substring of string that starts at the start'th character and extends for count characters with new substring", + args, + ),( + repeat, + "Repeats the `string` to `n` times", + string n + ),( + replace, + "Replaces all occurrences of `from` with `to` in the `string`", + string from to + ),( + rtrim, + "Removes all characters, spaces by default, from the end of a string", + args, + ),( + split_part, + "Splits a string based on a delimiter and picks out the desired field based on the index.", + string delimiter index + ),( + starts_with, + "Returns true if string starts with prefix.", + arg1 arg2 + ),( + to_hex, + "Converts an integer to a hexadecimal string.", + arg1 + ),( + upper, + "Converts a string to uppercase.", + arg1 + ),( + uuid, + "returns uuid v4 as a string value", + ), ( + contains, + "Return true if search_string is found within string. treated it like a reglike", + )); #[doc = "Removes all characters, spaces by default, from both sides of a string"] - pub fn btrim(args: Vec) -> Expr { + pub fn trim(args: Vec) -> Expr { super::btrim().call(args) } - #[doc = "Converts the Unicode code point to a UTF8 character"] - pub fn chr(arg: Expr) -> Expr { - super::chr().call(vec![arg]) - } - - #[doc = "Concatenates the text representations of all the arguments. NULL arguments are ignored"] - pub fn concat(args: Vec) -> Expr { - super::concat().call(args) - } - #[doc = "Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored."] pub fn concat_ws(delimiter: Expr, args: Vec) -> Expr { let mut args = args; args.insert(0, delimiter); super::concat_ws().call(args) } - - #[doc = "Returns true if the `string` ends with the `suffix`, false otherwise."] - pub fn ends_with(string: Expr, suffix: Expr) -> Expr { - super::ends_with().call(vec![string, suffix]) - } - - #[doc = "Converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"] - pub fn initcap(string: Expr) -> Expr { - super::initcap().call(vec![string]) - } - - #[doc = "Returns the Levenshtein distance between the two given strings"] - pub fn levenshtein(arg1: Expr, arg2: Expr) -> Expr { - super::levenshtein().call(vec![arg1, arg2]) - } - - #[doc = "Converts a string to lowercase."] - pub fn lower(arg1: Expr) -> Expr { - super::lower().call(vec![arg1]) - } - - #[doc = "Removes all characters, spaces by default, from the beginning of a string"] - pub fn ltrim(args: Vec) -> Expr { - super::ltrim().call(args) - } - - #[doc = "returns the number of bytes of a string"] - pub fn octet_length(args: Vec) -> Expr { - super::octet_length().call(args) - } - - #[doc = "replace the substring of string that starts at the start'th character and extends for count characters with new substring"] - pub fn overlay(args: Vec) -> Expr { - super::overlay().call(args) - } - - #[doc = "Repeats the `string` to `n` times"] - pub fn repeat(string: Expr, n: Expr) -> Expr { - super::repeat().call(vec![string, n]) - } - - #[doc = "Replaces all occurrences of `from` with `to` in the `string`"] - pub fn replace(string: Expr, from: Expr, to: Expr) -> Expr { - super::replace().call(vec![string, from, to]) - } - - #[doc = "Removes all characters, spaces by default, from the end of a string"] - pub fn rtrim(args: Vec) -> Expr { - super::rtrim().call(args) - } - - #[doc = "Splits a string based on a delimiter and picks out the desired field based on the index."] - pub fn split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr { - super::split_part().call(vec![string, delimiter, index]) - } - - #[doc = "Returns true if string starts with prefix."] - pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr { - super::starts_with().call(vec![arg1, arg2]) - } - - #[doc = "Converts an integer to a hexadecimal string."] - pub fn to_hex(arg1: Expr) -> Expr { - super::to_hex().call(vec![arg1]) - } - - #[doc = "Removes all characters, spaces by default, from both sides of a string"] - pub fn trim(args: Vec) -> Expr { - super::btrim().call(args) - } - - #[doc = "Converts a string to uppercase."] - pub fn upper(arg1: Expr) -> Expr { - super::upper().call(vec![arg1]) - } - - #[doc = "returns uuid v4 as a string value"] - pub fn uuid() -> Expr { - super::uuid().call(vec![]) - } } -/// Return a list of all functions in this package +/// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { vec![ ascii(), @@ -198,7 +182,6 @@ pub fn functions() -> Vec> { lower(), ltrim(), octet_length(), - overlay(), repeat(), replace(), rtrim(), @@ -207,5 +190,6 @@ pub fn functions() -> Vec> { to_hex(), upper(), uuid(), + contains(), ] } diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index bdd262b7e37e..2ac2bf70da23 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,31 +15,32 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; - use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::OnceLock; +use crate::utils::utf8_to_int_type; use datafusion_common::{exec_err, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::utf8_to_int_type; - #[derive(Debug)] pub struct OctetLengthFunc { signature: Signature, } +impl Default for OctetLengthFunc { + fn default() -> Self { + Self::new() + } +} + impl OctetLengthFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -78,10 +79,43 @@ impl ScalarUDFImpl for OctetLengthFunc { ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), )), + ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int32(v.as_ref().map(|x| x.len() as i32)), + )), _ => unreachable!(), }, } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_octet_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_octet_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the length of a string in bytes.") + .with_syntax_example("octet_length(str)") + .with_sql_example( + r#"```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("bit_length") + .with_related_udf("length") + .build() + .unwrap() + }) } #[cfg(test)] @@ -170,6 +204,36 @@ mod tests { Int32, Int32Array ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("joséjoséjoséjosé") + )))], + Ok(Some(20)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("josé") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); Ok(()) } diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 3f92a73c1af9..796776304f4a 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -16,34 +16,43 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct OverlayFunc { signature: Signature, } +impl Default for OverlayFunc { + fn default() -> Self { + Self::new() + } +} + impl OverlayFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -70,54 +79,136 @@ impl ScalarUDFImpl for OverlayFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(overlay::, vec![])(args), + DataType::Utf8View | DataType::Utf8 => { + make_scalar_function(overlay::, vec![])(args) + } DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), other => exec_err!("Unsupported data type {other:?} for function overlay"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_overlay_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_overlay_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the string which is replaced by another string from the specified position and specified count length.") + .with_syntax_example("overlay(str PLACING substr FROM pos [FOR count])") + .with_sql_example(r#"```sql +> select overlay('Txxxxas' placing 'hom' from 2 for 4); ++--------------------------------------------------------+ +| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) | ++--------------------------------------------------------+ +| Thomas | ++--------------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring to replace in str.") + .with_argument("pos", "The start position to start the replace in str.") + .with_argument("count", "The count of characters to be replaced from start position of str. If not specified, will use substr length instead.") + .build() + .unwrap() + }) +} + +macro_rules! process_overlay { + // For the three-argument case + ($string_array:expr, $characters_array:expr, $pos_num:expr) => {{ + $string_array + .iter() + .zip($characters_array.iter()) + .zip($pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>() + }}; + + // For the four-argument case + ($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{ + $string_array + .iter() + .zip($characters_array.iter()) + .zip($pos_num.iter()) + .zip($len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>() + }}; } /// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) /// Replaces a substring of string1 with string2 starting at the integer bit /// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas /// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead -pub fn overlay(args: &[ArrayRef]) -> Result { +fn overlay(args: &[ArrayRef]) -> Result { + let use_string_view = args[0].data_type() == &DataType::Utf8View; + if use_string_view { + string_view_overlay::(args) + } else { + string_overlay::(args) + } +} + +pub fn string_overlay(args: &[ArrayRef]) -> Result { match args.len() { 3 => { let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let pos_num = as_int64_array(&args[2])?; - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; + let result = process_overlay!(string_array, characters_array, pos_num)?; Ok(Arc::new(result) as ArrayRef) } 4 => { @@ -126,37 +217,34 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let pos_num = as_int64_array(&args[2])?; let len_num = as_int64_array(&args[3])?; - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .zip(len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; + let result = + process_overlay!(string_array, characters_array, pos_num, len_num)?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") + } + } +} + +pub fn string_view_overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_string_view_array(&args[0])?; + let characters_array = as_string_view_array(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = process_overlay!(string_array, characters_array, pos_num)?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_string_view_array(&args[0])?; + let characters_array = as_string_view_array(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = + process_overlay!(string_array, characters_array, pos_num, len_num)?; Ok(Arc::new(result) as ArrayRef) } other => { diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 77521120d9d8..aa69f9c6609a 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -16,30 +16,45 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use crate::strings::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use arrow::array::{ + ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, +}; use arrow::datatypes::DataType; - -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use arrow::datatypes::DataType::{Int64, LargeUtf8, Utf8, Utf8View}; +use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct RepeatFunc { signature: Signature, } +impl Default for RepeatFunc { + fn default() -> Self { + Self::new() + } +} + impl RepeatFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + vec![ + // Planner attempts coercion to the target type starting with the most preferred candidate. + // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`. + // If that fails, it proceeds to `(Utf8, Int64)`. + TypeSignature::Exact(vec![Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Int64]), + ], Volatility::Immutable, ), } @@ -64,30 +79,84 @@ impl ScalarUDFImpl for RepeatFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(repeat::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(repeat::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function repeat"), - } + make_scalar_function(repeat, vec![])(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_repeat_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_repeat_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description( + "Returns a string with an input string repeated a specified number.", + ) + .with_syntax_example("repeat(str, n)") + .with_sql_example( + r#"```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("n", "Number of times to repeat the input string.") + .build() + .unwrap() + }) +} + /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' -fn repeat(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; +fn repeat(args: &[ArrayRef]) -> Result { let number_array = as_int64_array(&args[1])?; + match args[0].data_type() { + Utf8View => { + let string_view_array = args[0].as_string_view(); + repeat_impl::(string_view_array, number_array) + } + Utf8 => { + let string_array = args[0].as_string::(); + repeat_impl::>(string_array, number_array) + } + LargeUtf8 => { + let string_array = args[0].as_string::(); + repeat_impl::>(string_array, number_array) + } + other => exec_err!( + "Unsupported data type {other:?} for function repeat. \ + Expected Utf8, Utf8View or LargeUtf8." + ), + } +} - let result = string_array +fn repeat_impl<'a, T, S>(string_array: S, number_array: &Int64Array) -> Result +where + T: OffsetSizeTrait, + S: StringArrayType<'a>, +{ + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + string_array .iter() .zip(number_array.iter()) - .map(|(string, number)| match (string, number) { - (Some(string), Some(number)) => Some(string.repeat(number as usize)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + .for_each(|(string, number)| match (string, number) { + (Some(string), Some(number)) if number >= 0 => { + builder.append_value(string.repeat(number as usize)) + } + (Some(_), Some(_)) => builder.append_value(""), + _ => builder.append_null(), + }); + let array = builder.finish(); + + Ok(Arc::new(array) as ArrayRef) } #[cfg(test)] @@ -115,7 +184,6 @@ mod tests { Utf8, StringArray ); - test_function!( RepeatFunc::new(), &[ @@ -139,6 +207,40 @@ mod tests { StringArray ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + Ok(()) } } diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 01a3762acaf4..91abc39da058 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -16,32 +16,33 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_generic_string_array; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; - #[derive(Debug)] pub struct ReplaceFunc { signature: Signature, } +impl Default for ReplaceFunc { + fn default() -> Self { + Self::new() + } +} + impl ReplaceFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8, Utf8])], - Volatility::Immutable, - ), + signature: Signature::string(3, Volatility::Immutable), } } } @@ -67,13 +68,59 @@ impl ScalarUDFImpl for ReplaceFunc { match args[0].data_type() { DataType::Utf8 => make_scalar_function(replace::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), + DataType::Utf8View => make_scalar_function(replace_view, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function replace") } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_replace_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_replace_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Replaces all occurrences of a specified substring in a string with a new substring.") + .with_syntax_example("replace(str, substr, replacement)") + .with_sql_example(r#"```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_standard_argument("substr", Some("Substring expression to replace in the input string. Substring")) + .with_standard_argument("replacement", Some("Replacement substring")) + .build() + .unwrap() + }) } +fn replace_view(args: &[ArrayRef]) -> Result { + let string_array = as_string_view_array(&args[0])?; + let from_array = as_string_view_array(&args[1])?; + let to_array = as_string_view_array(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), + _ => None, + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} /// Replaces all occurrences in string of substring from with substring to. /// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' fn replace(args: &[ArrayRef]) -> Result { @@ -94,4 +141,60 @@ fn replace(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -mod test {} +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::test::test_function; + use arrow::array::Array; + use arrow::array::LargeStringArray; + use arrow::array::StringArray; + use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use datafusion_common::ScalarValue; + #[test] + fn test_functions() -> Result<()> { + test_function!( + ReplaceFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), + ], + Ok(Some("aacccdqcccc")), + &str, + Utf8, + StringArray + ); + + test_function!( + ReplaceFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( + "aabbb" + )))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))), + ], + Ok(Some("aacc")), + &str, + LargeUtf8, + LargeStringArray + ); + + test_function!( + ReplaceFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "aabbbcw" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))), + ], + Ok(Some("aaccbcw")), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index e6e93e38c966..e934147efbbe 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -16,23 +16,23 @@ // under the License. use arrow::array::{ArrayRef, OffsetSizeTrait}; -use std::any::Any; - use arrow::datatypes::DataType; - -use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; -use datafusion_expr::{ScalarUDFImpl, Signature}; -use datafusion_physical_expr::functions::Hint; +use std::any::Any; +use std::sync::OnceLock; use crate::string::common::*; use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::function::Hint; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' fn rtrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Right) + let use_string_view = args[0].data_type() == &DataType::Utf8View; + general_trim::(args, TrimType::Right, use_string_view) } #[derive(Debug)] @@ -40,12 +40,17 @@ pub struct RtrimFunc { signature: Signature, } +impl Default for RtrimFunc { + fn default() -> Self { + Self::new() + } +} + impl RtrimFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + vec![TypeSignature::String(2), TypeSignature::String(1)], Volatility::Immutable, ), } @@ -66,12 +71,16 @@ impl ScalarUDFImpl for RtrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "rtrim") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "rtrim") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function( + DataType::Utf8 | DataType::Utf8View => make_scalar_function( rtrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), @@ -79,7 +88,192 @@ impl ScalarUDFImpl for RtrimFunc { rtrim::, vec![Hint::Pad, Hint::AcceptsSingular], )(args), - other => exec_err!("Unsupported data type {other:?} for function rtrim"), + other => exec_err!( + "Unsupported data type {other:?} for function rtrim,\ + expected Utf8, LargeUtf8 or Utf8View." + ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_rtrim_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rtrim_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string.") + .with_syntax_example("rtrim(str[, trim_str])") + .with_sql_example(r#"```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("trim_str", "String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._") + .with_alternative_syntax("trim(TRAILING trim_str FROM str)") + .with_related_udf("btrim") + .with_related_udf("ltrim") + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::rtrim::RtrimFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() { + // String view cases for checking normal logic + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( + String::from(" alphabet ") + ))),], + Ok(Some(" alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t ")))), + ], + Ok(Some("alphabe")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabe" + )))), + ], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + // Special string view case for checking unlined output(len > 12) + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabetalphabetxxx" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("x")))), + ], + Ok(Some("alphabetalphabet")), + &str, + Utf8View, + StringViewArray + ); + // String cases + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("alphabet ") + ))),], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from(" alphabet ") + ))),], + Ok(Some(" alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t ")))), + ], + Ok(Some("alphabe")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + RtrimFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + } } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 4396386afff5..ea01cb1f56f9 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -15,35 +15,48 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use crate::strings::StringArrayType; +use crate::utils::utf8_to_str_type; +use arrow::array::{ + ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringViewArray, +}; +use arrow::array::{AsArray, GenericStringBuilder}; use arrow::datatypes::DataType; - -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::ScalarValue; +use datafusion_common::{exec_err, DataFusionError, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use std::any::Any; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct SplitPartFunc { signature: Signature, } +impl Default for SplitPartFunc { + fn default() -> Self { + Self::new() + } +} + impl SplitPartFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8View, Utf8, Int64]), + TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8View, Int64]), + TypeSignature::Exact(vec![Utf8, Utf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]), + TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]), + TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]), + TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], Volatility::Immutable, ), @@ -69,43 +82,173 @@ impl ScalarUDFImpl for SplitPartFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(split_part::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(split_part::, vec![])(args), - other => { - exec_err!("Unsupported data type {other:?} for function split_part") + // First, determine if any of the arguments is an Array + let len = args.iter().find_map(|arg| match arg { + ColumnarValue::Array(a) => Some(a.len()), + _ => None, + }); + + let inferred_length = len.unwrap_or(1); + let is_scalar = len.is_none(); + + // Convert all ColumnarValues to ArrayRefs + let args = args + .iter() + .map(|arg| match arg { + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length), + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + }) + .collect::>>()?; + + // Unpack the ArrayRefs from the arguments + let n_array = as_int64_array(&args[2])?; + let result = match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8View, DataType::Utf8View) => { + split_part_impl::<&StringViewArray, &StringViewArray, i32>( + args[0].as_string_view(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::Utf8View, DataType::Utf8) => { + split_part_impl::<&StringViewArray, &GenericStringArray, i32>( + args[0].as_string_view(), + args[1].as_string::(), + n_array, + ) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + split_part_impl::<&StringViewArray, &GenericStringArray, i32>( + args[0].as_string_view(), + args[1].as_string::(), + n_array, + ) + } + (DataType::Utf8, DataType::Utf8View) => { + split_part_impl::<&GenericStringArray, &StringViewArray, i32>( + args[0].as_string::(), + args[1].as_string_view(), + n_array, + ) + } + (DataType::LargeUtf8, DataType::Utf8View) => { + split_part_impl::<&GenericStringArray, &StringViewArray, i64>( + args[0].as_string::(), + args[1].as_string_view(), + n_array, + ) } + (DataType::Utf8, DataType::Utf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i32>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i64>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) + } + (DataType::Utf8, DataType::LargeUtf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i32>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) + } + (DataType::LargeUtf8, DataType::Utf8) => { + split_part_impl::<&GenericStringArray, &GenericStringArray, i64>( + args[0].as_string::(), + args[1].as_string::(), + n_array, + ) + } + _ => exec_err!("Unsupported combination of argument types for split_part"), + }; + if is_scalar { + // If all inputs are scalar, keep the output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_split_part_doc()) + } } -/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). -/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' -fn split_part(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let n_array = as_int64_array(&args[2])?; - let result = string_array +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_split_part_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Splits a string based on a specified delimiter and returns the substring in the specified position.") + .with_syntax_example("split_part(str, delimiter, pos)") + .with_sql_example(r#"```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("delimiter", "String or character to split on.") + .with_argument("pos", "Position of the part to return.") + .build() + .unwrap() + }) +} + +/// impl +pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>( + string_array: StringArrType, + delimiter_array: DelimiterArrType, + n_array: &Int64Array, +) -> Result +where + StringArrType: StringArrayType<'a>, + DelimiterArrType: StringArrayType<'a>, + StringArrayLen: OffsetSizeTrait, +{ + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + string_array .iter() .zip(delimiter_array.iter()) .zip(n_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - if n <= 0 { - exec_err!("field position must be greater than zero") - } else { + .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> { + match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { let split_string: Vec<&str> = string.split(delimiter).collect(); - match split_string.get(n as usize - 1) { - Some(s) => Ok(Some(*s)), - None => Ok(Some("")), + let len = split_string.len(); + + let index = match n.cmp(&0) { + std::cmp::Ordering::Less => len as i64 + n, + std::cmp::Ordering::Equal => { + return exec_err!("field position must not be zero"); + } + std::cmp::Ordering::Greater => n - 1, + } as usize; + + if index < len { + builder.append_value(split_string[index]); + } else { + builder.append_value(""); } } + _ => builder.append_null(), } - _ => Ok(None), - }) - .collect::>>()?; + Ok(()) + })?; - Ok(Arc::new(result) as ArrayRef) + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] @@ -159,7 +302,21 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), ], - exec_err!("field position must be greater than zero"), + Ok(Some("ghi")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + exec_err!("field position must not be zero"), &str, Utf8, StringArray diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index edbf5c9217a7..dce161a2e14b 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -16,26 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; - use crate::utils::make_scalar_function; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' -pub fn starts_with(args: &[ArrayRef]) -> Result { - let left = as_generic_string_array::(&args[0])?; - let right = as_generic_string_array::(&args[1])?; - - let result = arrow::compute::kernels::comparison::starts_with(left, right)?; - +pub fn starts_with(args: &[ArrayRef]) -> Result { + let result = arrow::compute::kernels::comparison::starts_with(&args[0], &args[1])?; Ok(Arc::new(result) as ArrayRef) } @@ -43,19 +38,17 @@ pub fn starts_with(args: &[ArrayRef]) -> Result { pub struct StartsWithFunc { signature: Signature, } + +impl Default for StartsWithFunc { + fn default() -> Self { + Self::new() + } +} + impl StartsWithFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), } } } @@ -74,18 +67,102 @@ impl ScalarUDFImpl for StartsWithFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - use DataType::*; - - Ok(Boolean) + Ok(DataType::Boolean) } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(starts_with::, vec![])(args), - DataType::LargeUtf8 => { - return make_scalar_function(starts_with::, vec![])(args); + DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => { + make_scalar_function(starts_with, vec![])(args) } - _ => internal_err!("Unsupported data type"), + _ => internal_err!("Unsupported data types for starts_with. Expected Utf8, LargeUtf8 or Utf8View")?, + } + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_starts_with_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_starts_with_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Tests if a string starts with a substring.") + .with_syntax_example("starts_with(str, substr)") + .with_sql_example( + r#"```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring to test for.") + .build() + .unwrap() + }) +} + +#[cfg(test)] +mod tests { + use crate::utils::test::test_function; + use arrow::array::{Array, BooleanArray}; + use arrow::datatypes::DataType::Boolean; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use super::*; + + #[test] + fn test_functions() -> Result<()> { + // Generate test cases for starts_with + let test_cases = vec![ + (Some("alphabet"), Some("alph"), Some(true)), + (Some("alphabet"), Some("bet"), Some(false)), + ( + Some("somewhat large string"), + Some("somewhat large"), + Some(true), + ), + (Some("somewhat large string"), Some("large"), Some(false)), + ] + .into_iter() + .flat_map(|(a, b, c)| { + let utf_8_args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))), + ]; + + let large_utf_8_args = vec![ + ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))), + ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))), + ]; + + let utf_8_view_args = vec![ + ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))), + ]; + + vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)] + }); + + for (args, expected) in test_cases { + test_function!( + StartsWithFunc::new(), + &args, + Ok(expected), + bool, + Boolean, + BooleanArray + ); } + + Ok(()) } } diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 5050d8bab3e9..e0033d2d1cb0 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -16,21 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; +use crate::utils::make_scalar_function; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::utils::make_scalar_function; - /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result @@ -63,6 +63,13 @@ where pub struct ToHexFunc { signature: Signature, } + +impl Default for ToHexFunc { + fn default() -> Self { + Self::new() + } +} + impl ToHexFunc { pub fn new() -> Self { use DataType::*; @@ -103,6 +110,34 @@ impl ScalarUDFImpl for ToHexFunc { other => exec_err!("Unsupported data type {other:?} for function to_hex"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_to_hex_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_to_hex_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts an integer to a hexadecimal string.") + .with_syntax_example("to_hex(int)") + .with_sql_example( + r#"```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +```"#, + ) + .with_standard_argument("int", Some("Integer")) + .build() + .unwrap() + }) } #[cfg(test)] diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 8f03d7dc6bbc..68a9d60a1663 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -19,24 +19,27 @@ use crate::string::common::to_upper; use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; -use datafusion_expr::ColumnarValue; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; +use std::sync::OnceLock; #[derive(Debug)] pub struct UpperFunc { signature: Signature, } +impl Default for UpperFunc { + fn default() -> Self { + Self::new() + } +} + impl UpperFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8, LargeUtf8], - Volatility::Immutable, - ), + signature: Signature::string(1, Volatility::Immutable), } } } @@ -61,6 +64,36 @@ impl ScalarUDFImpl for UpperFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { to_upper(args, "upper") } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_upper_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_upper_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Converts a string to upper-case.") + .with_syntax_example("upper(str)") + .with_sql_example( + r#"```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("initcap") + .with_related_udf("lower") + .build() + .unwrap() + }) } #[cfg(test)] @@ -72,6 +105,7 @@ mod tests { fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); let args = vec![ColumnarValue::Array(input)]; + #[allow(deprecated)] // TODO migrate UDF invoke to invoke_batch let result = match func.invoke(&args)? { ColumnarValue::Array(result) => result, _ => unreachable!(), diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs index c68871d42e9f..0fbdce16ccd1 100644 --- a/datafusion/functions/src/string/uuid.rs +++ b/datafusion/functions/src/string/uuid.rs @@ -16,16 +16,16 @@ // under the License. use std::any::Any; -use std::iter; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::GenericStringArray; use arrow::datatypes::DataType; use arrow::datatypes::DataType::Utf8; use uuid::Uuid; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_common::{not_impl_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ColumnarValue, Documentation, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; #[derive(Debug)] @@ -33,6 +33,12 @@ pub struct UuidFunc { signature: Signature, } +impl Default for UuidFunc { + fn default() -> Self { + Self::new() + } +} + impl UuidFunc { pub fn new() -> Self { Self { @@ -58,16 +64,40 @@ impl ScalarUDFImpl for UuidFunc { Ok(Utf8) } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + not_impl_err!("{} function does not accept arguments", self.name()) + } + /// Prints random (v4) uuid values per row /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let len: usize = match &args[0] { - ColumnarValue::Array(array) => array.len(), - _ => return exec_err!("Expect uuid function to take no param"), - }; - - let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len); + fn invoke_no_args(&self, num_rows: usize) -> Result { + let values = std::iter::repeat_with(|| Uuid::new_v4().to_string()).take(num_rows); let array = GenericStringArray::::from_iter_values(values); Ok(ColumnarValue::Array(Arc::new(array))) } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_uuid_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_uuid_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns [`UUID v4`](https://en.wikipedia.org/wiki/Universally_unique_identifier#Version_4_(random)) string value which is unique per row.") + .with_syntax_example("uuid()") + .with_sql_example(r#"```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ +```"#) + .build() + .unwrap() + }) } diff --git a/datafusion/functions/src/strings.rs b/datafusion/functions/src/strings.rs new file mode 100644 index 000000000000..e0cec3cb5756 --- /dev/null +++ b/datafusion/functions/src/strings.rs @@ -0,0 +1,424 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::mem::size_of; + +use arrow::array::{ + make_view, Array, ArrayAccessor, ArrayDataBuilder, ArrayIter, ByteView, + GenericStringArray, LargeStringArray, OffsetSizeTrait, StringArray, StringViewArray, + StringViewBuilder, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{MutableBuffer, NullBuffer, NullBufferBuilder}; + +/// Abstracts iteration over different types of string arrays. +/// +/// The [`StringArrayType`] trait helps write generic code for string functions that can work with +/// different types of string arrays. +/// +/// Currently three types are supported: +/// - [`StringArray`] +/// - [`LargeStringArray`] +/// - [`StringViewArray`] +/// +/// It is inspired / copied from [arrow-rs]. +/// +/// [arrow-rs]: https://github.com/apache/arrow-rs/blob/bf0ea9129e617e4a3cf915a900b747cc5485315f/arrow-string/src/like.rs#L151-L157 +/// +/// # Examples +/// Generic function that works for [`StringArray`], [`LargeStringArray`] +/// and [`StringViewArray`]: +/// ``` +/// # use arrow::array::{StringArray, LargeStringArray, StringViewArray}; +/// # use datafusion_functions::strings::StringArrayType; +/// +/// /// Combines string values for any StringArrayType type. It can be invoked on +/// /// and combination of `StringArray`, `LargeStringArray` or `StringViewArray` +/// fn combine_values<'a, S1, S2>(array1: S1, array2: S2) -> Vec +/// where S1: StringArrayType<'a>, S2: StringArrayType<'a> +/// { +/// // iterate over the elements of the 2 arrays in parallel +/// array1 +/// .iter() +/// .zip(array2.iter()) +/// .map(|(s1, s2)| { +/// // if both values are non null, combine them +/// if let (Some(s1), Some(s2)) = (s1, s2) { +/// format!("{s1}{s2}") +/// } else { +/// "None".to_string() +/// } +/// }) +/// .collect() +/// } +/// +/// let string_array = StringArray::from(vec!["foo", "bar"]); +/// let large_string_array = LargeStringArray::from(vec!["foo2", "bar2"]); +/// let string_view_array = StringViewArray::from(vec!["foo3", "bar3"]); +/// +/// // can invoke this function a string array and large string array +/// assert_eq!( +/// combine_values(&string_array, &large_string_array), +/// vec![String::from("foofoo2"), String::from("barbar2")] +/// ); +/// +/// // Can call the same function with string array and string view array +/// assert_eq!( +/// combine_values(&string_array, &string_view_array), +/// vec![String::from("foofoo3"), String::from("barbar3")] +/// ); +/// ``` +/// +/// [`LargeStringArray`]: arrow::array::LargeStringArray +pub trait StringArrayType<'a>: ArrayAccessor + Sized { + /// Return an [`ArrayIter`] over the values of the array. + /// + /// This iterator iterates returns `Option<&str>` for each item in the array. + fn iter(&self) -> ArrayIter; + + /// Check if the array is ASCII only. + fn is_ascii(&self) -> bool; +} + +impl<'a, T: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { + fn iter(&self) -> ArrayIter { + GenericStringArray::::iter(self) + } + + fn is_ascii(&self) -> bool { + GenericStringArray::::is_ascii(self) + } +} + +impl<'a> StringArrayType<'a> for &'a StringViewArray { + fn iter(&self) -> ArrayIter { + StringViewArray::iter(self) + } + + fn is_ascii(&self) -> bool { + StringViewArray::is_ascii(self) + } +} + +/// Optimized version of the StringBuilder in Arrow that: +/// 1. Precalculating the expected length of the result, avoiding reallocations. +/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` +pub struct StringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl StringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = + MutableBuffer::with_capacity((item_capacity + 1) * size_of::()); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i32) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i32 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + pub fn finish(self, null_buffer: Option) -> StringArray { + let array_builder = ArrayDataBuilder::new(DataType::Utf8) + .len(self.offsets_buffer.len() / size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + StringArray::from(array_data) + } +} + +pub struct StringViewArrayBuilder { + builder: StringViewBuilder, + block: String, +} + +impl StringViewArrayBuilder { + pub fn with_capacity(_item_capacity: usize, data_capacity: usize) -> Self { + let builder = StringViewBuilder::with_capacity(data_capacity); + Self { + builder, + block: String::new(), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.block.push_str(std::str::from_utf8(s).unwrap()); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.block.push_str( + std::str::from_utf8(array.value(i).as_bytes()).unwrap(), + ); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.block + .push_str(std::str::from_utf8(array.value(i).as_bytes()).unwrap()); + } + } + } + + pub fn append_offset(&mut self) { + self.builder.append_value(&self.block); + self.block = String::new(); + } + + pub fn finish(mut self) -> StringViewArray { + self.builder.finish() + } +} + +pub struct LargeStringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl LargeStringArrayBuilder { + pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = + MutableBuffer::with_capacity((item_capacity + 1) * size_of::()); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i64) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + pub fn write( + &mut self, + column: &ColumnarValueRef, + i: usize, + ) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableLargeStringArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NullableStringViewArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableLargeStringArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + ColumnarValueRef::NonNullableStringViewArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + pub fn append_offset(&mut self) { + let next_offset: i64 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + pub fn finish(self, null_buffer: Option) -> LargeStringArray { + let array_builder = ArrayDataBuilder::new(DataType::LargeUtf8) + .len(self.offsets_buffer.len() / size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid Large UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + LargeStringArray::from(array_data) + } +} + +/// Append a new view to the views buffer with the given substr +/// +/// # Safety +/// +/// original_view must be a valid view (the format described on +/// [`GenericByteViewArray`](arrow::array::GenericByteViewArray). +/// +/// # Arguments +/// - views_buffer: The buffer to append the new view to +/// - null_builder: The buffer to append the null value to +/// - original_view: The original view value +/// - substr: The substring to append. Must be a valid substring of the original view +/// - start_offset: The start offset of the substring in the view +pub fn make_and_append_view( + views_buffer: &mut Vec, + null_builder: &mut NullBufferBuilder, + original_view: &u128, + substr: &str, + start_offset: u32, +) { + let substr_len = substr.len(); + let sub_view = if substr_len > 12 { + let view = ByteView::from(*original_view); + make_view( + substr.as_bytes(), + view.buffer_index, + view.offset + start_offset, + ) + } else { + // inline value does not need block id or offset + make_view(substr.as_bytes(), 0, 0) + }; + views_buffer.push(sub_view); + null_builder.append_non_null(); +} + +#[derive(Debug)] +pub enum ColumnarValueRef<'a> { + Scalar(&'a [u8]), + NullableArray(&'a StringArray), + NonNullableArray(&'a StringArray), + NullableLargeStringArray(&'a LargeStringArray), + NonNullableLargeStringArray(&'a LargeStringArray), + NullableStringViewArray(&'a StringViewArray), + NonNullableStringViewArray(&'a StringViewArray), +} + +impl<'a> ColumnarValueRef<'a> { + #[inline] + pub fn is_valid(&self, i: usize) -> bool { + match &self { + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableLargeStringArray(_) + | Self::NonNullableStringViewArray(_) => true, + Self::NullableArray(array) => array.is_valid(i), + Self::NullableStringViewArray(array) => array.is_valid(i), + Self::NullableLargeStringArray(array) => array.is_valid(i), + } + } + + #[inline] + pub fn nulls(&self) -> Option { + match &self { + Self::Scalar(_) + | Self::NonNullableArray(_) + | Self::NonNullableStringViewArray(_) + | Self::NonNullableLargeStringArray(_) => None, + Self::NullableArray(array) => array.nulls().cloned(), + Self::NullableStringViewArray(array) => array.nulls().cloned(), + Self::NullableLargeStringArray(array) => array.nulls().cloned(), + } + } +} diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 7e2723771ff2..7858a59664d3 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -15,17 +15,19 @@ // specific language governing permissions and limitations // under the License. +use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + Array, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::exec_err; use datafusion_common::Result; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; #[derive(Debug)] pub struct CharacterLengthFunc { @@ -33,13 +35,19 @@ pub struct CharacterLengthFunc { aliases: Vec, } +impl Default for CharacterLengthFunc { + fn default() -> Self { + Self::new() + } +} + impl CharacterLengthFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8], + vec![Utf8, LargeUtf8, Utf8View], Volatility::Immutable, ), aliases: vec![String::from("length"), String::from("char_length")], @@ -65,40 +73,85 @@ impl ScalarUDFImpl for CharacterLengthFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(character_length::, vec![])(args) - } - DataType::LargeUtf8 => { - make_scalar_function(character_length::, vec![])(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function character_length") - } - } + make_scalar_function(character_length, vec![])(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_character_length_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_character_length_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the number of characters in a string.") + .with_syntax_example("character_length(str)") + .with_sql_example( + r#"```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .with_related_udf("bit_length") + .with_related_udf("octet_length") + .build() + .unwrap() + }) } /// Returns number of characters in the string. /// character_length('josé') = 4 /// The implementation counts UTF-8 code points to count the number of characters -fn character_length(args: &[ArrayRef]) -> Result +fn character_length(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + character_length_general::(string_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + character_length_general::(string_array) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + character_length_general::(string_array) + } + _ => unreachable!(), + } +} + +fn character_length_general<'a, T: ArrowPrimitiveType, V: StringArrayType<'a>>( + array: V, +) -> Result where T::Native: OffsetSizeTrait, { - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() + // String characters are variable length encoded in UTF-8, counting the + // number of chars requires expensive decoding, however checking if the + // string is ASCII only is relatively cheap. + // If strings are ASCII only, count bytes instead. + let is_array_ascii_only = array.is_ascii(); + let iter = array.iter(); + let result = iter .map(|string| { string.map(|string: &str| { - T::Native::from_usize(string.chars().count()) - .expect("should not fail as string.chars will always return integer") + if is_array_ascii_only { + T::Native::usize_as(string.len()) + } else { + T::Native::usize_as(string.chars().count()) + } }) }) .collect::>(); @@ -110,55 +163,54 @@ where mod tests { use crate::unicode::character_length::CharacterLengthFunc; use crate::utils::test::test_function; - use arrow::array::{Array, Int32Array}; - use arrow::datatypes::DataType::Int32; + use arrow::array::{Array, Int32Array, Int64Array}; + use arrow::datatypes::DataType::{Int32, Int64}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + macro_rules! test_character_length { + ($INPUT:expr, $EXPECTED:expr) => { + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + i64, + Int64, + Int64Array + ); + + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + i32, + Int32, + Int32Array + ); + }; + } + #[test] fn test_functions() -> Result<()> { #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("chars") - )))], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("josé") - )))], - Ok(Some(4)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("") - )))], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); + { + test_character_length!(Some(String::from("chars")), Ok(Some(5))); + test_character_length!(Some(String::from("josé")), Ok(Some(4))); + // test long strings (more than 12 bytes for StringView) + test_character_length!(Some(String::from("joséjoséjoséjosé")), Ok(Some(16))); + test_character_length!(Some(String::from("")), Ok(Some(0))); + test_character_length!(None, Ok(None)); + } + #[cfg(not(feature = "unicode_expressions"))] test_function!( CharacterLengthFunc::new(), diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index fc45f897c5f4..cad860e41088 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -16,31 +16,43 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, + PrimitiveArray, }; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; -use datafusion_common::cast::as_generic_string_array; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_int_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct FindInSetFunc { signature: Signature, } +impl Default for FindInSetFunc { + fn default() -> Self { + Self::new() + } +} + impl FindInSetFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + vec![ + Exact(vec![Utf8View, Utf8View]), + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], Volatility::Immutable, ), } @@ -65,41 +77,79 @@ impl ScalarUDFImpl for FindInSetFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function(find_in_set::, vec![])(args) - } - DataType::LargeUtf8 => { - make_scalar_function(find_in_set::, vec![])(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") - } - } + make_scalar_function(find_in_set, vec![])(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_find_in_set_doc()) } } +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_find_in_set_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.") + .with_syntax_example("find_in_set(str, strlist)") + .with_sql_example(r#"```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +```"#) + .with_argument("str", "String expression to find in strlist.") + .with_argument("strlist", "A string list is a string composed of substrings separated by , characters.") + .build() + .unwrap() + }) +} + ///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings ///A string list is a string composed of substrings separated by , characters. -pub fn find_in_set(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ +fn find_in_set(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!( "find_in_set was called with {} arguments. It requires 2.", args.len() ); } + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let str_list_array = args[1].as_string::(); + find_in_set_general::(string_array, str_list_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let str_list_array = args[1].as_string::(); + find_in_set_general::(string_array, str_list_array) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let str_list_array = args[1].as_string_view(); + find_in_set_general::(string_array, str_list_array) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + } +} - let str_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - let str_list_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = str_array - .iter() - .zip(str_list_array.iter()) +pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor>( + string_array: V, + str_list_array: V, +) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_iter = ArrayIter::new(string_array); + let str_list_iter = ArrayIter::new(str_list_array); + let result = string_iter + .zip(str_list_iter) .map(|(string, str_list)| match (string, str_list) { (Some(string), Some(str_list)) => { let mut res = 0; diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index 24ea2d5a8f25..a6c2b9768f0b 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -17,30 +17,47 @@ use std::any::Any; use std::cmp::Ordering; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, + OffsetSizeTrait, +}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct LeftFunc { signature: Signature, } +impl Default for LeftFunc { + fn default() -> Self { + Self::new() + } +} + impl LeftFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + ], Volatility::Immutable, ), } @@ -66,21 +83,67 @@ impl ScalarUDFImpl for LeftFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(left::, vec![])(args), + DataType::Utf8 | DataType::Utf8View => { + make_scalar_function(left::, vec![])(args) + } DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function left"), + other => exec_err!( + "Unsupported data type {other:?} for function left,\ + expected Utf8View, Utf8 or LargeUtf8." + ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_left_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_left_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a specified number of characters from the left side of a string.") + .with_syntax_example("left(str, n)") + .with_sql_example(r#"```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("n", "Number of characters to return.") + .with_related_udf("right") + .build() + .unwrap() + }) } /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' /// The implementation uses UTF-8 code points as characters pub fn left(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; let n_array = as_int64_array(&args[1])?; - let result = string_array - .iter() + + if args[0].data_type() == &DataType::Utf8View { + let string_array = as_string_view_array(&args[0])?; + left_impl::(string_array, n_array) + } else { + let string_array = as_generic_string_array::(&args[0])?; + left_impl::(string_array, n_array) + } +} + +fn left_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( + string_array: V, + n_array: &Int64Array, +) -> Result { + let iter = ArrayIter::new(string_array); + let result = iter .zip(n_array.iter()) .map(|(string, n)| match (string, n) { (Some(string), Some(n)) => match n.cmp(&0) { diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 47208903bcef..767eda203c8f 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -16,34 +16,55 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::fmt::Write; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, +}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; use unicode_segmentation::UnicodeSegmentation; +use DataType::{LargeUtf8, Utf8, Utf8View}; +use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct LPadFunc { signature: Signature, } +impl Default for LPadFunc { + fn default() -> Self { + Self::new() + } +} + impl LPadFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8View, Int64, Utf8View]), + Exact(vec![Utf8View, Int64, Utf8]), + Exact(vec![Utf8View, Int64, LargeUtf8]), Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8View]), Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![LargeUtf8, Int64, Utf8View]), + Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, Int64, LargeUtf8]), ], Volatility::Immutable, @@ -71,299 +92,446 @@ impl ScalarUDFImpl for LPadFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + Utf8 | Utf8View => make_scalar_function(lpad::, vec![])(args), + LargeUtf8 => make_scalar_function(lpad::, vec![])(args), other => exec_err!("Unsupported data type {other:?} for function lpad"), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_lpad_doc()) + } } -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_lpad_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Pads the left side of a string with another string to a specified string length.") + .with_syntax_example("lpad(str, n[, padding_str])") + .with_sql_example(r#"```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("n", "String length to pad to.") + .with_argument("padding_str", "Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") + .with_related_udf("rpad") + .build() + .unwrap() + }) +} + +/// Extends the string to length 'length' by prepending the characters fill (a space by default). +/// If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; + if args.len() <= 1 || args.len() > 3 { + return exec_err!( + "lpad was called with {} arguments. It requires at least 2 and at most 3.", + args.len() + ); + } - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; + let length_array = as_int64_array(&args[1])?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." + match (args.len(), args[0].data_type()) { + (2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray, T>( + args[0].as_string_view(), + length_array, + None, ), + (2, Utf8 | LargeUtf8) => lpad_impl::< + &GenericStringArray, + &GenericStringArray, + T, + >(args[0].as_string::(), length_array, None), + (3, Utf8View) => lpad_with_replace::<&StringViewArray, T>( + args[0].as_string_view(), + length_array, + &args[2], + ), + (3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray, T>( + args[0].as_string::(), + length_array, + &args[2], + ), + (_, _) => unreachable!(), + } +} + +fn lpad_with_replace<'a, V, T: OffsetSizeTrait>( + string_array: V, + length_array: &Int64Array, + fill_array: &'a ArrayRef, +) -> Result +where + V: StringArrayType<'a>, +{ + match fill_array.data_type() { + Utf8View => lpad_impl::( + string_array, + length_array, + Some(fill_array.as_string_view()), + ), + LargeUtf8 => lpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + Utf8 => lpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + other => { + exec_err!("Unsupported data type {other:?} for function lpad") + } } } +fn lpad_impl<'a, V, V2, T>( + string_array: V, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + V: StringArrayType<'a>, + V2: StringArrayType<'a>, + T: OffsetSizeTrait, +{ + let array = if fill_array.is_none() { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + for (string, length) in string_array.iter().zip(length_array.iter()) { + if let (Some(string), Some(length)) = (string, length) { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + continue; + } + + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else { + builder.write_str(" ".repeat(length - graphemes.len()).as_str())?; + builder.write_str(string)?; + builder.append_value(""); + } + } else { + builder.append_null(); + } + } + + builder.finish() + } else { + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); + + for ((string, length), fill) in string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.unwrap().iter()) + { + if let (Some(string), Some(length), Some(fill)) = (string, length, fill) { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); + continue; + } + + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else if fill_chars.is_empty() { + builder.append_value(string); + } else { + for l in 0..length - graphemes.len() { + let c = *fill_chars.get(l % fill_chars.len()).unwrap(); + builder.write_char(c)?; + } + builder.write_str(string)?; + builder.append_value(""); + } + } else { + builder.append_null(); + } + } + + builder.finish() + }; + + Ok(Arc::new(array) as ArrayRef) +} + #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + use arrow::array::{Array, LargeStringArray, StringArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::unicode::lpad::LPadFunc; - use crate::utils::test::test_function; + macro_rules! test_lpad { + ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + + ($INPUT:expr, $LENGTH:expr, $REPLACE:expr, $EXPECTED:expr) => { + // utf8, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + // largeutf8, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + // largeutf8, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + // largeutf8, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + // utf8view, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8view, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8view, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + } #[test] fn test_functions() -> Result<()> { - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ], - Ok(Some("")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(5i64)), + Ok(Some(" josé")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Ok(Some(" hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(0i64)), + Ok(Some("")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray + test_lpad!(Some("hi".into()), ScalarValue::Int64(None), Ok(None)); + test_lpad!(None, ScalarValue::Int64(Some(5i64)), Ok(None)); + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some("xy".into()), + Ok(Some("xyxhi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(21i64)), - ColumnarValue::Scalar(ScalarValue::from("abcdef")), - ], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(21i64)), + Some("abcdef".into()), + Ok(Some("abcdefabcdefabcdefahi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(" ")), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some(" ".into()), + Ok(Some(" hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("")), - ], - Ok(Some("hi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some("".into()), + Ok(Some("hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + None, + ScalarValue::Int64(Some(5i64)), + Some("xy".into()), + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(None), + Some("xy".into()), + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + None, + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(10i64)), + Some("xy".into()), + Ok(Some("xyxyxyjosé")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("éñ")), - ], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(10i64)), + Some("éñ".into()), + Ok(Some("éñéñéñjosé")) ); + #[cfg(not(feature = "unicode_expressions"))] - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - internal_err!( + test_lpad!(Some("josé".into()), ScalarValue::Int64(Some(5i64)), internal_err!( "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); + )); + Ok(()) } } diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index eba4cd5048eb..40915bc9efde 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -21,17 +21,17 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; -mod character_length; -mod find_in_set; -mod left; -mod lpad; -mod reverse; -mod right; -mod rpad; -mod strpos; -mod substr; -mod substrindex; -mod translate; +pub mod character_length; +pub mod find_in_set; +pub mod left; +pub mod lpad; +pub mod reverse; +pub mod right; +pub mod rpad; +pub mod strpos; +pub mod substr; +pub mod substrindex; +pub mod translate; // create UDFs make_udf_function!( @@ -47,27 +47,68 @@ make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); make_udf_function!(strpos::StrposFunc, STRPOS, strpos); make_udf_function!(substr::SubstrFunc, SUBSTR, substr); +make_udf_function!(substr::SubstrFunc, SUBSTRING, substring); make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); pub mod expr_fn { use datafusion_expr::Expr; + export_functions!(( + character_length, + "the number of characters in the `string`", + string + ),( + lpad, + "fill up a string to the length by prepending the characters", + args, + ),( + rpad, + "fill up a string to the length by appending the characters", + args, + ),( + reverse, + "reverses the `string`", + string + ),( + substr, + "substring from the `position` to the end", + string position + ),( + substr_index, + "Returns the substring from str before count occurrences of the delimiter", + string delimiter count + ),( + strpos, + "finds the position from where the `substring` matches the `string`", + string substring + ),( + substring, + "substring from the `position` with `length` characters", + string position length + ),( + translate, + "replaces the characters in `from` with the counterpart in `to`", + string from to + ),( + right, + "returns the last `n` characters in the `string`", + string n + ),( + left, + "returns the first `n` characters in the `string`", + string n + ),( + find_in_set, + "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings", + string strlist + )); + #[doc = "the number of characters in the `string`"] pub fn char_length(string: Expr) -> Expr { character_length(string) } - #[doc = "the number of characters in the `string`"] - pub fn character_length(string: Expr) -> Expr { - super::character_length().call(vec![string]) - } - - #[doc = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"] - pub fn find_in_set(string: Expr, strlist: Expr) -> Expr { - super::find_in_set().call(vec![string, strlist]) - } - #[doc = "finds the position from where the `substring` matches the `string`"] pub fn instr(string: Expr, substring: Expr) -> Expr { strpos(string, substring) @@ -78,63 +119,13 @@ pub mod expr_fn { character_length(string) } - #[doc = "returns the first `n` characters in the `string`"] - pub fn left(string: Expr, n: Expr) -> Expr { - super::left().call(vec![string, n]) - } - - #[doc = "fill up a string to the length by prepending the characters"] - pub fn lpad(args: Vec) -> Expr { - super::lpad().call(args) - } - #[doc = "finds the position from where the `substring` matches the `string`"] pub fn position(string: Expr, substring: Expr) -> Expr { strpos(string, substring) } - - #[doc = "reverses the `string`"] - pub fn reverse(string: Expr) -> Expr { - super::reverse().call(vec![string]) - } - - #[doc = "returns the last `n` characters in the `string`"] - pub fn right(string: Expr, n: Expr) -> Expr { - super::right().call(vec![string, n]) - } - - #[doc = "fill up a string to the length by appending the characters"] - pub fn rpad(args: Vec) -> Expr { - super::rpad().call(args) - } - - #[doc = "finds the position from where the `substring` matches the `string`"] - pub fn strpos(string: Expr, substring: Expr) -> Expr { - super::strpos().call(vec![string, substring]) - } - - #[doc = "substring from the `position` to the end"] - pub fn substr(string: Expr, position: Expr) -> Expr { - super::substr().call(vec![string, position]) - } - - #[doc = "substring from the `position` with `length` characters"] - pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { - super::substr().call(vec![string, position, length]) - } - - #[doc = "Returns the substring from str before count occurrences of the delimiter"] - pub fn substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr { - super::substr_index().call(vec![string, delimiter, count]) - } - - #[doc = "replaces the characters in `from` with the counterpart in `to`"] - pub fn translate(string: Expr, from: Expr, to: Expr) -> Expr { - super::translate().call(vec![string, from, to]) - } } -/// Return a list of all functions in this package +/// Returns all DataFusion functions defined in this package pub fn functions() -> Vec> { vec![ character_length(), diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index 6b24c2336810..baf3b56636e2 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -16,29 +16,39 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, + OffsetSizeTrait, +}; use arrow::datatypes::DataType; - -use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use DataType::{LargeUtf8, Utf8, Utf8View}; #[derive(Debug)] pub struct ReverseFunc { signature: Signature, } +impl Default for ReverseFunc { + fn default() -> Self { + Self::new() + } +} + impl ReverseFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::uniform( 1, - vec![Utf8, LargeUtf8], + vec![Utf8View, Utf8, LargeUtf8], Volatility::Immutable, ), } @@ -64,23 +74,58 @@ impl ScalarUDFImpl for ReverseFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(reverse::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + Utf8 | Utf8View => make_scalar_function(reverse::, vec![])(args), + LargeUtf8 => make_scalar_function(reverse::, vec![])(args), other => { exec_err!("Unsupported data type {other:?} for function reverse") } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_reverse_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_reverse_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Reverses the character order of a string.") + .with_syntax_example("reverse(str)") + .with_sql_example( + r#"```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +```"#, + ) + .with_standard_argument("str", Some("String")) + .build() + .unwrap() + }) } /// Reverses the order of the characters in the string. /// reverse('abcde') = 'edcba' /// The implementation uses UTF-8 code points as characters pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; + if args[0].data_type() == &Utf8View { + reverse_impl::(args[0].as_string_view()) + } else { + reverse_impl::(args[0].as_string::()) + } +} - let result = string_array - .iter() +fn reverse_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( + string_array: V, +) -> Result { + let result = ArrayIter::new(string_array) .map(|string| string.map(|string: &str| string.chars().rev().collect::())) .collect::>(); @@ -89,8 +134,8 @@ pub fn reverse(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, LargeStringArray, StringArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -98,50 +143,49 @@ mod tests { use crate::unicode::reverse::ReverseFunc; use crate::utils::test::test_function; + macro_rules! test_reverse { + ($INPUT:expr, $EXPECTED:expr) => { + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + } + #[test] fn test_functions() -> Result<()> { - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("abcde"))], - Ok(Some("edcba")), - &str, - Utf8, - StringArray - ); - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("loẅks"))], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("loẅks"))], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); + test_reverse!(Some("abcde".into()), Ok(Some("edcba"))); + test_reverse!(Some("loẅks".into()), Ok(Some("sk̈wol"))); + test_reverse!(Some("loẅks".into()), Ok(Some("sk̈wol"))); + test_reverse!(None, Ok(None)); #[cfg(not(feature = "unicode_expressions"))] - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("abcde"))], + test_reverse!( + Some("abcde".into()), internal_err!( "function reverse requires compilation with feature flag: unicode_expressions." ), - &str, - Utf8, - StringArray ); Ok(()) diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index dddbf31e721b..ab3b7ba1a27e 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -17,30 +17,47 @@ use std::any::Any; use std::cmp::{max, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, GenericStringArray, Int64Array, + OffsetSizeTrait, +}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::{ + as_generic_string_array, as_int64_array, as_string_view_array, +}; use datafusion_common::exec_err; use datafusion_common::Result; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct RightFunc { signature: Signature, } +impl Default for RightFunc { + fn default() -> Self { + Self::new() + } +} + impl RightFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + ], Volatility::Immutable, ), } @@ -66,22 +83,70 @@ impl ScalarUDFImpl for RightFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(right::, vec![])(args), + DataType::Utf8 | DataType::Utf8View => { + make_scalar_function(right::, vec![])(args) + } DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function right"), + other => exec_err!( + "Unsupported data type {other:?} for function right,\ + expected Utf8View, Utf8 or LargeUtf8." + ), } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_right_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_right_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns a specified number of characters from the right side of a string.") + .with_syntax_example("right(str, n)") + .with_sql_example(r#"```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("n", "Number of characters to return") + .with_related_udf("left") + .build() + .unwrap() + }) } /// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. /// right('abcde', 2) = 'de' /// The implementation uses UTF-8 code points as characters pub fn right(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; let n_array = as_int64_array(&args[1])?; + if args[0].data_type() == &DataType::Utf8View { + // string_view_right(args) + let string_array = as_string_view_array(&args[0])?; + right_impl::(&mut string_array.iter(), n_array) + } else { + // string_right::(args) + let string_array = &as_generic_string_array::(&args[0])?; + right_impl::(&mut string_array.iter(), n_array) + } +} - let result = string_array - .iter() +// Currently the return type can only be Utf8 or LargeUtf8, to reach fully support, we need +// to edit the `get_optimal_return_type` in utils.rs to make the udfs be able to return Utf8View +// See https://github.com/apache/datafusion/issues/11790#issuecomment-2283777166 +fn right_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( + string_array_iter: &mut ArrayIter, + n_array: &Int64Array, +) -> Result { + let result = string_array_iter .zip(n_array.iter()) .map(|(string, n)| match (string, n) { (Some(string), Some(n)) => match n.cmp(&0) { diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 8946f07006b7..bd9d625105e9 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -15,35 +15,55 @@ // specific language governing permissions and limitations // under the License. -use std::any::Any; -use std::sync::Arc; - -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; -use unicode_segmentation::UnicodeSegmentation; - +use crate::strings::StringArrayType; use crate::utils::{make_scalar_function, utf8_to_str_type}; +use arrow::array::{ + ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, + OffsetSizeTrait, StringViewArray, +}; +use arrow::datatypes::DataType; +use datafusion_common::cast::as_int64_array; +use datafusion_common::DataFusionError; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::fmt::Write; +use std::sync::{Arc, OnceLock}; +use unicode_segmentation::UnicodeSegmentation; +use DataType::{LargeUtf8, Utf8, Utf8View}; #[derive(Debug)] pub struct RPadFunc { signature: Signature, } +impl Default for RPadFunc { + fn default() -> Self { + Self::new() + } +} + impl RPadFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8View, Int64, Utf8View]), + Exact(vec![Utf8View, Int64, Utf8]), + Exact(vec![Utf8View, Int64, LargeUtf8]), Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8View]), Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![LargeUtf8, Int64, Utf8View]), + Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, Int64, LargeUtf8]), ], Volatility::Immutable, @@ -70,99 +90,215 @@ impl ScalarUDFImpl for RPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(rpad::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(rpad::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function rpad"), + match ( + args.len(), + args[0].data_type(), + args.get(2).map(|arg| arg.data_type()), + ) { + (2, Utf8 | Utf8View, _) => { + make_scalar_function(rpad::, vec![])(args) + } + (2, LargeUtf8, _) => make_scalar_function(rpad::, vec![])(args), + (3, Utf8 | Utf8View, Some(Utf8 | Utf8View)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, LargeUtf8, Some(LargeUtf8)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, Utf8 | Utf8View, Some(LargeUtf8)) => { + make_scalar_function(rpad::, vec![])(args) + } + (3, LargeUtf8, Some(Utf8 | Utf8View)) => { + make_scalar_function(rpad::, vec![])(args) + } + (_, _, _) => { + exec_err!("Unsupported combination of data types for function rpad") + } } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_rpad_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_rpad_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Pads the right side of a string with another string to a specified string length.") + .with_syntax_example("rpad(str, n[, padding_str])") + .with_sql_example(r#"```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +```"#) + .with_standard_argument( + "str", + Some("String"), + ) + .with_argument("n", "String length to pad to.") + .with_argument("padding_str", + "String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._") + .with_related_udf("lpad") + .build() + .unwrap() + }) +} + +pub fn rpad( + args: &[ArrayRef], +) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!( + "rpad was called with {} arguments. It requires 2 or 3 arguments.", + args.len() + ); + } + + let length_array = as_int64_array(&args[1])?; + match ( + args.len(), + args[0].data_type(), + args.get(2).map(|arg| arg.data_type()), + ) { + (2, Utf8View, _) => { + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + None, + ) + } + (3, Utf8View, Some(Utf8View)) => { + rpad_impl::<&StringViewArray, &StringViewArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + Some(args[2].as_string_view()), + ) + } + (3, Utf8View, Some(Utf8 | LargeUtf8)) => { + rpad_impl::<&StringViewArray, &GenericStringArray, StringArrayLen>( + args[0].as_string_view(), + length_array, + Some(args[2].as_string::()), + ) + } + (3, Utf8 | LargeUtf8, Some(Utf8View)) => rpad_impl::< + &GenericStringArray, + &StringViewArray, + StringArrayLen, + >( + args[0].as_string::(), + length_array, + Some(args[2].as_string_view()), + ), + (_, _, _) => rpad_impl::< + &GenericStringArray, + &GenericStringArray, + StringArrayLen, + >( + args[0].as_string::(), + length_array, + args.get(2).map(|arg| arg.as_string::()), + ), + } } /// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. /// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; +pub fn rpad_impl<'a, StringArrType, FillArrType, StringArrayLen>( + string_array: StringArrType, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + StringArrType: StringArrayType<'a>, + FillArrType: StringArrayType<'a>, + StringArrayLen: OffsetSizeTrait, +{ + let mut builder: GenericStringBuilder = GenericStringBuilder::new(); - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) + match fill_array { + None => { + string_array.iter().zip(length_array.iter()).try_for_each( + |(string, length)| -> Result<(), DataFusionError> { + match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {} too large", + length + ); + } + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + builder.append_value(""); } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) + let graphemes = + string.graphemes(true).collect::>(); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else { + builder.write_str(string)?; + builder.write_str( + &" ".repeat(length - graphemes.len()), + )?; + builder.append_value(""); + } } } + _ => builder.append_null(), } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) + Ok(()) + }, + )?; } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array + Some(fill_array) => { + string_array .iter() .zip(length_array.iter()) .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } + .try_for_each( + |((string, length), fill)| -> Result<(), DataFusionError> { + match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {} too large", + length + ); + } + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = + string.graphemes(true).collect::>(); - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + if length < graphemes.len() { + builder.append_value(graphemes[..length].concat()); + } else if fill.is_empty() { + builder.append_value(string); + } else { + builder.write_str(string)?; + fill.chars() + .cycle() + .take(length - graphemes.len()) + .for_each(|ch| builder.write_char(ch).unwrap()); + builder.append_value(""); + } } - s.push_str(char_vector.iter().collect::().as_str()); - Ok(Some(s)) + _ => builder.append_null(), } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) + Ok(()) + }, + )?; } - other => exec_err!( - "rpad was called with {other} arguments. It requires at least 2 and at most 3." - ), } + + Ok(Arc::new(builder.finish()) as ArrayRef) } #[cfg(test)] diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 4ebdd9d58623..9c84590f7f94 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -16,19 +16,17 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ - ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, -}; +use crate::strings::StringArrayType; +use crate::utils::{make_scalar_function, utf8_to_int_type}; +use arrow::array::{ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray}; use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; - -use datafusion_common::cast::as_generic_string_array; use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_int_type}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct StrposFunc { @@ -36,19 +34,16 @@ pub struct StrposFunc { aliases: Vec, } +impl Default for StrposFunc { + fn default() -> Self { + Self::new() + } +} + impl StrposFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - Volatility::Immutable, - ), + signature: Signature::string(2, Volatility::Immutable), aliases: vec![String::from("instr"), String::from("position")], } } @@ -72,46 +67,131 @@ impl ScalarUDFImpl for StrposFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(strpos::, vec![])(args), - DataType::LargeUtf8 => { - make_scalar_function(strpos::, vec![])(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - } + make_scalar_function(strpos, vec![])(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_strpos_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_strpos_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0.") + .with_syntax_example("strpos(str, substr)") + .with_sql_example(r#"```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("substr", "Substring expression to search for.") + .with_alternative_syntax("position(substr in origstr)") + .build() + .unwrap() + }) +} + +fn strpos(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Utf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::Utf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::LargeUtf8, DataType::LargeUtf8) => { + let string_array = args[0].as_string::(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int64Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8View) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string_view(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::Utf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + (DataType::Utf8View, DataType::LargeUtf8) => { + let string_array = args[0].as_string_view(); + let substring_array = args[1].as_string::(); + calculate_strpos::<_, _, Int32Type>(string_array, substring_array) + } + + other => { + exec_err!("Unsupported data type combination {other:?} for function strpos") + } + } } /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters -fn strpos(args: &[ArrayRef]) -> Result +fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>( + string_array: V1, + substring_array: V2, +) -> Result where - T::Native: OffsetSizeTrait, + V1: StringArrayType<'a, Item = &'a str>, + V2: StringArrayType<'a, Item = &'a str>, { - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; + let ascii_only = substring_array.is_ascii() && string_array.is_ascii(); + let string_iter = string_array.iter(); + let substring_iter = substring_array.iter(); - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(substring_array.iter()) + let result = string_iter + .zip(substring_iter) .map(|(string, substring)| match (string, substring) { (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) + // If only ASCII characters are present, we can use the slide window method to find + // the sub vector in the main vector. This is faster than string.find() method. + if ascii_only { + // If the substring is empty, the result is 1. + if substring.as_bytes().is_empty() { + T::Native::from_usize(1) + } else { + T::Native::from_usize( + string + .as_bytes() + .windows(substring.as_bytes().len()) + .position(|w| w == substring.as_bytes()) + .map(|x| x + 1) + .unwrap_or(0), + ) + } + } else { + // The `find` method returns the byte index of the substring. + // We count the number of chars up to that byte index. + T::Native::from_usize( + string + .find(substring) + .map(|x| string[..x].chars().count() + 1) + .unwrap_or(0), + ) + } } _ => None, }) @@ -119,3 +199,97 @@ where Ok(Arc::new(result) as ArrayRef) } + +#[cfg(test)] +mod tests { + use arrow::array::{Array, Int32Array, Int64Array}; + use arrow::datatypes::DataType::{Int32, Int64}; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::strpos::StrposFunc; + use crate::utils::test::test_function; + + macro_rules! test_strpos { + ($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => { + test_function!( + StrposFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::$t1(Some($lhs.to_owned()))), + ColumnarValue::Scalar(ScalarValue::$t2(Some($rhs.to_owned()))), + ], + Ok(Some($result)), + $t3, + $t4, + $t5 + ) + }; + } + + #[test] + fn test_strpos_functions() { + // Utf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 Utf8 i32 Int32 Int32Array); + + // LargeUtf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 LargeUtf8 i64 Int64 Int64Array); + + // Utf8 and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8 LargeUtf8 i32 Int32 Int32Array); + + // LargeUtf8 and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; LargeUtf8 Utf8 i64 Int64 Int64Array); + + // Utf8View and Utf8View combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8View i32 Int32 Int32Array); + + // Utf8View and Utf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View Utf8 i32 Int32 Int32Array); + + // Utf8View and LargeUtf8 combinations + test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array); + test_strpos!("ДатаФусион数据融合📊🔥", "📊" -> 15; Utf8View LargeUtf8 i32 Int32 Int32Array); + } +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 260937a01a74..edfe57210b71 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -16,37 +16,40 @@ // under the License. use std::any::Any; -use std::cmp::max; -use std::sync::Arc; - -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; -use arrow::datatypes::DataType; - -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::sync::{Arc, OnceLock}; +use crate::strings::{make_and_append_view, StringArrayType}; use crate::utils::{make_scalar_function, utf8_to_str_type}; +use arrow::array::{ + Array, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, OffsetSizeTrait, + StringViewArray, +}; +use arrow::datatypes::DataType; +use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; +use datafusion_common::cast::as_int64_array; +use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrFunc { signature: Signature, + aliases: Vec, +} + +impl Default for SubstrFunc { + fn default() -> Self { + Self::new() + } } impl SubstrFunc { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), - ], - Volatility::Immutable, - ), + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![String::from("substring")], } } } @@ -65,67 +68,432 @@ impl ScalarUDFImpl for SubstrFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "substr") + if arg_types[0] == DataType::Utf8View { + Ok(DataType::Utf8View) + } else { + utf8_to_str_type(&arg_types[0], "substr") + } } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(substr::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(substr::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function substr"), + make_scalar_function(substr, vec![])(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() < 2 || arg_types.len() > 3 { + return plan_err!( + "The {} function requires 2 or 3 arguments, but got {}.", + self.name(), + arg_types.len() + ); + } + let first_data_type = match &arg_types[0] { + DataType::Null => Ok(DataType::Utf8), + DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()), + DataType::Dictionary(key_type, value_type) => { + if key_type.is_integer() { + match value_type.as_ref() { + DataType::Null => Ok(DataType::Utf8), + DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(*value_type.clone()), + _ => plan_err!( + "The first argument of the {} function can only be a string, but got {:?}.", + self.name(), + arg_types[0] + ), + } + } else { + plan_err!( + "The first argument of the {} function can only be a string, but got {:?}.", + self.name(), + arg_types[0] + ) + } + } + _ => plan_err!( + "The first argument of the {} function can only be a string, but got {:?}.", + self.name(), + arg_types[0] + ) + }?; + + if ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[1]) { + return plan_err!( + "The second argument of the {} function can only be an integer, but got {:?}.", + self.name(), + arg_types[1] + ); + } + + if arg_types.len() == 3 + && ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[2]) + { + return plan_err!( + "The third argument of the {} function can only be an integer, but got {:?}.", + self.name(), + arg_types[2] + ); + } + + if arg_types.len() == 2 { + Ok(vec![first_data_type.to_owned(), DataType::Int64]) + } else { + Ok(vec![ + first_data_type.to_owned(), + DataType::Int64, + DataType::Int64, + ]) } } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_substr_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_substr_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Extracts a substring of a specified number of characters from a specific starting position in a string.") + .with_syntax_example("substr(str, start_pos[, length])") + .with_sql_example(r#"```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("start_pos", "Character position to start the substring at. The first character in the string has a position of 1.") + .with_argument("length", "Number of characters to extract. If not specified, returns the rest of the string after the start position.") + .with_alternative_syntax("substring(str from start_pos for length)") + .build() + .unwrap() + }) } /// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) /// substr('alphabet', 3) = 'phabet' /// substr('alphabet', 3, 2) = 'ph' /// The implementation uses UTF-8 code points as characters -pub fn substr(args: &[ArrayRef]) -> Result { +pub fn substr(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + string_substr::<_, i32>(string_array, &args[1..]) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + string_substr::<_, i64>(string_array, &args[1..]) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + string_view_substr(string_array, &args[1..]) + } + other => exec_err!( + "Unsupported data type {other:?} for function substr,\ + expected Utf8View, Utf8 or LargeUtf8." + ), + } +} + +// Convert the given `start` and `count` to valid byte indices within `input` string +// +// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, count)` +// `start` is 1-based, if `count` is not provided count to the end of the string +// Input indices are character-based, and return values are byte indices +// The input bounds can be outside string bounds, this function will return +// the intersection between input bounds and valid string bounds +// `input_ascii_only` is used to optimize this function if `input` is ASCII-only +// +// * Example +// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx] +// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)` +// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)` +// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)` +fn get_true_start_end( + input: &str, + start: i64, + count: Option, + is_input_ascii_only: bool, +) -> (usize, usize) { + let start = start.checked_sub(1).unwrap_or(start); + + let end = match count { + Some(count) => start + count as i64, + None => input.len() as i64, + }; + let count_to_end = count.is_some(); + + let start = start.clamp(0, input.len() as i64) as usize; + let end = end.clamp(0, input.len() as i64) as usize; + let count = end - start; + + // If input is ASCII-only, byte-based indices equals to char-based indices + if is_input_ascii_only { + return (start, end); + } + + // Otherwise, calculate byte indices from char indices + // Note this decoding is relatively expensive for this simple `substr` function,, + // so the implementation attempts to decode in one pass (and caused the complexity) + let (mut st, mut ed) = (input.len(), input.len()); + let mut start_counting = false; + let mut cnt = 0; + for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() { + if char_cnt == start { + st = byte_cnt; + if count_to_end { + start_counting = true; + } else { + break; + } + } + if start_counting { + if cnt == count { + ed = byte_cnt; + break; + } + cnt += 1; + } + } + (st, ed) +} + +// String characters are variable length encoded in UTF-8, `substr()` function's +// arguments are character-based, converting them into byte-based indices +// requires expensive decoding. +// However, checking if a string is ASCII-only is relatively cheap. +// If strings are ASCII only, use byte-based indices instead. +// +// A common pattern to call `substr()` is taking a small prefix of a long +// string, such as `substr(long_str_with_1k_chars, 1, 32)`. +// In such case the overhead of ASCII-validation may not be worth it, so +// skip the validation for short prefix for now. +fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>( + string_array: &V, + start: &Int64Array, + count: Option<&Int64Array>, +) -> bool { + let is_short_prefix = match count { + Some(count) => { + let short_prefix_threshold = 32.0; + let n_sample = 10; + + // HACK: can be simplified if function has specialized + // implementation for `ScalarValue` (implement without `make_scalar_function()`) + let avg_prefix_len = start + .iter() + .zip(count.iter()) + .take(n_sample) + .map(|(start, count)| { + let start = start.unwrap_or(0); + let count = count.unwrap_or(0); + // To get substring, need to decode from 0 to start+count instead of start to start+count + start + count + }) + .sum::(); + + avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold + } + None => false, + }; + + if is_short_prefix { + // Skip ASCII validation for short prefix + false + } else { + string_array.is_ascii() + } +} + +// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44 +// From for ByteView +fn string_view_substr( + string_view_array: &StringViewArray, + args: &[ArrayRef], +) -> Result { + let mut views_buf = Vec::with_capacity(string_view_array.len()); + let mut null_builder = NullBufferBuilder::new(string_view_array.len()); + + let start_array = as_int64_array(&args[0])?; + let count_array_opt = if args.len() == 2 { + Some(as_int64_array(&args[1])?) + } else { + None + }; + + let enable_ascii_fast_path = + enable_ascii_fast_path(&string_view_array, start_array, count_array_opt); + + // In either case of `substr(s, i)` or `substr(s, i, cnt)` + // If any of input argument is `NULL`, the result is `NULL` match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; + 1 => { + for ((str_opt, raw_view), start_opt) in string_view_array + .iter() + .zip(string_view_array.views().iter()) + .zip(start_array.iter()) + { + if let (Some(str), Some(start)) = (str_opt, start_opt) { + let (start, end) = + get_true_start_end(str, start, None, enable_ascii_fast_path); + let substr = &str[start..end]; - let result = string_array + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw_view, + substr, + start as u32, + ); + } else { + null_builder.append_null(); + views_buf.push(0); + } + } + } + 2 => { + let count_array = count_array_opt.unwrap(); + for (((str_opt, raw_view), start_opt), count_opt) in string_view_array .iter() + .zip(string_view_array.views().iter()) + .zip(start_array.iter()) + .zip(count_array.iter()) + { + if let (Some(str), Some(start), Some(count)) = + (str_opt, start_opt, count_opt) + { + if count < 0 { + return exec_err!( + "negative substring length not allowed: substr(, {start}, {count})" + ); + } else { + if start == i64::MIN { + return exec_err!( + "negative overflow when calculating skip value" + ); + } + let (start, end) = get_true_start_end( + str, + start, + Some(count as u64), + enable_ascii_fast_path, + ); + let substr = &str[start..end]; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw_view, + substr, + start as u32, + ); + } + } else { + null_builder.append_null(); + views_buf.push(0); + } + } + } + other => { + return exec_err!( + "substr was called with {other} arguments. It requires 2 or 3." + ) + } + } + + let views_buf = ScalarBuffer::from(views_buf); + let nulls_buf = null_builder.finish(); + + // Safety: + // (1) The blocks of the given views are all provided + // (2) Each of the range `view.offset+start..end` of view in views_buf is within + // the bounds of each of the blocks + unsafe { + let array = StringViewArray::new_unchecked( + views_buf, + string_view_array.data_buffers().to_vec(), + nulls_buf, + ); + Ok(Arc::new(array) as ArrayRef) + } +} + +fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result +where + V: StringArrayType<'a>, + T: OffsetSizeTrait, +{ + let start_array = as_int64_array(&args[0])?; + let count_array_opt = if args.len() == 2 { + Some(as_int64_array(&args[1])?) + } else { + None + }; + + let enable_ascii_fast_path = + enable_ascii_fast_path(&string_array, start_array, count_array_opt); + + match args.len() { + 1 => { + let iter = ArrayIter::new(string_array); + + let result = iter .zip(start_array.iter()) .map(|(string, start)| match (string, start) { (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } + let (start, end) = get_true_start_end( + string, + start, + None, + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + Some(substr.to_string()) } _ => None, }) .collect::>(); - Ok(Arc::new(result) as ArrayRef) } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - let count_array = as_int64_array(&args[2])?; + 2 => { + let iter = ArrayIter::new(string_array); + let count_array = count_array_opt.unwrap(); - let result = string_array - .iter() + let result = iter .zip(start_array.iter()) .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( + .map(|((string, start), count)| { + match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + exec_err!( "negative substring length not allowed: substr(, {start}, {count})" ) - } else { - let skip = max(0, start - 1); - let count = max(0, count + (if start < 1 {start - 1} else {0})); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } else { + if start == i64::MIN { + return exec_err!("negative overflow when calculating skip value"); + } + let (start, end) = get_true_start_end( + string, + start, + Some(count as u64), + enable_ascii_fast_path, + ); // start, end is byte-based + let substr = &string[start..end]; + Ok(Some(substr.to_string())) + } } + _ => Ok(None), } - _ => Ok(None), }) .collect::>>()?; @@ -139,8 +507,8 @@ pub fn substr(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -150,6 +518,98 @@ mod tests { #[test] fn test_functions() -> Result<()> { + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "this és longer than 12B" + )))), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some(" é")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "this is longer than 12B" + )))), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some(" is longer than 12B")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "joséésoj" + )))), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("ésoj")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ph")), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), + ], + Ok(Some("phabet")), + &str, + Utf8View, + StringViewArray + ); test_function!( SubstrFunc::new(), &[ @@ -386,6 +846,29 @@ mod tests { Utf8, StringArray ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abc")), + ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("overflow")), + ColumnarValue::Scalar(ScalarValue::from(-9223372036854775808i64)), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + exec_err!("negative overflow when calculating skip value"), + &str, + Utf8, + StringArray + ); Ok(()) } diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index da4ff55828e9..c04839783f58 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -16,17 +16,21 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder}; -use arrow::datatypes::DataType; +use arrow::array::{ + ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait, + PrimitiveArray, StringBuilder, +}; +use arrow::datatypes::{DataType, Int32Type, Int64Type}; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SubstrIndexFunc { @@ -34,12 +38,19 @@ pub struct SubstrIndexFunc { aliases: Vec, } +impl Default for SubstrIndexFunc { + fn default() -> Self { + Self::new() + } +} + impl SubstrIndexFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Utf8View, Int64]), Exact(vec![Utf8, Utf8, Int64]), Exact(vec![LargeUtf8, LargeUtf8, Int64]), ], @@ -68,20 +79,48 @@ impl ScalarUDFImpl for SubstrIndexFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(substr_index::, vec![])(args), - DataType::LargeUtf8 => { - make_scalar_function(substr_index::, vec![])(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function substr_index") - } - } + make_scalar_function(substr_index, vec![])(args) } fn aliases(&self) -> &[String] { &self.aliases } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_substr_index_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_substr_index_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description(r#"Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned."#) + .with_syntax_example("substr_index(str, delim, count)") + .with_sql_example(r#"```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("delim", "The string to find in str to split str.") + .with_argument("count", "The number of times to search for the delimiter. Can be either a positive or negative number.") + .build() + .unwrap() + }) } /// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. @@ -89,7 +128,7 @@ impl ScalarUDFImpl for SubstrIndexFunc { /// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache /// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org /// SUBSTRING_INDEX('www.apache.org', '.', -1) = org -pub fn substr_index(args: &[ArrayRef]) -> Result { +fn substr_index(args: &[ArrayRef]) -> Result { if args.len() != 3 { return exec_err!( "substr_index was called with {} arguments. It requires 3.", @@ -97,15 +136,63 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { ); } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let count_array = as_int64_array(&args[2])?; + match args[0].data_type() { + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let delimiter_array = args[1].as_string::(); + let count_array: &PrimitiveArray = args[2].as_primitive(); + substr_index_general::( + string_array, + delimiter_array, + count_array, + ) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let delimiter_array = args[1].as_string::(); + let count_array: &PrimitiveArray = args[2].as_primitive(); + substr_index_general::( + string_array, + delimiter_array, + count_array, + ) + } + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let delimiter_array = args[1].as_string_view(); + let count_array: &PrimitiveArray = args[2].as_primitive(); + substr_index_general::( + string_array, + delimiter_array, + count_array, + ) + } + other => { + exec_err!("Unsupported data type {other:?} for function substr_index") + } + } +} +pub fn substr_index_general< + 'a, + T: ArrowPrimitiveType, + V: ArrayAccessor, + P: ArrayAccessor, +>( + string_array: V, + delimiter_array: V, + count_array: P, +) -> Result +where + T::Native: OffsetSizeTrait, +{ let mut builder = StringBuilder::new(); - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(count_array.iter()) + let string_iter = ArrayIter::new(string_array); + let delimiter_array_iter = ArrayIter::new(delimiter_array); + let count_array_iter = ArrayIter::new(count_array); + string_iter + .zip(delimiter_array_iter) + .zip(count_array_iter) .for_each(|((string, delimiter), n)| match (string, delimiter, n) { (Some(string), Some(delimiter), Some(n)) => { // In MySQL, these cases will return an empty string. @@ -116,15 +203,15 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); let length = if n > 0 { - let splitted = string.split(delimiter); - splitted + let split = string.split(delimiter); + split .take(occurrences) .map(|s| s.len() + delimiter.len()) .sum::() - delimiter.len() } else { - let splitted = string.rsplit(delimiter); - splitted + let split = string.rsplit(delimiter); + split .take(occurrences) .map(|s| s.len() + delimiter.len()) .sum::() diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 25daf8738b21..fa626b396b3b 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -16,31 +16,43 @@ // under the License. use std::any::Any; -use std::sync::Arc; +use std::sync::{Arc, OnceLock}; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait, +}; use arrow::datatypes::DataType; use hashbrown::HashMap; use unicode_segmentation::UnicodeSegmentation; -use datafusion_common::cast::as_generic_string_array; +use crate::utils::{make_scalar_function, utf8_to_str_type}; use datafusion_common::{exec_err, Result}; +use datafusion_expr::scalar_doc_sections::DOC_SECTION_STRING; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct TranslateFunc { signature: Signature, } +impl Default for TranslateFunc { + fn default() -> Self { + Self::new() + } +} + impl TranslateFunc { pub fn new() -> Self { use DataType::*; Self { signature: Signature::one_of( - vec![Exact(vec![Utf8, Utf8, Utf8])], + vec![ + Exact(vec![Utf8View, Utf8, Utf8]), + Exact(vec![Utf8, Utf8, Utf8]), + ], Volatility::Immutable, ), } @@ -65,27 +77,82 @@ impl ScalarUDFImpl for TranslateFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(translate::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(translate::, vec![])(args), - other => { - exec_err!("Unsupported data type {other:?} for function translate") - } + make_scalar_function(invoke_translate, vec![])(args) + } + + fn documentation(&self) -> Option<&Documentation> { + Some(get_translate_doc()) + } +} + +static DOCUMENTATION: OnceLock = OnceLock::new(); + +fn get_translate_doc() -> &'static Documentation { + DOCUMENTATION.get_or_init(|| { + Documentation::builder() + .with_doc_section(DOC_SECTION_STRING) + .with_description("Translates characters in a string to specified translation characters.") + .with_syntax_example("translate(str, chars, translation)") + .with_sql_example(r#"```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ +```"#) + .with_standard_argument("str", Some("String")) + .with_argument("chars", "Characters to translate.") + .with_argument("translation", "Translation characters. Translation characters replace only characters at the same position in the **chars** string.") + .build() + .unwrap() + }) +} + +fn invoke_translate(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + translate::(string_array, from_array, to_array) + } + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + translate::(string_array, from_array, to_array) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let from_array = args[1].as_string::(); + let to_array = args[2].as_string::(); + translate::(string_array, from_array, to_array) + } + other => { + exec_err!("Unsupported data type {other:?} for function translate") } } } /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' -fn translate(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) +fn translate<'a, T: OffsetSizeTrait, V, B>( + string_array: V, + from_array: B, + to_array: B, +) -> Result +where + V: ArrayAccessor, + B: ArrayAccessor, +{ + let string_array_iter = ArrayIter::new(string_array); + let from_array_iter = ArrayIter::new(from_array); + let to_array_iter = ArrayIter::new(to_array); + + let result = string_array_iter + .zip(from_array_iter) + .zip(to_array_iter) .map(|((string, from), to)| match (string, from, to) { (Some(string), Some(from), Some(to)) => { // create a hashmap of [char, index] to change from O(n) to O(1) for from list diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index d14844c4a445..87180cb77de7 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -17,10 +17,10 @@ use arrow::array::ArrayRef; use arrow::datatypes::DataType; + use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; -use std::sync::Arc; +use datafusion_expr::function::Hint; +use datafusion_expr::ColumnarValue; /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. @@ -29,6 +29,8 @@ use std::sync::Arc; /// `$largeUtf8Type`, /// /// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +/// +/// If the input type is `Utf8View` the return type is $utf8Type, macro_rules! get_optimal_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { @@ -37,6 +39,8 @@ macro_rules! get_optimal_return_type { DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, // Binary inputs are automatically coerced to Utf8 DataType::Utf8 | DataType::Binary => $utf8Type, + // Utf8View max offset size is u32::MAX, the same as UTF8 + DataType::Utf8View | DataType::BinaryView => $utf8Type, DataType::Null => DataType::Null, DataType::Dictionary(_, value_type) => match **value_type { DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, @@ -74,11 +78,11 @@ get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); pub(super) fn make_scalar_function( inner: F, hints: Vec, -) -> ScalarFunctionImplementation +) -> impl Fn(&[ColumnarValue]) -> Result where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, + F: Fn(&[ArrayRef]) -> Result, { - Arc::new(move |args: &[ColumnarValue]| { + move |args: &[ColumnarValue]| { // first, identify if any of the arguments is an Array. If yes, store its `len`, // as any scalar will need to be converted to an array of len `len`. let len = args @@ -103,7 +107,7 @@ where }; arg.clone().into_array(expansion_len) }) - .collect::>>()?; + .collect::>>()?; let result = (inner)(&args); if is_scalar { @@ -113,7 +117,7 @@ where } else { result.map(ColumnarValue::Array) } - }) + } } #[cfg(test)] @@ -130,6 +134,13 @@ pub mod test { let func = $FUNC; let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let cardinality = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }) + .unwrap_or(1); let return_type = func.return_type(&type_array); match expected { @@ -137,17 +148,10 @@ pub mod test { assert_eq!(return_type.is_ok(), true); assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); - let result = func.invoke($ARGS); - assert_eq!(result.is_ok(), true); - - let len = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - let inferred_length = len.unwrap_or(1); - let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = func.invoke_batch($ARGS, cardinality); + assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); + + let result = result.unwrap().clone().into_array(cardinality).expect("Failed to convert to array"); let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); // value is correct @@ -165,7 +169,7 @@ pub mod test { } else { // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke($ARGS) { + match func.invoke_batch($ARGS, cardinality) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); @@ -177,6 +181,21 @@ pub mod test { }; } + use arrow::datatypes::DataType; #[allow(unused_imports)] pub(crate) use test_function; + + use super::*; + + #[test] + fn string_to_int_type() { + let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap(); + assert_eq!(v, DataType::Int32); + + let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap(); + assert_eq!(v, DataType::Int32); + + let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap(); + assert_eq!(v, DataType::Int64); + } } diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index b1a6953501a6..79a5bb24e918 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -35,10 +35,6 @@ workspace = true name = "datafusion_optimizer" path = "src/lib.rs" -[features] -default = ["regex_expressions"] -regex_expressions = ["datafusion-physical-expr/regex_expressions"] - [dependencies] arrow = { workspace = true } async-trait = { workspace = true } @@ -46,12 +42,17 @@ chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } log = { workspace = true } +paste = "1.0.14" regex-syntax = "0.8.0" [dev-dependencies] +arrow-buffer = { workspace = true } ctor = { workspace = true } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-window-common = { workspace = true } datafusion-sql = { workspace = true } env_logger = { workspace = true } diff --git a/datafusion/optimizer/README.md b/datafusion/optimizer/README.md index 2f1f85e3a57a..61bc1cd70145 100644 --- a/datafusion/optimizer/README.md +++ b/datafusion/optimizer/README.md @@ -17,318 +17,6 @@ under the License. --> -# DataFusion Query Optimizer +Please see [Query Optimizer] in the Library User Guide -[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory -format. - -DataFusion has modular design, allowing individual crates to be re-used in other projects. - -This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and -contains an extensive set of OptimizerRules that may rewrite the plan and/or its expressions so -they execute more quickly while still computing the same result. - -## Running the Optimizer - -The following code demonstrates the basic flow of creating the optimizer with a default set of optimization rules -and applying it to a logical plan to produce an optimized logical plan. - -```rust - -// We need a logical plan as the starting point. There are many ways to build a logical plan: -// -// The `datafusion-expr` crate provides a LogicalPlanBuilder -// The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL -// The `datafusion` crate provides a DataFrame API that can create a LogicalPlan -let logical_plan = ... - -let mut config = OptimizerContext::default(); -let optimizer = Optimizer::new(&config); -let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; - -fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { - println!( - "After applying rule '{}':\n{}", - rule.name(), - plan.display_indent() - ) -} -``` - -## Providing Custom Rules - -The optimizer can be created with a custom set of rules. - -```rust -let optimizer = Optimizer::with_rules(vec![ - Arc::new(MyRule {}) -]); -``` - -## Writing Optimization Rules - -Please refer to the [rewrite_expr example](../../datafusion-examples/examples/rewrite_expr.rs) to learn more about -the general approach to writing optimizer rules and then move onto studying the existing rules. - -All rules must implement the `OptimizerRule` trait. - -```rust -/// `OptimizerRule` transforms one ['LogicalPlan'] into another which -/// computes the same results, but in a potentially more efficient -/// way. If there are no suitable transformations for the input plan, -/// the optimizer can simply return it as is. -pub trait OptimizerRule { - /// Rewrite `plan` to an optimized form - fn optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result; - - /// A human readable name for this optimizer rule - fn name(&self) -> &str; -} -``` - -### General Guidelines - -Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate -individual operators or expressions. - -Sometimes there is an initial pass that visits the plan and builds state that is used in a second pass that performs -the actual optimization. This approach is used in projection push down and filter push down. - -### Expression Naming - -Every expression in DataFusion has a name, which is used as the column name. For example, in this example the output -contains a single column with the name `"COUNT(aggregate_test_100.c9)"`: - -```text -> select count(c9) from aggregate_test_100; -+------------------------------+ -| COUNT(aggregate_test_100.c9) | -+------------------------------+ -| 100 | -+------------------------------+ -``` - -These names are used to refer to the columns in both subqueries as well as internally from one stage of the LogicalPlan -to another. For example: - -```text -> select "COUNT(aggregate_test_100.c9)" + 1 from (select count(c9) from aggregate_test_100) as sq; -+--------------------------------------------+ -| sq.COUNT(aggregate_test_100.c9) + Int64(1) | -+--------------------------------------------+ -| 101 | -+--------------------------------------------+ -``` - -### Implication - -Because DataFusion identifies columns using a string name, it means it is critical that the names of expressions are -not changed by the optimizer when it rewrites expressions. This is typically accomplished by renaming a rewritten -expression by adding an alias. - -Here is a simple example of such a rewrite. The expression `1 + 2` can be internally simplified to 3 but must still be -displayed the same as `1 + 2`: - -```text -> select 1 + 2; -+---------------------+ -| Int64(1) + Int64(2) | -+---------------------+ -| 3 | -+---------------------+ -``` - -Looking at the `EXPLAIN` output we can see that the optimizer has effectively rewritten `1 + 2` into effectively -`3 as "1 + 2"`: - -```text -> explain select 1 + 2; -+---------------+-------------------------------------------------+ -| plan_type | plan | -+---------------+-------------------------------------------------+ -| logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | -| | EmptyRelation | -| physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | -| | PlaceholderRowExec | -| | | -+---------------+-------------------------------------------------+ -``` - -If the expression name is not preserved, bugs such as [#3704](https://github.com/apache/datafusion/issues/3704) -and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the expected columns can not be found. - -### Building Expression Names - -There are currently two ways to create a name for an expression in the logical plan. - -```rust -impl Expr { - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. - pub fn display_name(&self) -> Result { - create_name(self) - } - - /// Returns a full and complete string representation of this expression. - pub fn canonical_name(&self) -> String { - format!("{}", self) - } -} -``` - -When comparing expressions to determine if they are equivalent, `canonical_name` should be used, and when creating a -name to be used in a schema, `display_name` should be used. - -### Utilities - -There are a number of utility methods provided that take care of some common tasks. - -### ExprVisitor - -The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. - -Here is an example that demonstrates this. - -```rust -fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { - struct InSubqueryVisitor<'a> { - accum: &'a mut Vec, - } - - impl ExpressionVisitor for InSubqueryVisitor<'_> { - fn pre_visit(self, expr: &Expr) -> Result> { - if let Expr::InSubquery(_) = expr { - self.accum.push(expr.to_owned()); - } - Ok(Recursion::Continue(self)) - } - } - - expression.accept(InSubqueryVisitor { accum: extracted })?; - Ok(()) -} -``` - -### Rewriting Expressions - -The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied -to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). - -The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, -consuming `self` producing a new expression. - -```rust -let mut expr_rewriter = MyExprRewriter {}; -let expr = expr.rewrite(&mut expr_rewriter)?; -``` - -Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the -implementation does not need to perform any recursion since this is handled by the `rewrite` method. - -```rust -struct MyExprRewriter {} - -impl ExprRewriter for MyExprRewriter { - fn mutate(&mut self, expr: Expr) -> Result { - match expr { - Expr::Between { - negated, - expr, - low, - high, - } => { - let expr: Expr = expr.as_ref().clone(); - let low: Expr = low.as_ref().clone(); - let high: Expr = high.as_ref().clone(); - if negated { - Ok(expr.clone().lt(low).or(expr.clone().gt(high))) - } else { - Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) - } - } - _ => Ok(expr.clone()), - } - } -} -``` - -### optimize_children - -Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate -that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on -the plan's children and then returns a node of the same type. - -```rust -fn optimize( - &self, - plan: &LogicalPlan, - _config: &mut OptimizerConfig, -) -> Result { - // recurse down and optimize children first - let plan = utils::optimize_children(self, plan, _config)?; - - ... -} -``` - -### Writing Tests - -There should be unit tests in the same file as the new rule that test the effect of the rule being applied to a plan -in isolation (without any other rule being applied). - -There should also be a test in `integration-tests.rs` that tests the rule as part of the overall optimization process. - -### Debugging - -The `EXPLAIN VERBOSE` command can be used to show the effect of each optimization rule on a query. - -In the following example, the `type_coercion` and `simplify_expressions` passes have simplified the plan so that it returns the constant `"3.2"` rather than doing a computation at execution time. - -```text -> explain verbose select cast(1 + 2.2 as string) as foo; -+------------------------------------------------------------+---------------------------------------------------------------------------+ -| plan_type | plan | -+------------------------------------------------------------+---------------------------------------------------------------------------+ -| initial_logical_plan | Projection: CAST(Int64(1) + Float64(2.2) AS Utf8) AS foo | -| | EmptyRelation | -| logical_plan after type_coercion | Projection: CAST(CAST(Int64(1) AS Float64) + Float64(2.2) AS Utf8) AS foo | -| | EmptyRelation | -| logical_plan after simplify_expressions | Projection: Utf8("3.2") AS foo | -| | EmptyRelation | -| logical_plan after unwrap_cast_in_comparison | SAME TEXT AS ABOVE | -| logical_plan after decorrelate_where_exists | SAME TEXT AS ABOVE | -| logical_plan after decorrelate_where_in | SAME TEXT AS ABOVE | -| logical_plan after scalar_subquery_to_join | SAME TEXT AS ABOVE | -| logical_plan after subquery_filter_to_join | SAME TEXT AS ABOVE | -| logical_plan after simplify_expressions | SAME TEXT AS ABOVE | -| logical_plan after eliminate_filter | SAME TEXT AS ABOVE | -| logical_plan after reduce_cross_join | SAME TEXT AS ABOVE | -| logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | -| logical_plan after eliminate_limit | SAME TEXT AS ABOVE | -| logical_plan after projection_push_down | SAME TEXT AS ABOVE | -| logical_plan after rewrite_disjunctive_predicate | SAME TEXT AS ABOVE | -| logical_plan after reduce_outer_join | SAME TEXT AS ABOVE | -| logical_plan after filter_push_down | SAME TEXT AS ABOVE | -| logical_plan after limit_push_down | SAME TEXT AS ABOVE | -| logical_plan after single_distinct_aggregation_to_group_by | SAME TEXT AS ABOVE | -| logical_plan | Projection: Utf8("3.2") AS foo | -| | EmptyRelation | -| initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | PlaceholderRowExec | -| | | -| physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | -| physical_plan after join_selection | SAME TEXT AS ABOVE | -| physical_plan after coalesce_batches | SAME TEXT AS ABOVE | -| physical_plan after repartition | SAME TEXT AS ABOVE | -| physical_plan after add_merge_exec | SAME TEXT AS ABOVE | -| physical_plan | ProjectionExec: expr=[3.2 as foo] | -| | PlaceholderRowExec | -| | | -+------------------------------------------------------------+---------------------------------------------------------------------------+ -``` - -[df]: https://crates.io/crates/datafusion +[query optimizer]: https://datafusion.apache.org/library-user-guide/query-optimizer.html diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 1ab3d1a81038..454afa24b628 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -21,21 +21,19 @@ use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{ - AggregateFunction, AggregateFunctionDefinition, WindowFunction, -}; +use datafusion_expr::expr::{AggregateFunction, WindowFunction}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// /// Resolves issue: -#[derive(Default)] +#[derive(Default, Debug)] pub struct CountWildcardRule {} impl CountWildcardRule { pub fn new() -> Self { - CountWildcardRule {} + Self {} } } @@ -50,33 +48,29 @@ impl AnalyzerRule for CountWildcardRule { } fn is_wildcard(expr: &Expr) -> bool { - matches!(expr, Expr::Wildcard { qualifier: None }) + matches!(expr, Expr::Wildcard { .. }) } fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { - matches!( - &aggregate_function.func_def, - AggregateFunctionDefinition::BuiltIn( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ) - ) && aggregate_function.args.len() == 1 - && is_wildcard(&aggregate_function.args[0]) + matches!(aggregate_function, + AggregateFunction { + func, + args, + .. + } if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { - matches!( - &window_function.fun, - WindowFunctionDefinition::AggregateFunction( - datafusion_expr::aggregate_function::AggregateFunction::Count, - ) - ) && window_function.args.len() == 1 - && is_wildcard(&window_function.args[0]) + let args = &window_function.args; + matches!(window_function.fun, + WindowFunctionDefinition::AggregateUDF(ref udaf) + if udaf.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } fn analyze_internal(plan: LogicalPlan) -> Result> { let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr)?; + let original_name = name_preserver.save(&expr); let transformed_expr = expr.transform_up(|expr| match expr { Expr::WindowFunction(mut window_function) if is_count_star_window_aggregate(&window_function) => @@ -94,7 +88,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { } _ => Ok(Transformed::no(expr)), })?; - transformed_expr.map_data(|data| original_name.restore(data)) + Ok(transformed_expr.update_data(|data| original_name.restore(data))) }) } @@ -105,14 +99,18 @@ mod tests { use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_expr::expr::Sort; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - col, count, exists, expr, in_subquery, logical_plan::LogicalPlanBuilder, max, - out_ref_col, scalar_subquery, sum, wildcard, AggregateFunction, WindowFrame, - WindowFrameBound, WindowFrameUnits, + col, exists, in_subquery, logical_plan::LogicalPlanBuilder, out_ref_col, + scalar_subquery, wildcard, WindowFrame, WindowFrameBound, WindowFrameUnits, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::max; use std::sync::Arc; - fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + use datafusion_functions_aggregate::expr_fn::{count, sum}; + + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_analyzed_plan_eq_display_indent( Arc::new(CountWildcardRule::new()), plan, @@ -128,11 +126,11 @@ mod tests { .project(vec![count(wildcard())])? .sort(vec![count(wildcard()).sort(true, false)])? .build()?; - let expected = "Sort: COUNT(*) ASC NULLS LAST [COUNT(*):Int64;N]\ - \n Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [b:UInt32, COUNT(*):Int64;N]\ + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -153,12 +151,12 @@ mod tests { .build()?; let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(*):Int64;N]\ - \n Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n Subquery: [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -176,12 +174,12 @@ mod tests { .build()?; let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(*):Int64;N]\ - \n Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + \n Subquery: [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -208,38 +206,37 @@ mod tests { let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ \n Filter: () > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [COUNT(Int64(1)):Int64;N]\ - \n Projection: COUNT(Int64(1)) [COUNT(Int64(1)):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] [COUNT(Int64(1)):Int64;N]\ + \n Subquery: [count(Int64(1)):Int64]\ + \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] [count(Int64(1)):Int64]\ \n Filter: outer_ref(t1.a) = t2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] fn test_count_wildcard_on_window() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .window(vec![Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + .window(vec![Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![wildcard()], - vec![], - vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], - WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - ), - None, - ))])? + )) + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build()?])? .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(Int64(1)) AS COUNT(*) [COUNT(*):Int64;N]\ - \n WindowAggr: windowExpr=[[COUNT(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS COUNT(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ + \n WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, count(*) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -250,20 +247,18 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] [COUNT(*):Int64;N]\ + let expected = "Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] [count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] fn test_count_wildcard_on_non_count_aggregate() -> Result<()> { let table_scan = test_table_scan()?; - let err = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![sum(wildcard())]) - .unwrap_err() - .to_string(); - assert!(err.contains("Error during planning: No function matches the given name and argument types 'SUM(Null)'."), "{err}"); + let res = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![sum(wildcard())]); + assert!(res.is_err()); Ok(()) } @@ -275,9 +270,9 @@ mod tests { .project(vec![count(wildcard())])? .build()?; - let expected = "Projection: COUNT(Int64(1)) AS COUNT(*) [COUNT(*):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(Int64(1))) AS MAX(COUNT(*))]] [MAX(COUNT(*)):Int64;N]\ + let expected = "Projection: count(Int64(1)) AS count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[max(count(Int64(1))) AS max(count(*))]] [max(count(*)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs new file mode 100644 index 000000000000..9fbe54e1ccb9 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/expand_wildcard_rule.rs @@ -0,0 +1,330 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use crate::AnalyzerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult}; +use datafusion_common::{Column, Result}; +use datafusion_expr::builder::validate_unique_names; +use datafusion_expr::expr::PlannedReplaceSelectItem; +use datafusion_expr::utils::{ + expand_qualified_wildcard, expand_wildcard, find_base_plan, +}; +use datafusion_expr::{ + Distinct, DistinctOn, Expr, LogicalPlan, Projection, SubqueryAlias, +}; + +#[derive(Default, Debug)] +pub struct ExpandWildcardRule {} + +impl ExpandWildcardRule { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for ExpandWildcardRule { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + // Because the wildcard expansion is based on the schema of the input plan, + // using `transform_up_with_subqueries` here. + plan.transform_up_with_subqueries(expand_internal).data() + } + + fn name(&self) -> &str { + "expand_wildcard_rule" + } +} + +fn expand_internal(plan: LogicalPlan) -> Result> { + match plan { + LogicalPlan::Projection(Projection { expr, input, .. }) => { + let projected_expr = expand_exprlist(&input, expr)?; + validate_unique_names("Projections", projected_expr.iter())?; + Ok(Transformed::yes( + Projection::try_new(projected_expr, Arc::clone(&input)) + .map(LogicalPlan::Projection)?, + )) + } + // The schema of the plan should also be updated if the child plan is transformed. + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + Ok(Transformed::yes( + SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias)?, + )) + } + LogicalPlan::Distinct(Distinct::On(distinct_on)) => { + let projected_expr = + expand_exprlist(&distinct_on.input, distinct_on.select_expr)?; + validate_unique_names("Distinct", projected_expr.iter())?; + Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::On( + DistinctOn::try_new( + distinct_on.on_expr, + projected_expr, + distinct_on.sort_expr, + distinct_on.input, + )?, + )))) + } + _ => Ok(Transformed::no(plan)), + } +} + +fn expand_exprlist(input: &LogicalPlan, expr: Vec) -> Result> { + let mut projected_expr = vec![]; + let input = find_base_plan(input); + for e in expr { + match e { + Expr::Wildcard { qualifier, options } => { + if let Some(qualifier) = qualifier { + let expanded = expand_qualified_wildcard( + &qualifier, + input.schema(), + Some(&options), + )?; + // If there is a REPLACE statement, replace that column with the given + // replace expression. Column name remains the same. + let replaced = if let Some(replace) = options.replace { + replace_columns(expanded, &replace)? + } else { + expanded + }; + projected_expr.extend(replaced); + } else { + let expanded = + expand_wildcard(input.schema(), input, Some(&options))?; + // If there is a REPLACE statement, replace that column with the given + // replace expression. Column name remains the same. + let replaced = if let Some(replace) = options.replace { + replace_columns(expanded, &replace)? + } else { + expanded + }; + projected_expr.extend(replaced); + } + } + // A workaround to handle the case when the column name is "*". + // We transform the expression to a Expr::Column through [Column::from_name] in many places. + // It would also convert the wildcard expression to a column expression with name "*". + Expr::Column(Column { + ref relation, + ref name, + }) => { + if name.eq("*") { + if let Some(qualifier) = relation { + projected_expr.extend(expand_qualified_wildcard( + qualifier, + input.schema(), + None, + )?); + } else { + projected_expr.extend(expand_wildcard( + input.schema(), + input, + None, + )?); + } + } else { + projected_expr.push(e.clone()); + } + } + _ => projected_expr.push(e), + } + } + Ok(projected_expr) +} + +/// If there is a REPLACE statement in the projected expression in the form of +/// "REPLACE (some_column_within_an_expr AS some_column)", this function replaces +/// that column with the given replace expression. Column name remains the same. +/// Multiple REPLACEs are also possible with comma separations. +fn replace_columns( + mut exprs: Vec, + replace: &PlannedReplaceSelectItem, +) -> Result> { + for expr in exprs.iter_mut() { + if let Expr::Column(Column { name, .. }) = expr { + if let Some((_, new_expr)) = replace + .items() + .iter() + .zip(replace.expressions().iter()) + .find(|(item, _)| item.column_name.value == *name) + { + *expr = new_expr.clone().alias(name.clone()) + } + } + } + Ok(exprs) +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::{DataType, Field, Schema}; + + use crate::test::{assert_analyzed_plan_eq_display_indent, test_table_scan}; + use crate::Analyzer; + use datafusion_common::{JoinType, TableReference}; + use datafusion_expr::{ + col, in_subquery, qualified_wildcard, table_scan, wildcard, LogicalPlanBuilder, + }; + + use super::*; + + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { + assert_analyzed_plan_eq_display_indent( + Arc::new(ExpandWildcardRule::new()), + plan, + expected, + ) + } + + #[test] + fn test_expand_wildcard() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![wildcard()])? + .build()?; + let expected = + "Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + + #[test] + fn test_expand_qualified_wildcard() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![qualified_wildcard(TableReference::bare("test"))])? + .build()?; + let expected = + "Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + + #[test] + fn test_expand_qualified_wildcard_in_subquery() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![qualified_wildcard(TableReference::bare("test"))])? + .build()?; + let plan = LogicalPlanBuilder::from(plan) + .project(vec![wildcard()])? + .build()?; + let expected = + "Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]\ + \n Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + + #[test] + fn test_expand_wildcard_in_subquery() -> Result<()> { + let projection_a = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![col("a")])? + .build()?; + let subquery = LogicalPlanBuilder::from(projection_a) + .project(vec![wildcard()])? + .build()?; + let plan = LogicalPlanBuilder::from(test_table_scan()?) + .filter(in_subquery(col("a"), Arc::new(subquery)))? + .project(vec![wildcard()])? + .build()?; + let expected = "\ + Projection: test.a, test.b, test.c [a:UInt32, b:UInt32, c:UInt32]\ + \n Filter: test.a IN () [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [a:UInt32]\ + \n Projection: test.a [a:UInt32]\ + \n Projection: test.a [a:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + + #[test] + fn test_expand_wildcard_in_distinct_on() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .distinct_on(vec![col("a")], vec![wildcard()], None)? + .build()?; + let expected = "\ + DistinctOn: on_expr=[[test.a]], select_expr=[[test.a, test.b, test.c]], sort_expr=[[]] [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(plan, expected) + } + + #[test] + fn test_subquery_schema() -> Result<()> { + let analyzer = Analyzer::with_rules(vec![Arc::new(ExpandWildcardRule::new())]); + let options = ConfigOptions::default(); + let subquery = LogicalPlanBuilder::from(test_table_scan()?) + .project(vec![wildcard()])? + .build()?; + let plan = LogicalPlanBuilder::from(subquery) + .alias("sub")? + .project(vec![wildcard()])? + .build()?; + let analyzed_plan = analyzer.execute_and_check(plan, &options, |_, _| {})?; + for x in analyzed_plan.inputs() { + for field in x.schema().fields() { + assert_ne!(field.name(), "*"); + } + } + Ok(()) + } + + fn employee_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Int32, false), + ]) + } + + #[test] + fn plan_using_join_wildcard_projection() -> Result<()> { + let t2 = table_scan(Some("t2"), &employee_schema(), None)?.build()?; + + let plan = table_scan(Some("t1"), &employee_schema(), None)? + .join_using(t2, JoinType::Inner, vec!["id"])? + .project(vec![wildcard()])? + .build()?; + + let expected = "Projection: *\ + \n Inner Join: Using t1.id = t2.id\ + \n TableScan: t1\ + \n TableScan: t2"; + + assert_eq!(expected, format!("{plan}")); + + let analyzer = Analyzer::with_rules(vec![Arc::new(ExpandWildcardRule::new())]); + let options = ConfigOptions::default(); + + let analyzed_plan = analyzer.execute_and_check(plan, &options, |_, _| {})?; + + // id column should only show up once in projection + let expected = "Projection: t1.id, t1.first_name, t1.last_name, t1.state, t1.salary, t2.first_name, t2.last_name, t2.state, t2.salary\ + \n Inner Join: Using t1.id = t2.id\ + \n TableScan: t1\ + \n TableScan: t2"; + assert_eq!(expected, format!("{analyzed_plan}")); + + Ok(()) + } +} diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index 098c934bf7e1..c6bf14ebce2e 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -29,7 +29,7 @@ use datafusion_expr::LogicalPlan; use std::sync::Arc; /// Analyzer rule that invokes [`FunctionRewrite`]s on expressions -#[derive(Default)] +#[derive(Default, Debug)] pub struct ApplyFunctionRewrites { /// Expr --> Function writes to apply function_rewrites: Vec>, @@ -48,7 +48,7 @@ impl ApplyFunctionRewrites { ) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(plan.inputs()); + let mut schema = merge_schema(&plan.inputs()); if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( @@ -61,7 +61,7 @@ impl ApplyFunctionRewrites { let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr)?; + let original_name = name_preserver.save(&expr); // recursively transform the expression, applying the rewrites at each step let transformed_expr = expr.transform_up(|expr| { @@ -74,7 +74,7 @@ impl ApplyFunctionRewrites { Ok(result) })?; - transformed_expr.map_data(|expr| original_name.restore(expr)) + Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index db1ce18e86f5..342d85a915b4 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -23,11 +23,12 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; -use datafusion_expr::{logical_plan::LogicalPlan, Expr, LogicalPlanBuilder, TableScan}; +use datafusion_expr::expr::WildcardOptions; +use datafusion_expr::{logical_plan::LogicalPlan, Expr, LogicalPlanBuilder}; /// Analyzed rule that inlines TableScan that provide a [`LogicalPlan`] /// (DataFrame / ViewTable) -#[derive(Default)] +#[derive(Default, Debug)] pub struct InlineTableScan; impl InlineTableScan { @@ -55,24 +56,23 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { match plan { // Match only on scans without filter / projection / fetch // Views and DataFrames won't have those added - // during the early stage of planning - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - filters, - .. - }) if filters.is_empty() && source.get_logical_plan().is_some() => { - let sub_plan = source.get_logical_plan().unwrap(); - let projection_exprs = generate_projection_expr(&projection, sub_plan)?; - LogicalPlanBuilder::from(sub_plan.clone()) - .project(projection_exprs)? - // Ensures that the reference to the inlined table remains the - // same, meaning we don't have to change any of the parent nodes - // that reference this table. - .alias(table_name)? - .build() - .map(Transformed::yes) + // during the early stage of planning. + LogicalPlan::TableScan(table_scan) if table_scan.filters.is_empty() => { + if let Some(sub_plan) = table_scan.source.get_logical_plan() { + let sub_plan = sub_plan.into_owned(); + let projection_exprs = + generate_projection_expr(&table_scan.projection, &sub_plan)?; + LogicalPlanBuilder::from(sub_plan) + .project(projection_exprs)? + // Ensures that the reference to the inlined table remains the + // same, meaning we don't have to change any of the parent nodes + // that reference this table. + .alias(table_scan.table_name)? + .build() + .map(Transformed::yes) + } else { + Ok(Transformed::no(LogicalPlan::TableScan(table_scan))) + } } _ => Ok(Transformed::no(plan)), } @@ -93,20 +93,23 @@ fn generate_projection_expr( ))); } } else { - exprs.push(Expr::Wildcard { qualifier: None }); + exprs.push(Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }); } Ok(exprs) } #[cfg(test)] mod tests { - use std::{sync::Arc, vec}; + use std::{borrow::Cow, sync::Arc, vec}; use crate::analyzer::inline_table_scan::InlineTableScan; use crate::test::assert_analyzed_plan_eq; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; + use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder, TableSource}; pub struct RawTableSource {} @@ -122,12 +125,14 @@ mod tests { ])) } - fn supports_filter_pushdown( + fn supports_filters_pushdown( &self, - _filter: &datafusion_expr::Expr, - ) -> datafusion_common::Result + filters: &[&Expr], + ) -> datafusion_common::Result> { - Ok(datafusion_expr::TableProviderFilterPushDown::Inexact) + Ok((0..filters.len()) + .map(|_| datafusion_expr::TableProviderFilterPushDown::Inexact) + .collect()) } } @@ -151,20 +156,22 @@ mod tests { self } - fn supports_filter_pushdown( + fn supports_filters_pushdown( &self, - _filter: &datafusion_expr::Expr, - ) -> datafusion_common::Result + filters: &[&Expr], + ) -> datafusion_common::Result> { - Ok(datafusion_expr::TableProviderFilterPushDown::Exact) + Ok((0..filters.len()) + .map(|_| datafusion_expr::TableProviderFilterPushDown::Exact) + .collect()) } fn schema(&self) -> arrow::datatypes::SchemaRef { Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])) } - fn get_logical_plan(&self) -> Option<&LogicalPlan> { - Some(&self.plan) + fn get_logical_plan(&self) -> Option> { + Some(Cow::Borrowed(&self.plan)) } } @@ -178,10 +185,10 @@ mod tests { let plan = scan.filter(col("x.a").eq(lit(1)))?.build()?; let expected = "Filter: x.a = Int32(1)\ \n SubqueryAlias: x\ - \n Projection: y.a, y.b\ + \n Projection: *\ \n TableScan: y"; - assert_analyzed_plan_eq(Arc::new(InlineTableScan::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(InlineTableScan::new()), plan, expected) } #[test] @@ -197,6 +204,6 @@ mod tests { \n Projection: y.a\ \n TableScan: y"; - assert_analyzed_plan_eq(Arc::new(InlineTableScan::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(InlineTableScan::new()), plan, expected) } } diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index fb0eb14da659..a9fd4900b2f4 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -16,6 +16,8 @@ // under the License. //! [`Analyzer`] and [`AnalyzerRule`] + +use std::fmt::Debug; use std::sync::Arc; use log::debug; @@ -30,7 +32,9 @@ use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{Expr, LogicalPlan}; use crate::analyzer::count_wildcard_rule::CountWildcardRule; +use crate::analyzer::expand_wildcard_rule::ExpandWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; +use crate::analyzer::resolve_grouping_function::ResolveGroupingFunction; use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; @@ -38,8 +42,10 @@ use crate::utils::log_plan; use self::function_rewrite::ApplyFunctionRewrites; pub mod count_wildcard_rule; +pub mod expand_wildcard_rule; pub mod function_rewrite; pub mod inline_table_scan; +pub mod resolve_grouping_function; pub mod subquery; pub mod type_coercion; @@ -57,8 +63,8 @@ pub mod type_coercion; /// Use [`SessionState::add_analyzer_rule`] to register additional /// `AnalyzerRule`s. /// -/// [`SessionState::add_analyzer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionState.html#method.add_analyzer_rule -pub trait AnalyzerRule { +/// [`SessionState::add_analyzer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_analyzer_rule +pub trait AnalyzerRule: Debug { /// Rewrite `plan` fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result; @@ -70,7 +76,7 @@ pub trait AnalyzerRule { /// /// An `Analyzer` transforms a `LogicalPlan` /// prior to the rest of the DataFusion optimization process. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Analyzer { /// Expr --> Function writes to apply prior to analysis passes pub function_rewrites: Vec>, @@ -89,6 +95,10 @@ impl Analyzer { pub fn new() -> Self { let rules: Vec> = vec![ Arc::new(InlineTableScan::new()), + // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. + Arc::new(ExpandWildcardRule::new()), + // [Expr::Wildcard] should be expanded before [TypeCoercion] + Arc::new(ResolveGroupingFunction::new()), Arc::new(TypeCoercion::new()), Arc::new(CountWildcardRule::new()), ]; @@ -111,11 +121,16 @@ impl Analyzer { self.function_rewrites.push(rewrite); } + /// return the list of function rewrites in this analyzer + pub fn function_rewrites(&self) -> &[Arc] { + &self.function_rewrites + } + /// Analyze the logical plan by applying analyzer rules, and /// do necessary check and fail the invalid plans pub fn execute_and_check( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &ConfigOptions, mut observer: F, ) -> Result @@ -123,7 +138,7 @@ impl Analyzer { F: FnMut(&LogicalPlan, &dyn AnalyzerRule), { let start_time = Instant::now(); - let mut new_plan = plan.clone(); + let mut new_plan = plan; // Create an analyzer pass that rewrites `Expr`s to function_calls, as // appropriate. @@ -131,9 +146,15 @@ impl Analyzer { // Note this is run before all other rules since it rewrites based on // the argument types (List or Scalar), and TypeCoercion may cast the // argument types from Scalar to List. - let expr_to_function: Arc = - Arc::new(ApplyFunctionRewrites::new(self.function_rewrites.clone())); - let rules = std::iter::once(&expr_to_function).chain(self.rules.iter()); + let expr_to_function: Option> = + if self.function_rewrites.is_empty() { + None + } else { + Some(Arc::new(ApplyFunctionRewrites::new( + self.function_rewrites.clone(), + ))) + }; + let rules = expr_to_function.iter().chain(self.rules.iter()); // TODO add common rule executor for Analyzer and Optimizer for rule in rules { diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs new file mode 100644 index 000000000000..16ebb8cd3972 --- /dev/null +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Analyzed rule to replace TableScan references +//! such as DataFrames and Views and inlines the LogicalPlan. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::sync::Arc; + +use crate::analyzer::AnalyzerRule; + +use arrow::datatypes::DataType; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{ + internal_datafusion_err, plan_err, Column, DFSchemaRef, Result, ScalarValue, +}; +use datafusion_expr::expr::{AggregateFunction, Alias}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::utils::grouping_set_to_exprlist; +use datafusion_expr::{ + bitwise_and, bitwise_or, bitwise_shift_left, bitwise_shift_right, cast, Aggregate, + Expr, Projection, +}; +use itertools::Itertools; + +/// Replaces grouping aggregation function with value derived from internal grouping id +#[derive(Default, Debug)] +pub struct ResolveGroupingFunction; + +impl ResolveGroupingFunction { + pub fn new() -> Self { + Self {} + } +} + +impl AnalyzerRule for ResolveGroupingFunction { + fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { + plan.transform_up(analyze_internal).data() + } + + fn name(&self) -> &str { + "resolve_grouping_function" + } +} + +/// Create a map from grouping expr to index in the internal grouping id. +/// +/// For more details on how the grouping id bitmap works the documentation for +/// [[Aggregate::INTERNAL_GROUPING_ID]] +fn group_expr_to_bitmap_index(group_expr: &[Expr]) -> Result> { + Ok(grouping_set_to_exprlist(group_expr)? + .into_iter() + .rev() + .enumerate() + .map(|(idx, v)| (v, idx)) + .collect::>()) +} + +fn replace_grouping_exprs( + input: Arc, + schema: DFSchemaRef, + group_expr: Vec, + aggr_expr: Vec, +) -> Result { + // Create HashMap from Expr to index in the grouping_id bitmap + let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]); + let group_expr_to_bitmap_index = group_expr_to_bitmap_index(&group_expr)?; + let columns = schema.columns(); + let mut new_agg_expr = Vec::new(); + let mut projection_exprs = Vec::new(); + let grouping_id_len = if is_grouping_set { 1 } else { 0 }; + let group_expr_len = columns.len() - aggr_expr.len() - grouping_id_len; + projection_exprs.extend( + columns + .iter() + .take(group_expr_len) + .map(|column| Expr::Column(column.clone())), + ); + for (expr, column) in aggr_expr + .into_iter() + .zip(columns.into_iter().skip(group_expr_len + grouping_id_len)) + { + match expr { + Expr::AggregateFunction(ref function) if is_grouping_function(&expr) => { + let grouping_expr = grouping_function_on_id( + function, + &group_expr_to_bitmap_index, + is_grouping_set, + )?; + projection_exprs.push(Expr::Alias(Alias::new( + grouping_expr, + column.relation, + column.name, + ))); + } + _ => { + projection_exprs.push(Expr::Column(column)); + new_agg_expr.push(expr); + } + } + } + // Recreate aggregate without grouping functions + let new_aggregate = + LogicalPlan::Aggregate(Aggregate::try_new(input, group_expr, new_agg_expr)?); + // Create projection with grouping functions calculations + let projection = LogicalPlan::Projection(Projection::try_new( + projection_exprs, + new_aggregate.into(), + )?); + Ok(projection) +} + +fn analyze_internal(plan: LogicalPlan) -> Result> { + // rewrite any subqueries in the plan first + let transformed_plan = + plan.map_subqueries(|plan| plan.transform_up(analyze_internal))?; + + let transformed_plan = transformed_plan.transform_data(|plan| match plan { + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + .. + }) if contains_grouping_function(&aggr_expr) => Ok(Transformed::yes( + replace_grouping_exprs(input, schema, group_expr, aggr_expr)?, + )), + _ => Ok(Transformed::no(plan)), + })?; + + Ok(transformed_plan) +} + +fn is_grouping_function(expr: &Expr) -> bool { + // TODO: Do something better than name here should grouping be a built + // in expression? + matches!(expr, Expr::AggregateFunction(AggregateFunction { ref func, .. }) if func.name() == "grouping") +} + +fn contains_grouping_function(exprs: &[Expr]) -> bool { + exprs.iter().any(is_grouping_function) +} + +/// Validate that the arguments to the grouping function are in the group by clause. +fn validate_args( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, +) -> Result<()> { + let expr_not_in_group_by = function + .args + .iter() + .find(|expr| !group_by_expr.contains_key(expr)); + if let Some(expr) = expr_not_in_group_by { + plan_err!( + "Argument {} to grouping function is not in grouping columns {}", + expr, + group_by_expr.keys().map(|e| e.to_string()).join(", ") + ) + } else { + Ok(()) + } +} + +fn grouping_function_on_id( + function: &AggregateFunction, + group_by_expr: &HashMap<&Expr, usize>, + is_grouping_set: bool, +) -> Result { + validate_args(function, group_by_expr)?; + let args = &function.args; + + // Postgres allows grouping function for group by without grouping sets, the result is then + // always 0 + if !is_grouping_set { + return Ok(Expr::Literal(ScalarValue::from(0i32))); + } + + let group_by_expr_count = group_by_expr.len(); + let literal = |value: usize| { + if group_by_expr_count < 8 { + Expr::Literal(ScalarValue::from(value as u8)) + } else if group_by_expr_count < 16 { + Expr::Literal(ScalarValue::from(value as u16)) + } else if group_by_expr_count < 32 { + Expr::Literal(ScalarValue::from(value as u32)) + } else { + Expr::Literal(ScalarValue::from(value as u64)) + } + }; + + let grouping_id_column = Expr::Column(Column::from(Aggregate::INTERNAL_GROUPING_ID)); + // The grouping call is exactly our internal grouping id + if args.len() == group_by_expr_count + && args + .iter() + .rev() + .enumerate() + .all(|(idx, expr)| group_by_expr.get(expr) == Some(&idx)) + { + return Ok(cast(grouping_id_column, DataType::Int32)); + } + + args.iter() + .rev() + .enumerate() + .map(|(arg_idx, expr)| { + group_by_expr.get(expr).map(|group_by_idx| { + let group_by_bit = + bitwise_and(grouping_id_column.clone(), literal(1 << group_by_idx)); + match group_by_idx.cmp(&arg_idx) { + Ordering::Less => { + bitwise_shift_left(group_by_bit, literal(arg_idx - group_by_idx)) + } + Ordering::Greater => { + bitwise_shift_right(group_by_bit, literal(group_by_idx - arg_idx)) + } + Ordering::Equal => group_by_bit, + } + }) + }) + .collect::>>() + .and_then(|bit_exprs| { + bit_exprs + .into_iter() + .reduce(bitwise_or) + .map(|expr| cast(expr, DataType::Int32)) + }) + .ok_or_else(|| { + internal_datafusion_err!("Grouping sets should contains at least one element") + }) +} diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index b46516017ae9..fa04835f0967 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::ops::Deref; - use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; @@ -24,18 +22,15 @@ use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; -use datafusion_expr::{ - Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, - Window, -}; +use datafusion_expr::{Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window}; /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, /// the allowed while list: [Projection, Filter, Window, Aggregate, Join]. /// 2) Check whether the inner plan is in the allowed inner plans list to use correlated(outer) expressions. /// 3) Check and validate unsupported cases to use the correlated(outer) expressions inside the subquery(inner) plans/inner expressions. -/// For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join -/// is a Full Out Join +/// For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join +/// is a Full Out Join pub fn check_subquery_expr( outer_plan: &LogicalPlan, inner_plan: &LogicalPlan, @@ -98,7 +93,7 @@ pub fn check_subquery_expr( ) }?; } - check_correlations_in_subquery(inner_plan, true) + check_correlations_in_subquery(inner_plan) } else { if let Expr::InSubquery(subquery) = expr { // InSubquery should only return one column @@ -118,28 +113,22 @@ pub fn check_subquery_expr( | LogicalPlan::Join(_) => Ok(()), _ => plan_err!( "In/Exist subquery can only be used in \ - Projection, Filter, Window functions, Aggregate and Join plan nodes" + Projection, Filter, Window functions, Aggregate and Join plan nodes, \ + but was used in [{}]", + outer_plan.display() ), }?; - check_correlations_in_subquery(inner_plan, false) + check_correlations_in_subquery(inner_plan) } } // Recursively check the unsupported outer references in the sub query plan. -fn check_correlations_in_subquery( - inner_plan: &LogicalPlan, - is_scalar: bool, -) -> Result<()> { - check_inner_plan(inner_plan, is_scalar, false, true) +fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> { + check_inner_plan(inner_plan, true) } // Recursively check the unsupported outer references in the sub query plan. -fn check_inner_plan( - inner_plan: &LogicalPlan, - is_scalar: bool, - is_aggregate: bool, - can_contain_outer_ref: bool, -) -> Result<()> { +fn check_inner_plan(inner_plan: &LogicalPlan, can_contain_outer_ref: bool) -> Result<()> { if !can_contain_outer_ref && inner_plan.contains_outer_reference() { return plan_err!("Accessing outer reference columns is not allowed in the plan"); } @@ -147,32 +136,18 @@ fn check_inner_plan( match inner_plan { LogicalPlan::Aggregate(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - LogicalPlan::Filter(Filter { - predicate, input, .. - }) => { - let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate) - .into_iter() - .partition(|e| e.contains_outer()); - let maybe_unsupport = correlated - .into_iter() - .filter(|expr| !can_pullup_over_aggregation(expr)) - .collect::>(); - if is_aggregate && is_scalar && !maybe_unsupport.is_empty() { - return plan_err!( - "Correlated column is not allowed in predicate: {predicate}" - ); - } - check_inner_plan(input, is_scalar, is_aggregate, can_contain_outer_ref) + LogicalPlan::Filter(Filter { input, .. }) => { + check_inner_plan(input, can_contain_outer_ref) } LogicalPlan::Window(window) => { check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -180,7 +155,6 @@ fn check_inner_plan( LogicalPlan::Projection(_) | LogicalPlan::Distinct(_) | LogicalPlan::Sort(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) | LogicalPlan::EmptyRelation(_) @@ -189,7 +163,7 @@ fn check_inner_plan( | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -202,27 +176,25 @@ fn check_inner_plan( }) => match join_type { JoinType::Inner => { inner_plan.apply_children(|plan| { - check_inner_plan( - plan, - is_scalar, - is_aggregate, - can_contain_outer_ref, - )?; + check_inner_plan(plan, can_contain_outer_ref)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - check_inner_plan(left, is_scalar, is_aggregate, can_contain_outer_ref)?; - check_inner_plan(right, is_scalar, is_aggregate, false) + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => { + check_inner_plan(left, can_contain_outer_ref)?; + check_inner_plan(right, false) } JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - check_inner_plan(left, is_scalar, is_aggregate, false)?; - check_inner_plan(right, is_scalar, is_aggregate, can_contain_outer_ref) + check_inner_plan(left, false)?; + check_inner_plan(right, can_contain_outer_ref) } JoinType::Full => { inner_plan.apply_children(|plan| { - check_inner_plan(plan, is_scalar, is_aggregate, false)?; + check_inner_plan(plan, false)?; Ok(TreeNodeRecursion::Continue) })?; Ok(()) @@ -245,11 +217,11 @@ fn check_aggregation_in_scalar_subquery( if !agg.group_expr.is_empty() { let correlated_exprs = get_correlated_expressions(inner_plan)?; let inner_subquery_cols = - collect_subquery_cols(&correlated_exprs, agg.input.schema().clone())?; + collect_subquery_cols(&correlated_exprs, agg.input.schema())?; let mut group_columns = agg .group_expr .iter() - .map(|group| Ok(group.to_columns()?.into_iter().collect::>())) + .map(|group| Ok(group.column_refs().into_iter().cloned().collect::>())) .collect::>>()? .into_iter() .flatten(); @@ -291,41 +263,12 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { Ok(exprs) } -/// Check whether the expression can pull up over the aggregation without change the result of the query -fn can_pullup_over_aggregation(expr: &Expr) -> bool { - if let Expr::BinaryExpr(BinaryExpr { - left, - op: Operator::Eq, - right, - }) = expr - { - match (left.deref(), right.deref()) { - (Expr::Column(_), right) if right.to_columns().unwrap().is_empty() => true, - (left, Expr::Column(_)) if left.to_columns().unwrap().is_empty() => true, - (Expr::Cast(Cast { expr, .. }), right) - if matches!(expr.deref(), Expr::Column(_)) - && right.to_columns().unwrap().is_empty() => - { - true - } - (left, Expr::Cast(Cast { expr, .. })) - if matches!(expr.deref(), Expr::Column(_)) - && left.to_columns().unwrap().is_empty() => - { - true - } - (_, _) => false, - } - } else { - false - } -} - /// Check whether the window expressions contain a mixture of out reference columns and inner columns fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { - let mixed = window.window_expr.iter().any(|win_expr| { - win_expr.contains_outer() && !win_expr.to_columns().unwrap().is_empty() - }); + let mixed = window + .window_expr + .iter() + .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs()); if mixed { plan_err!( "Window expressions should not contain a mixed of outer references and inner columns" @@ -337,6 +280,7 @@ fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> { #[cfg(test)] mod test { + use std::cmp::Ordering; use std::sync::Arc; use datafusion_common::{DFSchema, DFSchemaRef}; @@ -349,6 +293,12 @@ mod test { empty_schema: DFSchemaRef, } + impl PartialOrd for MockUserDefinedLogicalPlan { + fn partial_cmp(&self, _other: &Self) -> Option { + None + } + } + impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan { fn name(&self) -> &str { "MockUserDefinedLogicalPlan" @@ -358,7 +308,7 @@ mod test { vec![] } - fn schema(&self) -> &datafusion_common::DFSchemaRef { + fn schema(&self) -> &DFSchemaRef { &self.empty_schema } @@ -370,10 +320,18 @@ mod test { write!(f, "MockUserDefinedLogicalPlan") } - fn from_template(&self, _exprs: &[Expr], _inputs: &[LogicalPlan]) -> Self { - Self { - empty_schema: self.empty_schema.clone(), - } + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + _inputs: Vec, + ) -> Result { + Ok(Self { + empty_schema: Arc::clone(&self.empty_schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default } } @@ -385,6 +343,6 @@ mod test { }), }); - check_inner_plan(&plan, false, false, true).unwrap(); + check_inner_plan(&plan, true).unwrap(); } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index f96a359f9d47..9793c4c5490f 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,25 +19,31 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, IntervalUnit}; +use itertools::izip; +use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; + +use crate::analyzer::AnalyzerRule; +use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{ - exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, + exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, + DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::{ - self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, - InSubquery, Like, ScalarFunction, WindowFunction, + self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, + ScalarFunction, Sort, WindowFunction, }; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; +use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, }; -use datafusion_expr::type_coercion::functions::data_types; +use datafusion_expr::type_coercion::functions::{ + data_types_with_aggregate_udf, data_types_with_scalar_udf, +}; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; @@ -45,14 +51,14 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + AggregateUDF, Expr, ExprFunctionExt, ExprSchemable, Join, Limit, LogicalPlan, + Operator, Projection, ScalarUDF, Union, WindowFrame, WindowFrameBound, WindowFrameUnits, }; -use crate::analyzer::AnalyzerRule; - -#[derive(Default)] +/// Performs type coercion by determining the schema +/// and performing the expression rewrites. +#[derive(Default, Debug)] pub struct TypeCoercion {} impl TypeCoercion { @@ -61,32 +67,54 @@ impl TypeCoercion { } } +/// Coerce output schema based upon optimizer config. +fn coerce_output(plan: LogicalPlan, config: &ConfigOptions) -> Result { + if !config.optimizer.expand_views_at_output { + return Ok(plan); + } + + let outer_refs = plan.expressions(); + if outer_refs.is_empty() { + return Ok(plan); + } + + if let Some(dfschema) = transform_schema_to_nonview(plan.schema()) { + coerce_plan_expr_for_schema(plan, &dfschema?) + } else { + Ok(plan) + } +} + impl AnalyzerRule for TypeCoercion { fn name(&self) -> &str { "type_coercion" } - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&DFSchema::empty(), &plan) + fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { + let empty_schema = DFSchema::empty(); + + // recurse + let transformed_plan = plan + .transform_up_with_subqueries(|plan| analyze_internal(&empty_schema, plan))? + .data; + + // finish + coerce_output(transformed_plan, config) } } +/// use the external schema to handle the correlated subqueries case +/// +/// Assumes that children have already been optimized fn analyze_internal( - // use the external schema to handle the correlated subqueries case external_schema: &DFSchema, - plan: &LogicalPlan, -) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(external_schema, p)) - .collect::>>()?; + plan: LogicalPlan, +) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); + let mut schema = merge_schema(&plan.inputs()); - if let LogicalPlan::TableScan(ts) = plan { + if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( ts.table_name.clone(), &ts.source.schema(), @@ -99,47 +127,194 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - let mut expr_rewrite = TypeCoercionRewriter { - schema: Arc::new(schema), + // Coerce filter predicates to boolean (handles `WHERE NULL`) + let plan = if let LogicalPlan::Filter(mut filter) = plan { + filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?; + LogicalPlan::Filter(filter) + } else { + plan }; - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure aggregate names don't change: - // https://github.com/apache/datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; + let mut expr_rewrite = TypeCoercionRewriter::new(&schema); + + let name_preserver = NamePreserver::new(&plan); + // apply coercion rewrite all expressions in the plan individually + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr); + expr.rewrite(&mut expr_rewrite) + .map(|transformed| transformed.update_data(|e| original_name.restore(e))) + })? + // some plans need extra coercion after their expressions are coerced + .map_data(|plan| expr_rewrite.coerce_plan(plan))? + // recompute the schema after the expressions have been rewritten as the types may have changed + .map_data(|plan| plan.recompute_schema()) +} - plan.with_new_exprs(new_expr, new_inputs) +/// Rewrite expressions to apply type coercion. +pub struct TypeCoercionRewriter<'a> { + pub(crate) schema: &'a DFSchema, } -pub(crate) struct TypeCoercionRewriter { - pub(crate) schema: DFSchemaRef, +impl<'a> TypeCoercionRewriter<'a> { + /// Create a new [`TypeCoercionRewriter`] with a provided schema + /// representing both the inputs and output of the [`LogicalPlan`] node. + fn new(schema: &'a DFSchema) -> Self { + Self { schema } + } + + /// Coerce the [`LogicalPlan`]. + /// + /// Refer to [`TypeCoercionRewriter::coerce_join`] and [`TypeCoercionRewriter::coerce_union`] + /// for type-coercion approach. + pub fn coerce_plan(&mut self, plan: LogicalPlan) -> Result { + match plan { + LogicalPlan::Join(join) => self.coerce_join(join), + LogicalPlan::Union(union) => Self::coerce_union(union), + LogicalPlan::Limit(limit) => Self::coerce_limit(limit), + _ => Ok(plan), + } + } + + /// Coerce join equality expressions and join filter + /// + /// Joins must be treated specially as their equality expressions are stored + /// as a parallel list of left and right expressions, rather than a single + /// equality expression + /// + /// For example, on_exprs like `t1.a = t2.b AND t1.x = t2.y` will be stored + /// as a list of `(t1.a, t2.b), (t1.x, t2.y)` + pub fn coerce_join(&mut self, mut join: Join) -> Result { + join.on = join + .on + .into_iter() + .map(|(lhs, rhs)| { + // coerce the arguments as though they were a single binary equality + // expression + let (lhs, rhs) = self.coerce_binary_op(lhs, Operator::Eq, rhs)?; + Ok((lhs, rhs)) + }) + .collect::>>()?; + + // Join filter must be boolean + join.filter = join + .filter + .map(|expr| self.coerce_join_filter(expr)) + .transpose()?; + + Ok(LogicalPlan::Join(join)) + } + + /// Coerce the union’s inputs to a common schema compatible with all inputs. + /// This occurs after wildcard expansion and the coercion of the input expressions. + pub fn coerce_union(union_plan: Union) -> Result { + let union_schema = Arc::new(coerce_union_schema(&union_plan.inputs)?); + let new_inputs = union_plan + .inputs + .into_iter() + .map(|p| { + let plan = + coerce_plan_expr_for_schema(Arc::unwrap_or_clone(p), &union_schema)?; + match plan { + LogicalPlan::Projection(Projection { expr, input, .. }) => { + Ok(Arc::new(project_with_column_index( + expr, + input, + Arc::clone(&union_schema), + )?)) + } + other_plan => Ok(Arc::new(other_plan)), + } + }) + .collect::>>()?; + Ok(LogicalPlan::Union(Union { + inputs: new_inputs, + schema: union_schema, + })) + } + + /// Coerce the fetch and skip expression to Int64 type. + fn coerce_limit(limit: Limit) -> Result { + fn coerce_limit_expr( + expr: Expr, + schema: &DFSchema, + expr_name: &str, + ) -> Result { + let dt = expr.get_type(schema)?; + if dt.is_integer() || dt.is_null() { + expr.cast_to(&DataType::Int64, schema) + } else { + plan_err!("Expected {expr_name} to be an integer or null, but got {dt:?}") + } + } + + let empty_schema = DFSchema::empty(); + let new_fetch = limit + .fetch + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "LIMIT")) + .transpose()?; + let new_skip = limit + .skip + .map(|expr| coerce_limit_expr(*expr, &empty_schema, "OFFSET")) + .transpose()?; + Ok(LogicalPlan::Limit(Limit { + input: limit.input, + fetch: new_fetch.map(Box::new), + skip: new_skip.map(Box::new), + })) + } + + fn coerce_join_filter(&self, expr: Expr) -> Result { + let expr_type = expr.get_type(self.schema)?; + match expr_type { + DataType::Boolean => Ok(expr), + DataType::Null => expr.cast_to(&DataType::Boolean, self.schema), + other => plan_err!("Join condition must be boolean type, but got {other:?}"), + } + } + + fn coerce_binary_op( + &self, + left: Expr, + op: Operator, + right: Expr, + ) -> Result<(Expr, Expr)> { + let (left_type, right_type) = get_input_types( + &left.get_type(self.schema)?, + &op, + &right.get_type(self.schema)?, + )?; + Ok(( + left.cast_to(&left_type, self.schema)?, + right.cast_to(&right_type, self.schema)?, + )) + } } -impl TreeNodeRewriter for TypeCoercionRewriter { +impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { match expr { - Expr::Unnest(_) => internal_err!( + Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" ), Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns, }) => { - let new_plan = analyze_internal(&self.schema, &subquery)?; + let new_plan = + analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, }))) } Expr::Exists(Exists { subquery, negated }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), @@ -153,8 +328,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery, negated, }) => { - let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - let expr_type = expr.get_type(&self.schema)?; + let new_plan = analyze_internal( + self.schema, + Arc::unwrap_or_clone(subquery.subquery), + )? + .data; + let expr_type = expr.get_type(self.schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!( "expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery" @@ -165,32 +344,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter { outer_ref_columns: subquery.outer_ref_columns, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, &self.schema)?), + Box::new(expr.cast_to(&common_type, self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - &self.schema, + self.schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, &self.schema)?, + get_casted_expr_for_bool_op(*expr, self.schema)?, ))), Expr::Like(Like { negated, @@ -199,8 +378,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(&self.schema)?; - let right_type = pattern.get_type(&self.schema)?; + let left_type = expr.get_type(self.schema)?; + let right_type = pattern.get_type(self.schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -211,8 +390,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression" ) })?; - let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); - let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); + let expr = match left_type { + DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, + _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), + }; + let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -222,15 +404,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let (left_type, right_type) = get_input_types( - &left.get_type(&self.schema)?, - &op, - &right.get_type(&self.schema)?, - )?; + let (left, right) = self.coerce_binary_op(*left, op, *right)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left.cast_to(&left_type, &self.schema)?), + Box::new(left), op, - Box::new(right.cast_to(&right_type, &self.schema)?), + Box::new(right), )))) } Expr::Between(Between { @@ -239,15 +417,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { low, high, }) => { - let expr_type = expr.get_type(&self.schema)?; - let low_type = low.get_type(&self.schema)?; + let expr_type = expr.get_type(self.schema)?; + let low_type = low.get_type(self.schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" )) })?; - let high_type = high.get_type(&self.schema)?; + let high_type = high.get_type(self.schema)?; let high_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { DataFusionError::Internal(format!( @@ -262,10 +440,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { )) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, &self.schema)?), + Box::new(expr.cast_to(&coercion_type, self.schema)?), negated, - Box::new(low.cast_to(&coercion_type, &self.schema)?), - Box::new(high.cast_to(&coercion_type, &self.schema)?), + Box::new(low.cast_to(&coercion_type, self.schema)?), + Box::new(high.cast_to(&coercion_type, self.schema)?), )))) } Expr::InList(InList { @@ -273,10 +451,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list, negated, }) => { - let expr_data_type = expr.get_type(&self.schema)?; + let expr_data_type = expr.get_type(self.schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(&self.schema)) + .map(|list_expr| list_expr.get_type(self.schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -286,11 +464,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, &self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, self.schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, &self.schema) + list_expr.cast_to(&coerced_type, self.schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -302,73 +480,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } Expr::Case(case) => { - let case = coerce_case_expression(case, &self.schema)?; + let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature( - args, - &self.schema, - fun.signature(), - )?; - let new_expr = - coerce_arguments_for_fun(new_expr, &self.schema, &fun)?; - Ok(Transformed::yes(Expr::ScalarFunction( - ScalarFunction::new_udf(fun, new_expr), - ))) - } - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }, + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let new_expr = coerce_arguments_for_signature_with_scalar_udf( + args, + self.schema, + &func, + )?; + Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf(func, new_expr), + ))) + } Expr::AggregateFunction(expr::AggregateFunction { - func_def, + func, args, distinct, filter, order_by, null_treatment, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - args, - &self.schema, - &fun.signature(), - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new( - fun, - new_expr, - distinct, - filter, - order_by, - null_treatment, - ), - ))) - } - AggregateFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature( - args, - &self.schema, - fun.signature(), - )?; - Ok(Transformed::yes(Expr::AggregateFunction( - expr::AggregateFunction::new_udf( - fun, - new_expr, - false, - filter, - order_by, - null_treatment, - ), - ))) - } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => { + let new_expr = coerce_arguments_for_signature_with_aggregate_udf( + args, + self.schema, + &func, + )?; + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + func, + new_expr, + distinct, + filter, + order_by, + null_treatment, + ), + ))) + } Expr::WindowFunction(WindowFunction { fun, args, @@ -378,28 +526,27 @@ impl TreeNodeRewriter for TypeCoercionRewriter { null_treatment, }) => { let window_frame = - coerce_window_frame(window_frame, &self.schema, &order_by)?; + coerce_window_frame(window_frame, self.schema, &order_by)?; let args = match &fun { - expr::WindowFunctionDefinition::AggregateFunction(fun) => { - coerce_agg_exprs_for_signature( - fun, + expr::WindowFunctionDefinition::AggregateUDF(udf) => { + coerce_arguments_for_signature_with_aggregate_udf( args, - &self.schema, - &fun.signature(), + self.schema, + udf, )? } _ => args, }; - Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( - fun, - args, - partition_by, - order_by, - window_frame, - null_treatment, - )))) + Ok(Transformed::yes( + Expr::WindowFunction(WindowFunction::new(fun, args)) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build()?, + )) } Expr::Alias(_) | Expr::Column(_) @@ -409,10 +556,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { | Expr::IsNotNull(_) | Expr::IsNull(_) | Expr::Negative(_) - | Expr::GetIndexedField(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::Sort(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) @@ -421,6 +566,55 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } } +/// Transform a schema to use non-view types for Utf8View and BinaryView +fn transform_schema_to_nonview(dfschema: &DFSchemaRef) -> Option> { + let metadata = dfschema.as_arrow().metadata.clone(); + let mut transformed = false; + + let (qualifiers, transformed_fields): (Vec>, Vec>) = + dfschema + .iter() + .map(|(qualifier, field)| match field.data_type() { + DataType::Utf8View => { + transformed = true; + ( + qualifier.cloned() as Option, + Arc::new(Field::new( + field.name(), + DataType::LargeUtf8, + field.is_nullable(), + )), + ) + } + DataType::BinaryView => { + transformed = true; + ( + qualifier.cloned() as Option, + Arc::new(Field::new( + field.name(), + DataType::LargeBinary, + field.is_nullable(), + )), + ) + } + _ => ( + qualifier.cloned() as Option, + Arc::clone(field), + ), + }) + .unzip(); + + if !transformed { + return None; + } + + let schema = Schema::new_with_metadata(transformed_fields, metadata); + Some(DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(schema), + )) +} + /// Casts the given `value` to `target_type`. Note that this function /// only considers `Null` or `Utf8` values. fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result { @@ -451,12 +645,12 @@ fn coerce_scalar(target_type: &DataType, value: &ScalarValue) -> Result Result { - coerce_scalar(target_type, &value).or_else(|err| { + coerce_scalar(target_type, value).or_else(|err| { // If type coercion fails, check if the largest type in family works: if let Some(largest_type) = get_widest_type_in_family(target_type) { - coerce_scalar(largest_type, &value).map_or_else( + coerce_scalar(largest_type, value).map_or_else( |_| exec_err!("Cannot cast {value:?} to {target_type:?}"), |_| ScalarValue::try_from(target_type), ) @@ -485,56 +679,61 @@ fn coerce_frame_bound( ) -> Result { match bound { WindowFrameBound::Preceding(v) => { - coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Preceding) + coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Preceding) } WindowFrameBound::CurrentRow => Ok(WindowFrameBound::CurrentRow), WindowFrameBound::Following(v) => { - coerce_scalar_range_aware(target_type, v).map(WindowFrameBound::Following) + coerce_scalar_range_aware(target_type, &v).map(WindowFrameBound::Following) } } } +fn extract_window_frame_target_type(col_type: &DataType) -> Result { + if col_type.is_numeric() + || is_utf8_or_large_utf8(col_type) + || matches!(col_type, DataType::Null) + { + Ok(col_type.clone()) + } else if is_datetime(col_type) { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } else if let DataType::Dictionary(_, value_type) = col_type { + extract_window_frame_target_type(value_type) + } else { + return internal_err!("Cannot run range queries on datatype: {col_type:?}"); + } +} + // Coerces the given `window_frame` to use appropriate natural types. // For example, ROWS and GROUPS frames use `UInt64` during calculations. fn coerce_window_frame( window_frame: WindowFrame, - schema: &DFSchemaRef, - expressions: &[Expr], + schema: &DFSchema, + expressions: &[Sort], ) -> Result { let mut window_frame = window_frame; - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; let target_type = match window_frame.units { WindowFrameUnits::Range => { - if let Some(col_type) = current_types.first() { - if col_type.is_numeric() - || is_utf8_or_large_utf8(col_type) - || matches!(col_type, DataType::Null) - { - col_type - } else if is_datetime(col_type) { - &DataType::Interval(IntervalUnit::MonthDayNano) - } else { - return internal_err!( - "Cannot run range queries on datatype: {col_type:?}" - ); - } + let current_types = expressions + .first() + .map(|s| s.expr.get_type(schema)) + .transpose()?; + if let Some(col_type) = current_types { + extract_window_frame_target_type(&col_type)? } else { return internal_err!("ORDER BY column cannot be empty"); } } - WindowFrameUnits::Rows | WindowFrameUnits::Groups => &DataType::UInt64, + WindowFrameUnits::Rows | WindowFrameUnits::Groups => DataType::UInt64, }; - window_frame.start_bound = coerce_frame_bound(target_type, window_frame.start_bound)?; - window_frame.end_bound = coerce_frame_bound(target_type, window_frame.end_bound)?; + window_frame.start_bound = + coerce_frame_bound(&target_type, window_frame.start_bound)?; + window_frame.end_bound = coerce_frame_bound(&target_type, window_frame.end_bound)?; Ok(window_frame) } // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion. // The above op will be rewrite to the binary op when creating the physical op. -fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result { +fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { let left_type = expr.get_type(schema)?; get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?; expr.cast_to(&DataType::Boolean, schema) @@ -544,10 +743,10 @@ fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result /// `signature`, if possible. /// /// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature( +fn coerce_arguments_for_signature_with_scalar_udf( expressions: Vec, schema: &DFSchema, - signature: &Signature, + func: &ScalarUDF, ) -> Result> { if expressions.is_empty() { return Ok(expressions); @@ -558,7 +757,7 @@ fn coerce_arguments_for_signature( .map(|e| e.get_type(schema)) .collect::>>()?; - let new_types = data_types(¤t_types, signature)?; + let new_types = data_types_with_scalar_udf(¤t_types, func)?; expressions .into_iter() @@ -567,58 +766,34 @@ fn coerce_arguments_for_signature( .collect() } -fn coerce_arguments_for_fun( +/// Returns `expressions` coerced to types compatible with +/// `signature`, if possible. +/// +/// See the module level documentation for more detail on coercion. +fn coerce_arguments_for_signature_with_aggregate_udf( expressions: Vec, schema: &DFSchema, - fun: &Arc, + func: &AggregateUDF, ) -> Result> { - // Cast Fixedsizelist to List for array functions - if fun.name() == "make_array" { - expressions - .into_iter() - .map(|expr| { - let data_type = expr.get_type(schema).unwrap(); - if let DataType::FixedSizeList(field, _) = data_type { - let to_type = DataType::List(field.clone()); - expr.cast_to(&to_type, schema) - } else { - Ok(expr) - } - }) - .collect() - } else { - Ok(expressions) + if expressions.is_empty() { + return Ok(expressions); } -} -/// Returns the coerced exprs for each `input_exprs`. -/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the -/// data type of `input_exprs` need to be coerced. -fn coerce_agg_exprs_for_signature( - agg_fun: &AggregateFunction, - input_exprs: Vec, - schema: &DFSchema, - signature: &Signature, -) -> Result> { - if input_exprs.is_empty() { - return Ok(input_exprs); - } - let current_types = input_exprs + let current_types = expressions .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, ¤t_types, signature)?; + let new_types = data_types_with_aggregate_udf(¤t_types, func)?; - input_exprs + expressions .into_iter() .enumerate() - .map(|(i, expr)| expr.cast_to(&coerced_types[i], schema)) + .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) .collect() } -fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { +fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result { // Given expressions like: // // CASE a1 @@ -729,6 +904,111 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { Ok(Case::new(case_expr, when_then, else_expr)) } +/// Get a common schema that is compatible with all inputs of UNION. +/// +/// This method presumes that the wildcard expansion is unneeded, or has already +/// been applied. +pub fn coerce_union_schema(inputs: &[Arc]) -> Result { + let base_schema = inputs[0].schema(); + let mut union_datatypes = base_schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); + let mut union_nullabilities = base_schema + .fields() + .iter() + .map(|f| f.is_nullable()) + .collect::>(); + let mut union_field_meta = base_schema + .fields() + .iter() + .map(|f| f.metadata().clone()) + .collect::>(); + + let mut metadata = base_schema.metadata().clone(); + + for (i, plan) in inputs.iter().enumerate().skip(1) { + let plan_schema = plan.schema(); + metadata.extend(plan_schema.metadata().clone()); + + if plan_schema.fields().len() != base_schema.fields().len() { + return plan_err!( + "Union schemas have different number of fields: \ + query 1 has {} fields whereas query {} has {} fields", + base_schema.fields().len(), + i + 1, + plan_schema.fields().len() + ); + } + + // coerce data type and nullablity for each field + for (union_datatype, union_nullable, union_field_map, plan_field) in izip!( + union_datatypes.iter_mut(), + union_nullabilities.iter_mut(), + union_field_meta.iter_mut(), + plan_schema.fields().iter() + ) { + let coerced_type = + comparison_coercion(union_datatype, plan_field.data_type()).ok_or_else( + || { + plan_datafusion_err!( + "Incompatible inputs for Union: Previous inputs were \ + of type {}, but got incompatible type {} on column '{}'", + union_datatype, + plan_field.data_type(), + plan_field.name() + ) + }, + )?; + + *union_datatype = coerced_type; + *union_nullable = *union_nullable || plan_field.is_nullable(); + union_field_map.extend(plan_field.metadata().clone()); + } + } + let union_qualified_fields = izip!( + base_schema.iter(), + union_datatypes.into_iter(), + union_nullabilities, + union_field_meta.into_iter() + ) + .map(|((qualifier, field), datatype, nullable, metadata)| { + let mut field = Field::new(field.name().clone(), datatype, nullable); + field.set_metadata(metadata); + (qualifier.cloned(), field.into()) + }) + .collect::>(); + + DFSchema::new_with_metadata(union_qualified_fields, metadata) +} + +/// See `` +fn project_with_column_index( + expr: Vec, + input: Arc, + schema: DFSchemaRef, +) -> Result { + let alias_expr = expr + .into_iter() + .enumerate() + .map(|(i, e)| match e { + Expr::Alias(Alias { ref name, .. }) if name != schema.field(i).name() => { + e.unalias().alias(schema.field(i).name()) + } + Expr::Column(Column { + relation: _, + ref name, + }) if name != schema.field(i).name() => e.alias(schema.field(i).name()), + Expr::Alias { .. } | Expr::Column { .. } => e, + _ => e.alias(schema.field(i).name()), + }) + .collect::>(); + + Projection::try_new_with_schema(alias_expr, input, schema) + .map(LogicalPlan::Projection) +} + #[cfg(test)] mod test { use std::any::Any; @@ -737,22 +1017,24 @@ mod test { use arrow::datatypes::DataType::Utf8; use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; - use datafusion_expr::logical_plan::{EmptyRelation, Projection}; + use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; + use datafusion_expr::test::function_stub::avg_udaf; use datafusion_expr::{ - cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, - AggregateFunction, AggregateUDF, BinaryExpr, Case, ColumnarValue, Expr, - ExprSchemable, Filter, LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, - Signature, SimpleAggregateUDF, Subquery, Volatility, + cast, col, create_udaf, is_true, lit, AccumulatorFactoryFunction, AggregateUDF, + BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Filter, LogicalPlan, + Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, Subquery, + Volatility, }; - use datafusion_physical_expr::expressions::AvgAccumulator; + use datafusion_functions_aggregate::average::AvgAccumulator; use crate::analyzer::type_coercion::{ coerce_case_expression, TypeCoercion, TypeCoercionRewriter, }; - use crate::test::assert_analyzed_plan_eq; + use crate::test::{assert_analyzed_plan_eq, assert_analyzed_plan_with_config_eq}; fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { @@ -765,7 +1047,7 @@ mod test { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, schema: Arc::new( - DFSchema::from_unqualifed_fields( + DFSchema::from_unqualified_fields( vec![Field::new("a", data_type, true)].into(), std::collections::HashMap::new(), ) @@ -780,7 +1062,156 @@ mod test { let empty = empty_with_type(DataType::Float64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a < CAST(UInt32(2) AS Float64)\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) + } + + fn coerce_on_output_if_viewtype(plan: LogicalPlan, expected: &str) -> Result<()> { + let mut options = ConfigOptions::default(); + options.optimizer.expand_views_at_output = true; + + assert_analyzed_plan_with_config_eq( + options, + Arc::new(TypeCoercion::new()), + plan.clone(), + expected, + ) + } + + fn do_not_coerce_on_output(plan: LogicalPlan, expected: &str) -> Result<()> { + assert_analyzed_plan_with_config_eq( + ConfigOptions::default(), + Arc::new(TypeCoercion::new()), + plan.clone(), + expected, + ) + } + + #[test] + fn coerce_utf8view_output() -> Result<()> { + // Plan A + // scenario: outermost utf8view projection + let expr = col("a"); + let empty = empty_with_type(DataType::Utf8View); + let plan = LogicalPlan::Projection(Projection::try_new( + vec![expr.clone()], + Arc::clone(&empty), + )?); + // Plan A: no coerce + let if_not_coerced = "Projection: a\n EmptyRelation"; + do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + // Plan A: coerce requested: Utf8View => LargeUtf8 + let if_coerced = "Projection: CAST(a AS LargeUtf8)\n EmptyRelation"; + coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + + // Plan B + // scenario: outermost bool projection + let bool_expr = col("a").lt(lit("foo")); + let bool_plan = LogicalPlan::Projection(Projection::try_new( + vec![bool_expr], + Arc::clone(&empty), + )?); + // Plan B: no coerce + let if_not_coerced = + "Projection: a < CAST(Utf8(\"foo\") AS Utf8View)\n EmptyRelation"; + do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + // Plan B: coerce requested: no coercion applied + let if_coerced = if_not_coerced; + coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + + // Plan C + // scenario: with a non-projection root logical plan node + let sort_expr = expr.sort(true, true); + let sort_plan = LogicalPlan::Sort(Sort { + expr: vec![sort_expr], + input: Arc::new(plan), + fetch: None, + }); + // Plan C: no coerce + let if_not_coerced = + "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + // Plan C: coerce requested: Utf8View => LargeUtf8 + let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + + // Plan D + // scenario: two layers of projections with view types + let plan = LogicalPlan::Projection(Projection::try_new( + vec![col("a")], + Arc::new(sort_plan), + )?); + // Plan D: no coerce + let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + // Plan B: coerce requested: Utf8View => LargeUtf8 only on outermost + let if_coerced = "Projection: CAST(a AS LargeUtf8)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + + Ok(()) + } + + #[test] + fn coerce_binaryview_output() -> Result<()> { + // Plan A + // scenario: outermost binaryview projection + let expr = col("a"); + let empty = empty_with_type(DataType::BinaryView); + let plan = LogicalPlan::Projection(Projection::try_new( + vec![expr.clone()], + Arc::clone(&empty), + )?); + // Plan A: no coerce + let if_not_coerced = "Projection: a\n EmptyRelation"; + do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + // Plan A: coerce requested: BinaryView => LargeBinary + let if_coerced = "Projection: CAST(a AS LargeBinary)\n EmptyRelation"; + coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + + // Plan B + // scenario: outermost bool projection + let bool_expr = col("a").lt(lit(vec![8, 1, 8, 1])); + let bool_plan = LogicalPlan::Projection(Projection::try_new( + vec![bool_expr], + Arc::clone(&empty), + )?); + // Plan B: no coerce + let if_not_coerced = + "Projection: a < CAST(Binary(\"8,1,8,1\") AS BinaryView)\n EmptyRelation"; + do_not_coerce_on_output(bool_plan.clone(), if_not_coerced)?; + // Plan B: coerce requested: no coercion applied + let if_coerced = if_not_coerced; + coerce_on_output_if_viewtype(bool_plan, if_coerced)?; + + // Plan C + // scenario: with a non-projection root logical plan node + let sort_expr = expr.sort(true, true); + let sort_plan = LogicalPlan::Sort(Sort { + expr: vec![sort_expr], + input: Arc::new(plan), + fetch: None, + }); + // Plan C: no coerce + let if_not_coerced = + "Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + do_not_coerce_on_output(sort_plan.clone(), if_not_coerced)?; + // Plan C: coerce requested: BinaryView => LargeBinary + let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + coerce_on_output_if_viewtype(sort_plan.clone(), if_coerced)?; + + // Plan D + // scenario: two layers of projections with view types + let plan = LogicalPlan::Projection(Projection::try_new( + vec![col("a")], + Arc::new(sort_plan), + )?); + // Plan D: no coerce + let if_not_coerced = "Projection: a\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + do_not_coerce_on_output(plan.clone(), if_not_coerced)?; + // Plan B: coerce requested: BinaryView => LargeBinary only on outermost + let if_coerced = "Projection: CAST(a AS LargeBinary)\n Sort: a ASC NULLS FIRST\n Projection: a\n EmptyRelation"; + coerce_on_output_if_viewtype(plan.clone(), if_coerced)?; + + Ok(()) } #[test] @@ -794,7 +1225,7 @@ mod test { )?); let expected = "Projection: a < CAST(UInt32(2) AS Float64) OR a < CAST(UInt32(2) AS Float64)\ \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[derive(Debug, Clone)] @@ -816,7 +1247,7 @@ mod test { } fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::Utf8) + Ok(Utf8) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { @@ -835,7 +1266,7 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![udf], empty)?); let expected = "Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[test] @@ -845,12 +1276,9 @@ mod test { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), }) .call(vec![lit("Apple")]); - let plan_err = Projection::try_new(vec![udf], empty) + Projection::try_new(vec![udf], empty) .expect_err("Expected an error due to incorrect function input"); - let expected_error = "Error during planning: No function matches the given name and argument types 'TestScalarUDF(Utf8)'. You might need to add explicit type casts."; - - assert!(plan_err.to_string().starts_with(expected_error)); Ok(()) } @@ -870,7 +1298,7 @@ mod test { )?); let expected = "Projection: TestScalarUDF(CAST(Int64(10) AS Float32))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[test] @@ -894,7 +1322,7 @@ mod test { )); let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); let expected = "Projection: MY_AVG(CAST(Int64(10) AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[test] @@ -921,13 +1349,10 @@ mod test { None, None, )); - let plan = LogicalPlan::Projection(Projection::try_new(vec![udaf], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, "") - .err() - .unwrap(); - assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.", - err.strip_backtrace() + + let err = Projection::try_new(vec![udaf], empty).err().unwrap(); + assert!( + err.strip_backtrace().starts_with("Error during planning: Error during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed") ); Ok(()) } @@ -935,41 +1360,38 @@ mod test { #[test] fn agg_function_case() -> Result<()> { let empty = empty(); - let fun: AggregateFunction = AggregateFunction::Avg; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![lit(12i64)], + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + avg_udaf(), + vec![lit(12f64)], false, None, None, None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + let expected = "Projection: avg(Float64(12))\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int32); - let fun: AggregateFunction = AggregateFunction::Avg; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![col("a")], + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + avg_udaf(), + vec![cast(col("a"), DataType::Float64)], false, None, None, None, )); let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?); - let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + let expected = "Projection: avg(CAST(a AS Float64))\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } #[test] fn agg_function_invalid_input_avg() -> Result<()> { let empty = empty(); - let fun: AggregateFunction = AggregateFunction::Avg; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + avg_udaf(), vec![lit("1")], false, None, @@ -980,48 +1402,20 @@ mod test { .err() .unwrap() .strip_backtrace(); - assert_eq!( - "Error during planning: No function matches the given name and argument types 'AVG(Utf8)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tAVG(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64)", - err - ); + assert!(err.starts_with("Error during planning: Error during planning: Coercion from [Utf8] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed.")); Ok(()) } - #[test] - fn agg_function_invalid_input_percentile() { - let empty = empty(); - let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont; - let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - vec![lit(0.95), lit(42.0), lit(100.0)], - false, - None, - None, - None, - )); - - let err = Projection::try_new(vec![agg_expr], empty) - .err() - .unwrap() - .strip_backtrace(); - - let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:"; - assert!(!err - .strip_prefix(prefix) - .unwrap() - .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)")); - } - #[test] fn binary_op_date32_op_interval() -> Result<()> { - //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640") + // CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("...") let expr = cast(lit("1998-03-18"), DataType::Date32) - + lit(ScalarValue::IntervalDayTime(Some(386547056640))); + + lit(ScalarValue::new_interval_dt(123, 456)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = - "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"386547056640\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"IntervalDayTime { days: 123, milliseconds: 456 }\")\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1031,25 +1425,21 @@ mod test { let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)]) AS a IN (Map { iter: Iter([Literal(Int32(1)), Literal(Int8(4)), Literal(Int64(8))]) })\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + let expected = "Projection: a IN ([CAST(Int32(1) AS Int64), CAST(Int8(4) AS Int64), Int64(8)])\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // a in (1,4,8), a is decimal let expr = col("a").in_list(vec![lit(1_i32), lit(4_i8), lit(8_i64)], false); let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: Arc::new(DFSchema::from_unqualifed_fields( + schema: Arc::new(DFSchema::from_unqualified_fields( vec![Field::new("a", DataType::Decimal128(12, 4), true)].into(), std::collections::HashMap::new(), )?), })); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let expected = - "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))]) AS a IN (Map { iter: Iter([Literal(Int32(1)), Literal(Int8(4)), Literal(Int64(8))]) })\ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + let expected = "Projection: CAST(a AS Decimal128(24, 4)) IN ([CAST(Int32(1) AS Decimal128(24, 4)), CAST(Int8(4) AS Decimal128(24, 4)), CAST(Int64(8) AS Decimal128(24, 4))])\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[test] @@ -1060,12 +1450,12 @@ mod test { cast(lit("2002-05-08"), DataType::Date32) + lit(ScalarValue::new_interval_ym(0, 1)), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); let expected = "Filter: a BETWEEN Utf8(\"2002-05-08\") AND CAST(CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AS Utf8)\ \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[test] @@ -1076,13 +1466,13 @@ mod test { + lit(ScalarValue::new_interval_ym(0, 1)), lit("2002-12-08"), ); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Filter(Filter::try_new(expr, empty)?); // TODO: we should cast col(a). let expected = "Filter: CAST(a AS Date32) BETWEEN CAST(Utf8(\"2002-05-08\") AS Date32) + IntervalYearMonth(\"1\") AND CAST(Utf8(\"2002-12-08\") AS Date32)\ \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected) + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected) } #[test] @@ -1093,11 +1483,11 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, ""); + let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, ""); let err = ret.unwrap_err().to_string(); assert!(err.contains("Cannot infer common argument type for comparison operation Int64 IS DISTINCT FROM Boolean"), "{err}"); @@ -1106,21 +1496,21 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT TRUE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is false let expr = col("a").is_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; // is not false let expr = col("a").is_not_false(); let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT FALSE\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1131,26 +1521,25 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); let expected = "Projection: a LIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let expected = "Projection: a LIKE CAST(NULL AS Utf8) AS a LIKE NULL \ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + let expected = "Projection: a LIKE CAST(NULL AS Utf8)\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let like_expr = Expr::Like(Like::new(false, expr, pattern, None, false)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![like_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); + let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); assert!(err.unwrap_err().to_string().contains( "There isn't a common type to coerce Int64 and Utf8 in LIKE expression" @@ -1160,26 +1549,25 @@ mod test { let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); let expected = "Projection: a ILIKE Utf8(\"abc\")\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::Null)); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let expected = "Projection: a ILIKE CAST(NULL AS Utf8) AS a ILIKE NULL \ - \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + let expected = "Projection: a ILIKE CAST(NULL AS Utf8)\n EmptyRelation"; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; let expr = Box::new(col("a")); let pattern = Box::new(lit(ScalarValue::new_utf8("abc"))); let ilike_expr = Expr::Like(Like::new(false, expr, pattern, None, true)); let empty = empty_with_type(DataType::Int64); let plan = LogicalPlan::Projection(Projection::try_new(vec![ilike_expr], empty)?); - let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); + let err = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); assert!(err.is_err()); assert!(err.unwrap_err().to_string().contains( "There isn't a common type to coerce Int64 and Utf8 in ILIKE expression" @@ -1195,11 +1583,11 @@ mod test { let plan = LogicalPlan::Projection(Projection::try_new(vec![expr.clone()], empty)?); let expected = "Projection: a IS UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected); + let ret = assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected); let err = ret.unwrap_err().to_string(); assert!(err.contains("Cannot infer common argument type for comparison operation Utf8 IS DISTINCT FROM Boolean"), "{err}"); @@ -1208,14 +1596,14 @@ mod test { let empty = empty_with_type(DataType::Boolean); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: a IS NOT UNKNOWN\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } #[test] fn concat_for_type_coercion() -> Result<()> { - let empty = empty_with_type(DataType::Utf8); + let empty = empty_with_type(Utf8); let args = [col("a"), lit("b"), lit(true), lit(false), lit(13)]; // concat-type signature @@ -1224,11 +1612,13 @@ mod test { signature: Signature::variadic(vec![Utf8], Volatility::Immutable), }) .call(args.to_vec()); - let plan = - LogicalPlan::Projection(Projection::try_new(vec![expr], empty.clone())?); + let plan = LogicalPlan::Projection(Projection::try_new( + vec![expr], + Arc::clone(&empty), + )?); let expected = "Projection: TestScalarUDF(a, Utf8(\"b\"), CAST(Boolean(true) AS Utf8), CAST(Boolean(false) AS Utf8), CAST(Int32(13) AS Utf8))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; } Ok(()) @@ -1237,33 +1627,33 @@ mod test { #[test] fn test_type_coercion_rewrite() -> Result<()> { // gt - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = Arc::new(DFSchema::from_unqualified_fields( vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = Arc::new(DFSchema::from_unqualified_fields( vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = Arc::new(DFSchema::from_unqualified_fields( vec![Field::new("a", DataType::Int64, true)].into(), std::collections::HashMap::new(), )?); - let mut rewriter = TypeCoercionRewriter { schema }; + let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); let result = expr.rewrite(&mut rewriter).data()?; @@ -1281,10 +1671,9 @@ mod test { .eq(cast(lit("1998-03-18"), DataType::Date32)); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - dbg!(&plan); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) = CAST(CAST(Utf8(\"1998-03-18\") AS Date32) AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1302,26 +1691,26 @@ mod test { fn cast_helper( case: Case, - case_when_type: DataType, - then_else_type: DataType, + case_when_type: &DataType, + then_else_type: &DataType, schema: &DFSchemaRef, ) -> Case { let expr = case .expr - .map(|e| cast_if_not_same_type(e, &case_when_type, schema)); + .map(|e| cast_if_not_same_type(e, case_when_type, schema)); let when_then_expr = case .when_then_expr .into_iter() .map(|(when, then)| { ( - cast_if_not_same_type(when, &case_when_type, schema), - cast_if_not_same_type(then, &then_else_type, schema), + cast_if_not_same_type(when, case_when_type, schema), + cast_if_not_same_type(then, then_else_type, schema), ) }) .collect::>(); let else_expr = case .else_expr - .map(|e| cast_if_not_same_type(e, &then_else_type, schema)); + .map(|e| cast_if_not_same_type(e, then_else_type, schema)); Case { expr, @@ -1332,7 +1721,7 @@ mod test { #[test] fn test_case_expression_coercion() -> Result<()> { - let schema = Arc::new(DFSchema::from_unqualifed_fields( + let schema = Arc::new(DFSchema::from_unqualified_fields( vec![ Field::new("boolean", DataType::Boolean, true), Field::new("integer", DataType::Int32, true), @@ -1349,7 +1738,7 @@ mod test { true, ), Field::new("binary", DataType::Binary, true), - Field::new("string", DataType::Utf8, true), + Field::new("string", Utf8, true), Field::new("decimal", DataType::Decimal128(10, 10), true), ] .into(), @@ -1366,11 +1755,11 @@ mod test { else_expr: None, }; let case_when_common_type = DataType::Boolean; - let then_else_common_type = DataType::Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), - case_when_common_type, - then_else_common_type, + &case_when_common_type, + &then_else_common_type, &schema, ); let actual = coerce_case_expression(case, &schema)?; @@ -1385,12 +1774,12 @@ mod test { ], else_expr: Some(Box::new(col("string"))), }; - let case_when_common_type = DataType::Utf8; - let then_else_common_type = DataType::Utf8; + let case_when_common_type = Utf8; + let then_else_common_type = Utf8; let expected = cast_helper( case.clone(), - case_when_common_type, - then_else_common_type, + &case_when_common_type, + &then_else_common_type, &schema, ); let actual = coerce_case_expression(case, &schema)?; @@ -1435,6 +1824,186 @@ mod test { Ok(()) } + macro_rules! test_case_expression { + ($expr:expr, $when_then:expr, $case_when_type:expr, $then_else_type:expr, $schema:expr) => { + let case = Case { + expr: $expr.map(|e| Box::new(col(e))), + when_then_expr: $when_then, + else_expr: None, + }; + + let expected = + cast_helper(case.clone(), &$case_when_type, &$then_else_type, &$schema); + + let actual = coerce_case_expression(case, &$schema)?; + assert_eq!(expected, actual); + }; + } + + #[test] + fn tes_case_when_list() -> Result<()> { + let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); + let schema = Arc::new(DFSchema::from_unqualified_fields( + vec![ + Field::new( + "large_list", + DataType::LargeList(Arc::clone(&inner_field)), + true, + ), + Field::new( + "fixed_list", + DataType::FixedSizeList(Arc::clone(&inner_field), 3), + true, + ), + Field::new("list", DataType::List(inner_field), true), + ] + .into(), + std::collections::HashMap::new(), + )?); + + test_case_expression!( + Some("list"), + vec![(Box::new(col("large_list")), Box::new(lit("1")))], + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + Utf8, + schema + ); + + test_case_expression!( + Some("large_list"), + vec![(Box::new(col("list")), Box::new(lit("1")))], + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + Utf8, + schema + ); + + test_case_expression!( + Some("list"), + vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + Utf8, + schema + ); + + test_case_expression!( + Some("fixed_list"), + vec![(Box::new(col("list")), Box::new(lit("1")))], + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + Utf8, + schema + ); + + test_case_expression!( + Some("fixed_list"), + vec![(Box::new(col("large_list")), Box::new(lit("1")))], + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + Utf8, + schema + ); + + test_case_expression!( + Some("large_list"), + vec![(Box::new(col("fixed_list")), Box::new(lit("1")))], + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + Utf8, + schema + ); + Ok(()) + } + + #[test] + fn test_then_else_list() -> Result<()> { + let inner_field = Arc::new(Field::new("item", DataType::Int64, true)); + let schema = Arc::new(DFSchema::from_unqualified_fields( + vec![ + Field::new("boolean", DataType::Boolean, true), + Field::new( + "large_list", + DataType::LargeList(Arc::clone(&inner_field)), + true, + ), + Field::new( + "fixed_list", + DataType::FixedSizeList(Arc::clone(&inner_field), 3), + true, + ), + Field::new("list", DataType::List(inner_field), true), + ] + .into(), + std::collections::HashMap::new(), + )?); + + // large list and list + test_case_expression!( + None::, + vec![ + (Box::new(col("boolean")), Box::new(col("large_list"))), + (Box::new(col("boolean")), Box::new(col("list"))) + ], + DataType::Boolean, + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + schema + ); + + test_case_expression!( + None::, + vec![ + (Box::new(col("boolean")), Box::new(col("list"))), + (Box::new(col("boolean")), Box::new(col("large_list"))) + ], + DataType::Boolean, + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + schema + ); + + // fixed list and list + test_case_expression!( + None::, + vec![ + (Box::new(col("boolean")), Box::new(col("fixed_list"))), + (Box::new(col("boolean")), Box::new(col("list"))) + ], + DataType::Boolean, + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + schema + ); + + test_case_expression!( + None::, + vec![ + (Box::new(col("boolean")), Box::new(col("list"))), + (Box::new(col("boolean")), Box::new(col("fixed_list"))) + ], + DataType::Boolean, + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + schema + ); + + // fixed list and large list + test_case_expression!( + None::, + vec![ + (Box::new(col("boolean")), Box::new(col("fixed_list"))), + (Box::new(col("boolean")), Box::new(col("large_list"))) + ], + DataType::Boolean, + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + schema + ); + + test_case_expression!( + None::, + vec![ + (Box::new(col("boolean")), Box::new(col("large_list"))), + (Box::new(col("boolean")), Box::new(col("fixed_list"))) + ], + DataType::Boolean, + DataType::LargeList(Arc::new(Field::new("item", DataType::Int64, true))), + schema + ); + Ok(()) + } + #[test] fn interval_plus_timestamp() -> Result<()> { // SELECT INTERVAL '1' YEAR + '2000-01-01T00:00:00'::timestamp; @@ -1449,7 +2018,7 @@ mod test { let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); let expected = "Projection: IntervalYearMonth(\"12\") + CAST(Utf8(\"2000-01-01T00:00:00\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1468,10 +2037,9 @@ mod test { )); let empty = empty(); let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty)?); - dbg!(&plan); let expected = "Projection: CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None)) - CAST(Utf8(\"1998-03-18\") AS Timestamp(Nanosecond, None))\n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1496,7 +2064,7 @@ mod test { \n Projection: CAST(a AS Int64)\ \n EmptyRelation\ \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1520,7 +2088,7 @@ mod test { \n Subquery:\ \n EmptyRelation\ \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } @@ -1544,7 +2112,7 @@ mod test { \n Projection: CAST(a AS Decimal128(13, 8))\ \n EmptyRelation\ \n EmptyRelation"; - assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?; + assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), plan, expected)?; Ok(()) } } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index cb3b4accf35d..4fe22d252744 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,106 +17,27 @@ //! [`CommonSubexprEliminate`] to avoid redundant computation of common sub-expressions -use std::collections::hash_map::Entry; -use std::collections::{BTreeSet, HashMap}; +use std::collections::BTreeSet; +use std::fmt::Debug; use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, -}; -use datafusion_common::{ - internal_err, qualified_name, Column, DFSchema, DFSchemaRef, DataFusionError, Result, -}; -use datafusion_expr::expr::Alias; -use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; -use datafusion_expr::{col, Expr, ExprSchemable}; - -/// Set of expressions generated by the [`ExprIdentifierVisitor`] -/// and consumed by the [`CommonSubexprRewriter`]. -#[derive(Default)] -struct ExprSet { - /// A map from expression's identifier (stringified expr) to tuple including: - /// - the expression itself (cloned) - /// - counter - /// - DataType of this expression. - /// - symbol used as the identifier in the alias. - map: HashMap, -} - -impl ExprSet { - fn expr_identifier(expr: &Expr) -> Identifier { - format!("{expr}") - } - - fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> { - self.map.get(key) - } - - fn entry( - &mut self, - key: Identifier, - ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> { - self.map.entry(key) - } - - fn populate_expr_set( - &mut self, - expr: &[Expr], - input_schema: DFSchemaRef, - expr_mask: ExprMask, - ) -> Result<()> { - expr.iter().try_for_each(|e| { - self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?; - - Ok(()) - }) - } +use crate::optimizer::ApplyOrder; +use crate::utils::NamePreserver; +use datafusion_common::alias::AliasGenerator; - /// Go through an expression tree and generate identifier for every node in this tree. - fn expr_to_identifier( - &mut self, - expr: &Expr, - input_schema: DFSchemaRef, - expr_mask: ExprMask, - ) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_set: self, - input_schema, - visit_stack: vec![], - node_count: 0, - expr_mask, - })?; - - Ok(()) - } -} - -impl From> for ExprSet { - fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self { - let mut expr_set = Self::default(); - entries.into_iter().for_each(|(k, v)| { - expr_set.map.insert(k, v); - }); - expr_set - } -} +use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; +use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::logical_plan::{ + Aggregate, Filter, LogicalPlan, Projection, Sort, Window, +}; +use datafusion_expr::tree_node::replace_sort_expressions; +use datafusion_expr::{col, BinaryExpr, Case, Expr, Operator}; -/// Identifier for each subexpression. -/// -/// Note that the current implementation uses the `Display` of an expression -/// (a `String`) as `Identifier`. -/// -/// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no -/// collision (as low as possible)" -/// -/// Since an identifier is likely to be copied many times, it is better that an identifier -/// is small or "copy". otherwise some kinds of reference count is needed. String description -/// here is not such a good choose. -type Identifier = String; +const CSE_PREFIX: &str = "__common_expr"; /// Performs Common Sub-expression Elimination optimization. /// @@ -144,271 +65,475 @@ type Identifier = String; /// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here /// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once /// ``` +#[derive(Debug)] pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { - /// Rewrites `exprs_list` with common sub-expressions replaced with a new - /// column. - /// - /// `affected_id` is updated with any sub expressions that were replaced. - /// - /// Returns the rewritten expressions - fn rewrite_exprs_list( + pub fn new() -> Self { + Self {} + } + + fn try_optimize_proj( &self, - exprs_list: &[&[Expr]], - expr_set: &ExprSet, - affected_id: &mut BTreeSet, - ) -> Result>> { - exprs_list - .iter() - .map(|exprs| { - exprs - .iter() - .cloned() - .map(|expr| replace_common_expr(expr, expr_set, affected_id)) - .collect::>>() + projection: Projection, + config: &dyn OptimizerConfig, + ) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = projection; + let input = Arc::unwrap_or_clone(input); + self.try_unary_plan(expr, input, config)? + .map_data(|(new_expr, new_input)| { + Projection::try_new_with_schema(new_expr, Arc::new(new_input), schema) + .map(LogicalPlan::Projection) }) - .collect::>>() } - - /// Rewrites the expression in `exprs_list` with common sub-expressions - /// replaced with a new colum and adds a ProjectionExec on top of `input` - /// which computes any replaced common sub-expressions. - /// - /// Returns a tuple of: - /// 1. The rewritten expressions - /// 2. A `LogicalPlan::Projection` with input of `input` that computes any - /// common sub-expressions that were used - fn rewrite_expr( + fn try_optimize_sort( &self, - exprs_list: &[&[Expr]], - input: &LogicalPlan, - expr_set: &ExprSet, + sort: Sort, config: &dyn OptimizerConfig, - ) -> Result<(Vec>, LogicalPlan)> { - let mut affected_id = BTreeSet::::new(); - - let rewrite_exprs = - self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; - - let mut new_input = self - .try_optimize(input, config)? - .unwrap_or_else(|| input.clone()); - if !affected_id.is_empty() { - new_input = build_common_expr_project_plan(new_input, affected_id, expr_set)?; - } + ) -> Result> { + let Sort { expr, input, fetch } = sort; + let input = Arc::unwrap_or_clone(input); + let sort_expressions = expr.iter().map(|sort| sort.expr.clone()).collect(); + let new_sort = self + .try_unary_plan(sort_expressions, input, config)? + .update_data(|(new_expr, new_input)| { + LogicalPlan::Sort(Sort { + expr: replace_sort_expressions(expr, new_expr), + input: Arc::new(new_input), + fetch, + }) + }); + Ok(new_sort) + } - Ok((rewrite_exprs, new_input)) + fn try_optimize_filter( + &self, + filter: Filter, + config: &dyn OptimizerConfig, + ) -> Result> { + let Filter { + predicate, input, .. + } = filter; + let input = Arc::unwrap_or_clone(input); + let expr = vec![predicate]; + self.try_unary_plan(expr, input, config)? + .map_data(|(mut new_expr, new_input)| { + assert_eq!(new_expr.len(), 1); // passed in vec![predicate] + let new_predicate = new_expr.pop().unwrap(); + Filter::try_new(new_predicate, Arc::new(new_input)) + .map(LogicalPlan::Filter) + }) } fn try_optimize_window( &self, - window: &Window, + window: Window, config: &dyn OptimizerConfig, - ) -> Result { - let mut window_exprs = vec![]; - let mut expr_set = ExprSet::default(); - - // Get all window expressions inside the consecutive window operators. - // Consecutive window expressions may refer to same complex expression. - // If same complex expression is referred more than once by subsequent `WindowAggr`s, - // we can cache complex expression by evaluating it with a projection before the - // first WindowAggr. - // This enables us to cache complex expression "c3+c4" for following plan: - // WindowAggr: windowExpr=[[SUM(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - // --WindowAggr: windowExpr=[[SUM(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] - // where, it is referred once by each `WindowAggr` (total of 2) in the plan. - let mut plan = LogicalPlan::Window(window.clone()); - while let LogicalPlan::Window(window) = plan { - let Window { - input, window_expr, .. - } = window; - plan = input.as_ref().clone(); - - let input_schema = Arc::clone(input.schema()); - expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?; - - window_exprs.push(window_expr); - } + ) -> Result> { + // Collects window expressions from consecutive `LogicalPlan::Window` nodes into + // a list. + let (window_expr_list, window_schemas, input) = + get_consecutive_window_exprs(window); - let mut window_exprs = window_exprs - .iter() - .map(|expr| expr.as_slice()) - .collect::>(); - - let (mut new_expr, new_input) = - self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?; - assert_eq!(window_exprs.len(), new_expr.len()); - - // Construct consecutive window operator, with their corresponding new window expressions. - plan = new_input; - while let Some(new_window_expr) = new_expr.pop() { - // Since `new_expr` and `window_exprs` length are same. We can safely `.unwrap` here. - let orig_window_expr = window_exprs.pop().unwrap(); - assert_eq!(new_window_expr.len(), orig_window_expr.len()); - - // Rename new re-written window expressions with original name (by giving alias) - // Otherwise we may receive schema error, in subsequent operators. - let new_window_expr = new_window_expr - .into_iter() - .zip(orig_window_expr.iter()) - .map(|(new_window_expr, window_expr)| { - let original_name = window_expr.name_for_alias()?; - new_window_expr.alias_if_changed(original_name) - }) - .collect::>>()?; - plan = LogicalPlan::Window(Window::try_new(new_window_expr, Arc::new(plan))?); - } + // Extract common sub-expressions from the list. - Ok(plan) + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(window_expr_list)? + { + // If there are common sub-expressions, then the insert a projection node + // with the common expressions between the new window nodes and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: new_exprs_list, + original_nodes_list: original_exprs_list, + } => build_common_expr_project_plan(input, common_exprs).map(|new_input| { + Transformed::yes((new_exprs_list, new_input, Some(original_exprs_list))) + }), + FoundCommonNodes::No { + original_nodes_list: original_exprs_list, + } => Ok(Transformed::no((original_exprs_list, input, None))), + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_window_expr_list, new_input, window_expr_list)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok((new_window_expr_list, new_input, window_expr_list)) + }) + })? + // Rebuild the consecutive window nodes. + .map_data(|(new_window_expr_list, new_input, window_expr_list)| { + // If there were common expressions extracted, then we need to make sure + // we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around extracted + // common expressions this doesn't mean that the original column names + // (schema) are preserved due to the inserted aliases are not always at + // the top of the expression. + // Let's consider improving `find_common_exprs()` to always keep column + // names and get rid of additional name preserving logic here. + if let Some(window_expr_list) = window_expr_list { + let name_preserver = NamePreserver::new_for_projection(); + let saved_names = window_expr_list + .iter() + .map(|exprs| { + exprs + .iter() + .map(|expr| name_preserver.save(expr)) + .collect::>() + }) + .collect::>(); + new_window_expr_list.into_iter().zip(saved_names).try_rfold( + new_input, + |plan, (new_window_expr, saved_names)| { + let new_window_expr = new_window_expr + .into_iter() + .zip(saved_names) + .map(|(new_window_expr, saved_name)| { + saved_name.restore(new_window_expr) + }) + .collect::>(); + Window::try_new(new_window_expr, Arc::new(plan)) + .map(LogicalPlan::Window) + }, + ) + } else { + new_window_expr_list + .into_iter() + .zip(window_schemas) + .try_rfold(new_input, |plan, (new_window_expr, schema)| { + Window::try_new_with_schema( + new_window_expr, + Arc::new(plan), + schema, + ) + .map(LogicalPlan::Window) + }) + } + }) } fn try_optimize_aggregate( &self, - aggregate: &Aggregate, + aggregate: Aggregate, config: &dyn OptimizerConfig, - ) -> Result { + ) -> Result> { let Aggregate { group_expr, aggr_expr, input, + schema, .. } = aggregate; - let mut expr_set = ExprSet::default(); - - // build expr_set, with groupby and aggr - let input_schema = Arc::clone(input.schema()); - expr_set.populate_expr_set( - group_expr, - Arc::clone(&input_schema), + let input = Arc::unwrap_or_clone(input); + // Extract common sub-expressions from the aggregate and grouping expressions. + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), ExprMask::Normal, - )?; - expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?; - - // rewrite inputs - let (mut new_expr, new_input) = - self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?; - // note the reversed pop order. - let new_aggr_expr = pop_expr(&mut new_expr)?; - let new_group_expr = pop_expr(&mut new_expr)?; - - // create potential projection on top - let mut expr_set = ExprSet::default(); - let new_input_schema = Arc::clone(new_input.schema()); - expr_set.populate_expr_set( - &new_aggr_expr, - new_input_schema.clone(), - ExprMask::NormalAndAggregates, - )?; - - let mut affected_id = BTreeSet::::new(); - let mut rewritten = - self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?; - let rewritten = pop_expr(&mut rewritten)?; - - if affected_id.is_empty() { - // Alias aggregation expressions if they have changed - let new_aggr_expr = new_aggr_expr - .iter() - .zip(aggr_expr.iter()) - .map(|(new_expr, old_expr)| { - new_expr.clone().alias_if_changed(old_expr.display_name()?) + )) + .extract_common_nodes(vec![group_expr, aggr_expr])? + { + // If there are common sub-expressions, then insert a projection node + // with the common expressions between the new aggregate node and the + // original input. + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = new_exprs_list.pop().unwrap(); + let new_group_expr = new_exprs_list.pop().unwrap(); + + build_common_expr_project_plan(input, common_exprs).map(|new_input| { + let aggr_expr = original_exprs_list.pop().unwrap(); + Transformed::yes(( + new_aggr_expr, + new_group_expr, + new_input, + Some(aggr_expr), + )) }) - .collect::>>()?; - // Since group_epxr changes, schema changes also. Use try_new method. - Aggregate::try_new(Arc::new(new_input), new_group_expr, new_aggr_expr) - .map(LogicalPlan::Aggregate) - } else { - let mut agg_exprs = vec![]; - - for id in affected_id { - match expr_set.get(&id) { - Some((expr, _, _, symbol)) => { - // todo: check `nullable` - agg_exprs.push(expr.clone().alias(symbol.as_str())); - } - _ => { - return internal_err!("expr_set invalid state"); - } - } } - let mut proj_exprs = vec![]; - for expr in &new_group_expr { - extract_expressions(expr, &new_input_schema, &mut proj_exprs)? + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_aggr_expr = original_exprs_list.pop().unwrap(); + let new_group_expr = original_exprs_list.pop().unwrap(); + + Ok(Transformed::no(( + new_aggr_expr, + new_group_expr, + input, + None, + ))) } - for (expr_rewritten, expr_orig) in rewritten.into_iter().zip(new_aggr_expr) { - if expr_rewritten == expr_orig { - if let Expr::Alias(Alias { expr, name, .. }) = expr_rewritten { - agg_exprs.push(expr.alias(&name)); - proj_exprs.push(Expr::Column(Column::from_name(name))); - } else { - let id = ExprSet::expr_identifier(&expr_rewritten); - let (qualifier, field) = - expr_rewritten.to_field(&new_input_schema)?; - let out_name = qualified_name(qualifier.as_ref(), field.name()); - - agg_exprs.push(expr_rewritten.alias(&id)); - proj_exprs - .push(Expr::Column(Column::from_name(id)).alias(out_name)); + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_aggr_expr, new_group_expr, new_input, aggr_expr)| { + self.rewrite(new_input, config)?.map_data(|new_input| { + Ok(( + new_aggr_expr, + new_group_expr, + aggr_expr, + Arc::new(new_input), + )) + }) + })? + // Try extracting common aggregate expressions and rebuild the aggregate node. + .transform_data( + |(new_aggr_expr, new_group_expr, aggr_expr, new_input)| { + // Extract common aggregate sub-expressions from the aggregate expressions. + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::NormalAndAggregates, + )) + .extract_common_nodes(vec![new_aggr_expr])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = new_exprs_list.pop().unwrap(); + let new_aggr_expr = original_exprs_list.pop().unwrap(); + + let mut agg_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| expr.alias(expr_alias)) + .collect::>(); + + let mut proj_exprs = vec![]; + for expr in &new_group_expr { + extract_expressions(expr, &mut proj_exprs) + } + for (expr_rewritten, expr_orig) in + rewritten_aggr_expr.into_iter().zip(new_aggr_expr) + { + if expr_rewritten == expr_orig { + if let Expr::Alias(Alias { expr, name, .. }) = + expr_rewritten + { + agg_exprs.push(expr.alias(&name)); + proj_exprs + .push(Expr::Column(Column::from_name(name))); + } else { + let expr_alias = + config.alias_generator().next(CSE_PREFIX); + let (qualifier, field_name) = + expr_rewritten.qualified_name(); + let out_name = + qualified_name(qualifier.as_ref(), &field_name); + + agg_exprs.push(expr_rewritten.alias(&expr_alias)); + proj_exprs.push( + Expr::Column(Column::from_name(expr_alias)) + .alias(out_name), + ); + } + } else { + proj_exprs.push(expr_rewritten); + } + } + + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + new_input, + new_group_expr, + agg_exprs, + )?); + Projection::try_new(proj_exprs, Arc::new(agg)) + .map(|p| Transformed::yes(LogicalPlan::Projection(p))) } - } else { - proj_exprs.push(expr_rewritten); - } - } - let agg = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(new_input), - new_group_expr, - agg_exprs, - )?); - - Ok(LogicalPlan::Projection(Projection::try_new( - proj_exprs, - Arc::new(agg), - )?)) - } + // If there aren't any common aggregate sub-expressions, then just + // rebuild the aggregate node. + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let rewritten_aggr_expr = original_exprs_list.pop().unwrap(); + + // If there were common expressions extracted, then we need to + // make sure we restore the original column names. + // TODO: Although `find_common_exprs()` inserts aliases around + // extracted common expressions this doesn't mean that the + // original column names (schema) are preserved due to the + // inserted aliases are not always at the top of the + // expression. + // Let's consider improving `find_common_exprs()` to always + // keep column names and get rid of additional name + // preserving logic here. + if let Some(aggr_expr) = aggr_expr { + let name_perserver = NamePreserver::new_for_projection(); + let saved_names = aggr_expr + .iter() + .map(|expr| name_perserver.save(expr)) + .collect::>(); + let new_aggr_expr = rewritten_aggr_expr + .into_iter() + .zip(saved_names) + .map(|(new_expr, saved_name)| { + saved_name.restore(new_expr) + }) + .collect::>(); + + // Since `group_expr` may have changed, schema may also. + // Use `try_new()` method. + Aggregate::try_new(new_input, new_group_expr, new_aggr_expr) + .map(LogicalPlan::Aggregate) + .map(Transformed::no) + } else { + Aggregate::try_new_with_schema( + new_input, + new_group_expr, + rewritten_aggr_expr, + schema, + ) + .map(LogicalPlan::Aggregate) + .map(Transformed::no) + } + } + } + }, + ) } + /// Rewrites the expr list and input to remove common subexpressions + /// + /// # Parameters + /// + /// * `exprs`: List of expressions in the node + /// * `input`: input plan (that produces the columns referred to in `exprs`) + /// + /// # Return value + /// + /// Returns `(rewritten_exprs, new_input)`. `new_input` is either: + /// + /// 1. The original `input` of no common subexpressions were extracted + /// 2. A newly added projection on top of the original input + /// that computes the common subexpressions fn try_unary_plan( &self, - plan: &LogicalPlan, + exprs: Vec, + input: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result { - let expr = plan.expressions(); - let inputs = plan.inputs(); - let input = inputs[0]; - let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::default(); - - // Visit expr list and build expr identifier to occuring count map (`expr_set`). - expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?; - - let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], input, &expr_set, config)?; + ) -> Result, LogicalPlan)>> { + // Extract common sub-expressions from the expressions. + match CSE::new(ExprCSEController::new( + config.alias_generator().as_ref(), + ExprMask::Normal, + )) + .extract_common_nodes(vec![exprs])? + { + FoundCommonNodes::Yes { + common_nodes: common_exprs, + new_nodes_list: mut new_exprs_list, + original_nodes_list: _, + } => { + let new_exprs = new_exprs_list.pop().unwrap(); + build_common_expr_project_plan(input, common_exprs) + .map(|new_input| Transformed::yes((new_exprs, new_input))) + } + FoundCommonNodes::No { + original_nodes_list: mut original_exprs_list, + } => { + let new_exprs = original_exprs_list.pop().unwrap(); + Ok(Transformed::no((new_exprs, input))) + } + }? + // Recurse into the new input. + // (This is similar to what a `ApplyOrder::TopDown` optimizer rule would do.) + .transform_data(|(new_exprs, new_input)| { + self.rewrite(new_input, config)? + .map_data(|new_input| Ok((new_exprs, new_input))) + }) + } +} - plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) +/// Get all window expressions inside the consecutive window operators. +/// +/// Returns the window expressions, and the input to the deepest child +/// LogicalPlan. +/// +/// For example, if the input window looks like +/// +/// ```text +/// LogicalPlan::Window(exprs=[a, b, c]) +/// LogicalPlan::Window(exprs=[d]) +/// InputPlan +/// ``` +/// +/// Returns: +/// * `window_exprs`: `[[a, b, c], [d]]` +/// * InputPlan +/// +/// Consecutive window expressions may refer to same complex expression. +/// +/// If same complex expression is referred more than once by subsequent +/// `WindowAggr`s, we can cache complex expression by evaluating it with a +/// projection before the first WindowAggr. +/// +/// This enables us to cache complex expression "c3+c4" for following plan: +/// +/// ```text +/// WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +/// --WindowAggr: windowExpr=[[sum(c9) ORDER BY [c3 + c4] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +/// ``` +/// +/// where, it is referred once by each `WindowAggr` (total of 2) in the plan. +fn get_consecutive_window_exprs( + window: Window, +) -> (Vec>, Vec, LogicalPlan) { + let mut window_expr_list = vec![]; + let mut window_schemas = vec![]; + let mut plan = LogicalPlan::Window(window); + while let LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) = plan + { + window_expr_list.push(window_expr); + window_schemas.push(schema); + + plan = Arc::unwrap_or_clone(input); } + (window_expr_list, window_schemas, plan) } impl OptimizerRule for CommonSubexprEliminate { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn apply_order(&self) -> Option { + // This rule handles recursion itself in a `ApplyOrder::TopDown` like manner. + // This is because in some cases adjacent nodes are collected (e.g. `Window`) and + // CSEd as a group, which can't be done in a simple `ApplyOrder::TopDown` rule. + None + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { + let original_schema = Arc::clone(plan.schema()); + let optimized_plan = match plan { - LogicalPlan::Projection(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) => Some(self.try_unary_plan(plan, config)?), - LogicalPlan::Window(window) => { - Some(self.try_optimize_window(window, config)?) - } - LogicalPlan::Aggregate(aggregate) => { - Some(self.try_optimize_aggregate(aggregate, config)?) - } + LogicalPlan::Projection(proj) => self.try_optimize_proj(proj, config)?, + LogicalPlan::Sort(sort) => self.try_optimize_sort(sort, config)?, + LogicalPlan::Filter(filter) => self.try_optimize_filter(filter, config)?, + LogicalPlan::Window(window) => self.try_optimize_window(window, config)?, + LogicalPlan::Aggregate(agg) => self.try_optimize_aggregate(agg, config)?, LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) @@ -428,22 +553,22 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) | LogicalPlan::RecursiveQuery(_) - | LogicalPlan::Prepare(_) => { - // apply the optimization to all inputs of the plan - utils::optimize_children(self, plan, config)? + | LogicalPlan::Prepare(_) + | LogicalPlan::Execute(_) => { + // This rule handles recursion itself in a `ApplyOrder::TopDown` like + // manner. + plan.map_children(|c| self.rewrite(c, config))? } }; - let original_schema = plan.schema().clone(); - match optimized_plan { - Some(optimized_plan) if optimized_plan.schema() != &original_schema => { - // add an additional projection if the output schema changed. - Ok(Some(build_recover_project_plan( - &original_schema, - optimized_plan, - )?)) - } - plan => Ok(plan), + // If we rewrote the plan, ensure the schema stays the same + if optimized_plan.transformed && optimized_plan.data.schema() != &original_schema + { + optimized_plan.map_data(|optimized_plan| { + build_recover_project_plan(&original_schema, optimized_plan) + }) + } else { + Ok(optimized_plan) } } @@ -452,23 +577,144 @@ impl OptimizerRule for CommonSubexprEliminate { } } -impl Default for CommonSubexprEliminate { - fn default() -> Self { - Self::new() +/// Which type of [expressions](Expr) should be considered for rewriting? +#[derive(Debug, Clone, Copy)] +enum ExprMask { + /// Ignores: + /// + /// - [`Literal`](Expr::Literal) + /// - [`Columns`](Expr::Column) + /// - [`ScalarVariable`](Expr::ScalarVariable) + /// - [`Alias`](Expr::Alias) + /// - [`Wildcard`](Expr::Wildcard) + /// - [`AggregateFunction`](Expr::AggregateFunction) + Normal, + + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). + NormalAndAggregates, +} + +struct ExprCSEController<'a> { + alias_generator: &'a AliasGenerator, + mask: ExprMask, + + // how many aliases have we seen so far + alias_counter: usize, +} + +impl<'a> ExprCSEController<'a> { + fn new(alias_generator: &'a AliasGenerator, mask: ExprMask) -> Self { + Self { + alias_generator, + mask, + alias_counter: 0, + } } } -impl CommonSubexprEliminate { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} +impl CSEController for ExprCSEController<'_> { + type Node = Expr; + + fn conditional_children(node: &Expr) -> Option<(Vec<&Expr>, Vec<&Expr>)> { + match node { + // In case of `ScalarFunction`s we don't know which children are surely + // executed so start visiting all children conditionally and stop the + // recursion with `TreeNodeRecursion::Jump`. + Expr::ScalarFunction(ScalarFunction { func, args }) + if func.short_circuits() => + { + Some((vec![], args.iter().collect())) + } + + // In case of `And` and `Or` the first child is surely executed, but we + // account subexpressions as conditional in the second. + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And | Operator::Or, + right, + }) => Some((vec![left.as_ref()], vec![right.as_ref()])), + + // In case of `Case` the optional base expression and the first when + // expressions are surely executed, but we account subexpressions as + // conditional in the others. + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => Some(( + expr.iter() + .map(|e| e.as_ref()) + .chain(when_then_expr.iter().take(1).map(|(when, _)| when.as_ref())) + .collect(), + when_then_expr + .iter() + .take(1) + .map(|(_, then)| then.as_ref()) + .chain( + when_then_expr + .iter() + .skip(1) + .flat_map(|(when, then)| [when.as_ref(), then.as_ref()]), + ) + .chain(else_expr.iter().map(|e| e.as_ref())) + .collect(), + )), + _ => None, + } + } + + fn is_valid(node: &Expr) -> bool { + !node.is_volatile_node() + } + + fn is_ignored(&self, node: &Expr) -> bool { + let is_normal_minus_aggregates = matches!( + node, + Expr::Literal(..) + | Expr::Column(..) + | Expr::ScalarVariable(..) + | Expr::Alias(..) + | Expr::Wildcard { .. } + ); + + let is_aggr = matches!(node, Expr::AggregateFunction(..)); + + match self.mask { + ExprMask::Normal => is_normal_minus_aggregates || is_aggr, + ExprMask::NormalAndAggregates => is_normal_minus_aggregates, + } + } + + fn generate_alias(&self) -> String { + self.alias_generator.next(CSE_PREFIX) + } + + fn rewrite(&mut self, node: &Self::Node, alias: &str) -> Self::Node { + // alias the expressions without an `Alias` ancestor node + if self.alias_counter > 0 { + col(alias) + } else { + self.alias_counter += 1; + col(alias).alias(node.schema_name().to_string()) + } + } + + fn rewrite_f_down(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { + self.alias_counter += 1; + } + } + fn rewrite_f_up(&mut self, node: &Expr) { + if matches!(node, Expr::Alias(_)) { + self.alias_counter -= 1 + } } } -fn pop_expr(new_expr: &mut Vec>) -> Result> { - new_expr - .pop() - .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) +impl Default for CommonSubexprEliminate { + fn default() -> Self { + Self::new() + } } /// Build the "intermediate" projection plan that evaluates the extracted common @@ -477,31 +723,22 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { /// # Arguments /// input: the input plan /// -/// affected_id: which common subexpressions were used (and thus are added to +/// common_exprs: which common subexpressions were used (and thus are added to /// intermediate projection) /// -/// expr_set: the set of common subexpressions +/// expr_stats: the set of common subexpressions fn build_common_expr_project_plan( input: LogicalPlan, - affected_id: BTreeSet, - expr_set: &ExprSet, + common_exprs: Vec<(Expr, String)>, ) -> Result { - let mut project_exprs = vec![]; let mut fields_set = BTreeSet::new(); - - for id in affected_id { - match expr_set.get(&id) { - Some((expr, _, data_type, symbol)) => { - // todo: check `nullable` - let field = Field::new(&id, data_type.clone(), true); - fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(symbol.as_str())); - } - _ => { - return internal_err!("expr_set invalid state"); - } - } - } + let mut project_exprs = common_exprs + .into_iter() + .map(|(expr, expr_alias)| { + fields_set.insert(expr_alias.clone()); + Ok(expr.alias(expr_alias)) + }) + .collect::>>()?; for (qualifier, field) in input.schema().iter() { if fields_set.insert(qualified_name(qualifier, field.name())) { @@ -509,10 +746,7 @@ fn build_common_expr_project_plan( } } - Ok(LogicalPlan::Projection(Projection::try_new( - project_exprs, - Arc::new(input), - )?)) + Projection::try_new(project_exprs, Arc::new(input)).map(LogicalPlan::Projection) } /// Build the projection plan to eliminate unnecessary columns produced by @@ -525,267 +759,53 @@ fn build_recover_project_plan( input: LogicalPlan, ) -> Result { let col_exprs = schema.iter().map(Expr::from).collect(); - Ok(LogicalPlan::Projection(Projection::try_new( - col_exprs, - Arc::new(input), - )?)) + Projection::try_new(col_exprs, Arc::new(input)).map(LogicalPlan::Projection) } -fn extract_expressions( - expr: &Expr, - schema: &DFSchema, - result: &mut Vec, -) -> Result<()> { +fn extract_expressions(expr: &Expr, result: &mut Vec) { if let Expr::GroupingSet(groupings) = expr { for e in groupings.distinct_expr() { - let (qualifier, field) = e.to_field(schema)?; - let col = Column::new(qualifier, field.name()); + let (qualifier, field_name) = e.qualified_name(); + let col = Column::new(qualifier, field_name); result.push(Expr::Column(col)) } } else { - let (qualifier, field) = expr.to_field(schema)?; - let col = Column::new(qualifier, field.name()); + let (qualifier, field_name) = expr.qualified_name(); + let col = Column::new(qualifier, field_name); result.push(Expr::Column(col)); } - - Ok(()) -} - -/// Which type of [expressions](Expr) should be considered for rewriting? -#[derive(Debug, Clone, Copy)] -enum ExprMask { - /// Ignores: - /// - /// - [`Literal`](Expr::Literal) - /// - [`Columns`](Expr::Column) - /// - [`ScalarVariable`](Expr::ScalarVariable) - /// - [`Alias`](Expr::Alias) - /// - [`Sort`](Expr::Sort) - /// - [`Wildcard`](Expr::Wildcard) - /// - [`AggregateFunction`](Expr::AggregateFunction) - Normal, - - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). - NormalAndAggregates, -} - -impl ExprMask { - fn ignores(&self, expr: &Expr) -> bool { - let is_normal_minus_aggregates = matches!( - expr, - Expr::Literal(..) - | Expr::Column(..) - | Expr::ScalarVariable(..) - | Expr::Alias(..) - | Expr::Sort { .. } - | Expr::Wildcard { .. } - ); - - let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - - match self { - Self::Normal => is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_normal_minus_aggregates, - } - } -} - -/// Go through an expression tree and generate identifiers for each subexpression. -/// -/// An identifier contains information of the expression itself and its sub-expression. -/// This visitor implementation use a stack `visit_stack` to track traversal, which -/// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called -/// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` -/// before the first `EnterMark` is considered to be sub-tree of the leaving node. -/// -/// This visitor also records identifier in `id_array`. Makes the following traverse -/// pass can get the identifier of a node without recalculate it. We assign each node -/// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`f_up()`) a node. Has the property -/// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to -/// get the index of `id_array` for each node. -/// -/// `Expr` without sub-expr (column, literal etc.) will not have identifier -/// because they should not be recognized as common sub-expr. -struct ExprIdentifierVisitor<'a> { - // param - expr_set: &'a mut ExprSet, - /// input schema for the node that we're optimizing, so we can determine the correct datatype - /// for each subexpression - input_schema: DFSchemaRef, - // inner states - visit_stack: Vec, - /// increased in fn_down, start from 0. - node_count: usize, - /// which expression should be skipped? - expr_mask: ExprMask, -} - -/// Record item that used when traversing a expression tree. -enum VisitRecord { - /// `usize` is the monotone increasing series number assigned in pre_visit(). - /// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). - EnterMark(usize), - /// the node's children were skipped => jump to f_up on same node - JumpMark(usize), - /// Accumulated identifier of sub expression. - ExprItem(Identifier), -} - -impl ExprIdentifierVisitor<'_> { - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` - /// before it. - fn pop_enter_mark(&mut self) -> (usize, Identifier) { - let mut desc = String::new(); - - while let Some(item) = self.visit_stack.pop() { - match item { - VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { - return (idx, desc); - } - VisitRecord::ExprItem(id) => { - desc.push_str(&id); - } - } - } - unreachable!("Enter mark should paired with node number"); - } -} - -impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type Node = Expr; - - fn f_down(&mut self, expr: &Expr) -> Result { - // related to https://github.com/apache/datafusion/issues/8814 - // If the expr contain volatile expression or is a short-circuit expression, skip it. - if expr.short_circuits() || expr.is_volatile()? { - self.visit_stack - .push(VisitRecord::JumpMark(self.node_count)); - return Ok(TreeNodeRecursion::Jump); // go to f_up - } - - self.visit_stack - .push(VisitRecord::EnterMark(self.node_count)); - self.node_count += 1; - - Ok(TreeNodeRecursion::Continue) - } - - fn f_up(&mut self, expr: &Expr) -> Result { - let (_idx, sub_expr_identifier) = self.pop_enter_mark(); - - // skip exprs should not be recognize. - if self.expr_mask.ignores(expr) { - let curr_expr_identifier = ExprSet::expr_identifier(expr); - self.visit_stack - .push(VisitRecord::ExprItem(curr_expr_identifier)); - return Ok(TreeNodeRecursion::Continue); - } - let curr_expr_identifier = ExprSet::expr_identifier(expr); - let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}"); - - self.visit_stack - .push(VisitRecord::ExprItem(alias_symbol.clone())); - - let data_type = expr.get_type(&self.input_schema)?; - - self.expr_set - .entry(curr_expr_identifier) - .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol)) - .1 += 1; - Ok(TreeNodeRecursion::Continue) - } -} - -/// Rewrite expression by common sub-expression with a corresponding temporary -/// column name that will compute the subexpression. -/// -/// `affected_id` is updated with any sub expressions that were replaced -struct CommonSubexprRewriter<'a> { - expr_set: &'a ExprSet, - /// Which identifier is replaced. - affected_id: &'a mut BTreeSet, -} - -impl TreeNodeRewriter for CommonSubexprRewriter<'_> { - type Node = Expr; - - fn f_down(&mut self, expr: Expr) -> Result> { - // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate - // the `id_array`, which records the expr's identifier used to rewrite expr. So if we - // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() || expr.is_volatile()? { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } - - let curr_id = &ExprSet::expr_identifier(&expr); - - // lookup previously visited expression - match self.expr_set.get(curr_id) { - Some((_, counter, _, symbol)) => { - // if has a commonly used (a.k.a. 1+ use) expr - if *counter > 1 { - self.affected_id.insert(curr_id.clone()); - - let expr_name = expr.display_name()?; - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - Ok(Transformed::new( - col(symbol).alias(expr_name), - true, - TreeNodeRecursion::Jump, - )) - } else { - Ok(Transformed::no(expr)) - } - } - None => Ok(Transformed::no(expr)), - } - } -} - -/// Replace common sub-expression in `expr` with the corresponding temporary -/// column name, updating `affected_id` with any replaced expressions -fn replace_common_expr( - expr: Expr, - expr_set: &ExprSet, - affected_id: &mut BTreeSet, -) -> Result { - expr.rewrite(&mut CommonSubexprRewriter { - expr_set, - affected_id, - }) - .data() } #[cfg(test)] mod test { + use std::any::Any; use std::iter; - use arrow::datatypes::Schema; - + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::logical_plan::{table_scan, JoinType}; - use datafusion_expr::{avg, lit, logical_plan::builder::LogicalPlanBuilder, sum}; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, Signature, - SimpleAggregateUDF, Volatility, + grouping_set, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, Volatility, }; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use super::*; use crate::optimizer::OptimizerContext; use crate::test::*; - - use super::*; - - fn assert_optimized_plan_eq(expected: &str, plan: &LogicalPlan) { - let optimizer = CommonSubexprEliminate {}; - let optimized_plan = optimizer - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); + use crate::Optimizer; + use datafusion_expr::test::function_stub::{avg, sum}; + + fn assert_optimized_plan_eq( + expected: &str, + plan: LogicalPlan, + config: Option<&dyn OptimizerConfig>, + ) { + let optimizer = + Optimizer::with_rules(vec![Arc::new(CommonSubexprEliminate::new())]); + let default_config = OptimizerContext::new(); + let config = config.unwrap_or(&default_config); + let optimized_plan = optimizer.optimize(plan, config, |_, _| ()).unwrap(); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(expected, formatted_plan); } @@ -811,11 +831,31 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b), SUM(test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ - \n Projection: test.a * (Int32(1) - test.b) AS test.a * (Int32(1) - test.b)Int32(1) - test.btest.bInt32(1)test.a, test.a, test.b, test.c\ + let expected = "Aggregate: groupBy=[[]], aggr=[[sum(__common_expr_1 AS test.a * Int32(1) - test.b), sum(__common_expr_1 AS test.a * Int32(1) - test.b * (Int32(1) + test.c))]]\ + \n Projection: test.a * (Int32(1) - test.b) AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn nested_aliases() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + col("b")), + col("a") + col("b"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 - test.c AS alias1 * __common_expr_1 AS test.a + test.b, __common_expr_1 AS test.a + test.b\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -832,7 +872,7 @@ mod test { "my_agg", Signature::exact(vec![DataType::UInt32], Volatility::Stable), return_type.clone(), - accumulator.clone(), + Arc::clone(&accumulator), vec![Field::new("value", DataType::UInt32, true)], ))), vec![inner], @@ -864,11 +904,11 @@ mod test { )? .build()?; - let expected = "Projection: AVG(test.a)test.a AS AVG(test.a) AS col1, AVG(test.a)test.a AS AVG(test.a) AS col2, col3, AVG(test.c) AS AVG(test.c), my_agg(test.a)test.a AS my_agg(test.a) AS col4, my_agg(test.a)test.a AS my_agg(test.a) AS col5, col6, my_agg(test.c) AS my_agg(test.c)\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a, AVG(test.b) AS col3, AVG(test.c) AS AVG(test.c), my_agg(test.b) AS col6, my_agg(test.c) AS my_agg(test.c)]]\ + let expected = "Projection: __common_expr_1 AS col1, __common_expr_1 AS col2, col3, __common_expr_3 AS avg(test.c), __common_expr_2 AS col4, __common_expr_2 AS col5, col6, __common_expr_4 AS my_agg(test.c)\ + \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2, avg(test.b) AS col3, avg(test.c) AS __common_expr_3, my_agg(test.b) AS col6, my_agg(test.c) AS __common_expr_4]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); // test: trafo after aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -883,11 +923,11 @@ mod test { )? .build()?; - let expected = "Projection: Int32(1) + AVG(test.a)test.a AS AVG(test.a), Int32(1) - AVG(test.a)test.a AS AVG(test.a), Int32(1) + my_agg(test.a)test.a AS my_agg(test.a), Int32(1) - my_agg(test.a)test.a AS my_agg(test.a)\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS AVG(test.a)test.a, my_agg(test.a) AS my_agg(test.a)test.a]]\ + let expected = "Projection: Int32(1) + __common_expr_1 AS avg(test.a), Int32(1) - __common_expr_1 AS avg(test.a), Int32(1) + __common_expr_2 AS my_agg(test.a), Int32(1) - __common_expr_2 AS my_agg(test.a)\ + \n Aggregate: groupBy=[[]], aggr=[[avg(test.a) AS __common_expr_1, my_agg(test.a) AS __common_expr_2]]\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); // test: transformation before aggregate let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -900,11 +940,11 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col1, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col2]]\ - \n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\ + let expected ="Aggregate: groupBy=[[]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ + \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); // test: common between agg and group let plan = LogicalPlanBuilder::from(table_scan.clone()) @@ -917,11 +957,11 @@ mod test { )? .build()?; - let expected = "Aggregate: groupBy=[[UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col1, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS col2]]\ - \n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\ + let expected = "Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS col1, my_agg(__common_expr_1) AS col2]]\ + \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); // test: all mixed let plan = LogicalPlanBuilder::from(table_scan) @@ -938,18 +978,18 @@ mod test { )? .build()?; - let expected = "Projection: UInt32(1) + test.a, UInt32(1) + AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a) AS col1, UInt32(1) - AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a) AS col2, AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS AVG(UInt32(1) + test.a), UInt32(1) + my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS my_agg(UInt32(1) + test.a) AS col3, UInt32(1) - my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS my_agg(UInt32(1) + test.a) AS col4, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a AS my_agg(UInt32(1) + test.a)\ - \n Aggregate: groupBy=[[UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]], aggr=[[AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS AVG(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a, my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a) AS my_agg(UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a)UInt32(1) + test.atest.aUInt32(1) AS UInt32(1) + test.a]]\ - \n Projection: UInt32(1) + test.a AS UInt32(1) + test.atest.aUInt32(1), test.a, test.b, test.c\ + let expected = "Projection: UInt32(1) + test.a, UInt32(1) + __common_expr_2 AS col1, UInt32(1) - __common_expr_2 AS col2, __common_expr_4 AS avg(UInt32(1) + test.a), UInt32(1) + __common_expr_3 AS col3, UInt32(1) - __common_expr_3 AS col4, __common_expr_5 AS my_agg(UInt32(1) + test.a)\ + \n Aggregate: groupBy=[[__common_expr_1 AS UInt32(1) + test.a]], aggr=[[avg(__common_expr_1) AS __common_expr_2, my_agg(__common_expr_1) AS __common_expr_3, avg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_4, my_agg(__common_expr_1 AS UInt32(1) + test.a) AS __common_expr_5]]\ + \n Projection: UInt32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } #[test] - fn aggregate_with_releations_and_dots() -> Result<()> { + fn aggregate_with_relations_and_dots() -> Result<()> { let schema = Schema::new(vec![Field::new("col.a", DataType::UInt32, false)]); let table_scan = table_scan(Some("table.test"), &schema, None)?.build()?; @@ -965,12 +1005,12 @@ mod test { )? .build()?; - let expected = "Projection: table.test.col.a, UInt32(1) + AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a)UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a AS AVG(UInt32(1) + table.test.col.a), AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a)UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a AS AVG(UInt32(1) + table.test.col.a)\ - \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a) AS AVG(UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a)UInt32(1) + table.test.col.atable.test.col.aUInt32(1) AS UInt32(1) + table.test.col.a]]\ - \n Projection: UInt32(1) + table.test.col.a AS UInt32(1) + table.test.col.atable.test.col.aUInt32(1), table.test.col.a\ + let expected = "Projection: table.test.col.a, UInt32(1) + __common_expr_2 AS avg(UInt32(1) + table.test.col.a), __common_expr_2 AS avg(UInt32(1) + table.test.col.a)\ + \n Aggregate: groupBy=[[table.test.col.a]], aggr=[[avg(__common_expr_1 AS UInt32(1) + table.test.col.a) AS __common_expr_2]]\ + \n Projection: UInt32(1) + table.test.col.a AS __common_expr_1, table.test.col.a\ \n TableScan: table.test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -986,11 +1026,11 @@ mod test { ])? .build()?; - let expected = "Projection: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a AS first, Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a AS second\ - \n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ + let expected = "Projection: __common_expr_1 AS first, __common_expr_1 AS second\ + \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1006,7 +1046,7 @@ mod test { let expected = "Projection: Int32(1) + test.a, test.a + Int32(1)\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1016,50 +1056,35 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![lit(1) + col("a")])? + .project(vec![lit(1) + col("a"), col("a")])? .project(vec![lit(1) + col("a")])? .build()?; let expected = "Projection: Int32(1) + test.a\ - \n Projection: Int32(1) + test.a\ + \n Projection: Int32(1) + test.a, test.a\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } #[test] fn redundant_project_fields() { let table_scan = test_table_scan().unwrap(); - let affected_id: BTreeSet = - ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); - let expr_set_1 = vec![ - ( - "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), - ), - ( - "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), - ), - ] - .into(); - let expr_set_2 = vec![ - ( - "c+a".to_string(), - (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), - ), - ( - "b+a".to_string(), - (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), - ), - ] - .into(); - let project = - build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) - .unwrap(); - let project_2 = - build_common_expr_project_plan(project, affected_id, &expr_set_2).unwrap(); + let c_plus_a = col("c") + col("a"); + let b_plus_a = col("b") + col("a"); + let common_exprs_1 = vec![ + (c_plus_a, format!("{CSE_PREFIX}_1")), + (b_plus_a, format!("{CSE_PREFIX}_2")), + ]; + let c_plus_a_2 = col(format!("{CSE_PREFIX}_1")); + let b_plus_a_2 = col(format!("{CSE_PREFIX}_2")); + let common_exprs_2 = vec![ + (c_plus_a_2, format!("{CSE_PREFIX}_3")), + (b_plus_a_2, format!("{CSE_PREFIX}_4")), + ]; + let project = build_common_expr_project_plan(table_scan, common_exprs_1).unwrap(); + let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap(); let mut field_set = BTreeSet::new(); for name in project_2.schema().field_names() { @@ -1076,57 +1101,20 @@ mod test { .unwrap() .build() .unwrap(); - let affected_id: BTreeSet = - ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] - .into_iter() - .collect(); - let expr_set_1 = vec![ - ( - "test1.c+test1.a".to_string(), - ( - col("test1.c") + col("test1.a"), - 1, - DataType::UInt32, - "test1.c+test1.a".to_string(), - ), - ), - ( - "test1.b+test1.a".to_string(), - ( - col("test1.b") + col("test1.a"), - 1, - DataType::UInt32, - "test1.b+test1.a".to_string(), - ), - ), - ] - .into(); - let expr_set_2 = vec![ - ( - "test1.c+test1.a".to_string(), - ( - col("test1.c+test1.a"), - 1, - DataType::UInt32, - "test1.c+test1.a".to_string(), - ), - ), - ( - "test1.b+test1.a".to_string(), - ( - col("test1.b+test1.a"), - 1, - DataType::UInt32, - "test1.b+test1.a".to_string(), - ), - ), - ] - .into(); - let project = - build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) - .unwrap(); - let project_2 = - build_common_expr_project_plan(project, affected_id, &expr_set_2).unwrap(); + let c_plus_a = col("test1.c") + col("test1.a"); + let b_plus_a = col("test1.b") + col("test1.a"); + let common_exprs_1 = vec![ + (c_plus_a, format!("{CSE_PREFIX}_1")), + (b_plus_a, format!("{CSE_PREFIX}_2")), + ]; + let c_plus_a_2 = col(format!("{CSE_PREFIX}_1")); + let b_plus_a_2 = col(format!("{CSE_PREFIX}_2")); + let common_exprs_2 = vec![ + (c_plus_a_2, format!("{CSE_PREFIX}_3")), + (b_plus_a_2, format!("{CSE_PREFIX}_4")), + ]; + let project = build_common_expr_project_plan(join, common_exprs_1).unwrap(); + let project_2 = build_common_expr_project_plan(project, common_exprs_2).unwrap(); let mut field_set = BTreeSet::new(); for name in project_2.schema().field_names() { @@ -1154,11 +1142,10 @@ mod test { .unwrap() .build() .unwrap(); - let rule = CommonSubexprEliminate {}; - let optimized_plan = rule - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() - .unwrap(); + let rule = CommonSubexprEliminate::new(); + let optimized_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(optimized_plan.transformed); + let optimized_plan = optimized_plan.data; let schema = optimized_plan.schema(); let fields_with_datatypes: Vec<_> = schema @@ -1193,11 +1180,11 @@ mod test { .build()?; let expected = "Projection: test.a, test.b, test.c\ - \n Filter: Int32(1) + test.atest.aInt32(1) - Int32(10) > Int32(1) + test.atest.aInt32(1)\ - \n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ + \n Filter: __common_expr_1 - Int32(10) > __common_expr_1\ + \n Projection: Int32(1) + test.a AS __common_expr_1, test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(expected, &plan); + assert_optimized_plan_eq(expected, plan, None); Ok(()) } @@ -1206,16 +1193,7 @@ mod test { fn test_extract_expressions_from_grouping_set() -> Result<()> { let mut result = Vec::with_capacity(3); let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("c")]]); - let schema = DFSchema::from_unqualifed_fields( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - Field::new("c", DataType::Int32, false), - ] - .into(), - HashMap::default(), - )?; - extract_expressions(&grouping, &schema, &mut result)?; + extract_expressions(&grouping, &mut result); assert!(result.len() == 3); Ok(()) @@ -1225,30 +1203,244 @@ mod test { fn test_extract_expressions_from_grouping_set_with_identical_expr() -> Result<()> { let mut result = Vec::with_capacity(2); let grouping = grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]); - let schema = DFSchema::from_unqualifed_fields( - vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Int32, false), - ] - .into(), - HashMap::default(), - )?; - extract_expressions(&grouping, &schema, &mut result)?; - + extract_expressions(&grouping, &mut result); assert!(result.len() == 2); Ok(()) } + #[test] + fn test_alias_collision() -> Result<()> { + let table_scan = test_table_scan()?; + + let config = &OptimizerContext::new(); + let common_expr_1 = config.alias_generator().next(CSE_PREFIX); + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + (col("a") + col("b")).alias(common_expr_1.clone()), + col("c"), + ])? + .project(vec![ + col(common_expr_1.clone()).alias("c1"), + col(common_expr_1).alias("c2"), + (col("c") + lit(2)).alias("c3"), + (col("c") + lit(2)).alias("c4"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4\ + \n Projection: test.c + Int32(2) AS __common_expr_2, __common_expr_1, test.c\ + \n Projection: test.a + test.b AS __common_expr_1, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, Some(config)); + + let config = &OptimizerContext::new(); + let _common_expr_1 = config.alias_generator().next(CSE_PREFIX); + let common_expr_2 = config.alias_generator().next(CSE_PREFIX); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + (col("a") + col("b")).alias(common_expr_2.clone()), + col("c"), + ])? + .project(vec![ + col(common_expr_2.clone()).alias("c1"), + col(common_expr_2).alias("c2"), + (col("c") + lit(2)).alias("c3"), + (col("c") + lit(2)).alias("c4"), + ])? + .build()?; + + let expected = "Projection: __common_expr_2 AS c1, __common_expr_2 AS c2, __common_expr_3 AS c3, __common_expr_3 AS c4\ + \n Projection: test.c + Int32(2) AS __common_expr_3, __common_expr_2, test.c\ + \n Projection: test.a + test.b AS __common_expr_2, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, Some(config)); + + Ok(()) + } + #[test] fn test_extract_expressions_from_col() -> Result<()> { let mut result = Vec::with_capacity(1); - let schema = DFSchema::from_unqualifed_fields( - vec![Field::new("a", DataType::Int32, false)].into(), - HashMap::default(), - )?; - extract_expressions(&col("a"), &schema, &mut result)?; - + extract_expressions(&col("a"), &mut result); assert!(result.len() == 1); Ok(()) } + + #[test] + fn test_short_circuits() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_short_circuit = col("a").eq(lit(0)).or(col("b").eq(lit(0))); + let extracted_short_circuit_leg_1 = (col("a") + col("b")).eq(lit(0)); + let not_extracted_short_circuit_leg_2 = (col("a") - col("b")).eq(lit(0)); + let extracted_short_circuit_leg_3 = (col("a") * col("b")).eq(lit(0)); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + extracted_short_circuit.clone().alias("c1"), + extracted_short_circuit.alias("c2"), + extracted_short_circuit_leg_1 + .clone() + .or(not_extracted_short_circuit_leg_2.clone()) + .alias("c3"), + extracted_short_circuit_leg_1 + .and(not_extracted_short_circuit_leg_2) + .alias("c4"), + extracted_short_circuit_leg_3 + .clone() + .or(extracted_short_circuit_leg_3) + .alias("c5"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 OR test.a - test.b = Int32(0) AS c3, __common_expr_2 AND test.a - test.b = Int32(0) AS c4, __common_expr_3 OR __common_expr_3 AS c5\ + \n Projection: test.a = Int32(0) OR test.b = Int32(0) AS __common_expr_1, test.a + test.b = Int32(0) AS __common_expr_2, test.a * test.b = Int32(0) AS __common_expr_3, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_child = col("a") + col("b"); + let rand = rand_func().call(vec![]); + let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + not_extracted_volatile.clone().alias("c1"), + not_extracted_volatile.alias("c2"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_volatile_short_circuits() -> Result<()> { + let table_scan = test_table_scan()?; + + let rand = rand_func().call(vec![]); + let extracted_short_circuit_leg_1 = col("a").eq(lit(0)); + let not_extracted_volatile_short_circuit_1 = + extracted_short_circuit_leg_1.or(rand.clone().eq(lit(0))); + let not_extracted_short_circuit_leg_2 = col("b").eq(lit(0)); + let not_extracted_volatile_short_circuit_2 = + rand.eq(lit(0)).or(not_extracted_short_circuit_leg_2); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + not_extracted_volatile_short_circuit_1.clone().alias("c1"), + not_extracted_volatile_short_circuit_1.alias("c2"), + not_extracted_volatile_short_circuit_2.clone().alias("c3"), + not_extracted_volatile_short_circuit_2.alias("c4"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 OR random() = Int32(0) AS c1, __common_expr_1 OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\ + \n Projection: test.a = Int32(0) AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_non_top_level_common_expression() -> Result<()> { + let table_scan = test_table_scan()?; + + let common_expr = col("a") + col("b"); + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + common_expr.clone().alias("c1"), + common_expr.alias("c2"), + ])? + .project(vec![col("c1"), col("c2")])? + .build()?; + + let expected = "Projection: c1, c2\ + \n Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + #[test] + fn test_nested_common_expression() -> Result<()> { + let table_scan = test_table_scan()?; + + let nested_common_expr = col("a") + col("b"); + let common_expr = nested_common_expr.clone() * nested_common_expr; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + common_expr.clone().alias("c1"), + common_expr.alias("c2"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 AS c1, __common_expr_1 AS c2\ + \n Projection: __common_expr_2 * __common_expr_2 AS __common_expr_1, test.a, test.b, test.c\ + \n Projection: test.a + test.b AS __common_expr_2, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } + + /// returns a "random" function that is marked volatile (aka each invocation + /// returns a different value) + /// + /// Does not use datafusion_functions::rand to avoid introducing a + /// dependency on that crate. + fn rand_func() -> ScalarUDF { + ScalarUDF::new_from_impl(RandomStub::new()) + } + + #[derive(Debug)] + struct RandomStub { + signature: Signature, + } + + impl RandomStub { + fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for RandomStub { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "random" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } } diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index a6abec9efd8c..6aa59b77f7f9 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -19,6 +19,7 @@ use std::collections::{BTreeSet, HashMap}; use std::ops::Deref; +use std::sync::Arc; use crate::simplify_expressions::ExprSimplifier; use crate::utils::collect_subquery_cols; @@ -27,10 +28,13 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; -use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; +use datafusion_expr::expr::Alias; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; -use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{ + expr, lit, BinaryExpr, Cast, EmptyRelation, Expr, FetchType, LogicalPlan, + LogicalPlanBuilder, Operator, +}; use datafusion_physical_expr::execution_props::ExecutionProps; /// This struct rewrite the sub query plan by pull up the correlated @@ -38,25 +42,75 @@ use datafusion_physical_expr::execution_props::ExecutionProps; /// 'Filter'. It adds the inner reference columns to the 'Projection' or /// 'Aggregate' of the subquery if they are missing, so that they can be /// evaluated by the parent operator as the join condition. +#[derive(Debug)] pub struct PullUpCorrelatedExpr { pub join_filters: Vec, - // mapping from the plan to its holding correlated columns + /// mapping from the plan to its holding correlated columns pub correlated_subquery_cols_map: HashMap>, pub in_predicate_opt: Option, - // indicate whether it is Exists(Not Exists) SubQuery + /// Is this an Exists(Not Exists) SubQuery. Defaults to **FALSE** pub exists_sub_query: bool, - // indicate whether the correlated expressions can pull up or not + /// Can the correlated expressions be pulled up. Defaults to **TRUE** pub can_pull_up: bool, - // indicate whether need to handle the Count bug during the pull up process + /// Indicates if we encounter any correlated expression that can not be pulled up + /// above a aggregation without changing the meaning of the query. + can_pull_over_aggregation: bool, + /// Do we need to handle [the Count bug] during the pull up process + /// + /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 pub need_handle_count_bug: bool, - // mapping from the plan to its expressions' evaluation result on empty batch + /// mapping from the plan to its expressions' evaluation result on empty batch pub collected_count_expr_map: HashMap, - // pull up having expr, which must be evaluated after the Join + /// pull up having expr, which must be evaluated after the Join pub pull_up_having_expr: Option, } +impl Default for PullUpCorrelatedExpr { + fn default() -> Self { + Self::new() + } +} + +impl PullUpCorrelatedExpr { + pub fn new() -> Self { + Self { + join_filters: vec![], + correlated_subquery_cols_map: HashMap::new(), + in_predicate_opt: None, + exists_sub_query: false, + can_pull_up: true, + can_pull_over_aggregation: true, + need_handle_count_bug: false, + collected_count_expr_map: HashMap::new(), + pull_up_having_expr: None, + } + } + + /// Set if we need to handle [the Count bug] during the pull up process + /// + /// [the Count bug]: https://github.com/apache/datafusion/pull/10500 + pub fn with_need_handle_count_bug(mut self, need_handle_count_bug: bool) -> Self { + self.need_handle_count_bug = need_handle_count_bug; + self + } + + /// Set the in_predicate_opt + pub fn with_in_predicate_opt(mut self, in_predicate_opt: Option) -> Self { + self.in_predicate_opt = in_predicate_opt; + self + } + + /// Set if this is an Exists(Not Exists) SubQuery + pub fn with_exists_sub_query(mut self, exists_sub_query: bool) -> Self { + self.exists_sub_query = exists_sub_query; + self + } +} + /// Used to indicate the unmatched rows from the inner(subquery) table after the left out Join -/// This is used to handle the Count bug +/// This is used to handle [the Count bug] +/// +/// [the Count bug]: https://github.com/apache/datafusion/pull/10500 pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; /// Mapping from expr display name to its evaluation result on empty record @@ -101,10 +155,15 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { } fn f_up(&mut self, plan: LogicalPlan) -> Result> { - let subquery_schema = plan.schema().clone(); + let subquery_schema = plan.schema(); match &plan { LogicalPlan::Filter(plan_filter) => { let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + self.can_pull_over_aggregation = self.can_pull_over_aggregation + && subquery_filter_exprs + .iter() + .filter(|e| e.contains_outer()) + .all(|&e| can_pullup_over_aggregation(e)); let (mut join_filters, subquery_filters) = find_join_exprs(subquery_filter_exprs)?; if let Some(in_predicate) = &self.in_predicate_opt { @@ -126,7 +185,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { if let Some(expr) = conjunction(subquery_filters.clone()) { filter_exprs_evaluation_result_on_empty_batch( &expr, - plan_filter.input.schema().clone(), + Arc::clone(plan_filter.input.schema()), expr_result_map, &mut expr_result_map_for_count_bug, )? @@ -184,7 +243,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { { proj_exprs_evaluation_result_on_empty_batch( &projection.expr, - projection.input.schema().clone(), + projection.input.schema(), expr_result_map, &mut expr_result_map_for_count_bug, )?; @@ -210,6 +269,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => { + // If the aggregation is from a distinct it will not change the result for + // exists/in subqueries so we can still pull up all predicates. + let is_distinct = aggregate.aggr_expr.is_empty(); + if !is_distinct { + self.can_pull_up = self.can_pull_up && self.can_pull_over_aggregation; + } let mut local_correlated_cols = BTreeSet::new(); collect_local_correlated_cols( &plan, @@ -230,14 +295,12 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { { agg_exprs_evaluation_result_on_empty_batch( &aggregate.aggr_expr, - aggregate.input.schema().clone(), + aggregate.input.schema(), &mut expr_result_map_for_count_bug, )?; if !expr_result_map_for_count_bug.is_empty() { // has count bug - let un_matched_row = - Expr::Literal(ScalarValue::Boolean(Some(true))) - .alias(UN_MATCHED_ROW_INDICATOR); + let un_matched_row = lit(true).alias(UN_MATCHED_ROW_INDICATOR); // add the unmatched rows indicator to the Aggregation's group expressions missing_exprs.push(un_matched_row); } @@ -282,16 +345,15 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => Transformed::yes( - if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { + (true, false) => Transformed::yes(match limit.get_fetch_type()? { + FetchType::Literal(Some(0)) => { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: limit.input.schema().clone(), + schema: Arc::clone(limit.input.schema()), }) - } else { - LogicalPlanBuilder::from((*limit.input).clone()).build()? - }, - ), + } + _ => LogicalPlanBuilder::from((*limit.input).clone()).build()?, + }), _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { @@ -324,11 +386,14 @@ impl PullUpCorrelatedExpr { } } if let Some(pull_up_having) = &self.pull_up_having_expr { - let filter_apply_columns = pull_up_having.to_columns()?; + let filter_apply_columns = pull_up_having.column_refs(); for col in filter_apply_columns { - let col_expr = Expr::Column(col); - if !missing_exprs.contains(&col_expr) { - missing_exprs.push(col_expr) + // add to missing_exprs if not already there + let contains = missing_exprs + .iter() + .any(|expr| matches!(expr, Expr::Column(c) if c == col)); + if !contains { + missing_exprs.push(Expr::Column(col.clone())) } } } @@ -336,6 +401,33 @@ impl PullUpCorrelatedExpr { } } +fn can_pullup_over_aggregation(expr: &Expr) -> bool { + if let Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) = expr + { + match (left.deref(), right.deref()) { + (Expr::Column(_), right) => !right.any_column_refs(), + (left, Expr::Column(_)) => !left.any_column_refs(), + (Expr::Cast(Cast { expr, .. }), right) + if matches!(expr.deref(), Expr::Column(_)) => + { + !right.any_column_refs() + } + (left, Expr::Cast(Cast { expr, .. })) + if matches!(expr.deref(), Expr::Column(_)) => + { + !left.any_column_refs() + } + (_, _) => false, + } + } else { + false + } +} + fn collect_local_correlated_cols( plan: &LogicalPlan, all_cols_map: &HashMap>, @@ -375,7 +467,7 @@ fn remove_duplicated_filter(filters: Vec, in_predicate: &Expr) -> Vec Result<()> { for e in agg_expr.iter() { @@ -383,25 +475,13 @@ fn agg_exprs_evaluation_result_on_empty_batch( .clone() .transform_up(|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { - func_def, .. - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( - 0, - )))) - } else { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } - } - AggregateFunctionDefinition::UDF { .. } => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } - AggregateFunctionDefinition::Name(_) => { + Expr::AggregateFunction(expr::AggregateFunction { func, .. }) => { + if func.name() == "count" { + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0)))) + } else { Transformed::yes(Expr::Literal(ScalarValue::Null)) } - }, + } _ => Transformed::no(expr), }; Ok(new_expr) @@ -410,11 +490,12 @@ fn agg_exprs_evaluation_result_on_empty_batch( let result_expr = result_expr.unalias(); let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(schema.clone()); + let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; if matches!(result_expr, Expr::Literal(ScalarValue::Int64(_))) { - expr_result_map_for_count_bug.insert(e.display_name()?, result_expr); + expr_result_map_for_count_bug + .insert(e.schema_name().to_string(), result_expr); } } Ok(()) @@ -422,7 +503,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( fn proj_exprs_evaluation_result_on_empty_batch( proj_expr: &[Expr], - schema: DFSchemaRef, + schema: &DFSchemaRef, input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { @@ -446,13 +527,13 @@ fn proj_exprs_evaluation_result_on_empty_batch( if result_expr.ne(expr) { let props = ExecutionProps::new(); - let info = SimplifyContext::new(&props).with_schema(schema.clone()); + let info = SimplifyContext::new(&props).with_schema(Arc::clone(schema)); let simplifier = ExprSimplifier::new(info); let result_expr = simplifier.simplify(result_expr)?; let expr_name = match expr { Expr::Alias(Alias { name, .. }) => name.to_string(), Expr::Column(Column { relation: _, name }) => name.to_string(), - _ => expr.display_name()?, + _ => expr.schema_name().to_string(), }; expr_result_map_for_count_bug.insert(expr_name, result_expr); } @@ -508,8 +589,8 @@ fn filter_exprs_evaluation_result_on_empty_batch( )], else_expr: Some(Box::new(Expr::Literal(ScalarValue::Null))), }); - expr_result_map_for_count_bug - .insert(new_expr.display_name()?, new_expr); + let expr_key = new_expr.schema_name().to_string(); + expr_result_map_for_count_bug.insert(expr_key, new_expr); } None } diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 2e726321704f..7fdad5ba4b6e 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -26,21 +26,21 @@ use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; -use datafusion_common::{plan_err, Result}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{internal_err, plan_err, Column, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; -use datafusion_expr::utils::{conjunction, split_conjunction}; +use datafusion_expr::utils::{conjunction, split_conjunction_owned}; use datafusion_expr::{ - exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, + exists, in_subquery, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; use log::debug; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins -#[derive(Default)] +#[derive(Default, Debug)] pub struct DecorrelatePredicateSubquery {} impl DecorrelatePredicateSubquery { @@ -48,117 +48,72 @@ impl DecorrelatePredicateSubquery { pub fn new() -> Self { Self::default() } +} + +impl OptimizerRule for DecorrelatePredicateSubquery { + fn supports_rewrite(&self) -> bool { + true + } - /// Finds expressions that have the predicate subqueries (and recurses when found) - /// - /// # Arguments - /// - /// * `predicate` - A conjunction to split and search - /// * `optimizer_config` - For generating unique subquery aliases - /// - /// Returns a tuple (subqueries, non-subquery expressions) - fn extract_subquery_exprs( + fn rewrite( &self, - predicate: &Expr, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result<(Vec, Vec)> { - let filters = split_conjunction(predicate); // TODO: add ExistenceJoin to support disjunctions - - let mut subqueries = vec![]; - let mut others = vec![]; - for it in filters.iter() { - match it { - Expr::InSubquery(InSubquery { - expr, - subquery, - negated, - }) => { - let subquery_plan = self - .try_optimize(&subquery.subquery, config)? - .map(Arc::new) - .unwrap_or_else(|| subquery.subquery.clone()); - let new_subquery = subquery.with_plan(subquery_plan); - subqueries.push(SubqueryInfo::new_with_in_expr( - new_subquery, - (**expr).clone(), - *negated, - )); - } - Expr::Exists(Exists { subquery, negated }) => { - let subquery_plan = self - .try_optimize(&subquery.subquery, config)? - .map(Arc::new) - .unwrap_or_else(|| subquery.subquery.clone()); - let new_subquery = subquery.with_plan(subquery_plan); - subqueries.push(SubqueryInfo::new(new_subquery, *negated)); - } - _ => others.push((*it).clone()), - } + ) -> Result> { + let plan = plan + .map_subqueries(|subquery| { + subquery.transform_down(|p| self.rewrite(p, config)) + })? + .data; + + let LogicalPlan::Filter(filter) = plan else { + return Ok(Transformed::no(plan)); + }; + + if !has_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } - Ok((subqueries, others)) - } -} + let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) = + split_conjunction_owned(filter.predicate) + .into_iter() + .partition(has_subquery); -impl OptimizerRule for DecorrelatePredicateSubquery { - fn try_optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Filter(filter) => { - let (subqueries, mut other_exprs) = - self.extract_subquery_exprs(&filter.predicate, config)?; - if subqueries.is_empty() { - // regular filter, no subquery exists clause here - return Ok(None); - } + if with_subqueries.is_empty() { + return internal_err!( + "can not find expected subqueries in DecorrelatePredicateSubquery" + ); + } - // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = filter.input.as_ref().clone(); - for subquery in subqueries { - if let Some(plan) = - build_join(&subquery, &cur_input, config.alias_generator())? + // iterate through all exists clauses in predicate, turning each into a join + let mut cur_input = Arc::unwrap_or_clone(filter.input); + for subquery_expr in with_subqueries { + match extract_subquery_info(subquery_expr) { + // The subquery expression is at the top level of the filter + SubqueryPredicate::Top(subquery) => { + match build_join_top(&subquery, &cur_input, config.alias_generator())? { - cur_input = plan; - } else { + Some(plan) => cur_input = plan, // If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter - let sub_query_expr = match subquery { - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: false, - } => in_subquery(expr, query.subquery.clone()), - SubqueryInfo { - query, - where_in_expr: Some(expr), - negated: true, - } => not_in_subquery(expr, query.subquery.clone()), - SubqueryInfo { - query, - where_in_expr: None, - negated: false, - } => exists(query.subquery.clone()), - SubqueryInfo { - query, - where_in_expr: None, - negated: true, - } => not_exists(query.subquery.clone()), - }; - other_exprs.push(sub_query_expr); + None => other_exprs.push(subquery.expr()), } } - - let expr = conjunction(other_exprs); - if let Some(expr) = expr { - let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; - cur_input = LogicalPlan::Filter(new_filter); + // The subquery expression is embedded within another expression + SubqueryPredicate::Embedded(expr) => { + let (plan, expr_without_subqueries) = + rewrite_inner_subqueries(cur_input, expr, config)?; + cur_input = plan; + other_exprs.push(expr_without_subqueries); } - Ok(Some(cur_input)) } - _ => Ok(None), } + + let expr = conjunction(other_exprs); + if let Some(expr) = expr { + let new_filter = Filter::try_new(expr, Arc::new(cur_input))?; + cur_input = LogicalPlan::Filter(new_filter); + } + Ok(Transformed::yes(cur_input)) } fn name(&self) -> &str { @@ -170,6 +125,101 @@ impl OptimizerRule for DecorrelatePredicateSubquery { } } +fn rewrite_inner_subqueries( + outer: LogicalPlan, + expr: Expr, + config: &dyn OptimizerConfig, +) -> Result<(LogicalPlan, Expr)> { + let mut cur_input = outer; + let alias = config.alias_generator(); + let expr_without_subqueries = expr.transform(|e| match e { + Expr::Exists(Exists { + subquery: Subquery { subquery, .. }, + negated, + }) => match mark_join(&cur_input, Arc::clone(&subquery), None, negated, alias)? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_exists(subquery))), + None => Ok(Transformed::no(exists(subquery))), + }, + Expr::InSubquery(InSubquery { + expr, + subquery: Subquery { subquery, .. }, + negated, + }) => { + let in_predicate = subquery + .head_output_expr()? + .map_or(plan_err!("single expression required."), |output_expr| { + Ok(Expr::eq(*expr.clone(), output_expr)) + })?; + match mark_join( + &cur_input, + Arc::clone(&subquery), + Some(in_predicate), + negated, + alias, + )? { + Some((plan, exists_expr)) => { + cur_input = plan; + Ok(Transformed::yes(exists_expr)) + } + None if negated => Ok(Transformed::no(not_in_subquery(*expr, subquery))), + None => Ok(Transformed::no(in_subquery(*expr, subquery))), + } + } + _ => Ok(Transformed::no(e)), + })?; + Ok((cur_input, expr_without_subqueries.data)) +} + +enum SubqueryPredicate { + // The subquery expression is at the top level of the filter and can be fully replaced by a + // semi/anti join + Top(SubqueryInfo), + // The subquery expression is embedded within another expression and is replaced using an + // existence join + Embedded(Expr), +} + +fn extract_subquery_info(expr: Expr) -> SubqueryPredicate { + match expr { + Expr::Not(not_expr) => match *not_expr { + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, !negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, !negated)) + } + expr => SubqueryPredicate::Embedded(not(expr)), + }, + Expr::InSubquery(InSubquery { + expr, + subquery, + negated, + }) => SubqueryPredicate::Top(SubqueryInfo::new_with_in_expr( + subquery, *expr, negated, + )), + Expr::Exists(Exists { subquery, negated }) => { + SubqueryPredicate::Top(SubqueryInfo::new(subquery, negated)) + } + expr => SubqueryPredicate::Embedded(expr), + } +} + +fn has_subquery(expr: &Expr) -> bool { + expr.exists(|e| match e { + Expr::InSubquery(_) | Expr::Exists(_) => Ok(true), + _ => Ok(false), + }) + .unwrap() +} + /// Optimize the subquery to left-anti/left-semi join. /// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. /// @@ -200,10 +250,10 @@ impl OptimizerRule for DecorrelatePredicateSubquery { /// Projection: t2.id /// TableScan: t2 /// ``` -fn build_join( +fn build_join_top( query_info: &SubqueryInfo, left: &LogicalPlan, - alias: Arc, + alias: &Arc, ) -> Result> { let where_in_expr_opt = &query_info.where_in_expr; let in_predicate_opt = where_in_expr_opt @@ -219,26 +269,66 @@ fn build_join( }) .map_or(Ok(None), |v| v.map(Some))?; + let join_type = match query_info.negated { + true => JoinType::LeftAnti, + false => JoinType::LeftSemi, + }; let subquery = query_info.query.subquery.as_ref(); let subquery_alias = alias.next("__correlated_sq"); + build_join(left, subquery, in_predicate_opt, join_type, subquery_alias) +} + +/// This is used to handle the case when the subquery is embedded in a more complex boolean +/// expression like and OR. For example +/// +/// `select t1.id from t1 where t1.id < 0 OR exists(SELECT t2.id FROM t2 WHERE t1.id = t2.id)` +/// +/// The optimized plan will be: +/// +/// ```text +/// Projection: t1.id +/// Filter: t1.id < 0 OR __correlated_sq_1.mark +/// LeftMark Join: Filter: t1.id = __correlated_sq_1.id +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.id +/// TableScan: t2 +fn mark_join( + left: &LogicalPlan, + subquery: Arc, + in_predicate_opt: Option, + negated: bool, + alias_generator: &Arc, +) -> Result> { + let alias = alias_generator.next("__correlated_sq"); + + let exists_col = Expr::Column(Column::new(Some(alias.clone()), "mark")); + let exists_expr = if negated { !exists_col } else { exists_col }; + + Ok( + build_join(left, &subquery, in_predicate_opt, JoinType::LeftMark, alias)? + .map(|plan| (plan, exists_expr)), + ) +} + +fn build_join( + left: &LogicalPlan, + subquery: &LogicalPlan, + in_predicate_opt: Option, + join_type: JoinType, + alias: String, +) -> Result> { + let mut pull_up = PullUpCorrelatedExpr::new() + .with_in_predicate_opt(in_predicate_opt.clone()) + .with_exists_sub_query(in_predicate_opt.is_none()); - let mut pull_up = PullUpCorrelatedExpr { - join_filters: vec![], - correlated_subquery_cols_map: Default::default(), - in_predicate_opt: in_predicate_opt.clone(), - exists_sub_query: in_predicate_opt.is_none(), - can_pull_up: true, - need_handle_count_bug: false, - collected_count_expr_map: Default::default(), - pull_up_having_expr: None, - }; let new_plan = subquery.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } let sub_query_alias = LogicalPlanBuilder::from(new_plan) - .alias(subquery_alias.to_string())? + .alias(alias.to_string())? .build()?; let mut all_correlated_cols = BTreeSet::new(); pull_up @@ -247,10 +337,9 @@ fn build_join( .for_each(|cols| all_correlated_cols.extend(cols.clone())); // alias the join filter - let join_filter_opt = - conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, &subquery_alias) - .map(Option::Some) + let join_filter_opt = conjunction(pull_up.join_filters) + .map_or(Ok(None), |filter| { + replace_qualified_name(filter, &all_correlated_cols, &alias).map(Some) })?; if let Some(join_filter) = match (join_filter_opt, in_predicate_opt) { @@ -262,7 +351,7 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate.and(join_filter)) } @@ -275,17 +364,13 @@ fn build_join( right, })), ) => { - let right_col = create_col_from_scalar_expr(right.deref(), subquery_alias)?; + let right_col = create_col_from_scalar_expr(right.deref(), alias)?; let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col)); Some(in_predicate) } _ => None, } { // join our sub query into the main plan - let join_type = match query_info.negated { - true => JoinType::LeftAnti, - false => JoinType::LeftSemi, - }; let new_plan = LogicalPlanBuilder::from(left.clone()) .join_on(sub_query_alias, join_type, Some(join_filter))? .build()?; @@ -321,6 +406,19 @@ impl SubqueryInfo { negated, } } + + pub fn expr(self) -> Expr { + match self.where_in_expr { + Some(expr) => match self.negated { + true => not_in_subquery(expr, self.query.subquery), + false => in_subquery(expr, self.query.subquery), + }, + None => match self.negated { + true => not_exists(self.query.subquery), + false => exists(self.query.subquery), + }, + } + } } #[cfg(test)] @@ -330,8 +428,8 @@ mod tests { use super::*; use crate::test::*; - use arrow::datatypes::DataType; - use datafusion_expr::{and, binary_expr, col, lit, or, out_ref_col}; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{and, binary_expr, col, lit, not, out_ref_col, table_scan}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -402,60 +500,6 @@ mod tests { assert_optimized_plan_equal(plan, expected) } - /// Test for IN subquery with additional OR filter - /// filter expression not modified - #[test] - fn in_subquery_with_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(or( - and( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - ), - in_subquery(col("c"), test_subquery_with_name("sq")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) OR test.c IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq.c [c:UInt32]\ - \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - - #[test] - fn in_subquery_with_and_or_filters() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(and( - or( - binary_expr(col("a"), Operator::Eq, lit(1_u32)), - in_subquery(col("b"), test_subquery_with_name("sq1")?), - ), - in_subquery(col("c"), test_subquery_with_name("sq2")?), - ))? - .project(vec![col("test.b")])? - .build()?; - - let expected = "Projection: test.b [b:UInt32]\ - \n Filter: test.a = UInt32(1) OR test.b IN () [a:UInt32, b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq1.c [c:UInt32]\ - \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq2.c [c:UInt32]\ - \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test for nested IN subqueries #[test] fn in_subquery_nested() -> Result<()> { @@ -472,51 +516,19 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [a:UInt32]\ \n Projection: sq.a [a:UInt32]\ - \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } - /// Test for filter input modification in case filter not supported - /// Outer filter expression not modified while inner converted to join - #[test] - fn in_subquery_input_modified() -> Result<()> { - let table_scan = test_table_scan()?; - let plan = LogicalPlanBuilder::from(table_scan) - .filter(in_subquery(col("c"), test_subquery_with_name("sq_inner")?))? - .project(vec![col("b"), col("c")])? - .alias("wrapped")? - .filter(or( - binary_expr(col("b"), Operator::Lt, lit(30_u32)), - in_subquery(col("c"), test_subquery_with_name("sq_outer")?), - ))? - .project(vec![col("b")])? - .build()?; - - let expected = "Projection: wrapped.b [b:UInt32]\ - \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ - \n Subquery: [c:UInt32]\ - \n Projection: sq_outer.c [c:UInt32]\ - \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ - \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ - \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ - \n Projection: sq_inner.c [c:UInt32]\ - \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_equal(plan, expected) - } - /// Test multiple correlated subqueries /// See subqueries.rs where_in_multiple() #[test] @@ -532,7 +544,7 @@ mod tests { ); let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) .filter( - in_subquery(col("customer.c_custkey"), orders.clone()) + in_subquery(col("customer.c_custkey"), Arc::clone(&orders)) .and(in_subquery(col("customer.c_custkey"), orders)), )? .project(vec![col("customer.c_custkey")])? @@ -590,13 +602,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; @@ -826,7 +838,7 @@ mod tests { let expected = "check_analyzed_plan\ \ncaused by\ \nError during planning: InSubquery should only return one column, but found 4"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_analyzer_check_err(vec![], plan, expected); Ok(()) } @@ -921,7 +933,7 @@ mod tests { let expected = "check_analyzed_plan\ \ncaused by\ \nError during planning: InSubquery should only return one column"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_analyzer_check_err(vec![], plan, expected); Ok(()) } @@ -963,44 +975,6 @@ mod tests { Ok(()) } - /// Test for correlated IN subquery filter with disjustions - #[test] - fn in_subquery_disjunction() -> Result<()> { - let sq = Arc::new( - LogicalPlanBuilder::from(scan_tpch_table("orders")) - .filter( - out_ref_col(DataType::Int64, "customer.c_custkey") - .eq(col("orders.o_custkey")), - )? - .project(vec![col("orders.o_custkey")])? - .build()?, - ); - - let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter( - in_subquery(col("customer.c_custkey"), sq) - .or(col("customer.c_custkey").eq(lit(1))), - )? - .project(vec![col("customer.c_custkey")])? - .build()?; - - // TODO: support disjunction - for now expect unaltered plan - let expected = r#"Projection: customer.c_custkey [c_custkey:Int64] - Filter: customer.c_custkey IN () OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8] - Subquery: [o_custkey:Int64] - Projection: orders.o_custkey [o_custkey:Int64] - Filter: outer_ref(customer.c_custkey) = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] - TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - - assert_optimized_plan_eq_display_indent( - Arc::new(DecorrelatePredicateSubquery::new()), - plan, - expected, - ); - Ok(()) - } - /// Test for correlated IN subquery filter #[test] fn in_subquery_correlated() -> Result<()> { @@ -1079,6 +1053,55 @@ mod tests { Ok(()) } + #[test] + fn wrapped_not_in_subquery() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(not(in_subquery(col("c"), test_subquery_with_name("sq")?)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ + \n Projection: sq.c [c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelatePredicateSubquery::new()), + plan, + expected, + ); + Ok(()) + } + + #[test] + fn wrapped_not_not_in_subquery() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(not(not_in_subquery( + col("c"), + test_subquery_with_name("sq")?, + )))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ + \n Projection: sq.c [c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelatePredicateSubquery::new()), + plan, + expected, + ); + Ok(()) + } + #[test] fn in_subquery_both_side_expr() -> Result<()> { let table_scan = test_table_scan()?; @@ -1144,7 +1167,7 @@ mod tests { } #[test] - fn in_subquery_muti_project_subquery_cols() -> Result<()> { + fn in_subquery_multi_project_subquery_cols() -> Result<()> { let table_scan = test_table_scan()?; let subquery_scan = test_table_scan_with_name("sq")?; @@ -1270,7 +1293,7 @@ mod tests { ); let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) - .filter(exists(orders.clone()).and(exists(orders)))? + .filter(exists(Arc::clone(&orders)).and(exists(orders)))? .project(vec![col("customer.c_custkey")])? .build()?; @@ -1318,13 +1341,13 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: __correlated_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: Filter: __correlated_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: __correlated_sq_1.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ + \n SubqueryAlias: __correlated_sq_1 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_optimized_plan_equal(plan, expected) @@ -1820,4 +1843,35 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + + #[test] + fn upper_case_ident() -> Result<()> { + let fields = vec![ + Field::new("A", DataType::UInt32, false), + Field::new("B", DataType::UInt32, false), + ]; + + let schema = Schema::new(fields); + let table_scan_a = table_scan(Some("\"TEST_A\""), &schema, None)?.build()?; + let table_scan_b = table_scan(Some("\"TEST_B\""), &schema, None)?.build()?; + + let subquery = LogicalPlanBuilder::from(table_scan_b) + .filter(col("\"A\"").eq(out_ref_col(DataType::UInt32, "\"TEST_A\".\"A\"")))? + .project(vec![lit(1)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan_a) + .filter(exists(Arc::new(subquery)))? + .project(vec![col("\"TEST_A\".\"B\"")])? + .build()?; + + let expected = "Projection: TEST_A.B [B:UInt32]\ + \n LeftSemi Join: Filter: __correlated_sq_1.A = TEST_A.A [A:UInt32, B:UInt32]\ + \n TableScan: TEST_A [A:UInt32, B:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [Int32(1):Int32, A:UInt32]\ + \n Projection: Int32(1), TEST_B.A [Int32(1):Int32, A:UInt32]\ + \n TableScan: TEST_B [A:UInt32, B:UInt32]"; + + assert_optimized_plan_equal(plan, expected) + } } diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index ae6c1b339d5f..65ebac2106ad 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,20 +16,21 @@ // under the License. //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. -use std::collections::HashSet; use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{plan_err, Result}; +use crate::join_key_set::JoinKeySet; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::Result; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ - CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, + Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{build_join_schema, ExprSchemable, Operator}; +use datafusion_expr::{and, build_join_schema, ExprSchemable, Operator}; -#[derive(Default)] +#[derive(Default, Debug)] pub struct EliminateCrossJoin; impl EliminateCrossJoin { @@ -39,102 +40,152 @@ impl EliminateCrossJoin { } } -/// Attempt to reorder join to eliminate cross joins to inner joins. -/// for queries: -/// 'select ... from a, b where a.x = b.y and b.xx = 100;' -/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' -/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) -/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' -/// 'select ... from a, b where a.x > b.y' +/// Eliminate cross joins by rewriting them to inner joins when possible. +/// +/// # Example +/// The initial plan for this query: +/// ```sql +/// select ... from a, b where a.x = b.y and b.xx = 100; +/// ``` +/// +/// Looks like this: +/// ```text +/// Filter(a.x = b.y AND b.xx = 100) +/// Cross Join +/// TableScan a +/// TableScan b +/// ``` +/// +/// After the rule is applied, the plan will look like this: +/// ```text +/// Filter(b.xx = 100) +/// InnerJoin(a.x = b.y) +/// TableScan a +/// TableScan b +/// ``` +/// +/// # Other Examples +/// * 'select ... from a, b where a.x = b.y and b.xx = 100;' +/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' +/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) +/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// * 'select ... from a, b where a.x > b.y' +/// /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately +/// /// This fix helps to improve the performance of TPCH Q19. issue#78 impl OptimizerRule for EliminateCrossJoin { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; + ) -> Result> { + let plan_schema = Arc::clone(plan.schema()); + let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; - let parent_predicate = match plan { - LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref(); - match input { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs( - input, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - extract_possible_join_keys( - &filter.predicate, - &mut possible_join_keys, - )?; - Some(&filter.predicate) - } - _ => { - return utils::optimize_children(self, plan, config); - } - } + let mut all_filters: Vec = vec![]; + + let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + // if input isn't a join that can potentially be rewritten + // avoid unwrapping the input + let rewriteable = matches!( + filter.input.as_ref(), + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) + ); + + if !rewriteable { + // recursively try to rewrite children + return rewrite_children(self, LogicalPlan::Filter(filter), config); + } + + if !can_flatten_join_inputs(&filter.input) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } + + let Filter { + input, predicate, .. + } = filter; + flatten_join_inputs( + Arc::unwrap_or_clone(input), + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; + + extract_possible_join_keys(&predicate, &mut possible_join_keys); + Some(predicate) + } else if matches!( + plan, LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) => { - if !try_flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - None + }) + ) { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan)); } - _ => return utils::optimize_children(self, plan, config), + flatten_join_inputs( + plan, + &mut possible_join_keys, + &mut all_inputs, + &mut all_filters, + )?; + None + } else { + // recursively try to rewrite children + return rewrite_children(self, plan, config); }; // Join keys are handled locally: - let mut all_join_keys = HashSet::<(Expr, Expr)>::new(); + let mut all_join_keys = JoinKeySet::new(); let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( - &left, + left, &mut all_inputs, - &mut possible_join_keys, + &possible_join_keys, &mut all_join_keys, )?; } - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + left = rewrite_children(self, left, config)?.data; - if plan.schema() != left.schema() { + if &plan_schema != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( Arc::new(left), - plan.schema().clone(), + Arc::clone(&plan_schema), )); } + if !all_filters.is_empty() { + // Add any filters on top - PushDownFilter can push filters down to applicable join + let first = all_filters.swap_remove(0); + let predicate = all_filters.into_iter().fold(first, and); + left = LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left))?); + } + let Some(predicate) = parent_predicate else { - return Ok(Some(left)); + return Ok(Transformed::yes(left)); }; // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate.clone(), Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))) + Filter::try_new(predicate, Arc::new(left)) + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) } else { // Remove join expressions from filter: - match remove_join_expressions(predicate, &all_join_keys)? { + match remove_join_expressions(predicate, &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))), - _ => Ok(Some(left)), + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), + _ => Ok(Transformed::yes(left)), } } } @@ -144,67 +195,111 @@ impl OptimizerRule for EliminateCrossJoin { } } -/// Recursively accumulate possible_join_keys and inputs from inner joins (including cross joins). -/// Returns a boolean indicating whether the flattening was successful. -fn try_flatten_join_inputs( - plan: &LogicalPlan, - possible_join_keys: &mut Vec<(Expr, Expr)>, +fn rewrite_children( + optimizer: &impl OptimizerRule, + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + + // recompute schema if the plan was transformed + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + +/// Recursively accumulate possible_join_keys and inputs from inner joins +/// (including cross joins). +/// +/// Assumes can_flatten_join_inputs has returned true and thus the plan can be +/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to +/// possible_join_keys +fn flatten_join_inputs( + plan: LogicalPlan, + possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, -) -> Result { - let children = match plan { + all_filters: &mut Vec, +) -> Result<()> { + match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - if join.filter.is_some() { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - return Ok(false); + if let Some(filter) = join.filter { + all_filters.push(filter); } - possible_join_keys.extend(join.on.clone()); - let left = &*(join.left); - let right = &*(join.right); - vec![left, right] - } - LogicalPlan::CrossJoin(join) => { - let left = &*(join.left); - let right = &*(join.right); - vec![left, right] + possible_join_keys.insert_all_owned(join.on); + flatten_join_inputs( + Arc::unwrap_or_clone(join.left), + possible_join_keys, + all_inputs, + all_filters, + )?; + flatten_join_inputs( + Arc::unwrap_or_clone(join.right), + possible_join_keys, + all_inputs, + all_filters, + )?; } _ => { - return plan_err!("flatten_join_inputs just can call join/cross_join"); + all_inputs.push(plan); } }; + Ok(()) +} - for child in children.iter() { - match *child { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? { - return Ok(false); - } +/// Returns true if the plan is a Join or Cross join could be flattened with +/// `flatten_join_inputs` +/// +/// Must stay in sync with `flatten_join_inputs` +fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { + // can only flatten inner / cross joins + match plan { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} + _ => return false, + }; + + for child in plan.inputs() { + if let LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) = child + { + if !can_flatten_join_inputs(child) { + return false; } - _ => all_inputs.push((*child).clone()), } } - Ok(true) + true } +/// Finds the next to join with the left input plan, +/// +/// Finds the next `right` from `rights` that can be joined with `left_input` +/// plan based on the join keys in `possible_join_keys`. +/// +/// If such a matching `right` is found: +/// 1. Adds the matching join keys to `all_join_keys`. +/// 2. Returns `left_input JOIN right ON (all join keys)`. +/// +/// If no matching `right` is found: +/// 1. Removes the first plan from `rights` +/// 2. Returns `left_input CROSS JOIN right`. fn find_inner_join( - left_input: &LogicalPlan, + left_input: LogicalPlan, rights: &mut Vec, - possible_join_keys: &mut Vec<(Expr, Expr)>, - all_join_keys: &mut HashSet<(Expr, Expr)>, + possible_join_keys: &JoinKeySet, + all_join_keys: &mut JoinKeySet, ) -> Result { for (i, right_input) in rights.iter().enumerate() { let mut join_keys = vec![]; - for (l, r) in &mut *possible_join_keys { + for (l, r) in possible_join_keys.iter() { let key_pair = find_valid_equijoin_key_pair( l, r, - left_input.schema().clone(), - right_input.schema().clone(), + left_input.schema(), + right_input.schema(), )?; // Save join keys @@ -215,8 +310,9 @@ fn find_inner_join( } } + // Found one or more matching join keys if !join_keys.is_empty() { - all_join_keys.extend(join_keys.clone()); + all_join_keys.insert_all(join_keys.iter()); let right_input = rights.remove(i); let join_schema = Arc::new(build_join_schema( left_input.schema(), @@ -225,7 +321,7 @@ fn find_inner_join( )?); return Ok(LogicalPlan::Join(Join { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right_input), join_type: JoinType::Inner, join_constraint: JoinConstraint::On, @@ -236,6 +332,9 @@ fn find_inner_join( })); } } + + // no matching right plan had any join keys, cross join with the first right + // plan let right = rights.remove(0); let join_schema = Arc::new(build_join_schema( left_input.schema(), @@ -243,97 +342,90 @@ fn find_inner_join( &JoinType::Inner, )?); - Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_input.clone()), + Ok(LogicalPlan::Join(Join { + left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, + on: vec![], + filter: None, + join_type: JoinType::Inner, + join_constraint: JoinConstraint::On, + null_equals_null: false, })) } -fn intersect( - accum: &mut Vec<(Expr, Expr)>, - vec1: &[(Expr, Expr)], - vec2: &[(Expr, Expr)], -) { - if !(vec1.is_empty() || vec2.is_empty()) { - for x1 in vec1.iter() { - for x2 in vec2.iter() { - if x1.0 == x2.0 && x1.1 == x2.1 || x1.1 == x2.0 && x1.0 == x2.1 { - accum.push((x1.0.clone(), x1.1.clone())); - } - } - } - } -} - /// Extract join keys from a WHERE clause -fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> { +fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { match op { Operator::Eq => { - // Ensure that we don't add the same Join keys multiple times - if !(accum.contains(&(*left.clone(), *right.clone())) - || accum.contains(&(*right.clone(), *left.clone()))) - { - accum.push((*left.clone(), *right.clone())); - } + // insert handles ensuring we don't add the same Join keys multiple times + join_keys.insert(left, right); } Operator::And => { - extract_possible_join_keys(left, accum)?; - extract_possible_join_keys(right, accum)? + extract_possible_join_keys(left, join_keys); + extract_possible_join_keys(right, join_keys) } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. Operator::Or => { - let mut left_join_keys = vec![]; - let mut right_join_keys = vec![]; + let mut left_join_keys = JoinKeySet::new(); + let mut right_join_keys = JoinKeySet::new(); - extract_possible_join_keys(left, &mut left_join_keys)?; - extract_possible_join_keys(right, &mut right_join_keys)?; + extract_possible_join_keys(left, &mut left_join_keys); + extract_possible_join_keys(right, &mut right_join_keys); - intersect(accum, &left_join_keys, &right_join_keys) + join_keys.insert_intersection(&left_join_keys, &right_join_keys) } _ => (), }; } - Ok(()) } /// Remove join expressions from a filter expression -/// Returns Some() when there are few remaining predicates in filter_expr -/// Returns None otherwise -fn remove_join_expressions( - expr: &Expr, - join_keys: &HashSet<(Expr, Expr)>, -) -> Result> { +/// +/// # Returns +/// * `Some()` when there are few remaining predicates in filter_expr +/// * `None` otherwise +fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match op { - Operator::Eq => { - if join_keys.contains(&(*left.clone(), *right.clone())) - || join_keys.contains(&(*right.clone(), *left.clone())) - { - Ok(None) - } else { - Ok(Some(expr.clone())) - } - } - // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Operator::And | Operator::Or => { - let l = remove_join_expressions(left, join_keys)?; - let r = remove_join_expressions(right, join_keys)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr( - BinaryExpr::new(Box::new(ll), *op, Box::new(rr)), - ))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), - } - } - _ => Ok(Some(expr.clone())), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) if join_keys.contains(&left, &right) => { + // was a join key, so remove it + None + } + // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. + Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::And => { + let l = remove_join_expressions(*left, join_keys); + let r = remove_join_expressions(*right, join_keys); + match (l, r) { + (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + Box::new(ll), + op, + Box::new(rr), + ))), + (Some(ll), _) => Some(ll), + (_, Some(rr)) => Some(rr), + _ => None, } } - _ => Ok(Some(expr.clone())), + Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == Operator::Or => { + let l = remove_join_expressions(*left, join_keys); + let r = remove_join_expressions(*right, join_keys); + match (l, r) { + (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + Box::new(ll), + op, + Box::new(rr), + ))), + // When either `left` or `right` is empty, it means they are `true` + // so OR'ing anything with them will also be true + _ => None, + } + } + _ => Some(expr), } } @@ -349,12 +441,12 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { + let starting_schema = Arc::clone(plan.schema()); let rule = EliminateCrossJoin::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(transformed_plan.transformed, "failed to optimize plan"); + let optimized_plan = transformed_plan.data; let formatted = optimized_plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -363,13 +455,7 @@ mod tests { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - assert_eq!(plan.schema(), optimized_plan.schema()) - } - - fn assert_optimization_rule_fails(plan: &LogicalPlan) { - let rule = EliminateCrossJoin::new(); - let optimized_plan = rule.try_optimize(plan, &OptimizerContext::new()).unwrap(); - assert!(optimized_plan.is_none()); + assert_eq!(&starting_schema, optimized_plan.schema()) } #[test] @@ -394,7 +480,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -417,12 +503,12 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -449,7 +535,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -479,7 +565,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -505,11 +591,11 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -531,18 +617,17 @@ mod tests { let expected = vec![ "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } #[test] - /// See https://github.com/apache/datafusion/issues/7530 - fn eliminate_cross_not_possible_nested_inner_join_with_filter() -> Result<()> { + fn eliminate_cross_possible_nested_inner_join_with_filter() -> Result<()> { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; let t3 = test_table_scan_with_name("t3")?; @@ -559,7 +644,17 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(&plan); + let expected = vec![ + "Filter: t1.a > UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t1.a > UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]" + ]; + + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -606,7 +701,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -683,7 +778,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -747,7 +842,7 @@ mod tests { let expected = vec![ "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", @@ -758,7 +853,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -828,12 +923,12 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -903,7 +998,7 @@ mod tests { "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -912,7 +1007,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -925,6 +1020,7 @@ mod tests { let t4 = test_table_scan_with_name("t4")?; // could eliminate to inner join + // filter: (t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND tc.2 = 688) let plan1 = LogicalPlanBuilder::from(t1) .cross_join(t2)? .filter(binary_expr( @@ -942,6 +1038,10 @@ mod tests { let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?; // could eliminate to inner join + // filter: + // ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688)) + // AND + // ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b)) let plan = LogicalPlanBuilder::from(plan1) .cross_join(plan2)? .filter(binary_expr( @@ -987,7 +1087,7 @@ mod tests { "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Filter: t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", @@ -995,7 +1095,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1014,6 +1114,12 @@ mod tests { let plan2 = LogicalPlanBuilder::from(t3).cross_join(t4)?.build()?; // could eliminate to inner join + // Filter: + // ((t3.a = t1.a AND t4.c < 15) OR (t3.a = t1.a AND t4.c = 688)) + // AND + // ((t3.a = t4.a AND t4.c < 15) OR (t3.a = t4.a AND t3.c = 688) OR (t3.a = t4.a AND t3.b = t4.b)) + // AND + // ((t1.a = t2.a OR t2.c < 15) AND (t1.a = t2.a AND t2.c = 688)) let plan = LogicalPlanBuilder::from(plan1) .cross_join(plan2)? .filter(binary_expr( @@ -1072,7 +1178,7 @@ mod tests { .build()?; let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", @@ -1082,7 +1188,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1108,7 +1214,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1131,12 +1237,12 @@ mod tests { let expected = vec![ "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Cross Join: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1164,7 +1270,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1192,7 +1298,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1232,7 +1338,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index ee44a328f8b3..554985667fdf 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -19,14 +19,14 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::expr::Sort as ExprSort; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{Aggregate, Expr, Sort}; -use hashbrown::HashSet; - +use datafusion_expr::{Aggregate, Expr, Sort, SortExpr}; +use indexmap::IndexSet; +use std::hash::{Hash, Hasher}; /// Optimization rule that eliminate duplicated expr. -#[derive(Default)] +#[derive(Default, Debug)] pub struct EliminateDuplicatedExpr; impl EliminateDuplicatedExpr { @@ -35,78 +35,82 @@ impl EliminateDuplicatedExpr { Self {} } } - +// use this structure to avoid initial clone +#[derive(Eq, Clone, Debug)] +struct SortExprWrapper(SortExpr); +impl PartialEq for SortExprWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.expr == other.0.expr + } +} +impl Hash for SortExprWrapper { + fn hash(&self, state: &mut H) { + self.0.expr.hash(state); + } +} impl OptimizerRule for EliminateDuplicatedExpr { - fn try_optimize( + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { LogicalPlan::Sort(sort) => { - let normalized_sort_keys = sort + let len = sort.expr.len(); + let unique_exprs: Vec<_> = sort .expr - .iter() - .map(|e| match e { - Expr::Sort(ExprSort { expr, .. }) => { - Expr::Sort(ExprSort::new(expr.clone(), true, false)) - } - _ => e.clone(), - }) - .collect::>(); + .into_iter() + .map(SortExprWrapper) + .collect::>() + .into_iter() + .map(|wrapper| wrapper.0) + .collect(); - // dedup sort.expr and keep order - let mut dedup_expr = Vec::new(); - let mut dedup_set = HashSet::new(); - sort.expr.iter().zip(normalized_sort_keys.iter()).for_each( - |(expr, normalized_expr)| { - if !dedup_set.contains(normalized_expr) { - dedup_expr.push(expr); - dedup_set.insert(normalized_expr); - } - }, - ); - if dedup_expr.len() == sort.expr.len() { - Ok(None) + let transformed = if len != unique_exprs.len() { + Transformed::yes } else { - Ok(Some(LogicalPlan::Sort(Sort { - expr: dedup_expr.into_iter().cloned().collect::>(), - input: sort.input.clone(), - fetch: sort.fetch, - }))) - } + Transformed::no + }; + + Ok(transformed(LogicalPlan::Sort(Sort { + expr: unique_exprs, + input: sort.input, + fetch: sort.fetch, + }))) } LogicalPlan::Aggregate(agg) => { - // dedup agg.groupby and keep order - let mut dedup_expr = Vec::new(); - let mut dedup_set = HashSet::new(); - agg.group_expr.iter().for_each(|expr| { - if !dedup_set.contains(expr) { - dedup_expr.push(expr.clone()); - dedup_set.insert(expr); - } - }); - if dedup_expr.len() == agg.group_expr.len() { - Ok(None) + let len = agg.group_expr.len(); + + let unique_exprs: Vec = agg + .group_expr + .into_iter() + .collect::>() + .into_iter() + .collect(); + + let transformed = if len != unique_exprs.len() { + Transformed::yes } else { - Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new( - agg.input.clone(), - dedup_expr, - agg.aggr_expr.clone(), - )?))) - } + Transformed::no + }; + + Aggregate::try_new(agg.input, unique_exprs, agg.aggr_expr) + .map(|f| transformed(LogicalPlan::Aggregate(f))) } - _ => Ok(None), + _ => Ok(Transformed::no(plan)), } } - fn name(&self) -> &str { "eliminate_duplicated_expr" } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } } #[cfg(test)] @@ -128,11 +132,11 @@ mod tests { fn eliminate_sort_expr() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a"), col("a"), col("b"), col("c")])? + .sort_by(vec![col("a"), col("a"), col("b"), col("c")])? .limit(5, Some(10))? .build()?; let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, test.b, test.c\ + \n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 2bf5cfa30390..4ed2ac8ba1a4 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -17,13 +17,12 @@ //! [`EliminateFilter`] replaces `where false` or `where null` with an empty relation. -use crate::optimizer::ApplyOrder; +use datafusion_common::tree_node::Transformed; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ - logical_plan::{EmptyRelation, LogicalPlan}, - Expr, Filter, -}; +use datafusion_expr::{EmptyRelation, Expr, Filter, LogicalPlan}; +use std::sync::Arc; +use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; /// Optimization rule that eliminate the scalar value (true/false/null) filter @@ -31,7 +30,7 @@ use crate::{OptimizerConfig, OptimizerRule}; /// /// This saves time in planning and executing the query. /// Note that this rule should be applied after simplify expressions optimizer rule. -#[derive(Default)] +#[derive(Default, Debug)] pub struct EliminateFilter; impl EliminateFilter { @@ -42,54 +41,54 @@ impl EliminateFilter { } impl OptimizerRule for EliminateFilter { - fn try_optimize( + fn name(&self) -> &str { + "eliminate_filter" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { LogicalPlan::Filter(Filter { predicate: Expr::Literal(ScalarValue::Boolean(v)), input, .. - }) => { - match *v { - // input also can be filter, apply again - Some(true) => Ok(Some( - self.try_optimize(input, _config)? - .unwrap_or_else(|| input.as_ref().clone()), - )), - Some(false) | None => { - Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: input.schema().clone(), - }))) - } - } - } - _ => Ok(None), + }) => match v { + Some(true) => Ok(Transformed::yes(Arc::unwrap_or_clone(input))), + Some(false) | None => Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(input.schema()), + }, + ))), + }, + _ => Ok(Transformed::no(plan)), } } - - fn name(&self) -> &str { - "eliminate_filter" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } } #[cfg(test)] mod tests { - use crate::eliminate_filter::EliminateFilter; + use std::sync::Arc; + use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ - col, lit, logical_plan::builder::LogicalPlanBuilder, sum, Expr, LogicalPlan, + col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, }; - use std::sync::Arc; + use crate::eliminate_filter::EliminateFilter; use crate::test::*; + use datafusion_expr::test::function_stub::sum; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) @@ -97,7 +96,7 @@ mod tests { #[test] fn filter_false() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); + let filter_expr = lit(false); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) @@ -127,7 +126,7 @@ mod tests { #[test] fn filter_false_nested() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); + let filter_expr = lit(false); let table_scan = test_table_scan()?; let plan1 = LogicalPlanBuilder::from(table_scan.clone()) @@ -142,14 +141,14 @@ mod tests { // Left side is removed let expected = "Union\ \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) } #[test] fn filter_true() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true))); + let filter_expr = lit(true); let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -157,14 +156,14 @@ mod tests { .filter(filter_expr)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) } #[test] fn filter_true_nested() -> Result<()> { - let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true))); + let filter_expr = lit(true); let table_scan = test_table_scan()?; let plan1 = LogicalPlanBuilder::from(table_scan.clone()) @@ -178,9 +177,9 @@ mod tests { // Filter is removed let expected = "Union\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) } diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs b/datafusion/optimizer/src/eliminate_group_by_constant.rs new file mode 100644 index 000000000000..13d03d647fe2 --- /dev/null +++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs @@ -0,0 +1,312 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` clause +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::Result; +use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Volatility}; + +/// Optimizer rule that removes constant expressions from `GROUP BY` clause +/// and places additional projection on top of aggregation, to preserve +/// original schema +#[derive(Default, Debug)] +pub struct EliminateGroupByConstant {} + +impl EliminateGroupByConstant { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateGroupByConstant { + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + match plan { + LogicalPlan::Aggregate(aggregate) => { + let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) = aggregate + .group_expr + .iter() + .partition(|expr| is_constant_expression(expr)); + + // If no constant expressions found (nothing to optimize) or + // constant expression is the only expression in aggregate, + // optimization is skipped + if const_group_expr.is_empty() + || (!const_group_expr.is_empty() + && nonconst_group_expr.is_empty() + && aggregate.aggr_expr.is_empty()) + { + return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate))); + } + + let simplified_aggregate = LogicalPlan::Aggregate(Aggregate::try_new( + aggregate.input, + nonconst_group_expr.into_iter().cloned().collect(), + aggregate.aggr_expr.clone(), + )?); + + let projection_expr = + aggregate.group_expr.into_iter().chain(aggregate.aggr_expr); + + let projection = LogicalPlanBuilder::from(simplified_aggregate) + .project(projection_expr)? + .build()?; + + Ok(Transformed::yes(projection)) + } + _ => Ok(Transformed::no(plan)), + } + } + + fn name(&self) -> &str { + "eliminate_group_by_constant" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } +} + +/// Checks if expression is constant, and can be eliminated from group by. +/// +/// Intended to be used only within this rule, helper function, which heavily +/// reiles on `SimplifyExpressions` result. +fn is_constant_expression(expr: &Expr) -> bool { + match expr { + Expr::Alias(e) => is_constant_expression(&e.expr), + Expr::BinaryExpr(e) => { + is_constant_expression(&e.left) && is_constant_expression(&e.right) + } + Expr::Literal(_) => true, + Expr::ScalarFunction(e) => { + matches!( + e.func.signature().volatility, + Volatility::Immutable | Volatility::Stable + ) && e.args.iter().all(is_constant_expression) + } + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + + use arrow::datatypes::DataType; + use datafusion_common::Result; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{ + col, lit, ColumnarValue, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, + }; + + use datafusion_functions_aggregate::expr_fn::count; + + use std::sync::Arc; + + #[derive(Debug)] + struct ScalarUDFMock { + signature: Signature, + } + + impl ScalarUDFMock { + fn new_with_volatility(volatility: Volatility) -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), volatility), + } + } + } + + impl ScalarUDFImpl for ScalarUDFMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &str { + "scalar_fn_mock" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[test] + fn test_eliminate_gby_literal() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Projection: test.a, UInt32(1), count(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_eliminate_constant() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Projection: Utf8(\"test\"), UInt32(123), count(test.c)\ + \n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_no_op_no_constants() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![col("a"), col("b")], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_no_op_only_constant() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![lit(123u32)], Vec::::new())? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_eliminate_constant_with_alias() -> Result<()> { + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate( + vec![lit(123u32).alias("const"), col("a")], + vec![count(col("c"))], + )? + .build()?; + + let expected = "\ + Projection: UInt32(123) AS const, test.a, count(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_eliminate_scalar_fn_with_constant_arg() -> Result<()> { + let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility( + Volatility::Immutable, + )); + let udf_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } + + #[test] + fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> { + let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility( + Volatility::Volatile, + )); + let udf_expr = + Expr::ScalarFunction(ScalarFunction::new_udf(udf.into(), vec![lit(123u32)])); + let scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(scan) + .aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])? + .build()?; + + let expected = "\ + Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\ + \n TableScan: test\ + "; + + assert_optimized_plan_eq( + Arc::new(EliminateGroupByConstant::new()), + plan, + expected, + ) + } +} diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index fea87e758790..789235595dab 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -19,16 +19,16 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::JoinType::Inner; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, - CrossJoin, Expr, + Expr, }; /// Eliminates joins when join condition is false. /// Replaces joins when inner join condition is true with a cross join. -#[derive(Default)] +#[derive(Default, Debug)] pub struct EliminateJoin; impl EliminateJoin { @@ -38,14 +38,6 @@ impl EliminateJoin { } impl OptimizerRule for EliminateJoin { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateJoin::rewrite") - } - fn name(&self) -> &str { "eliminate_join" } @@ -62,13 +54,6 @@ impl OptimizerRule for EliminateJoin { match plan { LogicalPlan::Join(join) if join.join_type == Inner && join.on.is_empty() => { match join.filter { - Some(Expr::Literal(ScalarValue::Boolean(Some(true)))) => { - Ok(Transformed::yes(LogicalPlan::CrossJoin(CrossJoin { - left: join.left, - right: join.right, - schema: join.schema, - }))) - } Some(Expr::Literal(ScalarValue::Boolean(Some(false)))) => Ok( Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -91,9 +76,9 @@ impl OptimizerRule for EliminateJoin { mod tests { use crate::eliminate_join::EliminateJoin; use crate::test::*; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::Result; use datafusion_expr::JoinType::Inner; - use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder, LogicalPlan}; use std::sync::Arc; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -106,28 +91,11 @@ mod tests { .join_on( LogicalPlanBuilder::empty(false).build()?, Inner, - Some(Expr::Literal(ScalarValue::Boolean(Some(false)))), + Some(lit(false)), )? .build()?; let expected = "EmptyRelation"; assert_optimized_plan_equal(plan, expected) } - - #[test] - fn join_on_true() -> Result<()> { - let plan = LogicalPlanBuilder::empty(false) - .join_on( - LogicalPlanBuilder::empty(false).build()?, - Inner, - Some(Expr::Literal(ScalarValue::Boolean(Some(true)))), - )? - .build()?; - - let expected = "\ - CrossJoin:\ - \n EmptyRelation\ - \n EmptyRelation"; - assert_optimized_plan_equal(plan, expected) - } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 39231d784e00..267615c3e0d9 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -18,8 +18,10 @@ //! [`EliminateLimit`] eliminates `LIMIT` when possible use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; use datafusion_common::Result; -use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; +use datafusion_expr::logical_plan::{EmptyRelation, FetchType, LogicalPlan, SkipType}; +use std::sync::Arc; /// Optimizer rule to replace `LIMIT 0` or `LIMIT` whose ancestor LIMIT's skip is /// greater than or equal to current's fetch @@ -28,7 +30,7 @@ use datafusion_expr::logical_plan::{EmptyRelation, LogicalPlan}; /// plan with an empty relation. /// /// This rule also removes OFFSET 0 from the [LogicalPlan] -#[derive(Default)] +#[derive(Default, Debug)] pub struct EliminateLimit; impl EliminateLimit { @@ -39,36 +41,6 @@ impl EliminateLimit { } impl OptimizerRule for EliminateLimit { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - if let LogicalPlan::Limit(limit) = plan { - match limit.fetch { - Some(fetch) => { - if fetch == 0 { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: limit.input.schema().clone(), - }))); - } - } - None => { - if limit.skip == 0 { - let input = limit.input.as_ref(); - // input also can be Limit, so we should apply again. - return Ok(Some( - self.try_optimize(input, _config)? - .unwrap_or_else(|| input.clone()), - )); - } - } - } - } - Ok(None) - } - fn name(&self) -> &str { "eliminate_limit" } @@ -76,6 +48,42 @@ impl OptimizerRule for EliminateLimit { fn apply_order(&self) -> Option { Some(ApplyOrder::BottomUp) } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, datafusion_common::DataFusionError> { + match plan { + LogicalPlan::Limit(limit) => { + // Only supports rewriting for literal fetch + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + + if let Some(v) = fetch { + if v == 0 { + return Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(limit.input.schema()), + }, + ))); + } + } else if matches!(limit.get_skip_type()?, SkipType::Literal(0)) { + // If fetch is `None` and skip is 0, then Limit takes no effect and + // we can remove it. Its input also can be Limit, so we should apply again. + return self.rewrite(Arc::unwrap_or_clone(limit.input), _config); + } + Ok(Transformed::no(LogicalPlan::Limit(limit))) + } + _ => Ok(Transformed::no(plan)), + } + } } #[cfg(test)] @@ -88,11 +96,11 @@ mod tests { use datafusion_expr::{ col, logical_plan::{builder::LogicalPlanBuilder, JoinType}, - sum, }; use std::sync::Arc; use crate::push_down_limit::PushDownLimit; + use datafusion_expr::test::function_stub::sum; fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -100,7 +108,7 @@ mod tests { let optimized_plan = optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); Ok(()) } @@ -118,7 +126,7 @@ mod tests { let optimized_plan = optimizer .optimize(plan, &config, observe) .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); Ok(()) } @@ -150,7 +158,7 @@ mod tests { // Left side is removed let expected = "Union\ \n EmptyRelation\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } @@ -175,16 +183,16 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(2, Some(1))? .build()?; // After remove global-state, we don't record the parent // So, bottom don't know parent info, so can't eliminate. let expected = "Limit: skip=2, fetch=1\ - \n Sort: test.a, fetch=3\ + \n Sort: test.a ASC NULLS LAST, fetch=3\ \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_eq_with_pushdown(plan, expected) } @@ -195,14 +203,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(0, Some(2))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(1))? .build()?; let expected = "Limit: skip=0, fetch=1\ - \n Sort: test.a\ + \n Sort: test.a ASC NULLS LAST\ \n Limit: skip=0, fetch=2\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } @@ -213,14 +221,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a")], vec![sum(col("b"))])? .limit(2, Some(1))? - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(3, Some(1))? .build()?; let expected = "Limit: skip=3, fetch=1\ - \n Sort: test.a\ + \n Sort: test.a ASC NULLS LAST\ \n Limit: skip=2, fetch=1\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } @@ -255,7 +263,7 @@ mod tests { .limit(0, None)? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index da2a6a17214e..94da08243d78 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -18,12 +18,14 @@ //! [`EliminateNestedUnion`]: flattens nested `Union` to a single `Union` use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; use datafusion_common::Result; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::{Distinct, LogicalPlan, Union}; +use itertools::Itertools; use std::sync::Arc; -#[derive(Default)] +#[derive(Default, Debug)] /// An optimization rule that replaces nested unions with a single union. pub struct EliminateNestedUnion; @@ -35,75 +37,88 @@ impl EliminateNestedUnion { } impl OptimizerRule for EliminateNestedUnion { - fn try_optimize( + fn name(&self) -> &str { + "eliminate_nested_union" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { LogicalPlan::Union(Union { inputs, schema }) => { let inputs = inputs - .iter() + .into_iter() .flat_map(extract_plans_from_union) - .collect::>(); + .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) + .collect::>>()?; - Ok(Some(LogicalPlan::Union(Union { - inputs, - schema: schema.clone(), + Ok(Transformed::yes(LogicalPlan::Union(Union { + inputs: inputs.into_iter().map(Arc::new).collect_vec(), + schema, }))) } - LogicalPlan::Distinct(Distinct::All(plan)) => match plan.as_ref() { - LogicalPlan::Union(Union { inputs, schema }) => { - let inputs = inputs - .iter() - .map(extract_plan_from_distinct) - .flat_map(extract_plans_from_union) - .collect::>(); - - Ok(Some(LogicalPlan::Distinct(Distinct::All(Arc::new( - LogicalPlan::Union(Union { - inputs, - schema: schema.clone(), - }), - ))))) + LogicalPlan::Distinct(Distinct::All(nested_plan)) => { + match Arc::unwrap_or_clone(nested_plan) { + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .into_iter() + .map(extract_plan_from_distinct) + .flat_map(extract_plans_from_union) + .map(|plan| coerce_plan_expr_for_schema(plan, &schema)) + .collect::>>()?; + + Ok(Transformed::yes(LogicalPlan::Distinct(Distinct::All( + Arc::new(LogicalPlan::Union(Union { + inputs: inputs.into_iter().map(Arc::new).collect_vec(), + schema: Arc::clone(&schema), + })), + )))) + } + nested_plan => Ok(Transformed::no(LogicalPlan::Distinct( + Distinct::All(Arc::new(nested_plan)), + ))), } - _ => Ok(None), - }, - _ => Ok(None), + } + _ => Ok(Transformed::no(plan)), } } - - fn name(&self) -> &str { - "eliminate_nested_union" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) - } } -fn extract_plans_from_union(plan: &Arc) -> Vec> { - match plan.as_ref() { - LogicalPlan::Union(Union { inputs, schema }) => inputs - .iter() - .map(|plan| Arc::new(coerce_plan_expr_for_schema(plan, schema).unwrap())) +fn extract_plans_from_union(plan: Arc) -> Vec { + match Arc::unwrap_or_clone(plan) { + LogicalPlan::Union(Union { inputs, .. }) => inputs + .into_iter() + .map(Arc::unwrap_or_clone) .collect::>(), - _ => vec![plan.clone()], + plan => vec![plan], } } -fn extract_plan_from_distinct(plan: &Arc) -> &Arc { - match plan.as_ref() { +fn extract_plan_from_distinct(plan: Arc) -> Arc { + match Arc::unwrap_or_clone(plan) { LogicalPlan::Distinct(Distinct::All(plan)) => plan, - _ => plan, + plan => Arc::new(plan), } } #[cfg(test)] mod tests { use super::*; + use crate::analyzer::type_coercion::TypeCoercion; + use crate::analyzer::Analyzer; use crate::test::*; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; use datafusion_expr::{col, logical_plan::table_scan}; fn schema() -> Schema { @@ -115,17 +130,21 @@ mod tests { } fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) + let options = ConfigOptions::default(); + let analyzed_plan = Analyzer::with_rules(vec![Arc::new(TypeCoercion::new())]) + .execute_and_check(plan, &options, |_, _| {})?; + assert_optimized_plan_eq( + Arc::new(EliminateNestedUnion::new()), + analyzed_plan, + expected, + ) } #[test] fn eliminate_nothing() -> Result<()> { let plan_builder = table_scan(Some("table"), &schema(), None)?; - let plan = plan_builder - .clone() - .union(plan_builder.clone().build()?)? - .build()?; + let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?; let expected = "\ Union\ @@ -140,7 +159,7 @@ mod tests { let plan = plan_builder .clone() - .union_distinct(plan_builder.clone().build()?)? + .union_distinct(plan_builder.build()?)? .build()?; let expected = "Distinct:\ @@ -158,7 +177,7 @@ mod tests { .clone() .union(plan_builder.clone().build()?)? .union(plan_builder.clone().build()?)? - .union(plan_builder.clone().build()?)? + .union(plan_builder.build()?)? .build()?; let expected = "\ @@ -178,7 +197,7 @@ mod tests { .clone() .union_distinct(plan_builder.clone().build()?)? .union(plan_builder.clone().build()?)? - .union(plan_builder.clone().build()?)? + .union(plan_builder.build()?)? .build()?; let expected = "Union\ @@ -200,7 +219,7 @@ mod tests { .union(plan_builder.clone().build()?)? .union_distinct(plan_builder.clone().build()?)? .union(plan_builder.clone().build()?)? - .union_distinct(plan_builder.clone().build()?)? + .union_distinct(plan_builder.build()?)? .build()?; let expected = "Distinct:\ @@ -221,7 +240,7 @@ mod tests { .clone() .union_distinct(plan_builder.clone().distinct()?.build()?)? .union(plan_builder.clone().distinct()?.build()?)? - .union_distinct(plan_builder.clone().build()?)? + .union_distinct(plan_builder.build()?)? .build()?; let expected = "Distinct:\ @@ -249,7 +268,6 @@ mod tests { )? .union( plan_builder - .clone() .project(vec![col("id").alias("_id"), col("key"), col("value")])? .build()?, )? @@ -278,7 +296,6 @@ mod tests { )? .union_distinct( plan_builder - .clone() .project(vec![col("id").alias("_id"), col("key"), col("value")])? .build()?, )? diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 68d0ddba8b20..3e027811420c 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -16,13 +16,15 @@ // under the License. //! [`EliminateOneUnion`] eliminates single element `Union` + use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{internal_err, tree_node::Transformed, Result}; -use datafusion_expr::logical_plan::{tree_node::unwrap_arc, LogicalPlan, Union}; +use datafusion_common::{tree_node::Transformed, Result}; +use datafusion_expr::logical_plan::{LogicalPlan, Union}; +use std::sync::Arc; use crate::optimizer::ApplyOrder; -#[derive(Default)] +#[derive(Default, Debug)] /// An optimization rule that eliminates union with one element. pub struct EliminateOneUnion; @@ -34,14 +36,6 @@ impl EliminateOneUnion { } impl OptimizerRule for EliminateOneUnion { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateOneUnion::rewrite") - } - fn name(&self) -> &str { "eliminate_one_union" } @@ -56,9 +50,9 @@ impl OptimizerRule for EliminateOneUnion { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => { - Ok(Transformed::yes(unwrap_arc(inputs.pop().unwrap()))) - } + LogicalPlan::Union(Union { mut inputs, .. }) if inputs.len() == 1 => Ok( + Transformed::yes(Arc::unwrap_or_clone(inputs.pop().unwrap())), + ), _ => Ok(Transformed::no(plan)), } } @@ -88,10 +82,11 @@ mod tests { } fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { - assert_optimized_plan_eq_with_rules( + assert_optimized_plan_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, expected, + true, ) } @@ -99,10 +94,7 @@ mod tests { fn eliminate_nothing() -> Result<()> { let plan_builder = table_scan(Some("table"), &schema(), None)?; - let plan = plan_builder - .clone() - .union(plan_builder.clone().build()?)? - .build()?; + let plan = plan_builder.clone().union(plan_builder.build()?)?.build()?; let expected = "\ Union\ @@ -114,10 +106,10 @@ mod tests { #[test] fn eliminate_one_union() -> Result<()> { let table_plan = coerce_plan_expr_for_schema( - &table_scan(Some("table"), &schema(), None)?.build()?, + table_scan(Some("table"), &schema(), None)?.build()?, &schema().to_dfschema()?, )?; - let schema = table_plan.schema().clone(); + let schema = Arc::clone(table_plan.schema()); let single_union_plan = LogicalPlan::Union(Union { inputs: vec![Arc::new(table_plan)], schema, diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index c3c5d80922f9..1ecb32ca2a43 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -17,7 +17,7 @@ //! [`EliminateOuterJoin`] converts `LEFT/RIGHT/FULL` joins to `INNER` joins use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{internal_err, Column, DFSchema, Result}; +use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan}; use datafusion_expr::{Expr, Filter, Operator}; @@ -26,7 +26,6 @@ use datafusion_common::tree_node::Transformed; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use std::sync::Arc; -#[derive(Default)] /// /// Attempt to replace outer joins with inner joins. /// @@ -49,6 +48,7 @@ use std::sync::Arc; /// filters from the WHERE clause return false while any inputs are /// null and columns of those quals are come from nullable side of /// outer join. +#[derive(Default, Debug)] pub struct EliminateOuterJoin; impl EliminateOuterJoin { @@ -60,14 +60,6 @@ impl EliminateOuterJoin { /// Attempt to eliminate outer joins. impl OptimizerRule for EliminateOuterJoin { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called EliminateOuterJoin::rewrite") - } - fn name(&self) -> &str { "eliminate_outer_join" } @@ -86,7 +78,7 @@ impl OptimizerRule for EliminateOuterJoin { _config: &dyn OptimizerConfig, ) -> Result> { match plan { - LogicalPlan::Filter(filter) => match filter.input.as_ref() { + LogicalPlan::Filter(mut filter) => match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Join(join) => { let mut non_nullable_cols: Vec = vec![]; @@ -117,20 +109,24 @@ impl OptimizerRule for EliminateOuterJoin { } else { join.join_type }; + let new_join = Arc::new(LogicalPlan::Join(Join { - left: Arc::new((*join.left).clone()), - right: Arc::new((*join.right).clone()), + left: join.left, + right: join.right, join_type: new_join_type, join_constraint: join.join_constraint, on: join.on.clone(), filter: join.filter.clone(), - schema: join.schema.clone(), + schema: Arc::clone(&join.schema), null_equals_null: join.null_equals_null, })); Filter::try_new(filter.predicate, new_join) .map(|f| Transformed::yes(LogicalPlan::Filter(f))) } - _ => Ok(Transformed::no(LogicalPlan::Filter(filter))), + filter_input => { + filter.input = Arc::new(filter_input); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } }, _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index c47a86974cd2..48191ec20631 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -19,12 +19,11 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; +use datafusion_common::DFSchema; use datafusion_common::Result; -use datafusion_common::{internal_err, DFSchema}; use datafusion_expr::utils::split_conjunction_owned; use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator}; -use std::sync::Arc; // equijoin predicate type EquijoinPredicate = (Expr, Expr); @@ -39,7 +38,7 @@ type EquijoinPredicate = (Expr, Expr); /// has one equijoin predicate (`A.x = B.y`) and one filter predicate (`B.z > 50`). /// See [find_valid_equijoin_key_pair] for more information on what predicates /// are considered equijoins. -#[derive(Default)] +#[derive(Default, Debug)] pub struct ExtractEquijoinPredicate; impl ExtractEquijoinPredicate { @@ -50,13 +49,6 @@ impl ExtractEquijoinPredicate { } impl OptimizerRule for ExtractEquijoinPredicate { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called ExtractEquijoinPredicate::rewrite") - } fn supports_rewrite(&self) -> bool { true } @@ -122,8 +114,8 @@ impl OptimizerRule for ExtractEquijoinPredicate { fn split_eq_and_noneq_join_predicate( filter: Expr, - left_schema: &Arc, - right_schema: &Arc, + left_schema: &DFSchema, + right_schema: &DFSchema, ) -> Result<(Vec, Option)> { let exprs = split_conjunction_owned(filter); @@ -136,12 +128,8 @@ fn split_eq_and_noneq_join_predicate( op: Operator::Eq, ref right, }) => { - let join_key_pair = find_valid_equijoin_key_pair( - left, - right, - left_schema.clone(), - right_schema.clone(), - )?; + let join_key_pair = + find_valid_equijoin_key_pair(left, right, left_schema, right_schema)?; if let Some((left_expr, right_expr)) = join_key_pair { let left_expr_type = left_expr.get_type(left_schema)?; @@ -172,6 +160,7 @@ mod tests { use datafusion_expr::{ col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; + use std::sync::Arc; fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -368,8 +357,8 @@ mod tests { let t1 = test_table_scan_with_name("t1")?; let t2 = test_table_scan_with_name("t2")?; - let t1_schema = t1.schema().clone(); - let t2_schema = t2.schema().clone(); + let t1_schema = Arc::clone(t1.schema()); + let t2_schema = Arc::clone(t2.schema()); // filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2 let filter = Expr::eq( diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index ecd1901abe58..66c7463c3d5d 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -18,31 +18,21 @@ //! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable use crate::optimizer::ApplyOrder; +use crate::push_down_filter::on_lr_is_preserved; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, Result}; +use datafusion_common::Result; use datafusion_expr::utils::conjunction; -use datafusion_expr::{ - logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan, -}; +use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; use std::sync::Arc; -/// The FilterNullJoinKeys rule will identify inner joins with equi-join conditions -/// where the join key is nullable on one side and non-nullable on the other side -/// and then insert an `IsNotNull` filter on the nullable side since null values +/// The FilterNullJoinKeys rule will identify joins with equi-join conditions +/// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values /// can never match. -#[derive(Default)] +#[derive(Default, Debug)] pub struct FilterNullJoinKeys {} impl OptimizerRule for FilterNullJoinKeys { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called FilterNullJoinKeys::rewrite") - } - fn supports_rewrite(&self) -> bool { true } @@ -59,9 +49,13 @@ impl OptimizerRule for FilterNullJoinKeys { if !config.options().optimizer.filter_null_join_keys { return Ok(Transformed::no(plan)); } - match plan { - LogicalPlan::Join(mut join) if join.join_type == JoinType::Inner => { + LogicalPlan::Join(mut join) + if !join.on.is_empty() && !join.null_equals_null => + { + let (left_preserved, right_preserved) = + on_lr_is_preserved(join.join_type); + let left_schema = join.left.schema(); let right_schema = join.right.schema(); @@ -69,11 +63,11 @@ impl OptimizerRule for FilterNullJoinKeys { let mut right_filters = vec![]; for (l, r) in &join.on { - if l.nullable(left_schema)? { + if left_preserved && l.nullable(left_schema)? { left_filters.push(l.clone()); } - if r.nullable(right_schema)? { + if right_preserved && r.nullable(right_schema)? { right_filters.push(r.clone()); } } @@ -117,7 +111,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{col, lit, LogicalPlanBuilder}; + use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) @@ -126,7 +120,7 @@ mod tests { #[test] fn left_nullable() -> Result<()> { let (t1, t2) = test_tables()?; - let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?; + let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; let expected = "Inner Join: t1.optional_id = t2.id\ \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ @@ -134,10 +128,33 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + #[test] + fn left_nullable_left_join() -> Result<()> { + let (t1, t2) = test_tables()?; + let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; + let expected = "Left Join: t1.optional_id = t2.id\ + \n TableScan: t1\ + \n TableScan: t2"; + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn left_nullable_left_join_reordered() -> Result<()> { + let (t_left, t_right) = test_tables()?; + // Note: order of tables is reversed + let plan = + build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; + let expected = "Left Join: t2.id = t1.optional_id\ + \n TableScan: t2\ + \n Filter: t1.optional_id IS NOT NULL\ + \n TableScan: t1"; + assert_optimized_plan_equal(plan, expected) + } + #[test] fn left_nullable_on_condition_reversed() -> Result<()> { let (t1, t2) = test_tables()?; - let plan = build_plan(t1, t2, "t2.id", "t1.optional_id")?; + let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; let expected = "Inner Join: t1.optional_id = t2.id\ \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ @@ -148,7 +165,7 @@ mod tests { #[test] fn nested_join_multiple_filter_expr() -> Result<()> { let (t1, t2) = test_tables()?; - let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?; + let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; let schema = Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("t1_id", DataType::UInt32, true), @@ -252,11 +269,12 @@ mod tests { right_table: LogicalPlan, left_key: &str, right_key: &str, + join_type: JoinType, ) -> Result { LogicalPlanBuilder::from(left_table) .join( right_table, - JoinType::Inner, + join_type, ( vec![Column::from_qualified_name(left_key)], vec![Column::from_qualified_name(right_key)], diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs new file mode 100644 index 000000000000..c0eec78b183d --- /dev/null +++ b/datafusion/optimizer/src/join_key_set.rs @@ -0,0 +1,291 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [JoinKeySet] for tracking the set of join keys in a plan. + +use datafusion_expr::Expr; +use indexmap::{Equivalent, IndexSet}; + +/// Tracks a set of equality Join keys +/// +/// A join key is an expression that is used to join two tables via an equality +/// predicate such as `a.x = b.y` +/// +/// This struct models `a.x + 5 = b.y AND a.z = b.z` as two join keys +/// 1. `(a.x + 5, b.y)` +/// 2. `(a.z, b.z)` +/// +/// # Important properties: +/// +/// 1. Retains insert order +/// 2. Can quickly look up if a pair of expressions are in the set. +#[derive(Debug)] +pub struct JoinKeySet { + inner: IndexSet<(Expr, Expr)>, +} + +impl JoinKeySet { + /// Create a new empty set + pub fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Return true if the set contains a join pair + /// where left = right or right = left + pub fn contains(&self, left: &Expr, right: &Expr) -> bool { + self.inner.contains(&ExprPair::new(left, right)) + || self.inner.contains(&ExprPair::new(right, left)) + } + + /// Insert the join key `(left = right)` into the set if join pair `(right = + /// left)` is not already in the set + /// + /// returns true if the pair was inserted + pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool { + if self.contains(left, right) { + false + } else { + self.inner.insert((left.clone(), right.clone())); + true + } + } + + /// Same as [`Self::insert`] but avoids cloning expression if they + /// are owned + pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool { + if self.contains(&left, &right) { + false + } else { + self.inner.insert((left, right)); + true + } + } + + /// Inserts potentially many join keys into the set, copying only when necessary + /// + /// returns true if any of the pairs were inserted + pub fn insert_all<'a>( + &mut self, + iter: impl IntoIterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter.into_iter() { + inserted |= self.insert(left, right); + } + inserted + } + + /// Same as [`Self::insert_all`] but avoids cloning expressions if they are + /// already owned + /// + /// returns true if any of the pairs were inserted + pub fn insert_all_owned( + &mut self, + iter: impl IntoIterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter.into_iter() { + inserted |= self.insert_owned(left, right); + } + inserted + } + + /// Inserts any join keys that are common to both `s1` and `s2` into self + pub fn insert_intersection(&mut self, s1: &JoinKeySet, s2: &JoinKeySet) { + // note can't use inner.intersection as we need to consider both (l, r) + // and (r, l) in equality + for (left, right) in s1.inner.iter() { + if s2.contains(left, right) { + self.insert(left, right); + } + } + } + + /// returns true if this set is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Return the length of this set + #[cfg(test)] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Return an iterator over the join keys in this set + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(|(l, r)| (l, r)) + } +} + +/// Custom comparison operation to avoid copying owned values +/// +/// This behaves like a `(Expr, Expr)` tuple for hashing and comparison, but +/// avoids copying the values simply to comparing them. + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ExprPair<'a>(&'a Expr, &'a Expr); + +impl<'a> ExprPair<'a> { + fn new(left: &'a Expr, right: &'a Expr) -> Self { + Self(left, right) + } +} + +impl<'a> Equivalent<(Expr, Expr)> for ExprPair<'a> { + fn equivalent(&self, other: &(Expr, Expr)) -> bool { + self.0 == &other.0 && self.1 == &other.1 + } +} + +#[cfg(test)] +mod test { + use crate::join_key_set::JoinKeySet; + use datafusion_expr::{col, Expr}; + + #[test] + fn test_insert() { + let mut set = JoinKeySet::new(); + // new sets should be empty + assert!(set.is_empty()); + + // insert (a = b) + assert!(set.insert(&col("a"), &col("b"))); + assert!(!set.is_empty()); + + // insert (a=b) again returns false + assert!(!set.insert(&col("a"), &col("b"))); + assert_eq!(set.len(), 1); + + // insert (b = a) , should be considered equivalent + assert!(!set.insert(&col("b"), &col("a"))); + assert_eq!(set.len(), 1); + + // insert (a = c) should be considered different + assert!(set.insert(&col("a"), &col("c"))); + assert_eq!(set.len(), 2); + } + + #[test] + fn test_insert_owned() { + let mut set = JoinKeySet::new(); + assert!(set.insert_owned(col("a"), col("b"))); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("a"))); + assert!(!set.contains(&col("a"), &col("c"))); + } + + #[test] + fn test_contains() { + let mut set = JoinKeySet::new(); + assert!(set.insert(&col("a"), &col("b"))); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("a"))); + assert!(!set.contains(&col("a"), &col("c"))); + + assert!(set.insert(&col("a"), &col("c"))); + assert!(set.contains(&col("a"), &col("c"))); + assert!(set.contains(&col("c"), &col("a"))); + } + + #[test] + fn test_iterator() { + // put in c = a and + let mut set = JoinKeySet::new(); + // put in c = a , b = c, and a = c and expect to get only the first 2 + set.insert(&col("c"), &col("a")); + set.insert(&col("b"), &col("c")); + set.insert(&col("a"), &col("c")); + assert_contents(&set, vec![(&col("c"), &col("a")), (&col("b"), &col("c"))]); + } + + #[test] + fn test_insert_intersection() { + // a = b, b = c, c = d + let mut set1 = JoinKeySet::new(); + set1.insert(&col("a"), &col("b")); + set1.insert(&col("b"), &col("c")); + set1.insert(&col("c"), &col("d")); + + // a = a, b = b, b = c, d = c + // should only intersect on b = c and c = d + let mut set2 = JoinKeySet::new(); + set2.insert(&col("a"), &col("a")); + set2.insert(&col("b"), &col("b")); + set2.insert(&col("b"), &col("c")); + set2.insert(&col("d"), &col("c")); + + let mut set = JoinKeySet::new(); + // put something in there already + set.insert(&col("x"), &col("y")); + set.insert_intersection(&set1, &set2); + + assert_contents( + &set, + vec![ + (&col("x"), &col("y")), + (&col("b"), &col("c")), + (&col("c"), &col("d")), + ], + ); + } + + fn assert_contents(set: &JoinKeySet, expected: Vec<(&Expr, &Expr)>) { + let contents: Vec<_> = set.iter().collect(); + assert_eq!(contents, expected); + } + + #[test] + fn test_insert_all() { + let mut set = JoinKeySet::new(); + + // insert (a=b), (b=c), (b=a) + set.insert_all(vec![ + &(col("a"), col("b")), + &(col("b"), col("c")), + &(col("b"), col("a")), + ]); + assert_eq!(set.len(), 2); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("c"))); + assert!(set.contains(&col("b"), &col("a"))); + + // should not contain (a=c) + assert!(!set.contains(&col("a"), &col("c"))); + } + + #[test] + fn test_insert_all_owned() { + let mut set = JoinKeySet::new(); + + // insert (a=b), (b=c), (b=a) + set.insert_all_owned(vec![ + (col("a"), col("b")), + (col("b"), col("c")), + (col("b"), col("a")), + ]); + assert_eq!(set.len(), 2); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("c"))); + assert!(set.contains(&col("b"), &col("a"))); + + // should not contain (a=c) + assert!(!set.contains(&col("a"), &col("c"))); + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 9176d67c1d18..f31083831125 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -14,17 +14,19 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! # DataFusion Optimizer //! //! Contains rules for rewriting [`LogicalPlan`]s //! //! 1. [`Analyzer`] applies [`AnalyzerRule`]s to transform `LogicalPlan`s -//! to make the plan valid prior to the rest of the DataFusion optimization -//! process (for example, [`TypeCoercion`]). +//! to make the plan valid prior to the rest of the DataFusion optimization +//! process (for example, [`TypeCoercion`]). //! //! 2. [`Optimizer`] applies [`OptimizerRule`]s to transform `LogicalPlan`s -//! into equivalent, but more efficient plans. +//! into equivalent, but more efficient plans. //! //! [`LogicalPlan`]: datafusion_expr::LogicalPlan //! [`TypeCoercion`]: analyzer::type_coercion::TypeCoercion @@ -35,6 +37,7 @@ pub mod decorrelate_predicate_subquery; pub mod eliminate_cross_join; pub mod eliminate_duplicated_expr; pub mod eliminate_filter; +pub mod eliminate_group_by_constant; pub mod eliminate_join; pub mod eliminate_limit; pub mod eliminate_nested_union; @@ -48,7 +51,6 @@ pub mod propagate_empty_relation; pub mod push_down_filter; pub mod push_down_limit; pub mod replace_distinct_aggregate; -pub mod rewrite_disjunctive_predicate; pub mod scalar_subquery_to_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; @@ -60,8 +62,10 @@ pub mod test; pub use analyzer::{Analyzer, AnalyzerRule}; pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; +#[allow(deprecated)] pub use utils::optimize_children; +pub(crate) mod join_key_set; mod plan_signature; #[cfg(test)] diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections/mod.rs similarity index 70% rename from datafusion/optimizer/src/optimize_projections.rs rename to datafusion/optimizer/src/optimize_projections/mod.rs index 70ffd8f24498..ec2225bbc042 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -17,26 +17,30 @@ //! [`OptimizeProjections`] identifies and eliminates unused columns -use std::collections::HashSet; +mod required_indices; + +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use arrow::datatypes::SchemaRef; use datafusion_common::{ - get_required_group_by_exprs_indices, internal_err, Column, DFSchema, DFSchemaRef, + get_required_group_by_exprs_indices, internal_datafusion_err, internal_err, Column, JoinType, Result, }; -use datafusion_expr::expr::{Alias, ScalarFunction}; +use datafusion_expr::expr::Alias; +use datafusion_expr::Unnest; use datafusion_expr::{ - logical_plan::LogicalPlan, projection_schema, Aggregate, BinaryExpr, Cast, Distinct, - Expr, Projection, TableScan, Window, + logical_plan::LogicalPlan, projection_schema, Aggregate, Distinct, Expr, Projection, + TableScan, Window, }; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use hashbrown::HashMap; -use itertools::{izip, Itertools}; +use crate::optimize_projections::required_indices::RequiredIndicies; +use crate::utils::NamePreserver; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, +}; /// Optimizer rule to prune unnecessary columns from intermediate schemas /// inside the [`LogicalPlan`]. This rule: @@ -53,7 +57,7 @@ use itertools::{izip, Itertools}; /// The rule analyzes the input logical plan, determines the necessary column /// indices, and then removes any unnecessary columns. It also removes any /// unnecessary projections from the plan tree. -#[derive(Default)] +#[derive(Default, Debug)] pub struct OptimizeProjections {} impl OptimizeProjections { @@ -64,16 +68,6 @@ impl OptimizeProjections { } impl OptimizerRule for OptimizeProjections { - fn try_optimize( - &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // All output fields are necessary: - let indices = (0..plan.schema().fields().len()).collect::>(); - optimize_projections(plan, config, &indices) - } - fn name(&self) -> &str { "optimize_projections" } @@ -81,6 +75,20 @@ impl OptimizerRule for OptimizeProjections { fn apply_order(&self) -> Option { None } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + // All output fields are necessary: + let indices = RequiredIndicies::new_for_all_exprs(&plan); + optimize_projections(plan, config, indices) + } } /// Removes unnecessary columns (e.g. columns that do not appear in the output @@ -92,7 +100,7 @@ impl OptimizerRule for OptimizeProjections { /// - `plan`: A reference to the input `LogicalPlan` to optimize. /// - `config`: A reference to the optimizer configuration. /// - `indices`: A slice of column indices that represent the necessary column -/// indices for downstream operations. +/// indices for downstream (parent) plan nodes. /// /// # Returns /// @@ -101,21 +109,175 @@ impl OptimizerRule for OptimizeProjections { /// - `Ok(Some(LogicalPlan))`: An optimized `LogicalPlan` without unnecessary /// columns. /// - `Ok(None)`: Signal that the given logical plan did not require any change. -/// - `Err(error)`: An error occured during the optimization process. +/// - `Err(error)`: An error occurred during the optimization process. fn optimize_projections( - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - indices: &[usize], -) -> Result> { - // `child_required_indices` stores - // - indices of the columns required for each child - // - a flag indicating whether putting a projection above children is beneficial for the parent. - // As an example LogicalPlan::Filter benefits from small tables. Hence for filter child this flag would be `true`. - let child_required_indices: Vec<(Vec, bool)> = match plan { + indices: RequiredIndicies, +) -> Result> { + // Recursively rewrite any nodes that may be able to avoid computation given + // their parents' required indices. + match plan { + LogicalPlan::Projection(proj) => { + return merge_consecutive_projections(proj)?.transform_data(|proj| { + rewrite_projection_given_requirements(proj, config, &indices) + }) + } + LogicalPlan::Aggregate(aggregate) => { + // Split parent requirements to GROUP BY and aggregate sections: + let n_group_exprs = aggregate.group_expr_len()?; + // Offset aggregate indices so that they point to valid indices at + // `aggregate.aggr_expr`: + let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs); + + // Get absolutely necessary GROUP BY fields: + let group_by_expr_existing = aggregate + .group_expr + .iter() + .map(|group_by_expr| group_by_expr.schema_name().to_string()) + .collect::>(); + + let new_group_bys = if let Some(simplest_groupby_indices) = + get_required_group_by_exprs_indices( + aggregate.input.schema(), + &group_by_expr_existing, + ) { + // Some of the fields in the GROUP BY may be required by the + // parent even if these fields are unnecessary in terms of + // functional dependency. + group_by_reqs + .append(&simplest_groupby_indices) + .get_at_indices(&aggregate.group_expr) + } else { + aggregate.group_expr + }; + + // Only use the absolutely necessary aggregate expressions required + // by the parent: + let mut new_aggr_expr = aggregate_reqs.get_at_indices(&aggregate.aggr_expr); + + // Aggregations always need at least one aggregate expression. + // With a nested count, we don't require any column as input, but + // still need to create a correct aggregate, which may be optimized + // out later. As an example, consider the following query: + // + // SELECT count(*) FROM (SELECT count(*) FROM [...]) + // + // which always returns 1. + if new_aggr_expr.is_empty() + && new_group_bys.is_empty() + && !aggregate.aggr_expr.is_empty() + { + // take the old, first aggregate expression + new_aggr_expr = aggregate.aggr_expr; + new_aggr_expr.resize_with(1, || unreachable!()); + } + + let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); + let schema = aggregate.input.schema(); + let necessary_indices = + RequiredIndicies::new().with_exprs(schema, all_exprs_iter); + let necessary_exprs = necessary_indices.get_required_exprs(schema); + + return optimize_projections( + Arc::unwrap_or_clone(aggregate.input), + config, + necessary_indices, + )? + .transform_data(|aggregate_input| { + // Simplify the input of the aggregation by adding a projection so + // that its input only contains absolutely necessary columns for + // the aggregate expressions. Note that necessary_indices refer to + // fields in `aggregate.input.schema()`. + add_projection_on_top_if_helpful(aggregate_input, necessary_exprs) + })? + .map_data(|aggregate_input| { + // Create a new aggregate plan with the updated input and only the + // absolutely necessary fields: + Aggregate::try_new( + Arc::new(aggregate_input), + new_group_bys, + new_aggr_expr, + ) + .map(LogicalPlan::Aggregate) + }); + } + LogicalPlan::Window(window) => { + let input_schema = Arc::clone(window.input.schema()); + // Split parent requirements to child and window expression sections: + let n_input_fields = input_schema.fields().len(); + // Offset window expression indices so that they point to valid + // indices at `window.window_expr`: + let (child_reqs, window_reqs) = indices.split_off(n_input_fields); + + // Only use window expressions that are absolutely necessary according + // to parent requirements: + let new_window_expr = window_reqs.get_at_indices(&window.window_expr); + + // Get all the required column indices at the input, either by the + // parent or window expression requirements. + let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr); + + return optimize_projections( + Arc::unwrap_or_clone(window.input), + config, + required_indices.clone(), + )? + .transform_data(|window_child| { + if new_window_expr.is_empty() { + // When no window expression is necessary, use the input directly: + Ok(Transformed::no(window_child)) + } else { + // Calculate required expressions at the input of the window. + // Please note that we use `input_schema`, because `required_indices` + // refers to that schema + let required_exprs = + required_indices.get_required_exprs(&input_schema); + let window_child = + add_projection_on_top_if_helpful(window_child, required_exprs)? + .data; + Window::try_new(new_window_expr, Arc::new(window_child)) + .map(LogicalPlan::Window) + .map(Transformed::yes) + } + }); + } + LogicalPlan::TableScan(table_scan) => { + let TableScan { + table_name, + source, + projection, + filters, + fetch, + projected_schema: _, + } = table_scan; + + // Get indices referred to in the original (schema with all fields) + // given projected indices. + let projection = match &projection { + Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), + None => indices.into_inner(), + }; + return TableScan::try_new( + table_name, + source, + Some(projection), + filters, + fetch, + ) + .map(LogicalPlan::TableScan) + .map(Transformed::yes); + } + // Other node types are handled below + _ => {} + }; + + // For other plan node types, calculate indices for columns they use and + // try to rewrite their children + let mut child_required_indices: Vec = match &plan { LogicalPlan::Sort(_) | LogicalPlan::Filter(_) | LogicalPlan::Repartition(_) - | LogicalPlan::Unnest(_) | LogicalPlan::Union(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Distinct(Distinct::On(_)) => { @@ -123,12 +285,13 @@ fn optimize_projections( // that appear in this plan's expressions to its child. All these // operators benefit from "small" inputs, so the projection_beneficial // flag is `true`. - let exprs = plan.expressions(); plan.inputs() .into_iter() .map(|input| { - get_all_required_indices(indices, input, exprs.iter()) - .map(|idxs| (idxs, true)) + indices + .clone() + .with_projection_beneficial() + .with_plan_exprs(&plan, input.schema()) }) .collect::>()? } @@ -137,13 +300,9 @@ fn optimize_projections( // that appear in this plan's expressions to its child. These operators // do not benefit from "small" inputs, so the projection_beneficial // flag is `false`. - let exprs = plan.expressions(); plan.inputs() .into_iter() - .map(|input| { - get_all_required_indices(indices, input, exprs.iter()) - .map(|idxs| (idxs, false)) - }) + .map(|input| indices.clone().with_plan_exprs(&plan, input.schema())) .collect::>()? } LogicalPlan::Copy(_) @@ -159,18 +318,16 @@ fn optimize_projections( // TODO: For some subquery variants (e.g. a subquery arising from an // EXISTS expression), we may not need to require all indices. plan.inputs() - .iter() - .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) - .collect::>() + .into_iter() + .map(RequiredIndicies::new_for_all_exprs) + .collect() } LogicalPlan::Extension(extension) => { - let necessary_children_indices = if let Some(necessary_children_indices) = - extension.node.necessary_children_exprs(indices) - { - necessary_children_indices - } else { + let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices.indices()) + else { // Requirements from parent cannot be routed down to user defined logical plan safely - return Ok(None); + return Ok(Transformed::no(plan)); }; let children = extension.node.inputs(); if children.len() != necessary_children_indices.len() { @@ -178,16 +335,12 @@ fn optimize_projections( Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ consistent with actual children length for the node."); } - // Expressions used by node. - let exprs = plan.expressions(); children .into_iter() .zip(necessary_children_indices) .map(|(child, necessary_indices)| { - let child_schema = child.schema(); - let child_req_indices = - indices_referred_by_exprs(child_schema, exprs.iter())?; - Ok((merge_slices(&necessary_indices, &child_req_indices), false)) + RequiredIndicies::new_from_indices(necessary_indices) + .with_plan_exprs(&plan, child.schema()) }) .collect::>>()? } @@ -195,236 +348,80 @@ fn optimize_projections( | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::DescribeTable(_) => { + | LogicalPlan::DescribeTable(_) + | LogicalPlan::Execute(_) => { // These operators have no inputs, so stop the optimization process. - return Ok(None); - } - LogicalPlan::Projection(proj) => { - return if let Some(proj) = merge_consecutive_projections(proj)? { - Ok(Some( - rewrite_projection_given_requirements(&proj, config, indices)? - // Even if we cannot optimize the projection, merge if possible: - .unwrap_or_else(|| LogicalPlan::Projection(proj)), - )) - } else { - rewrite_projection_given_requirements(proj, config, indices) - }; - } - LogicalPlan::Aggregate(aggregate) => { - // Split parent requirements to GROUP BY and aggregate sections: - let n_group_exprs = aggregate.group_expr_len()?; - let (group_by_reqs, mut aggregate_reqs): (Vec, Vec) = - indices.iter().partition(|&&idx| idx < n_group_exprs); - // Offset aggregate indices so that they point to valid indices at - // `aggregate.aggr_expr`: - for idx in aggregate_reqs.iter_mut() { - *idx -= n_group_exprs; - } - - // Get absolutely necessary GROUP BY fields: - let group_by_expr_existing = aggregate - .group_expr - .iter() - .map(|group_by_expr| group_by_expr.display_name()) - .collect::>>()?; - let new_group_bys = if let Some(simplest_groupby_indices) = - get_required_group_by_exprs_indices( - aggregate.input.schema(), - &group_by_expr_existing, - ) { - // Some of the fields in the GROUP BY may be required by the - // parent even if these fields are unnecessary in terms of - // functional dependency. - let required_indices = - merge_slices(&simplest_groupby_indices, &group_by_reqs); - get_at_indices(&aggregate.group_expr, &required_indices) - } else { - aggregate.group_expr.clone() - }; - - // Only use the absolutely necessary aggregate expressions required - // by the parent: - let mut new_aggr_expr = get_at_indices(&aggregate.aggr_expr, &aggregate_reqs); - - // Aggregations always need at least one aggregate expression. - // With a nested count, we don't require any column as input, but - // still need to create a correct aggregate, which may be optimized - // out later. As an example, consider the following query: - // - // SELECT COUNT(*) FROM (SELECT COUNT(*) FROM [...]) - // - // which always returns 1. - if new_aggr_expr.is_empty() - && new_group_bys.is_empty() - && !aggregate.aggr_expr.is_empty() - { - new_aggr_expr = vec![aggregate.aggr_expr[0].clone()]; - } - - let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter()); - let schema = aggregate.input.schema(); - let necessary_indices = indices_referred_by_exprs(schema, all_exprs_iter)?; - - let aggregate_input = if let Some(input) = - optimize_projections(&aggregate.input, config, &necessary_indices)? - { - input - } else { - aggregate.input.as_ref().clone() - }; - - // Simplify the input of the aggregation by adding a projection so - // that its input only contains absolutely necessary columns for - // the aggregate expressions. Note that necessary_indices refer to - // fields in `aggregate.input.schema()`. - let necessary_exprs = get_required_exprs(schema, &necessary_indices); - let (aggregate_input, _) = - add_projection_on_top_if_helpful(aggregate_input, necessary_exprs)?; - - // Create a new aggregate plan with the updated input and only the - // absolutely necessary fields: - return Aggregate::try_new( - Arc::new(aggregate_input), - new_group_bys, - new_aggr_expr, - ) - .map(|aggregate| Some(LogicalPlan::Aggregate(aggregate))); - } - LogicalPlan::Window(window) => { - // Split parent requirements to child and window expression sections: - let n_input_fields = window.input.schema().fields().len(); - let (child_reqs, mut window_reqs): (Vec, Vec) = - indices.iter().partition(|&&idx| idx < n_input_fields); - // Offset window expression indices so that they point to valid - // indices at `window.window_expr`: - for idx in window_reqs.iter_mut() { - *idx -= n_input_fields; - } - - // Only use window expressions that are absolutely necessary according - // to parent requirements: - let new_window_expr = get_at_indices(&window.window_expr, &window_reqs); - - // Get all the required column indices at the input, either by the - // parent or window expression requirements. - let required_indices = get_all_required_indices( - &child_reqs, - &window.input, - new_window_expr.iter(), - )?; - let window_child = if let Some(new_window_child) = - optimize_projections(&window.input, config, &required_indices)? - { - new_window_child - } else { - window.input.as_ref().clone() - }; - - return if new_window_expr.is_empty() { - // When no window expression is necessary, use the input directly: - Ok(Some(window_child)) - } else { - // Calculate required expressions at the input of the window. - // Please note that we use `old_child`, because `required_indices` - // refers to `old_child`. - let required_exprs = - get_required_exprs(window.input.schema(), &required_indices); - let (window_child, _) = - add_projection_on_top_if_helpful(window_child, required_exprs)?; - Window::try_new(new_window_expr, Arc::new(window_child)) - .map(|window| Some(LogicalPlan::Window(window))) - }; + return Ok(Transformed::no(plan)); } LogicalPlan::Join(join) => { let left_len = join.left.schema().fields().len(); let (left_req_indices, right_req_indices) = split_join_requirements(left_len, indices, &join.join_type); - let exprs = plan.expressions(); let left_indices = - get_all_required_indices(&left_req_indices, &join.left, exprs.iter())?; + left_req_indices.with_plan_exprs(&plan, join.left.schema())?; let right_indices = - get_all_required_indices(&right_req_indices, &join.right, exprs.iter())?; + right_req_indices.with_plan_exprs(&plan, join.right.schema())?; // Joins benefit from "small" input tables (lower memory usage). // Therefore, each child benefits from projection: - vec![(left_indices, true), (right_indices, true)] + vec![ + left_indices.with_projection_beneficial(), + right_indices.with_projection_beneficial(), + ] } - LogicalPlan::CrossJoin(cross_join) => { - let left_len = cross_join.left.schema().fields().len(); - let (left_child_indices, right_child_indices) = - split_join_requirements(left_len, indices, &JoinType::Inner); - // Joins benefit from "small" input tables (lower memory usage). - // Therefore, each child benefits from projection: - vec![(left_child_indices, true), (right_child_indices, true)] + // these nodes are explicitly rewritten in the match statement above + LogicalPlan::Projection(_) + | LogicalPlan::Aggregate(_) + | LogicalPlan::Window(_) + | LogicalPlan::TableScan(_) => { + return internal_err!( + "OptimizeProjection: should have handled in the match statement above" + ); } - LogicalPlan::TableScan(table_scan) => { - let schema = table_scan.source.schema(); - // Get indices referred to in the original (schema with all fields) - // given projected indices. - let projection = with_indices(&table_scan.projection, schema, |map| { - indices.iter().map(|&idx| map[idx]).collect() - }); - - return TableScan::try_new( - table_scan.table_name.clone(), - table_scan.source.clone(), - Some(projection), - table_scan.filters.clone(), - table_scan.fetch, - ) - .map(|table| Some(LogicalPlan::TableScan(table))); + LogicalPlan::Unnest(Unnest { + dependency_indices, .. + }) => { + vec![RequiredIndicies::new_from_indices( + dependency_indices.clone(), + )] } }; - let new_inputs = izip!(child_required_indices, plan.inputs().into_iter()) - .map(|((required_indices, projection_beneficial), child)| { - let (input, is_changed) = if let Some(new_input) = - optimize_projections(child, config, &required_indices)? - { - (new_input, true) - } else { - (child.clone(), false) - }; - let project_exprs = get_required_exprs(child.schema(), &required_indices); - let (input, proj_added) = if projection_beneficial { - add_projection_on_top_if_helpful(input, project_exprs)? - } else { - (input, false) - }; - Ok((is_changed || proj_added).then_some(input)) - }) - .collect::>>()?; - if new_inputs.iter().all(|child| child.is_none()) { - // All children are the same in this case, no need to change the plan: - Ok(None) - } else { - // At least one of the children is changed: - let new_inputs = izip!(new_inputs, plan.inputs()) - // If new_input is `None`, this means child is not changed, so use - // `old_child` during construction: - .map(|(new_input, old_child)| new_input.unwrap_or_else(|| old_child.clone())) - .collect(); - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) + // Required indices are currently ordered (child0, child1, ...) + // but the loop pops off the last element, so we need to reverse the order + child_required_indices.reverse(); + if child_required_indices.len() != plan.inputs().len() { + return internal_err!( + "OptimizeProjection: child_required_indices length mismatch with plan inputs" + ); } -} -/// This function applies the given function `f` to the projection indices -/// `proj_indices` if they exist. Otherwise, applies `f` to a default set -/// of indices according to `schema`. -fn with_indices( - proj_indices: &Option>, - schema: SchemaRef, - mut f: F, -) -> Vec -where - F: FnMut(&[usize]) -> Vec, -{ - match proj_indices { - Some(indices) => f(indices.as_slice()), - None => { - let range: Vec = (0..schema.fields.len()).collect(); - f(range.as_slice()) - } + // Rewrite children of the plan + let transformed_plan = plan.map_children(|child| { + let required_indices = child_required_indices.pop().ok_or_else(|| { + internal_datafusion_err!( + "Unexpected number of required_indices in OptimizeProjections rule" + ) + })?; + + let projection_beneficial = required_indices.projection_beneficial(); + let project_exprs = required_indices.get_required_exprs(child.schema()); + + optimize_projections(child, config, required_indices)?.transform_data( + |new_input| { + if projection_beneficial { + add_projection_on_top_if_helpful(new_input, project_exprs) + } else { + Ok(Transformed::no(new_input)) + } + }, + ) + })?; + + // If any of the children are transformed, we need to potentially update the plan's schema + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) } } @@ -448,56 +445,104 @@ where /// merged projection. /// - `Ok(None)`: Signals that merge is not beneficial (and has not taken place). /// - `Err(error)`: An error occured during the function call. -fn merge_consecutive_projections(proj: &Projection) -> Result> { - let LogicalPlan::Projection(prev_projection) = proj.input.as_ref() else { - return Ok(None); +fn merge_consecutive_projections(proj: Projection) -> Result> { + let Projection { + expr, + input, + schema, + .. + } = proj; + let LogicalPlan::Projection(prev_projection) = input.as_ref() else { + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); }; // Count usages (referrals) of each projection expression in its input fields: - let mut column_referral_map = HashMap::::new(); - for columns in proj.expr.iter().flat_map(|expr| expr.to_columns()) { - for col in columns.into_iter() { - *column_referral_map.entry(col.clone()).or_default() += 1; - } - } + let mut column_referral_map = HashMap::<&Column, usize>::new(); + expr.iter() + .for_each(|expr| expr.add_column_ref_counts(&mut column_referral_map)); - // If an expression is non-trivial and appears more than once, consecutive - // projections will benefit from a compute-once approach. For details, see: - // https://github.com/apache/datafusion/issues/8296 + // If an expression is non-trivial and appears more than once, do not merge + // them as consecutive projections will benefit from a compute-once approach. + // For details, see: https://github.com/apache/datafusion/issues/8296 if column_referral_map.into_iter().any(|(col, usage)| { usage > 1 && !is_expr_trivial( &prev_projection.expr - [prev_projection.schema.index_of_column(&col).unwrap()], + [prev_projection.schema.index_of_column(col).unwrap()], ) }) { - return Ok(None); + // no change + return Projection::try_new_with_schema(expr, input, schema).map(Transformed::no); } - // If all the expression of the top projection can be rewritten, do so and - // create a new projection: - let new_exprs = proj - .expr - .iter() - .map(|expr| rewrite_expr(expr, prev_projection)) - .collect::>>>()?; - if let Some(new_exprs) = new_exprs { + let LogicalPlan::Projection(prev_projection) = Arc::unwrap_or_clone(input) else { + // We know it is a `LogicalPlan::Projection` from check above + unreachable!(); + }; + + // Try to rewrite the expressions in the current projection using the + // previous projection as input: + let name_preserver = NamePreserver::new_for_projection(); + let mut original_names = vec![]; + let new_exprs = expr.into_iter().map_until_stop_and_collect(|expr| { + original_names.push(name_preserver.save(&expr)); + + // do not rewrite top level Aliases (rewriter will remove all aliases within exprs) + match expr { + Expr::Alias(Alias { + expr, + relation, + name, + }) => rewrite_expr(*expr, &prev_projection).map(|result| { + result.update_data(|expr| Expr::Alias(Alias::new(expr, relation, name))) + }), + e => rewrite_expr(e, &prev_projection), + } + })?; + + // if the expressions could be rewritten, create a new projection with the + // new expressions + if new_exprs.transformed { + // Add any needed aliases back to the expressions let new_exprs = new_exprs + .data .into_iter() - .zip(proj.expr.iter()) - .map(|(new_expr, old_expr)| { - new_expr.alias_if_changed(old_expr.name_for_alias()?) - }) - .collect::>>()?; - Projection::try_new(new_exprs, prev_projection.input.clone()).map(Some) + .zip(original_names) + .map(|(expr, original_name)| original_name.restore(expr)) + .collect::>(); + Projection::try_new(new_exprs, prev_projection.input).map(Transformed::yes) } else { - Ok(None) + // not rewritten, so put the projection back together + let input = Arc::new(LogicalPlan::Projection(prev_projection)); + Projection::try_new_with_schema(new_exprs.data, input, schema) + .map(Transformed::no) } } -/// Trim the given expression by removing any unnecessary layers of aliasing. -/// If the expression is an alias, the function returns the underlying expression. -/// Otherwise, it returns the given expression as is. +// Check whether `expr` is trivial; i.e. it doesn't imply any computation. +fn is_expr_trivial(expr: &Expr) -> bool { + matches!(expr, Expr::Column(_) | Expr::Literal(_)) +} + +/// Rewrites a projection expression using the projection before it (i.e. its input) +/// This is a subroutine to the `merge_consecutive_projections` function. +/// +/// # Parameters +/// +/// * `expr` - A reference to the expression to rewrite. +/// * `input` - A reference to the input of the projection expression (itself +/// a projection). +/// +/// # Returns +/// +/// A `Result` object with the following semantics: +/// +/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. +/// - `Ok(None)`: Signals that `expr` can not be rewritten. +/// - `Err(error)`: An error occurred during the function call. +/// +/// # Notes +/// This rewrite also removes any unnecessary layers of aliasing. /// /// Without trimming, we can end up with unnecessary indirections inside expressions /// during projection merges. @@ -523,84 +568,28 @@ fn merge_consecutive_projections(proj: &Projection) -> Result /// Projection((a as a1 + b as b1) as sum1) /// --Source(a, b) /// ``` -fn trim_expr(expr: Expr) -> Expr { - match expr { - Expr::Alias(alias) => trim_expr(*alias.expr), - _ => expr, - } -} - -// Check whether `expr` is trivial; i.e. it doesn't imply any computation. -fn is_expr_trivial(expr: &Expr) -> bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_)) -} - -// Exit early when there is no rewrite to do. -macro_rules! rewrite_expr_with_check { - ($expr:expr, $input:expr) => { - if let Some(value) = rewrite_expr($expr, $input)? { - value - } else { - return Ok(None); - } - }; -} - -/// Rewrites a projection expression using the projection before it (i.e. its input) -/// This is a subroutine to the `merge_consecutive_projections` function. -/// -/// # Parameters -/// -/// * `expr` - A reference to the expression to rewrite. -/// * `input` - A reference to the input of the projection expression (itself -/// a projection). -/// -/// # Returns -/// -/// A `Result` object with the following semantics: -/// -/// - `Ok(Some(Expr))`: Rewrite was successful. Contains the rewritten result. -/// - `Ok(None)`: Signals that `expr` can not be rewritten. -/// - `Err(error)`: An error occurred during the function call. -fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { - let result = match expr { - Expr::Column(col) => { - // Find index of column: - let idx = input.schema.index_of_column(col)?; - input.expr[idx].clone() - } - Expr::BinaryExpr(binary) => Expr::BinaryExpr(BinaryExpr::new( - Box::new(trim_expr(rewrite_expr_with_check!(&binary.left, input))), - binary.op, - Box::new(trim_expr(rewrite_expr_with_check!(&binary.right, input))), - )), - Expr::Alias(alias) => Expr::Alias(Alias::new( - trim_expr(rewrite_expr_with_check!(&alias.expr, input)), - alias.relation.clone(), - alias.name.clone(), - )), - Expr::Literal(_) => expr.clone(), - Expr::Cast(cast) => { - let new_expr = rewrite_expr_with_check!(&cast.expr, input); - Expr::Cast(Cast::new(Box::new(new_expr), cast.data_type.clone())) - } - Expr::ScalarFunction(scalar_fn) => { - return Ok(scalar_fn - .args - .iter() - .map(|expr| rewrite_expr(expr, input)) - .collect::>>()? - .map(|new_args| { - Expr::ScalarFunction(ScalarFunction::new_func_def( - scalar_fn.func_def.clone(), - new_args, - )) - })); +fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { + expr.transform_up(|expr| { + match expr { + // remove any intermediate aliases + Expr::Alias(alias) => Ok(Transformed::yes(*alias.expr)), + Expr::Column(col) => { + // Find index of column: + let idx = input.schema.index_of_column(&col)?; + // get the corresponding unaliased input expression + // + // For example: + // * the input projection is [`a + b` as c, `d + e` as f] + // * the current column is an expression "f" + // + // return the expression `d + e` (not `d + e` as f) + let input_expr = input.expr[idx].clone().unalias_nested().data; + Ok(Transformed::yes(input_expr)) + } + // Unsupported type for consecutive projection merge analysis. + _ => Ok(Transformed::no(expr)), } - // Unsupported type for consecutive projection merge analysis. - _ => return Ok(None), - }; - Ok(Some(result)) + }) } /// Accumulates outer-referenced columns by the @@ -611,12 +600,12 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { /// * `expr` - The expression to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -fn outer_columns(expr: &Expr, columns: &mut HashSet) { +fn outer_columns<'a>(expr: &'a Expr, columns: &mut HashSet<&'a Column>) { // inspect_expr_pre doesn't handle subquery references, so find them explicitly expr.apply(|expr| { match expr { Expr::OuterReferenceColumn(_, col) => { - columns.insert(col.clone()); + columns.insert(col); } Expr::ScalarSubquery(subquery) => { outer_columns_helper_multi(&subquery.outer_ref_columns, columns); @@ -646,139 +635,13 @@ fn outer_columns(expr: &Expr, columns: &mut HashSet) { /// * `exprs` - The expressions to analyze for outer-referenced columns. /// * `columns` - A mutable reference to a `HashSet` where detected /// columns are collected. -fn outer_columns_helper_multi<'a>( +fn outer_columns_helper_multi<'a, 'b>( exprs: impl IntoIterator, - columns: &mut HashSet, + columns: &'b mut HashSet<&'a Column>, ) { exprs.into_iter().for_each(|e| outer_columns(e, columns)); } -/// Generates the required expressions (columns) that reside at `indices` of -/// the given `input_schema`. -/// -/// # Arguments -/// -/// * `input_schema` - A reference to the input schema. -/// * `indices` - A slice of `usize` indices specifying required columns. -/// -/// # Returns -/// -/// A vector of `Expr::Column` expressions residing at `indices` of the `input_schema`. -fn get_required_exprs(input_schema: &Arc, indices: &[usize]) -> Vec { - indices - .iter() - .map(|&idx| Expr::Column(Column::from(input_schema.qualified_field(idx)))) - .collect() -} - -/// Get indices of the fields referred to by any expression in `exprs` within -/// the given schema (`input_schema`). -/// -/// # Arguments -/// -/// * `input_schema`: The input schema to analyze for index requirements. -/// * `exprs`: An iterator of expressions for which we want to find necessary -/// field indices. -/// -/// # Returns -/// -/// A [`Result`] object containing the indices of all required fields in -/// `input_schema` to calculate all `exprs` successfully. -fn indices_referred_by_exprs<'a>( - input_schema: &DFSchemaRef, - exprs: impl Iterator, -) -> Result> { - let indices = exprs - .map(|expr| indices_referred_by_expr(input_schema, expr)) - .collect::>>()?; - Ok(indices - .into_iter() - .flatten() - // Make sure no duplicate entries exist and indices are ordered: - .sorted() - .dedup() - .collect()) -} - -/// Get indices of the fields referred to by the given expression `expr` within -/// the given schema (`input_schema`). -/// -/// # Parameters -/// -/// * `input_schema`: The input schema to analyze for index requirements. -/// * `expr`: An expression for which we want to find necessary field indices. -/// -/// # Returns -/// -/// A [`Result`] object containing the indices of all required fields in -/// `input_schema` to calculate `expr` successfully. -fn indices_referred_by_expr( - input_schema: &DFSchemaRef, - expr: &Expr, -) -> Result> { - let mut cols = expr.to_columns()?; - // Get outer-referenced (subquery) columns: - outer_columns(expr, &mut cols); - Ok(cols - .iter() - .flat_map(|col| input_schema.index_of_column(col)) - .collect()) -} - -/// Gets all required indices for the input; i.e. those required by the parent -/// and those referred to by `exprs`. -/// -/// # Parameters -/// -/// * `parent_required_indices` - A slice of indices required by the parent plan. -/// * `input` - The input logical plan to analyze for index requirements. -/// * `exprs` - An iterator of expressions used to determine required indices. -/// -/// # Returns -/// -/// A `Result` containing a vector of `usize` indices containing all the required -/// indices. -fn get_all_required_indices<'a>( - parent_required_indices: &[usize], - input: &LogicalPlan, - exprs: impl Iterator, -) -> Result> { - indices_referred_by_exprs(input.schema(), exprs) - .map(|indices| merge_slices(parent_required_indices, &indices)) -} - -/// Retrieves the expressions at specified indices within the given slice. Ignores -/// any invalid indices. -/// -/// # Parameters -/// -/// * `exprs` - A slice of expressions to index into. -/// * `indices` - A slice of indices specifying the positions of expressions sought. -/// -/// # Returns -/// -/// A vector of expressions corresponding to specified indices. -fn get_at_indices(exprs: &[Expr], indices: &[usize]) -> Vec { - indices - .iter() - // Indices may point to further places than `exprs` len. - .filter_map(|&idx| exprs.get(idx).cloned()) - .collect() -} - -/// Merges two slices into a single vector with sorted (ascending) and -/// deduplicated elements. For example, merging `[3, 2, 4]` and `[3, 6, 1]` -/// will produce `[1, 2, 3, 6]`. -fn merge_slices(left: &[T], right: &[T]) -> Vec { - // Make sure to sort before deduping, which removes the duplicates: - left.iter() - .cloned() - .chain(right.iter().cloned()) - .sorted() - .dedup() - .collect() -} - /// Splits requirement indices for a join into left and right children based on /// the join type. /// @@ -810,26 +673,25 @@ fn merge_slices(left: &[T], right: &[T]) -> Vec { /// adjusted based on the join type. fn split_join_requirements( left_len: usize, - indices: &[usize], + indices: RequiredIndicies, join_type: &JoinType, -) -> (Vec, Vec) { +) -> (RequiredIndicies, RequiredIndicies) { match join_type { // In these cases requirements are split between left/right children: - JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let (left_reqs, mut right_reqs): (Vec, Vec) = - indices.iter().partition(|&&idx| idx < left_len); + JoinType::Inner + | JoinType::Left + | JoinType::Right + | JoinType::Full + | JoinType::LeftMark => { // Decrease right side indices by `left_len` so that they point to valid // positions within the right child: - for idx in right_reqs.iter_mut() { - *idx -= left_len; - } - (left_reqs, right_reqs) + indices.split_off(left_len) } // All requirements can be re-routed to left child directly. - JoinType::LeftAnti | JoinType::LeftSemi => (indices.to_vec(), vec![]), + JoinType::LeftAnti | JoinType::LeftSemi => (indices, RequiredIndicies::new()), // All requirements can be re-routed to right side directly. // No need to change index, join schema is right child schema. - JoinType::RightSemi | JoinType::RightAnti => (vec![], indices.to_vec()), + JoinType::RightSemi | JoinType::RightAnti => (RequiredIndicies::new(), indices), } } @@ -849,19 +711,18 @@ fn split_join_requirements( /// /// # Returns /// -/// A `Result` containing a tuple with two values: The resulting `LogicalPlan` -/// (with or without the added projection) and a `bool` flag indicating if a -/// projection was added (`true`) or not (`false`). +/// A `Transformed` indicating if a projection was added fn add_projection_on_top_if_helpful( plan: LogicalPlan, project_exprs: Vec, -) -> Result<(LogicalPlan, bool)> { +) -> Result> { // Make sure projection decreases the number of columns, otherwise it is unnecessary. if project_exprs.len() >= plan.schema().fields().len() { - Ok((plan, false)) + Ok(Transformed::no(plan)) } else { Projection::try_new(project_exprs, Arc::new(plan)) - .map(|proj| (LogicalPlan::Projection(proj), true)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) } } @@ -883,50 +744,45 @@ fn add_projection_on_top_if_helpful( /// - `Ok(None)`: No rewrite necessary. /// - `Err(error)`: An error occured during the function call. fn rewrite_projection_given_requirements( - proj: &Projection, + proj: Projection, config: &dyn OptimizerConfig, - indices: &[usize], -) -> Result> { - let exprs_used = get_at_indices(&proj.expr, indices); + indices: &RequiredIndicies, +) -> Result> { + let Projection { expr, input, .. } = proj; + + let exprs_used = indices.get_at_indices(&expr); + let required_indices = - indices_referred_by_exprs(proj.input.schema(), exprs_used.iter())?; - return if let Some(input) = - optimize_projections(&proj.input, config, &required_indices)? - { - if is_projection_unnecessary(&input, &exprs_used)? { - Ok(Some(input)) - } else { - Projection::try_new(exprs_used, Arc::new(input)) - .map(|proj| Some(LogicalPlan::Projection(proj))) - } - } else if exprs_used.len() < proj.expr.len() { - // Projection expression used is different than the existing projection. - // In this case, even if the child doesn't change, we should update the - // projection to use fewer columns: - if is_projection_unnecessary(&proj.input, &exprs_used)? { - Ok(Some(proj.input.as_ref().clone())) - } else { - Projection::try_new(exprs_used, proj.input.clone()) - .map(|proj| Some(LogicalPlan::Projection(proj))) - } - } else { - // Projection doesn't change. - Ok(None) - }; + RequiredIndicies::new().with_exprs(input.schema(), exprs_used.iter()); + + // rewrite the children projection, and if they are changed rewrite the + // projection down + optimize_projections(Arc::unwrap_or_clone(input), config, required_indices)? + .transform_data(|input| { + if is_projection_unnecessary(&input, &exprs_used)? { + Ok(Transformed::yes(input)) + } else { + Projection::try_new(exprs_used, Arc::new(input)) + .map(LogicalPlan::Projection) + .map(Transformed::yes) + } + }) } /// Projection is unnecessary, when /// - input schema of the projection, output schema of the projection are same, and /// - all projection expressions are either Column or Literal fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result { - Ok(&projection_schema(input, proj_exprs)? == input.schema() - && proj_exprs.iter().all(is_expr_trivial)) + let proj_schema = projection_schema(input, proj_exprs)?; + Ok(&proj_schema == input.schema() && proj_exprs.iter().all(is_expr_trivial)) } #[cfg(test)] mod tests { + use std::cmp::Ordering; use std::collections::HashMap; use std::fmt::Formatter; + use std::ops::Add; use std::sync::Arc; use std::vec; @@ -941,18 +797,22 @@ mod tests { use datafusion_common::{ Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, }; + use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ binary_expr, build_join_schema, builder::table_scan_with_filters, - col, count, + col, expr::{self, Cast}, lit, logical_plan::{builder::LogicalPlanBuilder, table_scan}, - max, min, not, try_cast, when, AggregateFunction, BinaryExpr, Expr, Extension, - Like, LogicalPlan, Operator, Projection, UserDefinedLogicalNodeCore, WindowFrame, - WindowFunctionDefinition, + not, try_cast, when, BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + Projection, UserDefinedLogicalNodeCore, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::{count, max, min}; + use datafusion_functions_aggregate::min_max::max_udaf; + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -979,6 +839,16 @@ mod tests { } } + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoOpUserDefined { + fn partial_cmp(&self, other: &Self) -> Option { + match self.exprs.partial_cmp(&other.exprs) { + Some(Ordering::Equal) => self.input.partial_cmp(&other.input), + cmp => cmp, + } + } + } + impl UserDefinedLogicalNodeCore for NoOpUserDefined { fn name(&self) -> &str { "NoOpUserDefined" @@ -1000,12 +870,16 @@ mod tests { write!(f, "NoOpUserDefined") } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { - Self { - exprs: exprs.to_vec(), - input: Arc::new(inputs[0].clone()), - schema: self.schema.clone(), - } + fn with_exprs_and_inputs( + &self, + exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + exprs, + input: Arc::new(inputs.swap_remove(0)), + schema: Arc::clone(&self.schema), + }) } fn necessary_children_exprs( @@ -1015,6 +889,10 @@ mod tests { // Since schema is same. Output columns requires their corresponding version in the input columns. Some(vec![output_columns.to_vec()]) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[derive(Debug, Hash, PartialEq, Eq)] @@ -1041,6 +919,23 @@ mod tests { } } + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for UserDefinedCrossJoin { + fn partial_cmp(&self, other: &Self) -> Option { + match self.exprs.partial_cmp(&other.exprs) { + Some(Ordering::Equal) => { + match self.left_child.partial_cmp(&other.left_child) { + Some(Ordering::Equal) => { + self.right_child.partial_cmp(&other.right_child) + } + cmp => cmp, + } + } + cmp => cmp, + } + } + } + impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin { fn name(&self) -> &str { "UserDefinedCrossJoin" @@ -1062,14 +957,18 @@ mod tests { write!(f, "UserDefinedCrossJoin") } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs( + &self, + exprs: Vec, + mut inputs: Vec, + ) -> Result { assert_eq!(inputs.len(), 2); - Self { - exprs: exprs.to_vec(), - left_child: Arc::new(inputs[0].clone()), - right_child: Arc::new(inputs[1].clone()), - schema: self.schema.clone(), - } + Ok(Self { + exprs, + left_child: Arc::new(inputs.remove(0)), + right_child: Arc::new(inputs.remove(0)), + schema: Arc::clone(&self.schema), + }) } fn necessary_children_exprs( @@ -1090,6 +989,10 @@ mod tests { } Some(vec![left_reqs, right_reqs]) } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } #[test] @@ -1160,36 +1063,13 @@ mod tests { .build() .unwrap(); - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ \n Projection: \ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int32(1))]]\ \n TableScan: ?table? projection=[]"; assert_optimized_plan_equal(plan, expected) } - #[test] - fn test_struct_field_push_down() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int64, false), - Field::new_struct( - "s", - vec![ - Field::new("x", DataType::Int64, false), - Field::new("y", DataType::Int64, false), - ], - false, - ), - ])); - - let table_scan = table_scan(TableReference::none(), &schema, None)?.build()?; - let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("s").field("x")])? - .build()?; - let expected = "Projection: (?table?.s)[x]\ - \n TableScan: ?table? projection=[s]"; - assert_optimized_plan_equal(plan, expected) - } - #[test] fn test_neg_push_down() -> Result<()> { let table_scan = test_table_scan()?; @@ -1350,13 +1230,32 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + // Test Case expression + #[test] + fn test_case_merged() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), lit(0).alias("d")])? + .project(vec![ + col("a"), + when(col("a").eq(lit(1)), lit(10)) + .otherwise(col("d"))? + .alias("d"), + ])? + .build()?; + + let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE Int32(0) END AS d\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(plan, expected) + } + // Test outer projection isn't discarded despite the same schema as inner // https://github.com/apache/datafusion/issues/8942 #[test] fn test_derived_column() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), lit(0).alias("d")])? + .project(vec![col("a").add(lit(1)).alias("a"), lit(0).alias("d")])? .project(vec![ col("a"), when(col("a").eq(lit(1)), lit(10)) @@ -1365,8 +1264,9 @@ mod tests { ])? .build()?; - let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ - \n Projection: test.a, Int32(0) AS d\ + let expected = + "Projection: a, CASE WHEN a = Int32(1) THEN Int32(10) ELSE d END AS d\ + \n Projection: test.a + Int32(1) AS a, Int32(0) AS d\ \n TableScan: test projection=[a]"; assert_optimized_plan_equal(plan, expected) } @@ -1378,7 +1278,7 @@ mod tests { let table_scan = test_table_scan()?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoOpUserDefined::new( - table_scan.schema().clone(), + Arc::clone(table_scan.schema()), Arc::new(table_scan.clone()), )), }); @@ -1403,7 +1303,7 @@ mod tests { let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new( NoOpUserDefined::new( - table_scan.schema().clone(), + Arc::clone(table_scan.schema()), Arc::new(table_scan.clone()), ) .with_exprs(exprs), @@ -1438,7 +1338,7 @@ mod tests { let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new( NoOpUserDefined::new( - table_scan.schema().clone(), + Arc::clone(table_scan.schema()), Arc::new(table_scan.clone()), ) .with_exprs(exprs), @@ -1464,8 +1364,8 @@ mod tests { let right_table = test_table_scan_with_name("r")?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(UserDefinedCrossJoin::new( - Arc::new(left_table.clone()), - Arc::new(right_table.clone()), + Arc::new(left_table), + Arc::new(right_table), )), }); let plan = LogicalPlanBuilder::from(custom_plan) @@ -1487,7 +1387,7 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ \n TableScan: test projection=[b]"; assert_optimized_plan_equal(plan, expected) @@ -1501,7 +1401,7 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[max(test.b)]]\ \n TableScan: test projection=[b, c]"; assert_optimized_plan_equal(plan, expected) @@ -1516,7 +1416,7 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[a.c]], aggr=[[MAX(a.b)]]\ + let expected = "Aggregate: groupBy=[[a.c]], aggr=[[max(a.b)]]\ \n SubqueryAlias: a\ \n TableScan: test projection=[b, c]"; @@ -1532,7 +1432,7 @@ mod tests { .aggregate(Vec::::new(), vec![max(col("b"))])? .build()?; - let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[max(test.b)]]\ \n Projection: test.b\ \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; @@ -1548,7 +1448,7 @@ mod tests { // "tag.one", not a column named "one" in a table named "tag"): // // Projection: tag.one - // Aggregate: groupBy=[], aggr=[MAX("tag.one") AS "tag.one"] + // Aggregate: groupBy=[], aggr=[max("tag.one") AS "tag.one"] // TableScan let plan = table_scan(Some("m4"), &schema, None)? .aggregate( @@ -1559,7 +1459,7 @@ mod tests { .build()?; let expected = "\ - Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ + Aggregate: groupBy=[[]], aggr=[[max(m4.tag.one) AS tag.one]]\ \n TableScan: m4 projection=[tag.one]"; assert_optimized_plan_equal(plan, expected) @@ -1655,7 +1555,7 @@ mod tests { \n TableScan: test2 projection=[c1]"; let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); // make sure schema for join node include both join columns @@ -1707,7 +1607,7 @@ mod tests { \n TableScan: test2 projection=[c1]"; let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); // make sure schema for join node include both join columns @@ -1757,7 +1657,7 @@ mod tests { \n TableScan: test2 projection=[a]"; let optimized_plan = optimize(plan)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); // make sure schema for join node include both join columns @@ -1894,11 +1794,11 @@ mod tests { .aggregate(vec![col("c")], vec![max(col("a"))])? .build()?; - assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); + assert_fields_eq(&plan, vec!["c", "max(test.a)"]); let plan = optimize(plan).expect("failed to optimize plan"); let expected = "\ - Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ + Aggregate: groupBy=[[test.c]], aggr=[[max(test.a)]]\ \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; @@ -1988,14 +1888,14 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])? .filter(col("c").gt(lit(1)))? - .project(vec![col("c"), col("a"), col("MAX(test.b)")])? + .project(vec![col("c"), col("a"), col("max(test.b)")])? .build()?; - assert_fields_eq(&plan, vec!["c", "a", "MAX(test.b)"]); + assert_fields_eq(&plan, vec!["c", "a", "max(test.b)"]); - let expected = "Projection: test.c, test.a, MAX(test.b)\ + let expected = "Projection: test.c, test.a, max(test.b)\ \n Filter: test.c > Int32(1)\ - \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ + \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b)]]\ \n TableScan: test projection=[a, b, c]"; assert_optimized_plan_equal(plan, expected) @@ -2004,16 +1904,10 @@ mod tests { #[test] fn aggregate_filter_pushdown() -> Result<()> { let table_scan = test_table_scan()?; - - let aggr_with_filter = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("b")], - false, - Some(Box::new(col("c").gt(lit(42)))), - None, - None, - )); - + let aggr_with_filter = count_udaf() + .call(vec![col("b")]) + .filter(col("c").gt(lit(42))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], @@ -2021,7 +1915,7 @@ mod tests { )? .build()?; - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(test.b), count(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; assert_optimized_plan_equal(plan, expected) @@ -2049,24 +1943,19 @@ mod tests { let table_scan = test_table_scan()?; let max1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.a")], - vec![col("test.b")], - vec![], - WindowFrame::new(None), - None, - )); + )) + .partition_by(vec![col("test.b")]) + .build() + .unwrap(); let max2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("test.b")], - vec![], - vec![], - WindowFrame::new(None), - None, )); - let col1 = col(max1.display_name()?); - let col2 = col(max2.display_name()?); + let col1 = col(max1.schema_name().to_string()); + let col2 = col(max2.schema_name().to_string()); let plan = LogicalPlanBuilder::from(table_scan) .window(vec![max1])? @@ -2074,10 +1963,10 @@ mod tests { .project(vec![col1, col2])? .build()?; - let expected = "Projection: MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MAX(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n Projection: test.b, MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + let expected = "Projection: max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[max(test.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n Projection: test.b, max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[max(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: test projection=[a, b]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs new file mode 100644 index 000000000000..60d8ef1a8e6c --- /dev/null +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -0,0 +1,226 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`RequiredIndicies`] helper for OptimizeProjection + +use crate::optimize_projections::outer_columns; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{Column, DFSchemaRef, Result}; +use datafusion_expr::{Expr, LogicalPlan}; + +/// Represents columns in a schema which are required (used) by a plan node +/// +/// Also carries a flag indicating if putting a projection above children is +/// beneficial for the parent. For example `LogicalPlan::Filter` benefits from +/// small tables. Hence for filter child this flag would be `true`. Defaults to +/// `false` +/// +/// # Invariant +/// +/// Indices are always in order and without duplicates. For example, if these +/// indices were added `[3, 2, 4, 3, 6, 1]`, the instance would be represented +/// by `[1, 2, 3, 6]`. +#[derive(Debug, Clone, Default)] +pub(super) struct RequiredIndicies { + /// The indices of the required columns in the + indices: Vec, + /// If putting a projection above children is beneficial for the parent. + /// Defaults to false. + projection_beneficial: bool, +} + +impl RequiredIndicies { + /// Create a new, empty instance + pub fn new() -> Self { + Self::default() + } + + /// Create a new instance that requires all columns from the specified plan + pub fn new_for_all_exprs(plan: &LogicalPlan) -> Self { + Self { + indices: (0..plan.schema().fields().len()).collect(), + projection_beneficial: false, + } + } + + /// Create a new instance with the specified indices as required + pub fn new_from_indices(indices: Vec) -> Self { + Self { + indices, + projection_beneficial: false, + } + .compact() + } + + /// Convert the instance to its inner indices + pub fn into_inner(self) -> Vec { + self.indices + } + + /// Set the projection beneficial flag + pub fn with_projection_beneficial(mut self) -> Self { + self.projection_beneficial = true; + self + } + + /// Return the value of projection beneficial flag + pub fn projection_beneficial(&self) -> bool { + self.projection_beneficial + } + + /// Return a reference to the underlying indices + pub fn indices(&self) -> &[usize] { + &self.indices + } + + /// Add required indices for all `exprs` used in plan + pub fn with_plan_exprs( + mut self, + plan: &LogicalPlan, + schema: &DFSchemaRef, + ) -> Result { + // Add indices of the child fields referred to by the expressions in the + // parent + plan.apply_expressions(|e| { + self.add_expr(schema, e); + Ok(TreeNodeRecursion::Continue) + })?; + Ok(self.compact()) + } + + /// Adds the indices of the fields referred to by the given expression + /// `expr` within the given schema (`input_schema`). + /// + /// Self is NOT compacted (and thus this method is not pub) + /// + /// # Parameters + /// + /// * `input_schema`: The input schema to analyze for index requirements. + /// * `expr`: An expression for which we want to find necessary field indices. + fn add_expr(&mut self, input_schema: &DFSchemaRef, expr: &Expr) { + // TODO could remove these clones (and visit the expression directly) + let mut cols = expr.column_refs(); + // Get outer-referenced (subquery) columns: + outer_columns(expr, &mut cols); + self.indices.reserve(cols.len()); + for col in cols { + if let Some(idx) = input_schema.maybe_index_of_column(col) { + self.indices.push(idx); + } + } + } + + /// Adds the indices of the fields referred to by the given expressions + /// `within the given schema. + /// + /// # Parameters + /// + /// * `input_schema`: The input schema to analyze for index requirements. + /// * `exprs`: the expressions for which we want to find field indices. + pub fn with_exprs<'a>( + self, + schema: &DFSchemaRef, + exprs: impl IntoIterator, + ) -> Self { + exprs + .into_iter() + .fold(self, |mut acc, expr| { + acc.add_expr(schema, expr); + acc + }) + .compact() + } + + /// Adds all `indices` into this instance. + pub fn append(mut self, indices: &[usize]) -> Self { + self.indices.extend_from_slice(indices); + self.compact() + } + + /// Splits this instance into a tuple with two instances: + /// * The first `n` indices + /// * The remaining indices, adjusted down by n + pub fn split_off(self, n: usize) -> (Self, Self) { + let (l, r) = self.partition(|idx| idx < n); + (l, r.map_indices(|idx| idx - n)) + } + + /// Partitions the indices in this instance into two groups based on the + /// given predicate function `f`. + fn partition(&self, f: F) -> (Self, Self) + where + F: Fn(usize) -> bool, + { + let (l, r): (Vec, Vec) = + self.indices.iter().partition(|&&idx| f(idx)); + let projection_beneficial = self.projection_beneficial; + + ( + Self { + indices: l, + projection_beneficial, + }, + Self { + indices: r, + projection_beneficial, + }, + ) + } + + /// Map the indices in this instance to a new set of indices based on the + /// given function `f`, returning the mapped indices + /// + /// Not `pub` as it might not preserve the invariant of compacted indices + fn map_indices(mut self, f: F) -> Self + where + F: Fn(usize) -> usize, + { + self.indices.iter_mut().for_each(|idx| *idx = f(*idx)); + self + } + + /// Apply the given function `f` to each index in this instance, returning + /// the mapped indices + pub fn into_mapped_indices(self, f: F) -> Vec + where + F: Fn(usize) -> usize, + { + self.map_indices(f).into_inner() + } + + /// Returns the `Expr`s from `exprs` that are at the indices in this instance + pub fn get_at_indices(&self, exprs: &[Expr]) -> Vec { + self.indices.iter().map(|&idx| exprs[idx].clone()).collect() + } + + /// Generates the required expressions (columns) that reside at `indices` of + /// the given `input_schema`. + pub fn get_required_exprs(&self, input_schema: &DFSchemaRef) -> Vec { + self.indices + .iter() + .map(|&idx| Expr::from(Column::from(input_schema.qualified_field(idx)))) + .collect() + } + + /// Compacts the indices of this instance so they are sorted + /// (ascending) and deduplicated. + fn compact(mut self) -> Self { + self.indices.sort_unstable(); + self.indices.dedup(); + self + } +} diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index e787f56587f7..90a790a0e841 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -18,15 +18,17 @@ //! [`Optimizer`] and [`OptimizerRule`] use std::collections::HashSet; +use std::fmt::Debug; use std::sync::Arc; use chrono::{DateTime, Utc}; +use datafusion_expr::registry::FunctionRegistry; use log::{debug, warn}; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; @@ -35,6 +37,7 @@ use crate::decorrelate_predicate_subquery::DecorrelatePredicateSubquery; use crate::eliminate_cross_join::EliminateCrossJoin; use crate::eliminate_duplicated_expr::EliminateDuplicatedExpr; use crate::eliminate_filter::EliminateFilter; +use crate::eliminate_group_by_constant::EliminateGroupByConstant; use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_nested_union::EliminateNestedUnion; @@ -48,7 +51,6 @@ use crate::propagate_empty_relation::PropagateEmptyRelation; use crate::push_down_filter::PushDownFilter; use crate::push_down_limit::PushDownLimit; use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; -use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; use crate::scalar_subquery_to_join::ScalarSubqueryToJoin; use crate::simplify_expressions::SimplifyExpressions; use crate::single_distinct_to_groupby::SingleDistinctToGroupBy; @@ -66,20 +68,26 @@ use crate::utils::log_plan; /// `OptimizerRule`s. /// /// [`AnalyzerRule`]: crate::analyzer::AnalyzerRule -/// [`SessionState::add_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionState.html#method.add_optimizer_rule +/// [`SessionState::add_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_optimizer_rule -pub trait OptimizerRule { +pub trait OptimizerRule: Debug { /// Try and rewrite `plan` to an optimized form, returning None if the plan /// cannot be optimized by this rule. /// /// Note this API will be deprecated in the future as it requires `clone`ing /// the input plan, which can be expensive. OptimizerRules should implement /// [`Self::rewrite`] instead. + #[deprecated( + since = "40.0.0", + note = "please implement supports_rewrite and rewrite instead" + )] fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result>; + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called rewrite") + } /// A human readable name for this optimizer rule fn name(&self) -> &str; @@ -94,7 +102,7 @@ pub trait OptimizerRule { /// Does this rule support rewriting owned plans (rather than by reference)? fn supports_rewrite(&self) -> bool { - false + true } /// Try to rewrite `plan` to an optimized form, returning `Transformed::yes` @@ -118,9 +126,13 @@ pub trait OptimizerConfig { fn query_execution_start_time(&self) -> DateTime; /// Return alias generator used to generate unique aliases for subqueries - fn alias_generator(&self) -> Arc; + fn alias_generator(&self) -> &Arc; fn options(&self) -> &ConfigOptions; + + fn function_registry(&self) -> Option<&dyn FunctionRegistry> { + None + } } /// A standalone [`OptimizerConfig`] that can be used independently @@ -192,8 +204,8 @@ impl OptimizerConfig for OptimizerContext { self.query_execution_start_time } - fn alias_generator(&self) -> Arc { - self.alias_generator.clone() + fn alias_generator(&self) -> &Arc { + &self.alias_generator } fn options(&self) -> &ConfigOptions { @@ -202,7 +214,7 @@ impl OptimizerConfig for OptimizerContext { } /// A rule-based optimizer. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Optimizer { /// All optimizer rules to apply pub rules: Vec>, @@ -238,11 +250,6 @@ impl Optimizer { Arc::new(DecorrelatePredicateSubquery::new()), Arc::new(ScalarSubqueryToJoin::new()), Arc::new(ExtractEquijoinPredicate::new()), - // simplify expressions does not simplify expressions in subqueries, so we - // run it again after running the optimizations that potentially converted - // subqueries to joins - Arc::new(SimplifyExpressions::new()), - Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(EliminateDuplicatedExpr::new()), Arc::new(EliminateFilter::new()), Arc::new(EliminateCrossJoin::new()), @@ -262,6 +269,7 @@ impl Optimizer { Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateGroupByConstant::new()), Arc::new(OptimizeProjections::new()), ]; @@ -325,6 +333,7 @@ fn optimize_plan_node( return rule.rewrite(plan, config); } + #[allow(deprecated)] rule.try_optimize(&plan, config).map(|maybe_plan| { match maybe_plan { Some(new_plan) => { @@ -367,15 +376,13 @@ impl Optimizer { .skip_failed_rules .then(|| new_plan.clone()); - let starting_schema = new_plan.schema().clone(); + let starting_schema = Arc::clone(new_plan.schema()); let result = match rule.apply_order() { // optimizer handles recursion - Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( - apply_order, - rule.as_ref(), - config, - )), + Some(apply_order) => new_plan.rewrite_with_subqueries( + &mut Rewriter::new(apply_order, rule.as_ref(), config), + ), // rule handles recursion itself None => optimize_plan_node(new_plan, rule.as_ref(), config), } @@ -475,7 +482,8 @@ pub(crate) fn assert_schema_is_the_same( mod tests { use std::sync::{Arc, Mutex}; - use datafusion_common::{plan_err, DFSchema, DFSchemaRef, Result}; + use datafusion_common::tree_node::Transformed; + use datafusion_common::{plan_err, DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::logical_plan::EmptyRelation; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, Projection}; @@ -564,7 +572,7 @@ mod tests { let config = OptimizerContext::new().with_skip_failing_rules(false); let input = Arc::new(test_table_scan()?); - let input_schema = input.schema().clone(); + let input_schema = Arc::clone(input.schema()); let plan = LogicalPlan::Projection(Projection::try_new_with_schema( vec![col("a"), col("b"), col("c")], @@ -651,43 +659,56 @@ mod tests { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + #[derive(Default, Debug)] struct BadRule {} impl OptimizerRule for BadRule { - fn try_optimize( - &self, - _: &LogicalPlan, - _: &dyn OptimizerConfig, - ) -> Result> { - plan_err!("rule failed") - } - fn name(&self) -> &str { "bad rule" } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + _plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result, DataFusionError> { + plan_err!("rule failed") + } } /// Replaces whatever plan with a single table scan + #[derive(Default, Debug)] struct GetTableScanRule {} impl OptimizerRule for GetTableScanRule { - fn try_optimize( - &self, - _: &LogicalPlan, - _: &dyn OptimizerConfig, - ) -> Result> { - let table_scan = test_table_scan()?; - Ok(Some(LogicalPlanBuilder::from(table_scan).build()?)) - } - fn name(&self) -> &str { "get table_scan rule" } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + _plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + let table_scan = test_table_scan()?; + Ok(Transformed::yes( + LogicalPlanBuilder::from(table_scan).build()?, + )) + } } /// A goofy rule doing rotation of columns in all projections. /// /// Useful to test cycle detection. + #[derive(Default, Debug)] struct RotateProjectionRule { // reverse exprs instead of rotating on the first pass reverse_on_first_pass: Mutex, @@ -702,14 +723,26 @@ mod tests { } impl OptimizerRule for RotateProjectionRule { - fn try_optimize( + fn name(&self) -> &str { + "rotate_projection" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, - _: &dyn OptimizerConfig, - ) -> Result> { + plan: LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { let projection = match plan { LogicalPlan::Projection(p) if p.expr.len() >= 2 => p, - _ => return Ok(None), + _ => return Ok(Transformed::no(plan)), }; let mut exprs = projection.expr.clone(); @@ -722,18 +755,9 @@ mod tests { exprs.rotate_left(1); } - Ok(Some(LogicalPlan::Projection(Projection::try_new( - exprs, - projection.input.clone(), - )?))) - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } - - fn name(&self) -> &str { - "rotate_projection" + Ok(Transformed::yes(LogicalPlan::Projection( + Projection::try_new(exprs, Arc::clone(&projection.input))?, + ))) } } } diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index d22795797478..73e6b418272a 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -100,7 +100,7 @@ mod tests { let one_node_plan = Arc::new(LogicalPlan::EmptyRelation(datafusion_expr::EmptyRelation { produce_one_row: false, - schema: schema.clone(), + schema: Arc::clone(&schema), })); assert_eq!(1, get_node_number(&one_node_plan).get()); @@ -112,7 +112,7 @@ mod tests { assert_eq!(2, get_node_number(&two_node_plan).get()); let five_node_plan = Arc::new(LogicalPlan::Union(datafusion_expr::Union { - inputs: vec![two_node_plan.clone(), two_node_plan], + inputs: vec![Arc::clone(&two_node_plan), two_node_plan], schema, })); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 4003acaa7d65..d26df073dc6f 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -16,16 +16,20 @@ // under the License. //! [`PropagateEmptyRelation`] eliminates nodes fed by `EmptyRelation` + +use std::sync::Arc; + +use datafusion_common::tree_node::Transformed; +use datafusion_common::JoinType; use datafusion_common::{plan_err, Result}; use datafusion_expr::logical_plan::LogicalPlan; -use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; -use std::sync::Arc; +use datafusion_expr::{EmptyRelation, Projection, Union}; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; /// Optimization rule that bottom-up to eliminate plan by propagating empty_relation. -#[derive(Default)] +#[derive(Default, Debug)] pub struct PropagateEmptyRelation; impl PropagateEmptyRelation { @@ -36,13 +40,25 @@ impl PropagateEmptyRelation { } impl OptimizerRule for PropagateEmptyRelation { - fn try_optimize( + fn name(&self) -> &str { + "propagate_empty_relation" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { - LogicalPlan::EmptyRelation(_) => {} + LogicalPlan::EmptyRelation(_) => Ok(Transformed::no(plan)), LogicalPlan::Projection(_) | LogicalPlan::Filter(_) | LogicalPlan::Window(_) @@ -50,43 +66,88 @@ impl OptimizerRule for PropagateEmptyRelation { | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Repartition(_) | LogicalPlan::Limit(_) => { - if let Some(empty) = empty_child(plan)? { - return Ok(Some(empty)); - } - } - LogicalPlan::CrossJoin(_) => { - let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?; - if left_empty || right_empty { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: plan.schema().clone(), - }))); + let empty = empty_child(&plan)?; + if let Some(empty_plan) = empty { + return Ok(Transformed::yes(empty_plan)); } + Ok(Transformed::no(plan)) } - LogicalPlan::Join(join) => { + LogicalPlan::Join(ref join) => { // TODO: For Join, more join type need to be careful: - // For LeftOuter/LeftSemi/LeftAnti Join, only the left side is empty, the Join result is empty. - // For LeftSemi Join, if the right side is empty, the Join result is empty. - // For LeftAnti Join, if the right side is empty, the Join result is left side(should exclude null ??). - // For RightOuter/RightSemi/RightAnti Join, only the right side is empty, the Join result is empty. - // For RightSemi Join, if the left side is empty, the Join result is empty. - // For RightAnti Join, if the left side is empty, the Join result is right side(should exclude null ??). - // For Full Join, only both sides are empty, the Join result is empty. // For LeftOut/Full Join, if the right side is empty, the Join can be eliminated with a Projection with left side // columns + right side columns replaced with null values. // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side // columns + left side columns replaced with null values. - if join.join_type == JoinType::Inner { - let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?; - if left_empty || right_empty { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + let (left_empty, right_empty) = binary_plan_children_is_empty(&plan)?; + + match join.join_type { + // For Full Join, only both sides are empty, the Join result is empty. + JoinType::Full if left_empty && right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::Inner if left_empty || right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::Left if left_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::Right if right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + JoinType::LeftSemi if left_empty || right_empty => Ok( + Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + })), + ), + JoinType::RightSemi if left_empty || right_empty => Ok( + Transformed::yes(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + })), + ), + JoinType::LeftAnti if left_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: plan.schema().clone(), - }))); + schema: Arc::clone(&join.schema), + }), + )), + JoinType::LeftAnti if right_empty => { + Ok(Transformed::yes((*join.left).clone())) + } + JoinType::RightAnti if left_empty => { + Ok(Transformed::yes((*join.right).clone())) + } + JoinType::RightAnti if right_empty => Ok(Transformed::yes( + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::clone(&join.schema), + }), + )), + _ => Ok(Transformed::no(plan)), + } + } + LogicalPlan::Aggregate(ref agg) => { + if !agg.group_expr.is_empty() { + if let Some(empty_plan) = empty_child(&plan)? { + return Ok(Transformed::yes(empty_plan)); } } + Ok(Transformed::no(LogicalPlan::Aggregate(agg.clone()))) } - LogicalPlan::Union(union) => { + LogicalPlan::Union(ref union) => { let new_inputs = union .inputs .iter() @@ -98,49 +159,38 @@ impl OptimizerRule for PropagateEmptyRelation { .collect::>(); if new_inputs.len() == union.inputs.len() { - return Ok(None); + Ok(Transformed::no(plan)) } else if new_inputs.is_empty() { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: plan.schema().clone(), - }))); + Ok(Transformed::yes(LogicalPlan::EmptyRelation( + EmptyRelation { + produce_one_row: false, + schema: Arc::clone(plan.schema()), + }, + ))) } else if new_inputs.len() == 1 { - let child = (*new_inputs[0]).clone(); + let mut new_inputs = new_inputs; + let input_plan = new_inputs.pop().unwrap(); // length checked + let child = Arc::unwrap_or_clone(input_plan); if child.schema().eq(plan.schema()) { - return Ok(Some(child)); + Ok(Transformed::yes(child)) } else { - return Ok(Some(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::Projection( Projection::new_from_schema( Arc::new(child), - plan.schema().clone(), + Arc::clone(plan.schema()), ), - ))); + ))) } } else { - return Ok(Some(LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::Union(Union { inputs: new_inputs, - schema: union.schema.clone(), - }))); - } - } - LogicalPlan::Aggregate(agg) => { - if !agg.group_expr.is_empty() { - if let Some(empty) = empty_child(plan)? { - return Ok(Some(empty)); - } + schema: Arc::clone(&union.schema), + }))) } } - _ => {} - } - Ok(None) - } - - fn name(&self) -> &str { - "propagate_empty_relation" - } - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) + _ => Ok(Transformed::no(plan)), + } } } @@ -168,7 +218,7 @@ fn empty_child(plan: &LogicalPlan) -> Result> { if !empty.produce_one_row { Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: plan.schema().clone(), + schema: Arc::clone(plan.schema()), }))) } else { Ok(None) @@ -182,18 +232,22 @@ fn empty_child(plan: &LogicalPlan) -> Result> { #[cfg(test)] mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + + use datafusion_common::{Column, DFSchema, JoinType}; + use datafusion_expr::logical_plan::table_scan; + use datafusion_expr::{ + binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Operator, + }; + use crate::eliminate_filter::EliminateFilter; use crate::eliminate_nested_union::EliminateNestedUnion; use crate::test::{ - assert_optimized_plan_eq, assert_optimized_plan_eq_with_rules, test_table_scan, + assert_optimized_plan_eq, assert_optimized_plan_with_rules, test_table_scan, test_table_scan_fields, test_table_scan_with_name, }; - use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Column, DFSchema, ScalarValue}; - use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{ - binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, Operator, - }; use super::*; @@ -201,11 +255,12 @@ mod tests { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } - fn assert_together_optimized_plan_eq( + fn assert_together_optimized_plan( plan: LogicalPlan, expected: &str, + eq: bool, ) -> Result<()> { - assert_optimized_plan_eq_with_rules( + assert_optimized_plan_with_rules( vec![ Arc::new(EliminateFilter::new()), Arc::new(EliminateNestedUnion::new()), @@ -213,13 +268,14 @@ mod tests { ], plan, expected, + eq, ) } #[test] fn propagate_empty() -> Result<()> { let plan = LogicalPlanBuilder::empty(false) - .filter(Expr::Literal(ScalarValue::Boolean(Some(true))))? + .filter(lit(true))? .limit(10, None)? .project(vec![binary_expr(lit(1), Operator::Plus, lit(1))])? .build()?; @@ -235,7 +291,7 @@ mod tests { let right_table_scan = test_table_scan_with_name("test2")?; let right = LogicalPlanBuilder::from(right_table_scan) .project(vec![col("a")])? - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let plan = LogicalPlanBuilder::from(left) @@ -248,20 +304,20 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] fn propagate_union_empty() -> Result<()> { let left = LogicalPlanBuilder::from(test_table_scan()?).build()?; let right = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -269,10 +325,10 @@ mod tests { let one = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?).build()?; let two = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let three = LogicalPlanBuilder::from(test_table_scan_with_name("test3")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let four = LogicalPlanBuilder::from(test_table_scan_with_name("test4")?).build()?; @@ -286,22 +342,22 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] fn propagate_union_all_empty() -> Result<()> { let one = LogicalPlanBuilder::from(test_table_scan_with_name("test1")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let two = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let three = LogicalPlanBuilder::from(test_table_scan_with_name("test3")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let four = LogicalPlanBuilder::from(test_table_scan_with_name("test4")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let plan = LogicalPlanBuilder::from(one) @@ -311,7 +367,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -319,7 +375,7 @@ mod tests { let one_schema = Schema::new(vec![Field::new("t1a", DataType::UInt32, false)]); let t1_scan = table_scan(Some("test1"), &one_schema, None)?.build()?; let one = LogicalPlanBuilder::from(t1_scan) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let two_schema = Schema::new(vec![Field::new("t2a", DataType::UInt32, false)]); @@ -338,20 +394,20 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] fn propagate_union_alias() -> Result<()> { let left = LogicalPlanBuilder::from(test_table_scan()?).build()?; let right = LogicalPlanBuilder::from(test_table_scan_with_name("test2")?) - .filter(Expr::Literal(ScalarValue::Boolean(Some(false))))? + .filter(lit(false))? .build()?; let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } #[test] @@ -366,7 +422,140 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) + } + + fn assert_empty_left_empty_right_lp( + left_empty: bool, + right_empty: bool, + join_type: JoinType, + eq: bool, + ) -> Result<()> { + let left_lp = if left_empty { + let left_table_scan = test_table_scan()?; + + LogicalPlanBuilder::from(left_table_scan) + .filter(lit(false))? + .build() + } else { + let scan = test_table_scan_with_name("left").unwrap(); + LogicalPlanBuilder::from(scan).build() + }?; + + let right_lp = if right_empty { + let right_table_scan = test_table_scan_with_name("right")?; + + LogicalPlanBuilder::from(right_table_scan) + .filter(lit(false))? + .build() + } else { + let scan = test_table_scan_with_name("right").unwrap(); + LogicalPlanBuilder::from(scan).build() + }?; + + let plan = LogicalPlanBuilder::from(left_lp) + .join_using( + right_lp, + join_type, + vec![Column::from_name("a".to_string())], + )? + .build()?; + + let expected = "EmptyRelation"; + assert_together_optimized_plan(plan, expected, eq) + } + + // TODO: fix this long name + fn assert_anti_join_empty_join_table_is_base_table( + anti_left_join: bool, + ) -> Result<()> { + // if we have an anti join with an empty join table, then the result is the base_table + let (left, right, join_type, expected) = if anti_left_join { + let left = test_table_scan()?; + let right = LogicalPlanBuilder::from(test_table_scan()?) + .filter(lit(false))? + .build()?; + let expected = left.display_indent().to_string(); + (left, right, JoinType::LeftAnti, expected) + } else { + let right = test_table_scan()?; + let left = LogicalPlanBuilder::from(test_table_scan()?) + .filter(lit(false))? + .build()?; + let expected = right.display_indent().to_string(); + (left, right, JoinType::RightAnti, expected) + }; + + let plan = LogicalPlanBuilder::from(left) + .join_using(right, join_type, vec![Column::from_name("a".to_string())])? + .build()?; + + assert_together_optimized_plan(plan, &expected, true) + } + + #[test] + fn test_join_empty_propagation_rules() -> Result<()> { + // test full join with empty left and empty right + assert_empty_left_empty_right_lp(true, true, JoinType::Full, true)?; + + // test left join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::Left, true)?; + + // test right join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::Right, true)?; + + // test left semi join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::LeftSemi, true)?; + + // test left semi join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::LeftSemi, true)?; + + // test right semi join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::RightSemi, true)?; + + // test right semi join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::RightSemi, true)?; + + // test left anti join empty left + assert_empty_left_empty_right_lp(true, false, JoinType::LeftAnti, true)?; + + // test right anti join empty right + assert_empty_left_empty_right_lp(false, true, JoinType::RightAnti, true)?; + + // test left anti join empty right + assert_anti_join_empty_join_table_is_base_table(true)?; + + // test right anti join empty left + assert_anti_join_empty_join_table_is_base_table(false) + } + + #[test] + fn test_join_empty_propagation_rules_noop() -> Result<()> { + // these cases should not result in an empty relation + + // test left join with empty right + assert_empty_left_empty_right_lp(false, true, JoinType::Left, false)?; + + // test right join with empty left + assert_empty_left_empty_right_lp(true, false, JoinType::Right, false)?; + + // test left semi with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::LeftSemi, false)?; + + // test right semi with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::RightSemi, false)?; + + // test left anti join with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::LeftAnti, false)?; + + // test left anti with non-empty left and empty right + assert_empty_left_empty_right_lp(false, true, JoinType::LeftAnti, false)?; + + // test right anti join with non-empty left and right + assert_empty_left_empty_right_lp(false, false, JoinType::RightAnti, false)?; + + // test right anti with empty left and non-empty right + assert_empty_left_empty_right_lp(true, false, JoinType::RightAnti, false) } #[test] @@ -377,7 +566,7 @@ mod tests { let empty = LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: Arc::new(DFSchema::from_unqualifed_fields( + schema: Arc::new(DFSchema::from_unqualified_fields( fields.into(), Default::default(), )?), @@ -399,6 +588,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(plan, expected) + assert_together_optimized_plan(plan, expected, true) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0572dc5ea4f1..acb7ba0fa757 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -14,31 +14,30 @@ //! [`PushDownFilter`] applies filters as early as possible +use indexmap::IndexSet; use std::collections::{HashMap, HashSet}; use std::sync::Arc; -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; +use itertools::Itertools; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - internal_err, plan_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef, - JoinConstraint, Result, + internal_err, plan_err, qualified_name, Column, DFSchema, Result, }; -use datafusion_expr::expr::Alias; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::logical_plan::{ - CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union, +use datafusion_expr::logical_plan::{Join, JoinType, LogicalPlan, TableScan, Union}; +use datafusion_expr::utils::{ + conjunction, expr_to_columns, split_conjunction, split_conjunction_owned, }; -use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::{ - and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - ScalarFunctionDefinition, TableProviderFilterPushDown, + and, or, BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown, }; -use itertools::Itertools; +use crate::optimizer::ApplyOrder; +use crate::utils::{has_all_column_refs, is_restrict_null_predicate}; +use crate::{OptimizerConfig, OptimizerRule}; /// Optimizer rule for pushing (moving) filter expressions down in a plan so /// they are applied as early as possible. @@ -79,8 +78,8 @@ use itertools::Itertools; /// satisfies `filter(op(data)) = op(filter(data))`. /// /// The filter-commutative property is plan and column-specific. A filter on `a` -/// can be pushed through a `Aggregate(group_by = [a], agg=[SUM(b))`. However, a -/// filter on `SUM(b)` can not be pushed through the same aggregate. +/// can be pushed through a `Aggregate(group_by = [a], agg=[sum(b))`. However, a +/// filter on `sum(b)` can not be pushed through the same aggregate. /// /// # Handling Conjunctions /// @@ -90,16 +89,16 @@ use itertools::Itertools; /// For example, given the following plan: /// /// ```text -/// Filter(a > 10 AND SUM(b) < 5) -/// Aggregate(group_by = [a], agg = [SUM(b)) +/// Filter(a > 10 AND sum(b) < 5) +/// Aggregate(group_by = [a], agg = [sum(b)) /// ``` /// -/// The `a > 10` is commutative with the `Aggregate` but `SUM(b) < 5` is not. +/// The `a > 10` is commutative with the `Aggregate` but `sum(b) < 5` is not. /// Therefore it is possible to only push part of the expression, resulting in: /// /// ```text -/// Filter(SUM(b) < 5) -/// Aggregate(group_by = [a], agg = [SUM(b)) +/// Filter(sum(b) < 5) +/// Aggregate(group_by = [a], agg = [sum(b)) /// Filter(a > 10) /// ``` /// @@ -127,75 +126,114 @@ use itertools::Itertools; /// reaches a plan node that does not commute with that filter, it adds the /// filter to that place. When it passes through a projection, it re-writes the /// filter's expression taking into account that projection. -#[derive(Default)] +#[derive(Default, Debug)] pub struct PushDownFilter {} -// For a given JOIN logical plan, determine whether each side of the join is preserved. -// We say a join side is preserved if the join returns all or a subset of the rows from -// the relevant side, such that each row of the output table directly maps to a row of -// the preserved input table. If a table is not preserved, it can provide extra null rows. -// That is, there may be rows in the output table that don't directly map to a row in the -// input table. -// -// For example: -// - In an inner join, both sides are preserved, because each row of the output -// maps directly to a row from each side. -// - In a left join, the left side is preserved and the right is not, because -// there may be rows in the output that don't directly map to a row in the -// right input (due to nulls filling where there is no match on the right). -// -// This is important because we can always push down post-join filters to a preserved -// side of the join, assuming the filter only references columns from that side. For the -// non-preserved side it can be more tricky. -// -// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), - // No columns from the right side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), - // No columns from the left side of the join can be referenced in output - // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), - }, - LogicalPlan::CrossJoin(_) => Ok((true, true)), - _ => internal_err!("lr_is_preserved only valid for JOIN nodes"), +/// For a given JOIN type, determine whether each input of the join is preserved +/// for post-join (`WHERE` clause) filters. +/// +/// It is only correct to push filters below a join for preserved inputs. +/// +/// # Return Value +/// A tuple of booleans - (left_preserved, right_preserved). +/// +/// # "Preserved" input definition +/// +/// We say a join side is preserved if the join returns all or a subset of the rows from +/// the relevant side, such that each row of the output table directly maps to a row of +/// the preserved input table. If a table is not preserved, it can provide extra null rows. +/// That is, there may be rows in the output table that don't directly map to a row in the +/// input table. +/// +/// For example: +/// - In an inner join, both sides are preserved, because each row of the output +/// maps directly to a row from each side. +/// +/// - In a left join, the left side is preserved (we can push predicates) but +/// the right is not, because there may be rows in the output that don't +/// directly map to a row in the right input (due to nulls filling where there +/// is no match on the right). +pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { + match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), + // No columns from the right side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => (true, false), + // No columns from the left side of the join can be referenced in output + // predicates for semi/anti joins, so whether we specify t/f doesn't matter. + JoinType::RightSemi | JoinType::RightAnti => (false, true), + } +} + +/// For a given JOIN type, determine whether each input of the join is preserved +/// for the join condition (`ON` clause filters). +/// +/// It is only correct to push filters below a join for preserved inputs. +/// +/// # Return Value +/// A tuple of booleans - (left_preserved, right_preserved). +/// +/// See [`lr_is_preserved`] for a definition of "preserved". +pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { + match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (false, false), + JoinType::LeftSemi | JoinType::RightSemi => (true, true), + JoinType::LeftAnti => (false, true), + JoinType::RightAnti => (true, false), + JoinType::LeftMark => (false, true), } } -// For a given JOIN logical plan, determine whether each side of the join is preserved -// in terms on join filtering. -// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> { - match plan { - LogicalPlan::Join(Join { join_type, .. }) => match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), - JoinType::LeftAnti => Ok((false, true)), - JoinType::RightAnti => Ok((true, false)), - }, - LogicalPlan::CrossJoin(_) => { - internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes") +/// Evaluates the columns referenced in the given expression to see if they refer +/// only to the left or right columns +#[derive(Debug)] +struct ColumnChecker<'a> { + /// schema of left join input + left_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + left_columns: Option>, + /// schema of right join input + right_schema: &'a DFSchema, + /// columns in left_schema, computed on demand + right_columns: Option>, +} + +impl<'a> ColumnChecker<'a> { + fn new(left_schema: &'a DFSchema, right_schema: &'a DFSchema) -> Self { + Self { + left_schema, + left_columns: None, + right_schema, + right_columns: None, + } + } + + /// Return true if the expression references only columns from the left side of the join + fn is_left_only(&mut self, predicate: &Expr) -> bool { + if self.left_columns.is_none() { + self.left_columns = Some(schema_columns(self.left_schema)); + } + has_all_column_refs(predicate, self.left_columns.as_ref().unwrap()) + } + + /// Return true if the expression references only columns from the right side of the join + fn is_right_only(&mut self, predicate: &Expr) -> bool { + if self.right_columns.is_none() { + self.right_columns = Some(schema_columns(self.right_schema)); } - _ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"), + has_all_column_refs(predicate, self.right_columns.as_ref().unwrap()) } } -// Determine which predicates in state can be pushed down to a given side of a join. -// To determine this, we need to know the schema of the relevant join side and whether -// or not the side's rows are preserved when joining. If the side is not preserved, we -// do not push down anything. Otherwise we can push down predicates where all of the -// relevant columns are contained on the relevant join side's schema. -fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result { - let schema_columns = schema +/// Returns all columns in the schema +fn schema_columns(schema: &DFSchema) -> HashSet { + schema .iter() .flat_map(|(qualifier, field)| { [ @@ -204,17 +242,10 @@ fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result>(); - let columns = predicate.to_columns()?; - - Ok(schema_columns - .intersection(&columns) .collect::>() - .len() - == columns.len()) } -// Determine whether the predicate can evaluate as the join conditions +/// Determine whether the predicate can evaluate as the join conditions fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { let mut is_evaluate = true; predicate.apply(|expr| match expr { @@ -226,11 +257,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) - | Expr::Unnest(_) - | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(_), - .. - }) => { + | Expr::Unnest(_) => { is_evaluate = false; Ok(TreeNodeRecursion::Stop) } @@ -248,15 +275,13 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) | Expr::Negative(_) - | Expr::GetIndexedField(_) | Expr::Between(_) | Expr::Case(_) | Expr::Cast(_) | Expr::TryCast(_) - | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), - Expr::Sort(_) - | Expr::AggregateFunction(_) + | Expr::InList { .. } + | Expr::ScalarFunction(_) => Ok(TreeNodeRecursion::Continue), + Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), @@ -264,53 +289,44 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Ok(is_evaluate) } -// examine OR clause to see if any useful clauses can be extracted and push down. -// extract at least one qual from each sub clauses of OR clause, then form the quals -// to new OR clause as predicate. -// -// Filter: (a = c and a < 20) or (b = d and b > 10) -// join/crossjoin: -// TableScan: projection=[a, b] -// TableScan: projection=[c, d] -// -// is optimized to -// -// Filter: (a = c and a < 20) or (b = d and b > 10) -// join/crossjoin: -// Filter: (a < 20) or (b > 10) -// TableScan: projection=[a, b] -// TableScan: projection=[c, d] -// -// In general, predicates of this form: -// -// (A AND B) OR (C AND D) -// -// will be transformed to -// -// ((A AND B) OR (C AND D)) AND (A OR C) -// -// OR -// -// ((A AND B) OR (C AND D)) AND ((A AND B) OR C) -// -// OR -// -// do nothing. -// +/// examine OR clause to see if any useful clauses can be extracted and push down. +/// extract at least one qual from each sub clauses of OR clause, then form the quals +/// to new OR clause as predicate. +/// +/// # Example +/// ```text +/// Filter: (a = c and a < 20) or (b = d and b > 10) +/// join/crossjoin: +/// TableScan: projection=[a, b] +/// TableScan: projection=[c, d] +/// ``` +/// +/// is optimized to +/// +/// ```text +/// Filter: (a = c and a < 20) or (b = d and b > 10) +/// join/crossjoin: +/// Filter: (a < 20) or (b > 10) +/// TableScan: projection=[a, b] +/// TableScan: projection=[c, d] +/// ``` +/// +/// In general, predicates of this form: +/// +/// ```sql +/// (A AND B) OR (C AND D) +/// ``` +/// +/// will be transformed to one of: +/// +/// * `((A AND B) OR (C AND D)) AND (A OR C)` +/// * `((A AND B) OR (C AND D)) AND ((A AND B) OR C)` +/// * do nothing. fn extract_or_clauses_for_join<'a>( filters: &'a [Expr], schema: &'a DFSchema, ) -> impl Iterator + 'a { - let schema_columns = schema - .iter() - .flat_map(|(qualifier, field)| { - [ - Column::new(qualifier.cloned(), field.name()), - // we need to push down filter using unqualified column as well - Column::new_unqualified(field.name()), - ] - }) - .collect::>(); + let schema_columns = schema_columns(schema); // new formed OR clauses and their column references filters.iter().filter_map(move |expr| { @@ -332,17 +348,17 @@ fn extract_or_clauses_for_join<'a>( }) } -// extract qual from OR sub-clause. -// -// A qual is extracted if it only contains set of column references in schema_columns. -// -// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted -// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None. -// -// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted; -// Otherwise, return None. -// -// For other clause, apply the rule above to extract clause. +/// extract qual from OR sub-clause. +/// +/// A qual is extracted if it only contains set of column references in schema_columns. +/// +/// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted +/// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None. +/// +/// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted; +/// Otherwise, return None. +/// +/// For other clause, apply the rule above to extract clause. fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { let mut predicate = None; @@ -383,14 +399,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option { - let columns = expr.to_columns().ok().unwrap(); - - if schema_columns - .intersection(&columns) - .collect::>() - .len() - == columns.len() - { + if has_all_column_refs(expr, schema_columns) { predicate = Some(expr.clone()); } } @@ -399,36 +408,32 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet) -> Option, - infer_predicates: Vec, - join_plan: &LogicalPlan, - left: &LogicalPlan, - right: &LogicalPlan, + inferred_join_predicates: Vec, + mut join: Join, on_filter: Vec, - is_inner_join: bool, -) -> Result { - let on_filter_empty = on_filter.is_empty(); +) -> Result> { + let is_inner_join = join.join_type == JoinType::Inner; // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?; + let (left_preserved, right_preserved) = lr_is_preserved(join.join_type); // The predicates can be divided to three categories: // 1) can push through join to its children(left or right) // 2) can be converted to join conditions if the join type is Inner // 3) should be kept as filter conditions - let left_schema = left.schema(); - let right_schema = right.schema(); + let left_schema = join.left.schema(); + let right_schema = join.right.schema(); let mut left_push = vec![]; let mut right_push = vec![]; let mut keep_predicates = vec![]; let mut join_conditions = vec![]; + let mut checker = ColumnChecker::new(left_schema, right_schema); for predicate in predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } else if is_inner_join && can_evaluate_as_join_condition(&predicate)? { // Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate @@ -440,27 +445,25 @@ fn push_down_all_join( } // For infer predicates, if they can not push through join, just drop them - for predicate in infer_predicates { - if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? { + for predicate in inferred_join_predicates { + if left_preserved && checker.is_left_only(&predicate) { left_push.push(predicate); - } else if right_preserved - && can_pushdown_join_predicate(&predicate, right_schema)? - { + } else if right_preserved && checker.is_right_only(&predicate) { right_push.push(predicate); } } + let mut on_filter_join_conditions = vec![]; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type); + if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join_plan)?; for on in on_filter { - if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { + if on_left_preserved && checker.is_left_only(&on) { left_push.push(on) - } else if on_right_preserved - && can_pushdown_join_predicate(&on, right_schema)? - { + } else if on_right_preserved && checker.is_right_only(&on) { right_push.push(on) } else { - join_conditions.push(on) + on_filter_join_conditions.push(on) } } } @@ -476,138 +479,269 @@ fn push_down_all_join( right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema)); } - let left = match conjunction(left_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?) - } - None => left.clone(), - }; - let right = match conjunction(right_push) { - Some(predicate) => { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?) - } - None => right.clone(), - }; - // Create a new Join with the new `left` and `right` - // - // expressions() output for Join is a vector consisting of - // 1. join keys - columns mentioned in ON clause - // 2. optional predicate - in case join filter is not empty, - // it always will be the last element, otherwise result - // vector will contain only join keys (without additional - // element representing filter). - let mut exprs = join_plan.expressions(); - if !on_filter_empty { - exprs.pop(); - } - exprs.extend(join_conditions.into_iter().reduce(Expr::and)); - let plan = join_plan.with_new_exprs(exprs, vec![left, right])?; - - // wrap the join on the filter whose predicates must be kept - match conjunction(keep_predicates) { - Some(predicate) => { - Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter) - } - None => Ok(plan), + // For predicates from join filter, we should check with if a join side is preserved + // in term of join filtering. + if on_left_preserved { + left_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + left_schema, + )); + } + if on_right_preserved { + right_push.extend(extract_or_clauses_for_join( + &on_filter_join_conditions, + right_schema, + )); } + + if let Some(predicate) = conjunction(left_push) { + join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); + } + if let Some(predicate) = conjunction(right_push) { + join.right = + Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + } + + // Add any new join conditions as the non join predicates + join_conditions.extend(on_filter_join_conditions); + join.filter = conjunction(join_conditions); + + // wrap the join on the filter whose predicates must be kept, if any + let plan = LogicalPlan::Join(join); + let plan = if let Some(predicate) = conjunction(keep_predicates) { + LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + } else { + plan + }; + Ok(Transformed::yes(plan)) } fn push_down_join( - plan: &LogicalPlan, - join: &Join, + join: Join, parent_predicate: Option<&Expr>, -) -> Result> { - let predicates = match parent_predicate { - Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()), - None => vec![], - }; +) -> Result> { + // Split the parent predicate into individual conjunctive parts. + let predicates = parent_predicate + .map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone())); - // Convert JOIN ON predicate to Predicates + // Extract conjunctions from the JOIN's ON filter, if present. let on_filters = join .filter .as_ref() - .map(|e| split_conjunction_owned(e.clone())) - .unwrap_or_default(); - - let mut is_inner_join = false; - let infer_predicates = if join.join_type == JoinType::Inner { - is_inner_join = true; - // Only allow both side key is column. - let join_col_keys = join - .on - .iter() - .flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) { - (Ok(l_col), Ok(r_col)) => Some((l_col, r_col)), - _ => None, - }) - .collect::>(); - // TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down - // For inner joins, duplicate filters for joined columns so filters can be pushed down - // to both sides. Take the following query as an example: - // - // ```sql - // SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1 - // ``` - // - // `t1.id > 1` predicate needs to be pushed down to t1 table scan, while - // `t2.uid > 1` predicate needs to be pushed down to t2 table scan. - // - // Join clauses with `Using` constraints also take advantage of this logic to make sure - // predicates reference the shared join columns are pushed to both sides. - // This logic should also been applied to conditions in JOIN ON clause - predicates - .iter() - .chain(on_filters.iter()) - .filter_map(|predicate| { - let mut join_cols_to_replace = HashMap::new(); - let columns = match predicate.to_columns() { - Ok(columns) => columns, - Err(e) => return Some(Err(e)), - }; + .map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone())); - for col in columns.iter() { - for (l, r) in join_col_keys.iter() { - if col == l { - join_cols_to_replace.insert(col, r); - break; - } else if col == r { - join_cols_to_replace.insert(col, l); - break; - } - } - } + // Are there any new join predicates that can be inferred from the filter expressions? + let inferred_join_predicates = + infer_join_predicates(&join, &predicates, &on_filters)?; - if join_cols_to_replace.is_empty() { - return None; - } + if on_filters.is_empty() + && predicates.is_empty() + && inferred_join_predicates.is_empty() + { + return Ok(Transformed::no(LogicalPlan::Join(join))); + } - let join_side_predicate = - match replace_col(predicate.clone(), &join_cols_to_replace) { - Ok(p) => p, - Err(e) => { - return Some(Err(e)); - } - }; + push_down_all_join(predicates, inferred_join_predicates, join, on_filters) +} - Some(Ok(join_side_predicate)) - }) - .collect::>>()? - } else { - vec![] - }; +/// Extracts any equi-join join predicates from the given filter expressions. +/// +/// Parameters +/// * `join` the join in question +/// +/// * `predicates` the pushed down filter expression +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +fn infer_join_predicates( + join: &Join, + predicates: &[Expr], + on_filters: &[Expr], +) -> Result> { + // Only allow both side key is column. + let join_col_keys = join + .on + .iter() + .filter_map(|(l, r)| { + let left_col = l.try_as_col()?; + let right_col = r.try_as_col()?; + Some((left_col, right_col)) + }) + .collect::>(); - if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() { - return Ok(None); - } - Ok(Some(push_down_all_join( + let join_type = join.join_type; + + let mut inferred_predicates = InferredPredicates::new(join_type); + + infer_join_predicates_from_predicates( + &join_col_keys, predicates, - infer_predicates, - plan, - &join.left, - &join.right, + &mut inferred_predicates, + )?; + + infer_join_predicates_from_on_filters( + &join_col_keys, + join_type, on_filters, - is_inner_join, - )?)) + &mut inferred_predicates, + )?; + + Ok(inferred_predicates.predicates) +} + +/// Inferred predicates collector. +/// When the JoinType is not Inner, we need to detect whether the inferred predicate can strictly +/// filter out NULL, otherwise ignore it. e.g. +/// ```text +/// SELECT * FROM t1 LEFT JOIN t2 ON t1.c0 = t2.c0 WHERE t2.c0 IS NULL; +/// ``` +/// We cannot infer the predicate `t1.c0 IS NULL`, otherwise the predicate will be pushed down to +/// the left side, resulting in the wrong result. +struct InferredPredicates { + predicates: Vec, + is_inner_join: bool, +} + +impl InferredPredicates { + fn new(join_type: JoinType) -> Self { + Self { + predicates: vec![], + is_inner_join: matches!(join_type, JoinType::Inner), + } + } + + fn try_build_predicate( + &mut self, + predicate: Expr, + replace_map: &HashMap<&Column, &Column>, + ) -> Result<()> { + if self.is_inner_join + || matches!( + is_restrict_null_predicate( + predicate.clone(), + replace_map.keys().cloned() + ), + Ok(true) + ) + { + self.predicates.push(replace_col(predicate, replace_map)?); + } + + Ok(()) + } +} + +/// Infer predicates from the pushed down predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `predicates` the pushed down predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_predicates( + join_col_keys: &[(&Column, &Column)], + predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + infer_join_predicates_impl::( + join_col_keys, + predicates, + inferred_predicates, + ) +} + +/// Infer predicates from the join filter. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `join_type` the JoinType of Join +/// +/// * `on_filters` filters from the join ON clause that have not already been +/// identified as join predicates +/// +/// * `inferred_predicates` the inferred results +/// +fn infer_join_predicates_from_on_filters( + join_col_keys: &[(&Column, &Column)], + join_type: JoinType, + on_filters: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + match join_type { + JoinType::Full | JoinType::LeftAnti | JoinType::RightAnti => Ok(()), + JoinType::Inner => infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ), + JoinType::Left | JoinType::LeftSemi | JoinType::LeftMark => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } + JoinType::Right | JoinType::RightSemi => { + infer_join_predicates_impl::( + join_col_keys, + on_filters, + inferred_predicates, + ) + } + } +} + +/// Infer predicates from the given predicates. +/// +/// Parameters +/// * `join_col_keys` column pairs from the join ON clause +/// +/// * `input_predicates` the given predicates. It can be the pushed down predicates, +/// or it can be the filters of the Join +/// +/// * `inferred_predicates` the inferred results +/// +/// * `ENABLE_LEFT_TO_RIGHT` indicates that the right table related predicate can +/// be inferred from the left table related predicate +/// +/// * `ENABLE_RIGHT_TO_LEFT` indicates that the left table related predicate can +/// be inferred from the right table related predicate +/// +fn infer_join_predicates_impl< + const ENABLE_LEFT_TO_RIGHT: bool, + const ENABLE_RIGHT_TO_LEFT: bool, +>( + join_col_keys: &[(&Column, &Column)], + input_predicates: &[Expr], + inferred_predicates: &mut InferredPredicates, +) -> Result<()> { + for predicate in input_predicates { + let mut join_cols_to_replace = HashMap::new(); + + for &col in &predicate.column_refs() { + for (l, r) in join_col_keys.iter() { + if ENABLE_LEFT_TO_RIGHT && col == *l { + join_cols_to_replace.insert(col, *r); + break; + } + if ENABLE_RIGHT_TO_LEFT && col == *r { + join_cols_to_replace.insert(col, *l); + break; + } + } + } + if join_cols_to_replace.is_empty() { + continue; + } + + inferred_predicates + .try_build_predicate(predicate.clone(), &join_cols_to_replace)?; + } + Ok(()) } impl OptimizerRule for PushDownFilter { @@ -619,52 +753,65 @@ impl OptimizerRule for PushDownFilter { Some(ApplyOrder::TopDown) } - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { - let filter = match plan { - LogicalPlan::Filter(filter) => filter, - // we also need to pushdown filter in Join. - LogicalPlan::Join(join) => return push_down_join(plan, join, None), - _ => return Ok(None), + ) -> Result> { + if let LogicalPlan::Join(join) = plan { + return push_down_join(join, None); + }; + + let plan_schema = Arc::clone(plan.schema()); + + let LogicalPlan::Filter(mut filter) = plan else { + return Ok(Transformed::no(plan)); }; - let child_plan = filter.input.as_ref(); - let new_plan = match child_plan { + match Arc::unwrap_or_clone(filter.input) { LogicalPlan::Filter(child_filter) => { - let parents_predicates = split_conjunction(&filter.predicate); - let set: HashSet<&&Expr> = parents_predicates.iter().collect(); + let parents_predicates = split_conjunction_owned(filter.predicate); + // remove duplicated filters + let child_predicates = split_conjunction_owned(child_filter.predicate); let new_predicates = parents_predicates - .iter() - .chain( - split_conjunction(&child_filter.predicate) - .iter() - .filter(|e| !set.contains(e)), - ) - .map(|e| (*e).clone()) + .into_iter() + .chain(child_predicates) + // use IndexSet to remove dupes while preserving predicate order + .collect::>() + .into_iter() .collect::>(); - let new_predicate = conjunction(new_predicates).ok_or_else(|| { - plan_datafusion_err!("at least one expression exists") - })?; + + let Some(new_predicate) = conjunction(new_predicates) else { + return plan_err!("at least one expression exists"); + }; let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, - child_filter.input.clone(), + child_filter.input, )?); - self.try_optimize(&new_filter, _config)? - .unwrap_or(new_filter) + self.rewrite(new_filter, _config) } - LogicalPlan::Repartition(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Sort(_) => { - // commutable - let new_filter = plan.with_new_exprs( - plan.expressions(), - vec![child_plan.inputs()[0].clone()], - )?; - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + LogicalPlan::Repartition(repartition) => { + let new_filter = + Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Repartition(repartition), new_filter) + } + LogicalPlan::Distinct(distinct) => { + let new_filter = + Filter::try_new(filter.predicate, Arc::clone(distinct.input())) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Distinct(distinct), new_filter) + } + LogicalPlan::Sort(sort) => { + let new_filter = + Filter::try_new(filter.predicate, Arc::clone(&sort.input)) + .map(LogicalPlan::Filter)?; + insert_below(LogicalPlan::Sort(sort), new_filter) } LogicalPlan::SubqueryAlias(subquery_alias) => { let mut replace_map = HashMap::new(); @@ -678,75 +825,85 @@ impl OptimizerRule for PushDownFilter { Expr::Column(Column::new(qualifier.cloned(), field.name())), ); } - let new_predicate = - replace_cols_by_name(filter.predicate.clone(), &replace_map)?; + let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; + let new_filter = LogicalPlan::Filter(Filter::try_new( new_predicate, - subquery_alias.input.clone(), + Arc::clone(&subquery_alias.input), )?); - child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])? + insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } LogicalPlan::Projection(projection) => { - // A projection is filter-commutable if it do not contain volatile predicates or contain volatile - // predicates that are not used in the filter. However, we should re-writes all predicate expressions. - // collect projection. - let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = - projection - .schema - .iter() - .enumerate() - .map(|(i, (qualifier, field))| { - // strip alias, as they should not be part of filters - let expr = match &projection.expr[i] { - Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(), - expr => expr.clone(), - }; - - (qualified_name(qualifier, field.name()), expr) - }) - .partition(|(_, value)| value.is_volatile().unwrap_or(true)); - - let mut push_predicates = vec![]; - let mut keep_predicates = vec![]; - for expr in split_conjunction_owned(filter.predicate.clone()).into_iter() - { - if contain(&expr, &volatile_map) { - keep_predicates.push(expr); + let predicates = split_conjunction_owned(filter.predicate.clone()); + let (new_projection, keep_predicate) = + rewrite_projection(predicates, projection)?; + if new_projection.transformed { + match keep_predicate { + None => Ok(new_projection), + Some(keep_predicate) => new_projection.map_data(|child_plan| { + Filter::try_new(keep_predicate, Arc::new(child_plan)) + .map(LogicalPlan::Filter) + }), + } + } else { + filter.input = Arc::new(new_projection.data); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + } + LogicalPlan::Unnest(mut unnest) => { + let predicates = split_conjunction_owned(filter.predicate.clone()); + let mut non_unnest_predicates = vec![]; + let mut unnest_predicates = vec![]; + for predicate in predicates { + // collect all the Expr::Column in predicate recursively + let mut accum: HashSet = HashSet::new(); + expr_to_columns(&predicate, &mut accum)?; + + if unnest.list_type_columns.iter().any(|(_, unnest_list)| { + accum.contains(&unnest_list.output_column) + }) { + unnest_predicates.push(predicate); } else { - push_predicates.push(expr); + non_unnest_predicates.push(predicate); } } - match conjunction(push_predicates) { - Some(expr) => { - // re-write all filters based on this projection - // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( - replace_cols_by_name(expr, &non_volatile_map)?, - projection.input.clone(), - )?); - - match conjunction(keep_predicates) { - None => child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?, - Some(keep_predicate) => { - let child_plan = child_plan.with_new_exprs( - child_plan.expressions(), - vec![new_filter], - )?; - LogicalPlan::Filter(Filter::try_new( - keep_predicate, - Arc::new(child_plan), - )?) - } - } - } - None => return Ok(None), + // Unnest predicates should not be pushed down. + // If no non-unnest predicates exist, early return + if non_unnest_predicates.is_empty() { + filter.input = Arc::new(LogicalPlan::Unnest(unnest)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + // Push down non-unnest filter predicate + // Unnest + // Unnest Input (Projection) + // -> rewritten to + // Unnest + // Filter + // Unnest Input (Projection) + + let unnest_input = std::mem::take(&mut unnest.input); + + let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( + conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. + unnest_input, + )?); + + // Directly assign new filter plan as the new unnest's input. + // The new filter plan will go through another rewrite pass since the rule itself + // is applied recursively to all the child from top to down + let unnest_plan = + insert_below(LogicalPlan::Unnest(unnest), filter_with_unnest_input)?; + + match conjunction(unnest_predicates) { + None => Ok(unnest_plan), + Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( + Filter::try_new(predicate, Arc::new(unnest_plan.data))?, + ))), } } - LogicalPlan::Union(union) => { + LogicalPlan::Union(ref union) => { let mut inputs = Vec::with_capacity(union.inputs.len()); for input in &union.inputs { let mut replace_map = HashMap::new(); @@ -763,28 +920,28 @@ impl OptimizerRule for PushDownFilter { replace_cols_by_name(filter.predicate.clone(), &replace_map)?; inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( push_predicate, - input.clone(), + Arc::clone(input), )?))) } - LogicalPlan::Union(Union { + Ok(Transformed::yes(LogicalPlan::Union(Union { inputs, - schema: plan.schema().clone(), - }) + schema: Arc::clone(&plan_schema), + }))) } LogicalPlan::Aggregate(agg) => { // We can push down Predicate which in groupby_expr. let group_expr_columns = agg .group_expr .iter() - .map(|e| Ok(Column::from_qualified_name(e.display_name()?))) + .map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string()))) .collect::>>()?; - let predicates = split_conjunction_owned(filter.predicate.clone()); + let predicates = split_conjunction_owned(filter.predicate); let mut keep_predicates = vec![]; let mut push_predicates = vec![]; for expr in predicates { - let cols = expr.to_columns()?; + let cols = expr.column_refs(); if cols.iter().all(|c| group_expr_columns.contains(c)) { push_predicates.push(expr); } else { @@ -797,55 +954,35 @@ impl OptimizerRule for PushDownFilter { // So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))} let mut replace_map = HashMap::new(); for expr in &agg.group_expr { - replace_map.insert(expr.display_name()?, expr.clone()); + replace_map.insert(expr.schema_name().to_string(), expr.clone()); } let replaced_push_predicates = push_predicates - .iter() - .map(|expr| replace_cols_by_name(expr.clone(), &replace_map)) + .into_iter() + .map(|expr| replace_cols_by_name(expr, &replace_map)) .collect::>>()?; - let child = match conjunction(replaced_push_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - agg.input.clone(), - )?), - None => (*agg.input).clone(), - }; - let new_agg = filter - .input - .with_new_exprs(filter.input.expressions(), vec![child])?; - match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_agg), - )?), - None => new_agg, - } - } - LogicalPlan::Join(join) => { - match push_down_join(&filter.input, join, Some(&filter.predicate))? { - Some(optimized_plan) => optimized_plan, - None => return Ok(None), - } - } - LogicalPlan::CrossJoin(cross_join) => { - let predicates = split_conjunction_owned(filter.predicate.clone()); - let join = convert_cross_join_to_inner_join(cross_join.clone())?; - let join_plan = LogicalPlan::Join(join); - let inputs = join_plan.inputs(); - let left = inputs[0]; - let right = inputs[1]; - let plan = push_down_all_join( - predicates, - vec![], - &join_plan, - left, - right, - vec![], - true, - )?; - convert_to_cross_join_if_beneficial(plan)? + let agg_input = Arc::clone(&agg.input); + Transformed::yes(LogicalPlan::Aggregate(agg)) + .transform_data(|new_plan| { + // If we have a filter to push, we push it down to the input of the aggregate + if let Some(predicate) = conjunction(replaced_push_predicates) { + let new_filter = make_filter(predicate, agg_input)?; + insert_below(new_plan, new_filter) + } else { + Ok(Transformed::no(new_plan)) + } + })? + .map_data(|child_plan| { + // if there are any remaining predicates we can't push, add them + // back as a filter + if let Some(predicate) = conjunction(keep_predicates) { + make_filter(predicate, Arc::new(child_plan)) + } else { + Ok(child_plan) + } + }) } + LogicalPlan::Join(join) => push_down_join(join, Some(&filter.predicate)), LogicalPlan::TableScan(scan) => { let filter_predicates = split_conjunction(&filter.predicate); let results = scan @@ -858,12 +995,12 @@ impl OptimizerRule for PushDownFilter { filter_predicates.len()); } - let zip = filter_predicates.iter().zip(results); + let zip = filter_predicates.into_iter().zip(results); let new_scan_filters = zip .clone() .filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported) - .map(|(pred, _)| *pred); + .map(|(pred, _)| pred); let new_scan_filters: Vec = scan .filters .iter() @@ -873,37 +1010,55 @@ impl OptimizerRule for PushDownFilter { .collect(); let new_predicate: Vec = zip .filter(|(_, res)| res != &TableProviderFilterPushDown::Exact) - .map(|(pred, _)| (*pred).clone()) + .map(|(pred, _)| pred.clone()) .collect(); let new_scan = LogicalPlan::TableScan(TableScan { - source: scan.source.clone(), - projection: scan.projection.clone(), - projected_schema: scan.projected_schema.clone(), - table_name: scan.table_name.clone(), filters: new_scan_filters, - fetch: scan.fetch, + ..scan }); - match conjunction(new_predicate) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_scan), - )?), - None => new_scan, - } + Transformed::yes(new_scan).transform_data(|new_scan| { + if let Some(predicate) = conjunction(new_predicate) { + make_filter(predicate, Arc::new(new_scan)).map(Transformed::yes) + } else { + Ok(Transformed::no(new_scan)) + } + }) } LogicalPlan::Extension(extension_plan) => { let prevent_cols = extension_plan.node.prevent_predicate_push_down_columns(); - let predicates = split_conjunction_owned(filter.predicate.clone()); + // determine if we can push any predicates down past the extension node + + // each element is true for push, false to keep + let predicate_push_or_keep = split_conjunction(&filter.predicate) + .iter() + .map(|expr| { + let cols = expr.column_refs(); + if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + Ok(false) // No push (keep) + } else { + Ok(true) // push + } + }) + .collect::>>()?; + + // all predicates are kept, no changes needed + if predicate_push_or_keep.iter().all(|&x| !x) { + filter.input = Arc::new(LogicalPlan::Extension(extension_plan)); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + // going to push some predicates down, so split the predicates let mut keep_predicates = vec![]; let mut push_predicates = vec![]; - for expr in predicates { - let cols = expr.to_columns()?; - if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + for (push, expr) in predicate_push_or_keep + .into_iter() + .zip(split_conjunction_owned(filter.predicate).into_iter()) + { + if !push { keep_predicates.push(expr); } else { push_predicates.push(expr); @@ -925,64 +1080,148 @@ impl OptimizerRule for PushDownFilter { None => extension_plan.node.inputs().into_iter().cloned().collect(), }; // extension with new inputs. + let child_plan = LogicalPlan::Extension(extension_plan); let new_extension = child_plan.with_new_exprs(child_plan.expressions(), new_children)?; - match conjunction(keep_predicates) { + let new_plan = match conjunction(keep_predicates) { Some(predicate) => LogicalPlan::Filter(Filter::try_new( predicate, Arc::new(new_extension), )?), None => new_extension, - } + }; + Ok(Transformed::yes(new_plan)) } - _ => return Ok(None), - }; - Ok(Some(new_plan)) + child => { + filter.input = Arc::new(child); + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + } } } -impl PushDownFilter { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} +/// Attempts to push `predicate` into a `FilterExec` below `projection +/// +/// # Returns +/// (plan, remaining_predicate) +/// +/// `plan` is a LogicalPlan for `projection` with possibly a new FilterExec below it. +/// `remaining_predicate` is any part of the predicate that could not be pushed down +/// +/// # Args +/// - predicates: Split predicates like `[foo=5, bar=6]` +/// - projection: The target projection plan to push down the predicates +/// +/// # Example +/// +/// Pushing a predicate like `foo=5 AND bar=6` with an input plan like this: +/// +/// ```text +/// Projection(foo, c+d as bar) +/// ``` +/// +/// Might result in returning `remaining_predicate` of `bar=6` and a plan like +/// +/// ```text +/// Projection(foo, c+d as bar) +/// Filter(foo=5) +/// ... +/// ``` +fn rewrite_projection( + predicates: Vec, + mut projection: Projection, +) -> Result<(Transformed, Option)> { + // A projection is filter-commutable if it do not contain volatile predicates or contain volatile + // predicates that are not used in the filter. However, we should re-writes all predicate expressions. + // collect projection. + let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) = projection + .schema + .iter() + .zip(projection.expr.iter()) + .map(|((qualifier, field), expr)| { + // strip alias, as they should not be part of filters + let expr = expr.clone().unalias(); + + (qualified_name(qualifier, field.name()), expr) + }) + .partition(|(_, value)| value.is_volatile()); + + let mut push_predicates = vec![]; + let mut keep_predicates = vec![]; + for expr in predicates { + if contain(&expr, &volatile_map) { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); + } + } + + match conjunction(push_predicates) { + Some(expr) => { + // re-write all filters based on this projection + // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" + let new_filter = LogicalPlan::Filter(Filter::try_new( + replace_cols_by_name(expr, &non_volatile_map)?, + std::mem::take(&mut projection.input), + )?); + + projection.input = Arc::new(new_filter); + + Ok(( + Transformed::yes(LogicalPlan::Projection(projection)), + conjunction(keep_predicates), + )) + } + None => Ok((Transformed::no(LogicalPlan::Projection(projection)), None)), } } -/// Converts the given cross join to an inner join with an empty equality -/// predicate and an empty filter condition. -fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result { - let CrossJoin { left, right, .. } = cross_join; - let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?; - Ok(Join { - left, - right, - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - on: vec![], - filter: None, - schema: DFSchemaRef::new(join_schema), - null_equals_null: true, - }) +/// Creates a new LogicalPlan::Filter node. +pub fn make_filter(predicate: Expr, input: Arc) -> Result { + Filter::try_new(predicate, input).map(LogicalPlan::Filter) } -/// Converts the given inner join with an empty equality predicate and an -/// empty filter condition to a cross join. -fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result { - if let LogicalPlan::Join(join) = &plan { - // Can be converted back to cross join - if join.on.is_empty() && join.filter.is_none() { - return LogicalPlanBuilder::from(join.left.as_ref().clone()) - .cross_join(join.right.as_ref().clone())? - .build(); +/// Replace the existing child of the single input node with `new_child`. +/// +/// Starting: +/// ```text +/// plan +/// child +/// ``` +/// +/// Ending: +/// ```text +/// plan +/// new_child +/// ``` +fn insert_below( + plan: LogicalPlan, + new_child: LogicalPlan, +) -> Result> { + let mut new_child = Some(new_child); + let transformed_plan = plan.map_children(|_child| { + if let Some(new_child) = new_child.take() { + Ok(Transformed::yes(new_child)) + } else { + // already took the new child + internal_err!("node had more than one input") } - } else if let LogicalPlan::Filter(filter) = &plan { - let new_input = - convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?; - return Filter::try_new(filter.predicate.clone(), Arc::new(new_input)) - .map(LogicalPlan::Filter); + })?; + + // make sure we did the actual replacement + if new_child.is_some() { + return internal_err!("node had no inputs"); + } + + Ok(transformed_plan) +} + +impl PushDownFilter { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} } - Ok(plan) } /// replaces columns by its name on the projection. @@ -1025,26 +1264,30 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { #[cfg(test)] mod tests { - use super::*; use std::any::Any; + use std::cmp::Ordering; use std::fmt::{Debug, Formatter}; - use crate::optimizer::Optimizer; - use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; - use crate::test::*; - use crate::OptimizerContext; - use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::ScalarValue; + use async_trait::async_trait; + + use datafusion_common::{DFSchemaRef, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ - col, in_list, in_subquery, lit, sum, ColumnarValue, Extension, ScalarUDF, - ScalarUDFImpl, Signature, TableSource, TableType, UserDefinedLogicalNodeCore, - Volatility, + col, in_list, in_subquery, lit, ColumnarValue, Extension, LogicalPlanBuilder, + ScalarUDF, ScalarUDFImpl, Signature, TableSource, TableType, + UserDefinedLogicalNodeCore, Volatility, }; - use async_trait::async_trait; + use crate::optimizer::Optimizer; + use crate::simplify_expressions::SimplifyExpressions; + use crate::test::*; + use crate::OptimizerContext; + use datafusion_expr::test::function_stub::sum; + + use super::*; + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { @@ -1060,13 +1303,13 @@ mod tests { expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ - Arc::new(RewriteDisjunctivePredicate::new()), + Arc::new(SimplifyExpressions::new()), Arc::new(PushDownFilter::new()), ]); let optimized_plan = optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(expected, formatted_plan); Ok(()) } @@ -1137,7 +1380,7 @@ mod tests { .build()?; // filter of key aggregation is commutative let expected = "\ - Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ + Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS total_salary]]\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; assert_optimized_plan_eq(plan, expected) } @@ -1150,7 +1393,7 @@ mod tests { .filter(col("b").gt(lit(10i64)))? .build()?; let expected = "Filter: test.b > Int64(10)\ - \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ + \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } @@ -1162,7 +1405,7 @@ mod tests { .filter(col("test.b + test.a").gt(lit(10i64)))? .build()?; let expected = - "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ + "Aggregate: groupBy=[[test.b + test.a]], aggr=[[sum(test.a), test.b]]\ \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; assert_optimized_plan_eq(plan, expected) } @@ -1177,7 +1420,7 @@ mod tests { // filter of aggregate is after aggregation since they are non-commutative let expected = "\ Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b) AS b]]\ \n TableScan: test"; assert_optimized_plan_eq(plan, expected) } @@ -1227,7 +1470,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "\ Filter: b = Int64(1)\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ @@ -1257,7 +1500,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "\ Filter: a = Int64(1)\ \n Projection: b * Int32(3) AS a, test.c\ @@ -1279,6 +1522,13 @@ mod tests { schema: DFSchemaRef, } + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + impl UserDefinedLogicalNodeCore for NoopPlan { fn name(&self) -> &str { "NoopPlan" @@ -1307,11 +1557,19 @@ mod tests { write!(f, "NoopPlan") } - fn from_template(&self, _exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { - Self { - input: inputs.to_vec(), - schema: self.schema.clone(), - } + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default } } @@ -1322,7 +1580,7 @@ mod tests { let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], - schema: table_scan.schema().clone(), + schema: Arc::clone(table_scan.schema()), }), }); let plan = LogicalPlanBuilder::from(custom_plan) @@ -1338,7 +1596,7 @@ mod tests { let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone()], - schema: table_scan.schema().clone(), + schema: Arc::clone(table_scan.schema()), }), }); let plan = LogicalPlanBuilder::from(custom_plan) @@ -1355,7 +1613,7 @@ mod tests { let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], - schema: table_scan.schema().clone(), + schema: Arc::clone(table_scan.schema()), }), }); let plan = LogicalPlanBuilder::from(custom_plan) @@ -1372,7 +1630,7 @@ mod tests { let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { input: vec![table_scan.clone(), table_scan.clone()], - schema: table_scan.schema().clone(), + schema: Arc::clone(table_scan.schema()), }), }); let plan = LogicalPlanBuilder::from(custom_plan) @@ -1392,30 +1650,30 @@ mod tests { /// and the other not. #[test] fn multi_filter() -> Result<()> { - // the aggregation allows one filter to pass (b), and the other one to not pass (SUM(c)) + // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c)) let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("a").alias("b"), col("c")])? .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(col("b").gt(lit(10i64)))? - .filter(col("SUM(test.c)").gt(lit(10i64)))? + .filter(col("sum(test.c)").gt(lit(10i64)))? .build()?; // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "\ - Filter: SUM(test.c) > Int64(10)\ + Filter: sum(test.c) > Int64(10)\ \n Filter: b > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ + \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test" ); // filter is before the projections let expected = "\ - Filter: SUM(test.c) > Int64(10)\ - \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ + Filter: sum(test.c) > Int64(10)\ + \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; assert_optimized_plan_eq(plan, expected) @@ -1425,31 +1683,31 @@ mod tests { /// and the other not. #[test] fn split_filter() -> Result<()> { - // the aggregation allows one filter to pass (b), and the other one to not pass (SUM(c)) + // the aggregation allows one filter to pass (b), and the other one to not pass (sum(c)) let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) .project(vec![col("a").alias("b"), col("c")])? .aggregate(vec![col("b")], vec![sum(col("c"))])? .filter(and( - col("SUM(test.c)").gt(lit(10i64)), - and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))), + col("sum(test.c)").gt(lit(10i64)), + and(col("b").gt(lit(10i64)), col("sum(test.c)").lt(lit(20i64))), ))? .build()?; // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "\ - Filter: SUM(test.c) > Int64(10) AND b > Int64(10) AND SUM(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ + Filter: sum(test.c) > Int64(10) AND b > Int64(10) AND sum(test.c) < Int64(20)\ + \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test" ); // filter is before the projections let expected = "\ - Filter: SUM(test.c) > Int64(10) AND SUM(test.c) < Int64(20)\ - \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ + Filter: sum(test.c) > Int64(10) AND sum(test.c) < Int64(20)\ + \n Aggregate: groupBy=[[b]], aggr=[[sum(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; assert_optimized_plan_eq(plan, expected) @@ -1537,7 +1795,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.d\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.d, test1.e, test1.f\ @@ -1564,7 +1822,7 @@ mod tests { .build()?; let expected = "Projection: test.a, test1.a\ - \n CrossJoin:\ + \n Cross Join: \ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ @@ -1587,7 +1845,7 @@ mod tests { // not part of the test assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.a >= Int64(1)\ \n Projection: test.a\ \n Limit: skip=0, fetch=1\ @@ -1619,7 +1877,7 @@ mod tests { // not part of the test assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Projection: test.a\ \n Filter: test.a >= Int64(1)\ \n Filter: test.a <= Int64(1)\ @@ -1653,7 +1911,7 @@ mod tests { \n TableScan: test"; // not part of the test - assert_eq!(format!("{plan:?}"), expected); + assert_eq!(format!("{plan}"), expected); let expected = "\ TestUserDefined\ @@ -1683,7 +1941,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.a <= Int64(1)\ \n Inner Join: test.a = test2.a\ \n TableScan: test\ @@ -1720,7 +1978,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.a <= Int64(1)\ \n Inner Join: Using test.a = test2.a\ \n TableScan: test\ @@ -1737,7 +1995,7 @@ mod tests { assert_optimized_plan_eq(plan, expected) } - /// post-join predicates with columns from both sides are converted to join filterss + /// post-join predicates with columns from both sides are converted to join filters #[test] fn filter_join_on_common_dependent() -> Result<()> { let table_scan = test_table_scan()?; @@ -1760,7 +2018,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.c <= test2.b\ \n Inner Join: test.a = test2.a\ \n Projection: test.a, test.c\ @@ -1803,7 +2061,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.b <= Int64(1)\ \n Inner Join: test.a = test2.a\ \n Projection: test.a, test.b\ @@ -1842,7 +2100,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test2.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ \n TableScan: test\ @@ -1854,7 +2112,7 @@ mod tests { let expected = "\ Filter: test2.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ - \n TableScan: test\ + \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; assert_optimized_plan_eq(plan, expected) @@ -1880,7 +2138,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.a <= Int64(1)\ \n Right Join: Using test.a = test2.a\ \n TableScan: test\ @@ -1894,7 +2152,7 @@ mod tests { \n Right Join: Using test.a = test2.a\ \n TableScan: test\ \n Projection: test2.a\ - \n TableScan: test2"; + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; assert_optimized_plan_eq(plan, expected) } @@ -1919,7 +2177,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test.a <= Int64(1)\ \n Left Join: Using test.a = test2.a\ \n TableScan: test\ @@ -1957,7 +2215,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: test2.a <= Int64(1)\ \n Right Join: Using test.a = test2.a\ \n TableScan: test\ @@ -2000,7 +2258,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Inner Join: test.a = test2.a Filter: test.c > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ @@ -2042,7 +2300,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Inner Join: test.a = test2.a Filter: test.b > UInt32(1) AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ @@ -2082,7 +2340,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Inner Join: test.a = test2.b Filter: test.a > UInt32(1)\ \n Projection: test.a\ \n TableScan: test\ @@ -2125,7 +2383,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Left Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ @@ -2168,7 +2426,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Right Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ @@ -2211,7 +2469,7 @@ mod tests { // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Full Join: test.a = test2.a Filter: test.a > UInt32(1) AND test.b < test2.b AND test2.c > UInt32(4)\ \n Projection: test.a, test.b, test.c\ \n TableScan: test\ @@ -2219,7 +2477,7 @@ mod tests { \n TableScan: test2" ); - let expected = &format!("{plan:?}"); + let expected = &format!("{plan}"); assert_optimized_plan_eq(plan, expected) } @@ -2240,14 +2498,16 @@ mod tests { TableType::Base } - fn supports_filter_pushdown( + fn supports_filters_pushdown( &self, - _e: &Expr, - ) -> Result { - Ok(self.filter_support.clone()) + filters: &[&Expr], + ) -> Result> { + Ok((0..filters.len()) + .map(|_| self.filter_support.clone()) + .collect()) } - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } } @@ -2299,9 +2559,9 @@ mod tests { table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; let optimized_plan = PushDownFilter::new() - .try_optimize(&plan, &OptimizerContext::new()) + .rewrite(plan, &OptimizerContext::new()) .expect("failed to optimize plan") - .unwrap(); + .data; let expected = "\ Filter: a = Int64(1)\ @@ -2396,7 +2656,7 @@ Projection: a, b // filter on col b assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: b > Int64(10) AND test.c > Int64(10)\ \n Projection: test.a AS b, test.c\ \n TableScan: test" @@ -2425,7 +2685,7 @@ Projection: a, b // filter on col b assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: b > Int64(10) AND test.c > Int64(10)\ \n Projection: b, test.c\ \n Projection: test.a AS b, test.c\ @@ -2453,7 +2713,7 @@ Projection: a, b // filter on col b and d assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: b > Int64(10) AND d > Int64(10)\ \n Projection: test.a AS b, test.c AS d\ \n TableScan: test\ @@ -2490,7 +2750,7 @@ Projection: a, b .build()?; assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Inner Join: c = d Filter: c > UInt32(1)\ \n Projection: test.a AS c\ \n TableScan: test\ @@ -2522,7 +2782,7 @@ Projection: a, b // filter on col b assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ \n Projection: test.a AS b, test.c\ \n TableScan: test\ @@ -2552,7 +2812,7 @@ Projection: a, b // filter on col b assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "Filter: b IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])\ \n Projection: b, test.c\ \n Projection: test.a AS b, test.c\ @@ -2593,7 +2853,7 @@ Projection: a, b \n TableScan: sq\ \n Projection: test.a AS b, test.c\ \n TableScan: test"; - assert_eq!(format!("{plan:?}"), expected_before); + assert_eq!(format!("{plan}"), expected_before); // rewrite filter col b to test.a let expected_after = "\ @@ -2624,7 +2884,7 @@ Projection: a, b \n SubqueryAlias: b\ \n Projection: Int64(0) AS a\ \n EmptyRelation"; - assert_eq!(format!("{plan:?}"), expected_before); + assert_eq!(format!("{plan}"), expected_before); // Ensure that the predicate without any columns (0 = 1) is // still there. @@ -2668,11 +2928,52 @@ Projection: a, b // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() - .try_optimize(&plan, &OptimizerContext::new())? - .expect("failed to optimize plan"); + .rewrite(plan, &OptimizerContext::new()) + .expect("failed to optimize plan") + .data; assert_optimized_plan_eq(optimized_plan, expected) } + #[test] + fn left_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2" + ); + + // Inferred the predicate `test1.a <= Int64(1)` and push it down to the left side. + let expected = "\ + Filter: test2.a <= Int64(1)\ + \n LeftSemi Join: test1.a = test2.a\ + \n TableScan: test1, full_filters=[test1.a <= Int64(1)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2698,7 +2999,7 @@ Projection: a, b // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "LeftSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ \n TableScan: test1\ \n Projection: test2.a, test2.b\ @@ -2714,6 +3015,46 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_semi_join() -> Result<()> { + let left = test_table_scan_with_name("test1")?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightSemi, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").lt_eq(lit(1i64)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // Inferred the predicate `test2.a <= Int64(1)` and push it down to the right side. + let expected = "\ + Filter: test1.a <= Int64(1)\ + \n RightSemi Join: test1.a = test2.a\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_semi_join_with_filters() -> Result<()> { let left = test_table_scan_with_name("test1")?; @@ -2739,7 +3080,7 @@ Projection: a, b // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "RightSemi Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ \n TableScan: test1\ \n Projection: test2.a, test2.b\ @@ -2755,6 +3096,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn left_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::LeftAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test2.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For left anti, filter of the right side filter can be pushed down. + let expected = "\ + Filter: test2.a > UInt32(2)\ + \n LeftAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1, full_filters=[test1.a > UInt32(2)]\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn left_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; @@ -2783,7 +3169,7 @@ Projection: a, b // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "LeftAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ \n Projection: test1.a, test1.b\ \n TableScan: test1\ @@ -2801,6 +3187,51 @@ Projection: a, b assert_optimized_plan_eq(plan, expected) } + #[test] + fn right_anti_join() -> Result<()> { + let table_scan = test_table_scan_with_name("test1")?; + let left = LogicalPlanBuilder::from(table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let right_table_scan = test_table_scan_with_name("test2")?; + let right = LogicalPlanBuilder::from(right_table_scan) + .project(vec![col("a"), col("b")])? + .build()?; + let plan = LogicalPlanBuilder::from(left) + .join( + right, + JoinType::RightAnti, + ( + vec![Column::from_qualified_name("test1.a")], + vec![Column::from_qualified_name("test2.a")], + ), + None, + )? + .filter(col("test1.a").gt(lit(2u32)))? + .build()?; + + // not part of the test, just good to know: + assert_eq!( + format!("{plan}"), + "Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2", + ); + + // For right anti, filter of the left side can be pushed down. + let expected = "\ + Filter: test1.a > UInt32(2)\ + \n RightAnti Join: test1.a = test2.a\ + \n Projection: test1.a, test1.b\ + \n TableScan: test1\ + \n Projection: test2.a, test2.b\ + \n TableScan: test2, full_filters=[test2.a > UInt32(2)]"; + assert_optimized_plan_eq(plan, expected) + } + #[test] fn right_anti_join_with_filters() -> Result<()> { let table_scan = test_table_scan_with_name("test1")?; @@ -2829,7 +3260,7 @@ Projection: a, b // not part of the test, just good to know: assert_eq!( - format!("{plan:?}"), + format!("{plan}"), "RightAnti Join: test1.a = test2.a Filter: test1.b > UInt32(1) AND test2.b > UInt32(2)\ \n Projection: test1.a, test1.b\ \n TableScan: test1\ @@ -2874,7 +3305,7 @@ Projection: a, b #[test] fn test_push_down_volatile_function_in_aggregate() -> Result<()> { - // SELECT t.a, t.r FROM (SELECT a, SUM(b), TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; + // SELECT t.a, t.r FROM (SELECT a, sum(b), TestScalarUDF()+1 AS r FROM test1 GROUP BY a) AS t WHERE t.a > 5 AND t.r > 0.5; let table_scan = test_table_scan_with_name("test1")?; let fun = ScalarUDF::new_from_impl(TestScalarUDF { signature: Signature::exact(vec![], Volatility::Volatile), @@ -2892,16 +3323,16 @@ Projection: a, b let expected_before = "Projection: t.a, t.r\ \n Filter: t.a > Int32(5) AND t.r > Float64(0.5)\ \n SubqueryAlias: t\ - \n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ \n TableScan: test1"; - assert_eq!(format!("{plan:?}"), expected_before); + assert_eq!(format!("{plan}"), expected_before); let expected_after = "Projection: t.a, t.r\ \n SubqueryAlias: t\ \n Filter: r > Float64(0.5)\ - \n Projection: test1.a, SUM(test1.b), TestScalarUDF() + Int32(1) AS r\ - \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ + \n Projection: test1.a, sum(test1.b), TestScalarUDF() + Int32(1) AS r\ + \n Aggregate: groupBy=[[test1.a]], aggr=[[sum(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; assert_optimized_plan_eq(plan, expected_after) } @@ -2940,7 +3371,7 @@ Projection: a, b \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; - assert_eq!(format!("{plan:?}"), expected_before); + assert_eq!(format!("{plan}"), expected_before); let expected = "Projection: t.a, t.r\ \n SubqueryAlias: t\ diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 1af246fc556d..8a3aa4bb8459 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -17,21 +17,22 @@ //! [`PushDownLimit`] pushes `LIMIT` earlier in the query plan +use std::cmp::min; use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::utils::combine_limit; use datafusion_common::Result; -use datafusion_expr::logical_plan::{ - Join, JoinType, Limit, LogicalPlan, Sort, TableScan, Union, -}; -use datafusion_expr::CrossJoin; +use datafusion_expr::logical_plan::{Join, JoinType, Limit, LogicalPlan}; +use datafusion_expr::{lit, FetchType, SkipType}; /// Optimization rule that tries to push down `LIMIT`. /// //. It will push down through projection, limits (taking the smaller limit) -#[derive(Default)] +#[derive(Default, Debug)] pub struct PushDownLimit {} impl PushDownLimit { @@ -43,168 +44,140 @@ impl PushDownLimit { /// Push down Limit. impl OptimizerRule for PushDownLimit { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { - use std::cmp::min; + ) -> Result> { + let LogicalPlan::Limit(mut limit) = plan else { + return Ok(Transformed::no(plan)); + }; - let LogicalPlan::Limit(limit) = plan else { - return Ok(None); + // Currently only rewrite if skip and fetch are both literals + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - if let LogicalPlan::Limit(child) = &*limit.input { - // Merge the Parent Limit and the Child Limit. - - // Case 0: Parent and Child are disjoint. (child_fetch <= skip) - // Before merging: - // |........skip........|---fetch-->| Parent Limit - // |...child_skip...|---child_fetch-->| Child Limit - // After merging: - // |.........(child_skip + skip).........| - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 1: Parent is beyond the range of Child. (skip < child_fetch <= skip + fetch) - // Before merging: - // |...skip...|------------fetch------------>| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---(child_fetch - skip)-->| - - // Case 2: Parent is in the range of Child. (skip + fetch < child_fetch) - // Before merging: - // |...skip...|---fetch-->| Parent Limit - // |...child_skip...|-------------child_fetch------------>| Child Limit - // After merging: - // |....(child_skip + skip)....|---fetch-->| - let parent_skip = limit.skip; - let new_fetch = match (limit.fetch, child.fetch) { - (Some(fetch), Some(child_fetch)) => { - Some(min(fetch, child_fetch.saturating_sub(parent_skip))) - } - (Some(fetch), None) => Some(fetch), - (None, Some(child_fetch)) => { - Some(child_fetch.saturating_sub(parent_skip)) - } - (None, None) => None, + // Merge the Parent Limit and the Child Limit. + if let LogicalPlan::Limit(child) = limit.input.as_ref() { + let SkipType::Literal(child_skip) = child.get_skip_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); + }; + let FetchType::Literal(child_fetch) = child.get_fetch_type()? else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; + let (skip, fetch) = combine_limit(skip, fetch, child_skip, child_fetch); let plan = LogicalPlan::Limit(Limit { - skip: child.skip + parent_skip, - fetch: new_fetch, - input: Arc::new((*child.input).clone()), + skip: Some(Box::new(lit(skip as i64))), + fetch: fetch.map(|f| Box::new(lit(f as i64))), + input: Arc::clone(&child.input), }); - return self - .try_optimize(&plan, _config) - .map(|opt_plan| opt_plan.or_else(|| Some(plan))); + + // recursively reapply the rule on the new plan + return self.rewrite(plan, _config); } - let Some(fetch) = limit.fetch else { - return Ok(None); + // no fetch to push, so return the original plan + let Some(fetch) = fetch else { + return Ok(Transformed::no(LogicalPlan::Limit(limit))); }; - let skip = limit.skip; - match limit.input.as_ref() { - LogicalPlan::TableScan(scan) => { - let limit = if fetch != 0 { fetch + skip } else { 0 }; - let new_fetch = scan.fetch.map(|x| min(x, limit)).or(Some(limit)); + match Arc::unwrap_or_clone(limit.input) { + LogicalPlan::TableScan(mut scan) => { + let rows_needed = if fetch != 0 { fetch + skip } else { 0 }; + let new_fetch = scan + .fetch + .map(|x| min(x, rows_needed)) + .or(Some(rows_needed)); if new_fetch == scan.fetch { - Ok(None) + original_limit(skip, fetch, LogicalPlan::TableScan(scan)) } else { - let new_input = LogicalPlan::TableScan(TableScan { - table_name: scan.table_name.clone(), - source: scan.source.clone(), - projection: scan.projection.clone(), - filters: scan.filters.clone(), - fetch: scan.fetch.map(|x| min(x, limit)).or(Some(limit)), - projected_schema: scan.projected_schema.clone(), - }); - plan.with_new_exprs(plan.expressions(), vec![new_input]) - .map(Some) + // push limit into the table scan itself + scan.fetch = scan + .fetch + .map(|x| min(x, rows_needed)) + .or(Some(rows_needed)); + transformed_limit(skip, fetch, LogicalPlan::TableScan(scan)) } } - LogicalPlan::Union(union) => { - let new_inputs = union + LogicalPlan::Union(mut union) => { + // push limits to each input of the union + union.inputs = union .inputs - .iter() - .map(|x| { - Ok(Arc::new(LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), - input: x.clone(), - }))) - }) - .collect::>()?; - let union = LogicalPlan::Union(Union { - inputs: new_inputs, - schema: union.schema.clone(), - }); - plan.with_new_exprs(plan.expressions(), vec![union]) - .map(Some) + .into_iter() + .map(|input| make_arc_limit(0, fetch + skip, input)) + .collect(); + transformed_limit(skip, fetch, LogicalPlan::Union(union)) } - LogicalPlan::CrossJoin(cross_join) => { - let new_left = LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), - input: cross_join.left.clone(), - }); - let new_right = LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(fetch + skip), - input: cross_join.right.clone(), - }); - let new_cross_join = LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(new_left), - right: Arc::new(new_right), - schema: plan.schema().clone(), - }); - plan.with_new_exprs(plan.expressions(), vec![new_cross_join]) - .map(Some) - } - - LogicalPlan::Join(join) => { - if let Some(new_join) = push_down_join(join, fetch + skip) { - let inputs = vec![LogicalPlan::Join(new_join)]; - plan.with_new_exprs(plan.expressions(), inputs).map(Some) - } else { - Ok(None) - } - } + LogicalPlan::Join(join) => Ok(push_down_join(join, fetch + skip) + .update_data(|join| { + make_limit(skip, fetch, Arc::new(LogicalPlan::Join(join))) + })), - LogicalPlan::Sort(sort) => { + LogicalPlan::Sort(mut sort) => { let new_fetch = { let sort_fetch = skip + fetch; Some(sort.fetch.map(|f| f.min(sort_fetch)).unwrap_or(sort_fetch)) }; if new_fetch == sort.fetch { - Ok(None) + if skip > 0 { + original_limit(skip, fetch, LogicalPlan::Sort(sort)) + } else { + Ok(Transformed::yes(LogicalPlan::Sort(sort))) + } } else { - let new_sort = LogicalPlan::Sort(Sort { - expr: sort.expr.clone(), - input: sort.input.clone(), - fetch: new_fetch, - }); - plan.with_new_exprs(plan.expressions(), vec![new_sort]) - .map(Some) + sort.fetch = new_fetch; + limit.input = Arc::new(LogicalPlan::Sort(sort)); + Ok(Transformed::yes(LogicalPlan::Limit(limit))) } } - child_plan @ (LogicalPlan::Projection(_) | LogicalPlan::SubqueryAlias(_)) => { + LogicalPlan::Projection(mut proj) => { + // commute + limit.input = Arc::clone(&proj.input); + let new_limit = LogicalPlan::Limit(limit); + proj.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::Projection(proj))) + } + LogicalPlan::SubqueryAlias(mut subquery_alias) => { // commute - let new_limit = plan.with_new_exprs( - plan.expressions(), - vec![child_plan.inputs()[0].clone()], - )?; - child_plan - .with_new_exprs(child_plan.expressions(), vec![new_limit]) - .map(Some) + limit.input = Arc::clone(&subquery_alias.input); + let new_limit = LogicalPlan::Limit(limit); + subquery_alias.input = Arc::new(new_limit); + Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias))) } - _ => Ok(None), + LogicalPlan::Extension(extension_plan) + if extension_plan.node.supports_limit_pushdown() => + { + let new_children = extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + LogicalPlan::Limit(Limit { + skip: None, + fetch: Some(Box::new(lit((fetch + skip) as i64))), + input: Arc::new(child.clone()), + }) + }) + .collect::>(); + + // Create a new extension node with updated inputs + let child_plan = LogicalPlan::Extension(extension_plan); + let new_extension = + child_plan.with_new_exprs(child_plan.expressions(), new_children)?; + + transformed_limit(skip, fetch, new_extension) + } + input => original_limit(skip, fetch, input), } } @@ -217,74 +190,325 @@ impl OptimizerRule for PushDownLimit { } } -fn push_down_join(join: &Join, limit: usize) -> Option { +/// Wrap the input plan with a limit node +/// +/// Original: +/// ```text +/// input +/// ``` +/// +/// Return +/// ```text +/// Limit: skip=skip, fetch=fetch +/// input +/// ``` +fn make_limit(skip: usize, fetch: usize, input: Arc) -> LogicalPlan { + LogicalPlan::Limit(Limit { + skip: Some(Box::new(lit(skip as i64))), + fetch: Some(Box::new(lit(fetch as i64))), + input, + }) +} + +/// Wrap the input plan with a limit node +fn make_arc_limit( + skip: usize, + fetch: usize, + input: Arc, +) -> Arc { + Arc::new(make_limit(skip, fetch, input)) +} + +/// Returns the original limit (non transformed) +fn original_limit( + skip: usize, + fetch: usize, + input: LogicalPlan, +) -> Result> { + Ok(Transformed::no(make_limit(skip, fetch, Arc::new(input)))) +} + +/// Returns the a transformed limit +fn transformed_limit( + skip: usize, + fetch: usize, + input: LogicalPlan, +) -> Result> { + Ok(Transformed::yes(make_limit(skip, fetch, Arc::new(input)))) +} + +/// Adds a limit to the inputs of a join, if possible +fn push_down_join(mut join: Join, limit: usize) -> Transformed { use JoinType::*; fn is_no_join_condition(join: &Join) -> bool { join.on.is_empty() && join.filter.is_none() } - let (left_limit, right_limit) = if is_no_join_condition(join) { + let (left_limit, right_limit) = if is_no_join_condition(&join) { match join.join_type { - Left | Right | Full => (Some(limit), Some(limit)), - LeftAnti | LeftSemi => (Some(limit), None), + Left | Right | Full | Inner => (Some(limit), Some(limit)), + LeftAnti | LeftSemi | LeftMark => (Some(limit), None), RightAnti | RightSemi => (None, Some(limit)), - Inner => (None, None), } } else { match join.join_type { Left => (Some(limit), None), Right => (None, Some(limit)), + Full => (Some(limit), Some(limit)), _ => (None, None), } }; - match (left_limit, right_limit) { - (None, None) => None, - _ => { - let left = match left_limit { - Some(limit) => Arc::new(LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(limit), - input: join.left.clone(), - })), - None => join.left.clone(), - }; - let right = match right_limit { - Some(limit) => Arc::new(LogicalPlan::Limit(Limit { - skip: 0, - fetch: Some(limit), - input: join.right.clone(), - })), - None => join.right.clone(), - }; - Some(Join { - left, - right, - on: join.on.clone(), - filter: join.filter.clone(), - join_type: join.join_type, - join_constraint: join.join_constraint, - schema: join.schema.clone(), - null_equals_null: join.null_equals_null, - }) - } + if left_limit.is_none() && right_limit.is_none() { + return Transformed::no(join); + } + if let Some(limit) = left_limit { + join.left = make_arc_limit(0, limit, join.left); } + if let Some(limit) = right_limit { + join.right = make_arc_limit(0, limit, join.right); + } + Transformed::yes(join) } #[cfg(test)] mod test { + use std::cmp::Ordering; + use std::fmt::{Debug, Formatter}; use std::vec; use super::*; use crate::test::*; - use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder, max}; + use datafusion_common::DFSchemaRef; + use datafusion_expr::{ + col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension, + UserDefinedLogicalNodeCore, + }; + use datafusion_functions_aggregate::expr_fn::max; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } + #[derive(Debug, PartialEq, Eq, Hash)] + pub struct NoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + + impl UserDefinedLogicalNodeCore for NoopPlan { + fn name(&self) -> &str { + "NoopPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoopPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + true // Allow limit push-down + } + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct NoLimitNoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + // Manual implementation needed because of `schema` field. Comparison excludes this field. + impl PartialOrd for NoLimitNoopPlan { + fn partial_cmp(&self, other: &Self) -> Option { + self.input.partial_cmp(&other.input) + } + } + + impl UserDefinedLogicalNodeCore for NoLimitNoopPlan { + fn name(&self) -> &str { + "NoLimitNoopPlan" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoLimitNoopPlan") + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs, + schema: Arc::clone(&self.schema), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } + } + #[test] + fn limit_pushdown_basic() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_with_skip() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(10, Some(1000))? + .build()?; + + let expected = "Limit: skip=10, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1010\ + \n TableScan: test, fetch=1010"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_multiple_limits() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(10, Some(1000))? + .limit(20, Some(500))? + .build()?; + + let expected = "Limit: skip=30, fetch=500\ + \n NoopPlan\ + \n Limit: skip=0, fetch=530\ + \n TableScan: test, fetch=530"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_multiple_inputs() -> Result<()> { + let table_scan = test_table_scan()?; + let noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone(), table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoopPlan\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000"; + + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn limit_pushdown_disallowed_noop_plan() -> Result<()> { + let table_scan = test_table_scan()?; + let no_limit_noop_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoLimitNoopPlan { + input: vec![table_scan.clone()], + schema: Arc::clone(table_scan.schema()), + }), + }); + + let plan = LogicalPlanBuilder::from(no_limit_noop_plan) + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n NoLimitNoopPlan\ + \n TableScan: test"; + + assert_optimized_plan_equal(plan, expected) + } + #[test] fn limit_pushdown_projection_table_provider() -> Result<()> { let table_scan = test_table_scan()?; @@ -332,7 +556,7 @@ mod test { // Limit should *not* push down aggregate node let expected = "Limit: skip=0, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) @@ -363,13 +587,13 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(0, Some(10))? .build()?; // Should push down limit to sort let expected = "Limit: skip=0, fetch=10\ - \n Sort: test.a, fetch=10\ + \n Sort: test.a ASC NULLS LAST, fetch=10\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) @@ -380,13 +604,13 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .sort(vec![col("a")])? + .sort_by(vec![col("a")])? .limit(5, Some(10))? .build()?; // Should push down limit to sort let expected = "Limit: skip=5, fetch=10\ - \n Sort: test.a, fetch=15\ + \n Sort: test.a ASC NULLS LAST, fetch=15\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) @@ -404,7 +628,7 @@ mod test { // Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation let expected = "Limit: skip=0, fetch=10\ - \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; @@ -505,7 +729,7 @@ mod test { // Limit should *not* push down aggregate node let expected = "Limit: skip=10, fetch=1000\ - \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[max(test.b)]]\ \n TableScan: test"; assert_optimized_plan_equal(plan, expected) @@ -884,7 +1108,7 @@ mod test { .build()?; let expected = "Limit: skip=0, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000\ \n Limit: skip=0, fetch=1000\ @@ -904,7 +1128,7 @@ mod test { .build()?; let expected = "Limit: skip=1000, fetch=1000\ - \n CrossJoin:\ + \n Cross Join: \ \n Limit: skip=0, fetch=2000\ \n TableScan: test, fetch=2000\ \n Limit: skip=0, fetch=2000\ diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 4f68e2623f40..f3e1673e7211 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -16,15 +16,16 @@ // under the License. //! [`ReplaceDistinctWithAggregate`] replaces `DISTINCT ...` with `GROUP BY ...` + use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; +use std::sync::Arc; +use datafusion_common::tree_node::Transformed; use datafusion_common::{Column, Result}; +use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; -use datafusion_expr::{ - aggregate_function::AggregateFunction as AggregateFunctionFunc, col, - expr::AggregateFunction, LogicalPlanBuilder, -}; +use datafusion_expr::{col, ExprFunctionExt, LogicalPlanBuilder}; use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] @@ -55,7 +56,7 @@ use datafusion_expr::{Aggregate, Distinct, DistinctOn, Expr, LogicalPlan}; /// ``` /// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]] -#[derive(Default)] +#[derive(Default, Debug)] pub struct ReplaceDistinctWithAggregate {} impl ReplaceDistinctWithAggregate { @@ -66,20 +67,40 @@ impl ReplaceDistinctWithAggregate { } impl OptimizerRule for ReplaceDistinctWithAggregate { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { match plan { LogicalPlan::Distinct(Distinct::All(input)) => { - let group_expr = expand_wildcard(input.schema(), input, None)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( - input.clone(), + let group_expr = expand_wildcard(input.schema(), &input, None)?; + + let field_count = input.schema().fields().len(); + for dep in input.schema().functional_dependencies().iter() { + // If distinct is exactly the same with a previous GROUP BY, we can + // simply remove it: + if dep.source_indices.len() >= field_count + && dep.source_indices[..field_count] + .iter() + .enumerate() + .all(|(idx, f_idx)| idx == *f_idx) + { + return Ok(Transformed::yes(input.as_ref().clone())); + } + } + + // Replace with aggregation: + let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, vec![], )?); - Ok(Some(aggregate)) + Ok(Transformed::yes(aggr_plan)) } LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, @@ -88,57 +109,69 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { input, schema, })) => { + let expr_cnt = on_expr.len(); + // Construct the aggregation expression to be used to fetch the selected expressions. - let aggr_expr = select_expr - .iter() - .map(|e| { - Expr::AggregateFunction(AggregateFunction::new( - AggregateFunctionFunc::FirstValue, - vec![e.clone()], - false, - None, - sort_expr.clone(), - None, - )) - }) - .collect::>(); + let first_value_udaf: Arc = + config.function_registry().unwrap().udaf("first_value")?; + let aggr_expr = select_expr.into_iter().map(|e| { + if let Some(order_by) = &sort_expr { + first_value_udaf + .call(vec![e]) + .order_by(order_by.clone()) + .build() + // guaranteed to be `Expr::AggregateFunction` + .unwrap() + } else { + first_value_udaf.call(vec![e]) + } + }); + + let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; + let group_expr = normalize_cols(on_expr, input.as_ref())?; // Build the aggregation plan - let plan = LogicalPlanBuilder::from(input.as_ref().clone()) - .aggregate(on_expr.clone(), aggr_expr.to_vec())? - .build()?; + let plan = LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?); + // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate + // when https://github.com/apache/datafusion/issues/10485 is available + let lpb = LogicalPlanBuilder::from(plan); - let plan = if let Some(sort_expr) = sort_expr { + let plan = if let Some(mut sort_expr) = sort_expr { // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, // this on it's own isn't enough to guarantee the proper output order of the grouping // (`ON`) expression, so we need to sort those as well. - LogicalPlanBuilder::from(plan) - .sort(sort_expr[..on_expr.len()].to_vec())? - .build()? + + // truncate the sort_expr to the length of on_expr + sort_expr.truncate(expr_cnt); + + lpb.sort(sort_expr)?.build()? } else { - plan + lpb.build()? }; // Whereas the aggregation plan by default outputs both the grouping and the aggregation // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan .schema() .iter() - .skip(on_expr.len()) + .skip(expr_cnt) .zip(schema.iter()) .map(|((new_qualifier, new_field), (old_qualifier, old_field))| { - Ok(col(Column::from((new_qualifier, new_field))) - .alias_qualified(old_qualifier.cloned(), old_field.name())) + col(Column::from((new_qualifier, new_field))) + .alias_qualified(old_qualifier.cloned(), old_field.name()) }) - .collect::>>()?; + .collect::>(); let plan = LogicalPlanBuilder::from(plan) .project(project_exprs)? .build()?; - Ok(Some(plan)) + Ok(Transformed::yes(plan)) } - _ => Ok(None), + _ => Ok(Transformed::no(plan)), } } @@ -153,50 +186,75 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { #[cfg(test)] mod tests { - use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; - use crate::test::{assert_optimized_plan_eq, test_table_scan}; - use datafusion_expr::{col, LogicalPlanBuilder}; use std::sync::Arc; + use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate; + use crate::test::*; + + use datafusion_common::Result; + use datafusion_expr::{ + col, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan, + }; + use datafusion_functions_aggregate::sum::sum; + + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq( + Arc::new(ReplaceDistinctWithAggregate::new()), + plan.clone(), + expected, + ) + } + #[test] - fn replace_distinct() -> datafusion_common::Result<()> { + fn eliminate_redundant_distinct_simple() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a"), col("b")])? + .aggregate(vec![col("c")], Vec::::new())? + .project(vec![col("c")])? .distinct()? .build()?; - let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\ - \n Projection: test.a, test.b\ - \n TableScan: test"; + let expected = "Projection: test.c\n Aggregate: groupBy=[[test.c]], aggr=[[]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan, - expected, - ) + #[test] + fn eliminate_redundant_distinct_pair() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b")], Vec::::new())? + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = + "Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) } #[test] - fn replace_distinct_on() -> datafusion_common::Result<()> { + fn do_not_eliminate_distinct() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .distinct_on( - vec![col("a")], - vec![col("b")], - Some(vec![col("a").sort(false, true), col("c").sort(true, false)]), - )? + .project(vec![col("a"), col("b")])? + .distinct()? .build()?; - let expected = "Projection: FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\ - \n Sort: test.a DESC NULLS FIRST\ - \n Aggregate: groupBy=[[test.a]], aggr=[[FIRST_VALUE(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\ - \n TableScan: test"; + let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) + } - assert_optimized_plan_eq( - Arc::new(ReplaceDistinctWithAggregate::new()), - plan, - expected, - ) + #[test] + fn do_not_eliminate_distinct_with_aggr() -> Result<()> { + let table_scan = test_table_scan().unwrap(); + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("a"), col("b"), col("c")], vec![sum(col("c"))])? + .project(vec![col("a"), col("b")])? + .distinct()? + .build()?; + + let expected = + "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\n Projection: test.a, test.b\n Aggregate: groupBy=[[test.a, test.b, test.c]], aggr=[[sum(test.c)]]\n TableScan: test"; + assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs deleted file mode 100644 index 059b1452ff3d..000000000000 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ /dev/null @@ -1,429 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`RewriteDisjunctivePredicate`] rewrites predicates to reduce redundancy - -use crate::optimizer::ApplyOrder; -use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::Result; -use datafusion_expr::expr::BinaryExpr; -use datafusion_expr::logical_plan::Filter; -use datafusion_expr::{Expr, LogicalPlan, Operator}; - -/// Optimizer pass that rewrites predicates of the form -/// -/// ```text -/// (A = B AND ) OR (A = B AND ) OR ... (A = B AND ) -/// ``` -/// -/// Into -/// ```text -/// (A = B) AND ( OR OR ... ) -/// ``` -/// -/// Predicates connected by `OR` typically not able to be broken down -/// and distributed as well as those connected by `AND`. -/// -/// The idea is to rewrite predicates into `good_predicate1 AND -/// good_predicate2 AND ...` where `good_predicate` means the -/// predicate has special support in the execution engine. -/// -/// Equality join predicates (e.g. `col1 = col2`), or single column -/// expressions (e.g. `col = 5`) are examples of predicates with -/// special support. -/// -/// # TPCH Q19 -/// -/// This optimization is admittedly somewhat of a niche usecase. It's -/// main use is that it appears in TPCH Q19 and is required to avoid a -/// CROSS JOIN. -/// -/// Specifically, Q19 has a WHERE clause that looks like -/// -/// ```sql -/// where -/// p_partkey = l_partkey -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// and ( -/// ( -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// ) -/// ) -/// ``` -/// -/// Naively planning this query will result in a CROSS join with that -/// single large OR filter. However, rewriting it using the rewrite in -/// this pass results in a proper join predicate, `p_partkey = l_partkey`: -/// -/// ```sql -/// where -/// p_partkey = l_partkey -/// and l_shipmode in (‘AIR’, ‘AIR REG’) -/// and l_shipinstruct = ‘DELIVER IN PERSON’ -/// and ( -/// ( -/// and p_brand = ‘[BRAND1]’ -/// and p_container in ( ‘SM CASE’, ‘SM BOX’, ‘SM PACK’, ‘SM PKG’) -/// and l_quantity >= [QUANTITY1] and l_quantity <= [QUANTITY1] + 10 -/// and p_size between 1 and 5 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND2]’ -/// and p_container in (‘MED BAG’, ‘MED BOX’, ‘MED PKG’, ‘MED PACK’) -/// and l_quantity >= [QUANTITY2] and l_quantity <= [QUANTITY2] + 10 -/// and p_size between 1 and 10 -/// ) -/// or -/// ( -/// and p_brand = ‘[BRAND3]’ -/// and p_container in ( ‘LG CASE’, ‘LG BOX’, ‘LG PACK’, ‘LG PKG’) -/// and l_quantity >= [QUANTITY3] and l_quantity <= [QUANTITY3] + 10 -/// and p_size between 1 and 15 -/// ) -/// ) -/// ``` -/// -#[derive(Default)] -pub struct RewriteDisjunctivePredicate; - -impl RewriteDisjunctivePredicate { - pub fn new() -> Self { - Self - } -} - -impl OptimizerRule for RewriteDisjunctivePredicate { - fn try_optimize( - &self, - plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - match plan { - LogicalPlan::Filter(filter) => { - let predicate = predicate(&filter.predicate)?; - let rewritten_predicate = rewrite_predicate(predicate); - let rewritten_expr = normalize_predicate(rewritten_predicate); - Ok(Some(LogicalPlan::Filter(Filter::try_new( - rewritten_expr, - filter.input.clone(), - )?))) - } - _ => Ok(None), - } - } - - fn name(&self) -> &str { - "rewrite_disjunctive_predicate" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } -} - -#[derive(Clone, PartialEq, Debug)] -enum Predicate { - And { args: Vec }, - Or { args: Vec }, - Other { expr: Box }, -} - -fn predicate(expr: &Expr) -> Result { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::And => { - let args = vec![predicate(left)?, predicate(right)?]; - Ok(Predicate::And { args }) - } - Operator::Or => { - let args = vec![predicate(left)?, predicate(right)?]; - Ok(Predicate::Or { args }) - } - _ => Ok(Predicate::Other { - expr: Box::new(Expr::BinaryExpr(BinaryExpr::new( - left.clone(), - *op, - right.clone(), - ))), - }), - }, - _ => Ok(Predicate::Other { - expr: Box::new(expr.clone()), - }), - } -} - -fn normalize_predicate(predicate: Predicate) -> Expr { - match predicate { - Predicate::And { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::and) - .expect("had more than one arg") - } - Predicate::Or { args } => { - assert!(args.len() >= 2); - args.into_iter() - .map(normalize_predicate) - .reduce(Expr::or) - .expect("had more than one arg") - } - Predicate::Other { expr } => *expr, - } -} - -fn rewrite_predicate(predicate: Predicate) -> Predicate { - match predicate { - Predicate::And { args } => { - let mut rewritten_args = Vec::with_capacity(args.len()); - for arg in args.iter() { - rewritten_args.push(rewrite_predicate(arg.clone())); - } - rewritten_args = flatten_and_predicates(rewritten_args); - Predicate::And { - args: rewritten_args, - } - } - Predicate::Or { args } => { - let mut rewritten_args = vec![]; - for arg in args.iter() { - rewritten_args.push(rewrite_predicate(arg.clone())); - } - rewritten_args = flatten_or_predicates(rewritten_args); - delete_duplicate_predicates(&rewritten_args) - } - Predicate::Other { expr } => Predicate::Other { - expr: Box::new(*expr), - }, - } -} - -fn flatten_and_predicates( - and_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in and_predicates { - match predicate { - Predicate::And { args } => { - flattened_predicates - .extend_from_slice(flatten_and_predicates(args).as_slice()); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn flatten_or_predicates( - or_predicates: impl IntoIterator, -) -> Vec { - let mut flattened_predicates = vec![]; - for predicate in or_predicates { - match predicate { - Predicate::Or { args } => { - flattened_predicates - .extend_from_slice(flatten_or_predicates(args).as_slice()); - } - _ => { - flattened_predicates.push(predicate); - } - } - } - flattened_predicates -} - -fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate { - let mut shortest_exprs: Vec = vec![]; - let mut shortest_exprs_len = 0; - // choose the shortest AND predicate - for or_predicate in or_predicates.iter() { - match or_predicate { - Predicate::And { args } => { - let args_num = args.len(); - if shortest_exprs.is_empty() || args_num < shortest_exprs_len { - shortest_exprs = (*args).clone(); - shortest_exprs_len = args_num; - } - } - _ => { - // if there is no AND predicate, it must be the shortest expression. - shortest_exprs = vec![or_predicate.clone()]; - break; - } - } - } - - // dedup shortest_exprs - shortest_exprs.dedup(); - - // Check each element in shortest_exprs to see if it's in all the OR arguments. - let mut exist_exprs: Vec = vec![]; - for expr in shortest_exprs.iter() { - let found = or_predicates.iter().all(|or_predicate| match or_predicate { - Predicate::And { args } => args.contains(expr), - _ => or_predicate == expr, - }); - if found { - exist_exprs.push((*expr).clone()); - } - } - if exist_exprs.is_empty() { - return Predicate::Or { - args: or_predicates.to_vec(), - }; - } - - // Rebuild the OR predicate. - // (A AND B) OR A will be optimized to A. - let mut new_or_predicates = vec![]; - for or_predicate in or_predicates.iter() { - match or_predicate { - Predicate::And { args } => { - let mut new_args = (*args).clone(); - new_args.retain(|expr| !exist_exprs.contains(expr)); - if !new_args.is_empty() { - if new_args.len() == 1 { - new_or_predicates.push(new_args[0].clone()); - } else { - new_or_predicates.push(Predicate::And { args: new_args }); - } - } else { - new_or_predicates.clear(); - break; - } - } - _ => { - if exist_exprs.contains(or_predicate) { - new_or_predicates.clear(); - break; - } - } - } - } - if !new_or_predicates.is_empty() { - if new_or_predicates.len() == 1 { - exist_exprs.push(new_or_predicates[0].clone()); - } else { - exist_exprs.push(Predicate::Or { - args: flatten_or_predicates(new_or_predicates), - }); - } - } - - if exist_exprs.len() == 1 { - exist_exprs[0].clone() - } else { - Predicate::And { - args: flatten_and_predicates(exist_exprs), - } - } -} - -#[cfg(test)] -mod tests { - use crate::rewrite_disjunctive_predicate::{ - normalize_predicate, predicate, rewrite_predicate, Predicate, - }; - - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::{and, col, lit, or}; - - #[test] - fn test_rewrite_predicate() -> Result<()> { - let equi_expr = col("t1.a").eq(col("t2.b")); - let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1)))); - let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2)))); - let expr = or( - and(equi_expr.clone(), gt_expr.clone()), - and(equi_expr.clone(), lt_expr.clone()), - ); - let predicate = predicate(&expr)?; - assert_eq!( - predicate, - Predicate::Or { - args: vec![ - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - ] - }, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_predicate = rewrite_predicate(predicate); - assert_eq!( - rewritten_predicate, - Predicate::And { - args: vec![ - Predicate::Other { - expr: Box::new(equi_expr.clone()) - }, - Predicate::Or { - args: vec![ - Predicate::Other { - expr: Box::new(gt_expr.clone()) - }, - Predicate::Other { - expr: Box::new(lt_expr.clone()) - }, - ] - }, - ] - } - ); - let rewritten_expr = normalize_predicate(rewritten_predicate); - assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr))); - Ok(()) - } -} diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 3ee6af415e08..2e2c8fb1d6f8 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -29,14 +29,14 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; /// Optimizer rule for rewriting subquery filters to joins -#[derive(Default)] +#[derive(Default, Debug)] pub struct ScalarSubqueryToJoin {} impl ScalarSubqueryToJoin { @@ -50,11 +50,11 @@ impl ScalarSubqueryToJoin { /// # Arguments /// * `predicate` - A conjunction to split and search /// - /// Returns a tuple (subqueries, rewrite expression) + /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( &self, predicate: &Expr, - alias_gen: Arc, + alias_gen: &Arc, ) -> Result<(Vec<(Subquery, String)>, Expr)> { let mut extract = ExtractScalarSubQuery { sub_query_info: vec![], @@ -69,21 +69,30 @@ impl ScalarSubqueryToJoin { } impl OptimizerRule for ScalarSubqueryToJoin { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { LogicalPlan::Filter(filter) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !contains_scalar_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), )?; if subqueries.is_empty() { - // regular filter, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in filter"); } // iterate through all subqueries in predicate, turning each into a left join @@ -94,16 +103,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = - expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -113,15 +119,21 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } } let new_plan = LogicalPlanBuilder::from(cur_input) .filter(rewrite_expr)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !projection.expr.iter().any(contains_scalar_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(projection))); + } + let mut all_subqueryies = vec![]; let mut expr_to_rewrite_expr_map = HashMap::new(); let mut subquery_to_expr_map = HashMap::new(); @@ -135,8 +147,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } if all_subqueryies.is_empty() { - // regular projection, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in projection"); } // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); @@ -153,14 +164,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_expr = rewrite_expr .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + }) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -172,15 +182,15 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Projection(projection))); } } let mut proj_exprs = vec![]; for expr in projection.expr.iter() { - let old_expr_name = expr.display_name()?; + let old_expr_name = expr.schema_name().to_string(); let new_expr = expr_to_rewrite_expr_map.get(expr).unwrap(); - let new_expr_name = new_expr.display_name()?; + let new_expr_name = new_expr.schema_name().to_string(); if new_expr_name != old_expr_name { proj_exprs.push(new_expr.clone().alias(old_expr_name)) } else { @@ -190,10 +200,10 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_plan = LogicalPlanBuilder::from(cur_input) .project(proj_exprs)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } - _ => Ok(None), + plan => Ok(Transformed::no(plan)), } } @@ -206,12 +216,19 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } -struct ExtractScalarSubQuery { +/// Returns true if the expression has a scalar subquery somewhere in it +/// false otherwise +fn contains_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("Inner is always Ok") +} + +struct ExtractScalarSubQuery<'a> { sub_query_info: Vec<(Subquery, String)>, - alias_gen: Arc, + alias_gen: &'a Arc, } -impl TreeNodeRewriter for ExtractScalarSubQuery { +impl TreeNodeRewriter for ExtractScalarSubQuery<'_> { type Node = Expr; fn f_down(&mut self, expr: Expr) -> Result> { @@ -280,16 +297,7 @@ fn build_join( subquery_alias: &str, ) -> Result)>> { let subquery_plan = subquery.subquery.as_ref(); - let mut pull_up = PullUpCorrelatedExpr { - join_filters: vec![], - correlated_subquery_cols_map: Default::default(), - in_predicate_opt: None, - exists_sub_query: false, - can_pull_up: true, - need_handle_count_bug: true, - collected_count_expr_map: Default::default(), - pull_up_having_expr: None, - }; + let mut pull_up = PullUpCorrelatedExpr::new().with_need_handle_count_bug(true); let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); @@ -310,8 +318,7 @@ fn build_join( // alias the join filter let join_filter_opt = conjunction(pull_up.join_filters).map_or(Ok(None), |filter| { - replace_qualified_name(filter, &all_correlated_cols, subquery_alias) - .map(Option::Some) + replace_qualified_name(filter, &all_correlated_cols, subquery_alias).map(Some) })?; // join our sub query into the main plan @@ -385,9 +392,10 @@ mod tests { use crate::test::*; use arrow::datatypes::DataType; - use datafusion_expr::{ - col, lit, max, min, out_ref_col, scalar_subquery, sum, Between, - }; + use datafusion_expr::test::function_stub::sum; + + use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; + use datafusion_functions_aggregate::min_max::{max, min}; /// Test multiple correlated subqueries #[test] @@ -406,24 +414,24 @@ mod tests { let plan = LogicalPlanBuilder::from(scan_tpch_table("customer")) .filter( lit(1) - .lt(scalar_subquery(orders.clone())) + .lt(scalar_subquery(Arc::clone(&orders))) .and(lit(1).lt(scalar_subquery(orders))), )? .project(vec![col("customer.c_custkey")])? .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: Int32(1) < __scalar_sq_1.MAX(orders.o_custkey) AND Int32(1) < __scalar_sq_2.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Filter: Int32(1) < __scalar_sq_1.max(orders.o_custkey) AND Int32(1) < __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_2.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], @@ -468,18 +476,18 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_acctbal < __scalar_sq_1.SUM(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, SUM(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, SUM(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_acctbal < __scalar_sq_1.sum(orders.o_totalprice) [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_1.o_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8, sum(orders.o_totalprice):Float64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [SUM(orders.o_totalprice):Float64;N, o_custkey:Int64]\ - \n Projection: SUM(orders.o_totalprice), orders.o_custkey [SUM(orders.o_totalprice):Float64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __scalar_sq_2.SUM(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ - \n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\ + \n Projection: sum(orders.o_totalprice), orders.o_custkey [sum(orders.o_totalprice):Float64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] [o_custkey:Int64, sum(orders.o_totalprice):Float64;N]\ + \n Filter: orders.o_totalprice < __scalar_sq_2.sum(lineitem.l_extendedprice) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ + \n Left Join: Filter: __scalar_sq_2.l_orderkey = orders.o_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ - \n Projection: SUM(lineitem.l_extendedprice), lineitem.l_orderkey [SUM(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ - \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] [l_orderkey:Int64, SUM(lineitem.l_extendedprice):Float64;N]\ + \n SubqueryAlias: __scalar_sq_2 [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ + \n Projection: sum(lineitem.l_extendedprice), lineitem.l_orderkey [sum(lineitem.l_extendedprice):Float64;N, l_orderkey:Int64]\ + \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[sum(lineitem.l_extendedprice)]] [l_orderkey:Int64, sum(lineitem.l_extendedprice):Float64;N]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], @@ -510,12 +518,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; @@ -548,12 +556,12 @@ mod tests { // it will optimize, but fail for the same reason the unoptimized query would let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ - \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ - \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], @@ -580,12 +588,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ - \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ - \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; @@ -616,11 +624,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) != orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) != orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -643,11 +661,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) < orders.o_custkey"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) < orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -671,11 +699,21 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - let expected = "check_analyzed_plan\ - \ncaused by\ - \nError during planning: Correlated column is not allowed in predicate: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1)"; + // Unsupported predicate, subquery should not be decorrelated + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n Filter: customer.c_custkey = () [c_custkey:Int64, c_name:Utf8]\ + \n Subquery: [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ + \n Filter: outer_ref(customer.c_custkey) = orders.o_custkey OR orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_multi_rules_optimized_plan_eq_display_indent( + vec![Arc::new(ScalarSubqueryToJoin::new())], + plan, + expected, + ); Ok(()) } @@ -696,7 +734,7 @@ mod tests { let expected = "check_analyzed_plan\ \ncaused by\ \nError during planning: Scalar subquery should only return one column"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_analyzer_check_err(vec![], plan, expected); Ok(()) } @@ -710,7 +748,7 @@ mod tests { .eq(col("orders.o_custkey")), )? .aggregate(Vec::::new(), vec![max(col("orders.o_custkey"))])? - .project(vec![col("MAX(orders.o_custkey)").add(lit(1))])? + .project(vec![col("max(orders.o_custkey)").add(lit(1))])? .build()?, ); @@ -720,12 +758,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) + Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey) + Int32(1), orders.o_custkey [MAX(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey) + Int32(1), orders.o_custkey [max(orders.o_custkey) + Int32(1):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -758,7 +796,7 @@ mod tests { let expected = "check_analyzed_plan\ \ncaused by\ \nError during planning: Scalar subquery should only return one column"; - assert_analyzer_check_err(vec![], &plan, expected); + assert_analyzer_check_err(vec![], plan, expected); Ok(()) } @@ -786,12 +824,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey >= __scalar_sq_1.MAX(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_custkey >= __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -825,12 +863,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) AND customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -865,12 +903,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) OR customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -898,12 +936,12 @@ mod tests { .build()?; let expected = "Projection: test.c [c:UInt32]\ - \n Filter: test.c < __scalar_sq_1.MIN(sq.c) [a:UInt32, b:UInt32, c:UInt32, MIN(sq.c):UInt32;N, a:UInt32;N]\ - \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, MIN(sq.c):UInt32;N, a:UInt32;N]\ + \n Filter: test.c < __scalar_sq_1.min(sq.c) [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\ + \n Left Join: Filter: test.a = __scalar_sq_1.a [a:UInt32, b:UInt32, c:UInt32, min(sq.c):UInt32;N, a:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ - \n SubqueryAlias: __scalar_sq_1 [MIN(sq.c):UInt32;N, a:UInt32]\ - \n Projection: MIN(sq.c), sq.a [MIN(sq.c):UInt32;N, a:UInt32]\ - \n Aggregate: groupBy=[[sq.a]], aggr=[[MIN(sq.c)]] [a:UInt32, MIN(sq.c):UInt32;N]\ + \n SubqueryAlias: __scalar_sq_1 [min(sq.c):UInt32;N, a:UInt32]\ + \n Projection: min(sq.c), sq.a [min(sq.c):UInt32;N, a:UInt32]\ + \n Aggregate: groupBy=[[sq.a]], aggr=[[min(sq.c)]] [a:UInt32, min(sq.c):UInt32;N]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -930,12 +968,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey < __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ - \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Filter: customer.c_custkey < __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ - \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -961,12 +999,12 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey = __scalar_sq_1.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ - \n Left Join: [c_custkey:Int64, c_name:Utf8, MAX(orders.o_custkey):Int64;N]\ + \n Filter: customer.c_custkey = __scalar_sq_1.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, max(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MAX(orders.o_custkey):Int64;N]\ - \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -1013,17 +1051,17 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.MIN(orders.o_custkey) AND __scalar_sq_2.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, o_custkey:Int64;N, MAX(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ - \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N, max(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ + \n Left Join: Filter: customer.c_custkey = __scalar_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, o_custkey:Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MIN(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MIN(orders.o_custkey), orders.o_custkey [MIN(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MIN(orders.o_custkey)]] [o_custkey:Int64, MIN(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: min(orders.o_custkey), orders.o_custkey [min(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[min(orders.o_custkey)]] [o_custkey:Int64, min(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Projection: MAX(orders.o_custkey), orders.o_custkey [MAX(orders.o_custkey):Int64;N, o_custkey:Int64]\ - \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Projection: max(orders.o_custkey), orders.o_custkey [max(orders.o_custkey):Int64;N, o_custkey:Int64]\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[max(orders.o_custkey)]] [o_custkey:Int64, max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( @@ -1062,17 +1100,17 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.MIN(orders.o_custkey) AND __scalar_sq_2.MAX(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, MAX(orders.o_custkey):Int64;N]\ - \n Left Join: [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N, MAX(orders.o_custkey):Int64;N]\ - \n Left Join: [c_custkey:Int64, c_name:Utf8, MIN(orders.o_custkey):Int64;N]\ + \n Filter: customer.c_custkey BETWEEN __scalar_sq_1.min(orders.o_custkey) AND __scalar_sq_2.max(orders.o_custkey) [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N, max(orders.o_custkey):Int64;N]\ + \n Left Join: [c_custkey:Int64, c_name:Utf8, min(orders.o_custkey):Int64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __scalar_sq_1 [MIN(orders.o_custkey):Int64;N]\ - \n Projection: MIN(orders.o_custkey) [MIN(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(orders.o_custkey)]] [MIN(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_1 [min(orders.o_custkey):Int64;N]\ + \n Projection: min(orders.o_custkey) [min(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[min(orders.o_custkey)]] [min(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __scalar_sq_2 [MAX(orders.o_custkey):Int64;N]\ - \n Projection: MAX(orders.o_custkey) [MAX(orders.o_custkey):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ + \n SubqueryAlias: __scalar_sq_2 [max(orders.o_custkey):Int64;N]\ + \n Projection: max(orders.o_custkey) [max(orders.o_custkey):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[max(orders.o_custkey)]] [max(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 2fb06e659d70..40be1f85391d 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -31,17 +31,19 @@ use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; -use datafusion_common::{ - internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, -}; -use datafusion_expr::expr::{InList, InSubquery}; +use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, - ScalarFunctionDefinition, Volatility, + and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, + WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; +use datafusion_expr::{ + expr::{InList, InSubquery, WindowFunction}, + utils::{iter_conjunction, iter_conjunction_owned}, +}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; +use indexmap::IndexSet; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -94,9 +96,12 @@ pub struct ExprSimplifier { /// Should expressions be canonicalized before simplification? Defaults to /// true canonicalize: bool, + /// Maximum number of simplifier cycles + max_simplifier_cycles: u32, } pub const THRESHOLD_INLINE_INLIST: usize = 3; +pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3; impl ExprSimplifier { /// Create a new `ExprSimplifier` with the given `info` such as an @@ -109,10 +114,11 @@ impl ExprSimplifier { info, guarantees: vec![], canonicalize: true, + max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES, } } - /// Simplifies this [`Expr`]`s as much as possible, evaluating + /// Simplifies this [`Expr`] as much as possible, evaluating /// constants and applying algebraic simplifications. /// /// The types of the expression must match what operators expect, @@ -173,7 +179,18 @@ impl ExprSimplifier { /// let expr = simplifier.simplify(expr).unwrap(); /// assert_eq!(expr, b_lt_2); /// ``` - pub fn simplify(&self, mut expr: Expr) -> Result { + pub fn simplify(&self, expr: Expr) -> Result { + Ok(self.simplify_with_cycle_count(expr)?.0) + } + + /// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating + /// constants and applying algebraic simplifications. Additionally returns a `u32` + /// representing the number of simplification cycles performed, which can be useful for testing + /// optimizations. + /// + /// See [Self::simplify] for details and usage examples. + /// + pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); @@ -183,24 +200,26 @@ impl ExprSimplifier { expr = expr.rewrite(&mut Canonicalizer::new()).data()? } - // TODO iterate until no changes are made during rewrite - // (evaluating constants can enable new simplifications and - // simplifications can enable new constant evaluation) - // https://github.com/apache/datafusion/issues/1160 - expr.rewrite(&mut const_evaluator) - .data()? - .rewrite(&mut simplifier) - .data()? - .rewrite(&mut guarantee_rewriter) - .data()? - // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator) - .data()? - .rewrite(&mut simplifier) - .data()? - // shorten inlist should be started after other inlist rules are applied - .rewrite(&mut shorten_in_list_simplifier) - .data() + // Evaluating constants can enable new simplifications and + // simplifications can enable new constant evaluation + // see `Self::with_max_cycles` + let mut num_cycles = 0; + loop { + let Transformed { + data, transformed, .. + } = expr + .rewrite(&mut const_evaluator)? + .transform_data(|expr| expr.rewrite(&mut simplifier))? + .transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?; + expr = data; + num_cycles += 1; + if !transformed || num_cycles >= self.max_simplifier_cycles { + break; + } + } + // shorten inlist should be started after other inlist rules are applied + expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?; + Ok((expr, num_cycles)) } /// Apply type coercion to an [`Expr`] so that it can be @@ -208,14 +227,8 @@ impl ExprSimplifier { /// /// See the [type coercion module](datafusion_expr::type_coercion) /// documentation for more details on type coercion - /// - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/datafusion/issues/3793 - pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { + pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() } @@ -278,7 +291,7 @@ impl ExprSimplifier { self } - /// Should [`Canonicalizer`] be applied before simplification? + /// Should `Canonicalizer` be applied before simplification? /// /// If true (the default), the expression will be rewritten to canonical /// form before simplification. This is useful to ensure that the simplifier @@ -331,6 +344,63 @@ impl ExprSimplifier { self.canonicalize = canonicalize; self } + + /// Specifies the maximum number of simplification cycles to run. + /// + /// The simplifier can perform multiple passes of simplification. This is + /// because the output of one simplification step can allow more optimizations + /// in another simplification step. For example, constant evaluation can allow more + /// expression simplifications, and expression simplifications can allow more constant + /// evaluations. + /// + /// This method specifies the maximum number of allowed iteration cycles before the simplifier + /// returns an [Expr] output. However, it does not always perform the maximum number of cycles. + /// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification + /// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals. + /// + /// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used + /// instead. + /// + /// ```rust + /// use arrow::datatypes::{DataType, Field, Schema}; + /// use datafusion_expr::{col, lit, Expr}; + /// use datafusion_common::{Result, ScalarValue, ToDFSchema}; + /// use datafusion_expr::execution_props::ExecutionProps; + /// use datafusion_expr::simplify::SimplifyContext; + /// use datafusion_optimizer::simplify_expressions::ExprSimplifier; + /// + /// let schema = Schema::new(vec![ + /// Field::new("a", DataType::Int64, false), + /// ]) + /// .to_dfschema_ref().unwrap(); + /// + /// // Create the simplifier + /// let props = ExecutionProps::new(); + /// let context = SimplifyContext::new(&props) + /// .with_schema(schema); + /// let simplifier = ExprSimplifier::new(context); + /// + /// // Expression: a IS NOT NULL + /// let expr = col("a").is_not_null(); + /// + /// // When using default maximum cycles, 2 cycles will be performed. + /// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap(); + /// assert_eq!(simplified_expr, lit(true)); + /// // 2 cycles were executed, but only 1 was needed + /// assert_eq!(count, 2); + /// + /// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1. + /// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap(); + /// // Expression has been rewritten to: (c = a AND b = 1) + /// assert_eq!(simplified_expr, lit(true)); + /// // Only 1 cycle was executed + /// assert_eq!(count, 1); + /// + /// ``` + pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self { + self.max_simplifier_cycles = max_simplifier_cycles; + self + } } /// Canonicalize any BinaryExprs that are not in canonical form @@ -410,8 +480,10 @@ struct ConstEvaluator<'a> { #[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { - // Expr was simplifed and contains the new expression + // Expr was simplified and contains the new expression Simplified(ScalarValue), + // Expr was not simplified and original value is returned + NotSimplified(ScalarValue), // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } @@ -449,7 +521,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting - // and may not evalute all their sub expressions. Thus if + // and may not evaluate all their sub expressions. Thus if // if any error is countered during simplification, return the original // so that normal evaluation can occur Some(true) => { @@ -458,6 +530,9 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { ConstSimplifyResult::Simplified(s) => { Ok(Transformed::yes(Expr::Literal(s))) } + ConstSimplifyResult::NotSimplified(s) => { + Ok(Transformed::no(Expr::Literal(s))) + } ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { Ok(Transformed::yes(expr)) } @@ -520,16 +595,12 @@ impl<'a> ConstEvaluator<'a> { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::WindowFunction { .. } - | Expr::Sort { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { - ScalarFunctionDefinition::UDF(fun) => { - Self::volatility_ok(fun.signature().volatility) - } - ScalarFunctionDefinition::Name(_) => false, - }, + Expr::ScalarFunction(ScalarFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } Expr::Literal(_) | Expr::Unnest(_) | Expr::BinaryExpr { .. } @@ -549,15 +620,14 @@ impl<'a> ConstEvaluator<'a> { | Expr::Case(_) | Expr::Cast { .. } | Expr::TryCast { .. } - | Expr::InList { .. } - | Expr::GetIndexedField { .. } => true, + | Expr::InList { .. } => true, } } /// Internal helper to evaluates an Expr pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult { if let Expr::Literal(s) = expr { - return ConstSimplifyResult::Simplified(s); + return ConstSimplifyResult::NotSimplified(s); } let phys_expr = @@ -587,12 +657,35 @@ impl<'a> ConstEvaluator<'a> { } else { // Non-ListArray match ScalarValue::try_from_array(&a, 0) { - Ok(s) => ConstSimplifyResult::Simplified(s), + Ok(s) => { + // TODO: support the optimization for `Map` type after support impl hash for it + if matches!(&s, ScalarValue::Map(_)) { + ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::NotImplemented("Const evaluate for Map type is still not supported".to_string()), + expr, + ) + } else { + ConstSimplifyResult::Simplified(s) + } + } Err(err) => ConstSimplifyResult::SimplifyRuntimeError(err, expr), } } } - ColumnarValue::Scalar(s) => ConstSimplifyResult::Simplified(s), + ColumnarValue::Scalar(s) => { + // TODO: support the optimization for `Map` type after support impl hash for it + if matches!(&s, ScalarValue::Map(_)) { + ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::NotImplemented( + "Const evaluate for Map type is still not supported" + .to_string(), + ), + expr, + ) + } else { + ConstSimplifyResult::Simplified(s) + } + } } } } @@ -641,7 +734,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - Transformed::yes(match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(&left)? { Some(true) => *right, Some(false) => Expr::Not(right), None => lit_bool_null(), @@ -655,7 +748,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - Transformed::yes(match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(&right)? { Some(true) => *left, Some(false) => Expr::Not(left), None => lit_bool_null(), @@ -672,7 +765,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - Transformed::yes(match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(&left)? { Some(true) => Expr::Not(right), Some(false) => *right, None => lit_bool_null(), @@ -686,7 +779,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - Transformed::yes(match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(&right)? { Some(true) => Expr::Not(left), Some(false) => *left, None => lit_bool_null(), @@ -727,7 +820,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) + Transformed::yes(lit(true)) } // !A OR A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -735,7 +828,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) + Transformed::yes(lit(true)) } // (..A..) OR A --> (..A..) Expr::BinaryExpr(BinaryExpr { @@ -749,21 +842,38 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), - // A OR (A AND B) --> A (if B not null) + // A OR (A AND B) --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { - Transformed::yes(*left) - } - // (A AND B) OR A --> A (if B not null) + }) if is_op_with(And, &right, &left) => Transformed::yes(*left), + // (A AND B) OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { - Transformed::yes(*right) + }) if is_op_with(And, &left, &right) => Transformed::yes(*right), + // Eliminate common factors in conjunctions e.g + // (A AND B) OR (A AND C) -> A AND (B OR C) + Expr::BinaryExpr(BinaryExpr { + left, + op: Or, + right, + }) if has_common_conjunction(&left, &right) => { + let lhs: IndexSet = iter_conjunction_owned(*left).collect(); + let (common, rhs): (Vec<_>, Vec<_>) = iter_conjunction_owned(*right) + .partition(|e| lhs.contains(e) && !e.is_volatile()); + + let new_rhs = rhs.into_iter().reduce(and); + let new_lhs = lhs.into_iter().filter(|e| !common.contains(e)).reduce(and); + let common_conjunction = common.into_iter().reduce(and).unwrap(); + + let new_expr = match (new_lhs, new_rhs) { + (Some(lhs), Some(rhs)) => and(common_conjunction, or(lhs, rhs)), + (_, _) => common_conjunction, + }; + Transformed::yes(new_expr) } // @@ -800,7 +910,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) + Transformed::yes(lit(false)) } // !A AND A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -808,7 +918,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) + Transformed::yes(lit(false)) } // (..A..) AND A --> (..A..) Expr::BinaryExpr(BinaryExpr { @@ -822,22 +932,18 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if expr_contains(&right, &left, And) => Transformed::yes(*right), - // A AND (A OR B) --> A (if B not null) + // A AND (A OR B) --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { - Transformed::yes(*left) - } - // (A OR B) AND A --> A (if B not null) + }) if is_op_with(Or, &right, &left) => Transformed::yes(*left), + // (A OR B) AND A --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { - Transformed::yes(*right) - } + }) if is_op_with(Or, &left, &right) => Transformed::yes(*right), // // Rules for Multiply @@ -939,7 +1045,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - Transformed::yes(lit(0)) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // @@ -1303,17 +1411,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Do a first pass at simplification out_expr.rewrite(self)? } - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(udf), - args, - }) => match udf.simplify(args, info)? { - ExprSimplifyResult::Original(args) => { - Transformed::no(Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(udf), - args, - })) + Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { + match udf.simplify(args, info)? { + ExprSimplifyResult::Original(args) => { + Transformed::no(Expr::ScalarFunction(ScalarFunction { + func: udf, + args, + })) + } + ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), + } + } + + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { + ref func, + .. + }) => match (func.simplify(), expr) { + (Some(simplify_function), Expr::AggregateFunction(af)) => { + Transformed::yes(simplify_function(af, info)?) + } + (_, expr) => Transformed::no(expr), + }, + + Expr::WindowFunction(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(ref udwf), + .. + }) => match (udwf.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf)) => { + Transformed::yes(simplify_function(wf, info)?) } - ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), + (_, expr) => Transformed::no(expr), }, // @@ -1410,7 +1537,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { let lhs = to_inlist(*left).unwrap(); @@ -1450,7 +1577,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1461,7 +1588,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { match (*left, *right) { (Expr::InList(l1), Expr::InList(l2)) => { - return inlist_intersection(l1, l2, false).map(Transformed::yes); + return inlist_intersection(l1, &l2, false).map(Transformed::yes); } // Matched previously once _ => unreachable!(), @@ -1470,7 +1597,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1490,7 +1617,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1501,7 +1628,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { match (*left, *right) { (Expr::InList(l1), Expr::InList(l2)) => { - return inlist_except(l1, l2).map(Transformed::yes); + return inlist_except(l1, &l2).map(Transformed::yes); } // Matched previously once _ => unreachable!(), @@ -1510,7 +1637,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::And, + op: And, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1521,7 +1648,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { match (*left, *right) { (Expr::InList(l1), Expr::InList(l2)) => { - return inlist_except(l2, l1).map(Transformed::yes); + return inlist_except(l2, &l1).map(Transformed::yes); } // Matched previously once _ => unreachable!(), @@ -1530,7 +1657,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::BinaryExpr(BinaryExpr { left, - op: Operator::Or, + op: Or, right, }) if are_inlist_and_eq_and_match_neg( left.as_ref(), @@ -1541,7 +1668,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { { match (*left, *right) { (Expr::InList(l1), Expr::InList(l2)) => { - return inlist_intersection(l1, l2, true).map(Transformed::yes); + return inlist_intersection(l1, &l2, true).map(Transformed::yes); } // Matched previously once _ => unreachable!(), @@ -1554,6 +1681,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +fn has_common_conjunction(lhs: &Expr, rhs: &Expr) -> bool { + let lhs_set: HashSet<&Expr> = iter_conjunction(lhs).collect(); + iter_conjunction(rhs).any(|e| lhs_set.contains(&e) && !e.is_volatile()) +} + // TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq_and_match_neg( left: &Expr, @@ -1651,7 +1783,7 @@ fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { /// Return the intersection of two inlist expressions /// maintaining the order of the elements in the two lists -fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { +fn inlist_intersection(mut l1: InList, l2: &InList, negated: bool) -> Result { let l2_items = l2.list.iter().collect::>(); // remove all items from l1 that are not in l2 @@ -1667,7 +1799,7 @@ fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result Result { +fn inlist_except(mut l1: InList, l2: &InList) -> Result { let l2_items = l2.list.iter().collect::>(); // keep only items from l1 that are not in l2 @@ -1681,18 +1813,25 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { #[cfg(test)] mod tests { + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; + use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; + use datafusion_expr::{ + function::{ + AccumulatorArgs, AggregateFunctionSimplification, + WindowFunctionSimplification, + }, + interval_arithmetic::Interval, + *, + }; + use datafusion_functions_window_common::field::WindowUDFFieldArgs; + use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, sync::Arc, }; - use datafusion_common::{assert_contains, ToDFSchema}; - use datafusion_expr::{interval_arithmetic::Interval, *}; - - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - use super::*; // ------------------------------ @@ -1713,8 +1852,9 @@ mod tests { fn basic_coercion() { let schema = test_schema(); let props = ExecutionProps::new(); - let simplifier = - ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema.clone())); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&props).with_schema(Arc::clone(&schema)), + ); // Note expr type is int32 (not int64) // (1i64 + 2i32) < i @@ -1722,11 +1862,7 @@ mod tests { // should fully simplify to 3 < i (though i has been coerced to i64) let expected = lit(3i64).lt(col("i")); - // Would be nice if this API could use the SimplifyInfo - // rather than creating an DFSchemaRef coerces rather than doing - // it manually. - // https://github.com/apache/datafusion/issues/3793 - let expr = simplifier.coerce(expr, schema).unwrap(); + let expr = simplifier.coerce(expr, &schema).unwrap(); assert_eq!(expected, simplifier.simplify(expr).unwrap()); } @@ -2059,11 +2195,11 @@ mod tests { #[test] fn test_simplify_modulo_by_one_non_null() { - let expr = col("c2_non_null") % lit(1); - let expected = lit(0); + let expr = col("c3_non_null") % lit(1); + let expected = lit(0_i64); assert_eq!(simplify(expr), expected); let expr = - col("c2_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)); + col("c3_non_null") % lit(ScalarValue::Decimal128(Some(10000000000), 31, 10)); assert_eq!(simplify(expr), expected); } @@ -2497,15 +2633,11 @@ mod tests { // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) let expr = or(l.clone(), r.clone()); - // no rewrites if c1 can be null - let expected = expr.clone(); + let expected = l.clone(); assert_eq!(simplify(expr), expected); // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) - let expr = or(l, r); - - // no rewrites if c1 can be null - let expected = expr.clone(); + let expr = or(r, l); assert_eq!(simplify(expr), expected); } @@ -2536,13 +2668,11 @@ mod tests { // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 let expr = and(l.clone(), r.clone()); - // no rewrites if c1 can be null - let expected = expr.clone(); + let expected = l.clone(); assert_eq!(simplify(expr), expected); // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 - let expr = and(l, r); - let expected = expr.clone(); + let expr = and(r, l); assert_eq!(simplify(expr), expected); } @@ -2640,11 +2770,10 @@ mod tests { // unsupported cases assert_no_change(regex_match(col("c1"), lit("foo.*"))); assert_no_change(regex_match(col("c1"), lit("(foo)"))); - assert_no_change(regex_match(col("c1"), lit("^foo"))); - assert_no_change(regex_match(col("c1"), lit("foo$"))); assert_no_change(regex_match(col("c1"), lit("%"))); assert_no_change(regex_match(col("c1"), lit("_"))); assert_no_change(regex_match(col("c1"), lit("f%o"))); + assert_no_change(regex_match(col("c1"), lit("^f%o"))); assert_no_change(regex_match(col("c1"), lit("f_o"))); // empty cases @@ -2737,13 +2866,20 @@ mod tests { assert_no_change(regex_match(col("c1"), lit("(foo|ba_r)*"))); assert_no_change(regex_match(col("c1"), lit("(fo_o|ba_r)*"))); assert_no_change(regex_match(col("c1"), lit("^(foo|bar)*"))); - assert_no_change(regex_match(col("c1"), lit("^foo|bar$"))); assert_no_change(regex_match(col("c1"), lit("^(foo)(bar)$"))); assert_no_change(regex_match(col("c1"), lit("^"))); assert_no_change(regex_match(col("c1"), lit("$"))); assert_no_change(regex_match(col("c1"), lit("$^"))); assert_no_change(regex_match(col("c1"), lit("$foo^"))); + // regular expressions that match a partial literal + assert_change(regex_match(col("c1"), lit("^foo")), like(col("c1"), "foo%")); + assert_change(regex_match(col("c1"), lit("foo$")), like(col("c1"), "%foo")); + assert_change( + regex_match(col("c1"), lit("^foo|bar$")), + like(col("c1"), "foo%").or(like(col("c1"), "%bar")), + ); + // OR-chain assert_change( regex_match(col("c1"), lit("foo|bar|baz")), @@ -2881,6 +3017,19 @@ mod tests { try_simplify(expr).unwrap() } + fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> { + let schema = expr_test_schema(); + let execution_props = ExecutionProps::new(); + let simplifier = ExprSimplifier::new( + SimplifyContext::new(&execution_props).with_schema(schema), + ); + simplifier.simplify_with_cycle_count(expr) + } + + fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) { + try_simplify_with_cycle_count(expr).unwrap() + } + fn simplify_with_guarantee( expr: Expr, guarantees: Vec<(Expr, NullableInterval)>, @@ -2896,7 +3045,7 @@ mod tests { fn expr_test_schema() -> DFSchemaRef { Arc::new( - DFSchema::from_unqualifed_fields( + DFSchema::from_unqualified_fields( vec![ Field::new("c1", DataType::Utf8, true), Field::new("c2", DataType::Boolean, true), @@ -3092,7 +3241,7 @@ mod tests { )], Some(Box::new(col("c2").eq(lit(true)))), )))), - col("c2").or(col("c2").not().and(col("c2"))) // #1716 + col("c2") ); // CASE WHEN ISNULL(c2) THEN true ELSE c2 @@ -3231,15 +3380,15 @@ mod tests { assert_eq!( simplify(in_list( col("c1"), - vec![scalar_subquery(subquery.clone())], + vec![scalar_subquery(Arc::clone(&subquery))], false )), - in_subquery(col("c1"), subquery.clone()) + in_subquery(col("c1"), Arc::clone(&subquery)) ); assert_eq!( simplify(in_list( col("c1"), - vec![scalar_subquery(subquery.clone())], + vec![scalar_subquery(Arc::clone(&subquery))], true )), not_in_subquery(col("c1"), subquery) @@ -3276,32 +3425,32 @@ mod tests { let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and( in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], false), ); - assert_eq!(simplify(expr.clone()), lit(false)); + assert_eq!(simplify(expr), lit(false)); // 2. c1 IN (1,2,3,4) AND c1 IN (4,5,6,7) -> c1 = 4 let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and( in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], false), ); - assert_eq!(simplify(expr.clone()), col("c1").eq(lit(4))); + assert_eq!(simplify(expr), col("c1").eq(lit(4))); // 3. c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (5, 6, 7, 8) -> true let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or( in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true), ); - assert_eq!(simplify(expr.clone()), lit(true)); + assert_eq!(simplify(expr), lit(true)); // 3.5 c1 NOT IN (1, 2, 3, 4) OR c1 NOT IN (4, 5, 6, 7) -> c1 != 4 (4 overlaps) let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).or( in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true), ); - assert_eq!(simplify(expr.clone()), col("c1").not_eq(lit(4))); + assert_eq!(simplify(expr), col("c1").not_eq(lit(4))); // 4. c1 NOT IN (1,2,3,4) AND c1 NOT IN (4,5,6,7) -> c1 NOT IN (1,2,3,4,5,6,7) let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], true).and( in_list(col("c1"), vec![lit(4), lit(5), lit(6), lit(7)], true), ); assert_eq!( - simplify(expr.clone()), + simplify(expr), in_list( col("c1"), vec![lit(1), lit(2), lit(3), lit(4), lit(5), lit(6), lit(7)], @@ -3314,7 +3463,7 @@ mod tests { in_list(col("c1"), vec![lit(2), lit(3), lit(4), lit(5)], false), ); assert_eq!( - simplify(expr.clone()), + simplify(expr), in_list( col("c1"), vec![lit(1), lit(2), lit(3), lit(4), lit(5)], @@ -3328,7 +3477,7 @@ mod tests { vec![lit(1), lit(2), lit(3), lit(4), lit(5)], true, )); - assert_eq!(simplify(expr.clone()), lit(false)); + assert_eq!(simplify(expr), lit(false)); // 7. c1 NOT IN (1,2,3,4) AND c1 IN (1,2,3,4,5) -> c1 = 5 let expr = @@ -3337,14 +3486,14 @@ mod tests { vec![lit(1), lit(2), lit(3), lit(4), lit(5)], false, )); - assert_eq!(simplify(expr.clone()), col("c1").eq(lit(5))); + assert_eq!(simplify(expr), col("c1").eq(lit(5))); // 8. c1 IN (1,2,3,4) AND c1 NOT IN (5,6,7,8) -> c1 IN (1,2,3,4) let expr = in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false).and( in_list(col("c1"), vec![lit(5), lit(6), lit(7), lit(8)], true), ); assert_eq!( - simplify(expr.clone()), + simplify(expr), in_list(col("c1"), vec![lit(1), lit(2), lit(3), lit(4)], false) ); @@ -3362,7 +3511,7 @@ mod tests { )) .and(in_list(col("c1"), vec![lit(3), lit(6)], false)); assert_eq!( - simplify(expr.clone()), + simplify(expr), col("c1").eq(lit(3)).or(col("c1").eq(lit(6))) ); @@ -3376,7 +3525,7 @@ mod tests { )) .and(in_list(col("c1"), vec![lit(8), lit(9), lit(10)], false)), ); - assert_eq!(simplify(expr.clone()), col("c1").eq(lit(8))); + assert_eq!(simplify(expr), col("c1").eq(lit(8))); // Contains non-InList expression // c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) -> c1 NOT IN (1,2,3,4) OR c1 != 5 OR c1 NOT IN (6,7,8,9) @@ -3491,7 +3640,7 @@ mod tests { let expr_x = col("c3").gt(lit(3_i64)); let expr_y = (col("c4") + lit(2_u32)).lt(lit(10_u32)); let expr_z = col("c1").in_list(vec![lit("a"), lit("b")], true); - let expr = expr_x.clone().and(expr_y.clone().or(expr_z)); + let expr = expr_x.clone().and(expr_y.or(expr_z)); // All guaranteed null let guarantees = vec![ @@ -3567,7 +3716,7 @@ mod tests { col("c4"), NullableInterval::from(ScalarValue::UInt32(Some(3))), )]; - let output = simplify_with_guarantee(expr.clone(), guarantees); + let output = simplify_with_guarantee(expr, guarantees); assert_eq!(&output, &expr_x); } @@ -3588,4 +3737,310 @@ mod tests { assert_eq!(simplify(expr), expected); } + + #[test] + fn test_simplify_cycles() { + // TRUE + let expr = lit(true); + let expected = lit(true); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 1); + + // (true != NULL) OR (5 > 10) + let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10))); + let expected = lit_bool_null(); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 2); + + // NOTE: this currently does not simplify + // (((c4 - 10) + 10) *100) / 100 + let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100); + let expected = expr.clone(); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 1); + + // ((c4<1 or c3<2) and c3_non_null<3) and false + let expr = col("c4") + .lt(lit(1)) + .or(col("c3").lt(lit(2))) + .and(col("c3_non_null").lt(lit(3))) + .and(lit(false)); + let expected = lit(false); + let (expr, num_iter) = simplify_with_cycle_count(expr); + assert_eq!(expr, expected); + assert_eq!(num_iter, 2); + } + + fn boolean_test_schema() -> DFSchemaRef { + Schema::new(vec![ + Field::new("A", DataType::Boolean, false), + Field::new("B", DataType::Boolean, false), + Field::new("C", DataType::Boolean, false), + Field::new("D", DataType::Boolean, false), + ]) + .to_dfschema_ref() + .unwrap() + } + + #[test] + fn simplify_common_factor_conjuction_in_disjunction() { + let props = ExecutionProps::new(); + let schema = boolean_test_schema(); + let simplifier = + ExprSimplifier::new(SimplifyContext::new(&props).with_schema(schema)); + + let a = || col("A"); + let b = || col("B"); + let c = || col("C"); + let d = || col("D"); + + // (A AND B) OR (A AND C) -> A AND (B OR C) + let expr = a().and(b()).or(a().and(c())); + let expected = a().and(b().or(c())); + + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // (A AND B) OR (A AND C) OR (A AND D) -> A AND (B OR C OR D) + let expr = a().and(b()).or(a().and(c())).or(a().and(d())); + let expected = a().and(b().or(c()).or(d())); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + + // A OR (B AND C AND A) -> A + let expr = a().or(b().and(c().and(a()))); + let expected = a(); + assert_eq!(expected, simplifier.simplify(expr).unwrap()); + } + + #[test] + fn test_simplify_udaf() { + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(aggregate_function_expr), expected); + + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = aggregate_function_expr.clone(); + assert_eq!(simplify(aggregate_function_expr), expected); + } + + /// A Mock UDAF which defines `simplify` to be used in tests + /// related to UDAF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdaf { + simplify: bool, + } + + impl SimplifyMockUdaf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl AggregateUDFImpl for SimplifyMockUdaf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { + unimplemented!("not needed for tests") + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + unimplemented!("not needed for testing") + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unimplemented!("not needed for testing") + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + } + + #[test] + fn test_simplify_udwf() { + let udwf = WindowFunctionDefinition::WindowUDF( + WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), + ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + + let expected = col("result_column"); + assert_eq!(simplify(window_function_expr), expected); + + let udwf = WindowFunctionDefinition::WindowUDF( + WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), + ); + let window_function_expr = + Expr::WindowFunction(WindowFunction::new(udwf, vec![])); + + let expected = window_function_expr.clone(); + assert_eq!(simplify(window_function_expr), expected); + } + + /// A Mock UDWF which defines `simplify` to be used in tests + /// related to UDWF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdwf { + simplify: bool, + } + + impl SimplifyMockUdwf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl WindowUDFImpl for SimplifyMockUdwf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + unimplemented!("not needed for tests") + } + + fn field(&self, _field_args: WindowUDFFieldArgs) -> Result { + unimplemented!("not needed for tests") + } + } + #[derive(Debug)] + struct VolatileUdf { + signature: Signature, + } + + impl VolatileUdf { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } + } + impl ScalarUDFImpl for VolatileUdf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "VolatileUdf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int16) + } + } + #[test] + fn test_optimize_volatile_conditions() { + let fun = Arc::new(ScalarUDF::new_from_impl(VolatileUdf::new())); + let rand = Expr::ScalarFunction(ScalarFunction::new_udf(fun, vec![])); + { + let expr = rand + .clone() + .eq(lit(0)) + .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))); + + assert_eq!(simplify(expr.clone()), expr); + } + + { + let expr = col("column1") + .eq(lit(2)) + .or(col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))); + + assert_eq!(simplify(expr), col("column1").eq(lit(2))); + } + + { + let expr = (col("column1").eq(lit(2)).and(rand.clone().eq(lit(0)))).or(col( + "column1", + ) + .eq(lit(2)) + .and(rand.clone().eq(lit(0)))); + + assert_eq!( + simplify(expr), + col("column1") + .eq(lit(2)) + .and((rand.clone().eq(lit(0))).or(rand.clone().eq(lit(0)))) + ); + } + } } diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 9d8e3fecebc1..afcbe528083b 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -39,7 +39,7 @@ use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; /// See a full example in [`ExprSimplifier::with_guarantees()`]. /// /// [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees -pub(crate) struct GuaranteeRewriter<'a> { +pub struct GuaranteeRewriter<'a> { guarantees: HashMap<&'a Expr, &'a NullableInterval>, } @@ -170,7 +170,7 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { .filter_map(|expr| { if let Expr::Literal(item) = expr { match interval - .contains(&NullableInterval::from(item.clone())) + .contains(NullableInterval::from(item.clone())) { // If we know for certain the value isn't in the column's interval, // we can skip checking it. @@ -225,12 +225,12 @@ mod tests { // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); + let output = expr.rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); + let output = expr.rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(true)); } @@ -261,72 +261,6 @@ mod tests { } } - #[test] - fn test_inequalities_non_null_bounded() { - let guarantees = vec![ - // x ∈ [1, 3] (not null) - ( - col("x"), - NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), - }, - ), - // s.y ∈ [1, 3] (not null) - ( - col("s").field("y"), - NullableInterval::NotNull { - values: Interval::make(Some(1_i32), Some(3_i32)).unwrap(), - }, - ), - ]; - - let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - - // (original_expr, expected_simplification) - let simplified_cases = &[ - (col("x").lt(lit(0)), false), - (col("s").field("y").lt(lit(0)), false), - (col("x").lt_eq(lit(3)), true), - (col("x").gt(lit(3)), false), - (col("x").gt(lit(0)), true), - (col("x").eq(lit(0)), false), - (col("x").not_eq(lit(0)), true), - (col("x").between(lit(0), lit(5)), true), - (col("x").between(lit(5), lit(10)), false), - (col("x").not_between(lit(0), lit(5)), false), - (col("x").not_between(lit(5), lit(10)), true), - ( - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsDistinctFrom, - right: Box::new(lit(ScalarValue::Null)), - }), - true, - ), - ( - Expr::BinaryExpr(BinaryExpr { - left: Box::new(col("x")), - op: Operator::IsDistinctFrom, - right: Box::new(lit(5)), - }), - true, - ), - ]; - - validate_simplified_cases(&mut rewriter, simplified_cases); - - let unchanged_cases = &[ - col("x").gt(lit(2)), - col("x").lt_eq(lit(2)), - col("x").eq(lit(2)), - col("x").not_eq(lit(2)), - col("x").between(lit(3), lit(5)), - col("x").not_between(lit(3), lit(10)), - ]; - - validate_unchanged_cases(&mut rewriter, unchanged_cases); - } - #[test] fn test_inequalities_non_null_unbounded() { let guarantees = vec![ diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 9dcb8ed15563..c8638eb72395 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -52,7 +52,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // expressions list.len() == 1 || list.len() <= THRESHOLD_INLINE_INLIST - && expr.try_into_col().is_ok() + && expr.try_as_col().is_some() ) { let first_val = list[0].clone(); diff --git a/datafusion/optimizer/src/simplify_expressions/mod.rs b/datafusion/optimizer/src/simplify_expressions/mod.rs index d0399fef07e6..46c066c11c0f 100644 --- a/datafusion/optimizer/src/simplify_expressions/mod.rs +++ b/datafusion/optimizer/src/simplify_expressions/mod.rs @@ -30,3 +30,6 @@ pub use datafusion_expr::simplify::{SimplifyContext, SimplifyInfo}; pub use expr_simplifier::*; pub use simplify_exprs::*; + +// Export for test in datafusion/core/tests/optimizer_integration.rs +pub use guarantees::GuaranteeRewriter; diff --git a/datafusion/optimizer/src/simplify_expressions/regex.rs b/datafusion/optimizer/src/simplify_expressions/regex.rs index 175b70f2b10e..6c99f18ab0f6 100644 --- a/datafusion/optimizer/src/simplify_expressions/regex.rs +++ b/datafusion/optimizer/src/simplify_expressions/regex.rs @@ -22,6 +22,19 @@ use regex_syntax::hir::{Capture, Hir, HirKind, Literal, Look}; /// Maximum number of regex alternations (`foo|bar|...`) that will be expanded into multiple `LIKE` expressions. const MAX_REGEX_ALTERNATIONS_EXPANSION: usize = 4; +/// Tries to convert a regexp expression to a `LIKE` or `Eq`/`NotEq` expression. +/// +/// This function also validates the regex pattern. And will return error if the +/// pattern is invalid. +/// +/// Typical cases this function can simplify: +/// - empty regex pattern to `LIKE '%'` +/// - literal regex patterns to `LIKE '%foo%'` +/// - full anchored regex patterns (e.g. `^foo$`) to `= 'foo'` +/// - partial anchored regex patterns (e.g. `^foo`) to `LIKE 'foo%'` +/// - combinations (alternatives) of the above, will be concatenated with `OR` or `AND` +/// +/// Dev note: unit tests of this function are in `expr_simplifier.rs`, case `test_simplify_regex`. pub fn simplify_regex_expr( left: Box, op: Operator, @@ -53,13 +66,15 @@ pub fn simplify_regex_expr( } } - // leave untouched if optimization didn't work + // Leave untouched if optimization didn't work Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })) } #[derive(Debug)] struct OperatorMode { + /// Negative match. not: bool, + /// Ignore case (`true` for case-insensitive). i: bool, } @@ -80,6 +95,7 @@ impl OperatorMode { Self { not, i } } + /// Creates an [`LIKE`](Expr::Like) from the given `LIKE` pattern. fn expr(&self, expr: Box, pattern: String) -> Expr { let like = Like { negated: self.not, @@ -92,6 +108,7 @@ impl OperatorMode { Expr::Like(like) } + /// Creates an [`Expr::BinaryExpr`] of "`left` = `right`" or "`left` != `right`". fn expr_matches_literal(&self, left: Box, right: Box) -> Expr { let op = if self.not { Operator::NotEq @@ -118,7 +135,7 @@ fn collect_concat_to_like_string(parts: &[Hir]) -> Option { Some(s) } -/// returns a str represented by `Literal` if it contains a valid utf8 +/// Returns a str represented by `Literal` if it contains a valid utf8 /// sequence and is safe for like (has no '%' and '_') fn like_str_from_literal(l: &Literal) -> Option<&str> { // if not utf8, no good @@ -131,7 +148,7 @@ fn like_str_from_literal(l: &Literal) -> Option<&str> { } } -/// returns a str represented by `Literal` if it contains a valid utf8 +/// Returns a str represented by `Literal` if it contains a valid utf8 fn str_from_literal(l: &Literal) -> Option<&str> { // if not utf8, no good let s = std::str::from_utf8(&l.0).ok()?; @@ -143,7 +160,7 @@ fn is_safe_for_like(c: char) -> bool { (c != '%') && (c != '_') } -/// returns true if the elements in a `Concat` pattern are: +/// Returns true if the elements in a `Concat` pattern are: /// - `[Look::Start, Look::End]` /// - `[Look::Start, Literal(_), Look::End]` fn is_anchored_literal(v: &[Hir]) -> bool { @@ -157,10 +174,9 @@ fn is_anchored_literal(v: &[Hir]) -> bool { v.last().expect("length checked"), ); if !matches!(first_last, - (s, e) if s.kind() == &HirKind::Look(Look::Start) + (s, e) if s.kind() == &HirKind::Look(Look::Start) && e.kind() == &HirKind::Look(Look::End) - ) - { + ) { return false; } @@ -170,7 +186,7 @@ fn is_anchored_literal(v: &[Hir]) -> bool { .all(|h| matches!(h.kind(), HirKind::Literal(_))) } -/// returns true if the elements in a `Concat` pattern are: +/// Returns true if the elements in a `Concat` pattern are: /// - `[Look::Start, Capture(Alternation(Literals...)), Look::End]` fn is_anchored_capture(v: &[Hir]) -> bool { if v.len() != 3 @@ -197,7 +213,34 @@ fn is_anchored_capture(v: &[Hir]) -> bool { true } -/// extracts a string literal expression assuming that [`is_anchored_literal`] +/// Returns the `LIKE` pattern if the `Concat` pattern is partial anchored: +/// - `[Look::Start, Literal(_)]` +/// - `[Literal(_), Look::End]` +/// +/// Full anchored patterns are handled by [`anchored_literal_to_expr`]. +fn partial_anchored_literal_to_like(v: &[Hir]) -> Option { + if v.len() != 2 { + return None; + } + + let (lit, match_begin) = match (&v[0].kind(), &v[1].kind()) { + (HirKind::Look(Look::Start), HirKind::Literal(l)) => { + (like_str_from_literal(l)?, true) + } + (HirKind::Literal(l), HirKind::Look(Look::End)) => { + (like_str_from_literal(l)?, false) + } + _ => return None, + }; + + if match_begin { + Some(format!("{}%", lit)) + } else { + Some(format!("%{}", lit)) + } +} + +/// Extracts a string literal expression assuming that [`is_anchored_literal`] /// returned true. fn anchored_literal_to_expr(v: &[Hir]) -> Option { match v.len() { @@ -246,6 +289,7 @@ fn anchored_alternation_to_exprs(v: &[Hir]) -> Option> { None } +/// Tries to lower (transform) a simple regex pattern to a LIKE expression. fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { match hir.kind() { HirKind::Empty => { @@ -265,7 +309,9 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { .map(|right| left.clone().in_list(right, mode.not)); } HirKind::Concat(inner) => { - if let Some(pattern) = collect_concat_to_like_string(inner) { + if let Some(pattern) = partial_anchored_literal_to_like(inner) + .or(collect_concat_to_like_string(inner)) + { return Some(mode.expr(Box::new(left.clone()), pattern)); } } @@ -274,6 +320,9 @@ fn lower_simple(mode: &OperatorMode, left: &Expr, hir: &Hir) -> Option { None } +/// Calls [`lower_simple`] for each alternative and combine the results with `or` or `and` +/// based on [`OperatorMode`]. Any fail attempt to lower an alternative will makes this +/// function to return `None`. fn lower_alt(mode: &OperatorMode, left: &Expr, alts: &[Hir]) -> Option { let mut accu: Option = None; diff --git a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs index d15d12b690da..200f1f159d81 100644 --- a/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs +++ b/datafusion/optimizer/src/simplify_expressions/simplify_exprs.rs @@ -20,13 +20,14 @@ use std::sync::Arc; use datafusion_common::tree_node::Transformed; -use datafusion_common::{internal_err, DFSchema, DFSchemaRef, DataFusionError, Result}; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::utils::merge_schema; use crate::optimizer::ApplyOrder; +use crate::utils::NamePreserver; use crate::{OptimizerConfig, OptimizerRule}; use super::ExprSimplifier; @@ -44,18 +45,10 @@ use super::ExprSimplifier; /// `Filter: b > 2` /// /// [`Expr`]: datafusion_expr::Expr -#[derive(Default)] +#[derive(Default, Debug)] pub struct SimplifyExpressions {} impl OptimizerRule for SimplifyExpressions { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called SimplifyExpressions::rewrite") - } - fn name(&self) -> &str { "simplify_expressions" } @@ -87,7 +80,7 @@ impl SimplifyExpressions { execution_props: &ExecutionProps, ) -> Result> { let schema = if !plan.inputs().is_empty() { - DFSchemaRef::new(merge_schema(plan.inputs())) + DFSchemaRef::new(merge_schema(&plan.inputs())) } else if let LogicalPlan::TableScan(scan) = &plan { // When predicates are pushed into a table scan, there is no input // schema to resolve predicates against, so it must be handled specially @@ -127,18 +120,13 @@ impl SimplifyExpressions { simplifier }; - // the output schema of a filter or join is the input schema. Thus they - // can't handle aliased expressions - let use_alias = !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)); + // Preserve expression names to avoid changing the schema of the plan. + let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|e| { - let new_e = if use_alias { - // TODO: unify with `rewrite_preserving_name` - let original_name = e.name_for_alias()?; - simplifier.simplify(e)?.alias_if_changed(original_name) - } else { - simplifier.simplify(e) - }?; - + let original_name = name_preserver.save(&e); + let new_e = simplifier + .simplify(e) + .map(|expr| original_name.restore(expr))?; // TODO it would be nice to have a way to know if the expression was simplified // or not. For now conservatively return Transformed::yes Ok(Transformed::yes(new_e)) @@ -168,6 +156,7 @@ mod tests { ExprSchemable, JoinType, }; use datafusion_expr::{or, BinaryExpr, Cast, Operator}; + use datafusion_functions_aggregate::expr_fn::{max, min}; use crate::test::{assert_fields_eq, test_table_scan_with_name}; use crate::OptimizerContext; @@ -194,7 +183,7 @@ mod tests { let optimizer = Optimizer::with_rules(vec![Arc::new(SimplifyExpressions::new())]); let optimized_plan = optimizer.optimize(plan, &OptimizerContext::new(), observe)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); Ok(()) } @@ -219,7 +208,7 @@ mod tests { assert_eq!(1, table_scan.schema().fields().len()); assert_fields_eq(&table_scan, vec!["a"]); - let expected = "TableScan: test projection=[a], full_filters=[Boolean(true) AS b IS NOT NULL]"; + let expected = "TableScan: test projection=[a], full_filters=[Boolean(true)]"; assert_optimized_plan_eq(table_scan, expected) } @@ -403,15 +392,12 @@ mod tests { .project(vec![col("a"), col("c"), col("b")])? .aggregate( vec![col("a"), col("c")], - vec![ - datafusion_expr::max(col("b").eq(lit(true))), - datafusion_expr::min(col("b")), - ], + vec![max(col("b").eq(lit(true))), min(col("b"))], )? .build()?; let expected = "\ - Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b) AS MAX(test.b = Boolean(true)), MIN(test.b)]]\ + Aggregate: groupBy=[[test.a, test.c]], aggr=[[max(test.b) AS max(test.b = Boolean(true)), min(test.b)]]\ \n Projection: test.a, test.c, test.b\ \n TableScan: test"; @@ -447,7 +433,7 @@ mod tests { let rule = SimplifyExpressions::new(); let optimized_plan = rule.rewrite(plan, &config).unwrap().data; - format!("{optimized_plan:?}") + format!("{optimized_plan}") } #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/utils.rs b/datafusion/optimizer/src/simplify_expressions/utils.rs index 5da727cb5990..c30c3631c193 100644 --- a/datafusion/optimizer/src/simplify_expressions/utils.rs +++ b/datafusion/optimizer/src/simplify_expressions/utils.rs @@ -67,16 +67,21 @@ pub static POWS_OF_TEN: [i128; 38] = [ /// returns true if `needle` is found in a chain of search_op /// expressions. Such as: (A AND B) AND C -pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { +fn expr_contains_inner(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == search_op => { - expr_contains(left, needle, search_op) - || expr_contains(right, needle, search_op) + expr_contains_inner(left, needle, search_op) + || expr_contains_inner(right, needle, search_op) } _ => expr == needle, } } +/// check volatile calls and return if expr contains needle +pub fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { + expr_contains_inner(expr, needle, search_op) && !needle.is_volatile() +} + /// Deletes all 'needles' or remains one 'needle' that are found in a chain of xor /// expressions. Such as: A ^ (A ^ (B ^ A)) pub fn delete_xor_in_complex_expr(expr: &Expr, needle: &Expr, is_left: bool) -> Expr { @@ -206,7 +211,7 @@ pub fn is_false(expr: &Expr) -> bool { /// returns true if `haystack` looks like (needle OP X) or (X OP needle) pub fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { - matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref())) + matches!(haystack, Expr::BinaryExpr(BinaryExpr { left, op, right }) if op == &target_op && (needle == left.as_ref() || needle == right.as_ref()) && !needle.is_volatile()) } /// returns true if `not_expr` is !`expr` (not) @@ -221,9 +226,9 @@ pub fn is_negative_of(not_expr: &Expr, expr: &Expr) -> bool { /// returns the contained boolean value in `expr` as /// `Expr::Literal(ScalarValue::Boolean(v))`. -pub fn as_bool_lit(expr: Expr) -> Result> { +pub fn as_bool_lit(expr: &Expr) -> Result> { match expr { - Expr::Literal(ScalarValue::Boolean(v)) => Ok(v), + Expr::Literal(ScalarValue::Boolean(v)) => Ok(*v), _ => internal_err!("Expected boolean literal, got {expr:?}"), } } diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index aaf4667fb000..01875349c922 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -22,15 +22,13 @@ use std::sync::Arc; use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{qualified_name, DFSchema, Result}; -use datafusion_expr::expr::AggregateFunctionDefinition; +use datafusion_common::{internal_err, tree_node::Transformed, DataFusionError, Result}; +use datafusion_expr::builder::project; use datafusion_expr::{ - aggregate_function::AggregateFunction::{Max, Min, Sum}, col, expr::AggregateFunction, - logical_plan::{Aggregate, LogicalPlan, Projection}, - utils::columnize_expr, - Expr, ExprSchemable, + logical_plan::{Aggregate, LogicalPlan}, + Expr, }; use hashbrown::HashSet; @@ -38,20 +36,20 @@ use hashbrown::HashSet; /// single distinct to group by optimizer rule /// ```text /// Before: -/// SELECT a, COUNT(DINSTINCT b), SUM(c) +/// SELECT a, count(DISTINCT b), sum(c) /// FROM t /// GROUP BY a /// /// After: -/// SELECT a, COUNT(alias1), SUM(alias2) +/// SELECT a, count(alias1), sum(alias2) /// FROM ( -/// SELECT a, b as alias1, SUM(c) as alias2 +/// SELECT a, b as alias1, sum(c) as alias2 /// FROM t /// GROUP BY a, b /// ) /// GROUP BY a /// ``` -#[derive(Default)] +#[derive(Default, Debug)] pub struct SingleDistinctToGroupBy {} const SINGLE_DISTINCT_ALIAS: &str = "alias1"; @@ -64,38 +62,38 @@ impl SingleDistinctToGroupBy { } /// Check whether all aggregate exprs are distinct on a single field. -fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => { - let mut fields_set = HashSet::new(); - let mut aggregate_count = 0; - for expr in aggr_expr { - if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - distinct, - args, - filter, - order_by, - null_treatment: _, - }) = expr - { - if filter.is_some() || order_by.is_some() { - return Ok(false); - } - aggregate_count += 1; - if *distinct { - for e in args { - fields_set.insert(e.canonical_name()); - } - } else if !matches!(fun, Sum | Min | Max) { - return Ok(false); - } +fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { + let mut fields_set = HashSet::new(); + let mut aggregate_count = 0; + for expr in aggr_expr { + if let Expr::AggregateFunction(AggregateFunction { + func, + distinct, + args, + filter, + order_by, + null_treatment: _, + }) = expr + { + if filter.is_some() || order_by.is_some() { + return Ok(false); + } + aggregate_count += 1; + if *distinct { + for e in args { + fields_set.insert(e); } + } else if func.name() != "sum" + && func.name().to_lowercase() != "min" + && func.name().to_lowercase() != "max" + { + return Ok(false); } - Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) + } else { + return Ok(false); } - _ => Ok(false), } + Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1) } /// Check if the first expr is [Expr::GroupingSet]. @@ -104,11 +102,23 @@ fn contains_grouping_set(expr: &[Expr]) -> bool { } impl OptimizerRule for SingleDistinctToGroupBy { - fn try_optimize( + fn name(&self) -> &str { + "single_distinct_aggregation_to_group_by" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result, DataFusionError> { match plan { LogicalPlan::Aggregate(Aggregate { input, @@ -116,202 +126,177 @@ impl OptimizerRule for SingleDistinctToGroupBy { schema, group_expr, .. - }) => { - if is_single_distinct_agg(plan)? && !contains_grouping_set(group_expr) { - // alias all original group_by exprs - let (mut inner_group_exprs, out_group_expr_with_alias): ( - Vec, - Vec<(Expr, Option)>, - ) = group_expr - .iter() - .enumerate() - .map(|(i, group_expr)| { - if let Expr::Column(_) = group_expr { - // For Column expressions we can use existing expression as is. - (group_expr.clone(), (group_expr.clone(), None)) - } else { - // For complex expression write is as alias, to be able to refer - // if from parent operators successfully. - // Consider plan below. - // - // Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ - // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ - // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] - // - // First aggregate(from bottom) refers to `test.a` column. - // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. - // If we were to write plan above as below without alias - // - // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ - // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ - // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] - // - // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. - let alias_str = format!("group_alias_{i}"); - let alias_expr = group_expr.clone().alias(&alias_str); - let (qualifier, field) = schema.qualified_field(i); - ( - alias_expr, - ( - col(alias_str), - Some(qualified_name(qualifier, field.name())), - ), - ) - } - }) - .unzip(); - - // and they can be referenced by the alias in the outer aggr plan - let outer_group_exprs = out_group_expr_with_alias - .iter() - .map(|(out_group_expr, _)| out_group_expr.clone()) - .collect::>(); - - // replace the distinct arg with alias - let mut index = 1; - let mut group_fields_set = HashSet::new(); - let mut inner_aggr_exprs = vec![]; - let outer_aggr_exprs = aggr_expr - .iter() - .map(|aggr_expr| match aggr_expr { - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - args, - distinct, - .. - }) => { - // is_single_distinct_agg ensure args.len=1 - if *distinct - && group_fields_set.insert(args[0].display_name()?) - { - inner_group_exprs.push( - args[0].clone().alias(SINGLE_DISTINCT_ALIAS), - ); + }) if is_single_distinct_agg(&aggr_expr)? + && !contains_grouping_set(&group_expr) => + { + let group_size = group_expr.len(); + // alias all original group_by exprs + let (mut inner_group_exprs, out_group_expr_with_alias): ( + Vec, + Vec<(Expr, _)>, + ) = group_expr + .into_iter() + .enumerate() + .map(|(i, group_expr)| { + if let Expr::Column(_) = group_expr { + // For Column expressions we can use existing expression as is. + (group_expr.clone(), (group_expr, None)) + } else { + // For complex expression write is as alias, to be able to refer + // if from parent operators successfully. + // Consider plan below. + // + // Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // First aggregate(from bottom) refers to `test.a` column. + // Second aggregate refers to the `group_alias_0` column, Which is a valid field in the first aggregate. + + // If we were to write plan above as below without alias + // + // Aggregate: groupBy=[[test.a + Int32(1)]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64;N]\ + // --Aggregate: groupBy=[[test.a + Int32(1), test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ + // ----TableScan: test [a:UInt32, b:UInt32, c:UInt32] + // + // Second aggregate refers to the `test.a + Int32(1)` expression However, its input do not have `test.a` expression in it. + let alias_str = format!("group_alias_{i}"); + let (qualifier, field) = schema.qualified_field(i); + ( + group_expr.alias(alias_str.clone()), + (col(alias_str), Some((qualifier, field.name()))), + ) + } + }) + .unzip(); + + // replace the distinct arg with alias + let mut index = 1; + let mut group_fields_set = HashSet::new(); + let mut inner_aggr_exprs = vec![]; + let outer_aggr_exprs = aggr_expr + .into_iter() + .map(|aggr_expr| match aggr_expr { + Expr::AggregateFunction(AggregateFunction { + func, + mut args, + distinct, + .. + }) => { + if distinct { + if args.len() != 1 { + return internal_err!("DISTINCT aggregate should have exactly one argument"); } + let arg = args.swap_remove(0); + if group_fields_set.insert(arg.schema_name().to_string()) { + inner_group_exprs + .push(arg.alias(SINGLE_DISTINCT_ALIAS)); + } + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + vec![col(SINGLE_DISTINCT_ALIAS)], + false, // intentional to remove distinct here + None, + None, + None, + ))) // if the aggregate function is not distinct, we need to rewrite it like two phase aggregation - if !(*distinct) { - index += 1; - let alias_str = format!("alias{}", index); - inner_aggr_exprs.push( - Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - args.clone(), - false, - None, - None, - None, - )) - .alias(&alias_str), - ); - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(&alias_str)], + } else { + index += 1; + let alias_str = format!("alias{}", index); + inner_aggr_exprs.push( + Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::clone(&func), + args, false, None, None, None, - ))) - } else { - Ok(Expr::AggregateFunction(AggregateFunction::new( - fun.clone(), - vec![col(SINGLE_DISTINCT_ALIAS)], - false, // intentional to remove distinct here - None, - None, - None, - ))) - } - } - _ => Ok(aggr_expr.clone()), - }) - .collect::>>()?; - - // construct the inner AggrPlan - let inner_fields = inner_group_exprs - .iter() - .chain(inner_aggr_exprs.iter()) - .map(|expr| expr.to_field(input.schema())) - .collect::>>()?; - let inner_schema = DFSchema::new_with_metadata( - inner_fields, - input.schema().metadata().clone(), - )?; - let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( - input.clone(), - inner_group_exprs, - inner_aggr_exprs, - )?); - - let outer_fields = outer_group_exprs - .iter() - .chain(outer_aggr_exprs.iter()) - .map(|expr| expr.to_field(&inner_schema)) - .collect::>>()?; - let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( - outer_fields, - input.schema().metadata().clone(), - )?); - - // so the aggregates are displayed in the same way even after the rewrite - // this optimizer has two kinds of alias: - // - group_by aggr - // - aggr expr - let group_size = group_expr.len(); - let alias_expr = out_group_expr_with_alias - .into_iter() - .map(|(group_expr, original_field)| { - if let Some(name) = original_field { - group_expr.alias(name) - } else { - group_expr + )) + .alias(&alias_str), + ); + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + func, + vec![col(&alias_str)], + false, + None, + None, + None, + ))) } - }) - .chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| { + } + _ => Ok(aggr_expr), + }) + .collect::>>()?; + + // construct the inner AggrPlan + let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( + input, + inner_group_exprs, + inner_aggr_exprs, + )?); + + let outer_group_exprs = out_group_expr_with_alias + .iter() + .map(|(expr, _)| expr.clone()) + .collect(); + + // so the aggregates are displayed in the same way even after the rewrite + // this optimizer has two kinds of alias: + // - group_by aggr + // - aggr expr + let alias_expr: Vec<_> = out_group_expr_with_alias + .into_iter() + .map(|(group_expr, original_name)| match original_name { + Some((qualifier, name)) => { + group_expr.alias_qualified(qualifier.cloned(), name) + } + None => group_expr, + }) + .chain(outer_aggr_exprs.iter().cloned().enumerate().map( + |(idx, expr)| { let idx = idx + group_size; let (qualifier, field) = schema.qualified_field(idx); - let name = qualified_name(qualifier, field.name()); - columnize_expr(expr.clone().alias(name), &outer_aggr_schema) - })) - .collect(); - - let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( - Arc::new(inner_agg), - outer_group_exprs, - outer_aggr_exprs, - )?); - - Ok(Some(LogicalPlan::Projection(Projection::try_new( - alias_expr, - Arc::new(outer_aggr), - )?))) - } else { - Ok(None) - } + expr.alias_qualified(qualifier.cloned(), field.name()) + }, + )) + .collect(); + + let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new( + Arc::new(inner_agg), + outer_group_exprs, + outer_aggr_exprs, + )?); + Ok(Transformed::yes(project(outer_aggr, alias_expr)?)) } - _ => Ok(None), + _ => Ok(Transformed::no(plan)), } } - - fn name(&self) -> &str { - "single_distinct_aggregation_to_group_by" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::TopDown) - } } #[cfg(test)] mod tests { use super::*; use crate::test::*; - use datafusion_expr::expr; use datafusion_expr::expr::GroupingSet; - use datafusion_expr::{ - count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, min, - sum, AggregateFunction, - }; + use datafusion_expr::ExprFunctionExt; + use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; + use datafusion_functions_aggregate::min_max::max_udaf; + use datafusion_functions_aggregate::sum::sum_udaf; + + fn max_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(AggregateFunction::new_udf( + max_udaf(), + vec![expr], + true, + None, + None, + None, + )) + } fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( @@ -332,7 +317,7 @@ mod tests { // Do nothing let expected = - "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ + "Aggregate: groupBy=[[]], aggr=[[max(test.b)]] [max(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -347,8 +332,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT test.b) [count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -370,7 +355,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -388,7 +373,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -407,7 +392,7 @@ mod tests { .build()?; // Should not be optimized - let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -421,8 +406,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] [COUNT(alias1):Int64;N]\ + let expected = "Projection: count(alias1) AS count(DISTINCT Int32(2) * test.b) [count(DISTINCT Int32(2) * test.b):Int64]\ + \n Aggregate: groupBy=[[]], aggr=[[count(alias1)]] [count(alias1):Int64]\ \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -438,8 +423,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1)]] [a:UInt32, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -458,7 +443,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(DISTINCT test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(DISTINCT test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -471,22 +456,12 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan) .aggregate( vec![col("a")], - vec![ - count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), - ], + vec![count_distinct(col("b")), max_distinct(col("b"))], )? .build()?; // Should work - let expected = "Projection: test.a, COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1), MAX(alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: test.a, count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[count(alias1), max(alias1)]] [a:UInt32, count(alias1):Int64, max(alias1):UInt32;N]\ \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -505,7 +480,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.a]], aggr=[[count(DISTINCT test.b), count(test.c)]] [a:UInt32, count(DISTINCT test.b):Int64, count(test.c):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -520,8 +495,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\ - \n Aggregate: groupBy=[[group_alias_0]], aggr=[[COUNT(alias1)]] [group_alias_0:Int32, COUNT(alias1):Int64;N]\ + let expected = "Projection: group_alias_0 AS test.a + Int32(1), count(alias1) AS count(DISTINCT test.c) [test.a + Int32(1):Int32, count(DISTINCT test.c):Int64]\ + \n Aggregate: groupBy=[[group_alias_0]], aggr=[[count(alias1)]] [group_alias_0:Int32, count(alias1):Int64]\ \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; @@ -538,28 +513,21 @@ mod tests { vec![ sum(col("c")), count_distinct(col("b")), - Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Max, - vec![col("b")], - true, - None, - None, - None, - )), + max_distinct(col("b")), ], )? .build()?; // Should work - let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), count(alias1) AS count(DISTINCT test.b), max(alias1) AS max(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, count(DISTINCT test.b):Int64, max(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), count(alias1), max(alias1)]] [a:UInt32, sum(alias2):UInt64;N, count(alias1):Int64, max(alias1):UInt32;N]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) } #[test] - fn one_distinctand_and_two_common() -> Result<()> { + fn one_distinct_and_two_common() -> Result<()> { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) @@ -569,9 +537,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ + let expected = "Projection: test.a, sum(alias2) AS sum(test.c), max(alias3) AS max(test.c), count(alias1) AS count(DISTINCT test.b) [a:UInt32, sum(test.c):UInt64;N, max(test.c):UInt32;N, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.a]], aggr=[[sum(alias2), max(alias3), count(alias1)]] [a:UInt32, sum(alias2):UInt64;N, max(alias3):UInt32;N, count(alias1):Int64]\ + \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[sum(test.c) AS alias2, max(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -588,9 +556,9 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\ - \n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\ - \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ + let expected = "Projection: test.c, min(alias2) AS min(test.a), count(alias1) AS count(DISTINCT test.b) [c:UInt32, min(test.a):UInt32;N, count(DISTINCT test.b):Int64]\ + \n Aggregate: groupBy=[[test.c]], aggr=[[min(alias2), count(alias1)]] [c:UInt32, min(alias2):UInt32;N, count(alias1):Int64]\ + \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[min(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -600,9 +568,9 @@ mod tests { fn common_with_filter() -> Result<()> { let table_scan = test_table_scan()?; - // SUM(a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Sum, + // sum(a) FILTER (WHERE a > 5) + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( + sum_udaf(), vec![col("a")], false, Some(Box::new(col("a").gt(lit(5)))), @@ -613,7 +581,7 @@ mod tests { .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) FILTER (WHERE test.a > Int32(5)), count(DISTINCT test.b)]] [c:UInt32, sum(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -623,20 +591,17 @@ mod tests { fn distinct_with_filter() -> Result<()> { let table_scan = test_table_scan()?; - // COUNT(DISTINCT a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - None, - None, - )); + // count(DISTINCT a) FILTER (WHERE a > 5) + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -647,19 +612,19 @@ mod tests { let table_scan = test_table_scan()?; // SUM(a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Sum, + let expr = Expr::AggregateFunction(AggregateFunction::new_udf( + sum_udaf(), vec![col("a")], false, None, - Some(vec![col("a")]), + Some(vec![col("a").sort(true, false)]), None, )); let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a) ORDER BY [test.a ASC NULLS LAST], count(DISTINCT test.b)]] [c:UInt32, sum(test.a) ORDER BY [test.a ASC NULLS LAST]:UInt64;N, count(DISTINCT test.b):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -669,20 +634,17 @@ mod tests { fn distinct_with_order_by() -> Result<()> { let table_scan = test_table_scan()?; - // COUNT(DISTINCT a ORDER BY a) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - None, - Some(vec![col("a")]), - None, - )); + // count(DISTINCT a ORDER BY a) + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) ORDER BY [test.a ASC NULLS LAST]:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) @@ -692,20 +654,18 @@ mod tests { fn aggregate_with_filter_and_order_by() -> Result<()> { let table_scan = test_table_scan()?; - // COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("a")], - true, - Some(Box::new(col("a").gt(lit(5)))), - Some(vec![col("a")]), - None, - )); + // count(DISTINCT a ORDER BY a) FILTER (WHERE a > 5) + let expr = count_udaf() + .call(vec![col("a")]) + .distinct() + .filter(col("a").gt(lit(5))) + .order_by(vec![col("a").sort(true, false)]) + .build()?; let plan = LogicalPlanBuilder::from(table_scan) .aggregate(vec![col("c")], vec![sum(col("a")), expr])? .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ + let expected = "Aggregate: groupBy=[[test.c]], aggr=[[sum(test.a), count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]]] [c:UInt32, sum(test.a):UInt64;N, count(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a ASC NULLS LAST]:Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_equal(plan, expected) diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index cafda8359aa3..94d07a0791b3 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -110,20 +110,32 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { pub fn assert_analyzed_plan_eq( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let options = ConfigOptions::default(); + assert_analyzed_plan_with_config_eq(options, rule, plan, expected)?; + + Ok(()) +} + +pub fn assert_analyzed_plan_with_config_eq( + options: ConfigOptions, + rule: Arc, + plan: LogicalPlan, + expected: &str, +) -> Result<()> { let analyzed_plan = Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options, |_, _| {})?; - let formatted_plan = format!("{analyzed_plan:?}"); + let formatted_plan = format!("{analyzed_plan}"); assert_eq!(formatted_plan, expected); Ok(()) } + pub fn assert_analyzed_plan_eq_display_indent( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let options = ConfigOptions::default(); @@ -137,7 +149,7 @@ pub fn assert_analyzed_plan_eq_display_indent( pub fn assert_analyzer_check_err( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let options = ConfigOptions::default(); @@ -161,29 +173,41 @@ pub fn assert_optimized_plan_eq( // Apply the rule once let opt_context = OptimizerContext::new().with_max_passes(1); - let optimizer = Optimizer::with_rules(vec![rule.clone()]); + let optimizer = Optimizer::with_rules(vec![Arc::clone(&rule)]); let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; - let formatted_plan = format!("{optimized_plan:?}"); + let formatted_plan = format!("{optimized_plan}"); assert_eq!(formatted_plan, expected); Ok(()) } -pub fn assert_optimized_plan_eq_with_rules( +fn generate_optimized_plan_with_rules( rules: Vec>, plan: LogicalPlan, - expected: &str, -) -> Result<()> { +) -> LogicalPlan { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} let config = &mut OptimizerContext::new() .with_max_passes(1) .with_skip_failing_rules(false); let optimizer = Optimizer::with_rules(rules); - let optimized_plan = optimizer + optimizer .optimize(plan, config, observe) - .expect("failed to optimize plan"); - let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(formatted_plan, expected); + .expect("failed to optimize plan") +} + +pub fn assert_optimized_plan_with_rules( + rules: Vec>, + plan: LogicalPlan, + expected: &str, + eq: bool, +) -> Result<()> { + let optimized_plan = generate_optimized_plan_with_rules(rules, plan); + let formatted_plan = format!("{optimized_plan}"); + if eq { + assert_eq!(formatted_plan, expected); + } else { + assert_ne!(formatted_plan, expected); + } Ok(()) } diff --git a/datafusion/optimizer/src/test/user_defined.rs b/datafusion/optimizer/src/test/user_defined.rs index c60342fa002e..a39f90b5da5d 100644 --- a/datafusion/optimizer/src/test/user_defined.rs +++ b/datafusion/optimizer/src/test/user_defined.rs @@ -33,7 +33,7 @@ pub fn new(input: LogicalPlan) -> LogicalPlan { LogicalPlan::Extension(Extension { node }) } -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, PartialOrd, Hash)] struct TestUserDefinedPlanNode { input: LogicalPlan, } @@ -65,11 +65,19 @@ impl UserDefinedLogicalNodeCore for TestUserDefinedPlanNode { write!(f, "TestUserDefined") } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs( + &self, + exprs: Vec, + mut inputs: Vec, + ) -> datafusion_common::Result { assert_eq!(inputs.len(), 1, "input size inconsistent"); assert_eq!(exprs.len(), 0, "expression size inconsistent"); - Self { - input: inputs[0].clone(), - } + Ok(Self { + input: inputs.swap_remove(0), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 138769674dd1..31e21d08b569 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -33,7 +33,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{internal_err, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{BinaryExpr, Cast, InList, TryCast}; use datafusion_expr::utils::merge_schema; -use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; +use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan}; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -72,7 +72,7 @@ use datafusion_expr::{lit, Expr, ExprSchemable, LogicalPlan, Operator}; /// Filter: c1 > INT32(10) /// ``` /// -#[derive(Default)] +#[derive(Default, Debug)] pub struct UnwrapCastInComparison {} impl UnwrapCastInComparison { @@ -82,14 +82,6 @@ impl UnwrapCastInComparison { } impl OptimizerRule for UnwrapCastInComparison { - fn try_optimize( - &self, - _plan: &LogicalPlan, - _config: &dyn OptimizerConfig, - ) -> Result> { - internal_err!("Should have called UnwrapCastInComparison::rewrite") - } - fn name(&self) -> &str { "unwrap_cast_in_comparison" } @@ -107,7 +99,7 @@ impl OptimizerRule for UnwrapCastInComparison { plan: LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let mut schema = merge_schema(plan.inputs()); + let mut schema = merge_schema(&plan.inputs()); if let LogicalPlan::TableScan(ts) = &plan { let source_schema = DFSchema::try_from_qualified_schema( @@ -125,9 +117,9 @@ impl OptimizerRule for UnwrapCastInComparison { let name_preserver = NamePreserver::new(&plan); plan.map_expressions(|expr| { - let original_name = name_preserver.save(&expr)?; - expr.rewrite(&mut expr_rewriter)? - .map_data(|expr| original_name.restore(expr)) + let original_name = name_preserver.save(&expr); + expr.rewrite(&mut expr_rewriter) + .map(|transformed| transformed.update_data(|e| original_name.restore(e))) }) } } @@ -152,9 +144,9 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { let Ok(right_type) = right.get_type(&self.schema) else { return Ok(Transformed::no(expr)); }; - is_support_data_type(&left_type) - && is_support_data_type(&right_type) - && is_comparison_op(op) + is_supported_type(&left_type) + && is_supported_type(&right_type) + && op.supports_propagation() } => { match (left.as_mut(), right.as_mut()) { @@ -167,20 +159,26 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { expr: right_expr, .. }), ) => { - // if the left_lit_value can be casted to the type of expr + // if the left_lit_value can be cast to the type of expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Ok(expr_type) = right_expr.get_type(&self.schema) else { return Ok(Transformed::no(expr)); }; - let Ok(Some(value)) = - try_cast_literal_to_type(left_lit_value, &expr_type) - else { - return Ok(Transformed::no(expr)); - }; - **left = lit(value); - // unwrap the cast/try_cast for the right expr - **right = mem::take(right_expr); - Ok(Transformed::yes(expr)) + match expr_type { + // https://github.com/apache/datafusion/issues/12180 + DataType::Utf8View => Ok(Transformed::no(expr)), + _ => { + let Some(value) = + try_cast_literal_to_type(left_lit_value, &expr_type) + else { + return Ok(Transformed::no(expr)); + }; + **left = lit(value); + // unwrap the cast/try_cast for the right expr + **right = mem::take(right_expr); + Ok(Transformed::yes(expr)) + } + } } ( Expr::TryCast(TryCast { @@ -191,20 +189,26 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }), Expr::Literal(right_lit_value), ) => { - // if the right_lit_value can be casted to the type of expr + // if the right_lit_value can be cast to the type of expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal let Ok(expr_type) = left_expr.get_type(&self.schema) else { return Ok(Transformed::no(expr)); }; - let Ok(Some(value)) = - try_cast_literal_to_type(right_lit_value, &expr_type) - else { - return Ok(Transformed::no(expr)); - }; - // unwrap the cast/try_cast for the left expr - **left = mem::take(left_expr); - **right = lit(value); - Ok(Transformed::yes(expr)) + match expr_type { + // https://github.com/apache/datafusion/issues/12180 + DataType::Utf8View => Ok(Transformed::no(expr)), + _ => { + let Some(value) = + try_cast_literal_to_type(right_lit_value, &expr_type) + else { + return Ok(Transformed::no(expr)); + }; + // unwrap the cast/try_cast for the left expr + **left = mem::take(left_expr); + **right = lit(value); + Ok(Transformed::yes(expr)) + } + } } _ => Ok(Transformed::no(expr)), } @@ -226,14 +230,14 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { let Ok(expr_type) = left_expr.get_type(&self.schema) else { return Ok(Transformed::no(expr)); }; - if !is_support_data_type(&expr_type) { + if !is_supported_type(&expr_type) { return Ok(Transformed::no(expr)); } let Ok(right_exprs) = list .iter() .map(|right| { let right_type = right.get_type(&self.schema)?; - if !is_support_data_type(&right_type) { + if !is_supported_type(&right_type) { internal_err!( "The type of list expr {} is not supported", &right_type @@ -243,7 +247,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { Expr::Literal(right_lit_value) => { // if the right_lit_value can be casted to the type of internal_left_expr // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal - let Ok(Some(value)) = try_cast_literal_to_type(right_lit_value, &expr_type) else { + let Some(value) = try_cast_literal_to_type(right_lit_value, &expr_type) else { internal_err!( "Can't cast the list expr {:?} to type {:?}", right_lit_value, &expr_type @@ -270,19 +274,15 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { } } -fn is_comparison_op(op: &Operator) -> bool { - matches!( - op, - Operator::Eq - | Operator::NotEq - | Operator::Gt - | Operator::GtEq - | Operator::Lt - | Operator::LtEq - ) +/// Returns true if [UnwrapCastExprRewriter] supports this data type +fn is_supported_type(data_type: &DataType) -> bool { + is_supported_numeric_type(data_type) + || is_supported_string_type(data_type) + || is_supported_dictionary_type(data_type) } -fn is_support_data_type(data_type: &DataType) -> bool { +/// Returns true if [[UnwrapCastExprRewriter]] suppors this numeric type +fn is_supported_numeric_type(data_type: &DataType) -> bool { matches!( data_type, DataType::UInt8 @@ -298,19 +298,50 @@ fn is_support_data_type(data_type: &DataType) -> bool { ) } +/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a string +fn is_supported_string_type(data_type: &DataType) -> bool { + matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ) +} + +/// Returns true if [UnwrapCastExprRewriter] supports casting this value as a dictionary +fn is_supported_dictionary_type(data_type: &DataType) -> bool { + matches!(data_type, + DataType::Dictionary(_, inner) if is_supported_type(inner)) +} + +/// Convert a literal value from one data type to another fn try_cast_literal_to_type( lit_value: &ScalarValue, target_type: &DataType, -) -> Result> { +) -> Option { let lit_data_type = lit_value.data_type(); - // the rule just support the signed numeric data type now - if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) { - return Ok(None); + if !is_supported_type(&lit_data_type) || !is_supported_type(target_type) { + return None; } if lit_value.is_null() { // null value can be cast to any type of null value - return Ok(Some(ScalarValue::try_from(target_type)?)); + return ScalarValue::try_from(target_type).ok(); } + try_cast_numeric_literal(lit_value, target_type) + .or_else(|| try_cast_string_literal(lit_value, target_type)) + .or_else(|| try_cast_dictionary(lit_value, target_type)) +} + +/// Convert a numeric value from one numeric data type to another +fn try_cast_numeric_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_data_type = lit_value.data_type(); + if !is_supported_numeric_type(&lit_data_type) + || !is_supported_numeric_type(target_type) + { + return None; + } + let mul = match target_type { DataType::UInt8 | DataType::UInt16 @@ -322,9 +353,7 @@ fn try_cast_literal_to_type( | DataType::Int64 => 1_i128, DataType::Timestamp(_, _) => 1_i128, DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32), - other_type => { - return internal_err!("Error target data type {other_type:?}"); - } + _ => return None, }; let (target_min, target_max) = match target_type { DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128), @@ -343,9 +372,7 @@ fn try_cast_literal_to_type( MIN_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1], MAX_DECIMAL_FOR_EACH_PRECISION[*precision as usize - 1], ), - other_type => { - return internal_err!("Error target data type {other_type:?}"); - } + _ => return None, }; let lit_value_target_type = match lit_value { ScalarValue::Int8(Some(v)) => (*v as i128).checked_mul(mul), @@ -379,13 +406,11 @@ fn try_cast_literal_to_type( None } } - other_value => { - return internal_err!("Invalid literal value {other_value:?}"); - } + _ => None, }; match lit_value_target_type { - None => Ok(None), + None => None, Some(value) => { if value >= target_min && value <= target_max { // the value casted from lit to the target type is in the range of target type. @@ -401,32 +426,32 @@ fn try_cast_literal_to_type( DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)), DataType::Timestamp(TimeUnit::Second, tz) => { let value = cast_between_timestamp( - lit_data_type, - DataType::Timestamp(TimeUnit::Second, tz.clone()), + &lit_data_type, + &DataType::Timestamp(TimeUnit::Second, tz.clone()), value, ); ScalarValue::TimestampSecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { let value = cast_between_timestamp( - lit_data_type, - DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), + &lit_data_type, + &DataType::Timestamp(TimeUnit::Millisecond, tz.clone()), value, ); ScalarValue::TimestampMillisecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { let value = cast_between_timestamp( - lit_data_type, - DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + &lit_data_type, + &DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), value, ); ScalarValue::TimestampMicrosecond(value, tz.clone()) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { let value = cast_between_timestamp( - lit_data_type, - DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), + &lit_data_type, + &DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()), value, ); ScalarValue::TimestampNanosecond(value, tz.clone()) @@ -434,20 +459,65 @@ fn try_cast_literal_to_type( DataType::Decimal128(p, s) => { ScalarValue::Decimal128(Some(value), *p, *s) } - other_type => { - return internal_err!("Error target data type {other_type:?}"); + _ => { + return None; } }; - Ok(Some(result_scalar)) + Some(result_scalar) } else { - Ok(None) + None } } } } +fn try_cast_string_literal( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let string_value = match lit_value { + ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { + s.clone() + } + _ => return None, + }; + let scalar_value = match target_type { + DataType::Utf8 => ScalarValue::Utf8(string_value), + DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), + DataType::Utf8View => ScalarValue::Utf8View(string_value), + _ => return None, + }; + Some(scalar_value) +} + +/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary +fn try_cast_dictionary( + lit_value: &ScalarValue, + target_type: &DataType, +) -> Option { + let lit_value_type = lit_value.data_type(); + let result_scalar = match (lit_value, target_type) { + // Unwrap dictionary when inner type matches target type + (ScalarValue::Dictionary(_, inner_value), _) + if inner_value.data_type() == *target_type => + { + (**inner_value).clone() + } + // Wrap type when target type is dictionary + (_, DataType::Dictionary(index_type, inner_type)) + if **inner_type == lit_value_type => + { + ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) + } + _ => { + return None; + } + }; + Some(result_scalar) +} + /// Cast a timestamp value from one unit to another -fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { +fn cast_between_timestamp(from: &DataType, to: &DataType, value: i128) -> Option { let value = value as i64; let from_scale = match from { DataType::Timestamp(TimeUnit::Second, _) => 1, @@ -536,6 +606,45 @@ mod tests { assert_eq!(optimize_test(expr_input, &schema), expected); } + #[test] + fn test_unwrap_cast_comparison_string() { + let schema = expr_test_schema(); + let dict = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::from("value")), + ); + + // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') + let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); + let expected = col("str1").eq(lit("value")); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') + let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); + let expected = col("tag").eq(lit(dict.clone())); + assert_eq!(optimize_test(expr_input, &schema), expected); + + // Verify reversed argument order + // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 + let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); + let expected = lit("value").eq(col("str1")); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + + #[test] + fn test_unwrap_cast_comparison_large_string() { + let schema = expr_test_schema(); + // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') + let dict = ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), + ); + let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict)); + let expected = + col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); + assert_eq!(optimize_test(expr_input, &schema), expected); + } + #[test] fn test_not_unwrap_cast_with_decimal_comparison() { let schema = expr_test_schema(); @@ -729,14 +838,14 @@ mod tests { fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr { let mut expr_rewriter = UnwrapCastExprRewriter { - schema: schema.clone(), + schema: Arc::clone(schema), }; expr.rewrite(&mut expr_rewriter).data().unwrap() } fn expr_test_schema() -> DFSchemaRef { Arc::new( - DFSchema::from_unqualifed_fields( + DFSchema::from_unqualified_fields( vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int64, false), @@ -746,6 +855,9 @@ mod tests { Field::new("c6", DataType::UInt32, false), Field::new("ts_nano_none", timestamp_nano_none_type(), false), Field::new("ts_nano_utf", timestamp_nano_utc_type(), false), + Field::new("str1", DataType::Utf8, false), + Field::new("largestr", DataType::LargeUtf8, false), + Field::new("tag", dictionary_tag_type(), false), ] .into(), HashMap::new(), @@ -793,6 +905,11 @@ mod tests { DataType::Timestamp(TimeUnit::Nanosecond, utc) } + // a dictionary type for storing string tags + fn dictionary_tag_type() -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + } + #[test] fn test_try_cast_to_type_nulls() { // test that nulls can be cast to/from all integer types @@ -807,6 +924,8 @@ mod tests { ScalarValue::UInt64(None), ScalarValue::Decimal128(None, 3, 0), ScalarValue::Decimal128(None, 8, 2), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), ]; for s1 in &scalars { @@ -967,7 +1086,7 @@ mod tests { ), }; - // Datafusion ignores timezones for comparisons of ScalarValue + // DataFusion ignores timezones for comparisons of ScalarValue // so double check it here assert_eq!(lit_tz_none, lit_tz_utc); @@ -1061,18 +1180,17 @@ mod tests { target_type: DataType, expected_result: ExpectedCast, ) { - let actual_result = try_cast_literal_to_type(&literal, &target_type); + let actual_value = try_cast_literal_to_type(&literal, &target_type); println!("expect_cast: "); println!(" {literal:?} --> {target_type:?}"); println!(" expected_result: {expected_result:?}"); - println!(" actual_result: {actual_result:?}"); + println!(" actual_result: {actual_value:?}"); match expected_result { ExpectedCast::Value(expected_value) => { - let actual_value = actual_result - .expect("Expected success but got error") - .expect("Expected cast value but got None"); + let actual_value = + actual_value.expect("Expected cast value but got None"); assert_eq!(actual_value, expected_value); @@ -1094,7 +1212,7 @@ mod tests { assert_eq!( &expected_array, &cast_array, - "Result of casing {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" + "Result of casting {literal:?} with arrow was\n {cast_array:#?}\nbut expected\n{expected_array:#?}" ); // Verify that for timestamp types the timezones are the same @@ -1109,8 +1227,6 @@ mod tests { } } ExpectedCast::NoValue => { - let actual_value = actual_result.expect("Expected success but got error"); - assert!( actual_value.is_none(), "Expected no cast value, but got {actual_value:?}" @@ -1126,7 +1242,6 @@ mod tests { &ScalarValue::TimestampNanosecond(Some(123456), None), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) - .unwrap() .unwrap(); assert_eq!( @@ -1139,7 +1254,6 @@ mod tests { &ScalarValue::TimestampNanosecond(Some(123456), None), &DataType::Timestamp(TimeUnit::Microsecond, None), ) - .unwrap() .unwrap(); assert_eq!( @@ -1152,7 +1266,6 @@ mod tests { &ScalarValue::TimestampNanosecond(Some(123456), None), &DataType::Timestamp(TimeUnit::Millisecond, None), ) - .unwrap() .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); @@ -1162,7 +1275,6 @@ mod tests { &ScalarValue::TimestampNanosecond(Some(123456), None), &DataType::Timestamp(TimeUnit::Second, None), ) - .unwrap() .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None)); @@ -1172,7 +1284,6 @@ mod tests { &ScalarValue::TimestampMicrosecond(Some(123), None), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) - .unwrap() .unwrap(); assert_eq!( @@ -1185,7 +1296,6 @@ mod tests { &ScalarValue::TimestampMicrosecond(Some(123), None), &DataType::Timestamp(TimeUnit::Millisecond, None), ) - .unwrap() .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None)); @@ -1195,7 +1305,6 @@ mod tests { &ScalarValue::TimestampMicrosecond(Some(123456789), None), &DataType::Timestamp(TimeUnit::Second, None), ) - .unwrap() .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None)); @@ -1204,7 +1313,6 @@ mod tests { &ScalarValue::TimestampMillisecond(Some(123), None), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) - .unwrap() .unwrap(); assert_eq!( new_scalar, @@ -1216,7 +1324,6 @@ mod tests { &ScalarValue::TimestampMillisecond(Some(123), None), &DataType::Timestamp(TimeUnit::Microsecond, None), ) - .unwrap() .unwrap(); assert_eq!( new_scalar, @@ -1227,7 +1334,6 @@ mod tests { &ScalarValue::TimestampMillisecond(Some(123456789), None), &DataType::Timestamp(TimeUnit::Second, None), ) - .unwrap() .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None)); @@ -1236,7 +1342,6 @@ mod tests { &ScalarValue::TimestampSecond(Some(123), None), &DataType::Timestamp(TimeUnit::Nanosecond, None), ) - .unwrap() .unwrap(); assert_eq!( new_scalar, @@ -1248,7 +1353,6 @@ mod tests { &ScalarValue::TimestampSecond(Some(123), None), &DataType::Timestamp(TimeUnit::Microsecond, None), ) - .unwrap() .unwrap(); assert_eq!( new_scalar, @@ -1260,7 +1364,6 @@ mod tests { &ScalarValue::TimestampSecond(Some(123), None), &DataType::Timestamp(TimeUnit::Millisecond, None), ) - .unwrap() .unwrap(); assert_eq!( new_scalar, @@ -1272,8 +1375,48 @@ mod tests { &ScalarValue::TimestampSecond(Some(i64::MAX), None), &DataType::Timestamp(TimeUnit::Millisecond, None), ) - .unwrap() .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); } + + #[test] + fn test_try_cast_to_string_type() { + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + + for s1 in &scalars { + for s2 in &scalars { + let expected_value = ExpectedCast::Value(s2.clone()); + + expect_cast(s1.clone(), s2.data_type(), expected_value); + } + } + } + #[test] + fn test_try_cast_to_dictionary_type() { + fn dictionary_type(t: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) + } + fn dictionary_value(value: ScalarValue) -> ScalarValue { + ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) + } + let scalars = vec![ + ScalarValue::from("string"), + ScalarValue::LargeUtf8(Some("string".to_owned())), + ]; + for s in &scalars { + expect_cast( + s.clone(), + dictionary_type(s.data_type()), + ExpectedCast::Value(dictionary_value(s.clone())), + ); + expect_cast( + dictionary_value(s.clone()), + s.data_type(), + ExpectedCast::Value(s.clone()), + ) + } + } } diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 1c20501da53a..9f325bc01b1d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -17,16 +17,26 @@ //! Utility functions leveraged by the query optimizer rules -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, DFSchema, DFSchemaRef, Result}; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use arrow::array::{new_null_array, Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; -use datafusion_expr::utils as expr_utils; -use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; - +use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr}; +use datafusion_physical_expr::create_physical_expr; use log::{debug, trace}; +use std::sync::Arc; + +/// Re-export of `NamesPreserver` for backwards compatibility, +/// as it was initially placed here and then moved elsewhere. +pub use datafusion_expr::expr_rewriter::NamePreserver; /// Convenience rule for writing optimizers: recursively invoke /// optimize on plan's children and then return a node of the same @@ -35,6 +45,10 @@ use log::{debug, trace}; /// This also handles the case when the `plan` is a [`LogicalPlan::Explain`]. /// /// Returning `Ok(None)` indicates that the plan can't be optimized by the `optimizer`. +#[deprecated( + since = "40.0.0", + note = "please use OptimizerRule::apply_order with ApplyOrder::BottomUp instead" +)] pub fn optimize_children( optimizer: &impl OptimizerRule, plan: &LogicalPlan, @@ -43,9 +57,16 @@ pub fn optimize_children( let mut new_inputs = Vec::with_capacity(plan.inputs().len()); let mut plan_is_changed = false; for input in plan.inputs() { - let new_input = optimizer.try_optimize(input, config)?; - plan_is_changed = plan_is_changed || new_input.is_some(); - new_inputs.push(new_input.unwrap_or_else(|| input.clone())) + if optimizer.supports_rewrite() { + let new_input = optimizer.rewrite(input.clone(), config)?; + plan_is_changed = plan_is_changed || new_input.transformed; + new_inputs.push(new_input.data); + } else { + #[allow(deprecated)] + let new_input = optimizer.try_optimize(input, config)?; + plan_is_changed = plan_is_changed || new_input.is_some(); + new_inputs.push(new_input.unwrap_or_else(|| input.clone())) + } } if plan_is_changed { let exprs = plan.expressions(); @@ -55,15 +76,26 @@ pub fn optimize_children( } } +/// Returns true if `expr` contains all columns in `schema_cols` +pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet) -> bool { + let column_refs = expr.column_refs(); + // note can't use HashSet::intersect because of different types (owned vs References) + schema_cols + .iter() + .filter(|c| column_refs.contains(c)) + .count() + == column_refs.len() +} + pub(crate) fn collect_subquery_cols( exprs: &[Expr], - subquery_schema: DFSchemaRef, + subquery_schema: &DFSchema, ) -> Result> { exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| { let mut using_cols: Vec = vec![]; - for col in expr.to_columns()?.into_iter() { - if subquery_schema.has_column(&col) { - using_cols.push(col); + for col in expr.column_refs().into_iter() { + if subquery_schema.has_column(col) { + using_cols.push(col.clone()); } } @@ -79,9 +111,7 @@ pub(crate) fn replace_qualified_name( ) -> Result { let alias_cols: Vec = cols .iter() - .map(|col| { - Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) - }) + .map(|col| Column::new(Some(subquery_alias), &col.name)) .collect(); let replace_map: HashMap<&Column, &Column> = cols.iter().zip(alias_cols.iter()).collect(); @@ -95,224 +125,160 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } -/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` -/// -/// See [`split_conjunction_owned`] for more details and an example. -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::split_conjunction` instead" -)] -pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> { - expr_utils::split_conjunction(expr) -} - -/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` -/// -/// This is often used to "split" filter expressions such as `col1 = 5 -/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; -/// -/// # Example -/// ``` -/// # use datafusion_expr::{col, lit}; -/// # use datafusion_optimizer::utils::split_conjunction_owned; -/// // a=1 AND b=2 -/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); -/// -/// // [a=1, b=2] -/// let split = vec![ -/// col("a").eq(lit(1)), -/// col("b").eq(lit(2)), -/// ]; -/// -/// // use split_conjunction_owned to split them -/// assert_eq!(split_conjunction_owned(expr), split); -/// ``` -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::split_conjunction_owned` instead" -)] -pub fn split_conjunction_owned(expr: Expr) -> Vec { - expr_utils::split_conjunction_owned(expr) -} - -/// Splits an owned binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` -/// -/// This is often used to "split" expressions such as `col1 = 5 -/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`]; -/// -/// # Example -/// ``` -/// # use datafusion_expr::{col, lit, Operator}; -/// # use datafusion_optimizer::utils::split_binary_owned; -/// # use std::ops::Add; -/// // a=1 + b=2 -/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2))); -/// -/// // [a=1, b=2] -/// let split = vec![ -/// col("a").eq(lit(1)), -/// col("b").eq(lit(2)), -/// ]; -/// -/// // use split_binary_owned to split them -/// assert_eq!(split_binary_owned(expr, Operator::Plus), split); -/// ``` -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::split_binary_owned` instead" -)] -pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec { - expr_utils::split_binary_owned(expr, op) -} - -/// Splits an binary operator tree [`Expr`] such as `A B C` => `[A, B, C]` -/// -/// See [`split_binary_owned`] for more details and an example. -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::split_binary` instead" -)] -pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> { - expr_utils::split_binary(expr, op) -} +/// Determine whether a predicate can restrict NULLs. e.g. +/// `c0 > 8` return true; +/// `c0 IS NULL` return false. +pub fn is_restrict_null_predicate<'a>( + predicate: Expr, + join_cols_of_predicate: impl IntoIterator, +) -> Result { + if matches!(predicate, Expr::Column(_)) { + return Ok(true); + } -/// Combines an array of filter expressions into a single filter -/// expression consisting of the input filter expressions joined with -/// logical AND. -/// -/// Returns None if the filters array is empty. -/// -/// # Example -/// ``` -/// # use datafusion_expr::{col, lit}; -/// # use datafusion_optimizer::utils::conjunction; -/// // a=1 AND b=2 -/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2))); -/// -/// // [a=1, b=2] -/// let split = vec![ -/// col("a").eq(lit(1)), -/// col("b").eq(lit(2)), -/// ]; -/// -/// // use conjunction to join them together with `AND` -/// assert_eq!(conjunction(split), Some(expr)); -/// ``` -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::conjunction` instead" -)] -pub fn conjunction(filters: impl IntoIterator) -> Option { - expr_utils::conjunction(filters) -} + static DUMMY_COL_NAME: &str = "?"; + let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]); + let input_schema = DFSchema::try_from(schema.clone())?; + let column = new_null_array(&DataType::Null, 1); + let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?; + let execution_props = ExecutionProps::default(); + let null_column = Column::from_name(DUMMY_COL_NAME); -/// Combines an array of filter expressions into a single filter -/// expression consisting of the input filter expressions joined with -/// logical OR. -/// -/// Returns None if the filters array is empty. -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::disjunction` instead" -)] -pub fn disjunction(filters: impl IntoIterator) -> Option { - expr_utils::disjunction(filters) -} + let join_cols_to_replace = join_cols_of_predicate + .into_iter() + .map(|column| (column, &null_column)) + .collect::>(); -/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with -/// its predicate be all `predicates` ANDed. -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::add_filter` instead" -)] -pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { - expr_utils::add_filter(plan, predicates) -} + let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?; + let coerced_predicate = coerce(replaced_predicate, &input_schema)?; + let phys_expr = + create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?; -/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and -/// one not in the subquery (closed upon from outer scope) -/// -/// # Arguments -/// -/// * `exprs` - List of expressions that may or may not be joins -/// -/// # Return value -/// -/// Tuple of (expressions containing joins, remaining non-join expressions) -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::find_join_exprs` instead" -)] -pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec, Vec)> { - expr_utils::find_join_exprs(exprs) -} + let result_type = phys_expr.data_type(&schema)?; + if !matches!(&result_type, DataType::Boolean) { + return Ok(false); + } -/// Returns the first (and only) element in a slice, or an error -/// -/// # Arguments -/// -/// * `slice` - The slice to extract from -/// -/// # Return value -/// -/// The first element, or an error -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::only_or_err` instead" -)] -pub fn only_or_err(slice: &[T]) -> Result<&T> { - expr_utils::only_or_err(slice) + // If result is single `true`, return false; + // If result is single `NULL` or `false`, return true; + Ok(match phys_expr.evaluate(&input_batch)? { + ColumnarValue::Array(array) => { + if array.len() == 1 { + let boolean_array = as_boolean_array(&array)?; + boolean_array.is_null(0) || !boolean_array.value(0) + } else { + false + } + } + ColumnarValue::Scalar(scalar) => matches!( + scalar, + ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false)) + ), + }) } -/// merge inputs schema into a single schema. -#[deprecated( - since = "34.0.0", - note = "use `datafusion_expr::utils::merge_schema` instead" -)] -pub fn merge_schema(inputs: Vec<&LogicalPlan>) -> DFSchema { - expr_utils::merge_schema(inputs) +fn coerce(expr: Expr, schema: &DFSchema) -> Result { + let mut expr_rewrite = TypeCoercionRewriter { schema }; + expr.rewrite(&mut expr_rewrite).data() } -/// Handles ensuring the name of rewritten expressions is not changed. -/// -/// For example, if an expression `1 + 2` is rewritten to `3`, the name of the -/// expression should be preserved: `3 as "1 + 2"` -/// -/// See for details -pub struct NamePreserver { - use_alias: bool, -} +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator}; -/// If the name of an expression is remembered, it will be preserved when -/// rewriting the expression -pub struct SavedName(Option); + #[test] + fn expr_is_restrict_null_predicate() -> Result<()> { + let test_cases = vec![ + // a + (col("a"), true), + // a IS NULL + (is_null(col("a")), false), + // a IS NOT NULL + (Expr::IsNotNull(Box::new(col("a"))), true), + // a = NULL + ( + binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)), + true, + ), + // a > 8 + (binary_expr(col("a"), Operator::Gt, lit(8i64)), true), + // a <= 8 + (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true), + // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .when(lit(0i64), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + true, + ), + // CASE a WHEN 1 THEN true ELSE false END + ( + case(col("a")) + .when(lit(1i64), lit(true)) + .otherwise(lit(false))?, + true, + ), + // CASE a WHEN 0 THEN false ELSE true END + ( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + false, + ), + // (CASE a WHEN 0 THEN false ELSE true END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(false)) + .otherwise(lit(true))?, + Operator::Or, + lit(false), + ), + false, + ), + // (CASE a WHEN 0 THEN true ELSE false END) OR false + ( + binary_expr( + case(col("a")) + .when(lit(0i64), lit(true)) + .otherwise(lit(false))?, + Operator::Or, + lit(false), + ), + true, + ), + // a IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false), + true, + ), + // a NOT IN (1, 2, 3) + ( + in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true), + true, + ), + // a IN (NULL) + ( + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false), + true, + ), + // a NOT IN (NULL) + ( + in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true), + true, + ), + ]; -impl NamePreserver { - /// Create a new NamePreserver for rewriting the `expr` that is part of the specified plan - pub fn new(plan: &LogicalPlan) -> Self { - Self { - use_alias: !matches!(plan, LogicalPlan::Filter(_) | LogicalPlan::Join(_)), + let column_a = Column::from_name("a"); + for (predicate, expected) in test_cases { + let join_cols_of_predicate = std::iter::once(&column_a); + let actual = + is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?; + assert_eq!(actual, expected, "{}", predicate); } - } - - pub fn save(&self, expr: &Expr) -> Result { - let original_name = if self.use_alias { - Some(expr.name_for_alias()?) - } else { - None - }; - Ok(SavedName(original_name)) - } -} - -impl SavedName { - /// Ensures the name of the rewritten expression is preserved - pub fn restore(self, expr: Expr) -> Result { - let Self(original_name) = self; - match original_name { - Some(name) => expr.alias_if_changed(name), - None => Ok(expr), - } + Ok(()) } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 180c79206664..236167985790 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -20,9 +20,13 @@ use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; +use datafusion_expr::test::function_stub::sum_udaf; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource, WindowUDF}; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::count::count_udaf; use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; @@ -46,13 +50,13 @@ fn case_when() -> Result<()> { let expected = "Projection: CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\ \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test"; let plan = test_sql(sql)?; let expected = "Projection: CASE WHEN test.col_uint32 > UInt32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_uint32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\ \n TableScan: test projection=[col_uint32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -61,31 +65,31 @@ fn subquery_filter_with_cast() -> Result<()> { // regression test for https://github.com/apache/datafusion/issues/3760 let sql = "SELECT col_int32 FROM test \ WHERE col_int32 > (\ - SELECT AVG(col_int32) FROM test \ + SELECT avg(col_int32) FROM test \ WHERE col_utf8 BETWEEN '2002-05-08' \ AND (cast('2002-05-08' as date) + interval '5 days')\ )"; let plan = test_sql(sql)?; let expected = "Projection: test.col_int32\ - \n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.AVG(test.col_int32)\ + \n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.avg(test.col_int32)\ \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: __scalar_sq_1\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(CAST(test.col_int32 AS Float64))]]\ + \n Aggregate: groupBy=[[]], aggr=[[avg(CAST(test.col_int32 AS Float64))]]\ \n Projection: test.col_int32\ \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ \n TableScan: test projection=[col_int32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } #[test] fn case_when_aggregate() -> Result<()> { - let sql = "SELECT col_utf8, SUM(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; + let sql = "SELECT col_utf8, sum(CASE WHEN col_int32 > 0 THEN 1 ELSE 0 END) AS n FROM test GROUP BY col_utf8"; let plan = test_sql(sql)?; - let expected = "Projection: test.col_utf8, SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ - \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[SUM(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS SUM(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ + let expected = "Projection: test.col_utf8, sum(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END) AS n\ + \n Aggregate: groupBy=[[test.col_utf8]], aggr=[[sum(CASE WHEN test.col_int32 > Int32(0) THEN Int64(1) ELSE Int64(0) END) AS sum(CASE WHEN test.col_int32 > Int64(0) THEN Int64(1) ELSE Int64(0) END)]]\ \n TableScan: test projection=[col_int32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -96,7 +100,7 @@ fn unsigned_target_type() -> Result<()> { let expected = "Projection: test.col_utf8\ \n Filter: test.col_uint32 > UInt32(0)\ \n TableScan: test projection=[col_uint32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -105,9 +109,9 @@ fn distribute_by() -> Result<()> { // regression test for https://github.com/apache/datafusion/issues/3234 let sql = "SELECT col_int32, col_utf8 FROM test DISTRIBUTE BY (col_utf8)"; let plan = test_sql(sql)?; - let expected = "Repartition: DistributeBy(col_utf8)\ + let expected = "Repartition: DistributeBy(test.col_utf8)\ \n TableScan: test projection=[col_int32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -120,11 +124,13 @@ fn semi_join_with_join_filter() -> Result<()> { let plan = test_sql(sql)?; let expected = "Projection: test.col_utf8\ \n LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\ - \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ \n SubqueryAlias: __correlated_sq_1\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32]"; - assert_eq!(expected, format!("{plan:?}")); + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32]"; + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -140,8 +146,9 @@ fn anti_join_with_join_filter() -> Result<()> { \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ \n SubqueryAlias: __correlated_sq_1\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32]"; - assert_eq!(expected, format!("{plan:?}")); + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32]"; + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -151,12 +158,14 @@ fn where_exists_distinct() -> Result<()> { SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)"; let plan = test_sql(sql)?; let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32\ - \n TableScan: test projection=[col_int32]\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: __correlated_sq_1\ \n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{plan:?}")); + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -168,13 +177,13 @@ fn intersect() -> Result<()> { let plan = test_sql(sql)?; let expected = "LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ - \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ - \n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ - \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ - \n TableScan: test projection=[col_int32, col_utf8]\ - \n TableScan: test projection=[col_int32, col_utf8]\ - \n TableScan: test projection=[col_int32, col_utf8]"; - assert_eq!(expected, format!("{plan:?}")); + \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ + \n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ + \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -184,11 +193,11 @@ fn between_date32_plus_interval() -> Result<()> { WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n Projection: \ - \n Filter: test.col_date32 >= Date32(\"10303\") AND test.col_date32 <= Date32(\"10393\")\ + \n Filter: test.col_date32 >= Date32(\"1998-03-18\") AND test.col_date32 <= Date32(\"1998-06-16\")\ \n TableScan: test projection=[col_date32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -198,11 +207,11 @@ fn between_date64_plus_interval() -> Result<()> { WHERE col_date64 between '1998-03-18T00:00:00' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; let plan = test_sql(sql)?; let expected = - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n Projection: \ - \n Filter: test.col_date64 >= Date64(\"890179200000\") AND test.col_date64 <= Date64(\"897955200000\")\ + \n Filter: test.col_date64 >= Date64(\"1998-03-18\") AND test.col_date64 <= Date64(\"1998-06-16\")\ \n TableScan: test projection=[col_date64]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -212,7 +221,7 @@ fn propagate_empty_relation() { let plan = test_sql(sql).unwrap(); // when children exist EmptyRelation, it will bottom-up propagate. let expected = "EmptyRelation"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); } #[test] @@ -228,7 +237,7 @@ fn join_keys_in_subquery_alias() { \n Filter: test.col_int32 IS NOT NULL\ \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); } #[test] @@ -247,18 +256,18 @@ fn join_keys_in_subquery_alias_1() { \n SubqueryAlias: c\ \n Filter: test.col_int32 IS NOT NULL\ \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); } #[test] fn push_down_filter_groupby_expr_contains_alias() { let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3"; let plan = test_sql(sql).unwrap(); - let expected = "Projection: test.col_int32 + test.col_uint32 AS c, COUNT(*)\ - \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]]\ + let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(*)\ + \n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1)) AS count(*)]]\ \n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\ \n TableScan: test projection=[col_int32, col_uint32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); } #[test] @@ -272,7 +281,109 @@ fn test_same_name_but_not_ambiguous() { \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ \n TableScan: test projection=[col_int32]"; - assert_eq!(expected, format!("{plan:?}")); + assert_eq!(expected, format!("{plan}")); +} + +#[test] +fn eliminate_nested_filters() { + let sql = "\ + SELECT col_int32 FROM test \ + WHERE (1=1) AND (col_int32 > 0) \ + AND (1=1) AND (1=0 OR 1=1)"; + + let plan = test_sql(sql).unwrap(); + let expected = "\ + Filter: test.col_int32 > Int32(0)\ + \n TableScan: test projection=[col_int32]"; + + assert_eq!(expected, format!("{plan}")); +} + +#[test] +fn eliminate_redundant_null_check_on_count() { + let sql = "\ + SELECT col_int32, count(*) c + FROM test + GROUP BY col_int32 + HAVING c IS NOT NULL"; + let plan = test_sql(sql).unwrap(); + let expected = "\ + Projection: test.col_int32, count(*) AS c\ + \n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan}")); +} + +#[test] +fn test_propagate_empty_relation_inner_join_and_unions() { + let sql = "\ + SELECT A.col_int32 FROM test AS A \ + INNER JOIN ( \ + SELECT col_int32 FROM test WHERE 1 = 0 \ + ) AS B ON A.col_int32 = B.col_int32 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE 1 = 1 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE 0 = 0 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE test.col_int32 < 0 \ + UNION ALL \ + SELECT test.col_int32 FROM test WHERE 1 = 0"; + + let plan = test_sql(sql).unwrap(); + let expected = "\ + Union\ + \n TableScan: test projection=[col_int32]\ + \n TableScan: test projection=[col_int32]\ + \n Filter: test.col_int32 < Int32(0)\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan}")); +} + +#[test] +fn select_wildcard_with_repeated_column() { + let sql = "SELECT *, col_int32 FROM test"; + let err = test_sql(sql).expect_err("query should have failed"); + assert_eq!( + "Schema error: Schema contains duplicate qualified field name test.col_int32", + err.strip_backtrace() + ); +} + +#[test] +fn select_wildcard_with_repeated_column_but_is_aliased() { + let sql = "SELECT *, col_int32 as col_32 FROM test"; + + let plan = test_sql(sql).unwrap(); + let expected = "Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64, test.col_ts_nano_none, test.col_ts_nano_utc, test.col_int32 AS col_32\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]"; + + assert_eq!(expected, format!("{plan}")); +} + +#[test] +fn select_correlated_predicate_subquery_with_uppercase_ident() { + let sql = r#" + SELECT * + FROM + test + WHERE + EXISTS ( + SELECT 1 + FROM (SELECT col_int32 as "COL_INT32", col_uint32 as "COL_UINT32" FROM test) "T1" + WHERE "T1"."COL_INT32" = test.col_int32 + ) + "#; + let plan = test_sql(sql).unwrap(); + let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.COL_INT32\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\ + \n SubqueryAlias: __correlated_sq_1\ + \n SubqueryAlias: T1\ + \n Projection: test.col_int32 AS COL_INT32\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]"; + assert_eq!(expected, format!("{plan}")); } fn test_sql(sql: &str) -> Result { @@ -280,15 +391,18 @@ fn test_sql(sql: &str) -> Result { let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); let statement = &ast[0]; - let context_provider = MyContextProvider::default(); + let context_provider = MyContextProvider::default() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()) + .with_udaf(avg_udaf()); let sql_to_rel = SqlToRel::new(&context_provider); - let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone())?; let config = OptimizerContext::new().with_skip_failing_rules(false); let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); // analyze and optimize the logical plan - let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; + let plan = analyzer.execute_and_check(plan, config.options(), |_, _| {})?; optimizer.optimize(plan, &config, observe) } @@ -297,6 +411,15 @@ fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} #[derive(Default)] struct MyContextProvider { options: ConfigOptions, + udafs: HashMap>, +} + +impl MyContextProvider { + fn with_udaf(mut self, udaf: Arc) -> Self { + // TODO: change to to_string() if all the function name is converted to lowercase + self.udafs.insert(udaf.name().to_lowercase(), udaf); + self + } } impl ContextProvider for MyContextProvider { @@ -338,8 +461,8 @@ impl ContextProvider for MyContextProvider { None } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.udafs.get(name).cloned() } fn get_variable_type(&self, _variable_names: &[String]) -> Option { @@ -354,15 +477,15 @@ impl ContextProvider for MyContextProvider { &self.options } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } diff --git a/datafusion/physical-expr-common/Cargo.toml b/datafusion/physical-expr-common/Cargo.toml index d1202c83d526..45ccb08e52e9 100644 --- a/datafusion/physical-expr-common/Cargo.toml +++ b/datafusion/physical-expr-common/Cargo.toml @@ -36,6 +36,9 @@ name = "datafusion_physical_expr_common" path = "src/lib.rs" [dependencies] +ahash = { workspace = true } arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } +hashbrown = { workspace = true } +rand = { workspace = true } diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs deleted file mode 100644 index 448af634176a..000000000000 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ /dev/null @@ -1,295 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -pub mod utils; - -use arrow::datatypes::{DataType, Field, Schema}; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::type_coercion::aggregates::check_arg_count; -use datafusion_expr::{ - function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator, -}; -use std::fmt::Debug; -use std::{any::Any, sync::Arc}; - -use crate::physical_expr::PhysicalExpr; -use crate::sort_expr::{LexOrdering, PhysicalSortExpr}; - -use self::utils::{down_cast_any_ref, ordering_fields}; - -/// Creates a physical expression of the UDAF, that includes all necessary type coercion. -/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. -pub fn create_aggregate_expr( - fun: &AggregateUDF, - input_phy_exprs: &[Arc], - sort_exprs: &[Expr], - ordering_req: &[PhysicalSortExpr], - schema: &Schema, - name: impl Into, - ignore_nulls: bool, -) -> Result> { - let input_exprs_types = input_phy_exprs - .iter() - .map(|arg| arg.data_type(schema)) - .collect::>>()?; - - check_arg_count( - fun.name(), - &input_exprs_types, - &fun.signature().type_signature, - )?; - - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(schema)) - .collect::>>()?; - - let ordering_fields = ordering_fields(ordering_req, &ordering_types); - - Ok(Arc::new(AggregateFunctionExpr { - fun: fun.clone(), - args: input_phy_exprs.to_vec(), - data_type: fun.return_type(&input_exprs_types)?, - name: name.into(), - schema: schema.clone(), - sort_exprs: sort_exprs.to_vec(), - ordering_req: ordering_req.to_vec(), - ignore_nulls, - ordering_fields, - })) -} - -/// An aggregate expression that: -/// * knows its resulting field -/// * knows how to create its accumulator -/// * knows its accumulator's state's field -/// * knows the expressions from whose its accumulator will receive values -/// -/// Any implementation of this trait also needs to implement the -/// `PartialEq` to allows comparing equality between the -/// trait objects. -pub trait AggregateExpr: Send + Sync + Debug + PartialEq { - /// Returns the aggregate expression as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// the field of the final result of this aggregation. - fn field(&self) -> Result; - - /// the accumulator used to accumulate values from the expressions. - /// the accumulator expects the same number of arguments as `expressions` and must - /// return states with the same description as `state_fields` - fn create_accumulator(&self) -> Result>; - - /// the fields that encapsulate the Accumulator's state - /// the number of fields here equals the number of states that the accumulator contains - fn state_fields(&self) -> Result>; - - /// expressions that are passed to the Accumulator. - /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. - fn expressions(&self) -> Vec>; - - /// Order by requirements for the aggregate function - /// By default it is `None` (there is no requirement) - /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - /// Human readable name such as `"MIN(c2)"`. The default - /// implementation returns placeholder text. - fn name(&self) -> &str { - "AggregateExpr: default name" - } - - /// If the aggregate expression has a specialized - /// [`GroupsAccumulator`] implementation. If this returns true, - /// `[Self::create_groups_accumulator`] will be called. - fn groups_accumulator_supported(&self) -> bool { - false - } - - /// Return a specialized [`GroupsAccumulator`] that manages state - /// for all groups. - /// - /// For maximum performance, a [`GroupsAccumulator`] should be - /// implemented in addition to [`Accumulator`]. - fn create_groups_accumulator(&self) -> Result> { - not_impl_err!("GroupsAccumulator hasn't been implemented for {self:?} yet") - } - - /// Construct an expression that calculates the aggregate in reverse. - /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). - /// For aggregates that do not support calculation in reverse, - /// returns None (which is the default value). - fn reverse_expr(&self) -> Option> { - None - } - - /// Creates accumulator implementation that supports retract - fn create_sliding_accumulator(&self) -> Result> { - not_impl_err!("Retractable Accumulator hasn't been implemented for {self:?} yet") - } -} - -/// Physical aggregate expression of a UDAF. -#[derive(Debug)] -pub struct AggregateFunctionExpr { - fun: AggregateUDF, - args: Vec>, - /// Output / return type of this aggregate - data_type: DataType, - name: String, - schema: Schema, - // The logical order by expressions - sort_exprs: Vec, - // The physical order by expressions - ordering_req: LexOrdering, - ignore_nulls: bool, - ordering_fields: Vec, -} - -impl AggregateFunctionExpr { - /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` - pub fn fun(&self) -> &AggregateUDF { - &self.fun - } -} - -impl AggregateExpr for AggregateFunctionExpr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn expressions(&self) -> Vec> { - self.args.clone() - } - - fn state_fields(&self) -> Result> { - self.fun.state_fields( - self.name(), - self.data_type.clone(), - self.ordering_fields.clone(), - ) - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - let acc_args = AccumulatorArgs::new( - &self.data_type, - &self.schema, - self.ignore_nulls, - &self.sort_exprs, - ); - - self.fun.accumulator(acc_args) - } - - fn create_sliding_accumulator(&self) -> Result> { - let accumulator = self.create_accumulator()?; - - // Accumulators that have window frame startings different - // than `UNBOUNDED PRECEDING`, such as `1 PRECEEDING`, need to - // implement retract_batch method in order to run correctly - // currently in DataFusion. - // - // If this `retract_batches` is not present, there is no way - // to calculate result correctly. For example, the query - // - // ```sql - // SELECT - // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a - // FROM - // t - // ``` - // - // 1. First sum value will be the sum of rows between `[0, 1)`, - // - // 2. Second sum value will be the sum of rows between `[0, 2)` - // - // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. - // - // Since the accumulator keeps the running sum: - // - // 1. First sum we add to the state sum value between `[0, 1)` - // - // 2. Second sum we add to the state sum value between `[1, 2)` - // (`[0, 1)` is already in the state sum, hence running sum will - // cover `[0, 2)` range) - // - // 3. Third sum we add to the state sum value between `[2, 3)` - // (`[0, 2)` is already in the state sum). Also we need to - // retract values between `[0, 1)` by this way we can obtain sum - // between [1, 3) which is indeed the apropriate range. - // - // When we use `UNBOUNDED PRECEDING` in the query starting - // index will always be 0 for the desired range, and hence the - // `retract_batch` method will not be called. In this case - // having retract_batch is not a requirement. - // - // This approach is a a bit different than window function - // approach. In window function (when they use a window frame) - // they get all the desired range during evaluation. - if !accumulator.supports_retract_batch() { - return not_impl_err!( - "Aggregate can not be used as a sliding accumulator because \ - `retract_batch` is not implemented: {}", - self.name - ); - } - Ok(accumulator) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - self.fun.groups_accumulator_supported() - } - - fn create_groups_accumulator(&self) -> Result> { - self.fun.create_groups_accumulator() - } - - fn order_bys(&self) -> Option<&[PhysicalSortExpr]> { - (!self.ordering_req.is_empty()).then_some(&self.ordering_req) - } -} - -impl PartialEq for AggregateFunctionExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.fun == x.fun - && self.args.len() == x.args.len() - && self - .args - .iter() - .zip(x.args.iter()) - .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) - }) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr-common/src/aggregate/utils.rs b/datafusion/physical-expr-common/src/aggregate/utils.rs deleted file mode 100644 index 9821ba626b18..000000000000 --- a/datafusion/physical-expr-common/src/aggregate/utils.rs +++ /dev/null @@ -1,69 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{any::Any, sync::Arc}; - -use arrow::{ - compute::SortOptions, - datatypes::{DataType, Field}, -}; - -use crate::sort_expr::PhysicalSortExpr; - -use super::AggregateExpr; - -/// Downcast a `Box` or `Arc` -/// and return the inner trait object as [`Any`] so -/// that it can be downcast to a specific implementation. -/// -/// This method is used when implementing the `PartialEq` -/// for [`AggregateExpr`] aggregation expressions and allows comparing the equality -/// between the trait objects. -pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { - if let Some(obj) = any.downcast_ref::>() { - obj.as_any() - } else if let Some(obj) = any.downcast_ref::>() { - obj.as_any() - } else { - any - } -} - -/// Construct corresponding fields for lexicographical ordering requirement expression -pub fn ordering_fields( - ordering_req: &[PhysicalSortExpr], - // Data type of each expression in the ordering requirement - data_types: &[DataType], -) -> Vec { - ordering_req - .iter() - .zip(data_types.iter()) - .map(|(sort_expr, dtype)| { - Field::new( - sort_expr.expr.to_string().as_str(), - dtype.clone(), - // Multi partitions may be empty hence field should be nullable. - true, - ) - }) - .collect() -} - -/// Selects the sort option attribute from all the given `PhysicalSortExpr`s. -pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { - ordering_req.iter().map(|item| item.options).collect() -} diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs similarity index 95% rename from datafusion/physical-expr/src/binary_map.rs rename to datafusion/physical-expr-common/src/binary_map.rs index 0923fcdaeb91..80c4963ae035 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -19,20 +19,19 @@ //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. use ahash::RandomState; -use arrow_array::cast::AsArray; -use arrow_array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use arrow_array::{ - Array, ArrayRef, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, +use arrow::array::cast::AsArray; +use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; +use arrow::array::{ + Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, + GenericStringArray, OffsetSizeTrait, }; -use arrow_buffer::{ - BooleanBufferBuilder, BufferBuilder, NullBuffer, OffsetBuffer, ScalarBuffer, -}; -use arrow_schema::DataType; +use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; +use arrow::datatypes::DataType; use datafusion_common::hash_utils::create_hashes; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::any::type_name; use std::fmt::Debug; -use std::mem; +use std::mem::{size_of, swap}; use std::ops::Range; use std::sync::Arc; @@ -41,8 +40,12 @@ use std::sync::Arc; pub enum OutputType { /// `StringArray` or `LargeStringArray` Utf8, + /// `StringViewArray` + Utf8View, /// `BinaryArray` or `LargeBinaryArray` Binary, + /// `BinaryViewArray` + BinaryView, } /// HashSet optimized for storing string or binary values that can produce that @@ -57,7 +60,7 @@ impl ArrowBytesSet { /// Return the contents of this set and replace it with a new empty /// set with the same output type - pub(super) fn take(&mut self) -> Self { + pub fn take(&mut self) -> Self { Self(self.0.take()) } @@ -101,8 +104,9 @@ impl ArrowBytesSet { /// `Binary`, and `LargeBinary`) values that can produce the set of keys on /// output as `GenericBinaryArray` without copies. /// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringArray` / `BinaryArray`. For other +/// purposes it is the same as a `HashMap` /// /// # Generic Arguments /// @@ -114,13 +118,13 @@ impl ArrowBytesSet { /// This is a specialized HashMap with the following properties: /// /// 1. Optimized for storing and emitting Arrow byte types (e.g. -/// `StringArray` / `BinaryArray`) very efficiently by minimizing copying of -/// the string values themselves, both when inserting and when emitting the -/// final array. +/// `StringArray` / `BinaryArray`) very efficiently by minimizing copying of +/// the string values themselves, both when inserting and when emitting the +/// final array. /// /// /// 2. Retains the insertion order of entries in the final array. The values are -/// in the same order as they were inserted. +/// in the same order as they were inserted. /// /// Note this structure can be used as a `HashSet` by specifying the value type /// as `()`, as is done by [`ArrowBytesSet`]. @@ -135,18 +139,18 @@ impl ArrowBytesSet { /// "Foo", NULL, "Bar", "TheQuickBrownFox": /// /// * `hashtable` stores entries for each distinct string that has been -/// inserted. The entries contain the payload as well as information about the -/// value (either an offset or the actual bytes, see `Entry` docs for more -/// details) +/// inserted. The entries contain the payload as well as information about the +/// value (either an offset or the actual bytes, see `Entry` docs for more +/// details) /// /// * `offsets` stores offsets into `buffer` for each distinct string value, -/// following the same convention as the offsets in a `StringArray` or -/// `LargeStringArray`. +/// following the same convention as the offsets in a `StringArray` or +/// `LargeStringArray`. /// /// * `buffer` stores the actual byte data /// /// * `null`: stores the index and payload of the null value, in this case the -/// second value (index 1) +/// second value (index 1) /// /// ```text /// ┌───────────────────────────────────┐ ┌─────┐ ┌────┐ @@ -234,7 +238,7 @@ where /// The size, in number of entries, of the initial hash table const INITIAL_MAP_CAPACITY: usize = 128; /// The initial size, in bytes, of the string data -const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; +pub const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; impl ArrowBytesMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, @@ -256,7 +260,7 @@ where /// the same output type pub fn take(&mut self) -> Self { let mut new_self = Self::new(self.output_type); - std::mem::swap(self, &mut new_self); + swap(self, &mut new_self); new_self } @@ -319,6 +323,7 @@ where observe_payload_fn, ) } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; } @@ -356,7 +361,7 @@ where assert_eq!(values.len(), batch_hashes.len()); for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // hande null value + // handle null value let Some(value) = value else { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload @@ -440,7 +445,7 @@ where // Put the small values into buffer and offsets so it // appears the output array, and store that offset // so the bytes can be compared if needed - let offset = self.buffer.len(); // offset of start fof data + let offset = self.buffer.len(); // offset of start for data self.buffer.append_slice(value); self.offsets.push(O::usize_as(self.buffer.len())); @@ -517,6 +522,7 @@ where GenericStringArray::new_unchecked(offsets, values, nulls) }) } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), } } @@ -539,7 +545,7 @@ where /// this set, not including `self` pub fn size(&self) -> usize { self.map_size - + self.buffer.capacity() * std::mem::size_of::() + + self.buffer.capacity() * size_of::() + self.offsets.allocated_size() + self.hashes_buffer.allocated_size() } @@ -569,7 +575,7 @@ where } /// Maximum size of a value that can be inlined in the hash table -const SHORT_VALUE_LEN: usize = mem::size_of::(); +const SHORT_VALUE_LEN: usize = size_of::(); /// Entry in the hash table -- see [`ArrowBytesMap`] for more details #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] @@ -605,8 +611,8 @@ where #[cfg(test)] mod tests { use super::*; - use arrow_array::{BinaryArray, LargeBinaryArray, StringArray}; - use hashbrown::HashMap; + use arrow::array::{BinaryArray, LargeBinaryArray, StringArray}; + use std::collections::HashMap; #[test] fn string_set_empty() { diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs new file mode 100644 index 000000000000..c6768a19d30e --- /dev/null +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -0,0 +1,691 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ArrowBytesViewMap`] and [`ArrowBytesViewSet`] for storing maps/sets of values from +//! `StringViewArray`/`BinaryViewArray`. +//! Much of the code is from `binary_map.rs`, but with simpler implementation because we directly use the +//! [`GenericByteViewBuilder`]. +use ahash::RandomState; +use arrow::array::cast::AsArray; +use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; +use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use std::fmt::Debug; +use std::sync::Arc; + +use crate::binary_map::OutputType; + +/// HashSet optimized for storing string or binary values that can produce that +/// the final set as a `GenericBinaryViewArray` with minimal copies. +#[derive(Debug)] +pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>); + +impl ArrowBytesViewSet { + pub fn new(output_type: OutputType) -> Self { + Self(ArrowBytesViewMap::new(output_type)) + } + + /// Inserts each value from `values` into the set + pub fn insert(&mut self, values: &ArrayRef) { + fn make_payload_fn(_value: Option<&[u8]>) {} + fn observe_payload_fn(_payload: ()) {} + self.0 + .insert_if_new(values, make_payload_fn, observe_payload_fn); + } + + /// Return the contents of this map and replace it with a new empty map with + /// the same output type + pub fn take(&mut self) -> Self { + let mut new_self = Self::new(self.0.output_type); + std::mem::swap(self, &mut new_self); + new_self + } + + /// Converts this set into a `StringViewArray` or `BinaryViewArray` + /// containing each distinct value that was interned. + /// This is done without copying the values. + pub fn into_state(self) -> ArrayRef { + self.0.into_state() + } + + /// Returns the total number of distinct values (including nulls) seen so far + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + /// returns the total number of distinct values (not including nulls) seen so far + pub fn non_null_len(&self) -> usize { + self.0.non_null_len() + } + + /// Return the total size, in bytes, of memory used to store the data in + /// this set, not including `self` + pub fn size(&self) -> usize { + self.0.size() + } +} + +/// Optimized map for storing Arrow "byte view" types (`StringView`, `BinaryView`) +/// values that can produce the set of keys on +/// output as `GenericBinaryViewArray` without copies. +/// +/// Equivalent to `HashSet` but with better performance if you need +/// to emit the keys as an Arrow `StringViewArray` / `BinaryViewArray`. For other +/// purposes it is the same as a `HashMap` +/// +/// # Generic Arguments +/// +/// * `V`: payload type +/// +/// # Description +/// +/// This is a specialized HashMap with the following properties: +/// +/// 1. Optimized for storing and emitting Arrow byte types (e.g. +/// `StringViewArray` / `BinaryViewArray`) very efficiently by minimizing copying of +/// the string values themselves, both when inserting and when emitting the +/// final array. +/// +/// 2. Retains the insertion order of entries in the final array. The values are +/// in the same order as they were inserted. +/// +/// Note this structure can be used as a `HashSet` by specifying the value type +/// as `()`, as is done by [`ArrowBytesViewSet`]. +/// +/// This map is used by the special `COUNT DISTINCT` aggregate function to +/// store the distinct values, and by the `GROUP BY` operator to store +/// group values when they are a single string array. + +pub struct ArrowBytesViewMap +where + V: Debug + PartialEq + Eq + Clone + Copy + Default, +{ + /// Should the output be StringView or BinaryView? + output_type: OutputType, + /// Underlying hash set for each distinct value + map: hashbrown::raw::RawTable>, + /// Total size of the map in bytes + map_size: usize, + + /// Builder for output array + builder: GenericByteViewBuilder, + /// random state used to generate hashes + random_state: RandomState, + /// buffer that stores hash values (reused across batches to save allocations) + hashes_buffer: Vec, + /// `(payload, null_index)` for the 'null' value, if any + /// NOTE null_index is the logical index in the final array, not the index + /// in the buffer + null: Option<(V, usize)>, +} + +/// The size, in number of entries, of the initial hash table +const INITIAL_MAP_CAPACITY: usize = 512; + +impl ArrowBytesViewMap +where + V: Debug + PartialEq + Eq + Clone + Copy + Default, +{ + pub fn new(output_type: OutputType) -> Self { + Self { + output_type, + map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), + map_size: 0, + builder: GenericByteViewBuilder::new(), + random_state: RandomState::new(), + hashes_buffer: vec![], + null: None, + } + } + + /// Return the contents of this map and replace it with a new empty map with + /// the same output type + pub fn take(&mut self) -> Self { + let mut new_self = Self::new(self.output_type); + std::mem::swap(self, &mut new_self); + new_self + } + + /// Inserts each value from `values` into the map, invoking `payload_fn` for + /// each value if *not* already present, deferring the allocation of the + /// payload until it is needed. + /// + /// Note that this is different than a normal map that would replace the + /// existing entry + /// + /// # Arguments: + /// + /// `values`: array whose values are inserted + /// + /// `make_payload_fn`: invoked for each value that is not already present + /// to create the payload, in order of the values in `values` + /// + /// `observe_payload_fn`: invoked once, for each value in `values`, that was + /// already present in the map, with corresponding payload value. + /// + /// # Returns + /// + /// The payload value for the entry, either the existing value or + /// the newly inserted value + /// + /// # Safety: + /// + /// Note that `make_payload_fn` and `observe_payload_fn` are only invoked + /// with valid values from `values`, not for the `NULL` value. + pub fn insert_if_new( + &mut self, + values: &ArrayRef, + make_payload_fn: MP, + observe_payload_fn: OP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + OP: FnMut(V), + { + // Sanity check array type + match self.output_type { + OutputType::BinaryView => { + assert!(matches!(values.data_type(), DataType::BinaryView)); + self.insert_if_new_inner::( + values, + make_payload_fn, + observe_payload_fn, + ) + } + OutputType::Utf8View => { + assert!(matches!(values.data_type(), DataType::Utf8View)); + self.insert_if_new_inner::( + values, + make_payload_fn, + observe_payload_fn, + ) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), + }; + } + + /// Generic version of [`Self::insert_if_new`] that handles `ByteViewType` + /// (both StringView and BinaryView) + /// + /// Note this is the only function that is generic on [`ByteViewType`], which + /// avoids having to template the entire structure, making the code + /// simpler and understand and reducing code bloat due to duplication. + /// + /// See comments on `insert_if_new` for more details + fn insert_if_new_inner( + &mut self, + values: &ArrayRef, + mut make_payload_fn: MP, + mut observe_payload_fn: OP, + ) where + MP: FnMut(Option<&[u8]>) -> V, + OP: FnMut(V), + B: ByteViewType, + { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each value into the set, if not already present + let values = values.as_byte_view::(); + + // Ensure lengths are equivalent + assert_eq!(values.len(), batch_hashes.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // handle null value + let Some(value) = value else { + let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { + payload + } else { + let payload = make_payload_fn(None); + let null_index = self.builder.len(); + self.builder.append_null(); + self.null = Some((payload, null_index)); + payload + }; + observe_payload_fn(payload); + continue; + }; + + // get the value as bytes + let value: &[u8] = value.as_ref(); + + let entry = self.map.get_mut(hash, |header| { + let v = self.builder.get_value(header.view_idx); + + if v.len() != value.len() { + return false; + } + + v == value + }); + + let payload = if let Some(entry) = entry { + entry.payload + } else { + // no existing value, make a new one. + let payload = make_payload_fn(Some(value)); + + let inner_view_idx = self.builder.len(); + let new_header = Entry { + view_idx: inner_view_idx, + hash, + payload, + }; + + self.builder.append_value(value); + + self.map + .insert_accounted(new_header, |h| h.hash, &mut self.map_size); + payload + }; + observe_payload_fn(payload); + } + } + + /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, + /// containing each distinct value + /// that was inserted. This is done without copying the values. + /// + /// The values are guaranteed to be returned in the same order in which + /// they were first seen. + pub fn into_state(self) -> ArrayRef { + let mut builder = self.builder; + match self.output_type { + OutputType::BinaryView => { + let array = builder.finish(); + + Arc::new(array) + } + OutputType::Utf8View => { + // SAFETY: + // we asserted the input arrays were all the correct type and + // thus since all the values that went in were valid (e.g. utf8) + // so are all the values that come out + let array = builder.finish(); + let array = unsafe { array.to_string_view_unchecked() }; + Arc::new(array) + } + _ => { + unreachable!("Utf8/Binary should use `ArrowBytesMap`") + } + } + } + + /// Total number of entries (including null, if present) + pub fn len(&self) -> usize { + self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) + } + + /// Is the set empty? + pub fn is_empty(&self) -> bool { + self.map.is_empty() && self.null.is_none() + } + + /// Number of non null entries + pub fn non_null_len(&self) -> usize { + self.map.len() + } + + /// Return the total size, in bytes, of memory used to store the data in + /// this set, not including `self` + pub fn size(&self) -> usize { + self.map_size + + self.builder.allocated_size() + + self.hashes_buffer.allocated_size() + } +} + +impl Debug for ArrowBytesViewMap +where + V: Debug + PartialEq + Eq + Clone + Copy + Default, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ArrowBytesMap") + .field("map", &"") + .field("map_size", &self.map_size) + .field("view_builder", &self.builder) + .field("random_state", &self.random_state) + .field("hashes_buffer", &self.hashes_buffer) + .finish() + } +} + +/// Entry in the hash table -- see [`ArrowBytesViewMap`] for more details +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +struct Entry +where + V: Debug + PartialEq + Eq + Clone + Copy + Default, +{ + /// The idx into the views array + view_idx: usize, + + hash: u64, + + /// value stored by the entry + payload: V, +} + +#[cfg(test)] +mod tests { + use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; + use hashbrown::HashMap; + + use super::*; + + // asserts that the set contains the expected strings, in the same order + fn assert_set(set: ArrowBytesViewSet, expected: &[Option<&str>]) { + let strings = set.into_state(); + let strings = strings.as_string_view(); + let state = strings.into_iter().collect::>(); + assert_eq!(state, expected); + } + + #[test] + fn string_view_set_empty() { + let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); + let array: ArrayRef = Arc::new(StringViewArray::new_null(0)); + set.insert(&array); + assert_eq!(set.len(), 0); + assert_eq!(set.non_null_len(), 0); + assert_set(set, &[]); + } + + #[test] + fn string_view_set_one_null() { + let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); + let array: ArrayRef = Arc::new(StringViewArray::new_null(1)); + set.insert(&array); + assert_eq!(set.len(), 1); + assert_eq!(set.non_null_len(), 0); + assert_set(set, &[None]); + } + + #[test] + fn string_view_set_many_null() { + let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); + let array: ArrayRef = Arc::new(StringViewArray::new_null(11)); + set.insert(&array); + assert_eq!(set.len(), 1); + assert_eq!(set.non_null_len(), 0); + assert_set(set, &[None]); + } + + #[test] + fn test_string_view_set_basic() { + // basic test for mixed small and large string values + let values = GenericByteViewArray::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCCAABB"), // 14 bytes + Some(""), + Some("cbcxx"), // 5 bytes + None, + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBBAAA"), // 12 bytes + Some("a"), + Some("cbcxx"), + Some("b"), + Some("cbcxx"), + Some(""), + None, + Some("BBBBBQBBBAAA"), + Some("BBBBBQBBBAAA"), + Some("AAAAAAAA"), + Some("CXCCCCCCCCAABB"), + ]); + + let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); + let array: ArrayRef = Arc::new(values); + set.insert(&array); + // values mut appear be in the order they were inserted + assert_set( + set, + &[ + Some("a"), + Some("b"), + Some("CXCCCCCCCCAABB"), + Some(""), + Some("cbcxx"), + None, + Some("AAAAAAAA"), + Some("BBBBBQBBBAAA"), + ], + ); + } + + #[test] + fn test_string_set_non_utf8() { + // basic test for mixed small and large string values + let values = GenericByteViewArray::from(vec![ + Some("a"), + Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), + Some("🔥"), + Some("✨✨✨"), + Some("foobarbaz"), + Some("🔥"), + Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), + ]); + + let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); + let array: ArrayRef = Arc::new(values); + set.insert(&array); + // strings mut appear be in the order they were inserted + assert_set( + set, + &[ + Some("a"), + Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), + Some("🔥"), + Some("✨✨✨"), + Some("foobarbaz"), + ], + ); + } + + // Test use of binary output type + #[test] + fn test_binary_set() { + let v: Vec> = vec![ + Some(b"a"), + Some(b"CXCCCCCCCCCCCCC"), + None, + Some(b"CXCCCCCCCCCCCCC"), + ]; + let values: ArrayRef = Arc::new(BinaryViewArray::from(v)); + + let expected: Vec> = + vec![Some(b"a"), Some(b"CXCCCCCCCCCCCCC"), None]; + let expected: ArrayRef = Arc::new(GenericByteViewArray::from(expected)); + + let mut set = ArrowBytesViewSet::new(OutputType::BinaryView); + set.insert(&values); + assert_eq!(&set.into_state(), &expected); + } + + // inserting strings into the set does not increase reported memory + #[test] + fn test_string_set_memory_usage() { + let strings1 = StringViewArray::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCCCCC"), // 13 bytes + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + ]); + let total_strings1_len = strings1 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values1: ArrayRef = Arc::new(StringViewArray::from(strings1)); + + // Much larger strings in strings2 + let strings2 = StringViewArray::from(vec![ + "FOO".repeat(1000), + "BAR larger than 12 bytes.".repeat(100_000), + "more unique.".repeat(1000), + "more unique2.".repeat(1000), + "FOO".repeat(3000), + ]); + let total_strings2_len = strings2 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values2: ArrayRef = Arc::new(StringViewArray::from(strings2)); + + let mut set = ArrowBytesViewSet::new(OutputType::Utf8View); + let size_empty = set.size(); + + set.insert(&values1); + let size_after_values1 = set.size(); + assert!(size_empty < size_after_values1); + assert!( + size_after_values1 > total_strings1_len, + "expect {size_after_values1} to be more than {total_strings1_len}" + ); + assert!(size_after_values1 < total_strings1_len + total_strings2_len); + + // inserting the same strings should not affect the size + set.insert(&values1); + assert_eq!(set.size(), size_after_values1); + assert_eq!(set.len(), 5); + + // inserting the large strings should increase the reported size + set.insert(&values2); + let size_after_values2 = set.size(); + assert!(size_after_values2 > size_after_values1); + + assert_eq!(set.len(), 10); + } + + #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] + struct TestPayload { + // store the string value to check against input + index: usize, // store the index of the string (each new string gets the next sequential input) + } + + /// Wraps an [`ArrowBytesViewMap`], validating its invariants + struct TestMap { + map: ArrowBytesViewMap, + // stores distinct strings seen, in order + strings: Vec>, + // map strings to index in strings + indexes: HashMap, usize>, + } + + impl TestMap { + /// creates a map with TestPayloads for the given strings and then + /// validates the payloads + fn new() -> Self { + Self { + map: ArrowBytesViewMap::new(OutputType::Utf8View), + strings: vec![], + indexes: HashMap::new(), + } + } + + /// Inserts strings into the map + fn insert(&mut self, strings: &[Option<&str>]) { + let string_array = StringViewArray::from(strings.to_vec()); + let arr: ArrayRef = Arc::new(string_array); + + let mut next_index = self.indexes.len(); + let mut actual_new_strings = vec![]; + let mut actual_seen_indexes = vec![]; + // update self with new values, keeping track of newly added values + for str in strings { + let str = str.map(|s| s.to_string()); + let index = self.indexes.get(&str).cloned().unwrap_or_else(|| { + actual_new_strings.push(str.clone()); + let index = self.strings.len(); + self.strings.push(str.clone()); + self.indexes.insert(str, index); + index + }); + actual_seen_indexes.push(index); + } + + // insert the values into the map, recording what we did + let mut seen_new_strings = vec![]; + let mut seen_indexes = vec![]; + self.map.insert_if_new( + &arr, + |s| { + let value = s + .map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string")); + let index = next_index; + next_index += 1; + seen_new_strings.push(value); + TestPayload { index } + }, + |payload| { + seen_indexes.push(payload.index); + }, + ); + + assert_eq!(actual_seen_indexes, seen_indexes); + assert_eq!(actual_new_strings, seen_new_strings); + } + + /// Call `self.map.into_array()` validating that the strings are in the same + /// order as they were inserted + fn into_array(self) -> ArrayRef { + let Self { + map, + strings, + indexes: _, + } = self; + + let arr = map.into_state(); + let expected: ArrayRef = Arc::new(StringViewArray::from(strings)); + assert_eq!(&arr, &expected); + arr + } + } + + #[test] + fn test_map() { + let input = vec![ + // Note mix of short/long strings + Some("A"), + Some("bcdefghijklmnop1234567"), + Some("X"), + Some("Y"), + None, + Some("qrstuvqxyzhjwya"), + Some("✨🔥"), + Some("🔥"), + Some("🔥🔥🔥🔥🔥🔥"), + ]; + + let mut test_map = TestMap::new(); + test_map.insert(&input); + test_map.insert(&input); // put it in twice + let expected_output: ArrayRef = Arc::new(StringViewArray::from(input)); + assert_eq!(&test_map.into_array(), &expected_output); + } +} diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs new file mode 100644 index 000000000000..c47ec9d75d50 --- /dev/null +++ b/datafusion/physical-expr-common/src/datum.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// UnLt required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::BooleanArray; +use arrow::array::{make_comparator, ArrayRef, Datum}; +use arrow::buffer::NullBuffer; +use arrow::compute::SortOptions; +use arrow::error::ArrowError; +use datafusion_common::DataFusionError; +use datafusion_common::{arrow_datafusion_err, internal_err}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::operator::Operator; +use std::sync::Arc; + +/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` +/// +/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction +pub fn apply( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + match (&lhs, &rhs) { + (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { + Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) + } + (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( + ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ), + (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( + ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), + ), + (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { + let array = f(&left.to_scalar()?, &right.to_scalar()?)?; + let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` +pub fn apply_cmp( + lhs: &ColumnarValue, + rhs: &ColumnarValue, + f: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) +} + +/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` for nested type like +/// List, FixedSizeList, LargeList, Struct, Union, Map, or a dictionary of a nested type +pub fn apply_cmp_for_nested( + op: Operator, + lhs: &ColumnarValue, + rhs: &ColumnarValue, +) -> Result { + if matches!( + op, + Operator::Eq + | Operator::NotEq + | Operator::Lt + | Operator::Gt + | Operator::LtEq + | Operator::GtEq + | Operator::IsDistinctFrom + | Operator::IsNotDistinctFrom + ) { + apply(lhs, rhs, |l, r| { + Ok(Arc::new(compare_op_for_nested(op, l, r)?)) + }) + } else { + internal_err!("invalid operator for nested") + } +} + +/// Compare with eq with either nested or non-nested +pub fn compare_with_eq( + lhs: &dyn Datum, + rhs: &dyn Datum, + is_nested: bool, +) -> Result { + if is_nested { + compare_op_for_nested(Operator::Eq, lhs, rhs) + } else { + arrow::compute::kernels::cmp::eq(lhs, rhs).map_err(|e| arrow_datafusion_err!(e)) + } +} + +/// Compare on nested type List, Struct, and so on +pub fn compare_op_for_nested( + op: Operator, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + let (l, is_l_scalar) = lhs.get(); + let (r, is_r_scalar) = rhs.get(); + let l_len = l.len(); + let r_len = r.len(); + + if l_len != r_len && !is_l_scalar && !is_r_scalar { + return internal_err!("len mismatch"); + } + + let len = match is_l_scalar { + true => r_len, + false => l_len, + }; + + // fast path, if compare with one null and operator is not 'distinct', then we can return null array directly + if !matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) + && (is_l_scalar && l.null_count() == 1 || is_r_scalar && r.null_count() == 1) + { + return Ok(BooleanArray::new_null(len)); + } + + // TODO: make SortOptions configurable + // we choose the default behaviour from arrow-rs which has null-first that follow spark's behaviour + let cmp = make_comparator(l, r, SortOptions::default())?; + + let cmp_with_op = |i, j| match op { + Operator::Eq | Operator::IsNotDistinctFrom => cmp(i, j).is_eq(), + Operator::Lt => cmp(i, j).is_lt(), + Operator::Gt => cmp(i, j).is_gt(), + Operator::LtEq => !cmp(i, j).is_gt(), + Operator::GtEq => !cmp(i, j).is_lt(), + Operator::NotEq | Operator::IsDistinctFrom => !cmp(i, j).is_eq(), + _ => unreachable!("unexpected operator found"), + }; + + let values = match (is_l_scalar, is_r_scalar) { + (false, false) => (0..len).map(|i| cmp_with_op(i, i)).collect(), + (true, false) => (0..len).map(|i| cmp_with_op(0, i)).collect(), + (false, true) => (0..len).map(|i| cmp_with_op(i, 0)).collect(), + (true, true) => std::iter::once(cmp_with_op(0, 0)).collect(), + }; + + // Distinct understand how to compare with NULL + // i.e NULL is distinct from NULL -> false + if matches!(op, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom) { + Ok(BooleanArray::new(values, None)) + } else { + // If one of the side is NULL, we returns NULL + // i.e. NULL eq NULL -> NULL + let nulls = NullBuffer::union(l.nulls(), r.nulls()); + Ok(BooleanArray::new(values, nulls)) + } +} diff --git a/datafusion/physical-expr-common/src/expressions/column.rs b/datafusion/physical-expr-common/src/expressions/column.rs deleted file mode 100644 index 2cd52d6332fb..000000000000 --- a/datafusion/physical-expr-common/src/expressions/column.rs +++ /dev/null @@ -1,137 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Column expression - -use std::any::Any; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; - -use arrow::{ - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; -use datafusion_common::{internal_err, Result}; -use datafusion_expr::ColumnarValue; - -use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; - -/// Represents the column at a given index in a RecordBatch -#[derive(Debug, Hash, PartialEq, Eq, Clone)] -pub struct Column { - name: String, - index: usize, -} - -impl Column { - /// Create a new column expression - pub fn new(name: &str, index: usize) -> Self { - Self { - name: name.to_owned(), - index, - } - } - - /// Create a new column expression based on column name and schema - pub fn new_with_schema(name: &str, schema: &Schema) -> Result { - Ok(Column::new(name, schema.index_of(name)?)) - } - - /// Get the column name - pub fn name(&self) -> &str { - &self.name - } - - /// Get the column index - pub fn index(&self) -> usize { - self.index - } -} - -impl std::fmt::Display for Column { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}@{}", self.name, self.index) - } -} - -impl PhysicalExpr for Column { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { - self - } - - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result { - self.bounds_check(input_schema)?; - Ok(input_schema.field(self.index).data_type().clone()) - } - - /// Decide whehter this expression is nullable, given the schema of the input - fn nullable(&self, input_schema: &Schema) -> Result { - self.bounds_check(input_schema)?; - Ok(input_schema.field(self.index).is_nullable()) - } - - /// Evaluate the expression - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.bounds_check(batch.schema().as_ref())?; - Ok(ColumnarValue::Array(batch.column(self.index).clone())) - } - - fn children(&self) -> Vec> { - vec![] - } - - fn with_new_children( - self: Arc, - _children: Vec>, - ) -> Result> { - Ok(self) - } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for Column { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) - } -} - -impl Column { - fn bounds_check(&self, input_schema: &Schema) -> Result<()> { - if self.index < input_schema.fields.len() { - Ok(()) - } else { - internal_err!( - "PhysicalExpr Column references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}", - self.name, - self.index, input_schema.fields.len(), input_schema.fields().iter().map(|f| f.name().clone()).collect::>()) - } - } -} - -/// Create a column expression -pub fn col(name: &str, schema: &Schema) -> Result> { - Ok(Arc::new(Column::new_with_schema(name, schema)?)) -} diff --git a/datafusion/physical-expr-common/src/lib.rs b/datafusion/physical-expr-common/src/lib.rs index 53e3134a1b05..7e2ea0c49397 100644 --- a/datafusion/physical-expr-common/src/lib.rs +++ b/datafusion/physical-expr-common/src/lib.rs @@ -15,10 +15,15 @@ // specific language governing permissions and limitations // under the License. -pub mod aggregate; -pub mod expressions; +//! Physical Expr Common packages for [DataFusion] +//! This package contains high level PhysicalExpr trait +//! +//! [DataFusion]: + +pub mod binary_map; +pub mod binary_view_map; +pub mod datum; pub mod physical_expr; pub mod sort_expr; -pub mod sort_properties; pub mod tree_node; pub mod utils; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index be6358e73c99..cc725cf2cefb 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -16,24 +16,42 @@ // under the License. use std::any::Any; -use std::fmt::{Debug, Display}; +use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; +use crate::utils::scatter; + use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::utils::DataPtr; use datafusion_common::{internal_err, not_impl_err, Result}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::ColumnarValue; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::ExprProperties; -use crate::sort_properties::SortProperties; -use crate::utils::scatter; - -/// See [create_physical_expr](https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html) -/// for examples of creating `PhysicalExpr` from `Expr` +/// [`PhysicalExpr`]s represent expressions such as `A + 1` or `CAST(c1 AS int)`. +/// +/// `PhysicalExpr` knows its type, nullability and can be evaluated directly on +/// a [`RecordBatch`] (see [`Self::evaluate`]). +/// +/// `PhysicalExpr` are the physical counterpart to [`Expr`] used in logical +/// planning. They are typically created from [`Expr`] by a [`PhysicalPlanner`] +/// invoked from a higher level API +/// +/// Some important examples of `PhysicalExpr` are: +/// * [`Column`]: Represents a column at a given index in a RecordBatch +/// +/// To create `PhysicalExpr` from `Expr`, see +/// * [`SessionContext::create_physical_expr`]: A high level API +/// * [`create_physical_expr`]: A low level API +/// +/// [`SessionContext::create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.create_physical_expr +/// [`PhysicalPlanner`]: https://docs.rs/datafusion/latest/datafusion/physical_planner/trait.PhysicalPlanner.html +/// [`Expr`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html +/// [`create_physical_expr`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/fn.create_physical_expr.html +/// [`Column`]: https://docs.rs/datafusion/latest/datafusion/physical_expr/expressions/struct.Column.html pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// Returns the physical expression as [`Any`] so that it can be /// downcast to a specific implementation. @@ -66,7 +84,7 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { } /// Get a list of child PhysicalExpr that provide the input for this expr. - fn children(&self) -> Vec>; + fn children(&self) -> Vec<&Arc>; /// Returns a new PhysicalExpr where all children were replaced by new exprs. fn with_new_children( @@ -80,7 +98,7 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// # Arguments /// /// * `children` are the intervals for the children (inputs) of this - /// expression. + /// expression. /// /// # Example /// @@ -113,8 +131,8 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// /// If the expression is `a + b`, the current `interval` is `[4, 5]` and the /// inputs `a` and `b` are respectively given as `[0, 2]` and `[-∞, 4]`, then - /// propagation would would return `[0, 2]` and `[2, 4]` as `b` must be at - /// least `2` to make the output at least `4`. + /// propagation would return `[0, 2]` and `[2, 4]` as `b` must be at least + /// `2` to make the output at least `4`. fn propagate_constraints( &self, _interval: &Interval, @@ -155,17 +173,13 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { /// directly because it must remain object safe. fn dyn_hash(&self, _state: &mut dyn Hasher); - /// The order information of a PhysicalExpr can be estimated from its children. - /// This is especially helpful for projection expressions. If we can ensure that the - /// order of a PhysicalExpr to project matches with the order of SortExec, we can - /// eliminate that SortExecs. - /// - /// By recursively calling this function, we can obtain the overall order - /// information of the PhysicalExpr. Since `SortOptions` cannot fully handle - /// the propagation of unordered columns and literals, the `SortProperties` - /// struct is used. - fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties { - SortProperties::Unordered + /// Calculates the properties of this [`PhysicalExpr`] based on its + /// children's properties (i.e. order and range), recursively aggregating + /// the information from its children. In cases where the [`PhysicalExpr`] + /// has no children (e.g., `Literal` or `Column`), these properties should + /// be specified externally, as the function defaults to unknown properties. + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties::new_unknown()) } } @@ -188,7 +202,7 @@ pub fn with_new_children_if_necessary( || children .iter() .zip(old_children.iter()) - .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) { Ok(expr.with_new_children(children)?) } else { @@ -209,3 +223,25 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { any } } + +/// Returns [`Display`] able a list of [`PhysicalExpr`] +/// +/// Example output: `[a + 1, b]` +pub fn format_physical_expr_list(exprs: &[Arc]) -> impl Display + '_ { + struct DisplayWrapper<'a>(&'a [Arc]); + impl<'a> Display for DisplayWrapper<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut iter = self.0.iter(); + write!(f, "[")?; + if let Some(expr) = iter.next() { + write!(f, "{}", expr)?; + } + for expr in iter { + write!(f, ", {}", expr)?; + } + write!(f, "]")?; + Ok(()) + } + } + DisplayWrapper(exprs) +} diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 1e1187212d96..addf2fbfca0c 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -17,19 +17,66 @@ //! Sort expressions -use std::fmt::Display; +use crate::physical_expr::PhysicalExpr; +use std::fmt; +use std::fmt::{Display, Formatter}; use std::hash::{Hash, Hasher}; +use std::ops::{Deref, Index, Range, RangeFrom, RangeTo}; use std::sync::Arc; +use std::vec::IntoIter; use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::Schema; use arrow::record_batch::RecordBatch; use datafusion_common::Result; -use datafusion_expr::ColumnarValue; - -use crate::physical_expr::PhysicalExpr; +use datafusion_expr_common::columnar_value::ColumnarValue; /// Represents Sort operation for a column in a RecordBatch +/// +/// Example: +/// ``` +/// # use std::any::Any; +/// # use std::fmt::Display; +/// # use std::hash::Hasher; +/// # use std::sync::Arc; +/// # use arrow::array::RecordBatch; +/// # use datafusion_common::Result; +/// # use arrow::compute::SortOptions; +/// # use arrow::datatypes::{DataType, Schema}; +/// # use datafusion_expr_common::columnar_value::ColumnarValue; +/// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +/// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; +/// # // this crate doesn't have a physical expression implementation +/// # // so make a really simple one +/// # #[derive(Clone, Debug, PartialEq, Eq, Hash)] +/// # struct MyPhysicalExpr; +/// # impl PhysicalExpr for MyPhysicalExpr { +/// # fn as_any(&self) -> &dyn Any {todo!() } +/// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} +/// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } +/// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } +/// # fn children(&self) -> Vec<&Arc> {todo!()} +/// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} +/// # fn dyn_hash(&self, _state: &mut dyn Hasher) {todo!()} +/// # } +/// # impl Display for MyPhysicalExpr { +/// # fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "a") } +/// # } +/// # impl PartialEq for MyPhysicalExpr { +/// # fn eq(&self, _other: &dyn Any) -> bool { true } +/// # } +/// # fn col(name: &str) -> Arc { Arc::new(MyPhysicalExpr) } +/// // Sort by a ASC +/// let options = SortOptions::default(); +/// let sort_expr = PhysicalSortExpr::new(col("a"), options); +/// assert_eq!(sort_expr.to_string(), "a ASC"); +/// +/// // Sort by a DESC NULLS LAST +/// let sort_expr = PhysicalSortExpr::new_default(col("a")) +/// .desc() +/// .nulls_last(); +/// assert_eq!(sort_expr.to_string(), "a DESC NULLS LAST"); +/// ``` #[derive(Clone, Debug)] pub struct PhysicalSortExpr { /// Physical expression representing the column to sort @@ -38,6 +85,49 @@ pub struct PhysicalSortExpr { pub options: SortOptions, } +impl PhysicalSortExpr { + /// Create a new PhysicalSortExpr + pub fn new(expr: Arc, options: SortOptions) -> Self { + Self { expr, options } + } + + /// Create a new PhysicalSortExpr with default [`SortOptions`] + pub fn new_default(expr: Arc) -> Self { + Self::new(expr, SortOptions::default()) + } + + /// Set the sort sort options to ASC + pub fn asc(mut self) -> Self { + self.options.descending = false; + self + } + + /// Set the sort sort options to DESC + pub fn desc(mut self) -> Self { + self.options.descending = true; + self + } + + /// Set the sort sort options to NULLS FIRST + pub fn nulls_first(mut self) -> Self { + self.options.nulls_first = true; + self + } + + /// Set the sort sort options to NULLS LAST + pub fn nulls_last(mut self) -> Self { + self.options.nulls_first = false; + self + } +} + +/// Access the PhysicalSortExpr as a PhysicalExpr +impl AsRef for PhysicalSortExpr { + fn as_ref(&self) -> &(dyn PhysicalExpr + 'static) { + self.expr.as_ref() + } +} + impl PartialEq for PhysicalSortExpr { fn eq(&self, other: &PhysicalSortExpr) -> bool { self.options == other.options && self.expr.eq(&other.expr) @@ -53,8 +143,8 @@ impl Hash for PhysicalSortExpr { } } -impl std::fmt::Display for PhysicalSortExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for PhysicalSortExpr { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { write!(f, "{} {}", self.expr, to_str(&self.options)) } } @@ -94,26 +184,6 @@ impl PhysicalSortExpr { .map_or(true, |opts| self.options.descending == opts.descending) } } - - /// Returns a [`Display`]able list of `PhysicalSortExpr`. - pub fn format_list(input: &[PhysicalSortExpr]) -> impl Display + '_ { - struct DisplayableList<'a>(&'a [PhysicalSortExpr]); - impl<'a> Display for DisplayableList<'a> { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let mut first = true; - for sort_expr in self.0 { - if first { - first = false; - } else { - write!(f, ",")?; - } - write!(f, "{}", sort_expr)?; - } - Ok(()) - } - } - DisplayableList(input) - } } /// Represents sort requirement associated with a plan @@ -154,10 +224,7 @@ impl From for PhysicalSortExpr { descending: false, nulls_first: false, }); - PhysicalSortExpr { - expr: value.expr, - options, - } + PhysicalSortExpr::new(value.expr, options) } } @@ -173,13 +240,37 @@ impl PartialEq for PhysicalSortRequirement { } } -impl std::fmt::Display for PhysicalSortRequirement { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for PhysicalSortRequirement { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let opts_string = self.options.as_ref().map_or("NA", to_str); write!(f, "{} {}", self.expr, opts_string) } } +/// Writes a list of [`PhysicalSortRequirement`]s to a `std::fmt::Formatter`. +/// +/// Example output: `[a + 1, b]` +pub fn format_physical_sort_requirement_list( + exprs: &[PhysicalSortRequirement], +) -> impl Display + '_ { + struct DisplayWrapper<'a>(&'a [PhysicalSortRequirement]); + impl<'a> Display for DisplayWrapper<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut iter = self.0.iter(); + write!(f, "[")?; + if let Some(expr) = iter.next() { + write!(f, "{}", expr)?; + } + for expr in iter { + write!(f, ", {}", expr)?; + } + write!(f, "]")?; + Ok(()) + } + } + DisplayWrapper(exprs) +} + impl PhysicalSortRequirement { /// Creates a new requirement. /// @@ -217,12 +308,14 @@ impl PhysicalSortRequirement { /// [`ExecutionPlan::required_input_ordering`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#method.required_input_ordering pub fn from_sort_exprs<'a>( ordering: impl IntoIterator, - ) -> Vec { - ordering - .into_iter() - .cloned() - .map(PhysicalSortRequirement::from) - .collect() + ) -> LexRequirement { + LexRequirement::new( + ordering + .into_iter() + .cloned() + .map(PhysicalSortRequirement::from) + .collect(), + ) } /// Converts an iterator of [`PhysicalSortRequirement`] into a Vec @@ -233,7 +326,7 @@ impl PhysicalSortRequirement { /// default ordering `ASC, NULLS LAST` if given (see the `PhysicalSortExpr::from`). pub fn to_sort_exprs( requirements: impl IntoIterator, - ) -> Vec { + ) -> LexOrdering { requirements .into_iter() .map(PhysicalSortExpr::from) @@ -252,17 +345,205 @@ fn to_str(options: &SortOptions) -> &str { } } -///`LexOrdering` is an alias for the type `Vec`, which represents +///`LexOrdering` contains a `Vec`, which represents /// a lexicographical ordering. -pub type LexOrdering = Vec; +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +pub struct LexOrdering { + pub inner: Vec, +} + +impl LexOrdering { + // Creates a new [`LexOrdering`] from a vector + pub fn new(inner: Vec) -> Self { + Self { inner } + } + + pub fn as_ref(&self) -> LexOrderingRef { + &self.inner + } + + pub fn capacity(&self) -> usize { + self.inner.capacity() + } + + pub fn clear(&mut self) { + self.inner.clear() + } + + pub fn contains(&self, expr: &PhysicalSortExpr) -> bool { + self.inner.contains(expr) + } + + pub fn extend>(&mut self, iter: I) { + self.inner.extend(iter) + } + + pub fn from_ref(lex_ordering_ref: LexOrderingRef) -> Self { + Self::new(lex_ordering_ref.to_vec()) + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + pub fn len(&self) -> usize { + self.inner.len() + } + + pub fn pop(&mut self) -> Option { + self.inner.pop() + } + + pub fn push(&mut self, physical_sort_expr: PhysicalSortExpr) { + self.inner.push(physical_sort_expr) + } + + pub fn retain(&mut self, f: impl FnMut(&PhysicalSortExpr) -> bool) { + self.inner.retain(f) + } + + pub fn truncate(&mut self, len: usize) { + self.inner.truncate(len) + } +} + +impl Deref for LexOrdering { + type Target = [PhysicalSortExpr]; + + fn deref(&self) -> &Self::Target { + self.inner.as_slice() + } +} + +impl Display for LexOrdering { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let mut first = true; + for sort_expr in &self.inner { + if first { + first = false; + } else { + write!(f, ", ")?; + } + write!(f, "{}", sort_expr)?; + } + Ok(()) + } +} + +impl FromIterator for LexOrdering { + fn from_iter>(iter: T) -> Self { + let mut lex_ordering = LexOrdering::default(); + + for i in iter { + lex_ordering.push(i); + } + + lex_ordering + } +} + +impl Index for LexOrdering { + type Output = PhysicalSortExpr; + + fn index(&self, index: usize) -> &Self::Output { + &self.inner[index] + } +} + +impl Index> for LexOrdering { + type Output = [PhysicalSortExpr]; + + fn index(&self, range: Range) -> &Self::Output { + &self.inner[range] + } +} + +impl Index> for LexOrdering { + type Output = [PhysicalSortExpr]; + + fn index(&self, range_from: RangeFrom) -> &Self::Output { + &self.inner[range_from] + } +} + +impl Index> for LexOrdering { + type Output = [PhysicalSortExpr]; + + fn index(&self, range_to: RangeTo) -> &Self::Output { + &self.inner[range_to] + } +} + +impl IntoIterator for LexOrdering { + type Item = PhysicalSortExpr; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} ///`LexOrderingRef` is an alias for the type &`[PhysicalSortExpr]`, which represents /// a reference to a lexicographical ordering. pub type LexOrderingRef<'a> = &'a [PhysicalSortExpr]; -///`LexRequirement` is an alias for the type `Vec`, which +///`LexRequirement` is an struct containing a `Vec`, which /// represents a lexicographical ordering requirement. -pub type LexRequirement = Vec; +#[derive(Debug, Default, Clone, PartialEq)] +pub struct LexRequirement { + pub inner: Vec, +} + +impl LexRequirement { + pub fn new(inner: Vec) -> Self { + Self { inner } + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + pub fn push(&mut self, physical_sort_requirement: PhysicalSortRequirement) { + self.inner.push(physical_sort_requirement) + } +} + +impl Deref for LexRequirement { + type Target = [PhysicalSortRequirement]; + + fn deref(&self) -> &Self::Target { + self.inner.as_slice() + } +} + +impl FromIterator for LexRequirement { + fn from_iter>(iter: T) -> Self { + let mut lex_requirement = LexRequirement::new(vec![]); + + for i in iter { + lex_requirement.inner.push(i); + } + + lex_requirement + } +} + +impl IntoIterator for LexRequirement { + type Item = PhysicalSortRequirement; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} ///`LexRequirementRef` is an alias for the type &`[PhysicalSortRequirement]`, which /// represents a reference to a lexicographical ordering requirement. diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs index 42dc6673af6a..d9892ce55509 100644 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ b/datafusion/physical-expr-common/src/tree_node.rs @@ -26,7 +26,7 @@ use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; impl DynTreeNode for dyn PhysicalExpr { - fn arc_children(&self) -> Vec> { + fn arc_children(&self) -> Vec<&Arc> { self.children() } @@ -70,7 +70,12 @@ impl ExprContext { impl ExprContext { pub fn new_default(plan: Arc) -> Self { - let children = plan.children().into_iter().map(Self::new_default).collect(); + let children = plan + .children() + .into_iter() + .cloned() + .map(Self::new_default) + .collect(); Self::new(plan, Default::default(), children) } } @@ -84,8 +89,8 @@ impl Display for ExprContext { } impl ConcreteTreeNode for ExprContext { - fn children(&self) -> Vec<&Self> { - self.children.iter().collect() + fn children(&self) -> &[Self] { + &self.children } fn take_children(mut self) -> (Self, Vec) { diff --git a/datafusion/physical-expr-common/src/utils.rs b/datafusion/physical-expr-common/src/utils.rs index 459b5a4849cb..26293b1a76a2 100644 --- a/datafusion/physical-expr-common/src/utils.rs +++ b/datafusion/physical-expr-common/src/utils.rs @@ -15,13 +15,40 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}, - compute::{and_kleene, is_not_null, SlicesIterator}, -}; -use datafusion_common::Result; +use std::sync::Arc; -use crate::sort_expr::PhysicalSortExpr; +use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; +use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; + +use datafusion_common::Result; +use datafusion_expr_common::sort_properties::ExprProperties; + +use crate::physical_expr::PhysicalExpr; +use crate::sort_expr::{LexOrdering, LexOrderingRef, PhysicalSortExpr}; +use crate::tree_node::ExprContext; + +/// Represents a [`PhysicalExpr`] node with associated properties (order and +/// range) in a context where properties are tracked. +pub type ExprPropertiesNode = ExprContext; + +impl ExprPropertiesNode { + /// Constructs a new `ExprPropertiesNode` with unknown properties for a + /// given physical expression. This node initializes with default properties + /// and recursively applies this to all child expressions. + pub fn new_unknown(expr: Arc) -> Self { + let children = expr + .children() + .into_iter() + .cloned() + .map(Self::new_unknown) + .collect(); + Self { + expr, + data: ExprProperties::new_unknown(), + children, + } + } +} /// Scatter `truthy` array by boolean mask. When the mask evaluates `true`, next values of `truthy` /// are taken, when the mask evaluates `false` values null values are filled. @@ -69,13 +96,10 @@ pub fn scatter(mask: &BooleanArray, truthy: &dyn Array) -> Result { /// Reverses the ORDER BY expression, which is useful during equivalent window /// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into /// 'ORDER BY a DESC, NULLS FIRST'. -pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { +pub fn reverse_order_bys(order_bys: LexOrderingRef) -> LexOrdering { order_bys .iter() - .map(|e| PhysicalSortExpr { - expr: e.expr.clone(), - options: !e.options, - }) + .map(|e| PhysicalSortExpr::new(e.expr.clone(), !e.options)) .collect() } @@ -84,6 +108,7 @@ mod tests { use std::sync::Arc; use arrow::array::Int32Array; + use datafusion_common::cast::{as_boolean_array, as_int32_array}; use super::*; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 5261f1c8968d..4195e684381f 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -35,40 +35,27 @@ workspace = true name = "datafusion_physical_expr" path = "src/lib.rs" -[features] -default = [ - "regex_expressions", - "encoding_expressions", -] -encoding_expressions = ["base64", "hex"] -regex_expressions = ["regex"] - [dependencies] -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } +ahash = { workspace = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-string = { workspace = true } -base64 = { version = "0.22", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } -datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr-common = { workspace = true } half = { workspace = true } -hashbrown = { version = "0.14", features = ["raw"] } -hex = { version = "0.4", optional = true } +hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" -regex = { version = "1.8", optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } @@ -80,3 +67,11 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } [[bench]] harness = false name = "in_list" + +[[bench]] +harness = false +name = "case_when" + +[[bench]] +harness = false +name = "is_null" diff --git a/datafusion/physical-expr/benches/case_when.rs b/datafusion/physical-expr/benches/case_when.rs new file mode 100644 index 000000000000..9eda1277c263 --- /dev/null +++ b/datafusion/physical-expr/benches/case_when.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr, Column, Literal}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) +} + +fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + let mut c3 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(format!("string {i}")); + } + if i % 9 == 0 { + c3.append_null(); + } else { + c3.append_value(format!("other string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); + + // use same predicate for all benchmarks + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(500), + )); + + // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END + c.bench_function("case_when: scalar or scalar", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_lit_i32(1))], + Some(make_lit_i32(0)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END + c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END + c.bench_function("case_when: expr or expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_col("c2", 1))], + Some(make_col("c3", 2)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE c1 WHEN 1 THEN c2 WHEN 2 THEN c3 END + c.bench_function("case_when: CASE expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + Some(make_col("c1", 0)), + vec![ + (make_lit_i32(1), make_col("c2", 1)), + (make_lit_i32(2), make_col("c3", 2)), + ], + None, + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/benches/is_null.rs b/datafusion/physical-expr/benches/is_null.rs new file mode 100644 index 000000000000..7d26557afb1b --- /dev/null +++ b/datafusion/physical-expr/benches/is_null.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::Int32Builder; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_physical_expr::expressions::{Column, IsNotNullExpr, IsNullExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = Int32Builder::new(); + let mut c3 = Int32Builder::new(); + for i in 0..1000 { + // c1 is always null + c1.append_null(); + // c2 is never null + c2.append_value(i); + // c3 is a mix of values and nulls + if i % 7 == 0 { + c3.append_null(); + } else { + c3.append_value(i); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, false), + Field::new("c3", DataType::Int32, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); + + c.bench_function("is_null: column is all nulls", |b| { + let expr = is_null("c1", 0); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_null: column is never null", |b| { + let expr = is_null("c2", 1); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_null: column is mix of values and nulls", |b| { + let expr = is_null("c3", 2); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_not_null: column is all nulls", |b| { + let expr = is_not_null("c1", 0); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_not_null: column is never null", |b| { + let expr = is_not_null("c2", 1); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + c.bench_function("is_not_null: column is mix of values and nulls", |b| { + let expr = is_not_null("c3", 2); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +fn is_null(name: &str, index: usize) -> Arc { + Arc::new(IsNullExpr::new(Arc::new(Column::new(name, index)))) +} + +fn is_not_null(name: &str, index: usize) -> Arc { + Arc::new(IsNotNullExpr::new(Arc::new(Column::new(name, index)))) +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/aggregate.rs b/datafusion/physical-expr/src/aggregate.rs new file mode 100644 index 000000000000..e446776affc0 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate.rs @@ -0,0 +1,583 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod groups_accumulator { + #[allow(unused_imports)] + pub(crate) mod accumulate { + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; + } + pub use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ + accumulate::NullState, GroupsAccumulatorAdapter, + }; +} +pub(crate) mod stats { + pub use datafusion_functions_aggregate_common::stats::StatsType; +} +pub mod utils { + pub use datafusion_functions_aggregate_common::utils::{ + adjust_output_array, get_accum_scalar_values_as_arrays, get_sort_options, + ordering_fields, DecimalAverager, Hashable, + }; +} + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::ScalarValue; +use datafusion_common::{internal_err, not_impl_err, Result}; +use datafusion_expr::AggregateUDF; +use datafusion_expr::ReversedUDAF; +use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::type_coercion::aggregates::check_arg_count; +use datafusion_functions_aggregate_common::accumulator::AccumulatorArgs; +use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs; +use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; +use datafusion_physical_expr_common::utils::reverse_order_bys; + +use datafusion_expr_common::groups_accumulator::GroupsAccumulator; +use std::fmt::Debug; +use std::sync::Arc; + +/// Builder for physical [`AggregateFunctionExpr`] +/// +/// `AggregateFunctionExpr` contains the information necessary to call +/// an aggregate expression. +#[derive(Debug, Clone)] +pub struct AggregateExprBuilder { + fun: Arc, + /// Physical expressions of the aggregate function + args: Vec>, + alias: Option, + /// Arrow Schema for the aggregate function + schema: SchemaRef, + /// The physical order by expressions + ordering_req: LexOrdering, + /// Whether to ignore null values + ignore_nulls: bool, + /// Whether is distinct aggregate function + is_distinct: bool, + /// Whether the expression is reversed + is_reversed: bool, +} + +impl AggregateExprBuilder { + pub fn new(fun: Arc, args: Vec>) -> Self { + Self { + fun, + args, + alias: None, + schema: Arc::new(Schema::empty()), + ordering_req: LexOrdering::default(), + ignore_nulls: false, + is_distinct: false, + is_reversed: false, + } + } + + pub fn build(self) -> Result { + let Self { + fun, + args, + alias, + schema, + ordering_req, + ignore_nulls, + is_distinct, + is_reversed, + } = self; + if args.is_empty() { + return internal_err!("args should not be empty"); + } + + let mut ordering_fields = vec![]; + + if !ordering_req.is_empty() { + let ordering_types = ordering_req + .iter() + .map(|e| e.expr.data_type(&schema)) + .collect::>>()?; + + ordering_fields = + utils::ordering_fields(ordering_req.as_ref(), &ordering_types); + } + + let input_exprs_types = args + .iter() + .map(|arg| arg.data_type(&schema)) + .collect::>>()?; + + check_arg_count( + fun.name(), + &input_exprs_types, + &fun.signature().type_signature, + )?; + + let data_type = fun.return_type(&input_exprs_types)?; + let is_nullable = fun.is_nullable(); + let name = match alias { + None => return internal_err!("alias should be provided"), + Some(alias) => alias, + }; + + Ok(AggregateFunctionExpr { + fun: Arc::unwrap_or_clone(fun), + args, + data_type, + name, + schema: Arc::unwrap_or_clone(schema), + ordering_req, + ignore_nulls, + ordering_fields, + is_distinct, + input_types: input_exprs_types, + is_reversed, + is_nullable, + }) + } + + pub fn alias(mut self, alias: impl Into) -> Self { + self.alias = Some(alias.into()); + self + } + + pub fn schema(mut self, schema: SchemaRef) -> Self { + self.schema = schema; + self + } + + pub fn order_by(mut self, order_by: LexOrdering) -> Self { + self.ordering_req = order_by; + self + } + + pub fn reversed(mut self) -> Self { + self.is_reversed = true; + self + } + + pub fn with_reversed(mut self, is_reversed: bool) -> Self { + self.is_reversed = is_reversed; + self + } + + pub fn distinct(mut self) -> Self { + self.is_distinct = true; + self + } + + pub fn with_distinct(mut self, is_distinct: bool) -> Self { + self.is_distinct = is_distinct; + self + } + + pub fn ignore_nulls(mut self) -> Self { + self.ignore_nulls = true; + self + } + + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } +} + +/// Physical aggregate expression of a UDAF. +#[derive(Debug, Clone)] +pub struct AggregateFunctionExpr { + fun: AggregateUDF, + args: Vec>, + /// Output / return type of this aggregate + data_type: DataType, + name: String, + schema: Schema, + // The physical order by expressions + ordering_req: LexOrdering, + // Whether to ignore null values + ignore_nulls: bool, + // fields used for order sensitive aggregation functions + ordering_fields: Vec, + is_distinct: bool, + is_reversed: bool, + input_types: Vec, + is_nullable: bool, +} + +impl AggregateFunctionExpr { + /// Return the `AggregateUDF` used by this `AggregateFunctionExpr` + pub fn fun(&self) -> &AggregateUDF { + &self.fun + } + + /// expressions that are passed to the Accumulator. + /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. + pub fn expressions(&self) -> Vec> { + self.args.clone() + } + + /// Human readable name such as `"MIN(c2)"`. + pub fn name(&self) -> &str { + &self.name + } + + /// Return if the aggregation is distinct + pub fn is_distinct(&self) -> bool { + self.is_distinct + } + + /// Return if the aggregation ignores nulls + pub fn ignore_nulls(&self) -> bool { + self.ignore_nulls + } + + /// Return if the aggregation is reversed + pub fn is_reversed(&self) -> bool { + self.is_reversed + } + + /// Return if the aggregation is nullable + pub fn is_nullable(&self) -> bool { + self.is_nullable + } + + /// the field of the final result of this aggregation. + pub fn field(&self) -> Field { + Field::new(&self.name, self.data_type.clone(), self.is_nullable) + } + + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + pub fn create_accumulator(&self) -> Result> { + let acc_args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: self.ordering_req.as_ref(), + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + + self.fun.accumulator(acc_args) + } + + /// the field of the final result of this aggregation. + pub fn state_fields(&self) -> Result> { + let args = StateFieldsArgs { + name: &self.name, + input_types: &self.input_types, + return_type: &self.data_type, + ordering_fields: &self.ordering_fields, + is_distinct: self.is_distinct, + }; + + self.fun.state_fields(args) + } + + /// Order by requirements for the aggregate function + /// By default it is `None` (there is no requirement) + /// Order-sensitive aggregators, such as `FIRST_VALUE(x ORDER BY y)` should implement this + pub fn order_bys(&self) -> Option { + if self.ordering_req.is_empty() { + return None; + } + + if !self.order_sensitivity().is_insensitive() { + return Some(self.ordering_req.as_ref()); + } + + None + } + + /// Indicates whether aggregator can produce the correct result with any + /// arbitrary input ordering. By default, we assume that aggregate expressions + /// are order insensitive. + pub fn order_sensitivity(&self) -> AggregateOrderSensitivity { + if !self.ordering_req.is_empty() { + // If there is requirement, use the sensitivity of the implementation + self.fun.order_sensitivity() + } else { + // If no requirement, aggregator is order insensitive + AggregateOrderSensitivity::Insensitive + } + } + + /// Sets the indicator whether ordering requirements of the aggregator is + /// satisfied by its input. If this is not the case, aggregators with order + /// sensitivity `AggregateOrderSensitivity::Beneficial` can still produce + /// the correct result with possibly more work internally. + /// + /// # Returns + /// + /// Returns `Ok(Some(updated_expr))` if the process completes successfully. + /// If the expression can benefit from existing input ordering, but does + /// not implement the method, returns an error. Order insensitive and hard + /// requirement aggregators return `Ok(None)`. + pub fn with_beneficial_ordering( + self: Arc, + beneficial_ordering: bool, + ) -> Result> { + let Some(updated_fn) = self + .fun + .clone() + .with_beneficial_ordering(beneficial_ordering)? + else { + return Ok(None); + }; + + AggregateExprBuilder::new(Arc::new(updated_fn), self.args.to_vec()) + .order_by(self.ordering_req.clone()) + .schema(Arc::new(self.schema.clone())) + .alias(self.name().to_string()) + .with_ignore_nulls(self.ignore_nulls) + .with_distinct(self.is_distinct) + .with_reversed(self.is_reversed) + .build() + .map(Some) + } + + /// Creates accumulator implementation that supports retract + pub fn create_sliding_accumulator(&self) -> Result> { + let args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: self.ordering_req.as_ref(), + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + + let accumulator = self.fun.create_sliding_accumulator(args)?; + + // Accumulators that have window frame startings different + // than `UNBOUNDED PRECEDING`, such as `1 PRECEDING`, need to + // implement retract_batch method in order to run correctly + // currently in DataFusion. + // + // If this `retract_batches` is not present, there is no way + // to calculate result correctly. For example, the query + // + // ```sql + // SELECT + // SUM(a) OVER(ORDER BY a ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS sum_a + // FROM + // t + // ``` + // + // 1. First sum value will be the sum of rows between `[0, 1)`, + // + // 2. Second sum value will be the sum of rows between `[0, 2)` + // + // 3. Third sum value will be the sum of rows between `[1, 3)`, etc. + // + // Since the accumulator keeps the running sum: + // + // 1. First sum we add to the state sum value between `[0, 1)` + // + // 2. Second sum we add to the state sum value between `[1, 2)` + // (`[0, 1)` is already in the state sum, hence running sum will + // cover `[0, 2)` range) + // + // 3. Third sum we add to the state sum value between `[2, 3)` + // (`[0, 2)` is already in the state sum). Also we need to + // retract values between `[0, 1)` by this way we can obtain sum + // between [1, 3) which is indeed the appropriate range. + // + // When we use `UNBOUNDED PRECEDING` in the query starting + // index will always be 0 for the desired range, and hence the + // `retract_batch` method will not be called. In this case + // having retract_batch is not a requirement. + // + // This approach is a a bit different than window function + // approach. In window function (when they use a window frame) + // they get all the desired range during evaluation. + if !accumulator.supports_retract_batch() { + return not_impl_err!( + "Aggregate can not be used as a sliding accumulator because \ + `retract_batch` is not implemented: {}", + self.name + ); + } + Ok(accumulator) + } + + /// If the aggregate expression has a specialized + /// [`GroupsAccumulator`] implementation. If this returns true, + /// `[Self::create_groups_accumulator`] will be called. + pub fn groups_accumulator_supported(&self) -> bool { + let args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: self.ordering_req.as_ref(), + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + self.fun.groups_accumulator_supported(args) + } + + /// Return a specialized [`GroupsAccumulator`] that manages state + /// for all groups. + /// + /// For maximum performance, a [`GroupsAccumulator`] should be + /// implemented in addition to [`Accumulator`]. + pub fn create_groups_accumulator(&self) -> Result> { + let args = AccumulatorArgs { + return_type: &self.data_type, + schema: &self.schema, + ignore_nulls: self.ignore_nulls, + ordering_req: self.ordering_req.as_ref(), + is_distinct: self.is_distinct, + name: &self.name, + is_reversed: self.is_reversed, + exprs: &self.args, + }; + self.fun.create_groups_accumulator(args) + } + + /// Construct an expression that calculates the aggregate in reverse. + /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). + /// For aggregates that do not support calculation in reverse, + /// returns None (which is the default value). + pub fn reverse_expr(&self) -> Option { + match self.fun.reverse_udf() { + ReversedUDAF::NotSupported => None, + ReversedUDAF::Identical => Some(self.clone()), + ReversedUDAF::Reversed(reverse_udf) => { + let reverse_ordering_req = reverse_order_bys(self.ordering_req.as_ref()); + let mut name = self.name().to_string(); + // If the function is changed, we need to reverse order_by clause as well + // i.e. First(a order by b asc null first) -> Last(a order by b desc null last) + if self.fun().name() == reverse_udf.name() { + } else { + replace_order_by_clause(&mut name); + } + replace_fn_name_clause(&mut name, self.fun.name(), reverse_udf.name()); + + AggregateExprBuilder::new(reverse_udf, self.args.to_vec()) + .order_by(reverse_ordering_req) + .schema(Arc::new(self.schema.clone())) + .alias(name) + .with_ignore_nulls(self.ignore_nulls) + .with_distinct(self.is_distinct) + .with_reversed(!self.is_reversed) + .build() + .ok() + } + } + } + + /// Returns all expressions used in the [`AggregateFunctionExpr`]. + /// These expressions are (1)function arguments, (2) order by expressions. + pub fn all_expressions(&self) -> AggregatePhysicalExpressions { + let args = self.expressions(); + let order_bys = self.order_bys().unwrap_or_default(); + let order_by_exprs = order_bys + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + AggregatePhysicalExpressions { + args, + order_by_exprs, + } + } + + /// Rewrites [`AggregateFunctionExpr`], with new expressions given. The argument should be consistent + /// with the return value of the [`AggregateFunctionExpr::all_expressions`] method. + /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. + pub fn with_new_expressions( + &self, + _args: Vec>, + _order_by_exprs: Vec>, + ) -> Option { + None + } + + /// If this function is max, return (output_field, true) + /// if the function is min, return (output_field, false) + /// otherwise return None (the default) + /// + /// output_field is the name of the column produced by this aggregate + /// + /// Note: this is used to use special aggregate implementations in certain conditions + pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { + self.fun.is_descending().map(|flag| (self.field(), flag)) + } + + /// Returns default value of the function given the input is Null + /// Most of the aggregate function return Null if input is Null, + /// while `count` returns 0 if input is Null + pub fn default_value(&self, data_type: &DataType) -> Result { + self.fun.default_value(data_type) + } +} + +/// Stores the physical expressions used inside the `AggregateExpr`. +pub struct AggregatePhysicalExpressions { + /// Aggregate function arguments + pub args: Vec>, + /// Order by expressions + pub order_by_exprs: Vec>, +} + +impl PartialEq for AggregateFunctionExpr { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.data_type == other.data_type + && self.fun == other.fun + && self.args.len() == other.args.len() + && self + .args + .iter() + .zip(other.args.iter()) + .all(|(this_arg, other_arg)| this_arg.eq(other_arg)) + } +} + +fn replace_order_by_clause(order_by: &mut String) { + let suffixes = [ + (" DESC NULLS FIRST]", " ASC NULLS LAST]"), + (" ASC NULLS FIRST]", " DESC NULLS LAST]"), + (" DESC NULLS LAST]", " ASC NULLS FIRST]"), + (" ASC NULLS LAST]", " DESC NULLS FIRST]"), + ]; + + if let Some(start) = order_by.find("ORDER BY [") { + if let Some(end) = order_by[start..].find(']') { + let order_by_start = start + 9; + let order_by_end = start + end; + + let column_order = &order_by[order_by_start..=order_by_end]; + for (suffix, replacement) in suffixes { + if column_order.ends_with(suffix) { + let new_order = column_order.replace(suffix, replacement); + order_by.replace_range(order_by_start..=order_by_end, &new_order); + break; + } + } + } + } +} + +fn replace_fn_name_clause(aggr_name: &mut String, fn_name_old: &str, fn_name_new: &str) { + *aggr_name = aggr_name.replace(fn_name_old, fn_name_new); +} diff --git a/datafusion/physical-expr/src/aggregate/approx_median.rs b/datafusion/physical-expr/src/aggregate/approx_median.rs deleted file mode 100644 index cbbfef5a8919..000000000000 --- a/datafusion/physical-expr/src/aggregate/approx_median.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{lit, ApproxPercentileCont}; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{datatypes::DataType, datatypes::Field}; -use datafusion_common::Result; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// MEDIAN aggregate expression -#[derive(Debug)] -pub struct ApproxMedian { - name: String, - expr: Arc, - data_type: DataType, - approx_percentile: ApproxPercentileCont, -} - -impl ApproxMedian { - /// Create a new APPROX_MEDIAN aggregate function - pub fn try_new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Result { - let name: String = name.into(); - let approx_percentile = ApproxPercentileCont::new( - vec![expr.clone(), lit(0.5_f64)], - name.clone(), - data_type.clone(), - )?; - Ok(Self { - name, - expr, - data_type, - approx_percentile, - }) - } -} - -impl AggregateExpr for ApproxMedian { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - self.approx_percentile.create_accumulator() - } - - fn state_fields(&self) -> Result> { - self.approx_percentile.state_fields() - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ApproxMedian { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.approx_percentile == x.approx_percentile - }) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs deleted file mode 100644 index 3fa715a59238..000000000000 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ /dev/null @@ -1,174 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator; -use crate::aggregate::tdigest::{Centroid, TDigest, DEFAULT_MAX_SIZE}; -use crate::expressions::ApproxPercentileCont; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{ - array::ArrayRef, - datatypes::{DataType, Field}, -}; - -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; - -use crate::aggregate::utils::down_cast_any_ref; -use std::{any::Any, sync::Arc}; - -/// APPROX_PERCENTILE_CONT_WITH_WEIGTH aggregate expression -#[derive(Debug)] -pub struct ApproxPercentileContWithWeight { - approx_percentile_cont: ApproxPercentileCont, - column_expr: Arc, - weight_expr: Arc, - percentile_expr: Arc, -} - -impl ApproxPercentileContWithWeight { - /// Create a new [`ApproxPercentileContWithWeight`] aggregate function. - pub fn new( - expr: Vec>, - name: impl Into, - return_type: DataType, - ) -> Result { - // Arguments should be [ColumnExpr, WeightExpr, DesiredPercentileLiteral] - debug_assert_eq!(expr.len(), 3); - - let sub_expr = vec![expr[0].clone(), expr[2].clone()]; - let approx_percentile_cont = - ApproxPercentileCont::new(sub_expr, name, return_type)?; - - Ok(Self { - approx_percentile_cont, - column_expr: expr[0].clone(), - weight_expr: expr[1].clone(), - percentile_expr: expr[2].clone(), - }) - } -} - -impl AggregateExpr for ApproxPercentileContWithWeight { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - self.approx_percentile_cont.field() - } - - #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised - /// state. - fn state_fields(&self) -> Result> { - self.approx_percentile_cont.state_fields() - } - - fn expressions(&self) -> Vec> { - vec![ - self.column_expr.clone(), - self.weight_expr.clone(), - self.percentile_expr.clone(), - ] - } - - fn create_accumulator(&self) -> Result> { - let approx_percentile_cont_accumulator = - self.approx_percentile_cont.create_plain_accumulator()?; - let accumulator = ApproxPercentileWithWeightAccumulator::new( - approx_percentile_cont_accumulator, - ); - Ok(Box::new(accumulator)) - } - - fn name(&self) -> &str { - self.approx_percentile_cont.name() - } -} - -impl PartialEq for ApproxPercentileContWithWeight { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.approx_percentile_cont == x.approx_percentile_cont - && self.column_expr.eq(&x.column_expr) - && self.weight_expr.eq(&x.weight_expr) - && self.percentile_expr.eq(&x.percentile_expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub struct ApproxPercentileWithWeightAccumulator { - approx_percentile_cont_accumulator: ApproxPercentileAccumulator, -} - -impl ApproxPercentileWithWeightAccumulator { - pub fn new(approx_percentile_cont_accumulator: ApproxPercentileAccumulator) -> Self { - Self { - approx_percentile_cont_accumulator, - } - } -} - -impl Accumulator for ApproxPercentileWithWeightAccumulator { - fn state(&mut self) -> Result> { - self.approx_percentile_cont_accumulator.state() - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let means = &values[0]; - let weights = &values[1]; - debug_assert_eq!( - means.len(), - weights.len(), - "invalid number of values in means and weights" - ); - let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?; - let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?; - let mut digests: Vec = vec![]; - for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) { - digests.push(TDigest::new_with_centroid( - DEFAULT_MAX_SIZE, - Centroid::new(*mean, *weight), - )) - } - self.approx_percentile_cont_accumulator - .merge_digests(&digests); - Ok(()) - } - - fn evaluate(&mut self) -> Result { - self.approx_percentile_cont_accumulator.evaluate() - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.approx_percentile_cont_accumulator - .merge_batch(states)?; - - Ok(()) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - - std::mem::size_of_val(&self.approx_percentile_cont_accumulator) - + self.approx_percentile_cont_accumulator.size() - } -} diff --git a/datafusion/physical-expr/src/aggregate/array_agg.rs b/datafusion/physical-expr/src/aggregate/array_agg.rs deleted file mode 100644 index 23d916103204..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg.rs +++ /dev/null @@ -1,307 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::cast::as_list_array; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// ARRAY_AGG aggregate expression -#[derive(Debug)] -pub struct ArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have NULLs - nullable: bool, -} - -impl ArrayAgg { - /// Create a new ArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - nullable: bool, - ) -> Self { - Self { - name: name.into(), - input_data_type: data_type, - expr, - nullable, - } - } -} - -impl AggregateExpr for ArrayAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(ArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct ArrayAggAccumulator { - values: Vec, - datatype: DataType, -} - -impl ArrayAggAccumulator { - /// new array_agg accumulator based on given item data type - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: vec![], - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for ArrayAggAccumulator { - // Append value like Int64Array(1,2,3) - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - assert!(values.len() == 1, "array_agg can only take 1 param!"); - let val = values[0].clone(); - self.values.push(val); - Ok(()) - } - - // Append value like ListArray(Int64Array(1,2,3), Int64Array(4,5,6)) - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert!(states.len() == 1, "array_agg states must be singleton!"); - - let list_arr = as_list_array(&states[0])?; - for arr in list_arr.iter().flatten() { - self.values.push(arr); - } - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - // Transform Vec to ListArr - - let element_arrays: Vec<&dyn Array> = - self.values.iter().map(|a| a.as_ref()).collect(); - - if element_arrays.is_empty() { - let arr = ScalarValue::new_list(&[], &self.datatype); - return Ok(ScalarValue::List(arr)); - } - - let concated_array = arrow::compute::concat(&element_arrays)?; - let list_array = array_into_list_array(concated_array); - - Ok(ScalarValue::List(Arc::new(list_array))) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|arr| arr.get_array_memory_size()) - .sum::() - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::array::Int32Array; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use arrow_array::ListArray; - use arrow_buffer::OffsetBuffer; - use datafusion_common::DataFusionError; - - macro_rules! test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - true, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), DataFusionError> - }}; - } - - #[test] - fn array_agg_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - - let list = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - Some(4), - Some(5), - ])]); - let list = ScalarValue::List(Arc::new(list)); - - test_op!(a, DataType::Int32, ArrayAgg, list, DataType::Int32) - } - - #[test] - fn array_agg_nested() -> Result<()> { - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([a1.len() + a2.len()]), - arrow::compute::concat(&[&a1, &a2])?, - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([a1.len() + a2.len()]), - arrow::compute::concat(&[&a1, &a2])?, - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([a1.len()]), - arrow::compute::concat(&[&a1])?, - None, - ); - - let list = ListArray::new( - Arc::new(Field::new("item", l1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([l1.len() + l2.len() + l3.len()]), - arrow::compute::concat(&[&l1, &l2, &l3])?, - None, - ); - let list = ScalarValue::List(Arc::new(list)); - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - let array = ScalarValue::iter_to_array(vec![l1, l2, l3]).unwrap(); - - test_op!( - array, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ArrayAgg, - list, - DataType::List(Arc::new(Field::new("item", DataType::Int32, true,))) - ) - } -} diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs deleted file mode 100644 index b8671c39a943..000000000000 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ /dev/null @@ -1,437 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; - -/// Expression for a ARRAY_AGG(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctArrayAgg { - /// Column name - name: String, - /// The DataType for the input expression - input_data_type: DataType, - /// The input expression - expr: Arc, - /// If the input expression can have NULLs - nullable: bool, -} - -impl DistinctArrayAgg { - /// Create a new DistinctArrayAgg aggregate function - pub fn new( - expr: Arc, - name: impl Into, - input_data_type: DataType, - nullable: bool, - ) -> Self { - let name = name.into(); - Self { - name, - input_data_type, - expr, - nullable, - } - } -} - -impl AggregateExpr for DistinctArrayAgg { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new_list( - &self.name, - // This should be the same as return type of AggregateFunction::ArrayAgg - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctArrayAggAccumulator::try_new( - &self.input_data_type, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "distinct_array_agg"), - Field::new("item", self.input_data_type.clone(), true), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctArrayAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.input_data_type == x.input_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct DistinctArrayAggAccumulator { - values: HashSet, - datatype: DataType, -} - -impl DistinctArrayAggAccumulator { - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - values: HashSet::new(), - datatype: datatype.clone(), - }) - } -} - -impl Accumulator for DistinctArrayAggAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1, "batch input should only include 1 column!"); - - let array = &values[0]; - - for i in 0..array.len() { - let scalar = ScalarValue::try_from_array(&array, i)?; - self.values.insert(scalar); - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - let array = &states[0]; - - assert_eq!(array.len(), 1, "state array should only include 1 row!"); - // Unwrap outer ListArray then do update batch - let inner_array = array.as_list::().value(0); - self.update_batch(&[inner_array]) - } - - fn evaluate(&mut self) -> Result { - let values: Vec = self.values.iter().cloned().collect(); - let arr = ScalarValue::new_list(&values, &self.datatype); - Ok(ScalarValue::List(arr)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) + ScalarValue::size_of_hashset(&self.values) - - std::mem::size_of_val(&self.values) - + self.datatype.size() - - std::mem::size_of_val(&self.datatype) - } -} - -#[cfg(test)] -mod tests { - - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use arrow::array::Int32Array; - use arrow::datatypes::Schema; - use arrow::record_batch::RecordBatch; - use arrow_array::types::Int32Type; - use arrow_array::Array; - use arrow_array::ListArray; - use arrow_buffer::OffsetBuffer; - use datafusion_common::internal_err; - - // arrow::compute::sort can't sort nested ListArray directly, so we compare the scalar values pair-wise. - fn compare_list_contents( - expected: Vec, - actual: ScalarValue, - ) -> Result<()> { - let array = actual.to_array()?; - let list_array = array.as_list::(); - let inner_array = list_array.value(0); - let mut actual_scalars = vec![]; - for index in 0..inner_array.len() { - let sv = ScalarValue::try_from_array(&inner_array, index)?; - actual_scalars.push(sv); - } - - if actual_scalars.len() != expected.len() { - return internal_err!( - "Expected and actual list lengths differ: expected={}, actual={}", - expected.len(), - actual_scalars.len() - ); - } - - let mut seen = vec![false; expected.len()]; - for v in expected { - let mut found = false; - for (i, sv) in actual_scalars.iter().enumerate() { - if sv == &v { - seen[i] = true; - found = true; - break; - } - } - if !found { - return internal_err!( - "Expected value {:?} not found in actual values {:?}", - v, - actual_scalars - ); - } - } - - Ok(()) - } - - fn check_distinct_array_agg( - input: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![input])?; - - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - true, - )); - let actual = aggregate(&batch, agg)?; - compare_list_contents(expected, actual) - } - - fn check_merge_distinct_array_agg( - input1: ArrayRef, - input2: ArrayRef, - expected: Vec, - datatype: DataType, - ) -> Result<()> { - let schema = Schema::new(vec![Field::new("a", datatype.clone(), false)]); - let agg = Arc::new(DistinctArrayAgg::new( - col("a", &schema)?, - "bla".to_string(), - datatype, - true, - )); - - let mut accum1 = agg.create_accumulator()?; - let mut accum2 = agg.create_accumulator()?; - - accum1.update_batch(&[input1])?; - accum2.update_batch(&[input2])?; - - let array = accum2.state()?[0].raw_data()?; - accum1.merge_batch(&[array])?; - - let actual = accum1.evaluate()?; - compare_list_contents(expected, actual) - } - - #[test] - fn distinct_array_agg_i32() -> Result<()> { - let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ]; - - check_distinct_array_agg(col, expected, DataType::Int32) - } - - #[test] - fn merge_distinct_array_agg_i32() -> Result<()> { - let col1: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 4, 5, 2])); - let col2: ArrayRef = Arc::new(Int32Array::from(vec![1, 3, 7, 8, 4])); - - let expected = vec![ - ScalarValue::Int32(Some(1)), - ScalarValue::Int32(Some(2)), - ScalarValue::Int32(Some(3)), - ScalarValue::Int32(Some(4)), - ScalarValue::Int32(Some(5)), - ScalarValue::Int32(Some(7)), - ScalarValue::Int32(Some(8)), - ]; - - check_merge_distinct_array_agg(col1, col2, expected, DataType::Int32) - } - - #[test] - fn distinct_array_agg_nested() -> Result<()> { - // [[1, 2, 3], [4, 5]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[6], [7, 8]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - // [[9]] - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 and l3 in the input array and check that it is deduped in the output. - let array = ScalarValue::iter_to_array(vec![ - l1.clone(), - l2.clone(), - l3.clone(), - l3.clone(), - l1.clone(), - ]) - .unwrap(); - let expected = vec![l1, l2, l3]; - - check_distinct_array_agg( - array, - expected, - DataType::List(Arc::new(Field::new_list( - "item", - Field::new("item", DataType::Int32, true), - true, - ))), - ) - } - - #[test] - fn merge_distinct_array_agg_nested() -> Result<()> { - // [[1, 2], [3, 4]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(3), - Some(4), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(5)])]); - let l2 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([1]), - Arc::new(a1), - None, - ); - - // [[6, 7], [8]] - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(6), - Some(7), - ])]); - let a2 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(8)])]); - let l3 = ListArray::new( - Arc::new(Field::new("item", a1.data_type().to_owned(), true)), - OffsetBuffer::from_lengths([2]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let l1 = ScalarValue::List(Arc::new(l1)); - let l2 = ScalarValue::List(Arc::new(l2)); - let l3 = ScalarValue::List(Arc::new(l3)); - - // Duplicate l1 in the input array and check that it is deduped in the output. - let input1 = ScalarValue::iter_to_array(vec![l1.clone(), l2.clone()]).unwrap(); - let input2 = ScalarValue::iter_to_array(vec![l1.clone(), l3.clone()]).unwrap(); - - let expected = vec![l1, l2, l3]; - - check_merge_distinct_array_agg(input1, input2, expected, DataType::Int32) - } -} diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs deleted file mode 100644 index 7244686a5195..000000000000 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ /dev/null @@ -1,822 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines BitAnd, BitOr, and BitXor Aggregate accumulators - -use ahash::RandomState; -use datafusion_common::cast::as_list_array; -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; -use std::collections::HashSet; - -use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::compute::{bit_and, bit_or, bit_xor}; -use arrow_array::cast::AsArray; -use arrow_array::{downcast_integer, ArrowNumericType}; -use arrow_buffer::ArrowNativeType; - -/// BIT_AND aggregate expression -#[derive(Debug, Clone)] -pub struct BitAnd { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BitAnd { - /// Create a new BIT_AND aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BitAnd { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "BitAndAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bit_and"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - use std::ops::BitAndAssign; - - // Note the default value for BitAnd should be all set, i.e. `!0` - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new( - PrimitiveGroupsAccumulator::<$t, _>::new($dt, |x, y| { - x.bitand_assign(y) - }) - .with_starting_value(!0), - )) - }; - } - - let data_type = &self.data_type; - downcast_integer! { - data_type => (helper, data_type), - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } -} - -impl PartialEq for BitAnd { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct BitAndAccumulator { - value: Option, -} - -impl std::fmt::Debug for BitAndAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BitAndAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for BitAndAccumulator { - fn default() -> Self { - Self { value: None } - } -} - -impl Accumulator for BitAndAccumulator -where - T::Native: std::ops::BitAnd, -{ - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(x) = bit_and(values[0].as_primitive::()) { - let v = self.value.get_or_insert(x); - *v = *v & x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// BIT_OR aggregate expression -#[derive(Debug, Clone)] -pub struct BitOr { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BitOr { - /// Create a new BIT_OR aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BitOr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "BitOrAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bit_or"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - use std::ops::BitOrAssign; - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( - $dt, - |x, y| x.bitor_assign(y), - ))) - }; - } - - let data_type = &self.data_type; - downcast_integer! { - data_type => (helper, data_type), - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } -} - -impl PartialEq for BitOr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct BitOrAccumulator { - value: Option, -} - -impl std::fmt::Debug for BitOrAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BitOrAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for BitOrAccumulator { - fn default() -> Self { - Self { value: None } - } -} - -impl Accumulator for BitOrAccumulator -where - T::Native: std::ops::BitOr, -{ - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(x) = bit_or(values[0].as_primitive::()) { - let v = self.value.get_or_insert(T::Native::usize_as(0)); - *v = *v | x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// BIT_XOR aggregate expression -#[derive(Debug, Clone)] -pub struct BitXor { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BitXor { - /// Create a new BIT_XOR aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BitXor { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "BitXor not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bit_xor"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - use std::ops::BitXorAssign; - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( - $dt, - |x, y| x.bitxor_assign(y), - ))) - }; - } - - let data_type = &self.data_type; - downcast_integer! { - data_type => (helper, data_type), - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } -} - -impl PartialEq for BitXor { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct BitXorAccumulator { - value: Option, -} - -impl std::fmt::Debug for BitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "BitXorAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for BitXorAccumulator { - fn default() -> Self { - Self { value: None } - } -} - -impl Accumulator for BitXorAccumulator -where - T::Native: std::ops::BitXor, -{ - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if let Some(x) = bit_xor(values[0].as_primitive::()) { - let v = self.value.get_or_insert(T::Native::usize_as(0)); - *v = *v ^ x; - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.value, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// Expression for a BIT_XOR(DISTINCT) aggregation. -#[derive(Debug, Clone)] -pub struct DistinctBitXor { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl DistinctBitXor { - /// Create a new DistinctBitXor aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for DistinctBitXor { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty) => { - Ok(Box::>::default()) - }; - } - downcast_integer! { - &self.data_type => (helper), - _ => Err(DataFusionError::NotImplemented(format!( - "DistinctBitXorAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - // State field is a List which stores items to rebuild hash set. - Ok(vec![Field::new_list( - format_state_name(&self.name, "bit_xor distinct"), - Field::new("item", self.data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctBitXor { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -struct DistinctBitXorAccumulator { - values: HashSet, -} - -impl std::fmt::Debug for DistinctBitXorAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DistinctBitXorAccumulator({})", T::DATA_TYPE) - } -} - -impl Default for DistinctBitXorAccumulator { - fn default() -> Self { - Self { - values: HashSet::default(), - } - } -} - -impl Accumulator for DistinctBitXorAccumulator -where - T::Native: std::ops::BitXor + std::hash::Hash + Eq, -{ - fn state(&mut self) -> Result> { - // 1. Stores aggregate state in `ScalarValue::List` - // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set - let state_out = { - let values = self - .values - .iter() - .map(|x| ScalarValue::new_primitive::(Some(*x), &T::DATA_TYPE)) - .collect::>>()?; - - let arr = ScalarValue::new_list(&values, &T::DATA_TYPE); - vec![ScalarValue::List(arr)] - }; - Ok(state_out) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = values[0].as_primitive::(); - match array.nulls().filter(|x| x.null_count() > 0) { - Some(n) => { - for idx in n.valid_indices() { - self.values.insert(array.value(idx)); - } - } - None => array.values().iter().for_each(|x| { - self.values.insert(*x); - }), - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if let Some(state) = states.first() { - let list_arr = as_list_array(state)?; - for arr in list_arr.iter().flatten() { - self.update_batch(&[arr])?; - } - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let mut acc = T::Native::usize_as(0); - for distinct_value in self.values.iter() { - acc = acc ^ *distinct_value; - } - let v = (!self.values.is_empty()).then_some(acc); - ScalarValue::new_primitive::(v, &T::DATA_TYPE) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::array::*; - use arrow::datatypes::*; - - #[test] - fn bit_and_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 15])); - generic_test_op!(a, DataType::Int32, BitAnd, ScalarValue::from(4i32)) - } - - #[test] - fn bit_and_i32_with_nulls() -> Result<()> { - let a: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)])); - generic_test_op!(a, DataType::Int32, BitAnd, ScalarValue::from(1i32)) - } - - #[test] - fn bit_and_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, BitAnd, ScalarValue::Int32(None)) - } - - #[test] - fn bit_and_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 15_u32])); - generic_test_op!(a, DataType::UInt32, BitAnd, ScalarValue::from(4u32)) - } - - #[test] - fn bit_or_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 15])); - generic_test_op!(a, DataType::Int32, BitOr, ScalarValue::from(15i32)) - } - - #[test] - fn bit_or_i32_with_nulls() -> Result<()> { - let a: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)])); - generic_test_op!(a, DataType::Int32, BitOr, ScalarValue::from(7i32)) - } - - #[test] - fn bit_or_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, BitOr, ScalarValue::Int32(None)) - } - - #[test] - fn bit_or_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 15_u32])); - generic_test_op!(a, DataType::UInt32, BitOr, ScalarValue::from(15u32)) - } - - #[test] - fn bit_xor_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 4, 7, 15])); - generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(15i32)) - } - - #[test] - fn bit_xor_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - None, - Some(3), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(6i32)) - } - - #[test] - fn bit_xor_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::Int32(None)) - } - - #[test] - fn bit_xor_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 4_u32, 7_u32, 15_u32])); - generic_test_op!(a, DataType::UInt32, BitXor, ScalarValue::from(15u32)) - } - - #[test] - fn bit_xor_distinct_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 4, 7, 15])); - generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::from(12i32)) - } - - #[test] - fn bit_xor_distinct_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - None, - Some(3), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::from(7i32)) - } - - #[test] - fn bit_xor_distinct_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::Int32(None)) - } - - #[test] - fn bit_xor_distinct_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 4_u32, 7_u32, 15_u32])); - generic_test_op!( - a, - DataType::UInt32, - DistinctBitXor, - ScalarValue::from(12u32) - ) - } -} diff --git a/datafusion/physical-expr/src/aggregate/bool_and_or.rs b/datafusion/physical-expr/src/aggregate/bool_and_or.rs deleted file mode 100644 index 341932bd77a4..000000000000 --- a/datafusion/physical-expr/src/aggregate/bool_and_or.rs +++ /dev/null @@ -1,394 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::{ - array::{ArrayRef, BooleanArray}, - datatypes::Field, -}; -use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, -}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::compute::{bool_and, bool_or}; - -// returns the new value after bool_and/bool_or with the new values, taking nullability into account -macro_rules! typed_bool_and_or_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let delta = $OP(array); - Ok(ScalarValue::$SCALAR(delta)) - }}; -} - -// bool_and/bool_or the array and returns a ScalarValue of its corresponding type. -macro_rules! bool_and_or_batch { - ($VALUES:expr, $OP:ident) => {{ - match $VALUES.data_type() { - DataType::Boolean => { - typed_bool_and_or_batch!($VALUES, BooleanArray, Boolean, $OP) - } - e => { - return internal_err!( - "Bool and/Bool or is not expected to receive the type {e:?}" - ); - } - } - }}; -} - -/// dynamically-typed bool_and(array) -> ScalarValue -fn bool_and_batch(values: &ArrayRef) -> Result { - bool_and_or_batch!(values, bool_and) -} - -/// dynamically-typed bool_or(array) -> ScalarValue -fn bool_or_batch(values: &ArrayRef) -> Result { - bool_and_or_batch!(values, bool_or) -} - -/// BOOL_AND aggregate expression -#[derive(Debug, Clone)] -pub struct BoolAnd { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BoolAnd { - /// Create a new BOOL_AND aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BoolAnd { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::::default()) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bool_and"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - match self.data_type { - DataType::Boolean => { - Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x && y))) - } - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::::default()) - } -} - -impl PartialEq for BoolAnd { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug, Default)] -struct BoolAndAccumulator { - acc: Option, -} - -impl Accumulator for BoolAndAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.acc = match (self.acc, bool_and_batch(values)?) { - (None, ScalarValue::Boolean(v)) => v, - (Some(v), ScalarValue::Boolean(None)) => Some(v), - (Some(a), ScalarValue::Boolean(Some(b))) => Some(a && b), - _ => unreachable!(), - }; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Boolean(self.acc)]) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Boolean(self.acc)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// BOOL_OR aggregate expression -#[derive(Debug, Clone)] -pub struct BoolOr { - name: String, - pub data_type: DataType, - expr: Arc, - nullable: bool, -} - -impl BoolOr { - /// Create a new BOOL_OR aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for BoolOr { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::::default()) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "bool_or"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - match self.data_type { - DataType::Boolean => { - Ok(Box::new(BooleanGroupsAccumulator::new(|x, y| x || y))) - } - _ => not_impl_err!( - "GroupsAccumulator not supported for {} with {}", - self.name(), - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::::default()) - } -} - -impl PartialEq for BoolOr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -#[derive(Debug, Default)] -struct BoolOrAccumulator { - acc: Option, -} - -impl Accumulator for BoolOrAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Boolean(self.acc)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - self.acc = match (self.acc, bool_or_batch(values)?) { - (None, ScalarValue::Boolean(v)) => v, - (Some(v), ScalarValue::Boolean(None)) => Some(v), - (Some(a), ScalarValue::Boolean(Some(b))) => Some(a || b), - _ => unreachable!(), - }; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Boolean(self.acc)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - - #[test] - fn test_bool_and() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); - generic_test_op!(a, DataType::Boolean, BoolAnd, ScalarValue::from(false)) - } - - #[test] - fn bool_and_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![ - Some(true), - None, - Some(true), - Some(true), - ])); - generic_test_op!(a, DataType::Boolean, BoolAnd, ScalarValue::from(true)) - } - - #[test] - fn bool_and_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None, None])); - generic_test_op!(a, DataType::Boolean, BoolAnd, ScalarValue::Boolean(None)) - } - - #[test] - fn test_bool_or() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true, false])); - generic_test_op!(a, DataType::Boolean, BoolOr, ScalarValue::from(true)) - } - - #[test] - fn bool_or_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![ - Some(false), - None, - Some(false), - Some(false), - ])); - generic_test_op!(a, DataType::Boolean, BoolOr, ScalarValue::from(false)) - } - - #[test] - fn bool_or_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None, None])); - generic_test_op!(a, DataType::Boolean, BoolOr, ScalarValue::Boolean(None)) - } -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs deleted file mode 100644 index c549e6219375..000000000000 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ /dev/null @@ -1,1358 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Declaration of built-in (aggregate) functions. -//! This module contains built-in aggregates' enumeration and metadata. -//! -//! Generally, an aggregate has: -//! * a signature -//! * a return type, that is a function of the incoming argument's types -//! * the computation, that must accept each valid signature -//! -//! * Signature: see `Signature` -//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. - -use std::sync::Arc; - -use arrow::datatypes::Schema; - -use datafusion_common::{exec_err, not_impl_err, Result}; -use datafusion_expr::AggregateFunction; - -use crate::aggregate::regr::RegrType; -use crate::expressions::{self, Literal}; -use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; - -/// Create a physical aggregation expression. -/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregation function. -pub fn create_aggregate_expr( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - ordering_req: &[PhysicalSortExpr], - input_schema: &Schema, - name: impl Into, - ignore_nulls: bool, -) -> Result> { - let name = name.into(); - // get the result data type for this aggregate function - let input_phy_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - let data_type = input_phy_types[0].clone(); - let ordering_types = ordering_req - .iter() - .map(|e| e.expr.data_type(input_schema)) - .collect::>>()?; - let input_phy_exprs = input_phy_exprs.to_vec(); - Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => Arc::new( - expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), - ), - (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )), - (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitAnd, _) => Arc::new(expressions::BitAnd::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitOr, _) => Arc::new(expressions::BitOr::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitXor, false) => Arc::new(expressions::BitXor::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BitXor, true) => Arc::new(expressions::DistinctBitXor::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BoolAnd, _) => Arc::new(expressions::BoolAnd::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::BoolOr, _) => Arc::new(expressions::BoolOr::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - )), - (AggregateFunction::Sum, true) => Arc::new(expressions::DistinctSum::new( - vec![input_phy_exprs[0].clone()], - name, - data_type, - )), - (AggregateFunction::ApproxDistinct, _) => Arc::new( - expressions::ApproxDistinct::new(input_phy_exprs[0].clone(), name, data_type), - ), - (AggregateFunction::ArrayAgg, false) => { - let expr = input_phy_exprs[0].clone(); - let nullable = expr.nullable(input_schema)?; - - if ordering_req.is_empty() { - Arc::new(expressions::ArrayAgg::new(expr, name, data_type, nullable)) - } else { - Arc::new(expressions::OrderSensitiveArrayAgg::new( - expr, - name, - data_type, - nullable, - ordering_types, - ordering_req.to_vec(), - )) - } - } - (AggregateFunction::ArrayAgg, true) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "ARRAY_AGG(DISTINCT ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - let expr = input_phy_exprs[0].clone(); - let is_expr_nullable = expr.nullable(input_schema)?; - Arc::new(expressions::DistinctArrayAgg::new( - expr, - name, - data_type, - is_expr_nullable, - )) - } - (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Avg, true) => { - return not_impl_err!("AVG(DISTINCT) aggregations are not available"); - } - (AggregateFunction::Variance, false) => Arc::new(expressions::Variance::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Variance, true) => { - return not_impl_err!("VAR(DISTINCT) aggregations are not available"); - } - (AggregateFunction::VariancePop, false) => Arc::new( - expressions::VariancePop::new(input_phy_exprs[0].clone(), name, data_type), - ), - (AggregateFunction::VariancePop, true) => { - return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); - } - (AggregateFunction::Covariance, false) => Arc::new(expressions::Covariance::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )), - (AggregateFunction::Covariance, true) => { - return not_impl_err!("COVAR(DISTINCT) aggregations are not available"); - } - (AggregateFunction::CovariancePop, false) => { - Arc::new(expressions::CovariancePop::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::CovariancePop, true) => { - return not_impl_err!("COVAR_POP(DISTINCT) aggregations are not available"); - } - (AggregateFunction::Stddev, false) => Arc::new(expressions::Stddev::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Stddev, true) => { - return not_impl_err!("STDDEV(DISTINCT) aggregations are not available"); - } - (AggregateFunction::StddevPop, false) => Arc::new(expressions::StddevPop::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::StddevPop, true) => { - return not_impl_err!("STDDEV_POP(DISTINCT) aggregations are not available"); - } - (AggregateFunction::Correlation, false) => { - Arc::new(expressions::Correlation::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::Correlation, true) => { - return not_impl_err!("CORR(DISTINCT) aggregations are not available"); - } - (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Slope, - data_type, - )), - (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Intercept, - data_type, - )), - (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Count, - data_type, - )), - (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::R2, - data_type, - )), - (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgX, - data_type, - )), - (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgY, - data_type, - )), - (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXX, - data_type, - )), - (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SYY, - data_type, - )), - (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXY, - data_type, - )), - ( - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY, - true, - ) => { - return not_impl_err!("{}(DISTINCT) aggregations are not available", fun); - } - (AggregateFunction::ApproxPercentileCont, false) => { - if input_phy_exprs.len() == 2 { - Arc::new(expressions::ApproxPercentileCont::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } else { - Arc::new(expressions::ApproxPercentileCont::new_with_max_size( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - } - (AggregateFunction::ApproxPercentileCont, true) => { - return not_impl_err!( - "approx_percentile_cont(DISTINCT) aggregations are not available" - ); - } - (AggregateFunction::ApproxPercentileContWithWeight, false) => { - Arc::new(expressions::ApproxPercentileContWithWeight::new( - // Pass in the desired percentile expr - input_phy_exprs, - name, - data_type, - )?) - } - (AggregateFunction::ApproxPercentileContWithWeight, true) => { - return not_impl_err!( - "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" - ); - } - (AggregateFunction::ApproxMedian, false) => { - Arc::new(expressions::ApproxMedian::try_new( - input_phy_exprs[0].clone(), - name, - data_type, - )?) - } - (AggregateFunction::ApproxMedian, true) => { - return not_impl_err!( - "APPROX_MEDIAN(DISTINCT) aggregations are not available" - ); - } - (AggregateFunction::Median, false) => Arc::new(expressions::Median::new( - input_phy_exprs[0].clone(), - name, - data_type, - )), - (AggregateFunction::Median, true) => { - return not_impl_err!("MEDIAN(DISTINCT) aggregations are not available"); - } - (AggregateFunction::FirstValue, _) => Arc::new( - expressions::FirstValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - vec![], - ) - .with_ignore_nulls(ignore_nulls), - ), - (AggregateFunction::LastValue, _) => Arc::new( - expressions::LastValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - ) - .with_ignore_nulls(ignore_nulls), - ), - (AggregateFunction::NthValue, _) => { - let expr = &input_phy_exprs[0]; - let Some(n) = input_phy_exprs[1] - .as_any() - .downcast_ref::() - .map(|literal| literal.value()) - else { - return exec_err!("Second argument of NTH_VALUE needs to be a literal"); - }; - let nullable = expr.nullable(input_schema)?; - Arc::new(expressions::NthValueAgg::new( - expr.clone(), - n.clone().try_into()?, - name, - input_phy_types[0].clone(), - nullable, - ordering_types, - ordering_req.to_vec(), - )) - } - (AggregateFunction::StringAgg, false) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - Arc::new(expressions::StringAgg::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::StringAgg, true) => { - return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); - } - }) -} - -#[cfg(test)] -mod tests { - use arrow::datatypes::{DataType, Field}; - - use datafusion_common::{plan_err, DataFusionError, ScalarValue}; - use datafusion_expr::type_coercion::aggregates::NUMERICS; - use datafusion_expr::{type_coercion, Signature}; - - use crate::expressions::{ - try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, - BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Correlation, Count, Covariance, - DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, - }; - - use super::*; - - #[test] - fn test_count_arragg_approx_expr() -> Result<()> { - let funcs = vec![ - AggregateFunction::Count, - AggregateFunction::ArrayAgg, - AggregateFunction::ApproxDistinct, - ]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Count => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::ApproxDistinct => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::UInt64, false), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - - let result_distinct = create_physical_agg_expr_for_test( - &fun, - true, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Count => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new("c1", DataType::Int64, true), - result_distinct.field().unwrap() - ); - } - AggregateFunction::ApproxDistinct => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new("c1", DataType::UInt64, false), - result_distinct.field().unwrap() - ); - } - AggregateFunction::ArrayAgg => { - assert!(result_distinct.as_any().is::()); - assert_eq!("c1", result_distinct.name()); - assert_eq!( - Field::new_list( - "c1", - Field::new("item", data_type.clone(), true), - true, - ), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - - #[test] - fn test_agg_approx_percentile_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect("failed to create aggregate expr"); - - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), false), - result_agg_phy_exprs.field().unwrap() - ); - } - } - - #[test] - fn test_agg_approx_percentile_invalid_phy_expr() { - for data_type in NUMERICS { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - ), - Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), - ]; - let err = create_physical_agg_expr_for_test( - &AggregateFunction::ApproxPercentileCont, - false, - &input_phy_exprs[..], - &input_schema, - "c1", - ) - .expect_err("should fail due to invalid percentile"); - - assert!(matches!(err, DataFusionError::Plan(_))); - } - } - - #[test] - fn test_min_max_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; - let data_types = vec![ - DataType::UInt32, - DataType::Int32, - DataType::Float32, - DataType::Float64, - DataType::Decimal128(10, 2), - DataType::Utf8, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Min => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Max => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - - #[test] - fn test_bit_and_or_xor_expr() -> Result<()> { - let funcs = vec![ - AggregateFunction::BitAnd, - AggregateFunction::BitOr, - AggregateFunction::BitXor, - ]; - let data_types = vec![DataType::UInt64, DataType::Int64]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::BitAnd => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::BitOr => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::BitXor => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - - #[test] - fn test_bool_and_or_expr() -> Result<()> { - let funcs = vec![AggregateFunction::BoolAnd, AggregateFunction::BoolOr]; - let data_types = vec![DataType::Boolean]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::BoolAnd => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::BoolOr => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - - #[test] - fn test_sum_avg_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - match fun { - AggregateFunction::Sum => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - let expect_type = match data_type { - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::UInt64, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 => DataType::Int64, - DataType::Float32 | DataType::Float64 => DataType::Float64, - _ => data_type.clone(), - }; - - assert_eq!( - Field::new("c1", expect_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - AggregateFunction::Avg => { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ); - } - _ => {} - }; - } - } - Ok(()) - } - - #[test] - fn test_variance_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Variance]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_var_pop_expr() -> Result<()> { - let funcs = vec![AggregateFunction::VariancePop]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_stddev_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Stddev]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_stddev_pop_expr() -> Result<()> { - let funcs = vec![AggregateFunction::StddevPop]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Variance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_covar_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Covariance]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = Schema::new(vec![ - Field::new("c1", data_type.clone(), true), - Field::new("c2", data_type.clone(), true), - ]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema) - .unwrap(), - ), - Arc::new( - expressions::Column::new_with_schema("c2", &input_schema) - .unwrap(), - ), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..2], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Covariance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_covar_pop_expr() -> Result<()> { - let funcs = vec![AggregateFunction::CovariancePop]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = Schema::new(vec![ - Field::new("c1", data_type.clone(), true), - Field::new("c2", data_type.clone(), true), - ]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema) - .unwrap(), - ), - Arc::new( - expressions::Column::new_with_schema("c2", &input_schema) - .unwrap(), - ), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..2], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Covariance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_corr_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Correlation]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = Schema::new(vec![ - Field::new("c1", data_type.clone(), true), - Field::new("c2", data_type.clone(), true), - ]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema) - .unwrap(), - ), - Arc::new( - expressions::Column::new_with_schema("c2", &input_schema) - .unwrap(), - ), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..2], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Covariance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_median_expr() -> Result<()> { - let funcs = vec![AggregateFunction::ApproxMedian]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - - if fun == AggregateFunction::ApproxMedian { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - } - } - Ok(()) - } - - #[test] - fn test_median() -> Result<()> { - let observed = AggregateFunction::ApproxMedian.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - - let observed = AggregateFunction::ApproxMedian.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Int32, observed); - - let observed = - AggregateFunction::ApproxMedian.return_type(&[DataType::Decimal128(10, 6)]); - assert!(observed.is_err()); - - Ok(()) - } - - #[test] - fn test_min_max() -> Result<()> { - let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Utf8, observed); - - let observed = AggregateFunction::Max.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Int32, observed); - - // test decimal for min - let observed = - AggregateFunction::Min.return_type(&[DataType::Decimal128(10, 6)])?; - assert_eq!(DataType::Decimal128(10, 6), observed); - - // test decimal for max - let observed = - AggregateFunction::Max.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Decimal128(28, 13), observed); - - Ok(()) - } - - #[test] - fn test_sum_return_type() -> Result<()> { - let observed = AggregateFunction::Sum.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Sum.return_type(&[DataType::UInt8])?; - assert_eq!(DataType::UInt64, observed); - - let observed = AggregateFunction::Sum.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Sum.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = - AggregateFunction::Sum.return_type(&[DataType::Decimal128(10, 5)])?; - assert_eq!(DataType::Decimal128(20, 5), observed); - - let observed = - AggregateFunction::Sum.return_type(&[DataType::Decimal128(35, 5)])?; - assert_eq!(DataType::Decimal128(38, 5), observed); - - Ok(()) - } - - #[test] - fn test_sum_no_utf8() { - let observed = AggregateFunction::Sum.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - - #[test] - fn test_sum_upcasts() -> Result<()> { - let observed = AggregateFunction::Sum.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::UInt64, observed); - Ok(()) - } - - #[test] - fn test_count_return_type() -> Result<()> { - let observed = AggregateFunction::Count.return_type(&[DataType::Utf8])?; - assert_eq!(DataType::Int64, observed); - - let observed = AggregateFunction::Count.return_type(&[DataType::Int8])?; - assert_eq!(DataType::Int64, observed); - - let observed = - AggregateFunction::Count.return_type(&[DataType::Decimal128(28, 13)])?; - assert_eq!(DataType::Int64, observed); - Ok(()) - } - - #[test] - fn test_avg_return_type() -> Result<()> { - let observed = AggregateFunction::Avg.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Avg.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Avg.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = - AggregateFunction::Avg.return_type(&[DataType::Decimal128(10, 6)])?; - assert_eq!(DataType::Decimal128(14, 10), observed); - - let observed = - AggregateFunction::Avg.return_type(&[DataType::Decimal128(36, 6)])?; - assert_eq!(DataType::Decimal128(38, 10), observed); - Ok(()) - } - - #[test] - fn test_avg_no_utf8() { - let observed = AggregateFunction::Avg.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - - #[test] - fn test_variance_return_type() -> Result<()> { - let observed = AggregateFunction::Variance.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Variance.return_type(&[DataType::Int64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_variance_no_utf8() { - let observed = AggregateFunction::Variance.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - - #[test] - fn test_stddev_return_type() -> Result<()> { - let observed = AggregateFunction::Stddev.return_type(&[DataType::Float32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::Float64])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::UInt32])?; - assert_eq!(DataType::Float64, observed); - - let observed = AggregateFunction::Stddev.return_type(&[DataType::Int64])?; - assert_eq!(DataType::Float64, observed); - - Ok(()) - } - - #[test] - fn test_stddev_no_utf8() { - let observed = AggregateFunction::Stddev.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - } - - // Helper function - // Create aggregate expr with type coercion - fn create_physical_agg_expr_for_test( - fun: &AggregateFunction, - distinct: bool, - input_phy_exprs: &[Arc], - input_schema: &Schema, - name: impl Into, - ) -> Result> { - let name = name.into(); - let coerced_phy_exprs = - coerce_exprs_for_test(fun, input_phy_exprs, input_schema, &fun.signature())?; - if coerced_phy_exprs.is_empty() { - return plan_err!( - "Invalid or wrong number of arguments passed to aggregate: '{name}'" - ); - } - create_aggregate_expr( - fun, - distinct, - &coerced_phy_exprs, - &[], - input_schema, - name, - false, - ) - } - - // Returns the coerced exprs for each `input_exprs`. - // Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the - // data type of `input_exprs` need to be coerced. - fn coerce_exprs_for_test( - agg_fun: &AggregateFunction, - input_exprs: &[Arc], - schema: &Schema, - signature: &Signature, - ) -> Result>> { - if input_exprs.is_empty() { - return Ok(vec![]); - } - let input_types = input_exprs - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - // get the coerced data types - let coerced_types = - type_coercion::aggregates::coerce_types(agg_fun, &input_types, signature)?; - - // try cast if need - input_exprs - .iter() - .zip(coerced_types) - .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) - .collect::>>() - } -} diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs deleted file mode 100644 index a47d35053208..000000000000 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ /dev/null @@ -1,524 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use crate::aggregate::covariance::CovarianceAccumulator; -use crate::aggregate::stats::StatsType; -use crate::aggregate::stddev::StddevAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{ - array::ArrayRef, - compute::{and, filter, is_not_null}, - datatypes::{DataType, Field}, -}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// CORR aggregate expression -#[derive(Debug)] -pub struct Correlation { - name: String, - expr1: Arc, - expr2: Arc, -} - -impl Correlation { - /// Create a new COVAR_POP aggregate function - pub fn new( - expr1: Arc, - expr2: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of correlation just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr1, - expr2, - } - } -} - -impl AggregateExpr for Correlation { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CorrelationAccumulator::try_new()?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "m2_1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "m2_2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr1.clone(), self.expr2.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Correlation { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) - .unwrap_or(false) - } -} - -/// An accumulator to compute correlation -#[derive(Debug)] -pub struct CorrelationAccumulator { - covar: CovarianceAccumulator, - stddev1: StddevAccumulator, - stddev2: StddevAccumulator, -} - -impl CorrelationAccumulator { - /// Creates a new `CorrelationAccumulator` - pub fn try_new() -> Result { - Ok(Self { - covar: CovarianceAccumulator::try_new(StatsType::Population)?, - stddev1: StddevAccumulator::try_new(StatsType::Population)?, - stddev2: StddevAccumulator::try_new(StatsType::Population)?, - }) - } -} - -impl Accumulator for CorrelationAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.covar.get_count()), - ScalarValue::from(self.covar.get_mean1()), - ScalarValue::from(self.stddev1.get_m2()), - ScalarValue::from(self.covar.get_mean2()), - ScalarValue::from(self.stddev2.get_m2()), - ScalarValue::from(self.covar.get_algo_const()), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - // TODO: null input skipping logic duplicated across Correlation - // and its children accumulators. - // This could be simplified by splitting up input filtering and - // calculation logic in children accumulators, and calling only - // calculation part from Correlation - let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { - let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; - let values1 = filter(&values[0], &mask)?; - let values2 = filter(&values[1], &mask)?; - - vec![values1, values2] - } else { - values.to_vec() - }; - - self.covar.update_batch(&values)?; - self.stddev1.update_batch(&values[0..1])?; - self.stddev2.update_batch(&values[1..2])?; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { - let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; - let values1 = filter(&values[0], &mask)?; - let values2 = filter(&values[1], &mask)?; - - vec![values1, values2] - } else { - values.to_vec() - }; - - self.covar.retract_batch(&values)?; - self.stddev1.retract_batch(&values[0..1])?; - self.stddev2.retract_batch(&values[1..2])?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let states_c = [ - states[0].clone(), - states[1].clone(), - states[3].clone(), - states[5].clone(), - ]; - let states_s1 = [states[0].clone(), states[1].clone(), states[2].clone()]; - let states_s2 = [states[0].clone(), states[3].clone(), states[4].clone()]; - - self.covar.merge_batch(&states_c)?; - self.stddev1.merge_batch(&states_s1)?; - self.stddev2.merge_batch(&states_s2)?; - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let covar = self.covar.evaluate()?; - let stddev1 = self.stddev1.evaluate()?; - let stddev2 = self.stddev2.evaluate()?; - - if let ScalarValue::Float64(Some(c)) = covar { - if let ScalarValue::Float64(Some(s1)) = stddev1 { - if let ScalarValue::Float64(Some(s2)) = stddev2 { - if s1 == 0_f64 || s2 == 0_f64 { - return Ok(ScalarValue::Float64(Some(0_f64))); - } else { - return Ok(ScalarValue::Float64(Some(c / s1 / s2))); - } - } - } - } - - Ok(ScalarValue::Float64(None)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) - + self.covar.size() - - std::mem::size_of_val(&self.stddev1) - + self.stddev1.size() - - std::mem::size_of_val(&self.stddev2) - + self.stddev2.size() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op2; - use arrow::{array::*, datatypes::*}; - - #[test] - fn correlation_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 7_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(0.9819805060619659_f64) - ) - } - - #[test] - fn correlation_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, -5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(0.17066403719657236_f64) - ) - } - - #[test] - fn correlation_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_f64_6() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![ - 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, - ])); - let b = Arc::new(Float64Array::from(vec![ - 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, - ])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Correlation, - ScalarValue::from(0.9860135594710389_f64) - ) - } - - #[test] - fn correlation_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32])); - let b: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 5_u32, 6_u32])); - generic_test_op2!( - a, - b, - DataType::UInt32, - DataType::UInt32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32])); - let b: ArrayRef = Arc::new(Float32Array::from(vec![4_f32, 5_f32, 6_f32])); - generic_test_op2!( - a, - b, - DataType::Float32, - DataType::Float32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_i32_with_nulls_1() -> Result<()> { - let a: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(3)])); - let b: ArrayRef = - Arc::new(Int32Array::from(vec![Some(4), None, Some(6), Some(3)])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::from(0.1889822365046137_f64) - ) - } - - #[test] - fn correlation_i32_with_nulls_2() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - Some(9), - Some(3), - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(4), - Some(5), - Some(5), - None, - Some(6), - ])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn correlation_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Correlation, - ScalarValue::Float64(None) - ) - } - - #[test] - fn correlation_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![1.1_f64, 2.2_f64, 3.3_f64])); - let d = Arc::new(Float64Array::from(vec![4.4_f64, 5.5_f64, 9.9_f64])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(0.8443707186481967)); - - Ok(()) - } - - #[test] - fn correlation_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![None])); - let d = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(Correlation::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(1_f64)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs deleted file mode 100644 index 567a5589cb8b..000000000000 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ /dev/null @@ -1,453 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::fmt::Debug; -use std::ops::BitAnd; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, Int64Array}; -use arrow::compute; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::PrimitiveArray; -use arrow_buffer::BooleanBuffer; -use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, EmitTo, GroupsAccumulator}; - -use crate::expressions::format_state_name; - -use super::groups_accumulator::accumulate::accumulate_indices; - -/// COUNT aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug, Clone)] -pub struct Count { - name: String, - data_type: DataType, - nullable: bool, - /// Input exprs - /// - /// For `COUNT(c1)` this is `[c1]` - /// For `COUNT(c1, c2)` this is `[c1, c2]` - exprs: Vec>, -} - -impl Count { - /// Create a new COUNT aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs: vec![expr], - data_type, - nullable: true, - } - } - - pub fn new_with_multiple_exprs( - exprs: Vec>, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - exprs, - data_type, - nullable: true, - } - } -} - -/// An accumulator to compute the counts of [`PrimitiveArray`]. -/// Stores values as native types, and does overflow checking -/// -/// Unlike most other accumulators, COUNT never produces NULLs. If no -/// non-null values are seen in any group the output is 0. Thus, this -/// accumulator has no additional null or seen filter tracking. -#[derive(Debug)] -struct CountGroupsAccumulator { - /// Count per group. - /// - /// Note this is an i64 and not a u64 (or usize) because the - /// output type of count is `DataType::Int64`. Thus by using `i64` - /// for the counts, the output [`Int64Array`] can be created - /// without copy. - counts: Vec, -} - -impl CountGroupsAccumulator { - pub fn new() -> Self { - Self { counts: vec![] } - } -} - -impl GroupsAccumulator for CountGroupsAccumulator { - fn update_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = &values[0]; - - // Add one to each group's counter for each non null, non - // filtered value - self.counts.resize(total_num_groups, 0); - accumulate_indices( - group_indices, - values.logical_nulls().as_ref(), - opt_filter, - |group_index| { - self.counts[group_index] += 1; - }, - ); - - Ok(()) - } - - fn merge_batch( - &mut self, - values: &[ArrayRef], - group_indices: &[usize], - opt_filter: Option<&arrow_array::BooleanArray>, - total_num_groups: usize, - ) -> Result<()> { - assert_eq!(values.len(), 1, "one argument to merge_batch"); - // first batch is counts, second is partial sums - let partial_counts = values[0].as_primitive::(); - - // intermediate counts are always created as non null - assert_eq!(partial_counts.null_count(), 0); - let partial_counts = partial_counts.values(); - - // Adds the counts with the partial counts - self.counts.resize(total_num_groups, 0); - match opt_filter { - Some(filter) => filter - .iter() - .zip(group_indices.iter()) - .zip(partial_counts.iter()) - .for_each(|((filter_value, &group_index), partial_count)| { - if let Some(true) = filter_value { - self.counts[group_index] += partial_count; - } - }), - None => group_indices.iter().zip(partial_counts.iter()).for_each( - |(&group_index, partial_count)| { - self.counts[group_index] += partial_count; - }, - ), - } - - Ok(()) - } - - fn evaluate(&mut self, emit_to: EmitTo) -> Result { - let counts = emit_to.take_needed(&mut self.counts); - - // Count is always non null (null inputs just don't contribute to the overall values) - let nulls = None; - let array = PrimitiveArray::::new(counts.into(), nulls); - - Ok(Arc::new(array)) - } - - // return arrays for counts - fn state(&mut self, emit_to: EmitTo) -> Result> { - let counts = emit_to.take_needed(&mut self.counts); - let counts: PrimitiveArray = Int64Array::from(counts); // zero copy, no nulls - Ok(vec![Arc::new(counts) as ArrayRef]) - } - - fn size(&self) -> usize { - self.counts.capacity() * std::mem::size_of::() - } -} - -/// count null values for multiple columns -/// for each row if one column value is null, then null_count + 1 -fn null_count_for_multiple_cols(values: &[ArrayRef]) -> usize { - if values.len() > 1 { - let result_bool_buf: Option = values - .iter() - .map(|a| a.logical_nulls()) - .fold(None, |acc, b| match (acc, b) { - (Some(acc), Some(b)) => Some(acc.bitand(b.inner())), - (Some(acc), None) => Some(acc), - (None, Some(b)) => Some(b.into_inner()), - _ => None, - }); - result_bool_buf.map_or(0, |b| values[0].len() - b.count_set_bits()) - } else { - values[0] - .logical_nulls() - .map_or(0, |nulls| nulls.null_count()) - } -} - -impl AggregateExpr for Count { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "count"), - DataType::Int64, - true, - )]) - } - - fn expressions(&self) -> Vec> { - self.exprs.clone() - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - // groups accumulator only supports `COUNT(c1)`, not - // `COUNT(c1, c2)`, etc - self.exprs.len() == 1 - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(CountAccumulator::new())) - } - - fn create_groups_accumulator(&self) -> Result> { - // instantiate specialized accumulator - Ok(Box::new(CountGroupsAccumulator::new())) - } -} - -impl PartialEq for Count { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.exprs.len() == x.exprs.len() - && self - .exprs - .iter() - .zip(x.exprs.iter()) - .all(|(expr1, expr2)| expr1.eq(expr2)) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -struct CountAccumulator { - count: i64, -} - -impl CountAccumulator { - /// new count accumulator - pub fn new() -> Self { - Self { count: 0 } - } -} - -impl Accumulator for CountAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ScalarValue::Int64(Some(self.count))]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count += (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let array = &values[0]; - self.count -= (array.len() - null_count_for_multiple_cols(values)) as i64; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], Int64Array); - let delta = &compute::sum(counts); - if let Some(d) = delta { - self.count += *d; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.count))) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::aggregate; - use crate::expressions::{col, lit}; - use crate::generic_test_op; - use arrow::{array::*, datatypes::*}; - - #[test] - fn count_elements() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(5i64)) - } - - #[test] - fn count_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - None, - None, - Some(3), - None, - ])); - generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(3i64)) - } - - #[test] - fn count_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![ - None, None, None, None, None, None, None, None, - ])); - generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64)) - } - - #[test] - fn count_empty() -> Result<()> { - let a: Vec = vec![]; - let a: ArrayRef = Arc::new(BooleanArray::from(a)); - generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64)) - } - - #[test] - fn count_utf8() -> Result<()> { - let a: ArrayRef = - Arc::new(StringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"])); - generic_test_op!(a, DataType::Utf8, Count, ScalarValue::from(5i64)) - } - - #[test] - fn count_large_utf8() -> Result<()> { - let a: ArrayRef = - Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"])); - generic_test_op!(a, DataType::LargeUtf8, Count, ScalarValue::from(5i64)) - } - - #[test] - fn count_multi_cols() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(2), - None, - None, - Some(3), - None, - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - Some(4), - ])); - let schema = Schema::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - - let agg = Arc::new(Count::new_with_multiple_exprs( - vec![col("a", &schema)?, col("b", &schema)?], - "bla".to_string(), - DataType::Int64, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from(2i64); - - assert_eq!(expected, actual); - Ok(()) - } - - #[test] - fn count_eq() -> Result<()> { - let count = Count::new(lit(1i8), "COUNT(1)".to_string(), DataType::Int64); - let arc_count: Arc = Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )); - let box_count: Box = Box::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - )); - let count2 = Count::new(lit(1i8), "COUNT(2)".to_string(), DataType::Int64); - - assert!(arc_count.eq(&box_count)); - assert!(box_count.eq(&arc_count)); - assert!(arc_count.eq(&count)); - assert!(count.eq(&box_count)); - assert!(count.eq(&arc_count)); - - assert!(count2.ne(&arc_count)); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs deleted file mode 100644 index 52f1c5c0f9a0..000000000000 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ /dev/null @@ -1,718 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod bytes; -mod native; - -use std::any::Any; -use std::collections::HashSet; -use std::fmt::Debug; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType, - Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; - -use datafusion_common::{internal_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; -use crate::aggregate::count_distinct::native::{ - FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, -}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::binary_map::OutputType; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; - -/// Expression for a `COUNT(DISTINCT)` aggregation. -#[derive(Debug)] -pub struct DistinctCount { - /// Column name - name: String, - /// The DataType used to hold the state for each input - state_data_type: DataType, - /// The input arguments - expr: Arc, -} - -impl DistinctCount { - /// Create a new COUNT(DISTINCT) aggregate function. - pub fn new( - input_data_type: DataType, - expr: Arc, - name: impl Into, - ) -> Self { - Self { - name: name.into(), - state_data_type: input_data_type, - expr, - } - } -} - -impl AggregateExpr for DistinctCount { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, true)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "count distinct"), - Field::new("item", self.state_data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - - let data_type = &self.state_data_type; - Ok(match data_type { - // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator - Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Int64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt8 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt16 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - UInt64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Decimal128(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal128Type, - >::new(data_type)), - Decimal256(_, _) => Box::new(PrimitiveDistinctCountAccumulator::< - Decimal256Type, - >::new(data_type)), - - Date32 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Date64 => Box::new(PrimitiveDistinctCountAccumulator::::new( - data_type, - )), - Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32MillisecondType, - >::new(data_type)), - Time32(Second) => Box::new(PrimitiveDistinctCountAccumulator::< - Time32SecondType, - >::new(data_type)), - Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64MicrosecondType, - >::new(data_type)), - Time64(Nanosecond) => Box::new(PrimitiveDistinctCountAccumulator::< - Time64NanosecondType, - >::new(data_type)), - Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMicrosecondType, - >::new(data_type)), - Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampMillisecondType, - >::new(data_type)), - Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampNanosecondType, - >::new(data_type)), - Timestamp(Second, _) => Box::new(PrimitiveDistinctCountAccumulator::< - TimestampSecondType, - >::new(data_type)), - - Float16 => Box::new(FloatDistinctCountAccumulator::::new()), - Float32 => Box::new(FloatDistinctCountAccumulator::::new()), - Float64 => Box::new(FloatDistinctCountAccumulator::::new()), - - Utf8 => Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)), - LargeUtf8 => { - Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) - } - Binary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - LargeBinary => Box::new(BytesDistinctCountAccumulator::::new( - OutputType::Binary, - )), - - // Use the generic accumulator based on `ScalarValue` for all other types - _ => Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - }), - }) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for DistinctCount { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.state_data_type == x.state_data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. -/// -/// It stores intermediate results as a `ListArray` -/// -/// Note that many types have specialized accumulators that are (much) -/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and -/// [`BytesDistinctCountAccumulator`] -#[derive(Debug)] -struct DistinctCountAccumulator { - values: HashSet, - state_data_type: DataType, -} - -impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * - // number of batches This method is faster than .full_size(), however it is - // not suitable for variable length values like strings or complex types - fn fixed_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .next() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .unwrap_or(0) - + std::mem::size_of::() - } - - // calculates the size as accurately as possible. Note that calling this - // method is expensive - fn full_size(&self) -> usize { - std::mem::size_of_val(self) - + (std::mem::size_of::() * self.values.capacity()) - + self - .values - .iter() - .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals)) - .sum::() - + std::mem::size_of::() - } -} - -impl Accumulator for DistinctCountAccumulator { - /// Returns the distinct values seen so far as (one element) ListArray. - fn state(&mut self) -> Result> { - let scalars = self.values.iter().cloned().collect::>(); - let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let arr = &values[0]; - if arr.data_type() == &DataType::Null { - return Ok(()); - } - - (0..arr.len()).try_for_each(|index| { - if !arr.is_null(index) { - let scalar = ScalarValue::try_from_array(arr, index)?; - self.values.insert(scalar); - } - Ok(()) - }) - } - - /// Merges multiple sets of distinct values into the current set. - /// - /// The input to this function is a `ListArray` with **multiple** rows, - /// where each row contains the values from a partial aggregate's phase (e.g. - /// the result of calling `Self::state` on multiple accumulators). - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let array = &states[0]; - let list_array = array.as_list::(); - for inner_array in list_array.iter() { - let Some(inner_array) = inner_array else { - return internal_err!( - "Intermediate results of COUNT DISTINCT should always be non null" - ); - }; - self.update_batch(&[inner_array])?; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::Int64(Some(self.values.len() as i64))) - } - - fn size(&self) -> usize { - match &self.state_data_type { - DataType::Boolean | DataType::Null => self.fixed_size(), - d if d.is_primitive() => self.fixed_size(), - _ => self.full_size(), - } - } -} - -#[cfg(test)] -mod tests { - use arrow::array::{ - BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }; - use arrow_array::Decimal256Array; - use arrow_buffer::i256; - - use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; - use datafusion_common::internal_err; - use datafusion_common::DataFusionError; - - use crate::expressions::NoOp; - - use super::*; - - macro_rules! state_to_vec_primitive { - ($LIST:expr, $DATA_TYPE:ident) => {{ - let arr = ScalarValue::raw_data($LIST).unwrap(); - let list_arr = as_list_array(&arr).unwrap(); - let arr = list_arr.values(); - let arr = as_primitive_array::<$DATA_TYPE>(arr)?; - arr.values().iter().cloned().collect::>() - }}; - } - - macro_rules! test_count_distinct_update_batch_numeric { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(1), - Some(1), - None, - Some(3), - Some(2), - None, - Some(2), - Some(3), - Some(1), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![1, 2, 3]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - fn state_to_vec_bool(sv: &ScalarValue) -> Result> { - let arr = ScalarValue::raw_data(sv)?; - let list_arr = as_list_array(&arr)?; - let arr = list_arr.values(); - let bool_arr = as_boolean_array(arr)?; - Ok(bool_arr.iter().flatten().collect()) - } - - fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - arrays[0].data_type().clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - accum.update_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - fn run_update( - data_types: &[DataType], - rows: &[Vec], - ) -> Result<(Vec, ScalarValue)> { - let agg = DistinctCount::new( - data_types[0].clone(), - Arc::new(NoOp::new()), - String::from("__col_name__"), - ); - - let mut accum = agg.create_accumulator()?; - - let cols = (0..rows[0].len()) - .map(|i| { - rows.iter() - .map(|inner| inner[i].clone()) - .collect::>() - }) - .collect::>(); - - let arrays: Vec = cols - .iter() - .map(|c| ScalarValue::iter_to_array(c.clone())) - .collect::>>()?; - - accum.update_batch(&arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - // Used trait to create associated constant for f32 and f64 - trait SubNormal: 'static { - const SUBNORMAL: Self; - } - - impl SubNormal for f64 { - const SUBNORMAL: Self = 1.0e-308_f64; - } - - impl SubNormal for f32 { - const SUBNORMAL: Self = 1.0e-38_f32; - } - - macro_rules! test_count_distinct_update_batch_floating_point { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(<$PRIM_TYPE>::INFINITY), - Some(<$PRIM_TYPE>::NAN), - Some(1.0), - Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL), - Some(1.0), - Some(<$PRIM_TYPE>::INFINITY), - None, - Some(3.0), - Some(-4.5), - Some(2.0), - None, - Some(2.0), - Some(3.0), - Some(<$PRIM_TYPE>::NEG_INFINITY), - Some(1.0), - Some(<$PRIM_TYPE>::NAN), - Some(<$PRIM_TYPE>::NEG_INFINITY), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - - dbg!(&state_vec); - state_vec.sort_by(|a, b| match (a, b) { - (lhs, rhs) => lhs.total_cmp(rhs), - }); - - let nan_idx = state_vec.len() - 1; - assert_eq!(states.len(), 1); - assert_eq!( - &state_vec[..nan_idx], - vec![ - <$PRIM_TYPE>::NEG_INFINITY, - -4.5, - <$PRIM_TYPE as SubNormal>::SUBNORMAL, - 1.0, - 2.0, - 3.0, - <$PRIM_TYPE>::INFINITY - ] - ); - assert!(state_vec[nan_idx].is_nan()); - assert_eq!(result, ScalarValue::Int64(Some(8))); - - Ok(()) - }}; - } - - macro_rules! test_count_distinct_update_batch_bigint { - ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ - let values: Vec> = vec![ - Some(i256::from(1)), - Some(i256::from(1)), - None, - Some(i256::from(3)), - Some(i256::from(2)), - None, - Some(i256::from(2)), - Some(i256::from(3)), - Some(i256::from(1)), - ]; - - let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - - let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); - state_vec.sort(); - - assert_eq!(states.len(), 1); - assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - Ok(()) - }}; - } - - #[test] - fn count_distinct_update_batch_i8() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) - } - - #[test] - fn count_distinct_update_batch_i16() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16) - } - - #[test] - fn count_distinct_update_batch_i32() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32) - } - - #[test] - fn count_distinct_update_batch_i64() -> Result<()> { - test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64) - } - - #[test] - fn count_distinct_update_batch_u8() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8) - } - - #[test] - fn count_distinct_update_batch_u16() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16) - } - - #[test] - fn count_distinct_update_batch_u32() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32) - } - - #[test] - fn count_distinct_update_batch_u64() -> Result<()> { - test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64) - } - - #[test] - fn count_distinct_update_batch_f32() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32) - } - - #[test] - fn count_distinct_update_batch_f64() -> Result<()> { - test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) - } - - #[test] - fn count_distinct_update_batch_i256() -> Result<()> { - test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) - } - - #[test] - fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { - let arrays = vec![Arc::new(data) as ArrayRef]; - let (states, result) = run_update_batch(&arrays)?; - let mut state_vec = state_to_vec_bool(&states[0])?; - state_vec.sort(); - - let count = match result { - ScalarValue::Int64(c) => c.ok_or_else(|| { - DataFusionError::Internal("Found None count".to_string()) - }), - scalar => { - internal_err!("Found non int64 scalar value from count: {scalar}") - } - }?; - Ok((state_vec, count)) - }; - - let zero_count_values = BooleanArray::from(Vec::::new()); - - let one_count_values = BooleanArray::from(vec![false, false]); - let one_count_values_with_null = - BooleanArray::from(vec![Some(true), Some(true), None, None]); - - let two_count_values = BooleanArray::from(vec![true, false, true, false, true]); - let two_count_values_with_null = BooleanArray::from(vec![ - Some(true), - Some(false), - None, - None, - Some(true), - Some(false), - ]); - - assert_eq!(get_count(zero_count_values)?, (Vec::::new(), 0)); - assert_eq!(get_count(one_count_values)?, (vec![false], 1)); - assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1)); - assert_eq!(get_count(two_count_values)?, (vec![false, true], 2)); - assert_eq!( - get_count(two_count_values_with_null)?, - (vec![false, true], 2) - ); - Ok(()) - } - - #[test] - fn count_distinct_update_batch_all_nulls() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from( - vec![None, None, None, None] as Vec> - )) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update_batch_empty() -> Result<()> { - let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef]; - - let (states, result) = run_update_batch(&arrays)?; - let state_vec = state_to_vec_primitive!(&states[0], Int32Type); - assert_eq!(states.len(), 1); - assert!(state_vec.is_empty()); - assert_eq!(result, ScalarValue::Int64(Some(0))); - - Ok(()) - } - - #[test] - fn count_distinct_update() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(5))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(5))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(3))); - Ok(()) - } - - #[test] - fn count_distinct_update_with_nulls() -> Result<()> { - let (states, result) = run_update( - &[DataType::Int32], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(Some(-2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::Int32(Some(-1))], - vec![ScalarValue::Int32(None)], - vec![ScalarValue::Int32(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - - let (states, result) = run_update( - &[DataType::UInt64], - &[ - // None of these updates contains a None, so these are accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(Some(2))], - // Each of these updates contains at least one None, so these - // won't be accumulated. - vec![ScalarValue::UInt64(Some(1))], - vec![ScalarValue::UInt64(None)], - vec![ScalarValue::UInt64(None)], - ], - )?; - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(2))); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs deleted file mode 100644 index ba9bdbc8aee3..000000000000 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ /dev/null @@ -1,773 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::Float64Array; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, - datatypes::Field, -}; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; - -use crate::aggregate::stats::StatsType; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; - -/// COVAR and COVAR_SAMP aggregate expression -#[derive(Debug)] -pub struct Covariance { - name: String, - expr1: Arc, - expr2: Arc, -} - -/// COVAR_POP aggregate expression -#[derive(Debug)] -pub struct CovariancePop { - name: String, - expr1: Arc, - expr2: Arc, -} - -impl Covariance { - /// Create a new COVAR aggregate function - pub fn new( - expr1: Arc, - expr2: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of covariance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr1, - expr2, - } - } -} - -impl AggregateExpr for Covariance { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr1.clone(), self.expr2.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Covariance { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) - .unwrap_or(false) - } -} - -impl CovariancePop { - /// Create a new COVAR_POP aggregate function - pub fn new( - expr1: Arc, - expr2: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of covariance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr1, - expr2, - } - } -} - -impl AggregateExpr for CovariancePop { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CovarianceAccumulator::try_new( - StatsType::Population, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr1.clone(), self.expr2.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for CovariancePop { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) - .unwrap_or(false) - } -} - -/// An accumulator to compute covariance -/// The algrithm used is an online implementation and numerically stable. It is derived from the following paper -/// for calculating variance: -/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". -/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. -/// -/// The algorithm has been analyzed here: -/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". -/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. -/// -/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, -/// parallelizable and numerically stable. - -#[derive(Debug)] -pub struct CovarianceAccumulator { - algo_const: f64, - mean1: f64, - mean2: f64, - count: u64, - stats_type: StatsType, -} - -impl CovarianceAccumulator { - /// Creates a new `CovarianceAccumulator` - pub fn try_new(s_type: StatsType) -> Result { - Ok(Self { - algo_const: 0_f64, - mean1: 0_f64, - mean2: 0_f64, - count: 0_u64, - stats_type: s_type, - }) - } - - pub fn get_count(&self) -> u64 { - self.count - } - - pub fn get_mean1(&self) -> f64 { - self.mean1 - } - - pub fn get_mean2(&self) -> f64 { - self.mean2 - } - - pub fn get_algo_const(&self) -> f64 { - self.algo_const - } -} - -impl Accumulator for CovarianceAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean1), - ScalarValue::from(self.mean2), - ScalarValue::from(self.algo_const), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); - let new_count = self.count + 1; - let delta1 = value1 - self.mean1; - let new_mean1 = delta1 / new_count as f64 + self.mean1; - let delta2 = value2 - self.mean2; - let new_mean2 = delta2 / new_count as f64 + self.mean2; - let new_c = delta1 * (value2 - new_mean2) + self.algo_const; - - self.count += 1; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values1 = &cast(&values[0], &DataType::Float64)?; - let values2 = &cast(&values[1], &DataType::Float64)?; - let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); - let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); - - for i in 0..values1.len() { - let value1 = if values1.is_valid(i) { - arr1.next() - } else { - None - }; - let value2 = if values2.is_valid(i) { - arr2.next() - } else { - None - }; - - if value1.is_none() || value2.is_none() { - continue; - } - - let value1 = unwrap_or_internal_err!(value1); - let value2 = unwrap_or_internal_err!(value2); - - let new_count = self.count - 1; - let delta1 = self.mean1 - value1; - let new_mean1 = delta1 / new_count as f64 + self.mean1; - let delta2 = self.mean2 - value2; - let new_mean2 = delta2 / new_count as f64 + self.mean2; - let new_c = self.algo_const - delta1 * (new_mean2 - value2); - - self.count -= 1; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means1 = downcast_value!(states[1], Float64Array); - let means2 = downcast_value!(states[2], Float64Array); - let cs = downcast_value!(states[3], Float64Array); - - for i in 0..counts.len() { - let c = counts.value(i); - if c == 0_u64 { - continue; - } - let new_count = self.count + c; - let new_mean1 = self.mean1 * self.count as f64 / new_count as f64 - + means1.value(i) * c as f64 / new_count as f64; - let new_mean2 = self.mean2 * self.count as f64 / new_count as f64 - + means2.value(i) * c as f64 / new_count as f64; - let delta1 = self.mean1 - means1.value(i); - let delta2 = self.mean2 - means2.value(i); - let new_c = self.algo_const - + cs.value(i) - + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64; - - self.count = new_count; - self.mean1 = new_mean1; - self.mean2 = new_mean2; - self.algo_const = new_c; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let count = match self.stats_type { - StatsType::Population => self.count, - StatsType::Sample => { - if self.count > 0 { - self.count - 1 - } else { - self.count - } - } - }; - - if count == 0 { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) - } - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op2; - use arrow::{array::*, datatypes::*}; - - #[test] - fn covariance_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Covariance, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn covariance_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Covariance, - ScalarValue::from(0.9033333333333335_f64) - ) - } - - #[test] - fn covariance_f64_5() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0.6022222222222223_f64) - ) - } - - #[test] - fn covariance_f64_6() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![ - 1_f64, 2_f64, 3_f64, 1.1_f64, 2.2_f64, 3.3_f64, - ])); - let b = Arc::new(Float64Array::from(vec![ - 4_f64, 5_f64, 6_f64, 4.4_f64, 5.5_f64, 6.6_f64, - ])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0.7616666666666666_f64) - ) - } - - #[test] - fn covariance_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32])); - let b: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 5_u32, 6_u32])); - generic_test_op2!( - a, - b, - DataType::UInt32, - DataType::UInt32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_f32() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32])); - let b: ArrayRef = Arc::new(Float32Array::from(vec![4_f32, 5_f32, 6_f32])); - generic_test_op2!( - a, - b, - DataType::Float32, - DataType::Float32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_i32_with_nulls_1() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![Some(4), None, Some(6)])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn covariance_i32_with_nulls_2() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - None, - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(4), - Some(9), - Some(5), - Some(8), - Some(6), - None, - ])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::from(0.6666666666666666_f64) - ) - } - - #[test] - fn covariance_i32_with_nulls_3() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - None, - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(4), - Some(9), - Some(5), - Some(8), - Some(6), - None, - ])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Covariance, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn covariance_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Covariance, - ScalarValue::Float64(None) - ) - } - - #[test] - fn covariance_pop_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - CovariancePop, - ScalarValue::Float64(None) - ) - } - - #[test] - fn covariance_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![2_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Covariance, - ScalarValue::Float64(None) - ) - } - - #[test] - fn covariance_pop_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![2_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - CovariancePop, - ScalarValue::from(0_f64) - ) - } - - #[test] - fn covariance_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![1.1_f64, 2.2_f64, 3.3_f64])); - let d = Arc::new(Float64Array::from(vec![4.4_f64, 5.5_f64, 6.6_f64])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(0.7616666666666666)); - - Ok(()) - } - - #[test] - fn covariance_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - let c = Arc::new(Float64Array::from(vec![None])); - let d = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), - ]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a, b])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![c, d])?; - - let agg1 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(CovariancePop::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(0.6666666666666666)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/aggregate/grouping.rs b/datafusion/physical-expr/src/aggregate/grouping.rs deleted file mode 100644 index d43bcd5c7091..000000000000 --- a/datafusion/physical-expr/src/aggregate/grouping.rs +++ /dev/null @@ -1,103 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::Accumulator; - -use crate::expressions::format_state_name; - -/// GROUPING aggregate expression -/// Returns the amount of non-null values of the given expression. -#[derive(Debug)] -pub struct Grouping { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Grouping { - /// Create a new GROUPING aggregate function. - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - nullable: true, - } - } -} - -impl AggregateExpr for Grouping { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int32, self.nullable)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "grouping"), - DataType::Int32, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - not_impl_err!( - "physical plan is not yet implemented for GROUPING aggregate function" - ) - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Grouping { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs deleted file mode 100644 index ed373ba13d5e..000000000000 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ /dev/null @@ -1,332 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! # Median - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef}; -use arrow::datatypes::{DataType, Field}; -use arrow_array::cast::AsArray; -use arrow_array::{downcast_integer, ArrowNativeTypeOp, ArrowNumericType}; -use arrow_buffer::ArrowNativeType; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::fmt::Formatter; -use std::sync::Arc; - -/// MEDIAN aggregate expression. This uses a lot of memory because all values need to be -/// stored in memory before a result can be computed. If an approximation is sufficient -/// then APPROX_MEDIAN provides a much more efficient solution. -#[derive(Debug)] -pub struct Median { - name: String, - expr: Arc, - data_type: DataType, -} - -impl Median { - /// Create a new MEDIAN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type, - } - } -} - -impl AggregateExpr for Median { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - use arrow_array::types::*; - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(MedianAccumulator::<$t> { - data_type: $dt.clone(), - all_values: vec![], - })) - }; - } - let dt = &self.data_type; - downcast_integer! { - dt => (helper, dt), - DataType::Float16 => helper!(Float16Type, dt), - DataType::Float32 => helper!(Float32Type, dt), - DataType::Float64 => helper!(Float64Type, dt), - DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), - DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), - _ => Err(DataFusionError::NotImplemented(format!( - "MedianAccumulator not supported for {} with {}", - self.name(), - self.data_type - ))), - } - } - - fn state_fields(&self) -> Result> { - //Intermediate state is a list of the elements we have collected so far - let field = Field::new("item", self.data_type.clone(), true); - let data_type = DataType::List(Arc::new(field)); - - Ok(vec![Field::new( - format_state_name(&self.name, "median"), - data_type, - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Median { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// The median accumulator accumulates the raw input values -/// as `ScalarValue`s -/// -/// The intermediate state is represented as a List of scalar values updated by -/// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values -/// in the final evaluation step so that we avoid expensive conversions and -/// allocations during `update_batch`. -struct MedianAccumulator { - data_type: DataType, - all_values: Vec, -} - -impl std::fmt::Debug for MedianAccumulator { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "MedianAccumulator({})", self.data_type) - } -} - -impl Accumulator for MedianAccumulator { - fn state(&mut self) -> Result> { - let all_values = self - .all_values - .iter() - .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) - .collect::>>()?; - - let arr = ScalarValue::new_list(&all_values, &self.data_type); - Ok(vec![ScalarValue::List(arr)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.all_values.reserve(values.len() - values.null_count()); - self.all_values.extend(values.iter().flatten()); - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let array = states[0].as_list::(); - for v in array.iter().flatten() { - self.update_batch(&[v])? - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let mut d = std::mem::take(&mut self.all_values); - let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); - - let len = d.len(); - let median = if len == 0 { - None - } else if len % 2 == 0 { - let (low, high, _) = d.select_nth_unstable_by(len / 2, cmp); - let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); - let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); - Some(median) - } else { - let (_, median, _) = d.select_nth_unstable_by(len / 2, cmp); - Some(*median) - }; - ScalarValue::new_primitive::(median, &self.data_type) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.all_values.capacity() * std::mem::size_of::() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::{array::*, datatypes::*}; - - #[test] - fn median_decimal() -> Result<()> { - // test median - let array: ArrayRef = Arc::new( - (1..7) - .map(Some) - .collect::() - .with_precision_and_scale(10, 4)?, - ); - - generic_test_op!( - array, - DataType::Decimal128(10, 4), - Median, - ScalarValue::Decimal128(Some(3), 10, 4) - ) - } - - #[test] - fn median_decimal_with_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 4)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 4), - Median, - ScalarValue::Decimal128(Some(3), 10, 4) - ) - } - - #[test] - fn median_decimal_all_nulls() -> Result<()> { - // test median - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 4)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 4), - Median, - ScalarValue::Decimal128(None, 10, 4) - ) - } - - #[test] - fn median_i32_odd() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) - } - - #[test] - fn median_i32_even() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3_i32)) - } - - #[test] - fn median_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::from(3i32)) - } - - #[test] - fn median_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Median, ScalarValue::Int32(None)) - } - - #[test] - fn median_u32_odd() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) - } - - #[test] - fn median_u32_even() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![ - 1_u32, 2_u32, 3_u32, 4_u32, 5_u32, 6_u32, - ])); - generic_test_op!(a, DataType::UInt32, Median, ScalarValue::from(3u32)) - } - - #[test] - fn median_f32_odd() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3_f32)) - } - - #[test] - fn median_f32_even() -> Result<()> { - let a: ArrayRef = Arc::new(Float32Array::from(vec![ - 1_f32, 2_f32, 3_f32, 4_f32, 5_f32, 6_f32, - ])); - generic_test_op!(a, DataType::Float32, Median, ScalarValue::from(3.5_f32)) - } - - #[test] - fn median_f64_odd() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3_f64)) - } - - #[test] - fn median_f64_even() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![ - 1_f64, 2_f64, 3_f64, 4_f64, 5_f64, 6_f64, - ])); - generic_test_op!(a, DataType::Float64, Median, ScalarValue::from(3.5_f64)) - } -} diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs deleted file mode 100644 index 95ae3207462e..000000000000 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ /dev/null @@ -1,1611 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::compute; -use arrow::datatypes::{ - DataType, Date32Type, Date64Type, Time32MillisecondType, Time32SecondType, - Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType, - TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, -}; -use arrow::{ - array::{ - ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Float32Array, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, LargeBinaryArray, - LargeStringArray, StringArray, Time32MillisecondArray, Time32SecondArray, - Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, - UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::Field, -}; -use arrow_array::types::{ - Decimal128Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; -use datafusion_common::internal_err; -use datafusion_common::ScalarValue; -use datafusion_common::{downcast_value, DataFusionError, Result}; -use datafusion_expr::{Accumulator, GroupsAccumulator}; - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use arrow::array::Array; -use arrow::array::Decimal128Array; -use arrow::array::Decimal256Array; -use arrow::datatypes::i256; -use arrow::datatypes::Decimal256Type; - -use super::moving_min_max; - -// Min/max aggregation can take Dictionary encode input but always produces unpacked -// (aka non Dictionary) output. We need to adjust the output data type to reflect this. -// The reason min/max aggregate produces unpacked output because there is only one -// min/max value per group; there is no needs to keep them Dictionary encode -fn min_max_aggregate_data_type(input_type: DataType) -> DataType { - if let DataType::Dictionary(_, value_type) = input_type { - *value_type - } else { - input_type - } -} - -/// MAX aggregate expression -#[derive(Debug, Clone)] -pub struct Max { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Max { - /// Create a new MAX aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type: min_max_aggregate_data_type(data_type), - nullable: true, - } - } -} -/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MAX` -/// the specified [`ArrowPrimitiveType`]. -/// -/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_max_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ - Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur < new { - *cur = new - } - }, - ) - // Initialize each accumulator to $NATIVE::MIN - .with_starting_value($NATIVE::MIN), - )) - }}; -} - -/// Creates a [`PrimitiveGroupsAccumulator`] for computing `MIN` -/// the specified [`ArrowPrimitiveType`]. -/// -/// -/// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_min_accumulator { - ($SELF:expr, $NATIVE:ident, $PRIMTYPE:ident) => {{ - Ok(Box::new( - PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - |cur, new| { - if *cur > new { - *cur = new - } - }, - ) - // Initialize each accumulator to $NATIVE::MAX - .with_starting_value($NATIVE::MAX), - )) - }}; -} - -impl AggregateExpr for Max { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "max"), - self.data_type.clone(), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MaxAccumulator::try_new(&self.data_type)?)) - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - use DataType::*; - matches!( - self.data_type, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) - } - - fn create_groups_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - - match self.data_type { - Int8 => instantiate_max_accumulator!(self, i8, Int8Type), - Int16 => instantiate_max_accumulator!(self, i16, Int16Type), - Int32 => instantiate_max_accumulator!(self, i32, Int32Type), - Int64 => instantiate_max_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_max_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_max_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_max_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_max_accumulator!(self, u64, UInt64Type), - Float32 => { - instantiate_max_accumulator!(self, f32, Float32Type) - } - Float64 => { - instantiate_max_accumulator!(self, f64, Float64Type) - } - Date32 => instantiate_max_accumulator!(self, i32, Date32Type), - Date64 => instantiate_max_accumulator!(self, i64, Date64Type), - Time32(Second) => { - instantiate_max_accumulator!(self, i32, Time32SecondType) - } - Time32(Millisecond) => { - instantiate_max_accumulator!(self, i32, Time32MillisecondType) - } - Time64(Microsecond) => { - instantiate_max_accumulator!(self, i64, Time64MicrosecondType) - } - Time64(Nanosecond) => { - instantiate_max_accumulator!(self, i64, Time64NanosecondType) - } - Timestamp(Second, _) => { - instantiate_max_accumulator!(self, i64, TimestampSecondType) - } - Timestamp(Millisecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMillisecondType) - } - Timestamp(Microsecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampMicrosecondType) - } - Timestamp(Nanosecond, _) => { - instantiate_max_accumulator!(self, i64, TimestampNanosecondType) - } - Decimal128(_, _) => { - instantiate_max_accumulator!(self, i128, Decimal128Type) - } - Decimal256(_, _) => { - instantiate_max_accumulator!(self, i256, Decimal256Type) - } - - // It would be nice to have a fast implementation for Strings as well - // https://github.com/apache/datafusion/issues/6906 - - // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for max({})", - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingMaxAccumulator::try_new(&self.data_type)?)) - } -} - -impl PartialEq for Max { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -// Statically-typed version of min/max(array) -> ScalarValue for string types. -macro_rules! typed_min_max_batch_string { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - let value = value.and_then(|e| Some(e.to_string())); - ScalarValue::$SCALAR(value) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for binay types. -macro_rules! typed_min_max_batch_binary { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - let value = value.and_then(|e| Some(e.to_vec())); - ScalarValue::$SCALAR(value) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for non-string types. -macro_rules! typed_min_max_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) - }}; -} - -// Statically-typed version of min/max(array) -> ScalarValue for non-string types. -// this is a macro to support both operations (min and max). -macro_rules! min_max_batch { - ($VALUES:expr, $OP:ident) => {{ - match $VALUES.data_type() { - DataType::Decimal128(precision, scale) => { - typed_min_max_batch!( - $VALUES, - Decimal128Array, - Decimal128, - $OP, - precision, - scale - ) - } - DataType::Decimal256(precision, scale) => { - typed_min_max_batch!( - $VALUES, - Decimal256Array, - Decimal256, - $OP, - precision, - scale - ) - } - // all types that have a natural order - DataType::Float64 => { - typed_min_max_batch!($VALUES, Float64Array, Float64, $OP) - } - DataType::Float32 => { - typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) - } - DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), - DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), - DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), - DataType::Int8 => typed_min_max_batch!($VALUES, Int8Array, Int8, $OP), - DataType::UInt64 => typed_min_max_batch!($VALUES, UInt64Array, UInt64, $OP), - DataType::UInt32 => typed_min_max_batch!($VALUES, UInt32Array, UInt32, $OP), - DataType::UInt16 => typed_min_max_batch!($VALUES, UInt16Array, UInt16, $OP), - DataType::UInt8 => typed_min_max_batch!($VALUES, UInt8Array, UInt8, $OP), - DataType::Timestamp(TimeUnit::Second, tz_opt) => { - typed_min_max_batch!( - $VALUES, - TimestampSecondArray, - TimestampSecond, - $OP, - tz_opt - ) - } - DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampMillisecondArray, - TimestampMillisecond, - $OP, - tz_opt - ), - DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampMicrosecondArray, - TimestampMicrosecond, - $OP, - tz_opt - ), - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt) => typed_min_max_batch!( - $VALUES, - TimestampNanosecondArray, - TimestampNanosecond, - $OP, - tz_opt - ), - DataType::Date32 => typed_min_max_batch!($VALUES, Date32Array, Date32, $OP), - DataType::Date64 => typed_min_max_batch!($VALUES, Date64Array, Date64, $OP), - DataType::Time32(TimeUnit::Second) => { - typed_min_max_batch!($VALUES, Time32SecondArray, Time32Second, $OP) - } - DataType::Time32(TimeUnit::Millisecond) => { - typed_min_max_batch!( - $VALUES, - Time32MillisecondArray, - Time32Millisecond, - $OP - ) - } - DataType::Time64(TimeUnit::Microsecond) => { - typed_min_max_batch!( - $VALUES, - Time64MicrosecondArray, - Time64Microsecond, - $OP - ) - } - DataType::Time64(TimeUnit::Nanosecond) => { - typed_min_max_batch!( - $VALUES, - Time64NanosecondArray, - Time64Nanosecond, - $OP - ) - } - other => { - // This should have been handled before - return internal_err!( - "Min/Max accumulator not implemented for type {:?}", - other - ); - } - } - }}; -} - -/// dynamically-typed min(array) -> ScalarValue -fn min_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Utf8 => { - typed_min_max_batch_string!(values, StringArray, Utf8, min_string) - } - DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) - } - DataType::Boolean => { - typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) - } - DataType::Binary => { - typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) - } - DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - min_binary - ) - } - _ => min_max_batch!(values, min), - }) -} - -/// dynamically-typed max(array) -> ScalarValue -fn max_batch(values: &ArrayRef) -> Result { - Ok(match values.data_type() { - DataType::Utf8 => { - typed_min_max_batch_string!(values, StringArray, Utf8, max_string) - } - DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) - } - DataType::Boolean => { - typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) - } - DataType::Binary => { - typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) - } - DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - max_binary - ) - } - _ => min_max_batch!(values, max), - }) -} - -// min/max of two non-string scalar values. -macro_rules! typed_min_max { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - ScalarValue::$SCALAR( - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }, - $($EXTRA_ARGS.clone()),* - ) - }}; -} - -// min/max of two scalar string values. -macro_rules! typed_min_max_string { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((a).$OP(b).clone()), - }) - }}; -} - -macro_rules! interval_choose_min_max { - (min) => { - std::cmp::Ordering::Greater - }; - (max) => { - std::cmp::Ordering::Less - }; -} - -macro_rules! interval_min_max { - ($OP:tt, $LHS:expr, $RHS:expr) => {{ - match $LHS.partial_cmp(&$RHS) { - Some(interval_choose_min_max!($OP)) => $RHS.clone(), - Some(_) => $LHS.clone(), - None => { - return internal_err!("Comparison error while computing interval min/max") - } - } - }}; -} - -// min/max of two scalar values of the same type -macro_rules! min_max { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - Ok(match ($VALUE, $DELTA) { - ( - lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - ( - lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { - typed_min_max!(lhs, rhs, Boolean, $OP) - } - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_min_max!(lhs, rhs, Float64, $OP) - } - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_min_max!(lhs, rhs, Float32, $OP) - } - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_min_max!(lhs, rhs, UInt64, $OP) - } - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - typed_min_max!(lhs, rhs, UInt32, $OP) - } - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - typed_min_max!(lhs, rhs, UInt16, $OP) - } - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - typed_min_max!(lhs, rhs, UInt8, $OP) - } - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - typed_min_max!(lhs, rhs, Int64, $OP) - } - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - typed_min_max!(lhs, rhs, Int32, $OP) - } - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - typed_min_max!(lhs, rhs, Int16, $OP) - } - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - typed_min_max!(lhs, rhs, Int8, $OP) - } - (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8, $OP) - } - (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) - } - (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { - typed_min_max_string!(lhs, rhs, Binary, $OP) - } - (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeBinary, $OP) - } - (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { - typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMillisecond(lhs, l_tz), - ScalarValue::TimestampMillisecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMicrosecond(lhs, l_tz), - ScalarValue::TimestampMicrosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) - } - ( - ScalarValue::TimestampNanosecond(lhs, l_tz), - ScalarValue::TimestampNanosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) - } - ( - ScalarValue::Date32(lhs), - ScalarValue::Date32(rhs), - ) => { - typed_min_max!(lhs, rhs, Date32, $OP) - } - ( - ScalarValue::Date64(lhs), - ScalarValue::Date64(rhs), - ) => { - typed_min_max!(lhs, rhs, Date64, $OP) - } - ( - ScalarValue::Time32Second(lhs), - ScalarValue::Time32Second(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Second, $OP) - } - ( - ScalarValue::Time32Millisecond(lhs), - ScalarValue::Time32Millisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Millisecond, $OP) - } - ( - ScalarValue::Time64Microsecond(lhs), - ScalarValue::Time64Microsecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Microsecond, $OP) - } - ( - ScalarValue::Time64Nanosecond(lhs), - ScalarValue::Time64Nanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) - } - ( - ScalarValue::IntervalYearMonth(lhs), - ScalarValue::IntervalYearMonth(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) - } - ( - ScalarValue::IntervalMonthDayNano(lhs), - ScalarValue::IntervalMonthDayNano(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) - } - ( - ScalarValue::IntervalDayTime(lhs), - ScalarValue::IntervalDayTime(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalDayTime, $OP) - } - ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalMonthDayNano(_), - ) | ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalMonthDayNano(_), - ) => { - interval_min_max!($OP, $VALUE, $DELTA) - } - ( - ScalarValue::DurationSecond(lhs), - ScalarValue::DurationSecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationSecond, $OP) - } - ( - ScalarValue::DurationMillisecond(lhs), - ScalarValue::DurationMillisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMillisecond, $OP) - } - ( - ScalarValue::DurationMicrosecond(lhs), - ScalarValue::DurationMicrosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) - } - ( - ScalarValue::DurationNanosecond(lhs), - ScalarValue::DurationNanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationNanosecond, $OP) - } - e => { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - e - ) - } - }) - }}; -} - -/// the minimum of two scalar values -pub fn min(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - min_max!(lhs, rhs, min) -} - -/// the maximum of two scalar values -pub fn max(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { - min_max!(lhs, rhs, max) -} - -/// An accumulator to compute the maximum value -#[derive(Debug)] -pub struct MaxAccumulator { - max: ScalarValue, -} - -impl MaxAccumulator { - /// new max accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - max: ScalarValue::try_from(datatype)?, - }) - } -} - -impl Accumulator for MaxAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &max_batch(values)?; - self.max = max(&self.max, delta)?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.max.clone()]) - } - - fn evaluate(&mut self) -> Result { - Ok(self.max.clone()) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() - } -} - -/// An accumulator to compute the maximum value -#[derive(Debug)] -pub struct SlidingMaxAccumulator { - max: ScalarValue, - moving_max: moving_min_max::MovingMax, -} - -impl SlidingMaxAccumulator { - /// new max accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - max: ScalarValue::try_from(datatype)?, - moving_max: moving_min_max::MovingMax::::new(), - }) - } -} - -impl Accumulator for SlidingMaxAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - for idx in 0..values[0].len() { - let val = ScalarValue::try_from_array(&values[0], idx)?; - self.moving_max.push(val); - } - if let Some(res) = self.moving_max.max() { - self.max = res.clone(); - } - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - for _idx in 0..values[0].len() { - (self.moving_max).pop(); - } - if let Some(res) = self.moving_max.max() { - self.max = res.clone(); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.max.clone()]) - } - - fn evaluate(&mut self) -> Result { - Ok(self.max.clone()) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() - } -} - -/// MIN aggregate expression -#[derive(Debug, Clone)] -pub struct Min { - name: String, - data_type: DataType, - nullable: bool, - expr: Arc, -} - -impl Min { - /// Create a new MIN aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - expr, - data_type: min_max_aggregate_data_type(data_type), - nullable: true, - } - } -} - -impl AggregateExpr for Min { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(MinAccumulator::try_new(&self.data_type)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "min"), - self.data_type.clone(), - true, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - use DataType::*; - matches!( - self.data_type, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Decimal128(_, _) - | Decimal256(_, _) - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - ) - } - - fn create_groups_accumulator(&self) -> Result> { - use DataType::*; - use TimeUnit::*; - match self.data_type { - Int8 => instantiate_min_accumulator!(self, i8, Int8Type), - Int16 => instantiate_min_accumulator!(self, i16, Int16Type), - Int32 => instantiate_min_accumulator!(self, i32, Int32Type), - Int64 => instantiate_min_accumulator!(self, i64, Int64Type), - UInt8 => instantiate_min_accumulator!(self, u8, UInt8Type), - UInt16 => instantiate_min_accumulator!(self, u16, UInt16Type), - UInt32 => instantiate_min_accumulator!(self, u32, UInt32Type), - UInt64 => instantiate_min_accumulator!(self, u64, UInt64Type), - Float32 => { - instantiate_min_accumulator!(self, f32, Float32Type) - } - Float64 => { - instantiate_min_accumulator!(self, f64, Float64Type) - } - Date32 => instantiate_min_accumulator!(self, i32, Date32Type), - Date64 => instantiate_min_accumulator!(self, i64, Date64Type), - Time32(Second) => { - instantiate_min_accumulator!(self, i32, Time32SecondType) - } - Time32(Millisecond) => { - instantiate_min_accumulator!(self, i32, Time32MillisecondType) - } - Time64(Microsecond) => { - instantiate_min_accumulator!(self, i64, Time64MicrosecondType) - } - Time64(Nanosecond) => { - instantiate_min_accumulator!(self, i64, Time64NanosecondType) - } - Timestamp(Second, _) => { - instantiate_min_accumulator!(self, i64, TimestampSecondType) - } - Timestamp(Millisecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampMillisecondType) - } - Timestamp(Microsecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampMicrosecondType) - } - Timestamp(Nanosecond, _) => { - instantiate_min_accumulator!(self, i64, TimestampNanosecondType) - } - Decimal128(_, _) => { - instantiate_min_accumulator!(self, i128, Decimal128Type) - } - Decimal256(_, _) => { - instantiate_min_accumulator!(self, i256, Decimal256Type) - } - // This is only reached if groups_accumulator_supported is out of sync - _ => internal_err!( - "GroupsAccumulator not supported for min({})", - self.data_type - ), - } - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(SlidingMinAccumulator::try_new(&self.data_type)?)) - } -} - -impl PartialEq for Min { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// An accumulator to compute the minimum value -#[derive(Debug)] -pub struct MinAccumulator { - min: ScalarValue, -} - -impl MinAccumulator { - /// new min accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - min: ScalarValue::try_from(datatype)?, - }) - } -} - -impl Accumulator for MinAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.min.clone()]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let delta = &min_batch(values)?; - self.min = min(&self.min, delta)?; - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - Ok(self.min.clone()) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() - } -} - -/// An accumulator to compute the minimum value -#[derive(Debug)] -pub struct SlidingMinAccumulator { - min: ScalarValue, - moving_min: moving_min_max::MovingMin, -} - -impl SlidingMinAccumulator { - /// new min accumulator - pub fn try_new(datatype: &DataType) -> Result { - Ok(Self { - min: ScalarValue::try_from(datatype)?, - moving_min: moving_min_max::MovingMin::::new(), - }) - } -} - -impl Accumulator for SlidingMinAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.min.clone()]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - for idx in 0..values[0].len() { - let val = ScalarValue::try_from_array(&values[0], idx)?; - if !val.is_null() { - self.moving_min.push(val); - } - } - if let Some(res) = self.moving_min.min() { - self.min = res.clone(); - } - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - for idx in 0..values[0].len() { - let val = ScalarValue::try_from_array(&values[0], idx)?; - if !val.is_null() { - (self.moving_min).pop(); - } - } - if let Some(res) = self.moving_min.min() { - self.min = res.clone(); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - Ok(self.min.clone()) - } - - fn supports_retract_batch(&self) -> bool { - true - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::{aggregate, aggregate_new}; - use crate::{generic_test_op, generic_test_op_new}; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use datafusion_common::ScalarValue::Decimal128; - - #[test] - fn min_decimal() -> Result<()> { - // min - let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 10, 2); - let result = min(&left, &right)?; - assert_eq!(result, left); - - // min batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - - let result = min_batch(&array)?; - assert_eq!(result, ScalarValue::Decimal128(Some(1), 10, 0)); - - // min batch without values - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(0) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = min_batch(&array)?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); - - // min batch with agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(Some(1), 10, 0) - ) - } - - #[test] - fn min_decimal_all_nulls() -> Result<()> { - // min batch all nulls - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(None, 10, 0) - ) - } - - #[test] - fn min_decimal_with_nulls() -> Result<()> { - // min batch with nulls - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(Some(1), 10, 0) - ) - } - - #[test] - fn max_decimal() -> Result<()> { - // max - let left = ScalarValue::Decimal128(Some(123), 10, 2); - let right = ScalarValue::Decimal128(Some(124), 10, 2); - let result = max(&left, &right)?; - assert_eq!(result, right); - - let right = ScalarValue::Decimal128(Some(124), 10, 3); - let result = max(&left, &right); - let err_msg = format!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (Decimal128(Some(123), 10, 2), Decimal128(Some(124), 10, 3)) - ); - let expect = DataFusionError::Internal(err_msg); - assert!(expect - .strip_backtrace() - .starts_with(&result.unwrap_err().strip_backtrace())); - - // max batch - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 5)?, - ); - let result = max_batch(&array)?; - assert_eq!(result, ScalarValue::Decimal128(Some(5), 10, 5)); - - // max batch without values - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(0) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - let result = max_batch(&array)?; - assert_eq!(ScalarValue::Decimal128(None, 10, 0), result); - - // max batch with agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Max, - ScalarValue::Decimal128(Some(5), 10, 0) - ) - } - - #[test] - fn max_decimal_with_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Max, - ScalarValue::Decimal128(Some(5), 10, 0) - ) - } - - #[test] - fn max_decimal_all_nulls() -> Result<()> { - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0)?, - ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Min, - ScalarValue::Decimal128(None, 10, 0) - ) - } - - #[test] - fn max_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) - } - - #[test] - fn min_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) - } - - #[test] - fn max_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!(a, DataType::Utf8, Max, ScalarValue::from("d")) - } - - #[test] - fn max_large_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Max, - ScalarValue::LargeUtf8(Some("d".to_string())) - ) - } - - #[test] - fn min_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(StringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!(a, DataType::Utf8, Min, ScalarValue::from("a")) - } - - #[test] - fn min_large_utf8() -> Result<()> { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["d", "a", "c", "b"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Min, - ScalarValue::LargeUtf8(Some("a".to_string())) - ) - } - - #[test] - fn max_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) - } - - #[test] - fn min_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) - } - - #[test] - fn max_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Max, ScalarValue::Int32(None)) - } - - #[test] - fn min_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Min, ScalarValue::Int32(None)) - } - - #[test] - fn max_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Max, ScalarValue::from(5_u32)) - } - - #[test] - fn min_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Min, ScalarValue::from(1u32)) - } - - #[test] - fn max_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Max, ScalarValue::from(5_f32)) - } - - #[test] - fn min_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Min, ScalarValue::from(1_f32)) - } - - #[test] - fn max_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Max, ScalarValue::from(5_f64)) - } - - #[test] - fn min_f64() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Min, ScalarValue::from(1_f64)) - } - - #[test] - fn min_date32() -> Result<()> { - let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date32, Min, ScalarValue::Date32(Some(1))) - } - - #[test] - fn min_date64() -> Result<()> { - let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date64, Min, ScalarValue::Date64(Some(1))) - } - - #[test] - fn max_date32() -> Result<()> { - let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date32, Max, ScalarValue::Date32(Some(5))) - } - - #[test] - fn max_date64() -> Result<()> { - let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Date64, Max, ScalarValue::Date64(Some(5))) - } - - #[test] - fn min_time32second() -> Result<()> { - let a: ArrayRef = Arc::new(Time32SecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Second), - Min, - ScalarValue::Time32Second(Some(1)) - ) - } - - #[test] - fn max_time32second() -> Result<()> { - let a: ArrayRef = Arc::new(Time32SecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Second), - Max, - ScalarValue::Time32Second(Some(5)) - ) - } - - #[test] - fn min_time32millisecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Millisecond), - Min, - ScalarValue::Time32Millisecond(Some(1)) - ) - } - - #[test] - fn max_time32millisecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time32(TimeUnit::Millisecond), - Max, - ScalarValue::Time32Millisecond(Some(5)) - ) - } - - #[test] - fn min_time64microsecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Microsecond), - Min, - ScalarValue::Time64Microsecond(Some(1)) - ) - } - - #[test] - fn max_time64microsecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64MicrosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Microsecond), - Max, - ScalarValue::Time64Microsecond(Some(5)) - ) - } - - #[test] - fn min_time64nanosecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64NanosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Nanosecond), - Min, - ScalarValue::Time64Nanosecond(Some(1)) - ) - } - - #[test] - fn max_time64nanosecond() -> Result<()> { - let a: ArrayRef = Arc::new(Time64NanosecondArray::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Time64(TimeUnit::Nanosecond), - Max, - ScalarValue::Time64Nanosecond(Some(5)) - ) - } - - #[test] - fn max_new_timestamp_micro() -> Result<()> { - let dt = DataType::Timestamp(TimeUnit::Microsecond, None); - let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) - .with_data_type(dt.clone()); - let expected: ArrayRef = - Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); - generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) - } - - #[test] - fn max_new_timestamp_micro_with_tz() -> Result<()> { - let dt = DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())); - let actual = TimestampMicrosecondArray::from(vec![1, 2, 3, 4, 5]) - .with_data_type(dt.clone()); - let expected: ArrayRef = - Arc::new(TimestampMicrosecondArray::from(vec![5]).with_data_type(dt.clone())); - generic_test_op_new!(Arc::new(actual), dt.clone(), Max, &expected) - } - - #[test] - fn max_bool() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, true, false])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(Vec::::new())); - generic_test_op!( - a, - DataType::Boolean, - Max, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None as Option])); - generic_test_op!( - a, - DataType::Boolean, - Max, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = - Arc::new(BooleanArray::from(vec![None, Some(true), Some(false)])); - generic_test_op!(a, DataType::Boolean, Max, ScalarValue::from(true))?; - - Ok(()) - } - - #[test] - fn min_bool() -> Result<()> { - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, false])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, true])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(true))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![false, true, false])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![true, false, true])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - let a: ArrayRef = Arc::new(BooleanArray::from(Vec::::new())); - generic_test_op!( - a, - DataType::Boolean, - Min, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = Arc::new(BooleanArray::from(vec![None as Option])); - generic_test_op!( - a, - DataType::Boolean, - Min, - ScalarValue::from(None as Option) - )?; - - let a: ArrayRef = - Arc::new(BooleanArray::from(vec![None, Some(true), Some(false)])); - generic_test_op!(a, DataType::Boolean, Min, ScalarValue::from(false))?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs deleted file mode 100644 index eff008e8f825..000000000000 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ /dev/null @@ -1,65 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::sync::Arc; - -use crate::expressions::{NthValueAgg, OrderSensitiveArrayAgg}; - -pub use datafusion_physical_expr_common::aggregate::AggregateExpr; - -mod hyperloglog; -mod tdigest; - -pub(crate) mod approx_distinct; -pub(crate) mod approx_median; -pub(crate) mod approx_percentile_cont; -pub(crate) mod approx_percentile_cont_with_weight; -pub(crate) mod array_agg; -pub(crate) mod array_agg_distinct; -pub(crate) mod array_agg_ordered; -pub(crate) mod average; -pub(crate) mod bit_and_or_xor; -pub(crate) mod bool_and_or; -pub(crate) mod correlation; -pub(crate) mod count; -pub(crate) mod count_distinct; -pub(crate) mod covariance; -pub(crate) mod grouping; -pub(crate) mod median; -pub(crate) mod nth_value; -pub(crate) mod string_agg; -#[macro_use] -pub(crate) mod min_max; -pub(crate) mod groups_accumulator; -pub(crate) mod regr; -pub(crate) mod stats; -pub(crate) mod stddev; -pub(crate) mod sum; -pub(crate) mod sum_distinct; -pub(crate) mod variance; - -pub mod build_in; -pub mod moving_min_max; -pub mod utils; - -/// Checks whether the given aggregate expression is order-sensitive. -/// For instance, a `SUM` aggregation doesn't depend on the order of its inputs. -/// However, an `ARRAY_AGG` with `ORDER BY` depends on the input ordering. -pub fn is_order_sensitive(aggr_expr: &Arc) -> bool { - aggr_expr.as_any().is::() - || aggr_expr.as_any().is::() -} diff --git a/datafusion/physical-expr/src/aggregate/moving_min_max.rs b/datafusion/physical-expr/src/aggregate/moving_min_max.rs deleted file mode 100644 index c4fb07679747..000000000000 --- a/datafusion/physical-expr/src/aggregate/moving_min_max.rs +++ /dev/null @@ -1,335 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// The implementation is taken from https://github.com/spebern/moving_min_max/blob/master/src/lib.rs. - -//! Keep track of the minimum or maximum value in a sliding window. -//! -//! `moving min max` provides one data structure for keeping track of the -//! minimum value and one for keeping track of the maximum value in a sliding -//! window. -//! -//! Each element is stored with the current min/max. One stack to push and another one for pop. If pop stack is empty, -//! push to this stack all elements popped from first stack while updating their current min/max. Now pop from -//! the second stack (MovingMin/Max struct works as a queue). To find the minimum element of the queue, -//! look at the smallest/largest two elements of the individual stacks, then take the minimum of those two values. -//! -//! The complexity of the operations are -//! - O(1) for getting the minimum/maximum -//! - O(1) for push -//! - amortized O(1) for pop - -/// ``` -/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMin; -/// let mut moving_min = MovingMin::::new(); -/// moving_min.push(2); -/// moving_min.push(1); -/// moving_min.push(3); -/// -/// assert_eq!(moving_min.min(), Some(&1)); -/// assert_eq!(moving_min.pop(), Some(2)); -/// -/// assert_eq!(moving_min.min(), Some(&1)); -/// assert_eq!(moving_min.pop(), Some(1)); -/// -/// assert_eq!(moving_min.min(), Some(&3)); -/// assert_eq!(moving_min.pop(), Some(3)); -/// -/// assert_eq!(moving_min.min(), None); -/// assert_eq!(moving_min.pop(), None); -/// ``` -#[derive(Debug)] -pub struct MovingMin { - push_stack: Vec<(T, T)>, - pop_stack: Vec<(T, T)>, -} - -impl Default for MovingMin { - fn default() -> Self { - Self { - push_stack: Vec::new(), - pop_stack: Vec::new(), - } - } -} - -impl MovingMin { - /// Creates a new `MovingMin` to keep track of the minimum in a sliding - /// window. - #[inline] - pub fn new() -> Self { - Self::default() - } - - /// Creates a new `MovingMin` to keep track of the minimum in a sliding - /// window with `capacity` allocated slots. - #[inline] - pub fn with_capacity(capacity: usize) -> Self { - Self { - push_stack: Vec::with_capacity(capacity), - pop_stack: Vec::with_capacity(capacity), - } - } - - /// Returns the minimum of the sliding window or `None` if the window is - /// empty. - #[inline] - pub fn min(&self) -> Option<&T> { - match (self.push_stack.last(), self.pop_stack.last()) { - (None, None) => None, - (Some((_, min)), None) => Some(min), - (None, Some((_, min))) => Some(min), - (Some((_, a)), Some((_, b))) => Some(if a < b { a } else { b }), - } - } - - /// Pushes a new element into the sliding window. - #[inline] - pub fn push(&mut self, val: T) { - self.push_stack.push(match self.push_stack.last() { - Some((_, min)) => { - if val > *min { - (val, min.clone()) - } else { - (val.clone(), val) - } - } - None => (val.clone(), val), - }); - } - - /// Removes and returns the last value of the sliding window. - #[inline] - pub fn pop(&mut self) -> Option { - if self.pop_stack.is_empty() { - match self.push_stack.pop() { - Some((val, _)) => { - let mut last = (val.clone(), val); - self.pop_stack.push(last.clone()); - while let Some((val, _)) = self.push_stack.pop() { - let min = if last.1 < val { - last.1.clone() - } else { - val.clone() - }; - last = (val.clone(), min); - self.pop_stack.push(last.clone()); - } - } - None => return None, - } - } - self.pop_stack.pop().map(|(val, _)| val) - } - - /// Returns the number of elements stored in the sliding window. - #[inline] - pub fn len(&self) -> usize { - self.push_stack.len() + self.pop_stack.len() - } - - /// Returns `true` if the moving window contains no elements. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} -/// ``` -/// # use datafusion_physical_expr::aggregate::moving_min_max::MovingMax; -/// let mut moving_max = MovingMax::::new(); -/// moving_max.push(2); -/// moving_max.push(3); -/// moving_max.push(1); -/// -/// assert_eq!(moving_max.max(), Some(&3)); -/// assert_eq!(moving_max.pop(), Some(2)); -/// -/// assert_eq!(moving_max.max(), Some(&3)); -/// assert_eq!(moving_max.pop(), Some(3)); -/// -/// assert_eq!(moving_max.max(), Some(&1)); -/// assert_eq!(moving_max.pop(), Some(1)); -/// -/// assert_eq!(moving_max.max(), None); -/// assert_eq!(moving_max.pop(), None); -/// ``` -#[derive(Debug)] -pub struct MovingMax { - push_stack: Vec<(T, T)>, - pop_stack: Vec<(T, T)>, -} - -impl Default for MovingMax { - fn default() -> Self { - Self { - push_stack: Vec::new(), - pop_stack: Vec::new(), - } - } -} - -impl MovingMax { - /// Creates a new `MovingMax` to keep track of the maximum in a sliding window. - #[inline] - pub fn new() -> Self { - Self::default() - } - - /// Creates a new `MovingMax` to keep track of the maximum in a sliding window with - /// `capacity` allocated slots. - #[inline] - pub fn with_capacity(capacity: usize) -> Self { - Self { - push_stack: Vec::with_capacity(capacity), - pop_stack: Vec::with_capacity(capacity), - } - } - - /// Returns the maximum of the sliding window or `None` if the window is empty. - #[inline] - pub fn max(&self) -> Option<&T> { - match (self.push_stack.last(), self.pop_stack.last()) { - (None, None) => None, - (Some((_, max)), None) => Some(max), - (None, Some((_, max))) => Some(max), - (Some((_, a)), Some((_, b))) => Some(if a > b { a } else { b }), - } - } - - /// Pushes a new element into the sliding window. - #[inline] - pub fn push(&mut self, val: T) { - self.push_stack.push(match self.push_stack.last() { - Some((_, max)) => { - if val < *max { - (val, max.clone()) - } else { - (val.clone(), val) - } - } - None => (val.clone(), val), - }); - } - - /// Removes and returns the last value of the sliding window. - #[inline] - pub fn pop(&mut self) -> Option { - if self.pop_stack.is_empty() { - match self.push_stack.pop() { - Some((val, _)) => { - let mut last = (val.clone(), val); - self.pop_stack.push(last.clone()); - while let Some((val, _)) = self.push_stack.pop() { - let max = if last.1 > val { - last.1.clone() - } else { - val.clone() - }; - last = (val.clone(), max); - self.pop_stack.push(last.clone()); - } - } - None => return None, - } - } - self.pop_stack.pop().map(|(val, _)| val) - } - - /// Returns the number of elements stored in the sliding window. - #[inline] - pub fn len(&self) -> usize { - self.push_stack.len() + self.pop_stack.len() - } - - /// Returns `true` if the moving window contains no elements. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } -} - -#[cfg(test)] -mod tests { - use super::*; - use datafusion_common::Result; - use rand::Rng; - - fn get_random_vec_i32(len: usize) -> Vec { - let mut rng = rand::thread_rng(); - let mut input = Vec::with_capacity(len); - for _i in 0..len { - input.push(rng.gen_range(0..100)); - } - input - } - - fn moving_min_i32(len: usize, n_sliding_window: usize) -> Result<()> { - let data = get_random_vec_i32(len); - let mut expected = Vec::with_capacity(len); - let mut moving_min = MovingMin::::new(); - let mut res = Vec::with_capacity(len); - for i in 0..len { - let start = i.saturating_sub(n_sliding_window); - expected.push(*data[start..i + 1].iter().min().unwrap()); - - moving_min.push(data[i]); - if i > n_sliding_window { - moving_min.pop(); - } - res.push(*moving_min.min().unwrap()); - } - assert_eq!(res, expected); - Ok(()) - } - - fn moving_max_i32(len: usize, n_sliding_window: usize) -> Result<()> { - let data = get_random_vec_i32(len); - let mut expected = Vec::with_capacity(len); - let mut moving_max = MovingMax::::new(); - let mut res = Vec::with_capacity(len); - for i in 0..len { - let start = i.saturating_sub(n_sliding_window); - expected.push(*data[start..i + 1].iter().max().unwrap()); - - moving_max.push(data[i]); - if i > n_sliding_window { - moving_max.pop(); - } - res.push(*moving_max.max().unwrap()); - } - assert_eq!(res, expected); - Ok(()) - } - - #[test] - fn moving_min_tests() -> Result<()> { - moving_min_i32(100, 10)?; - moving_min_i32(100, 20)?; - moving_min_i32(100, 50)?; - moving_min_i32(100, 100)?; - Ok(()) - } - - #[test] - fn moving_max_tests() -> Result<()> { - moving_max_i32(100, 10)?; - moving_max_i32(100, 20)?; - moving_max_i32(100, 50)?; - moving_max_i32(100, 100)?; - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs deleted file mode 100644 index e5ce1b9230db..000000000000 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ /dev/null @@ -1,464 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::stats::StatsType; -use crate::aggregate::utils::down_cast_any_ref; -use crate::aggregate::variance::VarianceAccumulator; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{array::ArrayRef, datatypes::DataType, datatypes::Field}; -use datafusion_common::ScalarValue; -use datafusion_common::{internal_err, Result}; -use datafusion_expr::Accumulator; - -/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression -#[derive(Debug)] -pub struct Stddev { - name: String, - expr: Arc, -} - -/// STDDEV_POP population aggregate expression -#[derive(Debug)] -pub struct StddevPop { - name: String, - expr: Arc, -} - -impl Stddev { - /// Create a new STDDEV aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of stddev just support FLOAT64 and Decimal data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for Stddev { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Stddev { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - -impl StddevPop { - /// Create a new STDDEV aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of stddev just support FLOAT64 and Decimal data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for StddevPop { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for StddevPop { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - -/// An accumulator to compute the average -#[derive(Debug)] -pub struct StddevAccumulator { - variance: VarianceAccumulator, -} - -impl StddevAccumulator { - /// Creates a new `StddevAccumulator` - pub fn try_new(s_type: StatsType) -> Result { - Ok(Self { - variance: VarianceAccumulator::try_new(s_type)?, - }) - } - - pub fn get_m2(&self) -> f64 { - self.variance.get_m2() - } -} - -impl Accumulator for StddevAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.variance.get_count()), - ScalarValue::from(self.variance.get_mean()), - ScalarValue::from(self.variance.get_m2()), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.variance.update_batch(values) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.variance.retract_batch(values) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.variance.merge_batch(states) - } - - fn evaluate(&mut self) -> Result { - let variance = self.variance.evaluate()?; - match variance { - ScalarValue::Float64(e) => { - if e.is_none() { - Ok(ScalarValue::Float64(None)) - } else { - Ok(ScalarValue::Float64(e.map(|f| f.sqrt()))) - } - } - _ => internal_err!("Variance should be f64"), - } - } - - fn size(&self) -> usize { - std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) - + self.variance.size() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::{array::*, datatypes::*}; - - #[test] - fn stddev_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); - generic_test_op!(a, DataType::Float64, StddevPop, ScalarValue::from(0.5_f64)) - } - - #[test] - fn stddev_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - StddevPop, - ScalarValue::from(0.7760297817881877_f64) - ) - } - - #[test] - fn stddev_f64_3() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn stddev_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - Stddev, - ScalarValue::from(0.9504384952922168_f64) - ) - } - - #[test] - fn stddev_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn stddev_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn stddev_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2) - ) - } - - #[test] - fn test_stddev_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Stddev::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - - Ok(()) - } - - #[test] - fn stddev_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!( - a, - DataType::Int32, - StddevPop, - ScalarValue::from(1.479019945774904_f64) - ) - } - - #[test] - fn stddev_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Stddev::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - Ok(()) - } - - #[test] - fn stddev_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64])); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; - - let agg1 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(std::f64::consts::SQRT_2)); - - Ok(()) - } - - #[test] - fn stddev_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - let b = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; - - let agg1 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(StddevPop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(std::f64::consts::SQRT_2)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs deleted file mode 100644 index dc0ffc557968..000000000000 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ /dev/null @@ -1,246 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// STRING_AGG aggregate expression -#[derive(Debug)] -pub struct StringAgg { - name: String, - data_type: DataType, - expr: Arc, - delimiter: Arc, - nullable: bool, -} - -impl StringAgg { - /// Create a new StringAgg aggregate function - pub fn new( - expr: Arc, - delimiter: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - data_type, - delimiter, - expr, - nullable: true, - } - } -} - -impl AggregateExpr for StringAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { - match delimiter.value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { - return Ok(Box::new(StringAggAccumulator::new(delimiter))); - } - ScalarValue::Null => { - return Ok(Box::new(StringAggAccumulator::new(""))); - } - _ => return not_impl_err!("StringAgg not supported for {}", self.name), - } - } - not_impl_err!("StringAgg not supported for {}", self.name) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "string_agg"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone(), self.delimiter.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for StringAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.delimiter.eq(&x.delimiter) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct StringAggAccumulator { - values: Option, - delimiter: String, -} - -impl StringAggAccumulator { - pub fn new(delimiter: &str) -> Self { - Self { - values: None, - delimiter: delimiter.to_string(), - } - } -} - -impl Accumulator for StringAggAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(); - if !string_array.is_empty() { - let s = string_array.join(self.delimiter.as_str()); - let v = self.values.get_or_insert("".to_string()); - if !v.is_empty() { - v.push_str(self.delimiter.as_str()); - } - v.push_str(s.as_str()); - } - Ok(()) - } - - fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.update_batch(values)?; - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) - + self.delimiter.capacity() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::aggregate; - use crate::expressions::{col, create_aggregate_expr, try_cast}; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use arrow_array::LargeStringArray; - use arrow_array::StringArray; - use datafusion_expr::type_coercion::aggregates::coerce_types; - use datafusion_expr::AggregateFunction; - - fn assert_string_aggregate( - array: ArrayRef, - function: AggregateFunction, - distinct: bool, - expected: ScalarValue, - delimiter: String, - ) { - let data_type = array.data_type(); - let sig = function.signature(); - let coerced = - coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); - - let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); - let batch = - RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); - - let input = try_cast( - col("a", &input_schema).unwrap(), - &input_schema, - coerced[0].clone(), - ) - .unwrap(); - - let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); - let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); - let agg = create_aggregate_expr( - &function, - distinct, - &[input, delimiter], - &[], - &schema, - "agg", - false, - ) - .unwrap(); - - let result = aggregate(&batch, agg).unwrap(); - assert_eq!(expected, result); - } - - #[test] - fn string_agg_utf8() { - let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); - assert_string_aggregate( - a, - AggregateFunction::StringAgg, - false, - ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), - ",".to_owned(), - ); - } - - #[test] - fn string_agg_largeutf8() { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); - assert_string_aggregate( - a, - AggregateFunction::StringAgg, - false, - ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), - "|".to_owned(), - ); - } -} diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs deleted file mode 100644 index f19be62bbc95..000000000000 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ /dev/null @@ -1,402 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines `SUM` and `SUM DISTINCT` aggregate accumulators - -use std::any::Any; -use std::sync::Arc; - -use super::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::compute::sum; -use arrow::datatypes::DataType; -use arrow::{array::ArrayRef, datatypes::Field}; -use arrow_array::cast::AsArray; -use arrow_array::types::{ - Decimal128Type, Decimal256Type, Float64Type, Int64Type, UInt64Type, -}; -use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType}; -use arrow_buffer::ArrowNativeType; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::type_coercion::aggregates::sum_return_type; -use datafusion_expr::{Accumulator, GroupsAccumulator}; - -/// SUM aggregate expression -#[derive(Debug, Clone)] -pub struct Sum { - name: String, - // The DataType for the input expression - data_type: DataType, - // The DataType for the final sum - return_type: DataType, - expr: Arc, - nullable: bool, -} - -impl Sum { - /// Create a new SUM aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - let return_type = sum_return_type(&data_type).unwrap(); - Self { - name: name.into(), - data_type, - return_type, - expr, - nullable: true, - } - } -} - -/// Sum only supports a subset of numeric types, instead relying on type coercion -/// -/// This macro is similar to [downcast_primitive](arrow_array::downcast_primitive) -/// -/// `s` is a `Sum`, `helper` is a macro accepting (ArrowPrimitiveType, DataType) -macro_rules! downcast_sum { - ($s:ident, $helper:ident) => { - match $s.return_type { - DataType::UInt64 => $helper!(UInt64Type, $s.return_type), - DataType::Int64 => $helper!(Int64Type, $s.return_type), - DataType::Float64 => $helper!(Float64Type, $s.return_type), - DataType::Decimal128(_, _) => $helper!(Decimal128Type, $s.return_type), - DataType::Decimal256(_, _) => $helper!(Decimal256Type, $s.return_type), - _ => not_impl_err!("Sum not supported for {}: {}", $s.name, $s.return_type), - } - }; -} -pub(crate) use downcast_sum; - -impl AggregateExpr for Sum { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.return_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(SumAccumulator::<$t>::new($dt.clone()))) - }; - } - downcast_sum!(self, helper) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "sum"), - self.return_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn groups_accumulator_supported(&self) -> bool { - true - } - - fn create_groups_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(PrimitiveGroupsAccumulator::<$t, _>::new( - &$dt, - |x, y| *x = x.add_wrapping(y), - ))) - }; - } - downcast_sum!(self, helper) - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(self.clone())) - } - - fn create_sliding_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(SlidingSumAccumulator::<$t>::new($dt.clone()))) - }; - } - downcast_sum!(self, helper) - } -} - -impl PartialEq for Sum { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.nullable == x.nullable - && self.expr.eq(&x.expr) - }) - .unwrap_or(false) - } -} - -/// This accumulator computes SUM incrementally -struct SumAccumulator { - sum: Option, - data_type: DataType, -} - -impl std::fmt::Debug for SumAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "SumAccumulator({})", self.data_type) - } -} - -impl SumAccumulator { - fn new(data_type: DataType) -> Self { - Self { - sum: None, - data_type, - } - } -} - -impl Accumulator for SumAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - if let Some(x) = sum(values) { - let v = self.sum.get_or_insert(T::Native::usize_as(0)); - *v = v.add_wrapping(x); - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - self.update_batch(states) - } - - fn evaluate(&mut self) -> Result { - ScalarValue::new_primitive::(self.sum, &self.data_type) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -/// This accumulator incrementally computes sums over a sliding window -/// -/// This is separate from [`SumAccumulator`] as requires additional state -struct SlidingSumAccumulator { - sum: T::Native, - count: u64, - data_type: DataType, -} - -impl std::fmt::Debug for SlidingSumAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "SlidingSumAccumulator({})", self.data_type) - } -} - -impl SlidingSumAccumulator { - fn new(data_type: DataType) -> Self { - Self { - sum: T::Native::usize_as(0), - count: 0, - data_type, - } - } -} - -impl Accumulator for SlidingSumAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?, self.count.into()]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - self.count += (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { - self.sum = self.sum.add_wrapping(x) - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let values = states[0].as_primitive::(); - if let Some(x) = sum(values) { - self.sum = self.sum.add_wrapping(x) - } - if let Some(x) = sum(states[1].as_primitive::()) { - self.count += x; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let v = (self.count != 0).then_some(self.sum); - ScalarValue::new_primitive::(v, &self.data_type) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); - if let Some(x) = sum(values) { - self.sum = self.sum.sub_wrapping(x) - } - self.count -= (values.len() - values.null_count()) as u64; - Ok(()) - } - - fn supports_retract_batch(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::assert_aggregate; - use arrow_array::*; - use datafusion_expr::AggregateFunction; - - #[test] - fn sum_decimal() { - // test agg - let array: ArrayRef = Arc::new( - (1..6) - .map(Some) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - - assert_aggregate( - array, - AggregateFunction::Sum, - false, - ScalarValue::Decimal128(Some(15), 20, 0), - ); - } - - #[test] - fn sum_decimal_with_nulls() { - // test agg - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i) }) - .collect::() - .with_precision_and_scale(35, 0) - .unwrap(), - ); - - assert_aggregate( - array, - AggregateFunction::Sum, - false, - ScalarValue::Decimal128(Some(13), 38, 0), - ); - } - - #[test] - fn sum_decimal_all_nulls() { - // test with batch - let array: ArrayRef = Arc::new( - std::iter::repeat::>(None) - .take(6) - .collect::() - .with_precision_and_scale(10, 0) - .unwrap(), - ); - - // test agg - assert_aggregate( - array, - AggregateFunction::Sum, - false, - ScalarValue::Decimal128(None, 20, 0), - ); - } - - #[test] - fn sum_i32() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15i64)); - } - - #[test] - fn sum_i32_with_nulls() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(13i64)); - } - - #[test] - fn sum_i32_all_nulls() { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::Int64(None)); - } - - #[test] - fn sum_u32() { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15u64)); - } - - #[test] - fn sum_f32() { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15_f64)); - } - - #[test] - fn sum_f64() { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - assert_aggregate(a, AggregateFunction::Sum, false, ScalarValue::from(15_f64)); - } -} diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs deleted file mode 100644 index 09f3f9b498c1..000000000000 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ /dev/null @@ -1,283 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::expressions::format_state_name; -use arrow::datatypes::{DataType, Field}; -use std::any::Any; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use arrow_array::cast::AsArray; -use arrow_array::types::*; -use arrow_array::ArrowNativeTypeOp; -use arrow_buffer::ArrowNativeType; -use std::collections::HashSet; - -use crate::aggregate::sum::downcast_sum; -use crate::aggregate::utils::{down_cast_any_ref, Hashable}; -use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::type_coercion::aggregates::sum_return_type; -use datafusion_expr::Accumulator; - -/// Expression for a SUM(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctSum { - /// Column name - name: String, - // The DataType for the input expression - data_type: DataType, - // The DataType for the final sum - return_type: DataType, - /// The input arguments, only contains 1 item for sum - exprs: Vec>, -} - -impl DistinctSum { - /// Create a SUM(DISTINCT) aggregate function. - pub fn new( - exprs: Vec>, - name: String, - data_type: DataType, - ) -> Self { - let return_type = sum_return_type(&data_type).unwrap(); - Self { - name, - data_type, - return_type, - exprs, - } - } -} - -impl AggregateExpr for DistinctSum { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.return_type.clone(), true)) - } - - fn state_fields(&self) -> Result> { - // State field is a List which stores items to rebuild hash set. - Ok(vec![Field::new_list( - format_state_name(&self.name, "sum distinct"), - Field::new("item", self.return_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - self.exprs.clone() - } - - fn name(&self) -> &str { - &self.name - } - - fn create_accumulator(&self) -> Result> { - macro_rules! helper { - ($t:ty, $dt:expr) => { - Ok(Box::new(DistinctSumAccumulator::<$t>::try_new(&$dt)?)) - }; - } - downcast_sum!(self, helper) - } -} - -impl PartialEq for DistinctSum { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.exprs.len() == x.exprs.len() - && self - .exprs - .iter() - .zip(x.exprs.iter()) - .all(|(this, other)| this.eq(other)) - }) - .unwrap_or(false) - } -} - -struct DistinctSumAccumulator { - values: HashSet, RandomState>, - data_type: DataType, -} - -impl std::fmt::Debug for DistinctSumAccumulator { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "DistinctSumAccumulator({})", self.data_type) - } -} - -impl DistinctSumAccumulator { - pub fn try_new(data_type: &DataType) -> Result { - Ok(Self { - values: HashSet::default(), - data_type: data_type.clone(), - }) - } -} - -impl Accumulator for DistinctSumAccumulator { - fn state(&mut self) -> Result> { - // 1. Stores aggregate state in `ScalarValue::List` - // 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set - let state_out = { - let distinct_values = self - .values - .iter() - .map(|value| { - ScalarValue::new_primitive::(Some(value.0), &self.data_type) - }) - .collect::>>()?; - - vec![ScalarValue::List(ScalarValue::new_list( - &distinct_values, - &self.data_type, - ))] - }; - Ok(state_out) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = values[0].as_primitive::(); - match array.nulls().filter(|x| x.null_count() > 0) { - Some(n) => { - for idx in n.valid_indices() { - self.values.insert(Hashable(array.value(idx))); - } - } - None => array.values().iter().for_each(|x| { - self.values.insert(Hashable(*x)); - }), - } - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - for x in states[0].as_list::().iter().flatten() { - self.update_batch(&[x])? - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let mut acc = T::Native::usize_as(0); - for distinct_value in self.values.iter() { - acc = acc.add_wrapping(distinct_value.0) - } - let v = (!self.values.is_empty()).then_some(acc); - ScalarValue::new_primitive::(v, &self.data_type) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.capacity() * std::mem::size_of::() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::assert_aggregate; - use arrow::array::*; - use datafusion_expr::AggregateFunction; - - fn run_update_batch( - return_type: DataType, - arrays: &[ArrayRef], - ) -> Result<(Vec, ScalarValue)> { - let agg = DistinctSum::new(vec![], String::from("__col_name__"), return_type); - - let mut accum = agg.create_accumulator()?; - accum.update_batch(arrays)?; - - Ok((accum.state()?, accum.evaluate()?)) - } - - #[test] - fn sum_distinct_update_batch() -> Result<()> { - let array_int64: ArrayRef = Arc::new(Int64Array::from(vec![1, 1, 3])); - let arrays = vec![array_int64]; - let (states, result) = run_update_batch(DataType::Int64, &arrays)?; - - assert_eq!(states.len(), 1); - assert_eq!(result, ScalarValue::Int64(Some(4))); - - Ok(()) - } - - #[test] - fn sum_distinct_i32_with_nulls() { - let array = Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - None, - Some(2), - Some(2), - Some(3), - ])); - assert_aggregate(array, AggregateFunction::Sum, true, 6_i64.into()); - } - - #[test] - fn sum_distinct_u32_with_nulls() { - let array: ArrayRef = Arc::new(UInt32Array::from(vec![ - Some(1_u32), - Some(1_u32), - Some(3_u32), - Some(3_u32), - None, - ])); - assert_aggregate(array, AggregateFunction::Sum, true, 4_u64.into()); - } - - #[test] - fn sum_distinct_f64() { - let array: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64, 3_f64])); - assert_aggregate(array, AggregateFunction::Sum, true, 4_f64.into()); - } - - #[test] - fn sum_distinct_decimal_with_nulls() { - let array: ArrayRef = Arc::new( - (1..6) - .map(|i| if i == 2 { None } else { Some(i % 2) }) - .collect::() - .with_precision_and_scale(35, 0) - .unwrap(), - ); - assert_aggregate( - array, - AggregateFunction::Sum, - true, - ScalarValue::Decimal128(Some(1), 38, 0), - ); - } -} diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs deleted file mode 100644 index 989041097730..000000000000 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ /dev/null @@ -1,538 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions that can evaluated at runtime during query execution - -use std::any::Any; -use std::sync::Arc; - -use crate::aggregate::stats::StatsType; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::Float64Array; -use arrow::{ - array::{ArrayRef, UInt64Array}, - compute::cast, - datatypes::DataType, - datatypes::Field, -}; -use datafusion_common::downcast_value; -use datafusion_common::{DataFusionError, Result, ScalarValue}; -use datafusion_expr::Accumulator; - -/// VAR and VAR_SAMP aggregate expression -#[derive(Debug)] -pub struct Variance { - name: String, - expr: Arc, -} - -/// VAR_POP aggregate expression -#[derive(Debug)] -pub struct VariancePop { - name: String, - expr: Arc, -} - -impl Variance { - /// Create a new VARIANCE aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of variance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for Variance { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Variance { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - -impl VariancePop { - /// Create a new VAR_POP aggregate function - pub fn new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of variance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr, - } - } -} - -impl AggregateExpr for VariancePop { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new( - StatsType::Population, - )?)) - } - - fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(VarianceAccumulator::try_new( - StatsType::Population, - )?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean"), - DataType::Float64, - true, - ), - Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for VariancePop { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.name == x.name && self.expr.eq(&x.expr)) - .unwrap_or(false) - } -} - -/// An accumulator to compute variance -/// The algrithm used is an online implementation and numerically stable. It is based on this paper: -/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". -/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. -/// -/// The algorithm has been analyzed here: -/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". -/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. - -#[derive(Debug)] -pub struct VarianceAccumulator { - m2: f64, - mean: f64, - count: u64, - stats_type: StatsType, -} - -impl VarianceAccumulator { - /// Creates a new `VarianceAccumulator` - pub fn try_new(s_type: StatsType) -> Result { - Ok(Self { - m2: 0_f64, - mean: 0_f64, - count: 0_u64, - stats_type: s_type, - }) - } - - pub fn get_count(&self) -> u64 { - self.count - } - - pub fn get_mean(&self) -> f64 { - self.mean - } - - pub fn get_m2(&self) -> f64 { - self.m2 - } -} - -impl Accumulator for VarianceAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![ - ScalarValue::from(self.count), - ScalarValue::from(self.mean), - ScalarValue::from(self.m2), - ]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { - let new_count = self.count + 1; - let delta1 = value - self.mean; - let new_mean = delta1 / new_count as f64 + self.mean; - let delta2 = value - new_mean; - let new_m2 = self.m2 + delta1 * delta2; - - self.count += 1; - self.mean = new_mean; - self.m2 = new_m2; - } - - Ok(()) - } - - fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &cast(&values[0], &DataType::Float64)?; - let arr = downcast_value!(values, Float64Array).iter().flatten(); - - for value in arr { - let new_count = self.count - 1; - let delta1 = self.mean - value; - let new_mean = delta1 / new_count as f64 + self.mean; - let delta2 = new_mean - value; - let new_m2 = self.m2 - delta1 * delta2; - - self.count -= 1; - self.mean = new_mean; - self.m2 = new_m2; - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); - let means = downcast_value!(states[1], Float64Array); - let m2s = downcast_value!(states[2], Float64Array); - - for i in 0..counts.len() { - let c = counts.value(i); - if c == 0_u64 { - continue; - } - let new_count = self.count + c; - let new_mean = self.mean * self.count as f64 / new_count as f64 - + means.value(i) * c as f64 / new_count as f64; - let delta = self.mean - means.value(i); - let new_m2 = self.m2 - + m2s.value(i) - + delta * delta * self.count as f64 * c as f64 / new_count as f64; - - self.count = new_count; - self.mean = new_mean; - self.m2 = new_m2; - } - Ok(()) - } - - fn evaluate(&mut self) -> Result { - let count = match self.stats_type { - StatsType::Population => self.count, - StatsType::Sample => { - if self.count > 0 { - self.count - 1 - } else { - self.count - } - } - }; - - Ok(ScalarValue::Float64(match self.count { - 0 => None, - 1 => { - if let StatsType::Population = self.stats_type { - Some(0.0) - } else { - None - } - } - _ => Some(self.m2 / count as f64), - })) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::aggregate::utils::get_accum_scalar_values_as_arrays; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::{array::*, datatypes::*}; - - #[test] - fn variance_f64_1() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); - generic_test_op!( - a, - DataType::Float64, - VariancePop, - ScalarValue::from(0.25_f64) - ) - } - - #[test] - fn variance_f64_2() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, VariancePop, ScalarValue::from(2_f64)) - } - - #[test] - fn variance_f64_3() -> Result<()> { - let a: ArrayRef = - Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64)) - } - - #[test] - fn variance_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - generic_test_op!( - a, - DataType::Float64, - Variance, - ScalarValue::from(0.9033333333333333_f64) - ) - } - - #[test] - fn variance_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, VariancePop, ScalarValue::from(2_f64)) - } - - #[test] - fn variance_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, VariancePop, ScalarValue::from(2.0f64)) - } - - #[test] - fn variance_f32() -> Result<()> { - let a: ArrayRef = - Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, VariancePop, ScalarValue::from(2_f64)) - } - - #[test] - fn test_variance_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - - Ok(()) - } - - #[test] - fn variance_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(3), - Some(4), - Some(5), - ])); - generic_test_op!( - a, - DataType::Int32, - VariancePop, - ScalarValue::from(2.1875_f64) - ) - } - - #[test] - fn variance_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - - let agg = Arc::new(Variance::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - let actual = aggregate(&batch, agg).unwrap(); - assert_eq!(actual, ScalarValue::Float64(None)); - - Ok(()) - } - - #[test] - fn variance_f64_merge_1() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b = Arc::new(Float64Array::from(vec![4_f64, 5_f64])); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; - - let agg1 = Arc::new(VariancePop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(VariancePop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(2_f64)); - - Ok(()) - } - - #[test] - fn variance_f64_merge_2() -> Result<()> { - let a = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - let b = Arc::new(Float64Array::from(vec![None])); - - let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); - - let batch1 = RecordBatch::try_new(Arc::new(schema.clone()), vec![a])?; - let batch2 = RecordBatch::try_new(Arc::new(schema.clone()), vec![b])?; - - let agg1 = Arc::new(VariancePop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let agg2 = Arc::new(VariancePop::new( - col("a", &schema)?, - "bla".to_string(), - DataType::Float64, - )); - - let actual = merge(&batch1, &batch2, agg1, agg2)?; - assert!(actual == ScalarValue::from(2_f64)); - - Ok(()) - } - - fn merge( - batch1: &RecordBatch, - batch2: &RecordBatch, - agg1: Arc, - agg2: Arc, - ) -> Result { - let mut accum1 = agg1.create_accumulator()?; - let mut accum2 = agg2.create_accumulator()?; - let expr1 = agg1.expressions(); - let expr2 = agg2.expressions(); - - let values1 = expr1 - .iter() - .map(|e| { - e.evaluate(batch1) - .and_then(|v| v.into_array(batch1.num_rows())) - }) - .collect::>>()?; - let values2 = expr2 - .iter() - .map(|e| { - e.evaluate(batch2) - .and_then(|v| v.into_array(batch2.num_rows())) - }) - .collect::>>()?; - accum1.update_batch(&values1)?; - accum2.update_batch(&values2)?; - let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?; - accum1.merge_batch(&state2)?; - accum1.evaluate() - } -} diff --git a/datafusion/physical-expr/src/analysis.rs b/datafusion/physical-expr/src/analysis.rs index ca25bfd647b6..3eac62a4df08 100644 --- a/datafusion/physical-expr/src/analysis.rs +++ b/datafusion/physical-expr/src/analysis.rs @@ -27,7 +27,9 @@ use crate::PhysicalExpr; use arrow::datatypes::Schema; use datafusion_common::stats::Precision; -use datafusion_common::{internal_err, ColumnStatistics, Result, ScalarValue}; +use datafusion_common::{ + internal_datafusion_err, internal_err, ColumnStatistics, Result, ScalarValue, +}; use datafusion_expr::interval_arithmetic::{cardinality_ratio, Interval}; /// The shared context used during the analysis of an expression. Includes @@ -92,7 +94,13 @@ impl ExprBoundaries { col_stats: &ColumnStatistics, col_index: usize, ) -> Result { - let field = &schema.fields()[col_index]; + let field = schema.fields().get(col_index).ok_or_else(|| { + internal_datafusion_err!( + "Could not create `ExprBoundaries`: in `try_from_column` `col_index` + has gone out of bounds with a value of {col_index}, the schema has {} columns.", + schema.fields.len() + ) + })?; let empty_field = ScalarValue::try_from(field.data_type()).unwrap_or(ScalarValue::Null); let interval = Interval::try_new( @@ -111,7 +119,7 @@ impl ExprBoundaries { Ok(ExprBoundaries { column, interval, - distinct_count: col_stats.distinct_count.clone(), + distinct_count: col_stats.distinct_count, }) } @@ -155,7 +163,7 @@ pub fn analyze( ) -> Result { let target_boundaries = context.boundaries; - let mut graph = ExprIntervalGraph::try_new(expr.clone(), schema)?; + let mut graph = ExprIntervalGraph::try_new(Arc::clone(expr), schema)?; let columns = collect_columns(expr) .into_iter() diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 9ea456b0f879..7305bc1b0a2b 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::Display; use std::sync::Arc; use super::{add_offset_to_expr, collapse_lex_req, ProjectionMapping}; @@ -27,6 +28,157 @@ use crate::{ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::JoinType; +use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; + +/// A structure representing a expression known to be constant in a physical execution plan. +/// +/// The `ConstExpr` struct encapsulates an expression that is constant during the execution +/// of a query. For example if a predicate like `A = 5` applied earlier in the plan `A` would +/// be known constant +/// +/// # Fields +/// +/// - `expr`: Constant expression for a node in the physical plan. +/// +/// - `across_partitions`: A boolean flag indicating whether the constant +/// expression is the same across partitions. If set to `true`, the constant +/// expression has same value for all partitions. If set to `false`, the +/// constant expression may have different values for different partitions. +/// +/// # Example +/// +/// ```rust +/// # use datafusion_physical_expr::ConstExpr; +/// # use datafusion_physical_expr::expressions::lit; +/// let col = lit(5); +/// // Create a constant expression from a physical expression ref +/// let const_expr = ConstExpr::from(&col); +/// // create a constant expression from a physical expression +/// let const_expr = ConstExpr::from(col); +/// ``` +#[derive(Debug, Clone)] +pub struct ConstExpr { + /// The expression that is known to be constant (e.g. a `Column`) + expr: Arc, + /// Does the constant have the same value across all partitions? See + /// struct docs for more details + across_partitions: bool, +} + +impl PartialEq for ConstExpr { + fn eq(&self, other: &Self) -> bool { + self.across_partitions == other.across_partitions + && self.expr.eq(other.expr.as_any()) + } +} + +impl ConstExpr { + /// Create a new constant expression from a physical expression. + /// + /// Note you can also use `ConstExpr::from` to create a constant expression + /// from a reference as well + pub fn new(expr: Arc) -> Self { + Self { + expr, + // By default, assume constant expressions are not same across partitions. + across_partitions: false, + } + } + + /// Set the `across_partitions` flag + /// + /// See struct docs for more details + pub fn with_across_partitions(mut self, across_partitions: bool) -> Self { + self.across_partitions = across_partitions; + self + } + + /// Is the expression the same across all partitions? + /// + /// See struct docs for more details + pub fn across_partitions(&self) -> bool { + self.across_partitions + } + + pub fn expr(&self) -> &Arc { + &self.expr + } + + pub fn owned_expr(self) -> Arc { + self.expr + } + + pub fn map(&self, f: F) -> Option + where + F: Fn(&Arc) -> Option>, + { + let maybe_expr = f(&self.expr); + maybe_expr.map(|expr| Self { + expr, + across_partitions: self.across_partitions, + }) + } + + /// Returns true if this constant expression is equal to the given expression + pub fn eq_expr(&self, other: impl AsRef) -> bool { + self.expr.eq(other.as_ref().as_any()) + } + + /// Returns a [`Display`]able list of `ConstExpr`. + pub fn format_list(input: &[ConstExpr]) -> impl Display + '_ { + struct DisplayableList<'a>(&'a [ConstExpr]); + impl<'a> Display for DisplayableList<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let mut first = true; + for const_expr in self.0 { + if first { + first = false; + } else { + write!(f, ",")?; + } + write!(f, "{}", const_expr)?; + } + Ok(()) + } + } + DisplayableList(input) + } +} + +/// Display implementation for `ConstExpr` +/// +/// Example `c` or `c(across_partitions)` +impl Display for ConstExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.expr)?; + if self.across_partitions { + write!(f, "(across_partitions)")?; + } + Ok(()) + } +} + +impl From> for ConstExpr { + fn from(expr: Arc) -> Self { + Self::new(expr) + } +} + +impl From<&Arc> for ConstExpr { + fn from(expr: &Arc) -> Self { + Self::new(Arc::clone(expr)) + } +} + +/// Checks whether `expr` is among in the `const_exprs`. +pub fn const_exprs_contains( + const_exprs: &[ConstExpr], + expr: &Arc, +) -> bool { + const_exprs + .iter() + .any(|const_expr| const_expr.expr.eq(expr)) +} /// An `EquivalenceClass` is a set of [`Arc`]s that are known /// to have the same value for all tuples in a relation. These are generated by @@ -129,6 +281,12 @@ impl EquivalenceClass { } } +impl Display for EquivalenceClass { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "[{}]", format_physical_expr_list(&self.exprs)) + } +} + /// An `EquivalenceGroup` is a collection of `EquivalenceClass`es where each /// class represents a distinct equivalence class in a relation. #[derive(Debug, Clone)] @@ -200,17 +358,19 @@ impl EquivalenceGroup { } (Some(group_idx), None) => { // Right side is new, extend left side's class: - self.classes[group_idx].push(right.clone()); + self.classes[group_idx].push(Arc::clone(right)); } (None, Some(group_idx)) => { // Left side is new, extend right side's class: - self.classes[group_idx].push(left.clone()); + self.classes[group_idx].push(Arc::clone(left)); } (None, None) => { // None of the expressions is among existing classes. // Create a new equivalence class and extend the group. - self.classes - .push(EquivalenceClass::new(vec![left.clone(), right.clone()])); + self.classes.push(EquivalenceClass::new(vec![ + Arc::clone(left), + Arc::clone(right), + ])); } } } @@ -261,7 +421,7 @@ impl EquivalenceGroup { /// The expression is replaced with the first expression in the equivalence /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { - expr.clone() + Arc::clone(&expr) .transform(|expr| { for cls in self.iter() { if cls.contains(&expr) { @@ -321,7 +481,7 @@ impl EquivalenceGroup { // Normalize the requirements: let normalized_sort_reqs = self.normalize_sort_requirements(&sort_reqs); // Convert sort requirements back to sort expressions: - PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs) + PhysicalSortRequirement::to_sort_exprs(normalized_sort_reqs.inner) } /// This function applies the `normalize_sort_requirement` function for all @@ -331,12 +491,12 @@ impl EquivalenceGroup { &self, sort_reqs: LexRequirementRef, ) -> LexRequirement { - collapse_lex_req( + collapse_lex_req(LexRequirement::new( sort_reqs .iter() .map(|sort_req| self.normalize_sort_requirement(sort_req.clone())) .collect(), - ) + )) } /// Projects `expr` according to the given projection mapping. @@ -362,7 +522,7 @@ impl EquivalenceGroup { .get_equivalence_class(source) .map_or(false, |group| group.contains(expr)) { - return Some(target.clone()); + return Some(Arc::clone(target)); } } } @@ -374,9 +534,9 @@ impl EquivalenceGroup { } children .into_iter() - .map(|child| self.project_expr(mapping, &child)) + .map(|child| self.project_expr(mapping, child)) .collect::>>() - .map(|children| expr.clone().with_new_children(children).unwrap()) + .map(|children| Arc::clone(expr).with_new_children(children).unwrap()) } /// Projects this equivalence group according to the given projection mapping. @@ -394,13 +554,13 @@ impl EquivalenceGroup { let mut new_classes = vec![]; for (source, target) in mapping.iter() { if new_classes.is_empty() { - new_classes.push((source, vec![target.clone()])); + new_classes.push((source, vec![Arc::clone(target)])); } if let Some((_, values)) = new_classes.iter_mut().find(|(key, _)| key.eq(source)) { if !physical_exprs_contains(values, target) { - values.push(target.clone()); + values.push(Arc::clone(target)); } } } @@ -448,10 +608,9 @@ impl EquivalenceGroup { // are equal in the resulting table. if join_type == &JoinType::Inner { for (lhs, rhs) in on.iter() { - let new_lhs = lhs.clone() as _; + let new_lhs = Arc::clone(lhs) as _; // Rewrite rhs to point to the right side of the join: - let new_rhs = rhs - .clone() + let new_rhs = Arc::clone(rhs) .transform(|expr| { if let Some(column) = expr.as_any().downcast_ref::() @@ -473,12 +632,26 @@ impl EquivalenceGroup { } result } - JoinType::LeftSemi | JoinType::LeftAnti => self.clone(), + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => self.clone(), JoinType::RightSemi | JoinType::RightAnti => right_equivalences.clone(), } } } +impl Display for EquivalenceGroup { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "[")?; + let mut iter = self.iter(); + if let Some(cls) = iter.next() { + write!(f, "{}", cls)?; + } + for cls in iter { + write!(f, ", {}", cls)?; + } + write!(f, "]") + } +} + #[cfg(test)] mod tests { @@ -582,7 +755,7 @@ mod tests { let eq_group = eq_properties.eq_group(); for (expr, expected_eq) in expressions { assert!( - expected_eq.eq(&eq_group.normalize_expr(expr.clone())), + expected_eq.eq(&eq_group.normalize_expr(Arc::clone(expr))), "error in test: expr: {expr:?}" ); } @@ -602,9 +775,11 @@ mod tests { Arc::new(Literal::new(ScalarValue::Int32(Some(1)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let cls1 = EquivalenceClass::new(vec![lit_true.clone(), lit_false.clone()]); - let cls2 = EquivalenceClass::new(vec![lit_true.clone(), col_b_expr.clone()]); - let cls3 = EquivalenceClass::new(vec![lit2.clone(), lit1.clone()]); + let cls1 = + EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]); + let cls2 = + EquivalenceClass::new(vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]); + let cls3 = EquivalenceClass::new(vec![Arc::clone(&lit2), Arc::clone(&lit1)]); // lit_true is common assert!(cls1.contains_any(&cls2)); diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 7fc27172e431..902e53a7f236 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -18,7 +18,6 @@ use std::sync::Arc; use crate::expressions::Column; -use crate::sort_properties::SortProperties; use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -28,10 +27,12 @@ mod ordering; mod projection; mod properties; -pub use class::{EquivalenceClass, EquivalenceGroup}; +pub use class::{ConstExpr, EquivalenceClass, EquivalenceGroup}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; -pub use properties::{join_equivalence_properties, EquivalenceProperties}; +pub use properties::{ + calculate_union, join_equivalence_properties, EquivalenceProperties, +}; /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, @@ -47,48 +48,7 @@ pub fn collapse_lex_req(input: LexRequirement) -> LexRequirement { output.push(item); } } - collapse_monotonic_lex_req(output) -} - -/// This function constructs a normalized [`LexRequirement`] by filtering out entries -/// that are ordered if the next entry is. -/// Used in `collapse_lex_req` -fn collapse_monotonic_lex_req(input: LexRequirement) -> LexRequirement { - input - .iter() - .enumerate() - .filter_map(|(i, item)| { - // If it's the last entry, there is no next entry - if i == input.len() - 1 { - return Some(item); - } - let next_expr = &input[i + 1]; - - // Only handle expressions with exactly one child - // TODO: it should be possible to handle expressions orderings f(a, b, c), a, b, c - // if f is monotonic in all arguments - if !(item.expr.children().len() == 1 - && item.expr.children()[0].eq(&next_expr.expr)) - { - return Some(item); - } - - let opts = match next_expr.options { - None => return Some(item), - Some(opts) => opts, - }; - - if item.options.map(SortProperties::Ordered) - == Some(item.expr.get_ordering(&[SortProperties::Ordered(opts)])) - { - // Remove the redundant sort - return None; - } - - Some(item) - }) - .cloned() - .collect::>() + LexRequirement::new(output) } /// Adds the `offset` value to `Column` indices inside `expr`. This function is @@ -117,16 +77,10 @@ mod tests { use crate::expressions::col; use crate::PhysicalSortExpr; - use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; - - use itertools::izip; - use rand::rngs::StdRng; - use rand::seq::SliceRandom; - use rand::{Rng, SeedableRng}; + use datafusion_physical_expr_common::sort_expr::LexOrdering; pub fn output_schema( mapping: &ProjectionMapping, @@ -188,8 +142,8 @@ mod tests { let col_e = &col("e", &test_schema)?; let col_f = &col("f", &test_schema)?; let col_g = &col("g", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - eq_properties.add_equal_conditions(col_a, col_c); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); + eq_properties.add_equal_conditions(col_a, col_c)?; let option_asc = SortOptions { descending: false, @@ -216,75 +170,14 @@ mod tests { Ok((test_schema, eq_properties)) } - // Generate a schema which consists of 6 columns (a, b, c, d, e, f) - fn create_test_schema_2() -> Result { - let a = Field::new("a", DataType::Float64, true); - let b = Field::new("b", DataType::Float64, true); - let c = Field::new("c", DataType::Float64, true); - let d = Field::new("d", DataType::Float64, true); - let e = Field::new("e", DataType::Float64, true); - let f = Field::new("f", DataType::Float64, true); - let schema = Arc::new(Schema::new(vec![a, b, c, d, e, f])); - - Ok(schema) - } - - /// Construct a schema with random ordering - /// among column a, b, c, d - /// where - /// Column [a=f] (e.g they are aliases). - /// Column e is constant. - pub fn create_random_schema(seed: u64) -> Result<(SchemaRef, EquivalenceProperties)> { - let test_schema = create_test_schema_2()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_exprs = [col_a, col_b, col_c, col_d, col_e, col_f]; - - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); - // Define a and f are aliases - eq_properties.add_equal_conditions(col_a, col_f); - // Column e has constant value. - eq_properties = eq_properties.add_constants([col_e.clone()]); - - // Randomly order columns for sorting - let mut rng = StdRng::seed_from_u64(seed); - let mut remaining_exprs = col_exprs[0..4].to_vec(); // only a, b, c, d are sorted - - let options_asc = SortOptions { - descending: false, - nulls_first: false, - }; - - while !remaining_exprs.is_empty() { - let n_sort_expr = rng.gen_range(0..remaining_exprs.len() + 1); - remaining_exprs.shuffle(&mut rng); - - let ordering = remaining_exprs - .drain(0..n_sort_expr) - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: options_asc, - }) - .collect(); - - eq_properties.add_new_orderings([ordering]); - } - - Ok((test_schema, eq_properties)) - } - // Convert each tuple to PhysicalSortRequirement pub fn convert_to_sort_reqs( in_data: &[(&Arc, Option)], - ) -> Vec { + ) -> LexRequirement { in_data .iter() .map(|(expr, options)| { - PhysicalSortRequirement::new((*expr).clone(), *options) + PhysicalSortRequirement::new(Arc::clone(*expr), *options) }) .collect() } @@ -292,11 +185,11 @@ mod tests { // Convert each tuple to PhysicalSortExpr pub fn convert_to_sort_exprs( in_data: &[(&Arc, SortOptions)], - ) -> Vec { + ) -> LexOrdering { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(*expr), options: *options, }) .collect() @@ -305,7 +198,7 @@ mod tests { // Convert each inner tuple to PhysicalSortExpr pub fn convert_to_orderings( orderings: &[Vec<(&Arc, SortOptions)>], - ) -> Vec> { + ) -> Vec { orderings .iter() .map(|sort_exprs| convert_to_sort_exprs(sort_exprs)) @@ -315,53 +208,28 @@ mod tests { // Convert each tuple to PhysicalSortExpr pub fn convert_to_sort_exprs_owned( in_data: &[(Arc, SortOptions)], - ) -> Vec { - in_data - .iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), - options: *options, - }) - .collect() + ) -> LexOrdering { + LexOrdering::new( + in_data + .iter() + .map(|(expr, options)| PhysicalSortExpr { + expr: Arc::clone(expr), + options: *options, + }) + .collect(), + ) } // Convert each inner tuple to PhysicalSortExpr pub fn convert_to_orderings_owned( orderings: &[Vec<(Arc, SortOptions)>], - ) -> Vec> { + ) -> Vec { orderings .iter() .map(|sort_exprs| convert_to_sort_exprs_owned(sort_exprs)) .collect() } - // Apply projection to the input_data, return projected equivalence properties and record batch - pub fn apply_projection( - proj_exprs: Vec<(Arc, String)>, - input_data: &RecordBatch, - input_eq_properties: &EquivalenceProperties, - ) -> Result<(RecordBatch, EquivalenceProperties)> { - let input_schema = input_data.schema(); - let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; - - let output_schema = output_schema(&projection_mapping, &input_schema)?; - let num_rows = input_data.num_rows(); - // Apply projection to the input record batch. - let projected_values = projection_mapping - .iter() - .map(|(source, _target)| source.evaluate(input_data)?.into_array(num_rows)) - .collect::>>()?; - let projected_batch = if projected_values.is_empty() { - RecordBatch::new_empty(output_schema.clone()) - } else { - RecordBatch::try_new(output_schema.clone(), projected_values)? - }; - - let projected_eq = - input_eq_properties.project(&projection_mapping, output_schema); - Ok((projected_batch, projected_eq)) - } - #[test] fn add_equal_conditions_test() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -380,11 +248,11 @@ mod tests { let col_y_expr = Arc::new(Column::new("y", 4)) as Arc; // a and b are aliases - eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr); + eq_properties.add_equal_conditions(&col_a_expr, &col_b_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); // This new entry is redundant, size shouldn't increase - eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr); + eq_properties.add_equal_conditions(&col_b_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 2); @@ -393,7 +261,7 @@ mod tests { // b and c are aliases. Exising equivalence class should expand, // however there shouldn't be any new equivalence class - eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr); + eq_properties.add_equal_conditions(&col_b_expr, &col_c_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 3); @@ -402,12 +270,12 @@ mod tests { assert!(eq_groups.contains(&col_c_expr)); // This is a new set of equality. Hence equivalent class count should be 2. - eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr); + eq_properties.add_equal_conditions(&col_x_expr, &col_y_expr)?; assert_eq!(eq_properties.eq_group().len(), 2); // This equality bridges distinct equality sets. // Hence equivalent class count should decrease from 2 to 1. - eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr); + eq_properties.add_equal_conditions(&col_x_expr, &col_a_expr)?; assert_eq!(eq_properties.eq_group().len(), 1); let eq_groups = &eq_properties.eq_group().classes[0]; assert_eq!(eq_groups.len(), 5); @@ -419,168 +287,4 @@ mod tests { Ok(()) } - - /// Checks if the table (RecordBatch) remains unchanged when sorted according to the provided `required_ordering`. - /// - /// The function works by adding a unique column of ascending integers to the original table. This column ensures - /// that rows that are otherwise indistinguishable (e.g., if they have the same values in all other columns) can - /// still be differentiated. When sorting the extended table, the unique column acts as a tie-breaker to produce - /// deterministic sorting results. - /// - /// If the table remains the same after sorting with the added unique column, it indicates that the table was - /// already sorted according to `required_ordering` to begin with. - pub fn is_table_same_after_sort( - mut required_ordering: Vec, - batch: RecordBatch, - ) -> Result { - // Clone the original schema and columns - let original_schema = batch.schema(); - let mut columns = batch.columns().to_vec(); - - // Create a new unique column - let n_row = batch.num_rows(); - let vals: Vec = (0..n_row).collect::>(); - let vals: Vec = vals.into_iter().map(|val| val as f64).collect(); - let unique_col = Arc::new(Float64Array::from_iter_values(vals)) as ArrayRef; - columns.push(unique_col.clone()); - - // Create a new schema with the added unique column - let unique_col_name = "unique"; - let unique_field = - Arc::new(Field::new(unique_col_name, DataType::Float64, false)); - let fields: Vec<_> = original_schema - .fields() - .iter() - .cloned() - .chain(std::iter::once(unique_field)) - .collect(); - let schema = Arc::new(Schema::new(fields)); - - // Create a new batch with the added column - let new_batch = RecordBatch::try_new(schema.clone(), columns)?; - - // Add the unique column to the required ordering to ensure deterministic results - required_ordering.push(PhysicalSortExpr { - expr: Arc::new(Column::new(unique_col_name, original_schema.fields().len())), - options: Default::default(), - }); - - // Convert the required ordering to a list of SortColumn - let sort_columns = required_ordering - .iter() - .map(|order_expr| { - let expr_result = order_expr.expr.evaluate(&new_batch)?; - let values = expr_result.into_array(new_batch.num_rows())?; - Ok(SortColumn { - values, - options: Some(order_expr.options), - }) - }) - .collect::>>()?; - - // Check if the indices after sorting match the initial ordering - let sorted_indices = lexsort_to_indices(&sort_columns, None)?; - let original_indices = UInt32Array::from_iter_values(0..n_row as u32); - - Ok(sorted_indices == original_indices) - } - - // If we already generated a random result for one of the - // expressions in the equivalence classes. For other expressions in the same - // equivalence class use same result. This util gets already calculated result, when available. - fn get_representative_arr( - eq_group: &EquivalenceClass, - existing_vec: &[Option], - schema: SchemaRef, - ) -> Option { - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - if let Some(res) = &existing_vec[idx] { - return Some(res.clone()); - } - } - None - } - - // Generate a table that satisfies the given equivalence properties; i.e. - // equivalences, ordering equivalences, and constants. - pub fn generate_table_for_eq_properties( - eq_properties: &EquivalenceProperties, - n_elem: usize, - n_distinct: usize, - ) -> Result { - let mut rng = StdRng::seed_from_u64(23); - - let schema = eq_properties.schema(); - let mut schema_vec = vec![None; schema.fields.len()]; - - // Utility closure to generate random array - let mut generate_random_array = |num_elems: usize, max_val: usize| -> ArrayRef { - let values: Vec = (0..num_elems) - .map(|_| rng.gen_range(0..max_val) as f64 / 2.0) - .collect(); - Arc::new(Float64Array::from_iter_values(values)) - }; - - // Fill constant columns - for constant in &eq_properties.constants { - let col = constant.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = Arc::new(Float64Array::from_iter_values(vec![0 as f64; n_elem])) - as ArrayRef; - schema_vec[idx] = Some(arr); - } - - // Fill columns based on ordering equivalences - for ordering in eq_properties.oeq_class.iter() { - let (sort_columns, indices): (Vec<_>, Vec<_>) = ordering - .iter() - .map(|PhysicalSortExpr { expr, options }| { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - let arr = generate_random_array(n_elem, n_distinct); - ( - SortColumn { - values: arr, - options: Some(*options), - }, - idx, - ) - }) - .unzip(); - - let sort_arrs = arrow::compute::lexsort(&sort_columns, None)?; - for (idx, arr) in izip!(indices, sort_arrs) { - schema_vec[idx] = Some(arr); - } - } - - // Fill columns based on equivalence groups - for eq_group in eq_properties.eq_group.iter() { - let representative_array = - get_representative_arr(eq_group, &schema_vec, schema.clone()) - .unwrap_or_else(|| generate_random_array(n_elem, n_distinct)); - - for expr in eq_group.iter() { - let col = expr.as_any().downcast_ref::().unwrap(); - let (idx, _field) = schema.column_with_name(col.name()).unwrap(); - schema_vec[idx] = Some(representative_array.clone()); - } - } - - let res: Vec<_> = schema_vec - .into_iter() - .zip(schema.fields.iter()) - .map(|(elem, field)| { - ( - field.name(), - // Generate random values for columns that do not occur in any of the groups (equivalence, ordering equivalence, constants) - elem.unwrap_or_else(|| generate_random_array(n_elem, n_distinct)), - ) - }) - .collect(); - - Ok(RecordBatch::try_from_iter(res)?) - } } diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index ed4600f2d95e..838c9800f942 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::fmt::Display; use std::hash::Hash; use std::sync::Arc; - -use arrow_schema::SortOptions; +use std::vec::IntoIter; use crate::equivalence::add_offset_to_expr; -use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; +use crate::{LexOrdering, PhysicalExpr}; +use arrow_schema::SortOptions; /// An `OrderingEquivalenceClass` object keeps track of different alternative /// orderings than can describe a schema. For example, consider the following table: @@ -36,7 +37,7 @@ use crate::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; /// /// Here, both `vec![a ASC, b ASC]` and `vec![c DESC, d ASC]` describe the table /// ordering. In this case, we say that these orderings are equivalent. -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] pub struct OrderingEquivalenceClass { pub orderings: Vec, } @@ -44,7 +45,7 @@ pub struct OrderingEquivalenceClass { impl OrderingEquivalenceClass { /// Creates new empty ordering equivalence class. pub fn empty() -> Self { - Self { orderings: vec![] } + Default::default() } /// Clears (empties) this ordering equivalence class. @@ -104,6 +105,11 @@ impl OrderingEquivalenceClass { self.remove_redundant_entries(); } + /// Adds a single ordering to the existing ordering equivalence class. + pub fn add_new_ordering(&mut self, ordering: LexOrdering) { + self.add_new_orderings([ordering]); + } + /// Removes redundant orderings from this equivalence class. For instance, /// if we already have the ordering `[a ASC, b ASC, c DESC]`, then there is /// no need to keep ordering `[a ASC, b ASC]` in the state. @@ -140,7 +146,12 @@ impl OrderingEquivalenceClass { /// Returns the concatenation of all the orderings. This enables merge /// operations to preserve all equivalent orderings simultaneously. pub fn output_ordering(&self) -> Option { - let output_ordering = self.orderings.iter().flatten().cloned().collect(); + let output_ordering = self + .orderings + .iter() + .flat_map(|ordering| ordering.as_ref()) + .cloned() + .collect(); let output_ordering = collapse_lex_ordering(output_ordering); (!output_ordering.is_empty()).then_some(output_ordering) } @@ -163,7 +174,7 @@ impl OrderingEquivalenceClass { for idx in 0..n_ordering { // Calculate cross product index let idx = outer_idx * n_ordering + idx; - self.orderings[idx].extend(ordering.iter().cloned()); + self.orderings[idx].inner.extend(ordering.iter().cloned()); } } self @@ -173,8 +184,8 @@ impl OrderingEquivalenceClass { /// ordering equivalence class. pub fn add_offset(&mut self, offset: usize) { for ordering in self.orderings.iter_mut() { - for sort_expr in ordering { - sort_expr.expr = add_offset_to_expr(sort_expr.expr.clone(), offset); + for sort_expr in ordering.inner.iter_mut() { + sort_expr.expr = add_offset_to_expr(Arc::clone(&sort_expr.expr), offset); } } } @@ -192,14 +203,23 @@ impl OrderingEquivalenceClass { } } +impl IntoIterator for OrderingEquivalenceClass { + type Item = LexOrdering; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.orderings.into_iter() + } +} + /// This function constructs a duplicate-free `LexOrdering` by filtering out /// duplicate entries that have same physical expression inside. For example, /// `vec![a ASC, a DESC]` collapses to `vec![a ASC]`. pub fn collapse_lex_ordering(input: LexOrdering) -> LexOrdering { - let mut output = Vec::::new(); - for item in input { + let mut output = LexOrdering::default(); + for item in input.iter() { if !output.iter().any(|req| req.expr.eq(&item.expr)) { - output.push(item); + output.push(item.clone()); } } output @@ -219,29 +239,41 @@ fn resolve_overlap(orderings: &mut [LexOrdering], idx: usize, pre_idx: usize) -> false } +impl Display for OrderingEquivalenceClass { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[")?; + let mut iter = self.orderings.iter(); + if let Some(ordering) = iter.next() { + write!(f, "[{}]", ordering)?; + } + for ordering in iter { + write!(f, ", [{}]", ordering)?; + } + write!(f, "]")?; + Ok(()) + } +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::SortOptions; - use itertools::Itertools; - - use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{Operator, ScalarUDF}; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, create_random_schema, - create_test_params, generate_table_for_eq_properties, is_table_same_after_sort, + convert_to_orderings, convert_to_sort_exprs, create_test_schema, }; - use crate::equivalence::{tests::create_test_schema, EquivalenceProperties}; use crate::equivalence::{ - EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + EquivalenceClass, EquivalenceGroup, EquivalenceProperties, + OrderingEquivalenceClass, }; - use crate::expressions::Column; - use crate::expressions::{col, BinaryExpr}; + use crate::expressions::{col, BinaryExpr, Column}; use crate::utils::tests::TestScalarUDF; - use crate::{PhysicalExpr, PhysicalSortExpr}; + use crate::{ConstExpr, PhysicalExpr, PhysicalSortExpr}; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::SortOptions; + use datafusion_common::{DFSchema, Result}; + use datafusion_expr::{Operator, ScalarUDF}; + use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_ordering_satisfy() -> Result<()> { @@ -249,11 +281,11 @@ mod tests { Field::new("a", DataType::Int64, true), Field::new("b", DataType::Int64, true), ])); - let crude = vec![PhysicalSortExpr { + let crude = LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), - }]; - let finer = vec![ + }]); + let finer = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options: SortOptions::default(), @@ -262,16 +294,18 @@ mod tests { expr: Arc::new(Column::new("b", 1)), options: SortOptions::default(), }, - ]; + ]); // finer ordering satisfies, crude ordering should return true - let mut eq_properties_finer = EquivalenceProperties::new(input_schema.clone()); + let mut eq_properties_finer = + EquivalenceProperties::new(Arc::clone(&input_schema)); eq_properties_finer.oeq_class.push(finer.clone()); - assert!(eq_properties_finer.ordering_satisfy(&crude)); + assert!(eq_properties_finer.ordering_satisfy(crude.as_ref())); // Crude ordering doesn't satisfy finer ordering. should return false - let mut eq_properties_crude = EquivalenceProperties::new(input_schema.clone()); - eq_properties_crude.oeq_class.push(crude.clone()); - assert!(!eq_properties_crude.ordering_satisfy(&finer)); + let mut eq_properties_crude = + EquivalenceProperties::new(Arc::clone(&input_schema)); + eq_properties_crude.oeq_class.push(crude); + assert!(!eq_properties_crude.ordering_satisfy(finer.as_ref())); Ok(()) } @@ -307,9 +341,9 @@ mod tests { &DFSchema::empty(), )?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let options = SortOptions { descending: false, @@ -541,7 +575,7 @@ mod tests { for (orderings, eq_group, constants, reqs, expected) in test_cases { let err_msg = format!("error in test orderings: {orderings:?}, eq_group: {eq_group:?}, constants: {constants:?}, reqs: {reqs:?}, expected: {expected:?}"); - let mut eq_properties = EquivalenceProperties::new(test_schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_new_orderings(orderings); let eq_group = eq_group @@ -554,12 +588,14 @@ mod tests { let eq_group = EquivalenceGroup::new(eq_group); eq_properties.add_equivalence_group(eq_group); - let constants = constants.into_iter().cloned(); - eq_properties = eq_properties.add_constants(constants); + let constants = constants + .into_iter() + .map(|expr| ConstExpr::from(expr).with_across_partitions(true)); + eq_properties = eq_properties.with_constants(constants); let reqs = convert_to_sort_exprs(&reqs); assert_eq!( - eq_properties.ordering_satisfy(&reqs), + eq_properties.ordering_satisfy(reqs.as_ref()), expected, "{}", err_msg @@ -569,305 +605,6 @@ mod tests { Ok(()) } - #[test] - fn test_ordering_satisfy_with_equivalence() -> Result<()> { - // Schema satisfies following orderings: - // [a ASC], [d ASC, b ASC], [e DESC, f ASC, g ASC] - // and - // Column [a=c] (e.g they are aliases). - let (test_schema, eq_properties) = create_test_params()?; - let col_a = &col("a", &test_schema)?; - let col_b = &col("b", &test_schema)?; - let col_c = &col("c", &test_schema)?; - let col_d = &col("d", &test_schema)?; - let col_e = &col("e", &test_schema)?; - let col_f = &col("f", &test_schema)?; - let col_g = &col("g", &test_schema)?; - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, 625, 5)?; - - // First element in the tuple stores vector of requirement, second element is the expected return value for ordering_satisfy function - let requirements = vec![ - // `a ASC NULLS LAST`, expects `ordering_satisfy` to be `true`, since existing ordering `a ASC NULLS LAST, b ASC NULLS LAST` satisfies it - (vec![(col_a, option_asc)], true), - (vec![(col_a, option_desc)], false), - // Test whether equivalence works as expected - (vec![(col_c, option_asc)], true), - (vec![(col_c, option_desc)], false), - // Test whether ordering equivalence works as expected - (vec![(col_d, option_asc)], true), - (vec![(col_d, option_asc), (col_b, option_asc)], true), - (vec![(col_d, option_desc), (col_b, option_asc)], false), - ( - vec![ - (col_e, option_desc), - (col_f, option_asc), - (col_g, option_asc), - ], - true, - ), - (vec![(col_e, option_desc), (col_f, option_asc)], true), - (vec![(col_e, option_asc), (col_f, option_asc)], false), - (vec![(col_e, option_desc), (col_b, option_asc)], false), - (vec![(col_e, option_asc), (col_b, option_asc)], false), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_f, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_d, option_desc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_f, option_asc), - ], - false, - ), - ( - vec![ - (col_d, option_asc), - (col_b, option_asc), - (col_e, option_asc), - (col_b, option_asc), - ], - false, - ), - (vec![(col_d, option_asc), (col_e, option_desc)], true), - ( - vec![ - (col_d, option_asc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_f, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_c, option_asc), - (col_b, option_asc), - ], - true, - ), - ( - vec![ - (col_d, option_asc), - (col_e, option_desc), - (col_b, option_asc), - (col_f, option_asc), - ], - true, - ), - ]; - - for (cols, expected) in requirements { - let err_msg = format!("Error in test case:{cols:?}"); - let required = cols - .into_iter() - .map(|(expr, options)| PhysicalSortExpr { - expr: expr.clone(), - options, - }) - .collect::>(); - - // Check expected result with experimental result. - assert_eq!( - is_table_same_after_sort( - required.clone(), - table_data_with_properties.clone() - )?, - expected - ); - assert_eq!( - eq_properties.ordering_satisfy(&required), - expected, - "{err_msg}" - ); - } - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 5; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - ]; - - for n_req in 0..=col_exprs.len() { - for exprs in col_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - - #[test] - fn test_ordering_satisfy_with_equivalence_complex_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - table_data_with_properties.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - - assert_eq!( - eq_properties.ordering_satisfy(&requirement), - (expected | false), - "{}", - err_msg - ); - } - } - } - - Ok(()) - } - #[test] fn test_ordering_satisfy_different_lengths() -> Result<()> { let test_schema = create_test_schema()?; @@ -883,7 +620,7 @@ mod tests { }; // a=c (e.g they are aliases). let mut eq_properties = EquivalenceProperties::new(test_schema); - eq_properties.add_equal_conditions(col_a, col_c); + eq_properties.add_equal_conditions(col_a, col_c)?; let orderings = vec![ vec![(col_a, options)], @@ -918,7 +655,7 @@ mod tests { format!("error in test reqs: {:?}, expected: {:?}", reqs, expected,); let reqs = convert_to_sort_exprs(&reqs); assert_eq!( - eq_properties.ordering_satisfy(&reqs), + eq_properties.ordering_satisfy(reqs.as_ref()), expected, "{}", err_msg diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 260610f23dc6..25a05a2a5918 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -17,14 +17,13 @@ use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use crate::expressions::Column; +use crate::PhysicalExpr; +use arrow::datatypes::SchemaRef; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, Result}; -use crate::expressions::Column; -use crate::PhysicalExpr; - /// Stores the mapping between source expressions and target expressions for a /// projection. #[derive(Debug, Clone)] @@ -57,8 +56,7 @@ impl ProjectionMapping { .enumerate() .map(|(expr_idx, (expression, name))| { let target_expr = Arc::new(Column::new(name, expr_idx)) as _; - expression - .clone() + Arc::clone(expression) .transform_down(|e| match e.as_any().downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema @@ -84,6 +82,15 @@ impl ProjectionMapping { .map(|map| Self { map }) } + /// Constructs a subset mapping using the provided indices. + /// + /// This is used when the output is a subset of the input without any + /// other transformations. The indices are for columns in the schema. + pub fn from_indices(indices: &[usize], schema: &SchemaRef) -> Result { + let projection_exprs = project_index_to_exprs(indices, schema); + ProjectionMapping::try_new(&projection_exprs, schema) + } + /// Iterate over pairs of (source, target) expressions pub fn iter( &self, @@ -108,32 +115,41 @@ impl ProjectionMapping { self.map .iter() .find(|(source, _)| source.eq(expr)) - .map(|(_, target)| target.clone()) + .map(|(_, target)| Arc::clone(target)) } } +fn project_index_to_exprs( + projection_index: &[usize], + schema: &SchemaRef, +) -> Vec<(Arc, String)> { + projection_index + .iter() + .map(|index| { + let field = schema.field(*index); + ( + Arc::new(Column::new(field.name(), *index)) as Arc, + field.name().to_owned(), + ) + }) + .collect::>() +} + #[cfg(test)] mod tests { - - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{SortOptions, TimeUnit}; - use itertools::Itertools; - - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; - + use super::*; use crate::equivalence::tests::{ - apply_projection, convert_to_orderings, convert_to_orderings_owned, - create_random_schema, generate_table_for_eq_properties, is_table_same_after_sort, - output_schema, + convert_to_orderings, convert_to_orderings_owned, output_schema, }; use crate::equivalence::EquivalenceProperties; use crate::expressions::{col, BinaryExpr}; use crate::udf::create_physical_expr; use crate::utils::tests::TestScalarUDF; - use crate::PhysicalSortExpr; - use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{SortOptions, TimeUnit}; + use datafusion_common::DFSchema; + use datafusion_expr::{Operator, ScalarUDF}; #[test] fn project_orderings() -> Result<()> { @@ -152,24 +168,24 @@ mod tests { let col_e = &col("e", &schema)?; let col_ts = &col("ts", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(col_b), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let b_plus_e = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(col_b), Operator::Plus, - col_e.clone(), + Arc::clone(col_e), )) as Arc; let c_plus_d = Arc::new(BinaryExpr::new( - col_c.clone(), + Arc::clone(col_c), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let option_asc = SortOptions { @@ -590,14 +606,14 @@ mod tests { for (idx, (orderings, proj_exprs, expected)) in test_cases.into_iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(&orderings); eq_properties.add_new_orderings(orderings); let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -646,15 +662,15 @@ mod tests { let col_c = &col("c", &schema)?; let col_ts = &col("ts", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); let round_c = &create_physical_expr( &test_fun, - &[col_c.clone()], + &[Arc::clone(col_c)], &schema, &[], &DFSchema::empty(), @@ -673,7 +689,7 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -683,9 +699,9 @@ mod tests { let col_c_new = &col("c_new", &output_schema)?; let col_round_c_res = &col("round_c_res", &output_schema)?; let a_new_plus_b_new = Arc::new(BinaryExpr::new( - col_a_new.clone(), + Arc::clone(col_a_new), Operator::Plus, - col_b_new.clone(), + Arc::clone(col_b_new), )) as Arc; let test_cases = vec![ @@ -796,7 +812,7 @@ mod tests { ]; for (idx, (orderings, expected)) in test_cases.iter().enumerate() { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let orderings = convert_to_orderings(orderings); eq_properties.add_new_orderings(orderings); @@ -804,7 +820,7 @@ mod tests { let expected = convert_to_orderings(expected); let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); + eq_properties.project(&projection_mapping, Arc::clone(&output_schema)); let orderings = projected_eq.oeq_class(); let err_msg = format!( @@ -837,9 +853,9 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let a_plus_b = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc; let option_asc = SortOptions { @@ -854,7 +870,7 @@ mod tests { ]; let proj_exprs = proj_exprs .into_iter() - .map(|(expr, name)| (expr.clone(), name)) + .map(|(expr, name)| (Arc::clone(expr), name)) .collect::>(); let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &schema)?; let output_schema = output_schema(&projection_mapping, &schema)?; @@ -939,9 +955,9 @@ mod tests { ), ]; for (orderings, equal_columns, expected) in test_cases { - let mut eq_properties = EquivalenceProperties::new(schema.clone()); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); for (lhs, rhs) in equal_columns { - eq_properties.add_equal_conditions(lhs, rhs); + eq_properties.add_equal_conditions(lhs, rhs)?; } let orderings = convert_to_orderings(&orderings); @@ -950,7 +966,7 @@ mod tests { let expected = convert_to_orderings(&expected); let projected_eq = - eq_properties.project(&projection_mapping, output_schema.clone()); + eq_properties.project(&projection_mapping, Arc::clone(&output_schema)); let orderings = projected_eq.oeq_class(); let err_msg = format!( @@ -966,174 +982,4 @@ mod tests { Ok(()) } - - #[test] - fn project_orderings_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - // Make sure each ordering after projection is valid. - for ordering in projected_eq.oeq_class().iter() { - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, proj_exprs: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, proj_exprs - ); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - projected_batch.clone(), - )?, - "{}", - err_msg - ); - } - } - } - } - - Ok(()) - } - - #[test] - fn ordering_satisfy_after_projection_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 20; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - const SORT_OPTIONS: SortOptions = SortOptions { - descending: false, - nulls_first: false, - }; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - // Floor(a) - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - // a + b - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let proj_exprs = vec![ - (col("a", &test_schema)?, "a_new"), - (col("b", &test_schema)?, "b_new"), - (col("c", &test_schema)?, "c_new"), - (col("d", &test_schema)?, "d_new"), - (col("e", &test_schema)?, "e_new"), - (col("f", &test_schema)?, "f_new"), - (floor_a, "floor(a)"), - (a_plus_b, "a+b"), - ]; - - for n_req in 0..=proj_exprs.len() { - for proj_exprs in proj_exprs.iter().combinations(n_req) { - let proj_exprs = proj_exprs - .into_iter() - .map(|(expr, name)| (expr.clone(), name.to_string())) - .collect::>(); - let (projected_batch, projected_eq) = apply_projection( - proj_exprs.clone(), - &table_data_with_properties, - &eq_properties, - )?; - - let projection_mapping = - ProjectionMapping::try_new(&proj_exprs, &test_schema)?; - - let projected_exprs = projection_mapping - .iter() - .map(|(_source, target)| target.clone()) - .collect::>(); - - for n_req in 0..=projected_exprs.len() { - for exprs in projected_exprs.iter().combinations(n_req) { - let requirement = exprs - .into_iter() - .map(|expr| PhysicalSortExpr { - expr: expr.clone(), - options: SORT_OPTIONS, - }) - .collect::>(); - let expected = is_table_same_after_sort( - requirement.clone(), - projected_batch.clone(), - )?; - let err_msg = format!( - "Error in test case requirement:{:?}, expected: {:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}, projected_eq.oeq_class: {:?}, projected_eq.eq_group: {:?}, projected_eq.constants: {:?}, projection_mapping: {:?}", - requirement, expected, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants, projected_eq.oeq_class, projected_eq.eq_group, projected_eq.constants, projection_mapping - ); - // Check whether ordering_satisfy API result and - // experimental result matches. - assert_eq!( - projected_eq.ordering_satisfy(&requirement), - expected, - "{}", - err_msg - ); - } - } - } - } - } - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 555f0ad31786..1eb88d8a26f0 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,34 +15,46 @@ // specific language governing permissions and limitations // under the License. +use std::fmt; +use std::fmt::Display; use std::hash::{Hash, Hasher}; +use std::iter::Peekable; +use std::slice::Iter; use std::sync::Arc; -use arrow_schema::{SchemaRef, SortOptions}; -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; - -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{JoinSide, JoinType, Result}; - +use super::ordering::collapse_lex_ordering; +use crate::equivalence::class::const_exprs_contains; use crate::equivalence::{ - collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, + collapse_lex_req, EquivalenceClass, EquivalenceGroup, OrderingEquivalenceClass, + ProjectionMapping, }; -use crate::expressions::{CastExpr, Literal}; -use crate::sort_properties::{ExprOrdering, SortProperties}; +use crate::expressions::{with_new_schema, CastExpr, Column, Literal}; use crate::{ - physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, + physical_exprs_contains, ConstExpr, LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use super::ordering::collapse_lex_ordering; +use arrow_schema::{SchemaRef, SortOptions}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; +use datafusion_physical_expr_common::utils::ExprPropertiesNode; + +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; -/// A `EquivalenceProperties` object stores useful information related to a schema. +/// A `EquivalenceProperties` object stores information known about the output +/// of a plan node, that can be used to optimize the plan. +/// /// Currently, it keeps track of: -/// - Equivalent expressions, e.g expressions that have same value. -/// - Valid sort expressions (orderings) for the schema. -/// - Constants expressions (e.g expressions that are known to have constant values). +/// - Sort expressions (orderings) +/// - Equivalent expressions: expressions that are known to have same value. +/// - Constants expressions: expressions that are known to contain a single +/// constant value. +/// +/// # Example equivalent sort expressions /// /// Consider table below: /// @@ -57,9 +69,13 @@ use super::ordering::collapse_lex_ordering; /// └---┴---┘ /// ``` /// -/// where both `a ASC` and `b DESC` can describe the table ordering. With -/// `EquivalenceProperties`, we can keep track of these different valid sort -/// expressions and treat `a ASC` and `b DESC` on an equal footing. +/// In this case, both `a ASC` and `b DESC` can describe the table ordering. +/// `EquivalenceProperties`, tracks these different valid sort expressions and +/// treat `a ASC` and `b DESC` on an equal footing. For example if the query +/// specifies the output sorted by EITHER `a ASC` or `b DESC`, the sort can be +/// avoided. +/// +/// # Example equivalent expressions /// /// Similarly, consider the table below: /// @@ -74,11 +90,39 @@ use super::ordering::collapse_lex_ordering; /// └---┴---┘ /// ``` /// -/// where columns `a` and `b` always have the same value. We keep track of such -/// equivalences inside this object. With this information, we can optimize -/// things like partitioning. For example, if the partition requirement is -/// `Hash(a)` and output partitioning is `Hash(b)`, then we can deduce that -/// the existing partitioning satisfies the requirement. +/// In this case, columns `a` and `b` always have the same value, which can of +/// such equivalences inside this object. With this information, Datafusion can +/// optimize operations such as. For example, if the partition requirement is +/// `Hash(a)` and output partitioning is `Hash(b)`, then DataFusion avoids +/// repartitioning the data as the existing partitioning satisfies the +/// requirement. +/// +/// # Code Example +/// ``` +/// # use std::sync::Arc; +/// # use arrow_schema::{Schema, Field, DataType, SchemaRef}; +/// # use datafusion_physical_expr::{ConstExpr, EquivalenceProperties}; +/// # use datafusion_physical_expr::expressions::col; +/// use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +/// # let schema: SchemaRef = Arc::new(Schema::new(vec![ +/// # Field::new("a", DataType::Int32, false), +/// # Field::new("b", DataType::Int32, false), +/// # Field::new("c", DataType::Int32, false), +/// # ])); +/// # let col_a = col("a", &schema).unwrap(); +/// # let col_b = col("b", &schema).unwrap(); +/// # let col_c = col("c", &schema).unwrap(); +/// // This object represents data that is sorted by a ASC, c DESC +/// // with a single constant value of b +/// let mut eq_properties = EquivalenceProperties::new(schema) +/// .with_constants(vec![ConstExpr::from(col_b)]); +/// eq_properties.add_new_ordering(LexOrdering::new(vec![ +/// PhysicalSortExpr::new_default(col_a).asc(), +/// PhysicalSortExpr::new_default(col_c).desc(), +/// ])); +/// +/// assert_eq!(eq_properties.to_string(), "order: [[a@0 ASC, c@2 DESC]], const: [b@1]") +/// ``` #[derive(Debug, Clone)] pub struct EquivalenceProperties { /// Collection of equivalence classes that store expressions with the same @@ -89,7 +133,7 @@ pub struct EquivalenceProperties { /// Expressions whose values are constant throughout the table. /// TODO: We do not need to track constants separately, they can be tracked /// inside `eq_groups` as `Literal` expressions. - pub constants: Vec>, + pub constants: Vec, /// Schema associated with this object. schema: SchemaRef, } @@ -131,7 +175,7 @@ impl EquivalenceProperties { } /// Returns a reference to the constant expressions - pub fn constants(&self) -> &[Arc] { + pub fn constants(&self) -> &[ConstExpr] { &self.constants } @@ -141,7 +185,8 @@ impl EquivalenceProperties { let mut output_ordering = self.oeq_class().output_ordering().unwrap_or_default(); // Prune out constant expressions output_ordering - .retain(|sort_expr| !physical_exprs_contains(constants, &sort_expr.expr)); + .inner + .retain(|sort_expr| !const_exprs_contains(constants, &sort_expr.expr)); (!output_ordering.is_empty()).then_some(output_ordering) } @@ -152,7 +197,7 @@ impl EquivalenceProperties { OrderingEquivalenceClass::new( self.oeq_class .iter() - .map(|ordering| self.normalize_sort_exprs(ordering)) + .map(|ordering| self.normalize_sort_exprs(ordering.as_ref())) .collect(), ) } @@ -161,7 +206,7 @@ impl EquivalenceProperties { pub fn extend(mut self, other: Self) -> Self { self.eq_group.extend(other.eq_group); self.oeq_class.extend(other.oeq_class); - self.add_constants(other.constants) + self.with_constants(other.constants) } /// Clears (empties) the ordering equivalence class within this object. @@ -170,6 +215,12 @@ impl EquivalenceProperties { self.oeq_class.clear(); } + /// Removes constant expressions that may change across partitions. + /// This method should be used when data from different partitions are merged. + pub fn clear_per_partition_constants(&mut self) { + self.constants.retain(|item| item.across_partitions()); + } + /// Extends this `EquivalenceProperties` by adding the orderings inside the /// ordering equivalence class `other`. pub fn add_ordering_equivalence_class(&mut self, other: OrderingEquivalenceClass) { @@ -184,6 +235,11 @@ impl EquivalenceProperties { self.oeq_class.add_new_orderings(orderings); } + /// Adds a single ordering to the existing ordering equivalence class. + pub fn add_new_ordering(&mut self, ordering: LexOrdering) { + self.add_new_orderings([ordering]); + } + /// Incorporates the given equivalence group to into the existing /// equivalence group within. pub fn add_equivalence_group(&mut self, other_eq_group: EquivalenceGroup) { @@ -197,27 +253,146 @@ impl EquivalenceProperties { &mut self, left: &Arc, right: &Arc, - ) { + ) -> Result<()> { + // Discover new constants in light of new the equality: + if self.is_expr_constant(left) { + // Left expression is constant, add right as constant + if !const_exprs_contains(&self.constants, right) { + self.constants + .push(ConstExpr::from(right).with_across_partitions(true)); + } + } else if self.is_expr_constant(right) { + // Right expression is constant, add left as constant + if !const_exprs_contains(&self.constants, left) { + self.constants + .push(ConstExpr::from(left).with_across_partitions(true)); + } + } + + // Add equal expressions to the state self.eq_group.add_equal_conditions(left, right); + + // Discover any new orderings + self.discover_new_orderings(left)?; + Ok(()) + } + + /// Track/register physical expressions with constant values. + #[deprecated(since = "43.0.0", note = "Use [`with_constants`] instead")] + pub fn add_constants(self, constants: impl IntoIterator) -> Self { + self.with_constants(constants) + } + + /// Remove the specified constant + pub fn remove_constant(mut self, c: &ConstExpr) -> Self { + self.constants.retain(|existing| existing != c); + self } /// Track/register physical expressions with constant values. - pub fn add_constants( + pub fn with_constants( mut self, - constants: impl IntoIterator>, + constants: impl IntoIterator, ) -> Self { - for expr in self.eq_group.normalize_exprs(constants) { - if !physical_exprs_contains(&self.constants, &expr) { - self.constants.push(expr); + let (const_exprs, across_partition_flags): ( + Vec>, + Vec, + ) = constants + .into_iter() + .map(|const_expr| { + let across_partitions = const_expr.across_partitions(); + let expr = const_expr.owned_expr(); + (expr, across_partitions) + }) + .unzip(); + for (expr, across_partitions) in self + .eq_group + .normalize_exprs(const_exprs) + .into_iter() + .zip(across_partition_flags) + { + if !const_exprs_contains(&self.constants, &expr) { + let const_expr = + ConstExpr::from(expr).with_across_partitions(across_partitions); + self.constants.push(const_expr); } } + + for ordering in self.normalized_oeq_class().iter() { + if let Err(e) = self.discover_new_orderings(&ordering[0].expr) { + log::debug!("error discovering new orderings: {e}"); + } + } + self } + // Discover new valid orderings in light of a new equality. + // Accepts a single argument (`expr`) which is used to determine + // which orderings should be updated. + // When constants or equivalence classes are changed, there may be new orderings + // that can be discovered with the new equivalence properties. + // For a discussion, see: https://github.com/apache/datafusion/issues/9812 + fn discover_new_orderings(&mut self, expr: &Arc) -> Result<()> { + let normalized_expr = self.eq_group().normalize_expr(Arc::clone(expr)); + let eq_class = self + .eq_group + .classes + .iter() + .find_map(|class| { + class + .contains(&normalized_expr) + .then(|| class.clone().into_vec()) + }) + .unwrap_or_else(|| vec![Arc::clone(&normalized_expr)]); + + let mut new_orderings: Vec = vec![]; + for (ordering, next_expr) in self + .normalized_oeq_class() + .iter() + .filter(|ordering| ordering[0].expr.eq(&normalized_expr)) + // First expression after leading ordering + .filter_map(|ordering| Some(ordering).zip(ordering.inner.get(1))) + { + let leading_ordering = ordering[0].options; + // Currently, we only handle expressions with a single child. + // TODO: It should be possible to handle expressions orderings like + // f(a, b, c), a, b, c if f is monotonic in all arguments. + for equivalent_expr in &eq_class { + let children = equivalent_expr.children(); + if children.len() == 1 + && children[0].eq(&next_expr.expr) + && SortProperties::Ordered(leading_ordering) + == equivalent_expr + .get_properties(&[ExprProperties { + sort_properties: SortProperties::Ordered( + leading_ordering, + ), + range: Interval::make_unbounded( + &equivalent_expr.data_type(&self.schema)?, + )?, + }])? + .sort_properties + { + // Assume existing ordering is [a ASC, b ASC] + // When equality a = f(b) is given, If we know that given ordering `[b ASC]`, ordering `[f(b) ASC]` is valid, + // then we can deduce that ordering `[b ASC]` is also valid. + // Hence, ordering `[b ASC]` can be added to the state as valid ordering. + // (e.g. existing ordering where leading ordering is removed) + new_orderings.push(LexOrdering::new(ordering[1..].to_vec())); + break; + } + } + } + + self.oeq_class.add_new_orderings(new_orderings); + Ok(()) + } + /// Updates the ordering equivalence group within assuming that the table /// is re-sorted according to the argument `sort_exprs`. Note that constants /// and equivalence classes are unchanged as they are unaffected by a re-sort. - pub fn with_reorder(mut self, sort_exprs: Vec) -> Self { + pub fn with_reorder(mut self, sort_exprs: LexOrdering) -> Self { // TODO: In some cases, existing ordering equivalences may still be valid add this analysis. self.oeq_class = OrderingEquivalenceClass::new(vec![sort_exprs]); self @@ -260,7 +435,13 @@ impl EquivalenceProperties { sort_reqs: LexRequirementRef, ) -> LexRequirement { let normalized_sort_reqs = self.eq_group.normalize_sort_requirements(sort_reqs); - let constants_normalized = self.eq_group.normalize_exprs(self.constants.clone()); + let mut constant_exprs = vec![]; + constant_exprs.extend( + self.constants + .iter() + .map(|const_expr| Arc::clone(const_expr.expr())), + ); + let constants_normalized = self.eq_group.normalize_exprs(constant_exprs); // Prune redundant sections in the requirement: collapse_lex_req( normalized_sort_reqs @@ -304,8 +485,8 @@ impl EquivalenceProperties { // From the analysis above, we know that `[a ASC]` is satisfied. Then, // we add column `a` as constant to the algorithm state. This enables us // to deduce that `(b + c) ASC` is satisfied, given `a` is constant. - eq_properties = - eq_properties.add_constants(std::iter::once(normalized_req.expr)); + eq_properties = eq_properties + .with_constants(std::iter::once(ConstExpr::from(normalized_req.expr))); } true } @@ -323,11 +504,15 @@ impl EquivalenceProperties { /// /// Returns `true` if the specified ordering is satisfied, `false` otherwise. fn ordering_satisfy_single(&self, req: &PhysicalSortRequirement) -> bool { - let expr_ordering = self.get_expr_ordering(req.expr.clone()); - let ExprOrdering { expr, data, .. } = expr_ordering; - match data { + let ExprProperties { + sort_properties, .. + } = self.get_expr_properties(Arc::clone(&req.expr)); + match sort_properties { SortProperties::Ordered(options) => { - let sort_expr = PhysicalSortExpr { expr, options }; + let sort_expr = PhysicalSortExpr { + expr: Arc::clone(&req.expr), + options, + }; sort_expr.satisfy(req, self.schema()) } // Singleton expressions satisfies any ordering. @@ -389,8 +574,9 @@ impl EquivalenceProperties { ) -> Option { let mut lhs = self.normalize_sort_requirements(req1); let mut rhs = self.normalize_sort_requirements(req2); - lhs.iter_mut() - .zip(rhs.iter_mut()) + lhs.inner + .iter_mut() + .zip(rhs.inner.iter_mut()) .all(|(lhs, rhs)| { lhs.expr.eq(&rhs.expr) && match (lhs.options, rhs.options) { @@ -409,33 +595,6 @@ impl EquivalenceProperties { .then_some(if lhs.len() >= rhs.len() { lhs } else { rhs }) } - /// Calculates the "meet" of the given orderings (`lhs` and `rhs`). - /// The meet of a set of orderings is the finest ordering that is satisfied - /// by all the orderings in that set. For details, see: - /// - /// - /// - /// If there is no ordering that satisfies both `lhs` and `rhs`, returns - /// `None`. As an example, the meet of orderings `[a ASC]` and `[a ASC, b ASC]` - /// is `[a ASC]`. - pub fn get_meet_ordering( - &self, - lhs: LexOrderingRef, - rhs: LexOrderingRef, - ) -> Option { - let lhs = self.normalize_sort_exprs(lhs); - let rhs = self.normalize_sort_exprs(rhs); - let mut meet = vec![]; - for (lhs, rhs) in lhs.into_iter().zip(rhs.into_iter()) { - if lhs.eq(&rhs) { - meet.push(lhs); - } else { - break; - } - } - (!meet.is_empty()).then_some(meet) - } - /// we substitute the ordering according to input expression type, this is a simplified version /// In this case, we just substitute when the expression satisfy the following condition: /// I. just have one column and is a CAST expression @@ -447,8 +606,8 @@ impl EquivalenceProperties { pub fn substitute_ordering_component( &self, mapping: &ProjectionMapping, - sort_expr: &[PhysicalSortExpr], - ) -> Result>> { + sort_expr: LexOrderingRef, + ) -> Result> { let new_orderings = sort_expr .iter() .map(|sort_expr| { @@ -458,7 +617,7 @@ impl EquivalenceProperties { .filter(|source| expr_refers(source, &sort_expr.expr)) .cloned() .collect(); - let mut res = vec![sort_expr.clone()]; + let mut res = LexOrdering::new(vec![sort_expr.clone()]); // TODO: Add one-to-ones analysis for ScalarFunctions. for r_expr in referring_exprs { // we check whether this expression is substitutable or not @@ -469,7 +628,7 @@ impl EquivalenceProperties { && cast_expr.is_bigger_cast(expr_type) { res.push(PhysicalSortExpr { - expr: r_expr.clone(), + expr: Arc::clone(&r_expr), options: sort_expr.options, }); } @@ -481,7 +640,9 @@ impl EquivalenceProperties { // Generate all valid orderings, given substituted expressions. let res = new_orderings .into_iter() + .map(|ordering| ordering.inner) .multi_cartesian_product() + .map(LexOrdering::new) .collect::>(); Ok(res) } @@ -495,7 +656,7 @@ impl EquivalenceProperties { let orderings = &self.oeq_class.orderings; let new_order = orderings .iter() - .map(|order| self.substitute_ordering_component(mapping, order)) + .map(|order| self.substitute_ordering_component(mapping, order.as_ref())) .collect::>>()?; let new_order = new_order.into_iter().flatten().collect(); self.oeq_class = OrderingEquivalenceClass::new(new_order); @@ -552,7 +713,7 @@ impl EquivalenceProperties { /// c ASC: Node {None, HashSet{a ASC}} /// ``` fn construct_dependency_map(&self, mapping: &ProjectionMapping) -> DependencyMap { - let mut dependency_map = IndexMap::new(); + let mut dependency_map = DependencyMap::new(); for ordering in self.normalized_oeq_class().iter() { for (idx, sort_expr) in ordering.iter().enumerate() { let target_sort_expr = @@ -574,13 +735,11 @@ impl EquivalenceProperties { let dependency = idx.checked_sub(1).map(|a| &ordering[a]); // Add sort expressions that can be projected or referred to // by any of the projection expressions to the dependency map: - dependency_map - .entry(sort_expr.clone()) - .or_insert_with(|| DependencyNode { - target_sort_expr: target_sort_expr.clone(), - dependencies: IndexSet::new(), - }) - .insert_dependency(dependency); + dependency_map.insert( + sort_expr, + target_sort_expr.as_ref(), + dependency, + ); } if !is_projected { // If we can not project, stop constructing the dependency @@ -612,8 +771,9 @@ impl EquivalenceProperties { map: mapping .iter() .map(|(source, target)| { - let normalized_source = self.eq_group.normalize_expr(source.clone()); - (normalized_source, target.clone()) + let normalized_source = + self.eq_group.normalize_expr(Arc::clone(source)); + (normalized_source, Arc::clone(target)) }) .collect(), } @@ -643,8 +803,9 @@ impl EquivalenceProperties { referred_dependencies(&dependency_map, source) .into_iter() .filter_map(|relevant_deps| { - if let SortProperties::Ordered(options) = - get_expr_ordering(source, &relevant_deps) + if let Ok(SortProperties::Ordered(options)) = + get_expr_properties(source, &relevant_deps, &self.schema) + .map(|prop| prop.sort_properties) { Some((options, relevant_deps)) } else { @@ -654,7 +815,7 @@ impl EquivalenceProperties { }) .flat_map(|(options, relevant_deps)| { let sort_expr = PhysicalSortExpr { - expr: target.clone(), + expr: Arc::clone(target), options, }; // Generate dependent orderings (i.e. prefixes for `sort_expr`): @@ -678,7 +839,7 @@ impl EquivalenceProperties { if prefixes.is_empty() { // If prefix is empty, there is no dependency. Insert // empty ordering: - prefixes = vec![vec![]]; + prefixes = vec![LexOrdering::default()]; } // Append current ordering on top its dependencies: for ordering in prefixes.iter_mut() { @@ -710,24 +871,27 @@ impl EquivalenceProperties { /// # Returns /// /// Returns a `Vec>` containing the projected constants. - fn projected_constants( - &self, - mapping: &ProjectionMapping, - ) -> Vec> { + fn projected_constants(&self, mapping: &ProjectionMapping) -> Vec { // First, project existing constants. For example, assume that `a + b` // is known to be constant. If the projection were `a as a_new`, `b as b_new`, // then we would project constant `a + b` as `a_new + b_new`. let mut projected_constants = self .constants .iter() - .flat_map(|expr| self.eq_group.project_expr(mapping, expr)) + .flat_map(|const_expr| { + const_expr.map(|expr| self.eq_group.project_expr(mapping, expr)) + }) .collect::>(); // Add projection expressions that are known to be constant: for (source, target) in mapping.iter() { if self.is_expr_constant(source) - && !physical_exprs_contains(&projected_constants, target) + && !const_exprs_contains(&projected_constants, target) { - projected_constants.push(target.clone()); + let across_partitions = self.is_expr_constant_accross_partitions(source); + // Expression evaluates to single value + projected_constants.push( + ConstExpr::from(target).with_across_partitions(across_partitions), + ); } } projected_constants @@ -782,16 +946,27 @@ impl EquivalenceProperties { let ordered_exprs = search_indices .iter() .flat_map(|&idx| { - let ExprOrdering { expr, data, .. } = - eq_properties.get_expr_ordering(exprs[idx].clone()); - match data { - SortProperties::Ordered(options) => { - Some((PhysicalSortExpr { expr, options }, idx)) - } + let ExprProperties { + sort_properties, .. + } = eq_properties.get_expr_properties(Arc::clone(&exprs[idx])); + match sort_properties { + SortProperties::Ordered(options) => Some(( + PhysicalSortExpr { + expr: Arc::clone(&exprs[idx]), + options, + }, + idx, + )), SortProperties::Singleton => { // Assign default ordering to constant expressions let options = SortOptions::default(); - Some((PhysicalSortExpr { expr, options }, idx)) + Some(( + PhysicalSortExpr { + expr: Arc::clone(&exprs[idx]), + options, + }, + idx, + )) } SortProperties::Unordered => None, } @@ -810,13 +985,14 @@ impl EquivalenceProperties { // an implementation strategy confined to this function. for (PhysicalSortExpr { expr, .. }, idx) in &ordered_exprs { eq_properties = - eq_properties.add_constants(std::iter::once(expr.clone())); + eq_properties.with_constants(std::iter::once(ConstExpr::from(expr))); search_indices.shift_remove(idx); } // Add new ordered section to the state. result.extend(ordered_exprs); } - result.into_iter().unzip() + let (left, right) = result.into_iter().unzip(); + (LexOrdering::new(left), right) } /// This function determines whether the provided expression is constant @@ -835,71 +1011,216 @@ impl EquivalenceProperties { // As an example, assume that we know columns `a` and `b` are constant. // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will // return `false`. - let normalized_constants = self.eq_group.normalize_exprs(self.constants.to_vec()); - let normalized_expr = self.eq_group.normalize_expr(expr.clone()); + let const_exprs = self + .constants + .iter() + .map(|const_expr| Arc::clone(const_expr.expr())); + let normalized_constants = self.eq_group.normalize_exprs(const_exprs); + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); + is_constant_recurse(&normalized_constants, &normalized_expr) + } + + /// This function determines whether the provided expression is constant + /// across partitions based on the known constants. + /// + /// # Arguments + /// + /// - `expr`: A reference to a `Arc` representing the + /// expression to be checked. + /// + /// # Returns + /// + /// Returns `true` if the expression is constant across all partitions according + /// to equivalence group, `false` otherwise. + pub fn is_expr_constant_accross_partitions( + &self, + expr: &Arc, + ) -> bool { + // As an example, assume that we know columns `a` and `b` are constant. + // Then, `a`, `b` and `a + b` will all return `true` whereas `c` will + // return `false`. + let const_exprs = self.constants.iter().flat_map(|const_expr| { + if const_expr.across_partitions() { + Some(Arc::clone(const_expr.expr())) + } else { + None + } + }); + let normalized_constants = self.eq_group.normalize_exprs(const_exprs); + let normalized_expr = self.eq_group.normalize_expr(Arc::clone(expr)); is_constant_recurse(&normalized_constants, &normalized_expr) } - /// Retrieves the ordering information for a given physical expression. + /// Retrieves the properties for a given physical expression. /// - /// This function constructs an `ExprOrdering` object for the provided + /// This function constructs an [`ExprProperties`] object for the given /// expression, which encapsulates information about the expression's - /// ordering, including its [`SortProperties`]. + /// properties, including its [`SortProperties`] and [`Interval`]. /// - /// # Arguments + /// # Parameters /// /// - `expr`: An `Arc` representing the physical expression /// for which ordering information is sought. /// /// # Returns /// - /// Returns an `ExprOrdering` object containing the ordering information for - /// the given expression. - pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { - ExprOrdering::new_default(expr.clone()) - .transform_up(|expr| Ok(update_ordering(expr, self))) + /// Returns an [`ExprProperties`] object containing the ordering and range + /// information for the given expression. + pub fn get_expr_properties(&self, expr: Arc) -> ExprProperties { + ExprPropertiesNode::new_unknown(expr) + .transform_up(|expr| update_properties(expr, self)) .data() - // Guaranteed to always return `Ok`. - .unwrap() + .map(|node| node.data) + .unwrap_or(ExprProperties::new_unknown()) + } + + /// Transforms this `EquivalenceProperties` into a new `EquivalenceProperties` + /// by mapping columns in the original schema to columns in the new schema + /// by index. + pub fn with_new_schema(self, schema: SchemaRef) -> Result { + // The new schema and the original schema is aligned when they have the + // same number of columns, and fields at the same index have the same + // type in both schemas. + let schemas_aligned = (self.schema.fields.len() == schema.fields.len()) + && self + .schema + .fields + .iter() + .zip(schema.fields.iter()) + .all(|(lhs, rhs)| lhs.data_type().eq(rhs.data_type())); + if !schemas_aligned { + // Rewriting equivalence properties in terms of new schema is not + // safe when schemas are not aligned: + return plan_err!( + "Cannot rewrite old_schema:{:?} with new schema: {:?}", + self.schema, + schema + ); + } + // Rewrite constants according to new schema: + let new_constants = self + .constants + .into_iter() + .map(|const_expr| { + let across_partitions = const_expr.across_partitions(); + let new_const_expr = with_new_schema(const_expr.owned_expr(), &schema)?; + Ok(ConstExpr::new(new_const_expr) + .with_across_partitions(across_partitions)) + }) + .collect::>>()?; + + // Rewrite orderings according to new schema: + let mut new_orderings = vec![]; + for ordering in self.oeq_class.orderings { + let new_ordering = ordering + .inner + .into_iter() + .map(|mut sort_expr| { + sort_expr.expr = with_new_schema(sort_expr.expr, &schema)?; + Ok(sort_expr) + }) + .collect::>()?; + new_orderings.push(new_ordering); + } + + // Rewrite equivalence classes according to the new schema: + let mut eq_classes = vec![]; + for eq_class in self.eq_group.classes { + let new_eq_exprs = eq_class + .into_vec() + .into_iter() + .map(|expr| with_new_schema(expr, &schema)) + .collect::>()?; + eq_classes.push(EquivalenceClass::new(new_eq_exprs)); + } + + // Construct the resulting equivalence properties: + let mut result = EquivalenceProperties::new(schema); + result.constants = new_constants; + result.add_new_orderings(new_orderings); + result.add_equivalence_group(EquivalenceGroup::new(eq_classes)); + + Ok(result) + } +} + +/// More readable display version of the `EquivalenceProperties`. +/// +/// Format: +/// ```text +/// order: [[a ASC, b ASC], [a ASC, c ASC]], eq: [[a = b], [a = c]], const: [a = 1] +/// ``` +impl Display for EquivalenceProperties { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.eq_group.is_empty() + && self.oeq_class.is_empty() + && self.constants.is_empty() + { + return write!(f, "No properties"); + } + if !self.oeq_class.is_empty() { + write!(f, "order: {}", self.oeq_class)?; + } + if !self.eq_group.is_empty() { + write!(f, ", eq: {}", self.eq_group)?; + } + if !self.constants.is_empty() { + write!(f, ", const: [{}]", ConstExpr::format_list(&self.constants))?; + } + Ok(()) } } -/// Calculates the [`SortProperties`] of a given [`ExprOrdering`] node. -/// The node can either be a leaf node, or an intermediate node: +/// Calculates the properties of a given [`ExprPropertiesNode`]. +/// +/// Order information can be retrieved as: /// - If it is a leaf node, we directly find the order of the node by looking -/// at the given sort expression and equivalence properties if it is a `Column` -/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark -/// it as singleton so that it can cooperate with all ordered columns. +/// at the given sort expression and equivalence properties if it is a `Column` +/// leaf, or we mark it as unordered. In the case of a `Literal` leaf, we mark +/// it as singleton so that it can cooperate with all ordered columns. +/// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` +/// and operator has its own rules on how to propagate the children orderings. +/// However, before we engage in recursion, we check whether this intermediate +/// node directly matches with the sort expression. If there is a match, the +/// sort expression emerges at that node immediately, discarding the recursive +/// result coming from its children. +/// +/// Range information is calculated as: +/// - If it is a `Literal` node, we set the range as a point value. If it is a +/// `Column` node, we set the datatype of the range, but cannot give an interval +/// for the range, yet. /// - If it is an intermediate node, the children states matter. Each `PhysicalExpr` -/// and operator has its own rules on how to propagate the children orderings. -/// However, before we engage in recursion, we check whether this intermediate -/// node directly matches with the sort expression. If there is a match, the -/// sort expression emerges at that node immediately, discarding the recursive -/// result coming from its children. -fn update_ordering( - mut node: ExprOrdering, +/// and operator has its own rules on how to propagate the children range. +fn update_properties( + mut node: ExprPropertiesNode, eq_properties: &EquivalenceProperties, -) -> Transformed { - // We have a Column, which is one of the two possible leaf node types: - let normalized_expr = eq_properties.eq_group.normalize_expr(node.expr.clone()); +) -> Result> { + // First, try to gather the information from the children: + if !node.expr.children().is_empty() { + // We have an intermediate (non-leaf) node, account for its children: + let children_props = node.children.iter().map(|c| c.data.clone()).collect_vec(); + node.data = node.expr.get_properties(&children_props)?; + } else if node.expr.as_any().is::() { + // We have a Literal, which is one of the two possible leaf node types: + node.data = node.expr.get_properties(&[])?; + } else if node.expr.as_any().is::() { + // We have a Column, which is the other possible leaf node type: + node.data.range = + Interval::make_unbounded(&node.expr.data_type(eq_properties.schema())?)? + } + // Now, check what we know about orderings: + let normalized_expr = eq_properties + .eq_group + .normalize_expr(Arc::clone(&node.expr)); if eq_properties.is_expr_constant(&normalized_expr) { - node.data = SortProperties::Singleton; + node.data.sort_properties = SortProperties::Singleton; } else if let Some(options) = eq_properties .normalized_oeq_class() .get_options(&normalized_expr) { - node.data = SortProperties::Ordered(options); - } else if !node.expr.children().is_empty() { - // We have an intermediate (non-leaf) node, account for its children: - let children_orderings = node.children.iter().map(|c| c.data).collect_vec(); - node.data = node.expr.get_ordering(&children_orderings); - } else if node.expr.as_any().is::() { - // We have a Literal, which is the other possible leaf node type: - node.data = node.expr.get_ordering(&[]); - } else { - return Transformed::no(node); + node.data.sort_properties = SortProperties::Ordered(options); } - Transformed::yes(node) + Ok(Transformed::yes(node)) } /// This function determines whether the provided expression is constant @@ -973,10 +1294,10 @@ fn referred_dependencies( // Associate `PhysicalExpr`s with `PhysicalSortExpr`s that contain them: let mut expr_to_sort_exprs = IndexMap::::new(); for sort_expr in dependency_map - .keys() + .sort_exprs() .filter(|sort_expr| expr_refers(source, &sort_expr.expr)) { - let key = ExprWrapper(sort_expr.expr.clone()); + let key = ExprWrapper(Arc::clone(&sort_expr.expr)); expr_to_sort_exprs .entry(key) .or_default() @@ -986,10 +1307,16 @@ fn referred_dependencies( // Generate all valid dependencies for the source. For example, if the source // is `a + b` and the map is `[a -> (a ASC, a DESC), b -> (b ASC)]`, we get // `vec![HashSet(a ASC, b ASC), HashSet(a DESC, b ASC)]`. - expr_to_sort_exprs - .values() + let dependencies = expr_to_sort_exprs + .into_values() + .map(Dependencies::into_inner) + .collect::>(); + dependencies + .iter() .multi_cartesian_product() - .map(|referred_deps| referred_deps.into_iter().cloned().collect()) + .map(|referred_deps| { + Dependencies::new_from_iter(referred_deps.into_iter().cloned()) + }) .collect() } @@ -1011,21 +1338,32 @@ fn construct_prefix_orderings( relevant_sort_expr: &PhysicalSortExpr, dependency_map: &DependencyMap, ) -> Vec { - dependency_map[relevant_sort_expr] + let mut dep_enumerator = DependencyEnumerator::new(); + dependency_map + .get(relevant_sort_expr) + .expect("no relevant sort expr found") .dependencies .iter() - .flat_map(|dep| construct_orderings(dep, dependency_map)) + .flat_map(|dep| dep_enumerator.construct_orderings(dep, dependency_map)) .collect() } -/// Given a set of relevant dependencies (`relevant_deps`) and a map of dependencies -/// (`dependency_map`), this function generates all possible prefix orderings -/// based on the given dependencies. +/// Generates all possible orderings where dependencies are satisfied for the +/// current projection expression. +/// +/// # Example +/// If `dependences` is `a + b ASC` and the dependency map holds dependencies +/// * `a ASC` --> `[c ASC]` +/// * `b ASC` --> `[d DESC]`, +/// +/// This function generates these two sort orders +/// * `[c ASC, d DESC, a + b ASC]` +/// * `[d DESC, c ASC, a + b ASC]` /// /// # Parameters /// -/// * `dependencies` - A reference to the dependencies. -/// * `dependency_map` - A reference to the map of dependencies for expressions. +/// * `dependencies` - Set of relevant expressions. +/// * `dependency_map` - Map of dependencies for expressions that may appear in `dependencies` /// /// # Returns /// @@ -1048,14 +1386,9 @@ fn generate_dependency_orderings( // No dependency, dependent is a leading ordering. if relevant_prefixes.is_empty() { // Return an empty ordering: - return vec![vec![]]; + return vec![LexOrdering::default()]; } - // Generate all possible orderings where dependencies are satisfied for the - // current projection expression. For example, if expression is `a + b ASC`, - // and the dependency for `a ASC` is `[c ASC]`, the dependency for `b ASC` - // is `[d DESC]`, then we generate `[c ASC, d DESC, a + b ASC]` and - // `[d DESC, c ASC, a + b ASC]`. relevant_prefixes .into_iter() .multi_cartesian_product() @@ -1063,14 +1396,20 @@ fn generate_dependency_orderings( prefix_orderings .iter() .permutations(prefix_orderings.len()) - .map(|prefixes| prefixes.into_iter().flatten().cloned().collect()) + .map(|prefixes| { + prefixes + .into_iter() + .flat_map(|ordering| ordering.inner.clone()) + .collect() + }) .collect::>() }) .collect() } -/// This function examines the given expression and the sort expressions it -/// refers to determine the ordering properties of the expression. +/// This function examines the given expression and its properties to determine +/// the ordering properties of the expression. The range knowledge is not utilized +/// yet in the scope of this function. /// /// # Parameters /// @@ -1078,26 +1417,41 @@ fn generate_dependency_orderings( /// which ordering properties need to be determined. /// - `dependencies`: A reference to `Dependencies`, containing sort expressions /// referred to by `expr`. +/// - `schema``: A reference to the schema which the `expr` columns refer. /// /// # Returns /// /// A `SortProperties` indicating the ordering information of the given expression. -fn get_expr_ordering( +fn get_expr_properties( expr: &Arc, dependencies: &Dependencies, -) -> SortProperties { + schema: &SchemaRef, +) -> Result { if let Some(column_order) = dependencies.iter().find(|&order| expr.eq(&order.expr)) { // If exact match is found, return its ordering. - SortProperties::Ordered(column_order.options) + Ok(ExprProperties { + sort_properties: SortProperties::Ordered(column_order.options), + range: Interval::make_unbounded(&expr.data_type(schema)?)?, + }) + } else if expr.as_any().downcast_ref::().is_some() { + Ok(ExprProperties { + sort_properties: SortProperties::Unordered, + range: Interval::make_unbounded(&expr.data_type(schema)?)?, + }) + } else if let Some(literal) = expr.as_any().downcast_ref::() { + Ok(ExprProperties { + sort_properties: SortProperties::Singleton, + range: Interval::try_new(literal.value().clone(), literal.value().clone())?, + }) } else { // Find orderings of its children let child_states = expr .children() .iter() - .map(|child| get_expr_ordering(child, dependencies)) - .collect::>(); + .map(|child| get_expr_properties(child, dependencies, schema)) + .collect::>>()?; // Calculate expression ordering using ordering of its children. - expr.get_ordering(&child_states) + expr.get_properties(&child_states) } } @@ -1121,7 +1475,7 @@ struct DependencyNode { } impl DependencyNode { - // Insert dependency to the state (if exists). + /// Insert dependency to the state (if exists). fn insert_dependency(&mut self, dependency: Option<&PhysicalSortExpr>) { if let Some(dep) = dependency { self.dependencies.insert(dep.clone()); @@ -1129,46 +1483,229 @@ impl DependencyNode { } } -// Using `IndexMap` and `IndexSet` makes sure to generate consistent results across different executions for the same query. -// We could have used `HashSet`, `HashMap` in place of them without any loss of functionality. -// As an example, if existing orderings are `[a ASC, b ASC]`, `[c ASC]` for output ordering -// both `[a ASC, b ASC, c ASC]` and `[c ASC, a ASC, b ASC]` are valid (e.g. concatenated version of the alternative orderings). -// When using `HashSet`, `HashMap` it is not guaranteed to generate consistent result, among the possible 2 results in the example above. -type DependencyMap = IndexMap; -type Dependencies = IndexSet; - -/// This function recursively analyzes the dependencies of the given sort -/// expression within the given dependency map to construct lexicographical -/// orderings that include the sort expression and its dependencies. +impl Display for DependencyNode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if let Some(target) = &self.target_sort_expr { + write!(f, "(target: {}, ", target)?; + } else { + write!(f, "(")?; + } + write!(f, "dependencies: [{}])", self.dependencies) + } +} + +/// Maps an expression --> DependencyNode /// -/// # Parameters +/// # Debugging / deplaying `DependencyMap` /// -/// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) -/// for which lexicographical orderings satisfying its dependencies are to be -/// constructed. -/// - `dependency_map`: A reference to the `DependencyMap` that contains -/// dependencies for different `PhysicalSortExpr`s. +/// This structure implements `Display` to assist debugging. For example: /// -/// # Returns +/// ```text +/// DependencyMap: { +/// a@0 ASC --> (target: a@0 ASC, dependencies: [[]]) +/// b@1 ASC --> (target: b@1 ASC, dependencies: [[a@0 ASC, c@2 ASC]]) +/// c@2 ASC --> (target: c@2 ASC, dependencies: [[b@1 ASC, a@0 ASC]]) +/// d@3 ASC --> (target: d@3 ASC, dependencies: [[c@2 ASC, b@1 ASC]]) +/// } +/// ``` /// -/// A vector of lexicographical orderings (`Vec`) based on the given -/// sort expression and its dependencies. -fn construct_orderings( - referred_sort_expr: &PhysicalSortExpr, - dependency_map: &DependencyMap, -) -> Vec { - // We are sure that `referred_sort_expr` is inside `dependency_map`. - let node = &dependency_map[referred_sort_expr]; - // Since we work on intermediate nodes, we are sure `val.target_sort_expr` - // exists. - let target_sort_expr = node.target_sort_expr.clone().unwrap(); - if node.dependencies.is_empty() { - vec![vec![target_sort_expr]] - } else { +/// # Note on IndexMap Rationale +/// +/// Using `IndexMap` (which preserves insert order) to ensure consistent results +/// across different executions for the same query. We could have used +/// `HashSet`, `HashMap` in place of them without any loss of functionality. +/// +/// As an example, if existing orderings are +/// 1. `[a ASC, b ASC]` +/// 2. `[c ASC]` for +/// +/// Then both the following output orderings are valid +/// 1. `[a ASC, b ASC, c ASC]` +/// 2. `[c ASC, a ASC, b ASC]` +/// +/// (this are both valid as they are concatenated versions of the alternative +/// orderings). When using `HashSet`, `HashMap` it is not guaranteed to generate +/// consistent result, among the possible 2 results in the example above. +#[derive(Debug)] +struct DependencyMap { + inner: IndexMap, +} + +impl DependencyMap { + fn new() -> Self { + Self { + inner: IndexMap::new(), + } + } + + /// Insert a new dependency `sort_expr` --> `dependency` into the map. + /// + /// If `target_sort_expr` is none, a new entry is created with empty dependencies. + fn insert( + &mut self, + sort_expr: &PhysicalSortExpr, + target_sort_expr: Option<&PhysicalSortExpr>, + dependency: Option<&PhysicalSortExpr>, + ) { + self.inner + .entry(sort_expr.clone()) + .or_insert_with(|| DependencyNode { + target_sort_expr: target_sort_expr.cloned(), + dependencies: Dependencies::new(), + }) + .insert_dependency(dependency) + } + + /// Iterator over (sort_expr, DependencyNode) pairs + fn iter(&self) -> impl Iterator { + self.inner.iter() + } + + /// iterator over all sort exprs + fn sort_exprs(&self) -> impl Iterator { + self.inner.keys() + } + + /// Return the dependency node for the given sort expression, if any + fn get(&self, sort_expr: &PhysicalSortExpr) -> Option<&DependencyNode> { + self.inner.get(sort_expr) + } +} + +impl Display for DependencyMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "DependencyMap: {{")?; + for (sort_expr, node) in self.inner.iter() { + writeln!(f, " {sort_expr} --> {node}")?; + } + writeln!(f, "}}") + } +} + +/// A list of sort expressions that can be calculated from a known set of +/// dependencies. +#[derive(Debug, Default, Clone, PartialEq, Eq)] +struct Dependencies { + inner: IndexSet, +} + +impl Display for Dependencies { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + let mut iter = self.inner.iter(); + if let Some(dep) = iter.next() { + write!(f, "{}", dep)?; + } + for dep in iter { + write!(f, ", {}", dep)?; + } + write!(f, "]") + } +} + +impl Dependencies { + /// Create a new empty `Dependencies` instance. + fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Create a new `Dependencies` from an iterator of `PhysicalSortExpr`. + fn new_from_iter(iter: impl IntoIterator) -> Self { + Self { + inner: iter.into_iter().collect(), + } + } + + /// Insert a new dependency into the set. + fn insert(&mut self, sort_expr: PhysicalSortExpr) { + self.inner.insert(sort_expr); + } + + /// Iterator over dependencies in the set + fn iter(&self) -> impl Iterator + Clone { + self.inner.iter() + } + + /// Return the inner set of dependencies + fn into_inner(self) -> IndexSet { + self.inner + } + + /// Returns true if there are no dependencies + fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} + +/// Contains a mapping of all dependencies we have processed for each sort expr +struct DependencyEnumerator<'a> { + /// Maps `expr` --> `[exprs]` that have previously been processed + seen: IndexMap<&'a PhysicalSortExpr, IndexSet<&'a PhysicalSortExpr>>, +} + +impl<'a> DependencyEnumerator<'a> { + fn new() -> Self { + Self { + seen: IndexMap::new(), + } + } + + /// Insert a new dependency, + /// + /// returns false if the dependency was already in the map + /// returns true if the dependency was newly inserted + fn insert( + &mut self, + target: &'a PhysicalSortExpr, + dep: &'a PhysicalSortExpr, + ) -> bool { + self.seen.entry(target).or_default().insert(dep) + } + + /// This function recursively analyzes the dependencies of the given sort + /// expression within the given dependency map to construct lexicographical + /// orderings that include the sort expression and its dependencies. + /// + /// # Parameters + /// + /// - `referred_sort_expr`: A reference to the sort expression (`PhysicalSortExpr`) + /// for which lexicographical orderings satisfying its dependencies are to be + /// constructed. + /// - `dependency_map`: A reference to the `DependencyMap` that contains + /// dependencies for different `PhysicalSortExpr`s. + /// + /// # Returns + /// + /// A vector of lexicographical orderings (`Vec`) based on the given + /// sort expression and its dependencies. + fn construct_orderings( + &mut self, + referred_sort_expr: &'a PhysicalSortExpr, + dependency_map: &'a DependencyMap, + ) -> Vec { + let node = dependency_map + .get(referred_sort_expr) + .expect("`referred_sort_expr` should be inside `dependency_map`"); + // Since we work on intermediate nodes, we are sure `val.target_sort_expr` + // exists. + let target_sort_expr = node.target_sort_expr.as_ref().unwrap(); + // An empty dependency means the referred_sort_expr represents a global ordering. + // Return its projected version, which is the target_expression. + if node.dependencies.is_empty() { + return vec![LexOrdering::new(vec![target_sort_expr.clone()])]; + }; + node.dependencies .iter() .flat_map(|dep| { - let mut orderings = construct_orderings(dep, dependency_map); + let mut orderings = if self.insert(target_sort_expr, dep) { + self.construct_orderings(dep, dependency_map) + } else { + vec![] + }; + for ordering in orderings.iter_mut() { ordering.push(target_sort_expr.clone()) } @@ -1197,8 +1734,16 @@ pub fn join_equivalence_properties( on, )); - let left_oeq_class = left.oeq_class; - let mut right_oeq_class = right.oeq_class; + let EquivalenceProperties { + constants: left_constants, + oeq_class: left_oeq_class, + .. + } = left; + let EquivalenceProperties { + constants: right_constants, + oeq_class: mut right_oeq_class, + .. + } = right; match maintains_input_order { [true, false] => { // In this special case, right side ordering can be prefixed with @@ -1251,6 +1796,15 @@ pub fn join_equivalence_properties( [true, true] => unreachable!("Cannot maintain ordering of both sides"), _ => unreachable!("Join operators can not have more than two children"), } + match join_type { + JoinType::LeftAnti | JoinType::LeftSemi => { + result = result.with_constants(left_constants); + } + JoinType::RightAnti | JoinType::RightSemi => { + result = result.with_constants(right_constants); + } + _ => {} + } result } @@ -1292,26 +1846,311 @@ impl Hash for ExprWrapper { } } -#[cfg(test)] -mod tests { - use std::ops::Not; +/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of `lhs` and `rhs` according to the schema of `lhs`. +/// +/// Rules: The UnionExec does not interleave its inputs: instead it passes each +/// input partition from the children as its own output. +/// +/// Since the output equivalence properties are properties that are true for +/// *all* output partitions, that is the same as being true for all *input* +/// partitions +fn calculate_union_binary( + mut lhs: EquivalenceProperties, + mut rhs: EquivalenceProperties, +) -> Result { + // Harmonize the schema of the rhs with the schema of the lhs (which is the accumulator schema): + if !rhs.schema.eq(&lhs.schema) { + rhs = rhs.with_new_schema(Arc::clone(&lhs.schema))?; + } - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_schema::{Fields, TimeUnit}; + // First, calculate valid constants for the union. An expression is constant + // at the output of the union if it is constant in both sides. + let constants: Vec<_> = lhs + .constants() + .iter() + .filter(|const_expr| const_exprs_contains(rhs.constants(), const_expr.expr())) + .map(|const_expr| { + // TODO: When both sides have a constant column, and the actual + // constant value is the same, then the output properties could + // reflect the constant is valid across all partitions. However we + // don't track the actual value that the ConstExpr takes on, so we + // can't determine that yet + ConstExpr::new(Arc::clone(const_expr.expr())).with_across_partitions(false) + }) + .collect(); - use datafusion_common::DFSchema; - use datafusion_expr::{Operator, ScalarUDF}; + // remove any constants that are shared in both outputs (avoid double counting them) + for c in &constants { + lhs = lhs.remove_constant(c); + rhs = rhs.remove_constant(c); + } - use crate::equivalence::add_offset_to_expr; - use crate::equivalence::tests::{ - convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, - create_random_schema, create_test_params, create_test_schema, - generate_table_for_eq_properties, is_table_same_after_sort, output_schema, - }; - use crate::expressions::{col, BinaryExpr, Column}; - use crate::utils::tests::TestScalarUDF; + // Next, calculate valid orderings for the union by searching for prefixes + // in both sides. + let mut orderings = UnionEquivalentOrderingBuilder::new(); + orderings.add_satisfied_orderings( + lhs.normalized_oeq_class().orderings, + lhs.constants(), + &rhs, + ); + orderings.add_satisfied_orderings( + rhs.normalized_oeq_class().orderings, + rhs.constants(), + &lhs, + ); + let orderings = orderings.build(); + + let mut eq_properties = + EquivalenceProperties::new(lhs.schema).with_constants(constants); + + eq_properties.add_new_orderings(orderings); + Ok(eq_properties) +} + +/// Calculates the union (in the sense of `UnionExec`) `EquivalenceProperties` +/// of the given `EquivalenceProperties` in `eqps` according to the given +/// output `schema` (which need not be the same with those of `lhs` and `rhs` +/// as details such as nullability may be different). +pub fn calculate_union( + eqps: Vec, + schema: SchemaRef, +) -> Result { + // TODO: In some cases, we should be able to preserve some equivalence + // classes. Add support for such cases. + let mut iter = eqps.into_iter(); + let Some(mut acc) = iter.next() else { + return internal_err!( + "Cannot calculate EquivalenceProperties for a union with no inputs" + ); + }; + + // Harmonize the schema of the init with the schema of the union: + if !acc.schema.eq(&schema) { + acc = acc.with_new_schema(schema)?; + } + // Fold in the rest of the EquivalenceProperties: + for props in iter { + acc = calculate_union_binary(acc, props)?; + } + Ok(acc) +} + +#[derive(Debug)] +enum AddedOrdering { + /// The ordering was added to the in progress result + Yes, + /// The ordering was not added + No(LexOrdering), +} + +/// Builds valid output orderings of a `UnionExec` +#[derive(Debug)] +struct UnionEquivalentOrderingBuilder { + orderings: Vec, +} + +impl UnionEquivalentOrderingBuilder { + fn new() -> Self { + Self { orderings: vec![] } + } + + /// Add all orderings from `orderings` that satisfy `properties`, + /// potentially augmented with`constants`. + /// + /// Note: any column that is known to be constant can be inserted into the + /// ordering without changing its meaning + /// + /// For example: + /// * `orderings` contains `[a ASC, c ASC]` and `constants` contains `b` + /// * `properties` has required ordering `[a ASC, b ASC]` + /// + /// Then this will add `[a ASC, b ASC]` to the `orderings` list (as `a` was + /// in the sort order and `b` was a constant). + fn add_satisfied_orderings( + &mut self, + orderings: impl IntoIterator, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) { + for mut ordering in orderings.into_iter() { + // Progressively shorten the ordering to search for a satisfied prefix: + loop { + match self.try_add_ordering(ordering, constants, properties) { + AddedOrdering::Yes => break, + AddedOrdering::No(o) => { + ordering = o; + ordering.pop(); + } + } + } + } + } + + /// Adds `ordering`, potentially augmented with constants, if it satisfies + /// the target `properties` properties. + /// + /// Returns + /// + /// * [`AddedOrdering::Yes`] if the ordering was added (either directly or + /// augmented), or was empty. + /// + /// * [`AddedOrdering::No`] if the ordering was not added + fn try_add_ordering( + &mut self, + ordering: LexOrdering, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) -> AddedOrdering { + if ordering.is_empty() { + AddedOrdering::Yes + } else if constants.is_empty() && properties.ordering_satisfy(ordering.as_ref()) { + // If the ordering satisfies the target properties, no need to + // augment it with constants. + self.orderings.push(ordering); + AddedOrdering::Yes + } else { + // Did not satisfy target properties, try and augment with constants + // to match the properties + if self.try_find_augmented_ordering(&ordering, constants, properties) { + AddedOrdering::Yes + } else { + AddedOrdering::No(ordering) + } + } + } + + /// Attempts to add `constants` to `ordering` to satisfy the properties. + /// + /// returns true if any orderings were added, false otherwise + fn try_find_augmented_ordering( + &mut self, + ordering: &LexOrdering, + constants: &[ConstExpr], + properties: &EquivalenceProperties, + ) -> bool { + // can't augment if there is nothing to augment with + if constants.is_empty() { + return false; + } + let start_num_orderings = self.orderings.len(); + + // for each equivalent ordering in properties, try and augment + // `ordering` it with the constants to match + for existing_ordering in &properties.oeq_class.orderings { + if let Some(augmented_ordering) = self.augment_ordering( + ordering, + constants, + existing_ordering, + &properties.constants, + ) { + if !augmented_ordering.is_empty() { + assert!(properties.ordering_satisfy(augmented_ordering.as_ref())); + self.orderings.push(augmented_ordering); + } + } + } + + self.orderings.len() > start_num_orderings + } + + /// Attempts to augment the ordering with constants to match the + /// `existing_ordering` + /// + /// Returns Some(ordering) if an augmented ordering was found, None otherwise + fn augment_ordering( + &mut self, + ordering: &LexOrdering, + constants: &[ConstExpr], + existing_ordering: &LexOrdering, + existing_constants: &[ConstExpr], + ) -> Option { + let mut augmented_ordering = LexOrdering::default(); + let mut sort_expr_iter = ordering.inner.iter().peekable(); + let mut existing_sort_expr_iter = existing_ordering.inner.iter().peekable(); + + // walk in parallel down the two orderings, trying to match them up + while sort_expr_iter.peek().is_some() || existing_sort_expr_iter.peek().is_some() + { + // If the next expressions are equal, add the next match + // otherwise try and match with a constant + if let Some(expr) = + advance_if_match(&mut sort_expr_iter, &mut existing_sort_expr_iter) + { + augmented_ordering.push(expr); + } else if let Some(expr) = + advance_if_matches_constant(&mut sort_expr_iter, existing_constants) + { + augmented_ordering.push(expr); + } else if let Some(expr) = + advance_if_matches_constant(&mut existing_sort_expr_iter, constants) + { + augmented_ordering.push(expr); + } else { + // no match, can't continue the ordering, return what we have + break; + } + } + + Some(augmented_ordering) + } + + fn build(self) -> Vec { + self.orderings + } +} + +/// Advances two iterators in parallel +/// +/// If the next expressions are equal, the iterators are advanced and returns +/// the matched expression . +/// +/// Otherwise, the iterators are left unchanged and return `None` +fn advance_if_match( + iter1: &mut Peekable>, + iter2: &mut Peekable>, +) -> Option { + if matches!((iter1.peek(), iter2.peek()), (Some(expr1), Some(expr2)) if expr1.eq(expr2)) + { + iter1.next().unwrap(); + iter2.next().cloned() + } else { + None + } +} + +/// Advances the iterator with a constant +/// +/// If the next expression matches one of the constants, advances the iterator +/// returning the matched expression +/// +/// Otherwise, the iterator is left unchanged and returns `None` +fn advance_if_matches_constant( + iter: &mut Peekable>, + constants: &[ConstExpr], +) -> Option { + let expr = iter.peek()?; + let const_expr = constants.iter().find(|c| c.eq_expr(expr))?; + let found_expr = PhysicalSortExpr::new(Arc::clone(const_expr.expr()), expr.options); + iter.next(); + Some(found_expr) +} + +#[cfg(test)] +mod tests { + use std::ops::Not; use super::*; + use crate::equivalence::add_offset_to_expr; + use crate::equivalence::tests::{ + convert_to_orderings, convert_to_sort_exprs, convert_to_sort_reqs, + create_test_params, create_test_schema, output_schema, + }; + use crate::expressions::{col, BinaryExpr, Column}; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{Fields, TimeUnit}; + use datafusion_expr::Operator; #[test] fn project_equivalence_properties_test() -> Result<()> { @@ -1321,25 +2160,25 @@ mod tests { Field::new("c", DataType::Int64, true), ])); - let input_properties = EquivalenceProperties::new(input_schema.clone()); + let input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); let col_a = col("a", &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), + (Arc::clone(&col_a), "a1".to_string()), + (Arc::clone(&col_a), "a2".to_string()), + (Arc::clone(&col_a), "a3".to_string()), + (Arc::clone(&col_a), "a4".to_string()), ]; let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; let out_schema = output_schema(&projection_mapping, &input_schema)?; // a as a1, a as a2, a as a3, a as a3 let proj_exprs = vec![ - (col_a.clone(), "a1".to_string()), - (col_a.clone(), "a2".to_string()), - (col_a.clone(), "a3".to_string()), - (col_a.clone(), "a4".to_string()), + (Arc::clone(&col_a), "a1".to_string()), + (Arc::clone(&col_a), "a2".to_string()), + (Arc::clone(&col_a), "a3".to_string()), + (Arc::clone(&col_a), "a4".to_string()), ]; let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; @@ -1362,6 +2201,51 @@ mod tests { Ok(()) } + #[test] + fn project_equivalence_properties_test_multi() -> Result<()> { + // test multiple input orderings with equivalence properties + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + Field::new("c", DataType::Int64, true), + Field::new("d", DataType::Int64, true), + ])); + + let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); + // add equivalent ordering [a, b, c, d] + input_properties.add_new_ordering(LexOrdering::new(vec![ + parse_sort_expr("a", &input_schema), + parse_sort_expr("b", &input_schema), + parse_sort_expr("c", &input_schema), + parse_sort_expr("d", &input_schema), + ])); + + // add equivalent ordering [a, c, b, d] + input_properties.add_new_ordering(LexOrdering::new(vec![ + parse_sort_expr("a", &input_schema), + parse_sort_expr("c", &input_schema), + parse_sort_expr("b", &input_schema), // NB b and c are swapped + parse_sort_expr("d", &input_schema), + ])); + + // simply project all the columns in order + let proj_exprs = vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + (col("c", &input_schema)?, "c".to_string()), + (col("d", &input_schema)?, "d".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(&proj_exprs, &input_schema)?; + let out_properties = input_properties.project(&projection_mapping, input_schema); + + assert_eq!( + out_properties.to_string(), + "order: [[a@0 ASC, c@2 ASC, b@1 ASC, d@3 ASC], [a@0 ASC, b@1 ASC, c@2 ASC, d@3 ASC]]" + ); + + Ok(()) + } + #[test] fn test_join_equivalence_properties() -> Result<()> { let schema = create_test_schema()?; @@ -1369,8 +2253,8 @@ mod tests { let col_b = &col("b", &schema)?; let col_c = &col("c", &schema)?; let offset = schema.fields.len(); - let col_a2 = &add_offset_to_expr(col_a.clone(), offset); - let col_b2 = &add_offset_to_expr(col_b.clone(), offset); + let col_a2 = &add_offset_to_expr(Arc::clone(col_a), offset); + let col_b2 = &add_offset_to_expr(Arc::clone(col_b), offset); let option_asc = SortOptions { descending: false, nulls_first: false, @@ -1414,8 +2298,8 @@ mod tests { ), ]; for (left_orderings, right_orderings, expected) in test_cases { - let mut left_eq_properties = EquivalenceProperties::new(schema.clone()); - let mut right_eq_properties = EquivalenceProperties::new(schema.clone()); + let mut left_eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); + let mut right_eq_properties = EquivalenceProperties::new(Arc::clone(&schema)); let left_orderings = convert_to_orderings(&left_orderings); let right_orderings = convert_to_orderings(&right_orderings); let expected = convert_to_orderings(&expected); @@ -1463,17 +2347,17 @@ mod tests { let col_b = col("b", &schema)?; let col_d = col("d", &schema)?; let b_plus_d = Arc::new(BinaryExpr::new( - col_b.clone(), + Arc::clone(&col_b), Operator::Plus, - col_d.clone(), + Arc::clone(&col_d), )) as Arc; - let constants = vec![col_a.clone(), col_b.clone()]; - let expr = b_plus_d.clone(); + let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b)]; + let expr = Arc::clone(&b_plus_d); assert!(!is_constant_recurse(&constants, &expr)); - let constants = vec![col_a.clone(), col_b.clone(), col_d.clone()]; - let expr = b_plus_d.clone(); + let constants = vec![Arc::clone(&col_a), Arc::clone(&col_b), Arc::clone(&col_d)]; + let expr = Arc::clone(&b_plus_d); assert!(is_constant_recurse(&constants, &expr)); Ok(()) } @@ -1522,8 +2406,8 @@ mod tests { let mut join_eq_properties = EquivalenceProperties::new(Arc::new(schema)); // a=x and d=w - join_eq_properties.add_equal_conditions(col_a, col_x); - join_eq_properties.add_equal_conditions(col_d, col_w); + join_eq_properties.add_equal_conditions(col_a, col_x)?; + join_eq_properties.add_equal_conditions(col_d, col_w)?; updated_right_ordering_equivalence_class( &mut right_oeq_class, @@ -1560,29 +2444,29 @@ mod tests { let col_c_expr = col("c", &schema)?; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema.clone())); - eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr); + eq_properties.add_equal_conditions(&col_a_expr, &col_c_expr)?; let others = vec![ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_b_expr), options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), + }]), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_c_expr), options: sort_options, - }], + }]), ]; eq_properties.add_new_orderings(others); let mut expected_eqs = EquivalenceProperties::new(Arc::new(schema)); expected_eqs.add_new_orderings([ - vec![PhysicalSortExpr { - expr: col_b_expr.clone(), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_b_expr), options: sort_options, - }], - vec![PhysicalSortExpr { - expr: col_c_expr.clone(), + }]), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(&col_c_expr), options: sort_options, - }], + }]), ]); let oeq_class = eq_properties.oeq_class().clone(); @@ -1603,9 +2487,9 @@ mod tests { ]); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; + let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); - eq_properties.add_new_orderings([vec![ + eq_properties.add_new_orderings([LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: sort_options_not, @@ -1614,21 +2498,21 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: sort_options, }, - ]]); + ])]); let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: sort_options_not }, PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: sort_options } - ] + ]) ); let schema = Schema::new(vec![ @@ -1638,14 +2522,14 @@ mod tests { ]); let col_a = &col("a", &schema)?; let col_b = &col("b", &schema)?; - let required_columns = [col_b.clone(), col_a.clone()]; + let required_columns = [Arc::clone(col_b), Arc::clone(col_a)]; let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); eq_properties.add_new_orderings([ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("c", 2)), options: sort_options, - }], - vec![ + }]), + LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: sort_options_not, @@ -1654,22 +2538,22 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: sort_options, }, - ], + ]), ]); let (result, idxs) = eq_properties.find_longest_permutation(&required_columns); assert_eq!(idxs, vec![0, 1]); assert_eq!( result, - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: sort_options_not }, PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: sort_options } - ] + ]) ); let required_columns = [ @@ -1684,7 +2568,7 @@ mod tests { let mut eq_properties = EquivalenceProperties::new(Arc::new(schema)); // not satisfied orders - eq_properties.add_new_orderings([vec![ + eq_properties.add_new_orderings([LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("b", 1)), options: sort_options_not, @@ -1697,7 +2581,7 @@ mod tests { expr: Arc::new(Column::new("a", 0)), options: sort_options, }, - ]]); + ])]); let (_, idxs) = eq_properties.find_longest_permutation(&required_columns); assert_eq!(idxs, vec![0]); @@ -1705,7 +2589,7 @@ mod tests { } #[test] - fn test_update_ordering() -> Result<()> { + fn test_update_properties() -> Result<()> { let schema = Schema::new(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true), @@ -1723,39 +2607,39 @@ mod tests { nulls_first: false, }; // b=a (e.g they are aliases) - eq_properties.add_equal_conditions(col_b, col_a); + eq_properties.add_equal_conditions(col_b, col_a)?; // [b ASC], [d ASC] eq_properties.add_new_orderings(vec![ - vec![PhysicalSortExpr { - expr: col_b.clone(), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(col_b), options: option_asc, - }], - vec![PhysicalSortExpr { - expr: col_d.clone(), + }]), + LexOrdering::new(vec![PhysicalSortExpr { + expr: Arc::clone(col_d), options: option_asc, - }], + }]), ]); let test_cases = vec![ // d + b ( Arc::new(BinaryExpr::new( - col_d.clone(), + Arc::clone(col_d), Operator::Plus, - col_b.clone(), + Arc::clone(col_b), )) as Arc, SortProperties::Ordered(option_asc), ), // b - (col_b.clone(), SortProperties::Ordered(option_asc)), + (Arc::clone(col_b), SortProperties::Ordered(option_asc)), // a - (col_a.clone(), SortProperties::Ordered(option_asc)), + (Arc::clone(col_a), SortProperties::Ordered(option_asc)), // a + c ( Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_c.clone(), + Arc::clone(col_c), )), SortProperties::Unordered, ), @@ -1764,96 +2648,19 @@ mod tests { let leading_orderings = eq_properties .oeq_class() .iter() - .flat_map(|ordering| ordering.first().cloned()) + .flat_map(|ordering| ordering.inner.first().cloned()) .collect::>(); - let expr_ordering = eq_properties.get_expr_ordering(expr.clone()); + let expr_props = eq_properties.get_expr_properties(Arc::clone(&expr)); let err_msg = format!( "expr:{:?}, expected: {:?}, actual: {:?}, leading_orderings: {leading_orderings:?}", - expr, expected, expr_ordering.data + expr, expected, expr_props.sort_properties ); - assert_eq!(expr_ordering.data, expected, "{}", err_msg); + assert_eq!(expr_props.sort_properties, expected, "{}", err_msg); } Ok(()) } - #[test] - fn test_find_longest_permutation_random() -> Result<()> { - const N_RANDOM_SCHEMA: usize = 100; - const N_ELEMENTS: usize = 125; - const N_DISTINCT: usize = 5; - - for seed in 0..N_RANDOM_SCHEMA { - // Create a random schema with random properties - let (test_schema, eq_properties) = create_random_schema(seed as u64)?; - // Generate a data that satisfies properties given - let table_data_with_properties = - generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - - let test_fun = ScalarUDF::new_from_impl(TestScalarUDF::new()); - let floor_a = crate::udf::create_physical_expr( - &test_fun, - &[col("a", &test_schema)?], - &test_schema, - &[], - &DFSchema::empty(), - )?; - let a_plus_b = Arc::new(BinaryExpr::new( - col("a", &test_schema)?, - Operator::Plus, - col("b", &test_schema)?, - )) as Arc; - let exprs = [ - col("a", &test_schema)?, - col("b", &test_schema)?, - col("c", &test_schema)?, - col("d", &test_schema)?, - col("e", &test_schema)?, - col("f", &test_schema)?, - floor_a, - a_plus_b, - ]; - - for n_req in 0..=exprs.len() { - for exprs in exprs.iter().combinations(n_req) { - let exprs = exprs.into_iter().cloned().collect::>(); - let (ordering, indices) = - eq_properties.find_longest_permutation(&exprs); - // Make sure that find_longest_permutation return values are consistent - let ordering2 = indices - .iter() - .zip(ordering.iter()) - .map(|(&idx, sort_expr)| PhysicalSortExpr { - expr: exprs[idx].clone(), - options: sort_expr.options, - }) - .collect::>(); - assert_eq!( - ordering, ordering2, - "indices and lexicographical ordering do not match" - ); - - let err_msg = format!( - "Error in test case ordering:{:?}, eq_properties.oeq_class: {:?}, eq_properties.eq_group: {:?}, eq_properties.constants: {:?}", - ordering, eq_properties.oeq_class, eq_properties.eq_group, eq_properties.constants - ); - assert_eq!(ordering.len(), indices.len(), "{}", err_msg); - // Since ordered section satisfies schema, we expect - // that result will be same after sort (e.g sort was unnecessary). - assert!( - is_table_same_after_sort( - ordering.clone(), - table_data_with_properties.clone(), - )?, - "{}", - err_msg - ); - } - } - } - - Ok(()) - } #[test] fn test_find_longest_permutation() -> Result<()> { // Schema satisfies following orderings: @@ -1871,9 +2678,9 @@ mod tests { let col_h = &col("h", &test_schema)?; // a + d let a_plus_d = Arc::new(BinaryExpr::new( - col_a.clone(), + Arc::clone(col_a), Operator::Plus, - col_d.clone(), + Arc::clone(col_d), )) as Arc; let option_asc = SortOptions { @@ -1885,16 +2692,16 @@ mod tests { nulls_first: true, }; // [d ASC, h DESC] also satisfies schema. - eq_properties.add_new_orderings([vec![ + eq_properties.add_new_orderings([LexOrdering::new(vec![ PhysicalSortExpr { - expr: col_d.clone(), + expr: Arc::clone(col_d), options: option_asc, }, PhysicalSortExpr { - expr: col_h.clone(), + expr: Arc::clone(col_h), options: option_desc, }, - ]]); + ])]); let test_cases = vec![ // TEST CASE 1 (vec![col_a], vec![(col_a, option_asc)]), @@ -1980,7 +2787,7 @@ mod tests { let col_h = &col("h", &test_schema)?; // Add column h as constant - eq_properties = eq_properties.add_constants(vec![col_h.clone()]); + eq_properties = eq_properties.with_constants(vec![ConstExpr::from(col_h)]); let test_cases = vec![ // TEST CASE 1 @@ -1998,50 +2805,6 @@ mod tests { Ok(()) } - #[test] - fn test_get_meet_ordering() -> Result<()> { - let schema = create_test_schema()?; - let col_a = &col("a", &schema)?; - let col_b = &col("b", &schema)?; - let eq_properties = EquivalenceProperties::new(schema); - let option_asc = SortOptions { - descending: false, - nulls_first: false, - }; - let option_desc = SortOptions { - descending: true, - nulls_first: true, - }; - let tests_cases = vec![ - // Get meet ordering between [a ASC] and [a ASC, b ASC] - // result should be [a ASC] - ( - vec![(col_a, option_asc)], - vec![(col_a, option_asc), (col_b, option_asc)], - Some(vec![(col_a, option_asc)]), - ), - // Get meet ordering between [a ASC] and [a DESC] - // result should be None. - (vec![(col_a, option_asc)], vec![(col_a, option_desc)], None), - // Get meet ordering between [a ASC, b ASC] and [a ASC, b DESC] - // result should be [a ASC]. - ( - vec![(col_a, option_asc), (col_b, option_asc)], - vec![(col_a, option_asc), (col_b, option_desc)], - Some(vec![(col_a, option_asc)]), - ), - ]; - for (lhs, rhs, expected) in tests_cases { - let lhs = convert_to_sort_exprs(&lhs); - let rhs = convert_to_sort_exprs(&rhs); - let expected = expected.map(|expected| convert_to_sort_exprs(&expected)); - let finer = eq_properties.get_meet_ordering(&lhs, &rhs); - assert_eq!(finer, expected) - } - - Ok(()) - } - #[test] fn test_get_finer() -> Result<()> { let schema = create_test_schema()?; @@ -2211,6 +2974,7 @@ mod tests { Ok(()) } + #[test] fn test_eliminate_redundant_monotonic_sorts() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2218,20 +2982,21 @@ mod tests { Field::new("b", DataType::Utf8, true), Field::new("c", DataType::Timestamp(TimeUnit::Nanosecond, None), true), ])); - let base_properties = EquivalenceProperties::new(schema.clone()).with_reorder( - ["a", "b", "c"] - .into_iter() - .map(|c| { - col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { - expr, - options: SortOptions { - descending: false, - nulls_first: true, - }, + let base_properties = EquivalenceProperties::new(Arc::clone(&schema)) + .with_reorder(LexOrdering::new( + ["a", "b", "c"] + .into_iter() + .map(|c| { + col(c, schema.as_ref()).map(|expr| PhysicalSortExpr { + expr, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }) }) - }) - .collect::>>()?, - ); + .collect::>>()?, + )); struct TestCase { name: &'static str, @@ -2247,11 +3012,27 @@ mod tests { let cast_c = Arc::new(CastExpr::new(col_c, DataType::Date32, None)); let cases = vec![ + TestCase { + name: "(a, b, c) -> (c)", + // b is constant, so it should be removed from the sort order + constants: vec![Arc::clone(&col_b)], + equal_conditions: vec![[ + Arc::clone(&cast_c) as Arc, + Arc::clone(&col_a), + ]], + sort_columns: &["c"], + should_satisfy_ordering: true, + }, + // Same test with above test, where equality order is swapped. + // Algorithm shouldn't depend on this order. TestCase { name: "(a, b, c) -> (c)", // b is constant, so it should be removed from the sort order constants: vec![col_b], - equal_conditions: vec![[cast_c.clone(), col_a.clone()]], + equal_conditions: vec![[ + Arc::clone(&col_a), + Arc::clone(&cast_c) as Arc, + ]], sort_columns: &["c"], should_satisfy_ordering: true, }, @@ -2260,37 +3041,651 @@ mod tests { // b is not constant anymore constants: vec![], // a and c are still compatible, but this is irrelevant since the original ordering is (a, b, c) - equal_conditions: vec![[cast_c.clone(), col_a.clone()]], + equal_conditions: vec![[ + Arc::clone(&cast_c) as Arc, + Arc::clone(&col_a), + ]], sort_columns: &["c"], should_satisfy_ordering: false, }, ]; for case in cases { - let mut properties = base_properties.clone().add_constants(case.constants); - for [left, right] in &case.equal_conditions { - properties.add_equal_conditions(left, right) + // Construct the equivalence properties in different orders + // to exercise different code paths + // (The resulting properties _should_ be the same) + for properties in [ + // Equal conditions before constants + { + let mut properties = base_properties.clone(); + for [left, right] in &case.equal_conditions { + properties.add_equal_conditions(left, right)? + } + properties.with_constants( + case.constants.iter().cloned().map(ConstExpr::from), + ) + }, + // Constants before equal conditions + { + let mut properties = base_properties.clone().with_constants( + case.constants.iter().cloned().map(ConstExpr::from), + ); + for [left, right] in &case.equal_conditions { + properties.add_equal_conditions(left, right)? + } + properties + }, + ] { + let sort = case + .sort_columns + .iter() + .map(|&name| { + col(name, &schema).map(|col| PhysicalSortExpr { + expr: col, + options: SortOptions::default(), + }) + }) + .collect::>()?; + + assert_eq!( + properties.ordering_satisfy(sort.as_ref()), + case.should_satisfy_ordering, + "failed test '{}'", + case.name + ); } + } + + Ok(()) + } - let sort = case - .sort_columns + /// Return a new schema with the same types, but new field names + /// + /// The new field names are the old field names with `text` appended. + /// + /// For example, the schema "a", "b", "c" becomes "a1", "b1", "c1" + /// if `text` is "1". + fn append_fields(schema: &SchemaRef, text: &str) -> SchemaRef { + Arc::new(Schema::new( + schema + .fields() .iter() - .map(|&name| { - col(name, &schema).map(|col| PhysicalSortExpr { - expr: col, - options: SortOptions::default(), - }) + .map(|field| { + Field::new( + // Annotate name with `text`: + format!("{}{}", field.name(), text), + field.data_type().clone(), + field.is_nullable(), + ) }) - .collect::>>()?; + .collect::>(), + )) + } - assert_eq!( - properties.ordering_satisfy(&sort), - case.should_satisfy_ordering, - "failed test '{}'", - case.name + #[test] + fn test_union_equivalence_properties_multi_children_1() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_2() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b", "c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b", "c"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_3() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1", "c1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["a2", "b2", "c2"]], &schema3) + .with_expected_sort(vec![vec!["a", "b"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_4() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + let schema3 = append_fields(&schema, "2"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1"]], &schema2) + // Children 3 + .with_child_sort(vec![vec!["b2", "c2"]], &schema3) + .with_expected_sort(vec![]) + .run() + } + + #[test] + fn test_union_equivalence_properties_multi_children_5() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + UnionEquivalenceTest::new(&schema) + // Children 1 + .with_child_sort(vec![vec!["a", "b"], vec!["c"]], &schema) + // Children 2 + .with_child_sort(vec![vec!["a1", "b1"], vec!["c1"]], &schema2) + .with_expected_sort(vec![vec!["a", "b"], vec!["c"]]) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_common_constants() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [b, c] + vec![vec!["a"]], + vec!["b", "c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [b ASC], const [a, c] + vec![vec!["b"]], + vec!["a", "c"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union expected orderings: [[a ASC], [b ASC]], const [c] + vec![vec!["a"], vec!["b"]], + vec!["c"], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_prefix() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, b ASC], const [] + vec![vec!["a", "b"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [a ASC], const [] + vec![vec!["a"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_asc_desc_mismatch() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a DESC], const [] + vec![vec!["a DESC"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union doesn't have any ordering or constant + vec![], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_different_schemas() { + let schema = create_test_schema().unwrap(); + let schema2 = append_fields(&schema, "1"); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC], const [] + vec![vec!["a"]], + vec![], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a1 ASC, b1 ASC], const [] + vec![vec!["a1", "b1"]], + vec![], + &schema2, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [a ASC] + // + // Note that a, and a1 are at the same index for their + // corresponding schemas. + vec![vec!["a"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [b] + vec![vec!["a", "c"]], + vec!["b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [ + // [a ASC, b ASC, c ASC], + // [b ASC, a ASC, c ASC] + // ], const [] + vec![vec!["a", "b", "c"], vec!["b", "a", "c"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_no_fill_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [d] // some other constant + vec![vec!["a", "c"]], + vec!["d"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [[a]] (only a is constant) + vec![vec!["a"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_some_gaps() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [c ASC], const [a, b] // some other constant + vec![vec!["c"]], + vec!["a", "b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [a DESC, b], const [] + vec![vec!["a DESC", "b"]], + vec![], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [[a, b]] (can fill in the a/b with constants) + vec![vec!["a DESC", "b"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_fill_gaps_non_symmetric() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child orderings: [a ASC, c ASC], const [b] + vec![vec!["a", "c"]], + vec!["b"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child orderings: [b ASC, c ASC], const [a] + vec![vec!["b DESC", "c"]], + vec!["a"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: [ + // [a ASC, b ASC, c ASC], + // [b ASC, a ASC, c ASC] + // ], const [] + vec![vec!["a", "b DESC", "c"], vec!["b DESC", "a", "c"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_gap_fill_symmetric() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a ASC, b ASC, d ASC], const [c] + vec![vec!["a", "b", "d"]], + vec!["c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, c ASC, d ASC], const [b] + vec![vec!["a", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a, b, c, d] + // [a, c, b, d] + vec![vec!["a", "c", "b", "d"], vec!["a", "b", "c", "d"]], + vec![], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_gap_fill_and_common() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // First child: [a DESC, d ASC], const [b, c] + vec![vec!["a DESC", "d"]], + vec!["b", "c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a DESC, c ASC, d ASC], const [b] + vec![vec!["a DESC", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a DESC, c, d] [b] + vec![vec!["a DESC", "c", "d"]], + vec!["b"], + ) + .run() + } + + #[test] + fn test_union_equivalence_properties_constants_middle_desc() { + let schema = create_test_schema().unwrap(); + UnionEquivalenceTest::new(&schema) + .with_child_sort_and_const_exprs( + // NB `b DESC` in the first child + // + // First child: [a ASC, b DESC, d ASC], const [c] + vec![vec!["a", "b DESC", "d"]], + vec!["c"], + &schema, + ) + .with_child_sort_and_const_exprs( + // Second child: [a ASC, c ASC, d ASC], const [b] + vec![vec!["a", "c", "d"]], + vec!["b"], + &schema, + ) + .with_expected_sort_and_const_exprs( + // Union orderings: + // [a, b, d] (c constant) + // [a, c, d] (b constant) + vec![vec!["a", "c", "b DESC", "d"], vec!["a", "b DESC", "c", "d"]], + vec![], + ) + .run() + } + + // TODO tests with multiple constants + + #[derive(Debug)] + struct UnionEquivalenceTest { + /// The schema of the output of the Union + output_schema: SchemaRef, + /// The equivalence properties of each child to the union + child_properties: Vec, + /// The expected output properties of the union. Must be set before + /// running `build` + expected_properties: Option, + } + + impl UnionEquivalenceTest { + fn new(output_schema: &SchemaRef) -> Self { + Self { + output_schema: Arc::clone(output_schema), + child_properties: vec![], + expected_properties: None, + } + } + + /// Add a union input with the specified orderings + /// + /// See [`Self::make_props`] for the format of the strings in `orderings` + fn with_child_sort( + mut self, + orderings: Vec>, + schema: &SchemaRef, + ) -> Self { + let properties = self.make_props(orderings, vec![], schema); + self.child_properties.push(properties); + self + } + + /// Add a union input with the specified orderings and constant + /// equivalences + /// + /// See [`Self::make_props`] for the format of the strings in + /// `orderings` and `constants` + fn with_child_sort_and_const_exprs( + mut self, + orderings: Vec>, + constants: Vec<&str>, + schema: &SchemaRef, + ) -> Self { + let properties = self.make_props(orderings, constants, schema); + self.child_properties.push(properties); + self + } + + /// Set the expected output sort order for the union of the children + /// + /// See [`Self::make_props`] for the format of the strings in `orderings` + fn with_expected_sort(mut self, orderings: Vec>) -> Self { + let properties = self.make_props(orderings, vec![], &self.output_schema); + self.expected_properties = Some(properties); + self + } + + /// Set the expected output sort order and constant expressions for the + /// union of the children + /// + /// See [`Self::make_props`] for the format of the strings in + /// `orderings` and `constants`. + fn with_expected_sort_and_const_exprs( + mut self, + orderings: Vec>, + constants: Vec<&str>, + ) -> Self { + let properties = self.make_props(orderings, constants, &self.output_schema); + self.expected_properties = Some(properties); + self + } + + /// compute the union's output equivalence properties from the child + /// properties, and compare them to the expected properties + fn run(self) { + let Self { + output_schema, + child_properties, + expected_properties, + } = self; + + let expected_properties = + expected_properties.expect("expected_properties not set"); + + // try all permutations of the children + // as the code treats lhs and rhs differently + for child_properties in child_properties + .iter() + .cloned() + .permutations(child_properties.len()) + { + println!("--- permutation ---"); + for c in &child_properties { + println!("{c}"); + } + let actual_properties = + calculate_union(child_properties, Arc::clone(&output_schema)) + .expect("failed to calculate union equivalence properties"); + assert_eq_properties_same( + &actual_properties, + &expected_properties, + format!( + "expected: {expected_properties:?}\nactual: {actual_properties:?}" + ), + ); + } + } + + /// Make equivalence properties for the specified columns named in orderings and constants + /// + /// orderings: strings formatted like `"a"` or `"a DESC"`. See [`parse_sort_expr`] + /// constants: strings formatted like `"a"`. + fn make_props( + &self, + orderings: Vec>, + constants: Vec<&str>, + schema: &SchemaRef, + ) -> EquivalenceProperties { + let orderings = orderings + .iter() + .map(|ordering| { + ordering + .iter() + .map(|name| parse_sort_expr(name, schema)) + .collect::() + }) + .collect::>(); + + let constants = constants + .iter() + .map(|col_name| ConstExpr::new(col(col_name, schema).unwrap())) + .collect::>(); + + EquivalenceProperties::new_with_orderings(Arc::clone(schema), &orderings) + .with_constants(constants) + } + } + + fn assert_eq_properties_same( + lhs: &EquivalenceProperties, + rhs: &EquivalenceProperties, + err_msg: String, + ) { + // Check whether constants are same + let lhs_constants = lhs.constants(); + let rhs_constants = rhs.constants(); + for rhs_constant in rhs_constants { + assert!( + const_exprs_contains(lhs_constants, rhs_constant.expr()), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" ); } + assert_eq!( + lhs_constants.len(), + rhs_constants.len(), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); - Ok(()) + // Check whether orderings are same. + let lhs_orderings = lhs.oeq_class(); + let rhs_orderings = &rhs.oeq_class.orderings; + for rhs_ordering in rhs_orderings { + assert!( + lhs_orderings.contains(rhs_ordering), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); + } + assert_eq!( + lhs_orderings.len(), + rhs_orderings.len(), + "{err_msg}\nlhs: {lhs}\nrhs: {rhs}" + ); + } + + /// Converts a string to a physical sort expression + /// + /// # Example + /// * `"a"` -> (`"a"`, `SortOptions::default()`) + /// * `"a ASC"` -> (`"a"`, `SortOptions { descending: false, nulls_first: false }`) + fn parse_sort_expr(name: &str, schema: &SchemaRef) -> PhysicalSortExpr { + let mut parts = name.split_whitespace(); + let name = parts.next().expect("empty sort expression"); + let mut sort_expr = PhysicalSortExpr::new( + col(name, schema).expect("invalid column name"), + SortOptions::default(), + ); + + if let Some(options) = parts.next() { + sort_expr = match options { + "ASC" => sort_expr.asc(), + "DESC" => sort_expr.desc(), + _ => panic!( + "unknown sort options. Expected 'ASC' or 'DESC', got {}", + options + ), + } + } + + assert!( + parts.next().is_none(), + "unexpected tokens in column name. Expected 'name' / 'name ASC' / 'name DESC' but got '{name}'" + ); + + sort_expr } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 76154dca0338..47b04a876b37 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,27 +20,27 @@ mod kernels; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::datum::{apply, apply_cmp}; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; use arrow::compute::kernels::cmp::*; -use arrow::compute::kernels::comparison::regexp_is_match_utf8; -use arrow::compute::kernels::comparison::regexp_is_match_utf8_scalar; +use arrow::compute::kernels::comparison::{regexp_is_match, regexp_is_match_scalar}; use arrow::compute::kernels::concat_elements::concat_elements_utf8; use arrow::compute::{cast, ilike, like, nilike, nlike}; use arrow::datatypes::*; - +use arrow_schema::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::{apply_operator, Interval}; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::binary::get_result_type; use datafusion_expr::{ColumnarValue, Operator}; +use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested}; +use crate::expressions::binary::kernels::concat_elements_utf8view; use kernels::{ bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar, bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn, @@ -53,6 +53,8 @@ pub struct BinaryExpr { left: Arc, op: Operator, right: Arc, + /// Specifies whether an error is returned on overflow or not + fail_on_overflow: bool, } impl BinaryExpr { @@ -62,7 +64,22 @@ impl BinaryExpr { op: Operator, right: Arc, ) -> Self { - Self { left, op, right } + Self { + left, + op, + right, + fail_on_overflow: false, + } + } + + /// Create new binary expression with explicit fail_on_overflow value + pub fn with_fail_on_overflow(self, fail_on_overflow: bool) -> Self { + Self { + left: self.left, + op: self.op, + right: self.right, + fail_on_overflow, + } } /// Get the left side of the binary expression @@ -114,52 +131,27 @@ impl std::fmt::Display for BinaryExpr { } } -/// Invoke a compute kernel on a pair of binary data arrays -macro_rules! compute_utf8_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ - let ll = $LEFT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast left side array"); - let rr = $RIGHT - .as_any() - .downcast_ref::<$DT>() - .expect("compute_op failed to downcast right side array"); - Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?)) - }}; -} - -macro_rules! binary_string_array_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - match $LEFT.data_type() { - DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray), - DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray), - other => internal_err!( - "Data type {:?} not supported for binary operation '{}' on string arrays", - other, stringify!($OP) - ), - } - }}; -} - /// Invoke a boolean kernel on a pair of arrays -macro_rules! boolean_op { - ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ - let ll = as_boolean_array($LEFT).expect("boolean_op failed to downcast array"); - let rr = as_boolean_array($RIGHT).expect("boolean_op failed to downcast array"); - Ok(Arc::new($OP(&ll, &rr)?)) - }}; +#[inline] +fn boolean_op( + left: &dyn Array, + right: &dyn Array, + op: impl FnOnce(&BooleanArray, &BooleanArray) -> Result, +) -> Result, ArrowError> { + let ll = as_boolean_array(left).expect("boolean_op failed to downcast left array"); + let rr = as_boolean_array(right).expect("boolean_op failed to downcast right array"); + op(ll, rr).map(|t| Arc::new(t) as _) } macro_rules! binary_string_array_flag_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ match $LEFT.data_type() { - DataType::Utf8 => { + DataType::Utf8View | DataType::Utf8 => { compute_utf8_flag_op!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) - } + }, DataType::LargeUtf8 => { compute_utf8_flag_op!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) - } + }, other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op operation '{}' on string array", other, stringify!($OP) @@ -185,7 +177,7 @@ macro_rules! compute_utf8_flag_op { } else { None }; - let mut array = paste::expr! {[<$OP _utf8>]}(&ll, &rr, flag.as_ref())?; + let mut array = $OP(ll, rr, flag.as_ref())?; if $NOT { array = not(&array).unwrap(); } @@ -194,14 +186,37 @@ macro_rules! compute_utf8_flag_op { } macro_rules! binary_string_array_flag_op_scalar { - ($LEFT:expr, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + ($LEFT:ident, $RIGHT:expr, $OP:ident, $NOT:expr, $FLAG:expr) => {{ + // This macro is slightly different from binary_string_array_flag_op because, when comparing with a scalar value, + // the query can be optimized in such a way that operands will be dicts, so we need to support it here let result: Result> = match $LEFT.data_type() { - DataType::Utf8 => { + DataType::Utf8View | DataType::Utf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, StringArray, $NOT, $FLAG) - } + }, DataType::LargeUtf8 => { compute_utf8_flag_op_scalar!($LEFT, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG) - } + }, + DataType::Dictionary(_, _) => { + let values = $LEFT.as_any_dictionary().values(); + + match values.data_type() { + DataType::Utf8View | DataType::Utf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, StringArray, $NOT, $FLAG), + DataType::LargeUtf8 => compute_utf8_flag_op_scalar!(values, $RIGHT, $OP, LargeStringArray, $NOT, $FLAG), + other => internal_err!( + "Data type {:?} not supported as a dictionary value type for binary_string_array_flag_op_scalar operation '{}' on string array", + other, stringify!($OP) + ), + }.map( + // downcast_dictionary_array duplicates code per possible key type, so we aim to do all prep work before + |evaluated_values| downcast_dictionary_array! { + $LEFT => { + let unpacked_dict = evaluated_values.take_iter($LEFT.keys().iter().map(|opt| opt.map(|v| v as _))).collect::(); + Arc::new(unpacked_dict) as _ + }, + _ => unreachable!(), + } + ) + }, other => internal_err!( "Data type {:?} not supported for binary_string_array_flag_op_scalar operation '{}' on string array", other, stringify!($OP) @@ -219,20 +234,32 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - if let ScalarValue::Utf8(Some(string_value))|ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT { - let flag = if $FLAG { Some("i") } else { None }; - let mut array = - paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?; - if $NOT { - array = not(&array).unwrap(); - } - Ok(Arc::new(array)) - } else { - internal_err!( + let string_value = match $RIGHT { + ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, + ScalarValue::Dictionary(_, value) => { + match *value { + ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, + other => return internal_err!( + "compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'", + other, stringify!($OP) + ) + } + }, + _ => return internal_err!( "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", $RIGHT, stringify!($OP) ) + + }; + + let flag = $FLAG.then_some("i"); + let mut array = + paste::expr! {[<$OP _scalar>]}(ll, &string_value, flag)?; + if $NOT { + array = not(&array).unwrap(); } + + Ok(Arc::new(array)) }}; } @@ -265,9 +292,19 @@ impl PhysicalExpr for BinaryExpr { let schema = batch.schema(); let input_schema = schema.as_ref(); + if left_data_type.is_nested() { + if right_data_type != left_data_type { + return internal_err!("type mismatch"); + } + return apply_cmp_for_nested(self.op, &lhs, &rhs); + } + match self.op { + Operator::Plus if self.fail_on_overflow => return apply(&lhs, &rhs, add), Operator::Plus => return apply(&lhs, &rhs, add_wrapping), + Operator::Minus if self.fail_on_overflow => return apply(&lhs, &rhs, sub), Operator::Minus => return apply(&lhs, &rhs, sub_wrapping), + Operator::Multiply if self.fail_on_overflow => return apply(&lhs, &rhs, mul), Operator::Multiply => return apply(&lhs, &rhs, mul_wrapping), Operator::Divide => return apply(&lhs, &rhs, div), Operator::Modulo => return apply(&lhs, &rhs, rem), @@ -291,10 +328,14 @@ impl PhysicalExpr for BinaryExpr { // Attempt to use special kernels if one input is scalar and the other is an array let scalar_result = match (&lhs, &rhs) { (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { - // if left is array and right is literal - use scalar operations - self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + // if left is array and right is literal(not NULL) - use scalar operations + if scalar.is_null() { + None + } else { + self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { + r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) + }) + } } (_, _) => None, // default to array implementation }; @@ -312,19 +353,18 @@ impl PhysicalExpr for BinaryExpr { .map(ColumnarValue::Array) } - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(BinaryExpr::new( - children[0].clone(), - self.op, - children[1].clone(), - ))) + Ok(Arc::new( + BinaryExpr::new(Arc::clone(&children[0]), self.op, Arc::clone(&children[1])) + .with_fail_on_overflow(self.fail_on_overflow), + )) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { @@ -424,7 +464,7 @@ impl PhysicalExpr for BinaryExpr { // end-points of its children. Ok(Some(vec![])) } - } else if self.op.is_comparison_operator() { + } else if self.op.supports_propagation() { Ok( propagate_comparison(&self.op, interval, left_interval, right_interval)? .map(|(left, right)| vec![left, right]), @@ -442,17 +482,45 @@ impl PhysicalExpr for BinaryExpr { self.hash(&mut s); } - /// For each operator, [`BinaryExpr`] has distinct ordering rules. - /// TODO: There may be rules specific to some data types (such as division and multiplication on unsigned integers) - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - let (left_child, right_child) = (&children[0], &children[1]); + /// For each operator, [`BinaryExpr`] has distinct rules. + /// TODO: There may be rules specific to some data types and expression ranges. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let (l_order, l_range) = (children[0].sort_properties, &children[0].range); + let (r_order, r_range) = (children[1].sort_properties, &children[1].range); match self.op() { - Operator::Plus => left_child.add(right_child), - Operator::Minus => left_child.sub(right_child), - Operator::Gt | Operator::GtEq => left_child.gt_or_gteq(right_child), - Operator::Lt | Operator::LtEq => right_child.gt_or_gteq(left_child), - Operator::And | Operator::Or => left_child.and_or(right_child), - _ => SortProperties::Unordered, + Operator::Plus => Ok(ExprProperties { + sort_properties: l_order.add(&r_order), + range: l_range.add(r_range)?, + }), + Operator::Minus => Ok(ExprProperties { + sort_properties: l_order.sub(&r_order), + range: l_range.sub(r_range)?, + }), + Operator::Gt => Ok(ExprProperties { + sort_properties: l_order.gt_or_gteq(&r_order), + range: l_range.gt(r_range)?, + }), + Operator::GtEq => Ok(ExprProperties { + sort_properties: l_order.gt_or_gteq(&r_order), + range: l_range.gt_eq(r_range)?, + }), + Operator::Lt => Ok(ExprProperties { + sort_properties: r_order.gt_or_gteq(&l_order), + range: l_range.lt(r_range)?, + }), + Operator::LtEq => Ok(ExprProperties { + sort_properties: r_order.gt_or_gteq(&l_order), + range: l_range.lt_eq(r_range)?, + }), + Operator::And => Ok(ExprProperties { + sort_properties: r_order.and_or(&l_order), + range: l_range.and(r_range)?, + }), + Operator::Or => Ok(ExprProperties { + sort_properties: r_order.and_or(&l_order), + range: l_range.or(r_range)?, + }), + _ => Ok(ExprProperties::new_unknown()), } } } @@ -461,7 +529,12 @@ impl PartialEq for BinaryExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() - .map(|x| self.left.eq(&x.left) && self.op == x.op && self.right.eq(&x.right)) + .map(|x| { + self.left.eq(&x.left) + && self.op == x.op + && self.right.eq(&x.right) + && self.fail_on_overflow.eq(&x.fail_on_overflow) + }) .unwrap_or(false) } } @@ -560,7 +633,7 @@ impl BinaryExpr { | NotLikeMatch | NotILikeMatch => unreachable!(), And => { if left_data_type == &DataType::Boolean { - boolean_op!(&left, &right, and_kleene) + Ok(boolean_op(&left, &right, and_kleene)?) } else { internal_err!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", @@ -572,7 +645,7 @@ impl BinaryExpr { } Or => { if left_data_type == &DataType::Boolean { - boolean_op!(&left, &right, or_kleene) + Ok(boolean_op(&left, &right, or_kleene)?) } else { internal_err!( "Cannot evaluate binary expression {:?} with types {:?} and {:?}", @@ -599,7 +672,7 @@ impl BinaryExpr { BitwiseXor => bitwise_xor_dyn(left, right), BitwiseShiftRight => bitwise_shift_right_dyn(left, right), BitwiseShiftLeft => bitwise_shift_left_dyn(left, right), - StringConcat => binary_string_array_op!(left, right, concat_elements), + StringConcat => concat_elements(left, right), AtArrow | ArrowAt => { unreachable!("ArrowAt and AtArrow should be rewritten to function") } @@ -607,6 +680,28 @@ impl BinaryExpr { } } +fn concat_elements(left: Arc, right: Arc) -> Result { + Ok(match left.data_type() { + DataType::Utf8 => Arc::new(concat_elements_utf8( + left.as_string::(), + right.as_string::(), + )?), + DataType::LargeUtf8 => Arc::new(concat_elements_utf8( + left.as_string::(), + right.as_string::(), + )?), + DataType::Utf8View => Arc::new(concat_elements_utf8view( + left.as_string_view(), + right.as_string_view(), + )?), + other => { + return internal_err!( + "Data type {other:?} not supported for binary operation 'concat_elements' on string arrays" + ); + } + }) +} + /// Create a binary expression whose arguments are correctly coerced. /// This function errors if it is not possible to coerce the arguments /// to computational types supported by the operator. @@ -619,10 +714,26 @@ pub fn binary( Ok(Arc::new(BinaryExpr::new(lhs, op, rhs))) } +/// Create a similar to expression +pub fn similar_to( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, +) -> Result> { + let binary_op = match (negated, case_insensitive) { + (false, false) => Operator::RegexMatch, + (false, true) => Operator::RegexIMatch, + (true, false) => Operator::RegexNotMatch, + (true, true) => Operator::RegexNotIMatch, + }; + Ok(Arc::new(BinaryExpr::new(expr, binary_op, pattern))) +} + #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit, try_cast, Literal}; + use crate::expressions::{col, lit, try_cast, Column, Literal}; use datafusion_common::plan_datafusion_err; use datafusion_expr::type_coercion::binary::get_input_types; @@ -875,6 +986,54 @@ mod tests { DataType::Boolean, [true, false], ); + test_coercion!( + StringViewArray, + DataType::Utf8View, + vec!["abc"; 5], + StringArray, + DataType::Utf8, + vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"], + Operator::RegexMatch, + BooleanArray, + DataType::Boolean, + [true, false, true, false, false], + ); + test_coercion!( + StringViewArray, + DataType::Utf8View, + vec!["abc"; 5], + StringArray, + DataType::Utf8, + vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"], + Operator::RegexIMatch, + BooleanArray, + DataType::Boolean, + [true, true, true, true, false], + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["abc"; 5], + StringViewArray, + DataType::Utf8View, + vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"], + Operator::RegexNotMatch, + BooleanArray, + DataType::Boolean, + [false, true, false, true, true], + ); + test_coercion!( + StringArray, + DataType::Utf8, + vec!["abc"; 5], + StringViewArray, + DataType::Utf8View, + vec!["^a", "^A", "(b|d)", "(B|D)", "^(b|c)"], + Operator::RegexNotIMatch, + BooleanArray, + DataType::Boolean, + [false, false, false, false, true], + ); test_coercion!( StringArray, DataType::Utf8, @@ -1457,8 +1616,11 @@ mod tests { let b = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); apply_arithmetic::( - schema.clone(), - vec![a.clone(), b.clone()], + Arc::clone(&schema), + vec![ + Arc::clone(&a) as Arc, + Arc::clone(&b) as Arc, + ], Operator::Minus, Int32Array::from(vec![0, 0, 1, 4, 11]), )?; @@ -2340,8 +2502,8 @@ mod tests { expected: BooleanArray, ) -> Result<()> { let op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; - let data: Vec = vec![left.clone(), right.clone()]; - let batch = RecordBatch::try_new(schema.clone(), data)?; + let data: Vec = vec![Arc::clone(left), Arc::clone(right)]; + let batch = RecordBatch::try_new(Arc::clone(schema), data)?; let result = op .evaluate(&batch)? .into_array(batch.num_rows()) @@ -2436,6 +2598,111 @@ mod tests { Ok(()) } + #[test] + fn regex_with_nulls() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ]); + let a = Arc::new(StringArray::from(vec![ + Some("abc"), + None, + Some("abc"), + None, + Some("abc"), + ])) as ArrayRef; + let b = Arc::new(StringArray::from(vec![ + Some("^a"), + Some("^A"), + None, + None, + Some("^(b|c)"), + ])) as ArrayRef; + + let regex_expected = + BooleanArray::from(vec![Some(true), None, None, None, Some(false)]); + let regex_not_expected = + BooleanArray::from(vec![Some(false), None, None, None, Some(true)]); + apply_logic_op( + &Arc::new(schema.clone()), + &a, + &b, + Operator::RegexMatch, + regex_expected.clone(), + )?; + apply_logic_op( + &Arc::new(schema.clone()), + &a, + &b, + Operator::RegexIMatch, + regex_expected.clone(), + )?; + apply_logic_op( + &Arc::new(schema.clone()), + &a, + &b, + Operator::RegexNotMatch, + regex_not_expected.clone(), + )?; + apply_logic_op( + &Arc::new(schema), + &a, + &b, + Operator::RegexNotIMatch, + regex_not_expected.clone(), + )?; + + let schema = Schema::new(vec![ + Field::new("a", DataType::LargeUtf8, true), + Field::new("b", DataType::LargeUtf8, true), + ]); + let a = Arc::new(LargeStringArray::from(vec![ + Some("abc"), + None, + Some("abc"), + None, + Some("abc"), + ])) as ArrayRef; + let b = Arc::new(LargeStringArray::from(vec![ + Some("^a"), + Some("^A"), + None, + None, + Some("^(b|c)"), + ])) as ArrayRef; + + apply_logic_op( + &Arc::new(schema.clone()), + &a, + &b, + Operator::RegexMatch, + regex_expected.clone(), + )?; + apply_logic_op( + &Arc::new(schema.clone()), + &a, + &b, + Operator::RegexIMatch, + regex_expected, + )?; + apply_logic_op( + &Arc::new(schema.clone()), + &a, + &b, + Operator::RegexNotMatch, + regex_not_expected.clone(), + )?; + apply_logic_op( + &Arc::new(schema), + &a, + &b, + Operator::RegexNotIMatch, + regex_not_expected, + )?; + + Ok(()) + } + #[test] fn or_with_nulls_op() -> Result<()> { let schema = Schema::new(vec![ @@ -3435,8 +3702,8 @@ mod tests { expected: ArrayRef, ) -> Result<()> { let arithmetic_op = binary_op(col("a", schema)?, op, col("b", schema)?, schema)?; - let data: Vec = vec![left.clone(), right.clone()]; - let batch = RecordBatch::try_new(schema.clone(), data)?; + let data: Vec = vec![Arc::clone(left), Arc::clone(right)]; + let batch = RecordBatch::try_new(Arc::clone(schema), data)?; let result = arithmetic_op .evaluate(&batch)? .into_array(batch.num_rows()) @@ -3731,15 +3998,15 @@ mod tests { let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; let right = Arc::new(Int32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; - let mut result = bitwise_and_dyn(left.clone(), right.clone())?; + let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(0), None, Some(3)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_or_dyn(left.clone(), right.clone())?; + result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(13), None, Some(15)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_xor_dyn(left.clone(), right.clone())?; + result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = Int32Array::from(vec![Some(13), None, Some(12)]); assert_eq!(result.as_ref(), &expected); @@ -3747,15 +4014,15 @@ mod tests { Arc::new(UInt32Array::from(vec![Some(12), None, Some(11)])) as ArrayRef; let right = Arc::new(UInt32Array::from(vec![Some(1), Some(3), Some(7)])) as ArrayRef; - let mut result = bitwise_and_dyn(left.clone(), right.clone())?; + let mut result = bitwise_and_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(0), None, Some(3)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_or_dyn(left.clone(), right.clone())?; + result = bitwise_or_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(13), None, Some(15)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_xor_dyn(left.clone(), right.clone())?; + result = bitwise_xor_dyn(Arc::clone(&left), Arc::clone(&right))?; let expected = UInt32Array::from(vec![Some(13), None, Some(12)]); assert_eq!(result.as_ref(), &expected); @@ -3767,24 +4034,26 @@ mod tests { let input = Arc::new(Int32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef; let modules = Arc::new(Int32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef; - let mut result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let mut result = + bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = Int32Array::from(vec![Some(8), None, Some(2560)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_shift_right_dyn(result.clone(), modules.clone())?; + result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?; assert_eq!(result.as_ref(), &input); let input = Arc::new(UInt32Array::from(vec![Some(2), None, Some(10)])) as ArrayRef; let modules = Arc::new(UInt32Array::from(vec![Some(2), Some(4), Some(8)])) as ArrayRef; - let mut result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let mut result = + bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = UInt32Array::from(vec![Some(8), None, Some(2560)]); assert_eq!(result.as_ref(), &expected); - result = bitwise_shift_right_dyn(result.clone(), modules.clone())?; + result = bitwise_shift_right_dyn(Arc::clone(&result), Arc::clone(&modules))?; assert_eq!(result.as_ref(), &input); Ok(()) } @@ -3793,14 +4062,14 @@ mod tests { fn bitwise_shift_array_overflow_test() -> Result<()> { let input = Arc::new(Int32Array::from(vec![Some(2)])) as ArrayRef; let modules = Arc::new(Int32Array::from(vec![Some(100)])) as ArrayRef; - let result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = Int32Array::from(vec![Some(32)]); assert_eq!(result.as_ref(), &expected); let input = Arc::new(UInt32Array::from(vec![Some(2)])) as ArrayRef; let modules = Arc::new(UInt32Array::from(vec![Some(100)])) as ArrayRef; - let result = bitwise_shift_left_dyn(input.clone(), modules.clone())?; + let result = bitwise_shift_left_dyn(Arc::clone(&input), Arc::clone(&modules))?; let expected = UInt32Array::from(vec![Some(32)]); assert_eq!(result.as_ref(), &expected); @@ -3937,9 +4206,12 @@ mod tests { Arc::new(DictionaryArray::try_new(keys, values).unwrap()) as ArrayRef; // Casting Dictionary to Int32 - let casted = - to_result_type_array(&Operator::Plus, dictionary.clone(), &DataType::Int32) - .unwrap(); + let casted = to_result_type_array( + &Operator::Plus, + Arc::clone(&dictionary), + &DataType::Int32, + ) + .unwrap(); assert_eq!( &casted, &(Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])) @@ -3949,16 +4221,164 @@ mod tests { // Array has same datatype as result type, no casting let casted = to_result_type_array( &Operator::Plus, - dictionary.clone(), + Arc::clone(&dictionary), dictionary.data_type(), ) .unwrap(); assert_eq!(&casted, &dictionary); // Not numerical operator, no casting - let casted = - to_result_type_array(&Operator::Eq, dictionary.clone(), &DataType::Int32) - .unwrap(); + let casted = to_result_type_array( + &Operator::Eq, + Arc::clone(&dictionary), + &DataType::Int32, + ) + .unwrap(); assert_eq!(&casted, &dictionary); } + + #[test] + fn test_add_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MAX])); + let r = Arc::new(Int32Array::from(vec![2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Plus, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 + 1")); + Ok(()) + } + + #[test] + fn test_subtract_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MIN])); + let r = Arc::new(Int32Array::from(vec![2, 1])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Minus, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: -2147483648 - 1")); + Ok(()) + } + + #[test] + fn test_mul_with_overflow() -> Result<()> { + // create test data + let l = Arc::new(Int32Array::from(vec![1, i32::MAX])); + let r = Arc::new(Int32Array::from(vec![2, 2])); + let schema = Arc::new(Schema::new(vec![ + Field::new("l", DataType::Int32, false), + Field::new("r", DataType::Int32, false), + ])); + let batch = RecordBatch::try_new(schema, vec![l, r])?; + + // create expression + let expr = BinaryExpr::new( + Arc::new(Column::new("l", 0)), + Operator::Multiply, + Arc::new(Column::new("r", 1)), + ) + .with_fail_on_overflow(true); + + // evaluate expression + let result = expr.evaluate(&batch); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Overflow happened on: 2147483647 * 2")); + Ok(()) + } + + /// Test helper for SIMILAR TO binary operation + fn apply_similar_to( + schema: &SchemaRef, + va: Vec<&str>, + vb: Vec<&str>, + negated: bool, + case_insensitive: bool, + expected: &BooleanArray, + ) -> Result<()> { + let a = StringArray::from(va); + let b = StringArray::from(vb); + let op = similar_to( + negated, + case_insensitive, + col("a", schema)?, + col("b", schema)?, + )?; + let batch = + RecordBatch::try_new(Arc::clone(schema), vec![Arc::new(a), Arc::new(b)])?; + let result = op + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + assert_eq!(result.as_ref(), expected); + + Ok(()) + } + + #[test] + fn test_similar_to() { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Utf8, false), + ])); + + let expected = [Some(true), Some(false)].iter().collect(); + // case-sensitive + apply_similar_to( + &schema, + vec!["hello world", "Hello World"], + vec!["hello.*", "hello.*"], + false, + false, + &expected, + ) + .unwrap(); + // case-insensitive + apply_similar_to( + &schema, + vec!["hello world", "bye"], + vec!["hello.*", "hello.*"], + false, + true, + &expected, + ) + .unwrap(); + } } diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index b0736e140fec..c0685c6decde 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -24,9 +24,10 @@ use arrow::compute::kernels::bitwise::{ bitwise_xor, bitwise_xor_scalar, }; use arrow::datatypes::DataType; -use datafusion_common::internal_err; +use datafusion_common::plan_err; use datafusion_common::{Result, ScalarValue}; +use arrow_schema::ArrowError; use std::sync::Arc; /// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT) @@ -69,7 +70,7 @@ macro_rules! create_dyn_kernel { DataType::UInt64 => { call_bitwise_kernel!(left, right, $KERNEL, UInt64Array) } - other => internal_err!( + other => plan_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) @@ -115,7 +116,7 @@ macro_rules! create_dyn_scalar_kernel { DataType::UInt16 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), DataType::UInt32 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), DataType::UInt64 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), - other => internal_err!( + other => plan_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) @@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar); create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar); create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, bitwise_shift_right_scalar); create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, bitwise_shift_left_scalar); + +pub fn concat_elements_utf8view( + left: &StringViewArray, + right: &StringViewArray, +) -> std::result::Result { + let capacity = left + .data_buffers() + .iter() + .zip(right.data_buffers().iter()) + .map(|(b1, b2)| b1.len() + b2.len()) + .sum(); + let mut result = StringViewBuilder::with_capacity(capacity); + + // Avoid reallocations by writing to a reused buffer (note we + // could be even more efficient r by creating the view directly + // here and avoid the buffer but that would be more complex) + let mut buffer = String::new(); + + for (left, right) in left.iter().zip(right.iter()) { + if let (Some(left), Some(right)) = (left, right) { + use std::fmt::Write; + buffer.clear(); + write!(&mut buffer, "{left}{right}") + .expect("writing into string buffer failed"); + result.append_value(&buffer); + } else { + // at least one of the values is null, so the output is also null + result.append_null() + } + } + Ok(result.finish()) +} diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 7b10df9ac146..981e49d73750 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,23 +19,50 @@ use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::{try_cast, NoOp}; +use crate::expressions::try_cast; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::array::*; -use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; +use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; +use super::{Column, Literal}; +use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; type WhenThen = (Arc, Arc); +#[derive(Debug, Hash)] +enum EvalMethod { + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + NoExpression, + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + WithExpression, + /// This is a specialization for a specific use case where we can take a fast path + /// for expressions that are infallible and can be cheaply computed for the entire + /// record batch rather than just for the rows where the predicate is true. + /// + /// CASE WHEN condition THEN column [ELSE NULL] END + InfallibleExprOrNull, + /// This is a specialization for a specific use case where we can take a fast path + /// if there is just one when/then pair and both the `then` and `else` expressions + /// are literal values + /// CASE WHEN condition THEN literal ELSE literal END + ScalarOrScalar, +} + /// The CASE expression is similar to a series of nested if/else and there are two forms that /// can be used. The first form consists of a series of boolean "when" expressions with /// corresponding "then" expressions, and an optional "else" expression. @@ -61,6 +88,8 @@ pub struct CaseExpr { when_then_expr: Vec, /// Optional "else" expression else_expr: Option>, + /// Evaluation method to use + eval_method: EvalMethod, } impl std::fmt::Display for CaseExpr { @@ -79,6 +108,15 @@ impl std::fmt::Display for CaseExpr { } } +/// This is a specialization for a specific use case where we can take a fast path +/// for expressions that are infallible and can be cheaply computed for the entire +/// record batch rather than just for the rows where the predicate is true. For now, +/// this is limited to use with Column expressions but could potentially be used for other +/// expressions in the future +fn is_cheap_and_infallible(expr: &Arc) -> bool { + expr.as_any().is::() +} + impl CaseExpr { /// Create a new CASE WHEN expression pub fn try_new( @@ -86,13 +124,41 @@ impl CaseExpr { when_then_expr: Vec, else_expr: Option>, ) -> Result { + // normalize null literals to None in the else_expr (this already happens + // during SQL planning, but not necessarily for other use cases) + let else_expr = match &else_expr { + Some(e) => match e.as_any().downcast_ref::() { + Some(lit) if lit.value().is_null() => None, + _ => else_expr, + }, + _ => else_expr, + }; + if when_then_expr.is_empty() { exec_err!("There must be at least one WHEN clause") } else { + let eval_method = if expr.is_some() { + EvalMethod::WithExpression + } else if when_then_expr.len() == 1 + && is_cheap_and_infallible(&(when_then_expr[0].1)) + && else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if when_then_expr.len() == 1 + && when_then_expr[0].1.as_any().is::() + && else_expr.is_some() + && else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else { + EvalMethod::NoExpression + }; + Ok(Self { expr, when_then_expr, else_expr, + eval_method, }) } } @@ -138,7 +204,13 @@ impl CaseExpr { .evaluate_selection(batch, &remainder)?; let when_value = when_value.into_array(batch.num_rows())?; // build boolean array representing which rows match the "when" value - let when_match = eq(&when_value, &base_value)?; + let when_match = compare_with_eq( + &when_value, + &base_value, + // The types of case and when expressions will be coerced to match. + // We only need to check if the base_value is nested. + base_value.data_type().is_nested(), + )?; // Treat nulls as false let when_match = match when_match.null_count() { 0 => Cow::Borrowed(&when_match), @@ -168,13 +240,13 @@ impl CaseExpr { } }; - remainder = and(&remainder, ¬(&when_match)?)?; + remainder = and_not(&remainder, &when_match)?; } if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(e.clone(), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| e.clone()); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); // null and unmatched tuples should be assigned else value remainder = or(&base_nulls, &remainder)?; let else_ = expr @@ -241,13 +313,13 @@ impl CaseExpr { // Succeed tuples should be filtered out for short-circuit evaluation, // null values for the current when expr should be kept - remainder = and(&remainder, ¬(&when_value)?)?; + remainder = and_not(&remainder, &when_value)?; } if let Some(e) = &self.else_expr { // keep `else_expr`'s data type and return type consistent - let expr = try_cast(e.clone(), &batch.schema(), return_type.clone()) - .unwrap_or_else(|_| e.clone()); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone()) + .unwrap_or_else(|_| Arc::clone(e)); let else_ = expr .evaluate_selection(batch, &remainder)? .into_array(batch.num_rows())?; @@ -256,6 +328,70 @@ impl CaseExpr { Ok(ColumnarValue::Array(current_value)) } + + /// This function evaluates the specialized case of: + /// + /// CASE WHEN condition THEN column + /// [ELSE NULL] + /// END + /// + /// Note that this function is only safe to use for "then" expressions + /// that are infallible because the expression will be evaluated for all + /// rows in the input batch. + fn case_column_or_null(&self, batch: &RecordBatch) -> Result { + let when_expr = &self.when_then_expr[0].0; + let then_expr = &self.when_then_expr[0].1; + if let ColumnarValue::Array(bit_mask) = when_expr.evaluate(batch)? { + let bit_mask = bit_mask + .as_any() + .downcast_ref::() + .expect("predicate should evaluate to a boolean array"); + // invert the bitmask + let bit_mask = not(bit_mask)?; + match then_expr.evaluate(batch)? { + ColumnarValue::Array(array) => { + Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?)) + } + ColumnarValue::Scalar(_) => { + internal_err!("expression did not evaluate to an array") + } + } + } else { + internal_err!("predicate did not evaluate to an array") + } + } + + fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result { + let return_type = self.data_type(&batch.schema())?; + + // evaluate when expression + let when_value = self.when_then_expr[0].0.evaluate(batch)?; + let when_value = when_value.into_array(batch.num_rows())?; + let when_value = as_boolean_array(&when_value).map_err(|e| { + DataFusionError::Context( + "WHEN expression did not return a BooleanArray".to_string(), + Box::new(e), + ) + })?; + + // Treat 'NULL' as false value + let when_value = match when_value.null_count() { + 0 => Cow::Borrowed(when_value), + _ => Cow::Owned(prep_null_mask_filter(when_value)), + }; + + // evaluate then_value + let then_value = self.when_then_expr[0].1.evaluate(batch)?; + let then_value = Scalar::new(then_value.into_array(1)?); + + // keep `else_expr`'s data type and return type consistent + let e = self.else_expr.as_ref().unwrap(); + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type) + .unwrap_or_else(|_| Arc::clone(e)); + let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?); + + Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?)) + } } impl PhysicalExpr for CaseExpr { @@ -303,31 +439,37 @@ impl PhysicalExpr for CaseExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - if self.expr.is_some() { - // this use case evaluates "expr" and then compares the values with the "when" - // values - self.case_when_with_expr(batch) - } else { - // The "when" conditions all evaluate to boolean in this use case and can be - // arbitrary expressions - self.case_when_no_expr(batch) + match self.eval_method { + EvalMethod::WithExpression => { + // this use case evaluates "expr" and then compares the values with the "when" + // values + self.case_when_with_expr(batch) + } + EvalMethod::NoExpression => { + // The "when" conditions all evaluate to boolean in this use case and can be + // arbitrary expressions + self.case_when_no_expr(batch) + } + EvalMethod::InfallibleExprOrNull => { + // Specialization for CASE WHEN expr THEN column [ELSE NULL] END + self.case_column_or_null(batch) + } + EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), } } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { let mut children = vec![]; - match &self.expr { - Some(expr) => children.push(expr.clone()), - None => children.push(Arc::new(NoOp::new())), + if let Some(expr) = &self.expr { + children.push(expr) } self.when_then_expr.iter().for_each(|(cond, value)| { - children.push(cond.clone()); - children.push(value.clone()); + children.push(cond); + children.push(value); }); - match &self.else_expr { - Some(expr) => children.push(expr.clone()), - None => children.push(Arc::new(NoOp::new())), + if let Some(else_expr) = &self.else_expr { + children.push(else_expr) } children } @@ -340,29 +482,27 @@ impl PhysicalExpr for CaseExpr { if children.len() != self.children().len() { internal_err!("CaseExpr: Wrong number of children") } else { - assert_eq!(children.len() % 2, 0); - let expr = match children[0].clone().as_any().downcast_ref::() { - Some(_) => None, - _ => Some(children[0].clone()), - }; - let else_expr = match children[children.len() - 1] - .clone() - .as_any() - .downcast_ref::() - { - Some(_) => None, - _ => Some(children[children.len() - 1].clone()), - }; - - let branches = children[1..children.len() - 1].to_vec(); - let mut when_then_expr: Vec = vec![]; - for (prev, next) in branches.into_iter().tuples() { - when_then_expr.push((prev, next)); - } + let (expr, when_then_expr, else_expr) = + match (self.expr().is_some(), self.else_expr().is_some()) { + (true, true) => ( + Some(&children[0]), + &children[1..children.len() - 1], + Some(&children[children.len() - 1]), + ), + (true, false) => { + (Some(&children[0]), &children[1..children.len()], None) + } + (false, true) => ( + None, + &children[0..children.len() - 1], + Some(&children[children.len() - 1]), + ), + (false, false) => (None, &children[0..children.len()], None), + }; Ok(Arc::new(CaseExpr::try_new( - expr, - when_then_expr, - else_expr, + expr.cloned(), + when_then_expr.iter().cloned().tuples().collect(), + else_expr.cloned(), )?)) } } @@ -413,8 +553,8 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{binary, cast, col, lit}; + use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; @@ -874,7 +1014,7 @@ mod tests { ); assert!(expr.is_ok()); let result_type = expr.unwrap().data_type(schema.as_ref())?; - assert_eq!(DataType::Float64, result_type); + assert_eq!(Float64, result_type); Ok(()) } @@ -891,26 +1031,26 @@ mod tests { let expr1 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; let expr2 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; let expr3 = generate_case_when_with_type_coercion( Some(col("a", &schema)?), - vec![(when1.clone(), then1.clone()), (when2, then2)], + vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)], None, &schema, )?; @@ -935,7 +1075,7 @@ mod tests { } #[test] - fn case_tranform() -> Result<()> { + fn case_transform() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let when1 = lit("foo"); @@ -947,26 +1087,24 @@ mod tests { let expr = generate_case_when_with_type_coercion( Some(col("a", &schema)?), vec![ - (when1.clone(), then1.clone()), - (when2.clone(), then2.clone()), + (Arc::clone(&when1), Arc::clone(&then1)), + (Arc::clone(&when2), Arc::clone(&then2)), ], - Some(else_value.clone()), + Some(Arc::clone(&else_value)), &schema, )?; - let expr2 = expr - .clone() + let expr2 = Arc::clone(&expr) .transform(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { @@ -976,19 +1114,17 @@ mod tests { .data() .unwrap(); - let expr3 = expr - .clone() + let expr3 = Arc::clone(&expr) .transform_down(|e| { - let transformed = - match e.as_any().downcast_ref::() { - Some(lit_value) => match lit_value.value() { - ScalarValue::Utf8(Some(str_value)) => { - Some(lit(str_value.to_uppercase())) - } - _ => None, - }, + let transformed = match e.as_any().downcast_ref::() { + Some(lit_value) => match lit_value.value() { + ScalarValue::Utf8(Some(str_value)) => { + Some(lit(str_value.to_uppercase())) + } _ => None, - }; + }, + _ => None, + }; Ok(if let Some(transformed) = transformed { Transformed::yes(transformed) } else { @@ -1004,6 +1140,53 @@ mod tests { Ok(()) } + #[test] + fn test_column_or_null_specialization() -> Result<()> { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(format!("string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap(); + + // CaseWhenExprOrNull should produce same results as CaseExpr + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(250), + )); + let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?; + assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull)); + match expr.evaluate(&batch)? { + ColumnarValue::Array(array) => { + assert_eq!(1000, array.len()); + assert_eq!(785, array.null_count()); + } + _ => unreachable!(), + } + Ok(()) + } + + fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) + } + + fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) + } + fn generate_case_when_with_type_coercion( expr: Option>, when_thens: Vec, diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a3b32461e581..457c47097a19 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -15,28 +15,32 @@ // specific language governing permissions and limitations // under the License. -use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; -use crate::PhysicalExpr; use std::any::Any; use std::fmt; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use DataType::*; + +use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, DataType::*, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; -use datafusion_expr::interval_arithmetic::Interval; -use datafusion_expr::ColumnarValue; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::ExprProperties; const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { safe: false, format_options: DEFAULT_FORMAT_OPTIONS, }; +const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: true, + format_options: DEFAULT_FORMAT_OPTIONS, +}; + /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast #[derive(Debug, Clone)] pub struct CastExpr { @@ -123,8 +127,8 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn children(&self) -> Vec> { - vec![self.expr.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] } fn with_new_children( @@ -132,7 +136,7 @@ impl PhysicalExpr for CastExpr { children: Vec>, ) -> Result> { Ok(Arc::new(CastExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.cast_type.clone(), Some(self.cast_options.clone()), ))) @@ -151,9 +155,9 @@ impl PhysicalExpr for CastExpr { let child_interval = children[0]; // Get child's datatype: let cast_type = child_interval.data_type(); - Ok(Some( - vec![interval.cast_to(&cast_type, &self.cast_options)?], - )) + Ok(Some(vec![ + interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)? + ])) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -163,9 +167,22 @@ impl PhysicalExpr for CastExpr { self.cast_options.hash(&mut s); } - /// A [`CastExpr`] preserves the ordering of its child. - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - children[0] + /// A [`CastExpr`] preserves the ordering of its child if the cast is done + /// under the same datatype family. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let source_datatype = children[0].range.data_type(); + let target_type = &self.cast_type; + + let unbounded = Interval::make_unbounded(target_type)?; + if (source_datatype.is_numeric() || source_datatype == Boolean) + && target_type.is_numeric() + || source_datatype.is_temporal() && target_type.is_temporal() + || source_datatype.eq(target_type) + { + Ok(children[0].clone().with_range(unbounded)) + } else { + Ok(ExprProperties::new_unknown().with_range(unbounded)) + } } } @@ -194,7 +211,7 @@ pub fn cast_with_options( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - Ok(expr.clone()) + Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { @@ -217,7 +234,8 @@ pub fn cast( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; + + use crate::expressions::column::col; use arrow::{ array::{ @@ -353,9 +371,9 @@ mod tests { generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Decimal128Array, - DataType::Decimal128(20, 6), + Decimal128(20, 6), [ Some(1_234_000), Some(2_222_000), @@ -374,9 +392,9 @@ mod tests { generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Decimal128Array, - DataType::Decimal128(10, 2), + Decimal128(10, 2), [Some(123), Some(222), Some(0), Some(400), Some(500), None], None ); @@ -395,9 +413,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int8Array, - DataType::Int8, + Int8, [ Some(1_i8), Some(2_i8), @@ -417,9 +435,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int16Array, - DataType::Int16, + Int16, [ Some(1_i16), Some(2_i16), @@ -439,9 +457,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int32Array, - DataType::Int32, + Int32, [ Some(1_i32), Some(2_i32), @@ -460,9 +478,9 @@ mod tests { .with_precision_and_scale(10, 0)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), Int64Array, - DataType::Int64, + Int64, [ Some(1_i64), Some(2_i64), @@ -490,9 +508,9 @@ mod tests { .with_precision_and_scale(10, 3)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(10, 3), + Decimal128(10, 3), Float32Array, - DataType::Float32, + Float32, [ Some(1.234_f32), Some(2.222_f32), @@ -511,9 +529,9 @@ mod tests { .with_precision_and_scale(20, 6)?; generic_decimal_to_other_test_cast!( decimal_array, - DataType::Decimal128(20, 6), + Decimal128(20, 6), Float64Array, - DataType::Float64, + Float64, [ Some(0.001234_f64), Some(0.002222_f64), @@ -532,10 +550,10 @@ mod tests { // int8 generic_test_cast!( Int8Array, - DataType::Int8, + Int8, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(3, 0), + Decimal128(3, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -543,10 +561,10 @@ mod tests { // int16 generic_test_cast!( Int16Array, - DataType::Int16, + Int16, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(5, 0), + Decimal128(5, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -554,10 +572,10 @@ mod tests { // int32 generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(10, 0), + Decimal128(10, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -565,10 +583,10 @@ mod tests { // int64 generic_test_cast!( Int64Array, - DataType::Int64, + Int64, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(20, 0), + Decimal128(20, 0), [Some(1), Some(2), Some(3), Some(4), Some(5)], None ); @@ -576,10 +594,10 @@ mod tests { // int64 to different scale generic_test_cast!( Int64Array, - DataType::Int64, + Int64, vec![1, 2, 3, 4, 5], Decimal128Array, - DataType::Decimal128(20, 2), + Decimal128(20, 2), [Some(100), Some(200), Some(300), Some(400), Some(500)], None ); @@ -587,10 +605,10 @@ mod tests { // float32 generic_test_cast!( Float32Array, - DataType::Float32, + Float32, vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, - DataType::Decimal128(10, 2), + Decimal128(10, 2), [Some(150), Some(250), Some(300), Some(112), Some(550)], None ); @@ -598,10 +616,10 @@ mod tests { // float64 generic_test_cast!( Float64Array, - DataType::Float64, + Float64, vec![1.5, 2.5, 3.0, 1.123_456_8, 5.50], Decimal128Array, - DataType::Decimal128(20, 4), + Decimal128(20, 4), [ Some(15000), Some(25000), @@ -618,10 +636,10 @@ mod tests { fn test_cast_i32_u32() -> Result<()> { generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], UInt32Array, - DataType::UInt32, + UInt32, [ Some(1_u32), Some(2_u32), @@ -638,10 +656,10 @@ mod tests { fn test_cast_i32_utf8() -> Result<()> { generic_test_cast!( Int32Array, - DataType::Int32, + Int32, vec![1, 2, 3, 4, 5], StringArray, - DataType::Utf8, + Utf8, [Some("1"), Some("2"), Some("3"), Some("4"), Some("5")], None ); @@ -657,10 +675,10 @@ mod tests { .collect(); generic_test_cast!( Int64Array, - DataType::Int64, + Int64, original, TimestampNanosecondArray, - DataType::Timestamp(TimeUnit::Nanosecond, None), + Timestamp(TimeUnit::Nanosecond, None), expected, None ); @@ -670,12 +688,12 @@ mod tests { #[test] fn invalid_cast() { // Ensure a useful error happens at plan time if invalid casts are used - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let schema = Schema::new(vec![Field::new("a", Int32, false)]); let result = cast( col("a", &schema).unwrap(), &schema, - DataType::Interval(IntervalUnit::MonthDayNano), + Interval(IntervalUnit::MonthDayNano), ); result.expect_err("expected Invalid CAST"); } @@ -683,11 +701,10 @@ mod tests { #[test] fn invalid_cast_with_options_error() -> Result<()> { // Ensure a useful error happens at plan time if invalid casts are used - let schema = Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + let schema = Schema::new(vec![Field::new("a", Utf8, false)]); let a = StringArray::from(vec!["9.1"]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let expression = - cast_with_options(col("a", &schema)?, &schema, DataType::Int32, None)?; + let expression = cast_with_options(col("a", &schema)?, &schema, Int32, None)?; let result = expression.evaluate(&batch); match result { @@ -704,15 +721,11 @@ mod tests { #[test] #[ignore] // TODO: https://github.com/apache/datafusion/issues/5396 fn test_cast_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int64, false)]); + let schema = Schema::new(vec![Field::new("a", Int64, false)]); let a = Int64Array::from(vec![100]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - let expression = cast_with_options( - col("a", &schema)?, - &schema, - DataType::Decimal128(38, 38), - None, - )?; + let expression = + cast_with_options(col("a", &schema)?, &schema, Decimal128(38, 38), None)?; expression.evaluate(&batch)?; Ok(()) } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 634a56d1d683..3e2d49e9fa69 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -15,69 +15,121 @@ // specific language governing permissions and limitations // under the License. -//! Column expression +//! Physical column reference: [`Column`] use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; -use crate::PhysicalExpr; - use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::{internal_err, Result}; +use arrow_schema::SchemaRef; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; +use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; + +/// Represents the column at a given index in a RecordBatch +/// +/// This is a physical expression that represents a column at a given index in an +/// arrow [`Schema`] / [`RecordBatch`]. +/// +/// Unlike the [logical `Expr::Column`], this expression is always resolved by schema index, +/// even though it does have a name. This is because the physical plan is always +/// resolved to a specific schema and there is no concept of "relation" +/// +/// # Example: +/// If the schema is `a`, `b`, `c` the `Column` for `b` would be represented by +/// index 1, since `b` is the second colum in the schema. +/// +/// ``` +/// # use datafusion_physical_expr::expressions::Column; +/// # use arrow::datatypes::{DataType, Field, Schema}; +/// // Schema with columns a, b, c +/// let schema = Schema::new(vec![ +/// Field::new("a", DataType::Int32, false), +/// Field::new("b", DataType::Int32, false), +/// Field::new("c", DataType::Int32, false), +/// ]); +/// +/// // reference to column b is index 1 +/// let column_b = Column::new_with_schema("b", &schema).unwrap(); +/// assert_eq!(column_b.index(), 1); +/// +/// // reference to column c is index 2 +/// let column_c = Column::new_with_schema("c", &schema).unwrap(); +/// assert_eq!(column_c.index(), 2); +/// ``` +/// [logical `Expr::Column`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.Expr.html#variant.Column #[derive(Debug, Hash, PartialEq, Eq, Clone)] -pub struct UnKnownColumn { +pub struct Column { + /// The name of the column (used for debugging and display purposes) name: String, + /// The index of the column in its schema + index: usize, } -impl UnKnownColumn { - /// Create a new unknown column expression - pub fn new(name: &str) -> Self { +impl Column { + /// Create a new column expression which references the + /// column with the given index in the schema. + pub fn new(name: &str, index: usize) -> Self { Self { name: name.to_owned(), + index, } } - /// Get the column name + /// Create a new column expression which references the + /// column with the given name in the schema + pub fn new_with_schema(name: &str, schema: &Schema) -> Result { + Ok(Column::new(name, schema.index_of(name)?)) + } + + /// Get the column's name pub fn name(&self) -> &str { &self.name } + + /// Get the column's schema index + pub fn index(&self) -> usize { + self.index + } } -impl std::fmt::Display for UnKnownColumn { +impl std::fmt::Display for Column { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", self.name) + write!(f, "{}@{}", self.name, self.index) } } -impl PhysicalExpr for UnKnownColumn { +impl PhysicalExpr for Column { /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } /// Get the data type of this expression, given the schema of the input - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(DataType::Null) + fn data_type(&self, input_schema: &Schema) -> Result { + self.bounds_check(input_schema)?; + Ok(input_schema.field(self.index).data_type().clone()) } - /// Decide whehter this expression is nullable, given the schema of the input - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(true) + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, input_schema: &Schema) -> Result { + self.bounds_check(input_schema)?; + Ok(input_schema.field(self.index).is_nullable()) } /// Evaluate the expression - fn evaluate(&self, _batch: &RecordBatch) -> Result { - internal_err!("UnKnownColumn::evaluate() should not be called") + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.bounds_check(batch.schema().as_ref())?; + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -94,7 +146,7 @@ impl PhysicalExpr for UnKnownColumn { } } -impl PartialEq for UnKnownColumn { +impl PartialEq for Column { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -103,10 +155,58 @@ impl PartialEq for UnKnownColumn { } } +impl Column { + fn bounds_check(&self, input_schema: &Schema) -> Result<()> { + if self.index < input_schema.fields.len() { + Ok(()) + } else { + internal_err!( + "PhysicalExpr Column references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}", + self.name, + self.index, + input_schema.fields.len(), + input_schema.fields().iter().map(|f| f.name()).collect::>() + ) + } + } +} + +/// Create a column expression +pub fn col(name: &str, schema: &Schema) -> Result> { + Ok(Arc::new(Column::new_with_schema(name, schema)?)) +} + +/// Rewrites an expression according to new schema; i.e. changes the columns it +/// refers to with the column at corresponding index in the new schema. Returns +/// an error if the given schema has fewer columns than the original schema. +/// Note that the resulting expression may not be valid if data types in the +/// new schema is incompatible with expression nodes. +pub fn with_new_schema( + expr: Arc, + schema: &SchemaRef, +) -> Result> { + Ok(expr + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::() { + let idx = col.index(); + let Some(field) = schema.fields().get(idx) else { + return plan_err!( + "New schema has fewer columns than original schema" + ); + }; + let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) + } else { + Ok(Transformed::no(expr)) + } + })? + .data) +} + #[cfg(test)] mod test { - use crate::expressions::Column; - use crate::PhysicalExpr; + use super::Column; + use crate::physical_expr::PhysicalExpr; use arrow::array::StringArray; use arrow::datatypes::{DataType, Field, Schema}; diff --git a/datafusion/physical-expr/src/expressions/datum.rs b/datafusion/physical-expr/src/expressions/datum.rs deleted file mode 100644 index 2bb79922cfec..000000000000 --- a/datafusion/physical-expr/src/expressions/datum.rs +++ /dev/null @@ -1,58 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::array::{ArrayRef, Datum}; -use arrow::error::ArrowError; -use arrow_array::BooleanArray; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::ColumnarValue; -use std::sync::Arc; - -/// Applies a binary [`Datum`] kernel `f` to `lhs` and `rhs` -/// -/// This maps arrow-rs' [`Datum`] kernels to DataFusion's [`ColumnarValue`] abstraction -pub(crate) fn apply( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - match (&lhs, &rhs) { - (ColumnarValue::Array(left), ColumnarValue::Array(right)) => { - Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) - } - (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( - ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), - ), - (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( - ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), - ), - (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { - let array = f(&left.to_scalar()?, &right.to_scalar()?)?; - let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; - Ok(ColumnarValue::Scalar(scalar)) - } - } -} - -/// Applies a binary [`Datum`] comparison kernel `f` to `lhs` and `rhs` -pub(crate) fn apply_cmp( - lhs: &ColumnarValue, - rhs: &ColumnarValue, - f: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - apply(lhs, rhs, |l, r| Ok(Arc::new(f(l, r)?))) -} diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 9ae4c2784ccf..cf57ce3e0e21 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -28,17 +28,20 @@ use crate::PhysicalExpr; use arrow::array::*; use arrow::buffer::BooleanBuffer; use arrow::compute::kernels::boolean::{not, or_kleene}; -use arrow::compute::kernels::cmp::eq; use arrow::compute::take; use arrow::datatypes::*; use arrow::util::bit_iterator::BitIndexIterator; use arrow::{downcast_dictionary_array, downcast_primitive_array}; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion_common::cast::{ as_boolean_array, as_generic_binary_array, as_string_array, }; use datafusion_common::hash_utils::HashValue; -use datafusion_common::{exec_err, internal_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue, +}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::datum::compare_with_eq; use ahash::RandomState; use hashbrown::hash_map::RawEntryMut; @@ -258,6 +261,7 @@ macro_rules! is_equal { } is_equal!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64); is_equal!(bool, str, [u8]); +is_equal!(IntervalDayTime, IntervalMonthDayNano); macro_rules! is_equal_float { ($($t:ty),+) => { @@ -352,13 +356,16 @@ impl PhysicalExpr for InListExpr { Some(f) => f.contains(value.into_array(num_rows)?.as_ref(), self.negated)?, None => { let value = value.into_array(num_rows)?; + let is_nested = value.data_type().is_nested(); let found = self.list.iter().map(|expr| expr.evaluate(batch)).try_fold( BooleanArray::new(BooleanBuffer::new_unset(num_rows), None), |result, expr| -> Result { - Ok(or_kleene( - &result, - &eq(&value, &expr?.into_array(num_rows)?)?, - )?) + let rhs = compare_with_eq( + &value, + &expr?.into_array(num_rows)?, + is_nested, + )?; + Ok(or_kleene(&result, &rhs)?) }, )?; @@ -372,10 +379,10 @@ impl PhysicalExpr for InListExpr { Ok(ColumnarValue::Array(Arc::new(r))) } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { let mut children = vec![]; - children.push(self.expr.clone()); - children.extend(self.list.clone()); + children.push(&self.expr); + children.extend(&self.list); children } @@ -385,7 +392,7 @@ impl PhysicalExpr for InListExpr { ) -> Result> { // assume the static_filter will not change during the rewrite process Ok(Arc::new(InListExpr::new( - children[0].clone(), + Arc::clone(&children[0]), children[1..].to_vec(), self.negated, self.static_filter.clone(), @@ -414,18 +421,6 @@ impl PartialEq for InListExpr { } } -/// Checks if two types are logically equal, dictionary types are compared by their value types. -fn is_logically_eq(lhs: &DataType, rhs: &DataType) -> bool { - match (lhs, rhs) { - (DataType::Dictionary(_, v1), DataType::Dictionary(_, v2)) => { - v1.as_ref().eq(v2.as_ref()) - } - (DataType::Dictionary(_, l), _) => l.as_ref().eq(rhs), - (_, DataType::Dictionary(_, r)) => lhs.eq(r.as_ref()), - _ => lhs.eq(rhs), - } -} - /// Creates a unary expression InList pub fn in_list( expr: Arc, @@ -437,7 +432,7 @@ pub fn in_list( let expr_data_type = expr.data_type(schema)?; for list_expr in list.iter() { let list_expr_data_type = list_expr.data_type(schema)?; - if !is_logically_eq(&expr_data_type, &list_expr_data_type) { + if !DFSchema::datatype_is_logically_equal(&expr_data_type, &list_expr_data_type) { return internal_err!( "The data type inlist should be same, the value type is {expr_data_type}, one of list expr type is {list_expr_data_type}" ); @@ -548,7 +543,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -559,7 +554,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -570,7 +565,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -581,7 +576,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -606,7 +601,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -616,7 +611,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -631,7 +626,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -641,7 +636,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -662,7 +657,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -673,7 +668,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -684,7 +679,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -695,7 +690,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -722,7 +717,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(false), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -733,7 +728,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(true), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -744,7 +739,7 @@ mod tests { list, &false, vec![Some(true), None, None, None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -755,7 +750,7 @@ mod tests { list, &true, vec![Some(false), None, None, None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -766,7 +761,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(true), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -777,7 +772,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(false), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -788,7 +783,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None, Some(false), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -799,7 +794,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None, Some(true), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -820,7 +815,7 @@ mod tests { list, &false, vec![Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -831,7 +826,7 @@ mod tests { list, &true, vec![Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -842,7 +837,7 @@ mod tests { list, &false, vec![Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -853,7 +848,7 @@ mod tests { list, &true, vec![Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -877,7 +872,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -891,7 +886,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -906,7 +901,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -921,7 +916,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -945,7 +940,7 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -959,7 +954,7 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -974,7 +969,7 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -989,7 +984,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1016,7 +1011,7 @@ mod tests { list, &false, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); // expression: "a not in (100,200) @@ -1026,7 +1021,7 @@ mod tests { list, &true, vec![Some(false), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1037,7 +1032,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL @@ -1046,7 +1041,7 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1057,7 +1052,7 @@ mod tests { list, &false, vec![Some(true), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1068,7 +1063,7 @@ mod tests { list, &true, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1081,7 +1076,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1090,7 +1085,7 @@ mod tests { list, &true, vec![Some(false), None, Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1107,7 +1102,7 @@ mod tests { let mut phy_exprs = vec![ lit(1i64), expressions::cast(lit(2i32), &schema, DataType::Int64)?, - expressions::try_cast(lit(3.13f32), &schema, DataType::Int64)?, + try_cast(lit(3.13f32), &schema, DataType::Int64)?, ]; let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); @@ -1135,7 +1130,7 @@ mod tests { try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap(); // column - phy_exprs.push(expressions::col("a", &schema)?); + phy_exprs.push(col("a", &schema)?); assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_err()); Ok(()) @@ -1176,7 +1171,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1185,7 +1180,7 @@ mod tests { list.clone(), &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); Ok(()) @@ -1227,13 +1222,13 @@ mod tests { vec![Arc::new(a), Arc::new(b), Arc::new(c)], )?; - let list = vec![col_b.clone(), col_c.clone()]; + let list = vec![Arc::clone(&col_b), Arc::clone(&col_c)]; in_list!( batch, list.clone(), &false, vec![Some(false), Some(true), None, Some(true), Some(true)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1242,7 +1237,7 @@ mod tests { list, &true, vec![Some(true), Some(false), None, Some(false), Some(false)], - col_a.clone(), + Arc::clone(&col_a), &schema ); @@ -1270,22 +1265,22 @@ mod tests { // static_filter has no nulls let list = vec![lit(1_i64), lit(2_i64)]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); // static_filter has nulls let list = vec![lit(1_i64), lit(2_i64), lit(ScalarValue::Null)]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - let list = vec![c1_nullable.clone()]; - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, true); + let list = vec![Arc::clone(&c1_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, true); - let list = vec![c2_non_nullable.clone()]; - test_nullable!(c1_nullable.clone(), list.clone(), &schema, true); + let list = vec![Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c1_nullable), list.clone(), &schema, true); - let list = vec![c2_non_nullable.clone(), c2_non_nullable.clone()]; - test_nullable!(c2_non_nullable.clone(), list.clone(), &schema, false); + let list = vec![Arc::clone(&c2_non_nullable), Arc::clone(&c2_non_nullable)]; + test_nullable!(Arc::clone(&c2_non_nullable), list.clone(), &schema, false); Ok(()) } @@ -1378,7 +1373,7 @@ mod tests { list.clone(), &false, vec![Some(true), Some(false), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1390,7 +1385,7 @@ mod tests { list.clone(), &true, vec![Some(false), Some(true), None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1410,7 +1405,7 @@ mod tests { list.clone(), &false, vec![Some(true), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } @@ -1422,7 +1417,7 @@ mod tests { list.clone(), &true, vec![Some(false), None, None], - col_a.clone(), + Arc::clone(&col_a), &schema ); } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index c5c673ec28ea..cbab7d0c9d1f 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -22,7 +22,6 @@ use std::{any::Any, sync::Arc}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; -use arrow::compute; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -73,24 +72,25 @@ impl PhysicalExpr for IsNotNullExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let arg = self.arg.evaluate(batch)?; match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( - compute::is_not_null(array.as_ref())?, - ))), + ColumnarValue::Array(array) => { + let is_not_null = arrow::compute::is_not_null(&array)?; + Ok(ColumnarValue::Array(Arc::new(is_not_null))) + } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(!scalar.is_null())), )), } } - fn children(&self) -> Vec> { - vec![self.arg.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(IsNotNullExpr::new(children[0].clone()))) + Ok(Arc::new(IsNotNullExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -120,6 +120,8 @@ mod tests { array::{BooleanArray, StringArray}, datatypes::*, }; + use arrow_array::{Array, Float64Array, Int32Array, UnionArray}; + use arrow_buffer::ScalarBuffer; use datafusion_common::cast::as_boolean_array; #[test] @@ -143,4 +145,48 @@ mod tests { Ok(()) } + + #[test] + fn union_is_not_null_op() { + // union of [{A=1}, {A=}, {B=1.1}, {B=1.2}, {B=}] + let int_array = Int32Array::from(vec![Some(1), None, None, None, None]); + let float_array = + Float64Array::from(vec![None, None, Some(1.1), Some(1.2), None]); + let type_ids = [0, 0, 1, 1, 1].into_iter().collect::>(); + + let children = vec![Arc::new(int_array) as Arc, Arc::new(float_array)]; + + let union_fields: UnionFields = [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + ] + .into_iter() + .collect(); + + let array = + UnionArray::try_new(union_fields.clone(), type_ids, None, children).unwrap(); + + let field = Field::new( + "my_union", + DataType::Union(union_fields, UnionMode::Sparse), + true, + ); + + let schema = Schema::new(vec![field]); + let expr = is_not_null(col("my_union", &schema).unwrap()).unwrap(); + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap(); + + // expression: "a is not null" + let actual = expr + .evaluate(&batch) + .unwrap() + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let actual = as_boolean_array(&actual).unwrap(); + + let expected = &BooleanArray::from(vec![true, false, true, true, false]); + + assert_eq!(expected, actual); + } } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index b0f70b6f0d7a..1c8597d3fdea 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -20,7 +20,6 @@ use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use arrow::compute; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -75,7 +74,7 @@ impl PhysicalExpr for IsNullExpr { let arg = self.arg.evaluate(batch)?; match arg { ColumnarValue::Array(array) => Ok(ColumnarValue::Array(Arc::new( - compute::is_null(array.as_ref())?, + arrow::compute::is_null(&array)?, ))), ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( ScalarValue::Boolean(Some(scalar.is_null())), @@ -83,15 +82,15 @@ impl PhysicalExpr for IsNullExpr { } } - fn children(&self) -> Vec> { - vec![self.arg.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(IsNullExpr::new(children[0].clone()))) + Ok(Arc::new(IsNullExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -108,6 +107,7 @@ impl PartialEq for IsNullExpr { .unwrap_or(false) } } + /// Create an IS NULL expression pub fn is_null(arg: Arc) -> Result> { Ok(Arc::new(IsNullExpr::new(arg))) @@ -121,6 +121,8 @@ mod tests { array::{BooleanArray, StringArray}, datatypes::*, }; + use arrow_array::{Array, Float64Array, Int32Array, UnionArray}; + use arrow_buffer::ScalarBuffer; use datafusion_common::cast::as_boolean_array; #[test] @@ -145,4 +147,70 @@ mod tests { Ok(()) } + + fn union_fields() -> UnionFields { + [ + (0, Arc::new(Field::new("A", DataType::Int32, true))), + (1, Arc::new(Field::new("B", DataType::Float64, true))), + (2, Arc::new(Field::new("C", DataType::Utf8, true))), + ] + .into_iter() + .collect() + } + + #[test] + fn sparse_union_is_null() { + // union of [{A=1}, {A=}, {B=1.1}, {B=1.2}, {B=}, {C=}, {C="a"}] + let int_array = + Int32Array::from(vec![Some(1), None, None, None, None, None, None]); + let float_array = + Float64Array::from(vec![None, None, Some(1.1), Some(1.2), None, None, None]); + let str_array = + StringArray::from(vec![None, None, None, None, None, None, Some("a")]); + let type_ids = [0, 0, 1, 1, 1, 2, 2] + .into_iter() + .collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = + UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); + + let result = arrow::compute::is_null(&array).unwrap(); + + let expected = + &BooleanArray::from(vec![false, true, false, false, true, true, false]); + assert_eq!(expected, &result); + } + + #[test] + fn dense_union_is_null() { + // union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}] + let int_array = Int32Array::from(vec![Some(1), None]); + let float_array = Float64Array::from(vec![Some(3.2), None]); + let str_array = StringArray::from(vec![Some("a"), None]); + let type_ids = [0, 0, 1, 1, 2, 2].into_iter().collect::>(); + let offsets = [0, 1, 0, 1, 0, 1] + .into_iter() + .collect::>(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(float_array), + Arc::new(str_array), + ]; + + let array = + UnionArray::try_new(union_fields(), type_ids, Some(offsets), children) + .unwrap(); + + let result = arrow::compute::is_null(&array).unwrap(); + + let expected = &BooleanArray::from(vec![false, true, false, true, false, true]); + assert_eq!(expected, &result); + } } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 6e0beeb0beea..b84ba82b642d 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -20,11 +20,11 @@ use std::{any::Any, sync::Arc}; use crate::{physical_expr::down_cast_any_ref, PhysicalExpr}; -use crate::expressions::datum::apply_cmp; use arrow::record_batch::RecordBatch; use arrow_schema::{DataType, Schema}; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::datum::apply_cmp; // Like expression #[derive(Debug, Hash)] @@ -112,8 +112,8 @@ impl PhysicalExpr for LikeExpr { } } - fn children(&self) -> Vec> { - vec![self.expr.clone(), self.pattern.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.expr, &self.pattern] } fn with_new_children( @@ -123,8 +123,8 @@ impl PhysicalExpr for LikeExpr { Ok(Arc::new(LikeExpr::new( self.negated, self.case_insensitive, - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), ))) } @@ -148,6 +148,14 @@ impl PartialEq for LikeExpr { } } +/// used for optimize Dictionary like +fn can_like_type(from_type: &DataType) -> bool { + match from_type { + DataType::Dictionary(_, inner_type_from) => **inner_type_from == DataType::Utf8, + _ => false, + } +} + /// Create a like expression, erroring if the argument types are not compatible. pub fn like( negated: bool, @@ -158,7 +166,7 @@ pub fn like( ) -> Result> { let expr_type = &expr.data_type(input_schema)?; let pattern_type = &pattern.data_type(input_schema)?; - if !expr_type.eq(pattern_type) { + if !expr_type.eq(pattern_type) && !can_like_type(expr_type) { return internal_err!( "The type of {expr_type} AND {pattern_type} of like physical should be same" ); diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 35ea80ea574d..ed24e9028153 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -21,16 +21,17 @@ use std::any::Any; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; -use crate::PhysicalExpr; +use crate::physical_expr::{down_cast_any_ref, PhysicalExpr}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_expr::Expr; +use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; /// Represents a literal value #[derive(Debug, PartialEq, Eq, Hash)] @@ -74,7 +75,7 @@ impl PhysicalExpr for Literal { Ok(ColumnarValue::Scalar(self.value.clone())) } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -90,8 +91,11 @@ impl PhysicalExpr for Literal { self.hash(&mut s); } - fn get_ordering(&self, _children: &[SortProperties]) -> SortProperties { - SortProperties::Singleton + fn get_properties(&self, _children: &[ExprProperties]) -> Result { + Ok(ExprProperties { + sort_properties: SortProperties::Singleton, + range: Interval::try_new(self.value().clone(), self.value().clone())?, + }) } } @@ -115,6 +119,7 @@ pub fn lit(value: T) -> Arc { #[cfg(test)] mod tests { use super::*; + use arrow::array::Int32Array; use arrow::datatypes::*; use datafusion_common::cast::as_int32_array; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 688d5ce6eabf..7d71bd9ff17b 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -22,7 +22,6 @@ mod binary; mod case; mod cast; mod column; -mod datum; mod in_list; mod is_not_null; mod is_null; @@ -32,59 +31,18 @@ mod negative; mod no_op; mod not; mod try_cast; +mod unknown_column; /// Module with some convenient methods used in expression building -pub mod helpers { - pub use crate::aggregate::min_max::{max, min}; -} - -pub use crate::aggregate::approx_distinct::ApproxDistinct; -pub use crate::aggregate::approx_median::ApproxMedian; -pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; -pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; -pub use crate::aggregate::array_agg::ArrayAgg; -pub use crate::aggregate::array_agg_distinct::DistinctArrayAgg; -pub use crate::aggregate::array_agg_ordered::OrderSensitiveArrayAgg; -pub use crate::aggregate::average::{Avg, AvgAccumulator}; -pub use crate::aggregate::bit_and_or_xor::{BitAnd, BitOr, BitXor, DistinctBitXor}; -pub use crate::aggregate::bool_and_or::{BoolAnd, BoolOr}; -pub use crate::aggregate::build_in::create_aggregate_expr; -pub use crate::aggregate::correlation::Correlation; -pub use crate::aggregate::count::Count; -pub use crate::aggregate::count_distinct::DistinctCount; -pub use crate::aggregate::covariance::{Covariance, CovariancePop}; -pub use crate::aggregate::grouping::Grouping; -pub use crate::aggregate::median::Median; -pub use crate::aggregate::min_max::{Max, Min}; -pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator}; -pub use crate::aggregate::nth_value::NthValueAgg; -pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; -pub use crate::aggregate::stddev::{Stddev, StddevPop}; -pub use crate::aggregate::string_agg::StringAgg; -pub use crate::aggregate::sum::Sum; -pub use crate::aggregate::sum_distinct::DistinctSum; -pub use crate::aggregate::variance::{Variance, VariancePop}; -pub use crate::window::cume_dist::cume_dist; -pub use crate::window::cume_dist::CumeDist; -pub use crate::window::lead_lag::WindowShift; -pub use crate::window::lead_lag::{lag, lead}; pub use crate::window::nth_value::NthValue; -pub use crate::window::ntile::Ntile; -pub use crate::window::rank::{dense_rank, percent_rank, rank}; -pub use crate::window::rank::{Rank, RankType}; -pub use crate::window::row_number::RowNumber; pub use crate::PhysicalSortExpr; -pub use datafusion_functions_aggregate::first_last::{ - FirstValuePhysicalExpr as FirstValue, LastValuePhysicalExpr as LastValue, -}; -pub use binary::{binary, BinaryExpr}; +pub use binary::{binary, similar_to, BinaryExpr}; pub use case::{case, CaseExpr}; -pub use cast::{cast, cast_with_options, CastExpr}; -pub use column::UnKnownColumn; +pub use cast::{cast, CastExpr}; +pub use column::{col, with_new_schema, Column}; pub use datafusion_expr::utils::format_state_name; -pub use datafusion_physical_expr_common::expressions::column::{col, Column}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; @@ -94,186 +52,4 @@ pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; pub use try_cast::{try_cast, TryCastExpr}; - -#[cfg(test)] -pub(crate) mod tests { - use std::sync::Arc; - - use crate::expressions::{col, create_aggregate_expr, try_cast}; - use crate::AggregateExpr; - use arrow::record_batch::RecordBatch; - use arrow_array::ArrayRef; - use arrow_schema::{Field, Schema}; - use datafusion_common::{Result, ScalarValue}; - use datafusion_expr::type_coercion::aggregates::coerce_types; - use datafusion_expr::{AggregateFunction, EmitTo}; - - /// macro to perform an aggregation using [`datafusion_expr::Accumulator`] and verify the - /// result. - #[macro_export] - macro_rules! generic_test_op { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.data_type()) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// macro to perform an aggregation using [`crate::GroupsAccumulator`] and verify the result. - /// - /// The difference between this and the above `generic_test_op` is that the former checks - /// the old slow-path [`datafusion_expr::Accumulator`] implementation, while this checks - /// the new [`crate::GroupsAccumulator`] implementation. - #[macro_export] - macro_rules! generic_test_op_new { - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op_new!( - $ARRAY, - $DATATYPE, - $OP, - $EXPECTED, - $EXPECTED.data_type().clone() - ) - }; - ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); - - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate_new(&batch, agg)?; - assert_eq!($EXPECTED, &actual); - - Ok(()) as Result<(), ::datafusion_common::DataFusionError> - }}; - } - - /// Assert `function(array) == expected` performing any necessary type coercion - pub fn assert_aggregate( - array: ArrayRef, - function: AggregateFunction, - distinct: bool, - expected: ScalarValue, - ) { - let data_type = array.data_type(); - let sig = function.signature(); - let coerced = coerce_types(&function, &[data_type.clone()], &sig).unwrap(); - - let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); - let batch = - RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); - - let input = try_cast( - col("a", &input_schema).unwrap(), - &input_schema, - coerced[0].clone(), - ) - .unwrap(); - - let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); - let agg = create_aggregate_expr( - &function, - distinct, - &[input], - &[], - &schema, - "agg", - false, - ) - .unwrap(); - - let result = aggregate(&batch, agg).unwrap(); - assert_eq!(expected, result); - } - - /// macro to perform an aggregation with two inputs and verify the result. - #[macro_export] - macro_rules! generic_test_op2 { - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr) => { - generic_test_op2!( - $ARRAY1, - $ARRAY2, - $DATATYPE1, - $DATATYPE2, - $OP, - $EXPECTED, - $EXPECTED.data_type() - ) - }; - ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ - let schema = Schema::new(vec![ - Field::new("a", $DATATYPE1, true), - Field::new("b", $DATATYPE2, true), - ]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY1, $ARRAY2])?; - - let agg = Arc::new(<$OP>::new( - col("a", &schema)?, - col("b", &schema)?, - "bla".to_string(), - $EXPECTED_DATATYPE, - )); - let actual = aggregate(&batch, agg)?; - let expected = ScalarValue::from($EXPECTED); - - assert_eq!(expected, actual); - - Ok(()) - }}; - } - - pub fn aggregate( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - accum.update_batch(&values)?; - accum.evaluate() - } - - pub fn aggregate_new( - batch: &RecordBatch, - agg: Arc, - ) -> Result { - let mut accum = agg.create_groups_accumulator()?; - let expr = agg.expressions(); - let values = expr - .iter() - .map(|e| { - e.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) - }) - .collect::>>()?; - let indices = vec![0; batch.num_rows()]; - accum.update_batch(&values, &indices, None, 1)?; - accum.evaluate(EmitTo::All) - } -} +pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index f6d4620c427f..399ebde9f726 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -22,7 +22,6 @@ use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; -use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::{ @@ -32,6 +31,7 @@ use arrow::{ }; use datafusion_common::{plan_err, Result}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::{ type_coercion::{is_interval, is_null, is_signed_numeric, is_timestamp}, ColumnarValue, @@ -89,15 +89,15 @@ impl PhysicalExpr for NegativeExpr { } } - fn children(&self) -> Vec> { - vec![self.arg.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NegativeExpr::new(children[0].clone()))) + Ok(Arc::new(NegativeExpr::new(Arc::clone(&children[0])))) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -134,8 +134,11 @@ impl PhysicalExpr for NegativeExpr { } /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - -children[0] + fn get_properties(&self, children: &[ExprProperties]) -> Result { + Ok(ExprProperties { + sort_properties: -children[0].sort_properties, + range: children[0].range.clone().arithmetic_negate()?, + }) } } @@ -254,7 +257,7 @@ mod tests { #[test] fn test_negation_valid_types() -> Result<()> { let negatable_types = [ - DataType::Int8, + Int8, DataType::Timestamp(TimeUnit::Second, None), DataType::Interval(IntervalUnit::YearMonth), ]; diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index b558ccab154d..9148cb7c1c1d 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -68,7 +68,7 @@ impl PhysicalExpr for NoOp { internal_err!("NoOp::evaluate() should not be called") } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 1428be71cc21..6d91e9dfdd36 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -27,6 +27,7 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; /// Not expression @@ -89,15 +90,19 @@ impl PhysicalExpr for NotExpr { } } - fn children(&self) -> Vec> { - vec![self.arg.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(NotExpr::new(children[0].clone()))) + Ok(Arc::new(NotExpr::new(Arc::clone(&children[0])))) + } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + children[0].not() } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -125,10 +130,11 @@ mod tests { use super::*; use crate::expressions::col; use arrow::{array::BooleanArray, datatypes::*}; + use std::sync::OnceLock; #[test] fn neg_op() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); + let schema = schema(); let expr = not(col("a", &schema)?)?; assert_eq!(expr.data_type(&schema)?, DataType::Boolean); @@ -137,8 +143,7 @@ mod tests { let input = BooleanArray::from(vec![Some(true), None, Some(false)]); let expected = &BooleanArray::from(vec![Some(false), None, Some(true)]); - let batch = - RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; + let batch = RecordBatch::try_new(schema, vec![Arc::new(input)])?; let result = expr .evaluate(&batch)? @@ -150,4 +155,48 @@ mod tests { Ok(()) } + + #[test] + fn test_evaluate_bounds() -> Result<()> { + // Note that `None` for boolean intervals is converted to `Some(false)` + // / `Some(true)` by `Interval::make`, so it is not explicitly tested + // here + + // if the bounds are all booleans (false, true) so is the negation + assert_evaluate_bounds( + Interval::make(Some(false), Some(true))?, + Interval::make(Some(false), Some(true))?, + )?; + // (true, false) is not tested because it is not a valid interval (lower + // bound is greater than upper bound) + assert_evaluate_bounds( + Interval::make(Some(true), Some(true))?, + Interval::make(Some(false), Some(false))?, + )?; + assert_evaluate_bounds( + Interval::make(Some(false), Some(false))?, + Interval::make(Some(true), Some(true))?, + )?; + Ok(()) + } + + fn assert_evaluate_bounds( + interval: Interval, + expected_interval: Interval, + ) -> Result<()> { + let not_expr = not(col("a", &schema())?)?; + assert_eq!( + not_expr.evaluate_bounds(&[&interval]).unwrap(), + expected_interval + ); + Ok(()) + } + + fn schema() -> SchemaRef { + Arc::clone(SCHEMA.get_or_init(|| { + Arc::new(Schema::new(vec![Field::new("a", DataType::Boolean, true)])) + })) + } + + static SCHEMA: OnceLock = OnceLock::new(); } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index d25a904f7d6a..43b6c993d2b2 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -31,7 +31,7 @@ use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result, ScalarValue}; use datafusion_expr::ColumnarValue; -/// TRY_CAST expression casts an expression to a specific data type and retuns NULL on invalid cast +/// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast #[derive(Debug, Hash)] pub struct TryCastExpr { /// The expression to cast @@ -97,8 +97,8 @@ impl PhysicalExpr for TryCastExpr { } } - fn children(&self) -> Vec> { - vec![self.expr.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] } fn with_new_children( @@ -106,7 +106,7 @@ impl PhysicalExpr for TryCastExpr { children: Vec>, ) -> Result> { Ok(Arc::new(TryCastExpr::new( - children[0].clone(), + Arc::clone(&children[0]), self.cast_type.clone(), ))) } @@ -137,7 +137,7 @@ pub fn try_cast( ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { - Ok(expr.clone()) + Ok(Arc::clone(&expr)) } else if can_cast_types(&expr_type, &cast_type) { Ok(Arc::new(TryCastExpr::new(expr, cast_type))) } else { diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs new file mode 100644 index 000000000000..590efd577963 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UnKnownColumn expression + +use std::any::Any; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use crate::PhysicalExpr; + +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::ColumnarValue; + +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct UnKnownColumn { + name: String, +} + +impl UnKnownColumn { + /// Create a new unknown column expression + pub fn new(name: &str) -> Self { + Self { + name: name.to_owned(), + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for UnKnownColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +impl PhysicalExpr for UnKnownColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Null) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("UnKnownColumn::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.hash(&mut s); + } +} + +impl PartialEq for UnKnownColumn { + fn eq(&self, _other: &dyn Any) -> bool { + // UnknownColumn is not a valid expression, so it should not be equal to any other expression. + // See https://github.com/apache/datafusion/pull/11536 + false + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs deleted file mode 100644 index ac5b87e701af..000000000000 --- a/datafusion/physical-expr/src/functions.rs +++ /dev/null @@ -1,473 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Declaration of built-in (scalar) functions. -//! This module contains built-in functions' enumeration and metadata. -//! -//! Generally, a function has: -//! * a signature -//! * a return type, that is a function of the incoming argument's types -//! * the computation, that must accept each valid signature -//! -//! * Signature: see `Signature` -//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64. -//! -//! This module also supports coercion to improve user experience: if -//! an argument i32 is passed to a function that supports f64, the -//! argument is automatically is coerced to f64. - -use std::ops::Neg; -use std::sync::Arc; - -use arrow::{array::ArrayRef, datatypes::Schema}; -use arrow_array::Array; - -use datafusion_common::{DFSchema, Result, ScalarValue}; -pub use datafusion_expr::FuncMonotonicity; -use datafusion_expr::{ - type_coercion::functions::data_types, ColumnarValue, ScalarFunctionImplementation, -}; -use datafusion_expr::{Expr, ScalarFunctionDefinition, ScalarUDF}; - -use crate::sort_properties::SortProperties; -use crate::{PhysicalExpr, ScalarFunctionExpr}; - -/// Create a physical (function) expression. -/// This function errors when `args`' can't be coerced to a valid argument type of the function. -pub fn create_physical_expr( - fun: &ScalarUDF, - input_phy_exprs: &[Arc], - input_schema: &Schema, - args: &[Expr], - input_dfschema: &DFSchema, -) -> Result> { - let input_expr_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - - // verify that input data types is consistent with function's `TypeSignature` - data_types(&input_expr_types, fun.signature())?; - - // Since we have arg_types, we don't need args and schema. - let return_type = - fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; - - let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); - Ok(Arc::new(ScalarFunctionExpr::new( - fun.name(), - fun_def, - input_phy_exprs.to_vec(), - return_type, - fun.monotonicity()?, - fun.signature().type_signature.supports_zero_argument(), - ))) -} - -#[derive(Debug, Clone, Copy)] -pub enum Hint { - /// Indicates the argument needs to be padded if it is scalar - Pad, - /// Indicates the argument can be converted to an array of length 1 - AcceptsSingular, -} - -#[deprecated(since = "36.0.0", note = "Use ColumarValue::values_to_arrays instead")] -pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result> { - ColumnarValue::values_to_arrays(args) -} - -/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function -/// and vice-versa after evaluation. -/// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. -/// That's said its output will be same for all input rows in a batch. -#[deprecated( - since = "36.0.0", - note = "Implement your function directly in terms of ColumnarValue or use `ScalarUDF` instead" -)] -pub fn make_scalar_function(inner: F) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - make_scalar_function_inner(inner) -} - -/// Internal implementation, see comments on `make_scalar_function` for caveats -pub(crate) fn make_scalar_function_inner(inner: F) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - make_scalar_function_with_hints(inner, vec![]) -} - -/// Just like [`make_scalar_function`], decorates the given function to handle both [`ScalarValue`]s and arrays. -/// Additionally can receive a `hints` vector which can be used to control the output arrays when generating them -/// from [`ScalarValue`]s. -/// -/// Each element of the `hints` vector gets mapped to the corresponding argument of the function. The number of hints -/// can be less or greater than the number of arguments (for functions with variable number of arguments). Each unmapped -/// argument will assume the default hint (for padding, it is [`Hint::Pad`]). -pub(crate) fn make_scalar_function_with_hints( - inner: F, - hints: Vec, -) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) - .map(|(arg, hint)| { - // Decide on the length to expand this scalar to depending - // on the given hints. - let expansion_len = match hint { - Hint::AcceptsSingular => 1, - Hint::Pad => inferred_length, - }; - arg.clone().into_array(expansion_len) - }) - .collect::>>()?; - - let result = (inner)(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - }) -} - -/// Determines a [`ScalarFunctionExpr`]'s monotonicity for the given arguments -/// and the function's behavior depending on its arguments. -pub fn out_ordering( - func: &FuncMonotonicity, - arg_orderings: &[SortProperties], -) -> SortProperties { - func.iter().zip(arg_orderings).fold( - SortProperties::Singleton, - |prev_sort, (item, arg)| { - let current_sort = func_order_in_one_dimension(item, arg); - - match (prev_sort, current_sort) { - (_, SortProperties::Unordered) => SortProperties::Unordered, - (SortProperties::Singleton, SortProperties::Ordered(_)) => current_sort, - (SortProperties::Ordered(prev), SortProperties::Ordered(current)) - if prev.descending != current.descending => - { - SortProperties::Unordered - } - _ => prev_sort, - } - }, - ) -} - -/// This function decides the monotonicity property of a [`ScalarFunctionExpr`] for a single argument (i.e. across a single dimension), given that argument's sort properties. -fn func_order_in_one_dimension( - func_monotonicity: &Option, - arg: &SortProperties, -) -> SortProperties { - if *arg == SortProperties::Singleton { - SortProperties::Singleton - } else { - match func_monotonicity { - None => SortProperties::Unordered, - Some(false) => { - if let SortProperties::Ordered(_) = arg { - arg.neg() - } else { - SortProperties::Unordered - } - } - Some(true) => { - if let SortProperties::Ordered(_) = arg { - *arg - } else { - SortProperties::Unordered - } - } - } - } -} - -#[cfg(test)] -mod tests { - use arrow::{ - array::UInt64Array, - datatypes::{DataType, Field}, - }; - use arrow_schema::DataType::Utf8; - - use datafusion_common::cast::as_uint64_array; - use datafusion_common::DataFusionError; - use datafusion_common::{internal_err, plan_err}; - use datafusion_expr::{Signature, Volatility}; - - use crate::expressions::try_cast; - use crate::utils::tests::TestScalarUDF; - - use super::*; - - #[test] - fn test_empty_arguments_error() -> Result<()> { - let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let udf = ScalarUDF::new_from_impl(TestScalarUDF { - signature: Signature::variadic(vec![Utf8], Volatility::Immutable), - }); - let expr = create_physical_expr_with_type_coercion( - &udf, - &[], - &schema, - &[], - &DFSchema::empty(), - ); - - match expr { - Ok(..) => { - return plan_err!( - "ScalarUDF function {udf:?} does not support empty arguments" - ); - } - Err(DataFusionError::Plan(_)) => { - // Continue the loop - } - Err(..) => { - return internal_err!( - "ScalarUDF function {udf:?} didn't got the right error with empty arguments"); - } - } - - Ok(()) - } - - // Helper function just for testing. - // Returns `expressions` coerced to types compatible with - // `signature`, if possible. - pub fn coerce( - expressions: &[Arc], - schema: &Schema, - signature: &Signature, - ) -> Result>> { - if expressions.is_empty() { - return Ok(vec![]); - } - - let current_types = expressions - .iter() - .map(|e| e.data_type(schema)) - .collect::>>()?; - - let new_types = data_types(¤t_types, signature)?; - - expressions - .iter() - .enumerate() - .map(|(i, expr)| try_cast(expr.clone(), schema, new_types[i].clone())) - .collect::>>() - } - - // Helper function just for testing. - // The type coercion will be done in the logical phase, should do the type coercion for the test - fn create_physical_expr_with_type_coercion( - fun: &ScalarUDF, - input_phy_exprs: &[Arc], - input_schema: &Schema, - args: &[Expr], - input_dfschema: &DFSchema, - ) -> Result> { - let type_coerced_phy_exprs = - coerce(input_phy_exprs, input_schema, fun.signature()).unwrap(); - create_physical_expr( - fun, - &type_coerced_phy_exprs, - input_schema, - args, - input_dfschema, - ) - } - - fn dummy_function(args: &[ArrayRef]) -> Result { - let result: UInt64Array = - args.iter().map(|array| Some(array.len() as u64)).collect(); - Ok(Arc::new(result) as ArrayRef) - } - - fn unpack_uint64_array(col: Result) -> Result> { - if let ColumnarValue::Array(array) = col? { - Ok(as_uint64_array(&array)?.values().to_vec()) - } else { - internal_err!("Unexpected scalar created by a test function") - } - } - - #[test] - fn test_make_scalar_function() -> Result<()> { - let adapter_func = make_scalar_function_inner(dummy_function); - - let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; - assert_eq!(result, vec![5, 5]); - - Ok(()) - } - - #[test] - fn test_make_scalar_function_with_no_hints() -> Result<()> { - let adapter_func = make_scalar_function_with_hints(dummy_function, vec![]); - - let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; - assert_eq!(result, vec![5, 5]); - - Ok(()) - } - - #[test] - fn test_make_scalar_function_with_hints() -> Result<()> { - let adapter_func = make_scalar_function_with_hints( - dummy_function, - vec![Hint::Pad, Hint::AcceptsSingular], - ); - - let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; - assert_eq!(result, vec![5, 1]); - - Ok(()) - } - - #[test] - fn test_make_scalar_function_with_hints_on_arrays() -> Result<()> { - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let adapter_func = make_scalar_function_with_hints( - dummy_function, - vec![Hint::Pad, Hint::AcceptsSingular], - ); - - let result = unpack_uint64_array(adapter_func(&[array_arg.clone(), array_arg]))?; - assert_eq!(result, vec![5, 5]); - - Ok(()) - } - - #[test] - fn test_make_scalar_function_with_mixed_hints() -> Result<()> { - let adapter_func = make_scalar_function_with_hints( - dummy_function, - vec![Hint::Pad, Hint::AcceptsSingular, Hint::Pad], - ); - - let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let result = unpack_uint64_array(adapter_func(&[ - array_arg, - scalar_arg.clone(), - scalar_arg, - ]))?; - assert_eq!(result, vec![5, 1, 5]); - - Ok(()) - } - - #[test] - fn test_make_scalar_function_with_more_arguments_than_hints() -> Result<()> { - let adapter_func = make_scalar_function_with_hints( - dummy_function, - vec![Hint::Pad, Hint::AcceptsSingular, Hint::Pad], - ); - - let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let result = unpack_uint64_array(adapter_func(&[ - array_arg.clone(), - scalar_arg.clone(), - scalar_arg, - array_arg, - ]))?; - assert_eq!(result, vec![5, 1, 5, 5]); - - Ok(()) - } - - #[test] - fn test_make_scalar_function_with_hints_than_arguments() -> Result<()> { - let adapter_func = make_scalar_function_with_hints( - dummy_function, - vec![ - Hint::Pad, - Hint::AcceptsSingular, - Hint::Pad, - Hint::Pad, - Hint::AcceptsSingular, - Hint::Pad, - ], - ); - - let scalar_arg = ColumnarValue::Scalar(ScalarValue::Int64(Some(1))); - let array_arg = ColumnarValue::Array( - ScalarValue::Int64(Some(1)) - .to_array_of_size(5) - .expect("Failed to convert to array of size"), - ); - let result = unpack_uint64_array(adapter_func(&[array_arg, scalar_arg]))?; - assert_eq!(result, vec![5, 1]); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs b/datafusion/physical-expr/src/intervals/cp_solver.rs index 0c25e26d17aa..8084a52c78d8 100644 --- a/datafusion/physical-expr/src/intervals/cp_solver.rs +++ b/datafusion/physical-expr/src/intervals/cp_solver.rs @@ -19,6 +19,7 @@ use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; use super::utils::{ @@ -128,12 +129,11 @@ impl ExprIntervalGraph { /// Estimate size of bytes including `Self`. pub fn size(&self) -> usize { let node_memory_usage = self.graph.node_count() - * (std::mem::size_of::() - + std::mem::size_of::()); - let edge_memory_usage = self.graph.edge_count() - * (std::mem::size_of::() + std::mem::size_of::() * 2); + * (size_of::() + size_of::()); + let edge_memory_usage = + self.graph.edge_count() * (size_of::() + size_of::() * 2); - std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage + size_of_val(self) + node_memory_usage + edge_memory_usage } } @@ -176,11 +176,11 @@ impl ExprIntervalGraphNode { &self.interval } - /// This function creates a DAEG node from Datafusion's [`ExprTreeNode`] + /// This function creates a DAEG node from DataFusion's [`ExprTreeNode`] /// object. Literals are created with definite, singleton intervals while /// any other expression starts with an indefinite interval ([-∞, ∞]). pub fn make_node(node: &ExprTreeNode, schema: &Schema) -> Result { - let expr = node.expr.clone(); + let expr = Arc::clone(&node.expr); if let Some(literal) = expr.as_any().downcast_ref::() { let value = literal.value(); Interval::try_new(value.clone(), value.clone()) @@ -422,7 +422,7 @@ impl ExprIntervalGraph { let mut removals = vec![]; let mut expr_node_indices = exprs .iter() - .map(|e| (e.clone(), usize::MAX)) + .map(|e| (Arc::clone(e), usize::MAX)) .collect::>(); while let Some(node) = bfs.next(graph) { // Get the plan corresponding to this node: @@ -723,6 +723,7 @@ mod tests { use crate::intervals::test_utils::gen_conjunctive_numerical_expr; use arrow::datatypes::TimeUnit; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::Field; use datafusion_common::ScalarValue; @@ -743,16 +744,17 @@ mod tests { schema: &Schema, ) -> Result<()> { let col_stats = vec![ - (exprs_with_interval.0.clone(), left_interval), - (exprs_with_interval.1.clone(), right_interval), + (Arc::clone(&exprs_with_interval.0), left_interval), + (Arc::clone(&exprs_with_interval.1), right_interval), ]; let expected = vec![ - (exprs_with_interval.0.clone(), left_expected), - (exprs_with_interval.1.clone(), right_expected), + (Arc::clone(&exprs_with_interval.0), left_expected), + (Arc::clone(&exprs_with_interval.1), right_expected), ]; let mut graph = ExprIntervalGraph::try_new(expr, schema)?; - let expr_indexes = graph - .gather_node_indices(&col_stats.iter().map(|(e, _)| e.clone()).collect_vec()); + let expr_indexes = graph.gather_node_indices( + &col_stats.iter().map(|(e, _)| Arc::clone(e)).collect_vec(), + ); let mut col_stat_nodes = col_stats .iter() @@ -869,14 +871,21 @@ mod tests { // left_watermark > right_watermark + 5 let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), + Arc::clone(&left_col) as Arc, Operator::Plus, Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), )); - let expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, right_col.clone())); + let expr = Arc::new(BinaryExpr::new( + left_and_1, + Operator::Gt, + Arc::clone(&right_col) as Arc, + )); experiment( expr, - (left_col.clone(), right_col.clone()), + ( + Arc::clone(&left_col) as Arc, + Arc::clone(&right_col) as Arc, + ), Interval::make(Some(10_i32), Some(20_i32))?, Interval::make(Some(100), None)?, Interval::make(Some(10), Some(20))?, @@ -1390,9 +1399,17 @@ mod tests { )?; let right_child = Interval::try_new( // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 0, + days: 1, + nanoseconds: 321, + })), // 1 day 321 ns - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 0, + days: 1, + nanoseconds: 321, + })), )?; let children = vec![&left_child, &right_child]; let result = expression @@ -1415,9 +1432,17 @@ mod tests { )?, Interval::try_new( // 1 day 321 ns in Duration type - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 0, + days: 1, + nanoseconds: 321, + })), // 1 day 321 ns in Duration type - ScalarValue::IntervalMonthDayNano(Some(0x1_0000_0000_0000_0141)), + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 0, + days: 1, + nanoseconds: 321, + })), )? ], result @@ -1446,10 +1471,16 @@ mod tests { ScalarValue::TimestampMillisecond(Some(1_603_188_672_000), None), )?; let left_child = Interval::try_new( - // 2 days - ScalarValue::IntervalDayTime(Some(172_800_000)), - // 10 days - ScalarValue::IntervalDayTime(Some(864_000_000)), + // 2 days in millisecond + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 172_800_000, + })), + // 10 days in millisecond + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 864_000_000, + })), )?; let children = vec![&left_child, &right_child]; let result = expression @@ -1459,10 +1490,16 @@ mod tests { assert_eq!( vec![ Interval::try_new( - // 2 days - ScalarValue::IntervalDayTime(Some(172_800_000)), + // 2 days in millisecond + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 172_800_000, + })), // 6 days - ScalarValue::IntervalDayTime(Some(518_400_000)), + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 518_400_000, + })), )?, Interval::try_new( // 10.10.2020 - 10:11:12 AM diff --git a/datafusion/physical-expr/src/intervals/test_utils.rs b/datafusion/physical-expr/src/intervals/test_utils.rs index 075b8240353d..cedf55bccbf2 100644 --- a/datafusion/physical-expr/src/intervals/test_utils.rs +++ b/datafusion/physical-expr/src/intervals/test_utils.rs @@ -41,12 +41,12 @@ pub fn gen_conjunctive_numerical_expr( ) -> Arc { let (op_1, op_2, op_3, op_4) = op; let left_and_1 = Arc::new(BinaryExpr::new( - left_col.clone(), + Arc::clone(&left_col), op_1, Arc::new(Literal::new(a)), )); let left_and_2 = Arc::new(BinaryExpr::new( - right_col.clone(), + Arc::clone(&right_col), op_2, Arc::new(Literal::new(b)), )); @@ -78,8 +78,18 @@ pub fn gen_conjunctive_temporal_expr( d: ScalarValue, schema: &Schema, ) -> Result, DataFusionError> { - let left_and_1 = binary(left_col.clone(), op_1, Arc::new(Literal::new(a)), schema)?; - let left_and_2 = binary(right_col.clone(), op_2, Arc::new(Literal::new(b)), schema)?; + let left_and_1 = binary( + Arc::clone(&left_col), + op_1, + Arc::new(Literal::new(a)), + schema, + )?; + let left_and_2 = binary( + Arc::clone(&right_col), + op_2, + Arc::new(Literal::new(b)), + schema, + )?; let right_and_1 = binary(left_col, op_3, Arc::new(Literal::new(c)), schema)?; let right_and_2 = binary(right_col, op_4, Arc::new(Literal::new(d)), schema)?; let left_expr = Arc::new(BinaryExpr::new(left_and_1, Operator::Gt, left_and_2)); diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index e188b2d56bae..b426a656fba9 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -24,15 +24,12 @@ use crate::{ PhysicalExpr, }; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::{DataType, SchemaRef}; -use datafusion_common::{internal_datafusion_err, internal_err, Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::Operator; -const MDN_DAY_MASK: i128 = 0xFFFF_FFFF_0000_0000_0000_0000; -const MDN_NS_MASK: i128 = 0xFFFF_FFFF_FFFF_FFFF; -const DT_MS_MASK: i64 = 0xFFFF_FFFF; - /// Indicates whether interval arithmetic is supported for the given expression. /// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. /// We do not support every type of [`Operator`]s either. Over time, this check @@ -172,15 +169,9 @@ fn convert_duration_bound_to_interval( /// If both the month and day fields of [`ScalarValue::IntervalMonthDayNano`] are zero, this function returns the nanoseconds part. /// Otherwise, it returns an error. -fn interval_mdn_to_duration_ns(mdn: &i128) -> Result { - let months = mdn >> 96; - let days = (mdn & MDN_DAY_MASK) >> 64; - let nanoseconds = mdn & MDN_NS_MASK; - - if months == 0 && days == 0 { - nanoseconds - .try_into() - .map_err(|_| internal_datafusion_err!("Resulting duration exceeds i64::MAX")) +fn interval_mdn_to_duration_ns(mdn: &IntervalMonthDayNano) -> Result { + if mdn.months == 0 && mdn.days == 0 { + Ok(mdn.nanoseconds) } else { internal_err!( "The interval cannot have a non-zero month or day value for duration convertibility" @@ -190,12 +181,10 @@ fn interval_mdn_to_duration_ns(mdn: &i128) -> Result { /// If the day field of the [`ScalarValue::IntervalDayTime`] is zero, this function returns the milliseconds part. /// Otherwise, it returns an error. -fn interval_dt_to_duration_ms(dt: &i64) -> Result { - let days = dt >> 32; - let milliseconds = dt & DT_MS_MASK; - - if days == 0 { - Ok(milliseconds) +fn interval_dt_to_duration_ms(dt: &IntervalDayTime) -> Result { + if dt.days == 0 { + // Safe to cast i32 to i64 + Ok(dt.milliseconds as i64) } else { internal_err!( "The interval cannot have a non-zero day value for duration convertibility" diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index e0f19ad133e5..e7c2b4119c5a 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -15,19 +15,25 @@ // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +// Backward compatibility pub mod aggregate; pub mod analysis; -pub mod binary_map; +pub mod binary_map { + pub use datafusion_physical_expr_common::binary_map::{ArrowBytesSet, OutputType}; +} pub mod equivalence; pub mod expressions; -pub mod functions; pub mod intervals; -pub mod math_expressions; mod partitioning; mod physical_expr; pub mod planner; mod scalar_function; -pub mod udf; +pub mod udf { + pub use crate::scalar_function::create_physical_expr; +} pub mod utils; pub mod window; @@ -39,8 +45,7 @@ pub mod execution_props { pub use aggregate::groups_accumulator::{GroupsAccumulatorAdapter, NullState}; pub use analysis::{analyze, AnalysisContext, ExprBoundaries}; -pub use datafusion_physical_expr_common::aggregate::AggregateExpr; -pub use equivalence::EquivalenceProperties; +pub use equivalence::{calculate_union, ConstExpr, EquivalenceProperties}; pub use partitioning::{Distribution, Partitioning}; pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, @@ -59,13 +64,6 @@ pub use scalar_function::ScalarFunctionExpr; pub use datafusion_physical_expr_common::utils::reverse_order_bys; pub use utils::split_conjunction; -// For backwards compatibility -pub mod sort_properties { - pub use datafusion_physical_expr_common::sort_properties::{ - ExprOrdering, SortProperties, - }; -} - // For backwards compatibility pub mod tree_node { pub use datafusion_physical_expr_common::tree_node::ExprContext; diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs deleted file mode 100644 index 503565b1e261..000000000000 --- a/datafusion/physical-expr/src/math_expressions.rs +++ /dev/null @@ -1,126 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Math expressions - -use std::any::type_name; -use std::sync::Arc; - -use arrow::array::ArrayRef; -use arrow::array::{BooleanArray, Float32Array, Float64Array}; -use arrow::datatypes::DataType; -use arrow_array::Array; - -use datafusion_common::exec_err; -use datafusion_common::{DataFusionError, Result}; - -macro_rules! downcast_arg { - ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast {} from {} to {}", - $NAME, - $ARG.data_type(), - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! make_function_scalar_inputs_return_type { - ($ARG: expr, $NAME:expr, $ARGS_TYPE:ident, $RETURN_TYPE:ident, $FUNC: block) => {{ - let arg = downcast_arg!($ARG, $NAME, $ARGS_TYPE); - - arg.iter() - .map(|a| match a { - Some(a) => Some($FUNC(a)), - _ => None, - }) - .collect::<$RETURN_TYPE>() - }}; -} - -/// Isnan SQL function -pub fn isnan(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float64Array, - BooleanArray, - { f64::is_nan } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_scalar_inputs_return_type!( - &args[0], - "x", - Float32Array, - BooleanArray, - { f32::is_nan } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function isnan"), - } -} - -#[cfg(test)] -mod tests { - - use datafusion_common::cast::as_boolean_array; - - use super::*; - - #[test] - fn test_isnan_f64() { - let args: Vec = vec![Arc::new(Float64Array::from(vec![ - 1.0, - f64::NAN, - 3.0, - -f64::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } - - #[test] - fn test_isnan_f32() { - let args: Vec = vec![Arc::new(Float32Array::from(vec![ - 1.0, - f32::NAN, - 3.0, - f32::NAN, - ]))]; - - let result = isnan(&args).expect("failed to initialize function isnan"); - let booleans = - as_boolean_array(&result).expect("failed to initialize function isnan"); - - assert_eq!(booleans.len(), 4); - assert!(!booleans.value(0)); - assert!(booleans.value(1)); - assert!(!booleans.value(2)); - assert!(booleans.value(3)); - } -} diff --git a/datafusion/physical-expr/src/partitioning.rs b/datafusion/physical-expr/src/partitioning.rs index fcb3278b6022..98c0c864b9f7 100644 --- a/datafusion/physical-expr/src/partitioning.rs +++ b/datafusion/physical-expr/src/partitioning.rs @@ -17,15 +17,19 @@ //! [`Partitioning`] and [`Distribution`] for `ExecutionPlans` +use crate::{ + equivalence::ProjectionMapping, expressions::UnKnownColumn, physical_exprs_equal, + EquivalenceProperties, PhysicalExpr, +}; +use datafusion_physical_expr_common::physical_expr::format_physical_expr_list; use std::fmt; +use std::fmt::Display; use std::sync::Arc; -use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; - /// Output partitioning supported by [`ExecutionPlan`]s. /// -/// When `executed`, `ExecutionPlan`s produce one or more independent stream of -/// data batches in parallel, referred to as partitions. The streams are Rust +/// Calling [`ExecutionPlan::execute`] produce one or more independent streams of +/// [`RecordBatch`]es in parallel, referred to as partitions. The streams are Rust /// `async` [`Stream`]s (a special kind of future). The number of output /// partitions varies based on the input and the operation performed. /// @@ -84,9 +88,9 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// └──────────┐│┌──────────┘ /// │││ /// │││ -/// RepartitionExec with one input -/// that has 3 partitions, but 3 (async) streams, that internally -/// itself has only 1 output partition pull from the same input stream +/// RepartitionExec with 1 input +/// partition and 3 output partitions 3 (async) streams, that internally +/// pull from the same input stream /// ... /// ``` /// @@ -102,6 +106,8 @@ use crate::{physical_exprs_equal, EquivalenceProperties, PhysicalExpr}; /// Plans such as `FilterExec` produce the same number of output streams /// (partitions) as input streams (partitions). /// +/// [`RecordBatch`]: arrow::record_batch::RecordBatch +/// [`ExecutionPlan::execute`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#tymethod.execute /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html /// [`Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html #[derive(Debug, Clone)] @@ -115,8 +121,8 @@ pub enum Partitioning { UnknownPartitioning(usize), } -impl fmt::Display for Partitioning { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for Partitioning { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Partitioning::RoundRobinBatch(size) => write!(f, "RoundRobinBatch({size})"), Partitioning::Hash(phy_exprs, size) => { @@ -152,6 +158,8 @@ impl Partitioning { match required { Distribution::UnspecifiedDistribution => true, Distribution::SinglePartition if self.partition_count() == 1 => true, + // When partition count is 1, hash requirement is satisfied. + Distribution::HashPartitioned(_) if self.partition_count() == 1 => true, Distribution::HashPartitioned(required_exprs) => { match self { // Here we do not check the partition count for hash partitioning and assumes the partition count @@ -167,11 +175,11 @@ impl Partitioning { if !eq_groups.is_empty() { let normalized_required_exprs = required_exprs .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect::>(); let normalized_partition_exprs = partition_exprs .iter() - .map(|e| eq_groups.normalize_expr(e.clone())) + .map(|e| eq_groups.normalize_expr(Arc::clone(e))) .collect::>(); return physical_exprs_equal( &normalized_required_exprs, @@ -187,6 +195,29 @@ impl Partitioning { _ => false, } } + + /// Calculate the output partitioning after applying the given projection. + pub fn project( + &self, + projection_mapping: &ProjectionMapping, + input_eq_properties: &EquivalenceProperties, + ) -> Self { + if let Partitioning::Hash(exprs, part) = self { + let normalized_exprs = exprs + .iter() + .map(|expr| { + input_eq_properties + .project_expr(expr, projection_mapping) + .unwrap_or_else(|| { + Arc::new(UnKnownColumn::new(&expr.to_string())) + }) + }) + .collect(); + Partitioning::Hash(normalized_exprs, *part) + } else { + self.clone() + } + } } impl PartialEq for Partitioning { @@ -234,6 +265,18 @@ impl Distribution { } } +impl Display for Distribution { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Distribution::UnspecifiedDistribution => write!(f, "Unspecified"), + Distribution::SinglePartition => write!(f, "SinglePartition"), + Distribution::HashPartitioned(exprs) => { + write!(f, "HashPartitioned[{}])", format_physical_expr_list(exprs)) + } + } + } +} + #[cfg(test)] mod tests { @@ -290,7 +333,7 @@ mod tests { assert_eq!(result, (true, false, false, false, false)) } Distribution::HashPartitioned(_) => { - assert_eq!(result, (false, false, false, true, false)) + assert_eq!(result, (true, false, false, true, false)) } } } diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 127194f681a5..c718e6b054ef 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +pub(crate) use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use itertools::izip; pub use datafusion_physical_expr_common::physical_expr::down_cast_any_ref; @@ -117,12 +117,12 @@ mod tests { // lit(true), lit(false), lit(4), lit(2), Col(a), Col(b) let physical_exprs: Vec> = vec![ - lit_true.clone(), - lit_false.clone(), - lit4.clone(), - lit2.clone(), - col_a_expr.clone(), - col_b_expr.clone(), + Arc::clone(&lit_true), + Arc::clone(&lit_false), + Arc::clone(&lit4), + Arc::clone(&lit2), + Arc::clone(&col_a_expr), + Arc::clone(&col_b_expr), ]; // below expressions are inside physical_exprs assert!(physical_exprs_contains(&physical_exprs, &lit_true)); @@ -146,10 +146,10 @@ mod tests { Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; let col_b_expr = Arc::new(Column::new("b", 1)) as Arc; - let vec1 = vec![lit_true.clone(), lit_false.clone()]; - let vec2 = vec![lit_true.clone(), col_b_expr.clone()]; - let vec3 = vec![lit2.clone(), lit1.clone()]; - let vec4 = vec![lit_true.clone(), lit_false.clone()]; + let vec1 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; + let vec2 = vec![Arc::clone(&lit_true), Arc::clone(&col_b_expr)]; + let vec3 = vec![Arc::clone(&lit2), Arc::clone(&lit1)]; + let vec4 = vec![Arc::clone(&lit_true), Arc::clone(&lit_false)]; // these vectors are same assert!(physical_exprs_equal(&vec1, &vec1)); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index f46e5f6ec68f..bffc2c46fc1e 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,23 +17,22 @@ use std::sync::Arc; -use arrow::datatypes::Schema; +use crate::scalar_function; +use crate::{ + expressions::{self, binary, like, similar_to, Column, Literal}, + PhysicalExpr, +}; +use arrow::datatypes::Schema; use datafusion_common::{ - exec_err, internal_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, + exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{Alias, Cast, InList, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ - binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, - Operator, ScalarFunctionDefinition, TryCast, -}; - -use crate::{ - expressions::{self, binary, like, Column, Literal}, - udf, PhysicalExpr, + binary_expr, lit, Between, BinaryExpr, Expr, Like, Operator, TryCast, }; /// [PhysicalExpr] evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -143,32 +142,26 @@ pub fn create_physical_expr( let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(Some(true))), + lit(true), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsNotTrue(expr) => { - let binary_op = binary_expr( - expr.as_ref().clone(), - Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(Some(true))), - ); + let binary_op = + binary_expr(expr.as_ref().clone(), Operator::IsDistinctFrom, lit(true)); create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsFalse(expr) => { let binary_op = binary_expr( expr.as_ref().clone(), Operator::IsNotDistinctFrom, - Expr::Literal(ScalarValue::Boolean(Some(false))), + lit(false), ); create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsNotFalse(expr) => { - let binary_op = binary_expr( - expr.as_ref().clone(), - Operator::IsDistinctFrom, - Expr::Literal(ScalarValue::Boolean(Some(false))), - ); + let binary_op = + binary_expr(expr.as_ref().clone(), Operator::IsDistinctFrom, lit(false)); create_physical_expr(&binary_op, input_dfschema, execution_props) } Expr::IsUnknown(expr) => { @@ -222,6 +215,22 @@ pub fn create_physical_expr( input_schema, ) } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + case_insensitive, + }) => { + if escape_char.is_some() { + return exec_err!("SIMILAR TO does not support escape_char yet"); + } + let physical_expr = + create_physical_expr(expr, input_dfschema, execution_props)?; + let physical_pattern = + create_physical_expr(pattern, input_dfschema, execution_props)?; + similar_to(*negated, *case_insensitive, physical_expr, physical_pattern) + } Expr::Case(case) => { let expr: Option> = if let Some(e) = &case.expr { Some(create_physical_expr( @@ -245,7 +254,7 @@ pub fn create_physical_expr( when_expr .iter() .zip(then_expr.iter()) - .map(|(w, t)| (w.clone(), t.clone())) + .map(|(w, t)| (Arc::clone(w), Arc::clone(t))) .collect(); let else_expr: Option> = if let Some(e) = &case.else_expr { @@ -286,40 +295,17 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::GetIndexedField(GetIndexedField { expr: _, field }) => match field { - GetFieldAccess::NamedStructField { name: _ } => { - internal_err!( - "NamedStructField should be rewritten in OperatorToFunction" - ) - } - GetFieldAccess::ListIndex { key: _ } => { - internal_err!("ListIndex should be rewritten in OperatorToFunction") - } - GetFieldAccess::ListRange { - start: _, - stop: _, - stride: _, - } => { - internal_err!("ListRange should be rewritten in OperatorToFunction") - } - }, - - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - match func_def { - ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - args, - input_dfschema, - ), - ScalarFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } - } + scalar_function::create_physical_expr( + Arc::clone(func).as_ref(), + &physical_args, + input_schema, + args, + input_dfschema, + ) } Expr::Between(Between { expr, @@ -333,9 +319,19 @@ pub fn create_physical_expr( // rewrite the between into the two binary operators let binary_expr = binary( - binary(value_expr.clone(), Operator::GtEq, low_expr, input_schema)?, + binary( + Arc::clone(&value_expr), + Operator::GtEq, + low_expr, + input_schema, + )?, Operator::And, - binary(value_expr.clone(), Operator::LtEq, high_expr, input_schema)?, + binary( + Arc::clone(&value_expr), + Operator::LtEq, + high_expr, + input_schema, + )?, input_schema, ); @@ -383,6 +379,13 @@ where .collect::>>() } +/// Convert a logical expression to a physical expression (without any simplification, etc) +pub fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, &execution_props).unwrap() +} + #[cfg(test)] mod tests { use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 9ae9f3dee3e7..ab53106f6059 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,32 +34,25 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use arrow::datatypes::{DataType, Schema}; -use arrow::record_batch::RecordBatch; - -use datafusion_common::{internal_err, Result}; -use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, FuncMonotonicity, ScalarFunctionDefinition, -}; - -use crate::functions::out_ordering; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; -use crate::sort_properties::SortProperties; use crate::PhysicalExpr; +use arrow::datatypes::{DataType, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::Array; +use datafusion_common::{internal_err, DFSchema, Result, ScalarValue}; +use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::sort_properties::ExprProperties; +use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; +use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, ScalarUDF}; + /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionDefinition, + fun: Arc, name: String, args: Vec>, return_type: DataType, - // Keeps monotonicity information of the function. - // FuncMonotonicity vector is one to one mapped to `args`, - // and it specifies the effect of an increase or decrease in - // the corresponding `arg` to the function value. - monotonicity: Option, - // Whether this function can be invoked with zero arguments - supports_zero_argument: bool, + nullable: bool, } impl Debug for ScalarFunctionExpr { @@ -69,8 +62,6 @@ impl Debug for ScalarFunctionExpr { .field("name", &self.name) .field("args", &self.args) .field("return_type", &self.return_type) - .field("monotonicity", &self.monotonicity) - .field("supports_zero_argument", &self.supports_zero_argument) .finish() } } @@ -79,24 +70,21 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionDefinition, + fun: Arc, args: Vec>, return_type: DataType, - monotonicity: Option, - supports_zero_argument: bool, ) -> Self { Self { fun, name: name.to_owned(), args, return_type, - monotonicity, - supports_zero_argument, + nullable: true, } } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionDefinition { + pub fn fun(&self) -> &ScalarUDF { &self.fun } @@ -115,9 +103,13 @@ impl ScalarFunctionExpr { &self.return_type } - /// Monotonicity information of the function - pub fn monotonicity(&self) -> &Option { - &self.monotonicity + pub fn with_nullable(mut self, nullable: bool) -> Self { + self.nullable = nullable; + self + } + + pub fn nullable(&self) -> bool { + self.nullable } } @@ -138,53 +130,68 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(true) + Ok(self.nullable) } fn evaluate(&self, batch: &RecordBatch) -> Result { - // evaluate the arguments, if there are no arguments we'll instead pass in a null array - // indicating the batch size (as a convention) - let inputs = match self.args.is_empty() { - // If the function supports zero argument, we pass in a null array indicating the batch size. - // This is for user-defined functions. - // MakeArray support zero argument but has the different behavior from the array with one null. - true if self.supports_zero_argument && self.name != "make_array" => { - vec![ColumnarValue::create_null_array(batch.num_rows())] - } - _ => self - .args - .iter() - .map(|e| e.evaluate(batch)) - .collect::>>()?, - }; + let inputs = self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; // evaluate the function - match self.fun { - ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), - ScalarFunctionDefinition::Name(_) => { - internal_err!( - "Name function must be resolved to one of the other variants prior to physical planning" - ) + let output = self.fun.invoke_batch(&inputs, batch.num_rows())?; + + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + // If the arguments are a non-empty slice of scalar values, we can assume that + // returning a one-element array is equivalent to returning a scalar. + let preserve_scalar = array.len() == 1 + && !inputs.is_empty() + && inputs + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + return if preserve_scalar { + ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar) + } else { + internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + batch.num_rows(), array.len()) + }; } } + Ok(output) } - fn children(&self) -> Vec> { - self.args.clone() + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() } fn with_new_children( self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(ScalarFunctionExpr::new( - &self.name, - self.fun.clone(), - children, - self.return_type().clone(), - self.monotonicity.clone(), - self.supports_zero_argument, - ))) + Ok(Arc::new( + ScalarFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + self.return_type().clone(), + ) + .with_nullable(self.nullable), + )) + } + + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + self.fun.evaluate_bounds(children) + } + + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + self.fun.propagate_constraints(interval, children) } fn dyn_hash(&self, state: &mut dyn Hasher) { @@ -195,11 +202,18 @@ impl PhysicalExpr for ScalarFunctionExpr { // Add `self.fun` when hash is available } - fn get_ordering(&self, children: &[SortProperties]) -> SortProperties { - self.monotonicity - .as_ref() - .map(|monotonicity| out_ordering(monotonicity, children)) - .unwrap_or(SortProperties::Unordered) + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let sort_properties = self.fun.output_ordering(children)?; + let children_range = children + .iter() + .map(|props| &props.range) + .collect::>(); + let range = self.fun().evaluate_bounds(&children_range)?; + + Ok(ExprProperties { + sort_properties, + range, + }) } } @@ -216,3 +230,34 @@ impl PartialEq for ScalarFunctionExpr { .unwrap_or(false) } } + +/// Create a physical expression for the UDF. +pub fn create_physical_expr( + fun: &ScalarUDF, + input_phy_exprs: &[Arc], + input_schema: &Schema, + args: &[Expr], + input_dfschema: &DFSchema, +) -> Result> { + let input_expr_types = input_phy_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + + // verify that input data types is consistent with function's `TypeSignature` + data_types_with_scalar_udf(&input_expr_types, fun)?; + + // Since we have arg_types, we dont need args and schema. + let return_type = + fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + + Ok(Arc::new( + ScalarFunctionExpr::new( + fun.name(), + Arc::new(fun.clone()), + input_phy_exprs.to_vec(), + return_type, + ) + .with_nullable(fun.is_nullable(args, input_dfschema)), + )) +} diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs deleted file mode 100644 index 368dfdf92f45..000000000000 --- a/datafusion/physical-expr/src/udf.rs +++ /dev/null @@ -1,96 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! UDF support -use std::sync::Arc; - -use arrow_schema::Schema; - -use datafusion_common::{DFSchema, Result}; -pub use datafusion_expr::ScalarUDF; -use datafusion_expr::{ - type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, -}; - -use crate::{PhysicalExpr, ScalarFunctionExpr}; - -/// Create a physical expression of the UDF. -/// -/// Arguments: -pub fn create_physical_expr( - fun: &ScalarUDF, - input_phy_exprs: &[Arc], - input_schema: &Schema, - args: &[Expr], - input_dfschema: &DFSchema, -) -> Result> { - let input_expr_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - - // verify that input data types is consistent with function's `TypeSignature` - data_types(&input_expr_types, fun.signature())?; - - // Since we have arg_types, we dont need args and schema. - let return_type = - fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; - - let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); - Ok(Arc::new(ScalarFunctionExpr::new( - fun.name(), - fun_def, - input_phy_exprs.to_vec(), - return_type, - fun.monotonicity()?, - fun.signature().type_signature.supports_zero_argument(), - ))) -} - -#[cfg(test)] -mod tests { - use arrow_schema::Schema; - - use datafusion_common::{DFSchema, Result}; - use datafusion_expr::ScalarUDF; - - use crate::utils::tests::TestScalarUDF; - use crate::ScalarFunctionExpr; - - use super::create_physical_expr; - - #[test] - fn test_functions() -> Result<()> { - // create and register the udf - let udf = ScalarUDF::from(TestScalarUDF::new()); - - let e = crate::expressions::lit(1.1); - let p_expr = - create_physical_expr(&udf, &[e], &Schema::empty(), &[], &DFSchema::empty())?; - - assert_eq!( - p_expr - .as_any() - .downcast_ref::() - .unwrap() - .monotonicity(), - &Some(vec![Some(true)]) - ); - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/utils/guarantee.rs b/datafusion/physical-expr/src/utils/guarantee.rs index e441fe8f4802..fbb59cc92fa0 100644 --- a/datafusion/physical-expr/src/utils/guarantee.rs +++ b/datafusion/physical-expr/src/utils/guarantee.rs @@ -62,14 +62,14 @@ use std::sync::Arc; /// A guarantee can be one of two forms: /// /// 1. The column must be one the values for the predicate to be `true`. If the -/// column takes on any other value, the predicate can not evaluate to `true`. -/// For example, -/// `(a = 1)`, `(a = 1 OR a = 2) or `a IN (1, 2, 3)` +/// column takes on any other value, the predicate can not evaluate to `true`. +/// For example, +/// `(a = 1)`, `(a = 1 OR a = 2)` or `a IN (1, 2, 3)` /// /// 2. The column must NOT be one of the values for the predicate to be `true`. -/// If the column can ONLY take one of these values, the predicate can not -/// evaluate to `true`. For example, -/// `(a != 1)`, `(a != 1 AND a != 2)` or `a NOT IN (1, 2, 3)` +/// If the column can ONLY take one of these values, the predicate can not +/// evaluate to `true`. For example, +/// `(a != 1)`, `(a != 1 AND a != 2)` or `a NOT IN (1, 2, 3)` #[derive(Debug, Clone, PartialEq)] pub struct LiteralGuarantee { pub column: Column, @@ -93,18 +93,18 @@ impl LiteralGuarantee { /// Create a new instance of the guarantee if the provided operator is /// supported. Returns None otherwise. See [`LiteralGuarantee::analyze`] to /// create these structures from an predicate (boolean expression). - fn try_new<'a>( + fn new<'a>( column_name: impl Into, guarantee: Guarantee, literals: impl IntoIterator, - ) -> Option { + ) -> Self { let literals: HashSet<_> = literals.into_iter().cloned().collect(); - Some(Self { + Self { column: Column::from_name(column_name), guarantee, literals, - }) + } } /// Return a list of [`LiteralGuarantee`]s that must be satisfied for `expr` @@ -225,26 +225,21 @@ impl LiteralGuarantee { impl Display for LiteralGuarantee { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let mut sorted_literals: Vec<_> = + self.literals.iter().map(|lit| lit.to_string()).collect(); + sorted_literals.sort(); match self.guarantee { Guarantee::In => write!( f, "{} in ({})", self.column.name, - self.literals - .iter() - .map(|lit| lit.to_string()) - .collect::>() - .join(", ") + sorted_literals.join(", ") ), Guarantee::NotIn => write!( f, "{} not in ({})", self.column.name, - self.literals - .iter() - .map(|lit| lit.to_string()) - .collect::>() - .join(", ") + sorted_literals.join(", ") ), } } @@ -283,7 +278,7 @@ impl<'a> GuaranteeBuilder<'a> { ) } - /// Aggregates a new single column, multi literal term to ths builder + /// Aggregates a new single column, multi literal term to this builder /// combining with previously known guarantees if possible. /// /// # Examples @@ -343,13 +338,10 @@ impl<'a> GuaranteeBuilder<'a> { // This is a new guarantee let new_values: HashSet<_> = new_values.into_iter().collect(); - if let Some(guarantee) = - LiteralGuarantee::try_new(col.name(), guarantee, new_values) - { - // add it to the list of guarantees - self.guarantees.push(Some(guarantee)); - self.map.insert(key, self.guarantees.len() - 1); - } + let guarantee = LiteralGuarantee::new(col.name(), guarantee, new_values); + // add it to the list of guarantees + self.guarantees.push(Some(guarantee)); + self.map.insert(key, self.guarantees.len() - 1); } self @@ -374,6 +366,7 @@ impl<'a> ColOpLit<'a> { /// 1. `col literal` /// 2. `literal col` /// 3. operator is `=` or `!=` + /// /// Returns None otherwise fn try_new(expr: &'a Arc) -> Option { let binary_expr = expr @@ -419,15 +412,16 @@ impl<'a> ColOpLit<'a> { #[cfg(test)] mod test { + use std::sync::OnceLock; + use super::*; - use crate::create_physical_expr; + use crate::planner::logical2physical; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::ToDFSchema; - use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_fn::*; use datafusion_expr::{lit, Expr}; + use itertools::Itertools; - use std::sync::OnceLock; #[test] fn test_literal() { @@ -854,7 +848,7 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap() + LiteralGuarantee::new(column, Guarantee::In, literals.iter()) } /// Guarantee that the expression is true if the column is NOT any of the specified values @@ -864,26 +858,17 @@ mod test { S: Into + 'a, { let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect(); - LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap() - } - - /// Convert a logical expression to a physical expression (without any simplification, etc) - fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { - let df_schema = schema.clone().to_dfschema().unwrap(); - let execution_props = ExecutionProps::new(); - create_physical_expr(expr, &df_schema, &execution_props).unwrap() + LiteralGuarantee::new(column, Guarantee::NotIn, literals.iter()) } // Schema for testing fn schema() -> SchemaRef { - SCHEMA - .get_or_init(|| { - Arc::new(Schema::new(vec![ - Field::new("a", DataType::Utf8, false), - Field::new("b", DataType::Int32, false), - ])) - }) - .clone() + Arc::clone(SCHEMA.get_or_init(|| { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ])) + })) } static SCHEMA: OnceLock = OnceLock::new(); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 76cee3a1a786..c3d1b1425b7f 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -17,9 +17,10 @@ mod guarantee; pub use guarantee::{Guarantee, LiteralGuarantee}; +use hashbrown::HashSet; use std::borrow::Borrow; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; @@ -34,6 +35,7 @@ use datafusion_common::tree_node::{ use datafusion_common::Result; use datafusion_expr::Operator; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; use itertools::Itertools; use petgraph::graph::NodeIndex; use petgraph::stable_graph::StableGraph; @@ -85,6 +87,10 @@ pub fn map_columns_before_projection( parent_required: &[Arc], proj_exprs: &[(Arc, String)], ) -> Vec> { + if parent_required.is_empty() { + // No need to build mapping. + return vec![]; + } let column_mapping = proj_exprs .iter() .filter_map(|(expr, name)| { @@ -111,7 +117,7 @@ pub fn convert_to_expr>( ) -> Vec> { sequence .into_iter() - .map(|elem| elem.borrow().expr.clone()) + .map(|elem| Arc::clone(&elem.borrow().expr)) .collect() } @@ -166,7 +172,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> for expr_node in node.children.iter() { self.graph.add_edge(node_idx, expr_node.data.unwrap(), 0); } - self.visited_plans.push((expr.clone(), node_idx)); + self.visited_plans.push((Arc::clone(expr), node_idx)); node_idx } }; @@ -204,9 +210,7 @@ pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); expr.apply(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { - if !columns.iter().any(|c| c.eq(column)) { - columns.insert(column.clone()); - } + columns.get_or_insert_owned(column); } Ok(TreeNodeRecursion::Continue) }) @@ -242,10 +246,7 @@ pub fn reassign_predicate_columns( } /// Merge left and right sort expressions, checking for duplicates. -pub fn merge_vectors( - left: &[PhysicalSortExpr], - right: &[PhysicalSortExpr], -) -> Vec { +pub fn merge_vectors(left: LexOrderingRef, right: LexOrderingRef) -> LexOrdering { left.iter() .cloned() .chain(right.iter().cloned()) @@ -255,19 +256,18 @@ pub fn merge_vectors( #[cfg(test)] pub(crate) mod tests { - use arrow_array::{ArrayRef, Float32Array, Float64Array}; use std::any::Any; use std::fmt::{Display, Formatter}; use super::*; use crate::expressions::{binary, cast, col, in_list, lit, Literal}; + use arrow_array::{ArrayRef, Float32Array, Float64Array}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{exec_err, DataFusionError, ScalarValue}; + use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; - use datafusion_expr::{ - ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, - }; use petgraph::visit::Bfs; #[derive(Debug, Clone)] @@ -309,8 +309,8 @@ pub(crate) mod tests { } } - fn monotonicity(&self) -> Result> { - Ok(Some(vec![Some(true)])) + fn output_ordering(&self, input: &[ExprProperties]) -> Result { + Ok(input[0].sort_properties) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -380,7 +380,7 @@ pub(crate) mod tests { } fn make_dummy_node(node: &ExprTreeNode) -> Result { - let expr = node.expr.clone(); + let expr = Arc::clone(&node.expr); let dummy_property = if expr.as_any().is::() { "Binary" } else if expr.as_any().is::() { diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index 5892f7f3f3b0..94960c95e4bb 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -25,47 +25,46 @@ use arrow::array::Array; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::ScalarValue; -use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{Accumulator, WindowFrame}; - +use crate::aggregate::AggregateFunctionExpr; use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, SlidingAggregateWindowExpr, WindowExpr, }; -use crate::{ - expressions::PhysicalSortExpr, reverse_order_bys, AggregateExpr, PhysicalExpr, -}; +use crate::{reverse_order_bys, PhysicalExpr}; +use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; /// A window expr that takes the form of an aggregate function. /// /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct PlainAggregateWindowExpr { - aggregate: Arc, + aggregate: Arc, partition_by: Vec>, - order_by: Vec, + order_by: LexOrdering, window_frame: Arc, } impl PlainAggregateWindowExpr { /// Create a new aggregate window function expression pub fn new( - aggregate: Arc, + aggregate: Arc, partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: LexOrderingRef, window_frame: Arc, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.to_vec(), + order_by: LexOrdering::from_ref(order_by), window_frame, } } /// Get aggregate expr of AggregateWindowExpr - pub fn get_aggregate_expr(&self) -> &Arc { + pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr { &self.aggregate } } @@ -80,7 +79,7 @@ impl WindowExpr for PlainAggregateWindowExpr { } fn field(&self) -> Result { - self.aggregate.field() + Ok(self.aggregate.field()) } fn name(&self) -> &str { @@ -125,8 +124,8 @@ impl WindowExpr for PlainAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by + fn order_by(&self) -> LexOrderingRef { + self.order_by.as_ref() } fn get_window_frame(&self) -> &Arc { @@ -138,16 +137,16 @@ impl WindowExpr for PlainAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } @@ -176,9 +175,9 @@ impl AggregateWindowExpr for PlainAggregateWindowExpr { value_slice: &[ArrayRef], accumulator: &mut Box, ) -> Result { - let value = if cur_range.start == cur_range.end { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? + if cur_range.start == cur_range.end { + self.aggregate + .default_value(self.aggregate.field().data_type()) } else { // Accumulate any new rows that have entered the window: let update_bound = cur_range.end - last_range.end; @@ -193,8 +192,7 @@ impl AggregateWindowExpr for PlainAggregateWindowExpr { .collect(); accumulator.update_batch(&update)? } - accumulator.evaluate()? - }; - Ok(value) + accumulator.evaluate() + } } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 065260a73e0b..5f6c5e5c2c1b 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -22,11 +22,9 @@ use std::ops::Range; use std::sync::Arc; use super::{BuiltInWindowFunctionExpr, WindowExpr}; -use crate::expressions::PhysicalSortExpr; use crate::window::window_expr::{get_orderby_values, WindowFn}; use crate::window::{PartitionBatches, PartitionWindowAggStates, WindowState}; use crate::{reverse_order_bys, EquivalenceProperties, PhysicalExpr}; - use arrow::array::{new_empty_array, ArrayRef}; use arrow::compute::SortOptions; use arrow::datatypes::Field; @@ -35,13 +33,14 @@ use datafusion_common::utils::evaluate_partition_ranges; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::window_state::{WindowAggState, WindowFrameContext}; use datafusion_expr::WindowFrame; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; /// A window expr that takes the form of a [`BuiltInWindowFunctionExpr`]. #[derive(Debug)] pub struct BuiltInWindowExpr { expr: Arc, partition_by: Vec>, - order_by: Vec, + order_by: LexOrdering, window_frame: Arc, } @@ -50,13 +49,13 @@ impl BuiltInWindowExpr { pub fn new( expr: Arc, partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: LexOrderingRef, window_frame: Arc, ) -> Self { Self { expr, partition_by: partition_by.to_vec(), - order_by: order_by.to_vec(), + order_by: LexOrdering::from_ref(order_by), window_frame, } } @@ -77,7 +76,8 @@ impl BuiltInWindowExpr { if let Some(fn_res_ordering) = self.expr.get_result_ordering(schema) { if self.partition_by.is_empty() { // In the absence of a PARTITION BY, ordering of `self.expr` is global: - eq_properties.add_new_orderings([vec![fn_res_ordering]]); + eq_properties + .add_new_orderings([LexOrdering::new(vec![fn_res_ordering])]); } else { // If we have a PARTITION BY, built-in functions can not introduce // a global ordering unless the existing ordering is compatible @@ -118,8 +118,8 @@ impl WindowExpr for BuiltInWindowExpr { &self.partition_by } - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by + fn order_by(&self) -> LexOrderingRef { + self.order_by.as_ref() } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -137,7 +137,7 @@ impl WindowExpr for BuiltInWindowExpr { let order_bys_ref = &values[n_args..]; let mut window_frame_ctx = - WindowFrameContext::new(self.window_frame.clone(), sort_options); + WindowFrameContext::new(Arc::clone(&self.window_frame), sort_options); let mut last_range = Range { start: 0, end: 0 }; // We iterate on each row to calculate window frame range and and window function result for idx in 0..num_rows { @@ -217,7 +217,7 @@ impl WindowExpr for BuiltInWindowExpr { .window_frame_ctx .get_or_insert_with(|| { WindowFrameContext::new( - self.window_frame.clone(), + Arc::clone(&self.window_frame), sort_options.clone(), ) }) @@ -267,7 +267,7 @@ impl WindowExpr for BuiltInWindowExpr { Arc::new(BuiltInWindowExpr::new( reverse_expr, &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ }) diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs deleted file mode 100644 index 9720187ea83d..000000000000 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ /dev/null @@ -1,145 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `cume_dist` that can evaluated -//! at runtime during query execution - -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::array::Float64Array; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::iter; -use std::ops::Range; -use std::sync::Arc; - -/// CumeDist calculates the cume_dist in the window function with order by -#[derive(Debug)] -pub struct CumeDist { - name: String, - /// Output data type - data_type: DataType, -} - -/// Create a cume_dist window function -pub fn cume_dist(name: String, data_type: &DataType) -> CumeDist { - CumeDist { - name, - data_type: data_type.clone(), - } -} - -impl BuiltInWindowFunctionExpr for CumeDist { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(CumeDistEvaluator {})) - } -} - -#[derive(Debug)] -pub(crate) struct CumeDistEvaluator; - -impl PartitionEvaluator for CumeDistEvaluator { - fn evaluate_all_with_rank( - &self, - num_rows: usize, - ranks_in_partition: &[Range], - ) -> Result { - let scalar = num_rows as f64; - let result = Float64Array::from_iter_values( - ranks_in_partition - .iter() - .scan(0_u64, |acc, range| { - let len = range.end - range.start; - *acc += len as u64; - let value: f64 = (*acc as f64) / scalar; - let result = iter::repeat(value).take(len); - Some(result) - }) - .flatten(), - ); - Ok(Arc::new(result)) - } - - fn include_rank(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use datafusion_common::cast::as_float64_array; - - fn test_i32_result( - expr: &CumeDist, - num_rows: usize, - ranks: Vec>, - expected: Vec, - ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_all_with_rank(num_rows, &ranks)?; - let result = as_float64_array(&result)?; - let result = result.values(); - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - #[allow(clippy::single_range_in_vec_init)] - fn test_cume_dist() -> Result<()> { - let r = cume_dist("arr".into(), &DataType::Float64); - - let expected = vec![0.0; 0]; - test_i32_result(&r, 0, vec![], expected)?; - - let expected = vec![1.0; 1]; - test_i32_result(&r, 1, vec![0..1], expected)?; - - let expected = vec![1.0; 2]; - test_i32_result(&r, 2, vec![0..2], expected)?; - - let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; - - let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; - - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs deleted file mode 100644 index 7e35bddef568..000000000000 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ /dev/null @@ -1,489 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `lead` and `lag` that can evaluated -//! at runtime during query execution -use crate::window::BuiltInWindowFunctionExpr; -use crate::PhysicalExpr; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use arrow_array::Array; -use datafusion_common::{arrow_datafusion_err, DataFusionError, Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; -use std::any::Any; -use std::cmp::min; -use std::collections::VecDeque; -use std::ops::{Neg, Range}; -use std::sync::Arc; - -/// window shift expression -#[derive(Debug)] -pub struct WindowShift { - name: String, - /// Output data type - data_type: DataType, - shift_offset: i64, - expr: Arc, - default_value: ScalarValue, - ignore_nulls: bool, -} - -impl WindowShift { - /// Get shift_offset of window shift expression - pub fn get_shift_offset(&self) -> i64 { - self.shift_offset - } - - /// Get the default_value for window shift expression. - pub fn get_default_value(&self) -> ScalarValue { - self.default_value.clone() - } -} - -/// lead() window function -pub fn lead( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.map(|v| v.neg()).unwrap_or(-1), - expr, - default_value, - ignore_nulls, - } -} - -/// lag() window function -pub fn lag( - name: String, - data_type: DataType, - expr: Arc, - shift_offset: Option, - default_value: ScalarValue, - ignore_nulls: bool, -) -> WindowShift { - WindowShift { - name, - data_type, - shift_offset: shift_offset.unwrap_or(1), - expr, - default_value, - ignore_nulls, - } -} - -impl BuiltInWindowFunctionExpr for WindowShift { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(WindowShiftEvaluator { - shift_offset: self.shift_offset, - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, - non_null_offsets: VecDeque::new(), - })) - } - - fn reverse_expr(&self) -> Option> { - Some(Arc::new(Self { - name: self.name.clone(), - data_type: self.data_type.clone(), - shift_offset: -self.shift_offset, - expr: self.expr.clone(), - default_value: self.default_value.clone(), - ignore_nulls: self.ignore_nulls, - })) - } -} - -#[derive(Debug)] -pub(crate) struct WindowShiftEvaluator { - shift_offset: i64, - default_value: ScalarValue, - ignore_nulls: bool, - // VecDeque contains offset values that between non-null entries - non_null_offsets: VecDeque, -} - -impl WindowShiftEvaluator { - fn is_lag(&self) -> bool { - // Mode is LAG, when shift_offset is positive - self.shift_offset > 0 - } -} - -// implement ignore null for evaluate_all -fn evaluate_all_with_ignore_null( - array: &ArrayRef, - offset: i64, - default_value: &ScalarValue, - is_lag: bool, -) -> Result { - let valid_indices: Vec = - array.nulls().unwrap().valid_indices().collect::>(); - let direction = !is_lag; - let new_array_results: Result, DataFusionError> = (0..array.len()) - .map(|id| { - let result_index = match valid_indices.binary_search(&id) { - Ok(pos) => if direction { - pos.checked_add(offset as usize) - } else { - pos.checked_sub(offset.unsigned_abs() as usize) - } - .and_then(|new_pos| { - if new_pos < valid_indices.len() { - Some(valid_indices[new_pos]) - } else { - None - } - }), - Err(pos) => if direction { - pos.checked_add(offset as usize) - } else if pos > 0 { - pos.checked_sub(offset.unsigned_abs() as usize) - } else { - None - } - .and_then(|new_pos| { - if new_pos < valid_indices.len() { - Some(valid_indices[new_pos]) - } else { - None - } - }), - }; - - match result_index { - Some(index) => ScalarValue::try_from_array(array, index), - None => Ok(default_value.clone()), - } - }) - .collect(); - - let new_array = new_array_results?; - ScalarValue::iter_to_array(new_array) -} -// TODO: change the original arrow::compute::kernels::window::shift impl to support an optional default value -fn shift_with_default_value( - array: &ArrayRef, - offset: i64, - default_value: &ScalarValue, -) -> Result { - use arrow::compute::concat; - - let value_len = array.len() as i64; - if offset == 0 { - Ok(array.clone()) - } else if offset == i64::MIN || offset.abs() >= value_len { - default_value.to_array_of_size(value_len as usize) - } else { - let slice_offset = (-offset).clamp(0, value_len) as usize; - let length = array.len() - offset.unsigned_abs() as usize; - let slice = array.slice(slice_offset, length); - - // Generate array with remaining `null` items - let nulls = offset.unsigned_abs() as usize; - let default_values = default_value.to_array_of_size(nulls)?; - - // Concatenate both arrays, add nulls after if shift > 0 else before - if offset > 0 { - concat(&[default_values.as_ref(), slice.as_ref()]) - .map_err(|e| arrow_datafusion_err!(e)) - } else { - concat(&[slice.as_ref(), default_values.as_ref()]) - .map_err(|e| arrow_datafusion_err!(e)) - } - } -} - -impl PartitionEvaluator for WindowShiftEvaluator { - fn get_range(&self, idx: usize, n_rows: usize) -> Result> { - if self.is_lag() { - let start = if self.non_null_offsets.len() == self.shift_offset as usize { - // How many rows needed previous than the current row to get necessary lag result - let offset: usize = self.non_null_offsets.iter().sum(); - idx.saturating_sub(offset + 1) - } else { - 0 - }; - let end = idx + 1; - Ok(Range { start, end }) - } else { - let end = if self.non_null_offsets.len() == (-self.shift_offset) as usize { - // How many rows needed further than the current row to get necessary lead result - let offset: usize = self.non_null_offsets.iter().sum(); - min(idx + offset + 1, n_rows) - } else { - n_rows - }; - Ok(Range { start: idx, end }) - } - } - - fn is_causal(&self) -> bool { - // Lagging windows are causal by definition: - self.is_lag() - } - - fn evaluate( - &mut self, - values: &[ArrayRef], - range: &Range, - ) -> Result { - let array = &values[0]; - let len = array.len(); - - // LAG mode - let i = if self.is_lag() { - (range.end as i64 - self.shift_offset - 1) as usize - } else { - // LEAD mode - (range.start as i64 - self.shift_offset) as usize - }; - - let mut idx: Option = if i < len { Some(i) } else { None }; - - // LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows - // If current row index points to NULL value the row is NOT counted - if self.ignore_nulls && self.is_lag() { - // LAG when NULLS are ignored. - // Find the nonNULL row index that shifted by offset comparing to current row index - idx = if self.non_null_offsets.len() == self.shift_offset as usize { - let total_offset: usize = self.non_null_offsets.iter().sum(); - Some(range.end - 1 - total_offset) - } else { - None - }; - - // Keep track of offset values between non-null entries - if array.is_valid(range.end - 1) { - // Non-null add new offset - self.non_null_offsets.push_back(1); - if self.non_null_offsets.len() > self.shift_offset as usize { - // WE do not need to keep track of more than `lag number of offset` values. - self.non_null_offsets.pop_front(); - } - } else if !self.non_null_offsets.is_empty() { - // Entry is null, increment offset value of the last entry. - let end_idx = self.non_null_offsets.len() - 1; - self.non_null_offsets[end_idx] += 1; - } - } else if self.ignore_nulls && !self.is_lag() { - // LEAD when NULLS are ignored. - // Stores the necessary non-null entry number further than the current row. - let non_null_row_count = (-self.shift_offset) as usize; - - if self.non_null_offsets.is_empty() { - // When empty, fill non_null offsets with the data further than the current row. - let mut offset_val = 1; - for idx in range.start + 1..range.end { - if array.is_valid(idx) { - self.non_null_offsets.push_back(offset_val); - offset_val = 1; - } else { - offset_val += 1; - } - // It is enough to keep track of `non_null_row_count + 1` non-null offset. - // further data is unnecessary for the result. - if self.non_null_offsets.len() == non_null_row_count + 1 { - break; - } - } - } else if range.end < len && array.is_valid(range.end) { - // Update `non_null_offsets` with the new end data. - if array.is_valid(range.end) { - // When non-null, append a new offset. - self.non_null_offsets.push_back(1); - } else { - // When null, increment offset count of the last entry - let last_idx = self.non_null_offsets.len() - 1; - self.non_null_offsets[last_idx] += 1; - } - } - - // Find the nonNULL row index that shifted by offset comparing to current row index - idx = if self.non_null_offsets.len() >= non_null_row_count { - let total_offset: usize = - self.non_null_offsets.iter().take(non_null_row_count).sum(); - Some(range.start + total_offset) - } else { - None - }; - // Prune `self.non_null_offsets` from the start. so that at next iteration - // start of the `self.non_null_offsets` matches with current row. - if !self.non_null_offsets.is_empty() { - self.non_null_offsets[0] -= 1; - if self.non_null_offsets[0] == 0 { - // When offset is 0. Remove it. - self.non_null_offsets.pop_front(); - } - } - } - - // Set the default value if - // - index is out of window bounds - // OR - // - ignore nulls mode and current value is null and is within window bounds - // .unwrap() is safe here as there is a none check in front - #[allow(clippy::unnecessary_unwrap)] - if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) { - ScalarValue::try_from_array(array, idx.unwrap()) - } else { - Ok(self.default_value.clone()) - } - } - - fn evaluate_all( - &mut self, - values: &[ArrayRef], - _num_rows: usize, - ) -> Result { - // LEAD, LAG window functions take single column, values will have size 1 - let value = &values[0]; - if !self.ignore_nulls { - shift_with_default_value(value, self.shift_offset, &self.default_value) - } else { - evaluate_all_with_ignore_null( - value, - self.shift_offset, - &self.default_value, - self.is_lag(), - ) - } - } - - fn supports_bounded_execution(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::Column; - use arrow::{array::*, datatypes::*}; - use datafusion_common::cast::as_int32_array; - - fn test_i32_result(expr: WindowShift, expected: Int32Array) -> Result<()> { - let arr: ArrayRef = Arc::new(Int32Array::from(vec![1, -2, 3, -4, 5, -6, 7, 8])); - let values = vec![arr]; - let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; - let values = expr.evaluate_args(&batch)?; - let result = expr - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; - let result = as_int32_array(&result)?; - assert_eq!(expected, *result); - Ok(()) - } - - #[test] - fn lead_lag_window_shift() -> Result<()> { - test_i32_result( - lead( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), - [ - Some(-2), - Some(3), - Some(-4), - Some(5), - Some(-6), - Some(7), - Some(8), - None, - ] - .iter() - .collect::(), - )?; - - test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Null.cast_to(&DataType::Int32)?, - false, - ), - [ - None, - Some(1), - Some(-2), - Some(3), - Some(-4), - Some(5), - Some(-6), - Some(7), - ] - .iter() - .collect::(), - )?; - - test_i32_result( - lag( - "lead".to_owned(), - DataType::Int32, - Arc::new(Column::new("c3", 0)), - None, - ScalarValue::Int32(Some(100)), - false, - ), - [ - Some(100), - Some(1), - Some(-2), - Some(3), - Some(-4), - Some(5), - Some(-6), - Some(7), - ] - .iter() - .collect::(), - )?; - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/mod.rs b/datafusion/physical-expr/src/window/mod.rs index 644edae36c9c..3c37fff7a1ba 100644 --- a/datafusion/physical-expr/src/window/mod.rs +++ b/datafusion/physical-expr/src/window/mod.rs @@ -18,12 +18,7 @@ mod aggregate; mod built_in; mod built_in_window_function_expr; -pub(crate) mod cume_dist; -pub(crate) mod lead_lag; pub(crate) mod nth_value; -pub(crate) mod ntile; -pub(crate) mod rank; -pub(crate) mod row_number; mod sliding_aggregate; mod window_expr; diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 55d112e1f6e0..6ec3a23fc586 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -30,7 +30,7 @@ use crate::PhysicalExpr; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::ScalarValue; use datafusion_expr::window_state::WindowAggState; use datafusion_expr::PartitionEvaluator; @@ -86,16 +86,13 @@ impl NthValue { n: i64, ignore_nulls: bool, ) -> Result { - match n { - 0 => exec_err!("NTH_VALUE expects n to be non-zero"), - _ => Ok(Self { - name: name.into(), - expr, - data_type, - kind: NthValueKind::Nth(n), - ignore_nulls, - }), - } + Ok(Self { + name: name.into(), + expr, + data_type, + kind: NthValueKind::Nth(n), + ignore_nulls, + }) } /// Get the NTH_VALUE kind @@ -116,7 +113,7 @@ impl BuiltInWindowFunctionExpr for NthValue { } fn expressions(&self) -> Vec> { - vec![self.expr.clone()] + vec![Arc::clone(&self.expr)] } fn name(&self) -> &str { @@ -125,7 +122,6 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { let state = NthValueState { - range: Default::default(), finalized_result: None, kind: self.kind, }; @@ -143,7 +139,7 @@ impl BuiltInWindowFunctionExpr for NthValue { }; Some(Arc::new(Self { name: self.name.clone(), - expr: self.expr.clone(), + expr: Arc::clone(&self.expr), data_type: self.data_type.clone(), kind: reversed_kind, ignore_nulls: self.ignore_nulls, @@ -189,10 +185,7 @@ impl PartitionEvaluator for NthValueEvaluator { // Negative index represents reverse direction. (n_range >= reverse_index, true) } - Ordering::Equal => { - // The case n = 0 is not valid for the NTH_VALUE function. - unreachable!(); - } + Ordering::Equal => (false, false), } } }; @@ -299,10 +292,7 @@ impl PartitionEvaluator for NthValueEvaluator { ) } } - Ordering::Equal => { - // The case n = 0 is not valid for the NTH_VALUE function. - unreachable!(); - } + Ordering::Equal => ScalarValue::try_from(arr.data_type()), } } } diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs deleted file mode 100644 index fb7a7ad84fb7..000000000000 --- a/datafusion/physical-expr/src/window/ntile.rs +++ /dev/null @@ -1,111 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `ntile` that can evaluated -//! at runtime during query execution - -use crate::expressions::Column; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::{ArrayRef, UInt64Array}; -use arrow::datatypes::Field; -use arrow_schema::{DataType, SchemaRef, SortOptions}; -use datafusion_common::Result; -use datafusion_expr::PartitionEvaluator; - -use std::any::Any; -use std::sync::Arc; - -#[derive(Debug)] -pub struct Ntile { - name: String, - n: u64, - /// Output data type - data_type: DataType, -} - -impl Ntile { - pub fn new(name: String, n: u64, data_type: &DataType) -> Self { - Self { - name, - n, - data_type: data_type.clone(), - } - } - - pub fn get_n(&self) -> u64 { - self.n - } -} - -impl BuiltInWindowFunctionExpr for Ntile { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::new(NtileEvaluator { n: self.n })) - } - - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in NTILE window function introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } - }) - } -} - -#[derive(Debug)] -pub(crate) struct NtileEvaluator { - n: u64, -} - -impl PartitionEvaluator for NtileEvaluator { - fn evaluate_all( - &mut self, - _values: &[ArrayRef], - num_rows: usize, - ) -> Result { - let num_rows = num_rows as u64; - let mut vec: Vec = Vec::new(); - let n = u64::min(self.n, num_rows); - for i in 0..num_rows { - let res = i * n / num_rows; - vec.push(res + 1) - } - Ok(Arc::new(UInt64Array::from(vec))) - } -} diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs deleted file mode 100644 index 0a1255018d30..000000000000 --- a/datafusion/physical-expr/src/window/row_number.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expression for `row_number` that can evaluated at runtime during query execution - -use crate::expressions::Column; -use crate::window::window_expr::NumRowsState; -use crate::window::BuiltInWindowFunctionExpr; -use crate::{PhysicalExpr, PhysicalSortExpr}; - -use arrow::array::{ArrayRef, UInt64Array}; -use arrow::datatypes::{DataType, Field}; -use arrow_schema::{SchemaRef, SortOptions}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::PartitionEvaluator; - -use std::any::Any; -use std::ops::Range; -use std::sync::Arc; - -/// row_number expression -#[derive(Debug)] -pub struct RowNumber { - name: String, - /// Output data type - data_type: DataType, -} - -impl RowNumber { - /// Create a new ROW_NUMBER function - pub fn new(name: impl Into, data_type: &DataType) -> Self { - Self { - name: name.into(), - data_type: data_type.clone(), - } - } -} - -impl BuiltInWindowFunctionExpr for RowNumber { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - let nullable = false; - Ok(Field::new(self.name(), self.data_type.clone(), nullable)) - } - - fn expressions(&self) -> Vec> { - vec![] - } - - fn name(&self) -> &str { - &self.name - } - - fn get_result_ordering(&self, schema: &SchemaRef) -> Option { - // The built-in ROW_NUMBER window function introduces a new ordering: - schema.column_with_name(self.name()).map(|(idx, field)| { - let expr = Arc::new(Column::new(field.name(), idx)); - let options = SortOptions { - descending: false, - nulls_first: false, - }; // ASC, NULLS LAST - PhysicalSortExpr { expr, options } - }) - } - - fn create_evaluator(&self) -> Result> { - Ok(Box::::default()) - } -} - -#[derive(Default, Debug)] -pub(crate) struct NumRowsEvaluator { - state: NumRowsState, -} - -impl PartitionEvaluator for NumRowsEvaluator { - fn is_causal(&self) -> bool { - // The ROW_NUMBER function doesn't need "future" values to emit results: - true - } - - /// evaluate window function result inside given range - fn evaluate( - &mut self, - _values: &[ArrayRef], - _range: &Range, - ) -> Result { - self.state.n_rows += 1; - Ok(ScalarValue::UInt64(Some(self.state.n_rows as u64))) - } - - fn evaluate_all( - &mut self, - _values: &[ArrayRef], - num_rows: usize, - ) -> Result { - Ok(Arc::new(UInt64Array::from_iter_values( - 1..(num_rows as u64) + 1, - ))) - } - - fn supports_bounded_execution(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::{array::*, datatypes::*}; - use datafusion_common::cast::as_uint64_array; - - #[test] - fn row_number_all_null() -> Result<()> { - let arr: ArrayRef = Arc::new(BooleanArray::from(vec![ - None, None, None, None, None, None, None, None, - ])); - let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, true)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - let row_number = RowNumber::new("row_number".to_owned(), &DataType::UInt64); - let values = row_number.evaluate_args(&batch)?; - let result = row_number - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; - let result = as_uint64_array(&result)?; - let result = result.values(); - assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *result); - Ok(()) - } - - #[test] - fn row_number_all_values() -> Result<()> { - let arr: ArrayRef = Arc::new(BooleanArray::from(vec![ - true, false, true, false, false, true, false, true, - ])); - let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; - let row_number = RowNumber::new("row_number".to_owned(), &DataType::UInt64); - let values = row_number.evaluate_args(&batch)?; - let result = row_number - .create_evaluator()? - .evaluate_all(&values, batch.num_rows())?; - let result = as_uint64_array(&result)?; - let result = result.values(); - assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], *result); - Ok(()) - } -} diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 1494129cf897..1e46baae7b0a 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -25,16 +25,15 @@ use arrow::array::{Array, ArrayRef}; use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{Accumulator, WindowFrame}; - +use crate::aggregate::AggregateFunctionExpr; use crate::window::window_expr::AggregateWindowExpr; use crate::window::{ PartitionBatches, PartitionWindowAggStates, PlainAggregateWindowExpr, WindowExpr, }; -use crate::{ - expressions::PhysicalSortExpr, reverse_order_bys, AggregateExpr, PhysicalExpr, -}; +use crate::{expressions::PhysicalSortExpr, reverse_order_bys, PhysicalExpr}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{Accumulator, WindowFrame}; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexOrderingRef}; /// A window expr that takes the form of an aggregate function that /// can be incrementally computed over sliding windows. @@ -42,30 +41,30 @@ use crate::{ /// See comments on [`WindowExpr`] for more details. #[derive(Debug)] pub struct SlidingAggregateWindowExpr { - aggregate: Arc, + aggregate: Arc, partition_by: Vec>, - order_by: Vec, + order_by: LexOrdering, window_frame: Arc, } impl SlidingAggregateWindowExpr { /// Create a new (sliding) aggregate window function expression. pub fn new( - aggregate: Arc, + aggregate: Arc, partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: LexOrderingRef, window_frame: Arc, ) -> Self { Self { aggregate, partition_by: partition_by.to_vec(), - order_by: order_by.to_vec(), + order_by: LexOrdering::from_ref(order_by), window_frame, } } - /// Get the [AggregateExpr] of this object. - pub fn get_aggregate_expr(&self) -> &Arc { + /// Get the [AggregateFunctionExpr] of this object. + pub fn get_aggregate_expr(&self) -> &AggregateFunctionExpr { &self.aggregate } } @@ -82,7 +81,7 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn field(&self) -> Result { - self.aggregate.field() + Ok(self.aggregate.field()) } fn name(&self) -> &str { @@ -109,8 +108,8 @@ impl WindowExpr for SlidingAggregateWindowExpr { &self.partition_by } - fn order_by(&self) -> &[PhysicalSortExpr] { - &self.order_by + fn order_by(&self) -> LexOrderingRef { + self.order_by.as_ref() } fn get_window_frame(&self) -> &Arc { @@ -122,16 +121,16 @@ impl WindowExpr for SlidingAggregateWindowExpr { let reverse_window_frame = self.window_frame.reverse(); if reverse_window_frame.start_bound.is_unbounded() { Arc::new(PlainAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } else { Arc::new(SlidingAggregateWindowExpr::new( - reverse_expr, + Arc::new(reverse_expr), &self.partition_by.clone(), - &reverse_order_bys(&self.order_by), + reverse_order_bys(self.order_by.as_ref()).as_ref(), Arc::new(self.window_frame.reverse()), )) as _ } @@ -141,6 +140,34 @@ impl WindowExpr for SlidingAggregateWindowExpr { fn uses_bounded_memory(&self) -> bool { !self.window_frame.end_bound.is_unbounded() } + + fn with_new_expressions( + &self, + args: Vec>, + partition_bys: Vec>, + order_by_exprs: Vec>, + ) -> Option> { + debug_assert_eq!(self.order_by.len(), order_by_exprs.len()); + + let new_order_by = self + .order_by + .iter() + .zip(order_by_exprs) + .map(|(req, new_expr)| PhysicalSortExpr { + expr: new_expr, + options: req.options, + }) + .collect::(); + Some(Arc::new(SlidingAggregateWindowExpr { + aggregate: self + .aggregate + .with_new_expressions(args, vec![]) + .map(Arc::new)?, + partition_by: partition_bys, + order_by: new_order_by, + window_frame: Arc::clone(&self.window_frame), + })) + } } impl AggregateWindowExpr for SlidingAggregateWindowExpr { @@ -158,8 +185,8 @@ impl AggregateWindowExpr for SlidingAggregateWindowExpr { accumulator: &mut Box, ) -> Result { if cur_range.start == cur_range.end { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type()) + self.aggregate + .default_value(self.aggregate.field().data_type()) } else { // Accumulate any new rows that have entered the window: let update_bound = cur_range.end - last_range.end; @@ -170,6 +197,7 @@ impl AggregateWindowExpr for SlidingAggregateWindowExpr { .collect(); accumulator.update_batch(&update)? } + // Remove rows that have now left the window: let retract_bound = cur_range.start - last_range.start; if retract_bound > 0 { diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index dd9514c69a45..0f882def4433 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -20,7 +20,7 @@ use std::fmt::Debug; use std::ops::Range; use std::sync::Arc; -use crate::{LexOrderingRef, PhysicalExpr, PhysicalSortExpr}; +use crate::{LexOrderingRef, PhysicalExpr}; use arrow::array::{new_empty_array, Array, ArrayRef}; use arrow::compute::kernels::sort::SortColumn; @@ -109,7 +109,7 @@ pub trait WindowExpr: Send + Sync + Debug { fn partition_by(&self) -> &[Arc]; /// Expressions that's from the window function's order by clause, empty if absent - fn order_by(&self) -> &[PhysicalSortExpr]; + fn order_by(&self) -> LexOrderingRef; /// Get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { @@ -128,6 +128,45 @@ pub trait WindowExpr: Send + Sync + Debug { /// Get the reverse expression of this [WindowExpr]. fn get_reverse_expr(&self) -> Option>; + + /// Returns all expressions used in the [`WindowExpr`]. + /// These expressions are (1) function arguments, (2) partition by expressions, (3) order by expressions. + fn all_expressions(&self) -> WindowPhysicalExpressions { + let args = self.expressions(); + let partition_by_exprs = self.partition_by().to_vec(); + let order_by_exprs = self + .order_by() + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect::>(); + WindowPhysicalExpressions { + args, + partition_by_exprs, + order_by_exprs, + } + } + + /// Rewrites [`WindowExpr`], with new expressions given. The argument should be consistent + /// with the return value of the [`WindowExpr::all_expressions`] method. + /// Returns `Some(Arc)` if re-write is supported, otherwise returns `None`. + fn with_new_expressions( + &self, + _args: Vec>, + _partition_bys: Vec>, + _order_by_exprs: Vec>, + ) -> Option> { + None + } +} + +/// Stores the physical expressions used inside the `WindowExpr`. +pub struct WindowPhysicalExpressions { + /// Window function arguments + pub args: Vec>, + /// PARTITION BY expressions + pub partition_by_exprs: Vec>, + /// ORDER BY expressions + pub order_by_exprs: Vec>, } /// Extension trait that adds common functionality to [`AggregateWindowExpr`]s @@ -154,7 +193,7 @@ pub trait AggregateWindowExpr: WindowExpr { let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); let mut window_frame_ctx = - WindowFrameContext::new(self.get_window_frame().clone(), sort_options); + WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options); self.get_result_column( &mut accumulator, batch, @@ -202,7 +241,7 @@ pub trait AggregateWindowExpr: WindowExpr { let window_frame_ctx = state.window_frame_ctx.get_or_insert_with(|| { let sort_options: Vec = self.order_by().iter().map(|o| o.options).collect(); - WindowFrameContext::new(self.get_window_frame().clone(), sort_options) + WindowFrameContext::new(Arc::clone(self.get_window_frame()), sort_options) }); let out_col = self.get_result_column( accumulator, @@ -491,25 +530,6 @@ pub enum WindowFn { Aggregate(Box), } -/// State for the RANK(percent_rank, rank, dense_rank) built-in window function. -#[derive(Debug, Clone, Default)] -pub struct RankState { - /// The last values for rank as these values change, we increase n_rank - pub last_rank_data: Option>, - /// The index where last_rank_boundary is started - pub last_rank_boundary: usize, - /// Keep the number of entries in current rank - pub current_group_count: usize, - /// Rank number kept from the start - pub n_rank: usize, -} - -/// State for the 'ROW_NUMBER' built-in window function. -#[derive(Debug, Clone, Default)] -pub struct NumRowsState { - pub n_rows: usize, -} - /// Tag to differentiate special use cases of the NTH_VALUE built-in window function. #[derive(Debug, Copy, Clone)] pub enum NthValueKind { @@ -520,7 +540,6 @@ pub enum NthValueKind { #[derive(Debug, Clone)] pub struct NthValueState { - pub range: Range, // In certain cases, we can finalize the result early. Consider this usage: // ``` // FIRST_VALUE(increasing_col) OVER window AS my_first_value diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml new file mode 100644 index 000000000000..e7bf4a80fc45 --- /dev/null +++ b/datafusion/physical-optimizer/Cargo.toml @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-physical-optimizer" +description = "DataFusion Physical Optimizer" +keywords = ["datafusion", "query", "optimizer"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +arrow = { workspace = true } +arrow-schema = { workspace = true } +datafusion-common = { workspace = true, default-features = true } +datafusion-execution = { workspace = true } +datafusion-expr-common = { workspace = true, default-features = true } +datafusion-physical-expr = { workspace = true } +datafusion-physical-plan = { workspace = true } +itertools = { workspace = true } + +[dev-dependencies] +datafusion-functions-aggregate = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/physical-optimizer/README.md b/datafusion/physical-optimizer/README.md new file mode 100644 index 000000000000..eb361d3f6779 --- /dev/null +++ b/datafusion/physical-optimizer/README.md @@ -0,0 +1,25 @@ + + +# DataFusion Physical Optimizer + +DataFusion is an extensible query execution framework, written in Rust, +that uses Apache Arrow as its in-memory format. + +This crate contains the physical optimizer for DataFusion. diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs similarity index 61% rename from datafusion/core/src/physical_optimizer/aggregate_statistics.rs rename to datafusion/physical-optimizer/src/aggregate_statistics.rs index 505748860388..27870c7865f3 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -18,21 +18,20 @@ //! Utilizing exact statistics from sources to avoid scanning data use std::sync::Arc; -use super::optimizer::PhysicalOptimizerRule; -use crate::config::ConfigOptions; -use crate::error::Result; -use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::projection::ProjectionExec; -use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics}; -use crate::scalar::ScalarValue; - -use datafusion_common::stats::Precision; +use datafusion_common::config::ConfigOptions; +use datafusion_common::scalar::ScalarValue; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_common::Result; +use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; +use datafusion_physical_plan::{expressions, ExecutionPlan}; + +use crate::PhysicalOptimizerRule; /// Optimizer that uses available statistics for aggregate functions -#[derive(Default)] +#[derive(Default, Debug)] pub struct AggregateStatistics {} impl AggregateStatistics { @@ -56,18 +55,19 @@ impl PhysicalOptimizerRule for AggregateStatistics { let stats = partial_agg_exec.input().statistics()?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { - if let Some((non_null_rows, name)) = - take_optimizable_column_count(&**expr, &stats) - { - projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((num_rows, name)) = - take_optimizable_table_count(&**expr, &stats) + let field = expr.field(); + let args = expr.expressions(); + let statistics_args = StatisticsArgs { + statistics: &stats, + return_type: field.data_type(), + is_distinct: expr.is_distinct(), + exprs: args.as_slice(), + }; + if let Some((optimizable_statistic, name)) = + take_optimizable_value_from_statistics(&statistics_args, expr) { - projections.push((expressions::lit(num_rows), name.to_owned())); - } else if let Some((min, name)) = take_optimizable_min(&**expr, &stats) { - projections.push((expressions::lit(min), name.to_owned())); - } else if let Some((max, name)) = take_optimizable_max(&**expr, &stats) { - projections.push((expressions::lit(max), name.to_owned())); + projections + .push((expressions::lit(optimizable_statistic), name.to_owned())); } else { // TODO: we need all aggr_expr to be resolved (cf TODO fullres) break; @@ -106,6 +106,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { /// assert if the node passed as argument is a final `AggregateExec` node that can be optimized: /// - its child (with possible intermediate layers) is a partial `AggregateExec` node /// - they both have no grouping expression +/// /// If this is the case, return a ref to the partial `AggregateExec`, else `None`. /// We would have preferred to return a casted ref to AggregateExec but the recursion requires /// the `ExecutionPlan.children()` method that returns an owned reference. @@ -126,7 +127,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> return Some(child); } } - if let [ref childrens_child] = child.children().as_slice() { + if let [childrens_child] = child.children().as_slice() { child = Arc::clone(childrens_child); } else { break; @@ -137,182 +138,100 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that is exactly defined in the statistics, return it. -fn take_optimizable_table_count( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, +/// If this agg_expr is a max that is exactly defined in the statistics, return it. +fn take_optimizable_value_from_statistics( + statistics_args: &StatisticsArgs, + agg_expr: &AggregateFunctionExpr, ) -> Option<(ScalarValue, String)> { - if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - // TODO implementing Eq on PhysicalExpr would help a lot here - if casted_expr.expressions().len() == 1 { - if let Some(lit_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - casted_expr.name().to_owned(), - )); - } - } - } - } - None + let value = agg_expr.fun().value_from_stats(statistics_args); + value.map(|val| (val, agg_expr.name().to_string())) } -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_count( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if let (&Precision::Exact(num_rows), Some(casted_expr)) = ( - &stats.num_rows, - agg_expr.as_any().downcast_ref::(), - ) { - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - casted_expr.name().to_string(), - )); - } - } - } +#[cfg(test)] +mod tests { + use crate::aggregate_statistics::AggregateStatistics; + use crate::PhysicalOptimizerRule; + use datafusion_common::config::ConfigOptions; + use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; + use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_plan::aggregates::AggregateExec; + use datafusion_physical_plan::projection::ProjectionExec; + use datafusion_physical_plan::udaf::AggregateFunctionExpr; + use datafusion_physical_plan::ExecutionPlan; + use std::sync::Arc; + + use datafusion_common::Result; + use datafusion_expr_common::operator::Operator; + + use datafusion_physical_plan::aggregates::PhysicalGroupBy; + use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; + use datafusion_physical_plan::common; + use datafusion_physical_plan::filter::FilterExec; + use datafusion_physical_plan::memory::MemoryExec; + + use arrow::array::Int32Array; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion_common::cast::as_int64_array; + use datafusion_physical_expr::expressions::{self, cast}; + use datafusion_physical_plan::aggregates::AggregateMode; + + /// Describe the type of aggregate being tested + pub enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), } - None -} -/// If this agg_expr is a min that is exactly defined in the statistics, return it. -fn take_optimizable_min( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { - if let Ok(min_data_type) = - ScalarValue::try_from(casted_expr.field().unwrap().data_type()) - { - return Some((min_data_type, casted_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].min_value - { - if !val.is_null() { - return Some(( - val.clone(), - casted_expr.name().to_string(), - )); - } - } - } - } - } + impl TestAggregate { + /// Create a new COUNT(*) aggregate + pub fn new_count_star() -> Self { + Self::CountStar + } + + /// Create a new COUNT(column) aggregate + pub fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(Arc::clone(schema)) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + pub fn count_expr(&self, schema: &Schema) -> AggregateFunctionExpr { + AggregateExprBuilder::new(count_udaf(), vec![self.column()]) + .schema(Arc::new(schema.clone())) + .alias(self.column_name()) + .build() + .unwrap() + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), } - _ => {} } - } - None -} -/// If this agg_expr is a max that is exactly defined in the statistics, return it. -fn take_optimizable_max( - agg_expr: &dyn AggregateExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { - if let Ok(max_data_type) = - ScalarValue::try_from(casted_expr.field().unwrap().data_type()) - { - return Some((max_data_type, casted_expr.name().to_string())); - } - } + /// What name would this aggregate produce in a plan? + pub fn column_name(&self) -> &'static str { + match self { + Self::CountStar => "COUNT(*)", + Self::ColumnA(_) => "COUNT(a)", } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if let Some(casted_expr) = - agg_expr.as_any().downcast_ref::() - { - if casted_expr.expressions().len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = casted_expr.expressions()[0] - .as_any() - .downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].max_value - { - if !val.is_null() { - return Some(( - val.clone(), - casted_expr.name().to_string(), - )); - } - } - } - } - } + } + + /// What is the expected count? + pub fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, } - _ => {} } } - None -} - -#[cfg(test)] -pub(crate) mod tests { - - use super::*; - use crate::logical_expr::Operator; - use crate::physical_plan::aggregates::PhysicalGroupBy; - use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; - use crate::physical_plan::common; - use crate::physical_plan::expressions::Count; - use crate::physical_plan::filter::FilterExec; - use crate::physical_plan::memory::MemoryExec; - use crate::prelude::SessionContext; - - use arrow::array::Int32Array; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use datafusion_common::cast::as_int64_array; - use datafusion_physical_expr::expressions::cast; - use datafusion_physical_expr::PhysicalExpr; - use datafusion_physical_plan::aggregates::AggregateMode; /// Mock data using a MemoryExec which has an exact count statistic fn mock_data() -> Result> { @@ -341,21 +260,20 @@ pub(crate) mod tests { plan: AggregateExec, agg: TestAggregate, ) -> Result<()> { - let session_ctx = SessionContext::new(); - let state = session_ctx.state(); + let task_ctx = Arc::new(TaskContext::default()); let plan: Arc = Arc::new(plan); - let optimized = AggregateStatistics::new() - .optimize(Arc::clone(&plan), state.config_options())?; + let config = ConfigOptions::new(); + let optimized = + AggregateStatistics::new().optimize(Arc::clone(&plan), &config)?; // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); // run both the optimized and nonoptimized plan let optimized_result = - common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; - let nonoptimized_result = - common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; + common::collect(optimized.execute(0, Arc::clone(&task_ctx))?).await?; + let nonoptimized_result = common::collect(plan.execute(0, task_ctx)?).await?; assert_eq!(optimized_result.len(), nonoptimized_result.len()); // and validate the results are the same and expected @@ -384,58 +302,6 @@ pub(crate) mod tests { ); } - /// Describe the type of aggregate being tested - pub(crate) enum TestAggregate { - /// Testing COUNT(*) type aggregates - CountStar, - - /// Testing for COUNT(column) aggregate - ColumnA(Arc), - } - - impl TestAggregate { - pub(crate) fn new_count_star() -> Self { - Self::CountStar - } - - fn new_count_column(schema: &Arc) -> Self { - Self::ColumnA(schema.clone()) - } - - /// Return appropriate expr depending if COUNT is for col or table (*) - pub(crate) fn count_expr(&self) -> Arc { - Arc::new(Count::new( - self.column(), - self.column_name(), - DataType::Int64, - )) - } - - /// what argument would this aggregate need in the plan? - fn column(&self) -> Arc { - match self { - Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), - Self::ColumnA(s) => expressions::col("a", s).unwrap(), - } - } - - /// What name would this aggregate produce in a plan? - fn column_name(&self) -> &'static str { - match self { - Self::CountStar => "COUNT(*)", - Self::ColumnA(_) => "COUNT(a)", - } - } - - /// What is the expected count? - fn expected_count(&self) -> i64 { - match self { - TestAggregate::CountStar => 3, - TestAggregate::ColumnA(_) => 2, - } - } - } - #[tokio::test] async fn test_count_partial_direct_child() -> Result<()> { // basic test case with the aggregation applied on a source with exact statistics @@ -446,7 +312,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -455,7 +321,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -476,7 +342,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -485,7 +351,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -505,7 +371,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -517,7 +383,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -537,7 +403,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], source, Arc::clone(&schema), @@ -549,7 +415,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(coalesce), Arc::clone(&schema), @@ -580,7 +446,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], filter, Arc::clone(&schema), @@ -589,7 +455,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), @@ -625,7 +491,7 @@ pub(crate) mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], filter, Arc::clone(&schema), @@ -634,7 +500,7 @@ pub(crate) mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::default(), - vec![agg.count_expr()], + vec![Arc::new(agg.count_expr(&schema))], vec![None], Arc::new(partial_agg), Arc::clone(&schema), diff --git a/datafusion/physical-optimizer/src/combine_partial_final_agg.rs b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs new file mode 100644 index 000000000000..86f7e73e9e35 --- /dev/null +++ b/datafusion/physical-optimizer/src/combine_partial_final_agg.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CombinePartialFinalAggregate optimizer rule checks the adjacent Partial and Final AggregateExecs +//! and try to combine them if necessary + +use std::sync::Arc; + +use datafusion_common::error::Result; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_physical_plan::ExecutionPlan; + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr::{physical_exprs_equal, PhysicalExpr}; + +/// CombinePartialFinalAggregate optimizer rule combines the adjacent Partial and Final AggregateExecs +/// into a Single AggregateExec if their grouping exprs and aggregate exprs equal. +/// +/// This rule should be applied after the EnforceDistribution and EnforceSorting rules +/// +#[derive(Default, Debug)] +pub struct CombinePartialFinalAggregate {} + +impl CombinePartialFinalAggregate { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for CombinePartialFinalAggregate { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + plan.transform_down(|plan| { + // Check if the plan is AggregateExec + let Some(agg_exec) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + if !matches!( + agg_exec.mode(), + AggregateMode::Final | AggregateMode::FinalPartitioned + ) { + return Ok(Transformed::no(plan)); + } + + // Check if the input is AggregateExec + let Some(input_agg_exec) = + agg_exec.input().as_any().downcast_ref::() + else { + return Ok(Transformed::no(plan)); + }; + + let transformed = if matches!(input_agg_exec.mode(), AggregateMode::Partial) + && can_combine( + ( + agg_exec.group_expr(), + agg_exec.aggr_expr(), + agg_exec.filter_expr(), + ), + ( + input_agg_exec.group_expr(), + input_agg_exec.aggr_expr(), + input_agg_exec.filter_expr(), + ), + ) { + let mode = if agg_exec.mode() == &AggregateMode::Final { + AggregateMode::Single + } else { + AggregateMode::SinglePartitioned + }; + AggregateExec::try_new( + mode, + input_agg_exec.group_expr().clone(), + input_agg_exec.aggr_expr().to_vec(), + input_agg_exec.filter_expr().to_vec(), + Arc::clone(input_agg_exec.input()), + input_agg_exec.input_schema(), + ) + .map(|combined_agg| combined_agg.with_limit(agg_exec.limit())) + .ok() + .map(Arc::new) + } else { + None + }; + Ok(if let Some(transformed) = transformed { + Transformed::yes(transformed) + } else { + Transformed::no(plan) + }) + }) + .data() + } + + fn name(&self) -> &str { + "CombinePartialFinalAggregate" + } + + fn schema_check(&self) -> bool { + true + } +} + +type GroupExprsRef<'a> = ( + &'a PhysicalGroupBy, + &'a [Arc], + &'a [Option>], +); + +fn can_combine(final_agg: GroupExprsRef, partial_agg: GroupExprsRef) -> bool { + let (final_group_by, final_aggr_expr, final_filter_expr) = final_agg; + let (input_group_by, input_aggr_expr, input_filter_expr) = partial_agg; + + // Compare output expressions of the partial, and input expressions of the final operator. + physical_exprs_equal( + &input_group_by.output_exprs(), + &final_group_by.input_exprs(), + ) && input_group_by.groups() == final_group_by.groups() + && input_group_by.null_expr().len() == final_group_by.null_expr().len() + && input_group_by + .null_expr() + .iter() + .zip(final_group_by.null_expr().iter()) + .all(|((lhs_expr, lhs_str), (rhs_expr, rhs_str))| { + lhs_expr.eq(rhs_expr) && lhs_str == rhs_str + }) + && final_aggr_expr.len() == input_aggr_expr.len() + && final_aggr_expr + .iter() + .zip(input_aggr_expr.iter()) + .all(|(final_expr, partial_expr)| final_expr.eq(partial_expr)) + && final_filter_expr.len() == input_filter_expr.len() + && final_filter_expr.iter().zip(input_filter_expr.iter()).all( + |(final_expr, partial_expr)| match (final_expr, partial_expr) { + (Some(l), Some(r)) => l.eq(r), + (None, None) => true, + _ => false, + }, + ) +} + +// See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-optimizer/src/lib.rs b/datafusion/physical-optimizer/src/lib.rs new file mode 100644 index 000000000000..439f1dc873d1 --- /dev/null +++ b/datafusion/physical-optimizer/src/lib.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +pub mod aggregate_statistics; +pub mod combine_partial_final_agg; +pub mod limit_pushdown; +pub mod limited_distinct_aggregation; +mod optimizer; +pub mod output_requirements; +pub mod topk_aggregation; + +pub use optimizer::PhysicalOptimizerRule; diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs new file mode 100644 index 000000000000..8f392b683077 --- /dev/null +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -0,0 +1,341 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`LimitPushdown`] pushes `LIMIT` down through `ExecutionPlan`s to reduce +//! data transfer as much as possible. + +use std::fmt::Debug; +use std::sync::Arc; + +use crate::PhysicalOptimizerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::error::Result; +use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; +use datafusion_common::utils::combine_limit; +use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; + +/// This rule inspects [`ExecutionPlan`]'s and pushes down the fetch limit from +/// the parent to the child if applicable. +#[derive(Default, Debug)] +pub struct LimitPushdown {} + +/// This is a "data class" we use within the [`LimitPushdown`] rule to push +/// down [`LimitExec`] in the plan. GlobalRequirements are hold as a rule-wide state +/// and holds the fetch and skip information. The struct also has a field named +/// satisfied which means if the "current" plan is valid in terms of limits or not. +/// +/// For example: If the plan is satisfied with current fetch info, we decide to not add a LocalLimit +/// +/// [`LimitPushdown`]: crate::limit_pushdown::LimitPushdown +/// [`LimitExec`]: crate::limit_pushdown::LimitExec +#[derive(Default, Clone, Debug)] +pub struct GlobalRequirements { + fetch: Option, + skip: usize, + satisfied: bool, +} + +impl LimitPushdown { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl PhysicalOptimizerRule for LimitPushdown { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + let global_state = GlobalRequirements { + fetch: None, + skip: 0, + satisfied: false, + }; + pushdown_limits(plan, global_state) + } + + fn name(&self) -> &str { + "LimitPushdown" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// This enumeration makes `skip` and `fetch` calculations easier by providing +/// a single API for both local and global limit operators. +#[derive(Debug)] +pub enum LimitExec { + Global(GlobalLimitExec), + Local(LocalLimitExec), +} + +impl LimitExec { + fn input(&self) -> &Arc { + match self { + Self::Global(global) => global.input(), + Self::Local(local) => local.input(), + } + } + + fn fetch(&self) -> Option { + match self { + Self::Global(global) => global.fetch(), + Self::Local(local) => Some(local.fetch()), + } + } + + fn skip(&self) -> usize { + match self { + Self::Global(global) => global.skip(), + Self::Local(_) => 0, + } + } +} + +impl From for Arc { + fn from(limit_exec: LimitExec) -> Self { + match limit_exec { + LimitExec::Global(global) => Arc::new(global), + LimitExec::Local(local) => Arc::new(local), + } + } +} + +/// This function is the main helper function of the `LimitPushDown` rule. +/// The helper takes an `ExecutionPlan` and a global (algorithm) state which is +/// an instance of `GlobalRequirements` and modifies these parameters while +/// checking if the limits can be pushed down or not. +/// +/// If a limit is encountered, a [`TreeNodeRecursion::Stop`] is returned. Otherwise, +/// return a [`TreeNodeRecursion::Continue`]. +pub fn pushdown_limit_helper( + mut pushdown_plan: Arc, + mut global_state: GlobalRequirements, +) -> Result<(Transformed>, GlobalRequirements)> { + // Extract limit, if exist, and return child inputs. + if let Some(limit_exec) = extract_limit(&pushdown_plan) { + // If we have fetch/skip info in the global state already, we need to + // decide which one to continue with: + let (skip, fetch) = combine_limit( + global_state.skip, + global_state.fetch, + limit_exec.skip(), + limit_exec.fetch(), + ); + global_state.skip = skip; + global_state.fetch = fetch; + + // Now the global state has the most recent information, we can remove + // the `LimitExec` plan. We will decide later if we should add it again + // or not. + return Ok(( + Transformed { + data: Arc::clone(limit_exec.input()), + transformed: true, + tnr: TreeNodeRecursion::Stop, + }, + global_state, + )); + } + + // If we have a non-limit operator with fetch capability, update global + // state as necessary: + if pushdown_plan.fetch().is_some() { + if global_state.fetch.is_none() { + global_state.satisfied = true; + } + (global_state.skip, global_state.fetch) = combine_limit( + global_state.skip, + global_state.fetch, + 0, + pushdown_plan.fetch(), + ); + } + + let Some(global_fetch) = global_state.fetch else { + // There's no valid fetch information, exit early: + return if global_state.skip > 0 && !global_state.satisfied { + // There might be a case with only offset, if so add a global limit: + global_state.satisfied = true; + Ok(( + Transformed::yes(add_global_limit( + pushdown_plan, + global_state.skip, + None, + )), + global_state, + )) + } else { + // There's no info on offset or fetch, nothing to do: + Ok((Transformed::no(pushdown_plan), global_state)) + }; + }; + + let skip_and_fetch = Some(global_fetch + global_state.skip); + + if pushdown_plan.supports_limit_pushdown() { + if !combines_input_partitions(&pushdown_plan) { + // We have information in the global state and the plan pushes down, + // continue: + Ok((Transformed::no(pushdown_plan), global_state)) + } else if let Some(plan_with_fetch) = pushdown_plan.with_fetch(skip_and_fetch) { + // This plan is combining input partitions, so we need to add the + // fetch info to plan if possible. If not, we must add a `LimitExec` + // with the information from the global state. + let mut new_plan = plan_with_fetch; + // Execution plans can't (yet) handle skip, so if we have one, + // we still need to add a global limit + if global_state.skip > 0 { + new_plan = + add_global_limit(new_plan, global_state.skip, global_state.fetch); + } + global_state.fetch = skip_and_fetch; + global_state.skip = 0; + global_state.satisfied = true; + Ok((Transformed::yes(new_plan), global_state)) + } else if global_state.satisfied { + // If the plan is already satisfied, do not add a limit: + Ok((Transformed::no(pushdown_plan), global_state)) + } else { + global_state.satisfied = true; + Ok(( + Transformed::yes(add_limit( + pushdown_plan, + global_state.skip, + global_fetch, + )), + global_state, + )) + } + } else { + // The plan does not support push down and it is not a limit. We will need + // to add a limit or a fetch. If the plan is already satisfied, we will try + // to add the fetch info and return the plan. + + // There's no push down, change fetch & skip to default values: + let global_skip = global_state.skip; + global_state.fetch = None; + global_state.skip = 0; + + let maybe_fetchable = pushdown_plan.with_fetch(skip_and_fetch); + if global_state.satisfied { + if let Some(plan_with_fetch) = maybe_fetchable { + Ok((Transformed::yes(plan_with_fetch), global_state)) + } else { + Ok((Transformed::no(pushdown_plan), global_state)) + } + } else { + // Add fetch or a `LimitExec`: + global_state.satisfied = true; + pushdown_plan = if let Some(plan_with_fetch) = maybe_fetchable { + if global_skip > 0 { + add_global_limit(plan_with_fetch, global_skip, Some(global_fetch)) + } else { + plan_with_fetch + } + } else { + add_limit(pushdown_plan, global_skip, global_fetch) + }; + Ok((Transformed::yes(pushdown_plan), global_state)) + } + } +} + +/// Pushes down the limit through the plan. +pub(crate) fn pushdown_limits( + pushdown_plan: Arc, + global_state: GlobalRequirements, +) -> Result> { + // Call pushdown_limit_helper. + // This will either extract the limit node (returning the child), or apply the limit pushdown. + let (mut new_node, mut global_state) = + pushdown_limit_helper(pushdown_plan, global_state)?; + + // While limits exist, continue combining the global_state. + while new_node.tnr == TreeNodeRecursion::Stop { + (new_node, global_state) = pushdown_limit_helper(new_node.data, global_state)?; + } + + // Apply pushdown limits in children + let children = new_node.data.children(); + let new_children = children + .into_iter() + .map(|child| { + pushdown_limits(Arc::::clone(child), global_state.clone()) + }) + .collect::>()?; + new_node.data.with_new_children(new_children) +} + +/// Transforms the [`ExecutionPlan`] into a [`LimitExec`] if it is a +/// [`GlobalLimitExec`] or a [`LocalLimitExec`]. +fn extract_limit(plan: &Arc) -> Option { + if let Some(global_limit) = plan.as_any().downcast_ref::() { + Some(LimitExec::Global(GlobalLimitExec::new( + Arc::clone(global_limit.input()), + global_limit.skip(), + global_limit.fetch(), + ))) + } else { + plan.as_any() + .downcast_ref::() + .map(|local_limit| { + LimitExec::Local(LocalLimitExec::new( + Arc::clone(local_limit.input()), + local_limit.fetch(), + )) + }) + } +} + +/// Checks if the given plan combines input partitions. +fn combines_input_partitions(plan: &Arc) -> bool { + let plan = plan.as_any(); + plan.is::() || plan.is::() +} + +/// Adds a limit to the plan, chooses between global and local limits based on +/// skip value and the number of partitions. +fn add_limit( + pushdown_plan: Arc, + skip: usize, + fetch: usize, +) -> Arc { + if skip > 0 || pushdown_plan.output_partitioning().partition_count() == 1 { + add_global_limit(pushdown_plan, skip, Some(fetch)) + } else { + Arc::new(LocalLimitExec::new(pushdown_plan, fetch + skip)) as _ + } +} + +/// Adds a global limit to the plan. +fn add_global_limit( + pushdown_plan: Arc, + skip: usize, + fetch: Option, +) -> Arc { + Arc::new(GlobalLimitExec::new(pushdown_plan, skip, fetch)) as _ +} + +// See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs new file mode 100644 index 000000000000..7833324f64fa --- /dev/null +++ b/datafusion/physical-optimizer/src/limited_distinct_aggregation.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! A special-case optimizer rule that pushes limit into a grouped aggregation +//! which has no aggregate expressions or sorting requirements + +use std::sync::Arc; + +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; +use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; + +use datafusion_common::config::ConfigOptions; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::Result; + +use crate::PhysicalOptimizerRule; +use itertools::Itertools; + +/// An optimizer rule that passes a `limit` hint into grouped aggregations which don't require all +/// rows in the group to be processed for correctness. Example queries fitting this description are: +/// - `SELECT distinct l_orderkey FROM lineitem LIMIT 10;` +/// - `SELECT l_orderkey FROM lineitem GROUP BY l_orderkey LIMIT 10;` +#[derive(Debug)] +pub struct LimitedDistinctAggregation {} + +impl LimitedDistinctAggregation { + /// Create a new `LimitedDistinctAggregation` + pub fn new() -> Self { + Self {} + } + + fn transform_agg( + aggr: &AggregateExec, + limit: usize, + ) -> Option> { + // rules for transforming this Aggregate are held in this method + if !aggr.is_unordered_unfiltered_group_by_distinct() { + return None; + } + + // We found what we want: clone, copy the limit down, and return modified node + let new_aggr = AggregateExec::try_new( + *aggr.mode(), + aggr.group_expr().clone(), + aggr.aggr_expr().to_vec(), + aggr.filter_expr().to_vec(), + aggr.input().to_owned(), + aggr.input_schema(), + ) + .expect("Unable to copy Aggregate!") + .with_limit(Some(limit)); + Some(Arc::new(new_aggr)) + } + + /// transform_limit matches an `AggregateExec` as the child of a `LocalLimitExec` + /// or `GlobalLimitExec` and pushes the limit into the aggregation as a soft limit when + /// there is a group by, but no sorting, no aggregate expressions, and no filters in the + /// aggregation + fn transform_limit(plan: Arc) -> Option> { + let limit: usize; + let mut global_fetch: Option = None; + let mut global_skip: usize = 0; + let children: Vec>; + let mut is_global_limit = false; + if let Some(local_limit) = plan.as_any().downcast_ref::() { + limit = local_limit.fetch(); + children = local_limit.children().into_iter().cloned().collect(); + } else if let Some(global_limit) = plan.as_any().downcast_ref::() + { + global_fetch = global_limit.fetch(); + global_fetch?; + global_skip = global_limit.skip(); + // the aggregate must read at least fetch+skip number of rows + limit = global_fetch.unwrap() + global_skip; + children = global_limit.children().into_iter().cloned().collect(); + is_global_limit = true + } else { + return None; + } + let child = children.iter().exactly_one().ok()?; + // ensure there is no output ordering; can this rule be relaxed? + if plan.output_ordering().is_some() { + return None; + } + // ensure no ordering is required on the input + if plan.required_input_ordering()[0].is_some() { + return None; + } + + // if found_match_aggr is true, match_aggr holds a parent aggregation whose group_by + // must match that of a child aggregation in order to rewrite the child aggregation + let mut match_aggr: Arc = plan; + let mut found_match_aggr = false; + + let mut rewrite_applicable = true; + let closure = |plan: Arc| { + if !rewrite_applicable { + return Ok(Transformed::no(plan)); + } + if let Some(aggr) = plan.as_any().downcast_ref::() { + if found_match_aggr { + if let Some(parent_aggr) = + match_aggr.as_any().downcast_ref::() + { + if !parent_aggr.group_expr().eq(aggr.group_expr()) { + // a partial and final aggregation with different groupings disqualifies + // rewriting the child aggregation + rewrite_applicable = false; + return Ok(Transformed::no(plan)); + } + } + } + // either we run into an Aggregate and transform it, or disable the rewrite + // for subsequent children + match Self::transform_agg(aggr, limit) { + None => {} + Some(new_aggr) => { + match_aggr = plan; + found_match_aggr = true; + return Ok(Transformed::yes(new_aggr)); + } + } + } + rewrite_applicable = false; + Ok(Transformed::no(plan)) + }; + let child = child.to_owned().transform_down(closure).data().ok()?; + if is_global_limit { + return Some(Arc::new(GlobalLimitExec::new( + child, + global_skip, + global_fetch, + ))); + } + Some(Arc::new(LocalLimitExec::new(child, limit))) + } +} + +impl Default for LimitedDistinctAggregation { + fn default() -> Self { + Self::new() + } +} + +impl PhysicalOptimizerRule for LimitedDistinctAggregation { + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result> { + if config.optimizer.enable_distinct_aggregation_soft_limit { + plan.transform_down(|plan| { + Ok( + if let Some(plan) = + LimitedDistinctAggregation::transform_limit(plan.to_owned()) + { + Transformed::yes(plan) + } else { + Transformed::no(plan) + }, + ) + }) + .data() + } else { + Ok(plan) + } + } + + fn name(&self) -> &str { + "LimitedDistinctAggregation" + } + + fn schema_check(&self) -> bool { + true + } +} + +// See tests in datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs diff --git a/datafusion/physical-optimizer/src/optimizer.rs b/datafusion/physical-optimizer/src/optimizer.rs new file mode 100644 index 000000000000..609890e2d43f --- /dev/null +++ b/datafusion/physical-optimizer/src/optimizer.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical optimizer traits + +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_physical_plan::ExecutionPlan; +use std::fmt::Debug; +use std::sync::Arc; + +/// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which +/// computes the same results, but in a potentially more efficient way. +/// +/// Use [`SessionState::add_physical_optimizer_rule`] to register additional +/// `PhysicalOptimizerRule`s. +/// +/// [`SessionState::add_physical_optimizer_rule`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.add_physical_optimizer_rule +pub trait PhysicalOptimizerRule: Debug { + /// Rewrite `plan` to an optimized form + fn optimize( + &self, + plan: Arc, + config: &ConfigOptions, + ) -> Result>; + + /// A human readable name for this optimizer rule + fn name(&self) -> &str; + + /// A flag to indicate whether the physical planner should valid the rule will not + /// change the schema of the plan after the rewriting. + /// Some of the optimization rules might change the nullable properties of the schema + /// and should disable the schema check. + fn schema_check(&self) -> bool; +} diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs similarity index 90% rename from datafusion/core/src/physical_optimizer/output_requirements.rs rename to datafusion/physical-optimizer/src/output_requirements.rs index 5bf86e88d646..4f6f91a2348f 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -24,9 +24,11 @@ use std::sync::Arc; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, SendableRecordBatchStream, +}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -35,6 +37,8 @@ use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequire use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion_physical_plan::{ExecutionPlanProperties, PlanProperties}; +use crate::PhysicalOptimizerRule; + /// This rule either adds or removes [`OutputRequirements`]s to/from the physical /// plan according to its `mode` attribute, which is set by the constructors /// `new_add_mode` and `new_remove_mode`. With this rule, we can keep track of @@ -86,7 +90,7 @@ enum RuleMode { /// /// See [`OutputRequirements`] for more details #[derive(Debug)] -pub(crate) struct OutputRequirementExec { +pub struct OutputRequirementExec { input: Arc, order_requirement: Option, dist_requirement: Distribution, @@ -94,7 +98,7 @@ pub(crate) struct OutputRequirementExec { } impl OutputRequirementExec { - pub(crate) fn new( + pub fn new( input: Arc, requirements: Option, dist_requirement: Distribution, @@ -108,8 +112,8 @@ impl OutputRequirementExec { } } - pub(crate) fn input(&self) -> Arc { - self.input.clone() + pub fn input(&self) -> Arc { + Arc::clone(&self.input) } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -157,11 +161,11 @@ impl ExecutionPlan for OutputRequirementExec { vec![true] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { vec![self.order_requirement.clone()] } @@ -179,8 +183,8 @@ impl ExecutionPlan for OutputRequirementExec { fn execute( &self, _partition: usize, - _context: Arc, - ) -> Result { + _context: Arc, + ) -> Result { unreachable!(); } @@ -248,7 +252,9 @@ fn require_top_ordering_helper( if children.len() != 1 { Ok((plan, false)) } else if let Some(sort_exec) = plan.as_any().downcast_ref::() { - let req_ordering = sort_exec.properties().output_ordering().unwrap_or(&[]); + // In case of constant columns, output ordering of SortExec would give an empty set. + // Therefore; we check the sort expression field of the SortExec to assign the requirements. + let req_ordering = sort_exec.expr(); let req_dist = sort_exec.required_input_distribution()[0].clone(); let reqs = PhysicalSortRequirement::from_sort_exprs(req_ordering); Ok(( @@ -273,10 +279,12 @@ fn require_top_ordering_helper( // When an operator requires an ordering, any `SortExec` below can not // be responsible for (i.e. the originator of) the global ordering. let (new_child, is_changed) = - require_top_ordering_helper(children.swap_remove(0))?; + require_top_ordering_helper(Arc::clone(children.swap_remove(0)))?; Ok((plan.with_new_children(vec![new_child])?, is_changed)) } else { // Stop searching, there is no global ordering desired for the query. Ok((plan, false)) } } + +// See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs similarity index 64% rename from datafusion/core/src/physical_optimizer/topk_aggregation.rs rename to datafusion/physical-optimizer/src/topk_aggregation.rs index 95f7067cbe1b..0e5fb82d9e93 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -19,24 +19,22 @@ use std::sync::Arc; -use crate::physical_optimizer::PhysicalOptimizerRule; -use crate::physical_plan::aggregates::AggregateExec; -use crate::physical_plan::coalesce_batches::CoalesceBatchesExec; -use crate::physical_plan::filter::FilterExec; -use crate::physical_plan::repartition::RepartitionExec; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::ExecutionPlan; - -use arrow_schema::DataType; +use crate::PhysicalOptimizerRule; +use arrow::datatypes::DataType; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::PhysicalSortExpr; - +use datafusion_physical_expr::LexOrdering; +use datafusion_physical_plan::aggregates::AggregateExec; +use datafusion_physical_plan::execution_plan::CardinalityEffect; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed +#[derive(Debug)] pub struct TopKAggregation {} impl TopKAggregation { @@ -47,12 +45,13 @@ impl TopKAggregation { fn transform_agg( aggr: &AggregateExec, - order: &PhysicalSortExpr, + order_by: &str, + order_desc: bool, limit: usize, ) -> Option> { // ensure the sort direction matches aggregate function let (field, desc) = aggr.get_minmax_desc()?; - if desc != order.options.descending { + if desc != order_desc { return None; } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; @@ -65,18 +64,17 @@ impl TopKAggregation { } // ensure the sort is on the same field as the aggregate output - let col = order.expr.as_any().downcast_ref::()?; - if col.name() != field.name() { + if order_by != field.name() { return None; } // We found what we want: clone, copy the limit down, and return modified node let new_aggr = AggregateExec::try_new( *aggr.mode(), - aggr.group_by().clone(), + aggr.group_expr().clone(), aggr.aggr_expr().to_vec(), aggr.filter_expr().to_vec(), - aggr.input().clone(), + Arc::clone(aggr.input()), aggr.input_schema(), ) .expect("Unable to copy Aggregate!") @@ -84,23 +82,18 @@ impl TopKAggregation { Some(Arc::new(new_aggr)) } - fn transform_sort(plan: Arc) -> Option> { + fn transform_sort(plan: &Arc) -> Option> { let sort = plan.as_any().downcast_ref::()?; let children = sort.children(); - let child = children.iter().exactly_one().ok()?; + let child = children.into_iter().exactly_one().ok()?; let order = sort.properties().output_ordering()?; let order = order.iter().exactly_one().ok()?; + let order_desc = order.options.descending; + let order = order.expr.as_any().downcast_ref::()?; + let mut cur_col_name = order.name().to_string(); let limit = sort.fetch()?; - let is_cardinality_preserving = |plan: Arc| { - plan.as_any() - .downcast_ref::() - .is_some() - || plan.as_any().downcast_ref::().is_some() - || plan.as_any().downcast_ref::().is_some() - }; - let mut cardinality_preserved = true; let closure = |plan: Arc| { if !cardinality_preserved { @@ -108,20 +101,33 @@ impl TopKAggregation { } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it - match Self::transform_agg(aggr, order, limit) { + match Self::transform_agg(aggr, &cur_col_name, order_desc, limit) { None => cardinality_preserved = false, Some(plan) => return Ok(Transformed::yes(plan)), } + } else if let Some(proj) = plan.as_any().downcast_ref::() { + // track renames due to successive projections + for (src_expr, proj_name) in proj.expr() { + let Some(src_col) = src_expr.as_any().downcast_ref::() else { + continue; + }; + if *proj_name == cur_col_name { + cur_col_name = src_col.name().to_string(); + } + } } else { - // or we continue down whitelisted nodes of other types - if !is_cardinality_preserving(plan.clone()) { - cardinality_preserved = false; + // or we continue down through types that don't reduce cardinality + match plan.cardinality_effect() { + CardinalityEffect::Equal | CardinalityEffect::GreaterEqual => {} + CardinalityEffect::Unknown | CardinalityEffect::LowerEqual => { + cardinality_preserved = false; + } } } Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down(closure).data().ok()?; - let sort = SortExec::new(sort.expr().to_vec(), child) + let child = Arc::clone(child).transform_down(closure).data().ok()?; + let sort = SortExec::new(LexOrdering::new(sort.expr().to_vec()), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); Some(Arc::new(sort)) @@ -142,13 +148,11 @@ impl PhysicalOptimizerRule for TopKAggregation { ) -> Result> { if config.optimizer.enable_topk_aggregation { plan.transform_down(|plan| { - Ok( - if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { - Transformed::yes(plan) - } else { - Transformed::no(plan) - }, - ) + Ok(if let Some(plan) = TopKAggregation::transform_sort(&plan) { + Transformed::yes(plan) + } else { + Transformed::no(plan) + }) }) .data() } else { diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 25e1a6ad5bd3..a9f9b22fafda 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -31,14 +31,15 @@ rust-version = { workspace = true } [lints] workspace = true +[features] +force_hash_collisions = [] + [lib] name = "datafusion_physical_plan" path = "src/lib.rs" [dependencies] -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } +ahash = { workspace = true } arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } @@ -50,12 +51,13 @@ datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } -datafusion-functions-aggregate = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } +datafusion-functions-window-common = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } datafusion-physical-expr-common = { workspace = true } futures = { workspace = true } half = { workspace = true } -hashbrown = { version = "0.14", features = ["raw"] } +hashbrown = { workspace = true } indexmap = { workspace = true } itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } @@ -66,11 +68,16 @@ rand = { workspace = true } tokio = { workspace = true } [dev-dependencies] +criterion = { version = "0.5", features = ["async_futures"] } +datafusion-functions-aggregate = { workspace = true } rstest = { workspace = true } -rstest_reuse = "0.6.0" -termtree = "0.4.1" +rstest_reuse = "0.7.0" tokio = { workspace = true, features = [ "rt-multi-thread", "fs", "parking_lot", ] } + +[[bench]] +harness = false +name = "spm" diff --git a/datafusion/physical-plan/benches/spm.rs b/datafusion/physical-plan/benches/spm.rs new file mode 100644 index 000000000000..fbbd27409173 --- /dev/null +++ b/datafusion/physical-plan/benches/spm.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::record_batch::RecordBatch; +use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray}; +use datafusion_execution::TaskContext; +use datafusion_physical_expr::expressions::col; +use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; +use datafusion_physical_plan::memory::MemoryExec; +use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; +use datafusion_physical_plan::{collect, ExecutionPlan}; + +use criterion::async_executor::FuturesExecutor; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +fn generate_spm_for_round_robin_tie_breaker( + has_same_value: bool, + enable_round_robin_repartition: bool, + batch_count: usize, + partition_count: usize, +) -> SortPreservingMergeExec { + let row_size = 256; + let rb = if has_same_value { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + } else { + let v = (0i32..row_size as i32).collect::>(); + let a: ArrayRef = Arc::new(Int32Array::from(v)); + + // Use alphanumeric characters + let charset: Vec = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + .chars() + .collect(); + + let mut strings = Vec::new(); + for i in 0..256 { + let mut s = String::new(); + s.push(charset[i % charset.len()]); + s.push(charset[(i / charset.len()) % charset.len()]); + strings.push(Some(s)); + } + + let b: ArrayRef = Arc::new(StringArray::from_iter(strings)); + + let v = (0i64..row_size as i64).collect::>(); + let c: ArrayRef = Arc::new(Int64Array::from_iter(v)); + RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap() + }; + + let rbs = (0..batch_count).map(|_| rb.clone()).collect::>(); + let partitiones = vec![rbs.clone(); partition_count]; + + let schema = rb.schema(); + let sort = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: Default::default(), + }, + ]); + + let exec = MemoryExec::try_new(&partitiones, schema, None).unwrap(); + SortPreservingMergeExec::new(sort, Arc::new(exec)) + .with_round_robin_repartition(enable_round_robin_repartition) +} + +fn run_bench( + c: &mut Criterion, + has_same_value: bool, + enable_round_robin_repartition: bool, + batch_count: usize, + partition_count: usize, + description: &str, +) { + let task_ctx = TaskContext::default(); + let task_ctx = Arc::new(task_ctx); + + let spm = Arc::new(generate_spm_for_round_robin_tie_breaker( + has_same_value, + enable_round_robin_repartition, + batch_count, + partition_count, + )) as Arc; + + c.bench_function(description, |b| { + b.to_async(FuturesExecutor) + .iter(|| black_box(collect(Arc::clone(&spm), Arc::clone(&task_ctx)))) + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + let params = [ + (true, false, "low_card_without_tiebreaker"), // low cardinality, no tie breaker + (true, true, "low_card_with_tiebreaker"), // low cardinality, with tie breaker + (false, false, "high_card_without_tiebreaker"), // high cardinality, no tie breaker + (false, true, "high_card_with_tiebreaker"), // high cardinality, with tie breaker + ]; + + let batch_counts = [1, 25, 625]; + let partition_counts = [2, 8, 32]; + + for &(has_same_value, enable_round_robin_repartition, cardinality_label) in ¶ms { + for &batch_count in &batch_counts { + for &partition_count in &partition_counts { + let description = format!( + "{}_batch_count_{}_partition_count_{}", + cardinality_label, batch_count, partition_count + ); + run_bench( + c, + has_same_value, + enable_round_robin_repartition, + batch_count, + partition_count, + &description, + ); + } + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs index d073c8995a9b..013c027e7306 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes.rs @@ -18,7 +18,8 @@ use crate::aggregates::group_values::GroupValues; use arrow_array::{Array, ArrayRef, OffsetSizeTrait, RecordBatch}; use datafusion_expr::EmitTo; -use datafusion_physical_expr::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; +use std::mem::size_of; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// @@ -73,7 +74,7 @@ impl GroupValues for GroupValuesByes { } fn size(&self) -> usize { - self.map.size() + std::mem::size_of::() + self.map.size() + size_of::() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs new file mode 100644 index 000000000000..7379b7a538b4 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use datafusion_expr::EmitTo; +use datafusion_physical_expr::binary_map::OutputType; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; +use std::mem::size_of; + +/// A [`GroupValues`] storing single column of Utf8View/BinaryView values +/// +/// This specialization is significantly faster than using the more general +/// purpose `Row`s format +pub struct GroupValuesBytesView { + /// Map string/binary values to group index + map: ArrowBytesViewMap, + /// The total number of groups so far (used to assign group_index) + num_groups: usize, +} + +impl GroupValuesBytesView { + pub fn new(output_type: OutputType) -> Self { + Self { + map: ArrowBytesViewMap::new(output_type), + num_groups: 0, + } + } +} + +impl GroupValues for GroupValuesBytesView { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + ) -> datafusion_common::Result<()> { + assert_eq!(cols.len(), 1); + + // look up / add entries in the table + let arr = &cols[0]; + + groups.clear(); + self.map.insert_if_new( + arr, + // called for each new group + |_value| { + // assign new group index on each insert + let group_idx = self.num_groups; + self.num_groups += 1; + group_idx + }, + // called for each group + |group_idx| { + groups.push(group_idx); + }, + ); + + // ensure we assigned a group to for each row + assert_eq!(groups.len(), arr.len()); + Ok(()) + } + + fn size(&self) -> usize { + self.map.size() + size_of::() + } + + fn is_empty(&self) -> bool { + self.num_groups == 0 + } + + fn len(&self) -> usize { + self.num_groups + } + + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + // Reset the map to default, and convert it into a single array + let map_contents = self.map.take().into_state(); + + let group_values = match emit_to { + EmitTo::All => { + self.num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) if n == self.len() => { + self.num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) => { + // if we only wanted to take the first n, insert the rest back + // into the map we could potentially avoid this reallocation, at + // the expense of much more complex code. + // see https://github.com/apache/datafusion/issues/9195 + let emit_group_values = map_contents.slice(0, n); + let remaining_group_values = + map_contents.slice(n, map_contents.len() - n); + + self.num_groups = 0; + let mut group_indexes = vec![]; + self.intern(&[remaining_group_values], &mut group_indexes)?; + + // Verify that the group indexes were assigned in the correct order + assert_eq!(0, group_indexes[0]); + + emit_group_values + } + }; + + Ok(vec![group_values]) + } + + fn clear_shrink(&mut self, _batch: &RecordBatch) { + // in theory we could potentially avoid this reallocation and clear the + // contents of the maps, but for now we just reset the map from the beginning + self.map.take(); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs new file mode 100644 index 000000000000..958a4b58d800 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -0,0 +1,358 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::group_column::{ + ByteGroupValueBuilder, ByteViewGroupValueBuilder, GroupColumn, + PrimitiveGroupValueBuilder, +}; +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::compute::cast; +use arrow::datatypes::{ + BinaryViewType, Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, + Int32Type, Int64Type, Int8Type, StringViewType, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, +}; +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::{DataType, Schema, SchemaRef}; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_expr::EmitTo; +use datafusion_physical_expr::binary_map::OutputType; +use hashbrown::raw::RawTable; +use std::mem::size_of; + +/// A [`GroupValues`] that stores multiple columns of group values. +/// +/// +pub struct GroupValuesColumn { + /// The output schema + schema: SchemaRef, + + /// Logically maps group values to a group_index in + /// [`Self::group_values`] and in each accumulator + /// + /// Uses the raw API of hashbrown to avoid actually storing the + /// keys (group values) in the table + /// + /// keys: u64 hashes of the GroupValue + /// values: (hash, group_index) + map: RawTable<(u64, usize)>, + + /// The size of `map` in bytes + map_size: usize, + + /// The actual group by values, stored column-wise. Compare from + /// the left to right, each column is stored as [`GroupColumn`]. + /// + /// Performance tests showed that this design is faster than using the + /// more general purpose [`GroupValuesRows`]. See the ticket for details: + /// + /// + /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows + group_values: Vec>, + + /// reused buffer to store hashes + hashes_buffer: Vec, + + /// Random state for creating hashes + random_state: RandomState, +} + +impl GroupValuesColumn { + /// Create a new instance of GroupValuesColumn if supported for the specified schema + pub fn try_new(schema: SchemaRef) -> Result { + let map = RawTable::with_capacity(0); + Ok(Self { + schema, + map, + map_size: 0, + group_values: vec![], + hashes_buffer: Default::default(), + random_state: Default::default(), + }) + } + + /// Returns true if [`GroupValuesColumn`] supported for the specified schema + pub fn supported_schema(schema: &Schema) -> bool { + schema + .fields() + .iter() + .map(|f| f.data_type()) + .all(Self::supported_type) + } + + /// Returns true if the specified data type is supported by [`GroupValuesColumn`] + /// + /// In order to be supported, there must be a specialized implementation of + /// [`GroupColumn`] for the data type, instantiated in [`Self::intern`] + fn supported_type(data_type: &DataType) -> bool { + matches!( + *data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Date32 + | DataType::Date64 + | DataType::Utf8View + | DataType::BinaryView + ) + } +} + +/// instantiates a [`PrimitiveGroupValueBuilder`] and pushes it into $v +/// +/// Arguments: +/// `$v`: the vector to push the new builder into +/// `$nullable`: whether the input can contains nulls +/// `$t`: the primitive type of the builder +/// +macro_rules! instantiate_primitive { + ($v:expr, $nullable:expr, $t:ty) => { + if $nullable { + let b = PrimitiveGroupValueBuilder::<$t, true>::new(); + $v.push(Box::new(b) as _) + } else { + let b = PrimitiveGroupValueBuilder::<$t, false>::new(); + $v.push(Box::new(b) as _) + } + }; +} + +impl GroupValues for GroupValuesColumn { + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + let n_rows = cols[0].len(); + + if self.group_values.is_empty() { + let mut v = Vec::with_capacity(cols.len()); + + for f in self.schema.fields().iter() { + let nullable = f.is_nullable(); + match f.data_type() { + &DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type), + &DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type), + &DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type), + &DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type), + &DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type), + &DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type), + &DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type), + &DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type), + &DataType::Float32 => { + instantiate_primitive!(v, nullable, Float32Type) + } + &DataType::Float64 => { + instantiate_primitive!(v, nullable, Float64Type) + } + &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), + &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), + &DataType::Utf8 => { + let b = ByteGroupValueBuilder::::new(OutputType::Utf8); + v.push(Box::new(b) as _) + } + &DataType::LargeUtf8 => { + let b = ByteGroupValueBuilder::::new(OutputType::Utf8); + v.push(Box::new(b) as _) + } + &DataType::Binary => { + let b = ByteGroupValueBuilder::::new(OutputType::Binary); + v.push(Box::new(b) as _) + } + &DataType::LargeBinary => { + let b = ByteGroupValueBuilder::::new(OutputType::Binary); + v.push(Box::new(b) as _) + } + &DataType::Utf8View => { + let b = ByteViewGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } + &DataType::BinaryView => { + let b = ByteViewGroupValueBuilder::::new(); + v.push(Box::new(b) as _) + } + dt => { + return not_impl_err!("{dt} not supported in GroupValuesColumn") + } + } + } + self.group_values = v; + } + + // tracks to which group each of the input rows belongs + groups.clear(); + + // 1.1 Calculate the group keys for the group values + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(n_rows, 0); + create_hashes(cols, &self.random_state, batch_hashes)?; + + for (row, &target_hash) in batch_hashes.iter().enumerate() { + let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 + if target_hash != *exist_hash { + return false; + } + + fn check_row_equal( + array_row: &dyn GroupColumn, + lhs_row: usize, + array: &ArrayRef, + rhs_row: usize, + ) -> bool { + array_row.equal_to(lhs_row, array, rhs_row) + } + + for (i, group_val) in self.group_values.iter().enumerate() { + if !check_row_equal(group_val.as_ref(), *group_idx, &cols[i], row) { + return false; + } + } + + true + }); + + let group_idx = match entry { + // Existing group_index for this group value + Some((_hash, group_idx)) => *group_idx, + // 1.2 Need to create new entry for the group + None => { + // Add new entry to aggr_state and save newly created index + // let group_idx = group_values.num_rows(); + // group_values.push(group_rows.row(row)); + + let mut checklen = 0; + let group_idx = self.group_values[0].len(); + for (i, group_value) in self.group_values.iter_mut().enumerate() { + group_value.append_val(&cols[i], row); + let len = group_value.len(); + if i == 0 { + checklen = len; + } else { + debug_assert_eq!(checklen, len); + } + } + + // for hasher function, use precomputed hash value + self.map.insert_accounted( + (target_hash, group_idx), + |(hash, _group_index)| *hash, + &mut self.map_size, + ); + group_idx + } + }; + groups.push(group_idx); + } + + Ok(()) + } + + fn size(&self) -> usize { + let group_values_size: usize = self.group_values.iter().map(|v| v.size()).sum(); + group_values_size + self.map_size + self.hashes_buffer.allocated_size() + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + if self.group_values.is_empty() { + return 0; + } + + self.group_values[0].len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let mut output = match emit_to { + EmitTo::All => { + let group_values = std::mem::take(&mut self.group_values); + debug_assert!(self.group_values.is_empty()); + + group_values + .into_iter() + .map(|v| v.build()) + .collect::>() + } + EmitTo::First(n) => { + let output = self + .group_values + .iter_mut() + .map(|v| v.take_n(n)) + .collect::>(); + + // SAFETY: self.map outlives iterator and is not modified concurrently + unsafe { + for bucket in self.map.iter() { + // Decrement group index by n + match bucket.as_ref().1.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => bucket.as_mut().1 = sub, + // Group index was < n, so remove from table + None => self.map.erase(bucket), + } + } + } + + output + } + }; + + // TODO: Materialize dictionaries in group keys (#7647) + for (field, array) in self.schema.fields.iter().zip(&mut output) { + let expected = field.data_type(); + if let DataType::Dictionary(_, v) = expected { + let actual = array.data_type(); + if v.as_ref() != actual { + return Err(DataFusionError::Internal(format!( + "Converted group rows expected dictionary of {v} got {actual}" + ))); + } + *array = cast(array.as_ref(), expected)?; + } + } + + Ok(output) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.group_values.clear(); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); + self.hashes_buffer.clear(); + self.hashes_buffer.shrink_to(count); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs new file mode 100644 index 000000000000..bba59b6d0caa --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -0,0 +1,1257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::make_view; +use arrow::array::BufferBuilder; +use arrow::array::ByteView; +use arrow::array::GenericBinaryArray; +use arrow::array::GenericStringArray; +use arrow::array::OffsetSizeTrait; +use arrow::array::PrimitiveArray; +use arrow::array::{Array, ArrayRef, ArrowPrimitiveType, AsArray}; +use arrow::buffer::OffsetBuffer; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::ByteArrayType; +use arrow::datatypes::ByteViewType; +use arrow::datatypes::DataType; +use arrow::datatypes::GenericBinaryType; +use arrow_array::GenericByteViewArray; +use arrow_buffer::Buffer; +use datafusion_common::utils::proxy::VecAllocExt; + +use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; +use arrow_array::types::GenericStringType; +use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAPACITY}; +use std::marker::PhantomData; +use std::mem::{replace, size_of}; +use std::sync::Arc; +use std::vec; + +const BYTE_VIEW_MAX_BLOCK_SIZE: usize = 2 * 1024 * 1024; + +/// Trait for storing a single column of group values in [`GroupValuesColumn`] +/// +/// Implementations of this trait store an in-progress collection of group values +/// (similar to various builders in Arrow-rs) that allow for quick comparison to +/// incoming rows. +/// +/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn +pub trait GroupColumn: Send + Sync { + /// Returns equal if the row stored in this builder at `lhs_row` is equal to + /// the row in `array` at `rhs_row` + /// + /// Note that this comparison returns true if both elements are NULL + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool; + /// Appends the row at `row` in `array` to this builder + fn append_val(&mut self, array: &ArrayRef, row: usize); + /// Returns the number of rows stored in this builder + fn len(&self) -> usize; + /// Returns the number of bytes used by this [`GroupColumn`] + fn size(&self) -> usize; + /// Builds a new array from all of the stored rows + fn build(self: Box) -> ArrayRef; + /// Builds a new array from the first `n` stored rows, shifting the + /// remaining rows to the start of the builder + fn take_n(&mut self, n: usize) -> ArrayRef; +} + +/// An implementation of [`GroupColumn`] for primitive values +/// +/// Optimized to skip null buffer construction if the input is known to be non nullable +/// +/// # Template parameters +/// +/// `T`: the native Rust type that stores the data +/// `NULLABLE`: if the data can contain any nulls +#[derive(Debug)] +pub struct PrimitiveGroupValueBuilder { + group_values: Vec, + nulls: MaybeNullBufferBuilder, +} + +impl PrimitiveGroupValueBuilder +where + T: ArrowPrimitiveType, +{ + /// Create a new `PrimitiveGroupValueBuilder` + pub fn new() -> Self { + Self { + group_values: vec![], + nulls: MaybeNullBufferBuilder::new(), + } + } +} + +impl GroupColumn + for PrimitiveGroupValueBuilder +{ + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + // Perf: skip null check (by short circuit) if input is not nullable + if NULLABLE { + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + } + + self.group_values[lhs_row] == array.as_primitive::().value(rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + // Perf: skip null check if input can't have nulls + if NULLABLE { + if array.is_null(row) { + self.nulls.append(true); + self.group_values.push(T::default_value()); + } else { + self.nulls.append(false); + self.group_values.push(array.as_primitive::().value(row)); + } + } else { + self.group_values.push(array.as_primitive::().value(row)); + } + } + + fn len(&self) -> usize { + self.group_values.len() + } + + fn size(&self) -> usize { + self.group_values.allocated_size() + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + group_values, + nulls, + } = *self; + + let nulls = nulls.build(); + if !NULLABLE { + assert!(nulls.is_none(), "unexpected nulls in non nullable input"); + } + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(group_values), + nulls, + )) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + let first_n = self.group_values.drain(0..n).collect::>(); + + let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; + + Arc::new(PrimitiveArray::::new( + ScalarBuffer::from(first_n), + first_n_nulls, + )) + } +} + +/// An implementation of [`GroupColumn`] for binary and utf8 types. +/// +/// Stores a collection of binary or utf8 group values in a single buffer +/// in a way that allows: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +pub struct ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + output_type: OutputType, + buffer: BufferBuilder, + /// Offsets into `buffer` for each distinct value. These offsets as used + /// directly to create the final `GenericBinaryArray`. The `i`th string is + /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values + /// are stored as a zero length string. + offsets: Vec, + /// Nulls + nulls: MaybeNullBufferBuilder, +} + +impl ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + pub fn new(output_type: OutputType) -> Self { + Self { + output_type, + buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), + offsets: vec![O::default()], + nulls: MaybeNullBufferBuilder::new(), + } + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + where + B: ByteArrayType, + { + let arr = array.as_bytes::(); + if arr.is_null(row) { + self.nulls.append(true); + // nulls need a zero length in the offset buffer + let offset = self.buffer.len(); + self.offsets.push(O::usize_as(offset)); + } else { + self.nulls.append(false); + let value: &[u8] = arr.value(row).as_ref(); + self.buffer.append_slice(value); + self.offsets.push(O::usize_as(self.buffer.len())); + } + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool + where + B: ByteArrayType, + { + let array = array.as_bytes::(); + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + // Otherwise, we need to check their values + self.value(lhs_row) == (array.value(rhs_row).as_ref() as &[u8]) + } + + /// return the current value of the specified row irrespective of null + pub fn value(&self, row: usize) -> &[u8] { + let l = self.offsets[row].as_usize(); + let r = self.offsets[row + 1].as_usize(); + // Safety: the offsets are constructed correctly and never decrease + unsafe { self.buffer.as_slice().get_unchecked(l..r) } + } +} + +impl GroupColumn for ByteGroupValueBuilder +where + O: OffsetSizeTrait, +{ + fn equal_to(&self, lhs_row: usize, column: &ArrayRef, rhs_row: usize) -> bool { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.equal_to_inner::>(lhs_row, column, rhs_row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn append_val(&mut self, column: &ArrayRef, row: usize) { + // Sanity array type + match self.output_type { + OutputType::Binary => { + debug_assert!(matches!( + column.data_type(), + DataType::Binary | DataType::LargeBinary + )); + self.append_val_inner::>(column, row) + } + OutputType::Utf8 => { + debug_assert!(matches!( + column.data_type(), + DataType::Utf8 | DataType::LargeUtf8 + )); + self.append_val_inner::>(column, row) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + }; + } + + fn len(&self) -> usize { + self.offsets.len() - 1 + } + + fn size(&self) -> usize { + self.buffer.capacity() * size_of::() + + self.offsets.allocated_size() + + self.nulls.allocated_size() + } + + fn build(self: Box) -> ArrayRef { + let Self { + output_type, + mut buffer, + offsets, + nulls, + } = *self; + + let null_buffer = nulls.build(); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; + let values = buffer.finish(); + match output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. the input arrays were all the correct type and thus since + // all the values that went in were valid (e.g. utf8) so are all + // the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + let null_buffer = self.nulls.take_n(n); + let first_remaining_offset = O::as_usize(self.offsets[n]); + + // Given offests like [0, 2, 4, 5] and n = 1, we expect to get + // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. + // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. + let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); + let offset_n = *self.offsets.first().unwrap(); + self.offsets + .iter_mut() + .for_each(|offset| *offset = offset.sub(offset_n)); + first_n_offsets.push(offset_n); + + // SAFETY: the offsets were constructed correctly in `insert_if_new` -- + // monotonically increasing, overflows were checked. + let offsets = + unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(first_n_offsets)) }; + + let mut remaining_buffer = + BufferBuilder::new(self.buffer.len() - first_remaining_offset); + // TODO: Current approach copy the remaining and truncate the original one + // Find out a way to avoid copying buffer but split the original one into two. + remaining_buffer.append_slice(&self.buffer.as_slice()[first_remaining_offset..]); + self.buffer.truncate(first_remaining_offset); + let values = self.buffer.finish(); + self.buffer = remaining_buffer; + + match self.output_type { + OutputType::Binary => { + // SAFETY: the offsets were constructed correctly + Arc::new(unsafe { + GenericBinaryArray::new_unchecked(offsets, values, null_buffer) + }) + } + OutputType::Utf8 => { + // SAFETY: + // 1. the offsets were constructed safely + // + // 2. we asserted the input arrays were all the correct type and + // thus since all the values that went in were valid (e.g. utf8) + // so are all the values that come out + Arc::new(unsafe { + GenericStringArray::new_unchecked(offsets, values, null_buffer) + }) + } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), + } + } +} + +/// An implementation of [`GroupColumn`] for binary view and utf8 view types. +/// +/// Stores a collection of binary view or utf8 view group values in a buffer +/// whose structure is similar to `GenericByteViewArray`, and we can get benefits: +/// +/// 1. Efficient comparison of incoming rows to existing rows +/// 2. Efficient construction of the final output array +/// 3. Efficient to perform `take_n` comparing to use `GenericByteViewBuilder` +pub struct ByteViewGroupValueBuilder { + /// The views of string values + /// + /// If string len <= 12, the view's format will be: + /// string(12B) | len(4B) + /// + /// If string len > 12, its format will be: + /// offset(4B) | buffer_index(4B) | prefix(4B) | len(4B) + views: Vec, + + /// The progressing block + /// + /// New values will be inserted into it until its capacity + /// is not enough(detail can see `max_block_size`). + in_progress: Vec, + + /// The completed blocks + completed: Vec, + + /// The max size of `in_progress` + /// + /// `in_progress` will be flushed into `completed`, and create new `in_progress` + /// when found its remaining capacity(`max_block_size` - `len(in_progress)`), + /// is no enough to store the appended value. + /// + /// Currently it is fixed at 2MB. + max_block_size: usize, + + /// Nulls + nulls: MaybeNullBufferBuilder, + + /// phantom data so the type requires `` + _phantom: PhantomData, +} + +impl ByteViewGroupValueBuilder { + pub fn new() -> Self { + Self { + views: Vec::new(), + in_progress: Vec::new(), + completed: Vec::new(), + max_block_size: BYTE_VIEW_MAX_BLOCK_SIZE, + nulls: MaybeNullBufferBuilder::new(), + _phantom: PhantomData {}, + } + } + + /// Set the max block size + fn with_max_block_size(mut self, max_block_size: usize) -> Self { + self.max_block_size = max_block_size; + self + } + + fn append_val_inner(&mut self, array: &ArrayRef, row: usize) + where + B: ByteViewType, + { + let arr = array.as_byte_view::(); + + // Null row case, set and return + if arr.is_null(row) { + self.nulls.append(true); + self.views.push(0); + return; + } + + // Not null row case + self.nulls.append(false); + let value: &[u8] = arr.value(row).as_ref(); + + let value_len = value.len(); + let view = if value_len <= 12 { + make_view(value, 0, 0) + } else { + // Ensure big enough block to hold the value firstly + self.ensure_in_progress_big_enough(value_len); + + // Append value + let buffer_index = self.completed.len(); + let offset = self.in_progress.len(); + self.in_progress.extend_from_slice(value); + + make_view(value, buffer_index as u32, offset as u32) + }; + + // Append view + self.views.push(view); + } + + fn ensure_in_progress_big_enough(&mut self, value_len: usize) { + debug_assert!(value_len > 12); + let require_cap = self.in_progress.len() + value_len; + + // If current block isn't big enough, flush it and create a new in progress block + if require_cap > self.max_block_size { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } + } + + fn equal_to_inner(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + let array = array.as_byte_view::(); + + // Check if nulls equal firstly + let exist_null = self.nulls.is_null(lhs_row); + let input_null = array.is_null(rhs_row); + if let Some(result) = nulls_equal_to(exist_null, input_null) { + return result; + } + + // Otherwise, we need to check their values + let exist_view = self.views[lhs_row]; + let exist_view_len = exist_view as u32; + + let input_view = array.views()[rhs_row]; + let input_view_len = input_view as u32; + + // The check logic + // - Check len equality + // - If inlined, check inlined value + // - If non-inlined, check prefix and then check value in buffer + // when needed + if exist_view_len != input_view_len { + return false; + } + + if exist_view_len <= 12 { + let exist_inline = unsafe { + GenericByteViewArray::::inline_value( + &exist_view, + exist_view_len as usize, + ) + }; + let input_inline = unsafe { + GenericByteViewArray::::inline_value( + &input_view, + input_view_len as usize, + ) + }; + exist_inline == input_inline + } else { + let exist_prefix = + unsafe { GenericByteViewArray::::inline_value(&exist_view, 4) }; + let input_prefix = + unsafe { GenericByteViewArray::::inline_value(&input_view, 4) }; + + if exist_prefix != input_prefix { + return false; + } + + let exist_full = { + let byte_view = ByteView::from(exist_view); + self.value( + byte_view.buffer_index as usize, + byte_view.offset as usize, + byte_view.length as usize, + ) + }; + let input_full: &[u8] = unsafe { array.value_unchecked(rhs_row).as_ref() }; + exist_full == input_full + } + } + + fn value(&self, buffer_index: usize, offset: usize, length: usize) -> &[u8] { + debug_assert!(buffer_index <= self.completed.len()); + + if buffer_index < self.completed.len() { + let block = &self.completed[buffer_index]; + &block[offset..offset + length] + } else { + &self.in_progress[offset..offset + length] + } + } + + fn build_inner(self) -> ArrayRef { + let Self { + views, + in_progress, + mut completed, + nulls, + .. + } = self; + + // Build nulls + let null_buffer = nulls.build(); + + // Build values + // Flush `in_process` firstly + if !in_progress.is_empty() { + let buffer = Buffer::from(in_progress); + completed.push(buffer); + } + + let views = ScalarBuffer::from(views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + completed, + null_buffer, + )) + } + } + + fn take_n_inner(&mut self, n: usize) -> ArrayRef { + debug_assert!(self.len() >= n); + + // The `n == len` case, we need to take all + if self.len() == n { + let new_builder = Self::new().with_max_block_size(self.max_block_size); + let cur_builder = replace(self, new_builder); + return cur_builder.build_inner(); + } + + // The `n < len` case + // Take n for nulls + let null_buffer = self.nulls.take_n(n); + + // Take n for values: + // - Take first n `view`s from `views` + // + // - Find the last non-inlined `view`, if all inlined, + // we can build array and return happily, otherwise we + // we need to continue to process related buffers + // + // - Get the last related `buffer index`(let's name it `buffer index n`) + // from last non-inlined `view` + // + // - Take buffers, the key is that we need to know if we need to take + // the whole last related buffer. The logic is a bit complex, you can + // detail in `take_buffers_with_whole_last`, `take_buffers_with_partial_last` + // and other related steps in following + // + // - Shift the `buffer index` of remaining non-inlined `views` + // + let first_n_views = self.views.drain(0..n).collect::>(); + + let last_non_inlined_view = first_n_views + .iter() + .rev() + .find(|view| ((**view) as u32) > 12); + + // All taken views inlined + let Some(view) = last_non_inlined_view else { + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + return Arc::new(GenericByteViewArray::::new_unchecked( + views, + Vec::new(), + null_buffer, + )); + } + }; + + // Unfortunately, some taken views non-inlined + let view = ByteView::from(*view); + let last_remaining_buffer_index = view.buffer_index as usize; + + // Check should we take the whole `last_remaining_buffer_index` buffer + let take_whole_last_buffer = self.should_take_whole_buffer( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ); + + // Take related buffers + let buffers = if take_whole_last_buffer { + self.take_buffers_with_whole_last(last_remaining_buffer_index) + } else { + self.take_buffers_with_partial_last( + last_remaining_buffer_index, + (view.offset + view.length) as usize, + ) + }; + + // Shift `buffer index`s finally + let shifts = if take_whole_last_buffer { + last_remaining_buffer_index + 1 + } else { + last_remaining_buffer_index + }; + + self.views.iter_mut().for_each(|view| { + if (*view as u32) > 12 { + let mut byte_view = ByteView::from(*view); + byte_view.buffer_index -= shifts as u32; + *view = byte_view.as_u128(); + } + }); + + // Build array and return + let views = ScalarBuffer::from(first_n_views); + + // Safety: + // * all views were correctly made + // * (if utf8): Input was valid Utf8 so buffer contents are + // valid utf8 as well + unsafe { + Arc::new(GenericByteViewArray::::new_unchecked( + views, + buffers, + null_buffer, + )) + } + } + + fn take_buffers_with_whole_last( + &mut self, + last_remaining_buffer_index: usize, + ) -> Vec { + if last_remaining_buffer_index == self.completed.len() { + self.flush_in_progress(); + } + self.completed + .drain(0..last_remaining_buffer_index + 1) + .collect() + } + + fn take_buffers_with_partial_last( + &mut self, + last_remaining_buffer_index: usize, + last_take_len: usize, + ) -> Vec { + let mut take_buffers = Vec::with_capacity(last_remaining_buffer_index + 1); + + // Take `0 ~ last_remaining_buffer_index - 1` buffers + if !self.completed.is_empty() || last_remaining_buffer_index == 0 { + take_buffers.extend(self.completed.drain(0..last_remaining_buffer_index)); + } + + // Process the `last_remaining_buffer_index` buffers + let last_buffer = if last_remaining_buffer_index < self.completed.len() { + // If it is in `completed`, simply clone + self.completed[last_remaining_buffer_index].clone() + } else { + // If it is `in_progress`, copied `0 ~ offset` part + let taken_last_buffer = self.in_progress[0..last_take_len].to_vec(); + Buffer::from_vec(taken_last_buffer) + }; + take_buffers.push(last_buffer); + + take_buffers + } + + #[inline] + fn should_take_whole_buffer(&self, buffer_index: usize, take_len: usize) -> bool { + if buffer_index < self.completed.len() { + take_len == self.completed[buffer_index].len() + } else { + take_len == self.in_progress.len() + } + } + + fn flush_in_progress(&mut self) { + let flushed_block = replace( + &mut self.in_progress, + Vec::with_capacity(self.max_block_size), + ); + let buffer = Buffer::from_vec(flushed_block); + self.completed.push(buffer); + } +} + +impl GroupColumn for ByteViewGroupValueBuilder { + fn equal_to(&self, lhs_row: usize, array: &ArrayRef, rhs_row: usize) -> bool { + self.equal_to_inner(lhs_row, array, rhs_row) + } + + fn append_val(&mut self, array: &ArrayRef, row: usize) { + self.append_val_inner(array, row) + } + + fn len(&self) -> usize { + self.views.len() + } + + fn size(&self) -> usize { + let buffers_size = self + .completed + .iter() + .map(|buf| buf.capacity() * size_of::()) + .sum::(); + + self.nulls.allocated_size() + + self.views.capacity() * size_of::() + + self.in_progress.capacity() * size_of::() + + buffers_size + + size_of::() + } + + fn build(self: Box) -> ArrayRef { + Self::build_inner(*self) + } + + fn take_n(&mut self, n: usize) -> ArrayRef { + self.take_n_inner(n) + } +} + +/// Determines if the nullability of the existing and new input array can be used +/// to short-circuit the comparison of the two values. +/// +/// Returns `Some(result)` if the result of the comparison can be determined +/// from the nullness of the two values, and `None` if the comparison must be +/// done on the values themselves. +fn nulls_equal_to(lhs_null: bool, rhs_null: bool) -> Option { + match (lhs_null, rhs_null) { + (true, true) => Some(true), + (false, true) | (true, false) => Some(false), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + array::AsArray, + datatypes::{Int64Type, StringViewType}, + }; + use arrow_array::{ArrayRef, Int64Array, StringArray, StringViewArray}; + use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + use datafusion_physical_expr::binary_map::OutputType; + + use crate::aggregates::group_values::group_column::{ + ByteViewGroupValueBuilder, PrimitiveGroupValueBuilder, + }; + + use super::{ByteGroupValueBuilder, GroupColumn}; + + #[test] + fn test_take_n() { + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let array = Arc::new(StringArray::from(vec![Some("a"), None])) as ArrayRef; + // a, null, null + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (a, null) remaining: null + let output = builder.take_n(2); + assert_eq!(&output, &array); + + // null, a, null, a + builder.append_val(&array, 0); + builder.append_val(&array, 1); + builder.append_val(&array, 0); + + // (null, a) remaining: (null, a) + let output = builder.take_n(2); + let array = Arc::new(StringArray::from(vec![None, Some("a")])) as ArrayRef; + assert_eq!(&output, &array); + + let array = Arc::new(StringArray::from(vec![ + Some("a"), + None, + Some("longstringfortest"), + ])) as ArrayRef; + + // null, a, longstringfortest, null, null + builder.append_val(&array, 2); + builder.append_val(&array, 1); + builder.append_val(&array, 1); + + // (null, a, longstringfortest, null) remaining: (null) + let output = builder.take_n(4); + let array = Arc::new(StringArray::from(vec![ + None, + Some("a"), + Some("longstringfortest"), + None, + ])) as ArrayRef; + assert_eq!(&output, &array); + } + + #[test] + fn test_nullable_primitive_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = Arc::new(Int64Array::from(vec![ + None, + None, + None, + Some(1), + Some(2), + Some(3), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (_nulls, values, _) = + Int64Array::from(vec![Some(1), Some(2), None, None, Some(1), Some(3)]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some(2) to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = Arc::new(Int64Array::new(values, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } + + #[test] + fn test_not_nullable_primitive_equal_to() { + // Will cover such cases: + // - values equal + // - values not equal + + // Define PrimitiveGroupValueBuilder + let mut builder = PrimitiveGroupValueBuilder::::new(); + let builder_array = + Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + + // Define input array + let input_array = Arc::new(Int64Array::from(vec![Some(0), Some(2)])) as ArrayRef; + + // Check + assert!(builder.equal_to(0, &input_array, 0)); + assert!(!builder.equal_to(1, &input_array, 1)); + } + + #[test] + fn test_byte_array_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; values not equal + // - exist not null, input not null; values equal + + // Define PrimitiveGroupValueBuilder + let mut builder = ByteGroupValueBuilder::::new(OutputType::Utf8); + let builder_array = Arc::new(StringArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bar"), + Some("baz"), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + + // Define input array + let (offsets, buffer, _nulls) = StringArray::from(vec![ + Some("foo"), + Some("bar"), + None, + None, + Some("foo"), + Some("baz"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(6); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringArray::new(offsets, buffer, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(builder.equal_to(5, &input_array, 5)); + } + + #[test] + fn test_byte_view_append_val() { + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = StringViewArray::from(vec![ + Some("this string is quite long"), // in buffer 0 + Some("foo"), + None, + Some("bar"), + Some("this string is also quite long"), // buffer 0 + Some("this string is quite long"), // buffer 1 + Some("bar"), + ]); + let builder_array: ArrayRef = Arc::new(builder_array); + for row in 0..builder_array.len() { + builder.append_val(&builder_array, row); + } + + let output = Box::new(builder).build(); + // should be 2 output buffers to hold all the data + assert_eq!(output.as_string_view().data_buffers().len(), 2,); + assert_eq!(&output, &builder_array) + } + + #[test] + fn test_byte_view_equal_to() { + // Will cover such cases: + // - exist null, input not null + // - exist null, input null; values not equal + // - exist null, input null; values equal + // - exist not null, input null + // - exist not null, input not null; value lens not equal + // - exist not null, input not null; value not equal(inlined case) + // - exist not null, input not null; value equal(inlined case) + // + // - exist not null, input not null; value not equal + // (non-inlined case + prefix not equal) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `completed`) + // + // - exist not null, input not null; value not equal + // (non-inlined case + value in `in_progress`) + // + // - exist not null, input not null; value equal + // (non-inlined case + value in `in_progress`) + + // Set the block size to 40 for ensuring some unlined values are in `in_progress`, + // and some are in `completed`, so both two branches in `value` function can be covered. + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let builder_array = Arc::new(StringViewArray::from(vec![ + None, + None, + None, + Some("foo"), + Some("bazz"), + Some("foo"), + Some("bar"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in progress"), + ])) as ArrayRef; + builder.append_val(&builder_array, 0); + builder.append_val(&builder_array, 1); + builder.append_val(&builder_array, 2); + builder.append_val(&builder_array, 3); + builder.append_val(&builder_array, 4); + builder.append_val(&builder_array, 5); + builder.append_val(&builder_array, 6); + builder.append_val(&builder_array, 7); + builder.append_val(&builder_array, 8); + + // Define input array + let (views, buffer, _nulls) = StringViewArray::from(vec![ + Some("foo"), + Some("bar"), // set to null + None, + None, + Some("baz"), + Some("oof"), + Some("bar"), + Some("i am a long string for test eq in completed"), + Some("I am a long string for test eq in COMPLETED"), + Some("I am a long string for test eq in completed"), + Some("I am a long string for test eq in PROGRESS"), + Some("I am a long string for test eq in progress"), + ]) + .into_parts(); + + // explicitly build a boolean buffer where one of the null values also happens to match + let mut boolean_buffer_builder = BooleanBufferBuilder::new(9); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(false); // this sets Some("bar") to null above + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(false); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + boolean_buffer_builder.append(true); + let nulls = NullBuffer::new(boolean_buffer_builder.finish()); + let input_array = + Arc::new(StringViewArray::new(views, buffer, Some(nulls))) as ArrayRef; + + // Check + assert!(!builder.equal_to(0, &input_array, 0)); + assert!(builder.equal_to(1, &input_array, 1)); + assert!(builder.equal_to(2, &input_array, 2)); + assert!(!builder.equal_to(3, &input_array, 3)); + assert!(!builder.equal_to(4, &input_array, 4)); + assert!(!builder.equal_to(5, &input_array, 5)); + assert!(builder.equal_to(6, &input_array, 6)); + assert!(!builder.equal_to(7, &input_array, 7)); + assert!(!builder.equal_to(7, &input_array, 8)); + assert!(builder.equal_to(7, &input_array, 9)); + assert!(!builder.equal_to(8, &input_array, 10)); + assert!(builder.equal_to(8, &input_array, 11)); + } + + #[test] + fn test_byte_view_take_n() { + // ####### Define cases and init ####### + + // `take_n` is really complex, we should consider and test following situations: + // 1. Take nulls + // 2. Take all `inlined`s + // 3. Take non-inlined + partial last buffer in `completed` + // 4. Take non-inlined + whole last buffer in `completed` + // 5. Take non-inlined + partial last `in_progress` + // 6. Take non-inlined + whole last buffer in `in_progress` + // 7. Take all views at once + + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + let input_array = StringViewArray::from(vec![ + // Test situation 1 + None, + None, + // Test situation 2 (also test take null together) + None, + Some("foo"), + Some("bar"), + // Test situation 3 (also test take null + inlined) + None, + Some("foo"), + Some("this string is quite long"), + Some("this string is also quite long"), + // Test situation 4 (also test take null + inlined) + None, + Some("bar"), + Some("this string is quite long"), + // Test situation 5 (also test take null + inlined) + None, + Some("foo"), + Some("another string that is is quite long"), + Some("this string not so long"), + // Test situation 6 (also test take null + inlined + insert again after taking) + None, + Some("bar"), + Some("this string is quite long"), + // Insert 4 and just take 3 to ensure it will go the path of situation 6 + None, + // Finally, we create a new builder, insert the whole array and then + // take whole at once for testing situation 7 + ]); + + let input_array: ArrayRef = Arc::new(input_array); + let first_ones_to_append = 16; // For testing situation 1~5 + let second_ones_to_append = 4; // For testing situation 6 + let final_ones_to_append = input_array.len(); // For testing situation 7 + + // ####### Test situation 1~5 ####### + for row in 0..first_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 2); + assert_eq!(builder.in_progress.len(), 59); + + // Situation 1 + let taken_array = builder.take_n(2); + assert_eq!(&taken_array, &input_array.slice(0, 2)); + + // Situation 2 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(2, 3)); + + // Situation 3 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(5, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(8, 1)); + + // Situation 4 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(9, 3)); + + // Situation 5 + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(12, 3)); + + let taken_array = builder.take_n(1); + assert_eq!(&taken_array, &input_array.slice(15, 1)); + + // ####### Test situation 6 ####### + assert!(builder.completed.is_empty()); + assert!(builder.in_progress.is_empty()); + assert!(builder.views.is_empty()); + + for row in first_ones_to_append..first_ones_to_append + second_ones_to_append { + builder.append_val(&input_array, row); + } + + assert!(builder.completed.is_empty()); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(3); + assert_eq!(&taken_array, &input_array.slice(16, 3)); + + // ####### Test situation 7 ####### + // Create a new builder + let mut builder = + ByteViewGroupValueBuilder::::new().with_max_block_size(60); + + for row in 0..final_ones_to_append { + builder.append_val(&input_array, row); + } + + assert_eq!(builder.completed.len(), 3); + assert_eq!(builder.in_progress.len(), 25); + + let taken_array = builder.take_n(final_ones_to_append); + assert_eq!(&taken_array, &input_array); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index b5bc923b467d..fb7b66775092 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -15,34 +15,86 @@ // specific language governing permissions and limitations // under the License. +//! [`GroupValues`] trait for storing and interning group keys + use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::{DataType, SchemaRef}; +use bytes_view::GroupValuesBytesView; use datafusion_common::Result; pub(crate) mod primitive; use datafusion_expr::EmitTo; use primitive::GroupValuesPrimitive; +mod column; mod row; +use column::GroupValuesColumn; use row::GroupValuesRows; mod bytes; +mod bytes_view; use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; -/// An interning store for group keys +mod group_column; +mod null_builder; + +/// Stores the group values during hash aggregation. +/// +/// # Background +/// +/// In a query such as `SELECT a, b, count(*) FROM t GROUP BY a, b`, the group values +/// identify each group, and correspond to all the distinct values of `(a,b)`. +/// +/// ```sql +/// -- Input has 4 rows with 3 distinct combinations of (a,b) ("groups") +/// create table t(a int, b varchar) +/// as values (1, 'a'), (2, 'b'), (1, 'a'), (3, 'c'); +/// +/// select a, b, count(*) from t group by a, b; +/// ---- +/// 1 a 2 +/// 2 b 1 +/// 3 c 1 +/// ``` +/// +/// # Design +/// +/// Managing group values is a performance critical operation in hash +/// aggregation. The major operations are: +/// +/// 1. Intern: Quickly finding existing and adding new group values +/// 2. Emit: Returning the group values as an array +/// +/// There are multiple specialized implementations of this trait optimized for +/// different data types and number of columns, optimized for these operations. +/// See [`new_group_values`] for details. +/// +/// # Group Ids +/// +/// Each distinct group in a hash aggregation is identified by a unique group id +/// (usize) which is assigned by instances of this trait. Group ids are +/// continuous without gaps, starting from 0. pub trait GroupValues: Send { - /// Calculates the `groups` for each input row of `cols` + /// Calculates the group id for each input row of `cols`, assigning new + /// group ids as necessary. + /// + /// When the function returns, `groups` must contain the group id for each + /// row in `cols`. + /// + /// If a row has the same value as a previous row, the same group id is + /// assigned. If a row has a new value, the next available group id is + /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; - /// Returns the number of bytes used by this [`GroupValues`] + /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; /// Returns true if this [`GroupValues`] is empty fn is_empty(&self) -> bool; - /// The number of values stored in this [`GroupValues`] + /// The number of values (distinct group values) stored in this [`GroupValues`] fn len(&self) -> usize; /// Emits the group values @@ -52,6 +104,7 @@ pub trait GroupValues: Send { fn clear_shrink(&mut self, batch: &RecordBatch); } +/// Return a specialized implementation of [`GroupValues`] for the given schema. pub fn new_group_values(schema: SchemaRef) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); @@ -67,19 +120,32 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { _ => {} } - if let DataType::Utf8 = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); - } - if let DataType::LargeUtf8 = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); - } - if let DataType::Binary = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); - } - if let DataType::LargeBinary = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + match d { + DataType::Utf8 => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); + } + DataType::LargeUtf8 => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); + } + DataType::Utf8View => { + return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View))); + } + DataType::Binary => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + } + DataType::LargeBinary => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + } + DataType::BinaryView => { + return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView))); + } + _ => {} } } - Ok(Box::new(GroupValuesRows::try_new(schema)?)) + if GroupValuesColumn::supported_schema(schema.as_ref()) { + Ok(Box::new(GroupValuesColumn::try_new(schema)?)) + } else { + Ok(Box::new(GroupValuesRows::try_new(schema)?)) + } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs new file mode 100644 index 000000000000..0249390f38cd --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/null_builder.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + +/// Builder for an (optional) null mask +/// +/// Optimized for avoid creating the bitmask when all values are non-null +#[derive(Debug)] +pub(crate) enum MaybeNullBufferBuilder { + /// seen `row_count` rows but no nulls yet + NoNulls { row_count: usize }, + /// have at least one null value + /// + /// Note this is an Arrow *VALIDITY* buffer (so it is false for nulls, true + /// for non-nulls) + Nulls(BooleanBufferBuilder), +} + +impl MaybeNullBufferBuilder { + /// Create a new builder + pub fn new() -> Self { + Self::NoNulls { row_count: 0 } + } + + /// Return true if the row at index `row` is null + pub fn is_null(&self, row: usize) -> bool { + match self { + Self::NoNulls { .. } => false, + // validity mask means a unset bit is NULL + Self::Nulls(builder) => !builder.get_bit(row), + } + } + + /// Set the nullness of the next row to `is_null` + /// + /// num_values is the current length of the rows being tracked + /// + /// If `value` is true, the row is null. + /// If `value` is false, the row is non null + pub fn append(&mut self, is_null: bool) { + match self { + Self::NoNulls { row_count } if is_null => { + // have seen no nulls so far, this is the first null, + // need to create the nulls buffer for all currently valid values + // alloc 2x the need given we push a new but immediately + let mut nulls = BooleanBufferBuilder::new(*row_count * 2); + nulls.append_n(*row_count, true); + nulls.append(false); + *self = Self::Nulls(nulls); + } + Self::NoNulls { row_count } => { + *row_count += 1; + } + Self::Nulls(builder) => builder.append(!is_null), + } + } + + /// return the number of heap allocated bytes used by this structure to store boolean values + pub fn allocated_size(&self) -> usize { + match self { + Self::NoNulls { .. } => 0, + // BooleanBufferBuilder builder::capacity returns capacity in bits (not bytes) + Self::Nulls(builder) => builder.capacity() / 8, + } + } + + /// Return a NullBuffer representing the accumulated nulls so far + pub fn build(self) -> Option { + match self { + Self::NoNulls { .. } => None, + Self::Nulls(mut builder) => Some(NullBuffer::from(builder.finish())), + } + } + + /// Returns a NullBuffer representing the first `n` rows accumulated so far + /// shifting any remaining down by `n` + pub fn take_n(&mut self, n: usize) -> Option { + match self { + Self::NoNulls { row_count } => { + *row_count -= n; + None + } + Self::Nulls(builder) => { + // Copy over the values at n..len-1 values to the start of a + // new builder and leave it in self + // + // TODO: it would be great to use something like `set_bits` from arrow here. + let mut new_builder = BooleanBufferBuilder::new(builder.len()); + for i in n..builder.len() { + new_builder.append(builder.get_bit(i)); + } + std::mem::swap(&mut new_builder, builder); + + // take only first n values from the original builder + new_builder.truncate(n); + Some(NullBuffer::from(new_builder.finish())) + } + } + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs index 18d20f3c47e6..05214ec10d68 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/primitive.rs @@ -23,12 +23,14 @@ use arrow::datatypes::i256; use arrow::record_batch::RecordBatch; use arrow_array::cast::AsArray; use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray}; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::DataType; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; use half::f16; use hashbrown::raw::RawTable; +use std::mem::size_of; use std::sync::Arc; /// A trait to allow hashing of floating point numbers @@ -53,6 +55,7 @@ macro_rules! hash_integer { } hash_integer!(i8, i16, i32, i64, i128, i256); hash_integer!(u8, u16, u32, u64); +hash_integer!(IntervalDayTime, IntervalMonthDayNano); macro_rules! hash_float { ($($t:ty),+) => { @@ -149,7 +152,7 @@ where } fn size(&self) -> usize { - self.map.capacity() * std::mem::size_of::() + self.values.allocated_size() + self.map.capacity() * size_of::() + self.values.allocated_size() } fn is_empty(&self) -> bool { diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 3b7480cd292a..de0ae2e07dd2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -20,15 +20,24 @@ use ahash::RandomState; use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; +use arrow_array::{Array, ArrayRef, ListArray, StructArray}; use arrow_schema::{DataType, SchemaRef}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use hashbrown::raw::RawTable; +use std::mem::size_of; +use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] +/// +/// This is a general implementation of [`GroupValues`] that works for any +/// combination of data types and number of columns, including nested types such as +/// structs and lists. +/// +/// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise +/// representation. pub struct GroupValuesRows { /// The output schema schema: SchemaRef, @@ -59,9 +68,12 @@ pub struct GroupValuesRows { /// [`Row`]: arrow::row::Row group_values: Option, - // buffer to be reused to store hashes + /// reused buffer to store hashes hashes_buffer: Vec, + /// reused buffer to store rows + rows_buffer: Rows, + /// Random state for creating hashes random_state: RandomState, } @@ -78,6 +90,11 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); + let starting_rows_capacity = 1000; + + let starting_data_capacity = 64 * starting_rows_capacity; + let rows_buffer = + row_converter.empty_rows(starting_rows_capacity, starting_data_capacity); Ok(Self { schema, row_converter, @@ -85,6 +102,7 @@ impl GroupValuesRows { map_size: 0, group_values: None, hashes_buffer: Default::default(), + rows_buffer, random_state: Default::default(), }) } @@ -93,8 +111,9 @@ impl GroupValuesRows { impl GroupValues for GroupValuesRows { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { // Convert the group keys into the row format - // Avoid reallocation when https://github.com/apache/arrow-rs/issues/4479 is available - let group_rows = self.row_converter.convert_columns(cols)?; + let group_rows = &mut self.rows_buffer; + group_rows.clear(); + self.row_converter.append(group_rows, cols)?; let n_rows = group_rows.num_rows(); let mut group_values = match self.group_values.take() { @@ -111,12 +130,17 @@ impl GroupValues for GroupValuesRows { batch_hashes.resize(n_rows, 0); create_hashes(cols, &self.random_state, batch_hashes)?; - for (row, &hash) in batch_hashes.iter().enumerate() { - let entry = self.map.get_mut(hash, |(_hash, group_idx)| { - // verify that a group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - group_rows.row(row) == group_values.row(*group_idx) + for (row, &target_hash) in batch_hashes.iter().enumerate() { + let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { + // Somewhat surprisingly, this closure can be called even if the + // hash doesn't match, so check the hash first with an integer + // comparison first avoid the more expensive comparison with + // group value. https://github.com/apache/datafusion/pull/11718 + target_hash == *exist_hash + // verify that the group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + && group_rows.row(row) == group_values.row(*group_idx) }); let group_idx = match entry { @@ -130,7 +154,7 @@ impl GroupValues for GroupValuesRows { // for hasher function, use precomputed hash value self.map.insert_accounted( - (hash, group_idx), + (target_hash, group_idx), |(hash, _group_index)| *hash, &mut self.map_size, ); @@ -150,6 +174,7 @@ impl GroupValues for GroupValuesRows { self.row_converter.size() + group_values_size + self.map_size + + self.rows_buffer.size() + self.hashes_buffer.allocated_size() } @@ -180,7 +205,7 @@ impl GroupValues for GroupValuesRows { let groups_rows = group_values.iter().take(n); let output = self.row_converter.convert_rows(groups_rows)?; // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficent? + // TODO file some ticket in arrow-rs to make this more efficient? let mut new_group_values = self.row_converter.empty_rows(0, 0); for row in group_values.iter().skip(n) { new_group_values.push(row); @@ -203,18 +228,12 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) + // TODO: Materialize dictionaries in group keys + // https://github.com/apache/datafusion/issues/7647 for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); - } - *array = cast(array.as_ref(), expected)?; - } + *array = + dictionary_encode_if_necessary(Arc::::clone(array), expected)?; } self.group_values = Some(group_values); @@ -229,8 +248,50 @@ impl GroupValues for GroupValuesRows { }); self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared - self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); + self.map_size = self.map.capacity() * size_of::<(u64, usize)>(); self.hashes_buffer.clear(); self.hashes_buffer.shrink_to(count); } } + +fn dictionary_encode_if_necessary( + array: ArrayRef, + expected: &DataType, +) -> Result { + match (expected, array.data_type()) { + (DataType::Struct(expected_fields), _) => { + let struct_array = array.as_any().downcast_ref::().unwrap(); + let arrays = expected_fields + .iter() + .zip(struct_array.columns()) + .map(|(expected_field, column)| { + dictionary_encode_if_necessary( + Arc::::clone(column), + expected_field.data_type(), + ) + }) + .collect::>>()?; + + Ok(Arc::new(StructArray::try_new( + expected_fields.clone(), + arrays, + struct_array.nulls().cloned(), + )?)) + } + (DataType::List(expected_field), &DataType::List(_)) => { + let list = array.as_any().downcast_ref::().unwrap(); + + Ok(Arc::new(ListArray::try_new( + Arc::::clone(expected_field), + list.offsets().clone(), + dictionary_encode_if_necessary( + Arc::::clone(list.values()), + expected_field.data_type(), + )?, + list.nulls().cloned(), + )?)) + } + (DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?), + (_, _) => Ok(Arc::::clone(&array)), + } +} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 14485c833794..5ffe797c5c26 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -26,48 +26,54 @@ use crate::aggregates::{ topk_stream::GroupedTopKAggregateStream, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; +use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ - DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, Partitioning, + DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, SendableRecordBatchStream, Statistics, }; use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_array::{UInt16Array, UInt32Array, UInt64Array, UInt8Array}; use datafusion_common::stats::Precision; use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_execution::TaskContext; -use datafusion_expr::Accumulator; -use datafusion_physical_expr::aggregate::is_order_sensitive; -use datafusion_physical_expr::equivalence::collapse_lex_req; +use datafusion_expr::{Accumulator, Aggregate}; use datafusion_physical_expr::{ - equivalence::ProjectionMapping, - expressions::{Column, Max, Min, UnKnownColumn}, - AggregateExpr, LexRequirement, PhysicalExpr, -}; -use datafusion_physical_expr::{ - physical_exprs_contains, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, + equivalence::{collapse_lex_req, ProjectionMapping}, + expressions::Column, + physical_exprs_contains, EquivalenceProperties, LexOrdering, LexRequirement, + PhysicalExpr, PhysicalSortRequirement, }; +use crate::execution_plan::CardinalityEffect; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; use itertools::Itertools; -mod group_values; +pub mod group_values; mod no_grouping; -mod order; +pub mod order; mod row_hash; mod topk; mod topk_stream; -pub use datafusion_expr::AggregateFunction; -pub use datafusion_physical_expr::expressions::create_aggregate_expr; - /// Hash aggregate modes +/// +/// See [`Accumulator::state`] for background information on multi-phase +/// aggregation and how these modes are used. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum AggregateMode { - /// Partial aggregate that can be applied in parallel across input partitions + /// Partial aggregate that can be applied in parallel across input + /// partitions. + /// + /// This is the first phase of a multi-phase aggregation. Partial, - /// Final aggregate that produces a single partition of output + /// Final aggregate that produces a single partition of output by combining + /// the output of multiple partial aggregates. + /// + /// This is the second phase of a multi-phase aggregation. Final, /// Final aggregate that works on pre-partitioned data. /// @@ -79,12 +85,15 @@ pub enum AggregateMode { /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. - /// This mode requires tha the input is a single partition (like Final) + /// + /// This mode requires that the input is a single partition (like Final) Single, /// Applies the entire logical aggregation operation in a single operator, /// as opposed to Partial / Final modes which apply the logical aggregation using /// two operators. - /// This mode requires tha the input is partitioned by group key (like FinalPartitioned) + /// + /// This mode requires that the input is partitioned by group key (like + /// FinalPartitioned) SinglePartitioned, } @@ -105,9 +114,9 @@ impl AggregateMode { /// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET) /// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b] /// and a single group [false, false]. -/// In the case of `GROUP BY GROUPING SET/CUBE/ROLLUP` the planner will expand the expression +/// In the case of `GROUP BY GROUPING SETS/CUBE/ROLLUP` the planner will expand the expression /// into multiple groups, using null expressions to align each group. -/// For example, with a group by clause `GROUP BY GROUPING SET ((a,b),(a),(b))` the planner should +/// For example, with a group by clause `GROUP BY GROUPING SETS ((a,b),(a),(b))` the planner should /// create a `PhysicalGroupBy` like /// ```text /// PhysicalGroupBy { @@ -128,7 +137,7 @@ pub struct PhysicalGroupBy { null_expr: Vec<(Arc, String)>, /// Null mask for each group in this grouping set. Each group is /// composed of either one of the group expressions in expr or a null - /// expression in null_expr. If `groups[i][j]` is true, then the the + /// expression in null_expr. If `groups[i][j]` is true, then the /// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`. groups: Vec>, } @@ -158,9 +167,17 @@ impl PhysicalGroupBy { } } - /// Returns true if this GROUP BY contains NULL expressions - pub fn contains_null(&self) -> bool { - self.groups.iter().flatten().any(|is_null| *is_null) + /// Calculate GROUP BY expressions nullable + pub fn exprs_nullable(&self) -> Vec { + let mut exprs_nullable = vec![false; self.expr.len()]; + for group in self.groups.iter() { + group.iter().enumerate().for_each(|(index, is_null)| { + if *is_null { + exprs_nullable[index] = true; + } + }) + } + exprs_nullable } /// Returns the group expressions @@ -192,17 +209,103 @@ impl PhysicalGroupBy { pub fn input_exprs(&self) -> Vec> { self.expr .iter() - .map(|(expr, _alias)| expr.clone()) + .map(|(expr, _alias)| Arc::clone(expr)) .collect() } + /// The number of expressions in the output schema. + fn num_output_exprs(&self) -> usize { + let mut num_exprs = self.expr.len(); + if !self.is_single() { + num_exprs += 1 + } + num_exprs + } + /// Return grouping expressions as they occur in the output schema. pub fn output_exprs(&self) -> Vec> { - self.expr - .iter() - .enumerate() - .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _) - .collect() + let num_output_exprs = self.num_output_exprs(); + let mut output_exprs = Vec::with_capacity(num_output_exprs); + output_exprs.extend( + self.expr + .iter() + .enumerate() + .take(num_output_exprs) + .map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _), + ); + if !self.is_single() { + output_exprs.push(Arc::new(Column::new( + Aggregate::INTERNAL_GROUPING_ID, + self.expr.len(), + )) as _); + } + output_exprs + } + + /// Returns the number expression as grouping keys. + fn num_group_exprs(&self) -> usize { + if self.is_single() { + self.expr.len() + } else { + self.expr.len() + 1 + } + } + + /// Returns the fields that are used as the grouping keys. + fn group_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = Vec::with_capacity(self.num_group_exprs()); + for ((expr, name), group_expr_nullable) in + self.expr.iter().zip(self.exprs_nullable().into_iter()) + { + fields.push( + Field::new( + name, + expr.data_type(input_schema)?, + group_expr_nullable || expr.nullable(input_schema)?, + ) + .with_metadata( + get_field_metadata(expr, input_schema).unwrap_or_default(), + ), + ); + } + if !self.is_single() { + fields.push(Field::new( + Aggregate::INTERNAL_GROUPING_ID, + Aggregate::grouping_id_type(self.expr.len()), + false, + )); + } + Ok(fields) + } + + /// Returns the output fields of the group by. + /// + /// This might be different from the `group_fields` that might contain internal expressions that + /// should not be part of the output schema. + fn output_fields(&self, input_schema: &Schema) -> Result> { + let mut fields = self.group_fields(input_schema)?; + fields.truncate(self.num_output_exprs()); + Ok(fields) + } + + /// Returns the `PhysicalGroupBy` for a final aggregation if `self` is used for a partial + /// aggregation. + pub fn as_final(&self) -> PhysicalGroupBy { + let expr: Vec<_> = + self.output_exprs() + .into_iter() + .zip( + self.expr.iter().map(|t| t.1.clone()).chain(std::iter::once( + Aggregate::INTERNAL_GROUPING_ID.to_owned(), + )), + ) + .collect(); + let num_exprs = expr.len(); + Self { + expr, + null_expr: vec![], + groups: vec![vec![false; num_exprs]], + } } } @@ -241,14 +344,14 @@ impl From for SendableRecordBatchStream { } /// Hash aggregate execution plan -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct AggregateExec { /// Aggregation mode (full, partial) mode: AggregateMode, /// Group by expressions group_by: PhysicalGroupBy, /// Aggregate expressions - aggr_expr: Vec>, + aggr_expr: Vec>, /// FILTER (WHERE clause) expression for each aggregate expression filter_expr: Vec>>, /// Set if the output of this aggregation is truncated by a upstream sort/limit clause @@ -272,29 +375,27 @@ pub struct AggregateExec { } impl AggregateExec { - /// Function used in `ConvertFirstLast` optimizer rule, + /// Function used in `OptimizeAggregateOrder` optimizer rule, /// where we need parts of the new value, others cloned from the old one - pub fn new_with_aggr_expr_and_ordering_info( + /// Rewrites aggregate exec with new aggregate expressions. + pub fn with_new_aggr_exprs( &self, - required_input_ordering: Option, - aggr_expr: Vec>, - cache: PlanProperties, - input_order_mode: InputOrderMode, + aggr_expr: Vec>, ) -> Self { Self { aggr_expr, - required_input_ordering, - metrics: ExecutionPlanMetricsSet::new(), - input_order_mode, - cache, // clone the rest of the fields + required_input_ordering: self.required_input_ordering.clone(), + metrics: ExecutionPlanMetricsSet::new(), + input_order_mode: self.input_order_mode.clone(), + cache: self.cache.clone(), mode: self.mode, group_by: self.group_by.clone(), filter_expr: self.filter_expr.clone(), limit: self.limit, - input: self.input.clone(), - schema: self.schema.clone(), - input_schema: self.input_schema.clone(), + input: Arc::clone(&self.input), + schema: Arc::clone(&self.schema), + input_schema: Arc::clone(&self.input_schema), } } @@ -306,18 +407,12 @@ impl AggregateExec { pub fn try_new( mode: AggregateMode, group_by: PhysicalGroupBy, - aggr_expr: Vec>, + aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( - &input.schema(), - &group_by.expr, - &aggr_expr, - group_by.contains_null(), - mode, - )?; + let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?; let schema = Arc::new(schema); AggregateExec::try_new_with_schema( @@ -333,7 +428,7 @@ impl AggregateExec { /// Create a new hash aggregate execution plan with the given schema. /// This constructor isn't part of the public API, it is used internally - /// by Datafusion to enforce schema consistency during when re-creating + /// by DataFusion to enforce schema consistency during when re-creating /// `AggregateExec`s inside optimization rules. Schema field names of an /// `AggregateExec` depends on the names of aggregate expressions. Since /// a rule may re-write aggregate expressions (e.g. reverse them) during @@ -343,7 +438,7 @@ impl AggregateExec { fn try_new_with_schema( mode: AggregateMode, group_by: PhysicalGroupBy, - mut aggr_expr: Vec>, + mut aggr_expr: Vec>, filter_expr: Vec>>, input: Arc, input_schema: SchemaRef, @@ -361,13 +456,15 @@ impl AggregateExec { // prefix requirements with this section. In this case, aggregation will // work more efficiently. let indices = get_ordered_partition_by_indices(&groupby_exprs, &input); - let mut new_requirement = indices - .iter() - .map(|&idx| PhysicalSortRequirement { - expr: groupby_exprs[idx].clone(), - options: None, - }) - .collect::>(); + let mut new_requirement = LexRequirement::new( + indices + .iter() + .map(|&idx| PhysicalSortRequirement { + expr: Arc::clone(&groupby_exprs[idx]), + options: None, + }) + .collect::>(), + ); let req = get_finer_aggregate_exprs_requirement( &mut aggr_expr, @@ -375,17 +472,29 @@ impl AggregateExec { input_eq_properties, &mode, )?; - new_requirement.extend(req); + new_requirement.inner.extend(req); new_requirement = collapse_lex_req(new_requirement); - let input_order_mode = - if indices.len() == groupby_exprs.len() && !indices.is_empty() { - InputOrderMode::Sorted - } else if !indices.is_empty() { - InputOrderMode::PartiallySorted(indices) - } else { - InputOrderMode::Linear - }; + // If our aggregation has grouping sets then our base grouping exprs will + // be expanded based on the flags in `group_by.groups` where for each + // group we swap the grouping expr for `null` if the flag is `true` + // That means that each index in `indices` is valid if and only if + // it is not null in every group + let indices: Vec = indices + .into_iter() + .filter(|idx| group_by.groups.iter().all(|group| !group[*idx])) + .collect(); + + let input_order_mode = if indices.len() == groupby_exprs.len() + && !indices.is_empty() + && group_by.groups.len() == 1 + { + InputOrderMode::Sorted + } else if !indices.is_empty() { + InputOrderMode::PartiallySorted(indices) + } else { + InputOrderMode::Linear + }; // construct a map from the input expression to the output expression of the Aggregation group by let projection_mapping = @@ -396,7 +505,7 @@ impl AggregateExec { let cache = Self::compute_properties( &input, - schema.clone(), + Arc::clone(&schema), &projection_mapping, &mode, &input_order_mode, @@ -439,7 +548,7 @@ impl AggregateExec { } /// Aggregate expressions - pub fn aggr_expr(&self) -> &[Arc] { + pub fn aggr_expr(&self) -> &[Arc] { &self.aggr_expr } @@ -455,7 +564,7 @@ impl AggregateExec { /// Get the input schema before any aggregates are applied pub fn input_schema(&self) -> SchemaRef { - self.input_schema.clone() + Arc::clone(&self.input_schema) } /// number of rows soft limit of the AggregateExec @@ -493,17 +602,7 @@ impl AggregateExec { /// Finds the DataType and SortDirection for this Aggregate, if there is one pub fn get_minmax_desc(&self) -> Option<(Field, bool)> { let agg_expr = self.aggr_expr.iter().exactly_one().ok()?; - if let Some(max) = agg_expr.as_any().downcast_ref::() { - Some((max.field().ok()?, true)) - } else if let Some(min) = agg_expr.as_any().downcast_ref::() { - Some((min.field().ok()?, false)) - } else { - None - } - } - - pub fn group_by(&self) -> &PhysicalGroupBy { - &self.group_by + agg_expr.get_minmax_desc() } /// true, if this Aggregate has a group-by with no required or explicit ordering, @@ -512,7 +611,7 @@ impl AggregateExec { /// on an AggregateExec. pub fn is_unordered_unfiltered_group_by_distinct(&self) -> bool { // ensure there is a group by - if self.group_by().is_empty() { + if self.group_expr().is_empty() { return false; } // ensure there are no aggregate expressions @@ -553,26 +652,16 @@ impl AggregateExec { .project(projection_mapping, schema); // Get output partitioning: - let mut output_partitioning = input.output_partitioning().clone(); - if mode.is_first_stage() { + let input_partitioning = input.output_partitioning().clone(); + let output_partitioning = if mode.is_first_stage() { // First stage aggregation will not change the output partitioning, // but needs to respect aliases (e.g. mapping in the GROUP BY // expression). let input_eq_properties = input.equivalence_properties(); - if let Partitioning::Hash(exprs, part) = output_partitioning { - let normalized_exprs = exprs - .iter() - .map(|expr| { - input_eq_properties - .project_expr(expr, projection_mapping) - .unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) - }) - }) - .collect(); - output_partitioning = Partitioning::Hash(normalized_exprs, part); - } - } + input_partitioning.project(projection_mapping, input_eq_properties) + } else { + input_partitioning.clone() + }; // Determine execution mode: let mut exec_mode = input.execution_mode(); @@ -688,7 +777,7 @@ impl ExecutionPlan for AggregateExec { vec![Distribution::UnspecifiedDistribution] } AggregateMode::FinalPartitioned | AggregateMode::SinglePartitioned => { - vec![Distribution::HashPartitioned(self.output_group_expr())] + vec![Distribution::HashPartitioned(self.group_by.input_exprs())] } AggregateMode::Final | AggregateMode::Single => { vec![Distribution::SinglePartition] @@ -700,8 +789,8 @@ impl ExecutionPlan for AggregateExec { vec![self.required_input_ordering.clone()] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -713,9 +802,9 @@ impl ExecutionPlan for AggregateExec { self.group_by.clone(), self.aggr_expr.clone(), self.filter_expr.clone(), - children[0].clone(), - self.input_schema.clone(), - self.schema.clone(), + Arc::clone(&children[0]), + Arc::clone(&self.input_schema), + Arc::clone(&self.schema), )?; me.limit = self.limit; @@ -740,7 +829,7 @@ impl ExecutionPlan for AggregateExec { // - once expressions will be able to compute their own stats, use it here // - case where we group by on a column for which with have the `distinct` stat // TODO stats: aggr expression: - // - aggregations somtimes also preserve invariants such as min, max... + // - aggregations sometimes also preserve invariants such as min, max... let column_statistics = Statistics::unknown_column(&self.schema()); match self.mode { AggregateMode::Final | AggregateMode::FinalPartitioned @@ -781,26 +870,20 @@ impl ExecutionPlan for AggregateExec { } } } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual + } } fn create_schema( input_schema: &Schema, - group_expr: &[(Arc, String)], - aggr_expr: &[Arc], - contains_null_expr: bool, + group_by: &PhysicalGroupBy, + aggr_expr: &[Arc], mode: AggregateMode, ) -> Result { - let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); - for (expr, name) in group_expr { - fields.push(Field::new( - name, - expr.data_type(input_schema)?, - // In cases where we have multiple grouping sets, we will use NULL expressions in - // order to align the grouping sets. So the field must be nullable even if the underlying - // schema field is not. - contains_null_expr || expr.nullable(input_schema)?, - )) - } + let mut fields = Vec::with_capacity(group_by.num_output_exprs() + aggr_expr.len()); + fields.extend(group_by.output_fields(input_schema)?); match mode { AggregateMode::Partial => { @@ -815,24 +898,26 @@ fn create_schema( | AggregateMode::SinglePartitioned => { // in final mode, the field with the final result of the accumulator for expr in aggr_expr { - fields.push(expr.field()?) + fields.push(expr.field()) } } } - Ok(Schema::new(fields)) + Ok(Schema::new_with_metadata( + fields, + input_schema.metadata().clone(), + )) } -fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { - let group_fields = schema.fields()[0..group_count].to_vec(); - Arc::new(Schema::new(group_fields)) +fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result { + Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?))) } /// Determines the lexical ordering requirement for an aggregate expression. /// /// # Parameters /// -/// - `aggr_expr`: A reference to an `Arc` representing the +/// - `aggr_expr`: A reference to an `AggregateFunctionExpr` representing the /// aggregate expression. /// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the /// physical GROUP BY expression. @@ -844,19 +929,18 @@ fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { /// A `LexOrdering` instance indicating the lexical ordering requirement for /// the aggregate expression. fn get_aggregate_expr_req( - aggr_expr: &Arc, + aggr_expr: &AggregateFunctionExpr, group_by: &PhysicalGroupBy, agg_mode: &AggregateMode, ) -> LexOrdering { - // If the aggregation function is not order sensitive, or the aggregation - // is performing a "second stage" calculation, or all aggregate function - // requirements are inside the GROUP BY expression, then ignore the ordering - // requirement. - if !is_order_sensitive(aggr_expr) || !agg_mode.is_first_stage() { - return vec![]; + // If the aggregation function is ordering requirement is not absolutely + // necessary, or the aggregation is performing a "second stage" calculation, + // then ignore the ordering requirement. + if !aggr_expr.order_sensitivity().hard_requires() || !agg_mode.is_first_stage() { + return LexOrdering::default(); } - let mut req = aggr_expr.order_bys().unwrap_or_default().to_vec(); + let mut req = LexOrdering::from_ref(aggr_expr.order_bys().unwrap_or_default()); // In non-first stage modes, we accumulate data (using `merge_batch`) from // different partitions (i.e. merge partial results). During this merge, we @@ -893,13 +977,13 @@ fn get_aggregate_expr_req( /// the aggregator requirement is incompatible. fn finer_ordering( existing_req: &LexOrdering, - aggr_expr: &Arc, + aggr_expr: &AggregateFunctionExpr, group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, ) -> Option { let aggr_req = get_aggregate_expr_req(aggr_expr, group_by, agg_mode); - eq_properties.get_finer_ordering(existing_req, &aggr_req) + eq_properties.get_finer_ordering(existing_req.as_ref(), aggr_req.as_ref()) } /// Concatenates the given slices. @@ -911,7 +995,7 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// /// # Parameters /// -/// - `aggr_exprs`: A slice of `Arc` containing all the +/// - `aggr_exprs`: A slice of `AggregateFunctionExpr` containing all the /// aggregate expressions. /// - `group_by`: A reference to a `PhysicalGroupBy` instance representing the /// physical GROUP BY expression. @@ -924,18 +1008,18 @@ pub fn concat_slices(lhs: &[T], rhs: &[T]) -> Vec { /// /// A `LexRequirement` instance, which is the requirement that satisfies all the /// aggregate requirements. Returns an error in case of conflicting requirements. -fn get_finer_aggregate_exprs_requirement( - aggr_exprs: &mut [Arc], +pub fn get_finer_aggregate_exprs_requirement( + aggr_exprs: &mut [Arc], group_by: &PhysicalGroupBy, eq_properties: &EquivalenceProperties, agg_mode: &AggregateMode, ) -> Result { - let mut requirement = vec![]; + let mut requirement = LexOrdering::default(); for aggr_expr in aggr_exprs.iter_mut() { if let Some(finer_ordering) = finer_ordering(&requirement, aggr_expr, group_by, eq_properties, agg_mode) { - if eq_properties.ordering_satisfy(&finer_ordering) { + if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { // Requirement is satisfied by existing ordering requirement = finer_ordering; continue; @@ -949,11 +1033,11 @@ fn get_finer_aggregate_exprs_requirement( eq_properties, agg_mode, ) { - if eq_properties.ordering_satisfy(&finer_ordering) { + if eq_properties.ordering_satisfy(finer_ordering.as_ref()) { // Reverse requirement is satisfied by exiting ordering. // Hence reverse the aggregator requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -977,7 +1061,7 @@ fn get_finer_aggregate_exprs_requirement( // There is a requirement that both satisfies existing requirement and reverse // aggregate requirement. Use updated requirement requirement = finer_ordering; - *aggr_expr = reverse_aggr_expr; + *aggr_expr = Arc::new(reverse_aggr_expr); continue; } } @@ -990,15 +1074,18 @@ fn get_finer_aggregate_exprs_requirement( ); } - Ok(PhysicalSortRequirement::from_sort_exprs(&requirement)) + Ok(PhysicalSortRequirement::from_sort_exprs( + requirement.inner.iter(), + )) } -/// returns physical expressions for arguments to evaluate against a batch +/// Returns physical expressions for arguments to evaluate against a batch. +/// /// The expressions are different depending on `mode`: -/// * Partial: AggregateExpr::expressions -/// * Final: columns of `AggregateExpr::state_fields()` -fn aggregate_expressions( - aggr_expr: &[Arc], +/// * Partial: AggregateFunctionExpr::expressions +/// * Final: columns of `AggregateFunctionExpr::state_fields()` +pub fn aggregate_expressions( + aggr_expr: &[Arc], mode: &AggregateMode, col_idx_base: usize, ) -> Result>>> { @@ -1013,7 +1100,7 @@ fn aggregate_expressions( // way order sensitive aggregators can satisfy requirement // themselves. if let Some(ordering_req) = agg.order_bys() { - result.extend(ordering_req.iter().map(|item| item.expr.clone())); + result.extend(ordering_req.iter().map(|item| Arc::clone(&item.expr))); } result }) @@ -1034,12 +1121,12 @@ fn aggregate_expressions( } /// uses `state_fields` to build a vec of physical column expressions required to merge the -/// AggregateExpr' accumulator's state. +/// AggregateFunctionExpr' accumulator's state. /// /// `index_base` is the starting physical column index for the next expanded state field. fn merge_expressions( index_base: usize, - expr: &Arc, + expr: &AggregateFunctionExpr, ) -> Result>> { expr.state_fields().map(|fields| { fields @@ -1050,10 +1137,10 @@ fn merge_expressions( }) } -pub(crate) type AccumulatorItem = Box; +pub type AccumulatorItem = Box; -fn create_accumulators( - aggr_expr: &[Arc], +pub fn create_accumulators( + aggr_expr: &[Arc], ) -> Result> { aggr_expr .iter() @@ -1063,7 +1150,7 @@ fn create_accumulators( /// returns a vector of ArrayRefs, where each entry corresponds to either the /// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial) -fn finalize_aggregation( +pub fn finalize_aggregation( accumulators: &mut [AccumulatorItem], mode: &AggregateMode, ) -> Result> { @@ -1132,15 +1219,36 @@ fn evaluate_optional( .collect() } +fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result { + if group.len() > 64 { + return not_impl_err!( + "Grouping sets with more than 64 columns are not supported" + ); + } + let group_id = group.iter().fold(0u64, |acc, &is_null| { + (acc << 1) | if is_null { 1 } else { 0 } + }); + let num_rows = batch.num_rows(); + if group.len() <= 8 { + Ok(Arc::new(UInt8Array::from(vec![group_id as u8; num_rows]))) + } else if group.len() <= 16 { + Ok(Arc::new(UInt16Array::from(vec![group_id as u16; num_rows]))) + } else if group.len() <= 32 { + Ok(Arc::new(UInt32Array::from(vec![group_id as u32; num_rows]))) + } else { + Ok(Arc::new(UInt64Array::from(vec![group_id; num_rows]))) + } +} + /// Evaluate a group by expression against a `RecordBatch` /// /// Arguments: -/// `group_by`: the expression to evaluate -/// `batch`: the `RecordBatch` to evaluate against +/// - `group_by`: the expression to evaluate +/// - `batch`: the `RecordBatch` to evaluate against /// /// Returns: A Vec of Vecs of Array of results -/// The outer Vect appears to be for grouping sets -/// The inner Vect contains the results per expression +/// The outer Vec appears to be for grouping sets +/// The inner Vec contains the results per expression /// The inner-inner Array contains the results per row pub(crate) fn evaluate_group_by( group_by: &PhysicalGroupBy, @@ -1164,23 +1272,24 @@ pub(crate) fn evaluate_group_by( }) .collect::>>()?; - Ok(group_by + group_by .groups .iter() .map(|group| { - group - .iter() - .enumerate() - .map(|(idx, is_null)| { - if *is_null { - null_exprs[idx].clone() - } else { - exprs[idx].clone() - } - }) - .collect() + let mut group_values = Vec::with_capacity(group_by.num_group_exprs()); + group_values.extend(group.iter().enumerate().map(|(idx, is_null)| { + if *is_null { + Arc::clone(&null_exprs[idx]) + } else { + Arc::clone(&exprs[idx]) + } + })); + if !group_by.is_single() { + group_values.push(group_id_array(group, batch)?); + } + Ok(group_values) }) - .collect()) + .collect() } #[cfg(test)] @@ -1191,7 +1300,7 @@ mod tests { use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common; - use crate::expressions::{col, Avg}; + use crate::expressions::col; use crate::memory::MemoryExec; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -1199,19 +1308,30 @@ mod tests { use arrow::array::{Float64Array, UInt32Array}; use arrow::compute::{concat_batches, SortOptions}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Int32Type}; + use arrow_array::{ + DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array, + }; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError, ScalarValue, }; use datafusion_execution::config::SessionConfig; use datafusion_execution::memory_pool::FairSpillPool; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use datafusion_physical_expr::expressions::{ - lit, ApproxDistinct, Count, FirstValue, LastValue, Median, OrderSensitiveArrayAgg, - }; - use datafusion_physical_expr::{reverse_order_bys, PhysicalSortExpr}; - + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_functions_aggregate::array_agg::array_agg_udaf; + use datafusion_functions_aggregate::average::avg_udaf; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf}; + use datafusion_functions_aggregate::median::median_udaf; + use datafusion_functions_aggregate::sum::sum_udaf; + use datafusion_physical_expr::expressions::lit; + use datafusion_physical_expr::PhysicalSortExpr; + + use crate::common::collect; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::Literal; + use datafusion_physical_expr::Partitioning; use futures::{FutureExt, Stream}; // Generate a schema which consists of 5 columns (a, b, c, d, e) @@ -1236,10 +1356,10 @@ mod tests { // define data. ( - schema.clone(), + Arc::clone(&schema), vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), @@ -1271,10 +1391,10 @@ mod tests { // the expected result by accident, but merging actually works properly; // i.e. it doesn't depend on the data insertion order. ( - schema.clone(), + Arc::clone(&schema), vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), @@ -1282,7 +1402,7 @@ mod tests { ) .unwrap(), RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), Arc::new(Float64Array::from(vec![0.0, 1.0, 2.0, 3.0])), @@ -1290,7 +1410,7 @@ mod tests { ) .unwrap(), RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), Arc::new(Float64Array::from(vec![3.0, 4.0, 5.0, 6.0])), @@ -1311,13 +1431,10 @@ mod tests { fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc { let session_config = SessionConfig::new().with_batch_size(batch_size); - let runtime = Arc::new( - RuntimeEnv::new( - RuntimeConfig::default() - .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))), - ) - .unwrap(), - ); + let runtime = RuntimeEnvBuilder::default() + .with_memory_pool(Arc::new(FairSpillPool::new(max_memory))) + .build_arc() + .unwrap(); let task_ctx = TaskContext::default() .with_session_config(session_config) .with_runtime(runtime); @@ -1330,30 +1447,32 @@ mod tests { ) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![ + let grouping_set = PhysicalGroupBy::new( + vec![ (col("a", &input_schema)?, "a".to_string()), (col("b", &input_schema)?, "b".to_string()), ], - null_expr: vec![ + vec![ (lit(ScalarValue::UInt32(None)), "a".to_string()), (lit(ScalarValue::Float64(None)), "b".to_string()), ], - groups: vec![ + vec![ vec![false, true], // (a, NULL) vec![true, false], // (NULL, b) vec![false, false], // (a,b) ], - }; + ); - let aggregates: Vec> = vec![Arc::new(Count::new( - lit(1i8), - "COUNT(1)".to_string(), - DataType::Int64, - ))]; + let aggregates = vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(1)") + .build()?, + )]; let task_ctx = if spill { - new_spill_ctx(4, 1000) + // adjust the max memory size to have the partial aggregate result for spill mode. + new_spill_ctx(4, 500) } else { Arc::new(TaskContext::default()) }; @@ -1364,71 +1483,66 @@ mod tests { aggregates.clone(), vec![None], input, - input_schema.clone(), + Arc::clone(&input_schema), )?); let result = - common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { + // In spill mode, we test with the limited memory, if the mem usage exceeds, + // we trigger the early emit rule, which turns out the partial aggregate result. vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 1 |", - "| | 1.0 | 1 |", - "| | 2.0 | 1 |", - "| | 2.0 | 1 |", - "| | 3.0 | 1 |", - "| | 3.0 | 1 |", - "| | 4.0 | 1 |", - "| | 4.0 | 1 |", - "| 2 | | 1 |", - "| 2 | | 1 |", - "| 2 | 1.0 | 1 |", - "| 2 | 1.0 | 1 |", - "| 3 | | 1 |", - "| 3 | | 2 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 1 |", - "| 4 | | 2 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 1 |", + "| | 1.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 2.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 3.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| | 4.0 | 2 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | | 1 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 2 | 1.0 | 0 | 1 |", + "| 3 | | 1 | 1 |", + "| 3 | | 1 | 2 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 1 |", + "| 4 | | 1 | 2 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] } else { vec![ - "+---+-----+-----------------+", - "| a | b | COUNT(1)[count] |", - "+---+-----+-----------------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+-----------------+", + "+---+-----+---------------+-----------------+", + "| a | b | __grouping_id | COUNT(1)[count] |", + "+---+-----+---------------+-----------------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+-----------------+", ] }; assert_batches_sorted_eq!(expected, &result); - let groups = partial_aggregate.group_expr().expr().to_vec(); - let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = groups - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let task_ctx = if spill { new_spill_ctx(4, 3160) @@ -1445,29 +1559,28 @@ mod tests { input_schema, )?); - let result = - common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?; + let result = collect(merged_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let batch = concat_batches(&result[0].schema(), &result)?; - assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.num_columns(), 4); assert_eq!(batch.num_rows(), 12); let expected = vec![ - "+---+-----+----------+", - "| a | b | COUNT(1) |", - "+---+-----+----------+", - "| | 1.0 | 2 |", - "| | 2.0 | 2 |", - "| | 3.0 | 2 |", - "| | 4.0 | 2 |", - "| 2 | | 2 |", - "| 2 | 1.0 | 2 |", - "| 3 | | 3 |", - "| 3 | 2.0 | 2 |", - "| 3 | 3.0 | 1 |", - "| 4 | | 3 |", - "| 4 | 3.0 | 1 |", - "| 4 | 4.0 | 2 |", - "+---+-----+----------+", + "+---+-----+---------------+----------+", + "| a | b | __grouping_id | COUNT(1) |", + "+---+-----+---------------+----------+", + "| | 1.0 | 2 | 2 |", + "| | 2.0 | 2 | 2 |", + "| | 3.0 | 2 | 2 |", + "| | 4.0 | 2 | 2 |", + "| 2 | | 1 | 2 |", + "| 2 | 1.0 | 0 | 2 |", + "| 3 | | 1 | 3 |", + "| 3 | 2.0 | 0 | 2 |", + "| 3 | 3.0 | 0 | 1 |", + "| 4 | | 1 | 3 |", + "| 4 | 3.0 | 0 | 1 |", + "| 4 | 4.0 | 0 | 2 |", + "+---+-----+---------------+----------+", ]; assert_batches_sorted_eq!(&expected, &result); @@ -1483,19 +1596,21 @@ mod tests { async fn check_aggregates(input: Arc, spill: bool) -> Result<()> { let input_schema = input.schema(); - let grouping_set = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let grouping_set = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b", &input_schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; let task_ctx = if spill { + // set to an appropriate value to trigger spill new_spill_ctx(2, 1600) } else { Arc::new(TaskContext::default()) @@ -1507,11 +1622,11 @@ mod tests { aggregates.clone(), vec![None], input, - input_schema.clone(), + Arc::clone(&input_schema), )?); let result = - common::collect(partial_aggregate.execute(0, task_ctx.clone())?).await?; + collect(partial_aggregate.execute(0, Arc::clone(&task_ctx))?).await?; let expected = if spill { vec![ @@ -1540,13 +1655,7 @@ mod tests { let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate)); - let final_group: Vec<(Arc, String)> = grouping_set - .expr - .iter() - .map(|(_expr, name)| Ok((col(name, &input_schema)?, name.clone()))) - .collect::>()?; - - let final_grouping_set = PhysicalGroupBy::new_single(final_group); + let final_grouping_set = grouping_set.as_final(); let merged_aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, @@ -1557,8 +1666,13 @@ mod tests { input_schema, )?); - let result = - common::collect(merged_aggregate.execute(0, task_ctx.clone())?).await?; + let task_ctx = if spill { + // enlarge memory limit to let the final aggregation finish + new_spill_ctx(2, 2600) + } else { + Arc::clone(&task_ctx) + }; + let result = collect(merged_aggregate.execute(0, task_ctx)?).await?; let batch = concat_batches(&result[0].schema(), &result)?; assert_eq!(batch.num_columns(), 2); assert_eq!(batch.num_rows(), 3); @@ -1577,12 +1691,24 @@ mod tests { let metrics = merged_aggregate.metrics().unwrap(); let output_rows = metrics.output_rows().unwrap(); + let spill_count = metrics.spill_count().unwrap(); + let spilled_bytes = metrics.spilled_bytes().unwrap(); + let spilled_rows = metrics.spilled_rows().unwrap(); + if spill { // When spilling, the output rows metrics become partial output size + final output size // This is because final aggregation starts while partial aggregation is still emitting assert_eq!(8, output_rows); + + assert!(spill_count > 0); + assert!(spilled_bytes > 0); + assert!(spilled_rows > 0); } else { assert_eq!(3, output_rows); + + assert_eq!(0, spill_count); + assert_eq!(0, spilled_bytes); + assert_eq!(0, spilled_rows); } Ok(()) @@ -1644,7 +1770,7 @@ mod tests { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -1777,49 +1903,46 @@ mod tests { check_grouping_sets(input, true).await } + // Median(a) + fn test_median_agg_expr(schema: SchemaRef) -> Result { + AggregateExprBuilder::new(median_udaf(), vec![col("a", &schema)?]) + .schema(schema) + .alias("MEDIAN(a)") + .build() + } + #[tokio::test] async fn test_oom() -> Result<()> { let input: Arc = Arc::new(TestYieldingExec::new(true)); let input_schema = input.schema(); - let runtime = Arc::new( - RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)).unwrap(), - ); + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(1, 1.0) + .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); let groups_none = PhysicalGroupBy::default(); - let groups_some = PhysicalGroupBy { - expr: vec![(col("a", &input_schema)?, "a".to_string())], - null_expr: vec![], - groups: vec![vec![false]], - }; + let groups_some = PhysicalGroupBy::new( + vec![(col("a", &input_schema)?, "a".to_string())], + vec![], + vec![vec![false]], + ); // something that allocates within the aggregator - let aggregates_v0: Vec> = vec![Arc::new(Median::new( - col("a", &input_schema)?, - "MEDIAN(a)".to_string(), - DataType::UInt32, - ))]; - - // use slow-path in `hash.rs` - let aggregates_v1: Vec> = - vec![Arc::new(ApproxDistinct::new( - col("a", &input_schema)?, - "APPROX_DISTINCT(a)".to_string(), - DataType::UInt32, - ))]; + let aggregates_v0: Vec> = + vec![Arc::new(test_median_agg_expr(Arc::clone(&input_schema))?)]; // use fast-path in `row_hash.rs`. - let aggregates_v2: Vec> = vec![Arc::new(Avg::new( - col("b", &input_schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates_v2: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("AVG(b)") + .build()?, + )]; for (version, groups, aggregates) in [ (0, groups_none, aggregates_v0), - (1, groups_some.clone(), aggregates_v1), (2, groups_some, aggregates_v2), ] { let n_aggr = aggregates.len(); @@ -1828,11 +1951,11 @@ mod tests { groups, aggregates, vec![None; n_aggr], - input.clone(), - input_schema.clone(), + Arc::clone(&input), + Arc::clone(&input_schema), )?); - let stream = partial_aggregate.execute_typed(0, task_ctx.clone())?; + let stream = partial_aggregate.execute_typed(0, Arc::clone(&task_ctx))?; // ensure that we really got the version we wanted match version { @@ -1849,7 +1972,7 @@ mod tests { } let stream: SendableRecordBatchStream = stream.into(); - let err = common::collect(stream).await.unwrap_err(); + let err = collect(stream).await.unwrap_err(); // error root cause traversal is a bit complicated, see #4172. let err = err.find_root(); @@ -1866,15 +1989,16 @@ mod tests { async fn test_drop_cancel_without_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = - Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); + Arc::new(Schema::new(vec![Field::new("a", DataType::Float64, true)])); let groups = PhysicalGroupBy::default(); - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("a", &schema)?, - "AVG(a)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(a)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1901,18 +2025,19 @@ mod tests { async fn test_drop_cancel_with_groups() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, true), - Field::new("b", DataType::Float32, true), + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Float64, true), ])); let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let aggregates: Vec> = vec![Arc::new(Avg::new( - col("b", &schema)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let aggregates: Vec> = vec![Arc::new( + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?, + )]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); @@ -1940,14 +2065,56 @@ mod tests { for use_coalesce_batches in [false, true] { for is_first_acc in [false, true] { for spill in [false, true] { - first_last_multi_partitions(use_coalesce_batches, is_first_acc, spill) - .await? + first_last_multi_partitions( + use_coalesce_batches, + is_first_acc, + spill, + 4200, + ) + .await? } } } Ok(()) } + // FIRST_VALUE(b ORDER BY b ) + fn test_first_value_agg_expr( + schema: &Schema, + sort_options: SortOptions, + ) -> Result> { + let ordering_req = [PhysicalSortExpr { + expr: col("b", schema)?, + options: sort_options, + }]; + let args = [col("b", schema)?]; + + AggregateExprBuilder::new(first_value_udaf(), args.to_vec()) + .order_by(LexOrdering::new(ordering_req.to_vec())) + .schema(Arc::new(schema.clone())) + .alias(String::from("first_value(b) ORDER BY [b ASC NULLS LAST]")) + .build() + .map(Arc::new) + } + + // LAST_VALUE(b ORDER BY b ) + fn test_last_value_agg_expr( + schema: &Schema, + sort_options: SortOptions, + ) -> Result> { + let ordering_req = [PhysicalSortExpr { + expr: col("b", schema)?, + options: sort_options, + }]; + let args = [col("b", schema)?]; + AggregateExprBuilder::new(last_value_udaf(), args.to_vec()) + .order_by(LexOrdering::new(ordering_req.to_vec())) + .schema(Arc::new(schema.clone())) + .alias(String::from("last_value(b) ORDER BY [b ASC NULLS LAST]")) + .build() + .map(Arc::new) + } + // This function either constructs the physical plan below, // // "AggregateExec: mode=Final, gby=[a@0 as a], aggr=[FIRST_VALUE(b)]", @@ -1969,9 +2136,10 @@ mod tests { use_coalesce_batches: bool, is_first_acc: bool, spill: bool, + max_memory: usize, ) -> Result<()> { let task_ctx = if spill { - new_spill_ctx(2, 3200) + new_spill_ctx(2, max_memory) } else { Arc::new(TaskContext::default()) }; @@ -1985,27 +2153,14 @@ mod tests { let groups = PhysicalGroupBy::new_single(vec![(col("a", &schema)?, "a".to_string())]); - let ordering_req = vec![PhysicalSortExpr { - expr: col("b", &schema)?, - options: SortOptions::default(), - }]; - let aggregates: Vec> = if is_first_acc { - vec![Arc::new(FirstValue::new( - col("b", &schema)?, - "FIRST_VALUE(b)".to_string(), - DataType::Float64, - ordering_req.clone(), - vec![DataType::Float64], - vec![], - ))] + let sort_options = SortOptions { + descending: false, + nulls_first: false, + }; + let aggregates: Vec> = if is_first_acc { + vec![test_first_value_agg_expr(&schema, sort_options)?] } else { - vec![Arc::new(LastValue::new( - col("b", &schema)?, - "LAST_VALUE(b)".to_string(), - DataType::Float64, - ordering_req.clone(), - vec![DataType::Float64], - ))] + vec![test_last_value_agg_expr(&schema, sort_options)?] }; let memory_exec = Arc::new(MemoryExec::try_new( @@ -2015,7 +2170,7 @@ mod tests { vec![partition3], vec![partition4], ], - schema.clone(), + Arc::clone(&schema), None, )?); let aggregate_exec = Arc::new(AggregateExec::try_new( @@ -2024,7 +2179,7 @@ mod tests { aggregates.clone(), vec![None], memory_exec, - schema.clone(), + Arc::clone(&schema), )?); let coalesce = if use_coalesce_batches { let coalesce = Arc::new(CoalescePartitionsExec::new(aggregate_exec)); @@ -2045,24 +2200,24 @@ mod tests { let result = crate::collect(aggregate_final, task_ctx).await?; if is_first_acc { let expected = [ - "+---+----------------+", - "| a | FIRST_VALUE(b) |", - "+---+----------------+", - "| 2 | 0.0 |", - "| 3 | 1.0 |", - "| 4 | 3.0 |", - "+---+----------------+", + "+---+--------------------------------------------+", + "| a | first_value(b) ORDER BY [b ASC NULLS LAST] |", + "+---+--------------------------------------------+", + "| 2 | 0.0 |", + "| 3 | 1.0 |", + "| 4 | 3.0 |", + "+---+--------------------------------------------+", ]; assert_batches_eq!(expected, &result); } else { let expected = [ - "+---+---------------+", - "| a | LAST_VALUE(b) |", - "+---+---------------+", - "| 2 | 3.0 |", - "| 3 | 5.0 |", - "| 4 | 6.0 |", - "+---+---------------+", + "+---+-------------------------------------------+", + "| a | last_value(b) ORDER BY [b ASC NULLS LAST] |", + "+---+-------------------------------------------+", + "| 2 | 3.0 |", + "| 3 | 5.0 |", + "| 4 | 6.0 |", + "+---+-------------------------------------------+", ]; assert_batches_eq!(expected, &result); }; @@ -2072,6 +2227,7 @@ mod tests { #[tokio::test] async fn test_get_finest_requirements() -> Result<()> { let test_schema = create_test_schema()?; + // Assume column a and b are aliases // Assume also that a ASC and c DESC describe the same global ordering for the table. (Since they are ordering equivalent). let options1 = SortOptions { @@ -2081,63 +2237,64 @@ mod tests { let col_a = &col("a", &test_schema)?; let col_b = &col("b", &test_schema)?; let col_c = &col("c", &test_schema)?; - let mut eq_properties = EquivalenceProperties::new(test_schema); + let mut eq_properties = EquivalenceProperties::new(Arc::clone(&test_schema)); // Columns a and b are equal. - eq_properties.add_equal_conditions(col_a, col_b); + eq_properties.add_equal_conditions(col_a, col_b)?; // Aggregate requirements are // [None], [a ASC], [a ASC, b ASC, c ASC], [a ASC, b ASC] respectively let order_by_exprs = vec![ None, Some(vec![PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }]), Some(vec![ PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }, PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: options1, }, PhysicalSortExpr { - expr: col_c.clone(), + expr: Arc::clone(col_c), options: options1, }, ]), Some(vec![ PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }, PhysicalSortExpr { - expr: col_b.clone(), + expr: Arc::clone(col_b), options: options1, }, ]), ]; - let common_requirement = vec![ + + let common_requirement = LexOrdering::new(vec![ PhysicalSortExpr { - expr: col_a.clone(), + expr: Arc::clone(col_a), options: options1, }, PhysicalSortExpr { - expr: col_c.clone(), + expr: Arc::clone(col_c), options: options1, }, - ]; + ]); let mut aggr_exprs = order_by_exprs .into_iter() .map(|order_by_expr| { - Arc::new(OrderSensitiveArrayAgg::new( - col_a.clone(), - "array_agg", - DataType::Int32, - false, - vec![], - order_by_expr.unwrap_or_default(), - )) as _ + let ordering_req = order_by_expr.unwrap_or_default(); + AggregateExprBuilder::new(array_agg_udaf(), vec![Arc::clone(col_a)]) + .alias("a") + .order_by(LexOrdering::new(ordering_req.to_vec())) + .schema(Arc::clone(&test_schema)) + .build() + .map(Arc::new) + .unwrap() }) .collect::>(); let group_by = PhysicalGroupBy::new_single(vec![]); @@ -2160,48 +2317,457 @@ mod tests { ])); let col_a = col("a", &schema)?; - let col_b = col("b", &schema)?; let option_desc = SortOptions { descending: true, nulls_first: true, }; - let sort_expr = vec![PhysicalSortExpr { - expr: col_b.clone(), - options: option_desc, - }]; - let sort_expr_reverse = reverse_order_bys(&sort_expr); let groups = PhysicalGroupBy::new_single(vec![(col_a, "a".to_string())]); - let aggregates: Vec> = vec![ - Arc::new(FirstValue::new( - col_b.clone(), - "FIRST_VALUE(b)".to_string(), - DataType::Float64, - sort_expr_reverse.clone(), - vec![DataType::Float64], - vec![], - )), - Arc::new(LastValue::new( - col_b.clone(), - "LAST_VALUE(b)".to_string(), - DataType::Float64, - sort_expr.clone(), - vec![DataType::Float64], - )), + let aggregates: Vec> = vec![ + test_first_value_agg_expr(&schema, option_desc)?, + test_last_value_agg_expr(&schema, option_desc)?, ]; let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let aggregate_exec = Arc::new(AggregateExec::try_new( AggregateMode::Partial, groups, - aggregates.clone(), + aggregates, vec![None, None], - blocking_exec.clone(), + Arc::clone(&blocking_exec) as Arc, schema, )?); - let new_agg = aggregate_exec - .clone() - .with_new_children(vec![blocking_exec])?; + let new_agg = + Arc::clone(&aggregate_exec).with_new_children(vec![blocking_exec])?; assert_eq!(new_agg.schema(), aggregate_exec.schema()); Ok(()) } + + #[tokio::test] + async fn test_agg_exec_group_by_const() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, true), + Field::new("b", DataType::Float32, true), + Field::new("const", DataType::Int32, false), + ])); + + let col_a = col("a", &schema)?; + let col_b = col("b", &schema)?; + let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1)))); + + let groups = PhysicalGroupBy::new( + vec![ + (col_a, "a".to_string()), + (col_b, "b".to_string()), + (const_expr, "const".to_string()), + ], + vec![ + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "a".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Float32(None))), + "b".to_string(), + ), + ( + Arc::new(Literal::new(ScalarValue::Int32(None))), + "const".to_string(), + ), + ], + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + ); + + let aggregates: Vec> = + vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1)]) + .schema(Arc::clone(&schema)) + .alias("1") + .build() + .map(Arc::new)?]; + + let input_batches = (0..4) + .map(|_| { + let a = Arc::new(Float32Array::from(vec![0.; 8192])); + let b = Arc::new(Float32Array::from(vec![0.; 8192])); + let c = Arc::new(Int32Array::from(vec![1; 8192])); + + RecordBatch::try_new(Arc::clone(&schema), vec![a, b, c]).unwrap() + }) + .collect(); + + let input = Arc::new(MemoryExec::try_new( + &[input_batches], + Arc::clone(&schema), + None, + )?); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + groups, + aggregates.clone(), + vec![None], + input, + schema, + )?); + + let output = + collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?; + + let expected = [ + "+-----+-----+-------+---------------+-------+", + "| a | b | const | __grouping_id | 1 |", + "+-----+-----+-------+---------------+-------+", + "| | | 1 | 6 | 32768 |", + "| | 0.0 | | 5 | 32768 |", + "| 0.0 | | | 3 | 32768 |", + "+-----+-----+-------+---------------+-------+", + ]; + assert_batches_sorted_eq!(expected, &output); + + Ok(()) + } + + #[tokio::test] + async fn test_agg_exec_struct_of_dicts() -> Result<()> { + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![ + Field::new( + "labels".to_string(), + DataType::Struct( + vec![ + Field::new_dict( + "a".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + ), + Field::new_dict( + "b".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + ), + ] + .into(), + ), + false, + ), + Field::new("value", DataType::UInt64, false), + ])), + vec![ + Arc::new(StructArray::from(vec![ + ( + Arc::new(Field::new_dict( + "a".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + )), + Arc::new( + vec![Some("a"), None, Some("a")] + .into_iter() + .collect::>(), + ) as ArrayRef, + ), + ( + Arc::new(Field::new_dict( + "b".to_string(), + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Utf8), + ), + true, + 0, + false, + )), + Arc::new( + vec![Some("b"), Some("c"), Some("b")] + .into_iter() + .collect::>(), + ) as ArrayRef, + ), + ])), + Arc::new(UInt64Array::from(vec![1, 1, 1])), + ], + ) + .expect("Failed to create RecordBatch"); + + let group_by = PhysicalGroupBy::new_single(vec![( + col("labels", &batch.schema())?, + "labels".to_string(), + )]); + + let aggr_expr = vec![AggregateExprBuilder::new( + sum_udaf(), + vec![col("value", &batch.schema())?], + ) + .schema(Arc::clone(&batch.schema())) + .alias(String::from("SUM(value)")) + .build() + .map(Arc::new)?]; + + let input = Arc::new(MemoryExec::try_new( + &[vec![batch.clone()]], + Arc::::clone(&batch.schema()), + None, + )?); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::FinalPartitioned, + group_by, + aggr_expr, + vec![None], + Arc::clone(&input) as Arc, + batch.schema(), + )?); + + let session_config = SessionConfig::default(); + let ctx = TaskContext::default().with_session_config(session_config); + let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + + let expected = [ + "+--------------+------------+", + "| labels | SUM(value) |", + "+--------------+------------+", + "| {a: a, b: b} | 2 |", + "| {a: , b: c} | 1 |", + "+--------------+------------+", + ]; + assert_batches_eq!(expected, &output); + + Ok(()) + } + + #[tokio::test] + async fn test_skip_aggregation_after_first_batch() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, true), + Field::new("val", DataType::Int32, true), + ])); + + let group_by = + PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); + + let aggr_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build() + .map(Arc::new)?, + ]; + + let input_data = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![2, 3, 4])), + Arc::new(Int32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(), + ]; + + let input = Arc::new(MemoryExec::try_new( + &[input_data], + Arc::clone(&schema), + None, + )?); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![None], + Arc::clone(&input) as Arc, + schema, + )?); + + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::Int64(Some(2)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(0.1)), + ); + + let ctx = TaskContext::default().with_session_config(session_config); + let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + + let expected = [ + "+-----+-------------------+", + "| key | COUNT(val)[count] |", + "+-----+-------------------+", + "| 1 | 1 |", + "| 2 | 1 |", + "| 3 | 1 |", + "| 2 | 1 |", + "| 3 | 1 |", + "| 4 | 1 |", + "+-----+-------------------+", + ]; + assert_batches_eq!(expected, &output); + + Ok(()) + } + + #[tokio::test] + async fn test_skip_aggregation_after_threshold() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Int32, true), + Field::new("val", DataType::Int32, true), + ])); + + let group_by = + PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]); + + let aggr_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("val", &schema)?]) + .schema(Arc::clone(&schema)) + .alias(String::from("COUNT(val)")) + .build() + .map(Arc::new)?, + ]; + + let input_data = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![2, 3, 4])), + Arc::new(Int32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(), + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![2, 3, 4])), + Arc::new(Int32Array::from(vec![0, 0, 0])), + ], + ) + .unwrap(), + ]; + + let input = Arc::new(MemoryExec::try_new( + &[input_data], + Arc::clone(&schema), + None, + )?); + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by, + aggr_expr, + vec![None], + Arc::clone(&input) as Arc, + schema, + )?); + + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::Int64(Some(5)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(0.1)), + ); + + let ctx = TaskContext::default().with_session_config(session_config); + let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?; + + let expected = [ + "+-----+-------------------+", + "| key | COUNT(val)[count] |", + "+-----+-------------------+", + "| 1 | 1 |", + "| 2 | 2 |", + "| 3 | 2 |", + "| 4 | 1 |", + "| 2 | 1 |", + "| 3 | 1 |", + "| 4 | 1 |", + "+-----+-------------------+", + ]; + assert_batches_eq!(expected, &output); + + Ok(()) + } + + #[test] + fn group_exprs_nullable() -> Result<()> { + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + let aggr_expr = + vec![ + AggregateExprBuilder::new(count_udaf(), vec![col("a", &input_schema)?]) + .schema(Arc::clone(&input_schema)) + .alias("COUNT(a)") + .build() + .map(Arc::new)?, + ]; + + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("a", &input_schema)?, "a".to_string()), + (col("b", &input_schema)?, "b".to_string()), + ], + vec![ + (lit(ScalarValue::Float32(None)), "a".to_string()), + (lit(ScalarValue::Float32(None)), "b".to_string()), + ], + vec![ + vec![false, true], // (a, NULL) + vec![false, false], // (a,b) + ], + ); + let aggr_schema = create_schema( + &input_schema, + &grouping_set, + &aggr_expr, + AggregateMode::Final, + )?; + let expected_schema = Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, true), + Field::new("__grouping_id", DataType::UInt8, false), + Field::new("COUNT(a)", DataType::Int64, false), + ]); + assert_eq!(aggr_schema, expected_schema); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/aggregates/no_grouping.rs b/datafusion/physical-plan/src/aggregates/no_grouping.rs index 5ec95bd79942..99417e4ee3e9 100644 --- a/datafusion/physical-plan/src/aggregates/no_grouping.rs +++ b/datafusion/physical-plan/src/aggregates/no_grouping.rs @@ -140,8 +140,11 @@ impl AggregateStream { let result = finalize_aggregation(&mut this.accumulators, &this.mode) .and_then(|columns| { - RecordBatch::try_new(this.schema.clone(), columns) - .map_err(Into::into) + RecordBatch::try_new( + Arc::clone(&this.schema), + columns, + ) + .map_err(Into::into) }) .record_output(&this.baseline_metrics); @@ -181,7 +184,7 @@ impl Stream for AggregateStream { impl RecordBatchStream for AggregateStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -215,6 +218,7 @@ fn aggregate_batch( Some(filter) => Cow::Owned(batch_filter(&batch, filter)?), None => Cow::Borrowed(&batch), }; + // 1.3 let values = &expr .iter() diff --git a/datafusion/physical-plan/src/aggregates/order/full.rs b/datafusion/physical-plan/src/aggregates/order/full.rs index c15538e8ab8e..218855459b1e 100644 --- a/datafusion/physical-plan/src/aggregates/order/full.rs +++ b/datafusion/physical-plan/src/aggregates/order/full.rs @@ -16,12 +16,13 @@ // under the License. use datafusion_expr::EmitTo; +use std::mem::size_of; /// Tracks grouping state when the data is ordered entirely by its /// group keys /// /// When the group values are sorted, as soon as we see group `n+1` we -/// know we will never see any rows for group `n again and thus they +/// know we will never see any rows for group `n` again and thus they /// can be emitted. /// /// For example, given `SUM(amt) GROUP BY id` if the input is sorted @@ -54,7 +55,7 @@ use datafusion_expr::EmitTo; /// `0..12` can be emitted. Note that `13` can not yet be emitted as /// there may be more values in the next batch with the same group_id. #[derive(Debug)] -pub(crate) struct GroupOrderingFull { +pub struct GroupOrderingFull { state: State, } @@ -63,7 +64,7 @@ enum State { /// Seen no input yet Start, - /// Data is in progress. `current is the current group for which + /// Data is in progress. `current` is the current group for which /// values are being generated. Can emit `current` - 1 InProgress { current: usize }, @@ -139,6 +140,12 @@ impl GroupOrderingFull { } pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + } +} + +impl Default for GroupOrderingFull { + fn default() -> Self { + Self::new() } } diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 556103e1e222..24846d239591 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -19,18 +19,19 @@ use arrow_array::ArrayRef; use arrow_schema::Schema; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; +use std::mem::size_of; mod full; mod partial; use crate::InputOrderMode; -pub(crate) use full::GroupOrderingFull; -pub(crate) use partial::GroupOrderingPartial; +pub use full::GroupOrderingFull; +pub use partial::GroupOrderingPartial; /// Ordering information for each group in the hash table #[derive(Debug)] -pub(crate) enum GroupOrdering { +pub enum GroupOrdering { /// Groups are not ordered None, /// Groups are ordered by some pre-set of the group keys @@ -44,7 +45,7 @@ impl GroupOrdering { pub fn try_new( input_schema: &Schema, mode: &InputOrderMode, - ordering: &[PhysicalSortExpr], + ordering: LexOrderingRef, ) -> Result { match mode { InputOrderMode::Linear => Ok(GroupOrdering::None), @@ -87,7 +88,7 @@ impl GroupOrdering { /// Called when new groups are added in a batch /// /// * `total_num_groups`: total number of groups (so max - /// group_index is total_num_groups - 1). + /// group_index is total_num_groups - 1). /// /// * `group_values`: group key values for *each row* in the batch /// @@ -117,8 +118,8 @@ impl GroupOrdering { } /// Return the size of memory used by the ordering state, in bytes - pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + pub fn size(&self) -> usize { + size_of::() + match self { GroupOrdering::None => 0, GroupOrdering::Partial(partial) => partial.size(), diff --git a/datafusion/physical-plan/src/aggregates/order/partial.rs b/datafusion/physical-plan/src/aggregates/order/partial.rs index ecd37c913e98..5cc55dc0d028 100644 --- a/datafusion/physical-plan/src/aggregates/order/partial.rs +++ b/datafusion/physical-plan/src/aggregates/order/partial.rs @@ -21,7 +21,9 @@ use arrow_schema::Schema; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; +use std::mem::size_of; +use std::sync::Arc; /// Tracks grouping state when the data is ordered by some subset of /// the group keys. @@ -31,7 +33,7 @@ use datafusion_physical_expr::PhysicalSortExpr; /// key and earlier. /// /// For example, given `SUM(amt) GROUP BY id, state` if the input is -/// sorted by `state, when a new value of `state` is seen, all groups +/// sorted by `state`, when a new value of `state` is seen, all groups /// with prior values of `state` can be emitted. /// /// The state is tracked like this: @@ -59,7 +61,7 @@ use datafusion_physical_expr::PhysicalSortExpr; /// order) recent group index ///``` #[derive(Debug)] -pub(crate) struct GroupOrderingPartial { +pub struct GroupOrderingPartial { /// State machine state: State, @@ -105,7 +107,7 @@ impl GroupOrderingPartial { pub fn try_new( input_schema: &Schema, order_indices: &[usize], - ordering: &[PhysicalSortExpr], + ordering: LexOrderingRef, ) -> Result { assert!(!order_indices.is_empty()); assert!(order_indices.len() <= ordering.len()); @@ -138,7 +140,7 @@ impl GroupOrderingPartial { let sort_values: Vec<_> = self .order_indices .iter() - .map(|&idx| group_values[idx].clone()) + .map(|&idx| Arc::clone(&group_values[idx])) .collect(); Ok(self.row_converter.convert_columns(&sort_values)?) @@ -243,7 +245,7 @@ impl GroupOrderingPartial { /// Return the size of memory allocated by this structure pub(crate) fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.order_indices.allocated_size() + self.row_converter.size() } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index ad0860b93a3a..fe05f7375ed3 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -27,18 +27,18 @@ use crate::aggregates::{ evaluate_group_by, evaluate_many, evaluate_optional, group_schema, AggregateMode, PhysicalGroupBy, }; -use crate::common::IPCWriter; -use crate::metrics::{BaselineMetrics, RecordOutput}; -use crate::sorts::sort::{read_spill_as_stream, sort_batch}; -use crate::sorts::streaming_merge; +use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput}; +use crate::sorts::sort::sort_batch; +use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::spill::{read_spill_as_stream, spill_record_batch_by_size}; use crate::stream::RecordBatchStreamAdapter; -use crate::{aggregates, ExecutionPlan, PhysicalExpr}; +use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr}; use crate::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::datatypes::SchemaRef; use arrow_schema::SortOptions; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{internal_err, DataFusionError, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; @@ -46,10 +46,12 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_expr::{EmitTo, GroupsAccumulator}; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::{ - AggregateExpr, GroupsAccumulatorAdapter, PhysicalSortExpr, -}; +use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr}; +use super::order::GroupOrdering; +use super::AggregateExec; +use datafusion_physical_expr::aggregate::AggregateFunctionExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::ready; use futures::stream::{Stream, StreamExt}; use log::debug; @@ -61,32 +63,150 @@ pub(crate) enum ExecutionState { /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks ProducingOutput(RecordBatch), + /// Produce intermediate aggregate state for each input row without + /// aggregation. + /// + /// See "partial aggregation" discussion on [`GroupedHashAggregateStream`] + SkippingAggregation, + /// All input has been consumed and all groups have been emitted Done, } -use super::order::GroupOrdering; -use super::AggregateExec; - /// This encapsulates the spilling state struct SpillState { - /// If data has previously been spilled, the locations of the - /// spill files (in Arrow IPC format) - spills: Vec, - + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== /// Sorting expression for spilling batches - spill_expr: Vec, + spill_expr: LexOrdering, /// Schema for spilling batches spill_schema: SchemaRef, - /// true when streaming merge is in progress - is_stream_merging: bool, - /// aggregate_arguments for merging spilled data merging_aggregate_arguments: Vec>>, /// GROUP BY expressions for merging spilled data merging_group_by: PhysicalGroupBy, + + // ======================================================================== + // STATES: + // Fields changes during execution. Can be buffer, or state flags that + // influence the execution in parent `GroupedHashAggregateStream` + // ======================================================================== + /// If data has previously been spilled, the locations of the + /// spill files (in Arrow IPC format) + spills: Vec, + + /// true when streaming merge is in progress + is_stream_merging: bool, + + // ======================================================================== + // METRICS: + // ======================================================================== + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: metrics::Gauge, + /// count of spill files during the execution of the operator + spill_count: metrics::Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: metrics::Count, + /// total spilled rows during the execution of the operator + spilled_rows: metrics::Count, +} + +/// Tracks if the aggregate should skip partial aggregations +/// +/// See "partial aggregation" discussion on [`GroupedHashAggregateStream`] +struct SkipAggregationProbe { + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== + /// Aggregation ratio check performed when the number of input rows exceeds + /// this threshold (from `SessionConfig`) + probe_rows_threshold: usize, + /// Maximum ratio of `num_groups` to `input_rows` for continuing aggregation + /// (from `SessionConfig`). If the ratio exceeds this value, aggregation + /// is skipped and input rows are directly converted to output + probe_ratio_threshold: f64, + + // ======================================================================== + // STATES: + // Fields changes during execution. Can be buffer, or state flags that + // influence the exeuction in parent `GroupedHashAggregateStream` + // ======================================================================== + /// Number of processed input rows (updated during probing) + input_rows: usize, + /// Number of total group values for `input_rows` (updated during probing) + num_groups: usize, + + /// Flag indicating further data aggregation may be skipped (decision made + /// when probing complete) + should_skip: bool, + /// Flag indicating further updates of `SkipAggregationProbe` state won't + /// make any effect (set either while probing or on probing completion) + is_locked: bool, + + // ======================================================================== + // METRICS: + // ======================================================================== + /// Number of rows where state was output without aggregation. + /// + /// * If 0, all input rows were aggregated (should_skip was always false) + /// + /// * if greater than zero, the number of rows which were output directly + /// without aggregation + skipped_aggregation_rows: metrics::Count, +} + +impl SkipAggregationProbe { + fn new( + probe_rows_threshold: usize, + probe_ratio_threshold: f64, + skipped_aggregation_rows: metrics::Count, + ) -> Self { + Self { + input_rows: 0, + num_groups: 0, + probe_rows_threshold, + probe_ratio_threshold, + should_skip: false, + is_locked: false, + skipped_aggregation_rows, + } + } + + /// Updates `SkipAggregationProbe` state: + /// - increments the number of input rows + /// - replaces the number of groups with the new value + /// - on `probe_rows_threshold` exceeded calculates + /// aggregation ratio and sets `should_skip` flag + /// - if `should_skip` is set, locks further state updates + fn update_state(&mut self, input_rows: usize, num_groups: usize) { + if self.is_locked { + return; + } + self.input_rows += input_rows; + self.num_groups = num_groups; + if self.input_rows >= self.probe_rows_threshold { + self.should_skip = self.num_groups as f64 / self.input_rows as f64 + >= self.probe_ratio_threshold; + self.is_locked = true; + } + } + + fn should_skip(&self) -> bool { + self.should_skip + } + + /// Record the number of rows that were output directly without aggregation + fn record_skipped(&mut self, batch: &RecordBatch) { + self.skipped_aggregation_rows.add(batch.num_rows()); + } } /// HashTable based Grouping Aggregator @@ -136,7 +256,7 @@ struct SpillState { /// of `x` and one accumulator for `SUM(y)`, specialized for the data /// type of `y`. /// -/// # Description +/// # Discussion /// /// [`group_values`] does not store any aggregate state inline. It only /// assigns "group indices", one for each (distinct) group value. The @@ -154,7 +274,25 @@ struct SpillState { /// /// [`group_values`]: Self::group_values /// -/// # Spilling +/// # Partial Aggregate and multi-phase grouping +/// +/// As described on [`Accumulator::state`], this operator is used in the context +/// "multi-phase" grouping when the mode is [`AggregateMode::Partial`]. +/// +/// An important optimization for multi-phase partial aggregation is to skip +/// partial aggregation when it is not effective enough to warrant the memory or +/// CPU cost, as is often the case for queries many distinct groups (high +/// cardinality group by). Memory is particularly important because each Partial +/// aggregator must store the intermediate state for each group. +/// +/// If the ratio of the number of groups to the number of input rows exceeds a +/// threshold, and [`GroupsAccumulator::supports_convert_to_state`] is +/// supported, this operator will stop applying Partial aggregation and directly +/// pass the input rows to the next aggregation phase. +/// +/// [`Accumulator::state`]: datafusion_expr::Accumulator::state +/// +/// # Spilling (to disk) /// /// The sizes of group values and accumulators can become large. Before that causes out of memory, /// this hash aggregator outputs partial states early for partial aggregation or spills to local @@ -205,17 +343,15 @@ struct SpillState { /// └─────────────────┘ └─────────────────┘ /// ``` pub(crate) struct GroupedHashAggregateStream { + // ======================================================================== + // PROPERTIES: + // These fields are initialized at the start and remain constant throughout + // the execution. + // ======================================================================== schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, - /// Accumulators, one for each `AggregateExpr` in the query - /// - /// For example, if the query has aggregates, `SUM(x)`, - /// `COUNT(y)`, there will be two accumulators, each one - /// specialized for that particular aggregate and its input types - accumulators: Vec>, - /// Arguments to pass to each accumulator. /// /// The arguments in `accumulator[i]` is passed `aggregate_arguments[i]` @@ -236,9 +372,30 @@ pub(crate) struct GroupedHashAggregateStream { /// GROUP BY expressions group_by: PhysicalGroupBy, - /// The memory reservation for this grouping - reservation: MemoryReservation, + /// max rows in output RecordBatches + batch_size: usize, + + /// Optional soft limit on the number of `group_values` in a batch + /// If the number of `group_values` in a single batch exceeds this value, + /// the `GroupedHashAggregateStream` operation immediately switches to + /// output mode and emits all groups. + group_values_soft_limit: Option, + // ======================================================================== + // STATE FLAGS: + // These fields will be updated during the execution. And control the flow of + // the execution. + // ======================================================================== + /// Tracks if this stream is generating input or output + exec_state: ExecutionState, + + /// Have we seen the end of the input + input_done: bool, + + // ======================================================================== + // STATE BUFFERS: + // These fields will accumulate intermediate results during the execution. + // ======================================================================== /// An interning store of group keys group_values: Box, @@ -246,34 +403,41 @@ pub(crate) struct GroupedHashAggregateStream { /// processed. Reused across batches here to avoid reallocations current_group_indices: Vec, - /// Tracks if this stream is generating input or output - exec_state: ExecutionState, - - /// Execution metrics - baseline_metrics: BaselineMetrics, - - /// max rows in output RecordBatches - batch_size: usize, + /// Accumulators, one for each `AggregateFunctionExpr` in the query + /// + /// For example, if the query has aggregates, `SUM(x)`, + /// `COUNT(y)`, there will be two accumulators, each one + /// specialized for that particular aggregate and its input types + accumulators: Vec>, + // ======================================================================== + // TASK-SPECIFIC STATES: + // Inner states groups together properties, states for a specific task. + // ======================================================================== /// Optional ordering information, that might allow groups to be /// emitted from the hash table prior to seeing the end of the /// input group_ordering: GroupOrdering, - /// Have we seen the end of the input - input_done: bool, - - /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument - runtime: Arc, - /// The spill state object spill_state: SpillState, - /// Optional soft limit on the number of `group_values` in a batch - /// If the number of `group_values` in a single batch exceeds this value, - /// the `GroupedHashAggregateStream` operation immediately switches to - /// output mode and emits all groups. - group_values_soft_limit: Option, + /// Optional probe for skipping data aggregation, if supported by + /// current stream. + skip_aggregation_probe: Option, + + // ======================================================================== + // EXECUTION RESOURCES: + // Fields related to managing execution resources and monitoring performance. + // ======================================================================== + /// The memory reservation for this grouping + reservation: MemoryReservation, + + /// Execution metrics + baseline_metrics: BaselineMetrics, + + /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument + runtime: Arc, } impl GroupedHashAggregateStream { @@ -301,13 +465,13 @@ impl GroupedHashAggregateStream { let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; // arguments for aggregating spilled data is the same as the one for final aggregation let merging_aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &AggregateMode::Final, - agg_group_by.expr.len(), + agg_group_by.num_group_exprs(), )?; let filter_expressions = match agg.mode { @@ -325,7 +489,7 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?; let spill_expr = group_schema .fields .into_iter() @@ -347,7 +511,7 @@ impl GroupedHashAggregateStream { let group_ordering = GroupOrdering::try_new( &group_schema, &agg.input_order_mode, - ordering.as_slice(), + ordering.as_ref(), )?; let group_values = new_group_values(group_schema)?; @@ -358,10 +522,45 @@ impl GroupedHashAggregateStream { let spill_state = SpillState { spills: vec![], spill_expr, - spill_schema: agg_schema.clone(), + spill_schema: Arc::clone(&agg_schema), is_stream_merging: false, merging_aggregate_arguments, merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()), + peak_mem_used: MetricBuilder::new(&agg.metrics) + .gauge("peak_mem_used", partition), + spill_count: MetricBuilder::new(&agg.metrics).spill_count(partition), + spilled_bytes: MetricBuilder::new(&agg.metrics).spilled_bytes(partition), + spilled_rows: MetricBuilder::new(&agg.metrics).spilled_rows(partition), + }; + + // Skip aggregation is supported if: + // - aggregation mode is Partial + // - input is not ordered by GROUP BY expressions, + // since Final mode expects unique group values as its input + // - all accumulators support input batch to intermediate + // aggregate state conversion + // - there is only one GROUP BY expressions set + let skip_aggregation_probe = if agg.mode == AggregateMode::Partial + && matches!(group_ordering, GroupOrdering::None) + && accumulators + .iter() + .all(|acc| acc.supports_convert_to_state()) + && agg_group_by.is_single() + { + let options = &context.session_config().options().execution; + let probe_rows_threshold = + options.skip_partial_aggregation_probe_rows_threshold; + let probe_ratio_threshold = + options.skip_partial_aggregation_probe_ratio_threshold; + let skipped_aggregation_rows = MetricBuilder::new(&agg.metrics) + .counter("skipped_aggregation_rows", partition); + Some(SkipAggregationProbe::new( + probe_rows_threshold, + probe_ratio_threshold, + skipped_aggregation_rows, + )) + } else { + None }; Ok(GroupedHashAggregateStream { @@ -383,6 +582,7 @@ impl GroupedHashAggregateStream { runtime: context.runtime_env(), spill_state, group_values_soft_limit: agg.limit, + skip_aggregation_probe, }) } } @@ -391,7 +591,7 @@ impl GroupedHashAggregateStream { /// that is supported by the aggregate, or a /// [`GroupsAccumulatorAdapter`] if not. pub(crate) fn create_group_accumulator( - agg_expr: &Arc, + agg_expr: &Arc, ) -> Result> { if agg_expr.groups_accumulator_supported() { agg_expr.create_groups_accumulator() @@ -401,7 +601,7 @@ pub(crate) fn create_group_accumulator( "Creating GroupsAccumulatorAdapter for {}: {agg_expr:?}", agg_expr.name() ); - let agg_expr_captured = agg_expr.clone(); + let agg_expr_captured = Arc::clone(agg_expr); let factory = move || agg_expr_captured.create_accumulator(); Ok(Box::new(GroupsAccumulatorAdapter::new(factory))) } @@ -430,9 +630,49 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate + // New batch to aggregate in partial aggregation operator + Some(Ok(batch)) if self.mode == AggregateMode::Partial => { + let timer = elapsed_compute.timer(); + let input_rows = batch.num_rows(); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + self.update_skip_aggregation_probe(input_rows); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + extract_ok!(self.emit_early_if_necessary()); + + extract_ok!(self.switch_to_skip_aggregation()); + + timer.done(); + } + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) Some(Ok(batch)) => { let timer = elapsed_compute.timer(); + // Make sure we have enough capacity for `batch`, otherwise spill extract_ok!(self.spill_previous_if_necessary(&batch)); @@ -460,14 +700,16 @@ impl Stream for GroupedHashAggregateStream { break 'reading_input; } - extract_ok!(self.emit_early_if_necessary()); - timer.done(); } + + // Found error from input stream Some(Err(e)) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } + + // Found end from input stream None => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); @@ -475,6 +717,29 @@ impl Stream for GroupedHashAggregateStream { } } + ExecutionState::SkippingAggregation => { + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let _timer = elapsed_compute.timer(); + if let Some(probe) = self.skip_aggregation_probe.as_mut() { + probe.record_skipped(&batch); + } + let states = self.transform_to_states(batch)?; + return Poll::Ready(Some(Ok( + states.record_output(&self.baseline_metrics) + ))); + } + Some(Err(e)) => { + // inner had error, return to caller + return Poll::Ready(Some(Err(e))); + } + None => { + // inner is done, switching to `Done` state + self.exec_state = ExecutionState::Done; + } + } + } + ExecutionState::ProducingOutput(batch) => { // slice off a part of the batch, if needed let output_batch; @@ -483,6 +748,13 @@ impl Stream for GroupedHashAggregateStream { ( if self.input_done { ExecutionState::Done + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { + ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput }, @@ -515,7 +787,7 @@ impl Stream for GroupedHashAggregateStream { impl RecordBatchStream for GroupedHashAggregateStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -614,18 +886,26 @@ impl GroupedHashAggregateStream { fn update_memory_reservation(&mut self) -> Result<()> { let acc = self.accumulators.iter().map(|x| x.size()).sum::(); - self.reservation.try_resize( + let reservation_result = self.reservation.try_resize( acc + self.group_values.size() + self.group_ordering.size() + self.current_group_indices.allocated_size(), - ) + ); + + if reservation_result.is_ok() { + self.spill_state + .peak_mem_used + .set_max(self.reservation.size()); + } + + reservation_result } /// Create an output RecordBatch with the group keys and /// accumulator states/values specified in emit_to fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result { let schema = if spilling { - self.spill_state.spill_schema.clone() + Arc::clone(&self.spill_state.spill_schema) } else { self.schema() }; @@ -669,10 +949,10 @@ impl GroupedHashAggregateStream { if self.group_values.len() > 0 && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) - && !matches!(self.mode, AggregateMode::Partial) && !self.spill_state.is_stream_merging && self.update_memory_reservation().is_err() { + assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. self.spill_state.spill_schema = batch.schema(); @@ -685,22 +965,24 @@ impl GroupedHashAggregateStream { /// Emit all rows, sort them, and store them on disk. fn spill(&mut self) -> Result<()> { let emit = self.emit(EmitTo::All, true)?; - let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; + let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?; let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; - let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?; // TODO: slice large `sorted` and write to multiple files in parallel - let mut offset = 0; - let total_rows = sorted.num_rows(); - - while offset < total_rows { - let length = std::cmp::min(total_rows - offset, self.batch_size); - let batch = sorted.slice(offset, length); - offset += batch.num_rows(); - writer.write(&batch)?; - } - - writer.finish()?; + spill_record_batch_by_size( + &sorted, + spillfile.path().into(), + sorted.schema(), + self.batch_size, + )?; self.spill_state.spills.push(spillfile); + + // Update metrics + self.spill_state.spill_count.add(1); + self.spill_state + .spilled_bytes + .add(sorted.get_array_memory_size()); + self.spill_state.spilled_rows.add(sorted.num_rows()); + Ok(()) } @@ -723,9 +1005,9 @@ impl GroupedHashAggregateStream { fn emit_early_if_necessary(&mut self) -> Result<()> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) - && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { + assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; self.exec_state = ExecutionState::ProducingOutput(batch); @@ -746,25 +1028,24 @@ impl GroupedHashAggregateStream { let expr = self.spill_state.spill_expr.clone(); let schema = batch.schema(); streams.push(Box::pin(RecordBatchStreamAdapter::new( - schema.clone(), + Arc::clone(&schema), futures::stream::once(futures::future::lazy(move |_| { - sort_batch(&batch, &expr, None) + sort_batch(&batch, expr.as_ref(), None) })), ))); for spill in self.spill_state.spills.drain(..) { - let stream = read_spill_as_stream(spill, schema.clone())?; + let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?; streams.push(stream); } self.spill_state.is_stream_merging = true; - self.input = streaming_merge( - streams, - schema, - &self.spill_state.spill_expr, - self.baseline_metrics.clone(), - self.batch_size, - None, - self.reservation.new_empty(), - )?; + self.input = StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(schema) + .with_expressions(self.spill_state.spill_expr.as_ref()) + .with_metrics(self.baseline_metrics.clone()) + .with_batch_size(self.batch_size) + .with_reservation(self.reservation.new_empty()) + .build()?; self.input_done = false; self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); Ok(()) @@ -796,4 +1077,68 @@ impl GroupedHashAggregateStream { timer.done(); Ok(()) } + + /// Updates skip aggregation probe state. + /// + /// Notice: It should only be called in Partial aggregation + fn update_skip_aggregation_probe(&mut self, input_rows: usize) { + if let Some(probe) = self.skip_aggregation_probe.as_mut() { + // Skip aggregation probe is not supported if stream has any spills, + // currently spilling is not supported for Partial aggregation + assert!(self.spill_state.spills.is_empty()); + probe.update_state(input_rows, self.group_values.len()); + }; + } + + /// In case the probe indicates that aggregation may be + /// skipped, forces stream to produce currently accumulated output. + /// + /// Notice: It should only be called in Partial aggregation + fn switch_to_skip_aggregation(&mut self) -> Result<()> { + if let Some(probe) = self.skip_aggregation_probe.as_mut() { + if probe.should_skip() { + let batch = self.emit(EmitTo::All, false)?; + self.exec_state = ExecutionState::ProducingOutput(batch); + } + } + + Ok(()) + } + + /// Returns true if the aggregation probe indicates that aggregation + /// should be skipped. + /// + /// Notice: It should only be called in Partial aggregation + fn should_skip_aggregation(&self) -> bool { + self.skip_aggregation_probe + .as_ref() + .is_some_and(|probe| probe.should_skip()) + } + + /// Transforms input batch to intermediate aggregate state, without grouping it + fn transform_to_states(&self, batch: RecordBatch) -> Result { + let mut group_values = evaluate_group_by(&self.group_by, &batch)?; + let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; + let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; + + if group_values.len() != 1 { + return internal_err!("group_values expected to have single element"); + } + let mut output = group_values.swap_remove(0); + + let iter = self + .accumulators + .iter() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + for ((acc, values), opt_filter) in iter { + let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + output.extend(acc.convert_to_state(values, opt_filter)?); + } + + let states_batch = RecordBatch::try_new(self.schema(), output)?; + + Ok(states_batch) + } } diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index bae4c6133b9f..34df643b6cf0 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -26,6 +26,7 @@ use arrow_array::cast::AsArray; use arrow_array::{ downcast_primitive, Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray, StringArray, }; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::DataType; use datafusion_common::DataFusionError; use datafusion_common::Result; @@ -108,7 +109,7 @@ impl StringHashTable { Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } @@ -180,7 +181,7 @@ where Self { owned, map: TopKHashTable::new(limit, limit * 10), - rnd: ahash::RandomState::default(), + rnd: RandomState::default(), } } } @@ -363,9 +364,13 @@ macro_rules! has_integer { has_integer!(i8, i16, i32, i64, i128, i256); has_integer!(u8, u16, u32, u64); +has_integer!(IntervalDayTime, IntervalMonthDayNano); hash_float!(f16, f32, f64); -pub fn new_hash_table(limit: usize, kt: DataType) -> Result> { +pub fn new_hash_table( + limit: usize, + kt: DataType, +) -> Result> { macro_rules! downcast_helper { ($kt:ty, $d:ident) => { return Ok(Box::new(PrimitiveHashTable::<$kt>::new(limit))) diff --git a/datafusion/physical-plan/src/aggregates/topk/heap.rs b/datafusion/physical-plan/src/aggregates/topk/heap.rs index 41826ed72853..e694422e443d 100644 --- a/datafusion/physical-plan/src/aggregates/topk/heap.rs +++ b/datafusion/physical-plan/src/aggregates/topk/heap.rs @@ -20,13 +20,14 @@ use arrow::datatypes::i256; use arrow_array::cast::AsArray; use arrow_array::{downcast_primitive, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::DataType; use datafusion_common::DataFusionError; use datafusion_common::Result; use datafusion_physical_expr::aggregate::utils::adjust_output_array; use half::f16; use std::cmp::Ordering; -use std::fmt::{Debug, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::sync::Arc; /// A custom version of `Ord` that only exists to we can implement it for the Values in our heap @@ -322,29 +323,53 @@ impl TopKHeap { } } - #[cfg(test)] - fn _tree_print(&self, idx: usize) -> Option> { - let hi = self.heap.get(idx)?; - match hi { - None => None, - Some(hi) => { - let label = - format!("val={:?} idx={}, bucket={}", hi.val, idx, hi.map_idx); - let left = self._tree_print(idx * 2 + 1); - let right = self._tree_print(idx * 2 + 2); - let children = left.into_iter().chain(right); - let me = termtree::Tree::new(label).with_leaves(children); - Some(me) + fn _tree_print( + &self, + idx: usize, + prefix: String, + is_tail: bool, + output: &mut String, + ) { + if let Some(Some(hi)) = self.heap.get(idx) { + let connector = if idx != 0 { + if is_tail { + "└── " + } else { + "├── " + } + } else { + "" + }; + output.push_str(&format!( + "{}{}val={:?} idx={}, bucket={}\n", + prefix, connector, hi.val, idx, hi.map_idx + )); + let new_prefix = if is_tail { "" } else { "│ " }; + let child_prefix = format!("{}{}", prefix, new_prefix); + + let left_idx = idx * 2 + 1; + let right_idx = idx * 2 + 2; + + let left_exists = left_idx < self.len; + let right_exists = right_idx < self.len; + + if left_exists { + self._tree_print(left_idx, child_prefix.clone(), !right_exists, output); + } + if right_exists { + self._tree_print(right_idx, child_prefix, true, output); } } } +} - #[cfg(test)] - fn tree_print(&self) -> String { - match self._tree_print(0) { - None => "".to_string(), - Some(root) => format!("{}", root), +impl Display for TopKHeap { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let mut output = String::new(); + if self.heap.first().is_some() { + self._tree_print(0, String::new(), true, &mut output); } + write!(f, "{}", output) } } @@ -360,9 +385,9 @@ impl HeapItem { impl Debug for HeapItem { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str("bucket=")?; - self.map_idx.fmt(f)?; + Debug::fmt(&self.map_idx, f)?; f.write_str(" val=")?; - self.val.fmt(f)?; + Debug::fmt(&self.val, f)?; f.write_str("\n")?; Ok(()) } @@ -431,9 +456,14 @@ macro_rules! compare_integer { compare_integer!(i8, i16, i32, i64, i128, i256); compare_integer!(u8, u16, u32, u64); +compare_integer!(IntervalDayTime, IntervalMonthDayNano); compare_float!(f16, f32, f64); -pub fn new_heap(limit: usize, desc: bool, vt: DataType) -> Result> { +pub fn new_heap( + limit: usize, + desc: bool, + vt: DataType, +) -> Result> { macro_rules! downcast_helper { ($vt:ty, $d:ident) => { return Ok(Box::new(PrimitiveHeap::<$vt>::new(limit, desc, vt))) @@ -460,7 +490,7 @@ mod tests { let mut heap = TopKHeap::new(10, false); heap.append_or_replace(1, 1, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=1 idx=0, bucket=1 "#; @@ -480,7 +510,7 @@ val=1 idx=0, bucket=1 heap.append_or_replace(2, 2, &mut map); assert_eq!(map, vec![(2, 0), (1, 1)]); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -498,7 +528,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); heap.append_or_replace(3, 3, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=3 idx=0, bucket=3 ├── val=1 idx=1, bucket=1 @@ -508,7 +538,7 @@ val=3 idx=0, bucket=3 let mut map = vec![]; heap.append_or_replace(0, 0, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 ├── val=1 idx=1, bucket=1 @@ -529,7 +559,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(2, 2, &mut map); heap.append_or_replace(3, 3, &mut map); heap.append_or_replace(4, 4, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=4 idx=0, bucket=4 ├── val=3 idx=1, bucket=3 @@ -540,7 +570,7 @@ val=4 idx=0, bucket=4 let mut map = vec![]; heap.replace_if_better(1, 0, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=4 idx=0, bucket=4 ├── val=1 idx=1, bucket=1 @@ -561,7 +591,7 @@ val=4 idx=0, bucket=4 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -582,7 +612,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -605,7 +635,7 @@ val=2 idx=0, bucket=2 heap.append_or_replace(1, 1, &mut map); heap.append_or_replace(2, 2, &mut map); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=2 └── val=1 idx=1, bucket=1 @@ -614,7 +644,7 @@ val=2 idx=0, bucket=2 let numbers = vec![(0, 1), (1, 2)]; heap.renumber(numbers.as_slice()); - let actual = heap.tree_print(); + let actual = heap.to_string(); let expected = r#" val=2 idx=0, bucket=1 └── val=1 idx=1, bucket=2 diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 668018b9c24c..ed41d22e935b 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -25,17 +25,12 @@ use datafusion_common::Result; /// A `Map` / `PriorityQueue` combo that evicts the worst values after reaching `capacity` pub struct PriorityMap { - map: Box, - heap: Box, + map: Box, + heap: Box, capacity: usize, mapper: Vec<(usize, usize)>, } -// JUSTIFICATION -// Benefit: ~15% speedup + required to index into RawTable from binary heap -// Soundness: it is only accessed by one thread at a time, and indexes are kept up to date -unsafe impl Send for PriorityMap {} - impl PriorityMap { pub fn new( key_type: DataType, diff --git a/datafusion/physical-plan/src/aggregates/topk_stream.rs b/datafusion/physical-plan/src/aggregates/topk_stream.rs index 9f25473cb9b4..075d8c5f2883 100644 --- a/datafusion/physical-plan/src/aggregates/topk_stream.rs +++ b/datafusion/physical-plan/src/aggregates/topk_stream.rs @@ -84,14 +84,14 @@ impl GroupedTopKAggregateStream { impl RecordBatchStream for GroupedTopKAggregateStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } impl GroupedTopKAggregateStream { fn intern(&mut self, ids: ArrayRef, vals: ArrayRef) -> Result<()> { let len = ids.len(); - self.priority_map.set_batch(ids, vals.clone()); + self.priority_map.set_batch(ids, Arc::clone(&vals)); let has_nulls = vals.null_count() > 0; for row_idx in 0..len { @@ -139,14 +139,14 @@ impl Stream for GroupedTopKAggregateStream { 1, "Exactly 1 group value required" ); - let group_by_values = group_by_values[0][0].clone(); + let group_by_values = Arc::clone(&group_by_values[0][0]); let input_values = evaluate_many( &self.aggregate_arguments, batches.first().unwrap(), )?; assert_eq!(input_values.len(), 1, "Exactly 1 input required"); assert_eq!(input_values[0].len(), 1, "Exactly 1 input required"); - let input_values = input_values[0][0].clone(); + let input_values = Arc::clone(&input_values[0][0]); // iterate over each column of group_by values (*self).intern(group_by_values, input_values)?; @@ -158,7 +158,7 @@ impl Stream for GroupedTopKAggregateStream { return Poll::Ready(None); } let cols = self.priority_map.emit()?; - let batch = RecordBatch::try_new(self.schema.clone(), cols)?; + let batch = RecordBatch::try_new(Arc::clone(&self.schema), cols)?; trace!( "partition {} emit batch with {} rows", self.partition, diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index c420581c4323..c8b329fabdaa 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -40,9 +40,9 @@ use futures::StreamExt; /// discards the results, and then prints out an annotated plan with metrics #[derive(Debug, Clone)] pub struct AnalyzeExec { - /// control how much extra to print + /// Control how much extra to print verbose: bool, - /// if statistics should be displayed + /// If statistics should be displayed show_statistics: bool, /// The input plan (the plan being analyzed) pub(crate) input: Arc, @@ -59,7 +59,7 @@ impl AnalyzeExec { input: Arc, schema: SchemaRef, ) -> Self { - let cache = Self::compute_properties(&input, schema.clone()); + let cache = Self::compute_properties(&input, Arc::clone(&schema)); AnalyzeExec { verbose, show_statistics, @@ -69,12 +69,12 @@ impl AnalyzeExec { } } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } - /// access to show_statistics + /// Access to show_statistics pub fn show_statistics(&self) -> bool { self.show_statistics } @@ -124,8 +124,8 @@ impl ExecutionPlan for AnalyzeExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } /// AnalyzeExec is handled specially so this value is ignored @@ -141,7 +141,7 @@ impl ExecutionPlan for AnalyzeExec { self.verbose, self.show_statistics, children.pop().unwrap(), - self.schema.clone(), + Arc::clone(&self.schema), ))) } @@ -164,13 +164,17 @@ impl ExecutionPlan for AnalyzeExec { RecordBatchReceiverStream::builder(self.schema(), num_input_partitions); for input_partition in 0..num_input_partitions { - builder.run_input(self.input.clone(), input_partition, context.clone()); + builder.run_input( + Arc::clone(&self.input), + input_partition, + Arc::clone(&context), + ); } // Create future that computes thefinal output let start = Instant::now(); - let captured_input = self.input.clone(); - let captured_schema = self.schema.clone(); + let captured_input = Arc::clone(&self.input); + let captured_schema = Arc::clone(&self.schema); let verbose = self.verbose; let show_statistics = self.show_statistics; @@ -196,13 +200,13 @@ impl ExecutionPlan for AnalyzeExec { }; Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), futures::stream::once(output), ))) } } -/// Creates the ouput of AnalyzeExec as a RecordBatch +/// Creates the output of AnalyzeExec as a RecordBatch fn create_output_batch( verbose: bool, show_statistics: bool, diff --git a/datafusion/physical-plan/src/coalesce/mod.rs b/datafusion/physical-plan/src/coalesce/mod.rs new file mode 100644 index 000000000000..46875fae94fc --- /dev/null +++ b/datafusion/physical-plan/src/coalesce/mod.rs @@ -0,0 +1,600 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::concat_batches; +use arrow_array::builder::StringViewBuilder; +use arrow_array::cast::AsArray; +use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_schema::SchemaRef; +use std::sync::Arc; + +/// Concatenate multiple [`RecordBatch`]es +/// +/// `BatchCoalescer` concatenates multiple small [`RecordBatch`]es, produced by +/// operations such as `FilterExec` and `RepartitionExec`, into larger ones for +/// more efficient processing by subsequent operations. +/// +/// # Background +/// +/// Generally speaking, larger [`RecordBatch`]es are more efficient to process +/// than smaller record batches (until the CPU cache is exceeded) because there +/// is fixed processing overhead per batch. DataFusion tries to operate on +/// batches of `target_batch_size` rows to amortize this overhead +/// +/// ```text +/// ┌────────────────────┐ +/// │ RecordBatch │ +/// │ num_rows = 23 │ +/// └────────────────────┘ ┌────────────────────┐ +/// │ │ +/// ┌────────────────────┐ Coalesce │ │ +/// │ │ Batches │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │ +/// │ │ │ RecordBatch │ +/// │ │ │ num_rows = 106 │ +/// └────────────────────┘ │ │ +/// │ │ +/// ┌────────────────────┐ │ │ +/// │ │ │ │ +/// │ RecordBatch │ │ │ +/// │ num_rows = 33 │ └────────────────────┘ +/// │ │ +/// └────────────────────┘ +/// ``` +/// +/// # Notes: +/// +/// 1. Output rows are produced in the same order as the input rows +/// +/// 2. The output is a sequence of batches, with all but the last being at least +/// `target_batch_size` rows. +/// +/// 3. Eventually this may also be able to handle other optimizations such as a +/// combined filter/coalesce operation. +/// +#[derive(Debug)] +pub struct BatchCoalescer { + /// The input schema + schema: SchemaRef, + /// Minimum number of rows for coalesces batches + target_batch_size: usize, + /// Total number of rows returned so far + total_rows: usize, + /// Buffered batches + buffer: Vec, + /// Buffered row count + buffered_rows: usize, + /// Limit: maximum number of rows to fetch, `None` means fetch all rows + fetch: Option, +} + +impl BatchCoalescer { + /// Create a new `BatchCoalescer` + /// + /// # Arguments + /// - `schema` - the schema of the output batches + /// - `target_batch_size` - the minimum number of rows for each + /// output batch (until limit reached) + /// - `fetch` - the maximum number of rows to fetch, `None` means fetch all rows + pub fn new( + schema: SchemaRef, + target_batch_size: usize, + fetch: Option, + ) -> Self { + Self { + schema, + target_batch_size, + total_rows: 0, + buffer: vec![], + buffered_rows: 0, + fetch, + } + } + + /// Return the schema of the output batches + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + /// Push next batch, and returns [`CoalescerState`] indicating the current + /// state of the buffer. + pub fn push_batch(&mut self, batch: RecordBatch) -> CoalescerState { + let batch = gc_string_view_batch(&batch); + if self.limit_reached(&batch) { + CoalescerState::LimitReached + } else if self.target_reached(batch) { + CoalescerState::TargetReached + } else { + CoalescerState::Continue + } + } + + /// Return true if the there is no data buffered + pub fn is_empty(&self) -> bool { + self.buffer.is_empty() + } + + /// Checks if the buffer will reach the specified limit after getting + /// `batch`. + /// + /// If fetch would be exceeded, slices the received batch, updates the + /// buffer with it, and returns `true`. + /// + /// Otherwise: does nothing and returns `false`. + fn limit_reached(&mut self, batch: &RecordBatch) -> bool { + match self.fetch { + Some(fetch) if self.total_rows + batch.num_rows() >= fetch => { + // Limit is reached + let remaining_rows = fetch - self.total_rows; + debug_assert!(remaining_rows > 0); + + let batch = batch.slice(0, remaining_rows); + self.buffered_rows += batch.num_rows(); + self.total_rows = fetch; + self.buffer.push(batch); + true + } + _ => false, + } + } + + /// Updates the buffer with the given batch. + /// + /// If the target batch size is reached, returns `true`. Otherwise, returns + /// `false`. + fn target_reached(&mut self, batch: RecordBatch) -> bool { + if batch.num_rows() == 0 { + false + } else { + self.total_rows += batch.num_rows(); + self.buffered_rows += batch.num_rows(); + self.buffer.push(batch); + self.buffered_rows >= self.target_batch_size + } + } + + /// Concatenates and returns all buffered batches, and clears the buffer. + pub fn finish_batch(&mut self) -> datafusion_common::Result { + let batch = concat_batches(&self.schema, &self.buffer)?; + self.buffer.clear(); + self.buffered_rows = 0; + Ok(batch) + } +} + +/// Indicates the state of the [`BatchCoalescer`] buffer after the +/// [`BatchCoalescer::push_batch()`] operation. +/// +/// The caller should take diferent actions, depending on the variant returned. +pub enum CoalescerState { + /// Neither the limit nor the target batch size is reached. + /// + /// Action: continue pushing batches. + Continue, + /// The limit has been reached. + /// + /// Action: call [`BatchCoalescer::finish_batch()`] to get the final + /// buffered results as a batch and finish the query. + LimitReached, + /// The specified minimum number of rows a batch should have is reached. + /// + /// Action: call [`BatchCoalescer::finish_batch()`] to get the current + /// buffered results as a batch and then continue pushing batches. + TargetReached, +} + +/// Heuristically compact `StringViewArray`s to reduce memory usage, if needed +/// +/// Decides when to consolidate the StringView into a new buffer to reduce +/// memory usage and improve string locality for better performance. +/// +/// This differs from `StringViewArray::gc` because: +/// 1. It may not compact the array depending on a heuristic. +/// 2. It uses a precise block size to reduce the number of buffers to track. +/// +/// # Heuristic +/// +/// If the average size of each view is larger than 32 bytes, we compact the array. +/// +/// `StringViewArray` include pointers to buffer that hold the underlying data. +/// One of the great benefits of `StringViewArray` is that many operations +/// (e.g., `filter`) can be done without copying the underlying data. +/// +/// However, after a while (e.g., after `FilterExec` or `HashJoinExec`) the +/// `StringViewArray` may only refer to a small portion of the buffer, +/// significantly increasing memory usage. +fn gc_string_view_batch(batch: &RecordBatch) -> RecordBatch { + let new_columns: Vec = batch + .columns() + .iter() + .map(|c| { + // Try to re-create the `StringViewArray` to prevent holding the underlying buffer too long. + let Some(s) = c.as_string_view_opt() else { + return Arc::clone(c); + }; + let ideal_buffer_size: usize = s + .views() + .iter() + .map(|v| { + let len = (*v as u32) as usize; + if len > 12 { + len + } else { + 0 + } + }) + .sum(); + let actual_buffer_size = s.get_buffer_memory_size(); + + // Re-creating the array copies data and can be time consuming. + // We only do it if the array is sparse + if actual_buffer_size > (ideal_buffer_size * 2) { + // We set the block size to `ideal_buffer_size` so that the new StringViewArray only has one buffer, which accelerate later concat_batches. + // See https://github.com/apache/arrow-rs/issues/6094 for more details. + let mut builder = StringViewBuilder::with_capacity(s.len()); + if ideal_buffer_size > 0 { + builder = builder.with_fixed_block_size(ideal_buffer_size as u32); + } + + for v in s.iter() { + builder.append_option(v); + } + + let gc_string = builder.finish(); + + debug_assert!(gc_string.data_buffers().len() <= 1); // buffer count can be 0 if the `ideal_buffer_size` is 0 + + Arc::new(gc_string) + } else { + Arc::clone(c) + } + }) + .collect(); + let mut options = RecordBatchOptions::new(); + options = options.with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(batch.schema(), new_columns, &options) + .expect("Failed to re-create the gc'ed record batch") +} + +#[cfg(test)] +mod tests { + use std::ops::Range; + + use super::*; + + use arrow::datatypes::{DataType, Field, Schema}; + use arrow_array::builder::ArrayBuilder; + use arrow_array::{StringViewArray, UInt32Array}; + + #[test] + fn test_coalesce() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // expected output is batches of at least 20 rows (except for the final batch) + .with_target_batch_size(21) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run() + } + + #[test] + fn test_coalesce_with_fetch_larger_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 100 + // expected to behave the same as `test_concat_batches` + .with_target_batch_size(21) + .with_fetch(Some(100)) + .with_expected_output_sizes(vec![24, 24, 24, 8]) + .run(); + } + + #[test] + fn test_coalesce_with_fetch_less_than_input_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 50 + .with_target_batch_size(21) + .with_fetch(Some(50)) + .with_expected_output_sizes(vec![24, 24, 2]) + .run(); + } + + #[test] + fn test_coalesce_with_fetch_less_than_target_and_no_remaining_rows() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 48 + .with_target_batch_size(21) + .with_fetch(Some(48)) + .with_expected_output_sizes(vec![24, 24]) + .run(); + } + + #[test] + fn test_coalesce_with_fetch_less_target_batch_size() { + let batch = uint32_batch(0..8); + Test::new() + .with_batches(std::iter::repeat(batch).take(10)) + // input is 10 batches x 8 rows (80 rows) with fetch limit of 10 + .with_target_batch_size(21) + .with_fetch(Some(10)) + .with_expected_output_sizes(vec![10]) + .run(); + } + + #[test] + fn test_coalesce_single_large_batch_over_fetch() { + let large_batch = uint32_batch(0..100); + Test::new() + .with_batch(large_batch) + .with_target_batch_size(20) + .with_fetch(Some(7)) + .with_expected_output_sizes(vec![7]) + .run() + } + + /// Test for [`BatchCoalescer`] + /// + /// Pushes the input batches to the coalescer and verifies that the resulting + /// batches have the expected number of rows and contents. + #[derive(Debug, Clone, Default)] + struct Test { + /// Batches to feed to the coalescer. Tests must have at least one + /// schema + input_batches: Vec, + /// Expected output sizes of the resulting batches + expected_output_sizes: Vec, + /// target batch size + target_batch_size: usize, + /// Fetch (limit) + fetch: Option, + } + + impl Test { + fn new() -> Self { + Self::default() + } + + /// Set the target batch size + fn with_target_batch_size(mut self, target_batch_size: usize) -> Self { + self.target_batch_size = target_batch_size; + self + } + + /// Set the fetch (limit) + fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Extend the input batches with `batch` + fn with_batch(mut self, batch: RecordBatch) -> Self { + self.input_batches.push(batch); + self + } + + /// Extends the input batches with `batches` + fn with_batches( + mut self, + batches: impl IntoIterator, + ) -> Self { + self.input_batches.extend(batches); + self + } + + /// Extends `sizes` to expected output sizes + fn with_expected_output_sizes( + mut self, + sizes: impl IntoIterator, + ) -> Self { + self.expected_output_sizes.extend(sizes); + self + } + + /// Runs the test -- see documentation on [`Test`] for details + fn run(self) { + let Self { + input_batches, + target_batch_size, + fetch, + expected_output_sizes, + } = self; + + let schema = input_batches[0].schema(); + + // create a single large input batch for output comparison + let single_input_batch = concat_batches(&schema, &input_batches).unwrap(); + + let mut coalescer = + BatchCoalescer::new(Arc::clone(&schema), target_batch_size, fetch); + + let mut output_batches = vec![]; + for batch in input_batches { + match coalescer.push_batch(batch) { + CoalescerState::Continue => {} + CoalescerState::LimitReached => { + output_batches.push(coalescer.finish_batch().unwrap()); + break; + } + CoalescerState::TargetReached => { + coalescer.buffered_rows = 0; + output_batches.push(coalescer.finish_batch().unwrap()); + } + } + } + if coalescer.buffered_rows != 0 { + output_batches.extend(coalescer.buffer); + } + + // make sure we got the expected number of output batches and content + let mut starting_idx = 0; + assert_eq!(expected_output_sizes.len(), output_batches.len()); + for (i, (expected_size, batch)) in + expected_output_sizes.iter().zip(output_batches).enumerate() + { + assert_eq!( + *expected_size, + batch.num_rows(), + "Unexpected number of rows in Batch {i}" + ); + + // compare the contents of the batch (using `==` compares the + // underlying memory layout too) + let expected_batch = + single_input_batch.slice(starting_idx, *expected_size); + let batch_strings = batch_to_pretty_strings(&batch); + let expected_batch_strings = batch_to_pretty_strings(&expected_batch); + let batch_strings = batch_strings.lines().collect::>(); + let expected_batch_strings = + expected_batch_strings.lines().collect::>(); + assert_eq!( + expected_batch_strings, batch_strings, + "Unexpected content in Batch {i}:\ + \n\nExpected:\n{expected_batch_strings:#?}\n\nActual:\n{batch_strings:#?}" + ); + starting_idx += *expected_size; + } + } + } + + /// Return a batch of UInt32 with the specified range + fn uint32_batch(range: Range) -> RecordBatch { + let schema = + Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt32Array::from_iter_values(range))], + ) + .unwrap() + } + + #[test] + fn test_gc_string_view_batch_small_no_compact() { + // view with only short strings (no buffers) --> no need to compact + let array = StringViewTest { + rows: 1000, + strings: vec![Some("a"), Some("b"), Some("c")], + } + .build(); + + let gc_array = do_gc(array.clone()); + compare_string_array_values(&array, &gc_array); + assert_eq!(array.data_buffers().len(), 0); + assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction + } + + #[test] + fn test_gc_string_view_test_batch_empty() { + let schema = Schema::empty(); + let batch = RecordBatch::new_empty(schema.into()); + let output_batch = gc_string_view_batch(&batch); + assert_eq!(batch.num_columns(), output_batch.num_columns()); + assert_eq!(batch.num_rows(), output_batch.num_rows()); + } + + #[test] + fn test_gc_string_view_batch_large_no_compact() { + // view with large strings (has buffers) but full --> no need to compact + let array = StringViewTest { + rows: 1000, + strings: vec![Some("This string is longer than 12 bytes")], + } + .build(); + + let gc_array = do_gc(array.clone()); + compare_string_array_values(&array, &gc_array); + assert_eq!(array.data_buffers().len(), 5); + assert_eq!(array.data_buffers().len(), gc_array.data_buffers().len()); // no compaction + } + + #[test] + fn test_gc_string_view_batch_large_slice_compact() { + // view with large strings (has buffers) and only partially used --> no need to compact + let array = StringViewTest { + rows: 1000, + strings: vec![Some("this string is longer than 12 bytes")], + } + .build(); + + // slice only 11 rows, so most of the buffer is not used + let array = array.slice(11, 22); + + let gc_array = do_gc(array.clone()); + compare_string_array_values(&array, &gc_array); + assert_eq!(array.data_buffers().len(), 5); + assert_eq!(gc_array.data_buffers().len(), 1); // compacted into a single buffer + } + + /// Compares the values of two string view arrays + fn compare_string_array_values(arr1: &StringViewArray, arr2: &StringViewArray) { + assert_eq!(arr1.len(), arr2.len()); + for (s1, s2) in arr1.iter().zip(arr2.iter()) { + assert_eq!(s1, s2); + } + } + + /// runs garbage collection on string view array + /// and ensures the number of rows are the same + fn do_gc(array: StringViewArray) -> StringViewArray { + let batch = + RecordBatch::try_from_iter(vec![("a", Arc::new(array) as ArrayRef)]).unwrap(); + let gc_batch = gc_string_view_batch(&batch); + assert_eq!(batch.num_rows(), gc_batch.num_rows()); + assert_eq!(batch.schema(), gc_batch.schema()); + gc_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .clone() + } + + /// Describes parameters for creating a `StringViewArray` + struct StringViewTest { + /// The number of rows in the array + rows: usize, + /// The strings to use in the array (repeated over and over + strings: Vec>, + } + + impl StringViewTest { + /// Create a `StringViewArray` with the parameters specified in this struct + fn build(self) -> StringViewArray { + let mut builder = + StringViewBuilder::with_capacity(100).with_fixed_block_size(8192); + loop { + for &v in self.strings.iter() { + builder.append_option(v); + if builder.len() >= self.rows { + return builder.finish(); + } + } + } + } + } + fn batch_to_pretty_strings(batch: &RecordBatch) -> String { + arrow::util::pretty::pretty_format_batches(&[batch.clone()]) + .unwrap() + .to_string() + } +} diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index bc7c4a3d0673..11678e7a4696 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! CoalesceBatchesExec combines small batches into larger batches for more efficient use of -//! vectorized processing by upstream operators. +//! [`CoalesceBatchesExec`] combines small batches into larger batches. use std::any::Any; use std::pin::Pin; @@ -30,22 +29,33 @@ use crate::{ }; use arrow::datatypes::SchemaRef; -use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::TaskContext; +use crate::coalesce::{BatchCoalescer, CoalescerState}; +use crate::execution_plan::CardinalityEffect; +use futures::ready; use futures::stream::{Stream, StreamExt}; -use log::trace; -/// CoalesceBatchesExec combines small batches into larger batches for more efficient use of -/// vectorized processing by upstream operators. -#[derive(Debug)] +/// `CoalesceBatchesExec` combines small batches into larger batches for more +/// efficient vectorized processing by later operators. +/// +/// The operator buffers batches until it collects `target_batch_size` rows and +/// then emits a single concatenated batch. When only a limited number of rows +/// are necessary (specified by the `fetch` parameter), the operator will stop +/// buffering and returns the final batch once the number of collected rows +/// reaches the `fetch` value. +/// +/// See [`BatchCoalescer`] for more information +#[derive(Debug, Clone)] pub struct CoalesceBatchesExec { /// The input plan input: Arc, - /// Minimum number of rows for coalesces batches + /// Minimum number of rows for coalescing batches target_batch_size: usize, + /// Maximum number of rows to fetch, `None` means fetching all rows + fetch: Option, /// Execution metrics metrics: ExecutionPlanMetricsSet, cache: PlanProperties, @@ -58,11 +68,18 @@ impl CoalesceBatchesExec { Self { input, target_batch_size, + fetch: None, metrics: ExecutionPlanMetricsSet::new(), cache, } } + /// Update fetch with the argument + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + /// The input plan pub fn input(&self) -> &Arc { &self.input @@ -96,8 +113,13 @@ impl DisplayAs for CoalesceBatchesExec { write!( f, "CoalesceBatchesExec: target_batch_size={}", - self.target_batch_size - ) + self.target_batch_size, + )?; + if let Some(fetch) = self.fetch { + write!(f, ", fetch={fetch}")?; + }; + + Ok(()) } } } @@ -117,8 +139,8 @@ impl ExecutionPlan for CoalesceBatchesExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn maintains_input_order(&self) -> Vec { @@ -133,10 +155,10 @@ impl ExecutionPlan for CoalesceBatchesExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(CoalesceBatchesExec::new( - children[0].clone(), - self.target_batch_size, - ))) + Ok(Arc::new( + CoalesceBatchesExec::new(Arc::clone(&children[0]), self.target_batch_size) + .with_fetch(self.fetch), + )) } fn execute( @@ -146,12 +168,14 @@ impl ExecutionPlan for CoalesceBatchesExec { ) -> Result { Ok(Box::pin(CoalesceBatchesStream { input: self.input.execute(partition, context)?, - schema: self.input.schema(), - target_batch_size: self.target_batch_size, - buffer: Vec::new(), - buffered_rows: 0, - is_closed: false, + coalescer: BatchCoalescer::new( + self.input.schema(), + self.target_batch_size, + self.fetch, + ), baseline_metrics: BaselineMetrics::new(&self.metrics, partition), + // Start by pulling data + inner_state: CoalesceBatchesStreamState::Pull, })) } @@ -160,25 +184,39 @@ impl ExecutionPlan for CoalesceBatchesExec { } fn statistics(&self) -> Result { - self.input.statistics() + Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + } + + fn with_fetch(&self, limit: Option) -> Option> { + Some(Arc::new(CoalesceBatchesExec { + input: Arc::clone(&self.input), + target_batch_size: self.target_batch_size, + fetch: limit, + metrics: self.metrics.clone(), + cache: self.cache.clone(), + })) + } + + fn fetch(&self) -> Option { + self.fetch + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal } } +/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details. struct CoalesceBatchesStream { /// The input plan input: SendableRecordBatchStream, - /// The input schema - schema: SchemaRef, - /// Minimum number of rows for coalesces batches - target_batch_size: usize, - /// Buffered batches - buffer: Vec, - /// Buffered row count - buffered_rows: usize, - /// Whether the stream has finished returning all of its data or not - is_closed: bool, + /// Buffer for combining batches + coalescer: BatchCoalescer, /// Execution metrics baseline_metrics: BaselineMetrics, + /// The current inner state of the stream. This state dictates the current + /// action or operation to be performed in the streaming process. + inner_state: CoalesceBatchesStreamState, } impl Stream for CoalesceBatchesStream { @@ -198,73 +236,100 @@ impl Stream for CoalesceBatchesStream { } } +/// Enumeration of possible states for `CoalesceBatchesStream`. +/// It represents different stages in the lifecycle of a stream of record batches. +/// +/// An example of state transition: +/// Notation: +/// `[3000]`: A batch with size 3000 +/// `{[2000], [3000]}`: `CoalesceBatchStream`'s internal buffer with 2 batches buffered +/// Input of `CoalesceBatchStream` will generate three batches `[2000], [3000], [4000]` +/// The coalescing procedure will go through the following steps with 4096 coalescing threshold: +/// 1. Read the first batch and get it buffered. +/// - initial state: `Pull` +/// - initial buffer: `{}` +/// - updated buffer: `{[2000]}` +/// - next state: `Pull` +/// 2. Read the second batch, the coalescing target is reached since 2000 + 3000 > 4096 +/// - initial state: `Pull` +/// - initial buffer: `{[2000]}` +/// - updated buffer: `{[2000], [3000]}` +/// - next state: `ReturnBuffer` +/// 4. Two batches in the batch get merged and consumed by the upstream operator. +/// - initial state: `ReturnBuffer` +/// - initial buffer: `{[2000], [3000]}` +/// - updated buffer: `{}` +/// - next state: `Pull` +/// 5. Read the third input batch. +/// - initial state: `Pull` +/// - initial buffer: `{}` +/// - updated buffer: `{[4000]}` +/// - next state: `Pull` +/// 5. The input is ended now. Jump to exhaustion state preparing the finalized data. +/// - initial state: `Pull` +/// - initial buffer: `{[4000]}` +/// - updated buffer: `{[4000]}` +/// - next state: `Exhausted` +#[derive(Debug, Clone, Eq, PartialEq)] +enum CoalesceBatchesStreamState { + /// State to pull a new batch from the input stream. + Pull, + /// State to return a buffered batch. + ReturnBuffer, + /// State indicating that the stream is exhausted. + Exhausted, +} + impl CoalesceBatchesStream { fn poll_next_inner( self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { - // Get a clone (uses same underlying atomic) as self gets borrowed below let cloned_time = self.baseline_metrics.elapsed_compute().clone(); - - if self.is_closed { - return Poll::Ready(None); - } loop { - let input_batch = self.input.poll_next_unpin(cx); - // records time on drop - let _timer = cloned_time.timer(); - match input_batch { - Poll::Ready(x) => match x { - Some(Ok(batch)) => { - if batch.num_rows() >= self.target_batch_size - && self.buffer.is_empty() - { - return Poll::Ready(Some(Ok(batch))); - } else if batch.num_rows() == 0 { - // discard empty batches - } else { - // add to the buffered batches - self.buffered_rows += batch.num_rows(); - self.buffer.push(batch); - // check to see if we have enough batches yet - if self.buffered_rows >= self.target_batch_size { - // combine the batches and return - let batch = concat_batches( - &self.schema, - &self.buffer, - self.buffered_rows, - )?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); + match &self.inner_state { + CoalesceBatchesStreamState::Pull => { + // Attempt to pull the next batch from the input stream. + let input_batch = ready!(self.input.poll_next_unpin(cx)); + // Start timing the operation. The timer records time upon being dropped. + let _timer = cloned_time.timer(); + + match input_batch { + Some(Ok(batch)) => match self.coalescer.push_batch(batch) { + CoalescerState::Continue => {} + CoalescerState::LimitReached => { + self.inner_state = CoalesceBatchesStreamState::Exhausted; } + CoalescerState::TargetReached => { + self.inner_state = + CoalesceBatchesStreamState::ReturnBuffer; + } + }, + None => { + // End of input stream, but buffered batches might still be present. + self.inner_state = CoalesceBatchesStreamState::Exhausted; } + other => return Poll::Ready(other), } - None => { - self.is_closed = true; - // we have reached the end of the input stream but there could still - // be buffered batches - if self.buffer.is_empty() { - return Poll::Ready(None); - } else { - // combine the batches and return - let batch = concat_batches( - &self.schema, - &self.buffer, - self.buffered_rows, - )?; - // reset buffer state - self.buffer.clear(); - self.buffered_rows = 0; - // return batch - return Poll::Ready(Some(Ok(batch))); - } - } - other => return Poll::Ready(other), - }, - Poll::Pending => return Poll::Pending, + } + CoalesceBatchesStreamState::ReturnBuffer => { + // Combine buffered batches into one batch and return it. + let batch = self.coalescer.finish_batch()?; + // Set to pull state for the next iteration. + self.inner_state = CoalesceBatchesStreamState::Pull; + return Poll::Ready(Some(Ok(batch))); + } + CoalesceBatchesStreamState::Exhausted => { + // Handle the end of the input stream. + return if self.coalescer.is_empty() { + // If buffer is empty, return None indicating the stream is fully consumed. + Poll::Ready(None) + } else { + // If the buffer still contains batches, prepare to return them. + let batch = self.coalescer.finish_batch()?; + Poll::Ready(Some(Ok(batch))) + }; + } } } } @@ -272,101 +337,6 @@ impl CoalesceBatchesStream { impl RecordBatchStream for CoalesceBatchesStream { fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - -/// Concatenates an array of `RecordBatch` into one batch -pub fn concat_batches( - schema: &SchemaRef, - batches: &[RecordBatch], - row_count: usize, -) -> ArrowResult { - trace!( - "Combined {} batches containing {} rows", - batches.len(), - row_count - ); - arrow::compute::concat_batches(schema, batches) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning}; - - use arrow::datatypes::{DataType, Field, Schema}; - use arrow_array::UInt32Array; - - #[tokio::test(flavor = "multi_thread")] - async fn test_concat_batches() -> Result<()> { - let schema = test_schema(); - let partition = create_vec_batches(&schema, 10); - let partitions = vec![partition]; - - let output_partitions = coalesce_batches(&schema, partitions, 21).await?; - assert_eq!(1, output_partitions.len()); - - // input is 10 batches x 8 rows (80 rows) - // expected output is batches of at least 20 rows (except for the final batch) - let batches = &output_partitions[0]; - assert_eq!(4, batches.len()); - assert_eq!(24, batches[0].num_rows()); - assert_eq!(24, batches[1].num_rows()); - assert_eq!(24, batches[2].num_rows()); - assert_eq!(8, batches[3].num_rows()); - - Ok(()) - } - - fn test_schema() -> Arc { - Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) - } - - async fn coalesce_batches( - schema: &SchemaRef, - input_partitions: Vec>, - target_batch_size: usize, - ) -> Result>> { - // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; - let exec = - RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?; - let exec: Arc = - Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size)); - - // execute and collect results - let output_partition_count = exec.output_partitioning().partition_count(); - let mut output_partitions = Vec::with_capacity(output_partition_count); - for i in 0..output_partition_count { - // execute this *output* partition and collect all batches - let task_ctx = Arc::new(TaskContext::default()); - let mut stream = exec.execute(i, task_ctx.clone())?; - let mut batches = vec![]; - while let Some(result) = stream.next().await { - batches.push(result?); - } - output_partitions.push(batches); - } - Ok(output_partitions) - } - - /// Create vector batches - fn create_vec_batches(schema: &Schema, n: usize) -> Vec { - let batch = create_batch(schema); - let mut vec = Vec::with_capacity(n); - for _ in 0..n { - vec.push(batch.clone()); - } - vec - } - - /// Create batch - fn create_batch(schema: &Schema) -> RecordBatch { - RecordBatch::try_new( - Arc::new(schema.clone()), - vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))], - ) - .unwrap() + self.coalescer.schema() } } diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 1c725ce31f14..3da101d6092f 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -30,12 +30,13 @@ use super::{ use crate::{DisplayFormatType, ExecutionPlan, Partitioning}; +use crate::execution_plan::CardinalityEffect; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; /// Merge execution plan executes partitions in parallel and combines them into a single /// partition. No guarantees are made about the order of the resulting partition. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CoalescePartitionsExec { /// Input execution plan input: Arc, @@ -65,7 +66,7 @@ impl CoalescePartitionsExec { // Coalescing partitions loses existing orderings: let mut eq_properties = input.equivalence_properties().clone(); eq_properties.clear_orderings(); - + eq_properties.clear_per_partition_constants(); PlanProperties::new( eq_properties, // Equivalence Properties Partitioning::UnknownPartitioning(1), // Output Partitioning @@ -102,8 +103,8 @@ impl ExecutionPlan for CoalescePartitionsExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn benefits_from_input_partitioning(&self) -> Vec { @@ -114,7 +115,9 @@ impl ExecutionPlan for CoalescePartitionsExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(CoalescePartitionsExec::new(children[0].clone()))) + Ok(Arc::new(CoalescePartitionsExec::new(Arc::clone( + &children[0], + )))) } fn execute( @@ -152,7 +155,11 @@ impl ExecutionPlan for CoalescePartitionsExec { // spawn independent tasks whose resulting streams (of batches) // are sent to the channel for consumption. for part_i in 0..input_partitions { - builder.run_input(self.input.clone(), part_i, context.clone()); + builder.run_input( + Arc::clone(&self.input), + part_i, + Arc::clone(&context), + ); } let stream = builder.build(); @@ -168,6 +175,14 @@ impl ExecutionPlan for CoalescePartitionsExec { fn statistics(&self) -> Result { self.input.statistics() } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } #[cfg(test)] @@ -221,10 +236,10 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); - let coaelesce_partitions_exec = + let coalesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec, task_ctx); + let fut = collect(coalesce_partitions_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index f7cad9df4ba1..844208999d25 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -22,9 +22,9 @@ use std::fs::{metadata, File}; use std::path::{Path, PathBuf}; use std::sync::Arc; -use super::{ExecutionPlanProperties, SendableRecordBatchStream}; +use super::SendableRecordBatchStream; use crate::stream::RecordBatchReceiverStream; -use crate::{ColumnStatistics, ExecutionPlan, Statistics}; +use crate::{ColumnStatistics, Statistics}; use arrow::datatypes::Schema; use arrow::ipc::writer::{FileWriter, IpcWriteOptions}; @@ -33,8 +33,6 @@ use arrow_array::Array; use datafusion_common::stats::Precision; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_execution::memory_pool::MemoryReservation; -use datafusion_physical_expr::expressions::{BinaryExpr, Column}; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; use futures::{StreamExt, TryStreamExt}; use parking_lot::Mutex; @@ -111,7 +109,7 @@ pub(crate) fn spawn_buffered( builder.spawn(async move { while let Some(item) = input.next().await { if sender.send(item).await.is_err() { - // receiver dropped when query is shutdown early (e.g., limit) or error, + // Receiver dropped when query is shutdown early (e.g., limit) or error, // no need to return propagate the send error. return Ok(()); } @@ -153,16 +151,27 @@ pub fn compute_record_batch_statistics( }) .sum(); - let mut column_statistics = vec![ColumnStatistics::new_unknown(); projection.len()]; + let mut null_counts = vec![0; projection.len()]; for partition in batches.iter() { for batch in partition { for (stat_index, col_index) in projection.iter().enumerate() { - column_statistics[stat_index].null_count = - Precision::Exact(batch.column(*col_index).null_count()); + null_counts[stat_index] += batch + .column(*col_index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default(); } } } + let column_statistics = null_counts + .into_iter() + .map(|null_count| { + let mut s = ColumnStatistics::new_unknown(); + s.null_count = Precision::Exact(null_count); + s + }) + .collect(); Statistics { num_rows: Precision::Exact(nb_rows), @@ -171,99 +180,18 @@ pub fn compute_record_batch_statistics( } } -/// Transposes the given vector of vectors. -pub fn transpose(original: Vec>) -> Vec> { - match original.as_slice() { - [] => vec![], - [first, ..] => { - let mut result = (0..first.len()).map(|_| vec![]).collect::>(); - for row in original { - for (item, transposed_row) in row.into_iter().zip(&mut result) { - transposed_row.push(item); - } - } - result - } - } -} - -/// Calculates the "meet" of given orderings. -/// The meet is the finest ordering that satisfied by all the given -/// orderings, see . -pub fn get_meet_of_orderings( - given: &[Arc], -) -> Option<&[PhysicalSortExpr]> { - given - .iter() - .map(|item| item.output_ordering()) - .collect::>>() - .and_then(get_meet_of_orderings_helper) -} - -fn get_meet_of_orderings_helper( - orderings: Vec<&[PhysicalSortExpr]>, -) -> Option<&[PhysicalSortExpr]> { - let mut idx = 0; - let first = orderings[0]; - loop { - for ordering in orderings.iter() { - if idx >= ordering.len() { - return Some(ordering); - } else { - let schema_aligned = check_expr_alignment( - ordering[idx].expr.as_ref(), - first[idx].expr.as_ref(), - ); - if !schema_aligned || (ordering[idx].options != first[idx].options) { - // In a union, the output schema is that of the first child (by convention). - // Therefore, generate the result from the first child's schema: - return if idx > 0 { Some(&first[..idx]) } else { None }; - } - } - } - idx += 1; - } - - fn check_expr_alignment(first: &dyn PhysicalExpr, second: &dyn PhysicalExpr) -> bool { - match ( - first.as_any().downcast_ref::(), - second.as_any().downcast_ref::(), - first.as_any().downcast_ref::(), - second.as_any().downcast_ref::(), - ) { - (Some(first_col), Some(second_col), _, _) => { - first_col.index() == second_col.index() - } - (_, _, Some(first_binary), Some(second_binary)) => { - if first_binary.op() == second_binary.op() { - check_expr_alignment( - first_binary.left().as_ref(), - second_binary.left().as_ref(), - ) && check_expr_alignment( - first_binary.right().as_ref(), - second_binary.right().as_ref(), - ) - } else { - false - } - } - (_, _, _, _) => false, - } - } -} - /// Write in Arrow IPC format. pub struct IPCWriter { - /// path + /// Path pub path: PathBuf, - /// inner writer + /// Inner writer pub writer: FileWriter, - /// batches written - pub num_batches: u64, - /// rows written - pub num_rows: u64, - /// bytes written - pub num_bytes: u64, + /// Batches written + pub num_batches: usize, + /// Rows written + pub num_rows: usize, + /// Bytes written + pub num_bytes: usize, } impl IPCWriter { @@ -306,9 +234,9 @@ impl IPCWriter { pub fn write(&mut self, batch: &RecordBatch) -> Result<()> { self.writer.write(batch)?; self.num_batches += 1; - self.num_rows += batch.num_rows() as u64; + self.num_rows += batch.num_rows(); let num_bytes: usize = batch.get_array_memory_size(); - self.num_bytes += num_bytes as u64; + self.num_bytes += num_bytes; Ok(()) } @@ -351,296 +279,12 @@ pub fn can_project( #[cfg(test)] mod tests { - use std::ops::Not; - use super::*; - use crate::memory::MemoryExec; - use crate::sorts::sort::SortExec; - use crate::union::UnionExec; - use arrow::compute::SortOptions; use arrow::{ array::{Float32Array, Float64Array, UInt64Array}, datatypes::{DataType, Field}, }; - use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::col; - - #[test] - fn get_meet_of_orderings_helper_common_prefix_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("e", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("f", 2)), - options: SortOptions::default(), - }, - ]; - - let input4: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("g", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("h", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - // Note that index of this column is not 2. Hence this 3rd entry shouldn't be - // in the output ordering. - expr: Arc::new(Column::new("i", 3)), - options: SortOptions::default(), - }, - ]; - - let expected = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input3]); - assert_eq!(result.unwrap(), expected); - - let expected = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input4]); - assert_eq!(result.unwrap(), expected); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_subset_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("d", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("e", 2)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("f", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("g", 1)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("h", 2)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2, &input3]); - assert_eq!(result.unwrap(), input1); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_no_overlap_test() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 0)), - // Since ordering is conflicting with other inputs - // output ordering should be empty - options: SortOptions::default().not(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("b", 1)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("x", 0)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 1)), - options: SortOptions::default(), - }, - ]; - - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(Column::new("a", 2)), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("y", 1)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input2, &input3]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input1, &input3]); - assert!(result.is_none()); - Ok(()) - } - - #[test] - fn get_meet_of_orderings_helper_binary_exprs() -> Result<()> { - let input1: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 0)), - Operator::Plus, - Arc::new(Column::new("b", 1)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let input2: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("x", 0)), - Operator::Plus, - Arc::new(Column::new("y", 1)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("z", 2)), - options: SortOptions::default(), - }, - ]; - - // erroneous input - let input3: Vec = vec![ - PhysicalSortExpr { - expr: Arc::new(BinaryExpr::new( - Arc::new(Column::new("a", 1)), - Operator::Plus, - Arc::new(Column::new("b", 0)), - )), - options: SortOptions::default(), - }, - PhysicalSortExpr { - expr: Arc::new(Column::new("c", 2)), - options: SortOptions::default(), - }, - ]; - - let result = get_meet_of_orderings_helper(vec![&input1, &input2]); - assert_eq!(input1, result.unwrap()); - - let result = get_meet_of_orderings_helper(vec![&input2, &input3]); - assert!(result.is_none()); - - let result = get_meet_of_orderings_helper(vec![&input1, &input3]); - assert!(result.is_none()); - Ok(()) - } - - #[test] - fn test_meet_of_orderings() -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("f32", DataType::Float32, false), - Field::new("f64", DataType::Float64, false), - ])); - let sort_expr = vec![PhysicalSortExpr { - expr: col("f32", &schema).unwrap(), - options: SortOptions::default(), - }]; - let memory_exec = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) as _; - let sort_exec = Arc::new(SortExec::new(sort_expr.clone(), memory_exec)) - as Arc; - let memory_exec2 = Arc::new(MemoryExec::try_new(&[], schema, None)?) as _; - // memory_exec2 doesn't have output ordering - let union_exec = UnionExec::new(vec![sort_exec.clone(), memory_exec2]); - let res = get_meet_of_orderings(union_exec.inputs()); - assert!(res.is_none()); - - let union_exec = UnionExec::new(vec![sort_exec.clone(), sort_exec]); - let res = get_meet_of_orderings(union_exec.inputs()); - assert_eq!(res, Some(&sort_expr[..])); - Ok(()) - } #[test] fn test_compute_record_batch_statistics_empty() -> Result<()> { @@ -671,7 +315,7 @@ mod tests { ], )?; - // just select f32,f64 + // Just select f32,f64 let select_projection = Some(vec![0, 1]); let byte_size = batch .project(&select_projection.clone().unwrap()) @@ -705,11 +349,33 @@ mod tests { } #[test] - fn test_transpose() -> Result<()> { - let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]]; - let transposed = transpose(in_data); - let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]]; - assert_eq!(expected, transposed); + fn test_compute_record_batch_statistics_null() -> Result<()> { + let schema = + Arc::new(Schema::new(vec![Field::new("u64", DataType::UInt64, true)])); + let batch1 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt64Array::from(vec![Some(1), None, None]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(UInt64Array::from(vec![Some(1), Some(2), None]))], + )?; + let byte_size = batch1.get_array_memory_size() + batch2.get_array_memory_size(); + let actual = + compute_record_batch_statistics(&[vec![batch1], vec![batch2]], &schema, None); + + let expected = Statistics { + num_rows: Precision::Exact(6), + total_byte_size: Precision::Exact(byte_size), + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + null_count: Precision::Exact(3), + }], + }; + + assert_eq!(actual, expected); Ok(()) } } diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index ca93ce5e7b83..9f3a76e28577 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -21,11 +21,13 @@ use std::fmt; use std::fmt::Formatter; -use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; - use arrow_schema::SchemaRef; + use datafusion_common::display::{GraphvizBuilder, PlanType, StringifiedPlan}; -use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_expr::display_schema; +use datafusion_physical_expr::LexOrdering; + +use super::{accept, ExecutionPlan, ExecutionPlanVisitor}; /// Options for controlling how each [`ExecutionPlan`] should format itself #[derive(Debug, Clone, Copy)] @@ -37,12 +39,15 @@ pub enum DisplayFormatType { } /// Wraps an `ExecutionPlan` with various ways to display this plan +#[derive(Debug, Clone)] pub struct DisplayableExecutionPlan<'a> { inner: &'a dyn ExecutionPlan, /// How to show metrics show_metrics: ShowMetrics, /// If statistics should be displayed show_statistics: bool, + /// If schema should be displayed. See [`Self::set_show_schema`] + show_schema: bool, } impl<'a> DisplayableExecutionPlan<'a> { @@ -53,6 +58,7 @@ impl<'a> DisplayableExecutionPlan<'a> { inner, show_metrics: ShowMetrics::None, show_statistics: false, + show_schema: false, } } @@ -64,6 +70,7 @@ impl<'a> DisplayableExecutionPlan<'a> { inner, show_metrics: ShowMetrics::Aggregated, show_statistics: false, + show_schema: false, } } @@ -75,9 +82,19 @@ impl<'a> DisplayableExecutionPlan<'a> { inner, show_metrics: ShowMetrics::Full, show_statistics: false, + show_schema: false, } } + /// Enable display of schema + /// + /// If true, plans will be displayed with schema information at the end + /// of each line. The format is `schema=[[a:Int32;N, b:Int32;N, c:Int32;N]]` + pub fn set_show_schema(mut self, show_schema: bool) -> Self { + self.show_schema = show_schema; + self + } + /// Enable display of statistics pub fn set_show_statistics(mut self, show_statistics: bool) -> Self { self.show_statistics = show_statistics; @@ -105,15 +122,17 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: &'a dyn ExecutionPlan, show_metrics: ShowMetrics, show_statistics: bool, + show_schema: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { t: self.format_type, f, indent: 0, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, }; accept(self.plan, &mut visitor) } @@ -123,6 +142,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: self.inner, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, } } @@ -144,7 +164,7 @@ impl<'a> DisplayableExecutionPlan<'a> { show_statistics: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let t = DisplayFormatType::Default; let mut visitor = GraphvizVisitor { @@ -179,16 +199,18 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: &'a dyn ExecutionPlan, show_metrics: ShowMetrics, show_statistics: bool, + show_schema: bool, } impl<'a> fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let mut visitor = IndentVisitor { f, t: DisplayFormatType::Default, indent: 0, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, }; visitor.pre_visit(self.plan)?; Ok(()) @@ -199,6 +221,7 @@ impl<'a> DisplayableExecutionPlan<'a> { plan: self.inner, show_metrics: self.show_metrics, show_statistics: self.show_statistics, + show_schema: self.show_schema, } } @@ -208,12 +231,13 @@ impl<'a> DisplayableExecutionPlan<'a> { } } +/// Enum representing the different levels of metrics to display #[derive(Debug, Clone, Copy)] enum ShowMetrics { /// Do not show any metrics None, - /// Show aggregrated metrics across partition + /// Show aggregated metrics across partition Aggregated, /// Show full per-partition metrics @@ -221,17 +245,27 @@ enum ShowMetrics { } /// Formats plans with a single line per node. +/// +/// # Example +/// +/// ```text +/// ProjectionExec: expr=[column1@0 + 2 as column1 + Int64(2)] +/// FilterExec: column1@0 = 5 +/// ValuesExec +/// ``` struct IndentVisitor<'a, 'b> { /// How to format each node t: DisplayFormatType, /// Write to this formatter - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// Indent size indent: usize, /// How to show metrics show_metrics: ShowMetrics, /// If statistics should be displayed show_statistics: bool, + /// If schema should be displayed + show_schema: bool, } impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { @@ -265,6 +299,13 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { let stats = plan.statistics().map_err(|_e| fmt::Error)?; write!(self.f, ", statistics=[{}]", stats)?; } + if self.show_schema { + write!( + self.f, + ", schema={}", + display_schema(plan.schema().as_ref()) + )?; + } writeln!(self.f)?; self.indent += 1; Ok(true) @@ -277,7 +318,7 @@ impl<'a, 'b> ExecutionPlanVisitor for IndentVisitor<'a, 'b> { } struct GraphvizVisitor<'a, 'b> { - f: &'a mut fmt::Formatter<'b>, + f: &'a mut Formatter<'b>, /// How to format each node t: DisplayFormatType, /// How to show metrics @@ -308,8 +349,8 @@ impl ExecutionPlanVisitor for GraphvizVisitor<'_, '_> { struct Wrapper<'a>(&'a dyn ExecutionPlan, DisplayFormatType); - impl<'a> std::fmt::Display for Wrapper<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + impl<'a> fmt::Display for Wrapper<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(self.1, f) } } @@ -381,14 +422,14 @@ pub trait DisplayAs { /// different from the default one /// /// Should not include a newline - fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result; + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result; } /// A newtype wrapper to display `T` implementing`DisplayAs` using the `Default` mode pub struct DefaultDisplay(pub T); impl fmt::Display for DefaultDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Default, f) } } @@ -397,7 +438,7 @@ impl fmt::Display for DefaultDisplay { pub struct VerboseDisplay(pub T); impl fmt::Display for VerboseDisplay { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { self.0.fmt_as(DisplayFormatType::Verbose, f) } } @@ -407,7 +448,7 @@ impl fmt::Display for VerboseDisplay { pub struct ProjectSchemaDisplay<'a>(pub &'a SchemaRef); impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { let parts: Vec<_> = self .0 .fields() @@ -418,23 +459,6 @@ impl<'a> fmt::Display for ProjectSchemaDisplay<'a> { } } -/// A wrapper to customize output ordering display. -#[derive(Debug)] -pub struct OutputOrderingDisplay<'a>(pub &'a [PhysicalSortExpr]); - -impl<'a> fmt::Display for OutputOrderingDisplay<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "[")?; - for (i, e) in self.0.iter().enumerate() { - if i > 0 { - write!(f, ", ")? - } - write!(f, "{e}")?; - } - write!(f, "]") - } -} - pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::Result { if let Some(ordering) = orderings.first() { if !ordering.is_empty() { @@ -448,8 +472,8 @@ pub fn display_orderings(f: &mut Formatter, orderings: &[LexOrdering]) -> fmt::R orderings.iter().enumerate().filter(|(_, o)| !o.is_empty()) { match idx { - 0 => write!(f, "{}", OutputOrderingDisplay(ordering))?, - _ => write!(f, ", {}", OutputOrderingDisplay(ordering))?, + 0 => write!(f, "[{}]", ordering)?, + _ => write!(f, ", [{}]", ordering)?, } } let end = if orderings.len() == 1 { "" } else { "]" }; @@ -465,12 +489,13 @@ mod tests { use std::fmt::Write; use std::sync::Arc; - use super::DisplayableExecutionPlan; - use crate::{DisplayAs, ExecutionPlan, PlanProperties}; - use datafusion_common::{DataFusionError, Result, Statistics}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + use crate::{DisplayAs, ExecutionPlan, PlanProperties}; + + use super::DisplayableExecutionPlan; + #[derive(Debug, Clone, Copy)] enum TestStatsExecPlan { Panic, @@ -501,7 +526,7 @@ mod tests { unimplemented!() } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 33bf1668b3c9..192619f69f6a 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -35,7 +35,7 @@ use datafusion_physical_expr::EquivalenceProperties; use log::trace; /// Execution plan for empty relation with produce_one_row=false -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct EmptyExec { /// The schema for the produced row schema: SchemaRef, @@ -47,7 +47,7 @@ pub struct EmptyExec { impl EmptyExec { /// Create a new EmptyExec pub fn new(schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone(), 1); + let cache = Self::compute_properties(Arc::clone(&schema), 1); EmptyExec { schema, partitions: 1, @@ -114,7 +114,7 @@ impl ExecutionPlan for EmptyExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -142,7 +142,7 @@ impl ExecutionPlan for EmptyExec { Ok(Box::pin(MemoryStream::try_new( self.data()?, - self.schema.clone(), + Arc::clone(&self.schema), None, )?)) } @@ -170,10 +170,10 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let schema = test::aggr_test_schema(); - let empty = EmptyExec::new(schema.clone()); + let empty = EmptyExec::new(Arc::clone(&schema)); assert_eq!(empty.schema(), schema); - // we should have no results + // We should have no results let iter = empty.execute(0, task_ctx)?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); @@ -184,9 +184,12 @@ mod tests { #[test] fn with_new_children() -> Result<()> { let schema = test::aggr_test_schema(); - let empty = Arc::new(EmptyExec::new(schema.clone())); + let empty = Arc::new(EmptyExec::new(Arc::clone(&schema))); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?; + let empty2 = with_new_children_if_necessary( + Arc::clone(&empty) as Arc, + vec![], + )?; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; @@ -204,7 +207,7 @@ mod tests { let empty = EmptyExec::new(schema); // ask for the wrong partition - assert!(empty.execute(1, task_ctx.clone()).is_err()); + assert!(empty.execute(1, Arc::clone(&task_ctx)).is_err()); assert!(empty.execute(20, task_ctx).is_err()); Ok(()) } diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs new file mode 100644 index 000000000000..d65320dbab68 --- /dev/null +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -0,0 +1,1210 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use arrow::record_batch::RecordBatch; +use arrow_array::Array; +use futures::stream::{StreamExt, TryStreamExt}; +use tokio::task::JoinSet; + +use datafusion_common::config::ConfigOptions; +pub use datafusion_common::hash_utils; +pub use datafusion_common::utils::project_schema; +use datafusion_common::{exec_err, Result}; +pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +use datafusion_execution::TaskContext; +pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +pub use datafusion_expr::{Accumulator, ColumnarValue}; +pub use datafusion_physical_expr::window::WindowExpr; +pub use datafusion_physical_expr::{ + expressions, udf, Distribution, Partitioning, PhysicalExpr, +}; +use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use datafusion_physical_expr_common::sort_expr::{LexOrderingRef, LexRequirement}; + +use crate::coalesce_partitions::CoalescePartitionsExec; +use crate::display::DisplayableExecutionPlan; +pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub use crate::metrics::Metric; +use crate::metrics::MetricsSet; +pub use crate::ordering::InputOrderMode; +use crate::repartition::RepartitionExec; +use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; +pub use crate::stream::EmptyRecordBatchStream; +use crate::stream::RecordBatchStreamAdapter; + +/// Represent nodes in the DataFusion Physical Plan. +/// +/// Calling [`execute`] produces an `async` [`SendableRecordBatchStream`] of +/// [`RecordBatch`] that incrementally computes a partition of the +/// `ExecutionPlan`'s output from its input. See [`Partitioning`] for more +/// details on partitioning. +/// +/// Methods such as [`Self::schema`] and [`Self::properties`] communicate +/// properties of the output to the DataFusion optimizer, and methods such as +/// [`required_input_distribution`] and [`required_input_ordering`] express +/// requirements of the `ExecutionPlan` from its input. +/// +/// [`ExecutionPlan`] can be displayed in a simplified form using the +/// return value from [`displayable`] in addition to the (normally +/// quite verbose) `Debug` output. +/// +/// [`execute`]: ExecutionPlan::execute +/// [`required_input_distribution`]: ExecutionPlan::required_input_distribution +/// [`required_input_ordering`]: ExecutionPlan::required_input_ordering +pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { + /// Short name for the ExecutionPlan, such as 'ParquetExec'. + /// + /// Implementation note: this method can just proxy to + /// [`static_name`](ExecutionPlan::static_name) if no special action is + /// needed. It doesn't provide a default implementation like that because + /// this method doesn't require the `Sized` constrain to allow a wilder + /// range of use cases. + fn name(&self) -> &str; + + /// Short name for the ExecutionPlan, such as 'ParquetExec'. + /// Like [`name`](ExecutionPlan::name) but can be called without an instance. + fn static_name() -> &'static str + where + Self: Sized, + { + let full_name = std::any::type_name::(); + let maybe_start_idx = full_name.rfind(':'); + match maybe_start_idx { + Some(start_idx) => &full_name[start_idx + 1..], + None => "UNKNOWN", + } + } + + /// Returns the execution plan as [`Any`] so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// Get the schema for this execution plan + fn schema(&self) -> SchemaRef { + Arc::clone(self.properties().schema()) + } + + /// Return properties of the output of the `ExecutionPlan`, such as output + /// ordering(s), partitioning information etc. + /// + /// This information is available via methods on [`ExecutionPlanProperties`] + /// trait, which is implemented for all `ExecutionPlan`s. + fn properties(&self) -> &PlanProperties; + + /// Specifies the data distribution requirements for all the + /// children for this `ExecutionPlan`, By default it's [[Distribution::UnspecifiedDistribution]] for each child, + fn required_input_distribution(&self) -> Vec { + vec![Distribution::UnspecifiedDistribution; self.children().len()] + } + + /// Specifies the ordering required for all of the children of this + /// `ExecutionPlan`. + /// + /// For each child, it's the local ordering requirement within + /// each partition rather than the global ordering + /// + /// NOTE that checking `!is_empty()` does **not** check for a + /// required input ordering. Instead, the correct check is that at + /// least one entry must be `Some` + fn required_input_ordering(&self) -> Vec> { + vec![None; self.children().len()] + } + + /// Returns `false` if this `ExecutionPlan`'s implementation may reorder + /// rows within or between partitions. + /// + /// For example, Projection, Filter, and Limit maintain the order + /// of inputs -- they may transform values (Projection) or not + /// produce the same number of rows that went in (Filter and + /// Limit), but the rows that are produced go in the same way. + /// + /// DataFusion uses this metadata to apply certain optimizations + /// such as automatically repartitioning correctly. + /// + /// The default implementation returns `false` + /// + /// WARNING: if you override this default, you *MUST* ensure that + /// the `ExecutionPlan`'s maintains the ordering invariant or else + /// DataFusion may produce incorrect results. + fn maintains_input_order(&self) -> Vec { + vec![false; self.children().len()] + } + + /// Specifies whether the `ExecutionPlan` benefits from increased + /// parallelization at its input for each child. + /// + /// If returns `true`, the `ExecutionPlan` would benefit from partitioning + /// its corresponding child (and thus from more parallelism). For + /// `ExecutionPlan` that do very little work the overhead of extra + /// parallelism may outweigh any benefits + /// + /// The default implementation returns `true` unless this `ExecutionPlan` + /// has signalled it requires a single child input partition. + fn benefits_from_input_partitioning(&self) -> Vec { + // By default try to maximize parallelism with more CPUs if + // possible + self.required_input_distribution() + .into_iter() + .map(|dist| !matches!(dist, Distribution::SinglePartition)) + .collect() + } + + /// Get a list of children `ExecutionPlan`s that act as inputs to this plan. + /// The returned list will be empty for leaf nodes such as scans, will contain + /// a single value for unary nodes, or two values for binary nodes (such as + /// joins). + fn children(&self) -> Vec<&Arc>; + + /// Returns a new `ExecutionPlan` where all existing children were replaced + /// by the `children`, in order + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result>; + + /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to + /// produce `target_partitions` partitions. + /// + /// If the `ExecutionPlan` does not support changing its partitioning, + /// returns `Ok(None)` (the default). + /// + /// It is the `ExecutionPlan` can increase its partitioning, but not to the + /// `target_partitions`, it may return an ExecutionPlan with fewer + /// partitions. This might happen, for example, if each new partition would + /// be too small to be efficiently processed individually. + /// + /// The DataFusion optimizer attempts to use as many threads as possible by + /// repartitioning its inputs to match the target number of threads + /// available (`target_partitions`). Some data sources, such as the built in + /// CSV and Parquet readers, implement this method as they are able to read + /// from their input files in parallel, regardless of how the source data is + /// split amongst files. + fn repartitioned( + &self, + _target_partitions: usize, + _config: &ConfigOptions, + ) -> Result>> { + Ok(None) + } + + /// Begin execution of `partition`, returning a [`Stream`] of + /// [`RecordBatch`]es. + /// + /// # Notes + /// + /// The `execute` method itself is not `async` but it returns an `async` + /// [`futures::stream::Stream`]. This `Stream` should incrementally compute + /// the output, `RecordBatch` by `RecordBatch` (in a streaming fashion). + /// Most `ExecutionPlan`s should not do any work before the first + /// `RecordBatch` is requested from the stream. + /// + /// [`RecordBatchStreamAdapter`] can be used to convert an `async` + /// [`Stream`] into a [`SendableRecordBatchStream`]. + /// + /// Using `async` `Streams` allows for network I/O during execution and + /// takes advantage of Rust's built in support for `async` continuations and + /// crate ecosystem. + /// + /// [`Stream`]: futures::stream::Stream + /// [`StreamExt`]: futures::stream::StreamExt + /// [`TryStreamExt`]: futures::stream::TryStreamExt + /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter + /// + /// # Error handling + /// + /// Any error that occurs during execution is sent as an `Err` in the output + /// stream. + /// + /// `ExecutionPlan` implementations in DataFusion cancel additional work + /// immediately once an error occurs. The rationale is that if the overall + /// query will return an error, any additional work such as continued + /// polling of inputs will be wasted as it will be thrown away. + /// + /// # Cancellation / Aborting Execution + /// + /// The [`Stream`] that is returned must ensure that any allocated resources + /// are freed when the stream itself is dropped. This is particularly + /// important for [`spawn`]ed tasks or threads. Unless care is taken to + /// "abort" such tasks, they may continue to consume resources even after + /// the plan is dropped, generating intermediate results that are never + /// used. + /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`]. + /// + /// For more details see [`SpawnedTask`], [`JoinSet`] and [`RecordBatchReceiverStreamBuilder`] + /// for structures to help ensure all background tasks are cancelled. + /// + /// [`spawn`]: tokio::task::spawn + /// [`JoinSet`]: tokio::task::JoinSet + /// [`SpawnedTask`]: datafusion_common_runtime::SpawnedTask + /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder + /// + /// # Implementation Examples + /// + /// While `async` `Stream`s have a non trivial learning curve, the + /// [`futures`] crate provides [`StreamExt`] and [`TryStreamExt`] + /// which help simplify many common operations. + /// + /// Here are some common patterns: + /// + /// ## Return Precomputed `RecordBatch` + /// + /// We can return a precomputed `RecordBatch` as a `Stream`: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// batch: RecordBatch, + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // use functions from futures crate convert the batch into a stream + /// let fut = futures::future::ready(Ok(self.batch.clone())); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.batch.schema(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) Compute `RecordBatch` + /// + /// We can also lazily compute a `RecordBatch` when the returned `Stream` is polled + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// Returns a single batch when the returned stream is polled + /// async fn get_batch() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// let fut = get_batch(); + /// let stream = futures::stream::once(fut); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + /// + /// ## Lazily (async) create a Stream + /// + /// If you need to create the return `Stream` using an `async` function, + /// you can do so by flattening the result: + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::RecordBatch; + /// # use arrow_schema::SchemaRef; + /// # use futures::TryStreamExt; + /// # use datafusion_common::Result; + /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + /// # use datafusion_physical_plan::memory::MemoryStream; + /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; + /// struct MyPlan { + /// schema: SchemaRef, + /// } + /// + /// /// async function that returns a stream + /// async fn get_batch_stream() -> Result { + /// todo!() + /// } + /// + /// impl MyPlan { + /// fn execute( + /// &self, + /// partition: usize, + /// context: Arc + /// ) -> Result { + /// // A future that yields a stream + /// let fut = get_batch_stream(); + /// // Use TryStreamExt::try_flatten to flatten the stream of streams + /// let stream = futures::stream::once(fut).try_flatten(); + /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) + /// } + /// } + /// ``` + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result; + + /// Return a snapshot of the set of [`Metric`]s for this + /// [`ExecutionPlan`]. If no `Metric`s are available, return None. + /// + /// While the values of the metrics in the returned + /// [`MetricsSet`]s may change as execution progresses, the + /// specific metrics will not. + /// + /// Once `self.execute()` has returned (technically the future is + /// resolved) for all available partitions, the set of metrics + /// should be complete. If this function is called prior to + /// `execute()` new metrics may appear in subsequent calls. + fn metrics(&self) -> Option { + None + } + + /// Returns statistics for this `ExecutionPlan` node. If statistics are not + /// available, should return [`Statistics::new_unknown`] (the default), not + /// an error. + /// + /// For TableScan executors, which supports filter pushdown, special attention + /// needs to be paid to whether the stats returned by this method are exact or not + fn statistics(&self) -> Result { + Ok(Statistics::new_unknown(&self.schema())) + } + + /// Returns `true` if a limit can be safely pushed down through this + /// `ExecutionPlan` node. + /// + /// If this method returns `true`, and the query plan contains a limit at + /// the output of this node, DataFusion will push the limit to the input + /// of this node. + fn supports_limit_pushdown(&self) -> bool { + false + } + + /// Returns a fetching variant of this `ExecutionPlan` node, if it supports + /// fetch limits. Returns `None` otherwise. + fn with_fetch(&self, _limit: Option) -> Option> { + None + } + + /// Gets the fetch count for the operator, `None` means there is no fetch. + fn fetch(&self) -> Option { + None + } + + /// Gets the effect on cardinality, if known + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Unknown + } +} + +/// Extension trait provides an easy API to fetch various properties of +/// [`ExecutionPlan`] objects based on [`ExecutionPlan::properties`]. +pub trait ExecutionPlanProperties { + /// Specifies how the output of this `ExecutionPlan` is split into + /// partitions. + fn output_partitioning(&self) -> &Partitioning; + + /// Specifies whether this plan generates an infinite stream of records. + /// If the plan does not support pipelining, but its input(s) are + /// infinite, returns [`ExecutionMode::PipelineBreaking`] to indicate this. + fn execution_mode(&self) -> ExecutionMode; + + /// If the output of this `ExecutionPlan` within each partition is sorted, + /// returns `Some(keys)` describing the ordering. A `None` return value + /// indicates no assumptions should be made on the output ordering. + /// + /// For example, `SortExec` (obviously) produces sorted output as does + /// `SortPreservingMergeStream`. Less obviously, `Projection` produces sorted + /// output if its input is sorted as it does not reorder the input rows. + fn output_ordering(&self) -> Option; + + /// Get the [`EquivalenceProperties`] within the plan. + /// + /// Equivalence properties tell DataFusion what columns are known to be + /// equal, during various optimization passes. By default, this returns "no + /// known equivalences" which is always correct, but may cause DataFusion to + /// unnecessarily resort data. + /// + /// If this ExecutionPlan makes no changes to the schema of the rows flowing + /// through it or how columns within each row relate to each other, it + /// should return the equivalence properties of its input. For + /// example, since `FilterExec` may remove rows from its input, but does not + /// otherwise modify them, it preserves its input equivalence properties. + /// However, since `ProjectionExec` may calculate derived expressions, it + /// needs special handling. + /// + /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] + /// for related concepts. + fn equivalence_properties(&self) -> &EquivalenceProperties; +} + +impl ExecutionPlanProperties for Arc { + fn output_partitioning(&self) -> &Partitioning { + self.properties().output_partitioning() + } + + fn execution_mode(&self) -> ExecutionMode { + self.properties().execution_mode() + } + + fn output_ordering(&self) -> Option { + self.properties().output_ordering() + } + + fn equivalence_properties(&self) -> &EquivalenceProperties { + self.properties().equivalence_properties() + } +} + +impl ExecutionPlanProperties for &dyn ExecutionPlan { + fn output_partitioning(&self) -> &Partitioning { + self.properties().output_partitioning() + } + + fn execution_mode(&self) -> ExecutionMode { + self.properties().execution_mode() + } + + fn output_ordering(&self) -> Option { + self.properties().output_ordering() + } + + fn equivalence_properties(&self) -> &EquivalenceProperties { + self.properties().equivalence_properties() + } +} + +/// Describes the execution mode of the result of calling +/// [`ExecutionPlan::execute`] with respect to its size and behavior. +/// +/// The mode of the execution plan is determined by the mode of its input +/// execution plans and the details of the operator itself. For example, a +/// `FilterExec` operator will have the same execution mode as its input, but a +/// `SortExec` operator may have a different execution mode than its input, +/// depending on how the input stream is sorted. +/// +/// There are three possible execution modes: `Bounded`, `Unbounded` and +/// `PipelineBreaking`. +#[derive(Clone, Copy, PartialEq, Debug)] +pub enum ExecutionMode { + /// The stream is bounded / finite. + /// + /// In this case the stream will eventually return `None` to indicate that + /// there are no more records to process. + Bounded, + /// The stream is unbounded / infinite. + /// + /// In this case, the stream will never be done (never return `None`), + /// except in case of error. + /// + /// This mode is often used in "Steaming" use cases where data is + /// incrementally processed as it arrives. + /// + /// Note that even though the operator generates an unbounded stream of + /// results, it can execute with bounded memory and incrementally produces + /// output. + Unbounded, + /// Some of the operator's input stream(s) are unbounded, but the operator + /// cannot generate streaming results from these streaming inputs. + /// + /// In this case, the execution mode will be pipeline breaking, e.g. the + /// operator requires unbounded memory to generate results. This + /// information is used by the planner when performing sanity checks + /// on plans processings unbounded data sources. + PipelineBreaking, +} + +impl ExecutionMode { + /// Check whether the execution mode is unbounded or not. + pub fn is_unbounded(&self) -> bool { + matches!(self, ExecutionMode::Unbounded) + } + + /// Check whether the execution is pipeline friendly. If so, operator can + /// execute safely. + pub fn pipeline_friendly(&self) -> bool { + matches!(self, ExecutionMode::Bounded | ExecutionMode::Unbounded) + } +} + +/// Conservatively "combines" execution modes of a given collection of operators. +pub(crate) fn execution_mode_from_children<'a>( + children: impl IntoIterator>, +) -> ExecutionMode { + let mut result = ExecutionMode::Bounded; + for mode in children.into_iter().map(|child| child.execution_mode()) { + match (mode, result) { + (ExecutionMode::PipelineBreaking, _) + | (_, ExecutionMode::PipelineBreaking) => { + // If any of the modes is `PipelineBreaking`, so is the result: + return ExecutionMode::PipelineBreaking; + } + (ExecutionMode::Unbounded, _) | (_, ExecutionMode::Unbounded) => { + // Unbounded mode eats up bounded mode: + result = ExecutionMode::Unbounded; + } + (ExecutionMode::Bounded, ExecutionMode::Bounded) => { + // When both modes are bounded, so is the result: + result = ExecutionMode::Bounded; + } + } + } + result +} + +/// Stores certain, often expensive to compute, plan properties used in query +/// optimization. +/// +/// These properties are stored a single structure to permit this information to +/// be computed once and then those cached results used multiple times without +/// recomputation (aka a cache) +#[derive(Debug, Clone)] +pub struct PlanProperties { + /// See [ExecutionPlanProperties::equivalence_properties] + pub eq_properties: EquivalenceProperties, + /// See [ExecutionPlanProperties::output_partitioning] + pub partitioning: Partitioning, + /// See [ExecutionPlanProperties::execution_mode] + pub execution_mode: ExecutionMode, + /// See [ExecutionPlanProperties::output_ordering] + output_ordering: Option, +} + +impl PlanProperties { + /// Construct a new `PlanPropertiesCache` from the + pub fn new( + eq_properties: EquivalenceProperties, + partitioning: Partitioning, + execution_mode: ExecutionMode, + ) -> Self { + // Output ordering can be derived from `eq_properties`. + let output_ordering = eq_properties.output_ordering(); + Self { + eq_properties, + partitioning, + execution_mode, + output_ordering, + } + } + + /// Overwrite output partitioning with its new value. + pub fn with_partitioning(mut self, partitioning: Partitioning) -> Self { + self.partitioning = partitioning; + self + } + + /// Overwrite the execution Mode with its new value. + pub fn with_execution_mode(mut self, execution_mode: ExecutionMode) -> Self { + self.execution_mode = execution_mode; + self + } + + /// Overwrite equivalence properties with its new value. + pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { + // Changing equivalence properties also changes output ordering, so + // make sure to overwrite it: + self.output_ordering = eq_properties.output_ordering(); + self.eq_properties = eq_properties; + self + } + + pub fn equivalence_properties(&self) -> &EquivalenceProperties { + &self.eq_properties + } + + pub fn output_partitioning(&self) -> &Partitioning { + &self.partitioning + } + + pub fn output_ordering(&self) -> Option { + self.output_ordering.as_deref() + } + + pub fn execution_mode(&self) -> ExecutionMode { + self.execution_mode + } + + /// Get schema of the node. + fn schema(&self) -> &SchemaRef { + self.eq_properties.schema() + } +} + +/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful +/// especially for the distributed engine to judge whether need to deal with shuffling. +/// Currently there are 3 kinds of execution plan which needs data exchange +/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s +/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee +/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee +pub fn need_data_exchange(plan: Arc) -> bool { + if let Some(repartition) = plan.as_any().downcast_ref::() { + !matches!( + repartition.properties().output_partitioning(), + Partitioning::RoundRobinBatch(_) + ) + } else if let Some(coalesce) = plan.as_any().downcast_ref::() + { + coalesce.input().output_partitioning().partition_count() > 1 + } else if let Some(sort_preserving_merge) = + plan.as_any().downcast_ref::() + { + sort_preserving_merge + .input() + .output_partitioning() + .partition_count() + > 1 + } else { + false + } +} + +/// Returns a copy of this plan if we change any child according to the pointer comparison. +/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. +pub fn with_new_children_if_necessary( + plan: Arc, + children: Vec>, +) -> Result> { + let old_children = plan.children(); + if children.len() != old_children.len() { + internal_err!("Wrong number of children") + } else if children.is_empty() + || children + .iter() + .zip(old_children.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + { + plan.with_new_children(children) + } else { + Ok(plan) + } +} + +/// Return a [wrapper](DisplayableExecutionPlan) around an +/// [`ExecutionPlan`] which can be displayed in various easier to +/// understand ways. +pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { + DisplayableExecutionPlan::new(plan) +} + +/// Execute the [ExecutionPlan] and collect the results in memory +pub async fn collect( + plan: Arc, + context: Arc, +) -> Result> { + let stream = execute_stream(plan, context)?; + crate::common::collect(stream).await +} + +/// Execute the [ExecutionPlan] and return a single stream of `RecordBatch`es. +/// +/// See [collect] to buffer the `RecordBatch`es in memory. +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources +pub fn execute_stream( + plan: Arc, + context: Arc, +) -> Result { + match plan.output_partitioning().partition_count() { + 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), + 1 => plan.execute(0, context), + 2.. => { + // merge into a single partition + let plan = CoalescePartitionsExec::new(Arc::clone(&plan)); + // CoalescePartitionsExec must produce a single partition + assert_eq!(1, plan.properties().output_partitioning().partition_count()); + plan.execute(0, context) + } + } +} + +/// Execute the [ExecutionPlan] and collect the results in memory +pub async fn collect_partitioned( + plan: Arc, + context: Arc, +) -> Result>> { + let streams = execute_stream_partitioned(plan, context)?; + + let mut join_set = JoinSet::new(); + // Execute the plan and collect the results into batches. + streams.into_iter().enumerate().for_each(|(idx, stream)| { + join_set.spawn(async move { + let result: Result> = stream.try_collect().await; + (idx, result) + }); + }); + + let mut batches = vec![]; + // Note that currently this doesn't identify the thread that panicked + // + // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id + // once it is stable + while let Some(result) = join_set.join_next().await { + match result { + Ok((idx, res)) => batches.push((idx, res?)), + Err(e) => { + if e.is_panic() { + std::panic::resume_unwind(e.into_panic()); + } else { + unreachable!(); + } + } + } + } + + batches.sort_by_key(|(idx, _)| *idx); + let batches = batches.into_iter().map(|(_, batch)| batch).collect(); + + Ok(batches) +} + +/// Execute the [ExecutionPlan] and return a vec with one stream per output +/// partition +/// +/// # Aborting Execution +/// +/// Dropping the stream will abort the execution of the query, and free up +/// any allocated resources +pub fn execute_stream_partitioned( + plan: Arc, + context: Arc, +) -> Result> { + let num_partitions = plan.output_partitioning().partition_count(); + let mut streams = Vec::with_capacity(num_partitions); + for i in 0..num_partitions { + streams.push(plan.execute(i, Arc::clone(&context))?); + } + Ok(streams) +} + +/// Executes an input stream and ensures that the resulting stream adheres to +/// the `not null` constraints specified in the `sink_schema`. +/// +/// # Arguments +/// +/// * `input` - An execution plan +/// * `sink_schema` - The schema to be applied to the output stream +/// * `partition` - The partition index to be executed +/// * `context` - The task context +/// +/// # Returns +/// +/// * `Result` - A stream of `RecordBatch`es if successful +/// +/// This function first executes the given input plan for the specified partition +/// and context. It then checks if there are any columns in the input that might +/// violate the `not null` constraints specified in the `sink_schema`. If there are +/// such columns, it wraps the resulting stream to enforce the `not null` constraints +/// by invoking the `check_not_null_contraits` function on each batch of the stream. +pub fn execute_input_stream( + input: Arc, + sink_schema: SchemaRef, + partition: usize, + context: Arc, +) -> Result { + let input_stream = input.execute(partition, context)?; + + debug_assert_eq!(sink_schema.fields().len(), input.schema().fields().len()); + + // Find input columns that may violate the not null constraint. + let risky_columns: Vec<_> = sink_schema + .fields() + .iter() + .zip(input.schema().fields().iter()) + .enumerate() + .filter_map(|(idx, (sink_field, input_field))| { + (!sink_field.is_nullable() && input_field.is_nullable()).then_some(idx) + }) + .collect(); + + if risky_columns.is_empty() { + Ok(input_stream) + } else { + // Check not null constraint on the input stream + Ok(Box::pin(RecordBatchStreamAdapter::new( + sink_schema, + input_stream + .map(move |batch| check_not_null_constraints(batch?, &risky_columns)), + ))) + } +} + +/// Checks a `RecordBatch` for `not null` constraints on specified columns. +/// +/// # Arguments +/// +/// * `batch` - The `RecordBatch` to be checked +/// * `column_indices` - A vector of column indices that should be checked for +/// `not null` constraints. +/// +/// # Returns +/// +/// * `Result` - The original `RecordBatch` if all constraints are met +/// +/// This function iterates over the specified column indices and ensures that none +/// of the columns contain null values. If any column contains null values, an error +/// is returned. +pub fn check_not_null_constraints( + batch: RecordBatch, + column_indices: &Vec, +) -> Result { + for &index in column_indices { + if batch.num_columns() <= index { + return exec_err!( + "Invalid batch column count {} expected > {}", + batch.num_columns(), + index + ); + } + + if batch + .column(index) + .logical_nulls() + .map(|nulls| nulls.null_count()) + .unwrap_or_default() + > 0 + { + return exec_err!( + "Invalid batch column at '{}' has null but schema specifies non-nullable", + index + ); + } + } + + Ok(batch) +} + +/// Utility function yielding a string representation of the given [`ExecutionPlan`]. +pub fn get_plan_string(plan: &Arc) -> Vec { + let formatted = displayable(plan.as_ref()).indent(true).to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + actual.iter().map(|elem| elem.to_string()).collect() +} + +/// Indicates the effect an execution plan operator will have on the cardinality +/// of its input stream +pub enum CardinalityEffect { + /// Unknown effect. This is the default + Unknown, + /// The operator is guaranteed to produce exactly one row for + /// each input row + Equal, + /// The operator may produce fewer output rows than it receives input rows + LowerEqual, + /// The operator may produce more output rows than it receives input rows + GreaterEqual, +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{DictionaryArray, Int32Array, NullArray, RunArray}; + use arrow_schema::{DataType, Field, Schema, SchemaRef}; + use std::any::Any; + use std::sync::Arc; + + use datafusion_common::{Result, Statistics}; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan}; + + #[derive(Debug)] + pub struct EmptyExec; + + impl EmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + pub struct RenamedEmptyExec; + + impl RenamedEmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for RenamedEmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for RenamedEmptyExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn static_name() -> &'static str + where + Self: Sized, + { + "MyRenamedEmptyExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[test] + fn test_execution_plan_name() { + let schema1 = Arc::new(Schema::empty()); + let default_name_exec = EmptyExec::new(schema1); + assert_eq!(default_name_exec.name(), "EmptyExec"); + + let schema2 = Arc::new(Schema::empty()); + let renamed_exec = RenamedEmptyExec::new(schema2); + assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); + assert_eq!(RenamedEmptyExec::static_name(), "MyRenamedEmptyExec"); + } + + /// A compilation test to ensure that the `ExecutionPlan::name()` method can + /// be called from a trait object. + /// Related ticket: https://github.com/apache/datafusion/pull/11047 + #[allow(dead_code)] + fn use_execution_plan_as_trait_object(plan: &dyn ExecutionPlan) { + let _ = plan.name(); + } + + #[test] + fn test_check_not_null_constraints_accept_non_null() -> Result<()> { + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)]))], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_reject_null() -> Result<()> { + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])), + vec![Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_run_end_array() -> Result<()> { + // some null value inside REE array + let run_ends = Int32Array::from(vec![1, 2, 3, 4]); + let values = Int32Array::from(vec![Some(0), None, Some(1), None]); + let run_end_array = RunArray::try_new(&run_ends, &values)?; + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + run_end_array.data_type().to_owned(), + true, + )])), + vec![Arc::new(run_end_array)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_array_with_null() -> Result<()> { + let values = Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(4)])); + let keys = Int32Array::from(vec![0, 1, 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_with_dictionary_masking_null() -> Result<()> { + // some null value marked out by dictionary array + let values = Arc::new(Int32Array::from(vec![ + Some(1), + None, // this null value is masked by dictionary keys + Some(3), + Some(4), + ])); + let keys = Int32Array::from(vec![0, /*1,*/ 2, 3]); + let dictionary = DictionaryArray::new(keys, values); + check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new( + "a", + dictionary.data_type().to_owned(), + true, + )])), + vec![Arc::new(dictionary)], + )?, + &vec![0], + )?; + Ok(()) + } + + #[test] + fn test_check_not_null_constraints_on_null_type() -> Result<()> { + // null value of Null type + let result = check_not_null_constraints( + RecordBatch::try_new( + Arc::new(Schema::new(vec![Field::new("a", DataType::Null, true)])), + vec![Arc::new(NullArray::new(3))], + )?, + &vec![0], + ); + assert!(result.is_err()); + assert_starts_with( + result.err().unwrap().message().as_ref(), + "Invalid batch column at '0' has null but schema specifies non-nullable", + ); + Ok(()) + } + + fn assert_starts_with(actual: impl AsRef, expected_prefix: impl AsRef) { + let actual = actual.as_ref(); + let expected_prefix = expected_prefix.as_ref(); + assert!( + actual.starts_with(expected_prefix), + "Expected '{}' to start with '{}'", + actual, + expected_prefix + ); + } +} diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 649946993229..96f55a1446b0 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -53,7 +53,7 @@ impl ExplainExec { stringified_plans: Vec, verbose: bool, ) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); ExplainExec { schema, stringified_plans, @@ -67,7 +67,7 @@ impl ExplainExec { &self.stringified_plans } - /// access to verbose + /// Access to verbose pub fn verbose(&self) -> bool { self.verbose } @@ -111,8 +111,8 @@ impl ExecutionPlan for ExplainExec { &self.cache } - fn children(&self) -> Vec> { - // this is a leaf node and has no children + fn children(&self) -> Vec<&Arc> { + // This is a leaf node and has no children vec![] } @@ -160,7 +160,7 @@ impl ExecutionPlan for ExplainExec { } let record_batch = RecordBatch::try_new( - self.schema.clone(), + Arc::clone(&self.schema), vec![ Arc::new(type_builder.finish()), Arc::new(plan_builder.finish()), @@ -171,7 +171,7 @@ impl ExecutionPlan for ExplainExec { "Before returning RecordBatchStream in ExplainExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), + Arc::clone(&self.schema), futures::stream::iter(vec![Ok(record_batch)]), ))) } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index bf1ab8b73126..07898e8d22d8 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -15,18 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! FilterExec evaluates a boolean predicate against all input batches to determine which rows to -//! include in its output batches. - use std::any::Any; use std::pin::Pin; use std::sync::Arc; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use super::{ ColumnStatistics, DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use crate::common::can_project; use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, DisplayFormatType, ExecutionPlan, @@ -37,22 +35,26 @@ use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; -use datafusion_common::{plan_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, plan_err, project_schema, DataFusionError, Result, +}; use datafusion_execution::TaskContext; use datafusion_expr::Operator; +use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::intervals::utils::check_support; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - analyze, split_conjunction, AnalysisContext, ExprBoundaries, PhysicalExpr, + analyze, split_conjunction, AnalysisContext, ConstExpr, ExprBoundaries, PhysicalExpr, }; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to /// include in its output batches. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FilterExec { /// The expression to filter on. This expression must evaluate to a boolean value. predicate: Arc, @@ -60,9 +62,12 @@ pub struct FilterExec { input: Arc, /// Execution metrics metrics: ExecutionPlanMetricsSet, - /// Selectivity for statistics. 0 = no rows, 100 all rows + /// Selectivity for statistics. 0 = no rows, 100 = all rows default_selectivity: u8, + /// Properties equivalence properties, partitioning, etc. cache: PlanProperties, + /// The projection indices of the columns in the output schema of join + projection: Option>, } impl FilterExec { @@ -74,18 +79,23 @@ impl FilterExec { match predicate.data_type(input.schema().as_ref())? { DataType::Boolean => { let default_selectivity = 20; - let cache = - Self::compute_properties(&input, &predicate, default_selectivity)?; + let cache = Self::compute_properties( + &input, + &predicate, + default_selectivity, + None, + )?; Ok(Self { predicate, - input: input.clone(), + input: Arc::clone(&input), metrics: ExecutionPlanMetricsSet::new(), default_selectivity, cache, + projection: None, }) } other => { - plan_err!("Filter predicate must return boolean values, not {other:?}") + plan_err!("Filter predicate must return BOOLEAN values, got {other:?}") } } } @@ -95,12 +105,43 @@ impl FilterExec { default_selectivity: u8, ) -> Result { if default_selectivity > 100 { - return plan_err!("Default filter selectivity needs to be less than 100"); + return plan_err!( + "Default filter selectivity value needs to be less than or equal to 100" + ); } self.default_selectivity = default_selectivity; Ok(self) } + /// Return new instance of [FilterExec] with the given projection. + pub fn with_projection(&self, projection: Option>) -> Result { + // Check if the projection is valid + can_project(&self.schema(), projection.as_ref())?; + + let projection = match projection { + Some(projection) => match &self.projection { + Some(p) => Some(projection.iter().map(|i| p[*i]).collect()), + None => Some(projection), + }, + None => None, + }; + + let cache = Self::compute_properties( + &self.input, + &self.predicate, + self.default_selectivity, + projection.as_ref(), + )?; + Ok(Self { + predicate: Arc::clone(&self.predicate), + input: Arc::clone(&self.input), + metrics: self.metrics.clone(), + default_selectivity: self.default_selectivity, + cache, + projection, + }) + } + /// The expression to filter on. This expression must evaluate to a boolean value. pub fn predicate(&self) -> &Arc { &self.predicate @@ -116,6 +157,11 @@ impl FilterExec { self.default_selectivity } + /// Projection + pub fn projection(&self) -> Option<&Vec> { + self.projection.as_ref() + } + /// Calculates `Statistics` for `FilterExec`, by applying selectivity (either default, or estimated) to input statistics. fn statistics_helper( input: &Arc, @@ -126,7 +172,7 @@ impl FilterExec { let schema = input.schema(); if !check_support(predicate, &schema) { let selectivity = default_selectivity as f64 / 100.0; - let mut stats = input_stats.into_inexact(); + let mut stats = input_stats.to_inexact(); stats.num_rows = stats.num_rows.with_estimated_selectivity(selectivity); stats.total_byte_size = stats .total_byte_size @@ -162,7 +208,7 @@ impl FilterExec { fn extend_constants( input: &Arc, predicate: &Arc, - ) -> Vec> { + ) -> Vec { let mut res_constants = Vec::new(); let input_eqs = input.equivalence_properties(); @@ -170,10 +216,15 @@ impl FilterExec { for conjunction in conjunctions { if let Some(binary) = conjunction.as_any().downcast_ref::() { if binary.op() == &Operator::Eq { + // Filter evaluates to single value for all partitions if input_eqs.is_expr_constant(binary.left()) { - res_constants.push(binary.right().clone()) + res_constants.push( + ConstExpr::from(binary.right()).with_across_partitions(true), + ) } else if input_eqs.is_expr_constant(binary.right()) { - res_constants.push(binary.left().clone()) + res_constants.push( + ConstExpr::from(binary.left()).with_across_partitions(true), + ) } } } @@ -185,6 +236,7 @@ impl FilterExec { input: &Arc, predicate: &Arc, default_selectivity: u8, + projection: Option<&Vec>, ) -> Result { // Combine the equal predicates with the input equivalence properties // to construct the equivalence properties: @@ -192,24 +244,38 @@ impl FilterExec { let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - eq_properties.add_equal_conditions(lhs, rhs) + eq_properties.add_equal_conditions(lhs, rhs)? } // Add the columns that have only one viable value (singleton) after // filtering to constants. let constants = collect_columns(predicate) .into_iter() .filter(|column| stats.column_statistics[column.index()].is_singleton()) - .map(|column| Arc::new(column) as _); - // this is for statistics - eq_properties = eq_properties.add_constants(constants); - // this is for logical constant (for example: a = '1', then a could be marked as a constant) + .map(|column| { + let expr = Arc::new(column) as _; + ConstExpr::new(expr).with_across_partitions(true) + }); + // This is for statistics + eq_properties = eq_properties.with_constants(constants); + // This is for logical constant (for example: a = '1', then a could be marked as a constant) // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) eq_properties = - eq_properties.add_constants(Self::extend_constants(input, predicate)); + eq_properties.with_constants(Self::extend_constants(input, predicate)); + + let mut output_partitioning = input.output_partitioning().clone(); + // If contains projection, update the PlanProperties. + if let Some(projection) = projection { + let schema = eq_properties.schema(); + let projection_mapping = ProjectionMapping::from_indices(projection, schema)?; + let out_schema = project_schema(schema, Some(projection))?; + output_partitioning = + output_partitioning.project(&projection_mapping, &eq_properties); + eq_properties = eq_properties.project(&projection_mapping, out_schema); + } Ok(PlanProperties::new( eq_properties, - input.output_partitioning().clone(), // Output Partitioning - input.execution_mode(), // Execution Mode + output_partitioning, + input.execution_mode(), )) } } @@ -222,7 +288,25 @@ impl DisplayAs for FilterExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "FilterExec: {}", self.predicate) + let display_projections = if let Some(projection) = + self.projection.as_ref() + { + format!( + ", projection=[{}]", + projection + .iter() + .map(|index| format!( + "{}@{}", + self.input.schema().fields().get(*index).unwrap().name(), + index + )) + .collect::>() + .join(", ") + ) + } else { + "".to_string() + }; + write!(f, "FilterExec: {}{}", self.predicate, display_projections) } } } @@ -242,12 +326,12 @@ impl ExecutionPlan for FilterExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -255,11 +339,12 @@ impl ExecutionPlan for FilterExec { self: Arc, mut children: Vec>, ) -> Result> { - FilterExec::try_new(self.predicate.clone(), children.swap_remove(0)) + FilterExec::try_new(Arc::clone(&self.predicate), children.swap_remove(0)) .and_then(|e| { let selectivity = e.default_selectivity(); e.with_default_selectivity(selectivity) }) + .and_then(|e| e.with_projection(self.projection().cloned())) .map(|e| Arc::new(e) as _) } @@ -271,10 +356,11 @@ impl ExecutionPlan for FilterExec { trace!("Start FilterExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { - schema: self.input.schema(), - predicate: self.predicate.clone(), + schema: self.schema(), + predicate: Arc::clone(&self.predicate), input: self.input.execute(partition, context)?, baseline_metrics, + projection: self.projection.clone(), })) } @@ -285,7 +371,16 @@ impl ExecutionPlan for FilterExec { /// The output statistics of a filtering operation can be estimated if the /// predicate's selectivity value can be determined for the incoming data. fn statistics(&self) -> Result { - Self::statistics_helper(&self.input, self.predicate(), self.default_selectivity) + let stats = Self::statistics_helper( + &self.input, + self.predicate(), + self.default_selectivity, + )?; + Ok(stats.project(self.projection.as_ref())) + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual } } @@ -316,7 +411,7 @@ fn collect_new_statistics( (Precision::Inexact(lower), Precision::Inexact(upper)) }; ColumnStatistics { - null_count: input_column_stats[idx].null_count.clone().to_inexact(), + null_count: input_column_stats[idx].null_count.to_inexact(), max_value, min_value, distinct_count: distinct_count.to_inexact(), @@ -329,27 +424,55 @@ fn collect_new_statistics( /// The FilterExec streams wraps the input iterator and applies the predicate expression to /// determine which rows to include in its output batches struct FilterExecStream { - /// Output schema, which is the same as the input schema for this operator + /// Output schema after the projection schema: SchemaRef, /// The expression to filter on. This expression must evaluate to a boolean value. predicate: Arc, /// The input partition to filter. input: SendableRecordBatchStream, - /// runtime metrics recording + /// Runtime metrics recording baseline_metrics: BaselineMetrics, + /// The projection indices of the columns in the input schema + projection: Option>, } -pub(crate) fn batch_filter( +pub fn batch_filter( batch: &RecordBatch, predicate: &Arc, +) -> Result { + filter_and_project(batch, predicate, None, &batch.schema()) +} + +fn filter_and_project( + batch: &RecordBatch, + predicate: &Arc, + projection: Option<&Vec>, + output_schema: &SchemaRef, ) -> Result { predicate .evaluate(batch) .and_then(|v| v.into_array(batch.num_rows())) .and_then(|array| { - Ok(as_boolean_array(&array)?) - // apply filter array to record batch - .and_then(|filter_array| Ok(filter_record_batch(batch, filter_array)?)) + Ok(match (as_boolean_array(&array), projection) { + // Apply filter array to record batch + (Ok(filter_array), None) => filter_record_batch(batch, filter_array)?, + (Ok(filter_array), Some(projection)) => { + let projected_columns = projection + .iter() + .map(|i| Arc::clone(batch.column(*i))) + .collect(); + let projected_batch = RecordBatch::try_new( + Arc::clone(output_schema), + projected_columns, + )?; + filter_record_batch(&projected_batch, filter_array)? + } + (Err(_), _) => { + return internal_err!( + "Cannot create filter_array from non-boolean predicates" + ); + } + }) }) } @@ -362,26 +485,25 @@ impl Stream for FilterExecStream { ) -> Poll> { let poll; loop { - match self.input.poll_next_unpin(cx) { - Poll::Ready(value) => match value { - Some(Ok(batch)) => { - let timer = self.baseline_metrics.elapsed_compute().timer(); - let filtered_batch = batch_filter(&batch, &self.predicate)?; - // skip entirely filtered batches - if filtered_batch.num_rows() == 0 { - continue; - } - timer.done(); - poll = Poll::Ready(Some(Ok(filtered_batch))); - break; - } - _ => { - poll = Poll::Ready(value); - break; + match ready!(self.input.poll_next_unpin(cx)) { + Some(Ok(batch)) => { + let timer = self.baseline_metrics.elapsed_compute().timer(); + let filtered_batch = filter_and_project( + &batch, + &self.predicate, + self.projection.as_ref(), + &self.schema, + )?; + timer.done(); + // Skip entirely filtered batches + if filtered_batch.num_rows() == 0 { + continue; } - }, - Poll::Pending => { - poll = Poll::Pending; + poll = Poll::Ready(Some(Ok(filtered_batch))); + break; + } + value => { + poll = Poll::Ready(value); break; } } @@ -390,14 +512,14 @@ impl Stream for FilterExecStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } impl RecordBatchStream for FilterExecStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -433,13 +555,12 @@ pub type EqualAndNonEqual<'a> = #[cfg(test)] mod tests { - use super::*; + use crate::empty::EmptyExec; use crate::expressions::*; use crate::test; use crate::test::exec::StatisticsExec; - use crate::empty::EmptyExec; use arrow::datatypes::{Field, Schema}; use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::ScalarValue; @@ -1117,7 +1238,7 @@ mod tests { binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, &schema, )?, - Arc::new(EmptyExec::new(schema.clone())), + Arc::new(EmptyExec::new(Arc::clone(&schema))), )?; exec.statistics().unwrap(); diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 259db644ae0a..ae8a2acce696 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -23,8 +23,8 @@ use std::fmt::Debug; use std::sync::Arc; use super::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, - PlanProperties, SendableRecordBatchStream, + execute_input_stream, DisplayAs, DisplayFormatType, ExecutionPlan, + ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, }; use crate::metrics::MetricsSet; use crate::stream::RecordBatchStreamAdapter; @@ -33,13 +33,12 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow_array::{ArrayRef, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; -use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ - Distribution, EquivalenceProperties, PhysicalSortRequirement, -}; +use datafusion_physical_expr::{Distribution, EquivalenceProperties}; use async_trait::async_trait; +use datafusion_physical_expr_common::sort_expr::LexRequirement; use futures::StreamExt; /// `DataSink` implements writing streams of [`RecordBatch`]es to @@ -74,12 +73,10 @@ pub trait DataSink: DisplayAs + Debug + Send + Sync { ) -> Result; } -#[deprecated(since = "38.0.0", note = "Use [`DataSinkExec`] instead")] -pub type FileSinkExec = DataSinkExec; - /// Execution plan for writing record batches to a [`DataSink`] /// /// Returns a single row with the number of values written +#[derive(Clone)] pub struct DataSinkExec { /// Input plan that produces the record batches to be written. input: Arc, @@ -90,11 +87,11 @@ pub struct DataSinkExec { /// Schema describing the structure of the output data. count_schema: SchemaRef, /// Optional required sort order for output data. - sort_order: Option>, + sort_order: Option, cache: PlanProperties, } -impl fmt::Debug for DataSinkExec { +impl Debug for DataSinkExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "DataSinkExec schema: {:?}", self.count_schema) } @@ -106,7 +103,7 @@ impl DataSinkExec { input: Arc, sink: Arc, sink_schema: SchemaRef, - sort_order: Option>, + sort_order: Option, ) -> Self { let count_schema = make_count_schema(); let cache = Self::create_schema(&input, count_schema); @@ -120,46 +117,6 @@ impl DataSinkExec { } } - fn execute_input_stream( - &self, - partition: usize, - context: Arc, - ) -> Result { - let input_stream = self.input.execute(partition, context)?; - - debug_assert_eq!( - self.sink_schema.fields().len(), - self.input.schema().fields().len() - ); - - // Find input columns that may violate the not null constraint. - let risky_columns: Vec<_> = self - .sink_schema - .fields() - .iter() - .zip(self.input.schema().fields().iter()) - .enumerate() - .filter_map(|(i, (sink_field, input_field))| { - if !sink_field.is_nullable() && input_field.is_nullable() { - Some(i) - } else { - None - } - }) - .collect(); - - if risky_columns.is_empty() { - Ok(input_stream) - } else { - // Check not null constraint on the input stream - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.sink_schema.clone(), - input_stream - .map(move |batch| check_not_null_contraits(batch?, &risky_columns)), - ))) - } - } - /// Input execution plan pub fn input(&self) -> &Arc { &self.input @@ -171,15 +128,10 @@ impl DataSinkExec { } /// Optional sort order for output data - pub fn sort_order(&self) -> &Option> { + pub fn sort_order(&self) -> &Option { &self.sort_order } - /// Returns the metrics of the underlying [DataSink] - pub fn metrics(&self) -> Option { - self.sink.metrics() - } - fn create_schema( input: &Arc, schema: SchemaRef, @@ -194,11 +146,7 @@ impl DataSinkExec { } impl DisplayAs for DataSinkExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!(f, "DataSinkExec: sink=")?; @@ -234,7 +182,7 @@ impl ExecutionPlan for DataSinkExec { vec![Distribution::SinglePartition; self.children().len()] } - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { // The required input ordering is set externally (e.g. by a `ListingTable`). // Otherwise, there is no specific requirement (i.e. `sort_expr` is `None`). vec![self.sort_order.as_ref().cloned()] @@ -248,8 +196,8 @@ impl ExecutionPlan for DataSinkExec { vec![true] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -257,9 +205,9 @@ impl ExecutionPlan for DataSinkExec { children: Vec>, ) -> Result> { Ok(Arc::new(Self::new( - children[0].clone(), - self.sink.clone(), - self.sink_schema.clone(), + Arc::clone(&children[0]), + Arc::clone(&self.sink), + Arc::clone(&self.sink_schema), self.sort_order.clone(), ))) } @@ -274,10 +222,15 @@ impl ExecutionPlan for DataSinkExec { if partition != 0 { return internal_err!("DataSinkExec can only be called on partition 0!"); } - let data = self.execute_input_stream(0, context.clone())?; + let data = execute_input_stream( + Arc::clone(&self.input), + Arc::clone(&self.sink_schema), + 0, + Arc::clone(&context), + )?; - let count_schema = self.count_schema.clone(); - let sink = self.sink.clone(); + let count_schema = Arc::clone(&self.count_schema); + let sink = Arc::clone(&self.sink); let stream = futures::stream::once(async move { sink.write_all(data, &context).await.map(make_count_batch) @@ -289,6 +242,11 @@ impl ExecutionPlan for DataSinkExec { stream, ))) } + + /// Returns the metrics of the underlying [DataSink] + fn metrics(&self) -> Option { + self.sink.metrics() + } } /// Create a output record batch with a count @@ -307,34 +265,10 @@ fn make_count_batch(count: u64) -> RecordBatch { } fn make_count_schema() -> SchemaRef { - // define a schema. + // Define a schema. Arc::new(Schema::new(vec![Field::new( "count", DataType::UInt64, false, )])) } - -fn check_not_null_contraits( - batch: RecordBatch, - column_indices: &Vec, -) -> Result { - for &index in column_indices { - if batch.num_columns() <= index { - return exec_err!( - "Invalid batch column count {} expected > {}", - batch.num_columns(), - index - ); - } - - if batch.column(index).null_count() > 0 { - return exec_err!( - "Invalid batch column at '{}' has null but schema specifies non-nullable", - index - ); - } - } - - Ok(batch) -} diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 9d1de3715f54..8c8921eba6a1 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -18,13 +18,11 @@ //! Defines the cross join plan for loading the left side of the cross join //! and producing batches in parallel for the right partitions -use std::{any::Any, sync::Arc, task::Poll}; - use super::utils::{ - adjust_right_output_partitioning, BuildProbeJoinMetrics, OnceAsync, OnceFut, + adjust_right_output_partitioning, BatchSplitter, BatchTransformer, + BuildProbeJoinMetrics, NoopBatchTransformer, OnceAsync, OnceFut, StatefulStreamResult, }; -use crate::coalesce_batches::concat_batches; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ @@ -33,6 +31,8 @@ use crate::{ ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; +use arrow::compute::concat_batches; +use std::{any::Any, sync::Arc, task::Poll}; use arrow::datatypes::{Fields, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -47,10 +47,23 @@ use async_trait::async_trait; use futures::{ready, Stream, StreamExt, TryStreamExt}; /// Data of the left side -type JoinLeftData = (RecordBatch, MemoryReservation); +#[derive(Debug)] +struct JoinLeftData { + /// Single RecordBatch with all rows from the left side + merged_batch: RecordBatch, + /// Track memory reservation for merged_batch. Relies on drop + /// semantics to release reservation when JoinLeftData is dropped. + #[allow(dead_code)] + reservation: MemoryReservation, +} +#[allow(rustdoc::private_intra_doc_links)] /// executes partitions in parallel and combines them into a set of /// partitions by combining all values from the left with all values on the right +/// +/// Note that the `Clone` trait is not implemented for this struct due to the +/// `left_fut` [`OnceAsync`], which is used to coordinate the loading of the +/// left side with the processing in each output stream. #[derive(Debug)] pub struct CrossJoinExec { /// left (build) side which gets loaded in memory @@ -70,16 +83,24 @@ impl CrossJoinExec { /// Create a new [CrossJoinExec]. pub fn new(left: Arc, right: Arc) -> Self { // left then right - let all_columns: Fields = { + let (all_columns, metadata) = { let left_schema = left.schema(); let right_schema = right.schema(); let left_fields = left_schema.fields().iter(); let right_fields = right_schema.fields().iter(); - left_fields.chain(right_fields).cloned().collect() + + let mut metadata = left_schema.metadata().clone(); + metadata.extend(right_schema.metadata().clone()); + + ( + left_fields.chain(right_fields).cloned().collect::(), + metadata, + ) }; - let schema = Arc::new(Schema::new(all_columns)); - let cache = Self::compute_properties(&left, &right, schema.clone()); + let schema = Arc::new(Schema::new(all_columns).with_metadata(metadata)); + let cache = Self::compute_properties(&left, &right, Arc::clone(&schema)); + CrossJoinExec { left, right, @@ -155,29 +176,27 @@ async fn load_left_input( let stream = merge.execute(0, context)?; // Load all batches and count the rows - let (batches, num_rows, _, reservation) = stream - .try_fold( - (Vec::new(), 0usize, metrics, reservation), - |mut acc, batch| async { - let batch_size = batch.get_array_memory_size(); - // Reserve memory for incoming batch - acc.3.try_grow(batch_size)?; - // Update metrics - acc.2.build_mem_used.add(batch_size); - acc.2.build_input_batches.add(1); - acc.2.build_input_rows.add(batch.num_rows()); - // Update rowcount - acc.1 += batch.num_rows(); - // Push batch to output - acc.0.push(batch); - Ok(acc) - }, - ) + let (batches, _metrics, reservation) = stream + .try_fold((Vec::new(), metrics, reservation), |mut acc, batch| async { + let batch_size = batch.get_array_memory_size(); + // Reserve memory for incoming batch + acc.2.try_grow(batch_size)?; + // Update metrics + acc.1.build_mem_used.add(batch_size); + acc.1.build_input_batches.add(1); + acc.1.build_input_rows.add(batch.num_rows()); + // Push batch to output + acc.0.push(batch); + Ok(acc) + }) .await?; - let merged_batch = concat_batches(&left_schema, &batches, num_rows)?; + let merged_batch = concat_batches(&left_schema, &batches)?; - Ok((merged_batch, reservation)) + Ok(JoinLeftData { + merged_batch, + reservation, + }) } impl DisplayAs for CrossJoinExec { @@ -207,8 +226,8 @@ impl ExecutionPlan for CrossJoinExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] } fn metrics(&self) -> Option { @@ -220,8 +239,8 @@ impl ExecutionPlan for CrossJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(CrossJoinExec::new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), ))) } @@ -237,7 +256,7 @@ impl ExecutionPlan for CrossJoinExec { partition: usize, context: Arc, ) -> Result { - let stream = self.right.execute(partition, context.clone())?; + let stream = self.right.execute(partition, Arc::clone(&context))?; let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); @@ -245,24 +264,42 @@ impl ExecutionPlan for CrossJoinExec { let reservation = MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()); + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let left_fut = self.left_fut.once(|| { load_left_input( - self.left.clone(), + Arc::clone(&self.left), context, join_metrics.clone(), reservation, ) }); - Ok(Box::pin(CrossJoinStream { - schema: self.schema.clone(), - left_fut, - right: stream, - left_index: 0, - join_metrics, - state: CrossJoinStreamState::WaitBuildSide, - left_data: RecordBatch::new_empty(self.left().schema()), - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(CrossJoinStream { + schema: Arc::clone(&self.schema), + left_fut, + right: stream, + left_index: 0, + join_metrics, + state: CrossJoinStreamState::WaitBuildSide, + left_data: RecordBatch::new_empty(self.left().schema()), + batch_transformer: NoopBatchTransformer::new(), + })) + } } fn statistics(&self) -> Result { @@ -318,7 +355,7 @@ fn stats_cartesian_product( } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct CrossJoinStream { +struct CrossJoinStream { /// Input schema schema: Arc, /// Future for data from left side @@ -331,13 +368,15 @@ struct CrossJoinStream { join_metrics: BuildProbeJoinMetrics, /// State of the stream state: CrossJoinStreamState, - /// Left data + /// Left data (copy of the entire buffered left side) left_data: RecordBatch, + /// Batch transformer + batch_transformer: T, } -impl RecordBatchStream for CrossJoinStream { +impl RecordBatchStream for CrossJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -389,24 +428,24 @@ fn build_batch( } #[async_trait] -impl Stream for CrossJoinStream { +impl Stream for CrossJoinStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } -impl CrossJoinStream { +impl CrossJoinStream { /// Separate implementation function that unpins the [`CrossJoinStream`] so /// that partial borrows work correctly fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll>> { + ) -> Poll>> { loop { return match self.state { CrossJoinStreamState::WaitBuildSide => { @@ -429,16 +468,17 @@ impl CrossJoinStream { cx: &mut std::task::Context<'_>, ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let (left_data, _) = match ready!(self.left_fut.get(cx)) { + let left_data = match ready!(self.left_fut.get(cx)) { Ok(left_data) => left_data, Err(e) => return Poll::Ready(Err(e)), }; build_timer.done(); + let left_data = left_data.merged_batch.clone(); let result = if left_data.num_rows() == 0 { StatefulStreamResult::Ready(None) } else { - self.left_data = left_data.clone(); + self.left_data = left_data; self.state = CrossJoinStreamState::FetchProbeBatch; StatefulStreamResult::Continue }; @@ -469,21 +509,33 @@ impl CrossJoinStream { fn build_batches(&mut self) -> Result>> { let right_batch = self.state.try_as_record_batch()?; if self.left_index < self.left_data.num_rows() { - let join_timer = self.join_metrics.join_time.timer(); - let result = - build_batch(self.left_index, right_batch, &self.left_data, &self.schema); - join_timer.done(); - - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + match self.batch_transformer.next() { + None => { + let join_timer = self.join_metrics.join_time.timer(); + let result = build_batch( + self.left_index, + right_batch, + &self.left_data, + &self.schema, + ); + join_timer.done(); + + self.batch_transformer.set_batch(result?); + } + Some((batch, last)) => { + if last { + self.left_index += 1; + } + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Ok(StatefulStreamResult::Ready(Some(batch))); + } } - self.left_index += 1; - result.map(|r| StatefulStreamResult::Ready(Some(r))) } else { self.state = CrossJoinStreamState::FetchProbeBatch; - Ok(StatefulStreamResult::Continue) } + Ok(StatefulStreamResult::Continue) } } @@ -494,7 +546,7 @@ mod tests { use crate::test::build_table_scan_i32; use datafusion_common::{assert_batches_sorted_eq, assert_contains}; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; async fn join_collect( left: Arc, @@ -578,7 +630,7 @@ mod tests { } #[tokio::test] - async fn test_stats_cartesian_product_with_unknwon_size() { + async fn test_stats_cartesian_product_with_unknown_size() { let left_row_count = 11; let left = Statistics { @@ -679,8 +731,9 @@ mod tests { #[tokio::test] async fn test_overallocation() -> Result<()> { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); @@ -699,9 +752,8 @@ mod tests { assert_contains!( err.to_string(), - "External error: Resources exhausted: Failed to allocate additional" + "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: CrossJoinExec" ); - assert_contains!(err.to_string(), "CrossJoinExec"); Ok(()) } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index d3abedbe3806..ae872e13a9f6 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -22,8 +22,9 @@ use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use std::{any::Any, usize, vec}; +use std::{any::Any, vec}; +use super::utils::asymmetric_join_output_partitioning; use super::{ utils::{OnceAsync, OnceFut}, PartitionMode, @@ -35,10 +36,10 @@ use crate::{ execution_mode_from_children, handle_state, hash_utils::create_hashes, joins::utils::{ - adjust_indices_by_join_type, adjust_right_output_partitioning, - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, get_final_indices_from_bit_map, - need_produce_result_in_final, partitioned_join_output_partitioning, + adjust_indices_by_join_type, apply_join_filter_to_indices, + build_batch_from_indices, build_join_schema, check_join_is_valid, + estimate_join_statistics, get_final_indices_from_bit_map, + need_produce_result_in_final, symmetric_join_output_partitioning, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinHashMap, JoinHashMapOffset, JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, }, @@ -49,8 +50,7 @@ use crate::{ }; use arrow::array::{ - Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array, - UInt64Array, + Array, ArrayRef, BooleanArray, BooleanBufferBuilder, UInt32Array, UInt64Array, }; use arrow::compute::kernels::cmp::{eq, not_distinct}; use arrow::compute::{and, concat_batches, take, FilterBuilder}; @@ -59,6 +59,7 @@ use arrow::record_batch::RecordBatch; use arrow::util::bit_util; use arrow_array::cast::downcast_array; use arrow_schema::ArrowError; +use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, project_schema, DataFusionError, JoinSide, JoinType, Result, @@ -68,10 +69,11 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::{ join_equivalence_properties, ProjectionMapping, }; -use datafusion_physical_expr::expressions::UnKnownColumn; -use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; +use datafusion_physical_expr::PhysicalExprRef; use ahash::RandomState; +use datafusion_expr::Operator; +use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; @@ -134,6 +136,7 @@ impl JoinLeftData { } } +#[allow(rustdoc::private_intra_doc_links)] /// Join execution plan: Evaluates eqijoin predicates in parallel on multiple /// partitions using a hash table and an optional filter list to apply post /// join. @@ -177,9 +180,9 @@ impl JoinLeftData { /// Execution proceeds in 2 stages: /// /// 1. the **build phase** creates a hash table from the tuples of the build side, -/// and single concatenated batch containing data from all fetched record batches. -/// Resulting hash table stores hashed join-key fields for each row as a key, and -/// indices of corresponding rows in concatenated batch. +/// and single concatenated batch containing data from all fetched record batches. +/// Resulting hash table stores hashed join-key fields for each row as a key, and +/// indices of corresponding rows in concatenated batch. /// /// Hash join uses LIFO data structure as a hash table, and in order to retain /// original build-side input order while obtaining data during probe phase, hash @@ -220,7 +223,7 @@ impl JoinLeftData { /// ``` /// /// 2. the **probe phase** where the tuples of the probe side are streamed -/// through, checking for matches of the join keys in the hash table. +/// through, checking for matches of the join keys in the hash table. /// /// ```text /// ┌────────────────┐ ┌────────────────┐ @@ -291,6 +294,10 @@ impl JoinLeftData { /// │ "dimension" │ │ "fact" │ /// └───────────────┘ └───────────────┘ /// ``` +/// +/// Note that the `Clone` trait is not implemented for this struct due to the +/// `left_fut` [`OnceAsync`], which is used to coordinate the loading of the +/// left side with the processing in each output stream. #[derive(Debug)] pub struct HashJoinExec { /// left (build) side which gets hashed @@ -364,7 +371,7 @@ impl HashJoinExec { let cache = Self::compute_properties( &left, &right, - join_schema.clone(), + Arc::clone(&join_schema), *join_type, &on, partition_mode, @@ -414,6 +421,12 @@ impl HashJoinExec { &self.join_type } + /// The schema after join. Please be careful when using this schema, + /// if there is a projection, the schema isn't the same as the output schema. + pub fn join_schema(&self) -> &SchemaRef { + &self.join_schema + } + /// The partitioning mode of this hash join pub fn partition_mode(&self) -> &PartitionMode { &self.mode @@ -430,7 +443,10 @@ impl HashJoinExec { false, matches!( join_type, - JoinType::Inner | JoinType::RightAnti | JoinType::RightSemi + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi ), ] } @@ -458,8 +474,8 @@ impl HashJoinExec { None => None, }; Self::try_new( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), self.on.clone(), self.filter.clone(), &self.join_type, @@ -484,39 +500,22 @@ impl HashJoinExec { left.equivalence_properties().clone(), right.equivalence_properties().clone(), &join_type, - schema.clone(), + Arc::clone(&schema), &Self::maintains_input_order(join_type), Some(Self::probe_side()), on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); let mut output_partitioning = match mode { - PartitionMode::CollectLeft => match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left_columns_len, - ), - JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Left - | JoinType::LeftSemi - | JoinType::LeftAnti - | JoinType::Full => Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ), - }, - PartitionMode::Partitioned => partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ), + PartitionMode::CollectLeft => { + asymmetric_join_output_partitioning(left, right, &join_type) + } PartitionMode::Auto => Partitioning::UnknownPartitioning( right.output_partitioning().partition_count(), ), + PartitionMode::Partitioned => { + symmetric_join_output_partitioning(left, right, &join_type) + } }; // Determine execution mode by checking whether this join is pipeline @@ -530,6 +529,7 @@ impl HashJoinExec { | JoinType::Full | JoinType::LeftAnti | JoinType::LeftSemi + | JoinType::LeftMark )); let mode = if pipeline_breaking { @@ -540,24 +540,12 @@ impl HashJoinExec { // If contains projection, update the PlanProperties. if let Some(projection) = projection { - let projection_exprs = project_index_to_exprs(projection, &schema); // construct a map from the input expressions to the output expression of the Projection let projection_mapping = - ProjectionMapping::try_new(&projection_exprs, &schema)?; + ProjectionMapping::from_indices(projection, &schema)?; let out_schema = project_schema(&schema, Some(projection))?; - if let Partitioning::Hash(exprs, part) = output_partitioning { - let normalized_exprs = exprs - .iter() - .map(|expr| { - eq_properties - .project_expr(expr, &projection_mapping) - .unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) - }) - }) - .collect(); - output_partitioning = Partitioning::Hash(normalized_exprs, part); - } + output_partitioning = + output_partitioning.project(&projection_mapping, &eq_properties); eq_properties = eq_properties.project(&projection_mapping, out_schema); } Ok(PlanProperties::new( @@ -610,25 +598,6 @@ impl DisplayAs for HashJoinExec { } } -fn project_index_to_exprs( - projection_index: &[usize], - schema: &SchemaRef, -) -> Vec<(Arc, String)> { - projection_index - .iter() - .map(|index| { - let field = schema.field(*index); - ( - Arc::new(datafusion_physical_expr::expressions::Column::new( - field.name(), - *index, - )) as Arc, - field.name().to_owned(), - ) - }) - .collect::>() -} - impl ExecutionPlan for HashJoinExec { fn name(&self) -> &'static str { "HashJoinExec" @@ -649,8 +618,11 @@ impl ExecutionPlan for HashJoinExec { Distribution::UnspecifiedDistribution, ], PartitionMode::Partitioned => { - let (left_expr, right_expr) = - self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); vec![ Distribution::HashPartitioned(left_expr), Distribution::HashPartitioned(right_expr), @@ -683,8 +655,8 @@ impl ExecutionPlan for HashJoinExec { Self::maintains_input_order(self.join_type) } - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] } fn with_new_children( @@ -692,8 +664,8 @@ impl ExecutionPlan for HashJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(HashJoinExec::try_new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.on.clone(), self.filter.clone(), &self.join_type, @@ -708,8 +680,16 @@ impl ExecutionPlan for HashJoinExec { partition: usize, context: Arc, ) -> Result { - let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); - let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); + let on_left = self + .on + .iter() + .map(|on| Arc::clone(&on.0)) + .collect::>(); + let on_right = self + .on + .iter() + .map(|on| Arc::clone(&on.1)) + .collect::>(); let left_partitions = self.left.output_partitioning().partition_count(); let right_partitions = self.right.output_partitioning().partition_count(); @@ -729,9 +709,9 @@ impl ExecutionPlan for HashJoinExec { collect_left_input( None, self.random_state.clone(), - self.left.clone(), + Arc::clone(&self.left), on_left.clone(), - context.clone(), + Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), @@ -746,9 +726,9 @@ impl ExecutionPlan for HashJoinExec { OnceFut::new(collect_left_input( Some(partition), self.random_state.clone(), - self.left.clone(), + Arc::clone(&self.left), on_left.clone(), - context.clone(), + Arc::clone(&context), join_metrics.clone(), reservation, need_produce_result_in_final(self.join_type), @@ -793,6 +773,7 @@ impl ExecutionPlan for HashJoinExec { build_side: BuildSide::Initial(BuildSideInitialState { left_fut }), batch_size, hashes_buffer: vec![], + right_side_ordered: self.right.output_ordering().is_some(), })) } @@ -804,24 +785,15 @@ impl ExecutionPlan for HashJoinExec { // TODO stats: it is not possible in general to know the output size of joins // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` - let mut stats = estimate_join_statistics( - self.left.clone(), - self.right.clone(), + let stats = estimate_join_statistics( + Arc::clone(&self.left), + Arc::clone(&self.right), self.on.clone(), &self.join_type, &self.join_schema, )?; // Project statistics if there is a projection - if let Some(projection) = &self.projection { - stats.column_statistics = stats - .column_statistics - .into_iter() - .enumerate() - .filter(|(i, _)| projection.contains(i)) - .map(|(_, s)| s) - .collect(); - } - Ok(stats) + Ok(stats.project(self.projection.as_ref())) } } @@ -850,7 +822,7 @@ async fn collect_left_input( }; // Depending on partition argument load single partition or whole left side in memory - let stream = left_input.execute(left_input_partition, context.clone())?; + let stream = left_input.execute(left_input_partition, Arc::clone(&context))?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -875,23 +847,12 @@ async fn collect_left_input( // Estimation of memory size, required for hashtable, prior to allocation. // Final result can be verified using `RawTable.allocation_info()` - // - // For majority of cases hashbrown overestimates buckets qty to keep ~1/8 of them empty. - // This formula leads to overallocation for small tables (< 8 elements) but fine overall. - let estimated_buckets = (num_rows.checked_mul(8).ok_or_else(|| { - DataFusionError::Execution( - "usize overflow while estimating number of hasmap buckets".to_string(), - ) - })? / 7) - .next_power_of_two(); - // 16 bytes per `(u64, u64)` - // + 1 byte for each bucket - // + fixed size of JoinHashMap (RawTable + Vec) - let estimated_hastable_size = - 16 * estimated_buckets + estimated_buckets + size_of::(); + let fixed_size = size_of::(); + let estimated_hashtable_size = + estimate_memory_size::<(u64, u64)>(num_rows, fixed_size)?; - reservation.try_grow(estimated_hastable_size)?; - metrics.build_mem_used.add(estimated_hastable_size); + reservation.try_grow(estimated_hashtable_size)?; + metrics.build_mem_used.add(estimated_hashtable_size); let mut hashmap = JoinHashMap::with_capacity(num_rows); let mut hashes_buffer = Vec::new(); @@ -1102,7 +1063,7 @@ impl ProcessProbeBatchState { /// 1. Reads the entire left input (build) and constructs a hash table /// /// 2. Streams [RecordBatch]es as they arrive from the right input (probe) and joins -/// them with the contents of the hash table +/// them with the contents of the hash table struct HashJoinStream { /// Input schema schema: Arc, @@ -1132,11 +1093,13 @@ struct HashJoinStream { batch_size: usize, /// Scratch space for computing hashes hashes_buffer: Vec, + /// Specifies whether the right side has an ordering to potentially preserve + right_side_ordered: bool, } impl RecordBatchStream for HashJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1212,13 +1175,11 @@ fn lookup_join_hashmap( }) .collect::>>()?; - let (mut probe_builder, mut build_builder, next_offset) = build_hashmap + let (probe_indices, build_indices, next_offset) = build_hashmap .get_matched_indices_with_limit_offset(hashes_buffer, None, limit, offset); - let build_indices: UInt64Array = - PrimitiveArray::new(build_builder.finish().into(), None); - let probe_indices: UInt32Array = - PrimitiveArray::new(probe_builder.finish().into(), None); + let build_indices: UInt64Array = build_indices.into(); + let probe_indices: UInt32Array = probe_indices.into(); let (build_indices, probe_indices) = equal_rows_arr( &build_indices, @@ -1237,6 +1198,17 @@ fn eq_dyn_null( right: &dyn Array, null_equals_null: bool, ) -> Result { + // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special + // implementation + // + if left.data_type().is_nested() { + let op = if null_equals_null { + Operator::IsNotDistinctFrom + } else { + Operator::Eq + }; + return Ok(compare_op_for_nested(op, &left, &right)?); + } match (left.data_type(), right.data_type()) { _ if null_equals_null => not_distinct(&left, &right), _ => eq(&left, &right), @@ -1463,7 +1435,8 @@ impl HashJoinStream { right_indices, index_alignment_range_start..index_alignment_range_end, self.join_type, - ); + self.right_side_ordered, + )?; let result = build_batch_from_indices( &self.schema, @@ -1549,30 +1522,32 @@ impl Stream for HashJoinStream { fn poll_next( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + ) -> Poll> { self.poll_next_impl(cx) } } #[cfg(test)] mod tests { - use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, test::build_table_i32, test::exec::MockExec, }; - use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder}; + use arrow::array::{Date32Array, Int32Array}; use arrow::datatypes::{DataType, Field}; + use arrow_array::StructArray; + use arrow_buffer::NullBuffer; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, ScalarValue, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; + use datafusion_physical_expr::PhysicalExpr; use hashbrown::raw::RawTable; use rstest::*; @@ -1688,8 +1663,10 @@ mod tests { ) -> Result<(Vec, Vec)> { let partition_count = 4; - let (left_expr, right_expr) = - on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); + let (left_expr, right_expr) = on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); let left_repartitioned: Arc = match partition_mode { PartitionMode::CollectLeft => Arc::new(CoalescePartitionsExec::new(left)), @@ -1738,7 +1715,7 @@ mod tests { let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; + let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1772,8 +1749,8 @@ mod tests { )]; let (columns, batches) = join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Inner, false, @@ -1819,8 +1796,8 @@ mod tests { )]; let (columns, batches) = partitioned_join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Inner, false, @@ -1953,12 +1930,20 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - // expected joined records = 3 - // in case batch_size is 1 - additional empty batch for remaining 3-2 row - let mut expected_batch_count = div_ceil(3, batch_size); - if batch_size == 1 { - expected_batch_count += 1; - } + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + expected_batch_count + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(9, batch_size) + }; + assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -2015,12 +2000,20 @@ mod tests { assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); - // expected joined records = 3 - // in case batch_size is 1 - additional empty batch for remaining 3-2 row - let mut expected_batch_count = div_ceil(3, batch_size); - if batch_size == 1 { - expected_batch_count += 1; - } + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches = 3 + // in case batch_size is 1 - additional empty batch for remaining 3-2 row + let mut expected_batch_count = div_ceil(3, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + expected_batch_count + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(9, batch_size) + }; + assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -2123,15 +2116,22 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; - // expected joined records = 1 (first right batch) - // and additional empty batch for non-joined 20-6-80 - let mut expected_batch_count = div_ceil(1, batch_size); - if batch_size == 1 { - expected_batch_count += 1; - } + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches for first right batch = 1 + // and additional empty batch for non-joined 20-6-80 + let mut expected_batch_count = div_ceil(1, batch_size); + if batch_size == 1 { + expected_batch_count += 1; + } + expected_batch_count + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(6, batch_size) + }; assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -2146,11 +2146,17 @@ mod tests { assert_batches_eq!(expected, &batches); // second part - let stream = join.execute(1, task_ctx.clone())?; + let stream = join.execute(1, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; - // expected joined records = 2 (second right batch) - let expected_batch_count = div_ceil(2, batch_size); + let expected_batch_count = if cfg!(not(feature = "force_hash_collisions")) { + // Expected number of hash table matches for second right batch = 2 + div_ceil(2, batch_size) + } else { + // With hash collisions enabled, all records will match each other + // and filtered later. + div_ceil(3, batch_size) + }; assert_eq!(batches.len(), expected_batch_count); let expected = [ @@ -2361,8 +2367,8 @@ mod tests { )]; let (columns, batches) = join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Left, false, @@ -2405,8 +2411,8 @@ mod tests { )]; let (columns, batches) = partitioned_join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &JoinType::Left, false, @@ -2517,8 +2523,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::LeftSemi, @@ -2528,7 +2534,7 @@ mod tests { let columns_header = columns(&join.schema()); assert_eq!(columns_header.clone(), vec!["a1", "b1", "c1"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -2641,8 +2647,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::RightSemi, @@ -2652,7 +2658,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a2", "b2", "c2"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -2763,8 +2769,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::LeftAnti, @@ -2774,7 +2780,7 @@ mod tests { let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -2892,8 +2898,8 @@ mod tests { ); let join = join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter, &JoinType::RightAnti, @@ -2903,7 +2909,7 @@ mod tests { let columns_header = columns(&join.schema()); assert_eq!(columns_header, vec!["a2", "b2", "c2"]); - let stream = join.execute(0, task_ctx.clone())?; + let stream = join.execute(0, Arc::clone(&task_ctx))?; let batches = common::collect(stream).await?; let expected = [ @@ -3082,6 +3088,94 @@ mod tests { Ok(()) } + #[apply(batch_sizes)] + #[tokio::test] + async fn join_left_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[apply(batch_sizes)] + #[tokio::test] + async fn partitioned_join_left_mark(batch_size: usize) -> Result<()> { + let task_ctx = prepare_task_ctx(batch_size); + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (columns, batches) = partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::LeftMark, + false, + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[test] fn join_with_hash_collision() -> Result<()> { let mut hashmap_left = RawTable::with_capacity(2); @@ -3093,8 +3187,11 @@ mod tests { let random_state = RandomState::with_seeds(0, 0, 0, 0); let hashes_buff = &mut vec![0; left.num_rows()]; - let hashes = - create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?; + let hashes = create_hashes( + &[Arc::clone(&left.columns()[0])], + &random_state, + hashes_buff, + )?; // Create hash collisions (same hashes) hashmap_left.insert(hashes[0], (hashes[0], 1), |(h, _)| *h); @@ -3122,7 +3219,7 @@ mod tests { &join_hash_map, &left, &right, - &[key_column.clone()], + &[Arc::clone(&key_column)], &[key_column], false, &hashes_buffer, @@ -3130,17 +3227,13 @@ mod tests { (0, None), )?; - let mut left_ids = UInt64Builder::with_capacity(0); - left_ids.append_value(0); - left_ids.append_value(1); + let left_ids: UInt64Array = vec![0, 1].into(); - let mut right_ids = UInt32Builder::with_capacity(0); - right_ids.append_value(0); - right_ids.append_value(1); + let right_ids: UInt32Array = vec![0, 1].into(); - assert_eq!(left_ids.finish(), l); + assert_eq!(left_ids, l); - assert_eq!(right_ids.finish(), r); + assert_eq!(right_ids, r); Ok(()) } @@ -3468,6 +3561,15 @@ mod tests { "| 30 | 6 | 90 |", "+----+----+----+", ]; + let expected_left_mark = vec![ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; let test_cases = vec![ (JoinType::Inner, expected_inner), @@ -3478,17 +3580,18 @@ mod tests { (JoinType::LeftAnti, expected_left_anti), (JoinType::RightSemi, expected_right_semi), (JoinType::RightAnti, expected_right_anti), + (JoinType::LeftMark, expected_left_mark), ]; for (join_type, expected) in test_cases { let (_, batches) = join_collect_with_partition_mode( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), &join_type, PartitionMode::CollectLeft, false, - task_ctx.clone(), + Arc::clone(&task_ctx), ) .await?; assert_batches_sorted_eq!(expected, &batches); @@ -3506,13 +3609,14 @@ mod tests { let dates: ArrayRef = Arc::new(Date32Array::from(vec![19107, 19108, 19109])); let n: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3])); - let batch = RecordBatch::try_new(schema.clone(), vec![dates, n])?; - let left = - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?; + let left = Arc::new( + MemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None).unwrap(), + ); let dates: ArrayRef = Arc::new(Date32Array::from(vec![19108, 19108, 19109])); let n: ArrayRef = Arc::new(Int32Array::from(vec![4, 5, 6])); - let batch = RecordBatch::try_new(schema.clone(), vec![dates, n])?; + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![dates, n])?; let right = Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()); let on = vec![( @@ -3574,8 +3678,8 @@ mod tests { for join_type in join_types { let join = join( - left.clone(), - right_input.clone(), + Arc::clone(&left), + Arc::clone(&right_input) as Arc, on.clone(), &join_type, false, @@ -3586,10 +3690,7 @@ mod tests { let stream = join.execute(0, task_ctx).unwrap(); // Expect that an error is returned - let result_string = crate::common::collect(stream) - .await - .unwrap_err() - .to_string(); + let result_string = common::collect(stream).await.unwrap_err().to_string(); assert!( result_string.contains("bad data error"), "actual: {result_string}" @@ -3690,9 +3791,14 @@ mod tests { for batch_size in (1..21).rev() { let task_ctx = prepare_task_ctx(batch_size); - let join = - join(left.clone(), right.clone(), on.clone(), &join_type, false) - .unwrap(); + let join = join( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &join_type, + false, + ) + .unwrap(); let stream = join.execute(0, task_ctx).unwrap(); let batches = common::collect(stream).await.unwrap(); @@ -3706,9 +3812,9 @@ mod tests { | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { - (expected_resultset_records + batch_size - 1) / batch_size + div_ceil(expected_resultset_records, batch_size) } - _ => (expected_resultset_records + batch_size - 1) / batch_size + 1, + _ => div_ceil(expected_resultset_records, batch_size) + 1, }; assert_eq!( batches.len(), @@ -3757,26 +3863,32 @@ mod tests { JoinType::LeftAnti, JoinType::RightSemi, JoinType::RightAnti, + JoinType::LeftMark, ]; for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); - let join = join(left.clone(), right.clone(), on.clone(), &join_type, false)?; + let join = join( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &join_type, + false, + )?; let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); + // Asserting that operator-level reservation attempting to overallocate assert_contains!( err.to_string(), - "External error: Resources exhausted: Failed to allocate additional" + "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput" ); - - // Asserting that operator-level reservation attempting to overallocate - assert_contains!(err.to_string(), "HashJoinInput"); } Ok(()) @@ -3829,8 +3941,9 @@ mod tests { ]; for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .build_arc()?; let session_config = SessionConfig::default().with_batch_size(50); let task_ctx = TaskContext::default() .with_session_config(session_config) @@ -3838,8 +3951,8 @@ mod tests { let task_ctx = Arc::new(task_ctx); let join = HashJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(&left) as Arc, + Arc::clone(&right) as Arc, on.clone(), None, &join_type, @@ -3851,18 +3964,115 @@ mod tests { let stream = join.execute(1, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); + // Asserting that stream-level reservation attempting to overallocate assert_contains!( err.to_string(), - "External error: Resources exhausted: Failed to allocate additional" - ); + "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: HashJoinInput[1]" - // Asserting that stream-level reservation attempting to overallocate - assert_contains!(err.to_string(), "HashJoinInput[1]"); + ); } Ok(()) } + fn build_table_struct( + struct_name: &str, + field_name_and_values: (&str, &Vec>), + nulls: Option, + ) -> Arc { + let (field_name, values) = field_name_and_values; + let inner_fields = vec![Field::new(field_name, DataType::Int32, true)]; + let schema = Schema::new(vec![Field::new( + struct_name, + DataType::Struct(inner_fields.clone().into()), + nulls.is_some(), + )]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(StructArray::new( + inner_fields.into(), + vec![Arc::new(Int32Array::from(values.clone()))], + nulls, + ))], + ) + .unwrap(); + let schema_ref = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap()) + } + + #[tokio::test] + async fn join_on_struct() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None); + let right = + build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (columns, batches) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + assert_eq!(columns, vec!["n1", "n2"]); + + let expected = [ + "+--------+--------+", + "| n1 | n2 |", + "+--------+--------+", + "| {a: } | {a: } |", + "| {a: 1} | {a: 1} |", + "| {a: 2} | {a: 2} |", + "+--------+--------+", + ]; + assert_batches_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_on_struct_with_nulls() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = + build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let right = + build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1))); + let on = vec![( + Arc::new(Column::new_with_schema("n1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("n2", &right.schema())?) as _, + )]; + + let (_, batches_null_eq) = join_collect( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + &JoinType::Inner, + true, + Arc::clone(&task_ctx), + ) + .await?; + + let expected_null_eq = [ + "+----+----+", + "| n1 | n2 |", + "+----+----+", + "| | |", + "+----+----+", + ]; + assert_batches_eq!(expected_null_eq, &batches_null_eq); + + let (_, batches_null_neq) = + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; + + let expected_null_neq = + ["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"]; + assert_batches_eq!(expected_null_neq, &batches_null_neq); + + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 47e262c3c8f6..f36c2395e20f 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -25,28 +25,32 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::task::Poll; -use crate::coalesce_batches::concat_batches; +use super::utils::{ + asymmetric_join_output_partitioning, need_produce_result_in_final, BatchSplitter, + BatchTransformer, NoopBatchTransformer, StatefulStreamResult, +}; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::joins::utils::{ - adjust_indices_by_join_type, adjust_right_output_partitioning, - apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, estimate_join_statistics, get_final_indices_from_bit_map, - BuildProbeJoinMetrics, ColumnIndex, JoinFilter, OnceAsync, OnceFut, + adjust_indices_by_join_type, apply_join_filter_to_indices, build_batch_from_indices, + build_join_schema, check_join_is_valid, estimate_join_statistics, + get_final_indices_from_bit_map, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, + OnceAsync, OnceFut, }; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::{ - execution_mode_from_children, DisplayAs, DisplayFormatType, Distribution, - ExecutionMode, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, + execution_mode_from_children, handle_state, DisplayAs, DisplayFormatType, + Distribution, ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use arrow::array::{ - BooleanBufferBuilder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, -}; +use arrow::array::{BooleanBufferBuilder, UInt32Array, UInt64Array}; +use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow::util::bit_util; -use datafusion_common::{exec_err, JoinSide, Result, Statistics}; +use datafusion_common::{ + exec_datafusion_err, internal_err, JoinSide, Result, Statistics, +}; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::TaskContext; use datafusion_expr::JoinType; @@ -55,8 +59,6 @@ use datafusion_physical_expr::equivalence::join_equivalence_properties; use futures::{ready, Stream, StreamExt, TryStreamExt}; use parking_lot::Mutex; -use super::utils::need_produce_result_in_final; - /// Shared bitmap for visited left-side indices type SharedBitmapBuilder = Mutex; /// Left (build-side) data @@ -103,6 +105,7 @@ impl JoinLeftData { } } +#[allow(rustdoc::private_intra_doc_links)] /// NestedLoopJoinExec is build-probe join operator, whose main task is to /// perform joins without any equijoin conditions in `ON` clause. /// @@ -138,6 +141,9 @@ impl JoinLeftData { /// "reports" about probe phase completion (which means that "visited" bitmap won't be /// updated anymore), and only the last thread, reporting about completion, will return output. /// +/// Note that the `Clone` trait is not implemented for this struct due to the +/// `left_fut` [`OnceAsync`], which is used to coordinate the loading of the +/// left side with the processing in each output stream. #[derive(Debug)] pub struct NestedLoopJoinExec { /// left side @@ -161,7 +167,7 @@ pub struct NestedLoopJoinExec { } impl NestedLoopJoinExec { - /// Try to create a nwe [`NestedLoopJoinExec`] + /// Try to create a new [`NestedLoopJoinExec`] pub fn try_new( left: Arc, right: Arc, @@ -174,7 +180,8 @@ impl NestedLoopJoinExec { let (schema, column_indices) = build_join_schema(&left_schema, &right_schema, join_type); let schema = Arc::new(schema); - let cache = Self::compute_properties(&left, &right, schema.clone(), *join_type); + let cache = + Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type); Ok(NestedLoopJoinExec { left, @@ -222,36 +229,49 @@ impl NestedLoopJoinExec { right.equivalence_properties().clone(), &join_type, schema, - &[false, false], + &Self::maintains_input_order(join_type), None, // No on columns in nested loop join &[], ); - // Get output partitioning, - let output_partitioning = match join_type { - JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( - right.output_partitioning(), - left.schema().fields().len(), - ), - JoinType::RightSemi | JoinType::RightAnti => { - right.output_partitioning().clone() - } - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::Full => { - Partitioning::UnknownPartitioning( - right.output_partitioning().partition_count(), - ) - } - }; + let output_partitioning = + asymmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: - let mut mode = execution_mode_from_children([left, right]); - if mode.is_unbounded() { - mode = ExecutionMode::PipelineBreaking; - } + let mode = if left.execution_mode().is_unbounded() { + ExecutionMode::PipelineBreaking + } else { + execution_mode_from_children([left, right]) + }; PlanProperties::new(eq_properties, output_partitioning, mode) } + + /// Returns a vector indicating whether the left and right inputs maintain their order. + /// The first element corresponds to the left input, and the second to the right. + /// + /// The left (build-side) input's order may change, but the right (probe-side) input's + /// order is maintained for INNER, RIGHT, RIGHT ANTI, and RIGHT SEMI joins. + /// + /// Maintaining the right input's order helps optimize the nodes down the pipeline + /// (See [`ExecutionPlan::maintains_input_order`]). + /// + /// This is a separate method because it is also called when computing properties, before + /// a [`NestedLoopJoinExec`] is created. It also takes [`JoinType`] as an argument, as + /// opposed to `Self`, for the same reason. + fn maintains_input_order(join_type: JoinType) -> Vec { + vec![ + false, + matches!( + join_type, + JoinType::Inner + | JoinType::Right + | JoinType::RightAnti + | JoinType::RightSemi + ), + ] + } } impl DisplayAs for NestedLoopJoinExec { @@ -292,8 +312,12 @@ impl ExecutionPlan for NestedLoopJoinExec { ] } - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] + fn maintains_input_order(&self) -> Vec { + Self::maintains_input_order(self.join_type) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] } fn with_new_children( @@ -301,8 +325,8 @@ impl ExecutionPlan for NestedLoopJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(NestedLoopJoinExec::try_new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.filter.clone(), &self.join_type, )?)) @@ -322,26 +346,58 @@ impl ExecutionPlan for NestedLoopJoinExec { let inner_table = self.inner_table.once(|| { collect_left_input( - self.left.clone(), - context.clone(), + Arc::clone(&self.left), + Arc::clone(&context), join_metrics.clone(), load_reservation, need_produce_result_in_final(self.join_type), self.right().output_partitioning().partition_count(), ) }); + + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); + let outer_table = self.right.execute(partition, context)?; - Ok(Box::pin(NestedLoopJoinStream { - schema: self.schema.clone(), - filter: self.filter.clone(), - join_type: self.join_type, - outer_table, - inner_table, - is_exhausted: false, - column_indices: self.column_indices.clone(), - join_metrics, - })) + let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); + + // Right side has an order and it is maintained during operation. + let right_side_ordered = + self.maintains_input_order()[1] && self.right.output_ordering().is_some(); + + if enforce_batch_size_in_joins { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: BatchSplitter::new(batch_size), + left_data: None, + })) + } else { + Ok(Box::pin(NestedLoopJoinStream { + schema: Arc::clone(&self.schema), + filter: self.filter.clone(), + join_type: self.join_type, + outer_table, + inner_table, + column_indices: self.column_indices.clone(), + join_metrics, + indices_cache, + right_side_ordered, + state: NestedLoopJoinStreamState::WaitBuildSide, + batch_transformer: NoopBatchTransformer::new(), + left_data: None, + })) + } } fn metrics(&self) -> Option { @@ -350,8 +406,8 @@ impl ExecutionPlan for NestedLoopJoinExec { fn statistics(&self) -> Result { estimate_join_statistics( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), vec![], &self.join_type, &self.schema, @@ -377,19 +433,17 @@ async fn collect_left_input( let stream = merge.execute(0, context)?; // Load all batches and count the rows - let (batches, num_rows, metrics, mut reservation) = stream + let (batches, metrics, mut reservation) = stream .try_fold( - (Vec::new(), 0usize, join_metrics, reservation), + (Vec::new(), join_metrics, reservation), |mut acc, batch| async { let batch_size = batch.get_array_memory_size(); // Reserve memory for incoming batch - acc.3.try_grow(batch_size)?; + acc.2.try_grow(batch_size)?; // Update metrics - acc.2.build_mem_used.add(batch_size); - acc.2.build_input_batches.add(1); - acc.2.build_input_rows.add(batch.num_rows()); - // Update rowcount - acc.1 += batch.num_rows(); + acc.1.build_mem_used.add(batch_size); + acc.1.build_input_batches.add(1); + acc.1.build_input_rows.add(batch.num_rows()); // Push batch to output acc.0.push(batch); Ok(acc) @@ -397,7 +451,7 @@ async fn collect_left_input( ) .await?; - let merged_batch = concat_batches(&schema, &batches, num_rows)?; + let merged_batch = concat_batches(&schema, &batches)?; // Reserve memory for visited_left_side bitmap if required by join type let visited_left_side = if with_visited_left_side { @@ -422,8 +476,37 @@ async fn collect_left_input( )) } +/// This enumeration represents various states of the nested loop join algorithm. +#[derive(Debug, Clone)] +enum NestedLoopJoinStreamState { + /// The initial state, indicating that build-side data not collected yet + WaitBuildSide, + /// Indicates that build-side has been collected, and stream is ready for + /// fetching probe-side + FetchProbeBatch, + /// Indicates that a non-empty batch has been fetched from probe-side, and + /// is ready to be processed + ProcessProbeBatch(RecordBatch), + /// Indicates that probe-side has been fully processed + ExhaustedProbeSide, + /// Indicates that NestedLoopJoinStream execution is completed + Completed, +} + +impl NestedLoopJoinStreamState { + /// Tries to extract a `ProcessProbeBatchState` from the + /// `NestedLoopJoinStreamState` enum. Returns an error if state is not + /// `ProcessProbeBatchState`. + fn try_as_process_probe_batch(&mut self) -> Result<&RecordBatch> { + match self { + NestedLoopJoinStreamState::ProcessProbeBatch(state) => Ok(state), + _ => internal_err!("Expected join stream in ProcessProbeBatch state"), + } + } +} + /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct NestedLoopJoinStream { +struct NestedLoopJoinStream { /// Input schema schema: Arc, /// join filter @@ -434,29 +517,86 @@ struct NestedLoopJoinStream { outer_table: SendableRecordBatchStream, /// the inner table data of the nested loop join inner_table: OnceFut, - /// There is nothing to process anymore and left side is processed in case of full join - is_exhausted: bool, /// Information of index and left / right placement of columns column_indices: Vec, // TODO: support null aware equal // null_equals_null: bool /// Join execution metrics join_metrics: BuildProbeJoinMetrics, + /// Cache for join indices calculations + indices_cache: (UInt64Array, UInt32Array), + /// Whether the right side is ordered + right_side_ordered: bool, + /// Current state of the stream + state: NestedLoopJoinStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, + /// Result of the left data future + left_data: Option>, } +/// Creates a Cartesian product of two input batches, preserving the order of the right batch, +/// and applying a join filter if provided. +/// +/// # Example +/// Input: +/// left = [0, 1], right = [0, 1, 2] +/// +/// Output: +/// left_indices = [0, 1, 0, 1, 0, 1], right_indices = [0, 0, 1, 1, 2, 2] +/// +/// Input: +/// left = [0, 1, 2], right = [0, 1, 2, 3], filter = left.a != right.a +/// +/// Output: +/// left_indices = [1, 2, 0, 2, 0, 1, 0, 1, 2], right_indices = [0, 0, 1, 1, 2, 2, 3, 3, 3] fn build_join_indices( - left_row_index: usize, - right_batch: &RecordBatch, left_batch: &RecordBatch, + right_batch: &RecordBatch, filter: Option<&JoinFilter>, + indices_cache: &mut (UInt64Array, UInt32Array), ) -> Result<(UInt64Array, UInt32Array)> { - // left indices: [left_index, left_index, ...., left_index] - // right indices: [0, 1, 2, 3, 4,....,right_row_count] - + let left_row_count = left_batch.num_rows(); let right_row_count = right_batch.num_rows(); - let left_indices = UInt64Array::from(vec![left_row_index as u64; right_row_count]); - let right_indices = UInt32Array::from_iter_values(0..(right_row_count as u32)); - // in the nested loop join, the filter can contain non-equal and equal condition. + let output_row_count = left_row_count * right_row_count; + + // We always use the same indices before applying the filter, so we can cache them + let (left_indices_cache, right_indices_cache) = indices_cache; + let cached_output_row_count = left_indices_cache.len(); + + let (left_indices, right_indices) = + match output_row_count.cmp(&cached_output_row_count) { + std::cmp::Ordering::Equal => { + // Reuse the cached indices + (left_indices_cache.clone(), right_indices_cache.clone()) + } + std::cmp::Ordering::Less => { + // Left_row_count never changes because it's the build side. The changes to the + // right_row_count can be handled trivially by taking the first output_row_count + // elements of the cache because of how the indices are generated. + // (See the Ordering::Greater match arm) + ( + left_indices_cache.slice(0, output_row_count), + right_indices_cache.slice(0, output_row_count), + ) + } + std::cmp::Ordering::Greater => { + // Rebuild the indices cache + + // Produces 0, 1, 2, 0, 1, 2, 0, 1, 2, ... + *left_indices_cache = UInt64Array::from_iter_values( + (0..output_row_count as u64).map(|i| i % left_row_count as u64), + ); + + // Produces 0, 0, 0, 1, 1, 1, 2, 2, 2, ... + *right_indices_cache = UInt32Array::from_iter_values( + (0..output_row_count as u32).map(|i| i / left_row_count as u32), + ); + + (left_indices_cache.clone(), right_indices_cache.clone()) + } + }; + if let Some(filter) = filter { apply_join_filter_to_indices( left_batch, @@ -471,103 +611,168 @@ fn build_join_indices( } } -impl NestedLoopJoinStream { +impl NestedLoopJoinStream { fn poll_next_impl( &mut self, cx: &mut std::task::Context<'_>, ) -> Poll>> { - // all left row + loop { + return match self.state { + NestedLoopJoinStreamState::WaitBuildSide => { + handle_state!(ready!(self.collect_build_side(cx))) + } + NestedLoopJoinStreamState::FetchProbeBatch => { + handle_state!(ready!(self.fetch_probe_batch(cx))) + } + NestedLoopJoinStreamState::ProcessProbeBatch(_) => { + handle_state!(self.process_probe_batch()) + } + NestedLoopJoinStreamState::ExhaustedProbeSide => { + handle_state!(self.process_unmatched_build_batch()) + } + NestedLoopJoinStreamState::Completed => Poll::Ready(None), + }; + } + } + + fn collect_build_side( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { let build_timer = self.join_metrics.build_time.timer(); - let left_data = match ready!(self.inner_table.get_shared(cx)) { - Ok(data) => data, - Err(e) => return Poll::Ready(Some(Err(e))), - }; + // build hash table from left (build) side, if not yet done + self.left_data = Some(ready!(self.inner_table.get_shared(cx))?); build_timer.done(); - // Get or initialize visited_left_side bitmap if required by join type + self.state = NestedLoopJoinStreamState::FetchProbeBatch; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Fetches next batch from probe-side + /// + /// If a non-empty batch has been fetched, updates state to + /// `ProcessProbeBatchState`, otherwise updates state to `ExhaustedProbeSide`. + fn fetch_probe_batch( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll>>> { + match ready!(self.outer_table.poll_next_unpin(cx)) { + None => { + self.state = NestedLoopJoinStreamState::ExhaustedProbeSide; + } + Some(Ok(right_batch)) => { + self.state = NestedLoopJoinStreamState::ProcessProbeBatch(right_batch); + } + Some(Err(err)) => return Poll::Ready(Err(err)), + }; + + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + + /// Joins current probe batch with build-side data and produces batch with + /// matched output, updates state to `FetchProbeBatch`. + fn process_probe_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ProcessProbeBatch state" + ); + }; let visited_left_side = left_data.bitmap(); + let batch = self.state.try_as_process_probe_batch()?; + + match self.batch_transformer.next() { + None => { + // Setting up timer & updating input metrics + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + timer.done(); - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) && !self.is_exhausted - { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return None; - }; - - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) - } else { - // end of the join loop - None - } + self.batch_transformer.set_batch(result?); + Ok(StatefulStreamResult::Continue) + } + Some((batch, last)) => { + if last { + self.state = NestedLoopJoinStreamState::FetchProbeBatch; } - }) + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Ok(StatefulStreamResult::Ready(Some(batch))) + } + } + } + + /// Processes unmatched build-side rows for certain join types and produces + /// output batch, updates state to `Completed`. + fn process_unmatched_build_batch( + &mut self, + ) -> Result>> { + let Some(left_data) = self.left_data.clone() else { + return internal_err!( + "Expected left_data to be Some in ExhaustedProbeSide state" + ); + }; + let visited_left_side = left_data.bitmap(); + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` / returning None will prevent from + // multiple calls of `report_probe_completed()` + if !left_data.report_probe_completed() { + self.state = NestedLoopJoinStreamState::Completed; + return Ok(StatefulStreamResult::Ready(None)); + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap(visited_left_side, self.join_type); + let empty_right_batch = RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.state = NestedLoopJoinStreamState::Completed; + + // Recording time + if result.is_ok() { + timer.done(); + } + + Ok(StatefulStreamResult::Ready(Some(result?))) + } else { + // end of the join loop + self.state = NestedLoopJoinStreamState::Completed; + Ok(StatefulStreamResult::Ready(None)) + } } } +#[allow(clippy::too_many_arguments)] fn join_left_and_right_batch( left_batch: &RecordBatch, right_batch: &RecordBatch, @@ -576,62 +781,44 @@ fn join_left_and_right_batch( column_indices: &[ColumnIndex], schema: &Schema, visited_left_side: &SharedBitmapBuilder, + indices_cache: &mut (UInt64Array, UInt32Array), + right_side_ordered: bool, ) -> Result { - let indices_result = (0..left_batch.num_rows()) - .map(|left_row_index| { - build_join_indices(left_row_index, right_batch, left_batch, filter) - }) - .collect::>>(); - - let mut left_indices_builder = UInt64Builder::new(); - let mut right_indices_builder = UInt32Builder::new(); - let left_right_indices = match indices_result { - Err(err) => { - exec_err!("Fail to build join indices in NestedLoopJoinExec, error:{err}") - } - Ok(indices) => { - for (left_side, right_side) in indices { - left_indices_builder - .append_values(left_side.values(), &vec![true; left_side.len()]); - right_indices_builder - .append_values(right_side.values(), &vec![true; right_side.len()]); - } - Ok(( - left_indices_builder.finish(), - right_indices_builder.finish(), - )) - } - }; - match left_right_indices { - Ok((left_side, right_side)) => { - // set the left bitmap - // and only full join need the left bitmap - if need_produce_result_in_final(join_type) { - let mut bitmap = visited_left_side.lock(); - left_side.iter().flatten().for_each(|x| { - bitmap.set_bit(x as usize, true); - }); - } - // adjust the two side indices base on the join type - let (left_side, right_side) = adjust_indices_by_join_type( - left_side, - right_side, - 0..right_batch.num_rows(), - join_type, - ); - - build_batch_from_indices( - schema, - left_batch, - right_batch, - &left_side, - &right_side, - column_indices, - JoinSide::Left, - ) - } - Err(e) => Err(e), + let (left_side, right_side) = + build_join_indices(left_batch, right_batch, filter, indices_cache).map_err( + |e| { + exec_datafusion_err!( + "Fail to build join indices in NestedLoopJoinExec, error: {e}" + ) + }, + )?; + + // set the left bitmap + // and only full join need the left bitmap + if need_produce_result_in_final(join_type) { + let mut bitmap = visited_left_side.lock(); + left_side.values().iter().for_each(|x| { + bitmap.set_bit(*x as usize, true); + }); } + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + 0..right_batch.num_rows(), + join_type, + right_side_ordered, + )?; + + build_batch_from_indices( + schema, + left_batch, + right_batch, + &left_side, + &right_side, + column_indices, + JoinSide::Left, + ) } fn get_final_indices_from_shared_bitmap( @@ -642,7 +829,7 @@ fn get_final_indices_from_shared_bitmap( get_final_indices_from_bit_map(&bitmap, join_type) } -impl Stream for NestedLoopJoinStream { +impl Stream for NestedLoopJoinStream { type Item = Result; fn poll_next( @@ -653,15 +840,14 @@ impl Stream for NestedLoopJoinStream { } } -impl RecordBatchStream for NestedLoopJoinStream { +impl RecordBatchStream for NestedLoopJoinStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } #[cfg(test)] -mod tests { - +pub(crate) mod tests { use super::*; use crate::{ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec, @@ -669,20 +855,59 @@ mod tests { }; use arrow::datatypes::{DataType, Field}; + use arrow_array::Int32Array; + use arrow_schema::SortOptions; use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; - use datafusion_physical_expr::PhysicalExpr; + use datafusion_physical_expr::{Partitioning, PhysicalExpr}; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; + + use rstest::rstest; fn build_table( a: (&str, &Vec), b: (&str, &Vec), c: (&str, &Vec), + batch_size: Option, + sorted_column_names: Vec<&str>, ) -> Arc { let batch = build_table_i32(a, b, c); let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + + let batches = if let Some(batch_size) = batch_size { + let num_batches = batch.num_rows().div_ceil(batch_size); + (0..num_batches) + .map(|i| { + let start = i * batch_size; + let remaining_rows = batch.num_rows() - start; + batch.slice(start, batch_size.min(remaining_rows)) + }) + .collect::>() + } else { + vec![batch] + }; + + let mut exec = + MemoryExec::try_new(&[batches], Arc::clone(&schema), None).unwrap(); + if !sorted_column_names.is_empty() { + let mut sort_info = LexOrdering::default(); + for name in sorted_column_names { + let index = schema.index_of(name).unwrap(); + let sort_expr = PhysicalSortExpr { + expr: Arc::new(Column::new(name, index)), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }; + sort_info.push(sort_expr); + } + exec = exec.try_with_sort_information(vec![sort_info]).unwrap(); + } + + Arc::new(exec) } fn build_left_table() -> Arc { @@ -690,6 +915,8 @@ mod tests { ("a1", &vec![5, 9, 11]), ("b1", &vec![5, 8, 8]), ("c1", &vec![50, 90, 110]), + None, + Vec::new(), ) } @@ -698,6 +925,8 @@ mod tests { ("a2", &vec![12, 2, 10]), ("b2", &vec![10, 2, 10]), ("c2", &vec![40, 80, 100]), + None, + Vec::new(), ) } @@ -745,7 +974,7 @@ mod tests { JoinFilter::new(filter_expression, column_indices, intermediate_schema) } - async fn multi_partitioned_join_collect( + pub(crate) async fn multi_partitioned_join_collect( left: Arc, right: Arc, join_type: &JoinType, @@ -766,7 +995,7 @@ mod tests { let columns = columns(&nested_loop_join.schema()); let mut batches = vec![]; for i in 0..partition_count { - let stream = nested_loop_join.execute(i, context.clone())?; + let stream = nested_loop_join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1019,17 +1248,52 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_left_mark_with_filter() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let left = build_left_table(); + let right = build_right_table(); + + let filter = prepare_join_filter(); + let (columns, batches) = multi_partitioned_join_collect( + left, + right, + &JoinType::LeftMark, + Some(filter), + task_ctx, + ) + .await?; + assert_eq!(columns, vec!["a1", "b1", "c1", "mark"]); + let expected = [ + "+----+----+-----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+-----+-------+", + "| 11 | 8 | 110 | false |", + "| 5 | 5 | 50 | true |", + "| 9 | 8 | 90 | false |", + "+----+----+-----+-------+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[tokio::test] async fn test_overallocation() -> Result<()> { let left = build_table( ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + None, + Vec::new(), ); let right = build_table( ("a2", &vec![10, 11]), ("b2", &vec![12, 13]), ("c2", &vec![14, 15]), + None, + Vec::new(), ); let filter = prepare_join_filter(); @@ -1040,19 +1304,21 @@ mod tests { JoinType::Full, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightSemi, JoinType::RightAnti, ]; for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); let err = multi_partitioned_join_collect( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), &join_type, Some(filter.clone()), task_ctx, @@ -1062,9 +1328,166 @@ mod tests { assert_contains!( err.to_string(), - "External error: Resources exhausted: Failed to allocate additional" + "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" ); - assert_contains!(err.to_string(), "NestedLoopJoinLoad[0]"); + } + + Ok(()) + } + + fn prepare_mod_join_filter() -> JoinFilter { + let column_indices = vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 1, + side: JoinSide::Right, + }, + ]; + let intermediate_schema = Schema::new(vec![ + Field::new("x", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + ]); + + // left.b1 % 3 + let left_mod = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Modulo, + Arc::new(Literal::new(ScalarValue::Int32(Some(3)))), + )) as Arc; + // left.b1 % 3 != 0 + let left_filter = Arc::new(BinaryExpr::new( + left_mod, + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )) as Arc; + + // right.b2 % 5 + let right_mod = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 1)), + Operator::Modulo, + Arc::new(Literal::new(ScalarValue::Int32(Some(5)))), + )) as Arc; + // right.b2 % 5 != 0 + let right_filter = Arc::new(BinaryExpr::new( + right_mod, + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(0)))), + )) as Arc; + // filter = left.b1 % 3 != 0 and right.b2 % 5 != 0 + let filter_expression = + Arc::new(BinaryExpr::new(left_filter, Operator::And, right_filter)) + as Arc; + + JoinFilter::new(filter_expression, column_indices, intermediate_schema) + } + + fn generate_columns(num_columns: usize, num_rows: usize) -> Vec> { + let column = (1..=num_rows).map(|x| x as i32).collect(); + vec![column; num_columns] + } + + #[rstest] + #[tokio::test] + async fn join_maintains_right_order( + #[values( + JoinType::Inner, + JoinType::Right, + JoinType::RightAnti, + JoinType::RightSemi + )] + join_type: JoinType, + #[values(1, 100, 1000)] left_batch_size: usize, + #[values(1, 100, 1000)] right_batch_size: usize, + ) -> Result<()> { + let left_columns = generate_columns(3, 1000); + let left = build_table( + ("a1", &left_columns[0]), + ("b1", &left_columns[1]), + ("c1", &left_columns[2]), + Some(left_batch_size), + Vec::new(), + ); + + let right_columns = generate_columns(3, 1000); + let right = build_table( + ("a2", &right_columns[0]), + ("b2", &right_columns[1]), + ("c2", &right_columns[2]), + Some(right_batch_size), + vec!["a2", "b2", "c2"], + ); + + let filter = prepare_mod_join_filter(); + + let nested_loop_join = Arc::new(NestedLoopJoinExec::try_new( + left, + Arc::clone(&right), + Some(filter), + &join_type, + )?) as Arc; + assert_eq!(nested_loop_join.maintains_input_order(), vec![false, true]); + + let right_column_indices = match join_type { + JoinType::Inner | JoinType::Right => vec![3, 4, 5], + JoinType::RightAnti | JoinType::RightSemi => vec![0, 1, 2], + _ => unreachable!(), + }; + + let right_ordering = right.output_ordering().unwrap(); + let join_ordering = nested_loop_join.output_ordering().unwrap(); + for (right, join) in right_ordering.iter().zip(join_ordering.iter()) { + let right_column = right.expr.as_any().downcast_ref::().unwrap(); + let join_column = join.expr.as_any().downcast_ref::().unwrap(); + assert_eq!(join_column.name(), join_column.name()); + assert_eq!( + right_column_indices[right_column.index()], + join_column.index() + ); + assert_eq!(right.options, join.options); + } + + let batches = nested_loop_join + .execute(0, Arc::new(TaskContext::default()))? + .try_collect::>() + .await?; + + // Make sure that the order of the right side is maintained + let mut prev_values = [i32::MIN, i32::MIN, i32::MIN]; + + for (batch_index, batch) in batches.iter().enumerate() { + let columns: Vec<_> = right_column_indices + .iter() + .map(|&i| { + batch + .column(i) + .as_any() + .downcast_ref::() + .unwrap() + }) + .collect(); + + for row in 0..batch.num_rows() { + let current_values = [ + columns[0].value(row), + columns[1].value(row), + columns[2].value(row), + ]; + assert!( + current_values + .into_iter() + .zip(prev_values) + .all(|(current, prev)| current >= prev), + "batch_index: {} row: {} current: {:?}, prev: {:?}", + batch_index, + row, + current_values, + prev_values + ); + prev_values = current_values; + } } Ok(()) diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 2da88173a410..3ad892c880f6 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -22,43 +22,56 @@ use std::any::Any; use std::cmp::Ordering; -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use std::fmt::Formatter; -use std::mem; +use std::fs::File; +use std::io::BufReader; +use std::mem::size_of; use std::ops::Range; use std::pin::Pin; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; -use crate::joins::utils::{ - build_join_schema, check_join_is_valid, estimate_join_statistics, - partitioned_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, -}; -use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; -use crate::{ - execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, - ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, - RecordBatchStream, SendableRecordBatchStream, Statistics, -}; - use arrow::array::*; -use arrow::compute::{self, concat_batches, take, SortOptions}; +use arrow::compute::{ + self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, +}; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; +use arrow::ipc::reader::FileReader; +use arrow_array::types::UInt64Type; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, Result, + exec_err, internal_err, not_impl_err, plan_err, DataFusionError, JoinSide, JoinType, + Result, }; +use datafusion_execution::disk_manager::RefCountedTempFile; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; +use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::join_equivalence_properties; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; - +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use futures::{Stream, StreamExt}; +use hashbrown::HashSet; + +use crate::expressions::PhysicalSortExpr; +use crate::joins::utils::{ + build_join_schema, check_join_is_valid, estimate_join_statistics, + symmetric_join_output_partitioning, JoinFilter, JoinOn, JoinOnRef, +}; +use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use crate::spill::spill_record_batches; +use crate::{ + execution_mode_from_children, metrics, DisplayAs, DisplayFormatType, Distribution, + ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, Statistics, +}; /// join execution plan executes partitions in parallel and combines them into a set of /// partitions. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SortMergeJoinExec { /// Left sorted joining execution plan pub left: Arc, @@ -75,9 +88,9 @@ pub struct SortMergeJoinExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// The left SortExpr - left_sort_exprs: Vec, + left_sort_exprs: LexOrdering, /// The right SortExpr - right_sort_exprs: Vec, + right_sort_exprs: LexOrdering, /// Sort options of join columns used in sorting left and right execution plans pub sort_options: Vec, /// If null_equals_null is true, null == null else null != null @@ -123,11 +136,11 @@ impl SortMergeJoinExec { .zip(sort_options.iter()) .map(|((l, r), sort_op)| { let left = PhysicalSortExpr { - expr: l.clone(), + expr: Arc::clone(l), options: *sort_op, }; let right = PhysicalSortExpr { - expr: r.clone(), + expr: Arc::clone(r), options: *sort_op, }; (left, right) @@ -137,7 +150,7 @@ impl SortMergeJoinExec { let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); let cache = - Self::compute_properties(&left, &right, schema.clone(), join_type, &on); + Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on); Ok(Self { left, right, @@ -146,8 +159,8 @@ impl SortMergeJoinExec { join_type, schema, metrics: ExecutionPlanMetricsSet::new(), - left_sort_exprs, - right_sort_exprs, + left_sort_exprs: LexOrdering::new(left_sort_exprs), + right_sort_exprs: LexOrdering::new(right_sort_exprs), sort_options, null_equals_null, cache, @@ -167,7 +180,8 @@ impl SortMergeJoinExec { | JoinType::Left | JoinType::Full | JoinType::LeftAnti - | JoinType::LeftSemi => JoinSide::Left, + | JoinType::LeftSemi + | JoinType::LeftMark => JoinSide::Left, } } @@ -175,7 +189,10 @@ impl SortMergeJoinExec { fn maintains_input_order(join_type: JoinType) -> Vec { match join_type { JoinType::Inner => vec![true, false], - JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => vec![true, false], + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::LeftMark => vec![true, false], JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => { vec![false, true] } @@ -219,14 +236,8 @@ impl SortMergeJoinExec { join_on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); - let output_partitioning = partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ); + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mode = execution_mode_from_children([left, right]); @@ -237,11 +248,6 @@ impl SortMergeJoinExec { impl DisplayAs for SortMergeJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - let display_filter = self.filter.as_ref().map_or_else( - || "".to_string(), - |f| format!(", filter={}", f.expression()), - ); - match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let on = self @@ -253,7 +259,12 @@ impl DisplayAs for SortMergeJoinExec { write!( f, "SortMergeJoin: join_type={:?}, on=[{}]{}", - self.join_type, on, display_filter + self.join_type, + on, + self.filter.as_ref().map_or("".to_string(), |f| format!( + ", filter={}", + f.expression() + )) ) } } @@ -274,21 +285,24 @@ impl ExecutionPlan for SortMergeJoinExec { } fn required_input_distribution(&self) -> Vec { - let (left_expr, right_expr) = - self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip(); + let (left_expr, right_expr) = self + .on + .iter() + .map(|(l, r)| (Arc::clone(l), Arc::clone(r))) + .unzip(); vec![ Distribution::HashPartitioned(left_expr), Distribution::HashPartitioned(right_expr), ] } - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { vec![ Some(PhysicalSortRequirement::from_sort_exprs( - &self.left_sort_exprs, + self.left_sort_exprs.iter(), )), Some(PhysicalSortRequirement::from_sort_exprs( - &self.right_sort_exprs, + self.right_sort_exprs.iter(), )), ] } @@ -297,8 +311,8 @@ impl ExecutionPlan for SortMergeJoinExec { Self::maintains_input_order(self.join_type) } - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] } fn with_new_children( @@ -307,8 +321,8 @@ impl ExecutionPlan for SortMergeJoinExec { ) -> Result> { match &children[..] { [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right.clone(), + Arc::clone(left), + Arc::clone(right), self.on.clone(), self.filter.clone(), self.join_type, @@ -335,14 +349,24 @@ impl ExecutionPlan for SortMergeJoinExec { let (on_left, on_right) = self.on.iter().cloned().unzip(); let (streamed, buffered, on_streamed, on_buffered) = if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left { - (self.left.clone(), self.right.clone(), on_left, on_right) + ( + Arc::clone(&self.left), + Arc::clone(&self.right), + on_left, + on_right, + ) } else { - (self.right.clone(), self.left.clone(), on_right, on_left) + ( + Arc::clone(&self.right), + Arc::clone(&self.left), + on_right, + on_left, + ) }; // execute children plans - let streamed = streamed.execute(partition, context.clone())?; - let buffered = buffered.execute(partition, context.clone())?; + let streamed = streamed.execute(partition, Arc::clone(&context))?; + let buffered = buffered.execute(partition, Arc::clone(&context))?; // create output buffer let batch_size = context.session_config().batch_size(); @@ -353,7 +377,7 @@ impl ExecutionPlan for SortMergeJoinExec { // create join stream Ok(Box::pin(SMJStream::try_new( - self.schema.clone(), + Arc::clone(&self.schema), self.sort_options.clone(), self.null_equals_null, streamed, @@ -365,6 +389,7 @@ impl ExecutionPlan for SortMergeJoinExec { batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), reservation, + context.runtime_env(), )?)) } @@ -377,8 +402,8 @@ impl ExecutionPlan for SortMergeJoinExec { // There are some special cases though, for example: // - `A LEFT JOIN B ON A.col=B.col` with `COUNT_DISTINCT(B.col)=COUNT(B.col)` estimate_join_statistics( - self.left.clone(), - self.right.clone(), + Arc::clone(&self.left), + Arc::clone(&self.right), self.on.clone(), &self.join_type, &self.schema, @@ -392,16 +417,22 @@ struct SortMergeJoinMetrics { /// Total time for joining probe-side batches to the build-side batches join_time: metrics::Time, /// Number of batches consumed by this operator - input_batches: metrics::Count, + input_batches: Count, /// Number of rows consumed by this operator - input_rows: metrics::Count, + input_rows: Count, /// Number of batches produced by this operator - output_batches: metrics::Count, + output_batches: Count, /// Number of rows produced by this operator - output_rows: metrics::Count, + output_rows: Count, /// Peak memory used for buffered data. /// Calculated as sum of peak memory values across partitions peak_mem_used: metrics::Gauge, + /// count of spills during the execution of the operator + spill_count: Count, + /// total spilled bytes during the execution of the operator + spilled_bytes: Count, + /// total spilled rows during the execution of the operator + spilled_rows: Count, } impl SortMergeJoinMetrics { @@ -415,6 +446,9 @@ impl SortMergeJoinMetrics { MetricBuilder::new(metrics).counter("output_batches", partition); let output_rows = MetricBuilder::new(metrics).output_rows(partition); let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); + let spill_count = MetricBuilder::new(metrics).spill_count(partition); + let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition); + let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition); Self { join_time, @@ -423,6 +457,9 @@ impl SortMergeJoinMetrics { output_batches, output_rows, peak_mem_used, + spill_count, + spilled_bytes, + spilled_rows, } } } @@ -475,6 +512,7 @@ struct StreamedJoinedChunk { /// Array builder for streamed indices streamed_indices: UInt64Builder, /// Array builder for buffered indices + /// This could contain nulls if the join is null-joined buffered_indices: UInt64Builder, } @@ -486,11 +524,14 @@ struct StreamedBatch { /// The join key arrays of streamed batch which are used to compare with buffered batches /// and to produce output. They are produced by evaluating `on` expressions. pub join_arrays: Vec, - /// Chunks of indices from buffered side (may be nulls) joined to streamed pub output_indices: Vec, /// Index of currently scanned batch from buffered data pub buffered_batch_idx: Option, + /// Indices that found a match for the given join filter + /// Used for semi joins to keep track the streaming index which got a join filter match + /// and already emitted to the output. + pub join_filter_matched_idxs: HashSet, } impl StreamedBatch { @@ -502,6 +543,7 @@ impl StreamedBatch { join_arrays, output_indices: vec![], buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), } } @@ -512,6 +554,7 @@ impl StreamedBatch { join_arrays: vec![], output_indices: vec![], buffered_batch_idx: None, + join_filter_matched_idxs: HashSet::new(), } } @@ -549,7 +592,8 @@ impl StreamedBatch { #[derive(Debug)] struct BufferedBatch { /// The buffered record batch - pub batch: RecordBatch, + /// None if the batch spilled to disk th + pub batch: Option, /// The range in which the rows share the same join key pub range: Range, /// Array refs of the join key @@ -558,6 +602,19 @@ struct BufferedBatch { pub null_joined: Vec, /// Size estimation used for reserving / releasing memory pub size_estimation: usize, + /// The indices of buffered batch that failed the join filter. + /// This is a map between buffered row index and a boolean value indicating whether all joined row + /// of the buffered row failed the join filter. + /// When dequeuing the buffered batch, we need to produce null joined rows for these indices. + pub join_filter_failed_map: HashMap, + /// Current buffered batch number of rows. Equal to batch.num_rows() + /// but if batch is spilled to disk this property is preferable + /// and less expensive + pub num_rows: usize, + /// An optional temp spill file name on the disk if the batch spilled + /// None by default + /// Some(fileName) if the batch spilled to the disk + pub spill_file: Option, } impl BufferedBatch { @@ -579,16 +636,20 @@ impl BufferedBatch { .iter() .map(|arr| arr.get_array_memory_size()) .sum::() - + batch.num_rows().next_power_of_two() * mem::size_of::() - + mem::size_of::>() - + mem::size_of::(); + + batch.num_rows().next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + let num_rows = batch.num_rows(); BufferedBatch { - batch, + batch: Some(batch), range, join_arrays, null_joined: vec![], size_estimation, + join_filter_failed_map: HashMap::new(), + num_rows, + spill_file: None, } } } @@ -614,7 +675,7 @@ struct SMJStream { pub buffered: SendableRecordBatchStream, /// Current processing record batch of streamed pub streamed_batch: StreamedBatch, - /// Currrent buffered data + /// Current buffered data pub buffered_data: BufferedData, /// (used in outer join) Is current streamed row joined at least once? pub streamed_joined: bool, @@ -633,7 +694,7 @@ struct SMJStream { /// optional join filter pub filter: Option, /// Staging output array builders - pub output_record_batches: Vec, + pub output_record_batches: JoinedRecordBatches, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready @@ -646,11 +707,158 @@ struct SMJStream { pub join_metrics: SortMergeJoinMetrics, /// Memory reservation pub reservation: MemoryReservation, + /// Runtime env + pub runtime_env: Arc, + /// A unique number for each batch + pub streamed_batch_counter: AtomicUsize, +} + +/// Joined batches with attached join filter information +struct JoinedRecordBatches { + /// Joined batches. Each batch is already joined columns from left and right sources + pub batches: Vec, + /// Filter match mask for each row(matched/non-matched) + pub filter_mask: BooleanBuilder, + /// Row indices to glue together rows in `batches` and `filter_mask` + pub row_indices: UInt64Builder, + /// Which unique batch id the row belongs to + /// It is necessary to differentiate rows that are distributed the way when they point to the same + /// row index but in not the same batches + pub batch_ids: Vec, } impl RecordBatchStream for SMJStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) + } +} + +/// True if next index refers to either: +/// - another batch id +/// - another row index within same batch id +/// - end of row indices +#[inline(always)] +fn last_index_for_row( + row_index: usize, + indices: &UInt64Array, + batch_ids: &[usize], + indices_len: usize, +) -> bool { + row_index == indices_len - 1 + || batch_ids[row_index] != batch_ids[row_index + 1] + || indices.value(row_index) != indices.value(row_index + 1) +} + +// Returns a corrected boolean bitmask for the given join type +// Values in the corrected bitmask can be: true, false, null +// `true` - the row found its match and sent to the output +// `null` - the row ignored, no output +// `false` - the row sent as NULL joined row +fn get_corrected_filter_mask( + join_type: JoinType, + row_indices: &UInt64Array, + batch_ids: &[usize], + filter_mask: &BooleanArray, + expected_size: usize, +) -> Option { + let row_indices_length = row_indices.len(); + let mut corrected_mask: BooleanBuilder = + BooleanBuilder::with_capacity(row_indices_length); + let mut seen_true = false; + + match join_type { + JoinType::Left | JoinType::Right => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(false); null_matched]); + Some(corrected_mask.finish()) + } + JoinType::LeftMark => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else if seen_true || !filter_mask.value(i) && !last_index { + corrected_mask.append_null(); // to be ignored and not set to output + } else { + corrected_mask.append_value(false); // to be converted to null joined row + } + + if last_index { + seen_true = false; + } + } + + // Generate null joined rows for records which have no matching join key + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(false); null_matched]); + Some(corrected_mask.finish()) + } + JoinType::LeftSemi => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + if filter_mask.value(i) && !seen_true { + seen_true = true; + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); // to be ignored and not set to output + } + + if last_index { + seen_true = false; + } + } + + Some(corrected_mask.finish()) + } + JoinType::LeftAnti => { + for i in 0..row_indices_length { + let last_index = + last_index_for_row(i, row_indices, batch_ids, row_indices_length); + + if filter_mask.value(i) { + seen_true = true; + } + + if last_index { + if !seen_true { + corrected_mask.append_value(true); + } else { + corrected_mask.append_null(); + } + + seen_true = false; + } else { + corrected_mask.append_null(); + } + } + + let null_matched = expected_size - corrected_mask.len(); + corrected_mask.extend(vec![Some(true); null_matched]); + Some(corrected_mask.finish()) + } + // Only outer joins needs to keep track of processed rows and apply corrected filter mask + _ => None, } } @@ -663,7 +871,6 @@ impl Stream for SMJStream { ) -> Poll> { let join_time = self.join_metrics.join_time.clone(); let _timer = join_time.timer(); - loop { match &self.state { SMJState::Init => { @@ -677,6 +884,28 @@ impl Stream for SMJStream { match self.current_ordering { Ordering::Less | Ordering::Equal => { if !streamed_exhausted { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Right + | JoinType::LeftAnti + ) + { + self.freeze_all()?; + + if !self.output_record_batches.batches.is_empty() + { + let out_filtered_batch = + self.filter_joined_batch()?; + return Poll::Ready(Some(Ok( + out_filtered_batch, + ))); + } + } + self.streamed_joined = false; self.streamed_state = StreamedState::Init; } @@ -730,8 +959,26 @@ impl Stream for SMJStream { } } else { self.freeze_all()?; - if !self.output_record_batches.is_empty() { + if !self.output_record_batches.batches.is_empty() { let record_batch = self.output_record_batch_and_reset()?; + // For non-filtered join output whenever the target output batch size + // is hit. For filtered join its needed to output on later phase + // because target output batch size can be hit in the middle of + // filtering causing the filtering to be incomplete and causing + // correctness issues + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::LeftMark + ) + { + continue; + } + return Poll::Ready(Some(Ok(record_batch))); } return Poll::Pending; @@ -739,11 +986,27 @@ impl Stream for SMJStream { } SMJState::Exhausted => { self.freeze_all()?; - if !self.output_record_batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - return Poll::Ready(Some(Ok(record_batch))); + + if !self.output_record_batches.batches.is_empty() { + if self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::LeftMark + ) + { + let out = self.filter_joined_batch()?; + return Poll::Ready(Some(Ok(out))); + } else { + let record_batch = self.output_record_batch_and_reset()?; + return Poll::Ready(Some(Ok(record_batch))); + } + } else { + return Poll::Ready(None); } - return Poll::Ready(None); } } } @@ -765,6 +1028,7 @@ impl SMJStream { batch_size: usize, join_metrics: SortMergeJoinMetrics, reservation: MemoryReservation, + runtime_env: Arc, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -773,7 +1037,7 @@ impl SMJStream { sort_options, null_equals_null, schema, - streamed_schema: streamed_schema.clone(), + streamed_schema: Arc::clone(&streamed_schema), buffered_schema, streamed, buffered, @@ -787,12 +1051,19 @@ impl SMJStream { on_streamed, on_buffered, filter, - output_record_batches: vec![], + output_record_batches: JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }, output_size: 0, batch_size, join_type, join_metrics, reservation, + runtime_env, + streamed_batch_counter: AtomicUsize::new(0), }) } @@ -824,6 +1095,10 @@ impl SMJStream { self.join_metrics.input_rows.add(batch.num_rows()); self.streamed_batch = StreamedBatch::new(batch, &self.on_streamed); + // Every incoming streaming batch should have its unique id + // Check `JoinedRecordBatches.self.streamed_batch_counter` documentation + self.streamed_batch_counter + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); self.streamed_state = StreamedState::Ready; } } @@ -838,6 +1113,58 @@ impl SMJStream { } } + fn free_reservation(&mut self, buffered_batch: BufferedBatch) -> Result<()> { + // Shrink memory usage for in-memory batches only + if buffered_batch.spill_file.is_none() && buffered_batch.batch.is_some() { + self.reservation + .try_shrink(buffered_batch.size_estimation)?; + } + + Ok(()) + } + + fn allocate_reservation(&mut self, mut buffered_batch: BufferedBatch) -> Result<()> { + match self.reservation.try_grow(buffered_batch.size_estimation) { + Ok(_) => { + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + Ok(()) + } + Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => { + // spill buffered batch to disk + let spill_file = self + .runtime_env + .disk_manager + .create_tmp_file("sort_merge_join_buffered_spill")?; + + if let Some(batch) = buffered_batch.batch { + spill_record_batches( + vec![batch], + spill_file.path().into(), + Arc::clone(&self.buffered_schema), + )?; + buffered_batch.spill_file = Some(spill_file); + buffered_batch.batch = None; + + // update metrics to register spill + self.join_metrics.spill_count.add(1); + self.join_metrics + .spilled_bytes + .add(buffered_batch.size_estimation); + self.join_metrics.spilled_rows.add(buffered_batch.num_rows); + Ok(()) + } else { + internal_err!("Buffered batch has empty body") + } + } + Err(e) => exec_err!("{}. Disk spilling disabled.", e.message()), + }?; + + self.buffered_data.batches.push_back(buffered_batch); + Ok(()) + } + /// Poll next buffered batches fn poll_buffered_batches(&mut self, cx: &mut Context) -> Poll>> { loop { @@ -846,14 +1173,17 @@ impl SMJStream { // pop previous buffered batches while !self.buffered_data.batches.is_empty() { let head_batch = self.buffered_data.head_batch(); - if head_batch.range.end == head_batch.batch.num_rows() { + // If the head batch is fully processed, dequeue it and produce output of it. + if head_batch.range.end == head_batch.num_rows { self.freeze_dequeuing_buffered()?; if let Some(buffered_batch) = self.buffered_data.batches.pop_front() { - self.reservation.shrink(buffered_batch.size_estimation); + self.free_reservation(buffered_batch)?; } } else { + // If the head batch is not fully processed, break the loop. + // Streamed batch will be joined with the head batch in the next step. break; } } @@ -877,25 +1207,22 @@ impl SMJStream { Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); + if batch.num_rows() > 0 { let buffered_batch = BufferedBatch::new(batch, 0..1, &self.on_buffered); - self.reservation.try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.allocate_reservation(buffered_batch)?; self.buffered_state = BufferedState::PollingRest; } } }, BufferedState::PollingRest => { if self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { while self.buffered_data.tail_batch().range.end - < self.buffered_data.tail_batch().batch.num_rows() + < self.buffered_data.tail_batch().num_rows { if is_join_arrays_equal( &self.buffered_data.head_batch().join_arrays, @@ -918,6 +1245,7 @@ impl SMJStream { self.buffered_state = BufferedState::Ready; } Poll::Ready(Some(batch)) => { + // Polling batches coming concurrently as multiple partitions self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { @@ -926,12 +1254,7 @@ impl SMJStream { 0..0, &self.on_buffered, ); - self.reservation - .try_grow(buffered_batch.size_estimation)?; - self.join_metrics - .peak_mem_used - .set_max(self.reservation.size()); - self.buffered_data.batches.push_back(buffered_batch); + self.allocate_reservation(buffered_batch)?; } } } @@ -956,14 +1279,14 @@ impl SMJStream { return Ok(Ordering::Less); } - return compare_join_arrays( + compare_join_arrays( &self.streamed_batch.join_arrays, self.streamed_batch.idx, &self.buffered_data.head_batch().join_arrays, self.buffered_data.head_batch().range.start, &self.sort_options, self.null_equals_null, - ); + ) } /// Produce join and fill output buffer until reaching target batch size @@ -973,6 +1296,8 @@ impl SMJStream { let mut join_streamed = false; // Whether to join buffered rows let mut join_buffered = false; + // For Mark join we store a dummy id to indicate the the row has a match + let mut mark_row_as_match = false; // determine whether we need to join streamed/buffered rows match self.current_ordering { @@ -984,13 +1309,30 @@ impl SMJStream { | JoinType::RightSemi | JoinType::Full | JoinType::LeftAnti + | JoinType::LeftMark ) { join_streamed = !self.streamed_joined; } } Ordering::Equal => { - if matches!(self.join_type, JoinType::LeftSemi) { - join_streamed = !self.streamed_joined; + if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftMark) { + mark_row_as_match = matches!(self.join_type, JoinType::LeftMark); + // if the join filter is specified then its needed to output the streamed index + // only if it has not been emitted before + // the `join_filter_matched_idxs` keeps track on if streamed index has a successful + // filter match and prevents the same index to go into output more than once + if self.filter.is_some() { + join_streamed = !self + .streamed_batch + .join_filter_matched_idxs + .contains(&(self.streamed_batch.idx as u64)) + && !self.streamed_joined; + // if the join filter specified there can be references to buffered columns + // so buffered columns are needed to access them + join_buffered = join_streamed; + } else { + join_streamed = !self.streamed_joined; + } } if matches!( self.join_type, @@ -999,6 +1341,11 @@ impl SMJStream { join_streamed = true; join_buffered = true; }; + + if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() { + join_streamed = !self.streamed_joined; + join_buffered = join_streamed; + } } Ordering::Greater => { if matches!(self.join_type, JoinType::Full) { @@ -1025,7 +1372,7 @@ impl SMJStream { Some(scanning_idx), ); } else { - // Join nulls and buffered row + // Join nulls and buffered row for FULL join self.buffered_data .scanning_batch_mut() .null_joined @@ -1046,9 +1393,11 @@ impl SMJStream { } else { Some(self.buffered_data.scanning_batch_idx) }; + // For Mark join we store a dummy id to indicate the the row has a match + let scanning_idx = mark_row_as_match.then_some(0); self.streamed_batch - .append_output_pair(scanning_batch_idx, None); + .append_output_pair(scanning_batch_idx, scanning_idx); self.output_size += 1; self.buffered_data.scanning_finish(); self.streamed_joined = true; @@ -1058,7 +1407,7 @@ impl SMJStream { fn freeze_all(&mut self) -> Result<()> { self.freeze_streamed()?; - self.freeze_buffered(self.buffered_data.batches.len())?; + self.freeze_buffered(self.buffered_data.batches.len(), false)?; Ok(()) } @@ -1068,7 +1417,8 @@ impl SMJStream { // 2. freezes NULLs joined to dequeued buffered batch to "release" it fn freeze_dequeuing_buffered(&mut self) -> Result<()> { self.freeze_streamed()?; - self.freeze_buffered(1)?; + // Only freeze and produce the first batch in buffered_data as the batch is fully processed + self.freeze_buffered(1, true)?; Ok(()) } @@ -1076,7 +1426,14 @@ impl SMJStream { // NULLs on streamed side. // // Applicable only in case of Full join. - fn freeze_buffered(&mut self, batch_count: usize) -> Result<()> { + // + // If `output_not_matched_filter` is true, this will also produce record batches + // for buffered rows which are joined with streamed side but don't match join filter. + fn freeze_buffered( + &mut self, + batch_count: usize, + output_not_matched_filter: bool, + ) -> Result<()> { if !matches!(self.join_type, JoinType::Full) { return Ok(()); } @@ -1084,33 +1441,39 @@ impl SMJStream { let buffered_indices = UInt64Array::from_iter_values( buffered_batch.null_joined.iter().map(|&index| index as u64), ); - if buffered_indices.is_empty() { - continue; + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.output_record_batches.batches.push(record_batch); } buffered_batch.null_joined.clear(); - // Take buffered (right) columns - let buffered_columns = buffered_batch - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>() - .map_err(Into::::into)?; - - // Create null streamed (left) columns - let mut streamed_columns = self - .streamed_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>(); + // For buffered row which is joined with streamed side rows but all joined rows + // don't satisfy the join filter + if output_not_matched_filter { + let not_matched_buffered_indices = buffered_batch + .join_filter_failed_map + .iter() + .filter_map(|(idx, failed)| if *failed { Some(*idx) } else { None }) + .collect::>(); - streamed_columns.extend(buffered_columns); - let columns = streamed_columns; + let buffered_indices = UInt64Array::from_iter_values( + not_matched_buffered_indices.iter().copied(), + ); - self.output_record_batches - .push(RecordBatch::try_new(self.schema.clone(), columns)?); + if let Some(record_batch) = produce_buffered_null_batch( + &self.schema, + &self.streamed_schema, + &buffered_indices, + buffered_batch, + )? { + self.output_record_batches.batches.push(record_batch); + } + buffered_batch.join_filter_failed_map.clear(); + } } Ok(()) } @@ -1119,6 +1482,7 @@ impl SMJStream { // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { for chunk in self.streamed_batch.output_indices.iter_mut() { + // The row indices of joined streamed batch let streamed_indices = chunk.streamed_indices.finish(); if streamed_indices.is_empty() { @@ -1133,53 +1497,65 @@ impl SMJStream { .map(|column| take(column, &streamed_indices, None)) .collect::, ArrowError>>()?; + // The row indices of joined buffered batch let buffered_indices: UInt64Array = chunk.buffered_indices.finish(); - - let mut buffered_columns = - if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - vec![] - } else if let Some(buffered_idx) = chunk.buffered_batch_idx { - self.buffered_data.batches[buffered_idx] - .batch - .columns() - .iter() - .map(|column| take(column, &buffered_indices, None)) - .collect::, ArrowError>>()? - } else { - self.buffered_schema - .fields() - .iter() - .map(|f| new_null_array(f.data_type(), buffered_indices.len())) - .collect::>() - }; + let mut buffered_columns = if matches!(self.join_type, JoinType::LeftMark) { + vec![Arc::new(is_not_null(&buffered_indices)?) as ArrayRef] + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + vec![] + } else if let Some(buffered_idx) = chunk.buffered_batch_idx { + get_buffered_columns( + &self.buffered_data, + buffered_idx, + &buffered_indices, + )? + } else { + // If buffered batch none, meaning it is null joined batch. + // We need to create null arrays for buffered columns to join with streamed rows. + create_unmatched_columns( + self.join_type, + &self.buffered_schema, + buffered_indices.len(), + ) + }; let streamed_columns_length = streamed_columns.len(); - let buffered_columns_length = buffered_columns.len(); // Prepare the columns we apply join filter on later. // Only for joined rows between streamed and buffered. let filter_columns = if chunk.buffered_batch_idx.is_some() { if matches!(self.join_type, JoinType::Right) { get_filter_column(&self.filter, &buffered_columns, &streamed_columns) + } else if matches!( + self.join_type, + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark + ) { + // unwrap is safe here as we check is_some on top of if statement + let buffered_columns = get_buffered_columns( + &self.buffered_data, + chunk.buffered_batch_idx.unwrap(), + &buffered_indices, + )?; + + get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } else { get_filter_column(&self.filter, &streamed_columns, &buffered_columns) } } else { - // This chunk is for null joined rows (outer join), we don't need to apply join filter. + // This chunk is totally for null joined rows (outer join), we don't need to apply join filter. + // Any join filter applied only on either streamed or buffered side will be pushed already. vec![] }; let columns = if matches!(self.join_type, JoinType::Right) { - buffered_columns.extend(streamed_columns.clone()); + buffered_columns.extend(streamed_columns); buffered_columns } else { streamed_columns.extend(buffered_columns); streamed_columns }; - let output_batch = - RecordBatch::try_new(self.schema.clone(), columns.clone())?; - + let output_batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; // Apply join filter if any if !filter_columns.is_empty() { if let Some(f) = &self.filter { @@ -1194,34 +1570,69 @@ impl SMJStream { .evaluate(&filter_batch)? .into_array(filter_batch.num_rows())?; - // The selection mask of the filter - let mask = datafusion_common::cast::as_boolean_array(&filter_result)?; + // The boolean selection mask of the join filter result + let pre_mask = + datafusion_common::cast::as_boolean_array(&filter_result)?; - // Push the filtered batch to the output - let filtered_batch = - compute::filter_record_batch(&output_batch, mask)?; - self.output_record_batches.push(filtered_batch); + // If there are nulls in join filter result, exclude them from selecting + // the rows to output. + let mask = if pre_mask.null_count() > 0 { + compute::prep_null_mask_filter( + datafusion_common::cast::as_boolean_array(&filter_result)?, + ) + } else { + pre_mask.clone() + }; - // For outer joins, we need to push the null joined rows to the output. + // Push the filtered batch which contains rows passing join filter to the output if matches!( self.join_type, - JoinType::Left | JoinType::Right | JoinType::Full + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::LeftMark ) { - // The reverse of the selection mask. For the rows not pass join filter above, - // we need to join them (left or right) with null rows for outer joins. + self.output_record_batches + .batches + .push(output_batch.clone()); + } else { + let filtered_batch = filter_record_batch(&output_batch, &mask)?; + self.output_record_batches.batches.push(filtered_batch); + } + + self.output_record_batches.filter_mask.extend(&mask); + self.output_record_batches + .row_indices + .extend(&streamed_indices); + self.output_record_batches.batch_ids.extend(vec![ + self.streamed_batch_counter.load(Relaxed); + streamed_indices.len() + ]); + + // For outer joins, we need to push the null joined rows to the output if + // all joined rows are failed on the join filter. + // I.e., if all rows joined from a streamed row are failed with the join filter, + // we need to join it with nulls as buffered side. + if matches!(self.join_type, JoinType::Full) { + // We need to get the mask for row indices that the joined rows are failed + // on the join filter. I.e., for a row in streamed side, if all joined rows + // between it and all buffered rows are failed on the join filter, we need to + // output it with null columns from buffered side. For the mask here, it + // behaves like LeftAnti join. let not_mask = if mask.null_count() > 0 { // If the mask contains nulls, we need to use `prep_null_mask_filter` to // handle the nulls in the mask as false to produce rows where the mask // was null itself. - compute::not(&compute::prep_null_mask_filter(mask))? + compute::not(&compute::prep_null_mask_filter(&mask))? } else { - compute::not(mask)? + compute::not(&mask)? }; let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; + filter_record_batch(&output_batch, ¬_mask)?; - let mut buffered_columns = self + let buffered_columns = self .buffered_schema .fields() .iter() @@ -1233,18 +1644,7 @@ impl SMJStream { }) .collect::>(); - let columns = if matches!(self.join_type, JoinType::Right) { - let streamed_columns = null_joined_batch - .columns() - .iter() - .skip(buffered_columns_length) - .cloned() - .collect::>(); - - buffered_columns.extend(streamed_columns); - buffered_columns - } else { - // Left join or full outer join + let columns = { let mut streamed_columns = null_joined_batch .columns() .iter() @@ -1256,58 +1656,49 @@ impl SMJStream { streamed_columns }; + // Push the streamed/buffered batch joined nulls to the output let null_joined_streamed_batch = - RecordBatch::try_new(self.schema.clone(), columns.clone())?; - self.output_record_batches.push(null_joined_streamed_batch); - - // For full join, we also need to output the null joined rows from the buffered side + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + self.output_record_batches + .batches + .push(null_joined_streamed_batch); + + // For full join, we also need to output the null joined rows from the buffered side. + // Usually this is done by `freeze_buffered`. However, if a buffered row is joined with + // streamed side, it won't be outputted by `freeze_buffered`. + // We need to check if a buffered row is joined with streamed side and output. + // If it is joined with streamed side, but doesn't match the join filter, + // we need to output it with nulls as streamed side. if matches!(self.join_type, JoinType::Full) { - // Handle not mask for buffered side further. - // For buffered side, we want to output the rows that are not null joined with - // the streamed side. i.e. the rows that are not null in the `buffered_indices`. - let not_mask = if let Some(nulls) = buffered_indices.nulls() { - let mask = not_mask.values() & nulls.inner(); - BooleanArray::new(mask, None) - } else { - not_mask - }; - - let null_joined_batch = - compute::filter_record_batch(&output_batch, ¬_mask)?; - - let mut streamed_columns = self - .streamed_schema - .fields() - .iter() - .map(|f| { - new_null_array( - f.data_type(), - null_joined_batch.num_rows(), - ) - }) - .collect::>(); - - let buffered_columns = null_joined_batch - .columns() - .iter() - .skip(streamed_columns_length) - .cloned() - .collect::>(); + let buffered_batch = &mut self.buffered_data.batches + [chunk.buffered_batch_idx.unwrap()]; + + for i in 0..pre_mask.len() { + // If the buffered row is not joined with streamed side, + // skip it. + if buffered_indices.is_null(i) { + continue; + } - streamed_columns.extend(buffered_columns); + let buffered_index = buffered_indices.value(i); - let null_joined_buffered_batch = RecordBatch::try_new( - self.schema.clone(), - streamed_columns, - )?; - self.output_record_batches.push(null_joined_buffered_batch); + buffered_batch.join_filter_failed_map.insert( + buffered_index, + *buffered_batch + .join_filter_failed_map + .get(&buffered_index) + .unwrap_or(&true) + && !pre_mask.value(i), + ); + } } } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } else { - self.output_record_batches.push(output_batch); + self.output_record_batches.batches.push(output_batch); } } @@ -1317,20 +1708,131 @@ impl SMJStream { } fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = concat_batches(&self.schema, &self.output_record_batches)?; + let record_batch = + concat_batches(&self.schema, &self.output_record_batches.batches)?; self.join_metrics.output_batches.add(1); self.join_metrics.output_rows.add(record_batch.num_rows()); // If join filter exists, `self.output_size` is not accurate as we don't know the exact // number of rows in the output record batch. If streamed row joined with buffered rows, // once join filter is applied, the number of output rows may be more than 1. - if record_batch.num_rows() > self.output_size { + // If `record_batch` is empty, we should reset `self.output_size` to 0. It could be happened + // when the join filter is applied and all rows are filtered out. + if record_batch.num_rows() == 0 || record_batch.num_rows() > self.output_size { self.output_size = 0; } else { self.output_size -= record_batch.num_rows(); } - self.output_record_batches.clear(); + + if !(self.filter.is_some() + && matches!( + self.join_type, + JoinType::Left + | JoinType::LeftSemi + | JoinType::Right + | JoinType::LeftAnti + | JoinType::LeftMark + )) + { + self.output_record_batches.batches.clear(); + } Ok(record_batch) } + + fn filter_joined_batch(&mut self) -> Result { + let record_batch = self.output_record_batch_and_reset()?; + let out_indices = self.output_record_batches.row_indices.finish(); + let out_mask = self.output_record_batches.filter_mask.finish(); + let maybe_corrected_mask = get_corrected_filter_mask( + self.join_type, + &out_indices, + &self.output_record_batches.batch_ids, + &out_mask, + record_batch.num_rows(), + ); + + let corrected_mask = if let Some(ref filtered_join_mask) = maybe_corrected_mask { + filtered_join_mask + } else { + &out_mask + }; + + let mut filtered_record_batch = + filter_record_batch(&record_batch, corrected_mask)?; + let buffered_columns_length = self.buffered_schema.fields.len(); + let streamed_columns_length = self.streamed_schema.fields.len(); + + if matches!( + self.join_type, + JoinType::Left | JoinType::LeftMark | JoinType::Right + ) { + let null_mask = compute::not(corrected_mask)?; + let null_joined_batch = filter_record_batch(&record_batch, &null_mask)?; + + let mut buffered_columns = create_unmatched_columns( + self.join_type, + &self.buffered_schema, + null_joined_batch.num_rows(), + ); + + let columns = if matches!(self.join_type, JoinType::Right) { + let streamed_columns = null_joined_batch + .columns() + .iter() + .skip(buffered_columns_length) + .cloned() + .collect::>(); + + buffered_columns.extend(streamed_columns); + buffered_columns + } else { + // Left join or full outer join + let mut streamed_columns = null_joined_batch + .columns() + .iter() + .take(streamed_columns_length) + .cloned() + .collect::>(); + + streamed_columns.extend(buffered_columns); + streamed_columns + }; + + // Push the streamed/buffered batch joined nulls to the output + let null_joined_streamed_batch = + RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + + filtered_record_batch = concat_batches( + &self.schema, + &[filtered_record_batch, null_joined_streamed_batch], + )?; + } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + let output_column_indices = (0..streamed_columns_length).collect::>(); + filtered_record_batch = + filtered_record_batch.project(&output_column_indices)?; + } + + self.output_record_batches.batches.clear(); + self.output_record_batches.batch_ids = vec![]; + self.output_record_batches.filter_mask = BooleanBuilder::new(); + self.output_record_batches.row_indices = UInt64Builder::new(); + Ok(filtered_record_batch) + } +} + +fn create_unmatched_columns( + join_type: JoinType, + schema: &SchemaRef, + size: usize, +) -> Vec { + if matches!(join_type, JoinType::LeftMark) { + vec![Arc::new(BooleanArray::from(vec![false; size])) as ArrayRef] + } else { + schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), size)) + .collect::>() + } } /// Gets the arrays which join filters are applied on. @@ -1346,14 +1848,14 @@ fn get_filter_column( .column_indices() .iter() .filter(|col_index| col_index.side == JoinSide::Left) - .map(|i| streamed_columns[i.index].clone()) + .map(|i| Arc::clone(&streamed_columns[i.index])) .collect::>(); let right_columns = f .column_indices() .iter() .filter(|col_index| col_index.side == JoinSide::Right) - .map(|i| buffered_columns[i.index].clone()) + .map(|i| Arc::clone(&buffered_columns[i.index])) .collect::>(); filter_columns.extend(left_columns); @@ -1363,6 +1865,82 @@ fn get_filter_column( filter_columns } +fn produce_buffered_null_batch( + schema: &SchemaRef, + streamed_schema: &SchemaRef, + buffered_indices: &PrimitiveArray, + buffered_batch: &BufferedBatch, +) -> Result> { + if buffered_indices.is_empty() { + return Ok(None); + } + + // Take buffered (right) columns + let buffered_columns = + get_buffered_columns_from_batch(buffered_batch, buffered_indices)?; + + // Create null streamed (left) columns + let mut streamed_columns = streamed_schema + .fields() + .iter() + .map(|f| new_null_array(f.data_type(), buffered_indices.len())) + .collect::>(); + + streamed_columns.extend(buffered_columns); + + Ok(Some(RecordBatch::try_new( + Arc::clone(schema), + streamed_columns, + )?)) +} + +/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` +#[inline(always)] +fn get_buffered_columns( + buffered_data: &BufferedData, + buffered_batch_idx: usize, + buffered_indices: &UInt64Array, +) -> Result> { + get_buffered_columns_from_batch( + &buffered_data.batches[buffered_batch_idx], + buffered_indices, + ) +} + +#[inline(always)] +fn get_buffered_columns_from_batch( + buffered_batch: &BufferedBatch, + buffered_indices: &UInt64Array, +) -> Result> { + match (&buffered_batch.spill_file, &buffered_batch.batch) { + // In memory batch + (None, Some(batch)) => Ok(batch + .columns() + .iter() + .map(|column| take(column, &buffered_indices, None)) + .collect::, ArrowError>>() + .map_err(Into::::into)?), + // If the batch was spilled to disk, less likely + (Some(spill_file), None) => { + let mut buffered_cols: Vec = + Vec::with_capacity(buffered_indices.len()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = FileReader::try_new(file, None)?; + + for batch in reader { + batch?.columns().iter().for_each(|column| { + buffered_cols.extend(take(column, &buffered_indices, None)) + }); + } + + Ok(buffered_cols) + } + // Invalid combination + (spill, batch) => internal_err!("Unexpected buffered batch spill status. Spill exists: {}. In-memory exists: {}", spill.is_some(), batch.is_some()), + } +} + /// Buffered data contains all buffered batches with one unique join key #[derive(Debug, Default)] struct BufferedData { @@ -1517,9 +2095,10 @@ fn compare_join_arrays( }, DataType::Date32 => compare_value!(Date32Array), DataType::Date64 => compare_value!(Date64Array), - _ => { + dt => { return not_impl_err!( - "Unsupported data type in sort merge join comparator" + "Unsupported data type in sort merge join comparator: {}", + dt ); } } @@ -1583,9 +2162,10 @@ fn is_join_arrays_equal( }, DataType::Date32 => compare_value!(Date32Array), DataType::Date64 => compare_value!(Date64Array), - _ => { + dt => { return not_impl_err!( - "Unsupported data type in sort merge join comparator" + "Unsupported data type in sort merge join comparator: {}", + dt ); } } @@ -1600,24 +2180,30 @@ fn is_join_arrays_equal( mod tests { use std::sync::Arc; - use crate::expressions::Column; - use crate::joins::utils::JoinOn; - use crate::joins::SortMergeJoinExec; - use crate::memory::MemoryExec; - use crate::test::build_table_i32; - use crate::{common, ExecutionPlan}; - use arrow::array::{Date32Array, Date64Array, Int32Array}; - use arrow::compute::SortOptions; + use arrow::compute::{concat_batches, filter_record_batch, SortOptions}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use arrow_array::builder::{BooleanBuilder, UInt64Builder}; + use arrow_array::{BooleanArray, UInt64Array}; + + use datafusion_common::JoinType::*; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, JoinType, Result, }; use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; + use crate::expressions::Column; + use crate::joins::sort_merge_join::{get_corrected_filter_mask, JoinedRecordBatches}; + use crate::joins::utils::JoinOn; + use crate::joins::SortMergeJoinExec; + use crate::memory::MemoryExec; + use crate::test::build_table_i32; + use crate::{common, ExecutionPlan}; + fn build_table( a: (&str, &Vec), b: (&str, &Vec), @@ -1695,7 +2281,7 @@ mod tests { Field::new(c.0, DataType::Int32, true), ])); let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Int32Array::from(a.1.clone())), Arc::new(Int32Array::from(b.1.clone())), @@ -1804,7 +2390,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", @@ -1843,7 +2429,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -1881,7 +2467,7 @@ mod tests { ), ]; - let (_columns, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_columns, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -1920,7 +2506,7 @@ mod tests { ), ]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -1961,7 +2547,7 @@ mod tests { left, right, on, - JoinType::Inner, + Inner, vec![ SortOptions { descending: true, @@ -2011,7 +2597,7 @@ mod tests { ]; let (_, batches) = - join_collect_batch_size_equals_two(left, right, on, JoinType::Inner).await?; + join_collect_batch_size_equals_two(left, right, on, Inner).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b2 | c1 | a1 | b2 | c2 |", @@ -2046,7 +2632,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2078,7 +2664,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2110,7 +2696,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap()) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2142,7 +2728,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?; + let (_, batches) = join_collect(left, right, on, LeftAnti).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2173,7 +2759,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?; + let (_, batches) = join_collect(left, right, on, LeftSemi).await?; let expected = [ "+----+----+----+", "| a1 | b1 | c1 |", @@ -2188,6 +2774,39 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_left_mark() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 4, 5, 6]), // 5 is double on the right + ("c2", &vec![60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, + )]; + + let (_, batches) = join_collect(left, right, on, LeftMark).await?; + let expected = [ + "+----+----+----+-------+", + "| a1 | b1 | c1 | mark |", + "+----+----+----+-------+", + "| 1 | 4 | 7 | true |", + "| 2 | 5 | 8 | true |", + "| 2 | 5 | 8 | true |", + "| 3 | 7 | 9 | false |", + "+----+----+----+-------+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { let left = build_table( @@ -2206,7 +2825,7 @@ mod tests { Arc::new(Column::new_with_schema("b", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = [ "+---+---+---+----+---+----+", "| a | b | c | a | b | c |", @@ -2238,7 +2857,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+------------+------------+------------+------------+------------+------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2270,7 +2889,7 @@ mod tests { Arc::new(Column::new_with_schema("b1", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Inner).await?; + let (_, batches) = join_collect(left, right, on, Inner).await?; let expected = ["+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", "| a1 | b1 | c1 | a2 | b1 | c2 |", @@ -2301,7 +2920,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2337,7 +2956,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = [ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2381,7 +3000,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Left).await?; + let (_, batches) = join_collect(left, right, on, Left).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2430,7 +3049,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Right).await?; + let (_, batches) = join_collect(left, right, on, Right).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2479,7 +3098,7 @@ mod tests { Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, )]; - let (_, batches) = join_collect(left, right, on, JoinType::Full).await?; + let (_, batches) = join_collect(left, right, on, Full).await?; let expected = vec![ "+----+----+----+----+----+----+", "| a1 | b1 | c1 | a2 | b2 | c2 |", @@ -2502,7 +3121,7 @@ mod tests { } #[tokio::test] - async fn overallocation_single_batch() -> Result<()> { + async fn overallocation_single_batch_no_spill() -> Result<()> { let left = build_table( ("a1", &vec![0, 1, 2, 3, 4, 5]), ("b1", &vec![1, 2, 3, 4, 5, 6]), @@ -2519,28 +3138,24 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; - for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); + // Disable DiskManager to prevent spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled) + .build_arc()?; + let session_config = SessionConfig::default().with_batch_size(50); + for join_type in join_types { let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), join_type, sort_options.clone(), @@ -2550,18 +3165,20 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); } Ok(()) } #[tokio::test] - async fn overallocation_multi_batch() -> Result<()> { + async fn overallocation_multi_batch_no_spill() -> Result<()> { let left_batch_1 = build_table_i32( ("a1", &vec![0, 1]), ("b1", &vec![1, 1]), @@ -2599,26 +3216,23 @@ mod tests { )]; let sort_options = vec![SortOptions::default(); on.len()]; - let join_types = vec![ - JoinType::Inner, - JoinType::Left, - JoinType::Right, - JoinType::Full, - JoinType::LeftSemi, - JoinType::LeftAnti, - ]; + let join_types = vec![Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + + // Disable DiskManager to prevent spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::Disabled) + .build_arc()?; + let session_config = SessionConfig::default().with_batch_size(50); for join_type in join_types { - let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); - let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); - let session_config = SessionConfig::default().with_batch_size(50); let task_ctx = TaskContext::default() - .with_session_config(session_config) - .with_runtime(runtime); + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); let task_ctx = Arc::new(task_ctx); let join = join_with_options( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), join_type, sort_options.clone(), @@ -2628,15 +3242,871 @@ mod tests { let stream = join.execute(0, task_ctx)?; let err = common::collect(stream).await.unwrap_err(); - assert_contains!( - err.to_string(), - "Resources exhausted: Failed to allocate additional" - ); + assert_contains!(err.to_string(), "Failed to allocate additional"); assert_contains!(err.to_string(), "SMJStream[0]"); + assert_contains!(err.to_string(), "Disk spilling disabled"); + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_single_batch_spill() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + + // Enable DiskManager to allow spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(100, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } } Ok(()) } + + #[tokio::test] + async fn overallocation_multi_batch_spill() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = [Inner, Left, Right, Full, LeftSemi, LeftAnti, LeftMark]; + + // Enable DiskManager to allow spilling + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(500, 1.0) + .with_disk_manager(DiskManagerConfig::NewOs) + .build_arc()?; + + for batch_size in [1, 50] { + let session_config = SessionConfig::default().with_batch_size(batch_size); + + for join_type in &join_types { + let task_ctx = TaskContext::default() + .with_session_config(session_config.clone()) + .with_runtime(Arc::clone(&runtime)); + let task_ctx = Arc::new(task_ctx); + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let spilled_join_result = common::collect(stream).await.unwrap(); + assert!(join.metrics().is_some()); + assert!(join.metrics().unwrap().spill_count().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_bytes().unwrap() > 0); + assert!(join.metrics().unwrap().spilled_rows().unwrap() > 0); + + // Run the test with no spill configuration as + let task_ctx_no_spill = + TaskContext::default().with_session_config(session_config.clone()); + let task_ctx_no_spill = Arc::new(task_ctx_no_spill); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + *join_type, + sort_options.clone(), + false, + )?; + let stream = join.execute(0, task_ctx_no_spill)?; + let no_spilled_join_result = common::collect(stream).await.unwrap(); + + assert!(join.metrics().is_some()); + assert_eq!(join.metrics().unwrap().spill_count(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_bytes(), Some(0)); + assert_eq!(join.metrics().unwrap().spilled_rows(), Some(0)); + // Compare spilled and non spilled data to check spill logic doesn't corrupt the data + assert_eq!(spilled_join_result, no_spilled_join_result); + } + } + + Ok(()) + } + + fn build_joined_record_batches() -> Result { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("x", DataType::Int32, true), + Field::new("y", DataType::Int32, true), + ])); + + let mut batches = JoinedRecordBatches { + batches: vec![], + filter_mask: BooleanBuilder::new(), + row_indices: UInt64Builder::new(), + batch_ids: vec![], + }; + + // Insert already prejoined non-filtered rows + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![10, 10])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 9])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![11])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 12])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![11, 13])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![13])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int32Array::from(vec![12])), + ], + )?); + + batches.batches.push(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![14, 14])), + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int32Array::from(vec![12, 11])), + ], + )?); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![1]; + batches.batch_ids.extend(vec![0; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![1; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0]; + batches.batch_ids.extend(vec![2; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + let streamed_indices = vec![0, 0]; + batches.batch_ids.extend(vec![3; streamed_indices.len()]); + batches + .row_indices + .extend(&UInt64Array::from(streamed_indices)); + + batches + .filter_mask + .extend(&BooleanArray::from(vec![true, false])); + batches.filter_mask.extend(&BooleanArray::from(vec![true])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, true])); + batches.filter_mask.extend(&BooleanArray::from(vec![false])); + batches + .filter_mask + .extend(&BooleanArray::from(vec![false, false])); + + Ok(batches) + } + + #[tokio::test] + async fn test_left_outer_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, false, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + false, false, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + true, true, false, false, false, false, false, false + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true, true, true, false, false, false, false, false]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + Some(true), + Some(true), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + assert_eq!( + get_corrected_filter_mask( + Left, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![ + None, + None, + Some(false), + Some(false), + Some(false), + Some(false), + Some(false), + Some(false) + ]) + ); + + let corrected_mask = get_corrected_filter_mask( + Left, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + Some(false), + None, + Some(false) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + Some(true), + None, + Some(true) + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[null_joined_batch] + ); + Ok(()) + } + + #[tokio::test] + async fn test_left_semi_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![true]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![Some(true), None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true),]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, Some(true), None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftSemi, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + output.num_rows() + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftSemi, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + Some(true), + None, + Some(true), + None, + Some(true), + None, + None, + None + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 10 | 1 | 11 |", + "| 1 | 11 | 1 | 12 |", + "| 1 | 12 | 1 | 13 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + Some(false), + None, + Some(false), + None, + Some(false), + None, + None, + None + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + Ok(()) + } + + #[tokio::test] + async fn test_left_anti_join_filtered_mask() -> Result<()> { + let mut joined_batches = build_joined_record_batches()?; + let schema = joined_batches.batches.first().unwrap().schema(); + + let output = concat_batches(&schema, &joined_batches.batches)?; + let out_mask = joined_batches.filter_mask.finish(); + let out_indices = joined_batches.row_indices.finish(); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![true]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0]), + &[0usize], + &BooleanArray::from(vec![false]), + 1 + ) + .unwrap(), + BooleanArray::from(vec![Some(true)]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0]), + &[0usize; 2], + &BooleanArray::from(vec![true, true]), + 2 + ) + .unwrap(), + BooleanArray::from(vec![None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![true, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, true, true]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, None]) + ); + + assert_eq!( + get_corrected_filter_mask( + LeftAnti, + &UInt64Array::from(vec![0, 0, 0]), + &[0usize; 3], + &BooleanArray::from(vec![false, false, false]), + 3 + ) + .unwrap(), + BooleanArray::from(vec![None, None, Some(true)]) + ); + + let corrected_mask = get_corrected_filter_mask( + LeftAnti, + &out_indices, + &joined_batches.batch_ids, + &out_mask, + output.num_rows(), + ) + .unwrap(); + + assert_eq!( + corrected_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(true), + None, + Some(true) + ]) + ); + + let filtered_rb = filter_record_batch(&output, &corrected_mask)?; + + assert_batches_eq!( + &[ + "+---+----+---+----+", + "| a | b | x | y |", + "+---+----+---+----+", + "| 1 | 13 | 1 | 12 |", + "| 1 | 14 | 1 | 11 |", + "+---+----+---+----+", + ], + &[filtered_rb] + ); + + // output null rows + let null_mask = arrow::compute::not(&corrected_mask)?; + assert_eq!( + null_mask, + BooleanArray::from(vec![ + None, + None, + None, + None, + None, + Some(false), + None, + Some(false), + ]) + ); + + let null_joined_batch = filter_record_batch(&output, &null_mask)?; + + assert_batches_eq!( + &[ + "+---+---+---+---+", + "| a | b | x | y |", + "+---+---+---+---+", + "+---+---+---+---+", + ], + &[null_joined_batch] + ); + Ok(()) + } + /// Returns the column names on the schema fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index f19eb30313e6..5ccdd9b40dee 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -19,13 +19,12 @@ //! related functionality, used both in join calculations and optimization rules. use std::collections::{HashMap, VecDeque}; +use std::mem::size_of; use std::sync::Arc; -use std::task::{Context, Poll}; -use std::usize; -use crate::joins::utils::{JoinFilter, JoinHashMapType, StatefulStreamResult}; +use crate::joins::utils::{JoinFilter, JoinHashMapType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{handle_async_state, handle_state, metrics, ExecutionPlan}; +use crate::{metrics, ExecutionPlan}; use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; @@ -33,18 +32,15 @@ use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ - arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, - ScalarValue, + arrow_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, }; -use datafusion_execution::SendableRecordBatchStream; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; -use async_trait::async_trait; -use futures::{ready, FutureExt, StreamExt}; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; use hashbrown::raw::RawTable; use hashbrown::HashSet; @@ -159,8 +155,7 @@ impl PruningJoinHashMap { /// # Returns /// The size of the hash map in bytes. pub(crate) fn size(&self) -> usize { - self.map.allocation_info().1.size() - + self.next.capacity() * std::mem::size_of::() + self.map.allocation_info().1.size() + self.next.capacity() * size_of::() } /// Removes hash values from the map and the list based on the given pruning @@ -246,7 +241,7 @@ pub fn map_origin_col_to_filter_col( Ok(col_to_col_map) } -/// This function analyzes [`PhysicalSortExpr`] graphs with respect to monotonicity +/// This function analyzes [`PhysicalSortExpr`] graphs with respect to output orderings /// (sorting) properties. This is necessary since monotonically increasing and/or /// decreasing expressions are required when using join filter expressions for /// data pruning purposes. @@ -275,7 +270,7 @@ pub fn convert_sort_expr_with_filter_schema( sort_expr: &PhysicalSortExpr, ) -> Result>> { let column_map = map_origin_col_to_filter_col(filter, schema, side)?; - let expr = sort_expr.expr.clone(); + let expr = Arc::clone(&sort_expr.expr); // Get main schema columns: let expr_columns = collect_columns(&expr); // Calculation is possible with `column_map` since sort exprs belong to a child. @@ -374,34 +369,40 @@ impl SortedFilterExpr { filter_expr: Arc, filter_schema: &Schema, ) -> Result { - let dt = &filter_expr.data_type(filter_schema)?; + let dt = filter_expr.data_type(filter_schema)?; Ok(Self { origin_sorted_expr, filter_expr, - interval: Interval::make_unbounded(dt)?, + interval: Interval::make_unbounded(&dt)?, node_index: 0, }) } + /// Get origin expr information pub fn origin_sorted_expr(&self) -> &PhysicalSortExpr { &self.origin_sorted_expr } + /// Get filter expr information pub fn filter_expr(&self) -> &Arc { &self.filter_expr } + /// Get interval information pub fn interval(&self) -> &Interval { &self.interval } + /// Sets interval pub fn set_interval(&mut self, interval: Interval) { self.interval = interval; } + /// Node index in ExprIntervalGraph pub fn node_index(&self) -> usize { self.node_index } + /// Node index setter in ExprIntervalGraph pub fn set_node_index(&mut self, node_index: usize) { self.node_index = node_index; @@ -414,41 +415,45 @@ impl SortedFilterExpr { /// on the first or the last value of the expression in `build_input_buffer` /// and `probe_batch`. /// -/// # Arguments +/// # Parameters /// /// * `build_input_buffer` - The [RecordBatch] on the build side of the join. /// * `build_sorted_filter_expr` - Build side [SortedFilterExpr] to update. /// * `probe_batch` - The `RecordBatch` on the probe side of the join. /// * `probe_sorted_filter_expr` - Probe side `SortedFilterExpr` to update. /// -/// ### Note -/// ```text +/// ## Note /// -/// Interval arithmetic is used to calculate viable join ranges for build-side -/// pruning. This is done by first creating an interval for join filter values in -/// the build side of the join, which spans [-∞, FV] or [FV, ∞] depending on the -/// ordering (descending/ascending) of the filter expression. Here, FV denotes the -/// first value on the build side. This range is then compared with the probe side -/// interval, which either spans [-∞, LV] or [LV, ∞] depending on the ordering -/// (ascending/descending) of the probe side. Here, LV denotes the last value on -/// the probe side. +/// Utilizing interval arithmetic, this function computes feasible join intervals +/// on the pruning side by evaluating the prospective value ranges that might +/// emerge in subsequent data batches from the enforcer side. This is done by +/// first creating an interval for join filter values in the pruning side of the +/// join, which spans `[-∞, FV]` or `[FV, ∞]` depending on the ordering (descending/ +/// ascending) of the filter expression. Here, `FV` denotes the first value on the +/// pruning side. This range is then compared with the enforcer side interval, +/// which either spans `[-∞, LV]` or `[LV, ∞]` depending on the ordering (ascending/ +/// descending) of the probe side. Here, `LV` denotes the last value on the enforcer +/// side. /// /// As a concrete example, consider the following query: /// +/// ```text /// SELECT * FROM left_table, right_table /// WHERE /// left_key = right_key AND /// a > b - 3 AND /// a < b + 10 +/// ``` /// -/// where columns "a" and "b" come from tables "left_table" and "right_table", +/// where columns `a` and `b` come from tables `left_table` and `right_table`, /// respectively. When a new `RecordBatch` arrives at the right side, the -/// condition a > b - 3 will possibly indicate a prunable range for the left +/// condition `a > b - 3` will possibly indicate a prunable range for the left /// side. Conversely, when a new `RecordBatch` arrives at the left side, the -/// condition a < b + 10 will possibly indicate prunability for the right side. -/// Let’s inspect what happens when a new RecordBatch` arrives at the right +/// condition `a < b + 10` will possibly indicate prunability for the right side. +/// Let’s inspect what happens when a new `RecordBatch` arrives at the right /// side (i.e. when the left side is the build side): /// +/// ```text /// Build Probe /// +-------+ +-------+ /// | a | z | | b | y | @@ -461,13 +466,13 @@ impl SortedFilterExpr { /// |+--|--+| |+--|--+| /// | 7 | 1 | | 6 | 3 | /// +-------+ +-------+ +/// ``` /// /// In this case, the interval representing viable (i.e. joinable) values for -/// column "a" is [1, ∞], and the interval representing possible future values -/// for column "b" is [6, ∞]. With these intervals at hand, we next calculate +/// column `a` is `[1, ∞]`, and the interval representing possible future values +/// for column `b` is `[6, ∞]`. With these intervals at hand, we next calculate /// intervals for the whole filter expression and propagate join constraint by /// traversing the expression graph. -/// ``` pub fn calculate_filter_expr_intervals( build_input_buffer: &RecordBatch, build_sorted_filter_expr: &mut SortedFilterExpr, @@ -629,352 +634,6 @@ pub fn record_visited_indices( } } -/// Represents the various states of an eager join stream operation. -/// -/// This enum is used to track the current state of streaming during a join -/// operation. It provides indicators as to which side of the join needs to be -/// pulled next or if one (or both) sides have been exhausted. This allows -/// for efficient management of resources and optimal performance during the -/// join process. -#[derive(Clone, Debug)] -pub enum EagerJoinStreamState { - /// Indicates that the next step should pull from the right side of the join. - PullRight, - - /// Indicates that the next step should pull from the left side of the join. - PullLeft, - - /// State representing that the right side of the join has been fully processed. - RightExhausted, - - /// State representing that the left side of the join has been fully processed. - LeftExhausted, - - /// Represents a state where both sides of the join are exhausted. - /// - /// The `final_result` field indicates whether the join operation has - /// produced a final result or not. - BothExhausted { final_result: bool }, -} - -/// `EagerJoinStream` is an asynchronous trait designed for managing incremental -/// join operations between two streams, such as those used in `SymmetricHashJoinExec` -/// and `SortMergeJoinExec`. Unlike traditional join approaches that need to scan -/// one side of the join fully before proceeding, `EagerJoinStream` facilitates -/// more dynamic join operations by working with streams as they emit data. This -/// approach allows for more efficient processing, particularly in scenarios -/// where waiting for complete data materialization is not feasible or optimal. -/// The trait provides a framework for handling various states of such a join -/// process, ensuring that join logic is efficiently executed as data becomes -/// available from either stream. -/// -/// Implementors of this trait can perform eager joins of data from two different -/// asynchronous streams, typically referred to as left and right streams. The -/// trait provides a comprehensive set of methods to control and execute the join -/// process, leveraging the states defined in `EagerJoinStreamState`. Methods are -/// primarily focused on asynchronously fetching data batches from each stream, -/// processing them, and managing transitions between various states of the join. -/// -/// This trait's default implementations use a state machine approach to navigate -/// different stages of the join operation, handling data from both streams and -/// determining when the join completes. -/// -/// State Transitions: -/// - From `PullLeft` to `PullRight` or `LeftExhausted`: -/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: -/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for -/// processing the batch. -/// - On error (`Some(Err(e))`), the error is returned, and the state remains -/// unchanged. -/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` -/// to proceed with the join process. -/// - From `PullRight` to `PullLeft` or `RightExhausted`: -/// - In `fetch_next_from_right_stream`, when fetching from the right stream: -/// - If a batch is available, state changes to `PullLeft` for processing. -/// - On error, the error is returned without changing the state. -/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, -/// with a `Continue` result. -/// - Handling `RightExhausted` and `LeftExhausted`: -/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios -/// when streams are exhausted: -/// - They attempt to continue processing with the other stream. -/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. -/// - Transition to `BothExhausted { final_result: true }`: -/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are -/// exhausted, indicating completion of processing and availability of final results. -#[async_trait] -pub trait EagerJoinStream { - /// Implements the main polling logic for the join stream. - /// - /// This method continuously checks the state of the join stream and - /// acts accordingly by delegating the handling to appropriate sub-methods - /// depending on the current state. - /// - /// # Arguments - /// - /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. - /// - /// # Returns - /// - /// * `Poll>>` - A polled result, either a `RecordBatch` or None. - fn poll_next_impl( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> - where - Self: Send, - { - loop { - return match self.state() { - EagerJoinStreamState::PullRight => { - handle_async_state!(self.fetch_next_from_right_stream(), cx) - } - EagerJoinStreamState::PullLeft => { - handle_async_state!(self.fetch_next_from_left_stream(), cx) - } - EagerJoinStreamState::RightExhausted => { - handle_async_state!(self.handle_right_stream_end(), cx) - } - EagerJoinStreamState::LeftExhausted => { - handle_async_state!(self.handle_left_stream_end(), cx) - } - EagerJoinStreamState::BothExhausted { - final_result: false, - } => { - handle_state!(self.prepare_for_final_results_after_exhaustion()) - } - EagerJoinStreamState::BothExhausted { final_result: true } => { - Poll::Ready(None) - } - }; - } - } - /// Asynchronously pulls the next batch from the right stream. - /// - /// This default implementation checks for the next value in the right stream. - /// If a batch is found, the state is switched to `PullLeft`, and the batch handling - /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. - /// - /// # Returns - /// - /// * `Result>>` - The state result after pulling the batch. - async fn fetch_next_from_right_stream( - &mut self, - ) -> Result>> { - match self.right_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.set_state(EagerJoinStreamState::PullLeft); - self.process_batch_from_right(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::RightExhausted); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Asynchronously pulls the next batch from the left stream. - /// - /// This default implementation checks for the next value in the left stream. - /// If a batch is found, the state is switched to `PullRight`, and the batch handling - /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. - /// - /// # Returns - /// - /// * `Result>>` - The state result after pulling the batch. - async fn fetch_next_from_left_stream( - &mut self, - ) -> Result>> { - match self.left_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.set_state(EagerJoinStreamState::PullRight); - self.process_batch_from_left(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::LeftExhausted); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Asynchronously handles the scenario when the right stream is exhausted. - /// - /// In this default implementation, when the right stream is exhausted, it attempts - /// to pull from the left stream. If a batch is found in the left stream, it delegates - /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set - /// to indicate both streams are exhausted without final results yet. - /// - /// # Returns - /// - /// * `Result>>` - The state result after checking the exhaustion state. - async fn handle_right_stream_end( - &mut self, - ) -> Result>> { - match self.left_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.process_batch_after_right_end(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::BothExhausted { - final_result: false, - }); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Asynchronously handles the scenario when the left stream is exhausted. - /// - /// When the left stream is exhausted, this default - /// implementation tries to pull from the right stream and delegates the batch - /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state - /// is updated to indicate so. - /// - /// # Returns - /// - /// * `Result>>` - The state result after checking the exhaustion state. - async fn handle_left_stream_end( - &mut self, - ) -> Result>> { - match self.right_stream().next().await { - Some(Ok(batch)) => { - if batch.num_rows() == 0 { - return Ok(StatefulStreamResult::Continue); - } - self.process_batch_after_left_end(batch) - } - Some(Err(e)) => Err(e), - None => { - self.set_state(EagerJoinStreamState::BothExhausted { - final_result: false, - }); - Ok(StatefulStreamResult::Continue) - } - } - } - - /// Handles the state when both streams are exhausted and final results are yet to be produced. - /// - /// This default implementation switches the state to indicate both streams are - /// exhausted with final results and then invokes the handling for this specific - /// scenario via `process_batches_before_finalization`. - /// - /// # Returns - /// - /// * `Result>>` - The state result after both streams are exhausted. - fn prepare_for_final_results_after_exhaustion( - &mut self, - ) -> Result>> { - self.set_state(EagerJoinStreamState::BothExhausted { final_result: true }); - self.process_batches_before_finalization() - } - - /// Handles a pulled batch from the right stream. - /// - /// # Arguments - /// - /// * `batch` - The pulled `RecordBatch` from the right stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after processing the batch. - fn process_batch_from_right( - &mut self, - batch: RecordBatch, - ) -> Result>>; - - /// Handles a pulled batch from the left stream. - /// - /// # Arguments - /// - /// * `batch` - The pulled `RecordBatch` from the left stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after processing the batch. - fn process_batch_from_left( - &mut self, - batch: RecordBatch, - ) -> Result>>; - - /// Handles the situation when only the left stream is exhausted. - /// - /// # Arguments - /// - /// * `right_batch` - The `RecordBatch` from the right stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after the left stream is exhausted. - fn process_batch_after_left_end( - &mut self, - right_batch: RecordBatch, - ) -> Result>>; - - /// Handles the situation when only the right stream is exhausted. - /// - /// # Arguments - /// - /// * `left_batch` - The `RecordBatch` from the left stream. - /// - /// # Returns - /// - /// * `Result>>` - The state result after the right stream is exhausted. - fn process_batch_after_right_end( - &mut self, - left_batch: RecordBatch, - ) -> Result>>; - - /// Handles the final state after both streams are exhausted. - /// - /// # Returns - /// - /// * `Result>>` - The final state result after processing. - fn process_batches_before_finalization( - &mut self, - ) -> Result>>; - - /// Provides mutable access to the right stream. - /// - /// # Returns - /// - /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the right stream. - fn right_stream(&mut self) -> &mut SendableRecordBatchStream; - - /// Provides mutable access to the left stream. - /// - /// # Returns - /// - /// * `&mut SendableRecordBatchStream` - Returns a mutable reference to the left stream. - fn left_stream(&mut self) -> &mut SendableRecordBatchStream; - - /// Sets the current state of the join stream. - /// - /// # Arguments - /// - /// * `state` - The new state to be set. - fn set_state(&mut self, state: EagerJoinStreamState); - - /// Fetches the current state of the join stream. - /// - /// # Returns - /// - /// * `EagerJoinStreamState` - The current state of the join stream. - fn state(&mut self) -> EagerJoinStreamState; -} - #[derive(Debug)] pub struct StreamJoinSideMetrics { /// Number of batches consumed by this operator @@ -1048,7 +707,7 @@ fn update_sorted_exprs_with_node_indices( // Extract filter expressions from the sorted expressions: let filter_exprs = sorted_exprs .iter() - .map(|expr| expr.filter_expr().clone()) + .map(|expr| Arc::clone(expr.filter_expr())) .collect::>(); // Gather corresponding node indices for the extracted filter expressions from the graph: @@ -1061,13 +720,21 @@ fn update_sorted_exprs_with_node_indices( } } -/// Prepares and sorts expressions based on a given filter, left and right execution plans, and sort expressions. +/// Prepares and sorts expressions based on a given filter, left and right schemas, +/// and sort expressions. /// -/// # Arguments +/// This function prepares sorted filter expressions for both the left and right +/// sides of a join operation. It first builds the filter order for each side +/// based on the provided `ExecutionPlan`. If both sides have valid sorted filter +/// expressions, the function then constructs an expression interval graph and +/// updates the sorted expressions with node indices. The final sorted filter +/// expressions for both sides are then returned. +/// +/// # Parameters /// /// * `filter` - The join filter to base the sorting on. -/// * `left` - The left execution plan. -/// * `right` - The right execution plan. +/// * `left` - The `ExecutionPlan` for the left side of the join. +/// * `right` - The `ExecutionPlan` for the right side of the join. /// * `left_sort_exprs` - The expressions to sort on the left side. /// * `right_sort_exprs` - The expressions to sort on the right side. /// @@ -1078,12 +745,14 @@ pub fn prepare_sorted_exprs( filter: &JoinFilter, left: &Arc, right: &Arc, - left_sort_exprs: &[PhysicalSortExpr], - right_sort_exprs: &[PhysicalSortExpr], + left_sort_exprs: LexOrderingRef, + right_sort_exprs: LexOrderingRef, ) -> Result<(SortedFilterExpr, SortedFilterExpr, ExprIntervalGraph)> { - // Build the filter order for the left side - let err = || plan_datafusion_err!("Filter does not include the child order"); + let err = || { + datafusion_common::plan_datafusion_err!("Filter does not include the child order") + }; + // Build the filter order for the left side: let left_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Left, filter, @@ -1092,7 +761,7 @@ pub fn prepare_sorted_exprs( )? .ok_or_else(err)?; - // Build the filter order for the right side + // Build the filter order for the right side: let right_temp_sorted_filter_expr = build_filter_input_order( JoinSide::Right, filter, @@ -1107,7 +776,7 @@ pub fn prepare_sorted_exprs( // Build the expression interval graph let mut graph = - ExprIntervalGraph::try_new(filter.expression().clone(), filter.schema())?; + ExprIntervalGraph::try_new(Arc::clone(filter.expression()), filter.schema())?; // Update sorted expressions with node indices update_sorted_exprs_with_node_indices(&mut graph, &mut sorted_exprs); @@ -1169,9 +838,9 @@ pub mod tests { &intermediate_schema, )?; let filter_expr = binary( - filter_left.clone(), + Arc::clone(&filter_left), Operator::Gt, - filter_right.clone(), + Arc::clone(&filter_right), &intermediate_schema, )?; let column_indices = vec![ @@ -1303,15 +972,15 @@ pub mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 9d48c2a7d408..5b6dc2cd2ae9 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -27,9 +27,10 @@ use std::any::Any; use std::fmt::{self, Debug}; +use std::mem::{size_of, size_of_val}; use std::sync::Arc; -use std::task::Poll; -use std::{usize, vec}; +use std::task::{Context, Poll}; +use std::vec; use crate::common::SharedMemoryReservation; use crate::joins::hash_join::{equal_rows_arr, update_hash}; @@ -37,17 +38,16 @@ use crate::joins::stream_join_utils::{ calculate_filter_expr_intervals, combine_two_batches, convert_sort_expr_with_filter_schema, get_pruning_anti_indices, get_pruning_semi_indices, prepare_sorted_exprs, record_visited_indices, - EagerJoinStream, EagerJoinStreamState, PruningJoinHashMap, SortedFilterExpr, - StreamJoinMetrics, + PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, }; use crate::joins::utils::{ apply_join_filter_to_indices, build_batch_from_indices, build_join_schema, - check_join_is_valid, partitioned_join_output_partitioning, ColumnIndex, JoinFilter, - JoinHashMapType, JoinOn, JoinOnRef, StatefulStreamResult, + check_join_is_valid, symmetric_join_output_partitioning, BatchSplitter, + BatchTransformer, ColumnIndex, JoinFilter, JoinHashMapType, JoinOn, JoinOnRef, + NoopBatchTransformer, StatefulStreamResult, }; use crate::{ execution_mode_from_children, - expressions::PhysicalSortExpr, joins::StreamJoinPartitionMode, metrics::{ExecutionPlanMetricsSet, MetricsSet}, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, @@ -61,6 +61,7 @@ use arrow::array::{ use arrow::compute::concat_batches; use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_buffer::ArrowNativeType; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::bisect; use datafusion_common::{internal_err, plan_err, JoinSide, JoinType, Result}; @@ -72,7 +73,10 @@ use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use ahash::RandomState; -use futures::Stream; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexOrderingRef, LexRequirement, +}; +use futures::{ready, Stream, StreamExt}; use hashbrown::HashSet; use parking_lot::Mutex; @@ -94,7 +98,7 @@ const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; /// - If so record the visited rows. If the matched row results must be produced (INNER, LEFT), output the [RecordBatch]. /// - Try to prune other side (probe) with new [RecordBatch]. /// - If the join type indicates that the unmatched rows results must be produced (LEFT, FULL etc.), -/// output the [RecordBatch] when a pruning happens or at the end of the data. +/// output the [RecordBatch] when a pruning happens or at the end of the data. /// /// /// ``` text @@ -163,7 +167,7 @@ const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4; /// making the smallest value in 'left_sorted' 1231 and any rows below (since ascending) /// than that can be dropped from the inner buffer. /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SymmetricHashJoinExec { /// Left side stream pub(crate) left: Arc, @@ -184,9 +188,9 @@ pub struct SymmetricHashJoinExec { /// If null_equals_null is true, null == null else null != null pub(crate) null_equals_null: bool, /// Left side sort expression(s) - pub(crate) left_sort_exprs: Option>, + pub(crate) left_sort_exprs: Option, /// Right side sort expression(s) - pub(crate) right_sort_exprs: Option>, + pub(crate) right_sort_exprs: Option, /// Partition Mode mode: StreamJoinPartitionMode, /// Cache holding plan properties like equivalences, output partitioning etc. @@ -208,14 +212,14 @@ impl SymmetricHashJoinExec { filter: Option, join_type: &JoinType, null_equals_null: bool, - left_sort_exprs: Option>, - right_sort_exprs: Option>, + left_sort_exprs: Option, + right_sort_exprs: Option, mode: StreamJoinPartitionMode, ) -> Result { let left_schema = left.schema(); let right_schema = right.schema(); - // Error out if no "on" contraints are given: + // Error out if no "on" constraints are given: if on.is_empty() { return plan_err!( "On constraints in SymmetricHashJoinExec should be non-empty" @@ -233,7 +237,7 @@ impl SymmetricHashJoinExec { let random_state = RandomState::with_seeds(0, 0, 0, 0); let schema = Arc::new(schema); let cache = - Self::compute_properties(&left, &right, schema.clone(), *join_type, &on); + Self::compute_properties(&left, &right, Arc::clone(&schema), *join_type, &on); Ok(SymmetricHashJoinExec { left, right, @@ -271,14 +275,8 @@ impl SymmetricHashJoinExec { join_on, ); - // Get output partitioning: - let left_columns_len = left.schema().fields.len(); - let output_partitioning = partitioned_join_output_partitioning( - join_type, - left.output_partitioning(), - right.output_partitioning(), - left_columns_len, - ); + let output_partitioning = + symmetric_join_output_partitioning(left, right, &join_type); // Determine execution mode: let mode = execution_mode_from_children([left, right]); @@ -322,12 +320,12 @@ impl SymmetricHashJoinExec { } /// Get left_sort_exprs - pub fn left_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + pub fn left_sort_exprs(&self) -> Option { self.left_sort_exprs.as_deref() } /// Get right_sort_exprs - pub fn right_sort_exprs(&self) -> Option<&[PhysicalSortExpr]> { + pub fn right_sort_exprs(&self) -> Option { self.right_sort_exprs.as_deref() } @@ -403,7 +401,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { let (left_expr, right_expr) = self .on .iter() - .map(|(l, r)| (l.clone() as _, r.clone() as _)) + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) .unzip(); vec![ Distribution::HashPartitioned(left_expr), @@ -416,19 +414,21 @@ impl ExecutionPlan for SymmetricHashJoinExec { } } - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { vec![ self.left_sort_exprs .as_ref() + .map(LexOrdering::iter) .map(PhysicalSortRequirement::from_sort_exprs), self.right_sort_exprs .as_ref() + .map(LexOrdering::iter) .map(PhysicalSortRequirement::from_sort_exprs), ] } - fn children(&self) -> Vec> { - vec![self.left.clone(), self.right.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] } fn with_new_children( @@ -436,8 +436,8 @@ impl ExecutionPlan for SymmetricHashJoinExec { children: Vec>, ) -> Result> { Ok(Arc::new(SymmetricHashJoinExec::try_new( - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.on.clone(), self.filter.clone(), &self.join_type, @@ -470,23 +470,27 @@ impl ExecutionPlan for SymmetricHashJoinExec { consider using RepartitionExec" ); } - // If `filter_state` and `filter` are both present, then calculate sorted filter expressions - // for both sides, and build an expression graph. - let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = - match (&self.left_sort_exprs, &self.right_sort_exprs, &self.filter) { - (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { - let (left, right, graph) = prepare_sorted_exprs( - filter, - &self.left, - &self.right, - left_sort_exprs, - right_sort_exprs, - )?; - (Some(left), Some(right), Some(graph)) - } - // If `filter_state` or `filter` is not present, then return None for all three values: - _ => (None, None, None), - }; + // If `filter_state` and `filter` are both present, then calculate sorted + // filter expressions for both sides, and build an expression graph. + let (left_sorted_filter_expr, right_sorted_filter_expr, graph) = match ( + self.left_sort_exprs(), + self.right_sort_exprs(), + &self.filter, + ) { + (Some(left_sort_exprs), Some(right_sort_exprs), Some(filter)) => { + let (left, right, graph) = prepare_sorted_exprs( + filter, + &self.left, + &self.right, + left_sort_exprs, + right_sort_exprs, + )?; + (Some(left), Some(right), Some(graph)) + } + // If `filter_state` or `filter` is not present, then return None + // for all three values: + _ => (None, None, None), + }; let (on_left, on_right) = self.on.iter().cloned().unzip(); @@ -495,9 +499,13 @@ impl ExecutionPlan for SymmetricHashJoinExec { let right_side_joiner = OneSideHashJoiner::new(JoinSide::Right, on_right, self.right.schema()); - let left_stream = self.left.execute(partition, context.clone())?; + let left_stream = self.left.execute(partition, Arc::clone(&context))?; - let right_stream = self.right.execute(partition, context.clone())?; + let right_stream = self.right.execute(partition, Arc::clone(&context))?; + + let batch_size = context.session_config().batch_size(); + let enforce_batch_size_in_joins = + context.session_config().enforce_batch_size_in_joins(); let reservation = Arc::new(Mutex::new( MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]")) @@ -507,29 +515,52 @@ impl ExecutionPlan for SymmetricHashJoinExec { reservation.lock().try_grow(g.size())?; } - Ok(Box::pin(SymmetricHashJoinStream { - left_stream, - right_stream, - schema: self.schema(), - filter: self.filter.clone(), - join_type: self.join_type, - random_state: self.random_state.clone(), - left: left_side_joiner, - right: right_side_joiner, - column_indices: self.column_indices.clone(), - metrics: StreamJoinMetrics::new(partition, &self.metrics), - graph, - left_sorted_filter_expr, - right_sorted_filter_expr, - null_equals_null: self.null_equals_null, - state: EagerJoinStreamState::PullRight, - reservation, - })) + if enforce_batch_size_in_joins { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: BatchSplitter::new(batch_size), + })) + } else { + Ok(Box::pin(SymmetricHashJoinStream { + left_stream, + right_stream, + schema: self.schema(), + filter: self.filter.clone(), + join_type: self.join_type, + random_state: self.random_state.clone(), + left: left_side_joiner, + right: right_side_joiner, + column_indices: self.column_indices.clone(), + metrics: StreamJoinMetrics::new(partition, &self.metrics), + graph, + left_sorted_filter_expr, + right_sorted_filter_expr, + null_equals_null: self.null_equals_null, + state: SHJStreamState::PullRight, + reservation, + batch_transformer: NoopBatchTransformer::new(), + })) + } } } /// A stream that issues [RecordBatch]es as they arrive from the right of the join. -struct SymmetricHashJoinStream { +struct SymmetricHashJoinStream { /// Input streams left_stream: SendableRecordBatchStream, right_stream: SendableRecordBatchStream, @@ -560,21 +591,25 @@ struct SymmetricHashJoinStream { /// Memory reservation reservation: SharedMemoryReservation, /// State machine for input execution - state: EagerJoinStreamState, + state: SHJStreamState, + /// Transforms the output batch before returning. + batch_transformer: T, } -impl RecordBatchStream for SymmetricHashJoinStream { +impl RecordBatchStream + for SymmetricHashJoinStream +{ fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } -impl Stream for SymmetricHashJoinStream { +impl Stream for SymmetricHashJoinStream { type Item = Result; fn poll_next( mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { self.poll_next_impl(cx) } @@ -590,7 +625,7 @@ impl Stream for SymmetricHashJoinStream { /// /// * `buffer`: The record batch to be pruned. /// * `build_side_filter_expr`: The filter expression on the build side used -/// to determine the pruning length. +/// to determine the pruning length. /// /// # Returns /// @@ -639,7 +674,11 @@ fn need_to_produce_result_in_final(build_side: JoinSide, join_type: JoinType) -> if build_side == JoinSide::Left { matches!( join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::Full | JoinType::LeftSemi + JoinType::Left + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftSemi + | JoinType::LeftMark ) } else { matches!( @@ -678,6 +717,20 @@ where { // Store the result in a tuple let result = match (build_side, join_type) { + (JoinSide::Left, JoinType::LeftMark) => { + let build_indices = (0..prune_length) + .map(L::Native::from_usize) + .collect::>(); + let probe_indices = (0..prune_length) + .map(|idx| { + // For mark join we output a dummy index 0 to indicate the row had a match + visited_rows + .contains(&(idx + deleted_offset)) + .then_some(R::Native::from_usize(0).unwrap()) + }) + .collect(); + (build_indices, probe_indices) + } // In the case of `Left` or `Right` join, or `Full` join, get the anti indices (JoinSide::Left, JoinType::Left | JoinType::LeftAnti) | (JoinSide::Right, JoinType::Right | JoinType::RightAnti) @@ -841,6 +894,7 @@ pub(crate) fn join_with_probe_batch( JoinType::LeftAnti | JoinType::RightAnti | JoinType::LeftSemi + | JoinType::LeftMark | JoinType::RightSemi ) { Ok(None) @@ -935,13 +989,11 @@ fn lookup_join_hashmap( let (mut matched_probe, mut matched_build) = build_hashmap .get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset); - matched_probe.as_slice_mut().reverse(); - matched_build.as_slice_mut().reverse(); + matched_probe.reverse(); + matched_build.reverse(); - let build_indices: UInt64Array = - PrimitiveArray::new(matched_build.finish().into(), None); - let probe_indices: UInt32Array = - PrimitiveArray::new(matched_probe.finish().into(), None); + let build_indices: UInt64Array = matched_build.into(); + let probe_indices: UInt32Array = matched_probe.into(); let (build_indices, probe_indices) = equal_rows_arr( &build_indices, @@ -976,15 +1028,15 @@ pub struct OneSideHashJoiner { impl OneSideHashJoiner { pub fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(self); - size += std::mem::size_of_val(&self.build_side); + size += size_of_val(self); + size += size_of_val(&self.build_side); size += self.input_buffer.get_array_memory_size(); - size += std::mem::size_of_val(&self.on); + size += size_of_val(&self.on); size += self.hashmap.size(); - size += self.hashes_buffer.capacity() * std::mem::size_of::(); - size += self.visited_rows.capacity() * std::mem::size_of::(); - size += std::mem::size_of_val(&self.offset); - size += std::mem::size_of_val(&self.deleted_offset); + size += self.hashes_buffer.capacity() * size_of::(); + size += self.visited_rows.capacity() * size_of::(); + size += size_of_val(&self.offset); + size += size_of_val(&self.deleted_offset); size } pub fn new( @@ -1103,7 +1155,246 @@ impl OneSideHashJoiner { } } -impl EagerJoinStream for SymmetricHashJoinStream { +/// `SymmetricHashJoinStream` manages incremental join operations between two +/// streams. Unlike traditional join approaches that need to scan one side of +/// the join fully before proceeding, `SymmetricHashJoinStream` facilitates +/// more dynamic join operations by working with streams as they emit data. This +/// approach allows for more efficient processing, particularly in scenarios +/// where waiting for complete data materialization is not feasible or optimal. +/// The trait provides a framework for handling various states of such a join +/// process, ensuring that join logic is efficiently executed as data becomes +/// available from either stream. +/// +/// This implementation performs eager joins of data from two different asynchronous +/// streams, typically referred to as left and right streams. The implementation +/// provides a comprehensive set of methods to control and execute the join +/// process, leveraging the states defined in `SHJStreamState`. Methods are +/// primarily focused on asynchronously fetching data batches from each stream, +/// processing them, and managing transitions between various states of the join. +/// +/// This implementations use a state machine approach to navigate different +/// stages of the join operation, handling data from both streams and determining +/// when the join completes. +/// +/// State Transitions: +/// - From `PullLeft` to `PullRight` or `LeftExhausted`: +/// - In `fetch_next_from_left_stream`, when fetching a batch from the left stream: +/// - On success (`Some(Ok(batch))`), state transitions to `PullRight` for +/// processing the batch. +/// - On error (`Some(Err(e))`), the error is returned, and the state remains +/// unchanged. +/// - On no data (`None`), state changes to `LeftExhausted`, returning `Continue` +/// to proceed with the join process. +/// - From `PullRight` to `PullLeft` or `RightExhausted`: +/// - In `fetch_next_from_right_stream`, when fetching from the right stream: +/// - If a batch is available, state changes to `PullLeft` for processing. +/// - On error, the error is returned without changing the state. +/// - If right stream is exhausted (`None`), state transitions to `RightExhausted`, +/// with a `Continue` result. +/// - Handling `RightExhausted` and `LeftExhausted`: +/// - Methods `handle_right_stream_end` and `handle_left_stream_end` manage scenarios +/// when streams are exhausted: +/// - They attempt to continue processing with the other stream. +/// - If both streams are exhausted, state changes to `BothExhausted { final_result: false }`. +/// - Transition to `BothExhausted { final_result: true }`: +/// - Occurs in `prepare_for_final_results_after_exhaustion` when both streams are +/// exhausted, indicating completion of processing and availability of final results. +impl SymmetricHashJoinStream { + /// Implements the main polling logic for the join stream. + /// + /// This method continuously checks the state of the join stream and + /// acts accordingly by delegating the handling to appropriate sub-methods + /// depending on the current state. + /// + /// # Arguments + /// + /// * `cx` - A context that facilitates cooperative non-blocking execution within a task. + /// + /// # Returns + /// + /// * `Poll>>` - A polled result, either a `RecordBatch` or None. + fn poll_next_impl( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.batch_transformer.next() { + None => { + let result = match self.state() { + SHJStreamState::PullRight => { + ready!(self.fetch_next_from_right_stream(cx)) + } + SHJStreamState::PullLeft => { + ready!(self.fetch_next_from_left_stream(cx)) + } + SHJStreamState::RightExhausted => { + ready!(self.handle_right_stream_end(cx)) + } + SHJStreamState::LeftExhausted => { + ready!(self.handle_left_stream_end(cx)) + } + SHJStreamState::BothExhausted { + final_result: false, + } => self.prepare_for_final_results_after_exhaustion(), + SHJStreamState::BothExhausted { final_result: true } => { + return Poll::Ready(None); + } + }; + + match result? { + StatefulStreamResult::Ready(None) => { + return Poll::Ready(None); + } + StatefulStreamResult::Ready(Some(batch)) => { + self.batch_transformer.set_batch(batch); + } + _ => {} + } + } + Some((batch, _)) => { + self.metrics.output_batches.add(1); + self.metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); + } + } + } + } + /// Asynchronously pulls the next batch from the right stream. + /// + /// This default implementation checks for the next value in the right stream. + /// If a batch is found, the state is switched to `PullLeft`, and the batch handling + /// is delegated to `process_batch_from_right`. If the stream ends, the state is set to `RightExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + fn fetch_next_from_right_stream( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.right_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + self.set_state(SHJStreamState::PullLeft); + Poll::Ready(self.process_batch_from_right(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::RightExhausted); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Asynchronously pulls the next batch from the left stream. + /// + /// This default implementation checks for the next value in the left stream. + /// If a batch is found, the state is switched to `PullRight`, and the batch handling + /// is delegated to `process_batch_from_left`. If the stream ends, the state is set to `LeftExhausted`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after pulling the batch. + fn fetch_next_from_left_stream( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.left_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + self.set_state(SHJStreamState::PullRight); + Poll::Ready(self.process_batch_from_left(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::LeftExhausted); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Asynchronously handles the scenario when the right stream is exhausted. + /// + /// In this default implementation, when the right stream is exhausted, it attempts + /// to pull from the left stream. If a batch is found in the left stream, it delegates + /// the handling to `process_batch_from_left`. If both streams are exhausted, the state is set + /// to indicate both streams are exhausted without final results yet. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + fn handle_right_stream_end( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.left_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Poll::Ready(self.process_batch_after_right_end(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::BothExhausted { + final_result: false, + }); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Asynchronously handles the scenario when the left stream is exhausted. + /// + /// When the left stream is exhausted, this default + /// implementation tries to pull from the right stream and delegates the batch + /// handling to `process_batch_after_left_end`. If both streams are exhausted, the state + /// is updated to indicate so. + /// + /// # Returns + /// + /// * `Result>>` - The state result after checking the exhaustion state. + fn handle_left_stream_end( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>>> { + match ready!(self.right_stream().poll_next_unpin(cx)) { + Some(Ok(batch)) => { + if batch.num_rows() == 0 { + return Poll::Ready(Ok(StatefulStreamResult::Continue)); + } + Poll::Ready(self.process_batch_after_left_end(batch)) + } + Some(Err(e)) => Poll::Ready(Err(e)), + None => { + self.set_state(SHJStreamState::BothExhausted { + final_result: false, + }); + Poll::Ready(Ok(StatefulStreamResult::Continue)) + } + } + } + + /// Handles the state when both streams are exhausted and final results are yet to be produced. + /// + /// This default implementation switches the state to indicate both streams are + /// exhausted with final results and then invokes the handling for this specific + /// scenario via `process_batches_before_finalization`. + /// + /// # Returns + /// + /// * `Result>>` - The state result after both streams are exhausted. + fn prepare_for_final_results_after_exhaustion( + &mut self, + ) -> Result>> { + self.set_state(SHJStreamState::BothExhausted { final_result: true }); + self.process_batches_before_finalization() + } + fn process_batch_from_right( &mut self, batch: RecordBatch, @@ -1171,11 +1462,8 @@ impl EagerJoinStream for SymmetricHashJoinStream { // Combine the left and right results: let result = combine_two_batches(&self.schema, left_result, right_result)?; - // Update the metrics and return the result: - if let Some(batch) = &result { - // Update the metrics: - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); + // Return the result: + if result.is_some() { return Ok(StatefulStreamResult::Ready(result)); } Ok(StatefulStreamResult::Continue) @@ -1189,30 +1477,28 @@ impl EagerJoinStream for SymmetricHashJoinStream { &mut self.left_stream } - fn set_state(&mut self, state: EagerJoinStreamState) { + fn set_state(&mut self, state: SHJStreamState) { self.state = state; } - fn state(&mut self) -> EagerJoinStreamState { + fn state(&mut self) -> SHJStreamState { self.state.clone() } -} -impl SymmetricHashJoinStream { fn size(&self) -> usize { let mut size = 0; - size += std::mem::size_of_val(&self.schema); - size += std::mem::size_of_val(&self.filter); - size += std::mem::size_of_val(&self.join_type); + size += size_of_val(&self.schema); + size += size_of_val(&self.filter); + size += size_of_val(&self.join_type); size += self.left.size(); size += self.right.size(); - size += std::mem::size_of_val(&self.column_indices); + size += size_of_val(&self.column_indices); size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0); - size += std::mem::size_of_val(&self.left_sorted_filter_expr); - size += std::mem::size_of_val(&self.right_sorted_filter_expr); - size += std::mem::size_of_val(&self.random_state); - size += std::mem::size_of_val(&self.null_equals_null); - size += std::mem::size_of_val(&self.metrics); + size += size_of_val(&self.left_sorted_filter_expr); + size += size_of_val(&self.right_sorted_filter_expr); + size += size_of_val(&self.random_state); + size += size_of_val(&self.null_equals_null); + size += size_of_val(&self.metrics); size } @@ -1312,15 +1598,38 @@ impl SymmetricHashJoinStream { let capacity = self.size(); self.metrics.stream_memory_usage.set(capacity); self.reservation.lock().try_resize(capacity)?; - // Update the metrics if we have a batch; otherwise, continue the loop. - if let Some(batch) = &result { - self.metrics.output_batches.add(1); - self.metrics.output_rows.add(batch.num_rows()); - } Ok(result) } } +/// Represents the various states of an symmetric hash join stream operation. +/// +/// This enum is used to track the current state of streaming during a join +/// operation. It provides indicators as to which side of the join needs to be +/// pulled next or if one (or both) sides have been exhausted. This allows +/// for efficient management of resources and optimal performance during the +/// join process. +#[derive(Clone, Debug)] +pub enum SHJStreamState { + /// Indicates that the next step should pull from the right side of the join. + PullRight, + + /// Indicates that the next step should pull from the left side of the join. + PullLeft, + + /// State representing that the right side of the join has been fully processed. + RightExhausted, + + /// State representing that the left side of the join has been fully processed. + LeftExhausted, + + /// Represents a state where both sides of the join are exhausted. + /// + /// The `final_result` field indicates whether the join operation has + /// produced a final result or not. + BothExhausted { final_result: bool }, +} + #[cfg(test)] mod tests { use std::collections::HashMap; @@ -1340,6 +1649,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{binary, col, lit, Column}; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use once_cell::sync::Lazy; use rstest::*; @@ -1394,13 +1704,13 @@ mod tests { task_ctx: Arc, ) -> Result<()> { let first_batches = partitioned_sym_join_with_filter( - left.clone(), - right.clone(), + Arc::clone(&left), + Arc::clone(&right), on.clone(), filter.clone(), &join_type, false, - task_ctx.clone(), + Arc::clone(&task_ctx), ) .await?; let second_batches = partitioned_hash_join_with_filter( @@ -1421,6 +1731,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1439,7 +1750,7 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: binary( col("la1", left_schema)?, Operator::Plus, @@ -1447,11 +1758,11 @@ mod tests { left_schema, )?, options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1477,15 +1788,15 @@ mod tests { let filter_expr = complicated_filter(&intermediate_schema)?; let column_indices = vec![ ColumnIndex { - index: 0, + index: left_schema.index_of("la1")?, side: JoinSide::Left, }, ColumnIndex { - index: 4, + index: left_schema.index_of("la2")?, side: JoinSide::Left, }, ColumnIndex { - index: 0, + index: right_schema.index_of("ra1")?, side: JoinSide::Right, }, ]; @@ -1505,6 +1816,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1517,14 +1829,14 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1532,10 +1844,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1572,6 +1881,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1586,10 +1896,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1626,6 +1933,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1638,10 +1946,7 @@ mod tests { let (left, right) = create_memory_table(left_partition, right_partition, vec![], vec![])?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; experiment(left, right, None, join_type, on, task_ctx).await?; Ok(()) } @@ -1656,6 +1961,7 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] @@ -1666,20 +1972,20 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((11, 21), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1_des", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1_des", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1687,10 +1993,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1727,20 +2030,20 @@ mod tests { let (left_partition, right_partition) = get_or_create_table((10, 11), 8)?; let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_asc_null_first", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_asc_null_first", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1748,10 +2051,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1788,20 +2088,20 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_asc_null_last", left_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_asc_null_last", right_schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1809,10 +2109,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1851,20 +2148,20 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_desc_null_first", left_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_desc_null_first", right_schema)?, options: SortOptions { descending: true, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1872,10 +2169,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Int32, true), @@ -1915,15 +2209,15 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }]; + }]); - let right_sorted = vec![PhysicalSortExpr { + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -1931,10 +2225,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -1976,20 +2267,20 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); let left_sorted = vec![ - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("la1", left_schema)?, options: SortOptions::default(), - }], - vec![PhysicalSortExpr { + }]), + LexOrdering::new(vec![PhysicalSortExpr { expr: col("la2", left_schema)?, options: SortOptions::default(), - }], + }]), ]; - let right_sorted = vec![PhysicalSortExpr { + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ra1", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, @@ -1998,10 +2289,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("0", DataType::Int32, true), @@ -2039,13 +2327,14 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] join_type: JoinType, #[values( - (4, 5), - (12, 17), + (4, 5), + (12, 17), )] cardinality: (i32, i32), #[values(0, 1, 2)] case_expr: usize, @@ -2057,24 +2346,21 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; - let left_sorted = vec![PhysicalSortExpr { + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("lt1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("rt1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2124,13 +2410,14 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] join_type: JoinType, #[values( - (4, 5), - (12, 17), + (4, 5), + (12, 17), )] cardinality: (i32, i32), ) -> Result<()> { @@ -2141,24 +2428,21 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; - let left_sorted = vec![PhysicalSortExpr { + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("li1", left_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("ri1", right_schema)?, options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2201,13 +2485,14 @@ mod tests { JoinType::RightSemi, JoinType::LeftSemi, JoinType::LeftAnti, + JoinType::LeftMark, JoinType::RightAnti, JoinType::Full )] join_type: JoinType, #[values( - (4, 5), - (12, 17), + (4, 5), + (12, 17), )] cardinality: (i32, i32), #[values(0, 1, 2, 3, 4, 5)] case_expr: usize, @@ -2219,14 +2504,14 @@ mod tests { let left_schema = &left_partition[0].schema(); let right_schema = &right_partition[0].schema(); - let left_sorted = vec![PhysicalSortExpr { + let left_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("l_float", left_schema)?, options: SortOptions::default(), - }]; - let right_sorted = vec![PhysicalSortExpr { + }]); + let right_sorted = LexOrdering::new(vec![PhysicalSortExpr { expr: col("r_float", right_schema)?, options: SortOptions::default(), - }]; + }]); let (left, right) = create_memory_table( left_partition, right_partition, @@ -2234,10 +2519,7 @@ mod tests { vec![right_sorted], )?; - let on = vec![( - Arc::new(Column::new_with_schema("lc1", left_schema)?) as _, - Arc::new(Column::new_with_schema("rc1", right_schema)?) as _, - )]; + let on = vec![(col("lc1", left_schema)?, col("rc1", right_schema)?)]; let intermediate_schema = Schema::new(vec![ Field::new("left", DataType::Float64, true), diff --git a/datafusion/physical-plan/src/joins/test_utils.rs b/datafusion/physical-plan/src/joins/test_utils.rs index 6fb3aef5d5bf..421fd0da808c 100644 --- a/datafusion/physical-plan/src/joins/test_utils.rs +++ b/datafusion/physical-plan/src/joins/test_utils.rs @@ -18,7 +18,6 @@ //! This file has test utils for hash joins use std::sync::Arc; -use std::usize; use crate::joins::utils::{JoinFilter, JoinOn}; use crate::joins::{ @@ -33,6 +32,7 @@ use arrow_array::{ ArrayRef, Float64Array, Int32Array, IntervalDayTimeArray, RecordBatch, TimestampMillisecondArray, }; +use arrow_buffer::IntervalDayTime; use arrow_schema::{DataType, Schema}; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::TaskContext; @@ -78,31 +78,39 @@ pub async fn partitioned_sym_join_with_filter( ) -> Result> { let partition_count = 4; - let left_expr = on.iter().map(|(l, _)| l.clone() as _).collect::>(); + let left_expr = on + .iter() + .map(|(l, _)| Arc::clone(l) as _) + .collect::>(); - let right_expr = on.iter().map(|(_, r)| r.clone() as _).collect::>(); + let right_expr = on + .iter() + .map(|(_, r)| Arc::clone(r) as _) + .collect::>(); let join = SymmetricHashJoinExec::try_new( Arc::new(RepartitionExec::try_new( - left.clone(), + Arc::clone(&left), Partitioning::Hash(left_expr, partition_count), )?), Arc::new(RepartitionExec::try_new( - right.clone(), + Arc::clone(&right), Partitioning::Hash(right_expr, partition_count), )?), on, filter, join_type, null_equals_null, - left.output_ordering().map(|p| p.to_vec()), - right.output_ordering().map(|p| p.to_vec()), + left.output_ordering().map(|p| LexOrdering::new(p.to_vec())), + right + .output_ordering() + .map(|p| LexOrdering::new(p.to_vec())), StreamJoinPartitionMode::Partitioned, )?; let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; + let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -127,7 +135,7 @@ pub async fn partitioned_hash_join_with_filter( let partition_count = 4; let (left_expr, right_expr) = on .iter() - .map(|(l, r)| (l.clone() as _, r.clone() as _)) + .map(|(l, r)| (Arc::clone(l) as _, Arc::clone(r) as _)) .unzip(); let join = Arc::new(HashJoinExec::try_new( @@ -149,7 +157,7 @@ pub async fn partitioned_hash_join_with_filter( let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, context.clone())?; + let stream = join.execute(i, Arc::clone(&context))?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -283,7 +291,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(10 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 1 > right_col + 5 AND left_col + 3 < right_col + 10 + // left_col - 1 > right_col + 3 AND left_col + 3 < right_col + 15 1 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -294,9 +302,9 @@ macro_rules! join_expr_tests { Operator::Plus, ), ScalarValue::$SCALAR(Some(1 as $type)), - ScalarValue::$SCALAR(Some(5 as $type)), ScalarValue::$SCALAR(Some(3 as $type)), - ScalarValue::$SCALAR(Some(10 as $type)), + ScalarValue::$SCALAR(Some(3 as $type)), + ScalarValue::$SCALAR(Some(15 as $type)), (Operator::Gt, Operator::Lt), ), // left_col - 1 > right_col + 5 AND left_col - 3 < right_col + 10 @@ -347,7 +355,8 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::Gt, Operator::Lt), ), - // left_col - 2 >= right_col - 5 AND left_col - 7 <= right_col - 3 + // left_col - 2 >= right_col + 5 AND left_col + 7 <= right_col - 3 + // (filters all input rows) 5 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -363,7 +372,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(3 as $type)), (Operator::GtEq, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col - 39 + // left_col + 28 >= right_col - 11 AND left_col + 21 <= right_col + 39 6 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -379,7 +388,7 @@ macro_rules! join_expr_tests { ScalarValue::$SCALAR(Some(39 as $type)), (Operator::Gt, Operator::LtEq), ), - // left_col - 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 + // left_col + 28 >= right_col - 11 AND left_col - 21 <= right_col + 39 7 => gen_conjunctive_numerical_expr( left_col, right_col, @@ -462,8 +471,11 @@ pub fn build_sides_record_batches( )); let interval_time: ArrayRef = Arc::new(IntervalDayTimeArray::from( initial_range - .map(|x| x as i64 * 100) // x * 100ms - .collect::>(), + .map(|x| IntervalDayTime { + days: 0, + milliseconds: x * 100, + }) // x * 100ms + .collect::>(), )); let float_asc = Arc::new(Float64Array::from_iter_values( @@ -472,20 +484,29 @@ pub fn build_sides_record_batches( )); let left = RecordBatch::try_from_iter(vec![ - ("la1", ordered.clone()), - ("lb1", cardinality.clone()), + ("la1", Arc::clone(&ordered)), + ("lb1", Arc::clone(&cardinality) as ArrayRef), ("lc1", cardinality_key_left), - ("lt1", time.clone()), - ("la2", ordered.clone()), - ("la1_des", ordered_des.clone()), - ("l_asc_null_first", ordered_asc_null_first.clone()), - ("l_asc_null_last", ordered_asc_null_last.clone()), - ("l_desc_null_first", ordered_desc_null_first.clone()), - ("li1", interval_time.clone()), - ("l_float", float_asc.clone()), + ("lt1", Arc::clone(&time) as ArrayRef), + ("la2", Arc::clone(&ordered)), + ("la1_des", Arc::clone(&ordered_des) as ArrayRef), + ( + "l_asc_null_first", + Arc::clone(&ordered_asc_null_first) as ArrayRef, + ), + ( + "l_asc_null_last", + Arc::clone(&ordered_asc_null_last) as ArrayRef, + ), + ( + "l_desc_null_first", + Arc::clone(&ordered_desc_null_first) as ArrayRef, + ), + ("li1", Arc::clone(&interval_time)), + ("l_float", Arc::clone(&float_asc) as ArrayRef), ])?; let right = RecordBatch::try_from_iter(vec![ - ("ra1", ordered.clone()), + ("ra1", Arc::clone(&ordered)), ("rb1", cardinality), ("rc1", cardinality_key_right), ("rt1", time), @@ -508,10 +529,10 @@ pub fn create_memory_table( ) -> Result<(Arc, Arc)> { let left_schema = left_partition[0].schema(); let left = MemoryExec::try_new(&[left_partition], left_schema, None)? - .with_sort_information(left_sorted); + .try_with_sort_information(left_sorted)?; let right_schema = right_partition[0].schema(); let right = MemoryExec::try_new(&[right_partition], right_schema, None)? - .with_sort_information(right_sorted); + .try_with_sort_information(right_sorted)?; Ok((Arc::new(left), Arc::new(right))) } diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index acf9ed4d7ec8..d3fa37c2ac80 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -20,21 +20,24 @@ use std::collections::HashSet; use std::fmt::{self, Debug}; use std::future::Future; +use std::iter::once; use std::ops::{IndexMut, Range}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::usize; use crate::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder}; -use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics}; +use crate::{ + ColumnStatistics, ExecutionPlan, ExecutionPlanProperties, Partitioning, Statistics, +}; use arrow::array::{ downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array, - UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder, + UInt32Builder, UInt64Array, }; use arrow::compute; -use arrow::datatypes::{Field, Schema, SchemaBuilder}; +use arrow::datatypes::{Field, Schema, SchemaBuilder, UInt32Type, UInt64Type}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; +use arrow_array::builder::UInt64Builder; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; @@ -64,7 +67,7 @@ use parking_lot::Mutex; /// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 for hash value 1 /// As the key is a hash value, we need to check possible hash collisions in the probe stage /// During this stage it might be the case that a row is contained the same hashmap value, -/// but the values don't match. Those are checked in the [`equal_rows_arr`](crate::joins::hash_join::equal_rows_arr) method. +/// but the values don't match. Those are checked in the `equal_rows_arr` method. /// /// The indices (values) are stored in a separate chained list stored in the `Vec`. /// @@ -143,7 +146,7 @@ impl JoinHashMap { pub(crate) type JoinHashMapOffset = (usize, Option); // Macro for traversing chained values with limit. -// Early returns in case of reacing output tuples limit. +// Early returns in case of reaching output tuples limit. macro_rules! chain_traverse { ( $input_indices:ident, $match_indices:ident, $hash_values:ident, $next_chain:ident, @@ -161,8 +164,8 @@ macro_rules! chain_traverse { } else { i }; - $match_indices.append(match_row_idx); - $input_indices.append($input_idx as u32); + $match_indices.push(match_row_idx); + $input_indices.push($input_idx as u32); $remaining_output -= 1; // Follow the chain to get the next index value let next = $next_chain[match_row_idx as usize]; @@ -236,9 +239,9 @@ pub trait JoinHashMapType { &self, iter: impl Iterator, deleted_offset: Option, - ) -> (UInt32BufferBuilder, UInt64BufferBuilder) { - let mut input_indices = UInt32BufferBuilder::new(0); - let mut match_indices = UInt64BufferBuilder::new(0); + ) -> (Vec, Vec) { + let mut input_indices = vec![]; + let mut match_indices = vec![]; let hash_map = self.get_map(); let next_chain = self.get_list(); @@ -259,8 +262,8 @@ pub trait JoinHashMapType { } else { i }; - match_indices.append(match_row_idx); - input_indices.append(row_idx as u32); + match_indices.push(match_row_idx); + input_indices.push(row_idx as u32); // Follow the chain to get the next index value let next = next_chain[match_row_idx as usize]; if next == 0 { @@ -287,13 +290,9 @@ pub trait JoinHashMapType { deleted_offset: Option, limit: usize, offset: JoinHashMapOffset, - ) -> ( - UInt32BufferBuilder, - UInt64BufferBuilder, - Option, - ) { - let mut input_indices = UInt32BufferBuilder::new(0); - let mut match_indices = UInt64BufferBuilder::new(0); + ) -> (Vec, Vec, Option) { + let mut input_indices = vec![]; + let mut match_indices = vec![]; let mut remaining_output = limit; @@ -371,7 +370,7 @@ impl JoinHashMapType for JoinHashMap { } } -impl fmt::Debug for JoinHashMap { +impl Debug for JoinHashMap { fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result { Ok(()) } @@ -429,27 +428,6 @@ fn check_join_set_is_valid( Ok(()) } -/// Calculate the OutputPartitioning for Partitioned Join -pub fn partitioned_join_output_partitioning( - join_type: JoinType, - left_partitioning: &Partitioning, - right_partitioning: &Partitioning, - left_columns_len: usize, -) -> Partitioning { - match join_type { - JoinType::Inner | JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti => { - left_partitioning.clone() - } - JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), - JoinType::Right => { - adjust_right_output_partitioning(right_partitioning, left_columns_len) - } - JoinType::Full => { - Partitioning::UnknownPartitioning(right_partitioning.partition_count()) - } - } -} - /// Adjust the right out partitioning to new Column Index pub fn adjust_right_output_partitioning( right_partitioning: &Partitioning, @@ -459,7 +437,7 @@ pub fn adjust_right_output_partitioning( Partitioning::Hash(exprs, size) => { let new_exprs = exprs .iter() - .map(|expr| add_offset_to_expr(expr.clone(), left_columns_len)) + .map(|expr| add_offset_to_expr(Arc::clone(expr), left_columns_len)) .collect(); Partitioning::Hash(new_exprs, *size) } @@ -471,16 +449,14 @@ pub fn adjust_right_output_partitioning( /// the left column (zeroth index in the tuple) inside `right_ordering`. fn replace_on_columns_of_right_ordering( on_columns: &[(PhysicalExprRef, PhysicalExprRef)], - right_ordering: &mut [PhysicalSortExpr], + right_ordering: &mut LexOrdering, ) -> Result<()> { for (left_col, right_col) in on_columns { - for item in right_ordering.iter_mut() { - let new_expr = item - .expr - .clone() + for item in right_ordering.inner.iter_mut() { + let new_expr = Arc::clone(&item.expr) .transform(|e| { if e.eq(right_col) { - Ok(Transformed::yes(left_col.clone())) + Ok(Transformed::yes(Arc::clone(left_col))) } else { Ok(Transformed::no(e)) } @@ -496,18 +472,18 @@ fn offset_ordering( ordering: LexOrderingRef, join_type: &JoinType, offset: usize, -) -> Vec { +) -> LexOrdering { match join_type { - // In the case below, right ordering should be offseted with the left + // In the case below, right ordering should be offsetted with the left // side length, since we append the right table to the left table. JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => ordering .iter() .map(|sort_expr| PhysicalSortExpr { - expr: add_offset_to_expr(sort_expr.expr.clone(), offset), + expr: add_offset_to_expr(Arc::clone(&sort_expr.expr), offset), options: sort_expr.options, }) .collect(), - _ => ordering.to_vec(), + _ => LexOrdering::from_ref(ordering), } } @@ -527,15 +503,16 @@ pub fn calculate_join_output_ordering( if join_type == JoinType::Inner && probe_side == Some(JoinSide::Left) { replace_on_columns_of_right_ordering( on_columns, - &mut right_ordering.to_vec(), + &mut LexOrdering::from_ref(right_ordering), ) .ok()?; merge_vectors( left_ordering, - &offset_ordering(right_ordering, &join_type, left_columns_len), + offset_ordering(right_ordering, &join_type, left_columns_len) + .as_ref(), ) } else { - left_ordering.to_vec() + LexOrdering::from_ref(left_ordering) } } [false, true] => { @@ -543,11 +520,12 @@ pub fn calculate_join_output_ordering( if join_type == JoinType::Inner && probe_side == Some(JoinSide::Right) { replace_on_columns_of_right_ordering( on_columns, - &mut right_ordering.to_vec(), + &mut LexOrdering::from_ref(right_ordering), ) .ok()?; merge_vectors( - &offset_ordering(right_ordering, &join_type, left_columns_len), + offset_ordering(right_ordering, &join_type, left_columns_len) + .as_ref(), left_ordering, ) } else { @@ -571,15 +549,16 @@ pub struct ColumnIndex { pub side: JoinSide, } -/// Filter applied before join output +/// Filter applied before join output. Fields are crate-public to allow +/// downstream implementations to experiment with custom joins. #[derive(Debug, Clone)] pub struct JoinFilter { /// Filter expression - expression: Arc, + pub(crate) expression: Arc, /// Column indices required to construct intermediate batch for filtering - column_indices: Vec, + pub(crate) column_indices: Vec, /// Physical schema of intermediate batch - schema: Schema, + pub(crate) schema: Schema, } impl JoinFilter { @@ -643,6 +622,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) -> JoinType::RightSemi => false, // doesn't introduce nulls JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??) JoinType::RightAnti => false, // doesn't introduce nulls (or can it??) + JoinType::LeftMark => false, }; if force_nullable { @@ -659,44 +639,10 @@ pub fn build_join_schema( right: &Schema, join_type: &JoinType, ) -> (Schema, Vec) { - let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { - let left_fields = left - .fields() - .iter() - .map(|f| output_join_field(f, join_type, true)) - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Left, - }, - ) - }); - let right_fields = right - .fields() - .iter() - .map(|f| output_join_field(f, join_type, false)) - .enumerate() - .map(|(index, f)| { - ( - f, - ColumnIndex { - index, - side: JoinSide::Right, - }, - ) - }); - - // left then right - left_fields.chain(right_fields).unzip() - } - JoinType::LeftSemi | JoinType::LeftAnti => left - .fields() + let left_fields = || { + left.fields() .iter() - .cloned() + .map(|f| output_join_field(f, join_type, true)) .enumerate() .map(|(index, f)| { ( @@ -707,11 +653,13 @@ pub fn build_join_schema( }, ) }) - .unzip(), - JoinType::RightSemi | JoinType::RightAnti => right + }; + + let right_fields = || { + right .fields() .iter() - .cloned() + .map(|f| output_join_field(f, join_type, false)) .enumerate() .map(|(index, f)| { ( @@ -722,10 +670,34 @@ pub fn build_join_schema( }, ) }) - .unzip(), }; - (fields.finish(), column_indices) + let (fields, column_indices): (SchemaBuilder, Vec) = match join_type { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right => { + // left then right + left_fields().chain(right_fields()).unzip() + } + JoinType::LeftSemi | JoinType::LeftAnti => left_fields().unzip(), + JoinType::LeftMark => { + let right_field = once(( + Field::new("mark", arrow_schema::DataType::Boolean, false), + ColumnIndex { + index: 0, + side: JoinSide::None, + }, + )); + left_fields().chain(right_field).unzip() + } + JoinType::RightSemi | JoinType::RightAnti => right_fields().unzip(), + }; + + let metadata = left + .metadata() + .clone() + .into_iter() + .chain(right.metadata().clone()) + .collect(); + (fields.finish().with_metadata(metadata), column_indices) } /// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls @@ -745,8 +717,8 @@ impl Default for OnceAsync { } } -impl std::fmt::Debug for OnceAsync { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for OnceAsync { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "OnceAsync") } } @@ -848,12 +820,12 @@ fn estimate_join_cardinality( JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { let ij_cardinality = estimate_inner_join_cardinality( Statistics { - num_rows: left_stats.num_rows.clone(), + num_rows: left_stats.num_rows, total_byte_size: Precision::Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: right_stats.num_rows.clone(), + num_rows: right_stats.num_rows, total_byte_size: Precision::Absent, column_statistics: right_col_stats, }, @@ -920,6 +892,16 @@ fn estimate_join_cardinality( column_statistics: outer_stats.column_statistics, }) } + + JoinType::LeftMark => { + let num_rows = *left_stats.num_rows.get_value()?; + let mut column_statistics = left_stats.column_statistics; + column_statistics.push(ColumnStatistics::new_unknown()); + Some(PartialJoinStatistics { + num_rows, + column_statistics, + }) + } } } @@ -931,7 +913,7 @@ fn estimate_inner_join_cardinality( left_stats: Statistics, right_stats: Statistics, ) -> Option> { - // Immediatedly return if inputs considered as non-overlapping + // Immediately return if inputs considered as non-overlapping if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) { return Some(estimation); }; @@ -1045,7 +1027,7 @@ fn max_distinct_count( stats: &ColumnStatistics, ) -> Precision { match &stats.distinct_count { - dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc.clone(), + &dc @ (Precision::Exact(_) | Precision::Inexact(_)) => dc, _ => { // The number can never be greater than the number of rows we have // minus the nulls (since they don't count as distinct values). @@ -1141,7 +1123,7 @@ impl OnceFut { OnceFutState::Ready(r) => Poll::Ready( r.as_ref() .map(|r| r.as_ref()) - .map_err(|e| DataFusionError::External(Box::new(e.clone()))), + .map_err(|e| DataFusionError::External(Box::new(Arc::clone(e)))), ), } } @@ -1171,7 +1153,11 @@ impl OnceFut { pub(crate) fn need_produce_result_in_final(join_type: JoinType) -> bool { matches!( join_type, - JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full + JoinType::Left + | JoinType::LeftAnti + | JoinType::LeftSemi + | JoinType::LeftMark + | JoinType::Full ) } @@ -1189,6 +1175,13 @@ pub(crate) fn get_final_indices_from_bit_map( join_type: JoinType, ) -> (UInt64Array, UInt32Array) { let left_size = left_bit_map.len(); + if join_type == JoinType::LeftMark { + let left_indices = (0..left_size as u64).collect::(); + let right_indices = (0..left_size) + .map(|idx| left_bit_map.get_bit(idx).then_some(0)) + .collect::(); + return (left_indices, right_indices); + } let left_indices = if join_type == JoinType::LeftSemi { (0..left_size) .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) @@ -1272,7 +1265,10 @@ pub(crate) fn build_batch_from_indices( let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); for column_index in column_indices { - let array = if column_index.side == build_side { + let array = if column_index.side == JoinSide::None { + // LeftMark join, the mark column is a true if the indices is not null, otherwise it will be false + Arc::new(compute::is_not_null(probe_indices)?) + } else if column_index.side == build_side { let array = build_input_buffer.column(column_index.index); if array.is_empty() || build_indices.null_count() == build_indices.len() { // Outer join would generate a null index when finding no match at our side. @@ -1304,72 +1300,133 @@ pub(crate) fn adjust_indices_by_join_type( right_indices: UInt32Array, adjust_range: Range, join_type: JoinType, -) -> (UInt64Array, UInt32Array) { + preserve_order_for_right: bool, +) -> Result<(UInt64Array, UInt32Array)> { match join_type { JoinType::Inner => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::Left => { // matched - (left_indices, right_indices) + Ok((left_indices, right_indices)) // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap } - JoinType::Right | JoinType::Full => { - // matched - // unmatched right row will be produced in this batch - let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); + JoinType::Right => { // combine the matched and unmatched right result together - append_right_indices(left_indices, right_indices, right_unmatched_indices) + append_right_indices( + left_indices, + right_indices, + adjust_range, + preserve_order_for_right, + ) + } + JoinType::Full => { + append_right_indices(left_indices, right_indices, adjust_range, false) } JoinType::RightSemi => { // need to remove the duplicated record in the right side let right_indices = get_semi_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right semi` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } JoinType::RightAnti => { // need to remove the duplicated record in the right side // get the anti index for the right side let right_indices = get_anti_indices(adjust_range, &right_indices); // the left_indices will not be used later for the `right anti` join - (left_indices, right_indices) + Ok((left_indices, right_indices)) } - JoinType::LeftSemi | JoinType::LeftAnti => { + JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { // matched or unmatched left row will be produced in the end of loop // When visit the right batch, we can output the matched left row and don't need to wait the end of loop - ( + Ok(( UInt64Array::from_iter_values(vec![]), UInt32Array::from_iter_values(vec![]), - ) + )) } } } -/// Appends the `right_unmatched_indices` to the `right_indices`, -/// and fills Null to tail of `left_indices` to -/// keep the length of `right_indices` and `left_indices` consistent. +/// Appends right indices to left indices based on the specified order mode. +/// +/// The function operates in two modes: +/// 1. If `preserve_order_for_right` is true, probe matched and unmatched indices +/// are inserted in order using the `append_probe_indices_in_order()` method. +/// 2. Otherwise, unmatched probe indices are simply appended after matched ones. +/// +/// # Parameters +/// - `left_indices`: UInt64Array of left indices. +/// - `right_indices`: UInt32Array of right indices. +/// - `adjust_range`: Range to adjust the right indices. +/// - `preserve_order_for_right`: Boolean flag to determine the mode of operation. +/// +/// # Returns +/// A tuple of updated `UInt64Array` and `UInt32Array`. pub(crate) fn append_right_indices( left_indices: UInt64Array, right_indices: UInt32Array, - right_unmatched_indices: UInt32Array, -) -> (UInt64Array, UInt32Array) { - // left_indices, right_indices and right_unmatched_indices must not contain the null value - if right_unmatched_indices.is_empty() { - (left_indices, right_indices) + adjust_range: Range, + preserve_order_for_right: bool, +) -> Result<(UInt64Array, UInt32Array)> { + if preserve_order_for_right { + Ok(append_probe_indices_in_order( + left_indices, + right_indices, + adjust_range, + )) } else { - let unmatched_size = right_unmatched_indices.len(); - // the new left indices: left_indices + null array - // the new right indices: right_indices + right_unmatched_indices - let new_left_indices = left_indices - .iter() - .chain(std::iter::repeat(None).take(unmatched_size)) - .collect::(); - let new_right_indices = right_indices - .iter() - .chain(right_unmatched_indices.iter()) - .collect::(); - (new_left_indices, new_right_indices) + let right_unmatched_indices = get_anti_indices(adjust_range, &right_indices); + + if right_unmatched_indices.is_empty() { + Ok((left_indices, right_indices)) + } else { + // `into_builder()` can fail here when there is nothing to be filtered and + // left_indices or right_indices has the same reference to the cached indices. + // In that case, we use a slower alternative. + + // the new left indices: left_indices + null array + let mut new_left_indices_builder = + left_indices.into_builder().unwrap_or_else(|left_indices| { + let mut builder = UInt64Builder::with_capacity( + left_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + left_indices.null_count(), + 0, + "expected left indices to have no nulls" + ); + builder.append_slice(left_indices.values()); + builder + }); + new_left_indices_builder.append_nulls(right_unmatched_indices.len()); + let new_left_indices = UInt64Array::from(new_left_indices_builder.finish()); + + // the new right indices: right_indices + right_unmatched_indices + let mut new_right_indices_builder = right_indices + .into_builder() + .unwrap_or_else(|right_indices| { + let mut builder = UInt32Builder::with_capacity( + right_indices.len() + right_unmatched_indices.len(), + ); + debug_assert_eq!( + right_indices.null_count(), + 0, + "expected right indices to have no nulls" + ); + builder.append_slice(right_indices.values()); + builder + }); + debug_assert_eq!( + right_unmatched_indices.null_count(), + 0, + "expected right unmatched indices to have no nulls" + ); + new_right_indices_builder.append_slice(right_unmatched_indices.values()); + let new_right_indices = UInt32Array::from(new_right_indices_builder.finish()); + + Ok((new_left_indices, new_right_indices)) + } } } @@ -1399,7 +1456,7 @@ where .filter_map(|idx| { (!bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) }) - .collect::>() + .collect() } /// Returns intersection of `range` and `input_indices` omitting duplicates @@ -1428,7 +1485,61 @@ where .filter_map(|idx| { (bitmap.get_bit(idx - offset)).then_some(T::Native::from_usize(idx)) }) - .collect::>() + .collect() +} + +/// Appends probe indices in order by considering the given build indices. +/// +/// This function constructs new build and probe indices by iterating through +/// the provided indices, and appends any missing values between previous and +/// current probe index with a corresponding null build index. +/// +/// # Parameters +/// +/// - `build_indices`: `PrimitiveArray` of `UInt64Type` containing build indices. +/// - `probe_indices`: `PrimitiveArray` of `UInt32Type` containing probe indices. +/// - `range`: The range of indices to consider. +/// +/// # Returns +/// +/// A tuple of two arrays: +/// - A `PrimitiveArray` of `UInt64Type` with the newly constructed build indices. +/// - A `PrimitiveArray` of `UInt32Type` with the newly constructed probe indices. +fn append_probe_indices_in_order( + build_indices: PrimitiveArray, + probe_indices: PrimitiveArray, + range: Range, +) -> (PrimitiveArray, PrimitiveArray) { + // Builders for new indices: + let mut new_build_indices = UInt64Builder::new(); + let mut new_probe_indices = UInt32Builder::new(); + // Set previous index as the start index for the initial loop: + let mut prev_index = range.start as u32; + // Zip the two iterators. + debug_assert!(build_indices.len() == probe_indices.len()); + for (build_index, probe_index) in build_indices + .values() + .into_iter() + .zip(probe_indices.values().into_iter()) + { + // Append values between previous and current probe index with null build index: + for value in prev_index..*probe_index { + new_probe_indices.append_value(value); + new_build_indices.append_null(); + } + // Append current indices: + new_probe_indices.append_value(*probe_index); + new_build_indices.append_value(*build_index); + // Set current probe index as previous for the next iteration: + prev_index = probe_index + 1; + } + // Append remaining probe indices after the last valid probe index with null build index. + for value in prev_index..range.end as u32 { + new_probe_indices.append_value(value); + new_build_indices.append_null(); + } + // Build arrays and return: + (new_build_indices.finish(), new_probe_indices.finish()) } /// Metrics for build & probe joins @@ -1494,10 +1605,9 @@ impl BuildProbeJoinMetrics { } /// The `handle_state` macro is designed to process the result of a state-changing -/// operation, encountered e.g. in implementations of `EagerJoinStream`. It -/// operates on a `StatefulStreamResult` by matching its variants and executing -/// corresponding actions. This macro is used to streamline code that deals with -/// state transitions, reducing boilerplate and improving readability. +/// operation. It operates on a `StatefulStreamResult` by matching its variants and +/// executing corresponding actions. This macro is used to streamline code that deals +/// with state transitions, reducing boilerplate and improving readability. /// /// # Cases /// @@ -1525,26 +1635,7 @@ macro_rules! handle_state { }; } -/// The `handle_async_state` macro adapts the `handle_state` macro for use in -/// asynchronous operations, particularly when dealing with `Poll` results within -/// async traits like `EagerJoinStream`. It polls the asynchronous state-changing -/// function using `poll_unpin` and then passes the result to `handle_state` for -/// further processing. -/// -/// # Arguments -/// -/// * `$state_func`: An async function or future that returns a -/// `Result>`. -/// * `$cx`: The context to be passed for polling, usually of type `&mut Context`. -/// -#[macro_export] -macro_rules! handle_async_state { - ($state_func:expr, $cx:expr) => { - $crate::handle_state!(ready!($state_func.poll_unpin($cx))) - }; -} - -/// Represents the result of an operation on stateful join stream. +/// Represents the result of a stateful operation. /// /// This enumueration indicates whether the state produced a result that is /// ready for use (`Ready`) or if the operation requires continuation (`Continue`). @@ -1560,6 +1651,135 @@ pub enum StatefulStreamResult { Continue, } +pub(crate) fn symmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + let left_columns_len = left.schema().fields.len(); + let left_partitioning = left.output_partitioning(); + let right_partitioning = right.output_partitioning(); + match join_type { + JoinType::Left | JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => { + left_partitioning.clone() + } + JoinType::RightSemi | JoinType::RightAnti => right_partitioning.clone(), + JoinType::Inner | JoinType::Right => { + adjust_right_output_partitioning(right_partitioning, left_columns_len) + } + JoinType::Full => { + // We could also use left partition count as they are necessarily equal. + Partitioning::UnknownPartitioning(right_partitioning.partition_count()) + } + } +} + +pub(crate) fn asymmetric_join_output_partitioning( + left: &Arc, + right: &Arc, + join_type: &JoinType, +) -> Partitioning { + match join_type { + JoinType::Inner | JoinType::Right => adjust_right_output_partitioning( + right.output_partitioning(), + left.schema().fields().len(), + ), + JoinType::RightSemi | JoinType::RightAnti => right.output_partitioning().clone(), + JoinType::Left + | JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::Full + | JoinType::LeftMark => Partitioning::UnknownPartitioning( + right.output_partitioning().partition_count(), + ), + } +} + +/// Trait for incrementally generating Join output. +/// +/// This trait is used to limit some join outputs +/// so it does not produce single large batches +pub(crate) trait BatchTransformer: Debug + Clone { + /// Sets the next `RecordBatch` to be processed. + fn set_batch(&mut self, batch: RecordBatch); + + /// Retrieves the next `RecordBatch` from the transformer. + /// Returns `None` if all batches have been produced. + /// The boolean flag indicates whether the batch is the last one. + fn next(&mut self) -> Option<(RecordBatch, bool)>; +} + +#[derive(Debug, Clone)] +/// A batch transformer that does nothing. +pub(crate) struct NoopBatchTransformer { + /// RecordBatch to be processed + batch: Option, +} + +impl NoopBatchTransformer { + pub fn new() -> Self { + Self { batch: None } + } +} + +impl BatchTransformer for NoopBatchTransformer { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + self.batch.take().map(|batch| (batch, true)) + } +} + +#[derive(Debug, Clone)] +/// Splits large batches into smaller batches with a maximum number of rows. +pub(crate) struct BatchSplitter { + /// RecordBatch to be split + batch: Option, + /// Maximum number of rows in a split batch + batch_size: usize, + /// Current row index + row_index: usize, +} + +impl BatchSplitter { + /// Creates a new `BatchSplitter` with the specified batch size. + pub(crate) fn new(batch_size: usize) -> Self { + Self { + batch: None, + batch_size, + row_index: 0, + } + } +} + +impl BatchTransformer for BatchSplitter { + fn set_batch(&mut self, batch: RecordBatch) { + self.batch = Some(batch); + self.row_index = 0; + } + + fn next(&mut self) -> Option<(RecordBatch, bool)> { + let Some(batch) = &self.batch else { + return None; + }; + + let remaining_rows = batch.num_rows() - self.row_index; + let rows_to_slice = remaining_rows.min(self.batch_size); + let sliced_batch = batch.slice(self.row_index, rows_to_slice); + self.row_index += rows_to_slice; + + let mut last = false; + if self.row_index >= batch.num_rows() { + self.batch = None; + last = true; + } + + Some((sliced_batch, last)) + } +} + #[cfg(test)] mod tests { use std::pin::Pin; @@ -1568,11 +1788,13 @@ mod tests { use arrow::datatypes::{DataType, Fields}; use arrow::error::{ArrowError, Result as ArrowResult}; + use arrow_array::Int32Array; use arrow_schema::SortOptions; - use datafusion_common::stats::Precision::{Absent, Exact, Inexact}; use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue}; + use rstest::rstest; + fn check( left: &[Column], right: &[Column], @@ -1746,13 +1968,13 @@ mod tests { ) -> Statistics { Statistics { num_rows: if is_exact { - num_rows.map(Precision::Exact) + num_rows.map(Exact) } else { - num_rows.map(Precision::Inexact) + num_rows.map(Inexact) } - .unwrap_or(Precision::Absent), + .unwrap_or(Absent), column_statistics: column_stats, - total_byte_size: Precision::Absent, + total_byte_size: Absent, } } @@ -1975,9 +2197,7 @@ mod tests { ); assert_eq!( partial_join_stats.map(|s| s.column_statistics), - expected_cardinality - .clone() - .map(|_| [left_col_stats, right_col_stats].concat()) + expected_cardinality.map(|_| [left_col_stats, right_col_stats].concat()) ); } Ok(()) @@ -2000,17 +2220,17 @@ mod tests { assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(400), - total_byte_size: Precision::Absent, + num_rows: Inexact(400), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact((400 * 400) / 200)) + Some(Inexact((400 * 400) / 200)) ); Ok(()) } @@ -2018,33 +2238,33 @@ mod tests { #[test] fn test_inner_join_cardinality_decimal_range() -> Result<()> { let left_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(32500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(35000), 14, 4)), ..Default::default() }]; let right_col_stats = vec![ColumnStatistics { - distinct_count: Precision::Absent, - min_value: Precision::Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), - max_value: Precision::Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), + distinct_count: Absent, + min_value: Inexact(ScalarValue::Decimal128(Some(33500), 14, 4)), + max_value: Inexact(ScalarValue::Decimal128(Some(34000), 14, 4)), ..Default::default() }]; assert_eq!( estimate_inner_join_cardinality( Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: left_col_stats, }, Statistics { - num_rows: Precision::Inexact(100), - total_byte_size: Precision::Absent, + num_rows: Inexact(100), + total_byte_size: Absent, column_statistics: right_col_stats, }, ), - Some(Precision::Inexact(100)) + Some(Inexact(100)) ); Ok(()) } @@ -2340,7 +2560,7 @@ mod tests { ); assert!( absent_outer_estimation.is_none(), - "Expected \"None\" esimated SemiJoin cardinality for absent outer num_rows" + "Expected \"None\" estimated SemiJoin cardinality for absent outer num_rows" ); let absent_inner_estimation = estimate_join_cardinality( @@ -2358,7 +2578,7 @@ mod tests { &join_on, ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows"); - assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows esimated SemiJoin cardinality for absent inner num_rows"); + assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows estimated SemiJoin cardinality for absent inner num_rows"); let absent_inner_estimation = estimate_join_cardinality( &JoinType::LeftSemi, @@ -2370,11 +2590,11 @@ mod tests { Statistics { num_rows: Absent, total_byte_size: Absent, - column_statistics: dummy_column_stats.clone(), + column_statistics: dummy_column_stats, }, &join_on, ); - assert!(absent_inner_estimation.is_none(), "Expected \"None\" esimated SemiJoin cardinality for absent outer and inner num_rows"); + assert!(absent_inner_estimation.is_none(), "Expected \"None\" estimated SemiJoin cardinality for absent outer and inner num_rows"); Ok(()) } @@ -2382,7 +2602,7 @@ mod tests { #[test] fn test_calculate_join_output_ordering() -> Result<()> { let options = SortOptions::default(); - let left_ordering = vec![ + let left_ordering = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options, @@ -2395,8 +2615,8 @@ mod tests { expr: Arc::new(Column::new("d", 3)), options, }, - ]; - let right_ordering = vec![ + ]); + let right_ordering = LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("z", 2)), options, @@ -2405,7 +2625,7 @@ mod tests { expr: Arc::new(Column::new("y", 1)), options, }, - ]; + ]); let join_type = JoinType::Inner; let on_columns = [( Arc::new(Column::new("b", 1)) as _, @@ -2416,7 +2636,7 @@ mod tests { let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)]; let expected = [ - Some(vec![ + Some(LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("a", 0)), options, @@ -2437,8 +2657,8 @@ mod tests { expr: Arc::new(Column::new("y", 6)), options, }, - ]), - Some(vec![ + ])), + Some(LexOrdering::new(vec![ PhysicalSortExpr { expr: Arc::new(Column::new("z", 7)), options, @@ -2459,7 +2679,7 @@ mod tests { expr: Arc::new(Column::new("d", 3)), options, }, - ]), + ])), ]; for (i, (maintains_input_order, probe_side)) in @@ -2467,13 +2687,13 @@ mod tests { { assert_eq!( calculate_join_output_ordering( - &left_ordering, - &right_ordering, + left_ordering.as_ref(), + right_ordering.as_ref(), join_type, &on_columns, left_columns_len, maintains_input_order, - probe_side + probe_side, ), expected[i] ); @@ -2481,4 +2701,49 @@ mod tests { Ok(()) } + + fn create_test_batch(num_rows: usize) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + let data = Arc::new(Int32Array::from_iter_values(0..num_rows as i32)); + RecordBatch::try_new(schema, vec![data]).unwrap() + } + + fn assert_split_batches( + batches: Vec<(RecordBatch, bool)>, + batch_size: usize, + num_rows: usize, + ) { + let mut row_count = 0; + for (batch, last) in batches.into_iter() { + assert_eq!(batch.num_rows(), (num_rows - row_count).min(batch_size)); + let column = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + assert_eq!(column.value(i), i as i32 + row_count as i32); + } + row_count += batch.num_rows(); + assert_eq!(last, row_count == num_rows); + } + } + + #[rstest] + #[test] + fn test_batch_splitter( + #[values(1, 3, 11)] batch_size: usize, + #[values(1, 6, 50)] num_rows: usize, + ) { + let mut splitter = BatchSplitter::new(batch_size); + splitter.set_batch(create_test_batch(num_rows)); + + let mut batches = Vec::with_capacity(num_rows.div_ceil(batch_size)); + while let Some(batch) = splitter.next() { + batches.push(batch); + } + + assert!(splitter.next().is_none()); + assert_split_batches(batches, batch_size, num_rows); + } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index e1c8489655bf..845a74eaea48 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -15,30 +15,37 @@ // specific language governing permissions and limitations // under the License. -//! Traits for physical query plan, supporting parallel execution for partitioned relations. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 -use std::any::Any; -use std::fmt::Debug; -use std::sync::Arc; +#![deny(clippy::clone_on_ref_ptr)] -use crate::coalesce_partitions::CoalescePartitionsExec; -use crate::display::DisplayableExecutionPlan; -use crate::metrics::MetricsSet; -use crate::repartition::RepartitionExec; -use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; +//! Traits for physical query plan, supporting parallel execution for partitioned relations. +//! +//! Entrypoint of this crate is trait [ExecutionPlan]. -use arrow::datatypes::SchemaRef; -use arrow::record_batch::RecordBatch; -use datafusion_common::config::ConfigOptions; -use datafusion_common::utils::DataPtr; -use datafusion_common::Result; -use datafusion_execution::TaskContext; -use datafusion_physical_expr::{ - EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, +pub use datafusion_common::hash_utils; +pub use datafusion_common::utils::project_schema; +pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; +pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; +pub use datafusion_expr::{Accumulator, ColumnarValue}; +pub use datafusion_physical_expr::window::WindowExpr; +use datafusion_physical_expr::PhysicalSortExpr; +pub use datafusion_physical_expr::{ + expressions, udf, Distribution, Partitioning, PhysicalExpr, }; -use futures::stream::TryStreamExt; -use tokio::task::JoinSet; +pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; +pub(crate) use crate::execution_plan::execution_mode_from_children; +pub use crate::execution_plan::{ + collect, collect_partitioned, displayable, execute_input_stream, execute_stream, + execute_stream_partitioned, get_plan_string, with_new_children_if_necessary, + ExecutionMode, ExecutionPlan, ExecutionPlanProperties, PlanProperties, +}; +pub use crate::metrics::Metric; +pub use crate::ordering::InputOrderMode; +pub use crate::stream::EmptyRecordBatchStream; +pub use crate::topk::TopK; +pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; mod ordering; mod topk; @@ -51,6 +58,7 @@ pub mod coalesce_partitions; pub mod common; pub mod display; pub mod empty; +pub mod execution_plan; pub mod explain; pub mod filter; pub mod insert; @@ -63,6 +71,7 @@ pub mod projection; pub mod recursive_query; pub mod repartition; pub mod sorts; +pub mod spill; pub mod stream; pub mod streaming; pub mod tree_node; @@ -72,853 +81,11 @@ pub mod values; pub mod windows; pub mod work_table; -pub use crate::display::{DefaultDisplay, DisplayAs, DisplayFormatType, VerboseDisplay}; -pub use crate::metrics::Metric; -pub use crate::ordering::InputOrderMode; -pub use crate::topk::TopK; -pub use crate::visitor::{accept, visit_execution_plan, ExecutionPlanVisitor}; - -pub use datafusion_common::hash_utils; -pub use datafusion_common::utils::project_schema; -pub use datafusion_common::{internal_err, ColumnStatistics, Statistics}; -pub use datafusion_expr::{Accumulator, ColumnarValue}; -pub use datafusion_physical_expr::window::WindowExpr; -pub use datafusion_physical_expr::{ - expressions, functions, udf, AggregateExpr, Distribution, Partitioning, PhysicalExpr, -}; - -// Backwards compatibility -pub use crate::stream::EmptyRecordBatchStream; -pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; pub mod udaf { - pub use datafusion_physical_expr_common::aggregate::{ - create_aggregate_expr, AggregateFunctionExpr, - }; -} - -/// Represent nodes in the DataFusion Physical Plan. -/// -/// Calling [`execute`] produces an `async` [`SendableRecordBatchStream`] of -/// [`RecordBatch`] that incrementally computes a partition of the -/// `ExecutionPlan`'s output from its input. See [`Partitioning`] for more -/// details on partitioning. -/// -/// Methods such as [`Self::schema`] and [`Self::properties`] communicate -/// properties of the output to the DataFusion optimizer, and methods such as -/// [`required_input_distribution`] and [`required_input_ordering`] express -/// requirements of the `ExecutionPlan` from its input. -/// -/// [`ExecutionPlan`] can be displayed in a simplified form using the -/// return value from [`displayable`] in addition to the (normally -/// quite verbose) `Debug` output. -/// -/// [`execute`]: ExecutionPlan::execute -/// [`required_input_distribution`]: ExecutionPlan::required_input_distribution -/// [`required_input_ordering`]: ExecutionPlan::required_input_ordering -pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { - /// Short name for the ExecutionPlan, such as 'ParquetExec'. - fn name(&self) -> &'static str { - let full_name = std::any::type_name::(); - let maybe_start_idx = full_name.rfind(':'); - match maybe_start_idx { - Some(start_idx) => &full_name[start_idx + 1..], - None => "UNKNOWN", - } - } - /// Returns the execution plan as [`Any`] so that it can be - /// downcast to a specific implementation. - fn as_any(&self) -> &dyn Any; - - /// Get the schema for this execution plan - fn schema(&self) -> SchemaRef { - self.properties().schema().clone() - } - - /// Return properties of the output of the `ExecutionPlan`, such as output - /// ordering(s), partitioning information etc. - /// - /// This information is available via methods on [`ExecutionPlanProperties`] - /// trait, which is implemented for all `ExecutionPlan`s. - fn properties(&self) -> &PlanProperties; - - /// Specifies the data distribution requirements for all the - /// children for this `ExecutionPlan`, By default it's [[Distribution::UnspecifiedDistribution]] for each child, - fn required_input_distribution(&self) -> Vec { - vec![Distribution::UnspecifiedDistribution; self.children().len()] - } - - /// Specifies the ordering required for all of the children of this - /// `ExecutionPlan`. - /// - /// For each child, it's the local ordering requirement within - /// each partition rather than the global ordering - /// - /// NOTE that checking `!is_empty()` does **not** check for a - /// required input ordering. Instead, the correct check is that at - /// least one entry must be `Some` - fn required_input_ordering(&self) -> Vec>> { - vec![None; self.children().len()] - } - - /// Returns `false` if this `ExecutionPlan`'s implementation may reorder - /// rows within or between partitions. - /// - /// For example, Projection, Filter, and Limit maintain the order - /// of inputs -- they may transform values (Projection) or not - /// produce the same number of rows that went in (Filter and - /// Limit), but the rows that are produced go in the same way. - /// - /// DataFusion uses this metadata to apply certain optimizations - /// such as automatically repartitioning correctly. - /// - /// The default implementation returns `false` - /// - /// WARNING: if you override this default, you *MUST* ensure that - /// the `ExecutionPlan`'s maintains the ordering invariant or else - /// DataFusion may produce incorrect results. - fn maintains_input_order(&self) -> Vec { - vec![false; self.children().len()] - } - - /// Specifies whether the `ExecutionPlan` benefits from increased - /// parallelization at its input for each child. - /// - /// If returns `true`, the `ExecutionPlan` would benefit from partitioning - /// its corresponding child (and thus from more parallelism). For - /// `ExecutionPlan` that do very little work the overhead of extra - /// parallelism may outweigh any benefits - /// - /// The default implementation returns `true` unless this `ExecutionPlan` - /// has signalled it requires a single child input partition. - fn benefits_from_input_partitioning(&self) -> Vec { - // By default try to maximize parallelism with more CPUs if - // possible - self.required_input_distribution() - .into_iter() - .map(|dist| !matches!(dist, Distribution::SinglePartition)) - .collect() - } - - /// Get a list of children `ExecutionPlan`s that act as inputs to this plan. - /// The returned list will be empty for leaf nodes such as scans, will contain - /// a single value for unary nodes, or two values for binary nodes (such as - /// joins). - fn children(&self) -> Vec>; - - /// Returns a new `ExecutionPlan` where all existing children were replaced - /// by the `children`, in order - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result>; - - /// If supported, attempt to increase the partitioning of this `ExecutionPlan` to - /// produce `target_partitions` partitions. - /// - /// If the `ExecutionPlan` does not support changing its partitioning, - /// returns `Ok(None)` (the default). - /// - /// It is the `ExecutionPlan` can increase its partitioning, but not to the - /// `target_partitions`, it may return an ExecutionPlan with fewer - /// partitions. This might happen, for example, if each new partition would - /// be too small to be efficiently processed individually. - /// - /// The DataFusion optimizer attempts to use as many threads as possible by - /// repartitioning its inputs to match the target number of threads - /// available (`target_partitions`). Some data sources, such as the built in - /// CSV and Parquet readers, implement this method as they are able to read - /// from their input files in parallel, regardless of how the source data is - /// split amongst files. - fn repartitioned( - &self, - _target_partitions: usize, - _config: &ConfigOptions, - ) -> Result>> { - Ok(None) - } - - /// Begin execution of `partition`, returning a [`Stream`] of - /// [`RecordBatch`]es. - /// - /// # Notes - /// - /// The `execute` method itself is not `async` but it returns an `async` - /// [`futures::stream::Stream`]. This `Stream` should incrementally compute - /// the output, `RecordBatch` by `RecordBatch` (in a streaming fashion). - /// Most `ExecutionPlan`s should not do any work before the first - /// `RecordBatch` is requested from the stream. - /// - /// [`RecordBatchStreamAdapter`] can be used to convert an `async` - /// [`Stream`] into a [`SendableRecordBatchStream`]. - /// - /// Using `async` `Streams` allows for network I/O during execution and - /// takes advantage of Rust's built in support for `async` continuations and - /// crate ecosystem. - /// - /// [`Stream`]: futures::stream::Stream - /// [`StreamExt`]: futures::stream::StreamExt - /// [`TryStreamExt`]: futures::stream::TryStreamExt - /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter - /// - /// # Cancellation / Aborting Execution - /// - /// The [`Stream`] that is returned must ensure that any allocated resources - /// are freed when the stream itself is dropped. This is particularly - /// important for [`spawn`]ed tasks or threads. Unless care is taken to - /// "abort" such tasks, they may continue to consume resources even after - /// the plan is dropped, generating intermediate results that are never - /// used. - /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`]. - /// - /// For more details see [`SpawnedTask`], [`JoinSet`] and [`RecordBatchReceiverStreamBuilder`] - /// for structures to help ensure all background tasks are cancelled. - /// - /// [`spawn`]: tokio::task::spawn - /// [`JoinSet`]: tokio::task::JoinSet - /// [`SpawnedTask`]: datafusion_common_runtime::SpawnedTask - /// [`RecordBatchReceiverStreamBuilder`]: crate::stream::RecordBatchReceiverStreamBuilder - /// - /// # Implementation Examples - /// - /// While `async` `Stream`s have a non trivial learning curve, the - /// [`futures`] crate provides [`StreamExt`] and [`TryStreamExt`] - /// which help simplify many common operations. - /// - /// Here are some common patterns: - /// - /// ## Return Precomputed `RecordBatch` - /// - /// We can return a precomputed `RecordBatch` as a `Stream`: - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow_array::RecordBatch; - /// # use arrow_schema::SchemaRef; - /// # use datafusion_common::Result; - /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - /// # use datafusion_physical_plan::memory::MemoryStream; - /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - /// struct MyPlan { - /// batch: RecordBatch, - /// } - /// - /// impl MyPlan { - /// fn execute( - /// &self, - /// partition: usize, - /// context: Arc - /// ) -> Result { - /// // use functions from futures crate convert the batch into a stream - /// let fut = futures::future::ready(Ok(self.batch.clone())); - /// let stream = futures::stream::once(fut); - /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.batch.schema(), stream))) - /// } - /// } - /// ``` - /// - /// ## Lazily (async) Compute `RecordBatch` - /// - /// We can also lazily compute a `RecordBatch` when the returned `Stream` is polled - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow_array::RecordBatch; - /// # use arrow_schema::SchemaRef; - /// # use datafusion_common::Result; - /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - /// # use datafusion_physical_plan::memory::MemoryStream; - /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - /// struct MyPlan { - /// schema: SchemaRef, - /// } - /// - /// /// Returns a single batch when the returned stream is polled - /// async fn get_batch() -> Result { - /// todo!() - /// } - /// - /// impl MyPlan { - /// fn execute( - /// &self, - /// partition: usize, - /// context: Arc - /// ) -> Result { - /// let fut = get_batch(); - /// let stream = futures::stream::once(fut); - /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) - /// } - /// } - /// ``` - /// - /// ## Lazily (async) create a Stream - /// - /// If you need to create the return `Stream` using an `async` function, - /// you can do so by flattening the result: - /// - /// ``` - /// # use std::sync::Arc; - /// # use arrow_array::RecordBatch; - /// # use arrow_schema::SchemaRef; - /// # use futures::TryStreamExt; - /// # use datafusion_common::Result; - /// # use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - /// # use datafusion_physical_plan::memory::MemoryStream; - /// # use datafusion_physical_plan::stream::RecordBatchStreamAdapter; - /// struct MyPlan { - /// schema: SchemaRef, - /// } - /// - /// /// async function that returns a stream - /// async fn get_batch_stream() -> Result { - /// todo!() - /// } - /// - /// impl MyPlan { - /// fn execute( - /// &self, - /// partition: usize, - /// context: Arc - /// ) -> Result { - /// // A future that yields a stream - /// let fut = get_batch_stream(); - /// // Use TryStreamExt::try_flatten to flatten the stream of streams - /// let stream = futures::stream::once(fut).try_flatten(); - /// Ok(Box::pin(RecordBatchStreamAdapter::new(self.schema.clone(), stream))) - /// } - /// } - /// ``` - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result; - - /// Return a snapshot of the set of [`Metric`]s for this - /// [`ExecutionPlan`]. If no `Metric`s are available, return None. - /// - /// While the values of the metrics in the returned - /// [`MetricsSet`]s may change as execution progresses, the - /// specific metrics will not. - /// - /// Once `self.execute()` has returned (technically the future is - /// resolved) for all available partitions, the set of metrics - /// should be complete. If this function is called prior to - /// `execute()` new metrics may appear in subsequent calls. - fn metrics(&self) -> Option { - None - } - - /// Returns statistics for this `ExecutionPlan` node. If statistics are not - /// available, should return [`Statistics::new_unknown`] (the default), not - /// an error. - fn statistics(&self) -> Result { - Ok(Statistics::new_unknown(&self.schema())) - } -} - -/// Extension trait provides an easy API to fetch various properties of -/// [`ExecutionPlan`] objects based on [`ExecutionPlan::properties`]. -pub trait ExecutionPlanProperties { - /// Specifies how the output of this `ExecutionPlan` is split into - /// partitions. - fn output_partitioning(&self) -> &Partitioning; - - /// Specifies whether this plan generates an infinite stream of records. - /// If the plan does not support pipelining, but its input(s) are - /// infinite, returns [`ExecutionMode::PipelineBreaking`] to indicate this. - fn execution_mode(&self) -> ExecutionMode; - - /// If the output of this `ExecutionPlan` within each partition is sorted, - /// returns `Some(keys)` describing the ordering. A `None` return value - /// indicates no assumptions should be made on the output ordering. - /// - /// For example, `SortExec` (obviously) produces sorted output as does - /// `SortPreservingMergeStream`. Less obviously, `Projection` produces sorted - /// output if its input is sorted as it does not reorder the input rows. - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]>; - - /// Get the [`EquivalenceProperties`] within the plan. - /// - /// Equivalence properties tell DataFusion what columns are known to be - /// equal, during various optimization passes. By default, this returns "no - /// known equivalences" which is always correct, but may cause DataFusion to - /// unnecessarily resort data. - /// - /// If this ExecutionPlan makes no changes to the schema of the rows flowing - /// through it or how columns within each row relate to each other, it - /// should return the equivalence properties of its input. For - /// example, since `FilterExec` may remove rows from its input, but does not - /// otherwise modify them, it preserves its input equivalence properties. - /// However, since `ProjectionExec` may calculate derived expressions, it - /// needs special handling. - /// - /// See also [`ExecutionPlan::maintains_input_order`] and [`Self::output_ordering`] - /// for related concepts. - fn equivalence_properties(&self) -> &EquivalenceProperties; -} - -impl ExecutionPlanProperties for Arc { - fn output_partitioning(&self) -> &Partitioning { - self.properties().output_partitioning() - } - - fn execution_mode(&self) -> ExecutionMode { - self.properties().execution_mode() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.properties().output_ordering() - } - - fn equivalence_properties(&self) -> &EquivalenceProperties { - self.properties().equivalence_properties() - } -} - -impl ExecutionPlanProperties for &dyn ExecutionPlan { - fn output_partitioning(&self) -> &Partitioning { - self.properties().output_partitioning() - } - - fn execution_mode(&self) -> ExecutionMode { - self.properties().execution_mode() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.properties().output_ordering() - } - - fn equivalence_properties(&self) -> &EquivalenceProperties { - self.properties().equivalence_properties() - } -} - -/// Describes the execution mode of an operator's resulting stream with respect -/// to its size and behavior. There are three possible execution modes: `Bounded`, -/// `Unbounded` and `PipelineBreaking`. -#[derive(Clone, Copy, PartialEq, Debug)] -pub enum ExecutionMode { - /// Represents the mode where generated stream is bounded, e.g. finite. - Bounded, - /// Represents the mode where generated stream is unbounded, e.g. infinite. - /// Even though the operator generates an unbounded stream of results, it - /// works with bounded memory and execution can still continue successfully. - /// - /// The stream that results from calling `execute` on an `ExecutionPlan` that is `Unbounded` - /// will never be done (return `None`), except in case of error. - Unbounded, - /// Represents the mode where some of the operator's input stream(s) are - /// unbounded; however, the operator cannot generate streaming results from - /// these streaming inputs. In this case, the execution mode will be pipeline - /// breaking, e.g. the operator requires unbounded memory to generate results. - PipelineBreaking, -} - -impl ExecutionMode { - /// Check whether the execution mode is unbounded or not. - pub fn is_unbounded(&self) -> bool { - matches!(self, ExecutionMode::Unbounded) - } - - /// Check whether the execution is pipeline friendly. If so, operator can - /// execute safely. - pub fn pipeline_friendly(&self) -> bool { - matches!(self, ExecutionMode::Bounded | ExecutionMode::Unbounded) - } -} - -/// Conservatively "combines" execution modes of a given collection of operators. -fn execution_mode_from_children<'a>( - children: impl IntoIterator>, -) -> ExecutionMode { - let mut result = ExecutionMode::Bounded; - for mode in children.into_iter().map(|child| child.execution_mode()) { - match (mode, result) { - (ExecutionMode::PipelineBreaking, _) - | (_, ExecutionMode::PipelineBreaking) => { - // If any of the modes is `PipelineBreaking`, so is the result: - return ExecutionMode::PipelineBreaking; - } - (ExecutionMode::Unbounded, _) | (_, ExecutionMode::Unbounded) => { - // Unbounded mode eats up bounded mode: - result = ExecutionMode::Unbounded; - } - (ExecutionMode::Bounded, ExecutionMode::Bounded) => { - // When both modes are bounded, so is the result: - result = ExecutionMode::Bounded; - } - } - } - result -} - -/// Stores certain, often expensive to compute, plan properties used in query -/// optimization. -/// -/// These properties are stored a single structure to permit this information to -/// be computed once and then those cached results used multiple times without -/// recomputation (aka a cache) -#[derive(Debug, Clone)] -pub struct PlanProperties { - /// See [ExecutionPlanProperties::equivalence_properties] - pub eq_properties: EquivalenceProperties, - /// See [ExecutionPlanProperties::output_partitioning] - pub partitioning: Partitioning, - /// See [ExecutionPlanProperties::execution_mode] - pub execution_mode: ExecutionMode, - /// See [ExecutionPlanProperties::output_ordering] - output_ordering: Option, -} - -impl PlanProperties { - /// Construct a new `PlanPropertiesCache` from the - pub fn new( - eq_properties: EquivalenceProperties, - partitioning: Partitioning, - execution_mode: ExecutionMode, - ) -> Self { - // Output ordering can be derived from `eq_properties`. - let output_ordering = eq_properties.output_ordering(); - Self { - eq_properties, - partitioning, - execution_mode, - output_ordering, - } - } - - /// Overwrite output partitioning with its new value. - pub fn with_partitioning(mut self, partitioning: Partitioning) -> Self { - self.partitioning = partitioning; - self - } - - /// Overwrite the execution Mode with its new value. - pub fn with_execution_mode(mut self, execution_mode: ExecutionMode) -> Self { - self.execution_mode = execution_mode; - self - } - - /// Overwrite equivalence properties with its new value. - pub fn with_eq_properties(mut self, eq_properties: EquivalenceProperties) -> Self { - // Changing equivalence properties also changes output ordering, so - // make sure to overwrite it: - self.output_ordering = eq_properties.output_ordering(); - self.eq_properties = eq_properties; - self - } - - pub fn equivalence_properties(&self) -> &EquivalenceProperties { - &self.eq_properties - } - - pub fn output_partitioning(&self) -> &Partitioning { - &self.partitioning - } - - pub fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - self.output_ordering.as_deref() - } - - pub fn execution_mode(&self) -> ExecutionMode { - self.execution_mode - } - - /// Get schema of the node. - fn schema(&self) -> &SchemaRef { - self.eq_properties.schema() - } -} - -/// Indicate whether a data exchange is needed for the input of `plan`, which will be very helpful -/// especially for the distributed engine to judge whether need to deal with shuffling. -/// Currently there are 3 kinds of execution plan which needs data exchange -/// 1. RepartitionExec for changing the partition number between two `ExecutionPlan`s -/// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee -/// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee -pub fn need_data_exchange(plan: Arc) -> bool { - if let Some(repartition) = plan.as_any().downcast_ref::() { - !matches!( - repartition.properties().output_partitioning(), - Partitioning::RoundRobinBatch(_) - ) - } else if let Some(coalesce) = plan.as_any().downcast_ref::() - { - coalesce.input().output_partitioning().partition_count() > 1 - } else if let Some(sort_preserving_merge) = - plan.as_any().downcast_ref::() - { - sort_preserving_merge - .input() - .output_partitioning() - .partition_count() - > 1 - } else { - false - } -} - -/// Returns a copy of this plan if we change any child according to the pointer comparison. -/// The size of `children` must be equal to the size of `ExecutionPlan::children()`. -pub fn with_new_children_if_necessary( - plan: Arc, - children: Vec>, -) -> Result> { - let old_children = plan.children(); - if children.len() != old_children.len() { - internal_err!("Wrong number of children") - } else if children.is_empty() - || children - .iter() - .zip(old_children.iter()) - .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) - { - plan.with_new_children(children) - } else { - Ok(plan) - } -} - -/// Return a [wrapper](DisplayableExecutionPlan) around an -/// [`ExecutionPlan`] which can be displayed in various easier to -/// understand ways. -pub fn displayable(plan: &dyn ExecutionPlan) -> DisplayableExecutionPlan<'_> { - DisplayableExecutionPlan::new(plan) -} - -/// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect( - plan: Arc, - context: Arc, -) -> Result> { - let stream = execute_stream(plan, context)?; - common::collect(stream).await -} - -/// Execute the [ExecutionPlan] and return a single stream of `RecordBatch`es. -/// -/// See [collect] to buffer the `RecordBatch`es in memory. -/// -/// # Aborting Execution -/// -/// Dropping the stream will abort the execution of the query, and free up -/// any allocated resources -pub fn execute_stream( - plan: Arc, - context: Arc, -) -> Result { - match plan.output_partitioning().partition_count() { - 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0, context), - _ => { - // merge into a single partition - let plan = CoalescePartitionsExec::new(plan.clone()); - // CoalescePartitionsExec must produce a single partition - assert_eq!(1, plan.properties().output_partitioning().partition_count()); - plan.execute(0, context) - } - } -} - -/// Execute the [ExecutionPlan] and collect the results in memory -pub async fn collect_partitioned( - plan: Arc, - context: Arc, -) -> Result>> { - let streams = execute_stream_partitioned(plan, context)?; - - let mut join_set = JoinSet::new(); - // Execute the plan and collect the results into batches. - streams.into_iter().enumerate().for_each(|(idx, stream)| { - join_set.spawn(async move { - let result: Result> = stream.try_collect().await; - (idx, result) - }); - }); - - let mut batches = vec![]; - // Note that currently this doesn't identify the thread that panicked - // - // TODO: Replace with [join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id - // once it is stable - while let Some(result) = join_set.join_next().await { - match result { - Ok((idx, res)) => batches.push((idx, res?)), - Err(e) => { - if e.is_panic() { - std::panic::resume_unwind(e.into_panic()); - } else { - unreachable!(); - } - } - } - } - - batches.sort_by_key(|(idx, _)| *idx); - let batches = batches.into_iter().map(|(_, batch)| batch).collect(); - - Ok(batches) -} - -/// Execute the [ExecutionPlan] and return a vec with one stream per output -/// partition -/// -/// # Aborting Execution -/// -/// Dropping the stream will abort the execution of the query, and free up -/// any allocated resources -pub fn execute_stream_partitioned( - plan: Arc, - context: Arc, -) -> Result> { - let num_partitions = plan.output_partitioning().partition_count(); - let mut streams = Vec::with_capacity(num_partitions); - for i in 0..num_partitions { - streams.push(plan.execute(i, context.clone())?); - } - Ok(streams) -} - -/// Utility function yielding a string representation of the given [`ExecutionPlan`]. -pub fn get_plan_string(plan: &Arc) -> Vec { - let formatted = displayable(plan.as_ref()).indent(true).to_string(); - let actual: Vec<&str> = formatted.trim().lines().collect(); - actual.iter().map(|elem| elem.to_string()).collect() + pub use datafusion_expr::StatisticsArgs; + pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr; } +pub mod coalesce; #[cfg(test)] -#[allow(clippy::single_component_path_imports)] -use rstest_reuse; - -#[cfg(test)] -mod tests { - use std::any::Any; - use std::sync::Arc; - - use arrow_schema::{Schema, SchemaRef}; - use datafusion_common::{Result, Statistics}; - use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - - use crate::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; - - #[derive(Debug)] - pub struct EmptyExec; - - impl EmptyExec { - pub fn new(_schema: SchemaRef) -> Self { - Self - } - } - - impl DisplayAs for EmptyExec { - fn fmt_as( - &self, - _t: DisplayFormatType, - _f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - unimplemented!() - } - } - - impl ExecutionPlan for EmptyExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - unimplemented!() - } - - fn children(&self) -> Vec> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!() - } - - fn statistics(&self) -> Result { - unimplemented!() - } - } - - #[derive(Debug)] - pub struct RenamedEmptyExec; - - impl RenamedEmptyExec { - pub fn new(_schema: SchemaRef) -> Self { - Self - } - } - - impl DisplayAs for RenamedEmptyExec { - fn fmt_as( - &self, - _t: DisplayFormatType, - _f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { - unimplemented!() - } - } - - impl ExecutionPlan for RenamedEmptyExec { - fn name(&self) -> &'static str { - "MyRenamedEmptyExec" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - unimplemented!() - } - - fn children(&self) -> Vec> { - vec![] - } - - fn with_new_children( - self: Arc, - _: Vec>, - ) -> Result> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc, - ) -> Result { - unimplemented!() - } - - fn statistics(&self) -> Result { - unimplemented!() - } - } - - #[test] - fn test_execution_plan_name() { - let schema1 = Arc::new(Schema::empty()); - let default_name_exec = EmptyExec::new(schema1); - assert_eq!(default_name_exec.name(), "EmptyExec"); - - let schema2 = Arc::new(Schema::empty()); - let renamed_exec = RenamedEmptyExec::new(schema2); - assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); - } -} - pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index fab483b0da7d..ab1e6cb37bc8 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -31,15 +31,15 @@ use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::stats::Precision; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; /// Limit execution plan -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct GlobalLimitExec { /// Input execution plan input: Arc, @@ -124,8 +124,8 @@ impl ExecutionPlan for GlobalLimitExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn required_input_distribution(&self) -> Vec { @@ -145,7 +145,7 @@ impl ExecutionPlan for GlobalLimitExec { children: Vec>, ) -> Result> { Ok(Arc::new(GlobalLimitExec::new( - children[0].clone(), + Arc::clone(&children[0]), self.skip, self.fetch, ))) @@ -185,80 +185,21 @@ impl ExecutionPlan for GlobalLimitExec { } fn statistics(&self) -> Result { - let input_stats = self.input.statistics()?; - let skip = self.skip; - let col_stats = Statistics::unknown_column(&self.schema()); - let fetch = self.fetch.unwrap_or(usize::MAX); - - let mut fetched_row_number_stats = Statistics { - num_rows: Precision::Exact(fetch), - column_statistics: col_stats.clone(), - total_byte_size: Precision::Absent, - }; + Statistics::with_fetch( + self.input.statistics()?, + self.schema(), + self.fetch, + self.skip, + 1, + ) + } - let stats = match input_stats { - Statistics { - num_rows: Precision::Exact(nr), - .. - } - | Statistics { - num_rows: Precision::Inexact(nr), - .. - } => { - if nr <= skip { - // if all input data will be skipped, return 0 - let mut skip_all_rows_stats = Statistics { - num_rows: Precision::Exact(0), - column_statistics: col_stats, - total_byte_size: Precision::Absent, - }; - if !input_stats.num_rows.is_exact().unwrap_or(false) { - // The input stats are inexact, so the output stats must be too. - skip_all_rows_stats = skip_all_rows_stats.into_inexact(); - } - skip_all_rows_stats - } else if nr <= fetch && self.skip == 0 { - // if the input does not reach the "fetch" globally, and "skip" is zero - // (meaning the input and output are identical), return input stats. - // Can input_stats still be used, but adjusted, in the "skip != 0" case? - input_stats - } else if nr - skip <= fetch { - // after "skip" input rows are skipped, the remaining rows are less than or equal to the - // "fetch" values, so `num_rows` must equal the remaining rows - let remaining_rows: usize = nr - skip; - let mut skip_some_rows_stats = Statistics { - num_rows: Precision::Exact(remaining_rows), - column_statistics: col_stats, - total_byte_size: Precision::Absent, - }; - if !input_stats.num_rows.is_exact().unwrap_or(false) { - // The input stats are inexact, so the output stats must be too. - skip_some_rows_stats = skip_some_rows_stats.into_inexact(); - } - skip_some_rows_stats - } else { - // if the input is greater than "fetch+skip", the num_rows will be the "fetch", - // but we won't be able to predict the other statistics - if !input_stats.num_rows.is_exact().unwrap_or(false) - || self.fetch.is_none() - { - // If the input stats are inexact, the output stats must be too. - // If the fetch value is `usize::MAX` because no LIMIT was specified, - // we also can't represent it as an exact value. - fetched_row_number_stats = - fetched_row_number_stats.into_inexact(); - } - fetched_row_number_stats - } - } - _ => { - // The result output `num_rows` will always be no greater than the limit number. - // Should `num_rows` be marked as `Absent` here when the `fetch` value is large, - // as the actual `num_rows` may be far away from the `fetch` value? - fetched_row_number_stats.into_inexact() - } - }; - Ok(stats) + fn fetch(&self) -> Option { + self.fetch + } + + fn supports_limit_pushdown(&self) -> bool { + true } } @@ -334,8 +275,8 @@ impl ExecutionPlan for LocalLimitExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn benefits_from_input_partitioning(&self) -> Vec { @@ -352,7 +293,7 @@ impl ExecutionPlan for LocalLimitExec { ) -> Result> { match children.len() { 1 => Ok(Arc::new(LocalLimitExec::new( - children[0].clone(), + Arc::clone(&children[0]), self.fetch, ))), _ => internal_err!("LocalLimitExec wrong number of children"), @@ -380,58 +321,30 @@ impl ExecutionPlan for LocalLimitExec { } fn statistics(&self) -> Result { - let input_stats = self.input.statistics()?; - let col_stats = Statistics::unknown_column(&self.schema()); - let stats = match input_stats { - // if the input does not reach the limit globally, return input stats - Statistics { - num_rows: Precision::Exact(nr), - .. - } - | Statistics { - num_rows: Precision::Inexact(nr), - .. - } if nr <= self.fetch => input_stats, - // if the input is greater than the limit, the num_row will be greater - // than the limit because the partitions will be limited separatly - // the statistic - Statistics { - num_rows: Precision::Exact(nr), - .. - } if nr > self.fetch => Statistics { - num_rows: Precision::Exact(self.fetch), - // this is not actually exact, but will be when GlobalLimit is applied - // TODO stats: find a more explicit way to vehiculate this information - column_statistics: col_stats, - total_byte_size: Precision::Absent, - }, - Statistics { - num_rows: Precision::Inexact(nr), - .. - } if nr > self.fetch => Statistics { - num_rows: Precision::Inexact(self.fetch), - // this is not actually exact, but will be when GlobalLimit is applied - // TODO stats: find a more explicit way to vehiculate this information - column_statistics: col_stats, - total_byte_size: Precision::Absent, - }, - _ => Statistics { - // the result output row number will always be no greater than the limit number - num_rows: Precision::Inexact( - self.fetch - * self.properties().output_partitioning().partition_count(), - ), - - column_statistics: col_stats, - total_byte_size: Precision::Absent, - }, - }; - Ok(stats) + Statistics::with_fetch( + self.input.statistics()?, + self.schema(), + Some(self.fetch), + 0, + 1, + ) + } + + fn fetch(&self) -> Option { + Some(self.fetch) + } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::LowerEqual } } /// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows. -struct LimitStream { +pub struct LimitStream { /// The remaining number of rows to skip skip: usize, /// The remaining number of rows to produce @@ -446,7 +359,7 @@ struct LimitStream { } impl LimitStream { - fn new( + pub fn new( input: SendableRecordBatchStream, skip: usize, fetch: Option, @@ -485,7 +398,7 @@ impl LimitStream { if batch.num_rows() > 0 { break poll; } else { - // continue to poll input stream + // Continue to poll input stream } } Poll::Ready(Some(Err(_e))) => break poll, @@ -495,12 +408,12 @@ impl LimitStream { } } - /// fetches from the batch + /// Fetches from the batch fn stream_limit(&mut self, batch: RecordBatch) -> Option { // records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); if self.fetch == 0 { - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early None } else if batch.num_rows() < self.fetch { // @@ -509,7 +422,7 @@ impl LimitStream { } else if batch.num_rows() >= self.fetch { let batch_rows = self.fetch; self.fetch = 0; - self.input = None; // clear input so it can be dropped early + self.input = None; // Clear input so it can be dropped early // It is guaranteed that batch_rows is <= batch.num_rows Some(batch.slice(0, batch_rows)) @@ -540,7 +453,7 @@ impl Stream for LimitStream { other => other, }) } - // input has been cleared + // Input has been cleared None => Poll::Ready(None), }; @@ -551,7 +464,7 @@ impl Stream for LimitStream { impl RecordBatchStream for LimitStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -560,11 +473,12 @@ mod tests { use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::common::collect; - use crate::{common, test}; + use crate::test; use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use arrow_array::RecordBatchOptions; use arrow_schema::Schema; + use datafusion_common::stats::Precision; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalExpr; @@ -575,17 +489,17 @@ mod tests { let num_partitions = 4; let csv = test::scan_partitioned(num_partitions); - // input should have 4 partitions + // Input should have 4 partitions assert_eq!(csv.output_partitioning().partition_count(), num_partitions); let limit = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7)); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = limit.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; - // there should be a total of 100 rows + // There should be a total of 100 rows let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum(); assert_eq!(row_count, 7); @@ -606,7 +520,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (5 rows) and 1 row from the second (1 row) let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -636,7 +550,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -666,7 +580,7 @@ mod tests { let index = input.index(); assert_eq!(index.value(), 0); - // limit of six needs to consume the entire first record batch + // Limit of six needs to consume the entire first record batch // (6 rows) and stop immediately let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0); let limit_stream = @@ -684,7 +598,7 @@ mod tests { Ok(()) } - // test cases for "skip" + // Test cases for "skip" async fn skip_and_fetch(skip: usize, fetch: Option) -> Result { let task_ctx = Arc::new(TaskContext::default()); @@ -697,9 +611,9 @@ mod tests { let offset = GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch); - // the result should contain 4 batches (one per input partition) + // The result should contain 4 batches (one per input partition) let iter = offset.execute(0, task_ctx)?; - let batches = common::collect(iter).await?; + let batches = collect(iter).await?; Ok(batches.iter().map(|batch| batch.num_rows()).sum()) } @@ -719,7 +633,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 3 rows (offset = 3) + // There are total of 400 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, None).await?; assert_eq!(row_count, 397); Ok(()) @@ -727,7 +641,7 @@ mod tests { #[tokio::test] async fn skip_3_fetch_10_stats() -> Result<()> { - // there are total of 100 rows, we skipped 3 rows (offset = 3) + // There are total of 100 rows, we skipped 3 rows (offset = 3) let row_count = skip_and_fetch(3, Some(10)).await?; assert_eq!(row_count, 10); Ok(()) @@ -742,7 +656,7 @@ mod tests { #[tokio::test] async fn skip_400_fetch_1() -> Result<()> { - // there are a total of 400 rows + // There are a total of 400 rows let row_count = skip_and_fetch(400, Some(1)).await?; assert_eq!(row_count, 0); Ok(()) @@ -750,7 +664,7 @@ mod tests { #[tokio::test] async fn skip_401_fetch_none() -> Result<()> { - // there are total of 400 rows, we skipped 401 rows (offset = 3) + // There are total of 400 rows, we skipped 401 rows (offset = 3) let row_count = skip_and_fetch(401, None).await?; assert_eq!(row_count, 0); Ok(()) @@ -794,7 +708,7 @@ mod tests { let row_count = row_number_inexact_statistics_for_global_limit(400, Some(10)).await?; - assert_eq!(row_count, Precision::Inexact(0)); + assert_eq!(row_count, Precision::Exact(0)); let row_count = row_number_inexact_statistics_for_global_limit(398, Some(10)).await?; @@ -864,11 +778,11 @@ mod tests { // Adding a "GROUP BY i" changes the input stats from Exact to Inexact. let agg = AggregateExec::try_new( AggregateMode::Final, - build_group_by(&csv.schema().clone(), vec!["i".to_string()]), + build_group_by(&csv.schema(), vec!["i".to_string()]), vec![], vec![], - csv.clone(), - csv.schema().clone(), + Arc::clone(&csv), + Arc::clone(&csv.schema()), )?; let agg_exec: Arc = Arc::new(agg); diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 883cdb540a9e..c9ada345afc7 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -22,7 +22,6 @@ use std::fmt; use std::sync::Arc; use std::task::{Context, Poll}; -use super::expressions::PhysicalSortExpr; use super::{ common, DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan, Partitioning, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -33,11 +32,15 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, project_schema, Result}; use datafusion_execution::memory_pool::MemoryReservation; use datafusion_execution::TaskContext; +use datafusion_physical_expr::equivalence::ProjectionMapping; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use futures::Stream; /// Execution plan for reading in-memory batches of data +#[derive(Clone)] pub struct MemoryExec { /// The partitions to query partitions: Vec>, @@ -56,22 +59,17 @@ pub struct MemoryExec { impl fmt::Debug for MemoryExec { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "partitions: [...]")?; - write!(f, "schema: {:?}", self.projected_schema)?; - write!(f, "projection: {:?}", self.projection)?; - if let Some(sort_info) = &self.sort_information.first() { - write!(f, ", output_ordering: {:?}", sort_info)?; - } - Ok(()) + f.debug_struct("MemoryExec") + .field("partitions", &"[...]") + .field("schema", &self.schema) + .field("projection", &self.projection) + .field("sort_information", &self.sort_information) + .finish() } } impl DisplayAs for MemoryExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { let partition_sizes: Vec<_> = @@ -81,10 +79,7 @@ impl DisplayAs for MemoryExec { .sort_information .first() .map(|output_ordering| { - format!( - ", output_ordering={}", - PhysicalSortExpr::format_list(output_ordering) - ) + format!(", output_ordering={}", output_ordering) }) .unwrap_or_default(); @@ -116,8 +111,8 @@ impl ExecutionPlan for MemoryExec { &self.cache } - fn children(&self) -> Vec> { - // this is a leaf node and has no children + fn children(&self) -> Vec<&Arc> { + // This is a leaf node and has no children vec![] } @@ -140,7 +135,7 @@ impl ExecutionPlan for MemoryExec { ) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), - self.projected_schema.clone(), + Arc::clone(&self.projected_schema), self.projection.clone(), )?)) } @@ -164,7 +159,8 @@ impl MemoryExec { projection: Option>, ) -> Result { let projected_schema = project_schema(&schema, projection.as_ref())?; - let cache = Self::compute_properties(projected_schema.clone(), &[], partitions); + let cache = + Self::compute_properties(Arc::clone(&projected_schema), &[], partitions); Ok(Self { partitions: partitions.to_vec(), schema, @@ -176,7 +172,7 @@ impl MemoryExec { }) } - /// set `show_sizes` to determine whether to display partition sizes + /// Set `show_sizes` to determine whether to display partition sizes pub fn with_show_sizes(mut self, show_sizes: bool) -> Self { self.show_sizes = show_sizes; self @@ -206,20 +202,67 @@ impl MemoryExec { /// where both `a ASC` and `b DESC` can describe the table ordering. With /// [`EquivalenceProperties`], we can keep track of these equivalences /// and treat `a ASC` and `b DESC` as the same ordering requirement. - pub fn with_sort_information(mut self, sort_information: Vec) -> Self { - self.sort_information = sort_information; + /// + /// Note that if there is an internal projection, that projection will be + /// also applied to the given `sort_information`. + pub fn try_with_sort_information( + mut self, + mut sort_information: Vec, + ) -> Result { + // All sort expressions must refer to the original schema + let fields = self.schema.fields(); + let ambiguous_column = sort_information + .iter() + .flat_map(|ordering| ordering.inner.clone()) + .flat_map(|expr| collect_columns(&expr.expr)) + .find(|col| { + fields + .get(col.index()) + .map(|field| field.name() != col.name()) + .unwrap_or(true) + }); + if let Some(col) = ambiguous_column { + return internal_err!( + "Column {:?} is not found in the original schema of the MemoryExec", + col + ); + } + // If there is a projection on the source, we also need to project orderings + if let Some(projection) = &self.projection { + let base_eqp = EquivalenceProperties::new_with_orderings( + self.original_schema(), + &sort_information, + ); + let proj_exprs = projection + .iter() + .map(|idx| { + let base_schema = self.original_schema(); + let name = base_schema.field(*idx).name(); + (Arc::new(Column::new(name, *idx)) as _, name.to_string()) + }) + .collect::>(); + let projection_mapping = + ProjectionMapping::try_new(&proj_exprs, &self.original_schema())?; + sort_information = base_eqp + .project(&projection_mapping, self.schema()) + .oeq_class + .orderings; + } + + self.sort_information = sort_information; // We need to update equivalence properties when updating sort information. let eq_properties = EquivalenceProperties::new_with_orderings( self.schema(), &self.sort_information, ); self.cache = self.cache.with_eq_properties(eq_properties); - self + + Ok(self) } pub fn original_schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. @@ -305,7 +348,7 @@ impl Stream for MemoryStream { impl RecordBatchStream for MemoryStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -319,6 +362,7 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SortOptions}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::PhysicalSortExpr; + use datafusion_physical_expr_common::sort_expr::LexOrdering; #[test] fn test_memory_order_eq() -> datafusion_common::Result<()> { @@ -327,7 +371,7 @@ mod tests { Field::new("b", DataType::Int64, false), Field::new("c", DataType::Int64, false), ])); - let sort1 = vec![ + let sort1 = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), @@ -336,22 +380,22 @@ mod tests { expr: col("b", &schema)?, options: SortOptions::default(), }, - ]; - let sort2 = vec![PhysicalSortExpr { + ]); + let sort2 = LexOrdering::new(vec![PhysicalSortExpr { expr: col("c", &schema)?, options: SortOptions::default(), - }]; - let mut expected_output_order = vec![]; + }]); + let mut expected_output_order = LexOrdering::default(); expected_output_order.extend(sort1.clone()); expected_output_order.extend(sort2.clone()); let sort_information = vec![sort1.clone(), sort2.clone()]; let mem_exec = MemoryExec::try_new(&[vec![]], schema, None)? - .with_sort_information(sort_information); + .try_with_sort_information(sort_information)?; assert_eq!( - mem_exec.properties().output_ordering().unwrap(), - expected_output_order + mem_exec.properties().output_ordering().unwrap().to_vec(), + expected_output_order.inner ); let eq_properties = mem_exec.properties().equivalence_properties(); assert!(eq_properties.oeq_class().contains(&sort1)); diff --git a/datafusion/physical-plan/src/metrics/baseline.rs b/datafusion/physical-plan/src/metrics/baseline.rs index dc345cd8cdcd..b26a08dd0fad 100644 --- a/datafusion/physical-plan/src/metrics/baseline.rs +++ b/datafusion/physical-plan/src/metrics/baseline.rs @@ -56,7 +56,7 @@ pub struct BaselineMetrics { } impl BaselineMetrics { - /// Create a new BaselineMetric structure, and set `start_time` to now + /// Create a new BaselineMetric structure, and set `start_time` to now pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self { let start_time = MetricBuilder::new(metrics).start_timestamp(partition); start_time.record(); diff --git a/datafusion/physical-plan/src/metrics/mod.rs b/datafusion/physical-plan/src/metrics/mod.rs index 9232865aa09c..ead0ca336938 100644 --- a/datafusion/physical-plan/src/metrics/mod.rs +++ b/datafusion/physical-plan/src/metrics/mod.rs @@ -301,8 +301,12 @@ impl MetricsSet { /// Sort the order of metrics so the "most useful" show up first pub fn sorted_for_display(mut self) -> Self { - self.metrics - .sort_unstable_by_key(|metric| metric.value().display_sort_key()); + self.metrics.sort_unstable_by_key(|metric| { + ( + metric.value().display_sort_key(), + metric.value().name().to_owned(), + ) + }); self } @@ -665,7 +669,9 @@ mod tests { MetricBuilder::new(&metrics).end_timestamp(0); MetricBuilder::new(&metrics).start_timestamp(0); MetricBuilder::new(&metrics).elapsed_compute(0); + MetricBuilder::new(&metrics).counter("the_second_counter", 0); MetricBuilder::new(&metrics).counter("the_counter", 0); + MetricBuilder::new(&metrics).counter("the_third_counter", 0); MetricBuilder::new(&metrics).subset_time("the_time", 0); MetricBuilder::new(&metrics).output_rows(0); let metrics = metrics.clone_inner(); @@ -675,9 +681,9 @@ mod tests { n.join(", ") } - assert_eq!("end_timestamp, start_timestamp, elapsed_compute, the_counter, the_time, output_rows", metric_names(&metrics)); + assert_eq!("end_timestamp, start_timestamp, elapsed_compute, the_second_counter, the_counter, the_third_counter, the_time, output_rows", metric_names(&metrics)); let metrics = metrics.sorted_for_display(); - assert_eq!("output_rows, elapsed_compute, the_counter, the_time, start_timestamp, end_timestamp", metric_names(&metrics)); + assert_eq!("output_rows, elapsed_compute, the_counter, the_second_counter, the_third_counter, the_time, start_timestamp, end_timestamp", metric_names(&metrics)); } } diff --git a/datafusion/physical-plan/src/metrics/value.rs b/datafusion/physical-plan/src/metrics/value.rs index 22db8f1e4e88..2eb01914ee0a 100644 --- a/datafusion/physical-plan/src/metrics/value.rs +++ b/datafusion/physical-plan/src/metrics/value.rs @@ -37,7 +37,7 @@ use parking_lot::Mutex; #[derive(Debug, Clone)] pub struct Count { /// value of the metric counter - value: std::sync::Arc, + value: Arc, } impl PartialEq for Count { @@ -86,7 +86,7 @@ impl Count { #[derive(Debug, Clone)] pub struct Gauge { /// value of the metric gauge - value: std::sync::Arc, + value: Arc, } impl PartialEq for Gauge { @@ -168,7 +168,7 @@ impl PartialEq for Time { impl Display for Time { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - let duration = std::time::Duration::from_nanos(self.value() as u64); + let duration = Duration::from_nanos(self.value() as u64); write!(f, "{duration:?}") } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index c94c2b0607d7..f9437f46f8a6 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -37,7 +37,7 @@ use datafusion_physical_expr::EquivalenceProperties; use log::trace; /// Execution plan for empty relation with produce_one_row=true -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct PlaceholderRowExec { /// The schema for the produced row schema: SchemaRef, @@ -50,7 +50,7 @@ impl PlaceholderRowExec { /// Create a new PlaceholderRowExec pub fn new(schema: SchemaRef) -> Self { let partitions = 1; - let cache = Self::compute_properties(schema.clone(), partitions); + let cache = Self::compute_properties(Arc::clone(&schema), partitions); PlaceholderRowExec { schema, partitions, @@ -132,7 +132,7 @@ impl ExecutionPlan for PlaceholderRowExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -160,7 +160,7 @@ impl ExecutionPlan for PlaceholderRowExec { Ok(Box::pin(MemoryStream::try_new( self.data()?, - self.schema.clone(), + Arc::clone(&self.schema), None, )?)) } @@ -188,7 +188,10 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); - let placeholder_2 = with_new_children_if_necessary(placeholder.clone(), vec![])?; + let placeholder_2 = with_new_children_if_necessary( + Arc::clone(&placeholder) as Arc, + vec![], + )?; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; @@ -205,8 +208,8 @@ mod tests { let schema = test::aggr_test_schema(); let placeholder = PlaceholderRowExec::new(schema); - // ask for the wrong partition - assert!(placeholder.execute(1, task_ctx.clone()).is_err()); + // Ask for the wrong partition + assert!(placeholder.execute(1, Arc::clone(&task_ctx)).is_err()); assert!(placeholder.execute(20, task_ctx).is_err()); Ok(()) } @@ -220,7 +223,7 @@ mod tests { let iter = placeholder.execute(0, task_ctx)?; let batches = common::collect(iter).await?; - // should have one item + // Should have one item assert_eq!(batches.len(), 1); Ok(()) @@ -234,10 +237,10 @@ mod tests { let placeholder = PlaceholderRowExec::new(schema).with_partitions(partitions); for n in 0..partitions { - let iter = placeholder.execute(n, task_ctx.clone())?; + let iter = placeholder.execute(n, Arc::clone(&task_ctx))?; let batches = common::collect(iter).await?; - // should have one item + // Should have one item assert_eq!(batches.len(), 1); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index f72815c01a9e..c1d3f368366f 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -32,9 +32,7 @@ use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::{ - ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, PhysicalExpr, -}; +use crate::{ColumnStatistics, DisplayFormatType, ExecutionPlan, PhysicalExpr}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; @@ -42,8 +40,9 @@ use datafusion_common::stats::Precision; use datafusion_common::Result; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::{Literal, UnKnownColumn}; +use datafusion_physical_expr::expressions::{CastExpr, Literal}; +use crate::execution_plan::CardinalityEffect; use futures::stream::{Stream, StreamExt}; use log::trace; @@ -91,10 +90,10 @@ impl ProjectionExec { input_schema.metadata().clone(), )); - // construct a map from the input expressions to the output expression of the Projection + // Construct a map from the input expressions to the output expression of the Projection let projection_mapping = ProjectionMapping::try_new(&expr, &input_schema)?; let cache = - Self::compute_properties(&input, &projection_mapping, schema.clone())?; + Self::compute_properties(&input, &projection_mapping, Arc::clone(&schema))?; Ok(Self { expr, schema, @@ -127,22 +126,8 @@ impl ProjectionExec { // Calculate output partitioning, which needs to respect aliases: let input_partition = input.output_partitioning(); - let output_partitioning = if let Partitioning::Hash(exprs, part) = input_partition - { - let normalized_exprs = exprs - .iter() - .map(|expr| { - input_eq_properties - .project_expr(expr, projection_mapping) - .unwrap_or_else(|| { - Arc::new(UnKnownColumn::new(&expr.to_string())) - }) - }) - .collect(); - Partitioning::Hash(normalized_exprs, *part) - } else { - input_partition.clone() - }; + let output_partitioning = + input_partition.project(projection_mapping, &input_eq_properties); Ok(PlanProperties::new( eq_properties, @@ -193,12 +178,12 @@ impl ExecutionPlan for ProjectionExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn maintains_input_order(&self) -> Vec { - // tell optimizer this operator doesn't reorder its input + // Tell optimizer this operator doesn't reorder its input vec![true] } @@ -227,8 +212,8 @@ impl ExecutionPlan for ProjectionExec { ) -> Result { trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); Ok(Box::pin(ProjectionStream { - schema: self.schema.clone(), - expr: self.expr.iter().map(|x| x.0.clone()).collect(), + schema: Arc::clone(&self.schema), + expr: self.expr.iter().map(|x| Arc::clone(&x.0)).collect(), input: self.input.execute(partition, context)?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) @@ -242,17 +227,29 @@ impl ExecutionPlan for ProjectionExec { Ok(stats_projection( self.input.statistics()?, self.expr.iter().map(|(e, _)| Arc::clone(e)), - self.schema.clone(), + Arc::clone(&self.schema), )) } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } -/// If e is a direct column reference, returns the field level +/// If 'e' is a direct column reference, returns the field level /// metadata for that field, if any. Otherwise returns None -fn get_field_metadata( +pub(crate) fn get_field_metadata( e: &Arc, input_schema: &Schema, ) -> Option> { + if let Some(cast) = e.as_any().downcast_ref::() { + return get_field_metadata(cast.expr(), input_schema); + } + // Look up field by index in schema (not NAME as there can be more than one // column with the same name) e.as_any() @@ -297,7 +294,7 @@ fn stats_projection( impl ProjectionStream { fn batch_project(&self, batch: &RecordBatch) -> Result { - // records time on drop + // Records time on drop let _timer = self.baseline_metrics.elapsed_compute().timer(); let arrays = self .expr @@ -311,10 +308,10 @@ impl ProjectionStream { if arrays.is_empty() { let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); - RecordBatch::try_new_with_options(self.schema.clone(), arrays, &options) + RecordBatch::try_new_with_options(Arc::clone(&self.schema), arrays, &options) .map_err(Into::into) } else { - RecordBatch::try_new(self.schema.clone(), arrays).map_err(Into::into) + RecordBatch::try_new(Arc::clone(&self.schema), arrays).map_err(Into::into) } } } @@ -343,7 +340,7 @@ impl Stream for ProjectionStream { } fn size_hint(&self) -> (usize, Option) { - // same number of record batches + // Same number of record batches self.input.size_hint() } } @@ -351,7 +348,7 @@ impl Stream for ProjectionStream { impl RecordBatchStream for ProjectionStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -359,7 +356,6 @@ impl RecordBatchStream for ProjectionStream { mod tests { use super::*; use crate::common::collect; - use crate::expressions; use crate::test; use arrow_schema::DataType; @@ -370,10 +366,12 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, task_ctx.clone())?).await.unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) + .await + .unwrap(); let projection = ProjectionExec::try_new(vec![], exec)?; - let stream = projection.execute(0, task_ctx.clone())?; + let stream = projection.execute(0, Arc::clone(&task_ctx))?; let output = collect(stream).await.unwrap(); assert_eq!(output.len(), expected.len()); @@ -419,8 +417,8 @@ mod tests { let schema = get_schema(); let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col1", 1)), - Arc::new(expressions::Column::new("col0", 0)), + Arc::new(Column::new("col1", 1)), + Arc::new(Column::new("col0", 0)), ]; let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); @@ -453,8 +451,8 @@ mod tests { let schema = get_schema(); let exprs: Vec> = vec![ - Arc::new(expressions::Column::new("col2", 2)), - Arc::new(expressions::Column::new("col0", 0)), + Arc::new(Column::new("col2", 2)), + Arc::new(Column::new("col0", 0)), ]; let result = stats_projection(source, exprs.into_iter(), Arc::new(schema)); diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index ed897d78f0c8..cbf22a4b392f 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -53,7 +53,7 @@ use futures::{ready, Stream, StreamExt}; /// Note that there won't be any limit or checks applied to detect /// an infinite recursion, so it is up to the planner to ensure that /// it won't happen. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RecursiveQueryExec { /// Name of the query handler name: String, @@ -82,7 +82,7 @@ impl RecursiveQueryExec { // Each recursive query needs its own work table let work_table = Arc::new(WorkTable::new()); // Use the same work table for both the WorkTableExec and the recursive term - let recursive_term = assign_work_table(recursive_term, work_table.clone())?; + let recursive_term = assign_work_table(recursive_term, Arc::clone(&work_table))?; let cache = Self::compute_properties(static_term.schema()); Ok(RecursiveQueryExec { name, @@ -120,8 +120,8 @@ impl ExecutionPlan for RecursiveQueryExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.static_term.clone(), self.recursive_term.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.static_term, &self.recursive_term] } // TODO: control these hints and see whether we can @@ -147,8 +147,8 @@ impl ExecutionPlan for RecursiveQueryExec { ) -> Result> { RecursiveQueryExec::try_new( self.name.clone(), - children[0].clone(), - children[1].clone(), + Arc::clone(&children[0]), + Arc::clone(&children[1]), self.is_distinct, ) .map(|e| Arc::new(e) as _) @@ -167,12 +167,12 @@ impl ExecutionPlan for RecursiveQueryExec { ))); } - let static_stream = self.static_term.execute(partition, context.clone())?; + let static_stream = self.static_term.execute(partition, Arc::clone(&context))?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(RecursiveQueryStream::new( context, - self.work_table.clone(), - self.recursive_term.clone(), + Arc::clone(&self.work_table), + Arc::clone(&self.recursive_term), static_stream, baseline_metrics, ))) @@ -313,9 +313,9 @@ impl RecursiveQueryStream { // Downstream plans should not expect any partitioning. let partition = 0; - let recursive_plan = reset_plan_states(self.recursive_term.clone())?; + let recursive_plan = reset_plan_states(Arc::clone(&self.recursive_term))?; self.recursive_stream = - Some(recursive_plan.execute(partition, self.task_context.clone())?); + Some(recursive_plan.execute(partition, Arc::clone(&self.task_context))?); self.poll_next(cx) } } @@ -334,7 +334,7 @@ fn assign_work_table( } else { work_table_refs += 1; Ok(Transformed::yes(Arc::new( - exec.with_work_table(work_table.clone()), + exec.with_work_table(Arc::clone(&work_table)), ))) } } else if plan.as_any().is::() { @@ -358,7 +358,8 @@ fn reset_plan_states(plan: Arc) -> Result() { Ok(Transformed::no(plan)) } else { - let new_plan = plan.clone().with_new_children(plan.children())?; + let new_plan = Arc::clone(&plan) + .with_new_children(plan.children().into_iter().cloned().collect())?; Ok(Transformed::yes(new_plan)) } }) @@ -393,7 +394,7 @@ impl Stream for RecursiveQueryStream { self.recursive_stream = None; self.poll_next_iteration(cx) } - Some(Ok(batch)) => self.push_batch(batch.clone()), + Some(Ok(batch)) => self.push_batch(batch), _ => Poll::Ready(batch_result), } } else { @@ -405,7 +406,7 @@ impl Stream for RecursiveQueryStream { impl RecordBatchStream for RecursiveQueryStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/repartition/distributor_channels.rs b/datafusion/physical-plan/src/repartition/distributor_channels.rs index bad923ce9e82..2e5ef24beac3 100644 --- a/datafusion/physical-plan/src/repartition/distributor_channels.rs +++ b/datafusion/physical-plan/src/repartition/distributor_channels.rs @@ -474,7 +474,7 @@ type SharedGate = Arc; #[cfg(test)] mod tests { - use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::atomic::AtomicBool; use futures::{task::ArcWake, FutureExt}; @@ -829,7 +829,7 @@ mod tests { { let test_waker = Arc::new(TestWaker::default()); let waker = futures::task::waker(Arc::clone(&test_waker)); - let mut cx = std::task::Context::from_waker(&waker); + let mut cx = Context::from_waker(&waker); let res = fut.poll_unpin(&mut cx); (res, test_waker) } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index b6554f46cf78..bc65b251561b 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -29,25 +29,28 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, RecordBatchStream, SendableRecordBatchStream, }; -use crate::common::transpose; use crate::hash_utils::create_hashes; use crate::metrics::BaselineMetrics; use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, }; -use crate::sorts::streaming_merge; +use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; -use arrow::array::{ArrayRef, UInt64Builder}; -use arrow::datatypes::SchemaRef; +use arrow::compute::take_arrays; +use arrow::datatypes::{SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, Result}; +use arrow_array::{PrimitiveArray, RecordBatchOptions}; +use datafusion_common::utils::transpose; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; -use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr}; +use crate::execution_plan::CardinalityEffect; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use futures::stream::Stream; use futures::{FutureExt, StreamExt, TryStreamExt}; use hashbrown::HashMap; @@ -133,12 +136,12 @@ impl RepartitionExecState { let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( - input.clone(), + Arc::clone(&input), i, txs.clone(), partitioning.clone(), r_metrics, - context.clone(), + Arc::clone(&context), )); // In a separate task, wait for each input to be done @@ -261,6 +264,7 @@ impl BatchPartitioner { num_partitions: partitions, hash_buffer, } => { + // Tracking time required for distributing indexes across output partitions let timer = self.timer.timer(); let arrays = exprs @@ -274,37 +278,40 @@ impl BatchPartitioner { create_hashes(&arrays, random_state, hash_buffer)?; let mut indices: Vec<_> = (0..*partitions) - .map(|_| UInt64Builder::with_capacity(batch.num_rows())) + .map(|_| Vec::with_capacity(batch.num_rows())) .collect(); for (index, hash) in hash_buffer.iter().enumerate() { - indices[(*hash % *partitions as u64) as usize] - .append_value(index as u64); + indices[(*hash % *partitions as u64) as usize].push(index as u32); } + // Finished building index-arrays for output partitions + timer.done(); + + // Borrowing partitioner timer to prevent moving `self` to closure + let partitioner_timer = &self.timer; let it = indices .into_iter() .enumerate() - .filter_map(|(partition, mut indices)| { - let indices = indices.finish(); + .filter_map(|(partition, indices)| { + let indices: PrimitiveArray = indices.into(); (!indices.is_empty()).then_some((partition, indices)) }) .map(move |(partition, indices)| { - // Produce batches based on indices - let columns = batch - .columns() - .iter() - .map(|c| { - arrow::compute::take(c.as_ref(), &indices, None) - .map_err(|e| arrow_datafusion_err!(e)) - }) - .collect::>>()?; + // Tracking time required for repartitioned batches construction + let _timer = partitioner_timer.timer(); - let batch = - RecordBatch::try_new(batch.schema(), columns).unwrap(); + // Produce batches based on indices + let columns = take_arrays(batch.columns(), &indices, None)?; - // bind timer so it drops w/ this iterator - let _ = &timer; + let mut options = RecordBatchOptions::new(); + options = options.with_row_count(Some(indices.len())); + let batch = RecordBatch::try_new_with_options( + batch.schema(), + columns, + &options, + ) + .unwrap(); Ok((partition, batch)) }); @@ -373,6 +380,11 @@ impl BatchPartitioner { /// `───────' `───────' ///``` /// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// all output partitions and inputs are not polled again. +/// /// # Output Ordering /// /// If more than one stream is being repartitioned, the output will be some @@ -387,12 +399,10 @@ impl BatchPartitioner { /// Paper](https://w6113.github.io/files/papers/volcanoparallelism-89.pdf) /// which uses the term "Exchange" for the concept of repartitioning /// data across threads. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct RepartitionExec { /// Input execution plan input: Arc, - /// Partitioning scheme to use - partitioning: Partitioning, /// Inner state that is initialized when the first output stream is created. state: LazyState, /// Execution metrics @@ -408,7 +418,7 @@ pub struct RepartitionExec { struct RepartitionMetrics { /// Time in nanos to execute child operator and fetch batches fetch_time: metrics::Time, - /// Time in nanos to perform repartitioning + /// Repartitioning elapsed time in nanos repartition_time: metrics::Time, /// Time in nanos for sending resulting batches to channels. /// @@ -427,8 +437,8 @@ impl RepartitionMetrics { MetricBuilder::new(metrics).subset_time("fetch_time", input_partition); // Time in nanos to perform repartitioning - let repart_time = - MetricBuilder::new(metrics).subset_time("repart_time", input_partition); + let repartition_time = + MetricBuilder::new(metrics).subset_time("repartition_time", input_partition); // Time in nanos for sending resulting batches to channels let send_time = (0..num_output_partitions) @@ -443,7 +453,7 @@ impl RepartitionMetrics { Self { fetch_time, - repartition_time: repart_time, + repartition_time, send_time, } } @@ -457,7 +467,7 @@ impl RepartitionExec { /// Partitioning scheme to use pub fn partitioning(&self) -> &Partitioning { - &self.partitioning + &self.cache.partitioning } /// Get preserve_order flag of the RepartitionExecutor @@ -484,7 +494,7 @@ impl DisplayAs for RepartitionExec { f, "{}: partitioning={}, input_partitions={}", self.name(), - self.partitioning, + self.partitioning(), self.input.output_partitioning().partition_count() )?; @@ -493,11 +503,7 @@ impl DisplayAs for RepartitionExec { } if let Some(sort_exprs) = self.sort_exprs() { - write!( - f, - ", sort_exprs={}", - PhysicalSortExpr::format_list(sort_exprs) - )?; + write!(f, ", sort_exprs={}", LexOrdering::from_ref(sort_exprs))?; } Ok(()) } @@ -519,16 +525,18 @@ impl ExecutionPlan for RepartitionExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( self: Arc, mut children: Vec>, ) -> Result> { - let mut repartition = - RepartitionExec::try_new(children.swap_remove(0), self.partitioning.clone())?; + let mut repartition = RepartitionExec::try_new( + children.swap_remove(0), + self.partitioning().clone(), + )?; if self.preserve_order { repartition = repartition.with_preserve_order(); } @@ -536,7 +544,7 @@ impl ExecutionPlan for RepartitionExec { } fn benefits_from_input_partitioning(&self) -> Vec { - vec![matches!(self.partitioning, Partitioning::Hash(_, _))] + vec![matches!(self.partitioning(), Partitioning::Hash(_, _))] } fn maintains_input_order(&self) -> Vec { @@ -556,7 +564,7 @@ impl ExecutionPlan for RepartitionExec { let lazy_state = Arc::clone(&self.state); let input = Arc::clone(&self.input); - let partitioning = self.partitioning.clone(); + let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); let preserve_order = self.preserve_order; let name = self.name().to_owned(); @@ -616,7 +624,7 @@ impl ExecutionPlan for RepartitionExec { schema: Arc::clone(&schema_captured), receiver, drop_helper: Arc::clone(&abort_helper), - reservation: reservation.clone(), + reservation: Arc::clone(&reservation), }) as SendableRecordBatchStream }) .collect::>(); @@ -628,15 +636,15 @@ impl ExecutionPlan for RepartitionExec { let merge_reservation = MemoryConsumer::new(format!("{}[Merge {partition}]", name)) .register(context.memory_pool()); - streaming_merge( - input_streams, - schema_captured, - &sort_exprs, - BaselineMetrics::new(&metrics, partition), - context.session_config().batch_size(), - fetch, - merge_reservation, - ) + StreamingMergeBuilder::new() + .with_streams(input_streams) + .with_schema(schema_captured) + .with_expressions(&sort_exprs) + .with_metrics(BaselineMetrics::new(&metrics, partition)) + .with_batch_size(context.session_config().batch_size()) + .with_fetch(fetch) + .with_reservation(merge_reservation) + .build() } else { Ok(Box::pin(RepartitionStream { num_input_partitions, @@ -660,6 +668,10 @@ impl ExecutionPlan for RepartitionExec { fn statistics(&self) -> Result { self.input.statistics() } + + fn cardinality_effect(&self) -> CardinalityEffect { + CardinalityEffect::Equal + } } impl RepartitionExec { @@ -675,7 +687,6 @@ impl RepartitionExec { Self::compute_properties(&input, partitioning.clone(), preserve_order); Ok(RepartitionExec { input, - partitioning, state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, @@ -701,6 +712,11 @@ impl RepartitionExec { if !Self::maintains_input_order_helper(input, preserve_order)[0] { eq_properties.clear_orderings(); } + // When there are more than one input partitions, they will be fused at the output. + // Therefore, remove per partition constants. + if input.output_partitioning().partition_count() > 1 { + eq_properties.clear_per_partition_constants(); + } eq_properties } @@ -861,7 +877,7 @@ impl RepartitionExec { for (_, tx) in txs { // wrap it because need to send error to all output partitions - let err = Err(DataFusionError::External(Box::new(e.clone()))); + let err = Err(DataFusionError::External(Box::new(Arc::clone(&e)))); tx.send(Some(err)).await.ok(); } } @@ -940,7 +956,7 @@ impl Stream for RepartitionStream { impl RecordBatchStream for RepartitionStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -990,7 +1006,7 @@ impl Stream for PerPartitionStream { impl RecordBatchStream for PerPartitionStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1010,11 +1026,11 @@ mod tests { {collect, expressions::col, memory::MemoryExec}, }; - use arrow::array::{StringArray, UInt32Array}; + use arrow::array::{ArrayRef, StringArray, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_string_array; - use datafusion_common::{assert_batches_sorted_eq, exec_err}; - use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion_common::{arrow_datafusion_err, assert_batches_sorted_eq, exec_err}; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; use tokio::task::JoinSet; @@ -1112,14 +1128,14 @@ mod tests { ) -> Result>> { let task_ctx = Arc::new(TaskContext::default()); // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; // execute and collect results let mut output_partitions = vec![]; - for i in 0..exec.partitioning.partition_count() { + for i in 0..exec.partitioning().partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, task_ctx.clone())?; + let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -1296,14 +1312,18 @@ mod tests { let input = Arc::new(make_barrier_exec()); // partition into two output streams - let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); + let exec = RepartitionExec::try_new( + Arc::clone(&input) as Arc, + partitioning, + ) + .unwrap(); - let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced - std::mem::drop(output_stream0); + drop(output_stream0); // Now, start sending input let mut background_task = JoinSet::new(); @@ -1330,8 +1350,8 @@ mod tests { #[tokio::test] // As the hash results might be different on different platforms or - // wiht different compilers, we will compare the same execution with - // and without droping the output stream. + // with different compilers, we will compare the same execution with + // and without dropping the output stream. async fn hash_repartition_with_dropping_output_stream() { let task_ctx = Arc::new(TaskContext::default()); let partitioning = Partitioning::Hash( @@ -1342,10 +1362,14 @@ mod tests { 2, ); - // We first collect the results without droping the output stream. + // We first collect the results without dropping the output stream. let input = Arc::new(make_barrier_exec()); - let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let exec = RepartitionExec::try_new( + Arc::clone(&input) as Arc, + partitioning.clone(), + ) + .unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); let mut background_task = JoinSet::new(); background_task.spawn(async move { input.wait().await; @@ -1365,12 +1389,16 @@ mod tests { // Now do the same but dropping the stream before waiting for the barrier let input = Arc::new(make_barrier_exec()); - let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let exec = RepartitionExec::try_new( + Arc::clone(&input) as Arc, + partitioning, + ) + .unwrap(); + let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced - std::mem::drop(output_stream0); + drop(output_stream0); let mut background_task = JoinSet::new(); background_task.spawn(async move { input.wait().await; @@ -1466,9 +1494,9 @@ mod tests { let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch)], schema); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream0 = exec.execute(0, task_ctx.clone()).unwrap(); + let output_stream0 = exec.execute(0, Arc::clone(&task_ctx)).unwrap(); let batch0 = crate::common::collect(output_stream0).await.unwrap(); - let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); + let output_stream1 = exec.execute(1, Arc::clone(&task_ctx)).unwrap(); let batch1 = crate::common::collect(output_stream1).await.unwrap(); assert!(batch0.is_empty() || batch1.is_empty()); Ok(()) @@ -1483,20 +1511,20 @@ mod tests { let partitioning = Partitioning::RoundRobinBatch(4); // setup up context - let runtime = Arc::new( - RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)).unwrap(), - ); + let runtime = RuntimeEnvBuilder::default() + .with_memory_limit(1, 1.0) + .build_arc()?; let task_ctx = TaskContext::default().with_runtime(runtime); let task_ctx = Arc::new(task_ctx); // create physical plan - let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; + let exec = MemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; // pull partitions - for i in 0..exec.partitioning.partition_count() { - let mut stream = exec.execute(i, task_ctx.clone())?; + for i in 0..exec.partitioning().partition_count() { + let mut stream = exec.execute(i, Arc::clone(&task_ctx))?; let err = arrow_datafusion_err!(stream.next().await.unwrap().unwrap_err().into()); let err = err.find_root(); @@ -1530,10 +1558,10 @@ mod tests { mod test { use arrow_schema::{DataType, Field, Schema, SortOptions}; - use datafusion_physical_expr::expressions::col; - use crate::memory::MemoryExec; use crate::union::UnionExec; + use datafusion_physical_expr::expressions::col; + use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use super::*; @@ -1628,26 +1656,27 @@ mod test { Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)])) } - fn sort_exprs(schema: &Schema) -> Vec { + fn sort_exprs(schema: &Schema) -> LexOrdering { let options = SortOptions::default(); - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("c0", schema).unwrap(), options, - }] + }]) } fn memory_exec(schema: &SchemaRef) -> Arc { - Arc::new(MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap()) + Arc::new(MemoryExec::try_new(&[vec![]], Arc::clone(schema), None).unwrap()) } fn sorted_memory_exec( schema: &SchemaRef, - sort_exprs: Vec, + sort_exprs: LexOrdering, ) -> Arc { Arc::new( - MemoryExec::try_new(&[vec![]], schema.clone(), None) + MemoryExec::try_new(&[vec![]], Arc::clone(schema), None) .unwrap() - .with_sort_information(vec![sort_exprs]), + .try_with_sort_information(vec![sort_exprs]) + .unwrap(), ) } } diff --git a/datafusion/physical-plan/src/sorts/builder.rs b/datafusion/physical-plan/src/sorts/builder.rs index 3527d5738223..d32c60697ec8 100644 --- a/datafusion/physical-plan/src/sorts/builder.rs +++ b/datafusion/physical-plan/src/sorts/builder.rs @@ -20,6 +20,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; +use std::sync::Arc; #[derive(Debug, Copy, Clone, Default)] struct BatchCursor { @@ -145,6 +146,9 @@ impl BatchBuilder { retain }); - Ok(Some(RecordBatch::try_new(self.schema.clone(), columns)?)) + Ok(Some(RecordBatch::try_new( + Arc::clone(&self.schema), + columns, + )?)) } } diff --git a/datafusion/physical-plan/src/sorts/cursor.rs b/datafusion/physical-plan/src/sorts/cursor.rs index df90c97faf68..133d736c1467 100644 --- a/datafusion/physical-plan/src/sorts/cursor.rs +++ b/datafusion/physical-plan/src/sorts/cursor.rs @@ -38,6 +38,10 @@ pub trait CursorValues { /// Returns true if `l[l_idx] == r[r_idx]` fn eq(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> bool; + /// Returns true if `row[idx] == row[idx - 1]` + /// Given `idx` should be greater than 0 + fn eq_to_previous(cursor: &Self, idx: usize) -> bool; + /// Returns comparison of `l[l_idx]` and `r[r_idx]` fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering; } @@ -95,6 +99,16 @@ impl Cursor { self.offset += 1; t } + + pub fn is_eq_to_prev_one(&self, prev_cursor: Option<&Cursor>) -> bool { + if self.offset > 0 { + self.is_eq_to_prev_row() + } else if let Some(prev_cursor) = prev_cursor { + self.is_eq_to_prev_row_in_prev_batch(prev_cursor) + } else { + false + } + } } impl PartialEq for Cursor { @@ -103,6 +117,22 @@ impl PartialEq for Cursor { } } +impl Cursor { + fn is_eq_to_prev_row(&self) -> bool { + T::eq_to_previous(&self.values, self.offset) + } + + fn is_eq_to_prev_row_in_prev_batch(&self, other: &Self) -> bool { + assert_eq!(self.offset, 0); + T::eq( + &self.values, + self.offset, + &other.values, + other.values.len() - 1, + ) + } +} + impl Eq for Cursor {} impl PartialOrd for Cursor { @@ -156,6 +186,11 @@ impl CursorValues for RowValues { l.rows.row(l_idx) == r.rows.row(r_idx) } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + cursor.rows.row(idx) == cursor.rows.row(idx - 1) + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { l.rows.row(l_idx).cmp(&r.rows.row(r_idx)) } @@ -188,6 +223,11 @@ impl CursorValues for PrimitiveValues { l.0[l_idx].is_eq(r.0[r_idx]) } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + cursor.0[idx].is_eq(cursor.0[idx - 1]) + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { l.0[l_idx].compare(r.0[r_idx]) } @@ -219,6 +259,11 @@ impl CursorValues for ByteArrayValues { l.value(l_idx) == r.value(r_idx) } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + cursor.value(idx) == cursor.value(idx - 1) + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { l.value(l_idx).cmp(r.value(r_idx)) } @@ -284,6 +329,15 @@ impl CursorValues for ArrayValues { } } + fn eq_to_previous(cursor: &Self, idx: usize) -> bool { + assert!(idx > 0); + match (cursor.is_null(idx), cursor.is_null(idx - 1)) { + (true, true) => true, + (false, false) => T::eq(&cursor.values, idx, &cursor.values, idx - 1), + _ => false, + } + } + fn compare(l: &Self, l_idx: usize, r: &Self, r_idx: usize) -> Ordering { match (l.is_null(l_idx), r.is_null(r_idx)) { (true, true) => Ordering::Equal, diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 422ff3aebdb3..458c1c29c0cf 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -18,22 +18,28 @@ //! Merge that deals with an arbitrary size of streaming inputs. //! This is an order-preserving merge. +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{ready, Context, Poll}; + use crate::metrics::BaselineMetrics; use crate::sorts::builder::BatchBuilder; use crate::sorts::cursor::{Cursor, CursorValues}; use crate::sorts::stream::PartitionedStream; use crate::RecordBatchStream; + use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; + use futures::Stream; -use std::pin::Pin; -use std::task::{ready, Context, Poll}; /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; +/// Merges a stream of sorted cursors and record batches into a single sorted stream #[derive(Debug)] pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, @@ -85,17 +91,57 @@ pub(crate) struct SortPreservingMergeStream { /// been updated loser_tree_adjusted: bool, - /// target batch size + /// Target batch size batch_size: usize, /// Cursors for each input partition. `None` means the input is exhausted cursors: Vec>>, + /// Configuration parameter to enable round-robin selection of tied winners of loser tree. + /// + /// To address the issue of unbalanced polling between partitions due to tie-breakers being based + /// on partition index, especially in cases of low cardinality, we are making changes to the winner + /// selection mechanism. Previously, partitions with smaller indices were consistently chosen as the winners, + /// leading to an uneven distribution of polling. This caused upstream operator buffers for the other partitions + /// to grow excessively, as they continued receiving data without consuming it. + /// + /// For example, an upstream operator like a repartition execution would keep sending data to certain partitions, + /// but those partitions wouldn't consume the data if they weren't selected as winners. This resulted in inefficient buffer usage. + /// + /// To resolve this, we are modifying the tie-breaking logic. Instead of always choosing the partition with the smallest index, + /// we now select the partition that has the fewest poll counts for the same value. + /// This ensures that multiple partitions with the same value are chosen equally, distributing the polling load in a round-robin fashion. + /// This approach balances the workload more effectively across partitions and avoids excessive buffer growth. + enable_round_robin_tie_breaker: bool, + + /// Flag indicating whether we are in the mode of round-robin + /// tie breaker for the loser tree winners. + round_robin_tie_breaker_mode: bool, + + /// Total number of polls returning the same value, as per partition. + /// We select the one that has less poll counts for tie-breaker in loser tree. + num_of_polled_with_same_value: Vec, + + /// To keep track of reset counts + poll_reset_epochs: Vec, + + /// Current reset count + current_reset_epoch: usize, + + /// Stores the previous value of each partitions for tracking the poll counts on the same value. + prev_cursors: Vec>>, + /// Optional number of rows to fetch fetch: Option, /// number of rows produced produced: usize, + + /// This queue contains partition indices in order. When a partition is polled and returns `Poll::Ready`, + /// it is removed from the vector. If a partition returns `Poll::Pending`, it is moved to the end of the + /// vector to ensure the next iteration starts with a different partition, preventing the same partition + /// from being continuously polled. + uninitiated_partitions: VecDeque, } impl SortPreservingMergeStream { @@ -106,6 +152,7 @@ impl SortPreservingMergeStream { batch_size: usize, fetch: Option, reservation: MemoryReservation, + enable_round_robin_tie_breaker: bool, ) -> Self { let stream_count = streams.partitions(); @@ -115,11 +162,18 @@ impl SortPreservingMergeStream { metrics, aborted: false, cursors: (0..stream_count).map(|_| None).collect(), + prev_cursors: (0..stream_count).map(|_| None).collect(), + round_robin_tie_breaker_mode: false, + num_of_polled_with_same_value: vec![0; stream_count], + current_reset_epoch: 0, + poll_reset_epochs: vec![0; stream_count], loser_tree: vec![], loser_tree_adjusted: false, batch_size, fetch, produced: 0, + uninitiated_partitions: (0..stream_count).collect(), + enable_round_robin_tie_breaker, } } @@ -153,14 +207,36 @@ impl SortPreservingMergeStream { if self.aborted { return Poll::Ready(None); } - // try to initialize the loser tree + // Once all partitions have set their corresponding cursors for the loser tree, + // we skip the following block. Until then, this function may be called multiple + // times and can return Poll::Pending if any partition returns Poll::Pending. if self.loser_tree.is_empty() { - // Ensure all non-exhausted streams have a cursor from which - // rows can be pulled - for i in 0..self.streams.partitions() { - if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) { - self.aborted = true; - return Poll::Ready(Some(Err(e))); + let remaining_partitions = self.uninitiated_partitions.clone(); + for i in remaining_partitions { + match self.maybe_poll_stream(cx, i) { + Poll::Ready(Err(e)) => { + self.aborted = true; + return Poll::Ready(Some(Err(e))); + } + Poll::Pending => { + // If a partition returns Poll::Pending, to avoid continuously polling it + // and potentially increasing upstream buffer sizes, we move it to the + // back of the polling queue. + if let Some(front) = self.uninitiated_partitions.pop_front() { + // This pop_front can never return `None`. + self.uninitiated_partitions.push_back(front); + } + // This function could remain in a pending state, so we manually wake it here. + // However, this approach can be investigated further to find a more natural way + // to avoid disrupting the runtime scheduler. + cx.waker().wake_by_ref(); + return Poll::Pending; + } + _ => { + // If the polling result is Poll::Ready(Some(batch)) or Poll::Ready(None), + // we remove this partition from the queue so it is not polled again. + self.uninitiated_partitions.retain(|idx| *idx != i); + } } } self.init_loser_tree(); @@ -183,7 +259,7 @@ impl SortPreservingMergeStream { } let stream_idx = self.loser_tree[0]; - if self.advance(stream_idx) { + if self.advance_cursors(stream_idx) { self.loser_tree_adjusted = false; self.in_progress.push_row(stream_idx); @@ -201,27 +277,53 @@ impl SortPreservingMergeStream { } } + /// For the given partition, updates the poll count. If the current value is the same + /// of the previous value, it increases the count by 1; otherwise, it is reset as 0. + fn update_poll_count_on_the_same_value(&mut self, partition_idx: usize) { + let cursor = &mut self.cursors[partition_idx]; + + // Check if the current partition's poll count is logically "reset" + if self.poll_reset_epochs[partition_idx] != self.current_reset_epoch { + self.poll_reset_epochs[partition_idx] = self.current_reset_epoch; + self.num_of_polled_with_same_value[partition_idx] = 0; + } + + if let Some(c) = cursor.as_mut() { + // Compare with the last row in the previous batch + let prev_cursor = &self.prev_cursors[partition_idx]; + if c.is_eq_to_prev_one(prev_cursor.as_ref()) { + self.num_of_polled_with_same_value[partition_idx] += 1; + } else { + self.num_of_polled_with_same_value[partition_idx] = 0; + } + } + } + fn fetch_reached(&mut self) -> bool { self.fetch .map(|fetch| self.produced + self.in_progress.len() >= fetch) .unwrap_or(false) } - fn advance(&mut self, stream_idx: usize) -> bool { - let slot = &mut self.cursors[stream_idx]; - match slot.as_mut() { - Some(c) => { - c.advance(); - if c.is_finished() { - *slot = None; - } - true + /// Advances the actual cursor. If it reaches its end, update the + /// previous cursor with it. + /// + /// If the given partition is not exhausted, the function returns `true`. + fn advance_cursors(&mut self, stream_idx: usize) -> bool { + if let Some(cursor) = &mut self.cursors[stream_idx] { + let _ = cursor.advance(); + if cursor.is_finished() { + // Take the current cursor, leaving `None` in its place + self.prev_cursors[stream_idx] = self.cursors[stream_idx].take(); } - None => false, + true + } else { + false } } - /// Returns `true` if the cursor at index `a` is greater than at index `b` + /// Returns `true` if the cursor at index `a` is greater than at index `b`. + /// In an equality case, it compares the partition indices given. #[inline] fn is_gt(&self, a: usize, b: usize) -> bool { match (&self.cursors[a], &self.cursors[b]) { @@ -231,6 +333,19 @@ impl SortPreservingMergeStream { } } + #[inline] + fn is_poll_count_gt(&self, a: usize, b: usize) -> bool { + let poll_a = self.num_of_polled_with_same_value[a]; + let poll_b = self.num_of_polled_with_same_value[b]; + poll_a.cmp(&poll_b).then_with(|| a.cmp(&b)).is_gt() + } + + #[inline] + fn update_winner(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) { + self.loser_tree[cmp_node] = *winner; + *winner = challenger; + } + /// Find the leaf node index in the loser tree for the given cursor index /// /// Note that this is not necessarily a leaf node in the tree, but it can @@ -292,16 +407,101 @@ impl SortPreservingMergeStream { self.loser_tree_adjusted = true; } - /// Attempts to update the loser tree, following winner replacement, if possible + /// Resets the poll count by incrementing the reset epoch. + fn reset_poll_counts(&mut self) { + self.current_reset_epoch += 1; + } + + /// Handles tie-breaking logic during the adjustment of the loser tree. + /// + /// When comparing elements from multiple partitions in the `update_loser_tree` process, a tie can occur + /// between the current winner and a challenger. This function is invoked when such a tie needs to be + /// resolved according to the round-robin tie-breaker mode. + /// + /// If round-robin tie-breaking is not active, it is enabled, and the poll counts for all elements are reset. + /// The function then compares the poll counts of the current winner and the challenger: + /// - If the winner remains at the top after the final comparison, it increments the winner's poll count. + /// - If the challenger has a lower poll count than the current winner, the challenger becomes the new winner. + /// - If the poll counts are equal but the challenger's index is smaller, the challenger is preferred. + /// + /// # Parameters + /// - `cmp_node`: The index of the comparison node in the loser tree where the tie-breaking is happening. + /// - `winner`: A mutable reference to the current winner, which may be updated based on the tie-breaking result. + /// - `challenger`: The index of the challenger being compared against the winner. + /// + /// This function ensures fair selection among elements with equal values when tie-breaking mode is enabled, + /// aiming to balance the polling across different partitions. + #[inline] + fn handle_tie(&mut self, cmp_node: usize, winner: &mut usize, challenger: usize) { + if !self.round_robin_tie_breaker_mode { + self.round_robin_tie_breaker_mode = true; + // Reset poll count for tie-breaker + self.reset_poll_counts(); + } + // Update poll count if the winner survives in the final match + if *winner == self.loser_tree[0] { + self.update_poll_count_on_the_same_value(*winner); + if self.is_poll_count_gt(*winner, challenger) { + self.update_winner(cmp_node, winner, challenger); + } + } else if challenger < *winner { + // If the winner doesn’t survive in the final match, it indicates that the original winner + // has moved up in value, so the challenger now becomes the new winner. + // This also means that we’re in a new round of the tie breaker, + // and the polls count is outdated (though not yet cleaned up). + // + // By the time we reach this code, both the new winner and the current challenger + // have the same value, and neither has an updated polls count. + // Therefore, we simply select the one with the smaller index. + self.update_winner(cmp_node, winner, challenger); + } + } + + /// Updates the loser tree to reflect the new winner after the previous winner is consumed. + /// This function adjusts the tree by comparing the current winner with challengers from + /// other partitions. + /// + /// If `enable_round_robin_tie_breaker` is true and a tie occurs at the final level, the + /// tie-breaker logic will be applied to ensure fair selection among equal elements. fn update_loser_tree(&mut self) { + // Start with the current winner let mut winner = self.loser_tree[0]; - // Replace overall winner by walking tree of losers + + // Find the leaf node index of the winner in the loser tree. let mut cmp_node = self.lt_leaf_node_index(winner); + + // Traverse up the tree to adjust comparisons until reaching the root. while cmp_node != 0 { let challenger = self.loser_tree[cmp_node]; - if self.is_gt(winner, challenger) { - self.loser_tree[cmp_node] = winner; - winner = challenger; + // If round-robin tie-breaker is enabled and we're at the final comparison (cmp_node == 1) + if self.enable_round_robin_tie_breaker && cmp_node == 1 { + match (&self.cursors[winner], &self.cursors[challenger]) { + (Some(ac), Some(bc)) => { + let ord = ac.cmp(bc); + if ord.is_eq() { + self.handle_tie(cmp_node, &mut winner, challenger); + } else { + // Ends of tie breaker + self.round_robin_tie_breaker_mode = false; + if ord.is_gt() { + self.update_winner(cmp_node, &mut winner, challenger); + } + } + } + (None, _) => { + // Challenger wins, update winner + // Ends of tie breaker + self.round_robin_tie_breaker_mode = false; + self.update_winner(cmp_node, &mut winner, challenger); + } + (_, None) => { + // Winner wins again + // Ends of tie breaker + self.round_robin_tie_breaker_mode = false; + } + } + } else if self.is_gt(winner, challenger) { + self.update_winner(cmp_node, &mut winner, challenger); } cmp_node = self.lt_parent_node_index(cmp_node); } @@ -324,6 +524,6 @@ impl Stream for SortPreservingMergeStream { impl RecordBatchStream for SortPreservingMergeStream { fn schema(&self) -> SchemaRef { - self.in_progress.schema().clone() + Arc::clone(self.in_progress.schema()) } } diff --git a/datafusion/physical-plan/src/sorts/mod.rs b/datafusion/physical-plan/src/sorts/mod.rs index 7c084761fdc3..ab5df37ed327 100644 --- a/datafusion/physical-plan/src/sorts/mod.rs +++ b/datafusion/physical-plan/src/sorts/mod.rs @@ -28,4 +28,3 @@ mod stream; pub mod streaming_merge; pub use index::RowIndex; -pub(crate) use streaming_merge::streaming_merge; diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index d24bc5a670e5..8f853464c9bd 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -57,7 +57,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::sorts::sort::sort_batch; use crate::{ @@ -73,6 +72,7 @@ use datafusion_common::Result; use datafusion_execution::{RecordBatchStream, TaskContext}; use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; use futures::{ready, Stream, StreamExt}; use log::trace; @@ -82,7 +82,7 @@ pub struct PartialSortExec { /// Input schema pub(crate) input: Arc, /// Sort expressions - expr: Vec, + expr: LexOrdering, /// Length of continuous matching columns of input that satisfy /// the required ordering for the sort common_prefix_length: usize, @@ -100,11 +100,11 @@ pub struct PartialSortExec { impl PartialSortExec { /// Create a new partial sort execution plan pub fn new( - expr: Vec, + expr: LexOrdering, input: Arc, common_prefix_length: usize, ) -> Self { - assert!(common_prefix_length > 0); + debug_assert!(common_prefix_length > 0); let preserve_partitioning = false; let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); Self { @@ -159,8 +159,8 @@ impl PartialSortExec { } /// Sort expressions - pub fn expr(&self) -> &[PhysicalSortExpr] { - &self.expr + pub fn expr(&self) -> LexOrderingRef { + self.expr.as_ref() } /// If `Some(fetch)`, limits output to only the first "fetch" items @@ -212,13 +212,12 @@ impl DisplayAs for PartialSortExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr = PhysicalSortExpr::format_list(&self.expr); let common_prefix_length = self.common_prefix_length; match self.fetch { Some(fetch) => { - write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{expr}], common_prefix_length=[{common_prefix_length}]", ) + write!(f, "PartialSortExec: TopK(fetch={fetch}), expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr) } - None => write!(f, "PartialSortExec: expr=[{expr}], common_prefix_length=[{common_prefix_length}]"), + None => write!(f, "PartialSortExec: expr=[{}], common_prefix_length=[{common_prefix_length}]", self.expr), } } } @@ -238,6 +237,10 @@ impl ExecutionPlan for PartialSortExec { &self.cache } + fn fetch(&self) -> Option { + self.fetch + } + fn required_input_distribution(&self) -> Vec { if self.preserve_partitioning { vec![Distribution::UnspecifiedDistribution] @@ -250,8 +253,8 @@ impl ExecutionPlan for PartialSortExec { vec![false] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -260,7 +263,7 @@ impl ExecutionPlan for PartialSortExec { ) -> Result> { let new_partial_sort = PartialSortExec::new( self.expr.clone(), - children[0].clone(), + Arc::clone(&children[0]), self.common_prefix_length, ) .with_fetch(self.fetch) @@ -276,7 +279,7 @@ impl ExecutionPlan for PartialSortExec { ) -> Result { trace!("Start PartialSortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - let input = self.input.execute(partition, context.clone())?; + let input = self.input.execute(partition, Arc::clone(&context))?; trace!( "End PartialSortExec's input.execute for partition: {}", @@ -285,7 +288,7 @@ impl ExecutionPlan for PartialSortExec { // Make sure common prefix length is larger than 0 // Otherwise, we should use SortExec. - assert!(self.common_prefix_length > 0); + debug_assert!(self.common_prefix_length > 0); Ok(Box::pin(PartialSortStream { input, @@ -311,7 +314,7 @@ struct PartialSortStream { /// The input plan input: SendableRecordBatchStream, /// Sort expressions - expr: Vec, + expr: LexOrdering, /// Length of prefix common to input ordering and required ordering of plan /// should be more than 0 otherwise PartialSort is not applicable common_prefix_length: usize, @@ -390,7 +393,7 @@ impl PartialSortStream { fn sort_in_mem_batches(self: &mut Pin<&mut Self>) -> Result { let input_batch = concat_batches(&self.schema(), &self.in_mem_batches)?; self.in_mem_batches.clear(); - let result = sort_batch(&input_batch, &self.expr, self.fetch)?; + let result = sort_batch(&input_batch, self.expr.as_ref(), self.fetch)?; if let Some(remaining_fetch) = self.fetch { // remaining_fetch - result.num_rows() is always be >= 0 // because result length of sort_batch with limit cannot be @@ -444,6 +447,7 @@ mod tests { use crate::collect; use crate::expressions::col; + use crate::expressions::PhysicalSortExpr; use crate::memory::MemoryExec; use crate::sorts::sort::SortExec; use crate::test; @@ -471,7 +475,7 @@ mod tests { }; let partial_sort_exec = Arc::new(PartialSortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -484,12 +488,12 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ], - source.clone(), + ]), + Arc::clone(&source), 2, )) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; let expected_after_sort = [ "+---+---+---+", @@ -535,7 +539,7 @@ mod tests { for common_prefix_length in [1, 2] { let partial_sort_exec = Arc::new( PartialSortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -548,14 +552,14 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ], - source.clone(), + ]), + Arc::clone(&source), common_prefix_length, ) .with_fetch(Some(4)), ) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; let expected_after_sort = [ "+---+---+---+", @@ -607,7 +611,7 @@ mod tests { [(1, &source_tables[0]), (2, &source_tables[1])] { let partial_sort_exec = Arc::new(PartialSortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -620,12 +624,12 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ], - source.clone(), + ]), + Arc::clone(source), common_prefix_length, )); - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(2, result.len()); assert_eq!( task_ctx.runtime_env().memory_pool.reserved(), @@ -676,7 +680,7 @@ mod tests { Arc::new( MemoryExec::try_new( &[vec![batch1, batch2, batch3, batch4]], - schema.clone(), + Arc::clone(&schema), None, ) .unwrap(), @@ -697,7 +701,7 @@ mod tests { }; let schema = mem_exec.schema(); let partial_sort_executor = PartialSortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -710,8 +714,8 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ], - mem_exec.clone(), + ]), + Arc::clone(&mem_exec), 1, ); let partial_sort_exec = @@ -720,7 +724,7 @@ mod tests { partial_sort_executor.expr, partial_sort_executor.input, )) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), [0, 125, 125, 0, 150] @@ -732,7 +736,7 @@ mod tests { "The sort should have returned all memory used back to the memory manager" ); let partial_sort_result = concat_batches(&schema, &result).unwrap(); - let sort_result = collect(sort_exec, task_ctx.clone()).await?; + let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(sort_result[0], partial_sort_result); Ok(()) @@ -758,7 +762,7 @@ mod tests { (Some(250), vec![0, 125, 125]), ] { let partial_sort_executor = PartialSortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -771,8 +775,8 @@ mod tests { expr: col("c", &schema)?, options: option_asc, }, - ], - mem_exec.clone(), + ]), + Arc::clone(&mem_exec), 1, ) .with_fetch(fetch_size); @@ -783,7 +787,7 @@ mod tests { SortExec::new(partial_sort_executor.expr, partial_sort_executor.input) .with_fetch(fetch_size), ) as Arc; - let result = collect(partial_sort_exec, task_ctx.clone()).await?; + let result = collect(partial_sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!( result.iter().map(|r| r.num_rows()).collect_vec(), expected_batch_num_rows @@ -795,7 +799,7 @@ mod tests { "The sort should have returned all memory used back to the memory manager" ); let partial_sort_result = concat_batches(&schema, &result)?; - let sort_result = collect(sort_exec, task_ctx.clone()).await?; + let sort_result = collect(sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(sort_result[0], partial_sort_result); } @@ -822,14 +826,18 @@ mod tests { let data: ArrayRef = Arc::new(vec![1, 1, 2].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(schema.clone(), vec![data])?; - let input = Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None)?); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data])?; + let input = Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); let partial_sort_exec = Arc::new(PartialSortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }], + }]), input, 1, )); @@ -837,13 +845,13 @@ mod tests { let result: Vec = collect(partial_sort_exec, task_ctx).await?; let expected_batch = vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new( vec![1, 1].into_iter().map(Some).collect::(), )], )?, RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new( vec![2].into_iter().map(Some).collect::(), )], @@ -879,7 +887,7 @@ mod tests { // define data. let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![ Some(1.0_f32), @@ -915,7 +923,7 @@ mod tests { )?; let partial_sort_exec = Arc::new(PartialSortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: option_asc, @@ -928,7 +936,7 @@ mod tests { expr: col("c", &schema)?, options: option_desc, }, - ], + ]), Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), 2, )); @@ -961,8 +969,11 @@ mod tests { *partial_sort_exec.schema().field(2).data_type() ); - let result: Vec = - collect(partial_sort_exec.clone(), task_ctx).await?; + let result: Vec = collect( + Arc::clone(&partial_sort_exec) as Arc, + task_ctx, + ) + .await?; assert_batches_eq!(expected, &result); assert_eq!(result.len(), 2); let metrics = partial_sort_exec.metrics().unwrap(); @@ -989,15 +1000,15 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(PartialSortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }], + }]), blocking_exec, 1, )); - let fut = collect(sort_exec, task_ctx.clone()); + let fut = collect(sort_exec, Arc::clone(&task_ctx)); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index e4e3d46dfbbc..d90d0f64ceb4 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -22,18 +22,17 @@ use std::any::Any; use std::fmt; use std::fmt::{Debug, Formatter}; -use std::fs::File; -use std::io::BufReader; -use std::path::{Path, PathBuf}; use std::sync::Arc; -use crate::common::{spawn_buffered, IPCWriter}; +use crate::common::spawn_buffered; use crate::expressions::PhysicalSortExpr; +use crate::limit::LimitStream; use crate::metrics::{ BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, }; -use crate::sorts::streaming_merge::streaming_merge; -use crate::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; +use crate::sorts::streaming_merge::StreamingMergeBuilder; +use crate::spill::{read_spill_as_stream, spill_record_batches}; +use crate::stream::RecordBatchStreamAdapter; use crate::topk::TopK; use crate::{ DisplayAs, DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionMode, @@ -41,26 +40,25 @@ use crate::{ SendableRecordBatchStream, Statistics, }; -use arrow::compute::{concat_batches, lexsort_to_indices, take, SortColumn}; +use arrow::compute::{concat_batches, lexsort_to_indices, take_arrays, SortColumn}; use arrow::datatypes::SchemaRef; -use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use arrow_array::{Array, RecordBatchOptions, UInt32Array}; use arrow_schema::DataType; -use datafusion_common::{exec_err, DataFusionError, Result}; -use datafusion_common_runtime::SpawnedTask; +use datafusion_common::{internal_err, Result}; use datafusion_execution::disk_manager::RefCountedTempFile; -use datafusion_execution::memory_pool::{ - human_readable_size, MemoryConsumer, MemoryReservation, -}; +use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::LexOrdering; +use datafusion_physical_expr_common::sort_expr::{ + LexOrderingRef, PhysicalSortRequirement, +}; +use crate::execution_plan::CardinalityEffect; use futures::{StreamExt, TryStreamExt}; -use log::{debug, error, trace}; -use tokio::sync::mpsc::Sender; +use log::{debug, trace}; struct ExternalSorterMetrics { /// metrics @@ -97,14 +95,14 @@ impl ExternalSorterMetrics { /// 1. get a non-empty new batch from input /// /// 2. check with the memory manager there is sufficient space to -/// buffer the batch in memory 2.1 if memory sufficient, buffer -/// batch in memory, go to 1. +/// buffer the batch in memory 2.1 if memory sufficient, buffer +/// batch in memory, go to 1. /// /// 2.2 if no more memory is available, sort all buffered batches and /// spill to file. buffer the next batch in memory, go to 1. /// /// 3. when input is exhausted, merge all in memory batches and spills -/// to get a total order. +/// to get a total order. /// /// # When data fits in available memory /// @@ -247,7 +245,7 @@ impl ExternalSorter { pub fn new( partition_id: usize, schema: SchemaRef, - expr: Vec, + expr: LexOrdering, batch_size: usize, fetch: Option, sort_spill_reservation_bytes: usize, @@ -269,7 +267,7 @@ impl ExternalSorter { in_mem_batches: vec![], in_mem_batches_sorted: true, spills: vec![], - expr: expr.into(), + expr: expr.inner.into(), metrics, fetch, reservation, @@ -328,7 +326,7 @@ impl ExternalSorter { /// 1. An in-memory sort/merge (if the input fit in memory) /// /// 2. A combined streaming merge incorporating both in-memory - /// batches and data from spill files on disk. + /// batches and data from spill files on disk. fn sort(&mut self) -> Result { if self.spilled_before() { let mut streams = vec![]; @@ -340,28 +338,23 @@ impl ExternalSorter { for spill in self.spills.drain(..) { if !spill.path().exists() { - return Err(DataFusionError::Internal(format!( - "Spill file {:?} does not exist", - spill.path() - ))); + return internal_err!("Spill file {:?} does not exist", spill.path()); } - let stream = read_spill_as_stream(spill, self.schema.clone())?; + let stream = read_spill_as_stream(spill, Arc::clone(&self.schema), 2)?; streams.push(stream); } - streaming_merge( - streams, - self.schema.clone(), - &self.expr, - self.metrics.baseline.clone(), - self.batch_size, - self.fetch, - self.reservation.new_empty(), - ) - } else if !self.in_mem_batches.is_empty() { - self.in_mem_sort_stream(self.metrics.baseline.clone()) + StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(Arc::clone(&self.schema)) + .with_expressions(self.expr.to_vec().as_slice()) + .with_metrics(self.metrics.baseline.clone()) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_reservation(self.reservation.new_empty()) + .build() } else { - Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone()))) + self.in_mem_sort_stream(self.metrics.baseline.clone()) } } @@ -401,12 +394,15 @@ impl ExternalSorter { let spill_file = self.runtime.disk_manager.create_tmp_file("Sorting")?; let batches = std::mem::take(&mut self.in_mem_batches); - let spilled_rows = - spill_sorted_batches(batches, spill_file.path(), self.schema.clone()).await?; + let spilled_rows = spill_record_batches( + batches, + spill_file.path().into(), + Arc::clone(&self.schema), + )?; let used = self.reservation.free(); self.metrics.spill_count.add(1); self.metrics.spilled_bytes.add(used); - self.metrics.spilled_rows.add(spilled_rows as usize); + self.metrics.spilled_rows.add(spilled_rows); self.spills.push(spill_file); Ok(used) } @@ -503,9 +499,19 @@ impl ExternalSorter { &mut self, metrics: BaselineMetrics, ) -> Result { - assert_ne!(self.in_mem_batches.len(), 0); + if self.in_mem_batches.is_empty() { + return Ok(Box::pin(EmptyRecordBatchStream::new(Arc::clone( + &self.schema, + )))); + } + + // The elapsed compute timer is updated when the value is dropped. + // There is no need for an explicit call to drop. + let elapsed_compute = metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + if self.in_mem_batches.len() == 1 { - let batch = self.in_mem_batches.remove(0); + let batch = self.in_mem_batches.swap_remove(0); let reservation = self.reservation.take(); return self.sort_batch_stream(batch, metrics, reservation); } @@ -530,15 +536,15 @@ impl ExternalSorter { }) .collect::>()?; - streaming_merge( - streams, - self.schema.clone(), - &self.expr, - metrics, - self.batch_size, - self.fetch, - self.merge_reservation.new_empty(), - ) + StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(Arc::clone(&self.schema)) + .with_expressions(self.expr.as_ref()) + .with_metrics(metrics) + .with_batch_size(self.batch_size) + .with_fetch(self.fetch) + .with_reservation(self.merge_reservation.new_empty()) + .build() } /// Sorts a single `RecordBatch` into a single stream. @@ -555,9 +561,11 @@ impl ExternalSorter { let schema = batch.schema(); let fetch = self.fetch; - let expressions = self.expr.clone(); + let expressions = Arc::clone(&self.expr); let stream = futures::stream::once(futures::future::lazy(move |_| { + let timer = metrics.elapsed_compute().timer(); let sorted = sort_batch(&batch, &expressions, fetch)?; + timer.done(); metrics.record_output(sorted.num_rows()); drop(batch); drop(reservation); @@ -593,9 +601,9 @@ impl Debug for ExternalSorter { } } -pub(crate) fn sort_batch( +pub fn sort_batch( batch: &RecordBatch, - expressions: &[PhysicalSortExpr], + expressions: LexOrderingRef, fetch: Option, ) -> Result { let sort_columns = expressions @@ -604,18 +612,14 @@ pub(crate) fn sort_batch( .collect::>>()?; let indices = if is_multi_column_with_lists(&sort_columns) { - // lex_sort_to_indices doesn't support List with more than one colum + // lex_sort_to_indices doesn't support List with more than one column // https://github.com/apache/arrow-rs/issues/5454 lexsort_to_indices_multi_columns(sort_columns, fetch)? } else { lexsort_to_indices(&sort_columns, fetch)? }; - let columns = batch - .columns() - .iter() - .map(|c| take(c.as_ref(), &indices, None)) - .collect::>()?; + let columns = take_arrays(batch.columns(), &indices, None)?; let options = RecordBatchOptions::new().with_row_count(Some(indices.len())); Ok(RecordBatch::try_new_with_options( @@ -667,80 +671,16 @@ pub(crate) fn lexsort_to_indices_multi_columns( Ok(indices) } -/// Spills sorted `in_memory_batches` to disk. -/// -/// Returns number of the rows spilled to disk. -async fn spill_sorted_batches( - batches: Vec, - path: &Path, - schema: SchemaRef, -) -> Result { - let path: PathBuf = path.into(); - let task = SpawnedTask::spawn_blocking(move || write_sorted(batches, path, schema)); - match task.join().await { - Ok(r) => r, - Err(e) => exec_err!("Error occurred while spilling {e}"), - } -} - -pub(crate) fn read_spill_as_stream( - path: RefCountedTempFile, - schema: SchemaRef, -) -> Result { - let mut builder = RecordBatchReceiverStream::builder(schema, 2); - let sender = builder.tx(); - - builder.spawn_blocking(move || { - let result = read_spill(sender, path.path()); - if let Err(e) = &result { - error!("Failure while reading spill file: {:?}. Error: {}", path, e); - } - result - }); - - Ok(builder.build()) -} - -fn write_sorted( - batches: Vec, - path: PathBuf, - schema: SchemaRef, -) -> Result { - let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; - for batch in batches { - writer.write(&batch)?; - } - writer.finish()?; - debug!( - "Spilled {} batches of total {} rows to disk, memory released {}", - writer.num_batches, - writer.num_rows, - human_readable_size(writer.num_bytes as usize), - ); - Ok(writer.num_rows) -} - -fn read_spill(sender: Sender>, path: &Path) -> Result<()> { - let file = BufReader::new(File::open(path)?); - let reader = FileReader::try_new(file, None)?; - for batch in reader { - sender - .blocking_send(batch.map_err(Into::into)) - .map_err(|e| DataFusionError::Execution(format!("{e}")))?; - } - Ok(()) -} - /// Sort execution plan. /// /// Support sorting datasets that are larger than the memory allotted /// by the memory manager, by spilling to disk. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct SortExec { /// Input schema pub(crate) input: Arc, /// Sort expressions - expr: Vec, + expr: LexOrdering, /// Containing all metrics set created during sort metrics_set: ExecutionPlanMetricsSet, /// Preserve partitions of input plan. If false, the input partitions @@ -755,7 +695,7 @@ pub struct SortExec { impl SortExec { /// Create a new sort execution plan that produces a single, /// sorted output partition. - pub fn new(expr: Vec, input: Arc) -> Self { + pub fn new(expr: LexOrdering, input: Arc) -> Self { let preserve_partitioning = false; let cache = Self::compute_properties(&input, expr.clone(), preserve_partitioning); Self { @@ -776,7 +716,7 @@ impl SortExec { /// Specify the partitioning behavior of this sort exec /// /// If `preserve_partitioning` is true, sorts each partition - /// individually, producing one sorted strema for each input partition. + /// individually, producing one sorted stream for each input partition. /// /// If `preserve_partitioning` is false, sorts and merges all /// input partitions producing a single, sorted partition. @@ -798,9 +738,22 @@ impl SortExec { /// This can reduce the memory pressure required by the sort /// operation since rows that are not going to be included /// can be dropped. - pub fn with_fetch(mut self, fetch: Option) -> Self { - self.fetch = fetch; - self + pub fn with_fetch(&self, fetch: Option) -> Self { + let mut cache = self.cache.clone(); + if fetch.is_some() && self.cache.execution_mode == ExecutionMode::Unbounded { + // When a theoretically unnecessary sort becomes a top-K (which + // sometimes arises as an intermediate state before full removal), + // its execution mode should become `Bounded`. + cache.execution_mode = ExecutionMode::Bounded; + } + SortExec { + input: Arc::clone(&self.input), + expr: self.expr.clone(), + metrics_set: self.metrics_set.clone(), + preserve_partitioning: self.preserve_partitioning, + fetch, + cache, + } } /// Input schema @@ -809,8 +762,8 @@ impl SortExec { } /// Sort expressions - pub fn expr(&self) -> &[PhysicalSortExpr] { - &self.expr + pub fn expr(&self) -> LexOrderingRef { + self.expr.as_ref() } /// If `Some(fetch)`, limits output to only the first "fetch" items @@ -836,6 +789,18 @@ impl SortExec { sort_exprs: LexOrdering, preserve_partitioning: bool, ) -> PlanProperties { + // Determine execution mode: + let sort_satisfied = input.equivalence_properties().ordering_satisfy_requirement( + PhysicalSortRequirement::from_sort_exprs(sort_exprs.iter()) + .inner + .as_slice(), + ); + let mode = match input.execution_mode() { + ExecutionMode::Unbounded if sort_satisfied => ExecutionMode::Unbounded, + ExecutionMode::Bounded => ExecutionMode::Bounded, + _ => ExecutionMode::PipelineBreaking, + }; + // Calculate equivalence properties; i.e. reset the ordering equivalence // class with the new ordering: let eq_properties = input @@ -847,32 +812,20 @@ impl SortExec { let output_partitioning = Self::output_partitioning_helper(input, preserve_partitioning); - // Determine execution mode: - let mode = match input.execution_mode() { - ExecutionMode::Unbounded | ExecutionMode::PipelineBreaking => { - ExecutionMode::PipelineBreaking - } - ExecutionMode::Bounded => ExecutionMode::Bounded, - }; - PlanProperties::new(eq_properties, output_partitioning, mode) } } impl DisplayAs for SortExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let expr = PhysicalSortExpr::format_list(&self.expr); + let preserve_partitioning = self.preserve_partitioning; match self.fetch { Some(fetch) => { - write!(f, "SortExec: TopK(fetch={fetch}), expr=[{expr}]",) + write!(f, "SortExec: TopK(fetch={fetch}), expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr) } - None => write!(f, "SortExec: expr=[{expr}]"), + None => write!(f, "SortExec: expr=[{}], preserve_partitioning=[{preserve_partitioning}]", self.expr), } } } @@ -902,8 +855,8 @@ impl ExecutionPlan for SortExec { } } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn benefits_from_input_partitioning(&self) -> Vec { @@ -914,7 +867,7 @@ impl ExecutionPlan for SortExec { self: Arc, children: Vec>, ) -> Result> { - let new_sort = SortExec::new(self.expr.clone(), children[0].clone()) + let new_sort = SortExec::new(self.expr.clone(), Arc::clone(&children[0])) .with_fetch(self.fetch) .with_preserve_partitioning(self.preserve_partitioning); @@ -928,59 +881,76 @@ impl ExecutionPlan for SortExec { ) -> Result { trace!("Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - let mut input = self.input.execute(partition, context.clone())?; + let mut input = self.input.execute(partition, Arc::clone(&context))?; let execution_options = &context.session_config().options().execution; trace!("End SortExec's input.execute for partition: {}", partition); - if let Some(fetch) = self.fetch.as_ref() { - let mut topk = TopK::try_new( - partition, - input.schema(), - self.expr.clone(), - *fetch, - context.session_config().batch_size(), - context.runtime_env(), - &self.metrics_set, - partition, - )?; - - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - topk.insert_batch(batch)?; - } - topk.emit() - }) - .try_flatten(), - ))) - } else { - let mut sorter = ExternalSorter::new( - partition, - input.schema(), - self.expr.clone(), - context.session_config().batch_size(), - self.fetch, - execution_options.sort_spill_reservation_bytes, - execution_options.sort_in_place_threshold_bytes, - &self.metrics_set, - context.runtime_env(), + let sort_satisfied = self + .input + .equivalence_properties() + .ordering_satisfy_requirement( + PhysicalSortRequirement::from_sort_exprs(self.expr.iter()) + .inner + .as_slice(), ); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; - } - sorter.sort() - }) - .try_flatten(), - ))) + match (sort_satisfied, self.fetch.as_ref()) { + (true, Some(fetch)) => Ok(Box::pin(LimitStream::new( + input, + 0, + Some(*fetch), + BaselineMetrics::new(&self.metrics_set, partition), + ))), + (true, None) => Ok(input), + (false, Some(fetch)) => { + let mut topk = TopK::try_new( + partition, + input.schema(), + self.expr.clone(), + *fetch, + context.session_config().batch_size(), + context.runtime_env(), + &self.metrics_set, + partition, + )?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + topk.insert_batch(batch)?; + } + topk.emit() + }) + .try_flatten(), + ))) + } + (false, None) => { + let mut sorter = ExternalSorter::new( + partition, + input.schema(), + self.expr.clone(), + context.session_config().batch_size(), + self.fetch, + execution_options.sort_spill_reservation_bytes, + execution_options.sort_in_place_threshold_bytes, + &self.metrics_set, + context.runtime_env(), + ); + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + futures::stream::once(async move { + while let Some(batch) = input.next().await { + let batch = batch?; + sorter.insert_batch(batch).await?; + } + sorter.sort() + }) + .try_flatten(), + ))) + } } } @@ -989,13 +959,31 @@ impl ExecutionPlan for SortExec { } fn statistics(&self) -> Result { - self.input.statistics() + Statistics::with_fetch(self.input.statistics()?, self.schema(), self.fetch, 0, 1) + } + + fn with_fetch(&self, limit: Option) -> Option> { + Some(Arc::new(SortExec::with_fetch(self, limit))) + } + + fn fetch(&self) -> Option { + self.fetch + } + + fn cardinality_effect(&self) -> CardinalityEffect { + if self.fetch.is_none() { + CardinalityEffect::Equal + } else { + CardinalityEffect::LowerEqual + } } } #[cfg(test)] mod tests { use std::collections::HashMap; + use std::pin::Pin; + use std::task::{Context, Poll}; use super::*; use crate::coalesce_partitions::CoalescePartitionsExec; @@ -1010,12 +998,123 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::*; use datafusion_common::cast::as_primitive_array; + use datafusion_common::{assert_batches_eq, Result, ScalarValue}; use datafusion_execution::config::SessionConfig; - use datafusion_execution::runtime_env::RuntimeConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_execution::RecordBatchStream; + use datafusion_physical_expr::expressions::{Column, Literal}; + use datafusion_physical_expr::EquivalenceProperties; + + use futures::{FutureExt, Stream}; + + #[derive(Debug, Clone)] + pub struct SortedUnboundedExec { + schema: Schema, + batch_size: u64, + cache: PlanProperties, + } + + impl DisplayAs for SortedUnboundedExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "UnboundableExec",).unwrap() + } + } + Ok(()) + } + } + + impl SortedUnboundedExec { + fn compute_properties(schema: SchemaRef) -> PlanProperties { + let mut eq_properties = EquivalenceProperties::new(schema); + eq_properties.add_new_orderings(vec![LexOrdering::new(vec![ + PhysicalSortExpr::new_default(Arc::new(Column::new("c1", 0))), + ])]); + let mode = ExecutionMode::Unbounded; + PlanProperties::new(eq_properties, Partitioning::UnknownPartitioning(1), mode) + } + } - use datafusion_common::ScalarValue; - use datafusion_physical_expr::expressions::Literal; - use futures::FutureExt; + impl ExecutionPlan for SortedUnboundedExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(SortedUnboundedStream { + schema: Arc::new(self.schema.clone()), + batch_size: self.batch_size, + offset: 0, + })) + } + } + + #[derive(Debug)] + pub struct SortedUnboundedStream { + schema: SchemaRef, + batch_size: u64, + offset: u64, + } + + impl Stream for SortedUnboundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + let batch = SortedUnboundedStream::create_record_batch( + Arc::clone(&self.schema), + self.offset, + self.batch_size, + ); + self.offset += self.batch_size; + Poll::Ready(Some(Ok(batch))) + } + } + + impl RecordBatchStream for SortedUnboundedStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + impl SortedUnboundedStream { + fn create_record_batch( + schema: SchemaRef, + offset: u64, + batch_size: u64, + ) -> RecordBatch { + let values = (0..batch_size).map(|i| offset + i).collect::>(); + let array = UInt64Array::from(values); + let array_ref: ArrayRef = Arc::new(array); + RecordBatch::try_new(schema, vec![array_ref]).unwrap() + } + } #[tokio::test] async fn test_in_mem_sort() -> Result<()> { @@ -1025,14 +1124,14 @@ mod tests { let schema = csv.schema(); let sort_exec = Arc::new(SortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }], + }]), Arc::new(CoalescePartitionsExec::new(csv)), )); - let result = collect(sort_exec, task_ctx.clone()).await?; + let result = collect(sort_exec, Arc::clone(&task_ctx)).await?; assert_eq!(result.len(), 1); assert_eq!(result[0].num_rows(), 400); @@ -1054,9 +1153,9 @@ mod tests { .options() .execution .sort_spill_reservation_bytes; - let rt_config = RuntimeConfig::new() - .with_memory_limit(sort_spill_reservation_bytes + 12288, 1.0); - let runtime = Arc::new(RuntimeEnv::new(rt_config)?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(sort_spill_reservation_bytes + 12288, 1.0) + .build_arc()?; let task_ctx = Arc::new( TaskContext::default() .with_session_config(session_config) @@ -1068,14 +1167,18 @@ mod tests { let schema = input.schema(); let sort_exec = Arc::new(SortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }], + }]), Arc::new(CoalescePartitionsExec::new(input)), )); - let result = collect(sort_exec.clone(), task_ctx.clone()).await?; + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; assert_eq!(result.len(), 2); @@ -1126,11 +1229,12 @@ mod tests { .execution .sort_spill_reservation_bytes; - let rt_config = RuntimeConfig::new().with_memory_limit( - sort_spill_reservation_bytes + avg_batch_size * (partitions - 1), - 1.0, - ); - let runtime = Arc::new(RuntimeEnv::new(rt_config)?); + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit( + sort_spill_reservation_bytes + avg_batch_size * (partitions - 1), + 1.0, + ) + .build_arc()?; let task_ctx = Arc::new( TaskContext::default() .with_runtime(runtime) @@ -1142,16 +1246,20 @@ mod tests { let sort_exec = Arc::new( SortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema)?, options: SortOptions::default(), - }], + }]), Arc::new(CoalescePartitionsExec::new(csv)), ) .with_fetch(fetch), ); - let result = collect(sort_exec.clone(), task_ctx.clone()).await?; + let result = collect( + Arc::clone(&sort_exec) as Arc, + Arc::clone(&task_ctx), + ) + .await?; assert_eq!(result.len(), 1); let metrics = sort_exec.metrics().unwrap(); @@ -1181,15 +1289,16 @@ mod tests { let data: ArrayRef = Arc::new(vec![3, 2, 1].into_iter().map(Some).collect::()); - let batch = RecordBatch::try_new(schema.clone(), vec![data]).unwrap(); - let input = - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None).unwrap()); + let batch = RecordBatch::try_new(Arc::clone(&schema), vec![data]).unwrap(); + let input = Arc::new( + MemoryExec::try_new(&[vec![batch]], Arc::clone(&schema), None).unwrap(), + ); let sort_exec = Arc::new(SortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("field_name", &schema)?, options: SortOptions::default(), - }], + }]), input, )); @@ -1198,7 +1307,7 @@ mod tests { let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); let expected_batch = - RecordBatch::try_new(schema.clone(), vec![expected_data]).unwrap(); + RecordBatch::try_new(Arc::clone(&schema), vec![expected_data]).unwrap(); // Data is correct assert_eq!(&vec![expected_batch], &result); @@ -1224,7 +1333,7 @@ mod tests { // define data. let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![Some(2), None, Some(1), Some(2)])), Arc::new(ListArray::from_iter_primitive::(vec![ @@ -1237,7 +1346,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1252,8 +1361,12 @@ mod tests { nulls_first: false, }, }, - ], - Arc::new(MemoryExec::try_new(&[vec![batch]], schema.clone(), None)?), + ]), + Arc::new(MemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?), )); assert_eq!(DataType::Int32, *sort_exec.schema().field(0).data_type()); @@ -1262,7 +1375,8 @@ mod tests { *sort_exec.schema().field(1).data_type() ); - let result: Vec = collect(sort_exec.clone(), task_ctx).await?; + let result: Vec = + collect(Arc::clone(&sort_exec) as Arc, task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 4); @@ -1296,7 +1410,7 @@ mod tests { // define data. let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![ Arc::new(Float32Array::from(vec![ Some(f32::NAN), @@ -1322,7 +1436,7 @@ mod tests { )?; let sort_exec = Arc::new(SortExec::new( - vec![ + LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -1337,14 +1451,15 @@ mod tests { nulls_first: false, }, }, - ], + ]), Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None)?), )); assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); - let result: Vec = collect(sort_exec.clone(), task_ctx).await?; + let result: Vec = + collect(Arc::clone(&sort_exec) as Arc, task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 8); @@ -1400,14 +1515,14 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 1)); let refs = blocking_exec.refs(); let sort_exec = Arc::new(SortExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }], + }]), blocking_exec, )); - let fut = collect(sort_exec, task_ctx.clone()); + let fut = collect(sort_exec, Arc::clone(&task_ctx)); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1428,14 +1543,52 @@ mod tests { let schema = Arc::new(Schema::empty()); let options = RecordBatchOptions::new().with_row_count(Some(1)); let batch = - RecordBatch::try_new_with_options(schema.clone(), vec![], &options).unwrap(); + RecordBatch::try_new_with_options(Arc::clone(&schema), vec![], &options) + .unwrap(); - let expressions = vec![PhysicalSortExpr { + let expressions = LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Literal::new(ScalarValue::Int64(Some(1)))), options: SortOptions::default(), - }]; + }]); - let result = sort_batch(&batch, &expressions, None).unwrap(); + let result = sort_batch(&batch, expressions.as_ref(), None).unwrap(); assert_eq!(result.num_rows(), 1); } + + #[tokio::test] + async fn topk_unbounded_source() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); + let source = SortedUnboundedExec { + schema: schema.clone(), + batch_size: 2, + cache: SortedUnboundedExec::compute_properties(Arc::new(schema.clone())), + }; + let mut plan = SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + "c1", 0, + )))]), + Arc::new(source), + ); + plan = plan.with_fetch(Some(9)); + + let batches = collect(Arc::new(plan), task_ctx).await?; + #[rustfmt::skip] + let expected = [ + "+----+", + "| c1 |", + "+----+", + "| 0 |", + "| 1 |", + "| 2 |", + "| 3 |", + "| 4 |", + "| 5 |", + "| 6 |", + "| 7 |", + "| 8 |", + "+----+",]; + assert_batches_eq!(expected, &batches); + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 88c6c312b94b..9ee0faaa0a44 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -21,9 +21,9 @@ use std::any::Any; use std::sync::Arc; use crate::common::spawn_buffered; -use crate::expressions::PhysicalSortExpr; +use crate::limit::LimitStream; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; -use crate::sorts::streaming_merge; +use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::{ DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties, Partitioning, PlanProperties, SendableRecordBatchStream, Statistics, @@ -34,6 +34,9 @@ use datafusion_execution::memory_pool::MemoryConsumer; use datafusion_execution::TaskContext; use datafusion_physical_expr::PhysicalSortRequirement; +use datafusion_physical_expr_common::sort_expr::{ + LexOrdering, LexOrderingRef, LexRequirement, +}; use log::{debug, trace}; /// Sort preserving merge execution plan @@ -63,45 +66,63 @@ use log::{debug, trace}; /// Input Streams Output stream /// (sorted) (sorted) /// ``` -#[derive(Debug)] +/// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// the output and inputs are not polled again. +#[derive(Debug, Clone)] pub struct SortPreservingMergeExec { /// Input plan input: Arc, /// Sort expressions - expr: Vec, + expr: LexOrdering, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Optional number of rows to fetch. Stops producing rows after this fetch fetch: Option, /// Cache holding plan properties like equivalences, output partitioning etc. cache: PlanProperties, + /// Configuration parameter to enable round-robin selection of tied winners of loser tree. + enable_round_robin_repartition: bool, } impl SortPreservingMergeExec { /// Create a new sort execution plan - pub fn new(expr: Vec, input: Arc) -> Self { - let cache = Self::compute_properties(&input); + pub fn new(expr: LexOrdering, input: Arc) -> Self { + let cache = Self::compute_properties(&input, expr.clone()); Self { input, expr, metrics: ExecutionPlanMetricsSet::new(), fetch: None, cache, + enable_round_robin_repartition: true, } } + /// Sets the number of rows to fetch pub fn with_fetch(mut self, fetch: Option) -> Self { self.fetch = fetch; self } + /// Sets the selection strategy of tied winners of the loser tree algorithm + pub fn with_round_robin_repartition( + mut self, + enable_round_robin_repartition: bool, + ) -> Self { + self.enable_round_robin_repartition = enable_round_robin_repartition; + self + } + /// Input schema pub fn input(&self) -> &Arc { &self.input } /// Sort expressions - pub fn expr(&self) -> &[PhysicalSortExpr] { + pub fn expr(&self) -> LexOrderingRef { &self.expr } @@ -111,11 +132,17 @@ impl SortPreservingMergeExec { } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties(input: &Arc) -> PlanProperties { + fn compute_properties( + input: &Arc, + ordering: LexOrdering, + ) -> PlanProperties { + let mut eq_properties = input.equivalence_properties().clone(); + eq_properties.clear_per_partition_constants(); + eq_properties.add_new_orderings(vec![ordering]); PlanProperties::new( - input.equivalence_properties().clone(), // Equivalence Properties - Partitioning::UnknownPartitioning(1), // Output Partitioning - input.execution_mode(), // Execution Mode + eq_properties, // Equivalence Properties + Partitioning::UnknownPartitioning(1), // Output Partitioning + input.execution_mode(), // Execution Mode ) } } @@ -128,11 +155,7 @@ impl DisplayAs for SortPreservingMergeExec { ) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!( - f, - "SortPreservingMergeExec: [{}]", - PhysicalSortExpr::format_list(&self.expr) - )?; + write!(f, "SortPreservingMergeExec: [{}]", self.expr)?; if let Some(fetch) = self.fetch { write!(f, ", fetch={fetch}")?; }; @@ -157,6 +180,22 @@ impl ExecutionPlan for SortPreservingMergeExec { &self.cache } + fn fetch(&self) -> Option { + self.fetch + } + + /// Sets the number of rows to fetch + fn with_fetch(&self, limit: Option) -> Option> { + Some(Arc::new(Self { + input: Arc::clone(&self.input), + expr: self.expr.clone(), + metrics: self.metrics.clone(), + fetch: limit, + cache: self.cache.clone(), + enable_round_robin_repartition: true, + })) + } + fn required_input_distribution(&self) -> Vec { vec![Distribution::UnspecifiedDistribution] } @@ -165,16 +204,18 @@ impl ExecutionPlan for SortPreservingMergeExec { vec![false] } - fn required_input_ordering(&self) -> Vec>> { - vec![Some(PhysicalSortRequirement::from_sort_exprs(&self.expr))] + fn required_input_ordering(&self) -> Vec> { + vec![Some(PhysicalSortRequirement::from_sort_exprs( + self.expr.iter(), + ))] } fn maintains_input_order(&self) -> Vec { vec![true] } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -182,7 +223,7 @@ impl ExecutionPlan for SortPreservingMergeExec { children: Vec>, ) -> Result> { Ok(Arc::new( - SortPreservingMergeExec::new(self.expr.clone(), children[0].clone()) + SortPreservingMergeExec::new(self.expr.clone(), Arc::clone(&children[0])) .with_fetch(self.fetch), )) } @@ -217,31 +258,44 @@ impl ExecutionPlan for SortPreservingMergeExec { 0 => internal_err!( "SortPreservingMergeExec requires at least one input partition" ), - 1 => { - // bypass if there is only one partition to merge (no metrics in this case either) - let result = self.input.execute(0, context); - debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input"); - result - } + 1 => match self.fetch { + Some(fetch) => { + let stream = self.input.execute(0, context)?; + debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"); + Ok(Box::pin(LimitStream::new( + stream, + 0, + Some(fetch), + BaselineMetrics::new(&self.metrics, partition), + ))) + } + None => { + let stream = self.input.execute(0, context); + debug!("Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"); + stream + } + }, _ => { let receivers = (0..input_partitions) .map(|partition| { - let stream = self.input.execute(partition, context.clone())?; + let stream = + self.input.execute(partition, Arc::clone(&context))?; Ok(spawn_buffered(stream, 1)) }) .collect::>()?; debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); - let result = streaming_merge( - receivers, - schema, - &self.expr, - BaselineMetrics::new(&self.metrics, partition), - context.session_config().batch_size(), - self.fetch, - reservation, - )?; + let result = StreamingMergeBuilder::new() + .with_streams(receivers) + .with_schema(schema) + .with_expressions(self.expr.as_ref()) + .with_metrics(BaselineMetrics::new(&self.metrics, partition)) + .with_batch_size(context.session_config().batch_size()) + .with_fetch(self.fetch) + .with_reservation(reservation) + .with_round_robin_tie_breaker(self.enable_round_robin_repartition) + .build()?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -257,30 +311,125 @@ impl ExecutionPlan for SortPreservingMergeExec { fn statistics(&self) -> Result { self.input.statistics() } + + fn supports_limit_pushdown(&self) -> bool { + true + } } #[cfg(test)] mod tests { + use std::fmt::Formatter; + use std::pin::Pin; + use std::sync::Mutex; + use std::task::{Context, Poll}; + use std::time::Duration; use super::*; + use crate::coalesce_batches::CoalesceBatchesExec; use crate::coalesce_partitions::CoalescePartitionsExec; use crate::expressions::col; use crate::memory::MemoryExec; use crate::metrics::{MetricValue, Timestamp}; + use crate::repartition::RepartitionExec; use crate::sorts::sort::SortExec; use crate::stream::RecordBatchReceiverStream; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending, make_partition}; - use crate::{collect, common}; + use crate::{collect, common, ExecutionMode}; use arrow::array::{ArrayRef, Int32Array, StringArray, TimestampNanosecondArray}; use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - use datafusion_common::{assert_batches_eq, assert_contains}; + use arrow_array::Int64Array; + use arrow_schema::SchemaRef; + use datafusion_common::{assert_batches_eq, assert_contains, DataFusionError}; + use datafusion_common_runtime::SpawnedTask; use datafusion_execution::config::SessionConfig; + use datafusion_execution::runtime_env::RuntimeEnvBuilder; + use datafusion_execution::RecordBatchStream; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::EquivalenceProperties; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; + + use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; + use futures::{FutureExt, Stream, StreamExt}; + use tokio::time::timeout; + + // The number in the function is highly related to the memory limit we are testing + // any change of the constant should be aware of + fn generate_task_ctx_for_round_robin_tie_breaker() -> Result> { + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(20_000_000, 1.0) + .build_arc()?; + let config = SessionConfig::new(); + let task_ctx = TaskContext::default() + .with_runtime(runtime) + .with_session_config(config); + Ok(Arc::new(task_ctx)) + } + // The number in the function is highly related to the memory limit we are testing, + // any change of the constant should be aware of + fn generate_spm_for_round_robin_tie_breaker( + enable_round_robin_repartition: bool, + ) -> Result> { + let target_batch_size = 12500; + let row_size = 12500; + let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size])); + let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size])); + let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); + + let rbs = (0..1024).map(|_| rb.clone()).collect::>(); + + let schema = rb.schema(); + let sort = LexOrdering::new(vec![ + PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }, + PhysicalSortExpr { + expr: col("c", &schema).unwrap(), + options: Default::default(), + }, + ]); + + let exec = MemoryExec::try_new(&[rbs], schema, None).unwrap(); + let repartition_exec = + RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(2))?; + let coalesce_batches_exec = + CoalesceBatchesExec::new(Arc::new(repartition_exec), target_batch_size); + let spm = SortPreservingMergeExec::new(sort, Arc::new(coalesce_batches_exec)) + .with_round_robin_repartition(enable_round_robin_repartition); + Ok(Arc::new(spm)) + } - use futures::{FutureExt, StreamExt}; + /// This test verifies that memory usage stays within limits when the tie breaker is enabled. + /// Any errors here could indicate unintended changes in tie breaker logic. + /// + /// Note: If you adjust constants in this test, ensure that memory usage differs + /// based on whether the tie breaker is enabled or disabled. + #[tokio::test(flavor = "multi_thread")] + async fn test_round_robin_tie_breaker_success() -> Result<()> { + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let spm = generate_spm_for_round_robin_tie_breaker(true)?; + let _collected = collect(spm, task_ctx).await.unwrap(); + Ok(()) + } + + /// This test verifies that memory usage stays within limits when the tie breaker is enabled. + /// Any errors here could indicate unintended changes in tie breaker logic. + /// + /// Note: If you adjust constants in this test, ensure that memory usage differs + /// based on whether the tie breaker is enabled or disabled. + #[tokio::test(flavor = "multi_thread")] + async fn test_round_robin_tie_breaker_fail() -> Result<()> { + let task_ctx = generate_task_ctx_for_round_robin_tie_breaker()?; + let spm = generate_spm_for_round_robin_tie_breaker(false)?; + let _err = collect(spm, task_ctx).await.unwrap_err(); + Ok(()) + } #[tokio::test] async fn test_merge_interleave() { @@ -337,7 +486,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("a", a)]).unwrap(); let schema = batch.schema(); - let sort = vec![]; // no sort expressions + let sort = LexOrdering::default(); // no sort expressions let exec = MemoryExec::try_new(&[vec![batch.clone()], vec![batch]], schema, None) .unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); @@ -516,7 +665,7 @@ mod tests { context: Arc, ) { let schema = partitions[0][0].schema(); - let sort = vec![ + let sort = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), @@ -525,7 +674,7 @@ mod tests { expr: col("c", &schema).unwrap(), options: Default::default(), }, - ]; + ]); let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); @@ -535,7 +684,7 @@ mod tests { async fn sorted_merge( input: Arc, - sort: Vec, + sort: LexOrdering, context: Arc, ) -> RecordBatch { let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); @@ -546,7 +695,7 @@ mod tests { async fn partition_sort( input: Arc, - sort: Vec, + sort: LexOrdering, context: Arc, ) -> RecordBatch { let sort_exec = @@ -556,7 +705,7 @@ mod tests { async fn basic_sort( src: Arc, - sort: Vec, + sort: LexOrdering, context: Arc, ) -> RecordBatch { let merge = Arc::new(CoalescePartitionsExec::new(src)); @@ -573,16 +722,17 @@ mod tests { let csv = test::scan_partitioned(partitions); let schema = csv.schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: SortOptions { descending: true, nulls_first: true, }, - }]; + }]); - let basic = basic_sort(csv.clone(), sort.clone(), task_ctx.clone()).await; - let partition = partition_sort(csv, sort, task_ctx.clone()).await; + let basic = + basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await; + let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -623,7 +773,7 @@ mod tests { } async fn sorted_partitioned_input( - sort: Vec, + sort: LexOrdering, sizes: &[usize], context: Arc, ) -> Result> { @@ -642,16 +792,17 @@ mod tests { async fn test_partition_sort_streaming_input() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: Default::default(), - }]; + }]); let input = - sorted_partitioned_input(sort.clone(), &[10, 3, 11], task_ctx.clone()) + sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx)) .await?; - let basic = basic_sort(input.clone(), sort.clone(), task_ctx.clone()).await; - let partition = sorted_merge(input, sort, task_ctx.clone()).await; + let basic = + basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await; + let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await; assert_eq!(basic.num_rows(), 1200); assert_eq!(partition.num_rows(), 1200); @@ -671,17 +822,17 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input_output() -> Result<()> { let schema = make_partition(11).schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: Default::default(), - }]; + }]); // Test streaming with default batch size let task_ctx = Arc::new(TaskContext::default()); let input = - sorted_partitioned_input(sort.clone(), &[10, 5, 13], task_ctx.clone()) + sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx)) .await?; - let basic = basic_sort(input.clone(), sort.clone(), task_ctx).await; + let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await; // batch size of 23 let task_ctx = TaskContext::default() @@ -746,7 +897,7 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap(); let schema = b1.schema(); - let sort = vec![ + let sort = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: SortOptions { @@ -761,7 +912,7 @@ mod tests { nulls_first: false, }, }, - ]; + ]); let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); @@ -789,32 +940,106 @@ mod tests { ); } + #[tokio::test] + async fn test_sort_merge_single_partition_with_fetch() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + let schema = batch.schema(); + + let sort = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]); + let exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap(); + let merge = Arc::new( + SortPreservingMergeExec::new(sort, Arc::new(exec)).with_fetch(Some(2)), + ); + + let collected = collect(merge, task_ctx).await.unwrap(); + assert_eq!(collected.len(), 1); + + assert_batches_eq!( + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | a |", + "| 2 | b |", + "+---+---+", + ], + collected.as_slice() + ); + } + + #[tokio::test] + async fn test_sort_merge_single_partition_without_fetch() { + let task_ctx = Arc::new(TaskContext::default()); + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + let schema = batch.schema(); + + let sort = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]); + let exec = MemoryExec::try_new(&[vec![batch]], schema, None).unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); + + let collected = collect(merge, task_ctx).await.unwrap(); + assert_eq!(collected.len(), 1); + + assert_batches_eq!( + &[ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | a |", + "| 2 | b |", + "| 7 | c |", + "| 9 | d |", + "| 3 | e |", + "+---+---+", + ], + collected.as_slice() + ); + } + #[tokio::test] async fn test_async() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); let schema = make_partition(11).schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("i", &schema).unwrap(), options: SortOptions::default(), - }]; + }]); let batches = - sorted_partitioned_input(sort.clone(), &[5, 7, 3], task_ctx.clone()).await?; + sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx)) + .await?; let partition_count = batches.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(partition_count); for partition in 0..partition_count { - let mut builder = RecordBatchReceiverStream::builder(schema.clone(), 1); + let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1); let sender = builder.tx(); - let mut stream = batches.execute(partition, task_ctx.clone()).unwrap(); + let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap(); builder.spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); // This causes the MergeStream to wait for more input - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + tokio::time::sleep(Duration::from_millis(10)).await; } Ok(()) @@ -828,22 +1053,21 @@ mod tests { MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool); let fetch = None; - let merge_stream = streaming_merge( - streams, - batches.schema(), - sort.as_slice(), - BaselineMetrics::new(&metrics, 0), - task_ctx.session_config().batch_size(), - fetch, - reservation, - ) - .unwrap(); + let merge_stream = StreamingMergeBuilder::new() + .with_streams(streams) + .with_schema(batches.schema()) + .with_expressions(sort.as_ref()) + .with_metrics(BaselineMetrics::new(&metrics, 0)) + .with_batch_size(task_ctx.session_config().batch_size()) + .with_fetch(fetch) + .with_reservation(reservation) + .build()?; let mut merged = common::collect(merge_stream).await.unwrap(); assert_eq!(merged.len(), 1); let merged = merged.remove(0); - let basic = basic_sort(batches, sort.clone(), task_ctx.clone()).await; + let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -872,14 +1096,16 @@ mod tests { let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); let schema = b1.schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("b", &schema).unwrap(), options: Default::default(), - }]; + }]); let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge.clone(), task_ctx).await.unwrap(); + let collected = collect(Arc::clone(&merge) as Arc, task_ctx) + .await + .unwrap(); let expected = [ "+----+---+", "| a | b |", @@ -929,10 +1155,10 @@ mod tests { let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2)); let refs = blocking_exec.refs(); let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new( - vec![PhysicalSortExpr { + LexOrdering::new(vec![PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions::default(), - }], + }]), blocking_exec, )); @@ -977,13 +1203,13 @@ mod tests { let schema = partitions[0][0].schema(); - let sort = vec![PhysicalSortExpr { + let sort = LexOrdering::new(vec![PhysicalSortExpr { expr: col("value", &schema).unwrap(), options: SortOptions { descending: false, nulls_first: true, }, - }]; + }]); let exec = MemoryExec::try_new(&partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); @@ -1024,4 +1250,154 @@ mod tests { collected.as_slice() ); } + + /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st + /// partition is exhausted from the start, and if it is polled more than one, it panics. + #[derive(Debug, Clone)] + struct CongestedExec { + schema: Schema, + cache: PlanProperties, + congestion_cleared: Arc>, + } + + impl CongestedExec { + fn compute_properties(schema: SchemaRef) -> PlanProperties { + let columns = schema + .fields + .iter() + .enumerate() + .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc) + .collect::>(); + let mut eq_properties = EquivalenceProperties::new(schema); + eq_properties.add_new_orderings(vec![columns + .iter() + .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))) + .collect::()]); + let mode = ExecutionMode::Unbounded; + PlanProperties::new(eq_properties, Partitioning::Hash(columns, 3), mode) + } + } + + impl ExecutionPlan for CongestedExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { + self + } + fn properties(&self) -> &PlanProperties { + &self.cache + } + fn children(&self) -> Vec<&Arc> { + vec![] + } + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + Ok(self) + } + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> Result { + Ok(Box::pin(CongestedStream { + schema: Arc::new(self.schema.clone()), + none_polled_once: false, + congestion_cleared: Arc::clone(&self.congestion_cleared), + partition, + })) + } + } + + impl DisplayAs for CongestedExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CongestedExec",).unwrap() + } + } + Ok(()) + } + } + + /// It returns pending for the 2nd partition until the 3rd partition is polled. The 1st + /// partition is exhausted from the start, and if it is polled more than once, it panics. + #[derive(Debug)] + pub struct CongestedStream { + schema: SchemaRef, + none_polled_once: bool, + congestion_cleared: Arc>, + partition: usize, + } + + impl Stream for CongestedStream { + type Item = Result; + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + match self.partition { + 0 => { + if self.none_polled_once { + panic!("Exhausted stream is polled more than one") + } else { + self.none_polled_once = true; + Poll::Ready(None) + } + } + 1 => { + let cleared = self.congestion_cleared.lock().unwrap(); + if *cleared { + Poll::Ready(None) + } else { + Poll::Pending + } + } + 2 => { + let mut cleared = self.congestion_cleared.lock().unwrap(); + *cleared = true; + Poll::Ready(None) + } + _ => unreachable!(), + } + } + } + + impl RecordBatchStream for CongestedStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + } + + #[tokio::test] + async fn test_spm_congestion() -> Result<()> { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]); + let source = CongestedExec { + schema: schema.clone(), + cache: CongestedExec::compute_properties(Arc::new(schema.clone())), + congestion_cleared: Arc::new(Mutex::new(false)), + }; + let spm = SortPreservingMergeExec::new( + LexOrdering::new(vec![PhysicalSortExpr::new_default(Arc::new(Column::new( + "c1", 0, + )))]), + Arc::new(source), + ); + let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx)); + + let result = timeout(Duration::from_secs(3), spm_task.join()).await; + match result { + Ok(Ok(Ok(_batches))) => Ok(()), + Ok(Ok(Err(e))) => Err(e), + Ok(Err(_)) => Err(DataFusionError::Execution( + "SortPreservingMerge task panicked or was cancelled".to_string(), + )), + Err(_) => Err(DataFusionError::Execution( + "SortPreservingMerge caused a deadlock".to_string(), + )), + } + } } diff --git a/datafusion/physical-plan/src/sorts/stream.rs b/datafusion/physical-plan/src/sorts/stream.rs index 135b4fbdece4..70beb2c4a91b 100644 --- a/datafusion/physical-plan/src/sorts/stream.rs +++ b/datafusion/physical-plan/src/sorts/stream.rs @@ -24,6 +24,7 @@ use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; use datafusion_common::Result; use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; use futures::stream::{Fuse, StreamExt}; use std::marker::PhantomData; use std::sync::Arc; @@ -92,7 +93,7 @@ pub struct RowCursorStream { impl RowCursorStream { pub fn try_new( schema: &Schema, - expressions: &[PhysicalSortExpr], + expressions: LexOrderingRef, streams: Vec, reservation: MemoryReservation, ) -> Result { @@ -109,7 +110,7 @@ impl RowCursorStream { Ok(Self { converter, reservation, - column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + column_expressions: expressions.iter().map(|x| Arc::clone(&x.expr)).collect(), streams: FusedStreams(streams), }) } diff --git a/datafusion/physical-plan/src/sorts/streaming_merge.rs b/datafusion/physical-plan/src/sorts/streaming_merge.rs index 9e6618dd1af5..bd74685eac94 100644 --- a/datafusion/physical-plan/src/sorts/streaming_merge.rs +++ b/datafusion/physical-plan/src/sorts/streaming_merge.rs @@ -23,11 +23,12 @@ use crate::sorts::{ merge::SortPreservingMergeStream, stream::{FieldCursorStream, RowCursorStream}, }; -use crate::{PhysicalSortExpr, SendableRecordBatchStream}; +use crate::SendableRecordBatchStream; use arrow::datatypes::{DataType, SchemaRef}; use arrow_array::*; use datafusion_common::{internal_err, Result}; use datafusion_execution::memory_pool::MemoryReservation; +use datafusion_physical_expr_common::sort_expr::LexOrderingRef; macro_rules! primitive_merge_helper { ($t:ty, $($v:ident),+) => { @@ -36,7 +37,7 @@ macro_rules! primitive_merge_helper { } macro_rules! merge_helper { - ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident) => {{ + ($t:ty, $sort:ident, $streams:ident, $schema:ident, $tracking_metrics:ident, $batch_size:ident, $fetch:ident, $reservation:ident, $enable_round_robin_tie_breaker:ident) => {{ let streams = FieldCursorStream::<$t>::new($sort, $streams); return Ok(Box::pin(SortPreservingMergeStream::new( Box::new(streams), @@ -45,53 +46,139 @@ macro_rules! merge_helper { $batch_size, $fetch, $reservation, + $enable_round_robin_tie_breaker, ))); }}; } -/// Perform a streaming merge of [`SendableRecordBatchStream`] based on provided sort expressions -/// while preserving order. -pub fn streaming_merge( +#[derive(Default)] +pub struct StreamingMergeBuilder<'a> { streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - metrics: BaselineMetrics, - batch_size: usize, + schema: Option, + expressions: LexOrderingRef<'a>, + metrics: Option, + batch_size: Option, fetch: Option, - reservation: MemoryReservation, -) -> Result { - // If there are no sort expressions, preserving the order - // doesn't mean anything (and result in infinite loops) - if expressions.is_empty() { - return internal_err!("Sort expressions cannot be empty for streaming merge"); - } - // Special case single column comparisons with optimized cursor implementations - if expressions.len() == 1 { - let sort = expressions[0].clone(); - let data_type = sort.expr.data_type(schema.as_ref())?; - downcast_primitive! { - data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation), - DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation) - _ => {} + reservation: Option, + enable_round_robin_tie_breaker: bool, +} + +impl<'a> StreamingMergeBuilder<'a> { + pub fn new() -> Self { + Self { + enable_round_robin_tie_breaker: true, + ..Default::default() } } - let streams = RowCursorStream::try_new( - schema.as_ref(), - expressions, - streams, - reservation.new_empty(), - )?; - - Ok(Box::pin(SortPreservingMergeStream::new( - Box::new(streams), - schema, - metrics, - batch_size, - fetch, - reservation, - ))) + pub fn with_streams(mut self, streams: Vec) -> Self { + self.streams = streams; + self + } + + pub fn with_schema(mut self, schema: SchemaRef) -> Self { + self.schema = Some(schema); + self + } + + pub fn with_expressions(mut self, expressions: LexOrderingRef<'a>) -> Self { + self.expressions = expressions; + self + } + + pub fn with_metrics(mut self, metrics: BaselineMetrics) -> Self { + self.metrics = Some(metrics); + self + } + + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = Some(batch_size); + self + } + + pub fn with_fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + pub fn with_reservation(mut self, reservation: MemoryReservation) -> Self { + self.reservation = Some(reservation); + self + } + + pub fn with_round_robin_tie_breaker( + mut self, + enable_round_robin_tie_breaker: bool, + ) -> Self { + self.enable_round_robin_tie_breaker = enable_round_robin_tie_breaker; + self + } + + pub fn build(self) -> Result { + let Self { + streams, + schema, + metrics, + batch_size, + reservation, + fetch, + expressions, + enable_round_robin_tie_breaker, + } = self; + + // Early return if streams or expressions are empty + let checks = [ + ( + streams.is_empty(), + "Streams cannot be empty for streaming merge", + ), + ( + expressions.is_empty(), + "Sort expressions cannot be empty for streaming merge", + ), + ]; + + if let Some((_, error_message)) = checks.iter().find(|(condition, _)| *condition) + { + return internal_err!("{}", error_message); + } + + // Unwrapping mandatory fields + let schema = schema.expect("Schema cannot be empty for streaming merge"); + let metrics = metrics.expect("Metrics cannot be empty for streaming merge"); + let batch_size = + batch_size.expect("Batch size cannot be empty for streaming merge"); + let reservation = + reservation.expect("Reservation cannot be empty for streaming merge"); + + // Special case single column comparisons with optimized cursor implementations + if expressions.len() == 1 { + let sort = expressions[0].clone(); + let data_type = sort.expr.data_type(schema.as_ref())?; + downcast_primitive! { + data_type => (primitive_merge_helper, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker), + DataType::Utf8 => merge_helper!(StringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + DataType::LargeUtf8 => merge_helper!(LargeStringArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + DataType::Binary => merge_helper!(BinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + DataType::LargeBinary => merge_helper!(LargeBinaryArray, sort, streams, schema, metrics, batch_size, fetch, reservation, enable_round_robin_tie_breaker) + _ => {} + } + } + + let streams = RowCursorStream::try_new( + schema.as_ref(), + expressions, + streams, + reservation.new_empty(), + )?; + Ok(Box::pin(SortPreservingMergeStream::new( + Box::new(streams), + schema, + metrics, + batch_size, + fetch, + reservation, + enable_round_robin_tie_breaker, + ))) + } } diff --git a/datafusion/physical-plan/src/spill.rs b/datafusion/physical-plan/src/spill.rs new file mode 100644 index 000000000000..de85a7c6f098 --- /dev/null +++ b/datafusion/physical-plan/src/spill.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines the spilling functions + +use std::fs::File; +use std::io::BufReader; +use std::path::{Path, PathBuf}; + +use arrow::datatypes::SchemaRef; +use arrow::ipc::reader::FileReader; +use arrow::record_batch::RecordBatch; +use log::debug; +use tokio::sync::mpsc::Sender; + +use datafusion_common::{exec_datafusion_err, Result}; +use datafusion_execution::disk_manager::RefCountedTempFile; +use datafusion_execution::memory_pool::human_readable_size; +use datafusion_execution::SendableRecordBatchStream; + +use crate::common::IPCWriter; +use crate::stream::RecordBatchReceiverStream; + +/// Read spilled batches from the disk +/// +/// `path` - temp file +/// `schema` - batches schema, should be the same across batches +/// `buffer` - internal buffer of capacity batches +pub(crate) fn read_spill_as_stream( + path: RefCountedTempFile, + schema: SchemaRef, + buffer: usize, +) -> Result { + let mut builder = RecordBatchReceiverStream::builder(schema, buffer); + let sender = builder.tx(); + + builder.spawn_blocking(move || read_spill(sender, path.path())); + + Ok(builder.build()) +} + +/// Spills in-memory `batches` to disk. +/// +/// Returns total number of the rows spilled to disk. +pub(crate) fn spill_record_batches( + batches: Vec, + path: PathBuf, + schema: SchemaRef, +) -> Result { + let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?; + for batch in batches { + writer.write(&batch)?; + } + writer.finish()?; + debug!( + "Spilled {} batches of total {} rows to disk, memory released {}", + writer.num_batches, + writer.num_rows, + human_readable_size(writer.num_bytes), + ); + Ok(writer.num_rows) +} + +fn read_spill(sender: Sender>, path: &Path) -> Result<()> { + let file = BufReader::new(File::open(path)?); + let reader = FileReader::try_new(file, None)?; + for batch in reader { + sender + .blocking_send(batch.map_err(Into::into)) + .map_err(|e| exec_datafusion_err!("{e}"))?; + } + Ok(()) +} + +/// Spill the `RecordBatch` to disk as smaller batches +/// split by `batch_size_rows` +pub fn spill_record_batch_by_size( + batch: &RecordBatch, + path: PathBuf, + schema: SchemaRef, + batch_size_rows: usize, +) -> Result<()> { + let mut offset = 0; + let total_rows = batch.num_rows(); + let mut writer = IPCWriter::new(&path, schema.as_ref())?; + + while offset < total_rows { + let length = std::cmp::min(total_rows - offset, batch_size_rows); + let batch = batch.slice(offset, length); + offset += batch.num_rows(); + writer.write(&batch)?; + } + writer.finish()?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use crate::spill::{spill_record_batch_by_size, spill_record_batches}; + use crate::test::build_table_i32; + use datafusion_common::Result; + use datafusion_execution::disk_manager::DiskManagerConfig; + use datafusion_execution::DiskManager; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + + #[test] + fn test_batch_spill_and_read() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + + let batch2 = build_table_i32( + ("a2", &vec![10, 11, 12]), + ("b2", &vec![13, 14, 15]), + ("c2", &vec![14, 15, 16]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + let num_rows = batch1.num_rows() + batch2.num_rows(); + let cnt = spill_record_batches( + vec![batch1, batch2], + spill_file.path().into(), + Arc::clone(&schema), + ); + assert_eq!(cnt.unwrap(), num_rows); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 2); + assert_eq!(reader.schema(), schema); + + Ok(()) + } + + #[test] + fn test_batch_spill_by_size() -> Result<()> { + let batch1 = build_table_i32( + ("a2", &vec![0, 1, 2, 3]), + ("b2", &vec![3, 4, 5, 6]), + ("c2", &vec![4, 5, 6, 7]), + ); + + let disk_manager = DiskManager::try_new(DiskManagerConfig::NewOs)?; + + let spill_file = disk_manager.create_tmp_file("Test Spill")?; + let schema = batch1.schema(); + spill_record_batch_by_size( + &batch1, + spill_file.path().into(), + Arc::clone(&schema), + 1, + )?; + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = arrow::ipc::reader::FileReader::try_new(file, None)?; + + assert_eq!(reader.num_batches(), 4); + assert_eq!(reader.schema(), schema); + + Ok(()) + } +} diff --git a/datafusion/physical-plan/src/stream.rs b/datafusion/physical-plan/src/stream.rs index 99d9367740be..ec4c9dd502a6 100644 --- a/datafusion/physical-plan/src/stream.rs +++ b/datafusion/physical-plan/src/stream.rs @@ -56,7 +56,7 @@ pub(crate) struct ReceiverStreamBuilder { } impl ReceiverStreamBuilder { - /// create new channels with the specified buffer size + /// Create new channels with the specified buffer size pub fn new(capacity: usize) -> Self { let (tx, rx) = tokio::sync::mpsc::channel(capacity); @@ -83,10 +83,10 @@ impl ReceiverStreamBuilder { } /// Spawn a blocking task that will be aborted if this builder (or the stream - /// built from it) are dropped + /// built from it) are dropped. /// - /// this is often used to spawn tasks that write to the sender - /// retrieved from `Self::tx` + /// This is often used to spawn tasks that write to the sender + /// retrieved from `Self::tx`. pub fn spawn_blocking(&mut self, f: F) where F: FnOnce() -> Result<()>, @@ -103,7 +103,7 @@ impl ReceiverStreamBuilder { mut join_set, } = self; - // don't need tx + // Doesn't need tx drop(tx); // future that checks the result of the join set, and propagates panic if seen @@ -112,7 +112,7 @@ impl ReceiverStreamBuilder { match result { Ok(task_result) => { match task_result { - // nothing to report + // Nothing to report Ok(_) => continue, // This means a blocking task error Err(error) => return Some(Err(error)), @@ -215,7 +215,7 @@ pub struct RecordBatchReceiverStreamBuilder { } impl RecordBatchReceiverStreamBuilder { - /// create new channels with the specified buffer size + /// Create new channels with the specified buffer size pub fn new(schema: SchemaRef, capacity: usize) -> Self { Self { schema, @@ -256,7 +256,7 @@ impl RecordBatchReceiverStreamBuilder { self.inner.spawn_blocking(f) } - /// runs the `partition` of the `input` ExecutionPlan on the + /// Runs the `partition` of the `input` ExecutionPlan on the /// tokio threadpool and writes its outputs to this stream /// /// If the input partition produces an error, the error will be @@ -299,7 +299,7 @@ impl RecordBatchReceiverStreamBuilder { return Ok(()); } - // stop after the first error is encontered (don't + // Stop after the first error is encountered (Don't // drive all streams to completion) if is_err { debug!( @@ -382,7 +382,7 @@ where S: Stream>, { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -402,7 +402,7 @@ impl EmptyRecordBatchStream { impl RecordBatchStream for EmptyRecordBatchStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -437,12 +437,12 @@ impl ObservedStream { } impl RecordBatchStream for ObservedStream { - fn schema(&self) -> arrow::datatypes::SchemaRef { + fn schema(&self) -> SchemaRef { self.inner.schema() } } -impl futures::Stream for ObservedStream { +impl Stream for ObservedStream { type Item = Result; fn poll_next( @@ -474,7 +474,7 @@ mod test { let schema = schema(); let num_partitions = 10; - let input = PanicExec::new(schema.clone(), num_partitions); + let input = PanicExec::new(Arc::clone(&schema), num_partitions); consume(input, 10).await } @@ -483,13 +483,13 @@ mod test { async fn record_batch_receiver_stream_propagates_panics_early_shutdown() { let schema = schema(); - // make 2 partitions, second partition panics before the first + // Make 2 partitions, second partition panics before the first let num_partitions = 2; - let input = PanicExec::new(schema.clone(), num_partitions) + let input = PanicExec::new(Arc::clone(&schema), num_partitions) .with_partition_panic(0, 10) .with_partition_panic(1, 3); // partition 1 should panic first (after 3 ) - // ensure that the panic results in an early shutdown (that + // Ensure that the panic results in an early shutdown (that // everything stops after the first panic). // Since the stream reads every other batch: (0,1,0,1,0,panic) @@ -504,18 +504,18 @@ mod test { let schema = schema(); // Make an input that never proceeds - let input = BlockingExec::new(schema.clone(), 1); + let input = BlockingExec::new(Arc::clone(&schema), 1); let refs = input.refs(); // Configure a RecordBatchReceiverStream to consume the input let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(input), 0, task_ctx.clone()); + builder.run_input(Arc::new(input), 0, Arc::clone(&task_ctx)); let stream = builder.build(); - // input should still be present + // Input should still be present assert!(std::sync::Weak::strong_count(&refs) > 0); - // drop the stream, ensure the refs go to zero + // Drop the stream, ensure the refs go to zero drop(stream); assert_strong_count_converges_to_zero(refs).await; } @@ -529,15 +529,17 @@ mod test { let schema = schema(); // make an input that will error twice - let error_stream = - MockExec::new(vec![exec_err!("Test1"), exec_err!("Test2")], schema.clone()) - .with_use_task(false); + let error_stream = MockExec::new( + vec![exec_err!("Test1"), exec_err!("Test2")], + Arc::clone(&schema), + ) + .with_use_task(false); let mut builder = RecordBatchReceiverStream::builder(schema, 2); - builder.run_input(Arc::new(error_stream), 0, task_ctx.clone()); + builder.run_input(Arc::new(error_stream), 0, Arc::clone(&task_ctx)); let mut stream = builder.build(); - // get the first result, which should be an error + // Get the first result, which should be an error let first_batch = stream.next().await.unwrap(); let first_err = first_batch.unwrap_err(); assert_eq!(first_err.strip_backtrace(), "Execution error: Test1"); @@ -560,11 +562,15 @@ mod test { let mut builder = RecordBatchReceiverStream::builder(input.schema(), num_partitions); for partition in 0..num_partitions { - builder.run_input(input.clone(), partition, task_ctx.clone()); + builder.run_input( + Arc::clone(&input) as Arc, + partition, + Arc::clone(&task_ctx), + ); } let mut stream = builder.build(); - // drain the stream until it is complete, panic'ing on error + // Drain the stream until it is complete, panic'ing on error let mut num_batches = 0; while let Some(next) = stream.next().await { next.unwrap(); diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index d7e254c42fe1..7ccef3248069 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -18,6 +18,7 @@ //! Generic plans for deferred execution: [`StreamingTableExec`] and [`PartitionStream`] use std::any::Any; +use std::fmt::Debug; use std::sync::Arc; use super::{DisplayAs, DisplayFormatType, ExecutionMode, PlanProperties}; @@ -31,6 +32,8 @@ use datafusion_common::{internal_err, plan_err, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; +use crate::limit::LimitStream; +use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use async_trait::async_trait; use futures::stream::StreamExt; use log::debug; @@ -40,7 +43,7 @@ use log::debug; /// Combined with [`StreamingTableExec`], you can use this trait to implement /// [`ExecutionPlan`] for a custom source with less boiler plate than /// implementing `ExecutionPlan` directly for many use cases. -pub trait PartitionStream: Send + Sync { +pub trait PartitionStream: Debug + Send + Sync { /// Returns the schema of this partition fn schema(&self) -> &SchemaRef; @@ -52,13 +55,16 @@ pub trait PartitionStream: Send + Sync { /// /// If your source can be represented as one or more [`PartitionStream`]s, you can /// use this struct to implement [`ExecutionPlan`]. +#[derive(Clone)] pub struct StreamingTableExec { partitions: Vec>, projection: Option>, projected_schema: SchemaRef, projected_output_ordering: Vec, infinite: bool, + limit: Option, cache: PlanProperties, + metrics: ExecutionPlanMetricsSet, } impl StreamingTableExec { @@ -69,13 +75,14 @@ impl StreamingTableExec { projection: Option<&Vec>, projected_output_ordering: impl IntoIterator, infinite: bool, + limit: Option, ) -> Result { for x in partitions.iter() { let partition_schema = x.schema(); if !schema.eq(partition_schema) { debug!( "Target schema does not match with partition schema. \ - Target_schema: {schema:?}. Partiton Schema: {partition_schema:?}" + Target_schema: {schema:?}. Partition Schema: {partition_schema:?}" ); return plan_err!("Mismatch between schema and batches"); } @@ -88,7 +95,7 @@ impl StreamingTableExec { let projected_output_ordering = projected_output_ordering.into_iter().collect::>(); let cache = Self::compute_properties( - projected_schema.clone(), + Arc::clone(&projected_schema), &projected_output_ordering, &partitions, infinite, @@ -99,7 +106,9 @@ impl StreamingTableExec { projection: projection.cloned().map(Into::into), projected_output_ordering, infinite, + limit, cache, + metrics: ExecutionPlanMetricsSet::new(), }) } @@ -127,6 +136,10 @@ impl StreamingTableExec { self.infinite } + pub fn limit(&self) -> Option { + self.limit + } + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( schema: SchemaRef, @@ -151,7 +164,7 @@ impl StreamingTableExec { } } -impl std::fmt::Debug for StreamingTableExec { +impl Debug for StreamingTableExec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("LazyMemTableExec").finish_non_exhaustive() } @@ -180,6 +193,9 @@ impl DisplayAs for StreamingTableExec { if self.infinite { write!(f, ", infinite_source=true")?; } + if let Some(fetch) = self.limit { + write!(f, ", fetch={fetch}")?; + } display_orderings(f, &self.projected_output_ordering)?; @@ -203,7 +219,11 @@ impl ExecutionPlan for StreamingTableExec { &self.cache } - fn children(&self) -> Vec> { + fn fetch(&self) -> Option { + self.limit + } + + fn children(&self) -> Vec<&Arc> { vec![] } @@ -224,14 +244,129 @@ impl ExecutionPlan for StreamingTableExec { ctx: Arc, ) -> Result { let stream = self.partitions[partition].execute(ctx); - Ok(match self.projection.clone() { + let projected_stream = match self.projection.clone() { Some(projection) => Box::pin(RecordBatchStreamAdapter::new( - self.projected_schema.clone(), + Arc::clone(&self.projected_schema), stream.map(move |x| { x.and_then(|b| b.project(projection.as_ref()).map_err(Into::into)) }), )), None => stream, + }; + Ok(match self.limit { + None => projected_stream, + Some(fetch) => { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + Box::pin(LimitStream::new( + projected_stream, + 0, + Some(fetch), + baseline_metrics, + )) + } }) } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } + + fn with_fetch(&self, limit: Option) -> Option> { + Some(Arc::new(StreamingTableExec { + partitions: self.partitions.clone(), + projection: self.projection.clone(), + projected_schema: Arc::clone(&self.projected_schema), + projected_output_ordering: self.projected_output_ordering.clone(), + infinite: self.infinite, + limit, + cache: self.cache.clone(), + metrics: self.metrics.clone(), + })) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::collect_partitioned; + use crate::streaming::PartitionStream; + use crate::test::{make_partition, TestPartitionStream}; + use arrow::record_batch::RecordBatch; + + #[tokio::test] + async fn test_no_limit() { + let exec = TestBuilder::new() + // Make 2 batches, each with 100 rows + .with_batches(vec![make_partition(100), make_partition(100)]) + .build(); + + let counts = collect_num_rows(Arc::new(exec)).await; + assert_eq!(counts, vec![200]); + } + + #[tokio::test] + async fn test_limit() { + let exec = TestBuilder::new() + // Make 2 batches, each with 100 rows + .with_batches(vec![make_partition(100), make_partition(100)]) + // Limit to only the first 75 rows back + .with_limit(Some(75)) + .build(); + + let counts = collect_num_rows(Arc::new(exec)).await; + assert_eq!(counts, vec![75]); + } + + /// Runs the provided execution plan and returns a vector of the number of + /// rows in each partition + async fn collect_num_rows(exec: Arc) -> Vec { + let ctx = Arc::new(TaskContext::default()); + let partition_batches = collect_partitioned(exec, ctx).await.unwrap(); + partition_batches + .into_iter() + .map(|batches| batches.iter().map(|b| b.num_rows()).sum::()) + .collect() + } + + #[derive(Default)] + struct TestBuilder { + schema: Option, + partitions: Vec>, + projection: Option>, + projected_output_ordering: Vec, + infinite: bool, + limit: Option, + } + + impl TestBuilder { + fn new() -> Self { + Self::default() + } + + /// Set the batches for the stream + fn with_batches(mut self, batches: Vec) -> Self { + let stream = TestPartitionStream::new_with_batches(batches); + self.schema = Some(Arc::clone(stream.schema())); + self.partitions = vec![Arc::new(stream)]; + self + } + + /// Set the limit for the stream + fn with_limit(mut self, limit: Option) -> Self { + self.limit = limit; + self + } + + fn build(self) -> StreamingTableExec { + StreamingTableExec::try_new( + self.schema.unwrap(), + self.partitions, + self.projection.as_ref(), + self.projected_output_ordering, + self.infinite, + self.limit, + ) + .unwrap() + } + } } diff --git a/datafusion/physical-plan/src/test.rs b/datafusion/physical-plan/src/test.rs index 9e6312284c08..90ec9b106850 100644 --- a/datafusion/physical-plan/src/test.rs +++ b/datafusion/physical-plan/src/test.rs @@ -23,9 +23,12 @@ use std::sync::Arc; use arrow_array::{ArrayRef, Int32Array, RecordBatch}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use futures::{Future, FutureExt}; use crate::memory::MemoryExec; +use crate::stream::RecordBatchStreamAdapter; +use crate::streaming::PartitionStream; use crate::ExecutionPlan; pub mod exec; @@ -62,7 +65,7 @@ pub fn aggr_test_schema() -> SchemaRef { Arc::new(schema) } -/// returns record batch with 3 columns of i32 in memory +/// Returns record batch with 3 columns of i32 in memory pub fn build_table_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -85,7 +88,7 @@ pub fn build_table_i32( .unwrap() } -/// returns memory table scan wrapped around record batch with 3 columns of i32 +/// Returns memory table scan wrapped around record batch with 3 columns of i32 pub fn build_table_scan_i32( a: (&str, &Vec), b: (&str, &Vec), @@ -121,3 +124,30 @@ pub fn mem_exec(partitions: usize) -> MemoryExec { let projection = None; MemoryExec::try_new(&data, schema, projection).unwrap() } + +// Construct a stream partition for test purposes +#[derive(Debug)] +pub struct TestPartitionStream { + pub schema: SchemaRef, + pub batches: Vec, +} + +impl TestPartitionStream { + /// Create a new stream partition with the provided batches + pub fn new_with_batches(batches: Vec) -> Self { + let schema = batches[0].schema(); + Self { schema, batches } + } +} +impl PartitionStream for TestPartitionStream { + fn schema(&self) -> &SchemaRef { + &self.schema + } + fn execute(&self, _ctx: Arc) -> SendableRecordBatchStream { + let stream = futures::stream::iter(self.batches.clone().into_iter().map(Ok)); + Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&self.schema), + stream, + )) + } +} diff --git a/datafusion/physical-plan/src/test/exec.rs b/datafusion/physical-plan/src/test/exec.rs index b4f1eac0a655..cf1c0e313733 100644 --- a/datafusion/physical-plan/src/test/exec.rs +++ b/datafusion/physical-plan/src/test/exec.rs @@ -133,7 +133,7 @@ impl MockExec { /// ensure any poll loops are correct. This behavior can be /// changed with `with_use_task` pub fn new(data: Vec>, schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); Self { data, schema, @@ -177,6 +177,10 @@ impl DisplayAs for MockExec { } impl ExecutionPlan for MockExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -185,7 +189,7 @@ impl ExecutionPlan for MockExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -290,7 +294,7 @@ impl BarrierExec { pub fn new(data: Vec>, schema: SchemaRef) -> Self { // wait for all streams and the input let barrier = Arc::new(Barrier::new(data.len() + 1)); - let cache = Self::compute_properties(schema.clone(), &data); + let cache = Self::compute_properties(Arc::clone(&schema), &data); Self { data, schema, @@ -335,6 +339,10 @@ impl DisplayAs for BarrierExec { } impl ExecutionPlan for BarrierExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -343,7 +351,7 @@ impl ExecutionPlan for BarrierExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { unimplemented!() } @@ -366,7 +374,7 @@ impl ExecutionPlan for BarrierExec { // task simply sends data in order after barrier is reached let data = self.data[partition].clone(); - let b = self.barrier.clone(); + let b = Arc::clone(&self.barrier); let tx = builder.tx(); builder.spawn(async move { println!("Partition {partition} waiting on barrier"); @@ -413,7 +421,7 @@ impl ErrorExec { DataType::Int64, true, )])); - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(schema); Self { cache } } @@ -444,6 +452,10 @@ impl DisplayAs for ErrorExec { } impl ExecutionPlan for ErrorExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -452,7 +464,7 @@ impl ExecutionPlan for ErrorExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { unimplemented!() } @@ -527,6 +539,10 @@ impl DisplayAs for StatisticsExec { } impl ExecutionPlan for StatisticsExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -535,7 +551,7 @@ impl ExecutionPlan for StatisticsExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -575,7 +591,7 @@ pub struct BlockingExec { impl BlockingExec { /// Create new [`BlockingExec`] with a give schema and number of partitions. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { - let cache = Self::compute_properties(schema.clone(), n_partitions); + let cache = Self::compute_properties(Arc::clone(&schema), n_partitions); Self { schema, refs: Default::default(), @@ -619,6 +635,10 @@ impl DisplayAs for BlockingExec { } impl ExecutionPlan for BlockingExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -627,7 +647,7 @@ impl ExecutionPlan for BlockingExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { // this is a leaf node and has no children vec![] } @@ -705,7 +725,7 @@ pub struct PanicExec { schema: SchemaRef, /// Number of output partitions. Each partition will produce this - /// many empty output record batches prior to panicing + /// many empty output record batches prior to panicking batches_until_panics: Vec, cache: PlanProperties, } @@ -715,7 +735,7 @@ impl PanicExec { /// partitions, which will each panic immediately. pub fn new(schema: SchemaRef, n_partitions: usize) -> Self { let batches_until_panics = vec![0; n_partitions]; - let cache = Self::compute_properties(schema.clone(), &batches_until_panics); + let cache = Self::compute_properties(Arc::clone(&schema), &batches_until_panics); Self { schema, batches_until_panics, @@ -760,6 +780,10 @@ impl DisplayAs for PanicExec { } impl ExecutionPlan for PanicExec { + fn name(&self) -> &'static str { + Self::static_name() + } + fn as_any(&self) -> &dyn Any { self } @@ -768,7 +792,7 @@ impl ExecutionPlan for PanicExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { // this is a leaf node and has no children vec![] } @@ -821,7 +845,7 @@ impl Stream for PanicStream { if self.ready { self.batches_until_panic -= 1; self.ready = false; - let batch = RecordBatch::new_empty(self.schema.clone()); + let batch = RecordBatch::new_empty(Arc::clone(&self.schema)); return Poll::Ready(Some(Ok(batch))); } else { self.ready = true; diff --git a/datafusion/physical-plan/src/topk/mod.rs b/datafusion/physical-plan/src/topk/mod.rs index 6a77bfaf3ccd..14469ab6c0d9 100644 --- a/datafusion/physical-plan/src/topk/mod.rs +++ b/datafusion/physical-plan/src/topk/mod.rs @@ -21,8 +21,10 @@ use arrow::{ compute::interleave, row::{RowConverter, Rows, SortField}, }; +use std::mem::size_of; use std::{cmp::Ordering, collections::BinaryHeap, sync::Arc}; +use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; use arrow_array::{Array, ArrayRef, RecordBatch}; use arrow_schema::SchemaRef; use datafusion_common::Result; @@ -31,10 +33,9 @@ use datafusion_execution::{ runtime_env::RuntimeEnv, }; use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashMap; -use crate::{stream::RecordBatchStreamAdapter, SendableRecordBatchStream}; - use super::metrics::{BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder}; /// Global TopK @@ -94,13 +95,13 @@ pub struct TopK { impl TopK { /// Create a new [`TopK`] that stores the top `k` values, as /// defined by the sort expressions in `expr`. - // TOOD: make a builder or some other nicer API to avoid the + // TODO: make a builder or some other nicer API to avoid the // clippy warning #[allow(clippy::too_many_arguments)] pub fn try_new( partition_id: usize, schema: SchemaRef, - expr: Vec, + expr: LexOrdering, k: usize, batch_size: usize, runtime: Arc, @@ -110,7 +111,7 @@ impl TopK { let reservation = MemoryConsumer::new(format!("TopK[{partition_id}]")) .register(&runtime.memory_pool); - let expr: Arc<[PhysicalSortExpr]> = expr.into(); + let expr: Arc<[PhysicalSortExpr]> = expr.inner.into(); let sort_fields: Vec<_> = expr .iter() @@ -131,7 +132,7 @@ impl TopK { ); Ok(Self { - schema: schema.clone(), + schema: Arc::clone(&schema), metrics: TopKMetrics::new(metrics, partition), reservation, batch_size, @@ -225,7 +226,7 @@ impl TopK { /// return the size of memory used by this operator, in bytes fn size(&self) -> usize { - std::mem::size_of::() + size_of::() + self.row_converter.size() + self.scratch_rows.size() + self.heap.size() @@ -258,7 +259,7 @@ impl TopKMetrics { /// Using the `Row` format handles things such as ascending vs /// descending and nulls first vs nulls last. struct TopKHeap { - /// The maximum number of elemenents to store in this heap. + /// The maximum number of elements to store in this heap. k: usize, /// The target number of rows for output batches batch_size: usize, @@ -355,7 +356,7 @@ impl TopKHeap { /// high, as a single [`RecordBatch`], and a sorted vec of the /// current heap's contents pub fn emit_with_state(&mut self) -> Result<(RecordBatch, Vec)> { - let schema = self.store.schema().clone(); + let schema = Arc::clone(self.store.schema()); // generate sorted rows let topk_rows = std::mem::take(&mut self.inner).into_sorted_vec(); @@ -421,7 +422,7 @@ impl TopKHeap { let num_rows = self.inner.len(); let (new_batch, mut topk_rows) = self.emit_with_state()?; - // clear all old entires in store (this invalidates all + // clear all old entries in store (this invalidates all // store_ids in `inner`) self.store.clear(); @@ -444,8 +445,8 @@ impl TopKHeap { /// return the size of memory used by this heap, in bytes fn size(&self) -> usize { - std::mem::size_of::() - + (self.inner.capacity() * std::mem::size_of::()) + size_of::() + + (self.inner.capacity() * size_of::()) + self.store.size() + self.owned_bytes } @@ -453,7 +454,7 @@ impl TopKHeap { /// Represents one of the top K rows held in this heap. Orders /// according to memcmp of row (e.g. the arrow Row format, but could -/// also be primtive values) +/// also be primitive values) /// /// Reuses allocations to minimize runtime overhead of creating new Vecs #[derive(Debug, PartialEq)] @@ -636,9 +637,8 @@ impl RecordBatchStore { /// returns the size of memory used by this store, including all /// referenced `RecordBatch`es, in bytes pub fn size(&self) -> usize { - std::mem::size_of::() - + self.batches.capacity() - * (std::mem::size_of::() + std::mem::size_of::()) + size_of::() + + self.batches.capacity() * (size_of::() + size_of::()) + self.batches_size } } diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 46460cbb6684..96bd0de3d37c 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -26,7 +26,7 @@ use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; impl DynTreeNode for dyn ExecutionPlan { - fn arc_children(&self) -> Vec> { + fn arc_children(&self) -> Vec<&Arc> { self.children() } @@ -62,7 +62,7 @@ impl PlanContext { } pub fn update_plan_from_children(mut self) -> Result { - let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); + let children_plans = self.children.iter().map(|c| Arc::clone(&c.plan)).collect(); self.plan = with_new_children_if_necessary(self.plan, children_plans)?; Ok(self) @@ -71,7 +71,12 @@ impl PlanContext { impl PlanContext { pub fn new_default(plan: Arc) -> Self { - let children = plan.children().into_iter().map(Self::new_default).collect(); + let children = plan + .children() + .into_iter() + .cloned() + .map(Self::new_default) + .collect(); Self::new(plan, Default::default(), children) } } @@ -86,8 +91,8 @@ impl Display for PlanContext { } impl ConcreteTreeNode for PlanContext { - fn children(&self) -> Vec<&Self> { - self.children.iter().collect() + fn children(&self) -> &[Self] { + &self.children } fn take_children(mut self) -> (Self, Vec) { diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 1354644788ea..bd36753880eb 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -41,7 +41,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::EquivalenceProperties; +use datafusion_physical_expr::{calculate_union, EquivalenceProperties}; use futures::Stream; use itertools::Itertools; @@ -85,7 +85,7 @@ use tokio::macros::support::thread_rng_n; /// │Input 1 │ │Input 2 │ /// └─────────────────┘ └──────────────────┘ /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UnionExec { /// Input execution plan inputs: Vec>, @@ -99,7 +99,12 @@ impl UnionExec { /// Create a new UnionExec pub fn new(inputs: Vec>) -> Self { let schema = union_schema(&inputs); - let cache = Self::compute_properties(&inputs, schema); + // The schema of the inputs and the union schema is consistent when: + // - They have the same number of fields, and + // - Their fields have same types at the same indices. + // Here, we know that schemas are consistent and the call below can + // not return an error. + let cache = Self::compute_properties(&inputs, schema).unwrap(); UnionExec { inputs, metrics: ExecutionPlanMetricsSet::new(), @@ -116,43 +121,13 @@ impl UnionExec { fn compute_properties( inputs: &[Arc], schema: SchemaRef, - ) -> PlanProperties { + ) -> Result { // Calculate equivalence properties: - // TODO: In some cases, we should be able to preserve some equivalence - // classes and constants. Add support for such cases. - let children_eqs = inputs + let children_eqps = inputs .iter() - .map(|child| child.equivalence_properties()) + .map(|child| child.equivalence_properties().clone()) .collect::>(); - let mut eq_properties = EquivalenceProperties::new(schema); - // Use the ordering equivalence class of the first child as the seed: - let mut meets = children_eqs[0] - .oeq_class() - .iter() - .map(|item| item.to_vec()) - .collect::>(); - // Iterate over all the children: - for child_eqs in &children_eqs[1..] { - // Compute meet orderings of the current meets and the new ordering - // equivalence class. - let mut idx = 0; - while idx < meets.len() { - // Find all the meets of `current_meet` with this child's orderings: - let valid_meets = child_eqs.oeq_class().iter().filter_map(|ordering| { - child_eqs.get_meet_ordering(ordering, &meets[idx]) - }); - // Use the longest of these meets as others are redundant: - if let Some(next_meet) = valid_meets.max_by_key(|m| m.len()) { - meets[idx] = next_meet; - idx += 1; - } else { - meets.swap_remove(idx); - } - } - } - // We know have all the valid orderings after union, remove redundant - // entries (implicitly) and return: - eq_properties.add_new_orderings(meets); + let eq_properties = calculate_union(children_eqps, schema)?; // Calculate output partitioning; i.e. sum output partitions of the inputs. let num_partitions = inputs @@ -164,7 +139,11 @@ impl UnionExec { // Determine execution mode: let mode = execution_mode_from_children(inputs.iter()); - PlanProperties::new(eq_properties, output_partitioning, mode) + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + mode, + )) } } @@ -196,8 +175,8 @@ impl ExecutionPlan for UnionExec { &self.cache } - fn children(&self) -> Vec> { - self.inputs.clone() + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() } fn maintains_input_order(&self) -> Vec { @@ -281,6 +260,10 @@ impl ExecutionPlan for UnionExec { fn benefits_from_input_partitioning(&self) -> Vec { vec![false; self.children().len()] } + + fn supports_limit_pushdown(&self) -> bool { + true + } } /// Combines multiple input streams by interleaving them. @@ -315,7 +298,7 @@ impl ExecutionPlan for UnionExec { /// | |-----------------+ /// +---------+ /// ``` -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct InterleaveExec { /// Input execution plan inputs: Vec>, @@ -387,8 +370,8 @@ impl ExecutionPlan for InterleaveExec { &self.cache } - fn children(&self) -> Vec> { - self.inputs.clone() + fn children(&self) -> Vec<&Arc> { + self.inputs.iter().collect() } fn maintains_input_order(&self) -> Vec { @@ -399,6 +382,12 @@ impl ExecutionPlan for InterleaveExec { self: Arc, children: Vec>, ) -> Result> { + // New children are no longer interleavable, which might be a bug of optimization rewrite. + if !can_interleave(children.iter()) { + return internal_err!( + "Can not create InterleaveExec: new children can not be interleaved" + ); + } Ok(Arc::new(InterleaveExec::try_new(children)?)) } @@ -417,7 +406,7 @@ impl ExecutionPlan for InterleaveExec { let mut input_stream_vec = vec![]; for input in self.inputs.iter() { if partition < input.output_partitioning().partition_count() { - input_stream_vec.push(input.execute(partition, context.clone())?); + input_stream_vec.push(input.execute(partition, Arc::clone(&context))?); } else { // Do not find a partition to execute break; @@ -479,26 +468,41 @@ pub fn can_interleave>>( } fn union_schema(inputs: &[Arc]) -> SchemaRef { - let fields: Vec = (0..inputs[0].schema().fields().len()) + let first_schema = inputs[0].schema(); + + let fields = (0..first_schema.fields().len()) .map(|i| { inputs .iter() - .filter_map(|input| { - if input.schema().fields().len() > i { - Some(input.schema().field(i).clone()) - } else { - None - } + .enumerate() + .map(|(input_idx, input)| { + let field = input.schema().field(i).clone(); + let mut metadata = field.metadata().clone(); + + let other_metadatas = inputs + .iter() + .enumerate() + .filter(|(other_idx, _)| *other_idx != input_idx) + .flat_map(|(_, other_input)| { + other_input.schema().field(i).metadata().clone().into_iter() + }); + + metadata.extend(other_metadatas); + field.with_metadata(metadata) }) - .find_or_first(|f| f.is_nullable()) + .find_or_first(Field::is_nullable) + // We can unwrap this because if inputs was empty, this would've already panic'ed when we + // indexed into inputs[0]. .unwrap() }) + .collect::>(); + + let all_metadata_merged = inputs + .iter() + .flat_map(|i| i.schema().metadata().clone().into_iter()) .collect(); - Arc::new(Schema::new_with_metadata( - fields, - inputs[0].schema().metadata().clone(), - )) + Arc::new(Schema::new_with_metadata(fields, all_metadata_merged)) } /// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one @@ -518,7 +522,7 @@ impl CombinedRecordBatchStream { impl RecordBatchStream for CombinedRecordBatchStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -603,6 +607,7 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; + use datafusion_physical_expr_common::sort_expr::LexOrdering; // Generate a schema which consists of 7 columns (a, b, c, d, e, f, g) fn create_test_schema() -> Result { @@ -621,14 +626,14 @@ mod tests { // Convert each tuple to PhysicalSortExpr fn convert_to_sort_exprs( in_data: &[(&Arc, SortOptions)], - ) -> Vec { + ) -> LexOrdering { in_data .iter() .map(|(expr, options)| PhysicalSortExpr { - expr: (*expr).clone(), + expr: Arc::clone(*expr), options: *options, }) - .collect::>() + .collect::() } #[tokio::test] @@ -810,31 +815,39 @@ mod tests { .map(|ordering| convert_to_sort_exprs(ordering)) .collect::>(); let child1 = Arc::new( - MemoryExec::try_new(&[], schema.clone(), None)? - .with_sort_information(first_orderings), + MemoryExec::try_new(&[], Arc::clone(&schema), None)? + .try_with_sort_information(first_orderings)?, ); let child2 = Arc::new( - MemoryExec::try_new(&[], schema.clone(), None)? - .with_sort_information(second_orderings), + MemoryExec::try_new(&[], Arc::clone(&schema), None)? + .try_with_sort_information(second_orderings)?, ); + let mut union_expected_eq = EquivalenceProperties::new(Arc::clone(&schema)); + union_expected_eq.add_new_orderings(union_expected_orderings); + let union = UnionExec::new(vec![child1, child2]); let union_eq_properties = union.properties().equivalence_properties(); - let union_actual_orderings = union_eq_properties.oeq_class(); let err_msg = format!( "Error in test id: {:?}, test case: {:?}", test_idx, test_cases[test_idx] ); - assert_eq!( - union_actual_orderings.len(), - union_expected_orderings.len(), - "{}", - err_msg - ); - for expected in &union_expected_orderings { - assert!(union_actual_orderings.contains(expected), "{}", err_msg); - } + assert_eq_properties_same(union_eq_properties, &union_expected_eq, err_msg); } Ok(()) } + + fn assert_eq_properties_same( + lhs: &EquivalenceProperties, + rhs: &EquivalenceProperties, + err_msg: String, + ) { + // Check whether orderings are same. + let lhs_orderings = lhs.oeq_class(); + let rhs_orderings = &rhs.oeq_class.orderings; + assert_eq!(lhs_orderings.len(), rhs_orderings.len(), "{}", err_msg); + for rhs_ordering in rhs_orderings { + assert!(lhs_orderings.contains(rhs_ordering), "{}", err_msg); + } + } } diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 06dd8230d39e..b7b9f17eb1b6 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -17,14 +17,15 @@ //! Define a plan for unnesting values in columns that contain a list type. +use std::cmp::{self, Ordering}; use std::collections::HashMap; use std::{any::Any, sync::Arc}; use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{DisplayAs, ExecutionPlanProperties, PlanProperties}; use crate::{ - expressions::Column, DisplayFormatType, Distribution, ExecutionPlan, PhysicalExpr, - RecordBatchStream, SendableRecordBatchStream, + DisplayFormatType, Distribution, ExecutionPlan, RecordBatchStream, + SendableRecordBatchStream, }; use arrow::array::{ @@ -36,28 +37,35 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{cast, is_not_null, kernels, sum}; use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::{Int64Array, Scalar}; +use arrow_array::{Int64Array, Scalar, StructArray}; use arrow_ord::cmp::lt; -use datafusion_common::{exec_datafusion_err, exec_err, Result, UnnestOptions}; +use datafusion_common::{ + exec_datafusion_err, exec_err, internal_err, Result, UnnestOptions, +}; use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; use async_trait::async_trait; use futures::{Stream, StreamExt}; +use hashbrown::HashSet; use log::trace; -/// Unnest the given columns by joining the row with each value in the -/// nested type. +/// Unnest the given columns (either with type struct or list) +/// For list unnesting, each rows is vertically transformed into multiple rows +/// For struct unnesting, each columns is horizontally transformed into multiple columns, +/// Thus the original RecordBatch with dimension (n x m) may have new dimension (n' x m') /// /// See [`UnnestOptions`] for more details and an example. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UnnestExec { /// Input execution plan input: Arc, /// The schema once the unnest is applied schema: SchemaRef, - /// The unnest columns - columns: Vec, + /// Indices of the list-typed columns in the input schema + list_column_indices: Vec, + /// Indices of the struct-typed columns in the input schema + struct_column_indices: Vec, /// Options options: UnnestOptions, /// Execution metrics @@ -70,15 +78,18 @@ impl UnnestExec { /// Create a new [UnnestExec]. pub fn new( input: Arc, - columns: Vec, + list_column_indices: Vec, + struct_column_indices: Vec, schema: SchemaRef, options: UnnestOptions, ) -> Self { - let cache = Self::compute_properties(&input, schema.clone()); + let cache = Self::compute_properties(&input, Arc::clone(&schema)); + UnnestExec { input, schema, - columns, + list_column_indices, + struct_column_indices, options, metrics: Default::default(), cache, @@ -98,6 +109,25 @@ impl UnnestExec { input.execution_mode(), ) } + + /// Input execution plan + pub fn input(&self) -> &Arc { + &self.input + } + + /// Indices of the list-typed columns in the input schema + pub fn list_column_indices(&self) -> &[ListUnnest] { + &self.list_column_indices + } + + /// Indices of the struct-typed columns in the input schema + pub fn struct_column_indices(&self) -> &[usize] { + &self.struct_column_indices + } + + pub fn options(&self) -> &UnnestOptions { + &self.options + } } impl DisplayAs for UnnestExec { @@ -127,8 +157,8 @@ impl ExecutionPlan for UnnestExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn with_new_children( @@ -136,9 +166,10 @@ impl ExecutionPlan for UnnestExec { children: Vec>, ) -> Result> { Ok(Arc::new(UnnestExec::new( - children[0].clone(), - self.columns.clone(), - self.schema.clone(), + Arc::clone(&children[0]), + self.list_column_indices.clone(), + self.struct_column_indices.clone(), + Arc::clone(&self.schema), self.options.clone(), ))) } @@ -157,8 +188,9 @@ impl ExecutionPlan for UnnestExec { Ok(Box::pin(UnnestStream { input, - schema: self.schema.clone(), - columns: self.columns.clone(), + schema: Arc::clone(&self.schema), + list_type_columns: self.list_column_indices.clone(), + struct_column_indices: self.struct_column_indices.iter().copied().collect(), options: self.options.clone(), metrics, })) @@ -171,7 +203,7 @@ impl ExecutionPlan for UnnestExec { #[derive(Clone, Debug)] struct UnnestMetrics { - /// total time for column unnesting + /// Total time for column unnesting elapsed_compute: metrics::Time, /// Number of batches consumed input_batches: metrics::Count, @@ -213,8 +245,11 @@ struct UnnestStream { input: SendableRecordBatchStream, /// Unnested schema schema: Arc, - /// The unnest columns - columns: Vec, + /// represents all unnest operations to be applied to the input (input index, depth) + /// e.g unnest(col1),unnest(unnest(col1)) where col1 has index 1 in original input schema + /// then list_type_columns = [ListUnnest{1,1},ListUnnest{1,2}] + list_type_columns: Vec, + struct_column_indices: HashSet, /// Options options: UnnestOptions, /// Metrics @@ -223,7 +258,7 @@ struct UnnestStream { impl RecordBatchStream for UnnestStream { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -251,8 +286,13 @@ impl UnnestStream { .map(|maybe_batch| match maybe_batch { Some(Ok(batch)) => { let timer = self.metrics.elapsed_compute.timer(); - let result = - build_batch(&batch, &self.schema, &self.columns, &self.options); + let result = build_batch( + &batch, + &self.schema, + &self.list_type_columns, + &self.struct_column_indices, + &self.options, + ); self.metrics.input_batches.add(1); self.metrics.input_rows.add(batch.num_rows()); if let Ok(ref batch) = result { @@ -279,23 +319,124 @@ impl UnnestStream { } } -/// For each row in a `RecordBatch`, some list columns need to be unnested. -/// We will expand the values in each list into multiple rows, -/// taking the longest length among these lists, and shorter lists are padded with NULLs. -// -/// For columns that don't need to be unnested, repeat their values until reaching the longest length. -fn build_batch( - batch: &RecordBatch, +/// Given a set of struct column indices to flatten +/// try converting the column in input into multiple subfield columns +/// For example +/// struct_col: [a: struct(item: int, name: string), b: int] +/// with a batch +/// {a: {item: 1, name: "a"}, b: 2}, +/// {a: {item: 3, name: "b"}, b: 4] +/// will be converted into +/// {a.item: 1, a.name: "a", b: 2}, +/// {a.item: 3, a.name: "b", b: 4} +fn flatten_struct_cols( + input_batch: &[Arc], schema: &SchemaRef, - columns: &[Column], - options: &UnnestOptions, + struct_column_indices: &HashSet, ) -> Result { - let list_arrays: Vec = columns + // horizontal expansion because of struct unnest + let columns_expanded = input_batch .iter() - .map(|column| column.evaluate(batch)?.into_array(batch.num_rows())) - .collect::>()?; + .enumerate() + .map(|(idx, column_data)| match struct_column_indices.get(&idx) { + Some(_) => match column_data.data_type() { + DataType::Struct(_) => { + let struct_arr = + column_data.as_any().downcast_ref::().unwrap(); + Ok(struct_arr.columns().to_vec()) + } + data_type => internal_err!( + "expecting column {} from input plan to be a struct, got {:?}", + idx, + data_type + ), + }, + None => Ok(vec![Arc::clone(column_data)]), + }) + .collect::>>()? + .into_iter() + .flatten() + .collect(); + Ok(RecordBatch::try_new(Arc::clone(schema), columns_expanded)?) +} - let longest_length = find_longest_length(&list_arrays, options)?; +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +pub struct ListUnnest { + pub index_in_input_schema: usize, + pub depth: usize, +} + +/// This function is used to execute the unnesting on multiple columns all at once, but +/// one level at a time, and is called n times, where n is the highest recursion level among +/// the unnest exprs in the query. +/// +/// For example giving the following query: +/// ```sql +/// select unnest(colA, max_depth:=3) as P1, unnest(colA,max_depth:=2) as P2, unnest(colB, max_depth:=1) as P3 from temp; +/// ``` +/// Then the total times this function being called is 3 +/// +/// It needs to be aware of which level the current unnesting is, because if there exists +/// multiple unnesting on the same column, but with different recursion levels, say +/// **unnest(colA, max_depth:=3)** and **unnest(colA, max_depth:=2)**, then the unnesting +/// of expr **unnest(colA, max_depth:=3)** will start at level 3, while unnesting for expr +/// **unnest(colA, max_depth:=2)** has to start at level 2 +/// +/// Set *colA* as a 3-dimension columns and *colB* as an array (1-dimension). As stated, +/// this function is called with the descending order of recursion depth +/// +/// Depth = 3 +/// - colA(3-dimension) unnest into temp column temp_P1(2_dimension) (unnesting of P1 starts +/// from this level) +/// - colA(3-dimension) having indices repeated by the unnesting operation above +/// - colB(1-dimension) having indices repeated by the unnesting operation above +/// +/// Depth = 2 +/// - temp_P1(2-dimension) unnest into temp column temp_P1(1-dimension) +/// - colA(3-dimension) unnest into temp column temp_P2(2-dimension) (unnesting of P2 starts +/// from this level) +/// - colB(1-dimension) having indices repeated by the unnesting operation above +/// +/// Depth = 1 +/// - temp_P1(1-dimension) unnest into P1 +/// - temp_P2(2-dimension) unnest into P2 +/// - colB(1-dimension) unnest into P3 (unnesting of P3 starts from this level) +/// +/// The returned array will has the same size as the input batch +/// and only contains original columns that are not being unnested. +fn list_unnest_at_level( + batch: &[ArrayRef], + list_type_unnests: &[ListUnnest], + temp_unnested_arrs: &mut HashMap, + level_to_unnest: usize, + options: &UnnestOptions, +) -> Result<(Vec, usize)> { + // Extract unnestable columns at this level + let (arrs_to_unnest, list_unnest_specs): (Vec>, Vec<_>) = + list_type_unnests + .iter() + .filter_map(|unnesting| { + if level_to_unnest == unnesting.depth { + return Some(( + Arc::clone(&batch[unnesting.index_in_input_schema]), + *unnesting, + )); + } + // This means the unnesting on this item has started at higher level + // and need to continue until depth reaches 1 + if level_to_unnest < unnesting.depth { + return Some(( + Arc::clone(temp_unnested_arrs.get(unnesting).unwrap()), + *unnesting, + )); + } + None + }) + .unzip(); + + // Filter out so that list_arrays only contain column with the highest depth + // at the same time, during iteration remove this depth so next time we don't have to unnest them again + let longest_length = find_longest_length(&arrs_to_unnest, options)?; let unnested_length = longest_length.as_primitive::(); let total_length = if unnested_length.is_empty() { 0 @@ -305,22 +446,210 @@ fn build_batch( })? as usize }; if total_length == 0 { - return Ok(RecordBatch::new_empty(schema.clone())); + return Ok((vec![], 0)); } // Unnest all the list arrays - let unnested_arrays = - unnest_list_arrays(&list_arrays, unnested_length, total_length)?; - let unnested_array_map: HashMap<_, _> = unnested_arrays - .into_iter() - .zip(columns.iter()) - .map(|(array, column)| (column.index(), array)) - .collect(); + let unnested_temp_arrays = + unnest_list_arrays(arrs_to_unnest.as_ref(), unnested_length, total_length)?; // Create the take indices array for other columns - let take_indicies = create_take_indicies(unnested_length, total_length); + let take_indices = create_take_indicies(unnested_length, total_length); - batch_from_indices(batch, schema, &unnested_array_map, &take_indicies) + // Dimension of arrays in batch is untouched, but the values are repeated + // as the side effect of unnesting + let ret = repeat_arrs_from_indices(batch, &take_indices)?; + unnested_temp_arrays + .into_iter() + .zip(list_unnest_specs.iter()) + .for_each(|(flatten_arr, unnesting)| { + temp_unnested_arrs.insert(*unnesting, flatten_arr); + }); + Ok((ret, total_length)) +} +struct UnnestingResult { + arr: ArrayRef, + depth: usize, +} + +/// For each row in a `RecordBatch`, some list/struct columns need to be unnested. +/// - For list columns: We will expand the values in each list into multiple rows, +/// taking the longest length among these lists, and shorter lists are padded with NULLs. +/// - For struct columns: We will expand the struct columns into multiple subfield columns. +/// +/// For columns that don't need to be unnested, repeat their values until reaching the longest length. +/// +/// Note: unnest has a big difference in behavior between Postgres and DuckDB +/// +/// Take this example +/// +/// 1. Postgres +/// ```ignored +/// create table temp ( +/// i integer[][][], j integer[] +/// ) +/// insert into temp values ('{{{1,2},{3,4}},{{5,6},{7,8}}}', '{1,2}'); +/// select unnest(i), unnest(j) from temp; +/// ``` +/// +/// Result +/// ```text +/// 1 1 +/// 2 2 +/// 3 +/// 4 +/// 5 +/// 6 +/// 7 +/// 8 +/// ``` +/// 2. DuckDB +/// ```ignore +/// create table temp (i integer[][][], j integer[]); +/// insert into temp values ([[[1,2],[3,4]],[[5,6],[7,8]]], [1,2]); +/// select unnest(i,recursive:=true), unnest(j,recursive:=true) from temp; +/// ``` +/// Result: +/// ```text +/// +/// ┌────────────────────────────────────────────────┬────────────────────────────────────────────────┐ +/// │ unnest(i, "recursive" := CAST('t' AS BOOLEAN)) │ unnest(j, "recursive" := CAST('t' AS BOOLEAN)) │ +/// │ int32 │ int32 │ +/// ├────────────────────────────────────────────────┼────────────────────────────────────────────────┤ +/// │ 1 │ 1 │ +/// │ 2 │ 2 │ +/// │ 3 │ 1 │ +/// │ 4 │ 2 │ +/// │ 5 │ 1 │ +/// │ 6 │ 2 │ +/// │ 7 │ 1 │ +/// │ 8 │ 2 │ +/// └────────────────────────────────────────────────┴────────────────────────────────────────────────┘ +/// ``` +/// +/// The following implementation refer to DuckDB's implementation +fn build_batch( + batch: &RecordBatch, + schema: &SchemaRef, + list_type_columns: &[ListUnnest], + struct_column_indices: &HashSet, + options: &UnnestOptions, +) -> Result { + let transformed = match list_type_columns.len() { + 0 => flatten_struct_cols(batch.columns(), schema, struct_column_indices), + _ => { + let mut temp_unnested_result = HashMap::new(); + let max_recursion = list_type_columns + .iter() + .fold(0, |highest_depth, ListUnnest { depth, .. }| { + cmp::max(highest_depth, *depth) + }); + + // This arr always has the same column count with the input batch + let mut flatten_arrs = vec![]; + + // Original batch has the same columns + // All unnesting results are written to temp_batch + for depth in (1..=max_recursion).rev() { + let input = match depth == max_recursion { + true => batch.columns(), + false => &flatten_arrs, + }; + let (temp_result, num_rows) = list_unnest_at_level( + input, + list_type_columns, + &mut temp_unnested_result, + depth, + options, + )?; + if num_rows == 0 { + return Ok(RecordBatch::new_empty(Arc::clone(schema))); + } + flatten_arrs = temp_result; + } + let unnested_array_map: HashMap> = + temp_unnested_result.into_iter().fold( + HashMap::new(), + |mut acc, + ( + ListUnnest { + index_in_input_schema, + depth, + }, + flattened_array, + )| { + acc.entry(index_in_input_schema).or_default().push( + UnnestingResult { + arr: flattened_array, + depth, + }, + ); + acc + }, + ); + let output_order: HashMap = list_type_columns + .iter() + .enumerate() + .map(|(order, unnest_def)| (*unnest_def, order)) + .collect(); + + // One original column may be unnested multiple times into separate columns + let mut multi_unnested_per_original_index = unnested_array_map + .into_iter() + .map( + // Each item in unnested_columns is the result of unnesting the same input column + // we need to sort them to conform with the original expression order + // e.g unnest(unnest(col)) must goes before unnest(col) + |(original_index, mut unnested_columns)| { + unnested_columns.sort_by( + |UnnestingResult { depth: depth1, .. }, + UnnestingResult { depth: depth2, .. }| + -> Ordering { + output_order + .get(&ListUnnest { + depth: *depth1, + index_in_input_schema: original_index, + }) + .unwrap() + .cmp( + output_order + .get(&ListUnnest { + depth: *depth2, + index_in_input_schema: original_index, + }) + .unwrap(), + ) + }, + ); + ( + original_index, + unnested_columns + .into_iter() + .map(|result| result.arr) + .collect::>(), + ) + }, + ) + .collect::>(); + + let ret = flatten_arrs + .into_iter() + .enumerate() + .flat_map(|(col_idx, arr)| { + // Convert original column into its unnested version(s) + // Plural because one column can be unnested with different recursion level + // and into separate output columns + match multi_unnested_per_original_index.remove(&col_idx) { + Some(unnested_arrays) => unnested_arrays, + None => vec![arr], + } + }) + .collect::>(); + + flatten_struct_cols(&ret, schema, struct_column_indices) + } + }; + transformed } /// Find the longest list length among the given list arrays for each row. @@ -368,7 +697,7 @@ fn find_longest_length( .collect::>()?; let longest_length = list_lengths.iter().skip(1).try_fold( - list_lengths[0].clone(), + Arc::clone(&list_lengths[0]), |longest, current| { let is_lt = lt(&longest, ¤t)?; zip(&is_lt, ¤t, &longest) @@ -439,16 +768,10 @@ fn unnest_list_arrays( }) .collect::>>()?; - // If there is only one list column to unnest and it doesn't contain any NULL lists, - // we can return the values array directly without any copying. - if typed_arrays.len() == 1 && typed_arrays[0].null_count() == 0 { - Ok(vec![typed_arrays[0].values().clone()]) - } else { - typed_arrays - .iter() - .map(|list_array| unnest_list_array(*list_array, length_array, capacity)) - .collect::>() - } + typed_arrays + .iter() + .map(|list_array| unnest_list_array(*list_array, length_array, capacity)) + .collect::>() } /// Unnest a list array according the target length array. @@ -505,7 +828,8 @@ fn unnest_list_array( )?) } -/// Creates take indicies that will be used to expand all columns except for the unnest [`columns`](UnnestExec::columns). +/// Creates take indicies that will be used to expand all columns except for the list type +/// [`columns`](UnnestExec::list_column_indices) that is being unnested. /// Every column value needs to be repeated multiple times according to the length array. /// /// If the length array looks like this: @@ -537,10 +861,10 @@ fn create_take_indicies( builder.finish() } -/// Create the final batch given the unnested column arrays and a `indices` array +/// Create the batch given an arrays and a `indices` array /// that is used by the take kernel to copy values. /// -/// For example if we have the following `RecordBatch`: +/// For example if we have the following batch: /// /// ```ignore /// c1: [1], null, [2, 3, 4], null, [5, 6] @@ -568,31 +892,23 @@ fn create_take_indicies( /// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd' /// ``` /// -fn batch_from_indices( - batch: &RecordBatch, - schema: &SchemaRef, - unnested_list_arrays: &HashMap, +fn repeat_arrs_from_indices( + batch: &[ArrayRef], indices: &PrimitiveArray, -) -> Result { - let arrays = batch - .columns() +) -> Result>> { + batch .iter() - .enumerate() - .map(|(col_idx, arr)| match unnested_list_arrays.get(&col_idx) { - Some(unnested_array) => Ok(unnested_array.clone()), - None => Ok(kernels::take::take(arr, indices, None)?), - }) - .collect::>>()?; - - Ok(RecordBatch::try_new(schema.clone(), arrays.to_vec())?) + .map(|arr| Ok(kernels::take::take(arr, indices, None)?)) + .collect::>() } #[cfg(test)] mod tests { use super::*; - use arrow::datatypes::Field; + use arrow::datatypes::{Field, Int32Type}; use arrow_array::{GenericListArray, OffsetSizeTrait, StringArray}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; + use datafusion_common::assert_batches_eq; // Create a GenericListArray with the following list values: // [A, B, C], [], NULL, [D], NULL, [NULL, F] @@ -668,7 +984,7 @@ mod tests { list_array: &dyn ListArrayType, lengths: Vec, expected: Vec>, - ) -> datafusion_common::Result<()> { + ) -> Result<()> { let length_array = Int64Array::from(lengths); let unnested_array = unnest_list_array(list_array, &length_array, 3 * 6)?; let strs = unnested_array.as_string::().iter().collect::>(); @@ -677,7 +993,139 @@ mod tests { } #[test] - fn test_unnest_list_array() -> datafusion_common::Result<()> { + fn test_build_batch_list_arr_recursive() -> Result<()> { + // col1 | col2 + // [[1,2,3],null,[4,5]] | ['a','b'] + // [[7,8,9,10], null, [11,12,13]] | ['c','d'] + // null | ['e'] + let list_arr1 = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + Some(vec![Some(7), Some(8), Some(9), Some(10)]), + None, + Some(vec![Some(11), Some(12), Some(13)]), + ]); + + let list_arr1_ref = Arc::new(list_arr1) as ArrayRef; + let offsets = OffsetBuffer::from_lengths([3, 3, 0]); + let mut nulls = BooleanBufferBuilder::new(3); + nulls.append(true); + nulls.append(true); + nulls.append(false); + // list> + let col1_field = Field::new_list_field( + DataType::List(Arc::new(Field::new_list_field( + list_arr1_ref.data_type().to_owned(), + true, + ))), + true, + ); + let col1 = ListArray::new( + Arc::new(Field::new_list_field( + list_arr1_ref.data_type().to_owned(), + true, + )), + offsets, + list_arr1_ref, + Some(NullBuffer::new(nulls.finish())), + ); + + let list_arr2 = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + ]); + + let offsets = OffsetBuffer::from_lengths([2, 2, 1]); + let mut nulls = BooleanBufferBuilder::new(3); + nulls.append_n(3, true); + let col2_field = Field::new( + "col2", + DataType::List(Arc::new(Field::new_list_field(DataType::Utf8, true))), + true, + ); + let col2 = GenericListArray::::new( + Arc::new(Field::new_list_field(DataType::Utf8, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(list_arr2), + Some(NullBuffer::new(nulls.finish())), + ); + // convert col1 and col2 to a record batch + let schema = Arc::new(Schema::new(vec![col1_field, col2_field])); + let out_schema = Arc::new(Schema::new(vec![ + Field::new( + "col1_unnest_placeholder_depth_1", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new("col1_unnest_placeholder_depth_2", DataType::Int32, true), + Field::new("col2_unnest_placeholder_depth_1", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(col1) as ArrayRef, Arc::new(col2) as ArrayRef], + ) + .unwrap(); + let list_type_columns = vec![ + ListUnnest { + index_in_input_schema: 0, + depth: 1, + }, + ListUnnest { + index_in_input_schema: 0, + depth: 2, + }, + ListUnnest { + index_in_input_schema: 1, + depth: 1, + }, + ]; + let ret = build_batch( + &batch, + &out_schema, + list_type_columns.as_ref(), + &HashSet::default(), + &UnnestOptions { + preserve_nulls: true, + recursions: vec![], + }, + )?; + + let expected = &[ +"+---------------------------------+---------------------------------+---------------------------------+", +"| col1_unnest_placeholder_depth_1 | col1_unnest_placeholder_depth_2 | col2_unnest_placeholder_depth_1 |", +"+---------------------------------+---------------------------------+---------------------------------+", +"| [1, 2, 3] | 1 | a |", +"| | 2 | b |", +"| [4, 5] | 3 | |", +"| [1, 2, 3] | | a |", +"| | | b |", +"| [4, 5] | | |", +"| [1, 2, 3] | 4 | a |", +"| | 5 | b |", +"| [4, 5] | | |", +"| [7, 8, 9, 10] | 7 | c |", +"| | 8 | d |", +"| [11, 12, 13] | 9 | |", +"| | 10 | |", +"| [7, 8, 9, 10] | | c |", +"| | | d |", +"| [11, 12, 13] | | |", +"| [7, 8, 9, 10] | 11 | c |", +"| | 12 | d |", +"| [11, 12, 13] | 13 | |", +"| | | e |", +"+---------------------------------+---------------------------------+---------------------------------+", + ]; + assert_batches_eq!(expected, &[ret]); + Ok(()) + } + + #[test] + fn test_unnest_list_array() -> Result<()> { // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = make_generic_array::(); verify_unnest_list_array( @@ -725,8 +1173,11 @@ mod tests { list_arrays: &[ArrayRef], preserve_nulls: bool, expected: Vec, - ) -> datafusion_common::Result<()> { - let options = UnnestOptions { preserve_nulls }; + ) -> Result<()> { + let options = UnnestOptions { + preserve_nulls, + recursions: vec![], + }; let longest_length = find_longest_length(list_arrays, &options)?; let expected_array = Int64Array::from(expected); assert_eq!( @@ -740,31 +1191,31 @@ mod tests { } #[test] - fn test_longest_list_length() -> datafusion_common::Result<()> { + fn test_longest_list_length() -> Result<()> { // Test with single ListArray // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = Arc::new(make_generic_array::()) as ArrayRef; - verify_longest_length(&[list_array.clone()], false, vec![3, 0, 0, 1, 0, 2])?; - verify_longest_length(&[list_array.clone()], true, vec![3, 0, 1, 1, 1, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], false, vec![3, 0, 0, 1, 0, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], true, vec![3, 0, 1, 1, 1, 2])?; // Test with single LargeListArray // [A, B, C], [], NULL, [D], NULL, [NULL, F] let list_array = Arc::new(make_generic_array::()) as ArrayRef; - verify_longest_length(&[list_array.clone()], false, vec![3, 0, 0, 1, 0, 2])?; - verify_longest_length(&[list_array.clone()], true, vec![3, 0, 1, 1, 1, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], false, vec![3, 0, 0, 1, 0, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], true, vec![3, 0, 1, 1, 1, 2])?; // Test with single FixedSizeListArray // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] let list_array = Arc::new(make_fixed_list()) as ArrayRef; - verify_longest_length(&[list_array.clone()], false, vec![2, 0, 2, 0, 2, 2])?; - verify_longest_length(&[list_array.clone()], true, vec![2, 1, 2, 1, 2, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], false, vec![2, 0, 2, 0, 2, 2])?; + verify_longest_length(&[Arc::clone(&list_array)], true, vec![2, 1, 2, 1, 2, 2])?; // Test with multiple list arrays // [A, B, C], [], NULL, [D], NULL, [NULL, F] // [A, B], NULL, [C, D], NULL, [NULL, F], [NULL, NULL] let list1 = Arc::new(make_generic_array::()) as ArrayRef; let list2 = Arc::new(make_fixed_list()) as ArrayRef; - let list_arrays = vec![list1.clone(), list2.clone()]; + let list_arrays = vec![Arc::clone(&list1), Arc::clone(&list2)]; verify_longest_length(&list_arrays, false, vec![3, 0, 2, 1, 2, 2])?; verify_longest_length(&list_arrays, true, vec![3, 1, 2, 1, 2, 2])?; @@ -772,7 +1223,7 @@ mod tests { } #[test] - fn test_create_take_indicies() -> datafusion_common::Result<()> { + fn test_create_take_indicies() -> Result<()> { let length_array = Int64Array::from(vec![2, 3, 1]); let take_indicies = create_take_indicies(&length_array, 6); let expected = Int64Array::from(vec![0, 0, 1, 1, 1, 2]); diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 2aa893fd2916..edadf98cb10c 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -36,7 +36,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::EquivalenceProperties; /// Execution plan for values list based relation (produces constant rows) -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ValuesExec { /// The schema schema: SchemaRef, @@ -47,7 +47,7 @@ pub struct ValuesExec { } impl ValuesExec { - /// create a new values exec from data as expr + /// Create a new values exec from data as expr pub fn try_new( schema: SchemaRef, data: Vec>>, @@ -57,7 +57,7 @@ impl ValuesExec { } let n_row = data.len(); let n_col = schema.fields().len(); - // we have this single row batch as a placeholder to satisfy evaluation argument + // We have this single row batch as a placeholder to satisfy evaluation argument // and generate a single output row let batch = RecordBatch::try_new_with_options( Arc::new(Schema::empty()), @@ -88,7 +88,11 @@ impl ValuesExec { .and_then(ScalarValue::iter_to_array) }) .collect::>>()?; - let batch = RecordBatch::try_new(schema.clone(), arr)?; + let batch = RecordBatch::try_new_with_options( + Arc::clone(&schema), + arr, + &RecordBatchOptions::new().with_row_count(Some(n_row)), + )?; let data: Vec = vec![batch]; Self::try_new_from_batches(schema, data) } @@ -114,7 +118,7 @@ impl ValuesExec { } } - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); Ok(ValuesExec { schema, data: batches, @@ -122,7 +126,7 @@ impl ValuesExec { }) } - /// provides the data + /// Provides the data pub fn data(&self) -> Vec { self.data.clone() } @@ -167,7 +171,7 @@ impl ExecutionPlan for ValuesExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -175,7 +179,7 @@ impl ExecutionPlan for ValuesExec { self: Arc, _: Vec>, ) -> Result> { - ValuesExec::try_new_from_batches(self.schema.clone(), self.data.clone()) + ValuesExec::try_new_from_batches(Arc::clone(&self.schema), self.data.clone()) .map(|e| Arc::new(e) as _) } @@ -193,7 +197,7 @@ impl ExecutionPlan for ValuesExec { Ok(Box::pin(MemoryStream::try_new( self.data(), - self.schema.clone(), + Arc::clone(&self.schema), None, )?)) } @@ -215,6 +219,7 @@ mod tests { use crate::test::{self, make_partition}; use arrow_schema::{DataType, Field}; + use datafusion_common::stats::{ColumnStatistics, Precision}; #[tokio::test] async fn values_empty_case() -> Result<()> { @@ -260,9 +265,39 @@ mod tests { DataType::UInt32, false, )])); - let _ = ValuesExec::try_new(schema.clone(), vec![vec![lit(1u32)]]).unwrap(); + let _ = ValuesExec::try_new(Arc::clone(&schema), vec![vec![lit(1u32)]]).unwrap(); // Test that a null value is rejected let _ = ValuesExec::try_new(schema, vec![vec![lit(ScalarValue::UInt32(None))]]) .unwrap_err(); } + + #[test] + fn values_stats_with_nulls_only() -> Result<()> { + let data = vec![ + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + vec![lit(ScalarValue::Null)], + ]; + let rows = data.len(); + let values = ValuesExec::try_new( + Arc::new(Schema::new(vec![Field::new("col0", DataType::Null, true)])), + data, + )?; + + assert_eq!( + values.statistics()?, + Statistics { + num_rows: Precision::Exact(rows), + total_byte_size: Precision::Exact(8), // not important + column_statistics: vec![ColumnStatistics { + null_count: Precision::Exact(rows), // there are only nulls + distinct_count: Precision::Absent, + max_value: Precision::Absent, + min_value: Precision::Absent, + },], + } + ); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index b1c306194813..8c0331f94570 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -27,7 +27,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::expressions::PhysicalSortExpr; +use super::utils::create_schema; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, @@ -38,18 +38,18 @@ use crate::{ ExecutionPlanProperties, InputOrderMode, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; - +use ahash::RandomState; +use arrow::compute::take_record_batch; use arrow::{ array::{Array, ArrayRef, RecordBatchOptions, UInt32Builder}, - compute::{concat, concat_batches, sort_to_indices}, - datatypes::{Schema, SchemaBuilder, SchemaRef}, + compute::{concat, concat_batches, sort_to_indices, take_arrays}, + datatypes::SchemaRef, record_batch::RecordBatch, }; use datafusion_common::hash_utils::create_hashes; use datafusion_common::stats::Precision; use datafusion_common::utils::{ - evaluate_partition_ranges, get_arrayref_at_indices, get_at_indices, - get_record_batch_at_indices, get_row_at_idx, + evaluate_partition_ranges, get_at_indices, get_row_at_idx, }; use datafusion_common::{arrow_datafusion_err, exec_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -58,9 +58,8 @@ use datafusion_expr::ColumnarValue; use datafusion_physical_expr::window::{ PartitionBatches, PartitionKey, PartitionWindowAggStates, WindowState, }; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement}; - -use ahash::RandomState; +use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use futures::stream::Stream; use futures::{ready, StreamExt}; use hashbrown::raw::RawTable; @@ -68,7 +67,7 @@ use indexmap::IndexMap; use log::debug; /// Window execution plan -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BoundedWindowAggExec { /// Input plan input: Arc, @@ -149,7 +148,7 @@ impl BoundedWindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result> { + pub fn partition_by_sort_keys(&self) -> Result { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -250,24 +249,18 @@ impl ExecutionPlan for BoundedWindowAggExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); - if self.input_order_mode != InputOrderMode::Sorted - || self.ordered_partition_by_indices.len() >= partition_bys.len() - { - let partition_bys = self - .ordered_partition_by_indices - .iter() - .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys)] - } else { - vec![calc_requirements(partition_bys, order_keys)] - } + let partition_bys = self + .ordered_partition_by_indices + .iter() + .map(|idx| &partition_bys[*idx]); + vec![calc_requirements(partition_bys, order_keys.iter())] } fn required_input_distribution(&self) -> Vec { @@ -289,7 +282,7 @@ impl ExecutionPlan for BoundedWindowAggExec { ) -> Result> { Ok(Arc::new(BoundedWindowAggExec::try_new( self.window_expr.clone(), - children[0].clone(), + Arc::clone(&children[0]), self.partition_keys.clone(), self.input_order_mode.clone(), )?)) @@ -303,7 +296,7 @@ impl ExecutionPlan for BoundedWindowAggExec { let input = self.input.execute(partition, context)?; let search_mode = self.get_search_algo()?; let stream = Box::pin(BoundedWindowAggStream::new( - self.schema.clone(), + Arc::clone(&self.schema), self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), @@ -394,7 +387,9 @@ trait PartitionSearcher: Send { // as it may not have the "correct" schema in terms of output // nullability constraints. For details, see the following issue: // https://github.com/apache/datafusion/issues/9320 - .or_insert_with(|| PartitionBatchState::new(self.input_schema().clone())); + .or_insert_with(|| { + PartitionBatchState::new(Arc::clone(self.input_schema())) + }); partition_batch_state.extend(&partition_batch)?; } @@ -513,7 +508,7 @@ impl PartitionSearcher for LinearSearch { let length = indices.len(); for (idx, window_agg_state) in window_agg_states.iter().enumerate() { let partition = &window_agg_state[&row]; - let values = partition.state.out_col.slice(0, length).clone(); + let values = Arc::clone(&partition.state.out_col.slice(0, length)); new_columns[idx].push(values); } let partition_batch_state = &mut partition_buffers[&row]; @@ -540,7 +535,9 @@ impl PartitionSearcher for LinearSearch { // We should emit columns according to row index ordering. let sorted_indices = sort_to_indices(&all_indices, None, None)?; // Construct new column according to row ordering. This fixes ordering - get_arrayref_at_indices(&new_columns, &sorted_indices).map(Some) + take_arrays(&new_columns, &sorted_indices, None) + .map(Some) + .map_err(|e| arrow_datafusion_err!(e)) } fn evaluate_partition_batches( @@ -549,7 +546,7 @@ impl PartitionSearcher for LinearSearch { window_expr: &[Arc], ) -> Result> { let partition_bys = - self.evaluate_partition_by_column_values(record_batch, window_expr)?; + evaluate_partition_by_column_values(record_batch, window_expr)?; // NOTE: In Linear or PartiallySorted modes, we are sure that // `partition_bys` are not empty. // Calculate indices for each partition and construct a new record @@ -560,7 +557,7 @@ impl PartitionSearcher for LinearSearch { let mut new_indices = UInt32Builder::with_capacity(indices.len()); new_indices.append_slice(&indices); let indices = new_indices.finish(); - Ok((row, get_record_batch_at_indices(record_batch, &indices)?)) + Ok((row, take_record_batch(record_batch, &indices)?)) }) .collect() } @@ -616,25 +613,6 @@ impl LinearSearch { } } - /// Calculates partition by expression results for each window expression - /// on `record_batch`. - fn evaluate_partition_by_column_values( - &self, - record_batch: &RecordBatch, - window_expr: &[Arc], - ) -> Result> { - window_expr[0] - .partition_by() - .iter() - .map(|item| match item.evaluate(record_batch)? { - ColumnarValue::Array(array) => Ok(array), - ColumnarValue::Scalar(scalar) => { - scalar.to_array_of_size(record_batch.num_rows()) - } - }) - .collect() - } - /// Calculate indices of each partition (according to PARTITION BY expression) /// `columns` contain partition by expression results. fn get_per_partition_indices( @@ -681,7 +659,7 @@ impl LinearSearch { window_expr: &[Arc], ) -> Result)>> { let partition_by_columns = - self.evaluate_partition_by_column_values(input_buffer, window_expr)?; + evaluate_partition_by_column_values(input_buffer, window_expr)?; // Reset the row_map state: self.row_map_out.clear(); let mut partition_indices: Vec<(PartitionKey, Vec)> = vec![]; @@ -728,7 +706,7 @@ impl LinearSearch { /// when computing partitions. pub struct SortedSearch { /// Stores partition by columns and their ordering information - partition_by_sort_keys: Vec, + partition_by_sort_keys: LexOrdering, /// Input ordering and partition by key ordering need not be the same, so /// this vector stores the mapping between them. For instance, if the input /// is ordered by a, b and the window expression contains a PARTITION BY b, a @@ -850,18 +828,22 @@ impl SortedSearch { } } -fn create_schema( - input_schema: &Schema, +/// Calculates partition by expression results for each window expression +/// on `record_batch`. +fn evaluate_partition_by_column_values( + record_batch: &RecordBatch, window_expr: &[Arc], -) -> Result { - let capacity = input_schema.fields().len() + window_expr.len(); - let mut builder = SchemaBuilder::with_capacity(capacity); - builder.extend(input_schema.fields.iter().cloned()); - // append results to the schema - for expr in window_expr { - builder.push(expr.field()?); - } - Ok(builder.finish()) +) -> Result> { + window_expr[0] + .partition_by() + .iter() + .map(|item| match item.evaluate(record_batch)? { + ColumnarValue::Array(array) => Ok(array), + ColumnarValue::Scalar(scalar) => { + scalar.to_array_of_size(record_batch.num_rows()) + } + }) + .collect() } /// Stream for the bounded window aggregation plan. @@ -935,7 +917,7 @@ impl BoundedWindowAggStream { search_mode: Box, ) -> Result { let state = window_expr.iter().map(|_| IndexMap::new()).collect(); - let empty_batch = RecordBatch::new_empty(schema.clone()); + let empty_batch = RecordBatch::new_empty(Arc::clone(&schema)); Ok(Self { schema, input, @@ -957,7 +939,7 @@ impl BoundedWindowAggStream { cur_window_expr.evaluate_stateful(&self.partition_buffers, state)?; } - let schema = self.schema.clone(); + let schema = Arc::clone(&self.schema); let window_expr_out = self.search_mode.calculate_out_columns( &self.input_buffer, &self.window_agg_states, @@ -1114,7 +1096,7 @@ impl BoundedWindowAggStream { impl RecordBatchStream for BoundedWindowAggStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1177,6 +1159,7 @@ mod tests { use std::time::Duration; use crate::common::collect; + use crate::expressions::PhysicalSortExpr; use crate::memory::MemoryExec; use crate::projection::ProjectionExec; use crate::streaming::{PartitionStream, StreamingTableExec}; @@ -1194,15 +1177,16 @@ mod tests { RecordBatchStream, SendableRecordBatchStream, TaskContext, }; use datafusion_expr::{ - AggregateFunction, WindowFrame, WindowFrameBound, WindowFrameUnits, - WindowFunctionDefinition, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; + use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::expressions::{col, Column, NthValue}; use datafusion_physical_expr::window::{ BuiltInWindowExpr, BuiltInWindowFunctionExpr, }; - use datafusion_physical_expr::{LexOrdering, PhysicalExpr, PhysicalSortExpr}; + use datafusion_physical_expr::{LexOrdering, PhysicalExpr}; + use datafusion_physical_expr_common::sort_expr::LexOrderingRef; use futures::future::Shared; use futures::{pin_mut, ready, FutureExt, Stream, StreamExt}; use itertools::Itertools; @@ -1287,7 +1271,7 @@ mod tests { impl RecordBatchStream for TestStreamPartition { fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } @@ -1298,16 +1282,15 @@ mod tests { order_by: &str, ) -> Result> { let schema = input.schema(); - let window_fn = - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count); + let window_fn = WindowFunctionDefinition::AggregateUDF(count_udaf()); let col_expr = Arc::new(Column::new(schema.fields[0].name(), 0)) as Arc; let args = vec![col_expr]; let partitionby_exprs = vec![col(hash, &schema)?]; - let orderby_exprs = vec![PhysicalSortExpr { + let orderby_exprs = LexOrdering::new(vec![PhysicalSortExpr { expr: col(order_by, &schema)?, options: SortOptions::default(), - }]; + }]); let window_frame = WindowFrame::new_bounds( WindowFrameUnits::Range, WindowFrameBound::CurrentRow, @@ -1324,8 +1307,8 @@ mod tests { fn_name, &args, &partitionby_exprs, - &orderby_exprs, - Arc::new(window_frame.clone()), + orderby_exprs.as_ref(), + Arc::new(window_frame), &input.schema(), false, )?], @@ -1421,13 +1404,13 @@ mod tests { } fn schema_orders(schema: &SchemaRef) -> Result> { - let orderings = vec![vec![PhysicalSortExpr { + let orderings = vec![LexOrdering::new(vec![PhysicalSortExpr { expr: col("sn", schema)?, options: SortOptions { descending: false, nulls_first: false, }, - }]]; + }])]; Ok(orderings) } @@ -1464,7 +1447,7 @@ mod tests { } let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(schema), vec![Arc::new(sn1_array.finish()), Arc::new(hash_array.finish())], )?; batches.push(batch); @@ -1497,8 +1480,8 @@ mod tests { // Source has 2 partitions let partitions = vec![ Arc::new(TestStreamPartition { - schema: schema.clone(), - batches: batches.clone(), + schema: Arc::clone(&schema), + batches, idx: 0, state: PolingState::BatchReturn, sleep_duration: per_batch_wait_duration, @@ -1507,11 +1490,12 @@ mod tests { n_partition ]; let source = Arc::new(StreamingTableExec::try_new( - schema.clone(), + Arc::clone(&schema), partitions, None, orderings, is_infinite, + None, )?) as _; Ok(source) } @@ -1529,28 +1513,38 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); // Create a new batch of data to insert into the table let batch = RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))], )?; let memory_exec = MemoryExec::try_new( &[vec![batch.clone(), batch.clone(), batch.clone()]], - schema.clone(), + Arc::clone(&schema), None, ) .map(|e| Arc::new(e) as Arc)?; let col_a = col("a", &schema)?; - let nth_value_func1 = - NthValue::nth("nth_value(-1)", col_a.clone(), DataType::Int32, 1, false)? - .reverse_expr() - .unwrap(); - let nth_value_func2 = - NthValue::nth("nth_value(-2)", col_a.clone(), DataType::Int32, 2, false)? - .reverse_expr() - .unwrap(); + let nth_value_func1 = NthValue::nth( + "nth_value(-1)", + Arc::clone(&col_a), + DataType::Int32, + 1, + false, + )? + .reverse_expr() + .unwrap(); + let nth_value_func2 = NthValue::nth( + "nth_value(-2)", + Arc::clone(&col_a), + DataType::Int32, + 2, + false, + )? + .reverse_expr() + .unwrap(); let last_value_func = Arc::new(NthValue::last( "last", - col_a.clone(), + Arc::clone(&col_a), DataType::Int32, false, )) as _; @@ -1559,7 +1553,7 @@ mod tests { Arc::new(BuiltInWindowExpr::new( last_value_func, &[], - &[], + LexOrderingRef::default(), Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1570,7 +1564,7 @@ mod tests { Arc::new(BuiltInWindowExpr::new( nth_value_func1, &[], - &[], + LexOrderingRef::default(), Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1581,7 +1575,7 @@ mod tests { Arc::new(BuiltInWindowExpr::new( nth_value_func2, &[], - &[], + LexOrderingRef::default(), Arc::new(WindowFrame::new_bounds( WindowFrameUnits::Rows, WindowFrameBound::Preceding(ScalarValue::UInt64(None)), @@ -1649,7 +1643,7 @@ mod tests { // // Effectively following query is run on this data // - // SELECT *, COUNT(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) + // SELECT *, count(*) OVER(PARTITION BY duplicated_hash ORDER BY sn RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) // FROM test; // // partition `duplicated_hash=2` receives following data from the input @@ -1723,8 +1717,8 @@ mod tests { let plan = projection_exec(window)?; let expected_plan = vec![ - "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, COUNT([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]@2 as col_2]", - " BoundedWindowAggExec: wdw=[COUNT([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]: Ok(Field { name: \"COUNT([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [[PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }]]\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", + "ProjectionExec: expr=[sn@0 as sn, hash@1 as hash, count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]@2 as col_2]", + " BoundedWindowAggExec: wdw=[count([Column { name: \"sn\", index: 0 }]) PARTITION BY: [[Column { name: \"hash\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \"sn\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]: Ok(Field { name: \"count([Column { name: \\\"sn\\\", index: 0 }]) PARTITION BY: [[Column { name: \\\"hash\\\", index: 1 }]], ORDER BY: [LexOrdering { inner: [PhysicalSortExpr { expr: Column { name: \\\"sn\\\", index: 0 }, options: SortOptions { descending: false, nulls_first: true } }] }]\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Linear]", " StreamingTableExec: partition_sizes=1, projection=[sn, hash], infinite_source=true, output_ordering=[sn@0 ASC NULLS LAST]", ]; diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6de94b0e4103..217823fb6a0a 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -21,37 +21,76 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::{ - aggregates, - expressions::{ - cume_dist, dense_rank, lag, lead, percent_rank, rank, Literal, NthValue, Ntile, - PhysicalSortExpr, RowNumber, - }, - udaf, ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, + expressions::{Literal, NthValue, PhysicalSortExpr}, + ExecutionPlan, ExecutionPlanProperties, InputOrderMode, PhysicalExpr, }; use arrow::datatypes::Schema; use arrow_schema::{DataType, Field, SchemaRef}; -use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue}; +use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue}; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, - WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, WindowFrame, + WindowFunctionDefinition, WindowUDF, }; +use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::equivalence::collapse_lex_req; use datafusion_physical_expr::{ reverse_order_bys, window::{BuiltInWindowFunctionExpr, SlidingAggregateWindowExpr}, - AggregateExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, + ConstExpr, EquivalenceProperties, LexOrdering, PhysicalSortRequirement, }; +use itertools::Itertools; mod bounded_window_agg_exec; +mod utils; mod window_agg_exec; pub use bounded_window_agg_exec::BoundedWindowAggExec; +use datafusion_functions_window_common::expr::ExpressionArgs; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; +use datafusion_physical_expr::expressions::Column; pub use datafusion_physical_expr::window::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowExpr, }; +use datafusion_physical_expr_common::sort_expr::{LexOrderingRef, LexRequirement}; pub use window_agg_exec::WindowAggExec; +/// Build field from window function and add it into schema +pub fn schema_add_window_field( + args: &[Arc], + schema: &Schema, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| Arc::clone(e).as_ref().data_type(schema)) + .collect::>>()?; + let nullability = args + .iter() + .map(|e| Arc::clone(e).as_ref().nullable(schema)) + .collect::>>()?; + let window_expr_return_type = + window_fn.return_type(&data_types, &nullability, fn_name)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + // Skip extending schema for UDAF + if let WindowFunctionDefinition::AggregateUDF(_) = window_fn { + Ok(Arc::new(Schema::new(window_fields))) + } else { + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + false, + )]); + Ok(Arc::new(Schema::new(window_fields))) + } +} + /// Create a physical expression for window function #[allow(clippy::too_many_arguments)] pub fn create_window_expr( @@ -59,29 +98,12 @@ pub fn create_window_expr( name: String, args: &[Arc], partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: LexOrderingRef, window_frame: Arc, input_schema: &Schema, ignore_nulls: bool, ) -> Result> { Ok(match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - let aggregate = aggregates::create_aggregate_expr( - fun, - false, - args, - &[], - input_schema, - name, - ignore_nulls, - )?; - window_expr_from_aggregate_expr( - partition_by, - order_by, - window_frame, - aggregate, - ) - } WindowFunctionDefinition::BuiltInWindowFunction(fun) => { Arc::new(BuiltInWindowExpr::new( create_built_in_window_expr(fun, args, input_schema, name, ignore_nulls)?, @@ -91,19 +113,12 @@ pub fn create_window_expr( )) } WindowFunctionDefinition::AggregateUDF(fun) => { - // TODO: Ordering not supported for Window UDFs yet - let sort_exprs = &[]; - let ordering_req = &[]; - - let aggregate = udaf::create_aggregate_expr( - fun.as_ref(), - args, - sort_exprs, - ordering_req, - input_schema, - name, - ignore_nulls, - )?; + let aggregate = AggregateExprBuilder::new(Arc::clone(fun), args.to_vec()) + .schema(Arc::new(input_schema.clone())) + .alias(name) + .with_ignore_nulls(ignore_nulls) + .build() + .map(Arc::new)?; window_expr_from_aggregate_expr( partition_by, order_by, @@ -111,8 +126,9 @@ pub fn create_window_expr( aggregate, ) } + // TODO: Ordering not supported for Window UDFs yet WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( - create_udwf_window_expr(fun, args, input_schema, name)?, + create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, partition_by, order_by, window_frame, @@ -123,9 +139,9 @@ pub fn create_window_expr( /// Creates an appropriate [`WindowExpr`] based on the window frame and fn window_expr_from_aggregate_expr( partition_by: &[Arc], - order_by: &[PhysicalSortExpr], + order_by: LexOrderingRef, window_frame: Arc, - aggregate: Arc, + aggregate: Arc, ) -> Arc { // Is there a potentially unlimited sized window frame? let unbounded_window = window_frame.start_bound.is_unbounded(); @@ -147,34 +163,16 @@ fn window_expr_from_aggregate_expr( } } -fn get_scalar_value_from_args( - args: &[Arc], - index: usize, -) -> Result> { - Ok(if let Some(field) = args.get(index) { - let tmp = field - .as_any() - .downcast_ref::() - .ok_or_else(|| DataFusionError::NotImplemented( - format!("There is only support Literal types for field at idx: {index} in Window Function"), - ))? - .value() - .clone(); - Some(tmp) - } else { - None - }) -} +fn get_signed_integer(value: ScalarValue) -> Result { + if value.is_null() { + return Ok(0); + } -fn get_casted_value( - default_value: Option, - dtype: &DataType, -) -> Result { - match default_value { - Some(v) if !v.data_type().is_null() => v.cast_to(dtype), - // If None or Null datatype - _ => ScalarValue::try_from(dtype), + if !value.data_type().is_integer() { + return exec_err!("Expected an integer value"); } + + value.cast_to(&DataType::Int64)?.try_into() } fn create_built_in_window_expr( @@ -188,72 +186,18 @@ fn create_built_in_window_expr( let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)), - BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), - BuiltInWindowFunction::Ntile => { - let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { - DataFusionError::Execution( - "NTILE requires a positive integer".to_string(), - ) - })?; - - if n.is_null() { - return exec_err!("NTILE requires a positive integer, but finds NULL"); - } - - if n.is_unsigned() { - let n: u64 = n.try_into()?; - Arc::new(Ntile::new(name, n, out_data_type)) - } else { - let n: i64 = n.try_into()?; - if n <= 0 { - return exec_err!("NTILE requires a positive integer"); - } - Arc::new(Ntile::new(name, n as u64, out_data_type)) - } - } - BuiltInWindowFunction::Lag => { - let arg = args[0].clone(); - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(|v| v.try_into()) - .and_then(|v| v.ok()); - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; - Arc::new(lag( - name, - out_data_type.clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } - BuiltInWindowFunction::Lead => { - let arg = args[0].clone(); - let shift_offset = get_scalar_value_from_args(args, 1)? - .map(|v| v.try_into()) - .and_then(|v| v.ok()); - let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; - Arc::new(lead( - name, - out_data_type.clone(), - arg, - shift_offset, - default_value, - ignore_nulls, - )) - } BuiltInWindowFunction::NthValue => { - let arg = args[0].clone(); - let n = args[1].as_any().downcast_ref::().unwrap().value(); - let n: i64 = n - .clone() - .try_into() - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + let arg = Arc::clone(&args[0]); + let n = get_signed_integer( + args[1] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + exec_datafusion_err!("Expected a signed integer literal for the second argument of nth_value, got {}", args[1]) + })? + .value() + .clone(), + )?; Arc::new(NthValue::nth( name, arg, @@ -263,7 +207,7 @@ fn create_built_in_window_expr( )?) } BuiltInWindowFunction::FirstValue => { - let arg = args[0].clone(); + let arg = Arc::clone(&args[0]); Arc::new(NthValue::first( name, arg, @@ -272,7 +216,7 @@ fn create_built_in_window_expr( )) } BuiltInWindowFunction::LastValue => { - let arg = args[0].clone(); + let arg = Arc::clone(&args[0]); Arc::new(NthValue::last( name, arg, @@ -289,6 +233,7 @@ fn create_udwf_window_expr( args: &[Arc], input_schema: &Schema, name: String, + ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason let input_types: Vec<_> = args @@ -296,13 +241,13 @@ fn create_udwf_window_expr( .map(|arg| arg.data_type(input_schema)) .collect::>()?; - // figure out the output type - let data_type = fun.return_type(&input_types)?; Ok(Arc::new(WindowUDFExpr { fun: Arc::clone(fun), args: args.to_vec(), + input_types, name, - data_type, + is_reversed: false, + ignore_nulls, })) } @@ -313,8 +258,14 @@ struct WindowUDFExpr { args: Vec>, /// Display name name: String, - /// result type - data_type: DataType, + /// Types of input expressions + input_types: Vec, + /// This is set to `true` only if the user-defined window function + /// expression supports evaluation in reverse order, and the + /// evaluation order is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is defined, `false` otherwise. + ignore_nulls: bool, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -323,16 +274,23 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn field(&self) -> Result { - let nullable = true; - Ok(Field::new(&self.name, self.data_type.clone(), nullable)) + self.fun + .field(WindowUDFFieldArgs::new(&self.input_types, &self.name)) } fn expressions(&self) -> Vec> { - self.args.clone() + self.fun + .expressions(ExpressionArgs::new(&self.args, &self.input_types)) } fn create_evaluator(&self) -> Result> { - self.fun.partition_evaluator_factory() + self.fun + .partition_evaluator_factory(PartitionEvaluatorArgs::new( + &self.args, + &self.input_types, + self.is_reversed, + self.ignore_nulls, + )) } fn name(&self) -> &str { @@ -340,7 +298,28 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn reverse_expr(&self) -> Option> { - None + match self.fun.reverse_expr() { + ReversedUDWF::Identical => Some(Arc::new(self.clone())), + ReversedUDWF::NotSupported => None, + ReversedUDWF::Reversed(fun) => Some(Arc::new(WindowUDFExpr { + fun, + args: self.args.clone(), + name: self.name.clone(), + input_types: self.input_types.clone(), + is_reversed: !self.is_reversed, + ignore_nulls: self.ignore_nulls, + })), + } + } + + fn get_result_ordering(&self, schema: &SchemaRef) -> Option { + self.fun + .sort_options() + .zip(schema.column_with_name(self.name())) + .map(|(options, (idx, field))| { + let expr = Arc::new(Column::new(field.name(), idx)); + PhysicalSortExpr { expr, options } + }) } } @@ -350,17 +329,22 @@ pub(crate) fn calc_requirements< >( partition_by_exprs: impl IntoIterator, orderby_sort_exprs: impl IntoIterator, -) -> Option> { - let mut sort_reqs = partition_by_exprs - .into_iter() - .map(|partition_by| { - PhysicalSortRequirement::new(partition_by.borrow().clone(), None) - }) - .collect::>(); +) -> Option { + let mut sort_reqs = LexRequirement::new( + partition_by_exprs + .into_iter() + .map(|partition_by| { + PhysicalSortRequirement::new(Arc::clone(partition_by.borrow()), None) + }) + .collect::>(), + ); for element in orderby_sort_exprs.into_iter() { let PhysicalSortExpr { expr, options } = element.borrow(); if !sort_reqs.iter().any(|e| e.expr.eq(expr)) { - sort_reqs.push(PhysicalSortRequirement::new(expr.clone(), Some(*options))); + sort_reqs.push(PhysicalSortRequirement::new( + Arc::clone(expr), + Some(*options), + )); } } // Convert empty result to None. Otherwise wrap result inside Some() @@ -389,7 +373,7 @@ pub(crate) fn get_partition_by_sort_exprs( ) -> Result { let ordered_partition_exprs = ordered_partition_by_indices .iter() - .map(|idx| partition_by_exprs[*idx].clone()) + .map(|idx| Arc::clone(&partition_by_exprs[*idx])) .collect::>(); // Make sure ordered section doesn't move over the partition by expression assert!(ordered_partition_by_indices.len() <= partition_by_exprs.len()); @@ -410,7 +394,7 @@ pub(crate) fn window_equivalence_properties( ) -> EquivalenceProperties { // We need to update the schema, so we can not directly use // `input.equivalence_properties()`. - let mut window_eq_properties = EquivalenceProperties::new(schema.clone()) + let mut window_eq_properties = EquivalenceProperties::new(Arc::clone(schema)) .extend(input.equivalence_properties().clone()); for expr in window_expr { @@ -481,7 +465,7 @@ pub fn get_best_fitting_window( if window_expr.iter().all(|e| e.uses_bounded_memory()) { Ok(Some(Arc::new(BoundedWindowAggExec::try_new( window_expr, - input.clone(), + Arc::clone(input), physical_partition_keys.to_vec(), input_order_mode, )?) as _)) @@ -494,7 +478,7 @@ pub fn get_best_fitting_window( } else { Ok(Some(Arc::new(WindowAggExec::try_new( window_expr, - input.clone(), + Arc::clone(input), physical_partition_keys.to_vec(), )?) as _)) } @@ -507,30 +491,40 @@ pub fn get_best_fitting_window( /// (input ordering is not sufficient to run current window executor). /// - A `Some((bool, InputOrderMode))` value indicates that the window operator /// can run with existing input ordering, so we can remove `SortExec` before it. +/// /// The `bool` field in the return value represents whether we should reverse window /// operator to remove `SortExec` before it. The `InputOrderMode` field represents /// the mode this window operator should work in to accommodate the existing ordering. pub fn get_window_mode( partitionby_exprs: &[Arc], - orderby_keys: &[PhysicalSortExpr], + orderby_keys: LexOrderingRef, input: &Arc, ) -> Option<(bool, InputOrderMode)> { let input_eqs = input.equivalence_properties().clone(); - let mut partition_by_reqs: Vec = vec![]; + let mut partition_by_reqs: LexRequirement = LexRequirement::new(vec![]); let (_, indices) = input_eqs.find_longest_permutation(partitionby_exprs); - partition_by_reqs.extend(indices.iter().map(|&idx| PhysicalSortRequirement { - expr: partitionby_exprs[idx].clone(), + vec![].extend(indices.iter().map(|&idx| PhysicalSortRequirement { + expr: Arc::clone(&partitionby_exprs[idx]), options: None, })); + partition_by_reqs + .inner + .extend(indices.iter().map(|&idx| PhysicalSortRequirement { + expr: Arc::clone(&partitionby_exprs[idx]), + options: None, + })); // Treat partition by exprs as constant. During analysis of requirements are satisfied. - let partition_by_eqs = input_eqs.add_constants(partitionby_exprs.iter().cloned()); - let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys); + let const_exprs = partitionby_exprs.iter().map(ConstExpr::from); + let partition_by_eqs = input_eqs.with_constants(const_exprs); + let order_by_reqs = PhysicalSortRequirement::from_sort_exprs(orderby_keys.iter()); let reverse_order_by_reqs = - PhysicalSortRequirement::from_sort_exprs(&reverse_order_bys(orderby_keys)); + PhysicalSortRequirement::from_sort_exprs(reverse_order_bys(orderby_keys).iter()); for (should_swap, order_by_reqs) in [(false, order_by_reqs), (true, reverse_order_by_reqs)] { - let req = [partition_by_reqs.clone(), order_by_reqs].concat(); + let req = LexRequirement::new( + [partition_by_reqs.inner.clone(), order_by_reqs.inner].concat(), + ); let req = collapse_lex_req(req); if partition_by_eqs.ordering_satisfy_requirement(&req) { // Window can be run with existing ordering @@ -550,7 +544,6 @@ pub fn get_window_mode( #[cfg(test)] mod tests { use super::*; - use crate::aggregates::AggregateFunction; use crate::collect; use crate::expressions::col; use crate::streaming::StreamingTableExec; @@ -560,8 +553,8 @@ mod tests { use arrow::compute::SortOptions; use datafusion_execution::TaskContext; + use datafusion_functions_aggregate::count::count_udaf; use futures::FutureExt; - use InputOrderMode::{Linear, PartiallySorted, Sorted}; fn create_test_schema() -> Result { @@ -619,11 +612,12 @@ mod tests { let sort_exprs = sort_exprs.into_iter().collect(); Ok(Arc::new(StreamingTableExec::try_new( - schema.clone(), + Arc::clone(schema), vec![], None, Some(sort_exprs), infinite_source, + None, )?)) } @@ -672,7 +666,7 @@ mod tests { orderbys.push(PhysicalSortExpr { expr, options }); } - let mut expected: Option> = None; + let mut expected: Option = None; for (col_name, reqs) in expected_params { let options = reqs.map(|(descending, nulls_first)| SortOptions { descending, @@ -683,7 +677,7 @@ mod tests { if let Some(expected) = &mut expected { expected.push(res); } else { - expected = Some(vec![res]); + expected = Some(LexRequirement::new(vec![res])); } } assert_eq!(calc_requirements(partitionbys, orderbys), expected); @@ -701,11 +695,11 @@ mod tests { let refs = blocking_exec.refs(); let window_agg_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( - &WindowFunctionDefinition::AggregateFunction(AggregateFunction::Count), + &WindowFunctionDefinition::AggregateUDF(count_udaf()), "count".to_owned(), &[col("a", &schema)?], &[], - &[], + LexOrderingRef::default(), Arc::new(WindowFrame::new(None)), schema.as_ref(), false, @@ -725,7 +719,7 @@ mod tests { } #[tokio::test] - async fn test_satisfiy_nullable() -> Result<()> { + async fn test_satisfy_nullable() -> Result<()> { let schema = create_test_schema()?; let params = vec![ ((true, true), (false, false), false), @@ -902,7 +896,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = vec![]; + let mut order_by_exprs = LexOrdering::default(); for col_name in order_by_params { let expr = col(col_name, &test_schema)?; // Give default ordering, this is same with input ordering direction @@ -910,8 +904,11 @@ mod tests { let options = SortOptions::default(); order_by_exprs.push(PhysicalSortExpr { expr, options }); } - let res = - get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded); + let res = get_window_mode( + &partition_by_exprs, + order_by_exprs.as_ref(), + &exec_unbounded, + ); // Since reversibility is not important in this test. Convert Option<(bool, InputOrderMode)> to Option let res = res.map(|(_, mode)| mode); assert_eq!( @@ -1064,7 +1061,7 @@ mod tests { partition_by_exprs.push(col(col_name, &test_schema)?); } - let mut order_by_exprs = vec![]; + let mut order_by_exprs = LexOrdering::default(); for (col_name, descending, nulls_first) in order_by_params { let expr = col(col_name, &test_schema)?; let options = SortOptions { @@ -1075,7 +1072,7 @@ mod tests { } assert_eq!( - get_window_mode(&partition_by_exprs, &order_by_exprs, &exec_unbounded), + get_window_mode(&partition_by_exprs, order_by_exprs.as_ref(), &exec_unbounded), *expected, "Unexpected result for in unbounded test case#: {case_idx:?}, case: {test_case:?}" ); diff --git a/datafusion/physical-plan/src/windows/utils.rs b/datafusion/physical-plan/src/windows/utils.rs new file mode 100644 index 000000000000..13332ea82fa1 --- /dev/null +++ b/datafusion/physical-plan/src/windows/utils.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Schema, SchemaBuilder}; +use datafusion_common::Result; +use datafusion_physical_expr::window::WindowExpr; +use std::sync::Arc; + +pub(crate) fn create_schema( + input_schema: &Schema, + window_expr: &[Arc], +) -> Result { + let capacity = input_schema.fields().len() + window_expr.len(); + let mut builder = SchemaBuilder::with_capacity(capacity); + builder.extend(input_schema.fields().iter().cloned()); + // append results to the schema + for expr in window_expr { + builder.push(expr.field()?); + } + Ok(builder + .finish() + .with_metadata(input_schema.metadata().clone())) +} diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index 46ba21bd797e..f71a0b9fd095 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -22,8 +22,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use crate::common::transpose; -use crate::expressions::PhysicalSortExpr; +use super::utils::create_schema; use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use crate::windows::{ calc_requirements, get_ordered_partition_by_indices, get_partition_by_sort_exprs, @@ -34,22 +33,20 @@ use crate::{ ExecutionPlan, ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; - use arrow::array::ArrayRef; use arrow::compute::{concat, concat_batches}; -use arrow::datatypes::{Schema, SchemaBuilder, SchemaRef}; +use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use datafusion_common::stats::Precision; -use datafusion_common::utils::evaluate_partition_ranges; +use datafusion_common::utils::{evaluate_partition_ranges, transpose}; use datafusion_common::{internal_err, Result}; use datafusion_execution::TaskContext; -use datafusion_physical_expr::PhysicalSortRequirement; - +use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; use futures::{ready, Stream, StreamExt}; /// Window execution plan -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct WindowAggExec { /// Input plan pub(crate) input: Arc, @@ -80,7 +77,7 @@ impl WindowAggExec { let ordered_partition_by_indices = get_ordered_partition_by_indices(window_expr[0].partition_by(), &input); - let cache = Self::compute_properties(schema.clone(), &input, &window_expr); + let cache = Self::compute_properties(Arc::clone(&schema), &input, &window_expr); Ok(Self { input, window_expr, @@ -107,7 +104,7 @@ impl WindowAggExec { // We are sure that partition by columns are always at the beginning of sort_keys // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely // to calculate partition separation points - pub fn partition_by_sort_keys(&self) -> Result> { + pub fn partition_by_sort_keys(&self) -> Result { let partition_by = self.window_expr()[0].partition_by(); get_partition_by_sort_exprs( &self.input, @@ -127,7 +124,7 @@ impl WindowAggExec { // Get output partitioning: // Because we can have repartitioning using the partition keys this - // would be either 1 or more than 1 depending on the presense of repartitioning. + // would be either 1 or more than 1 depending on the presence of repartitioning. let output_partitioning = input.output_partitioning().clone(); // Determine execution mode: @@ -185,25 +182,25 @@ impl ExecutionPlan for WindowAggExec { &self.cache } - fn children(&self) -> Vec> { - vec![self.input.clone()] + fn children(&self) -> Vec<&Arc> { + vec![&self.input] } fn maintains_input_order(&self) -> Vec { vec![true] } - fn required_input_ordering(&self) -> Vec>> { + fn required_input_ordering(&self) -> Vec> { let partition_bys = self.window_expr()[0].partition_by(); let order_keys = self.window_expr()[0].order_by(); if self.ordered_partition_by_indices.len() < partition_bys.len() { - vec![calc_requirements(partition_bys, order_keys)] + vec![calc_requirements(partition_bys, order_keys.iter())] } else { let partition_bys = self .ordered_partition_by_indices .iter() .map(|idx| &partition_bys[*idx]); - vec![calc_requirements(partition_bys, order_keys)] + vec![calc_requirements(partition_bys, order_keys.iter())] } } @@ -221,7 +218,7 @@ impl ExecutionPlan for WindowAggExec { ) -> Result> { Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), - children[0].clone(), + Arc::clone(&children[0]), self.partition_keys.clone(), )?)) } @@ -233,7 +230,7 @@ impl ExecutionPlan for WindowAggExec { ) -> Result { let input = self.input.execute(partition, context)?; let stream = Box::pin(WindowAggStream::new( - self.schema.clone(), + Arc::clone(&self.schema), self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), @@ -266,20 +263,6 @@ impl ExecutionPlan for WindowAggExec { } } -fn create_schema( - input_schema: &Schema, - window_expr: &[Arc], -) -> Result { - let capacity = input_schema.fields().len() + window_expr.len(); - let mut builder = SchemaBuilder::with_capacity(capacity); - builder.extend(input_schema.fields().iter().cloned()); - // append results to the schema - for expr in window_expr { - builder.push(expr.field()?); - } - Ok(builder.finish()) -} - /// Compute the window aggregate columns fn compute_window_aggregates( window_expr: &[Arc], @@ -298,7 +281,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, - partition_by_sort_keys: Vec, + partition_by_sort_keys: LexOrdering, baseline_metrics: BaselineMetrics, ordered_partition_by_indices: Vec, } @@ -310,7 +293,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, - partition_by_sort_keys: Vec, + partition_by_sort_keys: LexOrdering, ordered_partition_by_indices: Vec, ) -> Result { // In WindowAggExec all partition by columns should be ordered. @@ -334,7 +317,7 @@ impl WindowAggStream { let _timer = self.baseline_metrics.elapsed_compute().timer(); let batch = concat_batches(&self.input.schema(), &self.batches)?; if batch.num_rows() == 0 { - return Ok(RecordBatch::new_empty(self.schema.clone())); + return Ok(RecordBatch::new_empty(Arc::clone(&self.schema))); } let partition_by_sort_keys = self @@ -367,7 +350,10 @@ impl WindowAggStream { let mut batch_columns = batch.columns().to_vec(); // calculate window cols batch_columns.extend_from_slice(&columns); - Ok(RecordBatch::try_new(self.schema.clone(), batch_columns)?) + Ok(RecordBatch::try_new( + Arc::clone(&self.schema), + batch_columns, + )?) } } @@ -413,6 +399,6 @@ impl WindowAggStream { impl RecordBatchStream for WindowAggStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.schema.clone() + Arc::clone(&self.schema) } } diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index b3c9043d4fdc..61d444171cc7 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -110,7 +110,7 @@ pub struct WorkTableExec { impl WorkTableExec { /// Create a new execution plan for a worktable exec. pub fn new(name: String, schema: SchemaRef) -> Self { - let cache = Self::compute_properties(schema.clone()); + let cache = Self::compute_properties(Arc::clone(&schema)); Self { name, schema, @@ -123,7 +123,7 @@ impl WorkTableExec { pub(super) fn with_work_table(&self, work_table: Arc) -> Self { Self { name: self.name.clone(), - schema: self.schema.clone(), + schema: Arc::clone(&self.schema), metrics: ExecutionPlanMetricsSet::new(), work_table, cache: self.cache.clone(), @@ -169,7 +169,7 @@ impl ExecutionPlan for WorkTableExec { &self.cache } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc> { vec![] } @@ -185,7 +185,7 @@ impl ExecutionPlan for WorkTableExec { self: Arc, _: Vec>, ) -> Result> { - Ok(self.clone()) + Ok(Arc::clone(&self) as Arc) } /// Stream the batches that were written to the work table. @@ -202,7 +202,7 @@ impl ExecutionPlan for WorkTableExec { } let batch = self.work_table.take()?; Ok(Box::pin( - MemoryStream::try_new(batch.batches, self.schema.clone(), None)? + MemoryStream::try_new(batch.batches, Arc::clone(&self.schema), None)? .with_reservation(batch.reservation), )) } @@ -225,31 +225,31 @@ mod tests { #[test] fn test_work_table() { let work_table = WorkTable::new(); - // cann't take from empty work_table + // Can't take from empty work_table assert!(work_table.take().is_err()); let pool = Arc::new(UnboundedMemoryPool::default()) as _; let mut reservation = MemoryConsumer::new("test_work_table").register(&pool); - // update batch to work_table + // Update batch to work_table let array: ArrayRef = Arc::new((0..5).collect::()); let batch = RecordBatch::try_from_iter(vec![("col", array)]).unwrap(); reservation.try_grow(100).unwrap(); work_table.update(ReservedBatches::new(vec![batch.clone()], reservation)); - // take from work_table + // Take from work_table let reserved_batches = work_table.take().unwrap(); assert_eq!(reserved_batches.batches, vec![batch.clone()]); - // consume the batch by the MemoryStream + // Consume the batch by the MemoryStream let memory_stream = MemoryStream::try_new(reserved_batches.batches, batch.schema(), None) .unwrap() .with_reservation(reserved_batches.reservation); - // should still be reserved + // Should still be reserved assert_eq!(pool.reserved(), 100); - // the reservation should be freed after drop the memory_stream + // The reservation should be freed after drop the memory_stream drop(memory_stream); assert_eq!(pool.reserved(), 0); } diff --git a/datafusion/proto-common/.gitignore b/datafusion/proto-common/.gitignore new file mode 100644 index 000000000000..cfa46f072b0b --- /dev/null +++ b/datafusion/proto-common/.gitignore @@ -0,0 +1,4 @@ +# Files generated by regen.sh +proto/proto_descriptor.bin +src/datafusion_common.rs +src/datafusion_common.serde.rs diff --git a/datafusion/proto-common/Cargo.toml b/datafusion/proto-common/Cargo.toml new file mode 100644 index 000000000000..6c53e1b1ced0 --- /dev/null +++ b/datafusion/proto-common/Cargo.toml @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-proto-common" +description = "Protobuf serialization of DataFusion common types" +keywords = ["arrow", "query", "sql"] +version = { workspace = true } +edition = { workspace = true } +readme = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = "1.79" + +# Exclude proto files so crates.io consumers don't need protoc +exclude = ["*.proto"] + +[lib] +name = "datafusion_proto_common" +path = "src/lib.rs" + +[features] +default = [] +json = ["serde", "serde_json", "pbjson"] + +[dependencies] +arrow = { workspace = true } +chrono = { workspace = true } +datafusion-common = { workspace = true } +object_store = { workspace = true } +pbjson = { workspace = true, optional = true } +prost = { workspace = true } +serde = { version = "1.0", optional = true } +serde_json = { workspace = true, optional = true } + +[dev-dependencies] +doc-comment = { workspace = true } +tokio = { workspace = true } diff --git a/datafusion/proto-common/README.md b/datafusion/proto-common/README.md new file mode 100644 index 000000000000..c8b46424f701 --- /dev/null +++ b/datafusion/proto-common/README.md @@ -0,0 +1,28 @@ + + +# `datafusion-proto-common`: Apache DataFusion Protobuf Serialization / Deserialization + +This crate contains code to convert Apache [DataFusion] primitive types to and from +bytes, which can be useful for sending data over the network. + +See [API Docs] for details and examples. + +[datafusion]: https://datafusion.apache.org +[api docs]: http://docs.rs/datafusion-proto/latest diff --git a/docs/Cargo.toml b/datafusion/proto-common/gen/Cargo.toml similarity index 76% rename from docs/Cargo.toml rename to datafusion/proto-common/gen/Cargo.toml index 14398c841579..6e5783f467a7 100644 --- a/docs/Cargo.toml +++ b/datafusion/proto-common/gen/Cargo.toml @@ -16,20 +16,23 @@ # under the License. [package] -name = "datafusion-docs-tests" -description = "DataFusion Documentation Tests" -publish = false -version = { workspace = true } +name = "gen-common" +description = "Code generation for proto" +version = "0.1.0" edition = { workspace = true } -readme = { workspace = true } +rust-version = "1.79" +authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } license = { workspace = true } -authors = { workspace = true } -rust-version = { workspace = true } +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lints] workspace = true [dependencies] -datafusion = { workspace = true } +# Pin these dependencies so that the generated output is deterministic +pbjson-build = "=0.7.0" +prost-build = "=0.13.3" diff --git a/datafusion/proto-common/gen/src/main.rs b/datafusion/proto-common/gen/src/main.rs new file mode 100644 index 000000000000..2cbe2afa5488 --- /dev/null +++ b/datafusion/proto-common/gen/src/main.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::path::Path; + +type Error = Box; +type Result = std::result::Result; + +fn main() -> Result<(), String> { + let proto_dir = Path::new("proto"); + let proto_path = Path::new("proto/datafusion_common.proto"); + + // proto definitions has to be there + let descriptor_path = proto_dir.join("proto_descriptor.bin"); + + prost_build::Config::new() + .file_descriptor_set_path(&descriptor_path) + .out_dir("src") + .compile_well_known_types() + .extern_path(".google.protobuf", "::pbjson_types") + .compile_protos(&[proto_path], &["proto"]) + .map_err(|e| format!("protobuf compilation failed: {e}"))?; + + let descriptor_set = std::fs::read(&descriptor_path) + .unwrap_or_else(|e| panic!("Cannot read {:?}: {}", &descriptor_path, e)); + + pbjson_build::Builder::new() + .out_dir("src") + .register_descriptors(&descriptor_set) + .unwrap_or_else(|e| { + panic!("Cannot register descriptors {:?}: {}", &descriptor_set, e) + }) + .build(&[".datafusion_common"]) + .map_err(|e| format!("pbjson compilation failed: {e}"))?; + + let prost = Path::new("src/datafusion_common.rs"); + let pbjson = Path::new("src/datafusion_common.serde.rs"); + + std::fs::copy(prost, "src/generated/prost.rs").unwrap(); + std::fs::copy(pbjson, "src/generated/pbjson.rs").unwrap(); + + Ok(()) +} diff --git a/datafusion/proto-common/proto/datafusion_common.proto b/datafusion/proto-common/proto/datafusion_common.proto new file mode 100644 index 000000000000..65cd33d523cd --- /dev/null +++ b/datafusion/proto-common/proto/datafusion_common.proto @@ -0,0 +1,573 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package datafusion_common; + +message ColumnRelation { + string relation = 1; +} + +message Column { + string name = 1; + ColumnRelation relation = 2; +} + +message DfField{ + Field field = 1; + ColumnRelation qualifier = 2; +} + +message DfSchema { + repeated DfField columns = 1; + map metadata = 2; +} + +message CsvFormat { + CsvOptions options = 5; +} + +message ParquetFormat { + // Used to be bool enable_pruning = 1; + reserved 1; + TableParquetOptions options = 2; +} + +message AvroFormat {} + +message NdJsonFormat { + JsonOptions options = 1; +} + + +message PrimaryKeyConstraint{ + repeated uint64 indices = 1; +} + +message UniqueConstraint{ + repeated uint64 indices = 1; +} + +message Constraint{ + oneof constraint_mode{ + PrimaryKeyConstraint primary_key = 1; + UniqueConstraint unique = 2; + } +} + +message Constraints{ + repeated Constraint constraints = 1; +} + +enum JoinType { + INNER = 0; + LEFT = 1; + RIGHT = 2; + FULL = 3; + LEFTSEMI = 4; + LEFTANTI = 5; + RIGHTSEMI = 6; + RIGHTANTI = 7; + LEFTMARK = 8; +} + +enum JoinConstraint { + ON = 0; + USING = 1; +} + +message AvroOptions {} +message ArrowOptions {} + +message Schema { + repeated Field columns = 1; + map metadata = 2; +} + +message Field { + // name of the field + string name = 1; + ArrowType arrow_type = 2; + bool nullable = 3; + // for complex data types like structs, unions + repeated Field children = 4; + map metadata = 5; + int64 dict_id = 6; + bool dict_ordered = 7; +} + +message Timestamp{ + TimeUnit time_unit = 1; + string timezone = 2; +} + +enum TimeUnit{ + Second = 0; + Millisecond = 1; + Microsecond = 2; + Nanosecond = 3; +} + +enum IntervalUnit{ + YearMonth = 0; + DayTime = 1; + MonthDayNano = 2; +} + +message Decimal{ + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + +message Decimal256Type{ + reserved 1, 2; + uint32 precision = 3; + int32 scale = 4; +} + +message List{ + Field field_type = 1; +} + +message FixedSizeList{ + Field field_type = 1; + int32 list_size = 2; +} + +message Dictionary{ + ArrowType key = 1; + ArrowType value = 2; +} + +message Struct{ + repeated Field sub_field_types = 1; +} + +message Map { + Field field_type = 1; + bool keys_sorted = 2; +} + +enum UnionMode{ + sparse = 0; + dense = 1; +} + +message Union{ + repeated Field union_types = 1; + UnionMode union_mode = 2; + repeated int32 type_ids = 3; +} + +// Used for List/FixedSizeList/LargeList/Struct/Map +message ScalarNestedValue { + message Dictionary { + bytes ipc_message = 1; + bytes arrow_data = 2; + } + + bytes ipc_message = 1; + bytes arrow_data = 2; + Schema schema = 3; + repeated Dictionary dictionaries = 4; +} + +message ScalarTime32Value { + oneof value { + int32 time32_second_value = 1; + int32 time32_millisecond_value = 2; + }; +} + +message ScalarTime64Value { + oneof value { + int64 time64_microsecond_value = 1; + int64 time64_nanosecond_value = 2; + }; +} + +message ScalarTimestampValue { + oneof value { + int64 time_microsecond_value = 1; + int64 time_nanosecond_value = 2; + int64 time_second_value = 3; + int64 time_millisecond_value = 4; + }; + string timezone = 5; +} + +message ScalarDictionaryValue { + ArrowType index_type = 1; + ScalarValue value = 2; +} + +message IntervalDayTimeValue { + int32 days = 1; + int32 milliseconds = 2; +} + +message IntervalMonthDayNanoValue { + int32 months = 1; + int32 days = 2; + int64 nanos = 3; +} + +message UnionField { + int32 field_id = 1; + Field field = 2; +} + +message UnionValue { + // Note that a null union value must have one or more fields, so we + // encode a null UnionValue as one with value_id == 128 + int32 value_id = 1; + ScalarValue value = 2; + repeated UnionField fields = 3; + UnionMode mode = 4; +} + +message ScalarFixedSizeBinary{ + bytes values = 1; + int32 length = 2; +} + +message ScalarValue{ + // was PrimitiveScalarType null_value = 19; + reserved 19; + + oneof value { + // was PrimitiveScalarType null_value = 19; + // Null value of any type + ArrowType null_value = 33; + + bool bool_value = 1; + string utf8_value = 2; + string large_utf8_value = 3; + string utf8_view_value = 23; + int32 int8_value = 4; + int32 int16_value = 5; + int32 int32_value = 6; + int64 int64_value = 7; + uint32 uint8_value = 8; + uint32 uint16_value = 9; + uint32 uint32_value = 10; + uint64 uint64_value = 11; + float float32_value = 12; + double float64_value = 13; + // Literal Date32 value always has a unit of day + int32 date_32_value = 14; + ScalarTime32Value time32_value = 15; + ScalarNestedValue large_list_value = 16; + ScalarNestedValue list_value = 17; + ScalarNestedValue fixed_size_list_value = 18; + ScalarNestedValue struct_value = 32; + ScalarNestedValue map_value = 41; + + Decimal128 decimal128_value = 20; + Decimal256 decimal256_value = 39; + + int64 date_64_value = 21; + int32 interval_yearmonth_value = 24; + + int64 duration_second_value = 35; + int64 duration_millisecond_value = 36; + int64 duration_microsecond_value = 37; + int64 duration_nanosecond_value = 38; + + ScalarTimestampValue timestamp_value = 26; + ScalarDictionaryValue dictionary_value = 27; + bytes binary_value = 28; + bytes large_binary_value = 29; + bytes binary_view_value = 22; + ScalarTime64Value time64_value = 30; + IntervalDayTimeValue interval_daytime_value = 25; + IntervalMonthDayNanoValue interval_month_day_nano = 31; + ScalarFixedSizeBinary fixed_size_binary_value = 34; + UnionValue union_value = 42; + } +} + +message Decimal128{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + +message Decimal256{ + bytes value = 1; + int64 p = 2; + int64 s = 3; +} + +// Serialized data type +message ArrowType{ + oneof arrow_type_enum { + EmptyMessage NONE = 1; // arrow::Type::NA + EmptyMessage BOOL = 2; // arrow::Type::BOOL + EmptyMessage UINT8 = 3; // arrow::Type::UINT8 + EmptyMessage INT8 = 4; // arrow::Type::INT8 + EmptyMessage UINT16 = 5; // represents arrow::Type fields in src/arrow/type.h + EmptyMessage INT16 = 6; + EmptyMessage UINT32 = 7; + EmptyMessage INT32 = 8; + EmptyMessage UINT64 = 9; + EmptyMessage INT64 = 10 ; + EmptyMessage FLOAT16 = 11 ; + EmptyMessage FLOAT32 = 12 ; + EmptyMessage FLOAT64 = 13 ; + EmptyMessage UTF8 = 14 ; + EmptyMessage UTF8_VIEW = 35; + EmptyMessage LARGE_UTF8 = 32; + EmptyMessage BINARY = 15 ; + EmptyMessage BINARY_VIEW = 34; + int32 FIXED_SIZE_BINARY = 16 ; + EmptyMessage LARGE_BINARY = 31; + EmptyMessage DATE32 = 17 ; + EmptyMessage DATE64 = 18 ; + TimeUnit DURATION = 19; + Timestamp TIMESTAMP = 20 ; + TimeUnit TIME32 = 21 ; + TimeUnit TIME64 = 22 ; + IntervalUnit INTERVAL = 23 ; + Decimal DECIMAL = 24 ; + Decimal256Type DECIMAL256 = 36; + List LIST = 25; + List LARGE_LIST = 26; + FixedSizeList FIXED_SIZE_LIST = 27; + Struct STRUCT = 28; + Union UNION = 29; + Dictionary DICTIONARY = 30; + Map MAP = 33; + } +} + +//Useful for representing an empty enum variant in rust +// E.G. enum example{One, Two(i32)} +// maps to +// message example{ +// oneof{ +// EmptyMessage One = 1; +// i32 Two = 2; +// } +//} +message EmptyMessage{} + +enum CompressionTypeVariant { + GZIP = 0; + BZIP2 = 1; + XZ = 2; + ZSTD = 3; + UNCOMPRESSED = 4; +} + +message JsonWriterOptions { + CompressionTypeVariant compression = 1; +} + + +message CsvWriterOptions { + // Compression type + CompressionTypeVariant compression = 1; + // Optional column delimiter. Defaults to `b','` + string delimiter = 2; + // Whether to write column names as file headers. Defaults to `true` + bool has_header = 3; + // Optional date format for date arrays + string date_format = 4; + // Optional datetime format for datetime arrays + string datetime_format = 5; + // Optional timestamp format for timestamp arrays + string timestamp_format = 6; + // Optional time format for time arrays + string time_format = 7; + // Optional value to represent null + string null_value = 8; + // Optional quote. Defaults to `b'"'` + string quote = 9; + // Optional escape. Defaults to `'\\'` + string escape = 10; + // Optional flag whether to double quotes, instead of escaping. Defaults to `true` + bool double_quote = 11; +} + +// Options controlling CSV format +message CsvOptions { + bytes has_header = 1; // Indicates if the CSV has a header row + bytes delimiter = 2; // Delimiter character as a byte + bytes quote = 3; // Quote character as a byte + bytes escape = 4; // Optional escape character as a byte + CompressionTypeVariant compression = 5; // Compression type + uint64 schema_infer_max_rec = 6; // Max records for schema inference + string date_format = 7; // Optional date format + string datetime_format = 8; // Optional datetime format + string timestamp_format = 9; // Optional timestamp format + string timestamp_tz_format = 10; // Optional timestamp with timezone format + string time_format = 11; // Optional time format + string null_value = 12; // Optional representation of null value + bytes comment = 13; // Optional comment character as a byte + bytes double_quote = 14; // Indicates if quotes are doubled + bytes newlines_in_values = 15; // Indicates if newlines are supported in values + bytes terminator = 16; // Optional terminator character as a byte +} + +// Options controlling CSV format +message JsonOptions { + CompressionTypeVariant compression = 1; // Compression type + uint64 schema_infer_max_rec = 2; // Max records for schema inference +} + +message TableParquetOptions { + ParquetOptions global = 1; + repeated ParquetColumnSpecificOptions column_specific_options = 2; + map key_value_metadata = 3; +} + +message ParquetColumnSpecificOptions { + string column_name = 1; + ParquetColumnOptions options = 2; +} + +message ParquetColumnOptions { + oneof bloom_filter_enabled_opt { + bool bloom_filter_enabled = 1; + } + + oneof encoding_opt { + string encoding = 2; + } + + oneof dictionary_enabled_opt { + bool dictionary_enabled = 3; + } + + oneof compression_opt { + string compression = 4; + } + + oneof statistics_enabled_opt { + string statistics_enabled = 5; + } + + oneof bloom_filter_fpp_opt { + double bloom_filter_fpp = 6; + } + + oneof bloom_filter_ndv_opt { + uint64 bloom_filter_ndv = 7; + } + + oneof max_statistics_size_opt { + uint32 max_statistics_size = 8; + } +} + +message ParquetOptions { + // Regular fields + bool enable_page_index = 1; // default = true + bool pruning = 2; // default = true + bool skip_metadata = 3; // default = true + bool pushdown_filters = 5; // default = false + bool reorder_filters = 6; // default = false + uint64 data_pagesize_limit = 7; // default = 1024 * 1024 + uint64 write_batch_size = 8; // default = 1024 + string writer_version = 9; // default = "1.0" + // bool bloom_filter_enabled = 20; // default = false + bool allow_single_file_parallelism = 23; // default = true + uint64 maximum_parallel_row_group_writers = 24; // default = 1 + uint64 maximum_buffered_record_batches_per_stream = 25; // default = 2 + bool bloom_filter_on_read = 26; // default = true + bool bloom_filter_on_write = 27; // default = false + bool schema_force_view_types = 28; // default = false + bool binary_as_string = 29; // default = false + + oneof metadata_size_hint_opt { + uint64 metadata_size_hint = 4; + } + + oneof compression_opt { + string compression = 10; + } + + oneof dictionary_enabled_opt { + bool dictionary_enabled = 11; + } + + oneof statistics_enabled_opt { + string statistics_enabled = 13; + } + + oneof max_statistics_size_opt { + uint64 max_statistics_size = 14; + } + + oneof column_index_truncate_length_opt { + uint64 column_index_truncate_length = 17; + } + + oneof encoding_opt { + string encoding = 19; + } + + oneof bloom_filter_fpp_opt { + double bloom_filter_fpp = 21; + } + + oneof bloom_filter_ndv_opt { + uint64 bloom_filter_ndv = 22; + } + + uint64 dictionary_page_size_limit = 12; + + uint64 data_page_row_count_limit = 18; + + uint64 max_row_group_size = 15; + + string created_by = 16; +} + +enum JoinSide { + LEFT_SIDE = 0; + RIGHT_SIDE = 1; + NONE = 2; +} + +message Precision{ + PrecisionInfo precision_info = 1; + ScalarValue val = 2; +} + +enum PrecisionInfo { + EXACT = 0; + INEXACT = 1; + ABSENT = 2; +} + +message Statistics { + Precision num_rows = 1; + Precision total_byte_size = 2; + repeated ColumnStats column_stats = 3; +} + +message ColumnStats { + Precision min_value = 1; + Precision max_value = 2; + Precision null_count = 3; + Precision distinct_count = 4; +} diff --git a/datafusion/proto-common/regen.sh b/datafusion/proto-common/regen.sh new file mode 100755 index 000000000000..2a5554b8752f --- /dev/null +++ b/datafusion/proto-common/regen.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd "$SCRIPT_DIR" && cargo run --manifest-path gen/Cargo.toml diff --git a/datafusion/proto-common/src/common.rs b/datafusion/proto-common/src/common.rs new file mode 100644 index 000000000000..61711dcf8e08 --- /dev/null +++ b/datafusion/proto-common/src/common.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion_common::{internal_datafusion_err, DataFusionError}; + +pub fn proto_error>(message: S) -> DataFusionError { + internal_datafusion_err!("{}", message.into()) +} diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs new file mode 100644 index 000000000000..a554e4ed2805 --- /dev/null +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -0,0 +1,1137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; +use std::sync::Arc; + +use crate::common::proto_error; +use crate::protobuf_common as protobuf; +use arrow::array::{ArrayRef, AsArray}; +use arrow::buffer::Buffer; +use arrow::csv::WriterBuilder; +use arrow::datatypes::{ + i256, DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + Schema, TimeUnit, UnionFields, UnionMode, +}; +use arrow::ipc::{reader::read_record_batch, root_as_message}; + +use datafusion_common::{ + arrow_datafusion_err, + config::{ + CsvOptions, JsonOptions, ParquetColumnOptions, ParquetOptions, + TableParquetOptions, + }, + file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, + parsers::CompressionTypeVariant, + plan_datafusion_err, + stats::Precision, + Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, + DataFusionError, JoinSide, ScalarValue, Statistics, TableReference, +}; + +#[derive(Debug)] +pub enum Error { + General(String), + + DataFusionError(DataFusionError), + + MissingRequiredField(String), + + AtLeastOneValue(String), + + UnknownEnumVariant { name: String, value: i32 }, +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::General(desc) => write!(f, "General error: {desc}"), + + Self::DataFusionError(desc) => { + write!(f, "DataFusion error: {desc:?}") + } + + Self::MissingRequiredField(name) => { + write!(f, "Missing required field {name}") + } + Self::AtLeastOneValue(name) => { + write!(f, "Must have at least one {name}, found 0") + } + Self::UnknownEnumVariant { name, value } => { + write!(f, "Unknown i32 value for {name} enum: {value}") + } + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(e: DataFusionError) -> Self { + Error::DataFusionError(e) + } +} + +impl Error { + pub fn required(field: impl Into) -> Error { + Error::MissingRequiredField(field.into()) + } + + pub fn unknown(name: impl Into, value: i32) -> Error { + Error::UnknownEnumVariant { + name: name.into(), + value, + } + } +} + +impl From for DataFusionError { + fn from(e: Error) -> Self { + plan_datafusion_err!("{}", e) + } +} + +/// An extension trait that adds the methods `optional` and `required` to any +/// Option containing a type implementing `TryInto` +pub trait FromOptionalField { + /// Converts an optional protobuf field to an option of a different type + /// + /// Returns None if the option is None, otherwise calls [`TryInto::try_into`] + /// on the contained data, returning any error encountered + fn optional(self) -> datafusion_common::Result, Error>; + + /// Converts an optional protobuf field to a different type, returning an error if None + /// + /// Returns `Error::MissingRequiredField` if None, otherwise calls [`TryInto::try_into`] + /// on the contained data, returning any error encountered + fn required(self, field: impl Into) -> datafusion_common::Result; +} + +impl FromOptionalField for Option +where + T: TryInto, +{ + fn optional(self) -> datafusion_common::Result, Error> { + self.map(|t| t.try_into()).transpose() + } + + fn required(self, field: impl Into) -> datafusion_common::Result { + match self { + None => Err(Error::required(field)), + Some(t) => t.try_into(), + } + } +} + +impl From for Column { + fn from(c: protobuf::Column) -> Self { + let protobuf::Column { relation, name } = c; + + Self::new(relation.map(|r| r.relation), name) + } +} + +impl From<&protobuf::Column> for Column { + fn from(c: &protobuf::Column) -> Self { + c.clone().into() + } +} + +impl TryFrom<&protobuf::DfSchema> for DFSchema { + type Error = Error; + + fn try_from( + df_schema: &protobuf::DfSchema, + ) -> datafusion_common::Result { + let df_fields = df_schema.columns.clone(); + let qualifiers_and_fields: Vec<(Option, Arc)> = df_fields + .iter() + .map(|df_field| { + let field: Field = df_field.field.as_ref().required("field")?; + Ok(( + df_field + .qualifier + .as_ref() + .map(|q| q.relation.clone().into()), + Arc::new(field), + )) + }) + .collect::, Error>>()?; + + Ok(DFSchema::new_with_metadata( + qualifiers_and_fields, + df_schema.metadata.clone(), + )?) + } +} + +impl TryFrom for DFSchemaRef { + type Error = Error; + + fn try_from( + df_schema: protobuf::DfSchema, + ) -> datafusion_common::Result { + let dfschema: DFSchema = (&df_schema).try_into()?; + Ok(Arc::new(dfschema)) + } +} + +impl TryFrom<&protobuf::ArrowType> for DataType { + type Error = Error; + + fn try_from( + arrow_type: &protobuf::ArrowType, + ) -> datafusion_common::Result { + arrow_type + .arrow_type_enum + .as_ref() + .required("arrow_type_enum") + } +} + +impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { + type Error = Error; + fn try_from( + arrow_type_enum: &protobuf::arrow_type::ArrowTypeEnum, + ) -> datafusion_common::Result { + use protobuf::arrow_type; + Ok(match arrow_type_enum { + arrow_type::ArrowTypeEnum::None(_) => DataType::Null, + arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, + arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, + arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, + arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, + arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, + arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, + arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, + arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, + arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, + arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, + arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, + arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, + arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, + arrow_type::ArrowTypeEnum::Utf8View(_) => DataType::Utf8View, + arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, + arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, + arrow_type::ArrowTypeEnum::BinaryView(_) => DataType::BinaryView, + arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { + DataType::FixedSizeBinary(*size) + } + arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, + arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, + arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, + arrow_type::ArrowTypeEnum::Duration(time_unit) => { + DataType::Duration(parse_i32_to_time_unit(time_unit)?) + } + arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { + time_unit, + timezone, + }) => DataType::Timestamp( + parse_i32_to_time_unit(time_unit)?, + match timezone.len() { + 0 => None, + _ => Some(timezone.as_str().into()), + }, + ), + arrow_type::ArrowTypeEnum::Time32(time_unit) => { + DataType::Time32(parse_i32_to_time_unit(time_unit)?) + } + arrow_type::ArrowTypeEnum::Time64(time_unit) => { + DataType::Time64(parse_i32_to_time_unit(time_unit)?) + } + arrow_type::ArrowTypeEnum::Interval(interval_unit) => { + DataType::Interval(parse_i32_to_interval_unit(interval_unit)?) + } + arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { + precision, + scale, + }) => DataType::Decimal128(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::Decimal256(protobuf::Decimal256Type { + precision, + scale, + }) => DataType::Decimal256(*precision as u8, *scale as i8), + arrow_type::ArrowTypeEnum::List(list) => { + let list_type = + list.as_ref().field_type.as_deref().required("field_type")?; + DataType::List(Arc::new(list_type)) + } + arrow_type::ArrowTypeEnum::LargeList(list) => { + let list_type = + list.as_ref().field_type.as_deref().required("field_type")?; + DataType::LargeList(Arc::new(list_type)) + } + arrow_type::ArrowTypeEnum::FixedSizeList(list) => { + let list_type = + list.as_ref().field_type.as_deref().required("field_type")?; + let list_size = list.list_size; + DataType::FixedSizeList(Arc::new(list_type), list_size) + } + arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( + parse_proto_fields_to_fields(&strct.sub_field_types)?.into(), + ), + arrow_type::ArrowTypeEnum::Union(union) => { + let union_mode = protobuf::UnionMode::try_from(union.union_mode) + .map_err(|_| Error::unknown("UnionMode", union.union_mode))?; + let union_mode = match union_mode { + protobuf::UnionMode::Dense => UnionMode::Dense, + protobuf::UnionMode::Sparse => UnionMode::Sparse, + }; + let union_fields = parse_proto_fields_to_fields(&union.union_types)?; + + // Default to index based type ids if not provided + let type_ids: Vec<_> = match union.type_ids.is_empty() { + true => (0..union_fields.len() as i8).collect(), + false => union.type_ids.iter().map(|i| *i as i8).collect(), + }; + + DataType::Union(UnionFields::new(type_ids, union_fields), union_mode) + } + arrow_type::ArrowTypeEnum::Dictionary(dict) => { + let key_datatype = dict.as_ref().key.as_deref().required("key")?; + let value_datatype = dict.as_ref().value.as_deref().required("value")?; + DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype)) + } + arrow_type::ArrowTypeEnum::Map(map) => { + let field: Field = + map.as_ref().field_type.as_deref().required("field_type")?; + let keys_sorted = map.keys_sorted; + DataType::Map(Arc::new(field), keys_sorted) + } + }) + } +} + +impl TryFrom<&protobuf::Field> for Field { + type Error = Error; + fn try_from(field: &protobuf::Field) -> Result { + let datatype = field.arrow_type.as_deref().required("arrow_type")?; + let field = if field.dict_id != 0 { + Self::new_dict( + field.name.as_str(), + datatype, + field.nullable, + field.dict_id, + field.dict_ordered, + ) + .with_metadata(field.metadata.clone()) + } else { + Self::new(field.name.as_str(), datatype, field.nullable) + .with_metadata(field.metadata.clone()) + }; + Ok(field) + } +} + +impl TryFrom<&protobuf::Schema> for Schema { + type Error = Error; + + fn try_from( + schema: &protobuf::Schema, + ) -> datafusion_common::Result { + let fields = schema + .columns + .iter() + .map(Field::try_from) + .collect::, _>>()?; + Ok(Self::new_with_metadata(fields, schema.metadata.clone())) + } +} + +impl TryFrom<&protobuf::ScalarValue> for ScalarValue { + type Error = Error; + + fn try_from( + scalar: &protobuf::ScalarValue, + ) -> datafusion_common::Result { + use protobuf::scalar_value::Value; + + let value = scalar + .value + .as_ref() + .ok_or_else(|| Error::required("value"))?; + + Ok(match value { + Value::BoolValue(v) => Self::Boolean(Some(*v)), + Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())), + Value::Utf8ViewValue(v) => Self::Utf8View(Some(v.to_owned())), + Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())), + Value::Int8Value(v) => Self::Int8(Some(*v as i8)), + Value::Int16Value(v) => Self::Int16(Some(*v as i16)), + Value::Int32Value(v) => Self::Int32(Some(*v)), + Value::Int64Value(v) => Self::Int64(Some(*v)), + Value::Uint8Value(v) => Self::UInt8(Some(*v as u8)), + Value::Uint16Value(v) => Self::UInt16(Some(*v as u16)), + Value::Uint32Value(v) => Self::UInt32(Some(*v)), + Value::Uint64Value(v) => Self::UInt64(Some(*v)), + Value::Float32Value(v) => Self::Float32(Some(*v)), + Value::Float64Value(v) => Self::Float64(Some(*v)), + Value::Date32Value(v) => Self::Date32(Some(*v)), + // ScalarValue::List is serialized using arrow IPC format + Value::ListValue(v) + | Value::FixedSizeListValue(v) + | Value::LargeListValue(v) + | Value::StructValue(v) + | Value::MapValue(v) => { + let protobuf::ScalarNestedValue { + ipc_message, + arrow_data, + dictionaries, + schema, + } = &v; + + let schema: Schema = if let Some(schema_ref) = schema { + schema_ref.try_into()? + } else { + return Err(Error::General( + "Invalid schema while deserializing ScalarValue::List" + .to_string(), + )); + }; + + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data.as_slice()); + + let ipc_batch = message.header_as_record_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List" + .to_string(), + ) + })?; + + let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { + Error::General(format!( + "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + )) + })?; + let buffer = Buffer::from(arrow_data.as_slice()); + + let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List dictionary message" + .to_string(), + ) + })?; + + let id = dict_batch.id(); + + let fields_using_this_dictionary = schema.fields_with_dict_id(id); + let first_field = fields_using_this_dictionary.first().ok_or_else(|| { + Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string()) + })?; + + let values: ArrayRef = match first_field.data_type() { + DataType::Dictionary(_, ref value_type) => { + // Make a fake schema for the dictionary batch. + let value = value_type.as_ref().clone(); + let schema = Schema::new(vec![Field::new("", value, true)]); + // Read a single column + let record_batch = read_record_batch( + &buffer, + dict_batch.data().unwrap(), + Arc::new(schema), + &Default::default(), + None, + &message.version(), + )?; + Ok(Arc::clone(record_batch.column(0))) + } + _ => Err(Error::General("dictionary id not found in schema while deserializing ScalarValue::List".to_string())), + }?; + + Ok((id,values)) + }).collect::>>()?; + + let record_batch = read_record_batch( + &buffer, + ipc_batch, + Arc::new(schema), + &dict_by_id, + None, + &message.version(), + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + let arr = record_batch.column(0); + match value { + Value::ListValue(_) => { + Self::List(arr.as_list::().to_owned().into()) + } + Value::LargeListValue(_) => { + Self::LargeList(arr.as_list::().to_owned().into()) + } + Value::FixedSizeListValue(_) => { + Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) + } + Value::StructValue(_) => { + Self::Struct(arr.as_struct().to_owned().into()) + } + Value::MapValue(_) => Self::Map(arr.as_map().to_owned().into()), + _ => unreachable!(), + } + } + Value::NullValue(v) => { + let null_type: DataType = v.try_into()?; + null_type.try_into().map_err(Error::DataFusionError)? + } + Value::Decimal128Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal128( + Some(i128::from_be_bytes(array)), + val.p as u8, + val.s as i8, + ) + } + Value::Decimal256Value(val) => { + let array = vec_to_array(val.value.clone()); + Self::Decimal256( + Some(i256::from_be_bytes(array)), + val.p as u8, + val.s as i8, + ) + } + Value::Date64Value(v) => Self::Date64(Some(*v)), + Value::Time32Value(v) => { + let time_value = + v.value.as_ref().ok_or_else(|| Error::required("value"))?; + match time_value { + protobuf::scalar_time32_value::Value::Time32SecondValue(t) => { + Self::Time32Second(Some(*t)) + } + protobuf::scalar_time32_value::Value::Time32MillisecondValue(t) => { + Self::Time32Millisecond(Some(*t)) + } + } + } + Value::Time64Value(v) => { + let time_value = + v.value.as_ref().ok_or_else(|| Error::required("value"))?; + match time_value { + protobuf::scalar_time64_value::Value::Time64MicrosecondValue(t) => { + Self::Time64Microsecond(Some(*t)) + } + protobuf::scalar_time64_value::Value::Time64NanosecondValue(t) => { + Self::Time64Nanosecond(Some(*t)) + } + } + } + Value::IntervalYearmonthValue(v) => Self::IntervalYearMonth(Some(*v)), + Value::DurationSecondValue(v) => Self::DurationSecond(Some(*v)), + Value::DurationMillisecondValue(v) => Self::DurationMillisecond(Some(*v)), + Value::DurationMicrosecondValue(v) => Self::DurationMicrosecond(Some(*v)), + Value::DurationNanosecondValue(v) => Self::DurationNanosecond(Some(*v)), + Value::TimestampValue(v) => { + let timezone = if v.timezone.is_empty() { + None + } else { + Some(v.timezone.as_str().into()) + }; + + let ts_value = + v.value.as_ref().ok_or_else(|| Error::required("value"))?; + + match ts_value { + protobuf::scalar_timestamp_value::Value::TimeMicrosecondValue(t) => { + Self::TimestampMicrosecond(Some(*t), timezone) + } + protobuf::scalar_timestamp_value::Value::TimeNanosecondValue(t) => { + Self::TimestampNanosecond(Some(*t), timezone) + } + protobuf::scalar_timestamp_value::Value::TimeSecondValue(t) => { + Self::TimestampSecond(Some(*t), timezone) + } + protobuf::scalar_timestamp_value::Value::TimeMillisecondValue(t) => { + Self::TimestampMillisecond(Some(*t), timezone) + } + } + } + Value::DictionaryValue(v) => { + let index_type: DataType = v + .index_type + .as_ref() + .ok_or_else(|| Error::required("index_type"))? + .try_into()?; + + let value: Self = v + .value + .as_ref() + .ok_or_else(|| Error::required("value"))? + .as_ref() + .try_into()?; + + Self::Dictionary(Box::new(index_type), Box::new(value)) + } + Value::BinaryValue(v) => Self::Binary(Some(v.clone())), + Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), + Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), + Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some( + IntervalDayTimeType::make_value(v.days, v.milliseconds), + )), + Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), + )), + Value::UnionValue(val) => { + let mode = match val.mode { + 0 => UnionMode::Sparse, + 1 => UnionMode::Dense, + id => Err(Error::unknown("UnionMode", id))?, + }; + let ids = val + .fields + .iter() + .map(|f| f.field_id as i8) + .collect::>(); + let fields = val + .fields + .iter() + .map(|f| f.field.clone()) + .collect::>>(); + let fields = fields.ok_or_else(|| Error::required("UnionField"))?; + let fields = parse_proto_fields_to_fields(&fields)?; + let fields = UnionFields::new(ids, fields); + let v_id = val.value_id as i8; + let val = match &val.value { + None => None, + Some(val) => { + let val: ScalarValue = val + .as_ref() + .try_into() + .map_err(|_| Error::General("Invalid Scalar".to_string()))?; + Some((v_id, Box::new(val))) + } + }; + Self::Union(val, fields, mode) + } + Value::FixedSizeBinaryValue(v) => { + Self::FixedSizeBinary(v.length, Some(v.clone().values)) + } + }) + } +} + +impl From for TimeUnit { + fn from(time_unit: protobuf::TimeUnit) -> Self { + match time_unit { + protobuf::TimeUnit::Second => TimeUnit::Second, + protobuf::TimeUnit::Millisecond => TimeUnit::Millisecond, + protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond, + protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond, + } + } +} + +impl From for IntervalUnit { + fn from(interval_unit: protobuf::IntervalUnit) -> Self { + match interval_unit { + protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth, + protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime, + protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano, + } + } +} + +impl From for Constraints { + fn from(constraints: protobuf::Constraints) -> Self { + Constraints::new_unverified( + constraints + .constraints + .into_iter() + .map(|item| item.into()) + .collect(), + ) + } +} + +impl From for Constraint { + fn from(value: protobuf::Constraint) -> Self { + match value.constraint_mode.unwrap() { + protobuf::constraint::ConstraintMode::PrimaryKey(elem) => { + Constraint::PrimaryKey( + elem.indices.into_iter().map(|item| item as usize).collect(), + ) + } + protobuf::constraint::ConstraintMode::Unique(elem) => Constraint::Unique( + elem.indices.into_iter().map(|item| item as usize).collect(), + ), + } + } +} + +impl From<&protobuf::ColumnStats> for ColumnStatistics { + fn from(cs: &protobuf::ColumnStats) -> ColumnStatistics { + ColumnStatistics { + null_count: if let Some(nc) = &cs.null_count { + nc.clone().into() + } else { + Precision::Absent + }, + max_value: if let Some(max) = &cs.max_value { + max.clone().into() + } else { + Precision::Absent + }, + min_value: if let Some(min) = &cs.min_value { + min.clone().into() + } else { + Precision::Absent + }, + distinct_count: if let Some(dc) = &cs.distinct_count { + dc.clone().into() + } else { + Precision::Absent + }, + } + } +} + +impl From for Precision { + fn from(s: protobuf::Precision) -> Self { + let Ok(precision_type) = s.precision_info.try_into() else { + return Precision::Absent; + }; + match precision_type { + protobuf::PrecisionInfo::Exact => { + if let Some(val) = s.val { + if let Ok(ScalarValue::UInt64(Some(val))) = + ScalarValue::try_from(&val) + { + Precision::Exact(val as usize) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Inexact => { + if let Some(val) = s.val { + if let Ok(ScalarValue::UInt64(Some(val))) = + ScalarValue::try_from(&val) + { + Precision::Inexact(val as usize) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Absent => Precision::Absent, + } + } +} + +impl From for Precision { + fn from(s: protobuf::Precision) -> Self { + let Ok(precision_type) = s.precision_info.try_into() else { + return Precision::Absent; + }; + match precision_type { + protobuf::PrecisionInfo::Exact => { + if let Some(val) = s.val { + if let Ok(val) = ScalarValue::try_from(&val) { + Precision::Exact(val) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Inexact => { + if let Some(val) = s.val { + if let Ok(val) = ScalarValue::try_from(&val) { + Precision::Inexact(val) + } else { + Precision::Absent + } + } else { + Precision::Absent + } + } + protobuf::PrecisionInfo::Absent => Precision::Absent, + } + } +} + +impl From for JoinSide { + fn from(t: protobuf::JoinSide) -> Self { + match t { + protobuf::JoinSide::LeftSide => JoinSide::Left, + protobuf::JoinSide::RightSide => JoinSide::Right, + protobuf::JoinSide::None => JoinSide::None, + } + } +} + +impl TryFrom<&protobuf::Statistics> for Statistics { + type Error = DataFusionError; + + fn try_from( + s: &protobuf::Statistics, + ) -> datafusion_common::Result { + // Keep it sync with Statistics::to_proto + Ok(Statistics { + num_rows: if let Some(nr) = &s.num_rows { + nr.clone().into() + } else { + Precision::Absent + }, + total_byte_size: if let Some(tbs) = &s.total_byte_size { + tbs.clone().into() + } else { + Precision::Absent + }, + // No column statistic (None) is encoded with empty array + column_statistics: s.column_stats.iter().map(|s| s.into()).collect(), + }) + } +} + +impl From for CompressionTypeVariant { + fn from(value: protobuf::CompressionTypeVariant) -> Self { + match value { + protobuf::CompressionTypeVariant::Gzip => Self::GZIP, + protobuf::CompressionTypeVariant::Bzip2 => Self::BZIP2, + protobuf::CompressionTypeVariant::Xz => Self::XZ, + protobuf::CompressionTypeVariant::Zstd => Self::ZSTD, + protobuf::CompressionTypeVariant::Uncompressed => Self::UNCOMPRESSED, + } + } +} + +impl From for protobuf::CompressionTypeVariant { + fn from(value: CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + +impl TryFrom<&protobuf::CsvWriterOptions> for CsvWriterOptions { + type Error = DataFusionError; + + fn try_from( + opts: &protobuf::CsvWriterOptions, + ) -> datafusion_common::Result { + let write_options = csv_writer_options_from_proto(opts)?; + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(CsvWriterOptions::new(write_options, compression)) + } +} + +impl TryFrom<&protobuf::JsonWriterOptions> for JsonWriterOptions { + type Error = DataFusionError; + + fn try_from( + opts: &protobuf::JsonWriterOptions, + ) -> datafusion_common::Result { + let compression: CompressionTypeVariant = opts.compression().into(); + Ok(JsonWriterOptions::new(compression)) + } +} + +impl TryFrom<&protobuf::CsvOptions> for CsvOptions { + type Error = DataFusionError; + + fn try_from( + proto_opts: &protobuf::CsvOptions, + ) -> datafusion_common::Result { + Ok(CsvOptions { + has_header: proto_opts.has_header.first().map(|h| *h != 0), + delimiter: proto_opts.delimiter[0], + quote: proto_opts.quote[0], + terminator: proto_opts.terminator.first().copied(), + escape: proto_opts.escape.first().copied(), + double_quote: proto_opts.has_header.first().map(|h| *h != 0), + newlines_in_values: proto_opts.newlines_in_values.first().map(|h| *h != 0), + compression: proto_opts.compression().into(), + schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, + date_format: (!proto_opts.date_format.is_empty()) + .then(|| proto_opts.date_format.clone()), + datetime_format: (!proto_opts.datetime_format.is_empty()) + .then(|| proto_opts.datetime_format.clone()), + timestamp_format: (!proto_opts.timestamp_format.is_empty()) + .then(|| proto_opts.timestamp_format.clone()), + timestamp_tz_format: (!proto_opts.timestamp_tz_format.is_empty()) + .then(|| proto_opts.timestamp_tz_format.clone()), + time_format: (!proto_opts.time_format.is_empty()) + .then(|| proto_opts.time_format.clone()), + null_value: (!proto_opts.null_value.is_empty()) + .then(|| proto_opts.null_value.clone()), + comment: proto_opts.comment.first().copied(), + }) + } +} + +impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { + type Error = DataFusionError; + + fn try_from( + value: &protobuf::ParquetOptions, + ) -> datafusion_common::Result { + Ok(ParquetOptions { + enable_page_index: value.enable_page_index, + pruning: value.pruning, + skip_metadata: value.skip_metadata, + metadata_size_hint: value + .metadata_size_hint_opt + .map(|opt| match opt { + protobuf::parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => Some(v as usize), + }) + .unwrap_or(None), + pushdown_filters: value.pushdown_filters, + reorder_filters: value.reorder_filters, + data_pagesize_limit: value.data_pagesize_limit as usize, + write_batch_size: value.write_batch_size as usize, + writer_version: value.writer_version.clone(), + compression: value.compression_opt.clone().map(|opt| match opt { + protobuf::parquet_options::CompressionOpt::Compression(v) => Some(v), + }).unwrap_or(None), + dictionary_enabled: value.dictionary_enabled_opt.as_ref().map(|protobuf::parquet_options::DictionaryEnabledOpt::DictionaryEnabled(v)| *v), + // Continuing from where we left off in the TryFrom implementation + dictionary_page_size_limit: value.dictionary_page_size_limit as usize, + statistics_enabled: value + .statistics_enabled_opt.clone() + .map(|opt| match opt { + protobuf::parquet_options::StatisticsEnabledOpt::StatisticsEnabled(v) => Some(v), + }) + .unwrap_or(None), + max_statistics_size: value + .max_statistics_size_opt.as_ref() + .map(|opt| match opt { + protobuf::parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(*v as usize), + }) + .unwrap_or(None), + max_row_group_size: value.max_row_group_size as usize, + created_by: value.created_by.clone(), + column_index_truncate_length: value + .column_index_truncate_length_opt.as_ref() + .map(|opt| match opt { + protobuf::parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => Some(*v as usize), + }) + .unwrap_or(None), + data_page_row_count_limit: value.data_page_row_count_limit as usize, + encoding: value + .encoding_opt.clone() + .map(|opt| match opt { + protobuf::parquet_options::EncodingOpt::Encoding(v) => Some(v), + }) + .unwrap_or(None), + bloom_filter_on_read: value.bloom_filter_on_read, + bloom_filter_on_write: value.bloom_filter_on_write, + bloom_filter_fpp: value.clone() + .bloom_filter_fpp_opt + .map(|opt| match opt { + protobuf::parquet_options::BloomFilterFppOpt::BloomFilterFpp(v) => Some(v), + }) + .unwrap_or(None), + bloom_filter_ndv: value.clone() + .bloom_filter_ndv_opt + .map(|opt| match opt { + protobuf::parquet_options::BloomFilterNdvOpt::BloomFilterNdv(v) => Some(v), + }) + .unwrap_or(None), + allow_single_file_parallelism: value.allow_single_file_parallelism, + maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as usize, + maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as usize, + schema_force_view_types: value.schema_force_view_types, + binary_as_string: value.binary_as_string, + }) + } +} + +impl TryFrom<&protobuf::ParquetColumnOptions> for ParquetColumnOptions { + type Error = DataFusionError; + fn try_from( + value: &protobuf::ParquetColumnOptions, + ) -> datafusion_common::Result { + Ok(ParquetColumnOptions { + compression: value.compression_opt.clone().map(|opt| match opt { + protobuf::parquet_column_options::CompressionOpt::Compression(v) => Some(v), + }).unwrap_or(None), + dictionary_enabled: value.dictionary_enabled_opt.as_ref().map(|protobuf::parquet_column_options::DictionaryEnabledOpt::DictionaryEnabled(v)| *v), + statistics_enabled: value + .statistics_enabled_opt.clone() + .map(|opt| match opt { + protobuf::parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled(v) => Some(v), + }) + .unwrap_or(None), + max_statistics_size: value + .max_statistics_size_opt + .map(|opt| match opt { + protobuf::parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(v as usize), + }) + .unwrap_or(None), + encoding: value + .encoding_opt.clone() + .map(|opt| match opt { + protobuf::parquet_column_options::EncodingOpt::Encoding(v) => Some(v), + }) + .unwrap_or(None), + bloom_filter_enabled: value.bloom_filter_enabled_opt.map(|opt| match opt { + protobuf::parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v) => Some(v), + }) + .unwrap_or(None), + bloom_filter_fpp: value + .bloom_filter_fpp_opt + .map(|opt| match opt { + protobuf::parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(v) => Some(v), + }) + .unwrap_or(None), + bloom_filter_ndv: value + .bloom_filter_ndv_opt + .map(|opt| match opt { + protobuf::parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => Some(v), + }) + .unwrap_or(None), + }) + } +} + +impl TryFrom<&protobuf::TableParquetOptions> for TableParquetOptions { + type Error = DataFusionError; + fn try_from( + value: &protobuf::TableParquetOptions, + ) -> datafusion_common::Result { + let mut column_specific_options: HashMap = + HashMap::new(); + for protobuf::ParquetColumnSpecificOptions { + column_name, + options: maybe_options, + } in &value.column_specific_options + { + if let Some(options) = maybe_options { + column_specific_options.insert(column_name.clone(), options.try_into()?); + } + } + Ok(TableParquetOptions { + global: value + .global + .as_ref() + .map(|v| v.try_into()) + .unwrap() + .unwrap(), + column_specific_options, + key_value_metadata: Default::default(), + }) + } +} + +impl TryFrom<&protobuf::JsonOptions> for JsonOptions { + type Error = DataFusionError; + + fn try_from( + proto_opts: &protobuf::JsonOptions, + ) -> datafusion_common::Result { + let compression: protobuf::CompressionTypeVariant = proto_opts.compression(); + Ok(JsonOptions { + compression: compression.into(), + schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, + }) + } +} + +pub fn parse_i32_to_time_unit(value: &i32) -> datafusion_common::Result { + protobuf::TimeUnit::try_from(*value) + .map(|t| t.into()) + .map_err(|_| Error::unknown("TimeUnit", *value)) +} + +pub fn parse_i32_to_interval_unit( + value: &i32, +) -> datafusion_common::Result { + protobuf::IntervalUnit::try_from(*value) + .map(|t| t.into()) + .map_err(|_| Error::unknown("IntervalUnit", *value)) +} + +// panic here because no better way to convert from Vec to Array +fn vec_to_array(v: Vec) -> [T; N] { + v.try_into().unwrap_or_else(|v: Vec| { + panic!("Expected a Vec of length {} but it was {}", N, v.len()) + }) +} + +/// Converts a vector of `protobuf::Field`s to `Arc`s. +pub fn parse_proto_fields_to_fields<'a, I>( + fields: I, +) -> std::result::Result, Error> +where + I: IntoIterator, +{ + fields + .into_iter() + .map(Field::try_from) + .collect::>() +} + +pub(crate) fn csv_writer_options_from_proto( + writer_options: &protobuf::CsvWriterOptions, +) -> datafusion_common::Result { + let mut builder = WriterBuilder::new(); + if !writer_options.delimiter.is_empty() { + if let Some(delimiter) = writer_options.delimiter.chars().next() { + if delimiter.is_ascii() { + builder = builder.with_delimiter(delimiter as u8); + } else { + return Err(proto_error("CSV Delimiter is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Delimiter")); + } + } + if !writer_options.quote.is_empty() { + if let Some(quote) = writer_options.quote.chars().next() { + if quote.is_ascii() { + builder = builder.with_quote(quote as u8); + } else { + return Err(proto_error("CSV Quote is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Quote")); + } + } + if !writer_options.escape.is_empty() { + if let Some(escape) = writer_options.escape.chars().next() { + if escape.is_ascii() { + builder = builder.with_escape(escape as u8); + } else { + return Err(proto_error("CSV Escape is not ASCII")); + } + } else { + return Err(proto_error("Error parsing CSV Escape")); + } + } + Ok(builder + .with_header(writer_options.has_header) + .with_date_format(writer_options.date_format.clone()) + .with_datetime_format(writer_options.datetime_format.clone()) + .with_timestamp_format(writer_options.timestamp_format.clone()) + .with_time_format(writer_options.time_format.clone()) + .with_null(writer_options.null_value.clone()) + .with_double_quote(writer_options.double_quote)) +} diff --git a/datafusion/proto-common/src/generated/mod.rs b/datafusion/proto-common/src/generated/mod.rs new file mode 100644 index 000000000000..24a062e4cad5 --- /dev/null +++ b/datafusion/proto-common/src/generated/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[allow(clippy::all)] +#[rustfmt::skip] +pub mod datafusion_proto_common { + include!("prost.rs"); + + #[cfg(feature = "json")] + include!("pbjson.rs"); +} diff --git a/datafusion/proto-common/src/generated/pbjson.rs b/datafusion/proto-common/src/generated/pbjson.rs new file mode 100644 index 000000000000..e8235ef7b9dd --- /dev/null +++ b/datafusion/proto-common/src/generated/pbjson.rs @@ -0,0 +1,8514 @@ +impl serde::Serialize for ArrowOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion_common.ArrowOptions", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ArrowOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(ArrowOptions { + }) + } + } + deserializer.deserialize_struct("datafusion_common.ArrowOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ArrowType { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.arrow_type_enum.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ArrowType", len)?; + if let Some(v) = self.arrow_type_enum.as_ref() { + match v { + arrow_type::ArrowTypeEnum::None(v) => { + struct_ser.serialize_field("NONE", v)?; + } + arrow_type::ArrowTypeEnum::Bool(v) => { + struct_ser.serialize_field("BOOL", v)?; + } + arrow_type::ArrowTypeEnum::Uint8(v) => { + struct_ser.serialize_field("UINT8", v)?; + } + arrow_type::ArrowTypeEnum::Int8(v) => { + struct_ser.serialize_field("INT8", v)?; + } + arrow_type::ArrowTypeEnum::Uint16(v) => { + struct_ser.serialize_field("UINT16", v)?; + } + arrow_type::ArrowTypeEnum::Int16(v) => { + struct_ser.serialize_field("INT16", v)?; + } + arrow_type::ArrowTypeEnum::Uint32(v) => { + struct_ser.serialize_field("UINT32", v)?; + } + arrow_type::ArrowTypeEnum::Int32(v) => { + struct_ser.serialize_field("INT32", v)?; + } + arrow_type::ArrowTypeEnum::Uint64(v) => { + struct_ser.serialize_field("UINT64", v)?; + } + arrow_type::ArrowTypeEnum::Int64(v) => { + struct_ser.serialize_field("INT64", v)?; + } + arrow_type::ArrowTypeEnum::Float16(v) => { + struct_ser.serialize_field("FLOAT16", v)?; + } + arrow_type::ArrowTypeEnum::Float32(v) => { + struct_ser.serialize_field("FLOAT32", v)?; + } + arrow_type::ArrowTypeEnum::Float64(v) => { + struct_ser.serialize_field("FLOAT64", v)?; + } + arrow_type::ArrowTypeEnum::Utf8(v) => { + struct_ser.serialize_field("UTF8", v)?; + } + arrow_type::ArrowTypeEnum::Utf8View(v) => { + struct_ser.serialize_field("UTF8VIEW", v)?; + } + arrow_type::ArrowTypeEnum::LargeUtf8(v) => { + struct_ser.serialize_field("LARGEUTF8", v)?; + } + arrow_type::ArrowTypeEnum::Binary(v) => { + struct_ser.serialize_field("BINARY", v)?; + } + arrow_type::ArrowTypeEnum::BinaryView(v) => { + struct_ser.serialize_field("BINARYVIEW", v)?; + } + arrow_type::ArrowTypeEnum::FixedSizeBinary(v) => { + struct_ser.serialize_field("FIXEDSIZEBINARY", v)?; + } + arrow_type::ArrowTypeEnum::LargeBinary(v) => { + struct_ser.serialize_field("LARGEBINARY", v)?; + } + arrow_type::ArrowTypeEnum::Date32(v) => { + struct_ser.serialize_field("DATE32", v)?; + } + arrow_type::ArrowTypeEnum::Date64(v) => { + struct_ser.serialize_field("DATE64", v)?; + } + arrow_type::ArrowTypeEnum::Duration(v) => { + let v = TimeUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("DURATION", &v)?; + } + arrow_type::ArrowTypeEnum::Timestamp(v) => { + struct_ser.serialize_field("TIMESTAMP", v)?; + } + arrow_type::ArrowTypeEnum::Time32(v) => { + let v = TimeUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("TIME32", &v)?; + } + arrow_type::ArrowTypeEnum::Time64(v) => { + let v = TimeUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("TIME64", &v)?; + } + arrow_type::ArrowTypeEnum::Interval(v) => { + let v = IntervalUnit::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("INTERVAL", &v)?; + } + arrow_type::ArrowTypeEnum::Decimal(v) => { + struct_ser.serialize_field("DECIMAL", v)?; + } + arrow_type::ArrowTypeEnum::Decimal256(v) => { + struct_ser.serialize_field("DECIMAL256", v)?; + } + arrow_type::ArrowTypeEnum::List(v) => { + struct_ser.serialize_field("LIST", v)?; + } + arrow_type::ArrowTypeEnum::LargeList(v) => { + struct_ser.serialize_field("LARGELIST", v)?; + } + arrow_type::ArrowTypeEnum::FixedSizeList(v) => { + struct_ser.serialize_field("FIXEDSIZELIST", v)?; + } + arrow_type::ArrowTypeEnum::Struct(v) => { + struct_ser.serialize_field("STRUCT", v)?; + } + arrow_type::ArrowTypeEnum::Union(v) => { + struct_ser.serialize_field("UNION", v)?; + } + arrow_type::ArrowTypeEnum::Dictionary(v) => { + struct_ser.serialize_field("DICTIONARY", v)?; + } + arrow_type::ArrowTypeEnum::Map(v) => { + struct_ser.serialize_field("MAP", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ArrowType { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "NONE", + "BOOL", + "UINT8", + "INT8", + "UINT16", + "INT16", + "UINT32", + "INT32", + "UINT64", + "INT64", + "FLOAT16", + "FLOAT32", + "FLOAT64", + "UTF8", + "UTF8_VIEW", + "UTF8VIEW", + "LARGE_UTF8", + "LARGEUTF8", + "BINARY", + "BINARY_VIEW", + "BINARYVIEW", + "FIXED_SIZE_BINARY", + "FIXEDSIZEBINARY", + "LARGE_BINARY", + "LARGEBINARY", + "DATE32", + "DATE64", + "DURATION", + "TIMESTAMP", + "TIME32", + "TIME64", + "INTERVAL", + "DECIMAL", + "DECIMAL256", + "LIST", + "LARGE_LIST", + "LARGELIST", + "FIXED_SIZE_LIST", + "FIXEDSIZELIST", + "STRUCT", + "UNION", + "DICTIONARY", + "MAP", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + None, + Bool, + Uint8, + Int8, + Uint16, + Int16, + Uint32, + Int32, + Uint64, + Int64, + Float16, + Float32, + Float64, + Utf8, + Utf8View, + LargeUtf8, + Binary, + BinaryView, + FixedSizeBinary, + LargeBinary, + Date32, + Date64, + Duration, + Timestamp, + Time32, + Time64, + Interval, + Decimal, + Decimal256, + List, + LargeList, + FixedSizeList, + Struct, + Union, + Dictionary, + Map, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "NONE" => Ok(GeneratedField::None), + "BOOL" => Ok(GeneratedField::Bool), + "UINT8" => Ok(GeneratedField::Uint8), + "INT8" => Ok(GeneratedField::Int8), + "UINT16" => Ok(GeneratedField::Uint16), + "INT16" => Ok(GeneratedField::Int16), + "UINT32" => Ok(GeneratedField::Uint32), + "INT32" => Ok(GeneratedField::Int32), + "UINT64" => Ok(GeneratedField::Uint64), + "INT64" => Ok(GeneratedField::Int64), + "FLOAT16" => Ok(GeneratedField::Float16), + "FLOAT32" => Ok(GeneratedField::Float32), + "FLOAT64" => Ok(GeneratedField::Float64), + "UTF8" => Ok(GeneratedField::Utf8), + "UTF8VIEW" | "UTF8_VIEW" => Ok(GeneratedField::Utf8View), + "LARGEUTF8" | "LARGE_UTF8" => Ok(GeneratedField::LargeUtf8), + "BINARY" => Ok(GeneratedField::Binary), + "BINARYVIEW" | "BINARY_VIEW" => Ok(GeneratedField::BinaryView), + "FIXEDSIZEBINARY" | "FIXED_SIZE_BINARY" => Ok(GeneratedField::FixedSizeBinary), + "LARGEBINARY" | "LARGE_BINARY" => Ok(GeneratedField::LargeBinary), + "DATE32" => Ok(GeneratedField::Date32), + "DATE64" => Ok(GeneratedField::Date64), + "DURATION" => Ok(GeneratedField::Duration), + "TIMESTAMP" => Ok(GeneratedField::Timestamp), + "TIME32" => Ok(GeneratedField::Time32), + "TIME64" => Ok(GeneratedField::Time64), + "INTERVAL" => Ok(GeneratedField::Interval), + "DECIMAL" => Ok(GeneratedField::Decimal), + "DECIMAL256" => Ok(GeneratedField::Decimal256), + "LIST" => Ok(GeneratedField::List), + "LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList), + "FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList), + "STRUCT" => Ok(GeneratedField::Struct), + "UNION" => Ok(GeneratedField::Union), + "DICTIONARY" => Ok(GeneratedField::Dictionary), + "MAP" => Ok(GeneratedField::Map), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ArrowType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ArrowType") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut arrow_type_enum__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::None => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("NONE")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::None) +; + } + GeneratedField::Bool => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("BOOL")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Bool) +; + } + GeneratedField::Uint8 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UINT8")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint8) +; + } + GeneratedField::Int8 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("INT8")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int8) +; + } + GeneratedField::Uint16 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UINT16")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint16) +; + } + GeneratedField::Int16 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("INT16")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int16) +; + } + GeneratedField::Uint32 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UINT32")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint32) +; + } + GeneratedField::Int32 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("INT32")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int32) +; + } + GeneratedField::Uint64 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UINT64")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint64) +; + } + GeneratedField::Int64 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("INT64")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int64) +; + } + GeneratedField::Float16 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FLOAT16")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float16) +; + } + GeneratedField::Float32 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FLOAT32")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float32) +; + } + GeneratedField::Float64 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FLOAT64")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float64) +; + } + GeneratedField::Utf8 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UTF8")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8) +; + } + GeneratedField::Utf8View => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UTF8VIEW")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8View) +; + } + GeneratedField::LargeUtf8 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("LARGEUTF8")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeUtf8) +; + } + GeneratedField::Binary => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("BINARY")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Binary) +; + } + GeneratedField::BinaryView => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("BINARYVIEW")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::BinaryView) +; + } + GeneratedField::FixedSizeBinary => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FIXEDSIZEBINARY")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| arrow_type::ArrowTypeEnum::FixedSizeBinary(x.0)); + } + GeneratedField::LargeBinary => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("LARGEBINARY")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeBinary) +; + } + GeneratedField::Date32 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DATE32")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date32) +; + } + GeneratedField::Date64 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DATE64")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date64) +; + } + GeneratedField::Duration => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DURATION")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Duration(x as i32)); + } + GeneratedField::Timestamp => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("TIMESTAMP")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Timestamp) +; + } + GeneratedField::Time32 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("TIME32")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time32(x as i32)); + } + GeneratedField::Time64 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("TIME64")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time64(x as i32)); + } + GeneratedField::Interval => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("INTERVAL")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Interval(x as i32)); + } + GeneratedField::Decimal => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) +; + } + GeneratedField::Decimal256 => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DECIMAL256")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal256) +; + } + GeneratedField::List => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("LIST")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::List) +; + } + GeneratedField::LargeList => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("LARGELIST")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeList) +; + } + GeneratedField::FixedSizeList => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FIXEDSIZELIST")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::FixedSizeList) +; + } + GeneratedField::Struct => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("STRUCT")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Struct) +; + } + GeneratedField::Union => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("UNION")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Union) +; + } + GeneratedField::Dictionary => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("DICTIONARY")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Dictionary) +; + } + GeneratedField::Map => { + if arrow_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("MAP")); + } + arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) +; + } + } + } + Ok(ArrowType { + arrow_type_enum: arrow_type_enum__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ArrowType", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for AvroFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion_common.AvroFormat", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AvroFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AvroFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.AvroFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(AvroFormat { + }) + } + } + deserializer.deserialize_struct("datafusion_common.AvroFormat", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for AvroOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion_common.AvroOptions", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AvroOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AvroOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.AvroOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(AvroOptions { + }) + } + } + deserializer.deserialize_struct("datafusion_common.AvroOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Column { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.relation.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Column", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.relation.as_ref() { + struct_ser.serialize_field("relation", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Column { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "relation", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + Relation, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "relation" => Ok(GeneratedField::Relation), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Column; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Column") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut relation__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = map_.next_value()?; + } + } + } + Ok(Column { + name: name__.unwrap_or_default(), + relation: relation__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.Column", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ColumnRelation { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.relation.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ColumnRelation", len)?; + if !self.relation.is_empty() { + struct_ser.serialize_field("relation", &self.relation)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ColumnRelation { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "relation", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Relation, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "relation" => Ok(GeneratedField::Relation), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ColumnRelation; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ColumnRelation") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut relation__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Relation => { + if relation__.is_some() { + return Err(serde::de::Error::duplicate_field("relation")); + } + relation__ = Some(map_.next_value()?); + } + } + } + Ok(ColumnRelation { + relation: relation__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.ColumnRelation", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ColumnStats { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.min_value.is_some() { + len += 1; + } + if self.max_value.is_some() { + len += 1; + } + if self.null_count.is_some() { + len += 1; + } + if self.distinct_count.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ColumnStats", len)?; + if let Some(v) = self.min_value.as_ref() { + struct_ser.serialize_field("minValue", v)?; + } + if let Some(v) = self.max_value.as_ref() { + struct_ser.serialize_field("maxValue", v)?; + } + if let Some(v) = self.null_count.as_ref() { + struct_ser.serialize_field("nullCount", v)?; + } + if let Some(v) = self.distinct_count.as_ref() { + struct_ser.serialize_field("distinctCount", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ColumnStats { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "min_value", + "minValue", + "max_value", + "maxValue", + "null_count", + "nullCount", + "distinct_count", + "distinctCount", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + MinValue, + MaxValue, + NullCount, + DistinctCount, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "minValue" | "min_value" => Ok(GeneratedField::MinValue), + "maxValue" | "max_value" => Ok(GeneratedField::MaxValue), + "nullCount" | "null_count" => Ok(GeneratedField::NullCount), + "distinctCount" | "distinct_count" => Ok(GeneratedField::DistinctCount), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ColumnStats; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ColumnStats") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut min_value__ = None; + let mut max_value__ = None; + let mut null_count__ = None; + let mut distinct_count__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::MinValue => { + if min_value__.is_some() { + return Err(serde::de::Error::duplicate_field("minValue")); + } + min_value__ = map_.next_value()?; + } + GeneratedField::MaxValue => { + if max_value__.is_some() { + return Err(serde::de::Error::duplicate_field("maxValue")); + } + max_value__ = map_.next_value()?; + } + GeneratedField::NullCount => { + if null_count__.is_some() { + return Err(serde::de::Error::duplicate_field("nullCount")); + } + null_count__ = map_.next_value()?; + } + GeneratedField::DistinctCount => { + if distinct_count__.is_some() { + return Err(serde::de::Error::duplicate_field("distinctCount")); + } + distinct_count__ = map_.next_value()?; + } + } + } + Ok(ColumnStats { + min_value: min_value__, + max_value: max_value__, + null_count: null_count__, + distinct_count: distinct_count__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ColumnStats", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CompressionTypeVariant { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "GZIP", + "BZIP2", + "XZ", + "ZSTD", + "UNCOMPRESSED", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CompressionTypeVariant; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "GZIP" => Ok(CompressionTypeVariant::Gzip), + "BZIP2" => Ok(CompressionTypeVariant::Bzip2), + "XZ" => Ok(CompressionTypeVariant::Xz), + "ZSTD" => Ok(CompressionTypeVariant::Zstd), + "UNCOMPRESSED" => Ok(CompressionTypeVariant::Uncompressed), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for Constraint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.constraint_mode.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Constraint", len)?; + if let Some(v) = self.constraint_mode.as_ref() { + match v { + constraint::ConstraintMode::PrimaryKey(v) => { + struct_ser.serialize_field("primaryKey", v)?; + } + constraint::ConstraintMode::Unique(v) => { + struct_ser.serialize_field("unique", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Constraint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "primary_key", + "primaryKey", + "unique", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + PrimaryKey, + Unique, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "primaryKey" | "primary_key" => Ok(GeneratedField::PrimaryKey), + "unique" => Ok(GeneratedField::Unique), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Constraint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Constraint") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut constraint_mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::PrimaryKey => { + if constraint_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("primaryKey")); + } + constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::PrimaryKey) +; + } + GeneratedField::Unique => { + if constraint_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("unique")); + } + constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::Unique) +; + } + } + } + Ok(Constraint { + constraint_mode: constraint_mode__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.Constraint", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Constraints { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.constraints.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Constraints", len)?; + if !self.constraints.is_empty() { + struct_ser.serialize_field("constraints", &self.constraints)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Constraints { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "constraints", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Constraints, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "constraints" => Ok(GeneratedField::Constraints), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Constraints; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Constraints") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut constraints__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Constraints => { + if constraints__.is_some() { + return Err(serde::de::Error::duplicate_field("constraints")); + } + constraints__ = Some(map_.next_value()?); + } + } + } + Ok(Constraints { + constraints: constraints__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Constraints", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvFormat", len)?; + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "options", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Options, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "options" => Ok(GeneratedField::Options), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.CsvFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = map_.next_value()?; + } + } + } + Ok(CsvFormat { + options: options__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.CsvFormat", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.has_header.is_empty() { + len += 1; + } + if !self.delimiter.is_empty() { + len += 1; + } + if !self.quote.is_empty() { + len += 1; + } + if !self.escape.is_empty() { + len += 1; + } + if self.compression != 0 { + len += 1; + } + if self.schema_infer_max_rec != 0 { + len += 1; + } + if !self.date_format.is_empty() { + len += 1; + } + if !self.datetime_format.is_empty() { + len += 1; + } + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.timestamp_tz_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + if !self.comment.is_empty() { + len += 1; + } + if !self.double_quote.is_empty() { + len += 1; + } + if !self.newlines_in_values.is_empty() { + len += 1; + } + if !self.terminator.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvOptions", len)?; + if !self.has_header.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("hasHeader", pbjson::private::base64::encode(&self.has_header).as_str())?; + } + if !self.delimiter.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("delimiter", pbjson::private::base64::encode(&self.delimiter).as_str())?; + } + if !self.quote.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("quote", pbjson::private::base64::encode(&self.quote).as_str())?; + } + if !self.escape.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("escape", pbjson::private::base64::encode(&self.escape).as_str())?; + } + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + if self.schema_infer_max_rec != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.timestamp_tz_format.is_empty() { + struct_ser.serialize_field("timestampTzFormat", &self.timestamp_tz_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; + } + if !self.comment.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("comment", pbjson::private::base64::encode(&self.comment).as_str())?; + } + if !self.double_quote.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("doubleQuote", pbjson::private::base64::encode(&self.double_quote).as_str())?; + } + if !self.newlines_in_values.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("newlinesInValues", pbjson::private::base64::encode(&self.newlines_in_values).as_str())?; + } + if !self.terminator.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("terminator", pbjson::private::base64::encode(&self.terminator).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "has_header", + "hasHeader", + "delimiter", + "quote", + "escape", + "compression", + "schema_infer_max_rec", + "schemaInferMaxRec", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "timestamp_tz_format", + "timestampTzFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", + "comment", + "double_quote", + "doubleQuote", + "newlines_in_values", + "newlinesInValues", + "terminator", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + HasHeader, + Delimiter, + Quote, + Escape, + Compression, + SchemaInferMaxRec, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimestampTzFormat, + TimeFormat, + NullValue, + Comment, + DoubleQuote, + NewlinesInValues, + Terminator, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), + "compression" => Ok(GeneratedField::Compression), + "schemaInferMaxRec" | "schema_infer_max_rec" => Ok(GeneratedField::SchemaInferMaxRec), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timestampTzFormat" | "timestamp_tz_format" => Ok(GeneratedField::TimestampTzFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "comment" => Ok(GeneratedField::Comment), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), + "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), + "terminator" => Ok(GeneratedField::Terminator), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.CsvOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut has_header__ = None; + let mut delimiter__ = None; + let mut quote__ = None; + let mut escape__ = None; + let mut compression__ = None; + let mut schema_infer_max_rec__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut timestamp_tz_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; + let mut comment__ = None; + let mut double_quote__ = None; + let mut newlines_in_values__ = None; + let mut terminator__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Escape => { + if escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + escape__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + GeneratedField::SchemaInferMaxRec => { + if schema_infer_max_rec__.is_some() { + return Err(serde::de::Error::duplicate_field("schemaInferMaxRec")); + } + schema_infer_max_rec__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampTzFormat => { + if timestamp_tz_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampTzFormat")); + } + timestamp_tz_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); + } + GeneratedField::Comment => { + if comment__.is_some() { + return Err(serde::de::Error::duplicate_field("comment")); + } + comment__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::NewlinesInValues => { + if newlines_in_values__.is_some() { + return Err(serde::de::Error::duplicate_field("newlinesInValues")); + } + newlines_in_values__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Terminator => { + if terminator__.is_some() { + return Err(serde::de::Error::duplicate_field("terminator")); + } + terminator__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + } + } + Ok(CsvOptions { + has_header: has_header__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + escape: escape__.unwrap_or_default(), + compression: compression__.unwrap_or_default(), + schema_infer_max_rec: schema_infer_max_rec__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + timestamp_tz_format: timestamp_tz_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), + comment: comment__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), + newlines_in_values: newlines_in_values__.unwrap_or_default(), + terminator: terminator__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.CsvOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CsvWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.compression != 0 { + len += 1; + } + if !self.delimiter.is_empty() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.date_format.is_empty() { + len += 1; + } + if !self.datetime_format.is_empty() { + len += 1; + } + if !self.timestamp_format.is_empty() { + len += 1; + } + if !self.time_format.is_empty() { + len += 1; + } + if !self.null_value.is_empty() { + len += 1; + } + if !self.quote.is_empty() { + len += 1; + } + if !self.escape.is_empty() { + len += 1; + } + if self.double_quote { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.CsvWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.date_format.is_empty() { + struct_ser.serialize_field("dateFormat", &self.date_format)?; + } + if !self.datetime_format.is_empty() { + struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + } + if !self.timestamp_format.is_empty() { + struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + } + if !self.time_format.is_empty() { + struct_ser.serialize_field("timeFormat", &self.time_format)?; + } + if !self.null_value.is_empty() { + struct_ser.serialize_field("nullValue", &self.null_value)?; + } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if !self.escape.is_empty() { + struct_ser.serialize_field("escape", &self.escape)?; + } + if self.double_quote { + struct_ser.serialize_field("doubleQuote", &self.double_quote)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CsvWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "compression", + "delimiter", + "has_header", + "hasHeader", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "time_format", + "timeFormat", + "null_value", + "nullValue", + "quote", + "escape", + "double_quote", + "doubleQuote", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Compression, + Delimiter, + HasHeader, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimeFormat, + NullValue, + Quote, + Escape, + DoubleQuote, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "compression" => Ok(GeneratedField::Compression), + "delimiter" => Ok(GeneratedField::Delimiter), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "quote" => Ok(GeneratedField::Quote), + "escape" => Ok(GeneratedField::Escape), + "doubleQuote" | "double_quote" => Ok(GeneratedField::DoubleQuote), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CsvWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.CsvWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut compression__ = None; + let mut delimiter__ = None; + let mut has_header__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut time_format__ = None; + let mut null_value__ = None; + let mut quote__ = None; + let mut escape__ = None; + let mut double_quote__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); + } + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = Some(map_.next_value()?); + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = Some(map_.next_value()?); + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = Some(map_.next_value()?); + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = Some(map_.next_value()?); + } + GeneratedField::NullValue => { + if null_value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + null_value__ = Some(map_.next_value()?); + } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + escape__ = Some(map_.next_value()?); + } + GeneratedField::DoubleQuote => { + if double_quote__.is_some() { + return Err(serde::de::Error::duplicate_field("doubleQuote")); + } + double_quote__ = Some(map_.next_value()?); + } + } + } + Ok(CsvWriterOptions { + compression: compression__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + has_header: has_header__.unwrap_or_default(), + date_format: date_format__.unwrap_or_default(), + datetime_format: datetime_format__.unwrap_or_default(), + timestamp_format: timestamp_format__.unwrap_or_default(), + time_format: time_format__.unwrap_or_default(), + null_value: null_value__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + escape: escape__.unwrap_or_default(), + double_quote: double_quote__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.CsvWriterOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal128 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal128", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal128 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal128; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal128") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal128 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal128", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal256 { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.value.is_empty() { + len += 1; + } + if self.p != 0 { + len += 1; + } + if self.s != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256", len)?; + if !self.value.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + } + if self.p != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + } + if self.s != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256 { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value", + "p", + "s", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Value, + P, + S, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "value" => Ok(GeneratedField::Value), + "p" => Ok(GeneratedField::P), + "s" => Ok(GeneratedField::S), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal256") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + let mut p__ = None; + let mut s__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::P => { + if p__.is_some() { + return Err(serde::de::Error::duplicate_field("p")); + } + p__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::S => { + if s__.is_some() { + return Err(serde::de::Error::duplicate_field("s")); + } + s__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256 { + value: value__.unwrap_or_default(), + p: p__.unwrap_or_default(), + s: s__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal256", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Decimal256Type { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision != 0 { + len += 1; + } + if self.scale != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Decimal256Type", len)?; + if self.precision != 0 { + struct_ser.serialize_field("precision", &self.precision)?; + } + if self.scale != 0 { + struct_ser.serialize_field("scale", &self.scale)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Decimal256Type { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision", + "scale", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Precision, + Scale, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precision" => Ok(GeneratedField::Precision), + "scale" => Ok(GeneratedField::Scale), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Decimal256Type; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Decimal256Type") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision__ = None; + let mut scale__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Precision => { + if precision__.is_some() { + return Err(serde::de::Error::duplicate_field("precision")); + } + precision__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Scale => { + if scale__.is_some() { + return Err(serde::de::Error::duplicate_field("scale")); + } + scale__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(Decimal256Type { + precision: precision__.unwrap_or_default(), + scale: scale__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Decimal256Type", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for DfField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field.is_some() { + len += 1; + } + if self.qualifier.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.DfField", len)?; + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DfField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field", + "qualifier", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Field, + Qualifier, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "field" => Ok(GeneratedField::Field), + "qualifier" => Ok(GeneratedField::Qualifier), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DfField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.DfField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field__ = None; + let mut qualifier__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + GeneratedField::Qualifier => { + if qualifier__.is_some() { + return Err(serde::de::Error::duplicate_field("qualifier")); + } + qualifier__ = map_.next_value()?; + } + } + } + Ok(DfField { + field: field__, + qualifier: qualifier__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.DfField", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for DfSchema { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.columns.is_empty() { + len += 1; + } + if !self.metadata.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.DfSchema", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns)?; + } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for DfSchema { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "columns", + "metadata", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Columns, + Metadata, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "columns" => Ok(GeneratedField::Columns), + "metadata" => Ok(GeneratedField::Metadata), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DfSchema; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.DfSchema") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut columns__ = None; + let mut metadata__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); + } + columns__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + } + } + Ok(DfSchema { + columns: columns__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.DfSchema", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Dictionary { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.key.is_some() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Dictionary", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Dictionary { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Dictionary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Dictionary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = map_.next_value()?; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + } + } + Ok(Dictionary { + key: key__, + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.Dictionary", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for EmptyMessage { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let len = 0; + let struct_ser = serializer.serialize_struct("datafusion_common.EmptyMessage", len)?; + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for EmptyMessage { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + Err(serde::de::Error::unknown_field(value, FIELDS)) + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = EmptyMessage; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.EmptyMessage") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + while map_.next_key::()?.is_some() { + let _ = map_.next_value::()?; + } + Ok(EmptyMessage { + }) + } + } + deserializer.deserialize_struct("datafusion_common.EmptyMessage", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Field { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.arrow_type.is_some() { + len += 1; + } + if self.nullable { + len += 1; + } + if !self.children.is_empty() { + len += 1; + } + if !self.metadata.is_empty() { + len += 1; + } + if self.dict_id != 0 { + len += 1; + } + if self.dict_ordered { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Field", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; + } + if self.nullable { + struct_ser.serialize_field("nullable", &self.nullable)?; + } + if !self.children.is_empty() { + struct_ser.serialize_field("children", &self.children)?; + } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + if self.dict_id != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; + } + if self.dict_ordered { + struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Field { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "arrow_type", + "arrowType", + "nullable", + "children", + "metadata", + "dict_id", + "dictId", + "dict_ordered", + "dictOrdered", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + ArrowType, + Nullable, + Children, + Metadata, + DictId, + DictOrdered, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "nullable" => Ok(GeneratedField::Nullable), + "children" => Ok(GeneratedField::Children), + "metadata" => Ok(GeneratedField::Metadata), + "dictId" | "dict_id" => Ok(GeneratedField::DictId), + "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Field; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Field") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut arrow_type__ = None; + let mut nullable__ = None; + let mut children__ = None; + let mut metadata__ = None; + let mut dict_id__ = None; + let mut dict_ordered__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); + } + arrow_type__ = map_.next_value()?; + } + GeneratedField::Nullable => { + if nullable__.is_some() { + return Err(serde::de::Error::duplicate_field("nullable")); + } + nullable__ = Some(map_.next_value()?); + } + GeneratedField::Children => { + if children__.is_some() { + return Err(serde::de::Error::duplicate_field("children")); + } + children__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::DictId => { + if dict_id__.is_some() { + return Err(serde::de::Error::duplicate_field("dictId")); + } + dict_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DictOrdered => { + if dict_ordered__.is_some() { + return Err(serde::de::Error::duplicate_field("dictOrdered")); + } + dict_ordered__ = Some(map_.next_value()?); + } + } + } + Ok(Field { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, + nullable: nullable__.unwrap_or_default(), + children: children__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), + dict_id: dict_id__.unwrap_or_default(), + dict_ordered: dict_ordered__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Field", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for FixedSizeList { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + if self.list_size != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.FixedSizeList", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + if self.list_size != 0 { + struct_ser.serialize_field("listSize", &self.list_size)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FixedSizeList { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + "list_size", + "listSize", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + ListSize, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "listSize" | "list_size" => Ok(GeneratedField::ListSize), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FixedSizeList; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.FixedSizeList") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + let mut list_size__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map_.next_value()?; + } + GeneratedField::ListSize => { + if list_size__.is_some() { + return Err(serde::de::Error::duplicate_field("listSize")); + } + list_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(FixedSizeList { + field_type: field_type__, + list_size: list_size__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.FixedSizeList", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for IntervalDayTimeValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.days != 0 { + len += 1; + } + if self.milliseconds != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.IntervalDayTimeValue", len)?; + if self.days != 0 { + struct_ser.serialize_field("days", &self.days)?; + } + if self.milliseconds != 0 { + struct_ser.serialize_field("milliseconds", &self.milliseconds)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for IntervalDayTimeValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "days", + "milliseconds", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Days, + Milliseconds, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "days" => Ok(GeneratedField::Days), + "milliseconds" => Ok(GeneratedField::Milliseconds), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = IntervalDayTimeValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.IntervalDayTimeValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut days__ = None; + let mut milliseconds__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Days => { + if days__.is_some() { + return Err(serde::de::Error::duplicate_field("days")); + } + days__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Milliseconds => { + if milliseconds__.is_some() { + return Err(serde::de::Error::duplicate_field("milliseconds")); + } + milliseconds__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(IntervalDayTimeValue { + days: days__.unwrap_or_default(), + milliseconds: milliseconds__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.IntervalDayTimeValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for IntervalMonthDayNanoValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.months != 0 { + len += 1; + } + if self.days != 0 { + len += 1; + } + if self.nanos != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.IntervalMonthDayNanoValue", len)?; + if self.months != 0 { + struct_ser.serialize_field("months", &self.months)?; + } + if self.days != 0 { + struct_ser.serialize_field("days", &self.days)?; + } + if self.nanos != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "months", + "days", + "nanos", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Months, + Days, + Nanos, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "months" => Ok(GeneratedField::Months), + "days" => Ok(GeneratedField::Days), + "nanos" => Ok(GeneratedField::Nanos), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = IntervalMonthDayNanoValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.IntervalMonthDayNanoValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut months__ = None; + let mut days__ = None; + let mut nanos__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Months => { + if months__.is_some() { + return Err(serde::de::Error::duplicate_field("months")); + } + months__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Days => { + if days__.is_some() { + return Err(serde::de::Error::duplicate_field("days")); + } + days__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Nanos => { + if nanos__.is_some() { + return Err(serde::de::Error::duplicate_field("nanos")); + } + nanos__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(IntervalMonthDayNanoValue { + months: months__.unwrap_or_default(), + days: days__.unwrap_or_default(), + nanos: nanos__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.IntervalMonthDayNanoValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for IntervalUnit { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::YearMonth => "YearMonth", + Self::DayTime => "DayTime", + Self::MonthDayNano => "MonthDayNano", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for IntervalUnit { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "YearMonth", + "DayTime", + "MonthDayNano", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = IntervalUnit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "YearMonth" => Ok(IntervalUnit::YearMonth), + "DayTime" => Ok(IntervalUnit::DayTime), + "MonthDayNano" => Ok(IntervalUnit::MonthDayNano), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JoinConstraint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::On => "ON", + Self::Using => "USING", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for JoinConstraint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "ON", + "USING", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JoinConstraint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "ON" => Ok(JoinConstraint::On), + "USING" => Ok(JoinConstraint::Using), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JoinSide { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::LeftSide => "LEFT_SIDE", + Self::RightSide => "RIGHT_SIDE", + Self::None => "NONE", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for JoinSide { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "LEFT_SIDE", + "RIGHT_SIDE", + "NONE", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JoinSide; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "LEFT_SIDE" => Ok(JoinSide::LeftSide), + "RIGHT_SIDE" => Ok(JoinSide::RightSide), + "NONE" => Ok(JoinSide::None), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JoinType { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", + Self::Leftmark => "LEFTMARK", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for JoinType { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "INNER", + "LEFT", + "RIGHT", + "FULL", + "LEFTSEMI", + "LEFTANTI", + "RIGHTSEMI", + "RIGHTANTI", + "LEFTMARK", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JoinType; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "INNER" => Ok(JoinType::Inner), + "LEFT" => Ok(JoinType::Left), + "RIGHT" => Ok(JoinType::Right), + "FULL" => Ok(JoinType::Full), + "LEFTSEMI" => Ok(JoinType::Leftsemi), + "LEFTANTI" => Ok(JoinType::Leftanti), + "RIGHTSEMI" => Ok(JoinType::Rightsemi), + "RIGHTANTI" => Ok(JoinType::Rightanti), + "LEFTMARK" => Ok(JoinType::Leftmark), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for JsonOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.compression != 0 { + len += 1; + } + if self.schema_infer_max_rec != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.JsonOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + if self.schema_infer_max_rec != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "compression", + "schema_infer_max_rec", + "schemaInferMaxRec", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Compression, + SchemaInferMaxRec, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "compression" => Ok(GeneratedField::Compression), + "schemaInferMaxRec" | "schema_infer_max_rec" => Ok(GeneratedField::SchemaInferMaxRec), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.JsonOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut compression__ = None; + let mut schema_infer_max_rec__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + GeneratedField::SchemaInferMaxRec => { + if schema_infer_max_rec__.is_some() { + return Err(serde::de::Error::duplicate_field("schemaInferMaxRec")); + } + schema_infer_max_rec__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(JsonOptions { + compression: compression__.unwrap_or_default(), + schema_infer_max_rec: schema_infer_max_rec__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.JsonOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for JsonWriterOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.compression != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.JsonWriterOptions", len)?; + if self.compression != 0 { + let v = CompressionTypeVariant::try_from(self.compression) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; + struct_ser.serialize_field("compression", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for JsonWriterOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "compression", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Compression, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "compression" => Ok(GeneratedField::Compression), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = JsonWriterOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.JsonWriterOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut compression__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Compression => { + if compression__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(JsonWriterOptions { + compression: compression__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.JsonWriterOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for List { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.List", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for List { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = List; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.List") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map_.next_value()?; + } + } + } + Ok(List { + field_type: field_type__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.List", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Map { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_type.is_some() { + len += 1; + } + if self.keys_sorted { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Map", len)?; + if let Some(v) = self.field_type.as_ref() { + struct_ser.serialize_field("fieldType", v)?; + } + if self.keys_sorted { + struct_ser.serialize_field("keysSorted", &self.keys_sorted)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Map { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_type", + "fieldType", + "keys_sorted", + "keysSorted", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldType, + KeysSorted, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldType" | "field_type" => Ok(GeneratedField::FieldType), + "keysSorted" | "keys_sorted" => Ok(GeneratedField::KeysSorted), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Map; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Map") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_type__ = None; + let mut keys_sorted__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldType => { + if field_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldType")); + } + field_type__ = map_.next_value()?; + } + GeneratedField::KeysSorted => { + if keys_sorted__.is_some() { + return Err(serde::de::Error::duplicate_field("keysSorted")); + } + keys_sorted__ = Some(map_.next_value()?); + } + } + } + Ok(Map { + field_type: field_type__, + keys_sorted: keys_sorted__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Map", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for NdJsonFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.NdJsonFormat", len)?; + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for NdJsonFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "options", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Options, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "options" => Ok(GeneratedField::Options), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = NdJsonFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.NdJsonFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = map_.next_value()?; + } + } + } + Ok(NdJsonFormat { + options: options__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.NdJsonFormat", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetColumnOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.bloom_filter_enabled_opt.is_some() { + len += 1; + } + if self.encoding_opt.is_some() { + len += 1; + } + if self.dictionary_enabled_opt.is_some() { + len += 1; + } + if self.compression_opt.is_some() { + len += 1; + } + if self.statistics_enabled_opt.is_some() { + len += 1; + } + if self.bloom_filter_fpp_opt.is_some() { + len += 1; + } + if self.bloom_filter_ndv_opt.is_some() { + len += 1; + } + if self.max_statistics_size_opt.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ParquetColumnOptions", len)?; + if let Some(v) = self.bloom_filter_enabled_opt.as_ref() { + match v { + parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v) => { + struct_ser.serialize_field("bloomFilterEnabled", v)?; + } + } + } + if let Some(v) = self.encoding_opt.as_ref() { + match v { + parquet_column_options::EncodingOpt::Encoding(v) => { + struct_ser.serialize_field("encoding", v)?; + } + } + } + if let Some(v) = self.dictionary_enabled_opt.as_ref() { + match v { + parquet_column_options::DictionaryEnabledOpt::DictionaryEnabled(v) => { + struct_ser.serialize_field("dictionaryEnabled", v)?; + } + } + } + if let Some(v) = self.compression_opt.as_ref() { + match v { + parquet_column_options::CompressionOpt::Compression(v) => { + struct_ser.serialize_field("compression", v)?; + } + } + } + if let Some(v) = self.statistics_enabled_opt.as_ref() { + match v { + parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled(v) => { + struct_ser.serialize_field("statisticsEnabled", v)?; + } + } + } + if let Some(v) = self.bloom_filter_fpp_opt.as_ref() { + match v { + parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(v) => { + struct_ser.serialize_field("bloomFilterFpp", v)?; + } + } + } + if let Some(v) = self.bloom_filter_ndv_opt.as_ref() { + match v { + parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; + } + } + } + if let Some(v) = self.max_statistics_size_opt.as_ref() { + match v { + parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { + struct_ser.serialize_field("maxStatisticsSize", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetColumnOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "bloom_filter_enabled", + "bloomFilterEnabled", + "encoding", + "dictionary_enabled", + "dictionaryEnabled", + "compression", + "statistics_enabled", + "statisticsEnabled", + "bloom_filter_fpp", + "bloomFilterFpp", + "bloom_filter_ndv", + "bloomFilterNdv", + "max_statistics_size", + "maxStatisticsSize", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + BloomFilterEnabled, + Encoding, + DictionaryEnabled, + Compression, + StatisticsEnabled, + BloomFilterFpp, + BloomFilterNdv, + MaxStatisticsSize, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "bloomFilterEnabled" | "bloom_filter_enabled" => Ok(GeneratedField::BloomFilterEnabled), + "encoding" => Ok(GeneratedField::Encoding), + "dictionaryEnabled" | "dictionary_enabled" => Ok(GeneratedField::DictionaryEnabled), + "compression" => Ok(GeneratedField::Compression), + "statisticsEnabled" | "statistics_enabled" => Ok(GeneratedField::StatisticsEnabled), + "bloomFilterFpp" | "bloom_filter_fpp" => Ok(GeneratedField::BloomFilterFpp), + "bloomFilterNdv" | "bloom_filter_ndv" => Ok(GeneratedField::BloomFilterNdv), + "maxStatisticsSize" | "max_statistics_size" => Ok(GeneratedField::MaxStatisticsSize), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetColumnOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ParquetColumnOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut bloom_filter_enabled_opt__ = None; + let mut encoding_opt__ = None; + let mut dictionary_enabled_opt__ = None; + let mut compression_opt__ = None; + let mut statistics_enabled_opt__ = None; + let mut bloom_filter_fpp_opt__ = None; + let mut bloom_filter_ndv_opt__ = None; + let mut max_statistics_size_opt__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::BloomFilterEnabled => { + if bloom_filter_enabled_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterEnabled")); + } + bloom_filter_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled); + } + GeneratedField::Encoding => { + if encoding_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("encoding")); + } + encoding_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_column_options::EncodingOpt::Encoding); + } + GeneratedField::DictionaryEnabled => { + if dictionary_enabled_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryEnabled")); + } + dictionary_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_column_options::DictionaryEnabledOpt::DictionaryEnabled); + } + GeneratedField::Compression => { + if compression_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_column_options::CompressionOpt::Compression); + } + GeneratedField::StatisticsEnabled => { + if statistics_enabled_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("statisticsEnabled")); + } + statistics_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled); + } + GeneratedField::BloomFilterFpp => { + if bloom_filter_fpp_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterFpp")); + } + bloom_filter_fpp_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(x.0)); + } + GeneratedField::BloomFilterNdv => { + if bloom_filter_ndv_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterNdv")); + } + bloom_filter_ndv_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(x.0)); + } + GeneratedField::MaxStatisticsSize => { + if max_statistics_size_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("maxStatisticsSize")); + } + max_statistics_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(x.0)); + } + } + } + Ok(ParquetColumnOptions { + bloom_filter_enabled_opt: bloom_filter_enabled_opt__, + encoding_opt: encoding_opt__, + dictionary_enabled_opt: dictionary_enabled_opt__, + compression_opt: compression_opt__, + statistics_enabled_opt: statistics_enabled_opt__, + bloom_filter_fpp_opt: bloom_filter_fpp_opt__, + bloom_filter_ndv_opt: bloom_filter_ndv_opt__, + max_statistics_size_opt: max_statistics_size_opt__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ParquetColumnOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetColumnSpecificOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.column_name.is_empty() { + len += 1; + } + if self.options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ParquetColumnSpecificOptions", len)?; + if !self.column_name.is_empty() { + struct_ser.serialize_field("columnName", &self.column_name)?; + } + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetColumnSpecificOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "column_name", + "columnName", + "options", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ColumnName, + Options, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "columnName" | "column_name" => Ok(GeneratedField::ColumnName), + "options" => Ok(GeneratedField::Options), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetColumnSpecificOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ParquetColumnSpecificOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut column_name__ = None; + let mut options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ColumnName => { + if column_name__.is_some() { + return Err(serde::de::Error::duplicate_field("columnName")); + } + column_name__ = Some(map_.next_value()?); + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = map_.next_value()?; + } + } + } + Ok(ParquetColumnSpecificOptions { + column_name: column_name__.unwrap_or_default(), + options: options__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ParquetColumnSpecificOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ParquetFormat", len)?; + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "options", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Options, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "options" => Ok(GeneratedField::Options), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ParquetFormat") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = map_.next_value()?; + } + } + } + Ok(ParquetFormat { + options: options__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ParquetFormat", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ParquetOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.enable_page_index { + len += 1; + } + if self.pruning { + len += 1; + } + if self.skip_metadata { + len += 1; + } + if self.pushdown_filters { + len += 1; + } + if self.reorder_filters { + len += 1; + } + if self.data_pagesize_limit != 0 { + len += 1; + } + if self.write_batch_size != 0 { + len += 1; + } + if !self.writer_version.is_empty() { + len += 1; + } + if self.allow_single_file_parallelism { + len += 1; + } + if self.maximum_parallel_row_group_writers != 0 { + len += 1; + } + if self.maximum_buffered_record_batches_per_stream != 0 { + len += 1; + } + if self.bloom_filter_on_read { + len += 1; + } + if self.bloom_filter_on_write { + len += 1; + } + if self.schema_force_view_types { + len += 1; + } + if self.binary_as_string { + len += 1; + } + if self.dictionary_page_size_limit != 0 { + len += 1; + } + if self.data_page_row_count_limit != 0 { + len += 1; + } + if self.max_row_group_size != 0 { + len += 1; + } + if !self.created_by.is_empty() { + len += 1; + } + if self.metadata_size_hint_opt.is_some() { + len += 1; + } + if self.compression_opt.is_some() { + len += 1; + } + if self.dictionary_enabled_opt.is_some() { + len += 1; + } + if self.statistics_enabled_opt.is_some() { + len += 1; + } + if self.max_statistics_size_opt.is_some() { + len += 1; + } + if self.column_index_truncate_length_opt.is_some() { + len += 1; + } + if self.encoding_opt.is_some() { + len += 1; + } + if self.bloom_filter_fpp_opt.is_some() { + len += 1; + } + if self.bloom_filter_ndv_opt.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ParquetOptions", len)?; + if self.enable_page_index { + struct_ser.serialize_field("enablePageIndex", &self.enable_page_index)?; + } + if self.pruning { + struct_ser.serialize_field("pruning", &self.pruning)?; + } + if self.skip_metadata { + struct_ser.serialize_field("skipMetadata", &self.skip_metadata)?; + } + if self.pushdown_filters { + struct_ser.serialize_field("pushdownFilters", &self.pushdown_filters)?; + } + if self.reorder_filters { + struct_ser.serialize_field("reorderFilters", &self.reorder_filters)?; + } + if self.data_pagesize_limit != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("dataPagesizeLimit", ToString::to_string(&self.data_pagesize_limit).as_str())?; + } + if self.write_batch_size != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; + } + if !self.writer_version.is_empty() { + struct_ser.serialize_field("writerVersion", &self.writer_version)?; + } + if self.allow_single_file_parallelism { + struct_ser.serialize_field("allowSingleFileParallelism", &self.allow_single_file_parallelism)?; + } + if self.maximum_parallel_row_group_writers != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("maximumParallelRowGroupWriters", ToString::to_string(&self.maximum_parallel_row_group_writers).as_str())?; + } + if self.maximum_buffered_record_batches_per_stream != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("maximumBufferedRecordBatchesPerStream", ToString::to_string(&self.maximum_buffered_record_batches_per_stream).as_str())?; + } + if self.bloom_filter_on_read { + struct_ser.serialize_field("bloomFilterOnRead", &self.bloom_filter_on_read)?; + } + if self.bloom_filter_on_write { + struct_ser.serialize_field("bloomFilterOnWrite", &self.bloom_filter_on_write)?; + } + if self.schema_force_view_types { + struct_ser.serialize_field("schemaForceViewTypes", &self.schema_force_view_types)?; + } + if self.binary_as_string { + struct_ser.serialize_field("binaryAsString", &self.binary_as_string)?; + } + if self.dictionary_page_size_limit != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; + } + if self.data_page_row_count_limit != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; + } + if self.max_row_group_size != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; + } + if !self.created_by.is_empty() { + struct_ser.serialize_field("createdBy", &self.created_by)?; + } + if let Some(v) = self.metadata_size_hint_opt.as_ref() { + match v { + parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("metadataSizeHint", ToString::to_string(&v).as_str())?; + } + } + } + if let Some(v) = self.compression_opt.as_ref() { + match v { + parquet_options::CompressionOpt::Compression(v) => { + struct_ser.serialize_field("compression", v)?; + } + } + } + if let Some(v) = self.dictionary_enabled_opt.as_ref() { + match v { + parquet_options::DictionaryEnabledOpt::DictionaryEnabled(v) => { + struct_ser.serialize_field("dictionaryEnabled", v)?; + } + } + } + if let Some(v) = self.statistics_enabled_opt.as_ref() { + match v { + parquet_options::StatisticsEnabledOpt::StatisticsEnabled(v) => { + struct_ser.serialize_field("statisticsEnabled", v)?; + } + } + } + if let Some(v) = self.max_statistics_size_opt.as_ref() { + match v { + parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("maxStatisticsSize", ToString::to_string(&v).as_str())?; + } + } + } + if let Some(v) = self.column_index_truncate_length_opt.as_ref() { + match v { + parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("columnIndexTruncateLength", ToString::to_string(&v).as_str())?; + } + } + } + if let Some(v) = self.encoding_opt.as_ref() { + match v { + parquet_options::EncodingOpt::Encoding(v) => { + struct_ser.serialize_field("encoding", v)?; + } + } + } + if let Some(v) = self.bloom_filter_fpp_opt.as_ref() { + match v { + parquet_options::BloomFilterFppOpt::BloomFilterFpp(v) => { + struct_ser.serialize_field("bloomFilterFpp", v)?; + } + } + } + if let Some(v) = self.bloom_filter_ndv_opt.as_ref() { + match v { + parquet_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ParquetOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "enable_page_index", + "enablePageIndex", + "pruning", + "skip_metadata", + "skipMetadata", + "pushdown_filters", + "pushdownFilters", + "reorder_filters", + "reorderFilters", + "data_pagesize_limit", + "dataPagesizeLimit", + "write_batch_size", + "writeBatchSize", + "writer_version", + "writerVersion", + "allow_single_file_parallelism", + "allowSingleFileParallelism", + "maximum_parallel_row_group_writers", + "maximumParallelRowGroupWriters", + "maximum_buffered_record_batches_per_stream", + "maximumBufferedRecordBatchesPerStream", + "bloom_filter_on_read", + "bloomFilterOnRead", + "bloom_filter_on_write", + "bloomFilterOnWrite", + "schema_force_view_types", + "schemaForceViewTypes", + "binary_as_string", + "binaryAsString", + "dictionary_page_size_limit", + "dictionaryPageSizeLimit", + "data_page_row_count_limit", + "dataPageRowCountLimit", + "max_row_group_size", + "maxRowGroupSize", + "created_by", + "createdBy", + "metadata_size_hint", + "metadataSizeHint", + "compression", + "dictionary_enabled", + "dictionaryEnabled", + "statistics_enabled", + "statisticsEnabled", + "max_statistics_size", + "maxStatisticsSize", + "column_index_truncate_length", + "columnIndexTruncateLength", + "encoding", + "bloom_filter_fpp", + "bloomFilterFpp", + "bloom_filter_ndv", + "bloomFilterNdv", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + EnablePageIndex, + Pruning, + SkipMetadata, + PushdownFilters, + ReorderFilters, + DataPagesizeLimit, + WriteBatchSize, + WriterVersion, + AllowSingleFileParallelism, + MaximumParallelRowGroupWriters, + MaximumBufferedRecordBatchesPerStream, + BloomFilterOnRead, + BloomFilterOnWrite, + SchemaForceViewTypes, + BinaryAsString, + DictionaryPageSizeLimit, + DataPageRowCountLimit, + MaxRowGroupSize, + CreatedBy, + MetadataSizeHint, + Compression, + DictionaryEnabled, + StatisticsEnabled, + MaxStatisticsSize, + ColumnIndexTruncateLength, + Encoding, + BloomFilterFpp, + BloomFilterNdv, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "enablePageIndex" | "enable_page_index" => Ok(GeneratedField::EnablePageIndex), + "pruning" => Ok(GeneratedField::Pruning), + "skipMetadata" | "skip_metadata" => Ok(GeneratedField::SkipMetadata), + "pushdownFilters" | "pushdown_filters" => Ok(GeneratedField::PushdownFilters), + "reorderFilters" | "reorder_filters" => Ok(GeneratedField::ReorderFilters), + "dataPagesizeLimit" | "data_pagesize_limit" => Ok(GeneratedField::DataPagesizeLimit), + "writeBatchSize" | "write_batch_size" => Ok(GeneratedField::WriteBatchSize), + "writerVersion" | "writer_version" => Ok(GeneratedField::WriterVersion), + "allowSingleFileParallelism" | "allow_single_file_parallelism" => Ok(GeneratedField::AllowSingleFileParallelism), + "maximumParallelRowGroupWriters" | "maximum_parallel_row_group_writers" => Ok(GeneratedField::MaximumParallelRowGroupWriters), + "maximumBufferedRecordBatchesPerStream" | "maximum_buffered_record_batches_per_stream" => Ok(GeneratedField::MaximumBufferedRecordBatchesPerStream), + "bloomFilterOnRead" | "bloom_filter_on_read" => Ok(GeneratedField::BloomFilterOnRead), + "bloomFilterOnWrite" | "bloom_filter_on_write" => Ok(GeneratedField::BloomFilterOnWrite), + "schemaForceViewTypes" | "schema_force_view_types" => Ok(GeneratedField::SchemaForceViewTypes), + "binaryAsString" | "binary_as_string" => Ok(GeneratedField::BinaryAsString), + "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), + "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), + "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), + "createdBy" | "created_by" => Ok(GeneratedField::CreatedBy), + "metadataSizeHint" | "metadata_size_hint" => Ok(GeneratedField::MetadataSizeHint), + "compression" => Ok(GeneratedField::Compression), + "dictionaryEnabled" | "dictionary_enabled" => Ok(GeneratedField::DictionaryEnabled), + "statisticsEnabled" | "statistics_enabled" => Ok(GeneratedField::StatisticsEnabled), + "maxStatisticsSize" | "max_statistics_size" => Ok(GeneratedField::MaxStatisticsSize), + "columnIndexTruncateLength" | "column_index_truncate_length" => Ok(GeneratedField::ColumnIndexTruncateLength), + "encoding" => Ok(GeneratedField::Encoding), + "bloomFilterFpp" | "bloom_filter_fpp" => Ok(GeneratedField::BloomFilterFpp), + "bloomFilterNdv" | "bloom_filter_ndv" => Ok(GeneratedField::BloomFilterNdv), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ParquetOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ParquetOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut enable_page_index__ = None; + let mut pruning__ = None; + let mut skip_metadata__ = None; + let mut pushdown_filters__ = None; + let mut reorder_filters__ = None; + let mut data_pagesize_limit__ = None; + let mut write_batch_size__ = None; + let mut writer_version__ = None; + let mut allow_single_file_parallelism__ = None; + let mut maximum_parallel_row_group_writers__ = None; + let mut maximum_buffered_record_batches_per_stream__ = None; + let mut bloom_filter_on_read__ = None; + let mut bloom_filter_on_write__ = None; + let mut schema_force_view_types__ = None; + let mut binary_as_string__ = None; + let mut dictionary_page_size_limit__ = None; + let mut data_page_row_count_limit__ = None; + let mut max_row_group_size__ = None; + let mut created_by__ = None; + let mut metadata_size_hint_opt__ = None; + let mut compression_opt__ = None; + let mut dictionary_enabled_opt__ = None; + let mut statistics_enabled_opt__ = None; + let mut max_statistics_size_opt__ = None; + let mut column_index_truncate_length_opt__ = None; + let mut encoding_opt__ = None; + let mut bloom_filter_fpp_opt__ = None; + let mut bloom_filter_ndv_opt__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::EnablePageIndex => { + if enable_page_index__.is_some() { + return Err(serde::de::Error::duplicate_field("enablePageIndex")); + } + enable_page_index__ = Some(map_.next_value()?); + } + GeneratedField::Pruning => { + if pruning__.is_some() { + return Err(serde::de::Error::duplicate_field("pruning")); + } + pruning__ = Some(map_.next_value()?); + } + GeneratedField::SkipMetadata => { + if skip_metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("skipMetadata")); + } + skip_metadata__ = Some(map_.next_value()?); + } + GeneratedField::PushdownFilters => { + if pushdown_filters__.is_some() { + return Err(serde::de::Error::duplicate_field("pushdownFilters")); + } + pushdown_filters__ = Some(map_.next_value()?); + } + GeneratedField::ReorderFilters => { + if reorder_filters__.is_some() { + return Err(serde::de::Error::duplicate_field("reorderFilters")); + } + reorder_filters__ = Some(map_.next_value()?); + } + GeneratedField::DataPagesizeLimit => { + if data_pagesize_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPagesizeLimit")); + } + data_pagesize_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriteBatchSize => { + if write_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("writeBatchSize")); + } + write_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::WriterVersion => { + if writer_version__.is_some() { + return Err(serde::de::Error::duplicate_field("writerVersion")); + } + writer_version__ = Some(map_.next_value()?); + } + GeneratedField::AllowSingleFileParallelism => { + if allow_single_file_parallelism__.is_some() { + return Err(serde::de::Error::duplicate_field("allowSingleFileParallelism")); + } + allow_single_file_parallelism__ = Some(map_.next_value()?); + } + GeneratedField::MaximumParallelRowGroupWriters => { + if maximum_parallel_row_group_writers__.is_some() { + return Err(serde::de::Error::duplicate_field("maximumParallelRowGroupWriters")); + } + maximum_parallel_row_group_writers__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::MaximumBufferedRecordBatchesPerStream => { + if maximum_buffered_record_batches_per_stream__.is_some() { + return Err(serde::de::Error::duplicate_field("maximumBufferedRecordBatchesPerStream")); + } + maximum_buffered_record_batches_per_stream__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::BloomFilterOnRead => { + if bloom_filter_on_read__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterOnRead")); + } + bloom_filter_on_read__ = Some(map_.next_value()?); + } + GeneratedField::BloomFilterOnWrite => { + if bloom_filter_on_write__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterOnWrite")); + } + bloom_filter_on_write__ = Some(map_.next_value()?); + } + GeneratedField::SchemaForceViewTypes => { + if schema_force_view_types__.is_some() { + return Err(serde::de::Error::duplicate_field("schemaForceViewTypes")); + } + schema_force_view_types__ = Some(map_.next_value()?); + } + GeneratedField::BinaryAsString => { + if binary_as_string__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryAsString")); + } + binary_as_string__ = Some(map_.next_value()?); + } + GeneratedField::DictionaryPageSizeLimit => { + if dictionary_page_size_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); + } + dictionary_page_size_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::DataPageRowCountLimit => { + if data_page_row_count_limit__.is_some() { + return Err(serde::de::Error::duplicate_field("dataPageRowCountLimit")); + } + data_page_row_count_limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::MaxRowGroupSize => { + if max_row_group_size__.is_some() { + return Err(serde::de::Error::duplicate_field("maxRowGroupSize")); + } + max_row_group_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::CreatedBy => { + if created_by__.is_some() { + return Err(serde::de::Error::duplicate_field("createdBy")); + } + created_by__ = Some(map_.next_value()?); + } + GeneratedField::MetadataSizeHint => { + if metadata_size_hint_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("metadataSizeHint")); + } + metadata_size_hint_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::MetadataSizeHintOpt::MetadataSizeHint(x.0)); + } + GeneratedField::Compression => { + if compression_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("compression")); + } + compression_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::CompressionOpt::Compression); + } + GeneratedField::DictionaryEnabled => { + if dictionary_enabled_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryEnabled")); + } + dictionary_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::DictionaryEnabledOpt::DictionaryEnabled); + } + GeneratedField::StatisticsEnabled => { + if statistics_enabled_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("statisticsEnabled")); + } + statistics_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::StatisticsEnabledOpt::StatisticsEnabled); + } + GeneratedField::MaxStatisticsSize => { + if max_statistics_size_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("maxStatisticsSize")); + } + max_statistics_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(x.0)); + } + GeneratedField::ColumnIndexTruncateLength => { + if column_index_truncate_length_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("columnIndexTruncateLength")); + } + column_index_truncate_length_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(x.0)); + } + GeneratedField::Encoding => { + if encoding_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("encoding")); + } + encoding_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::EncodingOpt::Encoding); + } + GeneratedField::BloomFilterFpp => { + if bloom_filter_fpp_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterFpp")); + } + bloom_filter_fpp_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::BloomFilterFppOpt::BloomFilterFpp(x.0)); + } + GeneratedField::BloomFilterNdv => { + if bloom_filter_ndv_opt__.is_some() { + return Err(serde::de::Error::duplicate_field("bloomFilterNdv")); + } + bloom_filter_ndv_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::BloomFilterNdvOpt::BloomFilterNdv(x.0)); + } + } + } + Ok(ParquetOptions { + enable_page_index: enable_page_index__.unwrap_or_default(), + pruning: pruning__.unwrap_or_default(), + skip_metadata: skip_metadata__.unwrap_or_default(), + pushdown_filters: pushdown_filters__.unwrap_or_default(), + reorder_filters: reorder_filters__.unwrap_or_default(), + data_pagesize_limit: data_pagesize_limit__.unwrap_or_default(), + write_batch_size: write_batch_size__.unwrap_or_default(), + writer_version: writer_version__.unwrap_or_default(), + allow_single_file_parallelism: allow_single_file_parallelism__.unwrap_or_default(), + maximum_parallel_row_group_writers: maximum_parallel_row_group_writers__.unwrap_or_default(), + maximum_buffered_record_batches_per_stream: maximum_buffered_record_batches_per_stream__.unwrap_or_default(), + bloom_filter_on_read: bloom_filter_on_read__.unwrap_or_default(), + bloom_filter_on_write: bloom_filter_on_write__.unwrap_or_default(), + schema_force_view_types: schema_force_view_types__.unwrap_or_default(), + binary_as_string: binary_as_string__.unwrap_or_default(), + dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), + data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), + max_row_group_size: max_row_group_size__.unwrap_or_default(), + created_by: created_by__.unwrap_or_default(), + metadata_size_hint_opt: metadata_size_hint_opt__, + compression_opt: compression_opt__, + dictionary_enabled_opt: dictionary_enabled_opt__, + statistics_enabled_opt: statistics_enabled_opt__, + max_statistics_size_opt: max_statistics_size_opt__, + column_index_truncate_length_opt: column_index_truncate_length_opt__, + encoding_opt: encoding_opt__, + bloom_filter_fpp_opt: bloom_filter_fpp_opt__, + bloom_filter_ndv_opt: bloom_filter_ndv_opt__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ParquetOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Precision { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.precision_info != 0 { + len += 1; + } + if self.val.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Precision", len)?; + if self.precision_info != 0 { + let v = PrecisionInfo::try_from(self.precision_info) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.precision_info)))?; + struct_ser.serialize_field("precisionInfo", &v)?; + } + if let Some(v) = self.val.as_ref() { + struct_ser.serialize_field("val", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Precision { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "precision_info", + "precisionInfo", + "val", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + PrecisionInfo, + Val, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "precisionInfo" | "precision_info" => Ok(GeneratedField::PrecisionInfo), + "val" => Ok(GeneratedField::Val), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Precision; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Precision") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut precision_info__ = None; + let mut val__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::PrecisionInfo => { + if precision_info__.is_some() { + return Err(serde::de::Error::duplicate_field("precisionInfo")); + } + precision_info__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Val => { + if val__.is_some() { + return Err(serde::de::Error::duplicate_field("val")); + } + val__ = map_.next_value()?; + } + } + } + Ok(Precision { + precision_info: precision_info__.unwrap_or_default(), + val: val__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.Precision", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PrecisionInfo { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for PrecisionInfo { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "EXACT", + "INEXACT", + "ABSENT", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PrecisionInfo; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "EXACT" => Ok(PrecisionInfo::Exact), + "INEXACT" => Ok(PrecisionInfo::Inexact), + "ABSENT" => Ok(PrecisionInfo::Absent), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for PrimaryKeyConstraint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.indices.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.PrimaryKeyConstraint", len)?; + if !self.indices.is_empty() { + struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "indices", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Indices, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "indices" => Ok(GeneratedField::Indices), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PrimaryKeyConstraint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.PrimaryKeyConstraint") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut indices__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Indices => { + if indices__.is_some() { + return Err(serde::de::Error::duplicate_field("indices")); + } + indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(PrimaryKeyConstraint { + indices: indices__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarDictionaryValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.index_type.is_some() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarDictionaryValue", len)?; + if let Some(v) = self.index_type.as_ref() { + struct_ser.serialize_field("indexType", v)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "index_type", + "indexType", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + IndexType, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "indexType" | "index_type" => Ok(GeneratedField::IndexType), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarDictionaryValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarDictionaryValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut index_type__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::IndexType => { + if index_type__.is_some() { + return Err(serde::de::Error::duplicate_field("indexType")); + } + index_type__ = map_.next_value()?; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + } + } + Ok(ScalarDictionaryValue { + index_type: index_type__, + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarDictionaryValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarFixedSizeBinary { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.values.is_empty() { + len += 1; + } + if self.length != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarFixedSizeBinary", len)?; + if !self.values.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("values", pbjson::private::base64::encode(&self.values).as_str())?; + } + if self.length != 0 { + struct_ser.serialize_field("length", &self.length)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "values", + "length", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Values, + Length, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "values" => Ok(GeneratedField::Values), + "length" => Ok(GeneratedField::Length), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarFixedSizeBinary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarFixedSizeBinary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut values__ = None; + let mut length__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Values => { + if values__.is_some() { + return Err(serde::de::Error::duplicate_field("values")); + } + values__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Length => { + if length__.is_some() { + return Err(serde::de::Error::duplicate_field("length")); + } + length__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(ScalarFixedSizeBinary { + values: values__.unwrap_or_default(), + length: length__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarFixedSizeBinary", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarNestedValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.ipc_message.is_empty() { + len += 1; + } + if !self.arrow_data.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.dictionaries.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarNestedValue", len)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; + } + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.dictionaries.is_empty() { + struct_ser.serialize_field("dictionaries", &self.dictionaries)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarNestedValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + "schema", + "dictionaries", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + IpcMessage, + ArrowData, + Schema, + Dictionaries, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + "schema" => Ok(GeneratedField::Schema), + "dictionaries" => Ok(GeneratedField::Dictionaries), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarNestedValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarNestedValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut ipc_message__ = None; + let mut arrow_data__ = None; + let mut schema__ = None; + let mut dictionaries__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); + } + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); + } + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Dictionaries => { + if dictionaries__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaries")); + } + dictionaries__ = Some(map_.next_value()?); + } + } + } + Ok(ScalarNestedValue { + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + schema: schema__, + dictionaries: dictionaries__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarNestedValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for scalar_nested_value::Dictionary { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.ipc_message.is_empty() { + len += 1; + } + if !self.arrow_data.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarNestedValue.Dictionary", len)?; + if !self.ipc_message.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; + } + if !self.arrow_data.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for scalar_nested_value::Dictionary { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "ipc_message", + "ipcMessage", + "arrow_data", + "arrowData", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + IpcMessage, + ArrowData, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), + "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = scalar_nested_value::Dictionary; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarNestedValue.Dictionary") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut ipc_message__ = None; + let mut arrow_data__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::IpcMessage => { + if ipc_message__.is_some() { + return Err(serde::de::Error::duplicate_field("ipcMessage")); + } + ipc_message__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::ArrowData => { + if arrow_data__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowData")); + } + arrow_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + } + } + Ok(scalar_nested_value::Dictionary { + ipc_message: ipc_message__.unwrap_or_default(), + arrow_data: arrow_data__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarNestedValue.Dictionary", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarTime32Value { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarTime32Value", len)?; + if let Some(v) = self.value.as_ref() { + match v { + scalar_time32_value::Value::Time32SecondValue(v) => { + struct_ser.serialize_field("time32SecondValue", v)?; + } + scalar_time32_value::Value::Time32MillisecondValue(v) => { + struct_ser.serialize_field("time32MillisecondValue", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarTime32Value { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "time32_second_value", + "time32SecondValue", + "time32_millisecond_value", + "time32MillisecondValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Time32SecondValue, + Time32MillisecondValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "time32SecondValue" | "time32_second_value" => Ok(GeneratedField::Time32SecondValue), + "time32MillisecondValue" | "time32_millisecond_value" => Ok(GeneratedField::Time32MillisecondValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarTime32Value; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarTime32Value") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Time32SecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("time32SecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32SecondValue(x.0)); + } + GeneratedField::Time32MillisecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("time32MillisecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32MillisecondValue(x.0)); + } + } + } + Ok(ScalarTime32Value { + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarTime32Value", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarTime64Value { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarTime64Value", len)?; + if let Some(v) = self.value.as_ref() { + match v { + scalar_time64_value::Value::Time64MicrosecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("time64MicrosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_time64_value::Value::Time64NanosecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("time64NanosecondValue", ToString::to_string(&v).as_str())?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarTime64Value { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "time64_microsecond_value", + "time64MicrosecondValue", + "time64_nanosecond_value", + "time64NanosecondValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Time64MicrosecondValue, + Time64NanosecondValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "time64MicrosecondValue" | "time64_microsecond_value" => Ok(GeneratedField::Time64MicrosecondValue), + "time64NanosecondValue" | "time64_nanosecond_value" => Ok(GeneratedField::Time64NanosecondValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarTime64Value; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarTime64Value") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Time64MicrosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("time64MicrosecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64MicrosecondValue(x.0)); + } + GeneratedField::Time64NanosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("time64NanosecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64NanosecondValue(x.0)); + } + } + } + Ok(ScalarTime64Value { + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarTime64Value", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarTimestampValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.timezone.is_empty() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarTimestampValue", len)?; + if !self.timezone.is_empty() { + struct_ser.serialize_field("timezone", &self.timezone)?; + } + if let Some(v) = self.value.as_ref() { + match v { + scalar_timestamp_value::Value::TimeMicrosecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("timeMicrosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_timestamp_value::Value::TimeNanosecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("timeNanosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_timestamp_value::Value::TimeSecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("timeSecondValue", ToString::to_string(&v).as_str())?; + } + scalar_timestamp_value::Value::TimeMillisecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("timeMillisecondValue", ToString::to_string(&v).as_str())?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "timezone", + "time_microsecond_value", + "timeMicrosecondValue", + "time_nanosecond_value", + "timeNanosecondValue", + "time_second_value", + "timeSecondValue", + "time_millisecond_value", + "timeMillisecondValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Timezone, + TimeMicrosecondValue, + TimeNanosecondValue, + TimeSecondValue, + TimeMillisecondValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "timezone" => Ok(GeneratedField::Timezone), + "timeMicrosecondValue" | "time_microsecond_value" => Ok(GeneratedField::TimeMicrosecondValue), + "timeNanosecondValue" | "time_nanosecond_value" => Ok(GeneratedField::TimeNanosecondValue), + "timeSecondValue" | "time_second_value" => Ok(GeneratedField::TimeSecondValue), + "timeMillisecondValue" | "time_millisecond_value" => Ok(GeneratedField::TimeMillisecondValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarTimestampValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarTimestampValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut timezone__ = None; + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Timezone => { + if timezone__.is_some() { + return Err(serde::de::Error::duplicate_field("timezone")); + } + timezone__ = Some(map_.next_value()?); + } + GeneratedField::TimeMicrosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("timeMicrosecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMicrosecondValue(x.0)); + } + GeneratedField::TimeNanosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("timeNanosecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeNanosecondValue(x.0)); + } + GeneratedField::TimeSecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("timeSecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeSecondValue(x.0)); + } + GeneratedField::TimeMillisecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("timeMillisecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMillisecondValue(x.0)); + } + } + } + Ok(ScalarTimestampValue { + timezone: timezone__.unwrap_or_default(), + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarTimestampValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for ScalarValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.ScalarValue", len)?; + if let Some(v) = self.value.as_ref() { + match v { + scalar_value::Value::NullValue(v) => { + struct_ser.serialize_field("nullValue", v)?; + } + scalar_value::Value::BoolValue(v) => { + struct_ser.serialize_field("boolValue", v)?; + } + scalar_value::Value::Utf8Value(v) => { + struct_ser.serialize_field("utf8Value", v)?; + } + scalar_value::Value::LargeUtf8Value(v) => { + struct_ser.serialize_field("largeUtf8Value", v)?; + } + scalar_value::Value::Utf8ViewValue(v) => { + struct_ser.serialize_field("utf8ViewValue", v)?; + } + scalar_value::Value::Int8Value(v) => { + struct_ser.serialize_field("int8Value", v)?; + } + scalar_value::Value::Int16Value(v) => { + struct_ser.serialize_field("int16Value", v)?; + } + scalar_value::Value::Int32Value(v) => { + struct_ser.serialize_field("int32Value", v)?; + } + scalar_value::Value::Int64Value(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("int64Value", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::Uint8Value(v) => { + struct_ser.serialize_field("uint8Value", v)?; + } + scalar_value::Value::Uint16Value(v) => { + struct_ser.serialize_field("uint16Value", v)?; + } + scalar_value::Value::Uint32Value(v) => { + struct_ser.serialize_field("uint32Value", v)?; + } + scalar_value::Value::Uint64Value(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("uint64Value", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::Float32Value(v) => { + struct_ser.serialize_field("float32Value", v)?; + } + scalar_value::Value::Float64Value(v) => { + struct_ser.serialize_field("float64Value", v)?; + } + scalar_value::Value::Date32Value(v) => { + struct_ser.serialize_field("date32Value", v)?; + } + scalar_value::Value::Time32Value(v) => { + struct_ser.serialize_field("time32Value", v)?; + } + scalar_value::Value::LargeListValue(v) => { + struct_ser.serialize_field("largeListValue", v)?; + } + scalar_value::Value::ListValue(v) => { + struct_ser.serialize_field("listValue", v)?; + } + scalar_value::Value::FixedSizeListValue(v) => { + struct_ser.serialize_field("fixedSizeListValue", v)?; + } + scalar_value::Value::StructValue(v) => { + struct_ser.serialize_field("structValue", v)?; + } + scalar_value::Value::MapValue(v) => { + struct_ser.serialize_field("mapValue", v)?; + } + scalar_value::Value::Decimal128Value(v) => { + struct_ser.serialize_field("decimal128Value", v)?; + } + scalar_value::Value::Decimal256Value(v) => { + struct_ser.serialize_field("decimal256Value", v)?; + } + scalar_value::Value::Date64Value(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::IntervalYearmonthValue(v) => { + struct_ser.serialize_field("intervalYearmonthValue", v)?; + } + scalar_value::Value::DurationSecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("durationSecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationMillisecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("durationMillisecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationMicrosecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("durationMicrosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::DurationNanosecondValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("durationNanosecondValue", ToString::to_string(&v).as_str())?; + } + scalar_value::Value::TimestampValue(v) => { + struct_ser.serialize_field("timestampValue", v)?; + } + scalar_value::Value::DictionaryValue(v) => { + struct_ser.serialize_field("dictionaryValue", v)?; + } + scalar_value::Value::BinaryValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("binaryValue", pbjson::private::base64::encode(&v).as_str())?; + } + scalar_value::Value::LargeBinaryValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("largeBinaryValue", pbjson::private::base64::encode(&v).as_str())?; + } + scalar_value::Value::BinaryViewValue(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("binaryViewValue", pbjson::private::base64::encode(&v).as_str())?; + } + scalar_value::Value::Time64Value(v) => { + struct_ser.serialize_field("time64Value", v)?; + } + scalar_value::Value::IntervalDaytimeValue(v) => { + struct_ser.serialize_field("intervalDaytimeValue", v)?; + } + scalar_value::Value::IntervalMonthDayNano(v) => { + struct_ser.serialize_field("intervalMonthDayNano", v)?; + } + scalar_value::Value::FixedSizeBinaryValue(v) => { + struct_ser.serialize_field("fixedSizeBinaryValue", v)?; + } + scalar_value::Value::UnionValue(v) => { + struct_ser.serialize_field("unionValue", v)?; + } + } + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ScalarValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "null_value", + "nullValue", + "bool_value", + "boolValue", + "utf8_value", + "utf8Value", + "large_utf8_value", + "largeUtf8Value", + "utf8_view_value", + "utf8ViewValue", + "int8_value", + "int8Value", + "int16_value", + "int16Value", + "int32_value", + "int32Value", + "int64_value", + "int64Value", + "uint8_value", + "uint8Value", + "uint16_value", + "uint16Value", + "uint32_value", + "uint32Value", + "uint64_value", + "uint64Value", + "float32_value", + "float32Value", + "float64_value", + "float64Value", + "date_32_value", + "date32Value", + "time32_value", + "time32Value", + "large_list_value", + "largeListValue", + "list_value", + "listValue", + "fixed_size_list_value", + "fixedSizeListValue", + "struct_value", + "structValue", + "map_value", + "mapValue", + "decimal128_value", + "decimal128Value", + "decimal256_value", + "decimal256Value", + "date_64_value", + "date64Value", + "interval_yearmonth_value", + "intervalYearmonthValue", + "duration_second_value", + "durationSecondValue", + "duration_millisecond_value", + "durationMillisecondValue", + "duration_microsecond_value", + "durationMicrosecondValue", + "duration_nanosecond_value", + "durationNanosecondValue", + "timestamp_value", + "timestampValue", + "dictionary_value", + "dictionaryValue", + "binary_value", + "binaryValue", + "large_binary_value", + "largeBinaryValue", + "binary_view_value", + "binaryViewValue", + "time64_value", + "time64Value", + "interval_daytime_value", + "intervalDaytimeValue", + "interval_month_day_nano", + "intervalMonthDayNano", + "fixed_size_binary_value", + "fixedSizeBinaryValue", + "union_value", + "unionValue", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + NullValue, + BoolValue, + Utf8Value, + LargeUtf8Value, + Utf8ViewValue, + Int8Value, + Int16Value, + Int32Value, + Int64Value, + Uint8Value, + Uint16Value, + Uint32Value, + Uint64Value, + Float32Value, + Float64Value, + Date32Value, + Time32Value, + LargeListValue, + ListValue, + FixedSizeListValue, + StructValue, + MapValue, + Decimal128Value, + Decimal256Value, + Date64Value, + IntervalYearmonthValue, + DurationSecondValue, + DurationMillisecondValue, + DurationMicrosecondValue, + DurationNanosecondValue, + TimestampValue, + DictionaryValue, + BinaryValue, + LargeBinaryValue, + BinaryViewValue, + Time64Value, + IntervalDaytimeValue, + IntervalMonthDayNano, + FixedSizeBinaryValue, + UnionValue, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "boolValue" | "bool_value" => Ok(GeneratedField::BoolValue), + "utf8Value" | "utf8_value" => Ok(GeneratedField::Utf8Value), + "largeUtf8Value" | "large_utf8_value" => Ok(GeneratedField::LargeUtf8Value), + "utf8ViewValue" | "utf8_view_value" => Ok(GeneratedField::Utf8ViewValue), + "int8Value" | "int8_value" => Ok(GeneratedField::Int8Value), + "int16Value" | "int16_value" => Ok(GeneratedField::Int16Value), + "int32Value" | "int32_value" => Ok(GeneratedField::Int32Value), + "int64Value" | "int64_value" => Ok(GeneratedField::Int64Value), + "uint8Value" | "uint8_value" => Ok(GeneratedField::Uint8Value), + "uint16Value" | "uint16_value" => Ok(GeneratedField::Uint16Value), + "uint32Value" | "uint32_value" => Ok(GeneratedField::Uint32Value), + "uint64Value" | "uint64_value" => Ok(GeneratedField::Uint64Value), + "float32Value" | "float32_value" => Ok(GeneratedField::Float32Value), + "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), + "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), + "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), + "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), + "listValue" | "list_value" => Ok(GeneratedField::ListValue), + "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), + "structValue" | "struct_value" => Ok(GeneratedField::StructValue), + "mapValue" | "map_value" => Ok(GeneratedField::MapValue), + "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), + "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), + "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), + "intervalYearmonthValue" | "interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue), + "durationSecondValue" | "duration_second_value" => Ok(GeneratedField::DurationSecondValue), + "durationMillisecondValue" | "duration_millisecond_value" => Ok(GeneratedField::DurationMillisecondValue), + "durationMicrosecondValue" | "duration_microsecond_value" => Ok(GeneratedField::DurationMicrosecondValue), + "durationNanosecondValue" | "duration_nanosecond_value" => Ok(GeneratedField::DurationNanosecondValue), + "timestampValue" | "timestamp_value" => Ok(GeneratedField::TimestampValue), + "dictionaryValue" | "dictionary_value" => Ok(GeneratedField::DictionaryValue), + "binaryValue" | "binary_value" => Ok(GeneratedField::BinaryValue), + "largeBinaryValue" | "large_binary_value" => Ok(GeneratedField::LargeBinaryValue), + "binaryViewValue" | "binary_view_value" => Ok(GeneratedField::BinaryViewValue), + "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), + "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), + "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), + "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), + "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ScalarValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.ScalarValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::NullValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("nullValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::NullValue) +; + } + GeneratedField::BoolValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("boolValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::BoolValue); + } + GeneratedField::Utf8Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("utf8Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Utf8Value); + } + GeneratedField::LargeUtf8Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeUtf8Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeUtf8Value); + } + GeneratedField::Utf8ViewValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("utf8ViewValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Utf8ViewValue); + } + GeneratedField::Int8Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("int8Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int8Value(x.0)); + } + GeneratedField::Int16Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("int16Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int16Value(x.0)); + } + GeneratedField::Int32Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("int32Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int32Value(x.0)); + } + GeneratedField::Int64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("int64Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int64Value(x.0)); + } + GeneratedField::Uint8Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("uint8Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint8Value(x.0)); + } + GeneratedField::Uint16Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("uint16Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint16Value(x.0)); + } + GeneratedField::Uint32Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("uint32Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint32Value(x.0)); + } + GeneratedField::Uint64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("uint64Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint64Value(x.0)); + } + GeneratedField::Float32Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("float32Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float32Value(x.0)); + } + GeneratedField::Float64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("float64Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float64Value(x.0)); + } + GeneratedField::Date32Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("date32Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date32Value(x.0)); + } + GeneratedField::Time32Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("time32Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) +; + } + GeneratedField::LargeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) +; + } + GeneratedField::ListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("listValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) +; + } + GeneratedField::FixedSizeListValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) +; + } + GeneratedField::StructValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("structValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) +; + } + GeneratedField::MapValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("mapValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::MapValue) +; + } + GeneratedField::Decimal128Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal128Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value) +; + } + GeneratedField::Decimal256Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("decimal256Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value) +; + } + GeneratedField::Date64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("date64Value")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date64Value(x.0)); + } + GeneratedField::IntervalYearmonthValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("intervalYearmonthValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalYearmonthValue(x.0)); + } + GeneratedField::DurationSecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationSecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationSecondValue(x.0)); + } + GeneratedField::DurationMillisecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationMillisecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMillisecondValue(x.0)); + } + GeneratedField::DurationMicrosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationMicrosecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMicrosecondValue(x.0)); + } + GeneratedField::DurationNanosecondValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("durationNanosecondValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationNanosecondValue(x.0)); + } + GeneratedField::TimestampValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::TimestampValue) +; + } + GeneratedField::DictionaryValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("dictionaryValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::DictionaryValue) +; + } + GeneratedField::BinaryValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::BinaryValue(x.0)); + } + GeneratedField::LargeBinaryValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("largeBinaryValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::LargeBinaryValue(x.0)); + } + GeneratedField::BinaryViewValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryViewValue")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::BinaryViewValue(x.0)); + } + GeneratedField::Time64Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("time64Value")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time64Value) +; + } + GeneratedField::IntervalDaytimeValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("intervalDaytimeValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::IntervalDaytimeValue) +; + } + GeneratedField::IntervalMonthDayNano => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("intervalMonthDayNano")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::IntervalMonthDayNano) +; + } + GeneratedField::FixedSizeBinaryValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) +; + } + GeneratedField::UnionValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("unionValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) +; + } + } + } + Ok(ScalarValue { + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.ScalarValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Schema { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.columns.is_empty() { + len += 1; + } + if !self.metadata.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Schema", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns)?; + } + if !self.metadata.is_empty() { + struct_ser.serialize_field("metadata", &self.metadata)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Schema { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "columns", + "metadata", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Columns, + Metadata, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "columns" => Ok(GeneratedField::Columns), + "metadata" => Ok(GeneratedField::Metadata), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Schema; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Schema") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut columns__ = None; + let mut metadata__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); + } + columns__ = Some(map_.next_value()?); + } + GeneratedField::Metadata => { + if metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("metadata")); + } + metadata__ = Some( + map_.next_value::>()? + ); + } + } + } + Ok(Schema { + columns: columns__.unwrap_or_default(), + metadata: metadata__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Schema", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Statistics { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.num_rows.is_some() { + len += 1; + } + if self.total_byte_size.is_some() { + len += 1; + } + if !self.column_stats.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Statistics", len)?; + if let Some(v) = self.num_rows.as_ref() { + struct_ser.serialize_field("numRows", v)?; + } + if let Some(v) = self.total_byte_size.as_ref() { + struct_ser.serialize_field("totalByteSize", v)?; + } + if !self.column_stats.is_empty() { + struct_ser.serialize_field("columnStats", &self.column_stats)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Statistics { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "num_rows", + "numRows", + "total_byte_size", + "totalByteSize", + "column_stats", + "columnStats", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + NumRows, + TotalByteSize, + ColumnStats, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "numRows" | "num_rows" => Ok(GeneratedField::NumRows), + "totalByteSize" | "total_byte_size" => Ok(GeneratedField::TotalByteSize), + "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Statistics; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Statistics") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut num_rows__ = None; + let mut total_byte_size__ = None; + let mut column_stats__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::NumRows => { + if num_rows__.is_some() { + return Err(serde::de::Error::duplicate_field("numRows")); + } + num_rows__ = map_.next_value()?; + } + GeneratedField::TotalByteSize => { + if total_byte_size__.is_some() { + return Err(serde::de::Error::duplicate_field("totalByteSize")); + } + total_byte_size__ = map_.next_value()?; + } + GeneratedField::ColumnStats => { + if column_stats__.is_some() { + return Err(serde::de::Error::duplicate_field("columnStats")); + } + column_stats__ = Some(map_.next_value()?); + } + } + } + Ok(Statistics { + num_rows: num_rows__, + total_byte_size: total_byte_size__, + column_stats: column_stats__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Statistics", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Struct { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.sub_field_types.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Struct", len)?; + if !self.sub_field_types.is_empty() { + struct_ser.serialize_field("subFieldTypes", &self.sub_field_types)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Struct { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "sub_field_types", + "subFieldTypes", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + SubFieldTypes, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "subFieldTypes" | "sub_field_types" => Ok(GeneratedField::SubFieldTypes), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Struct; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Struct") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut sub_field_types__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::SubFieldTypes => { + if sub_field_types__.is_some() { + return Err(serde::de::Error::duplicate_field("subFieldTypes")); + } + sub_field_types__ = Some(map_.next_value()?); + } + } + } + Ok(Struct { + sub_field_types: sub_field_types__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Struct", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for TableParquetOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.global.is_some() { + len += 1; + } + if !self.column_specific_options.is_empty() { + len += 1; + } + if !self.key_value_metadata.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.TableParquetOptions", len)?; + if let Some(v) = self.global.as_ref() { + struct_ser.serialize_field("global", v)?; + } + if !self.column_specific_options.is_empty() { + struct_ser.serialize_field("columnSpecificOptions", &self.column_specific_options)?; + } + if !self.key_value_metadata.is_empty() { + struct_ser.serialize_field("keyValueMetadata", &self.key_value_metadata)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for TableParquetOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "global", + "column_specific_options", + "columnSpecificOptions", + "key_value_metadata", + "keyValueMetadata", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Global, + ColumnSpecificOptions, + KeyValueMetadata, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "global" => Ok(GeneratedField::Global), + "columnSpecificOptions" | "column_specific_options" => Ok(GeneratedField::ColumnSpecificOptions), + "keyValueMetadata" | "key_value_metadata" => Ok(GeneratedField::KeyValueMetadata), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = TableParquetOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.TableParquetOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut global__ = None; + let mut column_specific_options__ = None; + let mut key_value_metadata__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Global => { + if global__.is_some() { + return Err(serde::de::Error::duplicate_field("global")); + } + global__ = map_.next_value()?; + } + GeneratedField::ColumnSpecificOptions => { + if column_specific_options__.is_some() { + return Err(serde::de::Error::duplicate_field("columnSpecificOptions")); + } + column_specific_options__ = Some(map_.next_value()?); + } + GeneratedField::KeyValueMetadata => { + if key_value_metadata__.is_some() { + return Err(serde::de::Error::duplicate_field("keyValueMetadata")); + } + key_value_metadata__ = Some( + map_.next_value::>()? + ); + } + } + } + Ok(TableParquetOptions { + global: global__, + column_specific_options: column_specific_options__.unwrap_or_default(), + key_value_metadata: key_value_metadata__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.TableParquetOptions", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for TimeUnit { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Second => "Second", + Self::Millisecond => "Millisecond", + Self::Microsecond => "Microsecond", + Self::Nanosecond => "Nanosecond", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for TimeUnit { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Second", + "Millisecond", + "Microsecond", + "Nanosecond", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = TimeUnit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Second" => Ok(TimeUnit::Second), + "Millisecond" => Ok(TimeUnit::Millisecond), + "Microsecond" => Ok(TimeUnit::Microsecond), + "Nanosecond" => Ok(TimeUnit::Nanosecond), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for Timestamp { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.time_unit != 0 { + len += 1; + } + if !self.timezone.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Timestamp", len)?; + if self.time_unit != 0 { + let v = TimeUnit::try_from(self.time_unit) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.time_unit)))?; + struct_ser.serialize_field("timeUnit", &v)?; + } + if !self.timezone.is_empty() { + struct_ser.serialize_field("timezone", &self.timezone)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Timestamp { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "time_unit", + "timeUnit", + "timezone", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + TimeUnit, + Timezone, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "timeUnit" | "time_unit" => Ok(GeneratedField::TimeUnit), + "timezone" => Ok(GeneratedField::Timezone), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Timestamp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Timestamp") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut time_unit__ = None; + let mut timezone__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::TimeUnit => { + if time_unit__.is_some() { + return Err(serde::de::Error::duplicate_field("timeUnit")); + } + time_unit__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Timezone => { + if timezone__.is_some() { + return Err(serde::de::Error::duplicate_field("timezone")); + } + timezone__ = Some(map_.next_value()?); + } + } + } + Ok(Timestamp { + time_unit: time_unit__.unwrap_or_default(), + timezone: timezone__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Timestamp", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for Union { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.union_types.is_empty() { + len += 1; + } + if self.union_mode != 0 { + len += 1; + } + if !self.type_ids.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.Union", len)?; + if !self.union_types.is_empty() { + struct_ser.serialize_field("unionTypes", &self.union_types)?; + } + if self.union_mode != 0 { + let v = UnionMode::try_from(self.union_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.union_mode)))?; + struct_ser.serialize_field("unionMode", &v)?; + } + if !self.type_ids.is_empty() { + struct_ser.serialize_field("typeIds", &self.type_ids)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for Union { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "union_types", + "unionTypes", + "union_mode", + "unionMode", + "type_ids", + "typeIds", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + UnionTypes, + UnionMode, + TypeIds, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "unionTypes" | "union_types" => Ok(GeneratedField::UnionTypes), + "unionMode" | "union_mode" => Ok(GeneratedField::UnionMode), + "typeIds" | "type_ids" => Ok(GeneratedField::TypeIds), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = Union; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.Union") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut union_types__ = None; + let mut union_mode__ = None; + let mut type_ids__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::UnionTypes => { + if union_types__.is_some() { + return Err(serde::de::Error::duplicate_field("unionTypes")); + } + union_types__ = Some(map_.next_value()?); + } + GeneratedField::UnionMode => { + if union_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("unionMode")); + } + union_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::TypeIds => { + if type_ids__.is_some() { + return Err(serde::de::Error::duplicate_field("typeIds")); + } + type_ids__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(Union { + union_types: union_types__.unwrap_or_default(), + union_mode: union_mode__.unwrap_or_default(), + type_ids: type_ids__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.Union", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for UnionField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_id != 0 { + len += 1; + } + if self.field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.UnionField", len)?; + if self.field_id != 0 { + struct_ser.serialize_field("fieldId", &self.field_id)?; + } + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_id", + "fieldId", + "field", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldId, + Field, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldId" | "field_id" => Ok(GeneratedField::FieldId), + "field" => Ok(GeneratedField::Field), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.UnionField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_id__ = None; + let mut field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldId => { + if field_id__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldId")); + } + field_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + } + } + Ok(UnionField { + field_id: field_id__.unwrap_or_default(), + field: field__, + }) + } + } + deserializer.deserialize_struct("datafusion_common.UnionField", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for UnionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Sparse => "sparse", + Self::Dense => "dense", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for UnionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "sparse", + "dense", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "sparse" => Ok(UnionMode::Sparse), + "dense" => Ok(UnionMode::Dense), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for UnionValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value_id != 0 { + len += 1; + } + if self.value.is_some() { + len += 1; + } + if !self.fields.is_empty() { + len += 1; + } + if self.mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.UnionValue", len)?; + if self.value_id != 0 { + struct_ser.serialize_field("valueId", &self.value_id)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } + if self.mode != 0 { + let v = UnionMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + struct_ser.serialize_field("mode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value_id", + "valueId", + "value", + "fields", + "mode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ValueId, + Value, + Fields, + Mode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "valueId" | "value_id" => Ok(GeneratedField::ValueId), + "value" => Ok(GeneratedField::Value), + "fields" => Ok(GeneratedField::Fields), + "mode" => Ok(GeneratedField::Mode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.UnionValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value_id__ = None; + let mut value__ = None; + let mut fields__ = None; + let mut mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ValueId => { + if value_id__.is_some() { + return Err(serde::de::Error::duplicate_field("valueId")); + } + value_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + GeneratedField::Mode => { + if mode__.is_some() { + return Err(serde::de::Error::duplicate_field("mode")); + } + mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(UnionValue { + value_id: value_id__.unwrap_or_default(), + value: value__, + fields: fields__.unwrap_or_default(), + mode: mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.UnionValue", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for UniqueConstraint { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.indices.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion_common.UniqueConstraint", len)?; + if !self.indices.is_empty() { + struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UniqueConstraint { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "indices", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Indices, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "indices" => Ok(GeneratedField::Indices), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UniqueConstraint; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion_common.UniqueConstraint") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut indices__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Indices => { + if indices__.is_some() { + return Err(serde::de::Error::duplicate_field("indices")); + } + indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + } + } + Ok(UniqueConstraint { + indices: indices__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion_common.UniqueConstraint", FIELDS, GeneratedVisitor) + } +} diff --git a/datafusion/proto-common/src/generated/prost.rs b/datafusion/proto-common/src/generated/prost.rs new file mode 100644 index 000000000000..68e7f74c7f49 --- /dev/null +++ b/datafusion/proto-common/src/generated/prost.rs @@ -0,0 +1,1127 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnRelation { + #[prost(string, tag = "1")] + pub relation: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Column { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub relation: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DfField { + #[prost(message, optional, tag = "1")] + pub field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub qualifier: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DfSchema { + #[prost(message, repeated, tag = "1")] + pub columns: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "2")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvFormat { + #[prost(message, optional, tag = "5")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetFormat { + #[prost(message, optional, tag = "2")] + pub options: ::core::option::Option, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AvroFormat {} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct NdJsonFormat { + #[prost(message, optional, tag = "1")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrimaryKeyConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UniqueConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraint { + #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] + pub constraint_mode: ::core::option::Option, +} +/// Nested message and enum types in `Constraint`. +pub mod constraint { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum ConstraintMode { + #[prost(message, tag = "1")] + PrimaryKey(super::PrimaryKeyConstraint), + #[prost(message, tag = "2")] + Unique(super::UniqueConstraint), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraints { + #[prost(message, repeated, tag = "1")] + pub constraints: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AvroOptions {} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ArrowOptions {} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Schema { + #[prost(message, repeated, tag = "1")] + pub columns: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "2")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Field { + /// name of the field + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, boxed, tag = "2")] + pub arrow_type: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "3")] + pub nullable: bool, + /// for complex data types like structs, unions + #[prost(message, repeated, tag = "4")] + pub children: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "5")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Timestamp { + #[prost(enumeration = "TimeUnit", tag = "1")] + pub time_unit: i32, + #[prost(string, tag = "2")] + pub timezone: ::prost::alloc::string::String, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Decimal { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct List { + #[prost(message, optional, boxed, tag = "1")] + pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FixedSizeList { + #[prost(message, optional, boxed, tag = "1")] + pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(int32, tag = "2")] + pub list_size: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Dictionary { + #[prost(message, optional, boxed, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Struct { + #[prost(message, repeated, tag = "1")] + pub sub_field_types: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Map { + #[prost(message, optional, boxed, tag = "1")] + pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "2")] + pub keys_sorted: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Union { + #[prost(message, repeated, tag = "1")] + pub union_types: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "2")] + pub union_mode: i32, + #[prost(int32, repeated, tag = "3")] + pub type_ids: ::prost::alloc::vec::Vec, +} +/// Used for List/FixedSizeList/LargeList/Struct/Map +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarNestedValue { + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "3")] + pub schema: ::core::option::Option, + #[prost(message, repeated, tag = "4")] + pub dictionaries: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `ScalarNestedValue`. +pub mod scalar_nested_value { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Dictionary { + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + } +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ScalarTime32Value { + #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarTime32Value`. +pub mod scalar_time32_value { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int32, tag = "1")] + Time32SecondValue(i32), + #[prost(int32, tag = "2")] + Time32MillisecondValue(i32), + } +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ScalarTime64Value { + #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarTime64Value`. +pub mod scalar_time64_value { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int64, tag = "1")] + Time64MicrosecondValue(i64), + #[prost(int64, tag = "2")] + Time64NanosecondValue(i64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarTimestampValue { + #[prost(string, tag = "5")] + pub timezone: ::prost::alloc::string::String, + #[prost(oneof = "scalar_timestamp_value::Value", tags = "1, 2, 3, 4")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarTimestampValue`. +pub mod scalar_timestamp_value { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int64, tag = "1")] + TimeMicrosecondValue(i64), + #[prost(int64, tag = "2")] + TimeNanosecondValue(i64), + #[prost(int64, tag = "3")] + TimeSecondValue(i64), + #[prost(int64, tag = "4")] + TimeMillisecondValue(i64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarDictionaryValue { + #[prost(message, optional, tag = "1")] + pub index_type: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct IntervalDayTimeValue { + #[prost(int32, tag = "1")] + pub days: i32, + #[prost(int32, tag = "2")] + pub milliseconds: i32, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct IntervalMonthDayNanoValue { + #[prost(int32, tag = "1")] + pub months: i32, + #[prost(int32, tag = "2")] + pub days: i32, + #[prost(int64, tag = "3")] + pub nanos: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarFixedSizeBinary { + #[prost(bytes = "vec", tag = "1")] + pub values: ::prost::alloc::vec::Vec, + #[prost(int32, tag = "2")] + pub length: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarValue { + #[prost( + oneof = "scalar_value::Value", + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + )] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarValue`. +pub mod scalar_value { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + /// was PrimitiveScalarType null_value = 19; + /// Null value of any type + #[prost(message, tag = "33")] + NullValue(super::ArrowType), + #[prost(bool, tag = "1")] + BoolValue(bool), + #[prost(string, tag = "2")] + Utf8Value(::prost::alloc::string::String), + #[prost(string, tag = "3")] + LargeUtf8Value(::prost::alloc::string::String), + #[prost(string, tag = "23")] + Utf8ViewValue(::prost::alloc::string::String), + #[prost(int32, tag = "4")] + Int8Value(i32), + #[prost(int32, tag = "5")] + Int16Value(i32), + #[prost(int32, tag = "6")] + Int32Value(i32), + #[prost(int64, tag = "7")] + Int64Value(i64), + #[prost(uint32, tag = "8")] + Uint8Value(u32), + #[prost(uint32, tag = "9")] + Uint16Value(u32), + #[prost(uint32, tag = "10")] + Uint32Value(u32), + #[prost(uint64, tag = "11")] + Uint64Value(u64), + #[prost(float, tag = "12")] + Float32Value(f32), + #[prost(double, tag = "13")] + Float64Value(f64), + /// Literal Date32 value always has a unit of day + #[prost(int32, tag = "14")] + Date32Value(i32), + #[prost(message, tag = "15")] + Time32Value(super::ScalarTime32Value), + #[prost(message, tag = "16")] + LargeListValue(super::ScalarNestedValue), + #[prost(message, tag = "17")] + ListValue(super::ScalarNestedValue), + #[prost(message, tag = "18")] + FixedSizeListValue(super::ScalarNestedValue), + #[prost(message, tag = "32")] + StructValue(super::ScalarNestedValue), + #[prost(message, tag = "41")] + MapValue(super::ScalarNestedValue), + #[prost(message, tag = "20")] + Decimal128Value(super::Decimal128), + #[prost(message, tag = "39")] + Decimal256Value(super::Decimal256), + #[prost(int64, tag = "21")] + Date64Value(i64), + #[prost(int32, tag = "24")] + IntervalYearmonthValue(i32), + #[prost(int64, tag = "35")] + DurationSecondValue(i64), + #[prost(int64, tag = "36")] + DurationMillisecondValue(i64), + #[prost(int64, tag = "37")] + DurationMicrosecondValue(i64), + #[prost(int64, tag = "38")] + DurationNanosecondValue(i64), + #[prost(message, tag = "26")] + TimestampValue(super::ScalarTimestampValue), + #[prost(message, tag = "27")] + DictionaryValue(::prost::alloc::boxed::Box), + #[prost(bytes, tag = "28")] + BinaryValue(::prost::alloc::vec::Vec), + #[prost(bytes, tag = "29")] + LargeBinaryValue(::prost::alloc::vec::Vec), + #[prost(bytes, tag = "22")] + BinaryViewValue(::prost::alloc::vec::Vec), + #[prost(message, tag = "30")] + Time64Value(super::ScalarTime64Value), + #[prost(message, tag = "25")] + IntervalDaytimeValue(super::IntervalDayTimeValue), + #[prost(message, tag = "31")] + IntervalMonthDayNano(super::IntervalMonthDayNanoValue), + #[prost(message, tag = "34")] + FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal128 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +/// Serialized data type +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowType { + #[prost( + oneof = "arrow_type::ArrowTypeEnum", + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" + )] + pub arrow_type_enum: ::core::option::Option, +} +/// Nested message and enum types in `ArrowType`. +pub mod arrow_type { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum ArrowTypeEnum { + /// arrow::Type::NA + #[prost(message, tag = "1")] + None(super::EmptyMessage), + /// arrow::Type::BOOL + #[prost(message, tag = "2")] + Bool(super::EmptyMessage), + /// arrow::Type::UINT8 + #[prost(message, tag = "3")] + Uint8(super::EmptyMessage), + /// arrow::Type::INT8 + #[prost(message, tag = "4")] + Int8(super::EmptyMessage), + /// represents arrow::Type fields in src/arrow/type.h + #[prost(message, tag = "5")] + Uint16(super::EmptyMessage), + #[prost(message, tag = "6")] + Int16(super::EmptyMessage), + #[prost(message, tag = "7")] + Uint32(super::EmptyMessage), + #[prost(message, tag = "8")] + Int32(super::EmptyMessage), + #[prost(message, tag = "9")] + Uint64(super::EmptyMessage), + #[prost(message, tag = "10")] + Int64(super::EmptyMessage), + #[prost(message, tag = "11")] + Float16(super::EmptyMessage), + #[prost(message, tag = "12")] + Float32(super::EmptyMessage), + #[prost(message, tag = "13")] + Float64(super::EmptyMessage), + #[prost(message, tag = "14")] + Utf8(super::EmptyMessage), + #[prost(message, tag = "35")] + Utf8View(super::EmptyMessage), + #[prost(message, tag = "32")] + LargeUtf8(super::EmptyMessage), + #[prost(message, tag = "15")] + Binary(super::EmptyMessage), + #[prost(message, tag = "34")] + BinaryView(super::EmptyMessage), + #[prost(int32, tag = "16")] + FixedSizeBinary(i32), + #[prost(message, tag = "31")] + LargeBinary(super::EmptyMessage), + #[prost(message, tag = "17")] + Date32(super::EmptyMessage), + #[prost(message, tag = "18")] + Date64(super::EmptyMessage), + #[prost(enumeration = "super::TimeUnit", tag = "19")] + Duration(i32), + #[prost(message, tag = "20")] + Timestamp(super::Timestamp), + #[prost(enumeration = "super::TimeUnit", tag = "21")] + Time32(i32), + #[prost(enumeration = "super::TimeUnit", tag = "22")] + Time64(i32), + #[prost(enumeration = "super::IntervalUnit", tag = "23")] + Interval(i32), + #[prost(message, tag = "24")] + Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), + #[prost(message, tag = "25")] + List(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + LargeList(::prost::alloc::boxed::Box), + #[prost(message, tag = "27")] + FixedSizeList(::prost::alloc::boxed::Box), + #[prost(message, tag = "28")] + Struct(super::Struct), + #[prost(message, tag = "29")] + Union(super::Union), + #[prost(message, tag = "30")] + Dictionary(::prost::alloc::boxed::Box), + #[prost(message, tag = "33")] + Map(::prost::alloc::boxed::Box), + } +} +/// Useful for representing an empty enum variant in rust +/// E.G. enum example{One, Two(i32)} +/// maps to +/// message example{ +/// oneof{ +/// EmptyMessage One = 1; +/// i32 Two = 2; +/// } +/// } +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct EmptyMessage {} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct JsonWriterOptions { + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "2")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "3")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "4")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "5")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "6")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "7")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "8")] + pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quotes, instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, +} +/// Options controlling CSV format +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvOptions { + /// Indicates if the CSV has a header row + #[prost(bytes = "vec", tag = "1")] + pub has_header: ::prost::alloc::vec::Vec, + /// Delimiter character as a byte + #[prost(bytes = "vec", tag = "2")] + pub delimiter: ::prost::alloc::vec::Vec, + /// Quote character as a byte + #[prost(bytes = "vec", tag = "3")] + pub quote: ::prost::alloc::vec::Vec, + /// Optional escape character as a byte + #[prost(bytes = "vec", tag = "4")] + pub escape: ::prost::alloc::vec::Vec, + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "5")] + pub compression: i32, + /// Max records for schema inference + #[prost(uint64, tag = "6")] + pub schema_infer_max_rec: u64, + /// Optional date format + #[prost(string, tag = "7")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format + #[prost(string, tag = "8")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format + #[prost(string, tag = "9")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional timestamp with timezone format + #[prost(string, tag = "10")] + pub timestamp_tz_format: ::prost::alloc::string::String, + /// Optional time format + #[prost(string, tag = "11")] + pub time_format: ::prost::alloc::string::String, + /// Optional representation of null value + #[prost(string, tag = "12")] + pub null_value: ::prost::alloc::string::String, + /// Optional comment character as a byte + #[prost(bytes = "vec", tag = "13")] + pub comment: ::prost::alloc::vec::Vec, + /// Indicates if quotes are doubled + #[prost(bytes = "vec", tag = "14")] + pub double_quote: ::prost::alloc::vec::Vec, + /// Indicates if newlines are supported in values + #[prost(bytes = "vec", tag = "15")] + pub newlines_in_values: ::prost::alloc::vec::Vec, + /// Optional terminator character as a byte + #[prost(bytes = "vec", tag = "16")] + pub terminator: ::prost::alloc::vec::Vec, +} +/// Options controlling CSV format +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct JsonOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, + /// Max records for schema inference + #[prost(uint64, tag = "2")] + pub schema_infer_max_rec: u64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableParquetOptions { + #[prost(message, optional, tag = "1")] + pub global: ::core::option::Option, + #[prost(message, repeated, tag = "2")] + pub column_specific_options: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "3")] + pub key_value_metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetColumnSpecificOptions { + #[prost(string, tag = "1")] + pub column_name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetColumnOptions { + #[prost(oneof = "parquet_column_options::BloomFilterEnabledOpt", tags = "1")] + pub bloom_filter_enabled_opt: ::core::option::Option< + parquet_column_options::BloomFilterEnabledOpt, + >, + #[prost(oneof = "parquet_column_options::EncodingOpt", tags = "2")] + pub encoding_opt: ::core::option::Option, + #[prost(oneof = "parquet_column_options::DictionaryEnabledOpt", tags = "3")] + pub dictionary_enabled_opt: ::core::option::Option< + parquet_column_options::DictionaryEnabledOpt, + >, + #[prost(oneof = "parquet_column_options::CompressionOpt", tags = "4")] + pub compression_opt: ::core::option::Option, + #[prost(oneof = "parquet_column_options::StatisticsEnabledOpt", tags = "5")] + pub statistics_enabled_opt: ::core::option::Option< + parquet_column_options::StatisticsEnabledOpt, + >, + #[prost(oneof = "parquet_column_options::BloomFilterFppOpt", tags = "6")] + pub bloom_filter_fpp_opt: ::core::option::Option< + parquet_column_options::BloomFilterFppOpt, + >, + #[prost(oneof = "parquet_column_options::BloomFilterNdvOpt", tags = "7")] + pub bloom_filter_ndv_opt: ::core::option::Option< + parquet_column_options::BloomFilterNdvOpt, + >, + #[prost(oneof = "parquet_column_options::MaxStatisticsSizeOpt", tags = "8")] + pub max_statistics_size_opt: ::core::option::Option< + parquet_column_options::MaxStatisticsSizeOpt, + >, +} +/// Nested message and enum types in `ParquetColumnOptions`. +pub mod parquet_column_options { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterEnabledOpt { + #[prost(bool, tag = "1")] + BloomFilterEnabled(bool), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum EncodingOpt { + #[prost(string, tag = "2")] + Encoding(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum DictionaryEnabledOpt { + #[prost(bool, tag = "3")] + DictionaryEnabled(bool), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CompressionOpt { + #[prost(string, tag = "4")] + Compression(::prost::alloc::string::String), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum StatisticsEnabledOpt { + #[prost(string, tag = "5")] + StatisticsEnabled(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterFppOpt { + #[prost(double, tag = "6")] + BloomFilterFpp(f64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterNdvOpt { + #[prost(uint64, tag = "7")] + BloomFilterNdv(u64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum MaxStatisticsSizeOpt { + #[prost(uint32, tag = "8")] + MaxStatisticsSize(u32), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetOptions { + /// Regular fields + /// + /// default = true + #[prost(bool, tag = "1")] + pub enable_page_index: bool, + /// default = true + #[prost(bool, tag = "2")] + pub pruning: bool, + /// default = true + #[prost(bool, tag = "3")] + pub skip_metadata: bool, + /// default = false + #[prost(bool, tag = "5")] + pub pushdown_filters: bool, + /// default = false + #[prost(bool, tag = "6")] + pub reorder_filters: bool, + /// default = 1024 * 1024 + #[prost(uint64, tag = "7")] + pub data_pagesize_limit: u64, + /// default = 1024 + #[prost(uint64, tag = "8")] + pub write_batch_size: u64, + /// default = "1.0" + #[prost(string, tag = "9")] + pub writer_version: ::prost::alloc::string::String, + /// bool bloom_filter_enabled = 20; // default = false + /// + /// default = true + #[prost(bool, tag = "23")] + pub allow_single_file_parallelism: bool, + /// default = 1 + #[prost(uint64, tag = "24")] + pub maximum_parallel_row_group_writers: u64, + /// default = 2 + #[prost(uint64, tag = "25")] + pub maximum_buffered_record_batches_per_stream: u64, + /// default = true + #[prost(bool, tag = "26")] + pub bloom_filter_on_read: bool, + /// default = false + #[prost(bool, tag = "27")] + pub bloom_filter_on_write: bool, + /// default = false + #[prost(bool, tag = "28")] + pub schema_force_view_types: bool, + /// default = false + #[prost(bool, tag = "29")] + pub binary_as_string: bool, + #[prost(uint64, tag = "12")] + pub dictionary_page_size_limit: u64, + #[prost(uint64, tag = "18")] + pub data_page_row_count_limit: u64, + #[prost(uint64, tag = "15")] + pub max_row_group_size: u64, + #[prost(string, tag = "16")] + pub created_by: ::prost::alloc::string::String, + #[prost(oneof = "parquet_options::MetadataSizeHintOpt", tags = "4")] + pub metadata_size_hint_opt: ::core::option::Option< + parquet_options::MetadataSizeHintOpt, + >, + #[prost(oneof = "parquet_options::CompressionOpt", tags = "10")] + pub compression_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::DictionaryEnabledOpt", tags = "11")] + pub dictionary_enabled_opt: ::core::option::Option< + parquet_options::DictionaryEnabledOpt, + >, + #[prost(oneof = "parquet_options::StatisticsEnabledOpt", tags = "13")] + pub statistics_enabled_opt: ::core::option::Option< + parquet_options::StatisticsEnabledOpt, + >, + #[prost(oneof = "parquet_options::MaxStatisticsSizeOpt", tags = "14")] + pub max_statistics_size_opt: ::core::option::Option< + parquet_options::MaxStatisticsSizeOpt, + >, + #[prost(oneof = "parquet_options::ColumnIndexTruncateLengthOpt", tags = "17")] + pub column_index_truncate_length_opt: ::core::option::Option< + parquet_options::ColumnIndexTruncateLengthOpt, + >, + #[prost(oneof = "parquet_options::EncodingOpt", tags = "19")] + pub encoding_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::BloomFilterFppOpt", tags = "21")] + pub bloom_filter_fpp_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::BloomFilterNdvOpt", tags = "22")] + pub bloom_filter_ndv_opt: ::core::option::Option, +} +/// Nested message and enum types in `ParquetOptions`. +pub mod parquet_options { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum MetadataSizeHintOpt { + #[prost(uint64, tag = "4")] + MetadataSizeHint(u64), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CompressionOpt { + #[prost(string, tag = "10")] + Compression(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum DictionaryEnabledOpt { + #[prost(bool, tag = "11")] + DictionaryEnabled(bool), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum StatisticsEnabledOpt { + #[prost(string, tag = "13")] + StatisticsEnabled(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum MaxStatisticsSizeOpt { + #[prost(uint64, tag = "14")] + MaxStatisticsSize(u64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum ColumnIndexTruncateLengthOpt { + #[prost(uint64, tag = "17")] + ColumnIndexTruncateLength(u64), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum EncodingOpt { + #[prost(string, tag = "19")] + Encoding(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterFppOpt { + #[prost(double, tag = "21")] + BloomFilterFpp(f64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterNdvOpt { + #[prost(uint64, tag = "22")] + BloomFilterNdv(u64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Precision { + #[prost(enumeration = "PrecisionInfo", tag = "1")] + pub precision_info: i32, + #[prost(message, optional, tag = "2")] + pub val: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Statistics { + #[prost(message, optional, tag = "1")] + pub num_rows: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub total_byte_size: ::core::option::Option, + #[prost(message, repeated, tag = "3")] + pub column_stats: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnStats { + #[prost(message, optional, tag = "1")] + pub min_value: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub max_value: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub null_count: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub distinct_count: ::core::option::Option, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum JoinType { + Inner = 0, + Left = 1, + Right = 2, + Full = 3, + Leftsemi = 4, + Leftanti = 5, + Rightsemi = 6, + Rightanti = 7, + Leftmark = 8, +} +impl JoinType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", + Self::Leftmark => "LEFTMARK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "INNER" => Some(Self::Inner), + "LEFT" => Some(Self::Left), + "RIGHT" => Some(Self::Right), + "FULL" => Some(Self::Full), + "LEFTSEMI" => Some(Self::Leftsemi), + "LEFTANTI" => Some(Self::Leftanti), + "RIGHTSEMI" => Some(Self::Rightsemi), + "RIGHTANTI" => Some(Self::Rightanti), + "LEFTMARK" => Some(Self::Leftmark), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum JoinConstraint { + On = 0, + Using = 1, +} +impl JoinConstraint { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::On => "ON", + Self::Using => "USING", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "ON" => Some(Self::On), + "USING" => Some(Self::Using), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum TimeUnit { + Second = 0, + Millisecond = 1, + Microsecond = 2, + Nanosecond = 3, +} +impl TimeUnit { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Second => "Second", + Self::Millisecond => "Millisecond", + Self::Microsecond => "Microsecond", + Self::Nanosecond => "Nanosecond", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Second" => Some(Self::Second), + "Millisecond" => Some(Self::Millisecond), + "Microsecond" => Some(Self::Microsecond), + "Nanosecond" => Some(Self::Nanosecond), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum IntervalUnit { + YearMonth = 0, + DayTime = 1, + MonthDayNano = 2, +} +impl IntervalUnit { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::YearMonth => "YearMonth", + Self::DayTime => "DayTime", + Self::MonthDayNano => "MonthDayNano", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "YearMonth" => Some(Self::YearMonth), + "DayTime" => Some(Self::DayTime), + "MonthDayNano" => Some(Self::MonthDayNano), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum UnionMode { + Sparse = 0, + Dense = 1, +} +impl UnionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Sparse => "sparse", + Self::Dense => "dense", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "sparse" => Some(Self::Sparse), + "dense" => Some(Self::Dense), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum CompressionTypeVariant { + Gzip = 0, + Bzip2 = 1, + Xz = 2, + Zstd = 3, + Uncompressed = 4, +} +impl CompressionTypeVariant { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GZIP" => Some(Self::Gzip), + "BZIP2" => Some(Self::Bzip2), + "XZ" => Some(Self::Xz), + "ZSTD" => Some(Self::Zstd), + "UNCOMPRESSED" => Some(Self::Uncompressed), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum JoinSide { + LeftSide = 0, + RightSide = 1, + None = 2, +} +impl JoinSide { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::LeftSide => "LEFT_SIDE", + Self::RightSide => "RIGHT_SIDE", + Self::None => "NONE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "LEFT_SIDE" => Some(Self::LeftSide), + "RIGHT_SIDE" => Some(Self::RightSide), + "NONE" => Some(Self::None), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum PrecisionInfo { + Exact = 0, + Inexact = 1, + Absent = 2, +} +impl PrecisionInfo { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "EXACT" => Some(Self::Exact), + "INEXACT" => Some(Self::Inexact), + "ABSENT" => Some(Self::Absent), + _ => None, + } + } +} diff --git a/datafusion/proto-common/src/lib.rs b/datafusion/proto-common/src/lib.rs new file mode 100644 index 000000000000..91e393915442 --- /dev/null +++ b/datafusion/proto-common/src/lib.rs @@ -0,0 +1,95 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] + +//! Serialize / Deserialize DataFusion Primitive Types to bytes +//! +//! This crate provides support for serializing and deserializing the +//! following structures to and from bytes: +//! +//! 1. [`ScalarValue`]'s +//! +//! [`ScalarValue`]: datafusion_common::ScalarValue +//! +//! Internally, this crate is implemented by converting the common types to [protocol +//! buffers] using [prost]. +//! +//! [protocol buffers]: https://developers.google.com/protocol-buffers +//! [prost]: https://docs.rs/prost/latest/prost/ +//! +//! # Version Compatibility +//! +//! The serialized form are not guaranteed to be compatible across +//! DataFusion versions. A plan serialized with one version of DataFusion +//! may not be able to deserialized with a different version. +//! +//! # See Also +//! +//! The binary format created by this crate supports the full range of DataFusion +//! plans, but is DataFusion specific. See [datafusion-substrait] for a crate +//! which can encode many DataFusion plans using the [substrait.io] standard. +//! +//! [datafusion-substrait]: https://docs.rs/datafusion-substrait/latest/datafusion_substrait +//! [substrait.io]: https://substrait.io +//! +//! # Example: Serializing [`ScalarValue`]s +//! ``` +//! # use datafusion_common::{ScalarValue, Result}; +//! # use prost::{bytes::{Bytes, BytesMut}}; +//! # use datafusion_common::plan_datafusion_err; +//! # use datafusion_proto_common::protobuf_common; +//! # use prost::Message; +//! # fn main() -> Result<()>{ +//! // Create a new ScalarValue +//! let val = ScalarValue::UInt64(Some(3)); +//! let mut buffer = BytesMut::new(); +//! let protobuf: protobuf_common::ScalarValue = match val { +//! ScalarValue::UInt64(Some(val)) => { +//! protobuf_common::ScalarValue{value: Some(protobuf_common::scalar_value::Value::Uint64Value(val))} +//! } +//! _ => unreachable!(), +//! }; +//! +//! protobuf.encode(&mut buffer) +//! .map_err(|e| plan_datafusion_err!("Error encoding protobuf as bytes: {e}"))?; +//! // Convert it to bytes (for sending over the network, etc.) +//! let bytes: Bytes = buffer.into(); +//! +//! let protobuf = protobuf_common::ScalarValue::decode(bytes).map_err(|e| plan_datafusion_err!("Error decoding ScalarValue as protobuf: {e}"))?; +//! // Decode bytes from somewhere (over network, etc.) back to ScalarValue +//! let decoded_val: ScalarValue = match protobuf.value { +//! Some(protobuf_common::scalar_value::Value::Uint64Value(val)) => ScalarValue::UInt64(Some(val)), +//! _ => unreachable!(), +//! }; +//! assert_eq!(val, decoded_val); +//! # Ok(()) +//! # } +//! ``` + +pub mod common; +pub mod from_proto; +pub mod generated; +pub mod to_proto; + +pub use from_proto::Error as FromProtoError; +pub use generated::datafusion_proto_common as protobuf_common; +pub use generated::datafusion_proto_common::*; +pub use to_proto::Error as ToProtoError; + +#[cfg(doctest)] +doc_comment::doctest!("../README.md", readme_example_test); diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs new file mode 100644 index 000000000000..02a642a4af93 --- /dev/null +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -0,0 +1,1056 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use crate::protobuf_common as protobuf; +use crate::protobuf_common::{ + arrow_type::ArrowTypeEnum, scalar_value::Value, EmptyMessage, +}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::csv::WriterBuilder; +use arrow::datatypes::{ + DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, + SchemaRef, TimeUnit, UnionMode, +}; +use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; +use datafusion_common::{ + config::{ + CsvOptions, JsonOptions, ParquetColumnOptions, ParquetOptions, + TableParquetOptions, + }, + file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, + parsers::CompressionTypeVariant, + plan_datafusion_err, + stats::Precision, + Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, + DataFusionError, JoinSide, ScalarValue, Statistics, +}; + +#[derive(Debug)] +pub enum Error { + General(String), + + InvalidScalarValue(ScalarValue), + + InvalidScalarType(DataType), + + InvalidTimeUnit(TimeUnit), + + NotImplemented(String), +} + +impl std::error::Error for Error {} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::General(desc) => write!(f, "General error: {desc}"), + Self::InvalidScalarValue(value) => { + write!(f, "{value:?} is invalid as a DataFusion scalar value") + } + Self::InvalidScalarType(data_type) => { + write!(f, "{data_type:?} is invalid as a DataFusion scalar type") + } + Self::InvalidTimeUnit(time_unit) => { + write!( + f, + "Only TimeUnit::Microsecond and TimeUnit::Nanosecond are valid time units, found: {time_unit:?}" + ) + } + Self::NotImplemented(s) => { + write!(f, "Not implemented: {s}") + } + } + } +} + +impl From for DataFusionError { + fn from(e: Error) -> Self { + plan_datafusion_err!("{}", e) + } +} + +impl TryFrom<&Field> for protobuf::Field { + type Error = Error; + + fn try_from(field: &Field) -> Result { + let arrow_type = field.data_type().try_into()?; + Ok(Self { + name: field.name().to_owned(), + arrow_type: Some(Box::new(arrow_type)), + nullable: field.is_nullable(), + children: Vec::new(), + metadata: field.metadata().clone(), + dict_id: field.dict_id().unwrap_or(0), + dict_ordered: field.dict_is_ordered().unwrap_or(false), + }) + } +} + +impl TryFrom<&DataType> for protobuf::ArrowType { + type Error = Error; + + fn try_from(val: &DataType) -> Result { + let arrow_type_enum: ArrowTypeEnum = val.try_into()?; + Ok(Self { + arrow_type_enum: Some(arrow_type_enum), + }) + } +} + +impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { + type Error = Error; + + fn try_from(val: &DataType) -> Result { + let res = match val { + DataType::Null => Self::None(EmptyMessage {}), + DataType::Boolean => Self::Bool(EmptyMessage {}), + DataType::Int8 => Self::Int8(EmptyMessage {}), + DataType::Int16 => Self::Int16(EmptyMessage {}), + DataType::Int32 => Self::Int32(EmptyMessage {}), + DataType::Int64 => Self::Int64(EmptyMessage {}), + DataType::UInt8 => Self::Uint8(EmptyMessage {}), + DataType::UInt16 => Self::Uint16(EmptyMessage {}), + DataType::UInt32 => Self::Uint32(EmptyMessage {}), + DataType::UInt64 => Self::Uint64(EmptyMessage {}), + DataType::Float16 => Self::Float16(EmptyMessage {}), + DataType::Float32 => Self::Float32(EmptyMessage {}), + DataType::Float64 => Self::Float64(EmptyMessage {}), + DataType::Timestamp(time_unit, timezone) => { + Self::Timestamp(protobuf::Timestamp { + time_unit: protobuf::TimeUnit::from(time_unit) as i32, + timezone: timezone.as_deref().unwrap_or("").to_string(), + }) + } + DataType::Date32 => Self::Date32(EmptyMessage {}), + DataType::Date64 => Self::Date64(EmptyMessage {}), + DataType::Time32(time_unit) => { + Self::Time32(protobuf::TimeUnit::from(time_unit) as i32) + } + DataType::Time64(time_unit) => { + Self::Time64(protobuf::TimeUnit::from(time_unit) as i32) + } + DataType::Duration(time_unit) => { + Self::Duration(protobuf::TimeUnit::from(time_unit) as i32) + } + DataType::Interval(interval_unit) => { + Self::Interval(protobuf::IntervalUnit::from(interval_unit) as i32) + } + DataType::Binary => Self::Binary(EmptyMessage {}), + DataType::BinaryView => Self::BinaryView(EmptyMessage {}), + DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(*size), + DataType::LargeBinary => Self::LargeBinary(EmptyMessage {}), + DataType::Utf8 => Self::Utf8(EmptyMessage {}), + DataType::Utf8View => Self::Utf8View(EmptyMessage {}), + DataType::LargeUtf8 => Self::LargeUtf8(EmptyMessage {}), + DataType::List(item_type) => Self::List(Box::new(protobuf::List { + field_type: Some(Box::new(item_type.as_ref().try_into()?)), + })), + DataType::FixedSizeList(item_type, size) => { + Self::FixedSizeList(Box::new(protobuf::FixedSizeList { + field_type: Some(Box::new(item_type.as_ref().try_into()?)), + list_size: *size, + })) + } + DataType::LargeList(item_type) => Self::LargeList(Box::new(protobuf::List { + field_type: Some(Box::new(item_type.as_ref().try_into()?)), + })), + DataType::Struct(struct_fields) => Self::Struct(protobuf::Struct { + sub_field_types: convert_arc_fields_to_proto_fields(struct_fields)?, + }), + DataType::Union(fields, union_mode) => { + let union_mode = match union_mode { + UnionMode::Sparse => protobuf::UnionMode::Sparse, + UnionMode::Dense => protobuf::UnionMode::Dense, + }; + Self::Union(protobuf::Union { + union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, + union_mode: union_mode.into(), + type_ids: fields.iter().map(|(x, _)| x as i32).collect(), + }) + } + DataType::Dictionary(key_type, value_type) => { + Self::Dictionary(Box::new(protobuf::Dictionary { + key: Some(Box::new(key_type.as_ref().try_into()?)), + value: Some(Box::new(value_type.as_ref().try_into()?)), + })) + } + DataType::Decimal128(precision, scale) => Self::Decimal(protobuf::Decimal { + precision: *precision as u32, + scale: *scale as i32, + }), + DataType::Decimal256(precision, scale) => Self::Decimal256(protobuf::Decimal256Type { + precision: *precision as u32, + scale: *scale as i32, + }), + DataType::Map(field, sorted) => { + Self::Map(Box::new( + protobuf::Map { + field_type: Some(Box::new(field.as_ref().try_into()?)), + keys_sorted: *sorted, + } + )) + } + DataType::RunEndEncoded(_, _) => { + return Err(Error::General( + "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() + )) + } + DataType::ListView(_) | DataType::LargeListView(_) => { + return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + } + }; + + Ok(res) + } +} + +impl From for protobuf::Column { + fn from(c: Column) -> Self { + Self { + relation: c.relation.map(|relation| protobuf::ColumnRelation { + relation: relation.to_string(), + }), + name: c.name, + } + } +} + +impl From<&Column> for protobuf::Column { + fn from(c: &Column) -> Self { + c.clone().into() + } +} + +impl TryFrom<&Schema> for protobuf::Schema { + type Error = Error; + + fn try_from(schema: &Schema) -> Result { + Ok(Self { + columns: convert_arc_fields_to_proto_fields(schema.fields())?, + metadata: schema.metadata.clone(), + }) + } +} + +impl TryFrom for protobuf::Schema { + type Error = Error; + + fn try_from(schema: SchemaRef) -> Result { + Ok(Self { + columns: convert_arc_fields_to_proto_fields(schema.fields())?, + metadata: schema.metadata.clone(), + }) + } +} + +impl TryFrom<&DFSchema> for protobuf::DfSchema { + type Error = Error; + + fn try_from(s: &DFSchema) -> Result { + let columns = s + .iter() + .map(|(qualifier, field)| { + Ok(protobuf::DfField { + field: Some(field.as_ref().try_into()?), + qualifier: qualifier.map(|r| protobuf::ColumnRelation { + relation: r.to_string(), + }), + }) + }) + .collect::, Error>>()?; + Ok(Self { + columns, + metadata: s.metadata().clone(), + }) + } +} + +impl TryFrom<&DFSchemaRef> for protobuf::DfSchema { + type Error = Error; + + fn try_from(s: &DFSchemaRef) -> Result { + s.as_ref().try_into() + } +} + +impl TryFrom<&ScalarValue> for protobuf::ScalarValue { + type Error = Error; + + fn try_from(val: &ScalarValue) -> Result { + let data_type = val.data_type(); + match val { + ScalarValue::Boolean(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) + } + ScalarValue::Float16(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Float32Value((*s).into()) + }) + } + ScalarValue::Float32(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float32Value(*s)) + } + ScalarValue::Float64(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float64Value(*s)) + } + ScalarValue::Int8(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Int8Value(*s as i32) + }) + } + ScalarValue::Int16(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Int16Value(*s as i32) + }) + } + ScalarValue::Int32(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int32Value(*s)) + } + ScalarValue::Int64(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int64Value(*s)) + } + ScalarValue::UInt8(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Uint8Value(*s as u32) + }) + } + ScalarValue::UInt16(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Uint16Value(*s as u32) + }) + } + ScalarValue::UInt32(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint32Value(*s)) + } + ScalarValue::UInt64(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint64Value(*s)) + } + ScalarValue::Utf8(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Utf8Value(s.to_owned()) + }) + } + ScalarValue::LargeUtf8(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::LargeUtf8Value(s.to_owned()) + }) + } + ScalarValue::Utf8View(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Utf8ViewValue(s.to_owned()) + }) + } + ScalarValue::List(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::LargeList(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::FixedSizeList(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::Struct(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::Map(arr) => { + encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) + } + ScalarValue::Date32(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) + } + ScalarValue::TimestampMicrosecond(val, tz) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::TimestampValue(protobuf::ScalarTimestampValue { + timezone: tz.as_deref().unwrap_or("").to_string(), + value: Some( + protobuf::scalar_timestamp_value::Value::TimeMicrosecondValue( + *s, + ), + ), + }) + }) + } + ScalarValue::TimestampNanosecond(val, tz) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::TimestampValue(protobuf::ScalarTimestampValue { + timezone: tz.as_deref().unwrap_or("").to_string(), + value: Some( + protobuf::scalar_timestamp_value::Value::TimeNanosecondValue( + *s, + ), + ), + }) + }) + } + ScalarValue::Decimal128(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal128Value(protobuf::Decimal128 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, + ScalarValue::Decimal256(val, p, s) => match *val { + Some(v) => { + let array = v.to_be_bytes(); + let vec_val: Vec = array.to_vec(); + Ok(protobuf::ScalarValue { + value: Some(Value::Decimal256Value(protobuf::Decimal256 { + value: vec_val, + p: *p as i64, + s: *s as i64, + })), + }) + } + None => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::NullValue( + (&data_type).try_into()?, + )), + }), + }, + ScalarValue::Date64(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date64Value(*s)) + } + ScalarValue::TimestampSecond(val, tz) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::TimestampValue(protobuf::ScalarTimestampValue { + timezone: tz.as_deref().unwrap_or("").to_string(), + value: Some( + protobuf::scalar_timestamp_value::Value::TimeSecondValue(*s), + ), + }) + }) + } + ScalarValue::TimestampMillisecond(val, tz) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::TimestampValue(protobuf::ScalarTimestampValue { + timezone: tz.as_deref().unwrap_or("").to_string(), + value: Some( + protobuf::scalar_timestamp_value::Value::TimeMillisecondValue( + *s, + ), + ), + }) + }) + } + ScalarValue::IntervalYearMonth(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::IntervalYearmonthValue(*s) + }) + } + ScalarValue::Null => Ok(protobuf::ScalarValue { + value: Some(Value::NullValue((&data_type).try_into()?)), + }), + + ScalarValue::Binary(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::BinaryValue(s.to_owned()) + }) + } + ScalarValue::BinaryView(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::BinaryViewValue(s.to_owned()) + }) + } + ScalarValue::LargeBinary(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::LargeBinaryValue(s.to_owned()) + }) + } + ScalarValue::FixedSizeBinary(length, val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::FixedSizeBinaryValue(protobuf::ScalarFixedSizeBinary { + values: s.to_owned(), + length: *length, + }) + }) + } + + ScalarValue::Time32Second(v) => { + create_proto_scalar(v.as_ref(), &data_type, |v| { + Value::Time32Value(protobuf::ScalarTime32Value { + value: Some( + protobuf::scalar_time32_value::Value::Time32SecondValue(*v), + ), + }) + }) + } + + ScalarValue::Time32Millisecond(v) => { + create_proto_scalar(v.as_ref(), &data_type, |v| { + Value::Time32Value(protobuf::ScalarTime32Value { + value: Some( + protobuf::scalar_time32_value::Value::Time32MillisecondValue( + *v, + ), + ), + }) + }) + } + + ScalarValue::Time64Microsecond(v) => { + create_proto_scalar(v.as_ref(), &data_type, |v| { + Value::Time64Value(protobuf::ScalarTime64Value { + value: Some( + protobuf::scalar_time64_value::Value::Time64MicrosecondValue( + *v, + ), + ), + }) + }) + } + + ScalarValue::Time64Nanosecond(v) => { + create_proto_scalar(v.as_ref(), &data_type, |v| { + Value::Time64Value(protobuf::ScalarTime64Value { + value: Some( + protobuf::scalar_time64_value::Value::Time64NanosecondValue( + *v, + ), + ), + }) + }) + } + + ScalarValue::IntervalDayTime(val) => { + let value = if let Some(v) = val { + let (days, milliseconds) = IntervalDayTimeType::to_parts(*v); + Value::IntervalDaytimeValue(protobuf::IntervalDayTimeValue { + days, + milliseconds, + }) + } else { + Value::NullValue((&data_type).try_into()?) + }; + + Ok(protobuf::ScalarValue { value: Some(value) }) + } + + ScalarValue::IntervalMonthDayNano(v) => { + let value = if let Some(v) = v { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); + Value::IntervalMonthDayNano(protobuf::IntervalMonthDayNanoValue { + months, + days, + nanos, + }) + } else { + Value::NullValue((&data_type).try_into()?) + }; + + Ok(protobuf::ScalarValue { value: Some(value) }) + } + + ScalarValue::DurationSecond(v) => { + let value = match v { + Some(v) => Value::DurationSecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationMillisecond(v) => { + let value = match v { + Some(v) => Value::DurationMillisecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationMicrosecond(v) => { + let value = match v { + Some(v) => Value::DurationMicrosecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + ScalarValue::DurationNanosecond(v) => { + let value = match v { + Some(v) => Value::DurationNanosecondValue(*v), + None => Value::NullValue((&data_type).try_into()?), + }; + Ok(protobuf::ScalarValue { value: Some(value) }) + } + + ScalarValue::Union(val, df_fields, mode) => { + let mut fields = + Vec::::with_capacity(df_fields.len()); + for (id, field) in df_fields.iter() { + let field_id = id as i32; + let field = Some(field.as_ref().try_into()?); + let field = protobuf::UnionField { field_id, field }; + fields.push(field); + } + let mode = match mode { + UnionMode::Sparse => 0, + UnionMode::Dense => 1, + }; + let value = match val { + None => None, + Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), + }; + let val = protobuf::UnionValue { + value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), + value, + fields, + mode, + }; + let val = Value::UnionValue(Box::new(val)); + let val = protobuf::ScalarValue { value: Some(val) }; + Ok(val) + } + + ScalarValue::Dictionary(index_type, val) => { + let value: protobuf::ScalarValue = val.as_ref().try_into()?; + Ok(protobuf::ScalarValue { + value: Some(Value::DictionaryValue(Box::new( + protobuf::ScalarDictionaryValue { + index_type: Some(index_type.as_ref().try_into()?), + value: Some(Box::new(value)), + }, + ))), + }) + } + } + } +} + +impl From<&TimeUnit> for protobuf::TimeUnit { + fn from(val: &TimeUnit) -> Self { + match val { + TimeUnit::Second => protobuf::TimeUnit::Second, + TimeUnit::Millisecond => protobuf::TimeUnit::Millisecond, + TimeUnit::Microsecond => protobuf::TimeUnit::Microsecond, + TimeUnit::Nanosecond => protobuf::TimeUnit::Nanosecond, + } + } +} + +impl From<&IntervalUnit> for protobuf::IntervalUnit { + fn from(interval_unit: &IntervalUnit) -> Self { + match interval_unit { + IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth, + IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime, + IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano, + } + } +} + +impl From for protobuf::Constraints { + fn from(value: Constraints) -> Self { + let constraints = value.into_iter().map(|item| item.into()).collect(); + protobuf::Constraints { constraints } + } +} + +impl From for protobuf::Constraint { + fn from(value: Constraint) -> Self { + let res = match value { + Constraint::PrimaryKey(indices) => { + let indices = indices.into_iter().map(|item| item as u64).collect(); + protobuf::constraint::ConstraintMode::PrimaryKey( + protobuf::PrimaryKeyConstraint { indices }, + ) + } + Constraint::Unique(indices) => { + let indices = indices.into_iter().map(|item| item as u64).collect(); + protobuf::constraint::ConstraintMode::PrimaryKey( + protobuf::PrimaryKeyConstraint { indices }, + ) + } + }; + protobuf::Constraint { + constraint_mode: Some(res), + } + } +} + +impl From<&Precision> for protobuf::Precision { + fn from(s: &Precision) -> protobuf::Precision { + match s { + Precision::Exact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Exact.into(), + val: Some(crate::protobuf_common::ScalarValue { + value: Some(Value::Uint64Value(*val as u64)), + }), + }, + Precision::Inexact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Inexact.into(), + val: Some(crate::protobuf_common::ScalarValue { + value: Some(Value::Uint64Value(*val as u64)), + }), + }, + Precision::Absent => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Absent.into(), + val: Some(crate::protobuf_common::ScalarValue { value: None }), + }, + } + } +} + +impl From<&Precision> for protobuf::Precision { + fn from(s: &Precision) -> protobuf::Precision { + match s { + Precision::Exact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Exact.into(), + val: val.try_into().ok(), + }, + Precision::Inexact(val) => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Inexact.into(), + val: val.try_into().ok(), + }, + Precision::Absent => protobuf::Precision { + precision_info: protobuf::PrecisionInfo::Absent.into(), + val: Some(crate::protobuf_common::ScalarValue { value: None }), + }, + } + } +} + +impl From<&Statistics> for protobuf::Statistics { + fn from(s: &Statistics) -> protobuf::Statistics { + let column_stats = s.column_statistics.iter().map(|s| s.into()).collect(); + protobuf::Statistics { + num_rows: Some(protobuf::Precision::from(&s.num_rows)), + total_byte_size: Some(protobuf::Precision::from(&s.total_byte_size)), + column_stats, + } + } +} + +impl From<&ColumnStatistics> for protobuf::ColumnStats { + fn from(s: &ColumnStatistics) -> protobuf::ColumnStats { + protobuf::ColumnStats { + min_value: Some(protobuf::Precision::from(&s.min_value)), + max_value: Some(protobuf::Precision::from(&s.max_value)), + null_count: Some(protobuf::Precision::from(&s.null_count)), + distinct_count: Some(protobuf::Precision::from(&s.distinct_count)), + } + } +} + +impl From for protobuf::JoinSide { + fn from(t: JoinSide) -> Self { + match t { + JoinSide::Left => protobuf::JoinSide::LeftSide, + JoinSide::Right => protobuf::JoinSide::RightSide, + JoinSide::None => protobuf::JoinSide::None, + } + } +} + +impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { + fn from(value: &CompressionTypeVariant) -> Self { + match value { + CompressionTypeVariant::GZIP => Self::Gzip, + CompressionTypeVariant::BZIP2 => Self::Bzip2, + CompressionTypeVariant::XZ => Self::Xz, + CompressionTypeVariant::ZSTD => Self::Zstd, + CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, + } + } +} + +impl TryFrom<&CsvWriterOptions> for protobuf::CsvWriterOptions { + type Error = DataFusionError; + + fn try_from(opts: &CsvWriterOptions) -> datafusion_common::Result { + Ok(csv_writer_options_to_proto( + &opts.writer_options, + &opts.compression, + )) + } +} + +impl TryFrom<&JsonWriterOptions> for protobuf::JsonWriterOptions { + type Error = DataFusionError; + + fn try_from( + opts: &JsonWriterOptions, + ) -> datafusion_common::Result { + let compression: protobuf::CompressionTypeVariant = opts.compression.into(); + Ok(protobuf::JsonWriterOptions { + compression: compression.into(), + }) + } +} + +impl TryFrom<&ParquetOptions> for protobuf::ParquetOptions { + type Error = DataFusionError; + + fn try_from(value: &ParquetOptions) -> datafusion_common::Result { + Ok(protobuf::ParquetOptions { + enable_page_index: value.enable_page_index, + pruning: value.pruning, + skip_metadata: value.skip_metadata, + metadata_size_hint_opt: value.metadata_size_hint.map(|v| protobuf::parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v as u64)), + pushdown_filters: value.pushdown_filters, + reorder_filters: value.reorder_filters, + data_pagesize_limit: value.data_pagesize_limit as u64, + write_batch_size: value.write_batch_size as u64, + writer_version: value.writer_version.clone(), + compression_opt: value.compression.clone().map(protobuf::parquet_options::CompressionOpt::Compression), + dictionary_enabled_opt: value.dictionary_enabled.map(protobuf::parquet_options::DictionaryEnabledOpt::DictionaryEnabled), + dictionary_page_size_limit: value.dictionary_page_size_limit as u64, + statistics_enabled_opt: value.statistics_enabled.clone().map(protobuf::parquet_options::StatisticsEnabledOpt::StatisticsEnabled), + max_statistics_size_opt: value.max_statistics_size.map(|v| protobuf::parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v as u64)), + max_row_group_size: value.max_row_group_size as u64, + created_by: value.created_by.clone(), + column_index_truncate_length_opt: value.column_index_truncate_length.map(|v| protobuf::parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v as u64)), + data_page_row_count_limit: value.data_page_row_count_limit as u64, + encoding_opt: value.encoding.clone().map(protobuf::parquet_options::EncodingOpt::Encoding), + bloom_filter_on_read: value.bloom_filter_on_read, + bloom_filter_on_write: value.bloom_filter_on_write, + bloom_filter_fpp_opt: value.bloom_filter_fpp.map(protobuf::parquet_options::BloomFilterFppOpt::BloomFilterFpp), + bloom_filter_ndv_opt: value.bloom_filter_ndv.map(protobuf::parquet_options::BloomFilterNdvOpt::BloomFilterNdv), + allow_single_file_parallelism: value.allow_single_file_parallelism, + maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as u64, + maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as u64, + schema_force_view_types: value.schema_force_view_types, + binary_as_string: value.binary_as_string, + }) + } +} + +impl TryFrom<&ParquetColumnOptions> for protobuf::ParquetColumnOptions { + type Error = DataFusionError; + + fn try_from( + value: &ParquetColumnOptions, + ) -> datafusion_common::Result { + Ok(protobuf::ParquetColumnOptions { + compression_opt: value + .compression + .clone() + .map(protobuf::parquet_column_options::CompressionOpt::Compression), + dictionary_enabled_opt: value + .dictionary_enabled + .map(protobuf::parquet_column_options::DictionaryEnabledOpt::DictionaryEnabled), + statistics_enabled_opt: value + .statistics_enabled + .clone() + .map(protobuf::parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled), + max_statistics_size_opt: value.max_statistics_size.map(|v| { + protobuf::parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize( + v as u32, + ) + }), + encoding_opt: value + .encoding + .clone() + .map(protobuf::parquet_column_options::EncodingOpt::Encoding), + bloom_filter_enabled_opt: value + .bloom_filter_enabled + .map(protobuf::parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled), + bloom_filter_fpp_opt: value + .bloom_filter_fpp + .map(protobuf::parquet_column_options::BloomFilterFppOpt::BloomFilterFpp), + bloom_filter_ndv_opt: value + .bloom_filter_ndv + .map(protobuf::parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv), + }) + } +} + +impl TryFrom<&TableParquetOptions> for protobuf::TableParquetOptions { + type Error = DataFusionError; + fn try_from( + value: &TableParquetOptions, + ) -> datafusion_common::Result { + let column_specific_options = value + .column_specific_options + .iter() + .map(|(k, v)| { + Ok(protobuf::ParquetColumnSpecificOptions { + column_name: k.into(), + options: Some(v.try_into()?), + }) + }) + .collect::>>()?; + let key_value_metadata = value + .key_value_metadata + .iter() + .filter_map(|(k, v)| v.as_ref().map(|v| (k.clone(), v.clone()))) + .collect::>(); + Ok(protobuf::TableParquetOptions { + global: Some((&value.global).try_into()?), + column_specific_options, + key_value_metadata, + }) + } +} + +impl TryFrom<&CsvOptions> for protobuf::CsvOptions { + type Error = DataFusionError; // Define or use an appropriate error type + + fn try_from(opts: &CsvOptions) -> datafusion_common::Result { + let compression: protobuf::CompressionTypeVariant = opts.compression.into(); + Ok(protobuf::CsvOptions { + has_header: opts.has_header.map_or_else(Vec::new, |h| vec![h as u8]), + delimiter: vec![opts.delimiter], + quote: vec![opts.quote], + terminator: opts.terminator.map_or_else(Vec::new, |e| vec![e]), + escape: opts.escape.map_or_else(Vec::new, |e| vec![e]), + double_quote: opts.double_quote.map_or_else(Vec::new, |h| vec![h as u8]), + newlines_in_values: opts + .newlines_in_values + .map_or_else(Vec::new, |h| vec![h as u8]), + compression: compression.into(), + schema_infer_max_rec: opts.schema_infer_max_rec as u64, + date_format: opts.date_format.clone().unwrap_or_default(), + datetime_format: opts.datetime_format.clone().unwrap_or_default(), + timestamp_format: opts.timestamp_format.clone().unwrap_or_default(), + timestamp_tz_format: opts.timestamp_tz_format.clone().unwrap_or_default(), + time_format: opts.time_format.clone().unwrap_or_default(), + null_value: opts.null_value.clone().unwrap_or_default(), + comment: opts.comment.map_or_else(Vec::new, |h| vec![h]), + }) + } +} + +impl TryFrom<&JsonOptions> for protobuf::JsonOptions { + type Error = DataFusionError; + + fn try_from(opts: &JsonOptions) -> datafusion_common::Result { + let compression: protobuf::CompressionTypeVariant = opts.compression.into(); + Ok(protobuf::JsonOptions { + compression: compression.into(), + schema_infer_max_rec: opts.schema_infer_max_rec as u64, + }) + } +} + +/// Creates a scalar protobuf value from an optional value (T), and +/// encoding None as the appropriate datatype +fn create_proto_scalar protobuf::scalar_value::Value>( + v: Option<&I>, + null_arrow_type: &DataType, + constructor: T, +) -> Result { + let value = v + .map(constructor) + .unwrap_or(protobuf::scalar_value::Value::NullValue( + null_arrow_type.try_into()?, + )); + + Ok(protobuf::ScalarValue { value: Some(value) }) +} + +// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using +// Arrow IPC messages as a single column RecordBatch +fn encode_scalar_nested_value( + arr: ArrayRef, + val: &ScalarValue, +) -> Result { + let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { + Error::General(format!( + "Error creating temporary batch while encoding ScalarValue::List: {e}" + )) + })?; + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (encoded_dictionaries, encoded_message) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .map_err(|e| { + Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + })?; + + let schema: protobuf::Schema = batch.schema().try_into()?; + + let scalar_list_value = protobuf::ScalarNestedValue { + ipc_message: encoded_message.ipc_message, + arrow_data: encoded_message.arrow_data, + dictionaries: encoded_dictionaries + .into_iter() + .map(|data| protobuf::scalar_nested_value::Dictionary { + ipc_message: data.ipc_message, + arrow_data: data.arrow_data, + }) + .collect(), + schema: Some(schema), + }; + + match val { + ScalarValue::List(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)), + }), + ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::LargeListValue( + scalar_list_value, + )), + }), + ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::FixedSizeListValue( + scalar_list_value, + )), + }), + ScalarValue::Struct(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::StructValue( + scalar_list_value, + )), + }), + ScalarValue::Map(_) => Ok(protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::MapValue(scalar_list_value)), + }), + _ => unreachable!(), + } +} + +/// Converts a vector of `Arc`s to `protobuf::Field`s +fn convert_arc_fields_to_proto_fields<'a, I>( + fields: I, +) -> Result, Error> +where + I: IntoIterator>, +{ + fields + .into_iter() + .map(|field| field.as_ref().try_into()) + .collect::, Error>>() +} + +pub(crate) fn csv_writer_options_to_proto( + csv_options: &WriterBuilder, + compression: &CompressionTypeVariant, +) -> protobuf::CsvWriterOptions { + let compression: protobuf::CompressionTypeVariant = compression.into(); + protobuf::CsvWriterOptions { + compression: compression.into(), + delimiter: (csv_options.delimiter() as char).to_string(), + has_header: csv_options.header(), + date_format: csv_options.date_format().unwrap_or("").to_owned(), + datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), + timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), + time_format: csv_options.time_format().unwrap_or("").to_owned(), + null_value: csv_options.null().to_owned(), + quote: (csv_options.quote() as char).to_string(), + escape: (csv_options.escape() as char).to_string(), + double_quote: csv_options.double_quote(), + } +} diff --git a/datafusion/proto/.gitignore b/datafusion/proto/.gitignore index 3aa373dc479b..662b95f238c2 100644 --- a/datafusion/proto/.gitignore +++ b/datafusion/proto/.gitignore @@ -1,4 +1,5 @@ # Files generated by regen.sh proto/proto_descriptor.bin src/datafusion.rs -datafusion.serde.rs +src/datafusion.serde.rs +src/datafusion_common.rs diff --git a/datafusion/proto/CONTRIBUTING.md b/datafusion/proto/CONTRIBUTING.md index f124c233d04f..db3658c72610 100644 --- a/datafusion/proto/CONTRIBUTING.md +++ b/datafusion/proto/CONTRIBUTING.md @@ -29,4 +29,4 @@ valid installation of [protoc] (see [installation instructions] for details). ``` [protoc]: https://github.com/protocolbuffers/protobuf#protocol-compiler-installation -[installation instructions]: https://datafusion.apache.org/contributor-guide/#protoc-installation +[installation instructions]: https://datafusion.apache.org/contributor-guide/getting_started.html#protoc-installation diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index ecb41e4e263e..cd6c385b8918 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -27,14 +27,11 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.79" # Exclude proto files so crates.io consumers don't need protoc exclude = ["*.proto"] -[lints] -workspace = true - [lib] name = "datafusion_proto" path = "src/lib.rs" @@ -47,17 +44,21 @@ parquet = ["datafusion/parquet", "datafusion-common/parquet"] [dependencies] arrow = { workspace = true } chrono = { workspace = true } -datafusion = { workspace = true, default-features = true } +datafusion = { workspace = true, default-features = false } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-proto-common = { workspace = true } object_store = { workspace = true } -pbjson = { version = "0.6.0", optional = true } -prost = "0.12.0" +pbjson = { workspace = true, optional = true } +prost = { workspace = true } serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +datafusion = { workspace = true, default-features = true } datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-window-common = { workspace = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/gen/Cargo.toml b/datafusion/proto/gen/Cargo.toml index ca93706419e4..aee8fac4a120 100644 --- a/datafusion/proto/gen/Cargo.toml +++ b/datafusion/proto/gen/Cargo.toml @@ -20,7 +20,7 @@ name = "gen" description = "Code generation for proto" version = "0.1.0" edition = { workspace = true } -rust-version = "1.73" +rust-version = "1.79" authors = { workspace = true } homepage = { workspace = true } repository = { workspace = true } @@ -34,5 +34,5 @@ workspace = true [dependencies] # Pin these dependencies so that the generated output is deterministic -pbjson-build = "=0.6.2" -prost-build = "=0.12.4" +pbjson-build = "=0.7.0" +prost-build = "=0.13.3" diff --git a/datafusion/proto/gen/src/main.rs b/datafusion/proto/gen/src/main.rs index a8936a2cb9cc..be61ff58fa8d 100644 --- a/datafusion/proto/gen/src/main.rs +++ b/datafusion/proto/gen/src/main.rs @@ -21,16 +21,19 @@ type Error = Box; type Result = std::result::Result; fn main() -> Result<(), String> { - let proto_dir = Path::new("proto"); - let proto_path = Path::new("proto/datafusion.proto"); + let proto_dir = Path::new("datafusion/proto"); + let proto_path = Path::new("datafusion/proto/proto/datafusion.proto"); + let out_dir = Path::new("datafusion/proto/src"); // proto definitions has to be there - let descriptor_path = proto_dir.join("proto_descriptor.bin"); + let descriptor_path = proto_dir.join("proto/proto_descriptor.bin"); prost_build::Config::new() + .protoc_arg("--experimental_allow_proto3_optional") .file_descriptor_set_path(&descriptor_path) - .out_dir("src") + .out_dir(out_dir) .compile_well_known_types() + .protoc_arg("--experimental_allow_proto3_optional") .extern_path(".google.protobuf", "::pbjson_types") .compile_protos(&[proto_path], &["proto"]) .map_err(|e| format!("protobuf compilation failed: {e}"))?; @@ -39,7 +42,7 @@ fn main() -> Result<(), String> { .unwrap_or_else(|e| panic!("Cannot read {:?}: {}", &descriptor_path, e)); pbjson_build::Builder::new() - .out_dir("src") + .out_dir(out_dir) .register_descriptors(&descriptor_set) .unwrap_or_else(|e| { panic!("Cannot register descriptors {:?}: {}", &descriptor_set, e) @@ -47,11 +50,21 @@ fn main() -> Result<(), String> { .build(&[".datafusion"]) .map_err(|e| format!("pbjson compilation failed: {e}"))?; - let prost = Path::new("src/datafusion.rs"); - let pbjson = Path::new("src/datafusion.serde.rs"); - - std::fs::copy(prost, "src/generated/prost.rs").unwrap(); - std::fs::copy(pbjson, "src/generated/pbjson.rs").unwrap(); + let prost = proto_dir.join("src/datafusion.rs"); + let pbjson = proto_dir.join("src/datafusion.serde.rs"); + let common_path = proto_dir.join("src/datafusion_common.rs"); + println!( + "Copying {} to {}", + prost.display(), + proto_dir.join("src/generated/prost.rs").display() + ); + std::fs::copy(prost, proto_dir.join("src/generated/prost.rs")).unwrap(); + std::fs::copy(pbjson, proto_dir.join("src/generated/pbjson.rs")).unwrap(); + std::fs::copy( + common_path, + proto_dir.join("src/generated/datafusion_proto_common.rs"), + ) + .unwrap(); Ok(()) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 5bffdc3af774..b68c47c57eb9 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1,6 +1,4 @@ /* - - * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information @@ -26,24 +24,7 @@ option java_multiple_files = true; option java_package = "org.apache.arrow.datafusion.protobuf"; option java_outer_classname = "DatafusionProto"; -message ColumnRelation { - string relation = 1; -} - -message Column { - string name = 1; - ColumnRelation relation = 2; -} - -message DfField{ - Field field = 1; - ColumnRelation qualifier = 2; -} - -message DfSchema { - repeated DfField columns = 1; - map metadata = 2; -} +import "datafusion/proto-common/proto/datafusion_common.proto"; // logical plan // LogicalPlan is a nested type @@ -77,6 +58,7 @@ message LogicalPlanNode { DropViewNode drop_view = 27; DistinctOnNode distinct_on = 28; CopyToNode copy_to = 29; + UnnestNode unnest = 30; } } @@ -89,46 +71,39 @@ message ProjectionColumns { repeated string columns = 1; } -message CsvFormat { - CsvOptions options = 5; -} - -message ParquetFormat { - // Used to be bool enable_pruning = 1; - reserved 1; - TableParquetOptions options = 2; -} - -message AvroFormat {} - message LogicalExprNodeCollection { repeated LogicalExprNode logical_expr_nodes = 1; } +message SortExprNodeCollection { + repeated SortExprNode sort_expr_nodes = 1; +} + message ListingTableScanNode { reserved 1; // was string table_name TableReference table_name = 14; repeated string paths = 2; string file_extension = 3; ProjectionColumns projection = 4; - Schema schema = 5; + datafusion_common.Schema schema = 5; repeated LogicalExprNode filters = 6; repeated string table_partition_cols = 7; bool collect_stat = 8; uint32 target_partitions = 9; oneof FileFormatType { - CsvFormat csv = 10; - ParquetFormat parquet = 11; - AvroFormat avro = 12; + datafusion_common.CsvFormat csv = 10; + datafusion_common.ParquetFormat parquet = 11; + datafusion_common.AvroFormat avro = 12; + datafusion_common.NdJsonFormat json = 15; } - repeated LogicalExprNodeCollection file_sort_order = 13; + repeated SortExprNodeCollection file_sort_order = 13; } message ViewTableScanNode { reserved 1; // was string table_name TableReference table_name = 6; LogicalPlanNode input = 2; - Schema schema = 3; + datafusion_common.Schema schema = 3; ProjectionColumns projection = 4; string definition = 5; } @@ -138,7 +113,7 @@ message CustomTableScanNode { reserved 1; // was string table_name TableReference table_name = 6; ProjectionColumns projection = 2; - Schema schema = 3; + datafusion_common.Schema schema = 3; repeated LogicalExprNode filters = 4; bytes custom_table_data = 5; } @@ -158,7 +133,7 @@ message SelectionNode { message SortNode { LogicalPlanNode input = 1; - repeated LogicalExprNode expr = 2; + repeated SortExprNode expr = 2; // Maximum number of highest/lowest rows to fetch; negative means no limit int64 fetch = 3; } @@ -180,67 +155,45 @@ message EmptyRelationNode { bool produce_one_row = 1; } -message PrimaryKeyConstraint{ - repeated uint64 indices = 1; -} - -message UniqueConstraint{ - repeated uint64 indices = 1; -} - -message Constraint{ - oneof constraint_mode{ - PrimaryKeyConstraint primary_key = 1; - UniqueConstraint unique = 2; - } -} - -message Constraints{ - repeated Constraint constraints = 1; -} - message CreateExternalTableNode { reserved 1; // was string name - TableReference name = 12; + TableReference name = 9; string location = 2; string file_type = 3; - bool has_header = 4; - DfSchema schema = 5; - repeated string table_partition_cols = 6; - bool if_not_exists = 7; - string delimiter = 8; - string definition = 9; - reserved 10; // was string file_compression_type - CompressionTypeVariant file_compression_type = 17; - repeated LogicalExprNodeCollection order_exprs = 13; - bool unbounded = 14; - map options = 11; - Constraints constraints = 15; - map column_defaults = 16; + datafusion_common.DfSchema schema = 4; + repeated string table_partition_cols = 5; + bool if_not_exists = 6; + bool temporary = 14; + string definition = 7; + repeated SortExprNodeCollection order_exprs = 10; + bool unbounded = 11; + map options = 8; + datafusion_common.Constraints constraints = 12; + map column_defaults = 13; } message PrepareNode { string name = 1; - repeated ArrowType data_types = 2; + repeated datafusion_common.ArrowType data_types = 2; LogicalPlanNode input = 3; } message CreateCatalogSchemaNode { string schema_name = 1; bool if_not_exists = 2; - DfSchema schema = 3; + datafusion_common.DfSchema schema = 3; } message CreateCatalogNode { string catalog_name = 1; bool if_not_exists = 2; - DfSchema schema = 3; + datafusion_common.DfSchema schema = 3; } message DropViewNode { TableReference name = 1; bool if_exists = 2; - DfSchema schema = 3; + datafusion_common.DfSchema schema = 3; } message CreateViewNode { @@ -248,6 +201,7 @@ message CreateViewNode { TableReference name = 5; LogicalPlanNode input = 2; bool or_replace = 3; + bool temporary = 6; string definition = 4; } @@ -279,27 +233,11 @@ message WindowNode { repeated LogicalExprNode window_expr = 2; } -enum JoinType { - INNER = 0; - LEFT = 1; - RIGHT = 2; - FULL = 3; - LEFTSEMI = 4; - LEFTANTI = 5; - RIGHTSEMI = 6; - RIGHTANTI = 7; -} - -enum JoinConstraint { - ON = 0; - USING = 1; -} - message JoinNode { LogicalPlanNode left = 1; LogicalPlanNode right = 2; - JoinType join_type = 3; - JoinConstraint join_constraint = 4; + datafusion_common.JoinType join_type = 3; + datafusion_common.JoinConstraint join_constraint = 4; repeated LogicalExprNode left_join_key = 5; repeated LogicalExprNode right_join_key = 6; bool null_equals_null = 7; @@ -313,25 +251,50 @@ message DistinctNode { message DistinctOnNode { repeated LogicalExprNode on_expr = 1; repeated LogicalExprNode select_expr = 2; - repeated LogicalExprNode sort_expr = 3; + repeated SortExprNode sort_expr = 3; LogicalPlanNode input = 4; } message CopyToNode { - LogicalPlanNode input = 1; - string output_url = 2; - oneof format_options { - CsvOptions csv = 8; - JsonOptions json = 9; - TableParquetOptions parquet = 10; - AvroOptions avro = 11; - ArrowOptions arrow = 12; - } - repeated string partition_by = 7; + LogicalPlanNode input = 1; + string output_url = 2; + bytes file_type = 3; + repeated string partition_by = 7; +} + +message UnnestNode { + LogicalPlanNode input = 1; + repeated datafusion_common.Column exec_columns = 2; + repeated ColumnUnnestListItem list_type_columns = 3; + repeated uint64 struct_type_columns = 4; + repeated uint64 dependency_indices = 5; + datafusion_common.DfSchema schema = 6; + UnnestOptions options = 7; +} +message ColumnUnnestListItem { + uint32 input_index = 1; + ColumnUnnestListRecursion recursion = 2; } -message AvroOptions {} -message ArrowOptions {} +message ColumnUnnestListRecursions { + repeated ColumnUnnestListRecursion recursions = 2; +} + +message ColumnUnnestListRecursion { + datafusion_common.Column output_column = 1; + uint32 depth = 2; +} + +message UnnestOptions { + bool preserve_nulls = 1; + repeated RecursionUnnestOption recursions = 2; +} + +message RecursionUnnestOption { + datafusion_common.Column output_column = 1; + datafusion_common.Column input_column = 2; + uint32 depth = 3; +} message UnionNode { repeated LogicalPlanNode inputs = 1; @@ -364,18 +327,16 @@ message SubqueryAliasNode { message LogicalExprNode { oneof ExprType { // column references - Column column = 1; + datafusion_common.Column column = 1; // alias AliasNode alias = 2; - ScalarValue literal = 3; + datafusion_common.ScalarValue literal = 3; // binary expressions BinaryExprNode binary_expr = 4; - // aggregate expressions - AggregateExprNode aggregate_expr = 5; // null checks IsNull is_null_expr = 6; @@ -385,10 +346,10 @@ message LogicalExprNode { BetweenNode between = 9; CaseNode case_ = 10; CastNode cast = 11; - SortExprNode sort = 12; NegativeNode negative = 13; InListNode in_list = 14; Wildcard wildcard = 15; + // was ScalarFunctionNode scalar_function = 16; TryCastNode try_cast = 17; // window expressions @@ -400,7 +361,7 @@ message LogicalExprNode { // Scalar UDF expressions ScalarUDFExprNode scalar_udf_expr = 20; - GetIndexedField get_indexed_field = 21; + // GetIndexedField get_indexed_field = 21; GroupingSetNode grouping_set = 22; @@ -426,12 +387,12 @@ message LogicalExprNode { } message Wildcard { - string qualifier = 1; + TableReference qualifier = 1; } message PlaceholderNode { string id = 1; - ArrowType data_type = 2; + datafusion_common.ArrowType data_type = 2; } message LogicalExprList { @@ -451,7 +412,7 @@ message RollupNode { } message NamedStructField { - ScalarValue name = 1; + datafusion_common.ScalarValue name = 1; } message ListIndex { @@ -464,15 +425,6 @@ message ListRange { LogicalExprNode stride = 3; } -message GetIndexedField { - LogicalExprNode expr = 1; - oneof field { - NamedStructField named_struct_field = 2; - ListIndex list_index = 3; - ListRange list_range = 4; - } -} - message IsNull { LogicalExprNode expr = 1; } @@ -537,61 +489,14 @@ message InListNode { bool negated = 3; } -enum AggregateFunction { - MIN = 0; - MAX = 1; - SUM = 2; - AVG = 3; - COUNT = 4; - APPROX_DISTINCT = 5; - ARRAY_AGG = 6; - VARIANCE = 7; - VARIANCE_POP = 8; - COVARIANCE = 9; - COVARIANCE_POP = 10; - STDDEV = 11; - STDDEV_POP = 12; - CORRELATION = 13; - APPROX_PERCENTILE_CONT = 14; - APPROX_MEDIAN = 15; - APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; - GROUPING = 17; - MEDIAN = 18; - BIT_AND = 19; - BIT_OR = 20; - BIT_XOR = 21; - BOOL_AND = 22; - BOOL_OR = 23; - // When a function with the same name exists among built-in window functions, - // we append "_AGG" to obey name scoping rules. - FIRST_VALUE_AGG = 24; - LAST_VALUE_AGG = 25; - REGR_SLOPE = 26; - REGR_INTERCEPT = 27; - REGR_COUNT = 28; - REGR_R2 = 29; - REGR_AVGX = 30; - REGR_AVGY = 31; - REGR_SXX = 32; - REGR_SYY = 33; - REGR_SXY = 34; - STRING_AGG = 35; - NTH_VALUE_AGG = 36; -} - -message AggregateExprNode { - AggregateFunction aggr_function = 1; - repeated LogicalExprNode expr = 2; - bool distinct = 3; - LogicalExprNode filter = 4; - repeated LogicalExprNode order_by = 5; -} message AggregateUDFExprNode { string fun_name = 1; repeated LogicalExprNode args = 2; + bool distinct = 5; LogicalExprNode filter = 3; - repeated LogicalExprNode order_by = 4; + repeated SortExprNode order_by = 4; + optional bytes fun_definition = 6; } message ScalarUDFExprNode { @@ -601,14 +506,15 @@ message ScalarUDFExprNode { } enum BuiltInWindowFunction { - ROW_NUMBER = 0; - RANK = 1; - DENSE_RANK = 2; - PERCENT_RANK = 3; - CUME_DIST = 4; - NTILE = 5; - LAG = 6; - LEAD = 7; + UNSPECIFIED = 0; // https://protobuf.dev/programming-guides/dos-donts/#unspecified-enum + // ROW_NUMBER = 0; + // RANK = 1; + // DENSE_RANK = 2; + // PERCENT_RANK = 3; + // CUME_DIST = 4; + // NTILE = 5; + // LAG = 6; + // LEAD = 7; FIRST_VALUE = 8; LAST_VALUE = 9; NTH_VALUE = 10; @@ -616,16 +522,16 @@ enum BuiltInWindowFunction { message WindowExprNode { oneof window_function { - AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; string udaf = 3; string udwf = 9; } - LogicalExprNode expr = 4; + repeated LogicalExprNode exprs = 4; repeated LogicalExprNode partition_by = 5; - repeated LogicalExprNode order_by = 6; + repeated SortExprNode order_by = 6; // repeated LogicalExprNode filter = 7; WindowFrame window_frame = 8; + optional bytes fun_definition = 10; } message BetweenNode { @@ -669,12 +575,12 @@ message WhenThen { message CastNode { LogicalExprNode expr = 1; - ArrowType arrow_type = 2; + datafusion_common.ArrowType arrow_type = 2; } message TryCastNode { LogicalExprNode expr = 1; - ArrowType arrow_type = 2; + datafusion_common.ArrowType arrow_type = 2; } message SortExprNode { @@ -707,273 +613,22 @@ enum WindowFrameBoundType { message WindowFrameBound { WindowFrameBoundType window_frame_bound_type = 1; - ScalarValue bound_value = 2; + datafusion_common.ScalarValue bound_value = 2; } /////////////////////////////////////////////////////////////////////////////////////////////////// // Arrow Data Types /////////////////////////////////////////////////////////////////////////////////////////////////// -message Schema { - repeated Field columns = 1; - map metadata = 2; -} - -message Field { - // name of the field - string name = 1; - ArrowType arrow_type = 2; - bool nullable = 3; - // for complex data types like structs, unions - repeated Field children = 4; - map metadata = 5; - int64 dict_id = 6; - bool dict_ordered = 7; -} - message FixedSizeBinary{ int32 length = 1; } -message Timestamp{ - TimeUnit time_unit = 1; - string timezone = 2; -} - enum DateUnit{ Day = 0; DateMillisecond = 1; } -enum TimeUnit{ - Second = 0; - Millisecond = 1; - Microsecond = 2; - Nanosecond = 3; -} - -enum IntervalUnit{ - YearMonth = 0; - DayTime = 1; - MonthDayNano = 2; -} - -message Decimal{ - reserved 1, 2; - uint32 precision = 3; - int32 scale = 4; -} - -message List{ - Field field_type = 1; -} - -message FixedSizeList{ - Field field_type = 1; - int32 list_size = 2; -} - -message Dictionary{ - ArrowType key = 1; - ArrowType value = 2; -} - -message Struct{ - repeated Field sub_field_types = 1; -} - -message Map { - Field field_type = 1; - bool keys_sorted = 2; -} - -enum UnionMode{ - sparse = 0; - dense = 1; -} - -message Union{ - repeated Field union_types = 1; - UnionMode union_mode = 2; - repeated int32 type_ids = 3; -} - -// Used for List/FixedSizeList/LargeList/Struct -message ScalarNestedValue { - bytes ipc_message = 1; - bytes arrow_data = 2; - Schema schema = 3; -} - -message ScalarTime32Value { - oneof value { - int32 time32_second_value = 1; - int32 time32_millisecond_value = 2; - }; -} - -message ScalarTime64Value { - oneof value { - int64 time64_microsecond_value = 1; - int64 time64_nanosecond_value = 2; - }; -} - -message ScalarTimestampValue { - oneof value { - int64 time_microsecond_value = 1; - int64 time_nanosecond_value = 2; - int64 time_second_value = 3; - int64 time_millisecond_value = 4; - }; - string timezone = 5; -} - -message ScalarDictionaryValue { - ArrowType index_type = 1; - ScalarValue value = 2; -} - -message IntervalMonthDayNanoValue { - int32 months = 1; - int32 days = 2; - int64 nanos = 3; -} - -message UnionField { - int32 field_id = 1; - Field field = 2; -} - -message UnionValue { - // Note that a null union value must have one or more fields, so we - // encode a null UnionValue as one with value_id == 128 - int32 value_id = 1; - ScalarValue value = 2; - repeated UnionField fields = 3; - UnionMode mode = 4; -} - -message ScalarFixedSizeBinary{ - bytes values = 1; - int32 length = 2; -} - -message ScalarValue{ - // was PrimitiveScalarType null_value = 19; - reserved 19; - - oneof value { - // was PrimitiveScalarType null_value = 19; - // Null value of any type - ArrowType null_value = 33; - - bool bool_value = 1; - string utf8_value = 2; - string large_utf8_value = 3; - int32 int8_value = 4; - int32 int16_value = 5; - int32 int32_value = 6; - int64 int64_value = 7; - uint32 uint8_value = 8; - uint32 uint16_value = 9; - uint32 uint32_value = 10; - uint64 uint64_value = 11; - float float32_value = 12; - double float64_value = 13; - // Literal Date32 value always has a unit of day - int32 date_32_value = 14; - ScalarTime32Value time32_value = 15; - ScalarNestedValue large_list_value = 16; - ScalarNestedValue list_value = 17; - ScalarNestedValue fixed_size_list_value = 18; - ScalarNestedValue struct_value = 32; - - Decimal128 decimal128_value = 20; - Decimal256 decimal256_value = 39; - - int64 date_64_value = 21; - int32 interval_yearmonth_value = 24; - int64 interval_daytime_value = 25; - - int64 duration_second_value = 35; - int64 duration_millisecond_value = 36; - int64 duration_microsecond_value = 37; - int64 duration_nanosecond_value = 38; - - ScalarTimestampValue timestamp_value = 26; - ScalarDictionaryValue dictionary_value = 27; - bytes binary_value = 28; - bytes large_binary_value = 29; - ScalarTime64Value time64_value = 30; - IntervalMonthDayNanoValue interval_month_day_nano = 31; - ScalarFixedSizeBinary fixed_size_binary_value = 34; - UnionValue union_value = 42; - } -} - -message Decimal128{ - bytes value = 1; - int64 p = 2; - int64 s = 3; -} - -message Decimal256{ - bytes value = 1; - int64 p = 2; - int64 s = 3; -} - -// Serialized data type -message ArrowType{ - oneof arrow_type_enum { - EmptyMessage NONE = 1; // arrow::Type::NA - EmptyMessage BOOL = 2; // arrow::Type::BOOL - EmptyMessage UINT8 = 3; // arrow::Type::UINT8 - EmptyMessage INT8 = 4; // arrow::Type::INT8 - EmptyMessage UINT16 = 5; // represents arrow::Type fields in src/arrow/type.h - EmptyMessage INT16 = 6; - EmptyMessage UINT32 = 7; - EmptyMessage INT32 = 8; - EmptyMessage UINT64 = 9; - EmptyMessage INT64 = 10 ; - EmptyMessage FLOAT16 = 11 ; - EmptyMessage FLOAT32 = 12 ; - EmptyMessage FLOAT64 = 13 ; - EmptyMessage UTF8 = 14 ; - EmptyMessage LARGE_UTF8 = 32; - EmptyMessage BINARY = 15 ; - int32 FIXED_SIZE_BINARY = 16 ; - EmptyMessage LARGE_BINARY = 31; - EmptyMessage DATE32 = 17 ; - EmptyMessage DATE64 = 18 ; - TimeUnit DURATION = 19; - Timestamp TIMESTAMP = 20 ; - TimeUnit TIME32 = 21 ; - TimeUnit TIME64 = 22 ; - IntervalUnit INTERVAL = 23 ; - Decimal DECIMAL = 24 ; - List LIST = 25; - List LARGE_LIST = 26; - FixedSizeList FIXED_SIZE_LIST = 27; - Struct STRUCT = 28; - Union UNION = 29; - Dictionary DICTIONARY = 30; - Map MAP = 33; - } -} - -//Useful for representing an empty enum variant in rust -// E.G. enum example{One, Two(i32)} -// maps to -// message example{ -// oneof{ -// EmptyMessage One = 1; -// i32 Two = 2; -// } -//} -message EmptyMessage{} - message AnalyzedLogicalPlanType { string analyzer_name = 1; } @@ -988,16 +643,18 @@ message OptimizedPhysicalPlanType { message PlanType { oneof plan_type_enum { - EmptyMessage InitialLogicalPlan = 1; + datafusion_common.EmptyMessage InitialLogicalPlan = 1; AnalyzedLogicalPlanType AnalyzedLogicalPlan = 7; - EmptyMessage FinalAnalyzedLogicalPlan = 8; + datafusion_common.EmptyMessage FinalAnalyzedLogicalPlan = 8; OptimizedLogicalPlanType OptimizedLogicalPlan = 2; - EmptyMessage FinalLogicalPlan = 3; - EmptyMessage InitialPhysicalPlan = 4; - EmptyMessage InitialPhysicalPlanWithStats = 9; + datafusion_common.EmptyMessage FinalLogicalPlan = 3; + datafusion_common.EmptyMessage InitialPhysicalPlan = 4; + datafusion_common.EmptyMessage InitialPhysicalPlanWithStats = 9; + datafusion_common.EmptyMessage InitialPhysicalPlanWithSchema = 11; OptimizedPhysicalPlanType OptimizedPhysicalPlan = 5; - EmptyMessage FinalPhysicalPlan = 6; - EmptyMessage FinalPhysicalPlanWithStats = 10; + datafusion_common.EmptyMessage FinalPhysicalPlan = 6; + datafusion_common.EmptyMessage FinalPhysicalPlanWithStats = 10; + datafusion_common.EmptyMessage FinalPhysicalPlanWithSchema = 12; } } @@ -1058,226 +715,88 @@ message PhysicalPlanNode { AnalyzeExecNode analyze = 23; JsonSinkExecNode json_sink = 24; SymmetricHashJoinExecNode symmetric_hash_join = 25; - InterleaveExecNode interleave = 26; + InterleaveExecNode interleave = 26; PlaceholderRowExecNode placeholder_row = 27; CsvSinkExecNode csv_sink = 28; ParquetSinkExecNode parquet_sink = 29; + UnnestExecNode unnest = 30; } } -enum CompressionTypeVariant { - GZIP = 0; - BZIP2 = 1; - XZ = 2; - ZSTD = 3; - UNCOMPRESSED = 4; -} - message PartitionColumn { string name = 1; - ArrowType arrow_type = 2; -} - - -message JsonWriterOptions { - CompressionTypeVariant compression = 1; -} - - -message CsvWriterOptions { - // Compression type - CompressionTypeVariant compression = 1; - // Optional column delimiter. Defaults to `b','` - string delimiter = 2; - // Whether to write column names as file headers. Defaults to `true` - bool has_header = 3; - // Optional date format for date arrays - string date_format = 4; - // Optional datetime format for datetime arrays - string datetime_format = 5; - // Optional timestamp format for timestamp arrays - string timestamp_format = 6; - // Optional time format for time arrays - string time_format = 7; - // Optional value to represent null - string null_value = 8; -} - -// Options controlling CSV format -message CsvOptions { - bool has_header = 1; // Indicates if the CSV has a header row - bytes delimiter = 2; // Delimiter character as a byte - bytes quote = 3; // Quote character as a byte - bytes escape = 4; // Optional escape character as a byte - CompressionTypeVariant compression = 5; // Compression type - uint64 schema_infer_max_rec = 6; // Max records for schema inference - string date_format = 7; // Optional date format - string datetime_format = 8; // Optional datetime format - string timestamp_format = 9; // Optional timestamp format - string timestamp_tz_format = 10; // Optional timestamp with timezone format - string time_format = 11; // Optional time format - string null_value = 12; // Optional representation of null value + datafusion_common.ArrowType arrow_type = 2; } -// Options controlling CSV format -message JsonOptions { - CompressionTypeVariant compression = 1; // Compression type - uint64 schema_infer_max_rec = 2; // Max records for schema inference -} message FileSinkConfig { reserved 6; // writer_mode + reserved 8; // was `overwrite` which has been superseded by `insert_op` string object_store_url = 1; repeated PartitionedFile file_groups = 2; repeated string table_paths = 3; - Schema output_schema = 4; + datafusion_common.Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; - bool overwrite = 8; + bool keep_partition_by_columns = 9; + InsertOp insert_op = 10; +} + +enum InsertOp { + Append = 0; + Overwrite = 1; + Replace = 2; } message JsonSink { FileSinkConfig config = 1; - JsonWriterOptions writer_options = 2; + datafusion_common.JsonWriterOptions writer_options = 2; } message JsonSinkExecNode { PhysicalPlanNode input = 1; JsonSink sink = 2; - Schema sink_schema = 3; + datafusion_common.Schema sink_schema = 3; PhysicalSortExprNodeCollection sort_order = 4; } message CsvSink { FileSinkConfig config = 1; - CsvWriterOptions writer_options = 2; + datafusion_common.CsvWriterOptions writer_options = 2; } message CsvSinkExecNode { PhysicalPlanNode input = 1; CsvSink sink = 2; - Schema sink_schema = 3; + datafusion_common.Schema sink_schema = 3; PhysicalSortExprNodeCollection sort_order = 4; } -message TableParquetOptions { - ParquetOptions global = 1; - repeated ColumnSpecificOptions column_specific_options = 2; -} - -message ColumnSpecificOptions { - string column_name = 1; - ColumnOptions options = 2; -} - -message ColumnOptions { - oneof bloom_filter_enabled_opt { - bool bloom_filter_enabled = 1; - } - - oneof encoding_opt { - string encoding = 2; - } - - oneof dictionary_enabled_opt { - bool dictionary_enabled = 3; - } - - oneof compression_opt { - string compression = 4; - } - - oneof statistics_enabled_opt { - string statistics_enabled = 5; - } - - oneof bloom_filter_fpp_opt { - double bloom_filter_fpp = 6; - } - - oneof bloom_filter_ndv_opt { - uint64 bloom_filter_ndv = 7; - } - - oneof max_statistics_size_opt { - uint32 max_statistics_size = 8; - } -} - -message ParquetOptions { - // Regular fields - bool enable_page_index = 1; // default = true - bool pruning = 2; // default = true - bool skip_metadata = 3; // default = true - bool pushdown_filters = 5; // default = false - bool reorder_filters = 6; // default = false - uint64 data_pagesize_limit = 7; // default = 1024 * 1024 - uint64 write_batch_size = 8; // default = 1024 - string writer_version = 9; // default = "1.0" - bool bloom_filter_enabled = 20; // default = false - bool allow_single_file_parallelism = 23; // default = true - uint64 maximum_parallel_row_group_writers = 24; // default = 1 - uint64 maximum_buffered_record_batches_per_stream = 25; // default = 2 - - oneof metadata_size_hint_opt { - uint64 metadata_size_hint = 4; - } - - oneof compression_opt { - string compression = 10; - } - - oneof dictionary_enabled_opt { - bool dictionary_enabled = 11; - } - - oneof statistics_enabled_opt { - string statistics_enabled = 13; - } - - oneof max_statistics_size_opt { - uint64 max_statistics_size = 14; - } - - oneof column_index_truncate_length_opt { - uint64 column_index_truncate_length = 17; - } - - oneof encoding_opt { - string encoding = 19; - } - - oneof bloom_filter_fpp_opt { - double bloom_filter_fpp = 21; - } - - oneof bloom_filter_ndv_opt { - uint64 bloom_filter_ndv = 22; - } - - uint64 dictionary_page_size_limit = 12; - - uint64 data_page_row_count_limit = 18; - - uint64 max_row_group_size = 15; - - string created_by = 16; -} - - - message ParquetSink { FileSinkConfig config = 1; - TableParquetOptions parquet_options = 2; + datafusion_common.TableParquetOptions parquet_options = 2; } message ParquetSinkExecNode { PhysicalPlanNode input = 1; ParquetSink sink = 2; - Schema sink_schema = 3; + datafusion_common.Schema sink_schema = 3; PhysicalSortExprNodeCollection sort_order = 4; } +message UnnestExecNode { + PhysicalPlanNode input = 1; + datafusion_common.Schema schema = 2; + repeated ListUnnest list_type_columns = 3; + repeated uint64 struct_type_columns = 4; + UnnestOptions options = 5; +} + +message ListUnnest { + uint32 index_in_input_schema = 1; + uint32 depth = 2; +} + message PhysicalExtensionNode { bytes node = 1; repeated PhysicalPlanNode inputs = 2; @@ -1292,7 +811,7 @@ message PhysicalExprNode { // column references PhysicalColumn column = 1; - ScalarValue literal = 2; + datafusion_common.ScalarValue literal = 2; // binary expressions PhysicalBinaryExprNode binary_expr = 3; @@ -1310,14 +829,17 @@ message PhysicalExprNode { PhysicalSortExprNode sort = 10; PhysicalNegativeNode negative = 11; PhysicalInListNode in_list = 12; + // was PhysicalScalarFunctionNode scalar_function = 13; PhysicalTryCastNode try_cast = 14; - // window expressions PhysicalWindowExprNode window_expr = 15; PhysicalScalarUdfNode scalar_udf = 16; + // was PhysicalDateTimeIntervalExprNode date_time_interval_expr = 17; PhysicalLikeExprNode like_expr = 18; + + PhysicalExtensionExprNode extension = 19; } } @@ -1325,30 +847,31 @@ message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; optional bytes fun_definition = 3; - ArrowType return_type = 4; + datafusion_common.ArrowType return_type = 4; } message PhysicalAggregateExprNode { oneof AggregateFunction { - AggregateFunction aggr_function = 1; string user_defined_aggr_function = 4; } repeated PhysicalExprNode expr = 2; repeated PhysicalSortExprNode ordering_req = 5; bool distinct = 3; + bool ignore_nulls = 6; + optional bytes fun_definition = 7; } message PhysicalWindowExprNode { oneof window_function { - AggregateFunction aggr_function = 1; BuiltInWindowFunction built_in_function = 2; - // udaf = 3 + string user_defined_aggr_function = 3; } repeated PhysicalExprNode args = 4; repeated PhysicalExprNode partition_by = 5; repeated PhysicalSortExprNode order_by = 6; WindowFrame window_frame = 7; string name = 8; + optional bytes fun_definition = 9; } message PhysicalIsNull { @@ -1412,22 +935,28 @@ message PhysicalCaseNode { message PhysicalTryCastNode { PhysicalExprNode expr = 1; - ArrowType arrow_type = 2; + datafusion_common.ArrowType arrow_type = 2; } message PhysicalCastNode { PhysicalExprNode expr = 1; - ArrowType arrow_type = 2; + datafusion_common.ArrowType arrow_type = 2; } message PhysicalNegativeNode { PhysicalExprNode expr = 1; } +message PhysicalExtensionExprNode { + bytes expr = 1; + repeated PhysicalExprNode inputs = 2; +} + message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; uint32 default_filter_selectivity = 3; + repeated uint32 projection = 9; } message FileGroup { @@ -1448,10 +977,10 @@ message FileScanExecConf { reserved 10; repeated FileGroup file_groups = 1; - Schema schema = 2; + datafusion_common.Schema schema = 2; repeated uint32 projection = 4; ScanLimit limit = 5; - Statistics statistics = 6; + datafusion_common.Statistics statistics = 6; repeated string table_partition_cols = 7; string object_store_url = 8; repeated PhysicalSortExprNodeCollection output_ordering = 9; @@ -1474,6 +1003,10 @@ message CsvScanExecNode { oneof optional_escape { string escape = 5; } + oneof optional_comment { + string comment = 6; + } + bool newlines_in_values = 7; } message AvroScanExecNode { @@ -1490,7 +1023,7 @@ message HashJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; repeated JoinOn on = 3; - JoinType join_type = 4; + datafusion_common.JoinType join_type = 4; PartitionMode partition_mode = 6; bool null_equals_null = 7; JoinFilter filter = 8; @@ -1506,7 +1039,7 @@ message SymmetricHashJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; repeated JoinOn on = 3; - JoinType join_type = 4; + datafusion_common.JoinType join_type = 4; StreamPartitionMode partition_mode = 6; bool null_equals_null = 7; JoinFilter filter = 8; @@ -1523,7 +1056,7 @@ message UnionExecNode { } message ExplainExecNode { - Schema schema = 1; + datafusion_common.Schema schema = 1; repeated StringifiedPlan stringified_plans = 2; bool verbose = 3; } @@ -1532,7 +1065,7 @@ message AnalyzeExecNode { bool verbose = 1; bool show_statistics = 2; PhysicalPlanNode input = 3; - Schema schema = 4; + datafusion_common.Schema schema = 4; } message CrossJoinExecNode { @@ -1551,11 +1084,11 @@ message JoinOn { } message EmptyExecNode { - Schema schema = 1; + datafusion_common.Schema schema = 1; } message PlaceholderRowExecNode { - Schema schema = 1; + datafusion_common.Schema schema = 1; } message ProjectionExecNode { @@ -1582,9 +1115,9 @@ message WindowAggExecNode { repeated PhysicalExprNode partition_keys = 5; // Set optional to `None` for `BoundedWindowAggExec`. oneof input_order_mode { - EmptyMessage linear = 7; + datafusion_common.EmptyMessage linear = 7; PartiallySortedInputOrderMode partially_sorted = 8; - EmptyMessage sorted = 9; + datafusion_common.EmptyMessage sorted = 9; } } @@ -1596,6 +1129,11 @@ message MaybePhysicalSortExprs { repeated PhysicalSortExprNode sort_expr = 1; } +message AggLimit { + // wrap into a message to make it optional + uint64 limit = 1; +} + message AggregateExecNode { repeated PhysicalExprNode group_expr = 1; repeated PhysicalExprNode aggr_expr = 2; @@ -1604,10 +1142,11 @@ message AggregateExecNode { repeated string group_expr_name = 5; repeated string aggr_expr_name = 6; // we need the input schema to the partial aggregate to pass to the final aggregate - Schema input_schema = 7; + datafusion_common.Schema input_schema = 7; repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; repeated MaybeFilter filter_expr = 10; + AggLimit limit = 11; } message GlobalLimitExecNode { @@ -1641,13 +1180,14 @@ message SortPreservingMergeExecNode { message NestedLoopJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; - JoinType join_type = 3; + datafusion_common.JoinType join_type = 3; JoinFilter filter = 4; } message CoalesceBatchesExecNode { PhysicalPlanNode input = 1; uint32 target_batch_size = 2; + optional uint32 fetch = 3; } message CoalescePartitionsExecNode { @@ -1661,35 +1201,40 @@ message PhysicalHashRepartition { message RepartitionExecNode{ PhysicalPlanNode input = 1; + // oneof partition_method { + // uint64 round_robin = 2; + // PhysicalHashRepartition hash = 3; + // uint64 unknown = 4; + // } + Partitioning partitioning = 5; +} + +message Partitioning { oneof partition_method { - uint64 round_robin = 2; - PhysicalHashRepartition hash = 3; - uint64 unknown = 4; + uint64 round_robin = 1; + PhysicalHashRepartition hash = 2; + uint64 unknown = 3; } } message JoinFilter{ PhysicalExprNode expression = 1; repeated ColumnIndex column_indices = 2; - Schema schema = 3; + datafusion_common.Schema schema = 3; } message ColumnIndex{ uint32 index = 1; - JoinSide side = 2; -} - -enum JoinSide{ - LEFT_SIDE = 0; - RIGHT_SIDE = 1; + datafusion_common.JoinSide side = 2; } message PartitionedFile { string path = 1; uint64 size = 2; uint64 last_modified_ns = 3; - repeated ScalarValue partition_values = 4; + repeated datafusion_common.ScalarValue partition_values = 4; FileRange range = 5; + datafusion_common.Statistics statistics = 6; } message FileRange { @@ -1701,29 +1246,5 @@ message PartitionStats { int64 num_rows = 1; int64 num_batches = 2; int64 num_bytes = 3; - repeated ColumnStats column_stats = 4; -} - -message Precision{ - PrecisionInfo precision_info = 1; - ScalarValue val = 2; -} - -enum PrecisionInfo { - EXACT = 0; - INEXACT = 1; - ABSENT = 2; -} - -message Statistics { - Precision num_rows = 1; - Precision total_byte_size = 2; - repeated ColumnStats column_stats = 3; -} - -message ColumnStats { - Precision min_value = 1; - Precision max_value = 2; - Precision null_count = 3; - Precision distinct_count = 4; + repeated datafusion_common.ColumnStats column_stats = 4; } diff --git a/datafusion/proto/regen.sh b/datafusion/proto/regen.sh index 4b7ad4d58533..02970a90add4 100755 --- a/datafusion/proto/regen.sh +++ b/datafusion/proto/regen.sh @@ -17,5 +17,5 @@ # specific language governing permissions and limitations # under the License. -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cd $SCRIPT_DIR && cargo run --manifest-path gen/Cargo.toml \ No newline at end of file +repo_root=$(git rev-parse --show-toplevel) +cd "$repo_root" && cargo run --manifest-path datafusion/proto/gen/Cargo.toml diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 901aa2455e16..12ddb4cb2e32 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -39,6 +39,7 @@ use std::sync::Arc; use datafusion::execution::registry::FunctionRegistry; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; +use datafusion_expr::planner::ExprPlanner; mod registry; @@ -115,7 +116,7 @@ impl Serializeable for Expr { Ok(Arc::new(create_udf( name, vec![], - Arc::new(arrow::datatypes::DataType::Null), + arrow::datatypes::DataType::Null, Volatility::Immutable, Arc::new(|_| unimplemented!()), ))) @@ -165,6 +166,10 @@ impl Serializeable for Expr { "register_udwf called in Placeholder Registry!" ) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } Expr::from_bytes_with_registry(&bytes, &PlaceHolderRegistry)?; diff --git a/datafusion/proto/src/bytes/registry.rs b/datafusion/proto/src/bytes/registry.rs index 4bf2bb3d7b79..eae2425f8ac1 100644 --- a/datafusion/proto/src/bytes/registry.rs +++ b/datafusion/proto/src/bytes/registry.rs @@ -20,6 +20,7 @@ use std::{collections::HashSet, sync::Arc}; use datafusion::execution::registry::FunctionRegistry; use datafusion_common::plan_err; use datafusion_common::Result; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; /// A default [`FunctionRegistry`] registry that does not resolve any @@ -54,4 +55,8 @@ impl FunctionRegistry for NoRegistry { fn register_udwf(&mut self, udwf: Arc) -> Result>> { plan_err!("No function registry provided to deserialize, so can not deserialize User Defined Window Function '{}'", udwf.inner().name()) } + + fn expr_planners(&self) -> Vec> { + vec![] + } } diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index b18831048e1a..2b052a31b8b7 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -68,7 +68,3 @@ macro_rules! convert_box_required { } }}; } - -pub fn proto_error>(message: S) -> DataFusionError { - DataFusionError::Internal(message.into()) -} diff --git a/datafusion/proto/src/generated/datafusion_proto_common.rs b/datafusion/proto/src/generated/datafusion_proto_common.rs new file mode 100644 index 000000000000..68e7f74c7f49 --- /dev/null +++ b/datafusion/proto/src/generated/datafusion_proto_common.rs @@ -0,0 +1,1127 @@ +// This file is @generated by prost-build. +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnRelation { + #[prost(string, tag = "1")] + pub relation: ::prost::alloc::string::String, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Column { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub relation: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DfField { + #[prost(message, optional, tag = "1")] + pub field: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub qualifier: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DfSchema { + #[prost(message, repeated, tag = "1")] + pub columns: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "2")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvFormat { + #[prost(message, optional, tag = "5")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetFormat { + #[prost(message, optional, tag = "2")] + pub options: ::core::option::Option, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AvroFormat {} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct NdJsonFormat { + #[prost(message, optional, tag = "1")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PrimaryKeyConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UniqueConstraint { + #[prost(uint64, repeated, tag = "1")] + pub indices: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraint { + #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] + pub constraint_mode: ::core::option::Option, +} +/// Nested message and enum types in `Constraint`. +pub mod constraint { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum ConstraintMode { + #[prost(message, tag = "1")] + PrimaryKey(super::PrimaryKeyConstraint), + #[prost(message, tag = "2")] + Unique(super::UniqueConstraint), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Constraints { + #[prost(message, repeated, tag = "1")] + pub constraints: ::prost::alloc::vec::Vec, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AvroOptions {} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ArrowOptions {} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Schema { + #[prost(message, repeated, tag = "1")] + pub columns: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "2")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Field { + /// name of the field + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, boxed, tag = "2")] + pub arrow_type: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "3")] + pub nullable: bool, + /// for complex data types like structs, unions + #[prost(message, repeated, tag = "4")] + pub children: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "5")] + pub metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, + #[prost(int64, tag = "6")] + pub dict_id: i64, + #[prost(bool, tag = "7")] + pub dict_ordered: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Timestamp { + #[prost(enumeration = "TimeUnit", tag = "1")] + pub time_unit: i32, + #[prost(string, tag = "2")] + pub timezone: ::prost::alloc::string::String, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Decimal { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct Decimal256Type { + #[prost(uint32, tag = "3")] + pub precision: u32, + #[prost(int32, tag = "4")] + pub scale: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct List { + #[prost(message, optional, boxed, tag = "1")] + pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FixedSizeList { + #[prost(message, optional, boxed, tag = "1")] + pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(int32, tag = "2")] + pub list_size: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Dictionary { + #[prost(message, optional, boxed, tag = "1")] + pub key: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Struct { + #[prost(message, repeated, tag = "1")] + pub sub_field_types: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Map { + #[prost(message, optional, boxed, tag = "1")] + pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(bool, tag = "2")] + pub keys_sorted: bool, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Union { + #[prost(message, repeated, tag = "1")] + pub union_types: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "2")] + pub union_mode: i32, + #[prost(int32, repeated, tag = "3")] + pub type_ids: ::prost::alloc::vec::Vec, +} +/// Used for List/FixedSizeList/LargeList/Struct/Map +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarNestedValue { + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "3")] + pub schema: ::core::option::Option, + #[prost(message, repeated, tag = "4")] + pub dictionaries: ::prost::alloc::vec::Vec, +} +/// Nested message and enum types in `ScalarNestedValue`. +pub mod scalar_nested_value { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Dictionary { + #[prost(bytes = "vec", tag = "1")] + pub ipc_message: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", tag = "2")] + pub arrow_data: ::prost::alloc::vec::Vec, + } +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ScalarTime32Value { + #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarTime32Value`. +pub mod scalar_time32_value { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int32, tag = "1")] + Time32SecondValue(i32), + #[prost(int32, tag = "2")] + Time32MillisecondValue(i32), + } +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ScalarTime64Value { + #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarTime64Value`. +pub mod scalar_time64_value { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int64, tag = "1")] + Time64MicrosecondValue(i64), + #[prost(int64, tag = "2")] + Time64NanosecondValue(i64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarTimestampValue { + #[prost(string, tag = "5")] + pub timezone: ::prost::alloc::string::String, + #[prost(oneof = "scalar_timestamp_value::Value", tags = "1, 2, 3, 4")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarTimestampValue`. +pub mod scalar_timestamp_value { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum Value { + #[prost(int64, tag = "1")] + TimeMicrosecondValue(i64), + #[prost(int64, tag = "2")] + TimeNanosecondValue(i64), + #[prost(int64, tag = "3")] + TimeSecondValue(i64), + #[prost(int64, tag = "4")] + TimeMillisecondValue(i64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarDictionaryValue { + #[prost(message, optional, tag = "1")] + pub index_type: ::core::option::Option, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct IntervalDayTimeValue { + #[prost(int32, tag = "1")] + pub days: i32, + #[prost(int32, tag = "2")] + pub milliseconds: i32, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct IntervalMonthDayNanoValue { + #[prost(int32, tag = "1")] + pub months: i32, + #[prost(int32, tag = "2")] + pub days: i32, + #[prost(int64, tag = "3")] + pub nanos: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarFixedSizeBinary { + #[prost(bytes = "vec", tag = "1")] + pub values: ::prost::alloc::vec::Vec, + #[prost(int32, tag = "2")] + pub length: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ScalarValue { + #[prost( + oneof = "scalar_value::Value", + tags = "33, 1, 2, 3, 23, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 41, 20, 39, 21, 24, 35, 36, 37, 38, 26, 27, 28, 29, 22, 30, 25, 31, 34, 42" + )] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `ScalarValue`. +pub mod scalar_value { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + /// was PrimitiveScalarType null_value = 19; + /// Null value of any type + #[prost(message, tag = "33")] + NullValue(super::ArrowType), + #[prost(bool, tag = "1")] + BoolValue(bool), + #[prost(string, tag = "2")] + Utf8Value(::prost::alloc::string::String), + #[prost(string, tag = "3")] + LargeUtf8Value(::prost::alloc::string::String), + #[prost(string, tag = "23")] + Utf8ViewValue(::prost::alloc::string::String), + #[prost(int32, tag = "4")] + Int8Value(i32), + #[prost(int32, tag = "5")] + Int16Value(i32), + #[prost(int32, tag = "6")] + Int32Value(i32), + #[prost(int64, tag = "7")] + Int64Value(i64), + #[prost(uint32, tag = "8")] + Uint8Value(u32), + #[prost(uint32, tag = "9")] + Uint16Value(u32), + #[prost(uint32, tag = "10")] + Uint32Value(u32), + #[prost(uint64, tag = "11")] + Uint64Value(u64), + #[prost(float, tag = "12")] + Float32Value(f32), + #[prost(double, tag = "13")] + Float64Value(f64), + /// Literal Date32 value always has a unit of day + #[prost(int32, tag = "14")] + Date32Value(i32), + #[prost(message, tag = "15")] + Time32Value(super::ScalarTime32Value), + #[prost(message, tag = "16")] + LargeListValue(super::ScalarNestedValue), + #[prost(message, tag = "17")] + ListValue(super::ScalarNestedValue), + #[prost(message, tag = "18")] + FixedSizeListValue(super::ScalarNestedValue), + #[prost(message, tag = "32")] + StructValue(super::ScalarNestedValue), + #[prost(message, tag = "41")] + MapValue(super::ScalarNestedValue), + #[prost(message, tag = "20")] + Decimal128Value(super::Decimal128), + #[prost(message, tag = "39")] + Decimal256Value(super::Decimal256), + #[prost(int64, tag = "21")] + Date64Value(i64), + #[prost(int32, tag = "24")] + IntervalYearmonthValue(i32), + #[prost(int64, tag = "35")] + DurationSecondValue(i64), + #[prost(int64, tag = "36")] + DurationMillisecondValue(i64), + #[prost(int64, tag = "37")] + DurationMicrosecondValue(i64), + #[prost(int64, tag = "38")] + DurationNanosecondValue(i64), + #[prost(message, tag = "26")] + TimestampValue(super::ScalarTimestampValue), + #[prost(message, tag = "27")] + DictionaryValue(::prost::alloc::boxed::Box), + #[prost(bytes, tag = "28")] + BinaryValue(::prost::alloc::vec::Vec), + #[prost(bytes, tag = "29")] + LargeBinaryValue(::prost::alloc::vec::Vec), + #[prost(bytes, tag = "22")] + BinaryViewValue(::prost::alloc::vec::Vec), + #[prost(message, tag = "30")] + Time64Value(super::ScalarTime64Value), + #[prost(message, tag = "25")] + IntervalDaytimeValue(super::IntervalDayTimeValue), + #[prost(message, tag = "31")] + IntervalMonthDayNano(super::IntervalMonthDayNanoValue), + #[prost(message, tag = "34")] + FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal128 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Decimal256 { + #[prost(bytes = "vec", tag = "1")] + pub value: ::prost::alloc::vec::Vec, + #[prost(int64, tag = "2")] + pub p: i64, + #[prost(int64, tag = "3")] + pub s: i64, +} +/// Serialized data type +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ArrowType { + #[prost( + oneof = "arrow_type::ArrowTypeEnum", + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 35, 32, 15, 34, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 36, 25, 26, 27, 28, 29, 30, 33" + )] + pub arrow_type_enum: ::core::option::Option, +} +/// Nested message and enum types in `ArrowType`. +pub mod arrow_type { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum ArrowTypeEnum { + /// arrow::Type::NA + #[prost(message, tag = "1")] + None(super::EmptyMessage), + /// arrow::Type::BOOL + #[prost(message, tag = "2")] + Bool(super::EmptyMessage), + /// arrow::Type::UINT8 + #[prost(message, tag = "3")] + Uint8(super::EmptyMessage), + /// arrow::Type::INT8 + #[prost(message, tag = "4")] + Int8(super::EmptyMessage), + /// represents arrow::Type fields in src/arrow/type.h + #[prost(message, tag = "5")] + Uint16(super::EmptyMessage), + #[prost(message, tag = "6")] + Int16(super::EmptyMessage), + #[prost(message, tag = "7")] + Uint32(super::EmptyMessage), + #[prost(message, tag = "8")] + Int32(super::EmptyMessage), + #[prost(message, tag = "9")] + Uint64(super::EmptyMessage), + #[prost(message, tag = "10")] + Int64(super::EmptyMessage), + #[prost(message, tag = "11")] + Float16(super::EmptyMessage), + #[prost(message, tag = "12")] + Float32(super::EmptyMessage), + #[prost(message, tag = "13")] + Float64(super::EmptyMessage), + #[prost(message, tag = "14")] + Utf8(super::EmptyMessage), + #[prost(message, tag = "35")] + Utf8View(super::EmptyMessage), + #[prost(message, tag = "32")] + LargeUtf8(super::EmptyMessage), + #[prost(message, tag = "15")] + Binary(super::EmptyMessage), + #[prost(message, tag = "34")] + BinaryView(super::EmptyMessage), + #[prost(int32, tag = "16")] + FixedSizeBinary(i32), + #[prost(message, tag = "31")] + LargeBinary(super::EmptyMessage), + #[prost(message, tag = "17")] + Date32(super::EmptyMessage), + #[prost(message, tag = "18")] + Date64(super::EmptyMessage), + #[prost(enumeration = "super::TimeUnit", tag = "19")] + Duration(i32), + #[prost(message, tag = "20")] + Timestamp(super::Timestamp), + #[prost(enumeration = "super::TimeUnit", tag = "21")] + Time32(i32), + #[prost(enumeration = "super::TimeUnit", tag = "22")] + Time64(i32), + #[prost(enumeration = "super::IntervalUnit", tag = "23")] + Interval(i32), + #[prost(message, tag = "24")] + Decimal(super::Decimal), + #[prost(message, tag = "36")] + Decimal256(super::Decimal256Type), + #[prost(message, tag = "25")] + List(::prost::alloc::boxed::Box), + #[prost(message, tag = "26")] + LargeList(::prost::alloc::boxed::Box), + #[prost(message, tag = "27")] + FixedSizeList(::prost::alloc::boxed::Box), + #[prost(message, tag = "28")] + Struct(super::Struct), + #[prost(message, tag = "29")] + Union(super::Union), + #[prost(message, tag = "30")] + Dictionary(::prost::alloc::boxed::Box), + #[prost(message, tag = "33")] + Map(::prost::alloc::boxed::Box), + } +} +/// Useful for representing an empty enum variant in rust +/// E.G. enum example{One, Two(i32)} +/// maps to +/// message example{ +/// oneof{ +/// EmptyMessage One = 1; +/// i32 Two = 2; +/// } +/// } +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct EmptyMessage {} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct JsonWriterOptions { + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvWriterOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, + /// Optional column delimiter. Defaults to `b','` + #[prost(string, tag = "2")] + pub delimiter: ::prost::alloc::string::String, + /// Whether to write column names as file headers. Defaults to `true` + #[prost(bool, tag = "3")] + pub has_header: bool, + /// Optional date format for date arrays + #[prost(string, tag = "4")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format for datetime arrays + #[prost(string, tag = "5")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format for timestamp arrays + #[prost(string, tag = "6")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional time format for time arrays + #[prost(string, tag = "7")] + pub time_format: ::prost::alloc::string::String, + /// Optional value to represent null + #[prost(string, tag = "8")] + pub null_value: ::prost::alloc::string::String, + /// Optional quote. Defaults to `b'"'` + #[prost(string, tag = "9")] + pub quote: ::prost::alloc::string::String, + /// Optional escape. Defaults to `'\\'` + #[prost(string, tag = "10")] + pub escape: ::prost::alloc::string::String, + /// Optional flag whether to double quotes, instead of escaping. Defaults to `true` + #[prost(bool, tag = "11")] + pub double_quote: bool, +} +/// Options controlling CSV format +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvOptions { + /// Indicates if the CSV has a header row + #[prost(bytes = "vec", tag = "1")] + pub has_header: ::prost::alloc::vec::Vec, + /// Delimiter character as a byte + #[prost(bytes = "vec", tag = "2")] + pub delimiter: ::prost::alloc::vec::Vec, + /// Quote character as a byte + #[prost(bytes = "vec", tag = "3")] + pub quote: ::prost::alloc::vec::Vec, + /// Optional escape character as a byte + #[prost(bytes = "vec", tag = "4")] + pub escape: ::prost::alloc::vec::Vec, + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "5")] + pub compression: i32, + /// Max records for schema inference + #[prost(uint64, tag = "6")] + pub schema_infer_max_rec: u64, + /// Optional date format + #[prost(string, tag = "7")] + pub date_format: ::prost::alloc::string::String, + /// Optional datetime format + #[prost(string, tag = "8")] + pub datetime_format: ::prost::alloc::string::String, + /// Optional timestamp format + #[prost(string, tag = "9")] + pub timestamp_format: ::prost::alloc::string::String, + /// Optional timestamp with timezone format + #[prost(string, tag = "10")] + pub timestamp_tz_format: ::prost::alloc::string::String, + /// Optional time format + #[prost(string, tag = "11")] + pub time_format: ::prost::alloc::string::String, + /// Optional representation of null value + #[prost(string, tag = "12")] + pub null_value: ::prost::alloc::string::String, + /// Optional comment character as a byte + #[prost(bytes = "vec", tag = "13")] + pub comment: ::prost::alloc::vec::Vec, + /// Indicates if quotes are doubled + #[prost(bytes = "vec", tag = "14")] + pub double_quote: ::prost::alloc::vec::Vec, + /// Indicates if newlines are supported in values + #[prost(bytes = "vec", tag = "15")] + pub newlines_in_values: ::prost::alloc::vec::Vec, + /// Optional terminator character as a byte + #[prost(bytes = "vec", tag = "16")] + pub terminator: ::prost::alloc::vec::Vec, +} +/// Options controlling CSV format +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct JsonOptions { + /// Compression type + #[prost(enumeration = "CompressionTypeVariant", tag = "1")] + pub compression: i32, + /// Max records for schema inference + #[prost(uint64, tag = "2")] + pub schema_infer_max_rec: u64, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TableParquetOptions { + #[prost(message, optional, tag = "1")] + pub global: ::core::option::Option, + #[prost(message, repeated, tag = "2")] + pub column_specific_options: ::prost::alloc::vec::Vec, + #[prost(map = "string, string", tag = "3")] + pub key_value_metadata: ::std::collections::HashMap< + ::prost::alloc::string::String, + ::prost::alloc::string::String, + >, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetColumnSpecificOptions { + #[prost(string, tag = "1")] + pub column_name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetColumnOptions { + #[prost(oneof = "parquet_column_options::BloomFilterEnabledOpt", tags = "1")] + pub bloom_filter_enabled_opt: ::core::option::Option< + parquet_column_options::BloomFilterEnabledOpt, + >, + #[prost(oneof = "parquet_column_options::EncodingOpt", tags = "2")] + pub encoding_opt: ::core::option::Option, + #[prost(oneof = "parquet_column_options::DictionaryEnabledOpt", tags = "3")] + pub dictionary_enabled_opt: ::core::option::Option< + parquet_column_options::DictionaryEnabledOpt, + >, + #[prost(oneof = "parquet_column_options::CompressionOpt", tags = "4")] + pub compression_opt: ::core::option::Option, + #[prost(oneof = "parquet_column_options::StatisticsEnabledOpt", tags = "5")] + pub statistics_enabled_opt: ::core::option::Option< + parquet_column_options::StatisticsEnabledOpt, + >, + #[prost(oneof = "parquet_column_options::BloomFilterFppOpt", tags = "6")] + pub bloom_filter_fpp_opt: ::core::option::Option< + parquet_column_options::BloomFilterFppOpt, + >, + #[prost(oneof = "parquet_column_options::BloomFilterNdvOpt", tags = "7")] + pub bloom_filter_ndv_opt: ::core::option::Option< + parquet_column_options::BloomFilterNdvOpt, + >, + #[prost(oneof = "parquet_column_options::MaxStatisticsSizeOpt", tags = "8")] + pub max_statistics_size_opt: ::core::option::Option< + parquet_column_options::MaxStatisticsSizeOpt, + >, +} +/// Nested message and enum types in `ParquetColumnOptions`. +pub mod parquet_column_options { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterEnabledOpt { + #[prost(bool, tag = "1")] + BloomFilterEnabled(bool), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum EncodingOpt { + #[prost(string, tag = "2")] + Encoding(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum DictionaryEnabledOpt { + #[prost(bool, tag = "3")] + DictionaryEnabled(bool), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CompressionOpt { + #[prost(string, tag = "4")] + Compression(::prost::alloc::string::String), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum StatisticsEnabledOpt { + #[prost(string, tag = "5")] + StatisticsEnabled(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterFppOpt { + #[prost(double, tag = "6")] + BloomFilterFpp(f64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterNdvOpt { + #[prost(uint64, tag = "7")] + BloomFilterNdv(u64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum MaxStatisticsSizeOpt { + #[prost(uint32, tag = "8")] + MaxStatisticsSize(u32), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ParquetOptions { + /// Regular fields + /// + /// default = true + #[prost(bool, tag = "1")] + pub enable_page_index: bool, + /// default = true + #[prost(bool, tag = "2")] + pub pruning: bool, + /// default = true + #[prost(bool, tag = "3")] + pub skip_metadata: bool, + /// default = false + #[prost(bool, tag = "5")] + pub pushdown_filters: bool, + /// default = false + #[prost(bool, tag = "6")] + pub reorder_filters: bool, + /// default = 1024 * 1024 + #[prost(uint64, tag = "7")] + pub data_pagesize_limit: u64, + /// default = 1024 + #[prost(uint64, tag = "8")] + pub write_batch_size: u64, + /// default = "1.0" + #[prost(string, tag = "9")] + pub writer_version: ::prost::alloc::string::String, + /// bool bloom_filter_enabled = 20; // default = false + /// + /// default = true + #[prost(bool, tag = "23")] + pub allow_single_file_parallelism: bool, + /// default = 1 + #[prost(uint64, tag = "24")] + pub maximum_parallel_row_group_writers: u64, + /// default = 2 + #[prost(uint64, tag = "25")] + pub maximum_buffered_record_batches_per_stream: u64, + /// default = true + #[prost(bool, tag = "26")] + pub bloom_filter_on_read: bool, + /// default = false + #[prost(bool, tag = "27")] + pub bloom_filter_on_write: bool, + /// default = false + #[prost(bool, tag = "28")] + pub schema_force_view_types: bool, + /// default = false + #[prost(bool, tag = "29")] + pub binary_as_string: bool, + #[prost(uint64, tag = "12")] + pub dictionary_page_size_limit: u64, + #[prost(uint64, tag = "18")] + pub data_page_row_count_limit: u64, + #[prost(uint64, tag = "15")] + pub max_row_group_size: u64, + #[prost(string, tag = "16")] + pub created_by: ::prost::alloc::string::String, + #[prost(oneof = "parquet_options::MetadataSizeHintOpt", tags = "4")] + pub metadata_size_hint_opt: ::core::option::Option< + parquet_options::MetadataSizeHintOpt, + >, + #[prost(oneof = "parquet_options::CompressionOpt", tags = "10")] + pub compression_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::DictionaryEnabledOpt", tags = "11")] + pub dictionary_enabled_opt: ::core::option::Option< + parquet_options::DictionaryEnabledOpt, + >, + #[prost(oneof = "parquet_options::StatisticsEnabledOpt", tags = "13")] + pub statistics_enabled_opt: ::core::option::Option< + parquet_options::StatisticsEnabledOpt, + >, + #[prost(oneof = "parquet_options::MaxStatisticsSizeOpt", tags = "14")] + pub max_statistics_size_opt: ::core::option::Option< + parquet_options::MaxStatisticsSizeOpt, + >, + #[prost(oneof = "parquet_options::ColumnIndexTruncateLengthOpt", tags = "17")] + pub column_index_truncate_length_opt: ::core::option::Option< + parquet_options::ColumnIndexTruncateLengthOpt, + >, + #[prost(oneof = "parquet_options::EncodingOpt", tags = "19")] + pub encoding_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::BloomFilterFppOpt", tags = "21")] + pub bloom_filter_fpp_opt: ::core::option::Option, + #[prost(oneof = "parquet_options::BloomFilterNdvOpt", tags = "22")] + pub bloom_filter_ndv_opt: ::core::option::Option, +} +/// Nested message and enum types in `ParquetOptions`. +pub mod parquet_options { + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum MetadataSizeHintOpt { + #[prost(uint64, tag = "4")] + MetadataSizeHint(u64), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum CompressionOpt { + #[prost(string, tag = "10")] + Compression(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum DictionaryEnabledOpt { + #[prost(bool, tag = "11")] + DictionaryEnabled(bool), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum StatisticsEnabledOpt { + #[prost(string, tag = "13")] + StatisticsEnabled(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum MaxStatisticsSizeOpt { + #[prost(uint64, tag = "14")] + MaxStatisticsSize(u64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum ColumnIndexTruncateLengthOpt { + #[prost(uint64, tag = "17")] + ColumnIndexTruncateLength(u64), + } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum EncodingOpt { + #[prost(string, tag = "19")] + Encoding(::prost::alloc::string::String), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterFppOpt { + #[prost(double, tag = "21")] + BloomFilterFpp(f64), + } + #[derive(Clone, Copy, PartialEq, ::prost::Oneof)] + pub enum BloomFilterNdvOpt { + #[prost(uint64, tag = "22")] + BloomFilterNdv(u64), + } +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Precision { + #[prost(enumeration = "PrecisionInfo", tag = "1")] + pub precision_info: i32, + #[prost(message, optional, tag = "2")] + pub val: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Statistics { + #[prost(message, optional, tag = "1")] + pub num_rows: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub total_byte_size: ::core::option::Option, + #[prost(message, repeated, tag = "3")] + pub column_stats: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnStats { + #[prost(message, optional, tag = "1")] + pub min_value: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub max_value: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub null_count: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub distinct_count: ::core::option::Option, +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum JoinType { + Inner = 0, + Left = 1, + Right = 2, + Full = 3, + Leftsemi = 4, + Leftanti = 5, + Rightsemi = 6, + Rightanti = 7, + Leftmark = 8, +} +impl JoinType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Inner => "INNER", + Self::Left => "LEFT", + Self::Right => "RIGHT", + Self::Full => "FULL", + Self::Leftsemi => "LEFTSEMI", + Self::Leftanti => "LEFTANTI", + Self::Rightsemi => "RIGHTSEMI", + Self::Rightanti => "RIGHTANTI", + Self::Leftmark => "LEFTMARK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "INNER" => Some(Self::Inner), + "LEFT" => Some(Self::Left), + "RIGHT" => Some(Self::Right), + "FULL" => Some(Self::Full), + "LEFTSEMI" => Some(Self::Leftsemi), + "LEFTANTI" => Some(Self::Leftanti), + "RIGHTSEMI" => Some(Self::Rightsemi), + "RIGHTANTI" => Some(Self::Rightanti), + "LEFTMARK" => Some(Self::Leftmark), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum JoinConstraint { + On = 0, + Using = 1, +} +impl JoinConstraint { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::On => "ON", + Self::Using => "USING", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "ON" => Some(Self::On), + "USING" => Some(Self::Using), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum TimeUnit { + Second = 0, + Millisecond = 1, + Microsecond = 2, + Nanosecond = 3, +} +impl TimeUnit { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Second => "Second", + Self::Millisecond => "Millisecond", + Self::Microsecond => "Microsecond", + Self::Nanosecond => "Nanosecond", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Second" => Some(Self::Second), + "Millisecond" => Some(Self::Millisecond), + "Microsecond" => Some(Self::Microsecond), + "Nanosecond" => Some(Self::Nanosecond), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum IntervalUnit { + YearMonth = 0, + DayTime = 1, + MonthDayNano = 2, +} +impl IntervalUnit { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::YearMonth => "YearMonth", + Self::DayTime => "DayTime", + Self::MonthDayNano => "MonthDayNano", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "YearMonth" => Some(Self::YearMonth), + "DayTime" => Some(Self::DayTime), + "MonthDayNano" => Some(Self::MonthDayNano), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum UnionMode { + Sparse = 0, + Dense = 1, +} +impl UnionMode { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Sparse => "sparse", + Self::Dense => "dense", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "sparse" => Some(Self::Sparse), + "dense" => Some(Self::Dense), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum CompressionTypeVariant { + Gzip = 0, + Bzip2 = 1, + Xz = 2, + Zstd = 3, + Uncompressed = 4, +} +impl CompressionTypeVariant { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Gzip => "GZIP", + Self::Bzip2 => "BZIP2", + Self::Xz => "XZ", + Self::Zstd => "ZSTD", + Self::Uncompressed => "UNCOMPRESSED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "GZIP" => Some(Self::Gzip), + "BZIP2" => Some(Self::Bzip2), + "XZ" => Some(Self::Xz), + "ZSTD" => Some(Self::Zstd), + "UNCOMPRESSED" => Some(Self::Uncompressed), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum JoinSide { + LeftSide = 0, + RightSide = 1, + None = 2, +} +impl JoinSide { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::LeftSide => "LEFT_SIDE", + Self::RightSide => "RIGHT_SIDE", + Self::None => "NONE", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "LEFT_SIDE" => Some(Self::LeftSide), + "RIGHT_SIDE" => Some(Self::RightSide), + "NONE" => Some(Self::None), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum PrecisionInfo { + Exact = 0, + Inexact = 1, + Absent = 2, +} +impl PrecisionInfo { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Exact => "EXACT", + Self::Inexact => "INEXACT", + Self::Absent => "ABSENT", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "EXACT" => Some(Self::Exact), + "INEXACT" => Some(Self::Inexact), + "ABSENT" => Some(Self::Absent), + _ => None, + } + } +} diff --git a/datafusion/proto/src/generated/mod.rs b/datafusion/proto/src/generated/mod.rs index e48e96e887cf..da3302a74375 100644 --- a/datafusion/proto/src/generated/mod.rs +++ b/datafusion/proto/src/generated/mod.rs @@ -17,10 +17,12 @@ #[allow(clippy::all)] #[rustfmt::skip] -#[cfg(not(docsrs))] pub mod datafusion { include!("prost.rs"); + include!("datafusion_proto_common.rs"); #[cfg(feature = "json")] include!("pbjson.rs"); } + +pub use datafusion_proto_common::protobuf_common as datafusion_common; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0fb6f4623745..e54edb718808 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1,3 +1,98 @@ +impl serde::Serialize for AggLimit { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.limit != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AggLimit", len)?; + if self.limit != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("limit", ToString::to_string(&self.limit).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AggLimit { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "limit", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Limit, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "limit" => Ok(GeneratedField::Limit), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AggLimit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.AggLimit") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut limit__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(AggLimit { + limit: limit__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.AggLimit", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AggregateExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -36,6 +131,9 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { len += 1; } + if self.limit.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -69,6 +167,9 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { struct_ser.serialize_field("filterExpr", &self.filter_expr)?; } + if let Some(v) = self.limit.as_ref() { + struct_ser.serialize_field("limit", v)?; + } struct_ser.end() } } @@ -96,6 +197,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "groups", "filter_expr", "filterExpr", + "limit", ]; #[allow(clippy::enum_variant_names)] @@ -110,6 +212,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { NullExpr, Groups, FilterExpr, + Limit, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -141,6 +244,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "nullExpr" | "null_expr" => Ok(GeneratedField::NullExpr), "groups" => Ok(GeneratedField::Groups), "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), + "limit" => Ok(GeneratedField::Limit), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -170,6 +274,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut null_expr__ = None; let mut groups__ = None; let mut filter_expr__ = None; + let mut limit__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -232,6 +337,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } filter_expr__ = Some(map_.next_value()?); } + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + limit__ = map_.next_value()?; + } } } Ok(AggregateExecNode { @@ -245,13 +356,94 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { null_expr: null_expr__.unwrap_or_default(), groups: groups__.unwrap_or_default(), filter_expr: filter_expr__.unwrap_or_default(), + limit: limit__, }) } } deserializer.deserialize_struct("datafusion.AggregateExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AggregateExprNode { +impl serde::Serialize for AggregateMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Partial => "PARTIAL", + Self::Final => "FINAL", + Self::FinalPartitioned => "FINAL_PARTITIONED", + Self::Single => "SINGLE", + Self::SinglePartitioned => "SINGLE_PARTITIONED", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for AggregateMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "PARTIAL", + "FINAL", + "FINAL_PARTITIONED", + "SINGLE", + "SINGLE_PARTITIONED", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AggregateMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "PARTIAL" => Ok(AggregateMode::Partial), + "FINAL" => Ok(AggregateMode::Final), + "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned), + "SINGLE" => Ok(AggregateMode::Single), + "SINGLE_PARTITIONED" => Ok(AggregateMode::SinglePartitioned), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for AggregateNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -259,65 +451,47 @@ impl serde::Serialize for AggregateExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.aggr_function != 0 { - len += 1; - } - if !self.expr.is_empty() { - len += 1; - } - if self.distinct { + if self.input.is_some() { len += 1; } - if self.filter.is_some() { + if !self.group_expr.is_empty() { len += 1; } - if !self.order_by.is_empty() { + if !self.aggr_expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExprNode", len)?; - if self.aggr_function != 0 { - let v = AggregateFunction::try_from(self.aggr_function) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.aggr_function)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - if self.distinct { - struct_ser.serialize_field("distinct", &self.distinct)?; + let mut struct_ser = serializer.serialize_struct("datafusion.AggregateNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; + if !self.group_expr.is_empty() { + struct_ser.serialize_field("groupExpr", &self.group_expr)?; } - if !self.order_by.is_empty() { - struct_ser.serialize_field("orderBy", &self.order_by)?; + if !self.aggr_expr.is_empty() { + struct_ser.serialize_field("aggrExpr", &self.aggr_expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for AggregateExprNode { +impl<'de> serde::Deserialize<'de> for AggregateNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "aggr_function", - "aggrFunction", - "expr", - "distinct", - "filter", - "order_by", - "orderBy", + "input", + "group_expr", + "groupExpr", + "aggr_expr", + "aggrExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - AggrFunction, - Expr, - Distinct, - Filter, - OrderBy, + Input, + GroupExpr, + AggrExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -339,11 +513,9 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { E: serde::de::Error, { match value { - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "expr" => Ok(GeneratedField::Expr), - "distinct" => Ok(GeneratedField::Distinct), - "filter" => Ok(GeneratedField::Filter), - "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "input" => Ok(GeneratedField::Input), + "groupExpr" | "group_expr" => Ok(GeneratedField::GroupExpr), + "aggrExpr" | "aggr_expr" => Ok(GeneratedField::AggrExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -353,324 +525,52 @@ impl<'de> serde::Deserialize<'de> for AggregateExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateExprNode; + type Value = AggregateNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AggregateExprNode") + formatter.write_str("struct datafusion.AggregateNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut aggr_function__ = None; - let mut expr__ = None; - let mut distinct__ = None; - let mut filter__ = None; - let mut order_by__ = None; + let mut input__ = None; + let mut group_expr__ = None; + let mut aggr_expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::AggrFunction => { - if aggr_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - aggr_function__ = Some(map_.next_value::()? as i32); + input__ = map_.next_value()?; } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::GroupExpr => { + if group_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("groupExpr")); } - expr__ = Some(map_.next_value()?); - } - GeneratedField::Distinct => { - if distinct__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - distinct__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; + group_expr__ = Some(map_.next_value()?); } - GeneratedField::OrderBy => { - if order_by__.is_some() { - return Err(serde::de::Error::duplicate_field("orderBy")); + GeneratedField::AggrExpr => { + if aggr_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("aggrExpr")); } - order_by__ = Some(map_.next_value()?); + aggr_expr__ = Some(map_.next_value()?); } } } - Ok(AggregateExprNode { - aggr_function: aggr_function__.unwrap_or_default(), - expr: expr__.unwrap_or_default(), - distinct: distinct__.unwrap_or_default(), - filter: filter__, - order_by: order_by__.unwrap_or_default(), + Ok(AggregateNode { + input: input__, + group_expr: group_expr__.unwrap_or_default(), + aggr_expr: aggr_expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.AggregateExprNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for AggregateFunction { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Min => "MIN", - Self::Max => "MAX", - Self::Sum => "SUM", - Self::Avg => "AVG", - Self::Count => "COUNT", - Self::ApproxDistinct => "APPROX_DISTINCT", - Self::ArrayAgg => "ARRAY_AGG", - Self::Variance => "VARIANCE", - Self::VariancePop => "VARIANCE_POP", - Self::Covariance => "COVARIANCE", - Self::CovariancePop => "COVARIANCE_POP", - Self::Stddev => "STDDEV", - Self::StddevPop => "STDDEV_POP", - Self::Correlation => "CORRELATION", - Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - Self::ApproxMedian => "APPROX_MEDIAN", - Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", - Self::Grouping => "GROUPING", - Self::Median => "MEDIAN", - Self::BitAnd => "BIT_AND", - Self::BitOr => "BIT_OR", - Self::BitXor => "BIT_XOR", - Self::BoolAnd => "BOOL_AND", - Self::BoolOr => "BOOL_OR", - Self::FirstValueAgg => "FIRST_VALUE_AGG", - Self::LastValueAgg => "LAST_VALUE_AGG", - Self::RegrSlope => "REGR_SLOPE", - Self::RegrIntercept => "REGR_INTERCEPT", - Self::RegrCount => "REGR_COUNT", - Self::RegrR2 => "REGR_R2", - Self::RegrAvgx => "REGR_AVGX", - Self::RegrAvgy => "REGR_AVGY", - Self::RegrSxx => "REGR_SXX", - Self::RegrSyy => "REGR_SYY", - Self::RegrSxy => "REGR_SXY", - Self::StringAgg => "STRING_AGG", - Self::NthValueAgg => "NTH_VALUE_AGG", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for AggregateFunction { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "MIN", - "MAX", - "SUM", - "AVG", - "COUNT", - "APPROX_DISTINCT", - "ARRAY_AGG", - "VARIANCE", - "VARIANCE_POP", - "COVARIANCE", - "COVARIANCE_POP", - "STDDEV", - "STDDEV_POP", - "CORRELATION", - "APPROX_PERCENTILE_CONT", - "APPROX_MEDIAN", - "APPROX_PERCENTILE_CONT_WITH_WEIGHT", - "GROUPING", - "MEDIAN", - "BIT_AND", - "BIT_OR", - "BIT_XOR", - "BOOL_AND", - "BOOL_OR", - "FIRST_VALUE_AGG", - "LAST_VALUE_AGG", - "REGR_SLOPE", - "REGR_INTERCEPT", - "REGR_COUNT", - "REGR_R2", - "REGR_AVGX", - "REGR_AVGY", - "REGR_SXX", - "REGR_SYY", - "REGR_SXY", - "STRING_AGG", - "NTH_VALUE_AGG", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateFunction; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "MIN" => Ok(AggregateFunction::Min), - "MAX" => Ok(AggregateFunction::Max), - "SUM" => Ok(AggregateFunction::Sum), - "AVG" => Ok(AggregateFunction::Avg), - "COUNT" => Ok(AggregateFunction::Count), - "APPROX_DISTINCT" => Ok(AggregateFunction::ApproxDistinct), - "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), - "VARIANCE" => Ok(AggregateFunction::Variance), - "VARIANCE_POP" => Ok(AggregateFunction::VariancePop), - "COVARIANCE" => Ok(AggregateFunction::Covariance), - "COVARIANCE_POP" => Ok(AggregateFunction::CovariancePop), - "STDDEV" => Ok(AggregateFunction::Stddev), - "STDDEV_POP" => Ok(AggregateFunction::StddevPop), - "CORRELATION" => Ok(AggregateFunction::Correlation), - "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), - "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), - "GROUPING" => Ok(AggregateFunction::Grouping), - "MEDIAN" => Ok(AggregateFunction::Median), - "BIT_AND" => Ok(AggregateFunction::BitAnd), - "BIT_OR" => Ok(AggregateFunction::BitOr), - "BIT_XOR" => Ok(AggregateFunction::BitXor), - "BOOL_AND" => Ok(AggregateFunction::BoolAnd), - "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "FIRST_VALUE_AGG" => Ok(AggregateFunction::FirstValueAgg), - "LAST_VALUE_AGG" => Ok(AggregateFunction::LastValueAgg), - "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), - "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), - "REGR_COUNT" => Ok(AggregateFunction::RegrCount), - "REGR_R2" => Ok(AggregateFunction::RegrR2), - "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), - "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), - "REGR_SXX" => Ok(AggregateFunction::RegrSxx), - "REGR_SYY" => Ok(AggregateFunction::RegrSyy), - "REGR_SXY" => Ok(AggregateFunction::RegrSxy), - "STRING_AGG" => Ok(AggregateFunction::StringAgg), - "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} -impl serde::Serialize for AggregateMode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Partial => "PARTIAL", - Self::Final => "FINAL", - Self::FinalPartitioned => "FINAL_PARTITIONED", - Self::Single => "SINGLE", - Self::SinglePartitioned => "SINGLE_PARTITIONED", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for AggregateMode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "PARTIAL", - "FINAL", - "FINAL_PARTITIONED", - "SINGLE", - "SINGLE_PARTITIONED", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateMode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "PARTIAL" => Ok(AggregateMode::Partial), - "FINAL" => Ok(AggregateMode::Final), - "FINAL_PARTITIONED" => Ok(AggregateMode::FinalPartitioned), - "SINGLE" => Ok(AggregateMode::Single), - "SINGLE_PARTITIONED" => Ok(AggregateMode::SinglePartitioned), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.AggregateNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AggregateNode { +impl serde::Serialize for AggregateUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -678,47 +578,74 @@ impl serde::Serialize for AggregateNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.fun_name.is_empty() { len += 1; } - if !self.group_expr.is_empty() { + if !self.args.is_empty() { len += 1; } - if !self.aggr_expr.is_empty() { + if self.distinct { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.AggregateNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if self.filter.is_some() { + len += 1; } - if !self.group_expr.is_empty() { - struct_ser.serialize_field("groupExpr", &self.group_expr)?; + if !self.order_by.is_empty() { + len += 1; } - if !self.aggr_expr.is_empty() { - struct_ser.serialize_field("aggrExpr", &self.aggr_expr)?; + if self.fun_definition.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?; + if !self.fun_name.is_empty() { + struct_ser.serialize_field("funName", &self.fun_name)?; + } + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for AggregateNode { +impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "group_expr", - "groupExpr", - "aggr_expr", - "aggrExpr", + "fun_name", + "funName", + "args", + "distinct", + "filter", + "order_by", + "orderBy", + "fun_definition", + "funDefinition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - GroupExpr, - AggrExpr, + FunName, + Args, + Distinct, + Filter, + OrderBy, + FunDefinition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -740,9 +667,12 @@ impl<'de> serde::Deserialize<'de> for AggregateNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "groupExpr" | "group_expr" => Ok(GeneratedField::GroupExpr), - "aggrExpr" | "aggr_expr" => Ok(GeneratedField::AggrExpr), + "funName" | "fun_name" => Ok(GeneratedField::FunName), + "args" => Ok(GeneratedField::Args), + "distinct" => Ok(GeneratedField::Distinct), + "filter" => Ok(GeneratedField::Filter), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -752,156 +682,22 @@ impl<'de> serde::Deserialize<'de> for AggregateNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateNode; + type Value = AggregateUdfExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AggregateNode") + formatter.write_str("struct datafusion.AggregateUDFExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut input__ = None; - let mut group_expr__ = None; - let mut aggr_expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::GroupExpr => { - if group_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("groupExpr")); - } - group_expr__ = Some(map_.next_value()?); - } - GeneratedField::AggrExpr => { - if aggr_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrExpr")); - } - aggr_expr__ = Some(map_.next_value()?); - } - } - } - Ok(AggregateNode { - input: input__, - group_expr: group_expr__.unwrap_or_default(), - aggr_expr: aggr_expr__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.AggregateNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for AggregateUdfExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.fun_name.is_empty() { - len += 1; - } - if !self.args.is_empty() { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - if !self.order_by.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.AggregateUDFExprNode", len)?; - if !self.fun_name.is_empty() { - struct_ser.serialize_field("funName", &self.fun_name)?; - } - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - if !self.order_by.is_empty() { - struct_ser.serialize_field("orderBy", &self.order_by)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "fun_name", - "funName", - "args", - "filter", - "order_by", - "orderBy", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - FunName, - Args, - Filter, - OrderBy, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "funName" | "fun_name" => Ok(GeneratedField::FunName), - "args" => Ok(GeneratedField::Args), - "filter" => Ok(GeneratedField::Filter), - "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AggregateUdfExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AggregateUDFExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut fun_name__ = None; let mut args__ = None; + let mut distinct__ = None; let mut filter__ = None; let mut order_by__ = None; + let mut fun_definition__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::FunName => { @@ -916,6 +712,12 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } args__ = Some(map_.next_value()?); } + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + distinct__ = Some(map_.next_value()?); + } GeneratedField::Filter => { if filter__.is_some() { return Err(serde::de::Error::duplicate_field("filter")); @@ -928,13 +730,23 @@ impl<'de> serde::Deserialize<'de> for AggregateUdfExprNode { } order_by__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } } } Ok(AggregateUdfExprNode { fun_name: fun_name__.unwrap_or_default(), args: args__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), filter: filter__, order_by: order_by__.unwrap_or_default(), + fun_definition: fun_definition__, }) } } @@ -1409,29 +1221,38 @@ impl<'de> serde::Deserialize<'de> for AnalyzedLogicalPlanType { deserializer.deserialize_struct("datafusion.AnalyzedLogicalPlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ArrowOptions { +impl serde::Serialize for AvroScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let len = 0; - let struct_ser = serializer.serialize_struct("datafusion.ArrowOptions", len)?; + let mut len = 0; + if self.base_conf.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AvroScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ArrowOptions { +impl<'de> serde::Deserialize<'de> for AvroScanExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "base_conf", + "baseConf", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + BaseConf, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1452,7 +1273,10 @@ impl<'de> serde::Deserialize<'de> for ArrowOptions { where E: serde::de::Error, { - Err(serde::de::Error::unknown_field(value, FIELDS)) + match value { + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -1460,27 +1284,36 @@ impl<'de> serde::Deserialize<'de> for ArrowOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ArrowOptions; + type Value = AvroScanExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ArrowOptions") + formatter.write_str("struct datafusion.AvroScanExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map_.next_key::()?.is_some() { - let _ = map_.next_value::()?; + let mut base_conf__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); + } + base_conf__ = map_.next_value()?; + } + } } - Ok(ArrowOptions { + Ok(AvroScanExecNode { + base_conf: base_conf__, }) } } - deserializer.deserialize_struct("datafusion.ArrowOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.AvroScanExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ArrowType { +impl serde::Serialize for BareTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -1488,206 +1321,29 @@ impl serde::Serialize for ArrowType { { use serde::ser::SerializeStruct; let mut len = 0; - if self.arrow_type_enum.is_some() { + if !self.table.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ArrowType", len)?; - if let Some(v) = self.arrow_type_enum.as_ref() { - match v { - arrow_type::ArrowTypeEnum::None(v) => { - struct_ser.serialize_field("NONE", v)?; - } - arrow_type::ArrowTypeEnum::Bool(v) => { - struct_ser.serialize_field("BOOL", v)?; - } - arrow_type::ArrowTypeEnum::Uint8(v) => { - struct_ser.serialize_field("UINT8", v)?; - } - arrow_type::ArrowTypeEnum::Int8(v) => { - struct_ser.serialize_field("INT8", v)?; - } - arrow_type::ArrowTypeEnum::Uint16(v) => { - struct_ser.serialize_field("UINT16", v)?; - } - arrow_type::ArrowTypeEnum::Int16(v) => { - struct_ser.serialize_field("INT16", v)?; - } - arrow_type::ArrowTypeEnum::Uint32(v) => { - struct_ser.serialize_field("UINT32", v)?; - } - arrow_type::ArrowTypeEnum::Int32(v) => { - struct_ser.serialize_field("INT32", v)?; - } - arrow_type::ArrowTypeEnum::Uint64(v) => { - struct_ser.serialize_field("UINT64", v)?; - } - arrow_type::ArrowTypeEnum::Int64(v) => { - struct_ser.serialize_field("INT64", v)?; - } - arrow_type::ArrowTypeEnum::Float16(v) => { - struct_ser.serialize_field("FLOAT16", v)?; - } - arrow_type::ArrowTypeEnum::Float32(v) => { - struct_ser.serialize_field("FLOAT32", v)?; - } - arrow_type::ArrowTypeEnum::Float64(v) => { - struct_ser.serialize_field("FLOAT64", v)?; - } - arrow_type::ArrowTypeEnum::Utf8(v) => { - struct_ser.serialize_field("UTF8", v)?; - } - arrow_type::ArrowTypeEnum::LargeUtf8(v) => { - struct_ser.serialize_field("LARGEUTF8", v)?; - } - arrow_type::ArrowTypeEnum::Binary(v) => { - struct_ser.serialize_field("BINARY", v)?; - } - arrow_type::ArrowTypeEnum::FixedSizeBinary(v) => { - struct_ser.serialize_field("FIXEDSIZEBINARY", v)?; - } - arrow_type::ArrowTypeEnum::LargeBinary(v) => { - struct_ser.serialize_field("LARGEBINARY", v)?; - } - arrow_type::ArrowTypeEnum::Date32(v) => { - struct_ser.serialize_field("DATE32", v)?; - } - arrow_type::ArrowTypeEnum::Date64(v) => { - struct_ser.serialize_field("DATE64", v)?; - } - arrow_type::ArrowTypeEnum::Duration(v) => { - let v = TimeUnit::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("DURATION", &v)?; - } - arrow_type::ArrowTypeEnum::Timestamp(v) => { - struct_ser.serialize_field("TIMESTAMP", v)?; - } - arrow_type::ArrowTypeEnum::Time32(v) => { - let v = TimeUnit::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("TIME32", &v)?; - } - arrow_type::ArrowTypeEnum::Time64(v) => { - let v = TimeUnit::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("TIME64", &v)?; - } - arrow_type::ArrowTypeEnum::Interval(v) => { - let v = IntervalUnit::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("INTERVAL", &v)?; - } - arrow_type::ArrowTypeEnum::Decimal(v) => { - struct_ser.serialize_field("DECIMAL", v)?; - } - arrow_type::ArrowTypeEnum::List(v) => { - struct_ser.serialize_field("LIST", v)?; - } - arrow_type::ArrowTypeEnum::LargeList(v) => { - struct_ser.serialize_field("LARGELIST", v)?; - } - arrow_type::ArrowTypeEnum::FixedSizeList(v) => { - struct_ser.serialize_field("FIXEDSIZELIST", v)?; - } - arrow_type::ArrowTypeEnum::Struct(v) => { - struct_ser.serialize_field("STRUCT", v)?; - } - arrow_type::ArrowTypeEnum::Union(v) => { - struct_ser.serialize_field("UNION", v)?; - } - arrow_type::ArrowTypeEnum::Dictionary(v) => { - struct_ser.serialize_field("DICTIONARY", v)?; - } - arrow_type::ArrowTypeEnum::Map(v) => { - struct_ser.serialize_field("MAP", v)?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.BareTableReference", len)?; + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ArrowType { +impl<'de> serde::Deserialize<'de> for BareTableReference { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "NONE", - "BOOL", - "UINT8", - "INT8", - "UINT16", - "INT16", - "UINT32", - "INT32", - "UINT64", - "INT64", - "FLOAT16", - "FLOAT32", - "FLOAT64", - "UTF8", - "LARGE_UTF8", - "LARGEUTF8", - "BINARY", - "FIXED_SIZE_BINARY", - "FIXEDSIZEBINARY", - "LARGE_BINARY", - "LARGEBINARY", - "DATE32", - "DATE64", - "DURATION", - "TIMESTAMP", - "TIME32", - "TIME64", - "INTERVAL", - "DECIMAL", - "LIST", - "LARGE_LIST", - "LARGELIST", - "FIXED_SIZE_LIST", - "FIXEDSIZELIST", - "STRUCT", - "UNION", - "DICTIONARY", - "MAP", + "table", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - None, - Bool, - Uint8, - Int8, - Uint16, - Int16, - Uint32, - Int32, - Uint64, - Int64, - Float16, - Float32, - Float64, - Utf8, - LargeUtf8, - Binary, - FixedSizeBinary, - LargeBinary, - Date32, - Date64, - Duration, - Timestamp, - Time32, - Time64, - Interval, - Decimal, - List, - LargeList, - FixedSizeList, - Struct, - Union, - Dictionary, - Map, + Table, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -1709,39 +1365,7 @@ impl<'de> serde::Deserialize<'de> for ArrowType { E: serde::de::Error, { match value { - "NONE" => Ok(GeneratedField::None), - "BOOL" => Ok(GeneratedField::Bool), - "UINT8" => Ok(GeneratedField::Uint8), - "INT8" => Ok(GeneratedField::Int8), - "UINT16" => Ok(GeneratedField::Uint16), - "INT16" => Ok(GeneratedField::Int16), - "UINT32" => Ok(GeneratedField::Uint32), - "INT32" => Ok(GeneratedField::Int32), - "UINT64" => Ok(GeneratedField::Uint64), - "INT64" => Ok(GeneratedField::Int64), - "FLOAT16" => Ok(GeneratedField::Float16), - "FLOAT32" => Ok(GeneratedField::Float32), - "FLOAT64" => Ok(GeneratedField::Float64), - "UTF8" => Ok(GeneratedField::Utf8), - "LARGEUTF8" | "LARGE_UTF8" => Ok(GeneratedField::LargeUtf8), - "BINARY" => Ok(GeneratedField::Binary), - "FIXEDSIZEBINARY" | "FIXED_SIZE_BINARY" => Ok(GeneratedField::FixedSizeBinary), - "LARGEBINARY" | "LARGE_BINARY" => Ok(GeneratedField::LargeBinary), - "DATE32" => Ok(GeneratedField::Date32), - "DATE64" => Ok(GeneratedField::Date64), - "DURATION" => Ok(GeneratedField::Duration), - "TIMESTAMP" => Ok(GeneratedField::Timestamp), - "TIME32" => Ok(GeneratedField::Time32), - "TIME64" => Ok(GeneratedField::Time64), - "INTERVAL" => Ok(GeneratedField::Interval), - "DECIMAL" => Ok(GeneratedField::Decimal), - "LIST" => Ok(GeneratedField::List), - "LARGELIST" | "LARGE_LIST" => Ok(GeneratedField::LargeList), - "FIXEDSIZELIST" | "FIXED_SIZE_LIST" => Ok(GeneratedField::FixedSizeList), - "STRUCT" => Ok(GeneratedField::Struct), - "UNION" => Ok(GeneratedField::Union), - "DICTIONARY" => Ok(GeneratedField::Dictionary), - "MAP" => Ok(GeneratedField::Map), + "table" => Ok(GeneratedField::Table), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -1751,278 +1375,90 @@ impl<'de> serde::Deserialize<'de> for ArrowType { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ArrowType; + type Value = BareTableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ArrowType") + formatter.write_str("struct datafusion.BareTableReference") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut arrow_type_enum__ = None; + let mut table__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::None => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("NONE")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::None) -; - } - GeneratedField::Bool => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("BOOL")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Bool) -; - } - GeneratedField::Uint8 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("UINT8")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint8) -; - } - GeneratedField::Int8 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("INT8")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int8) -; - } - GeneratedField::Uint16 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("UINT16")); + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint16) -; - } - GeneratedField::Int16 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("INT16")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int16) -; - } - GeneratedField::Uint32 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("UINT32")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint32) -; - } - GeneratedField::Int32 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("INT32")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int32) -; - } - GeneratedField::Uint64 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("UINT64")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Uint64) -; - } - GeneratedField::Int64 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("INT64")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Int64) -; - } - GeneratedField::Float16 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FLOAT16")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float16) -; - } - GeneratedField::Float32 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FLOAT32")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float32) -; - } - GeneratedField::Float64 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FLOAT64")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Float64) -; - } - GeneratedField::Utf8 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("UTF8")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Utf8) -; - } - GeneratedField::LargeUtf8 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("LARGEUTF8")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeUtf8) -; - } - GeneratedField::Binary => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("BINARY")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Binary) -; - } - GeneratedField::FixedSizeBinary => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FIXEDSIZEBINARY")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| arrow_type::ArrowTypeEnum::FixedSizeBinary(x.0)); - } - GeneratedField::LargeBinary => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("LARGEBINARY")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeBinary) -; - } - GeneratedField::Date32 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("DATE32")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date32) -; - } - GeneratedField::Date64 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("DATE64")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Date64) -; - } - GeneratedField::Duration => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("DURATION")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Duration(x as i32)); - } - GeneratedField::Timestamp => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("TIMESTAMP")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Timestamp) -; - } - GeneratedField::Time32 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("TIME32")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time32(x as i32)); - } - GeneratedField::Time64 => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("TIME64")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Time64(x as i32)); - } - GeneratedField::Interval => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("INTERVAL")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option>()?.map(|x| arrow_type::ArrowTypeEnum::Interval(x as i32)); - } - GeneratedField::Decimal => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("DECIMAL")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Decimal) -; - } - GeneratedField::List => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("LIST")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::List) -; - } - GeneratedField::LargeList => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("LARGELIST")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::LargeList) -; - } - GeneratedField::FixedSizeList => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FIXEDSIZELIST")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::FixedSizeList) -; - } - GeneratedField::Struct => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("STRUCT")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Struct) -; - } - GeneratedField::Union => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("UNION")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Union) -; - } - GeneratedField::Dictionary => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("DICTIONARY")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Dictionary) -; - } - GeneratedField::Map => { - if arrow_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("MAP")); - } - arrow_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(arrow_type::ArrowTypeEnum::Map) -; + table__ = Some(map_.next_value()?); } } } - Ok(ArrowType { - arrow_type_enum: arrow_type_enum__, + Ok(BareTableReference { + table: table__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ArrowType", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.BareTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AvroFormat { +impl serde::Serialize for BetweenNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let len = 0; - let struct_ser = serializer.serialize_struct("datafusion.AvroFormat", len)?; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + if self.negated { + len += 1; + } + if self.low.is_some() { + len += 1; + } + if self.high.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.BetweenNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if let Some(v) = self.low.as_ref() { + struct_ser.serialize_field("low", v)?; + } + if let Some(v) = self.high.as_ref() { + struct_ser.serialize_field("high", v)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for AvroFormat { +impl<'de> serde::Deserialize<'de> for BetweenNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "expr", + "negated", + "low", + "high", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + Expr, + Negated, + Low, + High, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2043,7 +1479,13 @@ impl<'de> serde::Deserialize<'de> for AvroFormat { where E: serde::de::Error, { - Err(serde::de::Error::unknown_field(value, FIELDS)) + match value { + "expr" => Ok(GeneratedField::Expr), + "negated" => Ok(GeneratedField::Negated), + "low" => Ok(GeneratedField::Low), + "high" => Ok(GeneratedField::High), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -2051,49 +1493,98 @@ impl<'de> serde::Deserialize<'de> for AvroFormat { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AvroFormat; + type Value = BetweenNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AvroFormat") + formatter.write_str("struct datafusion.BetweenNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map_.next_key::()?.is_some() { - let _ = map_.next_value::()?; + let mut expr__ = None; + let mut negated__ = None; + let mut low__ = None; + let mut high__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::Low => { + if low__.is_some() { + return Err(serde::de::Error::duplicate_field("low")); + } + low__ = map_.next_value()?; + } + GeneratedField::High => { + if high__.is_some() { + return Err(serde::de::Error::duplicate_field("high")); + } + high__ = map_.next_value()?; + } + } } - Ok(AvroFormat { + Ok(BetweenNode { + expr: expr__, + negated: negated__.unwrap_or_default(), + low: low__, + high: high__, }) } } - deserializer.deserialize_struct("datafusion.AvroFormat", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.BetweenNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AvroOptions { +impl serde::Serialize for BinaryExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let len = 0; - let struct_ser = serializer.serialize_struct("datafusion.AvroOptions", len)?; + let mut len = 0; + if !self.operands.is_empty() { + len += 1; + } + if !self.op.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.BinaryExprNode", len)?; + if !self.operands.is_empty() { + struct_ser.serialize_field("operands", &self.operands)?; + } + if !self.op.is_empty() { + struct_ser.serialize_field("op", &self.op)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for AvroOptions { +impl<'de> serde::Deserialize<'de> for BinaryExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "operands", + "op", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + Operands, + Op, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2114,7 +1605,11 @@ impl<'de> serde::Deserialize<'de> for AvroOptions { where E: serde::de::Error, { - Err(serde::de::Error::unknown_field(value, FIELDS)) + match value { + "operands" => Ok(GeneratedField::Operands), + "op" => Ok(GeneratedField::Op), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -2122,119 +1617,121 @@ impl<'de> serde::Deserialize<'de> for AvroOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AvroOptions; + type Value = BinaryExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AvroOptions") + formatter.write_str("struct datafusion.BinaryExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map_.next_key::()?.is_some() { - let _ = map_.next_value::()?; + let mut operands__ = None; + let mut op__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Operands => { + if operands__.is_some() { + return Err(serde::de::Error::duplicate_field("operands")); + } + operands__ = Some(map_.next_value()?); + } + GeneratedField::Op => { + if op__.is_some() { + return Err(serde::de::Error::duplicate_field("op")); + } + op__ = Some(map_.next_value()?); + } + } } - Ok(AvroOptions { + Ok(BinaryExprNode { + operands: operands__.unwrap_or_default(), + op: op__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.AvroOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.BinaryExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AvroScanExecNode { +impl serde::Serialize for BuiltInWindowFunction { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.base_conf.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.AvroScanExecNode", len)?; - if let Some(v) = self.base_conf.as_ref() { - struct_ser.serialize_field("baseConf", v)?; - } - struct_ser.end() + let variant = match self { + Self::Unspecified => "UNSPECIFIED", + Self::FirstValue => "FIRST_VALUE", + Self::LastValue => "LAST_VALUE", + Self::NthValue => "NTH_VALUE", + }; + serializer.serialize_str(variant) } } -impl<'de> serde::Deserialize<'de> for AvroScanExecNode { +impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "base_conf", - "baseConf", + "UNSPECIFIED", + "FIRST_VALUE", + "LAST_VALUE", + "NTH_VALUE", ]; - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - BaseConf, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; + struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = BuiltInWindowFunction; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AvroScanExecNode; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.AvroScanExecNode") + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, { - let mut base_conf__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::BaseConf => { - if base_conf__.is_some() { - return Err(serde::de::Error::duplicate_field("baseConf")); - } - base_conf__ = map_.next_value()?; - } - } + match value { + "UNSPECIFIED" => Ok(BuiltInWindowFunction::Unspecified), + "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), + "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), + "NTH_VALUE" => Ok(BuiltInWindowFunction::NthValue), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } - Ok(AvroScanExecNode { - base_conf: base_conf__, - }) } } - deserializer.deserialize_struct("datafusion.AvroScanExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for BareTableReference { +impl serde::Serialize for CaseNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2242,29 +1739,47 @@ impl serde::Serialize for BareTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.table.is_empty() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.BareTableReference", len)?; - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if !self.when_then_expr.is_empty() { + len += 1; + } + if self.else_expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CaseNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if !self.when_then_expr.is_empty() { + struct_ser.serialize_field("whenThenExpr", &self.when_then_expr)?; + } + if let Some(v) = self.else_expr.as_ref() { + struct_ser.serialize_field("elseExpr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for BareTableReference { +impl<'de> serde::Deserialize<'de> for CaseNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "table", + "expr", + "when_then_expr", + "whenThenExpr", + "else_expr", + "elseExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Table, + Expr, + WhenThenExpr, + ElseExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2286,7 +1801,9 @@ impl<'de> serde::Deserialize<'de> for BareTableReference { E: serde::de::Error, { match value { - "table" => Ok(GeneratedField::Table), + "expr" => Ok(GeneratedField::Expr), + "whenThenExpr" | "when_then_expr" => Ok(GeneratedField::WhenThenExpr), + "elseExpr" | "else_expr" => Ok(GeneratedField::ElseExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2296,36 +1813,52 @@ impl<'de> serde::Deserialize<'de> for BareTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = BareTableReference; + type Value = CaseNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.BareTableReference") + formatter.write_str("struct datafusion.CaseNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut table__ = None; + let mut expr__ = None; + let mut when_then_expr__ = None; + let mut else_expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - table__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; + } + GeneratedField::WhenThenExpr => { + if when_then_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("whenThenExpr")); + } + when_then_expr__ = Some(map_.next_value()?); + } + GeneratedField::ElseExpr => { + if else_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("elseExpr")); + } + else_expr__ = map_.next_value()?; } } } - Ok(BareTableReference { - table: table__.unwrap_or_default(), + Ok(CaseNode { + expr: expr__, + when_then_expr: when_then_expr__.unwrap_or_default(), + else_expr: else_expr__, }) } } - deserializer.deserialize_struct("datafusion.BareTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CaseNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for BetweenNode { +impl serde::Serialize for CastNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2336,32 +1869,20 @@ impl serde::Serialize for BetweenNode { if self.expr.is_some() { len += 1; } - if self.negated { - len += 1; - } - if self.low.is_some() { - len += 1; - } - if self.high.is_some() { + if self.arrow_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.BetweenNode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if let Some(v) = self.low.as_ref() { - struct_ser.serialize_field("low", v)?; - } - if let Some(v) = self.high.as_ref() { - struct_ser.serialize_field("high", v)?; + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for BetweenNode { +impl<'de> serde::Deserialize<'de> for CastNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -2369,17 +1890,14 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { { const FIELDS: &[&str] = &[ "expr", - "negated", - "low", - "high", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - Negated, - Low, - High, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2402,9 +1920,7 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { { match value { "expr" => Ok(GeneratedField::Expr), - "negated" => Ok(GeneratedField::Negated), - "low" => Ok(GeneratedField::Low), - "high" => Ok(GeneratedField::High), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2414,20 +1930,18 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = BetweenNode; + type Value = CastNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.BetweenNode") + formatter.write_str("struct datafusion.CastNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut negated__ = None; - let mut low__ = None; - let mut high__ = None; + let mut arrow_type__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -2436,38 +1950,24 @@ impl<'de> serde::Deserialize<'de> for BetweenNode { } expr__ = map_.next_value()?; } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - GeneratedField::Low => { - if low__.is_some() { - return Err(serde::de::Error::duplicate_field("low")); - } - low__ = map_.next_value()?; - } - GeneratedField::High => { - if high__.is_some() { - return Err(serde::de::Error::duplicate_field("high")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - high__ = map_.next_value()?; + arrow_type__ = map_.next_value()?; } } } - Ok(BetweenNode { + Ok(CastNode { expr: expr__, - negated: negated__.unwrap_or_default(), - low: low__, - high: high__, + arrow_type: arrow_type__, }) } } - deserializer.deserialize_struct("datafusion.BetweenNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CastNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for BinaryExprNode { +impl serde::Serialize for CoalesceBatchesExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2475,37 +1975,46 @@ impl serde::Serialize for BinaryExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.operands.is_empty() { + if self.input.is_some() { len += 1; } - if !self.op.is_empty() { + if self.target_batch_size != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.BinaryExprNode", len)?; - if !self.operands.is_empty() { - struct_ser.serialize_field("operands", &self.operands)?; + if self.fetch.is_some() { + len += 1; } - if !self.op.is_empty() { - struct_ser.serialize_field("op", &self.op)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CoalesceBatchesExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.target_batch_size != 0 { + struct_ser.serialize_field("targetBatchSize", &self.target_batch_size)?; + } + if let Some(v) = self.fetch.as_ref() { + struct_ser.serialize_field("fetch", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for BinaryExprNode { +impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "operands", - "op", + "input", + "target_batch_size", + "targetBatchSize", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Operands, - Op, + Input, + TargetBatchSize, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2527,9 +2036,10 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { E: serde::de::Error, { match value { - "operands" => Ok(GeneratedField::Operands), - "op" => Ok(GeneratedField::Op), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + "input" => Ok(GeneratedField::Input), + "targetBatchSize" | "target_batch_size" => Ok(GeneratedField::TargetBatchSize), + "fetch" => Ok(GeneratedField::Fetch), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } } @@ -2538,142 +2048,147 @@ impl<'de> serde::Deserialize<'de> for BinaryExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = BinaryExprNode; + type Value = CoalesceBatchesExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.BinaryExprNode") + formatter.write_str("struct datafusion.CoalesceBatchesExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut operands__ = None; - let mut op__ = None; + let mut input__ = None; + let mut target_batch_size__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Operands => { - if operands__.is_some() { - return Err(serde::de::Error::duplicate_field("operands")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - operands__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Op => { - if op__.is_some() { - return Err(serde::de::Error::duplicate_field("op")); + GeneratedField::TargetBatchSize => { + if target_batch_size__.is_some() { + return Err(serde::de::Error::duplicate_field("targetBatchSize")); } - op__ = Some(map_.next_value()?); + target_batch_size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| x.0) + ; } } } - Ok(BinaryExprNode { - operands: operands__.unwrap_or_default(), - op: op__.unwrap_or_default(), + Ok(CoalesceBatchesExecNode { + input: input__, + target_batch_size: target_batch_size__.unwrap_or_default(), + fetch: fetch__, }) } } - deserializer.deserialize_struct("datafusion.BinaryExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CoalesceBatchesExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for BuiltInWindowFunction { +impl serde::Serialize for CoalescePartitionsExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::RowNumber => "ROW_NUMBER", - Self::Rank => "RANK", - Self::DenseRank => "DENSE_RANK", - Self::PercentRank => "PERCENT_RANK", - Self::CumeDist => "CUME_DIST", - Self::Ntile => "NTILE", - Self::Lag => "LAG", - Self::Lead => "LEAD", - Self::FirstValue => "FIRST_VALUE", - Self::LastValue => "LAST_VALUE", - Self::NthValue => "NTH_VALUE", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CoalescePartitionsExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for BuiltInWindowFunction { +impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "ROW_NUMBER", - "RANK", - "DENSE_RANK", - "PERCENT_RANK", - "CUME_DIST", - "NTILE", - "LAG", - "LEAD", - "FIRST_VALUE", - "LAST_VALUE", - "NTH_VALUE", + "input", ]; - struct GeneratedVisitor; + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Input, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = BuiltInWindowFunction; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "input" => Ok(GeneratedField::Input), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CoalescePartitionsExecNode; - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CoalescePartitionsExecNode") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "ROW_NUMBER" => Ok(BuiltInWindowFunction::RowNumber), - "RANK" => Ok(BuiltInWindowFunction::Rank), - "DENSE_RANK" => Ok(BuiltInWindowFunction::DenseRank), - "PERCENT_RANK" => Ok(BuiltInWindowFunction::PercentRank), - "CUME_DIST" => Ok(BuiltInWindowFunction::CumeDist), - "NTILE" => Ok(BuiltInWindowFunction::Ntile), - "LAG" => Ok(BuiltInWindowFunction::Lag), - "LEAD" => Ok(BuiltInWindowFunction::Lead), - "FIRST_VALUE" => Ok(BuiltInWindowFunction::FirstValue), - "LAST_VALUE" => Ok(BuiltInWindowFunction::LastValue), - "NTH_VALUE" => Ok(BuiltInWindowFunction::NthValue), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut input__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + } } + Ok(CoalescePartitionsExecNode { + input: input__, + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CoalescePartitionsExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CaseNode { +impl serde::Serialize for ColumnIndex { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2681,47 +2196,39 @@ impl serde::Serialize for CaseNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if !self.when_then_expr.is_empty() { + if self.index != 0 { len += 1; } - if self.else_expr.is_some() { + if self.side != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CaseNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if !self.when_then_expr.is_empty() { - struct_ser.serialize_field("whenThenExpr", &self.when_then_expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ColumnIndex", len)?; + if self.index != 0 { + struct_ser.serialize_field("index", &self.index)?; } - if let Some(v) = self.else_expr.as_ref() { - struct_ser.serialize_field("elseExpr", v)?; + if self.side != 0 { + let v = super::datafusion_common::JoinSide::try_from(self.side) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.side)))?; + struct_ser.serialize_field("side", &v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CaseNode { +impl<'de> serde::Deserialize<'de> for ColumnIndex { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "when_then_expr", - "whenThenExpr", - "else_expr", - "elseExpr", + "index", + "side", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - WhenThenExpr, - ElseExpr, + Index, + Side, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2743,9 +2250,8 @@ impl<'de> serde::Deserialize<'de> for CaseNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "whenThenExpr" | "when_then_expr" => Ok(GeneratedField::WhenThenExpr), - "elseExpr" | "else_expr" => Ok(GeneratedField::ElseExpr), + "index" => Ok(GeneratedField::Index), + "side" => Ok(GeneratedField::Side), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2755,52 +2261,46 @@ impl<'de> serde::Deserialize<'de> for CaseNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CaseNode; + type Value = ColumnIndex; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CaseNode") + formatter.write_str("struct datafusion.ColumnIndex") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut when_then_expr__ = None; - let mut else_expr__ = None; + let mut index__ = None; + let mut side__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::WhenThenExpr => { - if when_then_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("whenThenExpr")); + GeneratedField::Index => { + if index__.is_some() { + return Err(serde::de::Error::duplicate_field("index")); } - when_then_expr__ = Some(map_.next_value()?); + index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::ElseExpr => { - if else_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("elseExpr")); + GeneratedField::Side => { + if side__.is_some() { + return Err(serde::de::Error::duplicate_field("side")); } - else_expr__ = map_.next_value()?; + side__ = Some(map_.next_value::()? as i32); } } } - Ok(CaseNode { - expr: expr__, - when_then_expr: when_then_expr__.unwrap_or_default(), - else_expr: else_expr__, + Ok(ColumnIndex { + index: index__.unwrap_or_default(), + side: side__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CaseNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ColumnIndex", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CastNode { +impl serde::Serialize for ColumnUnnestListItem { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2808,38 +2308,38 @@ impl serde::Serialize for CastNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.input_index != 0 { len += 1; } - if self.arrow_type.is_some() { + if self.recursion.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CastNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ColumnUnnestListItem", len)?; + if self.input_index != 0 { + struct_ser.serialize_field("inputIndex", &self.input_index)?; } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + if let Some(v) = self.recursion.as_ref() { + struct_ser.serialize_field("recursion", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CastNode { +impl<'de> serde::Deserialize<'de> for ColumnUnnestListItem { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "arrow_type", - "arrowType", + "input_index", + "inputIndex", + "recursion", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - ArrowType, + InputIndex, + Recursion, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2861,8 +2361,8 @@ impl<'de> serde::Deserialize<'de> for CastNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "inputIndex" | "input_index" => Ok(GeneratedField::InputIndex), + "recursion" => Ok(GeneratedField::Recursion), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2872,44 +2372,46 @@ impl<'de> serde::Deserialize<'de> for CastNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CastNode; + type Value = ColumnUnnestListItem; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CastNode") + formatter.write_str("struct datafusion.ColumnUnnestListItem") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut arrow_type__ = None; + let mut input_index__ = None; + let mut recursion__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::InputIndex => { + if input_index__.is_some() { + return Err(serde::de::Error::duplicate_field("inputIndex")); } - expr__ = map_.next_value()?; + input_index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::Recursion => { + if recursion__.is_some() { + return Err(serde::de::Error::duplicate_field("recursion")); } - arrow_type__ = map_.next_value()?; + recursion__ = map_.next_value()?; } } } - Ok(CastNode { - expr: expr__, - arrow_type: arrow_type__, + Ok(ColumnUnnestListItem { + input_index: input_index__.unwrap_or_default(), + recursion: recursion__, }) } } - deserializer.deserialize_struct("datafusion.CastNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ColumnUnnestListItem", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CoalesceBatchesExecNode { +impl serde::Serialize for ColumnUnnestListRecursion { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -2917,38 +2419,38 @@ impl serde::Serialize for CoalesceBatchesExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.output_column.is_some() { len += 1; } - if self.target_batch_size != 0 { + if self.depth != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CoalesceBatchesExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ColumnUnnestListRecursion", len)?; + if let Some(v) = self.output_column.as_ref() { + struct_ser.serialize_field("outputColumn", v)?; } - if self.target_batch_size != 0 { - struct_ser.serialize_field("targetBatchSize", &self.target_batch_size)?; + if self.depth != 0 { + struct_ser.serialize_field("depth", &self.depth)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { +impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursion { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "target_batch_size", - "targetBatchSize", + "output_column", + "outputColumn", + "depth", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - TargetBatchSize, + OutputColumn, + Depth, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -2970,8 +2472,8 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "targetBatchSize" | "target_batch_size" => Ok(GeneratedField::TargetBatchSize), + "outputColumn" | "output_column" => Ok(GeneratedField::OutputColumn), + "depth" => Ok(GeneratedField::Depth), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -2981,46 +2483,46 @@ impl<'de> serde::Deserialize<'de> for CoalesceBatchesExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CoalesceBatchesExecNode; + type Value = ColumnUnnestListRecursion; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CoalesceBatchesExecNode") + formatter.write_str("struct datafusion.ColumnUnnestListRecursion") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut target_batch_size__ = None; + let mut output_column__ = None; + let mut depth__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::OutputColumn => { + if output_column__.is_some() { + return Err(serde::de::Error::duplicate_field("outputColumn")); } - input__ = map_.next_value()?; + output_column__ = map_.next_value()?; } - GeneratedField::TargetBatchSize => { - if target_batch_size__.is_some() { - return Err(serde::de::Error::duplicate_field("targetBatchSize")); + GeneratedField::Depth => { + if depth__.is_some() { + return Err(serde::de::Error::duplicate_field("depth")); } - target_batch_size__ = + depth__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } } - Ok(CoalesceBatchesExecNode { - input: input__, - target_batch_size: target_batch_size__.unwrap_or_default(), + Ok(ColumnUnnestListRecursion { + output_column: output_column__, + depth: depth__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CoalesceBatchesExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ColumnUnnestListRecursion", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CoalescePartitionsExecNode { +impl serde::Serialize for ColumnUnnestListRecursions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3028,29 +2530,29 @@ impl serde::Serialize for CoalescePartitionsExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.recursions.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CoalescePartitionsExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ColumnUnnestListRecursions", len)?; + if !self.recursions.is_empty() { + struct_ser.serialize_field("recursions", &self.recursions)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { +impl<'de> serde::Deserialize<'de> for ColumnUnnestListRecursions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", + "recursions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, + Recursions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3072,7 +2574,7 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), + "recursions" => Ok(GeneratedField::Recursions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3082,36 +2584,36 @@ impl<'de> serde::Deserialize<'de> for CoalescePartitionsExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CoalescePartitionsExecNode; + type Value = ColumnUnnestListRecursions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CoalescePartitionsExecNode") + formatter.write_str("struct datafusion.ColumnUnnestListRecursions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; + let mut recursions__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Recursions => { + if recursions__.is_some() { + return Err(serde::de::Error::duplicate_field("recursions")); } - input__ = map_.next_value()?; + recursions__ = Some(map_.next_value()?); } } } - Ok(CoalescePartitionsExecNode { - input: input__, + Ok(ColumnUnnestListRecursions { + recursions: recursions__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CoalescePartitionsExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ColumnUnnestListRecursions", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Column { +impl serde::Serialize for CopyToNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3119,37 +2621,58 @@ impl serde::Serialize for Column { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if self.input.is_some() { len += 1; } - if self.relation.is_some() { + if !self.output_url.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Column", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; + if !self.file_type.is_empty() { + len += 1; + } + if !self.partition_by.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.output_url.is_empty() { + struct_ser.serialize_field("outputUrl", &self.output_url)?; } - if let Some(v) = self.relation.as_ref() { - struct_ser.serialize_field("relation", v)?; + if !self.file_type.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("fileType", pbjson::private::base64::encode(&self.file_type).as_str())?; + } + if !self.partition_by.is_empty() { + struct_ser.serialize_field("partitionBy", &self.partition_by)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Column { +impl<'de> serde::Deserialize<'de> for CopyToNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "relation", + "input", + "output_url", + "outputUrl", + "file_type", + "fileType", + "partition_by", + "partitionBy", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Relation, + Input, + OutputUrl, + FileType, + PartitionBy, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3171,8 +2694,10 @@ impl<'de> serde::Deserialize<'de> for Column { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "relation" => Ok(GeneratedField::Relation), + "input" => Ok(GeneratedField::Input), + "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), + "fileType" | "file_type" => Ok(GeneratedField::FileType), + "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3182,44 +2707,62 @@ impl<'de> serde::Deserialize<'de> for Column { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Column; + type Value = CopyToNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Column") + formatter.write_str("struct datafusion.CopyToNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut relation__ = None; + let mut input__ = None; + let mut output_url__ = None; + let mut file_type__ = None; + let mut partition_by__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - name__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Relation => { - if relation__.is_some() { - return Err(serde::de::Error::duplicate_field("relation")); + GeneratedField::OutputUrl => { + if output_url__.is_some() { + return Err(serde::de::Error::duplicate_field("outputUrl")); + } + output_url__ = Some(map_.next_value()?); + } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } + GeneratedField::PartitionBy => { + if partition_by__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionBy")); } - relation__ = map_.next_value()?; + partition_by__ = Some(map_.next_value()?); } } } - Ok(Column { - name: name__.unwrap_or_default(), - relation: relation__, + Ok(CopyToNode { + input: input__, + output_url: output_url__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), + partition_by: partition_by__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Column", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CopyToNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnIndex { +impl serde::Serialize for CreateCatalogNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3227,39 +2770,47 @@ impl serde::Serialize for ColumnIndex { { use serde::ser::SerializeStruct; let mut len = 0; - if self.index != 0 { + if !self.catalog_name.is_empty() { len += 1; } - if self.side != 0 { + if self.if_not_exists { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnIndex", len)?; - if self.index != 0 { - struct_ser.serialize_field("index", &self.index)?; + if self.schema.is_some() { + len += 1; } - if self.side != 0 { - let v = JoinSide::try_from(self.side) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.side)))?; - struct_ser.serialize_field("side", &v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogNode", len)?; + if !self.catalog_name.is_empty() { + struct_ser.serialize_field("catalogName", &self.catalog_name)?; + } + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ColumnIndex { +impl<'de> serde::Deserialize<'de> for CreateCatalogNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "index", - "side", + "catalog_name", + "catalogName", + "if_not_exists", + "ifNotExists", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Index, - Side, + CatalogName, + IfNotExists, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3281,8 +2832,9 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { E: serde::de::Error, { match value { - "index" => Ok(GeneratedField::Index), - "side" => Ok(GeneratedField::Side), + "catalogName" | "catalog_name" => Ok(GeneratedField::CatalogName), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3292,46 +2844,52 @@ impl<'de> serde::Deserialize<'de> for ColumnIndex { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnIndex; + type Value = CreateCatalogNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnIndex") + formatter.write_str("struct datafusion.CreateCatalogNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut index__ = None; - let mut side__ = None; + let mut catalog_name__ = None; + let mut if_not_exists__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Index => { - if index__.is_some() { - return Err(serde::de::Error::duplicate_field("index")); + GeneratedField::CatalogName => { + if catalog_name__.is_some() { + return Err(serde::de::Error::duplicate_field("catalogName")); } - index__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + catalog_name__ = Some(map_.next_value()?); } - GeneratedField::Side => { - if side__.is_some() { - return Err(serde::de::Error::duplicate_field("side")); + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); } - side__ = Some(map_.next_value::()? as i32); + if_not_exists__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; } } } - Ok(ColumnIndex { - index: index__.unwrap_or_default(), - side: side__.unwrap_or_default(), + Ok(CreateCatalogNode { + catalog_name: catalog_name__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.ColumnIndex", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateCatalogNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnOptions { +impl serde::Serialize for CreateCatalogSchemaNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3339,124 +2897,47 @@ impl serde::Serialize for ColumnOptions { { use serde::ser::SerializeStruct; let mut len = 0; - if self.bloom_filter_enabled_opt.is_some() { + if !self.schema_name.is_empty() { len += 1; } - if self.encoding_opt.is_some() { + if self.if_not_exists { len += 1; } - if self.dictionary_enabled_opt.is_some() { + if self.schema.is_some() { len += 1; } - if self.compression_opt.is_some() { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogSchemaNode", len)?; + if !self.schema_name.is_empty() { + struct_ser.serialize_field("schemaName", &self.schema_name)?; } - if self.statistics_enabled_opt.is_some() { - len += 1; - } - if self.bloom_filter_fpp_opt.is_some() { - len += 1; - } - if self.bloom_filter_ndv_opt.is_some() { - len += 1; - } - if self.max_statistics_size_opt.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnOptions", len)?; - if let Some(v) = self.bloom_filter_enabled_opt.as_ref() { - match v { - column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v) => { - struct_ser.serialize_field("bloomFilterEnabled", v)?; - } - } - } - if let Some(v) = self.encoding_opt.as_ref() { - match v { - column_options::EncodingOpt::Encoding(v) => { - struct_ser.serialize_field("encoding", v)?; - } - } - } - if let Some(v) = self.dictionary_enabled_opt.as_ref() { - match v { - column_options::DictionaryEnabledOpt::DictionaryEnabled(v) => { - struct_ser.serialize_field("dictionaryEnabled", v)?; - } - } - } - if let Some(v) = self.compression_opt.as_ref() { - match v { - column_options::CompressionOpt::Compression(v) => { - struct_ser.serialize_field("compression", v)?; - } - } - } - if let Some(v) = self.statistics_enabled_opt.as_ref() { - match v { - column_options::StatisticsEnabledOpt::StatisticsEnabled(v) => { - struct_ser.serialize_field("statisticsEnabled", v)?; - } - } - } - if let Some(v) = self.bloom_filter_fpp_opt.as_ref() { - match v { - column_options::BloomFilterFppOpt::BloomFilterFpp(v) => { - struct_ser.serialize_field("bloomFilterFpp", v)?; - } - } - } - if let Some(v) = self.bloom_filter_ndv_opt.as_ref() { - match v { - column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; - } - } + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; } - if let Some(v) = self.max_statistics_size_opt.as_ref() { - match v { - column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { - struct_ser.serialize_field("maxStatisticsSize", v)?; - } - } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ColumnOptions { +impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "bloom_filter_enabled", - "bloomFilterEnabled", - "encoding", - "dictionary_enabled", - "dictionaryEnabled", - "compression", - "statistics_enabled", - "statisticsEnabled", - "bloom_filter_fpp", - "bloomFilterFpp", - "bloom_filter_ndv", - "bloomFilterNdv", - "max_statistics_size", - "maxStatisticsSize", + "schema_name", + "schemaName", + "if_not_exists", + "ifNotExists", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - BloomFilterEnabled, - Encoding, - DictionaryEnabled, - Compression, - StatisticsEnabled, - BloomFilterFpp, - BloomFilterNdv, - MaxStatisticsSize, + SchemaName, + IfNotExists, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3478,14 +2959,9 @@ impl<'de> serde::Deserialize<'de> for ColumnOptions { E: serde::de::Error, { match value { - "bloomFilterEnabled" | "bloom_filter_enabled" => Ok(GeneratedField::BloomFilterEnabled), - "encoding" => Ok(GeneratedField::Encoding), - "dictionaryEnabled" | "dictionary_enabled" => Ok(GeneratedField::DictionaryEnabled), - "compression" => Ok(GeneratedField::Compression), - "statisticsEnabled" | "statistics_enabled" => Ok(GeneratedField::StatisticsEnabled), - "bloomFilterFpp" | "bloom_filter_fpp" => Ok(GeneratedField::BloomFilterFpp), - "bloomFilterNdv" | "bloom_filter_ndv" => Ok(GeneratedField::BloomFilterNdv), - "maxStatisticsSize" | "max_statistics_size" => Ok(GeneratedField::MaxStatisticsSize), + "schemaName" | "schema_name" => Ok(GeneratedField::SchemaName), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3495,92 +2971,52 @@ impl<'de> serde::Deserialize<'de> for ColumnOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnOptions; + type Value = CreateCatalogSchemaNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnOptions") + formatter.write_str("struct datafusion.CreateCatalogSchemaNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut bloom_filter_enabled_opt__ = None; - let mut encoding_opt__ = None; - let mut dictionary_enabled_opt__ = None; - let mut compression_opt__ = None; - let mut statistics_enabled_opt__ = None; - let mut bloom_filter_fpp_opt__ = None; - let mut bloom_filter_ndv_opt__ = None; - let mut max_statistics_size_opt__ = None; + let mut schema_name__ = None; + let mut if_not_exists__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::BloomFilterEnabled => { - if bloom_filter_enabled_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("bloomFilterEnabled")); - } - bloom_filter_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(column_options::BloomFilterEnabledOpt::BloomFilterEnabled); - } - GeneratedField::Encoding => { - if encoding_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("encoding")); - } - encoding_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(column_options::EncodingOpt::Encoding); - } - GeneratedField::DictionaryEnabled => { - if dictionary_enabled_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("dictionaryEnabled")); - } - dictionary_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(column_options::DictionaryEnabledOpt::DictionaryEnabled); - } - GeneratedField::Compression => { - if compression_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("compression")); - } - compression_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(column_options::CompressionOpt::Compression); - } - GeneratedField::StatisticsEnabled => { - if statistics_enabled_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("statisticsEnabled")); - } - statistics_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(column_options::StatisticsEnabledOpt::StatisticsEnabled); - } - GeneratedField::BloomFilterFpp => { - if bloom_filter_fpp_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("bloomFilterFpp")); + GeneratedField::SchemaName => { + if schema_name__.is_some() { + return Err(serde::de::Error::duplicate_field("schemaName")); } - bloom_filter_fpp_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| column_options::BloomFilterFppOpt::BloomFilterFpp(x.0)); + schema_name__ = Some(map_.next_value()?); } - GeneratedField::BloomFilterNdv => { - if bloom_filter_ndv_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("bloomFilterNdv")); + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); } - bloom_filter_ndv_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| column_options::BloomFilterNdvOpt::BloomFilterNdv(x.0)); + if_not_exists__ = Some(map_.next_value()?); } - GeneratedField::MaxStatisticsSize => { - if max_statistics_size_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("maxStatisticsSize")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - max_statistics_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(x.0)); + schema__ = map_.next_value()?; } } } - Ok(ColumnOptions { - bloom_filter_enabled_opt: bloom_filter_enabled_opt__, - encoding_opt: encoding_opt__, - dictionary_enabled_opt: dictionary_enabled_opt__, - compression_opt: compression_opt__, - statistics_enabled_opt: statistics_enabled_opt__, - bloom_filter_fpp_opt: bloom_filter_fpp_opt__, - bloom_filter_ndv_opt: bloom_filter_ndv_opt__, - max_statistics_size_opt: max_statistics_size_opt__, + Ok(CreateCatalogSchemaNode { + schema_name: schema_name__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.ColumnOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateCatalogSchemaNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnRelation { +impl serde::Serialize for CreateExternalTableNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3588,29 +3024,130 @@ impl serde::Serialize for ColumnRelation { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.relation.is_empty() { + if self.name.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnRelation", len)?; - if !self.relation.is_empty() { - struct_ser.serialize_field("relation", &self.relation)?; + if !self.location.is_empty() { + len += 1; + } + if !self.file_type.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.if_not_exists { + len += 1; + } + if self.temporary { + len += 1; + } + if !self.definition.is_empty() { + len += 1; + } + if !self.order_exprs.is_empty() { + len += 1; + } + if self.unbounded { + len += 1; + } + if !self.options.is_empty() { + len += 1; + } + if self.constraints.is_some() { + len += 1; + } + if !self.column_defaults.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + if !self.location.is_empty() { + struct_ser.serialize_field("location", &self.location)?; + } + if !self.file_type.is_empty() { + struct_ser.serialize_field("fileType", &self.file_type)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.if_not_exists { + struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; + } + if self.temporary { + struct_ser.serialize_field("temporary", &self.temporary)?; + } + if !self.definition.is_empty() { + struct_ser.serialize_field("definition", &self.definition)?; + } + if !self.order_exprs.is_empty() { + struct_ser.serialize_field("orderExprs", &self.order_exprs)?; + } + if self.unbounded { + struct_ser.serialize_field("unbounded", &self.unbounded)?; + } + if !self.options.is_empty() { + struct_ser.serialize_field("options", &self.options)?; + } + if let Some(v) = self.constraints.as_ref() { + struct_ser.serialize_field("constraints", v)?; + } + if !self.column_defaults.is_empty() { + struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ColumnRelation { +impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "relation", + "name", + "location", + "file_type", + "fileType", + "schema", + "table_partition_cols", + "tablePartitionCols", + "if_not_exists", + "ifNotExists", + "temporary", + "definition", + "order_exprs", + "orderExprs", + "unbounded", + "options", + "constraints", + "column_defaults", + "columnDefaults", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Relation, + Name, + Location, + FileType, + Schema, + TablePartitionCols, + IfNotExists, + Temporary, + Definition, + OrderExprs, + Unbounded, + Options, + Constraints, + ColumnDefaults, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3632,7 +3169,19 @@ impl<'de> serde::Deserialize<'de> for ColumnRelation { E: serde::de::Error, { match value { - "relation" => Ok(GeneratedField::Relation), + "name" => Ok(GeneratedField::Name), + "location" => Ok(GeneratedField::Location), + "fileType" | "file_type" => Ok(GeneratedField::FileType), + "schema" => Ok(GeneratedField::Schema), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "temporary" => Ok(GeneratedField::Temporary), + "definition" => Ok(GeneratedField::Definition), + "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), + "unbounded" => Ok(GeneratedField::Unbounded), + "options" => Ok(GeneratedField::Options), + "constraints" => Ok(GeneratedField::Constraints), + "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3642,75 +3191,199 @@ impl<'de> serde::Deserialize<'de> for ColumnRelation { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnRelation; + type Value = CreateExternalTableNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnRelation") + formatter.write_str("struct datafusion.CreateExternalTableNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut relation__ = None; + let mut name__ = None; + let mut location__ = None; + let mut file_type__ = None; + let mut schema__ = None; + let mut table_partition_cols__ = None; + let mut if_not_exists__ = None; + let mut temporary__ = None; + let mut definition__ = None; + let mut order_exprs__ = None; + let mut unbounded__ = None; + let mut options__ = None; + let mut constraints__ = None; + let mut column_defaults__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Relation => { - if relation__.is_some() { - return Err(serde::de::Error::duplicate_field("relation")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - relation__ = Some(map_.next_value()?); + name__ = map_.next_value()?; } - } - } - Ok(ColumnRelation { - relation: relation__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.ColumnRelation", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ColumnSpecificOptions { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.column_name.is_empty() { - len += 1; - } - if self.options.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnSpecificOptions", len)?; - if !self.column_name.is_empty() { - struct_ser.serialize_field("columnName", &self.column_name)?; - } - if let Some(v) = self.options.as_ref() { - struct_ser.serialize_field("options", v)?; + GeneratedField::Location => { + if location__.is_some() { + return Err(serde::de::Error::duplicate_field("location")); + } + location__ = Some(map_.next_value()?); + } + GeneratedField::FileType => { + if file_type__.is_some() { + return Err(serde::de::Error::duplicate_field("fileType")); + } + file_type__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::IfNotExists => { + if if_not_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifNotExists")); + } + if_not_exists__ = Some(map_.next_value()?); + } + GeneratedField::Temporary => { + if temporary__.is_some() { + return Err(serde::de::Error::duplicate_field("temporary")); + } + temporary__ = Some(map_.next_value()?); + } + GeneratedField::Definition => { + if definition__.is_some() { + return Err(serde::de::Error::duplicate_field("definition")); + } + definition__ = Some(map_.next_value()?); + } + GeneratedField::OrderExprs => { + if order_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("orderExprs")); + } + order_exprs__ = Some(map_.next_value()?); + } + GeneratedField::Unbounded => { + if unbounded__.is_some() { + return Err(serde::de::Error::duplicate_field("unbounded")); + } + unbounded__ = Some(map_.next_value()?); + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = Some( + map_.next_value::>()? + ); + } + GeneratedField::Constraints => { + if constraints__.is_some() { + return Err(serde::de::Error::duplicate_field("constraints")); + } + constraints__ = map_.next_value()?; + } + GeneratedField::ColumnDefaults => { + if column_defaults__.is_some() { + return Err(serde::de::Error::duplicate_field("columnDefaults")); + } + column_defaults__ = Some( + map_.next_value::>()? + ); + } + } + } + Ok(CreateExternalTableNode { + name: name__, + location: location__.unwrap_or_default(), + file_type: file_type__.unwrap_or_default(), + schema: schema__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + if_not_exists: if_not_exists__.unwrap_or_default(), + temporary: temporary__.unwrap_or_default(), + definition: definition__.unwrap_or_default(), + order_exprs: order_exprs__.unwrap_or_default(), + unbounded: unbounded__.unwrap_or_default(), + options: options__.unwrap_or_default(), + constraints: constraints__, + column_defaults: column_defaults__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.CreateExternalTableNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for CreateViewNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.name.is_some() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + if self.or_replace { + len += 1; + } + if self.temporary { + len += 1; + } + if !self.definition.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CreateViewNode", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.or_replace { + struct_ser.serialize_field("orReplace", &self.or_replace)?; + } + if self.temporary { + struct_ser.serialize_field("temporary", &self.temporary)?; + } + if !self.definition.is_empty() { + struct_ser.serialize_field("definition", &self.definition)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ColumnSpecificOptions { +impl<'de> serde::Deserialize<'de> for CreateViewNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "column_name", - "columnName", - "options", + "name", + "input", + "or_replace", + "orReplace", + "temporary", + "definition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ColumnName, - Options, + Name, + Input, + OrReplace, + Temporary, + Definition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3732,8 +3405,11 @@ impl<'de> serde::Deserialize<'de> for ColumnSpecificOptions { E: serde::de::Error, { match value { - "columnName" | "column_name" => Ok(GeneratedField::ColumnName), - "options" => Ok(GeneratedField::Options), + "name" => Ok(GeneratedField::Name), + "input" => Ok(GeneratedField::Input), + "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), + "temporary" => Ok(GeneratedField::Temporary), + "definition" => Ok(GeneratedField::Definition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3743,44 +3419,68 @@ impl<'de> serde::Deserialize<'de> for ColumnSpecificOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnSpecificOptions; + type Value = CreateViewNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnSpecificOptions") + formatter.write_str("struct datafusion.CreateViewNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut column_name__ = None; - let mut options__ = None; + let mut name__ = None; + let mut input__ = None; + let mut or_replace__ = None; + let mut temporary__ = None; + let mut definition__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ColumnName => { - if column_name__.is_some() { - return Err(serde::de::Error::duplicate_field("columnName")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - column_name__ = Some(map_.next_value()?); + name__ = map_.next_value()?; } - GeneratedField::Options => { - if options__.is_some() { - return Err(serde::de::Error::duplicate_field("options")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - options__ = map_.next_value()?; + input__ = map_.next_value()?; + } + GeneratedField::OrReplace => { + if or_replace__.is_some() { + return Err(serde::de::Error::duplicate_field("orReplace")); + } + or_replace__ = Some(map_.next_value()?); + } + GeneratedField::Temporary => { + if temporary__.is_some() { + return Err(serde::de::Error::duplicate_field("temporary")); + } + temporary__ = Some(map_.next_value()?); + } + GeneratedField::Definition => { + if definition__.is_some() { + return Err(serde::de::Error::duplicate_field("definition")); + } + definition__ = Some(map_.next_value()?); } } } - Ok(ColumnSpecificOptions { - column_name: column_name__.unwrap_or_default(), - options: options__, + Ok(CreateViewNode { + name: name__, + input: input__, + or_replace: or_replace__.unwrap_or_default(), + temporary: temporary__.unwrap_or_default(), + definition: definition__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ColumnSpecificOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CreateViewNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ColumnStats { +impl serde::Serialize for CrossJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -3788,57 +3488,37 @@ impl serde::Serialize for ColumnStats { { use serde::ser::SerializeStruct; let mut len = 0; - if self.min_value.is_some() { - len += 1; - } - if self.max_value.is_some() { - len += 1; - } - if self.null_count.is_some() { + if self.left.is_some() { len += 1; } - if self.distinct_count.is_some() { + if self.right.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ColumnStats", len)?; - if let Some(v) = self.min_value.as_ref() { - struct_ser.serialize_field("minValue", v)?; - } - if let Some(v) = self.max_value.as_ref() { - struct_ser.serialize_field("maxValue", v)?; - } - if let Some(v) = self.null_count.as_ref() { - struct_ser.serialize_field("nullCount", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; } - if let Some(v) = self.distinct_count.as_ref() { - struct_ser.serialize_field("distinctCount", v)?; + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ColumnStats { +impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "min_value", - "minValue", - "max_value", - "maxValue", - "null_count", - "nullCount", - "distinct_count", - "distinctCount", + "left", + "right", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - MinValue, - MaxValue, - NullCount, - DistinctCount, + Left, + Right, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -3860,10 +3540,8 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { E: serde::de::Error, { match value { - "minValue" | "min_value" => Ok(GeneratedField::MinValue), - "maxValue" | "max_value" => Ok(GeneratedField::MaxValue), - "nullCount" | "null_count" => Ok(GeneratedField::NullCount), - "distinctCount" | "distinct_count" => Ok(GeneratedField::DistinctCount), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -3873,140 +3551,152 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ColumnStats; + type Value = CrossJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ColumnStats") + formatter.write_str("struct datafusion.CrossJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut min_value__ = None; - let mut max_value__ = None; - let mut null_count__ = None; - let mut distinct_count__ = None; + let mut left__ = None; + let mut right__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::MinValue => { - if min_value__.is_some() { - return Err(serde::de::Error::duplicate_field("minValue")); - } - min_value__ = map_.next_value()?; - } - GeneratedField::MaxValue => { - if max_value__.is_some() { - return Err(serde::de::Error::duplicate_field("maxValue")); - } - max_value__ = map_.next_value()?; - } - GeneratedField::NullCount => { - if null_count__.is_some() { - return Err(serde::de::Error::duplicate_field("nullCount")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - null_count__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::DistinctCount => { - if distinct_count__.is_some() { - return Err(serde::de::Error::duplicate_field("distinctCount")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - distinct_count__ = map_.next_value()?; + right__ = map_.next_value()?; } } } - Ok(ColumnStats { - min_value: min_value__, - max_value: max_value__, - null_count: null_count__, - distinct_count: distinct_count__, + Ok(CrossJoinExecNode { + left: left__, + right: right__, }) } } - deserializer.deserialize_struct("datafusion.ColumnStats", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CrossJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CompressionTypeVariant { +impl serde::Serialize for CrossJoinNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::Gzip => "GZIP", - Self::Bzip2 => "BZIP2", - Self::Xz => "XZ", - Self::Zstd => "ZSTD", - Self::Uncompressed => "UNCOMPRESSED", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CompressionTypeVariant { +impl<'de> serde::Deserialize<'de> for CrossJoinNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "GZIP", - "BZIP2", - "XZ", - "ZSTD", - "UNCOMPRESSED", + "left", + "right", ]; - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CompressionTypeVariant; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Left, + Right, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result where - E: serde::de::Error, + D: serde::Deserializer<'de>, { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CrossJoinNode; - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CrossJoinNode") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "GZIP" => Ok(CompressionTypeVariant::Gzip), - "BZIP2" => Ok(CompressionTypeVariant::Bzip2), - "XZ" => Ok(CompressionTypeVariant::Xz), - "ZSTD" => Ok(CompressionTypeVariant::Zstd), - "UNCOMPRESSED" => Ok(CompressionTypeVariant::Uncompressed), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut left__ = None; + let mut right__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); + } + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + } } + Ok(CrossJoinNode { + left: left__, + right: right__, + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CrossJoinNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Constraint { +impl serde::Serialize for CsvScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4014,39 +3704,88 @@ impl serde::Serialize for Constraint { { use serde::ser::SerializeStruct; let mut len = 0; - if self.constraint_mode.is_some() { + if self.base_conf.is_some() { + len += 1; + } + if self.has_header { + len += 1; + } + if !self.delimiter.is_empty() { + len += 1; + } + if !self.quote.is_empty() { + len += 1; + } + if self.newlines_in_values { + len += 1; + } + if self.optional_escape.is_some() { + len += 1; + } + if self.optional_comment.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Constraint", len)?; - if let Some(v) = self.constraint_mode.as_ref() { + let mut struct_ser = serializer.serialize_struct("datafusion.CsvScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; + } + if self.has_header { + struct_ser.serialize_field("hasHeader", &self.has_header)?; + } + if !self.delimiter.is_empty() { + struct_ser.serialize_field("delimiter", &self.delimiter)?; + } + if !self.quote.is_empty() { + struct_ser.serialize_field("quote", &self.quote)?; + } + if self.newlines_in_values { + struct_ser.serialize_field("newlinesInValues", &self.newlines_in_values)?; + } + if let Some(v) = self.optional_escape.as_ref() { match v { - constraint::ConstraintMode::PrimaryKey(v) => { - struct_ser.serialize_field("primaryKey", v)?; + csv_scan_exec_node::OptionalEscape::Escape(v) => { + struct_ser.serialize_field("escape", v)?; } - constraint::ConstraintMode::Unique(v) => { - struct_ser.serialize_field("unique", v)?; + } + } + if let Some(v) = self.optional_comment.as_ref() { + match v { + csv_scan_exec_node::OptionalComment::Comment(v) => { + struct_ser.serialize_field("comment", v)?; } } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Constraint { +impl<'de> serde::Deserialize<'de> for CsvScanExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "primary_key", - "primaryKey", - "unique", + "base_conf", + "baseConf", + "has_header", + "hasHeader", + "delimiter", + "quote", + "newlines_in_values", + "newlinesInValues", + "escape", + "comment", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - PrimaryKey, - Unique, + BaseConf, + HasHeader, + Delimiter, + Quote, + NewlinesInValues, + Escape, + Comment, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4068,8 +3807,13 @@ impl<'de> serde::Deserialize<'de> for Constraint { E: serde::de::Error, { match value { - "primaryKey" | "primary_key" => Ok(GeneratedField::PrimaryKey), - "unique" => Ok(GeneratedField::Unique), + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), + "delimiter" => Ok(GeneratedField::Delimiter), + "quote" => Ok(GeneratedField::Quote), + "newlinesInValues" | "newlines_in_values" => Ok(GeneratedField::NewlinesInValues), + "escape" => Ok(GeneratedField::Escape), + "comment" => Ok(GeneratedField::Comment), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4079,44 +3823,84 @@ impl<'de> serde::Deserialize<'de> for Constraint { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Constraint; + type Value = CsvScanExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Constraint") + formatter.write_str("struct datafusion.CsvScanExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut constraint_mode__ = None; + let mut base_conf__ = None; + let mut has_header__ = None; + let mut delimiter__ = None; + let mut quote__ = None; + let mut newlines_in_values__ = None; + let mut optional_escape__ = None; + let mut optional_comment__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::PrimaryKey => { - if constraint_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("primaryKey")); + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); } - constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::PrimaryKey) -; + base_conf__ = map_.next_value()?; + } + GeneratedField::HasHeader => { + if has_header__.is_some() { + return Err(serde::de::Error::duplicate_field("hasHeader")); + } + has_header__ = Some(map_.next_value()?); } - GeneratedField::Unique => { - if constraint_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("unique")); + GeneratedField::Delimiter => { + if delimiter__.is_some() { + return Err(serde::de::Error::duplicate_field("delimiter")); } - constraint_mode__ = map_.next_value::<::std::option::Option<_>>()?.map(constraint::ConstraintMode::Unique) -; + delimiter__ = Some(map_.next_value()?); + } + GeneratedField::Quote => { + if quote__.is_some() { + return Err(serde::de::Error::duplicate_field("quote")); + } + quote__ = Some(map_.next_value()?); + } + GeneratedField::NewlinesInValues => { + if newlines_in_values__.is_some() { + return Err(serde::de::Error::duplicate_field("newlinesInValues")); + } + newlines_in_values__ = Some(map_.next_value()?); + } + GeneratedField::Escape => { + if optional_escape__.is_some() { + return Err(serde::de::Error::duplicate_field("escape")); + } + optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalEscape::Escape); + } + GeneratedField::Comment => { + if optional_comment__.is_some() { + return Err(serde::de::Error::duplicate_field("comment")); + } + optional_comment__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalComment::Comment); } } } - Ok(Constraint { - constraint_mode: constraint_mode__, + Ok(CsvScanExecNode { + base_conf: base_conf__, + has_header: has_header__.unwrap_or_default(), + delimiter: delimiter__.unwrap_or_default(), + quote: quote__.unwrap_or_default(), + newlines_in_values: newlines_in_values__.unwrap_or_default(), + optional_escape: optional_escape__, + optional_comment: optional_comment__, }) } } - deserializer.deserialize_struct("datafusion.Constraint", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Constraints { +impl serde::Serialize for CsvSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4124,29 +3908,38 @@ impl serde::Serialize for Constraints { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.constraints.is_empty() { + if self.config.is_some() { + len += 1; + } + if self.writer_options.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Constraints", len)?; - if !self.constraints.is_empty() { - struct_ser.serialize_field("constraints", &self.constraints)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; + } + if let Some(v) = self.writer_options.as_ref() { + struct_ser.serialize_field("writerOptions", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Constraints { +impl<'de> serde::Deserialize<'de> for CsvSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "constraints", + "config", + "writer_options", + "writerOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Constraints, + Config, + WriterOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4168,7 +3961,8 @@ impl<'de> serde::Deserialize<'de> for Constraints { E: serde::de::Error, { match value { - "constraints" => Ok(GeneratedField::Constraints), + "config" => Ok(GeneratedField::Config), + "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4178,36 +3972,44 @@ impl<'de> serde::Deserialize<'de> for Constraints { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Constraints; + type Value = CsvSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Constraints") + formatter.write_str("struct datafusion.CsvSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut constraints__ = None; + let mut config__ = None; + let mut writer_options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Constraints => { - if constraints__.is_some() { - return Err(serde::de::Error::duplicate_field("constraints")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); + } + config__ = map_.next_value()?; + } + GeneratedField::WriterOptions => { + if writer_options__.is_some() { + return Err(serde::de::Error::duplicate_field("writerOptions")); } - constraints__ = Some(map_.next_value()?); + writer_options__ = map_.next_value()?; } } } - Ok(Constraints { - constraints: constraints__.unwrap_or_default(), + Ok(CsvSink { + config: config__, + writer_options: writer_options__, }) } } - deserializer.deserialize_struct("datafusion.Constraints", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CopyToNode { +impl serde::Serialize for CsvSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4218,48 +4020,32 @@ impl serde::Serialize for CopyToNode { if self.input.is_some() { len += 1; } - if !self.output_url.is_empty() { + if self.sink.is_some() { len += 1; } - if !self.partition_by.is_empty() { + if self.sink_schema.is_some() { len += 1; } - if self.format_options.is_some() { + if self.sort_order.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CopyToNode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } - if !self.output_url.is_empty() { - struct_ser.serialize_field("outputUrl", &self.output_url)?; + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; } - if !self.partition_by.is_empty() { - struct_ser.serialize_field("partitionBy", &self.partition_by)?; + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; } - if let Some(v) = self.format_options.as_ref() { - match v { - copy_to_node::FormatOptions::Csv(v) => { - struct_ser.serialize_field("csv", v)?; - } - copy_to_node::FormatOptions::Json(v) => { - struct_ser.serialize_field("json", v)?; - } - copy_to_node::FormatOptions::Parquet(v) => { - struct_ser.serialize_field("parquet", v)?; - } - copy_to_node::FormatOptions::Avro(v) => { - struct_ser.serialize_field("avro", v)?; - } - copy_to_node::FormatOptions::Arrow(v) => { - struct_ser.serialize_field("arrow", v)?; - } - } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CopyToNode { +impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -4267,27 +4053,19 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { { const FIELDS: &[&str] = &[ "input", - "output_url", - "outputUrl", - "partition_by", - "partitionBy", - "csv", - "json", - "parquet", - "avro", - "arrow", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, - OutputUrl, - PartitionBy, - Csv, - Json, - Parquet, - Avro, - Arrow, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4310,13 +4088,9 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { { match value { "input" => Ok(GeneratedField::Input), - "outputUrl" | "output_url" => Ok(GeneratedField::OutputUrl), - "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), - "csv" => Ok(GeneratedField::Csv), - "json" => Ok(GeneratedField::Json), - "parquet" => Ok(GeneratedField::Parquet), - "avro" => Ok(GeneratedField::Avro), - "arrow" => Ok(GeneratedField::Arrow), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4326,20 +4100,20 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CopyToNode; + type Value = CsvSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CopyToNode") + formatter.write_str("struct datafusion.CsvSinkExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; - let mut output_url__ = None; - let mut partition_by__ = None; - let mut format_options__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -4348,67 +4122,38 @@ impl<'de> serde::Deserialize<'de> for CopyToNode { } input__ = map_.next_value()?; } - GeneratedField::OutputUrl => { - if output_url__.is_some() { - return Err(serde::de::Error::duplicate_field("outputUrl")); - } - output_url__ = Some(map_.next_value()?); - } - GeneratedField::PartitionBy => { - if partition_by__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionBy")); - } - partition_by__ = Some(map_.next_value()?); - } - GeneratedField::Csv => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("csv")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Csv) -; - } - GeneratedField::Json => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("json")); - } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Json) -; - } - GeneratedField::Parquet => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("parquet")); + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Parquet) -; + sink__ = map_.next_value()?; } - GeneratedField::Avro => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("avro")); + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Avro) -; + sink_schema__ = map_.next_value()?; } - GeneratedField::Arrow => { - if format_options__.is_some() { - return Err(serde::de::Error::duplicate_field("arrow")); + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); } - format_options__ = map_.next_value::<::std::option::Option<_>>()?.map(copy_to_node::FormatOptions::Arrow) -; + sort_order__ = map_.next_value()?; } } } - Ok(CopyToNode { + Ok(CsvSinkExecNode { input: input__, - output_url: output_url__.unwrap_or_default(), - partition_by: partition_by__.unwrap_or_default(), - format_options: format_options__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, }) } } - deserializer.deserialize_struct("datafusion.CopyToNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateCatalogNode { +impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4416,47 +4161,29 @@ impl serde::Serialize for CreateCatalogNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.catalog_name.is_empty() { - len += 1; - } - if self.if_not_exists { - len += 1; - } - if self.schema.is_some() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogNode", len)?; - if !self.catalog_name.is_empty() { - struct_ser.serialize_field("catalogName", &self.catalog_name)?; - } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.CubeNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateCatalogNode { +impl<'de> serde::Deserialize<'de> for CubeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "catalog_name", - "catalogName", - "if_not_exists", - "ifNotExists", - "schema", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - CatalogName, - IfNotExists, - Schema, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4478,9 +4205,7 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { E: serde::de::Error, { match value { - "catalogName" | "catalog_name" => Ok(GeneratedField::CatalogName), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), - "schema" => Ok(GeneratedField::Schema), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4490,52 +4215,36 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateCatalogNode; + type Value = CubeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateCatalogNode") + formatter.write_str("struct datafusion.CubeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut catalog_name__ = None; - let mut if_not_exists__ = None; - let mut schema__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::CatalogName => { - if catalog_name__.is_some() { - return Err(serde::de::Error::duplicate_field("catalogName")); - } - catalog_name__ = Some(map_.next_value()?); - } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); - } - if_not_exists__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - schema__ = map_.next_value()?; + expr__ = Some(map_.next_value()?); } } } - Ok(CreateCatalogNode { - catalog_name: catalog_name__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), - schema: schema__, + Ok(CubeNode { + expr: expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CreateCatalogNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CubeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateCatalogSchemaNode { +impl serde::Serialize for CustomTableScanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4543,47 +4252,65 @@ impl serde::Serialize for CreateCatalogSchemaNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.schema_name.is_empty() { + if self.table_name.is_some() { len += 1; } - if self.if_not_exists { + if self.projection.is_some() { len += 1; } if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateCatalogSchemaNode", len)?; - if !self.schema_name.is_empty() { - struct_ser.serialize_field("schemaName", &self.schema_name)?; + if !self.filters.is_empty() { + len += 1; } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; + if !self.custom_table_data.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; + } + if let Some(v) = self.projection.as_ref() { + struct_ser.serialize_field("projection", v)?; } if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } + if !self.filters.is_empty() { + struct_ser.serialize_field("filters", &self.filters)?; + } + if !self.custom_table_data.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { +impl<'de> serde::Deserialize<'de> for CustomTableScanNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema_name", - "schemaName", - "if_not_exists", - "ifNotExists", + "table_name", + "tableName", + "projection", "schema", + "filters", + "custom_table_data", + "customTableData", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - SchemaName, - IfNotExists, + TableName, + Projection, Schema, + Filters, + CustomTableData, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4605,9 +4332,11 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { E: serde::de::Error, { match value { - "schemaName" | "schema_name" => Ok(GeneratedField::SchemaName), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), + "tableName" | "table_name" => Ok(GeneratedField::TableName), + "projection" => Ok(GeneratedField::Projection), "schema" => Ok(GeneratedField::Schema), + "filters" => Ok(GeneratedField::Filters), + "customTableData" | "custom_table_data" => Ok(GeneratedField::CustomTableData), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4617,32 +4346,34 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateCatalogSchemaNode; + type Value = CustomTableScanNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateCatalogSchemaNode") + formatter.write_str("struct datafusion.CustomTableScanNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema_name__ = None; - let mut if_not_exists__ = None; + let mut table_name__ = None; + let mut projection__ = None; let mut schema__ = None; + let mut filters__ = None; + let mut custom_table_data__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::SchemaName => { - if schema_name__.is_some() { - return Err(serde::de::Error::duplicate_field("schemaName")); + GeneratedField::TableName => { + if table_name__.is_some() { + return Err(serde::de::Error::duplicate_field("tableName")); } - schema_name__ = Some(map_.next_value()?); + table_name__ = map_.next_value()?; } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); } - if_not_exists__ = Some(map_.next_value()?); + projection__ = map_.next_value()?; } GeneratedField::Schema => { if schema__.is_some() { @@ -4650,19 +4381,106 @@ impl<'de> serde::Deserialize<'de> for CreateCatalogSchemaNode { } schema__ = map_.next_value()?; } + GeneratedField::Filters => { + if filters__.is_some() { + return Err(serde::de::Error::duplicate_field("filters")); + } + filters__ = Some(map_.next_value()?); + } + GeneratedField::CustomTableData => { + if custom_table_data__.is_some() { + return Err(serde::de::Error::duplicate_field("customTableData")); + } + custom_table_data__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } } } - Ok(CreateCatalogSchemaNode { - schema_name: schema_name__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), + Ok(CustomTableScanNode { + table_name: table_name__, + projection: projection__, schema: schema__, + filters: filters__.unwrap_or_default(), + custom_table_data: custom_table_data__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CreateCatalogSchemaNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.CustomTableScanNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateExternalTableNode { +impl serde::Serialize for DateUnit { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Day => "Day", + Self::DateMillisecond => "DateMillisecond", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for DateUnit { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Day", + "DateMillisecond", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = DateUnit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Day" => Ok(DateUnit::Day), + "DateMillisecond" => Ok(DateUnit::DateMillisecond), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} +impl serde::Serialize for DistinctNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -4670,150 +4488,29 @@ impl serde::Serialize for CreateExternalTableNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { - len += 1; - } - if !self.location.is_empty() { - len += 1; - } - if !self.file_type.is_empty() { - len += 1; - } - if self.has_header { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - if !self.table_partition_cols.is_empty() { - len += 1; - } - if self.if_not_exists { - len += 1; - } - if !self.delimiter.is_empty() { - len += 1; - } - if !self.definition.is_empty() { - len += 1; - } - if self.file_compression_type != 0 { - len += 1; - } - if !self.order_exprs.is_empty() { - len += 1; - } - if self.unbounded { - len += 1; - } - if !self.options.is_empty() { - len += 1; - } - if self.constraints.is_some() { - len += 1; - } - if !self.column_defaults.is_empty() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateExternalTableNode", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; - } - if !self.location.is_empty() { - struct_ser.serialize_field("location", &self.location)?; - } - if !self.file_type.is_empty() { - struct_ser.serialize_field("fileType", &self.file_type)?; - } - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; - } - if self.if_not_exists { - struct_ser.serialize_field("ifNotExists", &self.if_not_exists)?; - } - if !self.delimiter.is_empty() { - struct_ser.serialize_field("delimiter", &self.delimiter)?; - } - if !self.definition.is_empty() { - struct_ser.serialize_field("definition", &self.definition)?; - } - if self.file_compression_type != 0 { - let v = CompressionTypeVariant::try_from(self.file_compression_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.file_compression_type)))?; - struct_ser.serialize_field("fileCompressionType", &v)?; - } - if !self.order_exprs.is_empty() { - struct_ser.serialize_field("orderExprs", &self.order_exprs)?; - } - if self.unbounded { - struct_ser.serialize_field("unbounded", &self.unbounded)?; - } - if !self.options.is_empty() { - struct_ser.serialize_field("options", &self.options)?; - } - if let Some(v) = self.constraints.as_ref() { - struct_ser.serialize_field("constraints", v)?; - } - if !self.column_defaults.is_empty() { - struct_ser.serialize_field("columnDefaults", &self.column_defaults)?; + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { +impl<'de> serde::Deserialize<'de> for DistinctNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "location", - "file_type", - "fileType", - "has_header", - "hasHeader", - "schema", - "table_partition_cols", - "tablePartitionCols", - "if_not_exists", - "ifNotExists", - "delimiter", - "definition", - "file_compression_type", - "fileCompressionType", - "order_exprs", - "orderExprs", - "unbounded", - "options", - "constraints", - "column_defaults", - "columnDefaults", + "input", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Location, - FileType, - HasHeader, - Schema, - TablePartitionCols, - IfNotExists, - Delimiter, - Definition, - FileCompressionType, - OrderExprs, - Unbounded, - Options, - Constraints, - ColumnDefaults, + Input, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4835,21 +4532,7 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "location" => Ok(GeneratedField::Location), - "fileType" | "file_type" => Ok(GeneratedField::FileType), - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "schema" => Ok(GeneratedField::Schema), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "ifNotExists" | "if_not_exists" => Ok(GeneratedField::IfNotExists), - "delimiter" => Ok(GeneratedField::Delimiter), - "definition" => Ok(GeneratedField::Definition), - "fileCompressionType" | "file_compression_type" => Ok(GeneratedField::FileCompressionType), - "orderExprs" | "order_exprs" => Ok(GeneratedField::OrderExprs), - "unbounded" => Ok(GeneratedField::Unbounded), - "options" => Ok(GeneratedField::Options), - "constraints" => Ok(GeneratedField::Constraints), - "columnDefaults" | "column_defaults" => Ok(GeneratedField::ColumnDefaults), + "input" => Ok(GeneratedField::Input), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4859,152 +4542,36 @@ impl<'de> serde::Deserialize<'de> for CreateExternalTableNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateExternalTableNode; + type Value = DistinctNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateExternalTableNode") + formatter.write_str("struct datafusion.DistinctNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut location__ = None; - let mut file_type__ = None; - let mut has_header__ = None; - let mut schema__ = None; - let mut table_partition_cols__ = None; - let mut if_not_exists__ = None; - let mut delimiter__ = None; - let mut definition__ = None; - let mut file_compression_type__ = None; - let mut order_exprs__ = None; - let mut unbounded__ = None; - let mut options__ = None; - let mut constraints__ = None; - let mut column_defaults__ = None; + let mut input__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = map_.next_value()?; - } - GeneratedField::Location => { - if location__.is_some() { - return Err(serde::de::Error::duplicate_field("location")); - } - location__ = Some(map_.next_value()?); - } - GeneratedField::FileType => { - if file_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fileType")); - } - file_type__ = Some(map_.next_value()?); - } - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); - } - has_header__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); - } - table_partition_cols__ = Some(map_.next_value()?); - } - GeneratedField::IfNotExists => { - if if_not_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifNotExists")); - } - if_not_exists__ = Some(map_.next_value()?); - } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); - } - delimiter__ = Some(map_.next_value()?); - } - GeneratedField::Definition => { - if definition__.is_some() { - return Err(serde::de::Error::duplicate_field("definition")); - } - definition__ = Some(map_.next_value()?); - } - GeneratedField::FileCompressionType => { - if file_compression_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fileCompressionType")); - } - file_compression_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::OrderExprs => { - if order_exprs__.is_some() { - return Err(serde::de::Error::duplicate_field("orderExprs")); - } - order_exprs__ = Some(map_.next_value()?); - } - GeneratedField::Unbounded => { - if unbounded__.is_some() { - return Err(serde::de::Error::duplicate_field("unbounded")); - } - unbounded__ = Some(map_.next_value()?); - } - GeneratedField::Options => { - if options__.is_some() { - return Err(serde::de::Error::duplicate_field("options")); - } - options__ = Some( - map_.next_value::>()? - ); - } - GeneratedField::Constraints => { - if constraints__.is_some() { - return Err(serde::de::Error::duplicate_field("constraints")); - } - constraints__ = map_.next_value()?; - } - GeneratedField::ColumnDefaults => { - if column_defaults__.is_some() { - return Err(serde::de::Error::duplicate_field("columnDefaults")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - column_defaults__ = Some( - map_.next_value::>()? - ); + input__ = map_.next_value()?; } } } - Ok(CreateExternalTableNode { - name: name__, - location: location__.unwrap_or_default(), - file_type: file_type__.unwrap_or_default(), - has_header: has_header__.unwrap_or_default(), - schema: schema__, - table_partition_cols: table_partition_cols__.unwrap_or_default(), - if_not_exists: if_not_exists__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - definition: definition__.unwrap_or_default(), - file_compression_type: file_compression_type__.unwrap_or_default(), - order_exprs: order_exprs__.unwrap_or_default(), - unbounded: unbounded__.unwrap_or_default(), - options: options__.unwrap_or_default(), - constraints: constraints__, - column_defaults: column_defaults__.unwrap_or_default(), + Ok(DistinctNode { + input: input__, }) } } - deserializer.deserialize_struct("datafusion.CreateExternalTableNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CreateViewNode { +impl serde::Serialize for DistinctOnNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5012,54 +4579,56 @@ impl serde::Serialize for CreateViewNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { + if !self.on_expr.is_empty() { len += 1; } - if self.input.is_some() { + if !self.select_expr.is_empty() { len += 1; } - if self.or_replace { + if !self.sort_expr.is_empty() { len += 1; } - if !self.definition.is_empty() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CreateViewNode", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; + if !self.on_expr.is_empty() { + struct_ser.serialize_field("onExpr", &self.on_expr)?; } - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if !self.select_expr.is_empty() { + struct_ser.serialize_field("selectExpr", &self.select_expr)?; } - if self.or_replace { - struct_ser.serialize_field("orReplace", &self.or_replace)?; + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; } - if !self.definition.is_empty() { - struct_ser.serialize_field("definition", &self.definition)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CreateViewNode { +impl<'de> serde::Deserialize<'de> for DistinctOnNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", + "on_expr", + "onExpr", + "select_expr", + "selectExpr", + "sort_expr", + "sortExpr", "input", - "or_replace", - "orReplace", - "definition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, + OnExpr, + SelectExpr, + SortExpr, Input, - OrReplace, - Definition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5081,10 +4650,10 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), + "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), + "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), "input" => Ok(GeneratedField::Input), - "orReplace" | "or_replace" => Ok(GeneratedField::OrReplace), - "definition" => Ok(GeneratedField::Definition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5094,60 +4663,60 @@ impl<'de> serde::Deserialize<'de> for CreateViewNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CreateViewNode; + type Value = DistinctOnNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CreateViewNode") + formatter.write_str("struct datafusion.DistinctOnNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; + let mut on_expr__ = None; + let mut select_expr__ = None; + let mut sort_expr__ = None; let mut input__ = None; - let mut or_replace__ = None; - let mut definition__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::OnExpr => { + if on_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("onExpr")); } - name__ = map_.next_value()?; + on_expr__ = Some(map_.next_value()?); } - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::SelectExpr => { + if select_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("selectExpr")); } - input__ = map_.next_value()?; + select_expr__ = Some(map_.next_value()?); } - GeneratedField::OrReplace => { - if or_replace__.is_some() { - return Err(serde::de::Error::duplicate_field("orReplace")); + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); } - or_replace__ = Some(map_.next_value()?); + sort_expr__ = Some(map_.next_value()?); } - GeneratedField::Definition => { - if definition__.is_some() { - return Err(serde::de::Error::duplicate_field("definition")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - definition__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } } } - Ok(CreateViewNode { - name: name__, + Ok(DistinctOnNode { + on_expr: on_expr__.unwrap_or_default(), + select_expr: select_expr__.unwrap_or_default(), + sort_expr: sort_expr__.unwrap_or_default(), input: input__, - or_replace: or_replace__.unwrap_or_default(), - definition: definition__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CreateViewNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CrossJoinExecNode { +impl serde::Serialize for DropViewNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5155,37 +4724,46 @@ impl serde::Serialize for CrossJoinExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { + if self.name.is_some() { len += 1; } - if self.right.is_some() { + if self.if_exists { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; + if self.schema.is_some() { + len += 1; } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.DropViewNode", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; + } + if self.if_exists { + struct_ser.serialize_field("ifExists", &self.if_exists)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { +impl<'de> serde::Deserialize<'de> for DropViewNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", - ]; - - #[allow(clippy::enum_variant_names)] + "name", + "if_exists", + "ifExists", + "schema", + ]; + + #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, + Name, + IfExists, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5207,8 +4785,9 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), + "name" => Ok(GeneratedField::Name), + "ifExists" | "if_exists" => Ok(GeneratedField::IfExists), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5218,44 +4797,52 @@ impl<'de> serde::Deserialize<'de> for CrossJoinExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CrossJoinExecNode; + type Value = DropViewNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CrossJoinExecNode") + formatter.write_str("struct datafusion.DropViewNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; + let mut name__ = None; + let mut if_exists__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - left__ = map_.next_value()?; + name__ = map_.next_value()?; } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); + GeneratedField::IfExists => { + if if_exists__.is_some() { + return Err(serde::de::Error::duplicate_field("ifExists")); } - right__ = map_.next_value()?; + if_exists__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; } } } - Ok(CrossJoinExecNode { - left: left__, - right: right__, + Ok(DropViewNode { + name: name__, + if_exists: if_exists__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.CrossJoinExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.DropViewNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CrossJoinNode { +impl serde::Serialize for EmptyExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5263,37 +4850,29 @@ impl serde::Serialize for CrossJoinNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CrossJoinNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CrossJoinNode { +impl<'de> serde::Deserialize<'de> for EmptyExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5315,8 +4894,7 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5326,44 +4904,36 @@ impl<'de> serde::Deserialize<'de> for CrossJoinNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CrossJoinNode; + type Value = EmptyExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CrossJoinNode") + formatter.write_str("struct datafusion.EmptyExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - right__ = map_.next_value()?; + schema__ = map_.next_value()?; } } } - Ok(CrossJoinNode { - left: left__, - right: right__, + Ok(EmptyExecNode { + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.CrossJoinNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.EmptyExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvFormat { +impl serde::Serialize for EmptyRelationNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5371,29 +4941,30 @@ impl serde::Serialize for CsvFormat { { use serde::ser::SerializeStruct; let mut len = 0; - if self.options.is_some() { + if self.produce_one_row { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvFormat", len)?; - if let Some(v) = self.options.as_ref() { - struct_ser.serialize_field("options", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.EmptyRelationNode", len)?; + if self.produce_one_row { + struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvFormat { +impl<'de> serde::Deserialize<'de> for EmptyRelationNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "options", + "produce_one_row", + "produceOneRow", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Options, + ProduceOneRow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5415,7 +4986,7 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { E: serde::de::Error, { match value { - "options" => Ok(GeneratedField::Options), + "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5425,36 +4996,36 @@ impl<'de> serde::Deserialize<'de> for CsvFormat { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvFormat; + type Value = EmptyRelationNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvFormat") + formatter.write_str("struct datafusion.EmptyRelationNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut options__ = None; + let mut produce_one_row__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Options => { - if options__.is_some() { - return Err(serde::de::Error::duplicate_field("options")); + GeneratedField::ProduceOneRow => { + if produce_one_row__.is_some() { + return Err(serde::de::Error::duplicate_field("produceOneRow")); } - options__ = map_.next_value()?; + produce_one_row__ = Some(map_.next_value()?); } } } - Ok(CsvFormat { - options: options__, + Ok(EmptyRelationNode { + produce_one_row: produce_one_row__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvFormat", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvOptions { +impl serde::Serialize for ExplainExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5462,131 +5033,46 @@ impl serde::Serialize for CsvOptions { { use serde::ser::SerializeStruct; let mut len = 0; - if self.has_header { - len += 1; - } - if !self.delimiter.is_empty() { - len += 1; - } - if !self.quote.is_empty() { - len += 1; - } - if !self.escape.is_empty() { - len += 1; - } - if self.compression != 0 { - len += 1; - } - if self.schema_infer_max_rec != 0 { - len += 1; - } - if !self.date_format.is_empty() { - len += 1; - } - if !self.datetime_format.is_empty() { - len += 1; - } - if !self.timestamp_format.is_empty() { - len += 1; - } - if !self.timestamp_tz_format.is_empty() { + if self.schema.is_some() { len += 1; } - if !self.time_format.is_empty() { + if !self.stringified_plans.is_empty() { len += 1; } - if !self.null_value.is_empty() { + if self.verbose { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvOptions", len)?; - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; - } - if !self.delimiter.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("delimiter", pbjson::private::base64::encode(&self.delimiter).as_str())?; - } - if !self.quote.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("quote", pbjson::private::base64::encode(&self.quote).as_str())?; - } - if !self.escape.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("escape", pbjson::private::base64::encode(&self.escape).as_str())?; - } - if self.compression != 0 { - let v = CompressionTypeVariant::try_from(self.compression) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; - struct_ser.serialize_field("compression", &v)?; - } - if self.schema_infer_max_rec != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; - } - if !self.date_format.is_empty() { - struct_ser.serialize_field("dateFormat", &self.date_format)?; - } - if !self.datetime_format.is_empty() { - struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; - } - if !self.timestamp_format.is_empty() { - struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; - } - if !self.timestamp_tz_format.is_empty() { - struct_ser.serialize_field("timestampTzFormat", &self.timestamp_tz_format)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ExplainExecNode", len)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } - if !self.time_format.is_empty() { - struct_ser.serialize_field("timeFormat", &self.time_format)?; + if !self.stringified_plans.is_empty() { + struct_ser.serialize_field("stringifiedPlans", &self.stringified_plans)?; } - if !self.null_value.is_empty() { - struct_ser.serialize_field("nullValue", &self.null_value)?; + if self.verbose { + struct_ser.serialize_field("verbose", &self.verbose)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvOptions { +impl<'de> serde::Deserialize<'de> for ExplainExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "has_header", - "hasHeader", - "delimiter", - "quote", - "escape", - "compression", - "schema_infer_max_rec", - "schemaInferMaxRec", - "date_format", - "dateFormat", - "datetime_format", - "datetimeFormat", - "timestamp_format", - "timestampFormat", - "timestamp_tz_format", - "timestampTzFormat", - "time_format", - "timeFormat", - "null_value", - "nullValue", + "schema", + "stringified_plans", + "stringifiedPlans", + "verbose", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - HasHeader, - Delimiter, - Quote, - Escape, - Compression, - SchemaInferMaxRec, - DateFormat, - DatetimeFormat, - TimestampFormat, - TimestampTzFormat, - TimeFormat, - NullValue, + Schema, + StringifiedPlans, + Verbose, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5608,18 +5094,9 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { E: serde::de::Error, { match value { - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "delimiter" => Ok(GeneratedField::Delimiter), - "quote" => Ok(GeneratedField::Quote), - "escape" => Ok(GeneratedField::Escape), - "compression" => Ok(GeneratedField::Compression), - "schemaInferMaxRec" | "schema_infer_max_rec" => Ok(GeneratedField::SchemaInferMaxRec), - "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), - "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), - "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), - "timestampTzFormat" | "timestamp_tz_format" => Ok(GeneratedField::TimestampTzFormat), - "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), - "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "schema" => Ok(GeneratedField::Schema), + "stringifiedPlans" | "stringified_plans" => Ok(GeneratedField::StringifiedPlans), + "verbose" => Ok(GeneratedField::Verbose), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5629,132 +5106,52 @@ impl<'de> serde::Deserialize<'de> for CsvOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvOptions; + type Value = ExplainExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvOptions") + formatter.write_str("struct datafusion.ExplainExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut has_header__ = None; - let mut delimiter__ = None; - let mut quote__ = None; - let mut escape__ = None; - let mut compression__ = None; - let mut schema_infer_max_rec__ = None; - let mut date_format__ = None; - let mut datetime_format__ = None; - let mut timestamp_format__ = None; - let mut timestamp_tz_format__ = None; - let mut time_format__ = None; - let mut null_value__ = None; + let mut schema__ = None; + let mut stringified_plans__ = None; + let mut verbose__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - has_header__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); + GeneratedField::StringifiedPlans => { + if stringified_plans__.is_some() { + return Err(serde::de::Error::duplicate_field("stringifiedPlans")); } - delimiter__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + stringified_plans__ = Some(map_.next_value()?); } - GeneratedField::Quote => { - if quote__.is_some() { - return Err(serde::de::Error::duplicate_field("quote")); + GeneratedField::Verbose => { + if verbose__.is_some() { + return Err(serde::de::Error::duplicate_field("verbose")); } - quote__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; - } - GeneratedField::Escape => { - if escape__.is_some() { - return Err(serde::de::Error::duplicate_field("escape")); - } - escape__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; - } - GeneratedField::Compression => { - if compression__.is_some() { - return Err(serde::de::Error::duplicate_field("compression")); - } - compression__ = Some(map_.next_value::()? as i32); - } - GeneratedField::SchemaInferMaxRec => { - if schema_infer_max_rec__.is_some() { - return Err(serde::de::Error::duplicate_field("schemaInferMaxRec")); - } - schema_infer_max_rec__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::DateFormat => { - if date_format__.is_some() { - return Err(serde::de::Error::duplicate_field("dateFormat")); - } - date_format__ = Some(map_.next_value()?); - } - GeneratedField::DatetimeFormat => { - if datetime_format__.is_some() { - return Err(serde::de::Error::duplicate_field("datetimeFormat")); - } - datetime_format__ = Some(map_.next_value()?); - } - GeneratedField::TimestampFormat => { - if timestamp_format__.is_some() { - return Err(serde::de::Error::duplicate_field("timestampFormat")); - } - timestamp_format__ = Some(map_.next_value()?); - } - GeneratedField::TimestampTzFormat => { - if timestamp_tz_format__.is_some() { - return Err(serde::de::Error::duplicate_field("timestampTzFormat")); - } - timestamp_tz_format__ = Some(map_.next_value()?); - } - GeneratedField::TimeFormat => { - if time_format__.is_some() { - return Err(serde::de::Error::duplicate_field("timeFormat")); - } - time_format__ = Some(map_.next_value()?); - } - GeneratedField::NullValue => { - if null_value__.is_some() { - return Err(serde::de::Error::duplicate_field("nullValue")); - } - null_value__ = Some(map_.next_value()?); + verbose__ = Some(map_.next_value()?); } } } - Ok(CsvOptions { - has_header: has_header__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - quote: quote__.unwrap_or_default(), - escape: escape__.unwrap_or_default(), - compression: compression__.unwrap_or_default(), - schema_infer_max_rec: schema_infer_max_rec__.unwrap_or_default(), - date_format: date_format__.unwrap_or_default(), - datetime_format: datetime_format__.unwrap_or_default(), - timestamp_format: timestamp_format__.unwrap_or_default(), - timestamp_tz_format: timestamp_tz_format__.unwrap_or_default(), - time_format: time_format__.unwrap_or_default(), - null_value: null_value__.unwrap_or_default(), + Ok(ExplainExecNode { + schema: schema__, + stringified_plans: stringified_plans__.unwrap_or_default(), + verbose: verbose__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ExplainExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvScanExecNode { +impl serde::Serialize for ExplainNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5762,67 +5159,37 @@ impl serde::Serialize for CsvScanExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.base_conf.is_some() { - len += 1; - } - if self.has_header { - len += 1; - } - if !self.delimiter.is_empty() { - len += 1; - } - if !self.quote.is_empty() { + if self.input.is_some() { len += 1; } - if self.optional_escape.is_some() { + if self.verbose { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvScanExecNode", len)?; - if let Some(v) = self.base_conf.as_ref() { - struct_ser.serialize_field("baseConf", v)?; - } - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; - } - if !self.delimiter.is_empty() { - struct_ser.serialize_field("delimiter", &self.delimiter)?; - } - if !self.quote.is_empty() { - struct_ser.serialize_field("quote", &self.quote)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ExplainNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.optional_escape.as_ref() { - match v { - csv_scan_exec_node::OptionalEscape::Escape(v) => { - struct_ser.serialize_field("escape", v)?; - } - } + if self.verbose { + struct_ser.serialize_field("verbose", &self.verbose)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvScanExecNode { +impl<'de> serde::Deserialize<'de> for ExplainNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "base_conf", - "baseConf", - "has_header", - "hasHeader", - "delimiter", - "quote", - "escape", + "input", + "verbose", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - BaseConf, - HasHeader, - Delimiter, - Quote, - Escape, + Input, + Verbose, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5844,11 +5211,8 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { E: serde::de::Error, { match value { - "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "delimiter" => Ok(GeneratedField::Delimiter), - "quote" => Ok(GeneratedField::Quote), - "escape" => Ok(GeneratedField::Escape), + "input" => Ok(GeneratedField::Input), + "verbose" => Ok(GeneratedField::Verbose), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5858,68 +5222,44 @@ impl<'de> serde::Deserialize<'de> for CsvScanExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvScanExecNode; + type Value = ExplainNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvScanExecNode") + formatter.write_str("struct datafusion.ExplainNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut base_conf__ = None; - let mut has_header__ = None; - let mut delimiter__ = None; - let mut quote__ = None; - let mut optional_escape__ = None; + let mut input__ = None; + let mut verbose__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::BaseConf => { - if base_conf__.is_some() { - return Err(serde::de::Error::duplicate_field("baseConf")); - } - base_conf__ = map_.next_value()?; - } - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); - } - has_header__ = Some(map_.next_value()?); - } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); - } - delimiter__ = Some(map_.next_value()?); - } - GeneratedField::Quote => { - if quote__.is_some() { - return Err(serde::de::Error::duplicate_field("quote")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - quote__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Escape => { - if optional_escape__.is_some() { - return Err(serde::de::Error::duplicate_field("escape")); + GeneratedField::Verbose => { + if verbose__.is_some() { + return Err(serde::de::Error::duplicate_field("verbose")); } - optional_escape__ = map_.next_value::<::std::option::Option<_>>()?.map(csv_scan_exec_node::OptionalEscape::Escape); + verbose__ = Some(map_.next_value()?); } } } - Ok(CsvScanExecNode { - base_conf: base_conf__, - has_header: has_header__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - quote: quote__.unwrap_or_default(), - optional_escape: optional_escape__, + Ok(ExplainNode { + input: input__, + verbose: verbose__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvScanExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ExplainNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvSink { +impl serde::Serialize for FileGroup { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -5927,38 +5267,29 @@ impl serde::Serialize for CsvSink { { use serde::ser::SerializeStruct; let mut len = 0; - if self.config.is_some() { - len += 1; - } - if self.writer_options.is_some() { + if !self.files.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvSink", len)?; - if let Some(v) = self.config.as_ref() { - struct_ser.serialize_field("config", v)?; - } - if let Some(v) = self.writer_options.as_ref() { - struct_ser.serialize_field("writerOptions", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileGroup", len)?; + if !self.files.is_empty() { + struct_ser.serialize_field("files", &self.files)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvSink { +impl<'de> serde::Deserialize<'de> for FileGroup { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "config", - "writer_options", - "writerOptions", + "files", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Config, - WriterOptions, + Files, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5980,8 +5311,7 @@ impl<'de> serde::Deserialize<'de> for CsvSink { E: serde::de::Error, { match value { - "config" => Ok(GeneratedField::Config), - "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), + "files" => Ok(GeneratedField::Files), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5991,44 +5321,36 @@ impl<'de> serde::Deserialize<'de> for CsvSink { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvSink; + type Value = FileGroup; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvSink") + formatter.write_str("struct datafusion.FileGroup") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut config__ = None; - let mut writer_options__ = None; + let mut files__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Config => { - if config__.is_some() { - return Err(serde::de::Error::duplicate_field("config")); - } - config__ = map_.next_value()?; - } - GeneratedField::WriterOptions => { - if writer_options__.is_some() { - return Err(serde::de::Error::duplicate_field("writerOptions")); + GeneratedField::Files => { + if files__.is_some() { + return Err(serde::de::Error::duplicate_field("files")); } - writer_options__ = map_.next_value()?; + files__ = Some(map_.next_value()?); } } } - Ok(CsvSink { - config: config__, - writer_options: writer_options__, + Ok(FileGroup { + files: files__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvSink", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileGroup", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvSinkExecNode { +impl serde::Serialize for FileRange { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6036,55 +5358,41 @@ impl serde::Serialize for CsvSinkExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.sink.is_some() { - len += 1; - } - if self.sink_schema.is_some() { + if self.start != 0 { len += 1; } - if self.sort_order.is_some() { + if self.end != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvSinkExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if let Some(v) = self.sink.as_ref() { - struct_ser.serialize_field("sink", v)?; - } - if let Some(v) = self.sink_schema.as_ref() { - struct_ser.serialize_field("sinkSchema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileRange", len)?; + if self.start != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; } - if let Some(v) = self.sort_order.as_ref() { - struct_ser.serialize_field("sortOrder", v)?; + if self.end != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { +impl<'de> serde::Deserialize<'de> for FileRange { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "sink", - "sink_schema", - "sinkSchema", - "sort_order", - "sortOrder", + "start", + "end", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Sink, - SinkSchema, - SortOrder, + Start, + End, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6106,10 +5414,8 @@ impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "sink" => Ok(GeneratedField::Sink), - "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), - "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + "start" => Ok(GeneratedField::Start), + "end" => Ok(GeneratedField::End), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6119,60 +5425,48 @@ impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvSinkExecNode; + type Value = FileRange; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvSinkExecNode") + formatter.write_str("struct datafusion.FileRange") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut sink__ = None; - let mut sink_schema__ = None; - let mut sort_order__ = None; + let mut start__ = None; + let mut end__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Sink => { - if sink__.is_some() { - return Err(serde::de::Error::duplicate_field("sink")); - } - sink__ = map_.next_value()?; - } - GeneratedField::SinkSchema => { - if sink_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("sinkSchema")); + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); } - sink_schema__ = map_.next_value()?; + start__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::SortOrder => { - if sort_order__.is_some() { - return Err(serde::de::Error::duplicate_field("sortOrder")); + GeneratedField::End => { + if end__.is_some() { + return Err(serde::de::Error::duplicate_field("end")); } - sort_order__ = map_.next_value()?; + end__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(CsvSinkExecNode { - input: input__, - sink: sink__, - sink_schema: sink_schema__, - sort_order: sort_order__, + Ok(FileRange { + start: start__.unwrap_or_default(), + end: end__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileRange", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CsvWriterOptions { +impl serde::Serialize for FileScanExecConf { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6180,93 +5474,89 @@ impl serde::Serialize for CsvWriterOptions { { use serde::ser::SerializeStruct; let mut len = 0; - if self.compression != 0 { + if !self.file_groups.is_empty() { len += 1; } - if !self.delimiter.is_empty() { + if self.schema.is_some() { len += 1; } - if self.has_header { + if !self.projection.is_empty() { len += 1; } - if !self.date_format.is_empty() { + if self.limit.is_some() { len += 1; } - if !self.datetime_format.is_empty() { + if self.statistics.is_some() { len += 1; } - if !self.timestamp_format.is_empty() { + if !self.table_partition_cols.is_empty() { len += 1; } - if !self.time_format.is_empty() { + if !self.object_store_url.is_empty() { len += 1; } - if !self.null_value.is_empty() { + if !self.output_ordering.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CsvWriterOptions", len)?; - if self.compression != 0 { - let v = CompressionTypeVariant::try_from(self.compression) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; - struct_ser.serialize_field("compression", &v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FileScanExecConf", len)?; + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; } - if !self.delimiter.is_empty() { - struct_ser.serialize_field("delimiter", &self.delimiter)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } - if self.has_header { - struct_ser.serialize_field("hasHeader", &self.has_header)?; + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; } - if !self.date_format.is_empty() { - struct_ser.serialize_field("dateFormat", &self.date_format)?; + if let Some(v) = self.limit.as_ref() { + struct_ser.serialize_field("limit", v)?; } - if !self.datetime_format.is_empty() { - struct_ser.serialize_field("datetimeFormat", &self.datetime_format)?; + if let Some(v) = self.statistics.as_ref() { + struct_ser.serialize_field("statistics", v)?; } - if !self.timestamp_format.is_empty() { - struct_ser.serialize_field("timestampFormat", &self.timestamp_format)?; + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; } - if !self.time_format.is_empty() { - struct_ser.serialize_field("timeFormat", &self.time_format)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; } - if !self.null_value.is_empty() { - struct_ser.serialize_field("nullValue", &self.null_value)?; + if !self.output_ordering.is_empty() { + struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CsvWriterOptions { +impl<'de> serde::Deserialize<'de> for FileScanExecConf { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "compression", - "delimiter", - "has_header", - "hasHeader", - "date_format", - "dateFormat", - "datetime_format", - "datetimeFormat", - "timestamp_format", - "timestampFormat", - "time_format", - "timeFormat", - "null_value", - "nullValue", + "file_groups", + "fileGroups", + "schema", + "projection", + "limit", + "statistics", + "table_partition_cols", + "tablePartitionCols", + "object_store_url", + "objectStoreUrl", + "output_ordering", + "outputOrdering", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Compression, - Delimiter, - HasHeader, - DateFormat, - DatetimeFormat, - TimestampFormat, - TimeFormat, - NullValue, + FileGroups, + Schema, + Projection, + Limit, + Statistics, + TablePartitionCols, + ObjectStoreUrl, + OutputOrdering, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6288,14 +5578,14 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { E: serde::de::Error, { match value { - "compression" => Ok(GeneratedField::Compression), - "delimiter" => Ok(GeneratedField::Delimiter), - "hasHeader" | "has_header" => Ok(GeneratedField::HasHeader), - "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), - "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), - "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), - "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), - "nullValue" | "null_value" => Ok(GeneratedField::NullValue), + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "schema" => Ok(GeneratedField::Schema), + "projection" => Ok(GeneratedField::Projection), + "limit" => Ok(GeneratedField::Limit), + "statistics" => Ok(GeneratedField::Statistics), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6305,92 +5595,95 @@ impl<'de> serde::Deserialize<'de> for CsvWriterOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CsvWriterOptions; + type Value = FileScanExecConf; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CsvWriterOptions") + formatter.write_str("struct datafusion.FileScanExecConf") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut compression__ = None; - let mut delimiter__ = None; - let mut has_header__ = None; - let mut date_format__ = None; - let mut datetime_format__ = None; - let mut timestamp_format__ = None; - let mut time_format__ = None; - let mut null_value__ = None; + let mut file_groups__ = None; + let mut schema__ = None; + let mut projection__ = None; + let mut limit__ = None; + let mut statistics__ = None; + let mut table_partition_cols__ = None; + let mut object_store_url__ = None; + let mut output_ordering__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Compression => { - if compression__.is_some() { - return Err(serde::de::Error::duplicate_field("compression")); + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); } - compression__ = Some(map_.next_value::()? as i32); + file_groups__ = Some(map_.next_value()?); } - GeneratedField::Delimiter => { - if delimiter__.is_some() { - return Err(serde::de::Error::duplicate_field("delimiter")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - delimiter__ = Some(map_.next_value()?); + schema__ = map_.next_value()?; } - GeneratedField::HasHeader => { - if has_header__.is_some() { - return Err(serde::de::Error::duplicate_field("hasHeader")); + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); } - has_header__ = Some(map_.next_value()?); + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } - GeneratedField::DateFormat => { - if date_format__.is_some() { - return Err(serde::de::Error::duplicate_field("dateFormat")); + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); } - date_format__ = Some(map_.next_value()?); + limit__ = map_.next_value()?; } - GeneratedField::DatetimeFormat => { - if datetime_format__.is_some() { - return Err(serde::de::Error::duplicate_field("datetimeFormat")); + GeneratedField::Statistics => { + if statistics__.is_some() { + return Err(serde::de::Error::duplicate_field("statistics")); } - datetime_format__ = Some(map_.next_value()?); + statistics__ = map_.next_value()?; } - GeneratedField::TimestampFormat => { - if timestamp_format__.is_some() { - return Err(serde::de::Error::duplicate_field("timestampFormat")); + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); } - timestamp_format__ = Some(map_.next_value()?); + table_partition_cols__ = Some(map_.next_value()?); } - GeneratedField::TimeFormat => { - if time_format__.is_some() { - return Err(serde::de::Error::duplicate_field("timeFormat")); + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); } - time_format__ = Some(map_.next_value()?); + object_store_url__ = Some(map_.next_value()?); } - GeneratedField::NullValue => { - if null_value__.is_some() { - return Err(serde::de::Error::duplicate_field("nullValue")); + GeneratedField::OutputOrdering => { + if output_ordering__.is_some() { + return Err(serde::de::Error::duplicate_field("outputOrdering")); } - null_value__ = Some(map_.next_value()?); + output_ordering__ = Some(map_.next_value()?); } } } - Ok(CsvWriterOptions { - compression: compression__.unwrap_or_default(), - delimiter: delimiter__.unwrap_or_default(), - has_header: has_header__.unwrap_or_default(), - date_format: date_format__.unwrap_or_default(), - datetime_format: datetime_format__.unwrap_or_default(), - timestamp_format: timestamp_format__.unwrap_or_default(), - time_format: time_format__.unwrap_or_default(), - null_value: null_value__.unwrap_or_default(), + Ok(FileScanExecConf { + file_groups: file_groups__.unwrap_or_default(), + schema: schema__, + projection: projection__.unwrap_or_default(), + limit: limit__, + statistics: statistics__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + object_store_url: object_store_url__.unwrap_or_default(), + output_ordering: output_ordering__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CsvWriterOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CubeNode { +impl serde::Serialize for FileSinkConfig { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6398,33 +5691,90 @@ impl serde::Serialize for CubeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if !self.object_store_url.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CubeNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + if !self.file_groups.is_empty() { + len += 1; + } + if !self.table_paths.is_empty() { + len += 1; + } + if self.output_schema.is_some() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.keep_partition_by_columns { + len += 1; + } + if self.insert_op != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; + if !self.object_store_url.is_empty() { + struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; + } + if !self.file_groups.is_empty() { + struct_ser.serialize_field("fileGroups", &self.file_groups)?; + } + if !self.table_paths.is_empty() { + struct_ser.serialize_field("tablePaths", &self.table_paths)?; + } + if let Some(v) = self.output_schema.as_ref() { + struct_ser.serialize_field("outputSchema", v)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.keep_partition_by_columns { + struct_ser.serialize_field("keepPartitionByColumns", &self.keep_partition_by_columns)?; + } + if self.insert_op != 0 { + let v = InsertOp::try_from(self.insert_op) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.insert_op)))?; + struct_ser.serialize_field("insertOp", &v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CubeNode { +impl<'de> serde::Deserialize<'de> for FileSinkConfig { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "object_store_url", + "objectStoreUrl", + "file_groups", + "fileGroups", + "table_paths", + "tablePaths", + "output_schema", + "outputSchema", + "table_partition_cols", + "tablePartitionCols", + "keep_partition_by_columns", + "keepPartitionByColumns", + "insert_op", + "insertOp", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where + ObjectStoreUrl, + FileGroups, + TablePaths, + OutputSchema, + TablePartitionCols, + KeepPartitionByColumns, + InsertOp, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where D: serde::Deserializer<'de>, { struct GeneratedVisitor; @@ -6442,7 +5792,13 @@ impl<'de> serde::Deserialize<'de> for CubeNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), + "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), + "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "keepPartitionByColumns" | "keep_partition_by_columns" => Ok(GeneratedField::KeepPartitionByColumns), + "insertOp" | "insert_op" => Ok(GeneratedField::InsertOp), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6452,36 +5808,84 @@ impl<'de> serde::Deserialize<'de> for CubeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CubeNode; + type Value = FileSinkConfig; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CubeNode") + formatter.write_str("struct datafusion.FileSinkConfig") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut object_store_url__ = None; + let mut file_groups__ = None; + let mut table_paths__ = None; + let mut output_schema__ = None; + let mut table_partition_cols__ = None; + let mut keep_partition_by_columns__ = None; + let mut insert_op__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::ObjectStoreUrl => { + if object_store_url__.is_some() { + return Err(serde::de::Error::duplicate_field("objectStoreUrl")); } - expr__ = Some(map_.next_value()?); + object_store_url__ = Some(map_.next_value()?); + } + GeneratedField::FileGroups => { + if file_groups__.is_some() { + return Err(serde::de::Error::duplicate_field("fileGroups")); + } + file_groups__ = Some(map_.next_value()?); + } + GeneratedField::TablePaths => { + if table_paths__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePaths")); + } + table_paths__ = Some(map_.next_value()?); + } + GeneratedField::OutputSchema => { + if output_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("outputSchema")); + } + output_schema__ = map_.next_value()?; + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::KeepPartitionByColumns => { + if keep_partition_by_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("keepPartitionByColumns")); + } + keep_partition_by_columns__ = Some(map_.next_value()?); + } + GeneratedField::InsertOp => { + if insert_op__.is_some() { + return Err(serde::de::Error::duplicate_field("insertOp")); + } + insert_op__ = Some(map_.next_value::()? as i32); } } } - Ok(CubeNode { - expr: expr__.unwrap_or_default(), + Ok(FileSinkConfig { + object_store_url: object_store_url__.unwrap_or_default(), + file_groups: file_groups__.unwrap_or_default(), + table_paths: table_paths__.unwrap_or_default(), + output_schema: output_schema__, + table_partition_cols: table_partition_cols__.unwrap_or_default(), + keep_partition_by_columns: keep_partition_by_columns__.unwrap_or_default(), + insert_op: insert_op__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CubeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FileSinkConfig", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for CustomTableScanNode { +impl serde::Serialize for FilterExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6489,64 +5893,54 @@ impl serde::Serialize for CustomTableScanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_name.is_some() { - len += 1; - } - if self.projection.is_some() { + if self.input.is_some() { len += 1; } - if self.schema.is_some() { + if self.expr.is_some() { len += 1; } - if !self.filters.is_empty() { + if self.default_filter_selectivity != 0 { len += 1; } - if !self.custom_table_data.is_empty() { + if !self.projection.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; - if let Some(v) = self.table_name.as_ref() { - struct_ser.serialize_field("tableName", v)?; - } - if let Some(v) = self.projection.as_ref() { - struct_ser.serialize_field("projection", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if !self.filters.is_empty() { - struct_ser.serialize_field("filters", &self.filters)?; + if self.default_filter_selectivity != 0 { + struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; } - if !self.custom_table_data.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for CustomTableScanNode { +impl<'de> serde::Deserialize<'de> for FilterExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "table_name", - "tableName", + "input", + "expr", + "default_filter_selectivity", + "defaultFilterSelectivity", "projection", - "schema", - "filters", - "custom_table_data", - "customTableData", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - TableName, + Input, + Expr, + DefaultFilterSelectivity, Projection, - Schema, - Filters, - CustomTableData, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6568,11 +5962,10 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { E: serde::de::Error, { match value { - "tableName" | "table_name" => Ok(GeneratedField::TableName), + "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), + "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), "projection" => Ok(GeneratedField::Projection), - "schema" => Ok(GeneratedField::Schema), - "filters" => Ok(GeneratedField::Filters), - "customTableData" | "custom_table_data" => Ok(GeneratedField::CustomTableData), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6582,186 +5975,102 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = CustomTableScanNode; + type Value = FilterExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.CustomTableScanNode") + formatter.write_str("struct datafusion.FilterExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut table_name__ = None; + let mut input__ = None; + let mut expr__ = None; + let mut default_filter_selectivity__ = None; let mut projection__ = None; - let mut schema__ = None; - let mut filters__ = None; - let mut custom_table_data__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::TableName => { - if table_name__.is_some() { - return Err(serde::de::Error::duplicate_field("tableName")); - } - table_name__ = map_.next_value()?; - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - projection__ = map_.next_value()?; + input__ = map_.next_value()?; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - schema__ = map_.next_value()?; + expr__ = map_.next_value()?; } - GeneratedField::Filters => { - if filters__.is_some() { - return Err(serde::de::Error::duplicate_field("filters")); + GeneratedField::DefaultFilterSelectivity => { + if default_filter_selectivity__.is_some() { + return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); } - filters__ = Some(map_.next_value()?); + default_filter_selectivity__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::CustomTableData => { - if custom_table_data__.is_some() { - return Err(serde::de::Error::duplicate_field("customTableData")); + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); } - custom_table_data__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) ; } } } - Ok(CustomTableScanNode { - table_name: table_name__, - projection: projection__, - schema: schema__, - filters: filters__.unwrap_or_default(), - custom_table_data: custom_table_data__.unwrap_or_default(), + Ok(FilterExecNode { + input: input__, + expr: expr__, + default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), + projection: projection__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.CustomTableScanNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FilterExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DateUnit { +impl serde::Serialize for FixedSizeBinary { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::Day => "Day", - Self::DateMillisecond => "DateMillisecond", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.length != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeBinary", len)?; + if self.length != 0 { + struct_ser.serialize_field("length", &self.length)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DateUnit { +impl<'de> serde::Deserialize<'de> for FixedSizeBinary { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "Day", - "DateMillisecond", + "length", ]; - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DateUnit; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Length, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result where - E: serde::de::Error, + D: serde::Deserializer<'de>, { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "Day" => Ok(DateUnit::Day), - "DateMillisecond" => Ok(DateUnit::DateMillisecond), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} -impl serde::Serialize for Decimal { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.precision != 0 { - len += 1; - } - if self.scale != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.Decimal", len)?; - if self.precision != 0 { - struct_ser.serialize_field("precision", &self.precision)?; - } - if self.scale != 0 { - struct_ser.serialize_field("scale", &self.scale)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for Decimal { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "precision", - "scale", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Precision, - Scale, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; + struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { type Value = GeneratedField; @@ -6776,8 +6085,7 @@ impl<'de> serde::Deserialize<'de> for Decimal { E: serde::de::Error, { match value { - "precision" => Ok(GeneratedField::Precision), - "scale" => Ok(GeneratedField::Scale), + "length" => Ok(GeneratedField::Length), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6787,48 +6095,38 @@ impl<'de> serde::Deserialize<'de> for Decimal { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal; + type Value = FixedSizeBinary; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Decimal") + formatter.write_str("struct datafusion.FixedSizeBinary") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut precision__ = None; - let mut scale__ = None; + let mut length__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Precision => { - if precision__.is_some() { - return Err(serde::de::Error::duplicate_field("precision")); - } - precision__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Scale => { - if scale__.is_some() { - return Err(serde::de::Error::duplicate_field("scale")); + GeneratedField::Length => { + if length__.is_some() { + return Err(serde::de::Error::duplicate_field("length")); } - scale__ = + length__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } } - Ok(Decimal { - precision: precision__.unwrap_or_default(), - scale: scale__.unwrap_or_default(), + Ok(FixedSizeBinary { + length: length__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Decimal", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal128 { +impl serde::Serialize for FullTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6836,48 +6134,45 @@ impl serde::Serialize for Decimal128 { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.value.is_empty() { + if !self.catalog.is_empty() { len += 1; } - if self.p != 0 { + if !self.schema.is_empty() { len += 1; } - if self.s != 0 { + if !self.table.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Decimal128", len)?; - if !self.value.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; + if !self.catalog.is_empty() { + struct_ser.serialize_field("catalog", &self.catalog)?; } - if self.p != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; } - if self.s != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal128 { +impl<'de> serde::Deserialize<'de> for FullTableReference { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "value", - "p", - "s", + "catalog", + "schema", + "table", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Value, - P, - S, + Catalog, + Schema, + Table, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6899,9 +6194,9 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { E: serde::de::Error, { match value { - "value" => Ok(GeneratedField::Value), - "p" => Ok(GeneratedField::P), - "s" => Ok(GeneratedField::S), + "catalog" => Ok(GeneratedField::Catalog), + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6911,58 +6206,52 @@ impl<'de> serde::Deserialize<'de> for Decimal128 { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal128; + type Value = FullTableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Decimal128") + formatter.write_str("struct datafusion.FullTableReference") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; - let mut p__ = None; - let mut s__ = None; + let mut catalog__ = None; + let mut schema__ = None; + let mut table__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Catalog => { + if catalog__.is_some() { + return Err(serde::de::Error::duplicate_field("catalog")); } - value__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + catalog__ = Some(map_.next_value()?); } - GeneratedField::P => { - if p__.is_some() { - return Err(serde::de::Error::duplicate_field("p")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - p__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + schema__ = Some(map_.next_value()?); } - GeneratedField::S => { - if s__.is_some() { - return Err(serde::de::Error::duplicate_field("s")); + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); } - s__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + table__ = Some(map_.next_value()?); } } } - Ok(Decimal128 { - value: value__.unwrap_or_default(), - p: p__.unwrap_or_default(), - s: s__.unwrap_or_default(), + Ok(FullTableReference { + catalog: catalog__.unwrap_or_default(), + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Decimal128", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Decimal256 { +impl serde::Serialize for GlobalLimitExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6970,48 +6259,47 @@ impl serde::Serialize for Decimal256 { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.value.is_empty() { + if self.input.is_some() { len += 1; } - if self.p != 0 { + if self.skip != 0 { len += 1; } - if self.s != 0 { + if self.fetch != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Decimal256", len)?; - if !self.value.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("value", pbjson::private::base64::encode(&self.value).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.GlobalLimitExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if self.p != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("p", ToString::to_string(&self.p).as_str())?; + if self.skip != 0 { + struct_ser.serialize_field("skip", &self.skip)?; } - if self.s != 0 { + if self.fetch != 0 { #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("s", ToString::to_string(&self.s).as_str())?; + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Decimal256 { +impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "value", - "p", - "s", + "input", + "skip", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Value, - P, - S, + Input, + Skip, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7033,9 +6321,9 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { E: serde::de::Error, { match value { - "value" => Ok(GeneratedField::Value), - "p" => Ok(GeneratedField::P), - "s" => Ok(GeneratedField::S), + "input" => Ok(GeneratedField::Input), + "skip" => Ok(GeneratedField::Skip), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7045,58 +6333,56 @@ impl<'de> serde::Deserialize<'de> for Decimal256 { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Decimal256; + type Value = GlobalLimitExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Decimal256") + formatter.write_str("struct datafusion.GlobalLimitExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; - let mut p__ = None; - let mut s__ = None; + let mut input__ = None; + let mut skip__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - value__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + input__ = map_.next_value()?; } - GeneratedField::P => { - if p__.is_some() { - return Err(serde::de::Error::duplicate_field("p")); + GeneratedField::Skip => { + if skip__.is_some() { + return Err(serde::de::Error::duplicate_field("skip")); } - p__ = + skip__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } - GeneratedField::S => { - if s__.is_some() { - return Err(serde::de::Error::duplicate_field("s")); + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); } - s__ = + fetch__ = Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } } } - Ok(Decimal256 { - value: value__.unwrap_or_default(), - p: p__.unwrap_or_default(), - s: s__.unwrap_or_default(), + Ok(GlobalLimitExecNode { + input: input__, + skip: skip__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Decimal256", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DfField { +impl serde::Serialize for GroupingSetNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7104,37 +6390,29 @@ impl serde::Serialize for DfField { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field.is_some() { - len += 1; - } - if self.qualifier.is_some() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DfField", len)?; - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; - } - if let Some(v) = self.qualifier.as_ref() { - struct_ser.serialize_field("qualifier", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.GroupingSetNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DfField { +impl<'de> serde::Deserialize<'de> for GroupingSetNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field", - "qualifier", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Field, - Qualifier, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7156,8 +6434,7 @@ impl<'de> serde::Deserialize<'de> for DfField { E: serde::de::Error, { match value { - "field" => Ok(GeneratedField::Field), - "qualifier" => Ok(GeneratedField::Qualifier), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7167,44 +6444,36 @@ impl<'de> serde::Deserialize<'de> for DfField { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DfField; + type Value = GroupingSetNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DfField") + formatter.write_str("struct datafusion.GroupingSetNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field__ = None; - let mut qualifier__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Field => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("field")); - } - field__ = map_.next_value()?; - } - GeneratedField::Qualifier => { - if qualifier__.is_some() { - return Err(serde::de::Error::duplicate_field("qualifier")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - qualifier__ = map_.next_value()?; + expr__ = Some(map_.next_value()?); } } } - Ok(DfField { - field: field__, - qualifier: qualifier__, + Ok(GroupingSetNode { + expr: expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DfField", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.GroupingSetNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DfSchema { +impl serde::Serialize for HashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7212,147 +6481,92 @@ impl serde::Serialize for DfSchema { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.columns.is_empty() { + if self.left.is_some() { len += 1; } - if !self.metadata.is_empty() { + if self.right.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DfSchema", len)?; - if !self.columns.is_empty() { - struct_ser.serialize_field("columns", &self.columns)?; + if !self.on.is_empty() { + len += 1; } - if !self.metadata.is_empty() { - struct_ser.serialize_field("metadata", &self.metadata)?; + if self.join_type != 0 { + len += 1; } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for DfSchema { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "columns", - "metadata", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Columns, - Metadata, + if self.partition_mode != 0 { + len += 1; } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "columns" => Ok(GeneratedField::Columns), - "metadata" => Ok(GeneratedField::Metadata), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DfSchema; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DfSchema") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut columns__ = None; - let mut metadata__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Columns => { - if columns__.is_some() { - return Err(serde::de::Error::duplicate_field("columns")); - } - columns__ = Some(map_.next_value()?); - } - GeneratedField::Metadata => { - if metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("metadata")); - } - metadata__ = Some( - map_.next_value::>()? - ); - } - } - } - Ok(DfSchema { - columns: columns__.unwrap_or_default(), - metadata: metadata__.unwrap_or_default(), - }) - } + if self.null_equals_null { + len += 1; } - deserializer.deserialize_struct("datafusion.DfSchema", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for Dictionary { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.key.is_some() { + if self.filter.is_some() { len += 1; } - if self.value.is_some() { + if !self.projection.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Dictionary", len)?; - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; } - if let Some(v) = self.value.as_ref() { - struct_ser.serialize_field("value", v)?; + if self.partition_mode != 0 { + let v = PartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.projection.is_empty() { + struct_ser.serialize_field("projection", &self.projection)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Dictionary { +impl<'de> serde::Deserialize<'de> for HashJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "key", - "value", + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", + "projection", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Key, - Value, + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, + Projection, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7374,8 +6588,14 @@ impl<'de> serde::Deserialize<'de> for Dictionary { E: serde::de::Error, { match value { - "key" => Ok(GeneratedField::Key), - "value" => Ok(GeneratedField::Value), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), + "projection" => Ok(GeneratedField::Projection), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7385,44 +6605,95 @@ impl<'de> serde::Deserialize<'de> for Dictionary { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Dictionary; + type Value = HashJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Dictionary") + formatter.write_str("struct datafusion.HashJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut key__ = None; - let mut value__ = None; + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; + let mut projection__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - key__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); + } + on__ = Some(map_.next_value()?); + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); } - value__ = map_.next_value()?; + filter__ = map_.next_value()?; + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; } } } - Ok(Dictionary { - key: key__, - value: value__, + Ok(HashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, + projection: projection__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Dictionary", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.HashJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DistinctNode { +impl serde::Serialize for HashRepartition { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7430,29 +6701,41 @@ impl serde::Serialize for DistinctNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.hash_expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DistinctNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if self.partition_count != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.HashRepartition", len)?; + if !self.hash_expr.is_empty() { + struct_ser.serialize_field("hashExpr", &self.hash_expr)?; + } + if self.partition_count != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DistinctNode { +impl<'de> serde::Deserialize<'de> for HashRepartition { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", + "hash_expr", + "hashExpr", + "partition_count", + "partitionCount", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, + HashExpr, + PartitionCount, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7474,7 +6757,8 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), + "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7484,36 +6768,46 @@ impl<'de> serde::Deserialize<'de> for DistinctNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DistinctNode; + type Value = HashRepartition; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DistinctNode") + formatter.write_str("struct datafusion.HashRepartition") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; + let mut hash_expr__ = None; + let mut partition_count__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::HashExpr => { + if hash_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("hashExpr")); } - input__ = map_.next_value()?; + hash_expr__ = Some(map_.next_value()?); + } + GeneratedField::PartitionCount => { + if partition_count__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionCount")); + } + partition_count__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(DistinctNode { - input: input__, + Ok(HashRepartition { + hash_expr: hash_expr__.unwrap_or_default(), + partition_count: partition_count__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DistinctNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.HashRepartition", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DistinctOnNode { +impl serde::Serialize for ILikeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7521,56 +6815,54 @@ impl serde::Serialize for DistinctOnNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.on_expr.is_empty() { + if self.negated { len += 1; } - if !self.select_expr.is_empty() { + if self.expr.is_some() { len += 1; } - if !self.sort_expr.is_empty() { + if self.pattern.is_some() { len += 1; } - if self.input.is_some() { + if !self.escape_char.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DistinctOnNode", len)?; - if !self.on_expr.is_empty() { - struct_ser.serialize_field("onExpr", &self.on_expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ILikeNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } - if !self.select_expr.is_empty() { - struct_ser.serialize_field("selectExpr", &self.select_expr)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if !self.sort_expr.is_empty() { - struct_ser.serialize_field("sortExpr", &self.sort_expr)?; + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; } - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if !self.escape_char.is_empty() { + struct_ser.serialize_field("escapeChar", &self.escape_char)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DistinctOnNode { +impl<'de> serde::Deserialize<'de> for ILikeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "on_expr", - "onExpr", - "select_expr", - "selectExpr", - "sort_expr", - "sortExpr", - "input", + "negated", + "expr", + "pattern", + "escape_char", + "escapeChar", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - OnExpr, - SelectExpr, - SortExpr, - Input, + Negated, + Expr, + Pattern, + EscapeChar, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7592,10 +6884,10 @@ impl<'de> serde::Deserialize<'de> for DistinctOnNode { E: serde::de::Error, { match value { - "onExpr" | "on_expr" => Ok(GeneratedField::OnExpr), - "selectExpr" | "select_expr" => Ok(GeneratedField::SelectExpr), - "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), - "input" => Ok(GeneratedField::Input), + "negated" => Ok(GeneratedField::Negated), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7605,60 +6897,60 @@ impl<'de> serde::Deserialize<'de> for DistinctOnNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DistinctOnNode; + type Value = ILikeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DistinctOnNode") + formatter.write_str("struct datafusion.ILikeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut on_expr__ = None; - let mut select_expr__ = None; - let mut sort_expr__ = None; - let mut input__ = None; + let mut negated__ = None; + let mut expr__ = None; + let mut pattern__ = None; + let mut escape_char__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::OnExpr => { - if on_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("onExpr")); + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); } - on_expr__ = Some(map_.next_value()?); + negated__ = Some(map_.next_value()?); } - GeneratedField::SelectExpr => { - if select_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("selectExpr")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - select_expr__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::SortExpr => { - if sort_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("sortExpr")); + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); } - sort_expr__ = Some(map_.next_value()?); + pattern__ = map_.next_value()?; } - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::EscapeChar => { + if escape_char__.is_some() { + return Err(serde::de::Error::duplicate_field("escapeChar")); } - input__ = map_.next_value()?; + escape_char__ = Some(map_.next_value()?); } } } - Ok(DistinctOnNode { - on_expr: on_expr__.unwrap_or_default(), - select_expr: select_expr__.unwrap_or_default(), - sort_expr: sort_expr__.unwrap_or_default(), - input: input__, + Ok(ILikeNode { + negated: negated__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + escape_char: escape_char__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DistinctOnNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ILikeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for DropViewNode { +impl serde::Serialize for InListNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7666,46 +6958,45 @@ impl serde::Serialize for DropViewNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.name.is_some() { + if self.expr.is_some() { len += 1; } - if self.if_exists { + if !self.list.is_empty() { len += 1; } - if self.schema.is_some() { + if self.negated { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.DropViewNode", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.InListNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if self.if_exists { - struct_ser.serialize_field("ifExists", &self.if_exists)?; + if !self.list.is_empty() { + struct_ser.serialize_field("list", &self.list)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for DropViewNode { +impl<'de> serde::Deserialize<'de> for InListNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "if_exists", - "ifExists", - "schema", + "expr", + "list", + "negated", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - IfExists, - Schema, + Expr, + List, + Negated, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7727,9 +7018,9 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "ifExists" | "if_exists" => Ok(GeneratedField::IfExists), - "schema" => Ok(GeneratedField::Schema), + "expr" => Ok(GeneratedField::Expr), + "list" => Ok(GeneratedField::List), + "negated" => Ok(GeneratedField::Negated), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7739,165 +7030,156 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = DropViewNode; + type Value = InListNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.DropViewNode") + formatter.write_str("struct datafusion.InListNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut if_exists__ = None; - let mut schema__ = None; + let mut expr__ = None; + let mut list__ = None; + let mut negated__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - name__ = map_.next_value()?; + expr__ = map_.next_value()?; } - GeneratedField::IfExists => { - if if_exists__.is_some() { - return Err(serde::de::Error::duplicate_field("ifExists")); + GeneratedField::List => { + if list__.is_some() { + return Err(serde::de::Error::duplicate_field("list")); } - if_exists__ = Some(map_.next_value()?); + list__ = Some(map_.next_value()?); } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); } - schema__ = map_.next_value()?; + negated__ = Some(map_.next_value()?); } } } - Ok(DropViewNode { - name: name__, - if_exists: if_exists__.unwrap_or_default(), - schema: schema__, + Ok(InListNode { + expr: expr__, + list: list__.unwrap_or_default(), + negated: negated__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.DropViewNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for EmptyExecNode { +impl serde::Serialize for InsertOp { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.schema.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.EmptyExecNode", len)?; - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - struct_ser.end() + let variant = match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + }; + serializer.serialize_str(variant) } } -impl<'de> serde::Deserialize<'de> for EmptyExecNode { +impl<'de> serde::Deserialize<'de> for InsertOp { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", + "Append", + "Overwrite", + "Replace", ]; - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Schema, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "schema" => Ok(GeneratedField::Schema), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyExecNode; + type Value = InsertOp; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyExecNode") + write!(formatter, "expected one of: {:?}", &FIELDS) } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, { - let mut schema__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - } + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Append" => Ok(InsertOp::Append), + "Overwrite" => Ok(InsertOp::Overwrite), + "Replace" => Ok(InsertOp::Replace), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } - Ok(EmptyExecNode { - schema: schema__, - }) } } - deserializer.deserialize_struct("datafusion.EmptyExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for EmptyMessage { +impl serde::Serialize for InterleaveExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { use serde::ser::SerializeStruct; - let len = 0; - let struct_ser = serializer.serialize_struct("datafusion.EmptyMessage", len)?; + let mut len = 0; + if !self.inputs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for EmptyMessage { +impl<'de> serde::Deserialize<'de> for InterleaveExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7918,7 +7200,10 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { where E: serde::de::Error, { - Err(serde::de::Error::unknown_field(value, FIELDS)) + match value { + "inputs" => Ok(GeneratedField::Inputs), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } } } deserializer.deserialize_identifier(GeneratedVisitor) @@ -7926,27 +7211,36 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyMessage; + type Value = InterleaveExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyMessage") + formatter.write_str("struct datafusion.InterleaveExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - while map_.next_key::()?.is_some() { - let _ = map_.next_value::()?; + let mut inputs__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } } - Ok(EmptyMessage { + Ok(InterleaveExecNode { + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.EmptyMessage", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for EmptyRelationNode { +impl serde::Serialize for IsFalse { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -7954,30 +7248,29 @@ impl serde::Serialize for EmptyRelationNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.produce_one_row { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.EmptyRelationNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsFalse", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for EmptyRelationNode { +impl<'de> serde::Deserialize<'de> for IsFalse { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ProduceOneRow, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7999,7 +7292,7 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { E: serde::de::Error, { match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8009,36 +7302,36 @@ impl<'de> serde::Deserialize<'de> for EmptyRelationNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyRelationNode; + type Value = IsFalse; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyRelationNode") + formatter.write_str("struct datafusion.IsFalse") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut produce_one_row__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - produce_one_row__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(EmptyRelationNode { - produce_one_row: produce_one_row__.unwrap_or_default(), + Ok(IsFalse { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsFalse", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ExplainExecNode { +impl serde::Serialize for IsNotFalse { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8046,46 +7339,29 @@ impl serde::Serialize for ExplainExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.schema.is_some() { - len += 1; - } - if !self.stringified_plans.is_empty() { - len += 1; - } - if self.verbose { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ExplainExecNode", len)?; - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.stringified_plans.is_empty() { - struct_ser.serialize_field("stringifiedPlans", &self.stringified_plans)?; - } - if self.verbose { - struct_ser.serialize_field("verbose", &self.verbose)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotFalse", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ExplainExecNode { +impl<'de> serde::Deserialize<'de> for IsNotFalse { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", - "stringified_plans", - "stringifiedPlans", - "verbose", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Schema, - StringifiedPlans, - Verbose, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8107,9 +7383,7 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { E: serde::de::Error, { match value { - "schema" => Ok(GeneratedField::Schema), - "stringifiedPlans" | "stringified_plans" => Ok(GeneratedField::StringifiedPlans), - "verbose" => Ok(GeneratedField::Verbose), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8119,52 +7393,36 @@ impl<'de> serde::Deserialize<'de> for ExplainExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExplainExecNode; + type Value = IsNotFalse; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExplainExecNode") + formatter.write_str("struct datafusion.IsNotFalse") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema__ = None; - let mut stringified_plans__ = None; - let mut verbose__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::StringifiedPlans => { - if stringified_plans__.is_some() { - return Err(serde::de::Error::duplicate_field("stringifiedPlans")); - } - stringified_plans__ = Some(map_.next_value()?); - } - GeneratedField::Verbose => { - if verbose__.is_some() { - return Err(serde::de::Error::duplicate_field("verbose")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - verbose__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(ExplainExecNode { - schema: schema__, - stringified_plans: stringified_plans__.unwrap_or_default(), - verbose: verbose__.unwrap_or_default(), + Ok(IsNotFalse { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.ExplainExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotFalse", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ExplainNode { +impl serde::Serialize for IsNotNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8172,37 +7430,29 @@ impl serde::Serialize for ExplainNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.verbose { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ExplainNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if self.verbose { - struct_ser.serialize_field("verbose", &self.verbose)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ExplainNode { +impl<'de> serde::Deserialize<'de> for IsNotNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "verbose", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Verbose, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8224,8 +7474,7 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "verbose" => Ok(GeneratedField::Verbose), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8235,44 +7484,36 @@ impl<'de> serde::Deserialize<'de> for ExplainNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExplainNode; + type Value = IsNotNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExplainNode") + formatter.write_str("struct datafusion.IsNotNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut verbose__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Verbose => { - if verbose__.is_some() { - return Err(serde::de::Error::duplicate_field("verbose")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - verbose__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(ExplainNode { - input: input__, - verbose: verbose__.unwrap_or_default(), + Ok(IsNotNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.ExplainNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Field { +impl serde::Serialize for IsNotTrue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8280,81 +7521,29 @@ impl serde::Serialize for Field { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { - len += 1; - } - if self.arrow_type.is_some() { - len += 1; - } - if self.nullable { - len += 1; - } - if !self.children.is_empty() { - len += 1; - } - if !self.metadata.is_empty() { + if self.expr.is_some() { len += 1; } - if self.dict_id != 0 { - len += 1; - } - if self.dict_ordered { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.Field", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; - } - if self.nullable { - struct_ser.serialize_field("nullable", &self.nullable)?; - } - if !self.children.is_empty() { - struct_ser.serialize_field("children", &self.children)?; - } - if !self.metadata.is_empty() { - struct_ser.serialize_field("metadata", &self.metadata)?; - } - if self.dict_id != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("dictId", ToString::to_string(&self.dict_id).as_str())?; - } - if self.dict_ordered { - struct_ser.serialize_field("dictOrdered", &self.dict_ordered)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotTrue", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Field { +impl<'de> serde::Deserialize<'de> for IsNotTrue { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "arrow_type", - "arrowType", - "nullable", - "children", - "metadata", - "dict_id", - "dictId", - "dict_ordered", - "dictOrdered", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - ArrowType, - Nullable, - Children, - Metadata, - DictId, - DictOrdered, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8376,13 +7565,7 @@ impl<'de> serde::Deserialize<'de> for Field { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), - "nullable" => Ok(GeneratedField::Nullable), - "children" => Ok(GeneratedField::Children), - "metadata" => Ok(GeneratedField::Metadata), - "dictId" | "dict_id" => Ok(GeneratedField::DictId), - "dictOrdered" | "dict_ordered" => Ok(GeneratedField::DictOrdered), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8392,88 +7575,36 @@ impl<'de> serde::Deserialize<'de> for Field { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Field; + type Value = IsNotTrue; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Field") + formatter.write_str("struct datafusion.IsNotTrue") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut arrow_type__ = None; - let mut nullable__ = None; - let mut children__ = None; - let mut metadata__ = None; - let mut dict_id__ = None; - let mut dict_ordered__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); - } - arrow_type__ = map_.next_value()?; - } - GeneratedField::Nullable => { - if nullable__.is_some() { - return Err(serde::de::Error::duplicate_field("nullable")); - } - nullable__ = Some(map_.next_value()?); - } - GeneratedField::Children => { - if children__.is_some() { - return Err(serde::de::Error::duplicate_field("children")); - } - children__ = Some(map_.next_value()?); - } - GeneratedField::Metadata => { - if metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("metadata")); - } - metadata__ = Some( - map_.next_value::>()? - ); - } - GeneratedField::DictId => { - if dict_id__.is_some() { - return Err(serde::de::Error::duplicate_field("dictId")); - } - dict_id__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::DictOrdered => { - if dict_ordered__.is_some() { - return Err(serde::de::Error::duplicate_field("dictOrdered")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - dict_ordered__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(Field { - name: name__.unwrap_or_default(), - arrow_type: arrow_type__, - nullable: nullable__.unwrap_or_default(), - children: children__.unwrap_or_default(), - metadata: metadata__.unwrap_or_default(), - dict_id: dict_id__.unwrap_or_default(), - dict_ordered: dict_ordered__.unwrap_or_default(), + Ok(IsNotTrue { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.Field", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotTrue", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileGroup { +impl serde::Serialize for IsNotUnknown { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8481,29 +7612,29 @@ impl serde::Serialize for FileGroup { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.files.is_empty() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FileGroup", len)?; - if !self.files.is_empty() { - struct_ser.serialize_field("files", &self.files)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNotUnknown", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileGroup { +impl<'de> serde::Deserialize<'de> for IsNotUnknown { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "files", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Files, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8525,7 +7656,7 @@ impl<'de> serde::Deserialize<'de> for FileGroup { E: serde::de::Error, { match value { - "files" => Ok(GeneratedField::Files), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8535,36 +7666,36 @@ impl<'de> serde::Deserialize<'de> for FileGroup { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileGroup; + type Value = IsNotUnknown; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileGroup") + formatter.write_str("struct datafusion.IsNotUnknown") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut files__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Files => { - if files__.is_some() { - return Err(serde::de::Error::duplicate_field("files")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - files__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(FileGroup { - files: files__.unwrap_or_default(), + Ok(IsNotUnknown { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.FileGroup", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNotUnknown", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileRange { +impl serde::Serialize for IsNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8572,39 +7703,29 @@ impl serde::Serialize for FileRange { { use serde::ser::SerializeStruct; let mut len = 0; - if self.start != 0 { - len += 1; - } - if self.end != 0 { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FileRange", len)?; - if self.start != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("start", ToString::to_string(&self.start).as_str())?; - } - if self.end != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("end", ToString::to_string(&self.end).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileRange { +impl<'de> serde::Deserialize<'de> for IsNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "start", - "end", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Start, - End, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8626,8 +7747,7 @@ impl<'de> serde::Deserialize<'de> for FileRange { E: serde::de::Error, { match value { - "start" => Ok(GeneratedField::Start), - "end" => Ok(GeneratedField::End), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8637,48 +7757,36 @@ impl<'de> serde::Deserialize<'de> for FileRange { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileRange; + type Value = IsNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileRange") + formatter.write_str("struct datafusion.IsNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut start__ = None; - let mut end__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); - } - start__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::End => { - if end__.is_some() { - return Err(serde::de::Error::duplicate_field("end")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - end__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } } } - Ok(FileRange { - start: start__.unwrap_or_default(), - end: end__.unwrap_or_default(), + Ok(IsNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.FileRange", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileScanExecConf { +impl serde::Serialize for IsTrue { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8686,89 +7794,29 @@ impl serde::Serialize for FileScanExecConf { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.file_groups.is_empty() { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - if !self.projection.is_empty() { - len += 1; - } - if self.limit.is_some() { - len += 1; - } - if self.statistics.is_some() { - len += 1; - } - if !self.table_partition_cols.is_empty() { - len += 1; - } - if !self.object_store_url.is_empty() { - len += 1; - } - if !self.output_ordering.is_empty() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FileScanExecConf", len)?; - if !self.file_groups.is_empty() { - struct_ser.serialize_field("fileGroups", &self.file_groups)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.projection.is_empty() { - struct_ser.serialize_field("projection", &self.projection)?; - } - if let Some(v) = self.limit.as_ref() { - struct_ser.serialize_field("limit", v)?; - } - if let Some(v) = self.statistics.as_ref() { - struct_ser.serialize_field("statistics", v)?; - } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; - } - if !self.object_store_url.is_empty() { - struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; - } - if !self.output_ordering.is_empty() { - struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsTrue", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileScanExecConf { +impl<'de> serde::Deserialize<'de> for IsTrue { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "file_groups", - "fileGroups", - "schema", - "projection", - "limit", - "statistics", - "table_partition_cols", - "tablePartitionCols", - "object_store_url", - "objectStoreUrl", - "output_ordering", - "outputOrdering", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FileGroups, - Schema, - Projection, - Limit, - Statistics, - TablePartitionCols, - ObjectStoreUrl, - OutputOrdering, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8790,14 +7838,7 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { E: serde::de::Error, { match value { - "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), - "schema" => Ok(GeneratedField::Schema), - "projection" => Ok(GeneratedField::Projection), - "limit" => Ok(GeneratedField::Limit), - "statistics" => Ok(GeneratedField::Statistics), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), - "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -8807,95 +7848,36 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileScanExecConf; + type Value = IsTrue; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileScanExecConf") + formatter.write_str("struct datafusion.IsTrue") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut file_groups__ = None; - let mut schema__ = None; - let mut projection__ = None; - let mut limit__ = None; - let mut statistics__ = None; - let mut table_partition_cols__ = None; - let mut object_store_url__ = None; - let mut output_ordering__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FileGroups => { - if file_groups__.is_some() { - return Err(serde::de::Error::duplicate_field("fileGroups")); - } - file_groups__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - projection__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; - } - GeneratedField::Limit => { - if limit__.is_some() { - return Err(serde::de::Error::duplicate_field("limit")); - } - limit__ = map_.next_value()?; - } - GeneratedField::Statistics => { - if statistics__.is_some() { - return Err(serde::de::Error::duplicate_field("statistics")); - } - statistics__ = map_.next_value()?; - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); - } - table_partition_cols__ = Some(map_.next_value()?); - } - GeneratedField::ObjectStoreUrl => { - if object_store_url__.is_some() { - return Err(serde::de::Error::duplicate_field("objectStoreUrl")); - } - object_store_url__ = Some(map_.next_value()?); - } - GeneratedField::OutputOrdering => { - if output_ordering__.is_some() { - return Err(serde::de::Error::duplicate_field("outputOrdering")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - output_ordering__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(FileScanExecConf { - file_groups: file_groups__.unwrap_or_default(), - schema: schema__, - projection: projection__.unwrap_or_default(), - limit: limit__, - statistics: statistics__, - table_partition_cols: table_partition_cols__.unwrap_or_default(), - object_store_url: object_store_url__.unwrap_or_default(), - output_ordering: output_ordering__.unwrap_or_default(), + Ok(IsTrue { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.FileScanExecConf", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsTrue", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FileSinkConfig { +impl serde::Serialize for IsUnknown { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -8903,74 +7885,29 @@ impl serde::Serialize for FileSinkConfig { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.object_store_url.is_empty() { - len += 1; - } - if !self.file_groups.is_empty() { - len += 1; - } - if !self.table_paths.is_empty() { - len += 1; - } - if self.output_schema.is_some() { - len += 1; - } - if !self.table_partition_cols.is_empty() { - len += 1; - } - if self.overwrite { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; - if !self.object_store_url.is_empty() { - struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; - } - if !self.file_groups.is_empty() { - struct_ser.serialize_field("fileGroups", &self.file_groups)?; - } - if !self.table_paths.is_empty() { - struct_ser.serialize_field("tablePaths", &self.table_paths)?; - } - if let Some(v) = self.output_schema.as_ref() { - struct_ser.serialize_field("outputSchema", v)?; - } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; - } - if self.overwrite { - struct_ser.serialize_field("overwrite", &self.overwrite)?; + let mut struct_ser = serializer.serialize_struct("datafusion.IsUnknown", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FileSinkConfig { +impl<'de> serde::Deserialize<'de> for IsUnknown { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "object_store_url", - "objectStoreUrl", - "file_groups", - "fileGroups", - "table_paths", - "tablePaths", - "output_schema", - "outputSchema", - "table_partition_cols", - "tablePartitionCols", - "overwrite", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ObjectStoreUrl, - FileGroups, - TablePaths, - OutputSchema, - TablePartitionCols, - Overwrite, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -8992,12 +7929,7 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { E: serde::de::Error, { match value { - "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), - "fileGroups" | "file_groups" => Ok(GeneratedField::FileGroups), - "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), - "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "overwrite" => Ok(GeneratedField::Overwrite), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9007,76 +7939,36 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FileSinkConfig; + type Value = IsUnknown; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FileSinkConfig") + formatter.write_str("struct datafusion.IsUnknown") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut object_store_url__ = None; - let mut file_groups__ = None; - let mut table_paths__ = None; - let mut output_schema__ = None; - let mut table_partition_cols__ = None; - let mut overwrite__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ObjectStoreUrl => { - if object_store_url__.is_some() { - return Err(serde::de::Error::duplicate_field("objectStoreUrl")); - } - object_store_url__ = Some(map_.next_value()?); - } - GeneratedField::FileGroups => { - if file_groups__.is_some() { - return Err(serde::de::Error::duplicate_field("fileGroups")); - } - file_groups__ = Some(map_.next_value()?); - } - GeneratedField::TablePaths => { - if table_paths__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePaths")); - } - table_paths__ = Some(map_.next_value()?); - } - GeneratedField::OutputSchema => { - if output_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("outputSchema")); - } - output_schema__ = map_.next_value()?; - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); - } - table_partition_cols__ = Some(map_.next_value()?); - } - GeneratedField::Overwrite => { - if overwrite__.is_some() { - return Err(serde::de::Error::duplicate_field("overwrite")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - overwrite__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(FileSinkConfig { - object_store_url: object_store_url__.unwrap_or_default(), - file_groups: file_groups__.unwrap_or_default(), - table_paths: table_paths__.unwrap_or_default(), - output_schema: output_schema__, - table_partition_cols: table_partition_cols__.unwrap_or_default(), - overwrite: overwrite__.unwrap_or_default(), + Ok(IsUnknown { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.FileSinkConfig", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.IsUnknown", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FilterExecNode { +impl serde::Serialize for JoinFilter { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9084,46 +7976,46 @@ impl serde::Serialize for FilterExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.expression.is_some() { len += 1; } - if self.expr.is_some() { + if !self.column_indices.is_empty() { len += 1; } - if self.default_filter_selectivity != 0 { + if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FilterExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JoinFilter", len)?; + if let Some(v) = self.expression.as_ref() { + struct_ser.serialize_field("expression", v)?; } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.column_indices.is_empty() { + struct_ser.serialize_field("columnIndices", &self.column_indices)?; } - if self.default_filter_selectivity != 0 { - struct_ser.serialize_field("defaultFilterSelectivity", &self.default_filter_selectivity)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FilterExecNode { +impl<'de> serde::Deserialize<'de> for JoinFilter { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "expr", - "default_filter_selectivity", - "defaultFilterSelectivity", + "expression", + "column_indices", + "columnIndices", + "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Expr, - DefaultFilterSelectivity, + Expression, + ColumnIndices, + Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9145,9 +8037,9 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "expr" => Ok(GeneratedField::Expr), - "defaultFilterSelectivity" | "default_filter_selectivity" => Ok(GeneratedField::DefaultFilterSelectivity), + "expression" => Ok(GeneratedField::Expression), + "columnIndices" | "column_indices" => Ok(GeneratedField::ColumnIndices), + "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9157,54 +8049,52 @@ impl<'de> serde::Deserialize<'de> for FilterExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FilterExecNode; + type Value = JoinFilter; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FilterExecNode") + formatter.write_str("struct datafusion.JoinFilter") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut expr__ = None; - let mut default_filter_selectivity__ = None; + let mut expression__ = None; + let mut column_indices__ = None; + let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; + GeneratedField::Expression => { + if expression__.is_some() { + return Err(serde::de::Error::duplicate_field("expression")); + } + expression__ = map_.next_value()?; } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::ColumnIndices => { + if column_indices__.is_some() { + return Err(serde::de::Error::duplicate_field("columnIndices")); } - expr__ = map_.next_value()?; + column_indices__ = Some(map_.next_value()?); } - GeneratedField::DefaultFilterSelectivity => { - if default_filter_selectivity__.is_some() { - return Err(serde::de::Error::duplicate_field("defaultFilterSelectivity")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - default_filter_selectivity__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + schema__ = map_.next_value()?; } } } - Ok(FilterExecNode { - input: input__, - expr: expr__, - default_filter_selectivity: default_filter_selectivity__.unwrap_or_default(), + Ok(JoinFilter { + expression: expression__, + column_indices: column_indices__.unwrap_or_default(), + schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.FilterExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JoinFilter", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FixedSizeBinary { +impl serde::Serialize for JoinNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9212,29 +8102,94 @@ impl serde::Serialize for FixedSizeBinary { { use serde::ser::SerializeStruct; let mut len = 0; - if self.length != 0 { + if self.left.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeBinary", len)?; - if self.length != 0 { - struct_ser.serialize_field("length", &self.length)?; + if self.right.is_some() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.join_constraint != 0 { + len += 1; + } + if !self.left_join_key.is_empty() { + len += 1; + } + if !self.right_join_key.is_empty() { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JoinNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.join_constraint != 0 { + let v = super::datafusion_common::JoinConstraint::try_from(self.join_constraint) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_constraint)))?; + struct_ser.serialize_field("joinConstraint", &v)?; + } + if !self.left_join_key.is_empty() { + struct_ser.serialize_field("leftJoinKey", &self.left_join_key)?; + } + if !self.right_join_key.is_empty() { + struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FixedSizeBinary { +impl<'de> serde::Deserialize<'de> for JoinNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "length", + "left", + "right", + "join_type", + "joinType", + "join_constraint", + "joinConstraint", + "left_join_key", + "leftJoinKey", + "right_join_key", + "rightJoinKey", + "null_equals_null", + "nullEqualsNull", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Length, + Left, + Right, + JoinType, + JoinConstraint, + LeftJoinKey, + RightJoinKey, + NullEqualsNull, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9256,7 +8211,14 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { E: serde::de::Error, { match value { - "length" => Ok(GeneratedField::Length), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), + "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), + "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9266,38 +8228,92 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FixedSizeBinary; + type Value = JoinNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FixedSizeBinary") + formatter.write_str("struct datafusion.JoinNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut length__ = None; + let mut left__ = None; + let mut right__ = None; + let mut join_type__ = None; + let mut join_constraint__ = None; + let mut left_join_key__ = None; + let mut right_join_key__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Length => { - if length__.is_some() { - return Err(serde::de::Error::duplicate_field("length")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - length__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + left__ = map_.next_value()?; + } + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); + } + right__ = map_.next_value()?; + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::JoinConstraint => { + if join_constraint__.is_some() { + return Err(serde::de::Error::duplicate_field("joinConstraint")); + } + join_constraint__ = Some(map_.next_value::()? as i32); + } + GeneratedField::LeftJoinKey => { + if left_join_key__.is_some() { + return Err(serde::de::Error::duplicate_field("leftJoinKey")); + } + left_join_key__ = Some(map_.next_value()?); + } + GeneratedField::RightJoinKey => { + if right_join_key__.is_some() { + return Err(serde::de::Error::duplicate_field("rightJoinKey")); + } + right_join_key__ = Some(map_.next_value()?); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; } } } - Ok(FixedSizeBinary { - length: length__.unwrap_or_default(), + Ok(JoinNode { + left: left__, + right: right__, + join_type: join_type__.unwrap_or_default(), + join_constraint: join_constraint__.unwrap_or_default(), + left_join_key: left_join_key__.unwrap_or_default(), + right_join_key: right_join_key__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JoinNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FixedSizeList { +impl serde::Serialize for JoinOn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9305,39 +8321,37 @@ impl serde::Serialize for FixedSizeList { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_type.is_some() { + if self.left.is_some() { len += 1; } - if self.list_size != 0 { + if self.right.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FixedSizeList", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JoinOn", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; } - if self.list_size != 0 { - struct_ser.serialize_field("listSize", &self.list_size)?; + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FixedSizeList { +impl<'de> serde::Deserialize<'de> for JoinOn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_type", - "fieldType", - "list_size", - "listSize", + "left", + "right", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldType, - ListSize, + Left, + Right, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9359,8 +8373,8 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { E: serde::de::Error, { match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), - "listSize" | "list_size" => Ok(GeneratedField::ListSize), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9370,46 +8384,44 @@ impl<'de> serde::Deserialize<'de> for FixedSizeList { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FixedSizeList; + type Value = JoinOn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FixedSizeList") + formatter.write_str("struct datafusion.JoinOn") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_type__ = None; - let mut list_size__ = None; + let mut left__ = None; + let mut right__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - field_type__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::ListSize => { - if list_size__.is_some() { - return Err(serde::de::Error::duplicate_field("listSize")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - list_size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + right__ = map_.next_value()?; } } } - Ok(FixedSizeList { - field_type: field_type__, - list_size: list_size__.unwrap_or_default(), + Ok(JoinOn { + left: left__, + right: right__, }) } } - deserializer.deserialize_struct("datafusion.FixedSizeList", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JoinOn", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for FullTableReference { +impl serde::Serialize for JsonSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9417,45 +8429,38 @@ impl serde::Serialize for FullTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.catalog.is_empty() { - len += 1; - } - if !self.schema.is_empty() { + if self.config.is_some() { len += 1; } - if !self.table.is_empty() { + if self.writer_options.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.FullTableReference", len)?; - if !self.catalog.is_empty() { - struct_ser.serialize_field("catalog", &self.catalog)?; - } - if !self.schema.is_empty() { - struct_ser.serialize_field("schema", &self.schema)?; + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if let Some(v) = self.writer_options.as_ref() { + struct_ser.serialize_field("writerOptions", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for FullTableReference { +impl<'de> serde::Deserialize<'de> for JsonSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "catalog", - "schema", - "table", + "config", + "writer_options", + "writerOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Catalog, - Schema, - Table, + Config, + WriterOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9477,9 +8482,8 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { E: serde::de::Error, { match value { - "catalog" => Ok(GeneratedField::Catalog), - "schema" => Ok(GeneratedField::Schema), - "table" => Ok(GeneratedField::Table), + "config" => Ok(GeneratedField::Config), + "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9489,52 +8493,44 @@ impl<'de> serde::Deserialize<'de> for FullTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = FullTableReference; + type Value = JsonSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.FullTableReference") + formatter.write_str("struct datafusion.JsonSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut catalog__ = None; - let mut schema__ = None; - let mut table__ = None; + let mut config__ = None; + let mut writer_options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Catalog => { - if catalog__.is_some() { - return Err(serde::de::Error::duplicate_field("catalog")); - } - catalog__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - schema__ = Some(map_.next_value()?); + config__ = map_.next_value()?; } - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::WriterOptions => { + if writer_options__.is_some() { + return Err(serde::de::Error::duplicate_field("writerOptions")); } - table__ = Some(map_.next_value()?); + writer_options__ = map_.next_value()?; } } } - Ok(FullTableReference { - catalog: catalog__.unwrap_or_default(), - schema: schema__.unwrap_or_default(), - table: table__.unwrap_or_default(), + Ok(JsonSink { + config: config__, + writer_options: writer_options__, }) } } - deserializer.deserialize_struct("datafusion.FullTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GetIndexedField { +impl serde::Serialize for JsonSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9542,54 +8538,55 @@ impl serde::Serialize for GetIndexedField { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.input.is_some() { len += 1; } - if self.field.is_some() { + if self.sink.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.GetIndexedField", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.sink_schema.is_some() { + len += 1; } - if let Some(v) = self.field.as_ref() { - match v { - get_indexed_field::Field::NamedStructField(v) => { - struct_ser.serialize_field("namedStructField", v)?; - } - get_indexed_field::Field::ListIndex(v) => { - struct_ser.serialize_field("listIndex", v)?; - } - get_indexed_field::Field::ListRange(v) => { - struct_ser.serialize_field("listRange", v)?; - } - } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.JsonSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for GetIndexedField { +impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "named_struct_field", - "namedStructField", - "list_index", - "listIndex", - "list_range", - "listRange", + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - NamedStructField, - ListIndex, - ListRange, + Input, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9611,10 +8608,10 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "namedStructField" | "named_struct_field" => Ok(GeneratedField::NamedStructField), - "listIndex" | "list_index" => Ok(GeneratedField::ListIndex), - "listRange" | "list_range" => Ok(GeneratedField::ListRange), + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9624,59 +8621,60 @@ impl<'de> serde::Deserialize<'de> for GetIndexedField { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GetIndexedField; + type Value = JsonSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GetIndexedField") + formatter.write_str("struct datafusion.JsonSinkExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut field__ = None; + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - expr__ = map_.next_value()?; + input__ = map_.next_value()?; } - GeneratedField::NamedStructField => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("namedStructField")); + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::NamedStructField) -; + sink__ = map_.next_value()?; } - GeneratedField::ListIndex => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listIndex")); + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListIndex) -; + sink_schema__ = map_.next_value()?; } - GeneratedField::ListRange => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("listRange")); + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(get_indexed_field::Field::ListRange) -; + sort_order__ = map_.next_value()?; } } } - Ok(GetIndexedField { - expr: expr__, - field: field__, + Ok(JsonSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, }) } } - deserializer.deserialize_struct("datafusion.GetIndexedField", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.JsonSinkExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GlobalLimitExecNode { +impl serde::Serialize for LikeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9684,46 +8682,54 @@ impl serde::Serialize for GlobalLimitExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.negated { len += 1; } - if self.skip != 0 { + if self.expr.is_some() { len += 1; } - if self.fetch != 0 { + if self.pattern.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.GlobalLimitExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if !self.escape_char.is_empty() { + len += 1; } - if self.skip != 0 { - struct_ser.serialize_field("skip", &self.skip)?; + let mut struct_ser = serializer.serialize_struct("datafusion.LikeNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } - if self.fetch != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + if !self.escape_char.is_empty() { + struct_ser.serialize_field("escapeChar", &self.escape_char)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { +impl<'de> serde::Deserialize<'de> for LikeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "skip", - "fetch", + "negated", + "expr", + "pattern", + "escape_char", + "escapeChar", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Skip, - Fetch, + Negated, + Expr, + Pattern, + EscapeChar, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9745,9 +8751,10 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "skip" => Ok(GeneratedField::Skip), - "fetch" => Ok(GeneratedField::Fetch), + "negated" => Ok(GeneratedField::Negated), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9757,56 +8764,60 @@ impl<'de> serde::Deserialize<'de> for GlobalLimitExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GlobalLimitExecNode; + type Value = LikeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GlobalLimitExecNode") + formatter.write_str("struct datafusion.LikeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut skip__ = None; - let mut fetch__ = None; + let mut negated__ = None; + let mut expr__ = None; + let mut pattern__ = None; + let mut escape_char__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); } - input__ = map_.next_value()?; + negated__ = Some(map_.next_value()?); } - GeneratedField::Skip => { - if skip__.is_some() { - return Err(serde::de::Error::duplicate_field("skip")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - skip__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = map_.next_value()?; } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + pattern__ = map_.next_value()?; + } + GeneratedField::EscapeChar => { + if escape_char__.is_some() { + return Err(serde::de::Error::duplicate_field("escapeChar")); + } + escape_char__ = Some(map_.next_value()?); } } } - Ok(GlobalLimitExecNode { - input: input__, - skip: skip__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), + Ok(LikeNode { + negated: negated__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + escape_char: escape_char__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.GlobalLimitExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LikeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for GroupingSetNode { +impl serde::Serialize for LimitNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9814,29 +8825,49 @@ impl serde::Serialize for GroupingSetNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.GroupingSetNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + if self.skip != 0 { + len += 1; + } + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LimitNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.skip != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("skip", ToString::to_string(&self.skip).as_str())?; + } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for GroupingSetNode { +impl<'de> serde::Deserialize<'de> for LimitNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "input", + "skip", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Input, + Skip, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -9858,7 +8889,9 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "input" => Ok(GeneratedField::Input), + "skip" => Ok(GeneratedField::Skip), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -9868,36 +8901,56 @@ impl<'de> serde::Deserialize<'de> for GroupingSetNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GroupingSetNode; + type Value = LimitNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.GroupingSetNode") + formatter.write_str("struct datafusion.LimitNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut input__ = None; + let mut skip__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - expr__ = Some(map_.next_value()?); + input__ = map_.next_value()?; + } + GeneratedField::Skip => { + if skip__.is_some() { + return Err(serde::de::Error::duplicate_field("skip")); + } + skip__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(GroupingSetNode { - expr: expr__.unwrap_or_default(), + Ok(LimitNode { + input: input__, + skip: skip__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.GroupingSetNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for HashJoinExecNode { +impl serde::Serialize for ListIndex { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -9905,92 +8958,29 @@ impl serde::Serialize for HashJoinExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - if !self.on.is_empty() { - len += 1; - } - if self.join_type != 0 { - len += 1; - } - if self.partition_mode != 0 { - len += 1; - } - if self.null_equals_null { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - if !self.projection.is_empty() { + if self.key.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.HashJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - if !self.on.is_empty() { - struct_ser.serialize_field("on", &self.on)?; - } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; - } - if self.partition_mode != 0 { - let v = PartitionMode::try_from(self.partition_mode) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; - struct_ser.serialize_field("partitionMode", &v)?; - } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - if !self.projection.is_empty() { - struct_ser.serialize_field("projection", &self.projection)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ListIndex", len)?; + if let Some(v) = self.key.as_ref() { + struct_ser.serialize_field("key", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for HashJoinExecNode { +impl<'de> serde::Deserialize<'de> for ListIndex { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", - "on", - "join_type", - "joinType", - "partition_mode", - "partitionMode", - "null_equals_null", - "nullEqualsNull", - "filter", - "projection", + "key", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, - On, - JoinType, - PartitionMode, - NullEqualsNull, - Filter, - Projection, + Key, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10012,14 +9002,7 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "on" => Ok(GeneratedField::On), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), - "filter" => Ok(GeneratedField::Filter), - "projection" => Ok(GeneratedField::Projection), + "key" => Ok(GeneratedField::Key), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10029,95 +9012,36 @@ impl<'de> serde::Deserialize<'de> for HashJoinExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = HashJoinExecNode; + type Value = ListIndex; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.HashJoinExecNode") + formatter.write_str("struct datafusion.ListIndex") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; - let mut on__ = None; - let mut join_type__ = None; - let mut partition_mode__ = None; - let mut null_equals_null__ = None; - let mut filter__ = None; - let mut projection__ = None; + let mut key__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - GeneratedField::On => { - if on__.is_some() { - return Err(serde::de::Error::duplicate_field("on")); - } - on__ = Some(map_.next_value()?); - } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); - } - join_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::PartitionMode => { - if partition_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionMode")); - } - partition_mode__ = Some(map_.next_value::()? as i32); - } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); - } - null_equals_null__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); } - projection__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; + key__ = map_.next_value()?; } } } - Ok(HashJoinExecNode { - left: left__, - right: right__, - on: on__.unwrap_or_default(), - join_type: join_type__.unwrap_or_default(), - partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), - filter: filter__, - projection: projection__.unwrap_or_default(), + Ok(ListIndex { + key: key__, }) } } - deserializer.deserialize_struct("datafusion.HashJoinExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListIndex", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for HashRepartition { +impl serde::Serialize for ListRange { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10125,40 +9049,45 @@ impl serde::Serialize for HashRepartition { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.hash_expr.is_empty() { + if self.start.is_some() { len += 1; } - if self.partition_count != 0 { + if self.stop.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.HashRepartition", len)?; - if !self.hash_expr.is_empty() { - struct_ser.serialize_field("hashExpr", &self.hash_expr)?; + if self.stride.is_some() { + len += 1; } - if self.partition_count != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; + if let Some(v) = self.start.as_ref() { + struct_ser.serialize_field("start", v)?; + } + if let Some(v) = self.stop.as_ref() { + struct_ser.serialize_field("stop", v)?; + } + if let Some(v) = self.stride.as_ref() { + struct_ser.serialize_field("stride", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for HashRepartition { +impl<'de> serde::Deserialize<'de> for ListRange { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "hash_expr", - "hashExpr", - "partition_count", - "partitionCount", + "start", + "stop", + "stride", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - HashExpr, - PartitionCount, + Start, + Stop, + Stride, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10180,8 +9109,9 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { E: serde::de::Error, { match value { - "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), - "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), + "start" => Ok(GeneratedField::Start), + "stop" => Ok(GeneratedField::Stop), + "stride" => Ok(GeneratedField::Stride), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10191,46 +9121,52 @@ impl<'de> serde::Deserialize<'de> for HashRepartition { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = HashRepartition; + type Value = ListRange; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.HashRepartition") + formatter.write_str("struct datafusion.ListRange") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut hash_expr__ = None; - let mut partition_count__ = None; + let mut start__ = None; + let mut stop__ = None; + let mut stride__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::HashExpr => { - if hash_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("hashExpr")); + GeneratedField::Start => { + if start__.is_some() { + return Err(serde::de::Error::duplicate_field("start")); } - hash_expr__ = Some(map_.next_value()?); + start__ = map_.next_value()?; } - GeneratedField::PartitionCount => { - if partition_count__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionCount")); + GeneratedField::Stop => { + if stop__.is_some() { + return Err(serde::de::Error::duplicate_field("stop")); } - partition_count__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + stop__ = map_.next_value()?; + } + GeneratedField::Stride => { + if stride__.is_some() { + return Err(serde::de::Error::duplicate_field("stride")); + } + stride__ = map_.next_value()?; } } } - Ok(HashRepartition { - hash_expr: hash_expr__.unwrap_or_default(), - partition_count: partition_count__.unwrap_or_default(), + Ok(ListRange { + start: start__, + stop: stop__, + stride: stride__, }) } } - deserializer.deserialize_struct("datafusion.HashRepartition", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ILikeNode { +impl serde::Serialize for ListUnnest { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10238,54 +9174,38 @@ impl serde::Serialize for ILikeNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { - len += 1; - } - if self.expr.is_some() { + if self.index_in_input_schema != 0 { len += 1; } - if self.pattern.is_some() { - len += 1; - } - if !self.escape_char.is_empty() { + if self.depth != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ILikeNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ListUnnest", len)?; + if self.index_in_input_schema != 0 { + struct_ser.serialize_field("indexInInputSchema", &self.index_in_input_schema)?; } - if !self.escape_char.is_empty() { - struct_ser.serialize_field("escapeChar", &self.escape_char)?; + if self.depth != 0 { + struct_ser.serialize_field("depth", &self.depth)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ILikeNode { +impl<'de> serde::Deserialize<'de> for ListUnnest { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", - "expr", - "pattern", - "escape_char", - "escapeChar", + "index_in_input_schema", + "indexInInputSchema", + "depth", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, - Expr, - Pattern, - EscapeChar, + IndexInInputSchema, + Depth, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10307,10 +9227,8 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), - "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), - "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), + "indexInInputSchema" | "index_in_input_schema" => Ok(GeneratedField::IndexInInputSchema), + "depth" => Ok(GeneratedField::Depth), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10320,60 +9238,48 @@ impl<'de> serde::Deserialize<'de> for ILikeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ILikeNode; + type Value = ListUnnest; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ILikeNode") + formatter.write_str("struct datafusion.ListUnnest") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; - let mut expr__ = None; - let mut pattern__ = None; - let mut escape_char__ = None; + let mut index_in_input_schema__ = None; + let mut depth__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); + GeneratedField::IndexInInputSchema => { + if index_in_input_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("indexInInputSchema")); } - pattern__ = map_.next_value()?; + index_in_input_schema__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::EscapeChar => { - if escape_char__.is_some() { - return Err(serde::de::Error::duplicate_field("escapeChar")); + GeneratedField::Depth => { + if depth__.is_some() { + return Err(serde::de::Error::duplicate_field("depth")); } - escape_char__ = Some(map_.next_value()?); + depth__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(ILikeNode { - negated: negated__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, - escape_char: escape_char__.unwrap_or_default(), + Ok(ListUnnest { + index_in_input_schema: index_in_input_schema__.unwrap_or_default(), + depth: depth__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ILikeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListUnnest", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for InListNode { +impl serde::Serialize for ListingTableScanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10381,45 +9287,134 @@ impl serde::Serialize for InListNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.table_name.is_some() { len += 1; } - if !self.list.is_empty() { + if !self.paths.is_empty() { len += 1; } - if self.negated { + if !self.file_extension.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.InListNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.projection.is_some() { + len += 1; } - if !self.list.is_empty() { - struct_ser.serialize_field("list", &self.list)?; + if self.schema.is_some() { + len += 1; } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; + if !self.filters.is_empty() { + len += 1; + } + if !self.table_partition_cols.is_empty() { + len += 1; + } + if self.collect_stat { + len += 1; + } + if self.target_partitions != 0 { + len += 1; + } + if !self.file_sort_order.is_empty() { + len += 1; + } + if self.file_format_type.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ListingTableScanNode", len)?; + if let Some(v) = self.table_name.as_ref() { + struct_ser.serialize_field("tableName", v)?; + } + if !self.paths.is_empty() { + struct_ser.serialize_field("paths", &self.paths)?; + } + if !self.file_extension.is_empty() { + struct_ser.serialize_field("fileExtension", &self.file_extension)?; + } + if let Some(v) = self.projection.as_ref() { + struct_ser.serialize_field("projection", v)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if !self.filters.is_empty() { + struct_ser.serialize_field("filters", &self.filters)?; + } + if !self.table_partition_cols.is_empty() { + struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; + } + if self.collect_stat { + struct_ser.serialize_field("collectStat", &self.collect_stat)?; + } + if self.target_partitions != 0 { + struct_ser.serialize_field("targetPartitions", &self.target_partitions)?; + } + if !self.file_sort_order.is_empty() { + struct_ser.serialize_field("fileSortOrder", &self.file_sort_order)?; + } + if let Some(v) = self.file_format_type.as_ref() { + match v { + listing_table_scan_node::FileFormatType::Csv(v) => { + struct_ser.serialize_field("csv", v)?; + } + listing_table_scan_node::FileFormatType::Parquet(v) => { + struct_ser.serialize_field("parquet", v)?; + } + listing_table_scan_node::FileFormatType::Avro(v) => { + struct_ser.serialize_field("avro", v)?; + } + listing_table_scan_node::FileFormatType::Json(v) => { + struct_ser.serialize_field("json", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for InListNode { +impl<'de> serde::Deserialize<'de> for ListingTableScanNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "list", - "negated", + "table_name", + "tableName", + "paths", + "file_extension", + "fileExtension", + "projection", + "schema", + "filters", + "table_partition_cols", + "tablePartitionCols", + "collect_stat", + "collectStat", + "target_partitions", + "targetPartitions", + "file_sort_order", + "fileSortOrder", + "csv", + "parquet", + "avro", + "json", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - List, - Negated, + TableName, + Paths, + FileExtension, + Projection, + Schema, + Filters, + TablePartitionCols, + CollectStat, + TargetPartitions, + FileSortOrder, + Csv, + Parquet, + Avro, + Json, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10441,9 +9436,20 @@ impl<'de> serde::Deserialize<'de> for InListNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "list" => Ok(GeneratedField::List), - "negated" => Ok(GeneratedField::Negated), + "tableName" | "table_name" => Ok(GeneratedField::TableName), + "paths" => Ok(GeneratedField::Paths), + "fileExtension" | "file_extension" => Ok(GeneratedField::FileExtension), + "projection" => Ok(GeneratedField::Projection), + "schema" => Ok(GeneratedField::Schema), + "filters" => Ok(GeneratedField::Filters), + "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), + "collectStat" | "collect_stat" => Ok(GeneratedField::CollectStat), + "targetPartitions" | "target_partitions" => Ok(GeneratedField::TargetPartitions), + "fileSortOrder" | "file_sort_order" => Ok(GeneratedField::FileSortOrder), + "csv" => Ok(GeneratedField::Csv), + "parquet" => Ok(GeneratedField::Parquet), + "avro" => Ok(GeneratedField::Avro), + "json" => Ok(GeneratedField::Json), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10453,52 +9459,140 @@ impl<'de> serde::Deserialize<'de> for InListNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = InListNode; + type Value = ListingTableScanNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.InListNode") + formatter.write_str("struct datafusion.ListingTableScanNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut list__ = None; - let mut negated__ = None; + let mut table_name__ = None; + let mut paths__ = None; + let mut file_extension__ = None; + let mut projection__ = None; + let mut schema__ = None; + let mut filters__ = None; + let mut table_partition_cols__ = None; + let mut collect_stat__ = None; + let mut target_partitions__ = None; + let mut file_sort_order__ = None; + let mut file_format_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::TableName => { + if table_name__.is_some() { + return Err(serde::de::Error::duplicate_field("tableName")); } - expr__ = map_.next_value()?; + table_name__ = map_.next_value()?; } - GeneratedField::List => { - if list__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); + GeneratedField::Paths => { + if paths__.is_some() { + return Err(serde::de::Error::duplicate_field("paths")); } - list__ = Some(map_.next_value()?); + paths__ = Some(map_.next_value()?); } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); + GeneratedField::FileExtension => { + if file_extension__.is_some() { + return Err(serde::de::Error::duplicate_field("fileExtension")); } - negated__ = Some(map_.next_value()?); + file_extension__ = Some(map_.next_value()?); + } + GeneratedField::Projection => { + if projection__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + projection__ = map_.next_value()?; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Filters => { + if filters__.is_some() { + return Err(serde::de::Error::duplicate_field("filters")); + } + filters__ = Some(map_.next_value()?); + } + GeneratedField::TablePartitionCols => { + if table_partition_cols__.is_some() { + return Err(serde::de::Error::duplicate_field("tablePartitionCols")); + } + table_partition_cols__ = Some(map_.next_value()?); + } + GeneratedField::CollectStat => { + if collect_stat__.is_some() { + return Err(serde::de::Error::duplicate_field("collectStat")); + } + collect_stat__ = Some(map_.next_value()?); + } + GeneratedField::TargetPartitions => { + if target_partitions__.is_some() { + return Err(serde::de::Error::duplicate_field("targetPartitions")); + } + target_partitions__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::FileSortOrder => { + if file_sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("fileSortOrder")); + } + file_sort_order__ = Some(map_.next_value()?); + } + GeneratedField::Csv => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csv")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Csv) +; + } + GeneratedField::Parquet => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquet")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Parquet) +; + } + GeneratedField::Avro => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("avro")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) +; + } + GeneratedField::Json => { + if file_format_type__.is_some() { + return Err(serde::de::Error::duplicate_field("json")); + } + file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Json) +; } } } - Ok(InListNode { - expr: expr__, - list: list__.unwrap_or_default(), - negated: negated__.unwrap_or_default(), + Ok(ListingTableScanNode { + table_name: table_name__, + paths: paths__.unwrap_or_default(), + file_extension: file_extension__.unwrap_or_default(), + projection: projection__, + schema: schema__, + filters: filters__.unwrap_or_default(), + table_partition_cols: table_partition_cols__.unwrap_or_default(), + collect_stat: collect_stat__.unwrap_or_default(), + target_partitions: target_partitions__.unwrap_or_default(), + file_sort_order: file_sort_order__.unwrap_or_default(), + file_format_type: file_format_type__, }) } } - deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ListingTableScanNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for InterleaveExecNode { +impl serde::Serialize for LocalLimitExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10506,29 +9600,37 @@ impl serde::Serialize for InterleaveExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.inputs.is_empty() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.InterleaveExecNode", len)?; - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.LocalLimitExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if self.fetch != 0 { + struct_ser.serialize_field("fetch", &self.fetch)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for InterleaveExecNode { +impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "inputs", + "input", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Inputs, + Input, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10550,7 +9652,8 @@ impl<'de> serde::Deserialize<'de> for InterleaveExecNode { E: serde::de::Error, { match value { - "inputs" => Ok(GeneratedField::Inputs), + "input" => Ok(GeneratedField::Input), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10560,36 +9663,46 @@ impl<'de> serde::Deserialize<'de> for InterleaveExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = InterleaveExecNode; + type Value = LocalLimitExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.InterleaveExecNode") + formatter.write_str("struct datafusion.LocalLimitExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut inputs__ = None; + let mut input__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - inputs__ = Some(map_.next_value()?); + input__ = map_.next_value()?; + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(InterleaveExecNode { - inputs: inputs__.unwrap_or_default(), + Ok(LocalLimitExecNode { + input: input__, + fetch: fetch__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.InterleaveExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LocalLimitExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IntervalMonthDayNanoValue { +impl serde::Serialize for LogicalExprList { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10597,46 +9710,29 @@ impl serde::Serialize for IntervalMonthDayNanoValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.months != 0 { - len += 1; - } - if self.days != 0 { - len += 1; - } - if self.nanos != 0 { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IntervalMonthDayNanoValue", len)?; - if self.months != 0 { - struct_ser.serialize_field("months", &self.months)?; - } - if self.days != 0 { - struct_ser.serialize_field("days", &self.days)?; - } - if self.nanos != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("nanos", ToString::to_string(&self.nanos).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprList", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { +impl<'de> serde::Deserialize<'de> for LogicalExprList { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "months", - "days", - "nanos", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Months, - Days, - Nanos, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10658,9 +9754,7 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { E: serde::de::Error, { match value { - "months" => Ok(GeneratedField::Months), - "days" => Ok(GeneratedField::Days), - "nanos" => Ok(GeneratedField::Nanos), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10670,132 +9764,36 @@ impl<'de> serde::Deserialize<'de> for IntervalMonthDayNanoValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IntervalMonthDayNanoValue; + type Value = LogicalExprList; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IntervalMonthDayNanoValue") + formatter.write_str("struct datafusion.LogicalExprList") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut months__ = None; - let mut days__ = None; - let mut nanos__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Months => { - if months__.is_some() { - return Err(serde::de::Error::duplicate_field("months")); - } - months__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Days => { - if days__.is_some() { - return Err(serde::de::Error::duplicate_field("days")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - days__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Nanos => { - if nanos__.is_some() { - return Err(serde::de::Error::duplicate_field("nanos")); - } - nanos__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr__ = Some(map_.next_value()?); } } } - Ok(IntervalMonthDayNanoValue { - months: months__.unwrap_or_default(), - days: days__.unwrap_or_default(), - nanos: nanos__.unwrap_or_default(), + Ok(LogicalExprList { + expr: expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.IntervalMonthDayNanoValue", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IntervalUnit { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::YearMonth => "YearMonth", - Self::DayTime => "DayTime", - Self::MonthDayNano => "MonthDayNano", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for IntervalUnit { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "YearMonth", - "DayTime", - "MonthDayNano", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IntervalUnit; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "YearMonth" => Ok(IntervalUnit::YearMonth), - "DayTime" => Ok(IntervalUnit::DayTime), - "MonthDayNano" => Ok(IntervalUnit::MonthDayNano), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LogicalExprList", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for IsFalse { +impl serde::Serialize for LogicalExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -10803,29 +9801,201 @@ impl serde::Serialize for IsFalse { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.expr_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.IsFalse", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNode", len)?; + if let Some(v) = self.expr_type.as_ref() { + match v { + logical_expr_node::ExprType::Column(v) => { + struct_ser.serialize_field("column", v)?; + } + logical_expr_node::ExprType::Alias(v) => { + struct_ser.serialize_field("alias", v)?; + } + logical_expr_node::ExprType::Literal(v) => { + struct_ser.serialize_field("literal", v)?; + } + logical_expr_node::ExprType::BinaryExpr(v) => { + struct_ser.serialize_field("binaryExpr", v)?; + } + logical_expr_node::ExprType::IsNullExpr(v) => { + struct_ser.serialize_field("isNullExpr", v)?; + } + logical_expr_node::ExprType::IsNotNullExpr(v) => { + struct_ser.serialize_field("isNotNullExpr", v)?; + } + logical_expr_node::ExprType::NotExpr(v) => { + struct_ser.serialize_field("notExpr", v)?; + } + logical_expr_node::ExprType::Between(v) => { + struct_ser.serialize_field("between", v)?; + } + logical_expr_node::ExprType::Case(v) => { + struct_ser.serialize_field("case", v)?; + } + logical_expr_node::ExprType::Cast(v) => { + struct_ser.serialize_field("cast", v)?; + } + logical_expr_node::ExprType::Negative(v) => { + struct_ser.serialize_field("negative", v)?; + } + logical_expr_node::ExprType::InList(v) => { + struct_ser.serialize_field("inList", v)?; + } + logical_expr_node::ExprType::Wildcard(v) => { + struct_ser.serialize_field("wildcard", v)?; + } + logical_expr_node::ExprType::TryCast(v) => { + struct_ser.serialize_field("tryCast", v)?; + } + logical_expr_node::ExprType::WindowExpr(v) => { + struct_ser.serialize_field("windowExpr", v)?; + } + logical_expr_node::ExprType::AggregateUdfExpr(v) => { + struct_ser.serialize_field("aggregateUdfExpr", v)?; + } + logical_expr_node::ExprType::ScalarUdfExpr(v) => { + struct_ser.serialize_field("scalarUdfExpr", v)?; + } + logical_expr_node::ExprType::GroupingSet(v) => { + struct_ser.serialize_field("groupingSet", v)?; + } + logical_expr_node::ExprType::Cube(v) => { + struct_ser.serialize_field("cube", v)?; + } + logical_expr_node::ExprType::Rollup(v) => { + struct_ser.serialize_field("rollup", v)?; + } + logical_expr_node::ExprType::IsTrue(v) => { + struct_ser.serialize_field("isTrue", v)?; + } + logical_expr_node::ExprType::IsFalse(v) => { + struct_ser.serialize_field("isFalse", v)?; + } + logical_expr_node::ExprType::IsUnknown(v) => { + struct_ser.serialize_field("isUnknown", v)?; + } + logical_expr_node::ExprType::IsNotTrue(v) => { + struct_ser.serialize_field("isNotTrue", v)?; + } + logical_expr_node::ExprType::IsNotFalse(v) => { + struct_ser.serialize_field("isNotFalse", v)?; + } + logical_expr_node::ExprType::IsNotUnknown(v) => { + struct_ser.serialize_field("isNotUnknown", v)?; + } + logical_expr_node::ExprType::Like(v) => { + struct_ser.serialize_field("like", v)?; + } + logical_expr_node::ExprType::Ilike(v) => { + struct_ser.serialize_field("ilike", v)?; + } + logical_expr_node::ExprType::SimilarTo(v) => { + struct_ser.serialize_field("similarTo", v)?; + } + logical_expr_node::ExprType::Placeholder(v) => { + struct_ser.serialize_field("placeholder", v)?; + } + logical_expr_node::ExprType::Unnest(v) => { + struct_ser.serialize_field("unnest", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for IsFalse { +impl<'de> serde::Deserialize<'de> for LogicalExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "column", + "alias", + "literal", + "binary_expr", + "binaryExpr", + "is_null_expr", + "isNullExpr", + "is_not_null_expr", + "isNotNullExpr", + "not_expr", + "notExpr", + "between", + "case_", + "case", + "cast", + "negative", + "in_list", + "inList", + "wildcard", + "try_cast", + "tryCast", + "window_expr", + "windowExpr", + "aggregate_udf_expr", + "aggregateUdfExpr", + "scalar_udf_expr", + "scalarUdfExpr", + "grouping_set", + "groupingSet", + "cube", + "rollup", + "is_true", + "isTrue", + "is_false", + "isFalse", + "is_unknown", + "isUnknown", + "is_not_true", + "isNotTrue", + "is_not_false", + "isNotFalse", + "is_not_unknown", + "isNotUnknown", + "like", + "ilike", + "similar_to", + "similarTo", + "placeholder", + "unnest", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Column, + Alias, + Literal, + BinaryExpr, + IsNullExpr, + IsNotNullExpr, + NotExpr, + Between, + Case, + Cast, + Negative, + InList, + Wildcard, + TryCast, + WindowExpr, + AggregateUdfExpr, + ScalarUdfExpr, + GroupingSet, + Cube, + Rollup, + IsTrue, + IsFalse, + IsUnknown, + IsNotTrue, + IsNotFalse, + IsNotUnknown, + Like, + Ilike, + SimilarTo, + Placeholder, + Unnest, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10847,5441 +10017,37 @@ impl<'de> serde::Deserialize<'de> for IsFalse { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsFalse; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsFalse") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsFalse { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsFalse", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsNotFalse { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotFalse", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsNotFalse { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotFalse; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotFalse") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsNotFalse { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsNotFalse", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsNotNull { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotNull", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsNotNull { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotNull; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotNull") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsNotNull { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsNotNull", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsNotTrue { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotTrue", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsNotTrue { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotTrue; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotTrue") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsNotTrue { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsNotTrue", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsNotUnknown { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNotUnknown", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsNotUnknown { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNotUnknown; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNotUnknown") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsNotUnknown { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsNotUnknown", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsNull { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsNull", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsNull { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsNull; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsNull") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsNull { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsNull", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsTrue { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsTrue", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsTrue { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsTrue; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsTrue") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsTrue { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsTrue", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for IsUnknown { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.IsUnknown", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for IsUnknown { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = IsUnknown; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.IsUnknown") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(IsUnknown { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.IsUnknown", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JoinConstraint { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::On => "ON", - Self::Using => "USING", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for JoinConstraint { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "ON", - "USING", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinConstraint; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "ON" => Ok(JoinConstraint::On), - "USING" => Ok(JoinConstraint::Using), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} -impl serde::Serialize for JoinFilter { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expression.is_some() { - len += 1; - } - if !self.column_indices.is_empty() { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JoinFilter", len)?; - if let Some(v) = self.expression.as_ref() { - struct_ser.serialize_field("expression", v)?; - } - if !self.column_indices.is_empty() { - struct_ser.serialize_field("columnIndices", &self.column_indices)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JoinFilter { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expression", - "column_indices", - "columnIndices", - "schema", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expression, - ColumnIndices, - Schema, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expression" => Ok(GeneratedField::Expression), - "columnIndices" | "column_indices" => Ok(GeneratedField::ColumnIndices), - "schema" => Ok(GeneratedField::Schema), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinFilter; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JoinFilter") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expression__ = None; - let mut column_indices__ = None; - let mut schema__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expression => { - if expression__.is_some() { - return Err(serde::de::Error::duplicate_field("expression")); - } - expression__ = map_.next_value()?; - } - GeneratedField::ColumnIndices => { - if column_indices__.is_some() { - return Err(serde::de::Error::duplicate_field("columnIndices")); - } - column_indices__ = Some(map_.next_value()?); - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - } - } - Ok(JoinFilter { - expression: expression__, - column_indices: column_indices__.unwrap_or_default(), - schema: schema__, - }) - } - } - deserializer.deserialize_struct("datafusion.JoinFilter", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JoinNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - if self.join_type != 0 { - len += 1; - } - if self.join_constraint != 0 { - len += 1; - } - if !self.left_join_key.is_empty() { - len += 1; - } - if !self.right_join_key.is_empty() { - len += 1; - } - if self.null_equals_null { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JoinNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; - } - if self.join_constraint != 0 { - let v = JoinConstraint::try_from(self.join_constraint) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_constraint)))?; - struct_ser.serialize_field("joinConstraint", &v)?; - } - if !self.left_join_key.is_empty() { - struct_ser.serialize_field("leftJoinKey", &self.left_join_key)?; - } - if !self.right_join_key.is_empty() { - struct_ser.serialize_field("rightJoinKey", &self.right_join_key)?; - } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JoinNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "left", - "right", - "join_type", - "joinType", - "join_constraint", - "joinConstraint", - "left_join_key", - "leftJoinKey", - "right_join_key", - "rightJoinKey", - "null_equals_null", - "nullEqualsNull", - "filter", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Left, - Right, - JoinType, - JoinConstraint, - LeftJoinKey, - RightJoinKey, - NullEqualsNull, - Filter, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "joinConstraint" | "join_constraint" => Ok(GeneratedField::JoinConstraint), - "leftJoinKey" | "left_join_key" => Ok(GeneratedField::LeftJoinKey), - "rightJoinKey" | "right_join_key" => Ok(GeneratedField::RightJoinKey), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), - "filter" => Ok(GeneratedField::Filter), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JoinNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut left__ = None; - let mut right__ = None; - let mut join_type__ = None; - let mut join_constraint__ = None; - let mut left_join_key__ = None; - let mut right_join_key__ = None; - let mut null_equals_null__ = None; - let mut filter__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); - } - join_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::JoinConstraint => { - if join_constraint__.is_some() { - return Err(serde::de::Error::duplicate_field("joinConstraint")); - } - join_constraint__ = Some(map_.next_value::()? as i32); - } - GeneratedField::LeftJoinKey => { - if left_join_key__.is_some() { - return Err(serde::de::Error::duplicate_field("leftJoinKey")); - } - left_join_key__ = Some(map_.next_value()?); - } - GeneratedField::RightJoinKey => { - if right_join_key__.is_some() { - return Err(serde::de::Error::duplicate_field("rightJoinKey")); - } - right_join_key__ = Some(map_.next_value()?); - } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); - } - null_equals_null__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; - } - } - } - Ok(JoinNode { - left: left__, - right: right__, - join_type: join_type__.unwrap_or_default(), - join_constraint: join_constraint__.unwrap_or_default(), - left_join_key: left_join_key__.unwrap_or_default(), - right_join_key: right_join_key__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), - filter: filter__, - }) - } - } - deserializer.deserialize_struct("datafusion.JoinNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JoinOn { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JoinOn", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JoinOn { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "left", - "right", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Left, - Right, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinOn; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JoinOn") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut left__ = None; - let mut right__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - } - } - Ok(JoinOn { - left: left__, - right: right__, - }) - } - } - deserializer.deserialize_struct("datafusion.JoinOn", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JoinSide { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::LeftSide => "LEFT_SIDE", - Self::RightSide => "RIGHT_SIDE", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for JoinSide { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "LEFT_SIDE", - "RIGHT_SIDE", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinSide; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "LEFT_SIDE" => Ok(JoinSide::LeftSide), - "RIGHT_SIDE" => Ok(JoinSide::RightSide), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} -impl serde::Serialize for JoinType { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Inner => "INNER", - Self::Left => "LEFT", - Self::Right => "RIGHT", - Self::Full => "FULL", - Self::Leftsemi => "LEFTSEMI", - Self::Leftanti => "LEFTANTI", - Self::Rightsemi => "RIGHTSEMI", - Self::Rightanti => "RIGHTANTI", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for JoinType { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "INNER", - "LEFT", - "RIGHT", - "FULL", - "LEFTSEMI", - "LEFTANTI", - "RIGHTSEMI", - "RIGHTANTI", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JoinType; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "INNER" => Ok(JoinType::Inner), - "LEFT" => Ok(JoinType::Left), - "RIGHT" => Ok(JoinType::Right), - "FULL" => Ok(JoinType::Full), - "LEFTSEMI" => Ok(JoinType::Leftsemi), - "LEFTANTI" => Ok(JoinType::Leftanti), - "RIGHTSEMI" => Ok(JoinType::Rightsemi), - "RIGHTANTI" => Ok(JoinType::Rightanti), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) - } -} -impl serde::Serialize for JsonOptions { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.compression != 0 { - len += 1; - } - if self.schema_infer_max_rec != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JsonOptions", len)?; - if self.compression != 0 { - let v = CompressionTypeVariant::try_from(self.compression) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; - struct_ser.serialize_field("compression", &v)?; - } - if self.schema_infer_max_rec != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("schemaInferMaxRec", ToString::to_string(&self.schema_infer_max_rec).as_str())?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JsonOptions { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "compression", - "schema_infer_max_rec", - "schemaInferMaxRec", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Compression, - SchemaInferMaxRec, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "compression" => Ok(GeneratedField::Compression), - "schemaInferMaxRec" | "schema_infer_max_rec" => Ok(GeneratedField::SchemaInferMaxRec), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JsonOptions; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JsonOptions") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut compression__ = None; - let mut schema_infer_max_rec__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Compression => { - if compression__.is_some() { - return Err(serde::de::Error::duplicate_field("compression")); - } - compression__ = Some(map_.next_value::()? as i32); - } - GeneratedField::SchemaInferMaxRec => { - if schema_infer_max_rec__.is_some() { - return Err(serde::de::Error::duplicate_field("schemaInferMaxRec")); - } - schema_infer_max_rec__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - } - } - Ok(JsonOptions { - compression: compression__.unwrap_or_default(), - schema_infer_max_rec: schema_infer_max_rec__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.JsonOptions", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JsonSink { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.config.is_some() { - len += 1; - } - if self.writer_options.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JsonSink", len)?; - if let Some(v) = self.config.as_ref() { - struct_ser.serialize_field("config", v)?; - } - if let Some(v) = self.writer_options.as_ref() { - struct_ser.serialize_field("writerOptions", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JsonSink { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "config", - "writer_options", - "writerOptions", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Config, - WriterOptions, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "config" => Ok(GeneratedField::Config), - "writerOptions" | "writer_options" => Ok(GeneratedField::WriterOptions), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JsonSink; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JsonSink") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut config__ = None; - let mut writer_options__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Config => { - if config__.is_some() { - return Err(serde::de::Error::duplicate_field("config")); - } - config__ = map_.next_value()?; - } - GeneratedField::WriterOptions => { - if writer_options__.is_some() { - return Err(serde::de::Error::duplicate_field("writerOptions")); - } - writer_options__ = map_.next_value()?; - } - } - } - Ok(JsonSink { - config: config__, - writer_options: writer_options__, - }) - } - } - deserializer.deserialize_struct("datafusion.JsonSink", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JsonSinkExecNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.sink.is_some() { - len += 1; - } - if self.sink_schema.is_some() { - len += 1; - } - if self.sort_order.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JsonSinkExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if let Some(v) = self.sink.as_ref() { - struct_ser.serialize_field("sink", v)?; - } - if let Some(v) = self.sink_schema.as_ref() { - struct_ser.serialize_field("sinkSchema", v)?; - } - if let Some(v) = self.sort_order.as_ref() { - struct_ser.serialize_field("sortOrder", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JsonSinkExecNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "input", - "sink", - "sink_schema", - "sinkSchema", - "sort_order", - "sortOrder", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Input, - Sink, - SinkSchema, - SortOrder, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "input" => Ok(GeneratedField::Input), - "sink" => Ok(GeneratedField::Sink), - "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), - "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JsonSinkExecNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JsonSinkExecNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut input__ = None; - let mut sink__ = None; - let mut sink_schema__ = None; - let mut sort_order__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Sink => { - if sink__.is_some() { - return Err(serde::de::Error::duplicate_field("sink")); - } - sink__ = map_.next_value()?; - } - GeneratedField::SinkSchema => { - if sink_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("sinkSchema")); - } - sink_schema__ = map_.next_value()?; - } - GeneratedField::SortOrder => { - if sort_order__.is_some() { - return Err(serde::de::Error::duplicate_field("sortOrder")); - } - sort_order__ = map_.next_value()?; - } - } - } - Ok(JsonSinkExecNode { - input: input__, - sink: sink__, - sink_schema: sink_schema__, - sort_order: sort_order__, - }) - } - } - deserializer.deserialize_struct("datafusion.JsonSinkExecNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for JsonWriterOptions { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.compression != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.JsonWriterOptions", len)?; - if self.compression != 0 { - let v = CompressionTypeVariant::try_from(self.compression) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.compression)))?; - struct_ser.serialize_field("compression", &v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for JsonWriterOptions { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "compression", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Compression, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "compression" => Ok(GeneratedField::Compression), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = JsonWriterOptions; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.JsonWriterOptions") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut compression__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Compression => { - if compression__.is_some() { - return Err(serde::de::Error::duplicate_field("compression")); - } - compression__ = Some(map_.next_value::()? as i32); - } - } - } - Ok(JsonWriterOptions { - compression: compression__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.JsonWriterOptions", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LikeNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.negated { - len += 1; - } - if self.expr.is_some() { - len += 1; - } - if self.pattern.is_some() { - len += 1; - } - if !self.escape_char.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LikeNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; - } - if !self.escape_char.is_empty() { - struct_ser.serialize_field("escapeChar", &self.escape_char)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LikeNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "negated", - "expr", - "pattern", - "escape_char", - "escapeChar", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Negated, - Expr, - Pattern, - EscapeChar, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "negated" => Ok(GeneratedField::Negated), - "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), - "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LikeNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LikeNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut negated__ = None; - let mut expr__ = None; - let mut pattern__ = None; - let mut escape_char__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); - } - pattern__ = map_.next_value()?; - } - GeneratedField::EscapeChar => { - if escape_char__.is_some() { - return Err(serde::de::Error::duplicate_field("escapeChar")); - } - escape_char__ = Some(map_.next_value()?); - } - } - } - Ok(LikeNode { - negated: negated__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, - escape_char: escape_char__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.LikeNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LimitNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.skip != 0 { - len += 1; - } - if self.fetch != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LimitNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if self.skip != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("skip", ToString::to_string(&self.skip).as_str())?; - } - if self.fetch != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LimitNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "input", - "skip", - "fetch", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Input, - Skip, - Fetch, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "input" => Ok(GeneratedField::Input), - "skip" => Ok(GeneratedField::Skip), - "fetch" => Ok(GeneratedField::Fetch), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LimitNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LimitNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut input__ = None; - let mut skip__ = None; - let mut fetch__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Skip => { - if skip__.is_some() { - return Err(serde::de::Error::duplicate_field("skip")); - } - skip__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); - } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - } - } - Ok(LimitNode { - input: input__, - skip: skip__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.LimitNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for List { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.field_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.List", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for List { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "field_type", - "fieldType", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - FieldType, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = List; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.List") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut field_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); - } - field_type__ = map_.next_value()?; - } - } - } - Ok(List { - field_type: field_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.List", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ListIndex { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.key.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ListIndex", len)?; - if let Some(v) = self.key.as_ref() { - struct_ser.serialize_field("key", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ListIndex { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "key", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Key, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "key" => Ok(GeneratedField::Key), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListIndex; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListIndex") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut key__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Key => { - if key__.is_some() { - return Err(serde::de::Error::duplicate_field("key")); - } - key__ = map_.next_value()?; - } - } - } - Ok(ListIndex { - key: key__, - }) - } - } - deserializer.deserialize_struct("datafusion.ListIndex", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ListRange { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.start.is_some() { - len += 1; - } - if self.stop.is_some() { - len += 1; - } - if self.stride.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ListRange", len)?; - if let Some(v) = self.start.as_ref() { - struct_ser.serialize_field("start", v)?; - } - if let Some(v) = self.stop.as_ref() { - struct_ser.serialize_field("stop", v)?; - } - if let Some(v) = self.stride.as_ref() { - struct_ser.serialize_field("stride", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ListRange { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "start", - "stop", - "stride", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Start, - Stop, - Stride, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "start" => Ok(GeneratedField::Start), - "stop" => Ok(GeneratedField::Stop), - "stride" => Ok(GeneratedField::Stride), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListRange; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListRange") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut start__ = None; - let mut stop__ = None; - let mut stride__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Start => { - if start__.is_some() { - return Err(serde::de::Error::duplicate_field("start")); - } - start__ = map_.next_value()?; - } - GeneratedField::Stop => { - if stop__.is_some() { - return Err(serde::de::Error::duplicate_field("stop")); - } - stop__ = map_.next_value()?; - } - GeneratedField::Stride => { - if stride__.is_some() { - return Err(serde::de::Error::duplicate_field("stride")); - } - stride__ = map_.next_value()?; - } - } - } - Ok(ListRange { - start: start__, - stop: stop__, - stride: stride__, - }) - } - } - deserializer.deserialize_struct("datafusion.ListRange", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ListingTableScanNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.table_name.is_some() { - len += 1; - } - if !self.paths.is_empty() { - len += 1; - } - if !self.file_extension.is_empty() { - len += 1; - } - if self.projection.is_some() { - len += 1; - } - if self.schema.is_some() { - len += 1; - } - if !self.filters.is_empty() { - len += 1; - } - if !self.table_partition_cols.is_empty() { - len += 1; - } - if self.collect_stat { - len += 1; - } - if self.target_partitions != 0 { - len += 1; - } - if !self.file_sort_order.is_empty() { - len += 1; - } - if self.file_format_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ListingTableScanNode", len)?; - if let Some(v) = self.table_name.as_ref() { - struct_ser.serialize_field("tableName", v)?; - } - if !self.paths.is_empty() { - struct_ser.serialize_field("paths", &self.paths)?; - } - if !self.file_extension.is_empty() { - struct_ser.serialize_field("fileExtension", &self.file_extension)?; - } - if let Some(v) = self.projection.as_ref() { - struct_ser.serialize_field("projection", v)?; - } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; - } - if !self.filters.is_empty() { - struct_ser.serialize_field("filters", &self.filters)?; - } - if !self.table_partition_cols.is_empty() { - struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; - } - if self.collect_stat { - struct_ser.serialize_field("collectStat", &self.collect_stat)?; - } - if self.target_partitions != 0 { - struct_ser.serialize_field("targetPartitions", &self.target_partitions)?; - } - if !self.file_sort_order.is_empty() { - struct_ser.serialize_field("fileSortOrder", &self.file_sort_order)?; - } - if let Some(v) = self.file_format_type.as_ref() { - match v { - listing_table_scan_node::FileFormatType::Csv(v) => { - struct_ser.serialize_field("csv", v)?; - } - listing_table_scan_node::FileFormatType::Parquet(v) => { - struct_ser.serialize_field("parquet", v)?; - } - listing_table_scan_node::FileFormatType::Avro(v) => { - struct_ser.serialize_field("avro", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ListingTableScanNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "table_name", - "tableName", - "paths", - "file_extension", - "fileExtension", - "projection", - "schema", - "filters", - "table_partition_cols", - "tablePartitionCols", - "collect_stat", - "collectStat", - "target_partitions", - "targetPartitions", - "file_sort_order", - "fileSortOrder", - "csv", - "parquet", - "avro", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - TableName, - Paths, - FileExtension, - Projection, - Schema, - Filters, - TablePartitionCols, - CollectStat, - TargetPartitions, - FileSortOrder, - Csv, - Parquet, - Avro, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "tableName" | "table_name" => Ok(GeneratedField::TableName), - "paths" => Ok(GeneratedField::Paths), - "fileExtension" | "file_extension" => Ok(GeneratedField::FileExtension), - "projection" => Ok(GeneratedField::Projection), - "schema" => Ok(GeneratedField::Schema), - "filters" => Ok(GeneratedField::Filters), - "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "collectStat" | "collect_stat" => Ok(GeneratedField::CollectStat), - "targetPartitions" | "target_partitions" => Ok(GeneratedField::TargetPartitions), - "fileSortOrder" | "file_sort_order" => Ok(GeneratedField::FileSortOrder), - "csv" => Ok(GeneratedField::Csv), - "parquet" => Ok(GeneratedField::Parquet), - "avro" => Ok(GeneratedField::Avro), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ListingTableScanNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ListingTableScanNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut table_name__ = None; - let mut paths__ = None; - let mut file_extension__ = None; - let mut projection__ = None; - let mut schema__ = None; - let mut filters__ = None; - let mut table_partition_cols__ = None; - let mut collect_stat__ = None; - let mut target_partitions__ = None; - let mut file_sort_order__ = None; - let mut file_format_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::TableName => { - if table_name__.is_some() { - return Err(serde::de::Error::duplicate_field("tableName")); - } - table_name__ = map_.next_value()?; - } - GeneratedField::Paths => { - if paths__.is_some() { - return Err(serde::de::Error::duplicate_field("paths")); - } - paths__ = Some(map_.next_value()?); - } - GeneratedField::FileExtension => { - if file_extension__.is_some() { - return Err(serde::de::Error::duplicate_field("fileExtension")); - } - file_extension__ = Some(map_.next_value()?); - } - GeneratedField::Projection => { - if projection__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - projection__ = map_.next_value()?; - } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); - } - schema__ = map_.next_value()?; - } - GeneratedField::Filters => { - if filters__.is_some() { - return Err(serde::de::Error::duplicate_field("filters")); - } - filters__ = Some(map_.next_value()?); - } - GeneratedField::TablePartitionCols => { - if table_partition_cols__.is_some() { - return Err(serde::de::Error::duplicate_field("tablePartitionCols")); - } - table_partition_cols__ = Some(map_.next_value()?); - } - GeneratedField::CollectStat => { - if collect_stat__.is_some() { - return Err(serde::de::Error::duplicate_field("collectStat")); - } - collect_stat__ = Some(map_.next_value()?); - } - GeneratedField::TargetPartitions => { - if target_partitions__.is_some() { - return Err(serde::de::Error::duplicate_field("targetPartitions")); - } - target_partitions__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::FileSortOrder => { - if file_sort_order__.is_some() { - return Err(serde::de::Error::duplicate_field("fileSortOrder")); - } - file_sort_order__ = Some(map_.next_value()?); - } - GeneratedField::Csv => { - if file_format_type__.is_some() { - return Err(serde::de::Error::duplicate_field("csv")); - } - file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Csv) -; - } - GeneratedField::Parquet => { - if file_format_type__.is_some() { - return Err(serde::de::Error::duplicate_field("parquet")); - } - file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Parquet) -; - } - GeneratedField::Avro => { - if file_format_type__.is_some() { - return Err(serde::de::Error::duplicate_field("avro")); - } - file_format_type__ = map_.next_value::<::std::option::Option<_>>()?.map(listing_table_scan_node::FileFormatType::Avro) -; - } - } - } - Ok(ListingTableScanNode { - table_name: table_name__, - paths: paths__.unwrap_or_default(), - file_extension: file_extension__.unwrap_or_default(), - projection: projection__, - schema: schema__, - filters: filters__.unwrap_or_default(), - table_partition_cols: table_partition_cols__.unwrap_or_default(), - collect_stat: collect_stat__.unwrap_or_default(), - target_partitions: target_partitions__.unwrap_or_default(), - file_sort_order: file_sort_order__.unwrap_or_default(), - file_format_type: file_format_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.ListingTableScanNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LocalLimitExecNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.fetch != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LocalLimitExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if self.fetch != 0 { - struct_ser.serialize_field("fetch", &self.fetch)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LocalLimitExecNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "input", - "fetch", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Input, - Fetch, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "input" => Ok(GeneratedField::Input), - "fetch" => Ok(GeneratedField::Fetch), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LocalLimitExecNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LocalLimitExecNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut input__ = None; - let mut fetch__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); - } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - } - } - Ok(LocalLimitExecNode { - input: input__, - fetch: fetch__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.LocalLimitExecNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LogicalExprList { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.expr.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprList", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LogicalExprList { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExprList; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExprList") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = Some(map_.next_value()?); - } - } - } - Ok(LogicalExprList { - expr: expr__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.LogicalExprList", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LogicalExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNode", len)?; - if let Some(v) = self.expr_type.as_ref() { - match v { - logical_expr_node::ExprType::Column(v) => { - struct_ser.serialize_field("column", v)?; - } - logical_expr_node::ExprType::Alias(v) => { - struct_ser.serialize_field("alias", v)?; - } - logical_expr_node::ExprType::Literal(v) => { - struct_ser.serialize_field("literal", v)?; - } - logical_expr_node::ExprType::BinaryExpr(v) => { - struct_ser.serialize_field("binaryExpr", v)?; - } - logical_expr_node::ExprType::AggregateExpr(v) => { - struct_ser.serialize_field("aggregateExpr", v)?; - } - logical_expr_node::ExprType::IsNullExpr(v) => { - struct_ser.serialize_field("isNullExpr", v)?; - } - logical_expr_node::ExprType::IsNotNullExpr(v) => { - struct_ser.serialize_field("isNotNullExpr", v)?; - } - logical_expr_node::ExprType::NotExpr(v) => { - struct_ser.serialize_field("notExpr", v)?; - } - logical_expr_node::ExprType::Between(v) => { - struct_ser.serialize_field("between", v)?; - } - logical_expr_node::ExprType::Case(v) => { - struct_ser.serialize_field("case", v)?; - } - logical_expr_node::ExprType::Cast(v) => { - struct_ser.serialize_field("cast", v)?; - } - logical_expr_node::ExprType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - logical_expr_node::ExprType::Negative(v) => { - struct_ser.serialize_field("negative", v)?; - } - logical_expr_node::ExprType::InList(v) => { - struct_ser.serialize_field("inList", v)?; - } - logical_expr_node::ExprType::Wildcard(v) => { - struct_ser.serialize_field("wildcard", v)?; - } - logical_expr_node::ExprType::TryCast(v) => { - struct_ser.serialize_field("tryCast", v)?; - } - logical_expr_node::ExprType::WindowExpr(v) => { - struct_ser.serialize_field("windowExpr", v)?; - } - logical_expr_node::ExprType::AggregateUdfExpr(v) => { - struct_ser.serialize_field("aggregateUdfExpr", v)?; - } - logical_expr_node::ExprType::ScalarUdfExpr(v) => { - struct_ser.serialize_field("scalarUdfExpr", v)?; - } - logical_expr_node::ExprType::GetIndexedField(v) => { - struct_ser.serialize_field("getIndexedField", v)?; - } - logical_expr_node::ExprType::GroupingSet(v) => { - struct_ser.serialize_field("groupingSet", v)?; - } - logical_expr_node::ExprType::Cube(v) => { - struct_ser.serialize_field("cube", v)?; - } - logical_expr_node::ExprType::Rollup(v) => { - struct_ser.serialize_field("rollup", v)?; - } - logical_expr_node::ExprType::IsTrue(v) => { - struct_ser.serialize_field("isTrue", v)?; - } - logical_expr_node::ExprType::IsFalse(v) => { - struct_ser.serialize_field("isFalse", v)?; - } - logical_expr_node::ExprType::IsUnknown(v) => { - struct_ser.serialize_field("isUnknown", v)?; - } - logical_expr_node::ExprType::IsNotTrue(v) => { - struct_ser.serialize_field("isNotTrue", v)?; - } - logical_expr_node::ExprType::IsNotFalse(v) => { - struct_ser.serialize_field("isNotFalse", v)?; - } - logical_expr_node::ExprType::IsNotUnknown(v) => { - struct_ser.serialize_field("isNotUnknown", v)?; - } - logical_expr_node::ExprType::Like(v) => { - struct_ser.serialize_field("like", v)?; - } - logical_expr_node::ExprType::Ilike(v) => { - struct_ser.serialize_field("ilike", v)?; - } - logical_expr_node::ExprType::SimilarTo(v) => { - struct_ser.serialize_field("similarTo", v)?; - } - logical_expr_node::ExprType::Placeholder(v) => { - struct_ser.serialize_field("placeholder", v)?; - } - logical_expr_node::ExprType::Unnest(v) => { - struct_ser.serialize_field("unnest", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LogicalExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "column", - "alias", - "literal", - "binary_expr", - "binaryExpr", - "aggregate_expr", - "aggregateExpr", - "is_null_expr", - "isNullExpr", - "is_not_null_expr", - "isNotNullExpr", - "not_expr", - "notExpr", - "between", - "case_", - "case", - "cast", - "sort", - "negative", - "in_list", - "inList", - "wildcard", - "try_cast", - "tryCast", - "window_expr", - "windowExpr", - "aggregate_udf_expr", - "aggregateUdfExpr", - "scalar_udf_expr", - "scalarUdfExpr", - "get_indexed_field", - "getIndexedField", - "grouping_set", - "groupingSet", - "cube", - "rollup", - "is_true", - "isTrue", - "is_false", - "isFalse", - "is_unknown", - "isUnknown", - "is_not_true", - "isNotTrue", - "is_not_false", - "isNotFalse", - "is_not_unknown", - "isNotUnknown", - "like", - "ilike", - "similar_to", - "similarTo", - "placeholder", - "unnest", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Column, - Alias, - Literal, - BinaryExpr, - AggregateExpr, - IsNullExpr, - IsNotNullExpr, - NotExpr, - Between, - Case, - Cast, - Sort, - Negative, - InList, - Wildcard, - TryCast, - WindowExpr, - AggregateUdfExpr, - ScalarUdfExpr, - GetIndexedField, - GroupingSet, - Cube, - Rollup, - IsTrue, - IsFalse, - IsUnknown, - IsNotTrue, - IsNotFalse, - IsNotUnknown, - Like, - Ilike, - SimilarTo, - Placeholder, - Unnest, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "column" => Ok(GeneratedField::Column), - "alias" => Ok(GeneratedField::Alias), - "literal" => Ok(GeneratedField::Literal), - "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), - "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), - "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), - "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), - "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), - "between" => Ok(GeneratedField::Between), - "case" | "case_" => Ok(GeneratedField::Case), - "cast" => Ok(GeneratedField::Cast), - "sort" => Ok(GeneratedField::Sort), - "negative" => Ok(GeneratedField::Negative), - "inList" | "in_list" => Ok(GeneratedField::InList), - "wildcard" => Ok(GeneratedField::Wildcard), - "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), - "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "aggregateUdfExpr" | "aggregate_udf_expr" => Ok(GeneratedField::AggregateUdfExpr), - "scalarUdfExpr" | "scalar_udf_expr" => Ok(GeneratedField::ScalarUdfExpr), - "getIndexedField" | "get_indexed_field" => Ok(GeneratedField::GetIndexedField), - "groupingSet" | "grouping_set" => Ok(GeneratedField::GroupingSet), - "cube" => Ok(GeneratedField::Cube), - "rollup" => Ok(GeneratedField::Rollup), - "isTrue" | "is_true" => Ok(GeneratedField::IsTrue), - "isFalse" | "is_false" => Ok(GeneratedField::IsFalse), - "isUnknown" | "is_unknown" => Ok(GeneratedField::IsUnknown), - "isNotTrue" | "is_not_true" => Ok(GeneratedField::IsNotTrue), - "isNotFalse" | "is_not_false" => Ok(GeneratedField::IsNotFalse), - "isNotUnknown" | "is_not_unknown" => Ok(GeneratedField::IsNotUnknown), - "like" => Ok(GeneratedField::Like), - "ilike" => Ok(GeneratedField::Ilike), - "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), - "placeholder" => Ok(GeneratedField::Placeholder), - "unnest" => Ok(GeneratedField::Unnest), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Column => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Column) -; - } - GeneratedField::Alias => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Alias) -; - } - GeneratedField::Literal => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("literal")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Literal) -; - } - GeneratedField::BinaryExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("binaryExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) -; - } - GeneratedField::AggregateExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateExpr) -; - } - GeneratedField::IsNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNullExpr) -; - } - GeneratedField::IsNotNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotNullExpr) -; - } - GeneratedField::NotExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("notExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::NotExpr) -; - } - GeneratedField::Between => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("between")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Between) -; - } - GeneratedField::Case => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("case")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Case) -; - } - GeneratedField::Cast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("cast")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) -; - } - GeneratedField::Sort => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Sort) -; - } - GeneratedField::Negative => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("negative")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Negative) -; - } - GeneratedField::InList => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inList")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InList) -; - } - GeneratedField::Wildcard => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("wildcard")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard) -; - } - GeneratedField::TryCast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("tryCast")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::TryCast) -; - } - GeneratedField::WindowExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("windowExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::WindowExpr) -; - } - GeneratedField::AggregateUdfExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateUdfExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateUdfExpr) -; - } - GeneratedField::ScalarUdfExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("scalarUdfExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarUdfExpr) -; - } - GeneratedField::GetIndexedField => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("getIndexedField")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GetIndexedField) -; - } - GeneratedField::GroupingSet => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("groupingSet")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GroupingSet) -; - } - GeneratedField::Cube => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("cube")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cube) -; - } - GeneratedField::Rollup => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("rollup")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Rollup) -; - } - GeneratedField::IsTrue => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isTrue")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsTrue) -; - } - GeneratedField::IsFalse => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isFalse")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsFalse) -; - } - GeneratedField::IsUnknown => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isUnknown")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsUnknown) -; - } - GeneratedField::IsNotTrue => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotTrue")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotTrue) -; - } - GeneratedField::IsNotFalse => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotFalse")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotFalse) -; - } - GeneratedField::IsNotUnknown => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotUnknown")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotUnknown) -; - } - GeneratedField::Like => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("like")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Like) -; - } - GeneratedField::Ilike => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("ilike")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Ilike) -; - } - GeneratedField::SimilarTo => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("similarTo")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) -; - } - GeneratedField::Placeholder => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("placeholder")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) -; - } - GeneratedField::Unnest => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("unnest")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Unnest) -; - } - } - } - Ok(LogicalExprNode { - expr_type: expr_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.LogicalExprNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LogicalExprNodeCollection { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.logical_expr_nodes.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNodeCollection", len)?; - if !self.logical_expr_nodes.is_empty() { - struct_ser.serialize_field("logicalExprNodes", &self.logical_expr_nodes)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "logical_expr_nodes", - "logicalExprNodes", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - LogicalExprNodes, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "logicalExprNodes" | "logical_expr_nodes" => Ok(GeneratedField::LogicalExprNodes), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExprNodeCollection; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExprNodeCollection") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut logical_expr_nodes__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::LogicalExprNodes => { - if logical_expr_nodes__.is_some() { - return Err(serde::de::Error::duplicate_field("logicalExprNodes")); - } - logical_expr_nodes__ = Some(map_.next_value()?); - } - } - } - Ok(LogicalExprNodeCollection { - logical_expr_nodes: logical_expr_nodes__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.LogicalExprNodeCollection", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LogicalExtensionNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.node.is_empty() { - len += 1; - } - if !self.inputs.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExtensionNode", len)?; - if !self.node.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; - } - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "node", - "inputs", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Node, - Inputs, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "node" => Ok(GeneratedField::Node), - "inputs" => Ok(GeneratedField::Inputs), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalExtensionNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalExtensionNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut node__ = None; - let mut inputs__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Node => { - if node__.is_some() { - return Err(serde::de::Error::duplicate_field("node")); - } - node__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; - } - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); - } - inputs__ = Some(map_.next_value()?); - } - } - } - Ok(LogicalExtensionNode { - node: node__.unwrap_or_default(), - inputs: inputs__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.LogicalExtensionNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for LogicalPlanNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.logical_plan_type.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.LogicalPlanNode", len)?; - if let Some(v) = self.logical_plan_type.as_ref() { - match v { - logical_plan_node::LogicalPlanType::ListingScan(v) => { - struct_ser.serialize_field("listingScan", v)?; - } - logical_plan_node::LogicalPlanType::Projection(v) => { - struct_ser.serialize_field("projection", v)?; - } - logical_plan_node::LogicalPlanType::Selection(v) => { - struct_ser.serialize_field("selection", v)?; - } - logical_plan_node::LogicalPlanType::Limit(v) => { - struct_ser.serialize_field("limit", v)?; - } - logical_plan_node::LogicalPlanType::Aggregate(v) => { - struct_ser.serialize_field("aggregate", v)?; - } - logical_plan_node::LogicalPlanType::Join(v) => { - struct_ser.serialize_field("join", v)?; - } - logical_plan_node::LogicalPlanType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - logical_plan_node::LogicalPlanType::Repartition(v) => { - struct_ser.serialize_field("repartition", v)?; - } - logical_plan_node::LogicalPlanType::EmptyRelation(v) => { - struct_ser.serialize_field("emptyRelation", v)?; - } - logical_plan_node::LogicalPlanType::CreateExternalTable(v) => { - struct_ser.serialize_field("createExternalTable", v)?; - } - logical_plan_node::LogicalPlanType::Explain(v) => { - struct_ser.serialize_field("explain", v)?; - } - logical_plan_node::LogicalPlanType::Window(v) => { - struct_ser.serialize_field("window", v)?; - } - logical_plan_node::LogicalPlanType::Analyze(v) => { - struct_ser.serialize_field("analyze", v)?; - } - logical_plan_node::LogicalPlanType::CrossJoin(v) => { - struct_ser.serialize_field("crossJoin", v)?; - } - logical_plan_node::LogicalPlanType::Values(v) => { - struct_ser.serialize_field("values", v)?; - } - logical_plan_node::LogicalPlanType::Extension(v) => { - struct_ser.serialize_field("extension", v)?; - } - logical_plan_node::LogicalPlanType::CreateCatalogSchema(v) => { - struct_ser.serialize_field("createCatalogSchema", v)?; - } - logical_plan_node::LogicalPlanType::Union(v) => { - struct_ser.serialize_field("union", v)?; - } - logical_plan_node::LogicalPlanType::CreateCatalog(v) => { - struct_ser.serialize_field("createCatalog", v)?; - } - logical_plan_node::LogicalPlanType::SubqueryAlias(v) => { - struct_ser.serialize_field("subqueryAlias", v)?; - } - logical_plan_node::LogicalPlanType::CreateView(v) => { - struct_ser.serialize_field("createView", v)?; - } - logical_plan_node::LogicalPlanType::Distinct(v) => { - struct_ser.serialize_field("distinct", v)?; - } - logical_plan_node::LogicalPlanType::ViewScan(v) => { - struct_ser.serialize_field("viewScan", v)?; - } - logical_plan_node::LogicalPlanType::CustomScan(v) => { - struct_ser.serialize_field("customScan", v)?; - } - logical_plan_node::LogicalPlanType::Prepare(v) => { - struct_ser.serialize_field("prepare", v)?; - } - logical_plan_node::LogicalPlanType::DropView(v) => { - struct_ser.serialize_field("dropView", v)?; - } - logical_plan_node::LogicalPlanType::DistinctOn(v) => { - struct_ser.serialize_field("distinctOn", v)?; - } - logical_plan_node::LogicalPlanType::CopyTo(v) => { - struct_ser.serialize_field("copyTo", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for LogicalPlanNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "listing_scan", - "listingScan", - "projection", - "selection", - "limit", - "aggregate", - "join", - "sort", - "repartition", - "empty_relation", - "emptyRelation", - "create_external_table", - "createExternalTable", - "explain", - "window", - "analyze", - "cross_join", - "crossJoin", - "values", - "extension", - "create_catalog_schema", - "createCatalogSchema", - "union", - "create_catalog", - "createCatalog", - "subquery_alias", - "subqueryAlias", - "create_view", - "createView", - "distinct", - "view_scan", - "viewScan", - "custom_scan", - "customScan", - "prepare", - "drop_view", - "dropView", - "distinct_on", - "distinctOn", - "copy_to", - "copyTo", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - ListingScan, - Projection, - Selection, - Limit, - Aggregate, - Join, - Sort, - Repartition, - EmptyRelation, - CreateExternalTable, - Explain, - Window, - Analyze, - CrossJoin, - Values, - Extension, - CreateCatalogSchema, - Union, - CreateCatalog, - SubqueryAlias, - CreateView, - Distinct, - ViewScan, - CustomScan, - Prepare, - DropView, - DistinctOn, - CopyTo, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "listingScan" | "listing_scan" => Ok(GeneratedField::ListingScan), - "projection" => Ok(GeneratedField::Projection), - "selection" => Ok(GeneratedField::Selection), - "limit" => Ok(GeneratedField::Limit), - "aggregate" => Ok(GeneratedField::Aggregate), - "join" => Ok(GeneratedField::Join), - "sort" => Ok(GeneratedField::Sort), - "repartition" => Ok(GeneratedField::Repartition), - "emptyRelation" | "empty_relation" => Ok(GeneratedField::EmptyRelation), - "createExternalTable" | "create_external_table" => Ok(GeneratedField::CreateExternalTable), - "explain" => Ok(GeneratedField::Explain), - "window" => Ok(GeneratedField::Window), - "analyze" => Ok(GeneratedField::Analyze), - "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), - "values" => Ok(GeneratedField::Values), - "extension" => Ok(GeneratedField::Extension), - "createCatalogSchema" | "create_catalog_schema" => Ok(GeneratedField::CreateCatalogSchema), - "union" => Ok(GeneratedField::Union), - "createCatalog" | "create_catalog" => Ok(GeneratedField::CreateCatalog), - "subqueryAlias" | "subquery_alias" => Ok(GeneratedField::SubqueryAlias), - "createView" | "create_view" => Ok(GeneratedField::CreateView), - "distinct" => Ok(GeneratedField::Distinct), - "viewScan" | "view_scan" => Ok(GeneratedField::ViewScan), - "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), - "prepare" => Ok(GeneratedField::Prepare), - "dropView" | "drop_view" => Ok(GeneratedField::DropView), - "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), - "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = LogicalPlanNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.LogicalPlanNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut logical_plan_type__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::ListingScan => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("listingScan")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ListingScan) -; - } - GeneratedField::Projection => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Projection) -; - } - GeneratedField::Selection => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("selection")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Selection) -; - } - GeneratedField::Limit => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("limit")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Limit) -; - } - GeneratedField::Aggregate => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregate")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Aggregate) -; - } - GeneratedField::Join => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("join")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Join) -; - } - GeneratedField::Sort => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sort) -; - } - GeneratedField::Repartition => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("repartition")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Repartition) -; - } - GeneratedField::EmptyRelation => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("emptyRelation")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::EmptyRelation) -; - } - GeneratedField::CreateExternalTable => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createExternalTable")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateExternalTable) -; - } - GeneratedField::Explain => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("explain")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Explain) -; - } - GeneratedField::Window => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("window")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Window) -; - } - GeneratedField::Analyze => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("analyze")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Analyze) -; - } - GeneratedField::CrossJoin => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("crossJoin")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CrossJoin) -; - } - GeneratedField::Values => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("values")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Values) -; - } - GeneratedField::Extension => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("extension")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Extension) -; - } - GeneratedField::CreateCatalogSchema => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createCatalogSchema")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalogSchema) -; - } - GeneratedField::Union => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("union")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Union) -; - } - GeneratedField::CreateCatalog => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createCatalog")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalog) -; - } - GeneratedField::SubqueryAlias => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("subqueryAlias")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::SubqueryAlias) -; - } - GeneratedField::CreateView => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("createView")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateView) -; - } - GeneratedField::Distinct => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Distinct) -; - } - GeneratedField::ViewScan => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("viewScan")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ViewScan) -; - } - GeneratedField::CustomScan => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("customScan")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) -; - } - GeneratedField::Prepare => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("prepare")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) -; - } - GeneratedField::DropView => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dropView")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) -; - } - GeneratedField::DistinctOn => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("distinctOn")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) -; - } - GeneratedField::CopyTo => { - if logical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("copyTo")); - } - logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CopyTo) -; - } - } - } - Ok(LogicalPlanNode { - logical_plan_type: logical_plan_type__, - }) - } - } - deserializer.deserialize_struct("datafusion.LogicalPlanNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for Map { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.field_type.is_some() { - len += 1; - } - if self.keys_sorted { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.Map", len)?; - if let Some(v) = self.field_type.as_ref() { - struct_ser.serialize_field("fieldType", v)?; - } - if self.keys_sorted { - struct_ser.serialize_field("keysSorted", &self.keys_sorted)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for Map { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "field_type", - "fieldType", - "keys_sorted", - "keysSorted", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - FieldType, - KeysSorted, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "fieldType" | "field_type" => Ok(GeneratedField::FieldType), - "keysSorted" | "keys_sorted" => Ok(GeneratedField::KeysSorted), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Map; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Map") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut field_type__ = None; - let mut keys_sorted__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::FieldType => { - if field_type__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldType")); - } - field_type__ = map_.next_value()?; - } - GeneratedField::KeysSorted => { - if keys_sorted__.is_some() { - return Err(serde::de::Error::duplicate_field("keysSorted")); - } - keys_sorted__ = Some(map_.next_value()?); - } - } - } - Ok(Map { - field_type: field_type__, - keys_sorted: keys_sorted__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.Map", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for MaybeFilter { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.MaybeFilter", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for MaybeFilter { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = MaybeFilter; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.MaybeFilter") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(MaybeFilter { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for MaybePhysicalSortExprs { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.sort_expr.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.MaybePhysicalSortExprs", len)?; - if !self.sort_expr.is_empty() { - struct_ser.serialize_field("sortExpr", &self.sort_expr)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "sort_expr", - "sortExpr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - SortExpr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = MaybePhysicalSortExprs; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.MaybePhysicalSortExprs") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut sort_expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::SortExpr => { - if sort_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("sortExpr")); - } - sort_expr__ = Some(map_.next_value()?); - } - } - } - Ok(MaybePhysicalSortExprs { - sort_expr: sort_expr__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for NamedStructField { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.name.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructField", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for NamedStructField { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "name", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Name, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "name" => Ok(GeneratedField::Name), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NamedStructField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NamedStructField") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut name__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = map_.next_value()?; - } - } - } - Ok(NamedStructField { - name: name__, - }) - } - } - deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for NegativeNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.NegativeNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for NegativeNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NegativeNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NegativeNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(NegativeNode { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for NestedLoopJoinExecNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - if self.join_type != 0 { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "left", - "right", - "join_type", - "joinType", - "filter", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Left, - Right, - JoinType, - Filter, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "filter" => Ok(GeneratedField::Filter), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NestedLoopJoinExecNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NestedLoopJoinExecNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut left__ = None; - let mut right__ = None; - let mut join_type__ = None; - let mut filter__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); - } - join_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; - } - } - } - Ok(NestedLoopJoinExecNode { - left: left__, - right: right__, - join_type: join_type__.unwrap_or_default(), - filter: filter__, - }) - } - } - deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for Not { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.expr.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.Not", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for Not { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "expr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Expr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "expr" => Ok(GeneratedField::Expr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Not; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Not") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - } - } - Ok(Not { - expr: expr__, - }) - } - } - deserializer.deserialize_struct("datafusion.Not", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for OptimizedLogicalPlanType { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.optimizer_name.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedLogicalPlanType", len)?; - if !self.optimizer_name.is_empty() { - struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "optimizer_name", - "optimizerName", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - OptimizerName, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = OptimizedLogicalPlanType; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.OptimizedLogicalPlanType") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut optimizer_name__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::OptimizerName => { - if optimizer_name__.is_some() { - return Err(serde::de::Error::duplicate_field("optimizerName")); - } - optimizer_name__ = Some(map_.next_value()?); - } - } - } - Ok(OptimizedLogicalPlanType { - optimizer_name: optimizer_name__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.OptimizedLogicalPlanType", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for OptimizedPhysicalPlanType { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.optimizer_name.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedPhysicalPlanType", len)?; - if !self.optimizer_name.is_empty() { - struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "optimizer_name", - "optimizerName", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - OptimizerName, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = OptimizedPhysicalPlanType; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.OptimizedPhysicalPlanType") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut optimizer_name__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::OptimizerName => { - if optimizer_name__.is_some() { - return Err(serde::de::Error::duplicate_field("optimizerName")); - } - optimizer_name__ = Some(map_.next_value()?); - } - } - } - Ok(OptimizedPhysicalPlanType { - optimizer_name: optimizer_name__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.OptimizedPhysicalPlanType", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ParquetFormat { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.options.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetFormat", len)?; - if let Some(v) = self.options.as_ref() { - struct_ser.serialize_field("options", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ParquetFormat { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "options", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Options, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "options" => Ok(GeneratedField::Options), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetFormat; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetFormat") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut options__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Options => { - if options__.is_some() { - return Err(serde::de::Error::duplicate_field("options")); - } - options__ = map_.next_value()?; - } - } - } - Ok(ParquetFormat { - options: options__, - }) - } - } - deserializer.deserialize_struct("datafusion.ParquetFormat", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ParquetOptions { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.enable_page_index { - len += 1; - } - if self.pruning { - len += 1; - } - if self.skip_metadata { - len += 1; - } - if self.pushdown_filters { - len += 1; - } - if self.reorder_filters { - len += 1; - } - if self.data_pagesize_limit != 0 { - len += 1; - } - if self.write_batch_size != 0 { - len += 1; - } - if !self.writer_version.is_empty() { - len += 1; - } - if self.bloom_filter_enabled { - len += 1; - } - if self.allow_single_file_parallelism { - len += 1; - } - if self.maximum_parallel_row_group_writers != 0 { - len += 1; - } - if self.maximum_buffered_record_batches_per_stream != 0 { - len += 1; - } - if self.dictionary_page_size_limit != 0 { - len += 1; - } - if self.data_page_row_count_limit != 0 { - len += 1; - } - if self.max_row_group_size != 0 { - len += 1; - } - if !self.created_by.is_empty() { - len += 1; - } - if self.metadata_size_hint_opt.is_some() { - len += 1; - } - if self.compression_opt.is_some() { - len += 1; - } - if self.dictionary_enabled_opt.is_some() { - len += 1; - } - if self.statistics_enabled_opt.is_some() { - len += 1; - } - if self.max_statistics_size_opt.is_some() { - len += 1; - } - if self.column_index_truncate_length_opt.is_some() { - len += 1; - } - if self.encoding_opt.is_some() { - len += 1; - } - if self.bloom_filter_fpp_opt.is_some() { - len += 1; - } - if self.bloom_filter_ndv_opt.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetOptions", len)?; - if self.enable_page_index { - struct_ser.serialize_field("enablePageIndex", &self.enable_page_index)?; - } - if self.pruning { - struct_ser.serialize_field("pruning", &self.pruning)?; - } - if self.skip_metadata { - struct_ser.serialize_field("skipMetadata", &self.skip_metadata)?; - } - if self.pushdown_filters { - struct_ser.serialize_field("pushdownFilters", &self.pushdown_filters)?; - } - if self.reorder_filters { - struct_ser.serialize_field("reorderFilters", &self.reorder_filters)?; - } - if self.data_pagesize_limit != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("dataPagesizeLimit", ToString::to_string(&self.data_pagesize_limit).as_str())?; - } - if self.write_batch_size != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("writeBatchSize", ToString::to_string(&self.write_batch_size).as_str())?; - } - if !self.writer_version.is_empty() { - struct_ser.serialize_field("writerVersion", &self.writer_version)?; - } - if self.bloom_filter_enabled { - struct_ser.serialize_field("bloomFilterEnabled", &self.bloom_filter_enabled)?; - } - if self.allow_single_file_parallelism { - struct_ser.serialize_field("allowSingleFileParallelism", &self.allow_single_file_parallelism)?; - } - if self.maximum_parallel_row_group_writers != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("maximumParallelRowGroupWriters", ToString::to_string(&self.maximum_parallel_row_group_writers).as_str())?; - } - if self.maximum_buffered_record_batches_per_stream != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("maximumBufferedRecordBatchesPerStream", ToString::to_string(&self.maximum_buffered_record_batches_per_stream).as_str())?; - } - if self.dictionary_page_size_limit != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("dictionaryPageSizeLimit", ToString::to_string(&self.dictionary_page_size_limit).as_str())?; - } - if self.data_page_row_count_limit != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("dataPageRowCountLimit", ToString::to_string(&self.data_page_row_count_limit).as_str())?; - } - if self.max_row_group_size != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("maxRowGroupSize", ToString::to_string(&self.max_row_group_size).as_str())?; - } - if !self.created_by.is_empty() { - struct_ser.serialize_field("createdBy", &self.created_by)?; - } - if let Some(v) = self.metadata_size_hint_opt.as_ref() { - match v { - parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("metadataSizeHint", ToString::to_string(&v).as_str())?; - } - } - } - if let Some(v) = self.compression_opt.as_ref() { - match v { - parquet_options::CompressionOpt::Compression(v) => { - struct_ser.serialize_field("compression", v)?; - } - } - } - if let Some(v) = self.dictionary_enabled_opt.as_ref() { - match v { - parquet_options::DictionaryEnabledOpt::DictionaryEnabled(v) => { - struct_ser.serialize_field("dictionaryEnabled", v)?; - } - } - } - if let Some(v) = self.statistics_enabled_opt.as_ref() { - match v { - parquet_options::StatisticsEnabledOpt::StatisticsEnabled(v) => { - struct_ser.serialize_field("statisticsEnabled", v)?; - } - } - } - if let Some(v) = self.max_statistics_size_opt.as_ref() { - match v { - parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("maxStatisticsSize", ToString::to_string(&v).as_str())?; - } - } - } - if let Some(v) = self.column_index_truncate_length_opt.as_ref() { - match v { - parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("columnIndexTruncateLength", ToString::to_string(&v).as_str())?; - } - } - } - if let Some(v) = self.encoding_opt.as_ref() { - match v { - parquet_options::EncodingOpt::Encoding(v) => { - struct_ser.serialize_field("encoding", v)?; - } - } - } - if let Some(v) = self.bloom_filter_fpp_opt.as_ref() { - match v { - parquet_options::BloomFilterFppOpt::BloomFilterFpp(v) => { - struct_ser.serialize_field("bloomFilterFpp", v)?; - } - } - } - if let Some(v) = self.bloom_filter_ndv_opt.as_ref() { - match v { - parquet_options::BloomFilterNdvOpt::BloomFilterNdv(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("bloomFilterNdv", ToString::to_string(&v).as_str())?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ParquetOptions { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "enable_page_index", - "enablePageIndex", - "pruning", - "skip_metadata", - "skipMetadata", - "pushdown_filters", - "pushdownFilters", - "reorder_filters", - "reorderFilters", - "data_pagesize_limit", - "dataPagesizeLimit", - "write_batch_size", - "writeBatchSize", - "writer_version", - "writerVersion", - "bloom_filter_enabled", - "bloomFilterEnabled", - "allow_single_file_parallelism", - "allowSingleFileParallelism", - "maximum_parallel_row_group_writers", - "maximumParallelRowGroupWriters", - "maximum_buffered_record_batches_per_stream", - "maximumBufferedRecordBatchesPerStream", - "dictionary_page_size_limit", - "dictionaryPageSizeLimit", - "data_page_row_count_limit", - "dataPageRowCountLimit", - "max_row_group_size", - "maxRowGroupSize", - "created_by", - "createdBy", - "metadata_size_hint", - "metadataSizeHint", - "compression", - "dictionary_enabled", - "dictionaryEnabled", - "statistics_enabled", - "statisticsEnabled", - "max_statistics_size", - "maxStatisticsSize", - "column_index_truncate_length", - "columnIndexTruncateLength", - "encoding", - "bloom_filter_fpp", - "bloomFilterFpp", - "bloom_filter_ndv", - "bloomFilterNdv", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - EnablePageIndex, - Pruning, - SkipMetadata, - PushdownFilters, - ReorderFilters, - DataPagesizeLimit, - WriteBatchSize, - WriterVersion, - BloomFilterEnabled, - AllowSingleFileParallelism, - MaximumParallelRowGroupWriters, - MaximumBufferedRecordBatchesPerStream, - DictionaryPageSizeLimit, - DataPageRowCountLimit, - MaxRowGroupSize, - CreatedBy, - MetadataSizeHint, - Compression, - DictionaryEnabled, - StatisticsEnabled, - MaxStatisticsSize, - ColumnIndexTruncateLength, - Encoding, - BloomFilterFpp, - BloomFilterNdv, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "enablePageIndex" | "enable_page_index" => Ok(GeneratedField::EnablePageIndex), - "pruning" => Ok(GeneratedField::Pruning), - "skipMetadata" | "skip_metadata" => Ok(GeneratedField::SkipMetadata), - "pushdownFilters" | "pushdown_filters" => Ok(GeneratedField::PushdownFilters), - "reorderFilters" | "reorder_filters" => Ok(GeneratedField::ReorderFilters), - "dataPagesizeLimit" | "data_pagesize_limit" => Ok(GeneratedField::DataPagesizeLimit), - "writeBatchSize" | "write_batch_size" => Ok(GeneratedField::WriteBatchSize), - "writerVersion" | "writer_version" => Ok(GeneratedField::WriterVersion), - "bloomFilterEnabled" | "bloom_filter_enabled" => Ok(GeneratedField::BloomFilterEnabled), - "allowSingleFileParallelism" | "allow_single_file_parallelism" => Ok(GeneratedField::AllowSingleFileParallelism), - "maximumParallelRowGroupWriters" | "maximum_parallel_row_group_writers" => Ok(GeneratedField::MaximumParallelRowGroupWriters), - "maximumBufferedRecordBatchesPerStream" | "maximum_buffered_record_batches_per_stream" => Ok(GeneratedField::MaximumBufferedRecordBatchesPerStream), - "dictionaryPageSizeLimit" | "dictionary_page_size_limit" => Ok(GeneratedField::DictionaryPageSizeLimit), - "dataPageRowCountLimit" | "data_page_row_count_limit" => Ok(GeneratedField::DataPageRowCountLimit), - "maxRowGroupSize" | "max_row_group_size" => Ok(GeneratedField::MaxRowGroupSize), - "createdBy" | "created_by" => Ok(GeneratedField::CreatedBy), - "metadataSizeHint" | "metadata_size_hint" => Ok(GeneratedField::MetadataSizeHint), - "compression" => Ok(GeneratedField::Compression), - "dictionaryEnabled" | "dictionary_enabled" => Ok(GeneratedField::DictionaryEnabled), - "statisticsEnabled" | "statistics_enabled" => Ok(GeneratedField::StatisticsEnabled), - "maxStatisticsSize" | "max_statistics_size" => Ok(GeneratedField::MaxStatisticsSize), - "columnIndexTruncateLength" | "column_index_truncate_length" => Ok(GeneratedField::ColumnIndexTruncateLength), - "encoding" => Ok(GeneratedField::Encoding), - "bloomFilterFpp" | "bloom_filter_fpp" => Ok(GeneratedField::BloomFilterFpp), - "bloomFilterNdv" | "bloom_filter_ndv" => Ok(GeneratedField::BloomFilterNdv), + "column" => Ok(GeneratedField::Column), + "alias" => Ok(GeneratedField::Alias), + "literal" => Ok(GeneratedField::Literal), + "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), + "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), + "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), + "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), + "between" => Ok(GeneratedField::Between), + "case" | "case_" => Ok(GeneratedField::Case), + "cast" => Ok(GeneratedField::Cast), + "negative" => Ok(GeneratedField::Negative), + "inList" | "in_list" => Ok(GeneratedField::InList), + "wildcard" => Ok(GeneratedField::Wildcard), + "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), + "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), + "aggregateUdfExpr" | "aggregate_udf_expr" => Ok(GeneratedField::AggregateUdfExpr), + "scalarUdfExpr" | "scalar_udf_expr" => Ok(GeneratedField::ScalarUdfExpr), + "groupingSet" | "grouping_set" => Ok(GeneratedField::GroupingSet), + "cube" => Ok(GeneratedField::Cube), + "rollup" => Ok(GeneratedField::Rollup), + "isTrue" | "is_true" => Ok(GeneratedField::IsTrue), + "isFalse" | "is_false" => Ok(GeneratedField::IsFalse), + "isUnknown" | "is_unknown" => Ok(GeneratedField::IsUnknown), + "isNotTrue" | "is_not_true" => Ok(GeneratedField::IsNotTrue), + "isNotFalse" | "is_not_false" => Ok(GeneratedField::IsNotFalse), + "isNotUnknown" | "is_not_unknown" => Ok(GeneratedField::IsNotUnknown), + "like" => Ok(GeneratedField::Like), + "ilike" => Ok(GeneratedField::Ilike), + "similarTo" | "similar_to" => Ok(GeneratedField::SimilarTo), + "placeholder" => Ok(GeneratedField::Placeholder), + "unnest" => Ok(GeneratedField::Unnest), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16291,460 +10057,247 @@ impl<'de> serde::Deserialize<'de> for ParquetOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetOptions; + type Value = LogicalExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetOptions") + formatter.write_str("struct datafusion.LogicalExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut enable_page_index__ = None; - let mut pruning__ = None; - let mut skip_metadata__ = None; - let mut pushdown_filters__ = None; - let mut reorder_filters__ = None; - let mut data_pagesize_limit__ = None; - let mut write_batch_size__ = None; - let mut writer_version__ = None; - let mut bloom_filter_enabled__ = None; - let mut allow_single_file_parallelism__ = None; - let mut maximum_parallel_row_group_writers__ = None; - let mut maximum_buffered_record_batches_per_stream__ = None; - let mut dictionary_page_size_limit__ = None; - let mut data_page_row_count_limit__ = None; - let mut max_row_group_size__ = None; - let mut created_by__ = None; - let mut metadata_size_hint_opt__ = None; - let mut compression_opt__ = None; - let mut dictionary_enabled_opt__ = None; - let mut statistics_enabled_opt__ = None; - let mut max_statistics_size_opt__ = None; - let mut column_index_truncate_length_opt__ = None; - let mut encoding_opt__ = None; - let mut bloom_filter_fpp_opt__ = None; - let mut bloom_filter_ndv_opt__ = None; + let mut expr_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::EnablePageIndex => { - if enable_page_index__.is_some() { - return Err(serde::de::Error::duplicate_field("enablePageIndex")); - } - enable_page_index__ = Some(map_.next_value()?); - } - GeneratedField::Pruning => { - if pruning__.is_some() { - return Err(serde::de::Error::duplicate_field("pruning")); - } - pruning__ = Some(map_.next_value()?); - } - GeneratedField::SkipMetadata => { - if skip_metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("skipMetadata")); - } - skip_metadata__ = Some(map_.next_value()?); - } - GeneratedField::PushdownFilters => { - if pushdown_filters__.is_some() { - return Err(serde::de::Error::duplicate_field("pushdownFilters")); - } - pushdown_filters__ = Some(map_.next_value()?); - } - GeneratedField::ReorderFilters => { - if reorder_filters__.is_some() { - return Err(serde::de::Error::duplicate_field("reorderFilters")); - } - reorder_filters__ = Some(map_.next_value()?); - } - GeneratedField::DataPagesizeLimit => { - if data_pagesize_limit__.is_some() { - return Err(serde::de::Error::duplicate_field("dataPagesizeLimit")); - } - data_pagesize_limit__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::WriteBatchSize => { - if write_batch_size__.is_some() { - return Err(serde::de::Error::duplicate_field("writeBatchSize")); - } - write_batch_size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::WriterVersion => { - if writer_version__.is_some() { - return Err(serde::de::Error::duplicate_field("writerVersion")); + GeneratedField::Column => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("column")); } - writer_version__ = Some(map_.next_value()?); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Column) +; } - GeneratedField::BloomFilterEnabled => { - if bloom_filter_enabled__.is_some() { - return Err(serde::de::Error::duplicate_field("bloomFilterEnabled")); + GeneratedField::Alias => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); } - bloom_filter_enabled__ = Some(map_.next_value()?); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Alias) +; } - GeneratedField::AllowSingleFileParallelism => { - if allow_single_file_parallelism__.is_some() { - return Err(serde::de::Error::duplicate_field("allowSingleFileParallelism")); + GeneratedField::Literal => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("literal")); } - allow_single_file_parallelism__ = Some(map_.next_value()?); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Literal) +; } - GeneratedField::MaximumParallelRowGroupWriters => { - if maximum_parallel_row_group_writers__.is_some() { - return Err(serde::de::Error::duplicate_field("maximumParallelRowGroupWriters")); + GeneratedField::BinaryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryExpr")); } - maximum_parallel_row_group_writers__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::BinaryExpr) +; } - GeneratedField::MaximumBufferedRecordBatchesPerStream => { - if maximum_buffered_record_batches_per_stream__.is_some() { - return Err(serde::de::Error::duplicate_field("maximumBufferedRecordBatchesPerStream")); + GeneratedField::IsNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNullExpr")); } - maximum_buffered_record_batches_per_stream__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNullExpr) +; } - GeneratedField::DictionaryPageSizeLimit => { - if dictionary_page_size_limit__.is_some() { - return Err(serde::de::Error::duplicate_field("dictionaryPageSizeLimit")); + GeneratedField::IsNotNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotNullExpr")); } - dictionary_page_size_limit__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotNullExpr) +; } - GeneratedField::DataPageRowCountLimit => { - if data_page_row_count_limit__.is_some() { - return Err(serde::de::Error::duplicate_field("dataPageRowCountLimit")); + GeneratedField::NotExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("notExpr")); } - data_page_row_count_limit__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::NotExpr) +; } - GeneratedField::MaxRowGroupSize => { - if max_row_group_size__.is_some() { - return Err(serde::de::Error::duplicate_field("maxRowGroupSize")); + GeneratedField::Between => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("between")); } - max_row_group_size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Between) +; } - GeneratedField::CreatedBy => { - if created_by__.is_some() { - return Err(serde::de::Error::duplicate_field("createdBy")); + GeneratedField::Case => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("case")); } - created_by__ = Some(map_.next_value()?); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Case) +; } - GeneratedField::MetadataSizeHint => { - if metadata_size_hint_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("metadataSizeHint")); + GeneratedField::Cast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cast")); } - metadata_size_hint_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::MetadataSizeHintOpt::MetadataSizeHint(x.0)); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cast) +; } - GeneratedField::Compression => { - if compression_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("compression")); + GeneratedField::Negative => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("negative")); } - compression_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::CompressionOpt::Compression); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Negative) +; } - GeneratedField::DictionaryEnabled => { - if dictionary_enabled_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("dictionaryEnabled")); + GeneratedField::InList => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("inList")); } - dictionary_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::DictionaryEnabledOpt::DictionaryEnabled); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::InList) +; } - GeneratedField::StatisticsEnabled => { - if statistics_enabled_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("statisticsEnabled")); + GeneratedField::Wildcard => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("wildcard")); } - statistics_enabled_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::StatisticsEnabledOpt::StatisticsEnabled); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Wildcard) +; } - GeneratedField::MaxStatisticsSize => { - if max_statistics_size_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("maxStatisticsSize")); + GeneratedField::TryCast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("tryCast")); } - max_statistics_size_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(x.0)); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::TryCast) +; } - GeneratedField::ColumnIndexTruncateLength => { - if column_index_truncate_length_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("columnIndexTruncateLength")); + GeneratedField::WindowExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("windowExpr")); } - column_index_truncate_length_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(x.0)); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::WindowExpr) +; } - GeneratedField::Encoding => { - if encoding_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("encoding")); + GeneratedField::AggregateUdfExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateUdfExpr")); } - encoding_opt__ = map_.next_value::<::std::option::Option<_>>()?.map(parquet_options::EncodingOpt::Encoding); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::AggregateUdfExpr) +; } - GeneratedField::BloomFilterFpp => { - if bloom_filter_fpp_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("bloomFilterFpp")); + GeneratedField::ScalarUdfExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarUdfExpr")); } - bloom_filter_fpp_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::BloomFilterFppOpt::BloomFilterFpp(x.0)); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::ScalarUdfExpr) +; } - GeneratedField::BloomFilterNdv => { - if bloom_filter_ndv_opt__.is_some() { - return Err(serde::de::Error::duplicate_field("bloomFilterNdv")); + GeneratedField::GroupingSet => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("groupingSet")); } - bloom_filter_ndv_opt__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| parquet_options::BloomFilterNdvOpt::BloomFilterNdv(x.0)); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::GroupingSet) +; } - } - } - Ok(ParquetOptions { - enable_page_index: enable_page_index__.unwrap_or_default(), - pruning: pruning__.unwrap_or_default(), - skip_metadata: skip_metadata__.unwrap_or_default(), - pushdown_filters: pushdown_filters__.unwrap_or_default(), - reorder_filters: reorder_filters__.unwrap_or_default(), - data_pagesize_limit: data_pagesize_limit__.unwrap_or_default(), - write_batch_size: write_batch_size__.unwrap_or_default(), - writer_version: writer_version__.unwrap_or_default(), - bloom_filter_enabled: bloom_filter_enabled__.unwrap_or_default(), - allow_single_file_parallelism: allow_single_file_parallelism__.unwrap_or_default(), - maximum_parallel_row_group_writers: maximum_parallel_row_group_writers__.unwrap_or_default(), - maximum_buffered_record_batches_per_stream: maximum_buffered_record_batches_per_stream__.unwrap_or_default(), - dictionary_page_size_limit: dictionary_page_size_limit__.unwrap_or_default(), - data_page_row_count_limit: data_page_row_count_limit__.unwrap_or_default(), - max_row_group_size: max_row_group_size__.unwrap_or_default(), - created_by: created_by__.unwrap_or_default(), - metadata_size_hint_opt: metadata_size_hint_opt__, - compression_opt: compression_opt__, - dictionary_enabled_opt: dictionary_enabled_opt__, - statistics_enabled_opt: statistics_enabled_opt__, - max_statistics_size_opt: max_statistics_size_opt__, - column_index_truncate_length_opt: column_index_truncate_length_opt__, - encoding_opt: encoding_opt__, - bloom_filter_fpp_opt: bloom_filter_fpp_opt__, - bloom_filter_ndv_opt: bloom_filter_ndv_opt__, - }) - } - } - deserializer.deserialize_struct("datafusion.ParquetOptions", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ParquetScanExecNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.base_conf.is_some() { - len += 1; - } - if self.predicate.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; - if let Some(v) = self.base_conf.as_ref() { - struct_ser.serialize_field("baseConf", v)?; - } - if let Some(v) = self.predicate.as_ref() { - struct_ser.serialize_field("predicate", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "base_conf", - "baseConf", - "predicate", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - BaseConf, - Predicate, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - "predicate" => Ok(GeneratedField::Predicate), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + GeneratedField::Cube => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cube")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Cube) +; } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetScanExecNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetScanExecNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut base_conf__ = None; - let mut predicate__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::BaseConf => { - if base_conf__.is_some() { - return Err(serde::de::Error::duplicate_field("baseConf")); + GeneratedField::Rollup => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("rollup")); } - base_conf__ = map_.next_value()?; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Rollup) +; } - GeneratedField::Predicate => { - if predicate__.is_some() { - return Err(serde::de::Error::duplicate_field("predicate")); + GeneratedField::IsTrue => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isTrue")); } - predicate__ = map_.next_value()?; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsTrue) +; } - } - } - Ok(ParquetScanExecNode { - base_conf: base_conf__, - predicate: predicate__, - }) - } - } - deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ParquetSink { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.config.is_some() { - len += 1; - } - if self.parquet_options.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; - if let Some(v) = self.config.as_ref() { - struct_ser.serialize_field("config", v)?; - } - if let Some(v) = self.parquet_options.as_ref() { - struct_ser.serialize_field("parquetOptions", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ParquetSink { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "config", - "parquet_options", - "parquetOptions", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Config, - ParquetOptions, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "config" => Ok(GeneratedField::Config), - "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + GeneratedField::IsFalse => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isFalse")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsFalse) +; } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetSink; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetSink") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut config__ = None; - let mut parquet_options__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Config => { - if config__.is_some() { - return Err(serde::de::Error::duplicate_field("config")); + GeneratedField::IsUnknown => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isUnknown")); } - config__ = map_.next_value()?; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsUnknown) +; } - GeneratedField::ParquetOptions => { - if parquet_options__.is_some() { - return Err(serde::de::Error::duplicate_field("parquetOptions")); + GeneratedField::IsNotTrue => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotTrue")); } - parquet_options__ = map_.next_value()?; + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotTrue) +; + } + GeneratedField::IsNotFalse => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotFalse")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotFalse) +; + } + GeneratedField::IsNotUnknown => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotUnknown")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::IsNotUnknown) +; + } + GeneratedField::Like => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("like")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Like) +; + } + GeneratedField::Ilike => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("ilike")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Ilike) +; + } + GeneratedField::SimilarTo => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("similarTo")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::SimilarTo) +; + } + GeneratedField::Placeholder => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholder")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Placeholder) +; + } + GeneratedField::Unnest => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("unnest")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_expr_node::ExprType::Unnest) +; } } } - Ok(ParquetSink { - config: config__, - parquet_options: parquet_options__, + Ok(LogicalExprNode { + expr_type: expr_type__, }) } } - deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LogicalExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ParquetSinkExecNode { +impl serde::Serialize for LogicalExprNodeCollection { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16752,55 +10305,30 @@ impl serde::Serialize for ParquetSinkExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.sink.is_some() { - len += 1; - } - if self.sink_schema.is_some() { - len += 1; - } - if self.sort_order.is_some() { + if !self.logical_expr_nodes.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if let Some(v) = self.sink.as_ref() { - struct_ser.serialize_field("sink", v)?; - } - if let Some(v) = self.sink_schema.as_ref() { - struct_ser.serialize_field("sinkSchema", v)?; - } - if let Some(v) = self.sort_order.as_ref() { - struct_ser.serialize_field("sortOrder", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExprNodeCollection", len)?; + if !self.logical_expr_nodes.is_empty() { + struct_ser.serialize_field("logicalExprNodes", &self.logical_expr_nodes)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { +impl<'de> serde::Deserialize<'de> for LogicalExprNodeCollection { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "sink", - "sink_schema", - "sinkSchema", - "sort_order", - "sortOrder", + "logical_expr_nodes", + "logicalExprNodes", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Sink, - SinkSchema, - SortOrder, + LogicalExprNodes, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16822,10 +10350,7 @@ impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "sink" => Ok(GeneratedField::Sink), - "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), - "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), + "logicalExprNodes" | "logical_expr_nodes" => Ok(GeneratedField::LogicalExprNodes), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16835,60 +10360,36 @@ impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ParquetSinkExecNode; + type Value = LogicalExprNodeCollection; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ParquetSinkExecNode") + formatter.write_str("struct datafusion.LogicalExprNodeCollection") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut sink__ = None; - let mut sink_schema__ = None; - let mut sort_order__ = None; + let mut logical_expr_nodes__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Sink => { - if sink__.is_some() { - return Err(serde::de::Error::duplicate_field("sink")); - } - sink__ = map_.next_value()?; - } - GeneratedField::SinkSchema => { - if sink_schema__.is_some() { - return Err(serde::de::Error::duplicate_field("sinkSchema")); - } - sink_schema__ = map_.next_value()?; - } - GeneratedField::SortOrder => { - if sort_order__.is_some() { - return Err(serde::de::Error::duplicate_field("sortOrder")); + GeneratedField::LogicalExprNodes => { + if logical_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("logicalExprNodes")); } - sort_order__ = map_.next_value()?; + logical_expr_nodes__ = Some(map_.next_value()?); } } } - Ok(ParquetSinkExecNode { - input: input__, - sink: sink__, - sink_schema: sink_schema__, - sort_order: sort_order__, + Ok(LogicalExprNodeCollection { + logical_expr_nodes: logical_expr_nodes__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LogicalExprNodeCollection", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartialTableReference { +impl serde::Serialize for LogicalExtensionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -16896,37 +10397,39 @@ impl serde::Serialize for PartialTableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.schema.is_empty() { + if !self.node.is_empty() { len += 1; } - if !self.table.is_empty() { + if !self.inputs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; - if !self.schema.is_empty() { - struct_ser.serialize_field("schema", &self.schema)?; + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalExtensionNode", len)?; + if !self.node.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; } - if !self.table.is_empty() { - struct_ser.serialize_field("table", &self.table)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartialTableReference { +impl<'de> serde::Deserialize<'de> for LogicalExtensionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", - "table", + "node", + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Schema, - Table, + Node, + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16948,8 +10451,8 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { E: serde::de::Error, { match value { - "schema" => Ok(GeneratedField::Schema), - "table" => Ok(GeneratedField::Table), + "node" => Ok(GeneratedField::Node), + "inputs" => Ok(GeneratedField::Inputs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16959,44 +10462,46 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartialTableReference; + type Value = LogicalExtensionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartialTableReference") + formatter.write_str("struct datafusion.LogicalExtensionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema__ = None; - let mut table__ = None; + let mut node__ = None; + let mut inputs__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Node => { + if node__.is_some() { + return Err(serde::de::Error::duplicate_field("node")); } - schema__ = Some(map_.next_value()?); + node__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::Table => { - if table__.is_some() { - return Err(serde::de::Error::duplicate_field("table")); + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); } - table__ = Some(map_.next_value()?); + inputs__ = Some(map_.next_value()?); } } } - Ok(PartialTableReference { - schema: schema__.unwrap_or_default(), - table: table__.unwrap_or_default(), + Ok(LogicalExtensionNode { + node: node__.unwrap_or_default(), + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LogicalExtensionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartiallySortedInputOrderMode { +impl serde::Serialize for LogicalPlanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17004,29 +10509,186 @@ impl serde::Serialize for PartiallySortedInputOrderMode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.columns.is_empty() { + if self.logical_plan_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; - if !self.columns.is_empty() { - struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; + let mut struct_ser = serializer.serialize_struct("datafusion.LogicalPlanNode", len)?; + if let Some(v) = self.logical_plan_type.as_ref() { + match v { + logical_plan_node::LogicalPlanType::ListingScan(v) => { + struct_ser.serialize_field("listingScan", v)?; + } + logical_plan_node::LogicalPlanType::Projection(v) => { + struct_ser.serialize_field("projection", v)?; + } + logical_plan_node::LogicalPlanType::Selection(v) => { + struct_ser.serialize_field("selection", v)?; + } + logical_plan_node::LogicalPlanType::Limit(v) => { + struct_ser.serialize_field("limit", v)?; + } + logical_plan_node::LogicalPlanType::Aggregate(v) => { + struct_ser.serialize_field("aggregate", v)?; + } + logical_plan_node::LogicalPlanType::Join(v) => { + struct_ser.serialize_field("join", v)?; + } + logical_plan_node::LogicalPlanType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + logical_plan_node::LogicalPlanType::Repartition(v) => { + struct_ser.serialize_field("repartition", v)?; + } + logical_plan_node::LogicalPlanType::EmptyRelation(v) => { + struct_ser.serialize_field("emptyRelation", v)?; + } + logical_plan_node::LogicalPlanType::CreateExternalTable(v) => { + struct_ser.serialize_field("createExternalTable", v)?; + } + logical_plan_node::LogicalPlanType::Explain(v) => { + struct_ser.serialize_field("explain", v)?; + } + logical_plan_node::LogicalPlanType::Window(v) => { + struct_ser.serialize_field("window", v)?; + } + logical_plan_node::LogicalPlanType::Analyze(v) => { + struct_ser.serialize_field("analyze", v)?; + } + logical_plan_node::LogicalPlanType::CrossJoin(v) => { + struct_ser.serialize_field("crossJoin", v)?; + } + logical_plan_node::LogicalPlanType::Values(v) => { + struct_ser.serialize_field("values", v)?; + } + logical_plan_node::LogicalPlanType::Extension(v) => { + struct_ser.serialize_field("extension", v)?; + } + logical_plan_node::LogicalPlanType::CreateCatalogSchema(v) => { + struct_ser.serialize_field("createCatalogSchema", v)?; + } + logical_plan_node::LogicalPlanType::Union(v) => { + struct_ser.serialize_field("union", v)?; + } + logical_plan_node::LogicalPlanType::CreateCatalog(v) => { + struct_ser.serialize_field("createCatalog", v)?; + } + logical_plan_node::LogicalPlanType::SubqueryAlias(v) => { + struct_ser.serialize_field("subqueryAlias", v)?; + } + logical_plan_node::LogicalPlanType::CreateView(v) => { + struct_ser.serialize_field("createView", v)?; + } + logical_plan_node::LogicalPlanType::Distinct(v) => { + struct_ser.serialize_field("distinct", v)?; + } + logical_plan_node::LogicalPlanType::ViewScan(v) => { + struct_ser.serialize_field("viewScan", v)?; + } + logical_plan_node::LogicalPlanType::CustomScan(v) => { + struct_ser.serialize_field("customScan", v)?; + } + logical_plan_node::LogicalPlanType::Prepare(v) => { + struct_ser.serialize_field("prepare", v)?; + } + logical_plan_node::LogicalPlanType::DropView(v) => { + struct_ser.serialize_field("dropView", v)?; + } + logical_plan_node::LogicalPlanType::DistinctOn(v) => { + struct_ser.serialize_field("distinctOn", v)?; + } + logical_plan_node::LogicalPlanType::CopyTo(v) => { + struct_ser.serialize_field("copyTo", v)?; + } + logical_plan_node::LogicalPlanType::Unnest(v) => { + struct_ser.serialize_field("unnest", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { +impl<'de> serde::Deserialize<'de> for LogicalPlanNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "columns", + "listing_scan", + "listingScan", + "projection", + "selection", + "limit", + "aggregate", + "join", + "sort", + "repartition", + "empty_relation", + "emptyRelation", + "create_external_table", + "createExternalTable", + "explain", + "window", + "analyze", + "cross_join", + "crossJoin", + "values", + "extension", + "create_catalog_schema", + "createCatalogSchema", + "union", + "create_catalog", + "createCatalog", + "subquery_alias", + "subqueryAlias", + "create_view", + "createView", + "distinct", + "view_scan", + "viewScan", + "custom_scan", + "customScan", + "prepare", + "drop_view", + "dropView", + "distinct_on", + "distinctOn", + "copy_to", + "copyTo", + "unnest", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Columns, + ListingScan, + Projection, + Selection, + Limit, + Aggregate, + Join, + Sort, + Repartition, + EmptyRelation, + CreateExternalTable, + Explain, + Window, + Analyze, + CrossJoin, + Values, + Extension, + CreateCatalogSchema, + Union, + CreateCatalog, + SubqueryAlias, + CreateView, + Distinct, + ViewScan, + CustomScan, + Prepare, + DropView, + DistinctOn, + CopyTo, + Unnest, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17048,49 +10710,271 @@ impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { E: serde::de::Error, { match value { - "columns" => Ok(GeneratedField::Columns), + "listingScan" | "listing_scan" => Ok(GeneratedField::ListingScan), + "projection" => Ok(GeneratedField::Projection), + "selection" => Ok(GeneratedField::Selection), + "limit" => Ok(GeneratedField::Limit), + "aggregate" => Ok(GeneratedField::Aggregate), + "join" => Ok(GeneratedField::Join), + "sort" => Ok(GeneratedField::Sort), + "repartition" => Ok(GeneratedField::Repartition), + "emptyRelation" | "empty_relation" => Ok(GeneratedField::EmptyRelation), + "createExternalTable" | "create_external_table" => Ok(GeneratedField::CreateExternalTable), + "explain" => Ok(GeneratedField::Explain), + "window" => Ok(GeneratedField::Window), + "analyze" => Ok(GeneratedField::Analyze), + "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), + "values" => Ok(GeneratedField::Values), + "extension" => Ok(GeneratedField::Extension), + "createCatalogSchema" | "create_catalog_schema" => Ok(GeneratedField::CreateCatalogSchema), + "union" => Ok(GeneratedField::Union), + "createCatalog" | "create_catalog" => Ok(GeneratedField::CreateCatalog), + "subqueryAlias" | "subquery_alias" => Ok(GeneratedField::SubqueryAlias), + "createView" | "create_view" => Ok(GeneratedField::CreateView), + "distinct" => Ok(GeneratedField::Distinct), + "viewScan" | "view_scan" => Ok(GeneratedField::ViewScan), + "customScan" | "custom_scan" => Ok(GeneratedField::CustomScan), + "prepare" => Ok(GeneratedField::Prepare), + "dropView" | "drop_view" => Ok(GeneratedField::DropView), + "distinctOn" | "distinct_on" => Ok(GeneratedField::DistinctOn), + "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), + "unnest" => Ok(GeneratedField::Unnest), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartiallySortedInputOrderMode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut columns__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Columns => { - if columns__.is_some() { - return Err(serde::de::Error::duplicate_field("columns")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = LogicalPlanNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.LogicalPlanNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut logical_plan_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ListingScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("listingScan")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ListingScan) +; + } + GeneratedField::Projection => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Projection) +; + } + GeneratedField::Selection => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("selection")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Selection) +; + } + GeneratedField::Limit => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Limit) +; + } + GeneratedField::Aggregate => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregate")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Aggregate) +; + } + GeneratedField::Join => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("join")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Join) +; + } + GeneratedField::Sort => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Sort) +; + } + GeneratedField::Repartition => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("repartition")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Repartition) +; + } + GeneratedField::EmptyRelation => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("emptyRelation")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::EmptyRelation) +; + } + GeneratedField::CreateExternalTable => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createExternalTable")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateExternalTable) +; + } + GeneratedField::Explain => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("explain")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Explain) +; + } + GeneratedField::Window => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("window")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Window) +; + } + GeneratedField::Analyze => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("analyze")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Analyze) +; + } + GeneratedField::CrossJoin => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("crossJoin")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CrossJoin) +; + } + GeneratedField::Values => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("values")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Values) +; + } + GeneratedField::Extension => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("extension")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Extension) +; + } + GeneratedField::CreateCatalogSchema => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createCatalogSchema")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalogSchema) +; + } + GeneratedField::Union => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("union")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Union) +; + } + GeneratedField::CreateCatalog => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createCatalog")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateCatalog) +; + } + GeneratedField::SubqueryAlias => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("subqueryAlias")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::SubqueryAlias) +; + } + GeneratedField::CreateView => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("createView")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CreateView) +; + } + GeneratedField::Distinct => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Distinct) +; + } + GeneratedField::ViewScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("viewScan")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::ViewScan) +; + } + GeneratedField::CustomScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("customScan")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CustomScan) +; + } + GeneratedField::Prepare => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("prepare")); } - columns__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Prepare) +; + } + GeneratedField::DropView => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dropView")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DropView) +; + } + GeneratedField::DistinctOn => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("distinctOn")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::DistinctOn) +; + } + GeneratedField::CopyTo => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("copyTo")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CopyTo) +; + } + GeneratedField::Unnest => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("unnest")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::Unnest) +; } } } - Ok(PartiallySortedInputOrderMode { - columns: columns__.unwrap_or_default(), + Ok(LogicalPlanNode { + logical_plan_type: logical_plan_type__, }) } } - deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.LogicalPlanNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartitionColumn { +impl serde::Serialize for MaybeFilter { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17098,38 +10982,29 @@ impl serde::Serialize for PartitionColumn { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { - len += 1; - } - if self.arrow_type.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionColumn", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.MaybeFilter", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartitionColumn { +impl<'de> serde::Deserialize<'de> for MaybeFilter { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "arrow_type", - "arrowType", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - ArrowType, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17151,8 +11026,7 @@ impl<'de> serde::Deserialize<'de> for PartitionColumn { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17162,118 +11036,36 @@ impl<'de> serde::Deserialize<'de> for PartitionColumn { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionColumn; + type Value = MaybeFilter; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionColumn") + formatter.write_str("struct datafusion.MaybeFilter") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut arrow_type__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - arrow_type__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(PartitionColumn { - name: name__.unwrap_or_default(), - arrow_type: arrow_type__, + Ok(MaybeFilter { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PartitionColumn", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for PartitionMode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::CollectLeft => "COLLECT_LEFT", - Self::Partitioned => "PARTITIONED", - Self::Auto => "AUTO", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for PartitionMode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "COLLECT_LEFT", - "PARTITIONED", - "AUTO", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionMode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "COLLECT_LEFT" => Ok(PartitionMode::CollectLeft), - "PARTITIONED" => Ok(PartitionMode::Partitioned), - "AUTO" => Ok(PartitionMode::Auto), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), - } - } - } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.MaybeFilter", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartitionStats { +impl serde::Serialize for MaybePhysicalSortExprs { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17281,60 +11073,30 @@ impl serde::Serialize for PartitionStats { { use serde::ser::SerializeStruct; let mut len = 0; - if self.num_rows != 0 { - len += 1; - } - if self.num_batches != 0 { - len += 1; - } - if self.num_bytes != 0 { - len += 1; - } - if !self.column_stats.is_empty() { + if !self.sort_expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionStats", len)?; - if self.num_rows != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; - } - if self.num_batches != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numBatches", ToString::to_string(&self.num_batches).as_str())?; - } - if self.num_bytes != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("numBytes", ToString::to_string(&self.num_bytes).as_str())?; - } - if !self.column_stats.is_empty() { - struct_ser.serialize_field("columnStats", &self.column_stats)?; + let mut struct_ser = serializer.serialize_struct("datafusion.MaybePhysicalSortExprs", len)?; + if !self.sort_expr.is_empty() { + struct_ser.serialize_field("sortExpr", &self.sort_expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartitionStats { +impl<'de> serde::Deserialize<'de> for MaybePhysicalSortExprs { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "num_rows", - "numRows", - "num_batches", - "numBatches", - "num_bytes", - "numBytes", - "column_stats", - "columnStats", + "sort_expr", + "sortExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - NumRows, - NumBatches, - NumBytes, - ColumnStats, + SortExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17356,10 +11118,7 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { E: serde::de::Error, { match value { - "numRows" | "num_rows" => Ok(GeneratedField::NumRows), - "numBatches" | "num_batches" => Ok(GeneratedField::NumBatches), - "numBytes" | "num_bytes" => Ok(GeneratedField::NumBytes), - "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), + "sortExpr" | "sort_expr" => Ok(GeneratedField::SortExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17369,66 +11128,36 @@ impl<'de> serde::Deserialize<'de> for PartitionStats { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionStats; + type Value = MaybePhysicalSortExprs; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionStats") + formatter.write_str("struct datafusion.MaybePhysicalSortExprs") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut num_rows__ = None; - let mut num_batches__ = None; - let mut num_bytes__ = None; - let mut column_stats__ = None; + let mut sort_expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::NumRows => { - if num_rows__.is_some() { - return Err(serde::de::Error::duplicate_field("numRows")); - } - num_rows__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::NumBatches => { - if num_batches__.is_some() { - return Err(serde::de::Error::duplicate_field("numBatches")); - } - num_batches__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::NumBytes => { - if num_bytes__.is_some() { - return Err(serde::de::Error::duplicate_field("numBytes")); - } - num_bytes__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::ColumnStats => { - if column_stats__.is_some() { - return Err(serde::de::Error::duplicate_field("columnStats")); + GeneratedField::SortExpr => { + if sort_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExpr")); } - column_stats__ = Some(map_.next_value()?); + sort_expr__ = Some(map_.next_value()?); } } } - Ok(PartitionStats { - num_rows: num_rows__.unwrap_or_default(), - num_batches: num_batches__.unwrap_or_default(), - num_bytes: num_bytes__.unwrap_or_default(), - column_stats: column_stats__.unwrap_or_default(), + Ok(MaybePhysicalSortExprs { + sort_expr: sort_expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PartitionStats", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.MaybePhysicalSortExprs", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartitionedFile { +impl serde::Serialize for NamedStructField { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17436,65 +11165,29 @@ impl serde::Serialize for PartitionedFile { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.path.is_empty() { - len += 1; - } - if self.size != 0 { - len += 1; - } - if self.last_modified_ns != 0 { - len += 1; - } - if !self.partition_values.is_empty() { - len += 1; - } - if self.range.is_some() { + if self.name.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionedFile", len)?; - if !self.path.is_empty() { - struct_ser.serialize_field("path", &self.path)?; - } - if self.size != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("size", ToString::to_string(&self.size).as_str())?; - } - if self.last_modified_ns != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("lastModifiedNs", ToString::to_string(&self.last_modified_ns).as_str())?; - } - if !self.partition_values.is_empty() { - struct_ser.serialize_field("partitionValues", &self.partition_values)?; - } - if let Some(v) = self.range.as_ref() { - struct_ser.serialize_field("range", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructField", len)?; + if let Some(v) = self.name.as_ref() { + struct_ser.serialize_field("name", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PartitionedFile { +impl<'de> serde::Deserialize<'de> for NamedStructField { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "path", - "size", - "last_modified_ns", - "lastModifiedNs", - "partition_values", - "partitionValues", - "range", + "name", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Path, - Size, - LastModifiedNs, - PartitionValues, - Range, + Name, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17516,11 +11209,7 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { E: serde::de::Error, { match value { - "path" => Ok(GeneratedField::Path), - "size" => Ok(GeneratedField::Size), - "lastModifiedNs" | "last_modified_ns" => Ok(GeneratedField::LastModifiedNs), - "partitionValues" | "partition_values" => Ok(GeneratedField::PartitionValues), - "range" => Ok(GeneratedField::Range), + "name" => Ok(GeneratedField::Name), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17530,72 +11219,36 @@ impl<'de> serde::Deserialize<'de> for PartitionedFile { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionedFile; + type Value = NamedStructField; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionedFile") + formatter.write_str("struct datafusion.NamedStructField") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut path__ = None; - let mut size__ = None; - let mut last_modified_ns__ = None; - let mut partition_values__ = None; - let mut range__ = None; + let mut name__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Path => { - if path__.is_some() { - return Err(serde::de::Error::duplicate_field("path")); - } - path__ = Some(map_.next_value()?); - } - GeneratedField::Size => { - if size__.is_some() { - return Err(serde::de::Error::duplicate_field("size")); - } - size__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::LastModifiedNs => { - if last_modified_ns__.is_some() { - return Err(serde::de::Error::duplicate_field("lastModifiedNs")); - } - last_modified_ns__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::PartitionValues => { - if partition_values__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionValues")); - } - partition_values__ = Some(map_.next_value()?); - } - GeneratedField::Range => { - if range__.is_some() { - return Err(serde::de::Error::duplicate_field("range")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - range__ = map_.next_value()?; + name__ = map_.next_value()?; } } } - Ok(PartitionedFile { - path: path__.unwrap_or_default(), - size: size__.unwrap_or_default(), - last_modified_ns: last_modified_ns__.unwrap_or_default(), - partition_values: partition_values__.unwrap_or_default(), - range: range__, + Ok(NamedStructField { + name: name__, }) } } - deserializer.deserialize_struct("datafusion.PartitionedFile", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalAggregateExprNode { +impl serde::Serialize for NegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17603,44 +11256,17 @@ impl serde::Serialize for PhysicalAggregateExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { - len += 1; - } - if !self.ordering_req.is_empty() { - len += 1; - } - if self.distinct { - len += 1; - } - if self.aggregate_function.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - if !self.ordering_req.is_empty() { - struct_ser.serialize_field("orderingReq", &self.ordering_req)?; - } - if self.distinct { - struct_ser.serialize_field("distinct", &self.distinct)?; - } - if let Some(v) = self.aggregate_function.as_ref() { - match v { - physical_aggregate_expr_node::AggregateFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } - physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { - struct_ser.serialize_field("userDefinedAggrFunction", v)?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.NegativeNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { +impl<'de> serde::Deserialize<'de> for NegativeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -17648,22 +11274,11 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { const FIELDS: &[&str] = &[ "expr", - "ordering_req", - "orderingReq", - "distinct", - "aggr_function", - "aggrFunction", - "user_defined_aggr_function", - "userDefinedAggrFunction", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - OrderingReq, - Distinct, - AggrFunction, - UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17686,10 +11301,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { { match value { "expr" => Ok(GeneratedField::Expr), - "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), - "distinct" => Ok(GeneratedField::Distinct), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17699,66 +11310,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalAggregateExprNode; + type Value = NegativeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalAggregateExprNode") + formatter.write_str("struct datafusion.NegativeNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut ordering_req__ = None; - let mut distinct__ = None; - let mut aggregate_function__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map_.next_value()?); - } - GeneratedField::OrderingReq => { - if ordering_req__.is_some() { - return Err(serde::de::Error::duplicate_field("orderingReq")); - } - ordering_req__ = Some(map_.next_value()?); - } - GeneratedField::Distinct => { - if distinct__.is_some() { - return Err(serde::de::Error::duplicate_field("distinct")); - } - distinct__ = Some(map_.next_value()?); - } - GeneratedField::AggrFunction => { - if aggregate_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); - } - aggregate_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_aggregate_expr_node::AggregateFunction::AggrFunction(x as i32)); - } - GeneratedField::UserDefinedAggrFunction => { - if aggregate_function__.is_some() { - return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); - } - aggregate_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); + expr__ = map_.next_value()?; } } } - Ok(PhysicalAggregateExprNode { - expr: expr__.unwrap_or_default(), - ordering_req: ordering_req__.unwrap_or_default(), - distinct: distinct__.unwrap_or_default(), - aggregate_function: aggregate_function__, + Ok(NegativeNode { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalAggregateExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalAliasNode { +impl serde::Serialize for NestedLoopJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17766,37 +11347,56 @@ impl serde::Serialize for PhysicalAliasNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.left.is_some() { + len += 1; + } + if self.right.is_some() { len += 1; } - if !self.alias.is_empty() { + if self.join_type != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAliasNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.filter.is_some() { + len += 1; } - if !self.alias.is_empty() { - struct_ser.serialize_field("alias", &self.alias)?; + let mut struct_ser = serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { +impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "alias", + "left", + "right", + "join_type", + "joinType", + "filter", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - Alias, + Left, + Right, + JoinType, + Filter, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17818,8 +11418,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "alias" => Ok(GeneratedField::Alias), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "filter" => Ok(GeneratedField::Filter), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17829,44 +11431,60 @@ impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalAliasNode; + type Value = NestedLoopJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalAliasNode") + formatter.write_str("struct datafusion.NestedLoopJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut alias__ = None; + let mut left__ = None; + let mut right__ = None; + let mut join_type__ = None; + let mut filter__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - expr__ = map_.next_value()?; + left__ = map_.next_value()?; } - GeneratedField::Alias => { - if alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - alias__ = Some(map_.next_value()?); + right__ = map_.next_value()?; + } + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; } } } - Ok(PhysicalAliasNode { - expr: expr__, - alias: alias__.unwrap_or_default(), + Ok(NestedLoopJoinExecNode { + left: left__, + right: right__, + join_type: join_type__.unwrap_or_default(), + filter: filter__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalAliasNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalBinaryExprNode { +impl serde::Serialize for Not { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17874,45 +11492,29 @@ impl serde::Serialize for PhysicalBinaryExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.l.is_some() { - len += 1; - } - if self.r.is_some() { - len += 1; - } - if !self.op.is_empty() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalBinaryExprNode", len)?; - if let Some(v) = self.l.as_ref() { - struct_ser.serialize_field("l", v)?; - } - if let Some(v) = self.r.as_ref() { - struct_ser.serialize_field("r", v)?; - } - if !self.op.is_empty() { - struct_ser.serialize_field("op", &self.op)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Not", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { +impl<'de> serde::Deserialize<'de> for Not { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "l", - "r", - "op", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - L, - R, - Op, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17934,9 +11536,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { E: serde::de::Error, { match value { - "l" => Ok(GeneratedField::L), - "r" => Ok(GeneratedField::R), - "op" => Ok(GeneratedField::Op), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17946,52 +11546,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalBinaryExprNode; + type Value = Not; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalBinaryExprNode") + formatter.write_str("struct datafusion.Not") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut l__ = None; - let mut r__ = None; - let mut op__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::L => { - if l__.is_some() { - return Err(serde::de::Error::duplicate_field("l")); - } - l__ = map_.next_value()?; - } - GeneratedField::R => { - if r__.is_some() { - return Err(serde::de::Error::duplicate_field("r")); - } - r__ = map_.next_value()?; - } - GeneratedField::Op => { - if op__.is_some() { - return Err(serde::de::Error::duplicate_field("op")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - op__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(PhysicalBinaryExprNode { - l: l__, - r: r__, - op: op__.unwrap_or_default(), + Ok(Not { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalBinaryExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Not", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalCaseNode { +impl serde::Serialize for OptimizedLogicalPlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -17999,47 +11583,30 @@ impl serde::Serialize for PhysicalCaseNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if !self.when_then_expr.is_empty() { - len += 1; - } - if self.else_expr.is_some() { + if !self.optimizer_name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCaseNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if !self.when_then_expr.is_empty() { - struct_ser.serialize_field("whenThenExpr", &self.when_then_expr)?; - } - if let Some(v) = self.else_expr.as_ref() { - struct_ser.serialize_field("elseExpr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedLogicalPlanType", len)?; + if !self.optimizer_name.is_empty() { + struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { +impl<'de> serde::Deserialize<'de> for OptimizedLogicalPlanType { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "when_then_expr", - "whenThenExpr", - "else_expr", - "elseExpr", + "optimizer_name", + "optimizerName", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - WhenThenExpr, - ElseExpr, + OptimizerName, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18061,9 +11628,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "whenThenExpr" | "when_then_expr" => Ok(GeneratedField::WhenThenExpr), - "elseExpr" | "else_expr" => Ok(GeneratedField::ElseExpr), + "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18073,52 +11638,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalCaseNode; + type Value = OptimizedLogicalPlanType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalCaseNode") + formatter.write_str("struct datafusion.OptimizedLogicalPlanType") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut when_then_expr__ = None; - let mut else_expr__ = None; + let mut optimizer_name__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::WhenThenExpr => { - if when_then_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("whenThenExpr")); - } - when_then_expr__ = Some(map_.next_value()?); - } - GeneratedField::ElseExpr => { - if else_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("elseExpr")); + GeneratedField::OptimizerName => { + if optimizer_name__.is_some() { + return Err(serde::de::Error::duplicate_field("optimizerName")); } - else_expr__ = map_.next_value()?; + optimizer_name__ = Some(map_.next_value()?); } } } - Ok(PhysicalCaseNode { - expr: expr__, - when_then_expr: when_then_expr__.unwrap_or_default(), - else_expr: else_expr__, + Ok(OptimizedLogicalPlanType { + optimizer_name: optimizer_name__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalCaseNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.OptimizedLogicalPlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalCastNode { +impl serde::Serialize for OptimizedPhysicalPlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18126,38 +11675,30 @@ impl serde::Serialize for PhysicalCastNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { - len += 1; - } - if self.arrow_type.is_some() { + if !self.optimizer_name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.OptimizedPhysicalPlanType", len)?; + if !self.optimizer_name.is_empty() { + struct_ser.serialize_field("optimizerName", &self.optimizer_name)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalCastNode { +impl<'de> serde::Deserialize<'de> for OptimizedPhysicalPlanType { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "arrow_type", - "arrowType", + "optimizer_name", + "optimizerName", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - ArrowType, + OptimizerName, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18179,8 +11720,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "optimizerName" | "optimizer_name" => Ok(GeneratedField::OptimizerName), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18190,44 +11730,36 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalCastNode; + type Value = OptimizedPhysicalPlanType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalCastNode") + formatter.write_str("struct datafusion.OptimizedPhysicalPlanType") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut arrow_type__ = None; + let mut optimizer_name__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::OptimizerName => { + if optimizer_name__.is_some() { + return Err(serde::de::Error::duplicate_field("optimizerName")); } - arrow_type__ = map_.next_value()?; + optimizer_name__ = Some(map_.next_value()?); } } } - Ok(PhysicalCastNode { - expr: expr__, - arrow_type: arrow_type__, + Ok(OptimizedPhysicalPlanType { + optimizer_name: optimizer_name__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalCastNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.OptimizedPhysicalPlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalColumn { +impl serde::Serialize for ParquetScanExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18235,37 +11767,38 @@ impl serde::Serialize for PhysicalColumn { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if self.base_conf.is_some() { len += 1; } - if self.index != 0 { + if self.predicate.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalColumn", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; + if let Some(v) = self.base_conf.as_ref() { + struct_ser.serialize_field("baseConf", v)?; } - if self.index != 0 { - struct_ser.serialize_field("index", &self.index)?; + if let Some(v) = self.predicate.as_ref() { + struct_ser.serialize_field("predicate", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalColumn { +impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "index", + "base_conf", + "baseConf", + "predicate", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - Index, + BaseConf, + Predicate, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18287,8 +11820,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "index" => Ok(GeneratedField::Index), + "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), + "predicate" => Ok(GeneratedField::Predicate), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18298,46 +11831,44 @@ impl<'de> serde::Deserialize<'de> for PhysicalColumn { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalColumn; + type Value = ParquetScanExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalColumn") + formatter.write_str("struct datafusion.ParquetScanExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut index__ = None; + let mut base_conf__ = None; + let mut predicate__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); + GeneratedField::BaseConf => { + if base_conf__.is_some() { + return Err(serde::de::Error::duplicate_field("baseConf")); } - name__ = Some(map_.next_value()?); + base_conf__ = map_.next_value()?; } - GeneratedField::Index => { - if index__.is_some() { - return Err(serde::de::Error::duplicate_field("index")); + GeneratedField::Predicate => { + if predicate__.is_some() { + return Err(serde::de::Error::duplicate_field("predicate")); } - index__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + predicate__ = map_.next_value()?; } } } - Ok(PhysicalColumn { - name: name__.unwrap_or_default(), - index: index__.unwrap_or_default(), + Ok(ParquetScanExecNode { + base_conf: base_conf__, + predicate: predicate__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalColumn", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetScanExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalDateTimeIntervalExprNode { +impl serde::Serialize for ParquetSink { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18345,45 +11876,38 @@ impl serde::Serialize for PhysicalDateTimeIntervalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.l.is_some() { - len += 1; - } - if self.r.is_some() { + if self.config.is_some() { len += 1; } - if !self.op.is_empty() { + if self.parquet_options.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", len)?; - if let Some(v) = self.l.as_ref() { - struct_ser.serialize_field("l", v)?; - } - if let Some(v) = self.r.as_ref() { - struct_ser.serialize_field("r", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSink", len)?; + if let Some(v) = self.config.as_ref() { + struct_ser.serialize_field("config", v)?; } - if !self.op.is_empty() { - struct_ser.serialize_field("op", &self.op)?; + if let Some(v) = self.parquet_options.as_ref() { + struct_ser.serialize_field("parquetOptions", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { +impl<'de> serde::Deserialize<'de> for ParquetSink { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "l", - "r", - "op", + "config", + "parquet_options", + "parquetOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - L, - R, - Op, + Config, + ParquetOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18405,9 +11929,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { E: serde::de::Error, { match value { - "l" => Ok(GeneratedField::L), - "r" => Ok(GeneratedField::R), - "op" => Ok(GeneratedField::Op), + "config" => Ok(GeneratedField::Config), + "parquetOptions" | "parquet_options" => Ok(GeneratedField::ParquetOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18417,52 +11940,44 @@ impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalDateTimeIntervalExprNode; + type Value = ParquetSink; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalDateTimeIntervalExprNode") + formatter.write_str("struct datafusion.ParquetSink") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut l__ = None; - let mut r__ = None; - let mut op__ = None; + let mut config__ = None; + let mut parquet_options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::L => { - if l__.is_some() { - return Err(serde::de::Error::duplicate_field("l")); - } - l__ = map_.next_value()?; - } - GeneratedField::R => { - if r__.is_some() { - return Err(serde::de::Error::duplicate_field("r")); + GeneratedField::Config => { + if config__.is_some() { + return Err(serde::de::Error::duplicate_field("config")); } - r__ = map_.next_value()?; + config__ = map_.next_value()?; } - GeneratedField::Op => { - if op__.is_some() { - return Err(serde::de::Error::duplicate_field("op")); + GeneratedField::ParquetOptions => { + if parquet_options__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetOptions")); } - op__ = Some(map_.next_value()?); + parquet_options__ = map_.next_value()?; } } } - Ok(PhysicalDateTimeIntervalExprNode { - l: l__, - r: r__, - op: op__.unwrap_or_default(), + Ok(ParquetSink { + config: config__, + parquet_options: parquet_options__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ParquetSink", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalExprNode { +impl serde::Serialize for ParquetSinkExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18470,119 +11985,55 @@ impl serde::Serialize for PhysicalExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr_type.is_some() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; - if let Some(v) = self.expr_type.as_ref() { - match v { - physical_expr_node::ExprType::Column(v) => { - struct_ser.serialize_field("column", v)?; - } - physical_expr_node::ExprType::Literal(v) => { - struct_ser.serialize_field("literal", v)?; - } - physical_expr_node::ExprType::BinaryExpr(v) => { - struct_ser.serialize_field("binaryExpr", v)?; - } - physical_expr_node::ExprType::AggregateExpr(v) => { - struct_ser.serialize_field("aggregateExpr", v)?; - } - physical_expr_node::ExprType::IsNullExpr(v) => { - struct_ser.serialize_field("isNullExpr", v)?; - } - physical_expr_node::ExprType::IsNotNullExpr(v) => { - struct_ser.serialize_field("isNotNullExpr", v)?; - } - physical_expr_node::ExprType::NotExpr(v) => { - struct_ser.serialize_field("notExpr", v)?; - } - physical_expr_node::ExprType::Case(v) => { - struct_ser.serialize_field("case", v)?; - } - physical_expr_node::ExprType::Cast(v) => { - struct_ser.serialize_field("cast", v)?; - } - physical_expr_node::ExprType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - physical_expr_node::ExprType::Negative(v) => { - struct_ser.serialize_field("negative", v)?; - } - physical_expr_node::ExprType::InList(v) => { - struct_ser.serialize_field("inList", v)?; - } - physical_expr_node::ExprType::TryCast(v) => { - struct_ser.serialize_field("tryCast", v)?; - } - physical_expr_node::ExprType::WindowExpr(v) => { - struct_ser.serialize_field("windowExpr", v)?; - } - physical_expr_node::ExprType::ScalarUdf(v) => { - struct_ser.serialize_field("scalarUdf", v)?; - } - physical_expr_node::ExprType::LikeExpr(v) => { - struct_ser.serialize_field("likeExpr", v)?; - } - } + if self.sink.is_some() { + len += 1; + } + if self.sink_schema.is_some() { + len += 1; + } + if self.sort_order.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ParquetSinkExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.sink.as_ref() { + struct_ser.serialize_field("sink", v)?; + } + if let Some(v) = self.sink_schema.as_ref() { + struct_ser.serialize_field("sinkSchema", v)?; + } + if let Some(v) = self.sort_order.as_ref() { + struct_ser.serialize_field("sortOrder", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalExprNode { +impl<'de> serde::Deserialize<'de> for ParquetSinkExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "column", - "literal", - "binary_expr", - "binaryExpr", - "aggregate_expr", - "aggregateExpr", - "is_null_expr", - "isNullExpr", - "is_not_null_expr", - "isNotNullExpr", - "not_expr", - "notExpr", - "case_", - "case", - "cast", - "sort", - "negative", - "in_list", - "inList", - "try_cast", - "tryCast", - "window_expr", - "windowExpr", - "scalar_udf", - "scalarUdf", - "like_expr", - "likeExpr", + "input", + "sink", + "sink_schema", + "sinkSchema", + "sort_order", + "sortOrder", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Column, - Literal, - BinaryExpr, - AggregateExpr, - IsNullExpr, - IsNotNullExpr, - NotExpr, - Case, - Cast, - Sort, - Negative, - InList, - TryCast, - WindowExpr, - ScalarUdf, - LikeExpr, + Input, + Sink, + SinkSchema, + SortOrder, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18604,22 +12055,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { E: serde::de::Error, { match value { - "column" => Ok(GeneratedField::Column), - "literal" => Ok(GeneratedField::Literal), - "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), - "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), - "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), - "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), - "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), - "case" | "case_" => Ok(GeneratedField::Case), - "cast" => Ok(GeneratedField::Cast), - "sort" => Ok(GeneratedField::Sort), - "negative" => Ok(GeneratedField::Negative), - "inList" | "in_list" => Ok(GeneratedField::InList), - "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), - "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), - "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), - "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), + "input" => Ok(GeneratedField::Input), + "sink" => Ok(GeneratedField::Sink), + "sinkSchema" | "sink_schema" => Ok(GeneratedField::SinkSchema), + "sortOrder" | "sort_order" => Ok(GeneratedField::SortOrder), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18629,142 +12068,168 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalExprNode; + type Value = ParquetSinkExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalExprNode") + formatter.write_str("struct datafusion.ParquetSinkExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr_type__ = None; + let mut input__ = None; + let mut sink__ = None; + let mut sink_schema__ = None; + let mut sort_order__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Column => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("column")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Column) -; - } - GeneratedField::Literal => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("literal")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Literal) -; - } - GeneratedField::BinaryExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("binaryExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::BinaryExpr) -; - } - GeneratedField::AggregateExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregateExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::AggregateExpr) -; - } - GeneratedField::IsNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNullExpr) -; - } - GeneratedField::IsNotNullExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("isNotNullExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNotNullExpr) -; - } - GeneratedField::NotExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("notExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::NotExpr) -; - } - GeneratedField::Case => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("case")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Case) -; - } - GeneratedField::Cast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("cast")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Cast) -; - } - GeneratedField::Sort => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Sort) -; + input__ = map_.next_value()?; } - GeneratedField::Negative => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("negative")); + GeneratedField::Sink => { + if sink__.is_some() { + return Err(serde::de::Error::duplicate_field("sink")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Negative) -; + sink__ = map_.next_value()?; } - GeneratedField::InList => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("inList")); + GeneratedField::SinkSchema => { + if sink_schema__.is_some() { + return Err(serde::de::Error::duplicate_field("sinkSchema")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::InList) -; + sink_schema__ = map_.next_value()?; } - GeneratedField::TryCast => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("tryCast")); + GeneratedField::SortOrder => { + if sort_order__.is_some() { + return Err(serde::de::Error::duplicate_field("sortOrder")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::TryCast) -; + sort_order__ = map_.next_value()?; } - GeneratedField::WindowExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("windowExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::WindowExpr) -; + } + } + Ok(ParquetSinkExecNode { + input: input__, + sink: sink__, + sink_schema: sink_schema__, + sort_order: sort_order__, + }) + } + } + deserializer.deserialize_struct("datafusion.ParquetSinkExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PartialTableReference { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.schema.is_empty() { + len += 1; + } + if !self.table.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartialTableReference", len)?; + if !self.schema.is_empty() { + struct_ser.serialize_field("schema", &self.schema)?; + } + if !self.table.is_empty() { + struct_ser.serialize_field("table", &self.table)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PartialTableReference { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "schema", + "table", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Schema, + Table, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "schema" => Ok(GeneratedField::Schema), + "table" => Ok(GeneratedField::Table), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - GeneratedField::ScalarUdf => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("scalarUdf")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartialTableReference; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PartialTableReference") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut schema__ = None; + let mut table__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarUdf) -; + schema__ = Some(map_.next_value()?); } - GeneratedField::LikeExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("likeExpr")); + GeneratedField::Table => { + if table__.is_some() { + return Err(serde::de::Error::duplicate_field("table")); } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) -; + table__ = Some(map_.next_value()?); } } } - Ok(PhysicalExprNode { - expr_type: expr_type__, + Ok(PartialTableReference { + schema: schema__.unwrap_or_default(), + table: table__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalExtensionNode { +impl serde::Serialize for PartiallySortedInputOrderMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18772,38 +12237,29 @@ impl serde::Serialize for PhysicalExtensionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.node.is_empty() { - len += 1; - } - if !self.inputs.is_empty() { + if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionNode", len)?; - if !self.node.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; - } - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartiallySortedInputOrderMode", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns.iter().map(ToString::to_string).collect::>())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { +impl<'de> serde::Deserialize<'de> for PartiallySortedInputOrderMode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "node", - "inputs", + "columns", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Node, - Inputs, + Columns, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18825,8 +12281,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { E: serde::de::Error, { match value { - "node" => Ok(GeneratedField::Node), - "inputs" => Ok(GeneratedField::Inputs), + "columns" => Ok(GeneratedField::Columns), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18836,46 +12291,39 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalExtensionNode; + type Value = PartiallySortedInputOrderMode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalExtensionNode") + formatter.write_str("struct datafusion.PartiallySortedInputOrderMode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut node__ = None; - let mut inputs__ = None; + let mut columns__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Node => { - if node__.is_some() { - return Err(serde::de::Error::duplicate_field("node")); + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); } - node__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + columns__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) ; } - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); - } - inputs__ = Some(map_.next_value()?); - } } } - Ok(PhysicalExtensionNode { - node: node__.unwrap_or_default(), - inputs: inputs__.unwrap_or_default(), + Ok(PartiallySortedInputOrderMode { + columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalExtensionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartiallySortedInputOrderMode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalHashRepartition { +impl serde::Serialize for PartitionColumn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18883,40 +12331,38 @@ impl serde::Serialize for PhysicalHashRepartition { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.hash_expr.is_empty() { + if !self.name.is_empty() { len += 1; } - if self.partition_count != 0 { + if self.arrow_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalHashRepartition", len)?; - if !self.hash_expr.is_empty() { - struct_ser.serialize_field("hashExpr", &self.hash_expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionColumn", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; } - if self.partition_count != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { +impl<'de> serde::Deserialize<'de> for PartitionColumn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "hash_expr", - "hashExpr", - "partition_count", - "partitionCount", + "name", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - HashExpr, - PartitionCount, + Name, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -18938,8 +12384,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { E: serde::de::Error, { match value { - "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), - "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), + "name" => Ok(GeneratedField::Name), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -18949,46 +12395,118 @@ impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalHashRepartition; + type Value = PartitionColumn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalHashRepartition") + formatter.write_str("struct datafusion.PartitionColumn") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut hash_expr__ = None; - let mut partition_count__ = None; + let mut name__ = None; + let mut arrow_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::HashExpr => { - if hash_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("hashExpr")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - hash_expr__ = Some(map_.next_value()?); + name__ = Some(map_.next_value()?); } - GeneratedField::PartitionCount => { - if partition_count__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionCount")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - partition_count__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + arrow_type__ = map_.next_value()?; } } } - Ok(PhysicalHashRepartition { - hash_expr: hash_expr__.unwrap_or_default(), - partition_count: partition_count__.unwrap_or_default(), - }) + Ok(PartitionColumn { + name: name__.unwrap_or_default(), + arrow_type: arrow_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.PartitionColumn", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PartitionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::CollectLeft => "COLLECT_LEFT", + Self::Partitioned => "PARTITIONED", + Self::Auto => "AUTO", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for PartitionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "COLLECT_LEFT", + "PARTITIONED", + "AUTO", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PartitionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "COLLECT_LEFT" => Ok(PartitionMode::CollectLeft), + "PARTITIONED" => Ok(PartitionMode::Partitioned), + "AUTO" => Ok(PartitionMode::Auto), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } } } - deserializer.deserialize_struct("datafusion.PhysicalHashRepartition", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for PhysicalInListNode { +impl serde::Serialize for PartitionStats { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -18996,45 +12514,63 @@ impl serde::Serialize for PhysicalInListNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.num_rows != 0 { len += 1; } - if !self.list.is_empty() { + if self.num_batches != 0 { len += 1; } - if self.negated { + if self.num_bytes != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalInListNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.column_stats.is_empty() { + len += 1; } - if !self.list.is_empty() { - struct_ser.serialize_field("list", &self.list)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionStats", len)?; + if self.num_rows != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("numRows", ToString::to_string(&self.num_rows).as_str())?; } - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; + if self.num_batches != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("numBatches", ToString::to_string(&self.num_batches).as_str())?; + } + if self.num_bytes != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("numBytes", ToString::to_string(&self.num_bytes).as_str())?; + } + if !self.column_stats.is_empty() { + struct_ser.serialize_field("columnStats", &self.column_stats)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalInListNode { +impl<'de> serde::Deserialize<'de> for PartitionStats { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "list", - "negated", + "num_rows", + "numRows", + "num_batches", + "numBatches", + "num_bytes", + "numBytes", + "column_stats", + "columnStats", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - List, - Negated, + NumRows, + NumBatches, + NumBytes, + ColumnStats, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19056,9 +12592,10 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "list" => Ok(GeneratedField::List), - "negated" => Ok(GeneratedField::Negated), + "numRows" | "num_rows" => Ok(GeneratedField::NumRows), + "numBatches" | "num_batches" => Ok(GeneratedField::NumBatches), + "numBytes" | "num_bytes" => Ok(GeneratedField::NumBytes), + "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19068,52 +12605,66 @@ impl<'de> serde::Deserialize<'de> for PhysicalInListNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalInListNode; + type Value = PartitionStats; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalInListNode") + formatter.write_str("struct datafusion.PartitionStats") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; - let mut list__ = None; - let mut negated__ = None; + let mut num_rows__ = None; + let mut num_batches__ = None; + let mut num_bytes__ = None; + let mut column_stats__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::NumRows => { + if num_rows__.is_some() { + return Err(serde::de::Error::duplicate_field("numRows")); } - expr__ = map_.next_value()?; + num_rows__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::List => { - if list__.is_some() { - return Err(serde::de::Error::duplicate_field("list")); + GeneratedField::NumBatches => { + if num_batches__.is_some() { + return Err(serde::de::Error::duplicate_field("numBatches")); } - list__ = Some(map_.next_value()?); + num_batches__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); + GeneratedField::NumBytes => { + if num_bytes__.is_some() { + return Err(serde::de::Error::duplicate_field("numBytes")); } - negated__ = Some(map_.next_value()?); + num_bytes__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::ColumnStats => { + if column_stats__.is_some() { + return Err(serde::de::Error::duplicate_field("columnStats")); + } + column_stats__ = Some(map_.next_value()?); } } } - Ok(PhysicalInListNode { - expr: expr__, - list: list__.unwrap_or_default(), - negated: negated__.unwrap_or_default(), + Ok(PartitionStats { + num_rows: num_rows__.unwrap_or_default(), + num_batches: num_batches__.unwrap_or_default(), + num_bytes: num_bytes__.unwrap_or_default(), + column_stats: column_stats__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalInListNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartitionStats", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalIsNotNull { +impl serde::Serialize for PartitionedFile { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19121,29 +12672,75 @@ impl serde::Serialize for PhysicalIsNotNull { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.path.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNotNull", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if self.size != 0 { + len += 1; + } + if self.last_modified_ns != 0 { + len += 1; + } + if !self.partition_values.is_empty() { + len += 1; + } + if self.range.is_some() { + len += 1; + } + if self.statistics.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PartitionedFile", len)?; + if !self.path.is_empty() { + struct_ser.serialize_field("path", &self.path)?; + } + if self.size != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("size", ToString::to_string(&self.size).as_str())?; + } + if self.last_modified_ns != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("lastModifiedNs", ToString::to_string(&self.last_modified_ns).as_str())?; + } + if !self.partition_values.is_empty() { + struct_ser.serialize_field("partitionValues", &self.partition_values)?; + } + if let Some(v) = self.range.as_ref() { + struct_ser.serialize_field("range", v)?; + } + if let Some(v) = self.statistics.as_ref() { + struct_ser.serialize_field("statistics", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { +impl<'de> serde::Deserialize<'de> for PartitionedFile { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "path", + "size", + "last_modified_ns", + "lastModifiedNs", + "partition_values", + "partitionValues", + "range", + "statistics", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Path, + Size, + LastModifiedNs, + PartitionValues, + Range, + Statistics, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19165,7 +12762,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "path" => Ok(GeneratedField::Path), + "size" => Ok(GeneratedField::Size), + "lastModifiedNs" | "last_modified_ns" => Ok(GeneratedField::LastModifiedNs), + "partitionValues" | "partition_values" => Ok(GeneratedField::PartitionValues), + "range" => Ok(GeneratedField::Range), + "statistics" => Ok(GeneratedField::Statistics), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19175,36 +12777,80 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalIsNotNull; + type Value = PartitionedFile; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalIsNotNull") + formatter.write_str("struct datafusion.PartitionedFile") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut path__ = None; + let mut size__ = None; + let mut last_modified_ns__ = None; + let mut partition_values__ = None; + let mut range__ = None; + let mut statistics__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Path => { + if path__.is_some() { + return Err(serde::de::Error::duplicate_field("path")); + } + path__ = Some(map_.next_value()?); + } + GeneratedField::Size => { + if size__.is_some() { + return Err(serde::de::Error::duplicate_field("size")); + } + size__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::LastModifiedNs => { + if last_modified_ns__.is_some() { + return Err(serde::de::Error::duplicate_field("lastModifiedNs")); + } + last_modified_ns__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::PartitionValues => { + if partition_values__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionValues")); + } + partition_values__ = Some(map_.next_value()?); + } + GeneratedField::Range => { + if range__.is_some() { + return Err(serde::de::Error::duplicate_field("range")); } - expr__ = map_.next_value()?; + range__ = map_.next_value()?; + } + GeneratedField::Statistics => { + if statistics__.is_some() { + return Err(serde::de::Error::duplicate_field("statistics")); + } + statistics__ = map_.next_value()?; } } } - Ok(PhysicalIsNotNull { - expr: expr__, + Ok(PartitionedFile { + path: path__.unwrap_or_default(), + size: size__.unwrap_or_default(), + last_modified_ns: last_modified_ns__.unwrap_or_default(), + partition_values: partition_values__.unwrap_or_default(), + range: range__, + statistics: statistics__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalIsNotNull", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PartitionedFile", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalIsNull { +impl serde::Serialize for Partitioning { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19212,29 +12858,48 @@ impl serde::Serialize for PhysicalIsNull { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.partition_method.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNull", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Partitioning", len)?; + if let Some(v) = self.partition_method.as_ref() { + match v { + partitioning::PartitionMethod::RoundRobin(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("roundRobin", ToString::to_string(&v).as_str())?; + } + partitioning::PartitionMethod::Hash(v) => { + struct_ser.serialize_field("hash", v)?; + } + partitioning::PartitionMethod::Unknown(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("unknown", ToString::to_string(&v).as_str())?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalIsNull { +impl<'de> serde::Deserialize<'de> for Partitioning { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "round_robin", + "roundRobin", + "hash", + "unknown", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + RoundRobin, + Hash, + Unknown, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19256,7 +12921,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "roundRobin" | "round_robin" => Ok(GeneratedField::RoundRobin), + "hash" => Ok(GeneratedField::Hash), + "unknown" => Ok(GeneratedField::Unknown), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19266,36 +12933,49 @@ impl<'de> serde::Deserialize<'de> for PhysicalIsNull { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalIsNull; + type Value = Partitioning; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalIsNull") + formatter.write_str("struct datafusion.Partitioning") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut partition_method__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::RoundRobin => { + if partition_method__.is_some() { + return Err(serde::de::Error::duplicate_field("roundRobin")); } - expr__ = map_.next_value()?; + partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| partitioning::PartitionMethod::RoundRobin(x.0)); + } + GeneratedField::Hash => { + if partition_method__.is_some() { + return Err(serde::de::Error::duplicate_field("hash")); + } + partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(partitioning::PartitionMethod::Hash) +; + } + GeneratedField::Unknown => { + if partition_method__.is_some() { + return Err(serde::de::Error::duplicate_field("unknown")); + } + partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| partitioning::PartitionMethod::Unknown(x.0)); } } } - Ok(PhysicalIsNull { - expr: expr__, + Ok(Partitioning { + partition_method: partition_method__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalIsNull", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Partitioning", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalLikeExprNode { +impl serde::Serialize for PhysicalAggregateExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19303,54 +12983,79 @@ impl serde::Serialize for PhysicalLikeExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { + if !self.expr.is_empty() { len += 1; } - if self.case_insensitive { + if !self.ordering_req.is_empty() { len += 1; } - if self.expr.is_some() { + if self.distinct { len += 1; } - if self.pattern.is_some() { + if self.ignore_nulls { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalLikeExprNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; + if self.fun_definition.is_some() { + len += 1; } - if self.case_insensitive { - struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; + if self.aggregate_function.is_some() { + len += 1; } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAggregateExprNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; + if !self.ordering_req.is_empty() { + struct_ser.serialize_field("orderingReq", &self.ordering_req)?; + } + if self.distinct { + struct_ser.serialize_field("distinct", &self.distinct)?; + } + if self.ignore_nulls { + struct_ser.serialize_field("ignoreNulls", &self.ignore_nulls)?; + } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } + if let Some(v) = self.aggregate_function.as_ref() { + match v { + physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(v) => { + struct_ser.serialize_field("userDefinedAggrFunction", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalAggregateExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", - "case_insensitive", - "caseInsensitive", "expr", - "pattern", + "ordering_req", + "orderingReq", + "distinct", + "ignore_nulls", + "ignoreNulls", + "fun_definition", + "funDefinition", + "user_defined_aggr_function", + "userDefinedAggrFunction", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, - CaseInsensitive, Expr, - Pattern, + OrderingReq, + Distinct, + IgnoreNulls, + FunDefinition, + UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19372,10 +13077,12 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), - "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), + "orderingReq" | "ordering_req" => Ok(GeneratedField::OrderingReq), + "distinct" => Ok(GeneratedField::Distinct), + "ignoreNulls" | "ignore_nulls" => Ok(GeneratedField::IgnoreNulls), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19385,60 +13092,78 @@ impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalLikeExprNode; + type Value = PhysicalAggregateExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalLikeExprNode") + formatter.write_str("struct datafusion.PhysicalAggregateExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; - let mut case_insensitive__ = None; let mut expr__ = None; - let mut pattern__ = None; + let mut ordering_req__ = None; + let mut distinct__ = None; + let mut ignore_nulls__ = None; + let mut fun_definition__ = None; + let mut aggregate_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - negated__ = Some(map_.next_value()?); + expr__ = Some(map_.next_value()?); } - GeneratedField::CaseInsensitive => { - if case_insensitive__.is_some() { - return Err(serde::de::Error::duplicate_field("caseInsensitive")); + GeneratedField::OrderingReq => { + if ordering_req__.is_some() { + return Err(serde::de::Error::duplicate_field("orderingReq")); } - case_insensitive__ = Some(map_.next_value()?); + ordering_req__ = Some(map_.next_value()?); } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Distinct => { + if distinct__.is_some() { + return Err(serde::de::Error::duplicate_field("distinct")); } - expr__ = map_.next_value()?; + distinct__ = Some(map_.next_value()?); } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); + GeneratedField::IgnoreNulls => { + if ignore_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("ignoreNulls")); } - pattern__ = map_.next_value()?; + ignore_nulls__ = Some(map_.next_value()?); + } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } + GeneratedField::UserDefinedAggrFunction => { + if aggregate_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); + } + aggregate_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction); } } } - Ok(PhysicalLikeExprNode { - negated: negated__.unwrap_or_default(), - case_insensitive: case_insensitive__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, + Ok(PhysicalAggregateExprNode { + expr: expr__.unwrap_or_default(), + ordering_req: ordering_req__.unwrap_or_default(), + distinct: distinct__.unwrap_or_default(), + ignore_nulls: ignore_nulls__.unwrap_or_default(), + fun_definition: fun_definition__, + aggregate_function: aggregate_function__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalLikeExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalAggregateExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalNegativeNode { +impl serde::Serialize for PhysicalAliasNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19449,14 +13174,20 @@ impl serde::Serialize for PhysicalNegativeNode { if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNegativeNode", len)?; + if !self.alias.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalAliasNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } + if !self.alias.is_empty() { + struct_ser.serialize_field("alias", &self.alias)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { +impl<'de> serde::Deserialize<'de> for PhysicalAliasNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -19464,11 +13195,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { { const FIELDS: &[&str] = &[ "expr", + "alias", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, + Alias, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19491,6 +13224,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { { match value { "expr" => Ok(GeneratedField::Expr), + "alias" => Ok(GeneratedField::Alias), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19500,17 +13234,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalNegativeNode; + type Value = PhysicalAliasNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalNegativeNode") + formatter.write_str("struct datafusion.PhysicalAliasNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; + let mut alias__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -19519,17 +13254,24 @@ impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { } expr__ = map_.next_value()?; } + GeneratedField::Alias => { + if alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); + } + alias__ = Some(map_.next_value()?); + } } } - Ok(PhysicalNegativeNode { + Ok(PhysicalAliasNode { expr: expr__, + alias: alias__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalNegativeNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalAliasNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalNot { +impl serde::Serialize for PhysicalBinaryExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19537,29 +13279,45 @@ impl serde::Serialize for PhysicalNot { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.l.is_some() { + len += 1; + } + if self.r.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNot", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.op.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalBinaryExprNode", len)?; + if let Some(v) = self.l.as_ref() { + struct_ser.serialize_field("l", v)?; + } + if let Some(v) = self.r.as_ref() { + struct_ser.serialize_field("r", v)?; + } + if !self.op.is_empty() { + struct_ser.serialize_field("op", &self.op)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalNot { +impl<'de> serde::Deserialize<'de> for PhysicalBinaryExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "l", + "r", + "op", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + L, + R, + Op, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19581,7 +13339,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "l" => Ok(GeneratedField::L), + "r" => Ok(GeneratedField::R), + "op" => Ok(GeneratedField::Op), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19591,36 +13351,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalNot { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalNot; + type Value = PhysicalBinaryExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalNot") + formatter.write_str("struct datafusion.PhysicalBinaryExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut l__ = None; + let mut r__ = None; + let mut op__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::L => { + if l__.is_some() { + return Err(serde::de::Error::duplicate_field("l")); } - expr__ = map_.next_value()?; + l__ = map_.next_value()?; + } + GeneratedField::R => { + if r__.is_some() { + return Err(serde::de::Error::duplicate_field("r")); + } + r__ = map_.next_value()?; + } + GeneratedField::Op => { + if op__.is_some() { + return Err(serde::de::Error::duplicate_field("op")); + } + op__ = Some(map_.next_value()?); } } } - Ok(PhysicalNot { - expr: expr__, + Ok(PhysicalBinaryExprNode { + l: l__, + r: r__, + op: op__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalNot", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalBinaryExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalPlanNode { +impl serde::Serialize for PhysicalCaseNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -19628,183 +13404,47 @@ impl serde::Serialize for PhysicalPlanNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.physical_plan_type.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalPlanNode", len)?; - if let Some(v) = self.physical_plan_type.as_ref() { - match v { - physical_plan_node::PhysicalPlanType::ParquetScan(v) => { - struct_ser.serialize_field("parquetScan", v)?; - } - physical_plan_node::PhysicalPlanType::CsvScan(v) => { - struct_ser.serialize_field("csvScan", v)?; - } - physical_plan_node::PhysicalPlanType::Empty(v) => { - struct_ser.serialize_field("empty", v)?; - } - physical_plan_node::PhysicalPlanType::Projection(v) => { - struct_ser.serialize_field("projection", v)?; - } - physical_plan_node::PhysicalPlanType::GlobalLimit(v) => { - struct_ser.serialize_field("globalLimit", v)?; - } - physical_plan_node::PhysicalPlanType::LocalLimit(v) => { - struct_ser.serialize_field("localLimit", v)?; - } - physical_plan_node::PhysicalPlanType::Aggregate(v) => { - struct_ser.serialize_field("aggregate", v)?; - } - physical_plan_node::PhysicalPlanType::HashJoin(v) => { - struct_ser.serialize_field("hashJoin", v)?; - } - physical_plan_node::PhysicalPlanType::Sort(v) => { - struct_ser.serialize_field("sort", v)?; - } - physical_plan_node::PhysicalPlanType::CoalesceBatches(v) => { - struct_ser.serialize_field("coalesceBatches", v)?; - } - physical_plan_node::PhysicalPlanType::Filter(v) => { - struct_ser.serialize_field("filter", v)?; - } - physical_plan_node::PhysicalPlanType::Merge(v) => { - struct_ser.serialize_field("merge", v)?; - } - physical_plan_node::PhysicalPlanType::Repartition(v) => { - struct_ser.serialize_field("repartition", v)?; - } - physical_plan_node::PhysicalPlanType::Window(v) => { - struct_ser.serialize_field("window", v)?; - } - physical_plan_node::PhysicalPlanType::CrossJoin(v) => { - struct_ser.serialize_field("crossJoin", v)?; - } - physical_plan_node::PhysicalPlanType::AvroScan(v) => { - struct_ser.serialize_field("avroScan", v)?; - } - physical_plan_node::PhysicalPlanType::Extension(v) => { - struct_ser.serialize_field("extension", v)?; - } - physical_plan_node::PhysicalPlanType::Union(v) => { - struct_ser.serialize_field("union", v)?; - } - physical_plan_node::PhysicalPlanType::Explain(v) => { - struct_ser.serialize_field("explain", v)?; - } - physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) => { - struct_ser.serialize_field("sortPreservingMerge", v)?; - } - physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => { - struct_ser.serialize_field("nestedLoopJoin", v)?; - } - physical_plan_node::PhysicalPlanType::Analyze(v) => { - struct_ser.serialize_field("analyze", v)?; - } - physical_plan_node::PhysicalPlanType::JsonSink(v) => { - struct_ser.serialize_field("jsonSink", v)?; - } - physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { - struct_ser.serialize_field("symmetricHashJoin", v)?; - } - physical_plan_node::PhysicalPlanType::Interleave(v) => { - struct_ser.serialize_field("interleave", v)?; - } - physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { - struct_ser.serialize_field("placeholderRow", v)?; - } - physical_plan_node::PhysicalPlanType::CsvSink(v) => { - struct_ser.serialize_field("csvSink", v)?; - } - physical_plan_node::PhysicalPlanType::ParquetSink(v) => { - struct_ser.serialize_field("parquetSink", v)?; - } - } + if !self.when_then_expr.is_empty() { + len += 1; } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "parquet_scan", - "parquetScan", - "csv_scan", - "csvScan", - "empty", - "projection", - "global_limit", - "globalLimit", - "local_limit", - "localLimit", - "aggregate", - "hash_join", - "hashJoin", - "sort", - "coalesce_batches", - "coalesceBatches", - "filter", - "merge", - "repartition", - "window", - "cross_join", - "crossJoin", - "avro_scan", - "avroScan", - "extension", - "union", - "explain", - "sort_preserving_merge", - "sortPreservingMerge", - "nested_loop_join", - "nestedLoopJoin", - "analyze", - "json_sink", - "jsonSink", - "symmetric_hash_join", - "symmetricHashJoin", - "interleave", - "placeholder_row", - "placeholderRow", - "csv_sink", - "csvSink", - "parquet_sink", - "parquetSink", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - ParquetScan, - CsvScan, - Empty, - Projection, - GlobalLimit, - LocalLimit, - Aggregate, - HashJoin, - Sort, - CoalesceBatches, - Filter, - Merge, - Repartition, - Window, - CrossJoin, - AvroScan, - Extension, - Union, - Explain, - SortPreservingMerge, - NestedLoopJoin, - Analyze, - JsonSink, - SymmetricHashJoin, - Interleave, - PlaceholderRow, - CsvSink, - ParquetSink, + if self.else_expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCaseNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if !self.when_then_expr.is_empty() { + struct_ser.serialize_field("whenThenExpr", &self.when_then_expr)?; + } + if let Some(v) = self.else_expr.as_ref() { + struct_ser.serialize_field("elseExpr", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + "when_then_expr", + "whenThenExpr", + "else_expr", + "elseExpr", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + WhenThenExpr, + ElseExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -19826,34 +13466,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { E: serde::de::Error, { match value { - "parquetScan" | "parquet_scan" => Ok(GeneratedField::ParquetScan), - "csvScan" | "csv_scan" => Ok(GeneratedField::CsvScan), - "empty" => Ok(GeneratedField::Empty), - "projection" => Ok(GeneratedField::Projection), - "globalLimit" | "global_limit" => Ok(GeneratedField::GlobalLimit), - "localLimit" | "local_limit" => Ok(GeneratedField::LocalLimit), - "aggregate" => Ok(GeneratedField::Aggregate), - "hashJoin" | "hash_join" => Ok(GeneratedField::HashJoin), - "sort" => Ok(GeneratedField::Sort), - "coalesceBatches" | "coalesce_batches" => Ok(GeneratedField::CoalesceBatches), - "filter" => Ok(GeneratedField::Filter), - "merge" => Ok(GeneratedField::Merge), - "repartition" => Ok(GeneratedField::Repartition), - "window" => Ok(GeneratedField::Window), - "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), - "avroScan" | "avro_scan" => Ok(GeneratedField::AvroScan), - "extension" => Ok(GeneratedField::Extension), - "union" => Ok(GeneratedField::Union), - "explain" => Ok(GeneratedField::Explain), - "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), - "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), - "analyze" => Ok(GeneratedField::Analyze), - "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), - "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), - "interleave" => Ok(GeneratedField::Interleave), - "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), - "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), - "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), + "expr" => Ok(GeneratedField::Expr), + "whenThenExpr" | "when_then_expr" => Ok(GeneratedField::WhenThenExpr), + "elseExpr" | "else_expr" => Ok(GeneratedField::ElseExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -19863,226 +13478,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalPlanNode; + type Value = PhysicalCaseNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalPlanNode") + formatter.write_str("struct datafusion.PhysicalCaseNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut physical_plan_type__ = None; + let mut expr__ = None; + let mut when_then_expr__ = None; + let mut else_expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ParquetScan => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("parquetScan")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetScan) -; - } - GeneratedField::CsvScan => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("csvScan")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvScan) -; - } - GeneratedField::Empty => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("empty")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Empty) -; - } - GeneratedField::Projection => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("projection")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Projection) -; - } - GeneratedField::GlobalLimit => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("globalLimit")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GlobalLimit) -; - } - GeneratedField::LocalLimit => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("localLimit")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::LocalLimit) -; - } - GeneratedField::Aggregate => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("aggregate")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Aggregate) -; - } - GeneratedField::HashJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("hashJoin")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::HashJoin) -; - } - GeneratedField::Sort => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sort")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sort) -; - } - GeneratedField::CoalesceBatches => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("coalesceBatches")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CoalesceBatches) -; - } - GeneratedField::Filter => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Filter) -; - } - GeneratedField::Merge => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("merge")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Merge) -; - } - GeneratedField::Repartition => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("repartition")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Repartition) -; - } - GeneratedField::Window => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("window")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Window) -; - } - GeneratedField::CrossJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("crossJoin")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CrossJoin) -; - } - GeneratedField::AvroScan => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("avroScan")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AvroScan) -; - } - GeneratedField::Extension => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("extension")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Extension) -; - } - GeneratedField::Union => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("union")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Union) -; - } - GeneratedField::Explain => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("explain")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Explain) -; - } - GeneratedField::SortPreservingMerge => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("sortPreservingMerge")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) -; - } - GeneratedField::NestedLoopJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("nestedLoopJoin")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin) -; - } - GeneratedField::Analyze => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("analyze")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) -; - } - GeneratedField::JsonSink => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("jsonSink")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) -; - } - GeneratedField::SymmetricHashJoin => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) -; - } - GeneratedField::Interleave => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("interleave")); - } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) -; - } - GeneratedField::PlaceholderRow => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("placeholderRow")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) -; + expr__ = map_.next_value()?; } - GeneratedField::CsvSink => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("csvSink")); + GeneratedField::WhenThenExpr => { + if when_then_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("whenThenExpr")); } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) -; + when_then_expr__ = Some(map_.next_value()?); } - GeneratedField::ParquetSink => { - if physical_plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("parquetSink")); + GeneratedField::ElseExpr => { + if else_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("elseExpr")); } - physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) -; + else_expr__ = map_.next_value()?; } } } - Ok(PhysicalPlanNode { - physical_plan_type: physical_plan_type__, + Ok(PhysicalCaseNode { + expr: expr__, + when_then_expr: when_then_expr__.unwrap_or_default(), + else_expr: else_expr__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalCaseNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalScalarUdfNode { +impl serde::Serialize for PhysicalCastNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20090,36 +13531,132 @@ impl serde::Serialize for PhysicalScalarUdfNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { + if self.expr.is_some() { len += 1; } - if !self.args.is_empty() { + if self.arrow_type.is_some() { len += 1; } - if self.fun_definition.is_some() { + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalCastNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + "arrow_type", + "arrowType", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + ArrowType, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalCastNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalCastNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + let mut arrow_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); + } + arrow_type__ = map_.next_value()?; + } + } + } + Ok(PhysicalCastNode { + expr: expr__, + arrow_type: arrow_type__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalCastNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PhysicalColumn { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { len += 1; } - if self.return_type.is_some() { + if self.index != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalColumn", len)?; if !self.name.is_empty() { struct_ser.serialize_field("name", &self.name)?; } - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; - } - if let Some(v) = self.fun_definition.as_ref() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; - } - if let Some(v) = self.return_type.as_ref() { - struct_ser.serialize_field("returnType", v)?; + if self.index != 0 { + struct_ser.serialize_field("index", &self.index)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { +impl<'de> serde::Deserialize<'de> for PhysicalColumn { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -20127,19 +13664,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { const FIELDS: &[&str] = &[ "name", - "args", - "fun_definition", - "funDefinition", - "return_type", - "returnType", + "index", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Name, - Args, - FunDefinition, - ReturnType, + Index, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20162,9 +13693,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { match value { "name" => Ok(GeneratedField::Name), - "args" => Ok(GeneratedField::Args), - "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), - "returnType" | "return_type" => Ok(GeneratedField::ReturnType), + "index" => Ok(GeneratedField::Index), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20174,20 +13703,18 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalScalarUdfNode; + type Value = PhysicalColumn; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalScalarUdfNode") + formatter.write_str("struct datafusion.PhysicalColumn") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut name__ = None; - let mut args__ = None; - let mut fun_definition__ = None; - let mut return_type__ = None; + let mut index__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Name => { @@ -20196,40 +13723,26 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } name__ = Some(map_.next_value()?); } - GeneratedField::Args => { - if args__.is_some() { - return Err(serde::de::Error::duplicate_field("args")); - } - args__ = Some(map_.next_value()?); - } - GeneratedField::FunDefinition => { - if fun_definition__.is_some() { - return Err(serde::de::Error::duplicate_field("funDefinition")); + GeneratedField::Index => { + if index__.is_some() { + return Err(serde::de::Error::duplicate_field("index")); } - fun_definition__ = - map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + index__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } - GeneratedField::ReturnType => { - if return_type__.is_some() { - return Err(serde::de::Error::duplicate_field("returnType")); - } - return_type__ = map_.next_value()?; - } } } - Ok(PhysicalScalarUdfNode { + Ok(PhysicalColumn { name: name__.unwrap_or_default(), - args: args__.unwrap_or_default(), - fun_definition: fun_definition__, - return_type: return_type__, + index: index__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalScalarUdfNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalColumn", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalSortExprNode { +impl serde::Serialize for PhysicalDateTimeIntervalExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20237,46 +13750,45 @@ impl serde::Serialize for PhysicalSortExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.l.is_some() { len += 1; } - if self.asc { + if self.r.is_some() { len += 1; } - if self.nulls_first { + if !self.op.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", len)?; + if let Some(v) = self.l.as_ref() { + struct_ser.serialize_field("l", v)?; } - if self.asc { - struct_ser.serialize_field("asc", &self.asc)?; + if let Some(v) = self.r.as_ref() { + struct_ser.serialize_field("r", v)?; } - if self.nulls_first { - struct_ser.serialize_field("nullsFirst", &self.nulls_first)?; + if !self.op.is_empty() { + struct_ser.serialize_field("op", &self.op)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalDateTimeIntervalExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", - "asc", - "nulls_first", - "nullsFirst", + "l", + "r", + "op", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, - Asc, - NullsFirst, + L, + R, + Op, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20298,9 +13810,9 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), - "asc" => Ok(GeneratedField::Asc), - "nullsFirst" | "nulls_first" => Ok(GeneratedField::NullsFirst), + "l" => Ok(GeneratedField::L), + "r" => Ok(GeneratedField::R), + "op" => Ok(GeneratedField::Op), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20310,52 +13822,52 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalSortExprNode; + type Value = PhysicalDateTimeIntervalExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalSortExprNode") + formatter.write_str("struct datafusion.PhysicalDateTimeIntervalExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where - V: serde::de::MapAccess<'de>, - { - let mut expr__ = None; - let mut asc__ = None; - let mut nulls_first__ = None; + V: serde::de::MapAccess<'de>, + { + let mut l__ = None; + let mut r__ = None; + let mut op__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::L => { + if l__.is_some() { + return Err(serde::de::Error::duplicate_field("l")); } - expr__ = map_.next_value()?; + l__ = map_.next_value()?; } - GeneratedField::Asc => { - if asc__.is_some() { - return Err(serde::de::Error::duplicate_field("asc")); + GeneratedField::R => { + if r__.is_some() { + return Err(serde::de::Error::duplicate_field("r")); } - asc__ = Some(map_.next_value()?); + r__ = map_.next_value()?; } - GeneratedField::NullsFirst => { - if nulls_first__.is_some() { - return Err(serde::de::Error::duplicate_field("nullsFirst")); + GeneratedField::Op => { + if op__.is_some() { + return Err(serde::de::Error::duplicate_field("op")); } - nulls_first__ = Some(map_.next_value()?); + op__ = Some(map_.next_value()?); } } } - Ok(PhysicalSortExprNode { - expr: expr__, - asc: asc__.unwrap_or_default(), - nulls_first: nulls_first__.unwrap_or_default(), + Ok(PhysicalDateTimeIntervalExprNode { + l: l__, + r: r__, + op: op__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalSortExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalDateTimeIntervalExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalSortExprNodeCollection { +impl serde::Serialize for PhysicalExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20363,30 +13875,124 @@ impl serde::Serialize for PhysicalSortExprNodeCollection { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.physical_sort_expr_nodes.is_empty() { + if self.expr_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNodeCollection", len)?; - if !self.physical_sort_expr_nodes.is_empty() { - struct_ser.serialize_field("physicalSortExprNodes", &self.physical_sort_expr_nodes)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExprNode", len)?; + if let Some(v) = self.expr_type.as_ref() { + match v { + physical_expr_node::ExprType::Column(v) => { + struct_ser.serialize_field("column", v)?; + } + physical_expr_node::ExprType::Literal(v) => { + struct_ser.serialize_field("literal", v)?; + } + physical_expr_node::ExprType::BinaryExpr(v) => { + struct_ser.serialize_field("binaryExpr", v)?; + } + physical_expr_node::ExprType::AggregateExpr(v) => { + struct_ser.serialize_field("aggregateExpr", v)?; + } + physical_expr_node::ExprType::IsNullExpr(v) => { + struct_ser.serialize_field("isNullExpr", v)?; + } + physical_expr_node::ExprType::IsNotNullExpr(v) => { + struct_ser.serialize_field("isNotNullExpr", v)?; + } + physical_expr_node::ExprType::NotExpr(v) => { + struct_ser.serialize_field("notExpr", v)?; + } + physical_expr_node::ExprType::Case(v) => { + struct_ser.serialize_field("case", v)?; + } + physical_expr_node::ExprType::Cast(v) => { + struct_ser.serialize_field("cast", v)?; + } + physical_expr_node::ExprType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + physical_expr_node::ExprType::Negative(v) => { + struct_ser.serialize_field("negative", v)?; + } + physical_expr_node::ExprType::InList(v) => { + struct_ser.serialize_field("inList", v)?; + } + physical_expr_node::ExprType::TryCast(v) => { + struct_ser.serialize_field("tryCast", v)?; + } + physical_expr_node::ExprType::WindowExpr(v) => { + struct_ser.serialize_field("windowExpr", v)?; + } + physical_expr_node::ExprType::ScalarUdf(v) => { + struct_ser.serialize_field("scalarUdf", v)?; + } + physical_expr_node::ExprType::LikeExpr(v) => { + struct_ser.serialize_field("likeExpr", v)?; + } + physical_expr_node::ExprType::Extension(v) => { + struct_ser.serialize_field("extension", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { +impl<'de> serde::Deserialize<'de> for PhysicalExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "physical_sort_expr_nodes", - "physicalSortExprNodes", + "column", + "literal", + "binary_expr", + "binaryExpr", + "aggregate_expr", + "aggregateExpr", + "is_null_expr", + "isNullExpr", + "is_not_null_expr", + "isNotNullExpr", + "not_expr", + "notExpr", + "case_", + "case", + "cast", + "sort", + "negative", + "in_list", + "inList", + "try_cast", + "tryCast", + "window_expr", + "windowExpr", + "scalar_udf", + "scalarUdf", + "like_expr", + "likeExpr", + "extension", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - PhysicalSortExprNodes, + Column, + Literal, + BinaryExpr, + AggregateExpr, + IsNullExpr, + IsNotNullExpr, + NotExpr, + Case, + Cast, + Sort, + Negative, + InList, + TryCast, + WindowExpr, + ScalarUdf, + LikeExpr, + Extension, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20408,46 +14014,175 @@ impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { E: serde::de::Error, { match value { - "physicalSortExprNodes" | "physical_sort_expr_nodes" => Ok(GeneratedField::PhysicalSortExprNodes), + "column" => Ok(GeneratedField::Column), + "literal" => Ok(GeneratedField::Literal), + "binaryExpr" | "binary_expr" => Ok(GeneratedField::BinaryExpr), + "aggregateExpr" | "aggregate_expr" => Ok(GeneratedField::AggregateExpr), + "isNullExpr" | "is_null_expr" => Ok(GeneratedField::IsNullExpr), + "isNotNullExpr" | "is_not_null_expr" => Ok(GeneratedField::IsNotNullExpr), + "notExpr" | "not_expr" => Ok(GeneratedField::NotExpr), + "case" | "case_" => Ok(GeneratedField::Case), + "cast" => Ok(GeneratedField::Cast), + "sort" => Ok(GeneratedField::Sort), + "negative" => Ok(GeneratedField::Negative), + "inList" | "in_list" => Ok(GeneratedField::InList), + "tryCast" | "try_cast" => Ok(GeneratedField::TryCast), + "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), + "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), + "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), + "extension" => Ok(GeneratedField::Extension), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalSortExprNodeCollection; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalSortExprNodeCollection") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut physical_sort_expr_nodes__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::PhysicalSortExprNodes => { - if physical_sort_expr_nodes__.is_some() { - return Err(serde::de::Error::duplicate_field("physicalSortExprNodes")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr_type__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Column => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("column")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Column) +; + } + GeneratedField::Literal => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("literal")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Literal) +; + } + GeneratedField::BinaryExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("binaryExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::BinaryExpr) +; + } + GeneratedField::AggregateExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregateExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::AggregateExpr) +; + } + GeneratedField::IsNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNullExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNullExpr) +; + } + GeneratedField::IsNotNullExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("isNotNullExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::IsNotNullExpr) +; + } + GeneratedField::NotExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("notExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::NotExpr) +; + } + GeneratedField::Case => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("case")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Case) +; + } + GeneratedField::Cast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cast")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Cast) +; + } + GeneratedField::Sort => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Sort) +; + } + GeneratedField::Negative => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("negative")); } - physical_sort_expr_nodes__ = Some(map_.next_value()?); + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Negative) +; + } + GeneratedField::InList => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("inList")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::InList) +; + } + GeneratedField::TryCast => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("tryCast")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::TryCast) +; + } + GeneratedField::WindowExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("windowExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::WindowExpr) +; + } + GeneratedField::ScalarUdf => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarUdf")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarUdf) +; + } + GeneratedField::LikeExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("likeExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) +; + } + GeneratedField::Extension => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("extension")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Extension) +; } } } - Ok(PhysicalSortExprNodeCollection { - physical_sort_expr_nodes: physical_sort_expr_nodes__.unwrap_or_default(), + Ok(PhysicalExprNode { + expr_type: expr_type__, }) } } - deserializer.deserialize_struct("datafusion.PhysicalSortExprNodeCollection", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalTryCastNode { +impl serde::Serialize for PhysicalExtensionExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20455,23 +14190,25 @@ impl serde::Serialize for PhysicalTryCastNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.expr.is_empty() { len += 1; } - if self.arrow_type.is_some() { + if !self.inputs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalTryCastNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionExprNode", len)?; + if !self.expr.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("expr", pbjson::private::base64::encode(&self.expr).as_str())?; } - if let Some(v) = self.arrow_type.as_ref() { - struct_ser.serialize_field("arrowType", v)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { +impl<'de> serde::Deserialize<'de> for PhysicalExtensionExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -20479,14 +14216,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { { const FIELDS: &[&str] = &[ "expr", - "arrow_type", - "arrowType", + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, - ArrowType, + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20509,7 +14245,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { { match value { "expr" => Ok(GeneratedField::Expr), - "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "inputs" => Ok(GeneratedField::Inputs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20519,44 +14255,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalTryCastNode; + type Value = PhysicalExtensionExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalTryCastNode") + formatter.write_str("struct datafusion.PhysicalExtensionExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut expr__ = None; - let mut arrow_type__ = None; + let mut inputs__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map_.next_value()?; + expr__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::ArrowType => { - if arrow_type__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowType")); + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); } - arrow_type__ = map_.next_value()?; + inputs__ = Some(map_.next_value()?); } } } - Ok(PhysicalTryCastNode { - expr: expr__, - arrow_type: arrow_type__, + Ok(PhysicalExtensionExprNode { + expr: expr__.unwrap_or_default(), + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalTryCastNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalExtensionExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalWhenThen { +impl serde::Serialize for PhysicalExtensionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20564,39 +14302,39 @@ impl serde::Serialize for PhysicalWhenThen { { use serde::ser::SerializeStruct; let mut len = 0; - if self.when_expr.is_some() { + if !self.node.is_empty() { len += 1; } - if self.then_expr.is_some() { + if !self.inputs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWhenThen", len)?; - if let Some(v) = self.when_expr.as_ref() { - struct_ser.serialize_field("whenExpr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalExtensionNode", len)?; + if !self.node.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("node", pbjson::private::base64::encode(&self.node).as_str())?; } - if let Some(v) = self.then_expr.as_ref() { - struct_ser.serialize_field("thenExpr", v)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { +impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "when_expr", - "whenExpr", - "then_expr", - "thenExpr", + "node", + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - WhenExpr, - ThenExpr, + Node, + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20618,8 +14356,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { E: serde::de::Error, { match value { - "whenExpr" | "when_expr" => Ok(GeneratedField::WhenExpr), - "thenExpr" | "then_expr" => Ok(GeneratedField::ThenExpr), + "node" => Ok(GeneratedField::Node), + "inputs" => Ok(GeneratedField::Inputs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20629,44 +14367,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalWhenThen; + type Value = PhysicalExtensionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalWhenThen") + formatter.write_str("struct datafusion.PhysicalExtensionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut when_expr__ = None; - let mut then_expr__ = None; + let mut node__ = None; + let mut inputs__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::WhenExpr => { - if when_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("whenExpr")); + GeneratedField::Node => { + if node__.is_some() { + return Err(serde::de::Error::duplicate_field("node")); } - when_expr__ = map_.next_value()?; + node__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; } - GeneratedField::ThenExpr => { - if then_expr__.is_some() { - return Err(serde::de::Error::duplicate_field("thenExpr")); + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); } - then_expr__ = map_.next_value()?; + inputs__ = Some(map_.next_value()?); } } } - Ok(PhysicalWhenThen { - when_expr: when_expr__, - then_expr: then_expr__, + Ok(PhysicalExtensionNode { + node: node__.unwrap_or_default(), + inputs: inputs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalWhenThen", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalExtensionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalWindowExprNode { +impl serde::Serialize for PhysicalHashRepartition { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20674,87 +14414,41 @@ impl serde::Serialize for PhysicalWindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.args.is_empty() { - len += 1; - } - if !self.partition_by.is_empty() { - len += 1; - } - if !self.order_by.is_empty() { - len += 1; - } - if self.window_frame.is_some() { - len += 1; - } - if !self.name.is_empty() { + if !self.hash_expr.is_empty() { len += 1; } - if self.window_function.is_some() { + if self.partition_count != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWindowExprNode", len)?; - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; - } - if !self.partition_by.is_empty() { - struct_ser.serialize_field("partitionBy", &self.partition_by)?; - } - if !self.order_by.is_empty() { - struct_ser.serialize_field("orderBy", &self.order_by)?; - } - if let Some(v) = self.window_frame.as_ref() { - struct_ser.serialize_field("windowFrame", v)?; - } - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalHashRepartition", len)?; + if !self.hash_expr.is_empty() { + struct_ser.serialize_field("hashExpr", &self.hash_expr)?; } - if let Some(v) = self.window_function.as_ref() { - match v { - physical_window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } - physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { - let v = BuiltInWindowFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("builtInFunction", &v)?; - } - } + if self.partition_count != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("partitionCount", ToString::to_string(&self.partition_count).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { +impl<'de> serde::Deserialize<'de> for PhysicalHashRepartition { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "args", - "partition_by", - "partitionBy", - "order_by", - "orderBy", - "window_frame", - "windowFrame", - "name", - "aggr_function", - "aggrFunction", - "built_in_function", - "builtInFunction", + "hash_expr", + "hashExpr", + "partition_count", + "partitionCount", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Args, - PartitionBy, - OrderBy, - WindowFrame, - Name, - AggrFunction, - BuiltInFunction, + HashExpr, + PartitionCount, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20776,13 +14470,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { E: serde::de::Error, { match value { - "args" => Ok(GeneratedField::Args), - "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), - "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), - "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), - "name" => Ok(GeneratedField::Name), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), - "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), + "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "partitionCount" | "partition_count" => Ok(GeneratedField::PartitionCount), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20792,82 +14481,46 @@ impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalWindowExprNode; + type Value = PhysicalHashRepartition; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalWindowExprNode") + formatter.write_str("struct datafusion.PhysicalHashRepartition") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut args__ = None; - let mut partition_by__ = None; - let mut order_by__ = None; - let mut window_frame__ = None; - let mut name__ = None; - let mut window_function__ = None; + let mut hash_expr__ = None; + let mut partition_count__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Args => { - if args__.is_some() { - return Err(serde::de::Error::duplicate_field("args")); - } - args__ = Some(map_.next_value()?); - } - GeneratedField::PartitionBy => { - if partition_by__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionBy")); - } - partition_by__ = Some(map_.next_value()?); - } - GeneratedField::OrderBy => { - if order_by__.is_some() { - return Err(serde::de::Error::duplicate_field("orderBy")); - } - order_by__ = Some(map_.next_value()?); - } - GeneratedField::WindowFrame => { - if window_frame__.is_some() { - return Err(serde::de::Error::duplicate_field("windowFrame")); - } - window_frame__ = map_.next_value()?; - } - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); + GeneratedField::HashExpr => { + if hash_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("hashExpr")); } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::AggrFunction(x as i32)); + hash_expr__ = Some(map_.next_value()?); } - GeneratedField::BuiltInFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("builtInFunction")); + GeneratedField::PartitionCount => { + if partition_count__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionCount")); } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); + partition_count__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(PhysicalWindowExprNode { - args: args__.unwrap_or_default(), - partition_by: partition_by__.unwrap_or_default(), - order_by: order_by__.unwrap_or_default(), - window_frame: window_frame__, - name: name__.unwrap_or_default(), - window_function: window_function__, + Ok(PhysicalHashRepartition { + hash_expr: hash_expr__.unwrap_or_default(), + partition_count: partition_count__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalHashRepartition", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PlaceholderNode { +impl serde::Serialize for PhysicalInListNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20875,38 +14528,45 @@ impl serde::Serialize for PlaceholderNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.id.is_empty() { + if self.expr.is_some() { len += 1; } - if self.data_type.is_some() { + if !self.list.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; - if !self.id.is_empty() { - struct_ser.serialize_field("id", &self.id)?; + if self.negated { + len += 1; } - if let Some(v) = self.data_type.as_ref() { - struct_ser.serialize_field("dataType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalInListNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if !self.list.is_empty() { + struct_ser.serialize_field("list", &self.list)?; + } + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PlaceholderNode { +impl<'de> serde::Deserialize<'de> for PhysicalInListNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "id", - "data_type", - "dataType", + "expr", + "list", + "negated", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Id, - DataType, + Expr, + List, + Negated, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -20928,8 +14588,9 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { E: serde::de::Error, { match value { - "id" => Ok(GeneratedField::Id), - "dataType" | "data_type" => Ok(GeneratedField::DataType), + "expr" => Ok(GeneratedField::Expr), + "list" => Ok(GeneratedField::List), + "negated" => Ok(GeneratedField::Negated), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -20939,44 +14600,52 @@ impl<'de> serde::Deserialize<'de> for PlaceholderNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlaceholderNode; + type Value = PhysicalInListNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlaceholderNode") + formatter.write_str("struct datafusion.PhysicalInListNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut id__ = None; - let mut data_type__ = None; + let mut expr__ = None; + let mut list__ = None; + let mut negated__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Id => { - if id__.is_some() { - return Err(serde::de::Error::duplicate_field("id")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - id__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::DataType => { - if data_type__.is_some() { - return Err(serde::de::Error::duplicate_field("dataType")); + GeneratedField::List => { + if list__.is_some() { + return Err(serde::de::Error::duplicate_field("list")); } - data_type__ = map_.next_value()?; + list__ = Some(map_.next_value()?); + } + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); } } } - Ok(PlaceholderNode { - id: id__.unwrap_or_default(), - data_type: data_type__, + Ok(PhysicalInListNode { + expr: expr__, + list: list__.unwrap_or_default(), + negated: negated__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalInListNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PlaceholderRowExecNode { +impl serde::Serialize for PhysicalIsNotNull { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -20984,29 +14653,29 @@ impl serde::Serialize for PlaceholderRowExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.schema.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNotNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { +impl<'de> serde::Deserialize<'de> for PhysicalIsNotNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "schema", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Schema, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21028,7 +14697,7 @@ impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { E: serde::de::Error, { match value { - "schema" => Ok(GeneratedField::Schema), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21038,115 +14707,66 @@ impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlaceholderRowExecNode; + type Value = PhysicalIsNotNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlaceholderRowExecNode") + formatter.write_str("struct datafusion.PhysicalIsNotNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut schema__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - schema__ = map_.next_value()?; + expr__ = map_.next_value()?; } - } - } - Ok(PlaceholderRowExecNode { - schema: schema__, - }) - } - } - deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for PlanType { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.plan_type_enum.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PlanType", len)?; - if let Some(v) = self.plan_type_enum.as_ref() { - match v { - plan_type::PlanTypeEnum::InitialLogicalPlan(v) => { - struct_ser.serialize_field("InitialLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::AnalyzedLogicalPlan(v) => { - struct_ser.serialize_field("AnalyzedLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan(v) => { - struct_ser.serialize_field("FinalAnalyzedLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::OptimizedLogicalPlan(v) => { - struct_ser.serialize_field("OptimizedLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalLogicalPlan(v) => { - struct_ser.serialize_field("FinalLogicalPlan", v)?; - } - plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { - struct_ser.serialize_field("InitialPhysicalPlan", v)?; - } - plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { - struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; - } - plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { - struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { - struct_ser.serialize_field("FinalPhysicalPlan", v)?; - } - plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { - struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; + } } + Ok(PhysicalIsNotNull { + expr: expr__, + }) } } + deserializer.deserialize_struct("datafusion.PhysicalIsNotNull", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for PhysicalIsNull { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalIsNull", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PlanType { +impl<'de> serde::Deserialize<'de> for PhysicalIsNull { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "InitialLogicalPlan", - "AnalyzedLogicalPlan", - "FinalAnalyzedLogicalPlan", - "OptimizedLogicalPlan", - "FinalLogicalPlan", - "InitialPhysicalPlan", - "InitialPhysicalPlanWithStats", - "OptimizedPhysicalPlan", - "FinalPhysicalPlan", - "FinalPhysicalPlanWithStats", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - InitialLogicalPlan, - AnalyzedLogicalPlan, - FinalAnalyzedLogicalPlan, - OptimizedLogicalPlan, - FinalLogicalPlan, - InitialPhysicalPlan, - InitialPhysicalPlanWithStats, - OptimizedPhysicalPlan, - FinalPhysicalPlan, - FinalPhysicalPlanWithStats, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21168,16 +14788,7 @@ impl<'de> serde::Deserialize<'de> for PlanType { E: serde::de::Error, { match value { - "InitialLogicalPlan" => Ok(GeneratedField::InitialLogicalPlan), - "AnalyzedLogicalPlan" => Ok(GeneratedField::AnalyzedLogicalPlan), - "FinalAnalyzedLogicalPlan" => Ok(GeneratedField::FinalAnalyzedLogicalPlan), - "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), - "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), - "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), - "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), - "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), - "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), - "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21187,100 +14798,36 @@ impl<'de> serde::Deserialize<'de> for PlanType { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PlanType; + type Value = PhysicalIsNull; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PlanType") + formatter.write_str("struct datafusion.PhysicalIsNull") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut plan_type_enum__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::InitialLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialLogicalPlan) -; - } - GeneratedField::AnalyzedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("AnalyzedLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::AnalyzedLogicalPlan) -; - } - GeneratedField::FinalAnalyzedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalAnalyzedLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan) -; - } - GeneratedField::OptimizedLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("OptimizedLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedLogicalPlan) -; - } - GeneratedField::FinalLogicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalLogicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalLogicalPlan) -; - } - GeneratedField::InitialPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) -; - } - GeneratedField::InitialPhysicalPlanWithStats => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) -; - } - GeneratedField::OptimizedPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("OptimizedPhysicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedPhysicalPlan) -; - } - GeneratedField::FinalPhysicalPlan => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); - } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) -; - } - GeneratedField::FinalPhysicalPlanWithStats => { - if plan_type_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) -; + expr__ = map_.next_value()?; } } } - Ok(PlanType { - plan_type_enum: plan_type_enum__, + Ok(PhysicalIsNull { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalIsNull", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Precision { +impl serde::Serialize for PhysicalLikeExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -21288,40 +14835,54 @@ impl serde::Serialize for Precision { { use serde::ser::SerializeStruct; let mut len = 0; - if self.precision_info != 0 { + if self.negated { + len += 1; + } + if self.case_insensitive { len += 1; } - if self.val.is_some() { + if self.expr.is_some() { + len += 1; + } + if self.pattern.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Precision", len)?; - if self.precision_info != 0 { - let v = PrecisionInfo::try_from(self.precision_info) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.precision_info)))?; - struct_ser.serialize_field("precisionInfo", &v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalLikeExprNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if self.case_insensitive { + struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.val.as_ref() { - struct_ser.serialize_field("val", v)?; + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Precision { +impl<'de> serde::Deserialize<'de> for PhysicalLikeExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "precision_info", - "precisionInfo", - "val", + "negated", + "case_insensitive", + "caseInsensitive", + "expr", + "pattern", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - PrecisionInfo, - Val, + Negated, + CaseInsensitive, + Expr, + Pattern, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21343,8 +14904,10 @@ impl<'de> serde::Deserialize<'de> for Precision { E: serde::de::Error, { match value { - "precisionInfo" | "precision_info" => Ok(GeneratedField::PrecisionInfo), - "val" => Ok(GeneratedField::Val), + "negated" => Ok(GeneratedField::Negated), + "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21354,118 +14917,151 @@ impl<'de> serde::Deserialize<'de> for Precision { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Precision; + type Value = PhysicalLikeExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Precision") + formatter.write_str("struct datafusion.PhysicalLikeExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut precision_info__ = None; - let mut val__ = None; + let mut negated__ = None; + let mut case_insensitive__ = None; + let mut expr__ = None; + let mut pattern__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::PrecisionInfo => { - if precision_info__.is_some() { - return Err(serde::de::Error::duplicate_field("precisionInfo")); + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::CaseInsensitive => { + if case_insensitive__.is_some() { + return Err(serde::de::Error::duplicate_field("caseInsensitive")); + } + case_insensitive__ = Some(map_.next_value()?); + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - precision_info__ = Some(map_.next_value::()? as i32); + expr__ = map_.next_value()?; } - GeneratedField::Val => { - if val__.is_some() { - return Err(serde::de::Error::duplicate_field("val")); + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); } - val__ = map_.next_value()?; + pattern__ = map_.next_value()?; } } } - Ok(Precision { - precision_info: precision_info__.unwrap_or_default(), - val: val__, + Ok(PhysicalLikeExprNode { + negated: negated__.unwrap_or_default(), + case_insensitive: case_insensitive__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, }) } } - deserializer.deserialize_struct("datafusion.Precision", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalLikeExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PrecisionInfo { +impl serde::Serialize for PhysicalNegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where S: serde::Serializer, { - let variant = match self { - Self::Exact => "EXACT", - Self::Inexact => "INEXACT", - Self::Absent => "ABSENT", - }; - serializer.serialize_str(variant) + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNegativeNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PrecisionInfo { +impl<'de> serde::Deserialize<'de> for PhysicalNegativeNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "EXACT", - "INEXACT", - "ABSENT", + "expr", ]; + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PrecisionInfo; + type Value = PhysicalNegativeNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + formatter.write_str("struct datafusion.PhysicalNegativeNode") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "EXACT" => Ok(PrecisionInfo::Exact), - "INEXACT" => Ok(PrecisionInfo::Inexact), - "ABSENT" => Ok(PrecisionInfo::Absent), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut expr__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + } } + Ok(PhysicalNegativeNode { + expr: expr__, + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalNegativeNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PrepareNode { +impl serde::Serialize for PhysicalNot { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -21473,46 +15069,29 @@ impl serde::Serialize for PrepareNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.name.is_empty() { - len += 1; - } - if !self.data_types.is_empty() { - len += 1; - } - if self.input.is_some() { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PrepareNode", len)?; - if !self.name.is_empty() { - struct_ser.serialize_field("name", &self.name)?; - } - if !self.data_types.is_empty() { - struct_ser.serialize_field("dataTypes", &self.data_types)?; - } - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalNot", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PrepareNode { +impl<'de> serde::Deserialize<'de> for PhysicalNot { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "name", - "data_types", - "dataTypes", - "input", + "expr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Name, - DataTypes, - Input, + Expr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21534,9 +15113,7 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { E: serde::de::Error, { match value { - "name" => Ok(GeneratedField::Name), - "dataTypes" | "data_types" => Ok(GeneratedField::DataTypes), - "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21546,52 +15123,36 @@ impl<'de> serde::Deserialize<'de> for PrepareNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PrepareNode; + type Value = PhysicalNot; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PrepareNode") + formatter.write_str("struct datafusion.PhysicalNot") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut name__ = None; - let mut data_types__ = None; - let mut input__ = None; + let mut expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = Some(map_.next_value()?); - } - GeneratedField::DataTypes => { - if data_types__.is_some() { - return Err(serde::de::Error::duplicate_field("dataTypes")); - } - data_types__ = Some(map_.next_value()?); - } - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - input__ = map_.next_value()?; + expr__ = map_.next_value()?; } } } - Ok(PrepareNode { - name: name__.unwrap_or_default(), - data_types: data_types__.unwrap_or_default(), - input: input__, + Ok(PhysicalNot { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.PrepareNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalNot", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PrimaryKeyConstraint { +impl serde::Serialize for PhysicalPlanNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -21599,29 +15160,188 @@ impl serde::Serialize for PrimaryKeyConstraint { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.indices.is_empty() { + if self.physical_plan_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.PrimaryKeyConstraint", len)?; - if !self.indices.is_empty() { - struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalPlanNode", len)?; + if let Some(v) = self.physical_plan_type.as_ref() { + match v { + physical_plan_node::PhysicalPlanType::ParquetScan(v) => { + struct_ser.serialize_field("parquetScan", v)?; + } + physical_plan_node::PhysicalPlanType::CsvScan(v) => { + struct_ser.serialize_field("csvScan", v)?; + } + physical_plan_node::PhysicalPlanType::Empty(v) => { + struct_ser.serialize_field("empty", v)?; + } + physical_plan_node::PhysicalPlanType::Projection(v) => { + struct_ser.serialize_field("projection", v)?; + } + physical_plan_node::PhysicalPlanType::GlobalLimit(v) => { + struct_ser.serialize_field("globalLimit", v)?; + } + physical_plan_node::PhysicalPlanType::LocalLimit(v) => { + struct_ser.serialize_field("localLimit", v)?; + } + physical_plan_node::PhysicalPlanType::Aggregate(v) => { + struct_ser.serialize_field("aggregate", v)?; + } + physical_plan_node::PhysicalPlanType::HashJoin(v) => { + struct_ser.serialize_field("hashJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Sort(v) => { + struct_ser.serialize_field("sort", v)?; + } + physical_plan_node::PhysicalPlanType::CoalesceBatches(v) => { + struct_ser.serialize_field("coalesceBatches", v)?; + } + physical_plan_node::PhysicalPlanType::Filter(v) => { + struct_ser.serialize_field("filter", v)?; + } + physical_plan_node::PhysicalPlanType::Merge(v) => { + struct_ser.serialize_field("merge", v)?; + } + physical_plan_node::PhysicalPlanType::Repartition(v) => { + struct_ser.serialize_field("repartition", v)?; + } + physical_plan_node::PhysicalPlanType::Window(v) => { + struct_ser.serialize_field("window", v)?; + } + physical_plan_node::PhysicalPlanType::CrossJoin(v) => { + struct_ser.serialize_field("crossJoin", v)?; + } + physical_plan_node::PhysicalPlanType::AvroScan(v) => { + struct_ser.serialize_field("avroScan", v)?; + } + physical_plan_node::PhysicalPlanType::Extension(v) => { + struct_ser.serialize_field("extension", v)?; + } + physical_plan_node::PhysicalPlanType::Union(v) => { + struct_ser.serialize_field("union", v)?; + } + physical_plan_node::PhysicalPlanType::Explain(v) => { + struct_ser.serialize_field("explain", v)?; + } + physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) => { + struct_ser.serialize_field("sortPreservingMerge", v)?; + } + physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => { + struct_ser.serialize_field("nestedLoopJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Analyze(v) => { + struct_ser.serialize_field("analyze", v)?; + } + physical_plan_node::PhysicalPlanType::JsonSink(v) => { + struct_ser.serialize_field("jsonSink", v)?; + } + physical_plan_node::PhysicalPlanType::SymmetricHashJoin(v) => { + struct_ser.serialize_field("symmetricHashJoin", v)?; + } + physical_plan_node::PhysicalPlanType::Interleave(v) => { + struct_ser.serialize_field("interleave", v)?; + } + physical_plan_node::PhysicalPlanType::PlaceholderRow(v) => { + struct_ser.serialize_field("placeholderRow", v)?; + } + physical_plan_node::PhysicalPlanType::CsvSink(v) => { + struct_ser.serialize_field("csvSink", v)?; + } + physical_plan_node::PhysicalPlanType::ParquetSink(v) => { + struct_ser.serialize_field("parquetSink", v)?; + } + physical_plan_node::PhysicalPlanType::Unnest(v) => { + struct_ser.serialize_field("unnest", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { +impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "indices", + "parquet_scan", + "parquetScan", + "csv_scan", + "csvScan", + "empty", + "projection", + "global_limit", + "globalLimit", + "local_limit", + "localLimit", + "aggregate", + "hash_join", + "hashJoin", + "sort", + "coalesce_batches", + "coalesceBatches", + "filter", + "merge", + "repartition", + "window", + "cross_join", + "crossJoin", + "avro_scan", + "avroScan", + "extension", + "union", + "explain", + "sort_preserving_merge", + "sortPreservingMerge", + "nested_loop_join", + "nestedLoopJoin", + "analyze", + "json_sink", + "jsonSink", + "symmetric_hash_join", + "symmetricHashJoin", + "interleave", + "placeholder_row", + "placeholderRow", + "csv_sink", + "csvSink", + "parquet_sink", + "parquetSink", + "unnest", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Indices, + ParquetScan, + CsvScan, + Empty, + Projection, + GlobalLimit, + LocalLimit, + Aggregate, + HashJoin, + Sort, + CoalesceBatches, + Filter, + Merge, + Repartition, + Window, + CrossJoin, + AvroScan, + Extension, + Union, + Explain, + SortPreservingMerge, + NestedLoopJoin, + Analyze, + JsonSink, + SymmetricHashJoin, + Interleave, + PlaceholderRow, + CsvSink, + ParquetSink, + Unnest, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21643,7 +15363,35 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { E: serde::de::Error, { match value { - "indices" => Ok(GeneratedField::Indices), + "parquetScan" | "parquet_scan" => Ok(GeneratedField::ParquetScan), + "csvScan" | "csv_scan" => Ok(GeneratedField::CsvScan), + "empty" => Ok(GeneratedField::Empty), + "projection" => Ok(GeneratedField::Projection), + "globalLimit" | "global_limit" => Ok(GeneratedField::GlobalLimit), + "localLimit" | "local_limit" => Ok(GeneratedField::LocalLimit), + "aggregate" => Ok(GeneratedField::Aggregate), + "hashJoin" | "hash_join" => Ok(GeneratedField::HashJoin), + "sort" => Ok(GeneratedField::Sort), + "coalesceBatches" | "coalesce_batches" => Ok(GeneratedField::CoalesceBatches), + "filter" => Ok(GeneratedField::Filter), + "merge" => Ok(GeneratedField::Merge), + "repartition" => Ok(GeneratedField::Repartition), + "window" => Ok(GeneratedField::Window), + "crossJoin" | "cross_join" => Ok(GeneratedField::CrossJoin), + "avroScan" | "avro_scan" => Ok(GeneratedField::AvroScan), + "extension" => Ok(GeneratedField::Extension), + "union" => Ok(GeneratedField::Union), + "explain" => Ok(GeneratedField::Explain), + "sortPreservingMerge" | "sort_preserving_merge" => Ok(GeneratedField::SortPreservingMerge), + "nestedLoopJoin" | "nested_loop_join" => Ok(GeneratedField::NestedLoopJoin), + "analyze" => Ok(GeneratedField::Analyze), + "jsonSink" | "json_sink" => Ok(GeneratedField::JsonSink), + "symmetricHashJoin" | "symmetric_hash_join" => Ok(GeneratedField::SymmetricHashJoin), + "interleave" => Ok(GeneratedField::Interleave), + "placeholderRow" | "placeholder_row" => Ok(GeneratedField::PlaceholderRow), + "csvSink" | "csv_sink" => Ok(GeneratedField::CsvSink), + "parquetSink" | "parquet_sink" => Ok(GeneratedField::ParquetSink), + "unnest" => Ok(GeneratedField::Unnest), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21653,130 +15401,233 @@ impl<'de> serde::Deserialize<'de> for PrimaryKeyConstraint { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PrimaryKeyConstraint; + type Value = PhysicalPlanNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PrimaryKeyConstraint") + formatter.write_str("struct datafusion.PhysicalPlanNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut indices__ = None; + let mut physical_plan_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Indices => { - if indices__.is_some() { - return Err(serde::de::Error::duplicate_field("indices")); + GeneratedField::ParquetScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetScan) +; + } + GeneratedField::CsvScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvScan) +; + } + GeneratedField::Empty => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("empty")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Empty) +; + } + GeneratedField::Projection => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("projection")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Projection) +; + } + GeneratedField::GlobalLimit => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("globalLimit")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GlobalLimit) +; + } + GeneratedField::LocalLimit => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("localLimit")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::LocalLimit) +; + } + GeneratedField::Aggregate => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("aggregate")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Aggregate) +; + } + GeneratedField::HashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("hashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::HashJoin) +; + } + GeneratedField::Sort => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sort")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Sort) +; + } + GeneratedField::CoalesceBatches => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("coalesceBatches")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CoalesceBatches) +; + } + GeneratedField::Filter => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Filter) +; + } + GeneratedField::Merge => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("merge")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Merge) +; + } + GeneratedField::Repartition => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("repartition")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Repartition) +; + } + GeneratedField::Window => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("window")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Window) +; + } + GeneratedField::CrossJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("crossJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CrossJoin) +; + } + GeneratedField::AvroScan => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("avroScan")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::AvroScan) +; + } + GeneratedField::Extension => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("extension")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Extension) +; + } + GeneratedField::Union => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("union")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Union) +; + } + GeneratedField::Explain => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("explain")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Explain) +; + } + GeneratedField::SortPreservingMerge => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("sortPreservingMerge")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge) +; + } + GeneratedField::NestedLoopJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("nestedLoopJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin) +; + } + GeneratedField::Analyze => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("analyze")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Analyze) +; + } + GeneratedField::JsonSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("jsonSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::JsonSink) +; + } + GeneratedField::SymmetricHashJoin => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("symmetricHashJoin")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SymmetricHashJoin) +; + } + GeneratedField::Interleave => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("interleave")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Interleave) +; + } + GeneratedField::PlaceholderRow => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("placeholderRow")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::PlaceholderRow) +; + } + GeneratedField::CsvSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("csvSink")); } - indices__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::CsvSink) +; } - } - } - Ok(PrimaryKeyConstraint { - indices: indices__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.PrimaryKeyConstraint", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ProjectionColumns { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.columns.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ProjectionColumns", len)?; - if !self.columns.is_empty() { - struct_ser.serialize_field("columns", &self.columns)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ProjectionColumns { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "columns", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Columns, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "columns" => Ok(GeneratedField::Columns), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + GeneratedField::ParquetSink => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("parquetSink")); + } + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::ParquetSink) +; } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ProjectionColumns; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ProjectionColumns") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut columns__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Columns => { - if columns__.is_some() { - return Err(serde::de::Error::duplicate_field("columns")); + GeneratedField::Unnest => { + if physical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("unnest")); } - columns__ = Some(map_.next_value()?); + physical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::Unnest) +; } } } - Ok(ProjectionColumns { - columns: columns__.unwrap_or_default(), + Ok(PhysicalPlanNode { + physical_plan_type: physical_plan_type__, }) } } - deserializer.deserialize_struct("datafusion.ProjectionColumns", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ProjectionExecNode { +impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -21784,46 +15635,57 @@ impl serde::Serialize for ProjectionExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.name.is_empty() { len += 1; } - if !self.expr.is_empty() { + if !self.args.is_empty() { len += 1; } - if !self.expr_name.is_empty() { + if self.fun_definition.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ProjectionExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if self.return_type.is_some() { + len += 1; } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarUdfNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; } - if !self.expr_name.is_empty() { - struct_ser.serialize_field("exprName", &self.expr_name)?; + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } + if let Some(v) = self.return_type.as_ref() { + struct_ser.serialize_field("returnType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ProjectionExecNode { +impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "expr", - "expr_name", - "exprName", + "name", + "args", + "fun_definition", + "funDefinition", + "return_type", + "returnType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Expr, - ExprName, + Name, + Args, + FunDefinition, + ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21845,9 +15707,10 @@ impl<'de> serde::Deserialize<'de> for ProjectionExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "expr" => Ok(GeneratedField::Expr), - "exprName" | "expr_name" => Ok(GeneratedField::ExprName), + "name" => Ok(GeneratedField::Name), + "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21857,52 +15720,62 @@ impl<'de> serde::Deserialize<'de> for ProjectionExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ProjectionExecNode; + type Value = PhysicalScalarUdfNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ProjectionExecNode") + formatter.write_str("struct datafusion.PhysicalScalarUdfNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut expr__ = None; - let mut expr_name__ = None; + let mut name__ = None; + let mut args__ = None; + let mut fun_definition__ = None; + let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); } - input__ = map_.next_value()?; + name__ = Some(map_.next_value()?); } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); } - expr__ = Some(map_.next_value()?); + args__ = Some(map_.next_value()?); } - GeneratedField::ExprName => { - if expr_name__.is_some() { - return Err(serde::de::Error::duplicate_field("exprName")); + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); } - expr_name__ = Some(map_.next_value()?); + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } + GeneratedField::ReturnType => { + if return_type__.is_some() { + return Err(serde::de::Error::duplicate_field("returnType")); + } + return_type__ = map_.next_value()?; } } } - Ok(ProjectionExecNode { - input: input__, - expr: expr__.unwrap_or_default(), - expr_name: expr_name__.unwrap_or_default(), + Ok(PhysicalScalarUdfNode { + name: name__.unwrap_or_default(), + args: args__.unwrap_or_default(), + fun_definition: fun_definition__, + return_type: return_type__, }) } } - deserializer.deserialize_struct("datafusion.ProjectionExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalScalarUdfNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ProjectionNode { +impl serde::Serialize for PhysicalSortExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -21910,49 +15783,46 @@ impl serde::Serialize for ProjectionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.expr.is_some() { len += 1; } - if !self.expr.is_empty() { + if self.asc { len += 1; } - if self.optional_alias.is_some() { + if self.nulls_first { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ProjectionNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + if self.asc { + struct_ser.serialize_field("asc", &self.asc)?; } - if let Some(v) = self.optional_alias.as_ref() { - match v { - projection_node::OptionalAlias::Alias(v) => { - struct_ser.serialize_field("alias", v)?; - } - } + if self.nulls_first { + struct_ser.serialize_field("nullsFirst", &self.nulls_first)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ProjectionNode { +impl<'de> serde::Deserialize<'de> for PhysicalSortExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", "expr", - "alias", + "asc", + "nulls_first", + "nullsFirst", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, Expr, - Alias, + Asc, + NullsFirst, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -21974,9 +15844,9 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), - "alias" => Ok(GeneratedField::Alias), + "asc" => Ok(GeneratedField::Asc), + "nullsFirst" | "nulls_first" => Ok(GeneratedField::NullsFirst), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -21986,52 +15856,52 @@ impl<'de> serde::Deserialize<'de> for ProjectionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ProjectionNode; + type Value = PhysicalSortExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ProjectionNode") + formatter.write_str("struct datafusion.PhysicalSortExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; let mut expr__ = None; - let mut optional_alias__ = None; + let mut asc__ = None; + let mut nulls_first__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::Alias => { - if optional_alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); + GeneratedField::Asc => { + if asc__.is_some() { + return Err(serde::de::Error::duplicate_field("asc")); } - optional_alias__ = map_.next_value::<::std::option::Option<_>>()?.map(projection_node::OptionalAlias::Alias); + asc__ = Some(map_.next_value()?); + } + GeneratedField::NullsFirst => { + if nulls_first__.is_some() { + return Err(serde::de::Error::duplicate_field("nullsFirst")); + } + nulls_first__ = Some(map_.next_value()?); } } } - Ok(ProjectionNode { - input: input__, - expr: expr__.unwrap_or_default(), - optional_alias: optional_alias__, + Ok(PhysicalSortExprNode { + expr: expr__, + asc: asc__.unwrap_or_default(), + nulls_first: nulls_first__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalSortExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for RepartitionExecNode { +impl serde::Serialize for PhysicalSortExprNodeCollection { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22039,54 +15909,30 @@ impl serde::Serialize for RepartitionExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { - len += 1; - } - if self.partition_method.is_some() { + if !self.physical_sort_expr_nodes.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.RepartitionExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if let Some(v) = self.partition_method.as_ref() { - match v { - repartition_exec_node::PartitionMethod::RoundRobin(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("roundRobin", ToString::to_string(&v).as_str())?; - } - repartition_exec_node::PartitionMethod::Hash(v) => { - struct_ser.serialize_field("hash", v)?; - } - repartition_exec_node::PartitionMethod::Unknown(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("unknown", ToString::to_string(&v).as_str())?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalSortExprNodeCollection", len)?; + if !self.physical_sort_expr_nodes.is_empty() { + struct_ser.serialize_field("physicalSortExprNodes", &self.physical_sort_expr_nodes)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for RepartitionExecNode { +impl<'de> serde::Deserialize<'de> for PhysicalSortExprNodeCollection { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "round_robin", - "roundRobin", - "hash", - "unknown", + "physical_sort_expr_nodes", + "physicalSortExprNodes", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - RoundRobin, - Hash, - Unknown, + PhysicalSortExprNodes, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22108,10 +15954,7 @@ impl<'de> serde::Deserialize<'de> for RepartitionExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "roundRobin" | "round_robin" => Ok(GeneratedField::RoundRobin), - "hash" => Ok(GeneratedField::Hash), - "unknown" => Ok(GeneratedField::Unknown), + "physicalSortExprNodes" | "physical_sort_expr_nodes" => Ok(GeneratedField::PhysicalSortExprNodes), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22121,57 +15964,36 @@ impl<'de> serde::Deserialize<'de> for RepartitionExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = RepartitionExecNode; + type Value = PhysicalSortExprNodeCollection; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.RepartitionExecNode") + formatter.write_str("struct datafusion.PhysicalSortExprNodeCollection") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut partition_method__ = None; + let mut physical_sort_expr_nodes__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::RoundRobin => { - if partition_method__.is_some() { - return Err(serde::de::Error::duplicate_field("roundRobin")); - } - partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_exec_node::PartitionMethod::RoundRobin(x.0)); - } - GeneratedField::Hash => { - if partition_method__.is_some() { - return Err(serde::de::Error::duplicate_field("hash")); - } - partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_exec_node::PartitionMethod::Hash) -; - } - GeneratedField::Unknown => { - if partition_method__.is_some() { - return Err(serde::de::Error::duplicate_field("unknown")); + GeneratedField::PhysicalSortExprNodes => { + if physical_sort_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("physicalSortExprNodes")); } - partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_exec_node::PartitionMethod::Unknown(x.0)); + physical_sort_expr_nodes__ = Some(map_.next_value()?); } } } - Ok(RepartitionExecNode { - input: input__, - partition_method: partition_method__, + Ok(PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: physical_sort_expr_nodes__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.RepartitionExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalSortExprNodeCollection", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for RepartitionNode { +impl serde::Serialize for PhysicalTryCastNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22179,48 +16001,38 @@ impl serde::Serialize for RepartitionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.expr.is_some() { len += 1; } - if self.partition_method.is_some() { + if self.arrow_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.RepartitionNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalTryCastNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.partition_method.as_ref() { - match v { - repartition_node::PartitionMethod::RoundRobin(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("roundRobin", ToString::to_string(&v).as_str())?; - } - repartition_node::PartitionMethod::Hash(v) => { - struct_ser.serialize_field("hash", v)?; - } - } + if let Some(v) = self.arrow_type.as_ref() { + struct_ser.serialize_field("arrowType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for RepartitionNode { +impl<'de> serde::Deserialize<'de> for PhysicalTryCastNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "round_robin", - "roundRobin", - "hash", + "expr", + "arrow_type", + "arrowType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - RoundRobin, - Hash, + Expr, + ArrowType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22242,9 +16054,8 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "roundRobin" | "round_robin" => Ok(GeneratedField::RoundRobin), - "hash" => Ok(GeneratedField::Hash), + "expr" => Ok(GeneratedField::Expr), + "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22254,51 +16065,44 @@ impl<'de> serde::Deserialize<'de> for RepartitionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = RepartitionNode; + type Value = PhysicalTryCastNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.RepartitionNode") + formatter.write_str("struct datafusion.PhysicalTryCastNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut partition_method__ = None; + let mut expr__ = None; + let mut arrow_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::RoundRobin => { - if partition_method__.is_some() { - return Err(serde::de::Error::duplicate_field("roundRobin")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_node::PartitionMethod::RoundRobin(x.0)); + expr__ = map_.next_value()?; } - GeneratedField::Hash => { - if partition_method__.is_some() { - return Err(serde::de::Error::duplicate_field("hash")); + GeneratedField::ArrowType => { + if arrow_type__.is_some() { + return Err(serde::de::Error::duplicate_field("arrowType")); } - partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) -; + arrow_type__ = map_.next_value()?; } } } - Ok(RepartitionNode { - input: input__, - partition_method: partition_method__, + Ok(PhysicalTryCastNode { + expr: expr__, + arrow_type: arrow_type__, }) } } - deserializer.deserialize_struct("datafusion.RepartitionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalTryCastNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for RollupNode { +impl serde::Serialize for PhysicalWhenThen { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22306,29 +16110,39 @@ impl serde::Serialize for RollupNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.expr.is_empty() { + if self.when_expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.RollupNode", len)?; - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + if self.then_expr.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWhenThen", len)?; + if let Some(v) = self.when_expr.as_ref() { + struct_ser.serialize_field("whenExpr", v)?; + } + if let Some(v) = self.then_expr.as_ref() { + struct_ser.serialize_field("thenExpr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for RollupNode { +impl<'de> serde::Deserialize<'de> for PhysicalWhenThen { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "when_expr", + "whenExpr", + "then_expr", + "thenExpr", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + WhenExpr, + ThenExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22350,7 +16164,8 @@ impl<'de> serde::Deserialize<'de> for RollupNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "whenExpr" | "when_expr" => Ok(GeneratedField::WhenExpr), + "thenExpr" | "then_expr" => Ok(GeneratedField::ThenExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22360,36 +16175,44 @@ impl<'de> serde::Deserialize<'de> for RollupNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = RollupNode; + type Value = PhysicalWhenThen; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.RollupNode") + formatter.write_str("struct datafusion.PhysicalWhenThen") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut when_expr__ = None; + let mut then_expr__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::WhenExpr => { + if when_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("whenExpr")); } - expr__ = Some(map_.next_value()?); + when_expr__ = map_.next_value()?; + } + GeneratedField::ThenExpr => { + if then_expr__.is_some() { + return Err(serde::de::Error::duplicate_field("thenExpr")); + } + then_expr__ = map_.next_value()?; } } } - Ok(RollupNode { - expr: expr__.unwrap_or_default(), + Ok(PhysicalWhenThen { + when_expr: when_expr__, + then_expr: then_expr__, }) } } - deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalWhenThen", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarDictionaryValue { +impl serde::Serialize for PhysicalWindowExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22397,38 +16220,96 @@ impl serde::Serialize for ScalarDictionaryValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.index_type.is_some() { + if !self.args.is_empty() { + len += 1; + } + if !self.partition_by.is_empty() { + len += 1; + } + if !self.order_by.is_empty() { + len += 1; + } + if self.window_frame.is_some() { + len += 1; + } + if !self.name.is_empty() { + len += 1; + } + if self.fun_definition.is_some() { len += 1; } - if self.value.is_some() { + if self.window_function.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarDictionaryValue", len)?; - if let Some(v) = self.index_type.as_ref() { - struct_ser.serialize_field("indexType", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalWindowExprNode", len)?; + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if !self.partition_by.is_empty() { + struct_ser.serialize_field("partitionBy", &self.partition_by)?; + } + if !self.order_by.is_empty() { + struct_ser.serialize_field("orderBy", &self.order_by)?; + } + if let Some(v) = self.window_frame.as_ref() { + struct_ser.serialize_field("windowFrame", v)?; + } + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } - if let Some(v) = self.value.as_ref() { - struct_ser.serialize_field("value", v)?; + if let Some(v) = self.window_function.as_ref() { + match v { + physical_window_expr_node::WindowFunction::BuiltInFunction(v) => { + let v = BuiltInWindowFunction::try_from(*v) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; + struct_ser.serialize_field("builtInFunction", &v)?; + } + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(v) => { + struct_ser.serialize_field("userDefinedAggrFunction", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { +impl<'de> serde::Deserialize<'de> for PhysicalWindowExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "index_type", - "indexType", - "value", + "args", + "partition_by", + "partitionBy", + "order_by", + "orderBy", + "window_frame", + "windowFrame", + "name", + "fun_definition", + "funDefinition", + "built_in_function", + "builtInFunction", + "user_defined_aggr_function", + "userDefinedAggrFunction", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - IndexType, - Value, + Args, + PartitionBy, + OrderBy, + WindowFrame, + Name, + FunDefinition, + BuiltInFunction, + UserDefinedAggrFunction, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22450,8 +16331,14 @@ impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { E: serde::de::Error, { match value { - "indexType" | "index_type" => Ok(GeneratedField::IndexType), - "value" => Ok(GeneratedField::Value), + "args" => Ok(GeneratedField::Args), + "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), + "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), + "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), + "name" => Ok(GeneratedField::Name), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), + "userDefinedAggrFunction" | "user_defined_aggr_function" => Ok(GeneratedField::UserDefinedAggrFunction), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22461,44 +16348,92 @@ impl<'de> serde::Deserialize<'de> for ScalarDictionaryValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarDictionaryValue; + type Value = PhysicalWindowExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarDictionaryValue") + formatter.write_str("struct datafusion.PhysicalWindowExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut index_type__ = None; - let mut value__ = None; + let mut args__ = None; + let mut partition_by__ = None; + let mut order_by__ = None; + let mut window_frame__ = None; + let mut name__ = None; + let mut fun_definition__ = None; + let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::IndexType => { - if index_type__.is_some() { - return Err(serde::de::Error::duplicate_field("indexType")); + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); + } + args__ = Some(map_.next_value()?); + } + GeneratedField::PartitionBy => { + if partition_by__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionBy")); + } + partition_by__ = Some(map_.next_value()?); + } + GeneratedField::OrderBy => { + if order_by__.is_some() { + return Err(serde::de::Error::duplicate_field("orderBy")); + } + order_by__ = Some(map_.next_value()?); + } + GeneratedField::WindowFrame => { + if window_frame__.is_some() { + return Err(serde::de::Error::duplicate_field("windowFrame")); + } + window_frame__ = map_.next_value()?; + } + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } + GeneratedField::BuiltInFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("builtInFunction")); } - index_type__ = map_.next_value()?; + window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| physical_window_expr_node::WindowFunction::BuiltInFunction(x as i32)); } - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::UserDefinedAggrFunction => { + if window_function__.is_some() { + return Err(serde::de::Error::duplicate_field("userDefinedAggrFunction")); } - value__ = map_.next_value()?; + window_function__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_window_expr_node::WindowFunction::UserDefinedAggrFunction); } } } - Ok(ScalarDictionaryValue { - index_type: index_type__, - value: value__, + Ok(PhysicalWindowExprNode { + args: args__.unwrap_or_default(), + partition_by: partition_by__.unwrap_or_default(), + order_by: order_by__.unwrap_or_default(), + window_frame: window_frame__, + name: name__.unwrap_or_default(), + fun_definition: fun_definition__, + window_function: window_function__, }) } } - deserializer.deserialize_struct("datafusion.ScalarDictionaryValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PhysicalWindowExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarFixedSizeBinary { +impl serde::Serialize for PlaceholderNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22506,38 +16441,38 @@ impl serde::Serialize for ScalarFixedSizeBinary { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.values.is_empty() { + if !self.id.is_empty() { len += 1; } - if self.length != 0 { + if self.data_type.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarFixedSizeBinary", len)?; - if !self.values.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("values", pbjson::private::base64::encode(&self.values).as_str())?; + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderNode", len)?; + if !self.id.is_empty() { + struct_ser.serialize_field("id", &self.id)?; } - if self.length != 0 { - struct_ser.serialize_field("length", &self.length)?; + if let Some(v) = self.data_type.as_ref() { + struct_ser.serialize_field("dataType", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { +impl<'de> serde::Deserialize<'de> for PlaceholderNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "values", - "length", + "id", + "data_type", + "dataType", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Values, - Length, + Id, + DataType, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22559,8 +16494,8 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { E: serde::de::Error, { match value { - "values" => Ok(GeneratedField::Values), - "length" => Ok(GeneratedField::Length), + "id" => Ok(GeneratedField::Id), + "dataType" | "data_type" => Ok(GeneratedField::DataType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22570,48 +16505,44 @@ impl<'de> serde::Deserialize<'de> for ScalarFixedSizeBinary { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarFixedSizeBinary; + type Value = PlaceholderNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarFixedSizeBinary") + formatter.write_str("struct datafusion.PlaceholderNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut values__ = None; - let mut length__ = None; + let mut id__ = None; + let mut data_type__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Values => { - if values__.is_some() { - return Err(serde::de::Error::duplicate_field("values")); + GeneratedField::Id => { + if id__.is_some() { + return Err(serde::de::Error::duplicate_field("id")); } - values__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; + id__ = Some(map_.next_value()?); } - GeneratedField::Length => { - if length__.is_some() { - return Err(serde::de::Error::duplicate_field("length")); + GeneratedField::DataType => { + if data_type__.is_some() { + return Err(serde::de::Error::duplicate_field("dataType")); } - length__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + data_type__ = map_.next_value()?; } } } - Ok(ScalarFixedSizeBinary { - values: values__.unwrap_or_default(), - length: length__.unwrap_or_default(), + Ok(PlaceholderNode { + id: id__.unwrap_or_default(), + data_type: data_type__, }) } } - deserializer.deserialize_struct("datafusion.ScalarFixedSizeBinary", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlaceholderNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarNestedValue { +impl serde::Serialize for PlaceholderRowExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22619,48 +16550,28 @@ impl serde::Serialize for ScalarNestedValue { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.ipc_message.is_empty() { - len += 1; - } - if !self.arrow_data.is_empty() { - len += 1; - } if self.schema.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarNestedValue", len)?; - if !self.ipc_message.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("ipcMessage", pbjson::private::base64::encode(&self.ipc_message).as_str())?; - } - if !self.arrow_data.is_empty() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("arrowData", pbjson::private::base64::encode(&self.arrow_data).as_str())?; - } + let mut struct_ser = serializer.serialize_struct("datafusion.PlaceholderRowExecNode", len)?; if let Some(v) = self.schema.as_ref() { struct_ser.serialize_field("schema", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarNestedValue { +impl<'de> serde::Deserialize<'de> for PlaceholderRowExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "ipc_message", - "ipcMessage", - "arrow_data", - "arrowData", "schema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - IpcMessage, - ArrowData, Schema, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -22683,8 +16594,6 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { E: serde::de::Error, { match value { - "ipcMessage" | "ipc_message" => Ok(GeneratedField::IpcMessage), - "arrowData" | "arrow_data" => Ok(GeneratedField::ArrowData), "schema" => Ok(GeneratedField::Schema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -22695,37 +16604,19 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarNestedValue; + type Value = PlaceholderRowExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarNestedValue") + formatter.write_str("struct datafusion.PlaceholderRowExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut ipc_message__ = None; - let mut arrow_data__ = None; let mut schema__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::IpcMessage => { - if ipc_message__.is_some() { - return Err(serde::de::Error::duplicate_field("ipcMessage")); - } - ipc_message__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; - } - GeneratedField::ArrowData => { - if arrow_data__.is_some() { - return Err(serde::de::Error::duplicate_field("arrowData")); - } - arrow_data__ = - Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) - ; - } GeneratedField::Schema => { if schema__.is_some() { return Err(serde::de::Error::duplicate_field("schema")); @@ -22734,17 +16625,15 @@ impl<'de> serde::Deserialize<'de> for ScalarNestedValue { } } } - Ok(ScalarNestedValue { - ipc_message: ipc_message__.unwrap_or_default(), - arrow_data: arrow_data__.unwrap_or_default(), + Ok(PlaceholderRowExecNode { schema: schema__, }) } } - deserializer.deserialize_struct("datafusion.ScalarNestedValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlaceholderRowExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarTime32Value { +impl serde::Serialize for PlanType { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22752,40 +16641,88 @@ impl serde::Serialize for ScalarTime32Value { { use serde::ser::SerializeStruct; let mut len = 0; - if self.value.is_some() { + if self.plan_type_enum.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarTime32Value", len)?; - if let Some(v) = self.value.as_ref() { + let mut struct_ser = serializer.serialize_struct("datafusion.PlanType", len)?; + if let Some(v) = self.plan_type_enum.as_ref() { match v { - scalar_time32_value::Value::Time32SecondValue(v) => { - struct_ser.serialize_field("time32SecondValue", v)?; + plan_type::PlanTypeEnum::InitialLogicalPlan(v) => { + struct_ser.serialize_field("InitialLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::AnalyzedLogicalPlan(v) => { + struct_ser.serialize_field("AnalyzedLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan(v) => { + struct_ser.serialize_field("FinalAnalyzedLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::OptimizedLogicalPlan(v) => { + struct_ser.serialize_field("OptimizedLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalLogicalPlan(v) => { + struct_ser.serialize_field("FinalLogicalPlan", v)?; + } + plan_type::PlanTypeEnum::InitialPhysicalPlan(v) => { + struct_ser.serialize_field("InitialPhysicalPlan", v)?; + } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithStats", v)?; + } + plan_type::PlanTypeEnum::InitialPhysicalPlanWithSchema(v) => { + struct_ser.serialize_field("InitialPhysicalPlanWithSchema", v)?; + } + plan_type::PlanTypeEnum::OptimizedPhysicalPlan(v) => { + struct_ser.serialize_field("OptimizedPhysicalPlan", v)?; } - scalar_time32_value::Value::Time32MillisecondValue(v) => { - struct_ser.serialize_field("time32MillisecondValue", v)?; + plan_type::PlanTypeEnum::FinalPhysicalPlan(v) => { + struct_ser.serialize_field("FinalPhysicalPlan", v)?; + } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithStats", v)?; + } + plan_type::PlanTypeEnum::FinalPhysicalPlanWithSchema(v) => { + struct_ser.serialize_field("FinalPhysicalPlanWithSchema", v)?; } } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarTime32Value { +impl<'de> serde::Deserialize<'de> for PlanType { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "time32_second_value", - "time32SecondValue", - "time32_millisecond_value", - "time32MillisecondValue", + "InitialLogicalPlan", + "AnalyzedLogicalPlan", + "FinalAnalyzedLogicalPlan", + "OptimizedLogicalPlan", + "FinalLogicalPlan", + "InitialPhysicalPlan", + "InitialPhysicalPlanWithStats", + "InitialPhysicalPlanWithSchema", + "OptimizedPhysicalPlan", + "FinalPhysicalPlan", + "FinalPhysicalPlanWithStats", + "FinalPhysicalPlanWithSchema", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Time32SecondValue, - Time32MillisecondValue, + InitialLogicalPlan, + AnalyzedLogicalPlan, + FinalAnalyzedLogicalPlan, + OptimizedLogicalPlan, + FinalLogicalPlan, + InitialPhysicalPlan, + InitialPhysicalPlanWithStats, + InitialPhysicalPlanWithSchema, + OptimizedPhysicalPlan, + FinalPhysicalPlan, + FinalPhysicalPlanWithStats, + FinalPhysicalPlanWithSchema, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22807,8 +16744,18 @@ impl<'de> serde::Deserialize<'de> for ScalarTime32Value { E: serde::de::Error, { match value { - "time32SecondValue" | "time32_second_value" => Ok(GeneratedField::Time32SecondValue), - "time32MillisecondValue" | "time32_millisecond_value" => Ok(GeneratedField::Time32MillisecondValue), + "InitialLogicalPlan" => Ok(GeneratedField::InitialLogicalPlan), + "AnalyzedLogicalPlan" => Ok(GeneratedField::AnalyzedLogicalPlan), + "FinalAnalyzedLogicalPlan" => Ok(GeneratedField::FinalAnalyzedLogicalPlan), + "OptimizedLogicalPlan" => Ok(GeneratedField::OptimizedLogicalPlan), + "FinalLogicalPlan" => Ok(GeneratedField::FinalLogicalPlan), + "InitialPhysicalPlan" => Ok(GeneratedField::InitialPhysicalPlan), + "InitialPhysicalPlanWithStats" => Ok(GeneratedField::InitialPhysicalPlanWithStats), + "InitialPhysicalPlanWithSchema" => Ok(GeneratedField::InitialPhysicalPlanWithSchema), + "OptimizedPhysicalPlan" => Ok(GeneratedField::OptimizedPhysicalPlan), + "FinalPhysicalPlan" => Ok(GeneratedField::FinalPhysicalPlan), + "FinalPhysicalPlanWithStats" => Ok(GeneratedField::FinalPhysicalPlanWithStats), + "FinalPhysicalPlanWithSchema" => Ok(GeneratedField::FinalPhysicalPlanWithSchema), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22818,42 +16765,114 @@ impl<'de> serde::Deserialize<'de> for ScalarTime32Value { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarTime32Value; + type Value = PlanType; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarTime32Value") + formatter.write_str("struct datafusion.PlanType") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; + let mut plan_type_enum__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Time32SecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("time32SecondValue")); + GeneratedField::InitialLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialLogicalPlan) +; + } + GeneratedField::AnalyzedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("AnalyzedLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::AnalyzedLogicalPlan) +; + } + GeneratedField::FinalAnalyzedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalAnalyzedLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalAnalyzedLogicalPlan) +; + } + GeneratedField::OptimizedLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("OptimizedLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedLogicalPlan) +; + } + GeneratedField::FinalLogicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalLogicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalLogicalPlan) +; + } + GeneratedField::InitialPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlan) +; + } + GeneratedField::InitialPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithStats) +; + } + GeneratedField::InitialPhysicalPlanWithSchema => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("InitialPhysicalPlanWithSchema")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::InitialPhysicalPlanWithSchema) +; + } + GeneratedField::OptimizedPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("OptimizedPhysicalPlan")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::OptimizedPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlan => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlan")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32SecondValue(x.0)); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlan) +; + } + GeneratedField::FinalPhysicalPlanWithStats => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithStats")); + } + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithStats) +; } - GeneratedField::Time32MillisecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("time32MillisecondValue")); + GeneratedField::FinalPhysicalPlanWithSchema => { + if plan_type_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("FinalPhysicalPlanWithSchema")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time32_value::Value::Time32MillisecondValue(x.0)); + plan_type_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(plan_type::PlanTypeEnum::FinalPhysicalPlanWithSchema) +; } } } - Ok(ScalarTime32Value { - value: value__, + Ok(PlanType { + plan_type_enum: plan_type_enum__, }) } } - deserializer.deserialize_struct("datafusion.ScalarTime32Value", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PlanType", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarTime64Value { +impl serde::Serialize for PrepareNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22861,42 +16880,46 @@ impl serde::Serialize for ScalarTime64Value { { use serde::ser::SerializeStruct; let mut len = 0; - if self.value.is_some() { + if !self.name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarTime64Value", len)?; - if let Some(v) = self.value.as_ref() { - match v { - scalar_time64_value::Value::Time64MicrosecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("time64MicrosecondValue", ToString::to_string(&v).as_str())?; - } - scalar_time64_value::Value::Time64NanosecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("time64NanosecondValue", ToString::to_string(&v).as_str())?; - } - } + if !self.data_types.is_empty() { + len += 1; + } + if self.input.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PrepareNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if !self.data_types.is_empty() { + struct_ser.serialize_field("dataTypes", &self.data_types)?; + } + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarTime64Value { +impl<'de> serde::Deserialize<'de> for PrepareNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "time64_microsecond_value", - "time64MicrosecondValue", - "time64_nanosecond_value", - "time64NanosecondValue", + "name", + "data_types", + "dataTypes", + "input", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Time64MicrosecondValue, - Time64NanosecondValue, + Name, + DataTypes, + Input, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -22918,8 +16941,9 @@ impl<'de> serde::Deserialize<'de> for ScalarTime64Value { E: serde::de::Error, { match value { - "time64MicrosecondValue" | "time64_microsecond_value" => Ok(GeneratedField::Time64MicrosecondValue), - "time64NanosecondValue" | "time64_nanosecond_value" => Ok(GeneratedField::Time64NanosecondValue), + "name" => Ok(GeneratedField::Name), + "dataTypes" | "data_types" => Ok(GeneratedField::DataTypes), + "input" => Ok(GeneratedField::Input), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -22929,42 +16953,52 @@ impl<'de> serde::Deserialize<'de> for ScalarTime64Value { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarTime64Value; + type Value = PrepareNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarTime64Value") + formatter.write_str("struct datafusion.PrepareNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; + let mut name__ = None; + let mut data_types__ = None; + let mut input__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Time64MicrosecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("time64MicrosecondValue")); + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::DataTypes => { + if data_types__.is_some() { + return Err(serde::de::Error::duplicate_field("dataTypes")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64MicrosecondValue(x.0)); + data_types__ = Some(map_.next_value()?); } - GeneratedField::Time64NanosecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("time64NanosecondValue")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_time64_value::Value::Time64NanosecondValue(x.0)); + input__ = map_.next_value()?; } } } - Ok(ScalarTime64Value { - value: value__, + Ok(PrepareNode { + name: name__.unwrap_or_default(), + data_types: data_types__.unwrap_or_default(), + input: input__, }) } } - deserializer.deserialize_struct("datafusion.ScalarTime64Value", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.PrepareNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarTimestampValue { +impl serde::Serialize for ProjectionColumns { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -22972,64 +17006,29 @@ impl serde::Serialize for ScalarTimestampValue { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.timezone.is_empty() { - len += 1; - } - if self.value.is_some() { + if !self.columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarTimestampValue", len)?; - if !self.timezone.is_empty() { - struct_ser.serialize_field("timezone", &self.timezone)?; - } - if let Some(v) = self.value.as_ref() { - match v { - scalar_timestamp_value::Value::TimeMicrosecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("timeMicrosecondValue", ToString::to_string(&v).as_str())?; - } - scalar_timestamp_value::Value::TimeNanosecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("timeNanosecondValue", ToString::to_string(&v).as_str())?; - } - scalar_timestamp_value::Value::TimeSecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("timeSecondValue", ToString::to_string(&v).as_str())?; - } - scalar_timestamp_value::Value::TimeMillisecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("timeMillisecondValue", ToString::to_string(&v).as_str())?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.ProjectionColumns", len)?; + if !self.columns.is_empty() { + struct_ser.serialize_field("columns", &self.columns)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { +impl<'de> serde::Deserialize<'de> for ProjectionColumns { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "timezone", - "time_microsecond_value", - "timeMicrosecondValue", - "time_nanosecond_value", - "timeNanosecondValue", - "time_second_value", - "timeSecondValue", - "time_millisecond_value", - "timeMillisecondValue", + "columns", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Timezone, - TimeMicrosecondValue, - TimeNanosecondValue, - TimeSecondValue, - TimeMillisecondValue, + Columns, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23051,11 +17050,7 @@ impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { E: serde::de::Error, { match value { - "timezone" => Ok(GeneratedField::Timezone), - "timeMicrosecondValue" | "time_microsecond_value" => Ok(GeneratedField::TimeMicrosecondValue), - "timeNanosecondValue" | "time_nanosecond_value" => Ok(GeneratedField::TimeNanosecondValue), - "timeSecondValue" | "time_second_value" => Ok(GeneratedField::TimeSecondValue), - "timeMillisecondValue" | "time_millisecond_value" => Ok(GeneratedField::TimeMillisecondValue), + "columns" => Ok(GeneratedField::Columns), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23065,62 +17060,36 @@ impl<'de> serde::Deserialize<'de> for ScalarTimestampValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarTimestampValue; + type Value = ProjectionColumns; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarTimestampValue") + formatter.write_str("struct datafusion.ProjectionColumns") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut timezone__ = None; - let mut value__ = None; + let mut columns__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Timezone => { - if timezone__.is_some() { - return Err(serde::de::Error::duplicate_field("timezone")); - } - timezone__ = Some(map_.next_value()?); - } - GeneratedField::TimeMicrosecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("timeMicrosecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMicrosecondValue(x.0)); - } - GeneratedField::TimeNanosecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("timeNanosecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeNanosecondValue(x.0)); - } - GeneratedField::TimeSecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("timeSecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeSecondValue(x.0)); - } - GeneratedField::TimeMillisecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("timeMillisecondValue")); + GeneratedField::Columns => { + if columns__.is_some() { + return Err(serde::de::Error::duplicate_field("columns")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_timestamp_value::Value::TimeMillisecondValue(x.0)); + columns__ = Some(map_.next_value()?); } } } - Ok(ScalarTimestampValue { - timezone: timezone__.unwrap_or_default(), - value: value__, + Ok(ProjectionColumns { + columns: columns__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ScalarTimestampValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ProjectionColumns", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarUdfExprNode { +impl serde::Serialize for ProjectionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -23128,48 +17097,46 @@ impl serde::Serialize for ScalarUdfExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.fun_name.is_empty() { + if self.input.is_some() { len += 1; } - if !self.args.is_empty() { + if !self.expr.is_empty() { len += 1; } - if self.fun_definition.is_some() { + if !self.expr_name.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarUDFExprNode", len)?; - if !self.fun_name.is_empty() { - struct_ser.serialize_field("funName", &self.fun_name)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ProjectionExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if !self.args.is_empty() { - struct_ser.serialize_field("args", &self.args)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } - if let Some(v) = self.fun_definition.as_ref() { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + if !self.expr_name.is_empty() { + struct_ser.serialize_field("exprName", &self.expr_name)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { +impl<'de> serde::Deserialize<'de> for ProjectionExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "fun_name", - "funName", - "args", - "fun_definition", - "funDefinition", + "input", + "expr", + "expr_name", + "exprName", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FunName, - Args, - FunDefinition, + Input, + Expr, + ExprName, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23191,9 +17158,9 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { E: serde::de::Error, { match value { - "funName" | "fun_name" => Ok(GeneratedField::FunName), - "args" => Ok(GeneratedField::Args), - "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), + "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), + "exprName" | "expr_name" => Ok(GeneratedField::ExprName), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23203,54 +17170,52 @@ impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarUdfExprNode; + type Value = ProjectionExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarUDFExprNode") + formatter.write_str("struct datafusion.ProjectionExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut fun_name__ = None; - let mut args__ = None; - let mut fun_definition__ = None; + let mut input__ = None; + let mut expr__ = None; + let mut expr_name__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::FunName => { - if fun_name__.is_some() { - return Err(serde::de::Error::duplicate_field("funName")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - fun_name__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Args => { - if args__.is_some() { - return Err(serde::de::Error::duplicate_field("args")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - args__ = Some(map_.next_value()?); + expr__ = Some(map_.next_value()?); } - GeneratedField::FunDefinition => { - if fun_definition__.is_some() { - return Err(serde::de::Error::duplicate_field("funDefinition")); + GeneratedField::ExprName => { + if expr_name__.is_some() { + return Err(serde::de::Error::duplicate_field("exprName")); } - fun_definition__ = - map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) - ; + expr_name__ = Some(map_.next_value()?); } } } - Ok(ScalarUdfExprNode { - fun_name: fun_name__.unwrap_or_default(), - args: args__.unwrap_or_default(), - fun_definition: fun_definition__, + Ok(ProjectionExecNode { + input: input__, + expr: expr__.unwrap_or_default(), + expr_name: expr_name__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ScalarUDFExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ProjectionExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScalarValue { +impl serde::Serialize for ProjectionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -23258,260 +17223,49 @@ impl serde::Serialize for ScalarValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.value.is_some() { + if self.input.is_some() { + len += 1; + } + if !self.expr.is_empty() { + len += 1; + } + if self.optional_alias.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScalarValue", len)?; - if let Some(v) = self.value.as_ref() { + let mut struct_ser = serializer.serialize_struct("datafusion.ProjectionNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + if let Some(v) = self.optional_alias.as_ref() { match v { - scalar_value::Value::NullValue(v) => { - struct_ser.serialize_field("nullValue", v)?; - } - scalar_value::Value::BoolValue(v) => { - struct_ser.serialize_field("boolValue", v)?; - } - scalar_value::Value::Utf8Value(v) => { - struct_ser.serialize_field("utf8Value", v)?; - } - scalar_value::Value::LargeUtf8Value(v) => { - struct_ser.serialize_field("largeUtf8Value", v)?; - } - scalar_value::Value::Int8Value(v) => { - struct_ser.serialize_field("int8Value", v)?; - } - scalar_value::Value::Int16Value(v) => { - struct_ser.serialize_field("int16Value", v)?; - } - scalar_value::Value::Int32Value(v) => { - struct_ser.serialize_field("int32Value", v)?; - } - scalar_value::Value::Int64Value(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("int64Value", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::Uint8Value(v) => { - struct_ser.serialize_field("uint8Value", v)?; - } - scalar_value::Value::Uint16Value(v) => { - struct_ser.serialize_field("uint16Value", v)?; - } - scalar_value::Value::Uint32Value(v) => { - struct_ser.serialize_field("uint32Value", v)?; - } - scalar_value::Value::Uint64Value(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("uint64Value", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::Float32Value(v) => { - struct_ser.serialize_field("float32Value", v)?; - } - scalar_value::Value::Float64Value(v) => { - struct_ser.serialize_field("float64Value", v)?; - } - scalar_value::Value::Date32Value(v) => { - struct_ser.serialize_field("date32Value", v)?; - } - scalar_value::Value::Time32Value(v) => { - struct_ser.serialize_field("time32Value", v)?; - } - scalar_value::Value::LargeListValue(v) => { - struct_ser.serialize_field("largeListValue", v)?; - } - scalar_value::Value::ListValue(v) => { - struct_ser.serialize_field("listValue", v)?; - } - scalar_value::Value::FixedSizeListValue(v) => { - struct_ser.serialize_field("fixedSizeListValue", v)?; - } - scalar_value::Value::StructValue(v) => { - struct_ser.serialize_field("structValue", v)?; - } - scalar_value::Value::Decimal128Value(v) => { - struct_ser.serialize_field("decimal128Value", v)?; - } - scalar_value::Value::Decimal256Value(v) => { - struct_ser.serialize_field("decimal256Value", v)?; - } - scalar_value::Value::Date64Value(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("date64Value", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::IntervalYearmonthValue(v) => { - struct_ser.serialize_field("intervalYearmonthValue", v)?; - } - scalar_value::Value::IntervalDaytimeValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("intervalDaytimeValue", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::DurationSecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("durationSecondValue", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::DurationMillisecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("durationMillisecondValue", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::DurationMicrosecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("durationMicrosecondValue", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::DurationNanosecondValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("durationNanosecondValue", ToString::to_string(&v).as_str())?; - } - scalar_value::Value::TimestampValue(v) => { - struct_ser.serialize_field("timestampValue", v)?; - } - scalar_value::Value::DictionaryValue(v) => { - struct_ser.serialize_field("dictionaryValue", v)?; - } - scalar_value::Value::BinaryValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("binaryValue", pbjson::private::base64::encode(&v).as_str())?; - } - scalar_value::Value::LargeBinaryValue(v) => { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("largeBinaryValue", pbjson::private::base64::encode(&v).as_str())?; - } - scalar_value::Value::Time64Value(v) => { - struct_ser.serialize_field("time64Value", v)?; - } - scalar_value::Value::IntervalMonthDayNano(v) => { - struct_ser.serialize_field("intervalMonthDayNano", v)?; - } - scalar_value::Value::FixedSizeBinaryValue(v) => { - struct_ser.serialize_field("fixedSizeBinaryValue", v)?; - } - scalar_value::Value::UnionValue(v) => { - struct_ser.serialize_field("unionValue", v)?; + projection_node::OptionalAlias::Alias(v) => { + struct_ser.serialize_field("alias", v)?; } } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScalarValue { +impl<'de> serde::Deserialize<'de> for ProjectionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "null_value", - "nullValue", - "bool_value", - "boolValue", - "utf8_value", - "utf8Value", - "large_utf8_value", - "largeUtf8Value", - "int8_value", - "int8Value", - "int16_value", - "int16Value", - "int32_value", - "int32Value", - "int64_value", - "int64Value", - "uint8_value", - "uint8Value", - "uint16_value", - "uint16Value", - "uint32_value", - "uint32Value", - "uint64_value", - "uint64Value", - "float32_value", - "float32Value", - "float64_value", - "float64Value", - "date_32_value", - "date32Value", - "time32_value", - "time32Value", - "large_list_value", - "largeListValue", - "list_value", - "listValue", - "fixed_size_list_value", - "fixedSizeListValue", - "struct_value", - "structValue", - "decimal128_value", - "decimal128Value", - "decimal256_value", - "decimal256Value", - "date_64_value", - "date64Value", - "interval_yearmonth_value", - "intervalYearmonthValue", - "interval_daytime_value", - "intervalDaytimeValue", - "duration_second_value", - "durationSecondValue", - "duration_millisecond_value", - "durationMillisecondValue", - "duration_microsecond_value", - "durationMicrosecondValue", - "duration_nanosecond_value", - "durationNanosecondValue", - "timestamp_value", - "timestampValue", - "dictionary_value", - "dictionaryValue", - "binary_value", - "binaryValue", - "large_binary_value", - "largeBinaryValue", - "time64_value", - "time64Value", - "interval_month_day_nano", - "intervalMonthDayNano", - "fixed_size_binary_value", - "fixedSizeBinaryValue", - "union_value", - "unionValue", + "input", + "expr", + "alias", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - NullValue, - BoolValue, - Utf8Value, - LargeUtf8Value, - Int8Value, - Int16Value, - Int32Value, - Int64Value, - Uint8Value, - Uint16Value, - Uint32Value, - Uint64Value, - Float32Value, - Float64Value, - Date32Value, - Time32Value, - LargeListValue, - ListValue, - FixedSizeListValue, - StructValue, - Decimal128Value, - Decimal256Value, - Date64Value, - IntervalYearmonthValue, - IntervalDaytimeValue, - DurationSecondValue, - DurationMillisecondValue, - DurationMicrosecondValue, - DurationNanosecondValue, - TimestampValue, - DictionaryValue, - BinaryValue, - LargeBinaryValue, - Time64Value, - IntervalMonthDayNano, - FixedSizeBinaryValue, - UnionValue, + Input, + Expr, + Alias, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23533,43 +17287,9 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { E: serde::de::Error, { match value { - "nullValue" | "null_value" => Ok(GeneratedField::NullValue), - "boolValue" | "bool_value" => Ok(GeneratedField::BoolValue), - "utf8Value" | "utf8_value" => Ok(GeneratedField::Utf8Value), - "largeUtf8Value" | "large_utf8_value" => Ok(GeneratedField::LargeUtf8Value), - "int8Value" | "int8_value" => Ok(GeneratedField::Int8Value), - "int16Value" | "int16_value" => Ok(GeneratedField::Int16Value), - "int32Value" | "int32_value" => Ok(GeneratedField::Int32Value), - "int64Value" | "int64_value" => Ok(GeneratedField::Int64Value), - "uint8Value" | "uint8_value" => Ok(GeneratedField::Uint8Value), - "uint16Value" | "uint16_value" => Ok(GeneratedField::Uint16Value), - "uint32Value" | "uint32_value" => Ok(GeneratedField::Uint32Value), - "uint64Value" | "uint64_value" => Ok(GeneratedField::Uint64Value), - "float32Value" | "float32_value" => Ok(GeneratedField::Float32Value), - "float64Value" | "float64_value" => Ok(GeneratedField::Float64Value), - "date32Value" | "date_32_value" => Ok(GeneratedField::Date32Value), - "time32Value" | "time32_value" => Ok(GeneratedField::Time32Value), - "largeListValue" | "large_list_value" => Ok(GeneratedField::LargeListValue), - "listValue" | "list_value" => Ok(GeneratedField::ListValue), - "fixedSizeListValue" | "fixed_size_list_value" => Ok(GeneratedField::FixedSizeListValue), - "structValue" | "struct_value" => Ok(GeneratedField::StructValue), - "decimal128Value" | "decimal128_value" => Ok(GeneratedField::Decimal128Value), - "decimal256Value" | "decimal256_value" => Ok(GeneratedField::Decimal256Value), - "date64Value" | "date_64_value" => Ok(GeneratedField::Date64Value), - "intervalYearmonthValue" | "interval_yearmonth_value" => Ok(GeneratedField::IntervalYearmonthValue), - "intervalDaytimeValue" | "interval_daytime_value" => Ok(GeneratedField::IntervalDaytimeValue), - "durationSecondValue" | "duration_second_value" => Ok(GeneratedField::DurationSecondValue), - "durationMillisecondValue" | "duration_millisecond_value" => Ok(GeneratedField::DurationMillisecondValue), - "durationMicrosecondValue" | "duration_microsecond_value" => Ok(GeneratedField::DurationMicrosecondValue), - "durationNanosecondValue" | "duration_nanosecond_value" => Ok(GeneratedField::DurationNanosecondValue), - "timestampValue" | "timestamp_value" => Ok(GeneratedField::TimestampValue), - "dictionaryValue" | "dictionary_value" => Ok(GeneratedField::DictionaryValue), - "binaryValue" | "binary_value" => Ok(GeneratedField::BinaryValue), - "largeBinaryValue" | "large_binary_value" => Ok(GeneratedField::LargeBinaryValue), - "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), - "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), - "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), - "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), + "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), + "alias" => Ok(GeneratedField::Alias), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23579,266 +17299,181 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScalarValue; + type Value = ProjectionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScalarValue") + formatter.write_str("struct datafusion.ProjectionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value__ = None; + let mut input__ = None; + let mut expr__ = None; + let mut optional_alias__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::NullValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("nullValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::NullValue) -; - } - GeneratedField::BoolValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("boolValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::BoolValue); - } - GeneratedField::Utf8Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("utf8Value")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Utf8Value); - } - GeneratedField::LargeUtf8Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("largeUtf8Value")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeUtf8Value); - } - GeneratedField::Int8Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("int8Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int8Value(x.0)); - } - GeneratedField::Int16Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("int16Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int16Value(x.0)); - } - GeneratedField::Int32Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("int32Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int32Value(x.0)); - } - GeneratedField::Int64Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("int64Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Int64Value(x.0)); - } - GeneratedField::Uint8Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("uint8Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint8Value(x.0)); - } - GeneratedField::Uint16Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("uint16Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint16Value(x.0)); - } - GeneratedField::Uint32Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("uint32Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint32Value(x.0)); - } - GeneratedField::Uint64Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("uint64Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Uint64Value(x.0)); - } - GeneratedField::Float32Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("float32Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float32Value(x.0)); - } - GeneratedField::Float64Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("float64Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Float64Value(x.0)); - } - GeneratedField::Date32Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("date32Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date32Value(x.0)); - } - GeneratedField::Time32Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("time32Value")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time32Value) -; - } - GeneratedField::LargeListValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("largeListValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::LargeListValue) -; - } - GeneratedField::ListValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("listValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::ListValue) -; - } - GeneratedField::FixedSizeListValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("fixedSizeListValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeListValue) -; - } - GeneratedField::StructValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("structValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::StructValue) -; - } - GeneratedField::Decimal128Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("decimal128Value")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal128Value) -; - } - GeneratedField::Decimal256Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("decimal256Value")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Decimal256Value) -; - } - GeneratedField::Date64Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("date64Value")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::Date64Value(x.0)); - } - GeneratedField::IntervalYearmonthValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("intervalYearmonthValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalYearmonthValue(x.0)); - } - GeneratedField::IntervalDaytimeValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("intervalDaytimeValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::IntervalDaytimeValue(x.0)); - } - GeneratedField::DurationSecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("durationSecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationSecondValue(x.0)); - } - GeneratedField::DurationMillisecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("durationMillisecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMillisecondValue(x.0)); - } - GeneratedField::DurationMicrosecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("durationMicrosecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationMicrosecondValue(x.0)); - } - GeneratedField::DurationNanosecondValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("durationNanosecondValue")); - } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| scalar_value::Value::DurationNanosecondValue(x.0)); - } - GeneratedField::TimestampValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("timestampValue")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::TimestampValue) -; - } - GeneratedField::DictionaryValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("dictionaryValue")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::DictionaryValue) -; + input__ = map_.next_value()?; } - GeneratedField::BinaryValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("binaryValue")); + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::BinaryValue(x.0)); + expr__ = Some(map_.next_value()?); } - GeneratedField::LargeBinaryValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("largeBinaryValue")); + GeneratedField::Alias => { + if optional_alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); } - value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| scalar_value::Value::LargeBinaryValue(x.0)); + optional_alias__ = map_.next_value::<::std::option::Option<_>>()?.map(projection_node::OptionalAlias::Alias); } - GeneratedField::Time64Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("time64Value")); - } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::Time64Value) -; + } + } + Ok(ProjectionNode { + input: input__, + expr: expr__.unwrap_or_default(), + optional_alias: optional_alias__, + }) + } + } + deserializer.deserialize_struct("datafusion.ProjectionNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for RecursionUnnestOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.output_column.is_some() { + len += 1; + } + if self.input_column.is_some() { + len += 1; + } + if self.depth != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RecursionUnnestOption", len)?; + if let Some(v) = self.output_column.as_ref() { + struct_ser.serialize_field("outputColumn", v)?; + } + if let Some(v) = self.input_column.as_ref() { + struct_ser.serialize_field("inputColumn", v)?; + } + if self.depth != 0 { + struct_ser.serialize_field("depth", &self.depth)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for RecursionUnnestOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "output_column", + "outputColumn", + "input_column", + "inputColumn", + "depth", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + OutputColumn, + InputColumn, + Depth, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "outputColumn" | "output_column" => Ok(GeneratedField::OutputColumn), + "inputColumn" | "input_column" => Ok(GeneratedField::InputColumn), + "depth" => Ok(GeneratedField::Depth), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } - GeneratedField::IntervalMonthDayNano => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("intervalMonthDayNano")); + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = RecursionUnnestOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.RecursionUnnestOption") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut output_column__ = None; + let mut input_column__ = None; + let mut depth__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::OutputColumn => { + if output_column__.is_some() { + return Err(serde::de::Error::duplicate_field("outputColumn")); } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::IntervalMonthDayNano) -; + output_column__ = map_.next_value()?; } - GeneratedField::FixedSizeBinaryValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); + GeneratedField::InputColumn => { + if input_column__.is_some() { + return Err(serde::de::Error::duplicate_field("inputColumn")); } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) -; + input_column__ = map_.next_value()?; } - GeneratedField::UnionValue => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("unionValue")); + GeneratedField::Depth => { + if depth__.is_some() { + return Err(serde::de::Error::duplicate_field("depth")); } - value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) -; + depth__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(ScalarValue { - value: value__, + Ok(RecursionUnnestOption { + output_column: output_column__, + input_column: input_column__, + depth: depth__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ScalarValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.RecursionUnnestOption", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ScanLimit { +impl serde::Serialize for RepartitionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -23846,29 +17481,37 @@ impl serde::Serialize for ScanLimit { { use serde::ser::SerializeStruct; let mut len = 0; - if self.limit != 0 { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ScanLimit", len)?; - if self.limit != 0 { - struct_ser.serialize_field("limit", &self.limit)?; + if self.partitioning.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.RepartitionExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if let Some(v) = self.partitioning.as_ref() { + struct_ser.serialize_field("partitioning", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ScanLimit { +impl<'de> serde::Deserialize<'de> for RepartitionExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "limit", + "input", + "partitioning", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Limit, + Input, + Partitioning, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23890,7 +17533,8 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { E: serde::de::Error, { match value { - "limit" => Ok(GeneratedField::Limit), + "input" => Ok(GeneratedField::Input), + "partitioning" => Ok(GeneratedField::Partitioning), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23900,38 +17544,44 @@ impl<'de> serde::Deserialize<'de> for ScanLimit { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ScanLimit; + type Value = RepartitionExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ScanLimit") + formatter.write_str("struct datafusion.RepartitionExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut limit__ = None; + let mut input__ = None; + let mut partitioning__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Limit => { - if limit__.is_some() { - return Err(serde::de::Error::duplicate_field("limit")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - limit__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + input__ = map_.next_value()?; + } + GeneratedField::Partitioning => { + if partitioning__.is_some() { + return Err(serde::de::Error::duplicate_field("partitioning")); + } + partitioning__ = map_.next_value()?; } } } - Ok(ScanLimit { - limit: limit__.unwrap_or_default(), + Ok(RepartitionExecNode { + input: input__, + partitioning: partitioning__, }) } } - deserializer.deserialize_struct("datafusion.ScanLimit", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.RepartitionExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Schema { +impl serde::Serialize for RepartitionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -23939,37 +17589,49 @@ impl serde::Serialize for Schema { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.columns.is_empty() { + if self.input.is_some() { len += 1; } - if !self.metadata.is_empty() { + if self.partition_method.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Schema", len)?; - if !self.columns.is_empty() { - struct_ser.serialize_field("columns", &self.columns)?; + let mut struct_ser = serializer.serialize_struct("datafusion.RepartitionNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if !self.metadata.is_empty() { - struct_ser.serialize_field("metadata", &self.metadata)?; + if let Some(v) = self.partition_method.as_ref() { + match v { + repartition_node::PartitionMethod::RoundRobin(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("roundRobin", ToString::to_string(&v).as_str())?; + } + repartition_node::PartitionMethod::Hash(v) => { + struct_ser.serialize_field("hash", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Schema { +impl<'de> serde::Deserialize<'de> for RepartitionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "columns", - "metadata", + "input", + "round_robin", + "roundRobin", + "hash", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Columns, - Metadata, + Input, + RoundRobin, + Hash, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23991,8 +17653,9 @@ impl<'de> serde::Deserialize<'de> for Schema { E: serde::de::Error, { match value { - "columns" => Ok(GeneratedField::Columns), - "metadata" => Ok(GeneratedField::Metadata), + "input" => Ok(GeneratedField::Input), + "roundRobin" | "round_robin" => Ok(GeneratedField::RoundRobin), + "hash" => Ok(GeneratedField::Hash), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24002,46 +17665,51 @@ impl<'de> serde::Deserialize<'de> for Schema { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Schema; + type Value = RepartitionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Schema") + formatter.write_str("struct datafusion.RepartitionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut columns__ = None; - let mut metadata__ = None; + let mut input__ = None; + let mut partition_method__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Columns => { - if columns__.is_some() { - return Err(serde::de::Error::duplicate_field("columns")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - columns__ = Some(map_.next_value()?); + input__ = map_.next_value()?; } - GeneratedField::Metadata => { - if metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("metadata")); + GeneratedField::RoundRobin => { + if partition_method__.is_some() { + return Err(serde::de::Error::duplicate_field("roundRobin")); } - metadata__ = Some( - map_.next_value::>()? - ); + partition_method__ = map_.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| repartition_node::PartitionMethod::RoundRobin(x.0)); + } + GeneratedField::Hash => { + if partition_method__.is_some() { + return Err(serde::de::Error::duplicate_field("hash")); + } + partition_method__ = map_.next_value::<::std::option::Option<_>>()?.map(repartition_node::PartitionMethod::Hash) +; } } } - Ok(Schema { - columns: columns__.unwrap_or_default(), - metadata: metadata__.unwrap_or_default(), + Ok(RepartitionNode { + input: input__, + partition_method: partition_method__, }) } } - deserializer.deserialize_struct("datafusion.Schema", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.RepartitionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SelectionExecNode { +impl serde::Serialize for RollupNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24049,17 +17717,17 @@ impl serde::Serialize for SelectionExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SelectionExecNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.RollupNode", len)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SelectionExecNode { +impl<'de> serde::Deserialize<'de> for RollupNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -24103,13 +17771,13 @@ impl<'de> serde::Deserialize<'de> for SelectionExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SelectionExecNode; + type Value = RollupNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SelectionExecNode") + formatter.write_str("struct datafusion.RollupNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { @@ -24120,19 +17788,19 @@ impl<'de> serde::Deserialize<'de> for SelectionExecNode { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = map_.next_value()?; + expr__ = Some(map_.next_value()?); } } } - Ok(SelectionExecNode { - expr: expr__, + Ok(RollupNode { + expr: expr__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.SelectionExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.RollupNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SelectionNode { +impl serde::Serialize for ScalarUdfExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24140,37 +17808,49 @@ impl serde::Serialize for SelectionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if !self.fun_name.is_empty() { len += 1; } - if self.expr.is_some() { + if !self.args.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SelectionNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if self.fun_definition.is_some() { + len += 1; } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ScalarUDFExprNode", len)?; + if !self.fun_name.is_empty() { + struct_ser.serialize_field("funName", &self.fun_name)?; + } + if !self.args.is_empty() { + struct_ser.serialize_field("args", &self.args)?; + } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SelectionNode { +impl<'de> serde::Deserialize<'de> for ScalarUdfExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", - "expr", + "fun_name", + "funName", + "args", + "fun_definition", + "funDefinition", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, - Expr, + FunName, + Args, + FunDefinition, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24192,8 +17872,9 @@ impl<'de> serde::Deserialize<'de> for SelectionNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), - "expr" => Ok(GeneratedField::Expr), + "funName" | "fun_name" => Ok(GeneratedField::FunName), + "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24203,44 +17884,54 @@ impl<'de> serde::Deserialize<'de> for SelectionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SelectionNode; + type Value = ScalarUdfExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SelectionNode") + formatter.write_str("struct datafusion.ScalarUDFExprNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; - let mut expr__ = None; + let mut fun_name__ = None; + let mut args__ = None; + let mut fun_definition__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::FunName => { + if fun_name__.is_some() { + return Err(serde::de::Error::duplicate_field("funName")); + } + fun_name__ = Some(map_.next_value()?); + } + GeneratedField::Args => { + if args__.is_some() { + return Err(serde::de::Error::duplicate_field("args")); } - input__ = map_.next_value()?; + args__ = Some(map_.next_value()?); } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); } - expr__ = map_.next_value()?; + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; } } } - Ok(SelectionNode { - input: input__, - expr: expr__, + Ok(ScalarUdfExprNode { + fun_name: fun_name__.unwrap_or_default(), + args: args__.unwrap_or_default(), + fun_definition: fun_definition__, }) } } - deserializer.deserialize_struct("datafusion.SelectionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ScalarUDFExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SimilarToNode { +impl serde::Serialize for ScanLimit { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24248,54 +17939,29 @@ impl serde::Serialize for SimilarToNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.negated { - len += 1; - } - if self.expr.is_some() { - len += 1; - } - if self.pattern.is_some() { - len += 1; - } - if !self.escape_char.is_empty() { + if self.limit != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SimilarToNode", len)?; - if self.negated { - struct_ser.serialize_field("negated", &self.negated)?; - } - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; - } - if let Some(v) = self.pattern.as_ref() { - struct_ser.serialize_field("pattern", v)?; - } - if !self.escape_char.is_empty() { - struct_ser.serialize_field("escapeChar", &self.escape_char)?; + let mut struct_ser = serializer.serialize_struct("datafusion.ScanLimit", len)?; + if self.limit != 0 { + struct_ser.serialize_field("limit", &self.limit)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SimilarToNode { +impl<'de> serde::Deserialize<'de> for ScanLimit { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "negated", - "expr", - "pattern", - "escape_char", - "escapeChar", + "limit", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Negated, - Expr, - Pattern, - EscapeChar, + Limit, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24317,10 +17983,7 @@ impl<'de> serde::Deserialize<'de> for SimilarToNode { E: serde::de::Error, { match value { - "negated" => Ok(GeneratedField::Negated), - "expr" => Ok(GeneratedField::Expr), - "pattern" => Ok(GeneratedField::Pattern), - "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), + "limit" => Ok(GeneratedField::Limit), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24330,60 +17993,38 @@ impl<'de> serde::Deserialize<'de> for SimilarToNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SimilarToNode; + type Value = ScanLimit; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SimilarToNode") + formatter.write_str("struct datafusion.ScanLimit") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut negated__ = None; - let mut expr__ = None; - let mut pattern__ = None; - let mut escape_char__ = None; + let mut limit__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Negated => { - if negated__.is_some() { - return Err(serde::de::Error::duplicate_field("negated")); - } - negated__ = Some(map_.next_value()?); - } - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); - } - expr__ = map_.next_value()?; - } - GeneratedField::Pattern => { - if pattern__.is_some() { - return Err(serde::de::Error::duplicate_field("pattern")); - } - pattern__ = map_.next_value()?; - } - GeneratedField::EscapeChar => { - if escape_char__.is_some() { - return Err(serde::de::Error::duplicate_field("escapeChar")); + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); } - escape_char__ = Some(map_.next_value()?); + limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(SimilarToNode { - negated: negated__.unwrap_or_default(), - expr: expr__, - pattern: pattern__, - escape_char: escape_char__.unwrap_or_default(), + Ok(ScanLimit { + limit: limit__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.SimilarToNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.ScanLimit", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SortExecNode { +impl serde::Serialize for SelectionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24391,55 +18032,29 @@ impl serde::Serialize for SortExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { - len += 1; - } - if !self.expr.is_empty() { - len += 1; - } - if self.fetch != 0 { - len += 1; - } - if self.preserve_partitioning { + if self.expr.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SortExecNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; - } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; - } - if self.fetch != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; - } - if self.preserve_partitioning { - struct_ser.serialize_field("preservePartitioning", &self.preserve_partitioning)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SelectionExecNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SortExecNode { +impl<'de> serde::Deserialize<'de> for SelectionExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", "expr", - "fetch", - "preserve_partitioning", - "preservePartitioning", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, Expr, - Fetch, - PreservePartitioning, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24461,10 +18076,7 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), - "fetch" => Ok(GeneratedField::Fetch), - "preservePartitioning" | "preserve_partitioning" => Ok(GeneratedField::PreservePartitioning), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24474,62 +18086,36 @@ impl<'de> serde::Deserialize<'de> for SortExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SortExecNode; + type Value = SelectionExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SortExecNode") + formatter.write_str("struct datafusion.SelectionExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; let mut expr__ = None; - let mut fetch__ = None; - let mut preserve_partitioning__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map_.next_value()?); - } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); - } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::PreservePartitioning => { - if preserve_partitioning__.is_some() { - return Err(serde::de::Error::duplicate_field("preservePartitioning")); - } - preserve_partitioning__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } } } - Ok(SortExecNode { - input: input__, - expr: expr__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), - preserve_partitioning: preserve_partitioning__.unwrap_or_default(), + Ok(SelectionExecNode { + expr: expr__, }) } } - deserializer.deserialize_struct("datafusion.SortExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SelectionExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SortExprNode { +impl serde::Serialize for SelectionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24537,46 +18123,37 @@ impl serde::Serialize for SortExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if self.input.is_some() { len += 1; } - if self.asc { + if self.expr.is_some() { len += 1; } - if self.nulls_first { - len += 1; + let mut struct_ser = serializer.serialize_struct("datafusion.SelectionNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - let mut struct_ser = serializer.serialize_struct("datafusion.SortExprNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; } - if self.asc { - struct_ser.serialize_field("asc", &self.asc)?; - } - if self.nulls_first { - struct_ser.serialize_field("nullsFirst", &self.nulls_first)?; - } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SortExprNode { +impl<'de> serde::Deserialize<'de> for SelectionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ + "input", "expr", - "asc", - "nulls_first", - "nullsFirst", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { + Input, Expr, - Asc, - NullsFirst, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24598,9 +18175,8 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { E: serde::de::Error, { match value { + "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), - "asc" => Ok(GeneratedField::Asc), - "nullsFirst" | "nulls_first" => Ok(GeneratedField::NullsFirst), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24610,52 +18186,44 @@ impl<'de> serde::Deserialize<'de> for SortExprNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SortExprNode; + type Value = SelectionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SortExprNode") + formatter.write_str("struct datafusion.SelectionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { + let mut input__ = None; let mut expr__ = None; - let mut asc__ = None; - let mut nulls_first__ = None; while let Some(k) = map_.next_key()? { match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } expr__ = map_.next_value()?; } - GeneratedField::Asc => { - if asc__.is_some() { - return Err(serde::de::Error::duplicate_field("asc")); - } - asc__ = Some(map_.next_value()?); - } - GeneratedField::NullsFirst => { - if nulls_first__.is_some() { - return Err(serde::de::Error::duplicate_field("nullsFirst")); - } - nulls_first__ = Some(map_.next_value()?); - } } } - Ok(SortExprNode { + Ok(SelectionNode { + input: input__, expr: expr__, - asc: asc__.unwrap_or_default(), - nulls_first: nulls_first__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.SortExprNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SelectionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SortNode { +impl serde::Serialize for SimilarToNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24663,46 +18231,54 @@ impl serde::Serialize for SortNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.input.is_some() { + if self.negated { len += 1; } - if !self.expr.is_empty() { + if self.expr.is_some() { len += 1; } - if self.fetch != 0 { + if self.pattern.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SortNode", len)?; - if let Some(v) = self.input.as_ref() { - struct_ser.serialize_field("input", v)?; + if !self.escape_char.is_empty() { + len += 1; } - if !self.expr.is_empty() { - struct_ser.serialize_field("expr", &self.expr)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SimilarToNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; } - if self.fetch != 0 { - #[allow(clippy::needless_borrow)] - struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + if !self.escape_char.is_empty() { + struct_ser.serialize_field("escapeChar", &self.escape_char)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SortNode { +impl<'de> serde::Deserialize<'de> for SimilarToNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "input", + "negated", "expr", - "fetch", + "pattern", + "escape_char", + "escapeChar", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Input, + Negated, Expr, - Fetch, + Pattern, + EscapeChar, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24724,9 +18300,10 @@ impl<'de> serde::Deserialize<'de> for SortNode { E: serde::de::Error, { match value { - "input" => Ok(GeneratedField::Input), + "negated" => Ok(GeneratedField::Negated), "expr" => Ok(GeneratedField::Expr), - "fetch" => Ok(GeneratedField::Fetch), + "pattern" => Ok(GeneratedField::Pattern), + "escapeChar" | "escape_char" => Ok(GeneratedField::EscapeChar), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24736,54 +18313,60 @@ impl<'de> serde::Deserialize<'de> for SortNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SortNode; + type Value = SimilarToNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SortNode") + formatter.write_str("struct datafusion.SimilarToNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut input__ = None; + let mut negated__ = None; let mut expr__ = None; - let mut fetch__ = None; + let mut pattern__ = None; + let mut escape_char__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); } - input__ = map_.next_value()?; + negated__ = Some(map_.next_value()?); } GeneratedField::Expr => { if expr__.is_some() { return Err(serde::de::Error::duplicate_field("expr")); } - expr__ = Some(map_.next_value()?); + expr__ = map_.next_value()?; } - GeneratedField::Fetch => { - if fetch__.is_some() { - return Err(serde::de::Error::duplicate_field("fetch")); + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); } - fetch__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + pattern__ = map_.next_value()?; + } + GeneratedField::EscapeChar => { + if escape_char__.is_some() { + return Err(serde::de::Error::duplicate_field("escapeChar")); + } + escape_char__ = Some(map_.next_value()?); } } } - Ok(SortNode { - input: input__, - expr: expr__.unwrap_or_default(), - fetch: fetch__.unwrap_or_default(), + Ok(SimilarToNode { + negated: negated__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + escape_char: escape_char__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.SortNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SimilarToNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SortPreservingMergeExecNode { +impl serde::Serialize for SortExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24800,7 +18383,10 @@ impl serde::Serialize for SortPreservingMergeExecNode { if self.fetch != 0 { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?; + if self.preserve_partitioning { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SortExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } @@ -24809,12 +18395,16 @@ impl serde::Serialize for SortPreservingMergeExecNode { } if self.fetch != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } + if self.preserve_partitioning { + struct_ser.serialize_field("preservePartitioning", &self.preserve_partitioning)?; + } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { +impl<'de> serde::Deserialize<'de> for SortExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -24824,6 +18414,8 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { "input", "expr", "fetch", + "preserve_partitioning", + "preservePartitioning", ]; #[allow(clippy::enum_variant_names)] @@ -24831,6 +18423,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { Input, Expr, Fetch, + PreservePartitioning, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24855,6 +18448,7 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { "input" => Ok(GeneratedField::Input), "expr" => Ok(GeneratedField::Expr), "fetch" => Ok(GeneratedField::Fetch), + "preservePartitioning" | "preserve_partitioning" => Ok(GeneratedField::PreservePartitioning), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24864,19 +18458,20 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SortPreservingMergeExecNode; + type Value = SortExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SortPreservingMergeExecNode") + formatter.write_str("struct datafusion.SortExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { let mut input__ = None; let mut expr__ = None; let mut fetch__ = None; + let mut preserve_partitioning__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Input => { @@ -24899,19 +18494,26 @@ impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::PreservePartitioning => { + if preserve_partitioning__.is_some() { + return Err(serde::de::Error::duplicate_field("preservePartitioning")); + } + preserve_partitioning__ = Some(map_.next_value()?); + } } } - Ok(SortPreservingMergeExecNode { + Ok(SortExecNode { input: input__, expr: expr__.unwrap_or_default(), fetch: fetch__.unwrap_or_default(), + preserve_partitioning: preserve_partitioning__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.SortPreservingMergeExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SortExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Statistics { +impl serde::Serialize for SortExprNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -24919,48 +18521,46 @@ impl serde::Serialize for Statistics { { use serde::ser::SerializeStruct; let mut len = 0; - if self.num_rows.is_some() { + if self.expr.is_some() { len += 1; } - if self.total_byte_size.is_some() { + if self.asc { len += 1; } - if !self.column_stats.is_empty() { + if self.nulls_first { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Statistics", len)?; - if let Some(v) = self.num_rows.as_ref() { - struct_ser.serialize_field("numRows", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SortExprNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; } - if let Some(v) = self.total_byte_size.as_ref() { - struct_ser.serialize_field("totalByteSize", v)?; + if self.asc { + struct_ser.serialize_field("asc", &self.asc)?; } - if !self.column_stats.is_empty() { - struct_ser.serialize_field("columnStats", &self.column_stats)?; + if self.nulls_first { + struct_ser.serialize_field("nullsFirst", &self.nulls_first)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Statistics { +impl<'de> serde::Deserialize<'de> for SortExprNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "num_rows", - "numRows", - "total_byte_size", - "totalByteSize", - "column_stats", - "columnStats", + "expr", + "asc", + "nulls_first", + "nullsFirst", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - NumRows, - TotalByteSize, - ColumnStats, + Expr, + Asc, + NullsFirst, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24982,9 +18582,9 @@ impl<'de> serde::Deserialize<'de> for Statistics { E: serde::de::Error, { match value { - "numRows" | "num_rows" => Ok(GeneratedField::NumRows), - "totalByteSize" | "total_byte_size" => Ok(GeneratedField::TotalByteSize), - "columnStats" | "column_stats" => Ok(GeneratedField::ColumnStats), + "expr" => Ok(GeneratedField::Expr), + "asc" => Ok(GeneratedField::Asc), + "nullsFirst" | "nulls_first" => Ok(GeneratedField::NullsFirst), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24994,123 +18594,52 @@ impl<'de> serde::Deserialize<'de> for Statistics { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Statistics; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Statistics") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut num_rows__ = None; - let mut total_byte_size__ = None; - let mut column_stats__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::NumRows => { - if num_rows__.is_some() { - return Err(serde::de::Error::duplicate_field("numRows")); - } - num_rows__ = map_.next_value()?; - } - GeneratedField::TotalByteSize => { - if total_byte_size__.is_some() { - return Err(serde::de::Error::duplicate_field("totalByteSize")); - } - total_byte_size__ = map_.next_value()?; - } - GeneratedField::ColumnStats => { - if column_stats__.is_some() { - return Err(serde::de::Error::duplicate_field("columnStats")); - } - column_stats__ = Some(map_.next_value()?); - } - } - } - Ok(Statistics { - num_rows: num_rows__, - total_byte_size: total_byte_size__, - column_stats: column_stats__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.Statistics", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for StreamPartitionMode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::SinglePartition => "SINGLE_PARTITION", - Self::PartitionedExec => "PARTITIONED_EXEC", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for StreamPartitionMode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "SINGLE_PARTITION", - "PARTITIONED_EXEC", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = StreamPartitionMode; + type Value = SortExprNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) + formatter.write_str("struct datafusion.SortExprNode") } - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, { - match value { - "SINGLE_PARTITION" => Ok(StreamPartitionMode::SinglePartition), - "PARTITIONED_EXEC" => Ok(StreamPartitionMode::PartitionedExec), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + let mut expr__ = None; + let mut asc__ = None; + let mut nulls_first__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Asc => { + if asc__.is_some() { + return Err(serde::de::Error::duplicate_field("asc")); + } + asc__ = Some(map_.next_value()?); + } + GeneratedField::NullsFirst => { + if nulls_first__.is_some() { + return Err(serde::de::Error::duplicate_field("nullsFirst")); + } + nulls_first__ = Some(map_.next_value()?); + } + } } + Ok(SortExprNode { + expr: expr__, + asc: asc__.unwrap_or_default(), + nulls_first: nulls_first__.unwrap_or_default(), + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SortExprNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for StringifiedPlan { +impl serde::Serialize for SortExprNodeCollection { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25118,38 +18647,30 @@ impl serde::Serialize for StringifiedPlan { { use serde::ser::SerializeStruct; let mut len = 0; - if self.plan_type.is_some() { - len += 1; - } - if !self.plan.is_empty() { + if !self.sort_expr_nodes.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.StringifiedPlan", len)?; - if let Some(v) = self.plan_type.as_ref() { - struct_ser.serialize_field("planType", v)?; - } - if !self.plan.is_empty() { - struct_ser.serialize_field("plan", &self.plan)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SortExprNodeCollection", len)?; + if !self.sort_expr_nodes.is_empty() { + struct_ser.serialize_field("sortExprNodes", &self.sort_expr_nodes)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for StringifiedPlan { +impl<'de> serde::Deserialize<'de> for SortExprNodeCollection { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "plan_type", - "planType", - "plan", + "sort_expr_nodes", + "sortExprNodes", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - PlanType, - Plan, + SortExprNodes, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25171,8 +18692,7 @@ impl<'de> serde::Deserialize<'de> for StringifiedPlan { E: serde::de::Error, { match value { - "planType" | "plan_type" => Ok(GeneratedField::PlanType), - "plan" => Ok(GeneratedField::Plan), + "sortExprNodes" | "sort_expr_nodes" => Ok(GeneratedField::SortExprNodes), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -25182,44 +18702,36 @@ impl<'de> serde::Deserialize<'de> for StringifiedPlan { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = StringifiedPlan; + type Value = SortExprNodeCollection; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.StringifiedPlan") + formatter.write_str("struct datafusion.SortExprNodeCollection") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut plan_type__ = None; - let mut plan__ = None; + let mut sort_expr_nodes__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::PlanType => { - if plan_type__.is_some() { - return Err(serde::de::Error::duplicate_field("planType")); - } - plan_type__ = map_.next_value()?; - } - GeneratedField::Plan => { - if plan__.is_some() { - return Err(serde::de::Error::duplicate_field("plan")); + GeneratedField::SortExprNodes => { + if sort_expr_nodes__.is_some() { + return Err(serde::de::Error::duplicate_field("sortExprNodes")); } - plan__ = Some(map_.next_value()?); + sort_expr_nodes__ = Some(map_.next_value()?); } } } - Ok(StringifiedPlan { - plan_type: plan_type__, - plan: plan__.unwrap_or_default(), + Ok(SortExprNodeCollection { + sort_expr_nodes: sort_expr_nodes__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.StringifiedPlan", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SortExprNodeCollection", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Struct { +impl serde::Serialize for SortNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25227,30 +18739,47 @@ impl serde::Serialize for Struct { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.sub_field_types.is_empty() { + if self.input.is_some() { + len += 1; + } + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Struct", len)?; - if !self.sub_field_types.is_empty() { - struct_ser.serialize_field("subFieldTypes", &self.sub_field_types)?; + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SortNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Struct { +impl<'de> serde::Deserialize<'de> for SortNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "sub_field_types", - "subFieldTypes", + "input", + "expr", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - SubFieldTypes, + Input, + Expr, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25272,7 +18801,9 @@ impl<'de> serde::Deserialize<'de> for Struct { E: serde::de::Error, { match value { - "subFieldTypes" | "sub_field_types" => Ok(GeneratedField::SubFieldTypes), + "input" => Ok(GeneratedField::Input), + "expr" => Ok(GeneratedField::Expr), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -25282,36 +18813,54 @@ impl<'de> serde::Deserialize<'de> for Struct { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Struct; + type Value = SortNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Struct") + formatter.write_str("struct datafusion.SortNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut sub_field_types__ = None; + let mut input__ = None; + let mut expr__ = None; + let mut fetch__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::SubFieldTypes => { - if sub_field_types__.is_some() { - return Err(serde::de::Error::duplicate_field("subFieldTypes")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); } - sub_field_types__ = Some(map_.next_value()?); + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; } } } - Ok(Struct { - sub_field_types: sub_field_types__.unwrap_or_default(), + Ok(SortNode { + input: input__, + expr: expr__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Struct", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SortNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for SubqueryAliasNode { +impl serde::Serialize for SortPreservingMergeExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25322,20 +18871,28 @@ impl serde::Serialize for SubqueryAliasNode { if self.input.is_some() { len += 1; } - if self.alias.is_some() { + if !self.expr.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryAliasNode", len)?; + if self.fetch != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SortPreservingMergeExecNode", len)?; if let Some(v) = self.input.as_ref() { struct_ser.serialize_field("input", v)?; } - if let Some(v) = self.alias.as_ref() { - struct_ser.serialize_field("alias", v)?; + if !self.expr.is_empty() { + struct_ser.serialize_field("expr", &self.expr)?; + } + if self.fetch != 0 { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("fetch", ToString::to_string(&self.fetch).as_str())?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { +impl<'de> serde::Deserialize<'de> for SortPreservingMergeExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where @@ -25343,13 +18900,15 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { { const FIELDS: &[&str] = &[ "input", - "alias", + "expr", + "fetch", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Input, - Alias, + Expr, + Fetch, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25372,54 +18931,136 @@ impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { { match value { "input" => Ok(GeneratedField::Input), - "alias" => Ok(GeneratedField::Alias), + "expr" => Ok(GeneratedField::Expr), + "fetch" => Ok(GeneratedField::Fetch), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } } deserializer.deserialize_identifier(GeneratedVisitor) } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SubqueryAliasNode; + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = SortPreservingMergeExecNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.SortPreservingMergeExecNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut input__ = None; + let mut expr__ = None; + let mut fetch__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = Some(map_.next_value()?); + } + GeneratedField::Fetch => { + if fetch__.is_some() { + return Err(serde::de::Error::duplicate_field("fetch")); + } + fetch__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(SortPreservingMergeExecNode { + input: input__, + expr: expr__.unwrap_or_default(), + fetch: fetch__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.SortPreservingMergeExecNode", FIELDS, GeneratedVisitor) + } +} +impl serde::Serialize for StreamPartitionMode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::SinglePartition => "SINGLE_PARTITION", + Self::PartitionedExec => "PARTITIONED_EXEC", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for StreamPartitionMode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "SINGLE_PARTITION", + "PARTITIONED_EXEC", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = StreamPartitionMode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SubqueryAliasNode") + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) } - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, { - let mut input__ = None; - let mut alias__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Input => { - if input__.is_some() { - return Err(serde::de::Error::duplicate_field("input")); - } - input__ = map_.next_value()?; - } - GeneratedField::Alias => { - if alias__.is_some() { - return Err(serde::de::Error::duplicate_field("alias")); - } - alias__ = map_.next_value()?; - } - } + match value { + "SINGLE_PARTITION" => Ok(StreamPartitionMode::SinglePartition), + "PARTITIONED_EXEC" => Ok(StreamPartitionMode::PartitionedExec), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } - Ok(SubqueryAliasNode { - input: input__, - alias: alias__, - }) } } - deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_any(GeneratedVisitor) } } -impl serde::Serialize for SymmetricHashJoinExecNode { +impl serde::Serialize for StringifiedPlan { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25427,102 +19068,38 @@ impl serde::Serialize for SymmetricHashJoinExecNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.left.is_some() { - len += 1; - } - if self.right.is_some() { - len += 1; - } - if !self.on.is_empty() { - len += 1; - } - if self.join_type != 0 { - len += 1; - } - if self.partition_mode != 0 { - len += 1; - } - if self.null_equals_null { - len += 1; - } - if self.filter.is_some() { - len += 1; - } - if !self.left_sort_exprs.is_empty() { + if self.plan_type.is_some() { len += 1; } - if !self.right_sort_exprs.is_empty() { + if !self.plan.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; - if let Some(v) = self.left.as_ref() { - struct_ser.serialize_field("left", v)?; - } - if let Some(v) = self.right.as_ref() { - struct_ser.serialize_field("right", v)?; - } - if !self.on.is_empty() { - struct_ser.serialize_field("on", &self.on)?; - } - if self.join_type != 0 { - let v = JoinType::try_from(self.join_type) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; - struct_ser.serialize_field("joinType", &v)?; - } - if self.partition_mode != 0 { - let v = StreamPartitionMode::try_from(self.partition_mode) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; - struct_ser.serialize_field("partitionMode", &v)?; - } - if self.null_equals_null { - struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; - } - if let Some(v) = self.filter.as_ref() { - struct_ser.serialize_field("filter", v)?; - } - if !self.left_sort_exprs.is_empty() { - struct_ser.serialize_field("leftSortExprs", &self.left_sort_exprs)?; + let mut struct_ser = serializer.serialize_struct("datafusion.StringifiedPlan", len)?; + if let Some(v) = self.plan_type.as_ref() { + struct_ser.serialize_field("planType", v)?; } - if !self.right_sort_exprs.is_empty() { - struct_ser.serialize_field("rightSortExprs", &self.right_sort_exprs)?; + if !self.plan.is_empty() { + struct_ser.serialize_field("plan", &self.plan)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { +impl<'de> serde::Deserialize<'de> for StringifiedPlan { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "left", - "right", - "on", - "join_type", - "joinType", - "partition_mode", - "partitionMode", - "null_equals_null", - "nullEqualsNull", - "filter", - "left_sort_exprs", - "leftSortExprs", - "right_sort_exprs", - "rightSortExprs", + "plan_type", + "planType", + "plan", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Left, - Right, - On, - JoinType, - PartitionMode, - NullEqualsNull, - Filter, - LeftSortExprs, - RightSortExprs, + PlanType, + Plan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25544,15 +19121,8 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { E: serde::de::Error, { match value { - "left" => Ok(GeneratedField::Left), - "right" => Ok(GeneratedField::Right), - "on" => Ok(GeneratedField::On), - "joinType" | "join_type" => Ok(GeneratedField::JoinType), - "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), - "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), - "filter" => Ok(GeneratedField::Filter), - "leftSortExprs" | "left_sort_exprs" => Ok(GeneratedField::LeftSortExprs), - "rightSortExprs" | "right_sort_exprs" => Ok(GeneratedField::RightSortExprs), + "planType" | "plan_type" => Ok(GeneratedField::PlanType), + "plan" => Ok(GeneratedField::Plan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -25562,100 +19132,44 @@ impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = SymmetricHashJoinExecNode; + type Value = StringifiedPlan; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.SymmetricHashJoinExecNode") + formatter.write_str("struct datafusion.StringifiedPlan") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut left__ = None; - let mut right__ = None; - let mut on__ = None; - let mut join_type__ = None; - let mut partition_mode__ = None; - let mut null_equals_null__ = None; - let mut filter__ = None; - let mut left_sort_exprs__ = None; - let mut right_sort_exprs__ = None; + let mut plan_type__ = None; + let mut plan__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Left => { - if left__.is_some() { - return Err(serde::de::Error::duplicate_field("left")); - } - left__ = map_.next_value()?; - } - GeneratedField::Right => { - if right__.is_some() { - return Err(serde::de::Error::duplicate_field("right")); - } - right__ = map_.next_value()?; - } - GeneratedField::On => { - if on__.is_some() { - return Err(serde::de::Error::duplicate_field("on")); - } - on__ = Some(map_.next_value()?); - } - GeneratedField::JoinType => { - if join_type__.is_some() { - return Err(serde::de::Error::duplicate_field("joinType")); - } - join_type__ = Some(map_.next_value::()? as i32); - } - GeneratedField::PartitionMode => { - if partition_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionMode")); - } - partition_mode__ = Some(map_.next_value::()? as i32); - } - GeneratedField::NullEqualsNull => { - if null_equals_null__.is_some() { - return Err(serde::de::Error::duplicate_field("nullEqualsNull")); - } - null_equals_null__ = Some(map_.next_value()?); - } - GeneratedField::Filter => { - if filter__.is_some() { - return Err(serde::de::Error::duplicate_field("filter")); - } - filter__ = map_.next_value()?; - } - GeneratedField::LeftSortExprs => { - if left_sort_exprs__.is_some() { - return Err(serde::de::Error::duplicate_field("leftSortExprs")); + GeneratedField::PlanType => { + if plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("planType")); } - left_sort_exprs__ = Some(map_.next_value()?); + plan_type__ = map_.next_value()?; } - GeneratedField::RightSortExprs => { - if right_sort_exprs__.is_some() { - return Err(serde::de::Error::duplicate_field("rightSortExprs")); + GeneratedField::Plan => { + if plan__.is_some() { + return Err(serde::de::Error::duplicate_field("plan")); } - right_sort_exprs__ = Some(map_.next_value()?); + plan__ = Some(map_.next_value()?); } } } - Ok(SymmetricHashJoinExecNode { - left: left__, - right: right__, - on: on__.unwrap_or_default(), - join_type: join_type__.unwrap_or_default(), - partition_mode: partition_mode__.unwrap_or_default(), - null_equals_null: null_equals_null__.unwrap_or_default(), - filter: filter__, - left_sort_exprs: left_sort_exprs__.unwrap_or_default(), - right_sort_exprs: right_sort_exprs__.unwrap_or_default(), + Ok(StringifiedPlan { + plan_type: plan_type__, + plan: plan__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.StringifiedPlan", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for TableParquetOptions { +impl serde::Serialize for SubqueryAliasNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25663,38 +19177,37 @@ impl serde::Serialize for TableParquetOptions { { use serde::ser::SerializeStruct; let mut len = 0; - if self.global.is_some() { + if self.input.is_some() { len += 1; } - if !self.column_specific_options.is_empty() { + if self.alias.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.TableParquetOptions", len)?; - if let Some(v) = self.global.as_ref() { - struct_ser.serialize_field("global", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.SubqueryAliasNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if !self.column_specific_options.is_empty() { - struct_ser.serialize_field("columnSpecificOptions", &self.column_specific_options)?; + if let Some(v) = self.alias.as_ref() { + struct_ser.serialize_field("alias", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for TableParquetOptions { +impl<'de> serde::Deserialize<'de> for SubqueryAliasNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "global", - "column_specific_options", - "columnSpecificOptions", + "input", + "alias", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Global, - ColumnSpecificOptions, + Input, + Alias, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25716,8 +19229,8 @@ impl<'de> serde::Deserialize<'de> for TableParquetOptions { E: serde::de::Error, { match value { - "global" => Ok(GeneratedField::Global), - "columnSpecificOptions" | "column_specific_options" => Ok(GeneratedField::ColumnSpecificOptions), + "input" => Ok(GeneratedField::Input), + "alias" => Ok(GeneratedField::Alias), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -25727,44 +19240,44 @@ impl<'de> serde::Deserialize<'de> for TableParquetOptions { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = TableParquetOptions; + type Value = SubqueryAliasNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.TableParquetOptions") + formatter.write_str("struct datafusion.SubqueryAliasNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut global__ = None; - let mut column_specific_options__ = None; + let mut input__ = None; + let mut alias__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Global => { - if global__.is_some() { - return Err(serde::de::Error::duplicate_field("global")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - global__ = map_.next_value()?; + input__ = map_.next_value()?; } - GeneratedField::ColumnSpecificOptions => { - if column_specific_options__.is_some() { - return Err(serde::de::Error::duplicate_field("columnSpecificOptions")); + GeneratedField::Alias => { + if alias__.is_some() { + return Err(serde::de::Error::duplicate_field("alias")); } - column_specific_options__ = Some(map_.next_value()?); + alias__ = map_.next_value()?; } } } - Ok(TableParquetOptions { - global: global__, - column_specific_options: column_specific_options__.unwrap_or_default(), + Ok(SubqueryAliasNode { + input: input__, + alias: alias__, }) } } - deserializer.deserialize_struct("datafusion.TableParquetOptions", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SubqueryAliasNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for TableReference { +impl serde::Serialize for SymmetricHashJoinExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25772,43 +19285,102 @@ impl serde::Serialize for TableReference { { use serde::ser::SerializeStruct; let mut len = 0; - if self.table_reference_enum.is_some() { + if self.left.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.TableReference", len)?; - if let Some(v) = self.table_reference_enum.as_ref() { - match v { - table_reference::TableReferenceEnum::Bare(v) => { - struct_ser.serialize_field("bare", v)?; - } - table_reference::TableReferenceEnum::Partial(v) => { - struct_ser.serialize_field("partial", v)?; - } - table_reference::TableReferenceEnum::Full(v) => { - struct_ser.serialize_field("full", v)?; - } - } + if self.right.is_some() { + len += 1; + } + if !self.on.is_empty() { + len += 1; + } + if self.join_type != 0 { + len += 1; + } + if self.partition_mode != 0 { + len += 1; + } + if self.null_equals_null { + len += 1; + } + if self.filter.is_some() { + len += 1; + } + if !self.left_sort_exprs.is_empty() { + len += 1; + } + if !self.right_sort_exprs.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.SymmetricHashJoinExecNode", len)?; + if let Some(v) = self.left.as_ref() { + struct_ser.serialize_field("left", v)?; + } + if let Some(v) = self.right.as_ref() { + struct_ser.serialize_field("right", v)?; + } + if !self.on.is_empty() { + struct_ser.serialize_field("on", &self.on)?; + } + if self.join_type != 0 { + let v = super::datafusion_common::JoinType::try_from(self.join_type) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.join_type)))?; + struct_ser.serialize_field("joinType", &v)?; + } + if self.partition_mode != 0 { + let v = StreamPartitionMode::try_from(self.partition_mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.partition_mode)))?; + struct_ser.serialize_field("partitionMode", &v)?; + } + if self.null_equals_null { + struct_ser.serialize_field("nullEqualsNull", &self.null_equals_null)?; + } + if let Some(v) = self.filter.as_ref() { + struct_ser.serialize_field("filter", v)?; + } + if !self.left_sort_exprs.is_empty() { + struct_ser.serialize_field("leftSortExprs", &self.left_sort_exprs)?; + } + if !self.right_sort_exprs.is_empty() { + struct_ser.serialize_field("rightSortExprs", &self.right_sort_exprs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for TableReference { +impl<'de> serde::Deserialize<'de> for SymmetricHashJoinExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "bare", - "partial", - "full", + "left", + "right", + "on", + "join_type", + "joinType", + "partition_mode", + "partitionMode", + "null_equals_null", + "nullEqualsNull", + "filter", + "left_sort_exprs", + "leftSortExprs", + "right_sort_exprs", + "rightSortExprs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Bare, - Partial, - Full, + Left, + Right, + On, + JoinType, + PartitionMode, + NullEqualsNull, + Filter, + LeftSortExprs, + RightSortExprs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -25830,9 +19402,15 @@ impl<'de> serde::Deserialize<'de> for TableReference { E: serde::de::Error, { match value { - "bare" => Ok(GeneratedField::Bare), - "partial" => Ok(GeneratedField::Partial), - "full" => Ok(GeneratedField::Full), + "left" => Ok(GeneratedField::Left), + "right" => Ok(GeneratedField::Right), + "on" => Ok(GeneratedField::On), + "joinType" | "join_type" => Ok(GeneratedField::JoinType), + "partitionMode" | "partition_mode" => Ok(GeneratedField::PartitionMode), + "nullEqualsNull" | "null_equals_null" => Ok(GeneratedField::NullEqualsNull), + "filter" => Ok(GeneratedField::Filter), + "leftSortExprs" | "left_sort_exprs" => Ok(GeneratedField::LeftSortExprs), + "rightSortExprs" | "right_sort_exprs" => Ok(GeneratedField::RightSortExprs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -25842,128 +19420,100 @@ impl<'de> serde::Deserialize<'de> for TableReference { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = TableReference; + type Value = SymmetricHashJoinExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.TableReference") + formatter.write_str("struct datafusion.SymmetricHashJoinExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut table_reference_enum__ = None; + let mut left__ = None; + let mut right__ = None; + let mut on__ = None; + let mut join_type__ = None; + let mut partition_mode__ = None; + let mut null_equals_null__ = None; + let mut filter__ = None; + let mut left_sort_exprs__ = None; + let mut right_sort_exprs__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Bare => { - if table_reference_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("bare")); + GeneratedField::Left => { + if left__.is_some() { + return Err(serde::de::Error::duplicate_field("left")); } - table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(table_reference::TableReferenceEnum::Bare) -; + left__ = map_.next_value()?; } - GeneratedField::Partial => { - if table_reference_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("partial")); + GeneratedField::Right => { + if right__.is_some() { + return Err(serde::de::Error::duplicate_field("right")); } - table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(table_reference::TableReferenceEnum::Partial) -; + right__ = map_.next_value()?; } - GeneratedField::Full => { - if table_reference_enum__.is_some() { - return Err(serde::de::Error::duplicate_field("full")); + GeneratedField::On => { + if on__.is_some() { + return Err(serde::de::Error::duplicate_field("on")); } - table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(table_reference::TableReferenceEnum::Full) -; + on__ = Some(map_.next_value()?); } - } - } - Ok(TableReference { - table_reference_enum: table_reference_enum__, - }) - } - } - deserializer.deserialize_struct("datafusion.TableReference", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for TimeUnit { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Second => "Second", - Self::Millisecond => "Millisecond", - Self::Microsecond => "Microsecond", - Self::Nanosecond => "Nanosecond", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for TimeUnit { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "Second", - "Millisecond", - "Microsecond", - "Nanosecond", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = TimeUnit; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "Second" => Ok(TimeUnit::Second), - "Millisecond" => Ok(TimeUnit::Millisecond), - "Microsecond" => Ok(TimeUnit::Microsecond), - "Nanosecond" => Ok(TimeUnit::Nanosecond), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + GeneratedField::JoinType => { + if join_type__.is_some() { + return Err(serde::de::Error::duplicate_field("joinType")); + } + join_type__ = Some(map_.next_value::()? as i32); + } + GeneratedField::PartitionMode => { + if partition_mode__.is_some() { + return Err(serde::de::Error::duplicate_field("partitionMode")); + } + partition_mode__ = Some(map_.next_value::()? as i32); + } + GeneratedField::NullEqualsNull => { + if null_equals_null__.is_some() { + return Err(serde::de::Error::duplicate_field("nullEqualsNull")); + } + null_equals_null__ = Some(map_.next_value()?); + } + GeneratedField::Filter => { + if filter__.is_some() { + return Err(serde::de::Error::duplicate_field("filter")); + } + filter__ = map_.next_value()?; + } + GeneratedField::LeftSortExprs => { + if left_sort_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("leftSortExprs")); + } + left_sort_exprs__ = Some(map_.next_value()?); + } + GeneratedField::RightSortExprs => { + if right_sort_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("rightSortExprs")); + } + right_sort_exprs__ = Some(map_.next_value()?); + } + } } + Ok(SymmetricHashJoinExecNode { + left: left__, + right: right__, + on: on__.unwrap_or_default(), + join_type: join_type__.unwrap_or_default(), + partition_mode: partition_mode__.unwrap_or_default(), + null_equals_null: null_equals_null__.unwrap_or_default(), + filter: filter__, + left_sort_exprs: left_sort_exprs__.unwrap_or_default(), + right_sort_exprs: right_sort_exprs__.unwrap_or_default(), + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.SymmetricHashJoinExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Timestamp { +impl serde::Serialize for TableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -25971,40 +19521,43 @@ impl serde::Serialize for Timestamp { { use serde::ser::SerializeStruct; let mut len = 0; - if self.time_unit != 0 { - len += 1; - } - if !self.timezone.is_empty() { + if self.table_reference_enum.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Timestamp", len)?; - if self.time_unit != 0 { - let v = TimeUnit::try_from(self.time_unit) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.time_unit)))?; - struct_ser.serialize_field("timeUnit", &v)?; - } - if !self.timezone.is_empty() { - struct_ser.serialize_field("timezone", &self.timezone)?; + let mut struct_ser = serializer.serialize_struct("datafusion.TableReference", len)?; + if let Some(v) = self.table_reference_enum.as_ref() { + match v { + table_reference::TableReferenceEnum::Bare(v) => { + struct_ser.serialize_field("bare", v)?; + } + table_reference::TableReferenceEnum::Partial(v) => { + struct_ser.serialize_field("partial", v)?; + } + table_reference::TableReferenceEnum::Full(v) => { + struct_ser.serialize_field("full", v)?; + } + } } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Timestamp { +impl<'de> serde::Deserialize<'de> for TableReference { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "time_unit", - "timeUnit", - "timezone", + "bare", + "partial", + "full", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - TimeUnit, - Timezone, + Bare, + Partial, + Full, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -26026,8 +19579,9 @@ impl<'de> serde::Deserialize<'de> for Timestamp { E: serde::de::Error, { match value { - "timeUnit" | "time_unit" => Ok(GeneratedField::TimeUnit), - "timezone" => Ok(GeneratedField::Timezone), + "bare" => Ok(GeneratedField::Bare), + "partial" => Ok(GeneratedField::Partial), + "full" => Ok(GeneratedField::Full), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -26037,41 +19591,48 @@ impl<'de> serde::Deserialize<'de> for Timestamp { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Timestamp; + type Value = TableReference; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Timestamp") + formatter.write_str("struct datafusion.TableReference") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut time_unit__ = None; - let mut timezone__ = None; + let mut table_reference_enum__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::TimeUnit => { - if time_unit__.is_some() { - return Err(serde::de::Error::duplicate_field("timeUnit")); + GeneratedField::Bare => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("bare")); + } + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(table_reference::TableReferenceEnum::Bare) +; + } + GeneratedField::Partial => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("partial")); } - time_unit__ = Some(map_.next_value::()? as i32); + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(table_reference::TableReferenceEnum::Partial) +; } - GeneratedField::Timezone => { - if timezone__.is_some() { - return Err(serde::de::Error::duplicate_field("timezone")); + GeneratedField::Full => { + if table_reference_enum__.is_some() { + return Err(serde::de::Error::duplicate_field("full")); } - timezone__ = Some(map_.next_value()?); + table_reference_enum__ = map_.next_value::<::std::option::Option<_>>()?.map(table_reference::TableReferenceEnum::Full) +; } } } - Ok(Timestamp { - time_unit: time_unit__.unwrap_or_default(), - timezone: timezone__.unwrap_or_default(), + Ok(TableReference { + table_reference_enum: table_reference_enum__, }) } } - deserializer.deserialize_struct("datafusion.Timestamp", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.TableReference", FIELDS, GeneratedVisitor) } } impl serde::Serialize for TryCastNode { @@ -26183,139 +19744,6 @@ impl<'de> serde::Deserialize<'de> for TryCastNode { deserializer.deserialize_struct("datafusion.TryCastNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Union { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.union_types.is_empty() { - len += 1; - } - if self.union_mode != 0 { - len += 1; - } - if !self.type_ids.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.Union", len)?; - if !self.union_types.is_empty() { - struct_ser.serialize_field("unionTypes", &self.union_types)?; - } - if self.union_mode != 0 { - let v = UnionMode::try_from(self.union_mode) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.union_mode)))?; - struct_ser.serialize_field("unionMode", &v)?; - } - if !self.type_ids.is_empty() { - struct_ser.serialize_field("typeIds", &self.type_ids)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for Union { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "union_types", - "unionTypes", - "union_mode", - "unionMode", - "type_ids", - "typeIds", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - UnionTypes, - UnionMode, - TypeIds, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "unionTypes" | "union_types" => Ok(GeneratedField::UnionTypes), - "unionMode" | "union_mode" => Ok(GeneratedField::UnionMode), - "typeIds" | "type_ids" => Ok(GeneratedField::TypeIds), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Union; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Union") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut union_types__ = None; - let mut union_mode__ = None; - let mut type_ids__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::UnionTypes => { - if union_types__.is_some() { - return Err(serde::de::Error::duplicate_field("unionTypes")); - } - union_types__ = Some(map_.next_value()?); - } - GeneratedField::UnionMode => { - if union_mode__.is_some() { - return Err(serde::de::Error::duplicate_field("unionMode")); - } - union_mode__ = Some(map_.next_value::()? as i32); - } - GeneratedField::TypeIds => { - if type_ids__.is_some() { - return Err(serde::de::Error::duplicate_field("typeIds")); - } - type_ids__ = - Some(map_.next_value::>>()? - .into_iter().map(|x| x.0).collect()) - ; - } - } - } - Ok(Union { - union_types: union_types__.unwrap_or_default(), - union_mode: union_mode__.unwrap_or_default(), - type_ids: type_ids__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.Union", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for UnionExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -26407,7 +19835,7 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for UnionField { +impl serde::Serialize for UnionNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -26415,38 +19843,29 @@ impl serde::Serialize for UnionField { { use serde::ser::SerializeStruct; let mut len = 0; - if self.field_id != 0 { - len += 1; - } - if self.field.is_some() { + if !self.inputs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.UnionField", len)?; - if self.field_id != 0 { - struct_ser.serialize_field("fieldId", &self.field_id)?; - } - if let Some(v) = self.field.as_ref() { - struct_ser.serialize_field("field", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.UnionNode", len)?; + if !self.inputs.is_empty() { + struct_ser.serialize_field("inputs", &self.inputs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for UnionField { +impl<'de> serde::Deserialize<'de> for UnionNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "field_id", - "fieldId", - "field", + "inputs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - FieldId, - Field, + Inputs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -26468,8 +19887,7 @@ impl<'de> serde::Deserialize<'de> for UnionField { E: serde::de::Error, { match value { - "fieldId" | "field_id" => Ok(GeneratedField::FieldId), - "field" => Ok(GeneratedField::Field), + "inputs" => Ok(GeneratedField::Inputs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -26479,117 +19897,36 @@ impl<'de> serde::Deserialize<'de> for UnionField { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = UnionField; + type Value = UnionNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.UnionField") + formatter.write_str("struct datafusion.UnionNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut field_id__ = None; - let mut field__ = None; + let mut inputs__ = None; while let Some(k) = map_.next_key()? { - match k { - GeneratedField::FieldId => { - if field_id__.is_some() { - return Err(serde::de::Error::duplicate_field("fieldId")); - } - field_id__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Field => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("field")); - } - field__ = map_.next_value()?; - } - } - } - Ok(UnionField { - field_id: field_id__.unwrap_or_default(), - field: field__, - }) - } - } - deserializer.deserialize_struct("datafusion.UnionField", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for UnionMode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - let variant = match self { - Self::Sparse => "sparse", - Self::Dense => "dense", - }; - serializer.serialize_str(variant) - } -} -impl<'de> serde::Deserialize<'de> for UnionMode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "sparse", - "dense", - ]; - - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = UnionMode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - fn visit_i64(self, v: i64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) - }) - } - - fn visit_u64(self, v: u64) -> std::result::Result - where - E: serde::de::Error, - { - i32::try_from(v) - .ok() - .and_then(|x| x.try_into().ok()) - .ok_or_else(|| { - serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) - }) - } - - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "sparse" => Ok(UnionMode::Sparse), - "dense" => Ok(UnionMode::Dense), - _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + match k { + GeneratedField::Inputs => { + if inputs__.is_some() { + return Err(serde::de::Error::duplicate_field("inputs")); + } + inputs__ = Some(map_.next_value()?); + } + } } + Ok(UnionNode { + inputs: inputs__.unwrap_or_default(), + }) } } - deserializer.deserialize_any(GeneratedVisitor) + deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for UnionNode { +impl serde::Serialize for Unnest { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -26597,29 +19934,29 @@ impl serde::Serialize for UnionNode { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.inputs.is_empty() { + if !self.exprs.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.UnionNode", len)?; - if !self.inputs.is_empty() { - struct_ser.serialize_field("inputs", &self.inputs)?; + let mut struct_ser = serializer.serialize_struct("datafusion.Unnest", len)?; + if !self.exprs.is_empty() { + struct_ser.serialize_field("exprs", &self.exprs)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for UnionNode { +impl<'de> serde::Deserialize<'de> for Unnest { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "inputs", + "exprs", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Inputs, + Exprs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -26641,7 +19978,7 @@ impl<'de> serde::Deserialize<'de> for UnionNode { E: serde::de::Error, { match value { - "inputs" => Ok(GeneratedField::Inputs), + "exprs" => Ok(GeneratedField::Exprs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -26651,36 +19988,36 @@ impl<'de> serde::Deserialize<'de> for UnionNode { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = UnionNode; + type Value = Unnest; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.UnionNode") + formatter.write_str("struct datafusion.Unnest") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut inputs__ = None; + let mut exprs__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Inputs => { - if inputs__.is_some() { - return Err(serde::de::Error::duplicate_field("inputs")); + GeneratedField::Exprs => { + if exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("exprs")); } - inputs__ = Some(map_.next_value()?); + exprs__ = Some(map_.next_value()?); } } } - Ok(UnionNode { - inputs: inputs__.unwrap_or_default(), + Ok(Unnest { + exprs: exprs__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.Unnest", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for UnionValue { +impl serde::Serialize for UnnestExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -26688,56 +20025,63 @@ impl serde::Serialize for UnionValue { { use serde::ser::SerializeStruct; let mut len = 0; - if self.value_id != 0 { + if self.input.is_some() { len += 1; } - if self.value.is_some() { + if self.schema.is_some() { len += 1; } - if !self.fields.is_empty() { + if !self.list_type_columns.is_empty() { len += 1; } - if self.mode != 0 { + if !self.struct_type_columns.is_empty() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.UnionValue", len)?; - if self.value_id != 0 { - struct_ser.serialize_field("valueId", &self.value_id)?; + if self.options.is_some() { + len += 1; } - if let Some(v) = self.value.as_ref() { - struct_ser.serialize_field("value", v)?; + let mut struct_ser = serializer.serialize_struct("datafusion.UnnestExecNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; } - if !self.fields.is_empty() { - struct_ser.serialize_field("fields", &self.fields)?; + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; } - if self.mode != 0 { - let v = UnionMode::try_from(self.mode) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; - struct_ser.serialize_field("mode", &v)?; + if !self.list_type_columns.is_empty() { + struct_ser.serialize_field("listTypeColumns", &self.list_type_columns)?; + } + if !self.struct_type_columns.is_empty() { + struct_ser.serialize_field("structTypeColumns", &self.struct_type_columns.iter().map(ToString::to_string).collect::>())?; + } + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for UnionValue { +impl<'de> serde::Deserialize<'de> for UnnestExecNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "value_id", - "valueId", - "value", - "fields", - "mode", + "input", + "schema", + "list_type_columns", + "listTypeColumns", + "struct_type_columns", + "structTypeColumns", + "options", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - ValueId, - Value, - Fields, - Mode, + Input, + Schema, + ListTypeColumns, + StructTypeColumns, + Options, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -26759,10 +20103,11 @@ impl<'de> serde::Deserialize<'de> for UnionValue { E: serde::de::Error, { match value { - "valueId" | "value_id" => Ok(GeneratedField::ValueId), - "value" => Ok(GeneratedField::Value), - "fields" => Ok(GeneratedField::Fields), - "mode" => Ok(GeneratedField::Mode), + "input" => Ok(GeneratedField::Input), + "schema" => Ok(GeneratedField::Schema), + "listTypeColumns" | "list_type_columns" => Ok(GeneratedField::ListTypeColumns), + "structTypeColumns" | "struct_type_columns" => Ok(GeneratedField::StructTypeColumns), + "options" => Ok(GeneratedField::Options), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -26772,62 +20117,71 @@ impl<'de> serde::Deserialize<'de> for UnionValue { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = UnionValue; + type Value = UnnestExecNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.UnionValue") + formatter.write_str("struct datafusion.UnnestExecNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut value_id__ = None; - let mut value__ = None; - let mut fields__ = None; - let mut mode__ = None; + let mut input__ = None; + let mut schema__ = None; + let mut list_type_columns__ = None; + let mut struct_type_columns__ = None; + let mut options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::ValueId => { - if value_id__.is_some() { - return Err(serde::de::Error::duplicate_field("valueId")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); } - value_id__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; + input__ = map_.next_value()?; } - GeneratedField::Value => { - if value__.is_some() { - return Err(serde::de::Error::duplicate_field("value")); + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); } - value__ = map_.next_value()?; + schema__ = map_.next_value()?; } - GeneratedField::Fields => { - if fields__.is_some() { - return Err(serde::de::Error::duplicate_field("fields")); + GeneratedField::ListTypeColumns => { + if list_type_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("listTypeColumns")); } - fields__ = Some(map_.next_value()?); + list_type_columns__ = Some(map_.next_value()?); } - GeneratedField::Mode => { - if mode__.is_some() { - return Err(serde::de::Error::duplicate_field("mode")); + GeneratedField::StructTypeColumns => { + if struct_type_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("structTypeColumns")); + } + struct_type_columns__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); } - mode__ = Some(map_.next_value::()? as i32); + options__ = map_.next_value()?; } } } - Ok(UnionValue { - value_id: value_id__.unwrap_or_default(), - value: value__, - fields: fields__.unwrap_or_default(), - mode: mode__.unwrap_or_default(), + Ok(UnnestExecNode { + input: input__, + schema: schema__, + list_type_columns: list_type_columns__.unwrap_or_default(), + struct_type_columns: struct_type_columns__.unwrap_or_default(), + options: options__, }) } } - deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.UnnestExecNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for UniqueConstraint { +impl serde::Serialize for UnnestNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -26835,29 +20189,81 @@ impl serde::Serialize for UniqueConstraint { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.indices.is_empty() { + if self.input.is_some() { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.UniqueConstraint", len)?; - if !self.indices.is_empty() { - struct_ser.serialize_field("indices", &self.indices.iter().map(ToString::to_string).collect::>())?; + if !self.exec_columns.is_empty() { + len += 1; + } + if !self.list_type_columns.is_empty() { + len += 1; + } + if !self.struct_type_columns.is_empty() { + len += 1; + } + if !self.dependency_indices.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + if self.options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnnestNode", len)?; + if let Some(v) = self.input.as_ref() { + struct_ser.serialize_field("input", v)?; + } + if !self.exec_columns.is_empty() { + struct_ser.serialize_field("execColumns", &self.exec_columns)?; + } + if !self.list_type_columns.is_empty() { + struct_ser.serialize_field("listTypeColumns", &self.list_type_columns)?; + } + if !self.struct_type_columns.is_empty() { + struct_ser.serialize_field("structTypeColumns", &self.struct_type_columns.iter().map(ToString::to_string).collect::>())?; + } + if !self.dependency_indices.is_empty() { + struct_ser.serialize_field("dependencyIndices", &self.dependency_indices.iter().map(ToString::to_string).collect::>())?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + if let Some(v) = self.options.as_ref() { + struct_ser.serialize_field("options", v)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for UniqueConstraint { +impl<'de> serde::Deserialize<'de> for UnnestNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "indices", + "input", + "exec_columns", + "execColumns", + "list_type_columns", + "listTypeColumns", + "struct_type_columns", + "structTypeColumns", + "dependency_indices", + "dependencyIndices", + "schema", + "options", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Indices, + Input, + ExecColumns, + ListTypeColumns, + StructTypeColumns, + DependencyIndices, + Schema, + Options, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -26879,7 +20285,13 @@ impl<'de> serde::Deserialize<'de> for UniqueConstraint { E: serde::de::Error, { match value { - "indices" => Ok(GeneratedField::Indices), + "input" => Ok(GeneratedField::Input), + "execColumns" | "exec_columns" => Ok(GeneratedField::ExecColumns), + "listTypeColumns" | "list_type_columns" => Ok(GeneratedField::ListTypeColumns), + "structTypeColumns" | "struct_type_columns" => Ok(GeneratedField::StructTypeColumns), + "dependencyIndices" | "dependency_indices" => Ok(GeneratedField::DependencyIndices), + "schema" => Ok(GeneratedField::Schema), + "options" => Ok(GeneratedField::Options), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -26889,39 +20301,90 @@ impl<'de> serde::Deserialize<'de> for UniqueConstraint { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = UniqueConstraint; + type Value = UnnestNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.UniqueConstraint") + formatter.write_str("struct datafusion.UnnestNode") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut indices__ = None; + let mut input__ = None; + let mut exec_columns__ = None; + let mut list_type_columns__ = None; + let mut struct_type_columns__ = None; + let mut dependency_indices__ = None; + let mut schema__ = None; + let mut options__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Indices => { - if indices__.is_some() { - return Err(serde::de::Error::duplicate_field("indices")); + GeneratedField::Input => { + if input__.is_some() { + return Err(serde::de::Error::duplicate_field("input")); + } + input__ = map_.next_value()?; + } + GeneratedField::ExecColumns => { + if exec_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("execColumns")); + } + exec_columns__ = Some(map_.next_value()?); + } + GeneratedField::ListTypeColumns => { + if list_type_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("listTypeColumns")); + } + list_type_columns__ = Some(map_.next_value()?); + } + GeneratedField::StructTypeColumns => { + if struct_type_columns__.is_some() { + return Err(serde::de::Error::duplicate_field("structTypeColumns")); } - indices__ = + struct_type_columns__ = Some(map_.next_value::>>()? .into_iter().map(|x| x.0).collect()) ; } + GeneratedField::DependencyIndices => { + if dependency_indices__.is_some() { + return Err(serde::de::Error::duplicate_field("dependencyIndices")); + } + dependency_indices__ = + Some(map_.next_value::>>()? + .into_iter().map(|x| x.0).collect()) + ; + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = map_.next_value()?; + } } } - Ok(UniqueConstraint { - indices: indices__.unwrap_or_default(), + Ok(UnnestNode { + input: input__, + exec_columns: exec_columns__.unwrap_or_default(), + list_type_columns: list_type_columns__.unwrap_or_default(), + struct_type_columns: struct_type_columns__.unwrap_or_default(), + dependency_indices: dependency_indices__.unwrap_or_default(), + schema: schema__, + options: options__, }) } } - deserializer.deserialize_struct("datafusion.UniqueConstraint", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.UnnestNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for Unnest { +impl serde::Serialize for UnnestOptions { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -26929,29 +20392,38 @@ impl serde::Serialize for Unnest { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.exprs.is_empty() { + if self.preserve_nulls { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.Unnest", len)?; - if !self.exprs.is_empty() { - struct_ser.serialize_field("exprs", &self.exprs)?; + if !self.recursions.is_empty() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnnestOptions", len)?; + if self.preserve_nulls { + struct_ser.serialize_field("preserveNulls", &self.preserve_nulls)?; + } + if !self.recursions.is_empty() { + struct_ser.serialize_field("recursions", &self.recursions)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for Unnest { +impl<'de> serde::Deserialize<'de> for UnnestOptions { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "exprs", + "preserve_nulls", + "preserveNulls", + "recursions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Exprs, + PreserveNulls, + Recursions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -26973,7 +20445,8 @@ impl<'de> serde::Deserialize<'de> for Unnest { E: serde::de::Error, { match value { - "exprs" => Ok(GeneratedField::Exprs), + "preserveNulls" | "preserve_nulls" => Ok(GeneratedField::PreserveNulls), + "recursions" => Ok(GeneratedField::Recursions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -26983,33 +20456,41 @@ impl<'de> serde::Deserialize<'de> for Unnest { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = Unnest; + type Value = UnnestOptions; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.Unnest") + formatter.write_str("struct datafusion.UnnestOptions") } - fn visit_map(self, mut map_: V) -> std::result::Result + fn visit_map(self, mut map_: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut exprs__ = None; + let mut preserve_nulls__ = None; + let mut recursions__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Exprs => { - if exprs__.is_some() { - return Err(serde::de::Error::duplicate_field("exprs")); + GeneratedField::PreserveNulls => { + if preserve_nulls__.is_some() { + return Err(serde::de::Error::duplicate_field("preserveNulls")); } - exprs__ = Some(map_.next_value()?); + preserve_nulls__ = Some(map_.next_value()?); + } + GeneratedField::Recursions => { + if recursions__.is_some() { + return Err(serde::de::Error::duplicate_field("recursions")); + } + recursions__ = Some(map_.next_value()?); } } } - Ok(Unnest { - exprs: exprs__.unwrap_or_default(), + Ok(UnnestOptions { + preserve_nulls: preserve_nulls__.unwrap_or_default(), + recursions: recursions__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.Unnest", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.UnnestOptions", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ValuesNode { @@ -27029,6 +20510,7 @@ impl serde::Serialize for ValuesNode { let mut struct_ser = serializer.serialize_struct("datafusion.ValuesNode", len)?; if self.n_cols != 0 { #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("nCols", ToString::to_string(&self.n_cols).as_str())?; } if !self.values_list.is_empty() { @@ -27403,12 +20885,12 @@ impl serde::Serialize for Wildcard { { use serde::ser::SerializeStruct; let mut len = 0; - if !self.qualifier.is_empty() { + if self.qualifier.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.Wildcard", len)?; - if !self.qualifier.is_empty() { - struct_ser.serialize_field("qualifier", &self.qualifier)?; + if let Some(v) = self.qualifier.as_ref() { + struct_ser.serialize_field("qualifier", v)?; } struct_ser.end() } @@ -27474,12 +20956,12 @@ impl<'de> serde::Deserialize<'de> for Wildcard { if qualifier__.is_some() { return Err(serde::de::Error::duplicate_field("qualifier")); } - qualifier__ = Some(map_.next_value()?); + qualifier__ = map_.next_value()?; } } } Ok(Wildcard { - qualifier: qualifier__.unwrap_or_default(), + qualifier: qualifier__, }) } } @@ -27670,7 +21152,7 @@ impl serde::Serialize for WindowExprNode { { use serde::ser::SerializeStruct; let mut len = 0; - if self.expr.is_some() { + if !self.exprs.is_empty() { len += 1; } if !self.partition_by.is_empty() { @@ -27682,12 +21164,15 @@ impl serde::Serialize for WindowExprNode { if self.window_frame.is_some() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.window_function.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.WindowExprNode", len)?; - if let Some(v) = self.expr.as_ref() { - struct_ser.serialize_field("expr", v)?; + if !self.exprs.is_empty() { + struct_ser.serialize_field("exprs", &self.exprs)?; } if !self.partition_by.is_empty() { struct_ser.serialize_field("partitionBy", &self.partition_by)?; @@ -27698,13 +21183,13 @@ impl serde::Serialize for WindowExprNode { if let Some(v) = self.window_frame.as_ref() { struct_ser.serialize_field("windowFrame", v)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.window_function.as_ref() { match v { - window_expr_node::WindowFunction::AggrFunction(v) => { - let v = AggregateFunction::try_from(*v) - .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; - struct_ser.serialize_field("aggrFunction", &v)?; - } window_expr_node::WindowFunction::BuiltInFunction(v) => { let v = BuiltInWindowFunction::try_from(*v) .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", *v)))?; @@ -27728,15 +21213,15 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "expr", + "exprs", "partition_by", "partitionBy", "order_by", "orderBy", "window_frame", "windowFrame", - "aggr_function", - "aggrFunction", + "fun_definition", + "funDefinition", "built_in_function", "builtInFunction", "udaf", @@ -27745,11 +21230,11 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { #[allow(clippy::enum_variant_names)] enum GeneratedField { - Expr, + Exprs, PartitionBy, OrderBy, WindowFrame, - AggrFunction, + FunDefinition, BuiltInFunction, Udaf, Udwf, @@ -27774,11 +21259,11 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { E: serde::de::Error, { match value { - "expr" => Ok(GeneratedField::Expr), + "exprs" => Ok(GeneratedField::Exprs), "partitionBy" | "partition_by" => Ok(GeneratedField::PartitionBy), "orderBy" | "order_by" => Ok(GeneratedField::OrderBy), "windowFrame" | "window_frame" => Ok(GeneratedField::WindowFrame), - "aggrFunction" | "aggr_function" => Ok(GeneratedField::AggrFunction), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "builtInFunction" | "built_in_function" => Ok(GeneratedField::BuiltInFunction), "udaf" => Ok(GeneratedField::Udaf), "udwf" => Ok(GeneratedField::Udwf), @@ -27801,18 +21286,19 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { where V: serde::de::MapAccess<'de>, { - let mut expr__ = None; + let mut exprs__ = None; let mut partition_by__ = None; let mut order_by__ = None; let mut window_frame__ = None; + let mut fun_definition__ = None; let mut window_function__ = None; while let Some(k) = map_.next_key()? { match k { - GeneratedField::Expr => { - if expr__.is_some() { - return Err(serde::de::Error::duplicate_field("expr")); + GeneratedField::Exprs => { + if exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("exprs")); } - expr__ = map_.next_value()?; + exprs__ = Some(map_.next_value()?); } GeneratedField::PartitionBy => { if partition_by__.is_some() { @@ -27832,11 +21318,13 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } window_frame__ = map_.next_value()?; } - GeneratedField::AggrFunction => { - if window_function__.is_some() { - return Err(serde::de::Error::duplicate_field("aggrFunction")); + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); } - window_function__ = map_.next_value::<::std::option::Option>()?.map(|x| window_expr_node::WindowFunction::AggrFunction(x as i32)); + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; } GeneratedField::BuiltInFunction => { if window_function__.is_some() { @@ -27859,10 +21347,11 @@ impl<'de> serde::Deserialize<'de> for WindowExprNode { } } Ok(WindowExprNode { - expr: expr__, + exprs: exprs__.unwrap_or_default(), partition_by: partition_by__.unwrap_or_default(), order_by: order_by__.unwrap_or_default(), window_frame: window_frame__, + fun_definition: fun_definition__, window_function: window_function__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index d0210eb7cfd3..dfc30e809108 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1,51 +1,16 @@ // This file is @generated by prost-build. -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnRelation { - #[prost(string, tag = "1")] - pub relation: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Column { - #[prost(string, tag = "1")] - pub name: ::prost::alloc::string::String, - #[prost(message, optional, tag = "2")] - pub relation: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DfField { - #[prost(message, optional, tag = "1")] - pub field: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub qualifier: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct DfSchema { - #[prost(message, repeated, tag = "1")] - pub columns: ::prost::alloc::vec::Vec, - #[prost(map = "string, string", tag = "2")] - pub metadata: ::std::collections::HashMap< - ::prost::alloc::string::String, - ::prost::alloc::string::String, - >, -} /// logical plan /// LogicalPlan is a nested type -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30" )] pub logical_plan_type: ::core::option::Option, } /// Nested message and enum types in `LogicalPlanNode`. pub mod logical_plan_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum LogicalPlanType { #[prost(message, tag = "1")] @@ -104,9 +69,10 @@ pub mod logical_plan_node { DistinctOn(::prost::alloc::boxed::Box), #[prost(message, tag = "29")] CopyTo(::prost::alloc::boxed::Box), + #[prost(message, tag = "30")] + Unnest(::prost::alloc::boxed::Box), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalExtensionNode { #[prost(bytes = "vec", tag = "1")] @@ -114,34 +80,21 @@ pub struct LogicalExtensionNode { #[prost(message, repeated, tag = "2")] pub inputs: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ProjectionColumns { #[prost(string, repeated, tag = "1")] pub columns: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CsvFormat { - #[prost(message, optional, tag = "5")] - pub options: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ParquetFormat { - #[prost(message, optional, tag = "2")] - pub options: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct AvroFormat {} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalExprNodeCollection { #[prost(message, repeated, tag = "1")] pub logical_expr_nodes: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SortExprNodeCollection { + #[prost(message, repeated, tag = "1")] + pub sort_expr_nodes: ::prost::alloc::vec::Vec, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct ListingTableScanNode { #[prost(message, optional, tag = "14")] @@ -153,7 +106,7 @@ pub struct ListingTableScanNode { #[prost(message, optional, tag = "4")] pub projection: ::core::option::Option, #[prost(message, optional, tag = "5")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, #[prost(message, repeated, tag = "6")] pub filters: ::prost::alloc::vec::Vec, #[prost(string, repeated, tag = "7")] @@ -163,26 +116,26 @@ pub struct ListingTableScanNode { #[prost(uint32, tag = "9")] pub target_partitions: u32, #[prost(message, repeated, tag = "13")] - pub file_sort_order: ::prost::alloc::vec::Vec, - #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12")] + pub file_sort_order: ::prost::alloc::vec::Vec, + #[prost(oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12, 15")] pub file_format_type: ::core::option::Option< listing_table_scan_node::FileFormatType, >, } /// Nested message and enum types in `ListingTableScanNode`. pub mod listing_table_scan_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum FileFormatType { #[prost(message, tag = "10")] - Csv(super::CsvFormat), + Csv(super::super::datafusion_common::CsvFormat), #[prost(message, tag = "11")] - Parquet(super::ParquetFormat), + Parquet(super::super::datafusion_common::ParquetFormat), #[prost(message, tag = "12")] - Avro(super::AvroFormat), + Avro(super::super::datafusion_common::AvroFormat), + #[prost(message, tag = "15")] + Json(super::super::datafusion_common::NdJsonFormat), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ViewTableScanNode { #[prost(message, optional, tag = "6")] @@ -190,14 +143,13 @@ pub struct ViewTableScanNode { #[prost(message, optional, boxed, tag = "2")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, #[prost(message, optional, tag = "4")] pub projection: ::core::option::Option, #[prost(string, tag = "5")] pub definition: ::prost::alloc::string::String, } /// Logical Plan to Scan a CustomTableProvider registered at runtime -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CustomTableScanNode { #[prost(message, optional, tag = "6")] @@ -205,13 +157,12 @@ pub struct CustomTableScanNode { #[prost(message, optional, tag = "2")] pub projection: ::core::option::Option, #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, #[prost(message, repeated, tag = "4")] pub filters: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", tag = "5")] pub custom_table_data: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ProjectionNode { #[prost(message, optional, boxed, tag = "1")] @@ -223,14 +174,12 @@ pub struct ProjectionNode { } /// Nested message and enum types in `ProjectionNode`. pub mod projection_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum OptionalAlias { #[prost(string, tag = "3")] Alias(::prost::alloc::string::String), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SelectionNode { #[prost(message, optional, boxed, tag = "1")] @@ -238,18 +187,16 @@ pub struct SelectionNode { #[prost(message, optional, tag = "2")] pub expr: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "2")] - pub expr: ::prost::alloc::vec::Vec, + pub expr: ::prost::alloc::vec::Vec, /// Maximum number of highest/lowest rows to fetch; negative means no limit #[prost(int64, tag = "3")] pub fetch: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct RepartitionNode { #[prost(message, optional, boxed, tag = "1")] @@ -259,7 +206,6 @@ pub struct RepartitionNode { } /// Nested message and enum types in `RepartitionNode`. pub mod repartition_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum PartitionMethod { #[prost(uint64, tag = "2")] @@ -268,7 +214,6 @@ pub mod repartition_node { Hash(super::HashRepartition), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct HashRepartition { #[prost(message, repeated, tag = "1")] @@ -276,98 +221,55 @@ pub struct HashRepartition { #[prost(uint64, tag = "2")] pub partition_count: u64, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct EmptyRelationNode { #[prost(bool, tag = "1")] pub produce_one_row: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PrimaryKeyConstraint { - #[prost(uint64, repeated, tag = "1")] - pub indices: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct UniqueConstraint { - #[prost(uint64, repeated, tag = "1")] - pub indices: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Constraint { - #[prost(oneof = "constraint::ConstraintMode", tags = "1, 2")] - pub constraint_mode: ::core::option::Option, -} -/// Nested message and enum types in `Constraint`. -pub mod constraint { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum ConstraintMode { - #[prost(message, tag = "1")] - PrimaryKey(super::PrimaryKeyConstraint), - #[prost(message, tag = "2")] - Unique(super::UniqueConstraint), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Constraints { - #[prost(message, repeated, tag = "1")] - pub constraints: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateExternalTableNode { - #[prost(message, optional, tag = "12")] + #[prost(message, optional, tag = "9")] pub name: ::core::option::Option, #[prost(string, tag = "2")] pub location: ::prost::alloc::string::String, #[prost(string, tag = "3")] pub file_type: ::prost::alloc::string::String, - #[prost(bool, tag = "4")] - pub has_header: bool, - #[prost(message, optional, tag = "5")] - pub schema: ::core::option::Option, - #[prost(string, repeated, tag = "6")] + #[prost(message, optional, tag = "4")] + pub schema: ::core::option::Option, + #[prost(string, repeated, tag = "5")] pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(bool, tag = "7")] + #[prost(bool, tag = "6")] pub if_not_exists: bool, - #[prost(string, tag = "8")] - pub delimiter: ::prost::alloc::string::String, - #[prost(string, tag = "9")] - pub definition: ::prost::alloc::string::String, - #[prost(enumeration = "CompressionTypeVariant", tag = "17")] - pub file_compression_type: i32, - #[prost(message, repeated, tag = "13")] - pub order_exprs: ::prost::alloc::vec::Vec, #[prost(bool, tag = "14")] + pub temporary: bool, + #[prost(string, tag = "7")] + pub definition: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "10")] + pub order_exprs: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "11")] pub unbounded: bool, - #[prost(map = "string, string", tag = "11")] + #[prost(map = "string, string", tag = "8")] pub options: ::std::collections::HashMap< ::prost::alloc::string::String, ::prost::alloc::string::String, >, - #[prost(message, optional, tag = "15")] - pub constraints: ::core::option::Option, - #[prost(map = "string, message", tag = "16")] + #[prost(message, optional, tag = "12")] + pub constraints: ::core::option::Option, + #[prost(map = "string, message", tag = "13")] pub column_defaults: ::std::collections::HashMap< ::prost::alloc::string::String, LogicalExprNode, >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PrepareNode { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] - pub data_types: ::prost::alloc::vec::Vec, + pub data_types: ::prost::alloc::vec::Vec, #[prost(message, optional, boxed, tag = "3")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateCatalogSchemaNode { #[prost(string, tag = "1")] @@ -375,9 +277,8 @@ pub struct CreateCatalogSchemaNode { #[prost(bool, tag = "2")] pub if_not_exists: bool, #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateCatalogNode { #[prost(string, tag = "1")] @@ -385,9 +286,8 @@ pub struct CreateCatalogNode { #[prost(bool, tag = "2")] pub if_not_exists: bool, #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DropViewNode { #[prost(message, optional, tag = "1")] @@ -395,9 +295,8 @@ pub struct DropViewNode { #[prost(bool, tag = "2")] pub if_exists: bool, #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CreateViewNode { #[prost(message, optional, tag = "5")] @@ -406,12 +305,13 @@ pub struct CreateViewNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(bool, tag = "3")] pub or_replace: bool, + #[prost(bool, tag = "6")] + pub temporary: bool, #[prost(string, tag = "4")] pub definition: ::prost::alloc::string::String, } /// a node containing data for defining values list. unlike in SQL where it's two dimensional, here /// the list is flattened, and with the field n_cols it can be parsed and partitioned into rows -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ValuesNode { #[prost(uint64, tag = "1")] @@ -419,7 +319,6 @@ pub struct ValuesNode { #[prost(message, repeated, tag = "2")] pub values_list: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AnalyzeNode { #[prost(message, optional, boxed, tag = "1")] @@ -427,7 +326,6 @@ pub struct AnalyzeNode { #[prost(bool, tag = "2")] pub verbose: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExplainNode { #[prost(message, optional, boxed, tag = "1")] @@ -435,7 +333,6 @@ pub struct ExplainNode { #[prost(bool, tag = "2")] pub verbose: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateNode { #[prost(message, optional, boxed, tag = "1")] @@ -445,7 +342,6 @@ pub struct AggregateNode { #[prost(message, repeated, tag = "3")] pub aggr_expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowNode { #[prost(message, optional, boxed, tag = "1")] @@ -453,16 +349,15 @@ pub struct WindowNode { #[prost(message, repeated, tag = "2")] pub window_expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct JoinNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, boxed, tag = "2")] pub right: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(enumeration = "JoinType", tag = "3")] + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "3")] pub join_type: i32, - #[prost(enumeration = "JoinConstraint", tag = "4")] + #[prost(enumeration = "super::datafusion_common::JoinConstraint", tag = "4")] pub join_constraint: i32, #[prost(message, repeated, tag = "5")] pub left_join_key: ::prost::alloc::vec::Vec, @@ -473,13 +368,11 @@ pub struct JoinNode { #[prost(message, optional, tag = "8")] pub filter: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistinctNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct DistinctOnNode { #[prost(message, repeated, tag = "1")] @@ -487,52 +380,78 @@ pub struct DistinctOnNode { #[prost(message, repeated, tag = "2")] pub select_expr: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "3")] - pub sort_expr: ::prost::alloc::vec::Vec, + pub sort_expr: ::prost::alloc::vec::Vec, #[prost(message, optional, boxed, tag = "4")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CopyToNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(string, tag = "2")] pub output_url: ::prost::alloc::string::String, + #[prost(bytes = "vec", tag = "3")] + pub file_type: ::prost::alloc::vec::Vec, #[prost(string, repeated, tag = "7")] pub partition_by: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(oneof = "copy_to_node::FormatOptions", tags = "8, 9, 10, 11, 12")] - pub format_options: ::core::option::Option, } -/// Nested message and enum types in `CopyToNode`. -pub mod copy_to_node { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum FormatOptions { - #[prost(message, tag = "8")] - Csv(super::CsvOptions), - #[prost(message, tag = "9")] - Json(super::JsonOptions), - #[prost(message, tag = "10")] - Parquet(super::TableParquetOptions), - #[prost(message, tag = "11")] - Avro(super::AvroOptions), - #[prost(message, tag = "12")] - Arrow(super::ArrowOptions), - } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnnestNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "2")] + pub exec_columns: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "3")] + pub list_type_columns: ::prost::alloc::vec::Vec, + #[prost(uint64, repeated, tag = "4")] + pub struct_type_columns: ::prost::alloc::vec::Vec, + #[prost(uint64, repeated, tag = "5")] + pub dependency_indices: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "6")] + pub schema: ::core::option::Option, + #[prost(message, optional, tag = "7")] + pub options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnUnnestListItem { + #[prost(uint32, tag = "1")] + pub input_index: u32, + #[prost(message, optional, tag = "2")] + pub recursion: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ColumnUnnestListRecursions { + #[prost(message, repeated, tag = "2")] + pub recursions: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct AvroOptions {} -#[allow(clippy::derive_partial_eq_without_eq)] +pub struct ColumnUnnestListRecursion { + #[prost(message, optional, tag = "1")] + pub output_column: ::core::option::Option, + #[prost(uint32, tag = "2")] + pub depth: u32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnnestOptions { + #[prost(bool, tag = "1")] + pub preserve_nulls: bool, + #[prost(message, repeated, tag = "2")] + pub recursions: ::prost::alloc::vec::Vec, +} #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ArrowOptions {} -#[allow(clippy::derive_partial_eq_without_eq)] +pub struct RecursionUnnestOption { + #[prost(message, optional, tag = "1")] + pub output_column: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub input_column: ::core::option::Option, + #[prost(uint32, tag = "3")] + pub depth: u32, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CrossJoinNode { #[prost(message, optional, boxed, tag = "1")] @@ -540,7 +459,6 @@ pub struct CrossJoinNode { #[prost(message, optional, boxed, tag = "2")] pub right: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LimitNode { #[prost(message, optional, boxed, tag = "1")] @@ -552,13 +470,11 @@ pub struct LimitNode { #[prost(int64, tag = "3")] pub fetch: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SelectionExecNode { #[prost(message, optional, tag = "1")] pub expr: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SubqueryAliasNode { #[prost(message, optional, boxed, tag = "1")] @@ -567,34 +483,29 @@ pub struct SubqueryAliasNode { pub alias: ::core::option::Option, } /// logical expressions -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalExprNode { #[prost( oneof = "logical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35" )] pub expr_type: ::core::option::Option, } /// Nested message and enum types in `LogicalExprNode`. pub mod logical_expr_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum ExprType { /// column references #[prost(message, tag = "1")] - Column(super::Column), + Column(super::super::datafusion_common::Column), /// alias #[prost(message, tag = "2")] Alias(::prost::alloc::boxed::Box), #[prost(message, tag = "3")] - Literal(super::ScalarValue), + Literal(super::super::datafusion_common::ScalarValue), /// binary expressions #[prost(message, tag = "4")] BinaryExpr(super::BinaryExprNode), - /// aggregate expressions - #[prost(message, tag = "5")] - AggregateExpr(::prost::alloc::boxed::Box), /// null checks #[prost(message, tag = "6")] IsNullExpr(::prost::alloc::boxed::Box), @@ -608,27 +519,24 @@ pub mod logical_expr_node { Case(::prost::alloc::boxed::Box), #[prost(message, tag = "11")] Cast(::prost::alloc::boxed::Box), - #[prost(message, tag = "12")] - Sort(::prost::alloc::boxed::Box), #[prost(message, tag = "13")] Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "14")] InList(::prost::alloc::boxed::Box), #[prost(message, tag = "15")] Wildcard(super::Wildcard), + /// was ScalarFunctionNode scalar_function = 16; #[prost(message, tag = "17")] TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "18")] - WindowExpr(::prost::alloc::boxed::Box), + WindowExpr(super::WindowExprNode), /// AggregateUDF expressions #[prost(message, tag = "19")] AggregateUdfExpr(::prost::alloc::boxed::Box), /// Scalar UDF expressions #[prost(message, tag = "20")] ScalarUdfExpr(super::ScalarUdfExprNode), - #[prost(message, tag = "21")] - GetIndexedField(::prost::alloc::boxed::Box), #[prost(message, tag = "22")] GroupingSet(super::GroupingSetNode), #[prost(message, tag = "23")] @@ -659,142 +567,102 @@ pub mod logical_expr_node { Unnest(super::Unnest), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Wildcard { - #[prost(string, tag = "1")] - pub qualifier: ::prost::alloc::string::String, + #[prost(message, optional, tag = "1")] + pub qualifier: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderNode { #[prost(string, tag = "1")] pub id: ::prost::alloc::string::String, #[prost(message, optional, tag = "2")] - pub data_type: ::core::option::Option, + pub data_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LogicalExprList { #[prost(message, repeated, tag = "1")] pub expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct GroupingSetNode { #[prost(message, repeated, tag = "1")] pub expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CubeNode { #[prost(message, repeated, tag = "1")] pub expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct RollupNode { #[prost(message, repeated, tag = "1")] pub expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct NamedStructField { #[prost(message, optional, tag = "1")] - pub name: ::core::option::Option, + pub name: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ListIndex { - #[prost(message, optional, boxed, tag = "1")] - pub key: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "1")] + pub key: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ListRange { - #[prost(message, optional, boxed, tag = "1")] - pub start: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "2")] - pub stop: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "3")] - pub stride: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct GetIndexedField { - #[prost(message, optional, boxed, tag = "1")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "get_indexed_field::Field", tags = "2, 3, 4")] - pub field: ::core::option::Option, -} -/// Nested message and enum types in `GetIndexedField`. -pub mod get_indexed_field { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Field { - #[prost(message, tag = "2")] - NamedStructField(super::NamedStructField), - #[prost(message, tag = "3")] - ListIndex(::prost::alloc::boxed::Box), - #[prost(message, tag = "4")] - ListRange(::prost::alloc::boxed::Box), - } + #[prost(message, optional, tag = "1")] + pub start: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub stop: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub stride: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsNull { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsNotNull { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsTrue { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsFalse { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsUnknown { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsNotTrue { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsNotFalse { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct IsNotUnknown { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Not { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AliasNode { #[prost(message, optional, boxed, tag = "1")] @@ -804,7 +672,6 @@ pub struct AliasNode { #[prost(message, repeated, tag = "3")] pub relation: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct BinaryExprNode { /// Represents the operands from the left inner most expression @@ -815,19 +682,16 @@ pub struct BinaryExprNode { #[prost(string, tag = "3")] pub op: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct NegativeNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct Unnest { #[prost(message, repeated, tag = "1")] pub exprs: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct InListNode { #[prost(message, optional, boxed, tag = "1")] @@ -837,33 +701,21 @@ pub struct InListNode { #[prost(bool, tag = "3")] pub negated: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct AggregateExprNode { - #[prost(enumeration = "AggregateFunction", tag = "1")] - pub aggr_function: i32, - #[prost(message, repeated, tag = "2")] - pub expr: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "3")] - pub distinct: bool, - #[prost(message, optional, boxed, tag = "4")] - pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, repeated, tag = "5")] - pub order_by: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateUdfExprNode { #[prost(string, tag = "1")] pub fun_name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "5")] + pub distinct: bool, #[prost(message, optional, boxed, tag = "3")] pub filter: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "4")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "6")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarUdfExprNode { #[prost(string, tag = "1")] @@ -873,28 +725,26 @@ pub struct ScalarUdfExprNode { #[prost(bytes = "vec", optional, tag = "3")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowExprNode { - #[prost(message, optional, boxed, tag = "4")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "4")] + pub exprs: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "5")] pub partition_by: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "6")] - pub order_by: ::prost::alloc::vec::Vec, + pub order_by: ::prost::alloc::vec::Vec, /// repeated LogicalExprNode filter = 7; #[prost(message, optional, tag = "8")] pub window_frame: ::core::option::Option, - #[prost(oneof = "window_expr_node::WindowFunction", tags = "1, 2, 3, 9")] + #[prost(bytes = "vec", optional, tag = "10")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(oneof = "window_expr_node::WindowFunction", tags = "2, 3, 9")] pub window_function: ::core::option::Option, } /// Nested message and enum types in `WindowExprNode`. pub mod window_expr_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum WindowFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), #[prost(string, tag = "3")] @@ -903,7 +753,6 @@ pub mod window_expr_node { Udwf(::prost::alloc::string::String), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct BetweenNode { #[prost(message, optional, boxed, tag = "1")] @@ -915,7 +764,6 @@ pub struct BetweenNode { #[prost(message, optional, boxed, tag = "4")] pub high: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LikeNode { #[prost(bool, tag = "1")] @@ -927,7 +775,6 @@ pub struct LikeNode { #[prost(string, tag = "4")] pub escape_char: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ILikeNode { #[prost(bool, tag = "1")] @@ -939,7 +786,6 @@ pub struct ILikeNode { #[prost(string, tag = "4")] pub escape_char: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SimilarToNode { #[prost(bool, tag = "1")] @@ -951,7 +797,6 @@ pub struct SimilarToNode { #[prost(string, tag = "4")] pub escape_char: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CaseNode { #[prost(message, optional, boxed, tag = "1")] @@ -961,7 +806,6 @@ pub struct CaseNode { #[prost(message, optional, boxed, tag = "3")] pub else_expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WhenThen { #[prost(message, optional, tag = "1")] @@ -969,33 +813,29 @@ pub struct WhenThen { #[prost(message, optional, tag = "2")] pub then_expr: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CastNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub arrow_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct TryCastNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub arrow_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "1")] + pub expr: ::core::option::Option, #[prost(bool, tag = "2")] pub asc: bool, #[prost(bool, tag = "3")] pub nulls_first: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowFrame { #[prost(enumeration = "WindowFrameUnits", tag = "1")] @@ -1011,504 +851,77 @@ pub struct WindowFrame { pub mod window_frame { /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see and ) /// this syntax is ugly but is binary compatible with the "optional" keyword (see ) - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum EndBound { #[prost(message, tag = "3")] Bound(super::WindowFrameBound), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowFrameBound { #[prost(enumeration = "WindowFrameBoundType", tag = "1")] pub window_frame_bound_type: i32, #[prost(message, optional, tag = "2")] - pub bound_value: ::core::option::Option, + pub bound_value: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Schema { - #[prost(message, repeated, tag = "1")] - pub columns: ::prost::alloc::vec::Vec, - #[prost(map = "string, string", tag = "2")] - pub metadata: ::std::collections::HashMap< - ::prost::alloc::string::String, - ::prost::alloc::string::String, - >, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Field { - /// name of the field - #[prost(string, tag = "1")] - pub name: ::prost::alloc::string::String, - #[prost(message, optional, boxed, tag = "2")] - pub arrow_type: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(bool, tag = "3")] - pub nullable: bool, - /// for complex data types like structs, unions - #[prost(message, repeated, tag = "4")] - pub children: ::prost::alloc::vec::Vec, - #[prost(map = "string, string", tag = "5")] - pub metadata: ::std::collections::HashMap< - ::prost::alloc::string::String, - ::prost::alloc::string::String, - >, - #[prost(int64, tag = "6")] - pub dict_id: i64, - #[prost(bool, tag = "7")] - pub dict_ordered: bool, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct FixedSizeBinary { #[prost(int32, tag = "1")] pub length: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Timestamp { - #[prost(enumeration = "TimeUnit", tag = "1")] - pub time_unit: i32, - #[prost(string, tag = "2")] - pub timezone: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Decimal { - #[prost(uint32, tag = "3")] - pub precision: u32, - #[prost(int32, tag = "4")] - pub scale: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct List { - #[prost(message, optional, boxed, tag = "1")] - pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct FixedSizeList { - #[prost(message, optional, boxed, tag = "1")] - pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(int32, tag = "2")] - pub list_size: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Dictionary { - #[prost(message, optional, boxed, tag = "1")] - pub key: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, boxed, tag = "2")] - pub value: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Struct { - #[prost(message, repeated, tag = "1")] - pub sub_field_types: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Map { - #[prost(message, optional, boxed, tag = "1")] - pub field_type: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(bool, tag = "2")] - pub keys_sorted: bool, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Union { - #[prost(message, repeated, tag = "1")] - pub union_types: ::prost::alloc::vec::Vec, - #[prost(enumeration = "UnionMode", tag = "2")] - pub union_mode: i32, - #[prost(int32, repeated, tag = "3")] - pub type_ids: ::prost::alloc::vec::Vec, -} -/// Used for List/FixedSizeList/LargeList/Struct -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarNestedValue { - #[prost(bytes = "vec", tag = "1")] - pub ipc_message: ::prost::alloc::vec::Vec, - #[prost(bytes = "vec", tag = "2")] - pub arrow_data: ::prost::alloc::vec::Vec, - #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarTime32Value { - #[prost(oneof = "scalar_time32_value::Value", tags = "1, 2")] - pub value: ::core::option::Option, -} -/// Nested message and enum types in `ScalarTime32Value`. -pub mod scalar_time32_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Value { - #[prost(int32, tag = "1")] - Time32SecondValue(i32), - #[prost(int32, tag = "2")] - Time32MillisecondValue(i32), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarTime64Value { - #[prost(oneof = "scalar_time64_value::Value", tags = "1, 2")] - pub value: ::core::option::Option, -} -/// Nested message and enum types in `ScalarTime64Value`. -pub mod scalar_time64_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Value { - #[prost(int64, tag = "1")] - Time64MicrosecondValue(i64), - #[prost(int64, tag = "2")] - Time64NanosecondValue(i64), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarTimestampValue { - #[prost(string, tag = "5")] - pub timezone: ::prost::alloc::string::String, - #[prost(oneof = "scalar_timestamp_value::Value", tags = "1, 2, 3, 4")] - pub value: ::core::option::Option, -} -/// Nested message and enum types in `ScalarTimestampValue`. -pub mod scalar_timestamp_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Value { - #[prost(int64, tag = "1")] - TimeMicrosecondValue(i64), - #[prost(int64, tag = "2")] - TimeNanosecondValue(i64), - #[prost(int64, tag = "3")] - TimeSecondValue(i64), - #[prost(int64, tag = "4")] - TimeMillisecondValue(i64), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarDictionaryValue { - #[prost(message, optional, tag = "1")] - pub index_type: ::core::option::Option, - #[prost(message, optional, boxed, tag = "2")] - pub value: ::core::option::Option<::prost::alloc::boxed::Box>, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct IntervalMonthDayNanoValue { - #[prost(int32, tag = "1")] - pub months: i32, - #[prost(int32, tag = "2")] - pub days: i32, - #[prost(int64, tag = "3")] - pub nanos: i64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct UnionField { - #[prost(int32, tag = "1")] - pub field_id: i32, - #[prost(message, optional, tag = "2")] - pub field: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct UnionValue { - /// Note that a null union value must have one or more fields, so we - /// encode a null UnionValue as one with value_id == 128 - #[prost(int32, tag = "1")] - pub value_id: i32, - #[prost(message, optional, boxed, tag = "2")] - pub value: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, repeated, tag = "3")] - pub fields: ::prost::alloc::vec::Vec, - #[prost(enumeration = "UnionMode", tag = "4")] - pub mode: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarFixedSizeBinary { - #[prost(bytes = "vec", tag = "1")] - pub values: ::prost::alloc::vec::Vec, - #[prost(int32, tag = "2")] - pub length: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ScalarValue { - #[prost( - oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42" - )] - pub value: ::core::option::Option, -} -/// Nested message and enum types in `ScalarValue`. -pub mod scalar_value { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Value { - /// was PrimitiveScalarType null_value = 19; - /// Null value of any type - #[prost(message, tag = "33")] - NullValue(super::ArrowType), - #[prost(bool, tag = "1")] - BoolValue(bool), - #[prost(string, tag = "2")] - Utf8Value(::prost::alloc::string::String), - #[prost(string, tag = "3")] - LargeUtf8Value(::prost::alloc::string::String), - #[prost(int32, tag = "4")] - Int8Value(i32), - #[prost(int32, tag = "5")] - Int16Value(i32), - #[prost(int32, tag = "6")] - Int32Value(i32), - #[prost(int64, tag = "7")] - Int64Value(i64), - #[prost(uint32, tag = "8")] - Uint8Value(u32), - #[prost(uint32, tag = "9")] - Uint16Value(u32), - #[prost(uint32, tag = "10")] - Uint32Value(u32), - #[prost(uint64, tag = "11")] - Uint64Value(u64), - #[prost(float, tag = "12")] - Float32Value(f32), - #[prost(double, tag = "13")] - Float64Value(f64), - /// Literal Date32 value always has a unit of day - #[prost(int32, tag = "14")] - Date32Value(i32), - #[prost(message, tag = "15")] - Time32Value(super::ScalarTime32Value), - #[prost(message, tag = "16")] - LargeListValue(super::ScalarNestedValue), - #[prost(message, tag = "17")] - ListValue(super::ScalarNestedValue), - #[prost(message, tag = "18")] - FixedSizeListValue(super::ScalarNestedValue), - #[prost(message, tag = "32")] - StructValue(super::ScalarNestedValue), - #[prost(message, tag = "20")] - Decimal128Value(super::Decimal128), - #[prost(message, tag = "39")] - Decimal256Value(super::Decimal256), - #[prost(int64, tag = "21")] - Date64Value(i64), - #[prost(int32, tag = "24")] - IntervalYearmonthValue(i32), - #[prost(int64, tag = "25")] - IntervalDaytimeValue(i64), - #[prost(int64, tag = "35")] - DurationSecondValue(i64), - #[prost(int64, tag = "36")] - DurationMillisecondValue(i64), - #[prost(int64, tag = "37")] - DurationMicrosecondValue(i64), - #[prost(int64, tag = "38")] - DurationNanosecondValue(i64), - #[prost(message, tag = "26")] - TimestampValue(super::ScalarTimestampValue), - #[prost(message, tag = "27")] - DictionaryValue(::prost::alloc::boxed::Box), - #[prost(bytes, tag = "28")] - BinaryValue(::prost::alloc::vec::Vec), - #[prost(bytes, tag = "29")] - LargeBinaryValue(::prost::alloc::vec::Vec), - #[prost(message, tag = "30")] - Time64Value(super::ScalarTime64Value), - #[prost(message, tag = "31")] - IntervalMonthDayNano(super::IntervalMonthDayNanoValue), - #[prost(message, tag = "34")] - FixedSizeBinaryValue(super::ScalarFixedSizeBinary), - #[prost(message, tag = "42")] - UnionValue(::prost::alloc::boxed::Box), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Decimal128 { - #[prost(bytes = "vec", tag = "1")] - pub value: ::prost::alloc::vec::Vec, - #[prost(int64, tag = "2")] - pub p: i64, - #[prost(int64, tag = "3")] - pub s: i64, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Decimal256 { - #[prost(bytes = "vec", tag = "1")] - pub value: ::prost::alloc::vec::Vec, - #[prost(int64, tag = "2")] - pub p: i64, - #[prost(int64, tag = "3")] - pub s: i64, -} -/// Serialized data type -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ArrowType { - #[prost( - oneof = "arrow_type::ArrowTypeEnum", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 32, 15, 16, 31, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 33" - )] - pub arrow_type_enum: ::core::option::Option, -} -/// Nested message and enum types in `ArrowType`. -pub mod arrow_type { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum ArrowTypeEnum { - /// arrow::Type::NA - #[prost(message, tag = "1")] - None(super::EmptyMessage), - /// arrow::Type::BOOL - #[prost(message, tag = "2")] - Bool(super::EmptyMessage), - /// arrow::Type::UINT8 - #[prost(message, tag = "3")] - Uint8(super::EmptyMessage), - /// arrow::Type::INT8 - #[prost(message, tag = "4")] - Int8(super::EmptyMessage), - /// represents arrow::Type fields in src/arrow/type.h - #[prost(message, tag = "5")] - Uint16(super::EmptyMessage), - #[prost(message, tag = "6")] - Int16(super::EmptyMessage), - #[prost(message, tag = "7")] - Uint32(super::EmptyMessage), - #[prost(message, tag = "8")] - Int32(super::EmptyMessage), - #[prost(message, tag = "9")] - Uint64(super::EmptyMessage), - #[prost(message, tag = "10")] - Int64(super::EmptyMessage), - #[prost(message, tag = "11")] - Float16(super::EmptyMessage), - #[prost(message, tag = "12")] - Float32(super::EmptyMessage), - #[prost(message, tag = "13")] - Float64(super::EmptyMessage), - #[prost(message, tag = "14")] - Utf8(super::EmptyMessage), - #[prost(message, tag = "32")] - LargeUtf8(super::EmptyMessage), - #[prost(message, tag = "15")] - Binary(super::EmptyMessage), - #[prost(int32, tag = "16")] - FixedSizeBinary(i32), - #[prost(message, tag = "31")] - LargeBinary(super::EmptyMessage), - #[prost(message, tag = "17")] - Date32(super::EmptyMessage), - #[prost(message, tag = "18")] - Date64(super::EmptyMessage), - #[prost(enumeration = "super::TimeUnit", tag = "19")] - Duration(i32), - #[prost(message, tag = "20")] - Timestamp(super::Timestamp), - #[prost(enumeration = "super::TimeUnit", tag = "21")] - Time32(i32), - #[prost(enumeration = "super::TimeUnit", tag = "22")] - Time64(i32), - #[prost(enumeration = "super::IntervalUnit", tag = "23")] - Interval(i32), - #[prost(message, tag = "24")] - Decimal(super::Decimal), - #[prost(message, tag = "25")] - List(::prost::alloc::boxed::Box), - #[prost(message, tag = "26")] - LargeList(::prost::alloc::boxed::Box), - #[prost(message, tag = "27")] - FixedSizeList(::prost::alloc::boxed::Box), - #[prost(message, tag = "28")] - Struct(super::Struct), - #[prost(message, tag = "29")] - Union(super::Union), - #[prost(message, tag = "30")] - Dictionary(::prost::alloc::boxed::Box), - #[prost(message, tag = "33")] - Map(::prost::alloc::boxed::Box), - } -} -/// Useful for representing an empty enum variant in rust -/// E.G. enum example{One, Two(i32)} -/// maps to -/// message example{ -/// oneof{ -/// EmptyMessage One = 1; -/// i32 Two = 2; -/// } -/// } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct EmptyMessage {} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AnalyzedLogicalPlanType { #[prost(string, tag = "1")] pub analyzer_name: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct OptimizedLogicalPlanType { #[prost(string, tag = "1")] pub optimizer_name: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct OptimizedPhysicalPlanType { #[prost(string, tag = "1")] pub optimizer_name: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlanType { - #[prost(oneof = "plan_type::PlanTypeEnum", tags = "1, 7, 8, 2, 3, 4, 9, 5, 6, 10")] + #[prost( + oneof = "plan_type::PlanTypeEnum", + tags = "1, 7, 8, 2, 3, 4, 9, 11, 5, 6, 10, 12" + )] pub plan_type_enum: ::core::option::Option, } /// Nested message and enum types in `PlanType`. pub mod plan_type { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum PlanTypeEnum { #[prost(message, tag = "1")] - InitialLogicalPlan(super::EmptyMessage), + InitialLogicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "7")] AnalyzedLogicalPlan(super::AnalyzedLogicalPlanType), #[prost(message, tag = "8")] - FinalAnalyzedLogicalPlan(super::EmptyMessage), + FinalAnalyzedLogicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "2")] OptimizedLogicalPlan(super::OptimizedLogicalPlanType), #[prost(message, tag = "3")] - FinalLogicalPlan(super::EmptyMessage), + FinalLogicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "4")] - InitialPhysicalPlan(super::EmptyMessage), + InitialPhysicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "9")] - InitialPhysicalPlanWithStats(super::EmptyMessage), + InitialPhysicalPlanWithStats(super::super::datafusion_common::EmptyMessage), + #[prost(message, tag = "11")] + InitialPhysicalPlanWithSchema(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "5")] OptimizedPhysicalPlan(super::OptimizedPhysicalPlanType), #[prost(message, tag = "6")] - FinalPhysicalPlan(super::EmptyMessage), + FinalPhysicalPlan(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "10")] - FinalPhysicalPlanWithStats(super::EmptyMessage), + FinalPhysicalPlanWithStats(super::super::datafusion_common::EmptyMessage), + #[prost(message, tag = "12")] + FinalPhysicalPlanWithSchema(super::super::datafusion_common::EmptyMessage), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct StringifiedPlan { #[prost(message, optional, tag = "1")] @@ -1516,13 +929,11 @@ pub struct StringifiedPlan { #[prost(string, tag = "2")] pub plan: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct BareTableReference { #[prost(string, tag = "1")] pub table: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PartialTableReference { #[prost(string, tag = "1")] @@ -1530,7 +941,6 @@ pub struct PartialTableReference { #[prost(string, tag = "2")] pub table: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FullTableReference { #[prost(string, tag = "1")] @@ -1540,7 +950,6 @@ pub struct FullTableReference { #[prost(string, tag = "3")] pub table: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct TableReference { #[prost(oneof = "table_reference::TableReferenceEnum", tags = "1, 2, 3")] @@ -1550,7 +959,6 @@ pub struct TableReference { } /// Nested message and enum types in `TableReference`. pub mod table_reference { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum TableReferenceEnum { #[prost(message, tag = "1")] @@ -1562,18 +970,16 @@ pub mod table_reference { } } /// PhysicalPlanNode is a nested type -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalPlanNode { #[prost( oneof = "physical_plan_node::PhysicalPlanType", - tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29" + tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30" )] pub physical_plan_type: ::core::option::Option, } /// Nested message and enum types in `PhysicalPlanNode`. pub mod physical_plan_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum PhysicalPlanType { #[prost(message, tag = "1")] @@ -1634,103 +1040,17 @@ pub mod physical_plan_node { CsvSink(::prost::alloc::boxed::Box), #[prost(message, tag = "29")] ParquetSink(::prost::alloc::boxed::Box), + #[prost(message, tag = "30")] + Unnest(::prost::alloc::boxed::Box), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PartitionColumn { #[prost(string, tag = "1")] pub name: ::prost::alloc::string::String, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub arrow_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct JsonWriterOptions { - #[prost(enumeration = "CompressionTypeVariant", tag = "1")] - pub compression: i32, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CsvWriterOptions { - /// Compression type - #[prost(enumeration = "CompressionTypeVariant", tag = "1")] - pub compression: i32, - /// Optional column delimiter. Defaults to `b','` - #[prost(string, tag = "2")] - pub delimiter: ::prost::alloc::string::String, - /// Whether to write column names as file headers. Defaults to `true` - #[prost(bool, tag = "3")] - pub has_header: bool, - /// Optional date format for date arrays - #[prost(string, tag = "4")] - pub date_format: ::prost::alloc::string::String, - /// Optional datetime format for datetime arrays - #[prost(string, tag = "5")] - pub datetime_format: ::prost::alloc::string::String, - /// Optional timestamp format for timestamp arrays - #[prost(string, tag = "6")] - pub timestamp_format: ::prost::alloc::string::String, - /// Optional time format for time arrays - #[prost(string, tag = "7")] - pub time_format: ::prost::alloc::string::String, - /// Optional value to represent null - #[prost(string, tag = "8")] - pub null_value: ::prost::alloc::string::String, -} -/// Options controlling CSV format -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CsvOptions { - /// Indicates if the CSV has a header row - #[prost(bool, tag = "1")] - pub has_header: bool, - /// Delimiter character as a byte - #[prost(bytes = "vec", tag = "2")] - pub delimiter: ::prost::alloc::vec::Vec, - /// Quote character as a byte - #[prost(bytes = "vec", tag = "3")] - pub quote: ::prost::alloc::vec::Vec, - /// Optional escape character as a byte - #[prost(bytes = "vec", tag = "4")] - pub escape: ::prost::alloc::vec::Vec, - /// Compression type - #[prost(enumeration = "CompressionTypeVariant", tag = "5")] - pub compression: i32, - /// Max records for schema inference - #[prost(uint64, tag = "6")] - pub schema_infer_max_rec: u64, - /// Optional date format - #[prost(string, tag = "7")] - pub date_format: ::prost::alloc::string::String, - /// Optional datetime format - #[prost(string, tag = "8")] - pub datetime_format: ::prost::alloc::string::String, - /// Optional timestamp format - #[prost(string, tag = "9")] - pub timestamp_format: ::prost::alloc::string::String, - /// Optional timestamp with timezone format - #[prost(string, tag = "10")] - pub timestamp_tz_format: ::prost::alloc::string::String, - /// Optional time format - #[prost(string, tag = "11")] - pub time_format: ::prost::alloc::string::String, - /// Optional representation of null value - #[prost(string, tag = "12")] - pub null_value: ::prost::alloc::string::String, -} -/// Options controlling CSV format -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct JsonOptions { - /// Compression type - #[prost(enumeration = "CompressionTypeVariant", tag = "1")] - pub compression: i32, - /// Max records for schema inference - #[prost(uint64, tag = "2")] - pub schema_infer_max_rec: u64, -} -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileSinkConfig { #[prost(string, tag = "1")] @@ -1738,293 +1058,65 @@ pub struct FileSinkConfig { #[prost(message, repeated, tag = "2")] pub file_groups: ::prost::alloc::vec::Vec, #[prost(string, repeated, tag = "3")] - pub table_paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, - #[prost(message, optional, tag = "4")] - pub output_schema: ::core::option::Option, - #[prost(message, repeated, tag = "5")] - pub table_partition_cols: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "8")] - pub overwrite: bool, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct JsonSink { - #[prost(message, optional, tag = "1")] - pub config: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub writer_options: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct JsonSinkExecNode { - #[prost(message, optional, boxed, tag = "1")] - pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub sink: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub sink_schema: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub sort_order: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CsvSink { - #[prost(message, optional, tag = "1")] - pub config: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub writer_options: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct CsvSinkExecNode { - #[prost(message, optional, boxed, tag = "1")] - pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(message, optional, tag = "2")] - pub sink: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub sink_schema: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub sort_order: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct TableParquetOptions { - #[prost(message, optional, tag = "1")] - pub global: ::core::option::Option, - #[prost(message, repeated, tag = "2")] - pub column_specific_options: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnSpecificOptions { - #[prost(string, tag = "1")] - pub column_name: ::prost::alloc::string::String, - #[prost(message, optional, tag = "2")] - pub options: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnOptions { - #[prost(oneof = "column_options::BloomFilterEnabledOpt", tags = "1")] - pub bloom_filter_enabled_opt: ::core::option::Option< - column_options::BloomFilterEnabledOpt, - >, - #[prost(oneof = "column_options::EncodingOpt", tags = "2")] - pub encoding_opt: ::core::option::Option, - #[prost(oneof = "column_options::DictionaryEnabledOpt", tags = "3")] - pub dictionary_enabled_opt: ::core::option::Option< - column_options::DictionaryEnabledOpt, - >, - #[prost(oneof = "column_options::CompressionOpt", tags = "4")] - pub compression_opt: ::core::option::Option, - #[prost(oneof = "column_options::StatisticsEnabledOpt", tags = "5")] - pub statistics_enabled_opt: ::core::option::Option< - column_options::StatisticsEnabledOpt, - >, - #[prost(oneof = "column_options::BloomFilterFppOpt", tags = "6")] - pub bloom_filter_fpp_opt: ::core::option::Option, - #[prost(oneof = "column_options::BloomFilterNdvOpt", tags = "7")] - pub bloom_filter_ndv_opt: ::core::option::Option, - #[prost(oneof = "column_options::MaxStatisticsSizeOpt", tags = "8")] - pub max_statistics_size_opt: ::core::option::Option< - column_options::MaxStatisticsSizeOpt, - >, -} -/// Nested message and enum types in `ColumnOptions`. -pub mod column_options { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum BloomFilterEnabledOpt { - #[prost(bool, tag = "1")] - BloomFilterEnabled(bool), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum EncodingOpt { - #[prost(string, tag = "2")] - Encoding(::prost::alloc::string::String), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum DictionaryEnabledOpt { - #[prost(bool, tag = "3")] - DictionaryEnabled(bool), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum CompressionOpt { - #[prost(string, tag = "4")] - Compression(::prost::alloc::string::String), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum StatisticsEnabledOpt { - #[prost(string, tag = "5")] - StatisticsEnabled(::prost::alloc::string::String), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum BloomFilterFppOpt { - #[prost(double, tag = "6")] - BloomFilterFpp(f64), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum BloomFilterNdvOpt { - #[prost(uint64, tag = "7")] - BloomFilterNdv(u64), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum MaxStatisticsSizeOpt { - #[prost(uint32, tag = "8")] - MaxStatisticsSize(u32), - } + pub table_paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "4")] + pub output_schema: ::core::option::Option, + #[prost(message, repeated, tag = "5")] + pub table_partition_cols: ::prost::alloc::vec::Vec, + #[prost(bool, tag = "9")] + pub keep_partition_by_columns: bool, + #[prost(enumeration = "InsertOp", tag = "10")] + pub insert_op: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct ParquetOptions { - /// Regular fields - /// - /// default = true - #[prost(bool, tag = "1")] - pub enable_page_index: bool, - /// default = true - #[prost(bool, tag = "2")] - pub pruning: bool, - /// default = true - #[prost(bool, tag = "3")] - pub skip_metadata: bool, - /// default = false - #[prost(bool, tag = "5")] - pub pushdown_filters: bool, - /// default = false - #[prost(bool, tag = "6")] - pub reorder_filters: bool, - /// default = 1024 * 1024 - #[prost(uint64, tag = "7")] - pub data_pagesize_limit: u64, - /// default = 1024 - #[prost(uint64, tag = "8")] - pub write_batch_size: u64, - /// default = "1.0" - #[prost(string, tag = "9")] - pub writer_version: ::prost::alloc::string::String, - /// default = false - #[prost(bool, tag = "20")] - pub bloom_filter_enabled: bool, - /// default = true - #[prost(bool, tag = "23")] - pub allow_single_file_parallelism: bool, - /// default = 1 - #[prost(uint64, tag = "24")] - pub maximum_parallel_row_group_writers: u64, - /// default = 2 - #[prost(uint64, tag = "25")] - pub maximum_buffered_record_batches_per_stream: u64, - #[prost(uint64, tag = "12")] - pub dictionary_page_size_limit: u64, - #[prost(uint64, tag = "18")] - pub data_page_row_count_limit: u64, - #[prost(uint64, tag = "15")] - pub max_row_group_size: u64, - #[prost(string, tag = "16")] - pub created_by: ::prost::alloc::string::String, - #[prost(oneof = "parquet_options::MetadataSizeHintOpt", tags = "4")] - pub metadata_size_hint_opt: ::core::option::Option< - parquet_options::MetadataSizeHintOpt, - >, - #[prost(oneof = "parquet_options::CompressionOpt", tags = "10")] - pub compression_opt: ::core::option::Option, - #[prost(oneof = "parquet_options::DictionaryEnabledOpt", tags = "11")] - pub dictionary_enabled_opt: ::core::option::Option< - parquet_options::DictionaryEnabledOpt, - >, - #[prost(oneof = "parquet_options::StatisticsEnabledOpt", tags = "13")] - pub statistics_enabled_opt: ::core::option::Option< - parquet_options::StatisticsEnabledOpt, - >, - #[prost(oneof = "parquet_options::MaxStatisticsSizeOpt", tags = "14")] - pub max_statistics_size_opt: ::core::option::Option< - parquet_options::MaxStatisticsSizeOpt, +pub struct JsonSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub writer_options: ::core::option::Option< + super::datafusion_common::JsonWriterOptions, >, - #[prost(oneof = "parquet_options::ColumnIndexTruncateLengthOpt", tags = "17")] - pub column_index_truncate_length_opt: ::core::option::Option< - parquet_options::ColumnIndexTruncateLengthOpt, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct JsonSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSink { + #[prost(message, optional, tag = "1")] + pub config: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub writer_options: ::core::option::Option< + super::datafusion_common::CsvWriterOptions, >, - #[prost(oneof = "parquet_options::EncodingOpt", tags = "19")] - pub encoding_opt: ::core::option::Option, - #[prost(oneof = "parquet_options::BloomFilterFppOpt", tags = "21")] - pub bloom_filter_fpp_opt: ::core::option::Option, - #[prost(oneof = "parquet_options::BloomFilterNdvOpt", tags = "22")] - pub bloom_filter_ndv_opt: ::core::option::Option, -} -/// Nested message and enum types in `ParquetOptions`. -pub mod parquet_options { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum MetadataSizeHintOpt { - #[prost(uint64, tag = "4")] - MetadataSizeHint(u64), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum CompressionOpt { - #[prost(string, tag = "10")] - Compression(::prost::alloc::string::String), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum DictionaryEnabledOpt { - #[prost(bool, tag = "11")] - DictionaryEnabled(bool), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum StatisticsEnabledOpt { - #[prost(string, tag = "13")] - StatisticsEnabled(::prost::alloc::string::String), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum MaxStatisticsSizeOpt { - #[prost(uint64, tag = "14")] - MaxStatisticsSize(u64), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum ColumnIndexTruncateLengthOpt { - #[prost(uint64, tag = "17")] - ColumnIndexTruncateLength(u64), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum EncodingOpt { - #[prost(string, tag = "19")] - Encoding(::prost::alloc::string::String), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum BloomFilterFppOpt { - #[prost(double, tag = "21")] - BloomFilterFpp(f64), - } - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum BloomFilterNdvOpt { - #[prost(uint64, tag = "22")] - BloomFilterNdv(u64), - } } -#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CsvSinkExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub sink: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub sink_schema: ::core::option::Option, + #[prost(message, optional, tag = "4")] + pub sort_order: ::core::option::Option, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetSink { #[prost(message, optional, tag = "1")] pub config: ::core::option::Option, #[prost(message, optional, tag = "2")] - pub parquet_options: ::core::option::Option, + pub parquet_options: ::core::option::Option< + super::datafusion_common::TableParquetOptions, + >, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetSinkExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2032,11 +1124,30 @@ pub struct ParquetSinkExecNode { #[prost(message, optional, tag = "2")] pub sink: ::core::option::Option, #[prost(message, optional, tag = "3")] - pub sink_schema: ::core::option::Option, + pub sink_schema: ::core::option::Option, #[prost(message, optional, tag = "4")] pub sort_order: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnnestExecNode { + #[prost(message, optional, boxed, tag = "1")] + pub input: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub schema: ::core::option::Option, + #[prost(message, repeated, tag = "3")] + pub list_type_columns: ::prost::alloc::vec::Vec, + #[prost(uint64, repeated, tag = "4")] + pub struct_type_columns: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "5")] + pub options: ::core::option::Option, +} +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct ListUnnest { + #[prost(uint32, tag = "1")] + pub index_in_input_schema: u32, + #[prost(uint32, tag = "2")] + pub depth: u32, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExtensionNode { #[prost(bytes = "vec", tag = "1")] @@ -2045,25 +1156,23 @@ pub struct PhysicalExtensionNode { pub inputs: ::prost::alloc::vec::Vec, } /// physical expressions -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19" )] pub expr_type: ::core::option::Option, } /// Nested message and enum types in `PhysicalExprNode`. pub mod physical_expr_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum ExprType { /// column references #[prost(message, tag = "1")] Column(super::PhysicalColumn), #[prost(message, tag = "2")] - Literal(super::ScalarValue), + Literal(super::super::datafusion_common::ScalarValue), /// binary expressions #[prost(message, tag = "3")] BinaryExpr(::prost::alloc::boxed::Box), @@ -2087,18 +1196,21 @@ pub mod physical_expr_node { Negative(::prost::alloc::boxed::Box), #[prost(message, tag = "12")] InList(::prost::alloc::boxed::Box), + /// was PhysicalScalarFunctionNode scalar_function = 13; #[prost(message, tag = "14")] TryCast(::prost::alloc::boxed::Box), /// window expressions #[prost(message, tag = "15")] WindowExpr(super::PhysicalWindowExprNode), + /// was PhysicalDateTimeIntervalExprNode date_time_interval_expr = 17; #[prost(message, tag = "16")] ScalarUdf(super::PhysicalScalarUdfNode), #[prost(message, tag = "18")] LikeExpr(::prost::alloc::boxed::Box), + #[prost(message, tag = "19")] + Extension(super::PhysicalExtensionExprNode), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalScalarUdfNode { #[prost(string, tag = "1")] @@ -2108,9 +1220,8 @@ pub struct PhysicalScalarUdfNode { #[prost(bytes = "vec", optional, tag = "3")] pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(message, optional, tag = "4")] - pub return_type: ::core::option::Option, + pub return_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalAggregateExprNode { #[prost(message, repeated, tag = "2")] @@ -2119,23 +1230,23 @@ pub struct PhysicalAggregateExprNode { pub ordering_req: ::prost::alloc::vec::Vec, #[prost(bool, tag = "3")] pub distinct: bool, - #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "1, 4")] + #[prost(bool, tag = "6")] + pub ignore_nulls: bool, + #[prost(bytes = "vec", optional, tag = "7")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(oneof = "physical_aggregate_expr_node::AggregateFunction", tags = "4")] pub aggregate_function: ::core::option::Option< physical_aggregate_expr_node::AggregateFunction, >, } /// Nested message and enum types in `PhysicalAggregateExprNode`. pub mod physical_aggregate_expr_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum AggregateFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), #[prost(string, tag = "4")] UserDefinedAggrFunction(::prost::alloc::string::String), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalWindowExprNode { #[prost(message, repeated, tag = "4")] @@ -2148,42 +1259,38 @@ pub struct PhysicalWindowExprNode { pub window_frame: ::core::option::Option, #[prost(string, tag = "8")] pub name: ::prost::alloc::string::String, - #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "1, 2")] + #[prost(bytes = "vec", optional, tag = "9")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, + #[prost(oneof = "physical_window_expr_node::WindowFunction", tags = "2, 3")] pub window_function: ::core::option::Option< physical_window_expr_node::WindowFunction, >, } /// Nested message and enum types in `PhysicalWindowExprNode`. pub mod physical_window_expr_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum WindowFunction { - #[prost(enumeration = "super::AggregateFunction", tag = "1")] - AggrFunction(i32), - /// udaf = 3 #[prost(enumeration = "super::BuiltInWindowFunction", tag = "2")] BuiltInFunction(i32), + #[prost(string, tag = "3")] + UserDefinedAggrFunction(::prost::alloc::string::String), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalIsNull { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalIsNotNull { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalNot { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalAliasNode { #[prost(message, optional, tag = "1")] @@ -2191,7 +1298,6 @@ pub struct PhysicalAliasNode { #[prost(string, tag = "2")] pub alias: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalBinaryExprNode { #[prost(message, optional, boxed, tag = "1")] @@ -2201,7 +1307,6 @@ pub struct PhysicalBinaryExprNode { #[prost(string, tag = "3")] pub op: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalDateTimeIntervalExprNode { #[prost(message, optional, tag = "1")] @@ -2211,7 +1316,6 @@ pub struct PhysicalDateTimeIntervalExprNode { #[prost(string, tag = "3")] pub op: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalLikeExprNode { #[prost(bool, tag = "1")] @@ -2223,7 +1327,6 @@ pub struct PhysicalLikeExprNode { #[prost(message, optional, boxed, tag = "4")] pub pattern: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalSortExprNode { #[prost(message, optional, boxed, tag = "1")] @@ -2233,7 +1336,6 @@ pub struct PhysicalSortExprNode { #[prost(bool, tag = "3")] pub nulls_first: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalWhenThen { #[prost(message, optional, tag = "1")] @@ -2241,7 +1343,6 @@ pub struct PhysicalWhenThen { #[prost(message, optional, tag = "2")] pub then_expr: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalInListNode { #[prost(message, optional, boxed, tag = "1")] @@ -2251,7 +1352,6 @@ pub struct PhysicalInListNode { #[prost(bool, tag = "3")] pub negated: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalCaseNode { #[prost(message, optional, boxed, tag = "1")] @@ -2261,29 +1361,32 @@ pub struct PhysicalCaseNode { #[prost(message, optional, boxed, tag = "3")] pub else_expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalTryCastNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub arrow_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalCastNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] - pub arrow_type: ::core::option::Option, + pub arrow_type: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalNegativeNode { #[prost(message, optional, boxed, tag = "1")] pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalExtensionExprNode { + #[prost(bytes = "vec", tag = "1")] + pub expr: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "2")] + pub inputs: ::prost::alloc::vec::Vec, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2292,39 +1395,37 @@ pub struct FilterExecNode { pub expr: ::core::option::Option, #[prost(uint32, tag = "3")] pub default_filter_selectivity: u32, + #[prost(uint32, repeated, tag = "9")] + pub projection: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileGroup { #[prost(message, repeated, tag = "1")] pub files: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ScanLimit { /// wrap into a message to make it optional #[prost(uint32, tag = "1")] pub limit: u32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalSortExprNodeCollection { #[prost(message, repeated, tag = "1")] pub physical_sort_expr_nodes: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FileScanExecConf { #[prost(message, repeated, tag = "1")] pub file_groups: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "2")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, #[prost(uint32, repeated, tag = "4")] pub projection: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "5")] pub limit: ::core::option::Option, #[prost(message, optional, tag = "6")] - pub statistics: ::core::option::Option, + pub statistics: ::core::option::Option, #[prost(string, repeated, tag = "7")] pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(string, tag = "8")] @@ -2332,7 +1433,6 @@ pub struct FileScanExecConf { #[prost(message, repeated, tag = "9")] pub output_ordering: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetScanExecNode { #[prost(message, optional, tag = "1")] @@ -2340,7 +1440,6 @@ pub struct ParquetScanExecNode { #[prost(message, optional, tag = "3")] pub predicate: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CsvScanExecNode { #[prost(message, optional, tag = "1")] @@ -2351,25 +1450,31 @@ pub struct CsvScanExecNode { pub delimiter: ::prost::alloc::string::String, #[prost(string, tag = "4")] pub quote: ::prost::alloc::string::String, + #[prost(bool, tag = "7")] + pub newlines_in_values: bool, #[prost(oneof = "csv_scan_exec_node::OptionalEscape", tags = "5")] pub optional_escape: ::core::option::Option, + #[prost(oneof = "csv_scan_exec_node::OptionalComment", tags = "6")] + pub optional_comment: ::core::option::Option, } /// Nested message and enum types in `CsvScanExecNode`. pub mod csv_scan_exec_node { - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum OptionalEscape { #[prost(string, tag = "5")] Escape(::prost::alloc::string::String), } + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum OptionalComment { + #[prost(string, tag = "6")] + Comment(::prost::alloc::string::String), + } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AvroScanExecNode { #[prost(message, optional, tag = "1")] pub base_conf: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct HashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2378,7 +1483,7 @@ pub struct HashJoinExecNode { pub right: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "3")] pub on: ::prost::alloc::vec::Vec, - #[prost(enumeration = "JoinType", tag = "4")] + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] pub join_type: i32, #[prost(enumeration = "PartitionMode", tag = "6")] pub partition_mode: i32, @@ -2389,7 +1494,6 @@ pub struct HashJoinExecNode { #[prost(uint32, repeated, tag = "9")] pub projection: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SymmetricHashJoinExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2398,7 +1502,7 @@ pub struct SymmetricHashJoinExecNode { pub right: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, repeated, tag = "3")] pub on: ::prost::alloc::vec::Vec, - #[prost(enumeration = "JoinType", tag = "4")] + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")] pub join_type: i32, #[prost(enumeration = "StreamPartitionMode", tag = "6")] pub partition_mode: i32, @@ -2411,29 +1515,25 @@ pub struct SymmetricHashJoinExecNode { #[prost(message, repeated, tag = "10")] pub right_sort_exprs: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct InterleaveExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct UnionExecNode { #[prost(message, repeated, tag = "1")] pub inputs: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ExplainExecNode { #[prost(message, optional, tag = "1")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, #[prost(message, repeated, tag = "2")] pub stringified_plans: ::prost::alloc::vec::Vec, #[prost(bool, tag = "3")] pub verbose: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct AnalyzeExecNode { #[prost(bool, tag = "1")] @@ -2443,9 +1543,8 @@ pub struct AnalyzeExecNode { #[prost(message, optional, boxed, tag = "3")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "4")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CrossJoinExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2453,7 +1552,6 @@ pub struct CrossJoinExecNode { #[prost(message, optional, boxed, tag = "2")] pub right: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalColumn { #[prost(string, tag = "1")] @@ -2461,7 +1559,6 @@ pub struct PhysicalColumn { #[prost(uint32, tag = "2")] pub index: u32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct JoinOn { #[prost(message, optional, tag = "1")] @@ -2469,19 +1566,16 @@ pub struct JoinOn { #[prost(message, optional, tag = "2")] pub right: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct EmptyExecNode { #[prost(message, optional, tag = "1")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PlaceholderRowExecNode { #[prost(message, optional, tag = "1")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ProjectionExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2491,13 +1585,11 @@ pub struct ProjectionExecNode { #[prost(string, repeated, tag = "3")] pub expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PartiallySortedInputOrderMode { #[prost(uint64, repeated, tag = "6")] pub columns: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WindowAggExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2513,30 +1605,32 @@ pub struct WindowAggExecNode { /// Nested message and enum types in `WindowAggExecNode`. pub mod window_agg_exec_node { /// Set optional to `None` for `BoundedWindowAggExec`. - #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum InputOrderMode { #[prost(message, tag = "7")] - Linear(super::EmptyMessage), + Linear(super::super::datafusion_common::EmptyMessage), #[prost(message, tag = "8")] PartiallySorted(super::PartiallySortedInputOrderMode), #[prost(message, tag = "9")] - Sorted(super::EmptyMessage), + Sorted(super::super::datafusion_common::EmptyMessage), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct MaybeFilter { #[prost(message, optional, tag = "1")] pub expr: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct MaybePhysicalSortExprs { #[prost(message, repeated, tag = "1")] pub sort_expr: ::prost::alloc::vec::Vec, } -#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct AggLimit { + /// wrap into a message to make it optional + #[prost(uint64, tag = "1")] + pub limit: u64, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateExecNode { #[prost(message, repeated, tag = "1")] @@ -2553,15 +1647,16 @@ pub struct AggregateExecNode { pub aggr_expr_name: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, /// we need the input schema to the partial aggregate to pass to the final aggregate #[prost(message, optional, tag = "7")] - pub input_schema: ::core::option::Option, + pub input_schema: ::core::option::Option, #[prost(message, repeated, tag = "8")] pub null_expr: ::prost::alloc::vec::Vec, #[prost(bool, repeated, tag = "9")] pub groups: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "10")] pub filter_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "11")] + pub limit: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct GlobalLimitExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2573,7 +1668,6 @@ pub struct GlobalLimitExecNode { #[prost(int64, tag = "3")] pub fetch: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct LocalLimitExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2581,7 +1675,6 @@ pub struct LocalLimitExecNode { #[prost(uint32, tag = "2")] pub fetch: u32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2594,7 +1687,6 @@ pub struct SortExecNode { #[prost(bool, tag = "4")] pub preserve_partitioning: bool, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct SortPreservingMergeExecNode { #[prost(message, optional, boxed, tag = "1")] @@ -2605,33 +1697,31 @@ pub struct SortPreservingMergeExecNode { #[prost(int64, tag = "3")] pub fetch: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct NestedLoopJoinExecNode { #[prost(message, optional, boxed, tag = "1")] pub left: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, boxed, tag = "2")] pub right: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(enumeration = "JoinType", tag = "3")] + #[prost(enumeration = "super::datafusion_common::JoinType", tag = "3")] pub join_type: i32, #[prost(message, optional, tag = "4")] pub filter: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CoalesceBatchesExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(uint32, tag = "2")] pub target_batch_size: u32, + #[prost(uint32, optional, tag = "3")] + pub fetch: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct CoalescePartitionsExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalHashRepartition { #[prost(message, repeated, tag = "1")] @@ -2639,28 +1729,35 @@ pub struct PhysicalHashRepartition { #[prost(uint64, tag = "2")] pub partition_count: u64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct RepartitionExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "repartition_exec_node::PartitionMethod", tags = "2, 3, 4")] - pub partition_method: ::core::option::Option, + /// oneof partition_method { + /// uint64 round_robin = 2; + /// PhysicalHashRepartition hash = 3; + /// uint64 unknown = 4; + /// } + #[prost(message, optional, tag = "5")] + pub partitioning: ::core::option::Option, } -/// Nested message and enum types in `RepartitionExecNode`. -pub mod repartition_exec_node { - #[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Partitioning { + #[prost(oneof = "partitioning::PartitionMethod", tags = "1, 2, 3")] + pub partition_method: ::core::option::Option, +} +/// Nested message and enum types in `Partitioning`. +pub mod partitioning { #[derive(Clone, PartialEq, ::prost::Oneof)] pub enum PartitionMethod { - #[prost(uint64, tag = "2")] + #[prost(uint64, tag = "1")] RoundRobin(u64), - #[prost(message, tag = "3")] + #[prost(message, tag = "2")] Hash(super::PhysicalHashRepartition), - #[prost(uint64, tag = "4")] + #[prost(uint64, tag = "3")] Unknown(u64), } } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct JoinFilter { #[prost(message, optional, tag = "1")] @@ -2668,17 +1765,15 @@ pub struct JoinFilter { #[prost(message, repeated, tag = "2")] pub column_indices: ::prost::alloc::vec::Vec, #[prost(message, optional, tag = "3")] - pub schema: ::core::option::Option, + pub schema: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ColumnIndex { #[prost(uint32, tag = "1")] pub index: u32, - #[prost(enumeration = "JoinSide", tag = "2")] + #[prost(enumeration = "super::datafusion_common::JoinSide", tag = "2")] pub side: i32, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PartitionedFile { #[prost(string, tag = "1")] @@ -2688,19 +1783,21 @@ pub struct PartitionedFile { #[prost(uint64, tag = "3")] pub last_modified_ns: u64, #[prost(message, repeated, tag = "4")] - pub partition_values: ::prost::alloc::vec::Vec, + pub partition_values: ::prost::alloc::vec::Vec< + super::datafusion_common::ScalarValue, + >, #[prost(message, optional, tag = "5")] pub range: ::core::option::Option, + #[prost(message, optional, tag = "6")] + pub statistics: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct FileRange { #[prost(int64, tag = "1")] pub start: i64, #[prost(int64, tag = "2")] pub end: i64, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct PartitionStats { #[prost(int64, tag = "1")] @@ -2710,256 +1807,21 @@ pub struct PartitionStats { #[prost(int64, tag = "3")] pub num_bytes: i64, #[prost(message, repeated, tag = "4")] - pub column_stats: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Precision { - #[prost(enumeration = "PrecisionInfo", tag = "1")] - pub precision_info: i32, - #[prost(message, optional, tag = "2")] - pub val: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct Statistics { - #[prost(message, optional, tag = "1")] - pub num_rows: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub total_byte_size: ::core::option::Option, - #[prost(message, repeated, tag = "3")] - pub column_stats: ::prost::alloc::vec::Vec, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ColumnStats { - #[prost(message, optional, tag = "1")] - pub min_value: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub max_value: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub null_count: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub distinct_count: ::core::option::Option, -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum JoinType { - Inner = 0, - Left = 1, - Right = 2, - Full = 3, - Leftsemi = 4, - Leftanti = 5, - Rightsemi = 6, - Rightanti = 7, -} -impl JoinType { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - JoinType::Inner => "INNER", - JoinType::Left => "LEFT", - JoinType::Right => "RIGHT", - JoinType::Full => "FULL", - JoinType::Leftsemi => "LEFTSEMI", - JoinType::Leftanti => "LEFTANTI", - JoinType::Rightsemi => "RIGHTSEMI", - JoinType::Rightanti => "RIGHTANTI", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "INNER" => Some(Self::Inner), - "LEFT" => Some(Self::Left), - "RIGHT" => Some(Self::Right), - "FULL" => Some(Self::Full), - "LEFTSEMI" => Some(Self::Leftsemi), - "LEFTANTI" => Some(Self::Leftanti), - "RIGHTSEMI" => Some(Self::Rightsemi), - "RIGHTANTI" => Some(Self::Rightanti), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum JoinConstraint { - On = 0, - Using = 1, -} -impl JoinConstraint { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - JoinConstraint::On => "ON", - JoinConstraint::Using => "USING", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "ON" => Some(Self::On), - "USING" => Some(Self::Using), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum AggregateFunction { - Min = 0, - Max = 1, - Sum = 2, - Avg = 3, - Count = 4, - ApproxDistinct = 5, - ArrayAgg = 6, - Variance = 7, - VariancePop = 8, - Covariance = 9, - CovariancePop = 10, - Stddev = 11, - StddevPop = 12, - Correlation = 13, - ApproxPercentileCont = 14, - ApproxMedian = 15, - ApproxPercentileContWithWeight = 16, - Grouping = 17, - Median = 18, - BitAnd = 19, - BitOr = 20, - BitXor = 21, - BoolAnd = 22, - BoolOr = 23, - /// When a function with the same name exists among built-in window functions, - /// we append "_AGG" to obey name scoping rules. - FirstValueAgg = 24, - LastValueAgg = 25, - RegrSlope = 26, - RegrIntercept = 27, - RegrCount = 28, - RegrR2 = 29, - RegrAvgx = 30, - RegrAvgy = 31, - RegrSxx = 32, - RegrSyy = 33, - RegrSxy = 34, - StringAgg = 35, - NthValueAgg = 36, -} -impl AggregateFunction { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - AggregateFunction::Min => "MIN", - AggregateFunction::Max => "MAX", - AggregateFunction::Sum => "SUM", - AggregateFunction::Avg => "AVG", - AggregateFunction::Count => "COUNT", - AggregateFunction::ApproxDistinct => "APPROX_DISTINCT", - AggregateFunction::ArrayAgg => "ARRAY_AGG", - AggregateFunction::Variance => "VARIANCE", - AggregateFunction::VariancePop => "VARIANCE_POP", - AggregateFunction::Covariance => "COVARIANCE", - AggregateFunction::CovariancePop => "COVARIANCE_POP", - AggregateFunction::Stddev => "STDDEV", - AggregateFunction::StddevPop => "STDDEV_POP", - AggregateFunction::Correlation => "CORRELATION", - AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - AggregateFunction::ApproxMedian => "APPROX_MEDIAN", - AggregateFunction::ApproxPercentileContWithWeight => { - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" - } - AggregateFunction::Grouping => "GROUPING", - AggregateFunction::Median => "MEDIAN", - AggregateFunction::BitAnd => "BIT_AND", - AggregateFunction::BitOr => "BIT_OR", - AggregateFunction::BitXor => "BIT_XOR", - AggregateFunction::BoolAnd => "BOOL_AND", - AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::FirstValueAgg => "FIRST_VALUE_AGG", - AggregateFunction::LastValueAgg => "LAST_VALUE_AGG", - AggregateFunction::RegrSlope => "REGR_SLOPE", - AggregateFunction::RegrIntercept => "REGR_INTERCEPT", - AggregateFunction::RegrCount => "REGR_COUNT", - AggregateFunction::RegrR2 => "REGR_R2", - AggregateFunction::RegrAvgx => "REGR_AVGX", - AggregateFunction::RegrAvgy => "REGR_AVGY", - AggregateFunction::RegrSxx => "REGR_SXX", - AggregateFunction::RegrSyy => "REGR_SYY", - AggregateFunction::RegrSxy => "REGR_SXY", - AggregateFunction::StringAgg => "STRING_AGG", - AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "MIN" => Some(Self::Min), - "MAX" => Some(Self::Max), - "SUM" => Some(Self::Sum), - "AVG" => Some(Self::Avg), - "COUNT" => Some(Self::Count), - "APPROX_DISTINCT" => Some(Self::ApproxDistinct), - "ARRAY_AGG" => Some(Self::ArrayAgg), - "VARIANCE" => Some(Self::Variance), - "VARIANCE_POP" => Some(Self::VariancePop), - "COVARIANCE" => Some(Self::Covariance), - "COVARIANCE_POP" => Some(Self::CovariancePop), - "STDDEV" => Some(Self::Stddev), - "STDDEV_POP" => Some(Self::StddevPop), - "CORRELATION" => Some(Self::Correlation), - "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), - "APPROX_MEDIAN" => Some(Self::ApproxMedian), - "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => { - Some(Self::ApproxPercentileContWithWeight) - } - "GROUPING" => Some(Self::Grouping), - "MEDIAN" => Some(Self::Median), - "BIT_AND" => Some(Self::BitAnd), - "BIT_OR" => Some(Self::BitOr), - "BIT_XOR" => Some(Self::BitXor), - "BOOL_AND" => Some(Self::BoolAnd), - "BOOL_OR" => Some(Self::BoolOr), - "FIRST_VALUE_AGG" => Some(Self::FirstValueAgg), - "LAST_VALUE_AGG" => Some(Self::LastValueAgg), - "REGR_SLOPE" => Some(Self::RegrSlope), - "REGR_INTERCEPT" => Some(Self::RegrIntercept), - "REGR_COUNT" => Some(Self::RegrCount), - "REGR_R2" => Some(Self::RegrR2), - "REGR_AVGX" => Some(Self::RegrAvgx), - "REGR_AVGY" => Some(Self::RegrAvgy), - "REGR_SXX" => Some(Self::RegrSxx), - "REGR_SYY" => Some(Self::RegrSyy), - "REGR_SXY" => Some(Self::RegrSxy), - "STRING_AGG" => Some(Self::StringAgg), - "NTH_VALUE_AGG" => Some(Self::NthValueAgg), - _ => None, - } - } + pub column_stats: ::prost::alloc::vec::Vec, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum BuiltInWindowFunction { - RowNumber = 0, - Rank = 1, - DenseRank = 2, - PercentRank = 3, - CumeDist = 4, - Ntile = 5, - Lag = 6, - Lead = 7, + /// + Unspecified = 0, + /// ROW_NUMBER = 0; + /// RANK = 1; + /// DENSE_RANK = 2; + /// PERCENT_RANK = 3; + /// CUME_DIST = 4; + /// NTILE = 5; + /// LAG = 6; + /// LEAD = 7; FirstValue = 8, LastValue = 9, NthValue = 10, @@ -2971,30 +1833,16 @@ impl BuiltInWindowFunction { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - BuiltInWindowFunction::RowNumber => "ROW_NUMBER", - BuiltInWindowFunction::Rank => "RANK", - BuiltInWindowFunction::DenseRank => "DENSE_RANK", - BuiltInWindowFunction::PercentRank => "PERCENT_RANK", - BuiltInWindowFunction::CumeDist => "CUME_DIST", - BuiltInWindowFunction::Ntile => "NTILE", - BuiltInWindowFunction::Lag => "LAG", - BuiltInWindowFunction::Lead => "LEAD", - BuiltInWindowFunction::FirstValue => "FIRST_VALUE", - BuiltInWindowFunction::LastValue => "LAST_VALUE", - BuiltInWindowFunction::NthValue => "NTH_VALUE", + Self::Unspecified => "UNSPECIFIED", + Self::FirstValue => "FIRST_VALUE", + Self::LastValue => "LAST_VALUE", + Self::NthValue => "NTH_VALUE", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { - "ROW_NUMBER" => Some(Self::RowNumber), - "RANK" => Some(Self::Rank), - "DENSE_RANK" => Some(Self::DenseRank), - "PERCENT_RANK" => Some(Self::PercentRank), - "CUME_DIST" => Some(Self::CumeDist), - "NTILE" => Some(Self::Ntile), - "LAG" => Some(Self::Lag), - "LEAD" => Some(Self::Lead), + "UNSPECIFIED" => Some(Self::Unspecified), "FIRST_VALUE" => Some(Self::FirstValue), "LAST_VALUE" => Some(Self::LastValue), "NTH_VALUE" => Some(Self::NthValue), @@ -3016,9 +1864,9 @@ impl WindowFrameUnits { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - WindowFrameUnits::Rows => "ROWS", - WindowFrameUnits::Range => "RANGE", - WindowFrameUnits::Groups => "GROUPS", + Self::Rows => "ROWS", + Self::Range => "RANGE", + Self::Groups => "GROUPS", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3045,9 +1893,9 @@ impl WindowFrameBoundType { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - WindowFrameBoundType::CurrentRow => "CURRENT_ROW", - WindowFrameBoundType::Preceding => "PRECEDING", - WindowFrameBoundType::Following => "FOLLOWING", + Self::CurrentRow => "CURRENT_ROW", + Self::Preceding => "PRECEDING", + Self::Following => "FOLLOWING", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3073,8 +1921,8 @@ impl DateUnit { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - DateUnit::Day => "Day", - DateUnit::DateMillisecond => "DateMillisecond", + Self::Day => "Day", + Self::DateMillisecond => "DateMillisecond", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3088,122 +1936,29 @@ impl DateUnit { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] -pub enum TimeUnit { - Second = 0, - Millisecond = 1, - Microsecond = 2, - Nanosecond = 3, -} -impl TimeUnit { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - TimeUnit::Second => "Second", - TimeUnit::Millisecond => "Millisecond", - TimeUnit::Microsecond => "Microsecond", - TimeUnit::Nanosecond => "Nanosecond", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "Second" => Some(Self::Second), - "Millisecond" => Some(Self::Millisecond), - "Microsecond" => Some(Self::Microsecond), - "Nanosecond" => Some(Self::Nanosecond), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum IntervalUnit { - YearMonth = 0, - DayTime = 1, - MonthDayNano = 2, -} -impl IntervalUnit { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - IntervalUnit::YearMonth => "YearMonth", - IntervalUnit::DayTime => "DayTime", - IntervalUnit::MonthDayNano => "MonthDayNano", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "YearMonth" => Some(Self::YearMonth), - "DayTime" => Some(Self::DayTime), - "MonthDayNano" => Some(Self::MonthDayNano), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum UnionMode { - Sparse = 0, - Dense = 1, -} -impl UnionMode { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - UnionMode::Sparse => "sparse", - UnionMode::Dense => "dense", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "sparse" => Some(Self::Sparse), - "dense" => Some(Self::Dense), - _ => None, - } - } +pub enum InsertOp { + Append = 0, + Overwrite = 1, + Replace = 2, } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum CompressionTypeVariant { - Gzip = 0, - Bzip2 = 1, - Xz = 2, - Zstd = 3, - Uncompressed = 4, -} -impl CompressionTypeVariant { +impl InsertOp { /// String value of the enum field names used in the ProtoBuf definition. /// /// The values are not transformed in any way and thus are considered stable /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - CompressionTypeVariant::Gzip => "GZIP", - CompressionTypeVariant::Bzip2 => "BZIP2", - CompressionTypeVariant::Xz => "XZ", - CompressionTypeVariant::Zstd => "ZSTD", - CompressionTypeVariant::Uncompressed => "UNCOMPRESSED", + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", } } /// Creates an enum from field names used in the ProtoBuf definition. pub fn from_str_name(value: &str) -> ::core::option::Option { match value { - "GZIP" => Some(Self::Gzip), - "BZIP2" => Some(Self::Bzip2), - "XZ" => Some(Self::Xz), - "ZSTD" => Some(Self::Zstd), - "UNCOMPRESSED" => Some(Self::Uncompressed), + "Append" => Some(Self::Append), + "Overwrite" => Some(Self::Overwrite), + "Replace" => Some(Self::Replace), _ => None, } } @@ -3222,9 +1977,9 @@ impl PartitionMode { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - PartitionMode::CollectLeft => "COLLECT_LEFT", - PartitionMode::Partitioned => "PARTITIONED", - PartitionMode::Auto => "AUTO", + Self::CollectLeft => "COLLECT_LEFT", + Self::Partitioned => "PARTITIONED", + Self::Auto => "AUTO", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3250,8 +2005,8 @@ impl StreamPartitionMode { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - StreamPartitionMode::SinglePartition => "SINGLE_PARTITION", - StreamPartitionMode::PartitionedExec => "PARTITIONED_EXEC", + Self::SinglePartition => "SINGLE_PARTITION", + Self::PartitionedExec => "PARTITIONED_EXEC", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3279,11 +2034,11 @@ impl AggregateMode { /// (if the ProtoBuf definition does not change) and safe for programmatic use. pub fn as_str_name(&self) -> &'static str { match self { - AggregateMode::Partial => "PARTIAL", - AggregateMode::Final => "FINAL", - AggregateMode::FinalPartitioned => "FINAL_PARTITIONED", - AggregateMode::Single => "SINGLE", - AggregateMode::SinglePartitioned => "SINGLE_PARTITIONED", + Self::Partial => "PARTIAL", + Self::Final => "FINAL", + Self::FinalPartitioned => "FINAL_PARTITIONED", + Self::Single => "SINGLE", + Self::SinglePartitioned => "SINGLE_PARTITIONED", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -3298,58 +2053,3 @@ impl AggregateMode { } } } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum JoinSide { - LeftSide = 0, - RightSide = 1, -} -impl JoinSide { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - JoinSide::LeftSide => "LEFT_SIDE", - JoinSide::RightSide => "RIGHT_SIDE", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "LEFT_SIDE" => Some(Self::LeftSide), - "RIGHT_SIDE" => Some(Self::RightSide), - _ => None, - } - } -} -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] -#[repr(i32)] -pub enum PrecisionInfo { - Exact = 0, - Inexact = 1, - Absent = 2, -} -impl PrecisionInfo { - /// String value of the enum field names used in the ProtoBuf definition. - /// - /// The values are not transformed in any way and thus are considered stable - /// (if the ProtoBuf definition does not change) and safe for programmatic use. - pub fn as_str_name(&self) -> &'static str { - match self { - PrecisionInfo::Exact => "EXACT", - PrecisionInfo::Inexact => "INEXACT", - PrecisionInfo::Absent => "ABSENT", - } - } - /// Creates an enum from field names used in the ProtoBuf definition. - pub fn from_str_name(value: &str) -> ::core::option::Option { - match value { - "EXACT" => Some(Self::Exact), - "INEXACT" => Some(Self::Inexact), - "ABSENT" => Some(Self::Absent), - _ => None, - } - } -} diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 5d60b9b57454..e7019553f53d 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] //! Serialize / Deserialize DataFusion Plans to bytes //! @@ -117,7 +119,15 @@ pub mod generated; pub mod logical_plan; pub mod physical_plan; -pub use generated::datafusion as protobuf; +pub mod protobuf { + pub use crate::generated::datafusion::*; + pub use datafusion_proto_common::common::proto_error; + pub use datafusion_proto_common::protobuf_common::{ + ArrowOptions, ArrowType, AvroFormat, AvroOptions, CsvFormat, DfSchema, + EmptyMessage, Field, JoinSide, NdJsonFormat, ParquetFormat, ScalarValue, Schema, + }; + pub use datafusion_proto_common::{FromProtoError, ToProtoError}; +} #[cfg(doctest)] doc_comment::doctest!("../README.md", readme_example_test); diff --git a/datafusion/proto/src/logical_plan/file_formats.rs b/datafusion/proto/src/logical_plan/file_formats.rs new file mode 100644 index 000000000000..02be3e11c1cb --- /dev/null +++ b/datafusion/proto/src/logical_plan/file_formats.rs @@ -0,0 +1,766 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use datafusion::{ + config::{ + CsvOptions, JsonOptions, ParquetColumnOptions, ParquetOptions, + TableParquetOptions, + }, + datasource::file_format::{ + arrow::ArrowFormatFactory, csv::CsvFormatFactory, json::JsonFormatFactory, + parquet::ParquetFormatFactory, FileFormatFactory, + }, + prelude::SessionContext, +}; +use datafusion_common::{ + exec_err, not_impl_err, parsers::CompressionTypeVariant, DataFusionError, + TableReference, +}; +use prost::Message; + +use crate::protobuf::{ + parquet_column_options, parquet_options, CsvOptions as CsvOptionsProto, + JsonOptions as JsonOptionsProto, ParquetColumnOptions as ParquetColumnOptionsProto, + ParquetColumnSpecificOptions, ParquetOptions as ParquetOptionsProto, + TableParquetOptions as TableParquetOptionsProto, +}; + +use super::LogicalExtensionCodec; + +#[derive(Debug)] +pub struct CsvLogicalExtensionCodec; + +impl CsvOptionsProto { + fn from_factory(factory: &CsvFormatFactory) -> Self { + if let Some(options) = &factory.options { + CsvOptionsProto { + has_header: options.has_header.map_or(vec![], |v| vec![v as u8]), + delimiter: vec![options.delimiter], + quote: vec![options.quote], + terminator: options.terminator.map_or(vec![], |v| vec![v]), + escape: options.escape.map_or(vec![], |v| vec![v]), + double_quote: options.double_quote.map_or(vec![], |v| vec![v as u8]), + compression: options.compression as i32, + schema_infer_max_rec: options.schema_infer_max_rec as u64, + date_format: options.date_format.clone().unwrap_or_default(), + datetime_format: options.datetime_format.clone().unwrap_or_default(), + timestamp_format: options.timestamp_format.clone().unwrap_or_default(), + timestamp_tz_format: options + .timestamp_tz_format + .clone() + .unwrap_or_default(), + time_format: options.time_format.clone().unwrap_or_default(), + null_value: options.null_value.clone().unwrap_or_default(), + comment: options.comment.map_or(vec![], |v| vec![v]), + newlines_in_values: options + .newlines_in_values + .map_or(vec![], |v| vec![v as u8]), + } + } else { + CsvOptionsProto::default() + } + } +} + +impl From<&CsvOptionsProto> for CsvOptions { + fn from(proto: &CsvOptionsProto) -> Self { + CsvOptions { + has_header: if !proto.has_header.is_empty() { + Some(proto.has_header[0] != 0) + } else { + None + }, + delimiter: proto.delimiter.first().copied().unwrap_or(b','), + quote: proto.quote.first().copied().unwrap_or(b'"'), + terminator: if !proto.terminator.is_empty() { + Some(proto.terminator[0]) + } else { + None + }, + escape: if !proto.escape.is_empty() { + Some(proto.escape[0]) + } else { + None + }, + double_quote: if !proto.double_quote.is_empty() { + Some(proto.double_quote[0] != 0) + } else { + None + }, + compression: match proto.compression { + 0 => CompressionTypeVariant::GZIP, + 1 => CompressionTypeVariant::BZIP2, + 2 => CompressionTypeVariant::XZ, + 3 => CompressionTypeVariant::ZSTD, + _ => CompressionTypeVariant::UNCOMPRESSED, + }, + schema_infer_max_rec: proto.schema_infer_max_rec as usize, + date_format: if proto.date_format.is_empty() { + None + } else { + Some(proto.date_format.clone()) + }, + datetime_format: if proto.datetime_format.is_empty() { + None + } else { + Some(proto.datetime_format.clone()) + }, + timestamp_format: if proto.timestamp_format.is_empty() { + None + } else { + Some(proto.timestamp_format.clone()) + }, + timestamp_tz_format: if proto.timestamp_tz_format.is_empty() { + None + } else { + Some(proto.timestamp_tz_format.clone()) + }, + time_format: if proto.time_format.is_empty() { + None + } else { + Some(proto.time_format.clone()) + }, + null_value: if proto.null_value.is_empty() { + None + } else { + Some(proto.null_value.clone()) + }, + comment: if !proto.comment.is_empty() { + Some(proto.comment[0]) + } else { + None + }, + newlines_in_values: if proto.newlines_in_values.is_empty() { + None + } else { + Some(proto.newlines_in_values[0] != 0) + }, + } + } +} + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for CsvLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + buf: &[u8], + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + let proto = CsvOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode CsvOptionsProto: {:?}", + e + )) + })?; + let options: CsvOptions = (&proto).into(); + Ok(Arc::new(CsvFormatFactory { + options: Some(options), + })) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> datafusion_common::Result<()> { + let options = + if let Some(csv_factory) = node.as_any().downcast_ref::() { + csv_factory.options.clone().unwrap_or_default() + } else { + return exec_err!("{}", "Unsupported FileFormatFactory type".to_string()); + }; + + let proto = CsvOptionsProto::from_factory(&CsvFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to encode CsvOptions: {:?}", e)) + })?; + + Ok(()) + } +} + +impl JsonOptionsProto { + fn from_factory(factory: &JsonFormatFactory) -> Self { + if let Some(options) = &factory.options { + JsonOptionsProto { + compression: options.compression as i32, + schema_infer_max_rec: options.schema_infer_max_rec as u64, + } + } else { + JsonOptionsProto::default() + } + } +} + +impl From<&JsonOptionsProto> for JsonOptions { + fn from(proto: &JsonOptionsProto) -> Self { + JsonOptions { + compression: match proto.compression { + 0 => CompressionTypeVariant::GZIP, + 1 => CompressionTypeVariant::BZIP2, + 2 => CompressionTypeVariant::XZ, + 3 => CompressionTypeVariant::ZSTD, + _ => CompressionTypeVariant::UNCOMPRESSED, + }, + schema_infer_max_rec: proto.schema_infer_max_rec as usize, + } + } +} + +#[derive(Debug)] +pub struct JsonLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for JsonLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + buf: &[u8], + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + let proto = JsonOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode JsonOptionsProto: {:?}", + e + )) + })?; + let options: JsonOptions = (&proto).into(); + Ok(Arc::new(JsonFormatFactory { + options: Some(options), + })) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> datafusion_common::Result<()> { + let options = if let Some(json_factory) = + node.as_any().downcast_ref::() + { + json_factory.options.clone().unwrap_or_default() + } else { + return Err(DataFusionError::Execution( + "Unsupported FileFormatFactory type".to_string(), + )); + }; + + let proto = JsonOptionsProto::from_factory(&JsonFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!("Failed to encode JsonOptions: {:?}", e)) + })?; + + Ok(()) + } +} + +impl TableParquetOptionsProto { + fn from_factory(factory: &ParquetFormatFactory) -> Self { + let global_options = if let Some(ref options) = factory.options { + options.clone() + } else { + return TableParquetOptionsProto::default(); + }; + + let column_specific_options = global_options.column_specific_options; + TableParquetOptionsProto { + global: Some(ParquetOptionsProto { + enable_page_index: global_options.global.enable_page_index, + pruning: global_options.global.pruning, + skip_metadata: global_options.global.skip_metadata, + metadata_size_hint_opt: global_options.global.metadata_size_hint.map(|size| { + parquet_options::MetadataSizeHintOpt::MetadataSizeHint(size as u64) + }), + pushdown_filters: global_options.global.pushdown_filters, + reorder_filters: global_options.global.reorder_filters, + data_pagesize_limit: global_options.global.data_pagesize_limit as u64, + write_batch_size: global_options.global.write_batch_size as u64, + writer_version: global_options.global.writer_version.clone(), + compression_opt: global_options.global.compression.map(|compression| { + parquet_options::CompressionOpt::Compression(compression) + }), + dictionary_enabled_opt: global_options.global.dictionary_enabled.map(|enabled| { + parquet_options::DictionaryEnabledOpt::DictionaryEnabled(enabled) + }), + dictionary_page_size_limit: global_options.global.dictionary_page_size_limit as u64, + statistics_enabled_opt: global_options.global.statistics_enabled.map(|enabled| { + parquet_options::StatisticsEnabledOpt::StatisticsEnabled(enabled) + }), + max_statistics_size_opt: global_options.global.max_statistics_size.map(|size| { + parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(size as u64) + }), + max_row_group_size: global_options.global.max_row_group_size as u64, + created_by: global_options.global.created_by.clone(), + column_index_truncate_length_opt: global_options.global.column_index_truncate_length.map(|length| { + parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(length as u64) + }), + data_page_row_count_limit: global_options.global.data_page_row_count_limit as u64, + encoding_opt: global_options.global.encoding.map(|encoding| { + parquet_options::EncodingOpt::Encoding(encoding) + }), + bloom_filter_on_read: global_options.global.bloom_filter_on_read, + bloom_filter_on_write: global_options.global.bloom_filter_on_write, + bloom_filter_fpp_opt: global_options.global.bloom_filter_fpp.map(|fpp| { + parquet_options::BloomFilterFppOpt::BloomFilterFpp(fpp) + }), + bloom_filter_ndv_opt: global_options.global.bloom_filter_ndv.map(|ndv| { + parquet_options::BloomFilterNdvOpt::BloomFilterNdv(ndv) + }), + allow_single_file_parallelism: global_options.global.allow_single_file_parallelism, + maximum_parallel_row_group_writers: global_options.global.maximum_parallel_row_group_writers as u64, + maximum_buffered_record_batches_per_stream: global_options.global.maximum_buffered_record_batches_per_stream as u64, + schema_force_view_types: global_options.global.schema_force_view_types, + binary_as_string: global_options.global.binary_as_string, + }), + column_specific_options: column_specific_options.into_iter().map(|(column_name, options)| { + ParquetColumnSpecificOptions { + column_name, + options: Some(ParquetColumnOptionsProto { + bloom_filter_enabled_opt: options.bloom_filter_enabled.map(|enabled| { + parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(enabled) + }), + encoding_opt: options.encoding.map(|encoding| { + parquet_column_options::EncodingOpt::Encoding(encoding) + }), + dictionary_enabled_opt: options.dictionary_enabled.map(|enabled| { + parquet_column_options::DictionaryEnabledOpt::DictionaryEnabled(enabled) + }), + compression_opt: options.compression.map(|compression| { + parquet_column_options::CompressionOpt::Compression(compression) + }), + statistics_enabled_opt: options.statistics_enabled.map(|enabled| { + parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled(enabled) + }), + bloom_filter_fpp_opt: options.bloom_filter_fpp.map(|fpp| { + parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(fpp) + }), + bloom_filter_ndv_opt: options.bloom_filter_ndv.map(|ndv| { + parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(ndv) + }), + max_statistics_size_opt: options.max_statistics_size.map(|size| { + parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(size as u32) + }), + }) + } + }).collect(), + key_value_metadata: global_options.key_value_metadata + .iter() + .filter_map(|(key, value)| { + value.as_ref().map(|v| (key.clone(), v.clone())) + }) + .collect(), + } + } +} + +impl From<&ParquetOptionsProto> for ParquetOptions { + fn from(proto: &ParquetOptionsProto) -> Self { + ParquetOptions { + enable_page_index: proto.enable_page_index, + pruning: proto.pruning, + skip_metadata: proto.skip_metadata, + metadata_size_hint: proto.metadata_size_hint_opt.as_ref().map(|opt| match opt { + parquet_options::MetadataSizeHintOpt::MetadataSizeHint(size) => *size as usize, + }), + pushdown_filters: proto.pushdown_filters, + reorder_filters: proto.reorder_filters, + data_pagesize_limit: proto.data_pagesize_limit as usize, + write_batch_size: proto.write_batch_size as usize, + writer_version: proto.writer_version.clone(), + compression: proto.compression_opt.as_ref().map(|opt| match opt { + parquet_options::CompressionOpt::Compression(compression) => compression.clone(), + }), + dictionary_enabled: proto.dictionary_enabled_opt.as_ref().map(|opt| match opt { + parquet_options::DictionaryEnabledOpt::DictionaryEnabled(enabled) => *enabled, + }), + dictionary_page_size_limit: proto.dictionary_page_size_limit as usize, + statistics_enabled: proto.statistics_enabled_opt.as_ref().map(|opt| match opt { + parquet_options::StatisticsEnabledOpt::StatisticsEnabled(statistics) => statistics.clone(), + }), + max_statistics_size: proto.max_statistics_size_opt.as_ref().map(|opt| match opt { + parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(size) => *size as usize, + }), + max_row_group_size: proto.max_row_group_size as usize, + created_by: proto.created_by.clone(), + column_index_truncate_length: proto.column_index_truncate_length_opt.as_ref().map(|opt| match opt { + parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(length) => *length as usize, + }), + data_page_row_count_limit: proto.data_page_row_count_limit as usize, + encoding: proto.encoding_opt.as_ref().map(|opt| match opt { + parquet_options::EncodingOpt::Encoding(encoding) => encoding.clone(), + }), + bloom_filter_on_read: proto.bloom_filter_on_read, + bloom_filter_on_write: proto.bloom_filter_on_write, + bloom_filter_fpp: proto.bloom_filter_fpp_opt.as_ref().map(|opt| match opt { + parquet_options::BloomFilterFppOpt::BloomFilterFpp(fpp) => *fpp, + }), + bloom_filter_ndv: proto.bloom_filter_ndv_opt.as_ref().map(|opt| match opt { + parquet_options::BloomFilterNdvOpt::BloomFilterNdv(ndv) => *ndv, + }), + allow_single_file_parallelism: proto.allow_single_file_parallelism, + maximum_parallel_row_group_writers: proto.maximum_parallel_row_group_writers as usize, + maximum_buffered_record_batches_per_stream: proto.maximum_buffered_record_batches_per_stream as usize, + schema_force_view_types: proto.schema_force_view_types, + binary_as_string: proto.binary_as_string, + } + } +} + +impl From for ParquetColumnOptions { + fn from(proto: ParquetColumnOptionsProto) -> Self { + ParquetColumnOptions { + bloom_filter_enabled: proto.bloom_filter_enabled_opt.map( + |parquet_column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v)| v, + ), + encoding: proto + .encoding_opt + .map(|parquet_column_options::EncodingOpt::Encoding(v)| v), + dictionary_enabled: proto.dictionary_enabled_opt.map( + |parquet_column_options::DictionaryEnabledOpt::DictionaryEnabled(v)| v, + ), + compression: proto + .compression_opt + .map(|parquet_column_options::CompressionOpt::Compression(v)| v), + statistics_enabled: proto.statistics_enabled_opt.map( + |parquet_column_options::StatisticsEnabledOpt::StatisticsEnabled(v)| v, + ), + bloom_filter_fpp: proto + .bloom_filter_fpp_opt + .map(|parquet_column_options::BloomFilterFppOpt::BloomFilterFpp(v)| v), + bloom_filter_ndv: proto + .bloom_filter_ndv_opt + .map(|parquet_column_options::BloomFilterNdvOpt::BloomFilterNdv(v)| v), + max_statistics_size: proto.max_statistics_size_opt.map( + |parquet_column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v)| { + v as usize + }, + ), + } + } +} + +impl From<&TableParquetOptionsProto> for TableParquetOptions { + fn from(proto: &TableParquetOptionsProto) -> Self { + TableParquetOptions { + global: proto + .global + .as_ref() + .map(ParquetOptions::from) + .unwrap_or_default(), + column_specific_options: proto + .column_specific_options + .iter() + .map(|parquet_column_options| { + ( + parquet_column_options.column_name.clone(), + ParquetColumnOptions::from( + parquet_column_options.options.clone().unwrap_or_default(), + ), + ) + }) + .collect(), + key_value_metadata: proto + .key_value_metadata + .iter() + .map(|(k, v)| (k.clone(), Some(v.clone()))) + .collect(), + } + } +} + +#[derive(Debug)] +pub struct ParquetLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for ParquetLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + buf: &[u8], + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + let proto = TableParquetOptionsProto::decode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode TableParquetOptionsProto: {:?}", + e + )) + })?; + let options: TableParquetOptions = (&proto).into(); + Ok(Arc::new(ParquetFormatFactory { + options: Some(options), + })) + } + + fn try_encode_file_format( + &self, + buf: &mut Vec, + node: Arc, + ) -> datafusion_common::Result<()> { + let options = if let Some(parquet_factory) = + node.as_any().downcast_ref::() + { + parquet_factory.options.clone().unwrap_or_default() + } else { + return Err(DataFusionError::Execution( + "Unsupported FileFormatFactory type".to_string(), + )); + }; + + let proto = TableParquetOptionsProto::from_factory(&ParquetFormatFactory { + options: Some(options), + }); + + proto.encode(buf).map_err(|e| { + DataFusionError::Execution(format!( + "Failed to encode TableParquetOptionsProto: {:?}", + e + )) + })?; + + Ok(()) + } +} + +#[derive(Debug)] +pub struct ArrowLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for ArrowLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _ctx: &SessionContext, + ) -> datafusion_common::Result> { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrowFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &mut Vec, + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} + +#[derive(Debug)] +pub struct AvroLogicalExtensionCodec; + +// TODO! This is a placeholder for now and needs to be implemented for real. +impl LogicalExtensionCodec for AvroLogicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[datafusion_expr::LogicalPlan], + _ctx: &SessionContext, + ) -> datafusion_common::Result { + not_impl_err!("Method not implemented") + } + + fn try_encode( + &self, + _node: &datafusion_expr::Extension, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _table_ref: &TableReference, + _schema: arrow::datatypes::SchemaRef, + _cts: &SessionContext, + ) -> datafusion_common::Result> { + not_impl_err!("Method not implemented") + } + + fn try_encode_table_provider( + &self, + _table_ref: &TableReference, + _node: Arc, + _buf: &mut Vec, + ) -> datafusion_common::Result<()> { + not_impl_err!("Method not implemented") + } + + fn try_decode_file_format( + &self, + __buf: &[u8], + __ctx: &SessionContext, + ) -> datafusion_common::Result> { + Ok(Arc::new(ArrowFormatFactory::new())) + } + + fn try_encode_file_format( + &self, + __buf: &mut Vec, + __node: Arc, + ) -> datafusion_common::Result<()> { + Ok(()) + } +} diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 83b232da9d21..f25fb0bf2561 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,35 +17,27 @@ use std::sync::Arc; -use arrow::{ - array::AsArray, - buffer::Buffer, - datatypes::{ - i256, DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, - UnionFields, UnionMode, - }, - ipc::{reader::read_record_batch, root_as_message}, -}; - use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ - arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, - Constraints, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, - TableReference, + exec_datafusion_err, internal_err, plan_datafusion_err, RecursionUnnestOption, + Result, ScalarValue, TableReference, UnnestOptions, }; -use datafusion_expr::expr::Unnest; -use datafusion_expr::expr::{Alias, Placeholder}; -use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; +use datafusion_expr::expr::{Alias, Placeholder, Sort}; +use datafusion_expr::expr::{Unnest, WildcardOptions}; +use datafusion_expr::ExprFunctionExt; use datafusion_expr::{ - expr::{self, InList, Sort, WindowFunction}, + expr::{self, InList, WindowFunction}, logical_plan::{PlanType, StringifiedPlan}, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, - GetFieldAccess, GetIndexedField, GroupingSet, + Between, BinaryExpr, BuiltInWindowFunction, Case, Cast, Expr, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use datafusion_proto_common::{from_proto::FromOptionalField, FromProtoError as Error}; +use crate::protobuf::plan_type::PlanTypeEnum::{ + FinalPhysicalPlanWithSchema, InitialPhysicalPlanWithSchema, +}; use crate::protobuf::{ self, plan_type::PlanTypeEnum::{ @@ -60,143 +52,23 @@ use crate::protobuf::{ use super::LogicalExtensionCodec; -#[derive(Debug)] -pub enum Error { - General(String), - - DataFusionError(DataFusionError), - - MissingRequiredField(String), - - AtLeastOneValue(String), - - UnknownEnumVariant { name: String, value: i32 }, -} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::General(desc) => write!(f, "General error: {desc}"), - - Self::DataFusionError(desc) => { - write!(f, "DataFusion error: {desc:?}") - } - - Self::MissingRequiredField(name) => { - write!(f, "Missing required field {name}") - } - Self::AtLeastOneValue(name) => { - write!(f, "Must have at least one {name}, found 0") - } - Self::UnknownEnumVariant { name, value } => { - write!(f, "Unknown i32 value for {name} enum: {value}") - } - } - } -} - -impl std::error::Error for Error {} - -impl From for Error { - fn from(e: DataFusionError) -> Self { - Error::DataFusionError(e) - } -} - -impl Error { - fn required(field: impl Into) -> Error { - Error::MissingRequiredField(field.into()) - } - - fn unknown(name: impl Into, value: i32) -> Error { - Error::UnknownEnumVariant { - name: name.into(), - value, - } - } -} - -/// An extension trait that adds the methods `optional` and `required` to any -/// Option containing a type implementing `TryInto` -pub trait FromOptionalField { - /// Converts an optional protobuf field to an option of a different type - /// - /// Returns None if the option is None, otherwise calls [`TryInto::try_into`] - /// on the contained data, returning any error encountered - fn optional(self) -> Result, Error>; - - /// Converts an optional protobuf field to a different type, returning an error if None - /// - /// Returns `Error::MissingRequiredField` if None, otherwise calls [`TryInto::try_into`] - /// on the contained data, returning any error encountered - fn required(self, field: impl Into) -> Result; -} - -impl FromOptionalField for Option -where - T: TryInto, -{ - fn optional(self) -> Result, Error> { - self.map(|t| t.try_into()).transpose() - } - - fn required(self, field: impl Into) -> Result { - match self { - None => Err(Error::required(field)), - Some(t) => t.try_into(), +impl From<&protobuf::UnnestOptions> for UnnestOptions { + fn from(opts: &protobuf::UnnestOptions) -> Self { + Self { + preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: r.input_column.as_ref().unwrap().into(), + output_column: r.output_column.as_ref().unwrap().into(), + depth: r.depth as usize, + }) + .collect::>(), } } } -impl From for Column { - fn from(c: protobuf::Column) -> Self { - let protobuf::Column { relation, name } = c; - - Self::new(relation.map(|r| r.relation), name) - } -} - -impl From<&protobuf::Column> for Column { - fn from(c: &protobuf::Column) -> Self { - c.clone().into() - } -} - -impl TryFrom<&protobuf::DfSchema> for DFSchema { - type Error = Error; - - fn try_from(df_schema: &protobuf::DfSchema) -> Result { - let df_fields = df_schema.columns.clone(); - let qualifiers_and_fields: Vec<(Option, Arc)> = df_fields - .iter() - .map(|df_field| { - let field: Field = df_field.field.as_ref().required("field")?; - Ok(( - df_field - .qualifier - .as_ref() - .map(|q| q.relation.clone().into()), - Arc::new(field), - )) - }) - .collect::, Error>>()?; - - Ok(DFSchema::new_with_metadata( - qualifiers_and_fields, - df_schema.metadata.clone(), - )?) - } -} - -impl TryFrom for DFSchemaRef { - type Error = Error; - - fn try_from(df_schema: protobuf::DfSchema) -> Result { - let dfschema: DFSchema = (&df_schema).try_into()?; - Ok(Arc::new(dfschema)) - } -} - impl From for WindowFrameUnits { fn from(units: protobuf::WindowFrameUnits) -> Self { match units { @@ -233,144 +105,6 @@ impl TryFrom for TableReference { } } -impl TryFrom<&protobuf::ArrowType> for DataType { - type Error = Error; - - fn try_from(arrow_type: &protobuf::ArrowType) -> Result { - arrow_type - .arrow_type_enum - .as_ref() - .required("arrow_type_enum") - } -} - -impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { - type Error = Error; - fn try_from( - arrow_type_enum: &protobuf::arrow_type::ArrowTypeEnum, - ) -> Result { - use protobuf::arrow_type; - Ok(match arrow_type_enum { - arrow_type::ArrowTypeEnum::None(_) => DataType::Null, - arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean, - arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8, - arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8, - arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16, - arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16, - arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32, - arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32, - arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64, - arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64, - arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16, - arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32, - arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64, - arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8, - arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8, - arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary, - arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => { - DataType::FixedSizeBinary(*size) - } - arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary, - arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, - arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, - arrow_type::ArrowTypeEnum::Duration(time_unit) => { - DataType::Duration(parse_i32_to_time_unit(time_unit)?) - } - arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { - time_unit, - timezone, - }) => DataType::Timestamp( - parse_i32_to_time_unit(time_unit)?, - match timezone.len() { - 0 => None, - _ => Some(timezone.as_str().into()), - }, - ), - arrow_type::ArrowTypeEnum::Time32(time_unit) => { - DataType::Time32(parse_i32_to_time_unit(time_unit)?) - } - arrow_type::ArrowTypeEnum::Time64(time_unit) => { - DataType::Time64(parse_i32_to_time_unit(time_unit)?) - } - arrow_type::ArrowTypeEnum::Interval(interval_unit) => { - DataType::Interval(parse_i32_to_interval_unit(interval_unit)?) - } - arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { - precision, - scale, - }) => DataType::Decimal128(*precision as u8, *scale as i8), - arrow_type::ArrowTypeEnum::List(list) => { - let list_type = - list.as_ref().field_type.as_deref().required("field_type")?; - DataType::List(Arc::new(list_type)) - } - arrow_type::ArrowTypeEnum::LargeList(list) => { - let list_type = - list.as_ref().field_type.as_deref().required("field_type")?; - DataType::LargeList(Arc::new(list_type)) - } - arrow_type::ArrowTypeEnum::FixedSizeList(list) => { - let list_type = - list.as_ref().field_type.as_deref().required("field_type")?; - let list_size = list.list_size; - DataType::FixedSizeList(Arc::new(list_type), list_size) - } - arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( - parse_proto_fields_to_fields(&strct.sub_field_types)?.into(), - ), - arrow_type::ArrowTypeEnum::Union(union) => { - let union_mode = protobuf::UnionMode::try_from(union.union_mode) - .map_err(|_| Error::unknown("UnionMode", union.union_mode))?; - let union_mode = match union_mode { - protobuf::UnionMode::Dense => UnionMode::Dense, - protobuf::UnionMode::Sparse => UnionMode::Sparse, - }; - let union_fields = parse_proto_fields_to_fields(&union.union_types)?; - - // Default to index based type ids if not provided - let type_ids: Vec<_> = match union.type_ids.is_empty() { - true => (0..union_fields.len() as i8).collect(), - false => union.type_ids.iter().map(|i| *i as i8).collect(), - }; - - DataType::Union(UnionFields::new(type_ids, union_fields), union_mode) - } - arrow_type::ArrowTypeEnum::Dictionary(dict) => { - let key_datatype = dict.as_ref().key.as_deref().required("key")?; - let value_datatype = dict.as_ref().value.as_deref().required("value")?; - DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype)) - } - arrow_type::ArrowTypeEnum::Map(map) => { - let field: Field = - map.as_ref().field_type.as_deref().required("field_type")?; - let keys_sorted = map.keys_sorted; - DataType::Map(Arc::new(field), keys_sorted) - } - }) - } -} - -impl TryFrom<&protobuf::Field> for Field { - type Error = Error; - fn try_from(field: &protobuf::Field) -> Result { - let datatype = field.arrow_type.as_deref().required("arrow_type")?; - let field = if field.dict_id != 0 { - Self::new_dict( - field.name.as_str(), - datatype, - field.nullable, - field.dict_id, - field.dict_ordered, - ) - .with_metadata(field.metadata.clone()) - } else { - Self::new(field.name.as_str(), datatype, field.nullable) - .with_metadata(field.metadata.clone()) - }; - Ok(field) - } -} - impl From<&protobuf::StringifiedPlan> for StringifiedPlan { fn from(stringified_plan: &protobuf::StringifiedPlan) -> Self { Self { @@ -398,6 +132,7 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { FinalLogicalPlan(_) => PlanType::FinalLogicalPlan, InitialPhysicalPlan(_) => PlanType::InitialPhysicalPlan, InitialPhysicalPlanWithStats(_) => PlanType::InitialPhysicalPlanWithStats, + InitialPhysicalPlanWithSchema(_) => PlanType::InitialPhysicalPlanWithSchema, OptimizedPhysicalPlan(OptimizedPhysicalPlanType { optimizer_name }) => { PlanType::OptimizedPhysicalPlan { optimizer_name: optimizer_name.clone(), @@ -405,314 +140,24 @@ impl From<&protobuf::StringifiedPlan> for StringifiedPlan { } FinalPhysicalPlan(_) => PlanType::FinalPhysicalPlan, FinalPhysicalPlanWithStats(_) => PlanType::FinalPhysicalPlanWithStats, + FinalPhysicalPlanWithSchema(_) => PlanType::FinalPhysicalPlanWithSchema, }, plan: Arc::new(stringified_plan.plan.clone()), } } } -impl From for AggregateFunction { - fn from(agg_fun: protobuf::AggregateFunction) -> Self { - match agg_fun { - protobuf::AggregateFunction::Min => Self::Min, - protobuf::AggregateFunction::Max => Self::Max, - protobuf::AggregateFunction::Sum => Self::Sum, - protobuf::AggregateFunction::Avg => Self::Avg, - protobuf::AggregateFunction::BitAnd => Self::BitAnd, - protobuf::AggregateFunction::BitOr => Self::BitOr, - protobuf::AggregateFunction::BitXor => Self::BitXor, - protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, - protobuf::AggregateFunction::BoolOr => Self::BoolOr, - protobuf::AggregateFunction::Count => Self::Count, - protobuf::AggregateFunction::ApproxDistinct => Self::ApproxDistinct, - protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, - protobuf::AggregateFunction::Variance => Self::Variance, - protobuf::AggregateFunction::VariancePop => Self::VariancePop, - protobuf::AggregateFunction::Covariance => Self::Covariance, - protobuf::AggregateFunction::CovariancePop => Self::CovariancePop, - protobuf::AggregateFunction::Stddev => Self::Stddev, - protobuf::AggregateFunction::StddevPop => Self::StddevPop, - protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, - protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, - protobuf::AggregateFunction::RegrCount => Self::RegrCount, - protobuf::AggregateFunction::RegrR2 => Self::RegrR2, - protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, - protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, - protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, - protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, - protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, - protobuf::AggregateFunction::ApproxPercentileCont => { - Self::ApproxPercentileCont - } - protobuf::AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } - protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, - protobuf::AggregateFunction::Grouping => Self::Grouping, - protobuf::AggregateFunction::Median => Self::Median, - protobuf::AggregateFunction::FirstValueAgg => Self::FirstValue, - protobuf::AggregateFunction::LastValueAgg => Self::LastValue, - protobuf::AggregateFunction::NthValueAgg => Self::NthValue, - protobuf::AggregateFunction::StringAgg => Self::StringAgg, - } - } -} - impl From for BuiltInWindowFunction { fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self { match built_in_function { - protobuf::BuiltInWindowFunction::RowNumber => Self::RowNumber, - protobuf::BuiltInWindowFunction::Rank => Self::Rank, - protobuf::BuiltInWindowFunction::PercentRank => Self::PercentRank, - protobuf::BuiltInWindowFunction::DenseRank => Self::DenseRank, - protobuf::BuiltInWindowFunction::Lag => Self::Lag, - protobuf::BuiltInWindowFunction::Lead => Self::Lead, + protobuf::BuiltInWindowFunction::Unspecified => todo!(), protobuf::BuiltInWindowFunction::FirstValue => Self::FirstValue, - protobuf::BuiltInWindowFunction::CumeDist => Self::CumeDist, - protobuf::BuiltInWindowFunction::Ntile => Self::Ntile, protobuf::BuiltInWindowFunction::NthValue => Self::NthValue, protobuf::BuiltInWindowFunction::LastValue => Self::LastValue, } } } -impl TryFrom<&protobuf::Schema> for Schema { - type Error = Error; - - fn try_from(schema: &protobuf::Schema) -> Result { - let fields = schema - .columns - .iter() - .map(Field::try_from) - .collect::, _>>()?; - Ok(Self::new_with_metadata(fields, schema.metadata.clone())) - } -} - -impl TryFrom<&protobuf::ScalarValue> for ScalarValue { - type Error = Error; - - fn try_from(scalar: &protobuf::ScalarValue) -> Result { - use protobuf::scalar_value::Value; - - let value = scalar - .value - .as_ref() - .ok_or_else(|| Error::required("value"))?; - - Ok(match value { - Value::BoolValue(v) => Self::Boolean(Some(*v)), - Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())), - Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())), - Value::Int8Value(v) => Self::Int8(Some(*v as i8)), - Value::Int16Value(v) => Self::Int16(Some(*v as i16)), - Value::Int32Value(v) => Self::Int32(Some(*v)), - Value::Int64Value(v) => Self::Int64(Some(*v)), - Value::Uint8Value(v) => Self::UInt8(Some(*v as u8)), - Value::Uint16Value(v) => Self::UInt16(Some(*v as u16)), - Value::Uint32Value(v) => Self::UInt32(Some(*v)), - Value::Uint64Value(v) => Self::UInt64(Some(*v)), - Value::Float32Value(v) => Self::Float32(Some(*v)), - Value::Float64Value(v) => Self::Float64(Some(*v)), - Value::Date32Value(v) => Self::Date32(Some(*v)), - // ScalarValue::List is serialized using arrow IPC format - Value::ListValue(v) - | Value::FixedSizeListValue(v) - | Value::LargeListValue(v) - | Value::StructValue(v) => { - let protobuf::ScalarNestedValue { - ipc_message, - arrow_data, - schema, - } = &v; - - let schema: Schema = if let Some(schema_ref) = schema { - schema_ref.try_into()? - } else { - return Err(Error::General( - "Invalid schema while deserializing ScalarValue::List" - .to_string(), - )); - }; - - let message = root_as_message(ipc_message.as_slice()).map_err(|e| { - Error::General(format!( - "Error IPC message while deserializing ScalarValue::List: {e}" - )) - })?; - let buffer = Buffer::from(arrow_data); - - let ipc_batch = message.header_as_record_batch().ok_or_else(|| { - Error::General( - "Unexpected message type deserializing ScalarValue::List" - .to_string(), - ) - })?; - - let record_batch = read_record_batch( - &buffer, - ipc_batch, - Arc::new(schema), - &Default::default(), - None, - &message.version(), - ) - .map_err(|e| arrow_datafusion_err!(e)) - .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; - let arr = record_batch.column(0); - match value { - Value::ListValue(_) => { - Self::List(arr.as_list::().to_owned().into()) - } - Value::LargeListValue(_) => { - Self::LargeList(arr.as_list::().to_owned().into()) - } - Value::FixedSizeListValue(_) => { - Self::FixedSizeList(arr.as_fixed_size_list().to_owned().into()) - } - Value::StructValue(_) => { - Self::Struct(arr.as_struct().to_owned().into()) - } - _ => unreachable!(), - } - } - Value::NullValue(v) => { - let null_type: DataType = v.try_into()?; - null_type.try_into().map_err(Error::DataFusionError)? - } - Value::Decimal128Value(val) => { - let array = vec_to_array(val.value.clone()); - Self::Decimal128( - Some(i128::from_be_bytes(array)), - val.p as u8, - val.s as i8, - ) - } - Value::Decimal256Value(val) => { - let array = vec_to_array(val.value.clone()); - Self::Decimal256( - Some(i256::from_be_bytes(array)), - val.p as u8, - val.s as i8, - ) - } - Value::Date64Value(v) => Self::Date64(Some(*v)), - Value::Time32Value(v) => { - let time_value = - v.value.as_ref().ok_or_else(|| Error::required("value"))?; - match time_value { - protobuf::scalar_time32_value::Value::Time32SecondValue(t) => { - Self::Time32Second(Some(*t)) - } - protobuf::scalar_time32_value::Value::Time32MillisecondValue(t) => { - Self::Time32Millisecond(Some(*t)) - } - } - } - Value::Time64Value(v) => { - let time_value = - v.value.as_ref().ok_or_else(|| Error::required("value"))?; - match time_value { - protobuf::scalar_time64_value::Value::Time64MicrosecondValue(t) => { - Self::Time64Microsecond(Some(*t)) - } - protobuf::scalar_time64_value::Value::Time64NanosecondValue(t) => { - Self::Time64Nanosecond(Some(*t)) - } - } - } - Value::IntervalYearmonthValue(v) => Self::IntervalYearMonth(Some(*v)), - Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some(*v)), - Value::DurationSecondValue(v) => Self::DurationSecond(Some(*v)), - Value::DurationMillisecondValue(v) => Self::DurationMillisecond(Some(*v)), - Value::DurationMicrosecondValue(v) => Self::DurationMicrosecond(Some(*v)), - Value::DurationNanosecondValue(v) => Self::DurationNanosecond(Some(*v)), - Value::TimestampValue(v) => { - let timezone = if v.timezone.is_empty() { - None - } else { - Some(v.timezone.as_str().into()) - }; - - let ts_value = - v.value.as_ref().ok_or_else(|| Error::required("value"))?; - - match ts_value { - protobuf::scalar_timestamp_value::Value::TimeMicrosecondValue(t) => { - Self::TimestampMicrosecond(Some(*t), timezone) - } - protobuf::scalar_timestamp_value::Value::TimeNanosecondValue(t) => { - Self::TimestampNanosecond(Some(*t), timezone) - } - protobuf::scalar_timestamp_value::Value::TimeSecondValue(t) => { - Self::TimestampSecond(Some(*t), timezone) - } - protobuf::scalar_timestamp_value::Value::TimeMillisecondValue(t) => { - Self::TimestampMillisecond(Some(*t), timezone) - } - } - } - Value::DictionaryValue(v) => { - let index_type: DataType = v - .index_type - .as_ref() - .ok_or_else(|| Error::required("index_type"))? - .try_into()?; - - let value: Self = v - .value - .as_ref() - .ok_or_else(|| Error::required("value"))? - .as_ref() - .try_into()?; - - Self::Dictionary(Box::new(index_type), Box::new(value)) - } - Value::BinaryValue(v) => Self::Binary(Some(v.clone())), - Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), - Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), - )), - Value::UnionValue(val) => { - let mode = match val.mode { - 0 => UnionMode::Sparse, - 1 => UnionMode::Dense, - id => Err(Error::unknown("UnionMode", id))?, - }; - let ids = val - .fields - .iter() - .map(|f| f.field_id as i8) - .collect::>(); - let fields = val - .fields - .iter() - .map(|f| f.field.clone()) - .collect::>>(); - let fields = fields.ok_or_else(|| Error::required("UnionField"))?; - let fields = parse_proto_fields_to_fields(&fields)?; - let fields = UnionFields::new(ids, fields); - let v_id = val.value_id as i8; - let val = match &val.value { - None => None, - Some(val) => { - let val: ScalarValue = val - .as_ref() - .try_into() - .map_err(|_| Error::General("Invalid Scalar".to_string()))?; - Some((v_id, Box::new(val))) - } - }; - Self::Union(val, fields, mode) - } - Value::FixedSizeBinaryValue(v) => { - Self::FixedSizeBinary(v.length, Some(v.clone().values)) - } - }) - } -} - impl TryFrom for WindowFrame { type Error = Error; @@ -757,27 +202,6 @@ impl TryFrom for WindowFrameBound { } } -impl From for TimeUnit { - fn from(time_unit: protobuf::TimeUnit) -> Self { - match time_unit { - protobuf::TimeUnit::Second => TimeUnit::Second, - protobuf::TimeUnit::Millisecond => TimeUnit::Millisecond, - protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond, - protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond, - } - } -} - -impl From for IntervalUnit { - fn from(interval_unit: protobuf::IntervalUnit) -> Self { - match interval_unit { - protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth, - protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime, - protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano, - } - } -} - impl From for JoinType { fn from(t: protobuf::JoinType) -> Self { match t { @@ -789,6 +213,7 @@ impl From for JoinType { protobuf::JoinType::Rightsemi => JoinType::RightSemi, protobuf::JoinType::Leftanti => JoinType::LeftAnti, protobuf::JoinType::Rightanti => JoinType::RightAnti, + protobuf::JoinType::Leftmark => JoinType::LeftMark, } } } @@ -802,51 +227,6 @@ impl From for JoinConstraint { } } -impl From for Constraints { - fn from(constraints: protobuf::Constraints) -> Self { - Constraints::new_unverified( - constraints - .constraints - .into_iter() - .map(|item| item.into()) - .collect(), - ) - } -} - -impl From for Constraint { - fn from(value: protobuf::Constraint) -> Self { - match value.constraint_mode.unwrap() { - protobuf::constraint::ConstraintMode::PrimaryKey(elem) => { - Constraint::PrimaryKey( - elem.indices.into_iter().map(|item| item as usize).collect(), - ) - } - protobuf::constraint::ConstraintMode::Unique(elem) => Constraint::Unique( - elem.indices.into_iter().map(|item| item as usize).collect(), - ), - } - } -} - -pub fn parse_i32_to_time_unit(value: &i32) -> Result { - protobuf::TimeUnit::try_from(*value) - .map(|t| t.into()) - .map_err(|_| Error::unknown("TimeUnit", *value)) -} - -pub fn parse_i32_to_interval_unit(value: &i32) -> Result { - protobuf::IntervalUnit::try_from(*value) - .map(|t| t.into()) - .map_err(|_| Error::unknown("IntervalUnit", *value)) -} - -pub fn parse_i32_to_aggregate_function(value: &i32) -> Result { - protobuf::AggregateFunction::try_from(*value) - .map(|a| a.into()) - .map_err(|_| Error::unknown("AggregateFunction", *value)) -} - pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, @@ -879,63 +259,6 @@ pub fn parse_expr( }) .expect("Binary expression could not be reduced to a single expression.")) } - ExprType::GetIndexedField(get_indexed_field) => { - let expr = parse_required_expr( - get_indexed_field.expr.as_deref(), - registry, - "expr", - codec, - )?; - let field = match &get_indexed_field.field { - Some(protobuf::get_indexed_field::Field::NamedStructField( - named_struct_field, - )) => GetFieldAccess::NamedStructField { - name: named_struct_field - .name - .as_ref() - .ok_or_else(|| Error::required("value"))? - .try_into()?, - }, - Some(protobuf::get_indexed_field::Field::ListIndex(list_index)) => { - GetFieldAccess::ListIndex { - key: Box::new(parse_required_expr( - list_index.key.as_deref(), - registry, - "key", - codec, - )?), - } - } - Some(protobuf::get_indexed_field::Field::ListRange(list_range)) => { - GetFieldAccess::ListRange { - start: Box::new(parse_required_expr( - list_range.start.as_deref(), - registry, - "start", - codec, - )?), - stop: Box::new(parse_required_expr( - list_range.stop.as_deref(), - registry, - "stop", - codec, - )?), - stride: Box::new(parse_required_expr( - list_range.stride.as_deref(), - registry, - "stride", - codec, - )?), - } - } - None => return Err(proto_error("Field must not be None")), - }; - - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - field, - ))) - } ExprType::Column(column) => Ok(Expr::Column(column.into())), ExprType::Literal(literal) => { let scalar_value: ScalarValue = literal.try_into()?; @@ -947,110 +270,78 @@ pub fn parse_expr( .as_ref() .ok_or_else(|| Error::required("window_function"))?; let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; - let mut order_by = parse_exprs(&expr.order_by, registry, codec)?; + let mut order_by = parse_sorts(&expr.order_by, registry, codec)?; let window_frame = expr .window_frame .as_ref() .map::, _>(|window_frame| { - let window_frame = window_frame.clone().try_into()?; - check_window_frame(&window_frame, order_by.len()) + let window_frame: WindowFrame = window_frame.clone().try_into()?; + window_frame + .regularize_order_bys(&mut order_by) .map(|_| window_frame) }) .transpose()? .ok_or_else(|| { - DataFusionError::Execution( - "missing window frame during deserialization".to_string(), - ) + exec_datafusion_err!("missing window frame during deserialization") })?; - // TODO: support proto for null treatment - let null_treatment = None; - regularize_window_order_by(&window_frame, &mut order_by)?; + // TODO: support proto for null treatment match window_function { - window_expr_node::WindowFunction::AggrFunction(i) => { - let aggr_function = parse_i32_to_aggregate_function(i)?; - - Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::AggregateFunction( - aggr_function, - ), - vec![parse_required_expr(expr.expr.as_deref(), registry, "expr", codec)?], - partition_by, - order_by, - window_frame, - None - ))) - } window_expr_node::WindowFunction::BuiltInFunction(i) => { let built_in_function = protobuf::BuiltInWindowFunction::try_from(*i) .map_err(|_| Error::unknown("BuiltInWindowFunction", *i))? .into(); - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); + let args = parse_exprs(&expr.exprs, registry, codec)?; - Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::BuiltInWindowFunction( + Expr::WindowFunction(WindowFunction::new( + expr::WindowFunctionDefinition::BuiltInWindowFunction( built_in_function, ), args, - partition_by, - order_by, - window_frame, - null_treatment - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::Udaf(udaf_name) => { - let udaf_function = registry.udaf(udaf_name)?; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::AggregateUDF( - udaf_function, - ), + let udaf_function = match &expr.fun_definition { + Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, + None => registry.udaf(udaf_name)?, + }; + + let args = parse_exprs(&expr.exprs, registry, codec)?; + Expr::WindowFunction(WindowFunction::new( + expr::WindowFunctionDefinition::AggregateUDF(udaf_function), args, - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .map_err(Error::DataFusionError) } window_expr_node::WindowFunction::Udwf(udwf_name) => { - let udwf_function = registry.udwf(udwf_name)?; - let args = - parse_optional_expr(expr.expr.as_deref(), registry, codec)? - .map(|e| vec![e]) - .unwrap_or_else(Vec::new); - Ok(Expr::WindowFunction(WindowFunction::new( - datafusion_expr::expr::WindowFunctionDefinition::WindowUDF( - udwf_function, - ), + let udwf_function = match &expr.fun_definition { + Some(buf) => codec.try_decode_udwf(udwf_name, buf)?, + None => registry.udwf(udwf_name)?, + }; + + let args = parse_exprs(&expr.exprs, registry, codec)?; + Expr::WindowFunction(WindowFunction::new( + expr::WindowFunctionDefinition::WindowUDF(udwf_function), args, - partition_by, - order_by, - window_frame, - None, - ))) + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .build() + .map_err(Error::DataFusionError) } } } - ExprType::AggregateExpr(expr) => { - let fun = parse_i32_to_aggregate_function(&expr.aggr_function)?; - - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - parse_exprs(&expr.expr, registry, codec)?, - expr.distinct, - parse_optional_expr(expr.filter.as_deref(), registry, codec)? - .map(Box::new), - parse_vec_expr(&expr.order_by, registry, codec)?, - None, - ))) - } ExprType::Alias(alias) => Ok(Expr::Alias(Alias::new( parse_required_expr(alias.expr.as_deref(), registry, "expr", codec)?, alias @@ -1227,16 +518,6 @@ pub fn parse_expr( let data_type = cast.arrow_type.as_ref().required("arrow_type")?; Ok(Expr::TryCast(TryCast::new(expr, data_type))) } - ExprType::Sort(sort) => Ok(Expr::Sort(Sort::new( - Box::new(parse_required_expr( - sort.expr.as_deref(), - registry, - "expr", - codec, - )?), - sort.asc, - sort.nulls_first, - ))), ExprType::Negative(negative) => Ok(Expr::Negative(Box::new( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), @@ -1257,13 +538,13 @@ pub fn parse_expr( parse_exprs(&in_list.list, registry, codec)?, in_list.negated, ))), - ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { - qualifier: if qualifier.is_empty() { - None - } else { - Some(qualifier.clone()) - }, - }), + ExprType::Wildcard(protobuf::Wildcard { qualifier }) => { + let qualifier = qualifier.to_owned().map(|x| x.try_into()).transpose()?; + Ok(Expr::Wildcard { + qualifier, + options: WildcardOptions::default(), + }) + } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { fun_name, args, @@ -1279,14 +560,20 @@ pub fn parse_expr( ))) } ExprType::AggregateUdfExpr(pb) => { - let agg_fn = registry.udaf(pb.fun_name.as_str())?; + let agg_fn = match &pb.fun_definition { + Some(buf) => codec.try_decode_udaf(&pb.fun_name, buf)?, + None => registry.udaf(&pb.fun_name)?, + }; Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, parse_exprs(&pb.args, registry, codec)?, - false, + pb.distinct, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), - parse_vec_expr(&pb.order_by, registry, codec)?, + match pb.order_by.len() { + 0 => None, + _ => Some(parse_sorts(&pb.order_by, registry, codec)?), + }, None, ))) } @@ -1332,6 +619,32 @@ where Ok(res) } +pub fn parse_sorts<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + protos + .into_iter() + .map(|sort| parse_sort(sort, registry, codec)) + .collect::, Error>>() +} + +pub fn parse_sort( + sort: &protobuf::SortExprNode, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result { + Ok(Sort::new( + parse_required_expr(sort.expr.as_ref(), registry, "expr", codec)?, + sort.asc, + sort.nulls_first, + )) +} + /// Parse an optional escape_char for Like, ILike, SimilarTo fn parse_escape_char(s: &str) -> Result> { match s.len() { @@ -1341,13 +654,6 @@ fn parse_escape_char(s: &str) -> Result> { } } -// panic here because no better way to convert from Vec to Array -fn vec_to_array(v: Vec) -> [T; N] { - v.try_into().unwrap_or_else(|v: Vec| { - panic!("Expected a Vec of length {} but it was {}", N, v.len()) - }) -} - pub fn from_proto_binary_op(op: &str) -> Result { match op { "And" => Ok(Operator::And), @@ -1383,16 +689,6 @@ pub fn from_proto_binary_op(op: &str) -> Result { } } -fn parse_vec_expr( - p: &[protobuf::LogicalExprNode], - registry: &dyn FunctionRegistry, - codec: &dyn LogicalExtensionCodec, -) -> Result>, Error> { - let res = parse_exprs(p, registry, codec)?; - // Convert empty vector to None. - Ok((!res.is_empty()).then_some(res)) -} - fn parse_optional_expr( p: Option<&protobuf::LogicalExprNode>, registry: &dyn FunctionRegistry, @@ -1419,16 +715,3 @@ fn parse_required_expr( fn proto_error>(message: S) -> Error { Error::General(message.into()) } - -/// Converts a vector of `protobuf::Field`s to `Arc`s. -fn parse_proto_fields_to_fields<'a, I>( - fields: I, -) -> std::result::Result, Error> -where - I: IntoIterator, -{ - fields - .into_iter() - .map(Field::try_from) - .collect::>() -} diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index cccfb0c27307..1993598f5cf7 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -19,24 +19,32 @@ use std::collections::HashMap; use std::fmt::Debug; use std::sync::Arc; -use crate::common::proto_error; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; -use crate::protobuf::{CustomTableScanNode, LogicalExprNodeCollection}; +use crate::protobuf::{ + ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode, + SortExprNodeCollection, +}; use crate::{ - convert_required, + convert_required, into_required, protobuf::{ self, listing_table_scan_node::FileFormatType, logical_plan_node::LogicalPlanType, LogicalExtensionNode, LogicalPlanNode, }, }; -use arrow::csv::WriterBuilder; +use crate::protobuf::{proto_error, ToProtoError}; use arrow::datatypes::{DataType, Schema, SchemaRef}; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetFormat; +use datafusion::datasource::file_format::{ + file_type_to_format, format_as_file_type, FileFormatFactory, +}; use datafusion::{ datasource::{ - file_format::{avro::AvroFormat, csv::CsvFormat, FileFormat}, + file_format::{ + avro::AvroFormat, csv::CsvFormat, json::JsonFormat as OtherNdJsonFormat, + FileFormat, + }, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, view::ViewTable, TableProvider, @@ -44,41 +52,33 @@ use datafusion::{ datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; +use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - context, internal_err, not_impl_err, parsers::CompressionTypeVariant, - plan_datafusion_err, DataFusionError, Result, TableReference, + context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, + Result, TableReference, }; use datafusion_expr::{ dml, logical_plan::{ builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, - CreateExternalTable, CreateView, CrossJoin, DdlStatement, Distinct, - EmptyRelation, Extension, Join, JoinConstraint, Limit, Prepare, Projection, - Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, + Extension, Join, JoinConstraint, Prepare, Projection, Repartition, Sort, + SubqueryAlias, TableScan, Values, Window, }, - DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, + DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, + WindowUDF, }; +use datafusion_expr::{AggregateUDF, ColumnUnnestList, FetchType, SkipType, Unnest}; +use self::to_proto::{serialize_expr, serialize_exprs}; +use crate::logical_plan::to_proto::serialize_sorts; use prost::bytes::BufMut; use prost::Message; -use self::to_proto::serialize_expr; - +pub mod file_formats; pub mod from_proto; pub mod to_proto; -impl From for DataFusionError { - fn from(e: from_proto::Error) -> Self { - plan_datafusion_err!("{}", e) - } -} - -impl From for DataFusionError { - fn from(e: to_proto::Error) -> Self { - plan_datafusion_err!("{}", e) - } -} - pub trait AsLogicalPlan: Debug + Send + Sync + Clone { fn try_decode(buf: &[u8]) -> Result where @@ -116,16 +116,34 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_decode_table_provider( &self, buf: &[u8], + table_ref: &TableReference, schema: SchemaRef, ctx: &SessionContext, ) -> Result>; fn try_encode_table_provider( &self, + table_ref: &TableReference, node: Arc, buf: &mut Vec, ) -> Result<()>; + fn try_decode_file_format( + &self, + _buf: &[u8], + _ctx: &SessionContext, + ) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for file format") + } + + fn try_encode_file_format( + &self, + _buf: &mut Vec, + _node: Arc, + ) -> Result<()> { + Ok(()) + } + fn try_decode_udf(&self, name: &str, _buf: &[u8]) -> Result> { not_impl_err!("LogicalExtensionCodec is not provided for scalar function {name}") } @@ -133,6 +151,24 @@ pub trait LogicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!( + "LogicalExtensionCodec is not provided for aggregate function {name}" + ) + } + + fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } + + fn try_decode_udwf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!("LogicalExtensionCodec is not provided for window function {name}") + } + + fn try_encode_udwf(&self, _node: &WindowUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug, Clone)] @@ -155,6 +191,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { fn try_decode_table_provider( &self, _buf: &[u8], + _table_ref: &TableReference, _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { @@ -163,6 +200,7 @@ impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { fn try_encode_table_provider( &self, + _table_ref: &TableReference, _node: Arc, _buf: &mut Vec, ) -> Result<()> { @@ -239,26 +277,18 @@ impl AsLogicalPlan for LogicalPlanNode { values .values_list .chunks_exact(n_cols) - .map(|r| { - r.iter() - .map(|expr| { - from_proto::parse_expr(expr, ctx, extension_codec) - }) - .collect::, from_proto::Error>>() - }) + .map(|r| from_proto::parse_exprs(r, ctx, extension_codec)) .collect::, _>>() .map_err(|e| e.into()) }?; + LogicalPlanBuilder::values(values)?.build() } LogicalPlanType::Projection(projection) => { let input: LogicalPlan = into_logical_plan!(projection.input, ctx, extension_codec)?; - let expr: Vec = projection - .expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let expr: Vec = + from_proto::parse_exprs(&projection.expr, ctx, extension_codec)?; let new_proj = project(input, expr)?; match projection.optional_alias.as_ref() { @@ -290,26 +320,17 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Window(window) => { let input: LogicalPlan = into_logical_plan!(window.input, ctx, extension_codec)?; - let window_expr = window - .window_expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let window_expr = + from_proto::parse_exprs(&window.window_expr, ctx, extension_codec)?; LogicalPlanBuilder::from(input).window(window_expr)?.build() } LogicalPlanType::Aggregate(aggregate) => { let input: LogicalPlan = into_logical_plan!(aggregate.input, ctx, extension_codec)?; - let group_expr = aggregate - .group_expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; - let aggr_expr = aggregate - .aggr_expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let group_expr = + from_proto::parse_exprs(&aggregate.group_expr, ctx, extension_codec)?; + let aggr_expr = + from_proto::parse_exprs(&aggregate.aggr_expr, ctx, extension_codec)?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? .build() @@ -327,20 +348,16 @@ impl AsLogicalPlan for LogicalPlanNode { projection = Some(column_indices); } - let filters = scan - .filters - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let filters = + from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?; let mut all_sort_orders = vec![]; for order in &scan.file_sort_order { - let file_sort_order = order - .logical_expr_nodes - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; - all_sort_orders.push(file_sort_order) + all_sort_orders.push(from_proto::parse_sorts( + &order.sort_expr_nodes, + ctx, + extension_codec, + )?) } let file_format: Arc = @@ -349,13 +366,18 @@ impl AsLogicalPlan for LogicalPlanNode { "logical_plan::from_proto() Unsupported file format '{self:?}'" )) })? { - #[cfg(feature = "parquet")] + #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] FileFormatType::Parquet(protobuf::ParquetFormat {options}) => { - let mut parquet = ParquetFormat::default(); - if let Some(options) = options { - parquet = parquet.with_options(options.try_into()?) + #[cfg(feature = "parquet")] + { + let mut parquet = ParquetFormat::default(); + if let Some(options) = options { + parquet = parquet.with_options(options.try_into()?) + } + Arc::new(parquet) } - Arc::new(parquet) + #[cfg(not(feature = "parquet"))] + panic!("Unable to process parquet file since `parquet` feature is not enabled"); } FileFormatType::Csv(protobuf::CsvFormat { options @@ -364,7 +386,17 @@ impl AsLogicalPlan for LogicalPlanNode { if let Some(options) = options { csv = csv.with_options(options.try_into()?) } - Arc::new(csv)}, + Arc::new(csv) + }, + FileFormatType::Json(protobuf::NdJsonFormat { + options + }) => { + let mut json = OtherNdJsonFormat::default(); + if let Some(options) = options { + json = json.with_options(options.try_into()?) + } + Arc::new(json) + } FileFormatType::Avro(..) => Arc::new(AvroFormat), }; @@ -418,7 +450,7 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } - LogicalPlanType::CustomScan(scan) => { + CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); let mut projection = None; @@ -431,20 +463,19 @@ impl AsLogicalPlan for LogicalPlanNode { projection = Some(column_indices); } - let filters = scan - .filters - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let filters = + from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?; + + let table_name = + from_table_reference(scan.table_name.as_ref(), "CustomScan")?; + let provider = extension_codec.try_decode_table_provider( &scan.custom_table_data, + &table_name, schema, ctx, )?; - let table_name = - from_table_reference(scan.table_name.as_ref(), "CustomScan")?; - LogicalPlanBuilder::scan_with_filters( table_name, provider_as_source(provider), @@ -456,12 +487,12 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Sort(sort) => { let input: LogicalPlan = into_logical_plan!(sort.input, ctx, extension_codec)?; - let sort_expr: Vec = sort - .expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; - LogicalPlanBuilder::from(input).sort(sort_expr)?.build() + let sort_expr: Vec = + from_proto::parse_sorts(&sort.expr, ctx, extension_codec)?; + let fetch: Option = sort.fetch.try_into().ok(); + LogicalPlanBuilder::from(input) + .sort_with_limit(sort_expr, fetch)? + .build() } LogicalPlanType::Repartition(repartition) => { use datafusion::logical_expr::Partitioning; @@ -469,9 +500,9 @@ impl AsLogicalPlan for LogicalPlanNode { into_logical_plan!(repartition.input, ctx, extension_codec)?; use protobuf::repartition_node::PartitionMethod; let pb_partition_method = repartition.partition_method.as_ref().ok_or_else(|| { - DataFusionError::Internal(String::from( - "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'", - )) + internal_datafusion_err!( + "Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'" + ) })?; let partitioning_scheme = match pb_partition_method { @@ -479,12 +510,7 @@ impl AsLogicalPlan for LogicalPlanNode { hash_expr: pb_hash_expr, partition_count, }) => Partitioning::Hash( - pb_hash_expr - .iter() - .map(|expr| { - from_proto::parse_expr(expr, ctx, extension_codec) - }) - .collect::, _>>()?, + from_proto::parse_exprs(pb_hash_expr, ctx, extension_codec)?, *partition_count as usize, ), PartitionMethod::RoundRobin(partition_count) => { @@ -502,7 +528,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::CreateExternalTable(create_extern_table) => { let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| { DataFusionError::Internal(String::from( - "Protobuf deserialization error, CreateExternalTableNode was missing required field schema.", + "Protobuf deserialization error, CreateExternalTableNode was missing required field schema." )) })?; @@ -524,12 +550,11 @@ impl AsLogicalPlan for LogicalPlanNode { let mut order_exprs = vec![]; for expr in &create_extern_table.order_exprs { - let order_expr = expr - .logical_expr_nodes - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; - order_exprs.push(order_expr) + order_exprs.push(from_proto::parse_sorts( + &expr.sort_expr_nodes, + ctx, + extension_codec, + )?); } let mut column_defaults = @@ -539,37 +564,28 @@ impl AsLogicalPlan for LogicalPlanNode { column_defaults.insert(col_name.clone(), expr); } - let file_compression_type = protobuf::CompressionTypeVariant::try_from( - create_extern_table.file_compression_type, - ) - .map_err(|_| { - proto_error(format!( - "Unknown file compression type {}", - create_extern_table.file_compression_type - )) - })?; - - Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable(CreateExternalTable { - schema: pb_schema.try_into()?, - name: from_table_reference(create_extern_table.name.as_ref(), "CreateExternalTable")?, - location: create_extern_table.location.clone(), - file_type: create_extern_table.file_type.clone(), - has_header: create_extern_table.has_header, - delimiter: create_extern_table.delimiter.chars().next().ok_or_else(|| { - DataFusionError::Internal(String::from("Protobuf deserialization error, unable to parse CSV delimiter")) - })?, - table_partition_cols: create_extern_table - .table_partition_cols - .clone(), - order_exprs, - if_not_exists: create_extern_table.if_not_exists, - file_compression_type: file_compression_type.into(), - definition, - unbounded: create_extern_table.unbounded, - options: create_extern_table.options.clone(), - constraints: constraints.into(), - column_defaults, - }))) + Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( + CreateExternalTable { + schema: pb_schema.try_into()?, + name: from_table_reference( + create_extern_table.name.as_ref(), + "CreateExternalTable", + )?, + location: create_extern_table.location.clone(), + file_type: create_extern_table.file_type.clone(), + table_partition_cols: create_extern_table + .table_partition_cols + .clone(), + order_exprs, + if_not_exists: create_extern_table.if_not_exists, + temporary: create_extern_table.temporary, + definition, + unbounded: create_extern_table.unbounded, + options: create_extern_table.options.clone(), + constraints: constraints.into(), + column_defaults, + }, + ))) } LogicalPlanType::CreateView(create_view) => { let plan = create_view @@ -585,6 +601,7 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlan::Ddl(DdlStatement::CreateView(CreateView { name: from_table_reference(create_view.name.as_ref(), "CreateView")?, + temporary: create_view.temporary, input: Arc::new(plan), or_replace: create_view.or_replace, definition, @@ -657,16 +674,10 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanBuilder::from(input).limit(skip, fetch)?.build() } LogicalPlanType::Join(join) => { - let left_keys: Vec = join - .left_join_key - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; - let right_keys: Vec = join - .right_join_key - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let left_keys: Vec = + from_proto::parse_exprs(&join.left_join_key, ctx, extension_codec)?; + let right_keys: Vec = + from_proto::parse_exprs(&join.right_join_key, ctx, extension_codec)?; let join_type = protobuf::JoinType::try_from(join.join_type).map_err(|_| { proto_error(format!( @@ -705,7 +716,12 @@ impl AsLogicalPlan for LogicalPlanNode { // The equijoin keys in using-join must be column. let using_keys = left_keys .into_iter() - .map(|key| key.try_into_col()) + .map(|key| { + key.try_as_col().cloned() + .ok_or_else(|| internal_datafusion_err!( + "Using join keys must be column references, got: {key:?}" + )) + }) .collect::, _>>()?; builder.join_using( into_logical_plan!(join.right, ctx, extension_codec)?, @@ -763,27 +779,20 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::DistinctOn(distinct_on) => { let input: LogicalPlan = into_logical_plan!(distinct_on.input, ctx, extension_codec)?; - let on_expr = distinct_on - .on_expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; - let select_expr = distinct_on - .select_expr - .iter() - .map(|expr| from_proto::parse_expr(expr, ctx, extension_codec)) - .collect::, _>>()?; + let on_expr = + from_proto::parse_exprs(&distinct_on.on_expr, ctx, extension_codec)?; + let select_expr = from_proto::parse_exprs( + &distinct_on.select_expr, + ctx, + extension_codec, + )?; let sort_expr = match distinct_on.sort_expr.len() { 0 => None, - _ => Some( - distinct_on - .sort_expr - .iter() - .map(|expr| { - from_proto::parse_expr(expr, ctx, extension_codec) - }) - .collect::, _>>()?, - ), + _ => Some(from_proto::parse_sorts( + &distinct_on.sort_expr, + ctx, + extension_codec, + )?), }; LogicalPlanBuilder::from(input) .distinct_on(on_expr, select_expr, sort_expr)? @@ -835,26 +844,66 @@ impl AsLogicalPlan for LogicalPlanNode { .prepare(prepare.name.clone(), data_types)? .build() } - LogicalPlanType::DropView(dropview) => Ok(datafusion_expr::LogicalPlan::Ddl( - datafusion_expr::DdlStatement::DropView(DropView { + LogicalPlanType::DropView(dropview) => { + Ok(LogicalPlan::Ddl(DdlStatement::DropView(DropView { name: from_table_reference(dropview.name.as_ref(), "DropView")?, if_exists: dropview.if_exists, schema: Arc::new(convert_required!(dropview.schema)?), - }), - )), + }))) + } LogicalPlanType::CopyTo(copy) => { let input: LogicalPlan = into_logical_plan!(copy.input, ctx, extension_codec)?; - Ok(datafusion_expr::LogicalPlan::Copy( - datafusion_expr::dml::CopyTo { - input: Arc::new(input), - output_url: copy.output_url.clone(), - partition_by: copy.partition_by.clone(), - format_options: convert_required!(copy.format_options)?, - options: Default::default(), - }, - )) + let file_type: Arc = format_as_file_type( + extension_codec.try_decode_file_format(©.file_type, ctx)?, + ); + + Ok(LogicalPlan::Copy(dml::CopyTo { + input: Arc::new(input), + output_url: copy.output_url.clone(), + partition_by: copy.partition_by.clone(), + file_type, + options: Default::default(), + })) + } + LogicalPlanType::Unnest(unnest) => { + let input: LogicalPlan = + into_logical_plan!(unnest.input, ctx, extension_codec)?; + Ok(LogicalPlan::Unnest(Unnest { + input: Arc::new(input), + exec_columns: unnest.exec_columns.iter().map(|c| c.into()).collect(), + list_type_columns: unnest + .list_type_columns + .iter() + .map(|c| { + let recursion_item = c.recursion.as_ref().unwrap(); + ( + c.input_index as _, + ColumnUnnestList { + output_column: recursion_item + .output_column + .as_ref() + .unwrap() + .into(), + depth: recursion_item.depth as _, + }, + ) + }) + .collect(), + struct_type_columns: unnest + .struct_type_columns + .iter() + .map(|c| *c as usize) + .collect(), + dependency_indices: unnest + .dependency_indices + .iter() + .map(|c| *c as usize) + .collect(), + schema: Arc::new(convert_required!(unnest.schema)?), + options: into_required!(unnest.options)?, + })) } } } @@ -873,12 +922,9 @@ impl AsLogicalPlan for LogicalPlanNode { } else { values[0].len() } as u64; - let values_list = values - .iter() - .flatten() - .map(|v| serialize_expr(v, extension_codec)) - .collect::, _>>()?; - Ok(protobuf::LogicalPlanNode { + let values_list = + serialize_exprs(values.iter().flatten(), extension_codec)?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Values( protobuf::ValuesNode { n_cols, @@ -912,10 +958,8 @@ impl AsLogicalPlan for LogicalPlanNode { }; let schema: protobuf::Schema = schema.as_ref().try_into()?; - let filters: Vec = filters - .iter() - .map(|filter| serialize_expr(filter, extension_codec)) - .collect::, _>>()?; + let filters: Vec = + serialize_exprs(filters, extension_codec)?; if let Some(listing_table) = source.downcast_ref::() { let any = listing_table.options().format.as_any(); @@ -939,6 +983,14 @@ impl AsLogicalPlan for LogicalPlanNode { })); } + if let Some(json) = any.downcast_ref::() { + let options = json.options(); + maybe_some_type = + Some(FileFormatType::Json(protobuf::NdJsonFormat { + options: Some(options.try_into()?), + })) + } + if any.is::() { maybe_some_type = Some(FileFormatType::Avro(protobuf::AvroFormat {})) @@ -956,18 +1008,15 @@ impl AsLogicalPlan for LogicalPlanNode { let options = listing_table.options(); - let mut exprs_vec: Vec = vec![]; + let mut exprs_vec: Vec = vec![]; for order in &options.file_sort_order { - let expr_vec = LogicalExprNodeCollection { - logical_expr_nodes: order - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, to_proto::Error>>()?, + let expr_vec = SortExprNodeCollection { + sort_expr_nodes: serialize_sorts(order, extension_codec)?, }; exprs_vec.push(expr_vec); } - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ListingScan( protobuf::ListingTableScanNode { file_format_type: Some(file_format_type), @@ -993,12 +1042,12 @@ impl AsLogicalPlan for LogicalPlanNode { )), }) } else if let Some(view_table) = source.downcast_ref::() { - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::ViewScan(Box::new( protobuf::ViewTableScanNode { table_name: Some(table_name.clone().into()), input: Some(Box::new( - protobuf::LogicalPlanNode::try_from_logical_plan( + LogicalPlanNode::try_from_logical_plan( view_table.logical_plan(), extension_codec, )?, @@ -1015,7 +1064,7 @@ impl AsLogicalPlan for LogicalPlanNode { } else { let mut bytes = vec![]; extension_codec - .try_encode_table_provider(provider, &mut bytes) + .try_encode_table_provider(table_name, provider, &mut bytes) .map_err(|e| context!("Error serializing custom table", e))?; let scan = CustomScan(CustomTableScanNode { table_name: Some(table_name.clone().into()), @@ -1031,31 +1080,27 @@ impl AsLogicalPlan for LogicalPlanNode { } } LogicalPlan::Projection(Projection { expr, input, .. }) => { - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Projection(Box::new( protobuf::ProjectionNode { input: Some(Box::new( - protobuf::LogicalPlanNode::try_from_logical_plan( + LogicalPlanNode::try_from_logical_plan( input.as_ref(), extension_codec, )?, )), - expr: expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, to_proto::Error>>()?, + expr: serialize_exprs(expr, extension_codec)?, optional_alias: None, }, ))), }) } LogicalPlan::Filter(filter) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - filter.input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + filter.input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Selection(Box::new( protobuf::SelectionNode { input: Some(Box::new(input)), @@ -1068,12 +1113,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Distinct(Distinct::All(input)) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Distinct(Box::new( protobuf::DistinctNode { input: Some(Box::new(input)), @@ -1088,29 +1132,19 @@ impl AsLogicalPlan for LogicalPlanNode { input, .. })) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; let sort_expr = match sort_expr { None => vec![], - Some(sort_expr) => sort_expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, _>>()?, + Some(sort_expr) => serialize_sorts(sort_expr, extension_codec)?, }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DistinctOn(Box::new( protobuf::DistinctOnNode { - on_expr: on_expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, _>>()?, - select_expr: select_expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, _>>()?, + on_expr: serialize_exprs(on_expr, extension_codec)?, + select_expr: serialize_exprs(select_expr, extension_codec)?, sort_expr, input: Some(Box::new(input)), }, @@ -1120,19 +1154,15 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Window(Window { input, window_expr, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Window(Box::new( protobuf::WindowNode { input: Some(Box::new(input)), - window_expr: window_expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, _>>()?, + window_expr: serialize_exprs(window_expr, extension_codec)?, }, ))), }) @@ -1143,23 +1173,16 @@ impl AsLogicalPlan for LogicalPlanNode { input, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Aggregate(Box::new( protobuf::AggregateNode { input: Some(Box::new(input)), - group_expr: group_expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, _>>()?, - aggr_expr: aggr_expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, _>>()?, + group_expr: serialize_exprs(group_expr, extension_codec)?, + aggr_expr: serialize_exprs(aggr_expr, extension_codec)?, }, ))), }) @@ -1174,16 +1197,14 @@ impl AsLogicalPlan for LogicalPlanNode { null_equals_null, .. }) => { - let left: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - left.as_ref(), - extension_codec, - )?; - let right: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - right.as_ref(), - extension_codec, - )?; + let left: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + left.as_ref(), + extension_codec, + )?; + let right: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + right.as_ref(), + extension_codec, + )?; let (left_join_key, right_join_key) = on .iter() .map(|(l, r)| { @@ -1192,7 +1213,7 @@ impl AsLogicalPlan for LogicalPlanNode { serialize_expr(r, extension_codec)?, )) }) - .collect::, to_proto::Error>>()? + .collect::, ToProtoError>>()? .into_iter() .unzip(); let join_type: protobuf::JoinType = join_type.to_owned().into(); @@ -1202,7 +1223,7 @@ impl AsLogicalPlan for LogicalPlanNode { .as_ref() .map(|e| serialize_expr(e, extension_codec)) .map_or(Ok(None), |v| v.map(Some))?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Join(Box::new( protobuf::JoinNode { left: Some(Box::new(left)), @@ -1221,12 +1242,11 @@ impl AsLogicalPlan for LogicalPlanNode { not_impl_err!("LogicalPlan serde is not yet implemented for subqueries") } LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::SubqueryAlias(Box::new( protobuf::SubqueryAliasNode { input: Some(Box::new(input)), @@ -1235,37 +1255,44 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Limit(Limit { input, skip, fetch }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + LogicalPlan::Limit(limit) => { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + limit.input.as_ref(), + extension_codec, + )?; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal skip values", + )); + }; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return Err(proto_error( + "LogicalPlan::Limit only supports literal fetch values", + )); + }; + + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Limit(Box::new( protobuf::LimitNode { input: Some(Box::new(input)), - skip: *skip as i64, + skip: skip as i64, fetch: fetch.unwrap_or(i64::MAX as usize) as i64, }, ))), }) } LogicalPlan::Sort(Sort { input, expr, fetch }) => { - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; - let selection_expr: Vec = expr - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, to_proto::Error>>()?; - Ok(protobuf::LogicalPlanNode { + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; + let sort_expr: Vec = + serialize_sorts(expr, extension_codec)?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Sort(Box::new( protobuf::SortNode { input: Some(Box::new(input)), - expr: selection_expr, + expr: sort_expr, fetch: fetch.map(|f| f as i64).unwrap_or(-1i64), }, ))), @@ -1276,11 +1303,10 @@ impl AsLogicalPlan for LogicalPlanNode { partitioning_scheme, }) => { use datafusion::logical_expr::Partitioning; - let input: protobuf::LogicalPlanNode = - protobuf::LogicalPlanNode::try_from_logical_plan( - input.as_ref(), - extension_codec, - )?; + let input: LogicalPlanNode = LogicalPlanNode::try_from_logical_plan( + input.as_ref(), + extension_codec, + )?; // Assumed common usize field was batch size // Used u64 to avoid any nastyness involving large values, most data clusters are probably uniformly 64 bits any ways @@ -1289,10 +1315,7 @@ impl AsLogicalPlan for LogicalPlanNode { let pb_partition_method = match partitioning_scheme { Partitioning::Hash(exprs, partition_count) => { PartitionMethod::Hash(protobuf::HashRepartition { - hash_expr: exprs - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, to_proto::Error>>()?, + hash_expr: serialize_exprs(exprs, extension_codec)?, partition_count: *partition_count as u64, }) } @@ -1304,7 +1327,7 @@ impl AsLogicalPlan for LogicalPlanNode { } }; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Repartition(Box::new( protobuf::RepartitionNode { input: Some(Box::new(input)), @@ -1315,7 +1338,7 @@ impl AsLogicalPlan for LogicalPlanNode { } LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row, .. - }) => Ok(protobuf::LogicalPlanNode { + }) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::EmptyRelation( protobuf::EmptyRelationNode { produce_one_row: *produce_one_row, @@ -1327,27 +1350,22 @@ impl AsLogicalPlan for LogicalPlanNode { name, location, file_type, - has_header, - delimiter, schema: df_schema, table_partition_cols, if_not_exists, definition, - file_compression_type, order_exprs, unbounded, options, constraints, column_defaults, + temporary, }, )) => { - let mut converted_order_exprs: Vec = vec![]; + let mut converted_order_exprs: Vec = vec![]; for order in order_exprs { - let temp = LogicalExprNodeCollection { - logical_expr_nodes: order - .iter() - .map(|expr| serialize_expr(expr, extension_codec)) - .collect::, to_proto::Error>>()?, + let temp = SortExprNodeCollection { + sort_expr_nodes: serialize_sorts(order, extension_codec)?, }; converted_order_exprs.push(temp); } @@ -1359,23 +1377,18 @@ impl AsLogicalPlan for LogicalPlanNode { .insert(col_name.clone(), serialize_expr(expr, extension_codec)?); } - let file_compression_type = - protobuf::CompressionTypeVariant::from(file_compression_type); - - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateExternalTable( protobuf::CreateExternalTableNode { name: Some(name.clone().into()), location: location.clone(), file_type: file_type.clone(), - has_header: *has_header, schema: Some(df_schema.try_into()?), table_partition_cols: table_partition_cols.clone(), if_not_exists: *if_not_exists, - delimiter: String::from(*delimiter), + temporary: *temporary, order_exprs: converted_order_exprs, definition: definition.clone().unwrap_or_default(), - file_compression_type: file_compression_type.into(), unbounded: *unbounded, options: options.clone(), constraints: Some(constraints.clone().into()), @@ -1389,7 +1402,8 @@ impl AsLogicalPlan for LogicalPlanNode { input, or_replace, definition, - })) => Ok(protobuf::LogicalPlanNode { + temporary, + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateView(Box::new( protobuf::CreateViewNode { name: Some(name.clone().into()), @@ -1398,6 +1412,7 @@ impl AsLogicalPlan for LogicalPlanNode { extension_codec, )?)), or_replace: *or_replace, + temporary: *temporary, definition: definition.clone().unwrap_or_default(), }, ))), @@ -1408,7 +1423,7 @@ impl AsLogicalPlan for LogicalPlanNode { if_not_exists, schema: df_schema, }, - )) => Ok(protobuf::LogicalPlanNode { + )) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalogSchema( protobuf::CreateCatalogSchemaNode { schema_name: schema_name.clone(), @@ -1421,7 +1436,7 @@ impl AsLogicalPlan for LogicalPlanNode { catalog_name, if_not_exists, schema: df_schema, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CreateCatalog( protobuf::CreateCatalogNode { catalog_name: catalog_name.clone(), @@ -1431,11 +1446,11 @@ impl AsLogicalPlan for LogicalPlanNode { )), }), LogicalPlan::Analyze(a) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( + let input = LogicalPlanNode::try_from_logical_plan( a.input.as_ref(), extension_codec, )?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Analyze(Box::new( protobuf::AnalyzeNode { input: Some(Box::new(input)), @@ -1445,11 +1460,11 @@ impl AsLogicalPlan for LogicalPlanNode { }) } LogicalPlan::Explain(a) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( + let input = LogicalPlanNode::try_from_logical_plan( a.plan.as_ref(), extension_codec, )?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Explain(Box::new( protobuf::ExplainNode { input: Some(Box::new(input)), @@ -1462,37 +1477,14 @@ impl AsLogicalPlan for LogicalPlanNode { let inputs: Vec = union .inputs .iter() - .map(|i| { - protobuf::LogicalPlanNode::try_from_logical_plan( - i, - extension_codec, - ) - }) + .map(|i| LogicalPlanNode::try_from_logical_plan(i, extension_codec)) .collect::>()?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Union( protobuf::UnionNode { inputs }, )), }) } - LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left = protobuf::LogicalPlanNode::try_from_logical_plan( - left.as_ref(), - extension_codec, - )?; - let right = protobuf::LogicalPlanNode::try_from_logical_plan( - right.as_ref(), - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { - logical_plan_type: Some(LogicalPlanType::CrossJoin(Box::new( - protobuf::CrossJoinNode { - left: Some(Box::new(left)), - right: Some(Box::new(right)), - }, - ))), - }) - } LogicalPlan::Extension(extension) => { let mut buf: Vec = vec![]; extension_codec.try_encode(extension, &mut buf)?; @@ -1501,15 +1493,10 @@ impl AsLogicalPlan for LogicalPlanNode { .node .inputs() .iter() - .map(|i| { - protobuf::LogicalPlanNode::try_from_logical_plan( - i, - extension_codec, - ) - }) + .map(|i| LogicalPlanNode::try_from_logical_plan(i, extension_codec)) .collect::>()?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Extension( LogicalExtensionNode { node: buf, inputs }, )), @@ -1520,11 +1507,9 @@ impl AsLogicalPlan for LogicalPlanNode { data_types, input, }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; - Ok(protobuf::LogicalPlanNode { + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Prepare(Box::new( protobuf::PrepareNode { name: name.clone(), @@ -1537,12 +1522,56 @@ impl AsLogicalPlan for LogicalPlanNode { ))), }) } - LogicalPlan::Unnest(_) => Err(proto_error( - "LogicalPlan serde is not yet implemented for Unnest", - )), + LogicalPlan::Unnest(Unnest { + input, + exec_columns, + list_type_columns, + struct_type_columns, + dependency_indices, + schema, + options, + }) => { + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; + let proto_unnest_list_items = list_type_columns + .iter() + .map(|(index, ul)| ColumnUnnestListItem { + input_index: *index as _, + recursion: Some(ColumnUnnestListRecursion { + output_column: Some(ul.output_column.to_owned().into()), + depth: ul.depth as _, + }), + }) + .collect(); + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::Unnest(Box::new( + protobuf::UnnestNode { + input: Some(Box::new(input)), + exec_columns: exec_columns + .iter() + .map(|col| col.into()) + .collect(), + list_type_columns: proto_unnest_list_items, + struct_type_columns: struct_type_columns + .iter() + .map(|c| *c as u64) + .collect(), + dependency_indices: dependency_indices + .iter() + .map(|c| *c as u64) + .collect(), + schema: Some(schema.try_into()?), + options: Some(options.into()), + }, + ))), + }) + } LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(_)) => Err(proto_error( "LogicalPlan serde is not yet implemented for CreateMemoryTable", )), + LogicalPlan::Ddl(DdlStatement::CreateIndex(_)) => Err(proto_error( + "LogicalPlan serde is not yet implemented for CreateIndex", + )), LogicalPlan::Ddl(DdlStatement::DropTable(_)) => Err(proto_error( "LogicalPlan serde is not yet implemented for DropTable", )), @@ -1550,7 +1579,7 @@ impl AsLogicalPlan for LogicalPlanNode { name, if_exists, schema, - })) => Ok(protobuf::LogicalPlanNode { + })) => Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::DropView( protobuf::DropViewNode { name: Some(name.clone().into()), @@ -1577,21 +1606,22 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::Copy(dml::CopyTo { input, output_url, - format_options, + file_type, partition_by, .. }) => { - let input = protobuf::LogicalPlanNode::try_from_logical_plan( - input, - extension_codec, - )?; + let input = + LogicalPlanNode::try_from_logical_plan(input, extension_codec)?; + let mut buf = Vec::new(); + extension_codec + .try_encode_file_format(&mut buf, file_type_to_format(file_type)?)?; - Ok(protobuf::LogicalPlanNode { + Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::CopyTo(Box::new( protobuf::CopyToNode { input: Some(Box::new(input)), output_url: output_url.to_string(), - format_options: Some(format_options.try_into()?), + file_type: buf, partition_by: partition_by.clone(), }, ))), @@ -1603,47 +1633,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::RecursiveQuery(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for RecursiveQuery", )), + LogicalPlan::Execute(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for Execute", + )), } } } - -pub(crate) fn csv_writer_options_to_proto( - csv_options: &WriterBuilder, - compression: &CompressionTypeVariant, -) -> protobuf::CsvWriterOptions { - let compression: protobuf::CompressionTypeVariant = compression.into(); - protobuf::CsvWriterOptions { - compression: compression.into(), - delimiter: (csv_options.delimiter() as char).to_string(), - has_header: csv_options.header(), - date_format: csv_options.date_format().unwrap_or("").to_owned(), - datetime_format: csv_options.datetime_format().unwrap_or("").to_owned(), - timestamp_format: csv_options.timestamp_format().unwrap_or("").to_owned(), - time_format: csv_options.time_format().unwrap_or("").to_owned(), - null_value: csv_options.null().to_owned(), - } -} - -pub(crate) fn csv_writer_options_from_proto( - writer_options: &protobuf::CsvWriterOptions, -) -> Result { - let mut builder = WriterBuilder::new(); - if !writer_options.delimiter.is_empty() { - if let Some(delimiter) = writer_options.delimiter.chars().next() { - if delimiter.is_ascii() { - builder = builder.with_delimiter(delimiter as u8); - } else { - return Err(proto_error("CSV Delimiter is not ASCII")); - } - } else { - return Err(proto_error("Error parsing CSV Delimiter")); - } - } - Ok(builder - .with_header(writer_options.has_header) - .with_date_format(writer_options.date_format.clone()) - .with_datetime_format(writer_options.datetime_format.clone()) - .with_timestamp_format(writer_options.timestamp_format.clone()) - .with_time_format(writer_options.time_format.clone()) - .with_null(writer_options.null_value.clone())) -} diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 45aebc88dc63..8af7b19d9091 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -19,284 +19,47 @@ //! DataFusion logical plans to be serialized and transmitted between //! processes. -use std::sync::Arc; - -use arrow::{ - array::ArrayRef, - datatypes::{ - DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, SchemaRef, - TimeUnit, UnionMode, - }, - ipc::writer::{DictionaryTracker, IpcDataGenerator}, - record_batch::RecordBatch, -}; - -use datafusion_common::{ - Column, Constraint, Constraints, DFSchema, DFSchemaRef, ScalarValue, TableReference, -}; +use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::expr::{ - self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, Unnest, + self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, + ScalarFunction, Unnest, }; use datafusion_expr::{ - logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, - BuiltInWindowFunction, Expr, JoinConstraint, JoinType, TryCast, WindowFrame, - WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + logical_plan::PlanType, logical_plan::StringifiedPlan, BuiltInWindowFunction, Expr, + JoinConstraint, JoinType, SortExpr, TryCast, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunctionDefinition, }; +use crate::protobuf::RecursionUnnestOption; use crate::protobuf::{ self, - arrow_type::ArrowTypeEnum, plan_type::PlanTypeEnum::{ AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, - InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, - OptimizedPhysicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithSchema, FinalPhysicalPlanWithStats, + InitialLogicalPlan, InitialPhysicalPlan, InitialPhysicalPlanWithSchema, + InitialPhysicalPlanWithStats, OptimizedLogicalPlan, OptimizedPhysicalPlan, }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, - UnionField, UnionValue, + ToProtoError as Error, }; use super::LogicalExtensionCodec; -#[derive(Debug)] -pub enum Error { - General(String), - - InvalidScalarValue(ScalarValue), - - InvalidScalarType(DataType), - - InvalidTimeUnit(TimeUnit), - - NotImplemented(String), -} - -impl std::error::Error for Error {} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::General(desc) => write!(f, "General error: {desc}"), - Self::InvalidScalarValue(value) => { - write!(f, "{value:?} is invalid as a DataFusion scalar value") - } - Self::InvalidScalarType(data_type) => { - write!(f, "{data_type:?} is invalid as a DataFusion scalar type") - } - Self::InvalidTimeUnit(time_unit) => { - write!( - f, - "Only TimeUnit::Microsecond and TimeUnit::Nanosecond are valid time units, found: {time_unit:?}" - ) - } - Self::NotImplemented(s) => { - write!(f, "Not implemented: {s}") - } - } - } -} - -impl TryFrom<&Field> for protobuf::Field { - type Error = Error; - - fn try_from(field: &Field) -> Result { - let arrow_type = field.data_type().try_into()?; - Ok(Self { - name: field.name().to_owned(), - arrow_type: Some(Box::new(arrow_type)), - nullable: field.is_nullable(), - children: Vec::new(), - metadata: field.metadata().clone(), - dict_id: field.dict_id().unwrap_or(0), - dict_ordered: field.dict_is_ordered().unwrap_or(false), - }) - } -} - -impl TryFrom<&DataType> for protobuf::ArrowType { - type Error = Error; - - fn try_from(val: &DataType) -> Result { - let arrow_type_enum: ArrowTypeEnum = val.try_into()?; - Ok(Self { - arrow_type_enum: Some(arrow_type_enum), - }) - } -} - -impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { - type Error = Error; - - fn try_from(val: &DataType) -> Result { - let res = match val { - DataType::Null => Self::None(EmptyMessage {}), - DataType::Boolean => Self::Bool(EmptyMessage {}), - DataType::Int8 => Self::Int8(EmptyMessage {}), - DataType::Int16 => Self::Int16(EmptyMessage {}), - DataType::Int32 => Self::Int32(EmptyMessage {}), - DataType::Int64 => Self::Int64(EmptyMessage {}), - DataType::UInt8 => Self::Uint8(EmptyMessage {}), - DataType::UInt16 => Self::Uint16(EmptyMessage {}), - DataType::UInt32 => Self::Uint32(EmptyMessage {}), - DataType::UInt64 => Self::Uint64(EmptyMessage {}), - DataType::Float16 => Self::Float16(EmptyMessage {}), - DataType::Float32 => Self::Float32(EmptyMessage {}), - DataType::Float64 => Self::Float64(EmptyMessage {}), - DataType::Timestamp(time_unit, timezone) => { - Self::Timestamp(protobuf::Timestamp { - time_unit: protobuf::TimeUnit::from(time_unit) as i32, - timezone: timezone.as_deref().unwrap_or("").to_string(), - }) - } - DataType::Date32 => Self::Date32(EmptyMessage {}), - DataType::Date64 => Self::Date64(EmptyMessage {}), - DataType::Time32(time_unit) => { - Self::Time32(protobuf::TimeUnit::from(time_unit) as i32) - } - DataType::Time64(time_unit) => { - Self::Time64(protobuf::TimeUnit::from(time_unit) as i32) - } - DataType::Duration(time_unit) => { - Self::Duration(protobuf::TimeUnit::from(time_unit) as i32) - } - DataType::Interval(interval_unit) => { - Self::Interval(protobuf::IntervalUnit::from(interval_unit) as i32) - } - DataType::Binary => Self::Binary(EmptyMessage {}), - DataType::FixedSizeBinary(size) => Self::FixedSizeBinary(*size), - DataType::LargeBinary => Self::LargeBinary(EmptyMessage {}), - DataType::Utf8 => Self::Utf8(EmptyMessage {}), - DataType::LargeUtf8 => Self::LargeUtf8(EmptyMessage {}), - DataType::List(item_type) => Self::List(Box::new(protobuf::List { - field_type: Some(Box::new(item_type.as_ref().try_into()?)), - })), - DataType::FixedSizeList(item_type, size) => { - Self::FixedSizeList(Box::new(protobuf::FixedSizeList { - field_type: Some(Box::new(item_type.as_ref().try_into()?)), - list_size: *size, - })) - } - DataType::LargeList(item_type) => Self::LargeList(Box::new(protobuf::List { - field_type: Some(Box::new(item_type.as_ref().try_into()?)), - })), - DataType::Struct(struct_fields) => Self::Struct(protobuf::Struct { - sub_field_types: convert_arc_fields_to_proto_fields(struct_fields)?, - }), - DataType::Union(fields, union_mode) => { - let union_mode = match union_mode { - UnionMode::Sparse => protobuf::UnionMode::Sparse, - UnionMode::Dense => protobuf::UnionMode::Dense, - }; - Self::Union(protobuf::Union { - union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, - union_mode: union_mode.into(), - type_ids: fields.iter().map(|(x, _)| x as i32).collect(), - }) - } - DataType::Dictionary(key_type, value_type) => { - Self::Dictionary(Box::new(protobuf::Dictionary { - key: Some(Box::new(key_type.as_ref().try_into()?)), - value: Some(Box::new(value_type.as_ref().try_into()?)), - })) - } - DataType::Decimal128(precision, scale) => Self::Decimal(protobuf::Decimal { - precision: *precision as u32, - scale: *scale as i32, - }), - DataType::Decimal256(_, _) => { - return Err(Error::General("Proto serialization error: The Decimal256 data type is not yet supported".to_owned())) - } - DataType::Map(field, sorted) => { - Self::Map(Box::new( - protobuf::Map { - field_type: Some(Box::new(field.as_ref().try_into()?)), - keys_sorted: *sorted, - } - )) - } - DataType::RunEndEncoded(_, _) => { - return Err(Error::General( - "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() - )) - } - DataType::Utf8View | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { - return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) - } - }; - - Ok(res) - } -} - -impl From for protobuf::Column { - fn from(c: Column) -> Self { +impl From<&UnnestOptions> for protobuf::UnnestOptions { + fn from(opts: &UnnestOptions) -> Self { Self { - relation: c.relation.map(|relation| protobuf::ColumnRelation { - relation: relation.to_string(), - }), - name: c.name, - } - } -} - -impl From<&Column> for protobuf::Column { - fn from(c: &Column) -> Self { - c.clone().into() - } -} - -impl TryFrom<&Schema> for protobuf::Schema { - type Error = Error; - - fn try_from(schema: &Schema) -> Result { - Ok(Self { - columns: convert_arc_fields_to_proto_fields(schema.fields())?, - metadata: schema.metadata.clone(), - }) - } -} - -impl TryFrom for protobuf::Schema { - type Error = Error; - - fn try_from(schema: SchemaRef) -> Result { - Ok(Self { - columns: convert_arc_fields_to_proto_fields(schema.fields())?, - metadata: schema.metadata.clone(), - }) - } -} - -impl TryFrom<&DFSchema> for protobuf::DfSchema { - type Error = Error; - - fn try_from(s: &DFSchema) -> Result { - let columns = s - .iter() - .map(|(qualifier, field)| { - Ok(protobuf::DfField { - field: Some(field.as_ref().try_into()?), - qualifier: qualifier.map(|r| protobuf::ColumnRelation { - relation: r.to_string(), - }), + preserve_nulls: opts.preserve_nulls, + recursions: opts + .recursions + .iter() + .map(|r| RecursionUnnestOption { + input_column: Some((&r.input_column).into()), + output_column: Some((&r.output_column).into()), + depth: r.depth as u32, }) - }) - .collect::, Error>>()?; - Ok(Self { - columns, - metadata: s.metadata().clone(), - }) - } -} - -impl TryFrom<&DFSchemaRef> for protobuf::DfSchema { - type Error = Error; - - fn try_from(s: &DFSchemaRef) -> Result { - s.as_ref().try_into() + .collect(), + } } } @@ -343,75 +106,27 @@ impl From<&StringifiedPlan> for protobuf::StringifiedPlan { PlanType::InitialPhysicalPlanWithStats => Some(protobuf::PlanType { plan_type_enum: Some(InitialPhysicalPlanWithStats(EmptyMessage {})), }), + PlanType::InitialPhysicalPlanWithSchema => Some(protobuf::PlanType { + plan_type_enum: Some(InitialPhysicalPlanWithSchema(EmptyMessage {})), + }), PlanType::FinalPhysicalPlanWithStats => Some(protobuf::PlanType { plan_type_enum: Some(FinalPhysicalPlanWithStats(EmptyMessage {})), }), + PlanType::FinalPhysicalPlanWithSchema => Some(protobuf::PlanType { + plan_type_enum: Some(FinalPhysicalPlanWithSchema(EmptyMessage {})), + }), }, plan: stringified_plan.plan.to_string(), } } } -impl From<&AggregateFunction> for protobuf::AggregateFunction { - fn from(value: &AggregateFunction) -> Self { - match value { - AggregateFunction::Min => Self::Min, - AggregateFunction::Max => Self::Max, - AggregateFunction::Sum => Self::Sum, - AggregateFunction::Avg => Self::Avg, - AggregateFunction::BitAnd => Self::BitAnd, - AggregateFunction::BitOr => Self::BitOr, - AggregateFunction::BitXor => Self::BitXor, - AggregateFunction::BoolAnd => Self::BoolAnd, - AggregateFunction::BoolOr => Self::BoolOr, - AggregateFunction::Count => Self::Count, - AggregateFunction::ApproxDistinct => Self::ApproxDistinct, - AggregateFunction::ArrayAgg => Self::ArrayAgg, - AggregateFunction::Variance => Self::Variance, - AggregateFunction::VariancePop => Self::VariancePop, - AggregateFunction::Covariance => Self::Covariance, - AggregateFunction::CovariancePop => Self::CovariancePop, - AggregateFunction::Stddev => Self::Stddev, - AggregateFunction::StddevPop => Self::StddevPop, - AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::RegrSlope => Self::RegrSlope, - AggregateFunction::RegrIntercept => Self::RegrIntercept, - AggregateFunction::RegrCount => Self::RegrCount, - AggregateFunction::RegrR2 => Self::RegrR2, - AggregateFunction::RegrAvgx => Self::RegrAvgx, - AggregateFunction::RegrAvgy => Self::RegrAvgy, - AggregateFunction::RegrSXX => Self::RegrSxx, - AggregateFunction::RegrSYY => Self::RegrSyy, - AggregateFunction::RegrSXY => Self::RegrSxy, - AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, - AggregateFunction::ApproxPercentileContWithWeight => { - Self::ApproxPercentileContWithWeight - } - AggregateFunction::ApproxMedian => Self::ApproxMedian, - AggregateFunction::Grouping => Self::Grouping, - AggregateFunction::Median => Self::Median, - AggregateFunction::FirstValue => Self::FirstValueAgg, - AggregateFunction::LastValue => Self::LastValueAgg, - AggregateFunction::NthValue => Self::NthValueAgg, - AggregateFunction::StringAgg => Self::StringAgg, - } - } -} - impl From<&BuiltInWindowFunction> for protobuf::BuiltInWindowFunction { fn from(value: &BuiltInWindowFunction) -> Self { match value { BuiltInWindowFunction::FirstValue => Self::FirstValue, BuiltInWindowFunction::LastValue => Self::LastValue, BuiltInWindowFunction::NthValue => Self::NthValue, - BuiltInWindowFunction::Ntile => Self::Ntile, - BuiltInWindowFunction::CumeDist => Self::CumeDist, - BuiltInWindowFunction::PercentRank => Self::PercentRank, - BuiltInWindowFunction::RowNumber => Self::RowNumber, - BuiltInWindowFunction::Rank => Self::Rank, - BuiltInWindowFunction::Lag => Self::Lag, - BuiltInWindowFunction::Lead => Self::Lead, - BuiltInWindowFunction::DenseRank => Self::DenseRank, } } } @@ -596,202 +311,95 @@ pub fn serialize_expr( // TODO: support null treatment in proto null_treatment: _, }) => { - let window_function = match fun { - WindowFunctionDefinition::AggregateFunction(fun) => { - protobuf::window_expr_node::WindowFunction::AggrFunction( - protobuf::AggregateFunction::from(fun).into(), - ) - } - WindowFunctionDefinition::BuiltInWindowFunction(fun) => { + let (window_function, fun_definition) = match fun { + WindowFunctionDefinition::BuiltInWindowFunction(fun) => ( protobuf::window_expr_node::WindowFunction::BuiltInFunction( protobuf::BuiltInWindowFunction::from(fun).into(), - ) - } + ), + None, + ), WindowFunctionDefinition::AggregateUDF(aggr_udf) => { - protobuf::window_expr_node::WindowFunction::Udaf( - aggr_udf.name().to_string(), + let mut buf = Vec::new(); + let _ = codec.try_encode_udaf(aggr_udf, &mut buf); + ( + protobuf::window_expr_node::WindowFunction::Udaf( + aggr_udf.name().to_string(), + ), + (!buf.is_empty()).then_some(buf), ) } WindowFunctionDefinition::WindowUDF(window_udf) => { - protobuf::window_expr_node::WindowFunction::Udwf( - window_udf.name().to_string(), + let mut buf = Vec::new(); + let _ = codec.try_encode_udwf(window_udf, &mut buf); + ( + protobuf::window_expr_node::WindowFunction::Udwf( + window_udf.name().to_string(), + ), + (!buf.is_empty()).then_some(buf), ) } }; - let arg_expr: Option> = if !args.is_empty() { - let arg = &args[0]; - Some(Box::new(serialize_expr(arg, codec)?)) - } else { - None - }; let partition_by = serialize_exprs(partition_by, codec)?; - let order_by = serialize_exprs(order_by, codec)?; + let order_by = serialize_sorts(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); - let window_expr = Box::new(protobuf::WindowExprNode { - expr: arg_expr, + let window_expr = protobuf::WindowExprNode { + exprs: serialize_exprs(args, codec)?, window_function: Some(window_function), partition_by, order_by, window_frame, - }); + fun_definition, + }; protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), } } Expr::AggregateFunction(expr::AggregateFunction { - ref func_def, + ref func, ref args, ref distinct, ref filter, ref order_by, null_treatment: _, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::NthValue => { - protobuf::AggregateFunction::NthValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg - } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: serialize_exprs(args, codec)?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(serialize_expr(e, codec)?)), - None => None, - }, - order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, - None => vec![], - }, - }; - protobuf::LogicalExprNode { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), - } - } - AggregateFunctionDefinition::UDF(fun) => protobuf::LogicalExprNode { + }) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udaf(func, &mut buf); + protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateUdfExpr(Box::new( protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), + fun_name: func.name().to_string(), args: serialize_exprs(args, codec)?, + distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, }, order_by: match order_by { - Some(e) => serialize_exprs(e, codec)?, + Some(e) => serialize_sorts(e, codec)?, None => vec![], }, + fun_definition: (!buf.is_empty()).then_some(buf), }, ))), - }, - AggregateFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( - "Proto serialization error: Trying to serialize a unresolved function" - .to_string(), - )); } - }, + } Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported".to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let args = serialize_exprs(args, codec)?; - match func_def { - ScalarFunctionDefinition::UDF(fun) => { - let mut buf = Vec::new(); - let _ = codec.try_encode_udf(fun.as_ref(), &mut buf); - - let fun_definition = if buf.is_empty() { None } else { Some(buf) }; - - protobuf::LogicalExprNode { - expr_type: Some(ExprType::ScalarUdfExpr( - protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - fun_definition, - args, - }, - )), - } - } - ScalarFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( - "Proto serialization error: Trying to serialize a unresolved function" - .to_string(), - )); - } + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let mut buf = Vec::new(); + let _ = codec.try_encode_udf(func, &mut buf); + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name: func.name().to_string(), + fun_definition: (!buf.is_empty()).then_some(buf), + args: serialize_exprs(args, codec)?, + })), } } Expr::Not(expr) => { @@ -926,20 +534,6 @@ pub fn serialize_expr( expr_type: Some(ExprType::TryCast(expr)), } } - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let expr = Box::new(protobuf::SortExprNode { - expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - asc: *asc, - nulls_first: *nulls_first, - }); - protobuf::LogicalExprNode { - expr_type: Some(ExprType::Sort(expr)), - } - } Expr::Negative(expr) => { let expr = Box::new(protobuf::NegativeNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), @@ -970,9 +564,9 @@ pub fn serialize_expr( expr_type: Some(ExprType::InList(expr)), } } - Expr::Wildcard { qualifier } => protobuf::LogicalExprNode { + Expr::Wildcard { qualifier, .. } => protobuf::LogicalExprNode { expr_type: Some(ExprType::Wildcard(protobuf::Wildcard { - qualifier: qualifier.clone().unwrap_or("".to_string()), + qualifier: qualifier.to_owned().map(|x| x.into()), })), }, Expr::ScalarSubquery(_) @@ -983,45 +577,6 @@ pub fn serialize_expr( // see discussion in https://github.com/apache/datafusion/issues/2565 return Err(Error::General("Proto serialization error: Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists { .. } | Exp:OuterReferenceColumn not supported".to_string())); } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let field = match field { - GetFieldAccess::NamedStructField { name } => { - protobuf::get_indexed_field::Field::NamedStructField( - protobuf::NamedStructField { - name: Some(name.try_into()?), - }, - ) - } - GetFieldAccess::ListIndex { key } => { - protobuf::get_indexed_field::Field::ListIndex(Box::new( - protobuf::ListIndex { - key: Some(Box::new(serialize_expr(key.as_ref(), codec)?)), - }, - )) - } - GetFieldAccess::ListRange { - start, - stop, - stride, - } => protobuf::get_indexed_field::Field::ListRange(Box::new( - protobuf::ListRange { - start: Some(Box::new(serialize_expr(start.as_ref(), codec)?)), - stop: Some(Box::new(serialize_expr(stop.as_ref(), codec)?)), - stride: Some(Box::new(serialize_expr(stride.as_ref(), codec)?)), - }, - )), - }; - - protobuf::LogicalExprNode { - expr_type: Some(ExprType::GetIndexedField(Box::new( - protobuf::GetIndexedField { - expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - field: Some(field), - }, - ))), - } - } - Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { expr: serialize_exprs(exprs, codec)?, @@ -1063,349 +618,28 @@ pub fn serialize_expr( Ok(expr_node) } -impl TryFrom<&ScalarValue> for protobuf::ScalarValue { - type Error = Error; - - fn try_from(val: &ScalarValue) -> Result { - use protobuf::scalar_value::Value; - - let data_type = val.data_type(); - match val { - ScalarValue::Boolean(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) - } - ScalarValue::Float32(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float32Value(*s)) - } - ScalarValue::Float64(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float64Value(*s)) - } - ScalarValue::Int8(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Int8Value(*s as i32) - }) - } - ScalarValue::Int16(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Int16Value(*s as i32) - }) - } - ScalarValue::Int32(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int32Value(*s)) - } - ScalarValue::Int64(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Int64Value(*s)) - } - ScalarValue::UInt8(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Uint8Value(*s as u32) - }) - } - ScalarValue::UInt16(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Uint16Value(*s as u32) - }) - } - ScalarValue::UInt32(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint32Value(*s)) - } - ScalarValue::UInt64(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Uint64Value(*s)) - } - ScalarValue::Utf8(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Utf8Value(s.to_owned()) - }) - } - ScalarValue::LargeUtf8(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::LargeUtf8Value(s.to_owned()) - }) - } - ScalarValue::List(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) - } - ScalarValue::LargeList(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) - } - ScalarValue::FixedSizeList(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) - } - ScalarValue::Struct(arr) => { - encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) - } - ScalarValue::Date32(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date32Value(*s)) - } - ScalarValue::TimestampMicrosecond(val, tz) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::TimestampValue(protobuf::ScalarTimestampValue { - timezone: tz.as_deref().unwrap_or("").to_string(), - value: Some( - protobuf::scalar_timestamp_value::Value::TimeMicrosecondValue( - *s, - ), - ), - }) - }) - } - ScalarValue::TimestampNanosecond(val, tz) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::TimestampValue(protobuf::ScalarTimestampValue { - timezone: tz.as_deref().unwrap_or("").to_string(), - value: Some( - protobuf::scalar_timestamp_value::Value::TimeNanosecondValue( - *s, - ), - ), - }) - }) - } - ScalarValue::Decimal128(val, p, s) => match *val { - Some(v) => { - let array = v.to_be_bytes(); - let vec_val: Vec = array.to_vec(); - Ok(protobuf::ScalarValue { - value: Some(Value::Decimal128Value(protobuf::Decimal128 { - value: vec_val, - p: *p as i64, - s: *s as i64, - })), - }) - } - None => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::NullValue( - (&data_type).try_into()?, - )), - }), - }, - ScalarValue::Decimal256(val, p, s) => match *val { - Some(v) => { - let array = v.to_be_bytes(); - let vec_val: Vec = array.to_vec(); - Ok(protobuf::ScalarValue { - value: Some(Value::Decimal256Value(protobuf::Decimal256 { - value: vec_val, - p: *p as i64, - s: *s as i64, - })), - }) - } - None => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::NullValue( - (&data_type).try_into()?, - )), - }), - }, - ScalarValue::Date64(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| Value::Date64Value(*s)) - } - ScalarValue::TimestampSecond(val, tz) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::TimestampValue(protobuf::ScalarTimestampValue { - timezone: tz.as_deref().unwrap_or("").to_string(), - value: Some( - protobuf::scalar_timestamp_value::Value::TimeSecondValue(*s), - ), - }) - }) - } - ScalarValue::TimestampMillisecond(val, tz) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::TimestampValue(protobuf::ScalarTimestampValue { - timezone: tz.as_deref().unwrap_or("").to_string(), - value: Some( - protobuf::scalar_timestamp_value::Value::TimeMillisecondValue( - *s, - ), - ), - }) - }) - } - ScalarValue::IntervalYearMonth(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::IntervalYearmonthValue(*s) - }) - } - ScalarValue::IntervalDayTime(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::IntervalDaytimeValue(*s) - }) - } - ScalarValue::Null => Ok(protobuf::ScalarValue { - value: Some(Value::NullValue((&data_type).try_into()?)), - }), - - ScalarValue::Binary(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::BinaryValue(s.to_owned()) - }) - } - ScalarValue::LargeBinary(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::LargeBinaryValue(s.to_owned()) - }) - } - ScalarValue::FixedSizeBinary(length, val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::FixedSizeBinaryValue(protobuf::ScalarFixedSizeBinary { - values: s.to_owned(), - length: *length, - }) - }) - } - - ScalarValue::Time32Second(v) => { - create_proto_scalar(v.as_ref(), &data_type, |v| { - Value::Time32Value(protobuf::ScalarTime32Value { - value: Some( - protobuf::scalar_time32_value::Value::Time32SecondValue(*v), - ), - }) - }) - } - - ScalarValue::Time32Millisecond(v) => { - create_proto_scalar(v.as_ref(), &data_type, |v| { - Value::Time32Value(protobuf::ScalarTime32Value { - value: Some( - protobuf::scalar_time32_value::Value::Time32MillisecondValue( - *v, - ), - ), - }) - }) - } - - ScalarValue::Time64Microsecond(v) => { - create_proto_scalar(v.as_ref(), &data_type, |v| { - Value::Time64Value(protobuf::ScalarTime64Value { - value: Some( - protobuf::scalar_time64_value::Value::Time64MicrosecondValue( - *v, - ), - ), - }) - }) - } - - ScalarValue::Time64Nanosecond(v) => { - create_proto_scalar(v.as_ref(), &data_type, |v| { - Value::Time64Value(protobuf::ScalarTime64Value { - value: Some( - protobuf::scalar_time64_value::Value::Time64NanosecondValue( - *v, - ), - ), - }) - }) - } - - ScalarValue::IntervalMonthDayNano(v) => { - let value = if let Some(v) = v { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - Value::IntervalMonthDayNano(protobuf::IntervalMonthDayNanoValue { - months, - days, - nanos, - }) - } else { - Value::NullValue((&data_type).try_into()?) - }; - - Ok(protobuf::ScalarValue { value: Some(value) }) - } - - ScalarValue::DurationSecond(v) => { - let value = match v { - Some(v) => Value::DurationSecondValue(*v), - None => Value::NullValue((&data_type).try_into()?), - }; - Ok(protobuf::ScalarValue { value: Some(value) }) - } - ScalarValue::DurationMillisecond(v) => { - let value = match v { - Some(v) => Value::DurationMillisecondValue(*v), - None => Value::NullValue((&data_type).try_into()?), - }; - Ok(protobuf::ScalarValue { value: Some(value) }) - } - ScalarValue::DurationMicrosecond(v) => { - let value = match v { - Some(v) => Value::DurationMicrosecondValue(*v), - None => Value::NullValue((&data_type).try_into()?), - }; - Ok(protobuf::ScalarValue { value: Some(value) }) - } - ScalarValue::DurationNanosecond(v) => { - let value = match v { - Some(v) => Value::DurationNanosecondValue(*v), - None => Value::NullValue((&data_type).try_into()?), - }; - Ok(protobuf::ScalarValue { value: Some(value) }) - } - - ScalarValue::Union(val, df_fields, mode) => { - let mut fields = Vec::::with_capacity(df_fields.len()); - for (id, field) in df_fields.iter() { - let field_id = id as i32; - let field = Some(field.as_ref().try_into()?); - let field = UnionField { field_id, field }; - fields.push(field); - } - let mode = match mode { - UnionMode::Sparse => 0, - UnionMode::Dense => 1, - }; - let value = match val { - None => None, - Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), - }; - let val = UnionValue { - value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), - value, - fields, - mode, - }; - let val = Value::UnionValue(Box::new(val)); - let val = protobuf::ScalarValue { value: Some(val) }; - Ok(val) - } - - ScalarValue::Dictionary(index_type, val) => { - let value: protobuf::ScalarValue = val.as_ref().try_into()?; - Ok(protobuf::ScalarValue { - value: Some(Value::DictionaryValue(Box::new( - protobuf::ScalarDictionaryValue { - index_type: Some(index_type.as_ref().try_into()?), - value: Some(Box::new(value)), - }, - ))), - }) - } - } - } -} - -impl From<&TimeUnit> for protobuf::TimeUnit { - fn from(val: &TimeUnit) -> Self { - match val { - TimeUnit::Second => protobuf::TimeUnit::Second, - TimeUnit::Millisecond => protobuf::TimeUnit::Millisecond, - TimeUnit::Microsecond => protobuf::TimeUnit::Microsecond, - TimeUnit::Nanosecond => protobuf::TimeUnit::Nanosecond, - } - } -} - -impl From<&IntervalUnit> for protobuf::IntervalUnit { - fn from(interval_unit: &IntervalUnit) -> Self { - match interval_unit { - IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth, - IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime, - IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano, - } - } +pub fn serialize_sorts<'a, I>( + sorts: I, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + sorts + .into_iter() + .map(|sort| { + let SortExpr { + expr, + asc, + nulls_first, + } = sort; + Ok(protobuf::SortExprNode { + expr: Some(serialize_expr(expr, codec)?), + asc: *asc, + nulls_first: *nulls_first, + }) + }) + .collect::, Error>>() } impl From for protobuf::TableReference { @@ -1451,6 +685,7 @@ impl From for protobuf::JoinType { JoinType::RightSemi => protobuf::JoinType::Rightsemi, JoinType::LeftAnti => protobuf::JoinType::Leftanti, JoinType::RightAnti => protobuf::JoinType::Rightanti, + JoinType::LeftMark => protobuf::JoinType::Leftmark, } } } @@ -1463,112 +698,3 @@ impl From for protobuf::JoinConstraint { } } } - -impl From for protobuf::Constraints { - fn from(value: Constraints) -> Self { - let constraints = value.into_iter().map(|item| item.into()).collect(); - protobuf::Constraints { constraints } - } -} - -impl From for protobuf::Constraint { - fn from(value: Constraint) -> Self { - let res = match value { - Constraint::PrimaryKey(indices) => { - let indices = indices.into_iter().map(|item| item as u64).collect(); - protobuf::constraint::ConstraintMode::PrimaryKey( - protobuf::PrimaryKeyConstraint { indices }, - ) - } - Constraint::Unique(indices) => { - let indices = indices.into_iter().map(|item| item as u64).collect(); - protobuf::constraint::ConstraintMode::PrimaryKey( - protobuf::PrimaryKeyConstraint { indices }, - ) - } - }; - protobuf::Constraint { - constraint_mode: Some(res), - } - } -} - -/// Creates a scalar protobuf value from an optional value (T), and -/// encoding None as the appropriate datatype -fn create_proto_scalar protobuf::scalar_value::Value>( - v: Option<&I>, - null_arrow_type: &DataType, - constructor: T, -) -> Result { - let value = v - .map(constructor) - .unwrap_or(protobuf::scalar_value::Value::NullValue( - null_arrow_type.try_into()?, - )); - - Ok(protobuf::ScalarValue { value: Some(value) }) -} - -// ScalarValue::List / FixedSizeList / LargeList / Struct are serialized using -// Arrow IPC messages as a single column RecordBatch -fn encode_scalar_nested_value( - arr: ArrayRef, - val: &ScalarValue, -) -> Result { - let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { - Error::General(format!( - "Error creating temporary batch while encoding ScalarValue::List: {e}" - )) - })?; - - let gen = IpcDataGenerator {}; - let mut dict_tracker = DictionaryTracker::new(false); - let (_, encoded_message) = gen - .encoded_batch(&batch, &mut dict_tracker, &Default::default()) - .map_err(|e| { - Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) - })?; - - let schema: protobuf::Schema = batch.schema().try_into()?; - - let scalar_list_value = protobuf::ScalarNestedValue { - ipc_message: encoded_message.ipc_message, - arrow_data: encoded_message.arrow_data, - schema: Some(schema), - }; - - match val { - ScalarValue::List(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue(scalar_list_value)), - }), - ScalarValue::LargeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::LargeListValue( - scalar_list_value, - )), - }), - ScalarValue::FixedSizeList(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::FixedSizeListValue( - scalar_list_value, - )), - }), - ScalarValue::Struct(_) => Ok(protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::StructValue( - scalar_list_value, - )), - }), - _ => unreachable!(), - } -} - -/// Converts a vector of `Arc`s to `protobuf::Field`s -fn convert_arc_fields_to_proto_fields<'a, I>( - fields: I, -) -> Result, Error> -where - I: IntoIterator>, -{ - fields - .into_iter() - .map(|field| field.as_ref().try_into()) - .collect::, Error>>() -} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 6184332ea581..316166042fc4 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -17,11 +17,11 @@ //! Serde code to convert from protocol buffers to Rust data structures. -use std::collections::HashMap; use std::sync::Arc; use arrow::compute::SortOptions; use chrono::{TimeZone, Utc}; +use datafusion_expr::dml::InsertOp; use object_store::path::Path; use object_store::ObjectMeta; @@ -35,31 +35,20 @@ use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, FileSinkConfig}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::WindowFunctionDefinition; -use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::windows::create_window_expr; -use datafusion::physical_plan::{ - ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, -}; -use datafusion_common::config::{ - ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, - TableParquetOptions, -}; -use datafusion_common::file_options::csv_writer::CsvWriterOptions; -use datafusion_common::file_options::json_writer::JsonWriterOptions; -use datafusion_common::parsers::CompressionTypeVariant; -use datafusion_common::stats::Precision; -use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; -use datafusion_expr::ScalarFunctionDefinition; - -use crate::common::proto_error; +use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; +use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; +use datafusion_common::{not_impl_err, DataFusionError, Result}; +use datafusion_proto_common::common::proto_error; + use crate::convert_required; -use crate::logical_plan::{self, csv_writer_options_from_proto}; +use crate::logical_plan::{self}; +use crate::protobuf; use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::{self, copy_to_node}; use super::PhysicalExtensionCodec; @@ -110,13 +99,13 @@ pub fn parse_physical_sort_exprs( registry: &dyn FunctionRegistry, input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, -) -> Result> { +) -> Result { proto .iter() .map(|sort_expr| { parse_physical_sort_expr(sort_expr, registry, input_schema, codec) }) - .collect::>>() + .collect::>() } /// Parses a physical window expr from a protobuf. @@ -137,7 +126,6 @@ pub fn parse_physical_window_expr( ) -> Result> { let window_node_expr = parse_physical_exprs(&proto.args, registry, input_schema, codec)?; - let partition_by = parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; @@ -156,14 +144,40 @@ pub fn parse_physical_window_expr( ) })?; + let fun = if let Some(window_func) = proto.window_function.as_ref() { + match window_func { + protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { + let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { + proto_error(format!( + "Received an unknown window builtin function: {n}" + )) + })?; + + WindowFunctionDefinition::BuiltInWindowFunction(f.into()) + } + protobuf::physical_window_expr_node::WindowFunction::UserDefinedAggrFunction(udaf_name) => { + WindowFunctionDefinition::AggregateUDF(match &proto.fun_definition { + Some(buf) => codec.try_decode_udaf(udaf_name, buf)?, + None => registry.udaf(udaf_name)? + }) + } + } + } else { + return Err(proto_error("Missing required field in protobuf")); + }; + + let name = proto.name.clone(); + // TODO: Remove extended_schema if functions are all UDAF + let extended_schema = + schema_add_window_field(&window_node_expr, input_schema, &fun, &name)?; create_window_expr( - &convert_required!(proto.window_function)?, - proto.name.clone(), + &fun, + name, &window_node_expr, &partition_by, - &order_by, + order_by.as_ref(), Arc::new(window_frame), - input_schema, + &extended_schema, false, ) } @@ -342,8 +356,7 @@ pub fn parse_physical_expr( Some(buf) => codec.try_decode_udf(&e.name, buf)?, None => registry.udf(e.name.as_str())?, }; - let signature = udf.signature(); - let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); + let scalar_fun_def = Arc::clone(&udf); let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; @@ -352,8 +365,6 @@ pub fn parse_physical_expr( scalar_fun_def, args, convert_required!(e.return_type)?, - None, - signature.type_signature.supports_zero_argument(), )) } ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new( @@ -374,6 +385,14 @@ pub fn parse_physical_expr( codec, )?, )), + ExprType::Extension(extension) => { + let inputs: Vec> = extension + .inputs + .iter() + .map(|e| parse_physical_expr(e, registry, input_schema, codec)) + .collect::>()?; + (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ + } }; Ok(pexpr) @@ -393,37 +412,6 @@ fn parse_required_physical_expr( }) } -impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> - for WindowFunctionDefinition -{ - type Error = DataFusionError; - - fn try_from( - expr: &protobuf::physical_window_expr_node::WindowFunction, - ) -> Result { - match expr { - protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => { - let f = protobuf::AggregateFunction::try_from(*n).map_err(|_| { - proto_error(format!( - "Received an unknown window aggregate function: {n}" - )) - })?; - - Ok(WindowFunctionDefinition::AggregateFunction(f.into())) - } - protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => { - let f = protobuf::BuiltInWindowFunction::try_from(*n).map_err(|_| { - proto_error(format!( - "Received an unknown window builtin function: {n}" - )) - })?; - - Ok(WindowFunctionDefinition::BuiltInWindowFunction(f.into())) - } - } - } -} - pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, registry: &dyn FunctionRegistry, @@ -448,6 +436,38 @@ pub fn parse_protobuf_hash_partitioning( } } +pub fn parse_protobuf_partitioning( + partitioning: Option<&protobuf::Partitioning>, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, +) -> Result> { + match partitioning { + Some(protobuf::Partitioning { partition_method }) => match partition_method { + Some(protobuf::partitioning::PartitionMethod::RoundRobin( + partition_count, + )) => Ok(Some(Partitioning::RoundRobinBatch( + *partition_count as usize, + ))), + Some(protobuf::partitioning::PartitionMethod::Hash(hash_repartition)) => { + parse_protobuf_hash_partitioning( + Some(hash_repartition), + registry, + input_schema, + codec, + ) + } + Some(protobuf::partitioning::PartitionMethod::Unknown(partition_count)) => { + Ok(Some(Partitioning::UnknownPartitioning( + *partition_count as usize, + ))) + } + None => Ok(None), + }, + None => Ok(None), + } +} + pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, registry: &dyn FunctionRegistry, @@ -537,6 +557,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { .map(|v| v.try_into()) .collect::, _>>()?, range: val.range.as_ref().map(|v| v.try_into()).transpose()?, + statistics: val.statistics.as_ref().map(|v| v.try_into()).transpose()?, extensions: None, }) } @@ -564,134 +585,6 @@ impl TryFrom<&protobuf::FileGroup> for Vec { } } -impl From<&protobuf::ColumnStats> for ColumnStatistics { - fn from(cs: &protobuf::ColumnStats) -> ColumnStatistics { - ColumnStatistics { - null_count: if let Some(nc) = &cs.null_count { - nc.clone().into() - } else { - Precision::Absent - }, - max_value: if let Some(max) = &cs.max_value { - max.clone().into() - } else { - Precision::Absent - }, - min_value: if let Some(min) = &cs.min_value { - min.clone().into() - } else { - Precision::Absent - }, - distinct_count: if let Some(dc) = &cs.distinct_count { - dc.clone().into() - } else { - Precision::Absent - }, - } - } -} - -impl From for Precision { - fn from(s: protobuf::Precision) -> Self { - let Ok(precision_type) = s.precision_info.try_into() else { - return Precision::Absent; - }; - match precision_type { - protobuf::PrecisionInfo::Exact => { - if let Some(val) = s.val { - if let Ok(ScalarValue::UInt64(Some(val))) = - ScalarValue::try_from(&val) - { - Precision::Exact(val as usize) - } else { - Precision::Absent - } - } else { - Precision::Absent - } - } - protobuf::PrecisionInfo::Inexact => { - if let Some(val) = s.val { - if let Ok(ScalarValue::UInt64(Some(val))) = - ScalarValue::try_from(&val) - { - Precision::Inexact(val as usize) - } else { - Precision::Absent - } - } else { - Precision::Absent - } - } - protobuf::PrecisionInfo::Absent => Precision::Absent, - } - } -} - -impl From for Precision { - fn from(s: protobuf::Precision) -> Self { - let Ok(precision_type) = s.precision_info.try_into() else { - return Precision::Absent; - }; - match precision_type { - protobuf::PrecisionInfo::Exact => { - if let Some(val) = s.val { - if let Ok(val) = ScalarValue::try_from(&val) { - Precision::Exact(val) - } else { - Precision::Absent - } - } else { - Precision::Absent - } - } - protobuf::PrecisionInfo::Inexact => { - if let Some(val) = s.val { - if let Ok(val) = ScalarValue::try_from(&val) { - Precision::Inexact(val) - } else { - Precision::Absent - } - } else { - Precision::Absent - } - } - protobuf::PrecisionInfo::Absent => Precision::Absent, - } - } -} - -impl From for JoinSide { - fn from(t: protobuf::JoinSide) -> Self { - match t { - protobuf::JoinSide::LeftSide => JoinSide::Left, - protobuf::JoinSide::RightSide => JoinSide::Right, - } - } -} - -impl TryFrom<&protobuf::Statistics> for Statistics { - type Error = DataFusionError; - - fn try_from(s: &protobuf::Statistics) -> Result { - // Keep it sync with Statistics::to_proto - Ok(Statistics { - num_rows: if let Some(nr) = &s.num_rows { - nr.clone().into() - } else { - Precision::Absent - }, - total_byte_size: if let Some(tbs) = &s.total_byte_size { - tbs.clone().into() - } else { - Precision::Absent - }, - // No column statistic (None) is encoded with empty array - column_statistics: s.column_stats.iter().map(|s| s.into()).collect(), - }) - } -} - impl TryFrom<&protobuf::JsonSink> for JsonSink { type Error = DataFusionError; @@ -748,258 +641,19 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { Ok((name.clone(), data_type)) }) .collect::>>()?; + let insert_op = match conf.insert_op() { + protobuf::InsertOp::Append => InsertOp::Append, + protobuf::InsertOp::Overwrite => InsertOp::Overwrite, + protobuf::InsertOp::Replace => InsertOp::Replace, + }; Ok(Self { object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, file_groups, table_paths, output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, - overwrite: conf.overwrite, - }) - } -} - -impl From for CompressionTypeVariant { - fn from(value: protobuf::CompressionTypeVariant) -> Self { - match value { - protobuf::CompressionTypeVariant::Gzip => Self::GZIP, - protobuf::CompressionTypeVariant::Bzip2 => Self::BZIP2, - protobuf::CompressionTypeVariant::Xz => Self::XZ, - protobuf::CompressionTypeVariant::Zstd => Self::ZSTD, - protobuf::CompressionTypeVariant::Uncompressed => Self::UNCOMPRESSED, - } - } -} - -impl From for protobuf::CompressionTypeVariant { - fn from(value: CompressionTypeVariant) -> Self { - match value { - CompressionTypeVariant::GZIP => Self::Gzip, - CompressionTypeVariant::BZIP2 => Self::Bzip2, - CompressionTypeVariant::XZ => Self::Xz, - CompressionTypeVariant::ZSTD => Self::Zstd, - CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, - } - } -} - -impl TryFrom<&protobuf::CsvWriterOptions> for CsvWriterOptions { - type Error = DataFusionError; - - fn try_from(opts: &protobuf::CsvWriterOptions) -> Result { - let write_options = csv_writer_options_from_proto(opts)?; - let compression: CompressionTypeVariant = opts.compression().into(); - Ok(CsvWriterOptions::new(write_options, compression)) - } -} - -impl TryFrom<&protobuf::JsonWriterOptions> for JsonWriterOptions { - type Error = DataFusionError; - - fn try_from(opts: &protobuf::JsonWriterOptions) -> Result { - let compression: CompressionTypeVariant = opts.compression().into(); - Ok(JsonWriterOptions::new(compression)) - } -} - -impl TryFrom<&protobuf::CsvOptions> for CsvOptions { - type Error = DataFusionError; - - fn try_from(proto_opts: &protobuf::CsvOptions) -> Result { - Ok(CsvOptions { - has_header: proto_opts.has_header, - delimiter: proto_opts.delimiter[0], - quote: proto_opts.quote[0], - escape: proto_opts.escape.first().copied(), - compression: proto_opts.compression().into(), - schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, - date_format: (!proto_opts.date_format.is_empty()) - .then(|| proto_opts.date_format.clone()), - datetime_format: (!proto_opts.datetime_format.is_empty()) - .then(|| proto_opts.datetime_format.clone()), - timestamp_format: (!proto_opts.timestamp_format.is_empty()) - .then(|| proto_opts.timestamp_format.clone()), - timestamp_tz_format: (!proto_opts.timestamp_tz_format.is_empty()) - .then(|| proto_opts.timestamp_tz_format.clone()), - time_format: (!proto_opts.time_format.is_empty()) - .then(|| proto_opts.time_format.clone()), - null_value: (!proto_opts.null_value.is_empty()) - .then(|| proto_opts.null_value.clone()), - }) - } -} - -impl TryFrom<&protobuf::ParquetOptions> for ParquetOptions { - type Error = DataFusionError; - - fn try_from(value: &protobuf::ParquetOptions) -> Result { - Ok(ParquetOptions { - enable_page_index: value.enable_page_index, - pruning: value.pruning, - skip_metadata: value.skip_metadata, - metadata_size_hint: value - .metadata_size_hint_opt.clone() - .map(|opt| match opt { - protobuf::parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v) => Some(v as usize), - }) - .unwrap_or(None), - pushdown_filters: value.pushdown_filters, - reorder_filters: value.reorder_filters, - data_pagesize_limit: value.data_pagesize_limit as usize, - write_batch_size: value.write_batch_size as usize, - writer_version: value.writer_version.clone(), - compression: value.compression_opt.clone().map(|opt| match opt { - protobuf::parquet_options::CompressionOpt::Compression(v) => Some(v), - }).unwrap_or(None), - dictionary_enabled: value.dictionary_enabled_opt.as_ref().map(|protobuf::parquet_options::DictionaryEnabledOpt::DictionaryEnabled(v)| *v), - // Continuing from where we left off in the TryFrom implementation - dictionary_page_size_limit: value.dictionary_page_size_limit as usize, - statistics_enabled: value - .statistics_enabled_opt.clone() - .map(|opt| match opt { - protobuf::parquet_options::StatisticsEnabledOpt::StatisticsEnabled(v) => Some(v), - }) - .unwrap_or(None), - max_statistics_size: value - .max_statistics_size_opt.as_ref() - .map(|opt| match opt { - protobuf::parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(*v as usize), - }) - .unwrap_or(None), - max_row_group_size: value.max_row_group_size as usize, - created_by: value.created_by.clone(), - column_index_truncate_length: value - .column_index_truncate_length_opt.as_ref() - .map(|opt| match opt { - protobuf::parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v) => Some(*v as usize), - }) - .unwrap_or(None), - data_page_row_count_limit: value.data_page_row_count_limit as usize, - encoding: value - .encoding_opt.clone() - .map(|opt| match opt { - protobuf::parquet_options::EncodingOpt::Encoding(v) => Some(v), - }) - .unwrap_or(None), - bloom_filter_enabled: value.bloom_filter_enabled, - bloom_filter_fpp: value.clone() - .bloom_filter_fpp_opt - .map(|opt| match opt { - protobuf::parquet_options::BloomFilterFppOpt::BloomFilterFpp(v) => Some(v), - }) - .unwrap_or(None), - bloom_filter_ndv: value.clone() - .bloom_filter_ndv_opt - .map(|opt| match opt { - protobuf::parquet_options::BloomFilterNdvOpt::BloomFilterNdv(v) => Some(v), - }) - .unwrap_or(None), - allow_single_file_parallelism: value.allow_single_file_parallelism, - maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as usize, - maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as usize, - - }) - } -} - -impl TryFrom<&protobuf::ColumnOptions> for ColumnOptions { - type Error = DataFusionError; - fn try_from(value: &protobuf::ColumnOptions) -> Result { - Ok(ColumnOptions { - compression: value.compression_opt.clone().map(|opt| match opt { - protobuf::column_options::CompressionOpt::Compression(v) => Some(v), - }).unwrap_or(None), - dictionary_enabled: value.dictionary_enabled_opt.as_ref().map(|protobuf::column_options::DictionaryEnabledOpt::DictionaryEnabled(v)| *v), - statistics_enabled: value - .statistics_enabled_opt.clone() - .map(|opt| match opt { - protobuf::column_options::StatisticsEnabledOpt::StatisticsEnabled(v) => Some(v), - }) - .unwrap_or(None), - max_statistics_size: value - .max_statistics_size_opt.clone() - .map(|opt| match opt { - protobuf::column_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v) => Some(v as usize), - }) - .unwrap_or(None), - encoding: value - .encoding_opt.clone() - .map(|opt| match opt { - protobuf::column_options::EncodingOpt::Encoding(v) => Some(v), - }) - .unwrap_or(None), - bloom_filter_enabled: value.bloom_filter_enabled_opt.clone().map(|opt| match opt { - protobuf::column_options::BloomFilterEnabledOpt::BloomFilterEnabled(v) => Some(v), - }) - .unwrap_or(None), - bloom_filter_fpp: value - .bloom_filter_fpp_opt.clone() - .map(|opt| match opt { - protobuf::column_options::BloomFilterFppOpt::BloomFilterFpp(v) => Some(v), - }) - .unwrap_or(None), - bloom_filter_ndv: value - .bloom_filter_ndv_opt.clone() - .map(|opt| match opt { - protobuf::column_options::BloomFilterNdvOpt::BloomFilterNdv(v) => Some(v), - }) - .unwrap_or(None), - }) - } -} - -impl TryFrom<&protobuf::TableParquetOptions> for TableParquetOptions { - type Error = DataFusionError; - fn try_from(value: &protobuf::TableParquetOptions) -> Result { - let mut column_specific_options: HashMap = HashMap::new(); - for protobuf::ColumnSpecificOptions { - column_name, - options: maybe_options, - } in &value.column_specific_options - { - if let Some(options) = maybe_options { - column_specific_options.insert(column_name.clone(), options.try_into()?); - } - } - Ok(TableParquetOptions { - global: value - .global - .as_ref() - .map(|v| v.try_into()) - .unwrap() - .unwrap(), - column_specific_options, - }) - } -} - -impl TryFrom<&protobuf::JsonOptions> for JsonOptions { - type Error = DataFusionError; - - fn try_from(proto_opts: &protobuf::JsonOptions) -> Result { - let compression: protobuf::CompressionTypeVariant = proto_opts.compression(); - Ok(JsonOptions { - compression: compression.into(), - schema_infer_max_rec: proto_opts.schema_infer_max_rec as usize, - }) - } -} - -impl TryFrom<©_to_node::FormatOptions> for FormatOptions { - type Error = DataFusionError; - fn try_from(value: ©_to_node::FormatOptions) -> Result { - Ok(match value { - copy_to_node::FormatOptions::Csv(options) => { - FormatOptions::CSV(options.try_into()?) - } - copy_to_node::FormatOptions::Json(options) => { - FormatOptions::JSON(options.try_into()?) - } - copy_to_node::FormatOptions::Parquet(options) => { - FormatOptions::PARQUET(options.try_into()?) - } - copy_to_node::FormatOptions::Avro(_) => FormatOptions::AVRO, - copy_to_node::FormatOptions::Arrow(_) => FormatOptions::ARROW, + insert_op, + keep_partition_by_columns: conf.keep_partition_by_columns, }) } } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 1c5ba861d297..e84eae2b9082 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -18,6 +18,7 @@ use std::fmt::Debug; use std::sync::Arc; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; use prost::bytes::BufMut; use prost::Message; @@ -33,8 +34,9 @@ use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; -use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; -use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; +use datafusion::physical_expr::aggregate::AggregateFunctionExpr; +use datafusion::physical_expr::{LexOrdering, PhysicalExprRef, PhysicalSortRequirement}; +use datafusion::physical_plan::aggregates::AggregateMode; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; @@ -56,16 +58,15 @@ use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{BoundedWindowAggExec, WindowAggExec}; use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, - WindowExpr, + ExecutionPlan, InputOrderMode, PhysicalExpr, WindowExpr, }; use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; -use datafusion_expr::ScalarUDF; +use datafusion_expr::{AggregateUDF, ScalarUDF}; -use crate::common::{byte_to_string, proto_error, str_to_byte}; -use crate::convert_required; +use crate::common::{byte_to_string, str_to_byte}; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, parse_physical_window_expr, parse_protobuf_file_scan_config, @@ -77,10 +78,13 @@ use crate::physical_plan::to_proto::{ use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{self, window_agg_exec_node}; +use crate::protobuf::{ + self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest, +}; +use crate::{convert_required, into_required}; -use self::to_proto::serialize_physical_expr; +use self::from_proto::parse_protobuf_partitioning; +use self::to_proto::{serialize_partitioning, serialize_physical_expr}; pub mod from_proto; pub mod to_proto; @@ -177,7 +181,19 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ) })?; let filter_selectivity = filter.default_filter_selectivity.try_into(); - let filter = FilterExec::try_new(predicate, input)?; + let projection = if !filter.projection.is_empty() { + Some( + filter + .projection + .iter() + .map(|i| *i as usize) + .collect::>(), + ) + } else { + None + }; + let filter = + FilterExec::try_new(predicate, input)?.with_projection(projection)?; match filter_selectivity { Ok(filter_selectivity) => Ok(Arc::new( filter.with_default_selectivity(filter_selectivity)?, @@ -187,50 +203,68 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { )), } } - PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - parse_protobuf_file_scan_config( + PhysicalPlanType::CsvScan(scan) => Ok(Arc::new( + CsvExec::builder(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), registry, extension_codec, - )?, - scan.has_header, - str_to_byte(&scan.delimiter, "delimiter")?, - str_to_byte(&scan.quote, "quote")?, - if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape( - escape, - )) = &scan.optional_escape - { - Some(str_to_byte(escape, "escape")?) - } else { - None - }, - FileCompressionType::UNCOMPRESSED, - ))), - #[cfg(feature = "parquet")] + )?) + .with_has_header(scan.has_header) + .with_delimeter(str_to_byte(&scan.delimiter, "delimiter")?) + .with_quote(str_to_byte(&scan.quote, "quote")?) + .with_escape( + if let Some(protobuf::csv_scan_exec_node::OptionalEscape::Escape( + escape, + )) = &scan.optional_escape + { + Some(str_to_byte(escape, "escape")?) + } else { + None + }, + ) + .with_comment( + if let Some(protobuf::csv_scan_exec_node::OptionalComment::Comment( + comment, + )) = &scan.optional_comment + { + Some(str_to_byte(comment, "comment")?) + } else { + None + }, + ) + .with_newlines_in_values(scan.newlines_in_values) + .with_file_compression_type(FileCompressionType::UNCOMPRESSED) + .build(), + )), + #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetScan(scan) => { - let base_config = parse_protobuf_file_scan_config( - scan.base_conf.as_ref().unwrap(), - registry, - extension_codec, - )?; - let predicate = scan - .predicate - .as_ref() - .map(|expr| { - parse_physical_expr( - expr, - registry, - base_config.file_schema.as_ref(), - extension_codec, - ) - }) - .transpose()?; - Ok(Arc::new(ParquetExec::new( - base_config, - predicate, - None, - Default::default(), - ))) + #[cfg(feature = "parquet")] + { + let base_config = parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + registry, + extension_codec, + )?; + let predicate = scan + .predicate + .as_ref() + .map(|expr| { + parse_physical_expr( + expr, + registry, + base_config.file_schema.as_ref(), + extension_codec, + ) + }) + .transpose()?; + let mut builder = ParquetExec::builder(base_config); + if let Some(predicate) = predicate { + builder = builder.with_predicate(predicate) + } + Ok(builder.build_arc()) + } + #[cfg(not(feature = "parquet"))] + panic!("Unable to process a Parquet PhysicalPlan when `parquet` feature is not enabled") } PhysicalPlanType::AvroScan(scan) => { Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( @@ -246,10 +280,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, )?; - Ok(Arc::new(CoalesceBatchesExec::new( - input, - coalesce_batches.target_batch_size as usize, - ))) + Ok(Arc::new( + CoalesceBatchesExec::new( + input, + coalesce_batches.target_batch_size as usize, + ) + .with_fetch(coalesce_batches.fetch.map(|f| f as usize)), + )) } PhysicalPlanType::Merge(merge) => { let input: Arc = @@ -263,47 +300,16 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { runtime, extension_codec, )?; - match repart.partition_method { - Some(PartitionMethod::Hash(ref hash_part)) => { - let expr = hash_part - .hash_expr - .iter() - .map(|e| { - parse_physical_expr( - e, - registry, - input.schema().as_ref(), - extension_codec, - ) - }) - .collect::>, _>>()?; - - Ok(Arc::new(RepartitionExec::try_new( - input, - Partitioning::Hash( - expr, - hash_part.partition_count.try_into().unwrap(), - ), - )?)) - } - Some(PartitionMethod::RoundRobin(partition_count)) => { - Ok(Arc::new(RepartitionExec::try_new( - input, - Partitioning::RoundRobinBatch( - partition_count.try_into().unwrap(), - ), - )?)) - } - Some(PartitionMethod::Unknown(partition_count)) => { - Ok(Arc::new(RepartitionExec::try_new( - input, - Partitioning::UnknownPartitioning( - partition_count.try_into().unwrap(), - ), - )?)) - } - _ => internal_err!("Invalid partitioning scheme"), - } + let partitioning = parse_protobuf_partitioning( + repart.partitioning.as_ref(), + registry, + input.schema().as_ref(), + extension_codec, + )?; + Ok(Arc::new(RepartitionExec::try_new( + input, + partitioning.unwrap(), + )?)) } PhysicalPlanType::GlobalLimit(limit) => { let input: Arc = @@ -482,7 +488,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - let physical_aggr_expr: Vec> = hash_agg + let physical_aggr_expr: Vec> = hash_agg .aggr_expr .iter() .zip(hash_agg.aggr_expr_name.iter()) @@ -494,38 +500,26 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); - let ordering_req: Vec = agg_node.ordering_req.iter() - .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec)).collect::>>()?; + let ordering_req: LexOrdering = agg_node.ordering_req.iter() + .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec)) + .collect::>()?; agg_node.aggregate_function.as_ref().map(|func| { match func { - AggregateFunction::AggrFunction(i) => { - let aggr_function = protobuf::AggregateFunction::try_from(*i) - .map_err( - |_| { - proto_error(format!( - "Received an unknown aggregate function: {i}" - )) - }, - )?; - - create_aggregate_expr( - &aggr_function.into(), - agg_node.distinct, - input_phy_expr.as_slice(), - &ordering_req, - &physical_schema, - name.to_string(), - false, - ) - } AggregateFunction::UserDefinedAggrFunction(udaf_name) => { - let agg_udf = registry.udaf(udaf_name)?; - // TODO: `order by` is not supported for UDAF yet - let sort_exprs = &[]; - let ordering_req = &[]; - let ignore_nulls = false; - udaf::create_aggregate_expr(agg_udf.as_ref(), &input_phy_expr, sort_exprs, ordering_req, &physical_schema, name, ignore_nulls) + let agg_udf = match &agg_node.fun_definition { + Some(buf) => extension_codec.try_decode_udaf(udaf_name, buf)?, + None => registry.udaf(udaf_name)? + }; + + AggregateExprBuilder::new(agg_udf, input_phy_expr) + .schema(Arc::clone(&physical_schema)) + .alias(name) + .with_ignore_nulls(agg_node.ignore_nulls) + .with_distinct(agg_node.distinct) + .order_by(ordering_req) + .build() + .map(Arc::new) } } }).transpose()?.ok_or_else(|| { @@ -539,14 +533,23 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - Ok(Arc::new(AggregateExec::try_new( + let limit = hash_agg + .limit + .as_ref() + .map(|lit_value| lit_value.limit as usize); + + let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups), physical_aggr_expr, physical_filter_expr, input, physical_schema, - )?)) + )?; + + let agg = agg.with_limit(limit); + + Ok(Arc::new(agg)) } PhysicalPlanType::HashJoin(hashjoin) => { let left: Arc = into_physical_plan( @@ -849,7 +852,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { "physical_plan::from_proto() Unexpected expr {self:?}" )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -872,7 +875,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ) } }) - .collect::, _>>()?; + .collect::>()?; let fetch = if sort.fetch < 0 { None } else { @@ -896,7 +899,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { "physical_plan::from_proto() Unexpected expr {self:?}" )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -919,7 +922,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { ) } }) - .collect::, _>>()?; + .collect::>()?; let fetch = if sort.fetch < 0 { None } else { @@ -1023,7 +1026,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1034,13 +1037,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item.inner)) }) .transpose()?; Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } @@ -1053,7 +1056,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .as_ref() .ok_or_else(|| proto_error("Missing required field in protobuf"))? .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; + let sink_schema = input.schema(); let sort_order = sink .sort_order .as_ref() @@ -1064,44 +1067,79 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { &sink_schema, extension_codec, ) - .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item.inner)) }) .transpose()?; Ok(Arc::new(DataSinkExec::new( input, Arc::new(data_sink), - Arc::new(sink_schema), + sink_schema, sort_order, ))) } + #[cfg_attr(not(feature = "parquet"), allow(unused_variables))] PhysicalPlanType::ParquetSink(sink) => { - let input = - into_physical_plan(&sink.input, registry, runtime, extension_codec)?; + #[cfg(feature = "parquet")] + { + let input = into_physical_plan( + &sink.input, + registry, + runtime, + extension_codec, + )?; - let data_sink: ParquetSink = sink - .sink - .as_ref() - .ok_or_else(|| proto_error("Missing required field in protobuf"))? - .try_into()?; - let sink_schema = convert_required!(sink.sink_schema)?; - let sort_order = sink - .sort_order - .as_ref() - .map(|collection| { - parse_physical_sort_exprs( - &collection.physical_sort_expr_nodes, - registry, - &sink_schema, - extension_codec, - ) - .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) - }) - .transpose()?; - Ok(Arc::new(DataSinkExec::new( + let data_sink: ParquetSink = sink + .sink + .as_ref() + .ok_or_else(|| proto_error("Missing required field in protobuf"))? + .try_into()?; + let sink_schema = input.schema(); + let sort_order = sink + .sort_order + .as_ref() + .map(|collection| { + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| { + PhysicalSortRequirement::from_sort_exprs(&item.inner) + }) + }) + .transpose()?; + Ok(Arc::new(DataSinkExec::new( + input, + Arc::new(data_sink), + sink_schema, + sort_order, + ))) + } + #[cfg(not(feature = "parquet"))] + panic!("Trying to use ParquetSink without `parquet` feature enabled"); + } + PhysicalPlanType::Unnest(unnest) => { + let input = into_physical_plan( + &unnest.input, + registry, + runtime, + extension_codec, + )?; + + Ok(Arc::new(UnnestExec::new( input, - Arc::new(data_sink), - Arc::new(sink_schema), - sort_order, + unnest + .list_type_columns + .iter() + .map(|c| ListUnnest { + index_in_input_schema: c.index_in_input_schema as _, + depth: c.depth as _, + }) + .collect(), + unnest.struct_type_columns.iter().map(|c| *c as _).collect(), + Arc::new(convert_required!(unnest.schema)?), + into_required!(unnest.options)?, ))) } } @@ -1114,7 +1152,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { where Self: Sized, { - let plan_clone = plan.clone(); + let plan_clone = Arc::clone(&plan); let plan = plan.as_any(); if let Some(exec) = plan.downcast_ref::() { @@ -1141,7 +1179,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) + .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1182,10 +1220,16 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { protobuf::FilterExecNode { input: Some(Box::new(input)), expr: Some(serialize_physical_expr( - exec.predicate().clone(), + exec.predicate(), extension_codec, )?), default_filter_selectivity: exec.default_selectivity() as u32, + projection: exec + .projection() + .as_ref() + .map_or_else(Vec::new, |v| { + v.iter().map(|x| *x as u32).collect::>() + }), }, ))), }); @@ -1239,8 +1283,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; - let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1252,10 +1296,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = serialize_physical_expr( - f.expression().to_owned(), - extension_codec, - )?; + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; let column_indices = f .column_indices() .iter() @@ -1313,8 +1355,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; - let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; + let l = serialize_physical_expr(&tuple.0, extension_codec)?; + let r = serialize_physical_expr(&tuple.1, extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1326,10 +1368,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = serialize_physical_expr( - f.expression().to_owned(), - extension_codec, - )?; + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; let column_indices = f .column_indices() .iter() @@ -1367,7 +1407,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .map(|expr| { Ok(protobuf::PhysicalSortExprNode { expr: Some(Box::new(serialize_physical_expr( - expr.expr.to_owned(), + &expr.expr, extension_codec, )?)), asc: !expr.options.descending, @@ -1387,7 +1427,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .map(|expr| { Ok(protobuf::PhysicalSortExprNode { expr: Some(Box::new(serialize_physical_expr( - expr.expr.to_owned(), + &expr.expr, extension_codec, )?)), asc: !expr.options.descending, @@ -1467,11 +1507,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let agg_names = exec .aggr_expr() .iter() - .map(|expr| match expr.field() { - Ok(field) => Ok(field.name().clone()), - Err(e) => Err(e), - }) - .collect::>()?; + .map(|expr| expr.name().to_string()) + .collect::>(); let agg_mode = match exec.mode() { AggregateMode::Partial => protobuf::AggregateMode::Partial, @@ -1494,16 +1531,20 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) + .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) + .map(|expr| serialize_physical_expr(&expr.0, extension_codec)) .collect::>>()?; + let limit = exec.limit().map(|value| protobuf::AggLimit { + limit: value as u64, + }); + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new( protobuf::AggregateExecNode { @@ -1517,6 +1558,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { input_schema: Some(input_schema.as_ref().try_into()?), null_expr, groups, + limit, }, ))), }); @@ -1554,6 +1596,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { protobuf::CoalesceBatchesExecNode { input: Some(Box::new(input)), target_batch_size: coalesce_batches.target_batch_size() as u32, + fetch: coalesce_batches.fetch().map(|n| n as u32), }, ))), }); @@ -1577,6 +1620,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } else { None }, + optional_comment: if let Some(comment) = exec.comment() { + Some(protobuf::csv_scan_exec_node::OptionalComment::Comment( + byte_to_string(comment, "comment")?, + )) + } else { + None + }, + newlines_in_values: exec.newlines_in_values(), }, )), }); @@ -1586,7 +1637,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) + .map(|pred| serialize_physical_expr(pred, extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1634,31 +1685,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { extension_codec, )?; - let pb_partition_method = match exec.partitioning() { - Partitioning::Hash(exprs, partition_count) => { - PartitionMethod::Hash(protobuf::PhysicalHashRepartition { - hash_expr: exprs - .iter() - .map(|expr| { - serialize_physical_expr(expr.clone(), extension_codec) - }) - .collect::>>()?, - partition_count: *partition_count as u64, - }) - } - Partitioning::RoundRobinBatch(partition_count) => { - PartitionMethod::RoundRobin(*partition_count as u64) - } - Partitioning::UnknownPartitioning(partition_count) => { - PartitionMethod::Unknown(*partition_count as u64) - } - }; + let pb_partitioning = + serialize_partitioning(exec.partitioning(), extension_codec)?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Repartition(Box::new( protobuf::RepartitionExecNode { input: Some(Box::new(input)), - partition_method: Some(pb_partition_method), + partitioning: Some(pb_partitioning), }, ))), }); @@ -1675,16 +1709,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { expr: Some(Box::new(serialize_physical_expr( - expr.expr.to_owned(), + &expr.expr, extension_codec, )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( - sort_expr, - )), + expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; @@ -1744,16 +1776,14 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { expr: Some(Box::new(serialize_physical_expr( - expr.expr.to_owned(), + &expr.expr, extension_codec, )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Sort( - sort_expr, - )), + expr_type: Some(ExprType::Sort(sort_expr)), }) }) .collect::>>()?; @@ -1783,10 +1813,8 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = serialize_physical_expr( - f.expression().to_owned(), - extension_codec, - )?; + let expression = + serialize_physical_expr(f.expression(), extension_codec)?; let column_indices = f .column_indices() .iter() @@ -1828,13 +1856,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_window_expr(e, extension_codec)) .collect::>>()?; let partition_keys = exec .partition_keys .iter() - .map(|e| serialize_physical_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_expr(e, extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1858,13 +1886,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let window_expr = exec .window_expr() .iter() - .map(|e| serialize_physical_window_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_window_expr(e, extension_codec)) .collect::>>()?; let partition_keys = exec .partition_keys .iter() - .map(|e| serialize_physical_expr(e.clone(), extension_codec)) + .map(|e| serialize_physical_expr(e, extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1908,7 +1936,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { expr: Some(Box::new(serialize_physical_expr( - expr.expr.to_owned(), + &expr.expr, extension_codec, )?)), asc: !expr.options.descending, @@ -1950,6 +1978,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }); } + #[cfg(feature = "parquet")] if let Some(sink) = exec.sink().as_any().downcast_ref::() { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetSink(Box::new( @@ -1966,12 +1995,43 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { // If unknown DataSink then let extension handle it } + if let Some(exec) = plan.downcast_ref::() { + let input = protobuf::PhysicalPlanNode::try_from_physical_plan( + exec.input().to_owned(), + extension_codec, + )?; + + return Ok(protobuf::PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Unnest(Box::new( + protobuf::UnnestExecNode { + input: Some(Box::new(input)), + schema: Some(exec.schema().try_into()?), + list_type_columns: exec + .list_column_indices() + .iter() + .map(|c| ProtoListUnnest { + index_in_input_schema: c.index_in_input_schema as _, + depth: c.depth as _, + }) + .collect(), + struct_type_columns: exec + .struct_column_indices() + .iter() + .map(|c| *c as _) + .collect(), + options: Some(exec.options().into()), + }, + ))), + }); + } + let mut buf: Vec = vec![]; - match extension_codec.try_encode(plan_clone.clone(), &mut buf) { + match extension_codec.try_encode(Arc::clone(&plan_clone), &mut buf) { Ok(_) => { let inputs: Vec = plan_clone .children() .into_iter() + .cloned() .map(|i| { protobuf::PhysicalPlanNode::try_from_physical_plan( i, @@ -2035,6 +2095,32 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { fn try_encode_udf(&self, _node: &ScalarUDF, _buf: &mut Vec) -> Result<()> { Ok(()) } + + fn try_decode_expr( + &self, + _buf: &[u8], + _inputs: &[Arc], + ) -> Result> { + not_impl_err!("PhysicalExtensionCodec is not provided") + } + + fn try_encode_expr( + &self, + _node: &Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("PhysicalExtensionCodec is not provided") + } + + fn try_decode_udaf(&self, name: &str, _buf: &[u8]) -> Result> { + not_impl_err!( + "PhysicalExtensionCodec is not provided for aggregate function {name}" + ) + } + + fn try_encode_udaf(&self, _node: &AggregateUDF, _buf: &mut Vec) -> Result<()> { + Ok(()) + } } #[derive(Debug)] diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index aa6121bebc34..4bf7e353326e 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -21,22 +21,14 @@ use std::sync::Arc; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; -use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; +use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, - ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, - CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, - DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, - Variance, VariancePop, WindowShift, + BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, + Literal, NegativeExpr, NotExpr, NthValue, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; -use datafusion::physical_plan::{ - AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, -}; +use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion::{ datasource::{ file_format::{csv::CsvSink, json::JsonSink}, @@ -45,184 +37,125 @@ use datafusion::{ }, physical_plan::expressions::LikeExpr, }; -use datafusion_common::config::{ - ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, - TableParquetOptions, -}; -use datafusion_common::{ - file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, - internal_err, not_impl_err, - parsers::CompressionTypeVariant, - stats::Precision, - DataFusionError, JoinSide, Result, -}; -use datafusion_expr::ScalarFunctionDefinition; +use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; +use datafusion_expr::WindowFrame; -use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ - self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, - scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, - PhysicalSortExprNodeCollection, ScalarValue, + self, physical_aggregate_expr_node, physical_window_expr_node, PhysicalSortExprNode, + PhysicalSortExprNodeCollection, }; use super::PhysicalExtensionCodec; pub fn serialize_physical_aggr_expr( - aggr_expr: Arc, + aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { - let expressions = serialize_physical_exprs(aggr_expr.expressions(), codec)?; - let ordering_req = aggr_expr.order_bys().unwrap_or(&[]).to_vec(); + let expressions = serialize_physical_exprs(&aggr_expr.expressions(), codec)?; + let ordering_req = match aggr_expr.order_bys() { + Some(order) => LexOrdering::from_ref(order), + None => LexOrdering::default(), + }; let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?; - if let Some(a) = aggr_expr.as_any().downcast_ref::() { - let name = a.fun().name().to_string(); - return Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( - protobuf::PhysicalAggregateExprNode { - aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), - expr: expressions, - ordering_req, - distinct: false, - }, - )), - }); - } - - let AggrFn { - inner: aggr_function, - distinct, - } = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?; - + let name = aggr_expr.fun().name().to_string(); + let mut buf = Vec::new(); + codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggregate_function: Some( - physical_aggregate_expr_node::AggregateFunction::AggrFunction( - aggr_function as i32, - ), - ), + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, - distinct, + distinct: aggr_expr.is_distinct(), + ignore_nulls: aggr_expr.ignore_nulls(), + fun_definition: (!buf.is_empty()).then_some(buf) }, )), }) } +fn serialize_physical_window_aggr_expr( + aggr_expr: &AggregateFunctionExpr, + _window_frame: &WindowFrame, + codec: &dyn PhysicalExtensionCodec, +) -> Result<(physical_window_expr_node::WindowFunction, Option>)> { + if aggr_expr.is_distinct() || aggr_expr.ignore_nulls() { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } + + let mut buf = Vec::new(); + codec.try_encode_udaf(aggr_expr.fun(), &mut buf)?; + Ok(( + physical_window_expr_node::WindowFunction::UserDefinedAggrFunction( + aggr_expr.fun().name().to_string(), + ), + (!buf.is_empty()).then_some(buf), + )) +} + pub fn serialize_physical_window_expr( - window_expr: Arc, + window_expr: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { let expr = window_expr.as_any(); let mut args = window_expr.expressions().to_vec(); let window_frame = window_expr.get_window_frame(); - let window_function = if let Some(built_in_window_expr) = + let (window_function, fun_definition) = if let Some(built_in_window_expr) = expr.downcast_ref::() { let expr = built_in_window_expr.get_built_in_func_expr(); let built_in_fn_expr = expr.as_any(); - let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { - protobuf::BuiltInWindowFunction::RowNumber - } else if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() { - match rank_expr.get_type() { - RankType::Basic => protobuf::BuiltInWindowFunction::Rank, - RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, - RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, - } - } else if built_in_fn_expr.downcast_ref::().is_some() { - protobuf::BuiltInWindowFunction::CumeDist - } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { - args.insert( - 0, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - ntile_expr.get_n() as i64, - )))), - ); - protobuf::BuiltInWindowFunction::Ntile - } else if let Some(window_shift_expr) = - built_in_fn_expr.downcast_ref::() - { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - window_shift_expr.get_shift_offset(), - )))), - ); - args.insert( - 2, - Arc::new(Literal::new(window_shift_expr.get_default_value())), - ); - - if window_shift_expr.get_shift_offset() >= 0 { - protobuf::BuiltInWindowFunction::Lag - } else { - protobuf::BuiltInWindowFunction::Lead - } - } else if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { - match nth_value_expr.get_kind() { - NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, - NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, - NthValueKind::Nth(n) => { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64( - Some(n), - ))), - ); - protobuf::BuiltInWindowFunction::NthValue + let builtin_fn = + if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { + match nth_value_expr.get_kind() { + NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, + NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, + NthValueKind::Nth(n) => { + args.insert( + 1, + Arc::new(Literal::new( + datafusion_common::ScalarValue::Int64(Some(n)), + )), + ); + protobuf::BuiltInWindowFunction::NthValue + } } - } - } else { - return not_impl_err!("BuiltIn function not supported: {expr:?}"); - }; + } else { + return not_impl_err!("BuiltIn function not supported: {expr:?}"); + }; - physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32) + ( + physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32), + None, + ) } else if let Some(plain_aggr_window_expr) = expr.downcast_ref::() { - let AggrFn { inner, distinct } = - aggr_expr_to_aggr_fn(plain_aggr_window_expr.get_aggregate_expr().as_ref())?; - - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } - - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + serialize_physical_window_aggr_expr( + plain_aggr_window_expr.get_aggregate_expr(), + window_frame, + codec, + )? } else if let Some(sliding_aggr_window_expr) = expr.downcast_ref::() { - let AggrFn { inner, distinct } = - aggr_expr_to_aggr_fn(sliding_aggr_window_expr.get_aggregate_expr().as_ref())?; - - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } - - if window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } - - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + serialize_physical_window_aggr_expr( + sliding_aggr_window_expr.get_aggregate_expr(), + window_frame, + codec, + )? } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - let args = serialize_physical_exprs(args, codec)?; - let partition_by = - serialize_physical_exprs(window_expr.partition_by().to_vec(), codec)?; + let args = serialize_physical_exprs(&args, codec)?; + let partition_by = serialize_physical_exprs(window_expr.partition_by(), codec)?; let order_by = serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?; let window_frame: protobuf::WindowFrame = window_frame .as_ref() @@ -236,110 +169,10 @@ pub fn serialize_physical_window_expr( window_frame: Some(window_frame), window_function: Some(window_function), name: window_expr.name().to_string(), + fun_definition, }) } -struct AggrFn { - inner: protobuf::AggregateFunction, - distinct: bool, -} - -fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { - let aggr_expr = expr.as_any(); - let mut distinct = false; - - let inner = if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::Count - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Grouping - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BitAnd - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BitOr - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BitXor - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::BitXor - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BoolAnd - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::BoolOr - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Sum - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::Sum - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxDistinct - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - distinct = true; - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ArrayAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Min - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Max - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Avg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Variance - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::VariancePop - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Covariance - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::CovariancePop - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Stddev - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::StddevPop - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Correlation - } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { - match regr_expr.get_regr_type() { - RegrType::Slope => protobuf::AggregateFunction::RegrSlope, - RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, - RegrType::Count => protobuf::AggregateFunction::RegrCount, - RegrType::R2 => protobuf::AggregateFunction::RegrR2, - RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, - RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, - RegrType::SXX => protobuf::AggregateFunction::RegrSxx, - RegrType::SYY => protobuf::AggregateFunction::RegrSyy, - RegrType::SXY => protobuf::AggregateFunction::RegrSxy, - } - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxPercentileCont - } else if aggr_expr - .downcast_ref::() - .is_some() - { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxMedian - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Median - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::FirstValueAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::LastValueAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::StringAgg - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::NthValueAgg - } else { - return not_impl_err!("Aggregate function not supported: {expr:?}"); - }; - - Ok(AggrFn { inner, distinct }) -} - pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, @@ -358,7 +191,7 @@ pub fn serialize_physical_sort_expr( codec: &dyn PhysicalExtensionCodec, ) -> Result { let PhysicalSortExpr { expr, options } = sort_expr; - let expr = serialize_physical_expr(expr, codec)?; + let expr = serialize_physical_expr(&expr, codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !options.descending, @@ -366,12 +199,12 @@ pub fn serialize_physical_sort_expr( }) } -pub fn serialize_physical_exprs( +pub fn serialize_physical_exprs<'a, I>( values: I, codec: &dyn PhysicalExtensionCodec, ) -> Result> where - I: IntoIterator>, + I: IntoIterator>, { values .into_iter() @@ -384,7 +217,7 @@ where /// If required, a [`PhysicalExtensionCodec`] can be provided which can handle /// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]) pub fn serialize_physical_expr( - value: Arc, + value: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { let expr = value.as_any(); @@ -400,14 +233,8 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(serialize_physical_expr( - expr.left().clone(), - codec, - )?)), - r: Some(Box::new(serialize_physical_expr( - expr.right().clone(), - codec, - )?)), + l: Some(Box::new(serialize_physical_expr(expr.left(), codec)?)), + r: Some(Box::new(serialize_physical_expr(expr.right(), codec)?)), op: format!("{:?}", expr.op()), }); @@ -425,8 +252,7 @@ pub fn serialize_physical_expr( expr: expr .expr() .map(|exp| { - serialize_physical_expr(exp.clone(), codec) - .map(Box::new) + serialize_physical_expr(exp, codec).map(Box::new) }) .transpose()?, when_then_expr: expr @@ -441,10 +267,7 @@ pub fn serialize_physical_expr( >>()?, else_expr: expr .else_expr() - .map(|a| { - serialize_physical_expr(a.clone(), codec) - .map(Box::new) - }) + .map(|a| serialize_physical_expr(a, codec).map(Box::new)) .transpose()?, }, ), @@ -455,10 +278,7 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( protobuf::PhysicalNot { - expr: Some(Box::new(serialize_physical_expr( - expr.arg().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), }, ))), }) @@ -466,10 +286,7 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(serialize_physical_expr( - expr.arg().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), }), )), }) @@ -477,10 +294,7 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(serialize_physical_expr( - expr.arg().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), }), )), }) @@ -488,11 +302,8 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_physical_expr( - expr.expr().to_owned(), - codec, - )?)), - list: serialize_physical_exprs(expr.list().to_vec(), codec)?, + expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), + list: serialize_physical_exprs(expr.list(), codec)?, negated: expr.negated(), }, ))), @@ -501,10 +312,7 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( protobuf::PhysicalNegativeNode { - expr: Some(Box::new(serialize_physical_expr( - expr.arg().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(expr.arg(), codec)?)), }, ))), }) @@ -518,10 +326,7 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { - expr: Some(Box::new(serialize_physical_expr( - cast.expr().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), arrow_type: Some(cast.cast_type().try_into()?), }, ))), @@ -530,36 +335,20 @@ pub fn serialize_physical_expr( Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( protobuf::PhysicalTryCastNode { - expr: Some(Box::new(serialize_physical_expr( - cast.expr().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(cast.expr(), codec)?)), arrow_type: Some(cast.cast_type().try_into()?), }, ))), }) } else if let Some(expr) = expr.downcast_ref::() { - let args = serialize_physical_exprs(expr.args().to_vec(), codec)?; - let mut buf = Vec::new(); - match expr.fun() { - ScalarFunctionDefinition::UDF(udf) => { - codec.try_encode_udf(udf, &mut buf)?; - } - _ => { - return not_impl_err!( - "Proto serialization error: Trying to serialize a unresolved function" - ); - } - } - - let fun_definition = if buf.is_empty() { None } else { Some(buf) }; + codec.try_encode_udf(expr.fun(), &mut buf)?; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( protobuf::PhysicalScalarUdfNode { name: expr.name().to_string(), - args, - fun_definition, + args: serialize_physical_exprs(expr.args(), codec)?, + fun_definition: (!buf.is_empty()).then_some(buf), return_type: Some(expr.return_type().try_into()?), }, )), @@ -570,30 +359,74 @@ pub fn serialize_physical_expr( protobuf::PhysicalLikeExprNode { negated: expr.negated(), case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(serialize_physical_expr( - expr.expr().to_owned(), - codec, - )?)), + expr: Some(Box::new(serialize_physical_expr(expr.expr(), codec)?)), pattern: Some(Box::new(serialize_physical_expr( - expr.pattern().to_owned(), + expr.pattern(), codec, )?)), }, ))), }) } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") + let mut buf: Vec = vec![]; + match codec.try_encode_expr(value, &mut buf) { + Ok(_) => { + let inputs: Vec = value + .children() + .into_iter() + .map(|e| serialize_physical_expr(e, codec)) + .collect::>()?; + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Extension( + protobuf::PhysicalExtensionExprNode { expr: buf, inputs }, + )), + }) + } + Err(e) => internal_err!( + "Unsupported physical expr and extension codec failed with [{e}]. Expr: {value:?}" + ), + } } } +pub fn serialize_partitioning( + partitioning: &Partitioning, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let serialized_partitioning = match partitioning { + Partitioning::RoundRobinBatch(partition_count) => protobuf::Partitioning { + partition_method: Some(protobuf::partitioning::PartitionMethod::RoundRobin( + *partition_count as u64, + )), + }, + Partitioning::Hash(exprs, partition_count) => { + let serialized_exprs = serialize_physical_exprs(exprs, codec)?; + protobuf::Partitioning { + partition_method: Some(protobuf::partitioning::PartitionMethod::Hash( + protobuf::PhysicalHashRepartition { + hash_expr: serialized_exprs, + partition_count: *partition_count as u64, + }, + )), + } + } + Partitioning::UnknownPartitioning(partition_count) => protobuf::Partitioning { + partition_method: Some(protobuf::partitioning::PartitionMethod::Unknown( + *partition_count as u64, + )), + }, + }; + Ok(serialized_partitioning) +} + fn serialize_when_then_expr( when_expr: &Arc, then_expr: &Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), - then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), + when_expr: Some(serialize_physical_expr(when_expr, codec)?), + then_expr: Some(serialize_physical_expr(then_expr, codec)?), }) } @@ -617,6 +450,7 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { .map(|v| v.try_into()) .collect::, _>>()?, range: pf.range.as_ref().map(|r| r.try_into()).transpose()?, + statistics: pf.statistics.as_ref().map(|s| s.into()), }) } } @@ -645,70 +479,6 @@ impl TryFrom<&[PartitionedFile]> for protobuf::FileGroup { } } -impl From<&Precision> for protobuf::Precision { - fn from(s: &Precision) -> protobuf::Precision { - match s { - Precision::Exact(val) => protobuf::Precision { - precision_info: protobuf::PrecisionInfo::Exact.into(), - val: Some(ScalarValue { - value: Some(Value::Uint64Value(*val as u64)), - }), - }, - Precision::Inexact(val) => protobuf::Precision { - precision_info: protobuf::PrecisionInfo::Inexact.into(), - val: Some(ScalarValue { - value: Some(Value::Uint64Value(*val as u64)), - }), - }, - Precision::Absent => protobuf::Precision { - precision_info: protobuf::PrecisionInfo::Absent.into(), - val: Some(ScalarValue { value: None }), - }, - } - } -} - -impl From<&Precision> for protobuf::Precision { - fn from(s: &Precision) -> protobuf::Precision { - match s { - Precision::Exact(val) => protobuf::Precision { - precision_info: protobuf::PrecisionInfo::Exact.into(), - val: val.try_into().ok(), - }, - Precision::Inexact(val) => protobuf::Precision { - precision_info: protobuf::PrecisionInfo::Inexact.into(), - val: val.try_into().ok(), - }, - Precision::Absent => protobuf::Precision { - precision_info: protobuf::PrecisionInfo::Absent.into(), - val: Some(ScalarValue { value: None }), - }, - } - } -} - -impl From<&Statistics> for protobuf::Statistics { - fn from(s: &Statistics) -> protobuf::Statistics { - let column_stats = s.column_statistics.iter().map(|s| s.into()).collect(); - protobuf::Statistics { - num_rows: Some(protobuf::Precision::from(&s.num_rows)), - total_byte_size: Some(protobuf::Precision::from(&s.total_byte_size)), - column_stats, - } - } -} - -impl From<&ColumnStatistics> for protobuf::ColumnStats { - fn from(s: &ColumnStatistics) -> protobuf::ColumnStats { - protobuf::ColumnStats { - min_value: Some(protobuf::Precision::from(&s.min_value)), - max_value: Some(protobuf::Precision::from(&s.max_value)), - null_count: Some(protobuf::Precision::from(&s.null_count)), - distinct_count: Some(protobuf::Precision::from(&s.distinct_count)), - } - } -} - pub fn serialize_file_scan_config( conf: &FileScanConfig, codec: &dyn PhysicalExtensionCodec, @@ -763,15 +533,6 @@ pub fn serialize_file_scan_config( }) } -impl From for protobuf::JoinSide { - fn from(t: JoinSide) -> Self { - match t { - JoinSide::Left => protobuf::JoinSide::LeftSide, - JoinSide::Right => protobuf::JoinSide::RightSide, - } - } -} - pub fn serialize_maybe_filter( expr: Option>, codec: &dyn PhysicalExtensionCodec, @@ -779,7 +540,7 @@ pub fn serialize_maybe_filter( match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_physical_expr(expr, codec)?), + expr: Some(serialize_physical_expr(&expr, codec)?), }), } } @@ -848,186 +609,8 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { table_paths, output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, - overwrite: conf.overwrite, - }) - } -} - -impl From<&CompressionTypeVariant> for protobuf::CompressionTypeVariant { - fn from(value: &CompressionTypeVariant) -> Self { - match value { - CompressionTypeVariant::GZIP => Self::Gzip, - CompressionTypeVariant::BZIP2 => Self::Bzip2, - CompressionTypeVariant::XZ => Self::Xz, - CompressionTypeVariant::ZSTD => Self::Zstd, - CompressionTypeVariant::UNCOMPRESSED => Self::Uncompressed, - } - } -} - -impl TryFrom<&CsvWriterOptions> for protobuf::CsvWriterOptions { - type Error = DataFusionError; - - fn try_from(opts: &CsvWriterOptions) -> Result { - Ok(csv_writer_options_to_proto( - &opts.writer_options, - &opts.compression, - )) - } -} - -impl TryFrom<&JsonWriterOptions> for protobuf::JsonWriterOptions { - type Error = DataFusionError; - - fn try_from(opts: &JsonWriterOptions) -> Result { - let compression: protobuf::CompressionTypeVariant = opts.compression.into(); - Ok(protobuf::JsonWriterOptions { - compression: compression.into(), - }) - } -} - -impl TryFrom<&ParquetOptions> for protobuf::ParquetOptions { - type Error = DataFusionError; - - fn try_from(value: &ParquetOptions) -> Result { - Ok(protobuf::ParquetOptions { - enable_page_index: value.enable_page_index, - pruning: value.pruning, - skip_metadata: value.skip_metadata, - metadata_size_hint_opt: value.metadata_size_hint.map(|v| protobuf::parquet_options::MetadataSizeHintOpt::MetadataSizeHint(v as u64)), - pushdown_filters: value.pushdown_filters, - reorder_filters: value.reorder_filters, - data_pagesize_limit: value.data_pagesize_limit as u64, - write_batch_size: value.write_batch_size as u64, - writer_version: value.writer_version.clone(), - compression_opt: value.compression.clone().map(protobuf::parquet_options::CompressionOpt::Compression), - dictionary_enabled_opt: value.dictionary_enabled.map(protobuf::parquet_options::DictionaryEnabledOpt::DictionaryEnabled), - dictionary_page_size_limit: value.dictionary_page_size_limit as u64, - statistics_enabled_opt: value.statistics_enabled.clone().map(protobuf::parquet_options::StatisticsEnabledOpt::StatisticsEnabled), - max_statistics_size_opt: value.max_statistics_size.map(|v| protobuf::parquet_options::MaxStatisticsSizeOpt::MaxStatisticsSize(v as u64)), - max_row_group_size: value.max_row_group_size as u64, - created_by: value.created_by.clone(), - column_index_truncate_length_opt: value.column_index_truncate_length.map(|v| protobuf::parquet_options::ColumnIndexTruncateLengthOpt::ColumnIndexTruncateLength(v as u64)), - data_page_row_count_limit: value.data_page_row_count_limit as u64, - encoding_opt: value.encoding.clone().map(protobuf::parquet_options::EncodingOpt::Encoding), - bloom_filter_enabled: value.bloom_filter_enabled, - bloom_filter_fpp_opt: value.bloom_filter_fpp.map(protobuf::parquet_options::BloomFilterFppOpt::BloomFilterFpp), - bloom_filter_ndv_opt: value.bloom_filter_ndv.map(protobuf::parquet_options::BloomFilterNdvOpt::BloomFilterNdv), - allow_single_file_parallelism: value.allow_single_file_parallelism, - maximum_parallel_row_group_writers: value.maximum_parallel_row_group_writers as u64, - maximum_buffered_record_batches_per_stream: value.maximum_buffered_record_batches_per_stream as u64, - }) - } -} - -impl TryFrom<&ColumnOptions> for protobuf::ColumnOptions { - type Error = DataFusionError; - - fn try_from(value: &ColumnOptions) -> Result { - Ok(protobuf::ColumnOptions { - compression_opt: value - .compression - .clone() - .map(protobuf::column_options::CompressionOpt::Compression), - dictionary_enabled_opt: value - .dictionary_enabled - .map(protobuf::column_options::DictionaryEnabledOpt::DictionaryEnabled), - statistics_enabled_opt: value - .statistics_enabled - .clone() - .map(protobuf::column_options::StatisticsEnabledOpt::StatisticsEnabled), - max_statistics_size_opt: value.max_statistics_size.map(|v| { - protobuf::column_options::MaxStatisticsSizeOpt::MaxStatisticsSize( - v as u32, - ) - }), - encoding_opt: value - .encoding - .clone() - .map(protobuf::column_options::EncodingOpt::Encoding), - bloom_filter_enabled_opt: value - .bloom_filter_enabled - .map(protobuf::column_options::BloomFilterEnabledOpt::BloomFilterEnabled), - bloom_filter_fpp_opt: value - .bloom_filter_fpp - .map(protobuf::column_options::BloomFilterFppOpt::BloomFilterFpp), - bloom_filter_ndv_opt: value - .bloom_filter_ndv - .map(protobuf::column_options::BloomFilterNdvOpt::BloomFilterNdv), - }) - } -} - -impl TryFrom<&TableParquetOptions> for protobuf::TableParquetOptions { - type Error = DataFusionError; - fn try_from(value: &TableParquetOptions) -> Result { - let column_specific_options = value - .column_specific_options - .iter() - .map(|(k, v)| { - Ok(protobuf::ColumnSpecificOptions { - column_name: k.into(), - options: Some(v.try_into()?), - }) - }) - .collect::>>()?; - Ok(protobuf::TableParquetOptions { - global: Some((&value.global).try_into()?), - column_specific_options, - }) - } -} - -impl TryFrom<&CsvOptions> for protobuf::CsvOptions { - type Error = DataFusionError; // Define or use an appropriate error type - - fn try_from(opts: &CsvOptions) -> Result { - let compression: protobuf::CompressionTypeVariant = opts.compression.into(); - Ok(protobuf::CsvOptions { - has_header: opts.has_header, - delimiter: vec![opts.delimiter], - quote: vec![opts.quote], - escape: opts.escape.map_or_else(Vec::new, |e| vec![e]), - compression: compression.into(), - schema_infer_max_rec: opts.schema_infer_max_rec as u64, - date_format: opts.date_format.clone().unwrap_or_default(), - datetime_format: opts.datetime_format.clone().unwrap_or_default(), - timestamp_format: opts.timestamp_format.clone().unwrap_or_default(), - timestamp_tz_format: opts.timestamp_tz_format.clone().unwrap_or_default(), - time_format: opts.time_format.clone().unwrap_or_default(), - null_value: opts.null_value.clone().unwrap_or_default(), - }) - } -} - -impl TryFrom<&JsonOptions> for protobuf::JsonOptions { - type Error = DataFusionError; - - fn try_from(opts: &JsonOptions) -> Result { - let compression: protobuf::CompressionTypeVariant = opts.compression.into(); - Ok(protobuf::JsonOptions { - compression: compression.into(), - schema_infer_max_rec: opts.schema_infer_max_rec as u64, - }) - } -} - -impl TryFrom<&FormatOptions> for copy_to_node::FormatOptions { - type Error = DataFusionError; - fn try_from(value: &FormatOptions) -> std::result::Result { - Ok(match value { - FormatOptions::CSV(options) => { - copy_to_node::FormatOptions::Csv(options.try_into()?) - } - FormatOptions::JSON(options) => { - copy_to_node::FormatOptions::Json(options.try_into()?) - } - FormatOptions::PARQUET(options) => { - copy_to_node::FormatOptions::Parquet(options.try_into()?) - } - FormatOptions::AVRO => copy_to_node::FormatOptions::Avro(AvroOptions {}), - FormatOptions::ARROW => copy_to_node::FormatOptions::Arrow(ArrowOptions {}), + keep_partition_by_columns: conf.keep_partition_by_columns, + insert_op: conf.insert_op as i32, }) } } diff --git a/datafusion/proto/tests/cases/mod.rs b/datafusion/proto/tests/cases/mod.rs index b17289205f3d..fbb2cd8f1e83 100644 --- a/datafusion/proto/tests/cases/mod.rs +++ b/datafusion/proto/tests/cases/mod.rs @@ -15,6 +15,113 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::plan_err; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, Signature, Volatility, +}; + mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, + aliases: Vec, +} + +impl MyRegexUdf { + fn new(pattern: String) -> Self { + let signature = Signature::exact(vec![DataType::Utf8], Volatility::Immutable); + Self { + signature, + pattern, + aliases: vec!["aggregate_udf_alias".to_string()], + } + } +} + +/// Implement the ScalarUDFImpl trait for MyRegexUdf +impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + if matches!(args, [DataType::Utf8]) { + Ok(DataType::Int64) + } else { + plan_err!("regex_udf only accepts Utf8 arguments") + } + } + fn invoke( + &self, + _args: &[ColumnarValue], + ) -> datafusion_common::Result { + unimplemented!() + } + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, +} + +#[derive(Debug, PartialEq, Eq, Hash)] +struct MyAggregateUDF { + signature: Signature, + result: String, +} + +impl MyAggregateUDF { + fn new(result: String) -> Self { + let signature = Signature::exact(vec![DataType::Int64], Volatility::Immutable); + Self { signature, result } + } +} + +impl AggregateUDFImpl for MyAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "aggregate_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type( + &self, + _arg_types: &[DataType], + ) -> datafusion_common::Result { + Ok(DataType::Utf8) + } + fn accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> datafusion_common::Result> { + unimplemented!() + } +} + +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct MyAggregateUdfNode { + #[prost(string, tag = "1")] + pub result: String, +} diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 1fd6160c2c6c..14d91913e7cd 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,53 +15,85 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{ + ArrayRef, FixedSizeListArray, Int32Builder, MapArray, MapBuilder, StringBuilder, +}; +use arrow::datatypes::{ + DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, + DECIMAL256_MAX_PRECISION, +}; +use datafusion::datasource::file_format::json::JsonFormatFactory; +use datafusion_common::parsers::CompressionTypeVariant; +use prost::Message; use std::any::Any; use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use std::vec; -use arrow::array::{ArrayRef, FixedSizeListArray}; -use arrow::datatypes::{ - DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, - IntervalUnit, Schema, SchemaRef, TimeUnit, UnionFields, UnionMode, +use datafusion::catalog::{TableProvider, TableProviderFactory}; +use datafusion::datasource::file_format::arrow::ArrowFormatFactory; +use datafusion::datasource::file_format::csv::CsvFormatFactory; +use datafusion::datasource::file_format::parquet::ParquetFormatFactory; +use datafusion::datasource::file_format::{format_as_file_type, DefaultFileType}; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::execution::FunctionRegistry; +use datafusion::functions_aggregate::count::count_udaf; +use datafusion::functions_aggregate::expr_fn::{ + approx_median, approx_percentile_cont, approx_percentile_cont_with_weight, count, + count_distinct, covar_pop, covar_samp, first_value, grouping, max, median, min, + stddev, stddev_pop, sum, var_pop, var_sample, +}; +use datafusion::functions_aggregate::min_max::max_udaf; +use datafusion::functions_nested::map::map; +use datafusion::functions_window::expr_fn::{ + cume_dist, dense_rank, lag, lead, ntile, percent_rank, rank, row_number, }; -use datafusion::datasource::provider::TableProviderFactory; -use datafusion::datasource::TableProvider; -use datafusion::execution::context::SessionState; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; -use datafusion::functions_aggregate::expr_fn::first_value; +use datafusion::functions_window::rank::rank_udwf; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; -use datafusion_common::config::{FormatOptions, TableOptions}; +use datafusion_common::config::TableOptions; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, + DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ self, Between, BinaryExpr, Case, Cast, GroupingSet, InList, Like, ScalarFunction, - Sort, Unnest, + Unnest, WildcardOptions, }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - Accumulator, AggregateFunction, ColumnarValue, ExprSchemable, LogicalPlan, Operator, - PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, + Accumulator, AggregateUDF, ColumnarValue, ExprFunctionExt, ExprSchemable, Literal, + LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::expr_fn::{ + approx_distinct, array_agg, avg, bit_and, bit_or, bit_xor, bool_and, bool_or, corr, + nth_value, +}; +use datafusion_functions_aggregate::string_agg::string_agg; +use datafusion_functions_window_common::field::WindowUDFFieldArgs; +use datafusion_functions_window_common::partition::PartitionEvaluatorArgs; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; +use datafusion_proto::logical_plan::file_formats::{ + ArrowLogicalExtensionCodec, CsvLogicalExtensionCodec, JsonLogicalExtensionCodec, + ParquetLogicalExtensionCodec, +}; use datafusion_proto::logical_plan::to_proto::serialize_expr; -use datafusion_proto::logical_plan::LogicalExtensionCodec; -use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; +use datafusion_proto::logical_plan::{ + from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec, +}; use datafusion_proto::protobuf; -use datafusion::execution::FunctionRegistry; -use prost::Message; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; #[cfg(feature = "json")] fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { @@ -78,10 +110,8 @@ fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} fn roundtrip_expr_test(initial_struct: Expr, ctx: SessionContext) { let extension_codec = DefaultLogicalExtensionCodec {}; let proto: protobuf::LogicalExprNode = - match serialize_expr(&initial_struct, &extension_codec) { - Ok(p) => p, - Err(e) => panic!("Error serializing expression: {:?}", e), - }; + serialize_expr(&initial_struct, &extension_codec) + .unwrap_or_else(|e| panic!("Error serializing expression: {:?}", e)); let round_trip: Expr = from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); @@ -116,6 +146,9 @@ pub struct TestTableProto { /// URL of the table root #[prost(string, tag = "1")] pub url: String, + /// Qualified table name + #[prost(string, tag = "2")] + pub table_name: String, } #[derive(Debug)] @@ -138,12 +171,14 @@ impl LogicalExtensionCodec for TestTableProviderCodec { fn try_decode_table_provider( &self, buf: &[u8], + table_ref: &TableReference, schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { let msg = TestTableProto::decode(buf).map_err(|_| { DataFusionError::Internal("Error decoding test table".to_string()) })?; + assert_eq!(msg.table_name, table_ref.to_string()); let provider = TestTableProvider { url: msg.url, schema, @@ -153,6 +188,7 @@ impl LogicalExtensionCodec for TestTableProviderCodec { fn try_encode_table_provider( &self, + table_ref: &TableReference, node: Arc, buf: &mut Vec, ) -> Result<()> { @@ -163,6 +199,7 @@ impl LogicalExtensionCodec for TestTableProviderCodec { .expect("Can't encode non-test tables"); let msg = TestTableProto { url: table.url.clone(), + table_name: table_ref.to_string(), }; msg.encode(buf).map_err(|_| { DataFusionError::Internal("Error encoding test table".to_string()) @@ -175,10 +212,7 @@ async fn roundtrip_custom_tables() -> Result<()> { let mut table_factories: HashMap> = HashMap::new(); table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let cfg = RuntimeConfig::new(); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let mut state = SessionState::new_with_config_rt(ses, Arc::new(env)); + let mut state = SessionStateBuilder::new().with_default_features().build(); // replace factories *state.table_factories_mut() = table_factories; let ctx = SessionContext::new_with_state(state); @@ -235,10 +269,10 @@ async fn roundtrip_custom_listing_tables() -> Result<()> { primary key(c) ) STORED AS CSV - WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) - LOCATION '../core/tests/data/window_2.csv';"; + LOCATION '../core/tests/data/window_2.csv' + OPTIONS ('format.has_header' 'true')"; let plan = ctx.state().create_logical_plan(query).await?; @@ -265,10 +299,10 @@ async fn roundtrip_logical_plan_aggregation_with_pk() -> Result<()> { primary key(c) ) STORED AS CSV - WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) - LOCATION '../core/tests/data/window_2.csv';", + LOCATION '../core/tests/data/window_2.csv' + OPTIONS ('format.has_header' 'true')", ) .await?; @@ -279,7 +313,7 @@ async fn roundtrip_logical_plan_aggregation_with_pk() -> Result<()> { let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -305,7 +339,33 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> { let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_sort() -> Result<()> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT a, b FROM t1 ORDER BY b LIMIT 5"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -315,21 +375,21 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new(); let input = create_csv_scan(&ctx).await?; - let mut table_options = ctx.copied_table_options(); - table_options.set_file_format(FileType::CSV); - table_options.set("format.delimiter", ";")?; + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new())); let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::CSV(table_options.csv.clone()), + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + let codec = CsvLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -344,7 +404,7 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { TableOptions::default_from_session_config(ctx.state().config_options()); let mut parquet_format = table_options.parquet; - parquet_format.global.bloom_filter_enabled = true; + parquet_format.global.bloom_filter_on_read = true; parquet_format.global.created_by = "DataFusion Test".to_string(); parquet_format.global.writer_version = "PARQUET_2_0".to_string(); parquet_format.global.write_batch_size = 111; @@ -353,26 +413,28 @@ async fn roundtrip_logical_plan_copy_to_writer_options() -> Result<()> { parquet_format.global.dictionary_page_size_limit = 444; parquet_format.global.max_row_group_size = 555; + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format), + )); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.parquet".to_string(), - format_options: FormatOptions::PARQUET(parquet_format.clone()), + file_type, partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = ParquetLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); - match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.parquet", copy_to.output_url); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); - assert_eq!( - copy_to.format_options, - FormatOptions::PARQUET(parquet_format) - ); + assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string()); } _ => panic!(), } @@ -385,22 +447,26 @@ async fn roundtrip_logical_plan_copy_to_arrow() -> Result<()> { let input = create_csv_scan(&ctx).await?; + let file_type = format_as_file_type(Arc::new(ArrowFormatFactory::new())); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.arrow".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::ARROW, + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + let codec = ArrowLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.arrow", copy_to.output_url); - assert_eq!(FormatOptions::ARROW, copy_to.format_options); + assert_eq!("arrow".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); } _ => panic!(), @@ -426,29 +492,193 @@ async fn roundtrip_logical_plan_copy_to_csv() -> Result<()> { csv_format.time_format = Some("HH:mm:ss".to_string()); csv_format.null_value = Some("NIL".to_string()); + let file_type = format_as_file_type(Arc::new(CsvFormatFactory::new_with_options( + csv_format.clone(), + ))); + let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: "test.csv".to_string(), partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], - format_options: FormatOptions::CSV(csv_format.clone()), + file_type, options: Default::default(), }); - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = CsvLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); match logical_round_trip { LogicalPlan::Copy(copy_to) => { assert_eq!("test.csv", copy_to.output_url); - assert_eq!(FormatOptions::CSV(csv_format), copy_to.format_options); + assert_eq!("csv".to_string(), copy_to.file_type.get_ext()); + assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.as_format_factory(); + let csv_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let csv_config = csv_factory.options.as_ref().unwrap(); + assert_eq!(csv_format.delimiter, csv_config.delimiter); + assert_eq!(csv_format.date_format, csv_config.date_format); + assert_eq!(csv_format.datetime_format, csv_config.datetime_format); + assert_eq!(csv_format.timestamp_format, csv_config.timestamp_format); + assert_eq!(csv_format.time_format, csv_config.time_format); + assert_eq!(csv_format.null_value, csv_config.null_value) + } + _ => panic!(), + } + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_json() -> Result<()> { + let ctx = SessionContext::new(); + + // Assume create_json_scan creates a logical plan for scanning a JSON file + let input = create_json_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut json_format = table_options.json; + + // Set specific JSON format options + json_format.compression = CompressionTypeVariant::GZIP; + json_format.schema_infer_max_rec = 1000; + + let file_type = format_as_file_type(Arc::new(JsonFormatFactory::new_with_options( + json_format.clone(), + ))); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.json".to_string(), + partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + file_type, + options: Default::default(), + }); + + // Assume JsonLogicalExtensionCodec is implemented similarly to CsvLogicalExtensionCodec + let codec = JsonLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.json", copy_to.output_url); + assert_eq!("json".to_string(), copy_to.file_type.get_ext()); assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.as_format_factory(); + let json_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let json_config = json_factory.options.as_ref().unwrap(); + assert_eq!(json_format.compression, json_config.compression); + assert_eq!( + json_format.schema_infer_max_rec, + json_config.schema_infer_max_rec + ); } _ => panic!(), } Ok(()) } + +#[tokio::test] +async fn roundtrip_logical_plan_copy_to_parquet() -> Result<()> { + let ctx = SessionContext::new(); + + // Assume create_parquet_scan creates a logical plan for scanning a Parquet file + let input = create_parquet_scan(&ctx).await?; + + let table_options = + TableOptions::default_from_session_config(ctx.state().config_options()); + let mut parquet_format = table_options.parquet; + + // Set specific Parquet format options + let mut key_value_metadata = HashMap::new(); + key_value_metadata.insert("test".to_string(), Some("test".to_string())); + parquet_format + .key_value_metadata + .clone_from(&key_value_metadata); + + parquet_format.global.allow_single_file_parallelism = false; + parquet_format.global.created_by = "test".to_string(); + + let file_type = format_as_file_type(Arc::new( + ParquetFormatFactory::new_with_options(parquet_format.clone()), + )); + + let plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(input), + output_url: "test.parquet".to_string(), + partition_by: vec!["a".to_string(), "b".to_string(), "c".to_string()], + file_type, + options: Default::default(), + }); + + // Assume ParquetLogicalExtensionCodec is implemented similarly to JsonLogicalExtensionCodec + let codec = ParquetLogicalExtensionCodec {}; + let bytes = logical_plan_to_bytes_with_extension_codec(&plan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); + + match logical_round_trip { + LogicalPlan::Copy(copy_to) => { + assert_eq!("test.parquet", copy_to.output_url); + assert_eq!("parquet".to_string(), copy_to.file_type.get_ext()); + assert_eq!(vec!["a", "b", "c"], copy_to.partition_by); + + let file_type = copy_to + .file_type + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + + let format_factory = file_type.as_format_factory(); + let parquet_factory = format_factory + .as_ref() + .as_any() + .downcast_ref::() + .unwrap(); + let parquet_config = parquet_factory.options.as_ref().unwrap(); + assert_eq!(parquet_config.key_value_metadata, key_value_metadata); + assert!(!parquet_config.global.allow_single_file_parallelism); + assert_eq!(parquet_config.global.created_by, "test".to_string()); + } + _ => panic!(), + } + + Ok(()) +} + async fn create_csv_scan(ctx: &SessionContext) -> Result { ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) .await?; @@ -457,6 +687,32 @@ async fn create_csv_scan(ctx: &SessionContext) -> Result Result { + ctx.register_json( + "t1", + "../core/tests/data/1.json", + NdJsonReadOptions::default(), + ) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + +async fn create_parquet_scan( + ctx: &SessionContext, +) -> Result { + ctx.register_parquet( + "t1", + "../substrait/tests/testdata/empty.parquet", + ParquetReadOptions::default(), + ) + .await?; + + let input = ctx.table("t1").await?.into_optimized_plan()?; + Ok(input) +} + #[tokio::test] async fn roundtrip_logical_plan_distinct_on() -> Result<()> { let ctx = SessionContext::new(); @@ -478,7 +734,7 @@ async fn roundtrip_logical_plan_distinct_on() -> Result<()> { let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -504,7 +760,7 @@ async fn roundtrip_single_count_distinct() -> Result<()> { let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -517,7 +773,32 @@ async fn roundtrip_logical_plan_with_extension() -> Result<()> { let plan = ctx.table("t1").await?.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); + Ok(()) +} + +#[tokio::test] +async fn roundtrip_logical_plan_unnest() -> Result<()> { + let ctx = SessionContext::new(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new( + "b", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + true, + ), + ]); + ctx.register_csv( + "t1", + "tests/testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + let query = "SELECT unnest(b) FROM t1"; + let plan = ctx.sql(query).await?.into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -580,10 +861,22 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(1), lit(2), lit(3)]), lit(1), lit(2), + Some(lit(1)), + ), + array_slice( + make_array(vec![lit(1), lit(2), lit(3)]), lit(1), + lit(2), + None, ), array_pop_front(make_array(vec![lit(1), lit(2), lit(3)])), array_pop_back(make_array(vec![lit(1), lit(2), lit(3)])), + array_any_value(make_array(vec![ + lit(ScalarValue::Null), + lit(1), + lit(2), + lit(3), + ])), array_reverse(make_array(vec![lit(1), lit(2), lit(3)])), array_position( make_array(vec![lit(1), lit(2), lit(3), lit(4)]), @@ -613,14 +906,71 @@ async fn roundtrip_expr_api() -> Result<()> { lit(1), ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), - first_value(vec![lit(1)], false, None, None, None), + count(lit(1)), + count_distinct(lit(1)), + first_value(lit(1), None), + first_value(lit(1), Some(vec![lit(2).sort(true, true)])), + avg(lit(1.5)), + covar_samp(lit(1.5), lit(2.2)), + covar_pop(lit(1.5), lit(2.2)), + corr(lit(1.5), lit(2.2)), + sum(lit(1)), + max(lit(1)), + median(lit(2)), + min(lit(2)), + var_sample(lit(2.2)), + var_pop(lit(2.2)), + stddev(lit(2.2)), + stddev_pop(lit(2.2)), + approx_distinct(lit(2)), + approx_median(lit(2)), + approx_percentile_cont(lit(2), lit(0.5), None), + approx_percentile_cont(lit(2), lit(0.5), Some(lit(50))), + approx_percentile_cont_with_weight(lit(2), lit(1), lit(0.5)), + grouping(lit(1)), + bit_and(lit(2)), + bit_or(lit(2)), + bit_xor(lit(2)), + string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), + bool_and(lit(true)), + bool_or(lit(true)), + array_agg(lit(1)), + array_agg(lit(1)).distinct().build().unwrap(), + map( + vec![lit(1), lit(2), lit(3)], + vec![lit(10), lit(20), lit(30)], + ), + cume_dist(), + row_number(), + rank(), + dense_rank(), + percent_rank(), + lead(col("b"), None, None), + lead(col("b"), Some(2), None), + lead(col("b"), Some(2), Some(ScalarValue::from(100))), + lag(col("b"), None, None), + lag(col("b"), Some(2), None), + lag(col("b"), Some(2), Some(ScalarValue::from(100))), + ntile(lit(3)), + nth_value(col("b"), 1, vec![]), + nth_value( + col("b"), + 1, + vec![col("a").sort(false, false), col("b").sort(true, false)], + ), + nth_value(col("b"), -1, vec![]), + nth_value( + col("b"), + -1, + vec![col("a").sort(false, false), col("b").sort(true, false)], + ), ]; // ensure expressions created with the expr api can be round tripped let plan = table.select(expr_list)?.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -640,13 +990,13 @@ async fn roundtrip_logical_plan_with_view_scan() -> Result<()> { let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); // DROP let plan = ctx.sql("DROP VIEW view_t1").await?.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + assert_eq!(format!("{plan}"), format!("{logical_round_trip}")); Ok(()) } @@ -658,7 +1008,7 @@ pub mod proto { pub k: u64, #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, + pub expr: Option, } #[derive(Clone, PartialEq, Eq, ::prost::Message)] @@ -666,15 +1016,9 @@ pub mod proto { #[prost(uint64, tag = "1")] pub k: u64, } - - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct MyRegexUdfNode { - #[prost(string, tag = "1")] - pub pattern: String, - } } -#[derive(PartialEq, Eq, Hash)] +#[derive(PartialEq, Eq, PartialOrd, Hash)] struct TopKPlanNode { k: usize, input: LogicalPlan, @@ -718,14 +1062,22 @@ impl UserDefinedLogicalNodeCore for TopKPlanNode { write!(f, "TopK: k={}", self.k) } - fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + fn with_exprs_and_inputs( + &self, + mut exprs: Vec, + mut inputs: Vec, + ) -> Result { assert_eq!(inputs.len(), 1, "input size inconsistent"); assert_eq!(exprs.len(), 1, "expression size inconsistent"); - Self { + Ok(Self { k: self.k, - input: inputs[0].clone(), - expr: exprs[0].clone(), - } + input: inputs.swap_remove(0), + expr: exprs.swap_remove(0), + }) + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default } } @@ -782,6 +1134,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { fn try_decode_table_provider( &self, _buf: &[u8], + _table_ref: &TableReference, _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { @@ -790,6 +1143,7 @@ impl LogicalExtensionCodec for TopKExtensionCodec { fn try_encode_table_provider( &self, + _table_ref: &TableReference, _node: Arc, _buf: &mut Vec, ) -> Result<()> { @@ -798,51 +1152,9 @@ impl LogicalExtensionCodec for TopKExtensionCodec { } #[derive(Debug)] -struct MyRegexUdf { - signature: Signature, - // regex as original string - pattern: String, -} - -impl MyRegexUdf { - fn new(pattern: String) -> Self { - Self { - signature: Signature::uniform( - 1, - vec![DataType::Int32], - Volatility::Immutable, - ), - pattern, - } - } -} - -/// Implement the ScalarUDFImpl trait for MyRegexUdf -impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "regex_udf" - } - fn signature(&self) -> &Signature { - &self.signature - } - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.first(), Some(&DataType::Utf8)) { - return plan_err!("regex_udf only accepts Utf8 arguments"); - } - Ok(DataType::Int32) - } - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() - } -} - -#[derive(Debug)] -pub struct ScalarUDFExtensionCodec {} +pub struct UDFExtensionCodec; -impl LogicalExtensionCodec for ScalarUDFExtensionCodec { +impl LogicalExtensionCodec for UDFExtensionCodec { fn try_decode( &self, _buf: &[u8], @@ -859,6 +1171,7 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_decode_table_provider( &self, _buf: &[u8], + _table_ref: &TableReference, _schema: SchemaRef, _ctx: &SessionContext, ) -> Result> { @@ -867,6 +1180,7 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_encode_table_provider( &self, + _table_ref: &TableReference, _node: Arc, _buf: &mut Vec, ) -> Result<()> { @@ -875,13 +1189,11 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { if name == "regex_udf" { - let proto = proto::MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!("failed to decode regex_udf: {}", err)) + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) })?; - Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( - proto.pattern, - )))) + Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) } else { not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } @@ -890,18 +1202,46 @@ impl LogicalExtensionCodec for ScalarUDFExtensionCodec { fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { let binding = node.inner(); let udf = binding.as_any().downcast_ref::().unwrap(); - let proto = proto::MyRegexUdfNode { + let proto = MyRegexUdfNode { pattern: udf.pattern.clone(), }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) + })?; + Ok(()) + } + + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "aggregate_udf" { + let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode aggregate_udf: {err}" + )) + })?; + + Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( + proto.result, + )))) + } else { + not_impl_err!("unrecognized aggregate UDF implementation, cannot decode") + } + } + + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + let udf = binding.as_any().downcast_ref::().unwrap(); + let proto = MyAggregateUdfNode { + result: udf.result.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) })?; Ok(()) } } #[test] -fn round_trip_scalar_values() { +fn round_trip_scalar_values_and_data_types() { let should_pass: Vec = vec![ ScalarValue::Boolean(None), ScalarValue::Float32(None), @@ -916,7 +1256,7 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(None), ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), - ScalarValue::List(ScalarValue::new_list(&[], &DataType::Boolean)), + ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), ScalarValue::Boolean(Some(true)), @@ -955,6 +1295,8 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(Some(0)), ScalarValue::Utf8(Some(String::from("Test string "))), ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), + ScalarValue::Utf8View(Some(String::from("Test stringview"))), + ScalarValue::BinaryView(Some(b"binaryview".to_vec())), ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(i32::MAX)), ScalarValue::Date32(None), @@ -1008,7 +1350,7 @@ fn round_trip_scalar_values() { i64::MAX, ))), ScalarValue::IntervalMonthDayNano(None), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -1028,10 +1370,13 @@ fn round_trip_scalar_values() { ], &DataType::Float32, )), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( &[ - ScalarValue::List(ScalarValue::new_list(&[], &DataType::Float32)), - ScalarValue::List(ScalarValue::new_list( + ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )), + ScalarValue::List(ScalarValue::new_list_nullable( &[ ScalarValue::Float32(Some(-213.1)), ScalarValue::Float32(None), @@ -1094,31 +1439,141 @@ fn round_trip_scalar_values() { ) .build() .unwrap(), + ScalarStructBuilder::new() + .with_scalar( + Field::new("a", DataType::Int32, true), + ScalarValue::from(23i32), + ) + .with_scalar( + Field::new("b", DataType::Boolean, false), + ScalarValue::from(false), + ) + .with_scalar( + Field::new( + "c", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Utf8), + ), + false, + ), + ScalarValue::Dictionary( + Box::new(DataType::UInt16), + Box::new("value".into()), + ), + ) + .build() + .unwrap(), ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ Field::new("a", DataType::Int32, true), Field::new("b", DataType::Boolean, false), ]))) .unwrap(), + ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, false), + Field::new( + "c", + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Binary), + ), + false, + ), + Field::new( + "d", + DataType::new_list( + DataType::Dictionary( + Box::new(DataType::UInt16), + Box::new(DataType::Binary), + ), + false, + ), + false, + ), + ]))) + .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, false), + ])), + false, + )), + false, + )) + .unwrap(), + ScalarValue::try_from(&DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Int32, true), + Field::new("value", DataType::Utf8, true), + ])), + false, + )), + true, + )) + .unwrap(), + ScalarValue::Map(Arc::new(create_map_array_test_case())), ScalarValue::FixedSizeBinary(b"bar".to_vec().len() as i32, Some(b"bar".to_vec())), ScalarValue::FixedSizeBinary(0, None), ScalarValue::FixedSizeBinary(5, None), ]; - for test_case in should_pass.into_iter() { - let proto: protobuf::ScalarValue = (&test_case) - .try_into() - .expect("failed conversion to protobuf"); - + // ScalarValue directly + for test_case in should_pass.iter() { + let proto: protobuf::ScalarValue = + test_case.try_into().expect("failed conversion to protobuf"); let roundtrip: ScalarValue = (&proto) .try_into() .expect("failed conversion from protobuf"); assert_eq!( - test_case, roundtrip, + test_case, &roundtrip, "ScalarValue was not the same after round trip!\n\n\ Input: {test_case:?}\n\nRoundtrip: {roundtrip:?}" ); } + + // DataType conversion + for test_case in should_pass.iter() { + let dt = test_case.data_type(); + + let proto: protobuf::ArrowType = (&dt) + .try_into() + .expect("datatype failed conversion to protobuf"); + let roundtrip: DataType = (&proto) + .try_into() + .expect("datatype failed conversion from protobuf"); + + assert_eq!( + dt, roundtrip, + "DataType was not the same after round trip!\n\n\ + Input: {dt:?}\n\nRoundtrip: {roundtrip:?}" + ); + } +} + +// create a map array [{joe:1}, {blogs:2, foo:4}, {}, null] for testing +fn create_map_array_test_case() -> MapArray { + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::with_capacity(4); + let mut builder = MapBuilder::new(None, string_builder, int_builder); + builder.keys().append_value("joe"); + builder.values().append_value(1); + builder.append(true).unwrap(); + + builder.keys().append_value("blogs"); + builder.values().append_value(2); + builder.keys().append_value("foo"); + builder.values().append_value(4); + builder.append(true).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + builder.finish() } #[test] @@ -1202,6 +1657,7 @@ fn round_trip_datatype() { DataType::Utf8, DataType::LargeUtf8, DataType::Decimal128(7, 12), + DataType::Decimal256(DECIMAL256_MAX_PRECISION, 0), // Recursive list tests DataType::List(new_arc_field("Level1", DataType::Binary, true)), DataType::List(new_arc_field( @@ -1331,13 +1787,11 @@ fn roundtrip_dict_id() -> Result<()> { // encode let mut buf: Vec = vec![]; - let schema_proto: datafusion_proto::generated::datafusion::Schema = - schema.try_into().unwrap(); + let schema_proto: protobuf::Schema = schema.try_into().unwrap(); schema_proto.encode(&mut buf).unwrap(); // decode - let schema_proto = - datafusion_proto::generated::datafusion::Schema::decode(buf.as_slice()).unwrap(); + let schema_proto = protobuf::Schema::decode(buf.as_slice()).unwrap(); let decoded: Schema = (&schema_proto).try_into()?; // assert @@ -1375,8 +1829,7 @@ fn roundtrip_null_scalar_values() { for test_case in test_types.into_iter() { let proto_scalar: protobuf::ScalarValue = (&test_case).try_into().unwrap(); - let returned_scalar: datafusion::scalar::ScalarValue = - (&proto_scalar).try_into().unwrap(); + let returned_scalar: ScalarValue = (&proto_scalar).try_into().unwrap(); assert_eq!(format!("{:?}", &test_case), format!("{returned_scalar:?}")); } } @@ -1568,14 +2021,6 @@ fn roundtrip_try_cast() { roundtrip_expr_test(test_expr, ctx); } -#[test] -fn roundtrip_sort_expr() { - let test_expr = Expr::Sort(Sort::new(Box::new(lit(1.0_f32)), true, true)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - #[test] fn roundtrip_negative() { let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); @@ -1608,7 +2053,10 @@ fn roundtrip_unnest() { #[test] fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard { qualifier: None }; + let test_expr = Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }; let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); @@ -1618,6 +2066,7 @@ fn roundtrip_wildcard() { fn roundtrip_qualified_wildcard() { let test_expr = Expr::Wildcard { qualifier: Some("foo".into()), + options: WildcardOptions::default(), }; let ctx = SessionContext::new(); @@ -1683,43 +2132,18 @@ fn roundtrip_similar_to() { #[test] fn roundtrip_count() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - false, - None, - None, - None, - )); + let test_expr = count(col("bananas")); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } #[test] fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::Count, - vec![col("bananas")], - true, - None, - None, - None, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); -} - -#[test] -fn roundtrip_approx_percentile_cont() { - let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ApproxPercentileCont, - vec![col("bananas"), lit(0.42_f32)], - false, - None, - None, - None, - )); - + let test_expr = count_udaf() + .call(vec![col("bananas")]) + .distinct() + .build() + .unwrap(); let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx); } @@ -1730,27 +2154,24 @@ fn roundtrip_aggregate_udf() { struct Dummy {} impl Accumulator for Dummy { - fn state(&mut self) -> datafusion::error::Result> { + fn state(&mut self) -> Result> { Ok(vec![]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64(None)) } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -1795,7 +2216,7 @@ fn roundtrip_scalar_udf() { let udf = create_udf( "dummy", vec![DataType::Utf8], - Arc::new(DataType::Utf8), + DataType::Utf8, Volatility::Immutable, scalar_fn, ); @@ -1813,25 +2234,27 @@ fn roundtrip_scalar_udf() { #[test] fn roundtrip_scalar_udf_extension_codec() { - let pattern = ".*"; - let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); - let test_expr = - Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf.clone()), vec![])); - + let udf = ScalarUDF::from(MyRegexUdf::new(".*".to_owned())); + let test_expr = udf.call(vec!["foo".lit()]); let ctx = SessionContext::new(); - ctx.register_udf(udf); - - let extension_codec = ScalarUDFExtensionCodec {}; - let proto: protobuf::LogicalExprNode = - match serialize_expr(&test_expr, &extension_codec) { - Ok(p) => p, - Err(e) => panic!("Error serializing expression: {:?}", e), - }; - let round_trip: Expr = - from_proto::parse_expr(&proto, &ctx, &extension_codec).unwrap(); + let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); + let round_trip = + from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); + roundtrip_json_test(&proto); +} + +#[test] +fn roundtrip_aggregate_udf_extension_codec() { + let udf = AggregateUDF::from(MyAggregateUDF::new("DataFusion".to_owned())); + let test_expr = udf.call(vec![42.lit()]); + let ctx = SessionContext::new(); + let proto = serialize_expr(&test_expr, &UDFExtensionCodec).expect("serialize expr"); + let round_trip = + from_proto::parse_expr(&proto, &ctx, &UDFExtensionCodec).expect("parse expr"); + assert_eq!(format!("{:?}", &test_expr), format!("{round_trip:?}")); roundtrip_json_test(&proto); } @@ -1896,27 +2319,25 @@ fn roundtrip_window() { // 1. without window_frame let test_expr1 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, false)]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 2. with default window_frame let test_expr2 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], - vec![col("col1")], - vec![col("col2")], - WindowFrame::new(Some(false)), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(false, true)]) + .window_frame(WindowFrame::new(Some(false))) + .build() + .unwrap(); // 3. with window_frame with row numbers let range_number_frame = WindowFrame::new_bounds( @@ -1926,15 +2347,14 @@ fn roundtrip_window() { ); let test_expr3 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::BuiltInWindowFunction( - datafusion_expr::BuiltInWindowFunction::Rank, - ), + WindowFunctionDefinition::WindowUDF(rank_udwf()), vec![], - vec![col("col1")], - vec![col("col2")], - range_number_frame, - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(false, false)]) + .window_frame(range_number_frame) + .build() + .unwrap(); // 4. test with AggregateFunction let row_number_frame = WindowFrame::new_bounds( @@ -1944,40 +2364,38 @@ fn roundtrip_window() { ); let test_expr4 = Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(AggregateFunction::Max), + WindowFunctionDefinition::AggregateUDF(max_udaf()), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); // 5. test with AggregateUDF #[derive(Debug)] struct DummyAggr {} impl Accumulator for DummyAggr { - fn state(&mut self) -> datafusion::error::Result> { + fn state(&mut self) -> Result> { Ok(vec![]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Float64(None)) } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -1998,11 +2416,12 @@ fn roundtrip_window() { let test_expr5 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::AggregateUDF(Arc::new(dummy_agg.clone())), vec![col("col1")], - vec![col("col1")], - vec![col("col2")], - row_number_frame.clone(), - None, - )); + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); ctx.register_udaf(dummy_agg); // 6. test with WindowUDF @@ -2049,19 +2468,23 @@ fn roundtrip_window() { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 1 { - return plan_err!( - "dummy_udwf expects 1 argument, got {}: {:?}", - arg_types.len(), - arg_types - ); - } - Ok(arg_types[0].clone()) + fn partition_evaluator( + &self, + _partition_evaluator_args: PartitionEvaluatorArgs, + ) -> Result> { + make_partition_evaluator() } - fn partition_evaluator(&self) -> Result> { - make_partition_evaluator() + fn field(&self, field_args: WindowUDFFieldArgs) -> Result { + if let Some(return_type) = field_args.get_input_type(0) { + Ok(Field::new(field_args.name(), return_type, true)) + } else { + plan_err!( + "dummy_udwf expects 1 argument, got {}: {:?}", + field_args.input_types().len(), + field_args.input_types() + ) + } } } @@ -2074,11 +2497,20 @@ fn roundtrip_window() { let test_expr6 = Expr::WindowFunction(expr::WindowFunction::new( WindowFunctionDefinition::WindowUDF(Arc::new(dummy_window_udf.clone())), vec![col("col1")], + )) + .partition_by(vec![col("col1")]) + .order_by(vec![col("col2").sort(true, true)]) + .window_frame(row_number_frame.clone()) + .build() + .unwrap(); + + let text_expr7 = Expr::WindowFunction(expr::WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(avg_udaf()), vec![col("col1")], - vec![col("col2")], - row_number_frame, - None, - )); + )) + .window_frame(row_number_frame) + .build() + .unwrap(); ctx.register_udwf(dummy_window_udf); @@ -2087,5 +2519,6 @@ fn roundtrip_window() { roundtrip_expr_test(test_expr3, ctx.clone()); roundtrip_expr_test(test_expr4, ctx.clone()); roundtrip_expr_test(test_expr5, ctx.clone()); - roundtrip_expr_test(test_expr6, ctx); + roundtrip_expr_test(test_expr6, ctx.clone()); + roundtrip_expr_test(text_expr7, ctx); } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 642860d6397b..1e078ee410c6 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -16,13 +16,24 @@ // under the License. use std::any::Any; +use std::fmt::Display; +use std::hash::Hasher; use std::ops::Deref; use std::sync::Arc; use std::vec; +use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; +use arrow::datatypes::{Fields, TimeUnit}; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; +use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_expr::dml::InsertOp; +use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_functions_aggregate::min_max::max_udaf; use prost::Message; +use crate::cases::{MyAggregateUDF, MyAggregateUdfNode, MyRegexUdf, MyRegexUdfNode}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -36,18 +47,22 @@ use datafusion::datasource::physical_plan::{ FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; +use datafusion::functions_aggregate::sum::sum_udaf; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg}; +use datafusion::physical_expr::expressions::Literal; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; -use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; +use datafusion::physical_expr::{ + LexOrdering, LexOrderingRef, LexRequirement, PhysicalSortRequirement, + ScalarFunctionExpr, +}; use datafusion::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, }; use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, + binary, cast, col, in_list, like, lit, BinaryExpr, Column, NotExpr, NthValue, + PhysicalSortExpr, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -60,12 +75,11 @@ use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::repartition::RepartitionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::union::{InterleaveExec, UnionExec}; +use datafusion::physical_plan::unnest::{ListUnnest, UnnestExec}; use datafusion::physical_plan::windows::{ BuiltInWindowExpr, PlainAggregateWindowExpr, WindowAggExec, }; -use datafusion::physical_plan::{ - udaf, AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, -}; +use datafusion::physical_plan::{ExecutionPlan, Partitioning, PhysicalExpr, Statistics}; use datafusion::prelude::SessionContext; use datafusion::scalar::ScalarValue; use datafusion_common::config::TableParquetOptions; @@ -73,12 +87,16 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, not_impl_err, DataFusionError, Result, UnnestOptions, +}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, - ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, - WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, + Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; +use datafusion_functions_aggregate::average::avg_udaf; +use datafusion_functions_aggregate::nth_value::nth_value_udaf; +use datafusion_functions_aggregate::string_agg::string_agg_udaf; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -254,8 +272,7 @@ fn roundtrip_nested_loop_join() -> Result<()> { fn roundtrip_window() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); let field_b = Field::new("b", DataType::Int64, false); - let field_c = Field::new("FIRST_VALUE(a) PARTITION BY [b] ORDER BY [a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c])); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); let window_frame = WindowFrame::new_bounds( datafusion_expr::WindowFrameUnits::Range, @@ -282,13 +299,16 @@ fn roundtrip_window() -> Result<()> { )); let plain_aggr_window_expr = Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Avg::new( - cast(col("b", &schema)?, &schema, DataType::Float64)?, - "AVG(b)".to_string(), - DataType::Float64, - )), - &[], + AggregateExprBuilder::new( + avg_udaf(), + vec![cast(col("b", &schema)?, &schema, DataType::Float64)?], + ) + .schema(Arc::clone(&schema)) + .alias("avg(b)") + .build() + .map(Arc::new)?, &[], + LexOrderingRef::default(), Arc::new(WindowFrame::new(None)), )); @@ -298,14 +318,17 @@ fn roundtrip_window() -> Result<()> { WindowFrameBound::Preceding(ScalarValue::Int64(None)), ); + let args = vec![cast(col("a", &schema)?, &schema, DataType::Float64)?]; + let sum_expr = AggregateExprBuilder::new(sum_udaf(), args) + .schema(Arc::clone(&schema)) + .alias("SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING") + .build() + .map(Arc::new)?; + let sliding_aggr_window_expr = Arc::new(SlidingAggregateWindowExpr::new( - Arc::new(Sum::new( - cast(col("a", &schema)?, &schema, DataType::Float64)?, - "SUM(a) RANGE BETWEEN CURRENT ROW AND UNBOUNDED PRECEEDING", - DataType::Float64, - )), - &[], + sum_expr, &[], + LexOrderingRef::default(), Arc::new(window_frame), )); @@ -331,30 +354,28 @@ fn rountrip_aggregate() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let test_cases: Vec>> = vec![ + let avg_expr = AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build()?; + let nth_expr = + AggregateExprBuilder::new(nth_value_udaf(), vec![col("b", &schema)?, lit(1u64)]) + .schema(Arc::clone(&schema)) + .alias("NTH_VALUE(b, 1)") + .build()?; + let str_agg_expr = + AggregateExprBuilder::new(string_agg_udaf(), vec![col("b", &schema)?, lit(1u64)]) + .schema(Arc::clone(&schema)) + .alias("NTH_VALUE(b, 1)") + .build()?; + + let test_cases = vec![ // AVG - vec![Arc::new(Avg::new( - cast(col("b", &schema)?, &schema, DataType::Float64)?, - "AVG(b)".to_string(), - DataType::Float64, - ))], + vec![Arc::new(avg_expr)], // NTH_VALUE - vec![Arc::new(NthValueAgg::new( - col("b", &schema)?, - 1, - "NTH_VALUE(b, 1)".to_string(), - DataType::Int64, - false, - Vec::new(), - Vec::new(), - ))], + vec![Arc::new(nth_expr)], // STRING_AGG - vec![Arc::new(StringAgg::new( - cast(col("b", &schema)?, &schema, DataType::Utf8)?, - lit(ScalarValue::Utf8(Some(",".to_string()))), - "STRING_AGG(name, ',')".to_string(), - DataType::Utf8, - ))], + vec![Arc::new(str_agg_expr)], ]; for aggregates in test_cases { @@ -372,6 +393,102 @@ fn rountrip_aggregate() -> Result<()> { Ok(()) } +#[test] +fn rountrip_aggregate_with_limit() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates = + vec![ + AggregateExprBuilder::new(avg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("AVG(b)") + .build() + .map(Arc::new)?, + ]; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates, + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?; + let agg = agg.with_limit(Some(12)); + roundtrip_test(Arc::new(agg)) +} + +#[test] +fn rountrip_aggregate_with_approx_pencentile_cont() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates = vec![AggregateExprBuilder::new( + approx_percentile_cont_udaf(), + vec![col("b", &schema)?, lit(0.5)], + ) + .schema(Arc::clone(&schema)) + .alias("APPROX_PERCENTILE_CONT(b, 0.5)") + .build() + .map(Arc::new)?]; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates, + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?; + roundtrip_test(Arc::new(agg)) +} + +#[test] +fn rountrip_aggregate_with_sort() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + let sort_exprs = LexOrdering::new(vec![PhysicalSortExpr { + expr: col("b", &schema)?, + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]); + + let aggregates = + vec![ + AggregateExprBuilder::new(array_agg_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("ARRAY_AGG(b)") + .order_by(sort_exprs) + .build() + .map(Arc::new)?, + ]; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates, + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?; + roundtrip_test(Arc::new(agg)) +} + #[test] fn roundtrip_aggregate_udaf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); @@ -419,21 +536,20 @@ fn roundtrip_aggregate_udaf() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![udaf::create_aggregate_expr( - &udaf, - &[col("b", &schema)?], - &[], - &[], - &schema, - "example_agg", - false, - )?]; + let aggregates = + vec![ + AggregateExprBuilder::new(Arc::new(udaf), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("example_agg") + .build() + .map(Arc::new)?, + ]; roundtrip_test_with_context( Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new_single(groups.clone()), - aggregates.clone(), + aggregates, vec![None], Arc::new(EmptyExec::new(schema.clone())), schema, @@ -470,7 +586,7 @@ fn roundtrip_sort() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = vec![ + let sort_exprs = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -485,7 +601,7 @@ fn roundtrip_sort() -> Result<()> { nulls_first: true, }, }, - ]; + ]); roundtrip_test(Arc::new(SortExec::new( sort_exprs, Arc::new(EmptyExec::new(schema)), @@ -497,7 +613,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { let field_a = Field::new("a", DataType::Boolean, false); let field_b = Field::new("b", DataType::Int64, false); let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let sort_exprs = vec![ + let sort_exprs = LexOrdering::new(vec![ PhysicalSortExpr { expr: col("a", &schema)?, options: SortOptions { @@ -512,7 +628,7 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { nulls_first: true, }, }, - ]; + ]); roundtrip_test(Arc::new(SortExec::new( sort_exprs.clone(), @@ -525,6 +641,23 @@ fn roundtrip_sort_preserve_partitioning() -> Result<()> { )) } +#[test] +fn roundtrip_coalesce_with_fetch() -> Result<()> { + let field_a = Field::new("a", DataType::Boolean, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + roundtrip_test(Arc::new(CoalesceBatchesExec::new( + Arc::new(EmptyExec::new(schema.clone())), + 8096, + )))?; + + roundtrip_test(Arc::new( + CoalesceBatchesExec::new(Arc::new(EmptyExec::new(schema)), 8096) + .with_fetch(Some(10)), + )) +} + #[test] fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { let scan_config = FileScanConfig { @@ -556,12 +689,11 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { Operator::Eq, lit("1"), )); - roundtrip_test(Arc::new(ParquetExec::new( - scan_config, - Some(predicate), - None, - Default::default(), - ))) + roundtrip_test( + ParquetExec::builder(scan_config) + .with_predicate(predicate) + .build_arc(), + ) } #[tokio::test] @@ -587,12 +719,147 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { output_ordering: vec![], }; - roundtrip_test(Arc::new(ParquetExec::new( - scan_config, - None, - None, - Default::default(), - ))) + roundtrip_test(ParquetExec::builder(scan_config).build_arc()) +} + +#[test] +fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { + let scan_config = FileScanConfig { + object_store_url: ObjectStoreUrl::local_filesystem(), + file_schema: Arc::new(Schema::new(vec![Field::new( + "col", + DataType::Utf8, + false, + )])), + file_groups: vec![vec![PartitionedFile::new( + "/path/to/file.parquet".to_string(), + 1024, + )]], + statistics: Statistics { + num_rows: Precision::Inexact(100), + total_byte_size: Precision::Inexact(1024), + column_statistics: Statistics::unknown_column(&Arc::new(Schema::new(vec![ + Field::new("col", DataType::Utf8, false), + ]))), + }, + projection: None, + limit: None, + table_partition_cols: vec![], + output_ordering: vec![], + }; + + #[derive(Debug, Hash, Clone)] + struct CustomPredicateExpr { + inner: Arc, + } + impl Display for CustomPredicateExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CustomPredicateExpr") + } + } + impl PartialEq for CustomPredicateExpr { + fn eq(&self, other: &dyn Any) -> bool { + other + .downcast_ref::() + .map(|x| self.inner.eq(&x.inner)) + .unwrap_or(false) + } + } + impl PhysicalExpr for CustomPredicateExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + unreachable!() + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + unreachable!() + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + unreachable!() + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.inner] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + todo!() + } + + fn dyn_hash(&self, _state: &mut dyn Hasher) { + unreachable!() + } + } + + #[derive(Debug)] + struct CustomPhysicalExtensionCodec; + impl PhysicalExtensionCodec for CustomPhysicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + unreachable!() + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + unreachable!() + } + + fn try_decode_expr( + &self, + buf: &[u8], + inputs: &[Arc], + ) -> Result> { + if buf == "CustomPredicateExpr".as_bytes() { + Ok(Arc::new(CustomPredicateExpr { + inner: inputs[0].clone(), + })) + } else { + internal_err!("Not supported") + } + } + + fn try_encode_expr( + &self, + node: &Arc, + buf: &mut Vec, + ) -> Result<()> { + if node + .as_any() + .downcast_ref::() + .is_some() + { + buf.extend_from_slice("CustomPredicateExpr".as_bytes()); + Ok(()) + } else { + internal_err!("Not supported") + } + } + } + + let custom_predicate_expr = Arc::new(CustomPredicateExpr { + inner: Arc::new(Column::new("col", 1)), + }); + let exec_plan = ParquetExec::builder(scan_config) + .with_predicate(custom_predicate_expr) + .build_arc(); + + let ctx = SessionContext::new(); + roundtrip_test_and_return(exec_plan, &ctx, &CustomPhysicalExtensionCodec {})?; + Ok(()) } #[test] @@ -613,20 +880,18 @@ fn roundtrip_scalar_udf() -> Result<()> { let udf = create_udf( "dummy", vec![DataType::Int64], - Arc::new(DataType::Int64), + DataType::Int64, Volatility::Immutable, scalar_fn.clone(), ); - let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let fun_def = Arc::new(udf.clone()); let expr = ScalarFunctionExpr::new( "dummy", fun_def, vec![col("a", &schema)?], DataType::Int64, - None, - false, ); let project = @@ -639,123 +904,95 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), &ctx) } -#[test] -fn roundtrip_scalar_udf_extension_codec() -> Result<()> { - #[derive(Debug)] - struct MyRegexUdf { - signature: Signature, - // regex as original string - pattern: String, +#[derive(Debug)] +struct UDFExtensionCodec; + +impl PhysicalExtensionCodec for UDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") } - impl MyRegexUdf { - fn new(pattern: String) -> Self { - Self { - signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), - pattern, - } - } + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") } - /// Implement the ScalarUDFImpl trait for MyRegexUdf - impl ScalarUDFImpl for MyRegexUdf { - fn as_any(&self) -> &dyn Any { - self - } + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to decode regex_udf: {err}")) + })?; - fn name(&self) -> &str { - "regex_udf" + Ok(Arc::new(ScalarUDF::from(MyRegexUdf::new(proto.pattern)))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, args: &[DataType]) -> Result { - if !matches!(args.first(), Some(&DataType::Utf8)) { - return plan_err!("regex_udf only accepts Utf8 arguments"); - } - Ok(DataType::Int64) - } - - fn invoke(&self, _args: &[ColumnarValue]) -> Result { - unimplemented!() - } - } - - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct MyRegexUdfNode { - #[prost(string, tag = "1")] - pub pattern: String, } - #[derive(Debug)] - pub struct ScalarUDFExtensionCodec {} - - impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[Arc], - _registry: &dyn FunctionRegistry, - ) -> Result> { - not_impl_err!("No extension codec provided") - } - - fn try_encode( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<()> { - not_impl_err!("No extension codec provided") + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err}")) + })?; } + Ok(()) + } - fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { - if name == "regex_udf" { - let proto = MyRegexUdfNode::decode(buf).map_err(|err| { - DataFusionError::Internal(format!( - "failed to decode regex_udf: {}", - err - )) - })?; - - Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( - proto.pattern, - )))) - } else { - not_impl_err!("unrecognized scalar UDF implementation, cannot decode") - } + fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "aggregate_udf" { + let proto = MyAggregateUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode aggregate_udf: {err}" + )) + })?; + + Ok(Arc::new(AggregateUDF::from(MyAggregateUDF::new( + proto.result, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") } + } - fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { - let binding = node.inner(); - if let Some(udf) = binding.as_any().downcast_ref::() { - let proto = MyRegexUdfNode { - pattern: udf.pattern.clone(), - }; - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!("failed to encode udf: {e:?}")) - })?; - } - Ok(()) + fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyAggregateUdfNode { + result: udf.result.clone(), + }; + proto.encode(buf).map_err(|err| { + DataFusionError::Internal(format!("failed to encode udf: {err:?}")) + })?; } + Ok(()) } +} +#[test] +fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let field_text = Field::new("text", DataType::Utf8, true); let field_published = Field::new("published", DataType::Boolean, false); let field_author = Field::new("author", DataType::Utf8, false); let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); let input = Arc::new(EmptyExec::new(schema.clone())); - let pattern = ".*"; - let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); let udf_expr = Arc::new(ScalarFunctionExpr::new( - udf.name(), - ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + "regex_udf", + Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], DataType::Int64, - None, - false, )); let filter = Arc::new(FilterExec::try_new( @@ -766,12 +1003,18 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )), input, )?); + let aggr_expr = + AggregateExprBuilder::new(max_udaf(), vec![udf_expr as Arc]) + .schema(schema.clone()) + .alias("max") + .build() + .map(Arc::new)?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new( - Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), + aggr_expr.clone(), &[col("author", &schema)?], - &[], + LexOrderingRef::default(), Arc::new(WindowFrame::new(None)), ))], filter, @@ -781,41 +1024,84 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))], + vec![aggr_expr], vec![None], window, - schema.clone(), + schema, )?); let ctx = SessionContext::new(); - let codec = ScalarUDFExtensionCodec {}; - roundtrip_test_and_return(aggregate, &ctx, &codec)?; + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; Ok(()) } #[test] -fn roundtrip_distinct_count() -> Result<()> { - let field_a = Field::new("a", DataType::Int64, false); - let field_b = Field::new("b", DataType::Int64, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); +fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { + let field_text = Field::new("text", DataType::Utf8, true); + let field_published = Field::new("published", DataType::Boolean, false); + let field_author = Field::new("author", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); + let input = Arc::new(EmptyExec::new(schema.clone())); - let aggregates: Vec> = vec![Arc::new(DistinctCount::new( + let udf_expr = Arc::new(ScalarFunctionExpr::new( + "regex_udf", + Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), + vec![col("text", &schema)?], DataType::Int64, - col("b", &schema)?, - "COUNT(DISTINCT b)".to_string(), - ))]; + )); - let groups: Vec<(Arc, String)> = - vec![(col("a", &schema)?, "unused".to_string())]; + let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( + "result".to_string(), + ))); + let aggr_args: Vec> = + vec![Arc::new(Literal::new(ScalarValue::from(42)))]; + + let aggr_expr = AggregateExprBuilder::new(Arc::clone(&udaf), aggr_args.clone()) + .schema(Arc::clone(&schema)) + .alias("aggregate_udf") + .build() + .map(Arc::new)?; - roundtrip_test(Arc::new(AggregateExec::try_new( + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("published", &schema)?, + Operator::And, + Arc::new(BinaryExpr::new(udf_expr, Operator::Gt, lit(0))), + )), + input, + )?); + + let window = Arc::new(WindowAggExec::try_new( + vec![Arc::new(PlainAggregateWindowExpr::new( + aggr_expr, + &[col("author", &schema)?], + LexOrderingRef::default(), + Arc::new(WindowFrame::new(None)), + ))], + filter, + vec![col("author", &schema)?], + )?); + + let aggr_expr = AggregateExprBuilder::new(udaf, aggr_args.clone()) + .schema(Arc::clone(&schema)) + .alias("aggregate_udf") + .distinct() + .ignore_nulls() + .build() + .map(Arc::new)?; + + let aggregate = Arc::new(AggregateExec::try_new( AggregateMode::Final, - PhysicalGroupBy::new_single(groups), - aggregates.clone(), + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![aggr_expr], vec![None], - Arc::new(EmptyExec::new(schema.clone())), + window, schema, - )?)) + )?); + + let ctx = SessionContext::new(); + roundtrip_test_and_return(aggregate, &ctx, &UDFExtensionCodec)?; + Ok(()) } #[test] @@ -867,24 +1153,25 @@ fn roundtrip_json_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: true, }; let data_sink = Arc::new(JsonSink::new( file_sink_config, JsonWriterOptions::new(CompressionTypeVariant::UNCOMPRESSED), )); - let sort_order = vec![PhysicalSortRequirement::new( + let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]; + )]); roundtrip_test(Arc::new(DataSinkExec::new( input, data_sink, - schema.clone(), + schema, Some(sort_order), ))) } @@ -902,19 +1189,20 @@ fn roundtrip_csv_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: true, }; let data_sink = Arc::new(CsvSink::new( file_sink_config, CsvWriterOptions::new(WriterBuilder::default(), CompressionTypeVariant::ZSTD), )); - let sort_order = vec![PhysicalSortRequirement::new( + let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]; + )]); let ctx = SessionContext::new(); let codec = DefaultPhysicalExtensionCodec {}; @@ -922,7 +1210,7 @@ fn roundtrip_csv_sink() -> Result<()> { Arc::new(DataSinkExec::new( input, data_sink, - schema.clone(), + schema, Some(sort_order), )), &ctx, @@ -960,24 +1248,25 @@ fn roundtrip_parquet_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: true, }; let data_sink = Arc::new(ParquetSink::new( file_sink_config, TableParquetOptions::default(), )); - let sort_order = vec![PhysicalSortRequirement::new( + let sort_order = LexRequirement::new(vec![PhysicalSortRequirement::new( Arc::new(Column::new("plan_type", 0)), Some(SortOptions { descending: true, nulls_first: false, }), - )]; + )]); roundtrip_test(Arc::new(DataSinkExec::new( input, data_sink, - schema.clone(), + schema, Some(sort_order), ))) } @@ -1010,17 +1299,17 @@ fn roundtrip_sym_hash_join() -> Result<()> { ] { for left_order in &[ None, - Some(vec![PhysicalSortExpr { + Some(LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_left.index_of("col")?)), options: Default::default(), - }]), + }])), ] { for right_order in &[ None, - Some(vec![PhysicalSortExpr { + Some(LexOrdering::new(vec![PhysicalSortExpr { expr: Arc::new(Column::new("col", schema_right.index_of("col")?)), options: Default::default(), - }]), + }])), ] { roundtrip_test(Arc::new( datafusion::physical_plan::joins::SymmetricHashJoinExec::try_new( @@ -1066,9 +1355,55 @@ fn roundtrip_interleave() -> Result<()> { )?; let right = RepartitionExec::try_new( Arc::new(EmptyExec::new(Arc::new(schema_right))), - partition.clone(), + partition, )?; let inputs: Vec> = vec![Arc::new(left), Arc::new(right)]; let interleave = InterleaveExec::try_new(inputs)?; roundtrip_test(Arc::new(interleave)) } + +#[test] +fn roundtrip_unnest() -> Result<()> { + let fa = Field::new("a", DataType::Int64, true); + let fb0 = Field::new_list_field(DataType::Utf8, true); + let fb = Field::new_list("b", fb0.clone(), false); + let fc1 = Field::new("c1", DataType::Boolean, false); + let fc2 = Field::new("c2", DataType::Date64, true); + let fc = Field::new_struct("c", Fields::from(vec![fc1.clone(), fc2.clone()]), true); + let fd0 = Field::new_list_field(DataType::Float32, false); + let fd = Field::new_list("d", fd0.clone(), true); + let fe1 = Field::new("e1", DataType::UInt16, false); + let fe2 = Field::new("e2", DataType::Duration(TimeUnit::Millisecond), true); + let fe3 = Field::new("e3", DataType::Timestamp(TimeUnit::Millisecond, None), true); + let fe_fields = Fields::from(vec![fe1.clone(), fe2.clone(), fe3.clone()]); + let fe = Field::new_struct("e", fe_fields, false); + + let fb0 = fb0.with_name("b"); + let fd0 = fd0.with_name("d"); + let input_schema = Arc::new(Schema::new(vec![fa.clone(), fb, fc, fd, fe])); + let output_schema = + Arc::new(Schema::new(vec![fa, fb0, fc1, fc2, fd0, fe1, fe2, fe3])); + let input = Arc::new(EmptyExec::new(input_schema)); + let options = UnnestOptions::default(); + let unnest = UnnestExec::new( + input, + vec![ + ListUnnest { + index_in_input_schema: 1, + depth: 1, + }, + ListUnnest { + index_in_input_schema: 1, + depth: 2, + }, + ListUnnest { + index_in_input_schema: 3, + depth: 2, + }, + ], + vec![2, 4], + output_schema, + options, + ); + roundtrip_test(Arc::new(unnest)) +} diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index cc683e778ebc..d1b50105d053 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -238,7 +238,7 @@ fn context_with_udf() -> SessionContext { let udf = create_udf( "dummy", vec![DataType::Utf8], - Arc::new(DataType::Utf8), + DataType::Utf8, Volatility::Immutable, scalar_fn, ); @@ -276,7 +276,7 @@ fn test_expression_serialization_roundtrip() { /// Extracts the first part of a function name /// 'foo(bar)' -> 'foo' fn extract_function_name(expr: &Expr) -> String { - let name = expr.display_name().unwrap(); + let name = expr.schema_name().to_string(); name.split('(').next().unwrap().to_string() } } diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index ef3ed265c7ab..1eef1b718ba6 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -46,13 +46,18 @@ arrow-array = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +indexmap = { workspace = true } log = { workspace = true } +regex = { workspace = true } sqlparser = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } datafusion-functions = { workspace = true, default-features = true } +datafusion-functions-aggregate = { workspace = true } +datafusion-functions-nested = { workspace = true } +datafusion-functions-window = { workspace = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 5bab2f19cfc0..aa17be6273ae 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -15,26 +15,32 @@ // specific language governing permissions and limitations // under the License. +use std::{collections::HashMap, sync::Arc}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::config::ConfigOptions; use datafusion_common::{plan_err, Result}; +use datafusion_expr::planner::ExprPlanner; use datafusion_expr::WindowUDF; use datafusion_expr::{ logical_plan::builder::LogicalTableSource, AggregateUDF, ScalarUDF, TableSource, }; +use datafusion_functions::core::planner::CoreFunctionPlanner; +use datafusion_functions_aggregate::count::count_udaf; +use datafusion_functions_aggregate::sum::sum_udaf; use datafusion_sql::{ planner::{ContextProvider, SqlToRel}, sqlparser::{dialect::GenericDialect, parser::Parser}, TableReference, }; -use std::{collections::HashMap, sync::Arc}; fn main() { let sql = "SELECT \ c.id, c.first_name, c.last_name, \ COUNT(*) as num_orders, \ - SUM(o.price) AS total_price, \ - SUM(o.price * s.sales_tax) AS state_tax \ + sum(o.price) AS total_price, \ + sum(o.price * s.sales_tax) AS state_tax \ FROM customer c \ JOIN state s ON c.state = s.id \ JOIN orders o ON c.id = o.customer_id \ @@ -49,20 +55,35 @@ fn main() { let statement = &ast[0]; // create a logical query plan - let context_provider = MyContextProvider::new(); + let context_provider = MyContextProvider::new() + .with_udaf(sum_udaf()) + .with_udaf(count_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); // show the plan - println!("{plan:?}"); + println!("{plan}"); } struct MyContextProvider { options: ConfigOptions, tables: HashMap>, + udafs: HashMap>, + expr_planners: Vec>, } impl MyContextProvider { + fn with_udaf(mut self, udaf: Arc) -> Self { + self.udafs.insert(udaf.name().to_string(), udaf); + self + } + + fn with_expr_planner(mut self, planner: Arc) -> Self { + self.expr_planners.push(planner); + self + } + fn new() -> Self { let mut tables = HashMap::new(); tables.insert( @@ -94,6 +115,8 @@ impl MyContextProvider { Self { tables, options: Default::default(), + udafs: Default::default(), + expr_planners: vec![], } } } @@ -107,7 +130,7 @@ fn create_table_source(fields: Vec) -> Arc { impl ContextProvider for MyContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), + Some(table) => Ok(Arc::clone(table)), _ => plan_err!("Table not found: {}", name.table()), } } @@ -116,8 +139,8 @@ impl ContextProvider for MyContextProvider { None } - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.udafs.get(name).cloned() } fn get_variable_type(&self, _variable_names: &[String]) -> Option { @@ -132,15 +155,19 @@ impl ContextProvider for MyContextProvider { &self.options } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } + + fn get_expr_planners(&self) -> &[Arc] { + &self.expr_planners + } } diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 4f7b9bb6d11d..c288d6ca7067 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -38,7 +38,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Process CTEs from top to bottom for cte in with.cte_tables { // A `WITH` block can't use the same name more than once - let cte_name = self.normalizer.normalize(cte.alias.name.clone()); + let cte_name = self.ident_normalizer.normalize(cte.alias.name.clone()); if planner_context.contains_cte(&cte_name) { return plan_err!( "WITH query name {cte_name:?} specified more than once" @@ -66,10 +66,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { cte_query: Query, planner_context: &mut PlannerContext, ) -> Result { - // CTE expr don't need extend outer_query_schema, - // so we clone a new planner_context here. - let mut cte_planner_context = planner_context.clone(); - self.query_to_plan(cte_query, &mut cte_planner_context) + self.query_to_plan(cte_query, planner_context) } fn recursive_cte( @@ -101,8 +98,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - // Each recursive CTE consists from two parts in the logical plan: - // 1. A static term (the left hand side on the SQL, where the + // Each recursive CTE consists of two parts in the logical plan: + // 1. A static term (the left-hand side on the SQL, where the // referencing to the same CTE is not allowed) // // 2. A recursive term (the right hand side, and the recursive @@ -113,8 +110,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // allow us to infer the schema to be used in the recursive term. // ---------- Step 1: Compile the static term ------------------ - let static_plan = - self.set_expr_to_plan(*left_expr, &mut planner_context.clone())?; + let static_plan = self.set_expr_to_plan(*left_expr, planner_context)?; // Since the recursive CTEs include a component that references a // table with its name, like the example below: @@ -148,7 +144,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // as the input to the recursive term let work_table_plan = LogicalPlanBuilder::scan( cte_name.to_string(), - work_table_source.clone(), + Arc::clone(&work_table_source), None, )? .build()?; @@ -166,8 +162,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // this uses the named_relation we inserted above to resolve the // relation. This ensures that the recursive term uses the named relation logical plan // and thus the 'continuance' physical plan as its input and source - let recursive_plan = - self.set_expr_to_plan(*right_expr, &mut planner_context.clone())?; + let recursive_plan = self.set_expr_to_plan(*right_expr, planner_context)?; // Check if the recursive term references the CTE itself, // if not, it is a non-recursive CTE diff --git a/datafusion/sql/src/expr/binary_op.rs b/datafusion/sql/src/expr/binary_op.rs index 0d37742e5b07..fcb57e8a82e4 100644 --- a/datafusion/sql/src/expr/binary_op.rs +++ b/datafusion/sql/src/expr/binary_op.rs @@ -51,6 +51,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), BinaryOperator::StringConcat => Ok(Operator::StringConcat), + BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + BinaryOperator::AtArrow => Ok(Operator::AtArrow), _ => not_impl_err!("Unsupported SQL binary operator {op:?}"), } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 68cba15634d5..619eadcf0fb8 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -16,67 +16,183 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + use arrow_schema::DataType; use datafusion_common::{ - internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, - Dependency, Result, + internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, + DFSchema, Dependency, Result, }; -use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; +use datafusion_expr::expr::WildcardOptions; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::{ - expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, + expr, Expr, ExprFunctionExt, ExprSchemable, WindowFrame, WindowFunctionDefinition, }; use datafusion_expr::{ expr::{ScalarFunction, Unnest}, BuiltInWindowFunction, }; use sqlparser::ast::{ - Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, + DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, + FunctionArgExpr, FunctionArgumentClause, FunctionArgumentList, FunctionArguments, + NullTreatment, ObjectName, OrderByExpr, WindowType, }; -use std::str::FromStr; use strum::IntoEnumIterator; /// Suggest a valid function based on an invalid input function name +/// +/// Returns `None` if no valid matches are found. This happens when there are no +/// functions registered with the context. pub fn suggest_valid_function( input_function_name: &str, is_window_func: bool, ctx: &dyn ContextProvider, -) -> String { +) -> Option { let valid_funcs = if is_window_func { // All aggregate functions and builtin window functions let mut funcs = Vec::new(); - funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); - funcs.extend(ctx.udafs_names()); + funcs.extend(ctx.udaf_names()); funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string())); - funcs.extend(ctx.udwfs_names()); + funcs.extend(ctx.udwf_names()); funcs } else { // All scalar functions and aggregate functions let mut funcs = Vec::new(); - funcs.extend(ctx.udfs_names()); - funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); - funcs.extend(ctx.udafs_names()); + funcs.extend(ctx.udf_names()); + funcs.extend(ctx.udaf_names()); funcs }; find_closest_match(valid_funcs, input_function_name) } -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) -/// Input `candidates` must not be empty otherwise it will panic -fn find_closest_match(candidates: Vec, target: &str) -> String { +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitive) +/// Input `candidates` must not be empty otherwise an error is returned. +fn find_closest_match(candidates: Vec, target: &str) -> Option { let target = target.to_lowercase(); - candidates - .into_iter() - .min_by_key(|candidate| { - datafusion_common::utils::datafusion_strsim::levenshtein( - &candidate.to_lowercase(), - &target, - ) + candidates.into_iter().min_by_key(|candidate| { + datafusion_common::utils::datafusion_strsim::levenshtein( + &candidate.to_lowercase(), + &target, + ) + }) +} + +/// Arguments to for a function call extracted from the SQL AST +#[derive(Debug)] +struct FunctionArgs { + /// Function name + name: ObjectName, + /// Argument expressions + args: Vec, + /// ORDER BY clause, if any + order_by: Vec, + /// OVER clause, if any + over: Option, + /// FILTER clause, if any + filter: Option>, + /// NULL treatment clause, if any + null_treatment: Option, + /// DISTINCT + distinct: bool, +} + +impl FunctionArgs { + fn try_new(function: SQLFunction) -> Result { + let SQLFunction { + name, + args, + over, + filter, + mut null_treatment, + within_group, + .. + } = function; + + // Handle no argument form (aka `current_time` as opposed to `current_time()`) + let FunctionArguments::List(args) = args else { + return Ok(Self { + name, + args: vec![], + order_by: vec![], + over, + filter, + null_treatment, + distinct: false, + }); + }; + + let FunctionArgumentList { + duplicate_treatment, + args, + clauses, + } = args; + + let distinct = match duplicate_treatment { + Some(DuplicateTreatment::Distinct) => true, + Some(DuplicateTreatment::All) => false, + None => false, + }; + + // Pull out argument handling + let mut order_by = None; + for clause in clauses { + match clause { + FunctionArgumentClause::IgnoreOrRespectNulls(nt) => { + if null_treatment.is_some() { + return not_impl_err!( + "Calling {name}: Duplicated null treatment clause" + ); + } + null_treatment = Some(nt); + } + FunctionArgumentClause::OrderBy(oby) => { + if order_by.is_some() { + return not_impl_err!("Calling {name}: Duplicated ORDER BY clause in function arguments"); + } + order_by = Some(oby); + } + FunctionArgumentClause::Limit(limit) => { + return not_impl_err!( + "Calling {name}: LIMIT not supported in function arguments: {limit}" + ) + } + FunctionArgumentClause::OnOverflow(overflow) => { + return not_impl_err!( + "Calling {name}: ON OVERFLOW not supported in function arguments: {overflow}" + ) + } + FunctionArgumentClause::Having(having) => { + return not_impl_err!( + "Calling {name}: HAVING not supported in function arguments: {having}" + ) + } + FunctionArgumentClause::Separator(sep) => { + return not_impl_err!( + "Calling {name}: SEPARATOR not supported in function arguments: {sep}" + ) + } + } + } + + if !within_group.is_empty() { + return not_impl_err!("WITHIN GROUP is not supported yet: {within_group:?}"); + } + + let order_by = order_by.unwrap_or_default(); + + Ok(Self { + name, + args, + order_by, + over, + filter, + null_treatment, + distinct, }) - .expect("No candidates provided.") // Panic if `candidates` argument is empty + } } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -86,16 +202,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let SQLFunction { + let function_args = FunctionArgs::try_new(function)?; + let FunctionArgs { name, args, + order_by, over, - distinct, filter, null_treatment, - special: _, // true if not called with trailing parens - order_by, - } = function; + distinct, + } = function_args; // If function is a window function (it has an OVER clause), // it shouldn't have ordering requirement as function argument @@ -110,7 +226,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { crate::utils::normalize_ident(name.0[0].clone()) }; - // user-defined function (UDF) should have precedence + if name.eq("make_map") { + let mut fn_args = + self.function_args_to_expr(args.clone(), schema, planner_context)?; + for planner in self.context_provider.get_expr_planners().iter() { + match planner.plan_make_map(fn_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => fn_args = args, + } + } + } + + // User-defined function (UDF) should have precedence if let Some(fm) = self.context_provider.get_function_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; return Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fm, args))); @@ -118,8 +245,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Build Unnest expression if name.eq("unnest") { - let mut exprs = - self.function_args_to_expr(args.clone(), schema, planner_context)?; + let mut exprs = self.function_args_to_expr(args, schema, planner_context)?; if exprs.len() != 1 { return plan_err!("unnest() requires exactly one argument"); } @@ -134,39 +260,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } - // then, window function + // Then, window function if let Some(WindowType::WindowSpec(window)) = over { let partition_by = window .partition_by .into_iter() - // ignore window spec PARTITION BY for scalar values + // Ignore window spec PARTITION BY for scalar values // as they do not change and thus do not generate new partitions .filter(|e| !matches!(e, sqlparser::ast::Expr::Value { .. },)) .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let mut order_by = self.order_by_to_sort_expr( - &window.order_by, + window.order_by, schema, planner_context, // Numeric literals in window function ORDER BY are treated as constants false, + None, )?; let func_deps = schema.functional_dependencies(); // Find whether ties are possible in the given ordering let is_ordering_strict = order_by.iter().find_map(|orderby_expr| { - if let Expr::Sort(sort_expr) = orderby_expr { - if let Expr::Column(col) = sort_expr.expr.as_ref() { - let idx = schema.index_of_column(col).ok()?; - return if func_deps.iter().any(|dep| { - dep.source_indices == vec![idx] - && dep.mode == Dependency::Single - }) { - Some(true) - } else { - Some(false) - }; - } + if let Expr::Column(col) = &orderby_expr.expr { + let idx = schema.index_of_column(col).ok()?; + return if func_deps.iter().any(|dep| { + dep.source_indices == vec![idx] && dep.mode == Dependency::Single + }) { + Some(true) + } else { + Some(false) + }; } Some(false) }); @@ -175,14 +299,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .window_frame .as_ref() .map(|window_frame| { - let window_frame = window_frame.clone().try_into()?; - check_window_frame(&window_frame, order_by.len()) + let window_frame: WindowFrame = window_frame.clone().try_into()?; + window_frame + .regularize_order_bys(&mut order_by) .map(|_| window_frame) }) .transpose()?; let window_frame = if let Some(window_frame) = window_frame { - regularize_window_order_by(&window_frame, &mut order_by)?; window_frame } else if let Some(is_ordering_strict) = is_ordering_strict { WindowFrame::new(Some(is_ordering_strict)) @@ -191,75 +315,51 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; if let Ok(fun) = self.find_window_func(&name) { - let expr = match fun { - WindowFunctionDefinition::AggregateFunction(aggregate_fun) => { - let args = - self.function_args_to_expr(args, schema, planner_context)?; - - Expr::WindowFunction(expr::WindowFunction::new( - WindowFunctionDefinition::AggregateFunction(aggregate_fun), - args, - partition_by, - order_by, - window_frame, - null_treatment, - )) - } - _ => Expr::WindowFunction(expr::WindowFunction::new( - fun, - self.function_args_to_expr(args, schema, planner_context)?, - partition_by, - order_by, - window_frame, - null_treatment, - )), - }; - return Ok(expr); + return Expr::WindowFunction(expr::WindowFunction::new( + fun, + self.function_args_to_expr(args, schema, planner_context)?, + )) + .partition_by(partition_by) + .order_by(order_by) + .window_frame(window_frame) + .null_treatment(null_treatment) + .build(); } } else { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { - let order_by = - self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; - let order_by = (!order_by.is_empty()).then_some(order_by); - let args = self.function_args_to_expr(args, schema, planner_context)?; - // TODO: Support filter and distinct for UDAFs - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fm, - args, - false, - None, + let order_by = self.order_by_to_sort_expr( order_by, - null_treatment, - ))); - } - - // next, aggregate built-ins - if let Ok(fun) = AggregateFunction::from_str(&name) { - let order_by = - self.order_by_to_sort_expr(&order_by, schema, planner_context, true)?; + schema, + planner_context, + true, + None, + )?; let order_by = (!order_by.is_empty()).then_some(order_by); let args = self.function_args_to_expr(args, schema, planner_context)?; let filter: Option> = filter .map(|e| self.sql_expr_to_logical_expr(*e, schema, planner_context)) .transpose()? .map(Box::new); - - return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, distinct, filter, order_by, null_treatment, ))); - }; + } } // Could not find the relevant function, so return an error - let suggested_func_name = - suggest_valid_function(&name, is_function_window, self.context_provider); - plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") + if let Some(suggested_func_name) = + suggest_valid_function(&name, is_function_window, self.context_provider) + { + plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") + } else { + internal_err!("No functions registered with this context.") + } } pub(super) fn sql_fn_name_to_expr( @@ -283,22 +383,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, name: &str, ) -> Result { - expr::find_df_window_func(name) - // next check user defined aggregates - .or_else(|| { - self.context_provider - .get_aggregate_meta(name) - .map(WindowFunctionDefinition::AggregateUDF) - }) - // next check user defined window functions - .or_else(|| { - self.context_provider - .get_window_meta(name) - .map(WindowFunctionDefinition::WindowUDF) - }) - .ok_or_else(|| { - plan_datafusion_err!("There is no window function named {name}") - }) + // Check udaf first + let udaf = self.context_provider.get_aggregate_meta(name); + // Use the builtin window function instead of the user-defined aggregate function + if udaf.as_ref().is_some_and(|udaf| { + udaf.name() != "first_value" + && udaf.name() != "last_value" + && udaf.name() != "nth_value" + }) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf.unwrap())) + } else { + expr::find_df_window_func(name) + .or_else(|| { + self.context_provider + .get_window_meta(name) + .map(WindowFunctionDefinition::WindowUDF) + }) + .ok_or_else(|| { + plan_datafusion_err!("There is no window function named {name}") + }) + } } fn sql_fn_arg_to_logical_expr( @@ -317,12 +421,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name: _, arg: FunctionArgExpr::Wildcard, operator: _, - } => Ok(Expr::Wildcard { qualifier: None }), + } => Ok(Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { self.sql_expr_to_logical_expr(arg, schema, planner_context) } - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Wildcard { qualifier: None }) + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }), + FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(object_name)) => { + let qualifier = self.object_name_to_table_reference(object_name)?; + // Sanity check on qualifier with schema + let qualified_indices = schema.fields_indices_with_qualified(&qualifier); + if qualified_indices.is_empty() { + return plan_err!("Invalid qualifier {qualifier}"); + } + Ok(Expr::Wildcard { + qualifier: Some(qualifier), + options: WildcardOptions::default(), + }) } _ => not_impl_err!("Unsupported qualified wildcard argument: {sql:?}"), } @@ -344,10 +464,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match arg.get_type(schema)? { DataType::List(_) | DataType::LargeList(_) - | DataType::FixedSizeList(_, _) => Ok(()), - DataType::Struct(_) => { - not_impl_err!("unnest() does not support struct yet") - } + | DataType::FixedSizeList(_, _) + | DataType::Struct(_) => Ok(()), DataType::Null => { not_impl_err!("unnest() does not support null yet") } diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 713ad6f72c24..e103f68fc927 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -15,14 +15,18 @@ // specific language governing permissions and limitations // under the License. -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; +use sqlparser::ast::{Expr as SQLExpr, Ident}; + use datafusion_common::{ - internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, - TableReference, + internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, + DataFusionError, Result, TableReference, }; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::{Case, Expr}; -use sqlparser::ast::{Expr as SQLExpr, Ident}; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_expr::UNNAMED_TABLE; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_identifier_to_expr( @@ -46,41 +50,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // interpret names with '.' as if they were // compound identifiers, but this is not a compound // identifier. (e.g. it is "foo.bar" not foo.bar) - let normalize_ident = self.normalizer.normalize(id); - match schema.field_with_unqualified_name(normalize_ident.as_str()) { - Ok(_) => { - // found a match without a qualified name, this is a inner table column - Ok(Expr::Column(Column { - relation: None, - name: normalize_ident, - })) - } - Err(_) => { - // check the outer_query_schema and try to find a match - if let Some(outer) = planner_context.outer_query_schema() { - match outer.qualified_field_with_unqualified_name( - normalize_ident.as_str(), - ) { - Ok((qualifier, field)) => { - // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column - Ok(Expr::OuterReferenceColumn( - field.data_type().clone(), - Column::from((qualifier, field)), - )) - } - Err(_) => Ok(Expr::Column(Column { - relation: None, - name: normalize_ident, - })), - } - } else { - Ok(Expr::Column(Column { - relation: None, - name: normalize_ident, - })) - } + let normalize_ident = self.ident_normalizer.normalize(id); + + // Check for qualified field with unqualified name + if let Ok((qualifier, _)) = + schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) + { + return Ok(Expr::Column(Column { + relation: qualifier.filter(|q| q.table() != UNNAMED_TABLE).cloned(), + name: normalize_ident, + })); + } + + // Check the outer query schema + if let Some(outer) = planner_context.outer_query_schema() { + if let Ok((qualifier, field)) = + outer.qualified_field_with_unqualified_name(normalize_ident.as_str()) + { + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + return Ok(Expr::OuterReferenceColumn( + field.data_type().clone(), + Column::from((qualifier, field)), + )); } } + + // Default case + Ok(Expr::Column(Column { + relation: None, + name: normalize_ident, + })) } } @@ -97,7 +96,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if ids[0].value.starts_with('@') { let var_names: Vec<_> = ids .into_iter() - .map(|id| self.normalizer.normalize(id)) + .map(|id| self.ident_normalizer.normalize(id)) .collect(); let ty = self .context_provider @@ -111,63 +110,63 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let ids = ids .into_iter() - .map(|id| self.normalizer.normalize(id)) + .map(|id| self.ident_normalizer.normalize(id)) .collect::>(); - // Currently not supporting more than one nested level - // Though ideally once that support is in place, this code should work with it - // TODO: remove when can support multiple nested identifiers - if ids.len() > 5 { - return internal_err!("Unsupported compound identifier: {ids:?}"); - } - let search_result = search_dfschema(&ids, schema); match search_result { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { - // TODO: remove when can support multiple nested identifiers - if nested_names.len() > 1 { - return internal_err!( - "Nested identifiers not yet supported for column {}", - Column::from((qualifier, field)).quoted_flat_name() - ); + // Found matching field with spare identifier(s) for nested field(s) in structure + for planner in self.context_provider.get_expr_planners() { + if let Ok(planner_result) = planner.plan_compound_identifier( + field, + qualifier, + nested_names, + ) { + match planner_result { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(_args) => {} + } + } } - let nested_name = nested_names[0].to_string(); - Ok(Expr::Column(Column::from((qualifier, field))).field(nested_name)) + plan_err!("could not parse compound identifier from {ids:?}") } - // found matching field with no spare identifier(s) + // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { Ok(Expr::Column(Column::from((qualifier, field)))) } None => { - // return default where use all identifiers to not have a nested field + // Return default where use all identifiers to not have a nested field // this len check is because at 5 identifiers will have to have a nested field if ids.len() == 5 { - internal_err!("Unsupported compound identifier: {ids:?}") + not_impl_err!("compound identifier: {ids:?}") } else { - // check the outer_query_schema and try to find a match + // Check the outer_query_schema and try to find a match if let Some(outer) = planner_context.outer_query_schema() { let search_result = search_dfschema(&ids, outer); match search_result { - // found matching field with spare identifier(s) for nested field(s) in structure + // Found matching field with spare identifier(s) for nested field(s) in structure Some((field, qualifier, nested_names)) if !nested_names.is_empty() => { // TODO: remove when can support nested identifiers for OuterReferenceColumn - internal_err!( + not_impl_err!( "Nested identifiers are not yet supported for OuterReferenceColumn {}", Column::from((qualifier, field)).quoted_flat_name() ) } - // found matching field with no spare identifier(s) + // Found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { - // found an exact match on a qualified name in the outer plan schema, so this is an outer reference column + // Found an exact match on a qualified name in the outer plan schema, so this is an outer reference column Ok(Expr::OuterReferenceColumn( field.data_type().clone(), Column::from((qualifier, field)), )) } - // found no matching field, will return a default + // Found no matching field, will return a default None => { let s = &ids[0..ids.len()]; // safe unwrap as s can never be empty or exceed the bounds @@ -178,7 +177,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } else { let s = &ids[0..ids.len()]; - // safe unwrap as s can never be empty or exceed the bounds + // Safe unwrap as s can never be empty or exceed the bounds let (relation, column_name) = form_identifier(s).unwrap(); Ok(Expr::Column(Column::new(relation, column_name))) } @@ -312,15 +311,15 @@ fn search_dfschema<'ids, 'schema>( fn generate_schema_search_terms( ids: &[String], ) -> impl Iterator, &String, &[String])> { - // take at most 4 identifiers to form a Column to search with + // Take at most 4 identifiers to form a Column to search with // - 1 for the column name // - 0 to 3 for the TableReference let bound = ids.len().min(4); - // search terms from most specific to least specific + // Search terms from most specific to least specific (0..bound).rev().map(|i| { let nested_names_index = i + 1; let qualifier_and_column = &ids[0..nested_names_index]; - // safe unwrap as qualifier_and_column can never be empty or exceed the bounds + // Safe unwrap as qualifier_and_column can never be empty or exceed the bounds let (relation, column_name) = form_identifier(qualifier_and_column).unwrap(); (relation, column_name, &ids[nested_names_index..]) }) @@ -332,7 +331,7 @@ mod test { #[test] // testing according to documentation of generate_schema_search_terms function - // where ensure generated search terms are in correct order with correct values + // where it ensures generated search terms are in correct order with correct values fn test_generate_schema_search_terms() -> Result<()> { type ExpectedItem = ( Option, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 0d1db8a29cce..432e8668c52e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -17,19 +17,23 @@ use arrow_schema::DataType; use arrow_schema::TimeUnit; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; -use sqlparser::parser::ParserError::ParserError; +use datafusion_expr::planner::{ + PlannerResult, RawBinaryExpr, RawDictionaryExpr, RawFieldAccessExpr, +}; +use sqlparser::ast::{ + BinaryOperator, CastFormat, CastKind, DataType as SQLDataType, DictionaryField, + Expr as SQLExpr, MapEntry, StructField, Subscript, TrimWhereField, Value, +}; use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, }; -use datafusion_expr::expr::AggregateFunctionDefinition; -use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{InList, WildcardOptions}; use datafusion_expr::{ - col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast, + lit, Between, BinaryExpr, Cast, Expr, ExprSchemable, GetFieldAccess, Like, Literal, + Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -38,7 +42,6 @@ mod binary_op; mod function; mod grouping_set; mod identifier; -mod json_access; mod order_by; mod subquery; mod substring; @@ -54,7 +57,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { enum StackEntry { SQLExpr(Box), - Operator(Operator), + Operator(BinaryOperator), } // Virtual stack machine to convert SQLExpr to Expr @@ -71,17 +74,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::BinaryOp { left, op, right } => { // Note the order that we push the entries to the stack // is important. We want to visit the left node first. - let op = self.parse_sql_binary_op(op)?; - stack.push(StackEntry::Operator(op)); - stack.push(StackEntry::SQLExpr(right)); - stack.push(StackEntry::SQLExpr(left)); - } - SQLExpr::JsonAccess { - left, - operator, - right, - } => { - let op = self.parse_sql_json_access(operator)?; stack.push(StackEntry::Operator(op)); stack.push(StackEntry::SQLExpr(right)); stack.push(StackEntry::SQLExpr(left)); @@ -99,13 +91,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { StackEntry::Operator(op) => { let right = eval_stack.pop().unwrap(); let left = eval_stack.pop().unwrap(); - - let expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(left), - op, - Box::new(right), - )); - + let expr = self.build_logical_expr(op, left, right, schema)?; eval_stack.push(expr); } } @@ -116,6 +102,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(expr) } + fn build_logical_expr( + &self, + op: BinaryOperator, + left: Expr, + right: Expr, + schema: &DFSchema, + ) -> Result { + // try extension planers + let mut binary_expr = RawBinaryExpr { op, left, right }; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_binary_op(binary_expr, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => { + binary_expr = expr; + } + } + } + + let RawBinaryExpr { op, left, right } = binary_expr; + Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left), + self.parse_sql_binary_op(op)?, + Box::new(right), + ))) + } + /// Generate a relational expression from a SQL expression pub fn sql_to_expr( &self, @@ -160,92 +174,45 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { + // NOTE: This function is called recusively, so each match arm body should be as + // small as possible to avoid stack overflows in debug builds. Follow the + // common pattern of extracting into a separate function for non-trivial + // arms. See https://github.com/apache/datafusion/pull/12384 for more context. match sql { SQLExpr::Value(value) => { self.parse_value(value, planner_context.prepare_param_data_types()) } - SQLExpr::Extract { field, expr } => { - let date_part = self - .context_provider - .get_function_meta("date_part") - .ok_or_else(|| { - internal_datafusion_err!( - "Unable to find expected 'date_part' function" - ) - })?; - let args = vec![ + SQLExpr::Extract { field, expr, .. } => { + let mut extract_args = vec![ Expr::Literal(ScalarValue::from(format!("{field}"))), self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - date_part, args, - ))) + + for planner in self.context_provider.get_expr_planners() { + match planner.plan_extract(extract_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + extract_args = args; + } + } + } + + not_impl_err!("Extract not supported by ExprPlanner: {extract_args:?}") } SQLExpr::Array(arr) => self.sql_array_literal(arr.elem, schema), - SQLExpr::Interval(interval) => { - self.sql_interval_to_expr(false, interval, schema, planner_context) - } + SQLExpr::Interval(interval) => self.sql_interval_to_expr(false, interval), SQLExpr::Identifier(id) => { self.sql_identifier_to_expr(id, schema, planner_context) } - SQLExpr::MapAccess { column, keys } => { - if let SQLExpr::Identifier(id) = *column { - let keys = keys.into_iter().map(|mak| mak.key).collect(); - self.plan_indexed( - col(self.normalizer.normalize(id)), - keys, - schema, - planner_context, - ) - } else { - not_impl_err!( - "map access requires an identifier, found column {column} instead" - ) - } + SQLExpr::MapAccess { .. } => { + not_impl_err!("Map Access") } - SQLExpr::ArrayIndex { obj, indexes } => { - fn is_unsupported(expr: &SQLExpr) -> bool { - matches!(expr, SQLExpr::JsonAccess { .. }) - } - fn simplify_array_index_expr(expr: Expr, index: Expr) -> (Expr, bool) { - match &expr { - Expr::AggregateFunction(agg_func) if agg_func.func_def == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(AggregateFunction::ArrayAgg) => { - let mut new_args = agg_func.args.clone(); - new_args.push(index.clone()); - (Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new( - datafusion_expr::AggregateFunction::NthValue, - new_args, - agg_func.distinct, - agg_func.filter.clone(), - agg_func.order_by.clone(), - agg_func.null_treatment, - )), true) - }, - _ => (expr, false), - } - } - let expr = - self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; - if indexes.len() > 1 || is_unsupported(&indexes[0]) { - return self.plan_indexed(expr, indexes, schema, planner_context); - } - let (new_expr, changed) = simplify_array_index_expr( - expr, - self.sql_expr_to_logical_expr( - indexes[0].clone(), - schema, - planner_context, - )?, - ); - - if changed { - Ok(new_expr) - } else { - self.plan_indexed(new_expr, indexes, schema, planner_context) - } + // ["foo"], [4] or [4:5] + SQLExpr::Subscript { expr, subscript } => { + self.sql_subscript_to_expr(*expr, subscript, schema, planner_context) } SQLExpr::CompoundIdentifier(ids) => { @@ -267,36 +234,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ), SQLExpr::Cast { + kind: CastKind::Cast | CastKind::DoubleColon, expr, data_type, format, - } => { - if let Some(format) = format { - return not_impl_err!("CAST with format is not supported: {format}"); - } - - let dt = self.convert_data_type(&data_type)?; - let expr = - self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; - - // numeric constants are treated as seconds (rather as nanoseconds) - // to align with postgres / duckdb semantics - let expr = match &dt { - DataType::Timestamp(TimeUnit::Nanosecond, tz) - if expr.get_type(schema)? == DataType::Int64 => - { - Expr::Cast(Cast::new( - Box::new(expr), - DataType::Timestamp(TimeUnit::Second, tz.clone()), - )) - } - _ => expr, - }; + } => self.sql_cast_to_expr(*expr, data_type, format, schema, planner_context), - Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) - } - - SQLExpr::TryCast { + SQLExpr::Cast { + kind: CastKind::TryCast | CastKind::SafeCast, expr, data_type, format, @@ -467,7 +412,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr, substring_from, substring_for, - special: false, + special: _, } => self.sql_substring_to_expr( expr, substring_from, @@ -497,10 +442,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, ), - SQLExpr::AggregateExpressionWithFilter { expr, filter } => { - self.sql_agg_with_filter_to_expr(*expr, *filter, schema, planner_context) - } - SQLExpr::Function(function) => { self.sql_function_to_expr(function, schema, planner_context) } @@ -552,12 +493,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.parse_scalar_subquery(*subquery, schema, planner_context) } - SQLExpr::ArrayAgg(array_agg) => { - self.parse_array_agg(array_agg, schema, planner_context) - } - SQLExpr::Struct { values, fields } => { - self.parse_struct(values, fields, schema, planner_context) + self.parse_struct(schema, planner_context, values, fields) } SQLExpr::Position { expr, r#in } => { self.sql_position_to_expr(*expr, *r#in, schema, planner_context) @@ -571,43 +508,196 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema, planner_context, )?), - DataType::Timestamp(TimeUnit::Nanosecond, Some(time_zone.into())), + match *time_zone { + SQLExpr::Value(Value::SingleQuotedString(s)) => { + DataType::Timestamp(TimeUnit::Nanosecond, Some(s.into())) + } + _ => { + return not_impl_err!( + "Unsupported ast node in sqltorel: {time_zone:?}" + ) + } + }, ))), + SQLExpr::Dictionary(fields) => { + self.try_plan_dictionary_literal(fields, schema, planner_context) + } + SQLExpr::Map(map) => { + self.try_plan_map_literal(map.entries, schema, planner_context) + } + SQLExpr::AnyOp { + left, + compare_op, + right, + } => { + let mut binary_expr = RawBinaryExpr { + op: compare_op, + left: self.sql_expr_to_logical_expr( + *left, + schema, + planner_context, + )?, + right: self.sql_expr_to_logical_expr( + *right, + schema, + planner_context, + )?, + }; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_any(binary_expr)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => { + binary_expr = expr; + } + } + } + not_impl_err!("AnyOp not supported by ExprPlanner: {binary_expr:?}") + } + SQLExpr::Wildcard => Ok(Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }), + SQLExpr::QualifiedWildcard(object_name) => Ok(Expr::Wildcard { + qualifier: Some(self.object_name_to_table_reference(object_name)?), + options: WildcardOptions::default(), + }), + SQLExpr::Tuple(values) => self.parse_tuple(schema, planner_context, values), _ => not_impl_err!("Unsupported ast node in sqltorel: {sql:?}"), } } - /// Parses a struct(..) expression + /// Parses a struct(..) expression and plans it creation fn parse_struct( &self, - values: Vec, - fields: Vec, - input_schema: &DFSchema, + schema: &DFSchema, planner_context: &mut PlannerContext, + values: Vec, + fields: Vec, ) -> Result { if !fields.is_empty() { return not_impl_err!("Struct fields are not supported yet"); } - - if values + let is_named_struct = values .iter() - .any(|value| matches!(value, SQLExpr::Named { .. })) - { - self.create_named_struct(values, input_schema, planner_context) + .any(|value| matches!(value, SQLExpr::Named { .. })); + + let mut create_struct_args = if is_named_struct { + self.create_named_struct_expr(values, schema, planner_context)? } else { - self.create_struct(values, input_schema, planner_context) + self.create_struct_expr(values, schema, planner_context)? + }; + + for planner in self.context_provider.get_expr_planners() { + match planner.plan_struct_literal(create_struct_args, is_named_struct)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => create_struct_args = args, + } + } + not_impl_err!("Struct not supported by ExprPlanner: {create_struct_args:?}") + } + + fn parse_tuple( + &self, + schema: &DFSchema, + planner_context: &mut PlannerContext, + values: Vec, + ) -> Result { + match values.first() { + Some(SQLExpr::Identifier(_)) | Some(SQLExpr::Value(_)) => { + self.parse_struct(schema, planner_context, values, vec![]) + } + None => not_impl_err!("Empty tuple not supported yet"), + _ => { + not_impl_err!("Only identifiers and literals are supported in tuples") + } + } + } + + fn sql_position_to_expr( + &self, + substr_expr: SQLExpr, + str_expr: SQLExpr, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let substr = + self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; + let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; + let mut position_args = vec![fullstr, substr]; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_position(position_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + position_args = args; + } + } + } + + not_impl_err!("Position not supported by ExprPlanner: {position_args:?}") + } + + fn try_plan_dictionary_literal( + &self, + fields: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut keys = vec![]; + let mut values = vec![]; + for field in fields { + let key = lit(field.key.value); + let value = + self.sql_expr_to_logical_expr(*field.value, schema, planner_context)?; + keys.push(key); + values.push(value); + } + + let mut raw_expr = RawDictionaryExpr { keys, values }; + + for planner in self.context_provider.get_expr_planners() { + match planner.plan_dictionary_literal(raw_expr, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => raw_expr = expr, + } + } + not_impl_err!("Dictionary not supported by ExprPlanner: {raw_expr:?}") + } + + fn try_plan_map_literal( + &self, + entries: Vec, + schema: &DFSchema, + planner_context: &mut PlannerContext, + ) -> Result { + let mut exprs: Vec<_> = entries + .into_iter() + .flat_map(|entry| vec![entry.key, entry.value].into_iter()) + .map(|expr| self.sql_expr_to_logical_expr(*expr, schema, planner_context)) + .collect::>>()?; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_make_map(exprs)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(expr) => exprs = expr, + } } + not_impl_err!("MAP not supported by ExprPlanner: {exprs:?}") } // Handles a call to struct(...) where the arguments are named. For example // `struct (v as foo, v2 as bar)` by creating a call to the `named_struct` function - fn create_named_struct( + fn create_named_struct_expr( &self, values: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, - ) -> Result { - let args = values + ) -> Result> { + Ok(values .into_iter() .enumerate() .map(|(i, value)| { @@ -636,95 +726,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()? .into_iter() .flatten() - .collect(); - - let named_struct_func = self - .context_provider - .get_function_meta("named_struct") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'named_struct' function") - })?; - - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - named_struct_func, - args, - ))) + .collect()) } // Handles a call to struct(...) where the arguments are not named. For example // `struct (v, v2)` by creating a call to the `struct` function // which will create a struct with fields named `c0`, `c1`, etc. - fn create_struct( + fn create_struct_expr( &self, values: Vec, input_schema: &DFSchema, planner_context: &mut PlannerContext, - ) -> Result { - let args = values + ) -> Result> { + values .into_iter() .map(|value| { self.sql_expr_to_logical_expr(value, input_schema, planner_context) }) - .collect::>>()?; - let struct_func = self - .context_provider - .get_function_meta("struct") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'struct' function") - })?; - - Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - struct_func, - args, - ))) - } - - fn parse_array_agg( - &self, - array_agg: ArrayAgg, - input_schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - // Some dialects have special syntax for array_agg. DataFusion only supports it like a function. - let ArrayAgg { - distinct, - expr, - order_by, - limit, - within_group, - } = array_agg; - let order_by = if let Some(order_by) = order_by { - Some(self.order_by_to_sort_expr( - &order_by, - input_schema, - planner_context, - true, - )?) - } else { - None - }; - - if let Some(limit) = limit { - return not_impl_err!("LIMIT not supported in ARRAY_AGG: {limit}"); - } - - if within_group { - return not_impl_err!("WITHIN GROUP not supported in ARRAY_AGG"); - } - - let args = - vec![self.sql_expr_to_logical_expr(*expr, input_schema, planner_context)?]; - - // next, aggregate built-ins - Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - AggregateFunction::ArrayAgg, - args, - distinct, - None, - order_by, - None, - ))) - // see if we can rewrite it into NTH-VALUE + .collect::>>() } fn sql_in_list_to_expr( @@ -753,7 +772,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated: bool, expr: SQLExpr, pattern: SQLExpr, - escape_char: Option, + escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, case_insensitive: bool, @@ -763,6 +782,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return plan_err!("Invalid pattern in LIKE expression"); } + let escape_char = if let Some(char) = escape_char { + if char.len() != 1 { + return plan_err!("Invalid escape character in LIKE expression"); + } + Some(char.chars().next().unwrap()) + } else { + None + }; Ok(Expr::Like(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), @@ -777,7 +804,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { negated: bool, expr: SQLExpr, pattern: SQLExpr, - escape_char: Option, + escape_char: Option, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { @@ -786,6 +813,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return plan_err!("Invalid pattern in SIMILAR TO expression"); } + let escape_char = if let Some(char) = escape_char { + if char.len() != 1 { + return plan_err!("Invalid escape character in SIMILAR TO expression"); + } + Some(char.chars().next().unwrap()) + } else { + None + }; Ok(Expr::SimilarTo(Like::new( negated, Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?), @@ -854,18 +889,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = self - .context_provider - .get_function_meta("overlay") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'overlay' function") - })?; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let what_arg = self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; let from_arg = self.sql_expr_to_logical_expr(overlay_from, schema, planner_context)?; - let args = match overlay_for { + let mut overlay_args = match overlay_for { Some(for_expr) => { let for_expr = self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; @@ -873,155 +902,125 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) - } - fn sql_position_to_expr( - &self, - substr_expr: SQLExpr, - str_expr: SQLExpr, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let fun = self - .context_provider - .get_function_meta("strpos") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'strpos' function") - })?; - let substr = - self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; - let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; - let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + for planner in self.context_provider.get_expr_planners() { + match planner.plan_overlay(overlay_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => overlay_args = args, + } + } + not_impl_err!("Overlay not supported by ExprPlanner: {overlay_args:?}") } - fn sql_agg_with_filter_to_expr( + + fn sql_cast_to_expr( &self, expr: SQLExpr, - filter: SQLExpr, + data_type: SQLDataType, + format: Option, schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { - Expr::AggregateFunction(expr::AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn(fun), - args, - distinct, - order_by, - null_treatment, - filter: None, // filter is passed in - }) => Ok(Expr::AggregateFunction(expr::AggregateFunction::new( - fun, - args, - distinct, - Some(Box::new(self.sql_expr_to_logical_expr( - filter, - schema, - planner_context, - )?)), - order_by, - null_treatment, - ))), - Expr::AggregateFunction(..) => { - internal_err!("Expected null filter clause in aggregate function") - } - _ => internal_err!( - "AggregateExpressionWithFilter expression was not an AggregateFunction" - ), + if let Some(format) = format { + return not_impl_err!("CAST with format is not supported: {format}"); } + + let dt = self.convert_data_type(&data_type)?; + let expr = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + + // numeric constants are treated as seconds (rather as nanoseconds) + // to align with postgres / duckdb semantics + let expr = match &dt { + DataType::Timestamp(TimeUnit::Nanosecond, tz) + if expr.get_type(schema)? == DataType::Int64 => + { + Expr::Cast(Cast::new( + Box::new(expr), + DataType::Timestamp(TimeUnit::Second, tz.clone()), + )) + } + _ => expr, + }; + + Ok(Expr::Cast(Cast::new(Box::new(expr), dt))) } - fn plan_indices( + fn sql_subscript_to_expr( &self, expr: SQLExpr, + subscript: Box, schema: &DFSchema, planner_context: &mut PlannerContext, - ) -> Result { - let field = match expr.clone() { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => GetFieldAccess::NamedStructField { - name: ScalarValue::from(s), - }, - SQLExpr::JsonAccess { - left, - operator: JsonOperator::Colon, - right, + ) -> Result { + let expr = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + + let field_access = match *subscript { + Subscript::Index { index } => { + // index can be a name, in which case it is a named field access + match index { + SQLExpr::Value( + Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), + ) => GetFieldAccess::NamedStructField { + name: ScalarValue::from(s), + }, + SQLExpr::JsonAccess { .. } => { + return not_impl_err!("JsonAccess"); + } + // otherwise treat like a list index + _ => GetFieldAccess::ListIndex { + key: Box::new(self.sql_expr_to_logical_expr( + index, + schema, + planner_context, + )?), + }, + } + } + Subscript::Slice { + lower_bound, + upper_bound, + stride, } => { - let (start, stop, stride) = if let SQLExpr::JsonAccess { - left: l, - operator: JsonOperator::Colon, - right: r, - } = *left - { - let start = Box::new(self.sql_expr_to_logical_expr( - *l, - schema, - planner_context, - )?); - let stop = Box::new(self.sql_expr_to_logical_expr( - *r, - schema, - planner_context, - )?); - let stride = Box::new(self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?); - (start, stop, stride) + // Means access like [:2] + let lower_bound = if let Some(lower_bound) = lower_bound { + self.sql_expr_to_logical_expr(lower_bound, schema, planner_context) } else { - let start = Box::new(self.sql_expr_to_logical_expr( - *left, - schema, - planner_context, - )?); - let stop = Box::new(self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?); - let stride = Box::new(Expr::Literal(ScalarValue::Int64(Some(1)))); - (start, stop, stride) + not_impl_err!("Slice subscript requires a lower bound") + }?; + + // means access like [2:] + let upper_bound = if let Some(upper_bound) = upper_bound { + self.sql_expr_to_logical_expr(upper_bound, schema, planner_context) + } else { + not_impl_err!("Slice subscript requires an upper bound") + }?; + + // stride, default to 1 + let stride = if let Some(stride) = stride { + self.sql_expr_to_logical_expr(stride, schema, planner_context)? + } else { + lit(1i64) }; + GetFieldAccess::ListRange { - start, - stop, - stride, + start: Box::new(lower_bound), + stop: Box::new(upper_bound), + stride: Box::new(stride), } } - _ => GetFieldAccess::ListIndex { - key: Box::new(self.sql_expr_to_logical_expr( - expr, - schema, - planner_context, - )?), - }, }; - Ok(field) - } - - fn plan_indexed( - &self, - expr: Expr, - mut keys: Vec, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result { - let indices = keys.pop().ok_or_else(|| { - ParserError("Internal error: Missing index key expression".to_string()) - })?; - - let expr = if !keys.is_empty() { - self.plan_indexed(expr, keys, schema, planner_context)? - } else { - expr - }; + let mut field_access_expr = RawFieldAccessExpr { expr, field_access }; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_field_access(field_access_expr, schema)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(expr) => { + field_access_expr = expr; + } + } + } - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - self.plan_indices(indices, schema, planner_context)?, - ))) + not_impl_err!( + "GetFieldAccess not supported by ExprPlanner: {field_access_expr:?}" + ) } } @@ -1069,7 +1068,7 @@ mod tests { impl ContextProvider for TestContextProvider { fn get_table_source(&self, name: TableReference) -> Result> { match self.tables.get(name.table()) { - Some(table) => Ok(table.clone()), + Some(table) => Ok(Arc::clone(table)), _ => plan_err!("Table not found: {}", name.table()), } } @@ -1094,15 +1093,15 @@ mod tests { None } - fn udfs_names(&self) -> Vec { + fn udf_names(&self) -> Vec { Vec::new() } - fn udafs_names(&self) -> Vec { + fn udaf_names(&self) -> Vec { Vec::new() } - fn udwfs_names(&self) -> Vec { + fn udwf_names(&self) -> Vec { Vec::new() } } diff --git a/datafusion/sql/src/expr/order_by.rs b/datafusion/sql/src/expr/order_by.rs index 4ccdf6c2d418..00289806876f 100644 --- a/datafusion/sql/src/expr/order_by.rs +++ b/datafusion/sql/src/expr/order_by.rs @@ -16,32 +16,62 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{plan_datafusion_err, plan_err, Column, DFSchema, Result}; +use datafusion_common::{ + not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, Result, +}; use datafusion_expr::expr::Sort; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, SortExpr}; use sqlparser::ast::{Expr as SQLExpr, OrderByExpr, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Convert sql [OrderByExpr] to `Vec`. /// - /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index - /// into the SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). + /// `input_schema` and `additional_schema` are used to resolve column references in the order-by expressions. + /// `input_schema` is the schema of the input logical plan, typically derived from the SELECT list. + /// + /// Usually order-by expressions can only reference the input plan's columns. + /// But the `SELECT ... FROM ... ORDER BY ...` syntax is a special case. Besides the input schema, + /// it can reference an `additional_schema` derived from the `FROM` clause. + /// + /// If `literal_to_column` is true, treat any numeric literals (e.g. `2`) as a 1 based index into the + /// SELECT list (e.g. `SELECT a, b FROM table ORDER BY 2`). Literals only reference the `input_schema`. + /// /// If false, interpret numeric literals as constant values. pub(crate) fn order_by_to_sort_expr( &self, - exprs: &[OrderByExpr], - schema: &DFSchema, + exprs: Vec, + input_schema: &DFSchema, planner_context: &mut PlannerContext, literal_to_column: bool, - ) -> Result> { + additional_schema: Option<&DFSchema>, + ) -> Result> { + if exprs.is_empty() { + return Ok(vec![]); + } + + let mut combined_schema; + let order_by_schema = match additional_schema { + Some(schema) => { + combined_schema = input_schema.clone(); + combined_schema.merge(schema); + &combined_schema + } + None => input_schema, + }; + let mut expr_vec = vec![]; for e in exprs { let OrderByExpr { asc, expr, nulls_first, + with_fill, } = e; + if let Some(with_fill) = with_fill { + return not_impl_err!("ORDER BY WITH FILL is not supported: {with_fill}"); + } + let expr = match expr { SQLExpr::Value(Value::Number(v, _)) if literal_to_column => { let field_index = v @@ -52,26 +82,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return plan_err!( "Order by index starts at 1 for column indexes" ); - } else if schema.fields().len() < field_index { + } else if input_schema.fields().len() < field_index { return plan_err!( "Order by column out of bounds, specified: {}, max: {}", field_index, - schema.fields().len() + input_schema.fields().len() ); } - Expr::Column(Column::from(schema.qualified_field(field_index - 1))) + Expr::Column(Column::from( + input_schema.qualified_field(field_index - 1), + )) + } + e => { + self.sql_expr_to_logical_expr(e, order_by_schema, planner_context)? } - e => self.sql_expr_to_logical_expr(e.clone(), schema, planner_context)?, }; let asc = asc.unwrap_or(true); - expr_vec.push(Expr::Sort(Sort::new( - Box::new(expr), + expr_vec.push(Sort::new( + expr, asc, - // when asc is true, by default nulls last to be consistent with postgres + // When asc is true, by default nulls last to be consistent with postgres // postgres rule: https://www.postgresql.org/docs/current/queries-order.html nulls_first.unwrap_or(!asc), - ))) + )) } Ok(expr_vec) } diff --git a/datafusion/sql/src/expr/subquery.rs b/datafusion/sql/src/expr/subquery.rs index d34065d92fe5..ff161c6ed644 100644 --- a/datafusion/sql/src/expr/subquery.rs +++ b/datafusion/sql/src/expr/subquery.rs @@ -33,7 +33,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone())); + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); @@ -55,7 +55,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone())); + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); @@ -77,7 +77,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { let old_outer_query_schema = - planner_context.set_outer_query_schema(Some(input_schema.clone())); + planner_context.set_outer_query_schema(Some(input_schema.clone().into())); let sub_plan = self.query_to_plan(subquery, planner_context)?; let outer_ref_columns = sub_plan.all_out_ref_exprs(); planner_context.set_outer_query_schema(old_outer_query_schema); diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index f58c6f3b94d0..f58ab5ff3612 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,9 +16,9 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{internal_datafusion_err, plan_err}; +use datafusion_common::{not_impl_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::Expr; use sqlparser::ast::Expr as SQLExpr; @@ -31,7 +31,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let args = match (substring_from, substring_for) { + let mut substring_args = match (substring_from, substring_for) { (Some(from_expr), Some(for_expr)) => { let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; @@ -68,13 +68,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - let fun = self - .context_provider - .get_function_meta("substr") - .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'substr' function") - })?; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_substring(substring_args)? { + PlannerResult::Planned(expr) => return Ok(expr), + PlannerResult::Original(args) => { + substring_args = args; + } + } + } - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) + not_impl_err!( + "Substring not supported by UserDefinedExtensionPlanners: {substring_args:?}" + ) } } diff --git a/datafusion/sql/src/expr/unary_op.rs b/datafusion/sql/src/expr/unary_op.rs index 9fcee7a06124..06988eb03893 100644 --- a/datafusion/sql/src/expr/unary_op.rs +++ b/datafusion/sql/src/expr/unary_op.rs @@ -16,8 +16,11 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{not_impl_err, DFSchema, Result}; -use datafusion_expr::Expr; +use datafusion_common::{not_impl_err, plan_err, DFSchema, Result}; +use datafusion_expr::{ + type_coercion::{is_interval, is_timestamp}, + Expr, ExprSchemable, +}; use sqlparser::ast::{Expr as SQLExpr, UnaryOperator, Value}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -33,19 +36,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.sql_expr_to_logical_expr(expr, schema, planner_context)?, ))), UnaryOperator::Plus => { - Ok(self.sql_expr_to_logical_expr(expr, schema, planner_context)?) + let operand = + self.sql_expr_to_logical_expr(expr, schema, planner_context)?; + let (data_type, _) = operand.data_type_and_nullable(schema)?; + if data_type.is_numeric() + || is_interval(&data_type) + || is_timestamp(&data_type) + { + Ok(operand) + } else { + plan_err!("Unary operator '+' only supports numeric, interval and timestamp types") + } } UnaryOperator::Minus => { match expr { - // optimization: if it's a number literal, we apply the negative operator + // Optimization: if it's a number literal, we apply the negative operator // here directly to calculate the new literal. SQLExpr::Value(Value::Number(n, _)) => { self.parse_sql_number(&n, true) } SQLExpr::Interval(interval) => { - self.sql_interval_to_expr(true, interval, schema, planner_context) + self.sql_interval_to_expr(true, interval) } - // not a literal, apply negative operator on expression + // Not a literal, apply negative operator on expression _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr( expr, schema, diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 25857db839c8..7dc15de7ad71 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -16,16 +16,19 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use arrow::compute::kernels::cast_utils::parse_interval_month_day_nano; +use arrow::compute::kernels::cast_utils::{ + parse_interval_month_day_nano_config, IntervalParseConfig, IntervalUnit, +}; use arrow::datatypes::DECIMAL128_MAX_PRECISION; use arrow_schema::DataType; use datafusion_common::{ - not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, not_impl_err, plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; -use datafusion_expr::expr::{BinaryExpr, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{BinaryExpr, Placeholder}; +use datafusion_expr::planner::PlannerResult; use datafusion_expr::{lit, Expr, Operator}; use log::debug; -use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, Value}; +use sqlparser::ast::{BinaryOperator, Expr as SQLExpr, Interval, UnaryOperator, Value}; use sqlparser::parser::ParserError::ParserError; use std::borrow::Cow; @@ -50,6 +53,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Invalid HexStringLiteral '{s}'") } } + Value::DollarQuotedString(s) => Ok(lit(s.value)), Value::EscapedStringLiteral(s) => Ok(lit(s)), _ => plan_err!("Unsupported Value '{value:?}'"), } @@ -129,6 +133,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } + // IMPORTANT: Keep sql_array_literal's function body small to prevent stack overflow + // This function is recursively called, potentially leading to deep call stacks. pub(super) fn sql_array_literal( &self, elements: Vec, @@ -141,23 +147,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) .collect::>>()?; - if let Some(udf) = self.context_provider.get_function_meta("make_array") { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(udf, values))) - } else { - not_impl_err!( - "array_expression featrue is disable, So should implement make_array UDF by yourself" - ) + self.try_plan_array_literal(values, schema) + } + + fn try_plan_array_literal( + &self, + values: Vec, + schema: &DFSchema, + ) -> Result { + let mut exprs = values; + for planner in self.context_provider.get_expr_planners() { + match planner.plan_array_literal(exprs, schema)? { + PlannerResult::Planned(expr) => { + return Ok(expr); + } + PlannerResult::Original(values) => exprs = values, + } } + + internal_err!("Expected a simplified result, but none was found") } /// Convert a SQL interval expression to a DataFusion logical plan /// expression + #[allow(clippy::only_used_in_recursion)] pub(super) fn sql_interval_to_expr( &self, negative: bool, interval: Interval, - schema: &DFSchema, - planner_context: &mut PlannerContext, ) -> Result { if interval.leading_precision.is_some() { return not_impl_err!( @@ -180,174 +197,89 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } - // Only handle string exprs for now - let value = match *interval.value { - SQLExpr::Value( - Value::SingleQuotedString(s) | Value::DoubleQuotedString(s), - ) => { - if negative { - format!("-{s}") - } else { - s + if let SQLExpr::BinaryOp { left, op, right } = *interval.value { + let df_op = match op { + BinaryOperator::Plus => Operator::Plus, + BinaryOperator::Minus => Operator::Minus, + _ => { + return not_impl_err!("Unsupported interval operator: {op:?}"); } - } - // Support expressions like `interval '1 month' + date/timestamp`. - // Such expressions are parsed like this by sqlparser-rs - // - // Interval - // BinaryOp - // Value(StringLiteral) - // Cast - // Value(StringLiteral) - // - // This code rewrites them to the following: - // - // BinaryOp - // Interval - // Value(StringLiteral) - // Cast - // Value(StringLiteral) - SQLExpr::BinaryOp { left, op, right } => { - let df_op = match op { - BinaryOperator::Plus => Operator::Plus, - BinaryOperator::Minus => Operator::Minus, - _ => { - return not_impl_err!("Unsupported interval operator: {op:?}"); - } - }; - match ( - interval.leading_field.as_ref(), - left.as_ref(), - right.as_ref(), - ) { - (_, _, SQLExpr::Value(_)) => { - let left_expr = self.sql_interval_to_expr( - negative, - Interval { - value: left, - leading_field: interval.leading_field.clone(), - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }, - schema, - planner_context, - )?; - let right_expr = self.sql_interval_to_expr( - false, - Interval { - value: right, - leading_field: interval.leading_field, - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }, - schema, - planner_context, - )?; - return Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left_expr), - df_op, - Box::new(right_expr), - ))); - } - // In this case, the left node is part of the interval - // expr and the right node is an independent expr. - // - // Leading field is not supported when the right operand - // is not a value. - (None, _, _) => { - let left_expr = self.sql_interval_to_expr( - negative, - Interval { - value: left, - leading_field: None, - leading_precision: None, - last_field: None, - fractional_seconds_precision: None, - }, - schema, - planner_context, - )?; - let right_expr = self.sql_expr_to_logical_expr( - *right, - schema, - planner_context, - )?; - return Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(left_expr), - df_op, - Box::new(right_expr), - ))); - } - _ => { - let value = SQLExpr::BinaryOp { left, op, right }; - return not_impl_err!( - "Unsupported interval argument. Expected string literal, got: {value:?}" - ); - } - } - } - _ => { - return not_impl_err!( - "Unsupported interval argument. Expected string literal, got: {:?}", - interval.value - ); - } - }; + }; + let left_expr = self.sql_interval_to_expr( + negative, + Interval { + value: left, + leading_field: interval.leading_field.clone(), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, + )?; + let right_expr = self.sql_interval_to_expr( + false, + Interval { + value: right, + leading_field: interval.leading_field, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }, + )?; + return Ok(Expr::BinaryExpr(BinaryExpr::new( + Box::new(left_expr), + df_op, + Box::new(right_expr), + ))); + } - let value = if has_units(&value) { - // If the interval already contains a unit - // `INTERVAL '5 month' rather than `INTERVAL '5' month` - // skip the other unit - value - } else { - // leading_field really means the unit if specified - // for example, "month" in `INTERVAL '5' month` - match interval.leading_field.as_ref() { - Some(leading_field) => { - format!("{value} {leading_field}") - } - None => { - // default to seconds for the units - // `INTERVAL '5' is parsed as '5 seconds' - format!("{value} seconds") - } - } + let value = interval_literal(*interval.value, negative)?; + + // leading_field really means the unit if specified + // For example, "month" in `INTERVAL '5' month` + let value = match interval.leading_field.as_ref() { + Some(leading_field) => format!("{value} {leading_field}"), + None => value, }; - let val = parse_interval_month_day_nano(&value)?; + let config = IntervalParseConfig::new(IntervalUnit::Second); + let val = parse_interval_month_day_nano_config(&value, config)?; Ok(lit(ScalarValue::IntervalMonthDayNano(Some(val)))) } } -// TODO make interval parsing better in arrow-rs / expose `IntervalType` -fn has_units(val: &str) -> bool { - let val = val.to_lowercase(); - val.ends_with("century") - || val.ends_with("centuries") - || val.ends_with("decade") - || val.ends_with("decades") - || val.ends_with("year") - || val.ends_with("years") - || val.ends_with("month") - || val.ends_with("months") - || val.ends_with("week") - || val.ends_with("weeks") - || val.ends_with("day") - || val.ends_with("days") - || val.ends_with("hour") - || val.ends_with("hours") - || val.ends_with("minute") - || val.ends_with("minutes") - || val.ends_with("second") - || val.ends_with("seconds") - || val.ends_with("millisecond") - || val.ends_with("milliseconds") - || val.ends_with("microsecond") - || val.ends_with("microseconds") - || val.ends_with("nanosecond") - || val.ends_with("nanoseconds") +fn interval_literal(interval_value: SQLExpr, negative: bool) -> Result { + let s = match interval_value { + SQLExpr::Value(Value::SingleQuotedString(s) | Value::DoubleQuotedString(s)) => s, + SQLExpr::Value(Value::Number(ref v, long)) => { + if long { + return not_impl_err!( + "Unsupported interval argument. Long number not supported: {interval_value:?}" + ); + } else { + v.to_string() + } + } + SQLExpr::UnaryOp { op, expr } => { + let negative = match op { + UnaryOperator::Minus => !negative, + UnaryOperator::Plus => negative, + _ => { + return not_impl_err!( + "Unsupported SQL unary operator in interval {op:?}" + ); + } + }; + interval_literal(*expr, negative)? + } + _ => { + return not_impl_err!("Unsupported interval argument. Expected string literal or number, got: {interval_value:?}"); + } + }; + if negative { + Ok(format!("-{s}")) + } else { + Ok(s) + } } /// Try to decode bytes from hex literal string. @@ -391,9 +323,9 @@ const fn try_decode_hex_char(c: u8) -> Option { fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { // remove leading zeroes let trimmed = unsigned_number.trim_start_matches('0'); - // parse precision and scale, remove decimal point if exists + // Parse precision and scale, remove decimal point if exists let (precision, scale, replaced_str) = if trimmed == "." { - // special cases for numbers such as “0.”, “000.”, and so on. + // Special cases for numbers such as “0.”, “000.”, and so on. (1, 0, Cow::Borrowed("0")) } else if let Some(i) = trimmed.find('.') { ( @@ -402,7 +334,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { Cow::Owned(trimmed.replace('.', "")), ) } else { - // no decimal point, keep as is + // No decimal point, keep as is (trimmed.len(), 0, Cow::Borrowed(trimmed)) }; @@ -412,7 +344,7 @@ fn parse_decimal_128(unsigned_number: &str, negative: bool) -> Result { ))) })?; - // check precision overflow + // Check precision overflow if precision as u8 > DECIMAL128_MAX_PRECISION { return Err(DataFusionError::from(ParserError(format!( "Cannot parse {replaced_str} as i128 when building decimal: precision overflow" diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index 1040cc61c702..956f5e17e26f 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -14,19 +14,25 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// Make cheap clones clear: https://github.com/apache/datafusion/issues/11143 +#![deny(clippy::clone_on_ref_ptr)] -//! This module provides: +//! This crate provides: //! //! 1. A SQL parser, [`DFParser`], that translates SQL query text into -//! an abstract syntax tree (AST), [`Statement`]. +//! an abstract syntax tree (AST), [`Statement`]. //! //! 2. A SQL query planner [`SqlToRel`] that creates [`LogicalPlan`]s -//! from [`Statement`]s. +//! from [`Statement`]s. +//! +//! 3. A SQL [`unparser`] that converts [`Expr`]s and [`LogicalPlan`]s +//! into SQL query text. //! //! [`DFParser`]: parser::DFParser //! [`Statement`]: parser::Statement //! [`SqlToRel`]: planner::SqlToRel //! [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan +//! [`Expr`]: datafusion_expr::expr::Expr mod cte; mod expr; diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index 5a999ab21d30..8a984f1645e9 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -17,14 +17,12 @@ //! [`DFParser`]: DataFusion SQL Parser based on [`sqlparser`] -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::fmt; -use std::str::FromStr; -use datafusion_common::parsers::CompressionTypeVariant; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, + ColumnDef, ColumnOptionDef, Expr, ObjectName, OrderByExpr, Query, Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, @@ -103,8 +101,6 @@ pub struct CopyToStatement { pub target: String, /// Partition keys pub partitioned_by: Vec, - /// Indicates whether there is a header row (e.g. CSV) - pub has_header: bool, /// File type (Parquet, NDJSON, CSV etc.) pub stored_as: Option, /// Target specific options @@ -130,12 +126,9 @@ impl fmt::Display for CopyToStatement { write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; } - if self.has_header { - write!(f, " WITH HEADER ROW")?; - } - if !options.is_empty() { - let opts: Vec<_> = options.iter().map(|(k, v)| format!("{k} {v}")).collect(); + let opts: Vec<_> = + options.iter().map(|(k, v)| format!("'{k}' {v}")).collect(); write!(f, " OPTIONS ({})", opts.join(", "))?; } @@ -172,9 +165,6 @@ pub(crate) type LexOrdering = Vec; /// [ IF NOT EXISTS ] /// [ () ] /// STORED AS -/// [ WITH HEADER ROW ] -/// [ DELIMITER ] -/// [ COMPRESSION TYPE ] /// [ PARTITIONED BY ( | ) ] /// [ WITH ORDER () /// [ OPTIONS () ] @@ -191,15 +181,11 @@ pub(crate) type LexOrdering = Vec; #[derive(Debug, Clone, PartialEq, Eq)] pub struct CreateExternalTable { /// Table name - pub name: String, + pub name: ObjectName, /// Optional schema pub columns: Vec, /// File type (Parquet, NDJSON, CSV, etc) pub file_type: String, - /// CSV Header row? - pub has_header: bool, - /// User defined delimiter for CSVs - pub delimiter: char, /// Path to file pub location: String, /// Partition Columns @@ -208,12 +194,12 @@ pub struct CreateExternalTable { pub order_exprs: Vec, /// Option to not error if table already exists pub if_not_exists: bool, - /// File compression type (GZIP, BZIP2, XZ) - pub file_compression_type: CompressionTypeVariant, + /// Whether the table is a temporary table + pub temporary: bool, /// Infinite streams? pub unbounded: bool, /// Table(provider) specific options - pub options: HashMap, + pub options: Vec<(String, Value)>, /// A table-level constraint pub constraints: Vec, } @@ -234,7 +220,7 @@ impl fmt::Display for CreateExternalTable { /// /// This can either be a [`Statement`] from [`sqlparser`] from a /// standard SQL dialect, or a DataFusion extension such as `CREATE -/// EXTERAL TABLE`. See [`DFParser`] for more information. +/// EXTERNAL TABLE`. See [`DFParser`] for more information. /// /// [`Statement`]: sqlparser::ast::Statement #[derive(Debug, Clone, PartialEq, Eq)] @@ -269,7 +255,7 @@ fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { Ok(()) } -/// Datafusion SQL Parser based on [`sqlparser`] +/// DataFusion SQL Parser based on [`sqlparser`] /// /// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. /// @@ -339,6 +325,14 @@ impl<'a> DFParser<'a> { Ok(stmts) } + pub fn parse_sql_into_expr_with_dialect( + sql: &str, + dialect: &dyn Dialect, + ) -> Result { + let mut parser = DFParser::new_with_dialect(sql, dialect)?; + parser.parse_expr() + } + /// Report an unexpected token fn expected( &self, @@ -383,6 +377,19 @@ impl<'a> DFParser<'a> { } } + pub fn parse_expr(&mut self) -> Result { + if let Token::Word(w) = self.parser.peek_token().token { + match w.keyword { + Keyword::CREATE | Keyword::COPY | Keyword::EXPLAIN => { + return parser_err!("Unsupported command in expression"); + } + _ => {} + } + } + + self.parser.parse_expr() + } + /// Parse a SQL `COPY TO` statement pub fn parse_copy(&mut self) -> Result { // parse as a query @@ -401,7 +408,6 @@ impl<'a> DFParser<'a> { stored_as: Option, target: Option, partitioned_by: Option>, - has_header: Option, options: Option>, } @@ -428,8 +434,7 @@ impl<'a> DFParser<'a> { Keyword::WITH => { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - ensure_not_set(&builder.has_header, "WITH HEADER ROW")?; - builder.has_header = Some(true); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')"); } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -466,7 +471,6 @@ impl<'a> DFParser<'a> { source, target, partitioned_by: builder.partitioned_by.unwrap_or(vec![]), - has_header: builder.has_header.unwrap_or(false), stored_as: builder.stored_as, options: builder.options.unwrap_or(vec![]), })) @@ -481,7 +485,21 @@ impl<'a> DFParser<'a> { pub fn parse_option_key(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { - Token::Word(Word { value, .. }) => Ok(value), + Token::Word(Word { value, .. }) => { + let mut parts = vec![value]; + while self.parser.consume_token(&Token::Period) { + let next_token = self.parser.next_token(); + if let Token::Word(Word { value, .. }) = next_token.token { + parts.push(value); + } else { + // Unquoted namespaced keys have to conform to the syntax + // "[\.]*". If we have a key that breaks this + // pattern, error out: + return self.parser.expected("key name", next_token); + } + } + Ok(parts.join(".")) + } Token::SingleQuotedString(s) => Ok(s), Token::DoubleQuotedString(s) => Ok(s), Token::EscapedStringLiteral(s) => Ok(s), @@ -498,18 +516,12 @@ impl<'a> DFParser<'a> { pub fn parse_option_value(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { - Token::Word(Word { value, .. }) => Ok(Value::UnQuotedString(value)), + // e.g. things like "snappy" or "gzip" that may be keywords + Token::Word(word) => Ok(Value::SingleQuotedString(word.value)), Token::SingleQuotedString(s) => Ok(Value::SingleQuotedString(s)), Token::DoubleQuotedString(s) => Ok(Value::DoubleQuotedString(s)), Token::EscapedStringLiteral(s) => Ok(Value::EscapedStringLiteral(s)), - Token::Number(ref n, l) => match n.parse() { - Ok(n) => Ok(Value::Number(n, l)), - // The tokenizer should have ensured `n` is an integer - // so this should not be possible - Err(e) => parser_err!(format!( - "Unexpected error: could not parse '{n}' as number: {e}" - )), - }, + Token::Number(n, l) => Ok(Value::Number(n, l)), _ => self.parser.expected("string or numeric value", next_token), } } @@ -608,6 +620,7 @@ impl<'a> DFParser<'a> { expr, asc, nulls_first, + with_fill: None, }) } @@ -688,6 +701,10 @@ impl<'a> DFParser<'a> { &mut self, unbounded: bool, ) -> Result { + let temporary = self + .parser + .parse_one_of_keywords(&[Keyword::TEMP, Keyword::TEMPORARY]) + .is_some(); self.parser.expect_keyword(Keyword::TABLE)?; let if_not_exists = self.parser @@ -699,12 +716,9 @@ impl<'a> DFParser<'a> { struct Builder { file_type: Option, location: Option, - has_header: Option, - delimiter: Option, - file_compression_type: Option, table_partition_cols: Option>, order_exprs: Vec, - options: Option>, + options: Option>, } let mut builder = Builder::default(); @@ -734,22 +748,15 @@ impl<'a> DFParser<'a> { } else { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - ensure_not_set(&builder.has_header, "WITH HEADER ROW")?; - builder.has_header = Some(true); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)"); } } Keyword::DELIMITER => { - ensure_not_set(&builder.delimiter, "DELIMITER")?; - builder.delimiter = Some(self.parse_delimiter()?); + return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')"); } Keyword::COMPRESSION => { self.parser.expect_keyword(Keyword::TYPE)?; - ensure_not_set( - &builder.file_compression_type, - "COMPRESSION TYPE", - )?; - builder.file_compression_type = - Some(self.parse_file_compression_type()?); + return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)"); } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -760,10 +767,10 @@ impl<'a> DFParser<'a> { // Note that mixing both names and definitions is not allowed let peeked = self.parser.peek_nth_token(2); if peeked == Token::Comma || peeked == Token::RParen { - // list of column names + // List of column names builder.table_partition_cols = Some(self.parse_partitions()?) } else { - // list of column defs + // List of column defs let (cols, cons) = self.parse_columns()?; builder.table_partition_cols = Some( cols.iter().map(|col| col.name.to_string()).collect(), @@ -781,7 +788,7 @@ impl<'a> DFParser<'a> { } Keyword::OPTIONS => { ensure_not_set(&builder.options, "OPTIONS")?; - builder.options = Some(self.parse_string_options()?); + builder.options = Some(self.parse_value_options()?); } _ => { unreachable!() @@ -812,20 +819,16 @@ impl<'a> DFParser<'a> { } let create = CreateExternalTable { - name: table_name.to_string(), + name: table_name, columns, file_type: builder.file_type.unwrap(), - has_header: builder.has_header.unwrap_or(false), - delimiter: builder.delimiter.unwrap_or(','), location: builder.location.unwrap(), table_partition_cols: builder.table_partition_cols.unwrap_or(vec![]), order_exprs: builder.order_exprs, if_not_exists, - file_compression_type: builder - .file_compression_type - .unwrap_or(CompressionTypeVariant::UNCOMPRESSED), + temporary, unbounded, - options: builder.options.unwrap_or(HashMap::new()), + options: builder.options.unwrap_or(Vec::new()), constraints, }; Ok(Statement::CreateExternalTable(create)) @@ -840,45 +843,10 @@ impl<'a> DFParser<'a> { } } - /// Parses the set of - fn parse_file_compression_type( - &mut self, - ) -> Result { - let token = self.parser.next_token(); - match &token.token { - Token::Word(w) => CompressionTypeVariant::from_str(&w.value), - _ => self.expected("one of GZIP, BZIP2, XZ, ZSTD", token), - } - } - - /// Parses (key value) style options where the values are literal strings. - fn parse_string_options(&mut self) -> Result, ParserError> { - let mut options = HashMap::new(); - self.parser.expect_token(&Token::LParen)?; - - loop { - let key = self.parser.parse_literal_string()?; - let value = self.parser.parse_literal_string()?; - options.insert(key, value); - let comma = self.parser.consume_token(&Token::Comma); - if self.parser.consume_token(&Token::RParen) { - // allow a trailing comma, even though it's not in standard - break; - } else if !comma { - return self.expected( - "',' or ')' after option definition", - self.parser.peek_token(), - ); - } - } - Ok(options) - } - /// Parses (key value) style options into a map of String --> [`Value`]. /// - /// Unlike [`Self::parse_string_options`], this method supports - /// keywords as key names as well as multiple value types such as - /// Numbers as well as Strings. + /// This method supports keywords as key names as well as multiple + /// value types such as Numbers as well as Strings. fn parse_value_options(&mut self) -> Result, ParserError> { let mut options = vec![]; self.parser.expect_token(&Token::LParen)?; @@ -889,7 +857,7 @@ impl<'a> DFParser<'a> { options.push((key, value)); let comma = self.parser.consume_token(&Token::Comma); if self.parser.consume_token(&Token::RParen) { - // allow a trailing comma, even though it's not in standard + // Allow a trailing comma, even though it's not in standard break; } else if !comma { return self.expected( @@ -900,16 +868,6 @@ impl<'a> DFParser<'a> { } Ok(options) } - - fn parse_delimiter(&mut self) -> Result { - let token = self.parser.parse_literal_string()?; - match token.len() { - 1 => Ok(token.chars().next().unwrap()), - _ => Err(ParserError::TokenizerError( - "Delimiter must be a single char".to_string(), - )), - } - } } #[cfg(test)] @@ -917,7 +875,6 @@ mod tests { use super::*; use sqlparser::ast::Expr::Identifier; use sqlparser::ast::{BinaryOperator, DataType, Expr, Ident}; - use CompressionTypeVariant::UNCOMPRESSED; fn expect_parse_ok(sql: &str, expected: Statement) -> Result<(), ParserError> { let statements = DFParser::parse_sql(sql)?; @@ -965,19 +922,18 @@ mod tests { // positive case let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv'"; let display = None; + let name = ObjectName(vec![Ident::from("t")]); let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -985,18 +941,16 @@ mod tests { // positive case: leading space let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' "; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1005,38 +959,37 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' ;"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; // positive case with delimiter - let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV DELIMITER '|' LOCATION 'foo.csv'"; + let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS (format.delimiter '|')"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), - has_header: false, - delimiter: '|', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![( + "format.delimiter".into(), + Value::SingleQuotedString("|".into()), + )], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1045,69 +998,47 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1, p2) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec!["p1".to_string(), "p2".to_string()], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; - // positive case: it is ok for case insensitive sql stmt with `WITH HEADER ROW` tokens - let sqls = vec![ - "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH HEADER ROW LOCATION 'foo.csv'", - "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV with header row LOCATION 'foo.csv'" - ]; - for sql in sqls { - let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), - columns: vec![make_column_def("c1", DataType::Int(display))], - file_type: "CSV".to_string(), - has_header: true, - delimiter: ',', - location: "foo.csv".into(), - table_partition_cols: vec![], - order_exprs: vec![], - if_not_exists: false, - file_compression_type: UNCOMPRESSED, - unbounded: false, - options: HashMap::new(), - constraints: vec![], - }); - expect_parse_ok(sql, expected)?; - } - // positive case: it is ok for sql stmt with `COMPRESSION TYPE GZIP` tokens - let sqls = vec![ - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE GZIP LOCATION 'foo.csv'", "GZIP"), - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE BZIP2 LOCATION 'foo.csv'", "BZIP2"), - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE XZ LOCATION 'foo.csv'", "XZ"), - ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE ZSTD LOCATION 'foo.csv'", "ZSTD"), - ]; - for (sql, file_compression_type) in sqls { + let sqls = + vec![ + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ('format.compression' 'GZIP')", "GZIP"), + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ('format.compression' 'BZIP2')", "BZIP2"), + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ('format.compression' 'XZ')", "XZ"), + ("CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS + ('format.compression' 'ZSTD')", "ZSTD"), + ]; + for (sql, compression) in sqls { let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(display))], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: CompressionTypeVariant::from_str( - file_compression_type, - )?, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![( + "format.compression".into(), + Value::SingleQuotedString(compression.into()), + )], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1116,18 +1047,16 @@ mod tests { // positive case: it is ok for parquet files not to have columns specified let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), - has_header: false, - delimiter: ',', location: "foo.parquet".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1135,18 +1064,16 @@ mod tests { // positive case: it is ok for parquet files to be other than upper case let sql = "CREATE EXTERNAL TABLE t STORED AS parqueT LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), - has_header: false, - delimiter: ',', location: "foo.parquet".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1154,18 +1081,16 @@ mod tests { // positive case: it is ok for avro files not to have columns specified let sql = "CREATE EXTERNAL TABLE t STORED AS AVRO LOCATION 'foo.avro'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "AVRO".to_string(), - has_header: false, - delimiter: ',', location: "foo.avro".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1174,41 +1099,37 @@ mod tests { let sql = "CREATE EXTERNAL TABLE IF NOT EXISTS t STORED AS PARQUET LOCATION 'foo.parquet'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "PARQUET".to_string(), - has_header: false, - delimiter: ',', location: "foo.parquet".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: true, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; - // positive case: column definiton allowed in 'partition by' clause + // positive case: column definition allowed in 'partition by' clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), make_column_def("p1", DataType::Int(None)), ], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec!["p1".to_string()], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1216,7 +1137,10 @@ mod tests { // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; - expect_parse_error(sql, "sql parser error: Expected a data type name, found: )"); + expect_parse_error( + sql, + "sql parser error: Expected: a data type name, found: )", + ); // negative case: mixed column defs and column names in `PARTITIONED BY` clause let sql = @@ -1227,18 +1151,16 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1') LOCATION 'blahblah'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "X".to_string(), - has_header: false, - delimiter: ',', location: "blahblah".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::from([("k1".into(), "v1".into())]), + options: vec![("k1".into(), Value::SingleQuotedString("v1".into()))], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1247,21 +1169,19 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1', k2 v2) LOCATION 'blahblah'"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![], file_type: "X".to_string(), - has_header: false, - delimiter: ',', location: "blahblah".into(), table_partition_cols: vec![], order_exprs: vec![], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::from([ - ("k1".into(), "v1".into()), - ("k2".into(), "v2".into()), - ]), + options: vec![ + ("k1".into(), Value::SingleQuotedString("v1".into())), + ("k2".into(), Value::SingleQuotedString("v2".into())), + ], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1289,11 +1209,9 @@ mod tests { ]; for (sql, (asc, nulls_first)) in sqls.iter().zip(expected.into_iter()) { let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![make_column_def("c1", DataType::Int(None))], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![vec![OrderByExpr { @@ -1303,11 +1221,12 @@ mod tests { }), asc, nulls_first, + with_fill: None, }]], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1317,14 +1236,12 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 ASC, c2 DESC NULLS FIRST) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(display)), make_column_def("c2", DataType::Int(display)), ], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![vec![ @@ -1335,6 +1252,7 @@ mod tests { }), asc: Some(true), nulls_first: None, + with_fill: None, }, OrderByExpr { expr: Identifier(Ident { @@ -1343,12 +1261,13 @@ mod tests { }), asc: Some(false), nulls_first: Some(true), + with_fill: None, }, ]], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1357,14 +1276,12 @@ mod tests { let sql = "CREATE EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV WITH ORDER (c1 - c2 ASC) LOCATION 'foo.csv'"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(display)), make_column_def("c2", DataType::Int(display)), ], file_type: "CSV".to_string(), - has_header: false, - delimiter: ',', location: "foo.csv".into(), table_partition_cols: vec![], order_exprs: vec![vec![OrderByExpr { @@ -1381,11 +1298,12 @@ mod tests { }, asc: Some(true), nulls_first: None, + with_fill: None, }]], if_not_exists: false, - file_compression_type: UNCOMPRESSED, + temporary: false, unbounded: false, - options: HashMap::new(), + options: vec![], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1394,23 +1312,21 @@ mod tests { let sql = " CREATE UNBOUNDED EXTERNAL TABLE IF NOT EXISTS t (c1 int, c2 float) STORED AS PARQUET - DELIMITER '*' - WITH HEADER ROW WITH ORDER (c1 - c2 ASC) - COMPRESSION TYPE zstd PARTITIONED BY (c1) LOCATION 'foo.parquet' - OPTIONS (ROW_GROUP_SIZE '1024', 'TRUNCATE' 'NO') - "; + OPTIONS ('format.compression' 'zstd', + 'format.delimiter' '*', + 'ROW_GROUP_SIZE' '1024', + 'TRUNCATE' 'NO', + 'format.has_header' 'true')"; let expected = Statement::CreateExternalTable(CreateExternalTable { - name: "t".into(), + name: name.clone(), columns: vec![ make_column_def("c1", DataType::Int(None)), make_column_def("c2", DataType::Float(None)), ], file_type: "PARQUET".to_string(), - has_header: true, - delimiter: '*', location: "foo.parquet".into(), table_partition_cols: vec!["c1".into()], order_exprs: vec![vec![OrderByExpr { @@ -1427,14 +1343,30 @@ mod tests { }, asc: Some(true), nulls_first: None, + with_fill: None, }]], if_not_exists: true, - file_compression_type: CompressionTypeVariant::ZSTD, + temporary: false, unbounded: true, - options: HashMap::from([ - ("ROW_GROUP_SIZE".into(), "1024".into()), - ("TRUNCATE".into(), "NO".into()), - ]), + options: vec![ + ( + "format.compression".into(), + Value::SingleQuotedString("zstd".into()), + ), + ( + "format.delimiter".into(), + Value::SingleQuotedString("*".into()), + ), + ( + "ROW_GROUP_SIZE".into(), + Value::SingleQuotedString("1024".into()), + ), + ("TRUNCATE".into(), Value::SingleQuotedString("NO".into())), + ( + "format.has_header".into(), + Value::SingleQuotedString("true".into()), + ), + ], constraints: vec![], }); expect_parse_ok(sql, expected)?; @@ -1452,7 +1384,6 @@ mod tests { source: object_name("foo"), target: "bar".to_string(), partitioned_by: vec![], - has_header: false, stored_as: Some("CSV".to_owned()), options: vec![], }); @@ -1488,7 +1419,6 @@ mod tests { source: object_name("foo"), target: "bar".to_string(), partitioned_by: vec![], - has_header: false, stored_as: Some("PARQUET".to_owned()), options: vec![], }); @@ -1519,14 +1449,17 @@ mod tests { panic!("Expected query, got {statement:?}"); }; - let sql = "COPY (SELECT 1) TO bar STORED AS CSV WITH HEADER ROW"; + let sql = + "COPY (SELECT 1) TO bar STORED AS CSV OPTIONS ('format.has_header' 'true')"; let expected = Statement::CopyTo(CopyToStatement { source: CopyToSource::Query(query), target: "bar".to_string(), partitioned_by: vec![], - has_header: true, stored_as: Some("CSV".to_owned()), - options: vec![], + options: vec![( + "format.has_header".into(), + Value::SingleQuotedString("true".into()), + )], }); assert_eq!(verified_stmt(sql), expected); Ok(()) @@ -1534,16 +1467,15 @@ mod tests { #[test] fn copy_to_options() -> Result<(), ParserError> { - let sql = "COPY foo TO bar STORED AS CSV OPTIONS (row_group_size 55)"; + let sql = "COPY foo TO bar STORED AS CSV OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), partitioned_by: vec![], - has_header: false, stored_as: Some("CSV".to_owned()), options: vec![( "row_group_size".to_string(), - Value::Number("55".to_string(), false), + Value::SingleQuotedString("55".to_string()), )], }); assert_eq!(verified_stmt(sql), expected); @@ -1552,16 +1484,15 @@ mod tests { #[test] fn copy_to_partitioned_by() -> Result<(), ParserError> { - let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS (row_group_size 55)"; + let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS ('row_group_size' '55')"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), partitioned_by: vec!["a".to_string()], - has_header: false, stored_as: Some("CSV".to_owned()), options: vec![( "row_group_size".to_string(), - Value::Number("55".to_string(), false), + Value::SingleQuotedString("55".to_string()), )], }); assert_eq!(verified_stmt(sql), expected); @@ -1572,7 +1503,7 @@ mod tests { fn copy_to_multi_options() -> Result<(), ParserError> { // order of options is preserved let sql = - "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy)"; + "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy, 'execution.keep_partition_by_columns' true)"; let expected_options = vec![ ( @@ -1581,7 +1512,11 @@ mod tests { ), ( "format.compression".to_string(), - Value::UnQuotedString("snappy".to_string()), + Value::SingleQuotedString("snappy".to_string()), + ), + ( + "execution.keep_partition_by_columns".to_string(), + Value::SingleQuotedString("true".to_string()), ), ]; @@ -1615,10 +1550,10 @@ mod tests { /// that: /// /// 1. parsing `sql` results in the same [`Statement`] as parsing - /// `canonical`. + /// `canonical`. /// /// 2. re-serializing the result of parsing `sql` produces the same - /// `canonical` sql string + /// `canonical` sql string fn one_statement_parses_to(sql: &str, canonical: &str) -> Statement { let mut statements = DFParser::parse_sql(sql).unwrap(); assert_eq!(statements.len(), 1); diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 0066f75f0d30..4d44d5ff2584 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -22,15 +22,13 @@ use std::vec; use arrow_schema::*; use datafusion_common::{ - field_not_found, internal_err, plan_datafusion_err, SchemaError, + field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, SchemaError, }; -use datafusion_expr::WindowUDF; -use sqlparser::ast::TimezoneInfo; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{DataType as SQLDataType, Ident, ObjectName, TableAlias}; +use sqlparser::ast::{TimezoneInfo, Value}; -use datafusion_common::config::ConfigOptions; use datafusion_common::TableReference; use datafusion_common::{ not_impl_err, plan_err, unqualified_field_not_found, DFSchema, DataFusionError, @@ -38,64 +36,18 @@ use datafusion_common::{ }; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; use datafusion_expr::utils::find_column_exprs; -use datafusion_expr::TableSource; -use datafusion_expr::{col, AggregateUDF, Expr, ScalarUDF}; - -use crate::utils::make_decimal_type; - -/// The ContextProvider trait allows the query planner to obtain meta-data about tables and -/// functions referenced in SQL statements -pub trait ContextProvider { - #[deprecated(since = "32.0.0", note = "please use `get_table_source` instead")] - fn get_table_provider(&self, name: TableReference) -> Result> { - self.get_table_source(name) - } - /// Getter for a datasource - fn get_table_source(&self, name: TableReference) -> Result>; - /// Getter for a table function - fn get_table_function_source( - &self, - _name: &str, - _args: Vec, - ) -> Result> { - not_impl_err!("Table Functions are not supported") - } - - /// This provides a worktable (an intermediate table that is used to store the results of a CTE during execution) - /// We don't directly implement this in the logical plan's ['SqlToRel`] - /// because the sql code needs access to a table that contains execution-related types that can't be a direct dependency - /// of the sql crate (namely, the `CteWorktable`). - /// The [`ContextProvider`] provides a way to "hide" this dependency. - fn create_cte_work_table( - &self, - _name: &str, - _schema: SchemaRef, - ) -> Result> { - not_impl_err!("Recursive CTE is not implemented") - } - - /// Getter for a UDF description - fn get_function_meta(&self, name: &str) -> Option>; - /// Getter for a UDAF description - fn get_aggregate_meta(&self, name: &str) -> Option>; - /// Getter for a UDWF - fn get_window_meta(&self, name: &str) -> Option>; - /// Getter for system/user-defined variable type - fn get_variable_type(&self, variable_names: &[String]) -> Option; - - /// Get configuration options - fn options(&self) -> &ConfigOptions; - - fn udfs_names(&self) -> Vec; - fn udafs_names(&self) -> Vec; - fn udwfs_names(&self) -> Vec; -} +use datafusion_expr::{col, Expr}; + +use crate::utils::{make_decimal_type, value_to_string}; +pub use datafusion_expr::planner::ContextProvider; /// SQL parser options #[derive(Debug)] pub struct ParserOptions { pub parse_float_as_decimal: bool, pub enable_ident_normalization: bool, + pub support_varchar_with_length: bool, + pub enable_options_value_normalization: bool, } impl Default for ParserOptions { @@ -103,6 +55,8 @@ impl Default for ParserOptions { Self { parse_float_as_decimal: false, enable_ident_normalization: true, + support_varchar_with_length: true, + enable_options_value_normalization: true, } } } @@ -133,21 +87,59 @@ impl IdentNormalizer { } } +/// Value Normalizer +#[derive(Debug)] +pub struct ValueNormalizer { + normalize: bool, +} + +impl Default for ValueNormalizer { + fn default() -> Self { + Self { normalize: true } + } +} + +impl ValueNormalizer { + pub fn new(normalize: bool) -> Self { + Self { normalize } + } + + pub fn normalize(&self, value: Value) -> Option { + match (value_to_string(&value), self.normalize) { + (Some(s), true) => Some(s.to_ascii_lowercase()), + (Some(s), false) => Some(s), + (None, _) => None, + } + } +} + /// Struct to store the states used by the Planner. The Planner will leverage the states to resolve /// CTEs, Views, subqueries and PREPARE statements. The states include /// Common Table Expression (CTE) provided with WITH clause and /// Parameter Data Types provided with PREPARE statement and the query schema of the -/// outer query plan +/// outer query plan. +/// +/// # Cloning +/// +/// Only the `ctes` are truly cloned when the `PlannerContext` is cloned. This helps resolve +/// scoping issues of CTEs. By using cloning, a subquery can inherit CTEs from the outer query +/// and can also define its own private CTEs without affecting the outer query. +/// #[derive(Debug, Clone)] pub struct PlannerContext { /// Data types for numbered parameters ($1, $2, etc), if supplied /// in `PREPARE` statement - prepare_param_data_types: Vec, + prepare_param_data_types: Arc>, /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, /// The query schema of the outer query plan, used to resolve the columns in subquery - outer_query_schema: Option, + outer_query_schema: Option, + /// The joined schemas of all FROM clauses planned so far. When planning LATERAL + /// FROM clauses, this should become a suffix of the `outer_query_schema`. + outer_from_schema: Option, + /// The query schema defined by the table + create_table_schema: Option, } impl Default for PlannerContext { @@ -160,9 +152,11 @@ impl PlannerContext { /// Create an empty PlannerContext pub fn new() -> Self { Self { - prepare_param_data_types: vec![], + prepare_param_data_types: Arc::new(vec![]), ctes: HashMap::new(), outer_query_schema: None, + outer_from_schema: None, + create_table_schema: None, } } @@ -171,31 +165,66 @@ impl PlannerContext { mut self, prepare_param_data_types: Vec, ) -> Self { - self.prepare_param_data_types = prepare_param_data_types; + self.prepare_param_data_types = prepare_param_data_types.into(); self } - // return a reference to the outer queries schema + // Return a reference to the outer query's schema pub fn outer_query_schema(&self) -> Option<&DFSchema> { - self.outer_query_schema.as_ref() + self.outer_query_schema.as_ref().map(|s| s.as_ref()) } - /// sets the outer query schema, returning the existing one, if + /// Sets the outer query schema, returning the existing one, if /// any pub fn set_outer_query_schema( &mut self, - mut schema: Option, - ) -> Option { + mut schema: Option, + ) -> Option { std::mem::swap(&mut self.outer_query_schema, &mut schema); schema } + pub fn set_table_schema( + &mut self, + mut schema: Option, + ) -> Option { + std::mem::swap(&mut self.create_table_schema, &mut schema); + schema + } + + pub fn table_schema(&self) -> Option { + self.create_table_schema.clone() + } + + // Return a clone of the outer FROM schema + pub fn outer_from_schema(&self) -> Option> { + self.outer_from_schema.clone() + } + + /// Sets the outer FROM schema, returning the existing one, if any + pub fn set_outer_from_schema( + &mut self, + mut schema: Option, + ) -> Option { + std::mem::swap(&mut self.outer_from_schema, &mut schema); + schema + } + + /// Extends the FROM schema, returning the existing one, if any + pub fn extend_outer_from_schema(&mut self, schema: &DFSchemaRef) -> Result<()> { + match self.outer_from_schema.as_mut() { + Some(from_schema) => Arc::make_mut(from_schema).merge(schema), + None => self.outer_from_schema = Some(Arc::clone(schema)), + }; + Ok(()) + } + /// Return the types of parameters (`$1`, `$2`, etc) if known pub fn prepare_param_data_types(&self) -> &[DataType] { &self.prepare_param_data_types } - /// returns true if there is a Common Table Expression (CTE) / + /// Returns true if there is a Common Table Expression (CTE) / /// Subquery for the specified name pub fn contains_cte(&self, cte_name: &str) -> bool { self.ctes.contains_key(cte_name) @@ -224,7 +253,8 @@ impl PlannerContext { pub struct SqlToRel<'a, S: ContextProvider> { pub(crate) context_provider: &'a S, pub(crate) options: ParserOptions, - pub(crate) normalizer: IdentNormalizer, + pub(crate) ident_normalizer: IdentNormalizer, + pub(crate) value_normalizer: ValueNormalizer, } impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -235,11 +265,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Create a new query planner pub fn new_with_options(context_provider: &'a S, options: ParserOptions) -> Self { - let normalize = options.enable_ident_normalization; + let ident_normalize = options.enable_ident_normalization; + let options_value_normalize = options.enable_options_value_normalization; + SqlToRel { context_provider, options, - normalizer: IdentNormalizer::new(normalize), + ident_normalizer: IdentNormalizer::new(ident_normalize), + value_normalizer: ValueNormalizer::new(options_value_normalize), } } @@ -253,7 +286,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .iter() .any(|x| x.option == ColumnOption::NotNull); fields.push(Field::new( - self.normalizer.normalize(column.name), + self.ident_normalizer.normalize(column.name), data_type, !not_nullable, )); @@ -291,8 +324,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let default_expr = self .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context) .map_err(error_desc)?; - column_defaults - .push((self.normalizer.normalize(column.name.clone()), default_expr)); + column_defaults.push(( + self.ident_normalizer.normalize(column.name.clone()), + default_expr, + )); } } Ok(column_defaults) @@ -307,7 +342,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = self.apply_expr_alias(plan, alias.columns)?; LogicalPlanBuilder::from(plan) - .alias(TableReference::bare(self.normalizer.normalize(alias.name)))? + .alias(TableReference::bare( + self.ident_normalizer.normalize(alias.name), + ))? .build() } @@ -328,7 +365,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let fields = plan.schema().fields().clone(); LogicalPlanBuilder::from(plan) .project(fields.iter().zip(idents.into_iter()).map(|(field, ident)| { - col(field.name()).alias(self.normalizer.normalize(ident)) + col(field.name()).alias(self.ident_normalizer.normalize(ident)) }))? .build() } @@ -365,12 +402,26 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { match sql_type { - SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) - | SQLDataType::Array(ArrayElemTypeDef::SquareBracket(inner_sql_type)) => { + SQLDataType::Array(ArrayElemTypeDef::AngleBracket(inner_sql_type)) => { // Arrays may be multi-dimensional. let inner_data_type = self.convert_data_type(inner_sql_type)?; Ok(DataType::new_list(inner_data_type, true)) } + SQLDataType::Array(ArrayElemTypeDef::SquareBracket( + inner_sql_type, + maybe_array_size, + )) => { + let inner_data_type = self.convert_data_type(inner_sql_type)?; + if let Some(array_size) = maybe_array_size { + Ok(DataType::new_fixed_size_list( + inner_data_type, + *array_size as i32, + true, + )) + } else { + Ok(DataType::new_list(inner_data_type, true)) + } + } SQLDataType::Array(ArrayElemTypeDef::None) => { not_impl_err!("Arrays with unspecified type is not supported") } @@ -390,15 +441,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::UnsignedInt(_) | SQLDataType::UnsignedInteger(_) | SQLDataType::UnsignedInt4(_) => { Ok(DataType::UInt32) } + SQLDataType::Varchar(length) => { + match (length, self.options.support_varchar_with_length) { + (Some(_), false) => plan_err!("does not support Varchar with length, please set `support_varchar_with_length` to be true"), + _ => Ok(DataType::Utf8), + } + } SQLDataType::UnsignedBigInt(_) | SQLDataType::UnsignedInt8(_) => Ok(DataType::UInt64), SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real | SQLDataType::Float4 => Ok(DataType::Float32), SQLDataType::Double | SQLDataType::DoublePrecision | SQLDataType::Float8 => Ok(DataType::Float64), SQLDataType::Char(_) - | SQLDataType::Varchar(_) | SQLDataType::Text | SQLDataType::String(_) => Ok(DataType::Utf8), - SQLDataType::Timestamp(None, tz_info) => { + SQLDataType::Timestamp(precision, tz_info) + if precision.is_none() || [0, 3, 6, 9].contains(&precision.unwrap()) => { let tz = if matches!(tz_info, TimezoneInfo::Tz) || matches!(tz_info, TimezoneInfo::WithTimeZone) { @@ -410,7 +467,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Timestamp Without Time zone None }; - Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.map(Into::into))) + let precision = match precision { + Some(0) => TimeUnit::Second, + Some(3) => TimeUnit::Millisecond, + Some(6) => TimeUnit::Microsecond, + None | Some(9) => TimeUnit::Nanosecond, + _ => unreachable!(), + }; + Ok(DataType::Timestamp(precision, tz.map(Into::into))) } SQLDataType::Date => Ok(DataType::Date32), SQLDataType::Time(None, tz_info) => { @@ -438,6 +502,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLDataType::Bytea => Ok(DataType::Binary), SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + SQLDataType::Struct(fields, _) => { + let fields = fields + .iter() + .enumerate() + .map(|(idx, field)| { + let data_type = self.convert_data_type(&field.field_type)?; + let field_name = match &field.field_name{ + Some(ident) => ident.clone(), + None => Ident::new(format!("c{idx}")) + }; + Ok(Arc::new(Field::new( + self.ident_normalizer.normalize(field_name), + data_type, + true, + ))) + }) + .collect::>>()?; + Ok(DataType::Struct(Fields::from(fields))) + } // Explicitly list all other types so that if sqlparser // adds/changes the `SQLDataType` the compiler will tell us on upgrade // and avoid bugs like https://github.com/apache/datafusion/issues/3059 @@ -460,9 +543,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::CharVarying(_) | SQLDataType::CharacterLargeObject(_) | SQLDataType::CharLargeObject(_) - // precision is not supported - | SQLDataType::Timestamp(Some(_), _) - // precision is not supported + // Unsupported precision + | SQLDataType::Timestamp(_, _) + // Precision is not supported | SQLDataType::Time(Some(_), _) | SQLDataType::Dec(_) | SQLDataType::BigNumeric(_) @@ -471,9 +554,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { | SQLDataType::Bytes(_) | SQLDataType::Int64 | SQLDataType::Float64 - | SQLDataType::Struct(_) | SQLDataType::JSONB | SQLDataType::Unspecified + // Clickhouse datatypes + | SQLDataType::Int16 + | SQLDataType::Int32 + | SQLDataType::Int128 + | SQLDataType::Int256 + | SQLDataType::UInt8 + | SQLDataType::UInt16 + | SQLDataType::UInt32 + | SQLDataType::UInt64 + | SQLDataType::UInt128 + | SQLDataType::UInt256 + | SQLDataType::Float32 + | SQLDataType::Date32 + | SQLDataType::Datetime64(_, _) + | SQLDataType::FixedString(_) + | SQLDataType::Map(_, _) + | SQLDataType::Tuple(_) + | SQLDataType::Nested(_) + | SQLDataType::Union(_) + | SQLDataType::Nullable(_) + | SQLDataType::LowCardinality(_) + | SQLDataType::Trigger => not_impl_err!( "Unsupported SQL type {sql_type:?}" ), @@ -505,7 +609,7 @@ pub fn object_name_to_table_reference( object_name: ObjectName, enable_normalization: bool, ) -> Result { - // use destructure to make it clear no fields on ObjectName are ignored + // Use destructure to make it clear no fields on ObjectName are ignored let ObjectName(idents) = object_name; idents_to_table_reference(idents, enable_normalization) } @@ -516,7 +620,7 @@ pub(crate) fn idents_to_table_reference( enable_normalization: bool, ) -> Result { struct IdentTaker(Vec); - /// take the next identifier from the back of idents, panic'ing if + /// Take the next identifier from the back of idents, panic'ing if /// there are none left impl IdentTaker { fn take(&mut self, enable_normalization: bool) -> String { diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index d5d3bcc4a13b..1ef009132f9e 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,56 +19,59 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::{plan_err, Constraints, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Constraints, DFSchema, Result}; +use datafusion_expr::expr::Sort; use datafusion_expr::{ - CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, - Operator, + CreateMemoryTable, DdlStatement, Distinct, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderBy, OrderByExpr, Query, SelectInto, + SetExpr, }; impl<'a, S: ContextProvider> SqlToRel<'a, S> { - /// Generate a logical plan from an SQL query + /// Generate a logical plan from an SQL query/subquery pub(crate) fn query_to_plan( &self, query: Query, - planner_context: &mut PlannerContext, + outer_planner_context: &mut PlannerContext, ) -> Result { - self.query_to_plan_with_schema(query, planner_context) - } + // Each query has its own planner context, including CTEs that are visible within that query. + // It also inherits the CTEs from the outer query by cloning the outer planner context. + let mut query_plan_context = outer_planner_context.clone(); + let planner_context = &mut query_plan_context; - /// Generate a logic plan from an SQL query. - /// It's implementation of `subquery_to_plan` and `query_to_plan`. - /// It shouldn't be invoked directly. - fn query_to_plan_with_schema( - &self, - query: Query, - planner_context: &mut PlannerContext, - ) -> Result { - let mut set_expr = query.body; if let Some(with) = query.with { self.plan_with_clause(with, planner_context)?; } - // Take the `SelectInto` for later processing. - let select_into = match set_expr.as_mut() { - SetExpr::Select(select) => select.into.take(), - _ => None, - }; - let plan = self.set_expr_to_plan(*set_expr, planner_context)?; - let plan = self.order_by(plan, query.order_by, planner_context)?; - let mut plan = self.limit(plan, query.offset, query.limit)?; - if let Some(into) = select_into { - plan = LogicalPlan::Ddl(DdlStatement::CreateMemoryTable(CreateMemoryTable { - name: self.object_name_to_table_reference(into.name)?, - constraints: Constraints::empty(), - input: Arc::new(plan), - if_not_exists: false, - or_replace: false, - column_defaults: vec![], - })) + + let set_expr = *query.body; + match set_expr { + SetExpr::Select(mut select) => { + let select_into = select.into.take(); + // Order-by expressions may refer to columns in the `FROM` clause, + // so we need to process `SELECT` and `ORDER BY` together. + let oby_exprs = to_order_by_exprs(query.order_by)?; + let plan = self.select_to_plan(*select, oby_exprs, planner_context)?; + let plan = + self.limit(plan, query.offset, query.limit, planner_context)?; + // Process the `SELECT INTO` after `LIMIT`. + self.select_into(plan, select_into) + } + other => { + let plan = self.set_expr_to_plan(other, planner_context)?; + let oby_exprs = to_order_by_exprs(query.order_by)?; + let order_by_rex = self.order_by_to_sort_expr( + oby_exprs, + plan.schema(), + planner_context, + true, + None, + )?; + let plan = self.order_by(plan, order_by_rex)?; + self.limit(plan, query.offset, query.limit, planner_context) + } } - Ok(plan) } /// Wrap a plan in a limit @@ -77,107 +80,77 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: LogicalPlan, skip: Option, fetch: Option, + planner_context: &mut PlannerContext, ) -> Result { if skip.is_none() && fetch.is_none() { return Ok(input); } - let skip = match skip { - Some(skip_expr) => { - let expr = self.sql_to_expr( - skip_expr.value, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "OFFSET")?; - convert_usize_with_check(n, "OFFSET") - } - _ => Ok(0), - }?; - - let fetch = match fetch { - Some(limit_expr) - if limit_expr != sqlparser::ast::Expr::Value(Value::Null) => - { - let expr = self.sql_to_expr( - limit_expr, - input.schema(), - &mut PlannerContext::new(), - )?; - let n = get_constant_result(&expr, "LIMIT")?; - Some(convert_usize_with_check(n, "LIMIT")?) - } - _ => None, - }; + // skip and fetch expressions are not allowed to reference columns from the input plan + let empty_schema = DFSchema::empty(); - LogicalPlanBuilder::from(input).limit(skip, fetch)?.build() + let skip = skip + .map(|o| self.sql_to_expr(o.value, &empty_schema, planner_context)) + .transpose()?; + let fetch = fetch + .map(|e| self.sql_to_expr(e, &empty_schema, planner_context)) + .transpose()?; + LogicalPlanBuilder::from(input) + .limit_by_expr(skip, fetch)? + .build() } /// Wrap the logical in a sort - fn order_by( + pub(super) fn order_by( &self, plan: LogicalPlan, - order_by: Vec, - planner_context: &mut PlannerContext, + order_by: Vec, ) -> Result { if order_by.is_empty() { return Ok(plan); } - let order_by_rex = - self.order_by_to_sort_expr(&order_by, plan.schema(), planner_context, true)?; - if let LogicalPlan::Distinct(Distinct::On(ref distinct_on)) = plan { // In case of `DISTINCT ON` we must capture the sort expressions since during the plan // optimization we're effectively doing a `first_value` aggregation according to them. - let distinct_on = distinct_on.clone().with_sort_expr(order_by_rex)?; + let distinct_on = distinct_on.clone().with_sort_expr(order_by)?; Ok(LogicalPlan::Distinct(Distinct::On(distinct_on))) } else { - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + LogicalPlanBuilder::from(plan).sort(order_by)?.build() } } -} -/// Retrieves the constant result of an expression, evaluating it if possible. -/// -/// This function takes an expression and an argument name as input and returns -/// a `Result` indicating either the constant result of the expression or an -/// error if the expression cannot be evaluated. -/// -/// # Arguments -/// -/// * `expr` - An `Expr` representing the expression to evaluate. -/// * `arg_name` - The name of the argument for error messages. -/// -/// # Returns -/// -/// * `Result` - An `Ok` variant containing the constant result if evaluation is successful, -/// or an `Err` variant containing an error message if evaluation fails. -/// -/// tracks a more general solution -fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { - match expr { - Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), - Expr::BinaryExpr(binary_expr) => { - let lhs = get_constant_result(&binary_expr.left, arg_name)?; - let rhs = get_constant_result(&binary_expr.right, arg_name)?; - let res = match binary_expr.op { - Operator::Plus => lhs + rhs, - Operator::Minus => lhs - rhs, - Operator::Multiply => lhs * rhs, - _ => return plan_err!("Unsupported operator for {arg_name} clause"), - }; - Ok(res) + /// Wrap the logical plan in a `SelectInto` + fn select_into( + &self, + plan: LogicalPlan, + select_into: Option, + ) -> Result { + match select_into { + Some(into) => Ok(LogicalPlan::Ddl(DdlStatement::CreateMemoryTable( + CreateMemoryTable { + name: self.object_name_to_table_reference(into.name)?, + constraints: Constraints::empty(), + input: Arc::new(plan), + if_not_exists: false, + or_replace: false, + temporary: false, + column_defaults: vec![], + }, + ))), + _ => Ok(plan), } - _ => plan_err!("Unexpected expression in {arg_name} clause"), } } -/// Converts an `i64` to `usize`, performing a boundary check. -fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { - if n < 0 { - plan_err!("{arg_name} must be >= 0, '{n}' was provided.") - } else { - Ok(n as usize) +/// Returns the order by expressions from the query. +fn to_order_by_exprs(order_by: Option) -> Result> { + let Some(OrderBy { exprs, interpolate }) = order_by else { + // If no order by, return an empty array. + return Ok(vec![]); + }; + if let Some(_interpolate) = interpolate { + return not_impl_err!("ORDER BY INTERPOLATE is not supported"); } + Ok(exprs) } diff --git a/datafusion/sql/src/relation/join.rs b/datafusion/sql/src/relation/join.rs index 262bae397cee..3f34608e3756 100644 --- a/datafusion/sql/src/relation/join.rs +++ b/datafusion/sql/src/relation/join.rs @@ -18,7 +18,7 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{not_impl_err, Column, Result}; use datafusion_expr::{JoinType, LogicalPlan, LogicalPlanBuilder}; -use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableWithJoins}; +use sqlparser::ast::{Join, JoinConstraint, JoinOperator, TableFactor, TableWithJoins}; use std::collections::HashSet; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -27,34 +27,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { t: TableWithJoins, planner_context: &mut PlannerContext, ) -> Result { - // From clause may exist CTEs, we should separate them from global CTEs. - // CTEs in from clause are allowed to be duplicated. - // Such as `select * from (WITH source AS (select 1 as e) SELECT * FROM source) t1, (WITH source AS (select 1 as e) SELECT * FROM source) t2;` which is valid. - // So always use original global CTEs to plan CTEs in from clause. - // Btw, don't need to add CTEs in from to global CTEs. - let origin_planner_context = planner_context.clone(); - let left = self.create_relation(t.relation, planner_context)?; - match t.joins.len() { - 0 => { - *planner_context = origin_planner_context; - Ok(left) - } - _ => { - let mut joins = t.joins.into_iter(); - *planner_context = origin_planner_context.clone(); - let mut left = self.parse_relation_join( - left, - joins.next().unwrap(), // length of joins > 0 - planner_context, - )?; - for join in joins { - *planner_context = origin_planner_context.clone(); - left = self.parse_relation_join(left, join, planner_context)?; - } - *planner_context = origin_planner_context; - Ok(left) - } + let mut left = if is_lateral(&t.relation) { + self.create_relation_subquery(t.relation, planner_context)? + } else { + self.create_relation(t.relation, planner_context)? + }; + let old_outer_from_schema = planner_context.outer_from_schema(); + for join in t.joins { + planner_context.extend_outer_from_schema(left.schema())?; + left = self.parse_relation_join(left, join, planner_context)?; } + planner_context.set_outer_from_schema(old_outer_from_schema); + Ok(left) } fn parse_relation_join( @@ -63,7 +47,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { join: Join, planner_context: &mut PlannerContext, ) -> Result { - let right = self.create_relation(join.relation, planner_context)?; + let right = if is_lateral_join(&join)? { + self.create_relation_subquery(join.relation, planner_context)? + } else { + self.create_relation(join.relation, planner_context)? + }; match join.join_operator { JoinOperator::LeftOuter(constraint) => { self.parse_join(left, right, constraint, JoinType::Left, planner_context) @@ -138,7 +126,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { JoinConstraint::Using(idents) => { let keys: Vec = idents .into_iter() - .map(|x| Column::from_name(self.normalizer.normalize(x))) + .map(|x| Column::from_name(self.ident_normalizer.normalize(x))) .collect(); LogicalPlanBuilder::from(left) .join_using(right, join_type, keys)? @@ -163,7 +151,39 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build() } } - JoinConstraint::None => not_impl_err!("NONE constraint is not supported"), + JoinConstraint::None => LogicalPlanBuilder::from(left) + .join_on(right, join_type, [])? + .build(), } } } + +/// Return `true` iff the given [`TableFactor`] is lateral. +pub(crate) fn is_lateral(factor: &TableFactor) -> bool { + match factor { + TableFactor::Derived { lateral, .. } => *lateral, + TableFactor::Function { lateral, .. } => *lateral, + _ => false, + } +} + +/// Return `true` iff the given [`Join`] is lateral. +pub(crate) fn is_lateral_join(join: &Join) -> Result { + let is_lateral_syntax = is_lateral(&join.relation); + let is_apply_syntax = match join.join_operator { + JoinOperator::FullOuter(..) + | JoinOperator::RightOuter(..) + | JoinOperator::RightAnti(..) + | JoinOperator::RightSemi(..) + if is_lateral_syntax => + { + return not_impl_err!( + "LATERAL syntax is not supported for \ + FULL OUTER and RIGHT [OUTER | ANTI | SEMI] joins" + ); + } + JoinOperator::CrossApply | JoinOperator::OuterApply => true, + _ => false, + }; + Ok(is_lateral_syntax || is_apply_syntax) +} diff --git a/datafusion/sql/src/relation/mod.rs b/datafusion/sql/src/relation/mod.rs index 9380e569f2e4..256cc58e71dc 100644 --- a/datafusion/sql/src/relation/mod.rs +++ b/datafusion/sql/src/relation/mod.rs @@ -15,9 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{not_impl_err, plan_err, DFSchema, Result, TableReference}; +use datafusion_expr::builder::subquery_alias; use datafusion_expr::{expr::Unnest, Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_expr::{Subquery, SubqueryAlias}; use sqlparser::ast::{FunctionArg, FunctionArgExpr, TableFactor}; mod join; @@ -36,6 +42,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Some(func_args) = args { let tbl_func_name = name.0.first().unwrap().value.to_string(); let args = func_args + .args .into_iter() .flat_map(|arg| { if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg @@ -63,7 +70,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .build()?; (plan, alias) } else { - // normalize name and alias + // Normalize name and alias let table_ref = self.object_name_to_table_reference(name)?; let table_name = table_ref.to_string(); let cte = planner_context.get_cte(&table_name); @@ -101,11 +108,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { array_exprs, with_offset: false, with_offset_alias: None, + with_ordinality, } => { + if with_ordinality { + return not_impl_err!("UNNEST with ordinality is not supported yet"); + } + // Unnest table factor has empty input let schema = DFSchema::empty(); let input = LogicalPlanBuilder::empty(true).build()?; - // Unnest table factor can have multiple arugments. + // Unnest table factor can have multiple arguments. // We treat each argument as a separate unnest expression. let unnest_exprs = array_exprs .into_iter() @@ -137,10 +149,86 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } }; + + let optimized_plan = optimize_subquery_sort(plan)?.data; if let Some(alias) = alias { - self.apply_table_alias(plan, alias) + self.apply_table_alias(optimized_plan, alias) } else { - Ok(plan) + Ok(optimized_plan) } } + + pub(crate) fn create_relation_subquery( + &self, + subquery: TableFactor, + planner_context: &mut PlannerContext, + ) -> Result { + // At this point for a syntactically valid query the outer_from_schema is + // guaranteed to be set, so the `.unwrap()` call will never panic. This + // is the case because we only call this method for lateral table + // factors, and those can never be the first factor in a FROM list. This + // means we arrived here through the `for` loop in `plan_from_tables` or + // the `for` loop in `plan_table_with_joins`. + let old_from_schema = planner_context + .set_outer_from_schema(None) + .unwrap_or_else(|| Arc::new(DFSchema::empty())); + let new_query_schema = match planner_context.outer_query_schema() { + Some(old_query_schema) => { + let mut new_query_schema = old_from_schema.as_ref().clone(); + new_query_schema.merge(old_query_schema); + Some(Arc::new(new_query_schema)) + } + None => Some(Arc::clone(&old_from_schema)), + }; + let old_query_schema = planner_context.set_outer_query_schema(new_query_schema); + + let plan = self.create_relation(subquery, planner_context)?; + let outer_ref_columns = plan.all_out_ref_exprs(); + + planner_context.set_outer_query_schema(old_query_schema); + planner_context.set_outer_from_schema(Some(old_from_schema)); + + match plan { + LogicalPlan::SubqueryAlias(SubqueryAlias { input, alias, .. }) => { + subquery_alias( + LogicalPlan::Subquery(Subquery { + subquery: input, + outer_ref_columns, + }), + alias, + ) + } + plan => Ok(LogicalPlan::Subquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + })), + } + } +} + +fn optimize_subquery_sort(plan: LogicalPlan) -> Result> { + // When initializing subqueries, we examine sort options since they might be unnecessary. + // They are only important if the subquery result is affected by the ORDER BY statement, + // which can happen when we have: + // 1. DISTINCT ON / ARRAY_AGG ... => Handled by an `Aggregate` and its requirements. + // 2. RANK / ROW_NUMBER ... => Handled by a `WindowAggr` and its requirements. + // 3. LIMIT => Handled by a `Sort`, so we need to search for it. + let mut has_limit = false; + let new_plan = plan.transform_down(|c| { + if let LogicalPlan::Limit(_) = c { + has_limit = true; + return Ok(Transformed::no(c)); + } + match c { + LogicalPlan::Sort(s) => { + if !has_limit { + has_limit = false; + return Ok(Transformed::yes(s.input.as_ref().clone())); + } + Ok(Transformed::no(LogicalPlan::Sort(s))) + } + _ => Ok(Transformed::no(c)), + } + }); + new_plan } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index fdcef0ef6acc..80a08da5e35d 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -21,26 +21,27 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::{ check_columns_satisfy_exprs, extract_aliases, rebase_expr, resolve_aliases_to_exprs, - resolve_columns, resolve_positions_to_exprs, + resolve_columns, resolve_positions_to_exprs, rewrite_recursive_unnests_bottom_up, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; -use datafusion_common::{Column, UnnestOptions}; -use datafusion_expr::expr::{Alias, Unnest}; +use datafusion_common::{RecursionUnnestOption, UnnestOptions}; +use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ - normalize_col, normalize_col_with_schemas_and_ambiguity_check, + normalize_col, normalize_col_with_schemas_and_ambiguity_check, normalize_sorts, }; use datafusion_expr::utils::{ - expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, - find_aggregate_exprs, find_window_exprs, + expr_as_column_expr, expr_to_columns, find_aggregate_exprs, find_window_exprs, }; use datafusion_expr::{ - Expr, Filter, GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, + qualified_wildcard_with_options, wildcard_with_options, Aggregate, Expr, Filter, + GroupingSet, LogicalPlan, LogicalPlanBuilder, Partitioning, }; +use indexmap::IndexMap; use sqlparser::ast::{ - Distinct, Expr as SQLExpr, GroupByExpr, ReplaceSelectItem, WildcardAdditionalOptions, - WindowType, + Distinct, Expr as SQLExpr, GroupByExpr, NamedWindowExpr, OrderByExpr, + WildcardAdditionalOptions, WindowType, }; use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; @@ -49,9 +50,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn select_to_plan( &self, mut select: Select, + order_by: Vec, planner_context: &mut PlannerContext, ) -> Result { - // check for unsupported syntax first + // Check for unsupported syntax first if !select.cluster_by.is_empty() { return not_impl_err!("CLUSTER BY"); } @@ -68,17 +70,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("SORT BY"); } - // process `from` clause + // Process `from` clause let plan = self.plan_from_tables(select.from, planner_context)?; let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); - // process `where` clause + // Process `where` clause let base_plan = self.plan_selection(select.selection, plan, planner_context)?; - // handle named windows before processing the projection expression + // Handle named windows before processing the projection expression check_conflicting_windows(&select.named_window)?; match_window_definitions(&mut select.projection, &select.named_window)?; - // process the SELECT expressions, with wildcards expanded. + // Process the SELECT expressions let select_exprs = self.prepare_select_exprs( &base_plan, select.projection, @@ -86,15 +88,27 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; - // having and group by clause may reference aliases defined in select projection + // Having and group by clause may reference aliases defined in select projection let projected_plan = self.project(base_plan.clone(), select_exprs.clone())?; + // Place the fields of the base plan at the front so that when there are references // with the same name, the fields of the base plan will be searched first. // See https://github.com/apache/datafusion/issues/9162 let mut combined_schema = base_plan.schema().as_ref().clone(); combined_schema.merge(projected_plan.schema()); - // this alias map is resolved and looked up in both having exprs and group by exprs + // Order-by expressions prioritize referencing columns from the select list, + // then from the FROM clause. + let order_by_rex = self.order_by_to_sort_expr( + order_by, + projected_plan.schema().as_ref(), + planner_context, + true, + Some(base_plan.schema().as_ref()), + )?; + let order_by_rex = normalize_sorts(order_by_rex, &projected_plan)?; + + // This alias map is resolved and looked up in both having exprs and group by exprs let alias_map = extract_aliases(&select_exprs); // Optionally the HAVING expression. @@ -119,24 +133,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // // SELECT c1, MAX(c2) AS m FROM t GROUP BY c1 HAVING MAX(c2) > 10; // - let having_expr = resolve_aliases_to_exprs(&having_expr, &alias_map)?; + let having_expr = resolve_aliases_to_exprs(having_expr, &alias_map)?; normalize_col(having_expr, &projected_plan) }) .transpose()?; - // The outer expressions we will search through for - // aggregates. Aggregates may be sourced from the SELECT... - let mut aggr_expr_haystack = select_exprs.clone(); - // ... or from the HAVING. - if let Some(having_expr) = &having_expr_opt { - aggr_expr_haystack.push(having_expr.clone()); - } - + // The outer expressions we will search through for aggregates. + // Aggregates may be sourced from the SELECT list or from the HAVING expression. + let aggr_expr_haystack = select_exprs.iter().chain(having_expr_opt.iter()); // All of the aggregate expressions (deduplicated). - let aggr_exprs = find_aggregate_exprs(&aggr_expr_haystack); + let aggr_exprs = find_aggregate_exprs(aggr_expr_haystack); // All of the group by expressions - let group_by_exprs = if let GroupByExpr::Expressions(exprs) = select.group_by { + let group_by_exprs = if let GroupByExpr::Expressions(exprs, _) = select.group_by { exprs .into_iter() .map(|e| { @@ -145,16 +154,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &combined_schema, planner_context, )?; - // aliases from the projection can conflict with same-named expressions in the input + + // Aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); for f in base_plan.schema().fields() { alias_map.remove(f.name()); } let group_by_expr = - resolve_aliases_to_exprs(&group_by_expr, &alias_map)?; + resolve_aliases_to_exprs(group_by_expr, &alias_map)?; let group_by_expr = - resolve_positions_to_exprs(&group_by_expr, &select_exprs) - .unwrap_or(group_by_expr); + resolve_positions_to_exprs(group_by_expr, &select_exprs)?; let group_by_expr = normalize_col(group_by_expr, &projected_plan)?; self.validate_schema_satisfies_exprs( base_plan.schema(), @@ -179,7 +188,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect() }; - // process group by, aggregation or having + // Process group by, aggregation or having let (plan, mut select_exprs_post_aggr, having_expr_post_aggr) = if !group_by_exprs .is_empty() || !aggr_exprs.is_empty() @@ -200,13 +209,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = if let Some(having_expr_post_aggr) = having_expr_post_aggr { LogicalPlanBuilder::from(plan) - .filter(having_expr_post_aggr)? + .having(having_expr_post_aggr)? .build()? } else { plan }; - // process window function + // Process window function let window_func_exprs = find_window_exprs(&select_exprs_post_aggr); let plan = if window_func_exprs.is_empty() { @@ -214,7 +223,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } else { let plan = LogicalPlanBuilder::window_plan(plan, window_func_exprs.clone())?; - // re-write the projection + // Re-write the projection select_exprs_post_aggr = select_exprs_post_aggr .iter() .map(|expr| rebase_expr(expr, &window_func_exprs, &plan)) @@ -223,10 +232,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - // try process unnest expression or do the final projection + // Try processing unnest expression or do the final projection let plan = self.try_process_unnest(plan, select_exprs_post_aggr)?; - // process distinct clause + // Process distinct clause let plan = match select.distinct { None => Ok(plan), Some(Distinct::Distinct) => { @@ -248,9 +257,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; // Build the final plan - return LogicalPlanBuilder::from(base_plan) + LogicalPlanBuilder::from(base_plan) .distinct_on(on_expr, select_exprs, None)? - .build(); + .build() } }?; @@ -274,7 +283,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - Ok(plan) + self.order_by(plan, order_by_rex) } /// Try converting Expr(Unnest(Expr)) to Projection/Unnest/Projection @@ -283,64 +292,196 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: LogicalPlan, select_exprs: Vec, ) -> Result { - let mut unnest_columns = vec![]; - let mut inner_projection_exprs = vec![]; + // Try process group by unnest + let input = self.try_process_aggregate_unnest(input)?; + + let mut intermediate_plan = input; + let mut intermediate_select_exprs = select_exprs; + + // Each expr in select_exprs can contains multiple unnest stage + // The transformation happen bottom up, one at a time for each iteration + // Only exhaust the loop if no more unnest transformation is found + for i in 0.. { + let mut unnest_columns = IndexMap::new(); + // from which column used for projection, before the unnest happen + // including non unnest column and unnest column + let mut inner_projection_exprs = vec![]; + + // expr returned here maybe different from the originals in inner_projection_exprs + // for example: + // - unnest(struct_col) will be transformed into unnest(struct_col).field1, unnest(struct_col).field2 + // - unnest(array_col) will be transformed into unnest(array_col).element + // - unnest(array_col) + 1 will be transformed into unnest(array_col).element +1 + let outer_projection_exprs = rewrite_recursive_unnests_bottom_up( + &intermediate_plan, + &mut unnest_columns, + &mut inner_projection_exprs, + &intermediate_select_exprs, + )?; - let outer_projection_exprs = select_exprs - .into_iter() - .map(|expr| { - let Transformed { - data: transformed_expr, - transformed, - tnr: _, - } = expr.transform_up(|expr: Expr| { - if let Expr::Unnest(Unnest { expr: ref arg }) = expr { - let column_name = expr.display_name()?; - unnest_columns.push(column_name.clone()); - // Add alias for the argument expression, to avoid naming conflicts with other expressions - // in the select list. For example: `select unnest(col1), col1 from t`. - inner_projection_exprs - .push(arg.clone().alias(column_name.clone())); - Ok(Transformed::yes(Expr::Column(Column::from_name( - column_name, - )))) - } else { - Ok(Transformed::no(expr)) + // No more unnest is possible + if unnest_columns.is_empty() { + // The original expr does not contain any unnest + if i == 0 { + return LogicalPlanBuilder::from(intermediate_plan) + .project(intermediate_select_exprs)? + .build(); + } + break; + } else { + // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); + } + unnest_col_vec.push(col); + } + let plan = LogicalPlanBuilder::from(intermediate_plan) + .project(inner_projection_exprs)? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? + .build()?; + intermediate_plan = plan; + intermediate_select_exprs = outer_projection_exprs; + } + } + + LogicalPlanBuilder::from(intermediate_plan) + .project(intermediate_select_exprs)? + .build() + } + + fn try_process_aggregate_unnest(&self, input: LogicalPlan) -> Result { + match input { + LogicalPlan::Aggregate(agg) => { + let agg_expr = agg.aggr_expr.clone(); + let (new_input, new_group_by_exprs) = + self.try_process_group_by_unnest(agg)?; + LogicalPlanBuilder::from(new_input) + .aggregate(new_group_by_exprs, agg_expr)? + .build() + } + LogicalPlan::Filter(mut filter) => { + filter.input = + Arc::new(self.try_process_aggregate_unnest(Arc::unwrap_or_clone( + filter.input, + ))?); + Ok(LogicalPlan::Filter(filter)) + } + _ => Ok(input), + } + } + + /// Try converting Unnest(Expr) of group by to Unnest/Projection. + /// Return the new input and group_by_exprs of Aggregate. + /// Select exprs can be different from agg exprs, for example: + fn try_process_group_by_unnest( + &self, + agg: Aggregate, + ) -> Result<(LogicalPlan, Vec)> { + let mut aggr_expr_using_columns: Option> = None; + + let Aggregate { + input, + group_expr, + aggr_expr, + .. + } = agg; + + // Process unnest of group_by_exprs, and input of agg will be rewritten + // for example: + // + // ``` + // Aggregate: groupBy=[[UNNEST(Column(Column { relation: Some(Bare { table: "tab" }), name: "array_col" }))]], aggr=[[]] + // TableScan: tab + // ``` + // + // will be transformed into + // + // ``` + // Aggregate: groupBy=[[unnest(tab.array_col)]], aggr=[[]] + // Unnest: lists[unnest(tab.array_col)] structs[] + // Projection: tab.array_col AS unnest(tab.array_col) + // TableScan: tab + // ``` + let mut intermediate_plan = Arc::unwrap_or_clone(input); + let mut intermediate_select_exprs = group_expr; + + loop { + let mut unnest_columns = IndexMap::new(); + let mut inner_projection_exprs = vec![]; + + let outer_projection_exprs = rewrite_recursive_unnests_bottom_up( + &intermediate_plan, + &mut unnest_columns, + &mut inner_projection_exprs, + &intermediate_select_exprs, + )?; + + if unnest_columns.is_empty() { + break; + } else { + let mut unnest_options = UnnestOptions::new().with_preserve_nulls(false); + + let mut projection_exprs = match &aggr_expr_using_columns { + Some(exprs) => (*exprs).clone(), + None => { + let mut columns = HashSet::new(); + for expr in &aggr_expr { + expr.apply(|expr| { + if let Expr::Column(c) = expr { + columns.insert(Expr::Column(c.clone())); + } + Ok(TreeNodeRecursion::Continue) + }) + // As the closure always returns Ok, this "can't" error + .expect("Unexpected error"); + } + aggr_expr_using_columns = Some(columns.clone()); + columns } - })?; - - if !transformed { - if matches!(&transformed_expr, Expr::Column(_)) { - inner_projection_exprs.push(transformed_expr.clone()); - Ok(transformed_expr) - } else { - // We need to evaluate the expr in the inner projection, - // outer projection just select its name - let column_name = transformed_expr.display_name()?; - inner_projection_exprs.push(transformed_expr); - Ok(Expr::Column(Column::from_name(column_name))) + }; + projection_exprs.extend(inner_projection_exprs); + + let mut unnest_col_vec = vec![]; + + for (col, maybe_list_unnest) in unnest_columns.into_iter() { + if let Some(list_unnest) = maybe_list_unnest { + unnest_options = list_unnest.into_iter().fold( + unnest_options, + |options, unnest_list| { + options.with_recursions(RecursionUnnestOption { + input_column: col.clone(), + output_column: unnest_list.output_column, + depth: unnest_list.depth, + }) + }, + ); } - } else { - Ok(transformed_expr) + unnest_col_vec.push(col); } - }) - .collect::>>()?; - // Do the final projection - if unnest_columns.is_empty() { - LogicalPlanBuilder::from(input) - .project(inner_projection_exprs)? - .build() - } else { - let columns = unnest_columns.into_iter().map(|col| col.into()).collect(); - // Set preserve_nulls to false to ensure compatibility with DuckDB and PostgreSQL - let unnest_options = UnnestOptions::new().with_preserve_nulls(false); - LogicalPlanBuilder::from(input) - .project(inner_projection_exprs)? - .unnest_columns_with_options(columns, unnest_options)? - .project(outer_projection_exprs)? - .build() + intermediate_plan = LogicalPlanBuilder::from(intermediate_plan) + .project(projection_exprs)? + .unnest_columns_with_options(unnest_col_vec, unnest_options)? + .build()?; + + intermediate_select_exprs = outer_projection_exprs; + } } + + Ok((intermediate_plan, intermediate_select_exprs)) } fn plan_selection( @@ -360,6 +501,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let filter_expr = self.sql_to_expr(predicate_expr, plan.schema(), planner_context)?; + + // Check for aggregation functions + let aggregate_exprs = + find_aggregate_exprs(std::slice::from_ref(&filter_expr)); + if !aggregate_exprs.is_empty() { + return plan_err!( + "Aggregate functions are not allowed in the WHERE clause. Consider using HAVING instead" + ); + } + let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( @@ -385,27 +536,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match from.len() { 0 => Ok(LogicalPlanBuilder::empty(true).build()?), 1 => { - let from = from.remove(0); - self.plan_table_with_joins(from, planner_context) + let input = from.remove(0); + self.plan_table_with_joins(input, planner_context) } _ => { - let mut plans = from - .into_iter() - .map(|t| self.plan_table_with_joins(t, planner_context)); - - let mut left = LogicalPlanBuilder::from(plans.next().unwrap()?); - - for right in plans { - left = left.cross_join(right?)?; + let mut from = from.into_iter(); + + let mut left = LogicalPlanBuilder::from({ + let input = from.next().unwrap(); + self.plan_table_with_joins(input, planner_context)? + }); + let old_outer_from_schema = { + let left_schema = Some(Arc::clone(left.schema())); + planner_context.set_outer_from_schema(left_schema) + }; + for input in from { + // Join `input` with the current result (`left`). + let right = self.plan_table_with_joins(input, planner_context)?; + left = left.cross_join(right)?; + // Update the outer FROM schema. + let left_schema = Some(Arc::clone(left.schema())); + planner_context.set_outer_from_schema(left_schema); } - Ok(left.build()?) + planner_context.set_outer_from_schema(old_outer_from_schema); + left.build() } } } /// Returns the `Expr`'s corresponding to a SQL query's SELECT expressions. - /// - /// Wildcards are expanded into the concrete list of columns. fn prepare_select_exprs( &self, plan: &LogicalPlan, @@ -449,7 +608,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &[&[plan.schema()]], &plan.using_columns()?, )?; - let name = self.normalizer.normalize(alias); + let name = self.ident_normalizer.normalize(alias); // avoiding adding an alias if the column name is the same. let expr = match &col { Expr::Column(column) if column.name.eq(&name) => col, @@ -459,49 +618,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SelectItem::Wildcard(options) => { Self::check_wildcard_options(&options)?; - if empty_from { return plan_err!("SELECT * with no tables specified is not valid"); } - // do not expand from outer schema - let expanded_exprs = - expand_wildcard(plan.schema().as_ref(), plan, Some(&options))?; - // If there is a REPLACE statement, replace that column with the given - // replace expression. Column name remains the same. - if let Some(replace) = options.opt_replace { - self.replace_columns( - plan, - empty_from, - planner_context, - expanded_exprs, - replace, - ) - } else { - Ok(expanded_exprs) - } + let planned_options = self.plan_wildcard_options( + plan, + empty_from, + planner_context, + options, + )?; + Ok(vec![wildcard_with_options(planned_options)]) } - SelectItem::QualifiedWildcard(ref object_name, options) => { + SelectItem::QualifiedWildcard(object_name, options) => { Self::check_wildcard_options(&options)?; - let qualifier = format!("{object_name}"); - // do not expand from outer schema - let expanded_exprs = expand_qualified_wildcard( - &qualifier, - plan.schema().as_ref(), - Some(&options), + let qualifier = self.object_name_to_table_reference(object_name)?; + let planned_options = self.plan_wildcard_options( + plan, + empty_from, + planner_context, + options, )?; - // If there is a REPLACE statement, replace that column with the given - // replace expression. Column name remains the same. - if let Some(replace) = options.opt_replace { - self.replace_columns( - plan, - empty_from, - planner_context, - expanded_exprs, - replace, - ) - } else { - Ok(expanded_exprs) - } + Ok(vec![qualified_wildcard_with_options( + qualifier, + planned_options, + )]) } } } @@ -513,6 +653,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { opt_except: _opt_except, opt_rename, opt_replace: _opt_replace, + opt_ilike: _opt_ilike, } = options; if opt_rename.is_some() { @@ -525,40 +666,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } /// If there is a REPLACE statement in the projected expression in the form of - /// "REPLACE (some_column_within_an_expr AS some_column)", this function replaces - /// that column with the given replace expression. Column name remains the same. - /// Multiple REPLACEs are also possible with comma separations. - fn replace_columns( + /// "REPLACE (some_column_within_an_expr AS some_column)", we should plan the + /// replace expressions first. + fn plan_wildcard_options( &self, plan: &LogicalPlan, empty_from: bool, planner_context: &mut PlannerContext, - mut exprs: Vec, - replace: ReplaceSelectItem, - ) -> Result> { - for expr in exprs.iter_mut() { - if let Expr::Column(Column { name, .. }) = expr { - if let Some(item) = replace - .items - .iter() - .find(|item| item.column_name.value == *name) - { - let new_expr = self.sql_select_to_rex( + options: WildcardAdditionalOptions, + ) -> Result { + let planned_option = WildcardOptions { + ilike: options.opt_ilike, + exclude: options.opt_exclude, + except: options.opt_except, + replace: None, + rename: options.opt_rename, + }; + if let Some(replace) = options.opt_replace { + let replace_expr = replace + .items + .iter() + .map(|item| { + Ok(self.sql_select_to_rex( SelectItem::UnnamedExpr(item.expr.clone()), plan, empty_from, planner_context, )?[0] - .clone(); - *expr = Expr::Alias(Alias { - expr: Box::new(new_expr), - relation: None, - name: name.clone(), - }); - } - } + .clone()) + }) + .collect::>>()?; + let planned_replace = PlannedReplaceSelectItem { + items: replace.items.into_iter().map(|i| *i).collect(), + planned_expressions: replace_expr, + }; + Ok(planned_option.with_replace(planned_replace)) + } else { + Ok(planned_option) } - Ok(exprs) } /// Wrap a plan in a projection @@ -603,7 +748,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let plan = LogicalPlanBuilder::from(input.clone()) .aggregate(group_by_exprs.to_vec(), aggr_exprs.to_vec())? .build()?; - let group_by_exprs = if let LogicalPlan::Aggregate(agg) = &plan { &agg.group_expr } else { @@ -712,10 +856,17 @@ fn match_window_definitions( } | SelectItem::UnnamedExpr(SQLExpr::Function(f)) = proj { - for NamedWindowDefinition(window_ident, window_spec) in named_windows.iter() { + for NamedWindowDefinition(window_ident, window_expr) in named_windows.iter() { if let Some(WindowType::NamedWindow(ident)) = &f.over { if ident.eq(window_ident) { - f.over = Some(WindowType::WindowSpec(window_spec.clone())) + f.over = Some(match window_expr { + NamedWindowExpr::NamedWindow(ident) => { + WindowType::NamedWindow(ident.clone()) + } + NamedWindowExpr::WindowSpec(spec) => { + WindowType::WindowSpec(spec.clone()) + } + }) } } } diff --git a/datafusion/sql/src/set_expr.rs b/datafusion/sql/src/set_expr.rs index cbe41c33c729..248aad846996 100644 --- a/datafusion/sql/src/set_expr.rs +++ b/datafusion/sql/src/set_expr.rs @@ -27,7 +27,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context: &mut PlannerContext, ) -> Result { match set_expr { - SetExpr::Select(s) => self.select_to_plan(*s, planner_context), + SetExpr::Select(s) => self.select_to_plan(*s, vec![], planner_context), SetExpr::Values(v) => self.sql_values_to_plan(v, planner_context), SetExpr::SetOperation { op, diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index c81217aa7017..00949aa13ae1 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -30,14 +30,15 @@ use crate::planner::{ use crate::utils::normalize_ident; use arrow_schema::{DataType, Fields}; +use datafusion_common::error::_plan_err; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{ exec_err, not_impl_err, plan_datafusion_err, plan_err, schema_err, - unqualified_field_not_found, Column, Constraints, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, SchemaError, SchemaReference, - TableReference, ToDFSchema, + unqualified_field_not_found, Column, Constraint, Constraints, DFSchema, DFSchemaRef, + DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, + ToDFSchema, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -45,17 +46,19 @@ use datafusion_expr::utils::expr_to_columns; use datafusion_expr::{ cast, col, Analyze, CreateCatalog, CreateCatalogSchema, CreateExternalTable as PlanCreateExternalTable, CreateFunction, CreateFunctionBody, - CreateMemoryTable, CreateView, DescribeTable, DmlStatement, DropCatalogSchema, - DropFunction, DropTable, DropView, EmptyRelation, Explain, ExprSchemable, Filter, - LogicalPlan, LogicalPlanBuilder, OperateFunctionArg, PlanType, Prepare, SetVariable, + CreateIndex as PlanCreateIndex, CreateMemoryTable, CreateView, DescribeTable, + DmlStatement, DropCatalogSchema, DropFunction, DropTable, DropView, EmptyRelation, + Execute, Explain, Expr, ExprSchemable, Filter, LogicalPlan, LogicalPlanBuilder, + OperateFunctionArg, PlanType, Prepare, SetVariable, SortExpr, Statement as PlanStatement, ToStringifiedPlan, TransactionAccessMode, TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; -use sqlparser::ast; +use sqlparser::ast::{self, SqliteOnConflict}; use sqlparser::ast::{ - Assignment, ColumnDef, CreateTableOptions, DescribeAlias, Expr as SQLExpr, Expr, - FromTable, Ident, ObjectName, ObjectType, Query, SchemaName, SetExpr, + Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, + CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, + ObjectName, ObjectType, OneOrManyWithParens, Query, SchemaName, SetExpr, ShowCreateObject, ShowStatementFilter, Statement, TableConstraint, TableFactor, TableWithJoins, TransactionMode, UnaryOperator, Value, }; @@ -96,7 +99,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::Unique { + } => constraints.push(TableConstraint::Unique { name: name.clone(), columns: vec![column.name.clone()], characteristics: *characteristics, @@ -108,7 +111,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::PrimaryKey { + } => constraints.push(TableConstraint::PrimaryKey { name: name.clone(), columns: vec![column.name.clone()], characteristics: *characteristics, @@ -122,7 +125,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec constraints.push(ast::TableConstraint::ForeignKey { + } => constraints.push(TableConstraint::ForeignKey { name: name.clone(), columns: vec![], foreign_table: foreign_table.clone(), @@ -132,7 +135,7 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { - constraints.push(ast::TableConstraint::Check { + constraints.push(TableConstraint::Check { name: name.clone(), expr: Box::new(expr.clone()), }) @@ -146,7 +149,10 @@ fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec {} + | ast::ColumnOption::OnUpdate(_) + | ast::ColumnOption::Materialized(_) + | ast::ColumnOption::Ephemeral(_) + | ast::ColumnOption::Alias(_) => {} } } } @@ -194,8 +200,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match statement { Statement::ExplainTable { describe_alias: DescribeAlias::Describe, // only parse 'DESCRIBE table_name' and not 'EXPLAIN table_name' - hive_format: _, table_name, + .. } => self.describe_table_to_plan(table_name), Statement::Explain { verbose, @@ -212,11 +218,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::SetVariable { local, hivevar, - variable, + variables, value, - } => self.set_variable_to_plan(local, hivevar, &variable, value), + } => self.set_variable_to_plan(local, hivevar, &variables, value), - Statement::CreateTable { + Statement::CreateTable(CreateTable { + temporary, + external, + global, + transient, + volatile, + hive_distribution, + hive_formats, + file_format, + location, query, name, columns, @@ -225,8 +240,153 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { with_options, if_not_exists, or_replace, - .. - } if table_properties.is_empty() && with_options.is_empty() => { + without_rowid, + like, + clone, + engine, + comment, + auto_increment_offset, + default_charset, + collation, + on_commit, + on_cluster, + primary_key, + order_by, + partition_by, + cluster_by, + clustered_by, + options, + strict, + copy_grants, + enable_schema_evolution, + change_tracking, + data_retention_time_in_days, + max_data_extension_time_in_days, + default_ddl_collation, + with_aggregation_policy, + with_row_access_policy, + with_tags, + }) if table_properties.is_empty() && with_options.is_empty() => { + if temporary { + return not_impl_err!("Temporary tables not supported")?; + } + if external { + return not_impl_err!("External tables not supported")?; + } + if global.is_some() { + return not_impl_err!("Global tables not supported")?; + } + if transient { + return not_impl_err!("Transient tables not supported")?; + } + if volatile { + return not_impl_err!("Volatile tables not supported")?; + } + if hive_distribution != ast::HiveDistributionStyle::NONE { + return not_impl_err!( + "Hive distribution not supported: {hive_distribution:?}" + )?; + } + if !matches!( + hive_formats, + Some(ast::HiveFormat { + row_format: None, + serde_properties: None, + storage: None, + location: None, + }) + ) { + return not_impl_err!( + "Hive formats not supported: {hive_formats:?}" + )?; + } + if file_format.is_some() { + return not_impl_err!("File format not supported")?; + } + if location.is_some() { + return not_impl_err!("Location not supported")?; + } + if without_rowid { + return not_impl_err!("Without rowid not supported")?; + } + if like.is_some() { + return not_impl_err!("Like not supported")?; + } + if clone.is_some() { + return not_impl_err!("Clone not supported")?; + } + if engine.is_some() { + return not_impl_err!("Engine not supported")?; + } + if comment.is_some() { + return not_impl_err!("Comment not supported")?; + } + if auto_increment_offset.is_some() { + return not_impl_err!("Auto increment offset not supported")?; + } + if default_charset.is_some() { + return not_impl_err!("Default charset not supported")?; + } + if collation.is_some() { + return not_impl_err!("Collation not supported")?; + } + if on_commit.is_some() { + return not_impl_err!("On commit not supported")?; + } + if on_cluster.is_some() { + return not_impl_err!("On cluster not supported")?; + } + if primary_key.is_some() { + return not_impl_err!("Primary key not supported")?; + } + if order_by.is_some() { + return not_impl_err!("Order by not supported")?; + } + if partition_by.is_some() { + return not_impl_err!("Partition by not supported")?; + } + if cluster_by.is_some() { + return not_impl_err!("Cluster by not supported")?; + } + if clustered_by.is_some() { + return not_impl_err!("Clustered by not supported")?; + } + if options.is_some() { + return not_impl_err!("Options not supported")?; + } + if strict { + return not_impl_err!("Strict not supported")?; + } + if copy_grants { + return not_impl_err!("Copy grants not supported")?; + } + if enable_schema_evolution.is_some() { + return not_impl_err!("Enable schema evolution not supported")?; + } + if change_tracking.is_some() { + return not_impl_err!("Change tracking not supported")?; + } + if data_retention_time_in_days.is_some() { + return not_impl_err!("Data retention time in days not supported")?; + } + if max_data_extension_time_in_days.is_some() { + return not_impl_err!( + "Max data extension time in days not supported" + )?; + } + if default_ddl_collation.is_some() { + return not_impl_err!("Default DDL collation not supported")?; + } + if with_aggregation_policy.is_some() { + return not_impl_err!("With aggregation policy not supported")?; + } + if with_row_access_policy.is_some() { + return not_impl_err!("With row access policy not supported")?; + } + if with_tags.is_some() { + return not_impl_err!("With tags not supported")?; + } + // Merge inline constraints and existing constraints let mut all_constraints = constraints; let inline_constraints = calc_inline_constraints_from_columns(&columns); @@ -234,13 +394,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Build column default values let column_defaults = self.build_column_defaults(&columns, planner_context)?; + + let has_columns = !columns.is_empty(); + let schema = self.build_schema(columns)?.to_dfschema_ref()?; + if has_columns { + planner_context.set_table_schema(Some(Arc::clone(&schema))); + } + match query { Some(query) => { let plan = self.query_to_plan(*query, planner_context)?; let input_schema = plan.schema(); - let plan = if !columns.is_empty() { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; + let plan = if has_columns { if schema.fields().len() != input_schema.fields().len() { return plan_err!( "Mismatch: {} columns specified, but result has {} columns", @@ -268,7 +434,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan }; - let constraints = Constraints::new_from_table_constraints( + let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -281,18 +447,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, column_defaults, + temporary, }, ))) } None => { - let schema = self.build_schema(columns)?.to_dfschema_ref()?; let plan = EmptyRelation { produce_one_row: false, schema, }; let plan = LogicalPlan::EmptyRelation(plan); - let constraints = Constraints::new_from_table_constraints( + let constraints = Self::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -304,6 +470,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if_not_exists, or_replace, column_defaults, + temporary, }, ))) } @@ -312,12 +479,37 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Statement::CreateView { or_replace, + materialized, name, columns, query, options: CreateTableOptions::None, - .. + cluster_by, + comment, + with_no_schema_binding, + if_not_exists, + temporary, + to, } => { + if materialized { + return not_impl_err!("Materialized views not supported")?; + } + if !cluster_by.is_empty() { + return not_impl_err!("Cluster by not supported")?; + } + if comment.is_some() { + return not_impl_err!("Comment not supported")?; + } + if with_no_schema_binding { + return not_impl_err!("With no schema binding not supported")?; + } + if if_not_exists { + return not_impl_err!("If not exists not supported")?; + } + if to.is_some() { + return not_impl_err!("To not supported")?; + } + let columns = columns .into_iter() .map(|view_column_def| { @@ -339,6 +531,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), or_replace, definition: sql, + temporary, }))) } Statement::ShowCreate { obj_type, obj_name } => match obj_type { @@ -405,18 +598,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } ObjectType::Schema => { let name = match name { - TableReference::Bare { table } => Ok(SchemaReference::Bare { schema: table } ) , - TableReference::Partial { schema, table } => Ok(SchemaReference::Full { schema: table,catalog: schema }), + TableReference::Bare { table } => Ok(SchemaReference::Bare { schema: table }), + TableReference::Partial { schema, table } => Ok(SchemaReference::Full { schema: table, catalog: schema }), TableReference::Full { catalog: _, schema: _, table: _ } => { Err(ParserError("Invalid schema specifier (has 3 parts)".to_string())) - }, + } }?; Ok(LogicalPlan::Ddl(DdlStatement::DropCatalogSchema(DropCatalogSchema { name, if_exists, cascade, schema: DFSchemaRef::new(DFSchema::empty()), - })))}, + }))) + } _ => not_impl_err!( "Only `DROP TABLE/VIEW/SCHEMA ...` statement is supported currently" ), @@ -448,6 +642,30 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { input: Arc::new(plan), })) } + Statement::Execute { + name, + parameters, + using, + } => { + // `USING` is a MySQL-specific syntax and currently not supported. + if !using.is_empty() { + return not_impl_err!( + "Execute statement with USING is not supported" + ); + } + + let empty_schema = DFSchema::empty(); + let parameters = parameters + .into_iter() + .map(|expr| self.sql_to_expr(expr, &empty_schema, planner_context)) + .collect::>>()?; + + Ok(LogicalPlan::Execute(Execute { + name: ident_to_string(&name), + parameters, + schema: DFSchemaRef::new(empty_schema), + })) + } Statement::ShowTables { extended, @@ -463,7 +681,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter, } => self.show_columns_to_plan(extended, full, table_name, filter), - Statement::Insert { + Statement::Insert(Insert { or, into, table_name, @@ -477,12 +695,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, ignore, table_alias, - replace_into, + mut replace_into, priority, insert_alias, - } => { - if or.is_some() { - plan_err!("Inserts with or clauses not supported")?; + }) => { + if let Some(or) = or { + match or { + SqliteOnConflict::Replace => replace_into = true, + _ => plan_err!("Inserts with {or} clause is not supported")?, + } } if partitioned.is_some() { plan_err!("Partitioned inserts not yet supported")?; @@ -510,9 +731,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "Inserts with a table alias not supported: {table_alias:?}" )? }; - if replace_into { - plan_err!("Inserts with a `REPLACE INTO` clause not supported")? - }; if let Some(priority) = priority { plan_err!( "Inserts with a `PRIORITY` clause not supported: {priority:?}" @@ -522,7 +740,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Inserts with an alias not supported")?; } let _ = into; // optional keyword doesn't change behavior - self.insert_to_plan(table_name, columns, source, overwrite) + self.insert_to_plan(table_name, columns, source, overwrite, replace_into) } Statement::Update { table, @@ -537,7 +755,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.update_to_plan(table, assignments, from, selection) } - Statement::Delete { + Statement::Delete(Delete { tables, using, selection, @@ -545,7 +763,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { from, order_by, limit, - } => { + }) => { if !tables.is_empty() { plan_err!("DELETE not supported")?; } @@ -582,7 +800,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let isolation_level: ast::TransactionIsolationLevel = modes .iter() - .filter_map(|m: &ast::TransactionMode| match m { + .filter_map(|m: &TransactionMode| match m { TransactionMode::AccessMode(_) => None, TransactionMode::IsolationLevel(level) => Some(level), }) @@ -591,7 +809,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .unwrap_or(ast::TransactionIsolationLevel::Serializable); let access_mode: ast::TransactionAccessMode = modes .iter() - .filter_map(|m: &ast::TransactionMode| match m { + .filter_map(|m: &TransactionMode| match m { TransactionMode::AccessMode(mode) => Some(mode), TransactionMode::IsolationLevel(_) => None, }) @@ -652,7 +870,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, args, return_type, - params, + function_body, + behavior, + language, + .. } => { let return_type = match return_type { Some(t) => Some(self.convert_data_type(&t)?), @@ -687,14 +908,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => None, }; - // at the moment functions can't be qualified `schema.name` + // At the moment functions can't be qualified `schema.name` let name = match &name.0[..] { [] => exec_err!("Function should have name")?, [n] => n.value.clone(), [..] => not_impl_err!("Qualified functions are not supported")?, }; // - // convert resulting expression to data fusion expression + // Convert resulting expression to data fusion expression // let arg_types = args.as_ref().map(|arg| { arg.iter().map(|t| t.data_type.clone()).collect::>() @@ -702,9 +923,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut planner_context = PlannerContext::new() .with_prepare_param_data_types(arg_types.unwrap_or_default()); - let result_expression = match params.return_ { + let function_body = match function_body { Some(r) => Some(self.sql_to_expr( - r, + match r { + ast::CreateFunctionBody::AsBeforeOptions(expr) => expr, + ast::CreateFunctionBody::AsAfterOptions(expr) => expr, + ast::CreateFunctionBody::Return(expr) => expr, + }, &DFSchema::empty(), &mut planner_context, )?), @@ -712,14 +937,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }; let params = CreateFunctionBody { - language: params.language, - behavior: params.behavior.map(|b| match b { + language, + behavior: behavior.map(|b| match b { ast::FunctionBehavior::Immutable => Volatility::Immutable, ast::FunctionBehavior::Stable => Volatility::Stable, ast::FunctionBehavior::Volatile => Volatility::Volatile, }), - as_: params.as_.map(|m| m.into()), - return_: result_expression, + function_body, }; let statement = DdlStatement::CreateFunction(CreateFunction { @@ -739,10 +963,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { func_desc, .. } => { - // according to postgresql documentation it can be only one function + // According to postgresql documentation it can be only one function // specified in drop statement if let Some(desc) = func_desc.first() { - // at the moment functions can't be qualified `schema.name` + // At the moment functions can't be qualified `schema.name` let name = match &desc.name.0[..] { [] => exec_err!("Function should have name")?, [n] => n.value.clone(), @@ -758,6 +982,42 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { exec_err!("Function name not provided") } } + Statement::CreateIndex(CreateIndex { + name, + table_name, + using, + columns, + unique, + if_not_exists, + .. + }) => { + let name: Option = name.as_ref().map(object_name_to_string); + let table = self.object_name_to_table_reference(table_name)?; + let table_schema = self + .context_provider + .get_table_source(table.clone())? + .schema() + .to_dfschema_ref()?; + let using: Option = using.as_ref().map(ident_to_string); + let columns = self.order_by_to_sort_expr( + columns, + &table_schema, + planner_context, + false, + None, + )?; + Ok(LogicalPlan::Ddl(DdlStatement::CreateIndex( + PlanCreateIndex { + name, + table, + using, + columns, + unique, + if_not_exists, + schema: DFSchemaRef::new(DFSchema::empty()), + }, + ))) + } _ => { not_impl_err!("Unsupported SQL statement: {sql:?}") } @@ -798,7 +1058,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { filter: Option, ) -> Result { if self.has_table("information_schema", "tables") { - // we only support the basic "SHOW TABLES" + // We only support the basic "SHOW TABLES" // https://github.com/apache/datafusion/issues/3188 if db_name.is_some() || filter.is_some() || full || extended { plan_err!("Unsupported parameters to SHOW TABLES") @@ -829,7 +1089,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } fn copy_to_plan(&self, statement: CopyToStatement) -> Result { - // determine if source is table or query and handle accordingly + // Determine if source is table or query and handle accordingly let copy_source = statement.source; let (input, input_schema, table_ref) = match copy_source { CopyToSource::Relation(object_name) => { @@ -839,72 +1099,47 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { self.context_provider.get_table_source(table_ref.clone())?; let plan = LogicalPlanBuilder::scan(table_name, table_source, None)?.build()?; - let input_schema = plan.schema().clone(); + let input_schema = Arc::clone(plan.schema()); (plan, input_schema, Some(table_ref)) } CopyToSource::Query(query) => { let plan = self.query_to_plan(query, &mut PlannerContext::new())?; - let input_schema = plan.schema().clone(); + let input_schema = Arc::clone(plan.schema()); (plan, input_schema, None) } }; - let mut options = HashMap::new(); - for (key, value) in statement.options { - let value_string = match value { - Value::SingleQuotedString(s) => s.to_string(), - Value::DollarQuotedString(s) => s.to_string(), - Value::UnQuotedString(s) => s.to_string(), - Value::Number(_, _) | Value::Boolean(_) => value.to_string(), - Value::DoubleQuotedString(_) - | Value::EscapedStringLiteral(_) - | Value::NationalStringLiteral(_) - | Value::SingleQuotedByteStringLiteral(_) - | Value::DoubleQuotedByteStringLiteral(_) - | Value::RawStringLiteral(_) - | Value::HexStringLiteral(_) - | Value::Null - | Value::Placeholder(_) => { - return plan_err!("Unsupported Value in COPY statement {}", value); - } - }; - if !(&key.contains('.')) { - // If config does not belong to any namespace, assume it is - // a format option and apply the format prefix for backwards - // compatibility. + let options_map = self.parse_options_map(statement.options, true)?; - let renamed_key = format!("format.{}", key); - options.insert(renamed_key.to_lowercase(), value_string.to_lowercase()); + let maybe_file_type = if let Some(stored_as) = &statement.stored_as { + if let Ok(ext_file_type) = self.context_provider.get_file_type(stored_as) { + Some(ext_file_type) } else { - options.insert(key.to_lowercase(), value_string.to_lowercase()); + None } - } - - let file_type = if let Some(file_type) = statement.stored_as { - FileType::from_str(&file_type).map_err(|_| { - DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) - })? } else { - let e = || { - DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." - .to_string(), - ) - }; - // try to infer file format from file extension - let extension: &str = &Path::new(&statement.target) - .extension() - .ok_or_else(e)? - .to_str() - .ok_or_else(e)? - .to_lowercase(); - - FileType::from_str(extension).map_err(|e| { - DataFusionError::Configuration(format!( - "{}. Use STORED AS to define file format.", - e - )) - })? + None + }; + + let file_type = match maybe_file_type { + Some(ft) => ft, + None => { + let e = || { + DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) + }; + // Try to infer file format from file extension + let extension: &str = &Path::new(&statement.target) + .extension() + .ok_or_else(e)? + .to_str() + .ok_or_else(e)? + .to_lowercase(); + + self.context_provider.get_file_type(extension)? + } }; let partition_by = statement @@ -919,9 +1154,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(input), output_url: statement.target, - format_options: file_type.into(), + file_type, partition_by, - options, + options: options_map, })) } @@ -930,22 +1165,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { order_exprs: Vec, schema: &DFSchemaRef, planner_context: &mut PlannerContext, - ) -> Result>> { - // Ask user to provide a schema if schema is empty. + ) -> Result>> { if !order_exprs.is_empty() && schema.fields().is_empty() { - return plan_err!( - "Provide a schema before specifying the order while creating a table." - ); + let results = order_exprs + .iter() + .map(|lex_order| { + let result = lex_order + .iter() + .map(|order_by_expr| { + let ordered_expr = &order_by_expr.expr; + let ordered_expr = ordered_expr.to_owned(); + let ordered_expr = self + .sql_expr_to_logical_expr( + ordered_expr, + schema, + planner_context, + ) + .unwrap(); + let asc = order_by_expr.asc.unwrap_or(true); + let nulls_first = order_by_expr.nulls_first.unwrap_or(!asc); + + SortExpr::new(ordered_expr, asc, nulls_first) + }) + .collect::>(); + result + }) + .collect::>>(); + + return Ok(results); } let mut all_results = vec![]; for expr in order_exprs { // Convert each OrderByExpr to a SortExpr: let expr_vec = - self.order_by_to_sort_expr(&expr, schema, planner_context, true)?; + self.order_by_to_sort_expr(expr, schema, planner_context, true, None)?; // Verify that columns of all SortExprs exist in the schema: - for expr in expr_vec.iter() { - for column in expr.to_columns()?.iter() { + for sort in expr_vec.iter() { + for column in sort.expr.column_refs().iter() { if !schema.has_column(column) { // Return an error if any column is not in the schema: return plan_err!("Column {column} is not in schema"); @@ -968,12 +1225,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { name, columns, file_type, - has_header, - delimiter, location, table_partition_cols, if_not_exists, - file_compression_type, + temporary, order_exprs, unbounded, options, @@ -985,8 +1240,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let inline_constraints = calc_inline_constraints_from_columns(&columns); all_constraints.extend(inline_constraints); + let options_map = self.parse_options_map(options, false)?; + + let compression = options_map + .get("format.compression") + .map(|c| CompressionTypeVariant::from_str(c)) + .transpose()?; if (file_type == "PARQUET" || file_type == "AVRO" || file_type == "ARROW") - && file_compression_type != CompressionTypeVariant::UNCOMPRESSED + && compression + .map(|c| c != CompressionTypeVariant::UNCOMPRESSED) + .unwrap_or(false) { plan_err!( "File compression type cannot be set for PARQUET, AVRO, or ARROW files." @@ -1007,31 +1270,126 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let ordered_exprs = self.build_order_by(order_exprs, &df_schema, &mut planner_context)?; - // External tables do not support schemas at the moment, so the name is just a table name - let name = TableReference::bare(name); + let name = self.object_name_to_table_reference(name)?; let constraints = - Constraints::new_from_table_constraints(&all_constraints, &df_schema)?; + Self::new_constraint_from_table_constraints(&all_constraints, &df_schema)?; Ok(LogicalPlan::Ddl(DdlStatement::CreateExternalTable( PlanCreateExternalTable { schema: df_schema, name, location, file_type, - has_header, - delimiter, table_partition_cols, if_not_exists, + temporary, definition, - file_compression_type, order_exprs: ordered_exprs, unbounded, - options, + options: options_map, constraints, column_defaults, }, ))) } + /// Convert each [TableConstraint] to corresponding [Constraint] + fn new_constraint_from_table_constraints( + constraints: &[TableConstraint], + df_schema: &DFSchemaRef, + ) -> Result { + let constraints = constraints + .iter() + .map(|c: &TableConstraint| match c { + TableConstraint::Unique { name, columns, .. } => { + let field_names = df_schema.field_names(); + // Get unique constraint indices in the schema: + let indices = columns + .iter() + .map(|u| { + let idx = field_names + .iter() + .position(|item| *item == u.value) + .ok_or_else(|| { + let name = name + .as_ref() + .map(|name| format!("with name '{name}' ")) + .unwrap_or("".to_string()); + DataFusionError::Execution( + format!("Column for unique constraint {}not found in schema: {}", name,u.value) + ) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::Unique(indices)) + } + TableConstraint::PrimaryKey { columns, .. } => { + let field_names = df_schema.field_names(); + // Get primary key indices in the schema: + let indices = columns + .iter() + .map(|pk| { + let idx = field_names + .iter() + .position(|item| *item == pk.value) + .ok_or_else(|| { + DataFusionError::Execution(format!( + "Column for primary key not found in schema: {}", + pk.value + )) + })?; + Ok(idx) + }) + .collect::>>()?; + Ok(Constraint::PrimaryKey(indices)) + } + TableConstraint::ForeignKey { .. } => { + _plan_err!("Foreign key constraints are not currently supported") + } + TableConstraint::Check { .. } => { + _plan_err!("Check constraints are not currently supported") + } + TableConstraint::Index { .. } => { + _plan_err!("Indexes are not currently supported") + } + TableConstraint::FulltextOrSpatial { .. } => { + _plan_err!("Indexes are not currently supported") + } + }) + .collect::>>()?; + Ok(Constraints::new_unverified(constraints)) + } + + fn parse_options_map( + &self, + options: Vec<(String, Value)>, + allow_duplicates: bool, + ) -> Result> { + let mut options_map = HashMap::new(); + for (key, value) in options { + if !allow_duplicates && options_map.contains_key(&key) { + return plan_err!("Option {key} is specified multiple times"); + } + + let Some(value_string) = self.value_normalizer.normalize(value.clone()) + else { + return plan_err!("Unsupported Value {}", value); + }; + + if !(&key.contains('.')) { + // If config does not belong to any namespace, assume it is + // a format option and apply the format prefix for backwards + // compatibility. + let renamed_key = format!("format.{}", key); + options_map.insert(renamed_key.to_lowercase(), value_string); + } else { + options_map.insert(key.to_lowercase(), value_string); + } + } + + Ok(options_map) + } + /// Generate a plan for EXPLAIN ... that will print out a plan /// /// Note this is the sqlparser explain statement, not the @@ -1097,6 +1455,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // we could introduce alias in OptionDefinition if this string matching thing grows format!("{base_query} WHERE name = 'datafusion.execution.time_zone'") } else { + // These values are what are used to make the information_schema table, so we just + // check here, before actually planning or executing the query, if it would produce no + // results, and error preemptively if it would (for a better UX) + let is_valid_variable = self + .context_provider + .options() + .entries() + .iter() + .any(|opt| opt.key == variable); + + if !is_valid_variable { + return plan_err!( + "'{variable}' is not a variable which can be viewed with 'SHOW'" + ); + } + format!("{base_query} WHERE name = '{variable}'") }; @@ -1110,8 +1484,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, local: bool, hivevar: bool, - variable: &ObjectName, - value: Vec, + variables: &OneOrManyWithParens, + value: Vec, ) -> Result { if local { return not_impl_err!("LOCAL is not supported"); @@ -1121,35 +1495,31 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return not_impl_err!("HIVEVAR is not supported"); } - let variable = object_name_to_string(variable); + let variable = match variables { + OneOrManyWithParens::One(v) => object_name_to_string(v), + OneOrManyWithParens::Many(vs) => { + return not_impl_err!( + "SET only supports single variable assignment: {vs:?}" + ); + } + }; let mut variable_lower = variable.to_lowercase(); if variable_lower == "timezone" || variable_lower == "time.zone" { - // we could introduce alias in OptionDefinition if this string matching thing grows + // We could introduce alias in OptionDefinition if this string matching thing grows variable_lower = "datafusion.execution.time_zone".to_string(); } - // parse value string from Expr + // Parse value string from Expr let value_string = match &value[0] { SQLExpr::Identifier(i) => ident_to_string(i), - SQLExpr::Value(v) => match v { - Value::SingleQuotedString(s) => s.to_string(), - Value::DollarQuotedString(s) => s.to_string(), - Value::Number(_, _) | Value::Boolean(_) => v.to_string(), - Value::DoubleQuotedString(_) - | Value::UnQuotedString(_) - | Value::EscapedStringLiteral(_) - | Value::NationalStringLiteral(_) - | Value::SingleQuotedByteStringLiteral(_) - | Value::DoubleQuotedByteStringLiteral(_) - | Value::RawStringLiteral(_) - | Value::HexStringLiteral(_) - | Value::Null - | Value::Placeholder(_) => { + SQLExpr::Value(v) => match crate::utils::value_to_string(v) { + None => { return plan_err!("Unsupported Value {}", value[0]); } + Some(v) => v, }, - // for capture signed number e.g. +8, -8 + // For capture signed number e.g. +8, -8 SQLExpr::UnaryOp { op, expr } => match op { UnaryOperator::Plus => format!("+{expr}"), UnaryOperator::Minus => format!("-{expr}"), @@ -1174,7 +1544,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn delete_to_plan( &self, table_name: ObjectName, - predicate_expr: Option, + predicate_expr: Option, ) -> Result { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; @@ -1206,12 +1576,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - let plan = LogicalPlan::Dml(DmlStatement { - table_name: table_ref, - table_schema: schema.into(), - op: WriteOp::Delete, - input: Arc::new(source), - }); + let plan = LogicalPlan::Dml(DmlStatement::new( + table_ref, + schema.into(), + WriteOp::Delete, + Arc::new(source), + )); Ok(plan) } @@ -1220,7 +1590,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table: TableWithJoins, assignments: Vec, from: Option, - predicate_expr: Option, + predicate_expr: Option, ) -> Result { let (table_name, table_alias) = match &table.relation { TableFactor::Table { name, alias, .. } => (name.clone(), alias.clone()), @@ -1240,8 +1610,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut assign_map = assignments .iter() .map(|assign| { - let col_name: &Ident = assign - .id + let cols = match &assign.target { + AssignmentTarget::ColumnName(cols) => cols, + _ => plan_err!("Tuples are not supported")?, + }; + let col_name: &Ident = cols + .0 .iter() .last() .ok_or_else(|| plan_datafusion_err!("Empty column id"))?; @@ -1249,7 +1623,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_schema.field_with_unqualified_name(&col_name.value)?; Ok((col_name.value.clone(), assign.value.clone())) }) - .collect::>>()?; + .collect::>>()?; // Build scan, join with from table if it exists. let mut input_tables = vec![table]; @@ -1288,8 +1662,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &mut planner_context, )?; // Update placeholder's datatype to the type of the target column - if let datafusion_expr::Expr::Placeholder(placeholder) = &mut expr - { + if let Expr::Placeholder(placeholder) = &mut expr { placeholder.data_type = placeholder .data_type .take() @@ -1301,14 +1674,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { None => { // If the target table has an alias, use it to qualify the column name if let Some(alias) = &table_alias { - datafusion_expr::Expr::Column(Column::new( - Some(self.normalizer.normalize(alias.name.clone())), + Expr::Column(Column::new( + Some(self.ident_normalizer.normalize(alias.name.clone())), field.name(), )) } else { - datafusion_expr::Expr::Column(Column::from(( - qualifier, field, - ))) + Expr::Column(Column::from((qualifier, field))) } } }; @@ -1318,12 +1689,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let source = project(source, exprs)?; - let plan = LogicalPlan::Dml(DmlStatement { + let plan = LogicalPlan::Dml(DmlStatement::new( table_name, table_schema, - op: WriteOp::Update, - input: Arc::new(source), - }); + WriteOp::Update, + Arc::new(source), + )); Ok(plan) } @@ -1333,6 +1704,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { columns: Vec, source: Box, overwrite: bool, + replace_into: bool, ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; @@ -1342,10 +1714,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Get insert fields and target table's value indices // - // if value_indices[i] = Some(j), it means that the value of the i-th target table's column is + // If value_indices[i] = Some(j), it means that the value of the i-th target table's column is // derived from the j-th output of the source. // - // if value_indices[i] = None, it means that the value of the i-th target table's column is + // If value_indices[i] = None, it means that the value of the i-th target table's column is // not provided, and should be filled with a default value later. let (fields, value_indices) = if columns.is_empty() { // Empty means we're inserting into all columns of the table @@ -1359,7 +1731,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut value_indices = vec![None; table_schema.fields().len()]; let fields = columns .into_iter() - .map(|c| self.normalizer.normalize(c)) + .map(|c| self.ident_normalizer.normalize(c)) .enumerate() .map(|(i, c)| { let column_index = table_schema @@ -1383,7 +1755,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let SetExpr::Values(ast::Values { rows, .. }) = (*source.body).clone() { for row in rows.iter() { for (idx, val) in row.iter().enumerate() { - if let ast::Expr::Value(Value::Placeholder(name)) = val { + if let SQLExpr::Value(Value::Placeholder(name)) = val { let name = name.replace('$', "").parse::().map_err(|_| { plan_datafusion_err!("Can't parse placeholder: {name}") @@ -1416,37 +1788,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|(i, value_index)| { let target_field = table_schema.field(i); let expr = match value_index { - Some(v) => datafusion_expr::Expr::Column(Column::from( - source.schema().qualified_field(v), - )) - .cast_to(target_field.data_type(), source.schema())?, + Some(v) => { + Expr::Column(Column::from(source.schema().qualified_field(v))) + .cast_to(target_field.data_type(), source.schema())? + } // The value is not specified. Fill in the default value for the column. None => table_source .get_column_default(target_field.name()) .cloned() .unwrap_or_else(|| { // If there is no default for the column, then the default is NULL - datafusion_expr::Expr::Literal(ScalarValue::Null) + Expr::Literal(ScalarValue::Null) }) .cast_to(target_field.data_type(), &DFSchema::empty())?, }; Ok(expr.alias(target_field.name())) }) - .collect::>>()?; + .collect::>>()?; let source = project(source, exprs)?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto + let insert_op = match (overwrite, replace_into) { + (false, false) => InsertOp::Append, + (true, false) => InsertOp::Overwrite, + (false, true) => InsertOp::Replace, + (true, true) => plan_err!("Conflicting insert operations: `overwrite` and `replace_into` cannot both be true")?, }; - let plan = LogicalPlan::Dml(DmlStatement { + let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - table_schema: Arc::new(table_schema), - op, - input: Arc::new(source), - }); + Arc::new(table_schema), + WriteOp::Insert(insert_op), + Arc::new(source), + )); Ok(plan) } @@ -1476,7 +1849,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_ref = self.object_name_to_table_reference(sql_table_name)?; let _ = self.context_provider.get_table_source(table_ref)?; - // treat both FULL and EXTENDED as the same + // Treat both FULL and EXTENDED as the same let select_list = if full || extended { "*" } else { diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index 0a76aee2e066..2de1ce9125a7 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -41,66 +41,69 @@ pub(super) struct QueryBuilder { #[allow(dead_code)] impl QueryBuilder { pub fn with(&mut self, value: Option) -> &mut Self { - let new = self; - new.with = value; - new + self.with = value; + self } pub fn body(&mut self, value: Box) -> &mut Self { - let new = self; - new.body = Option::Some(value); - new + self.body = Some(value); + self + } + pub fn take_body(&mut self) -> Option> { + self.body.take() } pub fn order_by(&mut self, value: Vec) -> &mut Self { - let new = self; - new.order_by = value; - new + self.order_by = value; + self } pub fn limit(&mut self, value: Option) -> &mut Self { - let new = self; - new.limit = value; - new + self.limit = value; + self } pub fn limit_by(&mut self, value: Vec) -> &mut Self { - let new = self; - new.limit_by = value; - new + self.limit_by = value; + self } pub fn offset(&mut self, value: Option) -> &mut Self { - let new = self; - new.offset = value; - new + self.offset = value; + self } pub fn fetch(&mut self, value: Option) -> &mut Self { - let new = self; - new.fetch = value; - new + self.fetch = value; + self } pub fn locks(&mut self, value: Vec) -> &mut Self { - let new = self; - new.locks = value; - new + self.locks = value; + self } pub fn for_clause(&mut self, value: Option) -> &mut Self { - let new = self; - new.for_clause = value; - new + self.for_clause = value; + self } pub fn build(&self) -> Result { + let order_by = if self.order_by.is_empty() { + None + } else { + Some(ast::OrderBy { + exprs: self.order_by.clone(), + interpolate: None, + }) + }; + Ok(ast::Query { with: self.with.clone(), body: match self.body { Some(ref value) => value.clone(), - None => { - return Result::Err(Into::into(UninitializedFieldError::from("body"))) - } + None => return Err(Into::into(UninitializedFieldError::from("body"))), }, - order_by: self.order_by.clone(), + order_by, limit: self.limit.clone(), limit_by: self.limit_by.clone(), offset: self.offset.clone(), fetch: self.fetch.clone(), locks: self.locks.clone(), for_clause: self.for_clause.clone(), + settings: None, + format_clause: None, }) } fn create_empty() -> Self { @@ -145,90 +148,95 @@ pub(super) struct SelectBuilder { #[allow(dead_code)] impl SelectBuilder { pub fn distinct(&mut self, value: Option) -> &mut Self { - let new = self; - new.distinct = value; - new + self.distinct = value; + self } pub fn top(&mut self, value: Option) -> &mut Self { - let new = self; - new.top = value; - new + self.top = value; + self } pub fn projection(&mut self, value: Vec) -> &mut Self { - let new = self; - new.projection = value; - new + self.projection = value; + self } pub fn already_projected(&self) -> bool { !self.projection.is_empty() } pub fn into(&mut self, value: Option) -> &mut Self { - let new = self; - new.into = value; - new + self.into = value; + self } pub fn from(&mut self, value: Vec) -> &mut Self { - let new = self; - new.from = value; - new + self.from = value; + self } pub fn push_from(&mut self, value: TableWithJoinsBuilder) -> &mut Self { - let new = self; - new.from.push(value); - new + self.from.push(value); + self } pub fn pop_from(&mut self) -> Option { self.from.pop() } pub fn lateral_views(&mut self, value: Vec) -> &mut Self { - let new = self; - new.lateral_views = value; - new + self.lateral_views = value; + self } pub fn selection(&mut self, value: Option) -> &mut Self { - let new = self; - new.selection = value; - new + // With filter pushdown optimization, the LogicalPlan can have filters defined as part of `TableScan` and `Filter` nodes. + // To avoid overwriting one of the filters, we combine the existing filter with the additional filter. + // Example: | + // | Projection: customer.c_phone AS cntrycode, customer.c_acctbal | + // | Filter: CAST(customer.c_acctbal AS Decimal128(38, 6)) > () | + // | Subquery: + // | .. | + // | TableScan: customer, full_filters=[customer.c_mktsegment = Utf8("BUILDING")] + match (&self.selection, value) { + (Some(existing_selection), Some(new_selection)) => { + self.selection = Some(ast::Expr::BinaryOp { + left: Box::new(existing_selection.clone()), + op: ast::BinaryOperator::And, + right: Box::new(new_selection), + }); + } + (None, Some(new_selection)) => { + self.selection = Some(new_selection); + } + (_, None) => (), + } + + self } pub fn group_by(&mut self, value: ast::GroupByExpr) -> &mut Self { - let new = self; - new.group_by = Option::Some(value); - new + self.group_by = Some(value); + self } pub fn cluster_by(&mut self, value: Vec) -> &mut Self { - let new = self; - new.cluster_by = value; - new + self.cluster_by = value; + self } pub fn distribute_by(&mut self, value: Vec) -> &mut Self { - let new = self; - new.distribute_by = value; - new + self.distribute_by = value; + self } pub fn sort_by(&mut self, value: Vec) -> &mut Self { - let new = self; - new.sort_by = value; - new + self.sort_by = value; + self } pub fn having(&mut self, value: Option) -> &mut Self { - let new = self; - new.having = value; - new + self.having = value; + self } pub fn named_window(&mut self, value: Vec) -> &mut Self { - let new = self; - new.named_window = value; - new + self.named_window = value; + self } pub fn qualify(&mut self, value: Option) -> &mut Self { - let new = self; - new.qualify = value; - new + self.qualify = value; + self } pub fn value_table_mode(&mut self, value: Option) -> &mut Self { - let new = self; - new.value_table_mode = value; - new + self.value_table_mode = value; + self } pub fn build(&self) -> Result { Ok(ast::Select { @@ -239,16 +247,14 @@ impl SelectBuilder { from: self .from .iter() - .map(|b| b.build()) + .filter_map(|b| b.build().transpose()) .collect::, BuilderError>>()?, lateral_views: self.lateral_views.clone(), selection: self.selection.clone(), group_by: match self.group_by { Some(ref value) => value.clone(), None => { - return Result::Err(Into::into(UninitializedFieldError::from( - "group_by", - ))) + return Err(Into::into(UninitializedFieldError::from("group_by"))) } }, cluster_by: self.cluster_by.clone(), @@ -258,6 +264,9 @@ impl SelectBuilder { named_window: self.named_window.clone(), qualify: self.qualify.clone(), value_table_mode: self.value_table_mode, + connect_by: None, + window_before_qualify: false, + prewhere: None, }) } fn create_empty() -> Self { @@ -269,7 +278,7 @@ impl SelectBuilder { from: Default::default(), lateral_views: Default::default(), selection: Default::default(), - group_by: Some(ast::GroupByExpr::Expressions(Vec::new())), + group_by: Some(ast::GroupByExpr::Expressions(Vec::new(), Vec::new())), cluster_by: Default::default(), distribute_by: Default::default(), sort_by: Default::default(), @@ -295,34 +304,30 @@ pub(super) struct TableWithJoinsBuilder { #[allow(dead_code)] impl TableWithJoinsBuilder { pub fn relation(&mut self, value: RelationBuilder) -> &mut Self { - let new = self; - new.relation = Option::Some(value); - new + self.relation = Some(value); + self } pub fn joins(&mut self, value: Vec) -> &mut Self { - let new = self; - new.joins = value; - new + self.joins = value; + self } pub fn push_join(&mut self, value: ast::Join) -> &mut Self { - let new = self; - new.joins.push(value); - new + self.joins.push(value); + self } - pub fn build(&self) -> Result { - Ok(ast::TableWithJoins { - relation: match self.relation { - Some(ref value) => value.build()?, - None => { - return Result::Err(Into::into(UninitializedFieldError::from( - "relation", - ))) - } + pub fn build(&self) -> Result, BuilderError> { + match self.relation { + Some(ref value) => match value.build()? { + Some(relation) => Ok(Some(ast::TableWithJoins { + relation, + joins: self.joins.clone(), + })), + None => Ok(None), }, - joins: self.joins.clone(), - }) + None => Err(Into::into(UninitializedFieldError::from("relation"))), + } } fn create_empty() -> Self { Self { @@ -347,6 +352,7 @@ pub(super) struct RelationBuilder { enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), + Empty, } #[allow(dead_code)] @@ -355,14 +361,16 @@ impl RelationBuilder { self.relation.is_some() } pub fn table(&mut self, value: TableRelationBuilder) -> &mut Self { - let new = self; - new.relation = Option::Some(TableFactorBuilder::Table(value)); - new + self.relation = Some(TableFactorBuilder::Table(value)); + self } pub fn derived(&mut self, value: DerivedRelationBuilder) -> &mut Self { - let new = self; - new.relation = Option::Some(TableFactorBuilder::Derived(value)); - new + self.relation = Some(TableFactorBuilder::Derived(value)); + self + } + pub fn empty(&mut self) -> &mut Self { + self.relation = Some(TableFactorBuilder::Empty); + self } pub fn alias(&mut self, value: Option) -> &mut Self { let new = self; @@ -373,17 +381,17 @@ impl RelationBuilder { Some(TableFactorBuilder::Derived(ref mut rel_builder)) => { rel_builder.alias = value; } + Some(TableFactorBuilder::Empty) => (), None => (), } new } - pub fn build(&self) -> Result { + pub fn build(&self) -> Result, BuilderError> { Ok(match self.relation { - Some(TableFactorBuilder::Table(ref value)) => value.build()?, - Some(TableFactorBuilder::Derived(ref value)) => value.build()?, - None => { - return Result::Err(Into::into(UninitializedFieldError::from("relation"))) - } + Some(TableFactorBuilder::Table(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::Derived(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::Empty) => None, + None => return Err(Into::into(UninitializedFieldError::from("relation"))), }) } fn create_empty() -> Self { @@ -411,48 +419,44 @@ pub(super) struct TableRelationBuilder { #[allow(dead_code)] impl TableRelationBuilder { pub fn name(&mut self, value: ast::ObjectName) -> &mut Self { - let new = self; - new.name = Option::Some(value); - new + self.name = Some(value); + self } pub fn alias(&mut self, value: Option) -> &mut Self { - let new = self; - new.alias = value; - new + self.alias = value; + self } pub fn args(&mut self, value: Option>) -> &mut Self { - let new = self; - new.args = value; - new + self.args = value; + self } pub fn with_hints(&mut self, value: Vec) -> &mut Self { - let new = self; - new.with_hints = value; - new + self.with_hints = value; + self } pub fn version(&mut self, value: Option) -> &mut Self { - let new = self; - new.version = value; - new + self.version = value; + self } pub fn partitions(&mut self, value: Vec) -> &mut Self { - let new = self; - new.partitions = value; - new + self.partitions = value; + self } pub fn build(&self) -> Result { Ok(ast::TableFactor::Table { name: match self.name { Some(ref value) => value.clone(), - None => { - return Result::Err(Into::into(UninitializedFieldError::from("name"))) - } + None => return Err(Into::into(UninitializedFieldError::from("name"))), }, alias: self.alias.clone(), - args: self.args.clone(), + args: self.args.clone().map(|args| ast::TableFunctionArgs { + args, + settings: None, + }), with_hints: self.with_hints.clone(), version: self.version.clone(), partitions: self.partitions.clone(), + with_ordinality: false, }) } fn create_empty() -> Self { @@ -481,36 +485,27 @@ pub(super) struct DerivedRelationBuilder { #[allow(dead_code)] impl DerivedRelationBuilder { pub fn lateral(&mut self, value: bool) -> &mut Self { - let new = self; - new.lateral = Option::Some(value); - new + self.lateral = Some(value); + self } pub fn subquery(&mut self, value: Box) -> &mut Self { - let new = self; - new.subquery = Option::Some(value); - new + self.subquery = Some(value); + self } pub fn alias(&mut self, value: Option) -> &mut Self { - let new = self; - new.alias = value; - new + self.alias = value; + self } fn build(&self) -> Result { Ok(ast::TableFactor::Derived { lateral: match self.lateral { Some(ref value) => *value, - None => { - return Result::Err(Into::into(UninitializedFieldError::from( - "lateral", - ))) - } + None => return Err(Into::into(UninitializedFieldError::from("lateral"))), }, subquery: match self.subquery { Some(ref value) => value.clone(), None => { - return Result::Err(Into::into(UninitializedFieldError::from( - "subquery", - ))) + return Err(Into::into(UninitializedFieldError::from("subquery"))) } }, alias: self.alias.clone(), @@ -536,7 +531,7 @@ impl Default for DerivedRelationBuilder { pub(super) struct UninitializedFieldError(&'static str); impl UninitializedFieldError { - /// Create a new `UnitializedFieldError` for the specified field name. + /// Create a new `UninitializedFieldError` for the specified field name. pub fn new(field_name: &'static str) -> Self { UninitializedFieldError(field_name) } diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index ccbd388fc4b7..88159ab6df15 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -15,59 +15,666 @@ // specific language governing permissions and limitations // under the License. -/// Dialect is used to capture dialect specific syntax. -/// Note: this trait will eventually be replaced by the Dialect in the SQLparser package +use std::sync::Arc; + +use arrow_schema::TimeUnit; +use datafusion_expr::Expr; +use regex::Regex; +use sqlparser::{ + ast::{self, Function, Ident, ObjectName, TimezoneInfo}, + keywords::ALL_KEYWORDS, +}; + +use datafusion_common::Result; + +use super::{utils::date_part_to_sql, Unparser}; + +/// `Dialect` to use for Unparsing +/// +/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`) +/// but this behavior can be overridden as needed +/// +/// **Note**: This trait will eventually be replaced by the Dialect in the SQLparser package /// /// See -pub trait Dialect { - fn identifier_quote_style(&self) -> Option; +/// See also the discussion in +pub trait Dialect: Send + Sync { + /// Return the character used to quote identifiers. + fn identifier_quote_style(&self, _identifier: &str) -> Option; + + /// Does the dialect support specifying `NULLS FIRST/LAST` in `ORDER BY` clauses? + fn supports_nulls_first_in_sort(&self) -> bool { + true + } + + /// Does the dialect use TIMESTAMP to represent Date64 rather than DATETIME? + /// E.g. Trino, Athena and Dremio does not have DATETIME data type + fn use_timestamp_for_date64(&self) -> bool { + false + } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::PostgresVerbose + } + + /// Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? + /// E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::Double + } + + /// The SQL type to use for Arrow Utf8 unparsing + /// Most dialects use VARCHAR, but some, like MySQL, require CHAR + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Varchar(None) + } + + /// The SQL type to use for Arrow LargeUtf8 unparsing + /// Most dialects use TEXT, but some, like MySQL, require CHAR + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text + } + + /// The date field extract style to use: `DateFieldExtractStyle` + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::DatePart + } + + /// The SQL type to use for Arrow Int64 unparsing + /// Most dialects use BigInt, but some, like MySQL, require SIGNED + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::BigInt(None) + } + + /// The SQL type to use for Arrow Int32 unparsing + /// Most dialects use Integer, but some, like MySQL, require SIGNED + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Integer(None) + } + + /// The SQL type to use for Timestamp unparsing + /// Most dialects use Timestamp, but some, like MySQL, require Datetime + /// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + tz: &Option>, + ) -> ast::DataType { + let tz_info = match tz { + Some(_) => TimezoneInfo::WithTimeZone, + None => TimezoneInfo::None, + }; + + ast::DataType::Timestamp(None, tz_info) + } + + /// The SQL type to use for Arrow Date32 unparsing + /// Most dialects use Date, but some, like SQLite require TEXT + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Date + } + + /// Does the dialect support specifying column aliases as part of alias table definition? + /// (SELECT col1, col2 from my_table) AS my_table_alias(col1_alias, col2_alias) + fn supports_column_alias_in_table_alias(&self) -> bool { + true + } + + /// Whether the dialect requires a table alias for any subquery in the FROM clause + /// This affects behavior when deriving logical plans for Sort, Limit, etc. + fn requires_derived_table_alias(&self) -> bool { + false + } + + /// Allows the dialect to override scalar function unparsing if the dialect has specific rules. + /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is + /// a custom implementation for the function. + fn scalar_function_to_sql_overrides( + &self, + _unparser: &Unparser, + _func_name: &str, + _args: &[Expr], + ) -> Result> { + Ok(None) + } +} + +/// `IntervalStyle` to use for unparsing +/// +/// +/// different DBMS follows different standards, popular ones are: +/// postgres_verbose: '2 years 15 months 100 weeks 99 hours 123456789 milliseconds' which is +/// compatible with arrow display format, as well as duckdb +/// sql standard format is '1-2' for year-month, or '1 10:10:10.123456' for day-time +/// +#[derive(Clone, Copy)] +pub enum IntervalStyle { + PostgresVerbose, + SQLStandard, + MySQL, } + +/// Datetime subfield extraction style for unparsing +/// +/// `` +/// Different DBMSs follow different standards; popular ones are: +/// date_part('YEAR', date '2001-02-16') +/// EXTRACT(YEAR from date '2001-02-16') +/// Some DBMSs, like Postgres, support both, whereas others like MySQL require EXTRACT. +#[derive(Clone, Copy, PartialEq)] +pub enum DateFieldExtractStyle { + DatePart, + Extract, + Strftime, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { - fn identifier_quote_style(&self) -> Option { - Some('"') + fn identifier_quote_style(&self, identifier: &str) -> Option { + let identifier_regex = Regex::new(r"^[a-zA-Z_][a-zA-Z0-9_]*$").unwrap(); + let id_upper = identifier.to_uppercase(); + // Special case ignore "ID", see https://github.com/sqlparser-rs/sqlparser-rs/issues/1382 + // ID is a keyword in ClickHouse, but we don't want to quote it when unparsing SQL here + if (id_upper != "ID" && ALL_KEYWORDS.contains(&id_upper.as_str())) + || !identifier_regex.is_match(identifier) + { + Some('"') + } else { + None + } } } pub struct PostgreSqlDialect {} impl Dialect for PostgreSqlDialect { - fn identifier_quote_style(&self) -> Option { + fn identifier_quote_style(&self, _: &str) -> Option { Some('"') } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::PostgresVerbose + } + + fn float64_ast_dtype(&self) -> ast::DataType { + ast::DataType::DoublePrecision + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "round" { + return Ok(Some( + self.round_to_sql_enforce_numeric(unparser, func_name, args)?, + )); + } + + Ok(None) + } +} + +impl PostgreSqlDialect { + fn round_to_sql_enforce_numeric( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result { + let mut args = unparser.function_args_to_sql(args)?; + + // Enforce the first argument to be Numeric + if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) = + args.first_mut() + { + if let ast::Expr::Cast { data_type, .. } = expr { + // Don't create an additional cast wrapper if we can update the existing one + *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None); + } else { + // Wrap the expression in a new cast + *expr = ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(expr.clone()), + data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None), + format: None, + }; + } + } + + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } } pub struct MySqlDialect {} impl Dialect for MySqlDialect { - fn identifier_quote_style(&self) -> Option { + fn identifier_quote_style(&self, _: &str) -> Option { Some('`') } + + fn supports_nulls_first_in_sort(&self) -> bool { + false + } + + fn interval_style(&self) -> IntervalStyle { + IntervalStyle::MySQL + } + + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::Extract + } + + fn int64_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) + } + + fn int32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![]) + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + _tz: &Option>, + ) -> ast::DataType { + ast::DataType::Datetime(None) + } + + fn requires_derived_table_alias(&self) -> bool { + true + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct SqliteDialect {} impl Dialect for SqliteDialect { - fn identifier_quote_style(&self) -> Option { + fn identifier_quote_style(&self, _: &str) -> Option { Some('`') } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + DateFieldExtractStyle::Strftime + } + + fn date32_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text + } + + fn supports_column_alias_in_table_alias(&self) -> bool { + false + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } } pub struct CustomDialect { identifier_quote_style: Option, + supports_nulls_first_in_sort: bool, + use_timestamp_for_date64: bool, + interval_style: IntervalStyle, + float64_ast_dtype: ast::DataType, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, + date_field_extract_style: DateFieldExtractStyle, + int64_cast_dtype: ast::DataType, + int32_cast_dtype: ast::DataType, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, + date32_cast_dtype: ast::DataType, + supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, +} + +impl Default for CustomDialect { + fn default() -> Self { + Self { + identifier_quote_style: None, + supports_nulls_first_in_sort: true, + use_timestamp_for_date64: false, + interval_style: IntervalStyle::SQLStandard, + float64_ast_dtype: ast::DataType::Double, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, + date_field_extract_style: DateFieldExtractStyle::DatePart, + int64_cast_dtype: ast::DataType::BigInt(None), + int32_cast_dtype: ast::DataType::Integer(None), + timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), + timestamp_tz_cast_dtype: ast::DataType::Timestamp( + None, + TimezoneInfo::WithTimeZone, + ), + date32_cast_dtype: ast::DataType::Date, + supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, + } + } } impl CustomDialect { + // Create a CustomDialect + #[deprecated(note = "please use `CustomDialectBuilder` instead")] pub fn new(identifier_quote_style: Option) -> Self { Self { identifier_quote_style, + ..Default::default() } } } impl Dialect for CustomDialect { - fn identifier_quote_style(&self) -> Option { + fn identifier_quote_style(&self, _: &str) -> Option { self.identifier_quote_style } + + fn supports_nulls_first_in_sort(&self) -> bool { + self.supports_nulls_first_in_sort + } + + fn use_timestamp_for_date64(&self) -> bool { + self.use_timestamp_for_date64 + } + + fn interval_style(&self) -> IntervalStyle { + self.interval_style + } + + fn float64_ast_dtype(&self) -> ast::DataType { + self.float64_ast_dtype.clone() + } + + fn utf8_cast_dtype(&self) -> ast::DataType { + self.utf8_cast_dtype.clone() + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + self.large_utf8_cast_dtype.clone() + } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + self.date_field_extract_style + } + + fn int64_cast_dtype(&self) -> ast::DataType { + self.int64_cast_dtype.clone() + } + + fn int32_cast_dtype(&self) -> ast::DataType { + self.int32_cast_dtype.clone() + } + + fn timestamp_cast_dtype( + &self, + _time_unit: &TimeUnit, + tz: &Option>, + ) -> ast::DataType { + if tz.is_some() { + self.timestamp_tz_cast_dtype.clone() + } else { + self.timestamp_cast_dtype.clone() + } + } + + fn date32_cast_dtype(&self) -> ast::DataType { + self.date32_cast_dtype.clone() + } + + fn supports_column_alias_in_table_alias(&self) -> bool { + self.supports_column_alias_in_table_alias + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "date_part" { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + + Ok(None) + } + + fn requires_derived_table_alias(&self) -> bool { + self.requires_derived_table_alias + } +} + +/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern +/// +/// +/// # Examples +/// +/// Building a custom dialect with all default options set in CustomDialectBuilder::new() +/// but with `use_timestamp_for_date64` overridden to `true` +/// +/// ``` +/// use datafusion_sql::unparser::dialect::CustomDialectBuilder; +/// let dialect = CustomDialectBuilder::new() +/// .with_use_timestamp_for_date64(true) +/// .build(); +/// ``` +pub struct CustomDialectBuilder { + identifier_quote_style: Option, + supports_nulls_first_in_sort: bool, + use_timestamp_for_date64: bool, + interval_style: IntervalStyle, + float64_ast_dtype: ast::DataType, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, + date_field_extract_style: DateFieldExtractStyle, + int64_cast_dtype: ast::DataType, + int32_cast_dtype: ast::DataType, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, + date32_cast_dtype: ast::DataType, + supports_column_alias_in_table_alias: bool, + requires_derived_table_alias: bool, +} + +impl Default for CustomDialectBuilder { + fn default() -> Self { + Self::new() + } +} + +impl CustomDialectBuilder { + pub fn new() -> Self { + Self { + identifier_quote_style: None, + supports_nulls_first_in_sort: true, + use_timestamp_for_date64: false, + interval_style: IntervalStyle::PostgresVerbose, + float64_ast_dtype: ast::DataType::Double, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, + date_field_extract_style: DateFieldExtractStyle::DatePart, + int64_cast_dtype: ast::DataType::BigInt(None), + int32_cast_dtype: ast::DataType::Integer(None), + timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), + timestamp_tz_cast_dtype: ast::DataType::Timestamp( + None, + TimezoneInfo::WithTimeZone, + ), + date32_cast_dtype: ast::DataType::Date, + supports_column_alias_in_table_alias: true, + requires_derived_table_alias: false, + } + } + + pub fn build(self) -> CustomDialect { + CustomDialect { + identifier_quote_style: self.identifier_quote_style, + supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, + use_timestamp_for_date64: self.use_timestamp_for_date64, + interval_style: self.interval_style, + float64_ast_dtype: self.float64_ast_dtype, + utf8_cast_dtype: self.utf8_cast_dtype, + large_utf8_cast_dtype: self.large_utf8_cast_dtype, + date_field_extract_style: self.date_field_extract_style, + int64_cast_dtype: self.int64_cast_dtype, + int32_cast_dtype: self.int32_cast_dtype, + timestamp_cast_dtype: self.timestamp_cast_dtype, + timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype, + date32_cast_dtype: self.date32_cast_dtype, + supports_column_alias_in_table_alias: self + .supports_column_alias_in_table_alias, + requires_derived_table_alias: self.requires_derived_table_alias, + } + } + + /// Customize the dialect with a specific identifier quote style, e.g. '`', '"' + pub fn with_identifier_quote_style(mut self, identifier_quote_style: char) -> Self { + self.identifier_quote_style = Some(identifier_quote_style); + self + } + + /// Customize the dialect to support `NULLS FIRST` in `ORDER BY` clauses + pub fn with_supports_nulls_first_in_sort( + mut self, + supports_nulls_first_in_sort: bool, + ) -> Self { + self.supports_nulls_first_in_sort = supports_nulls_first_in_sort; + self + } + + /// Customize the dialect to uses TIMESTAMP when casting Date64 rather than DATETIME + pub fn with_use_timestamp_for_date64( + mut self, + use_timestamp_for_date64: bool, + ) -> Self { + self.use_timestamp_for_date64 = use_timestamp_for_date64; + self + } + + /// Customize the dialect with a specific interval style listed in `IntervalStyle` + pub fn with_interval_style(mut self, interval_style: IntervalStyle) -> Self { + self.interval_style = interval_style; + self + } + + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. + pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { + self.float64_ast_dtype = float64_ast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for Utf8 casting: VARCHAR, CHAR, etc. + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { + self.utf8_cast_dtype = utf8_cast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for LargeUtf8 casting: TEXT, CHAR, etc. + pub fn with_large_utf8_cast_dtype( + mut self, + large_utf8_cast_dtype: ast::DataType, + ) -> Self { + self.large_utf8_cast_dtype = large_utf8_cast_dtype; + self + } + + /// Customize the dialect with a specific date field extract style listed in `DateFieldExtractStyle` + pub fn with_date_field_extract_style( + mut self, + date_field_extract_style: DateFieldExtractStyle, + ) -> Self { + self.date_field_extract_style = date_field_extract_style; + self + } + + /// Customize the dialect with a specific SQL type for Int64 casting: BigInt, SIGNED, etc. + pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> Self { + self.int64_cast_dtype = int64_cast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for Int32 casting: Integer, SIGNED, etc. + pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) -> Self { + self.int32_cast_dtype = int32_cast_dtype; + self + } + + /// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc. + pub fn with_timestamp_cast_dtype( + mut self, + timestamp_cast_dtype: ast::DataType, + timestamp_tz_cast_dtype: ast::DataType, + ) -> Self { + self.timestamp_cast_dtype = timestamp_cast_dtype; + self.timestamp_tz_cast_dtype = timestamp_tz_cast_dtype; + self + } + + pub fn with_date32_cast_dtype(mut self, date32_cast_dtype: ast::DataType) -> Self { + self.date32_cast_dtype = date32_cast_dtype; + self + } + + /// Customize the dialect to support column aliases as part of alias table definition + pub fn with_supports_column_alias_in_table_alias( + mut self, + supports_column_alias_in_table_alias: bool, + ) -> Self { + self.supports_column_alias_in_table_alias = supports_column_alias_in_table_alias; + self + } + + pub fn with_requires_derived_table_alias( + mut self, + requires_derived_table_alias: bool, + ) -> Self { + self.requires_derived_table_alias = requires_derived_table_alias; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d091fbe14dbd..b41b0a54b86f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,45 +15,83 @@ // specific language governing permissions and limitations // under the License. -use arrow_array::{Date32Array, Date64Array}; +use datafusion_expr::expr::Unnest; +use sqlparser::ast::Value::SingleQuotedString; +use sqlparser::ast::{ + self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, + TimezoneInfo, UnaryOperator, +}; +use std::sync::Arc; +use std::vec; + +use super::dialect::IntervalStyle; +use super::Unparser; +use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType}; +use arrow::util::display::array_value_to_string; +use arrow_array::types::{ + ArrowTemporalType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, + Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, + TimestampNanosecondType, TimestampSecondType, +}; +use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use datafusion_common::{ - internal_datafusion_err, not_impl_err, plan_err, Column, Result, ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, Result, + ScalarValue, }; use datafusion_expr::{ - expr::{ - AggregateFunctionDefinition, Alias, Exists, InList, ScalarFunction, Sort, - WindowFunction, - }, - Between, BinaryExpr, Case, Cast, Expr, Like, Operator, -}; -use sqlparser::ast::{ - self, Expr as AstExpr, Function, FunctionArg, Ident, UnaryOperator, + expr::{Alias, Exists, InList, ScalarFunction, Sort, WindowFunction}, + Between, BinaryExpr, Case, Cast, Expr, GroupingSet, Like, Operator, TryCast, }; -use super::Unparser; - -/// Convert a DataFusion [`Expr`] to `sqlparser::ast::Expr` +/// Convert a DataFusion [`Expr`] to [`ast::Expr`] +/// +/// This function is the opposite of [`SqlToRel::sql_to_expr`] and can be used +/// to, among other things, convert [`Expr`]s to SQL strings. Such strings could +/// be used to pass filters or other expressions to another SQL engine. +/// +/// # Errors +/// +/// Throws an error if [`Expr`] can not be represented by an [`ast::Expr`] /// -/// This function is the opposite of `SqlToRel::sql_to_expr` and can -/// be used to, among other things, convert `Expr`s to strings. +/// # See Also +/// +/// * [`Unparser`] for more control over the conversion to SQL +/// * [`plan_to_sql`] for converting a [`LogicalPlan`] to SQL /// /// # Example /// ``` /// use datafusion_expr::{col, lit}; /// use datafusion_sql::unparser::expr_to_sql; -/// let expr = col("a").gt(lit(4)); -/// let sql = expr_to_sql(&expr).unwrap(); -/// -/// assert_eq!(format!("{}", sql), "(\"a\" > 4)") +/// let expr = col("a").gt(lit(4)); // form an expression `a > 4` +/// let sql = expr_to_sql(&expr).unwrap(); // convert to ast::Expr +/// // use the Display impl to convert to SQL text +/// assert_eq!(sql.to_string(), "(a > 4)") /// ``` +/// +/// [`SqlToRel::sql_to_expr`]: crate::planner::SqlToRel::sql_to_expr +/// [`plan_to_sql`]: crate::unparser::plan_to_sql +/// [`LogicalPlan`]: datafusion_expr::logical_plan::LogicalPlan pub fn expr_to_sql(expr: &Expr) -> Result { let unparser = Unparser::default(); unparser.expr_to_sql(expr) } +const LOWEST: &BinaryOperator = &BinaryOperator::Or; +// Closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// (https://www.postgresql.org/docs/7.2/sql-precedence.html) +const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; + impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { + let mut root_expr = self.expr_to_sql_inner(expr)?; + if self.pretty { + root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST); + } + Ok(root_expr) + } + + fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { Expr::InList(InList { expr, @@ -62,43 +100,25 @@ impl Unparser<'_> { }) => { let list_expr = list .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?; Ok(ast::Expr::InList { - expr: Box::new(self.expr_to_sql(expr)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), list: list_expr, negated: *negated, }) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let func_name = func_def.name(); + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let func_name = func.name(); - let args = args - .iter() - .map(|e| { - if matches!(e, Expr::Wildcard { qualifier: None }) { - Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) - } else { - self.expr_to_sql(e).map(|e| { - FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - }) - } - }) - .collect::>>()?; + if let Some(expr) = self + .dialect + .scalar_function_to_sql_overrides(self, func_name, args)? + { + return Ok(expr); + } - Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { - value: func_name.to_string(), - quote_style: None, - }]), - args, - filter: None, - null_treatment: None, - over: None, - distinct: false, - special: false, - order_by: vec![], - })) + self.scalar_function_to_sql(func_name, args) } Expr::Between(Between { expr, @@ -106,9 +126,9 @@ impl Unparser<'_> { low, high, }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; - let sql_low = self.expr_to_sql(low)?; - let sql_high = self.expr_to_sql(high)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; + let sql_low = self.expr_to_sql_inner(low)?; + let sql_high = self.expr_to_sql_inner(high)?; Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( sql_parser_expr, *negated, @@ -118,8 +138,8 @@ impl Unparser<'_> { } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = self.expr_to_sql(left.as_ref())?; - let r = self.expr_to_sql(right.as_ref())?; + let l = self.expr_to_sql_inner(left.as_ref())?; + let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) @@ -131,21 +151,21 @@ impl Unparser<'_> { }) => { let conditions = when_then_expr .iter() - .map(|(w, _)| self.expr_to_sql(w)) + .map(|(w, _)| self.expr_to_sql_inner(w)) .collect::>>()?; let results = when_then_expr .iter() - .map(|(_, t)| self.expr_to_sql(t)) + .map(|(_, t)| self.expr_to_sql_inner(t)) .collect::>>()?; let operand = match expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, None => None, }; let else_result = match else_expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, @@ -160,24 +180,89 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; - Ok(ast::Expr::Cast { - expr: Box::new(inner_expr), - data_type: self.arrow_dtype_to_ast_dtype(data_type)?, - format: None, - }) + let inner_expr = self.expr_to_sql_inner(expr)?; + match data_type { + DataType::Dictionary(_, _) => match inner_expr { + // Dictionary values don't need to be cast to other types when rewritten back to sql + ast::Expr::Value(_) => Ok(inner_expr), + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + }, + _ => Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }), + } } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr), + Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { - fun: _, - args: _, - partition_by: _, - order_by: _, - window_frame: _, + fun, + args, + partition_by, + order_by, + window_frame, null_treatment: _, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let func_name = fun.name(); + + let args = self.function_args_to_sql(args)?; + + let units = match window_frame.units { + datafusion_expr::window_frame::WindowFrameUnits::Rows => { + ast::WindowFrameUnits::Rows + } + datafusion_expr::window_frame::WindowFrameUnits::Range => { + ast::WindowFrameUnits::Range + } + datafusion_expr::window_frame::WindowFrameUnits::Groups => { + ast::WindowFrameUnits::Groups + } + }; + + let order_by = order_by + .iter() + .map(|sort_expr| self.sort_to_sql(sort_expr)) + .collect::>>()?; + + let start_bound = self.convert_bound(&window_frame.start_bound)?; + let end_bound = self.convert_bound(&window_frame.end_bound)?; + let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec { + window_name: None, + partition_by: partition_by + .iter() + .map(|e| self.expr_to_sql_inner(e)) + .collect::>>()?, + order_by, + window_frame: Some(ast::WindowFrame { + units, + start_bound, + end_bound: Option::from(end_bound), + }), + })); + + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) } Expr::SimilarTo(Like { negated, @@ -194,47 +279,35 @@ impl Unparser<'_> { case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, - expr: Box::new(self.expr_to_sql(expr)?), - pattern: Box::new(self.expr_to_sql(pattern)?), - escape_char: *escape_char, + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), + escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { - let func_name = if let AggregateFunctionDefinition::BuiltIn(built_in) = - &agg.func_def - { - built_in.name() - } else { - return not_impl_err!( - "Only built in agg functions are supported, got {agg:?}" - ); - }; - - let args = agg - .args - .iter() - .map(|e| { - if matches!(e, Expr::Wildcard { qualifier: None }) { - Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) - } else { - self.expr_to_sql(e).map(|e| { - FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) - }) - } - }) - .collect::>>()?; + let func_name = agg.func.name(); + let args = self.function_args_to_sql(&agg.args)?; + let filter = match &agg.filter { + Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), + None => None, + }; Ok(ast::Expr::Function(Function { - name: ast::ObjectName(vec![Ident { + name: ObjectName(vec![Ident { value: func_name.to_string(), quote_style: None, }]), - args, - filter: None, + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: agg + .distinct + .then_some(ast::DuplicateTreatment::Distinct), + args, + clauses: vec![], + }), + filter, null_treatment: None, over: None, - distinct: agg.distinct, - special: false, - order_by: vec![], + within_group: vec![], + parameters: ast::FunctionArguments::None, })) } Expr::ScalarSubquery(subq) => { @@ -250,7 +323,7 @@ impl Unparser<'_> { Ok(ast::Expr::Subquery(sub_query)) } Expr::InSubquery(insubq) => { - let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?); + let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -282,65 +355,173 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::Sort(Sort { - expr, - asc: _, - nulls_first: _, - }) => self.expr_to_sql(expr), - Expr::IsNotNull(expr) => { - Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + Expr::IsNull(expr) => { + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsTrue(expr) => { - Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotTrue(expr) => { - Ok(ast::Expr::IsNotTrue(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsFalse(expr) => { - Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsUnknown(expr) => { - Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotUnknown(expr) => { - Ok(ast::Expr::IsNotUnknown(Box::new(self.expr_to_sql(expr)?))) - } + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) + } + Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::Not(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } Expr::Negative(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(sql_parser_expr), }) } - Expr::ScalarVariable(_, _) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") - } - Expr::IsNull(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), - Expr::IsNotFalse(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), - Expr::GetIndexedField(_) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") - } - Expr::TryCast(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), - Expr::Wildcard { qualifier: _ } => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") - } - Expr::GroupingSet(_) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") - } - Expr::Placeholder(_) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") + Expr::ScalarVariable(_, ids) => { + if ids.is_empty() { + return internal_err!("Not a valid ScalarVariable"); + } + + Ok(if ids.len() == 1 { + ast::Expr::Identifier( + self.new_ident_without_quote_style(ids[0].to_string()), + ) + } else { + ast::Expr::CompoundIdentifier( + ids.iter() + .map(|i| self.new_ident_without_quote_style(i.to_string())) + .collect(), + ) + }) } - Expr::OuterReferenceColumn(_, _) => { - not_impl_err!("Unsupported Expr conversion: {expr:?}") + Expr::TryCast(TryCast { expr, data_type }) => { + let inner_expr = self.expr_to_sql_inner(expr)?; + Ok(ast::Expr::Cast { + kind: ast::CastKind::TryCast, + expr: Box::new(inner_expr), + data_type: self.arrow_dtype_to_ast_dtype(data_type)?, + format: None, + }) } - Expr::Unnest(_) => not_impl_err!("Unsupported Expr conversion: {expr:?}"), + // TODO: unparsing wildcard addition options + Expr::Wildcard { qualifier, .. } => { + if let Some(qualifier) = qualifier { + let idents: Vec = + qualifier.to_vec().into_iter().map(Ident::new).collect(); + Ok(ast::Expr::QualifiedWildcard(ObjectName(idents))) + } else { + Ok(ast::Expr::Wildcard) + } + } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::GroupingSets(grouping_sets) => { + let expr_ast_sets = grouping_sets + .iter() + .map(|set| { + set.iter() + .map(|e| self.expr_to_sql_inner(e)) + .collect::>>() + }) + .collect::>>()?; + + Ok(ast::Expr::GroupingSets(expr_ast_sets)) + } + GroupingSet::Cube(cube) => { + let expr_ast_sets = cube + .iter() + .map(|e| { + let sql = self.expr_to_sql_inner(e)?; + Ok(vec![sql]) + }) + .collect::>>()?; + Ok(ast::Expr::Cube(expr_ast_sets)) + } + GroupingSet::Rollup(rollup) => { + let expr_ast_sets: Vec> = rollup + .iter() + .map(|e| { + let sql = self.expr_to_sql_inner(e)?; + Ok(vec![sql]) + }) + .collect::>>()?; + Ok(ast::Expr::Rollup(expr_ast_sets)) + } + }, + Expr::Placeholder(p) => { + Ok(ast::Expr::Value(ast::Value::Placeholder(p.id.to_string()))) + } + Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), + Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + } + } + + pub fn scalar_function_to_sql( + &self, + func_name: &str, + args: &[Expr], + ) -> Result { + let args = self.function_args_to_sql(args)?; + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + + pub fn sort_to_sql(&self, sort: &Sort) -> Result { + let Sort { + expr, + asc, + nulls_first, + } = sort; + let sql_parser_expr = self.expr_to_sql(expr)?; + + let nulls_first = if self.dialect.supports_nulls_first_in_sort() { + Some(*nulls_first) + } else { + None + }; + + Ok(ast::OrderByExpr { + expr: sql_parser_expr, + asc: Some(*asc), + nulls_first, + with_fill: None, + }) + } + + fn ast_type_for_date64_in_cast(&self) -> ast::DataType { + if self.dialect.use_timestamp_for_date64() { + ast::DataType::Timestamp(None, TimezoneInfo::None) + } else { + ast::DataType::Datetime(None) } } @@ -349,16 +530,82 @@ impl Unparser<'_> { let mut id = table_ref.to_vec(); id.push(col.name.to_string()); return Ok(ast::Expr::CompoundIdentifier( - id.iter().map(|i| self.new_ident(i.to_string())).collect(), + id.iter() + .map(|i| self.new_ident_quoted_if_needs(i.to_string())) + .collect(), )); } - Ok(ast::Expr::Identifier(self.new_ident(col.name.to_string()))) + Ok(ast::Expr::Identifier( + self.new_ident_quoted_if_needs(col.name.to_string()), + )) + } + + fn convert_bound( + &self, + bound: &datafusion_expr::window_frame::WindowFrameBound, + ) -> Result { + match bound { + datafusion_expr::window_frame::WindowFrameBound::Preceding(val) => { + Ok(ast::WindowFrameBound::Preceding({ + let val = self.scalar_to_sql(val)?; + if let ast::Expr::Value(ast::Value::Null) = &val { + None + } else { + Some(Box::new(val)) + } + })) + } + datafusion_expr::window_frame::WindowFrameBound::Following(val) => { + Ok(ast::WindowFrameBound::Following({ + let val = self.scalar_to_sql(val)?; + if let ast::Expr::Value(ast::Value::Null) = &val { + None + } else { + Some(Box::new(val)) + } + })) + } + datafusion_expr::window_frame::WindowFrameBound::CurrentRow => { + Ok(ast::WindowFrameBound::CurrentRow) + } + } + } + + pub(crate) fn function_args_to_sql( + &self, + args: &[Expr], + ) -> Result> { + args.iter() + .map(|e| { + if matches!( + e, + Expr::Wildcard { + qualifier: None, + .. + } + ) { + Ok(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) + } else { + self.expr_to_sql(e) + .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + } + }) + .collect::>>() + } + + /// This function can create an identifier with or without quotes based on the dialect rules + pub(super) fn new_ident_quoted_if_needs(&self, ident: String) -> Ident { + let quote_style = self.dialect.identifier_quote_style(&ident); + Ident { + value: ident, + quote_style, + } } - pub(super) fn new_ident(&self, str: String) -> ast::Ident { - ast::Ident { + pub(super) fn new_ident_without_quote_style(&self, str: String) -> Ident { + Ident { value: str, - quote_style: self.dialect.identifier_quote_style(), + quote_style: None, } } @@ -366,7 +613,7 @@ impl Unparser<'_> { &self, lhs: ast::Expr, rhs: ast::Expr, - op: ast::BinaryOperator, + op: BinaryOperator, ) -> ast::Expr { ast::Expr::BinaryOp { left: Box::new(lhs), @@ -375,6 +622,88 @@ impl Unparser<'_> { } } + /// Given an expression of the form `((a + b) * (c * d))`, + /// the parenthesing is redundant if the precedence of the nested expression is already higher + /// than the surrounding operators' precedence. The above expression would become + /// `(a + b) * c * d`. + /// + /// Also note that when fetching the precedence of a nested expression, we ignore other nested + /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. + fn remove_unnecessary_nesting( + &self, + expr: ast::Expr, + left_op: &BinaryOperator, + right_op: &BinaryOperator, + ) -> ast::Expr { + match expr { + ast::Expr::Nested(nested) => { + let surrounding_precedence = self + .sql_op_precedence(left_op) + .max(self.sql_op_precedence(right_op)); + + let inner_precedence = self.inner_precedence(&nested); + + let not_associative = + matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); + + if inner_precedence == surrounding_precedence && not_associative { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } else if inner_precedence >= surrounding_precedence { + self.remove_unnecessary_nesting(*nested, left_op, right_op) + } else { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } + } + ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { + left: Box::new(self.remove_unnecessary_nesting(*left, left_op, &op)), + right: Box::new(self.remove_unnecessary_nesting(*right, &op, right_op)), + op, + }, + ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + _ => expr, + } + } + + fn inner_precedence(&self, expr: &ast::Expr) -> u8 { + match expr { + ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, + ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), + // Closest precedence we currently have to Between is PGLikeMatch + // (https://www.postgresql.org/docs/7.2/sql-precedence.html) + ast::Expr::Between { .. } => { + self.sql_op_precedence(&BinaryOperator::PGLikeMatch) + } + _ => 0, + } + } + pub(super) fn between_op_to_sql( &self, expr: ast::Expr, @@ -390,42 +719,147 @@ impl Unparser<'_> { } } - fn op_to_sql(&self, op: &Operator) -> Result { + fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 { + match self.sql_to_op(op) { + Ok(op) => op.precedence(), + Err(_) => 0, + } + } + + fn sql_to_op(&self, op: &BinaryOperator) -> Result { match op { - Operator::Eq => Ok(ast::BinaryOperator::Eq), - Operator::NotEq => Ok(ast::BinaryOperator::NotEq), - Operator::Lt => Ok(ast::BinaryOperator::Lt), - Operator::LtEq => Ok(ast::BinaryOperator::LtEq), - Operator::Gt => Ok(ast::BinaryOperator::Gt), - Operator::GtEq => Ok(ast::BinaryOperator::GtEq), - Operator::Plus => Ok(ast::BinaryOperator::Plus), - Operator::Minus => Ok(ast::BinaryOperator::Minus), - Operator::Multiply => Ok(ast::BinaryOperator::Multiply), - Operator::Divide => Ok(ast::BinaryOperator::Divide), - Operator::Modulo => Ok(ast::BinaryOperator::Modulo), - Operator::And => Ok(ast::BinaryOperator::And), - Operator::Or => Ok(ast::BinaryOperator::Or), + BinaryOperator::Eq => Ok(Operator::Eq), + BinaryOperator::NotEq => Ok(Operator::NotEq), + BinaryOperator::Lt => Ok(Operator::Lt), + BinaryOperator::LtEq => Ok(Operator::LtEq), + BinaryOperator::Gt => Ok(Operator::Gt), + BinaryOperator::GtEq => Ok(Operator::GtEq), + BinaryOperator::Plus => Ok(Operator::Plus), + BinaryOperator::Minus => Ok(Operator::Minus), + BinaryOperator::Multiply => Ok(Operator::Multiply), + BinaryOperator::Divide => Ok(Operator::Divide), + BinaryOperator::Modulo => Ok(Operator::Modulo), + BinaryOperator::And => Ok(Operator::And), + BinaryOperator::Or => Ok(Operator::Or), + BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + BinaryOperator::StringConcat => Ok(Operator::StringConcat), + BinaryOperator::AtArrow => Ok(Operator::AtArrow), + BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + _ => not_impl_err!("unsupported operation: {op:?}"), + } + } + + fn op_to_sql(&self, op: &Operator) -> Result { + match op { + Operator::Eq => Ok(BinaryOperator::Eq), + Operator::NotEq => Ok(BinaryOperator::NotEq), + Operator::Lt => Ok(BinaryOperator::Lt), + Operator::LtEq => Ok(BinaryOperator::LtEq), + Operator::Gt => Ok(BinaryOperator::Gt), + Operator::GtEq => Ok(BinaryOperator::GtEq), + Operator::Plus => Ok(BinaryOperator::Plus), + Operator::Minus => Ok(BinaryOperator::Minus), + Operator::Multiply => Ok(BinaryOperator::Multiply), + Operator::Divide => Ok(BinaryOperator::Divide), + Operator::Modulo => Ok(BinaryOperator::Modulo), + Operator::And => Ok(BinaryOperator::And), + Operator::Or => Ok(BinaryOperator::Or), Operator::IsDistinctFrom => not_impl_err!("unsupported operation: {op:?}"), Operator::IsNotDistinctFrom => not_impl_err!("unsupported operation: {op:?}"), - Operator::RegexMatch => Ok(ast::BinaryOperator::PGRegexMatch), - Operator::RegexIMatch => Ok(ast::BinaryOperator::PGRegexIMatch), - Operator::RegexNotMatch => Ok(ast::BinaryOperator::PGRegexNotMatch), - Operator::RegexNotIMatch => Ok(ast::BinaryOperator::PGRegexNotIMatch), - Operator::ILikeMatch => Ok(ast::BinaryOperator::PGILikeMatch), - Operator::NotLikeMatch => Ok(ast::BinaryOperator::PGNotLikeMatch), - Operator::LikeMatch => Ok(ast::BinaryOperator::PGLikeMatch), - Operator::NotILikeMatch => Ok(ast::BinaryOperator::PGNotILikeMatch), - Operator::BitwiseAnd => Ok(ast::BinaryOperator::BitwiseAnd), - Operator::BitwiseOr => Ok(ast::BinaryOperator::BitwiseOr), - Operator::BitwiseXor => Ok(ast::BinaryOperator::BitwiseXor), - Operator::BitwiseShiftRight => Ok(ast::BinaryOperator::PGBitwiseShiftRight), - Operator::BitwiseShiftLeft => Ok(ast::BinaryOperator::PGBitwiseShiftLeft), - Operator::StringConcat => Ok(ast::BinaryOperator::StringConcat), + Operator::RegexMatch => Ok(BinaryOperator::PGRegexMatch), + Operator::RegexIMatch => Ok(BinaryOperator::PGRegexIMatch), + Operator::RegexNotMatch => Ok(BinaryOperator::PGRegexNotMatch), + Operator::RegexNotIMatch => Ok(BinaryOperator::PGRegexNotIMatch), + Operator::ILikeMatch => Ok(BinaryOperator::PGILikeMatch), + Operator::NotLikeMatch => Ok(BinaryOperator::PGNotLikeMatch), + Operator::LikeMatch => Ok(BinaryOperator::PGLikeMatch), + Operator::NotILikeMatch => Ok(BinaryOperator::PGNotILikeMatch), + Operator::BitwiseAnd => Ok(BinaryOperator::BitwiseAnd), + Operator::BitwiseOr => Ok(BinaryOperator::BitwiseOr), + Operator::BitwiseXor => Ok(BinaryOperator::BitwiseXor), + Operator::BitwiseShiftRight => Ok(BinaryOperator::PGBitwiseShiftRight), + Operator::BitwiseShiftLeft => Ok(BinaryOperator::PGBitwiseShiftLeft), + Operator::StringConcat => Ok(BinaryOperator::StringConcat), Operator::AtArrow => not_impl_err!("unsupported operation: {op:?}"), Operator::ArrowAt => not_impl_err!("unsupported operation: {op:?}"), } } + fn handle_timestamp( + &self, + v: &ScalarValue, + tz: &Option>, + ) -> Result + where + i64: From, + { + let ts = if let Some(tz) = tz { + v.to_array()? + .as_any() + .downcast_ref::>() + .ok_or(internal_datafusion_err!( + "Failed to downcast type {v:?} to arrow array" + ))? + .value_as_datetime_with_tz(0, tz.parse()?) + .ok_or(internal_datafusion_err!( + "Unable to convert {v:?} to DateTime" + ))? + .to_string() + } else { + v.to_array()? + .as_any() + .downcast_ref::>() + .ok_or(internal_datafusion_err!( + "Failed to downcast type {v:?} to arrow array" + ))? + .value_as_datetime(0) + .ok_or(internal_datafusion_err!( + "Unable to convert {v:?} to DateTime" + ))? + .to_string() + }; + Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(ast::Expr::Value(SingleQuotedString(ts))), + data_type: ast::DataType::Timestamp(None, TimezoneInfo::None), + format: None, + }) + } + + fn handle_time(&self, v: &ScalarValue) -> Result + where + i64: From, + { + let time = v + .to_array()? + .as_any() + .downcast_ref::>() + .ok_or(internal_datafusion_err!( + "Failed to downcast type {v:?} to arrow array" + ))? + .value_as_time(0) + .ok_or(internal_datafusion_err!("Unable to convert {v:?} to Time"))? + .to_string(); + Ok(ast::Expr::Cast { + kind: ast::CastKind::Cast, + expr: Box::new(ast::Expr::Value(SingleQuotedString(time))), + data_type: ast::DataType::Time(None, TimezoneInfo::None), + format: None, + }) + } + /// DataFusion ScalarValues sometimes require a ast::Expr to construct. /// For example ScalarValue::Date32(d) corresponds to the ast::Expr CAST('datestr' as DATE) fn scalar_to_sql(&self, v: &ScalarValue) -> Result { @@ -435,20 +869,38 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned()))) } ScalarValue::Boolean(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Float32(Some(f)) => { + ScalarValue::Float16(Some(f)) => { Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) } + ScalarValue::Float16(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Float32(Some(f)) => { + let f_val = match f.fract() { + 0.0 => format!("{:.1}", f), + _ => format!("{}", f), + }; + Ok(ast::Expr::Value(ast::Value::Number(f_val, false))) + } ScalarValue::Float32(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Float64(Some(f)) => { - Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) + let f_val = match f.fract() { + 0.0 => format!("{:.1}", f), + _ => format!("{}", f), + }; + Ok(ast::Expr::Value(ast::Value::Number(f_val, false))) } ScalarValue::Float64(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Decimal128(Some(_), ..) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::Decimal128(Some(value), precision, scale) => { + Ok(ast::Expr::Value(ast::Value::Number( + Decimal128Type::format_decimal(*value, *precision, *scale), + false, + ))) } ScalarValue::Decimal128(None, ..) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Decimal256(Some(_), ..) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::Decimal256(Some(value), precision, scale) => { + Ok(ast::Expr::Value(ast::Value::Number( + Decimal256Type::format_decimal(*value, *precision, *scale), + false, + ))) } ScalarValue::Decimal256(None, ..) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Int8(Some(i)) => { @@ -483,16 +935,24 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Number(ui.to_string(), false))) } ScalarValue::UInt64(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::Utf8(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), + ScalarValue::Utf8View(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } + ScalarValue::Utf8View(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::LargeUtf8(Some(str)) => { + Ok(ast::Expr::Value(SingleQuotedString(str.to_string()))) + } ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::BinaryView(Some(_)) => { + not_impl_err!("Unsupported scalar: {v:?}") + } + ScalarValue::BinaryView(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeBinary(..) => { not_impl_err!("Unsupported scalar: {v:?}") } @@ -517,7 +977,8 @@ impl Unparser<'_> { ))?; Ok(ast::Expr::Cast { - expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + kind: ast::CastKind::Cast, + expr: Box::new(ast::Expr::Value(SingleQuotedString( date.to_string(), ))), data_type: ast::DataType::Date, @@ -539,71 +1000,68 @@ impl Unparser<'_> { ))?; Ok(ast::Expr::Cast { - expr: Box::new(ast::Expr::Value(ast::Value::SingleQuotedString( + kind: ast::CastKind::Cast, + expr: Box::new(ast::Expr::Value(SingleQuotedString( datetime.to_string(), ))), - data_type: ast::DataType::Datetime(None), + data_type: self.ast_type_for_date64_in_cast(), format: None, }) } ScalarValue::Date64(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Time32Second(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time32Second(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Time32Millisecond(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time32Millisecond(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Time64Microsecond(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time64Microsecond(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Time64Nanosecond(Some(_t)) => { - not_impl_err!("Unsupported scalar: {v:?}") + self.handle_time::(v) } ScalarValue::Time64Nanosecond(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::TimestampSecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampSecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampSecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::TimestampMillisecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampMillisecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampMillisecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::TimestampMicrosecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampMicrosecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampMicrosecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::TimestampNanosecond(Some(_ts), _) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::TimestampNanosecond(Some(_ts), tz) => { + self.handle_timestamp::(v, tz) } ScalarValue::TimestampNanosecond(None, _) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::IntervalYearMonth(Some(_i)) => { - not_impl_err!("Unsupported scalar: {v:?}") + ScalarValue::IntervalYearMonth(Some(_)) + | ScalarValue::IntervalDayTime(Some(_)) + | ScalarValue::IntervalMonthDayNano(Some(_)) => { + self.interval_scalar_to_sql(v) } ScalarValue::IntervalYearMonth(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } - ScalarValue::IntervalDayTime(Some(_i)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } ScalarValue::IntervalDayTime(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::IntervalMonthDayNano(Some(_i)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } ScalarValue::IntervalMonthDayNano(None) => { Ok(ast::Expr::Value(ast::Value::Null)) } @@ -630,11 +1088,278 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } + /// MySQL requires INTERVAL sql to be in the format: INTERVAL 1 YEAR + INTERVAL 1 MONTH + INTERVAL 1 DAY etc + /// `` + /// Interval sequence can't be wrapped in brackets - (INTERVAL 1 YEAR + INTERVAL 1 MONTH ...) so we need to generate + /// a single INTERVAL expression so it works correct for interval substraction cases + /// MySQL supports the DAY_MICROSECOND unit type (format is DAYS HOURS:MINUTES:SECONDS.MICROSECONDS), but it is not supported by sqlparser + /// so we calculate the best single interval to represent the provided duration + fn interval_to_mysql_expr( + &self, + months: i32, + days: i32, + microseconds: i64, + ) -> Result { + // MONTH only + if months != 0 && days == 0 && microseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + months.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } else if months != 0 { + return not_impl_err!("Unsupported Interval scalar with both Month and DayTime for IntervalStyle::MySQL"); + } + + // DAY only + if microseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + days.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + // Calculate the best single interval to represent the provided days and microseconds + + let microseconds = microseconds + (days as i64 * 24 * 60 * 60 * 1_000_000); + + if microseconds % 1_000_000 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + microseconds.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Microsecond), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let secs = microseconds / 1_000_000; + + if secs % 60 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + secs.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Second), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let mins = secs / 60; + + if mins % 60 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + mins.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Minute), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let hours = mins / 60; + + if hours % 24 != 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + hours.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Hour), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + return Ok(ast::Expr::Interval(interval)); + } + + let days = hours / 24; + + let interval = Interval { + value: Box::new(ast::Expr::Value(ast::Value::Number( + days.to_string(), + false, + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + + fn interval_scalar_to_sql(&self, v: &ScalarValue) -> Result { + match self.dialect.interval_style() { + IntervalStyle::PostgresVerbose => { + let wrap_array = v.to_array()?; + let Some(result) = array_value_to_string(&wrap_array, 0).ok() else { + return internal_err!( + "Unable to convert interval scalar value to string" + ); + }; + let interval = Interval { + value: Box::new(ast::Expr::Value(SingleQuotedString( + result.to_uppercase(), + ))), + leading_field: None, + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + // If the interval standard is SQLStandard, implement a simple unparse logic + IntervalStyle::SQLStandard => match v { + ScalarValue::IntervalYearMonth(Some(v)) => { + let interval = Interval { + value: Box::new(ast::Expr::Value(SingleQuotedString( + v.to_string(), + ))), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + ScalarValue::IntervalDayTime(Some(v)) => { + let days = v.days; + let secs = v.milliseconds / 1_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let millis = v.milliseconds % 1_000; + let interval = Interval { + value: Box::new(ast::Expr::Value(SingleQuotedString(format!( + "{days} {hours}:{mins}:{secs}.{millis:3}" + )))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: Some(ast::DateTimeField::Second), + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + if v.months >= 0 && v.days == 0 && v.nanoseconds == 0 { + let interval = Interval { + value: Box::new(ast::Expr::Value(SingleQuotedString( + v.months.to_string(), + ))), + leading_field: Some(ast::DateTimeField::Month), + leading_precision: None, + last_field: None, + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } else if v.months == 0 && v.nanoseconds % 1_000_000 == 0 { + let days = v.days; + let secs = v.nanoseconds / 1_000_000_000; + let mins = secs / 60; + let hours = mins / 60; + + let secs = secs - (mins * 60); + let mins = mins - (hours * 60); + + let millis = (v.nanoseconds % 1_000_000_000) / 1_000_000; + + let interval = Interval { + value: Box::new(ast::Expr::Value(SingleQuotedString( + format!("{days} {hours}:{mins}:{secs}.{millis:03}"), + ))), + leading_field: Some(ast::DateTimeField::Day), + leading_precision: None, + last_field: Some(ast::DateTimeField::Second), + fractional_seconds_precision: None, + }; + Ok(ast::Expr::Interval(interval)) + } else { + not_impl_err!("Unsupported IntervalMonthDayNano scalar with both Month and DayTime for IntervalStyle::SQLStandard") + } + } + _ => not_impl_err!( + "Unsupported ScalarValue for Interval conversion: {v:?}" + ), + }, + IntervalStyle::MySQL => match v { + ScalarValue::IntervalYearMonth(Some(v)) => { + self.interval_to_mysql_expr(*v, 0, 0) + } + ScalarValue::IntervalDayTime(Some(v)) => { + self.interval_to_mysql_expr(0, v.days, v.milliseconds as i64 * 1_000) + } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + if v.nanoseconds % 1_000 != 0 { + return not_impl_err!( + "Unsupported IntervalMonthDayNano scalar with nanoseconds precision for IntervalStyle::MySQL" + ); + } + self.interval_to_mysql_expr(v.months, v.days, v.nanoseconds / 1_000) + } + _ => not_impl_err!( + "Unsupported ScalarValue for Interval conversion: {v:?}" + ), + }, + } + } + + /// Converts an UNNEST operation to an AST expression by wrapping it as a function call, + /// since there is no direct representation for UNNEST in the AST. + fn unnest_to_sql(&self, unnest: &Unnest) -> Result { + let args = self.function_args_to_sql(std::slice::from_ref(&unnest.expr))?; + + Ok(ast::Expr::Function(Function { + name: ObjectName(vec![Ident { + value: "UNNEST".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + })) + } + fn arrow_dtype_to_ast_dtype(&self, data_type: &DataType) -> Result { match data_type { DataType::Null => { @@ -643,8 +1368,8 @@ impl Unparser<'_> { DataType::Boolean => Ok(ast::DataType::Bool), DataType::Int8 => Ok(ast::DataType::TinyInt(None)), DataType::Int16 => Ok(ast::DataType::SmallInt(None)), - DataType::Int32 => Ok(ast::DataType::Integer(None)), - DataType::Int64 => Ok(ast::DataType::BigInt(None)), + DataType::Int32 => Ok(self.dialect.int32_cast_dtype()), + DataType::Int64 => Ok(self.dialect.int64_cast_dtype()), DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)), DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)), DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)), @@ -653,12 +1378,12 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } DataType::Float32 => Ok(ast::DataType::Float(None)), - DataType::Float64 => Ok(ast::DataType::Double), - DataType::Timestamp(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), + DataType::Timestamp(time_unit, tz) => { + Ok(self.dialect.timestamp_cast_dtype(time_unit, tz)) } - DataType::Date32 => Ok(ast::DataType::Date), - DataType::Date64 => Ok(ast::DataType::Datetime(None)), + DataType::Date32 => Ok(self.dialect.date32_cast_dtype()), + DataType::Date64 => Ok(self.ast_type_for_date64_in_cast()), DataType::Time32(_) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -683,8 +1408,8 @@ impl Unparser<'_> { DataType::BinaryView => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Utf8 => Ok(ast::DataType::Varchar(None)), - DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), + DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), DataType::Utf8View => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -712,11 +1437,18 @@ impl Unparser<'_> { DataType::Dictionary(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Decimal128(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") - } - DataType::Decimal256(_, _) => { - not_impl_err!("Unsupported DataType: conversion: {data_type:?}") + DataType::Decimal128(precision, scale) + | DataType::Decimal256(precision, scale) => { + let mut new_precision = *precision as u64; + let mut new_scale = *scale as u64; + if *scale < 0 { + new_precision = (*precision as i16 - *scale as i16) as u64; + new_scale = 0 + } + + Ok(ast::DataType::Decimal( + ast::ExactNumberInfo::PrecisionAndScale(new_precision, new_scale), + )) } DataType::Map(_, _) => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") @@ -730,16 +1462,30 @@ impl Unparser<'_> { #[cfg(test)] mod tests { + use std::ops::{Add, Sub}; use std::{any::Any, sync::Arc, vec}; + use arrow::datatypes::TimeUnit; use arrow::datatypes::{Field, Schema}; + use arrow_schema::DataType::Int8; + use ast::ObjectName; use datafusion_common::TableReference; + use datafusion_expr::expr::WildcardOptions; use datafusion_expr::{ - case, col, exists, expr::AggregateFunction, lit, not, not_exists, table_scan, - ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, + case, col, cube, exists, grouping_set, interval_datetime_lit, + interval_year_month_lit, lit, not, not_exists, out_ref_col, placeholder, rollup, + table_scan, try_cast, when, wildcard, ColumnarValue, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, WindowFrame, WindowFunctionDefinition, }; + use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_functions_window::row_number::row_number_udwf; - use crate::unparser::dialect::CustomDialect; + use crate::unparser::dialect::{ + CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, + PostgreSqlDialect, + }; use super::*; @@ -784,52 +1530,92 @@ mod tests { fn expr_to_sql_ok() -> Result<()> { let dummy_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let dummy_logical_plan = table_scan(Some("t"), &dummy_schema, None)? - .project(vec![Expr::Wildcard { qualifier: None }])? + .project(vec![Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }])? .filter(col("a").eq(lit(1)))? .build()?; let tests: Vec<(Expr, &str)> = vec![ - ((col("a") + col("b")).gt(lit(4)), r#"(("a" + "b") > 4)"#), + ((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#), ( Expr::Column(Column { relation: Some(TableReference::partial("a", "b")), name: "c".to_string(), }) .gt(lit(4)), - r#"("a"."b"."c" > 4)"#, + r#"(a.b.c > 4)"#, ), ( case(col("a")) .when(lit(1), lit(true)) .when(lit(0), lit(false)) .otherwise(lit(ScalarValue::Null))?, - r#"CASE "a" WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, + r#"CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, + ), + ( + when(col("a").is_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN a IS NULL THEN true ELSE false END"#, + ), + ( + when(col("a").is_not_null(), lit(true)).otherwise(lit(false))?, + r#"CASE WHEN a IS NOT NULL THEN true ELSE false END"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), data_type: DataType::Date64, }), - r#"CAST("a" AS DATETIME)"#, + r#"CAST(a AS DATETIME)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Timestamp( + TimeUnit::Nanosecond, + Some("+08:00".into()), + ), + }), + r#"CAST(a AS TIMESTAMP WITH TIME ZONE)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Timestamp(TimeUnit::Millisecond, None), + }), + r#"CAST(a AS TIMESTAMP)"#, ), ( Expr::Cast(Cast { expr: Box::new(col("a")), data_type: DataType::UInt32, }), - r#"CAST("a" AS INTEGER UNSIGNED)"#, + r#"CAST(a AS INTEGER UNSIGNED)"#, ), ( col("a").in_list(vec![lit(1), lit(2), lit(3)], false), - r#""a" IN (1, 2, 3)"#, + r#"a IN (1, 2, 3)"#, ), ( col("a").in_list(vec![lit(1), lit(2), lit(3)], true), - r#""a" NOT IN (1, 2, 3)"#, + r#"a NOT IN (1, 2, 3)"#, ), ( ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), - r#"dummy_udf("a", "b")"#, + r#"dummy_udf(a, b)"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_null(), + r#"dummy_udf(a, b) IS NULL"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()) + .call(vec![col("a"), col("b")]) + .is_not_null(), + r#"dummy_udf(a, b) IS NOT NULL"#, ), ( Expr::Like(Like { @@ -839,7 +1625,7 @@ mod tests { escape_char: Some('o'), case_insensitive: true, }), - r#""a" NOT LIKE 'foo' ESCAPE 'o'"#, + r#"a NOT LIKE 'foo' ESCAPE 'o'"#, ), ( Expr::SimilarTo(Like { @@ -849,7 +1635,7 @@ mod tests { escape_char: Some('o'), case_insensitive: true, }), - r#""a" LIKE 'foo' ESCAPE 'o'"#, + r#"a LIKE 'foo' ESCAPE 'o'"#, ), ( Expr::Literal(ScalarValue::Date64(Some(0))), @@ -876,67 +1662,228 @@ mod tests { r#"CAST('1969-12-31' AS DATE)"#, ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Sum, - ), - args: vec![col("a")], - distinct: false, - filter: None, - order_by: None, + Expr::Literal(ScalarValue::TimestampSecond(Some(10001), None)), + r#"CAST('1970-01-01 02:46:41' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampSecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 10:46:41 +08:00' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMillisecond(Some(10001), None)), + r#"CAST('1970-01-01 00:00:10.001' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMillisecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 08:00:10.001 +08:00' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMicrosecond(Some(10001), None)), + r#"CAST('1970-01-01 00:00:00.010001' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampMicrosecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 08:00:00.010001 +08:00' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampNanosecond(Some(10001), None)), + r#"CAST('1970-01-01 00:00:00.000010001' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::TimestampNanosecond( + Some(10001), + Some("+08:00".into()), + )), + r#"CAST('1970-01-01 08:00:00.000010001 +08:00' AS TIMESTAMP)"#, + ), + ( + Expr::Literal(ScalarValue::Time32Second(Some(10001))), + r#"CAST('02:46:41' AS TIME)"#, + ), + ( + Expr::Literal(ScalarValue::Time32Millisecond(Some(10001))), + r#"CAST('00:00:10.001' AS TIME)"#, + ), + ( + Expr::Literal(ScalarValue::Time64Microsecond(Some(10001))), + r#"CAST('00:00:00.010001' AS TIME)"#, + ), + ( + Expr::Literal(ScalarValue::Time64Nanosecond(Some(10001))), + r#"CAST('00:00:00.000010001' AS TIME)"#, + ), + (sum(col("a")), r#"sum(a)"#), + ( + count_udaf() + .call(vec![Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }]) + .distinct() + .build() + .unwrap(), + "count(DISTINCT *)", + ), + ( + count_udaf() + .call(vec![Expr::Wildcard { + qualifier: None, + options: WildcardOptions::default(), + }]) + .filter(lit(true)) + .build() + .unwrap(), + "count(*) FILTER (WHERE true)", + ), + ( + Expr::WindowFunction(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(row_number_udwf()), + args: vec![col("col")], + partition_by: vec![], + order_by: vec![], + window_frame: WindowFrame::new(None), null_treatment: None, }), - r#"SUM("a")"#, + r#"row_number(col) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)"#, ), ( - Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn( - datafusion_expr::AggregateFunction::Count, + Expr::WindowFunction(WindowFunction { + fun: WindowFunctionDefinition::AggregateUDF(count_udaf()), + args: vec![wildcard()], + partition_by: vec![], + order_by: vec![Sort::new(col("a"), false, true)], + window_frame: WindowFrame::new_bounds( + datafusion_expr::WindowFrameUnits::Range, + datafusion_expr::WindowFrameBound::Preceding( + ScalarValue::UInt32(Some(6)), + ), + datafusion_expr::WindowFrameBound::Following( + ScalarValue::UInt32(Some(2)), + ), ), - args: vec![Expr::Wildcard { qualifier: None }], - distinct: true, - filter: None, - order_by: None, null_treatment: None, }), - "COUNT(DISTINCT *)", + r#"count(*) OVER (ORDER BY a DESC NULLS FIRST RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING)"#, ), - (col("a").is_not_null(), r#""a" IS NOT NULL"#), + (col("a").is_not_null(), r#"a IS NOT NULL"#), + (col("a").is_null(), r#"a IS NULL"#), ( (col("a") + col("b")).gt(lit(4)).is_true(), - r#"(("a" + "b") > 4) IS TRUE"#, + r#"((a + b) > 4) IS TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_true(), - r#"(("a" + "b") > 4) IS NOT TRUE"#, + r#"((a + b) > 4) IS NOT TRUE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_false(), - r#"(("a" + "b") > 4) IS FALSE"#, + r#"((a + b) > 4) IS FALSE"#, + ), + ( + (col("a") + col("b")).gt(lit(4)).is_not_false(), + r#"((a + b) > 4) IS NOT FALSE"#, ), ( (col("a") + col("b")).gt(lit(4)).is_unknown(), - r#"(("a" + "b") > 4) IS UNKNOWN"#, + r#"((a + b) > 4) IS UNKNOWN"#, ), ( (col("a") + col("b")).gt(lit(4)).is_not_unknown(), - r#"(("a" + "b") > 4) IS NOT UNKNOWN"#, + r#"((a + b) > 4) IS NOT UNKNOWN"#, ), - (not(col("a")), r#"NOT "a""#), + (not(col("a")), r#"NOT a"#), ( Expr::between(col("a"), lit(1), lit(7)), - r#"("a" BETWEEN 1 AND 7)"#, + r#"(a BETWEEN 1 AND 7)"#, ), - (Expr::Negative(Box::new(col("a"))), r#"-"a""#), + (Expr::Negative(Box::new(col("a"))), r#"-a"#), ( exists(Arc::new(dummy_logical_plan.clone())), - r#"EXISTS (SELECT "t"."a" FROM "t" WHERE ("t"."a" = 1))"#, + r#"EXISTS (SELECT * FROM t WHERE (t.a = 1))"#, + ), + ( + not_exists(Arc::new(dummy_logical_plan)), + r#"NOT EXISTS (SELECT * FROM t WHERE (t.a = 1))"#, + ), + ( + try_cast(col("a"), DataType::Date64), + r#"TRY_CAST(a AS DATETIME)"#, + ), + ( + try_cast(col("a"), DataType::UInt32), + r#"TRY_CAST(a AS INTEGER UNSIGNED)"#, ), ( - not_exists(Arc::new(dummy_logical_plan.clone())), - r#"NOT EXISTS (SELECT "t"."a" FROM "t" WHERE ("t"."a" = 1))"#, + Expr::ScalarVariable(Int8, vec![String::from("@a")]), + r#"@a"#, + ), + ( + Expr::ScalarVariable( + Int8, + vec![String::from("@root"), String::from("foo")], + ), + r#"@root.foo"#, + ), + (col("x").eq(placeholder("$1")), r#"(x = $1)"#), + ( + out_ref_col(DataType::Int32, "t.a").gt(lit(1)), + r#"(t.a > 1)"#, + ), + ( + grouping_set(vec![vec![col("a"), col("b")], vec![col("a")]]), + r#"GROUPING SETS ((a, b), (a))"#, + ), + (cube(vec![col("a"), col("b")]), r#"CUBE (a, b)"#), + (rollup(vec![col("a"), col("b")]), r#"ROLLUP (a, b)"#), + (col("table").eq(lit(1)), r#"("table" = 1)"#), + ( + col("123_need_quoted").eq(lit(1)), + r#"("123_need_quoted" = 1)"#, + ), + (col("need-quoted").eq(lit(1)), r#"("need-quoted" = 1)"#), + (col("need quoted").eq(lit(1)), r#"("need quoted" = 1)"#), + // See test_interval_scalar_to_expr for interval literals + ( + (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal128( + Some(100123), + 28, + 3, + ))), + r#"((a + b) > 100.123)"#, + ), + ( + (col("a") + col("b")).gt(Expr::Literal(ScalarValue::Decimal256( + Some(100123.into()), + 28, + 3, + ))), + r#"((a + b) > 100.123)"#, + ), + ( + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Decimal128(10, -2), + }), + r#"CAST(a AS DECIMAL(12,0))"#, + ), + ( + Expr::Unnest(Unnest { + expr: Box::new(Expr::Column(Column { + relation: Some(TableReference::partial("schema", "table")), + name: "array_col".to_string(), + })), + }), + r#"UNNEST("schema"."table".array_col)"#, ), - (col("a").sort(true, true), r#""a""#), ]; for (expr, expected) in tests { @@ -951,8 +1898,10 @@ mod tests { } #[test] - fn custom_dialect() -> Result<()> { - let dialect = CustomDialect::new(Some('\'')); + fn custom_dialect_with_identifier_quote_style() -> Result<()> { + let dialect = CustomDialectBuilder::new() + .with_identifier_quote_style('\'') + .build(); let unparser = Unparser::new(&dialect); let expr = col("a").gt(lit(4)); @@ -967,8 +1916,8 @@ mod tests { } #[test] - fn custom_dialect_none() -> Result<()> { - let dialect = CustomDialect::new(None); + fn custom_dialect_without_identifier_quote_style() -> Result<()> { + let dialect = CustomDialect::default(); let unparser = Unparser::new(&dialect); let expr = col("a").gt(lit(4)); @@ -981,4 +1930,499 @@ mod tests { Ok(()) } + + #[test] + fn custom_dialect_use_timestamp_for_date64() -> Result<()> { + for (use_timestamp_for_date64, identifier) in + [(false, "DATETIME"), (true, "TIMESTAMP")] + { + let dialect = CustomDialectBuilder::new() + .with_use_timestamp_for_date64(use_timestamp_for_date64) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Date64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + + let expected = format!(r#"CAST(a AS {identifier})"#); + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_float64_ast_dtype() -> Result<()> { + for (float64_ast_dtype, identifier) in [ + (ast::DataType::Double, "DOUBLE"), + (ast::DataType::DoublePrecision, "DOUBLE PRECISION"), + ] { + let dialect = CustomDialectBuilder::new() + .with_float64_ast_dtype(float64_ast_dtype) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + + let expected = format!(r#"CAST(a AS {identifier})"#); + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { + let tests: Vec<(Sort, &str, bool)> = vec![ + (col("a").sort(true, true), r#"a ASC NULLS FIRST"#, true), + (col("a").sort(true, true), r#"a ASC"#, false), + ]; + + for (expr, expected, supports_nulls_first_in_sort) in tests { + let dialect = CustomDialectBuilder::new() + .with_supports_nulls_first_in_sort(supports_nulls_first_in_sort) + .build(); + let unparser = Unparser::new(&dialect); + let ast = unparser.sort_to_sql(&expr)?; + + let actual = format!("{}", ast); + + assert_eq!(actual, expected); + } + + Ok(()) + } + + #[test] + fn test_interval_scalar_to_expr() { + let tests = [ + ( + interval_month_day_nano_lit("1 MONTH"), + IntervalStyle::SQLStandard, + "INTERVAL '1' MONTH", + ), + ( + interval_month_day_nano_lit("1.5 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '1 12:0:0.000' DAY TO SECOND", + ), + ( + interval_month_day_nano_lit("-1.5 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '-1 -12:0:0.000' DAY TO SECOND", + ), + ( + interval_month_day_nano_lit("1.51234 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '1 12:17:46.176' DAY TO SECOND", + ), + ( + interval_datetime_lit("1.51234 DAY"), + IntervalStyle::SQLStandard, + "INTERVAL '1 12:17:46.176' DAY TO SECOND", + ), + ( + interval_year_month_lit("1 YEAR"), + IntervalStyle::SQLStandard, + "INTERVAL '12' MONTH", + ), + ( + interval_month_day_nano_lit( + "1 YEAR 1 MONTH 1 DAY 3 HOUR 10 MINUTE 20 SECOND", + ), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '13 MONS 1 DAYS 3 HOURS 10 MINS 20.000000000 SECS'"#, + ), + ( + interval_month_day_nano_lit("1.5 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '1 MONS 15 DAYS'"#, + ), + ( + interval_month_day_nano_lit("-3 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '-3 MONS'"#, + ), + ( + interval_month_day_nano_lit("1 MONTH") + .add(interval_month_day_nano_lit("1 DAY")), + IntervalStyle::PostgresVerbose, + r#"(INTERVAL '1 MONS' + INTERVAL '1 DAYS')"#, + ), + ( + interval_month_day_nano_lit("1 MONTH") + .sub(interval_month_day_nano_lit("1 DAY")), + IntervalStyle::PostgresVerbose, + r#"(INTERVAL '1 MONS' - INTERVAL '1 DAYS')"#, + ), + ( + interval_datetime_lit("10 DAY 1 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '10 DAYS 1 HOURS 10 MINS 20.000 SECS'"#, + ), + ( + interval_datetime_lit("10 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '10 DAYS 1 HOURS 40 MINS 20.000 SECS'"#, + ), + ( + interval_year_month_lit("1 YEAR 1 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '1 YEARS 1 MONS'"#, + ), + ( + interval_year_month_lit("1.5 YEAR 1 MONTH"), + IntervalStyle::PostgresVerbose, + r#"INTERVAL '1 YEARS 7 MONS'"#, + ), + ( + interval_year_month_lit("1 YEAR 1 MONTH"), + IntervalStyle::MySQL, + r#"INTERVAL 13 MONTH"#, + ), + ( + interval_month_day_nano_lit("1 YEAR -1 MONTH"), + IntervalStyle::MySQL, + r#"INTERVAL 11 MONTH"#, + ), + ( + interval_month_day_nano_lit("15 DAY"), + IntervalStyle::MySQL, + r#"INTERVAL 15 DAY"#, + ), + ( + interval_month_day_nano_lit("-40 HOURS"), + IntervalStyle::MySQL, + r#"INTERVAL -40 HOUR"#, + ), + ( + interval_datetime_lit("-1.5 DAY 1 HOUR"), + IntervalStyle::MySQL, + "INTERVAL -35 HOUR", + ), + ( + interval_datetime_lit("1000000 DAY 1.5 HOUR 10 MINUTE 20 SECOND"), + IntervalStyle::MySQL, + r#"INTERVAL 86400006020 SECOND"#, + ), + ( + interval_year_month_lit("0 DAY 0 HOUR"), + IntervalStyle::MySQL, + r#"INTERVAL 0 DAY"#, + ), + ( + interval_month_day_nano_lit("-1296000000 SECOND"), + IntervalStyle::MySQL, + r#"INTERVAL -15000 DAY"#, + ), + ]; + + for (value, style, expected) in tests { + let dialect = CustomDialectBuilder::new() + .with_interval_style(style) + .build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + + #[test] + fn test_float_scalar_to_expr() { + let tests = [ + (Expr::Literal(ScalarValue::Float64(Some(3f64))), "3.0"), + (Expr::Literal(ScalarValue::Float64(Some(3.1f64))), "3.1"), + (Expr::Literal(ScalarValue::Float32(Some(-2f32))), "-2.0"), + ( + Expr::Literal(ScalarValue::Float32(Some(-2.989f32))), + "-2.989", + ), + ]; + for (value, expected) in tests { + let dialect = CustomDialectBuilder::new().build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + + #[test] + fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { + let default_dialect = CustomDialectBuilder::default().build(); + let mysql_custom_dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .with_large_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + + for (dialect, data_type, identifier) in [ + (&default_dialect, DataType::Utf8, "VARCHAR"), + (&default_dialect, DataType::LargeUtf8, "TEXT"), + (&mysql_custom_dialect, DataType::Utf8, "CHAR"), + (&mysql_custom_dialect, DataType::LargeUtf8, "CHAR"), + ] { + let unparser = Unparser::new(dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_date_field_extract_style() -> Result<()> { + for (extract_style, unit, expected) in [ + ( + DateFieldExtractStyle::DatePart, + "YEAR", + "date_part('YEAR', x)", + ), + ( + DateFieldExtractStyle::Extract, + "YEAR", + "EXTRACT(YEAR FROM x)", + ), + (DateFieldExtractStyle::Strftime, "YEAR", "strftime('%Y', x)"), + ( + DateFieldExtractStyle::DatePart, + "MONTH", + "date_part('MONTH', x)", + ), + ( + DateFieldExtractStyle::Extract, + "MONTH", + "EXTRACT(MONTH FROM x)", + ), + ( + DateFieldExtractStyle::Strftime, + "MONTH", + "strftime('%m', x)", + ), + ( + DateFieldExtractStyle::DatePart, + "DAY", + "date_part('DAY', x)", + ), + (DateFieldExtractStyle::Strftime, "DAY", "strftime('%d', x)"), + (DateFieldExtractStyle::Extract, "DAY", "EXTRACT(DAY FROM x)"), + ] { + let dialect = CustomDialectBuilder::new() + .with_date_field_extract_style(extract_style) + .build(); + + let unparser = Unparser::new(&dialect); + let expr = ScalarUDF::new_from_impl( + datafusion_functions::datetime::date_part::DatePartFunc::new(), + ) + .call(vec![Expr::Literal(ScalarValue::new_utf8(unit)), col("x")]); + + let ast = unparser.expr_to_sql(&expr)?; + let actual = format!("{}", ast); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_int64_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_int64_cast_dtype(ast::DataType::Custom( + ObjectName(vec![Ident::new("SIGNED")]), + vec![], + )) + .build(); + + for (dialect, identifier) in + [(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")] + { + let unparser = Unparser::new(&dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Int64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_int32_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_int32_cast_dtype(ast::DataType::Custom( + ObjectName(vec![Ident::new("SIGNED")]), + vec![], + )) + .build(); + + for (dialect, identifier) in + [(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")] + { + let unparser = Unparser::new(&dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Int32, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_with_timestamp_cast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::new().build(); + let mysql_dialect = CustomDialectBuilder::new() + .with_timestamp_cast_dtype( + ast::DataType::Datetime(None), + ast::DataType::Datetime(None), + ) + .build(); + + let timestamp = DataType::Timestamp(TimeUnit::Nanosecond, None); + let timestamp_with_tz = + DataType::Timestamp(TimeUnit::Nanosecond, Some("+08:00".into())); + + for (dialect, data_type, identifier) in [ + (&default_dialect, ×tamp, "TIMESTAMP"), + ( + &default_dialect, + ×tamp_with_tz, + "TIMESTAMP WITH TIME ZONE", + ), + (&mysql_dialect, ×tamp, "DATETIME"), + (&mysql_dialect, ×tamp_with_tz, "DATETIME"), + ] { + let unparser = Unparser::new(dialect); + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: data_type.clone(), + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn custom_dialect_date32_ast_dtype() -> Result<()> { + let default_dialect = CustomDialectBuilder::default().build(); + let sqlite_custom_dialect = CustomDialectBuilder::new() + .with_date32_cast_dtype(ast::DataType::Text) + .build(); + + for (dialect, data_type, identifier) in [ + (&default_dialect, DataType::Date32, "DATE"), + (&sqlite_custom_dialect, DataType::Date32, "TEXT"), + ] { + let unparser = Unparser::new(dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } + + #[test] + fn test_cast_value_to_dict_expr() { + let tests = [( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "variation".to_string(), + )))), + data_type: DataType::Dictionary(Box::new(Int8), Box::new(DataType::Utf8)), + }), + "'variation'", + )]; + for (value, expected) in tests { + let dialect = CustomDialectBuilder::new().build(); + let unparser = Unparser::new(&dialect); + + let ast = unparser.expr_to_sql(&value).expect("to be unparsed"); + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + + #[test] + fn test_round_scalar_fn_to_expr() -> Result<()> { + let default_dialect: Arc = Arc::new( + CustomDialectBuilder::new() + .with_identifier_quote_style('"') + .build(), + ); + let postgres_dialect: Arc = Arc::new(PostgreSqlDialect {}); + + for (dialect, identifier) in + [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")] + { + let unparser = Unparser::new(dialect.as_ref()); + let expr = Expr::ScalarFunction(ScalarFunction { + func: Arc::new(ScalarUDF::from( + datafusion_functions::math::round::RoundFunc::new(), + )), + args: vec![ + Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }), + Expr::Literal(ScalarValue::Int64(Some(2))), + ], + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#); + + assert_eq!(actual, expected); + } + Ok(()) + } } diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fb0285901c3f..83ae64ba238b 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -15,9 +15,12 @@ // specific language governing permissions and limitations // under the License. +//! [`Unparser`] for converting `Expr` to SQL text + mod ast; mod expr; mod plan; +mod rewrite; mod utils; pub use expr::expr_to_sql; @@ -26,13 +29,81 @@ pub use plan::plan_to_sql; use self::dialect::{DefaultDialect, Dialect}; pub mod dialect; +/// Convert a DataFusion [`Expr`] to [`sqlparser::ast::Expr`] +/// +/// See [`expr_to_sql`] for background. `Unparser` allows greater control of +/// the conversion, but with a more complicated API. +/// +/// To get more human-readable output, see [`Self::with_pretty`] +/// +/// # Example +/// ``` +/// use datafusion_expr::{col, lit}; +/// use datafusion_sql::unparser::Unparser; +/// let expr = col("a").gt(lit(4)); // form an expression `a > 4` +/// let unparser = Unparser::default(); +/// let sql = unparser.expr_to_sql(&expr).unwrap();// convert to AST +/// // use the Display impl to convert to SQL text +/// assert_eq!(sql.to_string(), "(a > 4)"); +/// // now convert to pretty sql +/// let unparser = unparser.with_pretty(true); +/// let sql = unparser.expr_to_sql(&expr).unwrap(); +/// assert_eq!(sql.to_string(), "a > 4"); // note lack of parenthesis +/// ``` +/// +/// [`Expr`]: datafusion_expr::Expr pub struct Unparser<'a> { dialect: &'a dyn Dialect, + pretty: bool, } impl<'a> Unparser<'a> { pub fn new(dialect: &'a dyn Dialect) -> Self { - Self { dialect } + Self { + dialect, + pretty: false, + } + } + + /// Create pretty SQL output, better suited for human consumption + /// + /// See example on the struct level documentation + /// + /// # Pretty Output + /// + /// By default, `Unparser` generates SQL text that will parse back to the + /// same parsed [`Expr`], which is useful for creating machine readable + /// expressions to send to other systems. However, the resulting expressions are + /// not always nice to read for humans. + /// + /// For example + /// + /// ```sql + /// ((a + 4) > 5) + /// ``` + /// + /// This method removes parenthesis using to the precedence rules of + /// DataFusion. If the output is reparsed, the resulting [`Expr`] produces + /// same value as the original in DataFusion, but with a potentially + /// different order of operations. + /// + /// Note that this setting may create invalid SQL for other SQL query + /// engines with different precedence rules + /// + /// # Example + /// ``` + /// use datafusion_expr::{col, lit}; + /// use datafusion_sql::unparser::Unparser; + /// let expr = col("a").gt(lit(4)).and(col("b").lt(lit(5))); // form an expression `a > 4 AND b < 5` + /// let unparser = Unparser::default().with_pretty(true); + /// let sql = unparser.expr_to_sql(&expr).unwrap(); + /// assert_eq!(sql.to_string(), "a > 4 AND b < 5"); // note lack of parenthesis + /// ``` + /// + /// [`Expr`]: datafusion_expr::Expr + pub fn with_pretty(mut self, pretty: bool) -> Self { + self.pretty = pretty; + self } } @@ -40,6 +111,7 @@ impl<'a> Default for Unparser<'a> { fn default() -> Self { Self { dialect: &DefaultDialect {}, + pretty: false, } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index c9b0a8a04c7e..8167ddacffb4 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,25 +15,48 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; -use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan}; -use sqlparser::ast::{self}; - -use crate::unparser::utils::unproject_agg_exprs; - use super::{ ast::{ BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder, }, - utils::find_agg_node_within_select, + rewrite::{ + inject_column_aliases_into_subquery, normalize_union_schema, + rewrite_plan_for_sort_on_non_projected_fields, + subquery_alias_inner_query_and_columns, TableAliasRewriter, + }, + utils::{ + find_agg_node_within_select, find_unnest_node_within_select, + find_window_nodes_within_select, try_transform_to_simple_table_scan_with_filters, + unproject_sort_expr, unproject_unnest_expr, unproject_window_exprs, + }, Unparser, }; +use crate::unparser::utils::unproject_agg_exprs; +use datafusion_common::{ + internal_err, not_impl_err, + tree_node::{TransformedResult, TreeNode}, + Column, DataFusionError, Result, TableReference, +}; +use datafusion_expr::{ + expr::Alias, BinaryExpr, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, + LogicalPlanBuilder, Operator, Projection, SortExpr, TableScan, +}; +use sqlparser::ast::{self, Ident, SetExpr}; +use std::sync::Arc; -/// Convert a DataFusion [`LogicalPlan`] to `sqlparser::ast::Statement` +/// Convert a DataFusion [`LogicalPlan`] to [`ast::Statement`] /// -/// This function is the opposite of `SqlToRel::sql_statement_to_plan` and can -/// be used to, among other things, convert `LogicalPlan`s to strings. +/// This function is the opposite of [`SqlToRel::sql_statement_to_plan`] and can +/// be used to, among other things, to convert `LogicalPlan`s to SQL strings. +/// +/// # Errors +/// +/// This function returns an error if the plan cannot be converted to SQL. +/// +/// # See Also +/// +/// * [`expr_to_sql`] for converting [`Expr`], a single expression to SQL /// /// # Example /// ``` @@ -44,16 +67,20 @@ use super::{ /// Field::new("id", DataType::Utf8, false), /// Field::new("value", DataType::Utf8, false), /// ]); +/// // Scan 'table' and select columns 'id' and 'value' /// let plan = table_scan(Some("table"), &schema, None) /// .unwrap() /// .project(vec![col("id"), col("value")]) /// .unwrap() /// .build() /// .unwrap(); -/// let sql = plan_to_sql(&plan).unwrap(); -/// -/// assert_eq!(format!("{}", sql), "SELECT \"table\".\"id\", \"table\".\"value\" FROM \"table\"") +/// let sql = plan_to_sql(&plan).unwrap(); // convert to AST +/// // use the Display impl to convert to SQL text +/// assert_eq!(sql.to_string(), "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"") /// ``` +/// +/// [`SqlToRel::sql_statement_to_plan`]: crate::planner::SqlToRel::sql_statement_to_plan +/// [`expr_to_sql`]: crate::unparser::expr_to_sql pub fn plan_to_sql(plan: &LogicalPlan) -> Result { let unparser = Unparser::default(); unparser.plan_to_sql(plan) @@ -61,6 +88,8 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result { impl Unparser<'_> { pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { + let plan = normalize_union_schema(plan)?; + match plan { LogicalPlan::Projection(_) | LogicalPlan::Filter(_) @@ -68,7 +97,6 @@ impl Unparser<'_> { | LogicalPlan::Aggregate(_) | LogicalPlan::Sort(_) | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) | LogicalPlan::Repartition(_) | LogicalPlan::Union(_) | LogicalPlan::TableScan(_) @@ -78,12 +106,13 @@ impl Unparser<'_> { | LogicalPlan::Limit(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Distinct(_) => self.select_to_sql(plan), - LogicalPlan::Dml(_) => self.dml_to_sql(plan), + | LogicalPlan::Distinct(_) => self.select_to_sql_statement(&plan), + LogicalPlan::Dml(_) => self.dml_to_sql(&plan), LogicalPlan::Explain(_) | LogicalPlan::Analyze(_) | LogicalPlan::Extension(_) | LogicalPlan::Prepare(_) + | LogicalPlan::Execute(_) | LogicalPlan::Ddl(_) | LogicalPlan::Copy(_) | LogicalPlan::DescribeTable(_) @@ -92,104 +121,214 @@ impl Unparser<'_> { } } - fn select_to_sql(&self, plan: &LogicalPlan) -> Result { - let mut query_builder = QueryBuilder::default(); + fn select_to_sql_statement(&self, plan: &LogicalPlan) -> Result { + let mut query_builder = Some(QueryBuilder::default()); + + let body = self.select_to_sql_expr(plan, &mut query_builder)?; + + let query = query_builder.unwrap().body(Box::new(body)).build()?; + + Ok(ast::Statement::Query(Box::new(query))) + } + + fn select_to_sql_expr( + &self, + plan: &LogicalPlan, + query: &mut Option, + ) -> Result { let mut select_builder = SelectBuilder::default(); select_builder.push_from(TableWithJoinsBuilder::default()); let mut relation_builder = RelationBuilder::default(); self.select_to_sql_recursively( plan, - &mut query_builder, + query, &mut select_builder, &mut relation_builder, )?; + // If we were able to construct a full body (i.e. UNION ALL), return it + if let Some(body) = query.as_mut().and_then(|q| q.take_body()) { + return Ok(*body); + } + + // If no projection is set, add a wildcard projection to the select + // which will be translated to `SELECT *` in the SQL statement + if !select_builder.already_projected() { + select_builder.projection(vec![ast::SelectItem::Wildcard( + ast::WildcardAdditionalOptions::default(), + )]); + } + let mut twj = select_builder.pop_from().unwrap(); twj.relation(relation_builder); select_builder.push_from(twj); - let body = ast::SetExpr::Select(Box::new(select_builder.build()?)); - let query = query_builder.body(Box::new(body)).build()?; + Ok(SetExpr::Select(Box::new(select_builder.build()?))) + } - Ok(ast::Statement::Query(Box::new(query))) + /// Reconstructs a SELECT SQL statement from a logical plan by unprojecting column expressions + /// found in a [Projection] node. This requires scanning the plan tree for relevant Aggregate + /// and Window nodes and matching column expressions to the appropriate agg or window expressions. + fn reconstruct_select_statement( + &self, + plan: &LogicalPlan, + p: &Projection, + select: &mut SelectBuilder, + ) -> Result<()> { + let mut exprs = p.expr.clone(); + + // If an Unnest node is found within the select, find and unproject the unnest column + if let Some(unnest) = find_unnest_node_within_select(plan) { + exprs = exprs + .into_iter() + .map(|e| unproject_unnest_expr(e, unnest)) + .collect::>>()?; + }; + + match ( + find_agg_node_within_select(plan, true), + find_window_nodes_within_select(plan, None, true), + ) { + (Some(agg), window) => { + let window_option = window.as_deref(); + let items = exprs + .into_iter() + .map(|proj_expr| { + let unproj = unproject_agg_exprs(proj_expr, agg, window_option)?; + self.select_item_to_sql(&unproj) + }) + .collect::>>()?; + + select.projection(items); + select.group_by(ast::GroupByExpr::Expressions( + agg.group_expr + .iter() + .map(|expr| self.expr_to_sql(expr)) + .collect::>>()?, + vec![], + )); + } + (None, Some(window)) => { + let items = exprs + .into_iter() + .map(|proj_expr| { + let unproj = unproject_window_exprs(proj_expr, &window)?; + self.select_item_to_sql(&unproj) + }) + .collect::>>()?; + + select.projection(items); + } + _ => { + let items = exprs + .iter() + .map(|e| self.select_item_to_sql(e)) + .collect::>>()?; + select.projection(items); + } + } + Ok(()) + } + + fn derive( + &self, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + alias: Option, + ) -> Result<()> { + let mut derived_builder = DerivedRelationBuilder::default(); + derived_builder.lateral(false).alias(alias).subquery({ + let inner_statement = self.plan_to_sql(plan)?; + if let ast::Statement::Query(inner_query) = inner_statement { + inner_query + } else { + return internal_err!( + "Subquery must be a Query, but found {inner_statement:?}" + ); + } + }); + relation.derived(derived_builder); + + Ok(()) + } + + fn derive_with_dialect_alias( + &self, + alias: &str, + plan: &LogicalPlan, + relation: &mut RelationBuilder, + ) -> Result<()> { + if self.dialect.requires_derived_table_alias() { + self.derive( + plan, + relation, + Some(self.new_table_alias(alias.to_string(), vec![])), + ) + } else { + self.derive(plan, relation, None) + } } fn select_to_sql_recursively( &self, plan: &LogicalPlan, - query: &mut QueryBuilder, + query: &mut Option, select: &mut SelectBuilder, relation: &mut RelationBuilder, ) -> Result<()> { match plan { LogicalPlan::TableScan(scan) => { + if let Some(unparsed_table_scan) = + Self::unparse_table_scan_pushdown(plan, None)? + { + return self.select_to_sql_recursively( + &unparsed_table_scan, + query, + select, + relation, + ); + } let mut builder = TableRelationBuilder::default(); let mut table_parts = vec![]; + if let Some(catalog_name) = scan.table_name.catalog() { + table_parts + .push(self.new_ident_quoted_if_needs(catalog_name.to_string())); + } if let Some(schema_name) = scan.table_name.schema() { - table_parts.push(self.new_ident(schema_name.to_string())); + table_parts + .push(self.new_ident_quoted_if_needs(schema_name.to_string())); } - table_parts.push(self.new_ident(scan.table_name.table().to_string())); + table_parts.push( + self.new_ident_quoted_if_needs(scan.table_name.table().to_string()), + ); builder.name(ast::ObjectName(table_parts)); relation.table(builder); Ok(()) } LogicalPlan::Projection(p) => { - // A second projection implies a derived tablefactor - if !select.already_projected() { - // Special handling when projecting an agregation plan - if let Some(agg) = find_agg_node_within_select(plan, true) { - let items = p - .expr - .iter() - .map(|proj_expr| { - let unproj = unproject_agg_exprs(proj_expr, agg)?; - self.select_item_to_sql(&unproj) - }) - .collect::>>()?; + if let Some(new_plan) = rewrite_plan_for_sort_on_non_projected_fields(p) { + return self + .select_to_sql_recursively(&new_plan, query, select, relation); + } - select.projection(items); - select.group_by(ast::GroupByExpr::Expressions( - agg.group_expr - .iter() - .map(|expr| self.expr_to_sql(expr)) - .collect::>>()?, - )); - } else { - let items = p - .expr - .iter() - .map(|e| self.select_item_to_sql(e)) - .collect::>>()?; - select.projection(items); - } - self.select_to_sql_recursively( - p.input.as_ref(), - query, - select, + // Projection can be top-level plan for derived table + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_projection", + plan, relation, - ) - } else { - let mut derived_builder = DerivedRelationBuilder::default(); - derived_builder.lateral(false).alias(None).subquery({ - let inner_statment = self.plan_to_sql(plan)?; - if let ast::Statement::Query(inner_query) = inner_statment { - inner_query - } else { - return internal_err!( - "Subquery must be a Query, but found {inner_statment:?}" - ); - } - }); - relation.derived(derived_builder); - Ok(()) + ); } + self.reconstruct_select_statement(plan, p, select)?; + self.select_to_sql_recursively(p.input.as_ref(), query, select, relation) } LogicalPlan::Filter(filter) => { if let Some(agg) = find_agg_node_within_select(plan, select.already_projected()) { - let unprojected = unproject_agg_exprs(&filter.predicate, agg)?; + let unprojected = + unproject_agg_exprs(filter.predicate.clone(), agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); } else { @@ -205,11 +344,33 @@ impl Unparser<'_> { ) } LogicalPlan::Limit(limit) => { - if let Some(fetch) = limit.fetch { - query.limit(Some(ast::Expr::Value(ast::Value::Number( - fetch.to_string(), - false, - )))); + // Limit can be top-level plan for derived table + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_limit", + plan, + relation, + ); + } + if let Some(fetch) = &limit.fetch { + let Some(query) = query.as_mut() else { + return internal_err!( + "Limit operator only valid in a statement context." + ); + }; + query.limit(Some(self.expr_to_sql(fetch)?)); + } + + if let Some(skip) = &limit.skip { + let Some(query) = query.as_mut() else { + return internal_err!( + "Offset operator only valid in a statement context." + ); + }; + query.offset(Some(ast::Offset { + rows: ast::OffsetRows::None, + value: self.expr_to_sql(skip)?, + })); } self.select_to_sql_recursively( @@ -220,7 +381,38 @@ impl Unparser<'_> { ) } LogicalPlan::Sort(sort) => { - query.order_by(self.sort_to_sql(sort.expr.clone())?); + // Sort can be top-level plan for derived table + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_sort", + plan, + relation, + ); + } + let Some(query_ref) = query else { + return internal_err!( + "Sort operator only valid in a statement context." + ); + }; + + if let Some(fetch) = sort.fetch { + query_ref.limit(Some(ast::Expr::Value(ast::Value::Number( + fetch.to_string(), + false, + )))); + }; + + let agg = find_agg_node_within_select(plan, select.already_projected()); + // unproject sort expressions + let sort_exprs: Vec = sort + .expr + .iter() + .map(|sort_expr| { + unproject_sort_expr(sort_expr, agg, sort.input.as_ref()) + }) + .collect::>>()?; + + query_ref.order_by(self.sorts_to_sql(&sort_exprs)?); self.select_to_sql_recursively( sort.input.as_ref(), @@ -230,7 +422,7 @@ impl Unparser<'_> { ) } LogicalPlan::Aggregate(agg) => { - // Aggregate nodes are handled simulatenously with Projection nodes + // Aggregate nodes are handled simultaneously with Projection nodes self.select_to_sql_recursively( agg.input.as_ref(), query, @@ -238,61 +430,130 @@ impl Unparser<'_> { relation, ) } - LogicalPlan::Distinct(_distinct) => { - not_impl_err!("Unsupported operator: {plan:?}") + LogicalPlan::Distinct(distinct) => { + // Distinct can be top-level plan for derived table + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_distinct", + plan, + relation, + ); + } + let (select_distinct, input) = match distinct { + Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()), + Distinct::On(on) => { + let exprs = on + .on_expr + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + let items = on + .select_expr + .iter() + .map(|e| self.select_item_to_sql(e)) + .collect::>>()?; + if let Some(sort_expr) = &on.sort_expr { + if let Some(query_ref) = query { + query_ref.order_by(self.sorts_to_sql(sort_expr)?); + } else { + return internal_err!( + "Sort operator only valid in a statement context." + ); + } + } + select.projection(items); + (ast::Distinct::On(exprs), on.input.as_ref()) + } + }; + select.distinct(Some(select_distinct)); + self.select_to_sql_recursively(input, query, select, relation) } LogicalPlan::Join(join) => { - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => { - return not_impl_err!( - "Unsupported join constraint: {:?}", - join.join_constraint - ) - } - } + let mut table_scan_filters = vec![]; - // parse filter if exists - let join_filter = match &join.filter { - Some(filter) => Some(self.expr_to_sql(filter)?), - None => None, - }; + let left_plan = + match try_transform_to_simple_table_scan_with_filters(&join.left)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.left), + }; - // map join.on to `l.a = r.a AND l.b = r.b AND ...` - let eq_op = ast::BinaryOperator::Eq; - let join_on = self.join_conditions_to_sql(&join.on, eq_op)?; + self.select_to_sql_recursively( + left_plan.as_ref(), + query, + select, + relation, + )?; - // Merge `join_on` and `join_filter` - let join_expr = match (join_filter, join_on) { - (Some(filter), Some(on)) => Some(self.and_op_to_sql(filter, on)), - (Some(filter), None) => Some(filter), - (None, Some(on)) => Some(on), - (None, None) => None, - }; - let join_constraint = match join_expr { - Some(expr) => ast::JoinConstraint::On(expr), - None => ast::JoinConstraint::None, - }; + let right_plan = + match try_transform_to_simple_table_scan_with_filters(&join.right)? { + Some((plan, filters)) => { + table_scan_filters.extend(filters); + Arc::new(plan) + } + None => Arc::clone(&join.right), + }; let mut right_relation = RelationBuilder::default(); self.select_to_sql_recursively( - join.left.as_ref(), + right_plan.as_ref(), query, select, - relation, + &mut right_relation, )?; + + let join_filters = if table_scan_filters.is_empty() { + join.filter.clone() + } else { + // Combine `table_scan_filters` into a single filter using `AND` + let Some(combined_filters) = + table_scan_filters.into_iter().reduce(|acc, filter| { + Expr::BinaryExpr(BinaryExpr { + left: Box::new(acc), + op: Operator::And, + right: Box::new(filter), + }) + }) + else { + return internal_err!("Failed to combine TableScan filters"); + }; + + // Combine `join.filter` with `combined_filters` using `AND` + match &join.filter { + Some(filter) => Some(Expr::BinaryExpr(BinaryExpr { + left: Box::new(filter.clone()), + op: Operator::And, + right: Box::new(combined_filters), + })), + None => Some(combined_filters), + } + }; + + let join_constraint = self.join_constraint_to_sql( + join.join_constraint, + &join.on, + join_filters.as_ref(), + )?; + self.select_to_sql_recursively( - join.right.as_ref(), + right_plan.as_ref(), query, select, &mut right_relation, )?; + let Ok(Some(relation)) = right_relation.build() else { + return internal_err!("Failed to build right relation"); + }; + let ast_join = ast::Join { - relation: right_relation.build()?, + relation, + global: false, join_operator: self - .join_operator_to_sql(join.join_type, join_constraint), + .join_operator_to_sql(join.join_type, join_constraint)?, }; let mut from = select.pop_from().unwrap(); from.push_join(ast_join); @@ -301,31 +562,271 @@ impl Unparser<'_> { Ok(()) } LogicalPlan::SubqueryAlias(plan_alias) => { - // Handle bottom-up to allocate relation - self.select_to_sql_recursively( - plan_alias.input.as_ref(), - query, - select, - relation, + let (plan, mut columns) = + subquery_alias_inner_query_and_columns(plan_alias); + let unparsed_table_scan = Self::unparse_table_scan_pushdown( + plan, + Some(plan_alias.alias.clone()), )?; + // if the child plan is a TableScan with pushdown operations, we don't need to + // create an additional subquery for it + if !select.already_projected() && unparsed_table_scan.is_none() { + select.projection(vec![ast::SelectItem::Wildcard( + ast::WildcardAdditionalOptions::default(), + )]); + } + let plan = unparsed_table_scan.unwrap_or_else(|| plan.clone()); + if !columns.is_empty() + && !self.dialect.supports_column_alias_in_table_alias() + { + // Instead of specifying column aliases as part of the outer table, inject them directly into the inner projection + let rewritten_plan = + match inject_column_aliases_into_subquery(plan, columns) { + Ok(p) => p, + Err(e) => { + return internal_err!( + "Failed to transform SubqueryAlias plan: {e}" + ) + } + }; + + columns = vec![]; + + self.select_to_sql_recursively( + &rewritten_plan, + query, + select, + relation, + )?; + } else { + self.select_to_sql_recursively(&plan, query, select, relation)?; + } relation.alias(Some( - self.new_table_alias(plan_alias.alias.table().to_string()), + self.new_table_alias(plan_alias.alias.table().to_string(), columns), )); Ok(()) } - LogicalPlan::Union(_union) => { - not_impl_err!("Unsupported operator: {plan:?}") + LogicalPlan::Union(union) => { + if union.inputs.len() != 2 { + return not_impl_err!( + "UNION ALL expected 2 inputs, but found {}", + union.inputs.len() + ); + } + + // Covers cases where the UNION is a subquery and the projection is at the top level + if select.already_projected() { + return self.derive_with_dialect_alias( + "derived_union", + plan, + relation, + ); + } + + let input_exprs: Vec = union + .inputs + .iter() + .map(|input| self.select_to_sql_expr(input, query)) + .collect::>>()?; + + let union_expr = SetExpr::SetOperation { + op: ast::SetOperator::Union, + set_quantifier: ast::SetQuantifier::All, + left: Box::new(input_exprs[0].clone()), + right: Box::new(input_exprs[1].clone()), + }; + + let Some(query) = query.as_mut() else { + return internal_err!( + "UNION ALL operator only valid in a statement context" + ); + }; + query.body(Box::new(union_expr)); + + Ok(()) } - LogicalPlan::Window(_window) => { - not_impl_err!("Unsupported operator: {plan:?}") + LogicalPlan::Window(window) => { + // Window nodes are handled simultaneously with Projection nodes + self.select_to_sql_recursively( + window.input.as_ref(), + query, + select, + relation, + ) + } + LogicalPlan::EmptyRelation(_) => { + relation.empty(); + Ok(()) } LogicalPlan::Extension(_) => not_impl_err!("Unsupported operator: {plan:?}"), + LogicalPlan::Unnest(unnest) => { + if !unnest.struct_type_columns.is_empty() { + return internal_err!( + "Struct type columns are not currently supported in UNNEST: {:?}", + unnest.struct_type_columns + ); + } + + // In the case of UNNEST, the Unnest node is followed by a duplicate Projection node that we should skip. + // Otherwise, there will be a duplicate SELECT clause. + // | Projection: table.col1, UNNEST(table.col2) + // | Unnest: UNNEST(table.col2) + // | Projection: table.col1, table.col2 AS UNNEST(table.col2) + // | Filter: table.col3 = Int64(3) + // | TableScan: table projection=None + if let LogicalPlan::Projection(p) = unnest.input.as_ref() { + // continue with projection input + self.select_to_sql_recursively(&p.input, query, select, relation) + } else { + internal_err!("Unnest input is not a Projection: {unnest:?}") + } + } _ => not_impl_err!("Unsupported operator: {plan:?}"), } } + fn is_scan_with_pushdown(scan: &TableScan) -> bool { + scan.projection.is_some() || !scan.filters.is_empty() || scan.fetch.is_some() + } + + /// Try to unparse a table scan with pushdown operations into a new subquery plan. + /// If the table scan is without any pushdown operations, return None. + fn unparse_table_scan_pushdown( + plan: &LogicalPlan, + alias: Option, + ) -> Result> { + match plan { + LogicalPlan::TableScan(table_scan) => { + if !Self::is_scan_with_pushdown(table_scan) { + return Ok(None); + } + let table_schema = table_scan.source.schema(); + let mut filter_alias_rewriter = + alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: &table_schema, + alias_name: alias_name.clone(), + }); + + let mut builder = LogicalPlanBuilder::scan( + table_scan.table_name.clone(), + Arc::clone(&table_scan.source), + None, + )?; + // We will rebase the column references to the new alias if it exists. + // If the projection or filters are empty, we will append alias to the table scan. + // + // Example: + // select t1.c1 from t1 where t1.c1 > 1 -> select a.c1 from t1 as a where a.c1 > 1 + if let Some(ref alias) = alias { + if table_scan.projection.is_some() || !table_scan.filters.is_empty() { + builder = builder.alias(alias.clone())?; + } + } + + if let Some(project_vec) = &table_scan.projection { + let project_columns = project_vec + .iter() + .cloned() + .map(|i| { + let schema = table_scan.source.schema(); + let field = schema.field(i); + if alias.is_some() { + Column::new(alias.clone(), field.name().clone()) + } else { + Column::new( + Some(table_scan.table_name.clone()), + field.name().clone(), + ) + } + }) + .collect::>(); + builder = builder.project(project_columns)?; + } + + let filter_expr: Result> = table_scan + .filters + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = filter_alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .reduce(|acc, expr_result| { + acc.and_then(|acc_expr| { + expr_result.map(|expr| acc_expr.and(expr)) + }) + }) + .transpose(); + + if let Some(filter) = filter_expr? { + builder = builder.filter(filter)?; + } + + if let Some(fetch) = table_scan.fetch { + builder = builder.limit(0, Some(fetch))?; + } + + // If the table scan has an alias but no projection or filters, it means no column references are rebased. + // So we will append the alias to this subquery. + // Example: + // select * from t1 limit 10 -> (select * from t1 limit 10) as a + if let Some(alias) = alias { + if table_scan.projection.is_none() && table_scan.filters.is_empty() { + builder = builder.alias(alias)?; + } + } + + Ok(Some(builder.build()?)) + } + LogicalPlan::SubqueryAlias(subquery_alias) => { + Self::unparse_table_scan_pushdown( + &subquery_alias.input, + Some(subquery_alias.alias.clone()), + ) + } + // SubqueryAlias could be rewritten to a plan with a projection as the top node by [rewrite::subquery_alias_inner_query_and_columns]. + // The inner table scan could be a scan with pushdown operations. + LogicalPlan::Projection(projection) => { + if let Some(plan) = + Self::unparse_table_scan_pushdown(&projection.input, alias.clone())? + { + let exprs = if alias.is_some() { + let mut alias_rewriter = + alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: plan.schema().as_arrow(), + alias_name: alias_name.clone(), + }); + projection + .expr + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::>>()? + } else { + projection.expr.clone() + }; + Ok(Some( + LogicalPlanBuilder::from(plan).project(exprs)?.build()?, + )) + } else { + Ok(None) + } + } + _ => Ok(None), + } + } + fn select_item_to_sql(&self, expr: &Expr) -> Result { match expr { Expr::Alias(Alias { expr, name, .. }) => { @@ -333,7 +834,7 @@ impl Unparser<'_> { Ok(ast::SelectItem::ExprWithAlias { expr: inner, - alias: self.new_ident(name.to_string()), + alias: self.new_ident_quoted_if_needs(name.to_string()), }) } _ => { @@ -344,20 +845,10 @@ impl Unparser<'_> { } } - fn sort_to_sql(&self, sort_exprs: Vec) -> Result> { + fn sorts_to_sql(&self, sort_exprs: &[SortExpr]) -> Result> { sort_exprs .iter() - .map(|expr: &Expr| match expr { - Expr::Sort(sort_expr) => { - let col = self.expr_to_sql(&sort_expr.expr)?; - Ok(ast::OrderByExpr { - asc: Some(sort_expr.asc), - expr: col, - nulls_first: Some(sort_expr.nulls_first), - }) - } - _ => plan_err!("Expecting Sort expr"), - }) + .map(|sort_expr| self.sort_to_sql(sort_expr)) .collect::>>() } @@ -365,8 +856,8 @@ impl Unparser<'_> { &self, join_type: JoinType, constraint: ast::JoinConstraint, - ) -> ast::JoinOperator { - match join_type { + ) -> Result { + Ok(match join_type { JoinType::Inner => ast::JoinOperator::Inner(constraint), JoinType::Left => ast::JoinOperator::LeftOuter(constraint), JoinType::Right => ast::JoinOperator::RightOuter(constraint), @@ -375,37 +866,122 @@ impl Unparser<'_> { JoinType::LeftSemi => ast::JoinOperator::LeftSemi(constraint), JoinType::RightAnti => ast::JoinOperator::RightAnti(constraint), JoinType::RightSemi => ast::JoinOperator::RightSemi(constraint), + JoinType::LeftMark => unimplemented!("Unparsing of Left Mark join type"), + }) + } + + /// Convert the components of a USING clause to the USING AST. Returns + /// 'None' if the conditions are not compatible with a USING expression, + /// e.g. non-column expressions or non-matching names. + fn join_using_to_sql( + &self, + join_conditions: &[(Expr, Expr)], + ) -> Option { + let mut idents = Vec::with_capacity(join_conditions.len()); + for (left, right) in join_conditions { + match (left, right) { + ( + Expr::Column(Column { + relation: _, + name: left_name, + }), + Expr::Column(Column { + relation: _, + name: right_name, + }), + ) if left_name == right_name => { + idents.push(self.new_ident_quoted_if_needs(left_name.to_string())); + } + // USING is only valid with matching column names; arbitrary expressions + // are not allowed + _ => return None, + } } + Some(ast::JoinConstraint::Using(idents)) } - fn join_conditions_to_sql( + /// Convert a join constraint and associated conditions and filter to a SQL AST node + fn join_constraint_to_sql( &self, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: ast::BinaryOperator, - ) -> Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; + constraint: JoinConstraint, + conditions: &[(Expr, Expr)], + filter: Option<&Expr>, + ) -> Result { + match (constraint, conditions, filter) { + // No constraints + (JoinConstraint::On | JoinConstraint::Using, [], None) => { + Ok(ast::JoinConstraint::None) + } + + (JoinConstraint::Using, conditions, None) => { + match self.join_using_to_sql(conditions) { + Some(using) => Ok(using), + // As above, this should not be reachable from parsed SQL, + // but a user could create this; we "downgrade" to ON. + None => self.join_conditions_to_sql_on(conditions, None), + } + } + + // Two cases here: + // 1. Straightforward ON case, with possible equi-join conditions + // and additional filters + // 2. USING with additional filters; we "downgrade" to ON, because + // you can't use USING with arbitrary filters. (This should not + // be accessible from parsed SQL, but may have been a + // custom-built JOIN by a user.) + (JoinConstraint::On | JoinConstraint::Using, conditions, filter) => { + self.join_conditions_to_sql_on(conditions, filter) + } + } + } + + // Convert a list of equi0join conditions and an optional filter to a SQL ON + // AST node, with the equi-join conditions and the filter merged into a + // single conditional expression + fn join_conditions_to_sql_on( + &self, + join_conditions: &[(Expr, Expr)], + filter: Option<&Expr>, + ) -> Result { + let mut condition = None; + // AND the join conditions together to create the overall condition for (left, right) in join_conditions { - // Parse left + // Parse left and right let l = self.expr_to_sql(left)?; - // Parse right let r = self.expr_to_sql(right)?; - // AND with existing expression - exprs.push(self.binary_op_to_sql(l, r, eq_op.clone())); + let e = self.binary_op_to_sql(l, r, ast::BinaryOperator::Eq); + condition = match condition { + Some(expr) => Some(self.and_op_to_sql(expr, e)), + None => Some(e), + }; } - let join_expr: Option = - exprs.into_iter().reduce(|r, l| self.and_op_to_sql(r, l)); - Ok(join_expr) + + // Then AND the non-equijoin filter condition as well + condition = match (condition, filter) { + (Some(expr), Some(filter)) => { + Some(self.and_op_to_sql(expr, self.expr_to_sql(filter)?)) + } + (Some(expr), None) => Some(expr), + (None, Some(filter)) => Some(self.expr_to_sql(filter)?), + (None, None) => None, + }; + + let constraint = match condition { + Some(filter) => ast::JoinConstraint::On(filter), + None => ast::JoinConstraint::None, + }; + + Ok(constraint) } fn and_op_to_sql(&self, lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr { self.binary_op_to_sql(lhs, rhs, ast::BinaryOperator::And) } - fn new_table_alias(&self, alias: String) -> ast::TableAlias { + fn new_table_alias(&self, alias: String, columns: Vec) -> ast::TableAlias { ast::TableAlias { - name: self.new_ident(alias), - columns: Vec::new(), + name: self.new_ident_quoted_if_needs(alias), + columns, } } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs new file mode 100644 index 000000000000..57d700f86955 --- /dev/null +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -0,0 +1,368 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + +use arrow_schema::Schema; +use datafusion_common::{ + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + Column, Result, TableReference, +}; +use datafusion_expr::{expr::Alias, tree_node::transform_sort_vec}; +use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; +use sqlparser::ast::Ident; + +/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. +/// +/// DataFusion will return an error if two columns in the schema have the same name with no table qualifiers. +/// There are certain types of UNION queries that can result in having two columns with the same name, and the +/// solution was to add table qualifiers to the schema fields. +/// See for more context on this decision. +/// +/// However, this causes a problem when unparsing these queries back to SQL - as the table qualifier has +/// logically been erased and is no longer a valid reference. +/// +/// The following input SQL: +/// ```sql +/// SELECT table1.foo FROM table1 +/// UNION ALL +/// SELECT table2.foo FROM table2 +/// ORDER BY foo +/// ``` +/// +/// Would be unparsed into the following invalid SQL without this transformation: +/// ```sql +/// SELECT table1.foo FROM table1 +/// UNION ALL +/// SELECT table2.foo FROM table2 +/// ORDER BY table1.foo +/// ``` +/// +/// Which would result in a SQL error, as `table1.foo` is not a valid reference in the context of the UNION. +pub(super) fn normalize_union_schema(plan: &LogicalPlan) -> Result { + let plan = plan.clone(); + + let transformed_plan = plan.transform_up(|plan| match plan { + LogicalPlan::Union(mut union) => { + let schema = Arc::unwrap_or_clone(union.schema); + let schema = schema.strip_qualifiers(); + + union.schema = Arc::new(schema); + Ok(Transformed::yes(LogicalPlan::Union(union))) + } + LogicalPlan::Sort(sort) => { + // Only rewrite Sort expressions that have a UNION as their input + if !matches!(&*sort.input, LogicalPlan::Union(_)) { + return Ok(Transformed::no(LogicalPlan::Sort(sort))); + } + + Ok(Transformed::yes(LogicalPlan::Sort(Sort { + expr: rewrite_sort_expr_for_union(sort.expr)?, + input: sort.input, + fetch: sort.fetch, + }))) + } + _ => Ok(Transformed::no(plan)), + }); + transformed_plan.data() +} + +/// Rewrite sort expressions that have a UNION plan as their input to remove the table reference. +fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { + let sort_exprs = transform_sort_vec(exprs, &mut |expr| { + expr.transform_up(|expr| { + if let Expr::Column(mut col) = expr { + col.relation = None; + Ok(Transformed::yes(Expr::Column(col))) + } else { + Ok(Transformed::no(expr)) + } + }) + }) + .data()?; + + Ok(sort_exprs) +} + +/// Rewrite logic plan for query that order by columns are not in projections +/// Plan before rewrite: +/// +/// Projection: j1.j1_string, j2.j2_string +/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +/// Projection: j1.j1_string, j2.j2_string, j1.j1_id, j2.j2_id +/// Inner Join: Filter: j1.j1_id = j2.j2_id +/// TableScan: j1 +/// TableScan: j2 +/// +/// Plan after rewrite +/// +/// Sort: j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST +/// Projection: j1.j1_string, j2.j2_string +/// Inner Join: Filter: j1.j1_id = j2.j2_id +/// TableScan: j1 +/// TableScan: j2 +/// +/// This prevents the original plan generate query with derived table but missing alias. +pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( + p: &Projection, +) -> Option { + let LogicalPlan::Sort(sort) = p.input.as_ref() else { + return None; + }; + + let LogicalPlan::Projection(inner_p) = sort.input.as_ref() else { + return None; + }; + + let mut map = HashMap::new(); + let inner_exprs = inner_p + .expr + .iter() + .enumerate() + .map(|(i, f)| match f { + Expr::Alias(alias) => { + let a = Expr::Column(alias.name.clone().into()); + map.insert(a.clone(), f.clone()); + a + } + Expr::Column(_) => { + map.insert( + Expr::Column(inner_p.schema.field(i).name().into()), + f.clone(), + ); + f.clone() + } + _ => { + let a = Expr::Column(inner_p.schema.field(i).name().into()); + map.insert(a.clone(), f.clone()); + a + } + }) + .collect::>(); + + let mut collects = p.expr.clone(); + for sort in &sort.expr { + collects.push(sort.expr.clone()); + } + + // Compare outer collects Expr::to_string with inner collected transformed values + // alias -> alias column + // column -> remain + // others, extract schema field name + let outer_collects = collects.iter().map(Expr::to_string).collect::>(); + let inner_collects = inner_exprs + .iter() + .map(Expr::to_string) + .collect::>(); + + if outer_collects == inner_collects { + let mut sort = sort.clone(); + let mut inner_p = inner_p.clone(); + + let new_exprs = p + .expr + .iter() + .map(|e| map.get(e).unwrap_or(e).clone()) + .collect::>(); + + inner_p.expr.clone_from(&new_exprs); + sort.input = Arc::new(LogicalPlan::Projection(inner_p)); + + Some(LogicalPlan::Sort(sort)) + } else { + None + } +} + +/// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of +/// subquery +/// - `(SELECT column_a as a from table) AS A` +/// - `(SELECT column_a from table) AS A (a)` +/// +/// A roundtrip example for table alias with columns +/// +/// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) +/// +/// LogicPlan: +/// Projection: c.id +/// SubqueryAlias: c +/// Projection: j1.j1_id AS id +/// Projection: j1.j1_id +/// TableScan: j1 +/// +/// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS +/// id FROM (SELECT j1.j1_id FROM j1)) AS c`. +/// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table +/// `(SELECT j1.j1_id FROM j1)` +/// +/// With this logic, the unparsed query will be: +/// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` +/// +/// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` +/// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and +/// Column in the Projections. Once the parser side is fixed, this logic should work +pub(super) fn subquery_alias_inner_query_and_columns( + subquery_alias: &datafusion_expr::SubqueryAlias, +) -> (&LogicalPlan, Vec) { + let plan: &LogicalPlan = subquery_alias.input.as_ref(); + + let LogicalPlan::Projection(outer_projections) = plan else { + return (plan, vec![]); + }; + + // Check if it's projection inside projection + let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else { + return (plan, vec![]); + }; + + let mut columns: Vec = vec![]; + // Check if the inner projection and outer projection have a matching pattern like + // Projection: j1.j1_id AS id + // Projection: j1.j1_id + for (i, inner_expr) in inner_projection.expr.iter().enumerate() { + let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + return (plan, vec![]); + }; + + // Inner projection schema fields store the projection name which is used in outer + // projection expr + let inner_expr_string = match inner_expr { + Expr::Column(_) => inner_expr.to_string(), + _ => inner_projection.schema.field(i).name().clone(), + }; + + if outer_alias.expr.to_string() != inner_expr_string { + return (plan, vec![]); + }; + + columns.push(outer_alias.name.as_str().into()); + } + + (outer_projections.input.as_ref(), columns) +} + +/// Injects column aliases into a subquery's logical plan. The function searches for a `Projection` +/// within the given plan, which may be wrapped by other operators (e.g., LIMIT, SORT). +/// If the top-level plan is a `Projection`, it directly injects the column aliases. +/// Otherwise, it iterates through the plan's children to locate and transform the `Projection`. +/// +/// Example: +/// - `SELECT col1, col2 FROM table LIMIT 10` plan with aliases `["alias_1", "some_alias_2"]` will be transformed to +/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table LIMIT 10` +pub(super) fn inject_column_aliases_into_subquery( + plan: LogicalPlan, + aliases: Vec, +) -> Result { + match &plan { + LogicalPlan::Projection(inner_p) => Ok(inject_column_aliases(inner_p, aliases)), + _ => { + // projection is wrapped by other operator (LIMIT, SORT, etc), iterate through the plan to find it + plan.map_children(|child| { + if let LogicalPlan::Projection(p) = &child { + Ok(Transformed::yes(inject_column_aliases(p, aliases.clone()))) + } else { + Ok(Transformed::no(child)) + } + }) + .map(|plan| plan.data) + } + } +} + +/// Injects column aliases into the projection of a logical plan by wrapping expressions +/// with `Expr::Alias` using the provided list of aliases. +/// +/// Example: +/// - `SELECT col1, col2 FROM table` with aliases `["alias_1", "some_alias_2"]` will be transformed to +/// - `SELECT col1 AS alias_1, col2 AS some_alias_2 FROM table` +pub(super) fn inject_column_aliases( + projection: &Projection, + aliases: impl IntoIterator, +) -> LogicalPlan { + let mut updated_projection = projection.clone(); + + let new_exprs = updated_projection + .expr + .into_iter() + .zip(aliases) + .map(|(expr, col_alias)| { + let relation = match &expr { + Expr::Column(col) => col.relation.clone(), + _ => None, + }; + + Expr::Alias(Alias { + expr: Box::new(expr.clone()), + relation, + name: col_alias.value, + }) + }) + .collect::>(); + + updated_projection.expr = new_exprs; + + LogicalPlan::Projection(updated_projection) +} + +fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { + match logical_plan { + LogicalPlan::Projection(p) => Some(p), + LogicalPlan::Limit(p) => find_projection(p.input.as_ref()), + LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()), + LogicalPlan::Sort(p) => find_projection(p.input.as_ref()), + _ => None, + } +} + +/// A `TreeNodeRewriter` implementation that rewrites `Expr::Column` expressions by +/// replacing the column's name with an alias if the column exists in the provided schema. +/// +/// This is typically used to apply table aliases in query plans, ensuring that +/// the column references in the expressions use the correct table alias. +/// +/// # Fields +/// +/// * `table_schema`: The schema (`SchemaRef`) representing the table structure +/// from which the columns are referenced. This is used to look up columns by their names. +/// * `alias_name`: The alias (`TableReference`) that will replace the table name +/// in the column references when applicable. +pub struct TableAliasRewriter<'a> { + pub table_schema: &'a Schema, + pub alias_name: TableReference, +} + +impl TreeNodeRewriter for TableAliasRewriter<'_> { + type Node = Expr; + + fn f_down(&mut self, expr: Expr) -> Result> { + match expr { + Expr::Column(column) => { + if let Ok(field) = self.table_schema.field_with_name(&column.name) { + let new_column = + Column::new(Some(self.alias_name.clone()), field.name().clone()); + Ok(Transformed::yes(Expr::Column(new_column))) + } else { + Ok(Transformed::no(Expr::Column(column))) + } + } + _ => Ok(Transformed::no(expr)), + } + } +} diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index c1b02c330fae..284956cef195 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -15,22 +15,30 @@ // specific language governing permissions and limitations // under the License. +use std::{cmp::Ordering, sync::Arc, vec}; + use datafusion_common::{ internal_err, - tree_node::{Transformed, TreeNode}, - Result, + tree_node::{Transformed, TransformedResult, TreeNode}, + Column, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::{ + expr, utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, + LogicalPlanBuilder, Projection, SortExpr, Unnest, Window, }; -use datafusion_expr::{Aggregate, Expr, LogicalPlan}; +use sqlparser::ast; -/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists +use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; + +/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). -/// If an Aggregate node is not found prior to this or at all before reaching the end +/// If an Aggregate or node is not found prior to this or at all before reaching the end /// of the tree, None is returned. pub(crate) fn find_agg_node_within_select( plan: &LogicalPlan, already_projected: bool, ) -> Option<&Aggregate> { - // Note that none of the nodes that have a corresponding agg node can have more + // Note that none of the nodes that have a corresponding node can have more // than 1 input node. E.g. Projection / Filter always have 1 input node. let input = plan.inputs(); let input = if input.len() > 1 { @@ -38,6 +46,7 @@ pub(crate) fn find_agg_node_within_select( } else { input.first()? }; + // Agg nodes explicitly return immediately with a single node if let LogicalPlan::Aggregate(agg) = input { Some(agg) } else if let LogicalPlan::TableScan(_) = input { @@ -53,27 +62,121 @@ pub(crate) fn find_agg_node_within_select( } } +/// Recursively searches children of [LogicalPlan] to find Unnest node if exist +pub(crate) fn find_unnest_node_within_select(plan: &LogicalPlan) -> Option<&Unnest> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return None; + } else { + input.first()? + }; + + if let LogicalPlan::Unnest(unnest) = input { + Some(unnest) + } else if let LogicalPlan::TableScan(_) = input { + None + } else if let LogicalPlan::Projection(_) = input { + None + } else { + find_unnest_node_within_select(input) + } +} + +/// Recursively searches children of [LogicalPlan] to find Window nodes if exist +/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). +/// If Window node is not found prior to this or at all before reaching the end +/// of the tree, None is returned. +pub(crate) fn find_window_nodes_within_select<'a>( + plan: &'a LogicalPlan, + mut prev_windows: Option>, + already_projected: bool, +) -> Option> { + // Note that none of the nodes that have a corresponding node can have more + // than 1 input node. E.g. Projection / Filter always have 1 input node. + let input = plan.inputs(); + let input = if input.len() > 1 { + return prev_windows; + } else { + input.first()? + }; + + // Window nodes accumulate in a vec until encountering a TableScan or 2nd projection + match input { + LogicalPlan::Window(window) => { + prev_windows = match &mut prev_windows { + Some(windows) => { + windows.push(window); + prev_windows + } + _ => Some(vec![window]), + }; + find_window_nodes_within_select(input, prev_windows, already_projected) + } + LogicalPlan::Projection(_) => { + if already_projected { + prev_windows + } else { + find_window_nodes_within_select(input, prev_windows, true) + } + } + LogicalPlan::TableScan(_) => prev_windows, + _ => find_window_nodes_within_select(input, prev_windows, already_projected), + } +} + +/// Recursively identify Column expressions and transform them into the appropriate unnest expression +/// +/// For example, if expr contains the column expr "unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" +/// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) +pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(col_ref) = &sub_expr { + // Check if the column is among the columns to run unnest on. + // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. + if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if let Ok(idx) = unnest.schema.index_of_column(col_ref) { + if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { + if let Some(unprojected_expr) = expr.get(idx) { + let unnest_expr = Expr::Unnest(expr::Unnest::new(unprojected_expr.clone())); + return Ok(Transformed::yes(unnest_expr)); + } + } + } + return internal_err!( + "Tried to unproject unnest expr for column '{}' that was not found in the provided Unnest!", &col_ref.name + ); + } + } + + Ok(Transformed::no(sub_expr)) + + }).map(|e| e.data) +} + /// Recursively identify all Column expressions and transform them into the appropriate /// aggregate expression contained in agg. /// /// For example, if expr contains the column expr "COUNT(*)" it will be transformed /// into an actual aggregate expression COUNT(*) as identified in the aggregate node. -pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result { - expr.clone() - .transform(|sub_expr| { +pub(crate) fn unproject_agg_exprs( + expr: Expr, + agg: &Aggregate, + windows: Option<&[&Window]>, +) -> Result { + expr.transform(|sub_expr| { if let Expr::Column(c) = sub_expr { - // find the column in the agg schmea - if let Ok(n) = agg.schema.index_of_column(&c) { - let unprojected_expr = agg - .group_expr - .iter() - .chain(agg.aggr_expr.iter()) - .nth(n) - .unwrap(); + if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + return Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)); } else { internal_err!( - "Tried to unproject agg expr not found in provided Aggregate!" + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name ) } } else { @@ -82,3 +185,263 @@ pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result }) .map(|e| e.data) } + +/// Recursively identify all Column expressions and transform them into the appropriate +/// window expression contained in window. +/// +/// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed +/// into an actual window expression as identified in the window node. +pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { + expr.transform(|sub_expr| { + if let Expr::Column(c) = sub_expr { + if let Some(unproj) = find_window_expr(windows, &c.name) { + Ok(Transformed::yes(unproj.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } + } else { + Ok(Transformed::no(sub_expr)) + } + }) + .map(|e| e.data) +} + +fn find_agg_expr<'a>(agg: &'a Aggregate, column: &Column) -> Result> { + if let Ok(index) = agg.schema.index_of_column(column) { + if matches!(agg.group_expr.as_slice(), [Expr::GroupingSet(_)]) { + // For grouping set expr, we must operate by expression list from the grouping set + let grouping_expr = grouping_set_to_exprlist(agg.group_expr.as_slice())?; + match index.cmp(&grouping_expr.len()) { + Ordering::Less => Ok(grouping_expr.into_iter().nth(index)), + Ordering::Equal => { + internal_err!( + "Tried to unproject column referring to internal grouping id" + ) + } + Ordering::Greater => { + Ok(agg.aggr_expr.get(index - grouping_expr.len() - 1)) + } + } + } else { + Ok(agg.group_expr.iter().chain(agg.aggr_expr.iter()).nth(index)) + } + } else { + Ok(None) + } +} + +fn find_window_expr<'a>( + windows: &'a [&'a Window], + column_name: &'a str, +) -> Option<&'a Expr> { + windows + .iter() + .flat_map(|w| w.window_expr.iter()) + .find(|expr| expr.schema_name().to_string() == column_name) +} + +/// Transforms a Column expression into the actual expression from aggregation or projection if found. +/// This is required because if an ORDER BY expression is present in an Aggregate or Select, it is replaced +/// with a Column expression (e.g., "sum(catalog_returns.cr_net_loss)"). We need to transform it back to +/// the actual expression, such as sum("catalog_returns"."cr_net_loss"). +pub(crate) fn unproject_sort_expr( + sort_expr: &SortExpr, + agg: Option<&Aggregate>, + input: &LogicalPlan, +) -> Result { + let mut sort_expr = sort_expr.clone(); + + // Remove alias if present, because ORDER BY cannot use aliases + if let Expr::Alias(alias) = &sort_expr.expr { + sort_expr.expr = *alias.expr.clone(); + } + + let Expr::Column(ref col_ref) = sort_expr.expr else { + return Ok(sort_expr); + }; + + if col_ref.relation.is_some() { + return Ok(sort_expr); + }; + + // In case of aggregation there could be columns containing aggregation functions we need to unproject + if let Some(agg) = agg { + if agg.schema.is_column_from_schema(col_ref) { + let new_expr = unproject_agg_exprs(sort_expr.expr, agg, None)?; + sort_expr.expr = new_expr; + return Ok(sort_expr); + } + } + + // If SELECT and ORDER BY contain the same expression with a scalar function, the ORDER BY expression will + // be replaced by a Column expression (e.g., "substr(customer.c_last_name, Int64(0), Int64(5))"), and we need + // to transform it back to the actual expression. + if let LogicalPlan::Projection(Projection { expr, schema, .. }) = input { + if let Ok(idx) = schema.index_of_column(col_ref) { + if let Some(Expr::ScalarFunction(scalar_fn)) = expr.get(idx) { + sort_expr.expr = Expr::ScalarFunction(scalar_fn.clone()); + } + } + return Ok(sort_expr); + } + + Ok(sort_expr) +} + +/// Iterates through the children of a [LogicalPlan] to find a TableScan node before encountering +/// a Projection or any unexpected node that indicates the presence of a Projection (SELECT) in the plan. +/// If a TableScan node is found, returns the TableScan node without filters, along with the collected filters separately. +/// If the plan contains a Projection, returns None. +/// +/// Note: If a table alias is present, TableScan filters are rewritten to reference the alias. +/// +/// LogicalPlan example: +/// Filter: ta.j1_id < 5 +/// Alias: ta +/// TableScan: j1, j1_id > 10 +/// +/// Will return LogicalPlan below: +/// Alias: ta +/// TableScan: j1 +/// And filters: [ta.j1_id < 5, ta.j1_id > 10] +pub(crate) fn try_transform_to_simple_table_scan_with_filters( + plan: &LogicalPlan, +) -> Result)>> { + let mut filters: Vec = vec![]; + let mut plan_stack = vec![plan]; + let mut table_alias = None; + + while let Some(current_plan) = plan_stack.pop() { + match current_plan { + LogicalPlan::SubqueryAlias(alias) => { + table_alias = Some(alias.alias.clone()); + plan_stack.push(alias.input.as_ref()); + } + LogicalPlan::Filter(filter) => { + filters.push(filter.predicate.clone()); + plan_stack.push(filter.input.as_ref()); + } + LogicalPlan::TableScan(table_scan) => { + let table_schema = table_scan.source.schema(); + // optional rewriter if table has an alias + let mut filter_alias_rewriter = + table_alias.as_ref().map(|alias_name| TableAliasRewriter { + table_schema: &table_schema, + alias_name: alias_name.clone(), + }); + + // rewrite filters to use table alias if present + let table_scan_filters = table_scan + .filters + .iter() + .cloned() + .map(|expr| { + if let Some(ref mut rewriter) = filter_alias_rewriter { + expr.rewrite(rewriter).data() + } else { + Ok(expr) + } + }) + .collect::, DataFusionError>>()?; + + filters.extend(table_scan_filters); + + let mut builder = LogicalPlanBuilder::scan( + table_scan.table_name.clone(), + Arc::clone(&table_scan.source), + None, + )?; + + if let Some(alias) = table_alias.take() { + builder = builder.alias(alias)?; + } + + let plan = builder.build()?; + + return Ok(Some((plan, filters))); + } + _ => { + return Ok(None); + } + } + } + + Ok(None) +} + +/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. +pub(crate) fn date_part_to_sql( + unparser: &Unparser, + style: DateFieldExtractStyle, + date_part_args: &[Expr], +) -> Result> { + match (style, date_part_args.len()) { + (DateFieldExtractStyle::Extract, 2) => { + let date_expr = unparser.expr_to_sql(&date_part_args[1])?; + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + syntax: ast::ExtractSyntax::From, + })); + } + } + (DateFieldExtractStyle::Strftime, 2) => { + let column = unparser.expr_to_sql(&date_part_args[1])?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => "%Y", + "month" => "%m", + "day" => "%d", + "hour" => "%H", + "minute" => "%M", + "second" => "%S", + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Function(ast::Function { + name: ast::ObjectName(vec![ast::Ident { + value: "strftime".to_string(), + quote_style: None, + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + ast::Expr::Value(ast::Value::SingleQuotedString( + field.to_string(), + )), + )), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)), + ], + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + }))); + } + } + (DateFieldExtractStyle::DatePart, _) => { + return Ok(Some( + unparser.scalar_function_to_sql("date_part", date_part_args)?, + )); + } + _ => {} + }; + + Ok(None) +} diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 2c50d3af1f5e..14436de01843 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -18,18 +18,26 @@ //! SQL Utility Functions use std::collections::HashMap; +use std::vec; use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{ - exec_err, internal_err, plan_err, Column, DataFusionError, Result, ScalarValue, + exec_err, internal_err, plan_err, Column, DFSchemaRef, DataFusionError, Result, + ScalarValue, }; -use datafusion_expr::expr::{Alias, GroupingSet, WindowFunction}; +use datafusion_expr::builder::get_struct_unnested_columns; +use datafusion_expr::expr::{Alias, GroupingSet, Unnest, WindowFunction}; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; -use datafusion_expr::{expr_vec_fmt, Expr, LogicalPlan}; -use sqlparser::ast::Ident; +use datafusion_expr::{ + col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, +}; +use indexmap::IndexMap; +use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { @@ -148,50 +156,54 @@ pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap { } /// Given an expression that's literal int encoding position, lookup the corresponding expression -/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal; -/// Otherwise, return None +/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal, +/// otherwise, returns planning error. +/// If input expression is not an int literal, returns expression as-is. pub(crate) fn resolve_positions_to_exprs( - expr: &Expr, + expr: Expr, select_exprs: &[Expr], -) -> Option { +) -> Result { match expr { // sql_expr_to_logical_expr maps number to i64 // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887 Expr::Literal(ScalarValue::Int64(Some(position))) - if position > &0_i64 && position <= &(select_exprs.len() as i64) => + if position > 0_i64 && position <= select_exprs.len() as i64 => { let index = (position - 1) as usize; let select_expr = &select_exprs[index]; - Some(match select_expr { + Ok(match select_expr { Expr::Alias(Alias { expr, .. }) => *expr.clone(), _ => select_expr.clone(), }) } - _ => None, + Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!( + "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}", + position, select_exprs.len() + ), + _ => Ok(expr), } } /// Rebuilds an `Expr` with columns that refer to aliases replaced by the /// alias' underlying `Expr`. pub(crate) fn resolve_aliases_to_exprs( - expr: &Expr, + expr: Expr, aliases: &HashMap, ) -> Result { - expr.clone() - .transform_up(|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::yes(aliased_expr.clone())) - } else { - Ok(Transformed::no(Expr::Column(c))) - } + expr.transform_up(|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::yes(aliased_expr.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) } - _ => Ok(Transformed::no(nested_expr)), - }) - .data() + } + _ => Ok(Transformed::no(nested_expr)), + }) + .data() } -/// given a slice of window expressions sharing the same sort key, find their common partition +/// Given a slice of window expressions sharing the same sort key, find their common partition /// keys. pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> { let all_partition_keys = window_exprs @@ -255,3 +267,726 @@ pub(crate) fn normalize_ident(id: Ident) -> String { None => id.value.to_ascii_lowercase(), } } + +pub(crate) fn value_to_string(value: &Value) -> Option { + match value { + Value::SingleQuotedString(s) => Some(s.to_string()), + Value::DollarQuotedString(s) => Some(s.to_string()), + Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()), + Value::UnicodeStringLiteral(s) => Some(s.to_string()), + Value::EscapedStringLiteral(s) => Some(s.to_string()), + Value::DoubleQuotedString(_) + | Value::NationalStringLiteral(_) + | Value::SingleQuotedByteStringLiteral(_) + | Value::DoubleQuotedByteStringLiteral(_) + | Value::TripleSingleQuotedString(_) + | Value::TripleDoubleQuotedString(_) + | Value::TripleSingleQuotedByteStringLiteral(_) + | Value::TripleDoubleQuotedByteStringLiteral(_) + | Value::SingleQuotedRawStringLiteral(_) + | Value::DoubleQuotedRawStringLiteral(_) + | Value::TripleSingleQuotedRawStringLiteral(_) + | Value::TripleDoubleQuotedRawStringLiteral(_) + | Value::HexStringLiteral(_) + | Value::Null + | Value::Placeholder(_) => None, + } +} + +pub(crate) fn rewrite_recursive_unnests_bottom_up( + input: &LogicalPlan, + unnest_placeholder_columns: &mut IndexMap>>, + inner_projection_exprs: &mut Vec, + original_exprs: &[Expr], +) -> Result> { + Ok(original_exprs + .iter() + .map(|expr| { + rewrite_recursive_unnest_bottom_up( + input, + unnest_placeholder_columns, + inner_projection_exprs, + expr, + ) + }) + .collect::>>()? + .into_iter() + .flatten() + .collect::>()) +} + +/* +This is only usedful when used with transform down up +A full example of how the transformation works: + */ +struct RecursiveUnnestRewriter<'a> { + input_schema: &'a DFSchemaRef, + root_expr: &'a Expr, + // Useful to detect which child expr is a part of/ not a part of unnest operation + top_most_unnest: Option, + consecutive_unnest: Vec>, + inner_projection_exprs: &'a mut Vec, + columns_unnestings: &'a mut IndexMap>>, + transformed_root_exprs: Option>, +} +impl<'a> RecursiveUnnestRewriter<'a> { + /// This struct stores the history of expr + /// during its tree-traversal with a notation of + /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\] + /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\] + /// + /// The first item will be the inner most expr + fn get_latest_consecutive_unnest(&self) -> Vec { + self.consecutive_unnest + .iter() + .rev() + .skip_while(|item| item.is_none()) + .take_while(|item| item.is_some()) + .to_owned() + .cloned() + .map(|item| item.unwrap()) + .collect() + } + + fn transform( + &mut self, + level: usize, + alias_name: String, + expr_in_unnest: &Expr, + struct_allowed: bool, + ) -> Result> { + let inner_expr_name = expr_in_unnest.schema_name().to_string(); + + // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection + // inside unnest execution, each column inside the inner projection + // will be transformed into new columns. Thus we need to keep track of these placeholding column names + let placeholder_name = format!("unnest_placeholder({})", inner_expr_name); + let post_unnest_name = + format!("unnest_placeholder({},depth={})", inner_expr_name, level); + // This is due to the fact that unnest transformation should keep the original + // column name as is, to comply with group by and order by + let placeholder_column = Column::from_name(placeholder_name.clone()); + + let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; + + match data_type { + DataType::Struct(inner_fields) => { + if !struct_allowed { + return internal_err!("unnest on struct can only be applied at the root level of select expression"); + } + push_projection_dedupl( + self.inner_projection_exprs, + expr_in_unnest.clone().alias(placeholder_name.clone()), + ); + self.columns_unnestings + .insert(Column::from_name(placeholder_name.clone()), None); + Ok( + get_struct_unnested_columns(&placeholder_name, &inner_fields) + .into_iter() + .map(Expr::Column) + .collect(), + ) + } + DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) => { + push_projection_dedupl( + self.inner_projection_exprs, + expr_in_unnest.clone().alias(placeholder_name.clone()), + ); + + let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name); + let list_unnesting = self + .columns_unnestings + .entry(placeholder_column) + .or_insert(Some(vec![])); + let unnesting = ColumnUnnestList { + output_column: Column::from_name(post_unnest_name), + depth: level, + }; + let list_unnestings = list_unnesting.as_mut().unwrap(); + if !list_unnestings.contains(&unnesting) { + list_unnestings.push(unnesting); + } + Ok(vec![post_unnest_expr]) + } + _ => { + internal_err!("unnest on non-list or struct type is not supported") + } + } + } +} + +impl<'a> TreeNodeRewriter for RecursiveUnnestRewriter<'a> { + type Node = Expr; + + /// This downward traversal needs to keep track of: + /// - Whether or not some unnest expr has been visited from the top util the current node + /// - If some unnest expr has been visited, maintain a stack of such information, this + /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** + fn f_down(&mut self, expr: Expr) -> Result> { + if let Expr::Unnest(ref unnest_expr) = expr { + let (data_type, _) = + unnest_expr.expr.data_type_and_nullable(self.input_schema)?; + self.consecutive_unnest.push(Some(unnest_expr.clone())); + // if expr inside unnest is a struct, do not consider + // the next unnest as consecutive unnest (if any) + // meaning unnest(unnest(struct_arr_col)) can't + // be interpreted as unest(struct_arr_col, depth:=2) + // but has to be split into multiple unnest logical plan instead + // a.k.a: + // - unnest(struct_col) + // unnest(struct_arr_col) as struct_col + + if let DataType::Struct(_) = data_type { + self.consecutive_unnest.push(None); + } + if self.top_most_unnest.is_none() { + self.top_most_unnest = Some(unnest_expr.clone()); + } + + Ok(Transformed::no(expr)) + } else { + self.consecutive_unnest.push(None); + Ok(Transformed::no(expr)) + } + } + + /// The rewriting only happens when the traversal has reached the top-most unnest expr + /// within a sequence of consecutive unnest exprs node + /// + /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))** + /// ```text + /// ┌──────────────────┐ + /// │ binaryexpr │ + /// │ │ + /// └──────────────────┘ + /// f_down / / │ │ + /// / / f_up │ │ + /// / / f_down│ │f_up + /// unnest │ │ + /// │ │ + /// f_down / / f_up(rewriting) │ │ + /// / / + /// / / unnest + /// unnest + /// f_down / / f_up(rewriting) + /// f_down / /f_up / / + /// / / / / + /// / / unnest + /// column1 + /// f_down / /f_up + /// / / + /// / / + /// column2 + /// ``` + /// + fn f_up(&mut self, expr: Expr) -> Result> { + if let Expr::Unnest(ref traversing_unnest) = expr { + if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { + self.top_most_unnest = None; + } + // Find inside consecutive_unnest, the sequence of continous unnest exprs + + // Get the latest consecutive unnest exprs + // and check if current upward traversal is the returning to the root expr + // for example given a expr `unnest(unnest(col))` then the traversal happens like: + // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest) + // the result of such traversal is unnest(col, depth:=2) + let unnest_stack = self.get_latest_consecutive_unnest(); + + // This traversal has reached the top most unnest again + // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom) + // -> Unnest(2nd) -> Unnest(top) a.k.a here + // Thus + // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2) + if traversing_unnest == unnest_stack.last().unwrap() { + let most_inner = unnest_stack.first().unwrap(); + let inner_expr = most_inner.expr.as_ref(); + // unnest(unnest(struct_arr_col)) is not allow to be done recursively + // it needs to be splitted into multiple unnest logical plan + // unnest(struct_arr) + // unnest(struct_arr_col) as struct_arr + // instead of unnest(struct_arr_col, depth = 2) + + let unnest_recursion = unnest_stack.len(); + let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1; + + let mut transformed_exprs = self.transform( + unnest_recursion, + expr.schema_name().to_string(), + inner_expr, + struct_allowed, + )?; + if struct_allowed { + self.transformed_root_exprs = Some(transformed_exprs.clone()); + } + return Ok(Transformed::new( + transformed_exprs.swap_remove(0), + true, + TreeNodeRecursion::Continue, + )); + } + } else { + self.consecutive_unnest.push(None); + } + + // For column exprs that are not descendants of any unnest node + // retain their projection + // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b + // this condition can be checked by maintaining an Option + if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() { + push_projection_dedupl(self.inner_projection_exprs, expr.clone()); + } + + Ok(Transformed::no(expr)) + } +} + +fn push_projection_dedupl(projection: &mut Vec, expr: Expr) { + let schema_name = expr.schema_name().to_string(); + if !projection + .iter() + .any(|e| e.schema_name().to_string() == schema_name) + { + projection.push(expr); + } +} +/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection +/// Given an expression which contains unnest expr as one of its children, +/// Try transform depends on unnest type +/// - For list column: unnest(col) with type list -> unnest(col) with type list::item +/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2 +/// +/// The transformed exprs will be used in the outer projection +/// If along the path from root to bottom, there are multiple unnest expressions, the transformation +/// is done only for the bottom expression +pub(crate) fn rewrite_recursive_unnest_bottom_up( + input: &LogicalPlan, + unnest_placeholder_columns: &mut IndexMap>>, + inner_projection_exprs: &mut Vec, + original_expr: &Expr, +) -> Result> { + let mut rewriter = RecursiveUnnestRewriter { + input_schema: input.schema(), + root_expr: original_expr, + top_most_unnest: None, + consecutive_unnest: vec![], + inner_projection_exprs, + columns_unnestings: unnest_placeholder_columns, + transformed_root_exprs: None, + }; + + // This transformation is only done for list unnest + // struct unnest is done at the root level, and at the later stage + // because the syntax of TreeNode only support transform into 1 Expr, while + // Unnest struct will be transformed into multiple Exprs + // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102 + // + // The transformation looks like: + // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)") + // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1") + let Transformed { + data: transformed_expr, + transformed, + tnr: _, + } = original_expr.clone().rewrite(&mut rewriter)?; + + if !transformed { + if matches!(&transformed_expr, Expr::Column(_)) + || matches!(&transformed_expr, Expr::Wildcard { .. }) + { + push_projection_dedupl(inner_projection_exprs, transformed_expr.clone()); + Ok(vec![transformed_expr]) + } else { + // We need to evaluate the expr in the inner projection, + // outer projection just select its name + let column_name = transformed_expr.schema_name().to_string(); + push_projection_dedupl(inner_projection_exprs, transformed_expr); + Ok(vec![Expr::Column(Column::from_name(column_name))]) + } + } else { + if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs { + return Ok(transformed_root_exprs); + } + Ok(vec![transformed_expr]) + } +} + +#[cfg(test)] +mod tests { + use std::{ops::Add, sync::Arc}; + + use arrow::datatypes::{DataType as ArrowDataType, Field, Schema}; + use arrow_schema::Fields; + use datafusion_common::{Column, DFSchema, Result}; + use datafusion_expr::{ + col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan, + }; + use datafusion_functions::core::expr_ext::FieldAccessor; + use datafusion_functions_aggregate::expr_fn::count; + use indexmap::IndexMap; + + use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up}; + + fn column_unnests_eq( + l: Vec<&str>, + r: &IndexMap>>, + ) { + let r_formatted: Vec = r + .iter() + .map(|i| match i.1 { + None => format!("{}", i.0), + Some(vec) => format!( + "{}=>[{}]", + i.0, + vec.iter() + .map(|i| format!("{}", i)) + .collect::>() + .join(", ") + ), + }) + .collect(); + let l_formatted: Vec = l.iter().map(|i| i.to_string()).collect(); + assert_eq!(l_formatted, r_formatted); + } + + #[test] + fn test_transform_bottom_unnest_recursive() -> Result<()> { + let schema = Schema::new(vec![ + Field::new( + "3d_col", + ArrowDataType::List(Arc::new(Field::new( + "2d_col", + ArrowDataType::List(Arc::new(Field::new( + "elements", + ArrowDataType::Int64, + true, + ))), + true, + ))), + true, + ), + Field::new("i64_col", ArrowDataType::Int64, true), + ]); + + let dfschema = DFSchema::try_from(schema)?; + + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(dfschema), + }); + + let mut unnest_placeholder_columns = IndexMap::new(); + let mut inner_projection_exprs = vec![]; + + // unnest(unnest(3d_col)) + unnest(unnest(3d_col)) + let original_expr = unnest(unnest(col("3d_col"))) + .add(unnest(unnest(col("3d_col")))) + .add(col("i64_col")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &original_expr, + )?; + // Only the bottom most unnest exprs are transformed + assert_eq!( + transformed_exprs, + vec![col("unnest_placeholder(3d_col,depth=2)") + .alias("UNNEST(UNNEST(3d_col))") + .add( + col("unnest_placeholder(3d_col,depth=2)") + .alias("UNNEST(UNNEST(3d_col))") + ) + .add(col("i64_col"))] + ); + column_unnests_eq( + vec![ + "unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2]", + ], + &unnest_placeholder_columns, + ); + + // Still reference struct_col in original schema but with alias, + // to avoid colliding with the projection on the column itself if any + assert_eq!( + inner_projection_exprs, + vec![ + col("3d_col").alias("unnest_placeholder(3d_col)"), + col("i64_col") + ] + ); + + // unnest(3d_col) as 2d_col + let original_expr_2 = unnest(col("3d_col")).alias("2d_col"); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &original_expr_2, + )?; + + assert_eq!( + transformed_exprs, + vec![ + (col("unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)")) + .alias("2d_col") + ] + ); + column_unnests_eq( + vec!["unnest_placeholder(3d_col)=>[unnest_placeholder(3d_col,depth=2)|depth=2, unnest_placeholder(3d_col,depth=1)|depth=1]"], + &unnest_placeholder_columns, + ); + // Still reference struct_col in original schema but with alias, + // to avoid colliding with the projection on the column itself if any + assert_eq!( + inner_projection_exprs, + vec![ + col("3d_col").alias("unnest_placeholder(3d_col)"), + col("i64_col") + ] + ); + + Ok(()) + } + + #[test] + fn test_transform_bottom_unnest() -> Result<()> { + let schema = Schema::new(vec![ + Field::new( + "struct_col", + ArrowDataType::Struct(Fields::from(vec![ + Field::new("field1", ArrowDataType::Int32, false), + Field::new("field2", ArrowDataType::Int32, false), + ])), + false, + ), + Field::new( + "array_col", + ArrowDataType::List(Arc::new(Field::new( + "item", + ArrowDataType::Int64, + true, + ))), + true, + ), + Field::new("int_col", ArrowDataType::Int32, false), + ]); + + let dfschema = DFSchema::try_from(schema)?; + + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(dfschema), + }); + + let mut unnest_placeholder_columns = IndexMap::new(); + let mut inner_projection_exprs = vec![]; + + // unnest(struct_col) + let original_expr = unnest(col("struct_col")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &original_expr, + )?; + assert_eq!( + transformed_exprs, + vec![ + col("unnest_placeholder(struct_col).field1"), + col("unnest_placeholder(struct_col).field2"), + ] + ); + column_unnests_eq( + vec!["unnest_placeholder(struct_col)"], + &unnest_placeholder_columns, + ); + // Still reference struct_col in original schema but with alias, + // to avoid colliding with the projection on the column itself if any + assert_eq!( + inner_projection_exprs, + vec![col("struct_col").alias("unnest_placeholder(struct_col)"),] + ); + + // unnest(array_col) + 1 + let original_expr = unnest(col("array_col")).add(lit(1i64)); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &original_expr, + )?; + column_unnests_eq( + vec![ + "unnest_placeholder(struct_col)", + "unnest_placeholder(array_col)=>[unnest_placeholder(array_col,depth=1)|depth=1]", + ], + &unnest_placeholder_columns, + ); + // Only transform the unnest children + assert_eq!( + transformed_exprs, + vec![col("unnest_placeholder(array_col,depth=1)") + .alias("UNNEST(array_col)") + .add(lit(1i64))] + ); + + // Keep appending to the current vector + // Still reference array_col in original schema but with alias, + // to avoid colliding with the projection on the column itself if any + assert_eq!( + inner_projection_exprs, + vec![ + col("struct_col").alias("unnest_placeholder(struct_col)"), + col("array_col").alias("unnest_placeholder(array_col)") + ] + ); + + Ok(()) + } + + // Unnest -> field access -> unnest + #[test] + fn test_transform_non_consecutive_unnests() -> Result<()> { + // List of struct + // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}] + let schema = Schema::new(vec![ + Field::new( + "struct_list", + ArrowDataType::List(Arc::new(Field::new( + "element", + ArrowDataType::Struct(Fields::from(vec![ + Field::new( + // list of i64 + "subfield1", + ArrowDataType::List(Arc::new(Field::new( + "i64_element", + ArrowDataType::Int64, + true, + ))), + true, + ), + Field::new( + // list of utf8 + "subfield2", + ArrowDataType::List(Arc::new(Field::new( + "utf8_element", + ArrowDataType::Utf8, + true, + ))), + true, + ), + ])), + true, + ))), + true, + ), + Field::new("int_col", ArrowDataType::Int32, false), + ]); + + let dfschema = DFSchema::try_from(schema)?; + + let input = LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(dfschema), + }); + + let mut unnest_placeholder_columns = IndexMap::new(); + let mut inner_projection_exprs = vec![]; + + // An expr with multiple unnest + let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &select_expr1, + )?; + // Only the inner most/ bottom most unnest is transformed + assert_eq!( + transformed_exprs, + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield1") + )] + ); + + column_unnests_eq( + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], + &unnest_placeholder_columns, + ); + + assert_eq!( + inner_projection_exprs, + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + ); + + // continue rewrite another expr in select + let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2")); + let transformed_exprs = rewrite_recursive_unnest_bottom_up( + &input, + &mut unnest_placeholder_columns, + &mut inner_projection_exprs, + &select_expr2, + )?; + // Only the inner most/ bottom most unnest is transformed + assert_eq!( + transformed_exprs, + vec![unnest( + col("unnest_placeholder(struct_list,depth=1)") + .alias("UNNEST(struct_list)") + .field("subfield2") + )] + ); + + // unnest place holder columns remain the same + // because expr1 and expr2 derive from the same unnest result + column_unnests_eq( + vec![ + "unnest_placeholder(struct_list)=>[unnest_placeholder(struct_list,depth=1)|depth=1]", + ], + &unnest_placeholder_columns, + ); + + assert_eq!( + inner_projection_exprs, + vec![col("struct_list").alias("unnest_placeholder(struct_list)")] + ); + + Ok(()) + } + + #[test] + fn test_resolve_positions_to_exprs() -> Result<()> { + let select_exprs = vec![col("c1"), col("c2"), count(lit(1))]; + + // Assert 1 resolved as first column in select list + let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?; + assert_eq!(resolved, col("c1")); + + // Assert error if index out of select clause bounds + let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs); + assert!(resolved.is_err_and(|e| e.message().contains( + "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3" + ))); + + let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs); + assert!(resolved.is_err_and(|e| e.message().contains( + "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3" + ))); + + // Assert expression returned as-is + let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?; + assert_eq!(resolved, lit("text")); + + let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?; + assert_eq!(resolved, col("fake")); + + Ok(()) + } +} diff --git a/datafusion/sql/src/values.rs b/datafusion/sql/src/values.rs index 9efb75bd60e4..a4001bea7dea 100644 --- a/datafusion/sql/src/values.rs +++ b/datafusion/sql/src/values.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder}; @@ -31,16 +33,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { rows, } = values; - // values should not be based on any other schema - let schema = DFSchema::empty(); + let empty_schema = Arc::new(DFSchema::empty()); let values = rows .into_iter() .map(|row| { row.into_iter() - .map(|v| self.sql_to_expr(v, &schema, planner_context)) + .map(|v| self.sql_to_expr(v, &empty_schema, planner_context)) .collect::>>() }) .collect::>>()?; - LogicalPlanBuilder::values(values)?.build() + + let schema = planner_context.table_schema().unwrap_or(empty_schema); + if schema.fields().is_empty() { + LogicalPlanBuilder::values(values)?.build() + } else { + LogicalPlanBuilder::values_with_schema(values, &schema)?.build() + } } } diff --git a/datafusion/physical-expr-common/src/expressions/mod.rs b/datafusion/sql/tests/cases/mod.rs similarity index 97% rename from datafusion/physical-expr-common/src/expressions/mod.rs rename to datafusion/sql/tests/cases/mod.rs index d102422081dc..fc4c59cc88d8 100644 --- a/datafusion/physical-expr-common/src/expressions/mod.rs +++ b/datafusion/sql/tests/cases/mod.rs @@ -15,4 +15,4 @@ // specific language governing permissions and limitations // under the License. -pub mod column; +mod plan_to_sql; diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs new file mode 100644 index 000000000000..ea0ccb8e4b43 --- /dev/null +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -0,0 +1,1190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::vec; + +use arrow_schema::*; +use datafusion_common::{DFSchema, Result, TableReference}; +use datafusion_expr::test::function_stub::{count_udaf, max_udaf, min_udaf, sum_udaf}; +use datafusion_expr::{col, lit, table_scan, wildcard, LogicalPlanBuilder}; +use datafusion_functions::unicode; +use datafusion_functions_aggregate::grouping::grouping_udaf; +use datafusion_functions_nested::make_array::make_array_udf; +use datafusion_functions_window::rank::rank_udwf; +use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; +use datafusion_sql::unparser::dialect::{ + DefaultDialect as UnparserDefaultDialect, Dialect as UnparserDialect, + MySqlDialect as UnparserMySqlDialect, SqliteDialect, +}; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; + +use crate::common::{MockContextProvider, MockSessionState}; +use datafusion_expr::builder::{ + table_scan_with_filter_and_fetch, table_scan_with_filters, +}; +use datafusion_functions::core::planner::CoreFunctionPlanner; +use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; +use sqlparser::parser::Parser; + +#[test] +fn roundtrip_expr() { + let tests: Vec<(TableReference, &str, &str)> = vec![ + (TableReference::bare("person"), "age > 35", r#"(age > 35)"#), + ( + TableReference::bare("person"), + "id = '10'", + r#"(id = '10')"#, + ), + ( + TableReference::bare("person"), + "CAST(id AS VARCHAR)", + r#"CAST(id AS VARCHAR)"#, + ), + ( + TableReference::bare("person"), + "sum((age * 2))", + r#"sum((age * 2))"#, + ), + ]; + + let roundtrip = |table, sql: &str| -> Result { + let dialect = GenericDialect {}; + let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let context = MockContextProvider { state }; + let schema = context.get_table_source(table)?.schema(); + let df_schema = DFSchema::try_from(schema.as_ref().clone())?; + let sql_to_rel = SqlToRel::new(&context); + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + + let ast = expr_to_sql(&expr)?; + + Ok(ast.to_string()) + }; + + for (table, query, expected) in tests { + let actual = roundtrip(table, query).unwrap(); + assert_eq!(actual, expected); + } +} + +#[test] +fn roundtrip_statement() -> Result<()> { + let tests: Vec<&str> = vec![ + "select 1;", + "select 1 limit 0;", + "select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id;", + "select ta.j1_id from j1 ta join (select 1 as j1_id) tb using (j1_id);", + "select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id where ta.j1_id > 1;", + "select ta.j1_id from (select 1 as j1_id) ta;", + "select ta.j1_id from j1 ta;", + "select ta.j1_id from j1 ta order by ta.j1_id;", + "select * from j1 ta order by ta.j1_id, ta.j1_string desc;", + "select * from j1 limit 10;", + "select ta.j1_id from j1 ta where ta.j1_id > 1;", + "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", + "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", + "select * from (select id, first_name from person)", + "select * from (select id, first_name from (select * from person))", + "select id, count(*) as cnt from (select id from person) group by id", + "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", + "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))", + r#"select "First Name" from person_quoted_cols"#, + "select DISTINCT id FROM person", + "select DISTINCT on (id) id, first_name from person", + "select DISTINCT on (id) id, first_name from person order by id", + r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, + "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", + "select id, count(*), first_name from person group by first_name, id", + "select id, sum(age), first_name from person group by first_name, id", + "select id, count(*), first_name + from person + where id!=3 and first_name=='test' + group by first_name, id + having count(*)>5 and count(*)<10 + order by count(*)", + r#"select id, count("First Name") as count_first_name, "Last Name" + from person_quoted_cols + where id!=3 and "First Name"=='test' + group by "Last Name", id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, + r#"select p.id, count("First Name") as count_first_name, + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + from (select id, "First Name", "Last Name" from person_quoted_cols) qp + inner join (select * from person) p + on p.id = qp.id + where p.id!=3 and "First Name"=='test' and qp.id in + (select id from (select id, count(*) from person group by id having count(*) > 0)) + group by "Last Name", p.id + having count_first_name>5 and count_first_name<10 + order by count_first_name, "Last Name""#, + r#"SELECT j1_string as string FROM j1 + UNION ALL + SELECT j2_string as string FROM j2"#, + r#"SELECT j1_string as string FROM j1 + UNION ALL + SELECT j2_string as string FROM j2 + ORDER BY string DESC + LIMIT 10"#, + r#"SELECT col1, id FROM ( + SELECT j1_string AS col1, j1_id AS id FROM j1 + UNION ALL + SELECT j2_string AS col1, j2_id AS id FROM j2 + UNION ALL + SELECT j3_string AS col1, j3_id AS id FROM j3 + ) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#, + "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + first_name from person", + r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, + "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", + "WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col), w3 as (SELECT 'c' as col), w4 as (SELECT 'd' as col) SELECT * FROM w1 UNION ALL SELECT * FROM w2 UNION ALL SELECT * FROM w3 UNION ALL SELECT * FROM w4", + "WITH w1 AS (SELECT 'a' as col), w2 AS (SELECT 'b' as col) SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col UNION ALL SELECT * FROM w1 JOIN w2 ON w1.col = w2.col", + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM person JOIN orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total + FROM (SELECT id, first_name from person) person JOIN (SELECT customer_id FROM orders) orders ON person.id = orders.customer_id GROUP BY id, first_name"#, + r#"SELECT id, first_name, last_name, customer_id, SUM(id) AS total_sum + FROM person + JOIN orders ON person.id = orders.customer_id + GROUP BY ROLLUP(id, first_name, last_name, customer_id)"#, + r#"SELECT id, first_name, last_name, + SUM(id) AS total_sum, + COUNT(*) AS total_count, + SUM(id) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS running_total + FROM person + GROUP BY GROUPING SETS ((id, first_name, last_name), (first_name, last_name), (last_name))"#, + ]; + + // For each test sql string, we transform as follows: + // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) + // We test not that s1==s2, but rather p1==p2. This ensures that unparser preserves the logical + // query information of the original sql string and disreguards other differences in syntax or + // quoting. + for query in tests { + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; + let state = MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(count_udaf()) + .with_aggregate_function(max_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let roundtrip_statement = plan_to_sql(&plan)?; + + let actual = &roundtrip_statement.to_string(); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + let plan_roundtrip = sql_to_rel + .sql_statement_to_plan(roundtrip_statement.clone()) + .unwrap(); + + assert_eq!(plan, plan_roundtrip); + } + + Ok(()) +} + +#[test] +fn roundtrip_crossjoin() -> Result<()> { + let query = "select j1.j1_id, j2.j2_string from j1, j2"; + + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; + + let state = MockSessionState::default() + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let roundtrip_statement = plan_to_sql(&plan)?; + + let actual = &roundtrip_statement.to_string(); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + let plan_roundtrip = sql_to_rel + .sql_statement_to_plan(roundtrip_statement) + .unwrap(); + + let expected = "Projection: j1.j1_id, j2.j2_string\ + \n Cross Join: \ + \n TableScan: j1\ + \n TableScan: j2"; + + assert_eq!(plan_roundtrip.to_string(), expected); + + Ok(()) +} + +#[test] +fn roundtrip_statement_with_dialect() -> Result<()> { + struct TestStatementWithDialect { + sql: &'static str, + expected: &'static str, + parser_dialect: Box, + unparser_dialect: Box, + } + let tests: Vec = vec![ + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort gets derived into a subquery + // for MySQL, this subquery needs an alias + "SELECT `j1_min` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, min(`ta`.`j1_id`) FROM `j1` AS `ta` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min from j1 ta order by min(ta.j1_id) limit 10;", + expected: + // top projection sort still gets derived into a subquery in default dialect + // except for the default dialect, the subquery is left non-aliased + "SELECT j1_min FROM (SELECT min(ta.j1_id) AS j1_min, min(ta.j1_id) FROM j1 AS ta ORDER BY min(ta.j1_id) ASC NULLS LAST) LIMIT 10", + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "select min(ta.j1_id) as j1_min, max(tb.j1_max) from j1 ta, (select distinct max(ta.j1_id) as j1_max from j1 ta order by max(ta.j1_id)) tb order by min(ta.j1_id) limit 10;", + expected: + "SELECT `j1_min`, `max(tb.j1_max)` FROM (SELECT min(`ta`.`j1_id`) AS `j1_min`, max(`tb`.`j1_max`), min(`ta`.`j1_id`) FROM `j1` AS `ta` JOIN (SELECT `j1_max` FROM (SELECT DISTINCT max(`ta`.`j1_id`) AS `j1_max` FROM `j1` AS `ta`) AS `derived_distinct`) AS `tb` ORDER BY min(`ta`.`j1_id`) ASC) AS `derived_sort` LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select j1_id from (select 1 as j1_id);", + expected: + "SELECT `j1_id` FROM (SELECT 1 AS `j1_id`) AS `derived_projection`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select * from (select * from j1 limit 10);", + expected: + "SELECT * FROM (SELECT * FROM `j1` LIMIT 10) AS `derived_limit`", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + expected: + "SELECT `ta`.`j1_id` FROM `j1` AS `ta` ORDER BY `ta`.`j1_id` ASC LIMIT 10", + parser_dialect: Box::new(MySqlDialect {}), + unparser_dialect: Box::new(UnparserMySqlDialect {}), + }, + TestStatementWithDialect { + sql: "select ta.j1_id from j1 ta order by j1_id limit 10;", + expected: r#"SELECT ta.j1_id FROM j1 AS ta ORDER BY ta.j1_id ASC NULLS LAST LIMIT 10"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_id FROM j1 + UNION ALL + SELECT tb.j2_id as j1_id FROM j2 tb + ORDER BY j1_id + LIMIT 10;", + expected: r#"SELECT j1.j1_id FROM j1 UNION ALL SELECT tb.j2_id AS j1_id FROM j2 AS tb ORDER BY j1_id ASC NULLS LAST LIMIT 10"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query with derived tables that put distinct,sort,limit on the wrong level + TestStatementWithDialect { + sql: "SELECT j1_string from j1 order by j1_id", + expected: r#"SELECT j1.j1_string FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_string AS a from j1 order by j1_id", + expected: r#"SELECT j1.j1_string AS a FROM j1 ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT j1_string from j1 join j2 on j1.j1_id = j2.j2_id order by j1_id", + expected: r#"SELECT j1.j1_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: " + SELECT + j1_string, + j2_string + FROM + ( + SELECT + distinct j1_id, + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + order by + j1.j1_id desc + limit + 10 + ) abc + ORDER BY + abc.j2_string", + expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT DISTINCT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // more tests around subquery/derived table roundtrip + TestStatementWithDialect { + sql: "SELECT string_count FROM ( + SELECT + j1_id, + min(j2_string) + FROM + j1 LEFT OUTER JOIN j2 ON + j1_id = j2_id + GROUP BY + j1_id + ) AS agg (id, string_count) + ", + expected: r#"SELECT agg.string_count FROM (SELECT j1.j1_id, min(j2.j2_string) FROM j1 LEFT JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id) AS agg (id, string_count)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: " + SELECT + j1_string, + j2_string + FROM + ( + SELECT + j1_id, + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + group by + j1_id, + j1_string, + j2_string + order by + j1.j1_id desc + limit + 10 + ) abc + ORDER BY + abc.j2_string", + expected: r#"SELECT abc.j1_string, abc.j2_string FROM (SELECT j1.j1_id, j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) GROUP BY j1.j1_id, j1.j1_string, j2.j2_string ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query that order by columns are not in select columns + TestStatementWithDialect { + sql: " + SELECT + j1_string + FROM + ( + SELECT + j1_string, + j2_string + from + j1 + INNER join j2 ON j1.j1_id = j2.j2_id + order by + j1.j1_id desc, + j2.j2_id desc + limit + 10 + ) abc + ORDER BY + j2_string", + expected: r#"SELECT abc.j1_string FROM (SELECT j1.j1_string, j2.j2_string FROM j1 JOIN j2 ON (j1.j1_id = j2.j2_id) ORDER BY j1.j1_id DESC NULLS FIRST, j2.j2_id DESC NULLS FIRST LIMIT 10) AS abc ORDER BY abc.j2_string ASC NULLS LAST"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id from j1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id as id from j1) AS c", + expected: r#"SELECT c.id FROM (SELECT j1.j1_id AS id FROM j1) AS c"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query that has calculation in derived table with columns + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id + 1 * 3 from j1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + // Test query that has limit/distinct/order in derived table with columns + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT distinct (j1_id + 1 * 3) FROM j1 LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT CAST((CAST(j1_id as BIGINT) + 1) as int) * 10 FROM j1 LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (CAST((CAST(j1.j1_id AS BIGINT) + 1) AS INTEGER) * 10) FROM j1 LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT id FROM (SELECT CAST(j1_id as BIGINT) + 1 FROM j1 ORDER BY j1_id LIMIT 1) AS c (id)", + expected: r#"SELECT c.id FROM (SELECT (CAST(j1.j1_id AS BIGINT) + 1) FROM j1 ORDER BY j1.j1_id ASC NULLS LAST LIMIT 1) AS c (id)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", + expected: r#"SELECT temp_j.id2 FROM (SELECT j1.j1_id, j1.j1_string FROM j1) AS temp_j (id2, string2)"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(UnparserDefaultDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT temp_j.id2 FROM (SELECT j1_id, j1_string FROM j1) AS temp_j(id2, string2)", + expected: r#"SELECT `temp_j`.`id2` FROM (SELECT `j1`.`j1_id` AS `id2`, `j1`.`j1_string` AS `string2` FROM `j1`) AS `temp_j`"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(SqliteDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM (SELECT j1_id + 1 FROM j1) AS temp_j(id2)", + expected: r#"SELECT * FROM (SELECT (`j1`.`j1_id` + 1) AS `id2` FROM `j1`) AS `temp_j`"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(SqliteDialect {}), + }, + TestStatementWithDialect { + sql: "SELECT * FROM (SELECT j1_id FROM j1 LIMIT 1) AS temp_j(id2)", + expected: r#"SELECT * FROM (SELECT `j1`.`j1_id` AS `id2` FROM `j1` LIMIT 1) AS `temp_j`"#, + parser_dialect: Box::new(GenericDialect {}), + unparser_dialect: Box::new(SqliteDialect {}), + }, + ]; + + for query in tests { + let statement = Parser::new(&*query.parser_dialect) + .try_with_sql(query.sql)? + .parse_statement()?; + + let state = MockSessionState::default() + .with_aggregate_function(max_udaf()) + .with_aggregate_function(min_udaf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel + .sql_statement_to_plan(statement) + .unwrap_or_else(|e| panic!("Failed to parse sql: {}\n{e}", query.sql)); + + let unparser = Unparser::new(&*query.unparser_dialect); + let roundtrip_statement = unparser.plan_to_sql(&plan)?; + + let actual = &roundtrip_statement.to_string(); + println!("roundtrip sql: {actual}"); + println!("plan {}", plan.display_indent()); + + assert_eq!(query.expected, actual); + } + + Ok(()) +} + +#[test] +fn test_unnest_logical_plan() -> Result<()> { + let query = "select unnest(struct_col), unnest(array_col), struct_col, array_col from unnest_table"; + + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect) + .try_with_sql(query)? + .parse_statement()?; + + let context = MockContextProvider { + state: MockSessionState::default(), + }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + let expected = r#" +Projection: unnest_placeholder(unnest_table.struct_col).field1, unnest_placeholder(unnest_table.struct_col).field2, unnest_placeholder(unnest_table.array_col,depth=1) AS UNNEST(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col + Unnest: lists[unnest_placeholder(unnest_table.array_col)|depth=1] structs[unnest_placeholder(unnest_table.struct_col)] + Projection: unnest_table.struct_col AS unnest_placeholder(unnest_table.struct_col), unnest_table.array_col AS unnest_placeholder(unnest_table.array_col), unnest_table.struct_col, unnest_table.array_col + TableScan: unnest_table"#.trim_start(); + + assert_eq!(plan.to_string(), expected); + + Ok(()) +} + +#[test] +fn test_table_references_in_plan_to_sql() { + fn test(table_name: &str, expected_sql: &str) { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ]); + let plan = table_scan(Some(table_name), &schema, None) + .unwrap() + .project(vec![col("id"), col("value")]) + .unwrap() + .build() + .unwrap(); + let sql = plan_to_sql(&plan).unwrap(); + + assert_eq!(sql.to_string(), expected_sql) + } + + test( + "catalog.schema.table", + r#"SELECT "catalog"."schema"."table".id, "catalog"."schema"."table"."value" FROM "catalog"."schema"."table""#, + ); + test( + "schema.table", + r#"SELECT "schema"."table".id, "schema"."table"."value" FROM "schema"."table""#, + ); + test( + "table", + r#"SELECT "table".id, "table"."value" FROM "table""#, + ); +} + +#[test] +fn test_table_scan_with_no_projection_in_plan_to_sql() { + fn test(table_name: &str, expected_sql: &str) { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ]); + + let plan = table_scan(Some(table_name), &schema, None) + .unwrap() + .build() + .unwrap(); + let sql = plan_to_sql(&plan).unwrap(); + assert_eq!(sql.to_string(), expected_sql) + } + + test( + "catalog.schema.table", + r#"SELECT * FROM "catalog"."schema"."table""#, + ); + test("schema.table", r#"SELECT * FROM "schema"."table""#); + test("table", r#"SELECT * FROM "table""#); +} + +#[test] +fn test_pretty_roundtrip() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let df_schema = DFSchema::try_from(schema)?; + + let context = MockContextProvider { + state: MockSessionState::default(), + }; + let sql_to_rel = SqlToRel::new(&context); + + let unparser = Unparser::default().with_pretty(true); + + let sql_to_pretty_unparse = vec![ + ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"), + ("((id + 5) * (age * 8))", "(id + 5) * age * 8"), + ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), + ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), + ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), + ("3 + 5 + 6 + 3", "3 + 5 + 6 + 3"), + ("3 + (5 + (6 + 3))", "3 + 5 + 6 + 3"), + ("3 + ((5 + 6) + 3)", "3 + 5 + 6 + 3"), + ("(3 + 5) + (6 + 3)", "3 + 5 + 6 + 3"), + ("((3 + 5) + (6 + 3))", "3 + 5 + 6 + 3"), + ( + "((id > 10) OR (age BETWEEN 10 AND 20))", + "id > 10 OR age BETWEEN 10 AND 20", + ), + ( + "((id > 10) * (age BETWEEN 10 AND 20))", + "(id > 10) * (age BETWEEN 10 AND 20)", + ), + ("id - (age - 8)", "id - (age - 8)"), + ("((id - age) - 8)", "id - age - 8"), + ("(id OR (age - 8))", "id OR age - 8"), + ("(id / (age - 8))", "id / (age - 8)"), + ("((id / age) * 8)", "id / age * 8"), + ("((age + 10) < 20) IS TRUE", "(age + 10 < 20) IS TRUE"), + ( + "(20 > (age + 5)) IS NOT FALSE", + "(20 > age + 5) IS NOT FALSE", + ), + ("(true AND false) IS FALSE", "(true AND false) IS FALSE"), + ("true AND (false IS FALSE)", "true AND false IS FALSE"), + ]; + + for (sql, pretty) in sql_to_pretty_unparse.iter() { + let sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(sql)? + .parse_expr()?; + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(pretty.to_string(), round_trip_sql); + + // verify that the pretty string parses to the same underlying Expr + let pretty_sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(pretty)? + .parse_expr()?; + + let pretty_expr = sql_to_rel.sql_to_expr( + pretty_sql_expr, + &df_schema, + &mut PlannerContext::new(), + )?; + + assert_eq!(expr.to_string(), pretty_expr.to_string()); + } + + Ok(()) +} + +fn sql_round_trip(dialect: D, query: &str, expect: &str) +where + D: Dialect, +{ + let statement = Parser::new(&dialect) + .try_with_sql(query) + .unwrap() + .parse_statement() + .unwrap(); + + let context = MockContextProvider { + state: MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) + .with_scalar_function(make_array_udf()), + }; + let sql_to_rel = SqlToRel::new(&context); + let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); + + let roundtrip_statement = plan_to_sql(&plan).unwrap(); + assert_eq!(roundtrip_statement.to_string(), expect); +} + +#[test] +fn test_table_scan_alias() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let sql = plan_to_sql(&plan)?; + assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id")])? + .alias("a")? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_eq!(sql.to_string(), "SELECT * FROM (SELECT t1.id FROM t1) AS a"); + + let plan = table_scan(Some("t1"), &schema, None)? + .filter(col("id").gt(lit(5)))? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let sql = plan_to_sql(&plan)?; + assert_eq!( + sql.to_string(), + "SELECT * FROM (SELECT t1.id FROM t1 WHERE (t1.id > 5)) AS a" + ); + + let table_scan_with_two_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(lit(1)), col("age").lt(lit(2))], + )? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; + assert_eq!( + table_scan_with_two_filter.to_string(), + "SELECT a.id FROM t1 AS a WHERE ((a.id > 1) AND (a.age < 2))" + ); + + let table_scan_with_fetch = + table_scan_with_filter_and_fetch(Some("t1"), &schema, None, vec![], Some(10))? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_fetch = plan_to_sql(&table_scan_with_fetch)?; + assert_eq!( + table_scan_with_fetch.to_string(), + "SELECT a.id FROM (SELECT * FROM t1 LIMIT 10) AS a" + ); + + let table_scan_with_pushdown_all = table_scan_with_filter_and_fetch( + Some("t1"), + &schema, + Some(vec![0, 1]), + vec![col("id").gt(lit(1))], + Some(10), + )? + .project(vec![col("id")])? + .alias("a")? + .build()?; + let table_scan_with_pushdown_all = plan_to_sql(&table_scan_with_pushdown_all)?; + assert_eq!( + table_scan_with_pushdown_all.to_string(), + "SELECT a.id FROM (SELECT a.id, a.age FROM t1 AS a WHERE (a.id > 1) LIMIT 10) AS a" + ); + Ok(()) +} + +#[test] +fn test_table_scan_pushdown() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + let scan_with_projection = + table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?; + let scan_with_projection = plan_to_sql(&scan_with_projection)?; + assert_eq!( + scan_with_projection.to_string(), + "SELECT t1.id, t1.age FROM t1" + ); + + let scan_with_projection = table_scan(Some("t1"), &schema, Some(vec![1]))?.build()?; + let scan_with_projection = plan_to_sql(&scan_with_projection)?; + assert_eq!(scan_with_projection.to_string(), "SELECT t1.age FROM t1"); + + let scan_with_no_projection = table_scan(Some("t1"), &schema, None)?.build()?; + let scan_with_no_projection = plan_to_sql(&scan_with_no_projection)?; + assert_eq!(scan_with_no_projection.to_string(), "SELECT * FROM t1"); + + let table_scan_with_projection_alias = + table_scan(Some("t1"), &schema, Some(vec![0, 1]))? + .alias("ta")? + .build()?; + let table_scan_with_projection_alias = + plan_to_sql(&table_scan_with_projection_alias)?; + assert_eq!( + table_scan_with_projection_alias.to_string(), + "SELECT ta.id, ta.age FROM t1 AS ta" + ); + + let table_scan_with_projection_alias = + table_scan(Some("t1"), &schema, Some(vec![1]))? + .alias("ta")? + .build()?; + let table_scan_with_projection_alias = + plan_to_sql(&table_scan_with_projection_alias)?; + assert_eq!( + table_scan_with_projection_alias.to_string(), + "SELECT ta.age FROM t1 AS ta" + ); + + let table_scan_with_no_projection_alias = table_scan(Some("t1"), &schema, None)? + .alias("ta")? + .build()?; + let table_scan_with_no_projection_alias = + plan_to_sql(&table_scan_with_no_projection_alias)?; + assert_eq!( + table_scan_with_no_projection_alias.to_string(), + "SELECT * FROM t1 AS ta" + ); + + let query_from_table_scan_with_projection = LogicalPlanBuilder::from( + table_scan(Some("t1"), &schema, Some(vec![0, 1]))?.build()?, + ) + .project(vec![wildcard()])? + .build()?; + let query_from_table_scan_with_projection = + plan_to_sql(&query_from_table_scan_with_projection)?; + assert_eq!( + query_from_table_scan_with_projection.to_string(), + "SELECT * FROM (SELECT t1.id, t1.age FROM t1)" + ); + + let table_scan_with_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(col("age"))], + )? + .build()?; + let table_scan_with_filter = plan_to_sql(&table_scan_with_filter)?; + assert_eq!( + table_scan_with_filter.to_string(), + "SELECT * FROM t1 WHERE (t1.id > t1.age)" + ); + + let table_scan_with_two_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(lit(1)), col("age").lt(lit(2))], + )? + .build()?; + let table_scan_with_two_filter = plan_to_sql(&table_scan_with_two_filter)?; + assert_eq!( + table_scan_with_two_filter.to_string(), + "SELECT * FROM t1 WHERE ((t1.id > 1) AND (t1.age < 2))" + ); + + let table_scan_with_filter_alias = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(col("age"))], + )? + .alias("ta")? + .build()?; + let table_scan_with_filter_alias = plan_to_sql(&table_scan_with_filter_alias)?; + assert_eq!( + table_scan_with_filter_alias.to_string(), + "SELECT * FROM t1 AS ta WHERE (ta.id > ta.age)" + ); + + let table_scan_with_projection_and_filter = table_scan_with_filters( + Some("t1"), + &schema, + Some(vec![0, 1]), + vec![col("id").gt(col("age"))], + )? + .build()?; + let table_scan_with_projection_and_filter = + plan_to_sql(&table_scan_with_projection_and_filter)?; + assert_eq!( + table_scan_with_projection_and_filter.to_string(), + "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age)" + ); + + let table_scan_with_projection_and_filter = table_scan_with_filters( + Some("t1"), + &schema, + Some(vec![1]), + vec![col("id").gt(col("age"))], + )? + .build()?; + let table_scan_with_projection_and_filter = + plan_to_sql(&table_scan_with_projection_and_filter)?; + assert_eq!( + table_scan_with_projection_and_filter.to_string(), + "SELECT t1.age FROM t1 WHERE (t1.id > t1.age)" + ); + + let table_scan_with_inline_fetch = + table_scan_with_filter_and_fetch(Some("t1"), &schema, None, vec![], Some(10))? + .build()?; + let table_scan_with_inline_fetch = plan_to_sql(&table_scan_with_inline_fetch)?; + assert_eq!( + table_scan_with_inline_fetch.to_string(), + "SELECT * FROM t1 LIMIT 10" + ); + + let table_scan_with_projection_and_inline_fetch = table_scan_with_filter_and_fetch( + Some("t1"), + &schema, + Some(vec![0, 1]), + vec![], + Some(10), + )? + .build()?; + let table_scan_with_projection_and_inline_fetch = + plan_to_sql(&table_scan_with_projection_and_inline_fetch)?; + assert_eq!( + table_scan_with_projection_and_inline_fetch.to_string(), + "SELECT t1.id, t1.age FROM t1 LIMIT 10" + ); + + let table_scan_with_all = table_scan_with_filter_and_fetch( + Some("t1"), + &schema, + Some(vec![0, 1]), + vec![col("id").gt(col("age"))], + Some(10), + )? + .build()?; + let table_scan_with_all = plan_to_sql(&table_scan_with_all)?; + assert_eq!( + table_scan_with_all.to_string(), + "SELECT t1.id, t1.age FROM t1 WHERE (t1.id > t1.age) LIMIT 10" + ); + + let table_scan_with_additional_filter = table_scan_with_filters( + Some("t1"), + &schema, + None, + vec![col("id").gt(col("age"))], + )? + .filter(col("id").eq(lit(5)))? + .build()?; + let table_scan_with_filter = plan_to_sql(&table_scan_with_additional_filter)?; + assert_eq!( + table_scan_with_filter.to_string(), + "SELECT * FROM t1 WHERE (t1.id = 5) AND (t1.id > t1.age)" + ); + + Ok(()) +} + +#[test] +fn test_sort_with_push_down_fetch() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let plan = table_scan(Some("t1"), &schema, None)? + .project(vec![col("id"), col("age")])? + .sort_with_limit(vec![col("age").sort(true, true)], Some(10))? + .build()?; + + let sql = plan_to_sql(&plan)?; + assert_eq!( + format!("{}", sql), + "SELECT t1.id, t1.age FROM t1 ORDER BY t1.age ASC NULLS FIRST LIMIT 10" + ); + Ok(()) +} + +#[test] +fn test_join_with_table_scan_filters() -> Result<()> { + let schema_left = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("name", DataType::Utf8, false), + ]); + + let schema_right = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let left_plan = table_scan_with_filters( + Some("left_table"), + &schema_left, + None, + vec![col("name").like(lit("some_name"))], + )? + .alias("left")? + .build()?; + + let right_plan = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .build()?; + + let join_plan_with_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan.clone(), + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .build()?; + + let sql = plan_to_sql(&join_plan_with_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND ("left"."name" LIKE 'some_name' AND (age > 10)))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let join_plan_no_filter = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + None, + )? + .build()?; + + let sql = plan_to_sql(&join_plan_no_filter)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND ("left"."name" LIKE 'some_name' AND (age > 10))"#; + + assert_eq!(sql.to_string(), expected_sql); + + let right_plan_with_filter = table_scan_with_filters( + Some("right_table"), + &schema_right, + None, + vec![col("age").gt(lit(10))], + )? + .filter(col("right_table.name").eq(lit("before_join_filter_val")))? + .build()?; + + let join_plan_multiple_filters = LogicalPlanBuilder::from(left_plan.clone()) + .join( + right_plan_with_filter, + datafusion_expr::JoinType::Inner, + (vec!["left.id"], vec!["right_table.id"]), + Some(col("left.id").gt(lit(5))), + )? + .filter(col("left.name").eq(lit("after_join_filter_val")))? + .build()?; + + let sql = plan_to_sql(&join_plan_multiple_filters)?; + + let expected_sql = r#"SELECT * FROM left_table AS "left" JOIN right_table ON "left".id = right_table.id AND (("left".id > 5) AND (("left"."name" LIKE 'some_name' AND (right_table."name" = 'before_join_filter_val')) AND (age > 10))) WHERE ("left"."name" = 'after_join_filter_val')"#; + + assert_eq!(sql.to_string(), expected_sql); + + Ok(()) +} + +#[test] +fn test_interval_lhs_eq() { + sql_round_trip( + GenericDialect {}, + "select interval '2 seconds' = interval '2 seconds'", + "SELECT (INTERVAL '2.000000000 SECS' = INTERVAL '2.000000000 SECS')", + ); +} + +#[test] +fn test_interval_lhs_lt() { + sql_round_trip( + GenericDialect {}, + "select interval '2 seconds' < interval '2 seconds'", + "SELECT (INTERVAL '2.000000000 SECS' < INTERVAL '2.000000000 SECS')", + ); +} + +#[test] +fn test_without_offset() { + sql_round_trip(MySqlDialect {}, "select 1", "SELECT 1"); +} + +#[test] +fn test_with_offset0() { + sql_round_trip(MySqlDialect {}, "select 1 offset 0", "SELECT 1 OFFSET 0"); +} + +#[test] +fn test_with_offset95() { + sql_round_trip(MySqlDialect {}, "select 1 offset 95", "SELECT 1 OFFSET 95"); +} + +#[test] +fn test_order_by_to_sql() { + // order by aggregation function + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, SUM(id) FROM person GROUP BY id, first_name ORDER BY SUM(id) ASC, first_name DESC, id, first_name LIMIT 10"#, + r#"SELECT person.id, person.first_name, sum(person.id) FROM person GROUP BY person.id, person.first_name ORDER BY sum(person.id) ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, + ); + + // order by aggregation function alias + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, SUM(id) as total_sum FROM person GROUP BY id, first_name ORDER BY total_sum ASC, first_name DESC, id, first_name LIMIT 10"#, + r#"SELECT person.id, person.first_name, sum(person.id) AS total_sum FROM person GROUP BY person.id, person.first_name ORDER BY total_sum ASC NULLS LAST, person.first_name DESC NULLS FIRST, person.id ASC NULLS LAST, person.first_name ASC NULLS LAST LIMIT 10"#, + ); + + // order by scalar function from projection + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, substr(first_name,0,5) FROM person ORDER BY id, substr(first_name,0,5)"#, + r#"SELECT person.id, person.first_name, substr(person.first_name, 0, 5) FROM person ORDER BY person.id ASC NULLS LAST, substr(person.first_name, 0, 5) ASC NULLS LAST"#, + ); +} + +#[test] +fn test_aggregation_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT id, first_name, + SUM(id) AS total_sum, + SUM(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, + MAX(SUM(id)) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, + rank() OVER (PARTITION BY grouping(id) + grouping(age), CASE WHEN grouping(age) = 0 THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_1, + rank() OVER (PARTITION BY grouping(age) + grouping(id), CASE WHEN (CAST(grouping(age) AS BIGINT) = 0) THEN id END ORDER BY sum(id) DESC) AS rank_within_parent_2 + FROM person + GROUP BY id, first_name;"#, + r#"SELECT person.id, person.first_name, +sum(person.id) AS total_sum, sum(person.id) OVER (PARTITION BY person.first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) AS moving_sum, +max(sum(person.id)) OVER (PARTITION BY person.first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS max_total, +rank() OVER (PARTITION BY (grouping(person.id) + grouping(person.age)), CASE WHEN (grouping(person.age) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_1, +rank() OVER (PARTITION BY (grouping(person.age) + grouping(person.id)), CASE WHEN (CAST(grouping(person.age) AS BIGINT) = 0) THEN person.id END ORDER BY sum(person.id) DESC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS rank_within_parent_2 +FROM person +GROUP BY person.id, person.first_name"#.replace("\n", " ").as_str(), + ); +} + +#[test] +fn test_unnest_to_sql() { + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(array_col) as u1, struct_col, array_col FROM unnest_table WHERE array_col != NULL ORDER BY struct_col, array_col"#, + r#"SELECT UNNEST(unnest_table.array_col) AS u1, unnest_table.struct_col, unnest_table.array_col FROM unnest_table WHERE (unnest_table.array_col <> NULL) ORDER BY unnest_table.struct_col ASC NULLS LAST, unnest_table.array_col ASC NULLS LAST"#, + ); + + sql_round_trip( + GenericDialect {}, + r#"SELECT unnest(make_array(1, 2, 2, 5, NULL)) as u1"#, + r#"SELECT UNNEST(make_array(1, 2, 2, 5, NULL)) AS u1"#, + ); +} diff --git a/datafusion/sql/tests/common/mod.rs b/datafusion/sql/tests/common/mod.rs new file mode 100644 index 000000000000..b0fa17031849 --- /dev/null +++ b/datafusion/sql/tests/common/mod.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +#[cfg(test)] +use std::collections::HashMap; +use std::fmt::Display; +use std::{sync::Arc, vec}; + +use arrow_schema::*; +use datafusion_common::config::ConfigOptions; +use datafusion_common::file_options::file_type::FileType; +use datafusion_common::{plan_err, GetExt, Result, TableReference}; +use datafusion_expr::planner::ExprPlanner; +use datafusion_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}; +use datafusion_sql::planner::ContextProvider; + +struct MockCsvType {} + +impl GetExt for MockCsvType { + fn get_ext(&self) -> String { + "csv".to_string() + } +} + +impl FileType for MockCsvType { + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Display for MockCsvType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.get_ext()) + } +} + +#[derive(Default)] +pub(crate) struct MockSessionState { + scalar_functions: HashMap>, + aggregate_functions: HashMap>, + expr_planners: Vec>, + window_functions: HashMap>, + pub config_options: ConfigOptions, +} + +impl MockSessionState { + pub fn with_expr_planner(mut self, expr_planner: Arc) -> Self { + self.expr_planners.push(expr_planner); + self + } + + pub fn with_scalar_function(mut self, scalar_function: Arc) -> Self { + self.scalar_functions + .insert(scalar_function.name().to_string(), scalar_function); + self + } + + pub fn with_aggregate_function( + mut self, + aggregate_function: Arc, + ) -> Self { + // TODO: change to to_string() if all the function name is converted to lowercase + self.aggregate_functions.insert( + aggregate_function.name().to_string().to_lowercase(), + aggregate_function, + ); + self + } + + pub fn with_window_function(mut self, window_function: Arc) -> Self { + self.window_functions + .insert(window_function.name().to_string(), window_function); + self + } +} + +pub(crate) struct MockContextProvider { + pub(crate) state: MockSessionState, +} + +impl ContextProvider for MockContextProvider { + fn get_table_source(&self, name: TableReference) -> Result> { + let schema = match name.table() { + "test" => Ok(Schema::new(vec![ + Field::new("t_date32", DataType::Date32, false), + Field::new("t_date64", DataType::Date64, false), + ])), + "j1" => Ok(Schema::new(vec![ + Field::new("j1_id", DataType::Int32, false), + Field::new("j1_string", DataType::Utf8, false), + ])), + "j2" => Ok(Schema::new(vec![ + Field::new("j2_id", DataType::Int32, false), + Field::new("j2_string", DataType::Utf8, false), + ])), + "j3" => Ok(Schema::new(vec![ + Field::new("j3_id", DataType::Int32, false), + Field::new("j3_string", DataType::Utf8, false), + ])), + "test_decimal" => Ok(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("price", DataType::Decimal128(10, 2), false), + ])), + "person" => Ok(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new("state", DataType::Utf8, false), + Field::new("salary", DataType::Float64, false), + Field::new( + "birth_date", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("😀", DataType::Int32, false), + ])), + "person_quoted_cols" => Ok(Schema::new(vec![ + Field::new("id", DataType::UInt32, false), + Field::new("First Name", DataType::Utf8, false), + Field::new("Last Name", DataType::Utf8, false), + Field::new("Age", DataType::Int32, false), + Field::new("State", DataType::Utf8, false), + Field::new("Salary", DataType::Float64, false), + Field::new( + "Birth Date", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + Field::new("😀", DataType::Int32, false), + ])), + "orders" => Ok(Schema::new(vec![ + Field::new("order_id", DataType::UInt32, false), + Field::new("customer_id", DataType::UInt32, false), + Field::new("o_item_id", DataType::Utf8, false), + Field::new("qty", DataType::Int32, false), + Field::new("price", DataType::Float64, false), + Field::new("delivered", DataType::Boolean, false), + ])), + "array" => Ok(Schema::new(vec![ + Field::new( + "left", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new( + "right", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + ])), + "lineitem" => Ok(Schema::new(vec![ + Field::new("l_item_id", DataType::UInt32, false), + Field::new("l_description", DataType::Utf8, false), + Field::new("price", DataType::Float64, false), + ])), + "aggregate_test_100" => Ok(Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::UInt32, false), + Field::new("c3", DataType::Int8, false), + Field::new("c4", DataType::Int16, false), + Field::new("c5", DataType::Int32, false), + Field::new("c6", DataType::Int64, false), + Field::new("c7", DataType::UInt8, false), + Field::new("c8", DataType::UInt16, false), + Field::new("c9", DataType::UInt32, false), + Field::new("c10", DataType::UInt64, false), + Field::new("c11", DataType::Float32, false), + Field::new("c12", DataType::Float64, false), + Field::new("c13", DataType::Utf8, false), + ])), + "UPPERCASE_test" => Ok(Schema::new(vec![ + Field::new("Id", DataType::UInt32, false), + Field::new("lower", DataType::UInt32, false), + ])), + "unnest_table" => Ok(Schema::new(vec![ + Field::new( + "array_col", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + false, + ), + Field::new( + "struct_col", + DataType::Struct(Fields::from(vec![ + Field::new("field1", DataType::Int64, true), + Field::new("field2", DataType::Utf8, true), + ])), + false, + ), + ])), + _ => plan_err!("No table named: {} found", name.table()), + }; + + match schema { + Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))), + Err(e) => Err(e), + } + } + + fn get_function_meta(&self, name: &str) -> Option> { + self.state.scalar_functions.get(name).cloned() + } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.state.aggregate_functions.get(name).cloned() + } + + fn get_variable_type(&self, _: &[String]) -> Option { + unimplemented!() + } + + fn get_window_meta(&self, name: &str) -> Option> { + self.state.window_functions.get(name).cloned() + } + + fn options(&self) -> &ConfigOptions { + &self.state.config_options + } + + fn get_file_type(&self, _ext: &str) -> Result> { + Ok(Arc::new(MockCsvType {})) + } + + fn create_cte_work_table( + &self, + _name: &str, + schema: SchemaRef, + ) -> Result> { + Ok(Arc::new(EmptyTable::new(schema))) + } + + fn udf_names(&self) -> Vec { + self.state.scalar_functions.keys().cloned().collect() + } + + fn udaf_names(&self) -> Vec { + self.state.aggregate_functions.keys().cloned().collect() + } + + fn udwf_names(&self) -> Vec { + Vec::new() + } + + fn get_expr_planners(&self) -> &[Arc] { + &self.state.expr_planners + } +} + +struct EmptyTable { + table_schema: SchemaRef, +} + +impl EmptyTable { + fn new(table_schema: SchemaRef) -> Self { + Self { table_schema } + } +} + +impl TableSource for EmptyTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.table_schema) + } +} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 319aa5b5fd30..698c408e538f 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -18,36 +18,48 @@ use std::any::Any; #[cfg(test)] use std::collections::HashMap; -use std::{sync::Arc, vec}; +use std::sync::Arc; +use std::vec; use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; -use datafusion_common::config::ConfigOptions; +use common::MockContextProvider; use datafusion_common::{ - assert_contains, plan_err, DFSchema, DataFusionError, ParamValues, Result, - ScalarValue, TableReference, + assert_contains, DataFusionError, ParamValues, Result, ScalarValue, }; use datafusion_expr::{ + col, + dml::CopyTo, logical_plan::{LogicalPlan, Prepare}, - AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, - Volatility, WindowUDF, + test::function_stub::sum_udaf, + ColumnarValue, CreateExternalTable, CreateIndex, DdlStatement, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; -use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; +use datafusion_functions::{string, unicode}; use datafusion_sql::{ parser::DFParser, - planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, + planner::{ParserOptions, SqlToRel}, }; -use datafusion_functions::{string, unicode}; +use crate::common::MockSessionState; +use datafusion_functions::core::planner::CoreFunctionPlanner; +use datafusion_functions_aggregate::{ + approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf, + min_max::min_udaf, +}; +use datafusion_functions_aggregate::{average::avg_udaf, grouping::grouping_udaf}; +use datafusion_functions_window::rank::rank_udwf; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; -use sqlparser::parser::Parser; + +mod cases; +mod common; #[test] fn test_schema_support() { quick_test( "SELECT * FROM s1.test", - "Projection: s1.test.t_date32, s1.test.t_date64\ + "Projection: *\ \n TableScan: s1.test", ); } @@ -80,6 +92,8 @@ fn parse_decimals() { ParserOptions { parse_float_as_decimal: true, enable_ident_normalization: false, + support_varchar_with_length: false, + enable_options_value_normalization: false, }, ); } @@ -133,16 +147,83 @@ fn parse_ident_normalization() { ParserOptions { parse_float_as_decimal: false, enable_ident_normalization, + support_varchar_with_length: false, + enable_options_value_normalization: false, }, ); if plan.is_ok() { - assert_eq!(expected, format!("{plan:?}")); + let plan = plan.unwrap(); + assert_eq!(expected, format!("Ok({plan})")); } else { assert_eq!(expected, plan.unwrap_err().strip_backtrace()); } } } +#[test] +fn test_parse_options_value_normalization() { + let test_data = [ + ( + "CREATE EXTERNAL TABLE test OPTIONS ('location' 'LoCaTiOn') STORED AS PARQUET LOCATION 'fake_location'", + "CreateExternalTable: Bare { table: \"test\" }", + HashMap::from([("format.location", "LoCaTiOn")]), + false, + ), + ( + "CREATE EXTERNAL TABLE test OPTIONS ('location' 'LoCaTiOn') STORED AS PARQUET LOCATION 'fake_location'", + "CreateExternalTable: Bare { table: \"test\" }", + HashMap::from([("format.location", "location")]), + true, + ), + ( + "COPY test TO 'fake_location' STORED AS PARQUET OPTIONS ('location' 'LoCaTiOn')", + "CopyTo: format=csv output_url=fake_location options: (format.location LoCaTiOn)\n TableScan: test", + HashMap::from([("format.location", "LoCaTiOn")]), + false, + ), + ( + "COPY test TO 'fake_location' STORED AS PARQUET OPTIONS ('location' 'LoCaTiOn')", + "CopyTo: format=csv output_url=fake_location options: (format.location location)\n TableScan: test", + HashMap::from([("format.location", "location")]), + true, + ), + ]; + + for (sql, expected_plan, expected_options, enable_options_value_normalization) in + test_data + { + let plan = logical_plan_with_options( + sql, + ParserOptions { + parse_float_as_decimal: false, + enable_ident_normalization: false, + support_varchar_with_length: false, + enable_options_value_normalization, + }, + ); + if let Ok(plan) = plan { + assert_eq!(expected_plan, format!("{plan}")); + + match plan { + LogicalPlan::Ddl(DdlStatement::CreateExternalTable( + CreateExternalTable { options, .. }, + )) + | LogicalPlan::Copy(CopyTo { options, .. }) => { + expected_options.iter().for_each(|(k, v)| { + assert_eq!(Some(&v.to_string()), options.get(*k)); + }); + } + _ => panic!( + "Expected Ddl(CreateExternalTable) or Copy(CopyTo) but got {:?}", + plan + ), + } + } else { + assert_eq!(expected_plan, plan.unwrap_err().strip_backtrace()); + } + } +} + #[test] fn select_no_relation() { quick_test( @@ -201,9 +282,9 @@ fn cast_from_subquery() { #[test] fn try_cast_from_aggregation() { quick_test( - "SELECT TRY_CAST(SUM(age) AS FLOAT) FROM person", - "Projection: TRY_CAST(SUM(person.age) AS Float32)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(person.age)]]\ + "SELECT TRY_CAST(sum(age) AS FLOAT) FROM person", + "Projection: TRY_CAST(sum(person.age) AS Float32)\ + \n Aggregate: groupBy=[[]], aggr=[[sum(person.age)]]\ \n TableScan: person", ); } @@ -437,7 +518,7 @@ fn plan_copy_to_query() { let plan = r#" CopyTo: format=csv output_url=output.csv options: () Limit: skip=0, fetch=10 - Projection: test_decimal.id, test_decimal.price + Projection: * TableScan: test_decimal "# .trim(); @@ -557,23 +638,13 @@ fn select_repeated_column() { ); } -#[test] -fn select_wildcard_with_repeated_column() { - let sql = "SELECT *, age FROM person"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "Error during planning: Projections require unique expression names but the expression \"person.age\" at position 3 and \"person.age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.", - err.strip_backtrace() - ); -} - #[test] fn select_wildcard_with_repeated_column_but_is_aliased() { quick_test( - "SELECT *, first_name AS fn from person", - "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, person.first_name AS fn\ + "SELECT *, first_name AS fn from person", + "Projection: *, person.first_name AS fn\ \n TableScan: person", - ); + ); } #[test] @@ -790,7 +861,7 @@ fn where_selection_with_ambiguous_column() { #[test] fn natural_join() { let sql = "SELECT * FROM lineitem a NATURAL JOIN lineitem b"; - let expected = "Projection: a.l_item_id, a.l_description, a.price\ + let expected = "Projection: *\ \n Inner Join: Using a.l_item_id = b.l_item_id, a.l_description = b.l_description, a.price = b.price\ \n SubqueryAlias: a\ \n TableScan: lineitem\ @@ -826,8 +897,8 @@ fn natural_right_join() { #[test] fn natural_join_no_common_becomes_cross_join() { let sql = "SELECT * FROM person a NATURAL JOIN lineitem b"; - let expected = "Projection: a.id, a.first_name, a.last_name, a.age, a.state, a.salary, a.birth_date, a.😀, b.l_item_id, b.l_description, b.price\ - \n CrossJoin:\ + let expected = "Projection: *\ + \n Cross Join: \ \n SubqueryAlias: a\ \n TableScan: person\ \n SubqueryAlias: b\ @@ -838,8 +909,7 @@ fn natural_join_no_common_becomes_cross_join() { #[test] fn using_join_multiple_keys() { let sql = "SELECT * FROM person a join person b using (id, age)"; - let expected = "Projection: a.id, a.first_name, a.last_name, a.age, a.state, a.salary, a.birth_date, a.😀, \ - b.first_name, b.last_name, b.state, b.salary, b.birth_date, b.😀\ + let expected = "Projection: *\ \n Inner Join: Using a.id = b.id, a.age = b.age\ \n SubqueryAlias: a\ \n TableScan: person\ @@ -853,8 +923,7 @@ fn using_join_multiple_keys_subquery() { let sql = "SELECT age FROM (SELECT * FROM person a join person b using (id, age, state))"; let expected = "Projection: a.age\ - \n Projection: a.id, a.first_name, a.last_name, a.age, a.state, a.salary, a.birth_date, a.😀, \ - b.first_name, b.last_name, b.salary, b.birth_date, b.😀\ + \n Projection: *\ \n Inner Join: Using a.id = b.id, a.age = b.age, a.state = b.state\ \n SubqueryAlias: a\ \n TableScan: person\ @@ -866,8 +935,7 @@ fn using_join_multiple_keys_subquery() { #[test] fn using_join_multiple_keys_qualified_wildcard_select() { let sql = "SELECT a.* FROM person a join person b using (id, age)"; - let expected = - "Projection: a.id, a.first_name, a.last_name, a.age, a.state, a.salary, a.birth_date, a.😀\ + let expected = "Projection: a.*\ \n Inner Join: Using a.id = b.id, a.age = b.age\ \n SubqueryAlias: a\ \n TableScan: person\ @@ -879,8 +947,7 @@ fn using_join_multiple_keys_qualified_wildcard_select() { #[test] fn using_join_multiple_keys_select_all_columns() { let sql = "SELECT a.*, b.* FROM person a join person b using (id, age)"; - let expected = "Projection: a.id, a.first_name, a.last_name, a.age, a.state, a.salary, a.birth_date, a.😀, \ - b.id, b.first_name, b.last_name, b.age, b.state, b.salary, b.birth_date, b.😀\ + let expected = "Projection: a.*, b.*\ \n Inner Join: Using a.id = b.id, a.age = b.age\ \n SubqueryAlias: a\ \n TableScan: person\ @@ -892,9 +959,7 @@ fn using_join_multiple_keys_select_all_columns() { #[test] fn using_join_multiple_keys_multiple_joins() { let sql = "SELECT * FROM person a join person b using (id, age, state) join person c using (id, age, state)"; - let expected = "Projection: a.id, a.first_name, a.last_name, a.age, a.state, a.salary, a.birth_date, a.😀, \ - b.first_name, b.last_name, b.salary, b.birth_date, b.😀, \ - c.first_name, c.last_name, c.salary, c.birth_date, c.😀\ + let expected = "Projection: *\ \n Inner Join: Using a.id = c.id, a.age = c.age, a.state = c.state\ \n Inner Join: Using a.id = b.id, a.age = b.age, a.state = b.state\ \n SubqueryAlias: a\ @@ -938,7 +1003,7 @@ fn select_with_having_refers_to_invalid_column() { HAVING first_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: person.id, MAX(person.age)", + "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: person.id, max(person.age)", err.strip_backtrace() ); } @@ -962,7 +1027,7 @@ fn select_with_having_with_aggregate_not_in_select() { HAVING MAX(age) > 100"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projection references non-aggregate values: Expression person.first_name could not be resolved from available columns: MAX(person.age)", + "Error during planning: Projection references non-aggregate values: Expression person.first_name could not be resolved from available columns: max(person.age)", err.strip_backtrace() ); } @@ -972,33 +1037,33 @@ fn select_aggregate_with_having_that_reuses_aggregate() { let sql = "SELECT MAX(age) FROM person HAVING MAX(age) < 30"; - let expected = "Projection: MAX(person.age)\ - \n Filter: MAX(person.age) < Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: max(person.age)\ + \n Filter: max(person.age) < Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn select_aggregate_with_having_with_aggregate_not_in_select() { - let sql = "SELECT MAX(age) + let sql = "SELECT max(age) FROM person - HAVING MAX(first_name) > 'M'"; - let expected = "Projection: MAX(person.age)\ - \n Filter: MAX(person.first_name) > Utf8(\"M\")\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(person.age), MAX(person.first_name)]]\ + HAVING max(first_name) > 'M'"; + let expected = "Projection: max(person.age)\ + \n Filter: max(person.first_name) > Utf8(\"M\")\ + \n Aggregate: groupBy=[[]], aggr=[[max(person.age), max(person.first_name)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn select_aggregate_with_having_referencing_column_not_in_select() { - let sql = "SELECT COUNT(*) + let sql = "SELECT count(*) FROM person HAVING first_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: COUNT(*)", + "Error during planning: HAVING clause references non-aggregate values: Expression person.first_name could not be resolved from available columns: count(*)", err.strip_backtrace() ); } @@ -1009,21 +1074,21 @@ fn select_aggregate_aliased_with_having_referencing_aggregate_by_its_alias() { FROM person HAVING max_age < 30"; // FIXME: add test for having in execution - let expected = "Projection: MAX(person.age) AS max_age\ - \n Filter: MAX(person.age) < Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: max(person.age) AS max_age\ + \n Filter: max(person.age) < Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn select_aggregate_aliased_with_having_that_reuses_aggregate_but_not_by_its_alias() { - let sql = "SELECT MAX(age) as max_age + let sql = "SELECT max(age) as max_age FROM person - HAVING MAX(age) < 30"; - let expected = "Projection: MAX(person.age) AS max_age\ - \n Filter: MAX(person.age) < Int64(30)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]]\ + HAVING max(age) < 30"; + let expected = "Projection: max(person.age) AS max_age\ + \n Filter: max(person.age) < Int64(30)\ + \n Aggregate: groupBy=[[]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1034,23 +1099,23 @@ fn select_aggregate_with_group_by_with_having() { FROM person GROUP BY first_name HAVING first_name = 'M'"; - let expected = "Projection: person.first_name, MAX(person.age)\ + let expected = "Projection: person.first_name, max(person.age)\ \n Filter: person.first_name = Utf8(\"M\")\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn select_aggregate_with_group_by_with_having_and_where() { - let sql = "SELECT first_name, MAX(age) + let sql = "SELECT first_name, max(age) FROM person WHERE id > 5 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) < Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) < Int64(100)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n Filter: person.id > Int64(5)\ \n TableScan: person"; quick_test(sql, expected); @@ -1063,9 +1128,9 @@ fn select_aggregate_with_group_by_with_having_and_where_filtering_on_aggregate_c WHERE id > 5 AND age > 18 GROUP BY first_name HAVING MAX(age) < 100"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) < Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) < Int64(100)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n Filter: person.id > Int64(5) AND person.age > Int64(18)\ \n TableScan: person"; quick_test(sql, expected); @@ -1077,9 +1142,9 @@ fn select_aggregate_with_group_by_with_having_using_column_by_alias() { FROM person GROUP BY first_name HAVING MAX(age) > 2 AND fn = 'M'"; - let expected = "Projection: person.first_name AS fn, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(2) AND person.first_name = Utf8(\"M\")\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name AS fn, max(person.age)\ + \n Filter: max(person.age) > Int64(2) AND person.first_name = Utf8(\"M\")\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1091,9 +1156,9 @@ fn select_aggregate_with_group_by_with_having_using_columns_with_and_without_the FROM person GROUP BY first_name HAVING MAX(age) > 2 AND max_age < 5 AND first_name = 'M' AND fn = 'N'"; - let expected = "Projection: person.first_name AS fn, MAX(person.age) AS max_age\ - \n Filter: MAX(person.age) > Int64(2) AND MAX(person.age) < Int64(5) AND person.first_name = Utf8(\"M\") AND person.first_name = Utf8(\"N\")\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name AS fn, max(person.age) AS max_age\ + \n Filter: max(person.age) > Int64(2) AND max(person.age) < Int64(5) AND person.first_name = Utf8(\"M\") AND person.first_name = Utf8(\"N\")\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1104,9 +1169,9 @@ fn select_aggregate_with_group_by_with_having_that_reuses_aggregate() { FROM person GROUP BY first_name HAVING MAX(age) > 100"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) > Int64(100)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1119,7 +1184,7 @@ fn select_aggregate_with_group_by_with_having_referencing_column_not_in_group_by HAVING MAX(age) > 10 AND last_name = 'M'"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: HAVING clause references non-aggregate values: Expression person.last_name could not be resolved from available columns: person.first_name, MAX(person.age)", + "Error during planning: HAVING clause references non-aggregate values: Expression person.last_name could not be resolved from available columns: person.first_name, max(person.age)", err.strip_backtrace() ); } @@ -1130,22 +1195,22 @@ fn select_aggregate_with_group_by_with_having_that_reuses_aggregate_multiple_tim FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MAX(age) < 200"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND MAX(person.age) < Int64(200)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) > Int64(100) AND max(person.age) < Int64(200)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] -fn select_aggregate_with_group_by_with_having_using_aggreagate_not_in_select() { +fn select_aggregate_with_group_by_with_having_using_aggregate_not_in_select() { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id) < 50"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND MIN(person.id) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), MIN(person.id)]]\ + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) > Int64(100) AND min(person.id) < Int64(50)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), min(person.id)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1157,9 +1222,9 @@ fn select_aggregate_aliased_with_group_by_with_having_referencing_aggregate_by_i FROM person GROUP BY first_name HAVING max_age > 100"; - let expected = "Projection: person.first_name, MAX(person.age) AS max_age\ - \n Filter: MAX(person.age) > Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name, max(person.age) AS max_age\ + \n Filter: max(person.age) > Int64(100)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1171,23 +1236,23 @@ fn select_aggregate_compound_aliased_with_group_by_with_having_referencing_compo FROM person GROUP BY first_name HAVING max_age_plus_one > 100"; - let expected = "Projection: person.first_name, MAX(person.age) + Int64(1) AS max_age_plus_one\ - \n Filter: MAX(person.age) + Int64(1) > Int64(100)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: person.first_name, max(person.age) + Int64(1) AS max_age_plus_one\ + \n Filter: max(person.age) + Int64(1) > Int64(100)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] -fn select_aggregate_with_group_by_with_having_using_derived_column_aggreagate_not_in_select( +fn select_aggregate_with_group_by_with_having_using_derived_column_aggregate_not_in_select( ) { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name HAVING MAX(age) > 100 AND MIN(id - 2) < 50"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND MIN(person.id - Int64(2)) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), MIN(person.id - Int64(2))]]\ + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) > Int64(100) AND min(person.id - Int64(2)) < Int64(50)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), min(person.id - Int64(2))]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1197,10 +1262,10 @@ fn select_aggregate_with_group_by_with_having_using_count_star_not_in_select() { let sql = "SELECT first_name, MAX(age) FROM person GROUP BY first_name - HAVING MAX(age) > 100 AND COUNT(*) < 50"; - let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND COUNT(*) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(*)]]\ + HAVING MAX(age) > 100 AND count(*) < 50"; + let expected = "Projection: person.first_name, max(person.age)\ + \n Filter: max(person.age) > Int64(100) AND count(*) < Int64(50)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.age), count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1221,33 +1286,17 @@ fn select_binary_expr_nested() { quick_test(sql, expected); } -#[test] -fn select_at_arrow_operator() { - let sql = "SELECT left @> right from array"; - let expected = "Projection: array.left @> array.right\ - \n TableScan: array"; - quick_test(sql, expected); -} - -#[test] -fn select_arrow_at_operator() { - let sql = "SELECT left <@ right from array"; - let expected = "Projection: array.left <@ array.right\ - \n TableScan: array"; - quick_test(sql, expected); -} - #[test] fn select_wildcard_with_groupby() { quick_test( r#"SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date, "😀""#, - "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ + "Projection: *\ \n Aggregate: groupBy=[[person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀]], aggr=[[]]\ \n TableScan: person", ); quick_test( "SELECT * FROM (SELECT first_name, last_name FROM person) AS a GROUP BY first_name, last_name", - "Projection: a.first_name, a.last_name\ + "Projection: *\ \n Aggregate: groupBy=[[a.first_name, a.last_name]], aggr=[[]]\ \n SubqueryAlias: a\ \n Projection: person.first_name, person.last_name\ @@ -1259,8 +1308,8 @@ fn select_wildcard_with_groupby() { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Projection: MIN(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(person.age)]]\ + "Projection: min(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[min(person.age)]]\ \n TableScan: person", ); } @@ -1268,9 +1317,9 @@ fn select_simple_aggregate() { #[test] fn test_sum_aggregate() { quick_test( - "SELECT SUM(age) from person", - "Projection: SUM(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(person.age)]]\ + "SELECT sum(age) from person", + "Projection: sum(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[sum(person.age)]]\ \n TableScan: person", ); } @@ -1287,7 +1336,7 @@ fn select_simple_aggregate_repeated_aggregate() { let sql = "SELECT MIN(age), MIN(age) FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projections require unique expression names but the expression \"MIN(person.age)\" at position 0 and \"MIN(person.age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + "Error during planning: Projections require unique expression names but the expression \"min(person.age)\" at position 0 and \"min(person.age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", err.strip_backtrace() ); } @@ -1296,8 +1345,8 @@ fn select_simple_aggregate_repeated_aggregate() { fn select_simple_aggregate_repeated_aggregate_with_single_alias() { quick_test( "SELECT MIN(age), MIN(age) AS a FROM person", - "Projection: MIN(person.age), MIN(person.age) AS a\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(person.age)]]\ + "Projection: min(person.age), min(person.age) AS a\ + \n Aggregate: groupBy=[[]], aggr=[[min(person.age)]]\ \n TableScan: person", ); } @@ -1306,8 +1355,8 @@ fn select_simple_aggregate_repeated_aggregate_with_single_alias() { fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() { quick_test( "SELECT MIN(age) AS a, MIN(age) AS b FROM person", - "Projection: MIN(person.age) AS a, MIN(person.age) AS b\ - \n Aggregate: groupBy=[[]], aggr=[[MIN(person.age)]]\ + "Projection: min(person.age) AS a, min(person.age) AS b\ + \n Aggregate: groupBy=[[]], aggr=[[min(person.age)]]\ \n TableScan: person", ); } @@ -1328,7 +1377,7 @@ fn select_simple_aggregate_repeated_aggregate_with_repeated_aliases() { let sql = "SELECT MIN(age) AS a, MIN(age) AS a FROM person"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projections require unique expression names but the expression \"MIN(person.age) AS a\" at position 0 and \"MIN(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + "Error during planning: Projections require unique expression names but the expression \"min(person.age) AS a\" at position 0 and \"min(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", err.strip_backtrace() ); } @@ -1337,8 +1386,8 @@ fn select_simple_aggregate_repeated_aggregate_with_repeated_aliases() { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: person.state, MIN(person.age), MAX(person.age)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age), MAX(person.age)]]\ + "Projection: person.state, min(person.age), max(person.age)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.age)]]\ \n TableScan: person", ); } @@ -1347,8 +1396,8 @@ fn select_simple_aggregate_with_groupby() { fn select_simple_aggregate_with_groupby_with_aliases() { quick_test( "SELECT state AS a, MIN(age) AS b FROM person GROUP BY state", - "Projection: person.state AS a, MIN(person.age) AS b\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age)]]\ + "Projection: person.state AS a, min(person.age) AS b\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ \n TableScan: person", ); } @@ -1358,7 +1407,7 @@ fn select_simple_aggregate_with_groupby_with_aliases_repeated() { let sql = "SELECT state AS a, MIN(age) AS a FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projections require unique expression names but the expression \"person.state AS a\" at position 0 and \"MIN(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", + "Error during planning: Projections require unique expression names but the expression \"person.state AS a\" at position 0 and \"min(person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.", err.strip_backtrace() ); } @@ -1367,24 +1416,24 @@ fn select_simple_aggregate_with_groupby_with_aliases_repeated() { fn select_simple_aggregate_with_groupby_column_unselected() { quick_test( "SELECT MIN(age), MAX(age) FROM person GROUP BY state", - "Projection: MIN(person.age), MAX(person.age)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age), MAX(person.age)]]\ + "Projection: min(person.age), max(person.age)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age), max(person.age)]]\ \n TableScan: person", ); } #[test] fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() { - let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist"; + let sql = "SELECT sum(age) FROM person GROUP BY doesnotexist"; let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!("Schema error: No field named doesnotexist. Valid fields are \"SUM(person.age)\", \ + assert_eq!("Schema error: No field named doesnotexist. Valid fields are \"sum(person.age)\", \ person.id, person.first_name, person.last_name, person.age, person.state, \ person.salary, person.birth_date, person.\"😀\".", err.strip_backtrace()); } #[test] fn select_simple_aggregate_with_groupby_and_column_in_aggregate_does_not_exist() { - let sql = "SELECT SUM(doesnotexist) FROM person GROUP BY first_name"; + let sql = "SELECT sum(doesnotexist) FROM person GROUP BY first_name"; let err = logical_plan(sql).expect_err("query should have failed"); assert_field_not_found(err, "doesnotexist"); } @@ -1410,7 +1459,7 @@ fn recursive_ctes() { select * from numbers;"; quick_test( sql, - "Projection: numbers.n\ + "Projection: *\ \n SubqueryAlias: numbers\ \n RecursiveQuery: is_distinct=false\ \n Projection: Int64(1) AS n\ @@ -1432,8 +1481,9 @@ fn recursive_ctes_disabled() { select * from numbers;"; // manually setting up test here so that we can disable recursive ctes - let mut context = MockContextProvider::default(); - context.options_mut().execution.enable_recursive_ctes = false; + let mut state = MockSessionState::default(); + state.config_options.execution.enable_recursive_ctes = false; + let context = MockContextProvider { state }; let planner = SqlToRel::new_with_options(&context, ParserOptions::default()); let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {}); @@ -1452,8 +1502,8 @@ fn recursive_ctes_disabled() { fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() { quick_test( "SELECT MAX(first_name) FROM person GROUP BY first_name", - "Projection: MAX(person.first_name)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.first_name)]]\ + "Projection: max(person.first_name)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[max(person.first_name)]]\ \n TableScan: person", ); } @@ -1461,15 +1511,15 @@ fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() #[test] fn select_simple_aggregate_with_groupby_can_use_positions() { quick_test( - "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2", - "Projection: person.state, person.age AS b, COUNT(Int64(1))\ - \n Aggregate: groupBy=[[person.state, person.age]], aggr=[[COUNT(Int64(1))]]\ + "SELECT state, age AS b, count(1) FROM person GROUP BY 1, 2", + "Projection: person.state, person.age AS b, count(Int64(1))\ + \n Aggregate: groupBy=[[person.state, person.age]], aggr=[[count(Int64(1))]]\ \n TableScan: person", ); quick_test( - "SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1", - "Projection: person.state, person.age AS b, COUNT(Int64(1))\ - \n Aggregate: groupBy=[[person.age, person.state]], aggr=[[COUNT(Int64(1))]]\ + "SELECT state, age AS b, count(1) FROM person GROUP BY 2, 1", + "Projection: person.state, person.age AS b, count(Int64(1))\ + \n Aggregate: groupBy=[[person.age, person.state]], aggr=[[count(Int64(1))]]\ \n TableScan: person", ); } @@ -1479,24 +1529,24 @@ fn select_simple_aggregate_with_groupby_position_out_of_range() { let sql = "SELECT state, MIN(age) FROM person GROUP BY 0"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projection references non-aggregate values: Expression person.state could not be resolved from available columns: Int64(0), MIN(person.age)", - err.strip_backtrace() - ); + "Error during planning: Cannot find column with position 0 in SELECT clause. Valid columns: 1 to 2", + err.strip_backtrace() + ); let sql2 = "SELECT state, MIN(age) FROM person GROUP BY 5"; let err2 = logical_plan(sql2).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projection references non-aggregate values: Expression person.state could not be resolved from available columns: Int64(5), MIN(person.age)", - err2.strip_backtrace() - ); + "Error during planning: Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 2", + err2.strip_backtrace() + ); } #[test] fn select_simple_aggregate_with_groupby_can_use_alias() { quick_test( "SELECT state AS a, MIN(age) AS b FROM person GROUP BY a", - "Projection: person.state AS a, MIN(person.age) AS b\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age)]]\ + "Projection: person.state AS a, min(person.age) AS b\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ \n TableScan: person", ); } @@ -1506,7 +1556,7 @@ fn select_simple_aggregate_with_groupby_aggregate_repeated() { let sql = "SELECT state, MIN(age), MIN(age) FROM person GROUP BY state"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projections require unique expression names but the expression \"MIN(person.age)\" at position 1 and \"MIN(person.age)\" at position 2 have the same name. Consider aliasing (\"AS\") one of them.", + "Error during planning: Projections require unique expression names but the expression \"min(person.age)\" at position 1 and \"min(person.age)\" at position 2 have the same name. Consider aliasing (\"AS\") one of them.", err.strip_backtrace() ); } @@ -1515,8 +1565,8 @@ fn select_simple_aggregate_with_groupby_aggregate_repeated() { fn select_simple_aggregate_with_groupby_aggregate_repeated_and_one_has_alias() { quick_test( "SELECT state, MIN(age), MIN(age) AS ma FROM person GROUP BY state", - "Projection: person.state, MIN(person.age), MIN(person.age) AS ma\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age)]]\ + "Projection: person.state, min(person.age), min(person.age) AS ma\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ \n TableScan: person", ) } @@ -1525,8 +1575,8 @@ fn select_simple_aggregate_with_groupby_aggregate_repeated_and_one_has_alias() { fn select_simple_aggregate_with_groupby_non_column_expression_unselected() { quick_test( "SELECT MIN(first_name) FROM person GROUP BY age + 1", - "Projection: MIN(person.first_name)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[MIN(person.first_name)]]\ + "Projection: min(person.first_name)\ + \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ \n TableScan: person", ); } @@ -1535,14 +1585,14 @@ fn select_simple_aggregate_with_groupby_non_column_expression_unselected() { fn select_simple_aggregate_with_groupby_non_column_expression_selected_and_resolvable() { quick_test( "SELECT age + 1, MIN(first_name) FROM person GROUP BY age + 1", - "Projection: person.age + Int64(1), MIN(person.first_name)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[MIN(person.first_name)]]\ + "Projection: person.age + Int64(1), min(person.first_name)\ + \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ \n TableScan: person", ); quick_test( "SELECT MIN(first_name), age + 1 FROM person GROUP BY age + 1", - "Projection: MIN(person.first_name), person.age + Int64(1)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[MIN(person.first_name)]]\ + "Projection: min(person.first_name), person.age + Int64(1)\ + \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ \n TableScan: person", ); } @@ -1551,8 +1601,8 @@ fn select_simple_aggregate_with_groupby_non_column_expression_selected_and_resol fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_resolvable() { quick_test( "SELECT ((age + 1) / 2) * (age + 1), MIN(first_name) FROM person GROUP BY age + 1", - "Projection: person.age + Int64(1) / Int64(2) * person.age + Int64(1), MIN(person.first_name)\ - \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[MIN(person.first_name)]]\ + "Projection: person.age + Int64(1) / Int64(2) * person.age + Int64(1), min(person.first_name)\ + \n Aggregate: groupBy=[[person.age + Int64(1)]], aggr=[[min(person.first_name)]]\ \n TableScan: person", ); } @@ -1564,7 +1614,7 @@ fn select_simple_aggregate_with_groupby_non_column_expression_nested_and_not_res let sql = "SELECT ((age + 1) / 2) * (age + 9), MIN(first_name) FROM person GROUP BY age + 1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)", + "Error during planning: Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), min(person.first_name)", err.strip_backtrace() ); } @@ -1574,7 +1624,7 @@ fn select_simple_aggregate_with_groupby_non_column_expression_and_its_column_sel let sql = "SELECT age, MIN(first_name) FROM person GROUP BY age + 1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), MIN(person.first_name)", + "Error during planning: Projection references non-aggregate values: Expression person.age could not be resolved from available columns: person.age + Int64(1), min(person.first_name)", err.strip_backtrace() ); } @@ -1583,8 +1633,8 @@ fn select_simple_aggregate_with_groupby_non_column_expression_and_its_column_sel fn select_simple_aggregate_nested_in_binary_expr_with_groupby() { quick_test( "SELECT state, MIN(age) < 10 FROM person GROUP BY state", - "Projection: person.state, MIN(person.age) < Int64(10)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age)]]\ + "Projection: person.state, min(person.age) < Int64(10)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age)]]\ \n TableScan: person", ); } @@ -1593,8 +1643,8 @@ fn select_simple_aggregate_nested_in_binary_expr_with_groupby() { fn select_simple_aggregate_and_nested_groupby_column() { quick_test( "SELECT age + 1, MAX(first_name) FROM person GROUP BY age", - "Projection: person.age + Int64(1), MAX(person.first_name)\ - \n Aggregate: groupBy=[[person.age]], aggr=[[MAX(person.first_name)]]\ + "Projection: person.age + Int64(1), max(person.first_name)\ + \n Aggregate: groupBy=[[person.age]], aggr=[[max(person.first_name)]]\ \n TableScan: person", ); } @@ -1603,8 +1653,8 @@ fn select_simple_aggregate_and_nested_groupby_column() { fn select_aggregate_compounded_with_groupby_column() { quick_test( "SELECT age + MIN(salary) FROM person GROUP BY age", - "Projection: person.age + MIN(person.salary)\ - \n Aggregate: groupBy=[[person.age]], aggr=[[MIN(person.salary)]]\ + "Projection: person.age + min(person.salary)\ + \n Aggregate: groupBy=[[person.age]], aggr=[[min(person.salary)]]\ \n TableScan: person", ); } @@ -1613,8 +1663,8 @@ fn select_aggregate_compounded_with_groupby_column() { fn select_aggregate_with_non_column_inner_expression_with_groupby() { quick_test( "SELECT state, MIN(age + 1) FROM person GROUP BY state", - "Projection: person.state, MIN(person.age + Int64(1))\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MIN(person.age + Int64(1))]]\ + "Projection: person.state, min(person.age + Int64(1))\ + \n Aggregate: groupBy=[[person.state]], aggr=[[min(person.age + Int64(1))]]\ \n TableScan: person", ); } @@ -1622,26 +1672,26 @@ fn select_aggregate_with_non_column_inner_expression_with_groupby() { #[test] fn test_wildcard() { quick_test( - "SELECT * from person", - "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ + "SELECT * from person", + "Projection: *\ \n TableScan: person", - ); + ); } #[test] fn select_count_one() { - let sql = "SELECT COUNT(1) FROM person"; - let expected = "Projection: COUNT(Int64(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + let sql = "SELECT count(1) FROM person"; + let expected = "Projection: count(Int64(1))\ + \n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn select_count_column() { - let sql = "SELECT COUNT(id) FROM person"; - let expected = "Projection: COUNT(person.id)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(person.id)]]\ + let sql = "SELECT count(id) FROM person"; + let expected = "Projection: count(person.id)\ + \n Aggregate: groupBy=[[]], aggr=[[count(person.id)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1649,8 +1699,8 @@ fn select_count_column() { #[test] fn select_approx_median() { let sql = "SELECT approx_median(age) FROM person"; - let expected = "Projection: APPROX_MEDIAN(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[APPROX_MEDIAN(person.age)]]\ + let expected = "Projection: approx_median(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[approx_median(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1802,8 +1852,8 @@ fn select_group_by() { #[test] fn select_group_by_columns_not_in_select() { let sql = "SELECT MAX(age) FROM person GROUP BY state"; - let expected = "Projection: MAX(person.age)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[MAX(person.age)]]\ + let expected = "Projection: max(person.age)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[max(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -1811,9 +1861,9 @@ fn select_group_by_columns_not_in_select() { #[test] fn select_group_by_count_star() { - let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Projection: person.state, COUNT(*)\ - \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(*)]]\ + let sql = "SELECT state, count(*) FROM person GROUP BY state"; + let expected = "Projection: person.state, count(*)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -1821,10 +1871,10 @@ fn select_group_by_count_star() { #[test] fn select_group_by_needs_projection() { - let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; + let sql = "SELECT count(state), state FROM person GROUP BY state"; let expected = "\ - Projection: COUNT(person.state), person.state\ - \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(person.state)]]\ + Projection: count(person.state), person.state\ + \n Aggregate: groupBy=[[person.state]], aggr=[[count(person.state)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -1833,8 +1883,8 @@ fn select_group_by_needs_projection() { #[test] fn select_7480_1() { let sql = "SELECT c1, MIN(c12) FROM aggregate_test_100 GROUP BY c1, c13"; - let expected = "Projection: aggregate_test_100.c1, MIN(aggregate_test_100.c12)\ - \n Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c13]], aggr=[[MIN(aggregate_test_100.c12)]]\ + let expected = "Projection: aggregate_test_100.c1, min(aggregate_test_100.c12)\ + \n Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c13]], aggr=[[min(aggregate_test_100.c12)]]\ \n TableScan: aggregate_test_100"; quick_test(sql, expected); } @@ -1844,7 +1894,7 @@ fn select_7480_2() { let sql = "SELECT c1, c13, MIN(c12) FROM aggregate_test_100 GROUP BY c1"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c13 could not be resolved from available columns: aggregate_test_100.c1, MIN(aggregate_test_100.c12)", + "Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c13 could not be resolved from available columns: aggregate_test_100.c1, min(aggregate_test_100.c12)", err.strip_backtrace() ); } @@ -1864,6 +1914,13 @@ fn create_external_table_with_pk() { quick_test(sql, expected); } +#[test] +fn create_external_table_wih_schema() { + let sql = "CREATE EXTERNAL TABLE staging.foo STORED AS CSV LOCATION 'foo.csv'"; + let expected = "CreateExternalTable: Partial { schema: \"staging\", table: \"foo\" }"; + quick_test(sql, expected); +} + #[test] fn create_schema_with_quoted_name() { let sql = "CREATE SCHEMA \"quoted_schema_name\""; @@ -1903,12 +1960,12 @@ fn create_external_table_csv_no_schema() { fn create_external_table_with_compression_type() { // positive case let sqls = vec![ - "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE GZIP LOCATION 'foo.csv.gz'", - "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV COMPRESSION TYPE BZIP2 LOCATION 'foo.csv.bz2'", - "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON COMPRESSION TYPE GZIP LOCATION 'foo.json.gz'", - "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON COMPRESSION TYPE BZIP2 LOCATION 'foo.json.bz2'", - "CREATE EXTERNAL TABLE t(c1 int) STORED AS NONSTANDARD COMPRESSION TYPE GZIP LOCATION 'foo.unk'", - ]; + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv.gz' OPTIONS ('format.compression' 'gzip')", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv.bz2' OPTIONS ('format.compression' 'bzip2')", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON LOCATION 'foo.json.gz' OPTIONS ('format.compression' 'gzip')", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS JSON LOCATION 'foo.json.bz2' OPTIONS ('format.compression' 'bzip2')", + "CREATE EXTERNAL TABLE t(c1 int) STORED AS NONSTANDARD LOCATION 'foo.unk' OPTIONS ('format.compression' 'gzip')", + ]; for sql in sqls { let expected = "CreateExternalTable: Bare { table: \"t\" }"; quick_test(sql, expected); @@ -1916,12 +1973,12 @@ fn create_external_table_with_compression_type() { // negative case let sqls = vec![ - "CREATE EXTERNAL TABLE t STORED AS AVRO COMPRESSION TYPE GZIP LOCATION 'foo.avro'", - "CREATE EXTERNAL TABLE t STORED AS AVRO COMPRESSION TYPE BZIP2 LOCATION 'foo.avro'", - "CREATE EXTERNAL TABLE t STORED AS PARQUET COMPRESSION TYPE GZIP LOCATION 'foo.parquet'", - "CREATE EXTERNAL TABLE t STORED AS PARQUET COMPRESSION TYPE BZIP2 LOCATION 'foo.parquet'", - "CREATE EXTERNAL TABLE t STORED AS ARROW COMPRESSION TYPE GZIP LOCATION 'foo.arrow'", - "CREATE EXTERNAL TABLE t STORED AS ARROW COMPRESSION TYPE BZIP2 LOCATION 'foo.arrow'", + "CREATE EXTERNAL TABLE t STORED AS AVRO LOCATION 'foo.avro' OPTIONS ('format.compression' 'gzip')", + "CREATE EXTERNAL TABLE t STORED AS AVRO LOCATION 'foo.avro' OPTIONS ('format.compression' 'bzip2')", + "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet' OPTIONS ('format.compression' 'gzip')", + "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet' OPTIONS ('format.compression' 'bzip2')", + "CREATE EXTERNAL TABLE t STORED AS ARROW LOCATION 'foo.arrow' OPTIONS ('format.compression' 'gzip')", + "CREATE EXTERNAL TABLE t STORED AS ARROW LOCATION 'foo.arrow' OPTIONS ('format.compression' 'bzip2')", ]; for sql in sqls { let err = logical_plan(sql).expect_err("query should have failed"); @@ -1953,6 +2010,13 @@ fn create_external_table_parquet_no_schema() { quick_test(sql, expected); } +#[test] +fn create_external_table_parquet_no_schema_sort_order() { + let sql = "CREATE EXTERNAL TABLE t STORED AS PARQUET LOCATION 'foo.parquet' WITH ORDER (id)"; + let expected = "CreateExternalTable: Bare { table: \"t\" }"; + quick_test(sql, expected); +} + #[test] fn equijoin_explicit_syntax() { let sql = "SELECT id, order_id \ @@ -2053,7 +2117,7 @@ fn project_wildcard_on_join_with_using() { FROM lineitem \ JOIN lineitem as lineitem2 \ USING (l_item_id)"; - let expected = "Projection: lineitem.l_item_id, lineitem.l_description, lineitem.price, lineitem2.l_description, lineitem2.price\ + let expected = "Projection: *\ \n Inner Join: Using lineitem.l_item_id = lineitem2.l_item_id\ \n TableScan: lineitem\ \n SubqueryAlias: lineitem2\ @@ -2111,154 +2175,12 @@ fn union_all() { quick_test(sql, expected); } -#[test] -fn union_with_different_column_names() { - let sql = "SELECT order_id from orders UNION ALL SELECT customer_id FROM orders"; - let expected = "Union\ - \n Projection: orders.order_id\ - \n TableScan: orders\ - \n Projection: orders.customer_id AS order_id\ - \n TableScan: orders"; - quick_test(sql, expected); -} - -#[test] -fn union_values_with_no_alias() { - let sql = "SELECT 1, 2 UNION ALL SELECT 3, 4"; - let expected = "Union\ - \n Projection: Int64(1) AS Int64(1), Int64(2) AS Int64(2)\ - \n EmptyRelation\ - \n Projection: Int64(3) AS Int64(1), Int64(4) AS Int64(2)\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_incompatible_data_type() { - let sql = "SELECT interval '1 year 1 day' UNION ALL SELECT 1"; - let err = logical_plan(sql) - .expect_err("query should have failed") - .strip_backtrace(); - assert_eq!( - "Error during planning: UNION Column Int64(1) (type: Int64) is not compatible with column IntervalMonthDayNano(\"950737950189618795196236955648\") (type: Interval(MonthDayNano))", - err - ); -} - -#[test] -fn union_with_different_decimal_data_types() { - let sql = "SELECT 1 a UNION ALL SELECT 1.1 a"; - let expected = "Union\ - \n Projection: CAST(Int64(1) AS Float64) AS a\ - \n EmptyRelation\ - \n Projection: Float64(1.1) AS a\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_null() { - let sql = "SELECT NULL a UNION ALL SELECT 1.1 a"; - let expected = "Union\ - \n Projection: CAST(NULL AS Float64) AS a\ - \n EmptyRelation\ - \n Projection: Float64(1.1) AS a\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_float_and_string() { - let sql = "SELECT 'a' a UNION ALL SELECT 1.1 a"; - let expected = "Union\ - \n Projection: Utf8(\"a\") AS a\ - \n EmptyRelation\ - \n Projection: CAST(Float64(1.1) AS Utf8) AS a\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_multiply_cols() { - let sql = "SELECT 'a' a, 1 b UNION ALL SELECT 1.1 a, 1.1 b"; - let expected = "Union\ - \n Projection: Utf8(\"a\") AS a, CAST(Int64(1) AS Float64) AS b\ - \n EmptyRelation\ - \n Projection: CAST(Float64(1.1) AS Utf8) AS a, Float64(1.1) AS b\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn sorted_union_with_different_types_and_group_by() { - let sql = "SELECT a FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a FROM (select 1.1 a) x GROUP BY 1) ORDER BY 1"; - let expected = "Sort: x.a ASC NULLS LAST\ - \n Union\ - \n Projection: CAST(x.a AS Float64) AS a\ - \n Aggregate: groupBy=[[x.a]], aggr=[[]]\ - \n SubqueryAlias: x\ - \n Projection: Int64(1) AS a\ - \n EmptyRelation\ - \n Projection: x.a\ - \n Aggregate: groupBy=[[x.a]], aggr=[[]]\ - \n SubqueryAlias: x\ - \n Projection: Float64(1.1) AS a\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_binary_expr_and_cast() { - let sql = "SELECT cast(0.0 + a as integer) FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT 2.1 + a FROM (select 1 a) x GROUP BY 1)"; - let expected = "Union\ - \n Projection: CAST(Float64(0) + x.a AS Float64) AS Float64(0) + x.a\ - \n Aggregate: groupBy=[[CAST(Float64(0) + x.a AS Int32)]], aggr=[[]]\ - \n SubqueryAlias: x\ - \n Projection: Int64(1) AS a\ - \n EmptyRelation\ - \n Projection: Float64(2.1) + x.a AS Float64(0) + x.a\ - \n Aggregate: groupBy=[[Float64(2.1) + x.a]], aggr=[[]]\ - \n SubqueryAlias: x\ - \n Projection: Int64(1) AS a\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_aliases() { - let sql = "SELECT a as a1 FROM (select 1 a) x GROUP BY 1 UNION ALL (SELECT a as a1 FROM (select 1.1 a) x GROUP BY 1)"; - let expected = "Union\ - \n Projection: CAST(x.a AS Float64) AS a1\ - \n Aggregate: groupBy=[[x.a]], aggr=[[]]\ - \n SubqueryAlias: x\ - \n Projection: Int64(1) AS a\ - \n EmptyRelation\ - \n Projection: x.a AS a1\ - \n Aggregate: groupBy=[[x.a]], aggr=[[]]\ - \n SubqueryAlias: x\ - \n Projection: Float64(1.1) AS a\ - \n EmptyRelation"; - quick_test(sql, expected); -} - -#[test] -fn union_with_incompatible_data_types() { - let sql = "SELECT 'a' a UNION ALL SELECT true a"; - let err = logical_plan(sql) - .expect_err("query should have failed") - .strip_backtrace(); - assert_eq!( - "Error during planning: UNION Column a (type: Boolean) is not compatible with column a (type: Utf8)", - err - ); -} - #[test] fn empty_over() { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2267,8 +2189,8 @@ fn empty_over() { fn empty_over_with_alias() { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; let expected = "\ - Projection: orders.order_id AS oid, MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid\ - \n WindowAggr: windowExpr=[[MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid\ + \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2277,8 +2199,8 @@ fn empty_over_with_alias() { fn empty_over_dup_with_alias() { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid, MAX(order_id) OVER () max_oid_dup from orders"; let expected = "\ - Projection: orders.order_id AS oid, MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid, MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid_dup\ - \n WindowAggr: windowExpr=[[MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max_oid_dup\ + \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2287,9 +2209,9 @@ fn empty_over_dup_with_alias() { fn empty_over_dup_with_different_sort() { let sql = "SELECT order_id oid, MAX(order_id) OVER (), MAX(order_id) OVER (ORDER BY order_id) from orders"; let expected = "\ - Projection: orders.order_id AS oid, MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MAX(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[MAX(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id AS oid, max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, max(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.order_id) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n WindowAggr: windowExpr=[[max(orders.order_id) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2298,18 +2220,18 @@ fn empty_over_dup_with_different_sort() { fn empty_over_plus() { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, max(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[max(orders.qty * Float64(1.1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } #[test] fn empty_over_multiple() { - let sql = "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (), min(qty) over (), avg(qty) OVER () from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, AVG(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, AVG(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, max(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[max(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, avg(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2327,8 +2249,8 @@ fn empty_over_multiple() { fn over_partition_by() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2349,9 +2271,9 @@ fn over_partition_by() { fn over_order_by() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2360,9 +2282,9 @@ fn over_order_by() { fn over_order_by_with_window_frame_double_end() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND 3 FOLLOWING]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2371,9 +2293,9 @@ fn over_order_by_with_window_frame_double_end() { fn over_order_by_with_window_frame_single_end() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id ROWS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] ROWS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2382,9 +2304,9 @@ fn over_order_by_with_window_frame_single_end() { fn over_order_by_with_window_frame_single_end_groups() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] GROUPS BETWEEN 3 PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2405,9 +2327,9 @@ fn over_order_by_with_window_frame_single_end_groups() { fn over_order_by_two_sort_keys() { let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, MIN(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id + Int64(1) ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2427,12 +2349,12 @@ fn over_order_by_two_sort_keys() { /// ``` #[test] fn over_order_by_sort_keys_sorting() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), sum(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[SUM(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2450,12 +2372,12 @@ fn over_order_by_sort_keys_sorting() { /// ``` #[test] fn over_order_by_sort_keys_sorting_prefix_compacting() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), sum(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[SUM(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2478,13 +2400,13 @@ fn over_order_by_sort_keys_sorting_prefix_compacting() { /// sort #[test] fn over_order_by_sort_keys_sorting_global_order_compacting() { - let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders ORDER BY order_id"; + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), sum(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders ORDER BY order_id"; let expected = "\ Sort: orders.order_id ASC NULLS LAST\ - \n Projection: orders.order_id, MAX(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[SUM(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n Projection: orders.order_id, max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[sum(orders.qty) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n WindowAggr: windowExpr=[[max(orders.qty) ORDER BY [orders.qty ASC NULLS LAST, orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) ORDER BY [orders.order_id ASC NULLS LAST, orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2503,8 +2425,8 @@ fn over_partition_by_order_by() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2523,8 +2445,8 @@ fn over_partition_by_order_by_no_dup() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2546,9 +2468,9 @@ fn over_partition_by_order_by_mix_up() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id, qty ORDER BY qty), MIN(qty) OVER (PARTITION BY qty ORDER BY order_id) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, MIN(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[min(orders.qty) PARTITION BY [orders.qty] ORDER BY [orders.order_id ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2569,9 +2491,9 @@ fn over_partition_by_order_by_mix_up_prefix() { let sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY order_id ORDER BY qty), MIN(qty) OVER (PARTITION BY order_id, qty ORDER BY price) from orders"; let expected = "\ - Projection: orders.order_id, MAX(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, MIN(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ - \n WindowAggr: windowExpr=[[MAX(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n WindowAggr: windowExpr=[[MIN(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\ + \n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ORDER BY [orders.qty ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n WindowAggr: windowExpr=[[min(orders.qty) PARTITION BY [orders.order_id, orders.qty] ORDER BY [orders.price ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2581,8 +2503,8 @@ fn approx_median_window() { let sql = "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from orders"; let expected = "\ - Projection: orders.order_id, APPROX_MEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[APPROX_MEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2624,8 +2546,8 @@ fn select_groupby_orderby() { // expect that this is not an ambiguous reference let expected = "Sort: birth_date ASC NULLS LAST\ - \n Projection: AVG(person.age) AS value, date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ - \n Aggregate: groupBy=[[person.birth_date]], aggr=[[AVG(person.age)]]\ + \n Projection: avg(person.age) AS value, date_trunc(Utf8(\"month\"), person.birth_date) AS birth_date\ + \n Aggregate: groupBy=[[person.birth_date]], aggr=[[avg(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -2664,7 +2586,8 @@ fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result Result { - let context = MockContextProvider::default(); + let state = MockSessionState::default().with_aggregate_function(sum_udaf()); + let context = MockContextProvider { state }; let planner = SqlToRel::new(&context); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; @@ -2676,30 +2599,45 @@ fn logical_plan_with_dialect_and_options( dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let context = MockContextProvider::default() - .with_udf(unicode::character_length().as_ref().clone()) - .with_udf(string::concat().as_ref().clone()) - .with_udf(make_udf( + let state = MockSessionState::default() + .with_scalar_function(Arc::new(unicode::character_length().as_ref().clone())) + .with_scalar_function(Arc::new(string::concat().as_ref().clone())) + .with_scalar_function(Arc::new(make_udf( "nullif", vec![DataType::Int32, DataType::Int32], DataType::Int32, - )) - .with_udf(make_udf( + ))) + .with_scalar_function(Arc::new(make_udf( "round", vec![DataType::Float64, DataType::Int64], DataType::Float32, - )) - .with_udf(make_udf( + ))) + .with_scalar_function(Arc::new(make_udf( "arrow_cast", vec![DataType::Int64, DataType::Utf8], DataType::Float64, - )) - .with_udf(make_udf( + ))) + .with_scalar_function(Arc::new(make_udf( "date_trunc", vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)], DataType::Int32, - )) - .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)); + ))) + .with_scalar_function(Arc::new(make_udf( + "sqrt", + vec![DataType::Int64], + DataType::Int64, + ))) + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(approx_median_udaf()) + .with_aggregate_function(count_udaf()) + .with_aggregate_function(avg_udaf()) + .with_aggregate_function(min_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())); + + let context = MockContextProvider { state }; let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; @@ -2757,7 +2695,7 @@ fn quick_test(sql: &str, expected: &str) { fn quick_test_with_options(sql: &str, expected: &str, options: ParserOptions) { let plan = logical_plan_with_options(sql, options).unwrap(); - assert_eq!(format!("{plan:?}"), expected); + assert_eq!(format!("{plan}"), expected); } fn prepare_stmt_quick_test( @@ -2769,7 +2707,7 @@ fn prepare_stmt_quick_test( let assert_plan = plan.clone(); // verify plan - assert_eq!(format!("{assert_plan:?}"), expected_plan); + assert_eq!(format!("{assert_plan}"), expected_plan); // verify data types if let LogicalPlan::Prepare(Prepare { data_types, .. }) = assert_plan { @@ -2787,174 +2725,11 @@ fn prepare_stmt_replace_params_quick_test( ) -> LogicalPlan { // replace params let plan = plan.with_param_values(param_values).unwrap(); - assert_eq!(format!("{plan:?}"), expected_plan); + assert_eq!(format!("{plan}"), expected_plan); plan } -#[derive(Default)] -struct MockContextProvider { - options: ConfigOptions, - udfs: HashMap>, - udafs: HashMap>, -} - -impl MockContextProvider { - fn options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - - fn with_udf(mut self, udf: ScalarUDF) -> Self { - self.udfs.insert(udf.name().to_string(), Arc::new(udf)); - self - } -} - -impl ContextProvider for MockContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - let schema = match name.table() { - "test" => Ok(Schema::new(vec![ - Field::new("t_date32", DataType::Date32, false), - Field::new("t_date64", DataType::Date64, false), - ])), - "j1" => Ok(Schema::new(vec![ - Field::new("j1_id", DataType::Int32, false), - Field::new("j1_string", DataType::Utf8, false), - ])), - "j2" => Ok(Schema::new(vec![ - Field::new("j2_id", DataType::Int32, false), - Field::new("j2_string", DataType::Utf8, false), - ])), - "j3" => Ok(Schema::new(vec![ - Field::new("j3_id", DataType::Int32, false), - Field::new("j3_string", DataType::Utf8, false), - ])), - "test_decimal" => Ok(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("price", DataType::Decimal128(10, 2), false), - ])), - "person" => Ok(Schema::new(vec![ - Field::new("id", DataType::UInt32, false), - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new("age", DataType::Int32, false), - Field::new("state", DataType::Utf8, false), - Field::new("salary", DataType::Float64, false), - Field::new( - "birth_date", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new("😀", DataType::Int32, false), - ])), - "person_quoted_cols" => Ok(Schema::new(vec![ - Field::new("id", DataType::UInt32, false), - Field::new("First Name", DataType::Utf8, false), - Field::new("Last Name", DataType::Utf8, false), - Field::new("Age", DataType::Int32, false), - Field::new("State", DataType::Utf8, false), - Field::new("Salary", DataType::Float64, false), - Field::new( - "Birth Date", - DataType::Timestamp(TimeUnit::Nanosecond, None), - false, - ), - Field::new("😀", DataType::Int32, false), - ])), - "orders" => Ok(Schema::new(vec![ - Field::new("order_id", DataType::UInt32, false), - Field::new("customer_id", DataType::UInt32, false), - Field::new("o_item_id", DataType::Utf8, false), - Field::new("qty", DataType::Int32, false), - Field::new("price", DataType::Float64, false), - Field::new("delivered", DataType::Boolean, false), - ])), - "array" => Ok(Schema::new(vec![ - Field::new( - "left", - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - false, - ), - Field::new( - "right", - DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), - false, - ), - ])), - "lineitem" => Ok(Schema::new(vec![ - Field::new("l_item_id", DataType::UInt32, false), - Field::new("l_description", DataType::Utf8, false), - Field::new("price", DataType::Float64, false), - ])), - "aggregate_test_100" => Ok(Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::UInt32, false), - Field::new("c3", DataType::Int8, false), - Field::new("c4", DataType::Int16, false), - Field::new("c5", DataType::Int32, false), - Field::new("c6", DataType::Int64, false), - Field::new("c7", DataType::UInt8, false), - Field::new("c8", DataType::UInt16, false), - Field::new("c9", DataType::UInt32, false), - Field::new("c10", DataType::UInt64, false), - Field::new("c11", DataType::Float32, false), - Field::new("c12", DataType::Float64, false), - Field::new("c13", DataType::Utf8, false), - ])), - "UPPERCASE_test" => Ok(Schema::new(vec![ - Field::new("Id", DataType::UInt32, false), - Field::new("lower", DataType::UInt32, false), - ])), - _ => plan_err!("No table named: {} found", name.table()), - }; - - match schema { - Ok(t) => Ok(Arc::new(EmptyTable::new(Arc::new(t)))), - Err(e) => Err(e), - } - } - - fn get_function_meta(&self, name: &str) -> Option> { - self.udfs.get(name).cloned() - } - - fn get_aggregate_meta(&self, name: &str) -> Option> { - self.udafs.get(name).cloned() - } - - fn get_variable_type(&self, _: &[String]) -> Option { - unimplemented!() - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn create_cte_work_table( - &self, - _name: &str, - schema: SchemaRef, - ) -> Result> { - Ok(Arc::new(EmptyTable::new(schema))) - } - - fn udfs_names(&self) -> Vec { - self.udfs.keys().cloned().collect() - } - - fn udafs_names(&self) -> Vec { - self.udafs.keys().cloned().collect() - } - - fn udwfs_names(&self) -> Vec { - Vec::new() - } -} - #[test] fn select_partially_qualified_column() { let sql = r#"SELECT person.first_name FROM public.person"#; @@ -2969,8 +2744,8 @@ fn cross_join_not_to_inner_join() { "select person.id from person, orders, lineitem where person.id = person.age;"; let expected = "Projection: person.id\ \n Filter: person.id = person.age\ - \n CrossJoin:\ - \n CrossJoin:\ + \n Cross Join: \ + \n Cross Join: \ \n TableScan: person\ \n TableScan: orders\ \n TableScan: lineitem"; @@ -2993,7 +2768,7 @@ fn join_with_aliases() { fn negative_interval_plus_interval_in_projection() { let sql = "select -interval '2 days' + interval '5 days';"; let expected = - "Projection: IntervalMonthDayNano(\"79228162477370849446124847104\") + IntervalMonthDayNano(\"92233720368547758080\")\n EmptyRelation"; + "Projection: IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: -2, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\")\n EmptyRelation"; quick_test(sql, expected); } @@ -3001,7 +2776,7 @@ fn negative_interval_plus_interval_in_projection() { fn complex_interval_expression_in_projection() { let sql = "select -interval '2 days' + interval '5 days'+ (-interval '3 days' + interval '5 days');"; let expected = - "Projection: IntervalMonthDayNano(\"79228162477370849446124847104\") + IntervalMonthDayNano(\"92233720368547758080\") + IntervalMonthDayNano(\"79228162458924105372415295488\") + IntervalMonthDayNano(\"92233720368547758080\")\n EmptyRelation"; + "Projection: IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: -2, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: -3, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\")\n EmptyRelation"; quick_test(sql, expected); } @@ -3009,7 +2784,7 @@ fn complex_interval_expression_in_projection() { fn negative_sum_intervals_in_projection() { let sql = "select -((interval '2 days' + interval '5 days') + -(interval '4 days' + interval '7 days'));"; let expected = - "Projection: (- IntervalMonthDayNano(\"36893488147419103232\") + IntervalMonthDayNano(\"92233720368547758080\") + (- IntervalMonthDayNano(\"73786976294838206464\") + IntervalMonthDayNano(\"129127208515966861312\")))\n EmptyRelation"; + "Projection: (- IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 2, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\") + (- IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 4, nanoseconds: 0 }\") + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 7, nanoseconds: 0 }\")))\n EmptyRelation"; quick_test(sql, expected); } @@ -3017,8 +2792,7 @@ fn negative_sum_intervals_in_projection() { fn date_plus_interval_in_projection() { let sql = "select t_date32 + interval '5 days' FROM test"; let expected = - "Projection: test.t_date32 + IntervalMonthDayNano(\"92233720368547758080\")\ - \n TableScan: test"; + "Projection: test.t_date32 + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 5, nanoseconds: 0 }\")\n TableScan: test"; quick_test(sql, expected); } @@ -3030,7 +2804,7 @@ fn date_plus_interval_in_filter() { AND cast('1999-12-31' as date) + interval '30 days'"; let expected = "Projection: test.t_date64\ - \n Filter: test.t_date64 BETWEEN CAST(Utf8(\"1999-12-31\") AS Date32) AND CAST(Utf8(\"1999-12-31\") AS Date32) + IntervalMonthDayNano(\"553402322211286548480\")\ + \n Filter: test.t_date64 BETWEEN CAST(Utf8(\"1999-12-31\") AS Date32) AND CAST(Utf8(\"1999-12-31\") AS Date32) + IntervalMonthDayNano(\"IntervalMonthDayNano { months: 0, days: 30, nanoseconds: 0 }\")\ \n TableScan: test"; quick_test(sql, expected); } @@ -3068,11 +2842,11 @@ fn exists_subquery_schema_outer_schema_overlap() { \n Subquery:\ \n Projection: person.first_name\ \n Filter: person.id = p2.id AND person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p2\ \n TableScan: person\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: person\ \n SubqueryAlias: p\ \n TableScan: person"; @@ -3089,7 +2863,7 @@ fn exists_subquery_wildcard() { let expected = "Projection: p.id\ \n Filter: EXISTS ()\ \n Subquery:\ - \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ + \n Projection: *\ \n Filter: person.last_name = outer_ref(p.last_name) AND person.state = outer_ref(p.state)\ \n TableScan: person\ \n SubqueryAlias: p\ @@ -3135,8 +2909,8 @@ fn scalar_subquery() { let expected = "Projection: p.id, ()\ \n Subquery:\ - \n Projection: MAX(person.id)\ - \n Aggregate: groupBy=[[]], aggr=[[MAX(person.id)]]\ + \n Projection: max(person.id)\ + \n Aggregate: groupBy=[[]], aggr=[[max(person.id)]]\ \n Filter: person.last_name = outer_ref(p.last_name)\ \n TableScan: person\ \n SubqueryAlias: p\ @@ -3157,13 +2931,13 @@ fn scalar_subquery_reference_outer_field() { let expected = "Projection: j1.j1_string, j2.j2_string\ \n Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < ()\ \n Subquery:\ - \n Projection: COUNT(*)\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(*)]]\ + \n Projection: count(*)\ + \n Aggregate: groupBy=[[]], aggr=[[count(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j3\ - \n CrossJoin:\ + \n Cross Join: \ \n TableScan: j1\ \n TableScan: j2"; @@ -3176,13 +2950,13 @@ fn subquery_references_cte() { cte AS (SELECT * FROM person) \ SELECT * FROM person WHERE EXISTS (SELECT * FROM cte WHERE id = person.id)"; - let expected = "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ + let expected = "Projection: *\ \n Filter: EXISTS ()\ \n Subquery:\ - \n Projection: cte.id, cte.first_name, cte.last_name, cte.age, cte.state, cte.salary, cte.birth_date, cte.😀\ + \n Projection: *\ \n Filter: cte.id = outer_ref(person.id)\ \n SubqueryAlias: cte\ - \n Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀\ + \n Projection: *\ \n TableScan: person\ \n TableScan: person"; @@ -3197,7 +2971,7 @@ fn cte_with_no_column_names() { ) \ SELECT * FROM numbers;"; - let expected = "Projection: numbers.a, numbers.b, numbers.c\ + let expected = "Projection: *\ \n SubqueryAlias: numbers\ \n Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c\ \n EmptyRelation"; @@ -3213,7 +2987,7 @@ fn cte_with_column_names() { ) \ SELECT * FROM numbers;"; - let expected = "Projection: numbers.a, numbers.b, numbers.c\ + let expected = "Projection: *\ \n SubqueryAlias: numbers\ \n Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c\ \n Projection: Int64(1), Int64(2), Int64(3)\ @@ -3231,7 +3005,7 @@ fn cte_with_column_aliases_precedence() { ) \ SELECT * FROM numbers;"; - let expected = "Projection: numbers.a, numbers.b, numbers.c\ + let expected = "Projection: *\ \n SubqueryAlias: numbers\ \n Projection: x AS a, y AS b, z AS c\ \n Projection: Int64(1) AS x, Int64(2) AS y, Int64(3) AS z\ @@ -3255,19 +3029,19 @@ fn cte_unbalanced_number_of_columns() { #[test] fn aggregate_with_rollup() { let sql = - "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ + "SELECT id, state, age, count(*) FROM person GROUP BY id, ROLLUP (state, age)"; + let expected = "Projection: person.id, person.state, person.age, count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } #[test] fn aggregate_with_rollup_with_grouping() { - let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \ + let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), count(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(*)]]\ + let expected = "Projection: person.id, person.state, person.age, grouping(person.state), grouping(person.age), grouping(person.state) + grouping(person.age), count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[grouping(person.state), grouping(person.age), count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3287,9 +3061,9 @@ fn rank_partition_grouping() { from person group by rollup(state, last_name)"; - let expected = "Projection: SUM(person.age) AS total_sum, person.state, person.last_name, GROUPING(person.state) + GROUPING(person.last_name) AS x, RANK() PARTITION BY [GROUPING(person.state) + GROUPING(person.last_name), CASE WHEN GROUPING(person.last_name) = Int64(0) THEN person.state END] ORDER BY [SUM(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ - \n WindowAggr: windowExpr=[[RANK() PARTITION BY [GROUPING(person.state) + GROUPING(person.last_name), CASE WHEN GROUPING(person.last_name) = Int64(0) THEN person.state END] ORDER BY [SUM(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ - \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[SUM(person.age), GROUPING(person.state), GROUPING(person.last_name)]]\ + let expected = "Projection: sum(person.age) AS total_sum, person.state, person.last_name, grouping(person.state) + grouping(person.last_name) AS x, rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS the_rank\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [grouping(person.state) + grouping(person.last_name), CASE WHEN grouping(person.last_name) = Int64(0) THEN person.state END] ORDER BY [sum(person.age) DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]\ + \n Aggregate: groupBy=[[ROLLUP (person.state, person.last_name)]], aggr=[[sum(person.age), grouping(person.state), grouping(person.last_name)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3297,9 +3071,9 @@ fn rank_partition_grouping() { #[test] fn aggregate_with_cube() { let sql = - "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ + "SELECT id, state, age, count(*) FROM person GROUP BY id, CUBE (state, age)"; + let expected = "Projection: person.id, person.state, person.age, count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3314,9 +3088,9 @@ fn round_decimal() { #[test] fn aggregate_with_grouping_sets() { - let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(*)]]\ + let sql = "SELECT id, state, age, count(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; + let expected = "Projection: person.id, person.state, person.age, count(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[count(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3345,13 +3119,121 @@ fn join_on_complex_condition() { quick_test(sql, expected); } +#[test] +fn lateral_constant() { + let sql = "SELECT * FROM j1, LATERAL (SELECT 1) AS j2"; + let expected = "Projection: *\ + \n Cross Join: \ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: Int64(1)\ + \n EmptyRelation"; + quick_test(sql, expected); +} + +#[test] +fn lateral_comma_join() { + let sql = "SELECT j1_string, j2_string FROM + j1, \ + LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2"; + let expected = "Projection: j1.j1_string, j2.j2_string\ + \n Cross Join: \ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) < j2.j2_id\ + \n TableScan: j2"; + quick_test(sql, expected); +} + +#[test] +fn lateral_comma_join_referencing_join_rhs() { + let sql = "SELECT * FROM\ + \n j1 JOIN (j2 JOIN j3 ON(j2_id = j3_id - 2)) ON(j1_id = j2_id),\ + \n LATERAL (SELECT * FROM j3 WHERE j3_string = j2_string) as j4;"; + let expected = "Projection: *\ + \n Cross Join: \ + \n Inner Join: Filter: j1.j1_id = j2.j2_id\ + \n TableScan: j1\ + \n Inner Join: Filter: j2.j2_id = j3.j3_id - Int64(2)\ + \n TableScan: j2\ + \n TableScan: j3\ + \n SubqueryAlias: j4\ + \n Subquery:\ + \n Projection: *\ + \n Filter: j3.j3_string = outer_ref(j2.j2_string)\ + \n TableScan: j3"; + quick_test(sql, expected); +} + +#[test] +fn lateral_comma_join_with_shadowing() { + // The j1_id on line 3 references the (closest) j1 definition from line 2. + let sql = "\ + SELECT * FROM j1, LATERAL (\ + SELECT * FROM j1, LATERAL (\ + SELECT * FROM j2 WHERE j1_id = j2_id\ + ) as j2\ + ) as j2;"; + let expected = "Projection: *\ + \n Cross Join: \ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Cross Join: \ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) = j2.j2_id\ + \n TableScan: j2"; + quick_test(sql, expected); +} + +#[test] +fn lateral_left_join() { + let sql = "SELECT j1_string, j2_string FROM \ + j1 \ + LEFT JOIN LATERAL (SELECT * FROM j2 WHERE j1_id < j2_id) AS j2 ON(true);"; + let expected = "Projection: j1.j1_string, j2.j2_string\ + \n Left Join: Filter: Boolean(true)\ + \n TableScan: j1\ + \n SubqueryAlias: j2\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) < j2.j2_id\ + \n TableScan: j2"; + quick_test(sql, expected); +} + +#[test] +fn lateral_nested_left_join() { + let sql = "SELECT * FROM + j1, \ + (j2 LEFT JOIN LATERAL (SELECT * FROM j3 WHERE j1_id + j2_id = j3_id) AS j3 ON(true))"; + let expected = "Projection: *\ + \n Cross Join: \ + \n TableScan: j1\ + \n Left Join: Filter: Boolean(true)\ + \n TableScan: j2\ + \n SubqueryAlias: j3\ + \n Subquery:\ + \n Projection: *\ + \n Filter: outer_ref(j1.j1_id) + outer_ref(j2.j2_id) = j3.j3_id\ + \n TableScan: j3"; + quick_test(sql, expected); +} + #[test] fn hive_aggregate_with_filter() -> Result<()> { let dialect = &HiveDialect {}; - let sql = "SELECT SUM(age) FILTER (WHERE age > 4) FROM person"; + let sql = "SELECT sum(age) FILTER (WHERE age > 4) FROM person"; let plan = logical_plan_with_dialect(sql, dialect)?; - let expected = "Projection: SUM(person.age) FILTER (WHERE person.age > Int64(4))\ - \n Aggregate: groupBy=[[]], aggr=[[SUM(person.age) FILTER (WHERE person.age > Int64(4))]]\ + let expected = "Projection: sum(person.age) FILTER (WHERE person.age > Int64(4))\ + \n Aggregate: groupBy=[[]], aggr=[[sum(person.age) FILTER (WHERE person.age > Int64(4))]]\ \n TableScan: person" .to_string(); assert_eq!(plan.display_indent().to_string(), expected); @@ -3367,8 +3249,8 @@ fn order_by_unaliased_name() { "select p.state z, sum(age) q from person p group by p.state order by p.state"; let expected = "Projection: z, q\ \n Sort: p.state ASC NULLS LAST\ - \n Projection: p.state AS z, SUM(p.age) AS q, p.state\ - \n Aggregate: groupBy=[[p.state]], aggr=[[SUM(p.age)]]\ + \n Projection: p.state AS z, sum(p.age) AS q, p.state\ + \n Aggregate: groupBy=[[p.state]], aggr=[[sum(p.age)]]\ \n SubqueryAlias: p\ \n TableScan: person"; quick_test(sql, expected); @@ -3439,7 +3321,7 @@ fn test_offset_before_limit() { #[test] fn test_distribute_by() { let sql = "select id from person distribute by state"; - let expected = "Repartition: DistributeBy(state)\ + let expected = "Repartition: DistributeBy(person.state)\ \n Projection: person.id\ \n TableScan: person"; quick_test(sql, expected); @@ -3612,7 +3494,7 @@ fn test_select_all_inner_join() { INNER JOIN orders \ ON orders.customer_id * 2 = person.id + 10"; - let expected = "Projection: person.id, person.first_name, person.last_name, person.age, person.state, person.salary, person.birth_date, person.😀, orders.order_id, orders.customer_id, orders.o_item_id, orders.qty, orders.price, orders.delivered\ + let expected = "Projection: *\ \n Inner Join: Filter: orders.customer_id * Int64(2) = person.id + Int64(10)\ \n TableScan: person\ \n TableScan: orders"; @@ -3646,7 +3528,7 @@ fn test_select_distinct_order_by() { let sql = "SELECT distinct '1' from person order by id"; let expected = - "Error during planning: For SELECT DISTINCT, ORDER BY expressions id must appear in select list"; + "Error during planning: For SELECT DISTINCT, ORDER BY expressions person.id must appear in select list"; // It should return error. let result = logical_plan(sql); @@ -3657,7 +3539,7 @@ fn test_select_distinct_order_by() { #[rstest] #[case::select_cluster_by_unsupported( - "SELECT customer_name, SUM(order_total) as total_order_amount FROM orders CLUSTER BY customer_name", + "SELECT customer_name, sum(order_total) as total_order_amount FROM orders CLUSTER BY customer_name", "This feature is not implemented: CLUSTER BY" )] #[case::select_lateral_view_unsupported( @@ -3795,7 +3677,7 @@ fn test_prepare_statement_to_plan_panic_prepare_wrong_syntax() { let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: ParserError(\"Expected AS, found: SELECT\")" + "SQL error: ParserError(\"Expected: AS, found: SELECT\")" ) } @@ -3836,7 +3718,7 @@ fn test_non_prepare_statement_should_infer_types() { #[test] #[should_panic( - expected = "value: SQL(ParserError(\"Expected [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" + expected = "value: SQL(ParserError(\"Expected: [NOT] NULL or TRUE|FALSE or [NOT] DISTINCT FROM after IS, found: $1\"" )] fn test_prepare_statement_to_plan_panic_is_param() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; @@ -3947,7 +3829,7 @@ fn test_prepare_statement_to_plan_params_as_constants() { /////////////////// // replace params with values let param_values = vec![ScalarValue::Int32(Some(10))]; - let expected_plan = "Projection: Int32(10)\n EmptyRelation"; + let expected_plan = "Projection: Int32(10) AS $1\n EmptyRelation"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -3963,7 +3845,8 @@ fn test_prepare_statement_to_plan_params_as_constants() { /////////////////// // replace params with values let param_values = vec![ScalarValue::Int32(Some(10))]; - let expected_plan = "Projection: Int64(1) + Int32(10)\n EmptyRelation"; + let expected_plan = + "Projection: Int64(1) + Int32(10) AS Int64(1) + $1\n EmptyRelation"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -3982,7 +3865,9 @@ fn test_prepare_statement_to_plan_params_as_constants() { ScalarValue::Int32(Some(10)), ScalarValue::Float64(Some(10.0)), ]; - let expected_plan = "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation"; + let expected_plan = + "Projection: Int64(1) + Int32(10) + Float64(10) AS Int64(1) + $1 + $2\ + \n EmptyRelation"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } @@ -4095,8 +3980,8 @@ fn test_prepare_statement_infer_types_subquery() { Projection: person.id, person.age Filter: person.age = () Subquery: - Projection: MAX(person.age) - Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]] + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] Filter: person.id = $1 TableScan: person TableScan: person @@ -4116,8 +4001,8 @@ Projection: person.id, person.age Projection: person.id, person.age Filter: person.age = () Subquery: - Projection: MAX(person.age) - Aggregate: groupBy=[[]], aggr=[[MAX(person.age)]] + Projection: max(person.age) + Aggregate: groupBy=[[]], aggr=[[max(person.age)]] Filter: person.id = UInt32(10) TableScan: person TableScan: person @@ -4197,7 +4082,7 @@ fn test_prepare_statement_insert_infer() { \n Projection: column1 AS id, column2 AS first_name, column3 AS last_name, \ CAST(NULL AS Int32) AS age, CAST(NULL AS Utf8) AS state, CAST(NULL AS Float64) AS salary, \ CAST(NULL AS Timestamp(Nanosecond, None)) AS birth_date, CAST(NULL AS Int32) AS 😀\ - \n Values: (UInt32(1), Utf8(\"Alan\"), Utf8(\"Turing\"))"; + \n Values: (UInt32(1) AS $1, Utf8(\"Alan\") AS $2, Utf8(\"Turing\") AS $3)"; let plan = plan.replace_params_with_values(¶m_values).unwrap(); prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); @@ -4278,7 +4163,7 @@ fn test_prepare_statement_to_plan_multi_params() { ScalarValue::from("xyz"), ]; let expected_plan = - "Projection: person.id, person.age, Utf8(\"xyz\")\ + "Projection: person.id, person.age, Utf8(\"xyz\") AS $6\ \n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\ \n TableScan: person"; @@ -4288,17 +4173,17 @@ fn test_prepare_statement_to_plan_multi_params() { #[test] fn test_prepare_statement_to_plan_having() { let sql = "PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS - SELECT id, SUM(age) + SELECT id, sum(age) FROM person \ WHERE salary > $2 GROUP BY id - HAVING sum(age) < $1 AND SUM(age) > 10 OR SUM(age) in ($3, $4)\ + HAVING sum(age) < $1 AND sum(age) > 10 OR sum(age) in ($3, $4)\ "; let expected_plan = "Prepare: \"my_plan\" [Int32, Float64, Float64, Float64] \ - \n Projection: person.id, SUM(person.age)\ - \n Filter: SUM(person.age) < $1 AND SUM(person.age) > Int64(10) OR SUM(person.age) IN ([$3, $4])\ - \n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\ + \n Projection: person.id, sum(person.age)\ + \n Filter: sum(person.age) < $1 AND sum(person.age) > Int64(10) OR sum(person.age) IN ([$3, $4])\ + \n Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]\ \n Filter: person.salary > $2\ \n TableScan: person"; @@ -4315,21 +4200,44 @@ fn test_prepare_statement_to_plan_having() { ScalarValue::Float64(Some(300.0)), ]; let expected_plan = - "Projection: person.id, SUM(person.age)\ - \n Filter: SUM(person.age) < Int32(10) AND SUM(person.age) > Int64(10) OR SUM(person.age) IN ([Float64(200), Float64(300)])\ - \n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\ + "Projection: person.id, sum(person.age)\ + \n Filter: sum(person.age) < Int32(10) AND sum(person.age) > Int64(10) OR sum(person.age) IN ([Float64(200), Float64(300)])\ + \n Aggregate: groupBy=[[person.id]], aggr=[[sum(person.age)]]\ \n Filter: person.salary > Float64(100)\ \n TableScan: person"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_to_plan_limit() { + let sql = "PREPARE my_plan(BIGINT, BIGINT) AS + SELECT id FROM person \ + OFFSET $1 LIMIT $2"; + + let expected_plan = "Prepare: \"my_plan\" [Int64, Int64] \ + \n Limit: skip=$1, fetch=$2\ + \n Projection: person.id\ + \n TableScan: person"; + + let expected_dt = "[Int64, Int64]"; + + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + // replace params with values + let param_values = vec![ScalarValue::Int64(Some(10)), ScalarValue::Int64(Some(200))]; + let expected_plan = "Limit: skip=10, fetch=200\ + \n Projection: person.id\ + \n TableScan: person"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); +} + #[test] fn test_prepare_statement_to_plan_value_list() { let sql = "PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter);"; let expected_plan = "Prepare: \"my_plan\" [Utf8, Utf8] \ - \n Projection: t.num, t.letter\ + \n Projection: *\ \n SubqueryAlias: t\ \n Projection: column1 AS num, column2 AS letter\ \n Values: (Int64(1), $1), (Int64(2), $2)"; @@ -4344,10 +4252,10 @@ fn test_prepare_statement_to_plan_value_list() { ScalarValue::from("a".to_string()), ScalarValue::from("b".to_string()), ]; - let expected_plan = "Projection: t.num, t.letter\ + let expected_plan = "Projection: *\ \n SubqueryAlias: t\ \n Projection: column1 AS num, column2 AS letter\ - \n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))"; + \n Values: (Int64(1), Utf8(\"a\") AS $1), (Int64(2), Utf8(\"b\") AS $2)"; prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } @@ -4394,9 +4302,9 @@ fn test_table_alias() { (select age from person) t2 \ ) as f"; - let expected = "Projection: f.id, f.age\ + let expected = "Projection: *\ \n SubqueryAlias: f\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ @@ -4411,10 +4319,10 @@ fn test_table_alias() { (select age from person) t2 \ ) as f (c1, c2)"; - let expected = "Projection: f.c1, f.c2\ + let expected = "Projection: *\ \n SubqueryAlias: f\ \n Projection: t1.id AS c1, t2.age AS c2\ - \n CrossJoin:\ + \n Cross Join: \ \n SubqueryAlias: t1\ \n Projection: person.id\ \n TableScan: person\ @@ -4489,7 +4397,7 @@ fn test_field_not_found_window_function() { let qualified_sql = "SELECT order_id, MAX(qty) OVER (PARTITION BY orders.order_id) from orders"; - let expected = "Projection: orders.order_id, MAX(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\n WindowAggr: windowExpr=[[MAX(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\n TableScan: orders"; + let expected = "Projection: orders.order_id, max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\n WindowAggr: windowExpr=[[max(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\n TableScan: orders"; quick_test(qualified_sql, expected); } @@ -4515,10 +4423,39 @@ fn test_parse_escaped_string_literal_value() { let sql = r"SELECT character_length(E'\000') AS len"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 25\")" + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column: 25\")" ) } +#[test] +fn plan_create_index() { + let sql = + "CREATE UNIQUE INDEX IF NOT EXISTS idx_name ON test USING btree (name, age DESC)"; + let plan = logical_plan_with_options(sql, ParserOptions::default()).unwrap(); + match plan { + LogicalPlan::Ddl(DdlStatement::CreateIndex(CreateIndex { + name, + table, + using, + columns, + unique, + if_not_exists, + .. + })) => { + assert_eq!(name, Some("idx_name".to_string())); + assert_eq!(format!("{table}"), "test"); + assert_eq!(using, Some("btree".to_string())); + assert_eq!( + columns, + vec![col("name").sort(true, false), col("age").sort(false, true),] + ); + assert!(unique); + assert!(if_not_exists); + } + _ => panic!("wrong plan type"), + } +} + fn assert_field_not_found(err: DataFusionError, name: &str) { match err { DataFusionError::SchemaError { .. } => { @@ -4532,151 +4469,29 @@ fn assert_field_not_found(err: DataFusionError, name: &str) { } } -struct EmptyTable { - table_schema: SchemaRef, -} - -impl EmptyTable { - fn new(table_schema: SchemaRef) -> Self { - Self { table_schema } - } -} - -impl TableSource for EmptyTable { - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn schema(&self) -> SchemaRef { - self.table_schema.clone() - } +#[cfg(test)] +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for tests + let _ = env_logger::try_init(); } #[test] -fn roundtrip_expr() { - let tests: Vec<(TableReference, &str, &str)> = vec![ - ( - TableReference::bare("person"), - "age > 35", - r#"("age" > 35)"#, - ), - ( - TableReference::bare("person"), - "id = '10'", - r#"("id" = '10')"#, - ), - ( - TableReference::bare("person"), - "CAST(id AS VARCHAR)", - r#"CAST("id" AS VARCHAR)"#, - ), - ( - TableReference::bare("person"), - "SUM((age * 2))", - r#"SUM(("age" * 2))"#, - ), - ]; - - let roundtrip = |table, sql: &str| -> Result { - let dialect = GenericDialect {}; - let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?; - - let context = MockContextProvider::default(); - let schema = context.get_table_source(table)?.schema(); - let df_schema = DFSchema::try_from(schema.as_ref().clone())?; - let sql_to_rel = SqlToRel::new(&context); - let expr = - sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; - - let ast = expr_to_sql(&expr)?; - - Ok(format!("{}", ast)) - }; - - for (table, query, expected) in tests { - let actual = roundtrip(table, query).unwrap(); - assert_eq!(actual, expected); - } -} +fn test_no_functions_registered() { + let sql = "SELECT foo()"; -#[test] -fn roundtrip_statement() -> Result<()> { - let tests: Vec<&str> = vec![ - "select ta.j1_id from j1 ta;", - "select ta.j1_id from j1 ta order by ta.j1_id;", - "select * from j1 ta order by ta.j1_id, ta.j1_string desc;", - "select * from j1 limit 10;", - "select ta.j1_id from j1 ta where ta.j1_id > 1;", - "select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id);", - "select ta.j1_id, tb.j2_string, tc.j3_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id) join j3 tc on (ta.j1_id = tc.j3_id);", - "select * from (select id, first_name from person)", - "select * from (select id, first_name from (select * from person))", - "select id, count(*) as cnt from (select id from person) group by id", - "select (id-1)/2, count(*) / (sum(id/10)-1) as agg_expr from (select (id-1) as id from person) group by id", - "select CAST(id/2 as VARCHAR) NOT LIKE 'foo*' from person where NOT EXISTS (select ta.j1_id, tb.j2_string from j1 ta join j2 tb on (ta.j1_id = tb.j2_id))", - r#"select "First Name" from person_quoted_cols"#, - r#"select id, count("First Name") as cnt from (select id, "First Name" from person_quoted_cols) group by id"#, - "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", - "select id, count(*), first_name from person group by first_name, id", - "select id, sum(age), first_name from person group by first_name, id", - "select id, count(*), first_name - from person - where id!=3 and first_name=='test' - group by first_name, id - having count(*)>5 and count(*)<10 - order by count(*)", - r#"select id, count("First Name") as count_first_name, "Last Name" - from person_quoted_cols - where id!=3 and "First Name"=='test' - group by "Last Name", id - having count_first_name>5 and count_first_name<10 - order by count_first_name, "Last Name""#, - r#"select p.id, count("First Name") as count_first_name, - "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) - from (select id, "First Name", "Last Name" from person_quoted_cols) qp - inner join (select * from person) p - on p.id = qp.id - where p.id!=3 and "First Name"=='test' and qp.id in - (select id from (select id, count(*) from person group by id having count(*) > 0)) - group by "Last Name", p.id - having count_first_name>5 and count_first_name<10 - order by count_first_name, "Last Name""#, - ]; - - // For each test sql string, we transform as follows: - // sql -> ast::Statement (s1) -> LogicalPlan (p1) -> ast::Statement (s2) -> LogicalPlan (p2) - // We test not that s1==s2, but rather p1==p2. This ensures that unparser preserves the logical - // query information of the original sql string and disreguards other differences in syntax or - // quoting. - for query in tests { - let dialect = GenericDialect {}; - let statement = Parser::new(&dialect) - .try_with_sql(query)? - .parse_statement()?; - - let context = MockContextProvider::default(); - let sql_to_rel = SqlToRel::new(&context); - let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap(); - - let roundtrip_statement = plan_to_sql(&plan)?; - - let actual = format!("{}", &roundtrip_statement); - println!("roundtrip sql: {actual}"); - println!("plan {}", plan.display_indent()); - - let plan_roundtrip = sql_to_rel - .sql_statement_to_plan(roundtrip_statement.clone()) - .unwrap(); - - assert_eq!(plan, plan_roundtrip); - } + let options = ParserOptions::default(); + let dialect = &GenericDialect {}; + let state = MockSessionState::default(); + let context = MockContextProvider { state }; + let planner = SqlToRel::new_with_options(&context, options); + let result = DFParser::parse_sql_with_dialect(sql, dialect); + let mut ast = result.unwrap(); - Ok(()) -} + let err = planner.statement_to_plan(ast.pop_front().unwrap()); -#[cfg(test)] -#[ctor::ctor] -fn init() { - // Enable RUST_LOG logging configuration for tests - let _ = env_logger::try_init(); + assert_contains!( + err.unwrap_err().to_string(), + "Internal error: No functions registered with this context." + ); } diff --git a/datafusion/sqllogictest/Cargo.toml b/datafusion/sqllogictest/Cargo.toml index c652c8041ff1..07dbc60e86bc 100644 --- a/datafusion/sqllogictest/Cargo.toml +++ b/datafusion/sqllogictest/Cargo.toml @@ -39,8 +39,8 @@ async-trait = { workspace = true } bigdecimal = { workspace = true } bytes = { workspace = true, optional = true } chrono = { workspace = true, optional = true } -clap = { version = "4.4.8", features = ["derive", "env"] } -datafusion = { workspace = true, default-features = true } +clap = { version = "4.5.16", features = ["derive", "env"] } +datafusion = { workspace = true, default-features = true, features = ["avro"] } datafusion-common = { workspace = true, default-features = true } datafusion-common-runtime = { workspace = true, default-features = true } futures = { workspace = true } @@ -51,7 +51,7 @@ object_store = { workspace = true } postgres-protocol = { version = "0.6.4", optional = true } postgres-types = { version = "0.2.4", optional = true } rust_decimal = { version = "1.27.0" } -sqllogictest = "0.20.0" +sqllogictest = "0.22.0" sqlparser = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } @@ -60,7 +60,13 @@ tokio-postgres = { version = "0.7.7", optional = true } [features] avro = ["datafusion/avro"] -postgres = ["bytes", "chrono", "tokio-postgres", "postgres-types", "postgres-protocol"] +postgres = [ + "bytes", + "chrono", + "tokio-postgres", + "postgres-types", + "postgres-protocol", +] [dev-dependencies] env_logger = { workspace = true } diff --git a/datafusion/sqllogictest/README.md b/datafusion/sqllogictest/README.md index 930df4796776..5becc75c985a 100644 --- a/datafusion/sqllogictest/README.md +++ b/datafusion/sqllogictest/README.md @@ -133,7 +133,7 @@ In order to run the sqllogictests running against a previously running Postgres PG_COMPAT=true PG_URI="postgresql://postgres@127.0.0.1/postgres" cargo test --features=postgres --test sqllogictests ``` -The environemnt variables: +The environment variables: 1. `PG_COMPAT` instructs sqllogictest to run against Postgres (not DataFusion) 2. `PG_URI` contains a `libpq` style connection string, whose format is described in @@ -225,7 +225,7 @@ query ``` -- `test_name`: Uniquely identify the test name (Datafusion only) +- `test_name`: Uniquely identify the test name (DataFusion only) - `type_string`: A short string that specifies the number of result columns and the expected datatype of each result column. There is one character in the for each result column. The characters codes are: - 'B' - **B**oolean, diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs b/datafusion/sqllogictest/bin/sqllogictests.rs index 560328ee8619..c3e739d146c6 100644 --- a/datafusion/sqllogictest/bin/sqllogictests.rs +++ b/datafusion/sqllogictest/bin/sqllogictests.rs @@ -18,12 +18,11 @@ use std::ffi::OsStr; use std::fs; use std::path::{Path, PathBuf}; -#[cfg(target_family = "windows")] -use std::thread; use clap::Parser; use datafusion_sqllogictest::{DataFusion, TestContext}; use futures::stream::StreamExt; +use itertools::Itertools; use log::info; use sqllogictest::strict_column_validator; @@ -33,28 +32,29 @@ use datafusion_common_runtime::SpawnedTask; const TEST_DIRECTORY: &str = "test_files/"; const PG_COMPAT_FILE_PREFIX: &str = "pg_compat_"; -#[cfg(target_family = "windows")] -pub fn main() { - // Tests from `tpch/tpch.slt` fail with stackoverflow with the default stack size. - thread::Builder::new() - .stack_size(2 * 1024 * 1024) // 2 MB - .spawn(move || { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() - .block_on(async { run_tests().await }) - .unwrap() - }) +pub fn main() -> Result<()> { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() .unwrap() - .join() - .unwrap(); + .block_on(run_tests()) } -#[tokio::main] -#[cfg(not(target_family = "windows"))] -pub async fn main() -> Result<()> { - run_tests().await +fn value_validator(actual: &[Vec], expected: &[String]) -> bool { + let expected = expected + .iter() + // Trailing whitespace from lines in SLT will typically be removed, but do not fail if it is not + // If particular test wants to cover trailing whitespace on a value, + // it should project additional non-whitespace column on the right. + .map(|s| s.trim_end().to_owned()) + .collect::>(); + let actual = actual + .iter() + .map(|strs| strs.iter().join(" ")) + // Editors do not preserve trailing whitespace, so expected may or may not lack it included + .map(|s| s.trim_end().to_owned()) + .collect::>(); + actual == expected } /// Sets up an empty directory at test_files/scratch/ @@ -79,7 +79,16 @@ async fn run_tests() -> Result<()> { // Enable logging (e.g. set RUST_LOG=debug to see debug logs) env_logger::init(); - let options: Options = clap::Parser::parse(); + let options: Options = Parser::parse(); + if options.list { + // nextest parses stdout, so print messages to stderr + eprintln!("NOTICE: --list option unsupported, quitting"); + // return Ok, not error so that tools like nextest which are listing all + // workspace tests (by running `cargo test ... --list --format terse`) + // do not fail when they encounter this binary. Instead, print nothing + // to stdout and return OK so they can continue listing other tests. + return Ok(()); + } options.warn_on_ignored(); // Run all tests in parallel, reporting failures at the end @@ -149,6 +158,7 @@ async fn run_test_file(test_file: TestFile) -> Result<()> { )) }); runner.with_column_validator(strict_column_validator); + runner.with_validator(value_validator); runner .run_file_async(path) .await @@ -167,6 +177,7 @@ async fn run_test_file_with_postgres(test_file: TestFile) -> Result<()> { let mut runner = sqllogictest::Runner::new(|| Postgres::connect(relative_path.clone())); runner.with_column_validator(strict_column_validator); + runner.with_validator(value_validator); runner .run_file_async(path) .await @@ -185,7 +196,6 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { path, relative_path, } = test_file; - use sqllogictest::default_validator; info!("Using complete mode to complete: {}", path.display()); @@ -205,7 +215,7 @@ async fn run_complete_file(test_file: TestFile) -> Result<()> { .update_test_file( path, col_separator, - default_validator, + value_validator, strict_column_validator, ) .await @@ -273,7 +283,7 @@ fn read_dir_recursive>(path: P) -> Result> { /// Append all paths recursively to dst fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { - let entries = std::fs::read_dir(path) + let entries = fs::read_dir(path) .map_err(|e| exec_datafusion_err!("Error reading directory {path:?}: {e}"))?; for entry in entries { let path = entry @@ -294,7 +304,7 @@ fn read_dir_recursive_impl(dst: &mut Vec, path: &Path) -> Result<()> { /// Parsed command line options /// -/// This structure attempts to mimic the command line options +/// This structure attempts to mimic the command line options of the built in rust test runner /// accepted by IDEs such as CLion that pass arguments /// /// See for more details @@ -338,6 +348,18 @@ struct Options { help = "IGNORED (for compatibility with built in rust test runner)" )] show_output: bool, + + #[clap( + long, + help = "Quits immediately, not listing anything (for compatibility with built-in rust test runner)" + )] + list: bool, + + #[clap( + long, + help = "IGNORED (for compatibility with built-in rust test runner)" + )] + ignored: bool, } impl Options { @@ -372,15 +394,15 @@ impl Options { /// Logs warning messages to stdout if any ignored options are passed fn warn_on_ignored(&self) { if self.format.is_some() { - println!("WARNING: Ignoring `--format` compatibility option"); + eprintln!("WARNING: Ignoring `--format` compatibility option"); } if self.z_options.is_some() { - println!("WARNING: Ignoring `-Z` compatibility option"); + eprintln!("WARNING: Ignoring `-Z` compatibility option"); } if self.show_output { - println!("WARNING: Ignoring `--show-output` compatibility option"); + eprintln!("WARNING: Ignoring `--show-output` compatibility option"); } } } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 520b6b53b32d..8337d2e9a39c 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +use crate::engines::output::DFColumnType; +use arrow::array::{Array, AsArray}; use arrow::datatypes::Fields; use arrow::util::display::ArrayFormatter; use arrow::{array, array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; @@ -23,8 +25,6 @@ use datafusion_common::DataFusionError; use std::path::PathBuf; use std::sync::OnceLock; -use crate::engines::output::DFColumnType; - use super::super::conversion::*; use super::error::{DFSqlLogicTestError, Result}; @@ -233,6 +233,16 @@ pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result { DataType::Utf8 => { Ok(varchar_to_str(get_row_value!(array::StringArray, col, row))) } + DataType::Utf8View => Ok(varchar_to_str(get_row_value!( + array::StringViewArray, + col, + row + ))), + DataType::Dictionary(_, _) => { + let dict = col.as_any_dictionary(); + let key = dict.normalized_keys()[row]; + Ok(cell_to_string(dict.values(), key)?) + } _ => { let f = ArrayFormatter::try_new(col.as_ref(), &DEFAULT_FORMAT_OPTIONS); Ok(f.unwrap().value(row).to_string()) @@ -262,12 +272,25 @@ pub(crate) fn convert_schema_to_types(columns: &Fields) -> Vec { | DataType::Float64 | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => DFColumnType::Float, - DataType::Utf8 | DataType::LargeUtf8 => DFColumnType::Text, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + DFColumnType::Text + } DataType::Date32 | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => DFColumnType::DateTime, DataType::Timestamp(_, _) => DFColumnType::Timestamp, + DataType::Dictionary(key_type, value_type) => { + if key_type.is_integer() { + // mapping dictionary string types to Text + match value_type.as_ref() { + DataType::Utf8 | DataType::LargeUtf8 => DFColumnType::Text, + _ => DFColumnType::Another, + } + } else { + DFColumnType::Another + } + } _ => DFColumnType::Another, }) .collect() diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs index afd0a241ca5e..5c24b49cfe86 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/runner.rs @@ -15,10 +15,13 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; use std::{path::PathBuf, time::Duration}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::execute_stream; use datafusion::prelude::SessionContext; use log::info; use sqllogictest::DBOutput; @@ -69,9 +72,12 @@ impl sqllogictest::AsyncDB for DataFusion { async fn run_query(ctx: &SessionContext, sql: impl Into) -> Result { let df = ctx.sql(sql.into().as_str()).await?; + let task_ctx = Arc::new(df.task_ctx()); + let plan = df.create_physical_plan().await?; - let types = normalize::convert_schema_to_types(df.schema().fields()); - let results: Vec = df.collect().await?; + let stream = execute_stream(plan, task_ctx)?; + let types = normalize::convert_schema_to_types(stream.schema().fields()); + let results: Vec = collect(stream).await?; let rows = normalize::convert_batches(results)?; if rows.is_empty() && types.is_empty() { diff --git a/datafusion/sqllogictest/src/test_context.rs b/datafusion/sqllogictest/src/test_context.rs index dd27727e3ad5..477f225443e2 100644 --- a/datafusion/sqllogictest/src/test_context.rs +++ b/datafusion/sqllogictest/src/test_context.rs @@ -27,12 +27,12 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit}; use arrow::record_batch::RecordBatch; -use datafusion::execution::context::SessionState; use datafusion::logical_expr::{create_udf, ColumnarValue, Expr, ScalarUDF, Volatility}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionConfig; use datafusion::{ - catalog::{schema::MemorySchemaProvider, CatalogProvider, MemoryCatalogProvider}, + catalog::CatalogProvider, + catalog_common::{memory::MemoryCatalogProvider, memory::MemorySchemaProvider}, datasource::{MemTable, TableProvider, TableType}, prelude::{CsvReadOptions, SessionContext}, }; @@ -40,6 +40,7 @@ use datafusion_common::cast::as_float64_array; use datafusion_common::DataFusionError; use async_trait::async_trait; +use datafusion::catalog::Session; use log::info; use tempfile::TempDir; @@ -97,6 +98,9 @@ impl TestContext { return None; } } + "dynamic_file.slt" => { + test_ctx.ctx = test_ctx.ctx.enable_url_table(); + } "joins.slt" => { info!("Registering partition table tables"); let example_udf = create_example_udf(); @@ -135,7 +139,7 @@ impl TestContext { } #[cfg(feature = "avro")] -pub async fn register_avro_tables(ctx: &mut crate::TestContext) { +pub async fn register_avro_tables(ctx: &mut TestContext) { use datafusion::prelude::AvroReadOptions; ctx.enable_testdir(); @@ -203,6 +207,7 @@ pub async fn register_partition_table(test_ctx: &mut TestContext) { // registers a LOCAL TEMPORARY table. pub async fn register_temp_table(ctx: &SessionContext) { + #[derive(Debug)] struct TestTable(TableType); #[async_trait] @@ -221,7 +226,7 @@ pub async fn register_temp_table(ctx: &SessionContext) { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, _: Option<&Vec>, _: &[Expr], _: Option, @@ -309,17 +314,49 @@ pub async fn register_metadata_tables(ctx: &SessionContext) { String::from("metadata_key"), String::from("the name field"), )])); - - let schema = Schema::new(vec![id, name]).with_metadata(HashMap::from([( - String::from("metadata_key"), - String::from("the entire schema"), - )])); + let l_name = + Field::new("l_name", DataType::Utf8, true).with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("the l_name field"), + )])); + + let ts = Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false) + .with_metadata(HashMap::from([( + String::from("metadata_key"), + String::from("ts non-nullable field"), + )])); + + let nonnull_name = + Field::new("nonnull_name", DataType::Utf8, false).with_metadata(HashMap::from([ + ( + String::from("metadata_key"), + String::from("the nonnull_name field"), + ), + ])); + + let schema = Schema::new(vec![id, name, l_name, ts, nonnull_name]).with_metadata( + HashMap::from([( + String::from("metadata_key"), + String::from("the entire schema"), + )]), + ); let batch = RecordBatch::try_new( Arc::new(schema), vec![ Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as _, Arc::new(StringArray::from(vec![None, Some("bar"), Some("baz")])) as _, + Arc::new(StringArray::from(vec![None, Some("l_bar"), Some("l_baz")])) as _, + Arc::new(TimestampNanosecondArray::from(vec![ + 1599572549190855123, + 1599572549190855123, + 1599572549190855123, + ])) as _, + Arc::new(StringArray::from(vec![ + Some("no_foo"), + Some("no_bar"), + Some("no_baz"), + ])) as _, ], ) .unwrap(); @@ -355,7 +392,7 @@ fn create_example_udf() -> ScalarUDF { // Expects two f64 values: vec![DataType::Float64, DataType::Float64], // Returns an f64 value: - Arc::new(DataType::Float64), + DataType::Float64, Volatility::Immutable, adder, ) diff --git a/datafusion/sqllogictest/test_files/agg_func_substitute.slt b/datafusion/sqllogictest/test_files/agg_func_substitute.slt index 7beb20a52134..9a0a1d587433 100644 --- a/datafusion/sqllogictest/test_files/agg_func_substitute.slt +++ b/datafusion/sqllogictest/test_files/agg_func_substitute.slt @@ -27,10 +27,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../../datafusion/core/tests/data/window_2.csv'; +LOCATION '../../datafusion/core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); query TT @@ -39,16 +39,16 @@ EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1] as result GROUP BY a; ---- logical_plan -01)Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result -02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] +01)Projection: multiple_ordered_table.a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result +02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[nth_value(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] 03)----TableScan: multiple_ordered_table projection=[a, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] -02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted -03)----SortExec: expr=[a@0 ASC NULLS LAST] +01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] +02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true @@ -59,16 +59,16 @@ EXPLAIN SELECT a, NTH_VALUE(c, 1 ORDER BY c) as result GROUP BY a; ---- logical_plan -01)Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result -02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] +01)Projection: multiple_ordered_table.a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result +02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[nth_value(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] 03)----TableScan: multiple_ordered_table projection=[a, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] -02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted -03)----SortExec: expr=[a@0 ASC NULLS LAST] +01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] +02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true @@ -78,16 +78,16 @@ EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1 + 100] as result GROUP BY a; ---- logical_plan -01)Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result -02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(101)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] +01)Projection: multiple_ordered_table.a, nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result +02)--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[nth_value(multiple_ordered_table.c, Int64(101)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]] 03)----TableScan: multiple_ordered_table projection=[a, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] -02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted -03)----SortExec: expr=[a@0 ASC NULLS LAST] +01)ProjectionExec: expr=[a@0 as a, nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result] +02)--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted +06)----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[nth_value(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted 07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 8b5b84e76650..917e037682f2 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -35,8 +35,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE TABLE d_table (c1 decimal(10,3), c2 varchar) @@ -72,32 +72,43 @@ CREATE TABLE test (c1 BIGINT,c2 BIGINT) as values ####### # https://github.com/apache/datafusion/issues/3353 -statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name "APPROX_DISTINCT\(aggregate_test_100\.c9\)" +statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name "approx_distinct\(aggregate_test_100\.c9\)" SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_weight -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Utf8, Int8, Float64\)'. You might need to add explicit type casts. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Utf8, Int8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c1, c2, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Utf8, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Utf8, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c1, 0.95) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT_WITH_WEIGHT\(Int16, Int8, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Int8, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont_with_weight(c3, c2, c1) FROM aggregate_test_100 # csv_query_approx_percentile_cont_with_histogram_bins -statement error This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\). +statement error DataFusion error: External error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal \(got data type Int64\)\. SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Utf8\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Utf8\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int16, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\. +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float64, Float64, Float64\] to the signature OneOf(.*) failed(.|\n)* SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 +statement error DataFusion error: This feature is not implemented: Percentile value for 'APPROX_PERCENTILE_CONT' must be a literal +SELECT approx_percentile_cont(c12, c12) FROM aggregate_test_100 + +statement error DataFusion error: This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be a literal +SELECT approx_percentile_cont(c12, 0.95, c5) FROM aggregate_test_100 + +# Not supported over sliding windows +query error This feature is not implemented: Aggregate can not be used as a sliding accumulator because `retract_batch` is not implemented +SELECT approx_percentile_cont(c3, 0.5) OVER (ROWS BETWEEN 4 PRECEDING AND CURRENT ROW) +FROM aggregate_test_100 + # array agg can use order by query ? SELECT array_agg(c13 ORDER BY c13) @@ -116,8 +127,8 @@ c2 INT NOT NULL, c3 INT NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../core/tests/data/aggregate_agg_multi_order.csv'; +LOCATION '../core/tests/data/aggregate_agg_multi_order.csv' +OPTIONS ('format.has_header' 'true'); # test array_agg with order by multiple columns query ? @@ -129,13 +140,13 @@ query TT explain select array_agg(c1 order by c2 desc, c3) from agg_order; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] +01)Aggregate: groupBy=[[]], aggr=[[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]]] 02)--TableScan: agg_order projection=[c1, c2, c3] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] +01)AggregateExec: mode=Final, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] -04)------SortExec: expr=[c2@1 DESC,c3@2 ASC NULLS LAST] +03)----AggregateExec: mode=Partial, gby=[], aggr=[array_agg(agg_order.c1) ORDER BY [agg_order.c2 DESC NULLS FIRST, agg_order.c3 ASC NULLS LAST]] +04)------SortExec: expr=[c2@1 DESC, c3@2 ASC NULLS LAST], preserve_partitioning=[true] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_agg_multi_order.csv]]}, projection=[c1, c2, c3], has_header=true @@ -183,7 +194,7 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES ; # Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort, -# so they are covered in `datafusion/physical-expr/src/aggregate/array_agg_distinct.rs` +# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs` query ?? select array_sort(c1), array_sort(c2) from ( select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table @@ -194,10 +205,77 @@ select array_sort(c1), array_sort(c2) from ( statement ok drop table array_agg_distinct_list_table; -statement error This feature is not implemented: LIMIT not supported in ARRAY_AGG: 1 +statement error This feature is not implemented: Calling array_agg: LIMIT not supported in function arguments: 1 SELECT array_agg(c13 LIMIT 1) FROM aggregate_test_100 +# Test distinct aggregate function with merge batch +query II +with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 + ---- The order is non-deterministic, verify with length +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +3 1 + +# It has only AggregateExec with FinalPartitioned mode, so `merge_batch` is used +# If the plan is changed, whether the `merge_batch` is used should be verified to ensure the test coverage +query TT +explain with A as ( + select 1 as id, 2 as foo + UNION ALL + select 1, null + UNION ALL + select 1, null + UNION ALL + select 1, 3 + UNION ALL + select 1, 2 +) select array_length(array_agg(distinct a.foo)), sum(distinct 1) from A a group by a.id; +---- +logical_plan +01)Projection: array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1)) +02)--Aggregate: groupBy=[[a.id]], aggr=[[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))]] +03)----SubqueryAlias: a +04)------SubqueryAlias: a +05)--------Union +06)----------Projection: Int64(1) AS id, Int64(2) AS foo +07)------------EmptyRelation +08)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +09)------------EmptyRelation +10)----------Projection: Int64(1) AS id, Int64(NULL) AS foo +11)------------EmptyRelation +12)----------Projection: Int64(1) AS id, Int64(3) AS foo +13)------------EmptyRelation +14)----------Projection: Int64(1) AS id, Int64(2) AS foo +15)------------EmptyRelation +physical_plan +01)ProjectionExec: expr=[array_length(array_agg(DISTINCT a.foo)@1) as array_length(array_agg(DISTINCT a.foo)), sum(DISTINCT Int64(1))@2 as sum(DISTINCT Int64(1))] +02)--AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=5 +05)--------AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[array_agg(DISTINCT a.foo), sum(DISTINCT Int64(1))], ordering_mode=Sorted +06)----------UnionExec +07)------------ProjectionExec: expr=[1 as id, 2 as foo] +08)--------------PlaceholderRowExec +09)------------ProjectionExec: expr=[1 as id, NULL as foo] +10)--------------PlaceholderRowExec +11)------------ProjectionExec: expr=[1 as id, NULL as foo] +12)--------------PlaceholderRowExec +13)------------ProjectionExec: expr=[1 as id, 3 as foo] +14)--------------PlaceholderRowExec +15)------------ProjectionExec: expr=[1 as id, 2 as foo] +16)--------------PlaceholderRowExec + + # FIX: custom absolute values # csv_query_avg_multi_batch @@ -382,6 +460,15 @@ SELECT var(distinct c2) FROM aggregate_test_100 statement error DataFusion error: This feature is not implemented: VAR\(DISTINCT\) aggregations are not available SELECT var(c2), var(distinct c2) FROM aggregate_test_100 +# csv_query_distinct_variance_population +query R +SELECT var_pop(distinct c2) FROM aggregate_test_100 +---- +2 + +statement error DataFusion error: This feature is not implemented: VAR_POP\(DISTINCT\) aggregations are not available +SELECT var_pop(c2), var_pop(distinct c2) FROM aggregate_test_100 + # csv_query_variance_5 query R SELECT var_samp(c2) FROM aggregate_test_100 @@ -424,6 +511,85 @@ select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq ---- 0.950438495292 +# csv_query_stddev_7 +query IR +SELECT c2, stddev_samp(c12) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 +---- +1 0.303641032262 +2 0.284581967411 +3 0.296002660506 +4 0.284324609109 +5 0.331034486752 + +# csv_query_stddev_8 +query IR +SELECT c2, stddev_pop(c12) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 +---- +1 0.296659845456 +2 0.278038978602 +3 0.288107833475 +4 0.278074953424 +5 0.318992813225 + +# csv_query_stddev_9 +query IR +SELECT c2, var_pop(c12) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 +---- +1 0.088007063906 +2 0.077305673622 +3 0.083006123709 +4 0.077325679722 +5 0.101756414889 + +# csv_query_stddev_10 +query IR +SELECT c2, var_samp(c12) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 +---- +1 0.092197876473 +2 0.080986896176 +3 0.087617575027 +4 0.080840483345 +5 0.109583831419 + +# csv_query_stddev_11 +query IR +SELECT c2, var_samp(c12) FROM aggregate_test_100 WHERE c12 > 0.90 GROUP BY c2 ORDER BY c2 +---- +1 0.000889240174 +2 0.000785878272 +3 NULL +4 NULL +5 0.000269544643 + +# Use PostgresSQL dialect +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +# csv_query_stddev_12 +query IR +SELECT c2, var_samp(c12) FILTER (WHERE c12 > 0.95) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 +---- +1 0.000791243479 +2 0.000061521903 +3 NULL +4 NULL +5 NULL + +# Restore the default dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; + +# csv_query_stddev_13 +query IR +SELECT c2, var_samp(CASE WHEN c12 > 0.90 THEN c12 ELSE null END) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2 +---- +1 0.000889240174 +2 0.000785878272 +3 NULL +4 NULL +5 0.000269544643 + + # csv_query_approx_median_1 query I SELECT approx_median(c2) FROM aggregate_test_100 @@ -442,6 +608,11 @@ SELECT approx_median(c12) FROM aggregate_test_100 ---- 0.555006541052 +# csv_query_approx_median_4 +# test with string, approx median only supports numeric +statement error +SELECT approx_median(c1) FROM aggregate_test_100 + # csv_query_median_1 query I SELECT median(c2) FROM aggregate_test_100 @@ -472,8 +643,10 @@ SELECT median(distinct col_i8) FROM median_table ---- 100 -statement error DataFusion error: This feature is not implemented: MEDIAN\(DISTINCT\) aggregations are not available +query II SELECT median(col_i8), median(distinct col_i8) FROM median_table +---- +-14 100 # approx_distinct_median_i8 query I @@ -550,6 +723,290 @@ SELECT approx_median(col_f64_nan) FROM median_table ---- NaN +# median decimal +statement ok +create table t(c decimal(10, 4)) as values (0.0001), (0.0002), (0.0003), (0.0004), (0.0005), (0.0006); + +query RT +select median(c), arrow_typeof(median(c)) from t; +---- +0.0003 Decimal128(10, 4) + +query RT +select approx_median(c), arrow_typeof(approx_median(c)) from t; +---- +0.00035 Float64 + +statement ok +drop table t; + +# median decimal with nulls +statement ok +create table t(c decimal(10, 4)) as values (0.0001), (null), (0.0003), (0.0004), (0.0005); + +query RT +select median(c), arrow_typeof(median(c)) from t; +---- +0.0003 Decimal128(10, 4) + +statement ok +drop table t; + +# median decimal with all nulls +statement ok +create table t(c decimal(10, 4)) as values (null), (null), (null); + +query RT +select median(c), arrow_typeof(median(c)) from t; +---- +NULL Decimal128(10, 4) + +statement ok +drop table t; + +# median odd +statement ok +create table t(c int) as values (1), (2), (3), (4), (5); + +query I +select median(c) from t; +---- +3 + +statement ok +drop table t; + +# median even +statement ok +create table t(c int) as values (1), (2), (3), (4), (5), (6); + +query I +select median(c) from t; +---- +3 + +statement ok +drop table t; + +# median with nulls +statement ok +create table t(c int) as values (1), (null), (3), (4), (5); + +query I +select median(c) from t; +---- +3 + +statement ok +drop table t; + +# median with all nulls +statement ok +create table t(c int) as values (null), (null), (null); + +query I +select median(c) from t; +---- +NULL + +statement ok +drop table t; + +# median u32 +statement ok +create table t(c int unsigned) as values (1), (2), (3), (4), (5); + +query I +select median(c) from t; +---- +3 + +statement ok +drop table t; + +# median f32 +statement ok +create table t(c float) as values (1.1), (2.2), (3.3), (4.4), (5.5); + +query R +select median(c) from t; +---- +3.3 + +statement ok +drop table t; + +# median distinct decimal +statement ok +create table t(c decimal(10, 4)) as values (0.0001), (0.0001), (0.0001), (0.0001), (0.0002), (0.0002), (0.0003), (0.0003); + +query R +select median(distinct c) from t; +---- +0.0002 + +statement ok +drop table t; + +# median distinct decimal with nulls +statement ok +create table t(c decimal(10, 4)) as values (0.0001), (0.0001), (0.0001), (null), (null), (0.0002), (0.0003), (0.0003); + +query R +select median(distinct c) from t; +---- +0.0002 + +statement ok +drop table t; + +# distinct median i32 odd +statement ok +create table t(c int) as values (2), (1), (1), (2), (1), (3); + +query I +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median i32 even +statement ok +create table t(c int) as values (1), (1), (3), (1), (1); + +query I +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median i32 with nulls +statement ok +create table t(c int) as values (1), (null), (1), (1), (3); + +query I +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median u32 odd +statement ok +create table t(c int unsigned) as values (1), (1), (2), (1), (3); + +query I +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median u32 even +statement ok +create table t(c int unsigned) as values (1), (1), (1), (1), (3), (3); + +query I +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median f32 odd +statement ok +create table t(c float) as values (3), (2), (1), (1), (1); + +query R +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median f32 even +statement ok +create table t(c float) as values (1), (1), (1), (1), (2); + +query R +select median(distinct c) from t; +---- +1.5 + +statement ok +drop table t; + +# distinct median f64 odd +statement ok +create table t(c double) as values (1), (1), (1), (2), (3); + +query R +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# distinct median f64 even +statement ok +create table t(c double) as values (1), (1), (1), (1), (2); + +query R +select median(distinct c) from t; +---- +1.5 + +statement ok +drop table t; + +# distinct median i32 +statement ok +create table t(c int) as values (1), (1), (1), (1), (2), (2), (3), (3); + +query I +select median(distinct c) from t; +---- +2 + +statement ok +drop table t; + +# optimize distinct median to group by +statement ok +create table t(c int) as values (1), (1), (1), (1), (2), (2), (3), (3); + +query TT +explain select median(distinct c) from t; +---- +logical_plan +01)Projection: median(alias1) AS median(DISTINCT t.c) +02)--Aggregate: groupBy=[[]], aggr=[[median(alias1)]] +03)----Aggregate: groupBy=[[t.c AS alias1]], aggr=[[]] +04)------TableScan: t projection=[c] +physical_plan +01)ProjectionExec: expr=[median(alias1)@0 as median(DISTINCT t.c)] +02)--AggregateExec: mode=Final, gby=[], aggr=[median(alias1)] +03)----CoalescePartitionsExec +04)------AggregateExec: mode=Partial, gby=[], aggr=[median(alias1)] +05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([alias1@0], 4), input_partitions=4 +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------AggregateExec: mode=Partial, gby=[c@0 as alias1], aggr=[] +10)------------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table t; + # median_multi # test case for https://github.com/apache/datafusion/issues/3105 # has an intermediate grouping @@ -667,6 +1124,14 @@ SELECT COUNT(*) FROM aggregate_test_100 ---- 100 +query I +SELECT COUNT(aggregate_test_100.*) FROM aggregate_test_100 +---- +100 + +query error Error during planning: Invalid qualifier foo +SELECT COUNT(foo.*) FROM aggregate_test_100 + # csv_query_count_literal query I SELECT COUNT(2) FROM aggregate_test_100 @@ -870,11 +1335,79 @@ SELECT (ABS(1 - CAST(approx_percentile_cont(c11, 0.9) AS DOUBLE) / 0.834) < 0.05 ---- true -# csv_query_cube_avg -query TIR -SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2 +# percentile_cont_with_nulls +query I +SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (1), (2), (3), (NULL), (NULL), (NULL)) as t (v); ---- -a 1 -17.6 +2 + +# percentile_cont_with_nulls_only +query I +SELECT APPROX_PERCENTILE_CONT(v, 0.5) FROM (VALUES (CAST(NULL as INT))) as t (v); +---- +NULL + +# +# percentile_cont edge cases +# + +statement ok +CREATE TABLE tmp_percentile_cont(v1 INT, v2 DOUBLE); + +statement ok +INSERT INTO tmp_percentile_cont VALUES (1, 'NaN'::Double), (2, 'NaN'::Double), (3, 'NaN'::Double); + +# ISSUE: https://github.com/apache/datafusion/issues/11871 +# Note `approx_median()` is using the same implementation as `approx_percentile_cont()` +query R +select APPROX_MEDIAN(v2) from tmp_percentile_cont WHERE v1 = 1; +---- +NaN + +# ISSUE: https://github.com/apache/datafusion/issues/11870 +query R +select APPROX_PERCENTILE_CONT(v2, 0.8) from tmp_percentile_cont; +---- +NaN + +# ISSUE: https://github.com/apache/datafusion/issues/11869 +# Note: `approx_percentile_cont_with_weight()` uses the same implementation as `approx_percentile_cont()` +query R +SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT( + v2, + '+Inf'::Double, + 0.9 +) +FROM tmp_percentile_cont; +---- +NaN + +statement ok +DROP TABLE tmp_percentile_cont; + +# Test for issue where approx_percentile_cont_with_weight + +statement ok +CREATE TABLE t1(v1 BOOL); + +statement ok +INSERT INTO t1 VALUES (TRUE); + +# ISSUE: https://github.com/apache/datafusion/issues/12716 +# This test verifies that approx_percentile_cont_with_weight does not panic when given 'NaN' and returns 'inf' +query R +SELECT approx_percentile_cont_with_weight('NaN'::DOUBLE, 0, 0) FROM t1 WHERE t1.v1; +---- +Infinity + +statement ok +DROP TABLE t1; + +# csv_query_cube_avg +query TIR +SELECT c1, c2, AVG(c3) FROM aggregate_test_100 GROUP BY CUBE (c1, c2) ORDER BY c1, c2 +---- +a 1 -17.6 a 2 -15.333333333333 a 3 -4.5 a 4 -32 @@ -1161,154 +1694,154 @@ e e 1323 query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY CUBE (a.c1, b.c1) ORDER BY a.c1, b.c1 ---- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -a NULL 6000 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -b NULL 6200 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -c NULL 5600 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -d NULL 4400 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 -e NULL 6300 -NULL a 5985 -NULL b 5415 -NULL c 5985 -NULL d 5130 -NULL e 5985 +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +a NULL 6000 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +b NULL 6200 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +c NULL 5600 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +d NULL 4400 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 +e NULL 6300 +NULL a 5985 +NULL b 5415 +NULL c 5985 +NULL d 5130 +NULL e 5985 NULL NULL 28500 # csv_query_cube_distinct_count query TII SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY CUBE (c1,c2) ORDER BY c1,c2 ---- -a 1 5 -a 2 3 -a 3 5 -a 4 4 -a 5 3 -a NULL 19 -b 1 3 -b 2 4 -b 3 2 -b 4 5 -b 5 5 -b NULL 17 -c 1 4 -c 2 7 -c 3 4 -c 4 4 -c 5 2 -c NULL 21 -d 1 7 -d 2 3 -d 3 3 -d 4 3 -d 5 2 -d NULL 18 -e 1 3 -e 2 4 -e 3 4 -e 4 7 -e 5 2 -e NULL 18 -NULL 1 22 -NULL 2 20 -NULL 3 17 -NULL 4 23 -NULL 5 14 +a 1 5 +a 2 3 +a 3 5 +a 4 4 +a 5 3 +a NULL 19 +b 1 3 +b 2 4 +b 3 2 +b 4 5 +b 5 5 +b NULL 17 +c 1 4 +c 2 7 +c 3 4 +c 4 4 +c 5 2 +c NULL 21 +d 1 7 +d 2 3 +d 3 3 +d 4 3 +d 5 2 +d NULL 18 +e 1 3 +e 2 4 +e 3 4 +e 4 7 +e 5 2 +e NULL 18 +NULL 1 22 +NULL 2 20 +NULL 3 17 +NULL 4 23 +NULL 5 14 NULL NULL 80 # csv_query_rollup_distinct_count query TII SELECT c1, c2, COUNT(DISTINCT c3) FROM aggregate_test_100 GROUP BY ROLLUP (c1,c2) ORDER BY c1,c2 ---- -a 1 5 -a 2 3 -a 3 5 -a 4 4 -a 5 3 -a NULL 19 -b 1 3 -b 2 4 -b 3 2 -b 4 5 -b 5 5 -b NULL 17 -c 1 4 -c 2 7 -c 3 4 -c 4 4 -c 5 2 -c NULL 21 -d 1 7 -d 2 3 -d 3 3 -d 4 3 -d 5 2 -d NULL 18 -e 1 3 -e 2 4 -e 3 4 -e 4 7 -e 5 2 -e NULL 18 +a 1 5 +a 2 3 +a 3 5 +a 4 4 +a 5 3 +a NULL 19 +b 1 3 +b 2 4 +b 3 2 +b 4 5 +b 5 5 +b NULL 17 +c 1 4 +c 2 7 +c 3 4 +c 4 4 +c 5 2 +c NULL 21 +d 1 7 +d 2 3 +d 3 3 +d 4 3 +d 5 2 +d NULL 18 +e 1 3 +e 2 4 +e 3 4 +e 4 7 +e 5 2 +e NULL 18 NULL NULL 80 # csv_query_rollup_sum_crossjoin query TTI SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY ROLLUP (a.c1, b.c1) ORDER BY a.c1, b.c1 ---- -a a 1260 -a b 1140 -a c 1260 -a d 1080 -a e 1260 -a NULL 6000 -b a 1302 -b b 1178 -b c 1302 -b d 1116 -b e 1302 -b NULL 6200 -c a 1176 -c b 1064 -c c 1176 -c d 1008 -c e 1176 -c NULL 5600 -d a 924 -d b 836 -d c 924 -d d 792 -d e 924 -d NULL 4400 -e a 1323 -e b 1197 -e c 1323 -e d 1134 -e e 1323 -e NULL 6300 +a a 1260 +a b 1140 +a c 1260 +a d 1080 +a e 1260 +a NULL 6000 +b a 1302 +b b 1178 +b c 1302 +b d 1116 +b e 1302 +b NULL 6200 +c a 1176 +c b 1064 +c c 1176 +c d 1008 +c e 1176 +c NULL 5600 +d a 924 +d b 836 +d c 924 +d d 792 +d e 924 +d NULL 4400 +e a 1323 +e b 1197 +e c 1323 +e d 1134 +e e 1323 +e NULL 6300 NULL NULL 28500 # query_count_without_from @@ -1327,7 +1860,7 @@ SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT query ? SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test ---- -[] +NULL # csv_query_array_agg_one query ? @@ -1386,31 +1919,12 @@ NULL 4 29 1.260869565217 123 -117 23 NULL 5 -194 -13.857142857143 118 -101 14 NULL NULL 781 7.81 125 -117 100 -# TODO: array_agg_distinct output is non-deterministic -- rewrite with array_sort(list_sort) -# unnest is also not available, so manually unnesting via CROSS JOIN -# additional count(1) forces array_agg_distinct instead of array_agg over aggregated by c2 data -# +# select with count to forces array_agg_distinct function, since single distinct expression is converted to group by by optimizer # csv_query_array_agg_distinct -query III -WITH indices AS ( - SELECT 1 AS idx UNION ALL - SELECT 2 AS idx UNION ALL - SELECT 3 AS idx UNION ALL - SELECT 4 AS idx UNION ALL - SELECT 5 AS idx -) -SELECT data.arr[indices.idx] as element, array_length(data.arr) as array_len, dummy -FROM ( - SELECT array_agg(distinct c2) as arr, count(1) as dummy FROM aggregate_test_100 -) data - CROSS JOIN indices -ORDER BY 1 ----- -1 5 100 -2 5 100 -3 5 100 -4 5 100 -5 5 100 +query ?I +SELECT array_sort(array_agg(distinct c2)), count(1) FROM aggregate_test_100 +---- +[1, 2, 3, 4, 5] 100 # aggregate_time_min_and_max query TT @@ -1418,30 +1932,39 @@ select min(t), max(t) from (select '00:00:00' as t union select '00:00:01' unio ---- 00:00:00 00:00:02 -# aggregate_decimal_min -query RT -select min(c1), arrow_typeof(min(c1)) from d_table ----- --100.009 Decimal128(10, 3) - -# aggregate_decimal_max -query RT -select max(c1), arrow_typeof(max(c1)) from d_table +# aggregate Interval(MonthDayNano) min/max +query T?? +select + arrow_typeof(min(column1)), min(column1), max(column1) +from values + (interval '1 month'), + (interval '2 months'), + (interval '2 month 15 days'), + (interval '-2 month') ---- -110.009 Decimal128(10, 3) +Interval(MonthDayNano) -2 mons 2 mons 15 days -# aggregate_decimal_sum -query RT -select sum(c1), arrow_typeof(sum(c1)) from d_table +# aggregate Interval(DayTime) min/max +query T?? +select + arrow_typeof(min(column1)), min(column1), max(column1) +from values + (arrow_cast('60 minutes', 'Interval(DayTime)')), + (arrow_cast('-3 minutes', 'Interval(DayTime)')), + (arrow_cast('30 minutes', 'Interval(DayTime)')); ---- -100 Decimal128(20, 3) +Interval(DayTime) -3 mins 1 hours -# aggregate_decimal_avg -query RT -select avg(c1), arrow_typeof(avg(c1)) from d_table +# aggregate Interval(YearMonth) min/max +query T?? +select + arrow_typeof(min(column1)), min(column1), max(column1) +from values + (arrow_cast('-1 year', 'Interval(YearMonth)')), + (arrow_cast('13 months', 'Interval(YearMonth)')), + (arrow_cast('1 year', 'Interval(YearMonth)')); ---- -5 Decimal128(14, 7) - +Interval(YearMonth) -1 years 0 mons 1 years 1 mons # aggregate query II @@ -1474,6 +1997,12 @@ SELECT MIN(c1), MIN(c2) FROM test ---- 0 1 +query error min/max was called with 2 arguments. It requires only 1. +SELECT MIN(c1, c2) FROM test + +query error min/max was called with 2 arguments. It requires only 1. +SELECT MAX(c1, c2) FROM test + # aggregate_grouped query II SELECT c1, SUM(c2) FROM test GROUP BY c1 order by c1 @@ -1560,6 +2089,91 @@ SELECT max(c1) FROM test; # count_basic +statement ok +create table t (c int) as values (1), (2), (null), (3), (null), (4), (5); + +query IT +select count(c), arrow_typeof(count(c)) from t; +---- +5 Int64 + +statement ok +drop table t; + +# test count with all nulls +statement ok +create table t (c int) as values (null), (null), (null), (null), (null); + +query IT +select count(c), arrow_typeof(count(c)) from t; +---- +0 Int64 + +statement ok +drop table t; + +# test with empty +statement ok +create table t (c int); + +query IT +select count(c), arrow_typeof(count(c)) from t; +---- +0 Int64 + +statement ok +drop table t; + +# test count with string +statement ok +create table t (c string) as values ('a'), ('b'), (null), ('c'), (null), ('d'), ('e'); + +query IT +select count(c), arrow_typeof(count(c)) from t; +---- +5 Int64 + +statement ok +drop table t; + +# test count with largeutf8 +statement ok +create table t (c string) as values + (arrow_cast('a', 'LargeUtf8')), + (arrow_cast('b', 'LargeUtf8')), + (arrow_cast(null, 'LargeUtf8')), + (arrow_cast('c', 'LargeUtf8')) +; + +query T +select arrow_typeof(c) from t; +---- +Utf8 +Utf8 +Utf8 +Utf8 + +query IT +select count(c), arrow_typeof(count(c)) from t; +---- +3 Int64 + +statement ok +drop table t; + +# test count with multiple columns +statement ok +create table t (c1 int, c2 int) as values (1, 1), (2, null), (null, 2), (null, null), (3, 3), (null, 4); + +query IT +select count(c1, c2), arrow_typeof(count(c1, c2)) from t; +---- +2 Int64 + +statement ok +drop table t; + + query II SELECT COUNT(c1), COUNT(c2) FROM test ---- @@ -1577,6 +2191,10 @@ SELECT count(c1, c2) FROM test ---- 3 +# count(distinct) with multiple arguments +query error DataFusion error: This feature is not implemented: COUNT DISTINCT with multiple arguments +SELECT count(distinct c1, c2) FROM test + # count_null query III SELECT count(null), count(null, null), count(distinct null) FROM test @@ -1617,227 +2235,1761 @@ select avg(c1) from test ---- 1.75 -# simple_mean -query R -select mean(c1) from test +# avg_decimal +statement ok +create table t (c1 decimal(10, 0)) as values (1), (2), (3), (4), (5), (6); + +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; ---- -1.75 +3.5 Decimal128(14, 4) +statement ok +drop table t; +# avg_decimal_with_nulls +statement ok +create table t (c1 decimal(10, 0)) as values (1), (NULL), (3), (4), (5); -# query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) -query RI -SELECT AVG(c1), SUM(DISTINCT c2) FROM test +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; ---- -1.75 3 +3.25 Decimal128(14, 4) -# query_sum_distinct - 2 sum(distinct) functions -query II -SELECT SUM(DISTINCT c1), SUM(DISTINCT c2) FROM test ----- -4 3 +statement ok +drop table t; -# # query_count_distinct -query I -SELECT COUNT(DISTINCT c1) FROM test ----- -3 +# avg_decimal_all_nulls +statement ok +create table t (c1 decimal(10, 0)) as values (NULL), (NULL), (NULL), (NULL), (NULL), (NULL); -# TODO: count_distinct_integers_aggregated_single_partition +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; +---- +NULL Decimal128(14, 4) -# TODO: count_distinct_integers_aggregated_multiple_partitions +statement ok +drop table t; -# TODO: aggregate_with_alias +# avg_i32 +statement ok +create table t (c1 int) as values (1), (2), (3), (4), (5); -# array_agg_zero -query ? -SELECT ARRAY_AGG([]) +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; ---- -[[]] +3 Float64 -# array_agg_one -query ? -SELECT ARRAY_AGG([1]) ----- -[[1]] +statement ok +drop table t; -# test_approx_percentile_cont_decimal_support -query TI -SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +# avg_i32_with_nulls +statement ok +create table t (c1 int) as values (1), (NULL), (3), (4), (5); + +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; ---- -a 4 -b 5 -c 4 -d 4 -e 4 +3.25 Float64 +statement ok +drop table t; -# array_agg_zero -query ? -SELECT ARRAY_AGG([]); ----- -[[]] +# avg_i32_all_nulls +statement ok +create table t (c1 int) as values (NULL), (NULL); -# array_agg_one -query ? -SELECT ARRAY_AGG([1]); +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; ---- -[[1]] +NULL Float64 -# variance_single_value -query RRRR -select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; ----- -NULL 0 NULL 0 +statement ok +drop table t; -# variance_two_values -query RRRR -select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq; +# avg_u32 +statement ok +create table t (c1 int unsigned) as values (1), (2), (3), (4), (5); + +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; ---- -2 1 1.414213562373 1 +3 Float64 + +statement ok +drop table t; + +# avg_f32 +statement ok +create table t (c1 float) as values (1), (2), (3), (4), (5); + +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; +---- +3 Float64 + +statement ok +drop table t; + +# avg_f64 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; +---- +3 Float64 + +statement ok +drop table t; + +# covariance_f64_1 +statement ok +create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_f64_2 +statement ok +create table t (c1 double, c2 double) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +1 Float64 + +statement ok +drop table t; + +# covariance_f64_4 +statement ok +create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + +# covariance_f64_5 +statement ok +create table t (c1 double, c2 double) as values (1.1, 4.1), (2.0, 5.0), (3.0, 6.0); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.602222222222 Float64 + +statement ok +drop table t; + +# covariance_f64_6 +statement ok +create table t (c1 double, c2 double) as values (1.0, 4.0), (2.0, 5.0), (3.0, 6.0), (1.1, 4.4), (2.2, 5.5), (3.3, 6.6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.761666666667 Float64 + +statement ok +drop table t; + +# covariance_i32 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_u32 +statement ok +create table t (c1 int unsigned, c2 int unsigned) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_f32 +statement ok +create table t (c1 float, c2 float) as values (1, 4), (2, 5), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_i32_with_nulls_1 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (null, null), (3, 6); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +1 Float64 + +statement ok +drop table t; + +# covariance_i32_with_nulls_2 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (null, 9), (2, 5), (null, 8), (3, 6), (null, null); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0.666666666667 Float64 + +statement ok +drop table t; + +# covariance_i32_with_nulls_3 +statement ok +create table t (c1 int, c2 int) as values (1, 4), (null, 9), (2, 5), (null, 8), (3, 6), (null, null); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +1 Float64 + +statement ok +drop table t; + +# covariance_i32_all_nulls +statement ok +create table t (c1 int, c2 int) as values (null, null), (null, null); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# covariance_pop_i32_all_nulls +statement ok +create table t (c1 int, c2 int) as values (null, null), (null, null); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# covariance_1_input +statement ok +create table t (c1 double, c2 double) as values (1, 2); + +query RT +select covar_samp(c1, c2), arrow_typeof(covar_samp(c1, c2)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# covariance_pop_1_input +statement ok +create table t (c1 double, c2 double) as values (1, 2); + +query RT +select covar_pop(c1, c2), arrow_typeof(covar_pop(c1, c2)) from t; +---- +0 Float64 + +statement ok +drop table t; + +# variance_f64_1 +statement ok +create table t (c double) as values (1), (2), (3), (4), (5); + +query RT +select var(c), arrow_typeof(var(c)) from t; +---- +2.5 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_1 +statement ok +create table t (c1 double) as values (1), (2); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +0.5 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_2 +statement ok +create table t (c1 double) as values (1.1), (2), (3); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +0.776029781788 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_3 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev f64_4 +statement ok +create table t (c1 double) as values (1.1), (2), (3); + +query RT +select stddev(c1), arrow_typeof(stddev(c1)) from t; +---- +0.950438495292 Float64 + +statement ok +drop table t; + +# aggregate stddev i32 +statement ok +create table t (c1 int) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev u32 +statement ok +create table t (c1 int unsigned) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev f32 +statement ok +create table t (c1 float) as values (1), (2), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.414213562373 Float64 + +statement ok +drop table t; + +# aggregate stddev single_input +statement ok +create table t (c1 double) as values (1); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +0 Float64 + +statement ok +drop table t; + +# aggregate stddev with_nulls +statement ok +create table t (c1 int) as values (1), (null), (3), (4), (5); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +1.479019945775 Float64 + +statement ok +drop table t; + +# aggregate stddev all_nulls +statement ok +create table t (c1 int) as values (null), (null); + +query RT +select stddev_pop(c1), arrow_typeof(stddev_pop(c1)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# aggregate variance f64_1 +statement ok +create table t (c1 double) as values (1), (2); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +0.25 Float64 + +statement ok +drop table t; + +# aggregate variance f64_2 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate variance f64_3 +statement ok +create table t (c1 double) as values (1), (2), (3), (4), (5); + +query RT +select var(c1), arrow_typeof(var(c1)) from t; +---- +2.5 Float64 + +statement ok +drop table t; + +# variance_f64_2 +statement ok +create table t (c double) as values (1.1), (2), (3); + +query RT +select var(c), arrow_typeof(var(c)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + +# aggregate variance f64_4 +statement ok +create table t (c1 double) as values (1.1), (2), (3); + +query RT +select var(c1), arrow_typeof(var(c1)) from t; +---- +0.903333333333 Float64 + +statement ok +drop table t; + +# variance_1_input +statement ok +create table t (a double not null) as values (1); + +query RT +select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# variance_i32_all_nulls +statement ok +create table t (a int) as values (null), (null); + +query RT +select var(a), arrow_typeof(var(a)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# aggregate variance i32 +statement ok +create table t (c1 int) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate variance u32 +statement ok +create table t (c1 int unsigned) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate variance f32 +statement ok +create table t (c1 float) as values (1), (2), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2 Float64 + +statement ok +drop table t; + +# aggregate single input +statement ok +create table t (c1 double) as values (1); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +0 Float64 + +statement ok +drop table t; + +# aggregate i32 with nulls +statement ok +create table t (c1 int) as values (1), (null), (3), (4), (5); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +2.1875 Float64 + +statement ok +drop table t; + +# aggregate i32 all nulls +statement ok +create table t (c1 int) as values (null), (null); + +query RT +select var_pop(c1), arrow_typeof(var_pop(c1)) from t; +---- +NULL Float64 + +statement ok +drop table t; + +# simple_mean +query R +select mean(c1) from test +---- +1.75 + +# aggregate sum distinct, coerced result from i32 to i64 +statement ok +create table t (c int) as values (1), (2), (1), (3), (null), (null), (-3), (-3); + +query IT +select sum(distinct c), arrow_typeof(sum(distinct c)) from t; +---- +3 Int64 + +statement ok +drop table t; + +# aggregate sum distinct, coerced result from u32 to u64 +statement ok +create table t (c int unsigned) as values (1), (2), (1), (3), (null), (null), (3); + +query IT +select sum(distinct c), arrow_typeof(sum(distinct c)) from t; +---- +6 UInt64 + +statement ok +drop table t; + +# aggregate sum distinct, coerced result from f32 to f64 +statement ok +create table t (c float) as values (1.0), (2.2), (1.0), (3.3), (null), (null), (3.3), (-2.0); + +query RT +select sum(distinct c), arrow_typeof(sum(distinct c)) from t; +---- +4.5 Float64 + +statement ok +drop table t; + +# aggregate sum distinct with decimal +statement ok +create table t (c decimal(35, 0)) as values (1), (2), (1), (3), (null), (null), (3), (-2); + +query RT +select sum(distinct c), arrow_typeof(sum(distinct c)) from t; +---- +4 Decimal128(38, 0) + +statement ok +drop table t; + +# query_sum_distinct - 2 different aggregate functions: avg and sum(distinct) +query RI +SELECT AVG(c1), SUM(DISTINCT c2) FROM test +---- +1.75 3 + +# query_sum_distinct - 2 sum(distinct) functions +query II +SELECT SUM(DISTINCT c1), SUM(DISTINCT c2) FROM test +---- +4 3 + +# # query_count_distinct +query I +SELECT COUNT(DISTINCT c1) FROM test +---- +3 + +# TODO: count_distinct_integers_aggregated_single_partition + +# TODO: count_distinct_integers_aggregated_multiple_partitions + +# TODO: aggregate_with_alias + +# test_approx_percentile_cont_decimal_support +query TI +SELECT c1, approx_percentile_cont(c2, cast(0.85 as decimal(10,2))) apc FROM aggregate_test_100 GROUP BY 1 ORDER BY 1 +---- +a 4 +b 5 +c 4 +d 4 +e 4 + +# array_agg_zero +query ? +SELECT ARRAY_AGG([]) +---- +[[]] + +# array_agg_one +query ? +SELECT ARRAY_AGG([1]) +---- +[[1]] + +# test array_agg with no row qualified +statement ok +create table t(a int, b float, c bigint) as values (1, 1.2, 2); + +# returns NULL, follows DuckDB's behaviour +query ? +select array_agg(a) from t where a > 2; +---- +NULL + +query ? +select array_agg(b) from t where b > 3.1; +---- +NULL + +query ? +select array_agg(c) from t where c > 3; +---- +NULL + +query ?I +select array_agg(c), count(1) from t where c > 3; +---- +NULL 0 + +# returns 0 rows if group by is applied, follows DuckDB's behaviour +query ? +select array_agg(a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(a), count(1) from t where a > 3 group by a; +---- + +# returns NULL, follows DuckDB's behaviour +query ? +select array_agg(distinct a) from t where a > 3; +---- +NULL + +query ?I +select array_agg(distinct a), count(1) from t where a > 3; +---- +NULL 0 + +# returns 0 rows if group by is applied, follows DuckDB's behaviour +query ? +select array_agg(distinct a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(distinct a), count(1) from t where a > 3 group by a; +---- + +# test order sensitive array agg +query ? +select array_agg(a order by a) from t where a > 3; +---- +NULL + +query ? +select array_agg(a order by a) from t where a > 3 group by a; +---- + +query ?I +select array_agg(a order by a), count(1) from t where a > 3 group by a; +---- + +statement ok +drop table t; + +# test with no values +statement ok +create table t(a int, b float, c bigint); + +query ? +select array_agg(a) from t; +---- +NULL + +query ? +select array_agg(b) from t; +---- +NULL + +query ? +select array_agg(c) from t; +---- +NULL + +query ?I +select array_agg(distinct a), count(1) from t; +---- +NULL 0 + +query ?I +select array_agg(distinct b), count(1) from t; +---- +NULL 0 + +query ?I +select array_agg(distinct b), count(1) from t; +---- +NULL 0 + +statement ok +drop table t; + + +# array_agg_i32 +statement ok +create table t (c1 int) as values (1), (2), (3), (4), (5); + +query ? +select array_agg(c1) from t; +---- +[1, 2, 3, 4, 5] + +statement ok +drop table t; + +# array_agg_nested +statement ok +create table t as values (make_array([1, 2, 3], [4, 5])), (make_array([6], [7, 8])), (make_array([9])); + +query ? +select array_agg(column1) from t; +---- +[[[1, 2, 3], [4, 5]], [[6], [7, 8]], [[9]]] + +statement ok +drop table t; + +# variance_single_value +query RRRR +select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq; +---- +NULL 0 NULL 0 + +# variance_two_values +query RRRR +select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0), (3.0)) as sq; +---- +2 1 1.414213562373 1 # aggregates on empty tables statement ok -CREATE TABLE empty (column1 bigint, column2 int); +CREATE TABLE empty (column1 bigint, column2 int); + +# no group by column +query IIRIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1) +FROM empty +---- +0 NULL NULL NULL NULL NULL NULL NULL + +# Same query but with grouping (no groups, so no output) +query IIRIIIIII +SELECT + count(column1), + sum(column1), + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1), + column2 +FROM empty +GROUP BY column2 +ORDER BY column2; +---- + + +statement ok +drop table empty + +# aggregates on all nulls +statement ok +CREATE TABLE the_nulls +AS VALUES + (null::bigint, 1), + (null::bigint, 1), + (null::bigint, 2); + +query II +select * from the_nulls +---- +NULL 1 +NULL 1 +NULL 2 + +# no group by column +query IIRIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1) +FROM the_nulls +---- +0 NULL NULL NULL NULL NULL NULL NULL + +# Same query but with grouping +query IIRIIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1), + column2 +FROM the_nulls +GROUP BY column2 +ORDER BY column2; +---- +0 NULL NULL NULL NULL NULL NULL NULL 1 +0 NULL NULL NULL NULL NULL NULL NULL 2 + + +statement ok +drop table the_nulls; + +statement ok +create table bit_aggregate_functions ( + c1 SMALLINT NOT NULL, + c2 SMALLINT NOT NULL, + c3 SMALLINT, + tag varchar +) +as values + (5, 10, 11, 'A'), + (33, 11, null, 'B'), + (9, 12, null, 'A'); + +# query_bit_and, query_bit_or, query_bit_xor +query IIIIIIIII +SELECT + bit_and(c1), + bit_and(c2), + bit_and(c3), + bit_or(c1), + bit_or(c2), + bit_or(c3), + bit_xor(c1), + bit_xor(c2), + bit_xor(c3) +FROM bit_aggregate_functions +---- +1 8 11 45 15 11 45 13 11 + +# query_bit_and, query_bit_or, query_bit_xor, with group +query IIIIIIIIIT +SELECT + bit_and(c1), + bit_and(c2), + bit_and(c3), + bit_or(c1), + bit_or(c2), + bit_or(c3), + bit_xor(c1), + bit_xor(c2), + bit_xor(c3), + tag +FROM bit_aggregate_functions +GROUP BY tag +ORDER BY tag +---- +1 8 11 13 14 11 12 6 11 A +33 11 NULL 33 11 NULL 33 11 NULL B + + +# bit_and_i32 +statement ok +create table t (c int) as values (4), (7), (15); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +4 Int32 + +statement ok +drop table t; + +# bit_and_i32_with_nulls +statement ok +create table t (c int) as values (1), (NULL), (3), (5); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +1 Int32 + +statement ok +drop table t; + +# bit_and_i32_all_nulls +statement ok +create table t (c int) as values (NULL), (NULL); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +NULL Int32 + +statement ok +drop table t; + +# bit_and_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (15); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +4 UInt32 + +statement ok +drop table t; + +# bit_or_i32 +statement ok +create table t (c int) as values (4), (7), (15); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +15 Int32 + +statement ok +drop table t; + +# bit_or_i32_with_nulls +statement ok +create table t (c int) as values (1), (NULL), (3), (5); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +7 Int32 + +statement ok +drop table t; + +#bit_or_i32_all_nulls +statement ok +create table t (c int) as values (NULL), (NULL); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +NULL Int32 + +statement ok +drop table t; + + +#bit_or_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (15); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +15 UInt32 + +statement ok +drop table t; + +#bit_xor_i32 +statement ok +create table t (c int) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +15 Int32 + +statement ok +drop table t; + +# bit_xor_i32_with_nulls +statement ok +create table t (c int) as values (1), (1), (NULL), (3), (5); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +6 Int32 + +statement ok +drop table t; + +# bit_xor_i32_all_nulls +statement ok +create table t (c int) as values (NULL), (NULL); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +NULL Int32 + +statement ok +drop table t; + +# bit_xor_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +15 UInt32 + +statement ok +drop table t; + +# bit_xor_distinct_i32 +statement ok +create table t (c int) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +12 Int32 + +statement ok +drop table t; + +# bit_xor_distinct_i32_with_nulls +statement ok +create table t (c int) as values (1), (1), (NULL), (3), (5); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +7 Int32 + + +statement ok +drop table t; + +# bit_xor_distinct_i32_all_nulls +statement ok +create table t (c int ) as values (NULL), (NULL); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +NULL Int32 + + +statement ok +drop table t; + +# bit_xor_distinct_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +12 UInt32 + +statement ok +drop table t; + +################# +# Min_Max Begin # +################# +# min_decimal, max_decimal +statement ok +CREATE TABLE decimals (value DECIMAL(10, 2)); + +statement ok +INSERT INTO decimals VALUES (123.0001), (124.00); + +query RR +SELECT MIN(value), MAX(value) FROM decimals; +---- +123 124 + +statement ok +DROP TABLE decimals; + +statement ok +CREATE TABLE decimals_batch (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_batch VALUES (1), (2), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_batch; +---- +1 5 + +statement ok +DROP TABLE decimals_batch; + +statement ok +CREATE TABLE decimals_empty (value DECIMAL(10, 0)); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_empty; +---- +NULL NULL + +statement ok +DROP TABLE decimals_empty; + +# min_decimal_all_nulls, max_decimal_all_nulls +statement ok +CREATE TABLE decimals_all_nulls (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_all_nulls VALUES (NULL), (NULL), (NULL), (NULL), (NULL), (NULL); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_all_nulls; +---- +NULL NULL + +statement ok +DROP TABLE decimals_all_nulls; + +# min_decimal_with_nulls, max_decimal_with_nulls +statement ok +CREATE TABLE decimals_with_nulls (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_with_nulls VALUES (1), (NULL), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_with_nulls; +---- +1 5 + +statement ok +DROP TABLE decimals_with_nulls; + +statement ok +CREATE TABLE decimals_error (value DECIMAL(10, 2)); + +statement ok +INSERT INTO decimals_error VALUES (123.00), (arrow_cast(124.001, 'Decimal128(10, 3)')); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_error; +---- +123 124 + +statement ok +DROP TABLE decimals_error; + +statement ok +CREATE TABLE decimals_agg (value DECIMAL(10, 0)); + +statement ok +INSERT INTO decimals_agg VALUES (1), (2), (3), (4), (5); + +query RR +SELECT MIN(value), MAX(value) FROM decimals_agg; +---- +1 5 + +statement ok +DROP TABLE decimals_agg; + +# min_i32, max_i32 +statement ok +CREATE TABLE integers (value INT); + +statement ok +INSERT INTO integers VALUES (1), (2), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM integers +---- +1 5 + +statement ok +DROP TABLE integers; + +# min_utf8, max_utf8 +statement ok +CREATE TABLE strings (value TEXT); + +statement ok +INSERT INTO strings VALUES ('d'), ('a'), ('c'), ('b'); + +query TT +SELECT MIN(value), MAX(value) FROM strings +---- +a d + +statement ok +DROP TABLE strings; + +# min_i32_with_nulls, max_i32_with_nulls +statement ok +CREATE TABLE integers_with_nulls (value INT); + +statement ok +INSERT INTO integers_with_nulls VALUES (1), (NULL), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM integers_with_nulls +---- +1 5 + +# grouping_sets with null values +query II rowsort +SELECT value, min(value) FROM integers_with_nulls GROUP BY CUBE(value) +---- +1 1 +3 3 +4 4 +5 5 +NULL 1 +NULL NULL + + +statement ok +DROP TABLE integers_with_nulls; + +# min_i32_all_nulls, max_i32_all_nulls +statement ok +CREATE TABLE integers_all_nulls (value INT); + +query II +SELECT MIN(value), MAX(value) FROM integers_all_nulls +---- +NULL NULL + +statement ok +DROP TABLE integers_all_nulls; + +# min_u32, max_u32 +statement ok +CREATE TABLE uintegers (value INT UNSIGNED); + +statement ok +INSERT INTO uintegers VALUES (1), (2), (3), (4), (5); + +query II +SELECT MIN(value), MAX(value) FROM uintegers +---- +1 5 + +statement ok +DROP TABLE uintegers; + +# min_f32, max_f32 +statement ok +CREATE TABLE floats (value FLOAT); + +statement ok +INSERT INTO floats VALUES (1.0), (2.0), (3.0), (4.0), (5.0); + +query RR +SELECT MIN(value), MAX(value) FROM floats +---- +1 5 + +statement ok +DROP TABLE floats; + +# min_f64, max_f64 +statement ok +CREATE TABLE doubles (value DOUBLE); + +statement ok +INSERT INTO doubles VALUES (1.0), (2.0), (3.0), (4.0), (5.0); + +query RR +SELECT MIN(value), MAX(value) FROM doubles +---- +1 5 + +statement ok +DROP TABLE doubles; + +# min_date, max_date +statement ok +CREATE TABLE dates (value DATE); + +statement ok +INSERT INTO dates VALUES ('1970-01-02'), ('1970-01-03'), ('1970-01-04'), ('1970-01-05'), ('1970-01-06'); -# no group by column -query IIRIIIII -SELECT - count(column1), -- counts should be zero, even for nulls - sum(column1), -- other aggregates should be null - avg(column1), - min(column1), - max(column1), - bit_and(column1), - bit_or(column1), - bit_xor(column1) -FROM empty +query DD +SELECT MIN(value), MAX(value) FROM dates ---- -0 NULL NULL NULL NULL NULL NULL NULL +1970-01-02 1970-01-06 -# Same query but with grouping (no groups, so no output) -query IIRIIIIII -SELECT - count(column1), - sum(column1), - avg(column1), - min(column1), - max(column1), - bit_and(column1), - bit_or(column1), - bit_xor(column1), - column2 -FROM empty -GROUP BY column2 -ORDER BY column2; +statement ok +DROP TABLE dates; + +# min_seconds, max_seconds +statement ok +CREATE TABLE times (value TIME); + +statement ok +INSERT INTO times VALUES ('00:00:01'), ('00:00:02'), ('00:00:03'), ('00:00:04'), ('00:00:05'); + +query DD +SELECT MIN(value), MAX(value) FROM times ---- +00:00:01 00:00:05 +statement ok +DROP TABLE times; +# min_milliseconds, max_milliseconds statement ok -drop table empty +CREATE TABLE time32millisecond (value TIME); -# aggregates on all nulls statement ok -CREATE TABLE the_nulls -AS VALUES - (null::bigint, 1), - (null::bigint, 1), - (null::bigint, 2); +INSERT INTO time32millisecond VALUES ('00:00:00.001'), ('00:00:00.002'), ('00:00:00.003'), ('00:00:00.004'), ('00:00:00.005'); -query II -select * from the_nulls +query DD +SELECT MIN(value), MAX(value) FROM time32millisecond ---- -NULL 1 -NULL 1 -NULL 2 +00:00:00.001 00:00:00.005 -# no group by column -query IIRIIIII -SELECT - count(column1), -- counts should be zero, even for nulls - sum(column1), -- other aggregates should be null - avg(column1), - min(column1), - max(column1), - bit_and(column1), - bit_or(column1), - bit_xor(column1) -FROM the_nulls +statement ok +DROP TABLE time32millisecond; + +# min_microseconds, max_microseconds +statement ok +CREATE TABLE time64microsecond (value TIME); + +statement ok +INSERT INTO time64microsecond VALUES ('00:00:00.000001'), ('00:00:00.000002'), ('00:00:00.000003'), ('00:00:00.000004'), ('00:00:00.000005'); + +query DD +SELECT MIN(value), MAX(value) FROM time64microsecond ---- -0 NULL NULL NULL NULL NULL NULL NULL +00:00:00.000001 00:00:00.000005 -# Same query but with grouping -query IIRIIIIII -SELECT - count(column1), -- counts should be zero, even for nulls - sum(column1), -- other aggregates should be null - avg(column1), - min(column1), - max(column1), - bit_and(column1), - bit_or(column1), - bit_xor(column1), - column2 -FROM the_nulls -GROUP BY column2 -ORDER BY column2; +statement ok +DROP TABLE time64microsecond; + +# min_nanoseconds, max_nanoseconds +statement ok +CREATE TABLE time64nanosecond (value TIME); + +statement ok +INSERT INTO time64nanosecond VALUES ('00:00:00.000000001'), ('00:00:00.000000002'), ('00:00:00.000000003'), ('00:00:00.000000004'), ('00:00:00.000000005'); + +query DD +SELECT MIN(value), MAX(value) FROM time64nanosecond ---- -0 NULL NULL NULL NULL NULL NULL NULL 1 -0 NULL NULL NULL NULL NULL NULL NULL 2 +00:00:00.000000001 00:00:00.000000005 + +statement ok +DROP TABLE time64nanosecond; +# min_timestamp, max_timestamp +statement ok +CREATE TABLE timestampmicrosecond (value TIMESTAMP); statement ok -drop table the_nulls; +INSERT INTO timestampmicrosecond VALUES ('1970-01-01 00:00:00.000001'), ('1970-01-01 00:00:00.000002'), ('1970-01-01 00:00:00.000003'), ('1970-01-01 00:00:00.000004'), ('1970-01-01 00:00:00.000005'); + +query PP +SELECT MIN(value), MAX(value) FROM timestampmicrosecond +---- +1970-01-01T00:00:00.000001 1970-01-01T00:00:00.000005 statement ok -create table bit_aggregate_functions ( - c1 SMALLINT NOT NULL, - c2 SMALLINT NOT NULL, - c3 SMALLINT, - tag varchar -) -as values - (5, 10, 11, 'A'), - (33, 11, null, 'B'), - (9, 12, null, 'A'); +DROP TABLE timestampmicrosecond; + +# max_bool +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (false), (false); + +query B +SELECT MAX(value) FROM max_bool +---- +false + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (true), (true); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (false), (true), (false); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +statement ok +CREATE TABLE max_bool (value BOOLEAN); + +statement ok +INSERT INTO max_bool VALUES (true), (false), (true); + +query B +SELECT MAX(value) FROM max_bool +---- +true + +statement ok +DROP TABLE max_bool; + +# min_bool +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (false), (false); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (true), (true); + +query B +SELECT MIN(value) FROM min_bool +---- +true + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (false), (true), (false); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +statement ok +CREATE TABLE min_bool (value BOOLEAN); + +statement ok +INSERT INTO min_bool VALUES (true), (false), (true); + +query B +SELECT MIN(value) FROM min_bool +---- +false + +statement ok +DROP TABLE min_bool; + +################# +# Min_Max End # +################# + + + +################# +# min_max on strings/binary with null values and groups +################# + +statement ok +CREATE TABLE strings (value TEXT, id int); + +statement ok +INSERT INTO strings VALUES + ('c', 1), + ('d', 1), + ('a', 3), + ('c', 1), + ('b', 1), + (NULL, 1), + (NULL, 4), + ('d', 1), + ('z', 2), + ('c', 1), + ('a', 2); + +############ Utf8 ############ + +query IT +SELECT id, MIN(value) FROM strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +############ LargeUtf8 ############ + +statement ok +CREATE VIEW large_strings AS SELECT id, arrow_cast(value, 'LargeUtf8') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM large_strings GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW large_strings + +############ Utf8View ############ + +statement ok +CREATE VIEW string_views AS SELECT id, arrow_cast(value, 'Utf8View') as value FROM strings; + + +query IT +SELECT id, MIN(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 b +2 a +3 a +4 NULL + +query IT +SELECT id, MAX(value) FROM string_views GROUP BY id ORDER BY id; +---- +1 d +2 z +3 a +4 NULL + +statement ok +DROP VIEW string_views + +############ Binary ############ + +statement ok +CREATE VIEW binary AS SELECT id, arrow_cast(value, 'Binary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary + +############ LargeBinary ############ + +statement ok +CREATE VIEW large_binary AS SELECT id, arrow_cast(value, 'LargeBinary') as value FROM strings; + + +query I? +SELECT id, MIN(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 62 +2 61 +3 61 +4 NULL + +query I? +SELECT id, MAX(value) FROM large_binary GROUP BY id ORDER BY id; +---- +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW large_binary + +############ BinaryView ############ + +statement ok +CREATE VIEW binary_views AS SELECT id, arrow_cast(value, 'BinaryView') as value FROM strings; -# query_bit_and, query_bit_or, query_bit_xor -query IIIIIIIII -SELECT - bit_and(c1), - bit_and(c2), - bit_and(c3), - bit_or(c1), - bit_or(c2), - bit_or(c3), - bit_xor(c1), - bit_xor(c2), - bit_xor(c3) -FROM bit_aggregate_functions + +query I? +SELECT id, MIN(value) FROM binary_views GROUP BY id ORDER BY id; ---- -1 8 11 45 15 11 45 13 11 +1 62 +2 61 +3 61 +4 NULL -# query_bit_and, query_bit_or, query_bit_xor, with group -query IIIIIIIIIT -SELECT - bit_and(c1), - bit_and(c2), - bit_and(c3), - bit_or(c1), - bit_or(c2), - bit_or(c3), - bit_xor(c1), - bit_xor(c2), - bit_xor(c3), - tag -FROM bit_aggregate_functions -GROUP BY tag -ORDER BY tag +query I? +SELECT id, MAX(value) FROM binary_views GROUP BY id ORDER BY id; ---- -1 8 11 13 14 11 12 6 11 A -33 11 NULL 33 11 NULL 33 11 NULL B +1 64 +2 7a +3 61 +4 NULL + +statement ok +DROP VIEW binary_views + +statement ok +DROP TABLE strings; + +################# +# End min_max on strings/binary with null values and groups +################# statement ok @@ -1849,7 +4001,7 @@ create table bool_aggregate_functions ( c5 boolean, c6 boolean, c7 boolean, - c8 boolean, + c8 boolean ) as values (true, true, false, false, true, true, null, null), @@ -1880,6 +4032,51 @@ SELECT bool_or(distinct c1), bool_or(distinct c2), bool_or(distinct c3), bool_or ---- true true true false true true false NULL +# Test issue: https://github.com/apache/datafusion/issues/11846 +statement ok +create table t1(v1 int, v2 boolean); + +statement ok +insert into t1 values (1, true), (1, true); + +statement ok +insert into t1 values (3, null), (3, true); + +statement ok +insert into t1 values (2, false), (2, true); + +statement ok +insert into t1 values (6, false), (6, false); + +statement ok +insert into t1 values (4, null), (4, null); + +statement ok +insert into t1 values (5, false), (5, null); + +query IB +select v1, bool_and(v2) from t1 group by v1 order by v1; +---- +1 true +2 false +3 true +4 NULL +5 false +6 false + +query IB +select v1, bool_or(v2) from t1 group by v1 order by v1; +---- +1 true +2 true +3 true +4 NULL +5 false +6 false + +statement ok +drop table t1; + # All supported timestamp types # "nanos" --> TimestampNanosecondArray @@ -1911,7 +4108,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query PPPPPPPPTT select * from t; ---- @@ -1922,10 +4119,10 @@ NULL NULL NULL NULL NULL NULL NULL NULL Row 2 Y # aggregate_timestamps_sum -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t; -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY tag ORDER BY tag; # aggregate_timestamps_count @@ -1977,10 +4174,10 @@ X 2 2 2 2 Y 1 1 1 1 # aggregate_timestamps_avg -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Timestamp\(Nanosecond, None\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; # aggregate_duration_array_agg @@ -2022,7 +4219,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query DDTT select * from t; ---- @@ -2033,10 +4230,10 @@ NULL NULL Row 2 Y # aggregate_timestamps_sum -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT sum(date32), sum(date64) FROM t; -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT tag, sum(date32), sum(date64) FROM t GROUP BY tag ORDER BY tag; # aggregate_timestamps_count @@ -2077,10 +4274,10 @@ Y 2021-01-01 2021-01-01T00:00:00 # aggregate_timestamps_avg -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT avg(date32), avg(date64) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Date32\)'\. You might need to add explicit type casts\. +query error SELECT tag, avg(date32), avg(date64) FROM t GROUP BY tag ORDER BY tag; @@ -2120,7 +4317,7 @@ select column3 as tag from t_source; -# Demonstate the contents +# Demonstrate the contents query DDDDTT select * from t; ---- @@ -2130,10 +4327,10 @@ select * from t; 21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 Row 3 B # aggregate_times_sum -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT sum(nanos), sum(micros), sum(millis), sum(secs) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'SUM\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, sum(nanos), sum(micros), sum(millis), sum(secs) FROM t GROUP BY tag ORDER BY tag # aggregate_times_count @@ -2175,10 +4372,10 @@ B 21:06:28.247821084 21:06:28.247821 21:06:28.247 21:06:28 # aggregate_times_avg -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT avg(nanos), avg(micros), avg(millis), avg(secs) FROM t -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Time64\(Nanosecond\)\)'\. You might need to add explicit type casts\. +query error SELECT tag, avg(nanos), avg(micros), avg(millis), avg(secs) FROM t GROUP BY tag ORDER BY tag; statement ok @@ -2314,6 +4511,81 @@ select sum(c1), arrow_typeof(sum(c1)) from d_table; ---- 100 Decimal128(20, 3) +# aggregate sum with decimal +statement ok +create table t (c decimal(35, 3)) as values (10), (null), (20); + +query RT +select sum(c), arrow_typeof(sum(c)) from t; +---- +30 Decimal128(38, 3) + +statement ok +drop table t; + +# aggregate sum with i32, sum coerced result to i64 +statement ok +create table t (c int) as values (1), (-1), (10), (null), (-11); + +query IT +select sum(c), arrow_typeof(sum(c)) from t; +---- +-1 Int64 + +statement ok +drop table t; + +# aggregate sum with all nulls +statement ok +create table t (c1 decimal(10, 0), c2 int) as values (null, null), (null, null), (null, null); + +query RTIT +select + sum(c1), arrow_typeof(sum(c1)), + sum(c2), arrow_typeof(sum(c2)) +from t; +---- +NULL Decimal128(20, 0) NULL Int64 + +statement ok +drop table t; + +# aggregate sum with u32, sum coerced result to u64 +statement ok +create table t (c int unsigned) as values (1), (0), (10), (null), (4); + +query IT +select sum(c), arrow_typeof(sum(c)) from t; +---- +15 UInt64 + +statement ok +drop table t; + +# aggregate sum with f32, sum coerced result to f64 +statement ok +create table t (c float) as values (1.2), (0.2), (-1.2), (null), (-1.0); + +query RT +select sum(c), arrow_typeof(sum(c)) from t; +---- +-0.79999999702 Float64 + +statement ok +drop table t; + +# aggregate sum with f64 +statement ok +create table t (c double) as values (1.2), (0.2), (-1.2), (null), (-1.0); + +query RT +select sum(c), arrow_typeof(sum(c)) from t; +---- +-0.8 Float64 + +statement ok +drop table t; + query TRT select c2, sum(c1), arrow_typeof(sum(c1)) from d_table GROUP BY c2 ORDER BY c2; ---- @@ -2486,7 +4758,7 @@ set datafusion.sql_parser.dialect = 'Generic'; statement ok create table dict_test as values (1, arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('bar', 'Dictionary(Int32, Utf8)')); -query I? +query IT select * from dict_test; ---- 1 foo @@ -2537,7 +4809,7 @@ select avg(distinct x_dict) from value_dict; ---- 3 -statement error DataFusion error: This feature is not implemented: AVG\(DISTINCT\) aggregations are not available +query error select avg(x_dict), avg(distinct x_dict) from value_dict; query I @@ -2640,6 +4912,33 @@ false true NULL +# +# Add valid distinct case as aggregation plan test +# + +query TT +EXPLAIN SELECT DISTINCT c3, min(c1) FROM aggregate_test_100 group by c3 limit 5; +---- +logical_plan +01)Limit: skip=0, fetch=5 +02)--Aggregate: groupBy=[[aggregate_test_100.c3, min(aggregate_test_100.c1)]], aggr=[[]] +03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[min(aggregate_test_100.c1)]] +04)------TableScan: aggregate_test_100 projection=[c1, c3] +physical_plan +01)GlobalLimitExec: skip=0, fetch=5 +02)--CoalescePartitionsExec +03)----AggregateExec: mode=FinalPartitioned, gby=[c3@0 as c3, min(aggregate_test_100.c1)@1 as min(aggregate_test_100.c1)], aggr=[], lim=[5] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([c3@0, min(aggregate_test_100.c1)@1], 4), input_partitions=4 +06)----------AggregateExec: mode=Partial, gby=[c3@0 as c3, min(aggregate_test_100.c1)@1 as min(aggregate_test_100.c1)], aggr=[], lim=[5] +07)------------AggregateExec: mode=FinalPartitioned, gby=[c3@0 as c3], aggr=[min(aggregate_test_100.c1)] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------RepartitionExec: partitioning=Hash([c3@0], 4), input_partitions=4 +10)------------------AggregateExec: mode=Partial, gby=[c3@1 as c3], aggr=[min(aggregate_test_100.c1)] +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], has_header=true + + # # Push limit into distinct group-by aggregation tests # @@ -2655,19 +4954,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; logical_plan 01)Limit: skip=0, fetch=5 02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -04)------TableScan: aggregate_test_100 projection=[c3] +03)----TableScan: aggregate_test_100 projection=[c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[], lim=[5] -07)------------CoalescePartitionsExec -08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[], lim=[5] -09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true query I SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; @@ -2741,16 +5035,16 @@ query TT EXPLAIN SELECT max(c1), c2, c3 FROM aggregate_test_100 group by c2, c3 limit 5; ---- logical_plan -01)Projection: MAX(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 +01)Projection: max(aggregate_test_100.c1), aggregate_test_100.c2, aggregate_test_100.c3 02)--Limit: skip=0, fetch=5 -03)----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[MAX(aggregate_test_100.c1)]] +03)----Aggregate: groupBy=[[aggregate_test_100.c2, aggregate_test_100.c3]], aggr=[[max(aggregate_test_100.c1)]] 04)------TableScan: aggregate_test_100 projection=[c1, c2, c3] physical_plan -01)ProjectionExec: expr=[MAX(aggregate_test_100.c1)@2 as MAX(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] +01)ProjectionExec: expr=[max(aggregate_test_100.c1)@2 as max(aggregate_test_100.c1), c2@0 as c2, c3@1 as c3] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[MAX(aggregate_test_100.c1)] +03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[max(aggregate_test_100.c1)] 04)------CoalescePartitionsExec -05)--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[MAX(aggregate_test_100.c1)] +05)--------AggregateExec: mode=Partial, gby=[c2@1 as c2, c3@2 as c3], aggr=[max(aggregate_test_100.c1)] 06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true @@ -2789,16 +5083,18 @@ query TT EXPLAIN SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; ---- logical_plan -01)Limit: skip=0, fetch=3 -02)--Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] -03)----TableScan: aggregate_test_100 projection=[c2, c3] +01)Projection: aggregate_test_100.c2, aggregate_test_100.c3 +02)--Limit: skip=0, fetch=3 +03)----Aggregate: groupBy=[[ROLLUP (aggregate_test_100.c2, aggregate_test_100.c3)]], aggr=[[]] +04)------TableScan: aggregate_test_100 projection=[c2, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=3 -02)--AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3], aggr=[], lim=[3] -03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true +01)ProjectionExec: expr=[c2@0 as c2, c3@1 as c3] +02)--GlobalLimitExec: skip=0, fetch=3 +03)----AggregateExec: mode=Final, gby=[c2@0 as c2, c3@1 as c3, __grouping_id@2 as __grouping_id], aggr=[], lim=[3] +04)------CoalescePartitionsExec +05)--------AggregateExec: mode=Partial, gby=[(NULL as c2, NULL as c3), (c2@0 as c2, NULL as c3), (c2@0 as c2, c3@1 as c3)], aggr=[] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3], has_header=true query II SELECT c2, c3 FROM aggregate_test_100 group by rollup(c2, c3) limit 3; @@ -2818,19 +5114,14 @@ EXPLAIN SELECT DISTINCT c3 FROM aggregate_test_100 group by c3 limit 5; logical_plan 01)Limit: skip=0, fetch=5 02)--Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -03)----Aggregate: groupBy=[[aggregate_test_100.c3]], aggr=[[]] -04)------TableScan: aggregate_test_100 projection=[c3] +03)----TableScan: aggregate_test_100 projection=[c3] physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] 03)----CoalescePartitionsExec 04)------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -06)----------AggregateExec: mode=Final, gby=[c3@0 as c3], aggr=[] -07)------------CoalescePartitionsExec -08)--------------AggregateExec: mode=Partial, gby=[c3@0 as c3], aggr=[] -09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3], has_header=true statement ok set datafusion.optimizer.enable_distinct_aggregation_soft_limit = true; @@ -2873,27 +5164,27 @@ select regr_sxy(NULL, 'bar'); # regr_*() NULL results -query RRRRRRRRR +query RRIRRRRRR select regr_slope(1,1), regr_intercept(1,1), regr_count(1,1), regr_r2(1,1), regr_avgx(1,1), regr_avgy(1,1), regr_sxx(1,1), regr_syy(1,1), regr_sxy(1,1); ---- NULL NULL 1 NULL 1 1 0 0 0 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(1, NULL), regr_intercept(1, NULL), regr_count(1, NULL), regr_r2(1, NULL), regr_avgx(1, NULL), regr_avgy(1, NULL), regr_sxx(1, NULL), regr_syy(1, NULL), regr_sxy(1, NULL); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query RRRRRRRRR +query RRIRRRRRR select regr_slope(NULL, 1), regr_intercept(NULL, 1), regr_count(NULL, 1), regr_r2(NULL, 1), regr_avgx(NULL, 1), regr_avgy(NULL, 1), regr_sxx(NULL, 1), regr_syy(NULL, 1), regr_sxy(NULL, 1); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query RRRRRRRRR +query RRIRRRRRR select regr_slope(NULL, NULL), regr_intercept(NULL, NULL), regr_count(NULL, NULL), regr_r2(NULL, NULL), regr_avgx(NULL, NULL), regr_avgy(NULL, NULL), regr_sxx(NULL, NULL), regr_syy(NULL, NULL), regr_sxy(NULL, NULL); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), regr_count(column2, column1), regr_r2(column2, column1), regr_avgx(column2, column1), regr_avgy(column2, column1), regr_sxx(column2, column1), regr_syy(column2, column1), regr_sxy(column2, column1) from (values (1,2), (1,4), (1,6)); ---- NULL NULL 3 NULL 1 4 0 8 0 @@ -2901,7 +5192,7 @@ NULL NULL 3 NULL 1 4 0 8 0 # regr_*() basic tests -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -2916,7 +5207,7 @@ from (values (1,2), (2,4), (3,6)); ---- 2 0 3 1 2 4 2 8 4 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -2934,7 +5225,7 @@ from aggregate_test_100; # regr_*() functions ignore NULLs -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -2949,7 +5240,7 @@ from (values (1,NULL), (2,4), (3,6)); ---- 2 0 2 1 2.5 5 0.5 2 1 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -2964,7 +5255,7 @@ from (values (1,NULL), (NULL,4), (3,6)); ---- NULL NULL 1 NULL 3 6 0 0 0 -query RRRRRRRRR +query RRIRRRRRR select regr_slope(column2, column1), regr_intercept(column2, column1), @@ -2979,7 +5270,7 @@ from (values (1,NULL), (NULL,4), (NULL,NULL)); ---- NULL NULL 0 NULL NULL NULL NULL NULL NULL -query TRRRRRRRRR rowsort +query TRRIRRRRRR rowsort select column3, regr_slope(column2, column1), @@ -3004,7 +5295,7 @@ c NULL NULL 1 NULL 1 10 0 0 0 statement ok set datafusion.execution.batch_size = 1; -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -3022,7 +5313,7 @@ from aggregate_test_100; statement ok set datafusion.execution.batch_size = 2; -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -3040,7 +5331,7 @@ from aggregate_test_100; statement ok set datafusion.execution.batch_size = 3; -query RRRRRRRRR +query RRIRRRRRR select regr_slope(c12, c11), regr_intercept(c12, c11), @@ -3061,7 +5352,7 @@ set datafusion.execution.batch_size = 8192; # regr_*() testing retract_batch() from RegrAccumulator's internal implementation -query RRRRRRRRR +query RRIRRRRRR SELECT regr_slope(column2, column1) OVER w AS slope, regr_intercept(column2, column1) OVER w AS intercept, @@ -3082,7 +5373,7 @@ NULL NULL 1 NULL 1 2 0 0 0 4.5 -7 3 0.964285714286 4 11 2 42 9 3 0 3 1 5 15 2 18 6 -query RRRRRRRRR +query RRIRRRRRR SELECT regr_slope(column2, column1) OVER w AS slope, regr_intercept(column2, column1) OVER w AS intercept, @@ -3137,7 +5428,7 @@ SELECT STRING_AGG(column1, '|') FROM (values (''), (null), ('')); statement ok CREATE TABLE strings(g INTEGER, x VARCHAR, y VARCHAR) -query ITT +query I INSERT INTO strings VALUES (1,'a','/'), (1,'b','-'), (2,'i','/'), (2,NULL,'-'), (2,'j','+'), (3,'p','/'), (4,'x','/'), (4,'y','-'), (4,'z','+') ---- 9 @@ -3181,6 +5472,57 @@ GROUP BY dummy ---- text1, text1, text1 +# Tests for aggregating with NaN values +statement ok +CREATE TABLE float_table ( + col_f32 FLOAT, + col_f32_nan FLOAT, + col_f64 DOUBLE, + col_f64_nan DOUBLE +) as VALUES +( -128.2, -128.2, -128.2, -128.2 ), +( 32768.3, arrow_cast('NAN','Float32'), 32768.3, 32768.3 ), +( 27.3, 27.3, 27.3, arrow_cast('NAN','Float64') ); + +# Test string_agg with largeutf8 +statement ok +create table string_agg_large_utf8 (c string) as values + (arrow_cast('a', 'LargeUtf8')), + (arrow_cast('b', 'LargeUtf8')), + (arrow_cast('c', 'LargeUtf8')) +; + +query T +SELECT STRING_AGG(c, ',') FROM string_agg_large_utf8; +---- +a,b,c + +statement ok +drop table string_agg_large_utf8; + +query RRRRI +select min(col_f32), max(col_f32), avg(col_f32), sum(col_f32), count(col_f32) from float_table; +---- +-128.2 32768.3 10889.13359451294 32667.40078353882 3 + +query RRRRI +select min(col_f32_nan), max(col_f32_nan), avg(col_f32_nan), sum(col_f32_nan), count(col_f32_nan) from float_table; +---- +-128.2 NaN NaN NaN 3 + +query RRRRI +select min(col_f64), max(col_f64), avg(col_f64), sum(col_f64), count(col_f64) from float_table; +---- +-128.2 32768.3 10889.133333333333 32667.4 3 + +query RRRRI +select min(col_f64_nan), max(col_f64_nan), avg(col_f64_nan), sum(col_f64_nan), count(col_f64_nan) from float_table; +---- +-128.2 NaN NaN NaN 3 + +statement ok +drop table float_table + # Queries with nested count(*) @@ -3330,25 +5672,37 @@ query TT EXPLAIN SELECT MIN(col0) FROM empty; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[MIN(empty.col0)]] +01)Aggregate: groupBy=[[]], aggr=[[min(empty.col0)]] 02)--TableScan: empty projection=[col0] physical_plan -01)ProjectionExec: expr=[NULL as MIN(empty.col0)] +01)ProjectionExec: expr=[NULL as min(empty.col0)] 02)--PlaceholderRowExec query TT EXPLAIN SELECT MAX(col0) FROM empty; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[MAX(empty.col0)]] +01)Aggregate: groupBy=[[]], aggr=[[max(empty.col0)]] 02)--TableScan: empty projection=[col0] physical_plan -01)ProjectionExec: expr=[NULL as MAX(empty.col0)] +01)ProjectionExec: expr=[NULL as max(empty.col0)] 02)--PlaceholderRowExec statement ok DROP TABLE empty; +# verify count aggregate function should not be nullable +statement ok +create table empty; + +query I +select distinct count() from empty; +---- +0 + +statement ok +DROP TABLE empty; + statement ok CREATE TABLE t(col0 INTEGER) as VALUES(2); @@ -3498,6 +5852,34 @@ SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t; statement ok DROP TABLE t; +# Test for CASE with NULL in aggregate function +statement ok +CREATE TABLE example(data double precision); + +statement ok +INSERT INTO example VALUES (1), (2), (NULL), (4); + +query RR +SELECT + sum(CASE WHEN data is NULL THEN NULL ELSE data+1 END) as then_null, + sum(CASE WHEN data is NULL THEN data+1 ELSE NULL END) as else_null +FROM example; +---- +10 NULL + +query R +SELECT + CASE data WHEN 1 THEN NULL WHEN 2 THEN 3.3 ELSE NULL END as case_null +FROM example; +---- +NULL +3.3 +NULL +NULL + +statement ok +drop table example; + # Test Convert FirstLast optimizer rule statement ok CREATE EXTERNAL TABLE convert_first_last_table ( @@ -3506,23 +5888,23 @@ c2 INT NOT NULL, c3 INT NOT NULL ) STORED AS CSV -WITH HEADER ROW WITH ORDER (c1 ASC) WITH ORDER (c2 DESC) WITH ORDER (c3 ASC) -LOCATION '../core/tests/data/convert_first_last.csv'; +LOCATION '../core/tests/data/convert_first_last.csv' +OPTIONS ('format.has_header' 'true'); # test first to last, the result does not show difference, we need to check the conversion by `explain` query TT explain select first_value(c1 order by c3 desc) from convert_first_last_table; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c3 DESC NULLS FIRST]]] +01)Aggregate: groupBy=[[]], aggr=[[first_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c3 DESC NULLS FIRST]]] 02)--TableScan: convert_first_last_table projection=[c1, c3] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c3 DESC NULLS FIRST]] +01)AggregateExec: mode=Final, gby=[], aggr=[first_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c3 DESC NULLS FIRST]] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[LAST_VALUE(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c3 ASC NULLS LAST]] +03)----AggregateExec: mode=Partial, gby=[], aggr=[last_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c3 ASC NULLS LAST]] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/convert_first_last.csv]]}, projection=[c1, c3], output_orderings=[[c1@0 ASC NULLS LAST], [c3@1 ASC NULLS LAST]], has_header=true @@ -3531,11 +5913,176 @@ query TT explain select last_value(c1 order by c2 asc) from convert_first_last_table; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[LAST_VALUE(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 ASC NULLS LAST]]] +01)Aggregate: groupBy=[[]], aggr=[[last_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 ASC NULLS LAST]]] 02)--TableScan: convert_first_last_table projection=[c1, c2] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[LAST_VALUE(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 ASC NULLS LAST]] +01)AggregateExec: mode=Final, gby=[], aggr=[last_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 ASC NULLS LAST]] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 DESC NULLS FIRST]] +03)----AggregateExec: mode=Partial, gby=[], aggr=[first_value(convert_first_last_table.c1) ORDER BY [convert_first_last_table.c2 DESC NULLS FIRST]] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/convert_first_last.csv]]}, projection=[c1, c2], output_orderings=[[c1@0 ASC NULLS LAST], [c2@1 DESC]], has_header=true + +# test building plan with aggreagte sum + +statement ok +create table employee_csv(id int, first_name string, last_name varchar, state varchar, salary bigint) as values (1, 'jenson', 'huang', 'unemployed', 10); + +query TI +select state, sum(salary) total_salary from employee_csv group by state; +---- +unemployed 10 + +statement ok +set datafusion.explain.logical_plan_only = true; + +query TT +explain select state, sum(salary) as total_salary from employee_csv group by state; +---- +logical_plan +01)Projection: employee_csv.state, sum(employee_csv.salary) AS total_salary +02)--Aggregate: groupBy=[[employee_csv.state]], aggr=[[sum(employee_csv.salary)]] +03)----TableScan: employee_csv projection=[state, salary] + +# fail if there is duplicate name +query error DataFusion error: Schema error: Schema contains qualified field name employee_csv\.state and unqualified field name state which would be ambiguous +select state, sum(salary) as state from employee_csv group by state; + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +drop table employee_csv; + +# test null literal handling in supported aggregate functions +query I??III?T +select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(NULL), nth_value(NULL, 1), string_agg(NULL, ','); +---- +0 NULL NULL NULL NULL NULL NULL NULL + +statement ok +create table having_test(v1 int, v2 int) + +statement ok +create table join_table(v1 int, v2 int) + +statement ok +insert into having_test values (1, 2), (2, 3), (3, 4) + +statement ok +insert into join_table values (1, 2), (2, 3), (3, 4) + + +query II +select * from having_test group by v1, v2 having max(v1) = 3 +---- +3 4 + +query TT +EXPLAIN select * from having_test group by v1, v2 having max(v1) = 3 +---- +logical_plan +01)Projection: having_test.v1, having_test.v2 +02)--Filter: max(having_test.v1) = Int32(3) +03)----Aggregate: groupBy=[[having_test.v1, having_test.v2]], aggr=[[max(having_test.v1)]] +04)------TableScan: having_test projection=[v1, v2] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: max(having_test.v1)@2 = 3, projection=[v1@0, v2@1] +03)----AggregateExec: mode=FinalPartitioned, gby=[v1@0 as v1, v2@1 as v2], aggr=[max(having_test.v1)] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([v1@0, v2@1], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[v1@0 as v1, v2@1 as v2], aggr=[max(having_test.v1)] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] + + +query error +select * from having_test having max(v1) = 3 + +query I +select max(v1) from having_test having max(v1) = 3 +---- +3 + +query I +select max(v1), * exclude (v1, v2) from having_test having max(v1) = 3 +---- +3 + +# because v1, v2 is not in the group by clause, the sql is invalid +query III +select max(v1), * replace ('v1' as v3) from having_test group by v1, v2 having max(v1) = 3 +---- +3 3 4 + +query III +select max(v1), t.* from having_test t group by v1, v2 having max(v1) = 3 +---- +3 3 4 + +# j.* should also be included in the group-by clause +query error +select max(t.v1), j.* from having_test t join join_table j on t.v1 = j.v1 group by t.v1, t.v2 having max(t.v1) = 3 + +query III +select max(t.v1), j.* from having_test t join join_table j on t.v1 = j.v1 group by j.v1, j.v2 having max(t.v1) = 3 +---- +3 3 4 + +# If the select items only contain scalar expressions, the having clause is valid. +query P +select now() from having_test having max(v1) = 4 +---- + +# If the select items only contain scalar expressions, the having clause is valid. +query I +select 0 from having_test having max(v1) = 4 +---- + +# v2 should also be included in group-by clause +query error +select * from having_test group by v1 having max(v1) = 3 + +statement ok +drop table having_test + +statement ok +drop table join_table + +# test min/max Float16 without group expression +query RRTT +WITH data AS ( + SELECT arrow_cast(1, 'Float16') AS f + UNION ALL + SELECT arrow_cast(6, 'Float16') AS f +) +SELECT MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) FROM data; +---- +1 6 Float16 Float16 + +# test min/max Float16 with group expression +query IRRTT +WITH data AS ( + SELECT 1 as k, arrow_cast(1.8125, 'Float16') AS f + UNION ALL + SELECT 1 as k, arrow_cast(6.8007813, 'Float16') AS f + UNION ALL + SELECT 2 AS k, arrow_cast(8.5, 'Float16') AS f +) +SELECT k, MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) +FROM data +GROUP BY k +ORDER BY k; +---- +1 1.8125 6.8007813 Float16 Float16 +2 8.5 8.5 Float16 Float16 + +statement ok +CREATE TABLE t1(v1 int); + +# issue: https://github.com/apache/datafusion/issues/12814 +statement error DataFusion error: Error during planning: Aggregate functions are not allowed in the WHERE clause. Consider using HAVING instead +SELECT v1 FROM t1 WHERE ((count(v1) % 1) << 1) > 0; + +statement ok +DROP TABLE t1; diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt new file mode 100644 index 000000000000..a2e51cffacf7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -0,0 +1,713 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The main goal of these tests is to verify correctness of transforming +# input values to state by accumulators, supporting `convert_to_state`. + + +# Setup test data table +statement ok +CREATE EXTERNAL TABLE aggregate_test_100 ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); + +# Table to test `bool_and()`, `bool_or()` aggregate functions +statement ok +CREATE TABLE aggregate_test_100_bool ( + v1 VARCHAR NOT NULL, + v2 BOOLEAN, + v3 BOOLEAN +); + +statement ok +INSERT INTO aggregate_test_100_bool +SELECT + c1 as v1, + CASE WHEN c2 > 3 THEN TRUE WHEN c2 > 1 THEN FALSE ELSE NULL END as v2, + CASE WHEN c1='a' OR c1='b' THEN TRUE WHEN c1='c' OR c1='d' THEN FALSE ELSE NULL END as v3 +FROM aggregate_test_100; + +# Prepare settings to skip partial aggregation from the beginning +statement ok +set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 0; + +statement ok +set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 1; + +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +# Grouping by unique fields allows to check all accumulators +query ITIIII +SELECT c5, c1, + COUNT(), + COUNT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + COUNT() FILTER (WHERE c1 = 'b'), + COUNT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c 1 0 0 0 +-2141451704 a 1 1 0 0 +-2138770630 b 1 0 1 0 +-2117946883 d 1 0 0 0 +-2098805236 c 1 0 0 0 + +query ITIIII +SELECT c5, c1, + MIN(c5), + MIN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + MIN(c5) FILTER (WHERE c1 = 'b'), + MIN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + MAX(c5), + MAX(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + MAX(c5) FILTER (WHERE c1 = 'b'), + MAX(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + SUM(c5), + SUM(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + SUM(c5) FILTER (WHERE c1 = 'b'), + SUM(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + MEDIAN(c5), + MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + MEDIAN(c5) FILTER (WHERE c1 = 'b'), + MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + APPROX_MEDIAN(c5), + APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + APPROX_MEDIAN(c5) FILTER (WHERE c1 = 'b'), + APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c -2141999138 NULL NULL NULL +-2141451704 a -2141451704 -2141451704 NULL NULL +-2138770630 b -2138770630 NULL -2138770630 NULL +-2117946883 d -2117946883 NULL NULL NULL +-2098805236 c -2098805236 NULL NULL NULL + +query ITIIII +SELECT c5, c1, + APPROX_DISTINCT(c5), + APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END), + APPROX_DISTINCT(c5) FILTER (WHERE c1 = 'b'), + APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b') +FROM aggregate_test_100 +GROUP BY 1, 2 ORDER BY 1 LIMIT 5; +---- +-2141999138 c 1 0 0 0 +-2141451704 a 1 1 0 0 +-2138770630 b 1 0 1 0 +-2117946883 d 1 0 0 0 +-2098805236 c 1 0 0 0 + +# FIXME: add bool_and(v3) column when issue fixed +# ISSUE https://github.com/apache/datafusion/issues/11846 +query TBBB rowsort +select v1, bool_or(v2), bool_and(v2), bool_or(v3) +from aggregate_test_100_bool +group by v1 +---- +a true false true +b true false true +c true false false +d true false false +e true false NULL + +query TBBB rowsort +select v1, + bool_or(v2) FILTER (WHERE v1 = 'a' OR v1 = 'c' OR v1 = 'e'), + bool_or(v2) FILTER (WHERE v2 = false), + bool_or(v2) FILTER (WHERE v2 = NULL) +from aggregate_test_100_bool +group by v1 +---- +a true false NULL +b NULL false NULL +c true false NULL +d NULL false NULL +e true false NULL + +# Prepare settings to always skip aggregation after couple of batches +statement ok +set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 10; + +statement ok +set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; + +statement ok +set datafusion.execution.target_partitions = 2; + +statement ok +set datafusion.execution.batch_size = 4; + +# Inserting into nullable table with batch_size specified above +# to prevent creation on single in-memory batch +statement ok +CREATE TABLE aggregate_test_100_null ( + c2 TINYINT NOT NULL, + c5 INT NOT NULL, + c3 SMALLINT, + c11 FLOAT +); + +statement ok +INSERT INTO aggregate_test_100_null +SELECT + c2, + c5, + CASE WHEN c1 = 'e' THEN NULL ELSE c3 END as c3, + CASE WHEN c1 = 'a' THEN NULL ELSE c11 END as c11 +FROM aggregate_test_100; + +# Test count varchar / int / float +query IIII +SELECT c2, count(c1), count(c5), count(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 22 22 22 +2 22 22 22 +3 19 19 19 +4 23 23 23 +5 14 14 14 + +# Test min / max for int / float +query IIIRR +SELECT c2, min(c5), max(c5), min(c11), max(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 -1991133944 2143473091 0.064453244 0.89651865 +2 -2138770630 2053379412 0.055064857 0.8315913 +3 -2141999138 2030965207 0.034291923 0.9488028 +4 -1885422396 2064155045 0.028003037 0.7459874 +5 -2117946883 2025611582 0.12559289 0.87989986 + +# Test sum for int / float +query IIR +SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 -438598674 12.153253793716 +2 -8259865364 9.577824473381 +3 1956035476 9.590891361237 +4 16155718643 9.531112968922 +5 6449337880 7.074412226677 + +# Test median for int / float +query IIR +SELECT c2, median(c5), median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 23971150 0.5922606 +2 -562486880 0.43422085 +3 240273900 0.40199697 +4 762932956 0.48515016 +5 604973998 0.49842384 + +# Test approx_median for int / float +query IIR +SELECT c2, approx_median(c5), approx_median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 191655437 0.59926736 +2 -587831330 0.43230486 +3 240273900 0.40199697 +4 762932956 0.48515016 +5 593204320 0.5156586 + +# Test approx_distinct for varchar / int +query III +SELECT c2, approx_distinct(c1), approx_distinct(c5) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 5 22 +2 5 22 +3 5 19 +4 5 23 +5 5 14 + +# Test count with nullable fields +query III +SELECT c2, count(c3), count(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 19 17 +2 17 19 +3 15 13 +4 16 19 +5 12 11 + +# Test min / max with nullable fields +query IIIRR +SELECT c2, min(c3), max(c3), min(c11), max(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -99 125 0.064453244 0.89651865 +2 -117 122 0.09683716 0.8315913 +3 -101 123 0.034291923 0.94669616 +4 -117 123 0.028003037 0.7085086 +5 -101 118 0.12559289 0.87989986 + +# Test sum with nullable fields +query IIR +SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 367 12.153253793716 +2 184 9.577824473381 +3 395 9.590891361237 +4 29 9.531112968922 +5 -194 7.074412226677 + +# Test median with nullable fields +query IIR +SELECT c2, median(c3), median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 12 0.6067944 +2 1 0.46076488 +3 14 0.40154034 +4 -17 0.48515016 +5 -35 0.5536642 + +# Test approx_median with nullable fields +query IIR +SELECT c2, approx_median(c3), approx_median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 12 0.6067944 +2 1 0.46076488 +3 14 0.40154034 +4 -7 0.48515016 +5 -39 0.5536642 + +# Test approx_distinct with nullable fields +query II +SELECT c2, approx_distinct(c3) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 19 +2 16 +3 13 +4 16 +5 12 + +# Test avg for tinyint / float +query TRR +SELECT + c1, + avg(c2), + avg(c11) +FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 2.857142857143 0.438223421574 +b 3.263157894737 0.496481208425 +c 2.666666666667 0.425241138254 +d 2.444444444444 0.541519476308 +e 3 0.505440263521 + +# FIXME: add bool_and(v3) column when issue fixed +# ISSUE https://github.com/apache/datafusion/issues/11846 +query TBBB rowsort +select v1, bool_or(v2), bool_and(v2), bool_or(v3) +from aggregate_test_100_bool +group by v1 +---- +a true false true +b true false true +c true false false +d true false false +e true false NULL + +query TBBB rowsort +select v1, + bool_or(v2) FILTER (WHERE v1 = 'a' OR v1 = 'c' OR v1 = 'e'), + bool_or(v2) FILTER (WHERE v2 = false), + bool_or(v2) FILTER (WHERE v2 = NULL) +from aggregate_test_100_bool +group by v1 +---- +a true false NULL +b NULL false NULL +c true false NULL +d NULL false NULL +e true false NULL + +# Enabling PG dialect for filtered aggregates tests +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +# Test count with filter +query III +SELECT + c2, + count(c3) FILTER (WHERE c3 > 0), + count(c3) FILTER (WHERE c11 > 10) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 13 0 +2 13 0 +3 13 0 +4 13 0 +5 5 0 + +# Test min / max with filter +query III +SELECT + c2, + min(c3) FILTER (WHERE c3 > 0), + max(c3) FILTER (WHERE c3 < 0) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 12 -5 +2 1 -29 +3 13 -2 +4 3 -38 +5 36 -5 + +# Test sum with filter +query II +SELECT + c2, + sum(c3) FILTER (WHERE c1 != 'e' AND c3 > 0) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 612 +2 565 +3 466 +4 417 +5 284 + +# Test approx_distinct with filter +query III +SELECT + c2, + approx_distinct(c3) FILTER (WHERE c3 > 0), + approx_distinct(c3) FILTER (WHERE c11 > 10) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 13 0 +2 12 0 +3 11 0 +4 13 0 +5 5 0 + +# Test median with filter +query III +SELECT + c2, + median(c3) FILTER (WHERE c3 > 0), + median(c3) FILTER (WHERE c3 < 0) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 57 -56 +2 52 -60 +3 71 -74 +4 65 -69 +5 64 -59 + +# Test approx_median with filter +query III +SELECT + c2, + approx_median(c3) FILTER (WHERE c3 > 0), + approx_median(c3) FILTER (WHERE c3 < 0) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 57 -56 +2 52 -60 +3 71 -76 +4 65 -64 +5 64 -59 + +# Test count with nullable fields and filter +query III +SELECT c2, + COUNT(c3) FILTER (WHERE c5 > 0), + COUNT(c11) FILTER(WHERE c5 > 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 11 6 +2 6 6 +3 8 6 +4 11 14 +5 8 7 + +# Test avg for tinyint / float +query TRR +SELECT + c1, + avg(c2) FILTER (WHERE c2 != 5), + avg(c11) FILTER (WHERE c2 != 5) +FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; +---- +a 2.5 0.449071887467 +b 2.642857142857 0.445486298629 +c 2.421052631579 0.422882117723 +d 2.125 0.518706191331 +e 2.789473684211 0.536785323369 + +# Test count with nullable fields and nullable filter +query III +SELECT c2, + COUNT(c3) FILTER (WHERE c11 > 0.5), + COUNT(c11) FILTER(WHERE c3 > 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 10 9 +2 7 8 +3 3 6 +4 3 7 +5 6 3 + +# Test min / max with nullable fields and filter +query IIIRR +SELECT c2, + MIN(c3) FILTER (WHERE c5 > 0), + MAX(c3) FILTER (WHERE c5 > 0), + MIN(c11) FILTER (WHERE c5 < 0), + MAX(c11) FILTER (WHERE c5 < 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -99 103 0.2578469 0.89651865 +2 -48 93 0.09683716 0.8315913 +3 -76 123 0.034291923 0.94669616 +4 -117 123 0.06563997 0.57360977 +5 -94 68 0.12559289 0.75173044 + +# Test min / max with nullable fields and nullable filter +query III +SELECT c2, + MIN(c3) FILTER (WHERE c11 > 0.5), + MAX(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -99 125 +2 -106 122 +3 -76 73 +4 -117 47 +5 -82 118 + +# Test sum with nullable field and nullable / non-nullable filters +query IIIRR +SELECT c2, + SUM(c3) FILTER (WHERE c5 > 0), + SUM(c3) FILTER (WHERE c11 < 0.5), + SUM(c11) FILTER (WHERE c5 < 0), + SUM(c11) FILTER (WHERE c3 > 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -3 77 7.214695632458 5.085060358047 +2 100 77 6.197732746601 3.150197088718 +3 109 211 2.80575042963 2.80632930994 +4 -171 56 2.10740506649 1.939846396446 +5 -86 -76 1.8741710186 1.600569307804 + +# Test approx_distinct with nullable fields and filter +query II +SELECT c2, + approx_distinct(c3) FILTER (WHERE c5 > 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 11 +2 6 +3 6 +4 11 +5 8 + +# Test approx_distinct with nullable fields and nullable filter +query II +SELECT c2, + approx_distinct(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 10 +2 6 +3 3 +4 3 +5 6 + +# Test median with nullable fields and filter +query IIR +SELECT c2, + median(c3) FILTER (WHERE c5 > 0), + median(c11) FILTER (WHERE c5 < 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -5 0.6623719 +2 15 0.52930677 +3 13 0.32792538 +4 -38 0.49774808 +5 -18 0.49842384 + +# Test min / max with nullable fields and nullable filter +query II +SELECT c2, + median(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 33 +2 -29 +3 22 +4 -90 +5 -22 + +# Test approx_median with nullable fields and filter +query IIR +SELECT c2, + approx_median(c3) FILTER (WHERE c5 > 0), + approx_median(c11) FILTER (WHERE c5 < 0) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 -5 0.6623719 +2 12 0.52930677 +3 13 0.32792538 +4 -38 0.49774808 +5 -21 0.47652745 + +# Test approx_median with nullable fields and nullable filter +query II +SELECT c2, + approx_median(c3) FILTER (WHERE c11 > 0.5) +FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; +---- +1 35 +2 -29 +3 22 +4 -90 +5 -32 + +statement ok +DROP TABLE aggregate_test_100_null; + +# Test for aggregate functions with different intermediate types +# Need more than 10 values to trigger skipping +statement ok +CREATE TABLE decimal_table(i int, d decimal(10,3)) as +VALUES (1, 1.1), (2, 2.2), (3, 3.3), (2, 4.4), (1, 5.5); + +statement ok +CREATE TABLE t(id int) as values (1), (2), (3), (4), (5), (6), (7), (8), (9), (10); + +query IR +SELECT i, sum(d) +FROM decimal_table CROSS JOIN t +GROUP BY i +ORDER BY i; +---- +1 66 +2 66 +3 33 + +statement ok +DROP TABLE decimal_table; + +# Extra tests for 'bool_*()' edge cases +statement ok +set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 0; + +statement ok +set datafusion.execution.skip_partial_aggregation_probe_ratio_threshold = 0.0; + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.execution.batch_size = 1; + +statement ok +create table bool_aggregate_functions ( + c1 boolean not null, + c2 boolean not null, + c3 boolean not null, + c4 boolean not null, + c5 boolean, + c6 boolean, + c7 boolean, + c8 boolean +) +as values + (true, true, false, false, true, true, null, null), + (true, false, true, false, false, null, false, null), + (true, true, false, false, null, true, false, null); + +query BBBBBBBB +SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions +---- +true false false false false true false NULL + +statement ok +set datafusion.execution.skip_partial_aggregation_probe_rows_threshold = 2; + +query BBBBBBBB +SELECT bool_and(c1), bool_and(c2), bool_and(c3), bool_and(c4), bool_and(c5), bool_and(c6), bool_and(c7), bool_and(c8) FROM bool_aggregate_functions +---- +true false false false false true false NULL + +statement ok +DROP TABLE aggregate_test_100_bool diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index ab6d9af7bb81..a67fec695f6c 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -40,21 +40,24 @@ query TT explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ---- logical_plan -01)Limit: skip=0, fetch=4 -02)--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 -03)----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -04)------TableScan: traces projection=[trace_id, timestamp] +01)Sort: max(traces.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces.trace_id]], aggr=[[max(traces.timestamp)]] +03)----TableScan: traces projection=[trace_id, timestamp] physical_plan -01)GlobalLimitExec: skip=0, fetch=4 -02)--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 -03)----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] -04)------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -09)----------------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [max(traces.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] +query TI +select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where trace_id != 'b' order by max_ts desc limit 3; +---- +c 4 +a 1 query TI select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; @@ -91,84 +94,82 @@ c 1 2 statement ok set datafusion.optimizer.enable_topk_aggregation = true; +query TI +select * from (select trace_id, MAX(timestamp) max_ts from traces t group by trace_id) where max_ts != 3 order by max_ts desc limit 2; +---- +c 4 +a 1 + query TT explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ---- logical_plan -01)Limit: skip=0, fetch=4 -02)--Sort: MAX(traces.timestamp) DESC NULLS FIRST, fetch=4 -03)----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -04)------TableScan: traces projection=[trace_id, timestamp] +01)Sort: max(traces.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces.trace_id]], aggr=[[max(traces.timestamp)]] +03)----TableScan: traces projection=[trace_id, timestamp] physical_plan -01)GlobalLimitExec: skip=0, fetch=4 -02)--SortPreservingMergeExec: [MAX(traces.timestamp)@1 DESC], fetch=4 -03)----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 DESC] -04)------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)], lim=[4] -09)----------------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [max(traces.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)], lim=[4] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)], lim=[4] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] query TT explain select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) desc limit 4; ---- logical_plan -01)Limit: skip=0, fetch=4 -02)--Sort: MIN(traces.timestamp) DESC NULLS FIRST, fetch=4 -03)----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MIN(traces.timestamp)]] -04)------TableScan: traces projection=[trace_id, timestamp] +01)Sort: min(traces.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces.trace_id]], aggr=[[min(traces.timestamp)]] +03)----TableScan: traces projection=[trace_id, timestamp] physical_plan -01)GlobalLimitExec: skip=0, fetch=4 -02)--SortPreservingMergeExec: [MIN(traces.timestamp)@1 DESC], fetch=4 -03)----SortExec: TopK(fetch=4), expr=[MIN(traces.timestamp)@1 DESC] -04)------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MIN(traces.timestamp)] -09)----------------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [min(traces.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[min(traces.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[min(traces.timestamp)] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[min(traces.timestamp)] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] query TT explain select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) asc limit 4; ---- logical_plan -01)Limit: skip=0, fetch=4 -02)--Sort: MAX(traces.timestamp) ASC NULLS LAST, fetch=4 -03)----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -04)------TableScan: traces projection=[trace_id, timestamp] +01)Sort: max(traces.timestamp) ASC NULLS LAST, fetch=4 +02)--Aggregate: groupBy=[[traces.trace_id]], aggr=[[max(traces.timestamp)]] +03)----TableScan: traces projection=[trace_id, timestamp] physical_plan -01)GlobalLimitExec: skip=0, fetch=4 -02)--SortPreservingMergeExec: [MAX(traces.timestamp)@1 ASC NULLS LAST], fetch=4 -03)----SortExec: TopK(fetch=4), expr=[MAX(traces.timestamp)@1 ASC NULLS LAST] -04)------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -09)----------------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [max(traces.timestamp)@1 ASC NULLS LAST], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces.timestamp)@1 ASC NULLS LAST], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] query TT explain select trace_id, MAX(timestamp) from traces group by trace_id order by trace_id asc limit 4; ---- logical_plan -01)Limit: skip=0, fetch=4 -02)--Sort: traces.trace_id ASC NULLS LAST, fetch=4 -03)----Aggregate: groupBy=[[traces.trace_id]], aggr=[[MAX(traces.timestamp)]] -04)------TableScan: traces projection=[trace_id, timestamp] +01)Sort: traces.trace_id ASC NULLS LAST, fetch=4 +02)--Aggregate: groupBy=[[traces.trace_id]], aggr=[[max(traces.timestamp)]] +03)----TableScan: traces projection=[trace_id, timestamp] physical_plan -01)GlobalLimitExec: skip=0, fetch=4 -02)--SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 -03)----SortExec: TopK(fetch=4), expr=[trace_id@0 ASC NULLS LAST] -04)------AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[MAX(traces.timestamp)] -09)----------------MemoryExec: partitions=1, partition_sizes=[1] +01)SortPreservingMergeExec: [trace_id@0 ASC NULLS LAST], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[trace_id@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces.timestamp)] +08)--------------MemoryExec: partitions=1, partition_sizes=[1] query TI -select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; +select trace_id, max(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 4; ---- c 4 b 3 @@ -176,7 +177,7 @@ a 1 NULL 0 query TI -select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; +select trace_id, min(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 4; ---- b -2 a -1 @@ -184,21 +185,21 @@ NULL 0 c 2 query TI -select trace_id, MAX(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 3; +select trace_id, max(timestamp) from traces group by trace_id order by MAX(timestamp) desc limit 3; ---- c 4 b 3 a 1 query TI -select trace_id, MIN(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 3; +select trace_id, min(timestamp) from traces group by trace_id order by MIN(timestamp) asc limit 3; ---- b -2 a -1 NULL 0 query TII -select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; +select trace_id, other, min(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4; ---- b 0 -2 a -1 -1 @@ -206,7 +207,7 @@ NULL 0 0 a 1 1 query TII -select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; +select trace_id, min(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4; ---- b 0 -2 a -1 -1 diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c3c5603dafc6..1e60699a1f65 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -346,8 +346,8 @@ AS VALUES (arrow_cast(make_array([[1,2]], [[3, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), (arrow_cast(make_array([[1,2]], [[4, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1,2], [3, 4]), 'FixedSizeList(2, List(Int64))')), (arrow_cast(make_array([[1,2]], [[4, 4]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1,2,3], [1]), 'FixedSizeList(2, List(Int64))')), - (arrow_cast(make_array([[1], [2]], []), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([2], [3]), 'FixedSizeList(2, List(Int64))')), - (arrow_cast(make_array([[1], [2]], []), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1], [2]], [[]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([2], [3]), 'FixedSizeList(2, List(Int64))')), + (arrow_cast(make_array([[1], [2]], [[]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), (arrow_cast(make_array([[1], [2]], [[2], [3]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')), (arrow_cast(make_array([[1], [2]], [[2], [3]]), 'FixedSizeList(2, List(List(Int64)))'), arrow_cast(make_array([1], [2]), 'FixedSizeList(2, List(Int64))')) ; @@ -629,6 +629,38 @@ AS VALUES (arrow_cast(make_array([28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30], [31, 32, 33], [34, 35, 36], [28, 29, 30]), 'FixedSizeList(10, List(Int64))'), [28, 29, 30], [28, 29, 30], 10) ; +statement ok +CREATE TABLE arrays_distance_table +AS VALUES + (make_array(1, 2, 3), make_array(1, 2, 3), make_array(1.1, 2.2, 3.3) , make_array(1.1, NULL, 3.3)), + (make_array(1, 2, 3), make_array(4, 5, 6), make_array(4.4, 5.5, 6.6), make_array(4.4, NULL, 6.6)), + (make_array(1, 2, 3), make_array(7, 8, 9), make_array(7.7, 8.8, 9.9), make_array(7.7, NULL, 9.9)), + (make_array(1, 2, 3), make_array(10, 11, 12), make_array(10.1, 11.2, 12.3), make_array(10.1, NULL, 12.3)) +; + +statement ok +CREATE TABLE large_arrays_distance_table +AS + SELECT + arrow_cast(column1, 'LargeList(Int64)') AS column1, + arrow_cast(column2, 'LargeList(Int64)') AS column2, + arrow_cast(column3, 'LargeList(Float64)') AS column3, + arrow_cast(column4, 'LargeList(Float64)') AS column4 +FROM arrays_distance_table +; + +statement ok +CREATE TABLE fixed_size_arrays_distance_table +AS + SELECT + arrow_cast(column1, 'FixedSizeList(3, Int64)') AS column1, + arrow_cast(column2, 'FixedSizeList(3, Int64)') AS column2, + arrow_cast(column3, 'FixedSizeList(3, Float64)') AS column3, + arrow_cast(column4, 'FixedSizeList(3, Float64)') AS column4 +FROM arrays_distance_table +; + + # Array literal ## boolean coercion is not supported @@ -1136,8 +1168,12 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) +# Testing with empty arguments should result in an error +query error DataFusion error: Error during planning: Error during planning: array_element does not support zero arguments +select array_element(); + # array_element error -query error DataFusion error: Error during planning: No function matches the given name and argument types 'array_element\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarray_element\(array, index\) +query error select array_element(1, 2); # array_element with null @@ -1963,6 +1999,149 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), co [1, 2, 3, 4, 5] [43, 44, 45, 46] [41, 42, 43, 44, 45] [5] [, 54, 55, 56, 57, 58, 59, 60] [55] +# Test issue: https://github.com/apache/datafusion/issues/10425 +# `from` may be larger than `to` and `stride` is positive +query ???? +select array_slice(a, -1, 2, 1), array_slice(a, -1, 2), + array_slice(a, 3, 2, 1), array_slice(a, 3, 2) + from (values ([1.0, 2.0, 3.0, 3.0]), ([4.0, 5.0, 3.0]), ([6.0])) t(a); +---- +[] [] [] [] +[] [] [] [] +[6.0] [6.0] [] [] + +# Testing with empty arguments should result in an error +query error DataFusion error: Error during planning: Error during planning: array_slice does not support zero arguments +select array_slice(); + +## array_any_value (aliases: list_any_value) + +# Testing with empty arguments should result in an error +query error +select array_any_value(); + +# Testing with non-array arguments should result in an error +query error +select array_any_value(1), array_any_value('a'), array_any_value(NULL); + +# array_any_value scalar function #1 (with null and non-null elements) + +query ITII +select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')), array_any_value(make_array(NULL, NULL)), array_any_value(make_array(NULL, NULL, 1, 2, 3)); +---- +1 h NULL 1 + +query ITITI +select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), array_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'LargeList(Int64)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'LargeList(Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL, 1, 2, 3), 'LargeList(Int64)'));; +---- +1 h NULL NULL 1 + +query ITITI +select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'FixedSizeList(6, Int64)')), array_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'FixedSizeList(6, Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Int64)')), array_any_value(arrow_cast(make_array(NULL, NULL), 'FixedSizeList(2, Utf8)')), array_any_value(arrow_cast(make_array(NULL, NULL, 1, 2, 3, 4), 'FixedSizeList(6, Int64)')); +---- +1 h NULL NULL 1 + +# array_any_value scalar function #2 (with nested array) + +query ? +select array_any_value(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10))); +---- +[, 1, 2, 3, 4, 5] + +query ? +select array_any_value(arrow_cast(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)), 'LargeList(List(Int64))')); +---- +[, 1, 2, 3, 4, 5] + +query ? +select array_any_value(arrow_cast(make_array(NULL, make_array(NULL, 1, 2, 3, 4, 5), make_array(NULL, 6, 7, 8, 9, 10)), 'FixedSizeList(3, List(Int64))')); +---- +[, 1, 2, 3, 4, 5] + +# array_any_value scalar function #3 (using function alias `list_any_value`) +query IT +select list_any_value(make_array(NULL, 1, 2, 3, 4, 5)), list_any_value(make_array(NULL, 'h', 'e', 'l', 'l', 'o')); +---- +1 h + +query IT +select list_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), list_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)')); +---- +1 h + +query IT +select list_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'FixedSizeList(6, Int64)')), list_any_value(arrow_cast(make_array(NULL, 'h', 'e', 'l', 'l', 'o'), 'FixedSizeList(6, Utf8)')); +---- +1 h + +# array_any_value with columns + +query I +select array_any_value(column1) from slices; +---- +2 +11 +21 +31 +NULL +41 +51 + +query I +select array_any_value(arrow_cast(column1, 'LargeList(Int64)')) from slices; +---- +2 +11 +21 +31 +NULL +41 +51 + +query I +select array_any_value(column1) from fixed_slices; +---- +2 +11 +21 +31 +41 +51 + +# array_any_value with columns and scalars + +query II +select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(column1) from slices; +---- +1 2 +1 11 +1 21 +1 31 +1 NULL +1 41 +1 51 + +query II +select array_any_value(arrow_cast(make_array(NULL, 1, 2, 3, 4, 5), 'LargeList(Int64)')), array_any_value(arrow_cast(column1, 'LargeList(Int64)')) from slices; +---- +1 2 +1 11 +1 21 +1 31 +1 NULL +1 41 +1 51 + +query II +select array_any_value(make_array(NULL, 1, 2, 3, 4, 5)), array_any_value(column1) from fixed_slices; +---- +1 2 +1 11 +1 21 +1 31 +1 41 +1 51 + # make_array with nulls query ??????? select make_array(make_array('a','b'), null), @@ -2018,6 +2197,25 @@ NULL [, 51, 52, 54, 55, 56, 57, 58, 59, 60] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70] +# test with empty array +query ? +select array_sort([]); +---- +[] + +# test with empty row, the row that does not match the condition has row count 0 +statement ok +create table t1(a int, b int) as values (100, 1), (101, 2), (102, 3), (101, 2); + +# rowsort is to ensure the order of group by is deterministic, array_sort has no effect here, since the sum() always returns single row. +query ? rowsort +select array_sort([sum(a)]) from t1 where a > 100 group by b; +---- +[102] +[202] + +statement ok +drop table t1; ## list_sort (aliases: `array_sort`) query ??? @@ -2051,10 +2249,10 @@ select query ???? select - array_append(arrow_cast(make_array(), 'LargeList(Null)'), 4), - array_append(arrow_cast(make_array(), 'LargeList(Null)'), null), + array_append(arrow_cast(make_array(), 'LargeList(Int64)'), 4), + array_append(arrow_cast(make_array(), 'LargeList(Int64)'), null), array_append(arrow_cast(make_array(1, null, 3), 'LargeList(Int64)'), 4), - array_append(arrow_cast(make_array(null, null), 'LargeList(Null)'), 1) + array_append(arrow_cast(make_array(null, null), 'LargeList(Int64)'), 1) ; ---- [4] [] [1, , 3, 4] [, , 1] @@ -2535,7 +2733,7 @@ query ???? select array_repeat(arrow_cast([1], 'LargeList(Int64)'), 5), array_repeat(arrow_cast([1.1, 2.2, 3.3], 'LargeList(Float64)'), 3), - array_repeat(arrow_cast([null, null], 'LargeList(Null)'), 3), + array_repeat(arrow_cast([null, null], 'LargeList(Int64)'), 3), array_repeat(arrow_cast([[1, 2], [3, 4]], 'LargeList(List(Int64))'), 2); ---- [[1], [1], [1], [1], [1]] [[1.1, 2.2, 3.3], [1.1, 2.2, 3.3], [1.1, 2.2, 3.3]] [[, ], [, ], [, ]] [[[1, 2], [3, 4]], [[1, 2], [3, 4]]] @@ -2598,6 +2796,12 @@ drop table large_array_repeat_table; ## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`) +# test with empty array +query ? +select array_concat([]); +---- +[] + # array_concat error query error DataFusion error: Error during planning: The array_concat function can only accept list as the args\. select array_concat(1, 2); @@ -2642,19 +2846,19 @@ select array_concat(make_array(), make_array(2, 3)); query ? select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array())); ---- -[[1, 2], [3, 4]] +[[1, 2], [3, 4], []] # array_concat scalar function #8 (with empty arrays) query ? select array_concat(make_array(make_array(1, 2), make_array(3, 4)), make_array(make_array()), make_array(make_array(), make_array()), make_array(make_array(5, 6), make_array(7, 8))); ---- -[[1, 2], [3, 4], [5, 6], [7, 8]] +[[1, 2], [3, 4], [], [], [], [5, 6], [7, 8]] # array_concat scalar function #9 (with empty arrays) query ? select array_concat(make_array(make_array()), make_array(make_array(1, 2), make_array(3, 4))); ---- -[[1, 2], [3, 4]] +[[], [1, 2], [3, 4]] # array_cat scalar function #10 (function alias `array_concat`) query ?? @@ -3724,6 +3928,54 @@ select array_to_string(make_array(), ',') ---- (empty) +# array to string dictionary +statement ok +CREATE TABLE table1 AS VALUES + (1, 'foo'), + (3, 'bar'), + (1, 'foo'), + (2, NULL), + (NULL, 'baz') + ; + +# expect 1-3-1-2 (dictionary values should be repeated) +query T +SELECT array_to_string(array_agg(column1),'-') +FROM ( + SELECT arrow_cast(column1, 'Dictionary(Int32, Int32)') as column1 + FROM table1 +); +---- +1-3-1-2 + +# expect foo,bar,foo,baz (dictionary values should be repeated) +query T +SELECT array_to_string(array_agg(column2),',') +FROM ( + SELECT arrow_cast(column2, 'Dictionary(Int64, Utf8)') as column2 + FROM table1 +); +---- +foo,bar,foo,baz + +# Expect only values that are in the group +query I?T +SELECT column1, array_agg(column2), array_to_string(array_agg(column2),',') +FROM ( + SELECT column1, arrow_cast(column2, 'Dictionary(Int32, Utf8)') as column2 + FROM table1 +) +GROUP BY column1 +ORDER BY column1; +---- +1 [foo, foo] foo,foo +2 [] (empty) +3 [bar] bar +NULL [baz] baz + +statement ok +drop table table1; + ## array_union (aliases: `list_union`) @@ -3756,7 +4008,7 @@ select array_union([1,2,3], []); [1, 2, 3] query ? -select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Null)')); +select array_union(arrow_cast([1,2,3], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Int64)')); ---- [1, 2, 3] @@ -3804,7 +4056,7 @@ select array_union([], []); [] query ? -select array_union(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +select array_union(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Int64)')); ---- [] @@ -3815,7 +4067,7 @@ select array_union([[null]], []); [[]] query ? -select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([], 'LargeList(Null)')); +select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([], 'LargeList(Int64)')); ---- [[]] @@ -3826,7 +4078,7 @@ select array_union([null], [null]); [] query ? -select array_union(arrow_cast([[null]], 'LargeList(List(Null))'), arrow_cast([[null]], 'LargeList(List(Null))')); +select array_union(arrow_cast([[null]], 'LargeList(List(Int64))'), arrow_cast([[null]], 'LargeList(List(Int64))')); ---- [[]] @@ -3837,7 +4089,7 @@ select array_union(null, []); [] query ? -select array_union(null, arrow_cast([], 'LargeList(Null)')); +select array_union(null, arrow_cast([], 'LargeList(Int64)')); ---- [] @@ -4074,14 +4326,14 @@ select cardinality(make_array()), cardinality(make_array(make_array())) NULL 0 query II -select cardinality(arrow_cast(make_array(), 'LargeList(Null)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +select cardinality(arrow_cast(make_array(), 'LargeList(Int64)')), cardinality(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))')) ---- NULL 0 #TODO #https://github.com/apache/datafusion/issues/9158 #query II -#select cardinality(arrow_cast(make_array(), 'FixedSizeList(1, Null)')), cardinality(arrow_cast(make_array(make_array()), 'FixedSizeList(1, List(Null))')) +#select cardinality(arrow_cast(make_array(), 'FixedSizeList(1, Null)')), cardinality(arrow_cast(make_array(make_array()), 'FixedSizeList(1, List(Int64))')) #---- #NULL 0 @@ -4622,10 +4874,89 @@ NULL 10 NULL 10 NULL 10 +query RRR +select array_distance([2], [3]), list_distance([1], [2]), list_distance([1], [-2]); +---- +1 1 3 + +query error +select list_distance([1], [1, 2]); + +query R +select array_distance([[1, 1]], [1, 2]); +---- +1 + +query R +select array_distance([[1, 1]], [[1, 2]]); +---- +1 + +query R +select array_distance([[1, 1]], [[1, 2]]); +---- +1 + +query RR +select array_distance([1, 1, 0, 0], [2, 2, 1, 1]), list_distance([1, 2, 3], [1, 2, 3]); +---- +2 0 + +query RR +select array_distance([1.0, 1, 0, 0], [2, 2.0, 1, 1]), list_distance([1, 2.0, 3], [1, 2, 3]); +---- +2 0 + +query R +select list_distance([1, 1, NULL, 0], [2, 2, NULL, NULL]); +---- +NULL + +query R +select list_distance([NULL, NULL], [NULL, NULL]); +---- +NULL + +query R +select list_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.5]) AS distance; +---- +0.5 + +query R +select list_distance([1, 2, 3], [1, 2, 3]) AS distance; +---- +0 + +# array_distance with columns +query RRR +select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from arrays_distance_table; +---- +0 0.374165738677 NULL +5.196152422707 6.063827174318 NULL +10.392304845413 11.778794505381 NULL +15.58845726812 15.935494971917 NULL + +query RRR +select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from large_arrays_distance_table; +---- +0 0.374165738677 NULL +5.196152422707 6.063827174318 NULL +10.392304845413 11.778794505381 NULL +15.58845726812 15.935494971917 NULL + +query RRR +select array_distance(column1, column2), array_distance(column1, column3), array_distance(column1, column4) from fixed_size_arrays_distance_table; +---- +0 0.374165738677 NULL +5.196152422707 6.063827174318 NULL +10.392304845413 11.778794505381 NULL +15.58845726812 15.935494971917 NULL + + ## array_dims (aliases: `list_dims`) # array dims error -query error DataFusion error: Error during planning: No function matches the given name and argument types 'array_dims\(Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarray_dims\(array\) +query error select array_dims(1); # array_dims scalar function @@ -4667,7 +4998,7 @@ select array_dims(make_array()), array_dims(make_array(make_array())) NULL [1, 0] query ?? -select array_dims(arrow_cast(make_array(), 'LargeList(Null)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +select array_dims(arrow_cast(make_array(), 'LargeList(Int64)')), array_dims(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))')) ---- NULL [1, 0] @@ -4829,7 +5160,7 @@ select array_ndims(make_array()), array_ndims(make_array(make_array())) 1 2 query II -select array_ndims(arrow_cast(make_array(), 'LargeList(Null)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +select array_ndims(arrow_cast(make_array(), 'LargeList(Int64)')), array_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))')) ---- 1 2 @@ -4850,7 +5181,7 @@ select list_ndims(make_array()), list_ndims(make_array(make_array())) 1 2 query II -select list_ndims(arrow_cast(make_array(), 'LargeList(Null)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Null))')) +select list_ndims(arrow_cast(make_array(), 'LargeList(Int64)')), list_ndims(arrow_cast(make_array(make_array()), 'LargeList(List(Int64))')) ---- 1 2 @@ -4879,11 +5210,19 @@ NULL 1 1 ## array_has/array_has_all/array_has_any -query BB +# If lhs is empty, return false +query B +select array_has([], 1); +---- +false + +# If rhs is Null, we returns Null +query BBB select array_has([], null), - array_has([1, 2, 3], null); + array_has([1, 2, 3], null), + array_has([null, 1], null); ---- -false false +NULL NULL NULL #TODO: array_has_all and array_has_any cannot handle NULL #query BBBB @@ -5169,8 +5508,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has(arrow_cast(column1, 'LargeList(List(Int64))'), make_array(5, 6)), @@ -5183,8 +5523,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBB select array_has(column1, make_array(5, 6)), @@ -5197,8 +5538,9 @@ false false false true true false true false true false false true false true false false -false false false false -false false false false +NULL NULL false false +false false NULL false +false false false NULL query BBBBBBBBBBBBB select array_has_all(make_array(1,2,3), make_array(1,3)), @@ -5255,6 +5597,25 @@ true false true false false false true true false false true false true #---- #true false true false false false true true false false true false true +# any operator +query ? +select column3 from arrays where 'L'=any(column3); +---- +[L, o, r, e, m] + +query I +select count(*) from arrays where 'L'=any(column3); +---- +1 + +query I +select count(*) from arrays where 'X'=any(column3); +---- +0 + +query error DataFusion error: Error during planning: Unsupported AnyOp: '>', only '=' is supported +select count(*) from arrays where 'X'>any(column3); + ## array_distinct #TODO: https://github.com/apache/datafusion/issues/7142 @@ -5465,7 +5826,7 @@ select array_intersect([], []); [] query ? -select array_intersect(arrow_cast([], 'LargeList(Null)'), arrow_cast([], 'LargeList(Null)')); +select array_intersect(arrow_cast([], 'LargeList(Int64)'), arrow_cast([], 'LargeList(Int64)')); ---- [] @@ -5495,7 +5856,17 @@ select array_intersect([], null); [] query ? -select array_intersect(arrow_cast([], 'LargeList(Null)'), null); +select array_intersect([[1,2,3]], [[]]); +---- +[] + +query ? +select array_intersect([[null]], [[]]); +---- +[] + +query ? +select array_intersect(arrow_cast([], 'LargeList(Int64)'), null); ---- [] @@ -5505,7 +5876,7 @@ select array_intersect(null, []); NULL query ? -select array_intersect(null, arrow_cast([], 'LargeList(Null)')); +select array_intersect(null, arrow_cast([], 'LargeList(Int64)')); ---- NULL @@ -5584,9 +5955,9 @@ select ---- [] [] [0] [0] -# Test range for other egde cases +# Test range for other edge cases query ???????? -select +select range(9223372036854775807, 9223372036854775807, -1) as c1, range(9223372036854775807, 9223372036854775806, -1) as c2, range(9223372036854775807, 9223372036854775807, 1) as c3, @@ -5631,26 +6002,76 @@ select range(NULL) ---- NULL -## should throw error -query error +## should return NULL +query ? select range(DATE '1992-09-01', NULL, INTERVAL '1' YEAR); +---- +NULL -query error +## should return NULL +query ? +select range(TIMESTAMP '1992-09-01', NULL, INTERVAL '1' YEAR); +---- +NULL + +query ? select range(DATE '1992-09-01', DATE '1993-03-01', NULL); +---- +NULL -query error +query ? +select range(TIMESTAMP '1992-09-01', TIMESTAMP '1993-03-01', NULL); +---- +NULL + +query ? select range(NULL, DATE '1993-03-01', INTERVAL '1' YEAR); +---- +NULL + +query ? +select range(NULL, TIMESTAMP '1993-03-01', INTERVAL '1' YEAR); +---- +NULL + +query ? +select range(NULL, NULL, NULL); +---- +NULL + +query ? +select range(NULL::timestamp, NULL::timestamp, NULL); +---- +NULL query ? select range(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '-1' YEAR) ---- [] +query ? +select range(TIMESTAMP '1989-04-01', TIMESTAMP '1993-03-01', INTERVAL '-1' YEAR) +---- +[] + query ? select range(DATE '1993-03-01', DATE '1989-04-01', INTERVAL '1' YEAR) ---- [] +query ? +select range(TIMESTAMP '1993-03-01', TIMESTAMP '1989-04-01', INTERVAL '1' YEAR) +---- +[] + +query error DataFusion error: Execution error: Cannot generate date range less than 1 day\. +select range(DATE '1993-03-01', DATE '1993-03-01', INTERVAL '1' HOUR) + +query ? +select range(TIMESTAMP '1993-03-01', TIMESTAMP '1993-03-01', INTERVAL '1' HOUR) +---- +[] + query ????????? select generate_series(5), generate_series(2, 5), @@ -5665,31 +6086,124 @@ select generate_series(5), ---- [0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] [10, 7, 4] [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01, 1993-03-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993-01-03, 1993-01-02, 1993-01-01] [1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] -## should throw error -query error +query ? +select generate_series('2021-01-01'::timestamp, '2021-01-01T15:00:00'::timestamp, INTERVAL '1' HOUR); +---- +[2021-01-01T00:00:00, 2021-01-01T01:00:00, 2021-01-01T02:00:00, 2021-01-01T03:00:00, 2021-01-01T04:00:00, 2021-01-01T05:00:00, 2021-01-01T06:00:00, 2021-01-01T07:00:00, 2021-01-01T08:00:00, 2021-01-01T09:00:00, 2021-01-01T10:00:00, 2021-01-01T11:00:00, 2021-01-01T12:00:00, 2021-01-01T13:00:00, 2021-01-01T14:00:00, 2021-01-01T15:00:00] + +query ? +select generate_series('2021-01-01T00:00:00EST'::timestamp, '2021-01-01T15:00:00-12:00'::timestamp, INTERVAL '1' HOUR); +---- +[2021-01-01T05:00:00, 2021-01-01T06:00:00, 2021-01-01T07:00:00, 2021-01-01T08:00:00, 2021-01-01T09:00:00, 2021-01-01T10:00:00, 2021-01-01T11:00:00, 2021-01-01T12:00:00, 2021-01-01T13:00:00, 2021-01-01T14:00:00, 2021-01-01T15:00:00, 2021-01-01T16:00:00, 2021-01-01T17:00:00, 2021-01-01T18:00:00, 2021-01-01T19:00:00, 2021-01-01T20:00:00, 2021-01-01T21:00:00, 2021-01-01T22:00:00, 2021-01-01T23:00:00, 2021-01-02T00:00:00, 2021-01-02T01:00:00, 2021-01-02T02:00:00, 2021-01-02T03:00:00] + +query ? +select generate_series(arrow_cast('2021-01-01T00:00:00', 'Timestamp(Nanosecond, Some("-05:00"))'), arrow_cast('2021-01-01T15:00:00', 'Timestamp(Nanosecond, Some("+05:00"))'), INTERVAL '1' HOUR); +---- +[2021-01-01T00:00:00-05:00, 2021-01-01T01:00:00-05:00, 2021-01-01T02:00:00-05:00, 2021-01-01T03:00:00-05:00, 2021-01-01T04:00:00-05:00, 2021-01-01T05:00:00-05:00] + +## -5500000000 ns is -5.5 sec +query ? +select generate_series(arrow_cast('2021-01-01T00:00:00', 'Timestamp(Nanosecond, Some("-05:00"))'), arrow_cast('2021-01-01T06:00:00', 'Timestamp(Nanosecond, Some("-05:00"))'), INTERVAL '1 HOUR 30 MINUTE -5500000000 NANOSECOND'); +---- +[2021-01-01T00:00:00-05:00, 2021-01-01T01:29:54.500-05:00, 2021-01-01T02:59:49-05:00, 2021-01-01T04:29:43.500-05:00, 2021-01-01T05:59:38-05:00] + +## mixing types for timestamps is not supported +query error DataFusion error: Internal error: Unexpected argument type for GENERATE_SERIES : Date32 +select generate_series(arrow_cast('2021-01-01T00:00:00', 'Timestamp(Nanosecond, Some("-05:00"))'), DATE '2021-01-02', INTERVAL '1' HOUR); + + +## should return NULL +query ? select generate_series(DATE '1992-09-01', NULL, INTERVAL '1' YEAR); +---- +NULL -query error +## should return NULL +query ? +select generate_series(TIMESTAMP '1992-09-01', NULL, INTERVAL '1' YEAR); +---- +NULL + +query ? select generate_series(DATE '1992-09-01', DATE '1993-03-01', NULL); +---- +NULL -query error +query ? +select generate_series(TIMESTAMP '1992-09-01', DATE '1993-03-01', NULL); +---- +NULL + +query ? select generate_series(NULL, DATE '1993-03-01', INTERVAL '1' YEAR); +---- +NULL + +query ? +select generate_series(NULL, TIMESTAMP '1993-03-01', INTERVAL '1' YEAR); +---- +NULL + +query ? +select generate_series(NULL, NULL, NULL); +---- +NULL +query ? +select generate_series(NULL::timestamp, NULL::timestamp, NULL); +---- +NULL query ? select generate_series(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '-1' YEAR) ---- [] +query ? +select generate_series(TIMESTAMP '1989-04-01', TIMESTAMP '1993-03-01', INTERVAL '-1' YEAR) +---- +[] + query ? select generate_series(DATE '1993-03-01', DATE '1989-04-01', INTERVAL '1' YEAR) ---- [] +query ? +select generate_series(TIMESTAMP '1993-03-01', TIMESTAMP '1989-04-01', INTERVAL '1' YEAR) +---- +[] + +query error DataFusion error: Execution error: Cannot generate date range less than 1 day. +select generate_series(DATE '2000-01-01', DATE '2000-01-03', INTERVAL '1' HOUR) + +query error DataFusion error: Execution error: Cannot generate date range less than 1 day. +select generate_series(DATE '2000-01-01', DATE '2000-01-03', INTERVAL '-1' HOUR) + +query ? +select generate_series(TIMESTAMP '2000-01-01', TIMESTAMP '2000-01-02', INTERVAL '1' HOUR) +---- +[2000-01-01T00:00:00, 2000-01-01T01:00:00, 2000-01-01T02:00:00, 2000-01-01T03:00:00, 2000-01-01T04:00:00, 2000-01-01T05:00:00, 2000-01-01T06:00:00, 2000-01-01T07:00:00, 2000-01-01T08:00:00, 2000-01-01T09:00:00, 2000-01-01T10:00:00, 2000-01-01T11:00:00, 2000-01-01T12:00:00, 2000-01-01T13:00:00, 2000-01-01T14:00:00, 2000-01-01T15:00:00, 2000-01-01T16:00:00, 2000-01-01T17:00:00, 2000-01-01T18:00:00, 2000-01-01T19:00:00, 2000-01-01T20:00:00, 2000-01-01T21:00:00, 2000-01-01T22:00:00, 2000-01-01T23:00:00, 2000-01-02T00:00:00] + +query ? +select generate_series(TIMESTAMP '2000-01-02', TIMESTAMP '2000-01-01', INTERVAL '-1' HOUR) +---- +[2000-01-02T00:00:00, 2000-01-01T23:00:00, 2000-01-01T22:00:00, 2000-01-01T21:00:00, 2000-01-01T20:00:00, 2000-01-01T19:00:00, 2000-01-01T18:00:00, 2000-01-01T17:00:00, 2000-01-01T16:00:00, 2000-01-01T15:00:00, 2000-01-01T14:00:00, 2000-01-01T13:00:00, 2000-01-01T12:00:00, 2000-01-01T11:00:00, 2000-01-01T10:00:00, 2000-01-01T09:00:00, 2000-01-01T08:00:00, 2000-01-01T07:00:00, 2000-01-01T06:00:00, 2000-01-01T05:00:00, 2000-01-01T04:00:00, 2000-01-01T03:00:00, 2000-01-01T02:00:00, 2000-01-01T01:00:00, 2000-01-01T00:00:00] + +# Test generate_series with small intervals +query ? +select generate_series('2000-01-01T00:00:00.000000001Z'::timestamp, '2000-01-01T00:00:00.00000001Z'::timestamp, INTERVAL '1' NANOSECONDS) +---- +[2000-01-01T00:00:00.000000001, 2000-01-01T00:00:00.000000002, 2000-01-01T00:00:00.000000003, 2000-01-01T00:00:00.000000004, 2000-01-01T00:00:00.000000005, 2000-01-01T00:00:00.000000006, 2000-01-01T00:00:00.000000007, 2000-01-01T00:00:00.000000008, 2000-01-01T00:00:00.000000009, 2000-01-01T00:00:00.000000010] + # Test generate_series with zero step query error DataFusion error: Execution error: step can't be 0 for function generate_series\(start \[, stop, step\]\) select generate_series(1, 1, 0); +# Test generate_series with zero step +query error DataFusion error: Execution error: Interval argument to GENERATE_SERIES must not be 0 +select generate_series(TIMESTAMP '2000-01-02', TIMESTAMP '2000-01-01', INTERVAL '0' MINUTE); + # Test generate_series with big steps query ???? select @@ -5701,9 +6215,9 @@ select [-9223372036854775808] [9223372036854775807] [0, -9223372036854775808] [0, 9223372036854775807] -# Test generate_series for other egde cases +# Test generate_series for other edge cases query ???? -select +select generate_series(9223372036854775807, 9223372036854775807, -1) as c1, generate_series(9223372036854775807, 9223372036854775807, 1) as c2, generate_series(-9223372036854775808, -9223372036854775808, -1) as c3, @@ -5744,6 +6258,77 @@ select generate_series(NULL) ---- NULL +# Test generate_series with a table of date values +statement ok +CREATE TABLE date_table( + start DATE, + stop DATE, + step INTERVAL +) AS VALUES + (DATE '1992-01-01', DATE '1993-01-02', INTERVAL '1' MONTH), + (DATE '1993-02-01', DATE '1993-01-01', INTERVAL '-1' DAY), + (DATE '1989-04-01', DATE '1993-03-01', INTERVAL '1' YEAR); + +query ? +select generate_series(start, stop, step) from date_table; +---- +[1992-01-01, 1992-02-01, 1992-03-01, 1992-04-01, 1992-05-01, 1992-06-01, 1992-07-01, 1992-08-01, 1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01] +[1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993-01-03, 1993-01-02, 1993-01-01] +[1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] + +query ? +select generate_series(start, stop, INTERVAL '1 year') from date_table; +---- +[1992-01-01, 1993-01-01] +[] +[1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] + +query ? +select generate_series(start, '1993-03-01'::date, INTERVAL '1 year') from date_table; +---- +[1992-01-01, 1993-01-01] +[1993-02-01] +[1989-04-01, 1990-04-01, 1991-04-01, 1992-04-01] + +# Test generate_series with a table of timestamp values +statement ok +CREATE TABLE timestamp_table( + start TIMESTAMP, + stop TIMESTAMP, + step INTERVAL +) AS VALUES + (TIMESTAMP '1992-01-01T00:00:00', TIMESTAMP '1993-01-02T00:00:00', INTERVAL '1' MONTH), + (TIMESTAMP '1993-02-01T00:00:00', TIMESTAMP '1993-01-01T00:00:00', INTERVAL '-1' DAY), + (TIMESTAMP '1989-04-01T00:00:00', TIMESTAMP '1993-03-01T00:00:00', INTERVAL '1' YEAR); + +query ? +select generate_series(start, stop, step) from timestamp_table; +---- +[1992-01-01T00:00:00, 1992-02-01T00:00:00, 1992-03-01T00:00:00, 1992-04-01T00:00:00, 1992-05-01T00:00:00, 1992-06-01T00:00:00, 1992-07-01T00:00:00, 1992-08-01T00:00:00, 1992-09-01T00:00:00, 1992-10-01T00:00:00, 1992-11-01T00:00:00, 1992-12-01T00:00:00, 1993-01-01T00:00:00] +[1993-02-01T00:00:00, 1993-01-31T00:00:00, 1993-01-30T00:00:00, 1993-01-29T00:00:00, 1993-01-28T00:00:00, 1993-01-27T00:00:00, 1993-01-26T00:00:00, 1993-01-25T00:00:00, 1993-01-24T00:00:00, 1993-01-23T00:00:00, 1993-01-22T00:00:00, 1993-01-21T00:00:00, 1993-01-20T00:00:00, 1993-01-19T00:00:00, 1993-01-18T00:00:00, 1993-01-17T00:00:00, 1993-01-16T00:00:00, 1993-01-15T00:00:00, 1993-01-14T00:00:00, 1993-01-13T00:00:00, 1993-01-12T00:00:00, 1993-01-11T00:00:00, 1993-01-10T00:00:00, 1993-01-09T00:00:00, 1993-01-08T00:00:00, 1993-01-07T00:00:00, 1993-01-06T00:00:00, 1993-01-05T00:00:00, 1993-01-04T00:00:00, 1993-01-03T00:00:00, 1993-01-02T00:00:00, 1993-01-01T00:00:00] +[1989-04-01T00:00:00, 1990-04-01T00:00:00, 1991-04-01T00:00:00, 1992-04-01T00:00:00] + +query ? +select generate_series(start, stop, INTERVAL '1 year') from timestamp_table; +---- +[1992-01-01T00:00:00, 1993-01-01T00:00:00] +[] +[1989-04-01T00:00:00, 1990-04-01T00:00:00, 1991-04-01T00:00:00, 1992-04-01T00:00:00] + +query ? +select generate_series(start, '1993-03-01T00:00:00'::timestamp, INTERVAL '1 year') from timestamp_table; +---- +[1992-01-01T00:00:00, 1993-01-01T00:00:00] +[1993-02-01T00:00:00] +[1989-04-01T00:00:00, 1990-04-01T00:00:00, 1991-04-01T00:00:00, 1992-04-01T00:00:00] + +# https://github.com/apache/datafusion/issues/11922 +query ? +select generate_series(start, '1993-03-01T00:00:00'::timestamp, INTERVAL '1 year') from timestamp_table; +---- +[1992-01-01T00:00:00, 1993-01-01T00:00:00] +[1993-02-01T00:00:00] +[1989-04-01T00:00:00, 1990-04-01T00:00:00, 1991-04-01T00:00:00, 1992-04-01T00:00:00] ## array_except @@ -5949,6 +6534,17 @@ select make_array(1,2,3) @> make_array(1,3), ---- true false true false false false true +# Make sure it is rewritten to function array_has_all() +query TT +explain select [1,2,3] @> [1,3]; +---- +logical_plan +01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) +02)--EmptyRelation +physical_plan +01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] +02)--PlaceholderRowExec + # array containment operator with scalars #2 (arrow at) query BBBBBBB select make_array(1,3) <@ make_array(1,2,3), @@ -5961,6 +6557,17 @@ select make_array(1,3) <@ make_array(1,2,3), ---- true false true false false false true +# Make sure it is rewritten to function array_has_all() +query TT +explain select [1,3] <@ [1,2,3]; +---- +logical_plan +01)Projection: Boolean(true) AS array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3))) +02)--EmptyRelation +physical_plan +01)ProjectionExec: expr=[true as array_has_all(make_array(Int64(1),Int64(2),Int64(3)),make_array(Int64(1),Int64(3)))] +02)--PlaceholderRowExec + ### Array casting tests @@ -5988,7 +6595,7 @@ select make_array(1, 2.0, null, 3) query ? select make_array(1.0, '2', null) ---- -[1.0, 2, ] +[1.0, 2.0, ] ### FixedSizeListArray @@ -6140,7 +6747,7 @@ select empty(make_array()); true query B -select empty(arrow_cast(make_array(), 'LargeList(Null)')); +select empty(arrow_cast(make_array(), 'LargeList(Int64)')); ---- true @@ -6157,12 +6764,12 @@ select empty(make_array(NULL)); false query B -select empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +select empty(arrow_cast(make_array(NULL), 'LargeList(Int64)')); ---- false query B -select empty(arrow_cast(make_array(NULL), 'FixedSizeList(1, Null)')); +select empty(arrow_cast(make_array(NULL), 'FixedSizeList(1, Int64)')); ---- false @@ -6226,7 +6833,7 @@ select array_empty(make_array()); true query B -select array_empty(arrow_cast(make_array(), 'LargeList(Null)')); +select array_empty(arrow_cast(make_array(), 'LargeList(Int64)')); ---- true @@ -6237,7 +6844,7 @@ select array_empty(make_array(NULL)); false query B -select array_empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +select array_empty(arrow_cast(make_array(NULL), 'LargeList(Int64)')); ---- false @@ -6260,7 +6867,7 @@ select list_empty(make_array()); true query B -select list_empty(arrow_cast(make_array(), 'LargeList(Null)')); +select list_empty(arrow_cast(make_array(), 'LargeList(Int64)')); ---- true @@ -6271,7 +6878,7 @@ select list_empty(make_array(NULL)); false query B -select list_empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +select list_empty(arrow_cast(make_array(NULL), 'LargeList(Int64)')); ---- false @@ -6378,7 +6985,7 @@ select array_resize(column1, column2, column3) from arrays_values; [11, 12, 13, 14, 15, 16, 17, 18, , 20, 2, 2] [21, 22, 23, , 25, 26, 27, 28, 29, 30, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] [31, 32, 33, 34, 35, , 37, 38, 39, 40, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] -[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5] +NULL [] [51, 52, , 54, 55, 56, 57, 58, 59, 60, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] @@ -6390,7 +6997,7 @@ select array_resize(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) f [11, 12, 13, 14, 15, 16, 17, 18, , 20, 2, 2] [21, 22, 23, , 25, 26, 27, 28, 29, 30, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] [31, 32, 33, 34, 35, , 37, 38, 39, 40, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] -[5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5] +NULL [] [51, 52, , 54, 55, 56, 57, 58, 59, 60, , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , , ] [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7] @@ -6406,6 +7013,51 @@ select array_resize(arrow_cast([[1], [2], [3]], 'LargeList(List(Int64))'), 10, [ ---- [[1], [2], [3], [5], [5], [5], [5], [5], [5], [5]] +# array_resize null value +query ? +select array_resize(arrow_cast(NULL, 'List(Int8)'), 1); +---- +NULL + +statement ok +CREATE TABLE array_resize_values +AS VALUES + (make_array(1, NULL, 3, 4, 5, 6, 7, 8, 9, 10), 2, 1), + (make_array(11, 12, NULL, 14, 15, 16, 17, 18, 19, 20), 5, 2), + (make_array(21, 22, 23, 24, NULL, 26, 27, 28, 29, 30), 8, 3), + (make_array(31, 32, 33, 34, 35, 36, NULL, 38, 39, 40), 12, 4), + (NULL, 3, 0), + (make_array(41, 42, 43, 44, 45, 46, 47, 48, 49, 50), NULL, 6), + (make_array(51, 52, 53, 54, 55, NULL, 57, 58, 59, 60), 13, NULL), + (make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 15, 7) +; + +# array_resize columnar test #1 +query ? +select array_resize(column1, column2, column3) from array_resize_values; +---- +[1, ] +[11, 12, , 14, 15] +[21, 22, 23, 24, , 26, 27, 28] +[31, 32, 33, 34, 35, 36, , 38, 39, 40, 4, 4] +NULL +[] +[51, 52, 53, 54, 55, , 57, 58, 59, 60, , , ] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7, 7, 7, 7, 7] + +# array_resize columnar test #2 +query ? +select array_resize(arrow_cast(column1, 'LargeList(Int64)'), column2, column3) from array_resize_values; +---- +[1, ] +[11, 12, , 14, 15] +[21, 22, 23, 24, , 26, 27, 28] +[31, 32, 33, 34, 35, 36, , 38, 39, 40, 4, 4] +NULL +[] +[51, 52, 53, 54, 55, , 57, 58, 59, 60, , , ] +[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7, 7, 7, 7, 7] + ## array_reverse query ?? select array_reverse(make_array(1, 2, 3)), array_reverse(make_array(1)); @@ -6451,7 +7103,7 @@ create table test_create_array_table( d int ); -query ???I +query I insert into test_create_array_table values ([1, 2, 3], ['a', 'b', 'c'], [[4,6], [6,7,8]], 1); ---- @@ -6479,6 +7131,29 @@ select [1,2,3]::int[], [['1']]::int[][], arrow_typeof([]::text[]); ---- [1, 2, 3] [[1]] List(Field { name: "item", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +# test empty arrays return length +# issue: https://github.com/apache/datafusion/pull/12459 +statement ok +create table values_all_empty (a int[]) as values ([]), ([]); + +query B +select array_has(a, 1) from values_all_empty; +---- +false +false + +# Test create table with fixed sized array +statement ok +create table fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5,6]); + +query T +select arrow_typeof(a) from fixed_size_col_table; +---- +FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) +FixedSizeList(Field { name: "item", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) + +statement error +create table varying_fixed_size_col_table (a int[3]) as values ([1,2,3]), ([4,5]); ### Delete tables @@ -6652,3 +7327,9 @@ drop table fixed_size_arrays_values_without_nulls; statement ok drop table test_create_array_table; + +statement ok +drop table values_all_empty; + +statement ok +drop table fixed_size_col_table; diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt index 24c99fc849b6..8fde295e6051 100644 --- a/datafusion/sqllogictest/test_files/array_query.slt +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -41,17 +41,68 @@ SELECT * FROM data; # Filtering ########### -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I rowsort SELECT * FROM data WHERE column1 = [1,2,3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) -SELECT * FROM data WHERE column1 = column2 - -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != [1,2,3]; +---- +[2, 3] [2, 3] 1 -query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +query ??I SELECT * FROM data WHERE column1 != column2 +---- +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 < [1,2,3,4]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 <= [2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 > [1,2]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I rowsort +SELECT * FROM data WHERE column1 >= [1, 2, 3]; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +# test with scalar null +query ??I +SELECT * FROM data WHERE column2 = null; +---- + +query ??I +SELECT * FROM data WHERE null = column2; +---- + +query ??I rowsort +SELECT * FROM data WHERE column2 is distinct from null; +---- +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data WHERE column2 is not distinct from null; +---- +[1, 2, 3] NULL 1 ########### # Aggregates @@ -158,3 +209,68 @@ SELECT * FROM data ORDER BY column1, column3, column2; statement ok drop table data + + +# test filter column with all nulls +statement ok +create table data (a int) as values (null), (null), (null); + +query I +select * from data where a = null; +---- + +query I +select * from data where a is not distinct from null; +---- +NULL +NULL +NULL + +statement ok +drop table data; + +statement ok +create table data (a int[][], b int) as values ([[1,2,3]], 1), ([[2,3], [4,5]], 2), (null, 3); + +query ?I +select * from data; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 +NULL 3 + +query ?I +select * from data where a = [[1,2,3]]; +---- +[[1, 2, 3]] 1 + +query ?I +select * from data where a > [[1,2,3]]; +---- +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a > [[1,2]]; +---- +[[1, 2, 3]] 1 +[[2, 3], [4, 5]] 2 + +query ?I +select * from data where a < [[2, 3]]; +---- +[[1, 2, 3]] 1 + +# compare with null with eq results in null +query ?I +select * from data where a = null; +---- + +query ?I +select * from data where a != null; +---- + +# compare with null with distinct results in true/false +query ?I +select * from data where a is not distinct from null; +---- +NULL 3 diff --git a/datafusion/sqllogictest/test_files/arrow_files.slt b/datafusion/sqllogictest/test_files/arrow_files.slt index 8cf3550fdb25..e73acc384cb3 100644 --- a/datafusion/sqllogictest/test_files/arrow_files.slt +++ b/datafusion/sqllogictest/test_files/arrow_files.slt @@ -43,6 +43,11 @@ SELECT * FROM arrow_simple 3 baz false 4 NULL true +# Ensure that local files can not be read by default (a potential security issue) +# (url table is only supported when DynamicFileCatalog is enabled) +statement error DataFusion error: Error during planning: table 'datafusion.public.../core/tests/data/example.arrow' not found +SELECT * FROM '../core/tests/data/example.arrow'; + # ARROW partitioned table statement ok CREATE EXTERNAL TABLE arrow_partitioned ( @@ -113,3 +118,8 @@ EXPLAIN SELECT f0 FROM arrow_partitioned WHERE part = 456 ---- logical_plan TableScan: arrow_partitioned projection=[f0], full_filters=[arrow_partitioned.part = Int32(456)] physical_plan ArrowExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/partitioned_table_arrow/part=456/data.arrow]]}, projection=[f0] + + +# Errors in partition filters should be reported +query error Divide by zero error +SELECT f0 FROM arrow_partitioned WHERE CASE WHEN true THEN 1 / 0 ELSE part END = 1; diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 3e8694f3b2c2..77b10b41ccb3 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -92,10 +92,9 @@ SELECT arrow_cast('1', 'Int16') 1 # Basic error test -query error DataFusion error: Error during planning: No function matches the given name and argument types 'arrow_cast\(Utf8\)'. You might need to add explicit type casts. +query error SELECT arrow_cast('1') - query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) SELECT arrow_cast('1', 43) @@ -103,7 +102,7 @@ query error Error unrecognized word: unknown SELECT arrow_cast('1', 'unknown') # Round Trip tests: -query TTTTTTTTTTTTTTTTTTTTTTT +query TTTTTTTTTTTTTTTTTTTTTTTTT SELECT arrow_typeof(arrow_cast(1, 'Int8')) as col_i8, arrow_typeof(arrow_cast(1, 'Int16')) as col_i16, @@ -113,12 +112,12 @@ SELECT arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16, arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32, arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64, - -- can't seem to cast to Float16 for some reason - -- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16, + arrow_typeof(arrow_cast(1, 'Float16')) as col_f16, arrow_typeof(arrow_cast(1, 'Float32')) as col_f32, arrow_typeof(arrow_cast(1, 'Float64')) as col_f64, arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8, arrow_typeof(arrow_cast('foo', 'LargeUtf8')) as col_large_utf8, + arrow_typeof(arrow_cast('foo', 'Utf8View')) as col_utf8_view, arrow_typeof(arrow_cast('foo', 'Binary')) as col_binary, arrow_typeof(arrow_cast('foo', 'LargeBinary')) as col_large_binary, arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Second, None)')) as col_ts_s, @@ -131,7 +130,7 @@ SELECT arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, Some("+08:00"))')) as col_tstz_ns, arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict ---- -Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8) +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64 Utf8 LargeUtf8 Utf8View Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8) @@ -148,15 +147,14 @@ create table foo as select arrow_cast(1, 'UInt16') as col_u16, arrow_cast(1, 'UInt32') as col_u32, arrow_cast(1, 'UInt64') as col_u64, - -- can't seem to cast to Float16 for some reason - -- arrow_cast(1.0, 'Float16') as col_f16, + arrow_cast(1.0, 'Float16') as col_f16, arrow_cast(1.0, 'Float32') as col_f32, arrow_cast(1.0, 'Float64') as col_f64 ; ## Ensure each column in the table has the expected type -query TTTTTTTTTT +query TTTTTTTTTTT SELECT arrow_typeof(col_i8), arrow_typeof(col_i16), @@ -166,12 +164,12 @@ SELECT arrow_typeof(col_u16), arrow_typeof(col_u32), arrow_typeof(col_u64), - -- arrow_typeof(col_f16), + arrow_typeof(col_f16), arrow_typeof(col_f32), arrow_typeof(col_f64) FROM foo; ---- -Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64 statement ok @@ -214,21 +212,23 @@ statement ok create table foo as select arrow_cast('foo', 'Utf8') as col_utf8, arrow_cast('foo', 'LargeUtf8') as col_large_utf8, + arrow_cast('foo', 'Utf8View') as col_utf8_view, arrow_cast('foo', 'Binary') as col_binary, arrow_cast('foo', 'LargeBinary') as col_large_binary ; ## Ensure each column in the table has the expected type -query TTTT +query TTTTT SELECT arrow_typeof(col_utf8), arrow_typeof(col_large_utf8), + arrow_typeof(col_utf8_view), arrow_typeof(col_binary), arrow_typeof(col_large_binary) FROM foo; ---- -Utf8 LargeUtf8 Binary LargeBinary +Utf8 LargeUtf8 Utf8View Binary LargeBinary statement ok @@ -290,22 +290,22 @@ query ? --- select arrow_cast(interval '30 minutes', 'Interval(MonthDayNano)'); ---- -0 years 0 mons 0 days 0 hours 30 mins 0.000000000 secs +30 mins query ? select arrow_cast('30 minutes', 'Interval(DayTime)'); ---- -0 years 0 mons 0 days 0 hours 30 mins 0.000 secs +30 mins query ? select arrow_cast('1 year 5 months', 'Interval(YearMonth)'); ---- -1 years 5 mons 0 days 0 hours 0 mins 0.00 secs +1 years 5 mons query ? select arrow_cast('30 minutes', 'Interval(MonthDayNano)'); ---- -0 years 0 mons 0 days 0 hours 30 mins 0.000000000 secs +30 mins ## Duration @@ -337,7 +337,7 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( ---- 2000-01-01T00:00:00+08:00 -statement error DataFusion error: Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "\+25:00": failed to parse timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); @@ -423,3 +423,16 @@ query ? select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)'); ---- [1, 2, 3] + +# Tests for Utf8View +query TT +select arrow_cast('MyAwesomeString', 'Utf8View'), arrow_typeof(arrow_cast('MyAwesomeString', 'Utf8View')) +---- +MyAwesomeString Utf8View + +# Fails until we update to use the arrow-cast release with support for casting utf8 types to BinaryView +# refer to merge commit https://github.com/apache/arrow-rs/commit/4bd737dab2aa17aca200259347909d48ed793ba1 +query ?T +select arrow_cast('MyAwesomeString', 'BinaryView'), arrow_typeof(arrow_cast('MyAwesomeString', 'BinaryView')) +---- +4d79417765736f6d65537472696e67 BinaryView diff --git a/datafusion/sqllogictest/test_files/avro.slt b/datafusion/sqllogictest/test_files/avro.slt index 7b9fbe556fee..8282331f995e 100644 --- a/datafusion/sqllogictest/test_files/avro.slt +++ b/datafusion/sqllogictest/test_files/avro.slt @@ -31,8 +31,7 @@ CREATE EXTERNAL TABLE alltypes_plain ( timestamp_col TIMESTAMP NOT NULL, ) STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/alltypes_plain.avro' +LOCATION '../../testing/data/avro/alltypes_plain.avro'; statement ok CREATE EXTERNAL TABLE alltypes_plain_snappy ( @@ -49,8 +48,7 @@ CREATE EXTERNAL TABLE alltypes_plain_snappy ( timestamp_col TIMESTAMP NOT NULL, ) STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/alltypes_plain.snappy.avro' +LOCATION '../../testing/data/avro/alltypes_plain.snappy.avro'; statement ok CREATE EXTERNAL TABLE alltypes_plain_bzip2 ( @@ -67,8 +65,7 @@ CREATE EXTERNAL TABLE alltypes_plain_bzip2 ( timestamp_col TIMESTAMP NOT NULL, ) STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/alltypes_plain.bzip2.avro' +LOCATION '../../testing/data/avro/alltypes_plain.bzip2.avro'; statement ok CREATE EXTERNAL TABLE alltypes_plain_xz ( @@ -85,8 +82,7 @@ CREATE EXTERNAL TABLE alltypes_plain_xz ( timestamp_col TIMESTAMP NOT NULL, ) STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/alltypes_plain.xz.avro' +LOCATION '../../testing/data/avro/alltypes_plain.xz.avro'; statement ok CREATE EXTERNAL TABLE alltypes_plain_zstandard ( @@ -103,34 +99,29 @@ CREATE EXTERNAL TABLE alltypes_plain_zstandard ( timestamp_col TIMESTAMP NOT NULL, ) STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/alltypes_plain.zstandard.avro' +LOCATION '../../testing/data/avro/alltypes_plain.zstandard.avro'; statement ok CREATE EXTERNAL TABLE single_nan ( mycol FLOAT ) STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/single_nan.avro' +LOCATION '../../testing/data/avro/single_nan.avro'; statement ok CREATE EXTERNAL TABLE nested_records STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/nested_records.avro' +LOCATION '../../testing/data/avro/nested_records.avro'; statement ok CREATE EXTERNAL TABLE simple_enum STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/simple_enum.avro' +LOCATION '../../testing/data/avro/simple_enum.avro'; statement ok CREATE EXTERNAL TABLE simple_fixed STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/simple_fixed.avro' +LOCATION '../../testing/data/avro/simple_fixed.avro'; # test avro query query IT @@ -207,22 +198,22 @@ NULL query IT SELECT id, CAST(string_col AS varchar) FROM alltypes_plain_multi_files ---- -4 0 -5 1 -6 0 -7 1 -2 0 -3 1 -0 0 -1 1 -4 0 -5 1 -6 0 -7 1 -2 0 -3 1 -0 0 -1 1 +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 # test avro nested records query ???? @@ -252,11 +243,11 @@ query TT EXPLAIN SELECT count(*) from alltypes_plain ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--TableScan: alltypes_plain projection=[] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)AggregateExec: mode=Final, gby=[], aggr=[count(*)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------AvroExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/avro/alltypes_plain.avro]]} diff --git a/datafusion/sqllogictest/test_files/binary.slt b/datafusion/sqllogictest/test_files/binary.slt index 621cd3e528f1..5c5f9d510e55 100644 --- a/datafusion/sqllogictest/test_files/binary.slt +++ b/datafusion/sqllogictest/test_files/binary.slt @@ -25,7 +25,7 @@ SELECT X'FF01', arrow_typeof(X'FF01'); ---- ff01 Binary -# Invaid hex values +# Invalid hex values query error DataFusion error: Error during planning: Invalid HexStringLiteral 'Z' SELECT X'Z' diff --git a/datafusion/sqllogictest/test_files/binary_view.slt b/datafusion/sqllogictest/test_files/binary_view.slt new file mode 100644 index 000000000000..f973b909aeb6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/binary_view.slt @@ -0,0 +1,217 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +######## +## Test setup +######## + +statement ok +create table test_source as values + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + (NULL, 'R') +; + +# Table with the different combination of column types +statement ok +CREATE TABLE test AS +SELECT + arrow_cast(column1, 'Utf8') as column1_utf8, + arrow_cast(column2, 'Utf8') as column2_utf8, + arrow_cast(column1, 'Binary') AS column1_binary, + arrow_cast(column2, 'Binary') AS column2_binary, + arrow_cast(column1, 'LargeBinary') AS column1_large_binary, + arrow_cast(column2, 'LargeBinary') AS column2_large_binary, + arrow_cast(arrow_cast(column1, 'Binary'), 'BinaryView') AS column1_binaryview, + arrow_cast(arrow_cast(column2, 'Binary'), 'BinaryView') AS column2_binaryview, + arrow_cast(column1, 'Dictionary(Int32, Binary)') AS column1_dict, + arrow_cast(column2, 'Dictionary(Int32, Binary)') AS column2_dict +FROM test_source; + +statement ok +drop table test_source + +######## +## BinaryView to BinaryView +######## + +# BinaryView scalar to BinaryView scalar + +query BBBB +SELECT + arrow_cast(arrow_cast('NULL', 'Binary'), 'BinaryView') = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison1, + arrow_cast(arrow_cast('NULL', 'Binary'), 'BinaryView') <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison2, + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison3, + arrow_cast(arrow_cast('Xiangpeng', 'Binary'), 'BinaryView') <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') AS comparison4; +---- +false true true true + + +# BinaryView column to BinaryView column comparison as filters + +query TT +select column1_utf8, column2_utf8 from test where column1_binaryview = column2_binaryview; +---- +Xiangpeng Xiangpeng + +query TT +select column1_utf8, column2_utf8 from test where column1_binaryview <> column2_binaryview; +---- +Andrew X +Raphael R + +# BinaryView column to BinaryView column +query TTBB +select + column1_utf8, column2_utf8, + column1_binaryview = column2_binaryview, + column1_binaryview <> column2_binaryview +from test; +---- +Andrew X false true +Xiangpeng Xiangpeng true false +Raphael R false true +NULL R NULL NULL + +# BinaryView column to BinaryView scalar comparison +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = column1_binaryview, + column1_binaryview <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') <> column1_binaryview +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +######## +## BinaryView to Binary +######## + +# test BinaryViewArray with Binary columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = column2_binary, + column2_binary = column1_binaryview, + column1_binaryview <> column2_binary, + column2_binary <> column1_binaryview +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# test BinaryViewArray with LargeBinary columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = column2_large_binary, + column2_large_binary = column1_binaryview, + column1_binaryview <> column2_large_binary, + column2_large_binary <> column1_binaryview +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# BinaryView column to Binary scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = arrow_cast('Andrew', 'Binary'), + arrow_cast('Andrew', 'Binary') = column1_binaryview, + column1_binaryview <> arrow_cast('Andrew', 'Binary'), + arrow_cast('Andrew', 'Binary') <> column1_binaryview +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# BinaryView column to LargeBinary scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binaryview = arrow_cast('Andrew', 'LargeBinary'), + arrow_cast('Andrew', 'LargeBinary') = column1_binaryview, + column1_binaryview <> arrow_cast('Andrew', 'LargeBinary'), + arrow_cast('Andrew', 'LargeBinary') <> column1_binaryview +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +# Binary column to BinaryView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_binary = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = column1_binary, + column1_binary <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') <> column1_binary +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + + +# LargeBinary column to BinaryView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_large_binary = arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') = column1_large_binary, + column1_large_binary <> arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView'), + arrow_cast(arrow_cast('Andrew', 'Binary'), 'BinaryView') <> column1_large_binary +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +NULL R NULL NULL NULL NULL + +statement ok +drop table test; + +statement ok +create table bv as values +( + arrow_cast('one', 'BinaryView'), + arrow_cast('two', 'BinaryView') +); + +query B +select column1 like 'o%' from bv; +---- +true + +statement ok +drop table bv; diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt new file mode 100644 index 000000000000..3c967eed219a --- /dev/null +++ b/datafusion/sqllogictest/test_files/case.slt @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# create test data +statement ok +create table foo (a int, b int) as values (1, 2), (3, 4), (5, 6), (null, null), (6, null), (null, 7); + +# CASE WHEN with condition +query T +SELECT CASE a WHEN 1 THEN 'one' WHEN 3 THEN 'three' ELSE '?' END FROM foo +---- +one +three +? +? +? +? + +# CASE WHEN with no condition +query I +SELECT CASE WHEN a > 2 THEN a ELSE b END FROM foo +---- +2 +3 +5 +NULL +6 +7 + +# column or explicit null +query I +SELECT CASE WHEN a > 2 THEN b ELSE null END FROM foo +---- +NULL +4 +6 +NULL +NULL +7 + +# column or implicit null +query I +SELECT CASE WHEN a > 2 THEN b END FROM foo +---- +NULL +4 +6 +NULL +NULL +7 + +# scalar or scalar (string) +query T +SELECT CASE WHEN a > 2 THEN 'even' ELSE 'odd' END FROM foo +---- +odd +even +even +odd +even +odd + +# scalar or scalar (int) +query I +SELECT CASE WHEN a > 2 THEN 1 ELSE 0 END FROM foo +---- +0 +1 +1 +0 +1 +0 + +# predicate binary expression with scalars (does not make much sense because the expression in +# this case is always false, so this expression could be rewritten as a literal 0 during planning +query I +SELECT CASE WHEN 1 > 2 THEN 1 ELSE 0 END FROM foo +---- +0 +0 +0 +0 +0 +0 + +# predicate using boolean literal (does not make much sense because the expression in +# this case is always false, so this expression could be rewritten as a literal 0 during planning +query I +SELECT CASE WHEN false THEN 1 ELSE 0 END FROM foo +---- +0 +0 +0 +0 +0 +0 + +# test then value type coercion +# List(Utf8) will be casted to List(Int64) +query ? +SELECT CASE 1 WHEN 1 THEN ['1', '2', '3'] WHEN 2 THEN [1, 2, 3] ELSE null END; +---- +[1, 2, 3] + +query ? +SELECT CASE 1 WHEN 1 THEN [[1,2], [2,4]] WHEN 2 THEN [['1','2'], ['2','4']] ELSE null END; +---- +[[1, 2], [2, 4]] + +query ? +SELECT CASE 1 WHEN 1 THEN [1,2,3] WHEN 2 THEN arrow_cast([1,2,3], 'LargeList(Int64)') WHEN 3 THEN arrow_cast([1,2,3], 'FixedSizeList(3, Int32)') ELSE null END; +---- +[1, 2, 3] + +query ? +SELECT CASE 1 WHEN 1 THEN [[1,2], [2,4]] WHEN 2 THEN arrow_cast([[1,2], [2,4]], 'LargeList(LargeList(Int64))') WHEN 3 THEN arrow_cast([[1,2], [2,4]], 'FixedSizeList(2, FixedSizeList(2, Int32))') ELSE null END; +---- +[[1, 2], [2, 4]] + +query ? +SELECT CASE 1 WHEN 1 THEN [1,2,3] WHEN 2 THEN arrow_cast(['1','2','3'], 'LargeList(Utf8)') WHEN 3 THEN arrow_cast(['1','2','3'], 'FixedSizeList(3, Utf8)') ELSE null END; +---- +[1, 2, 3] + +query ? +SELECT CASE 1 WHEN 1 THEN [[1,2], [2,4]] WHEN 2 THEN arrow_cast([['1','2'], ['2','4']], 'LargeList(LargeList(Utf8))') WHEN 3 THEN arrow_cast([['1','2'], ['2','4']], 'FixedSizeList(2, FixedSizeList(2, Utf8))') ELSE null END; +---- +[[1, 2], [2, 4]] + +query ? +SELECT CASE 1 WHEN 1 THEN arrow_cast([1,2,3], 'LargeList(Int64)') WHEN 2 THEN arrow_cast(['1','2','3'], 'LargeList(Utf8)') ELSE null END; +---- +[1, 2, 3] + +query ? +SELECT CASE 1 WHEN 1 THEN arrow_cast([1, 2], 'FixedSizeList(2, Int64)') WHEN 2 THEN arrow_cast(['1', '2', '3'], 'FixedSizeList(3, Utf8)') ELSE null END; +---- +[1, 2] + +query error DataFusion error: type_coercion +SELECT CASE 1 WHEN 1 THEN [1,2,3] WHEN 2 THEN 'test' ELSE null END; + +# test case when type coercion +query I +SELECT CASE [1,2,3] WHEN arrow_cast([1,2,3], 'LargeList(Int64)') THEN 1 ELSE 0 END; +---- +1 + +query I +SELECT CASE [1,2,3] WHEN arrow_cast(['1','2','3'], 'LargeList(Int64)') THEN 1 ELSE 0 END; +---- +1 + +query I +SELECT CASE arrow_cast([1,2,3], 'LargeList(Int64)') WHEN [1,2,3] THEN 1 ELSE 0 END; +---- +1 + +query I +SELECT CASE [[1,2],[2,4]] WHEN arrow_cast([[1,2],[2,4]], 'LargeList(LargeList(Int64))') THEN 1 ELSE 0 END; +---- +1 + +query I +SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN [1,2,3] THEN 1 ELSE 0 END; +---- +1 + +query error DataFusion error: type_coercion +SELECT CASE [1,2,3] WHEN 'test' THEN 1 ELSE 0 END; + +query I +SELECT CASE arrow_cast([1,2], 'FixedSizeList(2, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END; +---- +0 + +query I +SELECT CASE arrow_cast([1,2], 'FixedSizeList(2, Int64)') WHEN arrow_cast(['1','2','3'], 'FixedSizeList(3, Utf8)') THEN 1 ELSE 0 END; +---- +0 + +query I +SELECT CASE arrow_cast(['1','2'], 'FixedSizeList(2, Utf8)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END; +---- +0 + +query I +SELECT CASE arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') WHEN arrow_cast([1,2,3], 'FixedSizeList(3, Int64)') THEN 1 ELSE 0 END; +---- +1 diff --git a/datafusion/sqllogictest/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt index 73862be60d9b..3466354e54d7 100644 --- a/datafusion/sqllogictest/test_files/cast.slt +++ b/datafusion/sqllogictest/test_files/cast.slt @@ -56,3 +56,36 @@ query I SELECT 10::bigint unsigned ---- 10 + +# cast array +query ? +SELECT CAST(MAKE_ARRAY(1, 2, 3) AS VARCHAR[]) +---- +[1, 2, 3] + + +# cast empty array +query ? +SELECT CAST(MAKE_ARRAY() AS VARCHAR[]) +---- +[] + +statement ok +create table t0(v0 BIGINT); + +statement ok +insert into t0 values (1),(2),(3); + +query I +select * from t0 where v0>1e100; +---- + +query I +select * from t0 where v0<1e100; +---- +1 +2 +3 + +statement ok +drop table t0; diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index c2dba435263d..733c0a3cd972 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -274,5 +274,23 @@ query PI SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; ---- +# Clickbench "Extended" queries that test count distinct + +query III +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; +---- +1 1 1 + +query III +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; +---- +1 1 1 + +query TIIII +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +---- +� 1 1 1 1 + + statement ok drop table hits; diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt new file mode 100644 index 000000000000..97e77d0feb3d --- /dev/null +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -0,0 +1,432 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# Test coalesce function +query I +select coalesce(1, 2, 3); +---- +1 + +# test with first null +query IT +select coalesce(null, 3, 2, 1), arrow_typeof(coalesce(null, 3, 2, 1)); +---- +3 Int64 + +# test with null values +query ? +select coalesce(null, null); +---- +NULL + +# cast to float +query RT +select + coalesce(1, 2.0), + arrow_typeof(coalesce(1, 2.0)) +; +---- +1 Float64 + +query RT +select + coalesce(2.0, 1), + arrow_typeof(coalesce(2.0, 1)) +; +---- +2 Float64 + +query RT +select + coalesce(1, arrow_cast(2.0, 'Float32')), + arrow_typeof(coalesce(1, arrow_cast(2.0, 'Float32'))) +; +---- +1 Float32 + +# test with empty args +query error +select coalesce(); + +# test with different types +query I +select coalesce(arrow_cast(1, 'Int32'), arrow_cast(1, 'Int64')); +---- +1 + +# test with nulls +query ?T +select coalesce(null, null, null), arrow_typeof(coalesce(null, null)); +---- +NULL Null + +# i32 and u32, cast to wider type i64 +query IT +select + coalesce(arrow_cast(2, 'Int32'), arrow_cast(3, 'UInt32')), + arrow_typeof(coalesce(arrow_cast(2, 'Int32'), arrow_cast(3, 'UInt32'))); +---- +2 Int64 + +query IT +select + coalesce(arrow_cast(2, 'Int16'), arrow_cast(3, 'UInt16')), + arrow_typeof(coalesce(arrow_cast(2, 'Int16'), arrow_cast(3, 'UInt16'))); +---- +2 Int32 + +query IT +select + coalesce(arrow_cast(2, 'Int8'), arrow_cast(3, 'UInt8')), + arrow_typeof(coalesce(arrow_cast(2, 'Int8'), arrow_cast(3, 'UInt8'))); +---- +2 Int16 + +# i64 and u32, cast to wider type i64 +query IT +select + coalesce(2, arrow_cast(3, 'UInt32')), + arrow_typeof(coalesce(2, arrow_cast(3, 'UInt32'))); +---- +2 Int64 + +# TODO: Got types (i64, u64), casting to decimal or double or even i128 if supported +query IT +select + coalesce(2, arrow_cast(3, 'UInt64')), + arrow_typeof(coalesce(2, arrow_cast(3, 'UInt64'))); +---- +2 Int64 + +statement ok +create table t1 (a bigint, b int, c int) as values (null, null, 1), (null, 2, null); + +# Follow Postgres and DuckDB behavior, since a is bigint, although the value is null, all args are coerced to bigint +query IT +select + coalesce(a, b, c), + arrow_typeof(coalesce(a, b, c)) +from t1; +---- +1 Int64 +2 Int64 + +# b, c has the same type int, so the result is int +query IT +select + coalesce(b, c), + arrow_typeof(coalesce(b, c)) +from t1; +---- +1 Int32 +2 Int32 + +statement ok +drop table t1; + +# test multi rows +statement ok +CREATE TABLE t1( + c1 int, + c2 int +) as VALUES +(1, 2), +(NULL, 2), +(1, NULL), +(NULL, NULL); + +query I +SELECT COALESCE(c1, c2) FROM t1 +---- +1 +2 +1 +NULL + +statement ok +drop table t1; + +# Decimal128(7, 2) and int64 are coerced to common wider type Decimal128(22, 2) +query RT +select + coalesce(arrow_cast(2, 'Decimal128(7, 2)'), 0), + arrow_typeof(coalesce(arrow_cast(2, 'Decimal128(7, 2)'), 0)) +---- +2 Decimal128(22, 2) + +query RT +select + coalesce(arrow_cast(2, 'Decimal256(7, 2)'), 0), + arrow_typeof(coalesce(arrow_cast(2, 'Decimal256(7, 2)'), 0)); +---- +2 Decimal256(22, 2) + +# coalesce string +query TT +select + coalesce('', 'test'), + coalesce(null, 'test'); +---- +(empty) test + +# coalesce utf8 and large utf8 +query TT +select + coalesce('a', arrow_cast('b', 'LargeUtf8')), + arrow_typeof(coalesce('a', arrow_cast('b', 'LargeUtf8'))) +; +---- +a LargeUtf8 + +# coalesce array +query ?T +select + coalesce(array[1, 2], array[3, 4]), + arrow_typeof(coalesce(array[1, 2], array[3, 4])); +---- +[1, 2] List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +query ?T +select + coalesce(null, array[3, 4]), + arrow_typeof(coalesce(array[1, 2], array[3, 4])); +---- +[3, 4] List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# coalesce with array +query ?T +select + coalesce(array[1, 2], array[arrow_cast(3, 'Int32'), arrow_cast(4, 'Int32')]), + arrow_typeof(coalesce(array[1, 2], array[arrow_cast(3, 'Int32'), arrow_cast(4, 'Int32')])); +---- +[1, 2] List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +# test dict(int32, utf8) +statement ok +create table test1 as values (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (null); + +query T +select coalesce(column1, 'none_set') from test1; +---- +foo +none_set + +query T +select coalesce(null, column1, 'none_set') from test1; +---- +foo +none_set + +# explicitly cast to Utf8 +query T +select coalesce(arrow_cast(column1, 'Utf8'), 'none_set') from test1; +---- +foo +none_set + +statement ok +drop table test1 + +# test dict coercion with value +statement ok +create table t(c varchar) as values ('a'), (null); + +query TT +select + coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)')), + arrow_typeof(coalesce(c, arrow_cast('b', 'Dictionary(Int32, Utf8)'))) +from t; +---- +a Dictionary(Int32, Utf8) +b Dictionary(Int32, Utf8) + +statement ok +drop table t; + +# test dict coercion with dict +statement ok +create table t as values + (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), + (null); + +query TT +select + coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, LargeUtf8)')), + arrow_typeof(coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, LargeUtf8)'))) +from t; +---- +foo Dictionary(Int64, LargeUtf8) +bar Dictionary(Int64, LargeUtf8) + +query TT +select + coalesce(column1, arrow_cast('bar', 'Dictionary(Int32, LargeUtf8)')), + arrow_typeof(coalesce(column1, arrow_cast('bar', 'Dictionary(Int32, LargeUtf8)'))) +from t; +---- +foo Dictionary(Int32, LargeUtf8) +bar Dictionary(Int32, LargeUtf8) + +query TT +select + coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, Utf8)')), + arrow_typeof(coalesce(column1, arrow_cast('bar', 'Dictionary(Int64, Utf8)'))) +from t; +---- +foo Dictionary(Int64, Utf8) +bar Dictionary(Int64, Utf8) + +statement ok +drop table t; + +# test dict(int32, int8) +query ? +select coalesce(34, arrow_cast(123, 'Dictionary(Int32, Int8)')); +---- +34 + +query ? +select coalesce(arrow_cast(123, 'Dictionary(Int32, Int8)'), 34); +---- +123 + +query ? +select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)')); +---- +34 + +# numeric string coercion +query RT +select coalesce(2.0, 1, '3'), arrow_typeof(coalesce(2.0, 1, '3')); +---- +2 Float64 + +# explicitly cast to Int8, and it will implicitly cast to Int64 +query IT +select + coalesce(arrow_cast(123, 'Int8'), 34), + arrow_typeof(coalesce(arrow_cast(123, 'Int8'), 34)); +---- +123 Int64 + +statement ok +CREATE TABLE test( + c1 INT, + c2 INT +) as VALUES +(0, 1), +(NULL, 1), +(1, 0), +(NULL, 1), +(NULL, NULL); + +# coalesce result +query I rowsort +SELECT COALESCE(c1, c2) FROM test +---- +0 +1 +1 +1 +NULL + +# numeric string is coerced to numeric in both Postgres and DuckDB +query I +SELECT COALESCE(c1, c2, '-1') FROM test; +---- +0 +1 +1 +1 +-1 + +statement ok +drop table test + +statement ok +CREATE TABLE test( + c1 BIGINT, + c2 BIGINT +) as VALUES +(1, 2), +(NULL, 2), +(1, NULL), +(NULL, NULL); + +# coalesce sum with default value +query I +SELECT SUM(COALESCE(c1, c2, 0)) FROM test +---- +4 + +# coalesce mul with default value +query I +SELECT COALESCE(c1 * c2, 0) FROM test +---- +2 +0 +0 +0 + +statement ok +drop table test + +# coalesce date32 + +statement ok +CREATE TABLE test( + d1_date DATE, + d2_date DATE, + d3_date DATE +) as VALUES + ('2022-12-12','2022-12-12','2022-12-12'), + (NULL,'2022-12-11','2022-12-12'), + ('2022-12-12','2022-12-10','2022-12-12'), + ('2022-12-12',NULL,'2022-12-12'), + ('2022-12-12','2022-12-8','2022-12-12'), + ('2022-12-12','2022-12-7',NULL), + ('2022-12-12',NULL,'2022-12-12'), + (NULL,'2022-12-5','2022-12-12') +; + +query D +SELECT COALESCE(d1_date, d2_date, d3_date) FROM test +---- +2022-12-12 +2022-12-11 +2022-12-12 +2022-12-12 +2022-12-12 +2022-12-12 +2022-12-12 +2022-12-05 + +query T +SELECT arrow_typeof(COALESCE(d1_date, d2_date, d3_date)) FROM test +---- +Date32 +Date32 +Date32 +Date32 +Date32 +Date32 +Date32 +Date32 + +statement ok +drop table test diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 502dfd4fa6bb..caa708483a11 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -20,13 +20,13 @@ statement ok create table source_table(col1 integer, col2 varchar) as values (1, 'Foo'), (2, 'Bar'); # Copy to directory as multiple files -query IT +query I COPY source_table TO 'test_files/scratch/copy/table/' STORED AS parquet OPTIONS ('format.compression' 'zstd(10)'); ---- 2 # Copy to directory as partitioned files -query IT +query I COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' STORED AS parquet PARTITIONED BY (col2) OPTIONS ('format.compression' 'zstd(10)'); ---- 2 @@ -36,7 +36,7 @@ statement ok CREATE EXTERNAL TABLE validate_partitioned_parquet STORED AS PARQUET LOCATION 'test_files/scratch/copy/partitioned_table1/' PARTITIONED BY (col2); -query I? +query IT select * from validate_partitioned_parquet order by col1, col2; ---- 1 Foo @@ -53,7 +53,7 @@ select * from validate_partitioned_parquet_bar order by col1; 2 # Copy to directory as partitioned files -query ITT +query I COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' STORED AS parquet PARTITIONED BY (column2, column3) OPTIONS ('format.compression' 'zstd(10)'); ---- @@ -64,7 +64,7 @@ statement ok CREATE EXTERNAL TABLE validate_partitioned_parquet2 STORED AS PARQUET LOCATION 'test_files/scratch/copy/partitioned_table2/' PARTITIONED BY (column2, column3); -query I?? +query ITT select * from validate_partitioned_parquet2 order by column1,column2,column3; ---- 1 a x @@ -81,7 +81,7 @@ select * from validate_partitioned_parquet_a_x order by column1; 1 # Copy to directory as partitioned files -query TTT +query I COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' STORED AS parquet PARTITIONED BY (column1, column3) OPTIONS ('format.compression' 'zstd(10)'); ---- @@ -92,7 +92,7 @@ statement ok CREATE EXTERNAL TABLE validate_partitioned_parquet3 STORED AS PARQUET LOCATION 'test_files/scratch/copy/partitioned_table3/' PARTITIONED BY (column1, column3); -query ?T? +query TTT select column1, column2, column3 from validate_partitioned_parquet3 order by column1,column2,column3; ---- 1 a x @@ -108,6 +108,28 @@ select * from validate_partitioned_parquet_1_x order by column2; ---- a +# Copy to directory as partitioned files +query I +COPY (values (1::int, 2::bigint, 19968::date, arrow_cast(1725235200000, 'Date64'), false, 'x'), + (11::int, 22::bigint, 19969::date, arrow_cast(1725148800000, 'Date64'), true, 'y') +) +TO 'test_files/scratch/copy/partitioned_table5/' STORED AS parquet PARTITIONED BY (column1, column2, column3, column4, column5) +OPTIONS ('format.compression' 'zstd(10)'); +---- +2 + +# validate partitioning +statement ok +CREATE EXTERNAL TABLE validate_partitioned_parquet5 (column1 int, column2 bigint, column3 date, column4 date, column5 boolean, column6 varchar) STORED AS PARQUET +LOCATION 'test_files/scratch/copy/partitioned_table5/' PARTITIONED BY (column1, column2, column3, column4, column5); + +query IIDDBT +select column1, column2, column3, column4, column5, column6 from validate_partitioned_parquet5 order by column1,column2,column3,column4,column5; +---- +1 2 2024-09-02 2024-09-02 false x +11 22 2024-09-03 2024-09-01 true y + + statement ok create table test ("'test'" varchar, "'test2'" varchar, "'test3'" varchar); @@ -166,8 +188,25 @@ physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) 02)--MemoryExec: partitions=1, partition_sizes=[1] +# Copy to directory as partitioned files with keep_partition_by_columns enabled +query I +COPY (values ('1', 'a'), ('2', 'b'), ('3', 'c')) TO 'test_files/scratch/copy/partitioned_table4/' STORED AS parquet PARTITIONED BY (column1) +OPTIONS (execution.keep_partition_by_columns true); +---- +3 + +# validate generated file contains tables +statement ok +CREATE EXTERNAL TABLE validate_partitioned_parquet4 STORED AS PARQUET +LOCATION 'test_files/scratch/copy/partitioned_table4/column1=1/*.parquet'; + +query TT +select column1, column2 from validate_partitioned_parquet4 order by column1,column2; +---- +1 a + # Copy more files to directory via query -query IT +query I COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' STORED AS PARQUET; ---- 4 @@ -186,7 +225,7 @@ select * from validate_parquet; 1 Foo 2 Bar -query ? +query I copy (values (struct(timestamp '2021-01-01 01:00:01', 1)), (struct(timestamp '2022-01-01 01:00:01', 2)), (struct(timestamp '2023-01-03 01:00:01', 3)), (struct(timestamp '2024-01-01 01:00:01', 4))) to 'test_files/scratch/copy/table_nested2/' STORED AS PARQUET; @@ -204,7 +243,7 @@ select * from validate_parquet_nested2; {c0: 2023-01-03T01:00:01, c1: 3} {c0: 2024-01-01T01:00:01, c1: 4} -query ?? +query I COPY (values (struct ('foo', (struct ('foo', make_array(struct('a',1), struct('b',2))))), make_array(timestamp '2023-01-01 01:00:01',timestamp '2023-01-01 01:00:01')), (struct('bar', (struct ('foo', make_array(struct('aa',10), struct('bb',20))))), make_array(timestamp '2024-01-01 01:00:01', timestamp '2024-01-01 01:00:01'))) @@ -222,7 +261,7 @@ select * from validate_parquet_nested; {c0: foo, c1: {c0: foo, c1: [{c0: a, c1: 1}, {c0: b, c1: 2}]}} [2023-01-01T01:00:01, 2023-01-01T01:00:01] {c0: bar, c1: {c0: foo, c1: [{c0: aa, c1: 10}, {c0: bb, c1: 20}]}} [2024-01-01T01:00:01, 2024-01-01T01:00:01] -query ? +query I copy (values ([struct('foo', 1), struct('bar', 2)])) to 'test_files/scratch/copy/array_of_struct/' STORED AS PARQUET; @@ -238,7 +277,7 @@ select * from validate_array_of_struct; ---- [{c0: foo, c1: 1}, {c0: bar, c1: 2}] -query ? +query I copy (values (struct('foo', [1,2,3], struct('bar', [2,3,4])))) to 'test_files/scratch/copy/struct_with_array/' STORED AS PARQUET; ---- @@ -254,8 +293,8 @@ select * from validate_struct_with_array; {c0: foo, c1: [1, 2, 3], c2: {c0: bar, c1: [2, 3, 4]}} -# Copy parquet with all supported statment overrides -query IT +# Copy parquet with all supported statement overrides +query I COPY source_table TO 'test_files/scratch/copy/table_with_options/' STORED AS PARQUET @@ -271,7 +310,7 @@ OPTIONS ( 'format.created_by' 'DF copy.slt', 'format.column_index_truncate_length' 123, 'format.data_page_row_count_limit' 1234, -'format.bloom_filter_enabled' true, +'format.bloom_filter_on_read' true, 'format.bloom_filter_enabled::col1' false, 'format.bloom_filter_fpp::col2' 0.456, 'format.bloom_filter_ndv::col2' 456, @@ -283,11 +322,73 @@ OPTIONS ( 'format.statistics_enabled::col2' none, 'format.max_statistics_size' 123, 'format.bloom_filter_fpp' 0.001, -'format.bloom_filter_ndv' 100 +'format.bloom_filter_ndv' 100, +'format.metadata::key' 'value' ) ---- 2 +# valid vs invalid metadata + +# accepts map with a single entry +statement ok +COPY source_table +TO 'test_files/scratch/copy/table_with_metadata/' +STORED AS PARQUET +OPTIONS ( + 'format.metadata::key' 'value' +) + +# accepts multiple entries (on different keys) +statement ok +COPY source_table +TO 'test_files/scratch/copy/table_with_metadata/' +STORED AS PARQUET +OPTIONS ( + 'format.metadata::key1' '', + 'format.metadata::key2' 'value', + 'format.metadata::key3' 'value with spaces', + 'format.metadata::key4' 'value with special chars :: :' +) + +# accepts multiple entries with the same key (will overwrite) +statement ok +COPY source_table +TO 'test_files/scratch/copy/table_with_metadata/' +STORED AS PARQUET +OPTIONS ( + 'format.metadata::key1' 'value', + 'format.metadata::key1' 'value' +) + +# errors if key is missing +statement error DataFusion error: Invalid or Unsupported Configuration: Invalid metadata key provided, missing key in metadata:: +COPY source_table +TO 'test_files/scratch/copy/table_with_metadata/' +STORED AS PARQUET +OPTIONS ( + 'format.metadata::' 'value' +) + +# errors if key contains internal '::' +statement error DataFusion error: Invalid or Unsupported Configuration: Invalid metadata key provided, found too many '::' in "metadata::key::extra" +COPY source_table +TO 'test_files/scratch/copy/table_with_metadata/' +STORED AS PARQUET +OPTIONS ( + 'format.metadata::key::extra' 'value' +) + +# errors for invalid property (not stating `format.metadata`) +statement error DataFusion error: Invalid or Unsupported Configuration: Config value "wrong-metadata" not found on ParquetColumnOptions +COPY source_table +TO 'test_files/scratch/copy/table_with_metadata/' +STORED AS PARQUET +OPTIONS ( + 'format.wrong-metadata::key' 'value' +) + + # validate multiple parquet file output with all options set statement ok CREATE EXTERNAL TABLE validate_parquet_with_options STORED AS PARQUET LOCATION 'test_files/scratch/copy/table_with_options/'; @@ -299,7 +400,7 @@ select * from validate_parquet_with_options; 2 Bar # Copy from table to single file -query IT +query I COPY source_table to 'test_files/scratch/copy/table.parquet'; ---- 2 @@ -315,14 +416,14 @@ select * from validate_parquet_single; 2 Bar # copy from table to folder of compressed json files -query IT +query I COPY source_table to 'test_files/scratch/copy/table_json_gz' STORED AS JSON OPTIONS ('format.compression' gzip); ---- 2 # validate folder of csv files statement ok -CREATE EXTERNAL TABLE validate_json_gz STORED AS json COMPRESSION TYPE gzip LOCATION 'test_files/scratch/copy/table_json_gz'; +CREATE EXTERNAL TABLE validate_json_gz STORED AS json LOCATION 'test_files/scratch/copy/table_json_gz' OPTIONS ('format.compression' 'gzip'); query IT select * from validate_json_gz; @@ -331,14 +432,14 @@ select * from validate_json_gz; 2 Bar # copy from table to folder of compressed csv files -query IT +query I COPY source_table to 'test_files/scratch/copy/table_csv' STORED AS CSV OPTIONS ('format.has_header' false, 'format.compression' gzip); ---- 2 # validate folder of csv files statement ok -CREATE EXTERNAL TABLE validate_csv STORED AS csv COMPRESSION TYPE gzip LOCATION 'test_files/scratch/copy/table_csv'; +CREATE EXTERNAL TABLE validate_csv STORED AS csv LOCATION 'test_files/scratch/copy/table_csv' OPTIONS ('format.has_header' false, 'format.compression' gzip); query IT select * from validate_csv; @@ -347,14 +448,14 @@ select * from validate_csv; 2 Bar # Copy from table to single csv -query IT -COPY source_table to 'test_files/scratch/copy/table.csv'; +query I +COPY source_table to 'test_files/scratch/copy/table.csv' OPTIONS ('format.has_header' false); ---- 2 # Validate single csv output statement ok -CREATE EXTERNAL TABLE validate_single_csv STORED AS csv WITH HEADER ROW LOCATION 'test_files/scratch/copy/table.csv'; +CREATE EXTERNAL TABLE validate_single_csv STORED AS csv LOCATION 'test_files/scratch/copy/table.csv' OPTIONS ('format.has_header' 'false'); query IT select * from validate_single_csv; @@ -363,7 +464,7 @@ select * from validate_single_csv; 2 Bar # Copy from table to folder of json -query IT +query I COPY source_table to 'test_files/scratch/copy/table_json' STORED AS JSON; ---- 2 @@ -379,7 +480,7 @@ select * from validate_json; 2 Bar # Copy from table to single json file -query IT +query I COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON ; ---- 2 @@ -395,11 +496,11 @@ select * from validate_single_json; 2 Bar # COPY csv files with all options set -query IT +query I COPY source_table to 'test_files/scratch/copy/table_csv_with_options' STORED AS CSV OPTIONS ( -'format.has_header' false, +'format.has_header' true, 'format.compression' uncompressed, 'format.datetime_format' '%FT%H:%M:%S.%9f', 'format.delimiter' ';', @@ -420,7 +521,7 @@ select * from validate_csv_with_options; 2;Bar # Copy from table to single arrow file -query IT +query I COPY source_table to 'test_files/scratch/copy/table.arrow' STORED AS ARROW; ---- 2 @@ -438,7 +539,7 @@ select * from validate_arrow_file; 2 Bar # Copy from dict encoded values to single arrow file -query T? +query I COPY (values ('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) to 'test_files/scratch/copy/table_dict.arrow' STORED AS ARROW; @@ -451,7 +552,7 @@ CREATE EXTERNAL TABLE validate_arrow_file_dict STORED AS arrow LOCATION 'test_files/scratch/copy/table_dict.arrow'; -query T? +query TT select * from validate_arrow_file_dict; ---- c foo @@ -459,7 +560,7 @@ d bar # Copy from table to folder of json -query IT +query I COPY source_table to 'test_files/scratch/copy/table_arrow' STORED AS ARROW; ---- 2 @@ -477,7 +578,7 @@ select * from validate_arrow; # Format Options Support without the 'format.' prefix # Copy with format options for Parquet without the 'format.' prefix -query IT +query I COPY source_table TO 'test_files/scratch/copy/format_table.parquet' OPTIONS ( compression snappy, @@ -487,14 +588,14 @@ OPTIONS ( 2 # Copy with format options for JSON without the 'format.' prefix -query IT +query I COPY source_table to 'test_files/scratch/copy/format_table' STORED AS JSON OPTIONS (compression gzip); ---- 2 # Copy with format options for CSV without the 'format.' prefix -query IT +query I COPY source_table to 'test_files/scratch/copy/format_table.csv' OPTIONS ( has_header false, @@ -521,9 +622,13 @@ query error DataFusion error: Invalid or Unsupported Configuration: Config value COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON OPTIONS ('format.row_group_size' 55); # Incomplete statement -query error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) +query error DataFusion error: SQL error: ParserError\("Expected: \), found: EOF"\) COPY (select col2, sum(col1) from source_table # Copy from table with non literal query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); + +# Copy using execution.keep_partition_by_columns with an invalid value +query error DataFusion error: Invalid or Unsupported Configuration: provided value for 'execution.keep_partition_by_columns' was not recognized: "invalid_value" +COPY source_table to '/tmp/table.parquet' OPTIONS (execution.keep_partition_by_columns invalid_value); diff --git a/datafusion/sqllogictest/test_files/count_star_rule.slt b/datafusion/sqllogictest/test_files/count_star_rule.slt new file mode 100644 index 000000000000..3625da68b39e --- /dev/null +++ b/datafusion/sqllogictest/test_files/count_star_rule.slt @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE t1 (a INTEGER, b INTEGER, c INTEGER); + +statement ok +INSERT INTO t1 VALUES +(1, 2, 3), +(1, 5, 6), +(2, 3, 5); + +statement ok +CREATE TABLE t2 (a INTEGER, b INTEGER, c INTEGER); + +query TT +EXPLAIN SELECT COUNT() FROM (SELECT 1 AS a, 2 AS b) AS t; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]] +02)--SubqueryAlias: t +03)----EmptyRelation +physical_plan +01)ProjectionExec: expr=[1 as count()] +02)--PlaceholderRowExec + +query TT +EXPLAIN SELECT t1.a, COUNT() FROM t1 GROUP BY t1.a; +---- +logical_plan +01)Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]] +02)--TableScan: t1 projection=[a] +physical_plan +01)AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()] +06)----------MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 0; +---- +logical_plan +01)Projection: t1.a, count() AS cnt +02)--Filter: count() > Int64(0) +03)----Aggregate: groupBy=[[t1.a]], aggr=[[count(Int64(1)) AS count()]] +04)------TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 as a, count()@1 as cnt] +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: count()@1 > 0 +04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count()] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count()] +09)----------------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT t1.a, COUNT() AS cnt FROM t1 GROUP BY t1.a HAVING COUNT() > 1; +---- +1 2 + +query TT +EXPLAIN SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1; +---- +logical_plan +01)Projection: t1.a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count_a +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 as a, count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as count_a] +02)--WindowAggExec: wdw=[count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count() PARTITION BY [t1.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +03)----SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------MemoryExec: partitions=1, partition_sizes=[1] + +query II +SELECT a, COUNT() OVER (PARTITION BY a) AS count_a FROM t1 ORDER BY a; +---- +1 2 +1 2 +2 1 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 8aeeb06c1909..ed001cf9f84c 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -33,33 +33,25 @@ statement error DataFusion error: SQL error: ParserError\("Missing LOCATION clau CREATE EXTERNAL TABLE t STORED AS CSV # Option value is missing -statement error DataFusion error: SQL error: ParserError\("Expected literal string, found: \)"\) +statement error DataFusion error: SQL error: ParserError\("Expected: string or numeric value, found: \)"\) CREATE EXTERNAL TABLE t STORED AS x OPTIONS ('k1' 'v1', k2 v2, k3) LOCATION 'blahblah' # Missing `(` in WITH ORDER clause -statement error DataFusion error: SQL error: ParserError\("Expected \(, found: c1"\) +statement error DataFusion error: SQL error: ParserError\("Expected: \(, found: c1"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER c1 LOCATION 'foo.csv' # Missing `)` in WITH ORDER clause -statement error DataFusion error: SQL error: ParserError\("Expected \), found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: \), found: LOCATION"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH ORDER (c1 LOCATION 'foo.csv' # Missing `ROW` in WITH HEADER clause -statement error DataFusion error: SQL error: ParserError\("Expected ROW, found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: ROW, found: LOCATION"\) CREATE EXTERNAL TABLE t STORED AS CSV WITH HEADER LOCATION 'abc' # Missing `BY` in PARTITIONED clause -statement error DataFusion error: SQL error: ParserError\("Expected BY, found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: BY, found: LOCATION"\) CREATE EXTERNAL TABLE t STORED AS CSV PARTITIONED LOCATION 'abc' -# Missing `TYPE` in COMPRESSION clause -statement error DataFusion error: SQL error: ParserError\("Expected TYPE, found: LOCATION"\) -CREATE EXTERNAL TABLE t STORED AS CSV COMPRESSION LOCATION 'abc' - -# Invalid compression type -statement error DataFusion error: SQL error: ParserError\("Unsupported file compression type ZZZ"\) -CREATE EXTERNAL TABLE t STORED AS CSV COMPRESSION TYPE ZZZ LOCATION 'blahblah' - # Duplicate `STORED AS` clause statement error DataFusion error: SQL error: ParserError\("STORED AS specified more than once"\) CREATE EXTERNAL TABLE t STORED AS CSV STORED AS PARQUET LOCATION 'foo.parquet' @@ -68,18 +60,6 @@ CREATE EXTERNAL TABLE t STORED AS CSV STORED AS PARQUET LOCATION 'foo.parquet' statement error DataFusion error: SQL error: ParserError\("LOCATION specified more than once"\) CREATE EXTERNAL TABLE t STORED AS CSV LOCATION 'foo.csv' LOCATION 'bar.csv' -# Duplicate `WITH HEADER ROW` clause -statement error DataFusion error: SQL error: ParserError\("WITH HEADER ROW specified more than once"\) -CREATE EXTERNAL TABLE t STORED AS CSV WITH HEADER ROW WITH HEADER ROW LOCATION 'foo.csv' - -# Duplicate `DELIMITER` clause -statement error DataFusion error: SQL error: ParserError\("DELIMITER specified more than once"\) -CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV DELIMITER '-' DELIMITER '+' LOCATION 'foo.csv' - -# Duplicate `COMPRESSION TYPE` clause -statement error DataFusion error: SQL error: ParserError\("COMPRESSION TYPE specified more than once"\) -CREATE EXTERNAL TABLE t STORED AS CSV COMPRESSION TYPE BZIP2 COMPRESSION TYPE XZ COMPRESSION TYPE ZSTD COMPRESSION TYPE GZIP LOCATION 'foo.csv' - # Duplicate `PARTITIONED BY` clause statement error DataFusion error: SQL error: ParserError\("PARTITIONED BY specified more than once"\) create EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV PARTITIONED BY (c1) partitioned by (c2) LOCATION 'foo.csv' @@ -89,11 +69,11 @@ statement error DataFusion error: SQL error: ParserError\("OPTIONS specified mor CREATE EXTERNAL TABLE t STORED AS CSV OPTIONS ('k1' 'v1', 'k2' 'v2') OPTIONS ('k3' 'v3') LOCATION 'foo.csv' # With typo error -statement error DataFusion error: SQL error: ParserError\("Expected HEADER, found: HEAD"\) +statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, found: HEAD"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH HEAD ROW LOCATION 'foo.csv'; # Missing `anything` in WITH clause -statement error DataFusion error: SQL error: ParserError\("Expected HEADER, found: LOCATION"\) +statement error DataFusion error: SQL error: ParserError\("Expected: HEADER, found: LOCATION"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; # Unrecognized random clause @@ -115,6 +95,13 @@ STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.delimiter' ';', 'format.column_index_truncate_length' '123') +# Creating Temporary tables +statement error DataFusion error: This feature is not implemented: Temporary tables not supported +CREATE TEMPORARY TABLE my_temp_table ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); + # Partitioned table on a single file query error DataFusion error: Error during planning: Can't create a partitioned table backed by a single file, perhaps the URL is missing a trailing slash\? CREATE EXTERNAL TABLE single_file_partition(c1 int) @@ -130,7 +117,7 @@ PARTITIONED BY (p1 string, p2 string) STORED AS parquet LOCATION 'test_files/scratch/create_external_table/bad_partitioning/'; -query ITT +query I INSERT INTO partitioned VALUES (1, 'x', 'y'); ---- 1 @@ -186,13 +173,13 @@ PARTITIONED BY (month string, year string) STORED AS parquet LOCATION 'test_files/scratch/create_external_table/manual_partitioning/'; -query TTT +query I -- creates year -> month partitions INSERT INTO test VALUES('name', '2024', '03'); ---- 1 -query TTT +query I -- creates month -> year partitions. -- now table have both partitions (year -> month and month -> year) INSERT INTO test2 VALUES('name', '2024', '03'); @@ -205,3 +192,94 @@ CREATE EXTERNAL TABLE test3(name string) PARTITIONED BY (month string, year string) STORED AS parquet LOCATION 'test_files/scratch/create_external_table/manual_partitioning/'; + +# Duplicate key assignment in OPTIONS clause +statement error DataFusion error: Error during planning: Option format.delimiter is specified multiple times +CREATE EXTERNAL TABLE t STORED AS CSV OPTIONS ( + 'format.delimiter' '*', + 'format.has_header' 'true', + 'format.delimiter' '|') +LOCATION 'foo.csv'; + +# If a config does not belong to any namespace, we assume it is a 'format' option and apply the 'format' prefix for backwards compatibility. +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS region ( + r_regionkey BIGINT, + r_name VARCHAR, + r_comment VARCHAR, + r_rev VARCHAR, +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +OPTIONS ( + 'format.delimiter' '|', + 'has_header' 'false'); + +# Verify that we do not need quotations for simple namespaced keys. +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS region ( + r_regionkey BIGINT, + r_name VARCHAR, + r_comment VARCHAR, + r_rev VARCHAR, +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +OPTIONS ( + format.delimiter '|', + has_header false, + compression gzip); + +# Create an external parquet table and infer schema to order by + +# query should succeed +statement ok +CREATE EXTERNAL TABLE t STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet' WITH ORDER (id); + +## Verify that the table is created with a sort order. Explain should show output_ordering=[id@0 ASC] +query TT +EXPLAIN SELECT id FROM t ORDER BY id ASC; +---- +logical_plan +01)Sort: t.id ASC NULLS LAST +02)--TableScan: t projection=[id] +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST] + +## Test a DESC order and verify that output_ordering is ASC from the previous OBRDER BY +query TT +EXPLAIN SELECT id FROM t ORDER BY id DESC; +---- +logical_plan +01)Sort: t.id DESC NULLS FIRST +02)--TableScan: t projection=[id] +physical_plan +01)SortExec: expr=[id@0 DESC], preserve_partitioning=[false] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id], output_ordering=[id@0 ASC NULLS LAST] + +statement ok +DROP TABLE t; + +# Create table with non default sort order +statement ok +CREATE EXTERNAL TABLE t STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet' WITH ORDER (id DESC NULLS FIRST); + +## Verify that the table is created with a sort order. Explain should show output_ordering=[id@0 DESC NULLS FIRST] +query TT +EXPLAIN SELECT id FROM t; +---- +logical_plan TableScan: t projection=[id] +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id], output_ordering=[id@0 DESC] + +statement ok +DROP TABLE t; + +# query should fail with bad column +statement error DataFusion error: Error during planning: Column foo is not in schema +CREATE EXTERNAL TABLE t STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet' WITH ORDER (foo); + +# Create external table with qualified name should belong to the schema +statement ok +CREATE SCHEMA staging; + +statement ok +CREATE EXTERNAL TABLE staging.foo STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; + +# Create external table with qualified name, but no schema should error +statement error DataFusion error: Error during planning: failed to resolve schema: release +CREATE EXTERNAL TABLE release.bar STORED AS parquet LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; diff --git a/datafusion/sqllogictest/test_files/cse.slt b/datafusion/sqllogictest/test_files/cse.slt new file mode 100644 index 000000000000..c95e9a1309f8 --- /dev/null +++ b/datafusion/sqllogictest/test_files/cse.slt @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE IF NOT EXISTS t1(a DOUBLE, b DOUBLE) + +# Trivial common expression +query TT +EXPLAIN SELECT + a + 1 AS c1, + a + 1 AS c2 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 AS c1, __common_expr_1 AS c2 +02)--Projection: t1.a + Float64(1) AS __common_expr_1 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 as c1, __common_expr_1@0 as c2] +02)--ProjectionExec: expr=[a@0 + 1 as __common_expr_1] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common volatile expression +query TT +EXPLAIN SELECT + a + random() AS c1, + a + random() AS c2 +FROM t1 +---- +logical_plan +01)Projection: t1.a + random() AS c1, t1.a + random() AS c2 +02)--TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[a@0 + random() as c1, a@0 + random() as c2] +02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Volatile expression with non-volatile common child +query TT +EXPLAIN SELECT + a + 1 + random() AS c1, + a + 1 + random() AS c2 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2 +02)--Projection: t1.a + Float64(1) AS __common_expr_1 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 + random() as c1, __common_expr_1@0 + random() as c2] +02)--ProjectionExec: expr=[a@0 + 1 as __common_expr_1] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Volatile expression with non-volatile common children +query TT +EXPLAIN SELECT + a + 1 + random() + (a + 2) AS c1, + a + 1 + random() + (a + 2) AS c2 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 + random() + __common_expr_2 AS c1, __common_expr_1 + random() + __common_expr_2 AS c2 +02)--Projection: t1.a + Float64(1) AS __common_expr_1, t1.a + Float64(2) AS __common_expr_2 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 + random() + __common_expr_2@1 as c1, __common_expr_1@0 + random() + __common_expr_2@1 as c2] +02)--ProjectionExec: expr=[a@0 + 1 as __common_expr_1, a@0 + 2 as __common_expr_2] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common short-circuit expression +query TT +EXPLAIN SELECT + a = 0 AND b = 0 AS c1, + a = 0 AND b = 0 AS c2, + a = 0 OR b = 0 AS c3, + a = 0 OR b = 0 AS c4, + CASE WHEN (a = 0) THEN 0 ELSE 1 END AS c5, + CASE WHEN (a = 0) THEN 0 ELSE 1 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 AS c1, __common_expr_1 AS c2, __common_expr_2 AS c3, __common_expr_2 AS c4, __common_expr_3 AS c5, __common_expr_3 AS c6 +02)--Projection: __common_expr_4 AND t1.b = Float64(0) AS __common_expr_1, __common_expr_4 OR t1.b = Float64(0) AS __common_expr_2, CASE WHEN __common_expr_4 THEN Int64(0) ELSE Int64(1) END AS __common_expr_3 +03)----Projection: t1.a = Float64(0) AS __common_expr_4, t1.b +04)------TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 as c1, __common_expr_1@0 as c2, __common_expr_2@1 as c3, __common_expr_2@1 as c4, __common_expr_3@2 as c5, __common_expr_3@2 as c6] +02)--ProjectionExec: expr=[__common_expr_4@0 AND b@1 = 0 as __common_expr_1, __common_expr_4@0 OR b@1 = 0 as __common_expr_2, CASE WHEN __common_expr_4@0 THEN 0 ELSE 1 END as __common_expr_3] +03)----ProjectionExec: expr=[a@0 = 0 as __common_expr_4, b@1 as b] +04)------MemoryExec: partitions=1, partition_sizes=[0] + +# Common children of short-circuit expression +query TT +EXPLAIN SELECT + a = 0 AND b = 0 AS c1, + a = 0 AND b = 1 AS c2, + b = 2 AND a = 1 AS c3, + b = 3 AND a = 1 AS c4, + a = 2 OR b = 4 AS c5, + a = 2 OR b = 5 AS c6, + b = 6 OR a = 3 AS c7, + b = 7 OR a = 3 AS c8, + CASE WHEN (a = 4) THEN 0 ELSE 1 END AS c9, + CASE WHEN (a = 4) THEN 0 ELSE 2 END AS c10, + CASE WHEN (b = 8) THEN a + 1 ELSE 0 END AS c11, + CASE WHEN (b = 9) THEN a + 1 ELSE 0 END AS c12, + CASE WHEN (b = 10) THEN 0 ELSE a + 2 END AS c13, + CASE WHEN (b = 11) THEN 0 ELSE a + 2 END AS c14 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 AND t1.b = Float64(0) AS c1, __common_expr_1 AND t1.b = Float64(1) AS c2, t1.b = Float64(2) AND t1.a = Float64(1) AS c3, t1.b = Float64(3) AND t1.a = Float64(1) AS c4, __common_expr_2 OR t1.b = Float64(4) AS c5, __common_expr_2 OR t1.b = Float64(5) AS c6, t1.b = Float64(6) OR t1.a = Float64(3) AS c7, t1.b = Float64(7) OR t1.a = Float64(3) AS c8, CASE WHEN __common_expr_3 THEN Int64(0) ELSE Int64(1) END AS c9, CASE WHEN __common_expr_3 THEN Int64(0) ELSE Int64(2) END AS c10, CASE WHEN t1.b = Float64(8) THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 +02)--Projection: t1.a = Float64(0) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a = Float64(4) AS __common_expr_3, t1.a, t1.b +03)----TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 AND b@4 = 0 as c1, __common_expr_1@0 AND b@4 = 1 as c2, b@4 = 2 AND a@3 = 1 as c3, b@4 = 3 AND a@3 = 1 as c4, __common_expr_2@1 OR b@4 = 4 as c5, __common_expr_2@1 OR b@4 = 5 as c6, b@4 = 6 OR a@3 = 3 as c7, b@4 = 7 OR a@3 = 3 as c8, CASE WHEN __common_expr_3@2 THEN 0 ELSE 1 END as c9, CASE WHEN __common_expr_3@2 THEN 0 ELSE 2 END as c10, CASE WHEN b@4 = 8 THEN a@3 + 1 ELSE 0 END as c11, CASE WHEN b@4 = 9 THEN a@3 + 1 ELSE 0 END as c12, CASE WHEN b@4 = 10 THEN 0 ELSE a@3 + 2 END as c13, CASE WHEN b@4 = 11 THEN 0 ELSE a@3 + 2 END as c14] +02)--ProjectionExec: expr=[a@0 = 0 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 = 4 as __common_expr_3, a@0 as a, b@1 as b] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common children of volatile, short-circuit expression +query TT +EXPLAIN SELECT + a = 0 AND b = random() AS c1, + a = 0 AND b = 1 + random() AS c2, + b = 2 + random() AND a = 1 AS c3, + b = 3 + random() AND a = 1 AS c4, + a = 2 OR b = 4 + random() AS c5, + a = 2 OR b = 5 + random() AS c6, + b = 6 + random() OR a = 3 AS c7, + b = 7 + random() OR a = 3 AS c8, + CASE WHEN (a = 4) THEN random() ELSE 1 END AS c9, + CASE WHEN (a = 4) THEN random() ELSE 2 END AS c10, + CASE WHEN (b = 8 + random()) THEN a + 1 ELSE 0 END AS c11, + CASE WHEN (b = 9 + random()) THEN a + 1 ELSE 0 END AS c12, + CASE WHEN (b = 10 + random()) THEN 0 ELSE a + 2 END AS c13, + CASE WHEN (b = 11 + random()) THEN 0 ELSE a + 2 END AS c14 +FROM t1 +---- +logical_plan +01)Projection: __common_expr_1 AND t1.b = random() AS c1, __common_expr_1 AND t1.b = Float64(1) + random() AS c2, t1.b = Float64(2) + random() AND t1.a = Float64(1) AS c3, t1.b = Float64(3) + random() AND t1.a = Float64(1) AS c4, __common_expr_2 OR t1.b = Float64(4) + random() AS c5, __common_expr_2 OR t1.b = Float64(5) + random() AS c6, t1.b = Float64(6) + random() OR t1.a = Float64(3) AS c7, t1.b = Float64(7) + random() OR t1.a = Float64(3) AS c8, CASE WHEN __common_expr_3 THEN random() ELSE Float64(1) END AS c9, CASE WHEN __common_expr_3 THEN random() ELSE Float64(2) END AS c10, CASE WHEN t1.b = Float64(8) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c11, CASE WHEN t1.b = Float64(9) + random() THEN t1.a + Float64(1) ELSE Float64(0) END AS c12, CASE WHEN t1.b = Float64(10) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c13, CASE WHEN t1.b = Float64(11) + random() THEN Float64(0) ELSE t1.a + Float64(2) END AS c14 +02)--Projection: t1.a = Float64(0) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a = Float64(4) AS __common_expr_3, t1.a, t1.b +03)----TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[__common_expr_1@0 AND b@4 = random() as c1, __common_expr_1@0 AND b@4 = 1 + random() as c2, b@4 = 2 + random() AND a@3 = 1 as c3, b@4 = 3 + random() AND a@3 = 1 as c4, __common_expr_2@1 OR b@4 = 4 + random() as c5, __common_expr_2@1 OR b@4 = 5 + random() as c6, b@4 = 6 + random() OR a@3 = 3 as c7, b@4 = 7 + random() OR a@3 = 3 as c8, CASE WHEN __common_expr_3@2 THEN random() ELSE 1 END as c9, CASE WHEN __common_expr_3@2 THEN random() ELSE 2 END as c10, CASE WHEN b@4 = 8 + random() THEN a@3 + 1 ELSE 0 END as c11, CASE WHEN b@4 = 9 + random() THEN a@3 + 1 ELSE 0 END as c12, CASE WHEN b@4 = 10 + random() THEN 0 ELSE a@3 + 2 END as c13, CASE WHEN b@4 = 11 + random() THEN 0 ELSE a@3 + 2 END as c14] +02)--ProjectionExec: expr=[a@0 = 0 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 = 4 as __common_expr_3, a@0 as a, b@1 as b] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Common volatile children of short-circuit expression +query TT +EXPLAIN SELECT + a = random() AND b = 0 AS c1, + a = random() AND b = 1 AS c2, + a = 2 + random() OR b = 4 AS c3, + a = 2 + random() OR b = 5 AS c4, + CASE WHEN (a = 4 + random()) THEN 0 ELSE 1 END AS c5, + CASE WHEN (a = 4 + random()) THEN 0 ELSE 2 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: t1.a = random() AND t1.b = Float64(0) AS c1, t1.a = random() AND t1.b = Float64(1) AS c2, t1.a = Float64(2) + random() OR t1.b = Float64(4) AS c3, t1.a = Float64(2) + random() OR t1.b = Float64(5) AS c4, CASE WHEN t1.a = Float64(4) + random() THEN Int64(0) ELSE Int64(1) END AS c5, CASE WHEN t1.a = Float64(4) + random() THEN Int64(0) ELSE Int64(2) END AS c6 +02)--TableScan: t1 projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 = random() AND b@1 = 0 as c1, a@0 = random() AND b@1 = 1 as c2, a@0 = 2 + random() OR b@1 = 4 as c3, a@0 = 2 + random() OR b@1 = 5 as c4, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 1 END as c5, CASE WHEN a@0 = 4 + random() THEN 0 ELSE 2 END as c6] +02)--MemoryExec: partitions=1, partition_sizes=[0] + +# Surely only once but also conditionally evaluated expressions +query TT +EXPLAIN SELECT + (a = 1 OR random() = 0) AND a = 2 AS c1, + (a = 2 AND random() = 0) OR a = 1 AS c2, + CASE WHEN a + 3 = 0 THEN a + 3 ELSE 0 END AS c3, + CASE WHEN a + 4 = 0 THEN 0 WHEN a + 4 THEN 0 ELSE 0 END AS c4, + CASE WHEN a + 5 = 0 THEN 0 WHEN random() = 0 THEN a + 5 ELSE 0 END AS c5, + CASE WHEN a + 6 = 0 THEN 0 ELSE a + 6 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND __common_expr_2 AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_1 AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Int64(0) WHEN CAST(__common_expr_4 AS Boolean) THEN Int64(0) ELSE Int64(0) END AS c4, CASE WHEN __common_expr_5 = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN __common_expr_5 ELSE Float64(0) END AS c5, CASE WHEN __common_expr_6 = Float64(0) THEN Float64(0) ELSE __common_expr_6 END AS c6 +02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4, t1.a + Float64(5) AS __common_expr_5, t1.a + Float64(6) AS __common_expr_6 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND __common_expr_2@1 as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_1@0 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 WHEN CAST(__common_expr_4@3 AS Boolean) THEN 0 ELSE 0 END as c4, CASE WHEN __common_expr_5@4 = 0 THEN 0 WHEN random() = 0 THEN __common_expr_5@4 ELSE 0 END as c5, CASE WHEN __common_expr_6@5 = 0 THEN 0 ELSE __common_expr_6@5 END as c6] +02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4, a@0 + 5 as __common_expr_5, a@0 + 6 as __common_expr_6] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Surely only once but also conditionally evaluated subexpressions +query TT +EXPLAIN SELECT + (a = 1 OR random() = 0) AND (a = 2 OR random() = 1) AS c1, + (a = 2 AND random() = 0) OR (a = 1 AND random() = 1) AS c2, + CASE WHEN a + 3 = 0 THEN a + 3 + random() ELSE 0 END AS c3, + CASE WHEN a + 4 = 0 THEN 0 ELSE a + 4 + random() END AS c4 +FROM t1 +---- +logical_plan +01)Projection: (__common_expr_1 OR random() = Float64(0)) AND (__common_expr_2 OR random() = Float64(1)) AS c1, __common_expr_2 AND random() = Float64(0) OR __common_expr_1 AND random() = Float64(1) AS c2, CASE WHEN __common_expr_3 = Float64(0) THEN __common_expr_3 + random() ELSE Float64(0) END AS c3, CASE WHEN __common_expr_4 = Float64(0) THEN Float64(0) ELSE __common_expr_4 + random() END AS c4 +02)--Projection: t1.a = Float64(1) AS __common_expr_1, t1.a = Float64(2) AS __common_expr_2, t1.a + Float64(3) AS __common_expr_3, t1.a + Float64(4) AS __common_expr_4 +03)----TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[(__common_expr_1@0 OR random() = 0) AND (__common_expr_2@1 OR random() = 1) as c1, __common_expr_2@1 AND random() = 0 OR __common_expr_1@0 AND random() = 1 as c2, CASE WHEN __common_expr_3@2 = 0 THEN __common_expr_3@2 + random() ELSE 0 END as c3, CASE WHEN __common_expr_4@3 = 0 THEN 0 ELSE __common_expr_4@3 + random() END as c4] +02)--ProjectionExec: expr=[a@0 = 1 as __common_expr_1, a@0 = 2 as __common_expr_2, a@0 + 3 as __common_expr_3, a@0 + 4 as __common_expr_4] +03)----MemoryExec: partitions=1, partition_sizes=[0] + +# Only conditionally evaluated expressions +query TT +EXPLAIN SELECT + (random() = 0 OR a = 1) AND a = 2 AS c1, + (random() = 0 AND a = 2) OR a = 1 AS c2, + CASE WHEN random() = 0 THEN a + 3 ELSE a + 3 END AS c3, + CASE WHEN random() = 0 THEN 0 WHEN a + 4 = 0 THEN a + 4 ELSE 0 END AS c4, + CASE WHEN random() = 0 THEN 0 WHEN a + 5 = 0 THEN 0 ELSE a + 5 END AS c5, + CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a + 6 ELSE a + 6 END AS c6 +FROM t1 +---- +logical_plan +01)Projection: (random() = Float64(0) OR t1.a = Float64(1)) AND t1.a = Float64(2) AS c1, random() = Float64(0) AND t1.a = Float64(2) OR t1.a = Float64(1) AS c2, CASE WHEN random() = Float64(0) THEN t1.a + Float64(3) ELSE t1.a + Float64(3) END AS c3, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(4) = Float64(0) THEN t1.a + Float64(4) ELSE Float64(0) END AS c4, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN t1.a + Float64(5) = Float64(0) THEN Float64(0) ELSE t1.a + Float64(5) END AS c5, CASE WHEN random() = Float64(0) THEN Float64(0) WHEN random() = Float64(0) THEN t1.a + Float64(6) ELSE t1.a + Float64(6) END AS c6 +02)--TableScan: t1 projection=[a] +physical_plan +01)ProjectionExec: expr=[(random() = 0 OR a@0 = 1) AND a@0 = 2 as c1, random() = 0 AND a@0 = 2 OR a@0 = 1 as c2, CASE WHEN random() = 0 THEN a@0 + 3 ELSE a@0 + 3 END as c3, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 4 = 0 THEN a@0 + 4 ELSE 0 END as c4, CASE WHEN random() = 0 THEN 0 WHEN a@0 + 5 = 0 THEN 0 ELSE a@0 + 5 END as c5, CASE WHEN random() = 0 THEN 0 WHEN random() = 0 THEN a@0 + 6 ELSE a@0 + 6 END as c6] +02)--MemoryExec: partitions=1, partition_sizes=[0] diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index f595cbe7f3b1..01d0f4ac39bd 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -21,19 +21,19 @@ CREATE EXTERNAL TABLE csv_with_quote ( c1 VARCHAR, c2 VARCHAR ) STORED AS CSV -WITH HEADER ROW -DELIMITER ',' -OPTIONS ('format.quote' '~') -LOCATION '../core/tests/data/quote.csv'; +LOCATION '../core/tests/data/quote.csv' +OPTIONS ('format.quote' '~', + 'format.delimiter' ',', + 'format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE csv_with_escape ( c1 VARCHAR, c2 VARCHAR ) STORED AS CSV -WITH HEADER ROW -DELIMITER ',' -OPTIONS ('format.escape' '\') +OPTIONS ('format.escape' '\', + 'format.delimiter' ',', + 'format.has_header' 'true') LOCATION '../core/tests/data/escape.csv'; query TT @@ -50,6 +50,11 @@ id7 value7 id8 value8 id9 value9 +# Ensure that local files can not be read by default (a potential security issue) +# (url table is only supported when DynamicFileCatalog is enabled) +statement error DataFusion error: Error during planning: table 'datafusion.public.../core/tests/data/quote.csv' not found +select * from '../core/tests/data/quote.csv'; + query TT select * from csv_with_escape; ---- @@ -69,9 +74,9 @@ CREATE EXTERNAL TABLE csv_with_escape_2 ( c1 VARCHAR, c2 VARCHAR ) STORED AS CSV -WITH HEADER ROW -DELIMITER ',' -OPTIONS ('format.escape' '"') +OPTIONS ('format.escape' '"', + 'format.delimiter' ',', + 'format.has_header' 'true') LOCATION '../core/tests/data/escape.csv'; # TODO: Validate this with better data. @@ -115,16 +120,16 @@ CREATE TABLE src_table_2 ( (7, 'ggg', 700, 2), (8, 'hhh', 800, 2); -query ITII +query I COPY src_table_1 TO 'test_files/scratch/csv_files/csv_partitions/1.csv' -STORED AS CSV; +STORED AS CSV OPTIONS ('format.has_header' 'false'); ---- 4 -query ITII +query I COPY src_table_2 TO 'test_files/scratch/csv_files/csv_partitions/2.csv' -STORED AS CSV; +STORED AS CSV OPTIONS ('format.has_header' 'false'); ---- 4 @@ -136,8 +141,8 @@ CREATE EXTERNAL TABLE partitioned_table ( partition_col INT ) STORED AS CSV -WITH HEADER ROW -LOCATION 'test_files/scratch/csv_files/csv_partitions'; +LOCATION 'test_files/scratch/csv_files/csv_partitions' +OPTIONS ('format.has_header' 'false'); query ITII SELECT * FROM partitioned_table ORDER BY int_col; @@ -159,5 +164,201 @@ logical_plan 02)--TableScan: partitioned_table projection=[int_col, string_col, bigint_col, partition_col] physical_plan 01)SortPreservingMergeExec: [int_col@0 ASC NULLS LAST] -02)--SortExec: expr=[int_col@0 ASC NULLS LAST] -03)----CsvExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/csv_files/csv_partitions/1.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/csv_files/csv_partitions/2.csv]]}, projection=[int_col, string_col, bigint_col, partition_col], has_header=true +02)--SortExec: expr=[int_col@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----CsvExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/csv_files/csv_partitions/1.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/csv_files/csv_partitions/2.csv]]}, projection=[int_col, string_col, bigint_col, partition_col], has_header=false + + +# ensure that correct quote character is used when writing to csv +statement ok +CREATE TABLE table_with_necessary_quoting ( + int_col INT, + string_col TEXT +) AS VALUES +(1, 'e|e|e'), +(2, 'f|f|f'), +(3, 'g|g|g'), +(4, 'h|h|h'); + +# quote is required because `|` is delimiter and part of the data +query I +COPY table_with_necessary_quoting TO 'test_files/scratch/csv_files/table_with_necessary_quoting.csv' +STORED AS csv +OPTIONS ('format.quote' '~', + 'format.delimiter' '|', + 'format.has_header' 'true'); +---- +4 + +# read the stored csv file with quote character +statement ok +CREATE EXTERNAL TABLE stored_table_with_necessary_quoting ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_necessary_quoting.csv' +OPTIONS ('format.quote' '~', + 'format.delimiter' '|', + 'format.has_header' 'true'); + +query TT +select * from stored_table_with_necessary_quoting; +---- +1 e|e|e +2 f|f|f +3 g|g|g +4 h|h|h + +# Read CSV file with comments +statement ok +COPY (VALUES + ('column1,column2'), + ('#second line is a comment'), + ('2,3')) +TO 'test_files/scratch/csv_files/file_with_comments.csv' +OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); + +statement ok +CREATE EXTERNAL TABLE stored_table_with_comments ( + c1 VARCHAR, + c2 VARCHAR +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/file_with_comments.csv' +OPTIONS ('format.comment' '#', + 'format.delimiter' ',', + 'format.has_header' 'false'); + +query TT +SELECT * from stored_table_with_comments; +---- +column1 column2 +2 3 + +# read csv with double quote +statement ok +CREATE EXTERNAL TABLE csv_with_double_quote ( +c1 VARCHAR, +c2 VARCHAR +) STORED AS CSV +OPTIONS ('format.delimiter' ',', + 'format.has_header' 'true', + 'format.double_quote' 'true') +LOCATION '../core/tests/data/double_quote.csv'; + +query TT +select * from csv_with_double_quote +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# ensure that double quote option is used when writing to csv +query I +COPY csv_with_double_quote TO 'test_files/scratch/csv_files/table_with_double_quotes.csv' +STORED AS csv +OPTIONS ('format.double_quote' 'true'); +---- +4 + +statement ok +CREATE EXTERNAL TABLE stored_table_with_double_quotes ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_double_quotes.csv' +OPTIONS ('format.double_quote' 'true'); + +query TT +select * from stored_table_with_double_quotes; +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# ensure when double quote option is disabled that quotes are escaped instead +query I +COPY csv_with_double_quote TO 'test_files/scratch/csv_files/table_with_escaped_quotes.csv' +STORED AS csv +OPTIONS ('format.double_quote' 'false', 'format.escape' '#'); +---- +4 + +statement ok +CREATE EXTERNAL TABLE stored_table_with_escaped_quotes ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION 'test_files/scratch/csv_files/table_with_escaped_quotes.csv' +OPTIONS ('format.double_quote' 'false', 'format.escape' '#'); + +query TT +select * from stored_table_with_escaped_quotes; +---- +id0 "value0" +id1 "value1" +id2 "value2" +id3 "value3" + +# Handling of newlines in values + +statement ok +SET datafusion.optimizer.repartition_file_min_size = 1; + +statement ok +CREATE EXTERNAL TABLE stored_table_with_newlines_in_values_unsafe ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/newlines_in_values.csv'; + +statement error incorrect number of fields +select * from stored_table_with_newlines_in_values_unsafe; + +statement ok +CREATE EXTERNAL TABLE stored_table_with_newlines_in_values_safe ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/newlines_in_values.csv' +OPTIONS ('format.newlines_in_values' 'true', 'format.has_header' 'false'); + +query TT +select * from stored_table_with_newlines_in_values_safe; +---- +id message +1 +01)hello +02)world +2 +01)something +02)else +3 +01) +02)many +03)lines +04)make +05)good test +4 unquoted +value end + +statement ok +CREATE EXTERNAL TABLE stored_table_with_cr_terminator ( +col1 TEXT, +col2 TEXT +) STORED AS CSV +LOCATION '../core/tests/data/cr_terminator.csv' +OPTIONS ('format.terminator' E'\r', 'format.has_header' 'true'); + +# TODO: It should be passed but got the error: External error: query failed: DataFusion error: Object Store error: Generic LocalFileSystem error: Requested range was invalid +# See the issue: https://github.com/apache/datafusion/issues/12328 +# query TT +# select * from stored_table_with_cr_terminator; +# ---- +# id0 value0 +# id1 value1 +# id2 value2 +# id3 value3 + +statement ok +drop table stored_table_with_cr_terminator; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 7d0f929962bd..53ca8d81b9e4 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -31,10 +31,9 @@ query TT EXPLAIN WITH "NUMBERS" AS (SELECT 1 as a, 2 as b, 3 as c) SELECT "NUMBERS".* FROM "NUMBERS" ---- logical_plan -01)Projection: NUMBERS.a, NUMBERS.b, NUMBERS.c -02)--SubqueryAlias: NUMBERS -03)----Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c -04)------EmptyRelation +01)SubqueryAlias: NUMBERS +02)--Projection: Int64(1) AS a, Int64(2) AS b, Int64(3) AS c +03)----EmptyRelation physical_plan 01)ProjectionExec: expr=[1 as a, 2 as b, 3 as c] 02)--PlaceholderRowExec @@ -105,14 +104,13 @@ EXPLAIN WITH RECURSIVE nodes AS ( SELECT * FROM nodes ---- logical_plan -01)Projection: nodes.id -02)--SubqueryAlias: nodes -03)----RecursiveQuery: is_distinct=false -04)------Projection: Int64(1) AS id -05)--------EmptyRelation -06)------Projection: nodes.id + Int64(1) AS id -07)--------Filter: nodes.id < Int64(10) -08)----------TableScan: nodes +01)SubqueryAlias: nodes +02)--RecursiveQuery: is_distinct=false +03)----Projection: Int64(1) AS id +04)------EmptyRelation +05)----Projection: nodes.id + Int64(1) AS id +06)------Filter: nodes.id < Int64(10) +07)--------TableScan: nodes physical_plan 01)RecursiveQueryExec: name=nodes, is_distinct=false 02)--ProjectionExec: expr=[1 as id] @@ -126,11 +124,11 @@ physical_plan # setup statement ok -CREATE EXTERNAL TABLE balance STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/balance.csv' +CREATE EXTERNAL TABLE balance STORED as CSV LOCATION '../core/tests/data/recursive_cte/balance.csv' OPTIONS ('format.has_header' 'true'); # setup statement ok -CREATE EXTERNAL TABLE growth STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/growth.csv' +CREATE EXTERNAL TABLE growth STORED as CSV LOCATION '../core/tests/data/recursive_cte/growth.csv' OPTIONS ('format.has_header' 'true'); # setup statement ok @@ -152,16 +150,15 @@ ORDER BY time, name, account_balance ---- logical_plan 01)Sort: balances.time ASC NULLS LAST, balances.name ASC NULLS LAST, balances.account_balance ASC NULLS LAST -02)--Projection: balances.time, balances.name, balances.account_balance -03)----SubqueryAlias: balances -04)------RecursiveQuery: is_distinct=false -05)--------Projection: balance.time, balance.name, balance.account_balance -06)----------TableScan: balance -07)--------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance -08)----------Filter: balances.time < Int64(10) -09)------------TableScan: balances +02)--SubqueryAlias: balances +03)----RecursiveQuery: is_distinct=false +04)------Projection: balance.time, balance.name, balance.account_balance +05)--------TableScan: balance +06)------Projection: balances.time + Int64(1) AS time, balances.name, balances.account_balance + Int64(10) AS account_balance +07)--------Filter: balances.time < Int64(10) +08)----------TableScan: balances physical_plan -01)SortExec: expr=[time@0 ASC NULLS LAST,name@1 ASC NULLS LAST,account_balance@2 ASC NULLS LAST] +01)SortExec: expr=[time@0 ASC NULLS LAST, name@1 ASC NULLS LAST, account_balance@2 ASC NULLS LAST], preserve_partitioning=[false] 02)--RecursiveQueryExec: name=balances, is_distinct=false 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/recursive_cte/balance.csv]]}, projection=[time, name, account_balance], has_header=true 04)----CoalescePartitionsExec @@ -407,7 +404,7 @@ FROM # setup statement ok -CREATE EXTERNAL TABLE prices STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/prices.csv' +CREATE EXTERNAL TABLE prices STORED as CSV LOCATION '../core/tests/data/recursive_cte/prices.csv' OPTIONS ('format.has_header' 'true'); # CTE within window function inside nested CTE works. This test demonstrates using a nested window function to recursively iterate over a column. query RRII @@ -598,11 +595,11 @@ ORDER BY # setup statement ok -CREATE EXTERNAL TABLE sales STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/sales.csv' +CREATE EXTERNAL TABLE sales STORED as CSV LOCATION '../core/tests/data/recursive_cte/sales.csv' OPTIONS ('format.has_header' 'true'); # setup statement ok -CREATE EXTERNAL TABLE salespersons STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/recursive_cte/salespersons.csv' +CREATE EXTERNAL TABLE salespersons STORED as CSV LOCATION '../core/tests/data/recursive_cte/salespersons.csv' OPTIONS ('format.has_header' 'true'); # group by works within recursive cte. This test case demonstrates rolling up a hierarchy of salespeople to their managers. @@ -720,18 +717,17 @@ explain WITH RECURSIVE recursive_cte AS ( SELECT * FROM recursive_cte; ---- logical_plan -01)Projection: recursive_cte.val -02)--SubqueryAlias: recursive_cte -03)----RecursiveQuery: is_distinct=false -04)------Projection: Int64(1) AS val -05)--------EmptyRelation -06)------Projection: Int64(2) AS val -07)--------CrossJoin: -08)----------Filter: recursive_cte.val < Int64(2) -09)------------TableScan: recursive_cte -10)----------SubqueryAlias: sub_cte -11)------------Projection: Int64(2) AS val -12)--------------EmptyRelation +01)SubqueryAlias: recursive_cte +02)--RecursiveQuery: is_distinct=false +03)----Projection: Int64(1) AS val +04)------EmptyRelation +05)----Projection: Int64(2) AS val +06)------Cross Join: +07)--------Filter: recursive_cte.val < Int64(2) +08)----------TableScan: recursive_cte +09)--------SubqueryAlias: sub_cte +10)----------Projection: Int64(2) AS val +11)------------EmptyRelation physical_plan 01)RecursiveQueryExec: name=recursive_cte, is_distinct=false 02)--ProjectionExec: expr=[1 as val] @@ -832,3 +828,34 @@ SELECT * FROM non_recursive_cte, recursive_cte; ---- 1 1 1 3 + +# Name shadowing: +# The first `t` refers to the table, the second to the CTE. +query I +WITH t AS (SELECT * FROM t where t.a < 2) SELECT * FROM t +---- +1 + +# Issue: https://github.com/apache/datafusion/issues/10914 +# The CTE defined within the subquery is only visible inside that subquery. +query I rowsort +(WITH t AS (SELECT 400) SELECT * FROM t) UNION (SELECT * FROM t); +---- +1 +2 +3 +400 + +query error DataFusion error: Error during planning: table 'datafusion\.public\.cte' not found +(WITH cte AS (SELECT 400) SELECT * FROM cte) UNION (SELECT * FROM cte); + +# Test duplicate CTE names in different subqueries in the FROM clause. +query III rowsort +SELECT * FROM + (WITH t AS (select 400 as e) SELECT * FROM t) t1, + (WITH t AS (select 500 as e) SELECT * FROM t) t2, + t +---- +400 500 1 +400 500 2 +400 500 3 diff --git a/datafusion/sqllogictest/test_files/dates.slt b/datafusion/sqllogictest/test_files/dates.slt index 32c0bd14e7cc..4425eee33373 100644 --- a/datafusion/sqllogictest/test_files/dates.slt +++ b/datafusion/sqllogictest/test_files/dates.slt @@ -123,6 +123,7 @@ SELECT to_date(ts / 100000000) FROM to_date_t1 LIMIT 3 2003-11-02 2003-11-29 +# verify date with time zone, where the time zone date is already the next day, but result date in UTC is day before query D SELECT to_date('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ---- @@ -137,8 +138,15 @@ select to_date(arrow_cast(123, 'Int64')) ---- 1970-05-04 -statement error DataFusion error: Arrow error: +# Parse sequence of digits which yield a valid date, e.g. "21311111" would be "2131-11-11" +query D SELECT to_date('21311111'); +---- +2131-11-11 + +# Parse sequence of digits which do not make up a valid date +statement error DataFusion error: Arrow error: +SELECT to_date('213111111'); # verify date cast with integer input query DDDDDD @@ -186,6 +194,14 @@ create table ts_utf8_data(ts varchar(100), format varchar(100)) as values ('1926632005', '%s'), ('2000-01-01T01:01:01+07:00', '%+'); +statement ok +create table ts_largeutf8_data as +select arrow_cast(ts, 'LargeUtf8') as ts, arrow_cast(format, 'LargeUtf8') as format from ts_utf8_data; + +statement ok +create table ts_utf8view_data as +select arrow_cast(ts, 'Utf8View') as ts, arrow_cast(format, 'Utf8View') as format from ts_utf8_data; + # verify date data using tables with formatting options query D SELECT to_date(t.ts, t.format) from ts_utf8_data as t @@ -196,6 +212,24 @@ SELECT to_date(t.ts, t.format) from ts_utf8_data as t 2031-01-19 1999-12-31 +query D +SELECT to_date(t.ts, t.format) from ts_largeutf8_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + +query D +SELECT to_date(t.ts, t.format) from ts_utf8view_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + # verify date data using tables with formatting options query D SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t @@ -206,6 +240,24 @@ SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') 2031-01-19 1999-12-31 +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_largeutf8_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8view_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + # verify date data using tables with formatting options where at least one column cannot be parsed query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t @@ -220,9 +272,69 @@ SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', 2031-01-19 1999-12-31 +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_largeutf8_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + +query D +SELECT to_date(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8view_data as t +---- +2020-09-08 +2031-01-19 +2020-09-08 +2031-01-19 +1999-12-31 + # timestamp data using tables with formatting options in an array is not supported at this time query error function unsupported data type at index 1: SELECT to_date(t.ts, make_array('%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+')) from ts_utf8_data as t +# verify to_date with format +query D +select to_date('2022-01-23', '%Y-%m-%d'); +---- +2022-01-23 + +query PPPP +select + date_trunc('YEAR', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('MONTH', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('DAY', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('WEEK', to_date('2022-02-23', '%Y-%m-%d')); +---- +2022-01-01T00:00:00 2022-02-01T00:00:00 2022-02-23T00:00:00 2022-02-21T00:00:00 + +query PPPPP +select + date_trunc('HOUR', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('MINUTE', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('SECOND', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('MILLISECOND', to_date('2022-02-23', '%Y-%m-%d')), + date_trunc('MICROSECOND', to_date('2022-02-23', '%Y-%m-%d')); +---- +2022-02-23T00:00:00 2022-02-23T00:00:00 2022-02-23T00:00:00 2022-02-23T00:00:00 2022-02-23T00:00:00 + +query PPPP +select + date_trunc('YEAR', d2_date), + date_trunc('MONTH', d2_date), + date_trunc('DAY', d2_date), + date_trunc('WEEK', d2_date) +FROM test +---- +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-12T00:00:00 2022-12-12T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-11T00:00:00 2022-12-05T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-10T00:00:00 2022-12-05T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-09T00:00:00 2022-12-05T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-08T00:00:00 2022-12-05T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-07T00:00:00 2022-12-05T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-06T00:00:00 2022-12-05T00:00:00 +2022-01-01T00:00:00 2022-12-01T00:00:00 2022-12-05T00:00:00 2022-12-05T00:00:00 + statement ok drop table ts_utf8_data diff --git a/datafusion/sqllogictest/test_files/ddl.slt b/datafusion/sqllogictest/test_files/ddl.slt index 682972b5572a..4a0ba87bfa1a 100644 --- a/datafusion/sqllogictest/test_files/ddl.slt +++ b/datafusion/sqllogictest/test_files/ddl.slt @@ -256,7 +256,7 @@ DROP VIEW non_existent_view ########## statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); # create_table_as statement ok @@ -455,8 +455,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TIIIIIIIIIRRT SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1; @@ -470,7 +470,9 @@ statement ok CREATE EXTERNAL TABLE csv_with_timestamps ( name VARCHAR, ts TIMESTAMP -) STORED AS CSV LOCATION '../core/tests/data/timestamps.csv'; +) STORED AS CSV +LOCATION '../core/tests/data/timestamps.csv' +OPTIONS('format.has_header' 'false'); query TP SELECT * from csv_with_timestamps @@ -496,7 +498,8 @@ CREATE EXTERNAL TABLE csv_with_timestamps ( ) STORED AS CSV PARTITIONED BY (c_date) -LOCATION '../core/tests/data/partitioned_table'; +LOCATION '../core/tests/data/partitioned_table' +OPTIONS('format.has_header' 'false'); query TPD SELECT * from csv_with_timestamps where c_date='2018-11-13' @@ -535,7 +538,7 @@ DROP VIEW y; # create_pipe_delimited_csv_table() statement ok -CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW DELIMITER '|' LOCATION '../core/tests/data/aggregate_simple_pipe.csv'; +CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV LOCATION '../core/tests/data/aggregate_simple_pipe.csv' OPTIONS ('format.delimiter' '|', 'format.has_header' 'true'); query RRB @@ -581,14 +584,14 @@ statement ok CREATE TABLE IF NOT EXISTS table_without_values(field1 BIGINT, field2 BIGINT); statement ok -CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' +CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); # Should not recreate the same EXTERNAL table statement error Execution error: Table 'aggregate_simple' already exists -CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' +CREATE EXTERNAL TABLE aggregate_simple STORED AS CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); statement ok -CREATE EXTERNAL TABLE IF NOT EXISTS aggregate_simple STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' +CREATE EXTERNAL TABLE IF NOT EXISTS aggregate_simple STORED AS CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); # create bad custom table statement error DataFusion error: Execution error: Unable to find factory for DELTATABLE @@ -690,7 +693,7 @@ drop table foo; # create csv table with empty csv file statement ok -CREATE EXTERNAL TABLE empty STORED AS CSV WITH HEADER ROW LOCATION '../core/tests/data/empty.csv'; +CREATE EXTERNAL TABLE empty STORED AS CSV LOCATION '../core/tests/data/empty.csv' OPTIONS ('format.has_header' 'true'); query TTI select column_name, data_type, ordinal_position from information_schema.columns where table_name='empty';; @@ -707,7 +710,7 @@ create table t (i interval, x int) as values (interval '5 days 3 nanoseconds', C query ?I select * from t; ---- -0 years 0 mons 5 days 0 hours 0 mins 0.000000003 secs 1 +5 days 0.000000003 secs 1 statement ok drop table t; @@ -742,8 +745,8 @@ DROP SCHEMA empty_schema; statement ok CREATE UNBOUNDED external table t(c1 integer, c2 integer, c3 integer) STORED as CSV -WITH HEADER ROW -LOCATION '../core/tests/data/empty.csv'; +LOCATION '../core/tests/data/empty.csv' +OPTIONS ('format.has_header' 'true'); # should see infinite_source=true in the explain query TT @@ -760,8 +763,8 @@ drop table t; statement ok CREATE external table t(c1 integer, c2 integer, c3 integer) STORED as CSV -WITH HEADER ROW -LOCATION '../core/tests/data/empty.csv'; +LOCATION '../core/tests/data/empty.csv' +OPTIONS ('format.has_header' 'true'); # expect to see no infinite_source in the explain query TT @@ -772,3 +775,33 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te statement ok drop table t; + +statement ok +set datafusion.explain.logical_plan_only=true; + +query TT +explain CREATE TEMPORARY VIEW z AS VALUES (1,2,3); +---- +logical_plan +01)CreateView: Bare { table: "z" } +02)--Values: (Int64(1), Int64(2), Int64(3)) + +query TT +explain CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data/example.arrow'; +---- +logical_plan CreateExternalTable: Bare { table: "tty" } + +statement ok +set datafusion.explain.logical_plan_only=false; + +statement error DataFusion error: This feature is not implemented: Temporary tables not supported +CREATE EXTERNAL TEMPORARY TABLE tty STORED as ARROW LOCATION '../core/tests/data/example.arrow'; + +statement error DataFusion error: This feature is not implemented: Temporary views not supported +CREATE TEMPORARY VIEW y AS VALUES (1,2,3); + +query error DataFusion error: Schema error: No field named a\. +EXPLAIN CREATE TABLE t(a int) AS VALUES (a + a); + +statement error DataFusion error: Schema error: No field named a\. +CREATE TABLE t(a int) AS SELECT x FROM (VALUES (a)) t(x) WHERE false; diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 3f75e42d9304..8db28c32f13b 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -44,8 +44,8 @@ c4 BOOLEAN NOT NULL, c5 DECIMAL(12,7) NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../core/tests/data/decimal_data.csv'; +LOCATION '../core/tests/data/decimal_data.csv' +OPTIONS ('format.has_header' 'true'); query TT @@ -639,8 +639,8 @@ c4 BOOLEAN NOT NULL, c5 DECIMAL(52,7) NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../core/tests/data/decimal_data.csv'; +LOCATION '../core/tests/data/decimal_data.csv' +OPTIONS ('format.has_header' 'true'); query TT select arrow_typeof(c1), arrow_typeof(c5) from decimal256_simple limit 1; diff --git a/datafusion/sqllogictest/test_files/describe.slt b/datafusion/sqllogictest/test_files/describe.slt index f94a2e453884..077e8e6474d1 100644 --- a/datafusion/sqllogictest/test_files/describe.slt +++ b/datafusion/sqllogictest/test_files/describe.slt @@ -24,7 +24,7 @@ statement ok set datafusion.catalog.information_schema = true statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); query TTT rowsort DESCRIBE aggregate_simple; @@ -44,7 +44,7 @@ statement ok set datafusion.catalog.information_schema = false statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); query TTT rowsort DESCRIBE aggregate_simple; @@ -57,7 +57,7 @@ statement ok DROP TABLE aggregate_simple; ########## -# Describe file (currently we can only describe file in datafusion-cli, fix this after issue (#4850) has been done) +# Describe file (we can only describe file if the default catalog is `DynamicFileCatalog`) ########## statement error Error during planning: table 'datafusion.public.../core/tests/data/aggregate_simple.csv' not found diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 06b726502664..b6923fcc944d 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -62,7 +62,7 @@ FROM ( ('1000', 32, 'foo', 'True', 10.0, 1703035800000000000) ); -query ?RTTRP +query TRTTRP SELECT * FROM m1; ---- 1000 32 foo True 1 2023-12-20T00:00:00 @@ -137,7 +137,7 @@ FROM ( ('passive', '1000', 1000, 1701653400000000000) ); -query ??RP +query TTRP SELECT * FROM m2; ---- active 1000 100 2023-12-04T00:00:00 @@ -208,7 +208,7 @@ true false NULL true true false true NULL # Reproducer for https://github.com/apache/datafusion/issues/8738 # This query should work correctly -query P?TT rowsort +query PTTT rowsort SELECT "data"."timestamp" as "time", "data"."tag_id", @@ -264,7 +264,7 @@ ORDER BY # deterministic sort (so we can avoid rowsort) -query P?TT +query PTTT SELECT "data"."timestamp" as "time", "data"."tag_id", @@ -348,7 +348,7 @@ create table m3 as from m3_source; # there are two values in column2 -query T?I rowsort +query TTI rowsort SELECT * FROM m3; ---- @@ -386,3 +386,67 @@ drop table m3; statement ok drop table m3_source; + + +## Test that filtering on dictionary columns coerces the filter value to the dictionary type +statement ok +create table test as values + ('row1', arrow_cast('1', 'Dictionary(Int32, Utf8)')), + ('row2', arrow_cast('2', 'Dictionary(Int32, Utf8)')), + ('row3', arrow_cast('3', 'Dictionary(Int32, Utf8)')) +; + +# query using an string '1' which must be coerced into a dictionary string +query TT +SELECT * from test where column2 = '1'; +---- +row1 1 + +# filter should not have a cast on column2 +query TT +explain SELECT * from test where column2 = '1'; +---- +logical_plan +01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +02)--TableScan: test projection=[column1, column2] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column2@1 = 1 +03)----MemoryExec: partitions=1, partition_sizes=[1] + +# try literal = col to verify order doesn't matter +# filter should not cast column2 +query TT +explain SELECT * from test where '1' = column2 +---- +logical_plan +01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +02)--TableScan: test projection=[column1, column2] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column2@1 = 1 +03)----MemoryExec: partitions=1, partition_sizes=[1] + + +# Now query using an integer which must be coerced into a dictionary string +query TT +SELECT * from test where column2 = 1; +---- +row1 1 + +query TT +explain SELECT * from test where column2 = 1; +---- +logical_plan +01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +02)--TableScan: test projection=[column1, column2] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column2@1 = 1 +03)----MemoryExec: partitions=1, partition_sizes=[1] + +# Window Functions +query I +select dense_rank() over (order by arrow_cast('abc', 'Dictionary(UInt16, Utf8)')); +---- +1 diff --git a/datafusion/sqllogictest/test_files/distinct_on.slt b/datafusion/sqllogictest/test_files/distinct_on.slt index 972c935cee99..604ac95ff476 100644 --- a/datafusion/sqllogictest/test_files/distinct_on.slt +++ b/datafusion/sqllogictest/test_files/distinct_on.slt @@ -32,8 +32,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # Basic example: distinct on the first column project the second one, and # order by the third @@ -89,18 +89,18 @@ query TT EXPLAIN SELECT DISTINCT ON (c1) c3, c2 FROM aggregate_test_100 ORDER BY c1, c3; ---- logical_plan -01)Projection: FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 +01)Projection: first_value(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c3, first_value(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST] AS c2 02)--Sort: aggregate_test_100.c1 ASC NULLS LAST -03)----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] +03)----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[first_value(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], first_value(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]]] 04)------TableScan: aggregate_test_100 projection=[c1, c2, c3] physical_plan -01)ProjectionExec: expr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] +01)ProjectionExec: expr=[first_value(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@1 as c3, first_value(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]@2 as c2] 02)--SortPreservingMergeExec: [c1@0 ASC NULLS LAST] -03)----SortExec: expr=[c1@0 ASC NULLS LAST] -04)------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]] +03)----SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] +04)------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[first_value(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], first_value(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[FIRST_VALUE(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], FIRST_VALUE(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]] +07)------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[first_value(aggregate_test_100.c3) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST], first_value(aggregate_test_100.c2) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c3 ASC NULLS LAST]] 08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true @@ -143,3 +143,52 @@ LIMIT 3; -25 15295 45 15673 -72 -11122 + +# use wildcard +query TIIIIIIIITRRT +SELECT DISTINCT ON (c1) * FROM aggregate_test_100 ORDER BY c1 LIMIT 3; +---- +a 1 -85 -15154 1171968280 1919439543497968449 77 52286 774637006 12101411955859039553 0.12285209 0.686439196277 0keZ5G8BffGwgF2RwQD59TFzMStxCB +b 1 29 -18218 994303988 5983957848665088916 204 9489 3275293996 14857091259186476033 0.53840446 0.179090351188 AyYVExXK6AR2qUTxNZ7qRHQOVGMLcz +c 2 1 18109 2033001162 -6513304855495910254 25 43062 1491205016 5863949479783605708 0.110830784 0.929409733247 6WfVFBVGJSQb7FhA7E0lBwdvjfZnSW + +# can't distinct on * +query error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \*"\) +SELECT DISTINCT ON (*) c1 FROM aggregate_test_100 ORDER BY c1 LIMIT 3; + + +# test distinct on +statement ok +create table t(a int, b int, c int) as values (1, 2, 3); + +statement ok +set datafusion.explain.logical_plan_only = true; + +query TT +explain select distinct on (a) b from t order by a desc, c; +---- +logical_plan +01)Projection: first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST] AS b +02)--Sort: t.a DESC NULLS FIRST +03)----Aggregate: groupBy=[[t.a]], aggr=[[first_value(t.b) ORDER BY [t.a DESC NULLS FIRST, t.c ASC NULLS LAST]]] +04)------TableScan: t projection=[a, b, c] + +statement ok +drop table t; + +# test distinct +statement ok +create table t(a int, b int) as values (1, 2); + +statement ok +set datafusion.explain.logical_plan_only = true; + +query TT +explain select distinct a, b from t; +---- +logical_plan +01)Aggregate: groupBy=[[t.a, t.b]], aggr=[[]] +02)--TableScan: t projection=[a, b] + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/dynamic_file.slt b/datafusion/sqllogictest/test_files/dynamic_file.slt new file mode 100644 index 000000000000..69f9a43ad407 --- /dev/null +++ b/datafusion/sqllogictest/test_files/dynamic_file.slt @@ -0,0 +1,267 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# +# Note: This file runs with a SessionContext that has the `enable_url_table` flag set +# +# dynamic select arrow file in the folder +query ITB +SELECT * FROM '../core/tests/data/partitioned_table_arrow/part=123' ORDER BY f0; +---- +1 foo true +2 bar false + +# Read partitioned file +statement ok +CREATE TABLE src_table_1 ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + partition_col INT +) AS VALUES +(1, 'aaa', 100, 1), +(2, 'bbb', 200, 1), +(3, 'ccc', 300, 1), +(4, 'ddd', 400, 1); + +statement ok +CREATE TABLE src_table_2 ( + int_col INT, + string_col TEXT, + bigint_col BIGINT, + partition_col INT +) AS VALUES +(5, 'eee', 500, 2), +(6, 'fff', 600, 2), +(7, 'ggg', 700, 2), +(8, 'hhh', 800, 2); + +# Read partitioned csv file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/csv_partitions' +STORED AS CSV +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/csv_partitions' +STORED AS CSV +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/csv_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned json file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/json_partitions' +STORED AS JSON +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/json_partitions' +STORED AS JSON +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/json_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned arrow file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/arrow_partitions' +STORED AS ARROW +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/arrow_partitions' +STORED AS ARROW +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +SELECT int_col, string_col, bigint_col, partition_col FROM 'test_files/scratch/dynamic_file/arrow_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned parquet file + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/parquet_partitions' +STORED AS PARQUET +PARTITIONED BY (partition_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/parquet_partitions' +STORED AS PARQUET +PARTITIONED BY (partition_col); +---- +4 + +query ITIT rowsort +select * from 'test_files/scratch/dynamic_file/parquet_partitions'; +---- +1 aaa 100 1 +2 bbb 200 1 +3 ccc 300 1 +4 ddd 400 1 +5 eee 500 2 +6 fff 600 2 +7 ggg 700 2 +8 hhh 800 2 + +# Read partitioned parquet file with multiple partition columns + +query I +COPY src_table_1 TO 'test_files/scratch/dynamic_file/nested_partition' +STORED AS PARQUET +PARTITIONED BY (partition_col, string_col); +---- +4 + +query I +COPY src_table_2 TO 'test_files/scratch/dynamic_file/nested_partition' +STORED AS PARQUET +PARTITIONED BY (partition_col, string_col); +---- +4 + +query IITT rowsort +select * from 'test_files/scratch/dynamic_file/nested_partition'; +---- +1 100 1 aaa +2 200 1 bbb +3 300 1 ccc +4 400 1 ddd +5 500 2 eee +6 600 2 fff +7 700 2 ggg +8 800 2 hhh + +# read avro file +query IT +SELECT id, CAST(string_col AS varchar) FROM '../../testing/data/avro/alltypes_plain.avro' +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# dynamic query snappy avro file +query IT +SELECT id, CAST(string_col AS varchar) FROM '../../testing/data/avro/alltypes_plain.snappy.avro' +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 + +# query the csv file dynamically with the config of current session +query TT +select * from '../core/tests/data/quote.csv'; +---- +~id0~ ~value0~ +~id1~ ~value1~ +~id2~ ~value2~ +~id3~ ~value3~ +~id4~ ~value4~ +~id5~ ~value5~ +~id6~ ~value6~ +~id7~ ~value7~ +~id8~ ~value8~ +~id9~ ~value9~ + +query TTT +DESCRIBE '../core/tests/data/aggregate_simple.csv'; +---- +c1 Float64 YES +c2 Float64 YES +c3 Boolean YES + +query IR rowsort +SELECT a, b FROM '../core/tests/data/2.json' +---- +-10 -3.5 +1 -3.5 +1 0.6 +1 0.6 +1 2 +1 2 +1 2 +1 2 +100000000000000 0.6 +2 0.6 +5 -3.5 +7 -3.5 + +query IT +SELECT id, CAST(string_col AS varchar) FROM '../../parquet-testing/data/alltypes_plain.parquet'; +---- +4 0 +5 1 +6 0 +7 1 +2 0 +3 1 +0 0 +1 1 diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index 9f4f508e23f3..68bdf78115aa 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -20,7 +20,7 @@ CREATE TABLE test( num INT, bin_field BYTEA, base64_field TEXT, - hex_field TEXT, + hex_field TEXT ) as VALUES (0, 'abc', encode('abc', 'base64'), encode('abc', 'hex')), (1, 'qweqwe', encode('qweqwe', 'base64'), encode('qweqwe', 'hex')), @@ -28,19 +28,19 @@ CREATE TABLE test( ; # errors -query error DataFusion error: Error during planning: The encode function can only accept utf8 or binary\. -select encode(12, 'hex') +query error 1st argument should be Utf8 or Binary or Null, got Int64 +select encode(12, 'hex'); query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex select encode(bin_field, 'non_encoding') from test; -query error DataFusion error: Error during planning: The decode function can only accept utf8 or binary\. -select decode(12, 'hex') +query error 1st argument should be Utf8 or Binary or Null, got Int64 +select decode(12, 'hex'); query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex select decode(hex_field, 'non_encoding') from test; -query error DataFusion error: Error during planning: No function matches the given name and argument types 'to_hex\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tto_hex\(Int64\) +query error select to_hex(hex_field) from test; # Arrays tests diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index ab281eac31f5..da46a7e5e679 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -34,11 +34,11 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # csv_query_error -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'sin\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tsin\(Float64/Float32\) +statement error SELECT sin(c1) FROM aggregate_test_100 # cast_expressions_error @@ -46,7 +46,7 @@ statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'c SELECT CAST(c1 AS INT) FROM aggregate_test_100 # aggregation_with_bad_arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tCOUNT\(Any, .., Any\) +query error SELECT COUNT(DISTINCT) FROM aggregate_test_100 # query_cte_incorrect @@ -80,23 +80,23 @@ SELECT COUNT(*) FROM way.too.many.namespaces.as.ident.prefixes.aggregate_test_10 # # error message for wrong function signature (Variadic: arbitrary number of args all from some common types) -statement error Error during planning: No function matches the given name and argument types 'concat\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\) +statement error SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts. +statement error SELECT nullif(1); # error message for wrong function signature (Exact: exact number of args of an exact type) -statement error Error during planning: No function matches the given name and argument types 'pi\(Float64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpi\(\) +statement error SELECT pi(3.14); # error message for wrong function signature (Any: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\) +statement error SELECT arrow_typeof(1, 1); # error message for wrong function signature (OneOf: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64, Float64\) +statement error SELECT power(1, 2, 3); # @@ -104,19 +104,15 @@ SELECT power(1, 2, 3); # # AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'COUNT\(\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tCOUNT\(Any, \.\., Any\) -select count(); - -# AggregateFunction with wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'AVG\(Utf8, Float64\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tAVG\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +query error select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select regr_slope(1, '2'); # WindowFunction using AggregateFunction wrong signature -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select c9, regr_slope(c11, '2') over () as min1 @@ -132,5 +128,12 @@ from aggregate_test_100 order by c9 -statement error Inconsistent data type across values list at row 1 column 0. Was Int64 but found Utf8 +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'foo' to value of Int64 type create table foo as values (1), ('foo'); + +query error No function matches +select 1 group by substr(''); + +# Error in filter should be reported +query error Divide by zero +SELECT c2 from aggregate_test_100 where CASE WHEN true THEN 1 / 0 ELSE 0 END = 1; diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index d8c8dcd41b6a..1340fd490e06 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -32,8 +32,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TT explain SELECT c1 FROM aggregate_test_100 where c2 > 10 @@ -43,11 +43,10 @@ logical_plan 02)--Filter: aggregate_test_100.c2 > Int8(10) 03)----TableScan: aggregate_test_100 projection=[c1, c2], partial_filters=[aggregate_test_100.c2 > Int8(10)] physical_plan -01)ProjectionExec: expr=[c1@0 as c1] -02)--CoalesceBatchesExec: target_batch_size=8192 -03)----FilterExec: c2@1 > 10 -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2], has_header=true +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: c2@1 > 10, projection=[c1@0] +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2], has_header=true # explain_csv_exec_scan_config @@ -68,21 +67,17 @@ CREATE EXTERNAL TABLE aggregate_test_100_with_order ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW WITH ORDER (c1 ASC) -LOCATION '../core/tests/data/aggregate_test_100_order_by_c1_asc.csv'; +LOCATION '../core/tests/data/aggregate_test_100_order_by_c1_asc.csv' +OPTIONS ('format.has_header' 'true'); query TT explain SELECT c1 FROM aggregate_test_100_with_order order by c1 ASC limit 10 ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: aggregate_test_100_with_order.c1 ASC NULLS LAST, fetch=10 -03)----TableScan: aggregate_test_100_with_order projection=[c1] -physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_test_100_order_by_c1_asc.csv]]}, projection=[c1], output_ordering=[c1@0 ASC NULLS LAST], has_header=true - +01)Sort: aggregate_test_100_with_order.c1 ASC NULLS LAST, fetch=10 +02)--TableScan: aggregate_test_100_with_order projection=[c1] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/aggregate_test_100_order_by_c1_asc.csv]]}, projection=[c1], limit=10, output_ordering=[c1@0 ASC NULLS LAST], has_header=true ## explain_physical_plan_only @@ -93,7 +88,7 @@ query TT EXPLAIN select count(*) from (values ('a', 1, 100), ('a', 2, 150)) as t (c1,c2,c3) ---- physical_plan -01)ProjectionExec: expr=[2 as COUNT(*)] +01)ProjectionExec: expr=[2 as count(*)] 02)--PlaceholderRowExec statement ok @@ -128,8 +123,8 @@ CREATE EXTERNAL TABLE simple_explain_test ( c INT ) STORED AS CSV -WITH HEADER ROW LOCATION '../core/tests/data/example.csv' +OPTIONS ('format.has_header' 'true'); query TT EXPLAIN SELECT a, b, c FROM simple_explain_test @@ -156,8 +151,8 @@ CREATE UNBOUNDED EXTERNAL TABLE sink_table ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TT EXPLAIN INSERT INTO sink_table SELECT * FROM aggregate_test_100 ORDER by c1 @@ -169,7 +164,7 @@ logical_plan 04)------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan 01)DataSinkExec: sink=StreamWrite { location: "../../testing/data/csv/aggregate_test_100.csv", batch_size: 8192, encoding: Csv, header: true, .. } -02)--SortExec: expr=[c1@0 ASC NULLS LAST] +02)--SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true # test EXPLAIN VERBOSE @@ -179,8 +174,9 @@ EXPLAIN VERBOSE SELECT a, b, c FROM simple_explain_test initial_logical_plan 01)Projection: simple_explain_test.a, simple_explain_test.b, simple_explain_test.c 02)--TableScan: simple_explain_test -logical_plan after apply_function_rewrites SAME TEXT AS ABOVE logical_plan after inline_table_scan SAME TEXT AS ABOVE +logical_plan after expand_wildcard_rule SAME TEXT AS ABOVE +logical_plan after resolve_grouping_function SAME TEXT AS ABOVE logical_plan after type_coercion SAME TEXT AS ABOVE logical_plan after count_wildcard_rule SAME TEXT AS ABOVE analyzed_logical_plan SAME TEXT AS ABOVE @@ -192,8 +188,6 @@ logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE @@ -209,6 +203,7 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after eliminate_nested_union SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE @@ -218,8 +213,6 @@ logical_plan after eliminate_join SAME TEXT AS ABOVE logical_plan after decorrelate_predicate_subquery SAME TEXT AS ABOVE logical_plan after scalar_subquery_to_join SAME TEXT AS ABOVE logical_plan after extract_equijoin_predicate SAME TEXT AS ABOVE -logical_plan after simplify_expressions SAME TEXT AS ABOVE -logical_plan after rewrite_disjunctive_predicate SAME TEXT AS ABOVE logical_plan after eliminate_duplicated_expr SAME TEXT AS ABOVE logical_plan after eliminate_filter SAME TEXT AS ABOVE logical_plan after eliminate_cross_join SAME TEXT AS ABOVE @@ -235,10 +228,12 @@ logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after simplify_expressions SAME TEXT AS ABOVE logical_plan after unwrap_cast_in_comparison SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true initial_physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +initial_physical_plan_with_schema CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] physical_plan after OutputRequirements 01)OutputRequirementExec 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true @@ -252,11 +247,29 @@ physical_plan after OptimizeAggregateOrder SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE +physical_plan after LimitPushdown SAME TEXT AS ABOVE +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true physical_plan_with_stats CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +physical_plan_with_schema CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] + +### tests for EXPLAIN with display schema enabled + +statement ok +set datafusion.explain.show_schema = true; + +# test EXPLAIN VERBOSE +query TT +EXPLAIN SELECT a, b, c FROM simple_explain_test; +---- +logical_plan TableScan: simple_explain_test projection=[a, b, c] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], has_header=true, schema=[a:Int32;N, b:Int32;N, c:Int32;N] + + +statement ok +set datafusion.explain.show_schema = false; ### tests for EXPLAIN with display statistics enabled @@ -270,9 +283,7 @@ set datafusion.explain.physical_plan_only = true; query TT EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10; ---- -physical_plan -01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] -02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:)]] # Parquet scan with statistics collected statement ok @@ -284,9 +295,7 @@ CREATE EXTERNAL TABLE alltypes_plain STORED AS PARQUET LOCATION '../../parquet-t query TT EXPLAIN SELECT * FROM alltypes_plain limit 10; ---- -physical_plan -01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] # explain verbose with both collect & show statistics on query TT @@ -295,6 +304,9 @@ EXPLAIN VERBOSE SELECT * FROM alltypes_plain limit 10; initial_physical_plan 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +initial_physical_plan_with_schema +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] physical_plan after OutputRequirements 01)OutputRequirementExec, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] @@ -311,12 +323,12 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan -01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after LimitPushdown ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan_with_schema ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] statement ok @@ -332,6 +344,9 @@ initial_physical_plan initial_physical_plan_with_stats 01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +initial_physical_plan_with_schema +01)GlobalLimitExec: skip=0, fetch=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] physical_plan after OutputRequirements 01)OutputRequirementExec 02)--GlobalLimitExec: skip=0, fetch=10 @@ -348,21 +363,19 @@ physical_plan after coalesce_batches SAME TEXT AS ABOVE physical_plan after OutputRequirements 01)GlobalLimitExec: skip=0, fetch=10 02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 -physical_plan after PipelineChecker SAME TEXT AS ABOVE physical_plan after LimitAggregation SAME TEXT AS ABOVE physical_plan after ProjectionPushdown SAME TEXT AS ABOVE -physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 -physical_plan_with_stats -01)GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] -02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan after LimitPushdown ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan after SanityCheckPlan SAME TEXT AS ABOVE +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10 +physical_plan_with_stats ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, statistics=[Rows=Exact(8), Bytes=Absent, [(Col[0]:),(Col[1]:),(Col[2]:),(Col[3]:),(Col[4]:),(Col[5]:),(Col[6]:),(Col[7]:),(Col[8]:),(Col[9]:),(Col[10]:)]] +physical_plan_with_schema ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/parquet-testing/data/alltypes_plain.parquet]]}, projection=[id, bool_col, tinyint_col, smallint_col, int_col, bigint_col, float_col, double_col, date_string_col, string_col, timestamp_col], limit=10, schema=[id:Int32;N, bool_col:Boolean;N, tinyint_col:Int32;N, smallint_col:Int32;N, int_col:Int32;N, bigint_col:Int64;N, float_col:Float32;N, double_col:Float64;N, date_string_col:Binary;N, string_col:Binary;N, timestamp_col:Timestamp(Nanosecond, None);N] statement ok set datafusion.execution.collect_statistics = false; -# Explain ArrayFuncions +# Explain ArrayFunctions statement ok set datafusion.explain.physical_plan_only = false diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index adc577f12f91..182afff7a693 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -22,7 +22,7 @@ SELECT true, false, false = false, true = false true false true false # test_mathematical_expressions_with_null -query RRRRRRRRRRRRRRRRRR?RRRRRRRIRRRRRRBB +query RRRRRRRRRRRRRRRRRR?RRRRRIIIRRRRRRBB SELECT sqrt(NULL), cbrt(NULL), @@ -63,7 +63,7 @@ SELECT NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL # test_array_cast_invalid_timezone_will_panic -statement error Parser error: Invalid timezone "Foo": 'Foo' is not a valid timezone +statement error Parser error: Invalid timezone "Foo": failed to parse timezone SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some("Foo"))') # test_array_index @@ -122,203 +122,197 @@ SELECT query ? SELECT interval '1' ---- -0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs +1.000000000 secs query ? SELECT interval '1 second' ---- -0 years 0 mons 0 days 0 hours 0 mins 1.000000000 secs +1.000000000 secs query ? SELECT interval '500 milliseconds' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs +0.500000000 secs query ? SELECT interval '5 second' ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs query ? SELECT interval '0.5 minute' ---- -0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs +30.000000000 secs query ? SELECT interval '.5 minute' ---- -0 years 0 mons 0 days 0 hours 0 mins 30.000000000 secs +30.000000000 secs query ? SELECT interval '5 minute' ---- -0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs +5 mins query ? SELECT interval '5 minute 1 second' ---- -0 years 0 mons 0 days 0 hours 5 mins 1.000000000 secs +5 mins 1.000000000 secs query ? SELECT interval '1 hour' ---- -0 years 0 mons 0 days 1 hours 0 mins 0.000000000 secs +1 hours query ? SELECT interval '5 hour' ---- -0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs +5 hours query ? SELECT interval '1 day' ---- -0 years 0 mons 1 days 0 hours 0 mins 0.000000000 secs +1 days query ? SELECT interval '1 week' ---- -0 years 0 mons 7 days 0 hours 0 mins 0.000000000 secs +7 days query ? SELECT interval '2 weeks' ---- -0 years 0 mons 14 days 0 hours 0 mins 0.000000000 secs +14 days query ? SELECT interval '1 day 1' ---- -0 years 0 mons 1 days 0 hours 0 mins 1.000000000 secs +1 days 1.000000000 secs query ? SELECT interval '0.5' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.500000000 secs +0.500000000 secs query ? SELECT interval '0.5 day 1' ---- -0 years 0 mons 0 days 12 hours 0 mins 1.000000000 secs +12 hours 1.000000000 secs query ? SELECT interval '0.49 day' ---- -0 years 0 mons 0 days 11 hours 45 mins 36.000000000 secs +11 hours 45 mins 36.000000000 secs query ? SELECT interval '0.499 day' ---- -0 years 0 mons 0 days 11 hours 58 mins 33.600000000 secs +11 hours 58 mins 33.600000000 secs query ? SELECT interval '0.4999 day' ---- -0 years 0 mons 0 days 11 hours 59 mins 51.360000000 secs +11 hours 59 mins 51.360000000 secs query ? SELECT interval '0.49999 day' ---- -0 years 0 mons 0 days 11 hours 59 mins 59.136000000 secs +11 hours 59 mins 59.136000000 secs query ? SELECT interval '0.49999999999 day' ---- -0 years 0 mons 0 days 11 hours 59 mins 59.999999136 secs +11 hours 59 mins 59.999999136 secs query ? SELECT interval '5 day' ---- -0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs - -# Hour is ignored, this matches PostgreSQL -query ? -SELECT interval '5 day' hour ----- -0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs +5 days query ? SELECT interval '5 day 4 hours 3 minutes 2 seconds 100 milliseconds' ---- -0 years 0 mons 5 days 4 hours 3 mins 2.100000000 secs +5 days 4 hours 3 mins 2.100000000 secs query ? SELECT interval '0.5 month' ---- -0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs +15 days query ? SELECT interval '0.5' month ---- -0 years 0 mons 15 days 0 hours 0 mins 0.000000000 secs +15 days query ? SELECT interval '1 month' ---- -0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs +1 mons query ? SELECT interval '1' MONTH ---- -0 years 1 mons 0 days 0 hours 0 mins 0.000000000 secs +1 mons query ? SELECT interval '5 month' ---- -0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs +5 mons query ? SELECT interval '13 month' ---- -0 years 13 mons 0 days 0 hours 0 mins 0.000000000 secs +13 mons query ? SELECT interval '0.5 year' ---- -0 years 6 mons 0 days 0 hours 0 mins 0.000000000 secs +6 mons query ? SELECT interval '1 year' ---- -0 years 12 mons 0 days 0 hours 0 mins 0.000000000 secs +12 mons query ? SELECT interval '1 decade' ---- -0 years 120 mons 0 days 0 hours 0 mins 0.000000000 secs +120 mons query ? SELECT interval '2 decades' ---- -0 years 240 mons 0 days 0 hours 0 mins 0.000000000 secs +240 mons query ? SELECT interval '1 century' ---- -0 years 1200 mons 0 days 0 hours 0 mins 0.000000000 secs +1200 mons query ? SELECT interval '2 year' ---- -0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs +24 mons query ? SELECT interval '1 year 1 day' ---- -0 years 12 mons 1 days 0 hours 0 mins 0.000000000 secs +12 mons 1 days query ? SELECT interval '1 year 1 day 1 hour' ---- -0 years 12 mons 1 days 1 hours 0 mins 0.000000000 secs +12 mons 1 days 1 hours query ? SELECT interval '1 year 1 day 1 hour 1 minute' ---- -0 years 12 mons 1 days 1 hours 1 mins 0.000000000 secs +12 mons 1 days 1 hours 1 mins query ? SELECT interval '1 year 1 day 1 hour 1 minute 1 second' ---- -0 years 12 mons 1 days 1 hours 1 mins 1.000000000 secs +12 mons 1 days 1 hours 1 mins 1.000000000 secs query I SELECT ascii('') @@ -365,7 +359,7 @@ SELECT bit_length('josé') ---- 40 -query ? +query I SELECT bit_length(NULL) ---- NULL @@ -395,7 +389,7 @@ SELECT btrim('\nxyxtrimyyx\n', 'xyz\n') ---- trim -query ? +query T SELECT btrim(NULL, 'xyz') ---- NULL @@ -476,7 +470,7 @@ SELECT initcap('hi THOMAS') ---- Hi Thomas -query ? +query T SELECT initcap(NULL) ---- NULL @@ -491,7 +485,7 @@ SELECT lower('TOM') ---- tom -query ? +query T SELECT lower(NULL) ---- NULL @@ -511,7 +505,7 @@ SELECT ltrim('zzzytest', 'xyz') ---- test -query ? +query T SELECT ltrim(NULL, 'xyz') ---- NULL @@ -531,7 +525,7 @@ SELECT octet_length('josé') ---- 5 -query ? +query I SELECT octet_length(NULL) ---- NULL @@ -541,12 +535,17 @@ SELECT repeat('Pg', 4) ---- PgPgPgPg +query T +SELECT repeat('Pg', -1) +---- +(empty) + query T SELECT repeat('Pg', CAST(NULL AS INT)) ---- NULL -query ? +query T SELECT repeat(NULL, 4) ---- NULL @@ -571,7 +570,7 @@ SELECT replace('abcdefabcdef', NULL, 'XX') ---- NULL -query ? +query T SELECT replace(NULL, 'cd', 'XX') ---- NULL @@ -591,7 +590,7 @@ SELECT rtrim('testxxzx', 'xyz') ---- test -query ? +query T SELECT rtrim(NULL, 'xyz') ---- NULL @@ -606,7 +605,7 @@ SELECT split_part('abc~@~def~@~ghi', '~@~', 20) ---- (empty) -query ? +query T SELECT split_part(NULL, '~@~', 20) ---- NULL @@ -621,6 +620,19 @@ SELECT split_part('abc~@~def~@~ghi', '~@~', CAST(NULL AS INT)) ---- NULL +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', -1) +---- +ghi + +query T +SELECT split_part('abc~@~def~@~ghi', '~@~', -100) +---- +(empty) + +statement error DataFusion error: Execution error: field position must not be zero +SELECT split_part('abc~@~def~@~ghi', '~@~', 0) + query B SELECT starts_with('alphabet', 'alph') ---- @@ -770,7 +782,7 @@ SELECT upper('tom') ---- TOM -query ? +query T SELECT upper(NULL) ---- NULL @@ -809,6 +821,12 @@ SELECT # test_extract_date_part +query error +SELECT EXTRACT("'''year'''" FROM timestamp '2020-09-08T12:00:00+00:00') + +query error +SELECT EXTRACT("'year'" FROM timestamp '2020-09-08T12:00:00+00:00') + query R SELECT date_part('YEAR', CAST('2000-01-01' AS DATE)) ---- @@ -819,6 +837,16 @@ SELECT EXTRACT(year FROM timestamp '2020-09-08T12:00:00+00:00') ---- 2020 +query R +SELECT EXTRACT("year" FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + +query R +SELECT EXTRACT('year' FROM timestamp '2020-09-08T12:00:00+00:00') +---- +2020 + query R SELECT date_part('QUARTER', CAST('2000-01-01' AS DATE)) ---- @@ -829,6 +857,16 @@ SELECT EXTRACT(quarter FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 3 +query R +SELECT EXTRACT("quarter" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + +query R +SELECT EXTRACT('quarter' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +3 + query R SELECT date_part('MONTH', CAST('2000-01-01' AS DATE)) ---- @@ -839,6 +877,16 @@ SELECT EXTRACT(month FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 9 +query R +SELECT EXTRACT("month" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + +query R +SELECT EXTRACT('month' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +9 + query R SELECT date_part('WEEK', CAST('2003-01-01' AS DATE)) ---- @@ -849,6 +897,16 @@ SELECT EXTRACT(WEEK FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 37 +query R +SELECT EXTRACT("WEEK" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + +query R +SELECT EXTRACT('WEEK' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +37 + query R SELECT date_part('DAY', CAST('2000-01-01' AS DATE)) ---- @@ -859,6 +917,16 @@ SELECT EXTRACT(day FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 8 +query R +SELECT EXTRACT("day" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + +query R +SELECT EXTRACT('day' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +8 + query R SELECT date_part('DOY', CAST('2000-01-01' AS DATE)) ---- @@ -869,6 +937,16 @@ SELECT EXTRACT(doy FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 252 +query R +SELECT EXTRACT("doy" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + +query R +SELECT EXTRACT('doy' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +252 + query R SELECT date_part('DOW', CAST('2000-01-01' AS DATE)) ---- @@ -879,6 +957,16 @@ SELECT EXTRACT(dow FROM to_timestamp('2020-09-08T12:00:00+00:00')) ---- 2 +query R +SELECT EXTRACT("dow" FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + +query R +SELECT EXTRACT('dow' FROM to_timestamp('2020-09-08T12:00:00+00:00')) +---- +2 + query R SELECT date_part('HOUR', CAST('2000-01-01' AS DATE)) ---- @@ -889,11 +977,31 @@ SELECT EXTRACT(hour FROM to_timestamp('2020-09-08T12:03:03+00:00')) ---- 12 +query R +SELECT EXTRACT("hour" FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + +query R +SELECT EXTRACT('hour' FROM to_timestamp('2020-09-08T12:03:03+00:00')) +---- +12 + query R SELECT EXTRACT(minute FROM to_timestamp('2020-09-08T12:12:00+00:00')) ---- 12 +query R +SELECT EXTRACT("minute" FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + +query R +SELECT EXTRACT('minute' FROM to_timestamp('2020-09-08T12:12:00+00:00')) +---- +12 + query R SELECT date_part('minute', to_timestamp('2020-09-08T12:12:00+00:00')) ---- @@ -919,6 +1027,46 @@ SELECT EXTRACT(nanosecond FROM timestamp '2020-09-08T12:00:12.12345678+00:00') ---- 12123456780 +query R +SELECT EXTRACT("second" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT EXTRACT("millisecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT EXTRACT("microsecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT EXTRACT("nanosecond" FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + +query R +SELECT EXTRACT('second' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12.12345678 + +query R +SELECT EXTRACT('millisecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123.45678 + +query R +SELECT EXTRACT('microsecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456.78 + +query R +SELECT EXTRACT('nanosecond' FROM timestamp '2020-09-08T12:00:12.12345678+00:00') +---- +12123456780 + # Keep precision when coercing Utf8 to Timestamp query R SELECT date_part('second', timestamp '2020-09-08T12:00:12.12345678+00:00') @@ -1201,6 +1349,16 @@ SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanose ---- 50.123456789 +query R +select extract(second from '2024-08-09T12:13:14') +---- +14 + +query R +select extract(seconds from '2024-08-09T12:13:14') +---- +14 + query R SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ---- @@ -1227,6 +1385,11 @@ SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(N ---- 50123456.789000005 +query R +SELECT extract(us from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + query R SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) ---- @@ -1309,6 +1472,189 @@ SELECT extract(epoch from arrow_cast('1969-12-31', 'Date64')) ---- -86400 +# test_extract_interval + +query R +SELECT extract(year from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +10 + +query R +SELECT extract(month from arrow_cast('10 years', 'Interval(YearMonth)')) +---- +0 + +query R +SELECT extract(year from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +0 + +query R +SELECT extract(month from arrow_cast('10 months', 'Interval(YearMonth)')) +---- +10 + +query R +SELECT extract(year from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +1 + +query R +SELECT extract(month from arrow_cast('20 months', 'Interval(YearMonth)')) +---- +8 + +query error DataFusion error: Arrow error: Compute error: Year does not support: Interval\(DayTime\) +SELECT extract(year from arrow_cast('10 days', 'Interval(DayTime)')) + +query error DataFusion error: Arrow error: Compute error: Month does not support: Interval\(DayTime\) +SELECT extract(month from arrow_cast('10 days', 'Interval(DayTime)')) + +query R +SELECT extract(day from arrow_cast('10 days', 'Interval(DayTime)')) +---- +10 + +query R +SELECT extract(day from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +0 + +query R +SELECT extract(minute from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +14400 + +query R +SELECT extract(second from arrow_cast('5.1 seconds', 'Interval(DayTime)')) +---- +5 + +query R +SELECT extract(second from arrow_cast('14400 minutes', 'Interval(DayTime)')) +---- +864000 + +query R +SELECT extract(second from arrow_cast('2 months', 'Interval(MonthDayNano)')) +---- +0 + +query R +SELECT extract(second from arrow_cast('2 days', 'Interval(MonthDayNano)')) +---- +0 + +query R +SELECT extract(second from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query R +SELECT extract(seconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2 + +query R +SELECT extract(milliseconds from arrow_cast('2 seconds', 'Interval(MonthDayNano)')) +---- +2000 + +query R +SELECT extract(second from arrow_cast('2030 milliseconds', 'Interval(MonthDayNano)')) +---- +2.03 + +query R +SELECT extract(second from arrow_cast(NULL, 'Interval(MonthDayNano)')) +---- +NULL + +statement ok +create table t (id int, i interval) as values + (0, interval '5 months 1 day 10 nanoseconds'), + (1, interval '1 year 3 months'), + (2, interval '3 days 2 milliseconds'), + (3, interval '2 seconds'), + (4, interval '8 months'), + (5, NULL); + +query IRR rowsort +select + id, + extract(second from i), + extract(month from i) +from t +order by id; +---- +0 0.00000001 5 +1 0 15 +2 0.002 0 +3 2 0 +4 0 8 +5 NULL NULL + +statement ok +drop table t; + +# test_extract_duration + +query R +SELECT extract(second from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query R +SELECT extract(seconds from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query R +SELECT extract(epoch from arrow_cast(2, 'Duration(Second)')) +---- +2 + +query R +SELECT extract(millisecond from arrow_cast(2, 'Duration(Second)')) +---- +2000 + +query R +SELECT extract(second from arrow_cast(2, 'Duration(Millisecond)')) +---- +0.002 + +query R +SELECT extract(second from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2.002 + +query R +SELECT extract(millisecond from arrow_cast(2002, 'Duration(Millisecond)')) +---- +2002 + +query R +SELECT extract(day from arrow_cast(864000, 'Duration(Second)')) +---- +10 + +query error DataFusion error: Arrow error: Compute error: Month does not support: Duration\(Second\) +SELECT extract(month from arrow_cast(864000, 'Duration(Second)')) + +query error DataFusion error: Arrow error: Compute error: Year does not support: Duration\(Second\) +SELECT extract(year from arrow_cast(864000, 'Duration(Second)')) + +query R +SELECT extract(day from arrow_cast(NULL, 'Duration(Second)')) +---- +NULL + # test_extract_date_part_func query B @@ -1646,7 +1992,7 @@ SELECT arrow_cast(decode(arrow_cast('746f6d', 'LargeBinary'),'hex'), 'Utf8'); ---- tom -query ? +query T SELECT encode(NULL,'base64'); ---- NULL @@ -1656,7 +2002,7 @@ SELECT decode(NULL,'base64'); ---- NULL -query ? +query T SELECT encode(NULL,'hex'); ---- NULL @@ -1701,7 +2047,7 @@ SELECT md5(''); ---- d41d8cd98f00b204e9800998ecf8427e -query ? +query T SELECT md5(NULL); ---- NULL @@ -1871,6 +2217,17 @@ SELECT digest('','blake3'); ---- af1349b9f5f9a1a6a0404dea36dcc9499bcb25c9adc112b7cc9a93cae41f3262 + +query T +SELECT substring('alphabet', 1) +---- +alphabet + +query T +SELECT substring('alphabet', 3, 2) +---- +ph + query T SELECT substring('alphabet' from 2 for 1); ---- @@ -1886,6 +2243,23 @@ SELECT substring('alphabet' for 1); ---- a +# The 'from' and 'for' parameters don't support string types, because they should be treated as +# regular expressions, which we have not implemented yet. +query error +SELECT substring('alphabet' FROM '3') + +query error +SELECT substring('alphabet' FROM '3' FOR '2') + +query error +SELECT substring('alphabet' FROM '3' FOR 2) + +query error +SELECT substring('alphabet' FROM 3 FOR '2') + +query error +SELECT substring('alphabet' FOR '2') + ##### csv_query_nullif_divide_by_0 @@ -1906,8 +2280,8 @@ CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); @@ -2200,7 +2574,7 @@ CREATE TABLE t_source( column1 String, column2 String, column3 String, - column4 String, + column4 String ) AS VALUES ('one', 'one', 'one', 'one'), ('two', 'two', '', 'two'), @@ -2246,13 +2620,13 @@ select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; 10.1 0 1.570796326795 # common subexpr with coalesce (short-circuited) -query RRR rowsort +query RRR select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; ---- 10.1 0.09900990099 1.471623942989 # common subexpr with coalesce (short-circuited) and alias -query RRR rowsort +query RRR select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; ---- 10.1 0.09900990099 1.471623942989 @@ -2296,7 +2670,42 @@ host3 3.3 # can have an aggregate function with an inner CASE WHEN query TR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; +---- +host1 101 +host2 202 +host3 303 + +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query TR +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; ---- host1 101 host2 202 @@ -2304,15 +2713,94 @@ host3 303 # can have 2 projections with aggr(short_circuited), with different short-circuited expr query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(coalesce(server_load1)), + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(coalesce(server['c1'])), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server, + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; + +query TRR +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then server_load1 + end + )), + sum(( + case when server_host is not null + then server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 -# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c1'] + end + )), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; ---- host1 1.1 101 host2 2.2 202 @@ -2325,3 +2813,28 @@ select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesc host1 1.1 101 host2 2.2 202 host3 3.3 303 + +statement ok +set datafusion.sql_parser.dialect = 'Postgres'; + +statement ok +create table t (a float) as values (1), (2), (3); + +query TT +explain select min(a) filter (where a > 1) as x from t; +---- +logical_plan +01)Projection: min(t.a) FILTER (WHERE t.a > Int64(1)) AS x +02)--Aggregate: groupBy=[[]], aggr=[[min(t.a) FILTER (WHERE t.a > Float32(1)) AS min(t.a) FILTER (WHERE t.a > Int64(1))]] +03)----TableScan: t projection=[a] +physical_plan +01)ProjectionExec: expr=[min(t.a) FILTER (WHERE t.a > Int64(1))@0 as x] +02)--AggregateExec: mode=Single, gby=[], aggr=[min(t.a) FILTER (WHERE t.a > Int64(1))] +03)----MemoryExec: partitions=1, partition_sizes=[1] + + +statement ok +drop table t; + +statement ok +set datafusion.sql_parser.dialect = 'Generic'; diff --git a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt index d5d3d87b5747..d96044fda8c0 100644 --- a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt +++ b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt @@ -37,7 +37,7 @@ logical_plan 02)--Filter: data.ticker = Utf8("A") 03)----TableScan: data projection=[date, ticker, time] physical_plan -01)SortPreservingMergeExec: [date@0 ASC NULLS LAST,time@2 ASC NULLS LAST] +01)SortPreservingMergeExec: [date@0 ASC NULLS LAST, time@2 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: ticker@1 = A 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 @@ -105,7 +105,7 @@ logical_plan 02)--Filter: data.ticker = Utf8("A") AND CAST(data.time AS Date32) = data.date 03)----TableScan: data projection=[date, ticker, time] physical_plan -01)SortPreservingMergeExec: [time@2 ASC NULLS LAST,date@0 ASC NULLS LAST] +01)SortPreservingMergeExec: [time@2 ASC NULLS LAST, date@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: ticker@1 = A AND CAST(time@2 AS Date32) = date@0 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 @@ -143,11 +143,11 @@ ORDER BY "ticker", "time"; ---- logical_plan 01)Sort: data.ticker ASC NULLS LAST, data.time ASC NULLS LAST -02)--Filter: data.date = Date32("13150") +02)--Filter: data.date = Date32("2006-01-02") 03)----TableScan: data projection=[date, ticker, time] physical_plan -01)SortPreservingMergeExec: [ticker@1 ASC NULLS LAST,time@2 ASC NULLS LAST] +01)SortPreservingMergeExec: [ticker@1 ASC NULLS LAST, time@2 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 -03)----FilterExec: date@0 = 13150 +03)----FilterExec: date@0 = 2006-01-02 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------StreamingTableExec: partition_sizes=1, projection=[date, ticker, time], infinite_source=true, output_ordering=[date@0 ASC NULLS LAST, ticker@1 ASC NULLS LAST, time@2 ASC NULLS LAST] diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index d0d2bac59e91..5b6017b08a00 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -18,46 +18,6 @@ # unicode expressions -query I -SELECT char_length('') ----- -0 - -query I -SELECT char_length('chars') ----- -5 - -query I -SELECT char_length('josé') ----- -4 - -query ? -SELECT char_length(NULL) ----- -NULL - -query I -SELECT character_length('') ----- -0 - -query I -SELECT character_length('chars') ----- -5 - -query I -SELECT character_length('josé') ----- -4 - -query ? -SELECT character_length(NULL) ----- -NULL - query T SELECT left('abcde', -2) ---- @@ -93,12 +53,12 @@ SELECT left('abcde', CAST(NULL AS INT)) ---- NULL -query ? +query T SELECT left(NULL, 2) ---- NULL -query ? +query T SELECT left(NULL, CAST(NULL AS INT)) ---- NULL @@ -128,101 +88,11 @@ SELECT length(arrow_cast('josé', 'Dictionary(Int32, Utf8)')) ---- 4 -query ? +query I SELECT length(NULL) ---- NULL -query T -SELECT lpad('hi', -1, 'xy') ----- -(empty) - -query T -SELECT lpad('hi', 5, 'xy') ----- -xyxhi - -query T -SELECT lpad('hi', -1) ----- -(empty) - -query T -SELECT lpad('hi', 0) ----- -(empty) - -query T -SELECT lpad('hi', 21, 'abcdef') ----- -abcdefabcdefabcdefahi - -query T -SELECT lpad('hi', 5, 'xy') ----- -xyxhi - -query T -SELECT lpad('hi', 5, NULL) ----- -NULL - -query T -SELECT lpad('hi', 5) ----- - hi - -query T -SELECT lpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5) ----- - hi - -query T -SELECT lpad('hi', CAST(NULL AS INT), 'xy') ----- -NULL - -query T -SELECT lpad('hi', CAST(NULL AS INT)) ----- -NULL - -query T -SELECT lpad('xyxhi', 3) ----- -xyx - -query ? -SELECT lpad(NULL, 0) ----- -NULL - -query ? -SELECT lpad(NULL, 5, 'xy') ----- -NULL - -query T -SELECT reverse('abcde') ----- -edcba - -query T -SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) ----- -edcba - -query T -SELECT reverse('loẅks') ----- -sk̈wol - -query ? -SELECT reverse(NULL) ----- -NULL - query T SELECT right('abcde', -2) ---- @@ -258,103 +128,13 @@ SELECT right('abcde', CAST(NULL AS INT)) ---- NULL -query ? -SELECT right(NULL, 2) ----- -NULL - -query ? -SELECT right(NULL, CAST(NULL AS INT)) ----- -NULL - -query T -SELECT rpad('hi', -1, 'xy') ----- -(empty) - -query T -SELECT rpad('hi', 5, 'xy') ----- -hixyx - -query T -SELECT rpad('hi', -1) ----- -(empty) - -query T -SELECT rpad('hi', 0) ----- -(empty) - -query T -SELECT rpad('hi', 21, 'abcdef') ----- -hiabcdefabcdefabcdefa - -query T -SELECT rpad('hi', 5, 'xy') ----- -hixyx - -query T -SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') ----- -hixyx - -query T -SELECT rpad('hi', 5, NULL) ----- -NULL - -query T -SELECT rpad('hi', 5) ----- -hi - -query T -SELECT rpad('hi', CAST(NULL AS INT), 'xy') ----- -NULL - query T -SELECT rpad('hi', CAST(NULL AS INT)) +SELECT right(NULL, 2) ---- NULL query T -SELECT rpad('xyxhi', 3) ----- -xyx - -query I -SELECT strpos('abc', 'c') ----- -3 - -query I -SELECT strpos('josé', 'é') ----- -4 - -query I -SELECT strpos('joséésoj', 'so') ----- -6 - -query I -SELECT strpos('joséésoj', 'abc') ----- -0 - -query ? -SELECT strpos(NULL, 'abc') ----- -NULL - -query I -SELECT strpos('joséésoj', NULL) +SELECT right(NULL, CAST(NULL AS INT)) ---- NULL @@ -413,10 +193,10 @@ SELECT substr('alphabet', 3, CAST(NULL AS int)) ---- NULL -statement error The SUBSTR function can only accept strings, but got Int64. +statement error The first argument of the substr function can only be a string, but got Int64 SELECT substr(1, 3) -statement error The SUBSTR function can only accept strings, but got Int64. +statement error The first argument of the substr function can only be a string, but got Int64 SELECT substr(1, 3, 4) query T @@ -429,7 +209,7 @@ SELECT translate(arrow_cast('12345', 'Dictionary(Int32, Utf8)'), '143', 'ax') ---- a2x5 -query ? +query T SELECT translate(NULL, '143', 'ax') ---- NULL @@ -487,23 +267,23 @@ statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function -statement error Did you mean 'COUNT'? +query error DataFusion error: Error during planning: Invalid function 'counter' SELECT counter(*) from test; # Aggregate function -statement error Did you mean 'STDDEV'? +statement error Did you mean 'stddev'? SELECT STDEV(v1) from test; # Aggregate function -statement error Did you mean 'COVAR'? +statement error DataFusion error: Error during planning: Invalid function 'covaria'.\nDid you mean 'covar'? SELECT COVARIA(1,1); # Window function -statement error Did you mean 'SUM'? +statement error SELECT v1, v2, SUMM(v2) OVER(ORDER BY v1) from test; # Window function -statement error Did you mean 'ROW_NUMBER'? +statement error Did you mean 'row_number'? SELECT v1, v2, ROWNUMBER() OVER(ORDER BY v1) from test; statement ok @@ -549,8 +329,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # sqrt_f32_vs_f64 @@ -712,35 +492,6 @@ SELECT md5(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) ---- acbd18db4cc2f85cedef654fccc4a4d8 -query T -SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') ----- -fooxx - -query T -SELECT repeat('foo', 3) ----- -foofoofoo - -query T -SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) ----- -foofoofoo - -query T -SELECT replace('foobar', 'bar', 'hello') ----- -foohello - -query T -SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') ----- -foohello query T SELECT rtrim(' foo ') @@ -752,36 +503,6 @@ SELECT rtrim(arrow_cast(' foo ', 'Dictionary(Int32, Utf8)')) ---- foo -query T -SELECT split_part('foo_bar', '_', 2) ----- -bar - -query T -SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) ----- -bar - -query B -SELECT starts_with('foobar', 'foo') ----- -true - -query B -SELECT starts_with('foobar', 'bar') ----- -false - -query B -SELECT ends_with('foobar', 'bar') ----- -true - -query B -SELECT ends_with('foobar', 'foo') ----- -false - query T SELECT trim(' foo ') ---- @@ -832,6 +553,16 @@ SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), 'world') ---- 6 +query I +SELECT strpos('helloworld', NULL) +---- +NULL + +query I +SELECT strpos(arrow_cast('helloworld', 'Dictionary(Int32, Utf8)'), NULL) +---- +NULL + statement ok CREATE TABLE products ( product_id INT PRIMARY KEY, @@ -871,7 +602,7 @@ SELECT products.* REPLACE (price*2 AS price, product_id+1000 AS product_id) FROM 1003 OldBrand Product 3 79.98 1004 OldBrand Product 4 99.98 -#overlay tests +# overlay tests statement ok CREATE TABLE over_test( str TEXT, @@ -913,205 +644,94 @@ NULL Thomxas NULL -query I -SELECT levenshtein('kitten', 'sitting') ----- -3 - -query I -SELECT levenshtein('kitten', NULL) +# overlay tests with utf8view +query T +SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos for len) from over_test ---- +abc +qwertyasdfg +ijkz +Thomas +NULL +NULL NULL - -query ? -SELECT levenshtein(NULL, 'sitting') ----- NULL -query ? -SELECT levenshtein(NULL, NULL) +query T +SELECT overlay(arrow_cast(str, 'Utf8View') placing arrow_cast(characters, 'Utf8View') from pos) from over_test ---- +abc +qwertyasdfg +ijk +Thomxas +NULL +NULL +Thomxas NULL -# Test substring_index using '.' as delimiter -# This query is compatible with MySQL(8.0.19 or later), convenient for comparing results -query TIT -SELECT str, n, substring_index(str, '.', n) AS c FROM - (VALUES - ROW('arrow.apache.org'), - ROW('.'), - ROW('...'), - ROW(NULL) - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(3), - ROW(100), - ROW(-1), - ROW(-2), - ROW(-3), - ROW(-100) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -NULL -100 NULL -NULL -3 NULL -NULL -2 NULL -NULL -1 NULL -NULL 1 NULL -NULL 2 NULL -NULL 3 NULL -NULL 100 NULL -arrow.apache.org -100 arrow.apache.org -arrow.apache.org -3 arrow.apache.org -arrow.apache.org -2 apache.org -arrow.apache.org -1 org -arrow.apache.org 1 arrow -arrow.apache.org 2 arrow.apache -arrow.apache.org 3 arrow.apache.org -arrow.apache.org 100 arrow.apache.org -... -100 ... -... -3 .. -... -2 . -... -1 (empty) -... 1 (empty) -... 2 . -... 3 .. -... 100 ... -. -100 . -. -3 . -. -2 . -. -1 (empty) -. 1 (empty) -. 2 . -. 3 . -. 100 . - -# Test substring_index using 'ac' as delimiter -query TIT -SELECT str, n, substring_index(str, 'ac', n) AS c FROM - (VALUES - -- input string does not contain the delimiter - ROW('arrow'), - -- input string contains the delimiter - ROW('arrow.apache.org') - ) AS strings(str), - (VALUES - ROW(1), - ROW(2), - ROW(-1), - ROW(-2) - ) AS occurrences(n) -ORDER BY str DESC, n; ----- -arrow.apache.org -2 arrow.apache.org -arrow.apache.org -1 he.org -arrow.apache.org 1 arrow.ap -arrow.apache.org 2 arrow.apache.org -arrow -2 arrow -arrow -1 arrow -arrow 1 arrow -arrow 2 arrow - -# Test substring_index with NULL values -query ?TT? -SELECT - substring_index(NULL, '.', 1), - substring_index('arrow.apache.org', NULL, 1), - substring_index('arrow.apache.org', '.', NULL), - substring_index(NULL, NULL, NULL) ----- -NULL NULL NULL NULL - -# Test substring_index with empty strings -query TT -SELECT - -- input string is empty - substring_index('', '.', 1), - -- delimiter is empty - substring_index('arrow.apache.org', '', 1) ----- -(empty) (empty) - -# Test substring_index with 0 occurrence -query T -SELECT substring_index('arrow.apache.org', 'ac', 0) +# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away +query B +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) ---- -(empty) +false -# Test substring_index with large occurrences -query TT -SELECT - -- i64::MIN - substring_index('arrow.apache.org', '.', -9223372036854775808) as c1, - -- i64::MAX - substring_index('arrow.apache.org', '.', 9223372036854775807) as c2; ----- -arrow.apache.org arrow.apache.org - -# Test substring_index issue https://github.com/apache/datafusion/issues/9472 -query TTT -SELECT - url, - substring_index(url, '.', 1) AS subdomain, - substring_index(url, '.', -1) AS tld -FROM - (VALUES ROW('docs.apache.com'), - ROW('community.influxdata.com'), - ROW('arrow.apache.org') - ) data(url) ----- -docs.apache.com docs com -community.influxdata.com community com -arrow.apache.org arrow org +####### +# verify that random() returns a different value for each row +####### +statement ok +create table t as values (1), (2); +statement ok +create table rand_table as select random() as r from t; +# should have 2 distinct values (not 1) query I -SELECT find_in_set('b', 'a,b,c,d') +select count(distinct r) from rand_table; ---- 2 +statement ok +drop table rand_table -query I -SELECT find_in_set('a', 'a,b,c,d,a') ----- -1 +statement ok +drop table t -query I -SELECT find_in_set('', 'a,b,c,d,a') ----- -0 -query I -SELECT find_in_set('a', '') ----- -0 +####### +# verify that uuid() returns a different value for each row +####### +statement ok +create table t as values (1), (2); +statement ok +create table uuid_table as select uuid() as u from t; +# should have 2 distinct values (not 1) query I -SELECT find_in_set('', '') +select count(distinct u) from uuid_table; ---- -1 +2 -query ? -SELECT find_in_set(NULL, 'a,b,c,d') ----- -NULL +statement ok +drop table uuid_table -query I -SELECT find_in_set('a', NULL) ----- -NULL +statement ok +drop table t -query ? -SELECT find_in_set(NULL, NULL) +# test for contains + +query B +select contains('alphabet', 'pha'); ---- -NULL +true -# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B -SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) +select contains('alphabet', 'dddd'); ---- false + +query B +select contains('', ''); +---- +true diff --git a/datafusion/sqllogictest/test_files/group.slt b/datafusion/sqllogictest/test_files/group.slt index 2a28efa73a62..a6b5f9b72a53 100644 --- a/datafusion/sqllogictest/test_files/group.slt +++ b/datafusion/sqllogictest/test_files/group.slt @@ -32,11 +32,11 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); # csv_query_group_by_int_min_max diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index e015f7b01d0c..daf270190870 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -1962,9 +1962,9 @@ GROUP BY ALL; 2 0 13 query IIR rowsort -SELECT sub.col1, sub.col0, sub."AVG(tab3.col2)" AS avg_col2 +SELECT sub.col1, sub.col0, sub."avg(tab3.col2)" AS avg_col2 FROM ( - SELECT col1, AVG(col2), col0 FROM tab3 GROUP BY ALL + SELECT col1, avg(col2), col0 FROM tab3 GROUP BY ALL ) AS sub GROUP BY ALL; ---- @@ -2005,8 +2005,8 @@ ORDER BY l.col0; ---- logical_plan 01)Sort: l.col0 ASC NULLS LAST -02)--Projection: l.col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST] AS last_col1 -03)----Aggregate: groupBy=[[l.col0, l.col1, l.col2]], aggr=[[LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]]] +02)--Projection: l.col0, last_value(r.col1) ORDER BY [r.col0 ASC NULLS LAST] AS last_col1 +03)----Aggregate: groupBy=[[l.col0, l.col1, l.col2]], aggr=[[last_value(r.col1) ORDER BY [r.col0 ASC NULLS LAST]]] 04)------Inner Join: l.col0 = r.col0 05)--------SubqueryAlias: l 06)----------TableScan: tab0 projection=[col0, col1, col2] @@ -2014,21 +2014,18 @@ logical_plan 08)----------TableScan: tab0 projection=[col0, col1] physical_plan 01)SortPreservingMergeExec: [col0@0 ASC NULLS LAST] -02)--SortExec: expr=[col0@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[col0@0 as col0, LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] -04)------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]] +02)--SortExec: expr=[col0@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[col0@0 as col0, last_value(r.col1) ORDER BY [r.col0 ASC NULLS LAST]@3 as last_col1] +04)------AggregateExec: mode=FinalPartitioned, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[last_value(r.col1) ORDER BY [r.col0 ASC NULLS LAST]] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([col0@0, col1@1, col2@2], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[LAST_VALUE(r.col1) ORDER BY [r.col0 ASC NULLS LAST]] -08)--------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] -09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] -11)--------------------CoalesceBatchesExec: target_batch_size=8192 -12)----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 -13)------------------------MemoryExec: partitions=1, partition_sizes=[3] -14)--------------------CoalesceBatchesExec: target_batch_size=8192 -15)----------------------RepartitionExec: partitioning=Hash([col0@0], 4), input_partitions=1 -16)------------------------MemoryExec: partitions=1, partition_sizes=[3] +07)------------AggregateExec: mode=Partial, gby=[col0@0 as col0, col1@1 as col1, col2@2 as col2], aggr=[last_value(r.col1) ORDER BY [r.col0 ASC NULLS LAST]] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------ProjectionExec: expr=[col0@2 as col0, col1@3 as col1, col2@4 as col2, col0@0 as col0, col1@1 as col1] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(col0@0, col0@0)] +12)----------------------MemoryExec: partitions=1, partition_sizes=[3] +13)----------------------MemoryExec: partitions=1, partition_sizes=[3] # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2042,9 +2039,9 @@ CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite2 ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # Create a table with 2 ordered columns. # In the next step, we will expect to observe the removed sort execs. @@ -2057,10 +2054,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # Expected a sort exec for b DESC query TT @@ -2072,7 +2069,7 @@ logical_plan 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan 01)ProjectionExec: expr=[a@0 as a] -02)--SortExec: expr=[b@1 DESC] +02)--SortExec: expr=[b@1 DESC], preserve_partitioning=[false] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true # Final plan shouldn't have SortExec c ASC, @@ -2109,12 +2106,12 @@ EXPLAIN SELECT a, b, GROUP BY b, a ---- logical_plan -01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, SUM(annotated_data_infinite2.c) AS summation1 -02)--Aggregate: groupBy=[[annotated_data_infinite2.b, annotated_data_infinite2.a]], aggr=[[SUM(CAST(annotated_data_infinite2.c AS Int64))]] +01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, sum(annotated_data_infinite2.c) AS summation1 +02)--Aggregate: groupBy=[[annotated_data_infinite2.b, annotated_data_infinite2.a]], aggr=[[sum(CAST(annotated_data_infinite2.c AS Int64))]] 03)----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan -01)ProjectionExec: expr=[a@1 as a, b@0 as b, SUM(annotated_data_infinite2.c)@2 as summation1] -02)--AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[SUM(annotated_data_infinite2.c)], ordering_mode=Sorted +01)ProjectionExec: expr=[a@1 as a, b@0 as b, sum(annotated_data_infinite2.c)@2 as summation1] +02)--AggregateExec: mode=Single, gby=[b@1 as b, a@0 as a], aggr=[sum(annotated_data_infinite2.c)], ordering_mode=Sorted 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] @@ -2140,12 +2137,12 @@ EXPLAIN SELECT a, d, GROUP BY d, a ---- logical_plan -01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS summation1 -02)--Aggregate: groupBy=[[annotated_data_infinite2.d, annotated_data_infinite2.a]], aggr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.d, sum(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS summation1 +02)--Aggregate: groupBy=[[annotated_data_infinite2.d, annotated_data_infinite2.a]], aggr=[[sum(CAST(annotated_data_infinite2.c AS Int64)) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] 03)----TableScan: annotated_data_infinite2 projection=[a, c, d] physical_plan -01)ProjectionExec: expr=[a@1 as a, d@0 as d, SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] -02)--AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[SUM(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]], ordering_mode=PartiallySorted([1]) +01)ProjectionExec: expr=[a@1 as a, d@0 as d, sum(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as summation1] +02)--AggregateExec: mode=Single, gby=[d@2 as d, a@0 as a], aggr=[sum(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]], ordering_mode=PartiallySorted([1]) 03)----StreamingTableExec: partition_sizes=1, projection=[a, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST] query III @@ -2173,12 +2170,12 @@ EXPLAIN SELECT a, b, FIRST_VALUE(c ORDER BY a DESC) as first_c GROUP BY a, b ---- logical_plan -01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS first_c -02)--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, first_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS first_c +02)--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[first_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] 03)----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, b@1 as b, FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] -02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[FIRST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]], ordering_mode=Sorted +01)ProjectionExec: expr=[a@0 as a, b@1 as b, first_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as first_c] +02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[first_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]], ordering_mode=Sorted 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III @@ -2199,12 +2196,12 @@ EXPLAIN SELECT a, b, LAST_VALUE(c ORDER BY a DESC) as last_c GROUP BY a, b ---- logical_plan -01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS last_c -02)--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] +01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, last_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST] AS last_c +02)--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[last_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]]] 03)----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] -02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]], ordering_mode=Sorted +01)ProjectionExec: expr=[a@0 as a, b@1 as b, last_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]@2 as last_c] +02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[last_value(annotated_data_infinite2.c) ORDER BY [annotated_data_infinite2.a DESC NULLS FIRST]], ordering_mode=Sorted 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III @@ -2226,12 +2223,12 @@ EXPLAIN SELECT a, b, LAST_VALUE(c) as last_c GROUP BY a, b ---- logical_plan -01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, LAST_VALUE(annotated_data_infinite2.c) AS last_c -02)--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[LAST_VALUE(annotated_data_infinite2.c)]] +01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, last_value(annotated_data_infinite2.c) AS last_c +02)--Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[last_value(annotated_data_infinite2.c)]] 03)----TableScan: annotated_data_infinite2 projection=[a, b, c] physical_plan -01)ProjectionExec: expr=[a@0 as a, b@1 as b, LAST_VALUE(annotated_data_infinite2.c)@2 as last_c] -02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[LAST_VALUE(annotated_data_infinite2.c)], ordering_mode=Sorted +01)ProjectionExec: expr=[a@0 as a, b@1 as b, last_value(annotated_data_infinite2.c)@2 as last_c] +02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[last_value(annotated_data_infinite2.c)], ordering_mode=Sorted 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, c], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query III @@ -2253,7 +2250,7 @@ logical_plan 01)Sort: annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.d ASC NULLS LAST 02)--TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] physical_plan -01)PartialSortExec: expr=[a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,d@4 ASC NULLS LAST], common_prefix_length=[2] +01)PartialSortExec: expr=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, d@4 ASC NULLS LAST], common_prefix_length=[2] 02)--StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] query TT @@ -2263,13 +2260,11 @@ ORDER BY a, b, d LIMIT 50; ---- logical_plan -01)Limit: skip=0, fetch=50 -02)--Sort: annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.d ASC NULLS LAST, fetch=50 -03)----TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] +01)Sort: annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.d ASC NULLS LAST, fetch=50 +02)--TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] physical_plan -01)GlobalLimitExec: skip=0, fetch=50 -02)--PartialSortExec: TopK(fetch=50), expr=[a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,d@4 ASC NULLS LAST], common_prefix_length=[2] -03)----StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] +01)PartialSortExec: TopK(fetch=50), expr=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, d@4 ASC NULLS LAST], common_prefix_length=[2] +02)--StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] query TT EXPLAIN SELECT * @@ -2280,7 +2275,7 @@ logical_plan 01)Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST 02)--TableScan: multiple_ordered_table projection=[a0, a, b, c, d] physical_plan -01)SortExec: expr=[a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,d@4 ASC NULLS LAST] +01)SortExec: expr=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, d@4 ASC NULLS LAST], preserve_partitioning=[false] 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_orderings=[[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], [c@3 ASC NULLS LAST]], has_header=true query TT @@ -2289,11 +2284,11 @@ FROM annotated_data_infinite2 GROUP BY a, b; ---- logical_plan -01)Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[ARRAY_AGG(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]]] +01)Aggregate: groupBy=[[annotated_data_infinite2.a, annotated_data_infinite2.b]], aggr=[[array_agg(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]]] 02)--TableScan: annotated_data_infinite2 projection=[a, b, d] physical_plan -01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]], ordering_mode=Sorted -02)--PartialSortExec: expr=[a@0 ASC NULLS LAST,b@1 ASC NULLS LAST,d@2 ASC NULLS LAST], common_prefix_length=[2] +01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[array_agg(annotated_data_infinite2.d) ORDER BY [annotated_data_infinite2.d ASC NULLS LAST]], ordering_mode=Sorted +02)--PartialSortExec: expr=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, d@2 ASC NULLS LAST], common_prefix_length=[2] 03)----StreamingTableExec: partition_sizes=1, projection=[a, b, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST] # as can be seen in the result below d is indeed ordered. @@ -2459,13 +2454,13 @@ EXPLAIN SELECT country, (ARRAY_AGG(amount ORDER BY amount ASC)) AS amounts GROUP BY country ---- logical_plan -01)Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] -03)----SortExec: expr=[amount@1 ASC NULLS LAST] +01)ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +03)----SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] @@ -2488,14 +2483,14 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s 04)------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)] -03)----SortExec: expr=[amount@1 DESC] +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)] +03)----SortExec: expr=[amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query T?R rowsort @@ -2527,21 +2522,23 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 FROM (SELECT * FROM sales_global - ORDER BY country) AS s + ORDER BY country + LIMIT 10) AS s GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s -04)------Sort: sales_global.country ASC NULLS LAST +04)------Sort: sales_global.country ASC NULLS LAST, fetch=10 05)--------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)], ordering_mode=Sorted -03)----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted +03)----SortExec: TopK(fetch=10), expr=[country@0 ASC NULLS LAST, amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] + query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.amount DESC) AS amounts, SUM(s.amount) AS sum1 @@ -2563,19 +2560,20 @@ EXPLAIN SELECT s.country, s.zip_code, ARRAY_AGG(s.amount ORDER BY s.amount DESC) SUM(s.amount) AS sum1 FROM (SELECT * FROM sales_global - ORDER BY country) AS s + ORDER BY country + LIMIT 10) AS s GROUP BY s.country, s.zip_code ---- logical_plan -01)Projection: s.country, s.zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +01)Projection: s.country, s.zip_code, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country, s.zip_code]], aggr=[[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s -04)------Sort: sales_global.country ASC NULLS LAST +04)------Sort: sales_global.country ASC NULLS LAST, fetch=10 05)--------TableScan: sales_global projection=[zip_code, country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, SUM(s.amount)@3 as sum1] -02)--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.amount DESC NULLS FIRST], SUM(s.amount)], ordering_mode=PartiallySorted([0]) -03)----SortExec: expr=[country@1 ASC NULLS LAST,amount@2 DESC] +01)ProjectionExec: expr=[country@0 as country, zip_code@1 as zip_code, array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST]@2 as amounts, sum(s.amount)@3 as sum1] +02)--AggregateExec: mode=Single, gby=[country@1 as country, zip_code@0 as zip_code], aggr=[array_agg(s.amount) ORDER BY [s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=PartiallySorted([0]) +03)----SortExec: TopK(fetch=10), expr=[country@1 ASC NULLS LAST, amount@2 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query TI?R rowsort @@ -2599,19 +2597,20 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC) AS amounts SUM(s.amount) AS sum1 FROM (SELECT * FROM sales_global - ORDER BY country) AS s + ORDER BY country + LIMIT 10) AS s GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s -04)------Sort: sales_global.country ASC NULLS LAST +04)------Sort: sales_global.country ASC NULLS LAST, fetch=10 05)--------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST], SUM(s.amount)], ordering_mode=Sorted -03)----SortExec: expr=[country@0 ASC NULLS LAST] +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted +03)----SortExec: TopK(fetch=10), expr=[country@0 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query T?R rowsort @@ -2634,21 +2633,23 @@ EXPLAIN SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount D SUM(s.amount) AS sum1 FROM (SELECT * FROM sales_global - ORDER BY country) AS s + ORDER BY country + LIMIT 10) AS s GROUP BY s.country ---- logical_plan -01)Projection: s.country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, SUM(s.amount) AS sum1 -02)--Aggregate: groupBy=[[s.country]], aggr=[[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], SUM(CAST(s.amount AS Float64))]] +01)Projection: s.country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST] AS amounts, sum(s.amount) AS sum1 +02)--Aggregate: groupBy=[[s.country]], aggr=[[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], sum(CAST(s.amount AS Float64))]] 03)----SubqueryAlias: s -04)------Sort: sales_global.country ASC NULLS LAST +04)------Sort: sales_global.country ASC NULLS LAST, fetch=10 05)--------TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, SUM(s.amount)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], SUM(s.amount)], ordering_mode=Sorted -03)----SortExec: expr=[country@0 ASC NULLS LAST,amount@1 DESC] +01)ProjectionExec: expr=[country@0 as country, array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST]@1 as amounts, sum(s.amount)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(s.amount) ORDER BY [s.country DESC NULLS FIRST, s.amount DESC NULLS FIRST], sum(s.amount)], ordering_mode=Sorted +03)----SortExec: TopK(fetch=10), expr=[country@0 ASC NULLS LAST, amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] + query T?R rowsort SELECT s.country, ARRAY_AGG(s.amount ORDER BY s.country DESC, s.amount DESC) AS amounts, SUM(s.amount) AS sum1 @@ -2672,13 +2673,13 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +01)Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] -03)----SortExec: expr=[amount@1 DESC] +01)ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] +03)----SortExec: expr=[amount@1 DESC], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query T?RR rowsort @@ -2703,13 +2704,13 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS amounts, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +01)Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] -03)----SortExec: expr=[amount@1 ASC NULLS LAST] +01)ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +03)----SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query T?RR @@ -2735,13 +2736,13 @@ EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY amount ASC) AS fv1, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[country, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] -03)----SortExec: expr=[amount@1 ASC NULLS LAST] +01)ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@2 as fv2, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@3 as amounts] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +03)----SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query TRR? @@ -2765,13 +2766,13 @@ EXPLAIN SELECT country, SUM(amount ORDER BY ts DESC) AS sum1, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[SUM(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: sales_global.country, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS amounts +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[sum(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[country, ts, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as sum1, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as amounts] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] -03)----SortExec: expr=[amount@2 ASC NULLS LAST] +01)ProjectionExec: expr=[country@0 as country, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as sum1, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as amounts] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +03)----SortExec: expr=[amount@2 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] query TR? @@ -2799,13 +2800,12 @@ EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST]]] -03)----Sort: sales_global.ts ASC NULLS LAST -04)------TableScan: sales_global projection=[country, ts, amount] +01)Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], sum(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +03)----TableScan: sales_global projection=[country, ts, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +01)ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 03)----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort @@ -2833,12 +2833,12 @@ EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts DESC) as fv1, GROUP BY country ---- logical_plan -01)Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 -02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +01)Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS lv1, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS sum1 +02)--Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], sum(CAST(sales_global.amount AS Float64)) ORDER BY [sales_global.ts DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[country, ts, amount] physical_plan -01)ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] -02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], SUM(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +01)ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as lv1, sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@3 as sum1] +02)--AggregateExec: mode=Single, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST], sum(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 03)----MemoryExec: partitions=1, partition_sizes=[1] query TRRR rowsort @@ -2863,8 +2863,8 @@ ORDER BY s.sn ---- logical_plan 01)Sort: s.sn ASC NULLS LAST -02)--Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST] AS last_rate -03)----Aggregate: groupBy=[[s.sn, s.zip_code, s.country, s.ts, s.currency]], aggr=[[LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]]] +02)--Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, last_value(e.amount) ORDER BY [e.sn ASC NULLS LAST] AS last_rate +03)----Aggregate: groupBy=[[s.sn, s.zip_code, s.country, s.ts, s.currency]], aggr=[[last_value(e.amount) ORDER BY [e.sn ASC NULLS LAST]]] 04)------Projection: s.zip_code, s.country, s.sn, s.ts, s.currency, e.sn, e.amount 05)--------Inner Join: s.currency = e.currency Filter: s.ts >= e.ts 06)----------SubqueryAlias: s @@ -2872,9 +2872,9 @@ logical_plan 08)----------SubqueryAlias: e 09)------------TableScan: sales_global projection=[sn, ts, currency, amount] physical_plan -01)SortExec: expr=[sn@2 ASC NULLS LAST] -02)--ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] -03)----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[LAST_VALUE(e.amount) ORDER BY [e.sn ASC NULLS LAST]] +01)SortExec: expr=[sn@2 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, last_value(e.amount) ORDER BY [e.sn ASC NULLS LAST]@5 as last_rate] +03)----AggregateExec: mode=Single, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency], aggr=[last_value(e.amount) ORDER BY [e.sn ASC NULLS LAST]] 04)------ProjectionExec: expr=[zip_code@2 as zip_code, country@3 as country, sn@4 as sn, ts@5 as ts, currency@6 as currency, sn@0 as sn, amount@1 as amount] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@2, currency@4)], filter=ts@0 >= ts@1, projection=[sn@0, amount@3, zip_code@4, country@5, sn@6, ts@7, currency@8] @@ -2912,18 +2912,18 @@ EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, ---- logical_plan 01)Sort: sales_global.country ASC NULLS LAST -02)--Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv2 -03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +02)--Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv2 +03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] 04)------TableScan: sales_global projection=[country, ts, amount] physical_plan 01)SortPreservingMergeExec: [country@0 ASC NULLS LAST] -02)--SortExec: expr=[country@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@2 as fv2] -04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +02)--SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@2 as fv2] +04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +08)--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 09)----------------MemoryExec: partitions=1, partition_sizes=[1] query TRR @@ -2948,20 +2948,21 @@ EXPLAIN SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, ---- logical_plan 01)Sort: sales_global.country ASC NULLS LAST -02)--Projection: sales_global.country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv2 -03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +02)--Projection: sales_global.country, first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv2 +03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] 04)------TableScan: sales_global projection=[country, ts, amount] physical_plan 01)SortPreservingMergeExec: [country@0 ASC NULLS LAST] -02)--SortExec: expr=[country@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[country@0 as country, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] -04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +02)--SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[country@0 as country, first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@2 as fv2] +04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -08)--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +08)--------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 09)----------------MemoryExec: partitions=1, partition_sizes=[1] + query TRR SELECT country, FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, LAST_VALUE(amount ORDER BY ts DESC) AS fv2 @@ -2986,14 +2987,14 @@ EXPLAIN SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, FROM sales_global ---- logical_plan -01)Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv2 -02)--Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +01)Projection: first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv2 +02)--Aggregate: groupBy=[[]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[ts, amount] physical_plan -01)ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv2] -02)--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +01)ProjectionExec: expr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@1 as fv2] +02)--AggregateExec: mode=Final, gby=[], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +04)------AggregateExec: mode=Partial, gby=[], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 05)--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 06)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3012,14 +3013,14 @@ EXPLAIN SELECT FIRST_VALUE(amount ORDER BY ts ASC) AS fv1, FROM sales_global ---- logical_plan -01)Projection: FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv2 -02)--Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +01)Projection: first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS fv2 +02)--Aggregate: groupBy=[[]], aggr=[[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[ts, amount] physical_plan -01)ProjectionExec: expr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] -02)--AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +01)ProjectionExec: expr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@1 as fv2] +02)--AggregateExec: mode=Final, gby=[], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +04)------AggregateExec: mode=Partial, gby=[], aggr=[first_value(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 05)--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 06)----------MemoryExec: partitions=1, partition_sizes=[1] @@ -3036,15 +3037,15 @@ EXPLAIN SELECT ARRAY_AGG(amount ORDER BY ts ASC) AS array_agg1 FROM sales_global ---- logical_plan -01)Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] +01)Projection: array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST] AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[ts, amount] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +01)ProjectionExec: expr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] -05)--------SortExec: expr=[ts@0 ASC NULLS LAST] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts ASC NULLS LAST]] +05)--------SortExec: expr=[ts@0 ASC NULLS LAST], preserve_partitioning=[true] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3060,15 +3061,15 @@ EXPLAIN SELECT ARRAY_AGG(amount ORDER BY ts DESC) AS array_agg1 FROM sales_global ---- logical_plan -01)Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] +01)Projection: array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST] AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]]] 03)----TableScan: sales_global projection=[ts, amount] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +01)ProjectionExec: expr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] -05)--------SortExec: expr=[ts@0 DESC] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.ts DESC NULLS FIRST]] +05)--------SortExec: expr=[ts@0 DESC], preserve_partitioning=[true] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3084,15 +3085,15 @@ EXPLAIN SELECT ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 FROM sales_global ---- logical_plan -01)Projection: ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +01)Projection: array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 03)----TableScan: sales_global projection=[amount] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +01)ProjectionExec: expr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] -05)--------SortExec: expr=[amount@0 ASC NULLS LAST] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +05)--------SortExec: expr=[amount@0 ASC NULLS LAST], preserve_partitioning=[true] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3111,18 +3112,18 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount ASC) AS array_agg1 ---- logical_plan 01)Sort: sales_global.country ASC NULLS LAST -02)--Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 -03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] +02)--Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS array_agg1 +03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]]] 04)------TableScan: sales_global projection=[country, amount] physical_plan 01)SortPreservingMergeExec: [country@0 ASC NULLS LAST] -02)--SortExec: expr=[country@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] -04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +02)--SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@1 as array_agg1] +04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] -08)--------------SortExec: expr=[amount@1 ASC NULLS LAST] +07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]] +08)--------------SortExec: expr=[amount@1 ASC NULLS LAST], preserve_partitioning=[true] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3147,18 +3148,18 @@ EXPLAIN SELECT country, ARRAY_AGG(amount ORDER BY amount DESC) AS amounts, ---- logical_plan 01)Sort: sales_global.country ASC NULLS LAST -02)--Projection: sales_global.country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 -03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] +02)--Projection: sales_global.country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST] AS fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST] AS fv2 +03)----Aggregate: groupBy=[[sales_global.country]], aggr=[[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]]] 04)------TableScan: sales_global projection=[country, amount] physical_plan 01)SortPreservingMergeExec: [country@0 ASC NULLS LAST] -02)--SortExec: expr=[country@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[country@0 as country, ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] -04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], FIRST_VALUE(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] +02)--SortExec: expr=[country@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[country@0 as country, array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@1 as amounts, first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST]@2 as fv1, last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]@3 as fv2] +04)------AggregateExec: mode=FinalPartitioned, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], first_value(sales_global.amount) ORDER BY [sales_global.amount ASC NULLS LAST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([country@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[ARRAY_AGG(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], LAST_VALUE(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] -08)--------------SortExec: expr=[amount@1 DESC] +07)------------AggregateExec: mode=Partial, gby=[country@0 as country], aggr=[array_agg(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST], last_value(sales_global.amount) ORDER BY [sales_global.amount DESC NULLS FIRST]] +08)--------------SortExec: expr=[amount@1 DESC], preserve_partitioning=[true] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------MemoryExec: partitions=1, partition_sizes=[1] @@ -3353,13 +3354,14 @@ logical_plan 05)--------TableScan: sales_global_with_pk projection=[sn, amount] physical_plan 01)SortPreservingMergeExec: [sn@0 ASC NULLS LAST] -02)--SortExec: expr=[sn@0 ASC NULLS LAST] +02)--SortExec: expr=[sn@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----ProjectionExec: expr=[sn@0 as sn, amount@1 as amount, 2 * CAST(sn@0 AS Int64) as Int64(2) * s.sn] 04)------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@0 as sn, amount@1 as amount], aggr=[] -08)--------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +09)----------------MemoryExec: partitions=1, partition_sizes=[1] query IRI SELECT s.sn, s.amount, 2*s.sn @@ -3411,8 +3413,8 @@ EXPLAIN SELECT r.sn, SUM(l.amount), r.amount ---- logical_plan 01)Sort: r.sn ASC NULLS LAST -02)--Projection: r.sn, SUM(l.amount), r.amount -03)----Aggregate: groupBy=[[r.sn, r.amount]], aggr=[[SUM(CAST(l.amount AS Float64))]] +02)--Projection: r.sn, sum(l.amount), r.amount +03)----Aggregate: groupBy=[[r.sn, r.amount]], aggr=[[sum(CAST(l.amount AS Float64))]] 04)------Projection: l.amount, r.sn, r.amount 05)--------Inner Join: Filter: l.sn >= r.sn 06)----------SubqueryAlias: l @@ -3421,17 +3423,17 @@ logical_plan 09)------------TableScan: sales_global_with_pk projection=[sn, amount] physical_plan 01)SortPreservingMergeExec: [sn@0 ASC NULLS LAST] -02)--SortExec: expr=[sn@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[sn@0 as sn, SUM(l.amount)@2 as SUM(l.amount), amount@1 as amount] -04)------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[SUM(l.amount)] +02)--SortExec: expr=[sn@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[sn@0 as sn, sum(l.amount)@2 as sum(l.amount), amount@1 as amount] +04)------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, amount@1 as amount], aggr=[sum(l.amount)] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([sn@0, amount@1], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[SUM(l.amount)] +07)------------AggregateExec: mode=Partial, gby=[sn@1 as sn, amount@2 as amount], aggr=[sum(l.amount)] 08)--------------ProjectionExec: expr=[amount@1 as amount, sn@2 as sn, amount@3 as amount] 09)----------------NestedLoopJoinExec: join_type=Inner, filter=sn@0 >= sn@1 -10)------------------CoalescePartitionsExec -11)--------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] -12)------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +12)--------------------MemoryExec: partitions=1, partition_sizes=[1] query IRR SELECT r.sn, SUM(l.amount), r.amount @@ -3467,7 +3469,7 @@ SELECT r.sn, SUM(l.amount), r.amount # to associate it with other fields, aggregate should contain all the composite columns # if any of the composite column is missing, we cannot use associated indices, inside select expression # below query should fail -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.amount could not be resolved from available columns: r.sn, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non\-aggregate values: Expression r\.amount could not be resolved from available columns: r\.sn, sum\(l\.amount\) SELECT r.sn, SUM(l.amount), r.amount FROM sales_global_with_composite_pk AS l JOIN sales_global_with_composite_pk AS r @@ -3495,7 +3497,7 @@ NULL NULL NULL # left join shouldn't propagate right side constraint, # if right side is a unique key (unique and can contain null) # Please note that, above query and this one is same except the constraint in the table. -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.amount could not be resolved from available columns: r.sn, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non\-aggregate values: Expression r\.amount could not be resolved from available columns: r\.sn, sum\(r\.amount\) SELECT r.sn, r.amount, SUM(r.amount) FROM (SELECT * FROM sales_global_with_unique as l @@ -3541,7 +3543,7 @@ SELECT column1, COUNT(*) as column2 FROM (VALUES (['a', 'b'], 1), (['c', 'd', 'e # primary key should be aware from which columns it is associated -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression r.sn could not be resolved from available columns: l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, SUM\(l.amount\) +statement error DataFusion error: Error during planning: Projection references non\-aggregate values: Expression r\.sn could not be resolved from available columns: l\.sn, l\.zip_code, l\.country, l\.ts, l\.currency, l\.amount, sum\(l\.amount\) SELECT l.sn, r.sn, SUM(l.amount), r.amount FROM sales_global_with_pk AS l JOIN sales_global_with_pk AS r @@ -3563,23 +3565,22 @@ logical_plan 02)--Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, l.sum_amount 03)----Aggregate: groupBy=[[l.sn, l.zip_code, l.country, l.ts, l.currency, l.amount, l.sum_amount]], aggr=[[]] 04)------SubqueryAlias: l -05)--------Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum_amount -06)----------WindowAggr: windowExpr=[[SUM(CAST(l.amount AS Float64)) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +05)--------Projection: l.zip_code, l.country, l.sn, l.ts, l.currency, l.amount, sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum_amount +06)----------WindowAggr: windowExpr=[[sum(CAST(l.amount AS Float64)) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 07)------------SubqueryAlias: l 08)--------------TableScan: sales_global_with_pk projection=[zip_code, country, sn, ts, currency, amount] physical_plan 01)SortPreservingMergeExec: [sn@2 ASC NULLS LAST] -02)--SortExec: expr=[sn@2 ASC NULLS LAST] +02)--SortExec: expr=[sn@2 ASC NULLS LAST], preserve_partitioning=[true] 03)----ProjectionExec: expr=[zip_code@1 as zip_code, country@2 as country, sn@0 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount] 04)------AggregateExec: mode=FinalPartitioned, gby=[sn@0 as sn, zip_code@1 as zip_code, country@2 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] 05)--------CoalesceBatchesExec: target_batch_size=4 06)----------RepartitionExec: partitioning=Hash([sn@0, zip_code@1, country@2, ts@3, currency@4, amount@5, sum_amount@6], 8), input_partitions=8 07)------------AggregateExec: mode=Partial, gby=[sn@2 as sn, zip_code@0 as zip_code, country@1 as country, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum_amount@6 as sum_amount], aggr=[] 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -09)----------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] -10)------------------BoundedWindowAggExec: wdw=[SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -11)--------------------CoalescePartitionsExec -12)----------------------MemoryExec: partitions=8, partition_sizes=[1, 0, 0, 0, 0, 0, 0, 0] +09)----------------ProjectionExec: expr=[zip_code@0 as zip_code, country@1 as country, sn@2 as sn, ts@3 as ts, currency@4 as currency, amount@5 as amount, sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@6 as sum_amount] +10)------------------BoundedWindowAggExec: wdw=[sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(l.amount) ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +11)--------------------MemoryExec: partitions=1, partition_sizes=[1] query ITIPTRR @@ -3633,7 +3634,7 @@ ORDER BY r.sn 4 100 2022-01-03T10:00:00 # after join, new window expressions shouldn't be associated with primary keys -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression rn1 could not be resolved from available columns: r.sn, r.ts, r.amount, SUM\(r.amount\) +statement error DataFusion error: Error during planning: Projection references non\-aggregate values: Expression rn1 could not be resolved from available columns: r\.sn, r\.ts, r\.amount, sum\(r\.amount\) SELECT r.sn, SUM(r.amount), rn1 FROM (SELECT r.ts, r.sn, r.amount, @@ -3755,12 +3756,12 @@ EXPLAIN SELECT LAST_VALUE(x) FROM FOO; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[LAST_VALUE(foo.x)]] +01)Aggregate: groupBy=[[]], aggr=[[last_value(foo.x)]] 02)--TableScan: foo projection=[x] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[LAST_VALUE(foo.x)] +01)AggregateExec: mode=Final, gby=[], aggr=[last_value(foo.x)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[LAST_VALUE(foo.x)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[last_value(foo.x)] 04)------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -3777,12 +3778,12 @@ EXPLAIN SELECT FIRST_VALUE(x) FROM FOO; ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[FIRST_VALUE(foo.x)]] +01)Aggregate: groupBy=[[]], aggr=[[first_value(foo.x)]] 02)--TableScan: foo projection=[x] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[FIRST_VALUE(foo.x)] +01)AggregateExec: mode=Final, gby=[], aggr=[first_value(foo.x)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[FIRST_VALUE(foo.x)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[first_value(foo.x)] 04)------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -3795,15 +3796,15 @@ FROM multiple_ordered_table GROUP BY d; ---- logical_plan -01)Projection: FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST] AS first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] AS last_c -02)--Aggregate: groupBy=[[multiple_ordered_table.d]], aggr=[[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +01)Projection: first_value(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST] AS first_a, last_value(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] AS last_c +02)--Aggregate: groupBy=[[multiple_ordered_table.d]], aggr=[[first_value(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], last_value(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] 03)----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan -01)ProjectionExec: expr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] -02)--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], LAST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]] +01)ProjectionExec: expr=[first_value(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST]@1 as first_a, last_value(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]@2 as last_c] +02)--AggregateExec: mode=FinalPartitioned, gby=[d@0 as d], aggr=[first_value(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], last_value(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]] 03)----CoalesceBatchesExec: target_batch_size=2 04)------RepartitionExec: partitioning=Hash([d@0], 8), input_partitions=8 -05)--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[FIRST_VALUE(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], FIRST_VALUE(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]] +05)--------AggregateExec: mode=Partial, gby=[d@2 as d], aggr=[first_value(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.a ASC NULLS LAST], first_value(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]] 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true @@ -3858,24 +3859,24 @@ ORDER BY row_n logical_plan 01)Projection: amount_usd 02)--Sort: row_n ASC NULLS LAST -03)----Projection: LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n -04)------Aggregate: groupBy=[[row_n]], aggr=[[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]]] +03)----Projection: last_value(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n +04)------Aggregate: groupBy=[[row_n]], aggr=[[last_value(l.d) ORDER BY [l.a ASC NULLS LAST]]] 05)--------Projection: l.a, l.d, row_n 06)----------Inner Join: l.d = r.d Filter: CAST(l.a AS Int64) >= CAST(r.a AS Int64) - Int64(10) 07)------------SubqueryAlias: l 08)--------------TableScan: multiple_ordered_table projection=[a, d] -09)------------Projection: r.a, r.d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n -10)--------------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +09)------------Projection: r.a, r.d, row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n +10)--------------WindowAggr: windowExpr=[[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 11)----------------SubqueryAlias: r 12)------------------TableScan: multiple_ordered_table projection=[a, d] physical_plan -01)ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] -02)--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]], ordering_mode=Sorted +01)ProjectionExec: expr=[last_value(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +02)--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[last_value(l.d) ORDER BY [l.a ASC NULLS LAST]], ordering_mode=Sorted 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10, projection=[a@0, d@1, row_n@4] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true -06)--------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] -07)----------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)--------ProjectionExec: expr=[a@0 as a, d@1 as d, row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +07)----------BoundedWindowAggExec: wdw=[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 08)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true # reset partition number to 8. @@ -3894,10 +3895,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( primary key(c) ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # We can use column b during selection # even if it is not among group by expressions @@ -3908,14 +3909,14 @@ FROM multiple_ordered_table_with_pk GROUP BY c; ---- logical_plan -01)Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +01)Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 02)--TableScan: multiple_ordered_table_with_pk projection=[b, c, d] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) -02)--SortExec: expr=[c@0 ASC NULLS LAST] +01)AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +02)--SortExec: expr=[c@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8 -05)--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +05)--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true @@ -3935,10 +3936,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table_with_pk ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # We can use column b during selection # even if it is not among group by expressions @@ -3949,14 +3950,14 @@ FROM multiple_ordered_table_with_pk GROUP BY c; ---- logical_plan -01)Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +01)Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 02)--TableScan: multiple_ordered_table_with_pk projection=[b, c, d] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) -02)--SortExec: expr=[c@0 ASC NULLS LAST] +01)AggregateExec: mode=FinalPartitioned, gby=[c@0 as c, b@1 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +02)--SortExec: expr=[c@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------RepartitionExec: partitioning=Hash([c@0, b@1], 8), input_partitions=8 -05)--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +05)--------AggregateExec: mode=Partial, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) 06)----------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true @@ -3973,13 +3974,13 @@ GROUP BY c; ---- logical_plan 01)Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, sum1]], aggr=[[]] -02)--Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 -03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +02)--Projection: multiple_ordered_table_with_pk.c, sum(multiple_ordered_table_with_pk.d) AS sum1 +03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 04)------TableScan: multiple_ordered_table_with_pk projection=[c, d] physical_plan 01)AggregateExec: mode=Single, gby=[c@0 as c, sum1@1 as sum1], aggr=[], ordering_mode=PartiallySorted([0]) -02)--ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] -03)----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +02)--ProjectionExec: expr=[c@0 as c, sum(multiple_ordered_table_with_pk.d)@1 as sum1] +03)----AggregateExec: mode=Single, gby=[c@0 as c], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true query TT @@ -3990,16 +3991,16 @@ EXPLAIN SELECT c, sum1, SUM(b) OVER() as sumb GROUP BY c); ---- logical_plan -01)Projection: multiple_ordered_table_with_pk.c, sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb -02)--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] -03)----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 -04)------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +01)Projection: multiple_ordered_table_with_pk.c, sum1, sum(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS sumb +02)--WindowAggr: windowExpr=[[sum(CAST(multiple_ordered_table_with_pk.b AS Int64)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, sum(multiple_ordered_table_with_pk.d) AS sum1 +04)------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 05)--------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] physical_plan -01)ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] -02)--WindowAggExec: wdw=[SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -03)----ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] -04)------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +01)ProjectionExec: expr=[c@0 as c, sum1@2 as sum1, sum(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as sumb] +02)--WindowAggExec: wdw=[sum(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(multiple_ordered_table_with_pk.b) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +03)----ProjectionExec: expr=[c@0 as c, b@1 as b, sum(multiple_ordered_table_with_pk.d)@2 as sum1] +04)------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true query TT @@ -4018,22 +4019,22 @@ logical_plan 01)Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 02)--Inner Join: lhs.b = rhs.b 03)----SubqueryAlias: lhs -04)------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 -05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +04)------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, sum(multiple_ordered_table_with_pk.d) AS sum1 +05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 06)----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] 07)----SubqueryAlias: rhs -08)------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 -09)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +08)------Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b, sum(multiple_ordered_table_with_pk.d) AS sum1 +09)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 10)----------TableScan: multiple_ordered_table_with_pk projection=[b, c, d] physical_plan 01)ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] 02)--CoalesceBatchesExec: target_batch_size=2 03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, b@1)], projection=[c@0, sum1@2, c@3, sum1@5] -04)------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] -05)--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +04)------ProjectionExec: expr=[c@0 as c, b@1 as b, sum(multiple_ordered_table_with_pk.d)@2 as sum1] +05)--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true -07)------ProjectionExec: expr=[c@0 as c, b@1 as b, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] -08)--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) +07)------ProjectionExec: expr=[c@0 as c, b@1 as b, sum(multiple_ordered_table_with_pk.d)@2 as sum1] +08)--------AggregateExec: mode=Single, gby=[c@1 as c, b@0 as b], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=PartiallySorted([0]) 09)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[b, c, d], output_ordering=[c@1 ASC NULLS LAST], has_header=true query TT @@ -4049,23 +4050,23 @@ EXPLAIN SELECT lhs.c, rhs.c, lhs.sum1, rhs.sum1 ---- logical_plan 01)Projection: lhs.c, rhs.c, lhs.sum1, rhs.sum1 -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: lhs -04)------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 -05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +04)------Projection: multiple_ordered_table_with_pk.c, sum(multiple_ordered_table_with_pk.d) AS sum1 +05)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 06)----------TableScan: multiple_ordered_table_with_pk projection=[c, d] 07)----SubqueryAlias: rhs -08)------Projection: multiple_ordered_table_with_pk.c, SUM(multiple_ordered_table_with_pk.d) AS sum1 -09)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +08)------Projection: multiple_ordered_table_with_pk.c, sum(multiple_ordered_table_with_pk.d) AS sum1 +09)--------Aggregate: groupBy=[[multiple_ordered_table_with_pk.c]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 10)----------TableScan: multiple_ordered_table_with_pk projection=[c, d] physical_plan 01)ProjectionExec: expr=[c@0 as c, c@2 as c, sum1@1 as sum1, sum1@3 as sum1] 02)--CrossJoinExec -03)----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] -04)------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +03)----ProjectionExec: expr=[c@0 as c, sum(multiple_ordered_table_with_pk.d)@1 as sum1] +04)------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true -06)----ProjectionExec: expr=[c@0 as c, SUM(multiple_ordered_table_with_pk.d)@1 as sum1] -07)------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +06)----ProjectionExec: expr=[c@0 as c, sum(multiple_ordered_table_with_pk.d)@1 as sum1] +07)------AggregateExec: mode=Single, gby=[c@0 as c], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 08)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true # we do not generate physical plan for Repartition yet (e.g Distribute By queries). @@ -4077,9 +4078,9 @@ FROM (SELECT c, b, a, SUM(d) as sum1 DISTRIBUTE BY a ---- logical_plan -01)Repartition: DistributeBy(a) -02)--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, SUM(multiple_ordered_table_with_pk.d) AS sum1 -03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +01)Repartition: DistributeBy(multiple_ordered_table_with_pk.a) +02)--Projection: multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b, sum(multiple_ordered_table_with_pk.d) AS sum1 +03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, multiple_ordered_table_with_pk.b]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 04)------TableScan: multiple_ordered_table_with_pk projection=[a, b, c, d] # union with aggregate @@ -4094,19 +4095,19 @@ UNION ALL ---- logical_plan 01)Union -02)--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 -03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +02)--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, sum(multiple_ordered_table_with_pk.d) AS sum1 +03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 04)------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] -05)--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 -06)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +05)--Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, sum(multiple_ordered_table_with_pk.d) AS sum1 +06)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 07)------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] physical_plan 01)UnionExec -02)--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] -03)----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +02)--ProjectionExec: expr=[c@0 as c, a@1 as a, sum(multiple_ordered_table_with_pk.d)@2 as sum1] +03)----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true -05)--ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] -06)----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +05)--ProjectionExec: expr=[c@0 as c, a@1 as a, sum(multiple_ordered_table_with_pk.d)@2 as sum1] +06)----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 07)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true # table scan should be simplified. @@ -4116,12 +4117,12 @@ EXPLAIN SELECT c, a, SUM(d) as sum1 GROUP BY c ---- logical_plan -01)Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 -02)--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +01)Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, sum(multiple_ordered_table_with_pk.d) AS sum1 +02)--Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 03)----TableScan: multiple_ordered_table_with_pk projection=[a, c, d] physical_plan -01)ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] -02)--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +01)ProjectionExec: expr=[c@0 as c, a@1 as a, sum(multiple_ordered_table_with_pk.d)@2 as sum1] +02)--AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true # limit should be simplified @@ -4133,14 +4134,14 @@ EXPLAIN SELECT * LIMIT 5) ---- logical_plan -01)Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, SUM(multiple_ordered_table_with_pk.d) AS sum1 +01)Projection: multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a, sum(multiple_ordered_table_with_pk.d) AS sum1 02)--Limit: skip=0, fetch=5 -03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[SUM(CAST(multiple_ordered_table_with_pk.d AS Int64))]] +03)----Aggregate: groupBy=[[multiple_ordered_table_with_pk.c, multiple_ordered_table_with_pk.a]], aggr=[[sum(CAST(multiple_ordered_table_with_pk.d AS Int64))]] 04)------TableScan: multiple_ordered_table_with_pk projection=[a, c, d] physical_plan -01)ProjectionExec: expr=[c@0 as c, a@1 as a, SUM(multiple_ordered_table_with_pk.d)@2 as sum1] +01)ProjectionExec: expr=[c@0 as c, a@1 as a, sum(multiple_ordered_table_with_pk.d)@2 as sum1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[SUM(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted +03)----AggregateExec: mode=Single, gby=[c@1 as c, a@0 as a], aggr=[sum(multiple_ordered_table_with_pk.d)], ordering_mode=Sorted 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true statement ok @@ -4169,39 +4170,39 @@ query TT EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT x) FROM t1 GROUP BY y; ---- logical_plan -01)Projection: SUM(DISTINCT t1.x), MAX(DISTINCT t1.x) -02)--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(DISTINCT CAST(t1.x AS Float64)), MAX(DISTINCT t1.x)]] +01)Projection: sum(DISTINCT t1.x), max(DISTINCT t1.x) +02)--Aggregate: groupBy=[[t1.y]], aggr=[[sum(DISTINCT CAST(t1.x AS Float64)), max(DISTINCT t1.x)]] 03)----TableScan: t1 projection=[x, y] physical_plan -01)ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)@2 as MAX(DISTINCT t1.x)] -02)--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +01)ProjectionExec: expr=[sum(DISTINCT t1.x)@1 as sum(DISTINCT t1.x), max(DISTINCT t1.x)@2 as max(DISTINCT t1.x)] +02)--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[sum(DISTINCT t1.x), max(DISTINCT t1.x)] 03)----CoalesceBatchesExec: target_batch_size=2 04)------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 05)--------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -06)----------AggregateExec: mode=Partial, gby=[y@1 as y], aggr=[SUM(DISTINCT t1.x), MAX(DISTINCT t1.x)] +06)----------AggregateExec: mode=Partial, gby=[y@1 as y], aggr=[sum(DISTINCT t1.x), max(DISTINCT t1.x)] 07)------------MemoryExec: partitions=1, partition_sizes=[1] query TT EXPLAIN SELECT SUM(DISTINCT CAST(x AS DOUBLE)), MAX(DISTINCT CAST(x AS DOUBLE)) FROM t1 GROUP BY y; ---- logical_plan -01)Projection: SUM(alias1) AS SUM(DISTINCT t1.x), MAX(alias1) AS MAX(DISTINCT t1.x) -02)--Aggregate: groupBy=[[t1.y]], aggr=[[SUM(alias1), MAX(alias1)]] -03)----Aggregate: groupBy=[[t1.y, CAST(t1.x AS Float64)t1.x AS t1.x AS alias1]], aggr=[[]] -04)------Projection: CAST(t1.x AS Float64) AS CAST(t1.x AS Float64)t1.x, t1.y +01)Projection: sum(alias1) AS sum(DISTINCT t1.x), max(alias1) AS max(DISTINCT t1.x) +02)--Aggregate: groupBy=[[t1.y]], aggr=[[sum(alias1), max(alias1)]] +03)----Aggregate: groupBy=[[t1.y, __common_expr_1 AS t1.x AS alias1]], aggr=[[]] +04)------Projection: CAST(t1.x AS Float64) AS __common_expr_1, t1.y 05)--------TableScan: t1 projection=[x, y] physical_plan -01)ProjectionExec: expr=[SUM(alias1)@1 as SUM(DISTINCT t1.x), MAX(alias1)@2 as MAX(DISTINCT t1.x)] -02)--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +01)ProjectionExec: expr=[sum(alias1)@1 as sum(DISTINCT t1.x), max(alias1)@2 as max(DISTINCT t1.x)] +02)--AggregateExec: mode=FinalPartitioned, gby=[y@0 as y], aggr=[sum(alias1), max(alias1)] 03)----CoalesceBatchesExec: target_batch_size=2 04)------RepartitionExec: partitioning=Hash([y@0], 8), input_partitions=8 -05)--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[SUM(alias1), MAX(alias1)] +05)--------AggregateExec: mode=Partial, gby=[y@0 as y], aggr=[sum(alias1), max(alias1)] 06)----------AggregateExec: mode=FinalPartitioned, gby=[y@0 as y, alias1@1 as alias1], aggr=[] 07)------------CoalesceBatchesExec: target_batch_size=2 08)--------------RepartitionExec: partitioning=Hash([y@0, alias1@1], 8), input_partitions=8 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -10)------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[] -11)--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y] +10)------------------AggregateExec: mode=Partial, gby=[y@1 as y, __common_expr_1@0 as alias1], aggr=[] +11)--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as __common_expr_1, y@1 as y] 12)----------------------MemoryExec: partitions=1, partition_sizes=[1] # create an unbounded table that contains ordered timestamp. @@ -4223,21 +4224,19 @@ EXPLAIN SELECT date_bin('15 minutes', ts) as time_chunks LIMIT 5; ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: time_chunks DESC NULLS FIRST, fetch=5 -03)----Projection: date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts) AS time_chunks -04)------Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("900000000000"), unbounded_csv_with_timestamps.ts) AS date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)]], aggr=[[]] -05)--------TableScan: unbounded_csv_with_timestamps projection=[ts] +01)Sort: time_chunks DESC NULLS FIRST, fetch=5 +02)--Projection: date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts) AS time_chunks +03)----Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 900000000000 }"), unbounded_csv_with_timestamps.ts) AS date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)]], aggr=[[]] +04)------TableScan: unbounded_csv_with_timestamps projection=[ts] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortPreservingMergeExec: [time_chunks@0 DESC], fetch=5 -03)----ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] -04)------AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, preserve_order=true, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC -07)------------AggregateExec: mode=Partial, gby=[date_bin(900000000000, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted -08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -09)----------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] +01)SortPreservingMergeExec: [time_chunks@0 DESC], fetch=5 +02)--ProjectionExec: expr=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as time_chunks] +03)----AggregateExec: mode=FinalPartitioned, gby=[date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0], 8), input_partitions=8, preserve_order=true, sort_exprs=date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)@0 DESC +06)----------AggregateExec: mode=Partial, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 900000000000 }, ts@0) as date_bin(Utf8("15 minutes"),unbounded_csv_with_timestamps.ts)], aggr=[], ordering_mode=Sorted +07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +08)--------------StreamingTableExec: partition_sizes=1, projection=[ts], infinite_source=true, output_ordering=[ts@0 DESC] query P SELECT date_bin('15 minutes', ts) as time_chunks @@ -4266,7 +4265,8 @@ CREATE EXTERNAL TABLE csv_with_timestamps ( ) STORED AS CSV WITH ORDER (ts DESC) -LOCATION '../core/tests/data/timestamps.csv'; +LOCATION '../core/tests/data/timestamps.csv' +OPTIONS('format.has_header' 'false'); # below query should run since it operates on a bounded source and have a sort # at the top of its plan. @@ -4278,22 +4278,20 @@ EXPLAIN SELECT extract(month from ts) as months LIMIT 5; ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: months DESC NULLS FIRST, fetch=5 -03)----Projection: date_part(Utf8("MONTH"),csv_with_timestamps.ts) AS months -04)------Aggregate: groupBy=[[date_part(Utf8("MONTH"), csv_with_timestamps.ts)]], aggr=[[]] -05)--------TableScan: csv_with_timestamps projection=[ts] +01)Sort: months DESC NULLS FIRST, fetch=5 +02)--Projection: date_part(Utf8("MONTH"),csv_with_timestamps.ts) AS months +03)----Aggregate: groupBy=[[date_part(Utf8("MONTH"), csv_with_timestamps.ts)]], aggr=[[]] +04)------TableScan: csv_with_timestamps projection=[ts] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortPreservingMergeExec: [months@0 DESC], fetch=5 -03)----SortExec: TopK(fetch=5), expr=[months@0 DESC] -04)------ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] -05)--------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] -06)----------CoalesceBatchesExec: target_batch_size=2 -07)------------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 -08)--------------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] -09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], has_header=false +01)SortPreservingMergeExec: [months@0 DESC], fetch=5 +02)--SortExec: TopK(fetch=5), expr=[months@0 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as months] +04)------AggregateExec: mode=FinalPartitioned, gby=[date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0 as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +05)--------CoalesceBatchesExec: target_batch_size=2 +06)----------RepartitionExec: partitioning=Hash([date_part(Utf8("MONTH"),csv_with_timestamps.ts)@0], 8), input_partitions=8 +07)------------AggregateExec: mode=Partial, gby=[date_part(MONTH, ts@0) as date_part(Utf8("MONTH"),csv_with_timestamps.ts)], aggr=[] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 DESC], has_header=false query R SELECT extract(month from ts) as months @@ -4326,16 +4324,14 @@ EXPLAIN SELECT name, date_bin('15 minutes', ts) as time_chunks LIMIT 5; ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: unbounded_csv_with_timestamps2.name DESC NULLS FIRST, time_chunks DESC NULLS FIRST, fetch=5 -03)----Projection: unbounded_csv_with_timestamps2.name, date_bin(IntervalMonthDayNano("900000000000"), unbounded_csv_with_timestamps2.ts) AS time_chunks -04)------TableScan: unbounded_csv_with_timestamps2 projection=[name, ts] +01)Sort: unbounded_csv_with_timestamps2.name DESC NULLS FIRST, time_chunks DESC NULLS FIRST, fetch=5 +02)--Projection: unbounded_csv_with_timestamps2.name, date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 900000000000 }"), unbounded_csv_with_timestamps2.ts) AS time_chunks +03)----TableScan: unbounded_csv_with_timestamps2 projection=[name, ts] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortPreservingMergeExec: [name@0 DESC,time_chunks@1 DESC], fetch=5 -03)----ProjectionExec: expr=[name@0 as name, date_bin(900000000000, ts@1) as time_chunks] -04)------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 -05)--------StreamingTableExec: partition_sizes=1, projection=[name, ts], infinite_source=true, output_ordering=[name@0 DESC, ts@1 DESC] +01)SortPreservingMergeExec: [name@0 DESC, time_chunks@1 DESC], fetch=5 +02)--ProjectionExec: expr=[name@0 as name, date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 900000000000 }, ts@1) as time_chunks] +03)----RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 +04)------StreamingTableExec: partition_sizes=1, projection=[name, ts], infinite_source=true, output_ordering=[name@0 DESC, ts@1 DESC] statement ok drop table t1 @@ -4377,8 +4373,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TIIII SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1; @@ -4394,18 +4390,18 @@ EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM a ---- logical_plan 01)Sort: aggregate_test_100.c1 ASC NULLS LAST -02)--Projection: aggregate_test_100.c1, COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4) -03)----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)]] -04)------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]] +02)--Projection: aggregate_test_100.c1, count(alias1) AS count(DISTINCT aggregate_test_100.c2), min(alias1) AS min(DISTINCT aggregate_test_100.c2), sum(alias2) AS sum(aggregate_test_100.c3), max(alias3) AS max(aggregate_test_100.c4) +03)----Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[count(alias1), min(alias1), sum(alias2), max(alias3)]] +04)------Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[sum(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, max(aggregate_test_100.c4) AS alias3]] 05)--------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4] physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST] -02)--SortExec: expr=[c1@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[c1@0 as c1, COUNT(alias1)@1 as COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1)@2 as MIN(DISTINCT aggregate_test_100.c2), SUM(alias2)@3 as SUM(aggregate_test_100.c3), MAX(alias3)@4 as MAX(aggregate_test_100.c4)] -04)------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +02)--SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c1@0 as c1, count(alias1)@1 as count(DISTINCT aggregate_test_100.c2), min(alias1)@2 as min(DISTINCT aggregate_test_100.c2), sum(alias2)@3 as sum(aggregate_test_100.c3), max(alias3)@4 as max(aggregate_test_100.c4)] +04)------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[count(alias1), min(alias1), sum(alias2), max(alias3)] 05)--------CoalesceBatchesExec: target_batch_size=2 06)----------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(alias1), MIN(alias1), SUM(alias2), MAX(alias3)] +07)------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[count(alias1), min(alias1), sum(alias2), max(alias3)] 08)--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3] 09)----------------CoalesceBatchesExec: target_batch_size=2 10)------------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8 @@ -4454,10 +4450,10 @@ CREATE EXTERNAL TABLE unbounded_multiple_ordered_table_with_pk ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # Query below can be executed, since c is primary key. query III rowsort @@ -4489,7 +4485,7 @@ LIMIT 5 statement ok CREATE TABLE src_table ( t1 TIMESTAMP, - c2 INT, + c2 INT ) AS VALUES ('2020-12-10T00:00:00.00Z', 0), ('2020-12-11T00:00:00.00Z', 1), @@ -4503,28 +4499,28 @@ CREATE TABLE src_table ( ('2020-12-19T00:00:00.00Z', 9); # Use src_table to create a partitioned file -query PI +query I COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/0.csv' STORED AS CSV; ---- 10 -query PI +query I COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/1.csv' STORED AS CSV; ---- 10 -query PI +query I COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/2.csv' STORED AS CSV; ---- 10 -query PI +query I COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/3.csv' STORED AS CSV; @@ -4538,8 +4534,8 @@ CREATE EXTERNAL TABLE timestamp_table ( c2 INT, ) STORED AS CSV -WITH HEADER ROW -LOCATION 'test_files/scratch/group_by/timestamp_table'; +LOCATION 'test_files/scratch/group_by/timestamp_table' +OPTIONS ('format.has_header' 'true'); # Group By using date_trunc query PI rowsort @@ -4573,20 +4569,18 @@ ORDER BY MAX(t1) DESC LIMIT 4; ---- logical_plan -01)Limit: skip=0, fetch=4 -02)--Sort: MAX(timestamp_table.t1) DESC NULLS FIRST, fetch=4 -03)----Aggregate: groupBy=[[timestamp_table.c2]], aggr=[[MAX(timestamp_table.t1)]] -04)------TableScan: timestamp_table projection=[t1, c2] +01)Sort: max(timestamp_table.t1) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[timestamp_table.c2]], aggr=[[max(timestamp_table.t1)]] +03)----TableScan: timestamp_table projection=[t1, c2] physical_plan -01)GlobalLimitExec: skip=0, fetch=4 -02)--SortPreservingMergeExec: [MAX(timestamp_table.t1)@1 DESC], fetch=4 -03)----SortExec: TopK(fetch=4), expr=[MAX(timestamp_table.t1)@1 DESC] -04)------AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[MAX(timestamp_table.t1)], lim=[4] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------RepartitionExec: partitioning=Hash([c2@0], 8), input_partitions=8 -07)------------AggregateExec: mode=Partial, gby=[c2@1 as c2], aggr=[MAX(timestamp_table.t1)], lim=[4] -08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 -09)----------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/0.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/1.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/2.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/3.csv]]}, projection=[t1, c2], has_header=true +01)SortPreservingMergeExec: [max(timestamp_table.t1)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(timestamp_table.t1)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[max(timestamp_table.t1)], lim=[4] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------RepartitionExec: partitioning=Hash([c2@0], 8), input_partitions=8 +06)----------AggregateExec: mode=Partial, gby=[c2@1 as c2], aggr=[max(timestamp_table.t1)], lim=[4] +07)------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=4 +08)--------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/0.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/1.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/2.csv], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/group_by/timestamp_table/3.csv]]}, projection=[t1, c2], has_header=true # Clean up statement ok @@ -4611,7 +4605,7 @@ DROP TABLE timestamp_table; # Table with an int column and Dict column: statement ok -CREATE TABLE int8_dict AS VALUES +CREATE TABLE int8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int8, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int8, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int8, Utf8)')), @@ -4620,7 +4614,7 @@ CREATE TABLE int8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int8, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM int8_dict GROUP BY column2; ---- A 4 @@ -4649,7 +4643,7 @@ DROP TABLE int8_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int16_dict AS VALUES +CREATE TABLE int16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int16, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int16, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int16, Utf8)')), @@ -4658,7 +4652,7 @@ CREATE TABLE int16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int16, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM int16_dict GROUP BY column2; ---- A 4 @@ -4687,7 +4681,7 @@ DROP TABLE int16_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int32_dict AS VALUES +CREATE TABLE int32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int32, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int32, Utf8)')), @@ -4696,7 +4690,7 @@ CREATE TABLE int32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int32, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM int32_dict GROUP BY column2; ---- A 4 @@ -4725,7 +4719,7 @@ DROP TABLE int32_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE int64_dict AS VALUES +CREATE TABLE int64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int64, Utf8)')), (2, arrow_cast('B', 'Dictionary(Int64, Utf8)')), (2, arrow_cast('A', 'Dictionary(Int64, Utf8)')), @@ -4734,7 +4728,7 @@ CREATE TABLE int64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(Int64, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM int64_dict GROUP BY column2; ---- A 4 @@ -4763,7 +4757,7 @@ DROP TABLE int64_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint8_dict AS VALUES +CREATE TABLE uint8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')), @@ -4772,7 +4766,7 @@ CREATE TABLE uint8_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM uint8_dict GROUP BY column2; ---- A 4 @@ -4801,7 +4795,7 @@ DROP TABLE uint8_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint16_dict AS VALUES +CREATE TABLE uint16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')), @@ -4810,7 +4804,7 @@ CREATE TABLE uint16_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM uint16_dict GROUP BY column2; ---- A 4 @@ -4839,7 +4833,7 @@ DROP TABLE uint16_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint32_dict AS VALUES +CREATE TABLE uint32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')), @@ -4848,7 +4842,7 @@ CREATE TABLE uint32_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM uint32_dict GROUP BY column2; ---- A 4 @@ -4877,7 +4871,7 @@ DROP TABLE uint32_dict; # Table with an int column and Dict column: statement ok -CREATE TABLE uint64_dict AS VALUES +CREATE TABLE uint64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), (2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')), (2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')), @@ -4886,7 +4880,7 @@ CREATE TABLE uint64_dict AS VALUES (1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')); # Group by the non-dict column -query ?I rowsort +query TI rowsort SELECT column2, count(column1) FROM uint64_dict GROUP BY column2; ---- A 4 @@ -4971,10 +4965,10 @@ ORDER BY a, b; ---- logical_plan 01)Sort: multiple_ordered_table.a ASC NULLS LAST, multiple_ordered_table.b ASC NULLS LAST -02)--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] +02)--Aggregate: groupBy=[[multiple_ordered_table.a, multiple_ordered_table.b]], aggr=[[array_agg(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]]] 03)----TableScan: multiple_ordered_table projection=[a, b, c] physical_plan -01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[ARRAY_AGG(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]], ordering_mode=Sorted +01)AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b], aggr=[array_agg(multiple_ordered_table.c) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST]], ordering_mode=Sorted 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query II? @@ -5059,3 +5053,220 @@ query II SELECT a + 1 AS d, a + 1 + b AS c FROM (SELECT 1 AS a, 2 AS b) GROUP BY a + 1, a + 1 + b; ---- 2 4 + +statement error DataFusion error: Error during planning: Cannot find column with position 4 in SELECT clause. Valid columns: 1 to 3 +SELECT a, b, COUNT(1) +FROM multiple_ordered_table +GROUP BY 1, 2, 4, 5, 6; + +statement ok +set datafusion.execution.target_partitions = 1; + +# Create a table that contains various keywords, with their corresponding timestamps +statement ok +CREATE TABLE keywords_stream ( + ts TIMESTAMP, + sn INTEGER PRIMARY KEY, + keyword VARCHAR NOT NULL +); + +statement ok +INSERT INTO keywords_stream(ts, sn, keyword) VALUES +('2024-01-01T00:00:00Z', '0', 'Drug'), +('2024-01-01T00:00:05Z', '1', 'Bomb'), +('2024-01-01T00:00:10Z', '2', 'Theft'), +('2024-01-01T00:00:15Z', '3', 'Gun'), +('2024-01-01T00:00:20Z', '4', 'Calm'); + +# Create a table that contains alert keywords +statement ok +CREATE TABLE ALERT_KEYWORDS(keyword VARCHAR NOT NULL); + +statement ok +INSERT INTO ALERT_KEYWORDS VALUES +('Drug'), +('Bomb'), +('Theft'), +('Gun'), +('Knife'), +('Fire'); + +query TT +explain SELECT + DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk, + COUNT(keyword) AS alert_keyword_count +FROM + keywords_stream +WHERE + keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS) +GROUP BY + ts_chunk; +---- +logical_plan +01)Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01")) AS ts_chunk, count(keywords_stream.keyword) AS alert_keyword_count +02)--Aggregate: groupBy=[[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"), keywords_stream.ts, TimestampNanosecond(946684800000000000, None)) AS date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))]], aggr=[[count(keywords_stream.keyword)]] +03)----LeftSemi Join: keywords_stream.keyword = __correlated_sq_1.keyword +04)------TableScan: keywords_stream projection=[ts, keyword] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: alert_keywords projection=[keyword] +physical_plan +01)ProjectionExec: expr=[date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))@0 as ts_chunk, count(keywords_stream.keyword)@1 as alert_keyword_count] +02)--AggregateExec: mode=Single, gby=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }, ts@0, 946684800000000000) as date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 120000000000 }"),keywords_stream.ts,Utf8("2000-01-01"))], aggr=[count(keywords_stream.keyword)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(keyword@0, keyword@1)] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)--------MemoryExec: partitions=1, partition_sizes=[1] + +query PI +SELECT + DATE_BIN(INTERVAL '2' MINUTE, ts, '2000-01-01') AS ts_chunk, + COUNT(keyword) AS alert_keyword_count +FROM + keywords_stream +WHERE + keywords_stream.keyword IN (SELECT keyword FROM ALERT_KEYWORDS) +GROUP BY + ts_chunk; +---- +2024-01-01T00:00:00 4 + +# Issue: https://github.com/apache/datafusion/issues/11118 +statement ok +CREATE TABLE test_case_expr(a INT, b TEXT) AS VALUES (1,'hello'), (2,'world') + +query T +SELECT (CASE WHEN CONCAT(b, 'hello') = 'test' THEN 'good' ELSE 'bad' END) AS c + FROM test_case_expr GROUP BY c; +---- +bad + +query I rowsort +SELECT (CASE a::BIGINT WHEN 1 THEN 1 END) AS c FROM test_case_expr GROUP BY c; +---- +1 +NULL + +statement ok +drop table test_case_expr + +statement ok +drop table t; + +# test multi group by for binary type with nulls +statement ok +create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, null), (null, 0xb), (null, 0xb); + +query I?I +select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); +---- +1 0a 2 +2 NULL 1 +NULL 0b 2 +1 NULL 2 +2 NULL 1 +NULL NULL 2 +NULL 0a 2 +NULL NULL 1 +NULL 0b 2 + +statement ok +drop table t; + +# test multi group by for binary type without nulls +statement ok +create table t(a int, b bytea) as values (1, 0xa), (1, 0xa), (2, 0xb), (3, 0xb), (3, 0xb); + +query I?I +select a, b, count(*) from t group by grouping sets ((a, b), (a), (b)); +---- +1 0a 2 +2 0b 1 +3 0b 2 +1 NULL 2 +2 NULL 1 +3 NULL 2 +NULL 0a 2 +NULL 0b 3 + +statement ok +drop table t; + +# test multi group by int + utf8 +statement ok +create table t(a int, b varchar) as values (1, 'a'), (1, 'a'), (2, 'ab'), (3, 'abc'), (3, 'cba'), (null, null), (null, 'a'), (null, null), (null, 'a'), (1, 'null'); + +query ITI rowsort +select a, b, count(*) from t group by a, b; +---- +1 a 2 +1 null 1 +2 ab 1 +3 abc 1 +3 cba 1 +NULL NULL 2 +NULL a 2 + +statement ok +drop table t; + +# test multi group by int + utf8view +statement ok +create table source as values +-- use some strings that are larger than 12 characters as that goes through a different path +(1, 'a'), +(1, 'a'), +(2, 'thisstringislongerthan12'), +(2, 'thisstring'), +(3, 'abc'), +(3, 'cba'), +(2, 'thisstring'), +(null, null), +(null, 'a'), +(null, null), +(null, 'a'), +(2, 'thisstringisalsolongerthan12'), +(2, 'thisstringislongerthan12'), +(1, 'null') +; + +statement ok +create view t as select column1 as a, arrow_cast(column2, 'Utf8View') as b from source; + +query ITI +select a, b, count(*) from t group by a, b order by a, b; +---- +1 a 2 +1 null 1 +2 thisstring 2 +2 thisstringisalsolongerthan12 1 +2 thisstringislongerthan12 2 +3 abc 1 +3 cba 1 +NULL a 2 +NULL NULL 2 + +statement ok +drop view t + +# test with binary view +statement ok +create view t as select column1 as a, arrow_cast(column2, 'BinaryView') as b from source; + +query I?I +select a, b, count(*) from t group by a, b order by a, b; +---- +1 61 2 +1 6e756c6c 1 +2 74686973737472696e67 2 +2 74686973737472696e676973616c736f6c6f6e6765727468616e3132 1 +2 74686973737472696e6769736c6f6e6765727468616e3132 2 +3 616263 1 +3 636261 1 +NULL 61 2 +NULL NULL 2 + +statement ok +drop view t + +statement ok +drop table source; diff --git a/datafusion/sqllogictest/test_files/grouping.slt b/datafusion/sqllogictest/test_files/grouping.slt new file mode 100644 index 000000000000..64d040d012f9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/grouping.slt @@ -0,0 +1,214 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE TABLE test (c1 VARCHAR,c2 VARCHAR,c3 INT) as values +('a','A',1), ('b','B',2) + +# grouping_with_grouping_sets +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + grouping sets ( + (c1, c2), + (c1), + (c2), + () + ) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_cube +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + cube(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL A 1 0 2 1 +NULL B 1 0 2 1 +NULL NULL 1 1 3 3 + +# grouping_with_rollup +query TTIIII +select + c1, + c2, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0, g1, g2, g3; +---- +a A 0 0 0 0 +a NULL 0 1 1 2 +b B 0 0 0 0 +b NULL 0 1 1 2 +NULL NULL 1 1 3 3 + +query TTIIIIIIII +select + c1, + c2, + c3, + grouping(c1) as g0, + grouping(c2) as g1, + grouping(c1, c2) as g2, + grouping(c2, c1) as g3, + grouping(c1, c2, c3) as g4, + grouping(c2, c3, c1) as g5, + grouping(c3, c2, c1) as g6 +from + test +group by + rollup(c1, c2, c3) +order by + c1, c2, g0, g1, g2, g3, g4, g5, g6; +---- +a A 1 0 0 0 0 0 0 0 +a A NULL 0 0 0 0 1 2 4 +a NULL NULL 0 1 1 2 3 6 6 +b B 2 0 0 0 0 0 0 0 +b B NULL 0 0 0 0 1 2 4 +b NULL NULL 0 1 1 2 3 6 6 +NULL NULL NULL 1 1 3 3 7 7 7 + +# grouping_with_add +query TTI +select + c1, + c2, + grouping(c1)+grouping(c2) as g0 +from + test +group by + rollup(c1, c2) +order by + c1, c2, g0; +---- +a A 0 +a NULL 1 +b B 0 +b NULL 1 +NULL NULL 2 + +#grouping_with_windown_function +query TTIII +select + c1, + c2, + count(c1) as cnt, + grouping(c1)+ grouping(c2) as g0, + rank() over ( + partition by grouping(c1)+grouping(c2), + case when grouping(c2) = 0 then c1 end + order by + count(c1) desc + ) as rank_within_parent +from + test +group by + rollup(c1, c2) +order by + c1, + c2, + cnt, + g0 desc, + rank_within_parent; +---- +a A 1 0 1 +a NULL 1 1 1 +b B 1 0 1 +b NULL 1 1 1 +NULL NULL 2 2 1 + +# grouping_with_non_columns +query TIIIII +select + c1, + c3 + 1 as c3_add_one, + grouping(c1) as g0, + grouping(c3 + 1) as g1, + grouping(c1, c3 + 1) as g2, + grouping(c3 + 1, c1) as g3 +from + test +group by + grouping sets ( + (c1, c3 + 1), + (c3 + 1), + (c1) + ) +order by + c1, c3_add_one, g0, g1, g2, g3; +---- +a 2 0 0 0 0 +a NULL 0 1 1 2 +b 3 0 0 0 0 +b NULL 0 1 1 2 +NULL 2 1 0 2 1 +NULL 3 1 0 2 1 + +# postgres allows grouping function for GROUP BY without GROUPING SETS/ROLLUP/CUBE +query TI +select c1, grouping(c1) from test group by c1 order by c1; +---- +a 0 +b 0 + +statement error c2.*not in grouping columns +select c1, grouping(c2) from test group by c1; + +statement error c2.*not in grouping columns +select c1, grouping(c1, c2) from test group by CUBE(c1); + +statement error zero arguments +select c1, grouping() from test group by CUBE(c1); diff --git a/datafusion/sqllogictest/test_files/identifiers.slt b/datafusion/sqllogictest/test_files/identifiers.slt index f60d60b2bfe0..755d617e7a2a 100644 --- a/datafusion/sqllogictest/test_files/identifiers.slt +++ b/datafusion/sqllogictest/test_files/identifiers.slt @@ -22,8 +22,8 @@ CREATE EXTERNAL TABLE case_insensitive_test ( c INT ) STORED AS CSV -WITH HEADER ROW LOCATION '../core/tests/data/example.csv' +OPTIONS ('format.has_header' 'true'); # normalized column identifiers query II diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index 8f4b1a3816a3..03ab4a090e67 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -49,6 +49,12 @@ SELECT * from information_schema.schemata; ---- datafusion public NULL NULL NULL NULL NULL +# Table name case insensitive +query T rowsort +SELECT catalog_name from information_schema.SchEmaTa; +---- +datafusion + # Disable information_schema and verify it now errors again statement ok set datafusion.catalog.information_schema = false @@ -159,51 +165,62 @@ datafusion.catalog.create_default_catalog_and_schema true datafusion.catalog.default_catalog datafusion datafusion.catalog.default_schema public datafusion.catalog.format NULL -datafusion.catalog.has_header false +datafusion.catalog.has_header true datafusion.catalog.information_schema true datafusion.catalog.location NULL -datafusion.execution.aggregate.scalar_update_factor 10 +datafusion.catalog.newlines_in_values false datafusion.execution.batch_size 8192 datafusion.execution.coalesce_batches true datafusion.execution.collect_statistics false datafusion.execution.enable_recursive_ctes true +datafusion.execution.enforce_batch_size_in_joins false +datafusion.execution.keep_partition_by_columns false datafusion.execution.listing_table_ignore_subdirectory true datafusion.execution.max_buffered_batches_per_output_file 2 datafusion.execution.meta_fetch_concurrency 32 datafusion.execution.minimum_parallel_output_files 4 datafusion.execution.parquet.allow_single_file_parallelism true -datafusion.execution.parquet.bloom_filter_enabled false +datafusion.execution.parquet.binary_as_string false datafusion.execution.parquet.bloom_filter_fpp NULL datafusion.execution.parquet.bloom_filter_ndv NULL -datafusion.execution.parquet.column_index_truncate_length NULL +datafusion.execution.parquet.bloom_filter_on_read true +datafusion.execution.parquet.bloom_filter_on_write false +datafusion.execution.parquet.column_index_truncate_length 64 datafusion.execution.parquet.compression zstd(3) datafusion.execution.parquet.created_by datafusion -datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 +datafusion.execution.parquet.data_page_row_count_limit 20000 datafusion.execution.parquet.data_pagesize_limit 1048576 -datafusion.execution.parquet.dictionary_enabled NULL +datafusion.execution.parquet.dictionary_enabled true datafusion.execution.parquet.dictionary_page_size_limit 1048576 datafusion.execution.parquet.enable_page_index true datafusion.execution.parquet.encoding NULL datafusion.execution.parquet.max_row_group_size 1048576 -datafusion.execution.parquet.max_statistics_size NULL +datafusion.execution.parquet.max_statistics_size 4096 datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 datafusion.execution.parquet.maximum_parallel_row_group_writers 1 datafusion.execution.parquet.metadata_size_hint NULL datafusion.execution.parquet.pruning true datafusion.execution.parquet.pushdown_filters false datafusion.execution.parquet.reorder_filters false +datafusion.execution.parquet.schema_force_view_types false datafusion.execution.parquet.skip_metadata true -datafusion.execution.parquet.statistics_enabled NULL +datafusion.execution.parquet.statistics_enabled page datafusion.execution.parquet.write_batch_size 1024 datafusion.execution.parquet.writer_version 1.0 datafusion.execution.planning_concurrency 13 +datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 +datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 +datafusion.execution.skip_physical_aggregate_schema_check false datafusion.execution.soft_max_rows_per_output_file 50000000 datafusion.execution.sort_in_place_threshold_bytes 1048576 datafusion.execution.sort_spill_reservation_bytes 10485760 +datafusion.execution.split_file_groups_by_statistics false datafusion.execution.target_partitions 7 datafusion.execution.time_zone +00:00 +datafusion.execution.use_row_number_estimates_to_optimize_partitioning false datafusion.explain.logical_plan_only false datafusion.explain.physical_plan_only false +datafusion.explain.show_schema false datafusion.explain.show_sizes true datafusion.explain.show_statistics false datafusion.optimizer.allow_symmetric_joins_without_pruning true @@ -211,11 +228,13 @@ datafusion.optimizer.default_filter_selectivity 20 datafusion.optimizer.enable_distinct_aggregation_soft_limit true datafusion.optimizer.enable_round_robin_repartition true datafusion.optimizer.enable_topk_aggregation true +datafusion.optimizer.expand_views_at_output false datafusion.optimizer.filter_null_join_keys false datafusion.optimizer.hash_join_single_partition_threshold 1048576 datafusion.optimizer.hash_join_single_partition_threshold_rows 131072 datafusion.optimizer.max_passes 3 datafusion.optimizer.prefer_existing_sort false +datafusion.optimizer.prefer_existing_union false datafusion.optimizer.prefer_hash_join true datafusion.optimizer.repartition_aggregations true datafusion.optimizer.repartition_file_min_size 10485760 @@ -227,7 +246,9 @@ datafusion.optimizer.skip_failed_rules false datafusion.optimizer.top_down_join_key_reordering true datafusion.sql_parser.dialect generic datafusion.sql_parser.enable_ident_normalization true +datafusion.sql_parser.enable_options_value_normalization true datafusion.sql_parser.parse_float_as_decimal false +datafusion.sql_parser.support_varchar_with_length true # show all variables with verbose query TTT rowsort @@ -237,51 +258,62 @@ datafusion.catalog.create_default_catalog_and_schema true Whether the default ca datafusion.catalog.default_catalog datafusion The default catalog name - this impacts what SQL queries use if not specified datafusion.catalog.default_schema public The default schema name - this impacts what SQL queries use if not specified datafusion.catalog.format NULL Type of `TableProvider` to use when loading `default` schema -datafusion.catalog.has_header false If the file has a header +datafusion.catalog.has_header true Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information datafusion.catalog.location NULL Location scanned to load tables for `default` schema -datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. +datafusion.catalog.newlines_in_values false Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. datafusion.execution.batch_size 8192 Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption datafusion.execution.coalesce_batches true When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting datafusion.execution.collect_statistics false Should DataFusion collect statistics after listing files datafusion.execution.enable_recursive_ctes true Should DataFusion support recursive CTEs +datafusion.execution.enforce_batch_size_in_joins false Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. +datafusion.execution.keep_partition_by_columns false Should DataFusion keep the columns used for partition_by in the output RecordBatches datafusion.execution.listing_table_ignore_subdirectory true Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). datafusion.execution.max_buffered_batches_per_output_file 2 This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption datafusion.execution.meta_fetch_concurrency 32 Number of files to read in parallel when inferring schema and statistics datafusion.execution.minimum_parallel_output_files 4 Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. -datafusion.execution.parquet.allow_single_file_parallelism true Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. -datafusion.execution.parquet.bloom_filter_enabled false Sets if bloom filter is enabled for any column -datafusion.execution.parquet.bloom_filter_fpp NULL Sets bloom filter false positive probability. If NULL, uses default parquet writer setting -datafusion.execution.parquet.bloom_filter_ndv NULL Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting -datafusion.execution.parquet.column_index_truncate_length NULL Sets column index truncate length -datafusion.execution.parquet.compression zstd(3) Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.created_by datafusion Sets "created by" property -datafusion.execution.parquet.data_page_row_count_limit 18446744073709551615 Sets best effort maximum number of rows in data page -datafusion.execution.parquet.data_pagesize_limit 1048576 Sets best effort maximum size of data page in bytes -datafusion.execution.parquet.dictionary_enabled NULL Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting -datafusion.execution.parquet.dictionary_page_size_limit 1048576 Sets best effort maximum dictionary page size, in bytes -datafusion.execution.parquet.enable_page_index true If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. -datafusion.execution.parquet.encoding NULL Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.max_row_group_size 1048576 Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. -datafusion.execution.parquet.max_statistics_size NULL Sets max statistics size for any column. If NULL, uses default parquet writer setting -datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. -datafusion.execution.parquet.maximum_parallel_row_group_writers 1 By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. -datafusion.execution.parquet.metadata_size_hint NULL If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer -datafusion.execution.parquet.pruning true If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file -datafusion.execution.parquet.pushdown_filters false If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". -datafusion.execution.parquet.reorder_filters false If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query -datafusion.execution.parquet.skip_metadata true If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata -datafusion.execution.parquet.statistics_enabled NULL Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting -datafusion.execution.parquet.write_batch_size 1024 Sets write_batch_size in bytes -datafusion.execution.parquet.writer_version 1.0 Sets parquet writer version valid values are "1.0" and "2.0" +datafusion.execution.parquet.allow_single_file_parallelism true (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. +datafusion.execution.parquet.binary_as_string false (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. +datafusion.execution.parquet.bloom_filter_fpp NULL (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_ndv NULL (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting +datafusion.execution.parquet.bloom_filter_on_read true (writing) Use any available bloom filters when reading parquet files +datafusion.execution.parquet.bloom_filter_on_write false (writing) Write bloom filters for all columns when creating parquet files +datafusion.execution.parquet.column_index_truncate_length 64 (writing) Sets column index truncate length +datafusion.execution.parquet.compression zstd(3) (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. +datafusion.execution.parquet.created_by datafusion (writing) Sets "created by" property +datafusion.execution.parquet.data_page_row_count_limit 20000 (writing) Sets best effort maximum number of rows in data page +datafusion.execution.parquet.data_pagesize_limit 1048576 (writing) Sets best effort maximum size of data page in bytes +datafusion.execution.parquet.dictionary_enabled true (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting +datafusion.execution.parquet.dictionary_page_size_limit 1048576 (writing) Sets best effort maximum dictionary page size, in bytes +datafusion.execution.parquet.enable_page_index true (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. +datafusion.execution.parquet.encoding NULL (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.max_row_group_size 1048576 (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. +datafusion.execution.parquet.max_statistics_size 4096 (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting +datafusion.execution.parquet.maximum_buffered_record_batches_per_stream 2 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.maximum_parallel_row_group_writers 1 (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. +datafusion.execution.parquet.metadata_size_hint NULL (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer +datafusion.execution.parquet.pruning true (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file +datafusion.execution.parquet.pushdown_filters false (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". +datafusion.execution.parquet.reorder_filters false (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query +datafusion.execution.parquet.schema_force_view_types false (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. +datafusion.execution.parquet.skip_metadata true (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata +datafusion.execution.parquet.statistics_enabled page (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting +datafusion.execution.parquet.write_batch_size 1024 (writing) Sets write_batch_size in bytes +datafusion.execution.parquet.writer_version 1.0 (writing) Sets parquet writer version valid values are "1.0" and "2.0" datafusion.execution.planning_concurrency 13 Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system +datafusion.execution.skip_partial_aggregation_probe_ratio_threshold 0.8 Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input +datafusion.execution.skip_partial_aggregation_probe_rows_threshold 100000 Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode +datafusion.execution.skip_physical_aggregate_schema_check false When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. datafusion.execution.soft_max_rows_per_output_file 50000000 Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max datafusion.execution.sort_in_place_threshold_bytes 1048576 When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. datafusion.execution.sort_spill_reservation_bytes 10485760 Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). +datafusion.execution.split_file_groups_by_statistics false Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental datafusion.execution.target_partitions 7 Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour +datafusion.execution.use_row_number_estimates_to_optimize_partitioning false Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. datafusion.explain.logical_plan_only false When set to true, the explain statement will only print logical plans datafusion.explain.physical_plan_only false When set to true, the explain statement will only print physical plans +datafusion.explain.show_schema false When set to true, the explain statement will print schema information datafusion.explain.show_sizes true When set to true, the explain statement will print the partition sizes datafusion.explain.show_statistics false When set to true, the explain statement will print operator statistics for physical plans datafusion.optimizer.allow_symmetric_joins_without_pruning true Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. @@ -289,11 +321,13 @@ datafusion.optimizer.default_filter_selectivity 20 The default filter selectivit datafusion.optimizer.enable_distinct_aggregation_soft_limit true When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. datafusion.optimizer.enable_round_robin_repartition true When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores datafusion.optimizer.enable_topk_aggregation true When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible +datafusion.optimizer.expand_views_at_output false When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. datafusion.optimizer.filter_null_join_keys false When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. datafusion.optimizer.hash_join_single_partition_threshold 1048576 The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition datafusion.optimizer.hash_join_single_partition_threshold_rows 131072 The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition datafusion.optimizer.max_passes 3 Number of times that the optimizer will attempt to optimize the plan datafusion.optimizer.prefer_existing_sort false When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. +datafusion.optimizer.prefer_existing_union false When set to true, the optimizer will not attempt to convert Union to Interleave datafusion.optimizer.prefer_hash_join true When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory datafusion.optimizer.repartition_aggregations true Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level datafusion.optimizer.repartition_file_min_size 10485760 Minimum total files size in bytes to perform file scan repartitioning. @@ -305,7 +339,9 @@ datafusion.optimizer.skip_failed_rules false When set to true, the logical plan datafusion.optimizer.top_down_join_key_reordering true When set to true, the physical plan optimizer will run a top down process to reorder the join keys datafusion.sql_parser.dialect generic Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. datafusion.sql_parser.enable_ident_normalization true When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) +datafusion.sql_parser.enable_options_value_normalization true When set to true, SQL parser will normalize options value (convert value to lowercase) datafusion.sql_parser.parse_float_as_decimal false When set to true, SQL parser will parse float as decimal type +datafusion.sql_parser.support_varchar_with_length true If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. # show_variable_in_config_options query TT @@ -350,9 +386,12 @@ datafusion.execution.time_zone +00:00 The default time zone Some functions, e.g. # show empty verbose -query TTT +statement error DataFusion error: Error during planning: '' is not a variable which can be viewed with 'SHOW' SHOW VERBOSE ----- + +# show nonsense verbose +statement error DataFusion error: Error during planning: 'nonsense' is not a variable which can be viewed with 'SHOW' +SHOW NONSENSE VERBOSE # information_schema_describe_table @@ -488,9 +527,7 @@ SHOW columns from datafusion.public.t2 # show_non_existing_variable -# FIXME -# currently we cannot know whether a variable exists, this will output 0 row instead -statement ok +statement error DataFusion error: Error during planning: 'something_unknown' is not a variable which can be viewed with 'SHOW' SHOW SOMETHING_UNKNOWN; statement ok @@ -559,7 +596,8 @@ DROP VIEW test.xyz statement ok CREATE EXTERNAL TABLE abc STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TTTT SHOW CREATE TABLE abc; diff --git a/datafusion/sqllogictest/test_files/insert.slt b/datafusion/sqllogictest/test_files/insert.slt index d19581d0cc0b..804612287246 100644 --- a/datafusion/sqllogictest/test_files/insert.slt +++ b/datafusion/sqllogictest/test_files/insert.slt @@ -37,8 +37,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # test_insert_into @@ -58,24 +58,24 @@ ORDER by c1 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST -04)------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 -05)--------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 +05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=1) -02)--ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] +02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true -query II +query I INSERT INTO table_without_values SELECT SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) @@ -121,15 +121,15 @@ FROM aggregate_test_100 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 -03)----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +03)----WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 04)------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=1) 02)--CoalescePartitionsExec -03)----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 @@ -137,7 +137,7 @@ physical_plan -query II +query I INSERT INTO table_without_values SELECT SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 @@ -171,23 +171,23 @@ logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] 02)--Projection: a1 AS a1, a2 AS a2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST -04)------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 -05)--------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS a2, aggregate_test_100.c1 +05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=8) 02)--ProjectionExec: expr=[a1@0 as a1, a2@1 as a2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as a1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as a2, c1@0 as c1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true -query II +query I INSERT INTO table_without_values SELECT SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 @@ -218,10 +218,10 @@ logical_plan 04)------TableScan: aggregate_test_100 projection=[c1] physical_plan 01)DataSinkExec: sink=MemoryTable (partitions=1) -02)--SortExec: expr=[c1@0 ASC NULLS LAST] +02)--SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true -query T +query I insert into table_without_values select c1 from aggregate_test_100 order by c1; ---- 100 @@ -239,12 +239,12 @@ drop table table_without_values; statement ok CREATE TABLE table_without_values(id BIGINT, name varchar); -query IT +query I insert into table_without_values(id, name) values(1, 'foo'); ---- 1 -query IT +query I insert into table_without_values(name, id) values('bar', 2); ---- 1 @@ -259,7 +259,7 @@ statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); # insert NULL values for the missing column (name) -query IT +query I insert into table_without_values(id) values(4); ---- 1 @@ -279,18 +279,18 @@ drop table table_without_values; statement ok CREATE TABLE table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL); -query II +query I insert into table_without_values values(1, 100); ---- 1 -query II +query I insert into table_without_values values(2, NULL); ---- 1 # insert NULL values for the missing column (field2) -query II +query I insert into table_without_values(field1) values(3); ---- 1 @@ -363,7 +363,7 @@ create table test_column_defaults( e timestamp default now() ) -query IIITP +query I insert into test_column_defaults values(1, 10, 100, 'ABC', now()) ---- 1 @@ -371,7 +371,7 @@ insert into test_column_defaults values(1, 10, 100, 'ABC', now()) statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable insert into test_column_defaults(a) values(2) -query IIITP +query I insert into test_column_defaults(b) values(20) ---- 1 @@ -383,7 +383,7 @@ select a,b,c,d from test_column_defaults NULL 20 500 default_text # fill the timestamp column with default value `now()` again, it should be different from the previous one -query IIITP +query I insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') ---- 1 @@ -417,7 +417,7 @@ create table test_column_defaults( e timestamp default now() ) as values(1, 10, 100, 'ABC', now()) -query IIITP +query I insert into test_column_defaults(b) values(20) ---- 1 diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index 2c7af6abe47c..35decd728eed 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -37,8 +37,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok @@ -60,7 +60,7 @@ STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b); -query TT +query I insert into dictionary_encoded_parquet_partitioned select * from dictionary_encoded_values ---- @@ -81,7 +81,7 @@ STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' PARTITIONED BY (b); -query TT +query I insert into dictionary_encoded_arrow_partitioned select * from dictionary_encoded_values ---- @@ -126,11 +126,11 @@ logical_plan 03)----Values: (Int64(5), Int64(1)), (Int64(4), Int64(2)), (Int64(7), Int64(7)), (Int64(7), Int64(8)), (Int64(7), Int64(9))... physical_plan 01)DataSinkExec: sink=CsvSink(file_groups=[]) -02)--SortExec: expr=[a@0 ASC NULLS LAST,b@1 DESC] +02)--SortExec: expr=[a@0 ASC NULLS LAST, b@1 DESC], preserve_partitioning=[false] 03)----ProjectionExec: expr=[column1@0 as a, column2@1 as b] 04)------ValuesExec -query II +query I INSERT INTO ordered_insert_test values (5, 1), (4, 2), (7,7), (7,8), (7,9), (7,10), (3, 3), (2, 4), (1, 5); ---- 9 @@ -158,7 +158,7 @@ LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned/' PARTITIONED BY (a, b); #note that partitioned cols are moved to the end so value tuples are (c, a, b) -query ITT +query I INSERT INTO partitioned_insert_test values (1, 10, 100), (1, 10, 200), (1, 20, 100), (1, 20, 200), (2, 20, 100), (2, 20, 200); ---- 6 @@ -192,7 +192,7 @@ STORED AS csv LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned' PARTITIONED BY (a string, b string); -query ITT +query I INSERT INTO partitioned_insert_test_hive VALUES (3,30,300); ---- 1 @@ -216,7 +216,7 @@ STORED AS json LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_json/' PARTITIONED BY (a); -query TT +query I INSERT INTO partitioned_insert_test_json values (1, 2), (3, 4), (5, 6), (1, 2), (3, 4), (5, 6); ---- 6 @@ -250,7 +250,7 @@ STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned_pq/' PARTITIONED BY (a); -query IT +query I INSERT INTO partitioned_insert_test_pq values (1, 2), (3, 4), (5, 6), (1, 2), (3, 4), (5, 6); ---- 6 @@ -296,12 +296,12 @@ single_file_test(a bigint, b bigint) STORED AS csv LOCATION 'test_files/scratch/insert_to_external/single_csv_table.csv'; -query II +query I INSERT INTO single_file_test values (1, 2), (3, 4); ---- 2 -query II +query I INSERT INTO single_file_test values (4, 5), (6, 7); ---- 2 @@ -320,7 +320,7 @@ directory_test(a bigint, b bigint) STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q0/'; -query II +query I INSERT INTO directory_test values (1, 2), (3, 4); ---- 2 @@ -347,24 +347,24 @@ ORDER by c1 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 03)----Sort: aggregate_test_100.c1 ASC NULLS LAST -04)------Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 -05)--------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +04)------Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, aggregate_test_100.c1 +05)--------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) -02)--ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] +02)--ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as field2] 03)----SortPreservingMergeExec: [c1@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] -05)--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +04)------ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, c1@0 as c1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 09)----------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c4, c9], has_header=true -query II +query I INSERT INTO table_without_values SELECT SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING), COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) @@ -411,15 +411,15 @@ FROM aggregate_test_100 ---- logical_plan 01)Dml: op=[Insert Into] table=[table_without_values] -02)--Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 -03)----WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +02)--Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS field2 +03)----WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 04)------TableScan: aggregate_test_100 projection=[c1, c4, c9] physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) 02)--CoalescePartitionsExec -03)----ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@2 ASC NULLS LAST] +03)----ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as field1, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as field2] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@2 ASC NULLS LAST], preserve_partitioning=[true] 06)----------CoalesceBatchesExec: target_batch_size=8192 07)------------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8 08)--------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1 @@ -427,7 +427,7 @@ physical_plan -query II +query I INSERT INTO table_without_values SELECT SUM(c4) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a1, COUNT(*) OVER(PARTITION BY c1 ORDER BY c9 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as a2 @@ -459,10 +459,10 @@ logical_plan 04)------TableScan: aggregate_test_100 projection=[c1] physical_plan 01)DataSinkExec: sink=ParquetSink(file_groups=[]) -02)--SortExec: expr=[c1@0 ASC NULLS LAST] +02)--SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true -query T +query I insert into table_without_values select c1 from aggregate_test_100 order by c1; ---- 100 @@ -484,12 +484,12 @@ table_without_values(id BIGINT, name varchar) STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q4/'; -query IT +query I insert into table_without_values(id, name) values(1, 'foo'); ---- 1 -query IT +query I insert into table_without_values(name, id) values('bar', 2); ---- 1 @@ -504,7 +504,7 @@ statement error Error during planning: Column count doesn't match insert query! insert into table_without_values(id) values(4, 'zoo'); # insert NULL values for the missing column (name) -query IT +query I insert into table_without_values(id) values(4); ---- 1 @@ -526,18 +526,18 @@ table_without_values(field1 BIGINT NOT NULL, field2 BIGINT NULL) STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q5/'; -query II +query I insert into table_without_values values(1, 100); ---- 1 -query II +query I insert into table_without_values values(2, NULL); ---- 1 # insert NULL values for the missing column (field2) -query II +query I insert into table_without_values(field1) values(3); ---- 1 @@ -576,7 +576,7 @@ CREATE EXTERNAL TABLE test_column_defaults( LOCATION 'test_files/scratch/insert_to_external/external_parquet_table_q6/'; # fill in all column values -query IIITP +query I insert into test_column_defaults values(1, 10, 100, 'ABC', now()) ---- 1 @@ -584,7 +584,7 @@ insert into test_column_defaults values(1, 10, 100, 'ABC', now()) statement error DataFusion error: Execution error: Invalid batch column at '1' has null but schema specifies non-nullable insert into test_column_defaults(a) values(2) -query IIITP +query I insert into test_column_defaults(b) values(20) ---- 1 @@ -596,7 +596,7 @@ select a,b,c,d from test_column_defaults NULL 20 500 default_text # fill the timestamp column with default value `now()` again, it should be different from the previous one -query IIITP +query I insert into test_column_defaults(a, b, c, d) values(2, 20, 200, 'DEF') ---- 1 diff --git a/datafusion/sqllogictest/test_files/interval.slt b/datafusion/sqllogictest/test_files/interval.slt index eab4eed00269..db453adf12ba 100644 --- a/datafusion/sqllogictest/test_files/interval.slt +++ b/datafusion/sqllogictest/test_files/interval.slt @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. - # Use `interval` SQL literal syntax # the types should be the same: https://github.com/apache/datafusion/issues/5801 query TT @@ -45,278 +44,221 @@ Interval(MonthDayNano) Interval(MonthDayNano) query ? select interval '5' years ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs # check all different kinds of intervals query ? select interval '5' year ---- -0 years 60 mons 0 days 0 hours 0 mins 0.000000000 secs +60 mons query ? select interval '5' month ---- -0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs +5 mons query ? select interval '5' months ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs query ? select interval '5' week ---- -0 years 0 mons 35 days 0 hours 0 mins 0.000000000 secs +35 days query ? select interval '5' day ---- -0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs +5 days query ? select interval '5' hour ---- -0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs +5 hours ## This seems wrong (5 mons) query ? select interval '5' hours ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs query ? select interval '5' minute ---- -0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs +5 mins query ? select interval '5' second ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs query ? select interval '5' millisecond ---- -0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs +0.005000000 secs query ? select interval '5' milliseconds ---- -0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs +0.005000000 secs query ? select interval '5' microsecond ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs +0.000005000 secs query ? select interval '5' microseconds ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs +0.000005000 secs query ? select interval '5' nanosecond ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +0.000000005 secs query ? select interval '5' nanoseconds ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +0.000000005 secs query ? select interval '5 YEAR' ---- -0 years 60 mons 0 days 0 hours 0 mins 0.000000000 secs +60 mons query ? select interval '5 MONTH' ---- -0 years 5 mons 0 days 0 hours 0 mins 0.000000000 secs +5 mons query ? select interval '5 WEEK' ---- -0 years 0 mons 35 days 0 hours 0 mins 0.000000000 secs +35 days query ? select interval '5 DAY' ---- -0 years 0 mons 5 days 0 hours 0 mins 0.000000000 secs +5 days query ? select interval '5 HOUR' ---- -0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs +5 hours query ? select interval '5 HOURS' ---- -0 years 0 mons 0 days 5 hours 0 mins 0.000000000 secs +5 hours query ? select interval '5 MINUTE' ---- -0 years 0 mons 0 days 0 hours 5 mins 0.000000000 secs +5 mins query ? select interval '5 SECOND' ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs query ? select interval '5 SECONDS' ---- -0 years 0 mons 0 days 0 hours 0 mins 5.000000000 secs +5.000000000 secs query ? select interval '5 MILLISECOND' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs +0.005000000 secs query ? select interval '5 MILLISECONDS' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.005000000 secs +0.005000000 secs query ? select interval '5 MICROSECOND' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs +0.000005000 secs query ? select interval '5 MICROSECONDS' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000005000 secs +0.000005000 secs query ? select interval '5 NANOSECOND' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +0.000000005 secs query ? select interval '5 NANOSECONDS' ---- -0 years 0 mons 0 days 0 hours 0 mins 0.000000005 secs +0.000000005 secs query ? select interval '5 YEAR 5 MONTH 5 DAY 5 HOUR 5 MINUTE 5 SECOND 5 MILLISECOND 5 MICROSECOND 5 NANOSECOND' ---- -0 years 65 mons 5 days 5 hours 5 mins 5.005005005 secs - -# Interval with string literal addition -query ? -select interval '1 month' + '1 month' ----- -0 years 2 mons 0 days 0 hours 0 mins 0.000000000 secs - -# Interval with string literal addition and leading field -query ? -select interval '1' + '1' month ----- -0 years 2 mons 0 days 0 hours 0 mins 0.000000000 secs +65 mons 5 days 5 hours 5 mins 5.005005005 secs -# Interval with nested string literal addition +# Interval mega nested literal addition query ? -select interval '1 month' + '1 month' + '1 month' +select interval '1 year' + interval '1 month' + interval '1 day' + interval '1 hour' + interval '1 minute' + interval '1 second' + interval '1 millisecond' + interval '1 microsecond' + interval '1 nanosecond' ---- -0 years 3 mons 0 days 0 hours 0 mins 0.000000000 secs - -# Interval with nested string literal addition and leading field -query ? -select interval '1' + '1' + '1' month ----- -0 years 3 mons 0 days 0 hours 0 mins 0.000000000 secs - -# Interval mega nested string literal addition -query ? -select interval '1 year' + '1 month' + '1 day' + '1 hour' + '1 minute' + '1 second' + '1 millisecond' + '1 microsecond' + '1 nanosecond' ----- -0 years 13 mons 1 days 1 hours 1 mins 1.001001001 secs +13 mons 1 days 1 hours 1 mins 1.001001001 secs # Interval with string literal subtraction query ? -select interval '1 month' - '1 day'; +select interval '1 month' - interval '1 day'; ---- -0 years 1 mons -1 days 0 hours 0 mins 0.000000000 secs - -# Interval with string literal subtraction and leading field -query ? -select interval '5' - '1' - '2' year; ----- -0 years 24 mons 0 days 0 hours 0 mins 0.000000000 secs +1 mons -1 days # Interval with nested string literal subtraction query ? -select interval '1 month' - '1 day' - '1 hour'; +select interval '1 month' - interval '1 day' - interval '1 hour'; ---- -0 years 1 mons -1 days -1 hours 0 mins 0.000000000 secs - -# Interval with nested string literal subtraction and leading field -query ? -select interval '10' - '1' - '1' month; ----- -0 years 8 mons 0 days 0 hours 0 mins 0.000000000 secs +1 mons -1 days -1 hours # Interval mega nested string literal subtraction query ? -select interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 second' - '1 millisecond' - '1 microsecond' - '1 nanosecond' +select interval '1 year' - interval '1 month' - interval '1 day' - interval '1 hour' - interval '1 minute' - interval '1 second' - interval '1 millisecond' - interval '1 microsecond' - interval '1 nanosecond' ---- -0 years 11 mons -1 days -1 hours -1 mins -1.001001001 secs +11 mons -1 days -1 hours -1 mins -1.001001001 secs -# Interval with string literal negation and leading field +# Interval with nested literal negation query ? -select -interval '5' - '1' - '2' year; +select -interval '1 month' + interval '1 day' + interval '1 hour'; ---- -0 years -96 mons 0 days 0 hours 0 mins 0.000000000 secs +-1 mons 1 days 1 hours -# Interval with nested string literal negation +# Interval mega nested literal negation query ? -select -interval '1 month' + '1 day' + '1 hour'; +select -interval '1 year' - interval '1 month' - interval '1 day' - interval '1 hour' - interval '1 minute' - interval '1 second' - interval '1 millisecond' - interval '1 microsecond' - interval '1 nanosecond' ---- -0 years -1 mons 1 days 1 hours 0 mins 0.000000000 secs - -# Interval with nested string literal negation and leading field -query ? -select -interval '10' - '1' - '1' month; ----- -0 years -12 mons 0 days 0 hours 0 mins 0.000000000 secs - -# Interval mega nested string literal negation -query ? -select -interval '1 year' - '1 month' - '1 day' - '1 hour' - '1 minute' - '1 second' - '1 millisecond' - '1 microsecond' - '1 nanosecond' ----- -0 years -13 mons -1 days -1 hours -1 mins -1.001001001 secs +-13 mons -1 days -1 hours -1 mins -1.001001001 secs # Interval string literal + date query D -select interval '1 month' + '1 day' + '2012-01-01'::date; ----- -2012-02-02 - -# Interval string literal parenthesized + date -query D -select ( interval '1 month' + '1 day' ) + '2012-01-01'::date; +select interval 1 month + interval 1 day + '2012-01-01'::date; ---- 2012-02-02 # Interval nested string literal + date query D -select interval '1 year' + '1 month' + '1 day' + '2012-01-01'::date +select interval 1 year + interval 1 month + interval 1 day + '2012-01-01'::date ---- 2013-02-02 # Interval nested string literal subtraction + date query D -select interval '1 year' - '1 month' + '1 day' + '2012-01-01'::date +select interval 1 year - interval 1 month + interval 1 day + '2012-01-01'::date ---- 2012-12-02 - - - # Use interval SQL type query TT select @@ -325,7 +267,7 @@ select ---- Interval(MonthDayNano) Interval(MonthDayNano) -# cast with explicit cast sytax +# cast with explicit cast syntax query TT select arrow_typeof(cast ('5 months' as interval)), @@ -343,7 +285,7 @@ select arrow_typeof(i) from t; ---- -0 years 0 mons 5 days 0 hours 0 mins 0.000000003 secs Interval(MonthDayNano) +5 days 0.000000003 secs Interval(MonthDayNano) statement ok @@ -359,8 +301,8 @@ insert into t values ('6 days 7 nanoseconds'::interval) query ? rowsort select -i from t order by 1; ---- -0 years 0 mons -5 days 0 hours 0 mins -0.000000003 secs -0 years 0 mons -6 days 0 hours 0 mins -0.000000007 secs +-5 days -0.000000003 secs +-6 days -0.000000007 secs query ?T rowsort select @@ -368,8 +310,8 @@ select arrow_typeof(i) from t; ---- -0 years 0 mons 5 days 0 hours 0 mins 0.000000003 secs Interval(MonthDayNano) -0 years 0 mons 6 days 0 hours 0 mins 0.000000007 secs Interval(MonthDayNano) +5 days 0.000000003 secs Interval(MonthDayNano) +6 days 0.000000007 secs Interval(MonthDayNano) statement ok drop table t; @@ -544,6 +486,31 @@ select i - d from t; query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types select i - ts from t; +# interval unit abreiviation and plurals +query ? +select interval '1s' +---- +1.000000000 secs + +query ? +select '1s'::interval +---- +1.000000000 secs + +query ? +select interval'1sec' +---- +1.000000000 secs + +query ? +select interval '1ms' +---- +0.001000000 secs + +query ? +select interval '1 y' + interval '1 year' +---- +24 mons # interval (scalar) + date / timestamp (array) query D @@ -560,6 +527,22 @@ select '1 month'::interval + ts from t; 2000-02-01T12:11:10 2000-03-01T00:00:00 +# trailing extra unit, this matches PostgreSQL +query ? +select interval '5 day 1' hour +---- +5 days 1 hours + +# trailing extra unit, this matches PostgreSQL +query ? +select interval '5 day 0' hour +---- +5 days + +# This is interpreted as "0 hours" with PostgreSQL, should be fixed with +query error DataFusion error: Arrow error: Parser error: Invalid input syntax for type interval: "5 day HOUR" +SELECT interval '5 day' hour + # expected error interval (scalar) - date / timestamp (array) query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Date32 to valid types select '1 month'::interval - d from t; diff --git a/datafusion/sqllogictest/test_files/interval_mysql.slt b/datafusion/sqllogictest/test_files/interval_mysql.slt new file mode 100644 index 000000000000..c05bb007e5f1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/interval_mysql.slt @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Use `interval` SQL literal syntax with MySQL dialect + +# this should fail with the generic dialect +query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \+ Utf8 to valid types +select interval '1' + '1' month + +statement ok +set datafusion.sql_parser.dialect = 'Mysql'; + +# Interval with string literal addition and leading field +query ? +select interval '1' + '1' month +---- +2 mons + +# Interval with nested string literal addition +query ? +select interval 1 + 1 + 1 month +---- +3 mons + +# Interval with nested string literal addition and leading field +query ? +select interval '1' + '1' + '1' month +---- +3 mons + +# Interval with string literal subtraction and leading field +query ? +select interval '5' - '1' - '2' year; +---- +24 mons + +# Interval with nested string literal subtraction and leading field +query ? +select interval '10' - '1' - '1' month; +---- +8 mons + +# Interval with string literal negation and leading field +query ? +select -interval '5' - '1' - '2' year; +---- +-96 mons + +# Interval with nested string literal negation and leading field +query ? +select -interval '10' - '1' - '1' month; +---- +-12 mons + +# revert to standard dialect +statement ok +set datafusion.sql_parser.dialect = 'Generic'; diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index 6732d3e9108b..39f903a58714 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -671,7 +671,7 @@ query TT explain select * from t1 inner join t2 on true; ---- logical_plan -01)CrossJoin: +01)Cross Join: 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] 03)--TableScan: t2 projection=[t2_id, t2_name, t2_int] physical_plan @@ -793,3 +793,436 @@ DROP TABLE companies statement ok DROP TABLE leads + +#### +## Test ON clause predicates are not pushed past join for OUTER JOINs +#### + + +# create tables +statement ok +CREATE TABLE employees(emp_id INT, name VARCHAR); + +statement ok +CREATE TABLE department(emp_id INT, dept_name VARCHAR); + +statement ok +INSERT INTO employees (emp_id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Carol'); + +statement ok +INSERT INTO department (emp_id, dept_name) VALUES (1, 'HR'), (3, 'Engineering'), (4, 'Sales'); + +# Can not push the ON filter below an OUTER JOIN +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +logical_plan +01)Left Join: Filter: e.name = Utf8("Alice") OR e.name = Utf8("Bob") +02)--SubqueryAlias: e +03)----TableScan: employees projection=[emp_id, name] +04)--SubqueryAlias: d +05)----TableScan: department projection=[dept_name] +physical_plan +01)ProjectionExec: expr=[emp_id@1 as emp_id, name@2 as name, dept_name@0 as dept_name] +02)--NestedLoopJoinExec: join_type=Right, filter=name@0 = Alice OR name@0 = Bob +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob HR +2 Bob Engineering +2 Bob Sales +3 Carol NULL + +# neither RIGHT OUTER JOIN +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM department AS d +RIGHT JOIN employees AS e +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob HR +2 Bob Engineering +2 Bob Sales +3 Carol NULL + +# neither FULL OUTER JOIN +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM department AS d +FULL JOIN employees AS e +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob HR +2 Bob Engineering +2 Bob Sales +3 Carol NULL + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees e +LEFT JOIN department d +ON (e.name = 'NotExist1' OR e.name = 'NotExist2'); +---- +1 Alice NULL +2 Bob NULL +3 Carol NULL + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees e +LEFT JOIN department d +ON (e.name = 'Alice' OR e.name = 'NotExist'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob NULL +3 Carol NULL + +# Can push the ON filter below the JOIN for INNER JOIN (expect to see a filter below the join) +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +logical_plan +01)Cross Join: +02)--SubqueryAlias: e +03)----Filter: employees.name = Utf8("Alice") OR employees.name = Utf8("Bob") +04)------TableScan: employees projection=[emp_id, name] +05)--SubqueryAlias: d +06)----TableScan: department projection=[dept_name] +physical_plan +01)CrossJoinExec +02)--CoalesceBatchesExec: target_batch_size=8192 +03)----FilterExec: name@1 = Alice OR name@1 = Bob +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)--MemoryExec: partitions=1, partition_sizes=[1] + +# expect no row for Carol +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +JOIN department AS d +ON (e.name = 'Alice' OR e.name = 'Bob'); +---- +1 Alice HR +1 Alice Engineering +1 Alice Sales +2 Bob HR +2 Bob Engineering +2 Bob Sales + +# OR conditions on Filter (not join filter) +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE (e.name = 'Alice' OR e.name = 'Carol'); +---- +1 Alice HR +3 Carol Engineering + +# Push down OR conditions on Filter through LEFT JOIN if possible +query TT +EXPLAIN SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); +---- +logical_plan +01)Filter: d.dept_name != Utf8("Engineering") AND e.name = Utf8("Alice") OR e.name != Utf8("Alice") AND e.name = Utf8("Carol") +02)--Projection: e.emp_id, e.name, d.dept_name +03)----Left Join: e.emp_id = d.emp_id +04)------SubqueryAlias: e +05)--------Filter: employees.name = Utf8("Alice") OR employees.name != Utf8("Alice") AND employees.name = Utf8("Carol") +06)----------TableScan: employees projection=[emp_id, name] +07)------SubqueryAlias: d +08)--------TableScan: department projection=[emp_id, dept_name] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: dept_name@2 != Engineering AND name@1 = Alice OR name@1 != Alice AND name@1 = Carol +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------HashJoinExec: mode=CollectLeft, join_type=Left, on=[(emp_id@0, emp_id@0)], projection=[emp_id@0, name@1, dept_name@3] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------FilterExec: name@1 = Alice OR name@1 != Alice AND name@1 = Carol +08)--------------MemoryExec: partitions=1, partition_sizes=[1] +09)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ITT +SELECT e.emp_id, e.name, d.dept_name +FROM employees AS e +LEFT JOIN department AS d +ON e.emp_id = d.emp_id +WHERE ((dept_name != 'Engineering' AND e.name = 'Alice') OR (name != 'Alice' AND e.name = 'Carol')); +---- +1 Alice HR +3 Carol Engineering + +statement ok +DROP TABLE employees + +statement ok +DROP TABLE department + + +statement ok +CREATE TABLE t1 (v0 BIGINT) AS VALUES (-503661263); + +statement ok +CREATE TABLE t2 (v0 DOUBLE) AS VALUES (-1.663563947387); + +statement ok +CREATE TABLE t3 (v0 DOUBLE) AS VALUES (0.05112015193508901); + +# Test issue: https://github.com/apache/datafusion/issues/11269 +query RR +SELECT t3.v0, t2.v0 FROM t1,t2,t3 WHERE t3.v0 >= t1.v0; +---- +0.051120151935 -1.663563947387 + +# Test issue: https://github.com/apache/datafusion/issues/11414 +query IRR +SELECT * FROM t1 INNER JOIN t2 ON NULL RIGHT JOIN t3 ON TRUE; +---- +NULL NULL 0.051120151935 + +# ON expression must be boolean type +query error DataFusion error: type_coercion\ncaused by\nError during planning: Join condition must be boolean type, but got Utf8 +SELECT * FROM t1 INNER JOIN t2 ON 'TRUE' + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; + + +statement ok +CREATE TABLE t0 (v1 BOOLEAN) AS VALUES (false), (null); + +statement ok +CREATE TABLE t1 (v1 BOOLEAN) AS VALUES (false), (null), (false); + +statement ok +CREATE TABLE t2 (v1 BOOLEAN) AS VALUES (false), (true); + +# Test issue: https://github.com/apache/datafusion/issues/11275 +query BB +SELECT t2.v1, t1.v1 FROM t0, t1, t2 WHERE t2.v1 IS DISTINCT FROM t0.v1 ORDER BY 1,2; +---- +false false +false false +false NULL +true false +true false +true false +true false +true NULL +true NULL + +# Test issue: https://github.com/apache/datafusion/issues/11621 +query BB +SELECT * FROM t1 JOIN t2 ON t1.v1 = t2.v1 WHERE (t1.v1 == t2.v1) OR t1.v1; +---- +false false +false false + +query BB +SELECT * FROM t1 JOIN t2 ON t1.v1 = t2.v1 WHERE t1.v1 OR (t1.v1 == t2.v1); +---- +false false +false false + +statement ok +DROP TABLE t0; + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +# Join Using Issue with Cast Expr +# Found issue: https://github.com/apache/datafusion/issues/11412 + +statement ok +/*DML*/CREATE TABLE t60(v0 BIGINT, v1 BIGINT, v2 BOOLEAN, v3 BOOLEAN); + +statement ok +/*DML*/CREATE TABLE t0(v0 DOUBLE, v1 BIGINT); + +statement ok +/*DML*/CREATE TABLE t1(v0 DOUBLE); + +query I +SELECT COUNT(*) +FROM t1 +NATURAL JOIN t60 +INNER JOIN t0 +ON t60.v1 = t0.v0 +AND t0.v1 > t60.v1; +---- +0 + +query I +SELECT COUNT(*) +FROM t1 +JOIN t60 +USING (v0) +INNER JOIN t0 +ON t60.v1 = t0.v0 +AND t0.v1 > t60.v1; +---- +0 + +statement ok +DROP TABLE t60; + +statement ok +DROP TABLE t0; + +statement ok +DROP TABLE t1; + +# Test SQLancer issue: https://github.com/apache/datafusion/issues/11704 +query II +WITH + t1 AS (SELECT NULL::int AS a), + t2 AS (SELECT NULL::int AS a) +SELECT * FROM + (SELECT * FROM t1 CROSS JOIN t2) +WHERE t1.a == t2.a + AND t1.a + t2.a IS NULL; +---- + +# Similar to above test case, but without the equality predicate +query II +WITH + t1 AS (SELECT NULL::int AS a), + t2 AS (SELECT NULL::int AS a) +SELECT * FROM + (SELECT * FROM t1 CROSS JOIN t2) +WHERE t1.a + t2.a IS NULL; +---- +NULL NULL + +statement ok +CREATE TABLE t5(v0 BIGINT, v1 STRING, v2 BIGINT, v3 STRING, v4 BOOLEAN); + +statement ok +CREATE TABLE t1(v0 BIGINT, v1 STRING); + +statement ok +CREATE TABLE t0(v0 BIGINT, v1 DOUBLE); + +query TT +explain SELECT * +FROM t1 +NATURAL JOIN t5 +INNER JOIN t0 ON (t0.v1 + t5.v0) > 0 +WHERE t0.v1 = t1.v0; +---- +logical_plan +01)Projection: t1.v0, t1.v1, t5.v2, t5.v3, t5.v4, t0.v0, t0.v1 +02)--Inner Join: CAST(t1.v0 AS Float64) = t0.v1 Filter: t0.v1 + CAST(t5.v0 AS Float64) > Float64(0) +03)----Projection: t1.v0, t1.v1, t5.v0, t5.v2, t5.v3, t5.v4 +04)------Inner Join: t1.v0 = t5.v0, t1.v1 = t5.v1 +05)--------TableScan: t1 projection=[v0, v1] +06)--------TableScan: t5 projection=[v0, v1, v2, v3, v4] +07)----TableScan: t0 projection=[v0, v1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(CAST(t1.v0 AS Float64)@6, v1@1)], filter=v1@1 + CAST(v0@0 AS Float64) > 0, projection=[v0@0, v1@1, v2@3, v3@4, v4@5, v0@7, v1@8] +03)----CoalescePartitionsExec +04)------ProjectionExec: expr=[v0@0 as v0, v1@1 as v1, v0@2 as v0, v2@3 as v2, v3@4 as v3, v4@5 as v4, CAST(v0@0 AS Float64) as CAST(t1.v0 AS Float64)] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(v0@0, v0@0), (v1@1, v1@1)], projection=[v0@0, v1@1, v0@2, v2@4, v3@5, v4@6] +08)--------------MemoryExec: partitions=1, partition_sizes=[0] +09)--------------MemoryExec: partitions=1, partition_sizes=[0] +10)----MemoryExec: partitions=1, partition_sizes=[0] + + + +statement ok +drop table t5; + +statement ok +drop table t1; + +statement ok +drop table t0; + +# Test decorrelate query with the uppercase table name and column name +statement ok +create table "T1"("C1" int, "C2" int); + +statement ok +create table "T2"("C1" int, "C3" int); + +statement ok +select "C1" from "T1" where not exists (select 1 from "T2" where "T1"."C1" = "T2"."C1") + +statement ok +create table t1(c1 int, c2 int); + +statement ok +create table t2(c1 int, c3 int); + +statement ok +select "C1" from (select c1 as "C1", c2 as "C2" from t1) as "T1" where not exists (select 1 from (select c1 as "C1", c3 as "C3" from t2) as "T2" where "T1"."C1" = "T2"."C1") + +statement ok +drop table "T1"; + +statement ok +drop table "T2"; + +statement ok +drop table t1; + +statement ok +drop table t2; + +# Test SQLancer issue: https://github.com/apache/datafusion/issues/12337 +statement ok +create table t1(v1 int) as values(100); + +## Query with Ambiguous column reference +query error DataFusion error: Schema error: Schema contains duplicate qualified field name t1\.v1 +select count(*) +from t1 +right outer join t1 +on t1.v1 > 0; + +query error DataFusion error: Schema error: Schema contains duplicate qualified field name t1\.v1 +select t1.v1 from t1 join t1 using(v1) cross join (select struct('foo' as v1) as t1); + +statement ok +drop table t1; diff --git a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt index 60a14f78bdf5..cf897d628da5 100644 --- a/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt +++ b/datafusion/sqllogictest/test_files/join_disable_repartition_joins.slt @@ -34,9 +34,9 @@ CREATE EXTERNAL TABLE annotated_data ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); query TT EXPLAIN SELECT t2.a @@ -46,22 +46,20 @@ EXPLAIN SELECT t2.a LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: t2.a ASC NULLS LAST, fetch=5 -03)----Projection: t2.a -04)------Inner Join: t1.c = t2.c -05)--------SubqueryAlias: t1 -06)----------TableScan: annotated_data projection=[c] -07)--------SubqueryAlias: t2 -08)----------TableScan: annotated_data projection=[a, c] +01)Sort: t2.a ASC NULLS LAST, fetch=5 +02)--Projection: t2.a +03)----Inner Join: t1.c = t2.c +04)------SubqueryAlias: t1 +05)--------TableScan: annotated_data projection=[c] +06)------SubqueryAlias: t2 +07)--------TableScan: annotated_data projection=[a, c] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 -03)----CoalesceBatchesExec: target_batch_size=8192 -04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)], projection=[a@1] -05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true -06)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST], fetch=5 +02)--CoalesceBatchesExec: target_batch_size=8192, fetch=5 +03)----HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(c@0, c@1)], projection=[a@1] +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], has_header=true +05)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_ordering=[a@0 ASC NULLS LAST], has_header=true # preserve_inner_join query IIII nosort @@ -87,26 +85,24 @@ EXPLAIN SELECT t2.a as a2, t2.b LIMIT 10 ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: a2 ASC NULLS LAST, t2.b ASC NULLS LAST, fetch=10 -03)----Projection: t2.a AS a2, t2.b -04)------RightSemi Join: t1.d = t2.d, t1.c = t2.c -05)--------SubqueryAlias: t1 -06)----------TableScan: annotated_data projection=[c, d] -07)--------SubqueryAlias: t2 -08)----------Filter: annotated_data.d = Int32(3) -09)------------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] +01)Sort: a2 ASC NULLS LAST, t2.b ASC NULLS LAST, fetch=10 +02)--Projection: t2.a AS a2, t2.b +03)----RightSemi Join: t1.d = t2.d, t1.c = t2.c +04)------SubqueryAlias: t1 +05)--------TableScan: annotated_data projection=[c, d] +06)------SubqueryAlias: t2 +07)--------Filter: annotated_data.d = Int32(3) +08)----------TableScan: annotated_data projection=[a, b, c, d], partial_filters=[annotated_data.d = Int32(3)] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [a2@0 ASC NULLS LAST,b@1 ASC NULLS LAST], fetch=10 -03)----ProjectionExec: expr=[a@0 as a2, b@1 as b] -04)------CoalesceBatchesExec: target_batch_size=8192 -05)--------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)], projection=[a@0, b@1] -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true -07)----------CoalesceBatchesExec: target_batch_size=8192 -08)------------FilterExec: d@3 = 3 -09)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +01)SortPreservingMergeExec: [a2@0 ASC NULLS LAST, b@1 ASC NULLS LAST], fetch=10 +02)--ProjectionExec: expr=[a@0 as a2, b@1 as b] +03)----CoalesceBatchesExec: target_batch_size=8192, fetch=10 +04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(d@1, d@3), (c@0, c@2)], projection=[a@0, b@1] +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], has_header=true +06)--------CoalesceBatchesExec: target_batch_size=8192 +07)----------FilterExec: d@3 = 3 +08)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true # preserve_right_semi_join query II nosort diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 5ef33307b521..93bb1f1f548e 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -53,6 +53,20 @@ AS VALUES (44, 'x', 3), (55, 'w', 3); +statement ok +CREATE TABLE join_t3(s3 struct) + AS VALUES + (NULL), + (struct(1)), + (struct(2)); + +statement ok +CREATE TABLE join_t4(s4 struct) + AS VALUES + (NULL), + (struct(2)), + (struct(3)); + # Left semi anti join statement ok @@ -1336,6 +1350,44 @@ physical_plan 10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 11)------------MemoryExec: partitions=1, partition_sizes=[1] +# Join on struct +query TT +explain select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +logical_plan +01)Inner Join: join_t3.s3 = join_t4.s4 +02)--TableScan: join_t3 projection=[s3] +03)--TableScan: join_t4 projection=[s4] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2 +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------MemoryExec: partitions=1, partition_sizes=[1] +07)----CoalesceBatchesExec: target_batch_size=2 +08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2 +09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)----------MemoryExec: partitions=1, partition_sizes=[1] + +query ?? +select join_t3.s3, join_t4.s4 +from join_t3 +inner join join_t4 on join_t3.s3 = join_t4.s4 +---- +{id: 2} {id: 2} + +# join with struct key and nulls +# Note that intersect or except applies `null_equals_null` as true for Join. +query ? +SELECT * FROM join_t3 +EXCEPT +SELECT * FROM join_t4 +---- +{id: 1} + query TT EXPLAIN select count(*) @@ -1343,15 +1395,15 @@ from (select * from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id) group by t1_id ---- logical_plan -01)Projection: COUNT(*) -02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Projection: count(*) +02)--Aggregate: groupBy=[[join_t1.t1_id]], aggr=[[count(Int64(1)) AS count(*)]] 03)----Projection: join_t1.t1_id 04)------Inner Join: join_t1.t1_id = join_t2.t2_id 05)--------TableScan: join_t1 projection=[t1_id] 06)--------TableScan: join_t2 projection=[t2_id] physical_plan -01)ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] -02)--AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[COUNT(*)] +01)ProjectionExec: expr=[count(*)@1 as count(*)] +02)--AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as t1_id], aggr=[count(*)] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -1370,30 +1422,29 @@ from join_t1 inner join join_t2 on join_t1.t1_id = join_t2.t2_id ---- logical_plan -01)Projection: COUNT(alias1) AS COUNT(DISTINCT join_t1.t1_id) -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(alias1)]] +01)Projection: count(alias1) AS count(DISTINCT join_t1.t1_id) +02)--Aggregate: groupBy=[[]], aggr=[[count(alias1)]] 03)----Aggregate: groupBy=[[join_t1.t1_id AS alias1]], aggr=[[]] 04)------Projection: join_t1.t1_id 05)--------Inner Join: join_t1.t1_id = join_t2.t2_id 06)----------TableScan: join_t1 projection=[t1_id] 07)----------TableScan: join_t2 projection=[t2_id] physical_plan -01)ProjectionExec: expr=[COUNT(alias1)@0 as COUNT(DISTINCT join_t1.t1_id)] -02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(alias1)] +01)ProjectionExec: expr=[count(alias1)@0 as count(DISTINCT join_t1.t1_id)] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(alias1)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(alias1)] -05)--------AggregateExec: mode=FinalPartitioned, gby=[alias1@0 as alias1], aggr=[] -06)----------AggregateExec: mode=Partial, gby=[t1_id@0 as alias1], aggr=[] -07)------------CoalesceBatchesExec: target_batch_size=2 -08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] -09)----------------CoalesceBatchesExec: target_batch_size=2 -10)------------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 -11)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -12)----------------------MemoryExec: partitions=1, partition_sizes=[1] -13)----------------CoalesceBatchesExec: target_batch_size=2 -14)------------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 -15)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -16)----------------------MemoryExec: partitions=1, partition_sizes=[1] +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(alias1)] +05)--------AggregateExec: mode=SinglePartitioned, gby=[t1_id@0 as alias1], aggr=[] +06)----------CoalesceBatchesExec: target_batch_size=2 +07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(t1_id@0, t2_id@0)], projection=[t1_id@0] +08)--------------CoalesceBatchesExec: target_batch_size=2 +09)----------------RepartitionExec: partitioning=Hash([t1_id@0], 2), input_partitions=2 +10)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +11)--------------------MemoryExec: partitions=1, partition_sizes=[1] +12)--------------CoalesceBatchesExec: target_batch_size=2 +13)----------------RepartitionExec: partitioning=Hash([t2_id@0], 2), input_partitions=2 +14)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +15)--------------------MemoryExec: partitions=1, partition_sizes=[1] statement ok set datafusion.explain.logical_plan_only = true; @@ -2010,7 +2061,7 @@ set datafusion.explain.logical_plan_only = false; statement ok set datafusion.execution.target_partitions = 4; -# Planning inner nested loop join +# Planning inner nested loop join # inputs are swapped due to inexact statistics + join reordering caused additional projection query TT @@ -2031,15 +2082,14 @@ physical_plan 01)ProjectionExec: expr=[t1_id@1 as t1_id, t2_id@0 as t2_id] 02)--NestedLoopJoinExec: join_type=Inner, filter=t1_id@0 > t2_id@1 03)----CoalescePartitionsExec -04)------ProjectionExec: expr=[t2_id@0 as t2_id] -05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------FilterExec: t2_int@1 > 1 -07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -08)--------------MemoryExec: partitions=1, partition_sizes=[1] -09)----CoalesceBatchesExec: target_batch_size=2 -10)------FilterExec: t1_id@0 > 10 -11)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -12)----------MemoryExec: partitions=1, partition_sizes=[1] +04)------CoalesceBatchesExec: target_batch_size=2 +05)--------FilterExec: t2_int@1 > 1, projection=[t2_id@0] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------MemoryExec: partitions=1, partition_sizes=[1] +08)----CoalesceBatchesExec: target_batch_size=2 +09)------FilterExec: t1_id@0 > 10 +10)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +11)----------MemoryExec: partitions=1, partition_sizes=[1] query II SELECT join_t1.t1_id, join_t2.t2_id @@ -2622,7 +2672,7 @@ logical_plan 05)----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] # hash_join_with_date32 -query DDR?DDR? rowsort +query DDRTDDRT rowsort select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c1 = t2.c1 ---- 1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc @@ -2641,7 +2691,7 @@ logical_plan 05)----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] # hash_join_with_date64 -query DDR?DDR? rowsort +query DDRTDDRT rowsort select * from hashjoin_datatype_table_t1 t1 left join hashjoin_datatype_table_t2 t2 on t1.c2 = t2.c2 ---- 1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc @@ -2662,7 +2712,7 @@ logical_plan 05)----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] # hash_join_with_decimal -query DDR?DDR? rowsort +query DDRTDDRT rowsort select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t1 t2 on t1.c3 = t2.c3 ---- 1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 1.23 abc @@ -2682,7 +2732,7 @@ logical_plan 05)----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] # hash_join_with_dictionary -query DDR?DDR? rowsort +query DDRTDDRT rowsort select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c4 = t2.c4 ---- 1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc @@ -2721,19 +2771,19 @@ logical_plan 05)----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] physical_plan 01)SortMergeJoin: join_type=Inner, on=[(c1@0, c1@0)] -02)--SortExec: expr=[c1@0 ASC] +02)--SortExec: expr=[c1@0 ASC], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 06)----------MemoryExec: partitions=1, partition_sizes=[1] -07)--SortExec: expr=[c1@0 ASC] +07)--SortExec: expr=[c1@0 ASC], preserve_partitioning=[true] 08)----CoalesceBatchesExec: target_batch_size=2 09)------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 10)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 11)----------MemoryExec: partitions=1, partition_sizes=[1] # sort_merge_join_on_date32 inner sort merge join on data type (Date32) -query DDR?DDR? rowsort +query DDRTDDRT rowsort select * from hashjoin_datatype_table_t1 t1 join hashjoin_datatype_table_t2 t2 on t1.c1 = t2.c1 ---- 1970-01-02 1970-01-02T00:00:00 1.23 abc 1970-01-02 1970-01-02T00:00:00 -123.12 abc @@ -2752,20 +2802,20 @@ logical_plan physical_plan 01)ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4] 02)--SortMergeJoin: join_type=Right, on=[(CAST(t1.c3 AS Decimal128(10, 2))@4, c3@2)] -03)----SortExec: expr=[CAST(t1.c3 AS Decimal128(10, 2))@4 ASC] +03)----SortExec: expr=[CAST(t1.c3 AS Decimal128(10, 2))@4 ASC], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=2 05)--------RepartitionExec: partitioning=Hash([CAST(t1.c3 AS Decimal128(10, 2))@4], 2), input_partitions=2 06)----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))] 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 08)--------------MemoryExec: partitions=1, partition_sizes=[1] -09)----SortExec: expr=[c3@2 ASC] +09)----SortExec: expr=[c3@2 ASC], preserve_partitioning=[true] 10)------CoalesceBatchesExec: target_batch_size=2 11)--------RepartitionExec: partitioning=Hash([c3@2], 2), input_partitions=2 12)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 13)------------MemoryExec: partitions=1, partition_sizes=[1] # sort_merge_join_on_decimal right join on data type (Decimal) -query DDR?DDR? rowsort +query DDRTDDRT rowsort select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t2 t2 on t1.c3 = t2.c3 ---- 1970-01-04 NULL -123.12 jkl 1970-01-02 1970-01-02T00:00:00 -123.12 abc @@ -2814,7 +2864,7 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -2855,7 +2905,7 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -2917,7 +2967,7 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 WHERE t1_id I ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -2953,7 +3003,7 @@ explain SELECT t1_id, t1_name FROM left_semi_anti_join_table_t1 t1 LEFT SEMI JOI ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)] 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -3011,7 +3061,7 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -3033,7 +3083,7 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -3093,7 +3143,7 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t1 t1 WHER ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@1 != t1_name@0 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -3110,7 +3160,7 @@ explain SELECT t1_id, t1_name, t1_int FROM right_semi_anti_join_table_t2 t2 RIGH ---- physical_plan 01)SortPreservingMergeExec: [t1_id@0 ASC NULLS LAST] -02)--SortExec: expr=[t1_id@0 ASC NULLS LAST] +02)--SortExec: expr=[t1_id@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=RightSemi, on=[(t2_id@0, t1_id@0)], filter=t2_name@0 != t1_name@1 05)--------MemoryExec: partitions=1, partition_sizes=[1] @@ -3162,9 +3212,12 @@ CREATE EXTERNAL TABLE annotated_data ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC NULLS FIRST, b ASC, c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +statement ok +set datafusion.optimizer.prefer_existing_sort = true; # sort merge join should propagate ordering equivalence of the left side # for inner join. Hence final requirement rn1 ASC is already satisfied at @@ -3181,26 +3234,24 @@ logical_plan 01)Sort: l_table.rn1 ASC NULLS LAST 02)--Inner Join: l_table.a = r_table.a 03)----SubqueryAlias: l_table -04)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +04)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +05)--------WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 06)----------TableScan: annotated_data projection=[a0, a, b, c, d] 07)----SubqueryAlias: r_table 08)------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan 01)SortPreservingMergeExec: [rn1@5 ASC NULLS LAST] 02)--SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] -03)----SortExec: expr=[rn1@5 ASC NULLS LAST] -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -07)------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -08)--------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -10)----SortExec: expr=[a@1 ASC] -11)------CoalesceBatchesExec: target_batch_size=2 -12)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -13)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -14)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, rn1@5 ASC NULLS LAST +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +07)------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +09)----CoalesceBatchesExec: target_batch_size=2 +10)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST +11)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +12)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true # sort merge join should propagate ordering equivalence of the right side # for right join. Hence final requirement rn1 ASC is already satisfied at @@ -3219,24 +3270,25 @@ logical_plan 03)----SubqueryAlias: l_table 04)------TableScan: annotated_data projection=[a0, a, b, c, d] 05)----SubqueryAlias: r_table -06)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -07)--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +06)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +07)--------WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 08)----------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan 01)SortPreservingMergeExec: [rn1@10 ASC NULLS LAST] 02)--SortMergeJoin: join_type=Right, on=[(a@1, a@1)] -03)----SortExec: expr=[a@1 ASC] -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -08)----SortExec: expr=[rn1@5 ASC NULLS LAST] -09)------CoalesceBatchesExec: target_batch_size=2 -10)--------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 -11)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -12)------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -13)--------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -14)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +03)----CoalesceBatchesExec: target_batch_size=2 +04)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST +05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +07)----CoalesceBatchesExec: target_batch_size=2 +08)------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, rn1@5 ASC NULLS LAST +09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +10)----------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +11)------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +12)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + +statement ok +set datafusion.optimizer.prefer_existing_sort = false; # SortMergeJoin should add ordering equivalences of # right table as lexicographical append to the global ordering @@ -3255,30 +3307,30 @@ logical_plan 01)Sort: l_table.a ASC NULLS FIRST, l_table.b ASC NULLS LAST, l_table.c ASC NULLS LAST, r_table.rn1 ASC NULLS LAST 02)--Inner Join: l_table.a = r_table.a 03)----SubqueryAlias: l_table -04)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +04)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +05)--------WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 06)----------TableScan: annotated_data projection=[a0, a, b, c, d] 07)----SubqueryAlias: r_table -08)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -09)--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +08)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +09)--------WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 10)----------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan -01)SortPreservingMergeExec: [a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@11 ASC NULLS LAST] -02)--SortExec: expr=[a@1 ASC,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,rn1@11 ASC NULLS LAST] +01)SortPreservingMergeExec: [a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, rn1@11 ASC NULLS LAST] +02)--SortExec: expr=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, rn1@11 ASC NULLS LAST], preserve_partitioning=[true] 03)----SortMergeJoin: join_type=Inner, on=[(a@1, a@1)] -04)------SortExec: expr=[a@1 ASC] +04)------SortExec: expr=[a@1 ASC], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=2 06)----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -08)--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -09)----------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +08)--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +09)----------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -11)------SortExec: expr=[a@1 ASC] +11)------SortExec: expr=[a@1 ASC], preserve_partitioning=[true] 12)--------CoalesceBatchesExec: target_batch_size=2 13)----------RepartitionExec: partitioning=Hash([a@1], 2), input_partitions=2 14)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -15)--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -16)----------------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +15)--------------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +16)----------------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 17)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true statement ok @@ -3305,15 +3357,15 @@ logical_plan 03)----SubqueryAlias: l_table 04)------TableScan: annotated_data projection=[a0, a, b, c, d] 05)----SubqueryAlias: r_table -06)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -07)--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +06)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +07)--------WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 08)----------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan 01)CoalesceBatchesExec: target_batch_size=2 02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@1, a@1)] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -04)----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -05)------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +05)------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 06)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true # hash join should propagate ordering equivalence of the right side for RIGHT ANTI join. @@ -3332,15 +3384,15 @@ logical_plan 03)----SubqueryAlias: l_table 04)------TableScan: annotated_data projection=[a] 05)----SubqueryAlias: r_table -06)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -07)--------WindowAggr: windowExpr=[[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +06)------Projection: annotated_data.a0, annotated_data.a, annotated_data.b, annotated_data.c, annotated_data.d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +07)--------WindowAggr: windowExpr=[[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 08)----------TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan 01)CoalesceBatchesExec: target_batch_size=2 02)--HashJoinExec: mode=CollectLeft, join_type=RightAnti, on=[(a@0, a@1)] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC], has_header=true -04)----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] -05)------BoundedWindowAggExec: wdw=[ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)----ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@5 as rn1] +05)------BoundedWindowAggExec: wdw=[row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 06)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true query TT @@ -3353,16 +3405,16 @@ ORDER BY l.a ASC NULLS FIRST; ---- logical_plan 01)Sort: l.a ASC NULLS FIRST -02)--Projection: l.a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 -03)----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]]] +02)--Projection: l.a, last_value(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 +03)----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]]] 04)------Inner Join: l.a = r.a 05)--------SubqueryAlias: l 06)----------TableScan: annotated_data projection=[a, b, c] 07)--------SubqueryAlias: r 08)----------TableScan: annotated_data projection=[a, b] physical_plan -01)ProjectionExec: expr=[a@0 as a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] -02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]], ordering_mode=PartiallySorted([0]) +01)ProjectionExec: expr=[a@0 as a, last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] +02)--AggregateExec: mode=Single, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]], ordering_mode=PartiallySorted([0]) 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(a@0, a@0)] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c], output_ordering=[a@0 ASC, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true @@ -3379,10 +3431,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); query TT EXPLAIN SELECT LAST_VALUE(l.d ORDER BY l.a) AS amount_usd @@ -3398,24 +3450,24 @@ ORDER BY row_n logical_plan 01)Projection: amount_usd 02)--Sort: row_n ASC NULLS LAST -03)----Projection: LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n -04)------Aggregate: groupBy=[[row_n]], aggr=[[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]]] +03)----Projection: last_value(l.d) ORDER BY [l.a ASC NULLS LAST] AS amount_usd, row_n +04)------Aggregate: groupBy=[[row_n]], aggr=[[last_value(l.d) ORDER BY [l.a ASC NULLS LAST]]] 05)--------Projection: l.a, l.d, row_n 06)----------Inner Join: l.d = r.d Filter: CAST(l.a AS Int64) >= CAST(r.a AS Int64) - Int64(10) 07)------------SubqueryAlias: l 08)--------------TableScan: multiple_ordered_table projection=[a, d] -09)------------Projection: r.a, r.d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n -10)--------------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +09)------------Projection: r.a, r.d, row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_n +10)--------------WindowAggr: windowExpr=[[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 11)----------------SubqueryAlias: r 12)------------------TableScan: multiple_ordered_table projection=[a, d] physical_plan -01)ProjectionExec: expr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] -02)--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[LAST_VALUE(l.d) ORDER BY [l.a ASC NULLS LAST]], ordering_mode=Sorted +01)ProjectionExec: expr=[last_value(l.d) ORDER BY [l.a ASC NULLS LAST]@1 as amount_usd] +02)--AggregateExec: mode=Single, gby=[row_n@2 as row_n], aggr=[last_value(l.d) ORDER BY [l.a ASC NULLS LAST]], ordering_mode=Sorted 03)----CoalesceBatchesExec: target_batch_size=2 04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(d@1, d@1)], filter=CAST(a@0 AS Int64) >= CAST(a@1 AS Int64) - 10, projection=[a@0, d@1, row_n@4] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true -06)--------ProjectionExec: expr=[a@0 as a, d@1 as d, ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] -07)----------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)--------ProjectionExec: expr=[a@0 as a, d@1 as d, row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as row_n] +07)----------BoundedWindowAggExec: wdw=[row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [r.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 08)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, d], output_ordering=[a@0 ASC NULLS LAST], has_header=true # run query above in multiple partitions @@ -3436,8 +3488,8 @@ ORDER BY l.a ASC NULLS FIRST; ---- logical_plan 01)Sort: l.a ASC NULLS FIRST -02)--Projection: l.a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 -03)----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]]] +02)--Projection: l.a, last_value(r.b) ORDER BY [r.a ASC NULLS FIRST] AS last_col1 +03)----Aggregate: groupBy=[[l.a, l.b, l.c]], aggr=[[last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]]] 04)------Inner Join: l.a = r.a 05)--------SubqueryAlias: l 06)----------TableScan: annotated_data projection=[a, b, c] @@ -3445,12 +3497,12 @@ logical_plan 08)----------TableScan: annotated_data projection=[a, b] physical_plan 01)SortPreservingMergeExec: [a@0 ASC] -02)--SortExec: expr=[a@0 ASC] -03)----ProjectionExec: expr=[a@0 as a, LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] -04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]] +02)--SortExec: expr=[a@0 ASC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[a@0 as a, last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]@3 as last_col1] +04)------AggregateExec: mode=FinalPartitioned, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]] 05)--------CoalesceBatchesExec: target_batch_size=2 06)----------RepartitionExec: partitioning=Hash([a@0, b@1, c@2], 2), input_partitions=2 -07)------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[LAST_VALUE(r.b) ORDER BY [r.a ASC NULLS FIRST]] +07)------------AggregateExec: mode=Partial, gby=[a@0 as a, b@1 as b, c@2 as c], aggr=[last_value(r.b) ORDER BY [r.a ASC NULLS FIRST]] 08)--------------CoalesceBatchesExec: target_batch_size=2 09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(a@0, a@0)] 10)------------------CoalesceBatchesExec: target_batch_size=2 @@ -3479,7 +3531,7 @@ physical_plan 03)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 04)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -# Currently datafusion cannot pushdown filter conditions with scalar UDF into +# Currently datafusion can pushdown filter conditions with scalar UDF into # cross join. query TT EXPLAIN SELECT * @@ -3487,19 +3539,16 @@ FROM annotated_data as t1, annotated_data as t2 WHERE EXAMPLE(t1.a, t2.a) > 3 ---- logical_plan -01)Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) -02)--CrossJoin: -03)----SubqueryAlias: t1 -04)------TableScan: annotated_data projection=[a0, a, b, c, d] -05)----SubqueryAlias: t2 -06)------TableScan: annotated_data projection=[a0, a, b, c, d] +01)Inner Join: Filter: example(CAST(t1.a AS Float64), CAST(t2.a AS Float64)) > Float64(3) +02)--SubqueryAlias: t1 +03)----TableScan: annotated_data projection=[a0, a, b, c, d] +04)--SubqueryAlias: t2 +05)----TableScan: annotated_data projection=[a0, a, b, c, d] physical_plan -01)CoalesceBatchesExec: target_batch_size=2 -02)--FilterExec: example(CAST(a@1 AS Float64), CAST(a@6 AS Float64)) > 3 -03)----CrossJoinExec -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true -05)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -06)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +01)NestedLoopJoinExec: join_type=Inner, filter=example(CAST(a@0 AS Float64), CAST(a@1 AS Float64)) > 3 +02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true +03)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true #### # Config teardown @@ -3602,17 +3651,16 @@ EXPLAIN SELECT * FROM ( ) as a FULL JOIN (SELECT 1 as e, 3 AS f) AS rhs ON a.c=rhs.e; ---- logical_plan -01)Projection: a.c, a.d, rhs.e, rhs.f -02)--Full Join: a.c = rhs.e -03)----SubqueryAlias: a -04)------Union -05)--------Projection: Int64(1) AS c, Int64(2) AS d -06)----------EmptyRelation -07)--------Projection: Int64(1) AS c, Int64(3) AS d -08)----------EmptyRelation -09)----SubqueryAlias: rhs -10)------Projection: Int64(1) AS e, Int64(3) AS f -11)--------EmptyRelation +01)Full Join: a.c = rhs.e +02)--SubqueryAlias: a +03)----Union +04)------Projection: Int64(1) AS c, Int64(2) AS d +05)--------EmptyRelation +06)------Projection: Int64(1) AS c, Int64(3) AS d +07)--------EmptyRelation +08)--SubqueryAlias: rhs +09)----Projection: Int64(1) AS e, Int64(3) AS f +10)------EmptyRelation physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2 @@ -3650,17 +3698,16 @@ EXPLAIN SELECT * FROM ( ) as a FULL JOIN (SELECT 1 as e, 3 AS f) AS rhs ON a.c=rhs.e; ---- logical_plan -01)Projection: a.c, a.d, rhs.e, rhs.f -02)--Full Join: a.c = rhs.e -03)----SubqueryAlias: a -04)------Union -05)--------Projection: Int64(1) AS c, Int64(2) AS d -06)----------EmptyRelation -07)--------Projection: Int64(1) AS c, Int64(3) AS d -08)----------EmptyRelation -09)----SubqueryAlias: rhs -10)------Projection: Int64(1) AS e, Int64(3) AS f -11)--------EmptyRelation +01)Full Join: a.c = rhs.e +02)--SubqueryAlias: a +03)----Union +04)------Projection: Int64(1) AS c, Int64(2) AS d +05)--------EmptyRelation +06)------Projection: Int64(1) AS c, Int64(3) AS d +07)--------EmptyRelation +08)--SubqueryAlias: rhs +09)----Projection: Int64(1) AS e, Int64(3) AS f +10)------EmptyRelation physical_plan 01)ProjectionExec: expr=[c@2 as c, d@3 as d, e@0 as e, f@1 as f] 02)--CoalesceBatchesExec: target_batch_size=2 @@ -3688,3 +3735,558 @@ set datafusion.explain.logical_plan_only = true; statement ok set datafusion.execution.target_partitions = 2; + +# Inner join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a INNER JOIN (SELECT 1 as a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Inner join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a INNER JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a LEFT JOIN (SELECT 1 as a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left join with empty left and empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a LEFT JOIN (SELECT 1 as a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a RIGHT JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right join with empty right and empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a RIGHT JOIN (SELECT 1 as a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left SEMI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a LEFT SEMI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left SEMI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a WHERE 1=0 +) AS a LEFT SEMI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right SEMI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a WHERE 1=0 +) AS a RIGHT SEMI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right SEMI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a RIGHT SEMI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left ANTI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a WHERE 1=0 +) AS a LEFT ANTI JOIN (SELECT 1 AS a) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Right ANTI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 AS a +) AS a RIGHT ANTI JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# FULL OUTER join with empty left and empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a FULL JOIN (SELECT 1 AS a WHERE 1=0) AS b ON a.a=b.a; +---- +logical_plan EmptyRelation + +# Left ANTI join with empty right table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a +) AS a LEFT ANTI JOIN (SELECT 1 AS a WHERE 1=0) as b ON a.a=b.a; +---- +logical_plan +01)SubqueryAlias: a +02)--Projection: Int64(1) AS a +03)----EmptyRelation + +# Right ANTI join with empty left table +query TT +EXPLAIN SELECT * FROM ( + SELECT 1 as a WHERE 1=0 +) AS a RIGHT ANTI JOIN (SELECT 1 AS a) as b ON a.a=b.a; +---- +logical_plan +01)SubqueryAlias: b +02)--Projection: Int64(1) AS a +03)----EmptyRelation + + +statement ok +set datafusion.execution.target_partitions = 1; + +statement ok +set datafusion.explain.logical_plan_only = false; + +statement ok +set datafusion.execution.batch_size = 3; + +# Right Hash Joins preserve the right ordering +# No nulls on build side: +statement ok +CREATE TABLE left_table_no_nulls(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(11, 1), +(12, 3), +(13, 5), +(14, 2), +(15, 4); + +statement ok +CREATE TABLE right_table_no_nulls(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(21, 1), +(22, 2), +(23, 3), +(24, 4); + +query IIII +SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b + LIMIT 10 +) AS rhs ON lhs.b=rhs.b +---- +11 1 21 1 +12 3 23 3 +14 2 22 2 +15 4 24 4 + +query TT +EXPLAIN SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b + LIMIT 10 +) AS rhs ON lhs.b=rhs.b +---- +logical_plan +01)Right Join: lhs.b = rhs.b +02)--SubqueryAlias: lhs +03)----TableScan: left_table_no_nulls projection=[a, b] +04)--SubqueryAlias: rhs +05)----Sort: right_table_no_nulls.b ASC NULLS LAST, fetch=10 +06)------TableScan: right_table_no_nulls projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)------MemoryExec: partitions=1, partition_sizes=[1] + + + +# Missing probe index in the middle of the batch: +statement ok +CREATE TABLE left_table_missing_probe(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(11, 1), +(12, 2), +(13, 3), +(14, 6), +(15, 8); + +statement ok +CREATE TABLE right_table_missing_probe(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(21, 1), +(22, 4), +(23, 6), +(24, 7), +(25, 8); + +query IIII +SELECT * FROM ( + SELECT * from left_table_missing_probe +) as lhs RIGHT JOIN ( + SELECT * from right_table_missing_probe + ORDER BY b + LIMIT 10 +) AS rhs ON lhs.b=rhs.b +---- +11 1 21 1 +NULL NULL 22 4 +14 6 23 6 +NULL NULL 24 7 +15 8 25 8 + +query TT +EXPLAIN SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b +) AS rhs ON lhs.b=rhs.b +---- +logical_plan +01)Right Join: lhs.b = rhs.b +02)--SubqueryAlias: lhs +03)----TableScan: left_table_no_nulls projection=[a, b] +04)--SubqueryAlias: rhs +05)----TableScan: right_table_no_nulls projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------MemoryExec: partitions=1, partition_sizes=[1] +05)------MemoryExec: partitions=1, partition_sizes=[1] + + +# Null build indices: +statement ok +CREATE TABLE left_table_append_null_build(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(11, 1), +(12, 1), +(13, 5), +(14, 5), +(15, 3); + +statement ok +CREATE TABLE right_table_append_null_build(a INT UNSIGNED, b INT UNSIGNED) +AS VALUES +(21, 4), +(22, 5), +(23, 6), +(24, 7), +(25, 8); + +query IIII +SELECT * FROM ( + SELECT * from left_table_append_null_build +) as lhs RIGHT JOIN ( + SELECT * from right_table_append_null_build + ORDER BY b + LIMIT 10 +) AS rhs ON lhs.b=rhs.b +---- +NULL NULL 21 4 +13 5 22 5 +14 5 22 5 +NULL NULL 23 6 +NULL NULL 24 7 +NULL NULL 25 8 + + +query TT +EXPLAIN SELECT * FROM ( + SELECT * from left_table_no_nulls +) as lhs RIGHT JOIN ( + SELECT * from right_table_no_nulls + ORDER BY b + LIMIT 10 +) AS rhs ON lhs.b=rhs.b +---- +logical_plan +01)Right Join: lhs.b = rhs.b +02)--SubqueryAlias: lhs +03)----TableScan: left_table_no_nulls projection=[a, b] +04)--SubqueryAlias: rhs +05)----Sort: right_table_no_nulls.b ASC NULLS LAST, fetch=10 +06)------TableScan: right_table_no_nulls projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b] +02)--CoalesceBatchesExec: target_batch_size=3 +03)----HashJoinExec: mode=CollectLeft, join_type=Left, on=[(b@1, b@1)] +04)------SortExec: TopK(fetch=10), expr=[b@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------MemoryExec: partitions=1, partition_sizes=[1] +06)------MemoryExec: partitions=1, partition_sizes=[1] + + +# Test CROSS JOIN LATERAL syntax (planning) +query TT +explain select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); +---- +logical_plan +01)Cross Join: +02)--SubqueryAlias: t1 +03)----TableScan: join_t1 projection=[t1_id, t1_name] +04)--SubqueryAlias: series +05)----Subquery: +06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)),depth=1) AS i +07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t1.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t1.t1_int))) +09)------------EmptyRelation + + +# Test CROSS JOIN LATERAL syntax (execution) +# TODO: https://github.com/apache/datafusion/issues/10048 +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t1" \}\), name: "t1_int" \}\) +select t1_id, t1_name, i from join_t1 t1 cross join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); + + +# Test INNER JOIN LATERAL syntax (planning) +query TT +explain select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); +---- +logical_plan +01)Inner Join: Filter: CAST(t2.t1_id AS Int64) > series.i +02)--SubqueryAlias: t2 +03)----TableScan: join_t1 projection=[t1_id, t1_name] +04)--SubqueryAlias: series +05)----Subquery: +06)------Projection: unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)),depth=1) AS i +07)--------Unnest: lists[unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int)))|depth=1] structs[] +08)----------Projection: generate_series(Int64(1), CAST(outer_ref(t2.t1_int) AS Int64)) AS unnest_placeholder(generate_series(Int64(1),outer_ref(t2.t1_int))) +09)------------EmptyRelation + + +# Test INNER JOIN LATERAL syntax (execution) +# TODO: https://github.com/apache/datafusion/issues/10048 +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression OuterReferenceColumn\(UInt32, Column \{ relation: Some\(Bare \{ table: "t2" \}\), name: "t1_int" \}\) +select t1_id, t1_name, i from join_t1 t2 inner join lateral (select * from unnest(generate_series(1, t1_int))) as series(i) on(t1_id > i); + +# Test RIGHT JOIN LATERAL syntax (unsupported) +query error DataFusion error: This feature is not implemented: LATERAL syntax is not supported for FULL OUTER and RIGHT \[OUTER \| ANTI \| SEMI\] joins +select t1_id, t1_name, i from join_t1 t1 right join lateral (select * from unnest(generate_series(1, t1_int))) as series(i); + + +# Functional dependencies across a join +statement ok +CREATE TABLE sales_global ( + ts TIMESTAMP, + sn INTEGER, + amount INTEGER, + currency VARCHAR NOT NULL, + primary key(sn) +); + +statement ok +CREATE TABLE exchange_rates ( + ts TIMESTAMP, + sn INTEGER, + currency_from VARCHAR NOT NULL, + currency_to VARCHAR NOT NULL, + rate FLOAT, + primary key(sn) +); + +query TT +EXPLAIN SELECT s.*, s.amount * LAST_VALUE(e.rate) AS amount_usd +FROM sales_global AS s +JOIN exchange_rates AS e +ON s.currency = e.currency_from AND + e.currency_to = 'USD' AND + s.ts >= e.ts +GROUP BY s.sn +ORDER BY s.sn +---- +logical_plan +01)Sort: s.sn ASC NULLS LAST +02)--Projection: s.ts, s.sn, s.amount, s.currency, CAST(s.amount AS Float32) * last_value(e.rate) AS amount_usd +03)----Aggregate: groupBy=[[s.sn, s.ts, s.amount, s.currency]], aggr=[[last_value(e.rate)]] +04)------Projection: s.ts, s.sn, s.amount, s.currency, e.rate +05)--------Inner Join: s.currency = e.currency_from Filter: s.ts >= e.ts +06)----------SubqueryAlias: s +07)------------TableScan: sales_global projection=[ts, sn, amount, currency] +08)----------SubqueryAlias: e +09)------------Projection: exchange_rates.ts, exchange_rates.currency_from, exchange_rates.rate +10)--------------Filter: exchange_rates.currency_to = Utf8("USD") +11)----------------TableScan: exchange_rates projection=[ts, currency_from, currency_to, rate] +physical_plan +01)SortExec: expr=[sn@1 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[ts@1 as ts, sn@0 as sn, amount@2 as amount, currency@3 as currency, CAST(amount@2 AS Float32) * last_value(e.rate)@4 as amount_usd] +03)----AggregateExec: mode=Single, gby=[sn@1 as sn, ts@0 as ts, amount@2 as amount, currency@3 as currency], aggr=[last_value(e.rate)] +04)------CoalesceBatchesExec: target_batch_size=3 +05)--------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(currency@3, currency_from@1)], filter=ts@0 >= ts@1, projection=[ts@0, sn@1, amount@2, currency@3, rate@6] +06)----------MemoryExec: partitions=1, partition_sizes=[0] +07)----------CoalesceBatchesExec: target_batch_size=3 +08)------------FilterExec: currency_to@2 = USD, projection=[ts@0, currency_from@1, rate@3] +09)--------------MemoryExec: partitions=1, partition_sizes=[0] + +statement ok +DROP TABLE sales_global; + +statement ok +DROP TABLE exchange_rates; + +# HashJoinExec and NestedLoopJoinExec can propagate SortExec down through its right child. + +statement ok +CREATE TABLE left_table(a INT, b INT, c INT) + +statement ok +CREATE TABLE right_table(x INT, y INT, z INT) + +query TT +EXPLAIN SELECT * FROM left_table JOIN right_table ON left_table.a= t1.c2 LIMIT 2; +---- +2 2 2 2 true +3 3 2 2 true + +query IIIIB +SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 2; +---- +2 2 2 2 true +2 2 2 2 false + +## Test !join.on.is_empty() && join.filter.is_none() +query TT +EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Full Join: t0.c1 = t1.c1 +03)----Limit: skip=0, fetch=2 +04)------TableScan: t0 projection=[c1, c2], fetch=2 +05)----Limit: skip=0, fetch=2 +06)------TableScan: t1 projection=[c1, c2, c3], fetch=2 +physical_plan +01)CoalesceBatchesExec: target_batch_size=3, fetch=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +## Test join.on.is_empty() && join.filter.is_some() +query TT +EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c2 >= t1.c2 LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Full Join: Filter: t0.c2 >= t1.c2 +03)----Limit: skip=0, fetch=2 +04)------TableScan: t0 projection=[c1, c2], fetch=2 +05)----Limit: skip=0, fetch=2 +06)------TableScan: t1 projection=[c1, c2, c3], fetch=2 +physical_plan +01)GlobalLimitExec: skip=0, fetch=2 +02)--NestedLoopJoinExec: join_type=Full, filter=c2@0 >= c2@1 +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +## Test !join.on.is_empty() && join.filter.is_some() +query TT +EXPLAIN SELECT * FROM t0 FULL JOIN t1 ON t0.c1 = t1.c1 AND t0.c2 >= t1.c2 LIMIT 2; +---- +logical_plan +01)Limit: skip=0, fetch=2 +02)--Full Join: t0.c1 = t1.c1 Filter: t0.c2 >= t1.c2 +03)----Limit: skip=0, fetch=2 +04)------TableScan: t0 projection=[c1, c2], fetch=2 +05)----Limit: skip=0, fetch=2 +06)------TableScan: t1 projection=[c1, c2, c3], fetch=2 +physical_plan +01)CoalesceBatchesExec: target_batch_size=3, fetch=2 +02)--HashJoinExec: mode=CollectLeft, join_type=Full, on=[(c1@0, c1@0)], filter=c2@0 >= c2@1 +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)----MemoryExec: partitions=1, partition_sizes=[1] + +# Test Utf8View as Join Key +# Issue: https://github.com/apache/datafusion/issues/12468 +statement ok +CREATE TABLE table1(v1 STRING) AS VALUES ('foo'), (NULL); + +statement ok +CREATE TABLE table1_stringview AS SELECT arrow_cast(v1, 'Utf8View') AS v1 FROM table1; + +query T +select * from table1 as t1 natural join table1_stringview as t2; +---- +foo diff --git a/datafusion/sqllogictest/test_files/json.slt b/datafusion/sqllogictest/test_files/json.slt index 5d3c23d5130b..0903c2427649 100644 --- a/datafusion/sqllogictest/test_files/json.slt +++ b/datafusion/sqllogictest/test_files/json.slt @@ -45,16 +45,21 @@ SELECT a, b FROM json_test 5 -3.5 7 -3.5 +# Ensure that local files can not be read by default (a potential security issue) +# (url table is only supported when DynamicFileCatalog is enabled) +statement error DataFusion error: Error during planning: table 'datafusion.public.../core/tests/data/2.json' not found +SELECT a, b FROM '../core/tests/data/2.json' + query TT EXPLAIN SELECT count(*) from json_test ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--TableScan: json_test projection=[] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)AggregateExec: mode=Final, gby=[], aggr=[count(*)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------JsonExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/2.json]]} diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index b4138f38ea2b..5b98392f1aa0 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -36,8 +36,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # async fn csv_query_limit query T @@ -307,11 +307,11 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Limit: skip=11, fetch=3 03)----TableScan: t1 projection=[], fetch=14 physical_plan -01)ProjectionExec: expr=[0 as COUNT(*)] +01)ProjectionExec: expr=[0 as count(*)] 02)--PlaceholderRowExec query I @@ -325,11 +325,11 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Limit: skip=8, fetch=3 03)----TableScan: t1 projection=[], fetch=11 physical_plan -01)ProjectionExec: expr=[2 as COUNT(*)] +01)ProjectionExec: expr=[2 as count(*)] 02)--PlaceholderRowExec query I @@ -343,11 +343,11 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Limit: skip=8, fetch=None 03)----TableScan: t1 projection=[] physical_plan -01)ProjectionExec: expr=[2 as COUNT(*)] +01)ProjectionExec: expr=[2 as count(*)] 02)--PlaceholderRowExec query I @@ -360,19 +360,19 @@ query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6); ---- logical_plan -01)Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Projection: 03)----Limit: skip=6, fetch=3 04)------Filter: t1.a > Int32(3) 05)--------TableScan: t1 projection=[a] physical_plan -01)AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)AggregateExec: mode=Final, gby=[], aggr=[count(*)] 02)--CoalescePartitionsExec -03)----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] +03)----AggregateExec: mode=Partial, gby=[], aggr=[count(*)] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------ProjectionExec: expr=[] 06)----------GlobalLimitExec: skip=6, fetch=3 -07)------------CoalesceBatchesExec: target_batch_size=8192 +07)------------CoalesceBatchesExec: target_batch_size=8192, fetch=9 08)--------------FilterExec: a@0 > 3 09)----------------MemoryExec: partitions=1, partition_sizes=[1] @@ -390,8 +390,8 @@ SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3; statement ok set datafusion.explain.show_sizes = false; -# verify that there are multiple partitions in the input (i.e. MemoryExec says -# there are 4 partitions) so that this tests multi-partition limit. +# verify that there are multiple partitions in the input so that this tests +# multi-partition limit. query TT EXPLAIN SELECT DISTINCT i FROM t1000; ---- @@ -402,8 +402,9 @@ physical_plan 01)AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4 -04)------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] -05)--------MemoryExec: partitions=4 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[] +06)----------MemoryExec: partitions=1 statement ok set datafusion.explain.show_sizes = true; @@ -512,3 +513,201 @@ SELECT COUNT(*) FROM (SELECT i FROM t15 LIMIT 9); statement ok drop table aggregate_test_100; + + +## Test limit pushdown in StreamingTableExec + +## Create sorted table with 5 rows +query I +COPY (select * from (values + (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e') +)) TO 'test_files/scratch/limit/data.csv' STORED AS CSV OPTIONS ('format.has_header' 'false'); +---- +5 + +statement ok +CREATE UNBOUNDED EXTERNAL TABLE data ( + "column1" INTEGER, + "column2" VARCHAR, +) STORED AS CSV +WITH ORDER ("column1", "column2") +LOCATION 'test_files/scratch/limit/data.csv' OPTIONS ('format.has_header' 'false'); + +query IT +SELECT * from data LIMIT 3; +---- +1 a +2 b +3 c + +# query +query TT +explain SELECT * FROM data LIMIT 3; +---- +logical_plan +01)Limit: skip=0, fetch=3 +02)--TableScan: data projection=[column1, column2], fetch=3 +physical_plan StreamingTableExec: partition_sizes=1, projection=[column1, column2], infinite_source=true, fetch=3, output_ordering=[column1@0 ASC NULLS LAST, column2@1 ASC NULLS LAST] + + +# Do not remove limit with Sort when skip is used +query TT +explain SELECT * FROM data ORDER BY column1 LIMIT 3,3; +---- +logical_plan +01)Limit: skip=3, fetch=3 +02)--Sort: data.column1 ASC NULLS LAST, fetch=6 +03)----TableScan: data projection=[column1, column2] +physical_plan +01)GlobalLimitExec: skip=3, fetch=3 +02)--StreamingTableExec: partition_sizes=1, projection=[column1, column2], infinite_source=true, fetch=6, output_ordering=[column1@0 ASC NULLS LAST, column2@1 ASC NULLS LAST] + + +statement ok +drop table data; + + +#################### +# Test issue: limit pushdown with offsets +# Ensure the offset is not lost: https://github.com/apache/datafusion/issues/12423 +#################### + +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT UNSIGNED, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +# all results +query II +SELECT b, sum(a) FROM ordered_table GROUP BY b order by b desc; +---- +3 25 +2 25 +1 0 +0 0 + +# limit only +query II +SELECT b, sum(a) FROM ordered_table GROUP BY b order by b desc LIMIT 3; +---- +3 25 +2 25 +1 0 + +# offset only +query II +SELECT b, sum(a) FROM ordered_table GROUP BY b order by b desc OFFSET 1; +---- +2 25 +1 0 +0 0 + +# offset + limit +query II +SELECT b, sum(a) FROM ordered_table GROUP BY b order by b desc OFFSET 1 LIMIT 2; +---- +2 25 +1 0 + +# Applying offset & limit when multiple streams from groupby +# the plan must still have a global limit to apply the offset +query TT +EXPLAIN SELECT b, sum(a) FROM ordered_table GROUP BY b order by b desc OFFSET 1 LIMIT 2; +---- +logical_plan +01)Limit: skip=1, fetch=2 +02)--Sort: ordered_table.b DESC NULLS FIRST, fetch=3 +03)----Aggregate: groupBy=[[ordered_table.b]], aggr=[[sum(CAST(ordered_table.a AS Int64))]] +04)------TableScan: ordered_table projection=[a, b] +physical_plan +01)GlobalLimitExec: skip=1, fetch=2 +02)--SortPreservingMergeExec: [b@0 DESC], fetch=3 +03)----SortExec: TopK(fetch=3), expr=[b@0 DESC], preserve_partitioning=[true] +04)------AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[sum(ordered_table.a)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 +07)------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[sum(ordered_table.a)] +08)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true + +# Applying offset & limit when multiple streams from union +# the plan must still have a global limit to apply the offset +query TT +explain select * FROM ( + select c FROM ordered_table + UNION ALL + select d FROM ordered_table +) order by 1 desc LIMIT 10 OFFSET 4; +---- +logical_plan +01)Limit: skip=4, fetch=10 +02)--Sort: ordered_table.c DESC NULLS FIRST, fetch=14 +03)----Union +04)------Projection: CAST(ordered_table.c AS Int64) AS c +05)--------TableScan: ordered_table projection=[c] +06)------Projection: CAST(ordered_table.d AS Int64) AS c +07)--------TableScan: ordered_table projection=[d] +physical_plan +01)GlobalLimitExec: skip=4, fetch=10 +02)--SortPreservingMergeExec: [c@0 DESC], fetch=14 +03)----UnionExec +04)------SortExec: TopK(fetch=14), expr=[c@0 DESC], preserve_partitioning=[true] +05)--------ProjectionExec: expr=[CAST(c@0 AS Int64) as c] +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true +08)------SortExec: TopK(fetch=14), expr=[c@0 DESC], preserve_partitioning=[true] +09)--------ProjectionExec: expr=[CAST(d@0 AS Int64) as c] +10)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +11)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[d], has_header=true + +# Applying LIMIT & OFFSET to subquery. +query III +select t1.b, c, c2 FROM ( + select b, c FROM ordered_table ORDER BY b desc, c desc OFFSET 1 LIMIT 4 +) as t1 INNER JOIN ( + select b, c as c2 FROM ordered_table ORDER BY b desc, d desc OFFSET 1 LIMIT 4 +) as t2 +ON t1.b = t2.b +ORDER BY t1.b desc, c desc, c2 desc; +---- +3 98 96 +3 98 89 +3 98 82 +3 98 79 +3 97 96 +3 97 89 +3 97 82 +3 97 79 +3 96 96 +3 96 89 +3 96 82 +3 96 79 +3 95 96 +3 95 89 +3 95 82 +3 95 79 + +# Apply OFFSET & LIMIT to both parent and child (subquery). +query III +select t1.b, c, c2 FROM ( + select b, c FROM ordered_table ORDER BY b desc, c desc OFFSET 1 LIMIT 4 +) as t1 INNER JOIN ( + select b, c as c2 FROM ordered_table ORDER BY b desc, d desc OFFSET 1 LIMIT 4 +) as t2 +ON t1.b = t2.b +ORDER BY t1.b desc, c desc, c2 desc +OFFSET 3 LIMIT 2; +---- +3 99 82 +3 99 79 + +statement ok +drop table ordered_table; diff --git a/datafusion/sqllogictest/test_files/map.slt b/datafusion/sqllogictest/test_files/map.slt index 415fabf224d7..d24b66aa5c30 100644 --- a/datafusion/sqllogictest/test_files/map.slt +++ b/datafusion/sqllogictest/test_files/map.slt @@ -15,11 +15,54 @@ # specific language governing permissions and limitations # under the License. +statement ok +CREATE TABLE map_array_table_1 +AS VALUES + (MAP {1: [1, NULL, 3], 2: [4, NULL, 6], 3: [7, 8, 9]}, 1, 1.0, '1'), + (MAP {4: [1, NULL, 3], 5: [4, NULL, 6], 6: [7, 8, 9]}, 5, 5.0, '5'), + (MAP {7: [1, NULL, 3], 8: [9, NULL, 6], 9: [7, 8, 9]}, 4, 4.0, '4') +; + +statement ok +CREATE TABLE map_array_table_2 +AS VALUES + (MAP {'1': [1, NULL, 3], '2': [4, NULL, 6], '3': [7, 8, 9]}, 1, 1.0, '1'), + (MAP {'4': [1, NULL, 3], '5': [4, NULL, 6], '6': [7, 8, 9]}, 5, 5.0, '5'), + (MAP {'7': [1, NULL, 3], '8': [9, NULL, 6], '9': [7, 8, 9]}, 4, 4.0, '4') +; + statement ok CREATE EXTERNAL TABLE data STORED AS PARQUET LOCATION '../core/tests/data/parquet_map.parquet'; +# Show shape of data: 3 columns, 209 rows +query TTT +describe data; +---- +ints Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) NO +strings Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) NO +timestamp Utf8 NO + +query ??T +SELECT * FROM data ORDER by ints['bytes'] DESC LIMIT 10; +---- +{bytes: 49960} {host: 21.169.210.169, method: GET, protocol: HTTP/1.1, referer: https://up.de/booper/bopper/mooper/mopper, request: /user/booperbot124, status: 500, user-identifier: shaneIxD} 06/Oct/2023:17:53:58 +{bytes: 49689} {host: 244.231.56.81, method: PATCH, protocol: HTTP/2.0, referer: https://names.de/this/endpoint/prints/money, request: /controller/setup, status: 500, user-identifier: ahmadajmi} 06/Oct/2023:17:53:54 +{bytes: 48768} {host: 127.152.34.105, method: POST, protocol: HTTP/1.1, referer: https://for.com/secret-info/open-sesame, request: /secret-info/open-sesame, status: 200, user-identifier: Karimmove} 06/Oct/2023:17:53:59 +{bytes: 48574} {host: 121.67.176.60, method: POST, protocol: HTTP/2.0, referer: https://names.com/this/endpoint/prints/money, request: /apps/deploy, status: 401, user-identifier: benefritz} 06/Oct/2023:17:54:02 +{bytes: 48274} {host: 39.37.198.203, method: DELETE, protocol: HTTP/1.0, referer: https://some.de/booper/bopper/mooper/mopper, request: /secret-info/open-sesame, status: 550, user-identifier: ahmadajmi} 06/Oct/2023:17:54:00 +{bytes: 47775} {host: 50.89.77.82, method: OPTION, protocol: HTTP/1.0, referer: https://random.com/observability/metrics/production, request: /controller/setup, status: 200, user-identifier: meln1ks} 06/Oct/2023:17:53:54 +{bytes: 47557} {host: 108.242.133.203, method: OPTION, protocol: HTTP/2.0, referer: https://we.org/observability/metrics/production, request: /apps/deploy, status: 500, user-identifier: meln1ks} 06/Oct/2023:17:53:48 +{bytes: 47552} {host: 206.248.141.240, method: HEAD, protocol: HTTP/1.1, referer: https://up.us/user/booperbot124, request: /wp-admin, status: 400, user-identifier: jesseddy} 06/Oct/2023:17:53:50 +{bytes: 47342} {host: 110.222.38.8, method: HEAD, protocol: HTTP/2.0, referer: https://we.com/controller/setup, request: /do-not-access/needs-work, status: 301, user-identifier: ahmadajmi} 06/Oct/2023:17:53:59 +{bytes: 47238} {host: 241.134.69.76, method: POST, protocol: HTTP/2.0, referer: https://up.de/do-not-access/needs-work, request: /controller/setup, status: 503, user-identifier: ahmadajmi} 06/Oct/2023:17:53:58 + +query I +SELECT COUNT(*) FROM data; +---- +209 + query I SELECT SUM(ints['bytes']) FROM data; ---- @@ -44,10 +87,20 @@ DELETE 24 query T SELECT strings['not_found'] FROM data LIMIT 1; ---- +NULL + +# Select non existent key, expect NULL for each row +query I +SELECT COUNT(CASE WHEN strings['not_found'] IS NULL THEN 1 ELSE 0 END) FROM data; +---- +209 statement ok drop table data; +query I? +select * from table_with_map where int_field > 0; +---- # Testing explain on a table with a map filter, registered in test_context.rs. query TT @@ -63,3 +116,602 @@ physical_plan statement ok drop table table_with_map; + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count; +---- +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} + +query I +SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST']; +---- +41 + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null); +---- +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null); +---- +{POST: , HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP(1, null, 2, 33, 3, null); +---- +{1: , 2: 33, 3: } + +query ? +SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']); +---- +{[1, 2]: [a, b], [3, 4]: [b]} + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 53, 'PATCH', 30); +---- +{POST: 41, HEAD: 53, PATCH: 30} + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'ab' to value of Int64 type +SELECT MAKE_MAP('POST', 41, 'HEAD', 'ab', 'PATCH', 30); + +# Map keys can not be NULL +query error +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, null, 30); + +query ? +SELECT MAKE_MAP() +---- +{} + +query error +SELECT MAKE_MAP('POST', 41, 'HEAD'); + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); +---- +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAP([[1,2], [3,4]], ['a', 'b']); +---- +{[1, 2]: a, [3, 4]: b} + +query error +SELECT MAP() + +query error DataFusion error: Execution error: map requires an even number of arguments, got 1 instead +SELECT MAP(['POST', 'HEAD']) + +query error DataFusion error: Execution error: Expected list, large_list or fixed_size_list, got Null +SELECT MAP(null, [41, 33, 30]); + +query error DataFusion error: Execution error: map requires key and value lists to have the same length +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33]); + +query error DataFusion error: Execution error: map key cannot be null +SELECT MAP(['POST', 'HEAD', null], [41, 33, 30]); + +statement error DataFusion error: Execution error: map key cannot be null +CREATE TABLE duplicated_keys_table +AS VALUES + (MAP {1: [1, NULL, 3], NULL: [4, NULL, 6]}); + +# Test duplicate keys +# key is a scalar type +query error DataFusion error: Execution error: map key must be unique, duplicate key found: POST +SELECT MAP(['POST', 'HEAD', 'POST'], [41, 33, null]); + +query error DataFusion error: Execution error: map key must be unique, duplicate key found: POST +SELECT MAP(make_array('POST', 'HEAD', 'POST'), make_array(41, 33, 30)); + +query error DataFusion error: Execution error: map key must be unique, duplicate key found: POST +SELECT make_map('POST', 41, 'HEAD', 33, 'POST', 30); + +statement error DataFusion error: Execution error: map key must be unique, duplicate key found: 1 +CREATE TABLE duplicated_keys_table +AS VALUES + (MAP {1: [1, NULL, 3], 1: [4, NULL, 6]}); + +statement ok +create table duplicate_keys_table as values +('a', 1, 'a', 10, ['k1', 'k1'], [1, 2]); + +query error DataFusion error: Execution error: map key must be unique, duplicate key found: a +SELECT make_map(column1, column2, column3, column4) FROM duplicate_keys_table; + +query error DataFusion error: Execution error: map key must be unique, duplicate key found: k1 +SELECT map(column5, column6) FROM duplicate_keys_table; + +# key is a nested type +query error DataFusion error: Execution error: map key must be unique, duplicate key found: \[1, 2\] +SELECT MAP([[1,2], [1,2], []], [41, 33, null]); + +query error DataFusion error: Execution error: map key must be unique, duplicate key found: \[\{1:1\}\] +SELECT MAP([Map {1:'1'}, Map {1:'1'}, Map {2:'2'}], [41, 33, null]); + + +query ? +SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')); +---- +{POST: 41, HEAD: 33, PATCH: 30} + +statement ok +create table t as values +('a', 1, 'k1', 10, ['k1', 'k2'], [1, 2], 'POST', [[1,2,3]], ['a']), +('b', 2, 'k3', 30, ['k3'], [3], 'PUT', [[4]], ['b']), +('d', 4, 'k5', 50, ['k5'], [5], null, [[1,2]], ['c']); + +query ? +SELECT make_map(column1, column2, column3, column4) FROM t; +---- +{a: 1, k1: 10} +{b: 2, k3: 30} +{d: 4, k5: 50} + +query ? +SELECT map(column5, column6) FROM t; +---- +{k1: 1, k2: 2} +{k3: 3} +{k5: 5} + +query ? +SELECT map(column8, column9) FROM t; +---- +{[1, 2, 3]: a} +{[4]: b} +{[1, 2]: c} + +query error +SELECT map(column6, column7) FROM t; + +query ? +select Map {column6: column7} from t; +---- +{[1, 2]: POST} +{[3]: PUT} +{[5]: } + +query ? +select Map {column8: column7} from t; +---- +{[[1, 2, 3]]: POST} +{[[4]]: PUT} +{[[1, 2]]: } + +query error +select Map {column7: column8} from t; + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', 30, 'OPTION', 29, 'GET', 27, 'PUT', 25, 'DELETE', 24) AS method_count from t; +---- +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} +{POST: 41, HEAD: 33, PATCH: 30, OPTION: 29, GET: 27, PUT: 25, DELETE: 24} + +query I +SELECT MAKE_MAP('POST', 41, 'HEAD', 33)['POST'] from t; +---- +41 +41 +41 + +query ? +SELECT MAKE_MAP('POST', 41, 'HEAD', 33, 'PATCH', null) from t; +---- +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP('POST', null, 'HEAD', 33, 'PATCH', null) from t; +---- +{POST: , HEAD: 33, PATCH: } +{POST: , HEAD: 33, PATCH: } +{POST: , HEAD: 33, PATCH: } + +query ? +SELECT MAKE_MAP(1, null, 2, 33, 3, null) from t; +---- +{1: , 2: 33, 3: } +{1: , 2: 33, 3: } +{1: , 2: 33, 3: } + +query ? +SELECT MAKE_MAP([1,2], ['a', 'b'], [3,4], ['b']) from t; +---- +{[1, 2]: [a, b], [3, 4]: [b]} +{[1, 2]: [a, b], [3, 4]: [b]} +{[1, 2]: [a, b], [3, 4]: [b]} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, 30]) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]) from t; +---- +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } +{POST: 41, HEAD: 33, PATCH: } + +query ? +SELECT MAP([[1,2], [3,4]], ['a', 'b']) from t; +---- +{[1, 2]: a, [3, 4]: b} +{[1, 2]: a, [3, 4]: b} +{[1, 2]: a, [3, 4]: b} + +query ? +SELECT MAP(make_array('POST', 'HEAD', 'PATCH'), make_array(41, 33, 30)) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'FixedSizeList(3, Utf8)'), arrow_cast(make_array(41, 33, 30), 'FixedSizeList(3, Int64)')) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + +query ? +SELECT MAP(arrow_cast(make_array('POST', 'HEAD', 'PATCH'), 'LargeList(Utf8)'), arrow_cast(make_array(41, 33, 30), 'LargeList(Int64)')) from t; +---- +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} +{POST: 41, HEAD: 33, PATCH: 30} + + +query ? +VALUES (MAP(['a'], [1])), (MAP(['b'], [2])), (MAP(['c', 'a'], [3, 1])) +---- +{a: 1} +{b: 2} +{c: 3, a: 1} + +query ? +SELECT MAP {'a':1, 'b':2, 'c':3}; +---- +{a: 1, b: 2, c: 3} + +query ? +SELECT MAP {'a':1, 'b':2, 'c':3 } FROM t; +---- +{a: 1, b: 2, c: 3} +{a: 1, b: 2, c: 3} +{a: 1, b: 2, c: 3} + +query I +SELECT MAP {'a':1, 'b':2, 'c':3}['a']; +---- +1 + +query I +SELECT MAP {'a':1, 'b':2, 'c':3 }['a'] FROM t; +---- +1 +1 +1 + +query ? +SELECT MAP {}; +---- +{} + +# values contain null +query ? +SELECT MAP {'a': 1, 'b': null}; +---- +{a: 1, b: } + +# keys contain null +query error DataFusion error: Execution error: map key cannot be null +SELECT MAP {'a': 1, null: 2} + +# array as key +query ? +SELECT MAP {[1,2,3]:1, [2,4]:2}; +---- +{[1, 2, 3]: 1, [2, 4]: 2} + +# array with different type as key +# expect to fail due to type coercion error +query error +SELECT MAP {[1,2,3]:1, ['a', 'b']:2}; + +# array as value +query ? +SELECT MAP {'a':[1,2,3], 'b':[2,4]}; +---- +{a: [1, 2, 3], b: [2, 4]} + +# array with different type as value +# expect to fail due to type coercion error +query error +SELECT MAP {'a':[1,2,3], 'b':['a', 'b']}; + +# struct as key +query ? +SELECT MAP {{'a':1, 'b':2}:1, {'a':3, 'b':4}:2}; +---- +{{a: 1, b: 2}: 1, {a: 3, b: 4}: 2} + +# struct with different fields as key +# expect to fail due to type coercion error +query error +SELECT MAP {{'a':1, 'b':2}:1, {'c':3, 'd':4}:2}; + +# struct as value +query ? +SELECT MAP {'a':{'b':1, 'c':2}, 'b':{'b':3, 'c':4}}; +---- +{a: {b: 1, c: 2}, b: {b: 3, c: 4}} + +# struct with different fields as value +# expect to fail due to type coercion error +query error +SELECT MAP {'a':{'b':1, 'c':2}, 'b':{'c':3, 'd':4}}; + +# map as key +query ? +SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }; +---- +{{1: a, 2: b}: 1, {1: c, 2: d}: 2} + +# map with different keys as key +query ? +SELECT MAP { MAP {1:'a', 2:'b', 3:'c'}:1, MAP {2:'c', 4:'d'}:2 }; +---- +{{1: a, 2: b, 3: c}: 1, {2: c, 4: d}: 2} + +# map as value +query ? +SELECT MAP {1: MAP {1:'a', 2:'b'}, 2: MAP {1:'c', 2:'d'} }; +---- +{1: {1: a, 2: b}, 2: {1: c, 2: d}} + +# map with different keys as value +query ? +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }; +---- +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} + +# complex map for each row +query ? +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} } from t; +---- +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} +{a: {1: a, 2: b, 3: c}, b: {2: c, 4: d}} + +# access map with non-existent key +query ? +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }['c']; +---- +NULL + +# access map with null key +query error +SELECT MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }[NULL]; + +query ? +SELECT MAP { 'a': 1, 'b': 3 }; +---- +{a: 1, b: 3} + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT MAP { 'a': 1, 2: 3 }; + +# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key +# query ? +# SELECT MAP { 1: 'a', 2: 'b', 3: 'c' }[1]; +# ---- +# a + +# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key +# query ? +# SELECT MAP { MAP {1:'a', 2:'b'}:1, MAP {1:'c', 2:'d'}:2 }[MAP {1:'a', 2:'b'}]; +# ---- +# 1 + +# TODO(https://github.com/apache/datafusion/issues/11785): fix accessing map with non-string key +# query ? +# SELECT MAKE_MAP(1, null, 2, 33, 3, null)[2]; +# ---- +# 33 + +## cardinality + +# cardinality scalar function +query IIII +select cardinality(map([1, 2, 3], ['a', 'b', 'c'])), cardinality(MAP {'a': 1, 'b': null}), cardinality(MAP([],[])), + cardinality(MAP {'a': MAP {1:'a', 2:'b', 3:'c'}, 'b': MAP {2:'c', 4:'d'} }); +---- +3 2 0 2 + +# map_extract +# key is string +query ???? +select map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'), map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'b'), + map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'c'), map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'd'); +---- +[1] [] [3] [] + +# key is integer +query ???? +select map_extract(MAP {1: 1, 2: NULL, 3:3}, 1), map_extract(MAP {1: 1, 2: NULL, 3:3}, 2), + map_extract(MAP {1: 1, 2: NULL, 3:3}, 3), map_extract(MAP {1: 1, 2: NULL, 3:3}, 4); +---- +[1] [] [3] [] + +# value is list +query ???? +select map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 1), map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 2), + map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 3), map_extract(MAP {1: [1, 2], 2: NULL, 3:[3]}, 4); +---- +[[1, 2]] [] [[3]] [] + +# key in map and query key are different types +query ????? +select map_extract(MAP {1: 1, 2: 2, 3:3}, '1'), map_extract(MAP {1: 1, 2: 2, 3:3}, 1.0), + map_extract(MAP {1.0: 1, 2: 2, 3:3}, '1'), map_extract(MAP {'1': 1, '2': 2, '3':3}, 1.0), + map_extract(MAP {arrow_cast('1', 'Utf8View'): 1, arrow_cast('2', 'Utf8View'): 2, arrow_cast('3', 'Utf8View'):3}, '1'); +---- +[1] [1] [1] [] [1] + +# map_extract with columns +query ??? +select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_1; +---- +[[1, , 3]] [] [] +[] [[4, , 6]] [] +[] [] [[1, , 3]] + +query ??? +select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_1; +---- +[[1, , 3]] [[1, , 3]] [[1, , 3]] +[[4, , 6]] [[4, , 6]] [[4, , 6]] +[] [] [] + +query ??? +select map_extract(column1, column2), map_extract(column1, column3), map_extract(column1, column4) from map_array_table_2; +---- +[[1, , 3]] [] [[1, , 3]] +[[4, , 6]] [] [[4, , 6]] +[] [] [] + +query ??? +select map_extract(column1, 1), map_extract(column1, 5), map_extract(column1, 7) from map_array_table_2; +---- +[[1, , 3]] [] [] +[] [[4, , 6]] [] +[] [] [[1, , 3]] + +# Tests for map_keys + +query ? +SELECT map_keys(MAP { 'a': 1, 'b': 3 }); +---- +[a, b] + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT map_keys(MAP { 'a': 1, 2: 3 }); + +query ? +SELECT map_keys(MAP {'a':1, 'b':2, 'c':3 }) FROM t; +---- +[a, b, c] +[a, b, c] +[a, b, c] + +query ? +SELECT map_keys(Map{column1: column2, column3: column4}) FROM t; +---- +[a, k1] +[b, k3] +[d, k5] + +query ? +SELECT map_keys(map(column5, column6)) FROM t; +---- +[k1, k2] +[k3] +[k5] + +query ? +SELECT map_keys(map(column8, column9)) FROM t; +---- +[[1, 2, 3]] +[[4]] +[[1, 2]] + +query ? +SELECT map_keys(Map{}); +---- +[] + +query ? +SELECT map_keys(column1) from map_array_table_1; +---- +[1, 2, 3] +[4, 5, 6] +[7, 8, 9] + + +# Tests for map_values + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Int64 type +SELECT map_values(MAP { 'a': 1, 2: 3 }); + +query ? +SELECT map_values(MAP { 'a': 1, 'b': 3 }); +---- +[1, 3] + +query ? +SELECT map_values(MAP {'a':1, 'b':2, 'c':3 }) FROM t; +---- +[1, 2, 3] +[1, 2, 3] +[1, 2, 3] + +query ? +SELECT map_values(Map{column1: column2, column3: column4}) FROM t; +---- +[1, 10] +[2, 30] +[4, 50] + +query ? +SELECT map_values(map(column5, column6)) FROM t; +---- +[1, 2] +[3] +[5] + +query ? +SELECT map_values(map(column8, column9)) FROM t; +---- +[a] +[b] +[c] + +query ? +SELECT map_values(Map{}); +---- +[] + +query ? +SELECT map_values(column1) from map_array_table_1; +---- +[[1, , 3], [4, , 6], [7, 8, 9]] +[[1, , 3], [4, , 6], [7, 8, 9]] +[[1, , 3], [9, , 6], [7, 8, 9]] + +statement ok +drop table map_array_table_1; + +statement ok +drop table map_array_table_2; diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 5f3e1dd9ee11..1bc972a3e37d 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -20,7 +20,7 @@ ########## statement ok -CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv'; +CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV LOCATION '../core/tests/data/aggregate_simple.csv' OPTIONS ('format.has_header' 'true'); # Round query R @@ -102,7 +102,12 @@ SELECT nanvl(asin(10), 1.0), nanvl(1.0, 2.0), nanvl(asin(10), asin(10)) # isnan query BBBB -SELECT isnan(1.0), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +SELECT isnan(1.0::DOUBLE), isnan('NaN'::DOUBLE), isnan(-'NaN'::DOUBLE), isnan(NULL) +---- +false true true NULL + +query BBBB +SELECT isnan(1.0::FLOAT), isnan('NaN'::FLOAT), isnan(-'NaN'::FLOAT), isnan(NULL::FLOAT) ---- false true true NULL @@ -112,12 +117,12 @@ SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) ---- false true true NULL -# abs: empty argumnet -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'abs\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +# abs: empty argument +statement error SELECT abs(); # abs: wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'abs\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +statement error SELECT abs(1, 2); # abs: unsupported argument type @@ -142,12 +147,12 @@ CREATE TABLE test_nullable_integer( (0, 0, 0, 0, 0, 0, 0, 0, 'zeros'), (1, 1, 1, 1, 1, 1, 1, 1, 'ones'); -query IIIIIIIIT +query I INSERT into test_nullable_integer values(-128, -32768, -2147483648, -9223372036854775808, 0, 0, 0, 0, 'mins'); ---- 1 -query IIIIIIIIT +query I INSERT into test_nullable_integer values(127, 32767, 2147483647, 9223372036854775807, 255, 65535, 4294967295, 18446744073709551615, 'maxs'); ---- 1 @@ -280,10 +285,10 @@ CREATE TABLE test_non_nullable_integer( c5 TINYINT UNSIGNED NOT NULL, c6 SMALLINT UNSIGNED NOT NULL, c7 INT UNSIGNED NOT NULL, - c8 BIGINT UNSIGNED NOT NULL, + c8 BIGINT UNSIGNED NOT NULL ); -query IIIIIIII +query I INSERT INTO test_non_nullable_integer VALUES(1, 1, 1, 1, 1, 1, 1, 1) ---- 1 @@ -348,7 +353,7 @@ drop table test_non_nullable_integer statement ok CREATE TABLE test_nullable_float( c1 float, - c2 double, + c2 double ) AS VALUES (-1.0, -1.0), (1.0, 1.0), @@ -415,10 +420,10 @@ drop table test_nullable_float statement ok CREATE TABLE test_non_nullable_float( c1 float NOT NULL, - c2 double NOT NULL, + c2 double NOT NULL ); -query RR +query I INSERT INTO test_non_nullable_float VALUES (-1.0, -1.0), (1.0, 1.0), @@ -473,7 +478,7 @@ CREATE TABLE test_nullable_decimal( (0, 0, 0, 0), (NULL, NULL, NULL, NULL); -query RRRR +query I INSERT into test_nullable_decimal values ( -99999999.99, @@ -546,7 +551,7 @@ drop table test_nullable_decimal statement ok CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL); -query R +query I INSERT INTO test_non_nullable_decimal VALUES(1) ---- 1 @@ -564,3 +569,121 @@ SELECT c1%0 FROM test_non_nullable_decimal statement ok drop table test_non_nullable_decimal + +statement ok +CREATE TABLE signed_integers( + a INT, + b INT, + c INT, + d INT, + e INT, + f INT +) as VALUES + (-1, 100, -567, 1024, -4, 10), + (2, -1000, 123, -256, 5, -11), + (-3, 10000, -978, 2048, -6, 12), + (4, NULL, NULL, -512, NULL, NULL) +; + + +## gcd + +# gcd scalar function +query IIIIII rowsort +select gcd(0, 0), gcd(2, 0), gcd(0, 2), gcd(2, 3), gcd(15, 10), gcd(20, 1000) +---- +0 2 2 1 5 20 + +# gcd with negative values +query IIIII +select gcd(-100, 0), gcd(0, -100), gcd(-2, 3), gcd(15, -10), gcd(-20, -1000) +---- +100 100 1 5 20 + +# gcd scalar nulls +query III +select gcd(null, 64), gcd(2, null), gcd(null, null); +---- +NULL NULL NULL + +# scalar maxes and/or negative 1 +query III +select + gcd(9223372036854775807, -9223372036854775808), -- i64::MAX, i64::MIN + gcd(9223372036854775807, -1), -- i64::MAX, -1 + gcd(-9223372036854775808, -1); -- i64::MIN, -1 +---- +1 1 1 + +# gcd with columns and expresions +query II rowsort +select gcd(a, b), gcd(c*d + 1, abs(e)) + f from signed_integers; +---- +1 11 +1 13 +2 -10 +NULL NULL + +# gcd(i64::MIN, i64::MIN) +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(\-9223372036854775808, \-9223372036854775808\) +select gcd(-9223372036854775808, -9223372036854775808); + +# gcd(i64::MIN, 0) +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(\-9223372036854775808, 0\) +select gcd(-9223372036854775808, 0); + +# gcd(0, i64::MIN) +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in GCD\(0, \-9223372036854775808\) +select gcd(0, -9223372036854775808); + + +## lcm + +# Basic cases +query IIIIII +select lcm(0, 0), lcm(0, 2), lcm(3, 0), lcm(2, 3), lcm(15, 10), lcm(20, 1000) +---- +0 0 0 6 30 1000 + +# Test lcm with negative numbers +query IIIII +select lcm(0, -2), lcm(-3, 0), lcm(-2, 3), lcm(15, -10), lcm(-15, -10) +---- +0 0 6 30 30 + +# Test lcm with Nulls +query III +select lcm(null, 64), lcm(16, null), lcm(null, null) +---- +NULL NULL NULL + +# Test lcm with columns +query III rowsort +select lcm(a, b), lcm(c, d), lcm(e, f) from signed_integers; +---- +100 580608 20 +1000 31488 55 +30000 1001472 12 +NULL NULL NULL + +# Result cannot fit in i64 +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(\-9223372036854775808, \-9223372036854775808\) +select lcm(-9223372036854775808, -9223372036854775808); + +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(1, \-9223372036854775808\) +select lcm(1, -9223372036854775808); + +# Overflow on multiplication +query error DataFusion error: Arrow error: Compute error: Signed integer overflow in LCM\(2, 9223372036854775803\) +select lcm(2, 9223372036854775803); + + +query error DataFusion error: Arrow error: Arithmetic overflow: Overflow happened on: 2107754225 \^ 1221660777 +select power(2107754225, 1221660777); + +# factorial overflow +query error DataFusion error: Arrow error: Compute error: Overflow happened on FACTORIAL\(350943270\) +select FACTORIAL(350943270); + +statement ok +drop table signed_integers diff --git a/datafusion/sqllogictest/test_files/metadata.slt b/datafusion/sqllogictest/test_files/metadata.slt index 3b2b219244f5..8f787254c096 100644 --- a/datafusion/sqllogictest/test_files/metadata.slt +++ b/datafusion/sqllogictest/test_files/metadata.slt @@ -25,7 +25,7 @@ ## with metadata in SQL. query IT -select * from table_with_metadata; +select id, name from table_with_metadata; ---- 1 NULL NULL bar @@ -58,5 +58,115 @@ WHERE "data"."id" = "samples"."id"; 1 3 + + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select count(distinct name) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +query I +select approx_median(distinct id) from table_with_metadata; +---- +2 + +# Regression test: prevent field metadata loss per https://github.com/apache/datafusion/issues/12687 +statement ok +select array_agg(distinct id) from table_with_metadata; + +query I +select distinct id from table_with_metadata order by id; +---- +1 +3 +NULL + +query I +select count(id) from table_with_metadata; +---- +2 + +query I +select count(id) cnt from table_with_metadata group by name order by cnt; +---- +0 +1 +1 + + + +# Regression test: missing schema metadata, when aggregate on cross join +query I +SELECT count("data"."id") +FROM + ( + SELECT "id" FROM "table_with_metadata" + ) as "data", + ( + SELECT "id" FROM "table_with_metadata" + ) as "samples"; +---- +6 + +# Regression test: missing field metadata, from the NULL field on the left side of the union +query ITT +(SELECT id, NULL::string as name, l_name FROM "table_with_metadata") + UNION +(SELECT id, name, NULL::string as l_name FROM "table_with_metadata") +ORDER BY id, name, l_name; +---- +1 NULL NULL +3 baz NULL +3 NULL l_baz +NULL bar NULL +NULL NULL l_bar + +# Regression test: missing field metadata from left side of the union when right side is chosen +query T +select name from ( + SELECT nonnull_name as name FROM "table_with_metadata" + UNION ALL + SELECT NULL::string as name +) group by name order by name; +---- +no_bar +no_baz +no_foo +NULL + +# Regression test: missing schema metadata from union when schema with metadata isn't the first one +# and also ensure it works fine with multiple unions +query T +select name from ( + SELECT NULL::string as name + UNION ALL + SELECT nonnull_name as name FROM "table_with_metadata" + UNION ALL + SELECT NULL::string as name +) group by name order by name; +---- +no_bar +no_baz +no_foo +NULL + +query P rowsort +SELECT ts +FROM (( + SELECT now() AS ts + FROM table_with_metadata +) UNION ALL ( + SELECT ts + FROM table_with_metadata +)) +GROUP BY ts +ORDER BY ts +LIMIT 1; +---- +2020-09-08T13:42:29.190855123Z + + statement ok drop table table_with_metadata; diff --git a/datafusion/sqllogictest/test_files/misc.slt b/datafusion/sqllogictest/test_files/misc.slt index 848cdc943914..9bd3023b56f7 100644 --- a/datafusion/sqllogictest/test_files/misc.slt +++ b/datafusion/sqllogictest/test_files/misc.slt @@ -24,3 +24,21 @@ query TT? select 'foo', '', NULL ---- foo (empty) NULL + +# Where clause accept NULL literal +query I +select 1 where NULL +---- + +# Where clause does not accept non boolean and has nice error message +query error Cannot create filter with non\-boolean predicate 'Utf8\("foo"\)' returning Utf8 +select 1 where 'foo' + +query I +select 1 where NULL and 1 = 1 +---- + +query I +select 1 where NULL or 1 = 1 +---- +1 diff --git a/datafusion/sqllogictest/test_files/monotonic_projection_test.slt b/datafusion/sqllogictest/test_files/monotonic_projection_test.slt index 943a7e07bb23..abf48fac5364 100644 --- a/datafusion/sqllogictest/test_files/monotonic_projection_test.slt +++ b/datafusion/sqllogictest/test_files/monotonic_projection_test.slt @@ -25,10 +25,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # test for substitute CAST scenario query TT @@ -44,7 +44,7 @@ logical_plan 02)--Projection: CAST(multiple_ordered_table.a AS Int64) AS a_big, multiple_ordered_table.b 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan -01)SortPreservingMergeExec: [a_big@0 ASC NULLS LAST,b@1 ASC NULLS LAST] +01)SortPreservingMergeExec: [a_big@0 ASC NULLS LAST, b@1 ASC NULLS LAST] 02)--ProjectionExec: expr=[CAST(a@0 AS Int64) as a_big, b@1 as b] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true @@ -60,7 +60,7 @@ logical_plan 02)--Projection: multiple_ordered_table.a, CAST(multiple_ordered_table.a AS Int64) AS a_big, multiple_ordered_table.b 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan -01)SortPreservingMergeExec: [a@0 ASC NULLS LAST,b@2 ASC NULLS LAST] +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST, b@2 ASC NULLS LAST] 02)--ProjectionExec: expr=[a@0 as a, CAST(a@0 AS Int64) as a_big, b@1 as b] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true @@ -81,7 +81,7 @@ logical_plan 02)--Projection: multiple_ordered_table.a, CAST(multiple_ordered_table.a AS Int64) AS a_big, multiple_ordered_table.b 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan -01)SortPreservingMergeExec: [a_big@1 ASC NULLS LAST,b@2 ASC NULLS LAST] +01)SortPreservingMergeExec: [a_big@1 ASC NULLS LAST, b@2 ASC NULLS LAST] 02)--ProjectionExec: expr=[a@0 as a, CAST(a@0 AS Int64) as a_big, b@1 as b] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true @@ -132,8 +132,8 @@ logical_plan 02)--Projection: CAST(multiple_ordered_table.a AS Utf8) AS a_str, multiple_ordered_table.b 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan -01)SortPreservingMergeExec: [a_str@0 ASC NULLS LAST,b@1 ASC NULLS LAST] -02)--SortExec: expr=[a_str@0 ASC NULLS LAST,b@1 ASC NULLS LAST] +01)SortPreservingMergeExec: [a_str@0 ASC NULLS LAST, b@1 ASC NULLS LAST] +02)--SortExec: expr=[a_str@0 ASC NULLS LAST, b@1 ASC NULLS LAST], preserve_partitioning=[true] 03)----ProjectionExec: expr=[CAST(a@0 AS Utf8) as a_str, b@1 as b] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true @@ -151,7 +151,7 @@ logical_plan 01)Sort: multiple_ordered_table.a + multiple_ordered_table.b ASC NULLS LAST 02)--TableScan: multiple_ordered_table projection=[a, b] physical_plan -01)SortExec: expr=[a@0 + b@1 ASC NULLS LAST] +01)SortExec: expr=[a@0 + b@1 ASC NULLS LAST], preserve_partitioning=[false] 02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true # With similar reasoning above. It is not guaranteed sum_expr is ordered @@ -168,7 +168,7 @@ logical_plan 03)----TableScan: multiple_ordered_table projection=[a, b] physical_plan 01)SortPreservingMergeExec: [sum_expr@0 ASC NULLS LAST] -02)--SortExec: expr=[sum_expr@0 ASC NULLS LAST] +02)--SortExec: expr=[sum_expr@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----ProjectionExec: expr=[CAST(a@0 + b@1 AS Int64) as sum_expr, a@0 as a, b@1 as b] 04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], has_header=true diff --git a/datafusion/sqllogictest/test_files/nvl.slt b/datafusion/sqllogictest/test_files/nvl.slt index c77214cc302a..81e79e1eb5b0 100644 --- a/datafusion/sqllogictest/test_files/nvl.slt +++ b/datafusion/sqllogictest/test_files/nvl.slt @@ -114,7 +114,7 @@ SELECT NVL(1, 3); ---- 1 -query ? +query I SELECT NVL(NULL, NULL); ---- NULL diff --git a/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt new file mode 100644 index 000000000000..de6a153f58d9 --- /dev/null +++ b/datafusion/sqllogictest/test_files/optimizer_group_by_constant.slt @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE test_table ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT, + c5 INT, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL +) +STORED AS CSV +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); + +statement ok +SET datafusion.execution.target_partitions = 1; + +statement ok +SET datafusion.explain.logical_plan_only = true; + +query TT +EXPLAIN +SELECT c1, 99999, c5 + c8, 'test', count(1) +FROM test_table t +GROUP BY 1, 2, 3, 4 +---- +logical_plan +01)Projection: t.c1, Int64(99999), t.c5 + t.c8, Utf8("test"), count(Int64(1)) +02)--Aggregate: groupBy=[[t.c1, t.c5 + t.c8]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[c1, c5, c8] + +query TT +EXPLAIN +SELECT 123, 456, 789, count(1), avg(c12) +FROM test_table t +group by 1, 2, 3 +---- +logical_plan +01)Projection: Int64(123), Int64(456), Int64(789), count(Int64(1)), avg(t.c12) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)), avg(t.c12)]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[c12] + +query TT +EXPLAIN +SELECT to_date('2023-05-04') as dt, extract(day from now()) < 1000 as today_filter, count(1) +FROM test_table t +GROUP BY 1, 2 +---- +logical_plan +01)Projection: Date32("2023-05-04") AS dt, Boolean(true) AS today_filter, count(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT + not ( + cast( + extract(month from now()) AS INT + ) + between 50 and 60 + ), count(1) +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Projection: Boolean(true) AS NOT date_part(Utf8("MONTH"),now()) BETWEEN Int64(50) AND Int64(60), count(Int64(1)) +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +03)----SubqueryAlias: t +04)------TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT 123 +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Aggregate: groupBy=[[Int64(123)]], aggr=[[]] +02)--SubqueryAlias: t +03)----TableScan: test_table projection=[] + +query TT +EXPLAIN +SELECT random() +FROM test_table t +GROUP BY 1 +---- +logical_plan +01)Aggregate: groupBy=[[random()]], aggr=[[]] +02)--SubqueryAlias: t +03)----TableScan: test_table projection=[] diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index ba9eedcbbd34..aafaa054964e 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -42,7 +42,7 @@ physical_plan statement ok set datafusion.execution.coalesce_batches = false -# expect no coalsece +# expect no coalescence query TT explain SELECT * FROM a WHERE c0 < 1; ---- diff --git a/datafusion/sqllogictest/test_files/order.slt b/datafusion/sqllogictest/test_files/order.slt index 4121de91cb8d..a46040aa532e 100644 --- a/datafusion/sqllogictest/test_files/order.slt +++ b/datafusion/sqllogictest/test_files/order.slt @@ -36,8 +36,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # test_sort_unprojected_col query I @@ -98,7 +98,8 @@ NULL three statement ok CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) -STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv' +OPTIONS('format.has_header' 'false'); # Demonstrate types query TTT @@ -164,7 +165,7 @@ logical_plan 03)----TableScan: aggregate_test_100 projection=[c1, c2, c3] physical_plan 01)ProjectionExec: expr=[c1@0 as c1, c2@1 as c2] -02)--SortExec: expr=[c2@1 ASC NULLS LAST,c3@2 ASC NULLS LAST] +02)--SortExec: expr=[c2@1 ASC NULLS LAST, c3@2 ASC NULLS LAST], preserve_partitioning=[false] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3], has_header=true query II @@ -326,6 +327,21 @@ select column1 + column2 from foo group by column1, column2 ORDER BY column2 des 7 3 +# Test issue: https://github.com/apache/datafusion/issues/11549 +query I +select column1 from foo order by log(column2); +---- +1 +3 +5 + +# Test issue: https://github.com/apache/datafusion/issues/13157 +query I +select column1 from foo order by column2 % 2, column2; +---- +1 +3 +5 # Cleanup statement ok @@ -374,11 +390,11 @@ select * from t SORT BY time; # distinct on a column not in the select list should not work -statement error For SELECT DISTINCT, ORDER BY expressions time must appear in select list +statement error DataFusion error: Error during planning: For SELECT DISTINCT, ORDER BY expressions t\.time must appear in select list SELECT DISTINCT value FROM t ORDER BY time; # distinct on an expression of a column not in the select list should not work -statement error For SELECT DISTINCT, ORDER BY expressions time must appear in select list +statement error DataFusion error: Error during planning: For SELECT DISTINCT, ORDER BY expressions t\.time must appear in select list SELECT DISTINCT date_trunc('hour', time) FROM t ORDER BY time; # distinct on a column that is in the select list but aliasted should work @@ -422,11 +438,11 @@ CREATE EXTERNAL TABLE multiple_ordered_table ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC) WITH ORDER (b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); query TT EXPLAIN SELECT (b+a+c) AS result @@ -456,7 +472,8 @@ CREATE EXTERNAL TABLE csv_with_timestamps ( ) STORED AS CSV WITH ORDER (ts ASC NULLS LAST) -LOCATION '../core/tests/data/timestamps.csv'; +LOCATION '../core/tests/data/timestamps.csv' +OPTIONS('format.has_header' 'false'); query TT EXPLAIN SELECT DATE_BIN(INTERVAL '15 minutes', ts, TIMESTAMP '2022-08-03 14:40:00Z') as db15 @@ -465,11 +482,11 @@ ORDER BY db15; ---- logical_plan 01)Sort: db15 ASC NULLS LAST -02)--Projection: date_bin(IntervalMonthDayNano("900000000000"), csv_with_timestamps.ts, TimestampNanosecond(1659537600000000000, None)) AS db15 +02)--Projection: date_bin(IntervalMonthDayNano("IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 900000000000 }"), csv_with_timestamps.ts, TimestampNanosecond(1659537600000000000, None)) AS db15 03)----TableScan: csv_with_timestamps projection=[ts] physical_plan 01)SortPreservingMergeExec: [db15@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[date_bin(900000000000, ts@0, 1659537600000000000) as db15] +02)--ProjectionExec: expr=[date_bin(IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 900000000000 }, ts@0, 1659537600000000000) as db15] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/timestamps.csv]]}, projection=[ts], output_ordering=[ts@0 ASC NULLS LAST], has_header=false @@ -511,10 +528,10 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW WITH ORDER(c11) -WITH ORDER(c12 DESC) +WITH ORDER(c12 DESC NULLS LAST) LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TT EXPLAIN SELECT ATAN(c11) as atan_c11 @@ -547,34 +564,34 @@ physical_plan 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11], output_ordering=[c11@0 ASC NULLS LAST], has_header=true query TT - EXPLAIN SELECT LOG(c11, c12) as log_c11_base_c12 + EXPLAIN SELECT LOG(c12, c11) as log_c11_base_c12 FROM aggregate_test_100 ORDER BY log_c11_base_c12; ---- logical_plan 01)Sort: log_c11_base_c12 ASC NULLS LAST -02)--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c11_base_c12 +02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c11_base_c12 03)----TableScan: aggregate_test_100 projection=[c11, c12] physical_plan 01)SortPreservingMergeExec: [log_c11_base_c12@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c11_base_c12] +02)--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c11_base_c12] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC NULLS LAST]], has_header=true query TT -EXPLAIN SELECT LOG(c12, c11) as log_c12_base_c11 +EXPLAIN SELECT LOG(c11, c12) as log_c12_base_c11 FROM aggregate_test_100 -ORDER BY log_c12_base_c11 DESC; +ORDER BY log_c12_base_c11 DESC NULLS LAST; ---- logical_plan -01)Sort: log_c12_base_c11 DESC NULLS FIRST -02)--Projection: log(aggregate_test_100.c12, CAST(aggregate_test_100.c11 AS Float64)) AS log_c12_base_c11 +01)Sort: log_c12_base_c11 DESC NULLS LAST +02)--Projection: log(CAST(aggregate_test_100.c11 AS Float64), aggregate_test_100.c12) AS log_c12_base_c11 03)----TableScan: aggregate_test_100 projection=[c11, c12] physical_plan -01)SortPreservingMergeExec: [log_c12_base_c11@0 DESC] -02)--ProjectionExec: expr=[log(c12@1, CAST(c11@0 AS Float64)) as log_c12_base_c11] +01)SortPreservingMergeExec: [log_c12_base_c11@0 DESC NULLS LAST] +02)--ProjectionExec: expr=[log(CAST(c11@0 AS Float64), c12@1) as log_c12_base_c11] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC]], has_header=true +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c11, c12], output_orderings=[[c11@0 ASC NULLS LAST], [c12@1 DESC NULLS LAST]], has_header=true statement ok drop table aggregate_test_100; @@ -627,7 +644,8 @@ CREATE EXTERNAL TABLE IF NOT EXISTS orders ( o_clerk VARCHAR, o_shippriority INTEGER, o_comment VARCHAR, -) STORED AS CSV WITH ORDER (o_orderkey ASC) DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/orders.csv'; +) STORED AS CSV WITH ORDER (o_orderkey ASC) LOCATION '../core/tests/tpch-csv/orders.csv' +OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); query TT EXPLAIN SELECT o_orderkey, o_orderstatus FROM orders ORDER BY o_orderkey ASC @@ -643,13 +661,6 @@ physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/te query error DataFusion error: Error during planning: Column a is not in schema CREATE EXTERNAL TABLE dt (a_id integer, a_str string, a_bool boolean) STORED AS CSV WITH ORDER (a ASC) LOCATION 'file://path/to/table'; - -# Create external table with DDL ordered columns without schema -# When schema is missing the query is expected to fail -query error DataFusion error: Error during planning: Provide a schema before specifying the order while creating a table\. -CREATE EXTERNAL TABLE dt STORED AS CSV WITH ORDER (a ASC) LOCATION 'file://path/to/table'; - - # Sort with duplicate sort expressions # Table is sorted multiple times on the same column name and should not fail statement ok @@ -680,7 +691,7 @@ logical_plan 01)Sort: t1.id DESC NULLS FIRST, t1.name ASC NULLS LAST 02)--TableScan: t1 projection=[id, name] physical_plan -01)SortExec: expr=[id@0 DESC,name@1 ASC NULLS LAST] +01)SortExec: expr=[id@0 DESC, name@1 ASC NULLS LAST], preserve_partitioning=[false] 02)--MemoryExec: partitions=1, partition_sizes=[1] query IT @@ -699,7 +710,7 @@ logical_plan 01)Sort: t1.id ASC NULLS LAST, t1.name ASC NULLS LAST 02)--TableScan: t1 projection=[id, name] physical_plan -01)SortExec: expr=[id@0 ASC NULLS LAST,name@1 ASC NULLS LAST] +01)SortExec: expr=[id@0 ASC NULLS LAST, name@1 ASC NULLS LAST], preserve_partitioning=[false] 02)--MemoryExec: partitions=1, partition_sizes=[1] @@ -754,34 +765,34 @@ logical_plan 02)--Union 03)----SubqueryAlias: u 04)------Projection: Int64(0) AS m, m0.t -05)--------Aggregate: groupBy=[[Int64(0), m0.t]], aggr=[[]] +05)--------Aggregate: groupBy=[[m0.t]], aggr=[[]] 06)----------SubqueryAlias: m0 07)------------Projection: column1 AS t 08)--------------Values: (Int64(0)), (Int64(1)), (Int64(2)) 09)----SubqueryAlias: v 10)------Projection: Int64(1) AS m, m1.t -11)--------Aggregate: groupBy=[[Int64(1), m1.t]], aggr=[[]] +11)--------Aggregate: groupBy=[[m1.t]], aggr=[[]] 12)----------SubqueryAlias: m1 13)------------Projection: column1 AS t 14)--------------Values: (Int64(0)), (Int64(1)) physical_plan -01)SortPreservingMergeExec: [m@0 ASC NULLS LAST,t@1 ASC NULLS LAST] -02)--SortExec: expr=[m@0 ASC NULLS LAST,t@1 ASC NULLS LAST] +01)SortPreservingMergeExec: [m@0 ASC NULLS LAST, t@1 ASC NULLS LAST] +02)--SortExec: expr=[m@0 ASC NULLS LAST, t@1 ASC NULLS LAST], preserve_partitioning=[true] 03)----InterleaveExec -04)------ProjectionExec: expr=[Int64(0)@0 as m, t@1 as t] -05)--------AggregateExec: mode=FinalPartitioned, gby=[Int64(0)@0 as Int64(0), t@1 as t], aggr=[], ordering_mode=PartiallySorted([0]) +04)------ProjectionExec: expr=[0 as m, t@0 as t] +05)--------AggregateExec: mode=FinalPartitioned, gby=[t@0 as t], aggr=[] 06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------RepartitionExec: partitioning=Hash([Int64(0)@0, t@1], 2), input_partitions=2 +07)------------RepartitionExec: partitioning=Hash([t@0], 2), input_partitions=2 08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -09)----------------AggregateExec: mode=Partial, gby=[0 as Int64(0), t@0 as t], aggr=[], ordering_mode=PartiallySorted([0]) +09)----------------AggregateExec: mode=Partial, gby=[t@0 as t], aggr=[] 10)------------------ProjectionExec: expr=[column1@0 as t] 11)--------------------ValuesExec -12)------ProjectionExec: expr=[Int64(1)@0 as m, t@1 as t] -13)--------AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1), t@1 as t], aggr=[], ordering_mode=PartiallySorted([0]) +12)------ProjectionExec: expr=[1 as m, t@0 as t] +13)--------AggregateExec: mode=FinalPartitioned, gby=[t@0 as t], aggr=[] 14)----------CoalesceBatchesExec: target_batch_size=8192 -15)------------RepartitionExec: partitioning=Hash([Int64(1)@0, t@1], 2), input_partitions=2 +15)------------RepartitionExec: partitioning=Hash([t@0], 2), input_partitions=2 16)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -17)----------------AggregateExec: mode=Partial, gby=[1 as Int64(1), t@0 as t], aggr=[], ordering_mode=PartiallySorted([0]) +17)----------------AggregateExec: mode=Partial, gby=[t@0 as t], aggr=[] 18)------------------ProjectionExec: expr=[column1@0 as t] 19)--------------------ValuesExec @@ -888,6 +899,410 @@ select column2, column1 from foo ORDER BY column2 desc, column1 desc; [1] 4 [0] 2 +# Test issue: https://github.com/apache/datafusion/issues/10013 +# There is both a HAVING clause and an ORDER BY clause in the query. +query I +SELECT + SUM(column1) +FROM foo +HAVING SUM(column1) > 0 +ORDER BY SUM(column1) +---- +16 + +# ORDER BY with a GROUP BY clause +query I +SELECT SUM(column1) + FROM foo +GROUP BY column2 +ORDER BY SUM(column1) +---- +0 +2 +2 +2 +3 +3 +4 + +# ORDER BY with a GROUP BY clause and a HAVING clause +query I +SELECT + SUM(column1) +FROM foo +GROUP BY column2 +HAVING SUM(column1) < 3 +ORDER BY SUM(column1) +---- +0 +2 +2 +2 + +# ORDER BY without a HAVING clause +query I +SELECT SUM(column1) FROM foo ORDER BY SUM(column1) +---- +16 + +# Order by unprojected aggregate expressions is not supported +query error DataFusion error: This feature is not implemented: Physical plan does not support logical expression AggregateFunction +SELECT column2 FROM foo ORDER BY SUM(column1) + +statement ok +create table ambiguity_test(a int, b int) as values (1,20), (2,10); + +# Order-by expressions prioritize referencing columns from the select list. +query II +select a as b, b as c2 from ambiguity_test order by b +---- +1 20 +2 10 + # Cleanup statement ok drop table foo; + +statement ok +drop table ambiguity_test; + +## reproducer for https://github.com/apache/datafusion/issues/12446 +# Ensure union ordering calculations with constants can be optimized + +statement ok +create table t(a0 int, a int, b int, c int) as values (1, 2, 3, 4), (5, 6, 7, 8); + +# expect this query to run successfully, not error +query III +select * from (select c, a, NULL::int as a0 from t order by a, c) t1 +union all +select * from (select c, NULL::int as a, a0 from t order by a0, c) t2 +order by c, a, a0, b +limit 2; +---- +4 2 NULL +4 NULL 1 + + +# Casting from numeric to string types breaks the ordering +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +query T +SELECT CAST(c as VARCHAR) as c_str +FROM ordered_table +ORDER BY c_str +limit 5; +---- +0 +1 +10 +11 +12 + +query TT +EXPLAIN SELECT CAST(c as VARCHAR) as c_str +FROM ordered_table +ORDER BY c_str +limit 5; +---- +logical_plan +01)Sort: c_str ASC NULLS LAST, fetch=5 +02)--Projection: CAST(ordered_table.c AS Utf8) AS c_str +03)----TableScan: ordered_table projection=[c] +physical_plan +01)SortPreservingMergeExec: [c_str@0 ASC NULLS LAST], fetch=5 +02)--SortExec: TopK(fetch=5), expr=[c_str@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[CAST(c@0 AS Utf8) as c_str] +04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + + +# Casting from numeric to numeric types preserves the ordering +query I +SELECT CAST(c as BIGINT) as c_bigint +FROM ordered_table +ORDER BY c_bigint +limit 5; +---- +0 +1 +2 +3 +4 + +query TT +EXPLAIN SELECT CAST(c as BIGINT) as c_bigint +FROM ordered_table +ORDER BY c_bigint +limit 5; +---- +logical_plan +01)Sort: c_bigint ASC NULLS LAST, fetch=5 +02)--Projection: CAST(ordered_table.c AS Int64) AS c_bigint +03)----TableScan: ordered_table projection=[c] +physical_plan +01)SortPreservingMergeExec: [c_bigint@0 ASC NULLS LAST], fetch=5 +02)--ProjectionExec: expr=[CAST(c@0 AS Int64) as c_bigint] +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +statement ok +drop table ordered_table; + + +# ABS(x) breaks the ordering if x's range contains both negative and positive values. +# Since x is defined as INT, its range is assumed to be from NEG_INF to INF. +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +query TT +EXPLAIN SELECT ABS(c) as abs_c +FROM ordered_table +ORDER BY abs_c +limit 5; +---- +logical_plan +01)Sort: abs_c ASC NULLS LAST, fetch=5 +02)--Projection: abs(ordered_table.c) AS abs_c +03)----TableScan: ordered_table projection=[c] +physical_plan +01)SortPreservingMergeExec: [abs_c@0 ASC NULLS LAST], fetch=5 +02)--SortExec: TopK(fetch=5), expr=[abs_c@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[abs(c@0) as abs_c] +04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +statement ok +drop table ordered_table; + +# ABS(x) preserves the ordering if x's range falls into positive values. +# Since x is defined as INT UNSIGNED, its range is assumed to be from 0 to INF. +statement ok +CREATE EXTERNAL TABLE ordered_table ( + a0 INT, + a INT, + b INT, + c INT UNSIGNED, + d INT +) +STORED AS CSV +WITH ORDER (c ASC) +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); + +query TT +EXPLAIN SELECT ABS(c) as abs_c +FROM ordered_table +ORDER BY abs_c +limit 5; +---- +logical_plan +01)Sort: abs_c ASC NULLS LAST, fetch=5 +02)--Projection: abs(ordered_table.c) AS abs_c +03)----TableScan: ordered_table projection=[c] +physical_plan +01)SortPreservingMergeExec: [abs_c@0 ASC NULLS LAST], fetch=5 +02)--ProjectionExec: expr=[abs(c@0) as abs_c] +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true + +# Boolean to integer casts preserve the order. +statement ok +CREATE EXTERNAL TABLE annotated_data_finite ( + ts INTEGER, + inc_col INTEGER, + desc_col INTEGER, +) +STORED AS CSV +WITH ORDER (inc_col ASC) +WITH ORDER (desc_col DESC) +LOCATION '../core/tests/data/window_1.csv' +OPTIONS ('format.has_header' 'true'); + +query TT +EXPLAIN SELECT CAST((inc_col>desc_col) as integer) as c from annotated_data_finite order by c; +---- +logical_plan +01)Sort: c ASC NULLS LAST +02)--Projection: CAST(annotated_data_finite.inc_col > annotated_data_finite.desc_col AS Int32) AS c +03)----TableScan: annotated_data_finite projection=[inc_col, desc_col] +physical_plan +01)SortPreservingMergeExec: [c@0 ASC NULLS LAST] +02)--ProjectionExec: expr=[CAST(inc_col@0 > desc_col@1 AS Int32) as c] +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[inc_col, desc_col], output_orderings=[[inc_col@0 ASC NULLS LAST], [desc_col@1 DESC]], has_header=true + +# Union a query with the actual data and one with a constant +query I +SELECT (SELECT c from ordered_table ORDER BY c LIMIT 1) UNION ALL (SELECT 23 as c from ordered_table ORDER BY c LIMIT 1) ORDER BY c; +---- +0 +23 + +# Do not increase partition number after fetch 1. As this will be unnecessary. +query TT +EXPLAIN SELECT a + b as sum1 FROM (SELECT a, b + FROM ordered_table + ORDER BY a ASC LIMIT 1 +); +---- +logical_plan +01)Projection: ordered_table.a + ordered_table.b AS sum1 +02)--Sort: ordered_table.a ASC NULLS LAST, fetch=1 +03)----TableScan: ordered_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 + b@1 as sum1] +02)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +03)----SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true + +statement ok +set datafusion.execution.use_row_number_estimates_to_optimize_partitioning = true; + +# Do not increase the number of partitions after fetch one, as this will be unnecessary. +query TT +EXPLAIN SELECT a + b as sum1 FROM (SELECT a, b + FROM ordered_table + ORDER BY a ASC LIMIT 1 +); +---- +logical_plan +01)Projection: ordered_table.a + ordered_table.b AS sum1 +02)--Sort: ordered_table.a ASC NULLS LAST, fetch=1 +03)----TableScan: ordered_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 + b@1 as sum1] +02)--SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true + +statement ok +set datafusion.execution.use_row_number_estimates_to_optimize_partitioning = false; + +# Here, we have multiple partitions after fetch one, since the row count estimate is not exact. +query TT +EXPLAIN SELECT a + b as sum1 FROM (SELECT a, b + FROM ordered_table + ORDER BY a ASC LIMIT 1 +); +---- +logical_plan +01)Projection: ordered_table.a + ordered_table.b AS sum1 +02)--Sort: ordered_table.a ASC NULLS LAST, fetch=1 +03)----TableScan: ordered_table projection=[a, b] +physical_plan +01)ProjectionExec: expr=[a@0 + b@1 as sum1] +02)--RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +03)----SortExec: TopK(fetch=1), expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b], has_header=true + + +# Test: inputs into union with different orderings +query TT +explain select * from (select b, c, a, NULL::int as a0 from ordered_table order by a, c) t1 +union all +select * from (select b, c, NULL::int as a, a0 from ordered_table order by a0, c) t2 +order by d, c, a, a0, b +limit 2; +---- +logical_plan +01)Projection: t1.b, t1.c, t1.a, t1.a0 +02)--Sort: t1.d ASC NULLS LAST, t1.c ASC NULLS LAST, t1.a ASC NULLS LAST, t1.a0 ASC NULLS LAST, t1.b ASC NULLS LAST, fetch=2 +03)----Union +04)------SubqueryAlias: t1 +05)--------Projection: ordered_table.b, ordered_table.c, ordered_table.a, Int32(NULL) AS a0, ordered_table.d +06)----------TableScan: ordered_table projection=[a, b, c, d] +07)------SubqueryAlias: t2 +08)--------Projection: ordered_table.b, ordered_table.c, Int32(NULL) AS a, ordered_table.a0, ordered_table.d +09)----------TableScan: ordered_table projection=[a0, b, c, d] +physical_plan +01)ProjectionExec: expr=[b@0 as b, c@1 as c, a@2 as a, a0@3 as a0] +02)--SortPreservingMergeExec: [d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a@2 ASC NULLS LAST, a0@3 ASC NULLS LAST, b@0 ASC NULLS LAST], fetch=2 +03)----UnionExec +04)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a@2 ASC NULLS LAST, b@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------ProjectionExec: expr=[b@1 as b, c@2 as c, a@0 as a, NULL as a0, d@3 as d] +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true +07)------SortExec: TopK(fetch=2), expr=[d@4 ASC NULLS LAST, c@1 ASC NULLS LAST, a0@3 ASC NULLS LAST, b@0 ASC NULLS LAST], preserve_partitioning=[false] +08)--------ProjectionExec: expr=[b@1 as b, c@2 as c, NULL as a, a0@0 as a0, d@3 as d] +09)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, b, c, d], output_ordering=[c@2 ASC NULLS LAST], has_header=true + +# Test: run the query from above +query IIII +select * from (select b, c, a, NULL::int as a0 from ordered_table order by a, c) t1 +union all +select * from (select b, c, NULL::int as a, a0 from ordered_table order by a0, c) t2 +order by d, c, a, a0, b +limit 2; +---- +0 0 0 NULL +0 0 NULL 1 + + +statement ok +drop table ordered_table; + +query TT +EXPLAIN SELECT + CASE + WHEN name = 'name1' THEN 0.0 + WHEN name = 'name2' THEN 0.5 + END AS a +FROM ( + SELECT 'name1' AS name + UNION ALL + SELECT 'name2' +) +ORDER BY a DESC; +---- +logical_plan +01)Sort: a DESC NULLS FIRST +02)--Projection: CASE WHEN name = Utf8("name1") THEN Float64(0) WHEN name = Utf8("name2") THEN Float64(0.5) END AS a +03)----Union +04)------Projection: Utf8("name1") AS name +05)--------EmptyRelation +06)------Projection: Utf8("name2") AS name +07)--------EmptyRelation +physical_plan +01)SortPreservingMergeExec: [a@0 DESC] +02)--ProjectionExec: expr=[CASE WHEN name@0 = name1 THEN 0 WHEN name@0 = name2 THEN 0.5 END as a] +03)----UnionExec +04)------ProjectionExec: expr=[name1 as name] +05)--------PlaceholderRowExec +06)------ProjectionExec: expr=[name2 as name] +07)--------PlaceholderRowExec + +query R +SELECT + CASE + WHEN name = 'name1' THEN 0.0 + WHEN name = 'name2' THEN 0.5 + END AS a +FROM ( + SELECT 'name1' AS name + UNION ALL + SELECT 'name2' +) +ORDER BY a DESC; +---- +0.5 +0 diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index 1b25406d8172..65341e67be87 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -42,7 +42,7 @@ CREATE TABLE src_table ( # Setup 2 files, i.e., as many as there are partitions: # File 1: -query ITID +query I COPY (SELECT * FROM src_table LIMIT 3) TO 'test_files/scratch/parquet/test_table/0.parquet' STORED AS PARQUET; @@ -50,7 +50,7 @@ STORED AS PARQUET; 3 # File 2: -query ITID +query I COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) TO 'test_files/scratch/parquet/test_table/1.parquet' STORED AS PARQUET; @@ -66,7 +66,6 @@ CREATE EXTERNAL TABLE test_table ( date_col DATE ) STORED AS PARQUET -WITH HEADER ROW LOCATION 'test_files/scratch/parquet/test_table'; # Basic query: @@ -90,8 +89,8 @@ logical_plan 01)Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST 02)--TableScan: test_table projection=[int_col, string_col] physical_plan -01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] -02)--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] +02)--SortExec: expr=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col] # Tear down test_table: @@ -107,7 +106,6 @@ CREATE EXTERNAL TABLE test_table ( date_col DATE ) STORED AS PARQUET -WITH HEADER ROW WITH ORDER (string_col ASC NULLS LAST, int_col ASC NULLS LAST) LOCATION 'test_files/scratch/parquet/test_table'; @@ -121,11 +119,11 @@ logical_plan 01)Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST 02)--TableScan: test_table projection=[int_col, string_col] physical_plan -01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] 02)--ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet]]}, projection=[int_col, string_col], output_ordering=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] # Add another file to the directory underlying test_table -query ITID +query I COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/2.parquet' STORED AS PARQUET; @@ -143,8 +141,8 @@ logical_plan 01)Sort: test_table.string_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST 02)--TableScan: test_table projection=[int_col, string_col] physical_plan -01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] -02)--SortExec: expr=[string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] +01)SortPreservingMergeExec: [string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST] +02)--SortExec: expr=[string_col@1 ASC NULLS LAST, int_col@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/test_table/2.parquet]]}, projection=[int_col, string_col] @@ -189,8 +187,7 @@ CREATE EXTERNAL TABLE alltypes_plain ( timestamp_col TIMESTAMP NOT NULL, ) STORED AS PARQUET -WITH HEADER ROW -LOCATION '../../parquet-testing/data/alltypes_plain.parquet' +LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; # Test a basic query with a CAST: query IT @@ -205,6 +202,11 @@ SELECT id, CAST(string_col AS varchar) FROM alltypes_plain 0 0 1 1 +# Ensure that local files can not be read by default (a potential security issue) +# (url table is only supported when DynamicFileCatalog is enabled) +statement error DataFusion error: Error during planning: table 'datafusion.public.../../parquet-testing/data/alltypes_plain.parquet' not found +SELECT id, CAST(string_col AS varchar) FROM '../../parquet-testing/data/alltypes_plain.parquet'; + # Clean up statement ok DROP TABLE alltypes_plain; @@ -214,7 +216,6 @@ DROP TABLE alltypes_plain; statement ok CREATE EXTERNAL TABLE test_binary STORED AS PARQUET -WITH HEADER ROW LOCATION '../core/tests/data/test_binary.parquet'; # Check size of table: @@ -247,7 +248,6 @@ DROP TABLE test_binary; statement ok CREATE EXTERNAL TABLE timestamp_with_tz STORED AS PARQUET -WITH HEADER ROW LOCATION '../core/tests/data/timestamp_with_tz.parquet'; # Check size of table: @@ -256,29 +256,25 @@ SELECT COUNT(*) FROM timestamp_with_tz; ---- 131072 -# Perform the query: -query IPT -SELECT - count, - LAG(timestamp, 1) OVER (ORDER BY timestamp), - arrow_typeof(LAG(timestamp, 1) OVER (ORDER BY timestamp)) -FROM timestamp_with_tz -LIMIT 10; ----- -0 NULL Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -4 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -14 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) -0 2014-08-27T14:00:00Z Timestamp(Millisecond, Some("UTC")) +# Ensure that output timestamp columns preserve the timezone from the input +# and output record count will match input file record count +query TPI +SELECT arrow_typeof(lag_timestamp), + MIN(lag_timestamp), + COUNT(1) +FROM ( + SELECT + count, + LAG(timestamp, 1) OVER (ORDER BY timestamp) AS lag_timestamp + FROM timestamp_with_tz +) t +GROUP BY 1 +---- +Timestamp(Millisecond, Some("UTC")) 2014-08-27T14:00:00Z 131072 # Test config listing_table_ignore_subdirectory: -query ITID +query I COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' STORED AS PARQUET; @@ -288,7 +284,6 @@ STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE listing_table STORED AS PARQUET -WITH HEADER ROW LOCATION 'test_files/scratch/parquet/test_table/*.parquet'; statement ok @@ -317,7 +312,6 @@ DROP TABLE timestamp_with_tz; statement ok CREATE EXTERNAL TABLE single_nan STORED AS PARQUET -WITH HEADER ROW LOCATION '../../parquet-testing/data/single_nan.parquet'; # Check table size: @@ -339,7 +333,6 @@ DROP TABLE single_nan; statement ok CREATE EXTERNAL TABLE list_columns STORED AS PARQUET -WITH HEADER ROW LOCATION '../../parquet-testing/data/list_columns.parquet'; query ?? @@ -355,3 +348,253 @@ DROP TABLE list_columns; # Clean up statement ok DROP TABLE listing_table; + +### Tests for binary_ar_string + +# This scenario models the case where a column has been stored in parquet +# "binary" column (without a String logical type annotation) +# this is the case with the `hits_partitioned` ClickBench datasets +# see https://github.com/apache/datafusion/issues/12788 + +## Create a table with a binary column + +query I +COPY ( + SELECT + arrow_cast(string_col, 'Binary') as binary_col, + arrow_cast(string_col, 'LargeBinary') as largebinary_col, + arrow_cast(string_col, 'BinaryView') as binaryview_col + FROM src_table + ) +TO 'test_files/scratch/parquet/binary_as_string.parquet' +STORED AS PARQUET; +---- +9 + +# Test 1: Read table with default options +statement ok +CREATE EXTERNAL TABLE binary_as_string_default +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' + +# NB the data is read and displayed as binary +query T?T?T? +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_default; +---- +Binary 616161 Binary 616161 Binary 616161 +Binary 626262 Binary 626262 Binary 626262 +Binary 636363 Binary 636363 Binary 636363 +Binary 646464 Binary 646464 Binary 646464 +Binary 656565 Binary 656565 Binary 656565 +Binary 666666 Binary 666666 Binary 666666 +Binary 676767 Binary 676767 Binary 676767 +Binary 686868 Binary 686868 Binary 686868 +Binary 696969 Binary 696969 Binary 696969 + +# Run an explain plan to show the cast happens in the plan (a CAST is needed for the predicates) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_default + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: CAST(binary_as_string_default.binary_col AS Utf8) LIKE Utf8("%a%") AND CAST(binary_as_string_default.largebinary_col AS Utf8) LIKE Utf8("%a%") AND CAST(binary_as_string_default.binaryview_col AS Utf8) LIKE Utf8("%a%") +02)--TableScan: binary_as_string_default projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[CAST(binary_as_string_default.binary_col AS Utf8) LIKE Utf8("%a%"), CAST(binary_as_string_default.largebinary_col AS Utf8) LIKE Utf8("%a%"), CAST(binary_as_string_default.binaryview_col AS Utf8) LIKE Utf8("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: CAST(binary_col@0 AS Utf8) LIKE %a% AND CAST(largebinary_col@1 AS Utf8) LIKE %a% AND CAST(binaryview_col@2 AS Utf8) LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=CAST(binary_col@0 AS Utf8) LIKE %a% AND CAST(largebinary_col@1 AS Utf8) LIKE %a% AND CAST(binaryview_col@2 AS Utf8) LIKE %a% + + +statement ok +DROP TABLE binary_as_string_default; + +## Test 2: Read table using the binary_as_string option + +statement ok +CREATE EXTERNAL TABLE binary_as_string_option +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' +OPTIONS ('binary_as_string' 'true'); + +# NB the data is read and displayed as string +query TTTTTT +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_option; +---- +Utf8 aaa Utf8 aaa Utf8 aaa +Utf8 bbb Utf8 bbb Utf8 bbb +Utf8 ccc Utf8 ccc Utf8 ccc +Utf8 ddd Utf8 ddd Utf8 ddd +Utf8 eee Utf8 eee Utf8 eee +Utf8 fff Utf8 fff Utf8 fff +Utf8 ggg Utf8 ggg Utf8 ggg +Utf8 hhh Utf8 hhh Utf8 hhh +Utf8 iii Utf8 iii Utf8 iii + +# Run an explain plan to show the cast happens in the plan (there should be no casts) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_option + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: binary_as_string_option.binary_col LIKE Utf8("%a%") AND binary_as_string_option.largebinary_col LIKE Utf8("%a%") AND binary_as_string_option.binaryview_col LIKE Utf8("%a%") +02)--TableScan: binary_as_string_option projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[binary_as_string_option.binary_col LIKE Utf8("%a%"), binary_as_string_option.largebinary_col LIKE Utf8("%a%"), binary_as_string_option.binaryview_col LIKE Utf8("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% + + +statement ok +DROP TABLE binary_as_string_option; + +## Test 3: Read table with binary_as_string option AND schema_force_view_types + +statement ok +CREATE EXTERNAL TABLE binary_as_string_both +STORED AS PARQUET LOCATION 'test_files/scratch/parquet/binary_as_string.parquet' +OPTIONS ( + 'binary_as_string' 'true', + 'schema_force_view_types' 'true' +); + +# NB the data is read and displayed a StringView +query TTTTTT +select + arrow_typeof(binary_col), binary_col, + arrow_typeof(largebinary_col), largebinary_col, + arrow_typeof(binaryview_col), binaryview_col + FROM binary_as_string_both; +---- +Utf8View aaa Utf8View aaa Utf8View aaa +Utf8View bbb Utf8View bbb Utf8View bbb +Utf8View ccc Utf8View ccc Utf8View ccc +Utf8View ddd Utf8View ddd Utf8View ddd +Utf8View eee Utf8View eee Utf8View eee +Utf8View fff Utf8View fff Utf8View fff +Utf8View ggg Utf8View ggg Utf8View ggg +Utf8View hhh Utf8View hhh Utf8View hhh +Utf8View iii Utf8View iii Utf8View iii + +# Run an explain plan to show the cast happens in the plan (there should be no casts) +query TT +EXPLAIN + SELECT binary_col, largebinary_col, binaryview_col + FROM binary_as_string_both + WHERE + binary_col LIKE '%a%' AND + largebinary_col LIKE '%a%' AND + binaryview_col LIKE '%a%'; +---- +logical_plan +01)Filter: binary_as_string_both.binary_col LIKE Utf8View("%a%") AND binary_as_string_both.largebinary_col LIKE Utf8View("%a%") AND binary_as_string_both.binaryview_col LIKE Utf8View("%a%") +02)--TableScan: binary_as_string_both projection=[binary_col, largebinary_col, binaryview_col], partial_filters=[binary_as_string_both.binary_col LIKE Utf8View("%a%"), binary_as_string_both.largebinary_col LIKE Utf8View("%a%"), binary_as_string_both.binaryview_col LIKE Utf8View("%a%")] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% +03)----RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +04)------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet/binary_as_string.parquet]]}, projection=[binary_col, largebinary_col, binaryview_col], predicate=binary_col@0 LIKE %a% AND largebinary_col@1 LIKE %a% AND binaryview_col@2 LIKE %a% + + +statement ok +drop table binary_as_string_both; + +# Read a parquet file with binary data in a FixedSizeBinary column + +# by default, the data is read as binary +statement ok +CREATE EXTERNAL TABLE test_non_utf8_binary +STORED AS PARQUET LOCATION '../core/tests/data/test_binary.parquet'; + +query T? +SELECT arrow_typeof(ids), ids FROM test_non_utf8_binary LIMIT 3; +---- +FixedSizeBinary(16) 008c7196f68089ab692e4739c5fd16b5 +FixedSizeBinary(16) 00a51a7bc5ff8eb1627f8f3dc959dce8 +FixedSizeBinary(16) 0166ce1d46129ad104fa4990c6057c91 + +statement ok +DROP TABLE test_non_utf8_binary; + + +# even with the binary_as_string option set, the data is read as binary +statement ok +CREATE EXTERNAL TABLE test_non_utf8_binary +STORED AS PARQUET LOCATION '../core/tests/data/test_binary.parquet' +OPTIONS ('binary_as_string' 'true'); + +query T? +SELECT arrow_typeof(ids), ids FROM test_non_utf8_binary LIMIT 3 +---- +FixedSizeBinary(16) 008c7196f68089ab692e4739c5fd16b5 +FixedSizeBinary(16) 00a51a7bc5ff8eb1627f8f3dc959dce8 +FixedSizeBinary(16) 0166ce1d46129ad104fa4990c6057c91 + +statement ok +DROP TABLE test_non_utf8_binary; + + +## Tests for https://github.com/apache/datafusion/issues/13186 +statement ok +create table cpu (time timestamp, usage_idle float, usage_user float, cpu int); + +statement ok +insert into cpu values ('1970-01-01 00:00:00', 1.0, 2.0, 3); + +# must put it into a parquet file to get statistics +statement ok +copy (select * from cpu) to 'test_files/scratch/parquet/cpu.parquet'; + +# Run queries against parquet files +statement ok +create external table cpu_parquet +stored as parquet +location 'test_files/scratch/parquet/cpu.parquet'; + +# Double filtering +# +# Expect 1 row for both queries +query PI +select time, rn +from ( + select time, row_number() OVER (ORDER BY usage_idle, time) as rn + from cpu + where cpu = 3 +) where rn > 0; +---- +1970-01-01T00:00:00 1 + +query PI +select time, rn +from ( + select time, row_number() OVER (ORDER BY usage_idle, time) as rn + from cpu_parquet + where cpu = 3 +) where rn > 0; +---- +1970-01-01T00:00:00 1 + + +# Clean up +statement ok +drop table cpu; + +statement ok +drop table cpu_parquet; diff --git a/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt new file mode 100644 index 000000000000..24ffb963bbe2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet_filter_pushdown.slt @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +########## +# Tests for parquet filter pushdown (filtering on data in the +# scan not just the metadata) +########## + +# File1 has only columns a and b +statement ok +COPY ( + SELECT column1 as a, column2 as b + FROM ( VALUES ('foo', 1), ('bar', 2), ('foo', 3), ('baz', 50) ) + ) TO 'test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet' +STORED AS PARQUET; + +# File2 has only b +statement ok +COPY ( + SELECT column1 as b + FROM ( VALUES (10), (20), (30) ) + ) TO 'test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet' +STORED AS PARQUET; + + +## Create table without filter pushdown +## (pushdown setting is part of the table, but is copied from the session settings) + +# pushdown_filters (currently) defaults to false, but we set it here to be explicit +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +statement ok +CREATE EXTERNAL TABLE t(a varchar, b int, c float) STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; + +## Create table with pushdown enabled (pushdown setting is part of the table) + +statement ok +set datafusion.execution.parquet.pushdown_filters = true; + +## Create table without pushdown +statement ok +CREATE EXTERNAL TABLE t_pushdown(a varchar, b int, c float) STORED AS PARQUET +LOCATION 'test_files/scratch/parquet_filter_pushdown/parquet_table/'; + +# restore defaults +statement ok +set datafusion.execution.parquet.pushdown_filters = false; + +# When filter pushdown is not enabled, ParquetExec only filters based on +# metadata, so a FilterExec is required to filter the +# output of the `ParquetExec` + +query T +select a from t where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t_pushdown where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t_pushdown.a ASC NULLS LAST +02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], predicate=b@1 > 2, pruning_predicate=CASE WHEN b_null_count@1 = b_row_count@2 THEN false ELSE b_max@0 > 2 END, required_guarantees=[] + + +# When filter pushdown *is* enabled, ParquetExec can filter exactly, +# not just metadata, so we expect to see no FilterExec +query T +select a from t_pushdown where b > 2 ORDER BY a; +---- +baz +foo +NULL +NULL +NULL + +query TT +EXPLAIN select a from t where b > 2 ORDER BY a; +---- +logical_plan +01)Sort: t.a ASC NULLS LAST +02)--Projection: t.a +03)----Filter: t.b > Int32(2) +04)------TableScan: t projection=[a, b], partial_filters=[t.b > Int32(2)] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----CoalesceBatchesExec: target_batch_size=8192 +04)------FilterExec: b@1 > 2, projection=[a@0] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=2 +06)----------ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a, b], predicate=b@1 > 2, pruning_predicate=CASE WHEN b_null_count@1 = b_row_count@2 THEN false ELSE b_max@0 > 2 END, required_guarantees=[] + +# also test querying on columns that are not in all the files +query T +select a from t_pushdown where b > 2 AND a IS NOT NULL order by a; +---- +baz +foo + +query TT +EXPLAIN select a from t_pushdown where b > 2 AND a IS NOT NULL order by a; +---- +logical_plan +01)Sort: t_pushdown.a ASC NULLS LAST +02)--TableScan: t_pushdown projection=[a], full_filters=[t_pushdown.b > Int32(2), t_pushdown.a IS NOT NULL] +physical_plan +01)SortPreservingMergeExec: [a@0 ASC NULLS LAST] +02)--SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[a], predicate=b@1 > 2 AND a@0 IS NOT NULL, pruning_predicate=CASE WHEN b_null_count@1 = b_row_count@2 THEN false ELSE b_max@0 > 2 END AND a_null_count@4 != a_row_count@3, required_guarantees=[] + + +query I +select b from t_pushdown where a = 'bar' order by b; +---- +2 + +query TT +EXPLAIN select b from t_pushdown where a = 'bar' order by b; +---- +logical_plan +01)Sort: t_pushdown.b ASC NULLS LAST +02)--TableScan: t_pushdown projection=[b], full_filters=[t_pushdown.a = Utf8("bar")] +physical_plan +01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] +02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_filter_pushdown/parquet_table/2.parquet]]}, projection=[b], predicate=a@0 = bar, pruning_predicate=CASE WHEN a_null_count@2 = a_row_count@3 THEN false ELSE a_min@0 <= bar AND bar <= a_max@1 END, required_guarantees=[a in (bar)] + +## cleanup +statement ok +DROP TABLE t; + +statement ok +DROP TABLE t_pushdown; diff --git a/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt new file mode 100644 index 000000000000..b68d4f52d21c --- /dev/null +++ b/datafusion/sqllogictest/test_files/parquet_sorted_statistics.slt @@ -0,0 +1,262 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# TESTS FOR SORTED PARQUET FILES + +# Set 2 partitions for deterministic output plans +statement ok +set datafusion.execution.target_partitions = 2; + +# Collect statistics -- used for sorting files +statement ok +set datafusion.execution.collect_statistics = true; + +# Enable split_file_groups_by_statistics since it's currently disabled by default +statement ok +set datafusion.execution.split_file_groups_by_statistics = true; + +# Create a table as a data source +statement ok +CREATE TABLE src_table ( + int_col INT, + descending_col INT, + string_col TEXT, + bigint_col BIGINT, + date_col DATE, + overlapping_col INT, + constant_col INT +) AS VALUES +-- first file +(1, 3, 'aaa', 100, 1, 0, 0), +(2, 2, 'bbb', 200, 2, 1, 0), +(3, 1, 'ccc', 300, 3, 2, 0), +-- second file +(4, 6, 'ddd', 400, 4, 0, 0), +(5, 5, 'eee', 500, 5, 1, 0), +(6, 4, 'fff', 600, 6, 2, 0), +-- third file +(7, 9, 'ggg', 700, 7, 3, 0), +(8, 8, 'hhh', 800, 8, 4, 0), +(9, 7, 'iii', 900, 9, 5, 0); + +# Setup 3 files, in particular more files than there are partitions + +# File 1: +query I +COPY (SELECT * FROM src_table ORDER BY int_col LIMIT 3) +TO 'test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet' +STORED AS PARQUET; +---- +3 + +# File 2: +query I +COPY (SELECT * FROM src_table WHERE int_col > 3 ORDER BY int_col LIMIT 3) +TO 'test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet' +STORED AS PARQUET; +---- +3 + +# Add another file to the directory underlying test_table +query I +COPY (SELECT * FROM src_table WHERE int_col > 6 ORDER BY int_col LIMIT 3) +TO 'test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet' +STORED AS PARQUET; +---- +3 + + +# Create a table from generated parquet files: +statement ok +CREATE EXTERNAL TABLE test_table ( + partition_col TEXT NOT NULL, + int_col INT NOT NULL, + descending_col INT NOT NULL, + string_col TEXT NOT NULL, + bigint_col BIGINT NOT NULL, + date_col DATE NOT NULL, + overlapping_col INT NOT NULL, + constant_col INT NOT NULL +) +STORED AS PARQUET +PARTITIONED BY (partition_col) +WITH ORDER (int_col ASC NULLS LAST, bigint_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; + +# Order by numeric columns +# This is to exercise file group sorting, which uses file-level statistics +# DataFusion doesn't currently support string column statistics +# This should not require a sort. +query TT +EXPLAIN SELECT int_col, bigint_col +FROM test_table +ORDER BY int_col, bigint_col; +---- +logical_plan +01)Sort: test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST +02)--TableScan: test_table projection=[int_col, bigint_col] +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[int_col, bigint_col], output_ordering=[int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST] + +# Another planning test, but project on a column with unsupported statistics +# We should be able to ignore this and look at only the relevant statistics +query TT +EXPLAIN SELECT string_col +FROM test_table +ORDER BY int_col, bigint_col; +---- +logical_plan +01)Projection: test_table.string_col +02)--Sort: test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST +03)----Projection: test_table.string_col, test_table.int_col, test_table.bigint_col +04)------TableScan: test_table projection=[int_col, string_col, bigint_col] +physical_plan +01)ProjectionExec: expr=[string_col@1 as string_col] +02)--ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[int_col, string_col, bigint_col], output_ordering=[int_col@0 ASC NULLS LAST, bigint_col@2 ASC NULLS LAST] + +# Clean up & recreate but sort on descending column +statement ok +DROP TABLE test_table; + +statement ok +CREATE EXTERNAL TABLE test_table ( + partition_col TEXT NOT NULL, + int_col INT NOT NULL, + descending_col INT NOT NULL, + string_col TEXT NOT NULL, + bigint_col BIGINT NOT NULL, + date_col DATE NOT NULL, + overlapping_col INT NOT NULL, + constant_col INT NOT NULL +) +STORED AS PARQUET +PARTITIONED BY (partition_col) +WITH ORDER (descending_col DESC NULLS LAST, bigint_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; + +# Query order by descending_col +# This should order the files like [C, B, A] +query TT +EXPLAIN SELECT descending_col, bigint_col +FROM test_table +ORDER BY descending_col DESC NULLS LAST, bigint_col ASC NULLS LAST; +---- +logical_plan +01)Sort: test_table.descending_col DESC NULLS LAST, test_table.bigint_col ASC NULLS LAST +02)--TableScan: test_table projection=[descending_col, bigint_col] +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet]]}, projection=[descending_col, bigint_col], output_ordering=[descending_col@0 DESC NULLS LAST, bigint_col@1 ASC NULLS LAST] + +# Clean up & re-create with partition columns in sort order +statement ok +DROP TABLE test_table; + +statement ok +CREATE EXTERNAL TABLE test_table ( + partition_col TEXT NOT NULL, + int_col INT NOT NULL, + descending_col INT NOT NULL, + string_col TEXT NOT NULL, + bigint_col BIGINT NOT NULL, + date_col DATE NOT NULL, + overlapping_col INT NOT NULL, + constant_col INT NOT NULL +) +STORED AS PARQUET +PARTITIONED BY (partition_col) +WITH ORDER (partition_col ASC NULLS LAST, int_col ASC NULLS LAST, bigint_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; + +# Order with partition column first +# In particular, the partition column is a string +# Even though statistics for string columns are not supported, +# string partition columns are common and we do support sorting file groups on them +query TT +EXPLAIN SELECT int_col, bigint_col, partition_col +FROM test_table +ORDER BY partition_col, int_col, bigint_col; +---- +logical_plan +01)Sort: test_table.partition_col ASC NULLS LAST, test_table.int_col ASC NULLS LAST, test_table.bigint_col ASC NULLS LAST +02)--TableScan: test_table projection=[int_col, bigint_col, partition_col] +physical_plan ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[int_col, bigint_col, partition_col], output_ordering=[partition_col@2 ASC NULLS LAST, int_col@0 ASC NULLS LAST, bigint_col@1 ASC NULLS LAST] + +# Clean up & re-create with overlapping column in sort order +# This will test the ability to sort files with overlapping statistics +statement ok +DROP TABLE test_table; + +statement ok +CREATE EXTERNAL TABLE test_table ( + partition_col TEXT NOT NULL, + int_col INT NOT NULL, + descending_col INT NOT NULL, + string_col TEXT NOT NULL, + bigint_col BIGINT NOT NULL, + date_col DATE NOT NULL, + overlapping_col INT NOT NULL, + constant_col INT NOT NULL +) +STORED AS PARQUET +PARTITIONED BY (partition_col) +WITH ORDER (overlapping_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; + +query TT +EXPLAIN SELECT int_col, bigint_col, overlapping_col +FROM test_table +ORDER BY overlapping_col; +---- +logical_plan +01)Sort: test_table.overlapping_col ASC NULLS LAST +02)--TableScan: test_table projection=[int_col, bigint_col, overlapping_col] +physical_plan +01)SortPreservingMergeExec: [overlapping_col@2 ASC NULLS LAST] +02)--ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet]]}, projection=[int_col, bigint_col, overlapping_col], output_ordering=[overlapping_col@2 ASC NULLS LAST] + +# Clean up & re-create with constant column in sort order +# This will require a sort because the # of required file groups (3) +# exceeds the # of target partitions (2) +statement ok +DROP TABLE test_table; + +statement ok +CREATE EXTERNAL TABLE test_table ( + partition_col TEXT NOT NULL, + int_col INT NOT NULL, + descending_col INT NOT NULL, + string_col TEXT NOT NULL, + bigint_col BIGINT NOT NULL, + date_col DATE NOT NULL, + overlapping_col INT NOT NULL, + constant_col INT NOT NULL +) +STORED AS PARQUET +PARTITIONED BY (partition_col) +WITH ORDER (constant_col ASC NULLS LAST) +LOCATION 'test_files/scratch/parquet_sorted_statistics/test_table'; + +query TT +EXPLAIN SELECT constant_col +FROM test_table +ORDER BY constant_col; +---- +logical_plan +01)Sort: test_table.constant_col ASC NULLS LAST +02)--TableScan: test_table projection=[constant_col] +physical_plan +01)SortPreservingMergeExec: [constant_col@0 ASC NULLS LAST] +02)--SortExec: expr=[constant_col@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ParquetExec: file_groups={2 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=A/0.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=B/1.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/parquet_sorted_statistics/test_table/partition_col=C/2.parquet]]}, projection=[constant_col] diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt index f933c90acc73..d14b6ca81f67 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_null.slt @@ -66,8 +66,8 @@ CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt index b01ea73c8056..25b4924715ca 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_simple.slt @@ -67,8 +67,8 @@ CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_types.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_types.slt index b7497429fa95..7e315a448b48 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_types.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_types.slt @@ -24,18 +24,18 @@ NULL query TT select 'a'::VARCHAR, ''::VARCHAR ---- -a (empty) +a (empty) skipif postgres query TT select 'a'::CHAR, ''::CHAR ---- -a (empty) +a (empty) query TT select 'a'::TEXT, ''::TEXT ---- -a (empty) +a (empty) skipif postgres query I diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt index 05343de32268..e02c19016790 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_union.slt @@ -64,8 +64,8 @@ CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query I rowsort SELECT * FROM ( diff --git a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt index cec51d472075..edad3747a203 100644 --- a/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt +++ b/datafusion/sqllogictest/test_files/pg_compat/pg_compat_window.slt @@ -64,8 +64,8 @@ CREATE EXTERNAL TABLE aggregate_test_100_by_sql ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query IIIIIIIIII SELECT diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index abb36c3c0858..878d7c8a4dfb 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -19,6 +19,9 @@ ## Predicates Tests ########## +statement ok +set datafusion.catalog.information_schema = true; + statement ok CREATE EXTERNAL TABLE aggregate_test_100 ( c1 VARCHAR NOT NULL, @@ -36,8 +39,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv' +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE alltypes_plain STORED AS PARQUET LOCATION '../../parquet-testing/data/alltypes_plain.parquet'; @@ -514,8 +517,38 @@ DROP TABLE t; statement ok CREATE EXTERNAL TABLE data_index_bloom_encoding_stats STORED AS PARQUET LOCATION '../../parquet-testing/data/data_index_bloom_encoding_stats.parquet'; +query TT +SHOW datafusion.execution.parquet.bloom_filter_on_read +---- +datafusion.execution.parquet.bloom_filter_on_read true + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'foo'; +---- + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'test'; +---- +test + +query T +SELECT * FROM data_index_bloom_encoding_stats WHERE "String" like '%e%'; +---- +Hello +test +are you +the quick +over +the lazy + + +######## +# Test query without bloom filter +# Refer to https://github.com/apache/datafusion/pull/7821#pullrequestreview-1688062599 +######## + statement ok -set datafusion.execution.parquet.bloom_filter_enabled=true; +set datafusion.execution.parquet.bloom_filter_on_read=false; query T SELECT * FROM data_index_bloom_encoding_stats WHERE "String" = 'foo'; @@ -537,7 +570,7 @@ over the lazy statement ok -set datafusion.execution.parquet.bloom_filter_enabled=false; +set datafusion.execution.parquet.bloom_filter_on_read=true; ######## @@ -551,7 +584,7 @@ DROP TABLE data_index_bloom_encoding_stats; # String coercion ######## -statement error DataFusion error: SQL error: ParserError\("Expected a data type name, found: ,"\) +statement error DataFusion error: SQL error: ParserError\("Expected: a data type name, found: ,"\) CREATE TABLE t(vendor_id_utf8, vendor_id_dict) AS VALUES (arrow_cast('124', 'Utf8'), arrow_cast('124', 'Dictionary(Int16, Utf8)')), @@ -588,7 +621,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( l_shipinstruct VARCHAR, l_shipmode VARCHAR, l_comment VARCHAR, -) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/lineitem.csv'; +) STORED AS CSV LOCATION '../core/tests/tpch-csv/lineitem.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS part ( @@ -601,7 +634,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS part ( p_container VARCHAR, p_retailprice DECIMAL(15, 2), p_comment VARCHAR, -) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/part.csv'; +) STORED AS CSV LOCATION '../core/tests/tpch-csv/part.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); query TT EXPLAIN SELECT l_partkey FROM @@ -659,7 +692,7 @@ CREATE TABLE IF NOT EXISTS partsupp ( ps_suppkey BIGINT, ps_availqty INTEGER, ps_supplycost DECIMAL(15, 2), - ps_comment VARCHAR, + ps_comment VARCHAR ) AS VALUES (63700, 7311, 100, 993.49, 'ven ideas. quickly even packages print. pending multipliers must have to are fluff'); @@ -715,7 +748,7 @@ OR GROUP BY p_partkey; ---- logical_plan -01)Aggregate: groupBy=[[part.p_partkey]], aggr=[[SUM(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)]] +01)Aggregate: groupBy=[[part.p_partkey]], aggr=[[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)]] 02)--Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey, partsupp.ps_suppkey 03)----Inner Join: part.p_partkey = partsupp.ps_partkey 04)------Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_partkey @@ -726,7 +759,7 @@ logical_plan 09)--------------TableScan: part projection=[p_partkey, p_brand], partial_filters=[part.p_brand = Utf8("Brand#12") OR part.p_brand = Utf8("Brand#23")] 10)------TableScan: partsupp projection=[ps_partkey, ps_suppkey] physical_plan -01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[SUM(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(DISTINCT partsupp.ps_suppkey)] +01)AggregateExec: mode=SinglePartitioned, gby=[p_partkey@2 as p_partkey], aggr=[sum(lineitem.l_extendedprice), avg(lineitem.l_discount), count(DISTINCT partsupp.ps_suppkey)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, ps_partkey@0)], projection=[l_extendedprice@0, l_discount@1, p_partkey@2, ps_suppkey@4] 04)------CoalesceBatchesExec: target_batch_size=8192 @@ -737,14 +770,13 @@ physical_plan 09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_extendedprice, l_discount], has_header=true 10)----------CoalesceBatchesExec: target_batch_size=8192 11)------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -12)--------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -13)----------------CoalesceBatchesExec: target_batch_size=8192 -14)------------------FilterExec: p_brand@1 = Brand#12 OR p_brand@1 = Brand#23 -15)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand], has_header=true -17)------CoalesceBatchesExec: target_batch_size=8192 -18)--------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=1 -19)----------MemoryExec: partitions=1, partition_sizes=[1] +12)--------------CoalesceBatchesExec: target_batch_size=8192 +13)----------------FilterExec: p_brand@1 = Brand#12 OR p_brand@1 = Brand#23, projection=[p_partkey@0] +14)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +15)--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand], has_header=true +16)------CoalesceBatchesExec: target_batch_size=8192 +17)--------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=1 +18)----------MemoryExec: partitions=1, partition_sizes=[1] # Inlist simplification diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index ce4b7217f990..e306ec7767c7 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -80,3 +80,21 @@ PREPARE my_plan(INT, DOUBLE, DOUBLE, DOUBLE) AS SELECT id, SUM(age) FROM person statement error PREPARE my_plan(STRING, STRING) AS SELECT * FROM (VALUES(1, $1), (2, $2)) AS t (num, letter); + +# test creating logical plan for EXECUTE statements +query TT +EXPLAIN EXECUTE my_plan; +---- +logical_plan Execute: my_plan params=[] + +query TT +EXPLAIN EXECUTE my_plan(10*2 + 1, 'Foo'); +---- +logical_plan Execute: my_plan params=[Int64(21), Utf8("Foo")] + +query error DataFusion error: Schema error: No field named a\. +EXPLAIN EXECUTE my_plan(a); + +# TODO: support EXECUTE queries +query error DataFusion error: This feature is not implemented: Unsupported logical plan: Execute +EXECUTE my_plan; diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt index b752f5644b7f..b5bcb5b4c6f7 100644 --- a/datafusion/sqllogictest/test_files/projection.slt +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -37,8 +37,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE aggregate_simple ( @@ -47,8 +47,8 @@ CREATE EXTERNAL TABLE aggregate_simple ( c3 BOOLEAN NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE TABLE memory_table(a INT NOT NULL, b INT NOT NULL, c INT NOT NULL) AS VALUES @@ -64,11 +64,13 @@ CREATE TABLE cpu_load_short(host STRING NOT NULL) AS VALUES statement ok CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) -STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv' +OPTIONS('format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE test_simple (c1 int, c2 bigint, c3 boolean) -STORED AS CSV LOCATION '../core/tests/data/partitioned_csv/partition-0.csv'; +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv/partition-0.csv' +OPTIONS('format.has_header' 'false'); # projection same fields query I rowsort @@ -233,3 +235,20 @@ DROP TABLE test; statement ok DROP TABLE test_simple; + +## projection push down with Struct +statement ok +create table t as values (struct(1)); + +query TT +explain select column1.c0 from t; +---- +logical_plan +01)Projection: get_field(t.column1, Utf8("c0")) +02)--TableScan: t projection=[column1] +physical_plan +01)ProjectionExec: expr=[get_field(column1@0, c0) as t.column1[c0]] +02)--MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt new file mode 100644 index 000000000000..86aa07b04ce1 --- /dev/null +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test push down filter + +statement ok +set datafusion.explain.logical_plan_only = true; + +statement ok +CREATE TABLE IF NOT EXISTS v AS VALUES(1,[1,2,3]),(2,[3,4,5]); + +query I +select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; +---- +3 +4 +5 + +# test push down filter for unnest with filter on non-unnest column +# filter plan is pushed down into projection plan +query TT +explain select uc2 from (select unnest(column2) as uc2, column1 from v) where column1 = 2; +---- +logical_plan +01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] +03)----Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +04)------Filter: v.column1 = Int64(2) +05)--------TableScan: v projection=[column1, column2] + +query I +select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; +---- +4 +5 + +# test push down filter for unnest with filter on unnest column +query TT +explain select uc2 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3; +---- +logical_plan +01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2 +02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Projection: unnest_placeholder(v.column2,depth=1) +04)------Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] +05)--------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +06)----------TableScan: v projection=[column1, column2] + +query II +select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; +---- +4 2 +5 2 + +# Could push the filter (column1 = 2) down below unnest +query TT +explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 AND column1 = 2; +---- +logical_plan +01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) +03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +05)--------Filter: v.column1 = Int64(2) +06)----------TableScan: v projection=[column1, column2] + +query II +select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +3 2 +4 2 +5 2 + +# only non-unnest filter in AND clause could be pushed down +query TT +explain select uc2, column1 from (select unnest(column2) as uc2, column1 from v) where uc2 > 3 OR column1 = 2; +---- +logical_plan +01)Projection: unnest_placeholder(v.column2,depth=1) AS uc2, v.column1 +02)--Filter: unnest_placeholder(v.column2,depth=1) > Int64(3) OR v.column1 = Int64(2) +03)----Unnest: lists[unnest_placeholder(v.column2)|depth=1] structs[] +04)------Projection: v.column2 AS unnest_placeholder(v.column2), v.column1 +05)--------TableScan: v projection=[column1, column2] + +statement ok +drop table v; + +# test with unnest struct, should not push down filter +statement ok +CREATE TABLE d AS VALUES(1,[named_struct('a', 1, 'b', 2)]),(2,[named_struct('a', 3, 'b', 4), named_struct('a', 5, 'b', 6)]); + +query I? +select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; +---- +1 {a: 1, b: 2} + +query TT +explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; +---- +logical_plan +01)Projection: d.column1, unnest_placeholder(d.column2,depth=1) AS o +02)--Filter: get_field(unnest_placeholder(d.column2,depth=1), Utf8("a")) = Int64(1) +03)----Unnest: lists[unnest_placeholder(d.column2)|depth=1] structs[] +04)------Projection: d.column1, d.column2 AS unnest_placeholder(d.column2) +05)--------TableScan: d projection=[column1, column2] + + + +statement ok +drop table d; diff --git a/datafusion/sqllogictest/test_files/references.slt b/datafusion/sqllogictest/test_files/references.slt index e276e322acab..4c3ac68aebd1 100644 --- a/datafusion/sqllogictest/test_files/references.slt +++ b/datafusion/sqllogictest/test_files/references.slt @@ -39,8 +39,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query I SELECT COUNT(*) FROM aggregate_test_100; @@ -105,7 +105,7 @@ logical_plan 02)--Projection: test....., test..... AS c3 03)----TableScan: test projection=[....] physical_plan -01)SortExec: expr=[....@0 ASC NULLS LAST] +01)SortExec: expr=[....@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--ProjectionExec: expr=[....@0 as ...., ....@0 as c3] 03)----MemoryExec: partitions=1, partition_sizes=[1] diff --git a/datafusion/sqllogictest/test_files/regexp.slt b/datafusion/sqllogictest/test_files/regexp.slt index a45ce3718bc4..800026dd766d 100644 --- a/datafusion/sqllogictest/test_files/regexp.slt +++ b/datafusion/sqllogictest/test_files/regexp.slt @@ -16,18 +16,18 @@ # under the License. statement ok -CREATE TABLE t (str varchar, pattern varchar, flags varchar) AS VALUES - ('abc', '^(a)', 'i'), - ('ABC', '^(A).*', 'i'), - ('aBc', '(b|d)', 'i'), - ('AbC', '(B|D)', null), - ('aBC', '^(b|c)', null), - ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', null), - ('Düsseldorf','[\p{Letter}-]+', null), - ('Москва', '[\p{L}-]+', null), - ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', null), - ('إسرائيل', '^\p{Arabic}+$', null); +CREATE TABLE t (str varchar, pattern varchar, start int, flags varchar) AS VALUES + ('abc', '^(a)', 1, 'i'), + ('ABC', '^(A).*', 1, 'i'), + ('aBc', '(b|d)', 1, 'i'), + ('AbC', '(B|D)', 2, null), + ('aBC', '^(b|c)', 3, null), + ('4000', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 1, null), + ('4010', '\b4([1-9]\d\d|\d[1-9]\d|\d\d[1-9])\b', 2, null), + ('Düsseldorf','[\p{Letter}-]+', 3, null), + ('Москва', '[\p{L}-]+', 4, null), + ('Köln', '[a-zA-Z]ö[a-zA-Z]{2}', 1, null), + ('إسرائيل', '^\p{Arabic}+$', 2, null); # # regexp_like tests @@ -48,6 +48,51 @@ true true true +query B +SELECT str ~ NULL FROM t; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select str ~ right('foo', NULL) FROM t; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + +query B +select right('foo', NULL) !~ str FROM t; +---- +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL +NULL + query B SELECT regexp_like('foobarbequebaz', ''); ---- @@ -94,7 +139,7 @@ SELECT regexp_like('aa', '.*-(\d)'); ---- false -query ? +query B SELECT regexp_like(NULL, '.*-(\d)'); ---- NULL @@ -104,7 +149,7 @@ SELECT regexp_like('aaa-0', NULL); ---- NULL -query ? +query B SELECT regexp_like(null, '.*-(\d)'); ---- NULL @@ -230,6 +275,66 @@ SELECT regexp_match('aaa-555', '.*-(\d*)'); ---- [555] +query B +select 'abc' ~ null; +---- +NULL + +query B +select null ~ null; +---- +NULL + +query B +select null ~ 'abc'; +---- +NULL + +query B +select 'abc' ~* null; +---- +NULL + +query B +select null ~* null; +---- +NULL + +query B +select null ~* 'abc'; +---- +NULL + +query B +select 'abc' !~ null; +---- +NULL + +query B +select null !~ null; +---- +NULL + +query B +select null !~ 'abc'; +---- +NULL + +query B +select 'abc' !~* null; +---- +NULL + +query B +select null !~* null; +---- +NULL + +query B +select null !~* 'abc'; +---- +NULL + # # regexp_replace tests # @@ -294,7 +399,7 @@ SELECT regexp_replace('Thomas', '.[mN]a.', 'M'); ---- ThM -query ? +query T SELECT regexp_replace(NULL, 'b(..)', 'X\\1Y', 'g'); ---- NULL @@ -309,16 +414,439 @@ SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'x ---- fooxx +query TTT +select + regexp_replace(col, NULL, 'c'), + regexp_replace(col, 'a', NULL), + regexp_replace(col, 'a', 'c', NULL) +from (values ('a'), ('b')) as tbl(col); +---- +NULL NULL NULL +NULL NULL NULL + # multiline string query B SELECT 'foo\nbar\nbaz' ~ 'bar'; ---- true +statement error +Error during planning: Cannot infer common argument type for regex operation List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata +: {} }) ~ List(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) +select [1,2] ~ [3]; + query B SELECT 'foo\nbar\nbaz' LIKE '%bar%'; ---- true +query B +SELECT NULL LIKE NULL; +---- +NULL + +query B +SELECT NULL iLIKE NULL; +---- +NULL + +query B +SELECT NULL not LIKE NULL; +---- +NULL + +query B +SELECT NULL not iLIKE NULL; +---- +NULL + +# regexp_count tests + +# regexp_count tests from postgresql +# https://github.com/postgres/postgres/blob/56d23855c864b7384970724f3ad93fb0fc319e51/src/test/regress/sql/strings.sql#L226-L235 + +query I +SELECT regexp_count('123123123123123', '(12)3'); +---- +5 + +query I +SELECT regexp_count('123123123123', '123', 1); +---- +4 + +query I +SELECT regexp_count('123123123123', '123', 3); +---- +3 + +query I +SELECT regexp_count('123123123123', '123', 33); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, ''); +---- +0 + +query I +SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i'); +---- +4 + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', 0); + +statement error +External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based +SELECT regexp_count('123123123123', '123', -3); + +statement error +External error: statement failed: DataFusion error: Arrow error: Compute error: regexp_count() does not support global flag +SELECT regexp_count('123123123123', '123', 1, 'g'); + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test string views + +statement ok +CREATE TABLE t_stringview AS +SELECT arrow_cast(str, 'Utf8View') as str, arrow_cast(pattern, 'Utf8View') as pattern, arrow_cast(start, 'Int64') as start, arrow_cast(flags, 'Utf8View') as flags FROM t; + +query I +SELECT regexp_count(str, '\w') from t; +---- +3 +3 +3 +3 +3 +4 +4 +10 +6 +4 +7 + +query I +SELECT regexp_count(str, '\w{2}', start) from t; +---- +1 +1 +1 +1 +0 +2 +1 +4 +1 +2 +3 + +query I +SELECT regexp_count(str, 'ab', 1, 'i') from t; +---- +1 +1 +1 +1 +1 +0 +0 +0 +0 +0 +0 + + +query I +SELECT regexp_count(str, pattern) from t; +---- +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start) from t; +---- +1 +1 +0 +0 +0 +0 +0 +1 +1 +1 +1 + +query I +SELECT regexp_count(str, pattern, start, flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# test type coercion +query I +SELECT regexp_count(arrow_cast(str, 'Utf8'), arrow_cast(pattern, 'LargeUtf8'), arrow_cast(start, 'Int32'), flags) from t; +---- +1 +1 +1 +0 +0 +0 +0 +1 +1 +1 +1 + +# NULL tests + +query I +SELECT regexp_count(NULL, NULL); +---- +0 + +query I +SELECT regexp_count(NULL, 'a'); +---- +0 + +query I +SELECT regexp_count('a', NULL); +---- +0 + +query I +SELECT regexp_count(NULL, NULL, NULL, NULL); +---- +0 + +statement ok +CREATE TABLE empty_table (str varchar, pattern varchar, start int, flags varchar); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- + +statement ok +INSERT INTO empty_table VALUES ('a', NULL, 1, 'i'), (NULL, 'a', 1, 'i'), (NULL, NULL, 1, 'i'), (NULL, NULL, NULL, 'i'); + +query I +SELECT regexp_count(str, pattern, start, flags) from empty_table; +---- +0 +0 +0 +0 + statement ok drop table t; + +statement ok +create or replace table strings as values + ('FooBar'), + ('Foo'), + ('Foo'), + ('Bar'), + ('FooBar'), + ('Bar'), + ('Baz'); + +statement ok +create or replace table dict_table as +select arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1 +from strings; + +query T +select column1 from dict_table where column1 LIKE '%oo%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT LIKE '%oo%'; +---- +Bar +Bar +Baz + +query T +select column1 from dict_table where column1 ILIKE '%oO%'; +---- +FooBar +Foo +Foo +FooBar + +query T +select column1 from dict_table where column1 NOT ILIKE '%oO%'; +---- +Bar +Bar +Baz + + +# plan should not cast the column, instead it should use the dictionary directly +query TT +explain select column1 from dict_table where column1 LIKE '%oo%'; +---- +logical_plan +01)Filter: dict_table.column1 LIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)CoalesceBatchesExec: target_batch_size=8192 +02)--FilterExec: column1@0 LIKE %oo% +03)----MemoryExec: partitions=1, partition_sizes=[1] + +# Ensure casting / coercion works for all operators +# (there should be no casts to Utf8) +query TT +explain select + column1 LIKE '%oo%', + column1 NOT LIKE '%oo%', + column1 ILIKE '%oo%', + column1 NOT ILIKE '%oo%' +from dict_table; +---- +logical_plan +01)Projection: dict_table.column1 LIKE Utf8("%oo%"), dict_table.column1 NOT LIKE Utf8("%oo%"), dict_table.column1 ILIKE Utf8("%oo%"), dict_table.column1 NOT ILIKE Utf8("%oo%") +02)--TableScan: dict_table projection=[column1] +physical_plan +01)ProjectionExec: expr=[column1@0 LIKE %oo% as dict_table.column1 LIKE Utf8("%oo%"), column1@0 NOT LIKE %oo% as dict_table.column1 NOT LIKE Utf8("%oo%"), column1@0 ILIKE %oo% as dict_table.column1 ILIKE Utf8("%oo%"), column1@0 NOT ILIKE %oo% as dict_table.column1 NOT ILIKE Utf8("%oo%")] +02)--MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +drop table strings + +statement ok +drop table dict_table diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 3f9e6e61f1d0..630674bb09ed 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -40,13 +40,13 @@ query TT EXPLAIN SELECT column1, SUM(column2) FROM parquet_table GROUP BY column1; ---- logical_plan -01)Aggregate: groupBy=[[parquet_table.column1]], aggr=[[SUM(CAST(parquet_table.column2 AS Int64))]] +01)Aggregate: groupBy=[[parquet_table.column1]], aggr=[[sum(CAST(parquet_table.column2 AS Int64))]] 02)--TableScan: parquet_table projection=[column1, column2] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=4 -04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 06)----------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2] @@ -58,13 +58,13 @@ query TT EXPLAIN SELECT column1, SUM(column2) FROM parquet_table GROUP BY column1; ---- logical_plan -01)Aggregate: groupBy=[[parquet_table.column1]], aggr=[[SUM(CAST(parquet_table.column2 AS Int64))]] +01)Aggregate: groupBy=[[parquet_table.column1]], aggr=[[sum(CAST(parquet_table.column2 AS Int64))]] 02)--TableScan: parquet_table projection=[column1, column2] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +01)AggregateExec: mode=FinalPartitioned, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([column1@0], 4), input_partitions=1 -04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[SUM(parquet_table.column2)] +04)------AggregateExec: mode=Partial, gby=[column1@0 as column1], aggr=[sum(parquet_table.column2)] 05)--------ParquetExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition/parquet_table/2.parquet]]}, projection=[column1, column2] @@ -95,8 +95,8 @@ CREATE UNBOUNDED EXTERNAL TABLE sink_table ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TII SELECT c1, c2, c3 FROM sink_table WHERE c3 > 0 LIMIT 5; @@ -123,7 +123,27 @@ logical_plan physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--CoalescePartitionsExec -03)----CoalesceBatchesExec: target_batch_size=8192 +03)----CoalesceBatchesExec: target_batch_size=8192, fetch=5 04)------FilterExec: c3@2 > 0 05)--------RepartitionExec: partitioning=RoundRobinBatch(3), input_partitions=1 06)----------StreamingTableExec: partition_sizes=1, projection=[c1, c2, c3], infinite_source=true + +# Start repratition on empty column test. +# See https://github.com/apache/datafusion/issues/12057 + +statement ok +CREATE TABLE t1(v1 int); + +statement ok +INSERT INTO t1 values(42); + +query I +SELECT sum(1) OVER (PARTITION BY false=false) +FROM t1 WHERE ((false > (v1 = v1)) IS DISTINCT FROM true); +---- +1 + +statement ok +DROP TABLE t1; + +# End repartition on empty columns test diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 7d2b6d4444ce..858e42106221 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ logical_plan physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 02)--FilterExec: column1@0 != 42 -03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +03)----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..88], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:88..176], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:176..264], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:264..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -99,10 +99,10 @@ logical_plan 03)----TableScan: parquet_table projection=[column1], partial_filters=[parquet_table.column1 != Int32(42)] physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] -02)--SortExec: expr=[column1@0 ASC NULLS LAST] +02)--SortExec: expr=[column1@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: column1@0 != 42 -05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +05)--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..174], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:174..342, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..6], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:6..180], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:180..351]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan 01)SortPreservingMergeExec: [column1@0 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: column1@0 != 42 -04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] +04)------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..171], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..175], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:175..351], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:171..342]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # Cleanup statement ok @@ -158,13 +158,13 @@ DROP TABLE parquet_table_with_order; # create a single csv file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' -STORED AS CSV WITH HEADER ROW; +STORED AS CSV OPTIONS ('format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE csv_table(column1 int) STORED AS csv -WITH HEADER ROW -LOCATION 'test_files/scratch/repartition_scan/csv_table/'; +LOCATION 'test_files/scratch/repartition_scan/csv_table/' +OPTIONS ('format.has_header' 'true'); query I select * from csv_table ORDER BY column1; @@ -277,8 +277,7 @@ DROP TABLE arrow_table; statement ok CREATE EXTERNAL TABLE avro_table STORED AS AVRO -WITH HEADER ROW -LOCATION '../../testing/data/avro/simple_enum.avro' +LOCATION '../../testing/data/avro/simple_enum.avro'; # It would be great to see the file read as "4" groups with even sizes (offsets) eventually diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 0c3fca446526..145172f31fd7 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -448,76 +448,6 @@ select floor(a), floor(b), floor(c) from signed_integers; 2 -1000 123 4 NULL NULL -## gcd - -# gcd scalar function -query III rowsort -select gcd(0, 0), gcd(2, 3), gcd(15, 10); ----- -0 1 5 - -# gcd scalar nulls -query I rowsort -select gcd(null, 64); ----- -NULL - -# gcd scalar nulls #1 -query I rowsort -select gcd(2, null); ----- -NULL - -# gcd scalar nulls #2 -query I rowsort -select gcd(null, null); ----- -NULL - -# gcd with columns -query III rowsort -select gcd(a, b), gcd(c, d), gcd(e, f) from signed_integers; ----- -1 1 2 -1 2 6 -2 1 1 -NULL NULL NULL - -## lcm - -# lcm scalar function -query III rowsort -select lcm(0, 0), lcm(2, 3), lcm(15, 10); ----- -0 6 30 - -# lcm scalar nulls -query I rowsort -select lcm(null, 64); ----- -NULL - -# lcm scalar nulls #1 -query I rowsort -select lcm(2, null); ----- -NULL - -# lcm scalar nulls #2 -query I rowsort -select lcm(null, null); ----- -NULL - -# lcm with columns -query III rowsort -select lcm(a, b), lcm(c, d), lcm(e, f) from signed_integers; ----- -100 580608 20 -1000 31488 55 -30000 1001472 12 -NULL NULL NULL - ## ln # ln scalar function @@ -606,6 +536,37 @@ select log(a, 64) a, log(b), log(10, b) from signed_integers; NaN 2 2 NaN 4 4 +# log overloaded base 10 float64 and float32 casting scalar +query RR rowsort +select log(arrow_cast(10, 'Float64')) a ,log(arrow_cast(100, 'Float32')) b; +---- +1 2 + +# log overloaded base 10 float64 and float32 casting with columns +query RR rowsort +select log(arrow_cast(a, 'Float64')), log(arrow_cast(b, 'Float32')) from signed_integers; +---- +0.301029995664 NaN +0.602059991328 NULL +NaN 2 +NaN 4 + +# log float64 and float32 casting scalar +query RR rowsort +select log(2,arrow_cast(8, 'Float64')) a, log(2,arrow_cast(16, 'Float32')) b; +---- +3 4 + +# log float64 and float32 casting with columns +query RR rowsort +select log(2,arrow_cast(a, 'Float64')), log(4,arrow_cast(b, 'Float32')) from signed_integers; +---- +1 NaN +2 NULL +NaN 3.321928 +NaN 6.643856 + + ## log10 # log10 scalar function @@ -776,7 +737,7 @@ select power(2, 0), power(2, 1), power(2, 2); 1 2 4 # power scalar nulls -query R rowsort +query I rowsort select power(null, 64); ---- NULL @@ -788,7 +749,7 @@ select power(2, null); NULL # power scalar nulls #2 -query R rowsort +query I rowsort select power(null, null); ---- NULL @@ -848,13 +809,23 @@ select round(a), round(b), round(c) from small_floats; 0 0 1 1 0 0 +# round with too large +# max Int32 is 2147483647 +query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483648 to type Int32 +select round(3.14, 2147483648); + +# with array +query error DataFusion error: Execution error: Invalid values for decimal places: Cast error: Can't cast value 2147483649 to type Int32 +select round(column1, column2) from values (3.14, 2), (3.14, 3), (3.14, 2147483649); + + ## signum # signum scalar function query RRR rowsort select signum(-2), signum(0), signum(2); ---- --1 1 1 +-1 0 1 # signum scalar nulls query R rowsort @@ -1263,7 +1234,7 @@ FROM t1 999 999 -# case_when_else_with_null_contant() +# case_when_else_with_null_constant() query I SELECT CASE WHEN c1 = 'a' THEN 1 @@ -1298,27 +1269,27 @@ SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END bar # case_expr_with_null() -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,null),(2,3)) as t (a,b)) a; ---- NULL 3 -query ? +query I select case when b is null then null else b end from (select a,b from (values (1,1),(2,3)) as t (a,b)) a; ---- 1 3 # case_expr_with_nulls() -query ? +query I select case when b is null then null when b < 3 then null when b >=3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a ---- NULL NULL 4 -query ? +query I select case b when 1 then null when 2 then null when 3 then b + 1 else b end from (select a,b from (values (1,null),(1,2),(2,3)) as t (a,b)) a; ---- NULL @@ -1352,8 +1323,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); # c8 = i32; c6 = i64 query TTT @@ -1586,6 +1557,9 @@ NULL NULL query error DataFusion error: Error during planning: Negation only supports numeric, interval and timestamp types SELECT -'100' +query error DataFusion error: Error during planning: Unary operator '\+' only supports numeric, interval and timestamp types +SELECT +true + statement ok drop table test_boolean @@ -1638,7 +1612,7 @@ false statement ok CREATE TABLE t1( a boolean, - b boolean, + b boolean ) as VALUES (true, true), (true, null), @@ -1767,165 +1741,6 @@ SELECT make_array(1, 2, 3); ---- [1, 2, 3] -# coalesce static empty value -query T -SELECT COALESCE('', 'test') ----- -(empty) - -# coalesce static value with null -query T -SELECT COALESCE(NULL, 'test') ----- -test - - -statement ok -create table test1 as values (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (null); - -# test coercion string -query ? -select coalesce(column1, 'none_set') from test1; ----- -foo -none_set - -# test coercion Int -query I -select coalesce(34, arrow_cast(123, 'Dictionary(Int32, Int8)')); ----- -34 - -# test with Int -query I -select coalesce(arrow_cast(123, 'Dictionary(Int32, Int8)'),34); ----- -123 - -# test with null -query I -select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)')); ----- -34 - -# test with null -query T -select coalesce(null, column1, 'none_set') from test1; ----- -foo -none_set - -statement ok -drop table test1 - - -statement ok -CREATE TABLE test( - c1 INT, - c2 INT -) as VALUES -(0, 1), -(NULL, 1), -(1, 0), -(NULL, 1), -(NULL, NULL); - -# coalesce result -query I rowsort -SELECT COALESCE(c1, c2) FROM test ----- -0 -1 -1 -1 -NULL - -# coalesce result with default value -query T rowsort -SELECT COALESCE(c1, c2, '-1') FROM test ----- --1 -0 -1 -1 -1 - -statement ok -drop table test - -statement ok -CREATE TABLE test( - c1 INT, - c2 INT -) as VALUES -(1, 2), -(NULL, 2), -(1, NULL), -(NULL, NULL); - -# coalesce sum with default value -query I -SELECT SUM(COALESCE(c1, c2, 0)) FROM test ----- -4 - -# coalesce mul with default value -query I -SELECT COALESCE(c1 * c2, 0) FROM test ----- -2 -0 -0 -0 - -statement ok -drop table test - -# coalesce date32 - -statement ok -CREATE TABLE test( - d1_date DATE, - d2_date DATE, - d3_date DATE -) as VALUES - ('2022-12-12','2022-12-12','2022-12-12'), - (NULL,'2022-12-11','2022-12-12'), - ('2022-12-12','2022-12-10','2022-12-12'), - ('2022-12-12',NULL,'2022-12-12'), - ('2022-12-12','2022-12-8','2022-12-12'), - ('2022-12-12','2022-12-7',NULL), - ('2022-12-12',NULL,'2022-12-12'), - (NULL,'2022-12-5','2022-12-12') -; - -query D -SELECT COALESCE(d1_date, d2_date, d3_date) FROM test ----- -2022-12-12 -2022-12-11 -2022-12-12 -2022-12-12 -2022-12-12 -2022-12-12 -2022-12-12 -2022-12-05 - -query T -SELECT arrow_typeof(COALESCE(d1_date, d2_date, d3_date)) FROM test ----- -Date32 -Date32 -Date32 -Date32 -Date32 -Date32 -Date32 -Date32 - -statement ok -drop table test - statement ok CREATE TABLE test( i32 INT, @@ -1939,7 +1754,7 @@ CREATE TABLE test( (-14, -14, -14.5, -14.5), (NULL, NULL, NULL, NULL); -query RRRRIR rowsort +query IRRRIR rowsort SELECT power(i32, exp_i) as power_i32, power(i64, exp_f) as power_i64, pow(f32, exp_i) as power_f32, @@ -1958,34 +1773,33 @@ statement ok drop table test # error message for wrong function signature (Variadic: arbitrary number of args all from some common types) -statement error Error during planning: No function matches the given name and argument types 'concat\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\) +statement error SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts. +statement error SELECT nullif(1); - # error message for wrong function signature (Exact: exact number of args of an exact type) -statement error Error during planning: No function matches the given name and argument types 'pi\(Float64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpi\(\) +statement error SELECT pi(3.14); # error message for wrong function signature (Any: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\) +statement error SELECT arrow_typeof(1, 1); # error message for wrong function signature (OneOf: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64, Float64\) +statement error SELECT power(1, 2, 3); # The following functions need 1 argument -statement error Error during planning: No function matches the given name and argument types 'abs\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +statement error SELECT abs(); -statement error Error during planning: No function matches the given name and argument types 'acos\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tacos\(Float64/Float32\) +statement error SELECT acos(); -statement error Error during planning: No function matches the given name and argument types 'isnan\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tisnan\(Float32\)\n\tisnan\(Float64\) +statement error SELECT isnan(); # turn off enable_ident_normalization @@ -2115,7 +1929,7 @@ select 100000 where position('legend' in 'league of legend') = 11; 100000 # test null -query ? +query I select position(null in null) ---- NULL @@ -2126,11 +1940,9 @@ select position('' in '') ---- 1 - -query error DataFusion error: Execution error: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. +query error DataFusion error: Error during planning: Error during planning: Int64 and Int64 are not coercible to a common string select position(1 in 1) - query I select strpos('abc', 'c'); ---- @@ -2202,3 +2014,37 @@ query I select strpos('joséésoj', arrow_cast(null, 'Utf8')); ---- NULL + +statement ok +CREATE TABLE t1 (v1 int) AS VALUES (1), (2), (3); + +query I +SELECT * FROM t1 ORDER BY ACOS(SIN(v1)); +---- +2 +1 +3 + +query I +SELECT * FROM t1 ORDER BY ACOSH(SIN(v1)); +---- +1 +2 +3 + +query I +SELECT * FROM t1 ORDER BY ASIN(SIN(v1)); +---- +3 +1 +2 + +query I +SELECT * FROM t1 ORDER BY ATANH(SIN(v1)); +---- +3 +1 +2 + +statement ok +drop table t1; diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index a3a4b3bfc584..c096f6e692af 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -33,8 +33,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE aggregate_simple ( @@ -43,8 +43,8 @@ CREATE EXTERNAL TABLE aggregate_simple ( c3 BOOLEAN NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' +OPTIONS ('format.has_header' 'true'); ########## ## SELECT Tests @@ -67,7 +67,7 @@ AS SELECT arrow_cast(x, 'Dictionary(Int32, Utf8)') as d1, y as d2, arrow_cast(z, 'LargeUtf8') as d3 FROM window_null_string_value_prepare; -query ?I +query TI SELECT d1, row_number() OVER (partition by d1) as rn1 FROM window_null_string_table order by d1 asc; ---- one 1 @@ -101,7 +101,7 @@ statement ok CREATE TABLE test ( c1 BIGINT NOT NULL, c2 BIGINT NOT NULL, - c3 BOOLEAN NOT NULL, + c3 BOOLEAN NOT NULL ) AS VALUES (0, 1, false), (0, 10, true), (0, 2, true), @@ -241,7 +241,7 @@ AS SELECT arrow_cast(x, 'Dictionary(Int32, Utf8)') as d1, arrow_cast(y, 'Dictionary(Int32, Utf8)') as d2, z as d3 FROM value; -query ? +query T SELECT d1 FROM string_dictionary_table; ---- one @@ -249,38 +249,38 @@ NULL three # basic filtering -query ? +query T SELECT d1 FROM string_dictionary_table WHERE d1 IS NOT NULL; ---- one three # comparison with constant -query ? +query T SELECT d1 FROM string_dictionary_table WHERE d1 = 'three'; ---- three # comparison with another dictionary column -query ? +query T SELECT d1 FROM string_dictionary_table WHERE d1 = d2; ---- three # order comparison with another dictionary column -query ? +query T SELECT d1 FROM string_dictionary_table WHERE d1 <= d2; ---- three # comparison with a non dictionary column -query ? +query T SELECT d1 FROM string_dictionary_table WHERE d1 = d3; ---- three # filtering with constant -query ? +query T SELECT d1 FROM string_dictionary_table WHERE d1 = 'three'; ---- three @@ -320,7 +320,7 @@ SELECT MAX(d1) FROM string_dictionary_table; three # grouping -query ?I +query TI SELECT d1, COUNT(*) FROM string_dictionary_table group by d1 order by d1; ---- one 1 @@ -328,7 +328,7 @@ three 1 NULL 1 # window functions -query ?I +query TI SELECT d1, row_number() OVER (partition by d1) as rn1 FROM string_dictionary_table order by d1; ---- one 1 @@ -336,26 +336,35 @@ three 1 NULL 1 # select_values_list -statement error DataFusion error: SQL error: ParserError\("Expected \(, found: EOF"\) +statement error DataFusion error: SQL error: ParserError\("Expected: \(, found: EOF"\) VALUES -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +statement error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \)"\) VALUES () -statement error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +statement error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \)"\) VALUES (1),() statement error DataFusion error: Error during planning: Inconsistent data length across values list: got 2 values in row 1 but expected 1 VALUES (1),(1,2) -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0 +query I VALUES (1),('2') +---- +1 +2 -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 0 +query R VALUES (1),(2.0) +---- +1 +2 -statement error DataFusion error: Error during planning: Inconsistent data type across values list at row 1 column 1 +query II VALUES (1,2), (1,'2') +---- +1 2 +1 2 query IT VALUES (1,'a'),(NULL,'b'),(3,'c') @@ -371,7 +380,7 @@ NULL a NULL b 3 c -query TT +query ?T VALUES (NULL,'a'),(NULL,'b'),(NULL,'c') ---- NULL a @@ -392,7 +401,7 @@ VALUES (1,NULL),(2,NULL),(3,'c') 2 NULL 3 c -query IIIIIIIIIIIIITTR +query IIIIIIIIIIIII?TR VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5) ---- 1 2 3 4 5 6 7 8 9 10 11 12 13 NULL F 3.5 @@ -450,7 +459,7 @@ VALUES (-1) query IIB VALUES (2+1,2-1,2>1) ---- -3 1 true +3 1 true # multiple rows values query I rowsort @@ -463,8 +472,8 @@ VALUES (1),(2) query IT rowsort VALUES (1,'a'),(2,'b') ---- -1 a -2 b +1 a +2 b # table foo for distinct order by statement ok @@ -473,6 +482,28 @@ CREATE TABLE foo AS VALUES (3, 4), (5, 6); +# multiple rows and columns need type coercion +statement ok +CREATE TABLE foo2(c1 double, c2 double) AS VALUES +(1.1, 4.1), +(2, 5), +(3, 6); + +query T +SELECT arrow_typeof(COALESCE(column1, column2)) FROM VALUES (null, 1.2); +---- +Float64 + +# multiple rows and columns with null need type coercion +query TTT +select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from (SELECT column1, column2, column3 FROM VALUES +(null, 2, 'a'), +(1.1, null, 'b'), +(2, 5, null)) LIMIT 1; +---- +Float64 Int64 Utf8 + + # foo distinct query T select distinct '1' from foo; @@ -488,7 +519,7 @@ select '1' from foo order by column1; 1 # foo distinct order by -statement error DataFusion error: Error during planning: For SELECT DISTINCT, ORDER BY expressions column1 must appear in select list +statement error DataFusion error: Error during planning: For SELECT DISTINCT, ORDER BY expressions foo\.column1 must appear in select list select distinct '1' from foo order by column1; # distincts for float nan @@ -527,7 +558,7 @@ EXPLAIN SELECT * FROM ((SELECT column1 FROM foo) "T1" CROSS JOIN (SELECT column2 ---- logical_plan 01)SubqueryAlias: F -02)--CrossJoin: +02)--Cross Join: 03)----SubqueryAlias: T1 04)------TableScan: foo projection=[column1] 05)----SubqueryAlias: T2 @@ -550,9 +581,32 @@ select * from (select 1 a union all select 2) b order by a limit 1; 1 # select limit clause invalid -statement error DataFusion error: Error during planning: LIMIT must be >= 0, '\-1' was provided\. +statement error Error during planning: LIMIT must be >= 0, '-1' was provided select * from (select 1 a union all select 2) b order by a limit -1; +statement error Error during planning: OFFSET must be >=0, '-1' was provided +select * from (select 1 a union all select 2) b order by a offset -1; + +statement error Unsupported LIMIT expression +select * from (values(1),(2)) limit (select 1); + +statement error Unsupported OFFSET expression +select * from (values(1),(2)) offset (select 1); + +# disallow non-integer limit/offset +statement error Expected LIMIT to be an integer or null, but got Float64 +select * from (values(1),(2)) limit 0.5; + +statement error Expected OFFSET to be an integer or null, but got Utf8 +select * from (values(1),(2)) offset '1'; + +# test with different integer types +query I +select * from (values (1), (2), (3), (4)) limit 2::int OFFSET 1::tinyint +---- +2 +3 + # select limit with basic arithmetic query I select * from (select 1 a union all select 2) b order by a limit 1+1; @@ -566,13 +620,38 @@ select * from (values (1)) LIMIT 10*100; ---- 1 -# More complex expressions in the limit is not supported yet. -# See issue: https://github.com/apache/datafusion/issues/9821 -statement error DataFusion error: Error during planning: Unsupported operator for LIMIT clause +# select limit with complex arithmetic +query I select * from (values (1)) LIMIT 100/10; +---- +1 + +# test constant-folding of LIMIT expr +query I +select * from (values (1), (2), (3), (4)) LIMIT abs(-4) + 4 / -2; -- LIMIT 2 +---- +1 +2 + +# test constant-folding of OFFSET expr +query I +select * from (values (1), (2), (3), (4)) OFFSET abs(-4) + 4 / -2; -- OFFSET 2 +---- +3 +4 -# More complex expressions in the limit is not supported yet. -statement error DataFusion error: Error during planning: Unexpected expression in LIMIT clause +# test constant-folding of LIMIT and OFFSET +query I +select * from (values (1), (2), (3), (4)) + -- LIMIT 2 + LIMIT abs(-4) + -1 * 2 + -- OFFSET 1 + OFFSET case when 1 < 2 then 1 else 0 end; +---- +2 +3 + +statement error Schema error: No field named column1. select * from (values (1)) LIMIT cast(column1 as tinyint); # select limit clause @@ -582,6 +661,13 @@ select * from (select 1 a union all select 2) b order by a limit null; 1 2 +# offset null takes no effect +query I +select * from (select 1 a union all select 2) b order by a offset null; +---- +1 +2 + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit 0; @@ -603,6 +689,33 @@ END; ---- 2 +# select case when type is null +query I +select CASE + WHEN NULL THEN 1 + ELSE 2 +END; +---- +2 + +# select case then type is null +query I +select CASE + WHEN 10 > 5 THEN NULL + ELSE 2 +END; +---- +NULL + +# select case else type is null +query I +select CASE + WHEN 10 = 5 THEN 1 + ELSE NULL +END; +---- +NULL + # Binary Expression for LargeUtf8 # issue: https://github.com/apache/datafusion/issues/5893 statement ok @@ -953,13 +1066,13 @@ FROM ( ) AS a ) AS b ---- -a 5 -101 -a 5 -54 a 5 -38 +a 5 -54 +a 6 36 +a 6 -31 a 5 65 +a 5 -101 a 6 -101 -a 6 -31 -a 6 36 # nested select without aliases query TII @@ -974,13 +1087,13 @@ FROM ( ) ) ---- -a 5 -101 -a 5 -54 a 5 -38 +a 5 -54 +a 6 36 +a 6 -31 a 5 65 +a 5 -101 a 6 -101 -a 6 -31 -a 6 36 # select with join unaliased subqueries query TIITII @@ -1067,9 +1180,9 @@ CREATE EXTERNAL TABLE annotated_data_finite2 ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # test_source_projection @@ -1081,12 +1194,9 @@ EXPLAIN SELECT a FROM annotated_data_finite2 LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: annotated_data_finite2.a ASC NULLS LAST, fetch=5 -03)----TableScan: annotated_data_finite2 projection=[a] -physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], output_ordering=[a@0 ASC NULLS LAST], has_header=true +01)Sort: annotated_data_finite2.a ASC NULLS LAST, fetch=5 +02)--TableScan: annotated_data_finite2 projection=[a] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a], limit=5, output_ordering=[a@0 ASC NULLS LAST], has_header=true query I SELECT a FROM annotated_data_finite2 @@ -1160,12 +1270,12 @@ LIMIT 5 200 2000 # Trying to exclude non-existing column should give error -statement error DataFusion error: Schema error: No field named e. Valid fields are table1.a, table1.b, table1.c, table1.d. +statement error SELECT * EXCLUDE e FROM table1 # similarly, except should raise error if excluded column is not in the table -statement error DataFusion error: Schema error: No field named e. Valid fields are table1.a, table1.b, table1.c, table1.d. +statement error SELECT * EXCEPT(e) FROM table1 @@ -1179,7 +1289,7 @@ FROM table1 2 20 20 200 2000 # EXCEPT, or EXCLUDE shouldn't contain duplicate column names -statement error DataFusion error: Error during planning: EXCLUDE or EXCEPT contains duplicate column names +statement error SELECT * EXCLUDE(a, a) FROM table1 @@ -1188,6 +1298,63 @@ statement ok SELECT * EXCEPT(a, b, c, d) FROM table1 +# try zero column with LIMIT, 1 row but empty +statement ok +SELECT * EXCEPT (a, b, c, d) +FROM table1 +LIMIT 1 + +# try zero column with GROUP BY, 2 row but empty +statement ok +SELECT * EXCEPT (a, b, c, d) +FROM table1 +GROUP BY a + +# try zero column with WHERE, 1 row but empty +statement ok +SELECT * EXCEPT (a, b, c, d) +FROM table1 +WHERE a = 1 + +# create table2 the same with table1 +statement ok +CREATE TABLE table2 ( + a int, + b int, + c int, + d int +) as values + (1, 10, 100, 1000), + (2, 20, 200, 2000); + +# try zero column with inner JOIN, 2 row but empty +statement ok +WITH t1 AS (SELECT a AS t1_a FROM table1), t2 AS (SELECT a AS t2_a FROM table2) +SELECT * EXCEPT (t1_a, t2_a) +FROM t1 +JOIN t2 ON (t1_a = t2_a) + +# try zero column with more JOIN, 2 row but empty +statement ok +SELECT * EXCEPT (b1, b2) +FROM ( + SELECT b AS b1 FROM table1 +) +JOIN ( + SELECT b AS b2 FROM table2 +) ON b1 = b2 + +# try zero column with Window, 2 row but empty +statement ok +SELECT * EXCEPT (a, b, row_num) +FROM ( + SELECT + a, + b, + ROW_NUMBER() OVER (ORDER BY b) AS row_num + FROM table1 +) + # EXCLUDE order shouldn't matter query II SELECT * EXCLUDE(b, a) @@ -1211,7 +1378,7 @@ logical_plan 01)Sort: table1.a ASC NULLS LAST 02)--TableScan: table1 projection=[a] physical_plan -01)SortExec: expr=[a@0 ASC NULLS LAST] +01)SortExec: expr=[a@0 ASC NULLS LAST], preserve_partitioning=[false] 02)--MemoryExec: partitions=1, partition_sizes=[1] # ambiguous column references in on join @@ -1272,7 +1439,7 @@ logical_plan 02)--Filter: annotated_data_finite2.a = Int32(0) 03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0)] physical_plan -01)SortPreservingMergeExec: [b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +01)SortPreservingMergeExec: [b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: a@1 = 0 04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1314,7 +1481,7 @@ logical_plan 02)--Filter: annotated_data_finite2.a = Int32(0) AND annotated_data_finite2.b = Int32(0) 03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0), annotated_data_finite2.b = Int32(0)] physical_plan -01)SortPreservingMergeExec: [b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +01)SortPreservingMergeExec: [b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: a@1 = 0 AND b@2 = 0 04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1335,7 +1502,7 @@ logical_plan 02)--Filter: annotated_data_finite2.a = Int32(0) AND annotated_data_finite2.b = Int32(0) 03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0), annotated_data_finite2.b = Int32(0)] physical_plan -01)SortPreservingMergeExec: [a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] +01)SortPreservingMergeExec: [a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: a@1 = 0 AND b@2 = 0 04)------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1357,7 +1524,7 @@ logical_plan 03)----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[annotated_data_finite2.a = Int32(0) OR annotated_data_finite2.b = Int32(0)] physical_plan 01)SortPreservingMergeExec: [c@3 ASC NULLS LAST] -02)--SortExec: expr=[c@3 ASC NULLS LAST] +02)--SortExec: expr=[c@3 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------FilterExec: a@1 = 0 OR b@2 = 0 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1369,22 +1536,25 @@ query TT EXPLAIN SELECT c2, COUNT(*) FROM (SELECT c2 FROM aggregate_test_100 -ORDER BY c1, c2) +ORDER BY c1, c2 +LIMIT 4) GROUP BY c2; ---- logical_plan -01)Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Aggregate: groupBy=[[aggregate_test_100.c2]], aggr=[[count(Int64(1)) AS count(*)]] 02)--Projection: aggregate_test_100.c2 -03)----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST +03)----Sort: aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST, fetch=4 04)------Projection: aggregate_test_100.c2, aggregate_test_100.c1 05)--------TableScan: aggregate_test_100 projection=[c1, c2] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] +01)AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[count(*)] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----RepartitionExec: partitioning=Hash([c2@0], 2), input_partitions=2 -04)------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[COUNT(*)] +04)------AggregateExec: mode=Partial, gby=[c2@0 as c2], aggr=[count(*)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true +06)----------ProjectionExec: expr=[c2@0 as c2] +07)------------SortExec: TopK(fetch=4), expr=[c1@1 ASC NULLS LAST, c2@0 ASC NULLS LAST], preserve_partitioning=[false] +08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c1], has_header=true # FilterExec can track equality of non-column expressions. # plan below shouldn't have a SortExec because given column 'a' is ordered. @@ -1426,12 +1596,12 @@ query TT EXPLAIN SELECT x/2, x/2+1 FROM t; ---- logical_plan -01)Projection: t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2), t.x / Int64(2)Int64(2)t.x AS t.x / Int64(2) + Int64(1) -02)--Projection: t.x / Int64(2) AS t.x / Int64(2)Int64(2)t.x +01)Projection: __common_expr_1 AS t.x / Int64(2), __common_expr_1 AS t.x / Int64(2) + Int64(1) +02)--Projection: t.x / Int64(2) AS __common_expr_1 03)----TableScan: t projection=[x] physical_plan -01)ProjectionExec: expr=[t.x / Int64(2)Int64(2)t.x@0 as t.x / Int64(2), t.x / Int64(2)Int64(2)t.x@0 + 1 as t.x / Int64(2) + Int64(1)] -02)--ProjectionExec: expr=[x@0 / 2 as t.x / Int64(2)Int64(2)t.x] +01)ProjectionExec: expr=[__common_expr_1@0 as t.x / Int64(2), __common_expr_1@0 + 1 as t.x / Int64(2) + Int64(1)] +02)--ProjectionExec: expr=[x@0 / 2 as __common_expr_1] 03)----MemoryExec: partitions=1, partition_sizes=[1] query II @@ -1444,12 +1614,12 @@ query TT EXPLAIN SELECT abs(x), abs(x) + abs(y) FROM t; ---- logical_plan -01)Projection: abs(t.x)t.x AS abs(t.x), abs(t.x)t.x AS abs(t.x) + abs(t.y) -02)--Projection: abs(t.x) AS abs(t.x)t.x, t.y +01)Projection: __common_expr_1 AS abs(t.x), __common_expr_1 AS abs(t.x) + abs(t.y) +02)--Projection: abs(t.x) AS __common_expr_1, t.y 03)----TableScan: t projection=[x, y] physical_plan -01)ProjectionExec: expr=[abs(t.x)t.x@0 as abs(t.x), abs(t.x)t.x@0 + abs(y@1) as abs(t.x) + abs(t.y)] -02)--ProjectionExec: expr=[abs(x@0) as abs(t.x)t.x, y@1 as y] +01)ProjectionExec: expr=[__common_expr_1@0 as abs(t.x), __common_expr_1@0 + abs(y@1) as abs(t.x) + abs(t.y)] +02)--ProjectionExec: expr=[abs(x@0) as __common_expr_1, y@1 as y] 03)----MemoryExec: partitions=1, partition_sizes=[1] query II @@ -1494,21 +1664,25 @@ query TT EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; ---- logical_plan -01)Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x -02)--TableScan: t projection=[x, y] +01)Projection: __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND __common_expr_1 AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x +02)--Projection: t.y > Int32(0) AS __common_expr_1, t.x, t.y +03)----TableScan: t projection=[x, y] physical_plan -01)ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] -02)--MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[__common_expr_1@0 AND 1 / CAST(y@2 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@1 > 0 AND __common_expr_1@0 AND 1 / CAST(y@2 AS Int64) < 1 / CAST(x@1 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] +02)--ProjectionExec: expr=[y@1 > 0 as __common_expr_1, x@0 as x, y@1 as y] +03)----MemoryExec: partitions=1, partition_sizes=[1] query TT EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; ---- logical_plan -01)Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x -02)--TableScan: t projection=[x, y] +01)Projection: __common_expr_1 OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR __common_expr_1 OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x +02)--Projection: t.y = Int32(0) AS __common_expr_1, t.x, t.y +03)----TableScan: t projection=[x, y] physical_plan -01)ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] -02)--MemoryExec: partitions=1, partition_sizes=[1] +01)ProjectionExec: expr=[__common_expr_1@0 OR 1 / CAST(y@2 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@1 = 0 OR __common_expr_1@0 OR 1 / CAST(y@2 AS Int64) < 1 / CAST(x@1 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] +02)--ProjectionExec: expr=[y@1 = 0 as __common_expr_1, x@0 as x, y@1 as y] +03)----MemoryExec: partitions=1, partition_sizes=[1] # due to the reason describe in https://github.com/apache/datafusion/issues/8927, # the following queries will fail @@ -1587,6 +1761,9 @@ SELECT i + i FROM test WHERE i > 2; ---- 6 +statement ok +DROP TABLE test; + query error DataFusion error: Arrow error: Parser error: Error parsing timestamp from 'I AM NOT A TIMESTAMP': error parsing date SELECT to_timestamp('I AM NOT A TIMESTAMP'); @@ -1613,6 +1790,23 @@ select count(1) from v; ---- 1 +# Ensure CSE resolves columns correctly +# should be 3, not 1 +# https://github.com/apache/datafusion/issues/10413 +query I +select a + b from (select 1 as a, 2 as b, 1 as "a + b"); +---- +3 + +# Can't reference an output column by expression over projection. +query error DataFusion error: Schema error: No field named a\. Valid fields are "a \+ Int64\(1\)"\. +select a + 1 from (select a+1 from (select 1 as a)); + +query I +select "a + Int64(1)" + 10 from (select a+1 from (select 1 as a)); +---- +12 + # run below query without logical optimizations statement ok set datafusion.optimizer.max_passes=0; @@ -1626,3 +1820,19 @@ select a from t; statement ok set datafusion.optimizer.max_passes=3; + +# Test issue: https://github.com/apache/datafusion/issues/12183 +statement ok +CREATE TABLE test(a BIGINT) AS VALUES (1); + +query I +SELECT "test.a" FROM (SELECT a AS "test.a" FROM test) +---- +1 + +statement ok +DROP TABLE test; + +# Can't reference an unqualified column by a qualified name +query error DataFusion error: Schema error: No field named t1\.v1\. Valid fields are "t1\.v1"\. +SELECT t1.v1 FROM (SELECT 1 AS "t1.v1"); diff --git a/datafusion/sqllogictest/test_files/set_variable.slt b/datafusion/sqllogictest/test_files/set_variable.slt index fccd144a37fb..6f19c9f4d42f 100644 --- a/datafusion/sqllogictest/test_files/set_variable.slt +++ b/datafusion/sqllogictest/test_files/set_variable.slt @@ -216,19 +216,19 @@ set datafusion.catalog.information_schema = true statement ok SET TIME ZONE = '+08:00:00' -statement error Arrow error: Parser error: Invalid timezone "\+08:00:00": '\+08:00:00' is not a valid timezone +statement error Arrow error: Parser error: Invalid timezone "\+08:00:00": failed to parse timezone SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ statement ok SET TIME ZONE = '08:00' -statement error Arrow error: Parser error: Invalid timezone "08:00": '08:00' is not a valid timezone +statement error Arrow error: Parser error: Invalid timezone "08:00": failed to parse timezone SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ statement ok SET TIME ZONE = '08' -statement error Arrow error: Parser error: Invalid timezone "08": '08' is not a valid timezone +statement error Arrow error: Parser error: Invalid timezone "08": failed to parse timezone SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ statement ok @@ -242,5 +242,5 @@ SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ statement ok SET TIME ZONE = 'Asia/Taipei2' -statement error Arrow error: Parser error: Invalid timezone "Asia/Taipei2": 'Asia/Taipei2' is not a valid timezone +statement error Arrow error: Parser error: Invalid timezone "Asia/Taipei2": failed to parse timezone SELECT '2000-01-01T00:00:00'::TIMESTAMP::TIMESTAMPTZ diff --git a/datafusion/sqllogictest/test_files/sort_merge_join.slt b/datafusion/sqllogictest/test_files/sort_merge_join.slt index 0361d4e5af7f..f4cc888d6b8e 100644 --- a/datafusion/sqllogictest/test_files/sort_merge_join.slt +++ b/datafusion/sqllogictest/test_files/sort_merge_join.slt @@ -38,14 +38,10 @@ logical_plan 03)--TableScan: t2 projection=[a, b] physical_plan 01)SortMergeJoin: join_type=Inner, on=[(a@0, a@0)], filter=CAST(b@1 AS Int64) * 50 <= CAST(b@0 AS Int64) -02)--SortExec: expr=[a@0 ASC] -03)----CoalesceBatchesExec: target_batch_size=8192 -04)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 -05)--------MemoryExec: partitions=1, partition_sizes=[1] -06)--SortExec: expr=[a@0 ASC] -07)----CoalesceBatchesExec: target_batch_size=8192 -08)------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 -09)--------MemoryExec: partitions=1, partition_sizes=[1] +02)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false] +03)----MemoryExec: partitions=1, partition_sizes=[1] +04)--SortExec: expr=[a@0 ASC], preserve_partitioning=[false] +05)----MemoryExec: partitions=1, partition_sizes=[1] # inner join with join filter query TITI rowsort @@ -84,7 +80,6 @@ SELECT * FROM t1 LEFT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 -Alice 50 NULL NULL Bob 1 NULL NULL query TITI rowsort @@ -112,7 +107,6 @@ SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t2.b * 50 <= t1.b Alice 100 Alice 1 Alice 100 Alice 2 Alice 50 Alice 1 -NULL NULL Alice 2 query TITI rowsort SELECT * FROM t1 RIGHT JOIN t2 ON t1.a = t2.a AND t1.b > t2.b @@ -132,29 +126,24 @@ Alice 50 Alice 1 Alice 50 Alice 2 Bob 1 NULL NULL +# Uncomment when filtered FULL moved # full join with join filter -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b ----- -Alice 100 NULL NULL -Alice 100 NULL NULL -Alice 50 Alice 2 -Alice 50 NULL NULL -Bob 1 NULL NULL -NULL NULL Alice 1 -NULL NULL Alice 1 -NULL NULL Alice 2 - -query TITI rowsort -SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 ----- -Alice 100 Alice 1 -Alice 100 Alice 2 -Alice 50 NULL NULL -Alice 50 NULL NULL -Bob 1 NULL NULL -NULL NULL Alice 1 -NULL NULL Alice 2 +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t2.b * 50 > t1.b +#---- +#Alice 100 NULL NULL +#Alice 50 Alice 2 +#Bob 1 NULL NULL +#NULL NULL Alice 1 + +# Uncomment when filtered FULL moved +#query TITI rowsort +#SELECT * FROM t1 FULL JOIN t2 ON t1.a = t2.a AND t1.b > t2.b + 50 +#---- +#Alice 100 Alice 1 +#Alice 100 Alice 2 +#Alice 50 NULL NULL +#Bob 1 NULL NULL statement ok DROP TABLE t1; @@ -263,5 +252,407 @@ DROP TABLE t1; statement ok DROP TABLE t2; +# LEFTSEMI join tests + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2 +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 13 b + ) + select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b = t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select null a, 12 b union all + select 11 a, 13 b) + select t1.* from t1 where exists (select 1 from t1 t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +# Set batch size to 1 for sort merge join to test scenario when data spread across multiple batches +statement ok +set datafusion.execution.batch_size = 1; + +query II +SELECT * FROM ( + WITH + t1 AS ( + SELECT 12 a, 12 b + ), + t2 AS ( + SELECT 12 a, 12 b + ) + SELECT t1.* FROM t1 JOIN t2 on t1.a = t2.b WHERE t1.a > t2.b +) ORDER BY 1, 2; +---- + + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b union all + select 11 a, 13 b), +t2 as ( + select 11 a, 12 b union all + select 11 a, 12 b union all + select 11 a, 14 b + ) +select t1.* from t1 where exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b) +) order by 1, 2; +---- +11 12 +11 13 + +#LEFTANTI tests +statement ok +set datafusion.execution.batch_size = 10; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 11 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 11 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 14 c union all + select 11 a, 11 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + + +# Test LEFT ANTI with cross batch data distribution +statement ok +set datafusion.execution.batch_size = 1; + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c union all + select 11 a, 14 b, 4 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query III +select * from ( +with +t1 as ( + select 11 a, 12 b, 1 c union all + select 11 a, 13 b, 2 c), +t2 as ( + select 11 a, 12 b, 3 c where false + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t2.b != t1.b and t1.c > t2.c) +) order by 1, 2; +---- +11 12 1 +11 13 2 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 13 c union all + select 11 a, 14 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- +11 12 + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 11 c union all + select 11 a, 15 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query II +select * from ( +with +t1 as ( + select 11 a, 12 b), +t2 as ( + select 11 a, 12 c union all + select 11 a, 14 c union all + select 11 a, 11 c + ) +select t1.* from t1 where not exists (select 1 from t2 where t2.a = t1.a and t1.b > t2.c) +) order by 1, 2 +---- + +query IIII +select * from ( +with t as ( + select id, id % 5 id1 from (select unnest(range(0,10)) id) +), t1 as ( + select id % 10 id, id + 2 id1 from (select unnest(range(0,10)) id) +) +select * from t right join t1 on t.id1 = t1.id and t.id > t1.id1 +) order by 1, 2, 3, 4 +---- +5 0 0 2 +6 1 1 3 +7 2 2 4 +8 3 3 5 +9 4 4 6 +NULL NULL 5 7 +NULL NULL 6 8 +NULL NULL 7 9 +NULL NULL 8 10 +NULL NULL 9 11 + +query IIII +select * from ( +with t as ( + select id_a id_a_1, id_a % 5 id_a_2 from (select unnest(make_array(5, 6, 7, 8, 9, 0, 1, 2, 3, 4)) id_a) +), t1 as ( + select id_b % 10 id_b_1, id_b + 2 id_b_2 from (select unnest(make_array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) id_b) +) +select * from t full join t1 on t.id_a_2 = t1.id_b_1 and t.id_a_1 > t1.id_b_2 +) order by 1, 2, 3, 4 +---- +0 0 NULL NULL +1 1 NULL NULL +2 2 NULL NULL +3 3 NULL NULL +4 4 NULL NULL +5 0 0 2 +6 1 1 3 +7 2 2 4 +8 3 3 5 +9 4 4 6 +NULL NULL 5 7 +NULL NULL 6 8 +NULL NULL 7 9 +NULL NULL 8 10 +NULL NULL 9 11 + +# return sql params back to default values statement ok set datafusion.optimizer.prefer_hash_join = true; + +statement ok +set datafusion.execution.batch_size = 8192; diff --git a/datafusion/sqllogictest/test_files/string/README.md b/datafusion/sqllogictest/test_files/string/README.md new file mode 100644 index 000000000000..8693ef16f9d7 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/README.md @@ -0,0 +1,44 @@ + + +# String Test Files + +This directory contains test files for the `string` test suite. +To ensure consistent behavior across different string types, we should run the same tests with the same inputs on all string types. +There is a framework in place to execute the same tests across different string types. + +See [#12415](https://github.com/apache/datafusion/issues/12415) for more background. + +## Directory Structure + +``` +string/ + - init_data.slt.part // generate the testing data + - string_query.slt.part // the sharing tests for all string type + - string.slt // the entrypoint for string type + - large_string.slt // the entrypoint for large_string type + - string_view.slt // the entrypoint for string_view type and the string_view specific tests + - string_literal.slt // the tests for string literal +``` + +## Pattern for Test Entry Point Files + +Any entry point file should include `init_data.slt.part` and `string_query.slt.part`. + +Planning-related tests (e.g., EXPLAIN ...) should be placed in their own entry point file (e.g., `string_view.slt`) as they are only used to assert planning behavior specific to that type. diff --git a/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt new file mode 100644 index 000000000000..c43f3a4cc16b --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/dictionary_utf8.slt @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# Setup test tables with different physical string types +# and repeat tests in `string_query.slt.part` +# -------------------------------------- +statement ok +create table test_basic_operator as +select + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as ascii_1, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as ascii_2, + arrow_cast(column3, 'Dictionary(Int32, Utf8)') as unicode_1, + arrow_cast(column4, 'Dictionary(Int32, Utf8)') as unicode_2 +from test_source; + +statement ok +create table test_substr as +select arrow_cast(col1, 'Dictionary(Int32, Utf8)') as c1 from test_substr_base; + +statement ok +drop table test_source + +query T +SELECT arrow_cast('', 'Dictionary(Int32, Utf8)'); +---- +(empty) + +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +under_score un iść core false false false false +percent pan Tadeusz ma iść w kąt false false false false +(empty) (empty) false false false false +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# +# common test for string-like functions and operators +# +include ./string_query.slt.part + +# +# Clean up +# +statement ok +drop table test_basic_operator; + +statement ok +drop table test_substr_base; diff --git a/datafusion/sqllogictest/test_files/string/init_data.slt.part b/datafusion/sqllogictest/test_files/string/init_data.slt.part new file mode 100644 index 000000000000..e3914ea49855 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/init_data.slt.part @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +create table test_source as values + ('Andrew', 'X', 'datafusion📊🔥', '🔥'), + ('Xiangpeng', 'Xiangpeng', 'datafusion数据融合', 'datafusion数据融合'), + ('Raphael', 'R', 'datafusionДатаФусион', 'аФус'), + ('under_score', 'un_____core', 'un iść core', 'chrząszcz na łące w 東京都'), + ('percent', 'p%t', 'pan Tadeusz ma iść w kąt', 'Pan Tadeusz ma frunąć stąd w kąt'), + ('', '%', '', ''), + (NULL, '%', NULL, NULL), + (NULL, 'R', NULL, '🔥'); + +# -------------------------------------- +# Setup test tables with different physical string types (Utf8/Utf8View/LargeUtf8) +# and repeat tests in `substr_table.slt.part` +# -------------------------------------- +statement ok +create table test_substr_base ( + col1 VARCHAR +) as values ('foo'), ('hello🌏世界'), ('💩'), ('ThisIsAVeryLongASCIIString'), (''), (NULL); diff --git a/datafusion/sqllogictest/test_files/string/large_string.slt b/datafusion/sqllogictest/test_files/string/large_string.slt new file mode 100644 index 000000000000..1cf906d7dc75 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/large_string.slt @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# Setup test tables with different physical string types +# and repeat tests in `string_query.slt.part` +# -------------------------------------- +statement ok +create table test_basic_operator as +select + arrow_cast(column1, 'LargeUtf8') as ascii_1, + arrow_cast(column2, 'LargeUtf8') as ascii_2, + arrow_cast(column3, 'LargeUtf8') as unicode_1, + arrow_cast(column4, 'LargeUtf8') as unicode_2 +from test_source; + +statement ok +create table test_substr as +select arrow_cast(col1, 'LargeUtf8') as c1 from test_substr_base; + +# select +query TTTT +SELECT ascii_1, ascii_2, unicode_1, unicode_2 FROM test_basic_operator +---- +Andrew X datafusion📊🔥 🔥 +Xiangpeng Xiangpeng datafusion数据融合 datafusion数据融合 +Raphael R datafusionДатаФусион аФус +under_score un_____core un iść core chrząszcz na łące w 東京都 +percent p%t pan Tadeusz ma iść w kąt Pan Tadeusz ma frunąć stąd w kąt +(empty) % (empty) (empty) +NULL % NULL NULL +NULL R NULL 🔥 + +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +under_score un iść core false false false false +percent pan Tadeusz ma iść w kąt false false false false +(empty) (empty) false false false false +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# +# common test for string-like functions and operators +# +include ./string_query.slt.part + +# +# Clean up +# + +statement ok +drop table test_basic_operator; + +statement ok +drop table test_substr_base; diff --git a/datafusion/sqllogictest/test_files/string/string.slt b/datafusion/sqllogictest/test_files/string/string.slt new file mode 100644 index 000000000000..9e97712b6871 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/string.slt @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# Setup test tables with different physical string types +# and repeat tests in `string_query.slt.part` +# -------------------------------------- +statement ok +create table test_basic_operator as +select + arrow_cast(column1, 'Utf8') as ascii_1, + arrow_cast(column2, 'Utf8') as ascii_2, + arrow_cast(column3, 'Utf8') as unicode_1, + arrow_cast(column4, 'Utf8') as unicode_2 +from test_source; + +statement ok +create table test_substr as +select arrow_cast(col1, 'Utf8') as c1 from test_substr_base; + +# TODO: move it back to `string_query.slt.part` after fixing the issue +# see detail: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An%' as ascii_like, + unicode_1 like '%ion数据%' as unicode_like, + ascii_1 ilike 'An%' as ascii_ilike, + unicode_1 ilike '%ion数据%' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false true false +Xiangpeng datafusion数据融合 false true false true +Raphael datafusionДатаФусион false false false false +under_score un iść core false false false false +percent pan Tadeusz ma iść w kąt false false false false +(empty) (empty) false false false false +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# +# common test for string-like functions and operators +# +include ./string_query.slt.part + +# TODO support all String types in sql_like_to_expr and move this test to `string_query.slt.part` +# dynamic LIKE as filter +query TTT rowsort +SELECT ascii_1, 'is LIKE', ascii_2 FROM test_basic_operator WHERE ascii_1 LIKE ascii_2 +UNION ALL +SELECT ascii_1, 'is NOT LIKE', ascii_2 FROM test_basic_operator WHERE ascii_1 NOT LIKE ascii_2 +UNION ALL +SELECT unicode_1, 'is LIKE', ascii_2 FROM test_basic_operator WHERE unicode_1 LIKE ascii_2 +UNION ALL +SELECT unicode_1, 'is NOT LIKE', ascii_2 FROM test_basic_operator WHERE unicode_1 NOT LIKE ascii_2 +UNION ALL +SELECT unicode_2, 'is LIKE', ascii_2 FROM test_basic_operator WHERE unicode_2 LIKE ascii_2 +UNION ALL +SELECT unicode_2, 'is NOT LIKE', ascii_2 FROM test_basic_operator WHERE unicode_2 NOT LIKE ascii_2 +---- +(empty) is LIKE % +(empty) is LIKE % +(empty) is LIKE % +Andrew is NOT LIKE X +Pan Tadeusz ma frunąć stąd w kąt is NOT LIKE p%t +Raphael is NOT LIKE R +Xiangpeng is LIKE Xiangpeng +chrząszcz na łące w 東京都 is NOT LIKE un_____core +datafusionДатаФусион is NOT LIKE R +datafusion数据融合 is NOT LIKE Xiangpeng +datafusion数据融合 is NOT LIKE Xiangpeng +datafusion📊🔥 is NOT LIKE X +pan Tadeusz ma iść w kąt is LIKE p%t +percent is LIKE p%t +un iść core is LIKE un_____core +under_score is LIKE un_____core +аФус is NOT LIKE R +🔥 is NOT LIKE R +🔥 is NOT LIKE X + +# TODO support all String types in sql_like_to_expr and move this test to `string_query.slt.part` +# dynamic LIKE as projection +query TTTTBBBB rowsort +SELECT + ascii_1, ascii_2, unicode_1, unicode_2, + (ascii_1 LIKE ascii_2) AS ascii_1_like_ascii_2, + (ascii_2 LIKE ascii_1) AS ascii_2_like_ascii_1, + (unicode_1 LIKE ascii_2) AS unicode_1_like_ascii_2, + (unicode_2 LIKE ascii_2) AS unicode_2_like_ascii_2 +FROM test_basic_operator +---- +(empty) % (empty) (empty) true false true true +Andrew X datafusion📊🔥 🔥 false false false false +NULL % NULL NULL NULL NULL NULL NULL +NULL R NULL 🔥 NULL NULL NULL false +Raphael R datafusionДатаФусион аФус false false false false +Xiangpeng Xiangpeng datafusion数据融合 datafusion数据融合 true true false false +percent p%t pan Tadeusz ma iść w kąt Pan Tadeusz ma frunąć stąd w kąt true false true false +under_score un_____core un iść core chrząszcz na łące w 東京都 true false true false + +# TODO support all String types in sql_like_to_expr and move this test to `string_query.slt.part` +# dynamic ILIKE as filter +query TTT rowsort +SELECT ascii_1, 'is ILIKE', ascii_2 FROM test_basic_operator WHERE ascii_1 ILIKE ascii_2 +UNION ALL +SELECT ascii_1, 'is NOT ILIKE', ascii_2 FROM test_basic_operator WHERE ascii_1 NOT ILIKE ascii_2 +UNION ALL +SELECT unicode_1, 'is ILIKE', ascii_2 FROM test_basic_operator WHERE unicode_1 ILIKE ascii_2 +UNION ALL +SELECT unicode_1, 'is NOT ILIKE', ascii_2 FROM test_basic_operator WHERE unicode_1 NOT ILIKE ascii_2 +UNION ALL +SELECT unicode_2, 'is ILIKE', ascii_2 FROM test_basic_operator WHERE unicode_2 ILIKE ascii_2 +UNION ALL +SELECT unicode_2, 'is NOT ILIKE', ascii_2 FROM test_basic_operator WHERE unicode_2 NOT ILIKE ascii_2 +---- +(empty) is ILIKE % +(empty) is ILIKE % +(empty) is ILIKE % +Andrew is NOT ILIKE X +Pan Tadeusz ma frunąć stąd w kąt is ILIKE p%t +Raphael is NOT ILIKE R +Xiangpeng is ILIKE Xiangpeng +chrząszcz na łące w 東京都 is NOT ILIKE un_____core +datafusionДатаФусион is NOT ILIKE R +datafusion数据融合 is NOT ILIKE Xiangpeng +datafusion数据融合 is NOT ILIKE Xiangpeng +datafusion📊🔥 is NOT ILIKE X +pan Tadeusz ma iść w kąt is ILIKE p%t +percent is ILIKE p%t +un iść core is ILIKE un_____core +under_score is ILIKE un_____core +аФус is NOT ILIKE R +🔥 is NOT ILIKE R +🔥 is NOT ILIKE X + +# TODO support all String types in sql_like_to_expr and move this test to `string_query.slt.part` +# dynamic ILIKE as projection +query TTTTBBBB rowsort +SELECT + ascii_1, ascii_2, unicode_1, unicode_2, + (ascii_1 ILIKE ascii_2) AS ascii_1_ilike_ascii_2, + (ascii_2 ILIKE ascii_1) AS ascii_2_ilike_ascii_1, + (unicode_1 ILIKE ascii_2) AS unicode_1_ilike_ascii_2, + (unicode_2 ILIKE ascii_2) AS unicode_2_ilike_ascii_2 +FROM test_basic_operator +---- +(empty) % (empty) (empty) true false true true +Andrew X datafusion📊🔥 🔥 false false false false +NULL % NULL NULL NULL NULL NULL NULL +NULL R NULL 🔥 NULL NULL NULL false +Raphael R datafusionДатаФусион аФус false false false false +Xiangpeng Xiangpeng datafusion数据融合 datafusion数据融合 true true false false +percent p%t pan Tadeusz ma iść w kąt Pan Tadeusz ma frunąć stąd w kąt true false true true +under_score un_____core un iść core chrząszcz na łące w 東京都 true false true false + +# +# Clean up +# + +statement ok +drop table test_basic_operator; + +statement ok +drop table test_substr; diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt new file mode 100644 index 000000000000..80bd7fc59c00 --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -0,0 +1,823 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +query T +SELECT substr('alphabet', -3) +---- +alphabet + +query T +SELECT substr('alphabet', 0) +---- +alphabet + +query T +SELECT substr('alphabet', 1) +---- +alphabet + +query T +SELECT substr('alphabet', 2) +---- +lphabet + +query T +SELECT substr('alphabet', 3) +---- +phabet + +query T +SELECT substr('alphabet', 30) +---- +(empty) + +query T +SELECT substr('alphabet', 3, 2) +---- +ph + +query T +SELECT substr('alphabet', 3, 20) +---- +phabet + +query TT +select + substr(arrow_cast('alphabet', 'LargeUtf8'), 3, 20), + substr(arrow_cast('alphabet', 'Utf8View'), 3, 20); +---- +phabet phabet + +# test range ouside of string length +query TTTTTTTTTTTT +SELECT + substr('hi🌏', 1, 3), + substr('hi🌏', 1, 4), + substr('hi🌏', 1, 100), + substr('hi🌏', 0, 1), + substr('hi🌏', 0, 2), + substr('hi🌏', 0, 4), + substr('hi🌏', 0, 5), + substr('hi🌏', -10, 100), + substr('hi🌏', -10, 12), + substr('hi🌏', -10, 5), + substr('hi🌏', 10, 0), + substr('hi🌏', 10, 10); +---- +hi🌏 hi🌏 hi🌏 (empty) h hi🌏 hi🌏 hi🌏 h (empty) (empty) (empty) + +query TTTTTTTTTTTT +SELECT + substr('', 1, 3), + substr('', 1, 4), + substr('', 1, 100), + substr('', 0, 1), + substr('', 0, 2), + substr('', 0, 4), + substr('', 0, 5), + substr('', -10, 100), + substr('', -10, 12), + substr('', -10, 5), + substr('', 10, 0), + substr('', 10, 10); +---- +(empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) + +# Nulls +query TTTTTTTTTT +SELECT + substr('alphabet', NULL), + substr(NULL, 1), + substr(NULL, NULL), + substr('alphabet', CAST(NULL AS int), -20), + substr('alphabet', 3, CAST(NULL AS int)), + substr(NULL, 3, -4), + substr(NULL, NULL, 4), + substr(NULL, 1, NULL), + substr('', NULL, NULL), + substr(NULL, NULL, NULL); +---- +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +query T +SELECT substr('Hello🌏世界', 5) +---- +o🌏世界 + +query T +SELECT substr('Hello🌏世界', 5, 3) +---- +o🌏世 + +statement error The first argument of the substr function can only be a string, but got Int64 +SELECT substr(1, 3) + +statement error The first argument of the substr function can only be a string, but got Int64 +SELECT substr(1, 3, 4) + +statement error Execution error: negative substring length not allowed +select substr(arrow_cast('foo', 'Utf8View'), 1, -1); + +statement error Execution error: negative substring length not allowed +select substr('', 1, -1); + +# StringView scalar to StringView scalar + +query BBBB +select + arrow_cast('NULL', 'Utf8View') = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('NULL', 'Utf8View') <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Xiangpeng', 'Utf8View') <> arrow_cast('Andrew', 'Utf8View'); +---- +false true true true + + +query II +SELECT + ASCII('hello'), + ASCII(arrow_cast('world', 'Utf8View')) +---- +104 119 + +query III +SELECT + ASCII(arrow_cast('äöüß', 'Utf8View')) as c1, + ASCII(arrow_cast('', 'Utf8View')) as c2, + ASCII(arrow_cast(NULL, 'Utf8View')) as c3 +---- +228 0 NULL + +# coercion from stringview to integer, as input to make_date +query D +select make_date(arrow_cast('2024', 'Utf8View'), arrow_cast('01', 'Utf8View'), arrow_cast('23', 'Utf8View')) +---- +2024-01-23 + +query I +SELECT character_length('') +---- +0 + +query I +SELECT character_length('chars') +---- +5 + +query I +SELECT character_length('josé') +---- +4 + +query I +SELECT character_length(NULL) +---- +NULL + +query B +SELECT ends_with('foobar', 'bar') +---- +true + +query B +SELECT ends_with('foobar', 'foo') +---- +false + +query I +SELECT levenshtein('kitten', 'sitting') +---- +3 + +query I +SELECT levenshtein('kitten', NULL) +---- +NULL + +query I +SELECT levenshtein(NULL, 'sitting') +---- +NULL + +query I +SELECT levenshtein(NULL, NULL) +---- +NULL + + +query T +SELECT lpad('hi', -1, 'xy') +---- +(empty) + +query T +SELECT lpad('hi', 5, 'xy') +---- +xyxhi + +query T +SELECT lpad('hi', -1) +---- +(empty) + +query T +SELECT lpad('hi', 0) +---- +(empty) + +query T +SELECT lpad('hi', 21, 'abcdef') +---- +abcdefabcdefabcdefahi + +query T +SELECT lpad('hi', 5, 'xy') +---- +xyxhi + +query T +SELECT lpad('hi', 5, NULL) +---- +NULL + +query T +SELECT lpad('hi', 5) +---- + hi + +query T +SELECT lpad('hi', CAST(NULL AS INT), 'xy') +---- +NULL + +query T +SELECT lpad('hi', CAST(NULL AS INT)) +---- +NULL + +query T +SELECT lpad('xyxhi', 3) +---- +xyx + +query T +SELECT lpad(NULL, 0) +---- +NULL + +query T +SELECT lpad(NULL, 5, 'xy') +---- +NULL + +query T +SELECT regexp_replace('foobar', 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT regexp_replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'xx', 'gi') +---- +fooxx + +query T +SELECT repeat('foo', 3) +---- +foofoofoo + +query T +SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3) +---- +foofoofoo + + +query T +SELECT replace('foobar', 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Dictionary(Int32, Utf8)'), 'bar', 'hello') +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'Utf8View'), arrow_cast('bar', 'Utf8View'), arrow_cast('hello', 'Utf8View')) +---- +foohello + +query T +SELECT replace(arrow_cast('foobar', 'LargeUtf8'), arrow_cast('bar', 'LargeUtf8'), arrow_cast('hello', 'LargeUtf8')) +---- +foohello + + +query T +SELECT reverse('abcde') +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'LargeUtf8')) +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'Utf8View')) +---- +edcba + +query T +SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)')) +---- +edcba + +query T +SELECT reverse('loẅks') +---- +sk̈wol + +query T +SELECT reverse(arrow_cast('loẅks', 'LargeUtf8')) +---- +sk̈wol + +query T +SELECT reverse(arrow_cast('loẅks', 'Utf8View')) +---- +sk̈wol + +query T +SELECT reverse(NULL) +---- +NULL + +query T +SELECT reverse(arrow_cast(NULL, 'LargeUtf8')) +---- +NULL + +query T +SELECT reverse(arrow_cast(NULL, 'Utf8View')) +---- +NULL + + +query I +SELECT strpos('abc', 'c') +---- +3 + +query I +SELECT strpos('josé', 'é') +---- +4 + +query I +SELECT strpos('joséésoj', 'so') +---- +6 + +query I +SELECT strpos('joséésoj', 'abc') +---- +0 + +query I +SELECT strpos(NULL, 'abc') +---- +NULL + +query I +SELECT strpos('joséésoj', NULL) +---- +NULL + + + +query T +SELECT rpad('hi', -1, 'xy') +---- +(empty) + +query T +SELECT rpad('hi', 5, 'xy') +---- +hixyx + +query T +SELECT rpad('hi', -1) +---- +(empty) + +query T +SELECT rpad('hi', 0) +---- +(empty) + +query T +SELECT rpad('hi', 21, 'abcdef') +---- +hiabcdefabcdefabcdefa + +query T +SELECT rpad('hi', 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Dictionary(Int32, Utf8)'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad('hi', 5, NULL) +---- +NULL + +query T +SELECT rpad('hi', 5) +---- +hi + +query T +SELECT rpad('hi', CAST(NULL AS INT), 'xy') +---- +NULL + +query T +SELECT rpad('hi', CAST(NULL AS INT)) +---- +NULL + +query T +SELECT rpad('xyxhi', 3) +---- +xyx + +# test for rpad with largeutf8 and utf8View + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) +---- +hixyx + +query T +SELECT rpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) +---- +hixyx + +query T +SELECT rpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') +---- +NULL + +query I +SELECT char_length('') +---- +0 + +query I +SELECT char_length('chars') +---- +5 + +query I +SELECT char_length('josé') +---- +4 + +query I +SELECT char_length(NULL) +---- +NULL + +# Test substring_index using '.' as delimiter +# This query is compatible with MySQL(8.0.19 or later), convenient for comparing results +query TIT +SELECT str, n, substring_index(str, '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + +# Test substring_index using '.' as delimiter with utf8view +query TIT +SELECT str, n, substring_index(arrow_cast(str, 'Utf8View'), '.', n) AS c FROM + (VALUES + ROW('arrow.apache.org'), + ROW('.'), + ROW('...'), + ROW(NULL) + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(3), + ROW(100), + ROW(-1), + ROW(-2), + ROW(-3), + ROW(-100) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +NULL -100 NULL +NULL -3 NULL +NULL -2 NULL +NULL -1 NULL +NULL 1 NULL +NULL 2 NULL +NULL 3 NULL +NULL 100 NULL +arrow.apache.org -100 arrow.apache.org +arrow.apache.org -3 arrow.apache.org +arrow.apache.org -2 apache.org +arrow.apache.org -1 org +arrow.apache.org 1 arrow +arrow.apache.org 2 arrow.apache +arrow.apache.org 3 arrow.apache.org +arrow.apache.org 100 arrow.apache.org +... -100 ... +... -3 .. +... -2 . +... -1 (empty) +... 1 (empty) +... 2 . +... 3 .. +... 100 ... +. -100 . +. -3 . +. -2 . +. -1 (empty) +. 1 (empty) +. 2 . +. 3 . +. 100 . + +# Test substring_index using 'ac' as delimiter +query TIT +SELECT str, n, substring_index(str, 'ac', n) AS c FROM + (VALUES + -- input string does not contain the delimiter + ROW('arrow'), + -- input string contains the delimiter + ROW('arrow.apache.org') + ) AS strings(str), + (VALUES + ROW(1), + ROW(2), + ROW(-1), + ROW(-2) + ) AS occurrences(n) +ORDER BY str DESC, n; +---- +arrow.apache.org -2 arrow.apache.org +arrow.apache.org -1 he.org +arrow.apache.org 1 arrow.ap +arrow.apache.org 2 arrow.apache.org +arrow -2 arrow +arrow -1 arrow +arrow 1 arrow +arrow 2 arrow + +# Test substring_index with NULL values +query TTTT +SELECT + substring_index(NULL, '.', 1), + substring_index('arrow.apache.org', NULL, 1), + substring_index('arrow.apache.org', '.', NULL), + substring_index(NULL, NULL, NULL) +---- +NULL NULL NULL NULL + +# Test substring_index with empty strings +query TT +SELECT + -- input string is empty + substring_index('', '.', 1), + -- delimiter is empty + substring_index('arrow.apache.org', '', 1) +---- +(empty) (empty) + +# Test substring_index with 0 occurrence +query T +SELECT substring_index('arrow.apache.org', 'ac', 0) +---- +(empty) + +# Test substring_index with large occurrences +query TT +SELECT + -- i64::MIN + substring_index('arrow.apache.org', '.', -9223372036854775808) as c1, + -- i64::MAX + substring_index('arrow.apache.org', '.', 9223372036854775807) as c2; +---- +arrow.apache.org arrow.apache.org + +# Test substring_index issue https://github.com/apache/datafusion/issues/9472 +query TTT +SELECT + url, + substring_index(url, '.', 1) AS subdomain, + substring_index(url, '.', -1) AS tld +FROM + (VALUES ROW('docs.apache.com'), + ROW('community.influxdata.com'), + ROW('arrow.apache.org') + ) data(url) +---- +docs.apache.com docs com +community.influxdata.com community com +arrow.apache.org arrow org + + +# find_in_set tests +query I +SELECT find_in_set('b', 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', 'a,b,c,d,a') +---- +1 + +query I +SELECT find_in_set('', 'a,b,c,d,a') +---- +0 + +query I +SELECT find_in_set('a', '') +---- +0 + + +query I +SELECT find_in_set('', '') +---- +1 + +query I +SELECT find_in_set(NULL, 'a,b,c,d') +---- +NULL + +query I +SELECT find_in_set('a', NULL) +---- +NULL + + +query I +SELECT find_in_set(NULL, NULL) +---- +NULL + +# find_in_set tests with utf8view +query I +SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d') +---- +2 + + +query I +SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +1 + +query I +SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View')) +---- +0 + + +query T +SELECT split_part('foo_bar', '_', 2) +---- +bar + +query T +SELECT split_part(arrow_cast('foo_bar', 'Dictionary(Int32, Utf8)'), '_', 2) +---- +bar + +# test largeutf8, utf8view for split_part +query T +SELECT split_part(arrow_cast('large_apple_large_orange_large_banana', 'LargeUtf8'), '_', 3) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_orange_view_banana', 'Utf8View'), '_', 3); +---- +view + +query T +SELECT split_part('test_large_split_large_case', arrow_cast('_large', 'LargeUtf8'), 2) +---- +_split + +query T +SELECT split_part(arrow_cast('huge_large_apple_large_orange_large_banana', 'LargeUtf8'), arrow_cast('_', 'Utf8View'), 2) +---- +large + +query T +SELECT split_part(arrow_cast('view_apple_view_large_banana', 'Utf8View'), arrow_cast('_large', 'LargeUtf8'), 2) +---- +_banana + +query T +SELECT split_part(NULL, '_', 2) +---- +NULL + +query B +SELECT starts_with('foobar', 'foo') +---- +true + +query B +SELECT starts_with('foobar', 'bar') +---- +false + +query TT +select ' ', '|' +---- + | diff --git a/datafusion/sqllogictest/test_files/string/string_query.slt.part b/datafusion/sqllogictest/test_files/string/string_query.slt.part new file mode 100644 index 000000000000..c4975b5b8c8d --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/string_query.slt.part @@ -0,0 +1,1209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is intended to be run with tables already defined +# with standard values, but different types in string columns +# (String, StringView, etc.) + +# select +query TTTT +SELECT ascii_1, ascii_2, unicode_1, unicode_2 FROM test_basic_operator +---- +Andrew X datafusion📊🔥 🔥 +Xiangpeng Xiangpeng datafusion数据融合 datafusion数据融合 +Raphael R datafusionДатаФусион аФус +under_score un_____core un iść core chrząszcz na łące w 東京都 +percent p%t pan Tadeusz ma iść w kąt Pan Tadeusz ma frunąć stąd w kąt +(empty) % (empty) (empty) +NULL % NULL NULL +NULL R NULL 🔥 + +# -------------------------------------- +# column comparison as filters +# -------------------------------------- + +query TT +select ascii_1, ascii_2 from test_basic_operator where ascii_1 = ascii_2 +---- +Xiangpeng Xiangpeng + +query TT +select ascii_1, ascii_2 from test_basic_operator where ascii_1 <> ascii_2 +---- +Andrew X +Raphael R +under_score un_____core +percent p%t +(empty) % + +query TT +select unicode_1, unicode_2 from test_basic_operator where unicode_1 = unicode_2 +---- +datafusion数据融合 datafusion数据融合 +(empty) (empty) + +query TT +select unicode_1, unicode_2 from test_basic_operator where unicode_1 <> unicode_2 +---- +datafusion📊🔥 🔥 +datafusionДатаФусион аФус +un iść core chrząszcz na łące w 東京都 +pan Tadeusz ma iść w kąt Pan Tadeusz ma frunąć stąd w kąt + +query TT +select ascii_1, unicode_1 from test_basic_operator where ascii_1 = unicode_1 +---- +(empty) (empty) + +query TT +select ascii_1, unicode_1 from test_basic_operator where ascii_1 <> unicode_1 +---- +Andrew datafusion📊🔥 +Xiangpeng datafusion数据融合 +Raphael datafusionДатаФусион +under_score un iść core +percent pan Tadeusz ma iść w kąt + +# -------------------------------------- +# column comparison +# -------------------------------------- +query TTTTBBBBBB +select + ascii_1, ascii_2, unicode_1, unicode_2, + ascii_1 = ascii_2, + ascii_1 <> ascii_2, + unicode_1 = unicode_2, + unicode_1 <> unicode_2, + ascii_1 = unicode_1, + ascii_1 <> unicode_1 +from test_basic_operator; +---- +Andrew X datafusion📊🔥 🔥 false true false true false true +Xiangpeng Xiangpeng datafusion数据融合 datafusion数据融合 true false true false false true +Raphael R datafusionДатаФусион аФус false true false true false true +under_score un_____core un iść core chrząszcz na łące w 東京都 false true false true false true +percent p%t pan Tadeusz ma iść w kąt Pan Tadeusz ma frunąć stąd w kąt false true false true false true +(empty) % (empty) (empty) false true true false true false +NULL % NULL NULL NULL NULL NULL NULL NULL NULL +NULL R NULL 🔥 NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# column to StringView scalar comparison +# -------------------------------------- +query TTBBBB +select + ascii_1, unicode_1, + ascii_1 = arrow_cast('Andrew', 'Utf8View'), + ascii_1 <> arrow_cast('Andrew', 'Utf8View'), + unicode_1 = arrow_cast('datafusion数据融合', 'Utf8View'), + unicode_1 <> arrow_cast('datafusion数据融合', 'Utf8View') +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false false true +Xiangpeng datafusion数据融合 false true true false +Raphael datafusionДатаФусион false true false true +under_score un iść core false true false true +percent pan Tadeusz ma iść w kąt false true false true +(empty) (empty) false true false true +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# column to String scalar +# -------------------------------------- +query TTBBBB +select + ascii_1, unicode_1, + ascii_1 = arrow_cast('Andrew', 'Utf8'), + ascii_1 <> arrow_cast('Andrew', 'Utf8'), + unicode_1 = arrow_cast('datafusion数据融合', 'Utf8'), + unicode_1 <> arrow_cast('datafusion数据融合', 'Utf8') +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false false true +Xiangpeng datafusion数据融合 false true true false +Raphael datafusionДатаФусион false true false true +under_score un iść core false true false true +percent pan Tadeusz ma iść w kąt false true false true +(empty) (empty) false true false true +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# column to LargeString scalar +# -------------------------------------- +query TTBBBB +select + ascii_1, unicode_1, + ascii_1 = arrow_cast('Andrew', 'LargeUtf8'), + ascii_1 <> arrow_cast('Andrew', 'LargeUtf8'), + unicode_1 = arrow_cast('datafusion数据融合', 'LargeUtf8'), + unicode_1 <> arrow_cast('datafusion数据融合', 'LargeUtf8') +from test_basic_operator; +---- +Andrew datafusion📊🔥 true false false true +Xiangpeng datafusion数据融合 false true true false +Raphael datafusionДатаФусион false true false true +under_score un iść core false true false true +percent pan Tadeusz ma iść w kąt false true false true +(empty) (empty) false true false true +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# substr function +# -------------------------------------- + +query TTTTTTTTTTTTTT +select + substr(c1, 1), + substr(c1, 3), + substr(c1, 100), + substr(c1, -1), + substr(c1, 0, 0), + substr(c1, -1, 2), + substr(c1, -2, 10), + substr(c1, -100, 200), + substr(c1, -10, 10), + substr(c1, -100, 10), + substr(c1, 1, 100), + substr(c1, 5, 3), + substr(c1, 100, 200), + substr(c1, 8, 0) +from test_substr; +---- +foo o (empty) foo (empty) (empty) foo foo (empty) (empty) foo (empty) (empty) (empty) +hello🌏世界 llo🌏世界 (empty) hello🌏世界 (empty) (empty) hello🌏世 hello🌏世界 (empty) (empty) hello🌏世界 o🌏世 (empty) (empty) +💩 (empty) (empty) 💩 (empty) (empty) 💩 💩 (empty) (empty) 💩 (empty) (empty) (empty) +ThisIsAVeryLongASCIIString isIsAVeryLongASCIIString (empty) ThisIsAVeryLongASCIIString (empty) (empty) ThisIsA ThisIsAVeryLongASCIIString (empty) (empty) ThisIsAVeryLongASCIIString IsA (empty) (empty) +(empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) (empty) +NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL + +query TTTT +SELECT + SUBSTR(ascii_1, 1, 3) as c1, + SUBSTR(ascii_2, 1, 3) as c2, + SUBSTR(unicode_1, 1, 3) as c3, + SUBSTR(unicode_2, 1, 3) as c4 +FROM test_basic_operator; +---- +And X dat 🔥 +Xia Xia dat dat +Rap R dat аФу +und un_ un chr +per p%t pan Pan +(empty) % (empty) (empty) +NULL % NULL NULL +NULL R NULL 🔥 + +# -------------------------------------- +# test distinct aggregate +# -------------------------------------- +query II +SELECT + COUNT(DISTINCT ascii_1), + COUNT(DISTINCT unicode_1) +FROM + test_basic_operator +---- +6 6 + +query II +SELECT + COUNT(DISTINCT ascii_1), + COUNT(DISTINCT unicode_1) +FROM + test_basic_operator +GROUP BY ascii_2; +---- +1 1 +1 1 +1 1 +1 1 +1 1 +1 1 + +query II rowsort +SELECT + COUNT(DISTINCT ascii_1), + COUNT(DISTINCT unicode_1) +FROM + test_basic_operator +GROUP BY unicode_2; +---- +0 0 +1 1 +1 1 +1 1 +1 1 +1 1 +1 1 + +# -------------------------------------- +# STARTS_WITH function +# -------------------------------------- + +query BBBB +SELECT + STARTS_WITH(ascii_1, ascii_2), + STARTS_WITH(unicode_1, unicode_2), + STARTS_WITH(ascii_1, unicode_2), + STARTS_WITH(unicode_1, ascii_2) +FROM test_basic_operator +---- +false false false false +true true false false +true false false false +false false false false +false false false false +false true true false +NULL NULL NULL NULL +NULL NULL NULL NULL + +query BBBB +SELECT + STARTS_WITH(ascii_1, 'And'), + STARTS_WITH(ascii_2, 'And'), + STARTS_WITH(unicode_1, 'data'), + STARTS_WITH(unicode_2, 'data') +FROM test_basic_operator +---- +true false true false +false false true true +false false true false +false false false false +false false false false +false false false false +NULL false NULL NULL +NULL false NULL false + +# -------------------------------------- +# Test TRANSLATE +# -------------------------------------- + +query T +SELECT + TRANSLATE(ascii_1, 'foo', 'bar') as c +FROM test_basic_operator; +---- +Andrew +Xiangpeng +Raphael +under_scrre +percent +(empty) +NULL +NULL + +query T +SELECT + TRANSLATE(unicode_1, 'foo', 'bar') as c +FROM test_basic_operator; +---- +databusirn📊🔥 +databusirn数据融合 +databusirnДатаФусион +un iść crre +pan Tadeusz ma iść w kąt +(empty) +NULL +NULL + +# -------------------------------------- +# Test REGEXP_REPLACE +# -------------------------------------- + +# Should run REGEXP_REPLACE with Scalar value for string +query T +SELECT + REGEXP_REPLACE(ascii_1, 'e', 'f') AS k +FROM test_basic_operator; +---- +Andrfw +Xiangpfng +Raphafl +undfr_score +pfrcent +(empty) +NULL +NULL + +# Should run REGEXP_REPLACE with Scalar value for string with flag +query T +SELECT + REGEXP_REPLACE(ascii_1, 'e', 'f', 'i') AS k +FROM test_basic_operator; +---- +Andrfw +Xiangpfng +Raphafl +undfr_score +pfrcent +(empty) +NULL +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for string +query T +SELECT + REGEXP_REPLACE(ascii_1, lower(ascii_1), 'bar') AS k +FROM test_basic_operator; +---- +Andrew +Xiangpeng +Raphael +bar +bar +bar +NULL +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for string with flag +query T +SELECT + REGEXP_REPLACE(ascii_1, lower(ascii_1), 'bar', 'g') AS k +FROM test_basic_operator; +---- +Andrew +Xiangpeng +Raphael +bar +bar +bar +NULL +NULL + +# -------------------------------------- +# Test Initcap +# -------------------------------------- +statement ok +CREATE TABLE test_lowercase AS SELECT + lower(ascii_1) as ascii_1_lower, + lower(unicode_1) as unicode_1_lower +FROM test_basic_operator; + +query TT +SELECT + INITCAP(ascii_1_lower) as c1, + INITCAP(unicode_1_lower) as c2 +FROM test_lowercase; +---- +Andrew Datafusion📊🔥 +Xiangpeng Datafusion数据融合 +Raphael Datafusionдатафусион +Under_Score Un Iść Core +Percent Pan Tadeusz Ma Iść W KąT +(empty) (empty) +NULL NULL +NULL NULL + +statement ok +drop table test_lowercase; + +# -------------------------------------- +# Test ASCII +# -------------------------------------- + +query IIII +SELECT + ASCII(ascii_1) as c1, + ASCII(ascii_2) as c2, + ASCII(unicode_1) as c3, + ASCII(unicode_2) as c4 +FROM test_basic_operator; +---- +65 88 100 128293 +88 88 100 100 +82 82 100 1072 +117 117 117 99 +112 112 112 80 +0 37 0 0 +NULL 37 NULL NULL +NULL 82 NULL 128293 + +# -------------------------------------- +# Test BTRIM +# -------------------------------------- + +# Test BTRIM outputs +query TTTTTT +SELECT + BTRIM(ascii_1, 'foo'), + BTRIM(ascii_1, 'A'), + BTRIM(ascii_1, NULL), + BTRIM(unicode_1), + BTRIM(unicode_1, '🔥'), + BTRIM(unicode_1, NULL) +FROM test_basic_operator; +---- +Andrew ndrew NULL datafusion📊🔥 datafusion📊 NULL +Xiangpeng Xiangpeng NULL datafusion数据融合 datafusion数据融合 NULL +Raphael Raphael NULL datafusionДатаФусион datafusionДатаФусион NULL +under_score under_score NULL un iść core un iść core NULL +percent percent NULL pan Tadeusz ma iść w kąt pan Tadeusz ma iść w kąt NULL +(empty) (empty) NULL (empty) (empty) NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test LTRIM +# -------------------------------------- + +# Test LTRIM outputs +query TTTTTT +SELECT + LTRIM(ascii_1, 'foo'), + LTRIM(ascii_1, ascii_2), + LTRIM(ascii_1, NULL), + LTRIM(unicode_1), + LTRIM(unicode_1, NULL), + LTRIM(unicode_1, '🔥') +FROM test_basic_operator; +---- +Andrew Andrew NULL datafusion📊🔥 NULL datafusion📊🔥 +Xiangpeng (empty) NULL datafusion数据融合 NULL datafusion数据融合 +Raphael aphael NULL datafusionДатаФусион NULL datafusionДатаФусион +under_score der_score NULL un iść core NULL un iść core +percent ercent NULL pan Tadeusz ma iść w kąt NULL pan Tadeusz ma iść w kąt +(empty) (empty) NULL (empty) NULL (empty) +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test RTRIM +# -------------------------------------- + +# Test RTRIM outputs +query TTTTT +SELECT + RTRIM(ascii_1, 'rew'), + RTRIM(ascii_1, ascii_2), + RTRIM(ascii_1), + RTRIM(unicode_1, NULL), + RTRIM(unicode_1, '🔥') +FROM test_basic_operator; +---- +And Andrew Andrew NULL datafusion📊 +Xiangpeng (empty) Xiangpeng NULL datafusion数据融合 +Raphael Raphael Raphael NULL datafusionДатаФусион +under_sco under_s under_score NULL un iść core +percent percen percent NULL pan Tadeusz ma iść w kąt +(empty) (empty) (empty) NULL (empty) +NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test CONTAINS +# -------------------------------------- + +query BBBBBB +SELECT + CONTAINS(ascii_1, 'foo') as c1, + CONTAINS(ascii_1, ascii_2) as c2, + CONTAINS(ascii_1, NULL) as c3, + CONTAINS(unicode_1, unicode_2) as c4, + CONTAINS(unicode_1, NULL) as c5, + CONTAINS(unicode_1, '🔥') as c6 +FROM test_basic_operator; +---- +false false NULL true NULL true +false true NULL true NULL false +false true NULL true NULL false +false false NULL false NULL false +false false NULL false NULL false +false false NULL true NULL false +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test LOWER +# -------------------------------------- + +query TT +SELECT LOWER(ascii_1) as c1, LOWER(unicode_1) as c2 FROM test_basic_operator; +---- +andrew datafusion📊🔥 +xiangpeng datafusion数据融合 +raphael datafusionдатафусион +under_score un iść core +percent pan tadeusz ma iść w kąt +(empty) (empty) +NULL NULL +NULL NULL + +# -------------------------------------- +# Test UPPER +# -------------------------------------- + +query TT +SELECT UPPER(ascii_1) as c1, UPPER(unicode_1) as c2 FROM test_basic_operator; +---- +ANDREW DATAFUSION📊🔥 +XIANGPENG DATAFUSION数据融合 +RAPHAEL DATAFUSIONДАТАФУСИОН +UNDER_SCORE UN IŚĆ CORE +PERCENT PAN TADEUSZ MA IŚĆ W KĄT +(empty) (empty) +NULL NULL +NULL NULL + +# -------------------------------------- +# Test Concat +# -------------------------------------- + +query TTTTTTTTTTTT +SELECT + concat(ascii_1, ':Data'), + concat(ascii_1, ascii_2), + concat(ascii_1, NULL), + concat(ascii_1, unicode_1), + concat(ascii_1, unicode_2), + concat(unicode_1, ascii_1), + concat(unicode_1, unicode_2), + concat(unicode_1, NULL), + concat(unicode_1, '🔥'), + concat(NULL, '🔥'), + concat(NULL, NULL), + concat(ascii_1, ',', unicode_1) +FROM test_basic_operator; +---- +Andrew:Data AndrewX Andrew Andrewdatafusion📊🔥 Andrew🔥 datafusion📊🔥Andrew datafusion📊🔥🔥 datafusion📊🔥 datafusion📊🔥🔥 🔥 (empty) Andrew,datafusion📊🔥 +Xiangpeng:Data XiangpengXiangpeng Xiangpeng Xiangpengdatafusion数据融合 Xiangpengdatafusion数据融合 datafusion数据融合Xiangpeng datafusion数据融合datafusion数据融合 datafusion数据融合 datafusion数据融合🔥 🔥 (empty) Xiangpeng,datafusion数据融合 +Raphael:Data RaphaelR Raphael RaphaeldatafusionДатаФусион RaphaelаФус datafusionДатаФусионRaphael datafusionДатаФусионаФус datafusionДатаФусион datafusionДатаФусион🔥 🔥 (empty) Raphael,datafusionДатаФусион +under_score:Data under_scoreun_____core under_score under_scoreun iść core under_scorechrząszcz na łące w 東京都 un iść coreunder_score un iść corechrząszcz na łące w 東京都 un iść core un iść core🔥 🔥 (empty) under_score,un iść core +percent:Data percentp%t percent percentpan Tadeusz ma iść w kąt percentPan Tadeusz ma frunąć stąd w kąt pan Tadeusz ma iść w kątpercent pan Tadeusz ma iść w kątPan Tadeusz ma frunąć stąd w kąt pan Tadeusz ma iść w kąt pan Tadeusz ma iść w kąt🔥 🔥 (empty) percent,pan Tadeusz ma iść w kąt +:Data % (empty) (empty) (empty) (empty) (empty) (empty) 🔥 🔥 (empty) , +:Data % (empty) (empty) (empty) (empty) (empty) (empty) 🔥 🔥 (empty) , +:Data R (empty) (empty) 🔥 (empty) 🔥 (empty) 🔥 🔥 (empty) , + +# -------------------------------------- +# Test OVERLAY +# -------------------------------------- + +query TTTTTT +SELECT + OVERLAY(ascii_1 PLACING 'foo' FROM 2 ), + OVERLAY(unicode_1 PLACING 'foo' FROM 2), + OVERLAY(ascii_1 PLACING '🔥' FROM 2), + OVERLAY(unicode_1 PLACING '🔥' FROM 2), + OVERLAY(ascii_1 PLACING NULL FROM 2), + OVERLAY(unicode_1 PLACING NULL FROM 2) +FROM test_basic_operator; +---- +Afooew dfoofusion📊🔥 A🔥drew d🔥tafusion📊🔥 NULL NULL +Xfoogpeng dfoofusion数据融合 X🔥angpeng d🔥tafusion数据融合 NULL NULL +Rfooael dfoofusionДатаФусион R🔥phael d🔥tafusionДатаФусион NULL NULL +ufoor_score ufoość core u🔥der_score u🔥 iść core NULL NULL +pfooent pfooTadeusz ma iść w kąt p🔥rcent p🔥n Tadeusz ma iść w kąt NULL NULL +foo foo 🔥 🔥 NULL NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REPLACE +# -------------------------------------- + +query TTTTTT +SELECT + REPLACE(ascii_1, 'foo', 'bar'), + REPLACE(ascii_1, ascii_2, 'bar'), + REPLACE(ascii_1, NULL, 'bar'), + REPLACE(unicode_1, unicode_2, 'bar'), + REPLACE(unicode_1, NULL, 'bar'), + REPLACE(unicode_1, '🔥', 'bar') +FROM test_basic_operator; +---- +Andrew Andrew NULL datafusion📊bar NULL datafusion📊bar +Xiangpeng bar NULL bar NULL datafusion数据融合 +Raphael baraphael NULL datafusionДатbarион NULL datafusionДатаФусион +under_score under_score NULL un iść core NULL un iść core +percent percent NULL pan Tadeusz ma iść w kąt NULL pan Tadeusz ma iść w kąt +(empty) (empty) NULL bar NULL (empty) +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test RIGHT +# -------------------------------------- +# Test outputs of RIGHT +query TTTTTT +SELECT + RIGHT(ascii_1, 3), + RIGHT(ascii_1, 0), + RIGHT(ascii_1, -3), + RIGHT(unicode_1, 3), + RIGHT(unicode_1, 0), + RIGHT(unicode_1, -3) +FROM test_basic_operator; +---- +rew (empty) rew n📊🔥 (empty) afusion📊🔥 +eng (empty) ngpeng 据融合 (empty) afusion数据融合 +ael (empty) hael ион (empty) afusionДатаФусион +ore (empty) er_score ore (empty) iść core +ent (empty) cent kąt (empty) Tadeusz ma iść w kąt +(empty) (empty) (empty) (empty) (empty) (empty) +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test LEFT +# -------------------------------------- + +# Test outputs of LEFT +query TTTTTT +SELECT + LEFT(ascii_1, 3), + LEFT(ascii_1, 0), + LEFT(ascii_1, -3), + LEFT(unicode_1, 3), + LEFT(unicode_1, 0), + LEFT(unicode_1, -3) +FROM test_basic_operator; +---- +And (empty) And dat (empty) datafusio +Xia (empty) Xiangp dat (empty) datafusion数 +Rap (empty) Raph dat (empty) datafusionДатаФус +und (empty) under_sc un (empty) un iść c +per (empty) perc pan (empty) pan Tadeusz ma iść w +(empty) (empty) (empty) (empty) (empty) (empty) +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test SUBSTR_INDEX +# -------------------------------------- + +query TTTT +SELECT + SUBSTR_INDEX(ascii_1, 'a', 1), + SUBSTR_INDEX(ascii_1, 'a', 2), + SUBSTR_INDEX(unicode_1, 'а', 1), + SUBSTR_INDEX(unicode_1, 'а', 2) +FROM test_basic_operator; +---- +Andrew Andrew datafusion📊🔥 datafusion📊🔥 +Xi Xiangpeng datafusion数据融合 datafusion数据融合 +R Raph datafusionД datafusionДат +under_score under_score un iść core un iść core +percent percent pan Tadeusz ma iść w kąt pan Tadeusz ma iść w kąt +(empty) (empty) (empty) (empty) +NULL NULL NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test FIND_IN_SET +# -------------------------------------- + +query IIII +SELECT + FIND_IN_SET(ascii_1, 'a,b,c,d'), + FIND_IN_SET(ascii_1, 'Andrew,Xiangpeng,Raphael'), + FIND_IN_SET(unicode_1, 'a,b,c,d'), + FIND_IN_SET(unicode_1, 'datafusion📊🔥,datafusion数据融合,datafusionДатаФусион') +FROM test_basic_operator; +---- +0 1 0 1 +0 2 0 2 +0 3 0 3 +0 0 0 0 +0 0 0 0 +0 0 0 0 +NULL NULL NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test || operator +# -------------------------------------- + +# || constants +# expect all results to be the same for each row as they all have the same values +query TTTT +SELECT + ascii_1 || 'foo', + ascii_1 || '🔥', + unicode_1 || 'foo', + unicode_1 || '🔥' +FROM test_basic_operator; +---- +Andrewfoo Andrew🔥 datafusion📊🔥foo datafusion📊🔥🔥 +Xiangpengfoo Xiangpeng🔥 datafusion数据融合foo datafusion数据融合🔥 +Raphaelfoo Raphael🔥 datafusionДатаФусионfoo datafusionДатаФусион🔥 +under_scorefoo under_score🔥 un iść corefoo un iść core🔥 +percentfoo percent🔥 pan Tadeusz ma iść w kątfoo pan Tadeusz ma iść w kąt🔥 +foo 🔥 foo 🔥 +NULL NULL NULL NULL +NULL NULL NULL NULL + +# || same type (column1 has null, so also tests NULL || NULL) +# expect all results to be the same for each row as they all have the same values +query TTTT +SELECT + ascii_1 || ascii_2, + ascii_1 || unicode_2, + unicode_1 || ascii_2, + unicode_1 || unicode_2 +FROM test_basic_operator; +---- +AndrewX Andrew🔥 datafusion📊🔥X datafusion📊🔥🔥 +XiangpengXiangpeng Xiangpengdatafusion数据融合 datafusion数据融合Xiangpeng datafusion数据融合datafusion数据融合 +RaphaelR RaphaelаФус datafusionДатаФусионR datafusionДатаФусионаФус +under_scoreun_____core under_scorechrząszcz na łące w 東京都 un iść coreun_____core un iść corechrząszcz na łące w 東京都 +percentp%t percentPan Tadeusz ma frunąć stąd w kąt pan Tadeusz ma iść w kątp%t pan Tadeusz ma iść w kątPan Tadeusz ma frunąć stąd w kąt +% (empty) % (empty) +NULL NULL NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test ~ operator +# -------------------------------------- + +query BB +SELECT + ascii_1 ~ 'an', + unicode_1 ~ 'таФ' +FROM test_basic_operator; +---- +false false +true false +false true +false false +false false +false false +NULL NULL +NULL NULL + +query BB +SELECT + ascii_1 ~* '^a.{3}e', + unicode_1 ~* '^d.*Фу' +FROM test_basic_operator; +---- +true false +false false +false true +false false +false false +false false +NULL NULL +NULL NULL + +query BB +SELECT + ascii_1 !~~ 'xia_g%g', + unicode_1 !~~ 'datafusion数据融合' +FROM test_basic_operator; +---- +true true +true false +true true +true true +true true +true true +NULL NULL +NULL NULL + +query BB +SELECT + ascii_1 !~~* 'xia_g%g', + unicode_1 !~~* 'datafusion数据融合' +FROM test_basic_operator; +---- +true true +false false +true true +true true +true true +true true +NULL NULL +NULL NULL + +# -------------------------------------- +# Test || operator +# -------------------------------------- + +query TTTTT +select + ascii_1 || ' nice', + ascii_1 || ' and ' || ascii_2, + unicode_1 || ' cool', + unicode_1 || ' and ' || unicode_2, + ascii_1 || ' 🔥 ' || unicode_1 +from test_basic_operator; +---- +Andrew nice Andrew and X datafusion📊🔥 cool datafusion📊🔥 and 🔥 Andrew 🔥 datafusion📊🔥 +Xiangpeng nice Xiangpeng and Xiangpeng datafusion数据融合 cool datafusion数据融合 and datafusion数据融合 Xiangpeng 🔥 datafusion数据融合 +Raphael nice Raphael and R datafusionДатаФусион cool datafusionДатаФусион and аФус Raphael 🔥 datafusionДатаФусион +under_score nice under_score and un_____core un iść core cool un iść core and chrząszcz na łące w 東京都 under_score 🔥 un iść core +percent nice percent and p%t pan Tadeusz ma iść w kąt cool pan Tadeusz ma iść w kąt and Pan Tadeusz ma frunąć stąd w kąt percent 🔥 pan Tadeusz ma iść w kąt + nice and % cool and 🔥 +NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test LIKE / ILIKE +# -------------------------------------- + +# TODO: StringView has wrong behavior for LIKE/ILIKE. Enable this after fixing the issue +# see issue: https://github.com/apache/datafusion/issues/12637 +# Test pattern with wildcard characters +#query TTBBBB +#select ascii_1, unicode_1, +# ascii_1 like 'An%' as ascii_like, +# unicode_1 like '%ion数据%' as unicode_like, +# ascii_1 ilike 'An%' as ascii_ilike, +# unicode_1 ilike '%ion数据%' as unicode_ilik +#from test_basic_operator; +#---- +#Andrew datafusion📊🔥 true false true false +#Xiangpeng datafusion数据融合 false true false true +#Raphael datafusionДатаФусион false false false false +#NULL NULL NULL NULL NULL NULL + +# Test pattern without wildcard characters +query TTBBBB +select ascii_1, unicode_1, + ascii_1 like 'An' as ascii_like, + unicode_1 like 'ion数据' as unicode_like, + ascii_1 ilike 'An' as ascii_ilike, + unicode_1 ilike 'ion数据' as unicode_ilik +from test_basic_operator; +---- +Andrew datafusion📊🔥 false false false false +Xiangpeng datafusion数据融合 false false false false +Raphael datafusionДатаФусион false false false false +under_score un iść core false false false false +percent pan Tadeusz ma iść w kąt false false false false +(empty) (empty) false false false false +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test CHARACTER_LENGTH +# -------------------------------------- + +query II +SELECT + CHARACTER_LENGTH(ascii_1), + CHARACTER_LENGTH(unicode_1) +FROM + test_basic_operator +---- +6 12 +9 14 +7 20 +11 11 +7 24 +0 0 +NULL NULL +NULL NULL + +# -------------------------------------- +# Test Start_With +# -------------------------------------- + +query BBBB +SELECT + STARTS_WITH(ascii_1, 'And'), + STARTS_WITH(unicode_1, 'data'), + STARTS_WITH(ascii_1, NULL), + STARTS_WITH(unicode_1, NULL) +FROM test_basic_operator; +---- +true true NULL NULL +false true NULL NULL +false true NULL NULL +false false NULL NULL +false false NULL NULL +false false NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test ENDS_WITH +# -------------------------------------- + +query BBBB +SELECT + ENDS_WITH(ascii_1, 'w'), + ENDS_WITH(unicode_1, 'ион'), + ENDS_WITH(ascii_1, NULL), + ENDS_WITH(unicode_1, NULL) +FROM test_basic_operator; +---- +true false NULL NULL +false false NULL NULL +false true NULL NULL +false false NULL NULL +false false NULL NULL +false false NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test LEVENSHTEIN +# -------------------------------------- + +query IIII +SELECT + LEVENSHTEIN(ascii_1, 'Andrew'), + LEVENSHTEIN(unicode_1, 'datafusion数据融合'), + LEVENSHTEIN(ascii_1, NULL), + LEVENSHTEIN(unicode_1, NULL) +FROM test_basic_operator; +---- +0 4 NULL NULL +7 0 NULL NULL +6 10 NULL NULL +8 13 NULL NULL +6 19 NULL NULL +6 14 NULL NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +# -------------------------------------- +# Test LPAD +# -------------------------------------- + +query TTTT +SELECT + LPAD(ascii_1, 20, 'x'), + LPAD(ascii_1, 20, NULL), + LPAD(unicode_1, 20, '🔥'), + LPAD(unicode_1, 20, NULL) +FROM test_basic_operator; +---- +xxxxxxxxxxxxxxAndrew NULL 🔥🔥🔥🔥🔥🔥🔥🔥datafusion📊🔥 NULL +xxxxxxxxxxxXiangpeng NULL 🔥🔥🔥🔥🔥🔥datafusion数据融合 NULL +xxxxxxxxxxxxxRaphael NULL datafusionДатаФусион NULL +xxxxxxxxxunder_score NULL 🔥🔥🔥🔥🔥🔥🔥🔥🔥un iść core NULL +xxxxxxxxxxxxxpercent NULL pan Tadeusz ma iść w NULL +xxxxxxxxxxxxxxxxxxxx NULL 🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥 NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +query TTT +SELECT + LPAD(ascii_1, 20), + LPAD(unicode_1, 20), + '|' +FROM test_basic_operator; +---- + Andrew datafusion📊🔥 | + Xiangpeng datafusion数据融合 | + Raphael datafusionДатаФусион | + under_score un iść core | + percent pan Tadeusz ma iść w | + | +NULL NULL | +NULL NULL | + +# -------------------------------------- +# Test RPAD +# -------------------------------------- + +query TTTT +SELECT + RPAD(ascii_1, 20, 'x'), + RPAD(ascii_1, 20, NULL), + RPAD(unicode_1, 20, '🔥'), + RPAD(unicode_1, 20, NULL) +FROM test_basic_operator; +---- +Andrewxxxxxxxxxxxxxx NULL datafusion📊🔥🔥🔥🔥🔥🔥🔥🔥🔥 NULL +Xiangpengxxxxxxxxxxx NULL datafusion数据融合🔥🔥🔥🔥🔥🔥 NULL +Raphaelxxxxxxxxxxxxx NULL datafusionДатаФусион NULL +under_scorexxxxxxxxx NULL un iść core🔥🔥🔥🔥🔥🔥🔥🔥🔥 NULL +percentxxxxxxxxxxxxx NULL pan Tadeusz ma iść w NULL +xxxxxxxxxxxxxxxxxxxx NULL 🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥🔥 NULL +NULL NULL NULL NULL +NULL NULL NULL NULL + +query TT +SELECT + RPAD(ascii_1, 20), + RPAD(unicode_1, 20) +FROM test_basic_operator; +---- +Andrew datafusion📊🔥 +Xiangpeng datafusion数据融合 +Raphael datafusionДатаФусион +under_score un iść core +percent pan Tadeusz ma iść w + +NULL NULL +NULL NULL + +# -------------------------------------- +# Test REGEXP_LIKE +# -------------------------------------- + +query BBBBBBBB +SELECT + -- without flags + REGEXP_LIKE(ascii_1, 'an'), + REGEXP_LIKE(unicode_1, 'таФ'), + REGEXP_LIKE(ascii_1, NULL), + REGEXP_LIKE(unicode_1, NULL), + -- with flags + REGEXP_LIKE(ascii_1, 'AN', 'i'), + REGEXP_LIKE(unicode_1, 'ТаФ', 'i'), + REGEXP_LIKE(ascii_1, NULL, 'i'), + REGEXP_LIKE(unicode_1, NULL, 'i') + FROM test_basic_operator; +---- +false false NULL NULL true false NULL NULL +true false NULL NULL true false NULL NULL +false true NULL NULL false true NULL NULL +false false NULL NULL false false NULL NULL +false false NULL NULL false false NULL NULL +false false NULL NULL false false NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REGEXP_MATCH +# -------------------------------------- + +query ???????? +SELECT + -- without flags + REGEXP_MATCH(ascii_1, 'an'), + REGEXP_MATCH(unicode_1, 'ТаФ'), + REGEXP_MATCH(ascii_1, NULL), + REGEXP_MATCH(unicode_1, NULL), + -- with flags + REGEXP_MATCH(ascii_1, 'AN', 'i'), + REGEXP_MATCH(unicode_1, 'таФ', 'i'), + REGEXP_MATCH(ascii_1, NULL, 'i'), + REGEXP_MATCH(unicode_1, NULL, 'i') +FROM test_basic_operator; +---- +NULL NULL NULL NULL [An] NULL NULL NULL +[an] NULL NULL NULL [an] NULL NULL NULL +NULL NULL NULL NULL NULL [таФ] NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REPEAT +# -------------------------------------- + +query TT +SELECT + REPEAT(ascii_1, 3), + REPEAT(unicode_1, 3) +FROM test_basic_operator; +---- +AndrewAndrewAndrew datafusion📊🔥datafusion📊🔥datafusion📊🔥 +XiangpengXiangpengXiangpeng datafusion数据融合datafusion数据融合datafusion数据融合 +RaphaelRaphaelRaphael datafusionДатаФусионdatafusionДатаФусионdatafusionДатаФусион +under_scoreunder_scoreunder_score un iść coreun iść coreun iść core +percentpercentpercent pan Tadeusz ma iść w kątpan Tadeusz ma iść w kątpan Tadeusz ma iść w kąt +(empty) (empty) +NULL NULL +NULL NULL + +# -------------------------------------- +# Test SPLIT_PART +# -------------------------------------- + +query TTTTTT +SELECT + SPLIT_PART(ascii_1, 'e', 1), + SPLIT_PART(ascii_1, 'e', 2), + SPLIT_PART(ascii_1, NULL, 1), + SPLIT_PART(unicode_1, 'и', 1), + SPLIT_PART(unicode_1, 'и', 2), + SPLIT_PART(unicode_1, NULL, 1) +FROM test_basic_operator; +---- +Andr w NULL datafusion📊🔥 (empty) NULL +Xiangp ng NULL datafusion数据融合 (empty) NULL +Rapha l NULL datafusionДатаФус он NULL +und r_scor NULL un iść core (empty) NULL +p rc NULL pan Tadeusz ma iść w kąt (empty) NULL +(empty) (empty) NULL (empty) (empty) NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test REVERSE +# -------------------------------------- + +query TT +SELECT + REVERSE(ascii_1), + REVERSE(unicode_1) +FROM test_basic_operator; +---- +werdnA 🔥📊noisufatad +gnepgnaiX 合融据数noisufatad +leahpaR ноисуФатаДnoisufatad +erocs_rednu eroc ćśi nu +tnecrep tąk w ćśi am zsuedaT nap +(empty) (empty) +NULL NULL +NULL NULL + +# -------------------------------------- +# Test STRPOS +# -------------------------------------- + +query IIIIII +SELECT + STRPOS(ascii_1, 'e'), + STRPOS(ascii_1, 'ang'), + STRPOS(ascii_1, NULL), + STRPOS(unicode_1, 'и'), + STRPOS(unicode_1, 'ион'), + STRPOS(unicode_1, NULL) +FROM test_basic_operator; +---- +5 0 NULL 0 0 NULL +7 3 NULL 0 0 NULL +6 0 NULL 18 18 NULL +4 0 NULL 0 0 NULL +2 0 NULL 0 0 NULL +0 0 NULL 0 0 NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL + +# -------------------------------------- +# Test SUBSTR_INDEX +# -------------------------------------- + +query TTTTTT +SELECT + SUBSTR_INDEX(ascii_1, 'e', 1), + SUBSTR_INDEX(ascii_1, 'ang', 1), + SUBSTR_INDEX(ascii_1, NULL, 1), + SUBSTR_INDEX(unicode_1, 'и', 1), + SUBSTR_INDEX(unicode_1, '据融', 1), + SUBSTR_INDEX(unicode_1, NULL, 1) +FROM test_basic_operator; +---- +Andr Andrew NULL datafusion📊🔥 datafusion📊🔥 NULL +Xiangp Xi NULL datafusion数据融合 datafusion数 NULL +Rapha Raphael NULL datafusionДатаФус datafusionДатаФусион NULL +und under_score NULL un iść core un iść core NULL +p percent NULL pan Tadeusz ma iść w kąt pan Tadeusz ma iść w kąt NULL +(empty) (empty) NULL (empty) (empty) NULL +NULL NULL NULL NULL NULL NULL +NULL NULL NULL NULL NULL NULL diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt new file mode 100644 index 000000000000..43b08cb25f3f --- /dev/null +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -0,0 +1,1027 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include ./init_data.slt.part + +# -------------------------------------- +# Setup test tables with different physical string types +# and repeat tests in `string_query.slt.part` +# -------------------------------------- +statement ok +create table test_basic_operator as +select + arrow_cast(column1, 'Utf8View') as ascii_1, + arrow_cast(column2, 'Utf8View') as ascii_2, + arrow_cast(column3, 'Utf8View') as unicode_1, + arrow_cast(column4, 'Utf8View') as unicode_2 +from test_source; + +statement ok +create table test_substr as +select arrow_cast(col1, 'Utf8View') as c1 from test_substr_base; + +statement ok +drop table test_source + +# +# common test for string-like functions and operators +# +include ./string_query.slt.part + +# +# Clean up +# +statement ok +drop table test_basic_operator; + +statement ok +drop table test_substr_base; + + +# -------------------------------------- +# String_view specific tests +# -------------------------------------- +statement ok +create table test_source as values + ('Andrew', 'X'), + ('Xiangpeng', 'Xiangpeng'), + ('Raphael', 'R'), + ('', 'Warsaw'), + (NULL, 'R'); + +# Table with the different combination of column types +statement ok +create table test as +SELECT + arrow_cast(column1, 'Utf8') as column1_utf8, + arrow_cast(column2, 'Utf8') as column2_utf8, + arrow_cast(column1, 'LargeUtf8') as column1_large_utf8, + arrow_cast(column2, 'LargeUtf8') as column2_large_utf8, + arrow_cast(column1, 'Utf8View') as column1_utf8view, + arrow_cast(column2, 'Utf8View') as column2_utf8view, + arrow_cast(column1, 'Dictionary(Int32, Utf8)') as column1_dict, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as column2_dict +FROM test_source; + +statement ok +drop table test_source + +######## +## StringView Function test +######## + +query I +select octet_length(column1_utf8view) from test; +---- +6 +9 +7 +0 +NULL + +query error DataFusion error: Arrow error: Compute error: bit_length not supported for Utf8View +select bit_length(column1_utf8view) from test; + +query T +select btrim(column1_large_utf8) from test; +---- +Andrew +Xiangpeng +Raphael +(empty) +NULL + +######## +## StringView to Other Types column +######## + +# test StringViewArray with Utf8 columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_utf8, + column2_utf8 = column1_utf8view, + column1_utf8view <> column2_utf8, + column2_utf8 <> column1_utf8view +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +(empty) Warsaw false false true true +NULL R NULL NULL NULL NULL + +# test StringViewArray with LargeUtf8 columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_large_utf8, + column2_large_utf8 = column1_utf8view, + column1_utf8view <> column2_large_utf8, + column2_large_utf8 <> column1_utf8view +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +(empty) Warsaw false false true true +NULL R NULL NULL NULL NULL + +######## +## StringView to Dictionary +######## + +# test StringViewArray with Dictionary columns +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = column2_dict, + column2_dict = column1_utf8view, + column1_utf8view <> column2_dict, + column2_dict <> column1_utf8view +from test; +---- +Andrew X false false true true +Xiangpeng Xiangpeng true true false false +Raphael R false false true true +(empty) Warsaw false false true true +NULL R NULL NULL NULL NULL + +# StringView column to Dict scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_utf8view = arrow_cast('Andrew', 'Dictionary(Int32, Utf8)'), + arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') = column1_utf8view, + column1_utf8view <> arrow_cast('Andrew', 'Dictionary(Int32, Utf8)'), + arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') <> column1_utf8view +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +(empty) Warsaw false false true true +NULL R NULL NULL NULL NULL + +# Dict column to StringView scalar +query TTBBBB +select + column1_utf8, column2_utf8, + column1_dict = arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') = column1_dict, + column1_dict <> arrow_cast('Andrew', 'Utf8View'), + arrow_cast('Andrew', 'Utf8View') <> column1_dict +from test; +---- +Andrew X true true false false +Xiangpeng Xiangpeng false false true true +Raphael R false false true true +(empty) Warsaw false false true true +NULL R NULL NULL NULL NULL + +######## +## Coercion Rules +######## + +statement ok +set datafusion.explain.logical_plan_only = true; + + +# Filter should have a StringView literal and no column cast +query TT +explain SELECT column1_utf8 from test where column1_utf8view = 'Andrew'; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +# reverse order should be the same +query TT +explain SELECT column1_utf8 from test where 'Andrew' = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +query TT +explain SELECT column1_utf8 from test where column1_utf8 = arrow_cast('Andrew', 'Utf8View'); +---- +logical_plan +01)Filter: test.column1_utf8 = Utf8("Andrew") +02)--TableScan: test projection=[column1_utf8] + +query TT +explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Utf8View') = column1_utf8; +---- +logical_plan +01)Filter: test.column1_utf8 = Utf8("Andrew") +02)--TableScan: test projection=[column1_utf8] + +query TT +explain SELECT column1_utf8 from test where column1_utf8view = arrow_cast('Andrew', 'Dictionary(Int32, Utf8)'); +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +query TT +explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Dictionary(Int32, Utf8)') = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = Utf8View("Andrew") +03)----TableScan: test projection=[column1_utf8, column1_utf8view] + +# compare string / stringview +# Should cast string -> stringview (which is cheap), not stringview -> string (which is not) +query TT +explain SELECT column1_utf8 from test where column1_utf8view = column2_utf8; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: test.column1_utf8view = CAST(test.column2_utf8 AS Utf8View) +03)----TableScan: test projection=[column1_utf8, column2_utf8, column1_utf8view] + +query TT +explain SELECT column1_utf8 from test where column2_utf8 = column1_utf8view; +---- +logical_plan +01)Projection: test.column1_utf8 +02)--Filter: CAST(test.column2_utf8 AS Utf8View) = test.column1_utf8view +03)----TableScan: test projection=[column1_utf8, column2_utf8, column1_utf8view] + +query TT +EXPLAIN SELECT + COUNT(DISTINCT column1_utf8), + COUNT(DISTINCT column1_utf8view), + COUNT(DISTINCT column1_dict) +FROM test; +---- +logical_plan +01)Aggregate: groupBy=[[]], aggr=[[count(DISTINCT test.column1_utf8), count(DISTINCT test.column1_utf8view), count(DISTINCT test.column1_dict)]] +02)--TableScan: test projection=[column1_utf8, column1_utf8view, column1_dict] + + +### `STARTS_WITH` + +# Test STARTS_WITH with utf8view against utf8view, utf8, and largeutf8 +# (should be no casts) +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8view, column2_utf8view) as c1, + STARTS_WITH(column1_utf8view, column2_utf8) as c2, + STARTS_WITH(column1_utf8view, column2_large_utf8) as c3 +FROM test; +---- +logical_plan +01)Projection: starts_with(test.column1_utf8view, test.column2_utf8view) AS c1, starts_with(test.column1_utf8view, CAST(test.column2_utf8 AS Utf8View)) AS c2, starts_with(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] + +query BBB +SELECT + STARTS_WITH(column1_utf8view, column2_utf8view) as c1, + STARTS_WITH(column1_utf8view, column2_utf8) as c2, + STARTS_WITH(column1_utf8view, column2_large_utf8) as c3 +FROM test; +---- +false false false +true true true +true true true +false false false +NULL NULL NULL + +# Test STARTS_WITH with utf8 against utf8view, utf8, and largeutf8 +# Should work, but will have to cast to common types +# should cast utf8 -> utf8view and largeutf8 -> utf8view +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8, column2_utf8view) as c1, + STARTS_WITH(column1_utf8, column2_utf8) as c3, + STARTS_WITH(column1_utf8, column2_large_utf8) as c4 +FROM test; +---- +logical_plan +01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c1, starts_with(test.column1_utf8, test.column2_utf8) AS c3, starts_with(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c4 +02)--TableScan: test projection=[column1_utf8, column2_utf8, column2_large_utf8, column2_utf8view] + +query BBB + SELECT + STARTS_WITH(column1_utf8, column2_utf8view) as c1, + STARTS_WITH(column1_utf8, column2_utf8) as c3, + STARTS_WITH(column1_utf8, column2_large_utf8) as c4 +FROM test; +---- +false false false +true true true +true true true +false false false +NULL NULL NULL + + +# Test STARTS_WITH with utf8view against literals +# In this case, the literals should be cast to utf8view. The columns +# should not be cast to utf8. +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8view, 'äöüß') as c1, + STARTS_WITH(column1_utf8view, '') as c2, + STARTS_WITH(column1_utf8view, NULL) as c3, + STARTS_WITH(NULL, column1_utf8view) as c4 +FROM test; +---- +logical_plan +01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4 +02)--TableScan: test projection=[column1_utf8view] + +query TT +EXPLAIN SELECT + INITCAP(column1_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: initcap(test.column1_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view] + + +# Create a table with lowercase strings +statement ok +CREATE TABLE test_lowercase AS SELECT + lower(column1_utf8) as column1_utf8_lower, + lower(column1_large_utf8) as column1_large_utf8_lower, + lower(column1_utf8view) as column1_utf8view_lower +FROM test; + +# Test INITCAP with utf8view, utf8, and largeutf8 +# Should not cast anything +query TT +EXPLAIN SELECT + INITCAP(column1_utf8view_lower) as c1, + INITCAP(column1_utf8_lower) as c2, + INITCAP(column1_large_utf8_lower) as c3 +FROM test_lowercase; +---- +logical_plan +01)Projection: initcap(test_lowercase.column1_utf8view_lower) AS c1, initcap(test_lowercase.column1_utf8_lower) AS c2, initcap(test_lowercase.column1_large_utf8_lower) AS c3 +02)--TableScan: test_lowercase projection=[column1_utf8_lower, column1_large_utf8_lower, column1_utf8view_lower] + +statement ok +drop table test_lowercase + +# Ensure string functions use native StringView implementation +# and do not fall back to Utf8 or LargeUtf8 +# Should see no casts to Utf8 in the plans below + +## Ensure no casts for LIKE/ILIKE +query TT +EXPLAIN SELECT + column1_utf8view like 'foo' as "like", + column1_utf8view ilike 'foo' as "ilike" +FROM test; +---- +logical_plan +01)Projection: test.column1_utf8view LIKE Utf8View("foo") AS like, test.column1_utf8view ILIKE Utf8View("foo") AS ilike +02)--TableScan: test projection=[column1_utf8view] + + +query TT +EXPLAIN SELECT + SUBSTR(column1_utf8view, 1, 3) as c1, + SUBSTR(column2_utf8, 1, 3) as c2, + SUBSTR(column2_large_utf8, 1, 3) as c3 +FROM test; +---- +logical_plan +01)Projection: substr(test.column1_utf8view, Int64(1), Int64(3)) AS c1, substr(test.column2_utf8, Int64(1), Int64(3)) AS c2, substr(test.column2_large_utf8, Int64(1), Int64(3)) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view] + +## Ensure no casts for SUBSTR + +query TT +EXPLAIN SELECT + SUBSTR(column1_utf8view, 1, 3) as c1, + SUBSTR(column2_utf8, 1, 3) as c2, + SUBSTR(column2_large_utf8, 1, 3) as c3 +FROM test; +---- +logical_plan +01)Projection: substr(test.column1_utf8view, Int64(1), Int64(3)) AS c1, substr(test.column2_utf8, Int64(1), Int64(3)) AS c2, substr(test.column2_large_utf8, Int64(1), Int64(3)) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view] + +# Test ASCII with utf8view against utf8view, utf8, and largeutf8 +# (should be no casts) +query TT +EXPLAIN SELECT + ASCII(column1_utf8view) as c1, + ASCII(column2_utf8) as c2, + ASCII(column2_large_utf8) as c3 +FROM test; +---- +logical_plan +01)Projection: ascii(test.column1_utf8view) AS c1, ascii(test.column2_utf8) AS c2, ascii(test.column2_large_utf8) AS c3 +02)--TableScan: test projection=[column2_utf8, column2_large_utf8, column1_utf8view] + +query TT +EXPLAIN SELECT + ASCII(column1_utf8) as c1, + ASCII(column1_large_utf8) as c2, + ASCII(column2_utf8view) as c3, + ASCII('hello') as c4, + ASCII(arrow_cast('world', 'Utf8View')) as c5 +FROM test; +---- +logical_plan +01)Projection: ascii(test.column1_utf8) AS c1, ascii(test.column1_large_utf8) AS c2, ascii(test.column2_utf8view) AS c3, Int32(104) AS c4, Int32(119) AS c5 +02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column2_utf8view] + +# Test ASCII with literals cast to Utf8View +query TT +EXPLAIN SELECT + ASCII(arrow_cast('äöüß', 'Utf8View')) as c1, + ASCII(arrow_cast('', 'Utf8View')) as c2, + ASCII(arrow_cast(NULL, 'Utf8View')) as c3 +FROM test; +---- +logical_plan +01)Projection: Int32(228) AS c1, Int32(0) AS c2, Int32(NULL) AS c3 +02)--TableScan: test projection=[] + +## Ensure no casts for BTRIM +# Test BTRIM with Utf8View input +query TT +EXPLAIN SELECT + BTRIM(column1_utf8view) AS l +FROM test; +---- +logical_plan +01)Projection: btrim(test.column1_utf8view) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test BTRIM with Utf8View input and Utf8View pattern +query TT +EXPLAIN SELECT + BTRIM(column1_utf8view, 'foo') AS l +FROM test; +---- +logical_plan +01)Projection: btrim(test.column1_utf8view, Utf8View("foo")) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test BTRIM with Utf8View bytes longer than 12 +query TT +EXPLAIN SELECT + BTRIM(column1_utf8view, 'this is longer than 12') AS l +FROM test; +---- +logical_plan +01)Projection: btrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for LTRIM +# Test LTRIM with Utf8View input +query TT +EXPLAIN SELECT + LTRIM(column1_utf8view) AS l +FROM test; +---- +logical_plan +01)Projection: ltrim(test.column1_utf8view) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test LTRIM with Utf8View input and Utf8View pattern +query TT +EXPLAIN SELECT + LTRIM(column1_utf8view, 'foo') AS l +FROM test; +---- +logical_plan +01)Projection: ltrim(test.column1_utf8view, Utf8View("foo")) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test LTRIM with Utf8View bytes longer than 12 +query TT +EXPLAIN SELECT + LTRIM(column1_utf8view, 'this is longer than 12') AS l +FROM test; +---- +logical_plan +01)Projection: ltrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +02)--TableScan: test projection=[column1_utf8view] + +## ensure no casts for RTRIM +# Test RTRIM with Utf8View input +query TT +EXPLAIN SELECT + RTRIM(column1_utf8view) AS l +FROM test; +---- +logical_plan +01)Projection: rtrim(test.column1_utf8view) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test RTRIM with Utf8View input and Utf8View pattern +query TT +EXPLAIN SELECT + RTRIM(column1_utf8view, 'foo') AS l +FROM test; +---- +logical_plan +01)Projection: rtrim(test.column1_utf8view, Utf8View("foo")) AS l +02)--TableScan: test projection=[column1_utf8view] + +# Test RTRIM with Utf8View bytes longer than 12 +query TT +EXPLAIN SELECT + RTRIM(column1_utf8view, 'this is longer than 12') AS l +FROM test; +---- +logical_plan +01)Projection: rtrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for CHARACTER_LENGTH +query TT +EXPLAIN SELECT + CHARACTER_LENGTH(column1_utf8view) AS l +FROM test; +---- +logical_plan +01)Projection: character_length(test.column1_utf8view) AS l +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for CONCAT Utf8View +query TT +EXPLAIN SELECT + concat(column1_utf8view, column2_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: concat(test.column1_utf8view, test.column2_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for CONCAT LargeUtf8 +query TT +EXPLAIN SELECT + concat(column1_large_utf8, column2_large_utf8) as c +FROM test; +---- +logical_plan +01)Projection: concat(test.column1_large_utf8, test.column2_large_utf8) AS c +02)--TableScan: test projection=[column1_large_utf8, column2_large_utf8] + +## Ensure no casts for CONCAT_WS +query TT +EXPLAIN SELECT + concat_ws(', ', column1_utf8view, column2_utf8view) as c +FROM test; +---- +logical_plan +01)Projection: concat_ws(Utf8(", "), test.column1_utf8view, test.column2_utf8view) AS c +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for CONTAINS +query TT +EXPLAIN SELECT + CONTAINS(column1_utf8view, 'foo') as c1, + CONTAINS(column1_utf8view, column2_utf8view) as c2, + CONTAINS(column1_utf8view, column2_large_utf8) as c3, + CONTAINS(column1_utf8, column2_utf8view) as c4, + CONTAINS(column1_utf8, column2_utf8) as c5, + CONTAINS(column1_utf8, column2_large_utf8) as c6, + CONTAINS(column1_large_utf8, column1_utf8view) as c7, + CONTAINS(column1_large_utf8, column2_utf8) as c8, + CONTAINS(column1_large_utf8, column2_large_utf8) as c9 +FROM test; +---- +logical_plan +01)Projection: contains(test.column1_utf8view, Utf8View("foo")) AS c1, contains(test.column1_utf8view, test.column2_utf8view) AS c2, contains(test.column1_utf8view, CAST(test.column2_large_utf8 AS Utf8View)) AS c3, contains(CAST(test.column1_utf8 AS Utf8View), test.column2_utf8view) AS c4, contains(test.column1_utf8, test.column2_utf8) AS c5, contains(CAST(test.column1_utf8 AS LargeUtf8), test.column2_large_utf8) AS c6, contains(CAST(test.column1_large_utf8 AS Utf8View), test.column1_utf8view) AS c7, contains(test.column1_large_utf8, CAST(test.column2_utf8 AS LargeUtf8)) AS c8, contains(test.column1_large_utf8, test.column2_large_utf8) AS c9 +02)--TableScan: test projection=[column1_utf8, column2_utf8, column1_large_utf8, column2_large_utf8, column1_utf8view, column2_utf8view] + +## Ensure no casts for ENDS_WITH +query TT +EXPLAIN SELECT + ENDS_WITH(column1_utf8view, 'foo') as c1, + ENDS_WITH(column2_utf8view, column2_utf8view) as c2 +FROM test; +---- +logical_plan +01)Projection: ends_with(test.column1_utf8view, Utf8View("foo")) AS c1, ends_with(test.column2_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for LEVENSHTEIN +query TT +EXPLAIN SELECT + levenshtein(column1_utf8view, 'foo') as c1, + levenshtein(column1_utf8view, column2_utf8view) as c2 +FROM test; +---- +logical_plan +01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for LOWER +query TT +EXPLAIN SELECT + LOWER(column1_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: lower(test.column1_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for UPPER +query TT +EXPLAIN SELECT + UPPER(column1_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: upper(test.column1_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for LPAD +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, ' ') as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), Utf8(" ")) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, column2_large_utf8) as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1 +02)--TableScan: test projection=[column2_large_utf8, column1_utf8view] + +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, column2_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for OCTET_LENGTH +query TT +EXPLAIN SELECT + OCTET_LENGTH(column1_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: octet_length(test.column1_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for OVERLAY +query TT +EXPLAIN SELECT + OVERLAY(column1_utf8view PLACING 'foo' FROM 2 ) as c1 +FROM test; +---- +logical_plan +01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +## Should run CONCAT successfully with utf8 and utf8view +query T +SELECT + concat(column1_utf8view, column2_utf8) as c +FROM test; +---- +AndrewX +XiangpengXiangpeng +RaphaelR +Warsaw +R + +## Should run CONCAT successfully with utf8 utf8view and largeutf8 +query T +SELECT + concat(column1_utf8view, column2_utf8, column2_large_utf8) as c +FROM test; +---- +AndrewXX +XiangpengXiangpengXiangpeng +RaphaelRR +WarsawWarsaw +RR + +## Ensure no casts for REGEXP_LIKE +query TT +EXPLAIN SELECT + REGEXP_LIKE(column1_utf8view, '^https?://(?:www\.)?([^/]+)/.*$') AS k +FROM test; +---- +logical_plan +01)Projection: regexp_like(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for REGEXP_MATCH +query TT +EXPLAIN SELECT + REGEXP_MATCH(column1_utf8view, '^https?://(?:www\.)?([^/]+)/.*$') AS k +FROM test; +---- +logical_plan +01)Projection: regexp_match(CAST(test.column1_utf8view AS Utf8), Utf8("^https?://(?:www\.)?([^/]+)/.*$")) AS k +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for REGEXP_REPLACE +query TT +EXPLAIN SELECT + REGEXP_REPLACE(column1_utf8view, '^https?://(?:www\.)?([^/]+)/.*$', '\1') AS k +FROM test; +---- +logical_plan +01)Projection: regexp_replace(test.column1_utf8view, Utf8("^https?://(?:www\.)?([^/]+)/.*$"), Utf8("\1")) AS k +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for REPEAT +query TT +EXPLAIN SELECT + REPEAT(column1_utf8view, 2) as c1 +FROM test; +---- +logical_plan +01)Projection: repeat(test.column1_utf8view, Int64(2)) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for REPLACE +query TT +EXPLAIN SELECT + REPLACE(column1_utf8view, 'foo', 'bar') as c1, + REPLACE(column1_utf8view, column2_utf8view, 'bar') as c2 +FROM test; +---- +logical_plan +01)Projection: replace(test.column1_utf8view, Utf8View("foo"), Utf8View("bar")) AS c1, replace(test.column1_utf8view, test.column2_utf8view, Utf8View("bar")) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for REVERSE +query TT +EXPLAIN SELECT + REVERSE(column1_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: reverse(test.column1_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for RIGHT +query TT +EXPLAIN SELECT + RIGHT(column1_utf8view, 3) as c2 +FROM test; +---- +logical_plan +01)Projection: right(test.column1_utf8view, Int64(3)) AS c2 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for LEFT +query TT +EXPLAIN SELECT + LEFT(column1_utf8view, 3) as c2 +FROM test; +---- +logical_plan +01)Projection: left(test.column1_utf8view, Int64(3)) AS c2 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for RPAD +query TT +EXPLAIN SELECT + RPAD(column1_utf8view, 1) as c1, + RPAD(column1_utf8view, 2, column2_utf8view) as c2 +FROM test; +---- +logical_plan +01)Projection: rpad(test.column1_utf8view, Int64(1)) AS c1, rpad(test.column1_utf8view, Int64(2), test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +query TT +EXPLAIN SELECT + RPAD(column1_utf8view, 12, column2_large_utf8) as c1 +FROM test; +---- +logical_plan +01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1 +02)--TableScan: test projection=[column2_large_utf8, column1_utf8view] + +query TT +EXPLAIN SELECT + RPAD(column1_utf8view, 12, column2_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: rpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for SPLIT_PART +query TT +EXPLAIN SELECT + SPLIT_PART(column1_utf8view, 'f', 1) as c1, + SPLIT_PART('testtesttest',column1_utf8view, 1) as c2 +FROM test; +---- +logical_plan +01)Projection: split_part(test.column1_utf8view, Utf8("f"), Int64(1)) AS c1, split_part(Utf8("testtesttest"), test.column1_utf8view, Int64(1)) AS c2 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for STRPOS +query TT +EXPLAIN SELECT + STRPOS(column1_utf8view, 'f') as c, + STRPOS(column1_utf8view, column2_utf8view) as c2 +FROM test; +---- +logical_plan +01)Projection: strpos(test.column1_utf8view, Utf8View("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for SUBSTR +query TT +EXPLAIN SELECT + SUBSTR(column1_utf8view, 1) as c, + SUBSTR(column1_utf8view, 1 ,2) as c2 +FROM test; +---- +logical_plan +01)Projection: substr(test.column1_utf8view, Int64(1)) AS c, substr(test.column1_utf8view, Int64(1), Int64(2)) AS c2 +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for SUBSTRINDEX +query TT +EXPLAIN SELECT + SUBSTR_INDEX(column1_utf8view, 'a', 1) as c, + SUBSTR_INDEX(column1_utf8view, 'a', 2) as c2 +FROM test; +---- +logical_plan +01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2 +02)--TableScan: test projection=[column1_utf8view] + + +## Ensure no casts on columns for STARTS_WITH +query TT +EXPLAIN SELECT + STARTS_WITH(column1_utf8view, 'foo') as c, + STARTS_WITH(column1_utf8view, column2_utf8view) as c2 +FROM test; +---- +logical_plan +01)Projection: starts_with(test.column1_utf8view, Utf8View("foo")) AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] + +## Ensure no casts for TRANSLATE +query TT +EXPLAIN SELECT + TRANSLATE(column1_utf8view, 'foo', 'bar') as c +FROM test; +---- +logical_plan +01)Projection: translate(test.column1_utf8view, Utf8("foo"), Utf8("bar")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for FIND_IN_SET +query TT +EXPLAIN SELECT + FIND_IN_SET(column1_utf8view, 'a,b,c,d') as c +FROM test; +---- +logical_plan +01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for to_date +query TT +EXPLAIN SELECT + to_date(column1_utf8view, 'a,b,c,d') as c +FROM test; +---- +logical_plan +01)Projection: to_date(test.column1_utf8view, Utf8("a,b,c,d")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for to_timestamp +query TT +EXPLAIN SELECT + to_timestamp(column1_utf8view, 'a,b,c,d') as c +FROM test; +---- +logical_plan +01)Projection: to_timestamp(test.column1_utf8view, Utf8("a,b,c,d")) AS c +02)--TableScan: test projection=[column1_utf8view] + +## Ensure no casts for binary operators +# `~` operator (regex match) +query TT +EXPLAIN SELECT + column1_utf8view ~ 'an' AS c1 +FROM test; +---- +logical_plan +01)Projection: CAST(test.column1_utf8view AS Utf8) LIKE Utf8("%an%") AS c1 +02)--TableScan: test projection=[column1_utf8view] + +# `~*` operator (regex match case-insensitive) +query TT +EXPLAIN SELECT + column1_utf8view ~* '^a.{3}e' AS c1 +FROM test; +---- +logical_plan +01)Projection: CAST(test.column1_utf8view AS Utf8) ~* Utf8("^a.{3}e") AS c1 +02)--TableScan: test projection=[column1_utf8view] + +# `!~~` operator (not like match) +query TT +EXPLAIN SELECT + column1_utf8view !~~ 'xia_g%g' AS c1 +FROM test; +---- +logical_plan +01)Projection: CAST(test.column1_utf8view AS Utf8) !~~ Utf8("xia_g%g") AS c1 +02)--TableScan: test projection=[column1_utf8view] + +# `!~~*` operator (not like match case-insensitive) +query TT +EXPLAIN SELECT + column1_utf8view !~~* 'xia_g%g' AS c1 +FROM test; +---- +logical_plan +01)Projection: CAST(test.column1_utf8view AS Utf8) !~~* Utf8("xia_g%g") AS c1 +02)--TableScan: test projection=[column1_utf8view] + +# coercions between stringview and date types +statement ok +create table dates (dt date) as values + (date '2024-01-23'), + (date '2023-11-30'); + +query D +select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt; +---- +2024-01-23 + +statement ok +drop table dates; + +### Tests for `||` with Utf8View specifically + +statement ok +create table temp as values +('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')), +('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool', 'Utf8View')); + +query TTT +select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from temp; +---- +Utf8 Utf8View Utf8View +Utf8 Utf8View Utf8View + +query TT +explain select column2 || 'is' || column3 from temp; +---- +logical_plan +01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2 || Utf8("is") || temp.column3 +02)--TableScan: temp projection=[column2, column3] + +# should not cast the column2 to utf8 +query TT +explain select column2||' is fast' from temp; +---- +logical_plan +01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8(" is fast") +02)--TableScan: temp projection=[column2] + +query TT +explain select column2||column3 from temp; +---- +logical_plan +01)Projection: temp.column2 || temp.column3 +02)--TableScan: temp projection=[column2, column3] + +statement ok +drop table test diff --git a/datafusion/sqllogictest/test_files/strings.slt b/datafusion/sqllogictest/test_files/strings.slt index 27ed0e2d0983..81b8f4b2da9a 100644 --- a/datafusion/sqllogictest/test_files/strings.slt +++ b/datafusion/sqllogictest/test_files/strings.slt @@ -17,7 +17,7 @@ statement ok CREATE TABLE test( - s TEXT, + s TEXT ) as VALUES ('p1'), ('p1e1'), @@ -46,6 +46,51 @@ P1m1e1 p1m1e1 p2m1e1 +# REGEX +query T rowsort +SELECT s FROM test WHERE s ~ 'p[12].*'; +---- +p1 +p1e1 +p1m1e1 +p2 +p2e1 +p2m1e1 + +# REGEX nocase +query T rowsort +SELECT s FROM test WHERE s ~* 'p[12].*'; +---- +P1 +P1e1 +P1m1e1 +p1 +p1e1 +p1m1e1 +p2 +p2e1 +p2m1e1 + +# SIMILAR TO +query T rowsort +SELECT s FROM test WHERE s SIMILAR TO 'p[12].*'; +---- +p1 +p1e1 +p1m1e1 +p2 +p2e1 +p2m1e1 + +# NOT SIMILAR TO +query T rowsort +SELECT s FROM test WHERE s NOT SIMILAR TO 'p[12].*'; +---- +P1 +P1e1 +P1m1e1 +e1 + # NOT LIKE query T rowsort SELECT s FROM test WHERE s NOT LIKE 'p1%'; @@ -78,3 +123,52 @@ e1 p2 p2e1 p2m1e1 + +## VARCHAR with length support + +# Lengths can be used by default +query T +SELECT '12345'::VARCHAR(2); +---- +12345 + +# Lengths can not be used when the config setting is disabled + +statement ok +set datafusion.sql_parser.support_varchar_with_length = false; + +query error +SELECT '12345'::VARCHAR(2); + +query error +SELECT s::VARCHAR(2) FROM (VALUES ('12345')) t(s); + +statement ok +create table vals(s char) as values('abc'), ('def'); + +query error +SELECT s::VARCHAR(2) FROM vals + +# Lengths can be used when the config setting is enabled + +statement ok +set datafusion.sql_parser.support_varchar_with_length = true; + +query T +SELECT '12345'::VARCHAR(2) +---- +12345 + +query T +SELECT s::VARCHAR(2) FROM (VALUES ('12345')) t(s) +---- +12345 + +query T +SELECT s::VARCHAR(2) FROM vals +---- +abc +def + +statement ok +drop table vals; diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 3e685cbb45a0..7596b820c688 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -24,13 +24,40 @@ CREATE TABLE values( a INT, b FLOAT, c VARCHAR, - n VARCHAR, + n VARCHAR ) AS VALUES (1, 1.1, 'a', NULL), (2, 2.2, 'b', NULL), (3, 3.3, 'c', NULL) ; + +# named and named less struct fields +statement ok +CREATE TABLE struct_values ( + s1 struct, + s2 struct +) AS VALUES + (struct(1), struct(1, 'string1')), + (struct(2), struct(2, 'string2')), + (struct(3), struct(3, 'string3')) +; + +query ?? +select * from struct_values; +---- +{c0: 1} {a: 1, b: string1} +{c0: 2} {a: 2, b: string2} +{c0: 3} {a: 3, b: string3} + +query TT +select arrow_typeof(s1), arrow_typeof(s2) from struct_values; +---- +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) Struct([Field { name: "a", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + + # struct[i] query IRT select struct(1, 3.14, 'h')['c0'], struct(3, 2.55, 'b')['c1'], struct(2, 6.43, 'a')['c2']; @@ -45,6 +72,14 @@ select struct(a, b, c)['c1'] from values; 2.2 3.3 +# explicit invocation of get_field +query R +select get_field(struct(a, b, c), 'c1') from values; +---- +1.1 +2.2 +3.3 + # struct scalar function #1 query ? select struct(1, 3.14, 'e'); @@ -92,9 +127,13 @@ physical_plan 02)--MemoryExec: partitions=1, partition_sizes=[1] # error on 0 arguments -query error DataFusion error: Error during planning: No function matches the given name and argument types 'named_struct\(\)'. You might need to add explicit type casts. +query error select named_struct(); +# error on duplicate field names +query error +select named_struct('c0': 1, 'c1': 2, 'c1': 3); + # error on odd number of arguments #1 query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead select named_struct('a'); @@ -135,6 +174,13 @@ select named_struct('scalar', 27, 'array', values.a, 'null', NULL) from values; {scalar: 27, array: 2, null: } {scalar: 27, array: 3, null: } +query ? +select {'scalar': 27, 'array': values.a, 'null': NULL} from values; +---- +{scalar: 27, array: 1, null: } +{scalar: 27, array: 2, null: } +{scalar: 27, array: 3, null: } + # named_struct with mixed scalar and array values #2 query ? select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values; @@ -143,6 +189,13 @@ select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values; {array: 2, scalar: 27, null: } {array: 3, scalar: 27, null: } +query ? +select {'array': values.a, 'scalar': 27, 'null': NULL} from values; +---- +{array: 1, scalar: 27, null: } +{array: 2, scalar: 27, null: } +{array: 3, scalar: 27, null: } + # named_struct with mixed scalar and array values #3 query ? select named_struct('null', NULL, 'array', values.a, 'scalar', 27) from values; @@ -173,10 +226,372 @@ select named_struct('field_a', 1, 'field_b', 2); ---- {field_a: 1, field_b: 2} +query T +select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); +---- +Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +query T +select arrow_typeof({'first': 1, 'second': 2, 'third': 3}); +---- +Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +# test nested struct literal +query ? +select {'animal': {'cat': 1, 'dog': 2, 'bird': {'parrot': 3, 'canary': 1}}, 'genre': {'fiction': ['mystery', 'sci-fi', 'fantasy'], 'non-fiction': {'biography': 5, 'history': 7, 'science': {'physics': 2, 'biology': 3}}}, 'vehicle': {'car': {'sedan': 4, 'suv': 2}, 'bicycle': 3, 'boat': ['sailboat', 'motorboat']}, 'weather': {'sunny': True, 'temperature': 25.5, 'wind': {'speed': 10, 'direction': 'NW'}}}; +---- +{animal: {cat: 1, dog: 2, bird: {parrot: 3, canary: 1}}, genre: {fiction: [mystery, sci-fi, fantasy], non-fiction: {biography: 5, history: 7, science: {physics: 2, biology: 3}}}, vehicle: {car: {sedan: 4, suv: 2}, bicycle: 3, boat: [sailboat, motorboat]}, weather: {sunny: true, temperature: 25.5, wind: {speed: 10, direction: NW}}} + +# test tuple as struct +query B +select ('x', 'y') = ('x', 'y'); +---- +true + +query B +select ('x', 'y') = ('y', 'x'); +---- +false + +query error DataFusion error: Error during planning: Cannot infer common argument type for comparison operation Struct.* +select ('x', 'y') = ('x', 'y', 'z'); + +query B +select ('x', 'y') IN (('x', 'y')); +---- +true + +query B +select ('x', 'y') IN (('x', 'y'), ('y', 'x')); +---- +true + +query I +select a from values where (a, c) = (1, 'a'); +---- +1 + +query I +select a from values where (a, c) IN ((1, 'a'), (2, 'b')); +---- +1 +2 + statement ok drop table values; +statement ok +drop table struct_values; + +statement ok +CREATE OR REPLACE VIEW complex_view AS +SELECT { + 'user': { + 'info': { + 'personal': { + 'name': 'John Doe', + 'age': 30, + 'email': 'john.doe@example.com' + }, + 'address': { + 'street': '123 Main St', + 'city': 'Anytown', + 'country': 'Countryland', + 'coordinates': [40.7128, -74.0060] + } + }, + 'preferences': { + 'theme': 'dark', + 'notifications': true, + 'languages': ['en', 'es', 'fr'] + }, + 'stats': { + 'logins': 42, + 'last_active': '2023-09-15', + 'scores': [85, 92, 78, 95], + 'achievements': { + 'badges': ['early_bird', 'top_contributor'], + 'levels': { + 'beginner': true, + 'intermediate': true, + 'advanced': false + } + } + } + }, + 'metadata': { + 'version': '1.0', + 'created_at': '2023-09-01T12:00:00Z' + }, + 'deep_nested': { + 'level1': { + 'level2': { + 'level3': { + 'level4': { + 'level5': { + 'level6': { + 'level7': { + 'level8': { + 'level9': { + 'level10': 'You reached the bottom!' + } + } + } + } + } + } + } + } + } + } +} AS complex_data; + query T -select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); +SELECT complex_data.user.info.personal.name FROM complex_view; ---- -Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +John Doe + +query I +SELECT complex_data.user.info.personal.age FROM complex_view; +---- +30 + +query T +SELECT complex_data.user.info.address.city FROM complex_view; +---- +Anytown + +query T +SELECT complex_data.user.preferences.languages[2] FROM complex_view; +---- +es + +query T +SELECT complex_data.deep_nested.level1.level2.level3.level4.level5.level6.level7.level8.level9.level10 FROM complex_view; +---- +You reached the bottom! + +statement ok +drop view complex_view; + +# struct with different keys r1 and r2 is not valid +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +# Expect same keys for struct type but got mismatched pair r1,c and r2,c +query error +select [a, b] from t; + +statement ok +drop table t; + +# struct with the same key +statement ok +create table t(a struct, b struct) as values (struct('red', 1), struct('blue', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +query ? +select [a, b] from t; +---- +[{r: red, c: 1.0}, {r: blue, c: 2.3}] + +statement ok +drop table t; + +# Test row alias + +query ? +select row('a', 'b'); +---- +{c0: a, c1: b} + +################################## +# Switch Dialect to DuckDB +################################## + +statement ok +set datafusion.sql_parser.dialect = 'DuckDB'; + +statement ok +CREATE TABLE struct_values ( + s1 struct(a int, b varchar), + s2 struct(a int, b varchar) +) AS VALUES + (row(1, 'red'), row(1, 'string1')), + (row(2, 'blue'), row(2, 'string2')), + (row(3, 'green'), row(3, 'string3')) +; + +statement ok +drop table struct_values; + +statement ok +create table t (c1 struct(r varchar, b int), c2 struct(r varchar, b float)) as values ( + row('red', 2), + row('blue', 2.3) +); + +query ?? +select * from t; +---- +{r: red, b: 2} {r: blue, b: 2.3} + +query T +select arrow_typeof(c1) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +query T +select arrow_typeof(c2) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +statement ok +create table t as values({r: 'a', c: 1}), ({r: 'b', c: 2.3}); + +query ? +select * from t; +---- +{c0: a, c1: 1.0} +{c0: b, c1: 2.3} + +query T +select arrow_typeof(column1) from t; +---- +Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "c0", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c1", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'a' to value of Float64 type +create table t as values({r: 'a', c: 1}), ({c: 2.3, r: 'b'}); + +################################## +## Test Coalesce with Struct +################################## + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (row(2, 'blue'), row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +query ? +select coalesce(s1) from t; +---- +{a: 1, b: red} +{a: 2, b: blue} +{a: 3, b: green} + +query T +select arrow_typeof(coalesce(s1, s2)) from t; +---- +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +statement ok +CREATE TABLE t ( + s1 struct(a int, b varchar), + s2 struct(a float, b varchar) +) AS VALUES + (row(1, 'red'), row(1.1, 'string1')), + (null, row(2.2, 'string2')), + (row(3, 'green'), row(33.2, 'string3')) +; + +query ? +select coalesce(s1, s2) from t; +---- +{a: 1.0, b: red} +{a: 2.2, b: string2} +{a: 3.0, b: green} + +query T +select arrow_typeof(coalesce(s1, s2)) from t; +---- +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) +Struct([Field { name: "a", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "b", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; + +# row() with incorrect order +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'blue' to value of Float32 type +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values + (row('red', 1), row(2.3, 'blue')), + (row('purple', 1), row('green', 2.3)); + +# out of order struct literal +# TODO: This query should not fail +statement error DataFusion error: Arrow error: Cast error: Cannot cast string 'b' to value of Int32 type +create table t(a struct(r varchar, c int)) as values ({r: 'a', c: 1}), ({c: 2, r: 'b'}); + +################################## +## Test Array of Struct +################################## + +query ? +select [{r: 'a', c: 1}, {r: 'b', c: 2}]; +---- +[{r: a, c: 1}, {r: b, c: 2}] + +# Can't create a list of struct with different field types +query error +select [{r: 'a', c: 1}, {c: 2, r: 'b'}]; + +statement ok +create table t(a struct(r varchar, c int), b struct(r varchar, c float)) as values (row('a', 1), row('b', 2.3)); + +query T +select arrow_typeof([a, b]) from t; +---- +List(Field { name: "item", data_type: Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +statement ok +drop table t; + +# create table with different struct type is fine +statement ok +create table t(a struct(r varchar, c int), b struct(c float, r varchar)) as values (row('a', 1), row(2.3, 'b')); + +# create array with different struct type is not valid +query error +select arrow_typeof([a, b]) from t; + +statement ok +drop table t; + +statement ok +create table t(a struct(r varchar, c int, g float), b struct(r varchar, c float, g int)) as values (row('a', 1, 2.3), row('b', 2.3, 2)); + +# type of each column should not coerced but perserve as it is +query T +select arrow_typeof(a) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +# type of each column should not coerced but perserve as it is +query T +select arrow_typeof(b) from t; +---- +Struct([Field { name: "r", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "c", data_type: Float32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "g", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) + +statement ok +drop table t; diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 7196418af1c7..027b5ca8dcfb 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -66,7 +66,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_acctbal DECIMAL(15, 2), c_mktsegment VARCHAR, c_comment VARCHAR, -) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/customer.csv'; +) STORED AS CSV LOCATION '../core/tests/tpch-csv/customer.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS orders ( @@ -79,7 +79,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS orders ( o_clerk VARCHAR, o_shippriority INTEGER, o_comment VARCHAR, -) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/orders.csv'; +) STORED AS CSV LOCATION '../core/tests/tpch-csv/orders.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( @@ -99,7 +99,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( l_shipinstruct VARCHAR, l_shipmode VARCHAR, l_comment VARCHAR, -) STORED AS CSV DELIMITER ',' WITH HEADER ROW LOCATION '../core/tests/tpch-csv/lineitem.csv'; +) STORED AS CSV LOCATION '../core/tests/tpch-csv/lineitem.csv' OPTIONS ('format.delimiter' ',', 'format.has_header' 'true'); # in_subquery_to_join_with_correlated_outer_filter query ITI rowsort @@ -127,6 +127,19 @@ where t1.t1_id + 12 not in ( ---- 22 b 2 +# wrapped_not_in_subquery_to_join_with_correlated_outer_filter +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where not t1.t1_id + 12 in ( + select t2.t2_id + 1 from t2 where t1.t1_int > 0 + ) +---- +22 b 2 + + # in subquery with two parentheses, see #5529 query ITI rowsort select t1.t1_id, @@ -179,26 +192,28 @@ query TT explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 ---- logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum +01)Projection: t1.t1_id, __scalar_sq_1.sum(t2.t2_int) AS t2_sum 02)--Left Join: t1.t1_id = __scalar_sq_1.t2_id 03)----TableScan: t1 projection=[t1_id] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: SUM(t2.t2_int), t2.t2_id -06)--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] +05)------Projection: sum(t2.t2_int), t2.t2_id +06)--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[sum(CAST(t2.t2_int AS Int64))]] 07)----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -01)ProjectionExec: expr=[t1_id@1 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +01)ProjectionExec: expr=[t1_id@1 as t1_id, sum(t2.t2_int)@0 as t2_sum] 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[SUM(t2.t2_int)@0, t1_id@2] -04)------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] -05)--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[sum(t2.t2_int)@0, t1_id@2] +04)------ProjectionExec: expr=[sum(t2.t2_int)@1 as sum(t2.t2_int), t2_id@0 as t2_id] +05)--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -213,26 +228,28 @@ query TT explain SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 ---- logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int * Float64(1)) + Int64(1) AS t2_sum +01)Projection: t1.t1_id, __scalar_sq_1.sum(t2.t2_int * Float64(1)) + Int64(1) AS t2_sum 02)--Left Join: t1.t1_id = __scalar_sq_1.t2_id 03)----TableScan: t1 projection=[t1_id] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: SUM(t2.t2_int * Float64(1)) + Float64(1) AS SUM(t2.t2_int * Float64(1)) + Int64(1), t2.t2_id -06)--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Float64)) AS SUM(t2.t2_int * Float64(1))]] +05)------Projection: sum(t2.t2_int * Float64(1)) + Float64(1) AS sum(t2.t2_int * Float64(1)) + Int64(1), t2.t2_id +06)--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[sum(CAST(t2.t2_int AS Float64)) AS sum(t2.t2_int * Float64(1))]] 07)----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -01)ProjectionExec: expr=[t1_id@1 as t1_id, SUM(t2.t2_int * Float64(1)) + Int64(1)@0 as t2_sum] +01)ProjectionExec: expr=[t1_id@1 as t1_id, sum(t2.t2_int * Float64(1)) + Int64(1)@0 as t2_sum] 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[SUM(t2.t2_int * Float64(1)) + Int64(1)@0, t1_id@2] -04)------ProjectionExec: expr=[SUM(t2.t2_int * Float64(1))@1 + 1 as SUM(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] -05)--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] +03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[sum(t2.t2_int * Float64(1)) + Int64(1)@0, t1_id@2] +04)------ProjectionExec: expr=[sum(t2.t2_int * Float64(1))@1 + 1 as sum(t2.t2_int * Float64(1)) + Int64(1), t2_id@0 as t2_id] +05)--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int * Float64(1))] 06)----------CoalesceBatchesExec: target_batch_size=2 07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int * Float64(1))] -09)----------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -10)------CoalesceBatchesExec: target_batch_size=2 -11)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -12)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int * Float64(1))] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query IR rowsort SELECT t1_id, (SELECT sum(t2_int * 1.0) + 1 FROM t2 WHERE t2.t2_id = t1.t1_id) as t2_sum from t1 @@ -247,28 +264,28 @@ query TT explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 ---- logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum +01)Projection: t1.t1_id, __scalar_sq_1.sum(t2.t2_int) AS t2_sum 02)--Left Join: t1.t1_id = __scalar_sq_1.t2_id 03)----TableScan: t1 projection=[t1_id] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: SUM(t2.t2_int), t2.t2_id -06)--------Aggregate: groupBy=[[t2.t2_id, Utf8("a")]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] +05)------Projection: sum(t2.t2_int), t2.t2_id +06)--------Aggregate: groupBy=[[t2.t2_id]], aggr=[[sum(CAST(t2.t2_int AS Int64))]] 07)----------TableScan: t2 projection=[t2_id, t2_int] physical_plan -01)ProjectionExec: expr=[t1_id@0 as t1_id, SUM(t2.t2_int)@1 as t2_sum] +01)ProjectionExec: expr=[t1_id@1 as t1_id, sum(t2.t2_int)@0 as t2_sum] 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=Partitioned, join_type=Left, on=[(t1_id@0, t2_id@1)], projection=[t1_id@0, SUM(t2.t2_int)@1] -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -06)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -07)------CoalesceBatchesExec: target_batch_size=2 -08)--------RepartitionExec: partitioning=Hash([t2_id@1], 4), input_partitions=4 -09)----------ProjectionExec: expr=[SUM(t2.t2_int)@2 as SUM(t2.t2_int), t2_id@0 as t2_id] -10)------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id, Utf8("a")@1 as Utf8("a")], aggr=[SUM(t2.t2_int)], ordering_mode=PartiallySorted([1]) -11)--------------CoalesceBatchesExec: target_batch_size=2 -12)----------------RepartitionExec: partitioning=Hash([t2_id@0, Utf8("a")@1], 4), input_partitions=4 -13)------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id, a as Utf8("a")], aggr=[SUM(t2.t2_int)], ordering_mode=PartiallySorted([1]) -14)--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[sum(t2.t2_int)@0, t1_id@2] +04)------ProjectionExec: expr=[sum(t2.t2_int)@1 as sum(t2.t2_int), t2_id@0 as t2_id] +05)--------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] +06)----------CoalesceBatchesExec: target_batch_size=2 +07)------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +08)--------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------MemoryExec: partitions=1, partition_sizes=[1] +11)------CoalesceBatchesExec: target_batch_size=2 +12)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +13)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +14)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id group by t2_id, 'a') as t2_sum from t1 @@ -283,29 +300,31 @@ query TT explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 ---- logical_plan -01)Projection: t1.t1_id, __scalar_sq_1.SUM(t2.t2_int) AS t2_sum +01)Projection: t1.t1_id, __scalar_sq_1.sum(t2.t2_int) AS t2_sum 02)--Left Join: t1.t1_id = __scalar_sq_1.t2_id 03)----TableScan: t1 projection=[t1_id] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: SUM(t2.t2_int), t2.t2_id -06)--------Filter: SUM(t2.t2_int) < Int64(3) -07)----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[SUM(CAST(t2.t2_int AS Int64))]] +05)------Projection: sum(t2.t2_int), t2.t2_id +06)--------Filter: sum(t2.t2_int) < Int64(3) +07)----------Aggregate: groupBy=[[t2.t2_id]], aggr=[[sum(CAST(t2.t2_int AS Int64))]] 08)------------TableScan: t2 projection=[t2_id, t2_int] physical_plan -01)ProjectionExec: expr=[t1_id@1 as t1_id, SUM(t2.t2_int)@0 as t2_sum] +01)ProjectionExec: expr=[t1_id@1 as t1_id, sum(t2.t2_int)@0 as t2_sum] 02)--CoalesceBatchesExec: target_batch_size=2 -03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[SUM(t2.t2_int)@0, t1_id@2] -04)------ProjectionExec: expr=[SUM(t2.t2_int)@1 as SUM(t2.t2_int), t2_id@0 as t2_id] +03)----HashJoinExec: mode=Partitioned, join_type=Right, on=[(t2_id@1, t1_id@0)], projection=[sum(t2.t2_int)@0, t1_id@2] +04)------ProjectionExec: expr=[sum(t2.t2_int)@1 as sum(t2.t2_int), t2_id@0 as t2_id] 05)--------CoalesceBatchesExec: target_batch_size=2 -06)----------FilterExec: SUM(t2.t2_int)@1 < 3 -07)------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] +06)----------FilterExec: sum(t2.t2_int)@1 < 3 +07)------------AggregateExec: mode=FinalPartitioned, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] 08)--------------CoalesceBatchesExec: target_batch_size=2 09)----------------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 -10)------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[SUM(t2.t2_int)] -11)--------------------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] -12)------CoalesceBatchesExec: target_batch_size=2 -13)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 -14)----------MemoryExec: partitions=4, partition_sizes=[1, 0, 0, 0] +10)------------------AggregateExec: mode=Partial, gby=[t2_id@0 as t2_id], aggr=[sum(t2.t2_int)] +11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)----------------------MemoryExec: partitions=1, partition_sizes=[1] +13)------CoalesceBatchesExec: target_batch_size=2 +14)--------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +15)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)------------MemoryExec: partitions=1, partition_sizes=[1] query II rowsort SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id having sum(t2_int) < 3) as t2_sum from t1 @@ -333,17 +352,17 @@ where c_acctbal < ( logical_plan 01)Sort: customer.c_custkey ASC NULLS LAST 02)--Projection: customer.c_custkey -03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.SUM(orders.o_totalprice) +03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice) 04)------TableScan: customer projection=[c_custkey, c_acctbal] 05)------SubqueryAlias: __scalar_sq_1 -06)--------Projection: SUM(orders.o_totalprice), orders.o_custkey -07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] +06)--------Projection: sum(orders.o_totalprice), orders.o_custkey +07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] 08)------------Projection: orders.o_custkey, orders.o_totalprice 09)--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price 10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] 11)----------------SubqueryAlias: __scalar_sq_2 -12)------------------Projection: SUM(lineitem.l_extendedprice) AS price, lineitem.l_orderkey -13)--------------------Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] +12)------------------Projection: sum(lineitem.l_extendedprice) AS price, lineitem.l_orderkey +13)--------------------Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[sum(lineitem.l_extendedprice)]] 14)----------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice] # correlated_where_in @@ -380,7 +399,7 @@ logical_plan 01)Filter: EXISTS () 02)--Subquery: 03)----Projection: t1.t1_int -04)------Filter: t1.t1_id > t1.t1_int +04)------Filter: t1.t1_int < t1.t1_id 05)--------TableScan: t1 06)--TableScan: t1 projection=[t1_id, t1_name, t1_int] @@ -404,13 +423,13 @@ query TT explain SELECT t1_id, t1_name, t1_int FROM t1 WHERE t1_id IN(SELECT t2_id FROM t2 WHERE EXISTS(select * from t1 WHERE t1.t1_int > t2.t2_int)) ---- logical_plan -01)LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id +01)LeftSemi Join: t1.t1_id = __correlated_sq_2.t2_id 02)--TableScan: t1 projection=[t1_id, t1_name, t1_int] -03)--SubqueryAlias: __correlated_sq_1 +03)--SubqueryAlias: __correlated_sq_2 04)----Projection: t2.t2_id -05)------LeftSemi Join: Filter: __correlated_sq_2.t1_int > t2.t2_int +05)------LeftSemi Join: Filter: __correlated_sq_1.t1_int > t2.t2_int 06)--------TableScan: t2 projection=[t2_id, t2_int] -07)--------SubqueryAlias: __correlated_sq_2 +07)--------SubqueryAlias: __correlated_sq_1 08)----------TableScan: t1 projection=[t1_int] #invalid_scalar_subquery @@ -419,7 +438,7 @@ SELECT t1_id, t1_name, t1_int, (select t2_id, t2_name FROM t2 WHERE t2.t2_id = t #subquery_not_allowed #In/Exist Subquery is not allowed in ORDER BY clause. -statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes +statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: In/Exist subquery can only be used in Projection, Filter, Window functions, Aggregate and Join plan nodes, but was used in \[Sort: t1.t1_int IN \(\) ASC NULLS LAST\] SELECT t1_id, t1_name, t1_int FROM t1 order by t1_int in (SELECT t2_int FROM t2 WHERE t1.t1_id > t1.t1_int) #non_aggregated_correlated_scalar_subquery @@ -451,8 +470,8 @@ explain SELECT t1_id, (SELECT t2_int FROM t2 WHERE t2.t2_int = t1.t1_int limit 1 logical_plan 01)Projection: t1.t1_id, () AS t2_int 02)--Subquery: -03)----Limit: skip=0, fetch=1 -04)------Projection: t2.t2_int +03)----Projection: t2.t2_int +04)------Limit: skip=0, fetch=1 05)--------Filter: t2.t2_int = outer_ref(t1.t1_int) 06)----------TableScan: t2 07)--TableScan: t1 projection=[t1_id, t1_int] @@ -464,8 +483,8 @@ logical_plan 01)Projection: t1.t1_id 02)--Filter: t1.t1_int = () 03)----Subquery: -04)------Limit: skip=0, fetch=1 -05)--------Projection: t2.t2_int +04)------Projection: t2.t2_int +05)--------Limit: skip=0, fetch=1 06)----------Filter: t2.t2_int = outer_ref(t1.t1_int) 07)------------TableScan: t2 08)----TableScan: t1 projection=[t1_id, t1_int] @@ -490,8 +509,18 @@ SELECT t1_id, (SELECT a FROM (select 1 as a) WHERE a = t1.t1_int) as t2_int from 44 NULL #non_equal_correlated_scalar_subquery -statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: Correlated column is not allowed in predicate: t2\.t2_id < outer_ref\(t1\.t1_id\) -SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 +# Currently not supported and should not be decorrelated +query TT +explain SELECT t1_id, (SELECT sum(t2_int) FROM t2 WHERE t2.t2_id < t1.t1_id) as t2_sum from t1 +---- +logical_plan +01)Projection: t1.t1_id, () AS t2_sum +02)--Subquery: +03)----Projection: sum(t2.t2_int) +04)------Aggregate: groupBy=[[]], aggr=[[sum(CAST(t2.t2_int AS Int64))]] +05)--------Filter: t2.t2_id < outer_ref(t1.t1_id) +06)----------TableScan: t2 +07)--TableScan: t1 projection=[t1_id] #aggregated_correlated_scalar_subquery_with_extra_group_by_columns statement error DataFusion error: check_analyzed_plan\ncaused by\nError during planning: A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns @@ -505,8 +534,8 @@ logical_plan 01)Projection: t1.t1_id, t1.t1_name 02)--Filter: EXISTS () 03)----Subquery: -04)------Projection: SUM(outer_ref(t1.t1_int) + t2.t2_id) -05)--------Aggregate: groupBy=[[]], aggr=[[SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +04)------Projection: sum(outer_ref(t1.t1_int) + t2.t2_id) +05)--------Aggregate: groupBy=[[]], aggr=[[sum(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] 06)----------Filter: outer_ref(t1.t1_name) = t2.t2_name 07)------------TableScan: t2 08)----TableScan: t1 projection=[t1_id, t1_name, t1_int] @@ -519,9 +548,9 @@ logical_plan 01)Projection: t1.t1_id, t1.t1_name 02)--Filter: EXISTS () 03)----Subquery: -04)------Projection: COUNT(*) -05)--------Filter: SUM(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) -06)----------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*), SUM(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] +04)------Projection: count(*) +05)--------Filter: sum(outer_ref(t1.t1_int) + t2.t2_id) > Int64(0) +06)----------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*), sum(CAST(outer_ref(t1.t1_int) + t2.t2_id AS Int64))]] 07)------------Filter: outer_ref(t1.t1_name) = t2.t2_name 08)--------------TableScan: t2 09)----TableScan: t1 projection=[t1_id, t1_name, t1_int] @@ -531,13 +560,13 @@ query TT explain SELECT t0_id, t0_name FROM t0 WHERE EXISTS (SELECT 1 FROM t1 INNER JOIN t2 ON(t1.t1_id = t2.t2_id and t1.t1_name = t0.t0_name)) ---- logical_plan -01)Filter: EXISTS () -02)--Subquery: -03)----Projection: Int64(1) -04)------Inner Join: Filter: t1.t1_id = t2.t2_id AND t1.t1_name = outer_ref(t0.t0_name) -05)--------TableScan: t1 -06)--------TableScan: t2 -07)--TableScan: t0 projection=[t0_id, t0_name] +01)LeftSemi Join: t0.t0_name = __correlated_sq_2.t1_name +02)--TableScan: t0 projection=[t0_id, t0_name] +03)--SubqueryAlias: __correlated_sq_2 +04)----Projection: t1.t1_name +05)------Inner Join: t1.t1_id = t2.t2_id +06)--------TableScan: t1 projection=[t1_id, t1_name] +07)--------TableScan: t2 projection=[t2_id] #subquery_contains_join_contains_correlated_columns query TT @@ -615,10 +644,7 @@ SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id query TT explain SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) ---- -logical_plan -01)LeftSemi Join: t1.t1_id = __correlated_sq_1.t2_id -02)--TableScan: t1 projection=[t1_id, t1_name] -03)--EmptyRelation +logical_plan EmptyRelation query IT rowsort SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) @@ -630,10 +656,7 @@ SELECT t1_id, t1_name FROM t1 WHERE EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id query TT explain SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) ---- -logical_plan -01)LeftAnti Join: t1.t1_id = __correlated_sq_1.t2_id -02)--TableScan: t1 projection=[t1_id, t1_name] -03)--EmptyRelation +logical_plan TableScan: t1 projection=[t1_id, t1_name] query IT rowsort SELECT t1_id, t1_name FROM t1 WHERE NOT EXISTS (SELECT * FROM t2 WHERE t2_id = t1_id limit 0) @@ -651,8 +674,8 @@ explain SELECT t1_id, t1_name FROM t1 WHERE t1_id in (SELECT t2_id FROM t2 where logical_plan 01)Filter: t1.t1_id IN () 02)--Subquery: -03)----Limit: skip=0, fetch=10 -04)------Projection: t2.t2_id +03)----Projection: t2.t2_id +04)------Limit: skip=0, fetch=10 05)--------Filter: outer_ref(t1.t1_name) = t2.t2_name 06)----------TableScan: t2 07)--TableScan: t1 projection=[t1_id, t1_name] @@ -709,9 +732,9 @@ query TT explain select (select count(*) from t1) as b ---- logical_plan -01)Projection: __scalar_sq_1.COUNT(*) AS b +01)Projection: __scalar_sq_1.count(*) AS b 02)--SubqueryAlias: __scalar_sq_1 -03)----Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +03)----Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 04)------TableScan: t1 projection=[] #simple_uncorrelated_scalar_subquery2 @@ -719,13 +742,13 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.COUNT(*) AS b, __scalar_sq_2.COUNT(Int64(1)) AS COUNT(Int64(1)) +01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) 02)--Left Join: 03)----SubqueryAlias: __scalar_sq_1 -04)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------TableScan: t1 projection=[] 06)----SubqueryAlias: __scalar_sq_2 -07)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +07)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 08)--------TableScan: t2 projection=[] statement ok @@ -735,20 +758,20 @@ query TT explain select (select count(*) from t1) as b, (select count(1) from t2) ---- logical_plan -01)Projection: __scalar_sq_1.COUNT(*) AS b, __scalar_sq_2.COUNT(Int64(1)) AS COUNT(Int64(1)) +01)Projection: __scalar_sq_1.count(*) AS b, __scalar_sq_2.count(Int64(1)) AS count(Int64(1)) 02)--Left Join: 03)----SubqueryAlias: __scalar_sq_1 -04)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +04)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 05)--------TableScan: t1 projection=[] 06)----SubqueryAlias: __scalar_sq_2 -07)------Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]] +07)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 08)--------TableScan: t2 projection=[] physical_plan -01)ProjectionExec: expr=[COUNT(*)@0 as b, COUNT(Int64(1))@1 as COUNT(Int64(1))] +01)ProjectionExec: expr=[count(*)@0 as b, count(Int64(1))@1 as count(Int64(1))] 02)--NestedLoopJoinExec: join_type=Left -03)----ProjectionExec: expr=[4 as COUNT(*)] +03)----ProjectionExec: expr=[4 as count(*)] 04)------PlaceholderRowExec -05)----ProjectionExec: expr=[4 as COUNT(Int64(1))] +05)----ProjectionExec: expr=[4 as count(Int64(1))] 06)------PlaceholderRowExec statement ok @@ -764,12 +787,12 @@ query TT explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END AS COUNT(*) +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS count(*) 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*), t2.t2_int, __always_true -06)--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*), t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -786,12 +809,12 @@ query TT explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END AS cnt +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END AS cnt 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*), t2.t2_int, __always_true -06)--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*), t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -811,8 +834,8 @@ logical_plan 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) AS _cnt, t2.t2_int, __always_true -06)--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) AS _cnt, t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -832,8 +855,8 @@ logical_plan 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) + Int64(2) AS _cnt, t2.t2_int, __always_true -06)--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) + Int64(2) AS _cnt, t2.t2_int, Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -850,13 +873,13 @@ explain select t1.t1_int from t1 where (select count(*) from t2 where t1.t1_id = ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END < CAST(t1.t1_int AS Int64) -03)----Projection: t1.t1_int, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END < CAST(t1.t1_int AS Int64) +03)----Projection: t1.t1_int, __scalar_sq_1.count(*), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_id = __scalar_sq_1.t2_id 05)--------TableScan: t1 projection=[t1_id, t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: COUNT(*), t2.t2_id, __always_true -08)------------Aggregate: groupBy=[[t2.t2_id, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: count(*), t2.t2_id, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_id]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_id] query I rowsort @@ -876,11 +899,10 @@ logical_plan 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) + Int64(2) AS cnt_plus_2, t2.t2_int -06)--------Filter: COUNT(*) > Int64(1) -07)----------Projection: t2.t2_int, COUNT(*) -08)------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] -09)--------------TableScan: t2 projection=[t2_int] +05)------Projection: count(*) + Int64(2) AS cnt_plus_2, t2.t2_int +06)--------Filter: count(*) > Int64(1) +07)----------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] +08)------------TableScan: t2 projection=[t2_int] query II rowsort SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) >1) from t1 @@ -896,12 +918,12 @@ query TT explain SELECT t1_id, (SELECT count(*) + 2 as cnt_plus_2 FROM t2 WHERE t2.t2_int = t1.t1_int having count(*) = 0) from t1 ---- logical_plan -01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.COUNT(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 +01)Projection: t1.t1_id, CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) AS cnt_plus_2 WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_2 END AS cnt_plus_2 02)--Left Join: t1.t1_int = __scalar_sq_1.t2_int 03)----TableScan: t1 projection=[t1_id, t1_int] 04)----SubqueryAlias: __scalar_sq_1 -05)------Projection: COUNT(*) + Int64(2) AS cnt_plus_2, t2.t2_int, COUNT(*), __always_true -06)--------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +05)------Projection: count(*) + Int64(2) AS cnt_plus_2, t2.t2_int, count(*), Boolean(true) AS __always_true +06)--------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 07)----------TableScan: t2 projection=[t2_int] query II rowsort @@ -918,14 +940,14 @@ explain select t1.t1_int from t1 group by t1.t1_int having (select count(*) from ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.COUNT(*) END = Int64(0) -03)----Projection: t1.t1_int, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END = Int64(0) +03)----Projection: t1.t1_int, __scalar_sq_1.count(*), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------Aggregate: groupBy=[[t1.t1_int]], aggr=[[]] 06)----------TableScan: t1 projection=[t1_int] 07)--------SubqueryAlias: __scalar_sq_1 -08)----------Projection: COUNT(*), t2.t2_int, __always_true -09)------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +08)----------Projection: count(*), t2.t2_int, Boolean(true) AS __always_true +09)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 10)--------------TableScan: t2 projection=[t2_int] query I rowsort @@ -945,8 +967,8 @@ logical_plan 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: COUNT(*) AS cnt, t2.t2_int, __always_true -08)------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: count(*) AS cnt, t2.t2_int, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_int] @@ -970,13 +992,13 @@ select t1.t1_int from t1 where ( ---- logical_plan 01)Projection: t1.t1_int -02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.COUNT(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) -03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, __scalar_sq_1.COUNT(*), __scalar_sq_1.__always_true +02)--Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(2) WHEN __scalar_sq_1.count(*) != Int64(0) THEN NULL ELSE __scalar_sq_1.cnt_plus_two END = Int64(2) +03)----Projection: t1.t1_int, __scalar_sq_1.cnt_plus_two, __scalar_sq_1.count(*), __scalar_sq_1.__always_true 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: COUNT(*) + Int64(1) + Int64(1) AS cnt_plus_two, t2.t2_int, COUNT(*), __always_true -08)------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: count(*) + Int64(1) + Int64(1) AS cnt_plus_two, t2.t2_int, count(*), Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_int] query I rowsort @@ -1004,8 +1026,8 @@ logical_plan 04)------Left Join: t1.t1_int = __scalar_sq_1.t2_int 05)--------TableScan: t1 projection=[t1_int] 06)--------SubqueryAlias: __scalar_sq_1 -07)----------Projection: CASE WHEN COUNT(*) = Int64(1) THEN Int64(NULL) ELSE COUNT(*) END AS cnt, t2.t2_int, __always_true -08)------------Aggregate: groupBy=[[t2.t2_int, Boolean(true) AS __always_true]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +07)----------Projection: CASE WHEN count(*) = Int64(1) THEN Int64(NULL) ELSE count(*) END AS cnt, t2.t2_int, Boolean(true) AS __always_true +08)------------Aggregate: groupBy=[[t2.t2_int]], aggr=[[count(Int64(1)) AS count(*)]] 09)--------------TableScan: t2 projection=[t2_int] @@ -1024,6 +1046,190 @@ false true true +# in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id Filter: t1.t1_int > Int32(0) +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 +44 d 4 + +# not_in_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id = Int32(11) OR NOT __correlated_sq_1.mark +03)----LeftMark Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.t2.t2_id + Int64(1) Filter: t1.t1_int > Int32(0) +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------Projection: CAST(t2.t2_id AS Int64) + Int64(1) +07)----------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id = 11 or t1.t1_id + 12 not in (select t2.t2_id + 1 from t2 where t1.t1_int > 0) +---- +11 a 1 +22 b 2 + +# exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +statement ok +set datafusion.explain.logical_plan_only = false; + +# not_exists_subquery_to_join_with_correlated_outer_filter_disjunction +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR NOT __correlated_sq_1.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_1.t2_id +04)------TableScan: t1 projection=[t1_id, t1_name, t1_int] +05)------SubqueryAlias: __correlated_sq_1 +06)--------TableScan: t2 projection=[t2_id] +physical_plan +01)CoalesceBatchesExec: target_batch_size=2 +02)--FilterExec: t1_id@0 > 40 OR NOT mark@3, projection=[t1_id@0, t1_name@1, t1_int@2] +03)----CoalesceBatchesExec: target_batch_size=2 +04)------HashJoinExec: mode=Partitioned, join_type=LeftMark, on=[(t1_id@0, t2_id@0)] +05)--------CoalesceBatchesExec: target_batch_size=2 +06)----------RepartitionExec: partitioning=Hash([t1_id@0], 4), input_partitions=4 +07)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +08)--------------MemoryExec: partitions=1, partition_sizes=[1] +09)--------CoalesceBatchesExec: target_batch_size=2 +10)----------RepartitionExec: partitioning=Hash([t2_id@0], 4), input_partitions=4 +11)------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +12)--------------MemoryExec: partitions=1, partition_sizes=[1] + +statement ok +set datafusion.explain.logical_plan_only = true; + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or not exists (select * from t2 where t1.t1_id = t2.t2_id) +---- +33 c 3 +44 d 4 + +# in_subquery_to_join_with_correlated_outer_filter_and_or +query TT +explain select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +logical_plan +01)Projection: t1.t1_id, t1.t1_name, t1.t1_int +02)--Filter: t1.t1_id > Int32(40) OR __correlated_sq_2.mark +03)----LeftMark Join: t1.t1_id = __correlated_sq_2.t2_id Filter: t1.t1_int > Int32(0) +04)------LeftSemi Join: t1.t1_id = __correlated_sq_1.t3_id +05)--------TableScan: t1 projection=[t1_id, t1_name, t1_int] +06)--------SubqueryAlias: __correlated_sq_1 +07)----------TableScan: t3 projection=[t3_id] +08)------SubqueryAlias: __correlated_sq_2 +09)--------TableScan: t2 projection=[t2_id] + +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id in (select t3.t3_id from t3) and (t1.t1_id > 40 or t1.t1_id in (select t2.t2_id from t2 where t1.t1_int > 0)) +---- +11 a 1 +22 b 2 +44 d 4 + +# Handle duplicate values in exists query +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where t1.t1_id > 40 or exists (select * from t2 cross join t3 where t1.t1_id = t2.t2_id) +---- +11 a 1 +22 b 2 +44 d 4 + +# Nested subqueries +query ITI rowsort +select t1.t1_id, + t1.t1_name, + t1.t1_int +from t1 +where exists ( + select * from t2 where t1.t1_id = t2.t2_id OR exists ( + select * from t3 where t2.t2_id = t3.t3_id + ) +) +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 # issue: https://github.com/apache/datafusion/issues/7027 query TTTT rowsort @@ -1072,8 +1278,8 @@ query TT explain select a/2, a/2 + 1 from t ---- logical_plan -01)Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) -02)--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +01)Projection: __common_expr_1 AS t.a / Int64(2), __common_expr_1 AS t.a / Int64(2) + Int64(1) +02)--Projection: t.a / Int64(2) AS __common_expr_1 03)----TableScan: t projection=[a] statement ok @@ -1083,8 +1289,8 @@ query TT explain select a/2, a/2 + 1 from t ---- logical_plan -01)Projection: t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2), t.a / Int64(2)Int64(2)t.a AS t.a / Int64(2) + Int64(1) -02)--Projection: t.a / Int64(2) AS t.a / Int64(2)Int64(2)t.a +01)Projection: __common_expr_1 AS t.a / Int64(2), __common_expr_1 AS t.a / Int64(2) + Int64(1) +02)--Projection: t.a / Int64(2) AS __common_expr_1 03)----TableScan: t projection=[a] ### diff --git a/datafusion/sqllogictest/test_files/subquery_sort.slt b/datafusion/sqllogictest/test_files/subquery_sort.slt new file mode 100644 index 000000000000..a3717dd838d6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/subquery_sort.slt @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +statement ok +CREATE EXTERNAL TABLE sink_table ( + c1 VARCHAR NOT NULL, + c2 TINYINT NOT NULL, + c3 SMALLINT NOT NULL, + c4 SMALLINT NOT NULL, + c5 INTEGER NOT NULL, + c6 BIGINT NOT NULL, + c7 SMALLINT NOT NULL, + c8 INT NOT NULL, + c9 INT UNSIGNED NOT NULL, + c10 BIGINT UNSIGNED NOT NULL, + c11 FLOAT NOT NULL, + c12 DOUBLE NOT NULL, + c13 VARCHAR NOT NULL + ) +STORED AS CSV +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); + +# Remove the redundant ordering in the subquery + +query TT +EXPLAIN SELECT c1 FROM (SELECT c1 FROM sink_table ORDER BY c2) AS ttt +---- +logical_plan +01)SubqueryAlias: ttt +02)--TableScan: sink_table projection=[c1] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true + +query TT +EXPLAIN SELECT c1 FROM (SELECT c1 FROM sink_table ORDER BY c2) +---- +logical_plan TableScan: sink_table projection=[c1] +physical_plan CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1], has_header=true + + +# Do not remove ordering when it's with limit + +query TT +EXPLAIN SELECT c1, c2 FROM (SELECT c1, c2, c3, c9 FROM sink_table ORDER BY c1 DESC, c3 LIMIT 2) AS t2 ORDER BY t2.c1, t2.c3, t2.c9; +---- +logical_plan +01)Projection: t2.c1, t2.c2 +02)--Sort: t2.c1 ASC NULLS LAST, t2.c3 ASC NULLS LAST, t2.c9 ASC NULLS LAST +03)----SubqueryAlias: t2 +04)------Sort: sink_table.c1 DESC NULLS FIRST, sink_table.c3 ASC NULLS LAST, fetch=2 +05)--------TableScan: sink_table projection=[c1, c2, c3, c9] +physical_plan +01)ProjectionExec: expr=[c1@0 as c1, c2@1 as c2] +02)--SortExec: expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] +03)----SortExec: TopK(fetch=2), expr=[c1@0 DESC, c3@2 ASC NULLS LAST], preserve_partitioning=[false] +04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c9], has_header=true + + +query TI +SELECT c1, c2 FROM (SELECT c1, c2, c3, c9 FROM sink_table ORDER BY c1, c3 LIMIT 2) AS t2 ORDER BY t2.c1, t2.c3, t2.c9; +---- +a 4 +a 5 + +query TI +SELECT c1, c2 FROM (SELECT c1, c2, c3, c9 FROM sink_table ORDER BY c1 DESC, c3 LIMIT 2) AS t2 ORDER BY t2.c1, t2.c3, t2.c9; +---- +e 3 +e 5 + + +# Do not remove ordering when it's a part of an aggregation in subquery + +query TT +EXPLAIN SELECT t2.c1, t2.r FROM (SELECT c1, RANK() OVER (ORDER BY c1 DESC) AS r, c3, c9 FROM sink_table ORDER BY c1, c3 LIMIT 2) AS t2 ORDER BY t2.c1, t2.c3, t2.c9; +---- +logical_plan +01)Projection: t2.c1, t2.r +02)--Sort: t2.c1 ASC NULLS LAST, t2.c3 ASC NULLS LAST, t2.c9 ASC NULLS LAST +03)----SubqueryAlias: t2 +04)------Sort: sink_table.c1 ASC NULLS LAST, sink_table.c3 ASC NULLS LAST, fetch=2 +05)--------Projection: sink_table.c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r, sink_table.c3, sink_table.c9 +06)----------WindowAggr: windowExpr=[[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +07)------------TableScan: sink_table projection=[c1, c3, c9] +physical_plan +01)ProjectionExec: expr=[c1@0 as c1, r@1 as r] +02)--SortExec: TopK(fetch=2), expr=[c1@0 ASC NULLS LAST, c3@2 ASC NULLS LAST, c9@3 ASC NULLS LAST], preserve_partitioning=[false] +03)----ProjectionExec: expr=[c1@0 as c1, rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as r, c3@1 as c3, c9@2 as c9] +04)------BoundedWindowAggExec: wdw=[rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "rank() ORDER BY [sink_table.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Utf8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c1@0 DESC], preserve_partitioning=[false] +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3, c9], has_header=true + + +query TT +EXPLAIN SELECT c1, c2 FROM (SELECT DISTINCT ON (c1) c1, c2, c3, c9 FROM sink_table ORDER BY c1, c3 DESC, c9) AS t2 ORDER BY t2.c1, t2.c3 DESC, t2.c9 +---- +logical_plan +01)Projection: t2.c1, t2.c2 +02)--Sort: t2.c1 ASC NULLS LAST, t2.c3 DESC NULLS FIRST, t2.c9 ASC NULLS LAST +03)----SubqueryAlias: t2 +04)------Projection: first_value(sink_table.c1) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST] AS c1, first_value(sink_table.c2) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST] AS c2, first_value(sink_table.c3) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST] AS c3, first_value(sink_table.c9) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST] AS c9 +05)--------Sort: sink_table.c1 ASC NULLS LAST +06)----------Aggregate: groupBy=[[sink_table.c1]], aggr=[[first_value(sink_table.c1) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c2) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c3) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c9) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]]] +07)------------TableScan: sink_table projection=[c1, c2, c3, c9] +physical_plan +01)ProjectionExec: expr=[c1@0 as c1, c2@1 as c2] +02)--SortPreservingMergeExec: [c1@0 ASC NULLS LAST, c3@2 DESC, c9@3 ASC NULLS LAST] +03)----SortExec: expr=[c1@0 ASC NULLS LAST, c3@2 DESC, c9@3 ASC NULLS LAST], preserve_partitioning=[true] +04)------ProjectionExec: expr=[first_value(sink_table.c1) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]@1 as c1, first_value(sink_table.c2) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]@2 as c2, first_value(sink_table.c3) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]@3 as c3, first_value(sink_table.c9) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]@4 as c9] +05)--------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[first_value(sink_table.c1) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c2) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c3) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c9) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +08)--------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[first_value(sink_table.c1) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c2) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c3) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST], first_value(sink_table.c9) ORDER BY [sink_table.c1 ASC NULLS LAST, sink_table.c3 DESC NULLS FIRST, sink_table.c9 ASC NULLS LAST]] +09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c9], has_header=true + + +query TI +SELECT c1, c2 FROM (SELECT DISTINCT ON (c1) c1, c2, c3, c9 FROM sink_table ORDER BY c1, c3, c9) AS t2 ORDER BY t2.c1, t2.c3, t2.c9; +---- +a 4 +b 4 +c 2 +d 1 +e 3 + + +query TI +SELECT c1, c2 FROM (SELECT DISTINCT ON (c1) c1, c2, c3, c9 FROM sink_table ORDER BY c1, c3 DESC, c9) AS t2 ORDER BY t2.c1, t2.c3 DESC, t2.c9 +---- +a 1 +b 5 +c 4 +d 1 +e 1 diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 32a28231d034..70f7dedeaca0 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -84,6 +84,11 @@ select case when current_time() = (now()::bigint % 86400000000000)::time then 'O ---- OK +query B +select now() = current_timestamp; +---- +true + ########## ## Timestamp Handling Tests ########## @@ -303,6 +308,29 @@ SELECT from_unixtime(ts / 1000000000) FROM ts_data LIMIT 3; 2020-09-08T12:42:29 2020-09-08T11:42:29 +# from_unixtime single + +query P +SELECT from_unixtime(1599572549190855123 / 1000000000, 'America/New_York'); +---- +2020-09-08T09:42:29-04:00 + +# from_unixtime with timezone +query P +SELECT from_unixtime(ts / 1000000000, 'Asia/Istanbul') FROM ts_data LIMIT 3; +---- +2020-09-08T16:42:29+03:00 +2020-09-08T15:42:29+03:00 +2020-09-08T14:42:29+03:00 + +# from_unixtime with utc timezone +query P +SELECT from_unixtime(ts / 1000000000, 'UTC') FROM ts_data LIMIT 3; +---- +2020-09-08T13:42:29Z +2020-09-08T12:42:29Z +2020-09-08T11:42:29Z + # to_timestamp query I @@ -397,6 +425,41 @@ SELECT COUNT(*) FROM ts_data_secs where ts > from_unixtime(1599566400) ---- 2 +query P rowsort +SELECT ts FROM ts_data_nanos; +---- +2020-09-08T11:42:29.190855123 +2020-09-08T12:42:29.190855123 +2020-09-08T13:42:29.190855123 + +query P rowsort +SELECT CAST(ts AS timestamp(0)) FROM ts_data_nanos; +---- +2020-09-08T11:42:29 +2020-09-08T12:42:29 +2020-09-08T13:42:29 + +query P rowsort +SELECT CAST(ts AS timestamp(3)) FROM ts_data_nanos; +---- +2020-09-08T11:42:29.190 +2020-09-08T12:42:29.190 +2020-09-08T13:42:29.190 + +query P rowsort +SELECT CAST(ts AS timestamp(6)) FROM ts_data_nanos; +---- +2020-09-08T11:42:29.190855 +2020-09-08T12:42:29.190855 +2020-09-08T13:42:29.190855 + +query P rowsort +SELECT CAST(ts AS timestamp(9)) FROM ts_data_nanos; +---- +2020-09-08T11:42:29.190855123 +2020-09-08T12:42:29.190855123 +2020-09-08T13:42:29.190855123 + # count_distinct_timestamps query P rowsort @@ -466,7 +529,7 @@ query error Cannot cast string '24:01:02' to value of Time64\(Nanosecond\) type SELECT TIME '24:01:02' as time; # invalid timezone -query error Arrow error: Parser error: Invalid timezone "ZZ": 'ZZ' is not a valid timezone +query error Arrow error: Parser error: Invalid timezone "ZZ": failed to parse timezone SELECT TIMESTAMP '2023-12-05T21:58:10.45ZZ'; statement ok @@ -538,7 +601,7 @@ select to_timestamp_seconds(cast (1 as int)); ########## # invalid second arg type -query error DataFusion error: Error during planning: No function matches the given name and argument types 'date_bin\(Interval\(MonthDayNano\), Int64, Timestamp\(Nanosecond, None\)\)'\. +query error SELECT DATE_BIN(INTERVAL '0 second', 25, TIMESTAMP '1970-01-01T00:00:00Z') # not support interval 0 @@ -975,6 +1038,23 @@ SELECT DATE_BIN('3 years 1 months', '2022-09-01 00:00:00Z'); ---- 2022-06-01T00:00:00 +# Times before the unix epoch +query P +select date_bin('1 hour', column1) +from (values + (timestamp '1969-01-01 00:00:00'), + (timestamp '1969-01-01 00:15:00'), + (timestamp '1969-01-01 00:30:00'), + (timestamp '1969-01-01 00:45:00'), + (timestamp '1969-01-01 01:00:00') +) as sq +---- +1969-01-01T00:00:00 +1969-01-01T00:00:00 +1969-01-01T00:00:00 +1969-01-01T00:00:00 +1969-01-01T01:00:00 + ### ## test date_trunc function ### @@ -1161,7 +1241,7 @@ ts_data_secs 2020-09-08T00:00:00 ts_data_secs 2020-09-08T00:00:00 ts_data_secs 2020-09-08T00:00:00 -# Test date trun on different granularity +# Test date turn on different granularity query TP rowsort SELECT 'millisecond', DATE_TRUNC('millisecond', ts) FROM ts_data_nanos UNION ALL @@ -1509,19 +1589,19 @@ SELECT val, ts1 - ts2 FROM foo ORDER BY ts2 - ts1; query ? SELECT i1 - i2 FROM bar; ---- -0 years 0 mons -1 days 0 hours 0 mins 0.000000000 secs -0 years 2 mons -13 days 0 hours 0 mins 0.000000000 secs -0 years 0 mons 1 days 2 hours 56 mins 0.000000000 secs -0 years 0 mons 1 days 0 hours 0 mins -3.999999993 secs +-1 days +2 mons -13 days +1 days 2 hours 56 mins +1 days -3.999999993 secs # Interval + Interval query ? SELECT i1 + i2 FROM bar; ---- -0 years 0 mons 3 days 0 hours 0 mins 0.000000000 secs -0 years 2 mons 13 days 0 hours 0 mins 0.000000000 secs -0 years 0 mons 1 days 3 hours 4 mins 0.000000000 secs -0 years 0 mons 1 days 0 hours 0 mins 4.000000007 secs +3 days +2 mons 13 days +1 days 3 hours 4 mins +1 days 4.000000007 secs # Timestamp - Interval query P @@ -2191,6 +2271,14 @@ create table ts_utf8_data(ts varchar(100), format varchar(100)) as values ('1926632005', '%s'), ('2000-01-01T01:01:01+07:00', '%+'); +statement ok +create table ts_largeutf8_data as +select arrow_cast(ts, 'LargeUtf8') as ts, arrow_cast(format, 'LargeUtf8') as format from ts_utf8_data; + +statement ok +create table ts_utf8view_data as +select arrow_cast(ts, 'Utf8View') as ts, arrow_cast(format, 'Utf8View') as format from ts_utf8_data; + # verify timestamp data using tables with formatting options query P SELECT to_timestamp(t.ts, t.format) from ts_utf8_data as t @@ -2201,9 +2289,84 @@ SELECT to_timestamp(t.ts, t.format) from ts_utf8_data as t 2031-01-19T23:33:25 1999-12-31T18:01:01 +query PPPPP +SELECT to_timestamp(t.ts, t.format), + to_timestamp_seconds(t.ts, t.format), + to_timestamp_millis(t.ts, t.format), + to_timestamp_micros(t.ts, t.format), + to_timestamp_nanos(t.ts, t.format) + from ts_largeutf8_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +query PPPPP +SELECT to_timestamp(t.ts, t.format), + to_timestamp_seconds(t.ts, t.format), + to_timestamp_millis(t.ts, t.format), + to_timestamp_micros(t.ts, t.format), + to_timestamp_nanos(t.ts, t.format) + from ts_utf8view_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + # verify timestamp data using tables with formatting options +query PPPPP +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_seconds(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_millis(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_micros(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_nanos(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') + from ts_utf8_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +query PPPPP +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_seconds(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_millis(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_micros(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_nanos(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') + from ts_largeutf8_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +query PPPPP +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_seconds(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_millis(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_micros(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z'), + to_timestamp_nanos(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') + from ts_utf8view_data as t +---- +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 2031-01-19T18:33:25 +2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 2020-09-08T12:00:00 +2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 2031-01-19T23:33:25 +1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 1999-12-31T18:01:01 + +# verify timestamp data using tables with formatting options where at least one column cannot be parsed +query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t + +# verify timestamp data using tables with formatting options where one of the formats is invalid query P -SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8_data as t ---- 2020-09-08T12:00:00 2031-01-19T18:33:25 @@ -2211,13 +2374,17 @@ SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%s', '%d-%m-%Y %H:%M:%S 2031-01-19T23:33:25 1999-12-31T18:01:01 -# verify timestamp data using tables with formatting options where at least one column cannot be parsed -query error Error parsing timestamp from '1926632005' using format '%d-%m-%Y %H:%M:%S%#z': input contains invalid characters -SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%+', '%d-%m-%Y %H:%M:%S%#z') from ts_utf8_data as t +query P +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_largeutf8_data as t +---- +2020-09-08T12:00:00 +2031-01-19T18:33:25 +2020-09-08T12:00:00 +2031-01-19T23:33:25 +1999-12-31T18:01:01 -# verify timestamp data using tables with formatting options where one of the formats is invalid query P -SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8_data as t +SELECT to_timestamp(t.ts, '%Y-%m-%d %H/%M/%S%#z', '%s', '%q', '%d-%m-%Y %H:%M:%S%#z', '%+') from ts_utf8view_data as t ---- 2020-09-08T12:00:00 2031-01-19T18:33:25 @@ -2688,6 +2855,11 @@ FROM NULL 01:01:2025 23-59-58 +query T +select to_char('2020-01-01 00:10:20.123'::timestamp at time zone 'America/New_York', '%Y-%m-%d %H:%M:%S.%3f'); +---- +2020-01-01 00:10:20.123 + statement ok drop table formats; @@ -2774,6 +2946,26 @@ SELECT '2000-12-01 04:04:12' AT TIME ZONE 'America/New_York'; ---- 2000-12-01T04:04:12-05:00 +query P +SELECT '2024-03-30 00:00:20' AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T00:00:20+01:00 + +query P +SELECT '2024-03-30 00:00:20'::timestamp AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T00:00:20+01:00 + +query P +SELECT '2024-03-30 00:00:20Z' AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T01:00:20+01:00 + +query P +SELECT '2024-03-30 00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'; +---- +2024-03-30T00:00:20+01:00 + ## date-time strings that already have a explicit timezone can be used with AT TIME ZONE # same time zone as provided date-time @@ -2795,3 +2987,367 @@ SELECT '2000-12-01 04:04:12' AT TIME ZONE 'America/New York'; # abbreviated timezone is not supported statement error SELECT '2023-03-12 02:00:00' AT TIME ZONE 'EDT'; + +# Test current_time without parentheses +query B +select current_time = current_time; +---- +true + +# Test temporal coercion for UTC +query ? +select arrow_cast('2024-06-17T11:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("UTC"))'); +---- +0 days -1 hours 0 mins 0.000000 secs + +query ? +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("UTC"))'); +---- +0 days 1 hours 0 mins 0.000000 secs + +query ? +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+00:00"))'); +---- +0 days 1 hours 0 mins 0.000000 secs + +# not supported: coercion across timezones +query error +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); + +query error +select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+01:00"))'); + +########## +## Test to_local_time function +########## + +# invalid number of arguments -- no argument +statement error +select to_local_time(); + +# invalid number of arguments -- more than 1 argument +statement error +select to_local_time('2024-04-01T00:00:20Z'::timestamp, 'some string'); + +# invalid argument data type +statement error The to_local_time function can only accept Timestamp as the arg got Utf8 +select to_local_time('2024-04-01T00:00:20Z'); + +# invalid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "Europe/timezone": failed to parse timezone +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/timezone'); + +# valid query +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE '+05:00'); +---- +2024-04-01T00:00:20 + +query P +select to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); +---- +2024-04-01T00:00:20 + +query PTPT +select + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +from ( + select '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' as time +); +---- +2024-04-01T00:00:20+02:00 Timestamp(Nanosecond, Some("Europe/Brussels")) 2024-04-01T00:00:20 Timestamp(Nanosecond, None) + +# use to_local_time() in date_bin() +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')); +---- +2024-04-01T00:00:00 + +query P +select date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels'; +---- +2024-04-01T00:00:00+02:00 + +# test using to_local_time() on array values +statement ok +create table t AS +VALUES + ('2024-01-01T00:00:01Z'), + ('2024-02-01T00:00:01Z'), + ('2024-03-01T00:00:01Z'), + ('2024-04-01T00:00:01Z'), + ('2024-05-01T00:00:01Z'), + ('2024-06-01T00:00:01Z'), + ('2024-07-01T00:00:01Z'), + ('2024-08-01T00:00:01Z'), + ('2024-09-01T00:00:01Z'), + ('2024-10-01T00:00:01Z'), + ('2024-11-01T00:00:01Z'), + ('2024-12-01T00:00:01Z') +; + +statement ok +create view t_utc as +select column1::timestamp AT TIME ZONE 'UTC' as "column1" +from t; + +statement ok +create view t_timezone as +select column1::timestamp AT TIME ZONE 'Europe/Brussels' as "column1" +from t; + +query PPT +select column1, to_local_time(column1::timestamp), arrow_typeof(to_local_time(column1::timestamp)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_utc; +---- +2024-01-01T00:00:01Z 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01Z 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01Z 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01Z 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01Z 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01Z 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01Z 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01Z 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01Z 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01Z 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01Z 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01Z 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +query PPT +select column1, to_local_time(column1), arrow_typeof(to_local_time(column1)) from t_timezone; +---- +2024-01-01T00:00:01+01:00 2024-01-01T00:00:01 Timestamp(Nanosecond, None) +2024-02-01T00:00:01+01:00 2024-02-01T00:00:01 Timestamp(Nanosecond, None) +2024-03-01T00:00:01+01:00 2024-03-01T00:00:01 Timestamp(Nanosecond, None) +2024-04-01T00:00:01+02:00 2024-04-01T00:00:01 Timestamp(Nanosecond, None) +2024-05-01T00:00:01+02:00 2024-05-01T00:00:01 Timestamp(Nanosecond, None) +2024-06-01T00:00:01+02:00 2024-06-01T00:00:01 Timestamp(Nanosecond, None) +2024-07-01T00:00:01+02:00 2024-07-01T00:00:01 Timestamp(Nanosecond, None) +2024-08-01T00:00:01+02:00 2024-08-01T00:00:01 Timestamp(Nanosecond, None) +2024-09-01T00:00:01+02:00 2024-09-01T00:00:01 Timestamp(Nanosecond, None) +2024-10-01T00:00:01+02:00 2024-10-01T00:00:01 Timestamp(Nanosecond, None) +2024-11-01T00:00:01+01:00 2024-11-01T00:00:01 Timestamp(Nanosecond, None) +2024-12-01T00:00:01+01:00 2024-12-01T00:00:01 Timestamp(Nanosecond, None) + +# combine to_local_time() with date_bin() +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_utc; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +query P +select date_bin(interval '1 day', to_local_time(column1)) AT TIME ZONE 'Europe/Brussels' as date_bin from t_timezone; +---- +2024-01-01T00:00:00+01:00 +2024-02-01T00:00:00+01:00 +2024-03-01T00:00:00+01:00 +2024-04-01T00:00:00+02:00 +2024-05-01T00:00:00+02:00 +2024-06-01T00:00:00+02:00 +2024-07-01T00:00:00+02:00 +2024-08-01T00:00:00+02:00 +2024-09-01T00:00:00+02:00 +2024-10-01T00:00:00+02:00 +2024-11-01T00:00:00+01:00 +2024-12-01T00:00:00+01:00 + +statement ok +drop table t; + +statement ok +drop view t_utc; + +statement ok +drop view t_timezone; + +# test comparisons across timestamps +statement ok +create table t AS +VALUES + ('2024-01-01T00:00:01Z'), + ('2024-02-01T00:00:01Z'), + ('2024-03-01T00:00:01Z') +; + +statement ok +create view t_utc as +select column1::timestamp AT TIME ZONE 'UTC' as "column1" +from t; + +statement ok +create view t_europe as +select column1::timestamp AT TIME ZONE 'Europe/Brussels' as "column1" +from t; + +query P +SELECT column1 FROM t_utc WHERE column1 < '2024-02-01T00:00:00' AT TIME ZONE 'America/Los_Angeles'; +---- +2024-01-01T00:00:01Z +2024-02-01T00:00:01Z + +query P +SELECT column1 FROM t_europe WHERE column1 = '2024-01-31T16:00:01' AT TIME ZONE 'America/Los_Angeles'; +---- +2024-02-01T00:00:01+01:00 + +query P +SELECT column1 FROM t_europe WHERE column1 BETWEEN '2020-01-01T00:00:00' AT TIME ZONE 'Australia/Brisbane' AND '2024-02-01T00:00:00' AT TIME ZONE 'America/Los_Angeles'; +---- +2024-01-01T00:00:01+01:00 +2024-02-01T00:00:01+01:00 + +query P +SELECT column1 FROM t_utc WHERE column1 IN ('2024-01-31T16:00:01' AT TIME ZONE 'America/Los_Angeles'); +---- +2024-02-01T00:00:01Z + +query P +SELECT column1 as u from t_utc UNION SELECT column1 from t_europe ORDER BY u; +---- +2023-12-31T23:00:01Z +2024-01-01T00:00:01Z +2024-01-31T23:00:01Z +2024-02-01T00:00:01Z +2024-02-29T23:00:01Z +2024-03-01T00:00:01Z + +query P +SELECT column1 as e from t_europe UNION SELECT column1 from t_utc ORDER BY e; +---- +2024-01-01T00:00:01+01:00 +2024-01-01T01:00:01+01:00 +2024-02-01T00:00:01+01:00 +2024-02-01T01:00:01+01:00 +2024-03-01T00:00:01+01:00 +2024-03-01T01:00:01+01:00 + +query P +SELECT nvl2(null, '2020-01-01T00:00:00-04:00'::timestamp, '2021-02-03T04:05:06Z'::timestamp) +---- +2021-02-03T04:05:06 + +query ? +SELECT make_array('2020-01-01T00:00:00-04:00'::timestamp, '2021-01-01T01:02:03Z'::timestamp); +---- +[2020-01-01T04:00:00, 2021-01-01T01:02:03] + +query P +SELECT * FROM VALUES + ('2023-12-31T23:00:00Z' AT TIME ZONE 'UTC'), + ('2024-02-01T00:00:00' AT TIME ZONE 'America/Los_Angeles'); +---- +2023-12-31T15:00:00-08:00 +2024-02-01T00:00:00-08:00 + +query P +SELECT * FROM VALUES + ('2024-02-01T00:00:00' AT TIME ZONE 'America/Los_Angeles'), + ('2023-12-31T23:00:00' AT TIME ZONE 'UTC'); +---- +2024-02-01T08:00:00Z +2023-12-31T23:00:00Z + +# interval vs. duration comparison +query B +select (now() - now()) < interval '1 seconds'; +---- +true + +query B +select (now() - now()) <= interval '1 seconds'; +---- +true + +query B +select (now() - now()) = interval '0 seconds'; +---- +true + +query B +select (now() - now()) != interval '1 seconds'; +---- +true + +query B +select (now() - now()) > interval '-1 seconds'; +---- +true + +query B +select (now() - now()) >= interval '-1 seconds'; +---- +true + +query B +select arrow_cast(123, 'Duration(Nanosecond)') < interval '200 nanoseconds'; +---- +true + +query B +select arrow_cast(123, 'Duration(Nanosecond)') < interval '100 nanoseconds'; +---- +false + +query B +select arrow_cast(123, 'Duration(Nanosecond)') < interval '1 seconds'; +---- +true + +query B +select interval '1 seconds' < arrow_cast(123, 'Duration(Nanosecond)') +---- +false + +# interval as LHS +query B +select interval '2 seconds' = interval '2 seconds'; +---- +true + +query B +select interval '1 seconds' < interval '2 seconds'; +---- +true + +statement ok +drop table t; + +statement ok +drop view t_utc; + +statement ok +drop view t_europe; diff --git a/datafusion/sqllogictest/test_files/topk.slt b/datafusion/sqllogictest/test_files/topk.slt index 1035c49d1fc0..1dbce79e0f1a 100644 --- a/datafusion/sqllogictest/test_files/topk.slt +++ b/datafusion/sqllogictest/test_files/topk.slt @@ -69,20 +69,18 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TT explain select * from aggregate_test_100 ORDER BY c13 desc limit 5; ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: aggregate_test_100.c13 DESC NULLS FIRST, fetch=5 -03)----TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] +01)Sort: aggregate_test_100.c13 DESC NULLS FIRST, fetch=5 +02)--TableScan: aggregate_test_100 projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[c13@12 DESC] -03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true +01)SortExec: TopK(fetch=5), expr=[c13@12 DESC], preserve_partitioning=[false] +02)--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13], has_header=true @@ -222,7 +220,7 @@ a 1 -5 12636 794623392 2909750622865366631 15 24022 2669374863 47766797847015095 statement ok create table dict as select c1, c2, c3, c13, arrow_cast(c13, 'Dictionary(Int32, Utf8)') as c13_dict from aggregate_test_100; -query TIIT? +query TIITT select * from dict order by c13 desc limit 5; ---- a 4 -38 ydkwycaISlYSlEq3TlkS2m15I2pcp8 ydkwycaISlYSlEq3TlkS2m15I2pcp8 diff --git a/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part index 2f5e2d5a7616..d6249cb57990 100644 --- a/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part @@ -31,7 +31,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS supplier ( s_acctbal DECIMAL(15, 2), s_comment VARCHAR, s_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/supplier.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/supplier.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS part ( @@ -45,7 +45,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS part ( p_retailprice DECIMAL(15, 2), p_comment VARCHAR, p_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/part.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/part.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok @@ -56,7 +56,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS partsupp ( ps_supplycost DECIMAL(15, 2), ps_comment VARCHAR, ps_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/partsupp.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/partsupp.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS customer ( @@ -69,7 +69,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS customer ( c_mktsegment VARCHAR, c_comment VARCHAR, c_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/customer.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/customer.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS orders ( @@ -83,7 +83,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS orders ( o_shippriority INTEGER, o_comment VARCHAR, o_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/orders.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/orders.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( @@ -104,7 +104,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS lineitem ( l_shipmode VARCHAR, l_comment VARCHAR, l_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/lineitem.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/lineitem.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS nation ( @@ -113,7 +113,7 @@ CREATE EXTERNAL TABLE IF NOT EXISTS nation ( n_regionkey BIGINT, n_comment VARCHAR, n_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/nation.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/nation.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); statement ok CREATE EXTERNAL TABLE IF NOT EXISTS region ( @@ -121,4 +121,4 @@ CREATE EXTERNAL TABLE IF NOT EXISTS region ( r_name VARCHAR, r_comment VARCHAR, r_rev VARCHAR, -) STORED AS CSV DELIMITER '|' LOCATION 'test_files/tpch/data/region.tbl'; +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ('format.delimiter' '|', 'format.has_header' 'false'); diff --git a/datafusion/sqllogictest/test_files/tpch/q1.slt.part b/datafusion/sqllogictest/test_files/tpch/q1.slt.part index 175040420160..4d4323e93e9e 100644 --- a/datafusion/sqllogictest/test_files/tpch/q1.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part @@ -41,22 +41,22 @@ explain select ---- logical_plan 01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS LAST -02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, SUM(lineitem.l_quantity) AS sum_qty, SUM(lineitem.l_extendedprice) AS sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order -03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(Int64(1)) AS COUNT(*)]] -04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus -05)--------Filter: lineitem.l_shipdate <= Date32("10471") -06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("10471")] +02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, avg(lineitem.l_quantity) AS avg_qty, avg(lineitem.l_extendedprice) AS avg_price, avg(lineitem.l_discount) AS avg_disc, count(*) AS count_order +03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(__common_expr_1 * (Decimal128(Some(1),20,0) + lineitem.l_tax)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), avg(lineitem.l_quantity), avg(lineitem.l_extendedprice), avg(lineitem.l_discount), count(Int64(1)) AS count(*)]] +04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus +05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02") +06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")] physical_plan -01)SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST] -02)--SortExec: expr=[l_returnflag@0 ASC NULLS LAST,l_linestatus@1 ASC NULLS LAST] -03)----ProjectionExec: expr=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus, SUM(lineitem.l_quantity)@2 as sum_qty, SUM(lineitem.l_extendedprice)@3 as sum_base_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@4 as sum_disc_price, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax)@5 as sum_charge, AVG(lineitem.l_quantity)@6 as avg_qty, AVG(lineitem.l_extendedprice)@7 as avg_price, AVG(lineitem.l_discount)@8 as avg_disc, COUNT(*)@9 as count_order] -04)------AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)] +01)SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST, l_linestatus@1 ASC NULLS LAST] +02)--SortExec: expr=[l_returnflag@0 ASC NULLS LAST, l_linestatus@1 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus, sum(lineitem.l_quantity)@2 as sum_qty, sum(lineitem.l_extendedprice)@3 as sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@4 as sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax)@5 as sum_charge, avg(lineitem.l_quantity)@6 as avg_qty, avg(lineitem.l_extendedprice)@7 as avg_price, avg(lineitem.l_discount)@8 as avg_disc, count(*)@9 as count_order] +04)------AggregateExec: mode=FinalPartitioned, gby=[l_returnflag@0 as l_returnflag, l_linestatus@1 as l_linestatus], aggr=[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), avg(lineitem.l_quantity), avg(lineitem.l_extendedprice), avg(lineitem.l_discount), count(*)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([l_returnflag@0, l_linestatus@1], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[SUM(lineitem.l_quantity), SUM(lineitem.l_extendedprice), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), AVG(lineitem.l_quantity), AVG(lineitem.l_extendedprice), AVG(lineitem.l_discount), COUNT(*)] -08)--------------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, l_quantity@0 as l_quantity, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_tax@3 as l_tax, l_returnflag@4 as l_returnflag, l_linestatus@5 as l_linestatus] +07)------------AggregateExec: mode=Partial, gby=[l_returnflag@5 as l_returnflag, l_linestatus@6 as l_linestatus], aggr=[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount * Int64(1) + lineitem.l_tax), avg(lineitem.l_quantity), avg(lineitem.l_extendedprice), avg(lineitem.l_discount), count(*)] +08)--------------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as __common_expr_1, l_quantity@0 as l_quantity, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, l_tax@3 as l_tax, l_returnflag@4 as l_returnflag, l_linestatus@5 as l_linestatus] 09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------FilterExec: l_shipdate@6 <= 10471 +10)------------------FilterExec: l_shipdate@6 <= 1998-09-02, projection=[l_quantity@0, l_extendedprice@1, l_discount@2, l_tax@3, l_returnflag@4, l_linestatus@5] 11)--------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_quantity, l_extendedprice, l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], has_header=false query TTRRRRRRRI @@ -80,7 +80,7 @@ group by l_linestatus order by l_returnflag, - l_linestatus; + l_linestatus; ---- A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587 36002.123829 0.050144 147790 N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916 0.049394 3765 diff --git a/datafusion/sqllogictest/test_files/tpch/q10.slt.part b/datafusion/sqllogictest/test_files/tpch/q10.slt.part index 2a3168b5c1bf..73593a470c9a 100644 --- a/datafusion/sqllogictest/test_files/tpch/q10.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q10.slt.part @@ -51,63 +51,59 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: revenue DESC NULLS FIRST, fetch=10 -03)----Projection: customer.c_custkey, customer.c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, customer.c_acctbal, nation.n_name, customer.c_address, customer.c_phone, customer.c_comment -04)------Aggregate: groupBy=[[customer.c_custkey, customer.c_name, customer.c_acctbal, customer.c_phone, nation.n_name, customer.c_address, customer.c_comment]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -05)--------Projection: customer.c_custkey, customer.c_name, customer.c_address, customer.c_phone, customer.c_acctbal, customer.c_comment, lineitem.l_extendedprice, lineitem.l_discount, nation.n_name -06)----------Inner Join: customer.c_nationkey = nation.n_nationkey -07)------------Projection: customer.c_custkey, customer.c_name, customer.c_address, customer.c_nationkey, customer.c_phone, customer.c_acctbal, customer.c_comment, lineitem.l_extendedprice, lineitem.l_discount -08)--------------Inner Join: orders.o_orderkey = lineitem.l_orderkey -09)----------------Projection: customer.c_custkey, customer.c_name, customer.c_address, customer.c_nationkey, customer.c_phone, customer.c_acctbal, customer.c_comment, orders.o_orderkey -10)------------------Inner Join: customer.c_custkey = orders.o_custkey -11)--------------------TableScan: customer projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment] -12)--------------------Projection: orders.o_orderkey, orders.o_custkey -13)----------------------Filter: orders.o_orderdate >= Date32("8674") AND orders.o_orderdate < Date32("8766") -14)------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("8674"), orders.o_orderdate < Date32("8766")] -15)----------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount -16)------------------Filter: lineitem.l_returnflag = Utf8("R") -17)--------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8("R")] -18)------------TableScan: nation projection=[n_nationkey, n_name] +01)Sort: revenue DESC NULLS FIRST, fetch=10 +02)--Projection: customer.c_custkey, customer.c_name, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, customer.c_acctbal, nation.n_name, customer.c_address, customer.c_phone, customer.c_comment +03)----Aggregate: groupBy=[[customer.c_custkey, customer.c_name, customer.c_acctbal, customer.c_phone, nation.n_name, customer.c_address, customer.c_comment]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +04)------Projection: customer.c_custkey, customer.c_name, customer.c_address, customer.c_phone, customer.c_acctbal, customer.c_comment, lineitem.l_extendedprice, lineitem.l_discount, nation.n_name +05)--------Inner Join: customer.c_nationkey = nation.n_nationkey +06)----------Projection: customer.c_custkey, customer.c_name, customer.c_address, customer.c_nationkey, customer.c_phone, customer.c_acctbal, customer.c_comment, lineitem.l_extendedprice, lineitem.l_discount +07)------------Inner Join: orders.o_orderkey = lineitem.l_orderkey +08)--------------Projection: customer.c_custkey, customer.c_name, customer.c_address, customer.c_nationkey, customer.c_phone, customer.c_acctbal, customer.c_comment, orders.o_orderkey +09)----------------Inner Join: customer.c_custkey = orders.o_custkey +10)------------------TableScan: customer projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment] +11)------------------Projection: orders.o_orderkey, orders.o_custkey +12)--------------------Filter: orders.o_orderdate >= Date32("1993-10-01") AND orders.o_orderdate < Date32("1994-01-01") +13)----------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("1993-10-01"), orders.o_orderdate < Date32("1994-01-01")] +14)--------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount +15)----------------Filter: lineitem.l_returnflag = Utf8("R") +16)------------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], partial_filters=[lineitem.l_returnflag = Utf8("R")] +17)----------TableScan: nation projection=[n_nationkey, n_name] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [revenue@2 DESC], fetch=10 -03)----SortExec: TopK(fetch=10), expr=[revenue@2 DESC] -04)------ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment] -05)--------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------RepartitionExec: partitioning=Hash([c_custkey@0, c_name@1, c_acctbal@2, c_phone@3, n_name@4, c_address@5, c_comment@6], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@4 as c_acctbal, c_phone@3 as c_phone, n_name@8 as n_name, c_address@2 as c_address, c_comment@5 as c_comment], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_nationkey@3, n_nationkey@0)], projection=[c_custkey@0, c_name@1, c_address@2, c_phone@4, c_acctbal@5, c_comment@6, l_extendedprice@7, l_discount@8, n_name@10] -11)--------------------CoalesceBatchesExec: target_batch_size=8192 -12)----------------------RepartitionExec: partitioning=Hash([c_nationkey@3], 4), input_partitions=4 -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@7, l_orderkey@0)], projection=[c_custkey@0, c_name@1, c_address@2, c_nationkey@3, c_phone@4, c_acctbal@5, c_comment@6, l_extendedprice@9, l_discount@10] -15)----------------------------CoalesceBatchesExec: target_batch_size=8192 -16)------------------------------RepartitionExec: partitioning=Hash([o_orderkey@7], 4), input_partitions=4 -17)--------------------------------CoalesceBatchesExec: target_batch_size=8192 -18)----------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)], projection=[c_custkey@0, c_name@1, c_address@2, c_nationkey@3, c_phone@4, c_acctbal@5, c_comment@6, o_orderkey@7] -19)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -20)--------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -21)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -22)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment], has_header=false -23)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -24)--------------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 -25)----------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] -26)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -27)--------------------------------------------FilterExec: o_orderdate@2 >= 8674 AND o_orderdate@2 < 8766 -28)----------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false -29)----------------------------CoalesceBatchesExec: target_batch_size=8192 -30)------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -31)--------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -32)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -33)------------------------------------FilterExec: l_returnflag@3 = R -34)--------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], has_header=false -35)--------------------CoalesceBatchesExec: target_batch_size=8192 -36)----------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -37)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -38)--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +01)SortPreservingMergeExec: [revenue@2 DESC], fetch=10 +02)--SortExec: TopK(fetch=10), expr=[revenue@2 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c_custkey@0 as c_custkey, c_name@1 as c_name, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@7 as revenue, c_acctbal@2 as c_acctbal, n_name@4 as n_name, c_address@5 as c_address, c_phone@3 as c_phone, c_comment@6 as c_comment] +04)------AggregateExec: mode=FinalPartitioned, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@2 as c_acctbal, c_phone@3 as c_phone, n_name@4 as n_name, c_address@5 as c_address, c_comment@6 as c_comment], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([c_custkey@0, c_name@1, c_acctbal@2, c_phone@3, n_name@4, c_address@5, c_comment@6], 4), input_partitions=4 +07)------------AggregateExec: mode=Partial, gby=[c_custkey@0 as c_custkey, c_name@1 as c_name, c_acctbal@4 as c_acctbal, c_phone@3 as c_phone, n_name@8 as n_name, c_address@2 as c_address, c_comment@5 as c_comment], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_nationkey@3, n_nationkey@0)], projection=[c_custkey@0, c_name@1, c_address@2, c_phone@4, c_acctbal@5, c_comment@6, l_extendedprice@7, l_discount@8, n_name@10] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------RepartitionExec: partitioning=Hash([c_nationkey@3], 4), input_partitions=4 +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@7, l_orderkey@0)], projection=[c_custkey@0, c_name@1, c_address@2, c_nationkey@3, c_phone@4, c_acctbal@5, c_comment@6, l_extendedprice@9, l_discount@10] +14)--------------------------CoalesceBatchesExec: target_batch_size=8192 +15)----------------------------RepartitionExec: partitioning=Hash([o_orderkey@7], 4), input_partitions=4 +16)------------------------------CoalesceBatchesExec: target_batch_size=8192 +17)--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)], projection=[c_custkey@0, c_name@1, c_address@2, c_nationkey@3, c_phone@4, c_acctbal@5, c_comment@6, o_orderkey@7] +18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 +19)------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +20)--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +21)----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_comment], has_header=false +22)----------------------------------CoalesceBatchesExec: target_batch_size=8192 +23)------------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 +24)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 +25)----------------------------------------FilterExec: o_orderdate@2 >= 1993-10-01 AND o_orderdate@2 < 1994-01-01, projection=[o_orderkey@0, o_custkey@1] +26)------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false +27)--------------------------CoalesceBatchesExec: target_batch_size=8192 +28)----------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +29)------------------------------CoalesceBatchesExec: target_batch_size=8192 +30)--------------------------------FilterExec: l_returnflag@3 = R, projection=[l_orderkey@0, l_extendedprice@1, l_discount@2] +31)----------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_returnflag], has_header=false +32)------------------CoalesceBatchesExec: target_batch_size=8192 +33)--------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +34)----------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +35)------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q11.slt.part b/datafusion/sqllogictest/test_files/tpch/q11.slt.part index 3050af6f89a2..adaf391de0a2 100644 --- a/datafusion/sqllogictest/test_files/tpch/q11.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q11.slt.part @@ -47,86 +47,82 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: value DESC NULLS FIRST, fetch=10 -03)----Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value -04)------Inner Join: Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) -05)--------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -06)----------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost -07)------------Inner Join: supplier.s_nationkey = nation.n_nationkey -08)--------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -09)----------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -10)------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] -11)------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -12)--------------Projection: nation.n_nationkey -13)----------------Filter: nation.n_name = Utf8("GERMANY") -14)------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] -15)--------SubqueryAlias: __scalar_sq_1 -16)----------Projection: CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) -17)------------Aggregate: groupBy=[[]], aggr=[[SUM(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] -18)--------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost -19)----------------Inner Join: supplier.s_nationkey = nation.n_nationkey -20)------------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey -21)--------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -22)----------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] -23)----------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -24)------------------Projection: nation.n_nationkey -25)--------------------Filter: nation.n_name = Utf8("GERMANY") -26)----------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +01)Sort: value DESC NULLS FIRST, fetch=10 +02)--Projection: partsupp.ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS value +03)----Inner Join: Filter: CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > __scalar_sq_1.sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001) +04)------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +05)--------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost +06)----------Inner Join: supplier.s_nationkey = nation.n_nationkey +07)------------Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +08)--------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] +10)----------------TableScan: supplier projection=[s_suppkey, s_nationkey] +11)------------Projection: nation.n_nationkey +12)--------------Filter: nation.n_name = Utf8("GERMANY") +13)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] +14)------SubqueryAlias: __scalar_sq_1 +15)--------Projection: CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS Decimal128(38, 15)) +16)----------Aggregate: groupBy=[[]], aggr=[[sum(partsupp.ps_supplycost * CAST(partsupp.ps_availqty AS Decimal128(10, 0)))]] +17)------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost +18)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey +19)----------------Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey +20)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +21)--------------------TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] +22)--------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +23)----------------Projection: nation.n_nationkey +24)------------------Filter: nation.n_name = Utf8("GERMANY") +25)--------------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortExec: TopK(fetch=10), expr=[value@1 DESC] -03)----ProjectionExec: expr=[ps_partkey@0 as ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] -04)------NestedLoopJoinExec: join_type=Inner, filter=CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1 -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -07)------------CoalesceBatchesExec: target_batch_size=8192 -08)--------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -09)----------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] -12)----------------------CoalesceBatchesExec: target_batch_size=8192 -13)------------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 -14)--------------------------CoalesceBatchesExec: target_batch_size=8192 -15)----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] -16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 -18)----------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], has_header=false -19)------------------------------CoalesceBatchesExec: target_batch_size=8192 -20)--------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -21)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -22)------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -23)----------------------CoalesceBatchesExec: target_batch_size=8192 -24)------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -25)--------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] -26)----------------------------CoalesceBatchesExec: target_batch_size=8192 -27)------------------------------FilterExec: n_name@1 = GERMANY -28)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -29)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false -30)--------ProjectionExec: expr=[CAST(CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as SUM(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] -31)----------AggregateExec: mode=Final, gby=[], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -32)------------CoalescePartitionsExec -33)--------------AggregateExec: mode=Partial, gby=[], aggr=[SUM(partsupp.ps_supplycost * partsupp.ps_availqty)] -34)----------------CoalesceBatchesExec: target_batch_size=8192 -35)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] -36)--------------------CoalesceBatchesExec: target_batch_size=8192 -37)----------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 -38)------------------------CoalesceBatchesExec: target_batch_size=8192 -39)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] -40)----------------------------CoalesceBatchesExec: target_batch_size=8192 -41)------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -42)--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], has_header=false -43)----------------------------CoalesceBatchesExec: target_batch_size=8192 -44)------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -45)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -46)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -47)--------------------CoalesceBatchesExec: target_batch_size=8192 -48)----------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -49)------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] -50)--------------------------CoalesceBatchesExec: target_batch_size=8192 -51)----------------------------FilterExec: n_name@1 = GERMANY -52)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -53)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +01)SortExec: TopK(fetch=10), expr=[value@1 DESC], preserve_partitioning=[false] +02)--ProjectionExec: expr=[ps_partkey@0 as ps_partkey, sum(partsupp.ps_supplycost * partsupp.ps_availqty)@1 as value] +03)----NestedLoopJoinExec: join_type=Inner, filter=CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Decimal128(38, 15)) > sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)@1 +04)------CoalescePartitionsExec +05)--------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +08)--------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[ps_partkey@0, ps_availqty@1, ps_supplycost@2] +11)--------------------CoalesceBatchesExec: target_batch_size=8192 +12)----------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 +13)------------------------CoalesceBatchesExec: target_batch_size=8192 +14)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_availqty@2, ps_supplycost@3, s_nationkey@5] +15)----------------------------CoalesceBatchesExec: target_batch_size=8192 +16)------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +17)--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost], has_header=false +18)----------------------------CoalesceBatchesExec: target_batch_size=8192 +19)------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +20)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +21)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +22)--------------------CoalesceBatchesExec: target_batch_size=8192 +23)----------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +24)------------------------CoalesceBatchesExec: target_batch_size=8192 +25)--------------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +26)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +27)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +28)------ProjectionExec: expr=[CAST(CAST(sum(partsupp.ps_supplycost * partsupp.ps_availqty)@0 AS Float64) * 0.0001 AS Decimal128(38, 15)) as sum(partsupp.ps_supplycost * partsupp.ps_availqty) * Float64(0.0001)] +29)--------AggregateExec: mode=Final, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +30)----------CoalescePartitionsExec +31)------------AggregateExec: mode=Partial, gby=[], aggr=[sum(partsupp.ps_supplycost * partsupp.ps_availqty)] +32)--------------CoalesceBatchesExec: target_batch_size=8192 +33)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_availqty@0, ps_supplycost@1] +34)------------------CoalesceBatchesExec: target_batch_size=8192 +35)--------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 +36)----------------------CoalesceBatchesExec: target_batch_size=8192 +37)------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@0, s_suppkey@0)], projection=[ps_availqty@1, ps_supplycost@2, s_nationkey@4] +38)--------------------------CoalesceBatchesExec: target_batch_size=8192 +39)----------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +40)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_suppkey, ps_availqty, ps_supplycost], has_header=false +41)--------------------------CoalesceBatchesExec: target_batch_size=8192 +42)----------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +43)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +44)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +45)------------------CoalesceBatchesExec: target_batch_size=8192 +46)--------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +47)----------------------CoalesceBatchesExec: target_batch_size=8192 +48)------------------------FilterExec: n_name@1 = GERMANY, projection=[n_nationkey@0] +49)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +50)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q12.slt.part b/datafusion/sqllogictest/test_files/tpch/q12.slt.part index 95cde1b0f105..b0d0baba90b0 100644 --- a/datafusion/sqllogictest/test_files/tpch/q12.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q12.slt.part @@ -50,33 +50,32 @@ order by ---- logical_plan 01)Sort: lineitem.l_shipmode ASC NULLS LAST -02)--Projection: lineitem.l_shipmode, SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count -03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] +02)--Projection: lineitem.l_shipmode, sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS high_line_count, sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END) AS low_line_count +03)----Aggregate: groupBy=[[lineitem.l_shipmode]], aggr=[[sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)]] 04)------Projection: lineitem.l_shipmode, orders.o_orderpriority 05)--------Inner Join: lineitem.l_orderkey = orders.o_orderkey 06)----------Projection: lineitem.l_orderkey, lineitem.l_shipmode -07)------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("8766") AND lineitem.l_receiptdate < Date32("9131") -08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("8766"), lineitem.l_receiptdate < Date32("9131")] +07)------------Filter: (lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP")) AND lineitem.l_receiptdate > lineitem.l_commitdate AND lineitem.l_shipdate < lineitem.l_commitdate AND lineitem.l_receiptdate >= Date32("1994-01-01") AND lineitem.l_receiptdate < Date32("1995-01-01") +08)--------------TableScan: lineitem projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8("MAIL") OR lineitem.l_shipmode = Utf8("SHIP"), lineitem.l_receiptdate > lineitem.l_commitdate, lineitem.l_shipdate < lineitem.l_commitdate, lineitem.l_receiptdate >= Date32("1994-01-01"), lineitem.l_receiptdate < Date32("1995-01-01")] 09)----------TableScan: orders projection=[o_orderkey, o_orderpriority] physical_plan 01)SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST] -02)--SortExec: expr=[l_shipmode@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[l_shipmode@0 as l_shipmode, SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@1 as high_line_count, SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@2 as low_line_count] -04)------AggregateExec: mode=FinalPartitioned, gby=[l_shipmode@0 as l_shipmode], aggr=[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)] +02)--SortExec: expr=[l_shipmode@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[l_shipmode@0 as l_shipmode, sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@1 as high_line_count, sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)@2 as low_line_count] +04)------AggregateExec: mode=FinalPartitioned, gby=[l_shipmode@0 as l_shipmode], aggr=[sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([l_shipmode@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[l_shipmode@0 as l_shipmode], aggr=[SUM(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), SUM(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)] +07)------------AggregateExec: mode=Partial, gby=[l_shipmode@0 as l_shipmode], aggr=[sum(CASE WHEN orders.o_orderpriority = Utf8("1-URGENT") OR orders.o_orderpriority = Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END), sum(CASE WHEN orders.o_orderpriority != Utf8("1-URGENT") AND orders.o_orderpriority != Utf8("2-HIGH") THEN Int64(1) ELSE Int64(0) END)] 08)--------------CoalesceBatchesExec: target_batch_size=8192 09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)], projection=[l_shipmode@1, o_orderpriority@3] 10)------------------CoalesceBatchesExec: target_batch_size=8192 11)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -12)----------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_shipmode@4 as l_shipmode] -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = SHIP) AND l_receiptdate@3 > l_commitdate@2 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 8766 AND l_receiptdate@3 < 9131 -15)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false -16)------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 -18)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderpriority], has_header=false +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------FilterExec: (l_shipmode@4 = MAIL OR l_shipmode@4 = SHIP) AND l_receiptdate@3 > l_commitdate@2 AND l_shipdate@1 < l_commitdate@2 AND l_receiptdate@3 >= 1994-01-01 AND l_receiptdate@3 < 1995-01-01, projection=[l_orderkey@0, l_shipmode@4] +14)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_shipdate, l_commitdate, l_receiptdate, l_shipmode], has_header=false +15)------------------CoalesceBatchesExec: target_batch_size=8192 +16)--------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +17)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderpriority], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q13.slt.part b/datafusion/sqllogictest/test_files/tpch/q13.slt.part index cf4983fb0431..2a9fb12a31c2 100644 --- a/datafusion/sqllogictest/test_files/tpch/q13.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q13.slt.part @@ -40,42 +40,39 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: custdist DESC NULLS FIRST, c_orders.c_count DESC NULLS FIRST, fetch=10 -03)----Projection: c_orders.c_count, COUNT(*) AS custdist -04)------Aggregate: groupBy=[[c_orders.c_count]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] -05)--------SubqueryAlias: c_orders -06)----------Projection: COUNT(orders.o_orderkey) AS c_count -07)------------Aggregate: groupBy=[[customer.c_custkey]], aggr=[[COUNT(orders.o_orderkey)]] -08)--------------Projection: customer.c_custkey, orders.o_orderkey -09)----------------Left Join: customer.c_custkey = orders.o_custkey -10)------------------TableScan: customer projection=[c_custkey] -11)------------------Projection: orders.o_orderkey, orders.o_custkey -12)--------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") -13)----------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] +01)Sort: custdist DESC NULLS FIRST, c_orders.c_count DESC NULLS FIRST, fetch=10 +02)--Projection: c_orders.c_count, count(*) AS custdist +03)----Aggregate: groupBy=[[c_orders.c_count]], aggr=[[count(Int64(1)) AS count(*)]] +04)------SubqueryAlias: c_orders +05)--------Projection: count(orders.o_orderkey) AS c_count +06)----------Aggregate: groupBy=[[customer.c_custkey]], aggr=[[count(orders.o_orderkey)]] +07)------------Projection: customer.c_custkey, orders.o_orderkey +08)--------------Left Join: customer.c_custkey = orders.o_custkey +09)----------------TableScan: customer projection=[c_custkey] +10)----------------Projection: orders.o_orderkey, orders.o_custkey +11)------------------Filter: orders.o_comment NOT LIKE Utf8("%special%requests%") +12)--------------------TableScan: orders projection=[o_orderkey, o_custkey, o_comment], partial_filters=[orders.o_comment NOT LIKE Utf8("%special%requests%")] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [custdist@1 DESC,c_count@0 DESC], fetch=10 -03)----SortExec: TopK(fetch=10), expr=[custdist@1 DESC,c_count@0 DESC] -04)------ProjectionExec: expr=[c_count@0 as c_count, COUNT(*)@1 as custdist] -05)--------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[COUNT(*)] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------RepartitionExec: partitioning=Hash([c_count@0], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[COUNT(*)] -09)----------------ProjectionExec: expr=[COUNT(orders.o_orderkey)@1 as c_count] -10)------------------AggregateExec: mode=SinglePartitioned, gby=[c_custkey@0 as c_custkey], aggr=[COUNT(orders.o_orderkey)] -11)--------------------CoalesceBatchesExec: target_batch_size=8192 -12)----------------------HashJoinExec: mode=Partitioned, join_type=Left, on=[(c_custkey@0, o_custkey@1)], projection=[c_custkey@0, o_orderkey@1] -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey], has_header=false -17)------------------------CoalesceBatchesExec: target_batch_size=8192 -18)--------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 -19)----------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] -20)------------------------------CoalesceBatchesExec: target_batch_size=8192 -21)--------------------------------FilterExec: o_comment@2 NOT LIKE %special%requests% -22)----------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_comment], has_header=false +01)SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC], fetch=10 +02)--SortExec: TopK(fetch=10), expr=[custdist@1 DESC, c_count@0 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[c_count@0 as c_count, count(*)@1 as custdist] +04)------AggregateExec: mode=FinalPartitioned, gby=[c_count@0 as c_count], aggr=[count(*)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([c_count@0], 4), input_partitions=4 +07)------------AggregateExec: mode=Partial, gby=[c_count@0 as c_count], aggr=[count(*)] +08)--------------ProjectionExec: expr=[count(orders.o_orderkey)@1 as c_count] +09)----------------AggregateExec: mode=SinglePartitioned, gby=[c_custkey@0 as c_custkey], aggr=[count(orders.o_orderkey)] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------HashJoinExec: mode=Partitioned, join_type=Left, on=[(c_custkey@0, o_custkey@1)], projection=[c_custkey@0, o_orderkey@1] +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +14)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +15)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey], has_header=false +16)----------------------CoalesceBatchesExec: target_batch_size=8192 +17)------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 +18)--------------------------CoalesceBatchesExec: target_batch_size=8192 +19)----------------------------FilterExec: o_comment@2 NOT LIKE %special%requests%, projection=[o_orderkey@0, o_custkey@1] +20)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_comment], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/q14.slt.part index 3d598fdfb63a..eee10cb3f8e2 100644 --- a/datafusion/sqllogictest/test_files/tpch/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q14.slt.part @@ -32,31 +32,31 @@ where and l_shipdate < date '1995-10-01'; ---- logical_plan -01)Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue -02)--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -03)----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type +01)Projection: Float64(100) * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue +02)--Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN __common_expr_1 ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(__common_expr_1) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +03)----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS __common_expr_1, part.p_type 04)------Inner Join: lineitem.l_partkey = part.p_partkey 05)--------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount -06)----------Filter: lineitem.l_shipdate >= Date32("9374") AND lineitem.l_shipdate < Date32("9404") -07)------------TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9374"), lineitem.l_shipdate < Date32("9404")] +06)----------Filter: lineitem.l_shipdate >= Date32("1995-09-01") AND lineitem.l_shipdate < Date32("1995-10-01") +07)------------TableScan: lineitem projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1995-09-01"), lineitem.l_shipdate < Date32("1995-10-01")] 08)--------TableScan: part projection=[p_partkey, p_type] physical_plan -01)ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END)@0 AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 AS Float64) as promo_revenue] -02)--AggregateExec: mode=Final, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +01)ProjectionExec: expr=[100 * CAST(sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END)@0 AS Float64) / CAST(sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 AS Float64) as promo_revenue] +02)--AggregateExec: mode=Final, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_extendedprice@1, l_discount@2, p_type@4] -07)------------CoalesceBatchesExec: target_batch_size=8192 -08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -09)----------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +05)--------ProjectionExec: expr=[l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as __common_expr_1, p_type@2 as p_type] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_extendedprice@1, l_discount@2, p_type@4] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------FilterExec: l_shipdate@3 >= 9374 AND l_shipdate@3 < 9404 +11)--------------------FilterExec: l_shipdate@3 >= 1995-09-01 AND l_shipdate@3 < 1995-10-01, projection=[l_partkey@0, l_extendedprice@1, l_discount@2] 12)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_extendedprice, l_discount, l_shipdate], has_header=false -13)------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -15)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -16)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false +13)--------------CoalesceBatchesExec: target_batch_size=8192 +14)----------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +15)------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q15.slt.part b/datafusion/sqllogictest/test_files/tpch/q15.slt.part index c01cc52ce4bc..2374fd8430a4 100644 --- a/datafusion/sqllogictest/test_files/tpch/q15.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q15.slt.part @@ -52,29 +52,29 @@ order by logical_plan 01)Sort: supplier.s_suppkey ASC NULLS LAST 02)--Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue -03)----Inner Join: revenue0.total_revenue = __scalar_sq_1.MAX(revenue0.total_revenue) +03)----Inner Join: revenue0.total_revenue = __scalar_sq_1.max(revenue0.total_revenue) 04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address, supplier.s_phone, revenue0.total_revenue 05)--------Inner Join: supplier.s_suppkey = revenue0.supplier_no 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone] 07)----------SubqueryAlias: revenue0 -08)------------Projection: lineitem.l_suppkey AS supplier_no, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -09)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +08)------------Projection: lineitem.l_suppkey AS supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +09)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 10)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -11)------------------Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") -12)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9496"), lineitem.l_shipdate < Date32("9587")] +11)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +12)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] 13)------SubqueryAlias: __scalar_sq_1 -14)--------Aggregate: groupBy=[[]], aggr=[[MAX(revenue0.total_revenue)]] +14)--------Aggregate: groupBy=[[]], aggr=[[max(revenue0.total_revenue)]] 15)----------SubqueryAlias: revenue0 -16)------------Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue -17)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +16)------------Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue +17)--------------Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 18)----------------Projection: lineitem.l_suppkey, lineitem.l_extendedprice, lineitem.l_discount -19)------------------Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587") -20)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9496"), lineitem.l_shipdate < Date32("9587")] +19)------------------Filter: lineitem.l_shipdate >= Date32("1996-01-01") AND lineitem.l_shipdate < Date32("1996-04-01") +20)--------------------TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1996-01-01"), lineitem.l_shipdate < Date32("1996-04-01")] physical_plan 01)SortPreservingMergeExec: [s_suppkey@0 ASC NULLS LAST] -02)--SortExec: expr=[s_suppkey@0 ASC NULLS LAST] +02)--SortExec: expr=[s_suppkey@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 -04)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(total_revenue@4, MAX(revenue0.total_revenue)@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@4] +04)------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(total_revenue@4, max(revenue0.total_revenue)@0)], projection=[s_suppkey@0, s_name@1, s_address@2, s_phone@3, total_revenue@4] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([total_revenue@4], 4), input_partitions=4 07)------------CoalesceBatchesExec: target_batch_size=8192 @@ -83,29 +83,27 @@ physical_plan 10)------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 11)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 12)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_phone], has_header=false -13)----------------ProjectionExec: expr=[l_suppkey@0 as supplier_no, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -14)------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +13)----------------ProjectionExec: expr=[l_suppkey@0 as supplier_no, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +14)------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 15)--------------------CoalesceBatchesExec: target_batch_size=8192 16)----------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -17)------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -18)--------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -19)----------------------------CoalesceBatchesExec: target_batch_size=8192 -20)------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 -21)--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false -22)--------CoalesceBatchesExec: target_batch_size=8192 -23)----------RepartitionExec: partitioning=Hash([MAX(revenue0.total_revenue)@0], 4), input_partitions=1 -24)------------AggregateExec: mode=Final, gby=[], aggr=[MAX(revenue0.total_revenue)] -25)--------------CoalescePartitionsExec -26)----------------AggregateExec: mode=Partial, gby=[], aggr=[MAX(revenue0.total_revenue)] -27)------------------ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] -28)--------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -29)----------------------CoalesceBatchesExec: target_batch_size=8192 -30)------------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 -31)--------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -32)----------------------------ProjectionExec: expr=[l_suppkey@0 as l_suppkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -33)------------------------------CoalesceBatchesExec: target_batch_size=8192 -34)--------------------------------FilterExec: l_shipdate@3 >= 9496 AND l_shipdate@3 < 9587 -35)----------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false +17)------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +18)--------------------------CoalesceBatchesExec: target_batch_size=8192 +19)----------------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +20)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false +21)--------CoalesceBatchesExec: target_batch_size=8192 +22)----------RepartitionExec: partitioning=Hash([max(revenue0.total_revenue)@0], 4), input_partitions=1 +23)------------AggregateExec: mode=Final, gby=[], aggr=[max(revenue0.total_revenue)] +24)--------------CoalescePartitionsExec +25)----------------AggregateExec: mode=Partial, gby=[], aggr=[max(revenue0.total_revenue)] +26)------------------ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as total_revenue] +27)--------------------AggregateExec: mode=FinalPartitioned, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +28)----------------------CoalesceBatchesExec: target_batch_size=8192 +29)------------------------RepartitionExec: partitioning=Hash([l_suppkey@0], 4), input_partitions=4 +30)--------------------------AggregateExec: mode=Partial, gby=[l_suppkey@0 as l_suppkey], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +31)----------------------------CoalesceBatchesExec: target_batch_size=8192 +32)------------------------------FilterExec: l_shipdate@3 >= 1996-01-01 AND l_shipdate@3 < 1996-04-01, projection=[l_suppkey@0, l_extendedprice@1, l_discount@2] +33)--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false query ITTTR with revenue0 (supplier_no, total_revenue) as ( diff --git a/datafusion/sqllogictest/test_files/tpch/q16.slt.part b/datafusion/sqllogictest/test_files/tpch/q16.slt.part index 16f808765228..6b2c2f7fdc3e 100644 --- a/datafusion/sqllogictest/test_files/tpch/q16.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q16.slt.part @@ -50,56 +50,53 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 -03)----Projection: part.p_brand, part.p_type, part.p_size, COUNT(alias1) AS supplier_cnt -04)------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[COUNT(alias1)]] -05)--------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] -06)----------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey -07)------------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size -08)--------------Inner Join: partsupp.ps_partkey = part.p_partkey -09)----------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] -10)----------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) -11)------------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] -12)------------SubqueryAlias: __correlated_sq_1 -13)--------------Projection: supplier.s_suppkey -14)----------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") -15)------------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] +01)Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type ASC NULLS LAST, part.p_size ASC NULLS LAST, fetch=10 +02)--Projection: part.p_brand, part.p_type, part.p_size, count(alias1) AS supplier_cnt +03)----Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size]], aggr=[[count(alias1)]] +04)------Aggregate: groupBy=[[part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey AS alias1]], aggr=[[]] +05)--------LeftAnti Join: partsupp.ps_suppkey = __correlated_sq_1.s_suppkey +06)----------Projection: partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size +07)------------Inner Join: partsupp.ps_partkey = part.p_partkey +08)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey] +09)--------------Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)]) +10)----------------TableScan: part projection=[p_partkey, p_brand, p_type, p_size], partial_filters=[part.p_brand != Utf8("Brand#45"), part.p_type NOT LIKE Utf8("MEDIUM POLISHED%"), part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])] +11)----------SubqueryAlias: __correlated_sq_1 +12)------------Projection: supplier.s_suppkey +13)--------------Filter: supplier.s_comment LIKE Utf8("%Customer%Complaints%") +14)----------------TableScan: supplier projection=[s_suppkey, s_comment], partial_filters=[supplier.s_comment LIKE Utf8("%Customer%Complaints%")] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST], fetch=10 -03)----SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC,p_brand@0 ASC NULLS LAST,p_type@1 ASC NULLS LAST,p_size@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, COUNT(alias1)@3 as supplier_cnt] -05)--------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[COUNT(alias1)] -09)----------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] -10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 -12)----------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] -15)----------------------------CoalesceBatchesExec: target_batch_size=8192 -16)------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -17)--------------------------------CoalesceBatchesExec: target_batch_size=8192 -18)----------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] -19)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -20)--------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -21)----------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], has_header=false -22)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -23)--------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -24)----------------------------------------CoalesceBatchesExec: target_batch_size=8192 -25)------------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) -26)--------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -27)----------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], has_header=false -28)----------------------------CoalesceBatchesExec: target_batch_size=8192 -29)------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -30)--------------------------------ProjectionExec: expr=[s_suppkey@0 as s_suppkey] -31)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -32)------------------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints% -33)--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -34)----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], has_header=false +01)SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], fetch=10 +02)--SortExec: TopK(fetch=10), expr=[supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, count(alias1)@3 as supplier_cnt] +04)------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2], 4), input_partitions=4 +07)------------AggregateExec: mode=Partial, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size], aggr=[count(alias1)] +08)--------------AggregateExec: mode=FinalPartitioned, gby=[p_brand@0 as p_brand, p_type@1 as p_type, p_size@2 as p_size, alias1@3 as alias1], aggr=[] +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------RepartitionExec: partitioning=Hash([p_brand@0, p_type@1, p_size@2, alias1@3], 4), input_partitions=4 +11)--------------------AggregateExec: mode=Partial, gby=[p_brand@1 as p_brand, p_type@2 as p_type, p_size@3 as p_size, ps_suppkey@0 as alias1], aggr=[] +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(ps_suppkey@0, s_suppkey@0)] +14)--------------------------CoalesceBatchesExec: target_batch_size=8192 +15)----------------------------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +16)------------------------------CoalesceBatchesExec: target_batch_size=8192 +17)--------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, p_partkey@0)], projection=[ps_suppkey@1, p_brand@3, p_type@4, p_size@5] +18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 +19)------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +20)--------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey], has_header=false +21)----------------------------------CoalesceBatchesExec: target_batch_size=8192 +22)------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +23)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 +24)----------------------------------------FilterExec: p_brand@1 != Brand#45 AND p_type@2 NOT LIKE MEDIUM POLISHED% AND Use p_size@3 IN (SET) ([Literal { value: Int32(49) }, Literal { value: Int32(14) }, Literal { value: Int32(23) }, Literal { value: Int32(45) }, Literal { value: Int32(19) }, Literal { value: Int32(3) }, Literal { value: Int32(36) }, Literal { value: Int32(9) }]) +25)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +26)--------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_type, p_size], has_header=false +27)--------------------------CoalesceBatchesExec: target_batch_size=8192 +28)----------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +29)------------------------------CoalesceBatchesExec: target_batch_size=8192 +30)--------------------------------FilterExec: s_comment@1 LIKE %Customer%Complaints%, projection=[s_suppkey@0] +31)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +32)------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_comment], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q17.slt.part b/datafusion/sqllogictest/test_files/tpch/q17.slt.part index 19fd2375f66c..058bcb5f4962 100644 --- a/datafusion/sqllogictest/test_files/tpch/q17.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q17.slt.part @@ -36,10 +36,10 @@ where ); ---- logical_plan -01)Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly -02)--Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice)]] +01)Projection: CAST(sum(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly +02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice)]] 03)----Projection: lineitem.l_extendedprice -04)------Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * AVG(lineitem.l_quantity) +04)------Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < __scalar_sq_1.Float64(0.2) * avg(lineitem.l_quantity) 05)--------Projection: lineitem.l_quantity, lineitem.l_extendedprice, part.p_partkey 06)----------Inner Join: lineitem.l_partkey = part.p_partkey 07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] @@ -47,16 +47,16 @@ logical_plan 09)--------------Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") 10)----------------TableScan: part projection=[p_partkey, p_brand, p_container], partial_filters=[part.p_brand = Utf8("Brand#23"), part.p_container = Utf8("MED BOX")] 11)--------SubqueryAlias: __scalar_sq_1 -12)----------Projection: CAST(Float64(0.2) * CAST(AVG(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey -13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[AVG(lineitem.l_quantity)]] +12)----------Projection: CAST(Float64(0.2) * CAST(avg(lineitem.l_quantity) AS Float64) AS Decimal128(30, 15)), lineitem.l_partkey +13)------------Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[avg(lineitem.l_quantity)]] 14)--------------TableScan: lineitem projection=[l_partkey, l_quantity] physical_plan -01)ProjectionExec: expr=[CAST(SUM(lineitem.l_extendedprice)@0 AS Float64) / 7 as avg_yearly] -02)--AggregateExec: mode=Final, gby=[], aggr=[SUM(lineitem.l_extendedprice)] +01)ProjectionExec: expr=[CAST(sum(lineitem.l_extendedprice)@0 AS Float64) / 7 as avg_yearly] +02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[SUM(lineitem.l_extendedprice)] +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice)] 05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * AVG(lineitem.l_quantity)@1, projection=[l_extendedprice@1] +06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@2, l_partkey@1)], filter=CAST(l_quantity@0 AS Decimal128(30, 15)) < Float64(0.2) * avg(lineitem.l_quantity)@1, projection=[l_extendedprice@1] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], projection=[l_quantity@1, l_extendedprice@2, p_partkey@3] 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -64,17 +64,16 @@ physical_plan 11)--------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice], has_header=false 12)----------------CoalesceBatchesExec: target_batch_size=8192 13)------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -14)--------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -15)----------------------CoalesceBatchesExec: target_batch_size=8192 -16)------------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX -17)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -18)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false -19)------------ProjectionExec: expr=[CAST(0.2 * CAST(AVG(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as Float64(0.2) * AVG(lineitem.l_quantity), l_partkey@0 as l_partkey] -20)--------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] -21)----------------CoalesceBatchesExec: target_batch_size=8192 -22)------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -23)--------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey], aggr=[AVG(lineitem.l_quantity)] -24)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity], has_header=false +14)--------------------CoalesceBatchesExec: target_batch_size=8192 +15)----------------------FilterExec: p_brand@1 = Brand#23 AND p_container@2 = MED BOX, projection=[p_partkey@0] +16)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +17)--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_container], has_header=false +18)------------ProjectionExec: expr=[CAST(0.2 * CAST(avg(lineitem.l_quantity)@1 AS Float64) AS Decimal128(30, 15)) as Float64(0.2) * avg(lineitem.l_quantity), l_partkey@0 as l_partkey] +19)--------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey], aggr=[avg(lineitem.l_quantity)] +20)----------------CoalesceBatchesExec: target_batch_size=8192 +21)------------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 +22)--------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey], aggr=[avg(lineitem.l_quantity)] +23)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q18.slt.part b/datafusion/sqllogictest/test_files/tpch/q18.slt.part index 28ff0e07f363..c80352c5d36a 100644 --- a/datafusion/sqllogictest/test_files/tpch/q18.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q18.slt.part @@ -52,7 +52,7 @@ order by ---- logical_plan 01)Sort: orders.o_totalprice DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST -02)--Aggregate: groupBy=[[customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice]], aggr=[[SUM(lineitem.l_quantity)]] +02)--Aggregate: groupBy=[[customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice]], aggr=[[sum(lineitem.l_quantity)]] 03)----LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey 04)------Projection: customer.c_custkey, customer.c_name, orders.o_orderkey, orders.o_totalprice, orders.o_orderdate, lineitem.l_quantity 05)--------Inner Join: orders.o_orderkey = lineitem.l_orderkey @@ -63,16 +63,16 @@ logical_plan 10)----------TableScan: lineitem projection=[l_orderkey, l_quantity] 11)------SubqueryAlias: __correlated_sq_1 12)--------Projection: lineitem.l_orderkey -13)----------Filter: SUM(lineitem.l_quantity) > Decimal128(Some(30000),25,2) -14)------------Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_quantity)]] +13)----------Filter: sum(lineitem.l_quantity) > Decimal128(Some(30000),25,2) +14)------------Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[sum(lineitem.l_quantity)]] 15)--------------TableScan: lineitem projection=[l_orderkey, l_quantity] physical_plan -01)SortPreservingMergeExec: [o_totalprice@4 DESC,o_orderdate@3 ASC NULLS LAST] -02)--SortExec: expr=[o_totalprice@4 DESC,o_orderdate@3 ASC NULLS LAST] -03)----AggregateExec: mode=FinalPartitioned, gby=[c_name@0 as c_name, c_custkey@1 as c_custkey, o_orderkey@2 as o_orderkey, o_orderdate@3 as o_orderdate, o_totalprice@4 as o_totalprice], aggr=[SUM(lineitem.l_quantity)] +01)SortPreservingMergeExec: [o_totalprice@4 DESC, o_orderdate@3 ASC NULLS LAST] +02)--SortExec: expr=[o_totalprice@4 DESC, o_orderdate@3 ASC NULLS LAST], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[c_name@0 as c_name, c_custkey@1 as c_custkey, o_orderkey@2 as o_orderkey, o_orderdate@3 as o_orderdate, o_totalprice@4 as o_totalprice], aggr=[sum(lineitem.l_quantity)] 04)------CoalesceBatchesExec: target_batch_size=8192 05)--------RepartitionExec: partitioning=Hash([c_name@0, c_custkey@1, o_orderkey@2, o_orderdate@3, o_totalprice@4], 4), input_partitions=4 -06)----------AggregateExec: mode=Partial, gby=[c_name@1 as c_name, c_custkey@0 as c_custkey, o_orderkey@2 as o_orderkey, o_orderdate@4 as o_orderdate, o_totalprice@3 as o_totalprice], aggr=[SUM(lineitem.l_quantity)] +06)----------AggregateExec: mode=Partial, gby=[c_name@1 as c_name, c_custkey@0 as c_custkey, o_orderkey@2 as o_orderkey, o_orderdate@4 as o_orderdate, o_totalprice@3 as o_totalprice], aggr=[sum(lineitem.l_quantity)] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(o_orderkey@2, l_orderkey@0)] 09)----------------CoalesceBatchesExec: target_batch_size=8192 @@ -91,14 +91,13 @@ physical_plan 22)--------------------CoalesceBatchesExec: target_batch_size=8192 23)----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 24)------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_quantity], has_header=false -25)----------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey] -26)------------------CoalesceBatchesExec: target_batch_size=8192 -27)--------------------FilterExec: SUM(lineitem.l_quantity)@1 > Some(30000),25,2 -28)----------------------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey], aggr=[SUM(lineitem.l_quantity)] -29)------------------------CoalesceBatchesExec: target_batch_size=8192 -30)--------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -31)----------------------------AggregateExec: mode=Partial, gby=[l_orderkey@0 as l_orderkey], aggr=[SUM(lineitem.l_quantity)] -32)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_quantity], has_header=false +25)----------------CoalesceBatchesExec: target_batch_size=8192 +26)------------------FilterExec: sum(lineitem.l_quantity)@1 > Some(30000),25,2, projection=[l_orderkey@0] +27)--------------------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey], aggr=[sum(lineitem.l_quantity)] +28)----------------------CoalesceBatchesExec: target_batch_size=8192 +29)------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +30)--------------------------AggregateExec: mode=Partial, gby=[l_orderkey@0 as l_orderkey], aggr=[sum(lineitem.l_quantity)] +31)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_quantity], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q19.slt.part b/datafusion/sqllogictest/test_files/tpch/q19.slt.part index 9a49fc424eb6..70465ea065a1 100644 --- a/datafusion/sqllogictest/test_files/tpch/q19.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q19.slt.part @@ -54,8 +54,8 @@ where ); ---- logical_plan -01)Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue -02)--Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +01)Projection: sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue +02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice, lineitem.l_discount 04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) 05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount @@ -64,24 +64,23 @@ logical_plan 08)--------Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) 09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)] physical_plan -01)ProjectionExec: expr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue] -02)--AggregateExec: mode=Final, gby=[], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue] +02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4 -09)----------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_quantity@1 as l_quantity, l_extendedprice@2 as l_extendedprice, l_discount@3 as l_discount] -10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON -12)----------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false -13)------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -15)----------------CoalesceBatchesExec: target_batch_size=8192 -16)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 -17)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -18)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], has_header=false +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON, projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3] +11)--------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], has_header=false +12)------------CoalesceBatchesExec: target_batch_size=8192 +13)--------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +14)----------------CoalesceBatchesExec: target_batch_size=8192 +15)------------------FilterExec: (p_brand@1 = Brand#12 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("SM CASE") }, Literal { value: Utf8("SM BOX") }, Literal { value: Utf8("SM PACK") }, Literal { value: Utf8("SM PKG") }]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("MED BAG") }, Literal { value: Utf8("MED BOX") }, Literal { value: Utf8("MED PKG") }, Literal { value: Utf8("MED PACK") }]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND Use p_container@3 IN (SET) ([Literal { value: Utf8("LG CASE") }, Literal { value: Utf8("LG BOX") }, Literal { value: Utf8("LG PACK") }, Literal { value: Utf8("LG PKG") }]) AND p_size@2 <= 15) AND p_size@2 >= 1 +16)--------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +17)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q2.slt.part b/datafusion/sqllogictest/test_files/tpch/q2.slt.part index 088f96f32928..23ffa0d226b8 100644 --- a/datafusion/sqllogictest/test_files/tpch/q2.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q2.slt.part @@ -63,126 +63,121 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST, fetch=10 -03)----Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment -04)------Inner Join: part.p_partkey = __scalar_sq_1.ps_partkey, partsupp.ps_supplycost = __scalar_sq_1.MIN(partsupp.ps_supplycost) -05)--------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name -06)----------Inner Join: nation.n_regionkey = region.r_regionkey -07)------------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name, nation.n_regionkey -08)--------------Inner Join: supplier.s_nationkey = nation.n_nationkey -09)----------------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_nationkey, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost -10)------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -11)--------------------Projection: part.p_partkey, part.p_mfgr, partsupp.ps_suppkey, partsupp.ps_supplycost -12)----------------------Inner Join: part.p_partkey = partsupp.ps_partkey -13)------------------------Projection: part.p_partkey, part.p_mfgr -14)--------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") -15)----------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] -16)------------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] -17)--------------------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] -18)----------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] -19)------------Projection: region.r_regionkey -20)--------------Filter: region.r_name = Utf8("EUROPE") -21)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] -22)--------SubqueryAlias: __scalar_sq_1 -23)----------Projection: MIN(partsupp.ps_supplycost), partsupp.ps_partkey -24)------------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[MIN(partsupp.ps_supplycost)]] -25)--------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost -26)----------------Inner Join: nation.n_regionkey = region.r_regionkey -27)------------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost, nation.n_regionkey -28)--------------------Inner Join: supplier.s_nationkey = nation.n_nationkey -29)----------------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost, supplier.s_nationkey -30)------------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey -31)--------------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] -32)--------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -33)----------------------TableScan: nation projection=[n_nationkey, n_regionkey] -34)------------------Projection: region.r_regionkey -35)--------------------Filter: region.r_name = Utf8("EUROPE") -36)----------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +01)Sort: supplier.s_acctbal DESC NULLS FIRST, nation.n_name ASC NULLS LAST, supplier.s_name ASC NULLS LAST, part.p_partkey ASC NULLS LAST, fetch=10 +02)--Projection: supplier.s_acctbal, supplier.s_name, nation.n_name, part.p_partkey, part.p_mfgr, supplier.s_address, supplier.s_phone, supplier.s_comment +03)----Inner Join: part.p_partkey = __scalar_sq_1.ps_partkey, partsupp.ps_supplycost = __scalar_sq_1.min(partsupp.ps_supplycost) +04)------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name +05)--------Inner Join: nation.n_regionkey = region.r_regionkey +06)----------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost, nation.n_name, nation.n_regionkey +07)------------Inner Join: supplier.s_nationkey = nation.n_nationkey +08)--------------Projection: part.p_partkey, part.p_mfgr, supplier.s_name, supplier.s_address, supplier.s_nationkey, supplier.s_phone, supplier.s_acctbal, supplier.s_comment, partsupp.ps_supplycost +09)----------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +10)------------------Projection: part.p_partkey, part.p_mfgr, partsupp.ps_suppkey, partsupp.ps_supplycost +11)--------------------Inner Join: part.p_partkey = partsupp.ps_partkey +12)----------------------Projection: part.p_partkey, part.p_mfgr +13)------------------------Filter: part.p_size = Int32(15) AND part.p_type LIKE Utf8("%BRASS") +14)--------------------------TableScan: part projection=[p_partkey, p_mfgr, p_type, p_size], partial_filters=[part.p_size = Int32(15), part.p_type LIKE Utf8("%BRASS")] +15)----------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] +16)------------------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment] +17)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] +18)----------Projection: region.r_regionkey +19)------------Filter: region.r_name = Utf8("EUROPE") +20)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] +21)------SubqueryAlias: __scalar_sq_1 +22)--------Projection: min(partsupp.ps_supplycost), partsupp.ps_partkey +23)----------Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[min(partsupp.ps_supplycost)]] +24)------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost +25)--------------Inner Join: nation.n_regionkey = region.r_regionkey +26)----------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost, nation.n_regionkey +27)------------------Inner Join: supplier.s_nationkey = nation.n_nationkey +28)--------------------Projection: partsupp.ps_partkey, partsupp.ps_supplycost, supplier.s_nationkey +29)----------------------Inner Join: partsupp.ps_suppkey = supplier.s_suppkey +30)------------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] +31)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +32)--------------------TableScan: nation projection=[n_nationkey, n_regionkey] +33)----------------Projection: region.r_regionkey +34)------------------Filter: region.r_name = Utf8("EUROPE") +35)--------------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("EUROPE")] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST], fetch=10 -03)----SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC,n_name@2 ASC NULLS LAST,s_name@1 ASC NULLS LAST,p_partkey@3 ASC NULLS LAST] -04)------ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@7 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment] -05)--------CoalesceBatchesExec: target_batch_size=8192 -06)----------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@1), (ps_supplycost@7, MIN(partsupp.ps_supplycost)@0)], projection=[p_partkey@0, p_mfgr@1, s_name@2, s_address@3, s_phone@4, s_acctbal@5, s_comment@6, n_name@8] -07)------------CoalesceBatchesExec: target_batch_size=8192 -08)--------------RepartitionExec: partitioning=Hash([p_partkey@0, ps_supplycost@7], 4), input_partitions=4 -09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@9, r_regionkey@0)], projection=[p_partkey@0, p_mfgr@1, s_name@2, s_address@3, s_phone@4, s_acctbal@5, s_comment@6, ps_supplycost@7, n_name@8] -11)--------------------CoalesceBatchesExec: target_batch_size=8192 -12)----------------------RepartitionExec: partitioning=Hash([n_regionkey@9], 4), input_partitions=4 -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@4, n_nationkey@0)], projection=[p_partkey@0, p_mfgr@1, s_name@2, s_address@3, s_phone@5, s_acctbal@6, s_comment@7, ps_supplycost@8, n_name@10, n_regionkey@11] -15)----------------------------CoalesceBatchesExec: target_batch_size=8192 -16)------------------------------RepartitionExec: partitioning=Hash([s_nationkey@4], 4), input_partitions=4 -17)--------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@3 as s_name, s_address@4 as s_address, s_nationkey@5 as s_nationkey, s_phone@6 as s_phone, s_acctbal@7 as s_acctbal, s_comment@8 as s_comment, ps_supplycost@2 as ps_supplycost] -18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -19)------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@2, s_suppkey@0)], projection=[p_partkey@0, p_mfgr@1, ps_supplycost@3, s_name@5, s_address@6, s_nationkey@7, s_phone@8, s_acctbal@9, s_comment@10] -20)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -21)----------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@2], 4), input_partitions=4 -22)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -23)--------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@0)], projection=[p_partkey@0, p_mfgr@1, ps_suppkey@3, ps_supplycost@4] -24)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -25)------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -26)--------------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr] -27)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -28)------------------------------------------------------FilterExec: p_size@3 = 15 AND p_type@2 LIKE %BRASS -29)--------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -30)----------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_mfgr, p_type, p_size], has_header=false -31)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -32)------------------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -33)--------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false -34)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -35)----------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -36)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -37)--------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment], has_header=false -38)----------------------------CoalesceBatchesExec: target_batch_size=8192 -39)------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -40)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -41)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false -42)--------------------CoalesceBatchesExec: target_batch_size=8192 -43)----------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 -44)------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] -45)--------------------------CoalesceBatchesExec: target_batch_size=8192 -46)----------------------------FilterExec: r_name@1 = EUROPE -47)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -48)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false -49)------------CoalesceBatchesExec: target_batch_size=8192 -50)--------------RepartitionExec: partitioning=Hash([ps_partkey@1, MIN(partsupp.ps_supplycost)@0], 4), input_partitions=4 -51)----------------ProjectionExec: expr=[MIN(partsupp.ps_supplycost)@1 as MIN(partsupp.ps_supplycost), ps_partkey@0 as ps_partkey] -52)------------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[MIN(partsupp.ps_supplycost)] -53)--------------------CoalesceBatchesExec: target_batch_size=8192 -54)----------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -55)------------------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[MIN(partsupp.ps_supplycost)] -56)--------------------------CoalesceBatchesExec: target_batch_size=8192 -57)----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@2, r_regionkey@0)], projection=[ps_partkey@0, ps_supplycost@1] -58)------------------------------CoalesceBatchesExec: target_batch_size=8192 -59)--------------------------------RepartitionExec: partitioning=Hash([n_regionkey@2], 4), input_partitions=4 -60)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -61)------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_partkey@0, ps_supplycost@1, n_regionkey@4] -62)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -63)----------------------------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 -64)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -65)--------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_supplycost@2, s_nationkey@4] -66)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -67)------------------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 -68)--------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false -69)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -70)------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -71)--------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -72)----------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -73)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -74)----------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -75)------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -76)--------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false -77)------------------------------CoalesceBatchesExec: target_batch_size=8192 -78)--------------------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 -79)----------------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] -80)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -81)--------------------------------------FilterExec: r_name@1 = EUROPE -82)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -83)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +01)SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=10 +02)--SortExec: TopK(fetch=10), expr=[s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[s_acctbal@5 as s_acctbal, s_name@2 as s_name, n_name@7 as n_name, p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_address@3 as s_address, s_phone@4 as s_phone, s_comment@6 as s_comment] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@1), (ps_supplycost@7, min(partsupp.ps_supplycost)@0)], projection=[p_partkey@0, p_mfgr@1, s_name@2, s_address@3, s_phone@4, s_acctbal@5, s_comment@6, n_name@8] +06)----------CoalesceBatchesExec: target_batch_size=8192 +07)------------RepartitionExec: partitioning=Hash([p_partkey@0, ps_supplycost@7], 4), input_partitions=4 +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@9, r_regionkey@0)], projection=[p_partkey@0, p_mfgr@1, s_name@2, s_address@3, s_phone@4, s_acctbal@5, s_comment@6, ps_supplycost@7, n_name@8] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------RepartitionExec: partitioning=Hash([n_regionkey@9], 4), input_partitions=4 +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@4, n_nationkey@0)], projection=[p_partkey@0, p_mfgr@1, s_name@2, s_address@3, s_phone@5, s_acctbal@6, s_comment@7, ps_supplycost@8, n_name@10, n_regionkey@11] +14)--------------------------CoalesceBatchesExec: target_batch_size=8192 +15)----------------------------RepartitionExec: partitioning=Hash([s_nationkey@4], 4), input_partitions=4 +16)------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey, p_mfgr@1 as p_mfgr, s_name@3 as s_name, s_address@4 as s_address, s_nationkey@5 as s_nationkey, s_phone@6 as s_phone, s_acctbal@7 as s_acctbal, s_comment@8 as s_comment, ps_supplycost@2 as ps_supplycost] +17)--------------------------------CoalesceBatchesExec: target_batch_size=8192 +18)----------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@2, s_suppkey@0)], projection=[p_partkey@0, p_mfgr@1, ps_supplycost@3, s_name@5, s_address@6, s_nationkey@7, s_phone@8, s_acctbal@9, s_comment@10] +19)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +20)--------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@2], 4), input_partitions=4 +21)----------------------------------------CoalesceBatchesExec: target_batch_size=8192 +22)------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, ps_partkey@0)], projection=[p_partkey@0, p_mfgr@1, ps_suppkey@3, ps_supplycost@4] +23)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +24)----------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +25)------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +26)--------------------------------------------------FilterExec: p_size@3 = 15 AND p_type@2 LIKE %BRASS, projection=[p_partkey@0, p_mfgr@1] +27)----------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +28)------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_mfgr, p_type, p_size], has_header=false +29)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +30)----------------------------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +31)------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false +32)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +33)--------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +34)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +35)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment], has_header=false +36)--------------------------CoalesceBatchesExec: target_batch_size=8192 +37)----------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +38)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +39)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false +40)------------------CoalesceBatchesExec: target_batch_size=8192 +41)--------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 +42)----------------------CoalesceBatchesExec: target_batch_size=8192 +43)------------------------FilterExec: r_name@1 = EUROPE, projection=[r_regionkey@0] +44)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +45)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +46)----------CoalesceBatchesExec: target_batch_size=8192 +47)------------RepartitionExec: partitioning=Hash([ps_partkey@1, min(partsupp.ps_supplycost)@0], 4), input_partitions=4 +48)--------------ProjectionExec: expr=[min(partsupp.ps_supplycost)@1 as min(partsupp.ps_supplycost), ps_partkey@0 as ps_partkey] +49)----------------AggregateExec: mode=FinalPartitioned, gby=[ps_partkey@0 as ps_partkey], aggr=[min(partsupp.ps_supplycost)] +50)------------------CoalesceBatchesExec: target_batch_size=8192 +51)--------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +52)----------------------AggregateExec: mode=Partial, gby=[ps_partkey@0 as ps_partkey], aggr=[min(partsupp.ps_supplycost)] +53)------------------------CoalesceBatchesExec: target_batch_size=8192 +54)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@2, r_regionkey@0)], projection=[ps_partkey@0, ps_supplycost@1] +55)----------------------------CoalesceBatchesExec: target_batch_size=8192 +56)------------------------------RepartitionExec: partitioning=Hash([n_regionkey@2], 4), input_partitions=4 +57)--------------------------------CoalesceBatchesExec: target_batch_size=8192 +58)----------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@2, n_nationkey@0)], projection=[ps_partkey@0, ps_supplycost@1, n_regionkey@4] +59)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +60)--------------------------------------RepartitionExec: partitioning=Hash([s_nationkey@2], 4), input_partitions=4 +61)----------------------------------------CoalesceBatchesExec: target_batch_size=8192 +62)------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_suppkey@1, s_suppkey@0)], projection=[ps_partkey@0, ps_supplycost@2, s_nationkey@4] +63)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +64)----------------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1], 4), input_partitions=4 +65)------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false +66)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +67)----------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +68)------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +69)--------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +70)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +71)--------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +72)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +73)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false +74)----------------------------CoalesceBatchesExec: target_batch_size=8192 +75)------------------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 +76)--------------------------------CoalesceBatchesExec: target_batch_size=8192 +77)----------------------------------FilterExec: r_name@1 = EUROPE, projection=[r_regionkey@0] +78)------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +79)--------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q20.slt.part b/datafusion/sqllogictest/test_files/tpch/q20.slt.part index 4b01f32a94ca..177e38e51ca4 100644 --- a/datafusion/sqllogictest/test_files/tpch/q20.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q20.slt.part @@ -58,31 +58,31 @@ order by logical_plan 01)Sort: supplier.s_name ASC NULLS LAST 02)--Projection: supplier.s_name, supplier.s_address -03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey +03)----LeftSemi Join: supplier.s_suppkey = __correlated_sq_2.ps_suppkey 04)------Projection: supplier.s_suppkey, supplier.s_name, supplier.s_address 05)--------Inner Join: supplier.s_nationkey = nation.n_nationkey 06)----------TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] 07)----------Projection: nation.n_nationkey 08)------------Filter: nation.n_name = Utf8("CANADA") 09)--------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] -10)------SubqueryAlias: __correlated_sq_1 +10)------SubqueryAlias: __correlated_sq_2 11)--------Projection: partsupp.ps_suppkey -12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * SUM(lineitem.l_quantity) -13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey +12)----------Inner Join: partsupp.ps_partkey = __scalar_sq_3.l_partkey, partsupp.ps_suppkey = __scalar_sq_3.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_3.Float64(0.5) * sum(lineitem.l_quantity) +13)------------LeftSemi Join: partsupp.ps_partkey = __correlated_sq_1.p_partkey 14)--------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] -15)--------------SubqueryAlias: __correlated_sq_2 +15)--------------SubqueryAlias: __correlated_sq_1 16)----------------Projection: part.p_partkey 17)------------------Filter: part.p_name LIKE Utf8("forest%") 18)--------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] 19)------------SubqueryAlias: __scalar_sq_3 -20)--------------Projection: Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64), lineitem.l_partkey, lineitem.l_suppkey -21)----------------Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] +20)--------------Projection: Float64(0.5) * CAST(sum(lineitem.l_quantity) AS Float64), lineitem.l_partkey, lineitem.l_suppkey +21)----------------Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[sum(lineitem.l_quantity)]] 22)------------------Projection: lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity -23)--------------------Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") -24)----------------------TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766"), lineitem.l_shipdate < Date32("9131")] +23)--------------------Filter: lineitem.l_shipdate >= Date32("1994-01-01") AND lineitem.l_shipdate < Date32("1995-01-01") +24)----------------------TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1994-01-01"), lineitem.l_shipdate < Date32("1995-01-01")] physical_plan 01)SortPreservingMergeExec: [s_name@0 ASC NULLS LAST] -02)--SortExec: expr=[s_name@0 ASC NULLS LAST] +02)--SortExec: expr=[s_name@0 ASC NULLS LAST], preserve_partitioning=[true] 03)----CoalesceBatchesExec: target_batch_size=8192 04)------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(s_suppkey@0, ps_suppkey@0)], projection=[s_name@1, s_address@2] 05)--------CoalesceBatchesExec: target_batch_size=8192 @@ -95,38 +95,35 @@ physical_plan 12)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_address, s_nationkey], has_header=false 13)----------------CoalesceBatchesExec: target_batch_size=8192 14)------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -15)--------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] -16)----------------------CoalesceBatchesExec: target_batch_size=8192 -17)------------------------FilterExec: n_name@1 = CANADA -18)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -19)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false -20)--------CoalesceBatchesExec: target_batch_size=8192 -21)----------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 -22)------------CoalesceBatchesExec: target_batch_size=8192 -23)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, l_partkey@1), (ps_suppkey@1, l_suppkey@2)], filter=CAST(ps_availqty@0 AS Float64) > Float64(0.5) * SUM(lineitem.l_quantity)@1, projection=[ps_suppkey@1] -24)----------------CoalesceBatchesExec: target_batch_size=8192 -25)------------------RepartitionExec: partitioning=Hash([ps_partkey@0, ps_suppkey@1], 4), input_partitions=4 -26)--------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(ps_partkey@0, p_partkey@0)] -28)------------------------CoalesceBatchesExec: target_batch_size=8192 -29)--------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 -30)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty], has_header=false -31)------------------------CoalesceBatchesExec: target_batch_size=8192 -32)--------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -33)----------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -34)------------------------------CoalesceBatchesExec: target_batch_size=8192 -35)--------------------------------FilterExec: p_name@1 LIKE forest% -36)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -37)------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false -38)----------------ProjectionExec: expr=[0.5 * CAST(SUM(lineitem.l_quantity)@2 AS Float64) as Float64(0.5) * SUM(lineitem.l_quantity), l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey] -39)------------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey], aggr=[SUM(lineitem.l_quantity)] -40)--------------------CoalesceBatchesExec: target_batch_size=8192 -41)----------------------RepartitionExec: partitioning=Hash([l_partkey@0, l_suppkey@1], 4), input_partitions=4 -42)------------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey], aggr=[SUM(lineitem.l_quantity)] -43)--------------------------ProjectionExec: expr=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey, l_quantity@2 as l_quantity] -44)----------------------------CoalesceBatchesExec: target_batch_size=8192 -45)------------------------------FilterExec: l_shipdate@3 >= 8766 AND l_shipdate@3 < 9131 -46)--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], has_header=false +15)--------------------CoalesceBatchesExec: target_batch_size=8192 +16)----------------------FilterExec: n_name@1 = CANADA, projection=[n_nationkey@0] +17)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +18)--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +19)--------CoalesceBatchesExec: target_batch_size=8192 +20)----------RepartitionExec: partitioning=Hash([ps_suppkey@0], 4), input_partitions=4 +21)------------CoalesceBatchesExec: target_batch_size=8192 +22)--------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(ps_partkey@0, l_partkey@1), (ps_suppkey@1, l_suppkey@2)], filter=CAST(ps_availqty@0 AS Float64) > Float64(0.5) * sum(lineitem.l_quantity)@1, projection=[ps_suppkey@1] +23)----------------CoalesceBatchesExec: target_batch_size=8192 +24)------------------RepartitionExec: partitioning=Hash([ps_partkey@0, ps_suppkey@1], 4), input_partitions=4 +25)--------------------CoalesceBatchesExec: target_batch_size=8192 +26)----------------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(ps_partkey@0, p_partkey@0)] +27)------------------------CoalesceBatchesExec: target_batch_size=8192 +28)--------------------------RepartitionExec: partitioning=Hash([ps_partkey@0], 4), input_partitions=4 +29)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_availqty], has_header=false +30)------------------------CoalesceBatchesExec: target_batch_size=8192 +31)--------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +32)----------------------------CoalesceBatchesExec: target_batch_size=8192 +33)------------------------------FilterExec: p_name@1 LIKE forest%, projection=[p_partkey@0] +34)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +35)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false +36)----------------ProjectionExec: expr=[0.5 * CAST(sum(lineitem.l_quantity)@2 AS Float64) as Float64(0.5) * sum(lineitem.l_quantity), l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey] +37)------------------AggregateExec: mode=FinalPartitioned, gby=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey], aggr=[sum(lineitem.l_quantity)] +38)--------------------CoalesceBatchesExec: target_batch_size=8192 +39)----------------------RepartitionExec: partitioning=Hash([l_partkey@0, l_suppkey@1], 4), input_partitions=4 +40)------------------------AggregateExec: mode=Partial, gby=[l_partkey@0 as l_partkey, l_suppkey@1 as l_suppkey], aggr=[sum(lineitem.l_quantity)] +41)--------------------------CoalesceBatchesExec: target_batch_size=8192 +42)----------------------------FilterExec: l_shipdate@3 >= 1994-01-01 AND l_shipdate@3 < 1995-01-01, projection=[l_partkey@0, l_suppkey@1, l_quantity@2] +43)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q21.slt.part b/datafusion/sqllogictest/test_files/tpch/q21.slt.part index 944a68abd327..93dcd4c68052 100644 --- a/datafusion/sqllogictest/test_files/tpch/q21.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q21.slt.part @@ -59,8 +59,8 @@ order by ---- logical_plan 01)Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST -02)--Projection: supplier.s_name, COUNT(*) AS numwait -03)----Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: supplier.s_name, count(*) AS numwait +03)----Aggregate: groupBy=[[supplier.s_name]], aggr=[[count(Int64(1)) AS count(*)]] 04)------Projection: supplier.s_name 05)--------LeftAnti Join: l1.l_orderkey = __correlated_sq_2.l_orderkey Filter: __correlated_sq_2.l_suppkey != l1.l_suppkey 06)----------LeftSemi Join: l1.l_orderkey = __correlated_sq_1.l_orderkey Filter: __correlated_sq_1.l_suppkey != l1.l_suppkey @@ -90,13 +90,13 @@ logical_plan 30)----------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate 31)------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate] physical_plan -01)SortPreservingMergeExec: [numwait@1 DESC,s_name@0 ASC NULLS LAST] -02)--SortExec: expr=[numwait@1 DESC,s_name@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[s_name@0 as s_name, COUNT(*)@1 as numwait] -04)------AggregateExec: mode=FinalPartitioned, gby=[s_name@0 as s_name], aggr=[COUNT(*)] +01)SortPreservingMergeExec: [numwait@1 DESC, s_name@0 ASC NULLS LAST] +02)--SortExec: expr=[numwait@1 DESC, s_name@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[s_name@0 as s_name, count(*)@1 as numwait] +04)------AggregateExec: mode=FinalPartitioned, gby=[s_name@0 as s_name], aggr=[count(*)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([s_name@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[s_name@0 as s_name], aggr=[COUNT(*)] +07)------------AggregateExec: mode=Partial, gby=[s_name@0 as s_name], aggr=[count(*)] 08)--------------CoalesceBatchesExec: target_batch_size=8192 09)----------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(l_orderkey@1, l_orderkey@0)], filter=l_suppkey@1 != l_suppkey@0, projection=[s_name@0] 10)------------------CoalesceBatchesExec: target_batch_size=8192 @@ -119,32 +119,28 @@ physical_plan 27)----------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_name, s_nationkey], has_header=false 28)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 29)------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@1], 4), input_partitions=4 -30)--------------------------------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey] -31)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -32)------------------------------------------------------FilterExec: l_receiptdate@3 > l_commitdate@2 -33)--------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false -34)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -35)----------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 -36)------------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey] -37)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -38)----------------------------------------------FilterExec: o_orderstatus@1 = F -39)------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderstatus], has_header=false -40)------------------------------CoalesceBatchesExec: target_batch_size=8192 -41)--------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -42)----------------------------------ProjectionExec: expr=[n_nationkey@0 as n_nationkey] -43)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -44)--------------------------------------FilterExec: n_name@1 = SAUDI ARABIA -45)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -46)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false -47)----------------------CoalesceBatchesExec: target_batch_size=8192 -48)------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -49)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey], has_header=false -50)------------------CoalesceBatchesExec: target_batch_size=8192 -51)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -52)----------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_suppkey@1 as l_suppkey] -53)------------------------CoalesceBatchesExec: target_batch_size=8192 -54)--------------------------FilterExec: l_receiptdate@3 > l_commitdate@2 -55)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false +30)--------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +31)----------------------------------------------------FilterExec: l_receiptdate@3 > l_commitdate@2, projection=[l_orderkey@0, l_suppkey@1] +32)------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false +33)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 +34)----------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +35)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +36)--------------------------------------------FilterExec: o_orderstatus@1 = F, projection=[o_orderkey@0] +37)----------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderstatus], has_header=false +38)------------------------------CoalesceBatchesExec: target_batch_size=8192 +39)--------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +40)----------------------------------CoalesceBatchesExec: target_batch_size=8192 +41)------------------------------------FilterExec: n_name@1 = SAUDI ARABIA, projection=[n_nationkey@0] +42)--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +43)----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +44)----------------------CoalesceBatchesExec: target_batch_size=8192 +45)------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +46)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey], has_header=false +47)------------------CoalesceBatchesExec: target_batch_size=8192 +48)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +49)----------------------CoalesceBatchesExec: target_batch_size=8192 +50)------------------------FilterExec: l_receiptdate@3 > l_commitdate@2, projection=[l_orderkey@0, l_suppkey@1] +51)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_commitdate, l_receiptdate], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q22.slt.part b/datafusion/sqllogictest/test_files/tpch/q22.slt.part index 98c8ba396552..2955748160ea 100644 --- a/datafusion/sqllogictest/test_files/tpch/q22.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q22.slt.part @@ -57,11 +57,11 @@ order by ---- logical_plan 01)Sort: custsale.cntrycode ASC NULLS LAST -02)--Projection: custsale.cntrycode, COUNT(*) AS numcust, SUM(custsale.c_acctbal) AS totacctbal -03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(Int64(1)) AS COUNT(*), SUM(custsale.c_acctbal)]] +02)--Projection: custsale.cntrycode, count(*) AS numcust, sum(custsale.c_acctbal) AS totacctbal +03)----Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[count(Int64(1)) AS count(*), sum(custsale.c_acctbal)]] 04)------SubqueryAlias: custsale 05)--------Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal -06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.AVG(customer.c_acctbal) +06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal) 07)------------Projection: customer.c_phone, customer.c_acctbal 08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey 09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) @@ -69,21 +69,21 @@ logical_plan 11)----------------SubqueryAlias: __correlated_sq_1 12)------------------TableScan: orders projection=[o_custkey] 13)------------SubqueryAlias: __scalar_sq_2 -14)--------------Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] +14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]] 15)----------------Projection: customer.c_acctbal 16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) -17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)] +17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] physical_plan 01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] -02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, COUNT(*)@1 as numcust, SUM(custsale.c_acctbal)@2 as totacctbal] -04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(*), SUM(custsale.c_acctbal)] +02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[cntrycode@0 as cntrycode, count(*)@1 as numcust, sum(custsale.c_acctbal)@2 as totacctbal] +04)------AggregateExec: mode=FinalPartitioned, gby=[cntrycode@0 as cntrycode], aggr=[count(*), sum(custsale.c_acctbal)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([cntrycode@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[COUNT(*), SUM(custsale.c_acctbal)] +07)------------AggregateExec: mode=Partial, gby=[cntrycode@0 as cntrycode], aggr=[count(*), sum(custsale.c_acctbal)] 08)--------------ProjectionExec: expr=[substr(c_phone@0, 1, 2) as cntrycode, c_acctbal@1 as c_acctbal] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -10)------------------NestedLoopJoinExec: join_type=Inner, filter=CAST(c_acctbal@0 AS Decimal128(19, 6)) > AVG(customer.c_acctbal)@1 +10)------------------NestedLoopJoinExec: join_type=Inner, filter=CAST(c_acctbal@0 AS Decimal128(19, 6)) > avg(customer.c_acctbal)@1 11)--------------------CoalescePartitionsExec 12)----------------------CoalesceBatchesExec: target_batch_size=8192 13)------------------------HashJoinExec: mode=Partitioned, join_type=LeftAnti, on=[(c_custkey@0, o_custkey@0)], projection=[c_phone@1, c_acctbal@2] @@ -96,14 +96,13 @@ physical_plan 20)--------------------------CoalesceBatchesExec: target_batch_size=8192 21)----------------------------RepartitionExec: partitioning=Hash([o_custkey@0], 4), input_partitions=4 22)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_custkey], has_header=false -23)--------------------AggregateExec: mode=Final, gby=[], aggr=[AVG(customer.c_acctbal)] +23)--------------------AggregateExec: mode=Final, gby=[], aggr=[avg(customer.c_acctbal)] 24)----------------------CoalescePartitionsExec -25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[AVG(customer.c_acctbal)] -26)--------------------------ProjectionExec: expr=[c_acctbal@1 as c_acctbal] -27)----------------------------CoalesceBatchesExec: target_batch_size=8192 -28)------------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]) -29)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -30)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], has_header=false +25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)] +26)--------------------------CoalesceBatchesExec: target_batch_size=8192 +27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]), projection=[c_acctbal@1] +28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +29)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], has_header=false query TIR diff --git a/datafusion/sqllogictest/test_files/tpch/q3.slt.part b/datafusion/sqllogictest/test_files/tpch/q3.slt.part index fe7816715632..289e9c7732bb 100644 --- a/datafusion/sqllogictest/test_files/tpch/q3.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q3.slt.part @@ -42,55 +42,51 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: revenue DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST, fetch=10 -03)----Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, orders.o_orderdate, orders.o_shippriority -04)------Aggregate: groupBy=[[lineitem.l_orderkey, orders.o_orderdate, orders.o_shippriority]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -05)--------Projection: orders.o_orderdate, orders.o_shippriority, lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount -06)----------Inner Join: orders.o_orderkey = lineitem.l_orderkey -07)------------Projection: orders.o_orderkey, orders.o_orderdate, orders.o_shippriority -08)--------------Inner Join: customer.c_custkey = orders.o_custkey -09)----------------Projection: customer.c_custkey -10)------------------Filter: customer.c_mktsegment = Utf8("BUILDING") -11)--------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8("BUILDING")] -12)----------------Filter: orders.o_orderdate < Date32("9204") -13)------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], partial_filters=[orders.o_orderdate < Date32("9204")] -14)------------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount -15)--------------Filter: lineitem.l_shipdate > Date32("9204") -16)----------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("9204")] +01)Sort: revenue DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST, fetch=10 +02)--Projection: lineitem.l_orderkey, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, orders.o_orderdate, orders.o_shippriority +03)----Aggregate: groupBy=[[lineitem.l_orderkey, orders.o_orderdate, orders.o_shippriority]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +04)------Projection: orders.o_orderdate, orders.o_shippriority, lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount +05)--------Inner Join: orders.o_orderkey = lineitem.l_orderkey +06)----------Projection: orders.o_orderkey, orders.o_orderdate, orders.o_shippriority +07)------------Inner Join: customer.c_custkey = orders.o_custkey +08)--------------Projection: customer.c_custkey +09)----------------Filter: customer.c_mktsegment = Utf8("BUILDING") +10)------------------TableScan: customer projection=[c_custkey, c_mktsegment], partial_filters=[customer.c_mktsegment = Utf8("BUILDING")] +11)--------------Filter: orders.o_orderdate < Date32("1995-03-15") +12)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], partial_filters=[orders.o_orderdate < Date32("1995-03-15")] +13)----------Projection: lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount +14)------------Filter: lineitem.l_shipdate > Date32("1995-03-15") +15)--------------TableScan: lineitem projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate > Date32("1995-03-15")] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [revenue@1 DESC,o_orderdate@2 ASC NULLS LAST], fetch=10 -03)----SortExec: TopK(fetch=10), expr=[revenue@1 DESC,o_orderdate@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority] -05)--------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------RepartitionExec: partitioning=Hash([l_orderkey@0, o_orderdate@1, o_shippriority@2], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[l_orderkey@2 as l_orderkey, o_orderdate@0 as o_orderdate, o_shippriority@1 as o_shippriority], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] -09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@0, l_orderkey@0)], projection=[o_orderdate@1, o_shippriority@2, l_orderkey@3, l_extendedprice@4, l_discount@5] -11)--------------------CoalesceBatchesExec: target_batch_size=8192 -12)----------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)], projection=[o_orderkey@1, o_orderdate@3, o_shippriority@4] -15)----------------------------CoalesceBatchesExec: target_batch_size=8192 -16)------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -17)--------------------------------ProjectionExec: expr=[c_custkey@0 as c_custkey] -18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -19)------------------------------------FilterExec: c_mktsegment@1 = BUILDING -20)--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -21)----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_mktsegment], has_header=false -22)----------------------------CoalesceBatchesExec: target_batch_size=8192 -23)------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 -24)--------------------------------CoalesceBatchesExec: target_batch_size=8192 -25)----------------------------------FilterExec: o_orderdate@2 < 9204 -26)------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], has_header=false -27)--------------------CoalesceBatchesExec: target_batch_size=8192 -28)----------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -29)------------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey, l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -30)--------------------------CoalesceBatchesExec: target_batch_size=8192 -31)----------------------------FilterExec: l_shipdate@3 > 9204 -32)------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], has_header=false +01)SortPreservingMergeExec: [revenue@1 DESC, o_orderdate@2 ASC NULLS LAST], fetch=10 +02)--SortExec: TopK(fetch=10), expr=[revenue@1 DESC, o_orderdate@2 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[l_orderkey@0 as l_orderkey, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@3 as revenue, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority] +04)------AggregateExec: mode=FinalPartitioned, gby=[l_orderkey@0 as l_orderkey, o_orderdate@1 as o_orderdate, o_shippriority@2 as o_shippriority], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([l_orderkey@0, o_orderdate@1, o_shippriority@2], 4), input_partitions=4 +07)------------AggregateExec: mode=Partial, gby=[l_orderkey@2 as l_orderkey, o_orderdate@0 as o_orderdate, o_shippriority@1 as o_shippriority], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +08)--------------CoalesceBatchesExec: target_batch_size=8192 +09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(o_orderkey@0, l_orderkey@0)], projection=[o_orderdate@1, o_shippriority@2, l_orderkey@3, l_extendedprice@4, l_discount@5] +10)------------------CoalesceBatchesExec: target_batch_size=8192 +11)--------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_custkey@0, o_custkey@1)], projection=[o_orderkey@1, o_orderdate@3, o_shippriority@4] +14)--------------------------CoalesceBatchesExec: target_batch_size=8192 +15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +16)------------------------------CoalesceBatchesExec: target_batch_size=8192 +17)--------------------------------FilterExec: c_mktsegment@1 = BUILDING, projection=[c_custkey@0] +18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +19)------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_mktsegment], has_header=false +20)--------------------------CoalesceBatchesExec: target_batch_size=8192 +21)----------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 +22)------------------------------CoalesceBatchesExec: target_batch_size=8192 +23)--------------------------------FilterExec: o_orderdate@2 < 1995-03-15 +24)----------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate, o_shippriority], has_header=false +25)------------------CoalesceBatchesExec: target_batch_size=8192 +26)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +27)----------------------CoalesceBatchesExec: target_batch_size=8192 +28)------------------------FilterExec: l_shipdate@3 > 1995-03-15, projection=[l_orderkey@0, l_extendedprice@1, l_discount@2] +29)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_extendedprice, l_discount, l_shipdate], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q4.slt.part b/datafusion/sqllogictest/test_files/tpch/q4.slt.part index eb069cdce081..a68b745c366c 100644 --- a/datafusion/sqllogictest/test_files/tpch/q4.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q4.slt.part @@ -41,39 +41,37 @@ order by ---- logical_plan 01)Sort: orders.o_orderpriority ASC NULLS LAST -02)--Projection: orders.o_orderpriority, COUNT(*) AS order_count -03)----Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: orders.o_orderpriority, count(*) AS order_count +03)----Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[count(Int64(1)) AS count(*)]] 04)------Projection: orders.o_orderpriority 05)--------LeftSemi Join: orders.o_orderkey = __correlated_sq_1.l_orderkey 06)----------Projection: orders.o_orderkey, orders.o_orderpriority -07)------------Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674") -08)--------------TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority], partial_filters=[orders.o_orderdate >= Date32("8582"), orders.o_orderdate < Date32("8674")] +07)------------Filter: orders.o_orderdate >= Date32("1993-07-01") AND orders.o_orderdate < Date32("1993-10-01") +08)--------------TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority], partial_filters=[orders.o_orderdate >= Date32("1993-07-01"), orders.o_orderdate < Date32("1993-10-01")] 09)----------SubqueryAlias: __correlated_sq_1 10)------------Projection: lineitem.l_orderkey 11)--------------Filter: lineitem.l_receiptdate > lineitem.l_commitdate 12)----------------TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate], partial_filters=[lineitem.l_receiptdate > lineitem.l_commitdate] physical_plan 01)SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST] -02)--SortExec: expr=[o_orderpriority@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, COUNT(*)@1 as order_count] -04)------AggregateExec: mode=FinalPartitioned, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(*)] +02)--SortExec: expr=[o_orderpriority@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[o_orderpriority@0 as o_orderpriority, count(*)@1 as order_count] +04)------AggregateExec: mode=FinalPartitioned, gby=[o_orderpriority@0 as o_orderpriority], aggr=[count(*)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([o_orderpriority@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[o_orderpriority@0 as o_orderpriority], aggr=[COUNT(*)] +07)------------AggregateExec: mode=Partial, gby=[o_orderpriority@0 as o_orderpriority], aggr=[count(*)] 08)--------------CoalesceBatchesExec: target_batch_size=8192 09)----------------HashJoinExec: mode=Partitioned, join_type=LeftSemi, on=[(o_orderkey@0, l_orderkey@0)], projection=[o_orderpriority@1] 10)------------------CoalesceBatchesExec: target_batch_size=8192 11)--------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 -12)----------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_orderpriority@2 as o_orderpriority] -13)------------------------CoalesceBatchesExec: target_batch_size=8192 -14)--------------------------FilterExec: o_orderdate@1 >= 8582 AND o_orderdate@1 < 8674 -15)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderdate, o_orderpriority], has_header=false -16)------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -18)----------------------ProjectionExec: expr=[l_orderkey@0 as l_orderkey] -19)------------------------CoalesceBatchesExec: target_batch_size=8192 -20)--------------------------FilterExec: l_receiptdate@2 > l_commitdate@1 -21)----------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_commitdate, l_receiptdate], has_header=false +12)----------------------CoalesceBatchesExec: target_batch_size=8192 +13)------------------------FilterExec: o_orderdate@1 >= 1993-07-01 AND o_orderdate@1 < 1993-10-01, projection=[o_orderkey@0, o_orderpriority@2] +14)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderdate, o_orderpriority], has_header=false +15)------------------CoalesceBatchesExec: target_batch_size=8192 +16)--------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +17)----------------------CoalesceBatchesExec: target_batch_size=8192 +18)------------------------FilterExec: l_receiptdate@2 > l_commitdate@1, projection=[l_orderkey@0] +19)--------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_commitdate, l_receiptdate], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q5.slt.part b/datafusion/sqllogictest/test_files/tpch/q5.slt.part index 188573069269..e59daf4943e8 100644 --- a/datafusion/sqllogictest/test_files/tpch/q5.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q5.slt.part @@ -44,8 +44,8 @@ order by ---- logical_plan 01)Sort: revenue DESC NULLS FIRST -02)--Projection: nation.n_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue -03)----Aggregate: groupBy=[[nation.n_name]], aggr=[[SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +02)--Projection: nation.n_name, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue +03)----Aggregate: groupBy=[[nation.n_name]], aggr=[[sum(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] 04)------Projection: lineitem.l_extendedprice, lineitem.l_discount, nation.n_name 05)--------Inner Join: nation.n_regionkey = region.r_regionkey 06)----------Projection: lineitem.l_extendedprice, lineitem.l_discount, nation.n_name, nation.n_regionkey @@ -58,8 +58,8 @@ logical_plan 13)------------------------Inner Join: customer.c_custkey = orders.o_custkey 14)--------------------------TableScan: customer projection=[c_custkey, c_nationkey] 15)--------------------------Projection: orders.o_orderkey, orders.o_custkey -16)----------------------------Filter: orders.o_orderdate >= Date32("8766") AND orders.o_orderdate < Date32("9131") -17)------------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("8766"), orders.o_orderdate < Date32("9131")] +16)----------------------------Filter: orders.o_orderdate >= Date32("1994-01-01") AND orders.o_orderdate < Date32("1995-01-01") +17)------------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("1994-01-01"), orders.o_orderdate < Date32("1995-01-01")] 18)----------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount] 19)------------------TableScan: supplier projection=[s_suppkey, s_nationkey] 20)--------------TableScan: nation projection=[n_nationkey, n_name, n_regionkey] @@ -68,12 +68,12 @@ logical_plan 23)--------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("ASIA")] physical_plan 01)SortPreservingMergeExec: [revenue@1 DESC] -02)--SortExec: expr=[revenue@1 DESC] -03)----ProjectionExec: expr=[n_name@0 as n_name, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as revenue] -04)------AggregateExec: mode=FinalPartitioned, gby=[n_name@0 as n_name], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +02)--SortExec: expr=[revenue@1 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[n_name@0 as n_name, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@1 as revenue] +04)------AggregateExec: mode=FinalPartitioned, gby=[n_name@0 as n_name], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([n_name@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[n_name@2 as n_name], aggr=[SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] +07)------------AggregateExec: mode=Partial, gby=[n_name@2 as n_name], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] 08)--------------CoalesceBatchesExec: target_batch_size=8192 09)----------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@3, r_regionkey@0)], projection=[l_extendedprice@0, l_discount@1, n_name@2] 10)------------------CoalesceBatchesExec: target_batch_size=8192 @@ -98,28 +98,26 @@ physical_plan 29)--------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false 30)--------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 31)----------------------------------------------------RepartitionExec: partitioning=Hash([o_custkey@1], 4), input_partitions=4 -32)------------------------------------------------------ProjectionExec: expr=[o_orderkey@0 as o_orderkey, o_custkey@1 as o_custkey] -33)--------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -34)----------------------------------------------------------FilterExec: o_orderdate@2 >= 8766 AND o_orderdate@2 < 9131 -35)------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false -36)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -37)--------------------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -38)----------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount], has_header=false -39)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -40)------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0, s_nationkey@1], 4), input_partitions=4 -41)--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -42)----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -43)--------------------------CoalesceBatchesExec: target_batch_size=8192 -44)----------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -45)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -46)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false -47)------------------CoalesceBatchesExec: target_batch_size=8192 -48)--------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 -49)----------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] -50)------------------------CoalesceBatchesExec: target_batch_size=8192 -51)--------------------------FilterExec: r_name@1 = ASIA -52)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -53)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +32)------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +33)--------------------------------------------------------FilterExec: o_orderdate@2 >= 1994-01-01 AND o_orderdate@2 < 1995-01-01, projection=[o_orderkey@0, o_custkey@1] +34)----------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false +35)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +36)--------------------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +37)----------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount], has_header=false +38)----------------------------------CoalesceBatchesExec: target_batch_size=8192 +39)------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0, s_nationkey@1], 4), input_partitions=4 +40)--------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +41)----------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +42)--------------------------CoalesceBatchesExec: target_batch_size=8192 +43)----------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +44)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +45)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name, n_regionkey], has_header=false +46)------------------CoalesceBatchesExec: target_batch_size=8192 +47)--------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 +48)----------------------CoalesceBatchesExec: target_batch_size=8192 +49)------------------------FilterExec: r_name@1 = ASIA, projection=[r_regionkey@0] +50)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +51)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q6.slt.part b/datafusion/sqllogictest/test_files/tpch/q6.slt.part index e54b3c1ccd03..548d26972f14 100644 --- a/datafusion/sqllogictest/test_files/tpch/q6.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q6.slt.part @@ -28,20 +28,19 @@ where and l_quantity < 24; ---- logical_plan -01)Projection: SUM(lineitem.l_extendedprice * lineitem.l_discount) AS revenue -02)--Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice * lineitem.l_discount)]] +01)Projection: sum(lineitem.l_extendedprice * lineitem.l_discount) AS revenue +02)--Aggregate: groupBy=[[]], aggr=[[sum(lineitem.l_extendedprice * lineitem.l_discount)]] 03)----Projection: lineitem.l_extendedprice, lineitem.l_discount -04)------Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") AND lineitem.l_discount >= Decimal128(Some(5),15,2) AND lineitem.l_discount <= Decimal128(Some(7),15,2) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) -05)--------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766"), lineitem.l_shipdate < Date32("9131"), lineitem.l_discount >= Decimal128(Some(5),15,2), lineitem.l_discount <= Decimal128(Some(7),15,2), lineitem.l_quantity < Decimal128(Some(2400),15,2)] +04)------Filter: lineitem.l_shipdate >= Date32("1994-01-01") AND lineitem.l_shipdate < Date32("1995-01-01") AND lineitem.l_discount >= Decimal128(Some(5),15,2) AND lineitem.l_discount <= Decimal128(Some(7),15,2) AND lineitem.l_quantity < Decimal128(Some(2400),15,2) +05)--------TableScan: lineitem projection=[l_quantity, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1994-01-01"), lineitem.l_shipdate < Date32("1995-01-01"), lineitem.l_discount >= Decimal128(Some(5),15,2), lineitem.l_discount <= Decimal128(Some(7),15,2), lineitem.l_quantity < Decimal128(Some(2400),15,2)] physical_plan -01)ProjectionExec: expr=[SUM(lineitem.l_extendedprice * lineitem.l_discount)@0 as revenue] -02)--AggregateExec: mode=Final, gby=[], aggr=[SUM(lineitem.l_extendedprice * lineitem.l_discount)] +01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * lineitem.l_discount)@0 as revenue] +02)--AggregateExec: mode=Final, gby=[], aggr=[sum(lineitem.l_extendedprice * lineitem.l_discount)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[SUM(lineitem.l_extendedprice * lineitem.l_discount)] -05)--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------FilterExec: l_shipdate@3 >= 8766 AND l_shipdate@3 < 9131 AND l_discount@2 >= Some(5),15,2 AND l_discount@2 <= Some(7),15,2 AND l_quantity@0 < Some(2400),15,2 -08)--------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_quantity, l_extendedprice, l_discount, l_shipdate], has_header=false +04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * lineitem.l_discount)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------FilterExec: l_shipdate@3 >= 1994-01-01 AND l_shipdate@3 < 1995-01-01 AND l_discount@2 >= Some(5),15,2 AND l_discount@2 <= Some(7),15,2 AND l_quantity@0 < Some(2400),15,2, projection=[l_extendedprice@1, l_discount@2] +07)------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_quantity, l_extendedprice, l_discount, l_shipdate], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q7.slt.part b/datafusion/sqllogictest/test_files/tpch/q7.slt.part index f699c3ee6734..a16af4710478 100644 --- a/datafusion/sqllogictest/test_files/tpch/q7.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q7.slt.part @@ -59,8 +59,8 @@ order by ---- logical_plan 01)Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, shipping.l_year ASC NULLS LAST -02)--Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, SUM(shipping.volume) AS revenue -03)----Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] +02)--Projection: shipping.supp_nation, shipping.cust_nation, shipping.l_year, sum(shipping.volume) AS revenue +03)----Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[sum(shipping.volume)]] 04)------SubqueryAlias: shipping 05)--------Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, date_part(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume 06)----------Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") @@ -73,8 +73,8 @@ logical_plan 13)------------------------Projection: supplier.s_nationkey, lineitem.l_orderkey, lineitem.l_extendedprice, lineitem.l_discount, lineitem.l_shipdate 14)--------------------------Inner Join: supplier.s_suppkey = lineitem.l_suppkey 15)----------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -16)----------------------------Filter: lineitem.l_shipdate >= Date32("9131") AND lineitem.l_shipdate <= Date32("9861") -17)------------------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("9131"), lineitem.l_shipdate <= Date32("9861")] +16)----------------------------Filter: lineitem.l_shipdate >= Date32("1995-01-01") AND lineitem.l_shipdate <= Date32("1996-12-31") +17)------------------------------TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("1995-01-01"), lineitem.l_shipdate <= Date32("1996-12-31")] 18)------------------------TableScan: orders projection=[o_orderkey, o_custkey] 19)--------------------TableScan: customer projection=[c_custkey, c_nationkey] 20)----------------SubqueryAlias: n1 @@ -84,13 +84,13 @@ logical_plan 24)--------------Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") 25)----------------TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE")] physical_plan -01)SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST,cust_nation@1 ASC NULLS LAST,l_year@2 ASC NULLS LAST] -02)--SortExec: expr=[supp_nation@0 ASC NULLS LAST,cust_nation@1 ASC NULLS LAST,l_year@2 ASC NULLS LAST] -03)----ProjectionExec: expr=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year, SUM(shipping.volume)@3 as revenue] -04)------AggregateExec: mode=FinalPartitioned, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[SUM(shipping.volume)] +01)SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST] +02)--SortExec: expr=[supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year, sum(shipping.volume)@3 as revenue] +04)------AggregateExec: mode=FinalPartitioned, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[sum(shipping.volume)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([supp_nation@0, cust_nation@1, l_year@2], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[SUM(shipping.volume)] +07)------------AggregateExec: mode=Partial, gby=[supp_nation@0 as supp_nation, cust_nation@1 as cust_nation, l_year@2 as l_year], aggr=[sum(shipping.volume)] 08)--------------ProjectionExec: expr=[n_name@3 as supp_nation, n_name@4 as cust_nation, date_part(YEAR, l_shipdate@2) as l_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume] 09)----------------CoalesceBatchesExec: target_batch_size=8192 10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(c_nationkey@3, n_nationkey@0)], filter=n_name@0 = FRANCE AND n_name@1 = GERMANY OR n_name@0 = GERMANY AND n_name@1 = FRANCE, projection=[l_extendedprice@0, l_discount@1, l_shipdate@2, n_name@4, n_name@6] @@ -117,7 +117,7 @@ physical_plan 31)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 32)------------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@1], 4), input_partitions=4 33)--------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -34)----------------------------------------------------------FilterExec: l_shipdate@4 >= 9131 AND l_shipdate@4 <= 9861 +34)----------------------------------------------------------FilterExec: l_shipdate@4 >= 1995-01-01 AND l_shipdate@4 <= 1996-12-31 35)------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate], has_header=false 36)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 37)----------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 diff --git a/datafusion/sqllogictest/test_files/tpch/q8.slt.part b/datafusion/sqllogictest/test_files/tpch/q8.slt.part index ec89bd0c0806..fd5773438466 100644 --- a/datafusion/sqllogictest/test_files/tpch/q8.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q8.slt.part @@ -57,8 +57,8 @@ order by ---- logical_plan 01)Sort: all_nations.o_year ASC NULLS LAST -02)--Projection: all_nations.o_year, CAST(CAST(SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END) AS Decimal128(12, 2)) / CAST(SUM(all_nations.volume) AS Decimal128(12, 2)) AS Decimal128(15, 2)) AS mkt_share -03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)]] +02)--Projection: all_nations.o_year, CAST(CAST(sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END) AS Decimal128(12, 2)) / CAST(sum(all_nations.volume) AS Decimal128(12, 2)) AS Decimal128(15, 2)) AS mkt_share +03)----Aggregate: groupBy=[[all_nations.o_year]], aggr=[[sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Decimal128(Some(0),38,4) END) AS sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)]] 04)------SubqueryAlias: all_nations 05)--------Projection: date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS volume, n2.n_name AS nation 06)----------Inner Join: n1.n_regionkey = region.r_regionkey @@ -79,8 +79,8 @@ logical_plan 21)----------------------------------------TableScan: part projection=[p_partkey, p_type], partial_filters=[part.p_type = Utf8("ECONOMY ANODIZED STEEL")] 22)------------------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount] 23)--------------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -24)----------------------------Filter: orders.o_orderdate >= Date32("9131") AND orders.o_orderdate <= Date32("9861") -25)------------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("9131"), orders.o_orderdate <= Date32("9861")] +24)----------------------------Filter: orders.o_orderdate >= Date32("1995-01-01") AND orders.o_orderdate <= Date32("1996-12-31") +25)------------------------------TableScan: orders projection=[o_orderkey, o_custkey, o_orderdate], partial_filters=[orders.o_orderdate >= Date32("1995-01-01"), orders.o_orderdate <= Date32("1996-12-31")] 26)------------------------TableScan: customer projection=[c_custkey, c_nationkey] 27)--------------------SubqueryAlias: n1 28)----------------------TableScan: nation projection=[n_nationkey, n_regionkey] @@ -91,12 +91,12 @@ logical_plan 33)----------------TableScan: region projection=[r_regionkey, r_name], partial_filters=[region.r_name = Utf8("AMERICA")] physical_plan 01)SortPreservingMergeExec: [o_year@0 ASC NULLS LAST] -02)--SortExec: expr=[o_year@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[o_year@0 as o_year, CAST(CAST(SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END)@1 AS Decimal128(12, 2)) / CAST(SUM(all_nations.volume)@2 AS Decimal128(12, 2)) AS Decimal128(15, 2)) as mkt_share] -04)------AggregateExec: mode=FinalPartitioned, gby=[o_year@0 as o_year], aggr=[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)] +02)--SortExec: expr=[o_year@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[o_year@0 as o_year, CAST(CAST(sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END)@1 AS Decimal128(12, 2)) / CAST(sum(all_nations.volume)@2 AS Decimal128(12, 2)) AS Decimal128(15, 2)) as mkt_share] +04)------AggregateExec: mode=FinalPartitioned, gby=[o_year@0 as o_year], aggr=[sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([o_year@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[o_year@0 as o_year], aggr=[SUM(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), SUM(all_nations.volume)] +07)------------AggregateExec: mode=Partial, gby=[o_year@0 as o_year], aggr=[sum(CASE WHEN all_nations.nation = Utf8("BRAZIL") THEN all_nations.volume ELSE Int64(0) END), sum(all_nations.volume)] 08)--------------ProjectionExec: expr=[date_part(YEAR, o_orderdate@2) as o_year, l_extendedprice@0 * (Some(1),20,0 - l_discount@1) as volume, n_name@3 as nation] 09)----------------CoalesceBatchesExec: target_batch_size=8192 10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(n_regionkey@3, r_regionkey@0)], projection=[l_extendedprice@0, l_discount@1, o_orderdate@2, n_name@4] @@ -126,42 +126,40 @@ physical_plan 34)------------------------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, l_partkey@1)], projection=[l_orderkey@1, l_suppkey@3, l_extendedprice@4, l_discount@5] 35)--------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 36)----------------------------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -37)------------------------------------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -38)--------------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -39)----------------------------------------------------------------------------FilterExec: p_type@1 = ECONOMY ANODIZED STEEL -40)------------------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -41)--------------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false -42)--------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -43)----------------------------------------------------------------------RepartitionExec: partitioning=Hash([l_partkey@1], 4), input_partitions=4 -44)------------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount], has_header=false -45)------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -46)--------------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -47)----------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -48)------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -49)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -50)------------------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 -51)--------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -52)----------------------------------------------------------FilterExec: o_orderdate@2 >= 9131 AND o_orderdate@2 <= 9861 -53)------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false -54)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -55)----------------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 -56)------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -57)--------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false -58)------------------------------------CoalesceBatchesExec: target_batch_size=8192 -59)--------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -60)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -61)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false -62)----------------------------CoalesceBatchesExec: target_batch_size=8192 -63)------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -64)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -65)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false -66)--------------------CoalesceBatchesExec: target_batch_size=8192 -67)----------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 -68)------------------------ProjectionExec: expr=[r_regionkey@0 as r_regionkey] -69)--------------------------CoalesceBatchesExec: target_batch_size=8192 -70)----------------------------FilterExec: r_name@1 = AMERICA -71)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -72)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false +37)------------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +38)--------------------------------------------------------------------------FilterExec: p_type@1 = ECONOMY ANODIZED STEEL, projection=[p_partkey@0] +39)----------------------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +40)------------------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_type], has_header=false +41)--------------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +42)----------------------------------------------------------------------RepartitionExec: partitioning=Hash([l_partkey@1], 4), input_partitions=4 +43)------------------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_extendedprice, l_discount], has_header=false +44)------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +45)--------------------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +46)----------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +47)------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +48)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +49)------------------------------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +50)--------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +51)----------------------------------------------------------FilterExec: o_orderdate@2 >= 1995-01-01 AND o_orderdate@2 <= 1996-12-31 +52)------------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_custkey, o_orderdate], has_header=false +53)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +54)----------------------------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4 +55)------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +56)--------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_nationkey], has_header=false +57)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +58)--------------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +59)----------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +60)------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_regionkey], has_header=false +61)----------------------------CoalesceBatchesExec: target_batch_size=8192 +62)------------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +63)--------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +64)----------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +65)--------------------CoalesceBatchesExec: target_batch_size=8192 +66)----------------------RepartitionExec: partitioning=Hash([r_regionkey@0], 4), input_partitions=4 +67)------------------------CoalesceBatchesExec: target_batch_size=8192 +68)--------------------------FilterExec: r_name@1 = AMERICA, projection=[r_regionkey@0] +69)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +70)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/region.tbl]]}, projection=[r_regionkey, r_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/tpch/q9.slt.part b/datafusion/sqllogictest/test_files/tpch/q9.slt.part index 61ed162aa712..c4910beb842b 100644 --- a/datafusion/sqllogictest/test_files/tpch/q9.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q9.slt.part @@ -52,81 +52,78 @@ order by limit 10; ---- logical_plan -01)Limit: skip=0, fetch=10 -02)--Sort: profit.nation ASC NULLS LAST, profit.o_year DESC NULLS FIRST, fetch=10 -03)----Projection: profit.nation, profit.o_year, SUM(profit.amount) AS sum_profit -04)------Aggregate: groupBy=[[profit.nation, profit.o_year]], aggr=[[SUM(profit.amount)]] -05)--------SubqueryAlias: profit -06)----------Projection: nation.n_name AS nation, date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) - partsupp.ps_supplycost * lineitem.l_quantity AS amount -07)------------Inner Join: supplier.s_nationkey = nation.n_nationkey -08)--------------Projection: lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, partsupp.ps_supplycost, orders.o_orderdate -09)----------------Inner Join: lineitem.l_orderkey = orders.o_orderkey -10)------------------Projection: lineitem.l_orderkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, partsupp.ps_supplycost -11)--------------------Inner Join: lineitem.l_suppkey = partsupp.ps_suppkey, lineitem.l_partkey = partsupp.ps_partkey -12)----------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey -13)------------------------Inner Join: lineitem.l_suppkey = supplier.s_suppkey -14)--------------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount -15)----------------------------Inner Join: part.p_partkey = lineitem.l_partkey -16)------------------------------Projection: part.p_partkey -17)--------------------------------Filter: part.p_name LIKE Utf8("%green%") -18)----------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("%green%")] -19)------------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] -20)--------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] -21)----------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] -22)------------------TableScan: orders projection=[o_orderkey, o_orderdate] -23)--------------TableScan: nation projection=[n_nationkey, n_name] +01)Sort: profit.nation ASC NULLS LAST, profit.o_year DESC NULLS FIRST, fetch=10 +02)--Projection: profit.nation, profit.o_year, sum(profit.amount) AS sum_profit +03)----Aggregate: groupBy=[[profit.nation, profit.o_year]], aggr=[[sum(profit.amount)]] +04)------SubqueryAlias: profit +05)--------Projection: nation.n_name AS nation, date_part(Utf8("YEAR"), orders.o_orderdate) AS o_year, lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) - partsupp.ps_supplycost * lineitem.l_quantity AS amount +06)----------Inner Join: supplier.s_nationkey = nation.n_nationkey +07)------------Projection: lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, partsupp.ps_supplycost, orders.o_orderdate +08)--------------Inner Join: lineitem.l_orderkey = orders.o_orderkey +09)----------------Projection: lineitem.l_orderkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey, partsupp.ps_supplycost +10)------------------Inner Join: lineitem.l_suppkey = partsupp.ps_suppkey, lineitem.l_partkey = partsupp.ps_partkey +11)--------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, supplier.s_nationkey +12)----------------------Inner Join: lineitem.l_suppkey = supplier.s_suppkey +13)------------------------Projection: lineitem.l_orderkey, lineitem.l_partkey, lineitem.l_suppkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount +14)--------------------------Inner Join: part.p_partkey = lineitem.l_partkey +15)----------------------------Projection: part.p_partkey +16)------------------------------Filter: part.p_name LIKE Utf8("%green%") +17)--------------------------------TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("%green%")] +18)----------------------------TableScan: lineitem projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount] +19)------------------------TableScan: supplier projection=[s_suppkey, s_nationkey] +20)--------------------TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_supplycost] +21)----------------TableScan: orders projection=[o_orderkey, o_orderdate] +22)------------TableScan: nation projection=[n_nationkey, n_name] physical_plan -01)GlobalLimitExec: skip=0, fetch=10 -02)--SortPreservingMergeExec: [nation@0 ASC NULLS LAST,o_year@1 DESC], fetch=10 -03)----SortExec: TopK(fetch=10), expr=[nation@0 ASC NULLS LAST,o_year@1 DESC] -04)------ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, SUM(profit.amount)@2 as sum_profit] -05)--------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] -06)----------CoalesceBatchesExec: target_batch_size=8192 -07)------------RepartitionExec: partitioning=Hash([nation@0, o_year@1], 4), input_partitions=4 -08)--------------AggregateExec: mode=Partial, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[SUM(profit.amount)] -09)----------------ProjectionExec: expr=[n_name@5 as nation, date_part(YEAR, o_orderdate@4) as o_year, l_extendedprice@1 * (Some(1),20,0 - l_discount@2) - ps_supplycost@3 * l_quantity@0 as amount] -10)------------------CoalesceBatchesExec: target_batch_size=8192 -11)--------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[l_quantity@0, l_extendedprice@1, l_discount@2, ps_supplycost@4, o_orderdate@5, n_name@7] -12)----------------------CoalesceBatchesExec: target_batch_size=8192 -13)------------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 -14)--------------------------CoalesceBatchesExec: target_batch_size=8192 -15)----------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)], projection=[l_quantity@1, l_extendedprice@2, l_discount@3, s_nationkey@4, ps_supplycost@5, o_orderdate@7] -16)------------------------------CoalesceBatchesExec: target_batch_size=8192 -17)--------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 -18)----------------------------------CoalesceBatchesExec: target_batch_size=8192 -19)------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@2, ps_suppkey@1), (l_partkey@1, ps_partkey@0)], projection=[l_orderkey@0, l_quantity@3, l_extendedprice@4, l_discount@5, s_nationkey@6, ps_supplycost@9] -20)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -21)----------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@2, l_partkey@1], 4), input_partitions=4 -22)------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -23)--------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@2, s_suppkey@0)], projection=[l_orderkey@0, l_partkey@1, l_suppkey@2, l_quantity@3, l_extendedprice@4, l_discount@5, s_nationkey@7] -24)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -25)------------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@2], 4), input_partitions=4 -26)--------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -27)----------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, l_partkey@1)], projection=[l_orderkey@1, l_partkey@2, l_suppkey@3, l_quantity@4, l_extendedprice@5, l_discount@6] -28)------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -29)--------------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 -30)----------------------------------------------------------ProjectionExec: expr=[p_partkey@0 as p_partkey] -31)------------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -32)--------------------------------------------------------------FilterExec: p_name@1 LIKE %green% -33)----------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -34)------------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false -35)------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -36)--------------------------------------------------------RepartitionExec: partitioning=Hash([l_partkey@1], 4), input_partitions=4 -37)----------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount], has_header=false -38)----------------------------------------------CoalesceBatchesExec: target_batch_size=8192 -39)------------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 -40)--------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -41)----------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false -42)--------------------------------------CoalesceBatchesExec: target_batch_size=8192 -43)----------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1, ps_partkey@0], 4), input_partitions=4 -44)------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false -45)------------------------------CoalesceBatchesExec: target_batch_size=8192 -46)--------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 -47)----------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderdate], has_header=false -48)----------------------CoalesceBatchesExec: target_batch_size=8192 -49)------------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 -50)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -51)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false +01)SortPreservingMergeExec: [nation@0 ASC NULLS LAST, o_year@1 DESC], fetch=10 +02)--SortExec: TopK(fetch=10), expr=[nation@0 ASC NULLS LAST, o_year@1 DESC], preserve_partitioning=[true] +03)----ProjectionExec: expr=[nation@0 as nation, o_year@1 as o_year, sum(profit.amount)@2 as sum_profit] +04)------AggregateExec: mode=FinalPartitioned, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[sum(profit.amount)] +05)--------CoalesceBatchesExec: target_batch_size=8192 +06)----------RepartitionExec: partitioning=Hash([nation@0, o_year@1], 4), input_partitions=4 +07)------------AggregateExec: mode=Partial, gby=[nation@0 as nation, o_year@1 as o_year], aggr=[sum(profit.amount)] +08)--------------ProjectionExec: expr=[n_name@5 as nation, date_part(YEAR, o_orderdate@4) as o_year, l_extendedprice@1 * (Some(1),20,0 - l_discount@2) - ps_supplycost@3 * l_quantity@0 as amount] +09)----------------CoalesceBatchesExec: target_batch_size=8192 +10)------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s_nationkey@3, n_nationkey@0)], projection=[l_quantity@0, l_extendedprice@1, l_discount@2, ps_supplycost@4, o_orderdate@5, n_name@7] +11)--------------------CoalesceBatchesExec: target_batch_size=8192 +12)----------------------RepartitionExec: partitioning=Hash([s_nationkey@3], 4), input_partitions=4 +13)------------------------CoalesceBatchesExec: target_batch_size=8192 +14)--------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_orderkey@0, o_orderkey@0)], projection=[l_quantity@1, l_extendedprice@2, l_discount@3, s_nationkey@4, ps_supplycost@5, o_orderdate@7] +15)----------------------------CoalesceBatchesExec: target_batch_size=8192 +16)------------------------------RepartitionExec: partitioning=Hash([l_orderkey@0], 4), input_partitions=4 +17)--------------------------------CoalesceBatchesExec: target_batch_size=8192 +18)----------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@2, ps_suppkey@1), (l_partkey@1, ps_partkey@0)], projection=[l_orderkey@0, l_quantity@3, l_extendedprice@4, l_discount@5, s_nationkey@6, ps_supplycost@9] +19)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +20)--------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@2, l_partkey@1], 4), input_partitions=4 +21)----------------------------------------CoalesceBatchesExec: target_batch_size=8192 +22)------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_suppkey@2, s_suppkey@0)], projection=[l_orderkey@0, l_partkey@1, l_suppkey@2, l_quantity@3, l_extendedprice@4, l_discount@5, s_nationkey@7] +23)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +24)----------------------------------------------RepartitionExec: partitioning=Hash([l_suppkey@2], 4), input_partitions=4 +25)------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +26)--------------------------------------------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(p_partkey@0, l_partkey@1)], projection=[l_orderkey@1, l_partkey@2, l_suppkey@3, l_quantity@4, l_extendedprice@5, l_discount@6] +27)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +28)------------------------------------------------------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4 +29)--------------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +30)----------------------------------------------------------FilterExec: p_name@1 LIKE %green%, projection=[p_partkey@0] +31)------------------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +32)--------------------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_name], has_header=false +33)----------------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +34)------------------------------------------------------RepartitionExec: partitioning=Hash([l_partkey@1], 4), input_partitions=4 +35)--------------------------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_orderkey, l_partkey, l_suppkey, l_quantity, l_extendedprice, l_discount], has_header=false +36)--------------------------------------------CoalesceBatchesExec: target_batch_size=8192 +37)----------------------------------------------RepartitionExec: partitioning=Hash([s_suppkey@0], 4), input_partitions=4 +38)------------------------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +39)--------------------------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/supplier.tbl]]}, projection=[s_suppkey, s_nationkey], has_header=false +40)------------------------------------CoalesceBatchesExec: target_batch_size=8192 +41)--------------------------------------RepartitionExec: partitioning=Hash([ps_suppkey@1, ps_partkey@0], 4), input_partitions=4 +42)----------------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:0..2932049], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:2932049..5864098], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:5864098..8796147], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/partsupp.tbl:8796147..11728193]]}, projection=[ps_partkey, ps_suppkey, ps_supplycost], has_header=false +43)----------------------------CoalesceBatchesExec: target_batch_size=8192 +44)------------------------------RepartitionExec: partitioning=Hash([o_orderkey@0], 4), input_partitions=4 +45)--------------------------------CsvExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:0..4223281], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:4223281..8446562], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:8446562..12669843], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/orders.tbl:12669843..16893122]]}, projection=[o_orderkey, o_orderdate], has_header=false +46)--------------------CoalesceBatchesExec: target_batch_size=8192 +47)----------------------RepartitionExec: partitioning=Hash([n_nationkey@0], 4), input_partitions=4 +48)------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +49)--------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/nation.tbl]]}, projection=[n_nationkey, n_name], has_header=false diff --git a/datafusion/sqllogictest/test_files/type_coercion.slt b/datafusion/sqllogictest/test_files/type_coercion.slt index aa1e6826eca5..43e7c2f7bc25 100644 --- a/datafusion/sqllogictest/test_files/type_coercion.slt +++ b/datafusion/sqllogictest/test_files/type_coercion.slt @@ -49,3 +49,208 @@ select interval '1 month' - '2023-05-01'::date; # interval - timestamp query error DataFusion error: Error during planning: Cannot coerce arithmetic expression Interval\(MonthDayNano\) \- Timestamp\(Nanosecond, None\) to valid types SELECT interval '1 month' - '2023-05-01 12:30:00'::timestamp; + +# dictionary(int32, utf8) -> utf8 +query T +select arrow_cast('foo', 'Dictionary(Int32, Utf8)') || arrow_cast('bar', 'Dictionary(Int32, Utf8)'); +---- +foobar + +# dictionary(int32, largeUtf8) -> largeUtf8 +query T +select arrow_cast('foo', 'Dictionary(Int32, LargeUtf8)') || arrow_cast('bar', 'Dictionary(Int32, LargeUtf8)'); +---- +foobar + +#################################### +## Concat column dictionary test ## +#################################### +statement ok +create table t as values (arrow_cast('foo', 'Dictionary(Int32, Utf8)'), arrow_cast('bar', 'Dictionary(Int32, Utf8)')); + +query T +select column1 || column2 from t; +---- +foobar + +statement ok +DROP TABLE t + +####################################### +## Concat column dictionary test end ## +####################################### + +#################################### +## Test type coercion with UNIONs ## +#################################### + +# Disable optimizer to test only the analyzer with type coercion +statement ok +set datafusion.optimizer.max_passes = 0; + +statement ok +set datafusion.explain.logical_plan_only = true; + +# Create test table +statement ok +CREATE TABLE orders( + order_id INT UNSIGNED NOT NULL, + customer_id INT UNSIGNED NOT NULL, + o_item_id VARCHAR NOT NULL, + qty INT NOT NULL, + price DOUBLE NOT NULL, + delivered BOOLEAN NOT NULL +); + +# union_different_num_columns_error() / UNION +query error DataFusion error: Error during planning: UNION queries have different number of columns: left has 1 columns whereas right has 2 columns +SELECT order_id FROM orders UNION SELECT customer_id, o_item_id FROM orders + +# union_different_num_columns_error() / UNION ALL +query error DataFusion error: Error during planning: UNION queries have different number of columns: left has 1 columns whereas right has 2 columns +SELECT order_id FROM orders UNION ALL SELECT customer_id, o_item_id FROM orders + +# union_with_different_column_names() +query TT +EXPLAIN SELECT order_id from orders UNION ALL SELECT customer_id FROM orders +---- +logical_plan +01)Union +02)--Projection: orders.order_id +03)----TableScan: orders +04)--Projection: orders.customer_id AS order_id +05)----TableScan: orders + +# union_values_with_no_alias() +query TT +EXPLAIN SELECT 1, 2 UNION ALL SELECT 3, 4 +---- +logical_plan +01)Union +02)--Projection: Int64(1) AS Int64(1), Int64(2) AS Int64(2) +03)----EmptyRelation +04)--Projection: Int64(3) AS Int64(1), Int64(4) AS Int64(2) +05)----EmptyRelation + +# union_with_incompatible_data_type() +query error Incompatible inputs for Union: Previous inputs were of type Interval\(MonthDayNano\), but got incompatible type Int64 on column 'Int64\(1\)' +SELECT interval '1 year 1 day' UNION ALL SELECT 1 + +# union_with_different_decimal_data_types() +query TT +EXPLAIN SELECT 1 a UNION ALL SELECT 1.1 a +---- +logical_plan +01)Union +02)--Projection: CAST(Int64(1) AS Float64) AS a +03)----EmptyRelation +04)--Projection: Float64(1.1) AS a +05)----EmptyRelation + +# union_with_null() +query TT +EXPLAIN SELECT NULL a UNION ALL SELECT 1.1 a +---- +logical_plan +01)Union +02)--Projection: CAST(NULL AS Float64) AS a +03)----EmptyRelation +04)--Projection: Float64(1.1) AS a +05)----EmptyRelation + +# union_with_float_and_string() +query TT +EXPLAIN SELECT 'a' a UNION ALL SELECT 1.1 a +---- +logical_plan +01)Union +02)--Projection: Utf8("a") AS a +03)----EmptyRelation +04)--Projection: CAST(Float64(1.1) AS Utf8) AS a +05)----EmptyRelation + +# union_with_multiply_cols() +query TT +EXPLAIN SELECT 'a' a, 1 b UNION ALL SELECT 1.1 a, 1.1 b +---- +logical_plan +01)Union +02)--Projection: Utf8("a") AS a, CAST(Int64(1) AS Float64) AS b +03)----EmptyRelation +04)--Projection: CAST(Float64(1.1) AS Utf8) AS a, Float64(1.1) AS b +05)----EmptyRelation + +# sorted_union_with_different_types_and_group_by() +query TT +EXPLAIN SELECT a FROM (select 1 a) x GROUP BY 1 + UNION ALL +(SELECT a FROM (select 1.1 a) x GROUP BY 1) ORDER BY 1 +---- +logical_plan +01)Sort: x.a ASC NULLS LAST +02)--Union +03)----Projection: CAST(x.a AS Float64) AS a +04)------Aggregate: groupBy=[[x.a]], aggr=[[]] +05)--------SubqueryAlias: x +06)----------Projection: Int64(1) AS a +07)------------EmptyRelation +08)----Projection: x.a +09)------Aggregate: groupBy=[[x.a]], aggr=[[]] +10)--------SubqueryAlias: x +11)----------Projection: Float64(1.1) AS a +12)------------EmptyRelation + +# union_with_binary_expr_and_cast() +query TT +EXPLAIN SELECT cast(0.0 + a as integer) FROM (select 1 a) x GROUP BY 1 + UNION ALL +(SELECT 2.1 + a FROM (select 1 a) x GROUP BY 1) +---- +logical_plan +01)Union +02)--Projection: CAST(Float64(0) + x.a AS Float64) AS Float64(0) + x.a +03)----Aggregate: groupBy=[[CAST(Float64(0) + CAST(x.a AS Float64) AS Int32)]], aggr=[[]] +04)------SubqueryAlias: x +05)--------Projection: Int64(1) AS a +06)----------EmptyRelation +07)--Projection: Float64(2.1) + x.a AS Float64(0) + x.a +08)----Aggregate: groupBy=[[Float64(2.1) + CAST(x.a AS Float64)]], aggr=[[]] +09)------SubqueryAlias: x +10)--------Projection: Int64(1) AS a +11)----------EmptyRelation + +# union_with_aliases() +query TT +EXPLAIN SELECT a as a1 FROM (select 1 a) x GROUP BY 1 + UNION ALL +(SELECT a as a1 FROM (select 1.1 a) x GROUP BY 1) +---- +logical_plan +01)Union +02)--Projection: CAST(x.a AS Float64) AS a1 +03)----Aggregate: groupBy=[[x.a]], aggr=[[]] +04)------SubqueryAlias: x +05)--------Projection: Int64(1) AS a +06)----------EmptyRelation +07)--Projection: x.a AS a1 +08)----Aggregate: groupBy=[[x.a]], aggr=[[]] +09)------SubqueryAlias: x +10)--------Projection: Float64(1.1) AS a +11)----------EmptyRelation + +# union_with_incompatible_data_types() +query error Incompatible inputs for Union: Previous inputs were of type Utf8, but got incompatible type Boolean on column 'a' +SELECT 'a' a UNION ALL SELECT true a + +statement ok +SET datafusion.optimizer.max_passes = 3; + +statement ok +SET datafusion.explain.logical_plan_only = false; + +statement ok +DROP TABLE orders; + +######################################## +## Test type coercion with UNIONs end ## +######################################## diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index ba2563fd5a4d..fb7afdda2ea8 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -22,7 +22,7 @@ statement ok CREATE TABLE t1( id INT, - name TEXT, + name TEXT ) as VALUES (1, 'Alex'), (2, 'Bob'), @@ -32,7 +32,7 @@ CREATE TABLE t1( statement ok CREATE TABLE t2( id TINYINT, - name TEXT, + name TEXT ) as VALUES (1, 'Alex'), (2, 'Bob'), @@ -105,8 +105,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query I select COUNT(*) from ( @@ -135,6 +135,36 @@ SELECT SUM(d) FROM ( ---- 5 +# three way union with aggregate and type coercion +query II rowsort +SELECT c1, SUM(c2) FROM ( + SELECT 1 as c1, 1::int as c2 + UNION + SELECT 2 as c1, 2::int as c2 + UNION + SELECT 3 as c1, COALESCE(3::int, 0) as c2 +) as a +GROUP BY c1 +---- +1 1 +2 2 +3 3 + +# This test goes through schema check in aggregate plan, if count's nullable is not matched, this test failed +query II rowsort +SELECT c1, SUM(c2) FROM ( + SELECT 1 as c1, 1::int as c2 + UNION + SELECT 2 as c1, 2::int as c2 + UNION + SELECT 3 as c1, count(1) as c2 +) as a +GROUP BY c1 +---- +1 1 +2 2 +3 1 + # union_all_with_count statement ok CREATE table t as SELECT 1 as a @@ -381,25 +411,23 @@ query TT explain SELECT c1, c9 FROM aggregate_test_100 UNION ALL SELECT c1, c3 FROM aggregate_test_100 ORDER BY c9 DESC LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 -03)----Union -04)------Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 AS Int64) AS c9 -05)--------TableScan: aggregate_test_100 projection=[c1, c9] -06)------Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3 AS Int64) AS c9 -07)--------TableScan: aggregate_test_100 projection=[c1, c3] +01)Sort: aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 +02)--Union +03)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c9 AS Int64) AS c9 +04)------TableScan: aggregate_test_100 projection=[c1, c9] +05)----Projection: aggregate_test_100.c1, CAST(aggregate_test_100.c3 AS Int64) AS c9 +06)------TableScan: aggregate_test_100 projection=[c1, c3] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortPreservingMergeExec: [c9@1 DESC], fetch=5 -03)----UnionExec -04)------SortExec: expr=[c9@1 DESC] -05)--------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9] -06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true -08)------SortExec: expr=[c9@1 DESC] -09)--------ProjectionExec: expr=[c1@0 as c1, CAST(c3@1 AS Int64) as c9] -10)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -11)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], has_header=true +01)SortPreservingMergeExec: [c9@1 DESC], fetch=5 +02)--UnionExec +03)----SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true] +04)------ProjectionExec: expr=[c1@0 as c1, CAST(c9@1 AS Int64) as c9] +05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true +07)----SortExec: TopK(fetch=5), expr=[c9@1 DESC], preserve_partitioning=[true] +08)------ProjectionExec: expr=[c1@0 as c1, CAST(c3@1 AS Int64) as c9] +09)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +10)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c3], has_header=true query TI SELECT c1, c9 FROM aggregate_test_100 UNION ALL SELECT c1, c3 FROM aggregate_test_100 ORDER BY c9 DESC LIMIT 5 @@ -420,16 +448,16 @@ SELECT count(*) FROM ( ) GROUP BY name ---- logical_plan -01)Projection: COUNT(*) -02)--Aggregate: groupBy=[[t1.name]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Projection: count(*) +02)--Aggregate: groupBy=[[t1.name]], aggr=[[count(Int64(1)) AS count(*)]] 03)----Union 04)------Aggregate: groupBy=[[t1.name]], aggr=[[]] 05)--------TableScan: t1 projection=[name] 06)------Aggregate: groupBy=[[t2.name]], aggr=[[]] 07)--------TableScan: t2 projection=[name] physical_plan -01)ProjectionExec: expr=[COUNT(*)@1 as COUNT(*)] -02)--AggregateExec: mode=SinglePartitioned, gby=[name@0 as name], aggr=[COUNT(*)] +01)ProjectionExec: expr=[count(*)@1 as count(*)] +02)--AggregateExec: mode=SinglePartitioned, gby=[name@0 as name], aggr=[count(*)] 03)----InterleaveExec 04)------AggregateExec: mode=FinalPartitioned, gby=[name@0 as name], aggr=[] 05)--------CoalesceBatchesExec: target_batch_size=2 @@ -444,6 +472,67 @@ physical_plan 14)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 15)----------------MemoryExec: partitions=1, partition_sizes=[1] +# Union with limit push down 3 children test case +query TT +EXPLAIN + SELECT count(*) as cnt FROM + (SELECT count(*), c1 + FROM aggregate_test_100 + WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434' + GROUP BY c1 + ORDER BY c1 + ) AS a + UNION ALL + SELECT 1 as cnt + UNION ALL + SELECT lead(c1, 1) OVER () as cnt FROM (select 1 as c1) AS b + LIMIT 3 +---- +logical_plan +01)Limit: skip=0, fetch=3 +02)--Union +03)----Projection: count(*) AS cnt +04)------Limit: skip=0, fetch=3 +05)--------Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] +06)----------SubqueryAlias: a +07)------------Projection: +08)--------------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] +09)----------------Projection: aggregate_test_100.c1 +10)------------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +11)--------------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +12)----Projection: Int64(1) AS cnt +13)------Limit: skip=0, fetch=3 +14)--------EmptyRelation +15)----Projection: lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS cnt +16)------Limit: skip=0, fetch=3 +17)--------WindowAggr: windowExpr=[[lead(b.c1, Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +18)----------SubqueryAlias: b +19)------------Projection: Int64(1) AS c1 +20)--------------EmptyRelation +physical_plan +01)GlobalLimitExec: skip=0, fetch=3 +02)--CoalescePartitionsExec +03)----UnionExec +04)------ProjectionExec: expr=[count(*)@0 as cnt] +05)--------AggregateExec: mode=Final, gby=[], aggr=[count(*)] +06)----------CoalescePartitionsExec +07)------------AggregateExec: mode=Partial, gby=[], aggr=[count(*)] +08)--------------ProjectionExec: expr=[] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +10)------------------CoalesceBatchesExec: target_batch_size=2 +11)--------------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +12)----------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +13)------------------------CoalesceBatchesExec: target_batch_size=2 +14)--------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] +15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true +17)------ProjectionExec: expr=[1 as cnt] +18)--------PlaceholderRowExec +19)------ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +20)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +21)----------ProjectionExec: expr=[1 as c1] +22)------------PlaceholderRowExec + ######## # Clean up after the test @@ -475,9 +564,9 @@ CREATE EXTERNAL TABLE t1 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW WITH ORDER (c1 ASC) -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE t2 ( @@ -496,9 +585,9 @@ CREATE EXTERNAL TABLE t2 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW WITH ORDER (c1a ASC) -LOCATION '../../testing/data/csv/aggregate_test_100.csv'; +LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); query TT explain @@ -547,15 +636,12 @@ logical_plan physical_plan 01)UnionExec 02)--ProjectionExec: expr=[Int64(1)@0 as a] -03)----AggregateExec: mode=FinalPartitioned, gby=[Int64(1)@0 as Int64(1)], aggr=[], ordering_mode=Sorted -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([Int64(1)@0], 4), input_partitions=1 -06)----------AggregateExec: mode=Partial, gby=[1 as Int64(1)], aggr=[], ordering_mode=Sorted -07)------------PlaceholderRowExec -08)--ProjectionExec: expr=[2 as a] -09)----PlaceholderRowExec -10)--ProjectionExec: expr=[3 as a] -11)----PlaceholderRowExec +03)----AggregateExec: mode=SinglePartitioned, gby=[1 as Int64(1)], aggr=[], ordering_mode=Sorted +04)------PlaceholderRowExec +05)--ProjectionExec: expr=[2 as a] +06)----PlaceholderRowExec +07)--ProjectionExec: expr=[3 as a] +08)----PlaceholderRowExec # test UNION ALL aliases correctly with aliased subquery query TT @@ -565,25 +651,113 @@ select x, y from (select 1 as x , max(10) as y) b ---- logical_plan 01)Union -02)--Projection: COUNT(*) AS count, a.n -03)----Aggregate: groupBy=[[a.n]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +02)--Projection: count(*) AS count, a.n +03)----Aggregate: groupBy=[[a.n]], aggr=[[count(Int64(1)) AS count(*)]] 04)------SubqueryAlias: a 05)--------Projection: Int64(5) AS n 06)----------EmptyRelation 07)--Projection: b.x AS count, b.y AS n 08)----SubqueryAlias: b -09)------Projection: Int64(1) AS x, MAX(Int64(10)) AS y -10)--------Aggregate: groupBy=[[]], aggr=[[MAX(Int64(10))]] +09)------Projection: Int64(1) AS x, max(Int64(10)) AS y +10)--------Aggregate: groupBy=[[]], aggr=[[max(Int64(10))]] 11)----------EmptyRelation physical_plan 01)UnionExec -02)--ProjectionExec: expr=[COUNT(*)@1 as count, n@0 as n] -03)----AggregateExec: mode=FinalPartitioned, gby=[n@0 as n], aggr=[COUNT(*)], ordering_mode=Sorted -04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([n@0], 4), input_partitions=1 -06)----------AggregateExec: mode=Partial, gby=[n@0 as n], aggr=[COUNT(*)], ordering_mode=Sorted -07)------------ProjectionExec: expr=[5 as n] -08)--------------PlaceholderRowExec -09)--ProjectionExec: expr=[1 as count, MAX(Int64(10))@0 as n] -10)----AggregateExec: mode=Single, gby=[], aggr=[MAX(Int64(10))] -11)------PlaceholderRowExec +02)--ProjectionExec: expr=[count(*)@1 as count, n@0 as n] +03)----AggregateExec: mode=SinglePartitioned, gby=[n@0 as n], aggr=[count(*)], ordering_mode=Sorted +04)------ProjectionExec: expr=[5 as n] +05)--------PlaceholderRowExec +06)--ProjectionExec: expr=[1 as count, max(Int64(10))@0 as n] +07)----AggregateExec: mode=Single, gby=[], aggr=[max(Int64(10))] +08)------PlaceholderRowExec + + +# Test issue: https://github.com/apache/datafusion/issues/11409 +statement ok +CREATE TABLE t1(v0 BIGINT, v1 BIGINT, v2 BIGINT, v3 BOOLEAN); + +statement ok +CREATE TABLE t2(v0 DOUBLE); + +query I +INSERT INTO t1(v0, v2, v1) VALUES (-1229445667, -342312412, -1507138076); +---- +1 + +query I +INSERT INTO t1(v0, v1) VALUES (1541512604, -1229445667); +---- +1 + +query I +INSERT INTO t1(v1, v3, v0, v2) VALUES (-1020641465, false, -1493773377, 1751276473); +---- +1 + +query I +INSERT INTO t1(v3) VALUES (true), (true), (false); +---- +3 + +query I +INSERT INTO t2(v0) VALUES (0.28014577292925047); +---- +1 + +query II +SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 + UNION ALL +SELECT t1.v2, t1.v0 FROM t2 NATURAL JOIN t1 WHERE (t1.v2 IS NULL); +---- + +statement ok +CREATE TABLE t3 ( + id INT +) as VALUES + (1), + (2), + (3) +; + +statement ok +CREATE TABLE t4 ( + id TEXT +) as VALUES + ('4'), + ('5'), + ('6') +; + +# test type coersion for wildcard expansion +query T rowsort +(SELECT * FROM t3 ) UNION ALL (SELECT * FROM t4) +---- +1 +2 +3 +4 +5 +6 + +statement ok +DROP TABLE t1; + +statement ok +DROP TABLE t2; + +statement ok +DROP TABLE t3; + +statement ok +DROP TABLE t4; + +# Test issue: https://github.com/apache/datafusion/issues/11742 +query R rowsort +WITH + tt(v1) AS (VALUES (1::INT),(NULL::INT)) +SELECT NVL(v1, 0.5) FROM tt + UNION ALL +SELECT NULL WHERE FALSE; +---- +0.5 +1 diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 28f3369ca0a2..947eb8630b52 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -22,12 +22,26 @@ statement ok CREATE TABLE unnest_table AS VALUES - ([1,2,3], [7], 1, [13, 14]), - ([4,5], [8,9,10], 2, [15, 16]), - ([6], [11,12], 3, null), - ([12], [null, 42, null], null, null), + ([1,2,3], [7], 1, [13, 14], struct(1,2)), + ([4,5], [8,9,10], 2, [15, 16], struct(3,4)), + ([6], [11,12], 3, null, null), + ([12], [null, 42, null], null, null, struct(7,8)), -- null array to verify the `preserve_nulls` option - (null, null, 4, [17, 18]) + (null, null, 4, [17, 18], null) +; + +statement ok +CREATE TABLE nested_unnest_table +AS VALUES + (struct('a', 'b', struct('c')), (struct('a', 'b', [10,20])), [struct('a', 'b')]), + (struct('d', 'e', struct('f')), (struct('x', 'y', [30,40, 50])), null) +; + +statement ok +CREATE TABLE recursive_unnest_table +AS VALUES + (struct([1], 'a'), [[[1],[2]],[[1,1]]], [struct([1],[[1,2]])]), + (struct([2], 'b'), [[[3,4],[5]],[[null,6],null,[7,8]]], [struct([2],[[3],[4]])]) ; ## Basic unnest expression in select list @@ -38,7 +52,13 @@ select unnest([1,2,3]); 2 3 -## Basic unnest expression in from clause +## Basic unnest expression in select struct +query III +select unnest(struct(1,2,3)); +---- +1 2 3 + +## Basic unnest list expression in from clause query I select * from unnest([1,2,3]); ---- @@ -46,6 +66,20 @@ select * from unnest([1,2,3]); 2 3 +## Basic unnest struct expression in from clause +query III +select * from unnest(struct(1,2,3)); +---- +1 2 3 + +## Multiple unnest expression in from clause +query IIII +select * from unnest(struct(1,2,3)),unnest([4,5,6]); +---- +1 2 3 4 +1 2 3 5 +1 2 3 6 + ## Unnest null in select list query error DataFusion error: This feature is not implemented: unnest\(\) does not support null yet @@ -57,12 +91,12 @@ select * from unnest(null); ## Unnest empty array in select list -query ? +query I select unnest([]); ---- ## Unnest empty array in from clause -query ? +query I select * from unnest([]); ---- @@ -131,6 +165,65 @@ select unnest(column1), column1 from unnest_table; 6 [6] 12 [12] +# unnest at different level at the same time +query II +select unnest([1,2,3]), unnest(unnest([[1,2,3]])); +---- +1 1 +2 2 +3 3 + +# binary expr linking different unnest exprs +query II +select unnest([1,2,3]) + unnest([1,2,3]), unnest([1,2,3]) + unnest([4,5]); +---- +2 5 +4 7 +6 NULL + + +# binary expr linking different recursive unnest exprs +query III +select unnest(unnest([[1,2,3]])) + unnest(unnest([[1,2,3]])), unnest(unnest([[1,2,3]])) + unnest([4,5]), unnest([4,5]); +---- +2 5 4 +4 7 5 +6 NULL NULL + + + + +## unnest as children of other expr +query I? +select unnest(column1) + 1 , column1 from unnest_table; +---- +2 [1, 2, 3] +3 [1, 2, 3] +4 [1, 2, 3] +5 [4, 5] +6 [4, 5] +7 [6] +13 [12] + +## unnest on multiple columns +query II +select unnest(column1), unnest(column2) from unnest_table; +---- +1 7 +2 NULL +3 NULL +4 8 +5 9 +NULL 10 +6 11 +NULL 12 +12 NULL +NULL 42 +NULL NULL + +query error DataFusion error: Error during planning: unnest\(\) can only be applied to array, struct and null +select unnest('foo'); + query ?II select array_remove(column1, 4), unnest(column2), column3 * 10 from unnest_table; ---- @@ -145,16 +238,12 @@ select array_remove(column1, 4), unnest(column2), column3 * 10 from unnest_table [12] NULL NULL -## Unnest column with scalars -query error DataFusion error: Error during planning: unnest\(\) can only be applied to array, struct and null -select unnest(column3) from unnest_table; - ## Unnest doesn't work with untyped nulls query error DataFusion error: This feature is not implemented: unnest\(\) does not support null yet select unnest(null) from unnest_table; ## Multiple unnest functions in selection -query ?I +query II select unnest([]), unnest(NULL::int[]); ---- @@ -174,7 +263,7 @@ NULL 10 NULL NULL NULL 17 NULL NULL 18 -query IIII +query IIIT select unnest(column1), unnest(column2) + 2, column3 * 10, unnest(array_remove(column1, '4')) @@ -206,7 +295,7 @@ query error DataFusion error: Error during planning: unnest\(\) requires exactly select unnest(); ## Unnest empty expression in from clause -query error DataFusion error: SQL error: ParserError\("Expected an expression:, found: \)"\) +query error DataFusion error: SQL error: ParserError\("Expected: an expression:, found: \)"\) select * from unnest(); @@ -227,18 +316,22 @@ select * from unnest( 2 b NULL NULL NULL c NULL NULL -query ?I +query II select * from unnest([], NULL::int[]); ---- ## Unnest struct expression in select list -query error DataFusion error: This feature is not implemented: unnest\(\) does not support struct yet +query ? select unnest(struct(null)); +---- +NULL ## Unnest struct expression in from clause -query error DataFusion error: This feature is not implemented: unnest\(\) does not support struct yet +query ? select * from unnest(struct(null)); +---- +NULL ## Unnest array expression @@ -288,6 +381,18 @@ select unnest(array_remove(column1, 12)) from unnest_table; 5 6 +## unnest struct-typed column and list-typed column at the same time +query I?II? +select unnest(column1), column1, unnest(column5), column5 from unnest_table; +---- +1 [1, 2, 3] 1 2 {c0: 1, c1: 2} +2 [1, 2, 3] 1 2 {c0: 1, c1: 2} +3 [1, 2, 3] 1 2 {c0: 1, c1: 2} +4 [4, 5] 3 4 {c0: 3, c1: 4} +5 [4, 5] 3 4 {c0: 3, c1: 4} +6 [6] NULL NULL NULL +12 [12] 7 8 {c0: 7, c1: 8} + ## Unnest in from clause with alias query I @@ -383,5 +488,368 @@ select unnest(array_remove(column1, 3)) - 1 as c1, column3 from unnest_table; 5 3 11 NULL +## unnest on nested(list(struct)) +query ? +select unnest(column3) as struct_elem from nested_unnest_table; +---- +{c0: a, c1: b} + +## unnest for nested struct(struct) +query TT? +select unnest(column1) from nested_unnest_table; +---- +a b {c0: c} +d e {c0: f} + +## unnest for nested(struct(list)) +query TT? +select unnest(column2) from nested_unnest_table; +---- +a b [10, 20] +x y [30, 40, 50] + +query error DataFusion error: type_coercion\ncaused by\nThis feature is not implemented: Unnest should be rewritten to LogicalPlan::Unnest before type coercion +select sum(unnest(generate_series(1,10))); + +query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression +select arrow_typeof(unnest(column5)) from unnest_table; + +query T +select arrow_typeof(unnest(column1)) from unnest_table; +---- +Int64 +Int64 +Int64 +Int64 +Int64 +Int64 +Int64 + +## unnest from a result of a logical plan with limit and offset +query I +select unnest(column1) from (select * from (values([1,2,3]), ([4,5,6])) limit 1 offset 1); +---- +4 +5 +6 + +query error DataFusion error: Error during planning: Projections require unique expression names but the expression "UNNEST\(unnest_table.column1\)" at position 0 and "UNNEST\(unnest_table.column1\)" at position 1 have the same name. Consider aliasing \("AS"\) one of them. +select unnest(column1), unnest(column1) from unnest_table; + +query II +select unnest(column1), unnest(column1) u1 from unnest_table; +---- +1 1 +2 2 +3 3 +4 4 +5 5 +6 6 +12 12 + +## the same unnest expr is referened multiple times (unnest is the bottom-most expr) +query ??II +select unnest(column2), unnest(unnest(column2)), unnest(unnest(unnest(column2))), unnest(unnest(unnest(column2))) + 1 from recursive_unnest_table; +---- +[[1], [2]] [1] 1 2 +[[1, 1]] [2] NULL NULL +[[1], [2]] [1, 1] 2 3 +[[1, 1]] NULL NULL NULL +[[1], [2]] [1] 1 2 +[[1, 1]] [2] 1 2 +[[1], [2]] [1, 1] NULL NULL +[[1, 1]] NULL NULL NULL +[[3, 4], [5]] [3, 4] 3 4 +[[, 6], , [7, 8]] [5] 4 5 +[[3, 4], [5]] [, 6] 5 6 +[[, 6], , [7, 8]] NULL NULL NULL +NULL [7, 8] NULL NULL +[[3, 4], [5]] [3, 4] NULL NULL +[[, 6], , [7, 8]] [5] 6 7 +[[3, 4], [5]] [, 6] NULL NULL +[[, 6], , [7, 8]] NULL NULL NULL +NULL [7, 8] NULL NULL +[[3, 4], [5]] NULL 7 8 +[[, 6], , [7, 8]] NULL 8 9 + +## the same composite expr (unnest(field_access(unnest(col)))) which containing unnest is referened multiple times +query ??II +select unnest(column3), unnest(column3)['c0'], unnest(unnest(column3)['c0']), unnest(unnest(column3)['c0']) + unnest(unnest(column3)['c0']) from recursive_unnest_table; +---- +{c0: [1], c1: [[1, 2]]} [1] 1 2 +{c0: [2], c1: [[3], [4]]} [2] 2 4 + + + + +## unnest list followed by unnest struct +query ??? +select unnest(unnest(column3)), column3 from recursive_unnest_table; +---- +[1] [[1, 2]] [{c0: [1], c1: [[1, 2]]}] +[2] [[3], [4]] [{c0: [2], c1: [[3], [4]]}] + + +query TT +explain select unnest(unnest(column3)), column3 from recursive_unnest_table; +---- +logical_plan +01)Unnest: lists[] structs[unnest_placeholder(UNNEST(recursive_unnest_table.column3))] +02)--Projection: unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)), recursive_unnest_table.column3 +03)----Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +04)------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +05)--------TableScan: recursive_unnest_table projection=[column3] +physical_plan +01)UnnestExec +02)--RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----ProjectionExec: expr=[unnest_placeholder(recursive_unnest_table.column3,depth=1)@0 as unnest_placeholder(UNNEST(recursive_unnest_table.column3)), column3@1 as column3] +04)------UnnestExec +05)--------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +06)----------MemoryExec: partitions=1, partition_sizes=[1] + +## unnest->field_access->unnest->unnest +query I? +select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; +---- +1 [{c0: [1], c1: [[1, 2]]}] +2 [{c0: [1], c1: [[1, 2]]}] +3 [{c0: [2], c1: [[3], [4]]}] +4 [{c0: [2], c1: [[3], [4]]}] + +## triple list unnest +query I? +select unnest(unnest(unnest(column2))), column2 from recursive_unnest_table; +---- +1 [[[1], [2]], [[1, 1]]] +2 [[[1], [2]], [[1, 1]]] +1 [[[1], [2]], [[1, 1]]] +1 [[[1], [2]], [[1, 1]]] +3 [[[3, 4], [5]], [[, 6], , [7, 8]]] +4 [[[3, 4], [5]], [[, 6], , [7, 8]]] +5 [[[3, 4], [5]], [[, 6], , [7, 8]]] +NULL [[[3, 4], [5]], [[, 6], , [7, 8]]] +6 [[[3, 4], [5]], [[, 6], , [7, 8]]] +7 [[[3, 4], [5]], [[, 6], , [7, 8]]] +8 [[[3, 4], [5]], [[, 6], , [7, 8]]] + + +query I?? +select unnest(unnest(unnest(column3)['c1'])), unnest(unnest(column3)['c1']), column3 from recursive_unnest_table; +---- +1 [1, 2] [{c0: [1], c1: [[1, 2]]}] +2 NULL [{c0: [1], c1: [[1, 2]]}] +3 [3] [{c0: [2], c1: [[3], [4]]}] +NULL [4] [{c0: [2], c1: [[3], [4]]}] +4 [3] [{c0: [2], c1: [[3], [4]]}] +NULL [4] [{c0: [2], c1: [[3], [4]]}] + +## demonstrate where recursive unnest is impossible +## and need multiple unnesting logical plans +## e.g unnest -> field_access -> unnest +query TT +explain select unnest(unnest(unnest(column3)['c1'])), column3 from recursive_unnest_table; +---- +logical_plan +01)Projection: unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2) AS UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), recursive_unnest_table.column3 +02)--Unnest: lists[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1])|depth=2] structs[] +03)----Projection: get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1) AS UNNEST(recursive_unnest_table.column3), Utf8("c1")) AS unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), recursive_unnest_table.column3 +04)------Unnest: lists[unnest_placeholder(recursive_unnest_table.column3)|depth=1] structs[] +05)--------Projection: recursive_unnest_table.column3 AS unnest_placeholder(recursive_unnest_table.column3), recursive_unnest_table.column3 +06)----------TableScan: recursive_unnest_table projection=[column3] +physical_plan +01)ProjectionExec: expr=[unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] +02)--UnnestExec +03)----ProjectionExec: expr=[get_field(unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +05)--------UnnestExec +06)----------ProjectionExec: expr=[column3@0 as unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] +07)------------MemoryExec: partitions=1, partition_sizes=[1] + + + +## group by unnest + +### without agg exprs +query I +select unnest(column1) c1 from unnest_table group by c1 order by c1; +---- +1 +2 +3 +4 +5 +6 +12 + +query II +select unnest(column1) c1, unnest(column2) c2 from unnest_table group by c1, c2 order by c1, c2; +---- +1 7 +2 NULL +3 NULL +4 8 +5 9 +6 11 +12 NULL +NULL 10 +NULL 12 +NULL 42 +NULL NULL + +query III +select unnest(column1) c1, unnest(column2) c2, column3 c3 from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 +2 NULL 1 +3 NULL 1 +4 8 2 +5 9 2 +6 11 3 +12 NULL NULL +NULL 10 2 +NULL 12 3 +NULL 42 NULL +NULL NULL NULL + +### with agg exprs + +query IIII +select unnest(column1) c1, unnest(column2) c2, column3 c3, count(1) from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 1 +2 NULL 1 1 +3 NULL 1 1 +4 8 2 1 +5 9 2 1 +6 11 3 1 +12 NULL NULL 1 +NULL 10 2 1 +NULL 12 3 1 +NULL 42 NULL 1 +NULL NULL NULL 1 + +query IIII +select unnest(column1) c1, unnest(column2) c2, column3 c3, count(column4) from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 1 +2 NULL 1 1 +3 NULL 1 1 +4 8 2 1 +5 9 2 1 +6 11 3 0 +12 NULL NULL 0 +NULL 10 2 1 +NULL 12 3 0 +NULL 42 NULL 0 +NULL NULL NULL 0 + +query IIIII +select unnest(column1) c1, unnest(column2) c2, column3 c3, count(column4), sum(column3) from unnest_table group by c1, c2, c3 order by c1, c2, c3; +---- +1 7 1 1 1 +2 NULL 1 1 1 +3 NULL 1 1 1 +4 8 2 1 2 +5 9 2 1 2 +6 11 3 0 3 +12 NULL NULL 0 NULL +NULL 10 2 1 2 +NULL 12 3 0 3 +NULL 42 NULL 0 NULL +NULL NULL NULL 0 NULL + +query II +select unnest(column1), count(*) from unnest_table group by unnest(column1) order by unnest(column1) desc; +---- +12 1 +6 1 +5 1 +4 1 +3 1 +2 1 +1 1 + +### group by recursive unnest list + +query ? +select unnest(unnest(column2)) c2 from recursive_unnest_table group by c2 order by c2; +---- +[1] +[1, 1] +[2] +[3, 4] +[5] +[7, 8] +[, 6] +NULL + +query ?I +select unnest(unnest(column2)) c2, count(column3) from recursive_unnest_table group by c2 order by c2; +---- +[1] 1 +[1, 1] 1 +[2] 1 +[3, 4] 1 +[5] 1 +[7, 8] 1 +[, 6] 1 +NULL 1 + +query error DataFusion error: Error during planning: Projection references non\-aggregate values +select unnest(column1) c1 from nested_unnest_table group by c1.c0; + +# TODO: this query should work. see issue: https://github.com/apache/datafusion/issues/12794 +query error DataFusion error: Internal error: unnest on struct can only be applied at the root level of select expression +select unnest(column1) c1 from nested_unnest_table + +query II??I?? +select unnest(column5), * from unnest_table; +---- +1 2 [1, 2, 3] [7] 1 [13, 14] {c0: 1, c1: 2} +3 4 [4, 5] [8, 9, 10] 2 [15, 16] {c0: 3, c1: 4} +NULL NULL [6] [11, 12] 3 NULL NULL +7 8 [12] [, 42, ] NULL NULL {c0: 7, c1: 8} +NULL NULL NULL NULL 4 [17, 18] NULL + +query TT???? +select unnest(column1), * from nested_unnest_table +---- +a b {c0: c} {c0: a, c1: b, c2: {c0: c}} {c0: a, c1: b, c2: [10, 20]} [{c0: a, c1: b}] +d e {c0: f} {c0: d, c1: e, c2: {c0: f}} {c0: x, c1: y, c2: [30, 40, 50]} NULL + +query ????? +select unnest(unnest(column3)), * from recursive_unnest_table +---- +[1] [[1, 2]] {c0: [1], c1: a} [[[1], [2]], [[1, 1]]] [{c0: [1], c1: [[1, 2]]}] +[2] [[3], [4]] {c0: [2], c1: b} [[[3, 4], [5]], [[, 6], , [7, 8]]] [{c0: [2], c1: [[3], [4]]}] + statement ok -drop table unnest_table; +CREATE TABLE join_table +AS VALUES + (1, 2, 3), + (2, 3, 4), + (4, 5, 6) +; + +query IIIII +select unnest(u.column5), j.* from unnest_table u join join_table j on u.column3 = j.column1 +---- +1 2 1 2 3 +3 4 2 3 4 +NULL NULL 4 5 6 + +query II?I? +select unnest(column5), * except (column5, column1) from unnest_table; +---- +1 2 [7] 1 [13, 14] +3 4 [8, 9, 10] 2 [15, 16] +NULL NULL [11, 12] 3 NULL +7 8 [, 42, ] NULL NULL +NULL NULL NULL 4 [17, 18] + +query III +select unnest(u.column5), j.* except(column2, column3) from unnest_table u join join_table j on u.column3 = j.column1 +---- +1 2 1 +3 4 2 +NULL NULL 4 diff --git a/datafusion/sqllogictest/test_files/update.slt b/datafusion/sqllogictest/test_files/update.slt index 49b2bd9aa0b5..aaba6998ee63 100644 --- a/datafusion/sqllogictest/test_files/update.slt +++ b/datafusion/sqllogictest/test_files/update.slt @@ -53,8 +53,8 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, () AS b, t1.c AS c, t1.d AS d 03)----Subquery: -04)------Projection: MAX(t2.b) -05)--------Aggregate: groupBy=[[]], aggr=[[MAX(t2.b)]] +04)------Projection: max(t2.b) +05)--------Aggregate: groupBy=[[]], aggr=[[max(t2.b)]] 06)----------Filter: outer_ref(t1.a) = t2.a 07)------------TableScan: t2 08)----TableScan: t1 @@ -67,14 +67,14 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t1.a AS a, t2.b AS b, CAST(t2.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t1.a = t2.a AND t1.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------TableScan: t1 06)--------TableScan: t2 statement ok create table t3(a int, b varchar, c double, d int); -# set from mutiple tables, sqlparser only supports from one table +# set from multiple tables, sqlparser only supports from one table query error DataFusion error: SQL error: ParserError\("Expected end of statement, found: ,"\) explain update t1 set b = t2.b, c = t3.a, d = 1 from t2, t3 where t1.a = t2.a and t1.a = t3.a; @@ -86,7 +86,7 @@ logical_plan 01)Dml: op=[Update] table=[t1] 02)--Projection: t.a AS a, t2.b AS b, CAST(t.a AS Float64) AS c, CAST(Int64(1) AS Int32) AS d 03)----Filter: t.a = t2.a AND t.b > Utf8("foo") AND t2.c > Float64(1) -04)------CrossJoin: +04)------Cross Join: 05)--------SubqueryAlias: t 06)----------TableScan: t1 07)--------TableScan: t2 diff --git a/datafusion/sqllogictest/test_files/wildcard.slt b/datafusion/sqllogictest/test_files/wildcard.slt index f83e84804a37..7c076f040feb 100644 --- a/datafusion/sqllogictest/test_files/wildcard.slt +++ b/datafusion/sqllogictest/test_files/wildcard.slt @@ -40,8 +40,8 @@ CREATE EXTERNAL TABLE aggregate_simple ( c3 BOOLEAN NOT NULL, ) STORED AS CSV -WITH HEADER ROW LOCATION '../core/tests/data/aggregate_simple.csv' +OPTIONS ('format.has_header' 'true'); ########## @@ -108,6 +108,31 @@ SELECT t1.*, tb2.* FROM t1 JOIN t2 tb2 ON t2_id = t1_id ORDER BY t1_id statement error Error during planning: Invalid qualifier agg SELECT agg.* FROM aggregate_simple ORDER BY c1 +# select_upper_case_qualified_wildcard +query ITI +SELECT PUBLIC.t1.* FROM PUBLIC.t1 +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 + +query ITI +SELECT PUBLIC.t1.* FROM public.t1 +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 + +query ITI +SELECT public.t1.* FROM PUBLIC.t1 +---- +11 a 1 +22 b 2 +33 c 3 +44 d 4 + ######## # Clean up after the test ######## diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt index 7320688cff1c..d593a985c458 100644 --- a/datafusion/sqllogictest/test_files/window.slt +++ b/datafusion/sqllogictest/test_files/window.slt @@ -32,8 +32,8 @@ CREATE EXTERNAL TABLE aggregate_test_100 ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' +OPTIONS ('format.has_header' 'true'); statement ok CREATE EXTERNAL TABLE null_cases( @@ -42,14 +42,15 @@ CREATE EXTERNAL TABLE null_cases( c3 BIGINT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '../core/tests/data/null_cases.csv'; +LOCATION '../core/tests/data/null_cases.csv' +OPTIONS ('format.has_header' 'true'); ### This is the same table as ### execute_with_partition with 4 partitions statement ok CREATE EXTERNAL TABLE test (c1 int, c2 bigint, c3 boolean) -STORED AS CSV LOCATION '../core/tests/data/partitioned_csv'; +STORED AS CSV LOCATION '../core/tests/data/partitioned_csv' +OPTIONS('format.has_header' 'false'); # for window functions without order by the first, last, and nth function call does not make sense @@ -254,8 +255,8 @@ WITH _sample_data AS ( ---- logical_plan 01)Sort: d.b ASC NULLS LAST -02)--Projection: d.b, MAX(d.a) AS max_a -03)----Aggregate: groupBy=[[d.b]], aggr=[[MAX(d.a)]] +02)--Projection: d.b, max(d.a) AS max_a +03)----Aggregate: groupBy=[[d.b]], aggr=[[max(d.a)]] 04)------SubqueryAlias: d 05)--------SubqueryAlias: _data2 06)----------SubqueryAlias: s @@ -271,12 +272,12 @@ logical_plan 16)------------------EmptyRelation physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] -02)--SortExec: expr=[b@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[b@0 as b, MAX(d.a)@1 as max_a] -04)------AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[MAX(d.a)] +02)--SortExec: expr=[b@0 ASC NULLS LAST], preserve_partitioning=[true] +03)----ProjectionExec: expr=[b@0 as b, max(d.a)@1 as max_a] +04)------AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[max(d.a)] 05)--------CoalesceBatchesExec: target_batch_size=8192 06)----------RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 -07)------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[MAX(d.a)] +07)------------AggregateExec: mode=Partial, gby=[b@1 as b], aggr=[max(d.a)], ordering_mode=Sorted 08)--------------UnionExec 09)----------------ProjectionExec: expr=[1 as a, aa as b] 10)------------------PlaceholderRowExec @@ -337,12 +338,12 @@ WITH _sample_data AS ( ---- logical_plan 01)Sort: d.b ASC NULLS LAST -02)--Projection: d.b, MAX(d.a) AS max_a, MAX(d.seq) -03)----Aggregate: groupBy=[[d.b]], aggr=[[MAX(d.a), MAX(d.seq)]] +02)--Projection: d.b, max(d.a) AS max_a, max(d.seq) +03)----Aggregate: groupBy=[[d.b]], aggr=[[max(d.a), max(d.seq)]] 04)------SubqueryAlias: d 05)--------SubqueryAlias: _data2 -06)----------Projection: ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS seq, s.a, s.b -07)------------WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------Projection: row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS seq, s.a, s.b +07)------------WindowAggr: windowExpr=[[row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 08)--------------SubqueryAlias: s 09)----------------SubqueryAlias: _sample_data 10)------------------Union @@ -356,11 +357,11 @@ logical_plan 18)----------------------EmptyRelation physical_plan 01)SortPreservingMergeExec: [b@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[b@0 as b, MAX(d.a)@1 as max_a, MAX(d.seq)@2 as MAX(d.seq)] -03)----AggregateExec: mode=SinglePartitioned, gby=[b@2 as b], aggr=[MAX(d.a), MAX(d.seq)], ordering_mode=Sorted -04)------ProjectionExec: expr=[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as seq, a@0 as a, b@1 as b] -05)--------BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[b@1 ASC NULLS LAST,a@0 ASC NULLS LAST] +02)--ProjectionExec: expr=[b@0 as b, max(d.a)@1 as max_a, max(d.seq)@2 as max(d.seq)] +03)----AggregateExec: mode=SinglePartitioned, gby=[b@2 as b], aggr=[max(d.a), max(d.seq)], ordering_mode=Sorted +04)------ProjectionExec: expr=[row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as seq, a@0 as a, b@1 as b] +05)--------BoundedWindowAggExec: wdw=[row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() PARTITION BY [s.b] ORDER BY [s.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[b@1 ASC NULLS LAST, a@0 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([b@1], 4), input_partitions=4 09)----------------UnionExec @@ -925,12 +926,43 @@ SELECT 2022-09-29T15:16:34 2 1 1 2022-09-30T19:03:14 1 1 1 + statement ok drop table t statement ok drop table temp +statement ok +CREATE TABLE table1 ( + bar DECIMAL(10,1), + foo VARCHAR(10), + time TIMESTAMP WITH TIME ZONE +); + +statement ok +INSERT INTO table1 (bar, foo, time) VALUES +(200.0, 'me', '1970-01-01T00:00:00.000000010Z'), +(1.0, 'me', '1970-01-01T00:00:00.000000030Z'), +(1.0, 'me', '1970-01-01T00:00:00.000000040Z'), +(2.0, 'you', '1970-01-01T00:00:00.000000020Z'); + +query TP +SELECT foo, first_value(time ORDER BY time DESC NULLS LAST) AS time FROM table1 GROUP BY foo ORDER BY foo; +---- +me 1970-01-01T00:00:00.000000040Z +you 1970-01-01T00:00:00.000000020Z + +query TP +SELECT foo, last_value(time ORDER BY time DESC NULLS LAST) AS time FROM table1 GROUP BY foo ORDER BY foo; +---- +me 1970-01-01T00:00:00.000000010Z +you 1970-01-01T00:00:00.000000020Z + +statement ok +drop table table1; + + #fn window_frame_ranges_unbounded_preceding_err statement error DataFusion error: Error during planning: Invalid window frame: end bound cannot be UNBOUNDED PRECEDING @@ -1103,8 +1135,8 @@ SELECT query IRR SELECT c8, - CUME_DIST() OVER(ORDER BY c9) as cd1, - CUME_DIST() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 + cume_dist() OVER(ORDER BY c9) as cd1, + cume_dist() OVER(ORDER BY c9 ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as cd2 FROM aggregate_test_100 ORDER BY c8 LIMIT 5 @@ -1202,17 +1234,17 @@ EXPLAIN SELECT FROM aggregate_test_100 ---- logical_plan -01)Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2 -02)--WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -03)----Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2 +02)--WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 05)--------TableScan: aggregate_test_100 projection=[c8, c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum2] -02)--BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -03)----ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@1 ASC NULLS LAST,c8@0 ASC NULLS LAST] +01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum2] +02)--BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[c9@1 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c9@1 ASC NULLS LAST, c8@0 ASC NULLS LAST], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c8, c9], has_header=true @@ -1223,17 +1255,17 @@ query TT EXPLAIN SELECT c2, MAX(c9) OVER (ORDER BY c2), SUM(c9) OVER (), MIN(c9) OVER (ORDER BY c2, c9) from aggregate_test_100 ---- logical_plan -01)Projection: aggregate_test_100.c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -02)--WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] -03)----WindowAggr: windowExpr=[[MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: aggregate_test_100.c2, max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +03)----WindowAggr: windowExpr=[[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 05)--------TableScan: aggregate_test_100 projection=[c2, c9] physical_plan -01)ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -03)----BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST] +01)ProjectionExec: expr=[c2@0 as c2, max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--WindowAggExec: wdw=[sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +03)----BoundedWindowAggExec: wdw=[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c2@0 ASC NULLS LAST, c9@1 ASC NULLS LAST], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c9], has_header=true @@ -1246,19 +1278,19 @@ EXPLAIN SELECT c2, MAX(c9) OVER (ORDER BY c9, c2), SUM(c9) OVER (), MIN(c9) OVER ---- logical_plan 01)Sort: aggregate_test_100.c2 ASC NULLS LAST -02)--Projection: aggregate_test_100.c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] -04)------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -05)--------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +02)--Projection: aggregate_test_100.c2, max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +04)------WindowAggr: windowExpr=[[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------WindowAggr: windowExpr=[[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 06)----------TableScan: aggregate_test_100 projection=[c2, c9] physical_plan -01)SortExec: expr=[c2@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[c2@0 as c2, MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -03)----WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -04)------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@1 ASC NULLS LAST,c2@0 ASC NULLS LAST] -06)----------BoundedWindowAggExec: wdw=[MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -07)------------SortExec: expr=[c2@0 ASC NULLS LAST,c9@1 ASC NULLS LAST] +01)SortExec: expr=[c2@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c2@0 as c2, max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +03)----WindowAggExec: wdw=[sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +04)------BoundedWindowAggExec: wdw=[max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c9@1 ASC NULLS LAST, c2@0 ASC NULLS LAST], preserve_partitioning=[false] +06)----------BoundedWindowAggExec: wdw=[min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c2 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +07)------------SortExec: expr=[c2@0 ASC NULLS LAST, c9@1 ASC NULLS LAST], preserve_partitioning=[false] 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c9], has_header=true # test_window_partition_by_order_by @@ -1272,20 +1304,20 @@ EXPLAIN SELECT FROM aggregate_test_100 ---- logical_plan -01)Projection: SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING -02)--WindowAggr: windowExpr=[[COUNT(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] -03)----Projection: aggregate_test_100.c1, aggregate_test_100.c2, SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING -04)------WindowAggr: windowExpr=[[SUM(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +01)Projection: sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +03)----Projection: aggregate_test_100.c1, aggregate_test_100.c2, sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING +04)------WindowAggr: windowExpr=[[sum(CAST(aggregate_test_100.c4 AS Int64)) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c1, c2, c4] physical_plan -01)ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -02)--BoundedWindowAggExec: wdw=[COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -03)----SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] +01)ProjectionExec: expr=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@2 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING, count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +02)--BoundedWindowAggExec: wdw=[count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +03)----SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=4096 05)--------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 -06)----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -08)--------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] +06)----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +07)------------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +08)--------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[true] 09)----------------CoalesceBatchesExec: target_batch_size=4096 10)------------------RepartitionExec: partitioning=Hash([c1@0, c2@1], 2), input_partitions=2 11)--------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1303,17 +1335,17 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2] +01)ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@0 DESC] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query III @@ -1344,17 +1376,17 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 +01)Projection: aggregate_test_100.c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------WindowAggr: windowExpr=[[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, LAG(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +03)----WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------WindowAggr: windowExpr=[[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, lag(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(aggregate_test_100.c9, Int64(2), Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] +01)ProjectionExec: expr=[c9@0 as c9, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as fv1, first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as fv2, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as lag1, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as lag2, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as lead1, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as lead2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@0 DESC] +03)----BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "first_value(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }, lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query IIIIIII @@ -1387,18 +1419,18 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS rn2 +01)Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS rn1, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS rn2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2] +01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as rn1, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@0 ASC NULLS LAST] -05)--------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c9@0 DESC] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] +05)--------BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true @@ -1428,20 +1460,20 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS rn2 +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS rn2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -05)--------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] 06)----------TableScan: aggregate_test_100 projection=[c1, c2, c9] physical_plan -01)ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as rn2] +01)ProjectionExec: expr=[c9@2 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as rn2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@2 ASC NULLS LAST,c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] -05)--------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -07)------------SortExec: expr=[c9@2 DESC,c1@0 DESC] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@2 ASC NULLS LAST, c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +07)------------SortExec: expr=[c9@2 DESC, c1@0 DESC], preserve_partitioning=[false] 08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9], has_header=true query IIII @@ -1504,35 +1536,35 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS o11 +01)Projection: sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS a, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS b, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS c, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS d, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS e, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS f, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING AS g, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS h, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS i, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS j, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS k, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS l, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS m, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS n, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS o, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS p, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS a1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS b1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS c1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS d1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS e1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS f1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING AS g1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS h1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS j1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS k1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS l1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS m1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS n1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS o1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS h11, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS j11, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS k11, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS l11, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS m11, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS n11, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING AS o11 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] -04)------Projection: null_cases.c1, null_cases.c3, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -05)--------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -07)------------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -08)--------------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] -09)----------------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] -10)------------------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] -11)--------------------WindowAggr: windowExpr=[[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] +04)------Projection: null_cases.c1, null_cases.c3, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +05)--------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +07)------------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +08)--------------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] +09)----------------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] +10)------------------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING]] +11)--------------------WindowAggr: windowExpr=[[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 12)----------------------TableScan: null_cases projection=[c1, c2, c3] physical_plan -01)ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@18 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@18 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@3 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@3 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@15 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@20 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@20 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@5 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@5 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@21 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@21 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@6 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@6 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as o11] +01)ProjectionExec: expr=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@18 as a, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@18 as b, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@3 as c, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as d, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as e, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@3 as f, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as g, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as h, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as i, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as j, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as k, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as l, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as m, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@15 as n, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as o, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as p, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@20 as a1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@20 as b1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@5 as c1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as d1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as e1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@5 as f1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as g1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as h1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as j1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as k1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as l1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as m1, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as n1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as o1, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@21 as h11, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@21 as j11, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@6 as k11, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as l11, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as m11, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@6 as n11, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as o11] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] -04)------ProjectionExec: expr=[c1@0 as c1, c3@2 as c3, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -05)--------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c3@2 ASC NULLS LAST,c2@1 ASC NULLS LAST] -07)------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -08)--------------SortExec: expr=[c3@2 ASC NULLS LAST,c1@0 ASC] -09)----------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -10)------------------SortExec: expr=[c3@2 ASC NULLS LAST,c1@0 DESC] -11)--------------------WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }] -12)----------------------WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] -13)------------------------SortExec: expr=[c3@2 DESC NULLS LAST] -14)--------------------------WindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] -15)----------------------------BoundedWindowAggExec: wdw=[SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -16)------------------------------SortExec: expr=[c3@2 DESC,c1@0 ASC NULLS LAST] +03)----WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] +04)------ProjectionExec: expr=[c1@0 as c1, c3@2 as c3, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@4 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@6 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@7 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@8 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@9 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@10 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@11 as sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@12 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@13 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@14 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@15 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@17 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@18 as sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +05)--------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c3@2 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] +07)------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +08)--------------SortExec: expr=[c3@2 ASC NULLS LAST, c1@0 ASC], preserve_partitioning=[false] +09)----------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +10)------------------SortExec: expr=[c3@2 ASC NULLS LAST, c1@0 DESC], preserve_partitioning=[false] +11)--------------------WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }] +12)----------------------WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] +13)------------------------SortExec: expr=[c3@2 DESC NULLS LAST], preserve_partitioning=[false] +14)--------------------------WindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)), is_causal: false }, sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)), is_causal: false }] +15)----------------------------BoundedWindowAggExec: wdw=[sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +16)------------------------------SortExec: expr=[c3@2 DESC, c1@0 ASC NULLS LAST], preserve_partitioning=[false] 17)--------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/null_cases.csv]]}, projection=[c1, c2, c3], has_header=true query IIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIII @@ -1597,17 +1629,17 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c1, c9] physical_plan -01)ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] +01)ProjectionExec: expr=[c9@1 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@1 DESC] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true @@ -1641,17 +1673,17 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 +01)Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c1, c9] physical_plan -01)ProjectionExec: expr=[c9@1 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2] +01)ProjectionExec: expr=[c9@1 as c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c1@0 ASC NULLS LAST,c9@1 DESC] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c1@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[false] 06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c9], has_header=true query III @@ -1685,22 +1717,22 @@ EXPLAIN SELECT c3, LIMIT 5 ---- logical_plan -01)Projection: aggregate_test_100.c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2 +01)Projection: aggregate_test_100.c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3 AS aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------Projection: aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3, aggregate_test_100.c3, aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -05)--------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3 AS aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------Projection: aggregate_test_100.c3 + aggregate_test_100.c4 AS aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3, aggregate_test_100.c2, aggregate_test_100.c3, aggregate_test_100.c9 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [__common_expr_1 AS aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Projection: __common_expr_1, aggregate_test_100.c3, aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +05)--------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [__common_expr_1 AS aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------Projection: aggregate_test_100.c3 + aggregate_test_100.c4 AS __common_expr_1, aggregate_test_100.c2, aggregate_test_100.c3, aggregate_test_100.c9 07)------------TableScan: aggregate_test_100 projection=[c2, c3, c4, c9] physical_plan -01)ProjectionExec: expr=[c3@1 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2] +01)ProjectionExec: expr=[c3@1 as c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)), is_causal: false }] -04)------ProjectionExec: expr=[aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3@0 as aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3, c3@2 as c3, c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -05)--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -06)----------SortPreservingMergeExec: [aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3@0 DESC,c9@3 DESC,c2@1 ASC NULLS LAST] -07)------------SortExec: expr=[aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3@0 DESC,c9@3 DESC,c2@1 ASC NULLS LAST] -08)--------------ProjectionExec: expr=[c3@1 + c4@2 as aggregate_test_100.c3 + aggregate_test_100.c4aggregate_test_100.c4aggregate_test_100.c3, c2@0 as c2, c3@1 as c3, c9@3 as c9] +03)----WindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)), is_causal: false }] +04)------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, c3@2 as c3, c9@3 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortPreservingMergeExec: [__common_expr_1@0 DESC, c9@3 DESC, c2@1 ASC NULLS LAST] +07)------------SortExec: expr=[__common_expr_1@0 DESC, c9@3 DESC, c2@1 ASC NULLS LAST], preserve_partitioning=[true] +08)--------------ProjectionExec: expr=[c3@1 + c4@2 as __common_expr_1, c2@0 as c2, c3@1 as c3, c9@3 as c9] 09)----------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3, c4, c9], has_header=true @@ -1732,31 +1764,29 @@ EXPLAIN SELECT count(*) as global_count FROM ORDER BY c1 ) AS a ---- logical_plan -01)Projection: COUNT(*) AS global_count -02)--Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1)) AS COUNT(*)]] +01)Projection: count(*) AS global_count +02)--Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] 03)----SubqueryAlias: a 04)------Projection: -05)--------Sort: aggregate_test_100.c1 ASC NULLS LAST -06)----------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] -07)------------Projection: aggregate_test_100.c1 -08)--------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") -09)----------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] +05)--------Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[]] +06)----------Projection: aggregate_test_100.c1 +07)------------Filter: aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434") +08)--------------TableScan: aggregate_test_100 projection=[c1, c13], partial_filters=[aggregate_test_100.c13 != Utf8("C2GT5KVyOPZpgKVl110TyZO0NcJ434")] physical_plan -01)ProjectionExec: expr=[COUNT(*)@0 as global_count] -02)--AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)] +01)ProjectionExec: expr=[count(*)@0 as global_count] +02)--AggregateExec: mode=Final, gby=[], aggr=[count(*)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)] -05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=2 -06)----------ProjectionExec: expr=[] -07)------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -08)--------------CoalesceBatchesExec: target_batch_size=4096 -09)----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 -10)------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -11)--------------------ProjectionExec: expr=[c1@0 as c1] -12)----------------------CoalesceBatchesExec: target_batch_size=4096 -13)------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434 -14)--------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -15)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true +04)------AggregateExec: mode=Partial, gby=[], aggr=[count(*)] +05)--------ProjectionExec: expr=[] +06)----------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +07)------------CoalesceBatchesExec: target_batch_size=4096 +08)--------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 +09)----------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +10)------------------CoalesceBatchesExec: target_batch_size=4096 +11)--------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] +12)----------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +13)------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], has_header=true + query I SELECT count(*) as global_count FROM @@ -1783,26 +1813,24 @@ EXPLAIN SELECT c3, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 -03)----Projection: aggregate_test_100.c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2 -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -05)--------Projection: aggregate_test_100.c3, aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -06)----------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -07)------------TableScan: aggregate_test_100 projection=[c2, c3, c9] +01)Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Projection: aggregate_test_100.c3, aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +05)--------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------TableScan: aggregate_test_100 projection=[c2, c3, c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5 -03)----ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c3@0 ASC NULLS LAST,c9@1 DESC] -06)----------CoalesceBatchesExec: target_batch_size=4096 -07)------------RepartitionExec: partitioning=Hash([c3@0], 2), input_partitions=2 -08)--------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -09)----------------ProjectionExec: expr=[c3@1 as c3, c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -10)------------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -11)--------------------SortExec: expr=[c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST] -12)----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3, c9], has_header=true +01)SortPreservingMergeExec: [c3@0 ASC NULLS LAST], fetch=5 +02)--ProjectionExec: expr=[c3@0 as c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum2] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c3@0 ASC NULLS LAST, c9@1 DESC], preserve_partitioning=[true] +05)--------CoalesceBatchesExec: target_batch_size=4096 +06)----------RepartitionExec: partitioning=Hash([c3@0], 2), input_partitions=2 +07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +08)--------------ProjectionExec: expr=[c3@1 as c3, c9@2 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +09)----------------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +10)------------------SortExec: expr=[c3@1 DESC, c9@2 DESC, c2@0 ASC NULLS LAST], preserve_partitioning=[false] +11)--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2, c3, c9], has_header=true @@ -1830,14 +1858,14 @@ EXPLAIN SELECT c1, ROW_NUMBER() OVER (PARTITION BY c1) as rn1 FROM aggregate_tes ---- logical_plan 01)Sort: aggregate_test_100.c1 ASC NULLS LAST -02)--Projection: aggregate_test_100.c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -03)----WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +02)--Projection: aggregate_test_100.c1, row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +03)----WindowAggr: windowExpr=[[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 04)------TableScan: aggregate_test_100 projection=[c1] physical_plan 01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c1@0 ASC NULLS LAST] +02)--ProjectionExec: expr=[c1@0 as c1, row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] +03)----BoundedWindowAggExec: wdw=[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=4096 06)----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1959,14 +1987,14 @@ EXPLAIN SELECT c1, ROW_NUMBER() OVER (PARTITION BY c1) as rn1 FROM aggregate_tes ---- logical_plan 01)Sort: aggregate_test_100.c1 ASC NULLS LAST -02)--Projection: aggregate_test_100.c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 -03)----WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +02)--Projection: aggregate_test_100.c1, row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS rn1 +03)----WindowAggr: windowExpr=[[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 04)------TableScan: aggregate_test_100 projection=[c1] physical_plan -01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST,rn1@1 ASC NULLS LAST] -02)--ProjectionExec: expr=[c1@0 as c1, ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c1@0 ASC NULLS LAST] +01)SortPreservingMergeExec: [c1@0 ASC NULLS LAST, rn1@1 ASC NULLS LAST] +02)--ProjectionExec: expr=[c1@0 as c1, row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as rn1] +03)----BoundedWindowAggExec: wdw=[row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() PARTITION BY [aggregate_test_100.c1] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[true] 05)--------CoalesceBatchesExec: target_batch_size=4096 06)----------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -1984,18 +2012,18 @@ EXPLAIN SELECT c1, ---- logical_plan 01)Sort: aggregate_test_100.c1 ASC NULLS LAST -02)--Projection: aggregate_test_100.c1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING AS sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING]] +02)--Projection: aggregate_test_100.c1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING AS sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING]] 05)--------TableScan: aggregate_test_100 projection=[c1, c9] physical_plan -01)SortExec: expr=[c1@0 ASC NULLS LAST] -02)--ProjectionExec: expr=[c1@0 as c1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING@2 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] -03)----BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +01)SortExec: expr=[c1@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c1@0 as c1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING@2 as sum1, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as sum2] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] 04)------SortPreservingMergeExec: [c9@1 ASC NULLS LAST] -05)--------SortExec: expr=[c9@1 ASC NULLS LAST] -06)----------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] -07)------------SortExec: expr=[c1@0 ASC NULLS LAST,c9@1 ASC NULLS LAST] +05)--------SortExec: expr=[c9@1 ASC NULLS LAST], preserve_partitioning=[true] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +07)------------SortExec: expr=[c1@0 ASC NULLS LAST, c9@1 ASC NULLS LAST], preserve_partitioning=[true] 08)--------------CoalesceBatchesExec: target_batch_size=4096 09)----------------RepartitionExec: partitioning=Hash([c1@0], 2), input_partitions=2 10)------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -2009,20 +2037,18 @@ query TT EXPLAIN SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) ---- logical_plan -01)Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1 -02)--Aggregate: groupBy=[[]], aggr=[[ARRAY_AGG(aggregate_test_100.c13)]] -03)----Limit: skip=0, fetch=1 -04)------Sort: aggregate_test_100.c13 ASC NULLS LAST, fetch=1 -05)--------TableScan: aggregate_test_100 projection=[c13] +01)Projection: array_agg(aggregate_test_100.c13) AS array_agg1 +02)--Aggregate: groupBy=[[]], aggr=[[array_agg(aggregate_test_100.c13)]] +03)----Sort: aggregate_test_100.c13 ASC NULLS LAST, fetch=1 +04)------TableScan: aggregate_test_100 projection=[c13] physical_plan -01)ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1] -02)--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] +01)ProjectionExec: expr=[array_agg(aggregate_test_100.c13)@0 as array_agg1] +02)--AggregateExec: mode=Final, gby=[], aggr=[array_agg(aggregate_test_100.c13)] 03)----CoalescePartitionsExec -04)------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)] +04)------AggregateExec: mode=Partial, gby=[], aggr=[array_agg(aggregate_test_100.c13)] 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 -06)----------GlobalLimitExec: skip=0, fetch=1 -07)------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST] -08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true +06)----------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST], preserve_partitioning=[false] +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true query ? @@ -2070,26 +2096,24 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: aggregate_test_100.c9 ASC NULLS LAST, fetch=5 -03)----Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum3, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum4 -04)------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -05)--------Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING -06)----------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] -07)------------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -08)--------------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] -09)----------------TableScan: aggregate_test_100 projection=[c1, c2, c8, c9] +01)Sort: aggregate_test_100.c9 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum3, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum4 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING +05)--------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +07)------------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] +08)--------------TableScan: aggregate_test_100 projection=[c1, c2, c8, c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[c9@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -05)--------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] -06)----------WindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] -07)------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -08)--------------WindowAggExec: wdw=[SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] -09)----------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true +01)SortExec: TopK(fetch=5), expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c9@2 as c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c9@3 as c9, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] +05)--------WindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c2, aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +07)------------WindowAggExec: wdw=[sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] +08)--------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, c9@3 ASC NULLS LAST, c8@2 ASC NULLS LAST], preserve_partitioning=[false] +09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true @@ -2124,28 +2148,27 @@ EXPLAIN SELECT c9, LIMIT 5 ---- logical_plan -01)Projection: t1.c9, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum3, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum4 +01)Projection: t1.c9, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum1, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sum2, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum3, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS sum4 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -04)------Projection: t1.c2, t1.c9, t1.c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING -05)--------WindowAggr: windowExpr=[[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] -06)----------Projection: t1.c2, t1.c8, t1.c9, t1.c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING -07)------------WindowAggr: windowExpr=[[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] -08)--------------WindowAggr: windowExpr=[[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] +03)----WindowAggr: windowExpr=[[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +04)------Projection: t1.c2, t1.c9, t1.c1_alias, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING +05)--------WindowAggr: windowExpr=[[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------Projection: t1.c2, t1.c8, t1.c9, t1.c1_alias, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING +07)------------WindowAggr: windowExpr=[[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING]] +08)--------------WindowAggr: windowExpr=[[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]] 09)----------------SubqueryAlias: t1 -10)------------------Sort: aggregate_test_100.c9 ASC NULLS LAST -11)--------------------Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c8, aggregate_test_100.c9, aggregate_test_100.c1 AS c1_alias -12)----------------------TableScan: aggregate_test_100 projection=[c1, c2, c8, c9] +10)------------------Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c8, aggregate_test_100.c9, aggregate_test_100.c1 AS c1_alias +11)--------------------TableScan: aggregate_test_100 projection=[c1, c2, c8, c9] physical_plan -01)ProjectionExec: expr=[c9@1 as c9, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] +01)ProjectionExec: expr=[c9@1 as c9, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sum1, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum2, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum3, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum4] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -04)------ProjectionExec: expr=[c2@0 as c2, c9@2 as c9, c1_alias@3 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] -05)--------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] -06)----------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -08)--------------WindowAggExec: wdw=[SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] -09)----------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST,c9@3 ASC NULLS LAST,c8@2 ASC NULLS LAST] +03)----BoundedWindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +04)------ProjectionExec: expr=[c2@0 as c2, c9@2 as c9, c1_alias@3 as c1_alias, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@5 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@6 as sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING] +05)--------WindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c2, t1.c1_alias] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] +06)----------ProjectionExec: expr=[c2@1 as c2, c8@2 as c8, c9@3 as c9, c1_alias@4 as c1_alias, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING@5 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING, sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@6 as sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING] +07)------------BoundedWindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +08)--------------WindowAggExec: wdw=[sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(t1.c9) PARTITION BY [t1.c1, t1.c2] ORDER BY [t1.c9 ASC NULLS LAST, t1.c8 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(NULL)), is_causal: false }] +09)----------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST, c9@3 ASC NULLS LAST, c8@2 ASC NULLS LAST], preserve_partitioning=[false] 10)------------------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c8@2 as c8, c9@3 as c9, c1@0 as c1_alias] 11)--------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c8, c9], has_header=true @@ -2175,23 +2198,21 @@ EXPLAIN SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FO ---- logical_plan 01)Projection: sum1, sum2 -02)--Limit: skip=0, fetch=5 -03)----Sort: aggregate_test_100.c9 ASC NULLS LAST, fetch=5 -04)------Projection: SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING AS sum2, aggregate_test_100.c9 -05)--------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING]] -06)----------Projection: aggregate_test_100.c1, aggregate_test_100.c9, aggregate_test_100.c12, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING -07)------------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] -08)--------------TableScan: aggregate_test_100 projection=[c1, c2, c9, c12] +02)--Sort: aggregate_test_100.c9 ASC NULLS LAST, fetch=5 +03)----Projection: sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING AS sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING AS sum2, aggregate_test_100.c9 +04)------WindowAggr: windowExpr=[[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING]] +05)--------Projection: aggregate_test_100.c1, aggregate_test_100.c9, aggregate_test_100.c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING +06)----------WindowAggr: windowExpr=[[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING]] +07)------------TableScan: aggregate_test_100 projection=[c1, c2, c9, c12] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] -05)--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: false }], mode=[Sorted] -06)----------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -08)--------------SortExec: expr=[c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST] -09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9, c12], has_header=true +02)--SortExec: TopK(fetch=5), expr=[c9@2 ASC NULLS LAST], preserve_partitioning=[false] +03)----ProjectionExec: expr=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@3 as sum1, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING@4 as sum2, c9@1 as c9] +04)------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST] GROUPS BETWEEN 5 PRECEDING AND 3 PRECEDING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(3)), is_causal: true }], mode=[Sorted] +05)--------ProjectionExec: expr=[c1@0 as c1, c9@2 as c9, c12@3 as c12, sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING@4 as sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING] +06)----------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Groups, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +07)------------SortExec: expr=[c1@0 ASC NULLS LAST, c2@1 ASC NULLS LAST], preserve_partitioning=[false] +08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c9, c12], has_header=true query RR SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as sum1, @@ -2205,7 +2226,7 @@ SELECT SUM(c12) OVER(ORDER BY c1, c2 GROUPS BETWEEN 1 PRECEDING AND 1 FOLLOWING) 7.728066219895 NULL # test_c9_rn_ordering_alias -# These tests check whether Datafusion is aware of the ordering generated by the ROW_NUMBER() window function. +# These tests check whether DataFusion is aware of the ordering generated by the ROW_NUMBER() window function. # Physical plan shouldn't have a SortExec after the BoundedWindowAggExec since the table after BoundedWindowAggExec is already ordered by rn1 ASC and c9 DESC. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, @@ -2216,17 +2237,15 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 ASC NULLS LAST, fetch=5 -03)----Sort: aggregate_test_100.c9 ASC NULLS LAST -04)------Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: rn1 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@0 ASC NULLS LAST] +01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 ASC NULLS LAST], preserve_partitioning=[false] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query II @@ -2244,7 +2263,7 @@ SELECT c9, rn1 FROM (SELECT c9, 145294611 5 # test_c9_rn_ordering_alias_opposite_direction -# These tests check whether Datafusion is aware of the ordering generated by the ROW_NUMBER() window function. +# These tests check whether DataFusion is aware of the ordering generated by the ROW_NUMBER() window function. # Physical plan shouldn't have a SortExec after the BoundedWindowAggExec since the table after BoundedWindowAggExec is already ordered by rn1 ASC and c9 DESC. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, @@ -2255,17 +2274,15 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 ASC NULLS LAST, fetch=5 -03)----Sort: aggregate_test_100.c9 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: rn1 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@0 DESC] +01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query II @@ -2283,7 +2300,7 @@ SELECT c9, rn1 FROM (SELECT c9, 4076864659 5 # test_c9_rn_ordering_alias_opposite_direction2 -# These tests check whether Datafusion is aware of the ordering generated by the ROW_NUMBER() window function. +# These tests check whether DataFusion is aware of the ordering generated by the ROW_NUMBER() window function. # Physical plan _should_ have a SortExec after BoundedWindowAggExec since the table after BoundedWindowAggExec is ordered by rn1 ASC and c9 DESC, which is conflicting with the requirement rn1 DESC. query TT EXPLAIN SELECT c9, rn1 FROM (SELECT c9, @@ -2294,19 +2311,16 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 DESC NULLS FIRST, fetch=5 -03)----Sort: aggregate_test_100.c9 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: rn1 DESC NULLS FIRST, fetch=5 +02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[rn1@1 DESC] -03)----ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -04)------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@0 DESC] -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true +01)SortExec: TopK(fetch=5), expr=[rn1@1 DESC], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query II SELECT c9, rn1 FROM (SELECT c9, @@ -2337,19 +2351,16 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST, fetch=5 -03)----Sort: aggregate_test_100.c9 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: rn1 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST,c9@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -04)------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@0 DESC] -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true +01)SortExec: TopK(fetch=5), expr=[rn1@1 ASC NULLS LAST, c9@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true query II SELECT c9, rn1 FROM (SELECT c9, @@ -2367,17 +2378,41 @@ SELECT c9, rn1 FROM (SELECT c9, # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. null as preceding -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between null preceding and current row) from (select 1 a) x # invalid window frame. negative as following -statement error DataFusion error: Error during planning: Invalid window frame: frame offsets must be non negative integers +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers select row_number() over (rows between current row and -1 following) from (select 1 a) x +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. null as preceding +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between null preceding and current row) from (select 1 a) x + +# invalid window frame. negative as following +statement error DataFusion error: Error during planning: Invalid window frame: frame offsets for ROWS / GROUPS must be non negative integers +select row_number() over (order by a groups between current row and -1 following) from (select 1 a) x + +# interval for rows +query I +select row_number() over (rows between '1' preceding and current row) from (select 1 a) x +---- +1 + +# interval for groups +query I +select row_number() over (order by a groups between '1' preceding and current row) from (select 1 a) x +---- +1 + # This test shows that ordering satisfy considers ordering equivalences, # and can simplify (reduce expression size) multi expression requirements during normalization # For the example below, requirement rn1 ASC, c9 DESC should be simplified to the rn1 ASC. @@ -2391,17 +2426,15 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 -03)----Sort: aggregate_test_100.c9 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: rn1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 +02)--Projection: aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--ProjectionExec: expr=[c9@0 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@0 DESC] +01)ProjectionExec: expr=[c9@0 as c9, row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as rn1] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true # This test shows that ordering equivalence can keep track of complex expressions (not just Column expressions) @@ -2415,17 +2448,15 @@ EXPLAIN SELECT c5, c9, rn1 FROM (SELECT c5, c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 ASC NULLS LAST, CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST, fetch=5 -03)----Sort: CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c5, aggregate_test_100.c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS ROW_NUMBER() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c5, c9] +01)Sort: rn1 ASC NULLS LAST, CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST, fetch=5 +02)--Projection: aggregate_test_100.c5, aggregate_test_100.c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [CAST(aggregate_test_100.c9 AS Int32) + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c5, c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--ProjectionExec: expr=[c5@0 as c5, c9@1 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[CAST(c9@1 AS Int32) + c5@0 DESC] +01)ProjectionExec: expr=[c5@0 as c5, c9@1 as c9, row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as rn1] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 + aggregate_test_100.c5 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[CAST(c9@1 AS Int32) + c5@0 DESC], preserve_partitioning=[false] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c5, c9], has_header=true # Ordering equivalence should be preserved during cast expression @@ -2438,17 +2469,15 @@ EXPLAIN SELECT c9, rn1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: rn1 ASC NULLS LAST, fetch=5 -03)----Sort: aggregate_test_100.c9 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c9, CAST(ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS Int64) AS rn1 -05)--------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: rn1 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c9, CAST(row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS Int64) AS rn1 +03)----WindowAggr: windowExpr=[[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--ProjectionExec: expr=[c9@0 as c9, CAST(ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 AS Int64) as rn1] -03)----BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -04)------SortExec: expr=[c9@0 DESC] +01)ProjectionExec: expr=[c9@0 as c9, CAST(row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 AS Int64) as rn1] +02)--GlobalLimitExec: skip=0, fetch=5 +03)----BoundedWindowAggExec: wdw=[row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true # The following query has type error. We should test the error could be detected @@ -2488,10 +2517,9 @@ CREATE EXTERNAL TABLE annotated_data_finite ( desc_col INTEGER, ) STORED AS CSV -WITH HEADER ROW WITH ORDER (ts ASC) LOCATION '../core/tests/data/window_1.csv' -; +OPTIONS ('format.has_header' 'true'); # 100 rows. Columns in the table are ts, inc_col, desc_col. # Source is CsvExec which is ordered by ts column. @@ -2503,9 +2531,9 @@ CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite ( desc_col INTEGER, ) STORED AS CSV -WITH HEADER ROW WITH ORDER (ts ASC) -LOCATION '../core/tests/data/window_1.csv'; +LOCATION '../core/tests/data/window_1.csv' +OPTIONS ('format.has_header' 'true'); # test_source_sorted_aggregate @@ -2541,26 +2569,24 @@ EXPLAIN SELECT ---- logical_plan 01)Projection: sum1, sum2, sum3, min1, min2, min3, max1, max2, max3, cnt1, cnt2, sumr1, sumr2, sumr3, minr1, minr2, minr3, maxr1, maxr2, maxr3, cntr1, cntr2, sum4, cnt3 -02)--Limit: skip=0, fetch=5 -03)----Sort: annotated_data_finite.inc_col DESC NULLS FIRST, fetch=5 -04)------Projection: SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING AS sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING AS sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt3, annotated_data_finite.inc_col -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col AS annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, COUNT(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -06)----------Projection: CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col, annotated_data_finite.inc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING -07)------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -08)--------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] -09)----------------Projection: CAST(annotated_data_finite.desc_col AS Int64) AS CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col, CAST(annotated_data_finite.inc_col AS Int64) AS CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col, annotated_data_finite.ts, annotated_data_finite.inc_col, annotated_data_finite.desc_col -10)------------------TableScan: annotated_data_finite projection=[ts, inc_col, desc_col] +02)--Sort: annotated_data_finite.inc_col DESC NULLS FIRST, fetch=5 +03)----Projection: sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS sum3, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS min1, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS min2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS min3, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS max1, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS max2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS max3, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS cnt1, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING AS sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING AS sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING AS sumr3, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS minr1, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS minr2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS minr3, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS maxr1, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING AS maxr2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING AS maxr3, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS cntr1, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS sum4, count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS cnt3, annotated_data_finite.inc_col +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, count(Int64(1)) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +05)--------Projection: __common_expr_1, annotated_data_finite.inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING +06)----------WindowAggr: windowExpr=[[sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +07)------------WindowAggr: windowExpr=[[sum(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(Int64(1)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING AS count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING]] +08)--------------Projection: CAST(annotated_data_finite.desc_col AS Int64) AS __common_expr_1, CAST(annotated_data_finite.inc_col AS Int64) AS __common_expr_2, annotated_data_finite.ts, annotated_data_finite.inc_col, annotated_data_finite.desc_col +09)----------------TableScan: annotated_data_finite projection=[ts, inc_col, desc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, sum3@2 as sum3, min1@3 as min1, min2@4 as min2, min3@5 as min3, max1@6 as max1, max2@7 as max2, max3@8 as max3, cnt1@9 as cnt1, cnt2@10 as cnt2, sumr1@11 as sumr1, sumr2@12 as sumr2, sumr3@13 as sumr3, minr1@14 as minr1, minr2@15 as minr2, minr3@16 as minr3, maxr1@17 as maxr1, maxr2@18 as maxr2, maxr3@19 as maxr3, cntr1@20 as cntr1, cntr2@21 as cntr2, sum4@22 as sum4, cnt3@23 as cnt3] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----SortExec: TopK(fetch=5), expr=[inc_col@24 DESC] -04)------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@1 as inc_col] -05)--------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------ProjectionExec: expr=[CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col@0 as CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col, inc_col@3 as inc_col, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@5 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@6 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@7 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@12 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@13 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@14 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@15 as COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@22 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@23 as MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@25 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] -07)------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -08)--------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)), is_causal: false }, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)), is_causal: false }, SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)), is_causal: false }, COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)), is_causal: false }], mode=[Sorted] -09)----------------ProjectionExec: expr=[CAST(desc_col@2 AS Int64) as CAST(annotated_data_finite.desc_col AS Int64)annotated_data_finite.desc_col, CAST(inc_col@1 AS Int64) as CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col, ts@0 as ts, inc_col@1 as inc_col, desc_col@2 as desc_col] -10)------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col, desc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +02)--SortExec: TopK(fetch=5), expr=[inc_col@24 DESC], preserve_partitioning=[false] +03)----ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as sum1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@14 as sum2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@15 as sum3, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as min1, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as min2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as min3, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as max1, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as max2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as max3, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@22 as cnt1, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@23 as cnt2, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@2 as sumr1, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@3 as sumr2, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@4 as sumr3, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as minr1, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@6 as minr2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@7 as minr3, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as maxr1, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as maxr2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as maxr3, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@11 as cntr1, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@12 as cntr2, sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@24 as sum4, count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@25 as cnt3, inc_col@1 as inc_col] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, inc_col@3 as inc_col, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING@5 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING@6 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@7 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@8 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@9 as min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@10 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@12 as max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@13 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@14 as count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@15 as count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@16 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@17 as sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@20 as min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@21 as min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@22 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING@23 as max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING@25 as count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING@26 as count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(5)), end_bound: Following(Int32(1)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 4 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(8)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(8)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 4 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(4)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 8 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(8)), end_bound: Following(Int32(1)), is_causal: false }, sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.desc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 5 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(5)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(2)), end_bound: Following(Int32(6)), is_causal: false }, count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(*) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 8 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(8)), is_causal: false }], mode=[Sorted] +08)--------------ProjectionExec: expr=[CAST(desc_col@2 AS Int64) as __common_expr_1, CAST(inc_col@1 AS Int64) as __common_expr_2, ts@0 as ts, inc_col@1 as inc_col, desc_col@2 as desc_col] +09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col, desc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIII SELECT @@ -2633,19 +2659,17 @@ EXPLAIN SELECT LIMIT 5; ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 -03)----Projection: annotated_data_finite.ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 -04)------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -05)--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LAG(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, LEAD(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] -06)----------TableScan: annotated_data_finite projection=[ts, inc_col] +01)Sort: annotated_data_finite.ts DESC NULLS FIRST, fetch=5 +02)--Projection: annotated_data_finite.ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING AS leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING AS leadr2 +03)----WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lag(annotated_data_finite.inc_col, Int64(2), Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(-1), Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING, lead(annotated_data_finite.inc_col, Int64(4), Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING]] +05)--------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[ts@0 DESC] -03)----ProjectionExec: expr=[ts@0 as ts, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] -04)------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "DENSE_RANK() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAG(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LEAD(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +01)SortExec: TopK(fetch=5), expr=[ts@0 DESC], preserve_partitioning=[false] +02)--ProjectionExec: expr=[ts@0 as ts, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@10 as fv1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@11 as fv2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@12 as lv1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@13 as lv2, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@14 as nv1, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@15 as nv2, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@16 as rn1, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@17 as rn2, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@18 as rank1, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@19 as rank2, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@20 as dense_rank1, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@21 as dense_rank2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@22 as lag1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@23 as lag2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@24 as lead1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@25 as lead2, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@2 as fvr1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@3 as fvr2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lvr1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lvr2, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@6 as lagr1, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@7 as lagr2, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING@8 as leadr1, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@9 as leadr2] +03)----BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(5)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "row_number() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "dense_rank() ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(1)), end_bound: Following(Int32(10)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lag(annotated_data_finite.inc_col,Int64(2),Int64(1002)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(-1),Int64(1001)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 1 PRECEDING AND 10 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: Following(Int32(1)), is_causal: false }, lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "lead(annotated_data_finite.inc_col,Int64(4),Int64(1004)) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(10)), is_causal: false }], mode=[Sorted] +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIIIIIIIIIIIII SELECT @@ -2697,30 +2721,28 @@ EXPLAIN SELECT MAX(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as max2, COUNT(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as count1, COUNT(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as count2, - AVG(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as avg1, - AVG(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as avg2 + avg(inc_col) OVER(ORDER BY ts ASC RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING) as avg1, + avg(inc_col) OVER(ORDER BY ts DESC RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING) as avg2 FROM annotated_data_finite ORDER BY inc_col ASC LIMIT 5 ---- logical_plan 01)Projection: sum1, sum2, min1, min2, max1, max2, count1, count2, avg1, avg2 -02)--Limit: skip=0, fetch=5 -03)----Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 -04)------Projection: SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS sum1, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS avg2, annotated_data_finite.inc_col -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, AVG(CAST(annotated_data_finite.inc_col AS Float64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, AVG(CAST(annotated_data_finite.inc_col AS Float64)annotated_data_finite.inc_col AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] -07)------------Projection: CAST(annotated_data_finite.inc_col AS Float64) AS CAST(annotated_data_finite.inc_col AS Float64)annotated_data_finite.inc_col, CAST(annotated_data_finite.inc_col AS Int64) AS CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col, annotated_data_finite.ts, annotated_data_finite.inc_col -08)--------------TableScan: annotated_data_finite projection=[ts, inc_col] +02)--Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 +03)----Projection: sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS min1, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS min2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS max1, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS max2, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS count1, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING AS avg1, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS avg2, annotated_data_finite.inc_col +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING, avg(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, avg(__common_expr_2 AS annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------Projection: CAST(annotated_data_finite.inc_col AS Int64) AS __common_expr_1, CAST(annotated_data_finite.inc_col AS Float64) AS __common_expr_2, annotated_data_finite.ts, annotated_data_finite.inc_col +07)------------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, min1@2 as min1, min2@3 as min2, max1@4 as max1, max2@5 as max2, count1@6 as count1, count2@7 as count2, avg1@8 as avg1, avg2@9 as avg2] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----SortExec: TopK(fetch=5), expr=[inc_col@10 ASC NULLS LAST] -04)------ProjectionExec: expr=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as sum1, SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum2, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as min1, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as min2, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as max1, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as max2, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@12 as count1, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@7 as count2, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@13 as avg1, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@8 as avg2, inc_col@3 as inc_col] -05)--------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MAX(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "AVG(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }], mode=[Sorted] -07)------------ProjectionExec: expr=[CAST(inc_col@1 AS Float64) as CAST(annotated_data_finite.inc_col AS Float64)annotated_data_finite.inc_col, CAST(inc_col@1 AS Int64) as CAST(annotated_data_finite.inc_col AS Int64)annotated_data_finite.inc_col, ts@0 as ts, inc_col@1 as inc_col] -08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +02)--SortExec: TopK(fetch=5), expr=[inc_col@10 ASC NULLS LAST], preserve_partitioning=[false] +03)----ProjectionExec: expr=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@9 as sum1, sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as sum2, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@10 as min1, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@5 as min2, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@11 as max1, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@6 as max2, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@12 as count1, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@7 as count2, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING@13 as avg1, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@8 as avg2, inc_col@3 as inc_col] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING: Ok(Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 5 FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(5)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "min(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "max(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }, avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "avg(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] RANGE BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: Following(Int32(3)), is_causal: false }], mode=[Sorted] +06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, CAST(inc_col@1 AS Float64) as __common_expr_2, ts@0 as ts, inc_col@1 as inc_col] +07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIIIIIIRR SELECT @@ -2760,20 +2782,18 @@ EXPLAIN SELECT ---- logical_plan 01)Projection: first_value1, first_value2, last_value1, last_value2, nth_value1 -02)--Limit: skip=0, fetch=5 -03)----Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 -04)------Projection: FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS last_value1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS nth_value1, annotated_data_finite.inc_col -05)--------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] -07)------------TableScan: annotated_data_finite projection=[ts, inc_col] +02)--Sort: annotated_data_finite.inc_col ASC NULLS LAST, fetch=5 +03)----Projection: first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS first_value1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS first_value2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS last_value1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS nth_value1, annotated_data_finite.inc_col +04)------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, NTH_VALUE(annotated_data_finite.inc_col, Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------TableScan: annotated_data_finite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[first_value1@0 as first_value1, first_value2@1 as first_value2, last_value1@2 as last_value1, last_value2@3 as last_value2, nth_value1@4 as nth_value1] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST] -04)------ProjectionExec: expr=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] -05)--------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "FIRST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "LAST_VALUE(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] -07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true +02)--SortExec: TopK(fetch=5), expr=[inc_col@5 ASC NULLS LAST], preserve_partitioning=[false] +03)----ProjectionExec: expr=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@4 as first_value1, first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@2 as first_value2, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as last_value1, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as last_value2, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as nth_value1, inc_col@1 as inc_col] +04)------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "NTH_VALUE(annotated_data_finite.inc_col,Int64(2)) ORDER BY [annotated_data_finite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "first_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "last_value(annotated_data_finite.inc_col) ORDER BY [annotated_data_finite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_1.csv]]}, projection=[ts, inc_col], output_ordering=[ts@0 ASC NULLS LAST], has_header=true query IIIII SELECT @@ -2806,20 +2826,19 @@ EXPLAIN SELECT ---- logical_plan 01)Projection: sum1, sum2, count1, count2 -02)--Limit: skip=0, fetch=5 -03)----Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 -04)------Projection: SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] -07)------------Projection: CAST(annotated_data_infinite.inc_col AS Int64) AS CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col, annotated_data_infinite.ts, annotated_data_infinite.inc_col -08)--------------TableScan: annotated_data_infinite projection=[ts, inc_col] +02)--Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 +03)----Projection: sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------Projection: CAST(annotated_data_infinite.inc_col AS Int64) AS __common_expr_1, annotated_data_infinite.ts, annotated_data_infinite.inc_col +07)------------TableScan: annotated_data_infinite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2@3 as count2] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] -04)------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] -06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col, ts@0 as ts, inc_col@1 as inc_col] +02)--ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] +03)----GlobalLimitExec: skip=0, fetch=5 +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, ts@0 as ts, inc_col@1 as inc_col] 07)------------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] query IIII @@ -2853,20 +2872,19 @@ EXPLAIN SELECT ---- logical_plan 01)Projection: sum1, sum2, count1, count2 -02)--Limit: skip=0, fetch=5 -03)----Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 -04)------Projection: SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] -07)------------Projection: CAST(annotated_data_infinite.inc_col AS Int64) AS CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col, annotated_data_infinite.ts, annotated_data_infinite.inc_col -08)--------------TableScan: annotated_data_infinite projection=[ts, inc_col] +02)--Sort: annotated_data_infinite.ts ASC NULLS LAST, fetch=5 +03)----Projection: sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING AS count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING AS count2, annotated_data_infinite.ts +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------Projection: CAST(annotated_data_infinite.inc_col AS Int64) AS __common_expr_1, annotated_data_infinite.ts, annotated_data_infinite.inc_col +07)------------TableScan: annotated_data_infinite projection=[ts, inc_col] physical_plan 01)ProjectionExec: expr=[sum1@0 as sum1, sum2@1 as sum2, count1@2 as count1, count2@3 as count2] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----ProjectionExec: expr=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] -04)------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "COUNT(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] -06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as CAST(annotated_data_infinite.inc_col AS Int64)annotated_data_infinite.inc_col, ts@0 as ts, inc_col@1 as inc_col] +02)--ProjectionExec: expr=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@5 as sum1, sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@3 as sum2, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING@6 as count1, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING@4 as count2, ts@1 as ts] +03)----GlobalLimitExec: skip=0, fetch=5 +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }, count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "count(annotated_data_infinite.inc_col) ORDER BY [annotated_data_infinite.ts DESC NULLS FIRST] ROWS BETWEEN 3 PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(3)), is_causal: false }], mode=[Sorted] +06)----------ProjectionExec: expr=[CAST(inc_col@1 AS Int64) as __common_expr_1, ts@0 as ts, inc_col@1 as inc_col] 07)------------StreamingTableExec: partition_sizes=1, projection=[ts, inc_col], infinite_source=true, output_ordering=[ts@0 ASC NULLS LAST] @@ -2910,9 +2928,9 @@ CREATE EXTERNAL TABLE annotated_data_finite2 ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # Columns in the table are a,b,c,d. Source is CsvExec which is ordered by # a,b,c column. Column a has cardinality 2, column b has cardinality 4. @@ -2926,9 +2944,9 @@ CREATE UNBOUNDED EXTERNAL TABLE annotated_data_infinite2 ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC, c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # test_infinite_source_partition_by @@ -2951,26 +2969,26 @@ EXPLAIN SELECT a, b, c, LIMIT 5 ---- logical_plan -01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING AS sum2, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum3, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING AS sum4, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum5, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum6, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum7, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum8, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum9, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW AS sum10, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum11, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING AS sum12 +01)Projection: annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING AS sum2, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum3, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING AS sum4, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum5, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum6, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum7, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum8, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum9, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW AS sum10, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum11, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING AS sum12 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] -04)------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] -07)------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] -08)--------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -09)----------------Projection: CAST(annotated_data_infinite2.c AS Int64) AS CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d +03)----WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] +07)------------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] +08)--------------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +09)----------------Projection: CAST(annotated_data_infinite2.c AS Int64) AS __common_expr_1, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d 10)------------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] physical_plan -01)ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] +01)ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)), is_causal: true }], mode=[Linear] -04)------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[PartiallySorted([1, 0])] -05)--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -06)----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[PartiallySorted([0])] -07)------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow, is_causal: true }], mode=[PartiallySorted([0, 1])] -08)--------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -09)----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_infinite2.c AS Int64)annotated_data_infinite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] +03)----BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST, annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)), is_causal: true }], mode=[Linear] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[PartiallySorted([1, 0])] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST, annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[PartiallySorted([0])] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow, is_causal: true }], mode=[PartiallySorted([0, 1])] +08)--------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_infinite2.c) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +09)----------------ProjectionExec: expr=[CAST(c@2 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] 10)------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] query IIIIIIIIIIIIIII @@ -3019,34 +3037,32 @@ EXPLAIN SELECT a, b, c, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: annotated_data_finite2.c ASC NULLS LAST, fetch=5 -03)----Projection: annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.c, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum1, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING AS sum2, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum3, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING AS sum4, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum5, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum6, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum7, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum8, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum9, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW AS sum10, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum11, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING AS sum12 -04)------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -07)------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] -08)--------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] -09)----------------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, SUM(CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] -10)------------------Projection: CAST(annotated_data_finite2.c AS Int64) AS CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c, annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.c, annotated_data_finite2.d -11)--------------------TableScan: annotated_data_finite2 projection=[a, b, c, d] +01)Sort: annotated_data_finite2.c ASC NULLS LAST, fetch=5 +02)--Projection: annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.c, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum1, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING AS sum2, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum3, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING AS sum4, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum5, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum6, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum7, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING AS sum8, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum9, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW AS sum10, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING AS sum11, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING AS sum12 +03)----WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING]] +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING]] +07)------------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW]] +08)--------------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING, sum(__common_expr_1 AS annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING]] +09)----------------Projection: CAST(annotated_data_finite2.c AS Int64) AS __common_expr_1, annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.c, annotated_data_finite2.d +10)------------------TableScan: annotated_data_finite2 projection=[a, b, c, d] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[c@2 ASC NULLS LAST] -03)----ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] -04)------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)), is_causal: true }], mode=[Sorted] -05)--------SortExec: expr=[d@4 ASC NULLS LAST,a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] -06)----------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] -07)------------SortExec: expr=[b@2 ASC NULLS LAST,a@1 ASC NULLS LAST,d@4 ASC NULLS LAST,c@3 ASC NULLS LAST] -08)--------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -09)----------------SortExec: expr=[b@2 ASC NULLS LAST,a@1 ASC NULLS LAST,c@3 ASC NULLS LAST] -10)------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -11)--------------------SortExec: expr=[a@1 ASC NULLS LAST,d@4 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST] -12)----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted] -13)------------------------SortExec: expr=[a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,d@4 ASC NULLS LAST,c@3 ASC NULLS LAST] -14)--------------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "SUM(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] -15)----------------------------ProjectionExec: expr=[CAST(c@2 AS Int64) as CAST(annotated_data_finite2.c AS Int64)annotated_data_finite2.c, a@0 as a, b@1 as b, c@2 as c, d@3 as d] -16)------------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true +01)SortExec: TopK(fetch=5), expr=[c@2 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[a@1 as a, b@2 as b, c@3 as c, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@9 as sum1, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING@10 as sum2, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@15 as sum3, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING@16 as sum4, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@5 as sum5, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@6 as sum6, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@11 as sum7, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING@12 as sum8, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@7 as sum9, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW@8 as sum10, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING@13 as sum11, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING@14 as sum12] +03)----BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.d] ORDER BY [annotated_data_finite2.a ASC NULLS LAST, annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 1 PRECEDING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Preceding(UInt64(1)), is_causal: true }], mode=[Sorted] +04)------SortExec: expr=[d@4 ASC NULLS LAST, a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN CURRENT ROW AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: CurrentRow, end_bound: Following(UInt64(1)), is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[b@2 ASC NULLS LAST, a@1 ASC NULLS LAST, d@4 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] +07)------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.b, annotated_data_finite2.a] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +08)--------------SortExec: expr=[b@2 ASC NULLS LAST, a@1 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] +09)----------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.b ASC NULLS LAST, annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 1 FOLLOWING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Following(UInt64(1)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +10)------------------SortExec: expr=[a@1 ASC NULLS LAST, d@4 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] +11)--------------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b, annotated_data_finite2.d] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted] +12)----------------------SortExec: expr=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, d@4 ASC NULLS LAST, c@3 ASC NULLS LAST], preserve_partitioning=[false] +13)------------------------BoundedWindowAggExec: wdw=[sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(2)), end_bound: Following(UInt64(1)), is_causal: false }, sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING: Ok(Field { name: "sum(annotated_data_finite2.c) PARTITION BY [annotated_data_finite2.a, annotated_data_finite2.b] ORDER BY [annotated_data_finite2.c ASC NULLS LAST] ROWS BETWEEN 5 PRECEDING AND 5 FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(5)), is_causal: false }], mode=[Sorted] +14)--------------------------ProjectionExec: expr=[CAST(c@2 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] +15)----------------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST], has_header=true query IIIIIIIIIIIIIII SELECT a, b, c, @@ -3099,19 +3115,17 @@ EXPLAIN SELECT * FROM (SELECT *, ROW_NUMBER() OVER(ORDER BY a ASC) as rn1 ---- logical_plan 01)Sort: rn1 ASC NULLS LAST -02)--Filter: rn1 < UInt64(50) -03)----Limit: skip=0, fetch=5 -04)------Sort: rn1 ASC NULLS LAST, fetch=5 -05)--------Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 -06)----------WindowAggr: windowExpr=[[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -07)------------TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] +02)--Sort: rn1 ASC NULLS LAST, fetch=5 +03)----Projection: annotated_data_infinite2.a0, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS rn1 +04)------Filter: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW < UInt64(50) +05)--------WindowAggr: windowExpr=[[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------TableScan: annotated_data_infinite2 projection=[a0, a, b, c, d] physical_plan -01)CoalesceBatchesExec: target_batch_size=4096 -02)--FilterExec: rn1@5 < 50 -03)----GlobalLimitExec: skip=0, fetch=5 -04)------ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] -05)--------BoundedWindowAggExec: wdw=[ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "ROW_NUMBER() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -06)----------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] +01)ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as rn1] +02)--CoalesceBatchesExec: target_batch_size=4096, fetch=5 +03)----FilterExec: row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 < 50 +04)------BoundedWindowAggExec: wdw=[row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "row_number() ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST] # this is a negative test for asserting that window functions (other than ROW_NUMBER) # are not added to ordering equivalence @@ -3125,19 +3139,16 @@ EXPLAIN SELECT c9, sum1 FROM (SELECT c9, LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: sum1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 -03)----Sort: aggregate_test_100.c9 DESC NULLS FIRST -04)------Projection: aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 -05)--------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c9] +01)Sort: sum1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST, fetch=5 +02)--Projection: aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: aggregate_test_100 projection=[c9] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[sum1@1 ASC NULLS LAST,c9@0 DESC] -03)----ProjectionExec: expr=[c9@0 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1] -04)------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c9@0 DESC] -06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true +01)SortExec: TopK(fetch=5), expr=[sum1@1 ASC NULLS LAST, c9@0 DESC], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c9@0 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1] +03)----BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c9@0 DESC], preserve_partitioning=[false] +05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c9], has_header=true # Query below should work when its input is unbounded # because ordering of ROW_NUMBER, RANK result is added to the ordering equivalence @@ -3160,8 +3171,8 @@ SELECT a, d, rn1, rank1 FROM (SELECT a, d, # this is a negative test for asserting that ROW_NUMBER is not # added to the ordering equivalence when it contains partition by. # physical plan should contain SortExec. Since source is unbounded -# pipeline checker should raise error, when plan contains SortExec. -statement error DataFusion error: PipelineChecker +# sanity checker should raise error, when plan contains SortExec. +statement error DataFusion error: SanityCheckPlan SELECT a, d, rn1 FROM (SELECT a, d, ROW_NUMBER() OVER(PARTITION BY d ORDER BY a ASC) as rn1 FROM annotated_data_infinite2 @@ -3210,22 +3221,22 @@ SUM(a) OVER(partition by b, a order by c) as sum2, FROM annotated_data_infinite2; ---- logical_plan -01)Projection: SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 -02)--WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -03)----Projection: CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -04)------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -07)------------Projection: CAST(annotated_data_infinite2.a AS Int64) AS CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d +01)Projection: sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 +02)--WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----Projection: __common_expr_1, annotated_data_infinite2.a, annotated_data_infinite2.d, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +07)------------Projection: CAST(annotated_data_infinite2.a AS Int64) AS __common_expr_1, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d 08)--------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] physical_plan -01)ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum4] -02)--BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] -03)----ProjectionExec: expr=[CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a@0 as CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, a@1 as a, d@4 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -04)------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[PartiallySorted([0])] -06)----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -07)------------ProjectionExec: expr=[CAST(a@0 AS Int64) as CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, a@0 as a, b@1 as b, c@2 as c, d@3 as d] +01)ProjectionExec: expr=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum3, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum4] +02)--BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +03)----ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, a@1 as a, d@4 as d, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +04)------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[PartiallySorted([0])] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +07)------------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] 08)--------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] statement ok @@ -3241,30 +3252,30 @@ EXPLAIN SELECT SUM(a) OVER(partition by a, b order by c) as sum1, FROM annotated_data_infinite2; ---- logical_plan -01)Projection: SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 -02)--WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -03)----Projection: CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, annotated_data_infinite2.a, annotated_data_infinite2.d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -04)------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -05)--------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------WindowAggr: windowExpr=[[SUM(CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -07)------------Projection: CAST(annotated_data_infinite2.a AS Int64) AS CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d +01)Projection: sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum3, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum4 +02)--WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----Projection: __common_expr_1, annotated_data_infinite2.a, annotated_data_infinite2.d, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +04)------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +06)----------WindowAggr: windowExpr=[[sum(__common_expr_1 AS annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +07)------------Projection: CAST(annotated_data_infinite2.a AS Int64) AS __common_expr_1, annotated_data_infinite2.a, annotated_data_infinite2.b, annotated_data_infinite2.c, annotated_data_infinite2.d 08)--------------TableScan: annotated_data_infinite2 projection=[a, b, c, d] physical_plan -01)ProjectionExec: expr=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum3, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum4] -02)--BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +01)ProjectionExec: expr=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum2, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum3, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum4] +02)--BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] 03)----CoalesceBatchesExec: target_batch_size=4096 -04)------RepartitionExec: partitioning=Hash([d@2], 2), input_partitions=2, preserve_order=true, sort_exprs=CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a@0 ASC NULLS LAST,a@1 ASC NULLS LAST -05)--------ProjectionExec: expr=[CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a@0 as CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, a@1 as a, d@4 as d, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -06)----------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------RepartitionExec: partitioning=Hash([d@2], 2), input_partitions=2, preserve_order=true, sort_exprs=__common_expr_1@0 ASC NULLS LAST, a@1 ASC NULLS LAST +05)--------ProjectionExec: expr=[__common_expr_1@0 as __common_expr_1, a@1 as a, d@4 as d, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@7 as sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +06)----------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.b, annotated_data_infinite2.a] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 07)------------CoalesceBatchesExec: target_batch_size=4096 -08)--------------RepartitionExec: partitioning=Hash([b@2, a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a@0 ASC NULLS LAST -09)----------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[PartiallySorted([0])] +08)--------------RepartitionExec: partitioning=Hash([b@2, a@1], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, __common_expr_1@0 ASC NULLS LAST +09)----------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.d] ORDER BY [annotated_data_infinite2.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[PartiallySorted([0])] 10)------------------CoalesceBatchesExec: target_batch_size=4096 -11)--------------------RepartitionExec: partitioning=Hash([a@1, d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a@0 ASC NULLS LAST -12)----------------------BoundedWindowAggExec: wdw=[SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +11)--------------------RepartitionExec: partitioning=Hash([a@1, d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, __common_expr_1@0 ASC NULLS LAST +12)----------------------BoundedWindowAggExec: wdw=[sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(annotated_data_infinite2.a) PARTITION BY [annotated_data_infinite2.a, annotated_data_infinite2.b] ORDER BY [annotated_data_infinite2.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 13)------------------------CoalesceBatchesExec: target_batch_size=4096 -14)--------------------------RepartitionExec: partitioning=Hash([a@1, b@2], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST,CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a@0 ASC NULLS LAST -15)----------------------------ProjectionExec: expr=[CAST(a@0 AS Int64) as CAST(annotated_data_infinite2.a AS Int64)annotated_data_infinite2.a, a@0 as a, b@1 as b, c@2 as c, d@3 as d] +14)--------------------------RepartitionExec: partitioning=Hash([a@1, b@2], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST, __common_expr_1@0 ASC NULLS LAST +15)----------------------------ProjectionExec: expr=[CAST(a@0 AS Int64) as __common_expr_1, a@0 as a, b@1 as b, c@2 as c, d@3 as d] 16)------------------------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 17)--------------------------------StreamingTableExec: partition_sizes=1, projection=[a, b, c, d], infinite_source=true, output_ordering=[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST, c@2 ASC NULLS LAST] @@ -3311,23 +3322,21 @@ EXPLAIN SELECT LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 -03)----Projection: aggregate_test_100.c3, MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max1 -04)------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -05)--------Projection: aggregate_test_100.c3, aggregate_test_100.c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING -06)----------WindowAggr: windowExpr=[[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] -07)------------TableScan: aggregate_test_100 projection=[c3, c11, c12] +01)Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 +02)--Projection: aggregate_test_100.c3, max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS max1 +03)----WindowAggr: windowExpr=[[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Projection: aggregate_test_100.c3, aggregate_test_100.c12, min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING +05)--------WindowAggr: windowExpr=[[min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +06)----------TableScan: aggregate_test_100 projection=[c3, c11, c12] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--SortExec: TopK(fetch=5), expr=[c3@0 ASC NULLS LAST] -03)----ProjectionExec: expr=[c3@0 as c3, MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1] -04)------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -05)--------SortExec: expr=[c12@1 ASC NULLS LAST] -06)----------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] -07)------------WindowAggExec: wdw=[MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "MIN(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] -08)--------------SortExec: expr=[c11@1 ASC NULLS LAST] -09)----------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c11, c12], has_header=true +01)SortExec: TopK(fetch=5), expr=[c3@0 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[c3@0 as c3, max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as max1] +03)----BoundedWindowAggExec: wdw=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +04)------SortExec: expr=[c12@1 ASC NULLS LAST], preserve_partitioning=[false] +05)--------ProjectionExec: expr=[c3@0 as c3, c12@2 as c12, min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@3 as min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] +06)----------WindowAggExec: wdw=[min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "min(aggregate_test_100.c12) PARTITION BY [aggregate_test_100.c11] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +07)------------SortExec: expr=[c11@1 ASC NULLS LAST], preserve_partitioning=[false] +08)--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c11, c12], has_header=true # window1 spec is used multiple times under different aggregations. # The query should still work. @@ -3358,19 +3367,17 @@ EXPLAIN SELECT ---- logical_plan 01)Projection: min1, max1 -02)--Limit: skip=0, fetch=5 -03)----Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 -04)------Projection: MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max1, aggregate_test_100.c3 -05)--------WindowAggr: windowExpr=[[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -06)----------TableScan: aggregate_test_100 projection=[c3, c12] +02)--Sort: aggregate_test_100.c3 ASC NULLS LAST, fetch=5 +03)----Projection: max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max1, aggregate_test_100.c3 +04)------WindowAggr: windowExpr=[[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +05)--------TableScan: aggregate_test_100 projection=[c3, c12] physical_plan 01)ProjectionExec: expr=[min1@0 as min1, max1@1 as max1] -02)--GlobalLimitExec: skip=0, fetch=5 -03)----SortExec: TopK(fetch=5), expr=[c3@2 ASC NULLS LAST] -04)------ProjectionExec: expr=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min1, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max1, c3@0 as c3] -05)--------BoundedWindowAggExec: wdw=[MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }, MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c12@1 ASC NULLS LAST] -07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c12], has_header=true +02)--SortExec: TopK(fetch=5), expr=[c3@2 ASC NULLS LAST], preserve_partitioning=[false] +03)----ProjectionExec: expr=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as min1, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as max1, c3@0 as c3] +04)------BoundedWindowAggExec: wdw=[max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }, min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(aggregate_test_100.c12) ORDER BY [aggregate_test_100.c12 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Float64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +05)--------SortExec: expr=[c12@1 ASC NULLS LAST], preserve_partitioning=[false] +06)----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c12], has_header=true # window2 spec is not defined statement error DataFusion error: Error during planning: The window window2 is not defined! @@ -3385,7 +3392,7 @@ SELECT # window1 spec is defined multiple times statement error DataFusion error: Error during planning: The window window1 is defined multiple times! SELECT - MAX(c12) OVER window1 as min1, + MAX(c12) OVER window1 as min1 FROM aggregate_test_100 WINDOW window1 AS (ORDER BY C12), window1 AS (ORDER BY C3) @@ -3402,10 +3409,10 @@ CREATE EXTERNAL TABLE multiple_ordered_table ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # Since column b is constant after filter b=0, # There should be no SortExec(b ASC) in the plan @@ -3416,11 +3423,11 @@ FROM multiple_ordered_table where b=0 ---- logical_plan -01)WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.a AS Int64)) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)WindowAggr: windowExpr=[[sum(CAST(multiple_ordered_table.a AS Int64)) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 02)--Filter: multiple_ordered_table.b = Int32(0) 03)----TableScan: multiple_ordered_table projection=[a0, a, b, c, d], partial_filters=[multiple_ordered_table.b = Int32(0)] physical_plan -01)BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 02)--CoalesceBatchesExec: target_batch_size=4096 03)----FilterExec: b@2 = 0 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_orderings=[[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], [c@3 ASC NULLS LAST]], has_header=true @@ -3434,12 +3441,12 @@ FROM multiple_ordered_table where b=0 ---- logical_plan -01)WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.a AS Int64)) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)WindowAggr: windowExpr=[[sum(CAST(multiple_ordered_table.a AS Int64)) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 02)--Filter: multiple_ordered_table.b = Int32(0) 03)----TableScan: multiple_ordered_table projection=[a0, a, b, c, d], partial_filters=[multiple_ordered_table.b = Int32(0)] physical_plan -01)BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -02)--SortExec: expr=[d@4 ASC NULLS LAST] +01)BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.a) ORDER BY [multiple_ordered_table.b ASC NULLS LAST, multiple_ordered_table.d ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +02)--SortExec: expr=[d@4 ASC NULLS LAST], preserve_partitioning=[false] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------FilterExec: b@2 = 0 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_orderings=[[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], [c@3 ASC NULLS LAST]], has_header=true @@ -3455,10 +3462,10 @@ CREATE UNBOUNDED EXTERNAL TABLE multiple_ordered_table_inf ( d INTEGER ) STORED AS CSV -WITH HEADER ROW WITH ORDER (a ASC, b ASC) WITH ORDER (c ASC) -LOCATION '../core/tests/data/window_2.csv'; +LOCATION '../core/tests/data/window_2.csv' +OPTIONS ('format.has_header' 'true'); # All of the window execs in the physical plan should work in the # sorted mode. @@ -3468,16 +3475,16 @@ EXPLAIN SELECT MIN(d) OVER(ORDER BY c ASC) as min1, FROM multiple_ordered_table ---- logical_plan -01)Projection: MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max1 -02)--WindowAggr: windowExpr=[[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -03)----Projection: multiple_ordered_table.c, multiple_ordered_table.d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -04)------WindowAggr: windowExpr=[[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS min1, max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max1 +02)--WindowAggr: windowExpr=[[min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----Projection: multiple_ordered_table.c, multiple_ordered_table.d, max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +04)------WindowAggr: windowExpr=[[max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 05)--------TableScan: multiple_ordered_table projection=[a, b, c, d] physical_plan -01)ProjectionExec: expr=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max1] -02)--BoundedWindowAggExec: wdw=[MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MIN(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -03)----ProjectionExec: expr=[c@2 as c, d@3 as d, MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -04)------BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as min1, max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max1] +02)--BoundedWindowAggExec: wdw=[min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "min(multiple_ordered_table.d) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +03)----ProjectionExec: expr=[c@2 as c, d@3 as d, max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +04)------BoundedWindowAggExec: wdw=[max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.b, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query TT @@ -3488,13 +3495,13 @@ FROM( WHERE d=0) ---- logical_plan -01)Projection: MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max_c -02)--WindowAggr: windowExpr=[[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS max_c +02)--WindowAggr: windowExpr=[[max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 03)----Filter: multiple_ordered_table.d = Int32(0) 04)------TableScan: multiple_ordered_table projection=[c, d], partial_filters=[multiple_ordered_table.d = Int32(0)] physical_plan -01)ProjectionExec: expr=[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max_c] -02)--BoundedWindowAggExec: wdw=[MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "MAX(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as max_c] +02)--BoundedWindowAggExec: wdw=[max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "max(multiple_ordered_table.c) PARTITION BY [multiple_ordered_table.d] ORDER BY [multiple_ordered_table.c ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------FilterExec: d@1 = 0 05)--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c, d], output_ordering=[c@0 ASC NULLS LAST], has_header=true @@ -3504,12 +3511,12 @@ explain SELECT SUM(d) OVER(PARTITION BY c ORDER BY a ASC) FROM multiple_ordered_table; ---- logical_plan -01)Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -02)--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--WindowAggr: windowExpr=[[sum(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 03)----TableScan: multiple_ordered_table projection=[a, c, d] physical_plan -01)ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c] ORDER BY [multiple_ordered_table.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c, d], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true query TT @@ -3517,12 +3524,12 @@ explain SELECT SUM(d) OVER(PARTITION BY c, a ORDER BY b ASC) FROM multiple_ordered_table; ---- logical_plan -01)Projection: SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -02)--WindowAggr: windowExpr=[[SUM(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--WindowAggr: windowExpr=[[sum(CAST(multiple_ordered_table.d AS Int64)) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 03)----TableScan: multiple_ordered_table projection=[a, b, c, d] physical_plan -01)ProjectionExec: expr=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--BoundedWindowAggExec: wdw=[sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(multiple_ordered_table.d) PARTITION BY [multiple_ordered_table.c, multiple_ordered_table.a] ORDER BY [multiple_ordered_table.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 03)----CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, b, c, d], output_orderings=[[a@0 ASC NULLS LAST, b@1 ASC NULLS LAST], [c@2 ASC NULLS LAST]], has_header=true query I @@ -3557,14 +3564,13 @@ EXPLAIN SELECT c, NTH_VALUE(c, 2) OVER(order by c DESC) as nv1 LIMIT 5 ---- logical_plan -01)Limit: skip=0, fetch=5 -02)--Sort: multiple_ordered_table.c ASC NULLS LAST, fetch=5 -03)----Projection: multiple_ordered_table.c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nv1 -04)------WindowAggr: windowExpr=[[NTH_VALUE(multiple_ordered_table.c, Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -05)--------TableScan: multiple_ordered_table projection=[c] +01)Sort: multiple_ordered_table.c ASC NULLS LAST, fetch=5 +02)--Projection: multiple_ordered_table.c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS nv1 +03)----WindowAggr: windowExpr=[[NTH_VALUE(multiple_ordered_table.c, Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------TableScan: multiple_ordered_table projection=[c] physical_plan -01)GlobalLimitExec: skip=0, fetch=5 -02)--ProjectionExec: expr=[c@0 as c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nv1] +01)ProjectionExec: expr=[c@0 as c, NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as nv1] +02)--GlobalLimitExec: skip=0, fetch=5 03)----WindowAggExec: wdw=[NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "NTH_VALUE(multiple_ordered_table.c,Int64(2)) ORDER BY [multiple_ordered_table.c DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int32(NULL)), is_causal: false }] 04)------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[c], output_ordering=[c@0 ASC NULLS LAST], has_header=true @@ -3600,7 +3606,7 @@ set datafusion.execution.target_partitions = 2; # we should still have the orderings [a ASC, b ASC], [c ASC]. query TT EXPLAIN SELECT *, - AVG(d) OVER sliding_window AS avg_d + avg(d) OVER sliding_window AS avg_d FROM multiple_ordered_table_inf WINDOW sliding_window AS ( PARTITION BY d @@ -3610,15 +3616,15 @@ ORDER BY c ---- logical_plan 01)Sort: multiple_ordered_table_inf.c ASC NULLS LAST -02)--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d -03)----WindowAggr: windowExpr=[[AVG(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] +02)--Projection: multiple_ordered_table_inf.a0, multiple_ordered_table_inf.a, multiple_ordered_table_inf.b, multiple_ordered_table_inf.c, multiple_ordered_table_inf.d, avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW AS avg_d +03)----WindowAggr: windowExpr=[[avg(CAST(multiple_ordered_table_inf.d AS Float64)) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW]] 04)------TableScan: multiple_ordered_table_inf projection=[a0, a, b, c, d] physical_plan 01)SortPreservingMergeExec: [c@3 ASC NULLS LAST] -02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] -03)----BoundedWindowAggExec: wdw=[AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "AVG(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] +02)--ProjectionExec: expr=[a0@0 as a0, a@1 as a, b@2 as b, c@3 as c, d@4 as d, avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW@5 as avg_d] +03)----BoundedWindowAggExec: wdw=[avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW: Ok(Field { name: "avg(multiple_ordered_table_inf.d) PARTITION BY [multiple_ordered_table_inf.d] ORDER BY [multiple_ordered_table_inf.a ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int32(10)), end_bound: CurrentRow, is_causal: false }], mode=[Linear] 04)------CoalesceBatchesExec: target_batch_size=4096 -05)--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST,b@2 ASC NULLS LAST,c@3 ASC NULLS LAST +05)--------RepartitionExec: partitioning=Hash([d@4], 2), input_partitions=2, preserve_order=true, sort_exprs=a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST 06)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 07)------------StreamingTableExec: partition_sizes=1, projection=[a0, a, b, c, d], infinite_source=true, output_orderings=[[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST], [c@3 ASC NULLS LAST]] @@ -3673,14 +3679,14 @@ SELECT FROM score_board s ORDER BY team_name, score; ---- -Mongrels Apu 350 1 -Mongrels Ned 666 1 -Mongrels Meg 1030 2 -Mongrels Burns 1270 2 -Simpsons Homer 1 1 -Simpsons Lisa 710 1 -Simpsons Marge 990 2 -Simpsons Bart 2010 2 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Simpsons Bart 2010 2 query TTII SELECT @@ -3691,14 +3697,14 @@ SELECT FROM score_board s ORDER BY score; ---- -Simpsons Homer 1 1 -Mongrels Apu 350 1 -Mongrels Ned 666 1 -Simpsons Lisa 710 1 -Simpsons Marge 990 2 -Mongrels Meg 1030 2 -Mongrels Burns 1270 2 -Simpsons Bart 2010 2 +Simpsons Homer 1 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 2 +Mongrels Meg 1030 2 +Mongrels Burns 1270 2 +Simpsons Bart 2010 2 query TTII SELECT @@ -3709,14 +3715,14 @@ SELECT FROM score_board s ORDER BY team_name, score; ---- -Mongrels Apu 350 1 -Mongrels Ned 666 2 -Mongrels Meg 1030 3 -Mongrels Burns 1270 4 -Simpsons Homer 1 1 -Simpsons Lisa 710 2 -Simpsons Marge 990 3 -Simpsons Bart 2010 4 +Mongrels Apu 350 1 +Mongrels Ned 666 2 +Mongrels Meg 1030 3 +Mongrels Burns 1270 4 +Simpsons Homer 1 1 +Simpsons Lisa 710 2 +Simpsons Marge 990 3 +Simpsons Bart 2010 4 query TTII SELECT @@ -3727,14 +3733,14 @@ SELECT FROM score_board s ORDER BY team_name, score; ---- -Mongrels Apu 350 1 -Mongrels Ned 666 1 -Mongrels Meg 1030 1 -Mongrels Burns 1270 1 -Simpsons Homer 1 1 -Simpsons Lisa 710 1 -Simpsons Marge 990 1 -Simpsons Bart 2010 1 +Mongrels Apu 350 1 +Mongrels Ned 666 1 +Mongrels Meg 1030 1 +Mongrels Burns 1270 1 +Simpsons Homer 1 1 +Simpsons Lisa 710 1 +Simpsons Marge 990 1 +Simpsons Bart 2010 1 # incorrect number of parameters for ntile query error DataFusion error: Execution error: NTILE requires a positive integer, but finds NULL @@ -3911,7 +3917,8 @@ b 1 3 a 1 4 b 5 5 -statement error DataFusion error: Error during planning: Projection references non-aggregate values: Expression aggregate_test_100.c1 could not be resolved from available columns: rn +# Schema error: No field named aggregate_test_100.c1. Valid fields are rn. +statement error SELECT * FROM (SELECT c1, c2, ROW_NUMBER() OVER(PARTITION BY c1) as rn FROM aggregate_test_100 @@ -3938,20 +3945,20 @@ CREATE TABLE table_with_pk ( # However, if we know that contains a unique column (e.g. a PRIMARY KEY), # it can be treated as `OVER (ORDER BY ROWS BETWEEN UNBOUNDED PRECEDING # AND CURRENT ROW)` where window frame units change from `RANGE` to `ROWS`. This -# conversion makes the window frame manifestly causal by eliminating the possiblity +# conversion makes the window frame manifestly causal by eliminating the possibility # of ties explicitly (see window frame documentation for a discussion of causality # in this context). The Query below should have `ROWS` in its window frame. query TT EXPLAIN SELECT *, SUM(amount) OVER (ORDER BY sn) as sum1 FROM table_with_pk; ---- logical_plan -01)Projection: table_with_pk.sn, table_with_pk.ts, table_with_pk.currency, table_with_pk.amount, SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 -02)--WindowAggr: windowExpr=[[SUM(CAST(table_with_pk.amount AS Float64)) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: table_with_pk.sn, table_with_pk.ts, table_with_pk.currency, table_with_pk.amount, sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 +02)--WindowAggr: windowExpr=[[sum(CAST(table_with_pk.amount AS Float64)) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 03)----TableScan: table_with_pk projection=[sn, ts, currency, amount] physical_plan -01)ProjectionExec: expr=[sn@0 as sn, ts@1 as ts, currency@2 as currency, amount@3 as amount, SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1] -02)--BoundedWindowAggExec: wdw=[SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted] -03)----SortExec: expr=[sn@0 ASC NULLS LAST] +01)ProjectionExec: expr=[sn@0 as sn, ts@1 as ts, currency@2 as currency, amount@3 as amount, sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum1] +02)--BoundedWindowAggExec: wdw=[sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(table_with_pk.amount) ORDER BY [table_with_pk.sn ASC NULLS LAST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Float64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: CurrentRow, is_causal: true }], mode=[Sorted] +03)----SortExec: expr=[sn@0 ASC NULLS LAST], preserve_partitioning=[false] 04)------MemoryExec: partitions=1, partition_sizes=[1] # test ROW_NUMBER window function returns correct data_type @@ -4060,19 +4067,19 @@ explain SELECT c3, limit 5 ---- logical_plan -01)Projection: aggregate_test_100.c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, sum1 +01)Projection: aggregate_test_100.c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum2, sum1 02)--Limit: skip=0, fetch=5 -03)----WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] -04)------Projection: aggregate_test_100.c3, aggregate_test_100.c4, aggregate_test_100.c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 -05)--------WindowAggr: windowExpr=[[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +03)----WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +04)------Projection: aggregate_test_100.c3, aggregate_test_100.c4, aggregate_test_100.c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS sum1 +05)--------WindowAggr: windowExpr=[[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 06)----------TableScan: aggregate_test_100 projection=[c3, c4, c9] physical_plan -01)ProjectionExec: expr=[c3@0 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, sum1@3 as sum1] +01)ProjectionExec: expr=[c3@0 as c3, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as sum2, sum1@3 as sum1] 02)--GlobalLimitExec: skip=0, fetch=5 -03)----WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)), is_causal: false }] -04)------ProjectionExec: expr=[c3@0 as c3, c4@1 as c4, c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1] -05)--------BoundedWindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] -06)----------SortExec: expr=[c3@0 + c4@1 DESC] +03)----WindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)), is_causal: false }] +04)------ProjectionExec: expr=[c3@0 as c3, c4@1 as c4, c9@2 as c9, sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@3 as sum1] +05)--------BoundedWindowAggExec: wdw=[sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "sum(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +06)----------SortExec: expr=[c3@0 + c4@1 DESC], preserve_partitioning=[false] 07)------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c3, c4, c9], has_header=true query III @@ -4104,13 +4111,13 @@ query TT EXPLAIN select count(*) over (partition by a order by a) from (select * from a where a = 1); ---- logical_plan -01)Projection: COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW -02)--WindowAggr: windowExpr=[[COUNT(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] +01)Projection: count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW +02)--WindowAggr: windowExpr=[[count(Int64(1)) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] 03)----Filter: a.a = Int64(1) 04)------TableScan: a projection=[a] physical_plan -01)ProjectionExec: expr=[COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] -02)--BoundedWindowAggExec: wdw=[COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "COUNT(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW] +02)--BoundedWindowAggExec: wdw=[count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: "count(*) PARTITION BY [a.a] ORDER BY [a.a ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow, is_causal: false }], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -4127,13 +4134,13 @@ query TT EXPLAIN select ROW_NUMBER() over (partition by a) from (select * from a where a = 1); ---- logical_plan -01)Projection: ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING -02)--WindowAggr: windowExpr=[[ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] +01)Projection: row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING +02)--WindowAggr: windowExpr=[[row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 03)----Filter: a.a = Int64(1) 04)------TableScan: a projection=[a] physical_plan -01)ProjectionExec: expr=[ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] -02)--BoundedWindowAggExec: wdw=[ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "ROW_NUMBER() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +01)ProjectionExec: expr=[row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] +02)--BoundedWindowAggExec: wdw=[row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "row_number() PARTITION BY [a.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] 03)----CoalesceBatchesExec: target_batch_size=4096 04)------RepartitionExec: partitioning=Hash([a@0], 2), input_partitions=2 05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 @@ -4799,3 +4806,243 @@ NULL 4 NULL 3 NULL 2 NULL 1 + +statement ok +drop table t + +### Test for window functions with arrays +statement ok +create table array_data +as values + ('a',make_array(1, 2, 3)), + ('b', make_array(4, 5, 6)), + ('c', make_array(7, 8, 9)); + +query T?? +SELECT + column1, + lag(column2) OVER (order by column1), + lead(column2) OVER (order by column1) +FROM array_data; +---- +a NULL [4, 5, 6] +b [1, 2, 3] [7, 8, 9] +c [4, 5, 6] NULL + +statement ok +drop table array_data + +# Test for non-i64 offsets for NTILE, LAG, LEAD, NTH_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 3), (4, 4), (5, 5), (6, 6); + +query IIIIIIIII +SELECT + column1, + ntile(2) OVER (order by column1), + ntile(arrow_cast(2, 'Int32')) OVER (order by column1), + lag(column2, -1) OVER (order by column1), + lag(column2, arrow_cast(-1, 'Int32')) OVER (order by column1), + lead(column2, -1) OVER (order by column1), + lead(column2, arrow_cast(-1, 'Int32')) OVER (order by column1), + nth_value(column2, 2) OVER (order by column1), + nth_value(column2, arrow_cast(2, 'Int32')) OVER (order by column1) +FROM t; +---- +3 1 1 4 4 NULL NULL NULL NULL +4 1 1 5 5 3 3 4 4 +5 2 2 6 6 4 4 4 4 +6 2 2 NULL NULL 5 5 4 4 + +# NTILE specifies the argument types so the error is different +query error +SELECT ntile(1.1) OVER (order by column1) FROM t; + +query error DataFusion error: Execution error: Expected an integer value +SELECT lag(column2, 1.1) OVER (order by column1) FROM t; + +query error DataFusion error: Execution error: Expected an integer value +SELECT lead(column2, 1.1) OVER (order by column1) FROM t; + +query error DataFusion error: Execution error: Expected an integer value +SELECT nth_value(column2, 1.1) OVER (order by column1) FROM t; + +statement ok +drop table t; + +statement ok +create table t(a int, b int) as values (1, 2) + +query II +select a, row_number() over (order by b) as rn from t; +---- +1 1 + +# RowNumber expect 0 args. +query error +select a, row_number(a) over (order by b) as rn from t; + +statement ok +drop table t; + +statement ok +DROP TABLE t1; + +# https://github.com/apache/datafusion/issues/12073 +statement ok +CREATE TABLE t1(v1 BIGINT); + +query error DataFusion error: Execution error: Expected a signed integer literal for the second argument of nth_value, got v1@0 +SELECT NTH_VALUE('+Inf'::Double, v1) OVER (PARTITION BY v1) FROM t1; + +statement ok +DROP TABLE t1; + +statement ok +create table t(c1 int, c2 varchar) as values (1, 'a'), (2, 'b'), (1, 'a'), (3, null), (null, 'a4'), (null, 'de'); + +# test multi group FirstN mode with nulls +query ITI +SELECT * +FROM (SELECT c1, c2, ROW_NUMBER() OVER() as rn + FROM t + LIMIT 5) +GROUP BY rn +ORDER BY rn; +---- +1 a 1 +2 b 2 +1 a 3 +3 NULL 4 +NULL a4 5 + +statement ok +drop table t + +## test handle NULL and 0 value of nth_value +statement ok +CREATE TABLE t(v1 int, v2 int); + +statement ok +INSERT INTO t VALUES (1,1), (1,2),(1,3),(2,1),(2,2); + +query II +SELECT v1, NTH_VALUE(v2, null) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query II +SELECT v1, NTH_VALUE(v2, v2*null) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query II +SELECT v1, NTH_VALUE(v2, 0) OVER (PARTITION BY v1 ORDER BY v2) FROM t; +---- +1 NULL +1 NULL +1 NULL +2 NULL +2 NULL + +query I +SELECT NTH_VALUE(tt0.v1, NULL) OVER (PARTITION BY tt0.v2 ORDER BY tt0.v1) FROM t AS tt0; +---- +NULL +NULL +NULL +NULL +NULL + +statement ok +DROP TABLE t; + +## end test handle NULL and 0 of NTH_VALUE + +## test handle NULL of lead + +statement ok +create table t1(v1 int); + +statement ok +insert into t1 values (1); + +query B +SELECT LEAD(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, false) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false + +query B +SELECT LEAD(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LAG(NULL, 0, true) OVER () FROM t1; +---- +NULL + +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true + +statement ok +insert into t1 values (2); + +query B +SELECT LEAD(NULL, 1, false) OVER () FROM t1; +---- +NULL +false + +query B +SELECT LAG(NULL, 1, false) OVER () FROM t1; +---- +false +NULL + +query B +SELECT LEAD(NULL, 1, true) OVER () FROM t1; +---- +NULL +true + +query B +SELECT LAG(NULL, 1, true) OVER () FROM t1; +---- +true +NULL + +statement ok +DROP TABLE t1; + +## end test handle NULL of lead diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index a947ac2c51c3..b0aa6acf3c7c 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -26,22 +26,27 @@ repository = { workspace = true } license = { workspace = true } authors = { workspace = true } # Specify MSRV here as `cargo msrv` doesn't support workspace version -rust-version = "1.73" +rust-version = "1.79" [lints] workspace = true [dependencies] +arrow-buffer = { workspace = true } async-recursion = "1.0" chrono = { workspace = true } datafusion = { workspace = true, default-features = true } itertools = { workspace = true } object_store = { workspace = true } -prost = "0.12" -prost-types = "0.12" -substrait = "0.31.0" +pbjson-types = "0.7" +# TODO use workspace version +prost = "0.13" +substrait = { version = "0.45", features = ["serde"] } +url = { workspace = true } [dev-dependencies] +datafusion-functions-aggregate = { workspace = true } +serde_json = "1.0" tokio = { workspace = true } [features] diff --git a/datafusion/substrait/src/extensions.rs b/datafusion/substrait/src/extensions.rs new file mode 100644 index 000000000000..459d0e0c5ae5 --- /dev/null +++ b/datafusion/substrait/src/extensions.rs @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::{plan_err, DataFusionError}; +use std::collections::HashMap; +use substrait::proto::extensions::simple_extension_declaration::{ + ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType, +}; +use substrait::proto::extensions::SimpleExtensionDeclaration; + +/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define +/// behavior of plans in addition to what's supported directly by the protobuf definitions. +/// That includes functions, but also provides support for custom types and variations for existing +/// types. This structs facilitates the use of these extensions in DataFusion. +/// TODO: DF doesn't yet use extensions for type variations +/// TODO: DF doesn't yet provide valid extensionUris +#[derive(Default, Debug, PartialEq)] +pub struct Extensions { + pub functions: HashMap, // anchor -> function name + pub types: HashMap, // anchor -> type name + pub type_variations: HashMap, // anchor -> type variation name +} + +impl Extensions { + /// Registers a function and returns the anchor (reference) to it. If the function has already + /// been registered, it returns the existing anchor. + /// Function names are case-insensitive (converted to lowercase). + pub fn register_function(&mut self, function_name: String) -> u32 { + let function_name = function_name.to_lowercase(); + + // Some functions are named differently in Substrait default extensions than in DF + // Rename those to match the Substrait extensions for interoperability + let function_name = match function_name.as_str() { + "substr" => "substring".to_string(), + _ => function_name, + }; + + match self.functions.iter().find(|(_, f)| *f == &function_name) { + Some((function_anchor, _)) => *function_anchor, // Function has been registered + None => { + // Function has NOT been registered + let function_anchor = self.functions.len() as u32; + self.functions + .insert(function_anchor, function_name.clone()); + function_anchor + } + } + } + + /// Registers a type and returns the anchor (reference) to it. If the type has already + /// been registered, it returns the existing anchor. + pub fn register_type(&mut self, type_name: String) -> u32 { + let type_name = type_name.to_lowercase(); + match self.types.iter().find(|(_, t)| *t == &type_name) { + Some((type_anchor, _)) => *type_anchor, // Type has been registered + None => { + // Type has NOT been registered + let type_anchor = self.types.len() as u32; + self.types.insert(type_anchor, type_name.clone()); + type_anchor + } + } + } +} + +impl TryFrom<&Vec> for Extensions { + type Error = DataFusionError; + + fn try_from( + value: &Vec, + ) -> datafusion::common::Result { + let mut functions = HashMap::new(); + let mut types = HashMap::new(); + let mut type_variations = HashMap::new(); + + for ext in value { + match &ext.mapping_type { + Some(MappingType::ExtensionFunction(ext_f)) => { + functions.insert(ext_f.function_anchor, ext_f.name.to_owned()); + } + Some(MappingType::ExtensionType(ext_t)) => { + types.insert(ext_t.type_anchor, ext_t.name.to_owned()); + } + Some(MappingType::ExtensionTypeVariation(ext_v)) => { + type_variations + .insert(ext_v.type_variation_anchor, ext_v.name.to_owned()); + } + None => return plan_err!("Cannot parse empty extension"), + } + } + + Ok(Extensions { + functions, + types, + type_variations, + }) + } +} + +impl From for Vec { + fn from(val: Extensions) -> Vec { + let mut extensions = vec![]; + for (f_anchor, f_name) in val.functions { + let function_extension = ExtensionFunction { + extension_uri_reference: u32::MAX, + function_anchor: f_anchor, + name: f_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction(function_extension)), + }; + extensions.push(simple_extension); + } + + for (t_anchor, t_name) in val.types { + let type_extension = ExtensionType { + extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545 + type_anchor: t_anchor, + name: t_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionType(type_extension)), + }; + extensions.push(simple_extension); + } + + for (tv_anchor, tv_name) in val.type_variations { + let type_variation_extension = ExtensionTypeVariation { + extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet + type_variation_anchor: tv_anchor, + name: tv_name, + }; + let simple_extension = SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionTypeVariation( + type_variation_extension, + )), + }; + extensions.push(simple_extension); + } + + extensions + } +} diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs index 454f0e7b7cb9..a6f7c033f9d0 100644 --- a/datafusion/substrait/src/lib.rs +++ b/datafusion/substrait/src/lib.rs @@ -68,10 +68,12 @@ //! //! // Receive a substrait protobuf from somewhere, and turn it into a LogicalPlan //! let logical_round_trip = logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?; +//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?; //! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); //! # Ok(()) //! # } //! ``` +pub mod extensions; pub mod logical_plan; pub mod physical_plan; pub mod serializer; diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index fab4528c0b42..890da7361d7c 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -15,32 +15,74 @@ // specific language governing permissions and limitations // under the License. +use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, OffsetBuffer}; use async_recursion::async_recursion; -use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; +use datafusion::arrow::array::{GenericListArray, MapArray}; +use datafusion::arrow::datatypes::{ + DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, +}; +use datafusion::common::plan_err; use datafusion::common::{ - not_impl_err, substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, + not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, + DFSchemaRef, }; - use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; + use datafusion::logical_expr::{ - aggregate_function, expr::find_df_window_func, BinaryExpr, Case, Expr, LogicalPlan, - Operator, ScalarUDF, + expr::find_df_window_func, Aggregate, BinaryExpr, Case, EmptyRelation, Expr, + ExprSchemable, LogicalPlan, Operator, Projection, SortExpr, Values, +}; +use substrait::proto::aggregate_rel::Grouping; +use substrait::proto::expression::subquery::set_predicate::PredicateOp; +use substrait::proto::expression_reference::ExprType; +use url::Url; + +use crate::extensions::Extensions; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, }; +#[allow(deprecated)] +use crate::variation_const::{ + INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_NAME, + INTERVAL_MONTH_DAY_NANO_TYPE_REF, INTERVAL_YEAR_MONTH_TYPE_REF, + TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, + TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{new_empty_array, AsArray}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::dataframe::DataFrame; +use datafusion::logical_expr::builder::project; +use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, Subquery, WindowFrameBound, WindowFrameUnits, + col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, + Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; use datafusion::prelude::JoinType; use datafusion::sql::TableReference; use datafusion::{ - error::{DataFusionError, Result}, + error::Result, logical_expr::utils::split_conjunction, prelude::{Column, SessionContext}, scalar::ScalarValue, }; +use std::collections::HashSet; +use std::sync::Arc; use substrait::proto::exchange_rel::ExchangeKind; +use substrait::proto::expression::literal::user_defined::Val; +use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, + UserDefined, +}; use substrait::proto::expression::subquery::SubqueryType; use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; +use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; +use substrait::proto::rel_common::{Emit, EmitKind}; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -48,89 +90,69 @@ use substrait::proto::{ reference_segment::ReferenceType::StructField, window_function::bound as SubstraitBound, window_function::bound::Kind as BoundKind, window_function::Bound, - MaskExpression, RexType, + window_function::BoundsType, MaskExpression, RexType, }, - extensions::simple_extension_declaration::MappingType, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::ReadType, rel::RelType, - set_rel, + rel_common, set_rel, sort_field::{SortDirection, SortKind::*}, - AggregateFunction, Expression, Plan, Rel, Type, + AggregateFunction, Expression, NamedStruct, Plan, Rel, RelCommon, Type, }; -use substrait::proto::{FunctionArgument, SortField}; +use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; -use datafusion::common::plan_err; -use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; -use std::collections::HashMap; -use std::str::FromStr; -use std::sync::Arc; +// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which +// is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone +// results in correct points on the timeline, and we pick UTC as a reasonable default. +// However, DF uses the timezone also for some arithmetic and display purposes (see e.g. +// https://github.com/apache/arrow-rs/blob/ee5694078c86c8201549654246900a4232d531a9/arrow-cast/src/cast/mod.rs#L1749). +const DEFAULT_TIMEZONE: &str = "UTC"; -use crate::variation_const::{ - DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, - DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF, - TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, - TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF, -}; - -enum ScalarFunctionType { - Op(Operator), - Expr(BuiltinExprBuilder), - Udf(Arc), -} - -pub fn name_to_op(name: &str) -> Result { +pub fn name_to_op(name: &str) -> Option { match name { - "equal" => Ok(Operator::Eq), - "not_equal" => Ok(Operator::NotEq), - "lt" => Ok(Operator::Lt), - "lte" => Ok(Operator::LtEq), - "gt" => Ok(Operator::Gt), - "gte" => Ok(Operator::GtEq), - "add" => Ok(Operator::Plus), - "subtract" => Ok(Operator::Minus), - "multiply" => Ok(Operator::Multiply), - "divide" => Ok(Operator::Divide), - "mod" => Ok(Operator::Modulo), - "and" => Ok(Operator::And), - "or" => Ok(Operator::Or), - "is_distinct_from" => Ok(Operator::IsDistinctFrom), - "is_not_distinct_from" => Ok(Operator::IsNotDistinctFrom), - "regex_match" => Ok(Operator::RegexMatch), - "regex_imatch" => Ok(Operator::RegexIMatch), - "regex_not_match" => Ok(Operator::RegexNotMatch), - "regex_not_imatch" => Ok(Operator::RegexNotIMatch), - "bitwise_and" => Ok(Operator::BitwiseAnd), - "bitwise_or" => Ok(Operator::BitwiseOr), - "str_concat" => Ok(Operator::StringConcat), - "at_arrow" => Ok(Operator::AtArrow), - "arrow_at" => Ok(Operator::ArrowAt), - "bitwise_xor" => Ok(Operator::BitwiseXor), - "bitwise_shift_right" => Ok(Operator::BitwiseShiftRight), - "bitwise_shift_left" => Ok(Operator::BitwiseShiftLeft), - _ => not_impl_err!("Unsupported function name: {name:?}"), + "equal" => Some(Operator::Eq), + "not_equal" => Some(Operator::NotEq), + "lt" => Some(Operator::Lt), + "lte" => Some(Operator::LtEq), + "gt" => Some(Operator::Gt), + "gte" => Some(Operator::GtEq), + "add" => Some(Operator::Plus), + "subtract" => Some(Operator::Minus), + "multiply" => Some(Operator::Multiply), + "divide" => Some(Operator::Divide), + "mod" => Some(Operator::Modulo), + "modulus" => Some(Operator::Modulo), + "and" => Some(Operator::And), + "or" => Some(Operator::Or), + "is_distinct_from" => Some(Operator::IsDistinctFrom), + "is_not_distinct_from" => Some(Operator::IsNotDistinctFrom), + "regex_match" => Some(Operator::RegexMatch), + "regex_imatch" => Some(Operator::RegexIMatch), + "regex_not_match" => Some(Operator::RegexNotMatch), + "regex_not_imatch" => Some(Operator::RegexNotIMatch), + "bitwise_and" => Some(Operator::BitwiseAnd), + "bitwise_or" => Some(Operator::BitwiseOr), + "str_concat" => Some(Operator::StringConcat), + "at_arrow" => Some(Operator::AtArrow), + "arrow_at" => Some(Operator::ArrowAt), + "bitwise_xor" => Some(Operator::BitwiseXor), + "bitwise_shift_right" => Some(Operator::BitwiseShiftRight), + "bitwise_shift_left" => Some(Operator::BitwiseShiftLeft), + _ => None, } } -fn scalar_function_type_from_str( - ctx: &SessionContext, - name: &str, -) -> Result { - let s = ctx.state(); - if let Some(func) = s.scalar_functions().get(name) { - return Ok(ScalarFunctionType::Udf(func.to_owned())); - } - - if let Ok(op) = name_to_op(name) { - return Ok(ScalarFunctionType::Op(op)); - } - - if let Some(builder) = BuiltinExprBuilder::try_from_name(name) { - return Ok(ScalarFunctionType::Expr(builder)); - } - - not_impl_err!("Unsupported function name: {name:?}") +pub fn substrait_fun_name(name: &str) -> &str { + let name = match name.rsplit_once(':') { + // Since 0.32.0, Substrait requires the function names to be in a compound format + // https://substrait.io/extensions/#function-signature-compound-names + // for example, `add:i8_i8`. + // On the consumer side, we don't really care about the signature though, just the name. + Some((name, _)) => name, + None => name, + }; + name } fn split_eq_and_noneq_join_predicate_with_nulls_equality( @@ -143,6 +165,7 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( let mut nulls_equal_nulls = false; for expr in exprs { + #[allow(clippy::collapsible_match)] match expr { Expr::BinaryExpr(binary_expr) => match binary_expr { x @ (BinaryExpr { @@ -178,35 +201,109 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( (accum_join_keys, nulls_equal_nulls, join_filter) } +async fn union_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + from_substrait_rel(ctx, &rels[0], extensions).await?, + )); + for input in &rels[1..] { + let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( ctx: &SessionContext, plan: &Plan, ) -> Result { // Register function extension - let function_extension = plan - .extensions - .iter() - .map(|e| match &e.mapping_type { - Some(ext) => match ext { - MappingType::ExtensionFunction(ext_f) => { - Ok((ext_f.function_anchor, &ext_f.name)) - } - _ => not_impl_err!("Extension type not supported: {ext:?}"), - }, - None => not_impl_err!("Cannot parse empty extension"), - }) - .collect::>>()?; + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + // Parse relations match plan.relations.len() { 1 => { match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(ctx, rel, &function_extension).await?) + Ok(from_substrait_rel(ctx, rel, &extensions).await?) }, plan_rel::RelType::Root(root) => { - Ok(from_substrait_rel(ctx, root.input.as_ref().unwrap(), &function_extension).await?) + let plan = from_substrait_rel(ctx, root.input.as_ref().unwrap(), &extensions).await?; + if root.names.is_empty() { + // Backwards compatibility for plans missing names + return Ok(plan); + } + let renamed_schema = make_renamed_schema(plan.schema(), &root.names)?; + if renamed_schema.equivalent_names_and_types(plan.schema()) { + // Nothing to do if the schema is already equivalent + return Ok(plan); + } + match plan { + // If the last node of the plan produces expressions, bake the renames into those expressions. + // This isn't necessary for correctness, but helps with roundtrip tests. + LogicalPlan::Projection(p) => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(p.expr, p.input.schema(), renamed_schema.fields())?, p.input)?)), + LogicalPlan::Aggregate(a) => { + let (group_fields, expr_fields) = renamed_schema.fields().split_at(a.group_expr.len()); + let new_group_exprs = rename_expressions(a.group_expr, a.input.schema(), group_fields)?; + let new_aggr_exprs = rename_expressions(a.aggr_expr, a.input.schema(), expr_fields)?; + Ok(LogicalPlan::Aggregate(Aggregate::try_new(a.input, new_group_exprs, new_aggr_exprs)?)) + }, + // There are probably more plans where we could bake things in, can add them later as needed. + // Otherwise, add a new Project to handle the renaming. + _ => Ok(LogicalPlan::Projection(Projection::try_new(rename_expressions(plan.schema().columns().iter().map(|c| col(c.to_owned())), plan.schema(), renamed_schema.fields())?, Arc::new(plan))?)) + } } }, None => plan_err!("Cannot parse plan relation: None") @@ -219,38 +316,300 @@ pub async fn from_substrait_plan( } } +/// An ExprContainer is a container for a collection of expressions with a common input schema +/// +/// In addition, each expression is associated with a field, which defines the +/// expression's output. The data type and nullability of the field are calculated from the +/// expression and the input schema. However the names of the field (and its nested fields) are +/// derived from the Substrait message. +pub struct ExprContainer { + /// The input schema for the expressions + pub input_schema: DFSchemaRef, + /// The expressions + /// + /// Each item contains an expression and the field that defines the expected nullability and name of the expr's output + pub exprs: Vec<(Expr, Field)>, +} + +/// Convert Substrait ExtendedExpression to ExprContainer +/// +/// A Substrait ExtendedExpression message contains one or more expressions, +/// with names for the outputs, and an input schema. These pieces are all included +/// in the ExprContainer. +/// +/// This is a top-level message and can be used to send expressions (not plans) +/// between systems. This is often useful for scenarios like pushdown where filter +/// expressions need to be sent to remote systems. +pub async fn from_substrait_extended_expr( + ctx: &SessionContext, + extended_expr: &ExtendedExpression, +) -> Result { + // Register function extension + let extensions = Extensions::try_from(&extended_expr.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { + Some(base_schema) => from_substrait_named_struct(base_schema, &extensions), + None => { + plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") + } + }?); + + // Parse expressions + let mut exprs = Vec::with_capacity(extended_expr.referred_expr.len()); + for (expr_idx, substrait_expr) in extended_expr.referred_expr.iter().enumerate() { + let scalar_expr = match &substrait_expr.expr_type { + Some(ExprType::Expression(scalar_expr)) => Ok(scalar_expr), + Some(ExprType::Measure(_)) => { + not_impl_err!("Measure expressions are not yet supported") + } + None => { + plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") + } + }?; + let expr = + from_substrait_rex(ctx, scalar_expr, &input_schema, &extensions).await?; + let (output_type, expected_nullability) = + expr.data_type_and_nullable(&input_schema)?; + let output_field = Field::new("", output_type, expected_nullability); + let mut names_idx = 0; + let output_field = rename_field( + &output_field, + &substrait_expr.output_names, + expr_idx, + &mut names_idx, + /*rename_self=*/ true, + )?; + exprs.push((expr, output_field)); + } + + Ok(ExprContainer { + input_schema, + exprs, + }) +} + +pub fn apply_masking( + schema: DFSchema, + mask_expression: &::core::option::Option, +) -> Result { + match mask_expression { + Some(MaskExpression { select, .. }) => match &select.as_ref() { + Some(projection) => { + let column_indices: Vec = projection + .struct_items + .iter() + .map(|item| item.field as usize) + .collect(); + + let fields = column_indices + .iter() + .map(|i| schema.qualified_field(*i)) + .map(|(qualifier, field)| { + (qualifier.cloned(), Arc::new(field.clone())) + }) + .collect(); + + Ok(DFSchema::new_with_metadata( + fields, + schema.metadata().clone(), + )?) + } + None => Ok(schema), + }, + None => Ok(schema), + } +} + +/// Ensure the expressions have the right name(s) according to the new schema. +/// This includes the top-level (column) name, which will be renamed through aliasing if needed, +/// as well as nested names (if the expression produces any struct types), which will be renamed +/// through casting if needed. +fn rename_expressions( + exprs: impl IntoIterator, + input_schema: &DFSchema, + new_schema_fields: &[Arc], +) -> Result> { + exprs + .into_iter() + .zip(new_schema_fields) + .map(|(old_expr, new_field)| { + // Check if type (i.e. nested struct field names) match, use Cast to rename if needed + let new_expr = if &old_expr.get_type(input_schema)? != new_field.data_type() { + Expr::Cast(Cast::new( + Box::new(old_expr), + new_field.data_type().to_owned(), + )) + } else { + old_expr + }; + // Alias column if needed to fix the top-level name + match &new_expr { + // If expr is a column reference, alias_if_changed would cause an aliasing if the old expr has a qualifier + Expr::Column(c) if &c.name == new_field.name() => Ok(new_expr), + _ => new_expr.alias_if_changed(new_field.name().to_owned()), + } + }) + .collect() +} + +fn rename_field( + field: &Field, + dfs_names: &Vec, + unnamed_field_suffix: usize, // If Substrait doesn't provide a name, we'll use this "c{unnamed_field_suffix}" + name_idx: &mut usize, // Index into dfs_names + rename_self: bool, // Some fields (e.g. list items) don't have names in Substrait and this will be false to keep old name +) -> Result { + let name = if rename_self { + next_struct_field_name(unnamed_field_suffix, dfs_names, name_idx)? + } else { + field.name().to_string() + }; + match field.data_type() { + DataType::Struct(children) => { + let children = children + .iter() + .enumerate() + .map(|(child_idx, f)| { + rename_field( + f.as_ref(), + dfs_names, + child_idx, + name_idx, + /*rename_self=*/ true, + ) + }) + .collect::>()?; + Ok(field + .to_owned() + .with_name(name) + .with_data_type(DataType::Struct(children))) + } + DataType::List(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self=*/ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::List(FieldRef::new(renamed_inner))) + .with_name(name)) + } + DataType::LargeList(inner) => { + let renamed_inner = rename_field( + inner.as_ref(), + dfs_names, + 0, + name_idx, + /*rename_self= */ false, + )?; + Ok(field + .to_owned() + .with_data_type(DataType::LargeList(FieldRef::new(renamed_inner))) + .with_name(name)) + } + _ => Ok(field.to_owned().with_name(name)), + } +} + +/// Produce a version of the given schema with names matching the given list of names. +/// Substrait doesn't deal with column (incl. nested struct field) names within the schema, +/// but it does give us the list of expected names at the end of the plan, so we use this +/// to rename the schema to match the expected names. +fn make_renamed_schema( + schema: &DFSchemaRef, + dfs_names: &Vec, +) -> Result { + let mut name_idx = 0; + + let (qualifiers, fields): (_, Vec) = schema + .iter() + .enumerate() + .map(|(field_idx, (q, f))| { + let renamed_f = rename_field( + f.as_ref(), + dfs_names, + field_idx, + &mut name_idx, + /*rename_self=*/ true, + )?; + Ok((q.cloned(), renamed_f)) + }) + .collect::>>()? + .into_iter() + .unzip(); + + if name_idx != dfs_names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + dfs_names.len()); + } + + DFSchema::from_field_specific_qualified_schema( + qualifiers, + &Arc::new(Schema::new(fields)), + ) +} + /// Convert Substrait Rel to DataFusion DataFrame +#[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( ctx: &SessionContext, rel: &Rel, - extensions: &HashMap, + extensions: &Extensions, ) -> Result { - match &rel.rel_type { + let plan: Result = match &rel.rel_type { Some(RelType::Project(p)) => { if let Some(input) = p.input.as_ref() { let mut input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); - let mut exprs: Vec = vec![]; - for e in &p.expressions { - let x = - from_substrait_rex(ctx, e, input.clone().schema(), extensions) + let original_schema = input.schema().clone(); + + // Ensure that all expressions have a unique display name, so that + // validate_unique_names does not fail when constructing the project. + let mut name_tracker = NameTracker::new(); + + // By default, a Substrait Project emits all inputs fields followed by all expressions. + // We build the explicit expressions first, and then the input expressions to avoid + // adding aliases to the explicit expressions (as part of ensuring unique names). + // + // This is helpful for plan visualization and tests, because when DataFusion produces + // Substrait Projects it adds an output mapping that excludes all input columns + // leaving only explicit expressions. + + let mut explicit_exprs: Vec = vec![]; + for expr in &p.expressions { + let e = + from_substrait_rex(ctx, expr, input.clone().schema(), extensions) .await?; // if the expression is WindowFunction, wrap in a Window relation - // before returning and do not add to list of this Projection's expression list - // otherwise, add expression to the Projection's expression list - match &*x { - Expr::WindowFunction(_) => { - input = input.window(vec![x.as_ref().clone()])?; - exprs.push(x.as_ref().clone()); - } - _ => { - exprs.push(x.as_ref().clone()); - } + if let Expr::WindowFunction(_) = &e { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + input = input.window(vec![e.clone()])? } + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); + } + + let mut final_exprs: Vec = vec![]; + for index in 0..original_schema.fields().len() { + let e = Expr::Column(Column::from( + original_schema.qualified_field(index), + )); + final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } - input.project(exprs)?.build() + final_exprs.append(&mut explicit_exprs); + + input.project(final_exprs)?.build() } else { not_impl_err!("Projection without an input is not supported") } @@ -264,7 +623,7 @@ pub async fn from_substrait_rel( let expr = from_substrait_rex(ctx, condition, input.schema(), extensions) .await?; - input.filter(expr.as_ref().clone())?.build() + input.filter(expr)?.build() } else { not_impl_err!("Filter without an condition is not valid") } @@ -278,8 +637,8 @@ pub async fn from_substrait_rel( from_substrait_rel(ctx, input, extensions).await?, ); let offset = fetch.offset as usize; - // Since protobuf can't directly distinguish `None` vs `0` `None` is encoded as `MAX` - let count = if fetch.count as usize == usize::MAX { + // -1 means that ALL records should be returned + let count = if fetch.count == -1 { None } else { Some(fetch.count as usize) @@ -307,39 +666,48 @@ pub async fn from_substrait_rel( let input = LogicalPlanBuilder::from( from_substrait_rel(ctx, input, extensions).await?, ); - let mut group_expr = vec![]; - let mut aggr_expr = vec![]; + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = + from_substrait_rex(ctx, e, input.schema(), extensions).await?; + ref_group_exprs.push(x); + } + + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; match agg.groupings.len() { 1 => { - for e in &agg.groupings[0].grouping_expressions { - let x = - from_substrait_rex(ctx, e, input.schema(), extensions) - .await?; - group_expr.push(x.as_ref().clone()); - } + group_exprs.extend_from_slice( + &from_substrait_grouping( + ctx, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + extensions, + ) + .await?, + ); } _ => { let mut grouping_sets = vec![]; for grouping in &agg.groupings { - let mut grouping_set = vec![]; - for e in &grouping.grouping_expressions { - let x = from_substrait_rex( - ctx, - e, - input.schema(), - extensions, - ) - .await?; - grouping_set.push(x.as_ref().clone()); - } + let grouping_set = from_substrait_grouping( + ctx, + grouping, + &ref_group_exprs, + input.schema(), + extensions, + ) + .await?; grouping_sets.push(grouping_set); } // Single-element grouping expression of type Expr::GroupingSet. // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when // parsed by the producer and consumer, since Substrait does not have a type dedicated // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_expr.push(Expr::GroupingSet(GroupingSet::GroupingSets( + group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets( grouping_sets, ))); } @@ -349,9 +717,7 @@ pub async fn from_substrait_rel( let filter = match &m.filter { Some(fil) => Some(Box::new( from_substrait_rex(ctx, fil, input.schema(), extensions) - .await? - .as_ref() - .clone(), + .await?, )), None => None, }; @@ -370,14 +736,27 @@ pub async fn from_substrait_rel( } _ => false, }; + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts( + ctx, + &f.sorts, + input.schema(), + extensions, + ) + .await?, + ) + } else { + None + }; + from_substrait_agg_func( ctx, f, input.schema(), extensions, filter, - // TODO: Add parsing of order_by also - None, + order_by, distinct, ) .await @@ -386,10 +765,9 @@ pub async fn from_substrait_rel( "Aggregate without aggregate function is not supported" ), }; - aggr_expr.push(agg_func?.as_ref().clone()); + aggr_exprs.push(agg_func?.as_ref().clone()); } - - input.aggregate(group_expr, aggr_expr)?.build() + input.aggregate(group_exprs, aggr_exprs)?.build() } else { not_impl_err!("Aggregate without an input is not valid") } @@ -407,6 +785,8 @@ pub async fn from_substrait_rel( let right = LogicalPlanBuilder::from( from_substrait_rel(ctx, join.right.as_ref().unwrap(), extensions).await?, ); + let (left, right) = requalify_sides_if_needed(left, right)?; + let join_type = from_substrait_jointype(join.r#type)?; // The join condition expression needs full input schema and not the output schema from join since we lose columns from // certain join types such as semi and anti joins @@ -436,95 +816,199 @@ pub async fn from_substrait_rel( )? .build() } - None => plan_err!("JoinRel without join condition is not allowed"), + None => { + let on: Vec = vec![]; + left.join_detailed( + right.build()?, + join_type, + (on.clone(), on), + None, + false, + )? + .build() + } } } Some(RelType::Cross(cross)) => { - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + let left = LogicalPlanBuilder::from( from_substrait_rel(ctx, cross.left.as_ref().unwrap(), extensions).await?, ); - let right = + let right = LogicalPlanBuilder::from( from_substrait_rel(ctx, cross.right.as_ref().unwrap(), extensions) - .await?; - left.cross_join(right)?.build() + .await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() } - Some(RelType::Read(read)) => match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); + Some(RelType::Read(read)) => { + fn read_with_schema( + df: DataFrame, + schema: DFSchema, + projection: &Option, + ) -> Result { + ensure_schema_compatability(df.schema().to_owned(), schema.clone())?; + + let schema = apply_masking(schema, projection)?; + + apply_projection(df, schema) + } + + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; + + let substrait_schema = from_substrait_named_struct(named_struct, extensions)?; + + match &read.as_ref().read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + let t = ctx.table(table_reference.clone()).await?; + + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); + + read_with_schema(t, substrait_schema, &read.projection) + } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; - let t = ctx.table(table_reference).await?; - let t = t.into_optimized_plan()?; - match &read.projection { - Some(MaskExpression { select, .. }) => match &select.as_ref() { - Some(projection) => { - let column_indices: Vec = projection - .struct_items - .iter() - .map(|item| item.field as usize) - .collect(); - match &t { - LogicalPlan::TableScan(scan) => { - let fields = column_indices - .iter() - .map(|i| { - scan.projected_schema.qualified_field(*i) - }) - .map(|(qualifier, field)| { - (qualifier.cloned(), Arc::new(field.clone())) - }) - .collect(); - let mut scan = scan.clone(); - scan.projection = Some(column_indices); - scan.projected_schema = - DFSchemaRef::new(DFSchema::new_with_metadata( - fields, - HashMap::new(), - )?); - Ok(LogicalPlan::TableScan(scan)) - } - _ => plan_err!("unexpected plan for table"), - } + + let values = vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + lit, + extensions, + &named_struct.names, + &mut name_idx, + )?)) + }) + .collect::>()?; + if name_idx != named_struct.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + named_struct.names.len() + ); } - _ => Ok(t), - }, - _ => Ok(t), + Ok(lits) + }) + .collect::>()?; + + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = if name.starts_with("file://") + && !name.starts_with("file:///") + { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; + + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + let t = ctx.table(table_reference.clone()).await?; + + let substrait_schema = + substrait_schema.replace_qualifier(table_reference); + + read_with_schema(t, substrait_schema, &read.projection) + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) } } - _ => not_impl_err!("Only NamedTable reads are supported"), - }, + } Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { - Ok(set_op) => match set_op { - set_rel::SetOp::UnionAll => { - if !set.inputs.is_empty() { - let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - )); - for input in &set.inputs[1..] { - union_builder = union_builder? - .union(from_substrait_rel(ctx, input, extensions).await?); + Ok(set_op) => { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set_op { + set_rel::SetOp::UnionAll => { + union_rels(&set.inputs, ctx, extensions, true).await } - union_builder?.build() - } else { - not_impl_err!("Union relation requires at least one input") + set_rel::SetOp::UnionDistinct => { + union_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::IntersectionPrimary => { + LogicalPlanBuilder::intersect( + from_substrait_rel(ctx, &set.inputs[0], extensions) + .await?, + union_rels(&set.inputs[1..], ctx, extensions, true) + .await?, + false, + ) + } + set_rel::SetOp::IntersectionMultiset => { + intersect_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::IntersectionMultisetAll => { + intersect_rels(&set.inputs, ctx, extensions, true).await + } + set_rel::SetOp::MinusPrimary => { + except_rels(&set.inputs, ctx, extensions, false).await + } + set_rel::SetOp::MinusPrimaryAll => { + except_rels(&set.inputs, ctx, extensions, true).await + } + _ => not_impl_err!("Unsupported set operator: {set_op:?}"), } } - _ => not_impl_err!("Unsupported set operator: {set_op:?}"), - }, + } Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), }, Some(RelType::ExtensionLeaf(extension)) => { @@ -551,7 +1035,8 @@ pub async fn from_substrait_rel( ); }; let input_plan = from_substrait_rel(ctx, input_rel, extensions).await?; - let plan = plan.from_template(&plan.expressions(), &[input_plan]); + let plan = + plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::ExtensionMulti(extension)) => { @@ -567,7 +1052,7 @@ pub async fn from_substrait_rel( let input_plan = from_substrait_rel(ctx, input, extensions).await?; inputs.push(input_plan); } - let plan = plan.from_template(&plan.expressions(), &inputs); + let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; Ok(LogicalPlan::Extension(Extension { node: plan })) } Some(RelType::Exchange(exchange)) => { @@ -610,6 +1095,278 @@ pub async fn from_substrait_rel( })) } _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), + }; + apply_emit_kind(retrieve_rel_common(rel), plan?) +} + +fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { + match rel.rel_type.as_ref() { + None => None, + Some(rt) => match rt { + RelType::Read(r) => r.common.as_ref(), + RelType::Filter(f) => f.common.as_ref(), + RelType::Fetch(f) => f.common.as_ref(), + RelType::Aggregate(a) => a.common.as_ref(), + RelType::Sort(s) => s.common.as_ref(), + RelType::Join(j) => j.common.as_ref(), + RelType::Project(p) => p.common.as_ref(), + RelType::Set(s) => s.common.as_ref(), + RelType::ExtensionSingle(e) => e.common.as_ref(), + RelType::ExtensionMulti(e) => e.common.as_ref(), + RelType::ExtensionLeaf(e) => e.common.as_ref(), + RelType::Cross(c) => c.common.as_ref(), + RelType::Reference(_) => None, + RelType::Write(w) => w.common.as_ref(), + RelType::Ddl(d) => d.common.as_ref(), + RelType::HashJoin(j) => j.common.as_ref(), + RelType::MergeJoin(j) => j.common.as_ref(), + RelType::NestedLoopJoin(j) => j.common.as_ref(), + RelType::Window(w) => w.common.as_ref(), + RelType::Exchange(e) => e.common.as_ref(), + RelType::Expand(e) => e.common.as_ref(), + }, + } +} + +fn retrieve_emit_kind(rel_common: Option<&RelCommon>) -> EmitKind { + // the default EmitKind is Direct if it is not set explicitly + let default = EmitKind::Direct(rel_common::Direct {}); + rel_common + .and_then(|rc| rc.emit_kind.as_ref()) + .map_or(default, |ek| ek.clone()) +} + +fn contains_volatile_expr(proj: &Projection) -> bool { + proj.expr.iter().any(|e| e.is_volatile()) +} + +fn apply_emit_kind( + rel_common: Option<&RelCommon>, + plan: LogicalPlan, +) -> Result { + match retrieve_emit_kind(rel_common) { + EmitKind::Direct(_) => Ok(plan), + EmitKind::Emit(Emit { output_mapping }) => { + // It is valid to reference the same field multiple times in the Emit + // In this case, we need to provide unique names to avoid collisions + let mut name_tracker = NameTracker::new(); + match plan { + // To avoid adding a projection on top of a projection, we apply special case + // handling to flatten Substrait Emits. This is only applicable if none of the + // expressions in the projection are volatile. This is to avoid issues like + // converting a single call of the random() function into multiple calls due to + // duplicate fields in the output_mapping. + LogicalPlan::Projection(proj) if !contains_volatile_expr(&proj) => { + let mut exprs: Vec = vec![]; + for field in output_mapping { + let expr = proj.expr + .get(field as usize) + .ok_or_else(|| substrait_datafusion_err!( + "Emit output field {} cannot be resolved in input schema {}", + field, proj.input.schema().clone() + ))?; + exprs.push(name_tracker.get_uniquely_named_expr(expr.clone())?); + } + + let input = Arc::unwrap_or_clone(proj.input); + project(input, exprs) + } + // Otherwise we just handle the output_mapping as a projection + _ => { + let input_schema = plan.schema(); + + let mut exprs: Vec = vec![]; + for index in output_mapping.into_iter() { + let column = Expr::Column(Column::from( + input_schema.qualified_field(index as usize), + )); + let expr = name_tracker.get_uniquely_named_expr(column)?; + exprs.push(expr); + } + + project(plan, exprs) + } + } + } + } +} + +struct NameTracker { + seen_names: HashSet, +} + +enum NameTrackerStatus { + NeverSeen, + SeenBefore, +} + +impl NameTracker { + fn new() -> Self { + NameTracker { + seen_names: HashSet::default(), + } + } + fn get_unique_name(&mut self, name: String) -> (String, NameTrackerStatus) { + match self.seen_names.insert(name.clone()) { + true => (name, NameTrackerStatus::NeverSeen), + false => { + let mut counter = 0; + loop { + let candidate_name = format!("{}__temp__{}", name, counter); + if self.seen_names.insert(candidate_name.clone()) { + return (candidate_name, NameTrackerStatus::SeenBefore); + } + counter += 1; + } + } + } + } + + fn get_uniquely_named_expr(&mut self, expr: Expr) -> Result { + match self.get_unique_name(expr.name_for_alias()?) { + (_, NameTrackerStatus::NeverSeen) => Ok(expr), + (name, NameTrackerStatus::SeenBefore) => Ok(expr.alias(name)), + } + } +} + +/// Ensures that the given Substrait schema is compatible with the schema as given by DataFusion +/// +/// This means: +/// 1. All fields present in the Substrait schema are present in the DataFusion schema. The +/// DataFusion schema may have MORE fields, but not the other way around. +/// 2. All fields are compatible. See [`ensure_field_compatability`] for details +fn ensure_schema_compatability( + table_schema: DFSchema, + substrait_schema: DFSchema, +) -> Result<()> { + substrait_schema + .strip_qualifiers() + .fields() + .iter() + .try_for_each(|substrait_field| { + let df_field = + table_schema.field_with_unqualified_name(substrait_field.name())?; + ensure_field_compatability(df_field, substrait_field) + }) +} + +/// This function returns a DataFrame with fields adjusted if necessary in the event that the +/// Substrait schema is a subset of the DataFusion schema. +fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> Result { + let df_schema = table.schema().to_owned(); + + let t = table.into_unoptimized_plan(); + + if df_schema.logically_equivalent_names_and_types(&substrait_schema) { + return Ok(t); + } + + match t { + LogicalPlan::TableScan(mut scan) => { + let column_indices: Vec = substrait_schema + .strip_qualifiers() + .fields() + .iter() + .map(|substrait_field| { + Ok(df_schema + .index_of_column_by_name(None, substrait_field.name().as_str()) + .unwrap()) + }) + .collect::>()?; + + let fields = column_indices + .iter() + .map(|i| df_schema.qualified_field(*i)) + .map(|(qualifier, field)| (qualifier.cloned(), Arc::new(field.clone()))) + .collect(); + + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( + fields, + df_schema.metadata().clone(), + )?); + scan.projection = Some(column_indices); + + Ok(LogicalPlan::TableScan(scan)) + } + _ => plan_err!("DataFrame passed to apply_projection must be a TableScan"), + } +} + +/// Ensures that the given Substrait field is compatible with the given DataFusion field +/// +/// A field is compatible between Substrait and DataFusion if: +/// 1. They have logically equivalent types. +/// 2. They have the same nullability OR the Substrait field is nullable and the DataFusion fields +/// is not nullable. +/// +/// If a Substrait field is not nullable, the Substrait plan may be built around assuming it is not +/// nullable. As such if DataFusion has that field as nullable the plan should be rejected. +fn ensure_field_compatability( + datafusion_field: &Field, + substrait_field: &Field, +) -> Result<()> { + if !DFSchema::datatype_is_logically_equal( + datafusion_field.data_type(), + substrait_field.data_type(), + ) { + return substrait_err!( + "Field '{}' in Substrait schema has a different type ({}) than the corresponding field in the table schema ({}).", + substrait_field.name(), + substrait_field.data_type(), + datafusion_field.data_type() + ); + } + + if !compatible_nullabilities( + datafusion_field.is_nullable(), + substrait_field.is_nullable(), + ) { + // TODO: from_substrait_struct_type needs to be updated to set the nullability correctly. It defaults to true for now. + return substrait_err!( + "Field '{}' is nullable in the DataFusion schema but not nullable in the Substrait schema.", + substrait_field.name() + ); + } + Ok(()) +} + +/// Returns true if the DataFusion and Substrait nullabilities are compatible, false otherwise +fn compatible_nullabilities( + datafusion_nullability: bool, + substrait_nullability: bool, +) -> bool { + // DataFusion and Substrait have the same nullability + (datafusion_nullability == substrait_nullability) + // DataFusion is not nullable and Substrait is nullable + || (!datafusion_nullability && substrait_nullability) +} + +/// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise +/// conflict with the columns from the other. +/// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For +/// Substrait the names don't matter since it only refers to columns by indices, however DataFusion +/// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). +fn requalify_sides_if_needed( + left: LogicalPlanBuilder, + right: LogicalPlanBuilder, +) -> Result<(LogicalPlanBuilder, LogicalPlanBuilder)> { + let left_cols = left.schema().columns(); + let right_cols = right.schema().columns(); + if left_cols.iter().any(|l| { + right_cols.iter().any(|r| { + l == r || (l.name == r.name && (l.relation.is_none() || r.relation.is_none())) + }) + }) { + // These names have no connection to the original plan, but they'll make the columns + // (mostly) unique. There may be cases where this still causes duplicates, if either left + // or right side itself contains duplicate names with different qualifiers. + Ok(( + left.alias(TableReference::bare("left"))?, + right.alias(TableReference::bare("right"))?, + )) + } else { + Ok((left, right)) } } @@ -620,8 +1377,9 @@ fn from_substrait_jointype(join_type: i32) -> Result { join_rel::JoinType::Left => Ok(JoinType::Left), join_rel::JoinType::Right => Ok(JoinType::Right), join_rel::JoinType::Outer => Ok(JoinType::Full), - join_rel::JoinType::Anti => Ok(JoinType::LeftAnti), - join_rel::JoinType::Semi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + join_rel::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + join_rel::JoinType::LeftMark => Ok(JoinType::LeftMark), _ => plan_err!("unsupported join type {substrait_join_type:?}"), } } else { @@ -634,9 +1392,9 @@ pub async fn from_substrait_sorts( ctx: &SessionContext, substrait_sorts: &Vec, input_schema: &DFSchema, - extensions: &HashMap, -) -> Result> { - let mut sorts: Vec = vec![]; + extensions: &Extensions, +) -> Result> { + let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, extensions) @@ -670,11 +1428,11 @@ pub async fn from_substrait_sorts( None => not_impl_err!("Sort without sort kind is invalid"), }; let (asc, nulls_first) = asc_nullfirst.unwrap(); - sorts.push(Expr::Sort(Sort { - expr: Box::new(expr.as_ref().clone()), + sorts.push(Sort { + expr, asc, nulls_first, - })); + }); } Ok(sorts) } @@ -684,22 +1442,22 @@ pub async fn from_substrait_rex_vec( ctx: &SessionContext, exprs: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { let expression = from_substrait_rex(ctx, expr, input_schema, extensions).await?; - expressions.push(expression.as_ref().clone()); + expressions.push(expression); } Ok(expressions) } /// Convert Substrait FunctionArguments to DataFusion Exprs -pub async fn from_substriat_func_args( +pub async fn from_substrait_func_args( ctx: &SessionContext, arguments: &Vec, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, ) -> Result> { let mut args: Vec = vec![]; for arg in arguments { @@ -707,11 +1465,9 @@ pub async fn from_substriat_func_args( Some(ArgType::Value(e)) => { from_substrait_rex(ctx, e, input_schema, extensions).await } - _ => { - not_impl_err!("Aggregated function argument non-Value type not supported") - } + _ => not_impl_err!("Function argument non-Value type not supported"), }; - args.push(arg_expr?.as_ref().clone()); + args.push(arg_expr?); } Ok(args) } @@ -721,44 +1477,37 @@ pub async fn from_substrait_agg_func( ctx: &SessionContext, f: &AggregateFunction, input_schema: &DFSchema, - extensions: &HashMap, + extensions: &Extensions, filter: Option>, - order_by: Option>, + order_by: Option>, distinct: bool, ) -> Result> { - let mut args: Vec = vec![]; - for arg in &f.arguments { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await - } - _ => { - not_impl_err!("Aggregated function argument non-Value type not supported") - } - }; - args.push(arg_expr?.as_ref().clone()); - } + let args = + from_substrait_func_args(ctx, &f.arguments, input_schema, extensions).await?; - let Some(function_name) = extensions.get(&f.function_reference) else { + let Some(function_name) = extensions.functions.get(&f.function_reference) else { return plan_err!( "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; + let function_name = substrait_fun_name(function_name); // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { + // deal with situation that count(*) got no arguments + let args = if fun.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + args + }; + Ok(Arc::new(Expr::AggregateFunction( expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), ))) - } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) - { - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new(fun, args, distinct, filter, order_by, None), - ))) } else { not_impl_err!( - "Aggregated function {} is not supported: function anchor = {:?}", + "Aggregate function {} is not supported: function anchor = {:?}", function_name, f.function_reference ) @@ -771,18 +1520,16 @@ pub async fn from_substrait_rex( ctx: &SessionContext, e: &Expression, input_schema: &DFSchema, - extensions: &HashMap, -) -> Result> { + extensions: &Extensions, +) -> Result { match &e.rex_type { Some(RexType::SingularOrList(s)) => { let substrait_expr = s.value.as_ref().unwrap(); let substrait_list = s.options.as_ref(); - Ok(Arc::new(Expr::InList(InList { + Ok(Expr::InList(InList { expr: Box::new( from_substrait_rex(ctx, substrait_expr, input_schema, extensions) - .await? - .as_ref() - .clone(), + .await?, ), list: from_substrait_rex_vec( ctx, @@ -792,11 +1539,11 @@ pub async fn from_substrait_rex( ) .await?, negated: false, - }))) + })) + } + Some(RexType::Selection(field_ref)) => { + Ok(from_substrait_field_reference(field_ref, input_schema)?) } - Some(RexType::Selection(field_ref)) => Ok(Arc::new( - from_substrait_field_reference(field_ref, input_schema)?, - )), Some(RexType::IfThen(if_then)) => { // Parse `ifs` // If the first element does not have a `then` part, then we can assume it's a base expression @@ -813,9 +1560,7 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, )); continue; } @@ -828,9 +1573,7 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), Box::new( from_substrait_rex( @@ -839,115 +1582,79 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), )); } // Parse `else` let else_expr = match &if_then.r#else { Some(e) => Some(Box::new( - from_substrait_rex(ctx, e, input_schema, extensions) - .await? - .as_ref() - .clone(), + from_substrait_rex(ctx, e, input_schema, extensions).await?, )), None => None, }; - Ok(Arc::new(Expr::Case(Case { + Ok(Expr::Case(Case { expr, when_then_expr, else_expr, - }))) + })) } Some(RexType::ScalarFunction(f)) => { - let fn_name = extensions.get(&f.function_reference).ok_or_else(|| { - DataFusionError::NotImplemented(format!( - "Aggregated function not found: function reference = {:?}", - f.function_reference - )) - })?; - - // Convert function arguments from Substrait to DataFusion - async fn decode_arguments( - ctx: &SessionContext, - input_schema: &DFSchema, - extensions: &HashMap, - function_args: &[FunctionArgument], - ) -> Result> { - let mut args = Vec::with_capacity(function_args.len()); - for arg in function_args { - let arg_expr = match &arg.arg_type { - Some(ArgType::Value(e)) => { - from_substrait_rex(ctx, e, input_schema, extensions).await - } - _ => not_impl_err!( - "Aggregated function argument non-Value type not supported" - ), - }?; - args.push(arg_expr.as_ref().clone()); - } - Ok(args) - } + let Some(fn_name) = extensions.functions.get(&f.function_reference) else { + return plan_err!( + "Scalar function not found: function reference = {:?}", + f.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_name); - let fn_type = scalar_function_type_from_str(ctx, fn_name)?; - match fn_type { - ScalarFunctionType::Udf(fun) => { - let args = decode_arguments( - ctx, - input_schema, - extensions, - f.arguments.as_slice(), - ) + let args = + from_substrait_func_args(ctx, &f.arguments, input_schema, extensions) .await?; - Ok(Arc::new(Expr::ScalarFunction( - expr::ScalarFunction::new_udf(fun, args), - ))) + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Some(func) = ctx.state().scalar_functions().get(fn_name) { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else if let Some(op) = name_to_op(fn_name) { + if f.arguments.len() < 2 { + return not_impl_err!( + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() + ); } - ScalarFunctionType::Op(op) => { - if f.arguments.len() != 2 { - return not_impl_err!( - "Expect two arguments for binary operator {op:?}" - ); - } - let lhs = &f.arguments[0].arg_type; - let rhs = &f.arguments[1].arg_type; - - match (lhs, rhs) { - (Some(ArgType::Value(l)), Some(ArgType::Value(r))) => { - Ok(Arc::new(Expr::BinaryExpr(BinaryExpr { - left: Box::new( - from_substrait_rex(ctx, l, input_schema, extensions) - .await? - .as_ref() - .clone(), - ), + // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. + // In those cases we iterate through all the arguments, applying the binary expression against them all + let combined_expr = args + .into_iter() + .fold(None, |combined_expr: Option, arg: Expr| { + Some(match combined_expr { + Some(expr) => Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), op, - right: Box::new( - from_substrait_rex(ctx, r, input_schema, extensions) - .await? - .as_ref() - .clone(), - ), - }))) - } - (l, r) => not_impl_err!( - "Invalid arguments for binary expression: {l:?} and {r:?}" - ), - } - } - ScalarFunctionType::Expr(builder) => { - builder.build(ctx, f, input_schema, extensions).await - } + right: Box::new(arg), + }), + None => arg, + }) + }) + .unwrap(); + + Ok(combined_expr) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(ctx, f, input_schema, extensions).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") } } Some(RexType::Literal(lit)) => { - let scalar_value = from_substrait_literal(lit)?; - Ok(Arc::new(Expr::Literal(scalar_value))) + let scalar_value = from_substrait_literal_without_names(lit, extensions)?; + Ok(Expr::Literal(scalar_value)) } Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { - Some(output_type) => Ok(Arc::new(Expr::Cast(Cast::new( + Some(output_type) => Ok(Expr::Cast(Cast::new( Box::new( from_substrait_rex( ctx, @@ -955,37 +1662,61 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), - from_substrait_type(output_type)?, - )))), - None => substrait_err!("Cast experssion without output type is not allowed"), + from_substrait_type_without_names(output_type, extensions)?, + ))), + None => substrait_err!("Cast expression without output type is not allowed"), }, Some(RexType::WindowFunction(window)) => { - let fun = match extensions.get(&window.function_reference) { - Some(function_name) => Ok(find_df_window_func(function_name)), - None => not_impl_err!( - "Window function not found: function anchor = {:?}", - &window.function_reference - ), + let Some(fn_name) = extensions.functions.get(&window.function_reference) + else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); }; + let fn_name = substrait_fun_name(fn_name); + + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = ctx.udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = ctx.udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else if let Some(fun) = find_df_window_func(fn_name) { + Ok(fun) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }?; + let order_by = from_substrait_sorts(ctx, &window.sorts, input_schema, extensions) .await?; - // Substrait does not encode WindowFrameUnits so we're using a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - // TODO: Consider the cases where window frame is specified in query and is different from default - let units = if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - }; - Ok(Arc::new(Expr::WindowFunction(expr::WindowFunction { - fun: fun?.unwrap(), - args: from_substriat_func_args( + + let bound_units = + match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; + Ok(Expr::WindowFunction(expr::WindowFunction { + fun, + args: from_substrait_func_args( ctx, &window.arguments, input_schema, @@ -1001,21 +1732,18 @@ pub async fn from_substrait_rex( .await?, order_by, window_frame: datafusion::logical_expr::WindowFrame::new_bounds( - units, + bound_units, from_substrait_bound(&window.lower_bound, true)?, from_substrait_bound(&window.upper_bound, false)?, ), null_treatment: None, - }))) + })) } Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { Some(subquery_type) => match subquery_type { SubqueryType::InPredicate(in_predicate) => { if in_predicate.needles.len() != 1 { - Err(DataFusionError::Substrait( - "InPredicate Subquery type must have exactly one Needle expression" - .to_string(), - )) + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") } else { let needle_expr = &in_predicate.needles[0]; let haystack_expr = &in_predicate.haystack; @@ -1024,7 +1752,7 @@ pub async fn from_substrait_rex( from_substrait_rel(ctx, haystack_expr, extensions) .await?; let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Arc::new(Expr::InSubquery(InSubquery { + Ok(Expr::InSubquery(InSubquery { expr: Box::new( from_substrait_rex( ctx, @@ -1032,92 +1760,172 @@ pub async fn from_substrait_rex( input_schema, extensions, ) - .await? - .as_ref() - .clone(), + .await?, ), subquery: Subquery { subquery: Arc::new(haystack_expr), outer_ref_columns: outer_refs, }, negated: false, - }))) + })) } else { substrait_err!("InPredicate Subquery type must have a Haystack expression") } } } - _ => substrait_err!("Subquery type not implemented"), + SubqueryType::Scalar(query) => { + let plan = from_substrait_rel( + ctx, + &(query.input.clone()).unwrap_or_default(), + extensions, + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + })) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = from_substrait_rel( + ctx, + &relation.clone().unwrap_or_default(), + extensions, + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + }, + false, + ))) + } + other_type => substrait_err!( + "unimplemented type {:?} for set predicate", + other_type + ), + } + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) + } }, None => { - substrait_err!("Subquery experssion without SubqueryType is not allowed") + substrait_err!("Subquery expression without SubqueryType is not allowed") } }, _ => not_impl_err!("unsupported rex_type"), } } -fn from_substrait_type(dt: &substrait::proto::Type) -> Result { +pub(crate) fn from_substrait_type_without_names( + dt: &Type, + extensions: &Extensions, +) -> Result { + from_substrait_type(dt, extensions, &[], &mut 0) +} + +fn from_substrait_type( + dt: &Type, + extensions: &Extensions, + dfs_names: &[String], + name_idx: &mut usize, +) -> Result { match &dt.kind { Some(s_kind) => match s_kind { r#type::Kind::Bool(_) => Ok(DataType::Boolean), r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int8), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt8), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int8), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt8), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int16), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt16), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int16), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt16), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int32), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt32), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt32), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(DataType::Int64), - UNSIGNED_INTEGER_TYPE_REF => Ok(DataType::UInt64), + DEFAULT_TYPE_VARIATION_REF => Ok(DataType::Int64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(DataType::UInt64), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::Fp32(_) => Ok(DataType::Float32), r#type::Kind::Fp64(_) => Ok(DataType::Float64), - r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_REF => { - Ok(DataType::Timestamp(TimeUnit::Second, None)) - } - TIMESTAMP_MILLI_TYPE_REF => { - Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) - } - TIMESTAMP_MICRO_TYPE_REF => { - Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) - } - TIMESTAMP_NANO_TYPE_REF => { - Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Second, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Millisecond, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Microsecond, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {s_kind:?}" + ), } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {s_kind:?}" - ), - }, + } + r#type::Kind::PrecisionTimestamp(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ), + }?; + Ok(DataType::Timestamp(unit, None)) + } + r#type::Kind::PrecisionTimestampTz(pts) => { + let unit = match pts.precision { + 0 => Ok(TimeUnit::Second), + 3 => Ok(TimeUnit::Millisecond), + 6 => Ok(TimeUnit::Microsecond), + 9 => Ok(TimeUnit::Nanosecond), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestampTz" + ), + }?; + Ok(DataType::Timestamp(unit, Some(DEFAULT_TIMEZONE.into()))) + } r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_REF => Ok(DataType::Date32), - DATE_64_TYPE_REF => Ok(DataType::Date64), + DATE_32_TYPE_VARIATION_REF => Ok(DataType::Date32), + DATE_64_TYPE_VARIATION_REF => Ok(DataType::Date64), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Binary), - LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeBinary), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), @@ -1126,43 +1934,181 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { Ok(DataType::FixedSizeBinary(fixed.length)) } r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::Utf8), - LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeUtf8), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, r#type::Kind::List(list) => { - let inner_type = - from_substrait_type(list.r#type.as_ref().ok_or_else(|| { - substrait_datafusion_err!("List type must have inner type") - })?)?; - let field = Arc::new(Field::new("list_item", inner_type, true)); + let inner_type = list.r#type.as_ref().ok_or_else(|| { + substrait_datafusion_err!("List type must have inner type") + })?; + let field = Arc::new(Field::new_list_field( + from_substrait_type(inner_type, extensions, dfs_names, name_idx)?, + // We ignore Substrait's nullability here to match to_substrait_literal + // which always creates nullable lists + true, + )); match list.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), - LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeList(field)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" )?, } } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + let key_field = Arc::new(Field::new( + "key", + from_substrait_type(key_type, extensions, dfs_names, name_idx)?, + false, + )); + let value_field = Arc::new(Field::new( + "value", + from_substrait_type(value_type, extensions, dfs_names, name_idx)?, + true, + )); + Ok(DataType::Map( + Arc::new(Field::new_struct( + "entries", + [key_field, value_field], + false, // The inner map field is always non-nullable (Arrow #1697), + )), + false, // whether keys are sorted + )) + } r#type::Kind::Decimal(d) => match d.type_variation_reference { - DECIMAL_128_TYPE_REF => { + DECIMAL_128_TYPE_VARIATION_REF => { Ok(DataType::Decimal128(d.precision as u8, d.scale as i8)) } - DECIMAL_256_TYPE_REF => { + DECIMAL_256_TYPE_VARIATION_REF => { Ok(DataType::Decimal256(d.precision as u8, d.scale as i8)) } v => not_impl_err!( "Unsupported Substrait type variation {v} of type {s_kind:?}" ), }, + r#type::Kind::IntervalYear(_) => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + r#type::Kind::IntervalDay(_) => Ok(DataType::Interval(IntervalUnit::DayTime)), + r#type::Kind::IntervalCompound(_) => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + r#type::Kind::UserDefined(u) => { + if let Some(name) = extensions.types.get(&u.type_reference) { + #[allow(deprecated)] + match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } else { + #[allow(deprecated)] + match u.type_reference { + // Kept for backwards compatibility, producers should use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, producers should use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( + "Unsupported Substrait user defined type with ref {} and variation {}", + u.type_reference, + u.type_variation_reference + ), + } + } + } + r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( + s, extensions, dfs_names, name_idx, + )?)), + r#type::Kind::Varchar(_) => Ok(DataType::Utf8), + r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), _ => not_impl_err!("Unsupported Substrait type: {s_kind:?}"), }, _ => not_impl_err!("`None` Substrait kind is not supported"), } } +fn from_substrait_struct_type( + s: &r#type::Struct, + extensions: &Extensions, + dfs_names: &[String], + name_idx: &mut usize, +) -> Result { + let mut fields = vec![]; + for (i, f) in s.types.iter().enumerate() { + let field = Field::new( + next_struct_field_name(i, dfs_names, name_idx)?, + from_substrait_type(f, extensions, dfs_names, name_idx)?, + true, // We assume everything to be nullable since that's easier than ensuring it matches + ); + fields.push(field); + } + Ok(fields.into()) +} + +fn next_struct_field_name( + column_idx: usize, + dfs_names: &[String], + name_idx: &mut usize, +) -> Result { + if dfs_names.is_empty() { + // If names are not given, create dummy names + // c0, c1, ... align with e.g. SqlToRel::create_named_struct + Ok(format!("c{column_idx}")) + } else { + let name = dfs_names.get(*name_idx).cloned().ok_or_else(|| { + substrait_datafusion_err!("Named schema must contain names for all fields") + })?; + *name_idx += 1; + Ok(name) + } +} + +/// Convert Substrait NamedStruct to DataFusion DFSchemaRef +pub fn from_substrait_named_struct( + base_schema: &NamedStruct, + extensions: &Extensions, +) -> Result { + let mut name_idx = 0; + let fields = from_substrait_struct_type( + base_schema.r#struct.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Named struct must contain a struct") + })?, + extensions, + &base_schema.names, + &mut name_idx, + ); + if name_idx != base_schema.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + base_schema.names.len() + ); + } + DFSchema::try_from(Schema::new(fields?)) +} + fn from_substrait_bound( bound: &Option, is_lower: bool, @@ -1173,12 +2119,22 @@ fn from_substrait_bound( BoundKind::CurrentRow(SubstraitBound::CurrentRow {}) => { Ok(WindowFrameBound::CurrentRow) } - BoundKind::Preceding(SubstraitBound::Preceding { offset }) => Ok( - WindowFrameBound::Preceding(ScalarValue::Int64(Some(*offset))), - ), - BoundKind::Following(SubstraitBound::Following { offset }) => Ok( - WindowFrameBound::Following(ScalarValue::Int64(Some(*offset))), - ), + BoundKind::Preceding(SubstraitBound::Preceding { offset }) => { + if *offset <= 0 { + return plan_err!("Preceding bound must be positive"); + } + Ok(WindowFrameBound::Preceding(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } + BoundKind::Following(SubstraitBound::Following { offset }) => { + if *offset <= 0 { + return plan_err!("Following bound must be positive"); + } + Ok(WindowFrameBound::Following(ScalarValue::UInt64(Some( + *offset as u64, + )))) + } BoundKind::Unbounded(SubstraitBound::Unbounded {}) => { if is_lower { Ok(WindowFrameBound::Preceding(ScalarValue::Null)) @@ -1199,59 +2155,121 @@ fn from_substrait_bound( } } -pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { +pub(crate) fn from_substrait_literal_without_names( + lit: &Literal, + extensions: &Extensions, +) -> Result { + from_substrait_literal(lit, extensions, &vec![], &mut 0) +} + +fn from_substrait_literal( + lit: &Literal, + extensions: &Extensions, + dfs_names: &Vec, + name_idx: &mut usize, +) -> Result { let scalar_value = match &lit.literal_type { Some(LiteralType::Boolean(b)) => ScalarValue::Boolean(Some(*b)), Some(LiteralType::I8(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int8(Some(*n as i8)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt8(Some(*n as u8)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int8(Some(*n as i8)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt8(Some(*n as u8)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I16(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int16(Some(*n as i16)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt16(Some(*n as u16)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int16(Some(*n as i16)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt16(Some(*n as u16)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I32(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int32(Some(*n)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt32(Some(*n as u32)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int32(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt32(Some(*n as u32)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::I64(n)) => match lit.type_variation_reference { - DEFAULT_TYPE_REF => ScalarValue::Int64(Some(*n)), - UNSIGNED_INTEGER_TYPE_REF => ScalarValue::UInt64(Some(*n as u64)), + DEFAULT_TYPE_VARIATION_REF => ScalarValue::Int64(Some(*n)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => ScalarValue::UInt64(Some(*n as u64)), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Fp32(f)) => ScalarValue::Float32(Some(*f)), Some(LiteralType::Fp64(f)) => ScalarValue::Float64(Some(*f)), - Some(LiteralType::Timestamp(t)) => match lit.type_variation_reference { - TIMESTAMP_SECOND_TYPE_REF => ScalarValue::TimestampSecond(Some(*t), None), - TIMESTAMP_MILLI_TYPE_REF => ScalarValue::TimestampMillisecond(Some(*t), None), - TIMESTAMP_MICRO_TYPE_REF => ScalarValue::TimestampMicrosecond(Some(*t), None), - TIMESTAMP_NANO_TYPE_REF => ScalarValue::TimestampNanosecond(Some(*t), None), - others => { - return substrait_err!("Unknown type variation reference {others}"); + Some(LiteralType::Timestamp(t)) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match lit.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + ScalarValue::TimestampSecond(Some(*t), None) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + ScalarValue::TimestampMillisecond(Some(*t), None) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + ScalarValue::TimestampMicrosecond(Some(*t), None) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + ScalarValue::TimestampNanosecond(Some(*t), None) + } + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::PrecisionTimestamp(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond(Some(pt.value), None), + 3 => ScalarValue::TimestampMillisecond(Some(pt.value), None), + 6 => ScalarValue::TimestampMicrosecond(Some(pt.value), None), + 9 => ScalarValue::TimestampNanosecond(Some(pt.value), None), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); + } + }, + Some(LiteralType::PrecisionTimestampTz(pt)) => match pt.precision { + 0 => ScalarValue::TimestampSecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 3 => ScalarValue::TimestampMillisecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 6 => ScalarValue::TimestampMicrosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + 9 => ScalarValue::TimestampNanosecond( + Some(pt.value), + Some(DEFAULT_TIMEZONE.into()), + ), + p => { + return not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ); } }, Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), Some(LiteralType::String(s)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeUtf8(Some(s.clone())), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8View(Some(s.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeBinary(Some(b.clone())), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), + LARGE_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::LargeBinary(Some(b.clone())) + } + VIEW_CONTAINER_TYPE_VARIATION_REF => ScalarValue::BinaryView(Some(b.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } @@ -1271,86 +2289,423 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { let s = d.scale.try_into().map_err(|e| { substrait_datafusion_err!("Failed to parse decimal scale: {e}") })?; - ScalarValue::Decimal128( - Some(std::primitive::i128::from_le_bytes(value)), - p, - s, - ) + ScalarValue::Decimal128(Some(i128::from_le_bytes(value)), p, s) + } + Some(LiteralType::List(l)) => { + // Each element should start the name index from the same value, then we increase it + // once at the end + let mut element_name_idx = *name_idx; + let elements = l + .values + .iter() + .map(|el| { + element_name_idx = *name_idx; + from_substrait_literal( + el, + extensions, + dfs_names, + &mut element_name_idx, + ) + }) + .collect::>>()?; + *name_idx = element_name_idx; + if elements.is_empty() { + return substrait_err!( + "Empty list must be encoded as EmptyList literal type, not List" + ); + } + let element_type = elements[0].data_type(); + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::List( + ScalarValue::new_list_nullable(elements.as_slice(), &element_type), + ), + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(elements.as_slice(), &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::EmptyList(l)) => { + let element_type = from_substrait_type( + l.r#type.clone().unwrap().as_ref(), + extensions, + dfs_names, + name_idx, + )?; + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => { + ScalarValue::List(ScalarValue::new_list_nullable(&[], &element_type)) + } + LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::Map(m)) => { + // Each entry should start the name index from the same value, then we increase it + // once at the end + let mut entry_name_idx = *name_idx; + let entries = m + .key_values + .iter() + .map(|kv| { + entry_name_idx = *name_idx; + let key_sv = from_substrait_literal( + kv.key.as_ref().unwrap(), + extensions, + dfs_names, + &mut entry_name_idx, + )?; + let value_sv = from_substrait_literal( + kv.value.as_ref().unwrap(), + extensions, + dfs_names, + &mut entry_name_idx, + )?; + ScalarStructBuilder::new() + .with_scalar(Field::new("key", key_sv.data_type(), false), key_sv) + .with_scalar( + Field::new("value", value_sv.data_type(), true), + value_sv, + ) + .build() + }) + .collect::>>()?; + *name_idx = entry_name_idx; + + if entries.is_empty() { + return substrait_err!( + "Empty map must be encoded as EmptyMap literal type, not Map" + ); + } + + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(Field::new("entries", entries[0].data_type(), false)), + OffsetBuffer::new(vec![0, entries.len() as i32].into()), + ScalarValue::iter_to_array(entries)?.as_struct().to_owned(), + None, + false, + ))) + } + Some(LiteralType::EmptyMap(m)) => { + let key = match &m.key { + Some(k) => Ok(k), + _ => plan_err!("Missing key type for empty map"), + }?; + let value = match &m.value { + Some(v) => Ok(v), + _ => plan_err!("Missing value type for empty map"), + }?; + let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?; + let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?; + + // new_empty_array on a MapType creates a too empty array + // We want it to contain an empty struct array to align with an empty MapBuilder one + let entries = Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + ); + let struct_array = + new_empty_array(entries.data_type()).as_struct().to_owned(); + ScalarValue::Map(Arc::new(MapArray::new( + Arc::new(entries), + OffsetBuffer::new(vec![0, 0].into()), + struct_array, + None, + false, + ))) + } + Some(LiteralType::Struct(s)) => { + let mut builder = ScalarStructBuilder::new(); + for (i, field) in s.fields.iter().enumerate() { + let name = next_struct_field_name(i, dfs_names, name_idx)?; + let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; + // We assume everything to be nullable, since Arrow's strict about things matching + // and it's hard to match otherwise. + builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); + } + builder.build()? + } + Some(LiteralType::Null(ntype)) => { + from_substrait_null(ntype, extensions, dfs_names, name_idx)? + } + Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode, + })) => { + use interval_day_to_second::PrecisionMode; + // DF only supports millisecond precision, so for any more granular type we lose precision + let milliseconds = match precision_mode { + Some(PrecisionMode::Microseconds(ms)) => ms / 1000, + None => + if *subseconds != 0 { + return substrait_err!("Cannot set subseconds field of IntervalDayToSecond without setting precision"); + } else { + 0_i32 + } + Some(PrecisionMode::Precision(0)) => *subseconds as i32 * 1000, + Some(PrecisionMode::Precision(3)) => *subseconds as i32, + Some(PrecisionMode::Precision(6)) => (subseconds / 1000) as i32, + Some(PrecisionMode::Precision(9)) => (subseconds / 1000 / 1000) as i32, + _ => { + return not_impl_err!( + "Unsupported Substrait interval day to second precision mode: {precision_mode:?}") + } + }; + + ScalarValue::new_interval_dt(*days, (seconds * 1000) + milliseconds) + } + Some(LiteralType::IntervalYearToMonth(IntervalYearToMonth { years, months })) => { + ScalarValue::new_interval_ym(*years, *months) + } + Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month, + interval_day_to_second, + })) => match (interval_year_to_month, interval_day_to_second) { + ( + Some(IntervalYearToMonth { years, months }), + Some(IntervalDayToSecond { + days, + seconds, + subseconds, + precision_mode: + Some(interval_day_to_second::PrecisionMode::Precision(p)), + }), + ) => { + if *p < 0 || *p > 9 { + return plan_err!( + "Unsupported Substrait interval day to second precision: {}", + p + ); + } + let nanos = *subseconds * i64::pow(10, (9 - p) as u32); + ScalarValue::new_interval_mdn( + *years * 12 + months, + *days, + *seconds as i64 * NANOSECONDS + nanos, + ) + } + _ => return plan_err!("Substrait compound interval missing components"), + }, + Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), + Some(LiteralType::UserDefined(user_defined)) => { + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed + let interval_month_day_nano = + |user_defined: &UserDefined| -> Result { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval month day nano value is empty"); + }; + let value_slice: [u8; 16] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval month day nano value" + ) + })?; + let months = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + let nanoseconds = + i64::from_le_bytes(value_slice[8..16].try_into().unwrap()); + Ok(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months, + days, + nanoseconds, + }, + ))) + }; + + if let Some(name) = extensions.types.get(&user_defined.type_reference) { + match name.as_ref() { + // Kept for backwards compatibility - producers should use IntervalCompound instead + #[allow(deprecated)] + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type with ref {} and name {}", + user_defined.type_reference, + name + ) + } + } + } else { + #[allow(deprecated)] + match user_defined.type_reference { + // Kept for backwards compatibility, producers should useIntervalYearToMonth instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval year month value is empty"); + }; + let value_slice: [u8; 4] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval year month value" + ) + })?; + ScalarValue::IntervalYearMonth(Some(i32::from_le_bytes( + value_slice, + ))) + } + // Kept for backwards compatibility, producers should useIntervalDayToSecond instead + INTERVAL_DAY_TIME_TYPE_REF => { + let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { + return substrait_err!("Interval day time value is empty"); + }; + let value_slice: [u8; 8] = + (*raw_val.value).try_into().map_err(|_| { + substrait_datafusion_err!( + "Failed to parse interval day time value" + ) + })?; + let days = + i32::from_le_bytes(value_slice[0..4].try_into().unwrap()); + let milliseconds = + i32::from_le_bytes(value_slice[4..8].try_into().unwrap()); + ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days, + milliseconds, + })) + } + // Kept for backwards compatibility, producers should useIntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + interval_month_day_nano(user_defined)? + } + _ => { + return not_impl_err!( + "Unsupported Substrait user defined type literal with ref {}", + user_defined.type_reference + ) + } + } + } } - Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), }; Ok(scalar_value) } -fn from_substrait_null(null_type: &Type) -> Result { +fn from_substrait_null( + null_type: &Type, + extensions: &Extensions, + dfs_names: &[String], + name_idx: &mut usize, +) -> Result { if let Some(kind) = &null_type.kind { match kind { r#type::Kind::Bool(_) => Ok(ScalarValue::Boolean(None)), r#type::Kind::I8(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int8(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt8(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int8(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt8(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::I16(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int16(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt16(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int16(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt16(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::I32(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int32(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt32(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int32(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt32(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::I64(integer) => match integer.type_variation_reference { - DEFAULT_TYPE_REF => Ok(ScalarValue::Int64(None)), - UNSIGNED_INTEGER_TYPE_REF => Ok(ScalarValue::UInt64(None)), + DEFAULT_TYPE_VARIATION_REF => Ok(ScalarValue::Int64(None)), + UNSIGNED_INTEGER_TYPE_VARIATION_REF => Ok(ScalarValue::UInt64(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::Fp32(_) => Ok(ScalarValue::Float32(None)), r#type::Kind::Fp64(_) => Ok(ScalarValue::Float64(None)), - r#type::Kind::Timestamp(ts) => match ts.type_variation_reference { - TIMESTAMP_SECOND_TYPE_REF => Ok(ScalarValue::TimestampSecond(None, None)), - TIMESTAMP_MILLI_TYPE_REF => { - Ok(ScalarValue::TimestampMillisecond(None, None)) - } - TIMESTAMP_MICRO_TYPE_REF => { - Ok(ScalarValue::TimestampMicrosecond(None, None)) - } - TIMESTAMP_NANO_TYPE_REF => { - Ok(ScalarValue::TimestampNanosecond(None, None)) + r#type::Kind::Timestamp(ts) => { + // Kept for backwards compatibility, new plans should use PrecisionTimestamp(Tz) instead + #[allow(deprecated)] + match ts.type_variation_reference { + TIMESTAMP_SECOND_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampSecond(None, None)) + } + TIMESTAMP_MILLI_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampMillisecond(None, None)) + } + TIMESTAMP_MICRO_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampMicrosecond(None, None)) + } + TIMESTAMP_NANO_TYPE_VARIATION_REF => { + Ok(ScalarValue::TimestampNanosecond(None, None)) + } + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ), } - v => not_impl_err!( - "Unsupported Substrait type variation {v} of type {kind:?}" + } + r#type::Kind::PrecisionTimestamp(pts) => match pts.precision { + 0 => Ok(ScalarValue::TimestampSecond(None, None)), + 3 => Ok(ScalarValue::TimestampMillisecond(None, None)), + 6 => Ok(ScalarValue::TimestampMicrosecond(None, None)), + 9 => Ok(ScalarValue::TimestampNanosecond(None, None)), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" + ), + }, + r#type::Kind::PrecisionTimestampTz(pts) => match pts.precision { + 0 => Ok(ScalarValue::TimestampSecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + 3 => Ok(ScalarValue::TimestampMillisecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + 6 => Ok(ScalarValue::TimestampMicrosecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + 9 => Ok(ScalarValue::TimestampNanosecond( + None, + Some(DEFAULT_TIMEZONE.into()), + )), + p => not_impl_err!( + "Unsupported Substrait precision {p} for PrecisionTimestamp" ), }, r#type::Kind::Date(date) => match date.type_variation_reference { - DATE_32_TYPE_REF => Ok(ScalarValue::Date32(None)), - DATE_64_TYPE_REF => Ok(ScalarValue::Date64(None)), + DATE_32_TYPE_VARIATION_REF => Ok(ScalarValue::Date32(None)), + DATE_64_TYPE_VARIATION_REF => Ok(ScalarValue::Date64(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Binary(None)), - LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeBinary(None)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Binary(None)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeBinary(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), }, // FixedBinary is not supported because `None` doesn't have length r#type::Kind::String(string) => match string.type_variation_reference { - DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::Utf8(None)), - LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeUtf8(None)), + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Utf8(None)), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeUtf8(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), @@ -1360,13 +2715,86 @@ fn from_substrait_null(null_type: &Type) -> Result { d.precision as u8, d.scale as i8, )), - _ => not_impl_err!("Unsupported Substrait type: {kind:?}"), + r#type::Kind::List(l) => { + let field = Field::new_list_field( + from_substrait_type( + l.r#type.clone().unwrap().as_ref(), + extensions, + dfs_names, + name_idx, + )?, + true, + ); + match l.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::List( + Arc::new(GenericListArray::new_null(field.into(), 1)), + )), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeList( + Arc::new(GenericListArray::new_null(field.into(), 1)), + )), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ), + } + } + r#type::Kind::Map(map) => { + let key_type = map.key.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have key type") + })?; + let value_type = map.value.as_ref().ok_or_else(|| { + substrait_datafusion_err!("Map type must have value type") + })?; + + let key_type = + from_substrait_type(key_type, extensions, dfs_names, name_idx)?; + let value_type = + from_substrait_type(value_type, extensions, dfs_names, name_idx)?; + let entries_field = Arc::new(Field::new_struct( + "entries", + vec![ + Field::new("key", key_type, false), + Field::new("value", value_type, true), + ], + false, + )); + + DataType::Map(entries_field, false /* keys sorted */).try_into() + } + r#type::Kind::Struct(s) => { + let fields = + from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; + Ok(ScalarStructBuilder::new_null(fields)) + } + _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), } } else { not_impl_err!("Null type without kind is not supported") } } +#[allow(deprecated)] +async fn from_substrait_grouping( + ctx: &SessionContext, + grouping: &Grouping, + expressions: &[Expr], + input_schema: &DFSchemaRef, + extensions: &Extensions, +) -> Result> { + let mut group_exprs = vec![]; + if !grouping.grouping_expressions.is_empty() { + for e in &grouping.grouping_expressions { + let expr = from_substrait_rex(ctx, e, input_schema, extensions).await?; + group_exprs.push(expr); + } + return Ok(group_exprs); + } + for idx in &grouping.expression_references { + let e = &expressions[*idx as usize]; + group_exprs.push(e.clone()); + } + Ok(group_exprs) +} + fn from_substrait_field_reference( field_ref: &FieldReference, input_schema: &DFSchema, @@ -1399,7 +2827,7 @@ impl BuiltinExprBuilder { match name { "not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" - | "is_not_unknown" | "negative" => Some(Self { + | "is_not_unknown" | "negative" | "negate" => Some(Self { expr_name: name.to_string(), }), _ => None, @@ -1411,8 +2839,8 @@ impl BuiltinExprBuilder { ctx: &SessionContext, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, - ) -> Result> { + extensions: &Extensions, + ) -> Result { match self.expr_name.as_str() { "like" => { Self::build_like_expr(ctx, false, f, input_schema, extensions).await @@ -1420,8 +2848,9 @@ impl BuiltinExprBuilder { "ilike" => { Self::build_like_expr(ctx, true, f, input_schema, extensions).await } - "not" | "negative" | "is_null" | "is_not_null" | "is_true" | "is_false" - | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { + "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" + | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" + | "is_not_unknown" => { Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, extensions) .await } @@ -1436,23 +2865,21 @@ impl BuiltinExprBuilder { fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, - ) -> Result> { + extensions: &Extensions, + ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let arg = + from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; let arg = Box::new(arg); let expr = match fn_name { "not" => Expr::Not(arg), - "negative" => Expr::Negative(arg), + "negative" | "negate" => Expr::Negative(arg), "is_null" => Expr::IsNull(arg), "is_not_null" => Expr::IsNotNull(arg), "is_true" => Expr::IsTrue(arg), @@ -1464,7 +2891,7 @@ impl BuiltinExprBuilder { _ => return not_impl_err!("Unsupported builtin expression: {}", fn_name), }; - Ok(Arc::new(expr)) + Ok(expr) } async fn build_like_expr( @@ -1472,48 +2899,105 @@ impl BuiltinExprBuilder { case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &HashMap, - ) -> Result> { + extensions: &Extensions, + ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; - if f.arguments.len() != 3 { - return substrait_err!("Expect three arguments for `{fn_name}` expr"); + if f.arguments.len() != 2 && f.arguments.len() != 3 { + return substrait_err!("Expect two or three arguments for `{fn_name}` expr"); } let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = from_substrait_rex(ctx, expr_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); + let expr = + from_substrait_rex(ctx, expr_substrait, input_schema, extensions).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(ctx, pattern_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else { - return substrait_err!("Invalid arguments type for `{fn_name}` expr"); - }; - let escape_char_expr = - from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) - .await? - .as_ref() - .clone(); - let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else { - return substrait_err!( - "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" - ); + from_substrait_rex(ctx, pattern_substrait, input_schema, extensions).await?; + + // Default case: escape character is Literal(Utf8(None)) + let escape_char = if f.arguments.len() == 3 { + let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type + else { + return substrait_err!("Invalid arguments type for `{fn_name}` expr"); + }; + + let escape_char_expr = + from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions) + .await?; + + match escape_char_expr { + Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { + // Convert Option to Option + escape_char_string.and_then(|s| s.chars().next()) + } + _ => { + return substrait_err!( + "Expect Utf8 literal for escape char, but found {escape_char_expr:?}" + ) + } + } + } else { + None }; - Ok(Arc::new(Expr::Like(Like { + Ok(Expr::Like(Like { negated: false, expr: Box::new(expr), pattern: Box::new(pattern), - escape_char: escape_char.map(|c| c.chars().next().unwrap()), + escape_char, case_insensitive, - }))) + })) + } +} + +#[cfg(test)] +mod test { + use crate::extensions::Extensions; + use crate::logical_plan::consumer::from_substrait_literal_without_names; + use arrow_buffer::IntervalMonthDayNano; + use datafusion::error::Result; + use datafusion::scalar::ScalarValue; + use substrait::proto::expression::literal::{ + interval_day_to_second, IntervalCompound, IntervalDayToSecond, + IntervalYearToMonth, LiteralType, + }; + use substrait::proto::expression::Literal; + + #[test] + fn interval_compound_different_precision() -> Result<()> { + // DF producer (and thus roundtrip) always uses precision = 9, + // this test exists to test with some other value. + let substrait = Literal { + nullable: false, + type_variation_reference: 0, + literal_type: Some(LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: 1, + months: 2, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: 3, + seconds: 4, + subseconds: 5, + precision_mode: Some( + interval_day_to_second::PrecisionMode::Precision(6), + ), + }), + })), + }; + + assert_eq!( + from_substrait_literal_without_names(&substrait, &Extensions::default())?, + ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { + months: 14, + days: 3, + nanoseconds: 4_000_005_000 + })) + ); + + Ok(()) } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a6a38ab6145c..4d864e4334ce 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; -use std::ops::Deref; +use datafusion::config::ConfigOptions; +use datafusion::optimizer::analyzer::expand_wildcard_rule::ExpandWildcardRule; +use datafusion::optimizer::AnalyzerRule; use std::sync::Arc; +use substrait::proto::expression_reference::ExprType; +use datafusion::arrow::datatypes::{Field, IntervalUnit}; use datafusion::logical_expr::{ - CrossJoin, Distinct, Like, Partitioning, WindowFrameUnits, + Distinct, FetchType, Like, Partitioning, SkipType, WindowFrameUnits, }; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -30,21 +33,42 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::common::{exec_err, internal_err, not_impl_err}; -use datafusion::common::{substrait_err, DFSchemaRef}; +use crate::extensions::Extensions; +use crate::variation_const::{ + DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, + DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, + LARGE_CONTAINER_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; +use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; +use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::common::{ + exec_err, internal_err, not_impl_err, plan_err, substrait_datafusion_err, + substrait_err, DFSchemaRef, ToDFSchema, +}; #[allow(unused_imports)] -use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, ScalarFunctionDefinition, Sort, WindowFunction, + Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; -use prost_types::Any as ProtoAny; +use pbjson_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::interval_day_to_second::PrecisionMode; +use substrait::proto::expression::literal::map::KeyValue; +use substrait::proto::expression::literal::{ + IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, List, Map, + PrecisionTimestamp, Struct, +}; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::{CrossRel, ExchangeRel}; +use substrait::proto::read_rel::VirtualTable; +use substrait::proto::rel_common::EmitKind; +use substrait::proto::rel_common::EmitKind::Emit; +use substrait::proto::{ + rel_common, ExchangeRel, ExpressionReference, ExtendedExpression, RelCommon, +}; use substrait::{ proto::{ aggregate_function::AggregationInvocation, @@ -62,10 +86,6 @@ use substrait::{ ScalarFunction, SingularOrList, Subquery, WindowFunction as SubstraitWindowFunction, }, - extensions::{ - self, - simple_extension_declaration::{ExtensionFunction, MappingType}, - }, function_argument::ArgType, join_rel, plan_rel, r#type, read_rel::{NamedTable, ReadType}, @@ -80,50 +100,91 @@ use substrait::{ version, }; -use crate::variation_const::{ - DATE_32_TYPE_REF, DATE_64_TYPE_REF, DECIMAL_128_TYPE_REF, DECIMAL_256_TYPE_REF, - DEFAULT_CONTAINER_TYPE_REF, DEFAULT_TYPE_REF, LARGE_CONTAINER_TYPE_REF, - TIMESTAMP_MICRO_TYPE_REF, TIMESTAMP_MILLI_TYPE_REF, TIMESTAMP_NANO_TYPE_REF, - TIMESTAMP_SECOND_TYPE_REF, UNSIGNED_INTEGER_TYPE_REF, -}; - /// Convert DataFusion LogicalPlan to Substrait Plan pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> Result> { + let mut extensions = Extensions::default(); // Parse relation nodes - let mut extension_info: ( - Vec, - HashMap, - ) = (vec![], HashMap::new()); // Generate PlanRel(s) // Note: Only 1 relation tree is currently supported + + // We have to expand wildcard expressions first as wildcards can't be represented in substrait + let plan = Arc::new(ExpandWildcardRule::new()) + .analyze(plan.clone(), &ConfigOptions::default())?; + let plan_rels = vec![PlanRel { rel_type: Some(plan_rel::RelType::Root(RelRoot { - input: Some(*to_substrait_rel(plan, ctx, &mut extension_info)?), - names: plan.schema().field_names(), + input: Some(*to_substrait_rel(&plan, ctx, &mut extensions)?), + names: to_substrait_named_struct(plan.schema())?.names, })), }]; - let (function_extensions, _) = extension_info; - // Return parsed plan Ok(Box::new(Plan { version: Some(version::version_with_producer("datafusion")), extension_uris: vec![], - extensions: function_extensions, + extensions: extensions.into(), relations: plan_rels, advanced_extensions: None, expected_type_urls: vec![], })) } +/// Serializes a collection of expressions to a Substrait ExtendedExpression message +/// +/// The ExtendedExpression message is a top-level message that can be used to send +/// expressions (not plans) between systems. +/// +/// Each expression is also given names for the output type. These are provided as a +/// field and not a String (since the names may be nested, e.g. a struct). The data +/// type and nullability of this field is redundant (those can be determined by the +/// Expr) and will be ignored. +/// +/// Substrait also requires the input schema of the expressions to be included in the +/// message. The field names of the input schema will be serialized. +pub fn to_substrait_extended_expr( + exprs: &[(&Expr, &Field)], + schema: &DFSchemaRef, + ctx: &SessionContext, +) -> Result> { + let mut extensions = Extensions::default(); + + let substrait_exprs = exprs + .iter() + .map(|(expr, field)| { + let substrait_expr = to_substrait_rex( + ctx, + expr, + schema, + /*col_ref_offset=*/ 0, + &mut extensions, + )?; + let mut output_names = Vec::new(); + flatten_names(field, false, &mut output_names)?; + Ok(ExpressionReference { + output_names, + expr_type: Some(ExprType::Expression(substrait_expr)), + }) + }) + .collect::>>()?; + let substrait_schema = to_substrait_named_struct(schema)?; + + Ok(Box::new(ExtendedExpression { + advanced_extensions: None, + expected_type_urls: vec![], + extension_uris: vec![], + extensions: extensions.into(), + version: Some(version::version_with_producer("datafusion")), + referred_expr: substrait_exprs, + base_schema: Some(substrait_schema), + })) +} + /// Convert DataFusion LogicalPlan to Substrait Rel +#[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, ctx: &SessionContext, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { match plan { LogicalPlan::TableScan(scan) => { @@ -141,19 +202,13 @@ pub fn to_substrait_rel( maintain_singular_struct: false, }); + let table_schema = scan.source.schema().to_dfschema_ref()?; + let base_schema = to_substrait_named_struct(&table_schema)?; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: Some(NamedStruct { - names: scan - .source - .schema() - .fields() - .iter() - .map(|f| f.name().to_owned()) - .collect(), - r#struct: None, - }), + base_schema: Some(base_schema), filter: None, best_effort_filter: None, projection, @@ -165,29 +220,100 @@ pub fn to_substrait_rel( }))), })) } + LogicalPlan::EmptyRelation(e) => { + if e.produce_one_row { + return not_impl_err!( + "Producing a row from empty relation is unsupported" + ); + } + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&e.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + })) + } + LogicalPlan::Values(v) => { + let values = v + .values + .iter() + .map(|row| { + let fields = row + .iter() + .map(|v| match v { + Expr::Literal(sv) => to_substrait_literal(sv, extensions), + Expr::Alias(alias) => match alias.expr.as_ref() { + // The schema gives us the names, so we can skip aliases + Expr::Literal(sv) => to_substrait_literal(sv, extensions), + _ => Err(substrait_datafusion_err!( + "Only literal types can be aliased in Virtual Tables, got: {}", alias.expr.variant_name() + )), + }, + _ => Err(substrait_datafusion_err!( + "Only literal types and aliases are supported in Virtual Tables, got: {}", v.variant_name() + )), + }) + .collect::>()?; + Ok(Struct { fields }) + }) + .collect::>()?; + Ok(Box::new(Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(to_substrait_named_struct(&v.schema)?), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), + }))), + })) + } LogicalPlan::Projection(p) => { let expressions = p .expr .iter() - .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extension_info)) + .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, extensions)) .collect::>>()?; + + let emit_kind = create_project_remapping( + expressions.len(), + p.input.as_ref().schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { - common: None, - input: Some(to_substrait_rel(p.input.as_ref(), ctx, extension_info)?), + common: Some(common), + input: Some(to_substrait_rel(p.input.as_ref(), ctx, extensions)?), expressions, advanced_extension: None, }))), })) } LogicalPlan::Filter(filter) => { - let input = to_substrait_rel(filter.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(filter.input.as_ref(), ctx, extensions)?; let filter_expr = to_substrait_rex( ctx, &filter.predicate, filter.input.schema(), 0, - extension_info, + extensions, )?; Ok(Box::new(Rel { rel_type: Some(RelType::Filter(Box::new(FilterRel { @@ -199,27 +325,30 @@ pub fn to_substrait_rel( })) } LogicalPlan::Limit(limit) => { - let input = to_substrait_rel(limit.input.as_ref(), ctx, extension_info)?; - // Since protobuf can't directly distinguish `None` vs `0` encode `None` as `MAX` - let limit_fetch = limit.fetch.unwrap_or(usize::MAX); + let input = to_substrait_rel(limit.input.as_ref(), ctx, extensions)?; + let FetchType::Literal(fetch) = limit.get_fetch_type()? else { + return not_impl_err!("Non-literal limit fetch"); + }; + let SkipType::Literal(skip) = limit.get_skip_type()? else { + return not_impl_err!("Non-literal limit skip"); + }; Ok(Box::new(Rel { rel_type: Some(RelType::Fetch(Box::new(FetchRel { common: None, input: Some(input), - offset: limit.skip as i64, - count: limit_fetch as i64, + offset: skip as i64, + // use -1 to signal that ALL records should be returned + count: fetch.map(|f| f as i64).unwrap_or(-1), advanced_extension: None, }))), })) } LogicalPlan::Sort(sort) => { - let input = to_substrait_rel(sort.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(sort.input.as_ref(), ctx, extensions)?; let sort_fields = sort .expr .iter() - .map(|e| { - substrait_sort_field(ctx, e, sort.input.schema(), extension_info) - }) + .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Sort(Box::new(SortRel { @@ -231,25 +360,24 @@ pub fn to_substrait_rel( })) } LogicalPlan::Aggregate(agg) => { - let input = to_substrait_rel(agg.input.as_ref(), ctx, extension_info)?; - let groupings = to_substrait_groupings( + let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?; + let (grouping_expressions, groupings) = to_substrait_groupings( ctx, &agg.group_expr, agg.input.schema(), - extension_info, + extensions, )?; let measures = agg .aggr_expr .iter() - .map(|e| { - to_substrait_agg_measure(ctx, e, agg.input.schema(), extension_info) - }) + .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), extensions)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions, groupings, measures, advanced_extension: None, @@ -258,7 +386,7 @@ pub fn to_substrait_rel( } LogicalPlan::Distinct(Distinct::All(plan)) => { // Use Substrait's AggregateRel with empty measures to represent `select distinct` - let input = to_substrait_rel(plan.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?; // Get grouping keys from the input relation's number of output fields let grouping = (0..plan.schema().fields().len()) .map(substrait_field_ref) @@ -268,8 +396,10 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings: vec![Grouping { grouping_expressions: grouping, + expression_references: vec![], }], measures: vec![], advanced_extension: None, @@ -277,8 +407,8 @@ pub fn to_substrait_rel( })) } LogicalPlan::Join(join) => { - let left = to_substrait_rel(join.left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(join.right.as_ref(), ctx, extension_info)?; + let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?; + let right = to_substrait_rel(join.right.as_ref(), ctx, extensions)?; let join_type = to_substrait_jointype(join.join_type); // we only support basic joins so return an error for anything not yet supported match join.join_constraint { @@ -295,7 +425,7 @@ pub fn to_substrait_rel( filter, &Arc::new(in_join_schema), 0, - extension_info, + extensions, )?), None => None, }; @@ -313,7 +443,7 @@ pub fn to_substrait_rel( eq_op, join.left.schema(), join.right.schema(), - extension_info, + extensions, )?; // create conjunction between `join_on` and `join_filter` to embed all join conditions, @@ -324,7 +454,7 @@ pub fn to_substrait_rel( on_expr, filter, Operator::And, - extension_info, + extensions, ))), None => join_on.map(Box::new), // the join expression will only contain `join_on` if filter doesn't exist }, @@ -346,39 +476,22 @@ pub fn to_substrait_rel( }))), })) } - LogicalPlan::CrossJoin(cross_join) => { - let CrossJoin { - left, - right, - schema: _, - } = cross_join; - let left = to_substrait_rel(left.as_ref(), ctx, extension_info)?; - let right = to_substrait_rel(right.as_ref(), ctx, extension_info)?; - Ok(Box::new(Rel { - rel_type: Some(RelType::Cross(Box::new(CrossRel { - common: None, - left: Some(left), - right: Some(right), - advanced_extension: None, - }))), - })) - } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias // since there is no corresponding relation type in Substrait - to_substrait_rel(alias.input.as_ref(), ctx, extension_info) + to_substrait_rel(alias.input.as_ref(), ctx, extensions) } LogicalPlan::Union(union) => { let input_rels = union .inputs .iter() - .map(|input| to_substrait_rel(input.as_ref(), ctx, extension_info)) + .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions)) .collect::>>()? .into_iter() .map(|ptr| *ptr) .collect(); Ok(Box::new(Rel { - rel_type: Some(substrait::proto::rel::RelType::Set(SetRel { + rel_type: Some(RelType::Set(SetRel { common: None, inputs: input_rels, op: set_rel::SetOp::UnionAll as i32, // UNION DISTINCT gets translated to AGGREGATION + UNION ALL @@ -387,46 +500,46 @@ pub fn to_substrait_rel( })) } LogicalPlan::Window(window) => { - let input = to_substrait_rel(window.input.as_ref(), ctx, extension_info)?; - // If the input is a Project relation, we can just append the WindowFunction expressions - // before returning - // Otherwise, wrap the input in a Project relation before appending the WindowFunction - // expressions - let mut project_rel: Box = match &input.as_ref().rel_type { - Some(RelType::Project(p)) => Box::new(*p.clone()), - _ => { - // Create Projection with field referencing all output fields in the input relation - let expressions = (0..window.input.schema().fields().len()) - .map(substrait_field_ref) - .collect::>>()?; - Box::new(ProjectRel { - common: None, - input: Some(input), - expressions, - advanced_extension: None, - }) - } - }; - // Parse WindowFunction expression - let mut window_exprs = vec![]; + let input = to_substrait_rel(window.input.as_ref(), ctx, extensions)?; + + // create a field reference for each input field + let mut expressions = (0..window.input.schema().fields().len()) + .map(substrait_field_ref) + .collect::>>()?; + + // process and add each window function expression for expr in &window.window_expr { - window_exprs.push(to_substrait_rex( + expressions.push(to_substrait_rex( ctx, expr, window.input.schema(), 0, - extension_info, + extensions, )?); } - // Append parsed WindowFunction expressions - project_rel.expressions.extend(window_exprs); + + let emit_kind = create_project_remapping( + expressions.len(), + window.input.schema().fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + let project_rel = Box::new(ProjectRel { + common: Some(common), + input: Some(input), + expressions, + advanced_extension: None, + }); + Ok(Box::new(Rel { rel_type: Some(RelType::Project(project_rel)), })) } LogicalPlan::Repartition(repartition) => { - let input = - to_substrait_rel(repartition.input.as_ref(), ctx, extension_info)?; + let input = to_substrait_rel(repartition.input.as_ref(), ctx, extensions)?; let partition_count = match repartition.partitioning_scheme { Partitioning::RoundRobinBatch(num) => num, Partitioning::Hash(_, num) => num, @@ -478,13 +591,13 @@ pub fn to_substrait_rel( .serialize_logical_plan(extension_plan.node.as_ref())?; let detail = ProtoAny { type_url: extension_plan.node.name().to_string(), - value: extension_bytes, + value: extension_bytes.into(), }; let mut inputs_rel = extension_plan .node .inputs() .into_iter() - .map(|plan| to_substrait_rel(plan, ctx, extension_info)) + .map(|plan| to_substrait_rel(plan, ctx, extensions)) .collect::>>()?; let rel_type = match inputs_rel.len() { 0 => RelType::ExtensionLeaf(ExtensionLeafRel { @@ -506,8 +619,71 @@ pub fn to_substrait_rel( rel_type: Some(rel_type), })) } - _ => not_impl_err!("Unsupported operator: {plan:?}"), + _ => not_impl_err!("Unsupported operator: {plan}"), + } +} + +/// By default, a Substrait Project outputs all input fields followed by all expressions. +/// A DataFusion Projection only outputs expressions. In order to keep the Substrait +/// plan consistent with DataFusion, we must apply an output mapping that skips the input +/// fields so that the Substrait Project will only output the expression fields. +fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { + let expression_field_start = input_field_count; + let expression_field_end = expression_field_start + expr_count; + let output_mapping = (expression_field_start..expression_field_end) + .map(|i| i as i32) + .collect(); + Emit(rel_common::Emit { output_mapping }) +} + +// Substrait wants a list of all field names, including nested fields from structs, +// also from within e.g. lists and maps. However, it does not want the list and map field names +// themselves - only proper structs fields are considered to have useful names. +fn flatten_names(field: &Field, skip_self: bool, names: &mut Vec) -> Result<()> { + if !skip_self { + names.push(field.name().to_string()); + } + match field.data_type() { + DataType::Struct(fields) => { + for field in fields { + flatten_names(field, false, names)?; + } + Ok(()) + } + DataType::List(l) => flatten_names(l, true, names), + DataType::LargeList(l) => flatten_names(l, true, names), + DataType::Map(m, _) => match m.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + flatten_names(&key_and_value[0], true, names)?; + flatten_names(&key_and_value[1], true, names) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, + _ => Ok(()), + }?; + Ok(()) +} + +fn to_substrait_named_struct(schema: &DFSchemaRef) -> Result { + let mut names = Vec::with_capacity(schema.fields().len()); + for field in schema.fields() { + flatten_names(field, false, &mut names)?; } + + let field_types = r#type::Struct { + types: schema + .fields() + .iter() + .map(|f| to_substrait_type(f.data_type(), f.is_nullable())) + .collect::>()?, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability: r#type::Nullability::Unspecified as i32, + }; + + Ok(NamedStruct { + names, + r#struct: Some(field_types), + }) } fn to_substrait_join_expr( @@ -516,30 +692,27 @@ fn to_substrait_join_expr( eq_op: Operator, left_schema: &DFSchemaRef, right_schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result> { // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left - let l = to_substrait_rex(ctx, left, left_schema, 0, extension_info)?; + let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?; // Parse right let r = to_substrait_rex( ctx, right, right_schema, left_schema.fields().len(), // offset to return the correct index - extension_info, + extensions, )?; // AND with existing expression - exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extensions)); } let join_expr: Option = exprs.into_iter().reduce(|acc: Expression, e: Expression| { - make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + make_binary_op_scalar_func(&acc, &e, Operator::And, extensions) }); Ok(join_expr) } @@ -550,9 +723,12 @@ fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { JoinType::Left => join_rel::JoinType::Left, JoinType::Right => join_rel::JoinType::Right, JoinType::Full => join_rel::JoinType::Outer, - JoinType::LeftAnti => join_rel::JoinType::Anti, - JoinType::LeftSemi => join_rel::JoinType::Semi, - JoinType::RightAnti | JoinType::RightSemi => unimplemented!(), + JoinType::LeftAnti => join_rel::JoinType::LeftAnti, + JoinType::LeftSemi => join_rel::JoinType::LeftSemi, + JoinType::LeftMark => join_rel::JoinType::LeftMark, + JoinType::RightAnti | JoinType::RightSemi => { + unimplemented!() + } } } @@ -568,7 +744,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { Operator::Minus => "subtract", Operator::Multiply => "multiply", Operator::Divide => "divide", - Operator::Modulo => "mod", + Operator::Modulo => "modulus", Operator::And => "and", Operator::Or => "or", Operator::IsDistinctFrom => "is_distinct_from", @@ -592,21 +768,26 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } +#[allow(deprecated)] pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, + ref_group_exprs: &mut Vec, ) -> Result { - let grouping_expressions = exprs - .iter() - .map(|e| to_substrait_rex(ctx, e, schema, 0, extension_info)) - .collect::>>()?; + let mut expression_references = vec![]; + let mut grouping_expressions = vec![]; + + for e in exprs { + let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?; + grouping_expressions.push(rex.clone()); + ref_group_exprs.push(rex); + expression_references.push((ref_group_exprs.len() - 1) as u32); + } Ok(Grouping { grouping_expressions, + expression_references, }) } @@ -614,12 +795,10 @@ pub fn to_substrait_groupings( ctx: &SessionContext, exprs: &[Expr], schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), -) -> Result> { - match exprs.len() { + extensions: &mut Extensions, +) -> Result<(Vec, Vec)> { + let mut ref_group_exprs = vec![]; + let groupings = match exprs.len() { 1 => match &exprs[0] { Expr::GroupingSet(gs) => match gs { GroupingSet::Cube(_) => Err(DataFusionError::NotImplemented( @@ -628,7 +807,13 @@ pub fn to_substrait_groupings( GroupingSet::GroupingSets(sets) => Ok(sets .iter() .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) + parse_flat_grouping_exprs( + ctx, + set, + schema, + extensions, + &mut ref_group_exprs, + ) }) .collect::>>()?), GroupingSet::Rollup(set) => { @@ -640,7 +825,13 @@ pub fn to_substrait_groupings( .iter() .rev() .map(|set| { - parse_flat_grouping_exprs(ctx, set, schema, extension_info) + parse_flat_grouping_exprs( + ctx, + set, + schema, + extensions, + &mut ref_group_exprs, + ) }) .collect::>>()?) } @@ -649,16 +840,19 @@ pub fn to_substrait_groupings( ctx, exprs, schema, - extension_info, + extensions, + &mut ref_group_exprs, )?]), }, _ => Ok(vec![parse_flat_grouping_exprs( ctx, exprs, schema, - extension_info, + extensions, + &mut ref_group_exprs, )?]), - } + }?; + Ok((ref_group_exprs, groupings)) } #[allow(deprecated)] @@ -666,25 +860,20 @@ pub fn to_substrait_agg_measure( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by, null_treatment: _, }) => { - match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { + Expr::AggregateFunction(expr::AggregateFunction { func, args, distinct, filter, order_by, null_treatment: _, }) => { let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? + order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extensions)).collect::>>()? } else { vec![] }; let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) }); } - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(func.name().to_string()); Ok(Measure { measure: Some(AggregateFunction { function_reference: function_anchor, @@ -700,47 +889,14 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), + Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extensions)?), None => None } }) - } - AggregateFunctionDefinition::UDF(fun) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(ctx, expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.name().to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: AggregationInvocation::All as i32, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(ctx, f, schema, 0, extension_info)?), - None => None - } - }) - } - AggregateFunctionDefinition::Name(name) => { - internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) - } - } } Expr::Alias(Alias{expr,..})=> { - to_substrait_agg_measure(ctx, expr, schema, extension_info) + to_substrait_agg_measure(ctx, expr, schema, extensions) } _ => internal_err!( "Expression must be compatible with aggregation. Unsupported expression: {:?}. ExpressionType: {:?}", @@ -753,74 +909,20 @@ pub fn to_substrait_agg_measure( /// Converts sort expression to corresponding substrait `SortField` fn to_substrait_sort_field( ctx: &SessionContext, - expr: &Expr, + sort: &Sort, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { - match expr { - Expr::Sort(sort) => { - let sort_kind = match (sort.asc, sort.nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(to_substrait_rex( - ctx, - sort.expr.deref(), - schema, - 0, - extension_info, - )?), - sort_kind: Some(SortKind::Direction(sort_kind.into())), - }) - } - _ => exec_err!("expects to receive sort expression"), - } -} - -fn _register_function( - function_name: String, - extension_info: &mut ( - Vec, - HashMap, - ), -) -> u32 { - let (function_extensions, function_set) = extension_info; - let function_name = function_name.to_lowercase(); - // To prevent ambiguous references between ScalarFunctions and AggregateFunctions, - // a plan-relative identifier starting from 0 is used as the function_anchor. - // The consumer is responsible for correctly registering - // mapping info stored in the extensions by the producer. - let function_anchor = match function_set.get(&function_name) { - Some(function_anchor) => { - // Function has been registered - *function_anchor - } - None => { - // Function has NOT been registered - let function_anchor = function_set.len() as u32; - function_set.insert(function_name.clone(), function_anchor); - - let function_extension = ExtensionFunction { - extension_uri_reference: u32::MAX, - function_anchor, - name: function_name, - }; - let simple_extension = extensions::SimpleExtensionDeclaration { - mapping_type: Some(MappingType::ExtensionFunction(function_extension)), - }; - function_extensions.push(simple_extension); - function_anchor - } + let sort_kind = match (sort.asc, sort.nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, }; - - // Return function anchor - function_anchor + Ok(SortField { + expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?), + sort_kind: Some(SortKind::Direction(sort_kind.into())), + }) } /// Return Substrait scalar function with two arguments @@ -829,13 +931,9 @@ pub fn make_binary_op_scalar_func( lhs: &Expression, rhs: &Expression, op: Operator, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Expression { - let function_anchor = - _register_function(operator_to_name(op).to_string(), extension_info); + let function_anchor = extensions.register_function(operator_to_name(op).to_string()); Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -860,7 +958,7 @@ pub fn make_binary_op_scalar_func( /// /// * `expr` - DataFusion expression to be parse into a Substrait expression /// * `schema` - DataFusion input schema for looking up field qualifiers -/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// * `col_ref_offset` - Offset for calculating Substrait field reference indices. /// This should only be set by caller with more than one input relations i.e. Join. /// Substrait expects one set of indices when joining two relations. /// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` @@ -876,17 +974,14 @@ pub fn make_binary_op_scalar_func( /// `col_ref(1) = col_ref(3 + 0)` /// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index /// of the join key column from `right` -/// * `extension_info` - Substrait extension info. Contains registered function information +/// * `extensions` - Substrait extension info. Contains registered function information #[allow(deprecated)] pub fn to_substrait_rex( ctx: &SessionContext, expr: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { match expr { Expr::InList(InList { @@ -896,10 +991,10 @@ pub fn to_substrait_rex( }) => { let substrait_list = list .iter() - .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extension_info)) + .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, extensions)) .collect::>>()?; let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_or_list = Expression { rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList { @@ -909,8 +1004,7 @@ pub fn to_substrait_rex( }; if *negated { - let function_anchor = - _register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -936,18 +1030,12 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } - // function should be resolved during `AnalyzerRule` - if let ScalarFunctionDefinition::Name(_) = fun.func_def { - return internal_err!("Function `Expr` with name should be resolved."); - } - - let function_anchor = - _register_function(fun.name().to_string(), extension_info); + let function_anchor = extensions.register_function(fun.name().to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { function_reference: function_anchor, @@ -967,58 +1055,58 @@ pub fn to_substrait_rex( if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_low, Operator::Lt, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_high, &substrait_expr, Operator::Lt, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::Or, - extension_info, + extensions, )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let substrait_low = - to_substrait_rex(ctx, low, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, low, schema, col_ref_offset, extensions)?; let substrait_high = - to_substrait_rex(ctx, high, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, high, schema, col_ref_offset, extensions)?; let l_expr = make_binary_op_scalar_func( &substrait_low, &substrait_expr, Operator::LtEq, - extension_info, + extensions, ); let r_expr = make_binary_op_scalar_func( &substrait_expr, &substrait_high, Operator::LtEq, - extension_info, + extensions, ); Ok(make_binary_op_scalar_func( &l_expr, &r_expr, Operator::And, - extension_info, + extensions, )) } } @@ -1027,10 +1115,10 @@ pub fn to_substrait_rex( substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extension_info)?; - let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extension_info)?; + let l = to_substrait_rex(ctx, left, schema, col_ref_offset, extensions)?; + let r = to_substrait_rex(ctx, right, schema, col_ref_offset, extensions)?; - Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) + Ok(make_binary_op_scalar_func(&l, &r, *op, extensions)) } Expr::Case(Case { expr, @@ -1047,7 +1135,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?), then: None, }); @@ -1060,14 +1148,14 @@ pub fn to_substrait_rex( r#if, schema, col_ref_offset, - extension_info, + extensions, )?), then: Some(to_substrait_rex( ctx, then, schema, col_ref_offset, - extension_info, + extensions, )?), }); } @@ -1079,7 +1167,7 @@ pub fn to_substrait_rex( e, schema, col_ref_offset, - extension_info, + extensions, )?)), None => None, }; @@ -1092,22 +1180,22 @@ pub fn to_substrait_rex( Ok(Expression { rex_type: Some(RexType::Cast(Box::new( substrait::proto::expression::Cast { - r#type: Some(to_substrait_type(data_type)?), + r#type: Some(to_substrait_type(data_type, true)?), input: Some(Box::new(to_substrait_rex( ctx, expr, schema, col_ref_offset, - extension_info, + extensions, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED }, ))), }) } - Expr::Literal(value) => to_substrait_literal(value), + Expr::Literal(value) => to_substrait_literal_expr(value, extensions), Expr::Alias(Alias { expr, .. }) => { - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions) } Expr::WindowFunction(WindowFunction { fun, @@ -1118,7 +1206,7 @@ pub fn to_substrait_rex( null_treatment: _, }) => { // function reference - let function_anchor = _register_function(fun.to_string(), extension_info); + let function_anchor = extensions.register_function(fun.to_string()); // arguments let mut arguments: Vec = vec![]; for arg in args { @@ -1128,19 +1216,19 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, )?)), }); } // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extension_info)) + .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, extensions)) .collect::>>()?; // order by expressions let order_by = order_by .iter() - .map(|e| substrait_sort_field(ctx, e, schema, extension_info)) + .map(|e| substrait_sort_field(ctx, e, schema, extensions)) .collect::>>()?; // window frame let bounds = to_substrait_bounds(window_frame)?; @@ -1169,7 +1257,7 @@ pub fn to_substrait_rex( *escape_char, schema, col_ref_offset, - extension_info, + extensions, ), Expr::InSubquery(InSubquery { expr, @@ -1177,10 +1265,10 @@ pub fn to_substrait_rex( negated, }) => { let substrait_expr = - to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; + to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; let subquery_plan = - to_substrait_rel(subquery.subquery.as_ref(), ctx, extension_info)?; + to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?; let substrait_subquery = Expression { rex_type: Some(RexType::Subquery(Box::new(Subquery { @@ -1195,8 +1283,7 @@ pub fn to_substrait_rex( }))), }; if *negated { - let function_anchor = - _register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1219,7 +1306,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1227,7 +1314,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1235,7 +1322,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1243,7 +1330,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1251,7 +1338,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1259,7 +1346,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1267,7 +1354,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1275,7 +1362,7 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn( ctx, @@ -1283,15 +1370,15 @@ pub fn to_substrait_rex( arg, schema, col_ref_offset, - extension_info, + extensions, ), Expr::Negative(arg) => to_substrait_unary_scalar_fn( ctx, - "negative", + "negate", arg, schema, col_ref_offset, - extension_info, + extensions, ), _ => { not_impl_err!("Unsupported expression: {expr:?}") @@ -1299,180 +1386,256 @@ pub fn to_substrait_rex( } } -fn to_substrait_type(dt: &DataType) -> Result { - let default_nullability = r#type::Nullability::Required as i32; +fn to_substrait_type(dt: &DataType, nullable: bool) -> Result { + let nullability = if nullable { + r#type::Nullability::Nullable as i32 + } else { + r#type::Nullability::Required as i32 + }; match dt { DataType::Null => internal_err!("Null cast is not valid"), DataType::Boolean => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::Int8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::UInt8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, })), }), DataType::Int16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::UInt16 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, })), }), DataType::Int32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::UInt32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, })), }), DataType::Int64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::UInt64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: UNSIGNED_INTEGER_TYPE_VARIATION_REF, + nullability, })), }), // Float16 is not supported in Substrait DataType::Float32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::Float64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), - // Timezone is ignored. - DataType::Timestamp(unit, _) => { - let type_variation_reference = match unit { - TimeUnit::Second => TIMESTAMP_SECOND_TYPE_REF, - TimeUnit::Millisecond => TIMESTAMP_MILLI_TYPE_REF, - TimeUnit::Microsecond => TIMESTAMP_MICRO_TYPE_REF, - TimeUnit::Nanosecond => TIMESTAMP_NANO_TYPE_REF, + DataType::Timestamp(unit, tz) => { + let precision = match unit { + TimeUnit::Second => 0, + TimeUnit::Millisecond => 3, + TimeUnit::Microsecond => 6, + TimeUnit::Nanosecond => 9, }; - Ok(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference, - nullability: default_nullability, - })), - }) + let kind = match tz { + None => r#type::Kind::PrecisionTimestamp(r#type::PrecisionTimestamp { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }), + Some(_) => { + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + r#type::Kind::PrecisionTimestampTz(r#type::PrecisionTimestampTz { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision, + }) + } + }; + Ok(substrait::proto::Type { kind: Some(kind) }) } DataType::Date32 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DATE_32_TYPE_VARIATION_REF, + nullability, })), }), DataType::Date64 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DATE_64_TYPE_VARIATION_REF, + nullability, })), }), + DataType::Interval(interval_unit) => { + match interval_unit { + IntervalUnit::YearMonth => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalYear(r#type::IntervalYear { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + })), + }), + IntervalUnit::DayTime => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalDay(r#type::IntervalDay { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: Some(3), // DayTime precision is always milliseconds + })), + }), + IntervalUnit::MonthDayNano => { + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::IntervalCompound( + r#type::IntervalCompound { + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, + precision: 9, // nanos + }, + )), + }) + } + } + } DataType::Binary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, })), }), DataType::FixedSizeBinary(length) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::FixedBinary(r#type::FixedBinary { length: *length, - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }), DataType::LargeBinary => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::BinaryView => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Binary(r#type::Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, })), }), DataType::Utf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, })), }), DataType::LargeUtf8 => Ok(substrait::proto::Type { kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8View => Ok(substrait::proto::Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, })), }), DataType::List(inner) => { - let inner_type = to_substrait_type(inner.data_type())?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, }))), }) } DataType::LargeList(inner) => { - let inner_type = to_substrait_type(inner.data_type())?; + let inner_type = to_substrait_type(inner.data_type(), inner.is_nullable())?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::List(Box::new(r#type::List { r#type: Some(Box::new(inner_type)), - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, }))), }) } + DataType::Map(inner, _) => match inner.data_type() { + DataType::Struct(key_and_value) if key_and_value.len() == 2 => { + let key_type = to_substrait_type( + key_and_value[0].data_type(), + key_and_value[0].is_nullable(), + )?; + let value_type = to_substrait_type( + key_and_value[1].data_type(), + key_and_value[1].is_nullable(), + )?; + Ok(substrait::proto::Type { + kind: Some(r#type::Kind::Map(Box::new(r#type::Map { + key: Some(Box::new(key_type)), + value: Some(Box::new(value_type)), + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + }))), + }) + } + _ => plan_err!("Map fields must contain a Struct with exactly 2 fields"), + }, DataType::Struct(fields) => { let field_types = fields .iter() - .map(|field| to_substrait_type(field.data_type())) + .map(|field| to_substrait_type(field.data_type(), field.is_nullable())) .collect::>>()?; Ok(substrait::proto::Type { kind: Some(r#type::Kind::Struct(r#type::Struct { types: field_types, - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + nullability, })), }) } DataType::Decimal128(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_128_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DECIMAL_128_TYPE_VARIATION_REF, + nullability, scale: *s as i32, precision: *p as i32, })), }), DataType::Decimal256(p, s) => Ok(substrait::proto::Type { kind: Some(r#type::Kind::Decimal(r#type::Decimal { - type_variation_reference: DECIMAL_256_TYPE_REF, - nullability: default_nullability, + type_variation_reference: DECIMAL_256_TYPE_VARIATION_REF, + nullability, scale: *s as i32, precision: *p as i32, })), @@ -1519,20 +1682,19 @@ fn make_substrait_like_expr( escape_char: Option, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { let function_anchor = if ignore_case { - _register_function("ilike".to_string(), extension_info) + extensions.register_function("ilike".to_string()) } else { - _register_function("like".to_string(), extension_info) + extensions.register_function("like".to_string()) }; - let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; - let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = - to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; + let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)?; + let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extensions)?; + let escape_char = to_substrait_literal_expr( + &ScalarValue::Utf8(escape_char.map(|c| c.to_string())), + extensions, + )?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1556,7 +1718,7 @@ fn make_substrait_like_expr( }; if negated { - let function_anchor = _register_function("not".to_string(), extension_info); + let function_anchor = extensions.register_function("not".to_string()); Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1574,98 +1736,38 @@ fn make_substrait_like_expr( } } +fn to_substrait_bound_offset(value: &ScalarValue) -> Option { + match value { + ScalarValue::UInt8(Some(v)) => Some(*v as i64), + ScalarValue::UInt16(Some(v)) => Some(*v as i64), + ScalarValue::UInt32(Some(v)) => Some(*v as i64), + ScalarValue::UInt64(Some(v)) => Some(*v as i64), + ScalarValue::Int8(Some(v)) => Some(*v as i64), + ScalarValue::Int16(Some(v)) => Some(*v as i64), + ScalarValue::Int32(Some(v)) => Some(*v as i64), + ScalarValue::Int64(Some(v)) => Some(*v), + _ => None, + } +} + fn to_substrait_bound(bound: &WindowFrameBound) -> Bound { match bound { WindowFrameBound::CurrentRow => Bound { kind: Some(BoundKind::CurrentRow(SubstraitBound::CurrentRow {})), }, - WindowFrameBound::Preceding(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v as i64, - })), + WindowFrameBound::Preceding(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { offset })), }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Preceding(SubstraitBound::Preceding { - offset: *v, - })), - }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, - WindowFrameBound::Following(s) => match s { - ScalarValue::UInt8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::UInt32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), + WindowFrameBound::Following(s) => match to_substrait_bound_offset(s) { + Some(offset) => Bound { + kind: Some(BoundKind::Following(SubstraitBound::Following { offset })), }, - ScalarValue::UInt64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int8(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int16(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int32(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v as i64, - })), - }, - ScalarValue::Int64(Some(v)) => Bound { - kind: Some(BoundKind::Following(SubstraitBound::Following { - offset: *v, - })), - }, - _ => Bound { + None => Bound { kind: Some(BoundKind::Unbounded(SubstraitBound::Unbounded {})), }, }, @@ -1688,73 +1790,299 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { + if value.is_null() { + return Ok(Literal { + nullable: true, + type_variation_reference: DEFAULT_TYPE_VARIATION_REF, + literal_type: Some(LiteralType::Null(to_substrait_type( + &value.data_type(), + true, + )?)), + }); + } let (literal_type, type_variation_reference) = match value { - ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), - ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), - ScalarValue::UInt8(Some(n)) => { - (LiteralType::I8(*n as i32), UNSIGNED_INTEGER_TYPE_REF) - } - ScalarValue::Int16(Some(n)) => (LiteralType::I16(*n as i32), DEFAULT_TYPE_REF), - ScalarValue::UInt16(Some(n)) => { - (LiteralType::I16(*n as i32), UNSIGNED_INTEGER_TYPE_REF) + ScalarValue::Boolean(Some(b)) => { + (LiteralType::Boolean(*b), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_REF), - ScalarValue::UInt32(Some(n)) => { - (LiteralType::I32(*n as i32), UNSIGNED_INTEGER_TYPE_REF) + ScalarValue::Int8(Some(n)) => { + (LiteralType::I8(*n as i32), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_REF), - ScalarValue::UInt64(Some(n)) => { - (LiteralType::I64(*n as i64), UNSIGNED_INTEGER_TYPE_REF) - } - ScalarValue::Float32(Some(f)) => (LiteralType::Fp32(*f), DEFAULT_TYPE_REF), - ScalarValue::Float64(Some(f)) => (LiteralType::Fp64(*f), DEFAULT_TYPE_REF), - ScalarValue::TimestampSecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_SECOND_TYPE_REF) + ScalarValue::UInt8(Some(n)) => ( + LiteralType::I8(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int16(Some(n)) => { + (LiteralType::I16(*n as i32), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampMillisecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_MILLI_TYPE_REF) + ScalarValue::UInt16(Some(n)) => ( + LiteralType::I16(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int32(Some(n)) => (LiteralType::I32(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt32(Some(n)) => ( + LiteralType::I32(*n as i32), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Int64(Some(n)) => (LiteralType::I64(*n), DEFAULT_TYPE_VARIATION_REF), + ScalarValue::UInt64(Some(n)) => ( + LiteralType::I64(*n as i64), + UNSIGNED_INTEGER_TYPE_VARIATION_REF, + ), + ScalarValue::Float32(Some(f)) => { + (LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampMicrosecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_MICRO_TYPE_REF) + ScalarValue::Float64(Some(f)) => { + (LiteralType::Fp64(*f), DEFAULT_TYPE_VARIATION_REF) } - ScalarValue::TimestampNanosecond(Some(t), _) => { - (LiteralType::Timestamp(*t), TIMESTAMP_NANO_TYPE_REF) + ScalarValue::TimestampSecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), None) => ( + LiteralType::PrecisionTimestamp(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + // If timezone is present, no matter what the actual tz value is, it indicates the + // value of the timestamp is tied to UTC epoch. That's all that Substrait cares about. + // As the timezone is lost, this conversion may be lossy for downstream use of the value. + ScalarValue::TimestampSecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 0, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMillisecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 3, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampMicrosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 6, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::TimestampNanosecond(Some(t), Some(_)) => ( + LiteralType::PrecisionTimestampTz(PrecisionTimestamp { + precision: 9, + value: *t, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Date32(Some(d)) => { + (LiteralType::Date(*d), DATE_32_TYPE_VARIATION_REF) } - ScalarValue::Date32(Some(d)) => (LiteralType::Date(*d), DATE_32_TYPE_REF), // Date64 literal is not supported in Substrait - ScalarValue::Binary(Some(b)) => { - (LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_REF) - } - ScalarValue::LargeBinary(Some(b)) => { - (LiteralType::Binary(b.clone()), LARGE_CONTAINER_TYPE_REF) - } - ScalarValue::FixedSizeBinary(_, Some(b)) => { - (LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_REF) - } - ScalarValue::Utf8(Some(s)) => { - (LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_REF) - } - ScalarValue::LargeUtf8(Some(s)) => { - (LiteralType::String(s.clone()), LARGE_CONTAINER_TYPE_REF) - } + ScalarValue::IntervalYearMonth(Some(i)) => ( + LiteralType::IntervalYearToMonth(IntervalYearToMonth { + // DF only tracks total months, but there should always be 12 months in a year + years: *i / 12, + months: *i % 12, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalMonthDayNano(Some(i)) => ( + LiteralType::IntervalCompound(IntervalCompound { + interval_year_to_month: Some(IntervalYearToMonth { + years: i.months / 12, + months: i.months % 12, + }), + interval_day_to_second: Some(IntervalDayToSecond { + days: i.days, + seconds: (i.nanoseconds / NANOSECONDS) as i32, + subseconds: i.nanoseconds % NANOSECONDS, + precision_mode: Some(PrecisionMode::Precision(9)), // nanoseconds + }), + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::IntervalDayTime(Some(i)) => ( + LiteralType::IntervalDayToSecond(IntervalDayToSecond { + days: i.days, + seconds: i.milliseconds / 1000, + subseconds: (i.milliseconds % 1000) as i64, + precision_mode: Some(PrecisionMode::Precision(3)), // 3 for milliseconds + }), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Binary(Some(b)) => ( + LiteralType::Binary(b.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeBinary(Some(b)) => ( + LiteralType::Binary(b.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::BinaryView(Some(b)) => ( + LiteralType::Binary(b.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::FixedSizeBinary(_, Some(b)) => ( + LiteralType::FixedBinary(b.clone()), + DEFAULT_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8(Some(s)) => ( + LiteralType::String(s.clone()), + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeUtf8(Some(s)) => ( + LiteralType::String(s.clone()), + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Utf8View(Some(s)) => ( + LiteralType::String(s.clone()), + VIEW_CONTAINER_TYPE_VARIATION_REF, + ), ScalarValue::Decimal128(v, p, s) if v.is_some() => ( LiteralType::Decimal(Decimal { value: v.unwrap().to_le_bytes().to_vec(), precision: *p as i32, scale: *s as i32, }), - DECIMAL_128_TYPE_REF, + DECIMAL_128_TYPE_VARIATION_REF, + ), + ScalarValue::List(l) => ( + convert_array_to_literal_list(l, extensions)?, + DEFAULT_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::LargeList(l) => ( + convert_array_to_literal_list(l, extensions)?, + LARGE_CONTAINER_TYPE_VARIATION_REF, + ), + ScalarValue::Map(m) => { + let map = if m.is_empty() || m.value(0).is_empty() { + let mt = to_substrait_type(m.data_type(), m.is_nullable())?; + let mt = match mt { + substrait::proto::Type { + kind: Some(r#type::Kind::Map(mt)), + } => Ok(mt.as_ref().to_owned()), + _ => exec_err!("Unexpected type for a map: {mt:?}"), + }?; + LiteralType::EmptyMap(mt) + } else { + let keys = (0..m.keys().len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&m.keys(), i)?, + extensions, + ) + }) + .collect::>>()?; + let values = (0..m.values().len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&m.values(), i)?, + extensions, + ) + }) + .collect::>>()?; + + let key_values = keys + .into_iter() + .zip(values.into_iter()) + .map(|(k, v)| { + Ok(KeyValue { + key: Some(k), + value: Some(v), + }) + }) + .collect::>>()?; + LiteralType::Map(Map { key_values }) + }; + (map, DEFAULT_CONTAINER_TYPE_VARIATION_REF) + } + ScalarValue::Struct(s) => ( + LiteralType::Struct(Struct { + fields: s + .columns() + .iter() + .map(|col| { + to_substrait_literal( + &ScalarValue::try_from_array(col, 0)?, + extensions, + ) + }) + .collect::>>()?, + }), + DEFAULT_TYPE_VARIATION_REF, + ), + _ => ( + not_impl_err!("Unsupported literal: {value:?}")?, + DEFAULT_TYPE_VARIATION_REF, ), - _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), }; + Ok(Literal { + nullable: false, + type_variation_reference, + literal_type: Some(literal_type), + }) +} + +fn convert_array_to_literal_list( + array: &GenericListArray, + extensions: &mut Extensions, +) -> Result { + assert_eq!(array.len(), 1); + let nested_array = array.value(0); + + let values = (0..nested_array.len()) + .map(|i| { + to_substrait_literal( + &ScalarValue::try_from_array(&nested_array, i)?, + extensions, + ) + }) + .collect::>>()?; + + if values.is_empty() { + let lt = match to_substrait_type(array.data_type(), array.is_nullable())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(lt)), + } => lt.as_ref().to_owned(), + _ => unreachable!(), + }; + Ok(LiteralType::EmptyList(lt)) + } else { + Ok(LiteralType::List(List { values })) + } +} + +fn to_substrait_literal_expr( + value: &ScalarValue, + extensions: &mut Extensions, +) -> Result { + let literal = to_substrait_literal(value, extensions)?; Ok(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: true, - type_variation_reference, - literal_type: Some(literal_type), - })), + rex_type: Some(RexType::Literal(literal)), }) } @@ -1765,14 +2093,10 @@ fn to_substrait_unary_scalar_fn( arg: &Expr, schema: &DFSchemaRef, col_ref_offset: usize, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { - let function_anchor = _register_function(fn_name.to_string(), extension_info); - let substrait_expr = - to_substrait_rex(ctx, arg, schema, col_ref_offset, extension_info)?; + let function_anchor = extensions.register_function(fn_name.to_string()); + let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, extensions)?; Ok(Expression { rex_type: Some(RexType::ScalarFunction(ScalarFunction { @@ -1787,166 +2111,6 @@ fn to_substrait_unary_scalar_fn( }) } -fn try_to_substrait_null(v: &ScalarValue) -> Result { - let default_nullability = r#type::Nullability::Nullable as i32; - match v { - ScalarValue::Boolean(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Bool(r#type::Boolean { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I8(r#type::I8 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int16(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt16(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I16(r#type::I16 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I32(r#type::I32 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Int64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::UInt64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::I64(r#type::I64 { - type_variation_reference: UNSIGNED_INTEGER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Float32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Fp32(r#type::Fp32 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Float64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Fp64(r#type::Fp64 { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::TimestampSecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_SECOND_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampMillisecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_MILLI_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampMicrosecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_MICRO_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::TimestampNanosecond(None, _) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Timestamp(r#type::Timestamp { - type_variation_reference: TIMESTAMP_NANO_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::Date32(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_32_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Date64(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Date(r#type::Date { - type_variation_reference: DATE_64_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Binary(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::LargeBinary(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::FixedSizeBinary(_, None) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Binary(r#type::Binary { - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })) - } - ScalarValue::Utf8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: DEFAULT_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::LargeUtf8(None) => Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::String(r#type::String { - type_variation_reference: LARGE_CONTAINER_TYPE_REF, - nullability: default_nullability, - })), - })), - ScalarValue::Decimal128(None, p, s) => { - Ok(LiteralType::Null(substrait::proto::Type { - kind: Some(r#type::Kind::Decimal(r#type::Decimal { - scale: *s as i32, - precision: *p as i32, - type_variation_reference: DEFAULT_TYPE_REF, - nullability: default_nullability, - })), - })) - } - // TODO: Extend support for remaining data types - _ => not_impl_err!("Unsupported literal: {v:?}"), - } -} - /// Try to convert an [Expr] to a [FieldReference]. /// Returns `Err` if the [Expr] is not a [Expr::Column]. fn try_to_substrait_field_reference( @@ -1974,33 +2138,26 @@ fn try_to_substrait_field_reference( fn substrait_sort_field( ctx: &SessionContext, - expr: &Expr, + sort: &Sort, schema: &DFSchemaRef, - extension_info: &mut ( - Vec, - HashMap, - ), + extensions: &mut Extensions, ) -> Result { - match expr { - Expr::Sort(Sort { - expr, - asc, - nulls_first, - }) => { - let e = to_substrait_rex(ctx, expr, schema, 0, extension_info)?; - let d = match (asc, nulls_first) { - (true, true) => SortDirection::AscNullsFirst, - (true, false) => SortDirection::AscNullsLast, - (false, true) => SortDirection::DescNullsFirst, - (false, false) => SortDirection::DescNullsLast, - }; - Ok(SortField { - expr: Some(e), - sort_kind: Some(SortKind::Direction(d as i32)), - }) - } - _ => not_impl_err!("Expecting sort expression but got {expr:?}"), - } + let Sort { + expr, + asc, + nulls_first, + } = sort; + let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?; + let d = match (asc, nulls_first) { + (true, true) => SortDirection::AscNullsFirst, + (true, false) => SortDirection::AscNullsLast, + (false, true) => SortDirection::DescNullsFirst, + (false, false) => SortDirection::DescNullsLast, + }; + Ok(SortField { + expr: Some(e), + sort_kind: Some(SortKind::Direction(d as i32)), + }) } fn substrait_field_ref(index: usize) -> Result { @@ -2021,9 +2178,18 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { - use crate::logical_plan::consumer::from_substrait_literal; - use super::*; + use crate::logical_plan::consumer::{ + from_substrait_extended_expr, from_substrait_literal_without_names, + from_substrait_named_struct, from_substrait_type_without_names, + }; + use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; + use datafusion::arrow::array::{ + GenericListArray, Int64Builder, MapBuilder, StringBuilder, + }; + use datafusion::arrow::datatypes::{Field, Fields, Schema}; + use datafusion::common::scalar::ScalarStructBuilder; + use datafusion::common::DFSchema; #[test] fn round_trip_literals() -> Result<()> { @@ -2059,22 +2225,275 @@ mod test { round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + for (ts, tz) in [ + (Some(12345), None), + (None, None), + (Some(12345), Some("UTC".into())), + (None, Some("UTC".into())), + ] { + round_trip_literal(ScalarValue::TimestampSecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMillisecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampMicrosecond(ts, tz.clone()))?; + round_trip_literal(ScalarValue::TimestampNanosecond(ts, tz))?; + } + + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list_nullable( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ))))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ), + )))?; + + // Null map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(false)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Empty map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + // Valid map + let mut map_builder = + MapBuilder::new(None, StringBuilder::new(), Int64Builder::new()); + map_builder.keys().append_value("key1"); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(1); + map_builder.values().append_value(2); + map_builder.append(true)?; + round_trip_literal(ScalarValue::Map(Arc::new(map_builder.finish())))?; + + let c0 = Field::new("c0", DataType::Boolean, true); + let c1 = Field::new("c1", DataType::Int32, true); + let c2 = Field::new("c2", DataType::Utf8, true); + round_trip_literal( + ScalarStructBuilder::new() + .with_scalar(c0.to_owned(), ScalarValue::Boolean(Some(true))) + .with_scalar(c1.to_owned(), ScalarValue::Int32(Some(1))) + .with_scalar(c2.to_owned(), ScalarValue::Utf8(None)) + .build()?, + )?; + round_trip_literal(ScalarStructBuilder::new_null(vec![c0, c1, c2]))?; + + round_trip_literal(ScalarValue::IntervalYearMonth(Some(17)))?; + round_trip_literal(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano::new(17, 25, 1234567890), + )))?; + round_trip_literal(ScalarValue::IntervalDayTime(Some(IntervalDayTime::new( + 57, 123456, + ))))?; + Ok(()) } fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait = to_substrait_literal(&scalar)?; - let Expression { - rex_type: Some(RexType::Literal(substrait_literal)), - } = substrait - else { - panic!("Expected Literal expression, got {substrait:?}"); - }; - - let roundtrip_scalar = from_substrait_literal(&substrait_literal)?; + let mut extensions = Extensions::default(); + let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; + let roundtrip_scalar = + from_substrait_literal_without_names(&substrait_literal, &extensions)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) } + + #[test] + fn round_trip_types() -> Result<()> { + round_trip_type(DataType::Boolean)?; + round_trip_type(DataType::Int8)?; + round_trip_type(DataType::UInt8)?; + round_trip_type(DataType::Int16)?; + round_trip_type(DataType::UInt16)?; + round_trip_type(DataType::Int32)?; + round_trip_type(DataType::UInt32)?; + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float32)?; + round_trip_type(DataType::Float64)?; + + for tz in [None, Some("UTC".into())] { + round_trip_type(DataType::Timestamp(TimeUnit::Second, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, tz.clone()))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, tz))?; + } + + round_trip_type(DataType::Date32)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Binary)?; + round_trip_type(DataType::FixedSizeBinary(10))?; + round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::BinaryView)?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Utf8View)?; + round_trip_type(DataType::Decimal128(10, 2))?; + round_trip_type(DataType::Decimal256(30, 2))?; + + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + + round_trip_type(DataType::Map( + Field::new_struct( + "entries", + [ + Field::new("key", DataType::Utf8, false).into(), + Field::new("value", DataType::Int32, true).into(), + ], + false, + ) + .into(), + false, + ))?; + + round_trip_type(DataType::Struct( + vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ] + .into(), + ))?; + + round_trip_type(DataType::Interval(IntervalUnit::YearMonth))?; + round_trip_type(DataType::Interval(IntervalUnit::MonthDayNano))?; + round_trip_type(DataType::Interval(IntervalUnit::DayTime))?; + + Ok(()) + } + + fn round_trip_type(dt: DataType) -> Result<()> { + println!("Checking round trip of {dt:?}"); + + // As DataFusion doesn't consider nullability as a property of the type, but field, + // it doesn't matter if we set nullability to true or false here. + let substrait = to_substrait_type(&dt, true)?; + let roundtrip_dt = + from_substrait_type_without_names(&substrait, &Extensions::default())?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } + + #[test] + fn named_struct_names() -> Result<()> { + let schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("int", DataType::Int32, true), + Field::new( + "struct", + DataType::Struct(Fields::from(vec![Field::new( + "inner", + DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))), + true, + )])), + true, + ), + Field::new("trailer", DataType::Float64, true), + ]))?); + + let named_struct = to_substrait_named_struct(&schema)?; + + // Struct field names should be flattened DFS style + // List field names should be omitted + assert_eq!( + named_struct.names, + vec!["int", "struct", "inner", "trailer"] + ); + + let roundtrip_schema = + from_substrait_named_struct(&named_struct, &Extensions::default())?; + assert_eq!(schema.as_ref(), &roundtrip_schema); + Ok(()) + } + + #[tokio::test] + async fn extended_expressions() -> Result<()> { + let ctx = SessionContext::new(); + + // One expression, empty input schema + let expr = Expr::Literal(ScalarValue::Int32(Some(42))); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + let substrait = + to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx)?; + let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, empty_schema); + assert_eq!(roundtrip_expr.exprs.len(), 1); + + let (rt_expr, rt_field) = roundtrip_expr.exprs.first().unwrap(); + assert_eq!(rt_field, &field); + assert_eq!(rt_expr, &expr); + + // Multiple expressions, with column references + let expr1 = Expr::Column("c0".into()); + let expr2 = Expr::Column("c1".into()); + let out1 = Field::new("out1", DataType::Int32, true); + let out2 = Field::new("out2", DataType::Utf8, true); + let input_schema = DFSchemaRef::new(DFSchema::try_from(Schema::new(vec![ + Field::new("c0", DataType::Int32, true), + Field::new("c1", DataType::Utf8, true), + ]))?); + + let substrait = to_substrait_extended_expr( + &[(&expr1, &out1), (&expr2, &out2)], + &input_schema, + &ctx, + )?; + let roundtrip_expr = from_substrait_extended_expr(&ctx, &substrait).await?; + + assert_eq!(roundtrip_expr.input_schema, input_schema); + assert_eq!(roundtrip_expr.exprs.len(), 2); + + let mut exprs = roundtrip_expr.exprs.into_iter(); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out1); + assert_eq!(rt_expr, expr1); + + let (rt_expr, rt_field) = exprs.next().unwrap(); + assert_eq!(rt_field, out2); + assert_eq!(rt_expr, expr2); + + Ok(()) + } + + #[tokio::test] + async fn invalid_extended_expression() { + let ctx = SessionContext::new(); + + // Not ok if input schema is missing field referenced by expr + let expr = Expr::Column("missing".into()); + let field = Field::new("out", DataType::Int32, false); + let empty_schema = DFSchemaRef::new(DFSchema::empty()); + + let err = to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, &ctx); + + assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); + } } diff --git a/datafusion/substrait/src/physical_plan/consumer.rs b/datafusion/substrait/src/physical_plan/consumer.rs index 11ddb91ad391..a8f8ce048e0f 100644 --- a/datafusion/substrait/src/physical_plan/consumer.rs +++ b/datafusion/substrait/src/physical_plan/consumer.rs @@ -18,23 +18,30 @@ use std::collections::HashMap; use std::sync::Arc; -use datafusion::arrow::datatypes::Schema; -use datafusion::common::not_impl_err; +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::common::{not_impl_err, substrait_err}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::error::{DataFusionError, Result}; -use datafusion::physical_plan::{ExecutionPlan, Statistics}; +use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; use async_recursion::async_recursion; use chrono::DateTime; use object_store::ObjectMeta; +use substrait::proto::r#type::{Kind, Nullability}; use substrait::proto::read_rel::local_files::file_or_files::PathType; +use substrait::proto::Type; use substrait::proto::{ expression::MaskExpression, read_rel::ReadType, rel::RelType, Rel, }; +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; + /// Convert Substrait Rel to DataFusion ExecutionPlan #[async_recursion] pub async fn from_substrait_rel( @@ -42,17 +49,42 @@ pub async fn from_substrait_rel( rel: &Rel, _extensions: &HashMap, ) -> Result> { + let mut base_config; + match &rel.rel_type { Some(RelType::Read(read)) => { if read.filter.is_some() || read.best_effort_filter.is_some() { return not_impl_err!("Read with filter is not supported"); } - if read.base_schema.is_some() { - return not_impl_err!("Read with schema is not supported"); - } + if read.advanced_extension.is_some() { return not_impl_err!("Read with AdvancedExtension is not supported"); } + + let Some(schema) = read.base_schema.as_ref() else { + return substrait_err!("Missing base schema in the read"); + }; + + let Some(r#struct) = schema.r#struct.as_ref() else { + return substrait_err!("Missing struct in the schema"); + }; + + match schema + .names + .iter() + .zip(r#struct.types.iter()) + .map(|(name, r#type)| to_field(name, r#type)) + .collect::>>() + { + Ok(fields) => { + base_config = FileScanConfig::new( + ObjectStoreUrl::local_filesystem(), + Arc::new(Schema::new(fields)), + ); + } + Err(e) => return Err(e), + }; + match &read.as_ref().read_type { Some(ReadType::LocalFiles(files)) => { let mut file_groups = vec![]; @@ -93,6 +125,7 @@ pub async fn from_substrait_rel( }, partition_values: vec![], range: None, + statistics: None, extensions: None, }; @@ -103,16 +136,7 @@ pub async fn from_substrait_rel( file_groups[part_index].push(partitioned_file) } - let mut base_config = FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: Arc::new(Schema::empty()), - file_groups, - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }; + base_config = base_config.with_file_groups(file_groups); if let Some(MaskExpression { select, .. }) = &read.projection { if let Some(projection) = &select.as_ref() { @@ -125,12 +149,8 @@ pub async fn from_substrait_rel( } } - Ok(Arc::new(ParquetExec::new( - base_config, - None, - None, - Default::default(), - )) as Arc) + Ok(ParquetExec::builder(base_config).build_arc() + as Arc) } _ => not_impl_err!( "Only LocalFile reads are supported when parsing physical" @@ -140,3 +160,67 @@ pub async fn from_substrait_rel( _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), } } + +fn to_field(name: &String, r#type: &Type) -> Result { + let Some(kind) = r#type.kind.as_ref() else { + return substrait_err!("Missing kind in the type with name {}", name); + }; + + let mut nullable = false; + let data_type = match kind { + Kind::Bool(boolean) => { + nullable = is_nullable(boolean.nullability); + Ok(DataType::Boolean) + } + Kind::I64(i64) => { + nullable = is_nullable(i64.nullability); + Ok(DataType::Int64) + } + Kind::Fp64(fp64) => { + nullable = is_nullable(fp64.nullability); + Ok(DataType::Float64) + } + Kind::String(string) => { + nullable = is_nullable(string.nullability); + match string.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeUtf8), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Utf8View), + _ => substrait_err!( + "Invalid type variation found for substrait string type class: {}", + string.type_variation_reference + ), + } + } + Kind::Binary(binary) => { + nullable = is_nullable(binary.nullability); + match binary.type_variation_reference { + DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::Binary), + LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::LargeBinary), + VIEW_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::BinaryView), + _ => substrait_err!( + "Invalid type variation found for substrait binary type class: {}", + binary.type_variation_reference + ), + } + } + _ => substrait_err!( + "Unsupported kind: {:?} in the type with name {}", + kind, + name + ), + }?; + + Ok(Field::new(name, data_type, nullable)) +} + +fn is_nullable(nullability: i32) -> bool { + let Ok(nullability) = Nullability::try_from(nullability) else { + return true; + }; + + match nullability { + Nullability::Nullable | Nullability::Unspecified => true, + Nullability::Required => false, + } +} diff --git a/datafusion/substrait/src/physical_plan/producer.rs b/datafusion/substrait/src/physical_plan/producer.rs index ad87d7afb058..7279785ae873 100644 --- a/datafusion/substrait/src/physical_plan/producer.rs +++ b/datafusion/substrait/src/physical_plan/producer.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. +use datafusion::arrow::datatypes::DataType; use datafusion::datasource::physical_plan::ParquetExec; use datafusion::error::{DataFusionError, Result}; use datafusion::physical_plan::{displayable, ExecutionPlan}; use std::collections::HashMap; +use substrait::proto::expression::mask_expression::{StructItem, StructSelect}; use substrait::proto::expression::MaskExpression; -use substrait::proto::extensions; +use substrait::proto::r#type::{ + Binary, Boolean, Fp64, Kind, Nullability, String as SubstraitString, Struct, I64, +}; use substrait::proto::read_rel::local_files::file_or_files::ParquetReadOptions; use substrait::proto::read_rel::local_files::file_or_files::{FileFormat, PathType}; use substrait::proto::read_rel::local_files::FileOrFiles; @@ -29,6 +33,12 @@ use substrait::proto::read_rel::ReadType; use substrait::proto::rel::RelType; use substrait::proto::ReadRel; use substrait::proto::Rel; +use substrait::proto::{extensions, NamedStruct, Type}; + +use crate::variation_const::{ + DEFAULT_CONTAINER_TYPE_VARIATION_REF, LARGE_CONTAINER_TYPE_VARIATION_REF, + VIEW_CONTAINER_TYPE_VARIATION_REF, +}; /// Convert DataFusion ExecutionPlan to Substrait Rel pub fn to_substrait_rel( @@ -55,15 +65,56 @@ pub fn to_substrait_rel( } } + let mut names = vec![]; + let mut types = vec![]; + + for field in base_config.file_schema.fields.iter() { + match to_substrait_type(field.data_type(), field.is_nullable()) { + Ok(t) => { + names.push(field.name().clone()); + types.push(t); + } + Err(e) => return Err(e), + } + } + + let type_info = Struct { + types, + // FIXME: duckdb doesn't set this field, keep it as default variant 0. + // https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L1106-L1127 + type_variation_reference: 0, + nullability: Nullability::Required.into(), + }; + + let mut select_struct = None; + if let Some(projection) = base_config.projection.as_ref() { + let struct_items = projection + .iter() + .map(|index| StructItem { + field: *index as i32, + // FIXME: duckdb sets this to None, but it's not clear why. + // https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L1191 + child: None, + }) + .collect(); + + select_struct = Some(StructSelect { struct_items }); + } + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, - base_schema: None, + base_schema: Some(NamedStruct { + names, + r#struct: Some(type_info), + }), filter: None, best_effort_filter: None, projection: Some(MaskExpression { - select: None, - maintain_singular_struct: false, + select: select_struct, + // FIXME: duckdb set this to true, but it's not clear why. + // https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L1186. + maintain_singular_struct: true, }), advanced_extension: None, read_type: Some(ReadType::LocalFiles(LocalFiles { @@ -79,3 +130,72 @@ pub fn to_substrait_rel( ))) } } + +// see https://github.com/duckdb/substrait/blob/b6f56643cb11d52de0e32c24a01dfd5947df62be/src/to_substrait.cpp#L954-L1094. +fn to_substrait_type(data_type: &DataType, nullable: bool) -> Result { + let nullability = if nullable { + Nullability::Nullable.into() + } else { + Nullability::Required.into() + }; + + match data_type { + DataType::Boolean => Ok(Type { + kind: Some(Kind::Bool(Boolean { + type_variation_reference: 0, + nullability, + })), + }), + DataType::Int64 => Ok(Type { + kind: Some(Kind::I64(I64 { + type_variation_reference: 0, + nullability, + })), + }), + DataType::Float64 => Ok(Type { + kind: Some(Kind::Fp64(Fp64 { + type_variation_reference: 0, + nullability, + })), + }), + DataType::Utf8 => Ok(Type { + kind: Some(Kind::String(SubstraitString { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeUtf8 => Ok(Type { + kind: Some(Kind::String(SubstraitString { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Utf8View => Ok(Type { + kind: Some(Kind::String(SubstraitString { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::Binary => Ok(Type { + kind: Some(Kind::Binary(Binary { + type_variation_reference: DEFAULT_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::LargeBinary => Ok(Type { + kind: Some(Kind::Binary(Binary { + type_variation_reference: LARGE_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + DataType::BinaryView => Ok(Type { + kind: Some(Kind::Binary(Binary { + type_variation_reference: VIEW_CONTAINER_TYPE_VARIATION_REF, + nullability, + })), + }), + _ => Err(DataFusionError::Substrait(format!( + "Logical type {data_type} not implemented as substrait type" + ))), + } +} diff --git a/datafusion/substrait/src/variation_const.rs b/datafusion/substrait/src/variation_const.rs index 27ef15153bd8..58774db424da 100644 --- a/datafusion/substrait/src/variation_const.rs +++ b/datafusion/substrait/src/variation_const.rs @@ -18,22 +18,94 @@ //! Type variation constants //! //! To add support for types not in the [core specification](https://substrait.io/types/type_classes/), -//! we make use of the [simple extensions](https://substrait.io/extensions/#simple-extensions) of substrait -//! type. This module contains the constants used to identify the type variation. +//! we make use of the [simple extensions] of substrait type. This module contains the constants used +//! to identify the type variation. //! //! The rules of type variations here are: //! - Default type reference is 0. It is used when the actual type is the same with the original type. //! - Extended variant type references start from 1, and ususlly increase by 1. +//! +//! TODO: Definitions here are not the final form. All the non-system-preferred variations will be defined +//! using [simple extensions] as per the [spec of type_variations](https://substrait.io/types/type_variations/) +//! +//! +//! [simple extensions]: (https://substrait.io/extensions/#simple-extensions) + +// For [type variations](https://substrait.io/types/type_variations/#type-variations) in substrait. +// Type variations are used to represent different types based on one type class. +// TODO: Define as extensions: + +/// The "system-preferred" variation (i.e., no variation). +pub const DEFAULT_TYPE_VARIATION_REF: u32 = 0; +pub const UNSIGNED_INTEGER_TYPE_VARIATION_REF: u32 = 1; + +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] +pub const TIMESTAMP_SECOND_TYPE_VARIATION_REF: u32 = 0; +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] +pub const TIMESTAMP_MILLI_TYPE_VARIATION_REF: u32 = 1; +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] +pub const TIMESTAMP_MICRO_TYPE_VARIATION_REF: u32 = 2; +#[deprecated(since = "42.0.0", note = "Use `PrecisionTimestamp(Tz)` type instead")] +pub const TIMESTAMP_NANO_TYPE_VARIATION_REF: u32 = 3; + +pub const DATE_32_TYPE_VARIATION_REF: u32 = 0; +pub const DATE_64_TYPE_VARIATION_REF: u32 = 1; +pub const DEFAULT_CONTAINER_TYPE_VARIATION_REF: u32 = 0; +pub const LARGE_CONTAINER_TYPE_VARIATION_REF: u32 = 1; +pub const VIEW_CONTAINER_TYPE_VARIATION_REF: u32 = 2; +pub const DECIMAL_128_TYPE_VARIATION_REF: u32 = 0; +pub const DECIMAL_256_TYPE_VARIATION_REF: u32 = 1; + +// For [user-defined types](https://substrait.io/types/type_classes/#user-defined-types). +/// For [`DataType::Interval`] with [`IntervalUnit::YearMonth`]. +/// +/// An `i32` for elapsed whole months. See also [`ScalarValue::IntervalYearMonth`] +/// for the literal definition in DataFusion. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::YearMonth`]: datafusion::arrow::datatypes::IntervalUnit::YearMonth +/// [`ScalarValue::IntervalYearMonth`]: datafusion::common::ScalarValue::IntervalYearMonth +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalYear` type instead")] +pub const INTERVAL_YEAR_MONTH_TYPE_REF: u32 = 1; + +/// For [`DataType::Interval`] with [`IntervalUnit::DayTime`]. +/// +/// An `i64` as: +/// - days: `i32` +/// - milliseconds: `i32` +/// +/// See also [`ScalarValue::IntervalDayTime`] for the literal definition in DataFusion. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::DayTime`]: datafusion::arrow::datatypes::IntervalUnit::DayTime +/// [`ScalarValue::IntervalDayTime`]: datafusion::common::ScalarValue::IntervalDayTime +#[deprecated(since = "41.0.0", note = "Use Substrait `IntervalDay` type instead")] +pub const INTERVAL_DAY_TIME_TYPE_REF: u32 = 2; + +/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. +/// +/// An `i128` as: +/// - months: `i32` +/// - days: `i32` +/// - nanoseconds: `i64` +/// +/// See also [`ScalarValue::IntervalMonthDayNano`] for the literal definition in DataFusion. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano +/// [`ScalarValue::IntervalMonthDayNano`]: datafusion::common::ScalarValue::IntervalMonthDayNano +#[deprecated( + since = "41.0.0", + note = "Use Substrait `IntervalCompund` type instead" +)] +pub const INTERVAL_MONTH_DAY_NANO_TYPE_REF: u32 = 3; -pub const DEFAULT_TYPE_REF: u32 = 0; -pub const UNSIGNED_INTEGER_TYPE_REF: u32 = 1; -pub const TIMESTAMP_SECOND_TYPE_REF: u32 = 0; -pub const TIMESTAMP_MILLI_TYPE_REF: u32 = 1; -pub const TIMESTAMP_MICRO_TYPE_REF: u32 = 2; -pub const TIMESTAMP_NANO_TYPE_REF: u32 = 3; -pub const DATE_32_TYPE_REF: u32 = 0; -pub const DATE_64_TYPE_REF: u32 = 1; -pub const DEFAULT_CONTAINER_TYPE_REF: u32 = 0; -pub const LARGE_CONTAINER_TYPE_REF: u32 = 1; -pub const DECIMAL_128_TYPE_REF: u32 = 0; -pub const DECIMAL_256_TYPE_REF: u32 = 1; +/// For [`DataType::Interval`] with [`IntervalUnit::MonthDayNano`]. +/// +/// [`DataType::Interval`]: datafusion::arrow::datatypes::DataType::Interval +/// [`IntervalUnit::MonthDayNano`]: datafusion::arrow::datatypes::IntervalUnit::MonthDayNano +#[deprecated( + since = "43.0.0", + note = "Use Substrait `IntervalCompund` type instead" +)] +pub const INTERVAL_MONTH_DAY_NANO_TYPE_NAME: &str = "interval-month-day-nano"; diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs b/datafusion/substrait/tests/cases/consumer_integration.rs new file mode 100644 index 000000000000..bc38ef82977f --- /dev/null +++ b/datafusion/substrait/tests/cases/consumer_integration.rs @@ -0,0 +1,458 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! TPCH `substrait_consumer` tests +//! +//! This module tests that substrait plans as json encoded protobuf can be +//! correctly read as DataFusion plans. +//! +//! The input data comes from + +#[cfg(test)] +mod tests { + use crate::utils::test::add_plan_schemas_to_ctx; + use datafusion::common::Result; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::fs::File; + use std::io::BufReader; + use substrait::proto::Plan; + + async fn tpch_plan_to_string(query_id: i32) -> Result { + let path = + format!("tests/testdata/tpch_substrait_plans/query_{query_id:02}_plan.json"); + let proto = serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json"); + + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?; + let plan = from_substrait_plan(&ctx, &proto).await?; + Ok(format!("{}", plan)) + } + + #[tokio::test] + async fn tpch_test_01() -> Result<()> { + let plan_str = tpch_plan_to_string(1).await?; + assert_eq!( + plan_str, + "Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, sum(LINEITEM.L_QUANTITY) AS SUM_QTY, sum(LINEITEM.L_EXTENDEDPRICE) AS SUM_BASE_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS SUM_DISC_PRICE, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX) AS SUM_CHARGE, avg(LINEITEM.L_QUANTITY) AS AVG_QTY, avg(LINEITEM.L_EXTENDEDPRICE) AS AVG_PRICE, avg(LINEITEM.L_DISCOUNT) AS AVG_DISC, count(Int64(1)) AS COUNT_ORDER\ + \n Sort: LINEITEM.L_RETURNFLAG ASC NULLS LAST, LINEITEM.L_LINESTATUS ASC NULLS LAST\ + \n Aggregate: groupBy=[[LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS]], aggr=[[sum(LINEITEM.L_QUANTITY), sum(LINEITEM.L_EXTENDEDPRICE), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT * Int32(1) + LINEITEM.L_TAX), avg(LINEITEM.L_QUANTITY), avg(LINEITEM.L_EXTENDEDPRICE), avg(LINEITEM.L_DISCOUNT), count(Int64(1))]]\ + \n Projection: LINEITEM.L_RETURNFLAG, LINEITEM.L_LINESTATUS, LINEITEM.L_QUANTITY, LINEITEM.L_EXTENDEDPRICE, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT), LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) * (CAST(Int32(1) AS Decimal128(15, 2)) + LINEITEM.L_TAX), LINEITEM.L_DISCOUNT\ + \n Filter: LINEITEM.L_SHIPDATE <= Date32(\"1998-12-01\") - IntervalDayTime(\"IntervalDayTime { days: 0, milliseconds: 10368000 }\")\ + \n TableScan: LINEITEM" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_02() -> Result<()> { + let plan_str = tpch_plan_to_string(2).await?; + assert_eq!( + plan_str, + "Limit: skip=0, fetch=100\ + \n Sort: SUPPLIER.S_ACCTBAL DESC NULLS FIRST, NATION.N_NAME ASC NULLS LAST, SUPPLIER.S_NAME ASC NULLS LAST, PART.P_PARTKEY ASC NULLS LAST\ + \n Projection: SUPPLIER.S_ACCTBAL, SUPPLIER.S_NAME, NATION.N_NAME, PART.P_PARTKEY, PART.P_MFGR, SUPPLIER.S_ADDRESS, SUPPLIER.S_PHONE, SUPPLIER.S_COMMENT\ + \n Filter: PART.P_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND PART.P_SIZE = Int32(15) AND PART.P_TYPE LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"EUROPE\") AND PARTSUPP.PS_SUPPLYCOST = ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[]], aggr=[[min(PARTSUPP.PS_SUPPLYCOST)]]\ + \n Projection: PARTSUPP.PS_SUPPLYCOST\ + \n Filter: PARTSUPP.PS_PARTKEY = PARTSUPP.PS_PARTKEY AND SUPPLIER.S_SUPPKEY = PARTSUPP.PS_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"EUROPE\")\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n TableScan: REGION\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PART\ + \n TableScan: SUPPLIER\ + \n TableScan: PARTSUPP\ + \n TableScan: NATION\ + \n TableScan: REGION" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_03() -> Result<()> { + let plan_str = tpch_plan_to_string(3).await?; + assert_eq!( + plan_str, + "Projection: LINEITEM.L_ORDERKEY, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY\ + \n Limit: skip=0, fetch=10\ + \n Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST, ORDERS.O_ORDERDATE ASC NULLS LAST\ + \n Projection: LINEITEM.L_ORDERKEY, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY\ + \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ + \n Projection: LINEITEM.L_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_SHIPPRIORITY, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ + \n Filter: CUSTOMER.C_MKTSEGMENT = Utf8(\"BUILDING\") AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-03-15\") AS Date32) AND LINEITEM.L_SHIPDATE > CAST(Utf8(\"1995-03-15\") AS Date32)\ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: LINEITEM\ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_04() -> Result<()> { + let plan_str = tpch_plan_to_string(4).await?; + assert_eq!( + plan_str, + "Projection: ORDERS.O_ORDERPRIORITY, count(Int64(1)) AS ORDER_COUNT\ + \n Sort: ORDERS.O_ORDERPRIORITY ASC NULLS LAST\ + \n Aggregate: groupBy=[[ORDERS.O_ORDERPRIORITY]], aggr=[[count(Int64(1))]]\ + \n Projection: ORDERS.O_ORDERPRIORITY\ + \n Filter: ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-07-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1993-10-01\") AS Date32) AND EXISTS ()\ + \n Subquery:\ + \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE\ + \n TableScan: LINEITEM\ + \n TableScan: ORDERS" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_05() -> Result<()> { + let plan_str = tpch_plan_to_string(5).await?; + assert_eq!( + plan_str, + "Projection: NATION.N_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE\ + \n Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST\ + \n Aggregate: groupBy=[[NATION.N_NAME]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ + \n Projection: NATION.N_NAME, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ + \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND LINEITEM.L_SUPPKEY = SUPPLIER.S_SUPPKEY AND CUSTOMER.C_NATIONKEY = SUPPLIER.S_NATIONKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_REGIONKEY = REGION.R_REGIONKEY AND REGION.R_NAME = Utf8(\"ASIA\") AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n TableScan: REGION" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_06() -> Result<()> { + let plan_str = tpch_plan_to_string(6).await?; + assert_eq!( + plan_str, + "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT) AS REVENUE]]\ + \n Projection: LINEITEM.L_EXTENDEDPRICE * LINEITEM.L_DISCOUNT\ + \n Filter: LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32) AND LINEITEM.L_DISCOUNT >= Decimal128(Some(5),3,2) AND LINEITEM.L_DISCOUNT <= Decimal128(Some(7),3,2) AND LINEITEM.L_QUANTITY < CAST(Int32(24) AS Decimal128(15, 2))\ + \n TableScan: LINEITEM" + ); + Ok(()) + } + + #[ignore] + #[tokio::test] + async fn tpch_test_07() -> Result<()> { + let plan_str = tpch_plan_to_string(7).await?; + assert_eq!(plan_str, "Missing support for enum function arguments"); + Ok(()) + } + + #[ignore] + #[tokio::test] + async fn tpch_test_08() -> Result<()> { + let plan_str = tpch_plan_to_string(8).await?; + assert_eq!(plan_str, "Missing support for enum function arguments"); + Ok(()) + } + + #[ignore] + #[tokio::test] + async fn tpch_test_09() -> Result<()> { + let plan_str = tpch_plan_to_string(9).await?; + assert_eq!(plan_str, "Missing support for enum function arguments"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_10() -> Result<()> { + let plan_str = tpch_plan_to_string(10).await?; + assert_eq!( + plan_str, + "Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE, CUSTOMER.C_ACCTBAL, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_PHONE, CUSTOMER.C_COMMENT\ + \n Limit: skip=0, fetch=20\ + \n Sort: sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) DESC NULLS FIRST\ + \n Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT), CUSTOMER.C_ACCTBAL, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_PHONE, CUSTOMER.C_COMMENT\ + \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ + \n Projection: CUSTOMER.C_CUSTKEY, CUSTOMER.C_NAME, CUSTOMER.C_ACCTBAL, CUSTOMER.C_PHONE, NATION.N_NAME, CUSTOMER.C_ADDRESS, CUSTOMER.C_COMMENT, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ + \n Filter: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND LINEITEM.L_ORDERKEY = ORDERS.O_ORDERKEY AND ORDERS.O_ORDERDATE >= CAST(Utf8(\"1993-10-01\") AS Date32) AND ORDERS.O_ORDERDATE < CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RETURNFLAG = Utf8(\"R\") AND CUSTOMER.C_NATIONKEY = NATION.N_NATIONKEY\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM\ + \n TableScan: NATION" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_11() -> Result<()> { + let plan_str = tpch_plan_to_string(11).await?; + assert_eq!( + plan_str, + "Projection: PARTSUPP.PS_PARTKEY, sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) AS value\ + \n Sort: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) DESC NULLS FIRST\ + \n Filter: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) > ()\ + \n Subquery:\ + \n Projection: sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY) * Decimal128(Some(1000000),11,10)\ + \n Aggregate: groupBy=[[]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ + \n Projection: PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ + \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION\ + \n Aggregate: groupBy=[[PARTSUPP.PS_PARTKEY]], aggr=[[sum(PARTSUPP.PS_SUPPLYCOST * PARTSUPP.PS_AVAILQTY)]]\ + \n Projection: PARTSUPP.PS_PARTKEY, PARTSUPP.PS_SUPPLYCOST * CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0))\ + \n Filter: PARTSUPP.PS_SUPPKEY = SUPPLIER.S_SUPPKEY AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"JAPAN\")\ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: SUPPLIER\ + \n TableScan: NATION" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_12() -> Result<()> { + let plan_str = tpch_plan_to_string(12).await?; + assert_eq!( + plan_str, + "Projection: LINEITEM.L_SHIPMODE, sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END) AS HIGH_LINE_COUNT, sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END) AS LOW_LINE_COUNT\ + \n Sort: LINEITEM.L_SHIPMODE ASC NULLS LAST\ + \n Aggregate: groupBy=[[LINEITEM.L_SHIPMODE]], aggr=[[sum(CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END), sum(CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END)]]\ + \n Projection: LINEITEM.L_SHIPMODE, CASE WHEN ORDERS.O_ORDERPRIORITY = Utf8(\"1-URGENT\") OR ORDERS.O_ORDERPRIORITY = Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END, CASE WHEN ORDERS.O_ORDERPRIORITY != Utf8(\"1-URGENT\") AND ORDERS.O_ORDERPRIORITY != Utf8(\"2-HIGH\") THEN Int32(1) ELSE Int32(0) END\ + \n Filter: ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"MAIL\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"SHIP\") AS Utf8)) AND LINEITEM.L_COMMITDATE < LINEITEM.L_RECEIPTDATE AND LINEITEM.L_SHIPDATE < LINEITEM.L_COMMITDATE AND LINEITEM.L_RECEIPTDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_RECEIPTDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n Cross Join: \ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_13() -> Result<()> { + let plan_str = tpch_plan_to_string(13).await?; + assert_eq!( + plan_str, + "Projection: count(ORDERS.O_ORDERKEY) AS C_COUNT, count(Int64(1)) AS CUSTDIST\ + \n Sort: count(Int64(1)) DESC NULLS FIRST, count(ORDERS.O_ORDERKEY) DESC NULLS FIRST\ + \n Projection: count(ORDERS.O_ORDERKEY), count(Int64(1))\ + \n Aggregate: groupBy=[[count(ORDERS.O_ORDERKEY)]], aggr=[[count(Int64(1))]]\ + \n Projection: count(ORDERS.O_ORDERKEY)\ + \n Aggregate: groupBy=[[CUSTOMER.C_CUSTKEY]], aggr=[[count(ORDERS.O_ORDERKEY)]]\ + \n Projection: CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY\ + \n Left Join: CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY Filter: NOT ORDERS.O_COMMENT LIKE CAST(Utf8(\"%special%requests%\") AS Utf8)\ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_14() -> Result<()> { + let plan_str = tpch_plan_to_string(14).await?; + assert_eq!( + plan_str, + "Projection: Decimal128(Some(10000),5,2) * sum(CASE WHEN PART.P_TYPE LIKE Utf8(\"PROMO%\") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END) / sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS PROMO_REVENUE\ + \n Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN PART.P_TYPE LIKE Utf8(\"PROMO%\") THEN LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT ELSE Decimal128(Some(0),19,4) END), sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT)]]\ + \n Projection: CASE WHEN PART.P_TYPE LIKE CAST(Utf8(\"PROMO%\") AS Utf8) THEN LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT) ELSE Decimal128(Some(0),19,4) END, LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ + \n Filter: LINEITEM.L_PARTKEY = PART.P_PARTKEY AND LINEITEM.L_SHIPDATE >= Date32(\"1995-09-01\") AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-10-01\") AS Date32)\ + \n Cross Join: \ + \n TableScan: LINEITEM\ + \n TableScan: PART" + ); + Ok(()) + } + + #[ignore] + #[tokio::test] + async fn tpch_test_15() -> Result<()> { + let plan_str = tpch_plan_to_string(15).await?; + assert_eq!(plan_str, "Test file is empty"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_16() -> Result<()> { + let plan_str = tpch_plan_to_string(16).await?; + assert_eq!( + plan_str, + "Projection: PART.P_BRAND, PART.P_TYPE, PART.P_SIZE, count(DISTINCT PARTSUPP.PS_SUPPKEY) AS SUPPLIER_CNT\ + \n Sort: count(DISTINCT PARTSUPP.PS_SUPPKEY) DESC NULLS FIRST, PART.P_BRAND ASC NULLS LAST, PART.P_TYPE ASC NULLS LAST, PART.P_SIZE ASC NULLS LAST\ + \n Aggregate: groupBy=[[PART.P_BRAND, PART.P_TYPE, PART.P_SIZE]], aggr=[[count(DISTINCT PARTSUPP.PS_SUPPKEY)]]\ + \n Projection: PART.P_BRAND, PART.P_TYPE, PART.P_SIZE, PARTSUPP.PS_SUPPKEY\ + \n Filter: PART.P_PARTKEY = PARTSUPP.PS_PARTKEY AND PART.P_BRAND != Utf8(\"Brand#45\") AND NOT PART.P_TYPE LIKE CAST(Utf8(\"MEDIUM POLISHED%\") AS Utf8) AND (PART.P_SIZE = Int32(49) OR PART.P_SIZE = Int32(14) OR PART.P_SIZE = Int32(23) OR PART.P_SIZE = Int32(45) OR PART.P_SIZE = Int32(19) OR PART.P_SIZE = Int32(3) OR PART.P_SIZE = Int32(36) OR PART.P_SIZE = Int32(9)) AND NOT PARTSUPP.PS_SUPPKEY IN ()\ + \n Subquery:\ + \n Projection: SUPPLIER.S_SUPPKEY\ + \n Filter: SUPPLIER.S_COMMENT LIKE CAST(Utf8(\"%Customer%Complaints%\") AS Utf8)\ + \n TableScan: SUPPLIER\ + \n Cross Join: \ + \n TableScan: PARTSUPP\ + \n TableScan: PART" + ); + Ok(()) + } + + #[ignore] + #[tokio::test] + async fn tpch_test_17() -> Result<()> { + let plan_str = tpch_plan_to_string(17).await?; + assert_eq!(plan_str, "panics due to out of bounds field access"); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_18() -> Result<()> { + let plan_str = tpch_plan_to_string(18).await?; + assert_eq!( + plan_str, + "Projection: CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE, sum(LINEITEM.L_QUANTITY) AS EXPR$5\ + \n Limit: skip=0, fetch=100\ + \n Sort: ORDERS.O_TOTALPRICE DESC NULLS FIRST, ORDERS.O_ORDERDATE ASC NULLS LAST\ + \n Aggregate: groupBy=[[CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ + \n Projection: CUSTOMER.C_NAME, CUSTOMER.C_CUSTKEY, ORDERS.O_ORDERKEY, ORDERS.O_ORDERDATE, ORDERS.O_TOTALPRICE, LINEITEM.L_QUANTITY\ + \n Filter: ORDERS.O_ORDERKEY IN () AND CUSTOMER.C_CUSTKEY = ORDERS.O_CUSTKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY\ + \n Subquery:\ + \n Projection: LINEITEM.L_ORDERKEY\ + \n Filter: sum(LINEITEM.L_QUANTITY) > CAST(Int32(300) AS Decimal128(15, 2))\ + \n Aggregate: groupBy=[[LINEITEM.L_ORDERKEY]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ + \n Projection: LINEITEM.L_ORDERKEY, LINEITEM.L_QUANTITY\ + \n TableScan: LINEITEM\ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: CUSTOMER\ + \n TableScan: ORDERS\ + \n TableScan: LINEITEM" + ); + Ok(()) + } + #[tokio::test] + async fn tpch_test_19() -> Result<()> { + let plan_str = tpch_plan_to_string(19).await?; + assert_eq!( + plan_str, + "Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_EXTENDEDPRICE * Int32(1) - LINEITEM.L_DISCOUNT) AS REVENUE]]\ + \n Projection: LINEITEM.L_EXTENDEDPRICE * (CAST(Int32(1) AS Decimal128(15, 2)) - LINEITEM.L_DISCOUNT)\ + \n Filter: PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#12\") AND (PART.P_CONTAINER = CAST(Utf8(\"SM CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"SM PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(1) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(1) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(5) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#23\") AND (PART.P_CONTAINER = CAST(Utf8(\"MED BAG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PKG\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"MED PACK\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(10) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(10) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(10) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\") OR PART.P_PARTKEY = LINEITEM.L_PARTKEY AND PART.P_BRAND = Utf8(\"Brand#34\") AND (PART.P_CONTAINER = CAST(Utf8(\"LG CASE\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG BOX\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PACK\") AS Utf8) OR PART.P_CONTAINER = CAST(Utf8(\"LG PKG\") AS Utf8)) AND LINEITEM.L_QUANTITY >= CAST(Int32(20) AS Decimal128(15, 2)) AND LINEITEM.L_QUANTITY <= CAST(Int32(20) + Int32(10) AS Decimal128(15, 2)) AND PART.P_SIZE >= Int32(1) AND PART.P_SIZE <= Int32(15) AND (LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR\") AS Utf8) OR LINEITEM.L_SHIPMODE = CAST(Utf8(\"AIR REG\") AS Utf8)) AND LINEITEM.L_SHIPINSTRUCT = Utf8(\"DELIVER IN PERSON\")\ + \n Cross Join: \ + \n TableScan: LINEITEM\ + \n TableScan: PART" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_20() -> Result<()> { + let plan_str = tpch_plan_to_string(20).await?; + assert_eq!( + plan_str, + "Sort: SUPPLIER.S_NAME ASC NULLS LAST\ + \n Projection: SUPPLIER.S_NAME, SUPPLIER.S_ADDRESS\ + \n Filter: SUPPLIER.S_SUPPKEY IN () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"CANADA\")\ + \n Subquery:\ + \n Projection: PARTSUPP.PS_SUPPKEY\ + \n Filter: PARTSUPP.PS_PARTKEY IN () AND CAST(PARTSUPP.PS_AVAILQTY AS Decimal128(19, 0)) > ()\ + \n Subquery:\ + \n Projection: PART.P_PARTKEY\ + \n Filter: PART.P_NAME LIKE CAST(Utf8(\"forest%\") AS Utf8)\ + \n TableScan: PART\ + \n Subquery:\ + \n Projection: Decimal128(Some(5),2,1) * sum(LINEITEM.L_QUANTITY)\ + \n Aggregate: groupBy=[[]], aggr=[[sum(LINEITEM.L_QUANTITY)]]\ + \n Projection: LINEITEM.L_QUANTITY\ + \n Filter: LINEITEM.L_PARTKEY = LINEITEM.L_ORDERKEY AND LINEITEM.L_SUPPKEY = LINEITEM.L_PARTKEY AND LINEITEM.L_SHIPDATE >= CAST(Utf8(\"1994-01-01\") AS Date32) AND LINEITEM.L_SHIPDATE < CAST(Utf8(\"1995-01-01\") AS Date32)\ + \n TableScan: LINEITEM\ + \n TableScan: PARTSUPP\ + \n Cross Join: \ + \n TableScan: SUPPLIER\ + \n TableScan: NATION" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_21() -> Result<()> { + let plan_str = tpch_plan_to_string(21).await?; + assert_eq!( + plan_str, + "Projection: SUPPLIER.S_NAME, count(Int64(1)) AS NUMWAIT\ + \n Limit: skip=0, fetch=100\ + \n Sort: count(Int64(1)) DESC NULLS FIRST, SUPPLIER.S_NAME ASC NULLS LAST\ + \n Aggregate: groupBy=[[SUPPLIER.S_NAME]], aggr=[[count(Int64(1))]]\ + \n Projection: SUPPLIER.S_NAME\ + \n Filter: SUPPLIER.S_SUPPKEY = LINEITEM.L_SUPPKEY AND ORDERS.O_ORDERKEY = LINEITEM.L_ORDERKEY AND ORDERS.O_ORDERSTATUS = Utf8(\"F\") AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE AND EXISTS () AND NOT EXISTS () AND SUPPLIER.S_NATIONKEY = NATION.N_NATIONKEY AND NATION.N_NAME = Utf8(\"SAUDI ARABIA\")\ + \n Subquery:\ + \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS\ + \n TableScan: LINEITEM\ + \n Subquery:\ + \n Filter: LINEITEM.L_ORDERKEY = LINEITEM.L_TAX AND LINEITEM.L_SUPPKEY != LINEITEM.L_LINESTATUS AND LINEITEM.L_RECEIPTDATE > LINEITEM.L_COMMITDATE\ + \n TableScan: LINEITEM\ + \n Cross Join: \ + \n Cross Join: \ + \n Cross Join: \ + \n TableScan: SUPPLIER\ + \n TableScan: LINEITEM\ + \n TableScan: ORDERS\ + \n TableScan: NATION" + ); + Ok(()) + } + + #[tokio::test] + async fn tpch_test_22() -> Result<()> { + let plan_str = tpch_plan_to_string(22).await?; + assert_eq!( + plan_str, + "Projection: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) AS CNTRYCODE, count(Int64(1)) AS NUMCUST, sum(CUSTOMER.C_ACCTBAL) AS TOTACCTBAL\ + \n Sort: substr(CUSTOMER.C_PHONE,Int32(1),Int32(2)) ASC NULLS LAST\ + \n Aggregate: groupBy=[[substr(CUSTOMER.C_PHONE,Int32(1),Int32(2))]], aggr=[[count(Int64(1)), sum(CUSTOMER.C_ACCTBAL)]]\ + \n Projection: substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)), CUSTOMER.C_ACCTBAL\ + \n Filter: (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8)) AND CUSTOMER.C_ACCTBAL > () AND NOT EXISTS ()\ + \n Subquery:\ + \n Aggregate: groupBy=[[]], aggr=[[avg(CUSTOMER.C_ACCTBAL)]]\ + \n Projection: CUSTOMER.C_ACCTBAL\ + \n Filter: CUSTOMER.C_ACCTBAL > Decimal128(Some(0),3,2) AND (substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"13\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"31\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"23\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"29\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"30\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"18\") AS Utf8) OR substr(CUSTOMER.C_PHONE, Int32(1), Int32(2)) = CAST(Utf8(\"17\") AS Utf8))\ + \n TableScan: CUSTOMER\ + \n Subquery:\ + \n Filter: ORDERS.O_CUSTKEY = ORDERS.O_ORDERKEY\ + \n TableScan: ORDERS\ + \n TableScan: CUSTOMER" + ); + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs b/datafusion/substrait/tests/cases/emit_kind_tests.rs new file mode 100644 index 000000000000..ac66177ed796 --- /dev/null +++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for Emit Kind usage + +#[cfg(test)] +mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + + use datafusion::common::Result; + use datafusion::execution::SessionStateBuilder; + use datafusion::prelude::{CsvReadOptions, SessionConfig, SessionContext}; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use datafusion_substrait::logical_plan::producer::to_substrait_plan; + + #[tokio::test] + async fn project_respects_direct_emit_kind() -> Result<()> { + let proto_plan = read_json( + "tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json", + ); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + let plan_str = format!("{}", plan); + + assert_eq!( + plan_str, + "Projection: DATA.A AS a, DATA.B AS b, DATA.A + Int64(1) AS add1\ + \n TableScan: DATA" + ); + Ok(()) + } + + #[tokio::test] + async fn handle_emit_as_project() -> Result<()> { + let proto_plan = read_json( + "tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json", + ); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + let plan_str = format!("{}", plan); + + assert_eq!( + plan_str, + // Note that duplicate references in the remap are aliased + "Projection: DATA.B, DATA.A AS A1, DATA.A AS DATA.A__temp__0 AS A2\ + \n Filter: DATA.B = Int64(2)\ + \n TableScan: DATA" + ); + Ok(()) + } + + async fn make_context() -> Result { + let state = SessionStateBuilder::new() + .with_config(SessionConfig::default()) + .with_default_features() + .build(); + let ctx = SessionContext::new_with_state(state); + ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::default()) + .await?; + Ok(ctx) + } + + #[tokio::test] + async fn handle_emit_as_project_with_volatile_expr() -> Result<()> { + let ctx = make_context().await?; + + let df = ctx + .sql("SELECT random() AS c1, a + 1 AS c2 FROM data") + .await?; + + let plan = df.into_unoptimized_plan(); + assert_eq!( + format!("{}", plan), + "Projection: random() AS c1, data.a + Int64(1) AS c2\ + \n TableScan: data" + ); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + // note how the Projections are not flattened + assert_eq!( + format!("{}", plan2), + "Projection: random() AS c1, data.a + Int64(1) AS c2\ + \n Projection: data.a, data.b, data.c, data.d, data.e, data.f, random(), data.a + Int64(1)\ + \n TableScan: data" + ); + Ok(()) + } + + #[tokio::test] + async fn handle_emit_as_project_without_volatile_exprs() -> Result<()> { + let ctx = make_context().await?; + let df = ctx.sql("SELECT a + 1, b + 2 FROM data").await?; + + let plan = df.into_unoptimized_plan(); + assert_eq!( + format!("{}", plan), + "Projection: data.a + Int64(1), data.b + Int64(2)\ + \n TableScan: data" + ); + + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + + let plan1str = format!("{plan}"); + let plan2str = format!("{plan2}"); + assert_eq!(plan1str, plan2str); + + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/function_test.rs b/datafusion/substrait/tests/cases/function_test.rs new file mode 100644 index 000000000000..b136b0af19c2 --- /dev/null +++ b/datafusion/substrait/tests/cases/function_test.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for Function Compatibility + +#[cfg(test)] +mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + + use datafusion::common::Result; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + + #[tokio::test] + async fn contains_function_test() -> Result<()> { + let proto_plan = read_json("tests/testdata/contains_plan.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + let plan_str = format!("{}", plan); + + assert_eq!( + plan_str, + "Projection: nation.n_name\ + \n Filter: contains(nation.n_name, Utf8(\"IA\"))\ + \n TableScan: nation" + ); + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/logical_plans.rs b/datafusion/substrait/tests/cases/logical_plans.rs new file mode 100644 index 000000000000..f4e34af35d78 --- /dev/null +++ b/datafusion/substrait/tests/cases/logical_plans.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for reading substrait plans produced by other systems + +#[cfg(test)] +mod tests { + use crate::utils::test::{add_plan_schemas_to_ctx, read_json}; + use datafusion::common::Result; + use datafusion::dataframe::DataFrame; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + + #[tokio::test] + async fn scalar_function_compound_signature() -> Result<()> { + // DataFusion currently produces Substrait that refers to functions only by their name. + // However, the Substrait spec requires that functions be identified by their compound signature. + // This test confirms that DataFusion is able to consume plans following the spec, even though + // we don't yet produce such plans. + // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus --create "create table data (d boolean)" "select not d from data" + let proto_plan = + read_json("tests/testdata/test_plans/select_not_bool.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: NOT DATA.D AS EXPR$0\ + \n TableScan: DATA" + ); + Ok(()) + } + + // Aggregate function compound signature is tested through TPCH plans + + #[tokio::test] + async fn window_function_compound_signature() -> Result<()> { + // DataFusion currently produces Substrait that refers to functions only by their name. + // However, the Substrait spec requires that functions be identified by their compound signature. + // This test confirms that DataFusion is able to consume plans following the spec, even though + // we don't yet produce such plans. + // Once we start producing plans with compound signatures, this test can be replaced by the roundtrip tests. + + // File generated with substrait-java's Isthmus: + // ./isthmus-cli/build/graal/isthmus --create "create table data (d int, part int, ord int)" "select sum(d) OVER (PARTITION BY part ORDER BY ord ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING) AS lead_expr from data" + let proto_plan = + read_json("tests/testdata/test_plans/select_window.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING AS LEAD_EXPR\ + \n WindowAggr: windowExpr=[[sum(DATA.D) PARTITION BY [DATA.PART] ORDER BY [DATA.ORD ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: DATA" + ); + Ok(()) + } + + #[tokio::test] + async fn non_nullable_lists() -> Result<()> { + // DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable. + // That's because implementing the non-nullability consistently is non-trivial. + // This test confirms that reading a plan with non-nullable lists works as expected. + let proto_plan = + read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json"); + let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))"); + + // Need to trigger execution to ensure that Arrow has validated the plan + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } +} diff --git a/datafusion/substrait/tests/cases/mod.rs b/datafusion/substrait/tests/cases/mod.rs index b17289205f3d..b1f4b95df66f 100644 --- a/datafusion/substrait/tests/cases/mod.rs +++ b/datafusion/substrait/tests/cases/mod.rs @@ -15,6 +15,11 @@ // specific language governing permissions and limitations // under the License. +mod consumer_integration; +mod emit_kind_tests; +mod function_test; +mod logical_plans; mod roundtrip_logical_plan; mod roundtrip_physical_plan; mod serialize; +mod substrait_validations; diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 28c0de1c9973..5687c9af540a 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -15,32 +15,35 @@ // specific language governing permissions and limitations // under the License. +use crate::utils::test::read_json; use datafusion::arrow::array::ArrayRef; use datafusion::physical_plan::Accumulator; use datafusion::scalar::ScalarValue; use datafusion_substrait::logical_plan::{ consumer::from_substrait_plan, producer::to_substrait_plan, }; +use std::cmp::Ordering; -use std::hash::Hash; -use std::sync::Arc; - -use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; +use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; use datafusion::error::Result; -use datafusion::execution::context::SessionState; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_expr::{ - Extension, LogicalPlan, Repartition, UserDefinedLogicalNode, Volatility, + Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, + Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; +use std::hash::Hash; +use std::sync::Arc; +use datafusion::execution::session_state::SessionStateBuilder; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; use substrait::proto::{plan_rel, Plan, Rel}; +#[derive(Debug)] struct MockSerializerRegistry; impl SerializerRegistry for MockSerializerRegistry { @@ -63,8 +66,7 @@ impl SerializerRegistry for MockSerializerRegistry { &self, name: &str, bytes: &[u8], - ) -> Result> - { + ) -> Result> { if name == "MockUserDefinedLogicalPlan" { MockUserDefinedLogicalPlan::deserialize(bytes) } else { @@ -81,6 +83,17 @@ struct MockUserDefinedLogicalPlan { empty_schema: DFSchemaRef, } +// `PartialOrd` needed for `UserDefinedLogicalNodeCore`, manual implementation necessary due to +// the `empty_schema` field. +impl PartialOrd for MockUserDefinedLogicalPlan { + fn partial_cmp(&self, other: &Self) -> Option { + match self.validation_bytes.partial_cmp(&other.validation_bytes) { + Some(Ordering::Equal) => self.inputs.partial_cmp(&other.inputs), + cmp => cmp, + } + } +} + impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { fn as_any(&self) -> &dyn std::any::Any { self @@ -110,16 +123,16 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { ) } - fn from_template( + fn with_exprs_and_inputs( &self, - _: &[Expr], - inputs: &[LogicalPlan], - ) -> Arc { - Arc::new(Self { + _: Vec, + inputs: Vec, + ) -> Result> { + Ok(Arc::new(Self { validation_bytes: self.validation_bytes.clone(), - inputs: inputs.to_vec(), + inputs, empty_schema: Arc::new(DFSchema::empty()), - }) + })) } fn dyn_hash(&self, _: &mut dyn std::hash::Hasher) { @@ -129,6 +142,14 @@ impl UserDefinedLogicalNode for MockUserDefinedLogicalPlan { fn dyn_eq(&self, _: &dyn UserDefinedLogicalNode) -> bool { unimplemented!() } + + fn dyn_ord(&self, _: &dyn UserDefinedLogicalNode) -> Option { + unimplemented!() + } + + fn supports_limit_pushdown(&self) -> bool { + false // Disallow limit push-down by default + } } impl MockUserDefinedLogicalPlan { @@ -159,7 +180,18 @@ async fn simple_select() -> Result<()> { #[tokio::test] async fn wildcard_select() -> Result<()> { - roundtrip("SELECT * FROM data").await + assert_expected_plan_unoptimized( + "SELECT * FROM data", + "Projection: data.a, data.b, data.c, data.d, data.e, data.f\ + \n TableScan: data", + true, + ) + .await +} + +#[tokio::test] +async fn select_with_alias() -> Result<()> { + roundtrip("SELECT a AS aliased_a FROM data").await } #[tokio::test] @@ -169,14 +201,28 @@ async fn select_with_filter() -> Result<()> { #[tokio::test] async fn select_with_reused_functions() -> Result<()> { + let ctx = create_context().await?; let sql = "SELECT * FROM data WHERE a > 1 AND a < 10 AND b > 0"; - roundtrip(sql).await?; - let (mut function_names, mut function_anchors) = function_extension_info(sql).await?; - function_names.sort(); - function_anchors.sort(); + let proto = roundtrip_with_ctx(sql, ctx).await?; + let mut functions = proto + .extensions + .iter() + .map(|e| match e.mapping_type.as_ref().unwrap() { + MappingType::ExtensionFunction(ext_f) => { + (ext_f.function_anchor, ext_f.name.to_owned()) + } + _ => unreachable!("Non-function extensions not expected"), + }) + .collect::>(); + functions.sort_by_key(|(anchor, _)| *anchor); - assert_eq!(function_names, ["and", "gt", "lt"]); - assert_eq!(function_anchors, [0, 1, 2]); + // Functions are encountered (and thus registered) depth-first + let expected = vec![ + (0, "gt".to_string()), + (1, "lt".to_string()), + (2, "and".to_string()), + ]; + assert_eq!(functions, expected); Ok(()) } @@ -234,8 +280,10 @@ async fn aggregate_grouping_sets() -> Result<()> { async fn aggregate_grouping_rollup() -> Result<()> { assert_expected_plan( "SELECT a, c, e, avg(b) FROM data GROUP BY ROLLUP (a, c, e)", - "Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[AVG(data.b)]]\ - \n TableScan: data projection=[a, b, c, e]" + "Projection: data.a, data.c, data.e, avg(data.b)\ + \n Aggregate: groupBy=[[GROUPING SETS ((data.a, data.c, data.e), (data.a, data.c), (data.a), ())]], aggr=[[avg(data.b)]]\ + \n TableScan: data projection=[a, b, c, e]", + true ).await } @@ -321,7 +369,7 @@ async fn simple_scalar_function_pow() -> Result<()> { #[tokio::test] async fn simple_scalar_function_substr() -> Result<()> { - roundtrip("SELECT * FROM data WHERE a = SUBSTR('datafusion', 0, 3)").await + roundtrip("SELECT SUBSTR(f, 1, 3) FROM data").await } #[tokio::test] @@ -365,9 +413,10 @@ async fn implicit_cast() -> Result<()> { #[tokio::test] async fn aggregate_case() -> Result<()> { assert_expected_plan( - "SELECT SUM(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", - "Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END)]]\ + "SELECT sum(CASE WHEN a > 0 THEN 1 ELSE NULL END) FROM data", + "Aggregate: groupBy=[[]], aggr=[[sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE Int64(NULL) END) AS sum(CASE WHEN data.a > Int64(0) THEN Int64(1) ELSE NULL END)]]\ \n TableScan: data projection=[a]", + true ) .await } @@ -404,17 +453,16 @@ async fn roundtrip_inlist_5() -> Result<()> { // on roundtrip there is an additional projection during TableScan which includes all column of the table, // using assert_expected_plan here as a workaround assert_expected_plan( - "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", - "Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2 projection=[a, b, c, d, e, f]\ - \n TableScan: data projection=[a, f], partial_filters=[data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data.a IN ()]\ - \n Subquery:\ - \n Projection: data2.a\ - \n Filter: data2.f IN ([Utf8(\"b\"), Utf8(\"c\"), Utf8(\"d\")])\ - \n TableScan: data2 projection=[a, b, c, d, e, f]").await + "SELECT a, f FROM data WHERE (f IN ('a', 'b', 'c') OR a in (SELECT data2.a FROM data2 WHERE f IN ('b', 'c', 'd')))", + + "Projection: data.a, data.f\ + \n Filter: data.f = Utf8(\"a\") OR data.f = Utf8(\"b\") OR data.f = Utf8(\"c\") OR data2.mark\ + \n LeftMark Join: data.a = data2.a\ + \n TableScan: data projection=[a, f]\ + \n Projection: data2.a\ + \n Filter: data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")\ + \n TableScan: data2 projection=[a, f], partial_filters=[data2.f = Utf8(\"b\") OR data2.f = Utf8(\"c\") OR data2.f = Utf8(\"d\")]", + true).await } #[tokio::test] @@ -450,7 +498,8 @@ async fn roundtrip_exists_filter() -> Result<()> { "Projection: data.b\ \n LeftSemi Join: data.a = data2.a Filter: data2.e != CAST(data.e AS Int64)\ \n TableScan: data projection=[a, b, e]\ - \n TableScan: data2 projection=[a, e]" + \n TableScan: data2 projection=[a, e]", + false // "d1" vs "data" field qualifier ).await } @@ -462,6 +511,7 @@ async fn inner_join() -> Result<()> { \n Inner Join: data.a = data2.a\ \n TableScan: data projection=[a]\ \n TableScan: data2 projection=[a]", + true, ) .await } @@ -481,6 +531,23 @@ async fn roundtrip_outer_join() -> Result<()> { roundtrip("SELECT data.a FROM data FULL OUTER JOIN data2 ON data.a = data2.a").await } +#[tokio::test] +async fn roundtrip_self_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // This roundtrip works because we set aliases to what the Substrait consumer will generate. + roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.a = right.a").await?; + roundtrip("SELECT left.a as left_a, left.b, right.a as right_a, right.c FROM data AS left JOIN data AS right ON left.b = right.b").await +} + +#[tokio::test] +async fn roundtrip_self_implicit_cross_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // This roundtrip works because we set aliases to what the Substrait consumer will generate. + roundtrip("SELECT left.a left_a, left.b, right.a right_a, right.c FROM data AS left, data AS right").await +} + #[tokio::test] async fn roundtrip_arithmetic_ops() -> Result<()> { roundtrip("SELECT a - a FROM data").await?; @@ -506,6 +573,11 @@ async fn roundtrip_ilike() -> Result<()> { roundtrip("SELECT f FROM data WHERE f ILIKE 'a%b'").await } +#[tokio::test] +async fn roundtrip_modulus() -> Result<()> { + roundtrip("SELECT a%3 from data").await +} + #[tokio::test] async fn roundtrip_not() -> Result<()> { roundtrip("SELECT * FROM data WHERE NOT d").await @@ -566,71 +638,182 @@ async fn roundtrip_union_all() -> Result<()> { #[tokio::test] async fn simple_intersect() -> Result<()> { + // Substrait treats both count(*) and count(1) the same assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ + "SELECT count(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);", + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\ \n Projection: \ \n LeftSemi Join: data.a = data2.a\ \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ \n TableScan: data projection=[a]\ \n TableScan: data2 projection=[a]", + true ) .await } #[tokio::test] -async fn simple_intersect_table_reuse() -> Result<()> { - assert_expected_plan( - "SELECT COUNT(*) FROM (SELECT data.a FROM data INTERSECT SELECT data.a FROM data);", - "Aggregate: groupBy=[[]], aggr=[[COUNT(Int64(1))]]\ - \n Projection: \ - \n LeftSemi Join: data.a = data.a\ - \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ - \n TableScan: data projection=[a]\ - \n TableScan: data projection=[a]", +async fn aggregate_wo_projection_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_no_project.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\ + \n TableScan: data projection=[a]", ) - .await + .await } #[tokio::test] -async fn simple_window_function() -> Result<()> { - roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, SUM(b) OVER (PARTITION BY a) FROM data;").await +async fn aggregate_wo_projection_group_expression_ref_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) AS countA]]\ + \n TableScan: data projection=[a]", + ) + .await } #[tokio::test] -async fn qualified_schema_table_reference() -> Result<()> { - roundtrip("SELECT * FROM public.data;").await +async fn aggregate_wo_projection_sorted_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json"); + + assert_expected_plan_substrait( + proto_plan, + "Aggregate: groupBy=[[data.a]], aggr=[[count(data.a) ORDER BY [data.a DESC NULLS FIRST] AS countA]]\ + \n TableScan: data projection=[a]", + ) + .await } #[tokio::test] -async fn qualified_catalog_schema_table_reference() -> Result<()> { - roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await +async fn simple_intersect_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/intersect.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2", + ) + .await } #[tokio::test] -async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { - assert_expected_plan( - "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", - "Projection: data.b, data.c\ - \n Inner Join: data.a = data.a\ - \n TableScan: data projection=[a, b]\ - \n TableScan: data projection=[a, c]", +async fn primary_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT (SELECT a FROM data2 UNION ALL SELECT a FROM data2)", ) .await } #[tokio::test] -async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { - assert_expected_plan( - "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", - "Projection: data.b, data.c\ - \n Inner Join: data.b = data.b\ - \n TableScan: data projection=[b]\ - \n TableScan: data projection=[b, c]", +async fn multiset_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2 INTERSECT SELECT a FROM data2", ) .await } +#[tokio::test] +async fn multiset_intersect_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT ALL SELECT a FROM data2 INTERSECT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/minus_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT SELECT a FROM data2 EXCEPT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/minus_primary_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT ALL SELECT a FROM data2 EXCEPT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn union_distinct_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/union_distinct.substrait.json"); + + assert_substrait_sql(proto_plan, "SELECT a FROM data UNION SELECT a FROM data2").await +} + +#[tokio::test] +async fn simple_intersect_table_reuse() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // In this case the aliasing happens at a different point in the plan, so we cannot use roundtrip. + // Schema check works because we set aliases to what the Substrait consumer will generate. + assert_expected_plan( + "SELECT count(1) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);", + "Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\ + \n Projection: \ + \n LeftSemi Join: left.a = right.a\ + \n SubqueryAlias: left\ + \n Aggregate: groupBy=[[data.a]], aggr=[[]]\ + \n TableScan: data projection=[a]\ + \n SubqueryAlias: right\ + \n TableScan: data projection=[a]", + true + ).await +} + +#[tokio::test] +async fn simple_window_function() -> Result<()> { + roundtrip("SELECT RANK() OVER (PARTITION BY a ORDER BY b), d, sum(b) OVER (PARTITION BY a) FROM data;").await +} + +#[tokio::test] +async fn window_with_rows() -> Result<()> { + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 2 FOLLOWING AND 4 FOLLOWING) FROM data;").await?; + roundtrip("SELECT sum(b) OVER (PARTITION BY a ROWS BETWEEN 4 PRECEDING AND 2 PRECEDING) FROM data;").await +} + +#[tokio::test] +async fn qualified_schema_table_reference() -> Result<()> { + roundtrip("SELECT * FROM public.data;").await +} + +#[tokio::test] +async fn qualified_catalog_schema_table_reference() -> Result<()> { + roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await +} + /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported @@ -659,12 +842,106 @@ async fn all_type_literal() -> Result<()> { date32_col = arrow_cast('2020-01-01', 'Date32') AND binary_col = arrow_cast('binary', 'Binary') AND large_binary_col = arrow_cast('large_binary', 'LargeBinary') AND + view_binary_col = arrow_cast('binary_view', 'BinaryView') AND utf8_col = arrow_cast('utf8', 'Utf8') AND - large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8');", + large_utf8_col = arrow_cast('large_utf8', 'LargeUtf8') AND + view_utf8_col = arrow_cast('utf8_view', 'Utf8View');", ) .await } +#[tokio::test] +async fn roundtrip_literal_list() -> Result<()> { + roundtrip("SELECT [[1,2,3], [], NULL, [NULL]] FROM data").await +} + +#[tokio::test] +async fn roundtrip_literal_struct() -> Result<()> { + assert_expected_plan( + "SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data", + "Projection: Struct({c0:1,c1:true,c2:}) AS struct(Int64(1),Boolean(true),NULL)\ + \n TableScan: data projection=[]", + false, // "Struct(..)" vs "struct(..)" + ) + .await +} + +#[tokio::test] +async fn roundtrip_values() -> Result<()> { + // TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently + assert_expected_plan( + "VALUES \ + (\ + 1, \ + 'a', \ + [[-213.1, NULL, 5.5, 2.0, 1.0], []], \ + arrow_cast([1,2,3], 'LargeList(Int64)'), \ + STRUCT(true, 1 AS int_field, CAST(NULL AS STRING)), \ + [STRUCT(STRUCT('a' AS string_field) AS struct_field), STRUCT(STRUCT('b' AS string_field) AS struct_field)]\ + ), \ + (NULL, NULL, NULL, NULL, NULL, NULL)", + "Values: \ + (\ + Int64(1), \ + Utf8(\"a\"), \ + List([[-213.1, , 5.5, 2.0, 1.0], []]), \ + LargeList([1, 2, 3]), \ + Struct({c0:true,int_field:1,c2:}), \ + List([{struct_field: {string_field: a}}, {struct_field: {string_field: b}}])\ + ), \ + (Int64(NULL), Utf8(NULL), List(), LargeList(), Struct({c0:,int_field:,c2:}), List())", + true).await +} + +#[tokio::test] +async fn roundtrip_values_no_columns() -> Result<()> { + let ctx = create_context().await?; + // "VALUES ()" is not yet supported by the SQL parser, so we construct the plan manually + let plan = LogicalPlan::Values(Values { + values: vec![vec![], vec![]], // two rows, no columns + schema: DFSchemaRef::new(DFSchema::empty()), + }); + roundtrip_logical_plan_with_ctx(plan, ctx).await?; + Ok(()) +} + +#[tokio::test] +async fn roundtrip_values_empty_relation() -> Result<()> { + roundtrip("SELECT * FROM (VALUES ('a')) LIMIT 0").await +} + +#[tokio::test] +async fn roundtrip_values_duplicate_column_join() -> Result<()> { + // Substrait does currently NOT maintain the alias of the tables. + // Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide. + // This roundtrip works because we set aliases to what the Substrait consumer will generate. + roundtrip( + "SELECT left.column1 as c1, right.column1 as c2 \ + FROM \ + (VALUES (1)) AS left \ + JOIN \ + (VALUES (2)) AS right \ + ON left.column1 == right.column1", + ) + .await +} + +#[tokio::test] +async fn duplicate_column() -> Result<()> { + // Substrait does not keep column names (aliases) in the plan, rather it operates on column indices + // only. DataFusion however, is strict about not having duplicate column names appear in the plan. + // This test confirms that we generate aliases for columns in the plan which would otherwise have + // colliding names. + assert_expected_plan( + "SELECT a + 1 as sum_a, a + 1 as sum_a_2 FROM data", + "Projection: data.a + Int64(1) AS sum_a, data.a + Int64(1) AS data.a + Int64(1)__temp__0 AS sum_a_2\ + \n Projection: data.a + Int64(1)\ + \n TableScan: data projection=[a]", + true, + ) + .await +} + /// Construct a plan that cast columns. Only those SQL types are supported for now. #[tokio::test] async fn new_test_grammar() -> Result<()> { @@ -704,8 +981,8 @@ async fn extension_logical_plan() -> Result<()> { let proto = to_substrait_plan(&ext_plan, &ctx)?; let plan2 = from_substrait_plan(&ctx, &proto).await?; - let plan1str = format!("{ext_plan:?}"); - let plan2str = format!("{plan2:?}"); + let plan1str = format!("{ext_plan}"); + let plan2str = format!("{plan2}"); assert_eq!(plan1str, plan2str); Ok(()) @@ -717,27 +994,24 @@ async fn roundtrip_aggregate_udf() -> Result<()> { struct Dummy {} impl Accumulator for Dummy { - fn state(&mut self) -> datafusion::error::Result> { - Ok(vec![]) + fn state(&mut self) -> Result> { + Ok(vec![ScalarValue::Float64(None), ScalarValue::UInt32(None)]) } - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { + fn update_batch(&mut self, _values: &[ArrayRef]) -> Result<()> { Ok(()) } - fn merge_batch(&mut self, _states: &[ArrayRef]) -> datafusion::error::Result<()> { + fn merge_batch(&mut self, _states: &[ArrayRef]) -> Result<()> { Ok(()) } - fn evaluate(&mut self) -> datafusion::error::Result { - Ok(ScalarValue::Float64(None)) + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::Int64(None)) } fn size(&self) -> usize { - std::mem::size_of_val(self) + size_of_val(self) } } @@ -757,8 +1031,44 @@ async fn roundtrip_aggregate_udf() -> Result<()> { let ctx = create_context().await?; ctx.register_udaf(dummy_agg); + roundtrip_with_ctx("select dummy_agg(a) from data", ctx.clone()).await?; + roundtrip_with_ctx("select dummy_agg(a order by a) from data", ctx.clone()).await?; + + Ok(()) +} + +#[tokio::test] +async fn roundtrip_window_udf() -> Result<()> { + #[derive(Debug)] + struct Dummy {} + + impl PartitionEvaluator for Dummy { + fn evaluate_all( + &mut self, + values: &[ArrayRef], + _num_rows: usize, + ) -> Result { + Ok(values[0].to_owned()) + } + } + + fn make_partition_evaluator() -> Result> { + Ok(Box::new(Dummy {})) + } - roundtrip_with_ctx("select dummy_agg(a) from data", ctx).await + let dummy_agg = create_udwf( + "dummy_window", // name + DataType::Int64, // input type + Arc::new(DataType::Int64), // return type + Volatility::Immutable, + Arc::new(make_partition_evaluator), + ); + + let ctx = create_context().await?; + ctx.register_udwf(dummy_agg); + + roundtrip_with_ctx("select dummy_window(a) OVER () from data", ctx).await?; + Ok(()) } #[tokio::test] @@ -774,7 +1084,7 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> { let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; - assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + assert_eq!(format!("{plan}"), format!("{plan2}")); Ok(()) } @@ -791,7 +1101,7 @@ async fn roundtrip_repartition_hash() -> Result<()> { let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; - assert_eq!(format!("{plan:?}"), format!("{plan2:?}")); + assert_eq!(format!("{plan}"), format!("{plan2}")); Ok(()) } @@ -880,31 +1190,106 @@ async fn verify_post_join_filter_value(proto: Box) -> Result<()> { Ok(()) } -async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> { +async fn assert_expected_plan_unoptimized( + sql: &str, + expected_plan_str: &str, + assert_schema: bool, +) -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql(sql).await?; + let plan = df.into_unoptimized_plan(); + let proto = to_substrait_plan(&plan, &ctx)?; + let plan2 = from_substrait_plan(&ctx, &proto).await?; + + println!("{plan}"); + println!("{plan2}"); + + println!("{proto:?}"); + + if assert_schema { + assert_eq!(plan.schema(), plan2.schema()); + } + + let plan2str = format!("{plan2}"); + assert_eq!(expected_plan_str, &plan2str); + + Ok(()) +} + +async fn assert_expected_plan( + sql: &str, + expected_plan_str: &str, + assert_schema: bool, +) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; - let plan2str = format!("{plan2:?}"); + + println!("{plan}"); + println!("{plan2}"); + + println!("{proto:?}"); + + if assert_schema { + assert_eq!(plan.schema(), plan2.schema()); + } + + let plan2str = format!("{plan2}"); assert_eq!(expected_plan_str, &plan2str); + + Ok(()) +} + +async fn assert_expected_plan_substrait( + substrait_plan: Plan, + expected_plan_str: &str, +) -> Result<()> { + let ctx = create_context().await?; + + let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + + let plan = ctx.state().optimize(&plan)?; + + let planstr = format!("{plan}"); + assert_eq!(planstr, expected_plan_str); + + Ok(()) +} + +async fn assert_substrait_sql(substrait_plan: Plan, sql: &str) -> Result<()> { + let ctx = create_context().await?; + + let expected = ctx.sql(sql).await?.into_optimized_plan()?; + + let plan = from_substrait_plan(&ctx, &substrait_plan).await?; + + let plan = ctx.state().optimize(&plan)?; + + let planstr = format!("{plan}"); + let expectedstr = format!("{expected}"); + assert_eq!(planstr, expectedstr); + Ok(()) } async fn roundtrip_fill_na(sql: &str) -> Result<()> { let ctx = create_context().await?; let df = ctx.sql(sql).await?; - let plan1 = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan1, &ctx)?; + let plan = df.into_optimized_plan()?; + let proto = to_substrait_plan(&plan, &ctx)?; let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; // Format plan string and replace all None's with 0 - let plan1str = format!("{plan1:?}").replace("None", "0"); - let plan2str = format!("{plan2:?}").replace("None", "0"); + let plan1str = format!("{plan}").replace("None", "0"); + let plan2str = format!("{plan2}").replace("None", "0"); assert_eq!(plan1str, plan2str); + + assert_eq!(plan.schema(), plan2.schema()); Ok(()) } @@ -922,112 +1307,93 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: &str) -> Result<()> { let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?; let plan = from_substrait_plan(&ctx, &proto).await?; - println!("{plan_with_alias:#?}"); - println!("{plan:#?}"); + println!("{plan_with_alias}"); + println!("{plan}"); - let plan1str = format!("{plan_with_alias:?}"); - let plan2str = format!("{plan:?}"); + let plan1str = format!("{plan_with_alias}"); + let plan2str = format!("{plan}"); assert_eq!(plan1str, plan2str); + + assert_eq!(plan_with_alias.schema(), plan.schema()); Ok(()) } -async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result<()> { - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; +async fn roundtrip_logical_plan_with_ctx( + plan: LogicalPlan, + ctx: SessionContext, +) -> Result> { let proto = to_substrait_plan(&plan, &ctx)?; let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; - println!("{plan:#?}"); - println!("{plan2:#?}"); + println!("{plan}"); + println!("{plan2}"); - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); + println!("{proto:?}"); + + let plan1str = format!("{plan}"); + let plan2str = format!("{plan2}"); assert_eq!(plan1str, plan2str); - Ok(()) -} -async fn roundtrip(sql: &str) -> Result<()> { - roundtrip_with_ctx(sql, create_context().await?).await + assert_eq!(plan.schema(), plan2.schema()); + + DataFrame::new(ctx.state(), plan2).show().await?; + Ok(proto) } -async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { - let ctx = create_context().await?; +async fn roundtrip_with_ctx(sql: &str, ctx: SessionContext) -> Result> { let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; + roundtrip_logical_plan_with_ctx(plan, ctx).await +} - println!("{plan:#?}"); - println!("{plan2:#?}"); +async fn roundtrip(sql: &str) -> Result<()> { + roundtrip_with_ctx(sql, create_context().await?).await?; + Ok(()) +} - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); +async fn roundtrip_verify_post_join_filter(sql: &str) -> Result<()> { + let ctx = create_context().await?; + let proto = roundtrip_with_ctx(sql, ctx).await?; // verify that the join filters are None verify_post_join_filter_value(proto).await } async fn roundtrip_all_types(sql: &str) -> Result<()> { - let ctx = create_all_type_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - let plan2 = from_substrait_plan(&ctx, &proto).await?; - let plan2 = ctx.state().optimize(&plan2)?; - - println!("{plan:#?}"); - println!("{plan2:#?}"); - - let plan1str = format!("{plan:?}"); - let plan2str = format!("{plan2:?}"); - assert_eq!(plan1str, plan2str); + roundtrip_with_ctx(sql, create_all_type_context().await?).await?; Ok(()) } -async fn function_extension_info(sql: &str) -> Result<(Vec, Vec)> { - let ctx = create_context().await?; - let df = ctx.sql(sql).await?; - let plan = df.into_optimized_plan()?; - let proto = to_substrait_plan(&plan, &ctx)?; - - let mut function_names: Vec = vec![]; - let mut function_anchors: Vec = vec![]; - for e in &proto.extensions { - let (function_anchor, function_name) = match e.mapping_type.as_ref().unwrap() { - MappingType::ExtensionFunction(ext_f) => (ext_f.function_anchor, &ext_f.name), - _ => unreachable!("Producer does not generate a non-function extension"), - }; - function_names.push(function_name.to_string()); - function_anchors.push(function_anchor); - } +async fn create_context() -> Result { + let mut state = SessionStateBuilder::new() + .with_config(SessionConfig::default()) + .with_runtime_env(Arc::new(RuntimeEnv::default())) + .with_default_features() + .with_serializer_registry(Arc::new(MockSerializerRegistry)) + .build(); - Ok((function_names, function_anchors)) -} + // register udaf for test, e.g. `sum()` + datafusion_functions_aggregate::register_all(&mut state) + .expect("can not register aggregate functions"); -async fn create_context() -> Result { - let state = SessionState::new_with_config_rt( - SessionConfig::default(), - Arc::new(RuntimeEnv::default()), - ) - .with_serializer_registry(Arc::new(MockSerializerRegistry)); let ctx = SessionContext::new_with_state(state); let mut explicit_options = CsvReadOptions::new(); - let schema = Schema::new(vec![ + let fields = vec![ Field::new("a", DataType::Int64, true), Field::new("b", DataType::Decimal128(5, 2), true), Field::new("c", DataType::Date32, true), Field::new("d", DataType::Boolean, true), Field::new("e", DataType::UInt32, true), Field::new("f", DataType::Utf8, true), - ]); + ]; + let schema = Schema::new(fields); explicit_options.schema = Some(&schema); ctx.register_csv("data", "tests/testdata/data.csv", explicit_options) .await?; ctx.register_csv("data2", "tests/testdata/data.csv", CsvReadOptions::new()) .await?; + Ok(ctx) } @@ -1071,9 +1437,11 @@ async fn create_all_type_context() -> Result { Field::new("date64_col", DataType::Date64, true), Field::new("binary_col", DataType::Binary, true), Field::new("large_binary_col", DataType::LargeBinary, true), + Field::new("view_binary_col", DataType::BinaryView, true), Field::new("fixed_size_binary_col", DataType::FixedSizeBinary(42), true), Field::new("utf8_col", DataType::Utf8, true), Field::new("large_utf8_col", DataType::LargeUtf8, true), + Field::new("view_utf8_col", DataType::Utf8View, true), Field::new_list("list_col", Field::new("item", DataType::Int64, true), true), Field::new_list( "large_list_col", @@ -1082,6 +1450,11 @@ async fn create_all_type_context() -> Result { ), Field::new("decimal_128_col", DataType::Decimal128(10, 2), true), Field::new("decimal_256_col", DataType::Decimal256(10, 2), true), + Field::new( + "interval_day_time_col", + DataType::Interval(IntervalUnit::DayTime), + true, + ), ]); explicit_options.schema = Some(&schema); explicit_options.has_header = false; diff --git a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs index 70887e393491..57fb3e2ee7cc 100644 --- a/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_physical_plan.rs @@ -19,43 +19,35 @@ use std::collections::HashMap; use std::sync::Arc; use datafusion::arrow::datatypes::Schema; +use datafusion::dataframe::DataFrame; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec}; use datafusion::error::Result; -use datafusion::physical_plan::{displayable, ExecutionPlan, Statistics}; -use datafusion::prelude::SessionContext; +use datafusion::physical_plan::{displayable, ExecutionPlan}; +use datafusion::prelude::{ParquetReadOptions, SessionContext}; use datafusion_substrait::physical_plan::{consumer, producer}; use substrait::proto::extensions; #[tokio::test] async fn parquet_exec() -> Result<()> { - let scan_config = FileScanConfig { - object_store_url: ObjectStoreUrl::local_filesystem(), - file_schema: Arc::new(Schema::empty()), - file_groups: vec![ - vec![PartitionedFile::new( - "file://foo/part-0.parquet".to_string(), - 123, - )], - vec![PartitionedFile::new( - "file://foo/part-1.parquet".to_string(), - 123, - )], - ], - statistics: Statistics::new_unknown(&Schema::empty()), - projection: None, - limit: None, - table_partition_cols: vec![], - output_ordering: vec![], - }; - let parquet_exec: Arc = Arc::new(ParquetExec::new( - scan_config, - None, - None, - Default::default(), - )); + let scan_config = FileScanConfig::new( + ObjectStoreUrl::local_filesystem(), + Arc::new(Schema::empty()), + ) + .with_file_groups(vec![ + vec![PartitionedFile::new( + "file://foo/part-0.parquet".to_string(), + 123, + )], + vec![PartitionedFile::new( + "file://foo/part-1.parquet".to_string(), + 123, + )], + ]); + let parquet_exec: Arc = + ParquetExec::builder(scan_config).build_arc(); let mut extension_info: ( Vec, @@ -80,3 +72,92 @@ async fn parquet_exec() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn simple_select() -> Result<()> { + roundtrip("SELECT a, b FROM data").await +} + +#[tokio::test] +#[ignore = "This test is failing because the translation of the substrait plan to the physical plan is not implemented yet"] +async fn simple_select_alltypes() -> Result<()> { + roundtrip_alltypes("SELECT bool_col, int_col FROM alltypes_plain").await +} + +#[tokio::test] +async fn wildcard_select() -> Result<()> { + roundtrip("SELECT * FROM data").await +} + +#[tokio::test] +#[ignore = "This test is failing because the translation of the substrait plan to the physical plan is not implemented yet"] +async fn wildcard_select_alltypes() -> Result<()> { + roundtrip_alltypes("SELECT * FROM alltypes_plain").await +} + +async fn roundtrip(sql: &str) -> Result<()> { + let ctx = create_parquet_context().await?; + let df = ctx.sql(sql).await?; + + roundtrip_parquet(df).await?; + + Ok(()) +} + +async fn roundtrip_alltypes(sql: &str) -> Result<()> { + let ctx = create_all_types_context().await?; + let df = ctx.sql(sql).await?; + + roundtrip_parquet(df).await?; + + Ok(()) +} + +async fn roundtrip_parquet(df: DataFrame) -> Result<()> { + let physical_plan = df.create_physical_plan().await?; + + // Convert the plan into a substrait (protobuf) Rel + let mut extension_info = (vec![], HashMap::new()); + let substrait_plan = + producer::to_substrait_rel(physical_plan.as_ref(), &mut extension_info)?; + + // Convert the substrait Rel back into a physical plan + let ctx = create_parquet_context().await?; + let physical_plan_roundtrip = + consumer::from_substrait_rel(&ctx, substrait_plan.as_ref(), &HashMap::new()) + .await?; + + // Compare the original and roundtrip physical plans + let expected = format!("{}", displayable(physical_plan.as_ref()).indent(true)); + let actual = format!( + "{}", + displayable(physical_plan_roundtrip.as_ref()).indent(true) + ); + assert_eq!(expected, actual); + + Ok(()) +} + +async fn create_parquet_context() -> Result { + let ctx = SessionContext::new(); + let explicit_options = ParquetReadOptions::default(); + + ctx.register_parquet("data", "tests/testdata/data.parquet", explicit_options) + .await?; + + Ok(ctx) +} + +async fn create_all_types_context() -> Result { + let ctx = SessionContext::new(); + + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + Ok(ctx) +} diff --git a/datafusion/substrait/tests/cases/serialize.rs b/datafusion/substrait/tests/cases/serialize.rs index f6736ca22279..54d55d1b6f10 100644 --- a/datafusion/substrait/tests/cases/serialize.rs +++ b/datafusion/substrait/tests/cases/serialize.rs @@ -20,13 +20,16 @@ mod tests { use datafusion::datasource::provider_as_source; use datafusion::logical_expr::LogicalPlanBuilder; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; - use datafusion_substrait::logical_plan::producer; + use datafusion_substrait::logical_plan::producer::to_substrait_plan; use datafusion_substrait::serializer; use datafusion::error::Result; use datafusion::prelude::*; use std::fs; + use substrait::proto::plan_rel::RelType; + use substrait::proto::rel_common::{Emit, EmitKind}; + use substrait::proto::{rel, RelCommon}; #[tokio::test] async fn serialize_simple_select() -> Result<()> { @@ -43,8 +46,8 @@ mod tests { let proto = serializer::deserialize(path).await?; // Check plan equality let plan = from_substrait_plan(&ctx, &proto).await?; - let plan_str_ref = format!("{plan_ref:?}"); - let plan_str = format!("{plan:?}"); + let plan_str_ref = format!("{plan_ref}"); + let plan_str = format!("{plan}"); assert_eq!(plan_str_ref, plan_str); // Delete test binary file fs::remove_file(path)?; @@ -57,12 +60,109 @@ mod tests { let ctx = create_context().await?; let table = provider_as_source(ctx.table_provider("data").await?); let table_scan = LogicalPlanBuilder::scan("data", table, None)?.build()?; - let convert_result = producer::to_substrait_plan(&table_scan, &ctx); + let convert_result = to_substrait_plan(&table_scan, &ctx); assert!(convert_result.is_ok()); Ok(()) } + #[tokio::test] + async fn include_remaps_for_projects() -> Result<()> { + let ctx = create_context().await?; + let df = ctx.sql("SELECT b, a + a, a FROM data").await?; + let datafusion_plan = df.into_optimized_plan()?; + + assert_eq!( + format!("{}", datafusion_plan), + "Projection: data.b, data.a + data.a, data.a\ + \n TableScan: data projection=[a, b]", + ); + + let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + + let relation = plan.relations.first().unwrap().rel_type.as_ref(); + let root_rel = match relation { + Some(RelType::Root(root)) => root.input.as_ref().unwrap(), + _ => panic!("expected Root"), + }; + if let Some(rel::RelType::Project(p)) = root_rel.rel_type.as_ref() { + // The input has 2 columns [a, b], the Projection has 3 expressions [b, a + a, a] + // The required output mapping is [2,3,4], which skips the 2 input columns. + assert_emit(p.common.as_ref(), vec![2, 3, 4]); + + if let Some(rel::RelType::Read(r)) = + p.input.as_ref().unwrap().rel_type.as_ref() + { + let mask_expression = r.projection.as_ref().unwrap(); + let select = mask_expression.select.as_ref().unwrap(); + assert_eq!( + 2, + select.struct_items.len(), + "Read outputs two columns: a, b" + ); + return Ok(()); + } + } + panic!("plan did not match expected structure") + } + + #[tokio::test] + async fn include_remaps_for_windows() -> Result<()> { + let ctx = create_context().await?; + // let df = ctx.sql("SELECT a, b, lead(b) OVER (PARTITION BY a) FROM data").await?; + let df = ctx + .sql("SELECT b, RANK() OVER (PARTITION BY a), c FROM data;") + .await?; + let datafusion_plan = df.into_optimized_plan()?; + assert_eq!( + format!("{}", datafusion_plan), + "Projection: data.b, rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING, data.c\ + \n WindowAggr: windowExpr=[[rank() PARTITION BY [data.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + \n TableScan: data projection=[a, b, c]", + ); + + let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone(); + + let relation = plan.relations.first().unwrap().rel_type.as_ref(); + let root_rel = match relation { + Some(RelType::Root(root)) => root.input.as_ref().unwrap(), + _ => panic!("expected Root"), + }; + + if let Some(rel::RelType::Project(p1)) = root_rel.rel_type.as_ref() { + // The WindowAggr outputs 4 columns, the Projection has 4 columns + assert_emit(p1.common.as_ref(), vec![4, 5, 6]); + + if let Some(rel::RelType::Project(p2)) = + p1.input.as_ref().unwrap().rel_type.as_ref() + { + // The input has 3 columns, the WindowAggr has 4 expression + assert_emit(p2.common.as_ref(), vec![3, 4, 5, 6]); + + if let Some(rel::RelType::Read(r)) = + p2.input.as_ref().unwrap().rel_type.as_ref() + { + let mask_expression = r.projection.as_ref().unwrap(); + let select = mask_expression.select.as_ref().unwrap(); + assert_eq!( + 3, + select.struct_items.len(), + "Read outputs three columns: a, b, c" + ); + return Ok(()); + } + } + } + panic!("plan did not match expected structure") + } + + fn assert_emit(rel_common: Option<&RelCommon>, output_mapping: Vec) { + assert_eq!( + rel_common.unwrap().emit_kind.clone(), + Some(EmitKind::Emit(Emit { output_mapping })) + ); + } + async fn create_context() -> Result { let ctx = SessionContext::new(); ctx.register_csv("data", "tests/testdata/data.csv", CsvReadOptions::new()) diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs b/datafusion/substrait/tests/cases/substrait_validations.rs new file mode 100644 index 000000000000..5ae586afe56f --- /dev/null +++ b/datafusion/substrait/tests/cases/substrait_validations.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +mod tests { + + // verify the schema compatability validations + mod schema_compatability { + use crate::utils::test::read_json; + use datafusion::arrow::datatypes::{DataType, Field}; + use datafusion::catalog_common::TableReference; + use datafusion::common::{DFSchema, Result}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::prelude::SessionContext; + use datafusion_substrait::logical_plan::consumer::from_substrait_plan; + use std::collections::HashMap; + use std::sync::Arc; + + fn generate_context_with_table( + table_name: &str, + fields: Vec<(&str, DataType, bool)>, + ) -> Result { + let table_ref = TableReference::bare(table_name); + let fields: Vec<(Option, Arc)> = fields + .into_iter() + .map(|pair| { + let (field_name, data_type, nullable) = pair; + ( + Some(table_ref.clone()), + Arc::new(Field::new(field_name, data_type, nullable)), + ) + }) + .collect(); + + let df_schema = DFSchema::new_with_metadata(fields, HashMap::default())?; + + let ctx = SessionContext::new(); + ctx.register_table( + table_ref, + Arc::new(EmptyTable::new(df_schema.inner().clone())), + )?; + Ok(ctx) + } + + #[tokio::test] + async fn ensure_schema_match_exact() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // this is the exact schema of the Substrait plan + let df_schema = + vec![("a", DataType::Int32, false), ("b", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n TableScan: DATA" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_subset() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // the DataFusion schema { b, a, c } contains the Substrait schema { a, b } + let df_schema = vec![ + ("b", DataType::Int32, true), + ("a", DataType::Int32, false), + ("c", DataType::Int32, false), + ]; + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n TableScan: DATA projection=[a, b]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_subset_with_mask() -> Result<()> { + let proto_plan = read_json( + "tests/testdata/test_plans/simple_select_with_mask.substrait.json", + ); + // the DataFusion schema { d, a, c, b } contains the Substrait schema { a, b, c } + let df_schema = vec![ + ("d", DataType::Int32, true), + ("a", DataType::Int32, false), + ("c", DataType::Int32, false), + ("b", DataType::Int32, false), + ]; + let ctx = generate_context_with_table("DATA", df_schema)?; + let plan = from_substrait_plan(&ctx, &proto_plan).await?; + + assert_eq!( + format!("{}", plan), + "Projection: DATA.a, DATA.b\ + \n TableScan: DATA projection=[a, b]" + ); + Ok(()) + } + + #[tokio::test] + async fn ensure_schema_match_not_subset() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + // the substrait plans contains a field b which is not in the schema + let df_schema = + vec![("a", DataType::Int32, false), ("c", DataType::Int32, true)]; + + let ctx = generate_context_with_table("DATA", df_schema)?; + let res = from_substrait_plan(&ctx, &proto_plan).await; + assert!(res.is_err()); + Ok(()) + } + + #[tokio::test] + async fn reject_plans_with_incompatible_field_types() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/simple_select.substrait.json"); + + let ctx = + generate_context_with_table("DATA", vec![("a", DataType::Date32, true)])?; + let res = from_substrait_plan(&ctx, &proto_plan).await; + assert!(res.is_err()); + Ok(()) + } + } +} diff --git a/datafusion/substrait/tests/substrait_integration.rs b/datafusion/substrait/tests/substrait_integration.rs index 6ce41c9de71a..eedd4da373e0 100644 --- a/datafusion/substrait/tests/substrait_integration.rs +++ b/datafusion/substrait/tests/substrait_integration.rs @@ -17,3 +17,4 @@ /// Run all tests that are found in the `cases` directory mod cases; +mod utils; diff --git a/datafusion/substrait/tests/testdata/Readme.md b/datafusion/substrait/tests/testdata/Readme.md new file mode 100644 index 000000000000..c1bd48abf96e --- /dev/null +++ b/datafusion/substrait/tests/testdata/Readme.md @@ -0,0 +1,51 @@ + + +# Apache DataFusion Substrait Test Data + +This folder contains test data for the [substrait] crate. + +The substrait crate is at an init stage and many functions not implemented yet. Compared to the [parquet-testing](https://github.com/apache/parquet-testing) submodule, this folder contains only simple test data evolving around the substrait producers and consumers for [logical plans](https://github.com/apache/datafusion/tree/main/datafusion/substrait/src/logical_plan) and [physical plans](https://github.com/apache/datafusion/tree/main/datafusion/substrait/src/physical_plan). + +## Test Data + +### Example Data + +- [empty.csv](https://github.com/apache/datafusion/blob/main/datafusion/substrait/tests/testdata/empty.csv): An empty CSV file. +- [empty.parquet](https://github.com/apache/datafusion/blob/main/datafusion/substrait/tests/testdata/empty.parquet): An empty Parquet file with metadata only. +- [data.csv](https://github.com/apache/datafusion/blob/main/datafusion/substrait/tests/testdata/data.csv): A simple CSV file with 6 columns and 2 rows. +- [data.parquet](https://github.com/apache/datafusion/blob/main/datafusion/substrait/tests/testdata/data.parquet): A simple Parquet generated from the CSV file using `pandas`, e.g., + + ```python + import pandas as pd + + df = pandas.read_csv('data.csv') + df.to_parquet('data.parquet') + ``` + +### Add new test data + +To add a new test data, create a new file in this folder, reference it in the test source file, e.g., + +```rust +let ctx = SessionContext::new(); +let explicit_options = ParquetReadOptions::default(); + +ctx.register_parquet("data", "tests/testdata/data.parquet", explicit_options) +``` diff --git a/datafusion/substrait/tests/testdata/contains_plan.substrait.json b/datafusion/substrait/tests/testdata/contains_plan.substrait.json new file mode 100644 index 000000000000..76edde34e3b0 --- /dev/null +++ b/datafusion/substrait/tests/testdata/contains_plan.substrait.json @@ -0,0 +1,133 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_string.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "contains:str_str" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 4 + ] + } + }, + "input": { + "filter": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "n_nationkey", + "n_name", + "n_regionkey", + "n_comment" + ], + "struct": { + "types": [ + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "nation" + ] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + }, + { + "value": { + "literal": { + "string": "IA" + } + } + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "n_name" + ] + } + } + ], + "version": { + "minorNumber": 38, + "producer": "ibis-substrait" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/data.csv b/datafusion/substrait/tests/testdata/data.csv index 1b85b166b1df..ef2766d29565 100644 --- a/datafusion/substrait/tests/testdata/data.csv +++ b/datafusion/substrait/tests/testdata/data.csv @@ -1,3 +1,3 @@ a,b,c,d,e,f -1,2.0,2020-01-01,false,4294967296,'a' -3,4.5,2020-01-01,true,2147483648,'b' \ No newline at end of file +1,2.0,2020-01-01,false,4294967295,'a' +3,4.5,2020-01-01,true,2147483648,'b' diff --git a/datafusion/substrait/tests/testdata/data.parquet b/datafusion/substrait/tests/testdata/data.parquet new file mode 100644 index 000000000000..f9c03394db43 Binary files /dev/null and b/datafusion/substrait/tests/testdata/data.parquet differ diff --git a/datafusion/substrait/tests/testdata/empty.parquet b/datafusion/substrait/tests/testdata/empty.parquet new file mode 100644 index 000000000000..3f135e77f498 Binary files /dev/null and b/datafusion/substrait/tests/testdata/empty.parquet differ diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project.substrait.json new file mode 100644 index 000000000000..ed8675b96826 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project.substrait.json @@ -0,0 +1,97 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json new file mode 100644 index 000000000000..b6f14afd6fa9 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_no_project_group_expression_ref.substrait.json @@ -0,0 +1,98 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "grouping_expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ], + "groupings": [ + { + "expression_references": [0] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json b/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json new file mode 100644 index 000000000000..d5170223cd65 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/aggregate_sorted_no_project.substrait.json @@ -0,0 +1,113 @@ +{ + "extensionUris": [ + { + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_aggregate_generic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "functionAnchor": 185, + "name": "count:any" + } + } + ], + "relations": [ + { + "root": { + "input": { + "aggregate": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "groupings": [ + { + "groupingExpressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + ], + "measures": [ + { + "measure": { + "functionReference": 185, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": {} + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + } + ] + } + } + ] + } + }, + "names": [ + "a", + "countA" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "manual" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json b/datafusion/substrait/tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json new file mode 100644 index 000000000000..63b275e1723f --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json @@ -0,0 +1,90 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "add:i64_i64" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["A", "B"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["DATA"] + } + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i64": 1, + "nullable": false, + "typeVariationReference": 0 + } + } + }], + "options": [] + } + }] + } + }, + "names": ["a", "b", "add1"] + } + }], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json b/datafusion/substrait/tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json new file mode 100644 index 000000000000..2fc970155955 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json @@ -0,0 +1,91 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "filter": { + "common": { + "emit": { + "outputMapping": [1, 0, 0] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["A", "B"], + "struct": { + "types": [{ + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i64": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["DATA"] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i64": "2", + "nullable": false, + "typeVariationReference": 0 + } + } + }], + "options": [] + } + } + } + }, + "names": ["B", "A1", "A2"] + } + }], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json new file mode 100644 index 000000000000..b9a2e4ad1403 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json new file mode 100644 index 000000000000..8ff69bd82c3a --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json new file mode 100644 index 000000000000..56daf6ed46f4 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json new file mode 100644 index 000000000000..229dd7251705 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json new file mode 100644 index 000000000000..33b0e2ab8c80 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json new file mode 100644 index 000000000000..229f78ab5bf6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json b/datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json new file mode 100644 index 000000000000..e1c5574f8bec --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/non_nullable_lists.substrait.json @@ -0,0 +1,71 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "col" + ], + "struct": { + "types": [ + { + "list": { + "type": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "virtualTable": { + "values": [ + { + "fields": [ + { + "list": { + "values": [ + { + "i32": 1, + "nullable": false, + "typeVariationReference": 0 + }, + { + "i32": 2, + "nullable": false, + "typeVariationReference": 0 + } + ] + }, + "nullable": false, + "typeVariationReference": 0 + } + ] + } + ] + } + } + }, + "names": [ + "col" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json new file mode 100644 index 000000000000..e52cf87d5028 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/select_not_bool.substrait.json @@ -0,0 +1,98 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "not:bool" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "D" + ], + "struct": { + "types": [ + { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "scalarFunction": { + "functionReference": 0, + "args": [], + "outputType": { + "bool": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ], + "options": [] + } + } + ] + } + }, + "names": [ + "EXPR$0" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json new file mode 100644 index 000000000000..3082c4258f83 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/select_window.substrait.json @@ -0,0 +1,153 @@ +{ + "extensionUris": [ + { + "extensionUriAnchor": 1, + "uri": "/functions_arithmetic.yaml" + } + ], + "extensions": [ + { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 0, + "name": "sum:i32" + } + } + ], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": [ + "D", + "PART", + "ORD" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + } + } + }, + "expressions": [ + { + "windowFunction": { + "functionReference": 0, + "partitions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + ], + "sorts": [ + { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + } + ], + "upperBound": { + "unbounded": { + } + }, + "lowerBound": { + "preceding": { + "offset": "1" + } + }, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "args": [], + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + } + } + ], + "invocation": "AGGREGATION_INVOCATION_ALL", + "options": [], + "boundsType": "BOUNDS_TYPE_ROWS" + } + } + ] + } + }, + "names": [ + "LEAD_EXPR" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json b/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json new file mode 100644 index 000000000000..aee27ef3b417 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/simple_select.substrait.json @@ -0,0 +1,69 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["a", "b"], + "struct": { + "types": [{ + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["DATA"] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "names": ["a", "b"] + } + }], + "expectedTypeUrls": [] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json b/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json new file mode 100644 index 000000000000..774126ca3836 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/simple_select_with_mask.substrait.json @@ -0,0 +1,104 @@ +{ + "extensionUris": [], + "extensions": [], + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 2, + 3 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a", + "b", + "c" + ], + "struct": { + "types": [ + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "i32": { + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "typeVariationReference": 0, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "DATA" + ] + }, + "projection": { + "select": { + "struct_items": [ + { + "field": 0 + }, + { + "field": 1 + } + ] + } + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": { + "field": 0 + } + }, + "rootReference": {} + } + }, + { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "a", + "b" + ] + } + } + ], + "expectedTypeUrls": [] +} diff --git a/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json new file mode 100644 index 000000000000..e8b02749660d --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_UNION_DISTINCT" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md b/datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md new file mode 100644 index 000000000000..ffcd38dfb88d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/README.md @@ -0,0 +1,22 @@ + + +# Apache DataFusion Substrait consumer integration test + +these test json files come from [consumer-testing](https://github.com/substrait-io/consumer-testing/tree/main/substrait_consumer/tests/integration/queries/tpch_substrait_plans) diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_01_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_01_plan.json new file mode 100644 index 000000000000..3738a50a6238 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_01_plan.json @@ -0,0 +1,723 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 3, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_datetime.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "lte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "subtract:date_iday" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 3, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "add:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 5, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 6, + "name": "avg:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 7, + "name": "count:" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16, 17, 18, 19, 20, 21, 22] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "literal": { + "date": 10561 + } + } + }, { + "value": { + "literal": { + "intervalDayToSecond": { + "seconds": 10368 + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }, { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "decimal": { + "scale": 6, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 6, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["L_RETURNFLAG", "L_LINESTATUS", "SUM_QTY", "SUM_BASE_PRICE", "SUM_DISC_PRICE", "SUM_CHARGE", "AVG_QTY", "AVG_PRICE", "AVG_DISC", "COUNT_ORDER"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_02_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_02_plan.json new file mode 100644 index 000000000000..f6c5e802a5e3 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_02_plan.json @@ -0,0 +1,1157 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "like:str_str" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 3, + "name": "min:dec" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [28, 29, 30, 31, 32, 33, 34, 35] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["R_REGIONKEY", "R_NAME", "R_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["REGION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 15 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "%BRASS" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 26 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "subquery": { + "scalar": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [19] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["R_REGIONKEY", "R_NAME", "R_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["REGION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 3, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 15 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "count": "100" + } + }, + "names": ["S_ACCTBAL", "S_NAME", "N_NAME", "P_PARTKEY", "P_MFGR", "S_ADDRESS", "S_PHONE", "S_COMMENT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_03_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_03_plan.json new file mode 100644 index 000000000000..d4dea1d03c46 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_03_plan.json @@ -0,0 +1,742 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "gt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [4, 5, 6, 7] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [33, 34, 35, 36] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "BUILDING" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 24 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 28 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-03-15" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-03-15" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 28 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 31 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "count": "10" + } + }, + "names": ["L_ORDERKEY", "REVENUE", "O_ORDERDATE", "O_SHIPPRIORITY"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_04_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_04_plan.json new file mode 100644 index 000000000000..3e665f50f320 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_04_plan.json @@ -0,0 +1,464 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "count:" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [9] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1993-07-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1993-10-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + } + } + } + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 4, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["O_ORDERPRIORITY", "ORDER_COUNT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_05_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_05_plan.json new file mode 100644 index 000000000000..d42975d3326d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_05_plan.json @@ -0,0 +1,912 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [47, 48] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["R_REGIONKEY", "R_NAME", "R_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["REGION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 40 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 42 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 44 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 45 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "ASIA" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }] + } + }, + "names": ["N_NAME", "REVENUE"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_06_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_06_plan.json new file mode 100644 index 000000000000..c26f2861e0d1 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_06_plan.json @@ -0,0 +1,448 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "gte:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "lte:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 5, + "name": "lt:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 7, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "decimal": { + "value": "BQAAAAAAAAAAAAAAAAAAAA==", + "precision": 3, + "scale": 2 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "decimal": { + "value": "BwAAAAAAAAAAAAAAAAAAAA==", + "precision": 3, + "scale": 2 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 24 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "names": ["REVENUE"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_07_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_07_plan.json new file mode 100644 index 000000000000..82740fb3d87b --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_07_plan.json @@ -0,0 +1,1095 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 2, + "name": "or:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "lte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 5, + "name": "extract:req_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 7, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 8, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [48, 49, 50, 51] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 32 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 24 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 40 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 35 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 44 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "FRANCE" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 45 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "GERMANY" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "GERMANY" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 45 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "FRANCE" + } + } + }] + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1996-12-31" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 45 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "enum": "YEAR" + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 7, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 8, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["SUPP_NATION", "CUST_NATION", "L_YEAR", "REVENUE"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_08_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_08_plan.json new file mode 100644 index 000000000000..8c886f84ed16 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_08_plan.json @@ -0,0 +1,1301 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "extract:req_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 7, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 8, + "name": "divide:dec_dec" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [3, 4] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [60, 61, 62] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["R_REGIONKEY", "R_NAME", "R_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["REGION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 32 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 44 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 49 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 51 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 57 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 58 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "AMERICA" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 53 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1996-12-31" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "ECONOMY ANODIZED STEEL" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 4, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "enum": "YEAR" + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 36 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "ifThen": { + "ifs": [{ + "if": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 54 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "BRAZIL" + } + } + }] + } + }, + "then": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + }], + "else": { + "literal": { + "decimal": { + "value": "AAAAAAAAAAAAAAAAAAAAAA==", + "precision": 19, + "scale": 4 + } + } + } + } + }, { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 8, + "outputType": { + "decimal": { + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["O_YEAR", "MKT_SHARE"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_09_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_09_plan.json new file mode 100644 index 000000000000..04b367a0b5bf --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_09_plan.json @@ -0,0 +1,957 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 5, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "like:str_str" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 3, + "name": "extract:req_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 4, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 5, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 6, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [50, 51, 52] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 32 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 37 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 46 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "%green%" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 47 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "enum": "YEAR" + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 41 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 35 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 20 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }] + } + }, + "names": ["NATION", "O_YEAR", "SUM_PROFIT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_10_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_10_plan.json new file mode 100644 index 000000000000..2daa1dabb423 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_10_plan.json @@ -0,0 +1,927 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [8, 9, 10, 11, 12, 13, 14, 15] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [37, 38, 39, 40, 41, 42, 43, 44] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1993-10-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "R" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 34 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }] + } + }, + "count": "20" + } + }, + "names": ["C_CUSTKEY", "C_NAME", "REVENUE", "C_ACCTBAL", "N_NAME", "C_ADDRESS", "C_PHONE", "C_COMMENT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_11_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_11_plan.json new file mode 100644 index 000000000000..d79b065403d5 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_11_plan.json @@ -0,0 +1,872 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "gt:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16, 17] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "JAPAN" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 3, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "condition": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "subquery": { + "scalar": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [1] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "JAPAN" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 2, + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 3, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 2, + "outputType": { + "decimal": { + "scale": 12, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "decimal": { + "value": "QEIPAAAAAAAAAAAAAAAAAA==", + "precision": 11, + "scale": 10 + } + } + } + }] + } + }] + } + } + } + } + } + }] + } + } + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }] + } + }, + "names": ["PS_PARTKEY", "value"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_12_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_12_plan.json new file mode 100644 index 000000000000..db3100052704 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_12_plan.json @@ -0,0 +1,794 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 2, + "name": "or:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 5, + "name": "not_equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "sum:i32" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [25, 26, 27] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "MAIL" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "SHIP" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 20 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 20 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + }, { + "ifThen": { + "ifs": [{ + "if": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "1-URGENT" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "2-HIGH" + } + } + }] + } + } + }] + } + }, + "then": { + "literal": { + "i32": 1 + } + } + }], + "else": { + "literal": { + "i32": 0 + } + } + } + }, { + "ifThen": { + "ifs": [{ + "if": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "1-URGENT" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "2-HIGH" + } + } + }] + } + } + }] + } + }, + "then": { + "literal": { + "i32": 1 + } + } + }], + "else": { + "literal": { + "i32": 0 + } + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["L_SHIPMODE", "HIGH_LINE_COUNT", "LOW_LINE_COUNT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13_plan.json new file mode 100644 index 000000000000..19b80b0aac73 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_13_plan.json @@ -0,0 +1,459 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 2, + "name": "not:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "like:str_str" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "count:any" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "count:" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2, 3] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [17, 18] + } + }, + "input": { + "join": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + }, + "expression": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "%special%requests%" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }] + } + }, + "type": "JOIN_TYPE_LEFT" + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 4, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + }] + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }] + } + }, + "names": ["C_COUNT", "CUSTDIST"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14_plan.json new file mode 100644 index 000000000000..81daf41caa81 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_14_plan.json @@ -0,0 +1,686 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 5, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "like:str_str" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 5, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 6, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 7, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 8, + "name": "divide:dec_dec" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [25, 26] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "date": 9374 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-10-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "ifThen": { + "ifs": [{ + "if": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 20 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "PROMO%" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + }, + "then": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + }], + "else": { + "literal": { + "decimal": { + "value": "AAAAAAAAAAAAAAAAAAAAAA==", + "precision": 19, + "scale": 4 + } + } + } + } + }, { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }, { + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 8, + "outputType": { + "decimal": { + "scale": 2, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "decimal": { + "scale": 6, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "literal": { + "decimal": { + "value": "ECcAAAAAAAAAAAAAAAAAAA==", + "precision": 5, + "scale": 2 + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "names": ["PROMO_REVENUE"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_15_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_15_plan.json new file mode 100644 index 000000000000..0967ef424bce --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_15_plan.json @@ -0,0 +1 @@ +{} diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16_plan.json new file mode 100644 index 000000000000..bf97fb918571 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_16_plan.json @@ -0,0 +1,872 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "not_equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 3, + "name": "not:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "like:str_str" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 5, + "name": "or:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "count:any" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [14, 15, 16, 17] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "Brand#45" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "MEDIUM POLISHED%" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 49 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 14 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 23 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 45 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 19 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 3 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 36 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 9 + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "subquery": { + "inPredicate": { + "needles": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [7] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "%Customer%Complaints%" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + } + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_DISTINCT", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["P_BRAND", "P_TYPE", "P_SIZE", "SUPPLIER_CNT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_17_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_17_plan.json new file mode 100644 index 000000000000..3135e68fd527 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_17_plan.json @@ -0,0 +1,690 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "lt:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "avg:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 4, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 5, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 6, + "name": "divide:dec_dec" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [1] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [25] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "Brand#23" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "MED BOX" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "subquery": { + "scalar": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [1] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 3, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 4, + "outputType": { + "decimal": { + "scale": 3, + "precision": 17, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "literal": { + "decimal": { + "value": "AgAAAAAAAAAAAAAAAAAAAA==", + "precision": 2, + "scale": 1 + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 5, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "decimal": { + "value": "RgAAAAAAAAAAAAAAAAAAAA==", + "precision": 2, + "scale": 1 + } + } + } + }] + } + }] + } + }, + "names": ["AVG_YEARLY"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18_plan.json new file mode 100644 index 000000000000..7f0ff438db78 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_18_plan.json @@ -0,0 +1,796 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gt:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "equal:any_any" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [33, 34, 35, 36, 37, 38] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "subquery": { + "inPredicate": { + "needles": [{ + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16, 17] + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "condition": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "literal": { + "i32": 300 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + } + } + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 17 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 1, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "count": "100" + } + }, + "names": ["C_NAME", "C_CUSTKEY", "O_ORDERKEY", "O_ORDERDATE", "O_TOTALPRICE", "EXPR$5"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19_plan.json new file mode 100644 index 000000000000..8ea0bc881c55 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_19_plan.json @@ -0,0 +1,1956 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 3, + "uri": "/functions_arithmetic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "or:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 3, + "name": "gte:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "lte:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 5, + "name": "add:i32_i32" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 6, + "name": "multiply:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 7, + "name": "subtract:dec_dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 8, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [25] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "Brand#12" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "SM CASE" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "SM BOX" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "SM PACK" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "SM PKG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 10 + } + } + }] + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 5 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "AIR" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "AIR REG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "DELIVER IN PERSON" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "Brand#23" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "MED BAG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "MED BOX" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "MED PKG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "MED PACK" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 10 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "literal": { + "i32": 10 + } + } + }, { + "value": { + "literal": { + "i32": 10 + } + } + }] + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 10 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "AIR" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "AIR REG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "DELIVER IN PERSON" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 16 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "Brand#34" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "LG CASE" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "LG BOX" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "LG PACK" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 22 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "LG PKG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 20 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "literal": { + "i32": 20 + } + } + }, { + "value": { + "literal": { + "i32": 10 + } + } + }] + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 21 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 15 + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "AIR" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 14 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "AIR REG" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 13 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "DELIVER IN PERSON" + } + } + }] + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 6, + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 7, + "outputType": { + "decimal": { + "scale": 2, + "precision": 16, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "i32": 1 + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 6 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 8, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 4, + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "names": ["REVENUE"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20_plan.json new file mode 100644 index 000000000000..a616e3fc066d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_20_plan.json @@ -0,0 +1,932 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 5, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "like:str_str" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gt:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 4, + "name": "gte:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "lt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 6, + "name": "sum:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 7, + "name": "multiply:dec_dec" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [11, 12] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "subquery": { + "inPredicate": { + "needles": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [5] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["PS_PARTKEY", "PS_SUPPKEY", "PS_AVAILQTY", "PS_SUPPLYCOST", "PS_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PARTSUPP"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "subquery": { + "inPredicate": { + "needles": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }], + "haystack": { + "project": { + "common": { + "emit": { + "outputMapping": [9] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["P_PARTKEY", "P_NAME", "P_MFGR", "P_BRAND", "P_TYPE", "P_SIZE", "P_CONTAINER", "P_RETAILPRICE", "P_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["PART"] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "forest%" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + } + } + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "cast": { + "type": { + "decimal": { + "precision": 19, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "input": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }, { + "value": { + "subquery": { + "scalar": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [1] + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [16] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1994-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 5, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 10 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "cast": { + "type": { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "1995-01-01" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 6, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 7, + "outputType": { + "decimal": { + "scale": 3, + "precision": 17, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "literal": { + "decimal": { + "value": "BQAAAAAAAAAAAAAAAAAAAA==", + "precision": 2, + "scale": 1 + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + } + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 8 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "CANADA" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["S_NAME", "S_ADDRESS"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21_plan.json new file mode 100644 index 000000000000..c3d4fc3bcb87 --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_21_plan.json @@ -0,0 +1,1050 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 4, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_datetime.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 1, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 2, + "name": "gt:date_date" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 3, + "name": "not_equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 4, + "name": "not:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "count:" + } + }], + "relations": [{ + "root": { + "input": { + "fetch": { + "common": { + "direct": { + } + }, + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [36] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "cross": { + "common": { + "direct": { + } + }, + "left": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_NATIONKEY", "S_PHONE", "S_ACCTBAL", "S_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["SUPPLIER"] + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + } + } + }, + "right": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["N_NATIONKEY", "N_NAME", "N_REGIONKEY", "N_COMMENT"], + "struct": { + "types": [{ + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["NATION"] + } + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 23 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 25 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "F" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 19 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 18 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }] + } + } + } + } + } + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["L_ORDERKEY", "L_PARTKEY", "L_SUPPKEY", "L_LINENUMBER", "L_QUANTITY", "L_EXTENDEDPRICE", "L_DISCOUNT", "L_TAX", "L_RETURNFLAG", "L_LINESTATUS", "L_SHIPDATE", "L_COMMITDATE", "L_RECEIPTDATE", "L_SHIPINSTRUCT", "L_SHIPMODE", "L_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["LINEITEM"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 7 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 2 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 9 + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 12 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 11 + } + }, + "rootReference": { + } + } + } + }] + } + } + }] + } + } + } + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 3 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + "field": 32 + } + }, + "rootReference": { + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 33 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "string": "SAUDI ARABIA" + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_DESC_NULLS_FIRST" + }, { + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "count": "100" + } + }, + "names": ["S_NAME", "NUMWAIT"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22_plan.json b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22_plan.json new file mode 100644 index 000000000000..fcd61b23ae2d --- /dev/null +++ b/datafusion/substrait/tests/testdata/tpch_substrait_plans/query_22_plan.json @@ -0,0 +1,1510 @@ +{ + "extensionUris": [{ + "extensionUriAnchor": 5, + "uri": "/functions_aggregate_generic.yaml" + }, { + "extensionUriAnchor": 1, + "uri": "/functions_boolean.yaml" + }, { + "extensionUriAnchor": 3, + "uri": "/functions_string.yaml" + }, { + "extensionUriAnchor": 4, + "uri": "/functions_arithmetic_decimal.yaml" + }, { + "extensionUriAnchor": 2, + "uri": "/functions_comparison.yaml" + }], + "extensions": [{ + "extensionFunction": { + "extensionUriReference": 1, + "name": "and:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 1, + "name": "or:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 2, + "name": "equal:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 3, + "functionAnchor": 3, + "name": "substring:str_i32_i32" + } + }, { + "extensionFunction": { + "extensionUriReference": 2, + "functionAnchor": 4, + "name": "gt:any_any" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 5, + "name": "avg:dec" + } + }, { + "extensionFunction": { + "extensionUriReference": 1, + "functionAnchor": 6, + "name": "not:bool" + } + }, { + "extensionFunction": { + "extensionUriReference": 5, + "functionAnchor": 7, + "name": "count:" + } + }, { + "extensionFunction": { + "extensionUriReference": 4, + "functionAnchor": 8, + "name": "sum:dec" + } + }], + "relations": [{ + "root": { + "input": { + "sort": { + "common": { + "direct": { + } + }, + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [8, 9] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "13" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "31" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "23" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "29" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "30" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "18" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "17" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "subquery": { + "scalar": { + "input": { + "aggregate": { + "common": { + "direct": { + } + }, + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [8] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["C_CUSTKEY", "C_NAME", "C_ADDRESS", "C_NATIONKEY", "C_PHONE", "C_ACCTBAL", "C_MKTSEGMENT", "C_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["CUSTOMER"] + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 4, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "decimal": { + "value": "AAAAAAAAAAAAAAAAAAAAAA==", + "precision": 3, + "scale": 2 + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "13" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "31" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "23" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "29" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "30" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "18" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + } + }, { + "value": { + "cast": { + "type": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "input": { + "literal": { + "fixedChar": "17" + } + }, + "failureBehavior": "FAILURE_BEHAVIOR_THROW_EXCEPTION" + } + } + }] + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + }], + "measures": [{ + "measure": { + "functionReference": 5, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 6, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "subquery": { + "setPredicate": { + "predicateOp": "PREDICATE_OP_EXISTS", + "tuples": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["O_ORDERKEY", "O_CUSTKEY", "O_ORDERSTATUS", "O_TOTALPRICE", "O_ORDERDATE", "O_ORDERPRIORITY", "O_CLERK", "O_SHIPPRIORITY", "O_COMMENT"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "date": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "i32": { + "nullability": "NULLABILITY_REQUIRED" + } + }, { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": ["ORDERS"] + } + } + }, + "condition": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "selection": { + "directReference": { + "structField": { + } + }, + "outerReference": { + "stepsOut": 1 + } + } + } + }] + } + } + } + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "scalarFunction": { + "functionReference": 3, + "outputType": { + "string": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 4 + } + }, + "rootReference": { + } + } + } + }, { + "value": { + "literal": { + "i32": 1 + } + } + }, { + "value": { + "literal": { + "i32": 2 + } + } + }] + } + }, { + "selection": { + "directReference": { + "structField": { + "field": 5 + } + }, + "rootReference": { + } + } + }] + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 7, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "i64": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL" + } + }, { + "measure": { + "functionReference": 8, + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "outputType": { + "decimal": { + "scale": 2, + "precision": 15, + "nullability": "NULLABILITY_NULLABLE" + } + }, + "invocation": "AGGREGATION_INVOCATION_ALL", + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + }, + "rootReference": { + } + } + } + }] + } + }] + } + }, + "sorts": [{ + "expr": { + "selection": { + "directReference": { + "structField": { + } + }, + "rootReference": { + } + } + }, + "direction": "SORT_DIRECTION_ASC_NULLS_LAST" + }] + } + }, + "names": ["CNTRYCODE", "NUMCUST", "TOTACCTBAL"] + } + }] +} \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs new file mode 100644 index 000000000000..00cbfb0c412c --- /dev/null +++ b/datafusion/substrait/tests/utils.rs @@ -0,0 +1,492 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(test)] +pub mod test { + use datafusion::catalog_common::TableReference; + use datafusion::common::{substrait_datafusion_err, substrait_err}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::TableProvider; + use datafusion::error::Result; + use datafusion::prelude::SessionContext; + use datafusion_substrait::extensions::Extensions; + use datafusion_substrait::logical_plan::consumer::from_substrait_named_struct; + use std::collections::HashMap; + use std::fs::File; + use std::io::BufReader; + use std::sync::Arc; + use substrait::proto::exchange_rel::ExchangeKind; + use substrait::proto::expand_rel::expand_field::FieldType; + use substrait::proto::expression::nested::NestedType; + use substrait::proto::expression::subquery::SubqueryType; + use substrait::proto::expression::RexType; + use substrait::proto::function_argument::ArgType; + use substrait::proto::read_rel::{NamedTable, ReadType}; + use substrait::proto::rel::RelType; + use substrait::proto::{Expression, FunctionArgument, Plan, ReadRel, Rel}; + + pub fn read_json(path: &str) -> Plan { + serde_json::from_reader::<_, Plan>(BufReader::new( + File::open(path).expect("file not found"), + )) + .expect("failed to parse json") + } + + pub fn add_plan_schemas_to_ctx( + ctx: SessionContext, + plan: &Plan, + ) -> Result { + let schemas = TestSchemaCollector::collect_schemas(plan)?; + let mut schema_map: HashMap> = + HashMap::new(); + for (table_reference, table) in schemas.into_iter() { + let schema = table.schema(); + if let Some(existing_table) = + schema_map.insert(table_reference.clone(), table) + { + if existing_table.schema() != schema { + return substrait_err!( + "Substrait plan contained the same table {} with different schemas.\nSchema 1: {}\nSchema 2: {}", + table_reference, existing_table.schema(), schema); + } + } + } + for (table_reference, table) in schema_map.into_iter() { + ctx.register_table(table_reference, table)?; + } + Ok(ctx) + } + + pub struct TestSchemaCollector { + schemas: Vec<(TableReference, Arc)>, + } + + impl TestSchemaCollector { + fn new() -> Self { + TestSchemaCollector { + schemas: Vec::new(), + } + } + + fn collect_schemas( + plan: &Plan, + ) -> Result)>> { + let mut schema_collector = Self::new(); + + for plan_rel in plan.relations.iter() { + let rel_type = plan_rel + .rel_type + .as_ref() + .ok_or(substrait_datafusion_err!("PlanRel must set rel_type"))?; + match rel_type { + substrait::proto::plan_rel::RelType::Rel(r) => { + schema_collector.collect_schemas_from_rel(r)? + } + substrait::proto::plan_rel::RelType::Root(r) => { + let input = r + .input + .as_ref() + .ok_or(substrait_datafusion_err!("RelRoot must set input"))?; + schema_collector.collect_schemas_from_rel(input)? + } + } + } + Ok(schema_collector.schemas) + } + + fn collect_named_table(&mut self, read: &ReadRel, nt: &NamedTable) -> Result<()> { + let table_reference = match nt.names.len() { + 0 => { + panic!("No table name found in NamedTable"); + } + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + let substrait_schema = + read.base_schema.as_ref().ok_or(substrait_datafusion_err!( + "No base schema found for NamedTable: {}", + table_reference + ))?; + let empty_extensions = Extensions { + functions: Default::default(), + types: Default::default(), + type_variations: Default::default(), + }; + + let df_schema = + from_substrait_named_struct(substrait_schema, &empty_extensions)? + .replace_qualifier(table_reference.clone()); + + let table = EmptyTable::new(df_schema.inner().clone()); + self.schemas.push((table_reference, Arc::new(table))); + Ok(()) + } + + #[allow(deprecated)] + fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> { + let rel_type = rel + .rel_type + .as_ref() + .ok_or(substrait_datafusion_err!("Rel must set rel_type"))?; + match rel_type { + RelType::Read(r) => { + let read_type = r + .read_type + .as_ref() + .ok_or(substrait_datafusion_err!("Read must set read_type"))?; + match read_type { + // Virtual Tables do not contribute to the schema + ReadType::VirtualTable(_) => (), + ReadType::LocalFiles(_) => todo!(), + ReadType::NamedTable(nt) => self.collect_named_table(r, nt)?, + ReadType::ExtensionTable(_) => todo!(), + } + if let Some(expr) = r.filter.as_ref() { + self.collect_schemas_from_expr(expr)? + }; + if let Some(expr) = r.best_effort_filter.as_ref() { + self.collect_schemas_from_expr(expr)? + }; + } + RelType::Filter(f) => { + self.apply(f.input.as_ref().map(|b| b.as_ref()))?; + for expr in f.condition.iter() { + self.collect_schemas_from_expr(expr)?; + } + } + RelType::Fetch(f) => { + self.apply(f.input.as_ref().map(|b| b.as_ref()))?; + } + RelType::Aggregate(a) => { + self.apply(a.input.as_ref().map(|b| b.as_ref()))?; + for grouping in a.groupings.iter() { + for expr in grouping.grouping_expressions.iter() { + self.collect_schemas_from_expr(expr)? + } + } + for measure in a.measures.iter() { + if let Some(agg_fn) = measure.measure.as_ref() { + for arg in agg_fn.arguments.iter() { + self.collect_schemas_from_arg(arg)? + } + for sort in agg_fn.sorts.iter() { + if let Some(expr) = sort.expr.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + } + if let Some(expr) = measure.filter.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + } + RelType::Sort(s) => { + self.apply(s.input.as_ref().map(|b| b.as_ref()))?; + for sort_field in s.sorts.iter() { + if let Some(expr) = sort_field.expr.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + } + RelType::Join(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref()))?; + self.apply(j.right.as_ref().map(|b| b.as_ref()))?; + if let Some(expr) = j.expression.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + if let Some(expr) = j.post_join_filter.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + } + RelType::Project(p) => { + self.apply(p.input.as_ref().map(|b| b.as_ref()))? + } + RelType::Set(s) => { + for input in s.inputs.iter() { + self.collect_schemas_from_rel(input)?; + } + } + RelType::ExtensionSingle(s) => { + self.apply(s.input.as_ref().map(|b| b.as_ref()))? + } + + RelType::ExtensionMulti(m) => { + for input in m.inputs.iter() { + self.collect_schemas_from_rel(input)? + } + } + RelType::ExtensionLeaf(_) => {} + RelType::Cross(c) => { + self.apply(c.left.as_ref().map(|b| b.as_ref()))?; + self.apply(c.right.as_ref().map(|b| b.as_ref()))?; + } + // RelType::Reference(_) => {} + // RelType::Write(_) => {} + // RelType::Ddl(_) => {} + RelType::HashJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref()))?; + self.apply(j.right.as_ref().map(|b| b.as_ref()))?; + if let Some(expr) = j.post_join_filter.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + } + RelType::MergeJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref()))?; + self.apply(j.right.as_ref().map(|b| b.as_ref()))?; + if let Some(expr) = j.post_join_filter.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + } + RelType::NestedLoopJoin(j) => { + self.apply(j.left.as_ref().map(|b| b.as_ref()))?; + self.apply(j.right.as_ref().map(|b| b.as_ref()))?; + if let Some(expr) = j.expression.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + } + RelType::Window(w) => { + self.apply(w.input.as_ref().map(|b| b.as_ref()))?; + for wf in w.window_functions.iter() { + for arg in wf.arguments.iter() { + self.collect_schemas_from_arg(arg)?; + } + } + for expr in w.partition_expressions.iter() { + self.collect_schemas_from_expr(expr)?; + } + for sort_field in w.sorts.iter() { + if let Some(expr) = sort_field.expr.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + } + } + RelType::Exchange(e) => { + self.apply(e.input.as_ref().map(|b| b.as_ref()))?; + let exchange_kind = e.exchange_kind.as_ref().ok_or( + substrait_datafusion_err!("Exhange must set exchange_kind"), + )?; + match exchange_kind { + ExchangeKind::ScatterByFields(_) => {} + ExchangeKind::SingleTarget(st) => { + if let Some(expr) = st.expression.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + ExchangeKind::MultiTarget(mt) => { + if let Some(expr) = mt.expression.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + ExchangeKind::RoundRobin(_) => {} + ExchangeKind::Broadcast(_) => {} + } + } + RelType::Expand(e) => { + self.apply(e.input.as_ref().map(|b| b.as_ref()))?; + for expand_field in e.fields.iter() { + let expand_type = expand_field.field_type.as_ref().ok_or( + substrait_datafusion_err!("ExpandField must set field_type"), + )?; + match expand_type { + FieldType::SwitchingField(sf) => { + for expr in sf.duplicates.iter() { + self.collect_schemas_from_expr(expr)?; + } + } + FieldType::ConsistentField(expr) => { + self.collect_schemas_from_expr(expr)? + } + } + } + } + _ => todo!(), + } + Ok(()) + } + + fn apply(&mut self, input: Option<&Rel>) -> Result<()> { + match input { + None => Ok(()), + Some(rel) => self.collect_schemas_from_rel(rel), + } + } + + fn collect_schemas_from_expr(&mut self, e: &Expression) -> Result<()> { + let rex_type = e.rex_type.as_ref().ok_or(substrait_datafusion_err!( + "rex_type must be set on Expression" + ))?; + match rex_type { + RexType::Literal(_) => {} + RexType::Selection(_) => {} + RexType::ScalarFunction(sf) => { + for arg in sf.arguments.iter() { + self.collect_schemas_from_arg(arg)? + } + } + RexType::WindowFunction(wf) => { + for arg in wf.arguments.iter() { + self.collect_schemas_from_arg(arg)? + } + for sort_field in wf.sorts.iter() { + if let Some(expr) = sort_field.expr.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + for expr in wf.partitions.iter() { + self.collect_schemas_from_expr(expr)? + } + } + RexType::IfThen(it) => { + for if_clause in it.ifs.iter() { + if let Some(expr) = if_clause.r#if.as_ref() { + self.collect_schemas_from_expr(expr)?; + }; + if let Some(expr) = if_clause.then.as_ref() { + self.collect_schemas_from_expr(expr)?; + }; + } + if let Some(expr) = it.r#else.as_ref() { + self.collect_schemas_from_expr(expr)?; + }; + } + RexType::SwitchExpression(se) => { + if let Some(expr) = se.r#match.as_ref() { + self.collect_schemas_from_expr(expr)? + } + for if_value in se.ifs.iter() { + if let Some(expr) = if_value.then.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + if let Some(expr) = se.r#else.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + RexType::SingularOrList(sol) => { + if let Some(expr) = sol.value.as_ref() { + self.collect_schemas_from_expr(expr)? + } + for expr in sol.options.iter() { + self.collect_schemas_from_expr(expr)? + } + } + RexType::MultiOrList(mol) => { + for expr in mol.value.iter() { + self.collect_schemas_from_expr(expr)? + } + for record in mol.options.iter() { + for expr in record.fields.iter() { + self.collect_schemas_from_expr(expr)? + } + } + } + RexType::Cast(c) => { + if let Some(expr) = c.input.as_ref() { + self.collect_schemas_from_expr(expr)? + } + } + RexType::Subquery(subquery) => { + let subquery_type = subquery + .subquery_type + .as_ref() + .ok_or(substrait_datafusion_err!("subquery_type must be set"))?; + match subquery_type { + SubqueryType::Scalar(s) => { + if let Some(rel) = s.input.as_ref() { + self.collect_schemas_from_rel(rel)?; + } + } + SubqueryType::InPredicate(ip) => { + for expr in ip.needles.iter() { + self.collect_schemas_from_expr(expr)?; + } + if let Some(rel) = ip.haystack.as_ref() { + self.collect_schemas_from_rel(rel)?; + } + } + SubqueryType::SetPredicate(sp) => { + if let Some(rel) = sp.tuples.as_ref() { + self.collect_schemas_from_rel(rel)?; + } + } + SubqueryType::SetComparison(sc) => { + if let Some(expr) = sc.left.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + if let Some(rel) = sc.right.as_ref() { + self.collect_schemas_from_rel(rel)?; + } + } + } + } + RexType::Nested(n) => { + let nested_type = n.nested_type.as_ref().ok_or( + substrait_datafusion_err!("Nested must set nested_type"), + )?; + match nested_type { + NestedType::Struct(s) => { + for expr in s.fields.iter() { + self.collect_schemas_from_expr(expr)?; + } + } + NestedType::List(l) => { + for expr in l.values.iter() { + self.collect_schemas_from_expr(expr)?; + } + } + NestedType::Map(m) => { + for key_value in m.key_values.iter() { + if let Some(expr) = key_value.key.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + if let Some(expr) = key_value.value.as_ref() { + self.collect_schemas_from_expr(expr)?; + } + } + } + } + } + // Enum is deprecated + RexType::Enum(_) => {} + } + Ok(()) + } + + fn collect_schemas_from_arg(&mut self, fa: &FunctionArgument) -> Result<()> { + let arg_type = fa.arg_type.as_ref().ok_or(substrait_datafusion_err!( + "FunctionArgument must set arg_type" + ))?; + match arg_type { + ArgType::Enum(_) => {} + ArgType::Type(_) => {} + ArgType::Value(expr) => self.collect_schemas_from_expr(expr)?, + } + Ok(()) + } + } +} diff --git a/datafusion/wasmtest/Cargo.toml b/datafusion/wasmtest/Cargo.toml index 46e157aecfd9..2440244d08c3 100644 --- a/datafusion/wasmtest/Cargo.toml +++ b/datafusion/wasmtest/Cargo.toml @@ -60,4 +60,5 @@ wasm-bindgen = "0.2.87" wasm-bindgen-futures = "0.4.40" [dev-dependencies] -wasm-bindgen-test = "0.3" +tokio = { workspace = true } +wasm-bindgen-test = "0.3.44" diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 7d324d074c9d..37512e8278a7 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -13,7 +13,7 @@ }, "devDependencies": { "copy-webpack-plugin": "6.4.1", - "webpack": "5.88.2", + "webpack": "5.94.0", "webpack-cli": "5.1.4", "webpack-dev-server": "4.15.1" } @@ -38,57 +38,57 @@ "dev": true }, "node_modules/@jridgewell/gen-mapping": { - "version": "0.3.3", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz", - "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==", + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", + "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", "dev": true, "dependencies": { - "@jridgewell/set-array": "^1.0.1", + "@jridgewell/set-array": "^1.2.1", "@jridgewell/sourcemap-codec": "^1.4.10", - "@jridgewell/trace-mapping": "^0.3.9" + "@jridgewell/trace-mapping": "^0.3.24" }, "engines": { "node": ">=6.0.0" } }, "node_modules/@jridgewell/resolve-uri": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz", - "integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", "dev": true, "engines": { "node": ">=6.0.0" } }, "node_modules/@jridgewell/set-array": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", - "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", "dev": true, "engines": { "node": ">=6.0.0" } }, "node_modules/@jridgewell/source-map": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.5.tgz", - "integrity": "sha512-UTYAUj/wviwdsMfzoSJspJxbkH5o1snzwX0//0ENX1u/55kkZZkcTZP6u9bwKGkv+dkk9at4m1Cpt0uY80kcpQ==", + "version": "0.3.6", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", + "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", "dev": true, "dependencies": { - "@jridgewell/gen-mapping": "^0.3.0", - "@jridgewell/trace-mapping": "^0.3.9" + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25" } }, "node_modules/@jridgewell/sourcemap-codec": { - "version": "1.4.15", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz", - "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==", + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", "dev": true }, "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.19", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.19.tgz", - "integrity": "sha512-kf37QtfW+Hwx/buWGMPcR60iF9ziHa6r/CZJIHbmcm4+0qrXiVdxegAH0F6yddEVQ7zdkjcGCgCzUu+BcbhQxw==", + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", "dev": true, "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -198,30 +198,10 @@ "@types/node": "*" } }, - "node_modules/@types/eslint": { - "version": "8.44.2", - "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.44.2.tgz", - "integrity": "sha512-sdPRb9K6iL5XZOmBubg8yiFp5yS/JdUDQsq5e6h95km91MCYMuvp7mh1fjPEYUhvHepKpZOjnEaMBR4PxjWDzg==", - "dev": true, - "dependencies": { - "@types/estree": "*", - "@types/json-schema": "*" - } - }, - "node_modules/@types/eslint-scope": { - "version": "3.7.4", - "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz", - "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==", - "dev": true, - "dependencies": { - "@types/eslint": "*", - "@types/estree": "*" - } - }, "node_modules/@types/estree": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.1.tgz", - "integrity": "sha512-LG4opVs2ANWZ1TJoKc937iMmNstM/d0ae1vNbnBvBhqCSezgVUOzcLCqbI5elV8Vy6WKwKjaqR+zO9VKirBBCA==", + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", + "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==", "dev": true }, "node_modules/@types/express": { @@ -348,9 +328,9 @@ } }, "node_modules/@webassemblyjs/ast": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.6.tgz", - "integrity": "sha512-IN1xI7PwOvLPgjcf180gC1bqn3q/QaOCwYUahIOhbYUu8KA/3tw2RT/T0Gidi1l7Hhj5D/INhJxiICObqpMu4Q==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz", + "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==", "dev": true, "dependencies": { "@webassemblyjs/helper-numbers": "1.11.6", @@ -370,9 +350,9 @@ "dev": true }, "node_modules/@webassemblyjs/helper-buffer": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.6.tgz", - "integrity": "sha512-z3nFzdcp1mb8nEOFFk8DrYLpHvhKC3grJD2ardfKOzmbmJvEf/tPIqCY+sNcwZIY8ZD7IkB2l7/pqhUhqm7hLA==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz", + "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==", "dev": true }, "node_modules/@webassemblyjs/helper-numbers": { @@ -393,15 +373,15 @@ "dev": true }, "node_modules/@webassemblyjs/helper-wasm-section": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.6.tgz", - "integrity": "sha512-LPpZbSOwTpEC2cgn4hTydySy1Ke+XEu+ETXuoyvuyezHO3Kjdu90KK95Sh9xTbmjrCsUwvWwCOQQNta37VrS9g==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz", + "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.11.6", - "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/helper-buffer": "1.12.1", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/wasm-gen": "1.11.6" + "@webassemblyjs/wasm-gen": "1.12.1" } }, "node_modules/@webassemblyjs/ieee754": { @@ -429,28 +409,28 @@ "dev": true }, "node_modules/@webassemblyjs/wasm-edit": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.6.tgz", - "integrity": "sha512-Ybn2I6fnfIGuCR+Faaz7YcvtBKxvoLV3Lebn1tM4o/IAJzmi9AWYIPWpyBfU8cC+JxAO57bk4+zdsTjJR+VTOw==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz", + "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.11.6", - "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/helper-buffer": "1.12.1", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/helper-wasm-section": "1.11.6", - "@webassemblyjs/wasm-gen": "1.11.6", - "@webassemblyjs/wasm-opt": "1.11.6", - "@webassemblyjs/wasm-parser": "1.11.6", - "@webassemblyjs/wast-printer": "1.11.6" + "@webassemblyjs/helper-wasm-section": "1.12.1", + "@webassemblyjs/wasm-gen": "1.12.1", + "@webassemblyjs/wasm-opt": "1.12.1", + "@webassemblyjs/wasm-parser": "1.12.1", + "@webassemblyjs/wast-printer": "1.12.1" } }, "node_modules/@webassemblyjs/wasm-gen": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.6.tgz", - "integrity": "sha512-3XOqkZP/y6B4F0PBAXvI1/bky7GryoogUtfwExeP/v7Nzwo1QLcq5oQmpKlftZLbT+ERUOAZVQjuNVak6UXjPA==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz", + "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/ast": "1.12.1", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", "@webassemblyjs/ieee754": "1.11.6", "@webassemblyjs/leb128": "1.11.6", @@ -458,24 +438,24 @@ } }, "node_modules/@webassemblyjs/wasm-opt": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.6.tgz", - "integrity": "sha512-cOrKuLRE7PCe6AsOVl7WasYf3wbSo4CeOk6PkrjS7g57MFfVUF9u6ysQBBODX0LdgSvQqRiGz3CXvIDKcPNy4g==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz", + "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.11.6", - "@webassemblyjs/helper-buffer": "1.11.6", - "@webassemblyjs/wasm-gen": "1.11.6", - "@webassemblyjs/wasm-parser": "1.11.6" + "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/helper-buffer": "1.12.1", + "@webassemblyjs/wasm-gen": "1.12.1", + "@webassemblyjs/wasm-parser": "1.12.1" } }, "node_modules/@webassemblyjs/wasm-parser": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.6.tgz", - "integrity": "sha512-6ZwPeGzMJM3Dqp3hCsLgESxBGtT/OeCvCZ4TA1JUPYgmhAx38tTPR9JaKy0S5H3evQpO/h2uWs2j6Yc/fjkpTQ==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz", + "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/ast": "1.12.1", "@webassemblyjs/helper-api-error": "1.11.6", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", "@webassemblyjs/ieee754": "1.11.6", @@ -484,12 +464,12 @@ } }, "node_modules/@webassemblyjs/wast-printer": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.6.tgz", - "integrity": "sha512-JM7AhRcE+yW2GWYaKeHL5vt4xqee5N2WcezptmgyhNS+ScggqcT1OtXykhAb13Sn5Yas0j2uv9tHgrjwvzAP4A==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz", + "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==", "dev": true, "dependencies": { - "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/ast": "1.12.1", "@xtuc/long": "4.2.2" } }, @@ -563,9 +543,9 @@ } }, "node_modules/acorn": { - "version": "8.10.0", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz", - "integrity": "sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw==", + "version": "8.12.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz", + "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==", "dev": true, "bin": { "acorn": "bin/acorn" @@ -574,10 +554,10 @@ "node": ">=0.4.0" } }, - "node_modules/acorn-import-assertions": { - "version": "1.9.0", - "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.9.0.tgz", - "integrity": "sha512-cmMwop9x+8KFhxvKrKfPYmN6/pKTYYHBqLa0DfvVZcKMJWNyWLnaqND7dx/qn66R7ewM1UX5XMaDVP5wlVTaVA==", + "node_modules/acorn-import-attributes": { + "version": "1.9.5", + "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", + "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", "dev": true, "peerDependencies": { "acorn": "^8" @@ -731,9 +711,9 @@ } }, "node_modules/body-parser": { - "version": "1.20.2", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", - "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", + "version": "1.20.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, "dependencies": { "bytes": "3.1.2", @@ -744,7 +724,7 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "on-finished": "2.4.1", - "qs": "6.11.0", + "qs": "6.13.0", "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" @@ -804,12 +784,12 @@ } }, "node_modules/braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "dependencies": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" }, "engines": { "node": ">=8" @@ -1115,9 +1095,9 @@ } }, "node_modules/cookie": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true, "engines": { "node": ">= 0.6" @@ -1322,9 +1302,9 @@ } }, "node_modules/enhanced-resolve": { - "version": "5.15.0", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.15.0.tgz", - "integrity": "sha512-LXYT42KJ7lpIKECr2mAXIaMldcNCh/7E0KBKOu4KSfkHmP+mZmSs+8V5gBAqisWBy0OO4W5Oyys0GO1Y8KtdKg==", + "version": "5.17.1", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz", + "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==", "dev": true, "dependencies": { "graceful-fs": "^4.2.4", @@ -1479,37 +1459,37 @@ } }, "node_modules/express": { - "version": "4.19.2", - "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz", - "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==", + "version": "4.21.1", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", + "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", "dev": true, "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "1.20.2", + "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.6.0", + "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", - "encodeurl": "~1.0.2", + "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "etag": "~1.8.1", - "finalhandler": "1.2.0", + "finalhandler": "1.3.1", "fresh": "0.5.2", "http-errors": "2.0.0", - "merge-descriptors": "1.0.1", + "merge-descriptors": "1.0.3", "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.7", + "path-to-regexp": "0.1.10", "proxy-addr": "~2.0.7", - "qs": "6.11.0", + "qs": "6.13.0", "range-parser": "~1.2.1", "safe-buffer": "5.2.1", - "send": "0.18.0", - "serve-static": "1.15.0", + "send": "0.19.0", + "serve-static": "1.16.2", "setprototypeof": "1.2.0", "statuses": "2.0.1", "type-is": "~1.6.18", @@ -1544,6 +1524,15 @@ "node": ">= 0.8" } }, + "node_modules/express/node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/express/node_modules/safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", @@ -1632,9 +1621,9 @@ } }, "node_modules/fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "dependencies": { "to-regex-range": "^5.0.1" @@ -1644,13 +1633,13 @@ } }, "node_modules/finalhandler": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.2.0.tgz", - "integrity": "sha512-5uXcUVftlQMFnWC9qu/svkWv3GTd2PfUhK/3PLkYNAe7FbqJMt3515HaxE6eRL74GdsriiwujiawdaB1BpEISg==", + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.3.1.tgz", + "integrity": "sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==", "dev": true, "dependencies": { "debug": "2.6.9", - "encodeurl": "~1.0.2", + "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "on-finished": "2.4.1", "parseurl": "~1.3.3", @@ -1670,6 +1659,15 @@ "ms": "2.0.0" } }, + "node_modules/finalhandler/node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/finalhandler/node_modules/statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", @@ -2489,10 +2487,13 @@ } }, "node_modules/merge-descriptors": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", - "integrity": "sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==", - "dev": true + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.3.tgz", + "integrity": "sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } }, "node_modules/merge-stream": { "version": "2.0.0", @@ -2735,10 +2736,13 @@ } }, "node_modules/object-inspect": { - "version": "1.13.1", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz", - "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", + "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", "dev": true, + "engines": { + "node": ">= 0.4" + }, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -2933,9 +2937,9 @@ "dev": true }, "node_modules/path-to-regexp": { - "version": "0.1.7", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", - "integrity": "sha512-5DFkuoqlv1uYQKxy8omFBeJPQcdoE07Kv2sferDCrAq1ohOU+MSDswDIbnx3YAM60qIOnYa53wBhXW0EbMonrQ==", + "version": "0.1.10", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", + "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", "dev": true }, "node_modules/path-type": { @@ -3021,12 +3025,12 @@ } }, "node_modules/qs": { - "version": "6.11.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", - "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "version": "6.13.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, "dependencies": { - "side-channel": "^1.0.4" + "side-channel": "^1.0.6" }, "engines": { "node": ">=0.6" @@ -3310,9 +3314,9 @@ } }, "node_modules/send": { - "version": "0.18.0", - "resolved": "https://registry.npmjs.org/send/-/send-0.18.0.tgz", - "integrity": "sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==", + "version": "0.19.0", + "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz", + "integrity": "sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==", "dev": true, "dependencies": { "debug": "2.6.9", @@ -3430,20 +3434,29 @@ "dev": true }, "node_modules/serve-static": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.15.0.tgz", - "integrity": "sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==", + "version": "1.16.2", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.16.2.tgz", + "integrity": "sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==", "dev": true, "dependencies": { - "encodeurl": "~1.0.2", + "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "parseurl": "~1.3.3", - "send": "0.18.0" + "send": "0.19.0" }, "engines": { "node": ">= 0.8.0" } }, + "node_modules/serve-static/node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true, + "engines": { + "node": ">= 0.8" + } + }, "node_modules/set-function-length": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", @@ -3724,9 +3737,9 @@ } }, "node_modules/terser": { - "version": "5.20.0", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.20.0.tgz", - "integrity": "sha512-e56ETryaQDyebBwJIWYB2TT6f2EZ0fL0sW/JRXNMN26zZdKi2u/E/5my5lG6jNxym6qsrVXfFRmOdV42zlAgLQ==", + "version": "5.31.6", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", + "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==", "dev": true, "dependencies": { "@jridgewell/source-map": "^0.3.3", @@ -3742,16 +3755,16 @@ } }, "node_modules/terser-webpack-plugin": { - "version": "5.3.9", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.9.tgz", - "integrity": "sha512-ZuXsqE07EcggTWQjXUj+Aot/OMcD0bMKGgF63f7UxYcu5/AJF53aIpK1YoP5xR9l6s/Hy2b+t1AM0bLNPRuhwA==", + "version": "5.3.10", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz", + "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==", "dev": true, "dependencies": { - "@jridgewell/trace-mapping": "^0.3.17", + "@jridgewell/trace-mapping": "^0.3.20", "jest-worker": "^27.4.5", "schema-utils": "^3.1.1", "serialize-javascript": "^6.0.1", - "terser": "^5.16.8" + "terser": "^5.26.0" }, "engines": { "node": ">= 10.13.0" @@ -3776,9 +3789,9 @@ } }, "node_modules/terser-webpack-plugin/node_modules/serialize-javascript": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz", - "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", + "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", "dev": true, "dependencies": { "randombytes": "^2.1.0" @@ -3930,9 +3943,9 @@ } }, "node_modules/watchpack": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", - "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz", + "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", "dev": true, "dependencies": { "glob-to-regexp": "^0.4.1", @@ -3952,34 +3965,33 @@ } }, "node_modules/webpack": { - "version": "5.88.2", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.88.2.tgz", - "integrity": "sha512-JmcgNZ1iKj+aiR0OvTYtWQqJwq37Pf683dY9bVORwVbUrDhLhdn/PlO2sHsFHPkj7sHNQF3JwaAkp49V+Sq1tQ==", + "version": "5.94.0", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz", + "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==", "dev": true, "dependencies": { - "@types/eslint-scope": "^3.7.3", - "@types/estree": "^1.0.0", - "@webassemblyjs/ast": "^1.11.5", - "@webassemblyjs/wasm-edit": "^1.11.5", - "@webassemblyjs/wasm-parser": "^1.11.5", + "@types/estree": "^1.0.5", + "@webassemblyjs/ast": "^1.12.1", + "@webassemblyjs/wasm-edit": "^1.12.1", + "@webassemblyjs/wasm-parser": "^1.12.1", "acorn": "^8.7.1", - "acorn-import-assertions": "^1.9.0", - "browserslist": "^4.14.5", + "acorn-import-attributes": "^1.9.5", + "browserslist": "^4.21.10", "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.15.0", + "enhanced-resolve": "^5.17.1", "es-module-lexer": "^1.2.1", "eslint-scope": "5.1.1", "events": "^3.2.0", "glob-to-regexp": "^0.4.1", - "graceful-fs": "^4.2.9", + "graceful-fs": "^4.2.11", "json-parse-even-better-errors": "^2.3.1", "loader-runner": "^4.2.0", "mime-types": "^2.1.27", "neo-async": "^2.6.2", "schema-utils": "^3.2.0", "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.3.7", - "watchpack": "^2.4.0", + "terser-webpack-plugin": "^5.3.10", + "watchpack": "^2.4.1", "webpack-sources": "^3.2.3" }, "bin": { @@ -4323,9 +4335,9 @@ "dev": true }, "node_modules/ws": { - "version": "8.14.2", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.14.2.tgz", - "integrity": "sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "engines": { "node": ">=10.0.0" @@ -4376,48 +4388,48 @@ "dev": true }, "@jridgewell/gen-mapping": { - "version": "0.3.3", - "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.3.tgz", - "integrity": "sha512-HLhSWOLRi875zjjMG/r+Nv0oCW8umGb0BgEhyX3dDX3egwZtB8PqLnjz3yedt8R5StBrzcg4aBpnh8UA9D1BoQ==", + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.5.tgz", + "integrity": "sha512-IzL8ZoEDIBRWEzlCcRhOaCupYyN5gdIK+Q6fbFdPDg6HqX6jpkItn7DFIpW9LQzXG6Df9sA7+OKnq0qlz/GaQg==", "dev": true, "requires": { - "@jridgewell/set-array": "^1.0.1", + "@jridgewell/set-array": "^1.2.1", "@jridgewell/sourcemap-codec": "^1.4.10", - "@jridgewell/trace-mapping": "^0.3.9" + "@jridgewell/trace-mapping": "^0.3.24" } }, "@jridgewell/resolve-uri": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.1.tgz", - "integrity": "sha512-dSYZh7HhCDtCKm4QakX0xFpsRDqjjtZf/kjI/v3T3Nwt5r8/qz/M19F9ySyOqU94SXBmeG9ttTul+YnR4LOxFA==", + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", + "integrity": "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==", "dev": true }, "@jridgewell/set-array": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz", - "integrity": "sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw==", + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.2.1.tgz", + "integrity": "sha512-R8gLRTZeyp03ymzP/6Lil/28tGeGEzhx1q2k703KGWRAI1VdvPIXdG70VJc2pAMw3NA6JKL5hhFu1sJX0Mnn/A==", "dev": true }, "@jridgewell/source-map": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.5.tgz", - "integrity": "sha512-UTYAUj/wviwdsMfzoSJspJxbkH5o1snzwX0//0ENX1u/55kkZZkcTZP6u9bwKGkv+dkk9at4m1Cpt0uY80kcpQ==", + "version": "0.3.6", + "resolved": "https://registry.npmjs.org/@jridgewell/source-map/-/source-map-0.3.6.tgz", + "integrity": "sha512-1ZJTZebgqllO79ue2bm3rIGud/bOe0pP5BjSRCRxxYkEZS8STV7zN84UBbiYu7jy+eCKSnVIUgoWWE/tt+shMQ==", "dev": true, "requires": { - "@jridgewell/gen-mapping": "^0.3.0", - "@jridgewell/trace-mapping": "^0.3.9" + "@jridgewell/gen-mapping": "^0.3.5", + "@jridgewell/trace-mapping": "^0.3.25" } }, "@jridgewell/sourcemap-codec": { - "version": "1.4.15", - "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz", - "integrity": "sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg==", + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.0.tgz", + "integrity": "sha512-gv3ZRaISU3fjPAgNsriBRqGWQL6quFx04YMPW/zD8XMLsU32mhCCbfbO6KZFLjvYpCZ8zyDEgqsgf+PwPaM7GQ==", "dev": true }, "@jridgewell/trace-mapping": { - "version": "0.3.19", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.19.tgz", - "integrity": "sha512-kf37QtfW+Hwx/buWGMPcR60iF9ziHa6r/CZJIHbmcm4+0qrXiVdxegAH0F6yddEVQ7zdkjcGCgCzUu+BcbhQxw==", + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", "dev": true, "requires": { "@jridgewell/resolve-uri": "^3.1.0", @@ -4514,30 +4526,10 @@ "@types/node": "*" } }, - "@types/eslint": { - "version": "8.44.2", - "resolved": "https://registry.npmjs.org/@types/eslint/-/eslint-8.44.2.tgz", - "integrity": "sha512-sdPRb9K6iL5XZOmBubg8yiFp5yS/JdUDQsq5e6h95km91MCYMuvp7mh1fjPEYUhvHepKpZOjnEaMBR4PxjWDzg==", - "dev": true, - "requires": { - "@types/estree": "*", - "@types/json-schema": "*" - } - }, - "@types/eslint-scope": { - "version": "3.7.4", - "resolved": "https://registry.npmjs.org/@types/eslint-scope/-/eslint-scope-3.7.4.tgz", - "integrity": "sha512-9K4zoImiZc3HlIp6AVUDE4CWYx22a+lhSZMYNpbjW04+YF0KWj4pJXnEMjdnFTiQibFFmElcsasJXDbdI/EPhA==", - "dev": true, - "requires": { - "@types/eslint": "*", - "@types/estree": "*" - } - }, "@types/estree": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.1.tgz", - "integrity": "sha512-LG4opVs2ANWZ1TJoKc937iMmNstM/d0ae1vNbnBvBhqCSezgVUOzcLCqbI5elV8Vy6WKwKjaqR+zO9VKirBBCA==", + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.5.tgz", + "integrity": "sha512-/kYRxGDLWzHOB7q+wtSUQlFrtcdUccpfy+X+9iMBpHK8QLLhx2wIPYuS5DYtR9Wa/YlZAbIovy7qVdB1Aq6Lyw==", "dev": true }, "@types/express": { @@ -4664,9 +4656,9 @@ } }, "@webassemblyjs/ast": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.11.6.tgz", - "integrity": "sha512-IN1xI7PwOvLPgjcf180gC1bqn3q/QaOCwYUahIOhbYUu8KA/3tw2RT/T0Gidi1l7Hhj5D/INhJxiICObqpMu4Q==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/ast/-/ast-1.12.1.tgz", + "integrity": "sha512-EKfMUOPRRUTy5UII4qJDGPpqfwjOmZ5jeGFwid9mnoqIFK+e0vqoi1qH56JpmZSzEL53jKnNzScdmftJyG5xWg==", "dev": true, "requires": { "@webassemblyjs/helper-numbers": "1.11.6", @@ -4686,9 +4678,9 @@ "dev": true }, "@webassemblyjs/helper-buffer": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.11.6.tgz", - "integrity": "sha512-z3nFzdcp1mb8nEOFFk8DrYLpHvhKC3grJD2ardfKOzmbmJvEf/tPIqCY+sNcwZIY8ZD7IkB2l7/pqhUhqm7hLA==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-buffer/-/helper-buffer-1.12.1.tgz", + "integrity": "sha512-nzJwQw99DNDKr9BVCOZcLuJJUlqkJh+kVzVl6Fmq/tI5ZtEyWT1KZMyOXltXLZJmDtvLCDgwsyrkohEtopTXCw==", "dev": true }, "@webassemblyjs/helper-numbers": { @@ -4709,15 +4701,15 @@ "dev": true }, "@webassemblyjs/helper-wasm-section": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.11.6.tgz", - "integrity": "sha512-LPpZbSOwTpEC2cgn4hTydySy1Ke+XEu+ETXuoyvuyezHO3Kjdu90KK95Sh9xTbmjrCsUwvWwCOQQNta37VrS9g==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/helper-wasm-section/-/helper-wasm-section-1.12.1.tgz", + "integrity": "sha512-Jif4vfB6FJlUlSbgEMHUyk1j234GTNG9dBJ4XJdOySoj518Xj0oGsNi59cUQF4RRMS9ouBUxDDdyBVfPTypa5g==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.6", - "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/helper-buffer": "1.12.1", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/wasm-gen": "1.11.6" + "@webassemblyjs/wasm-gen": "1.12.1" } }, "@webassemblyjs/ieee754": { @@ -4745,28 +4737,28 @@ "dev": true }, "@webassemblyjs/wasm-edit": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.11.6.tgz", - "integrity": "sha512-Ybn2I6fnfIGuCR+Faaz7YcvtBKxvoLV3Lebn1tM4o/IAJzmi9AWYIPWpyBfU8cC+JxAO57bk4+zdsTjJR+VTOw==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-edit/-/wasm-edit-1.12.1.tgz", + "integrity": "sha512-1DuwbVvADvS5mGnXbE+c9NfA8QRcZ6iKquqjjmR10k6o+zzsRVesil54DKexiowcFCPdr/Q0qaMgB01+SQ1u6g==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.6", - "@webassemblyjs/helper-buffer": "1.11.6", + "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/helper-buffer": "1.12.1", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", - "@webassemblyjs/helper-wasm-section": "1.11.6", - "@webassemblyjs/wasm-gen": "1.11.6", - "@webassemblyjs/wasm-opt": "1.11.6", - "@webassemblyjs/wasm-parser": "1.11.6", - "@webassemblyjs/wast-printer": "1.11.6" + "@webassemblyjs/helper-wasm-section": "1.12.1", + "@webassemblyjs/wasm-gen": "1.12.1", + "@webassemblyjs/wasm-opt": "1.12.1", + "@webassemblyjs/wasm-parser": "1.12.1", + "@webassemblyjs/wast-printer": "1.12.1" } }, "@webassemblyjs/wasm-gen": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.11.6.tgz", - "integrity": "sha512-3XOqkZP/y6B4F0PBAXvI1/bky7GryoogUtfwExeP/v7Nzwo1QLcq5oQmpKlftZLbT+ERUOAZVQjuNVak6UXjPA==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-gen/-/wasm-gen-1.12.1.tgz", + "integrity": "sha512-TDq4Ojh9fcohAw6OIMXqiIcTq5KUXTGRkVxbSo1hQnSy6lAM5GSdfwWeSxpAo0YzgsgF182E/U0mDNhuA0tW7w==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/ast": "1.12.1", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", "@webassemblyjs/ieee754": "1.11.6", "@webassemblyjs/leb128": "1.11.6", @@ -4774,24 +4766,24 @@ } }, "@webassemblyjs/wasm-opt": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.11.6.tgz", - "integrity": "sha512-cOrKuLRE7PCe6AsOVl7WasYf3wbSo4CeOk6PkrjS7g57MFfVUF9u6ysQBBODX0LdgSvQqRiGz3CXvIDKcPNy4g==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-opt/-/wasm-opt-1.12.1.tgz", + "integrity": "sha512-Jg99j/2gG2iaz3hijw857AVYekZe2SAskcqlWIZXjji5WStnOpVoat3gQfT/Q5tb2djnCjBtMocY/Su1GfxPBg==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.6", - "@webassemblyjs/helper-buffer": "1.11.6", - "@webassemblyjs/wasm-gen": "1.11.6", - "@webassemblyjs/wasm-parser": "1.11.6" + "@webassemblyjs/ast": "1.12.1", + "@webassemblyjs/helper-buffer": "1.12.1", + "@webassemblyjs/wasm-gen": "1.12.1", + "@webassemblyjs/wasm-parser": "1.12.1" } }, "@webassemblyjs/wasm-parser": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.11.6.tgz", - "integrity": "sha512-6ZwPeGzMJM3Dqp3hCsLgESxBGtT/OeCvCZ4TA1JUPYgmhAx38tTPR9JaKy0S5H3evQpO/h2uWs2j6Yc/fjkpTQ==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wasm-parser/-/wasm-parser-1.12.1.tgz", + "integrity": "sha512-xikIi7c2FHXysxXe3COrVUPSheuBtpcfhbpFj4gmu7KRLYOzANztwUU0IbsqvMqzuNK2+glRGWCEqZo1WCLyAQ==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/ast": "1.12.1", "@webassemblyjs/helper-api-error": "1.11.6", "@webassemblyjs/helper-wasm-bytecode": "1.11.6", "@webassemblyjs/ieee754": "1.11.6", @@ -4800,12 +4792,12 @@ } }, "@webassemblyjs/wast-printer": { - "version": "1.11.6", - "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.11.6.tgz", - "integrity": "sha512-JM7AhRcE+yW2GWYaKeHL5vt4xqee5N2WcezptmgyhNS+ScggqcT1OtXykhAb13Sn5Yas0j2uv9tHgrjwvzAP4A==", + "version": "1.12.1", + "resolved": "https://registry.npmjs.org/@webassemblyjs/wast-printer/-/wast-printer-1.12.1.tgz", + "integrity": "sha512-+X4WAlOisVWQMikjbcvY2e0rwPsKQ9F688lksZhBcPycBBuii3O7m8FACbDMWDojpAqvjIncrG8J0XHKyQfVeA==", "dev": true, "requires": { - "@webassemblyjs/ast": "1.11.6", + "@webassemblyjs/ast": "1.12.1", "@xtuc/long": "4.2.2" } }, @@ -4853,15 +4845,15 @@ } }, "acorn": { - "version": "8.10.0", - "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz", - "integrity": "sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw==", + "version": "8.12.1", + "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.12.1.tgz", + "integrity": "sha512-tcpGyI9zbizT9JbV6oYE477V6mTlXvvi0T0G3SNIYE2apm/G5huBa1+K89VGeovbg+jycCrfhl3ADxErOuO6Jg==", "dev": true }, - "acorn-import-assertions": { - "version": "1.9.0", - "resolved": "https://registry.npmjs.org/acorn-import-assertions/-/acorn-import-assertions-1.9.0.tgz", - "integrity": "sha512-cmMwop9x+8KFhxvKrKfPYmN6/pKTYYHBqLa0DfvVZcKMJWNyWLnaqND7dx/qn66R7ewM1UX5XMaDVP5wlVTaVA==", + "acorn-import-attributes": { + "version": "1.9.5", + "resolved": "https://registry.npmjs.org/acorn-import-attributes/-/acorn-import-attributes-1.9.5.tgz", + "integrity": "sha512-n02Vykv5uA3eHGM/Z2dQrcD56kL8TyDb2p1+0P83PClMnC/nc+anbQRhIOWnSq4Ke/KvDPrY3C9hDtC/A3eHnQ==", "dev": true, "requires": {} }, @@ -4976,9 +4968,9 @@ "dev": true }, "body-parser": { - "version": "1.20.2", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", - "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", + "version": "1.20.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, "requires": { "bytes": "3.1.2", @@ -4989,7 +4981,7 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "on-finished": "2.4.1", - "qs": "6.11.0", + "qs": "6.13.0", "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" @@ -5041,12 +5033,12 @@ } }, "braces": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", - "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz", + "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "requires": { - "fill-range": "^7.0.1" + "fill-range": "^7.1.1" } }, "browserslist": { @@ -5255,9 +5247,9 @@ "dev": true }, "cookie": { - "version": "0.6.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", - "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", + "version": "0.7.1", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.1.tgz", + "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==", "dev": true }, "cookie-signature": { @@ -5415,9 +5407,9 @@ "dev": true }, "enhanced-resolve": { - "version": "5.15.0", - "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.15.0.tgz", - "integrity": "sha512-LXYT42KJ7lpIKECr2mAXIaMldcNCh/7E0KBKOu4KSfkHmP+mZmSs+8V5gBAqisWBy0OO4W5Oyys0GO1Y8KtdKg==", + "version": "5.17.1", + "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.17.1.tgz", + "integrity": "sha512-LMHl3dXhTcfv8gM4kEzIUeTQ+7fpdA0l2tUf34BddXPkz2A5xJ5L/Pchd5BL6rdccM9QGvu0sWZzK1Z1t4wwyg==", "dev": true, "requires": { "graceful-fs": "^4.2.4", @@ -5532,37 +5524,37 @@ } }, "express": { - "version": "4.19.2", - "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz", - "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==", + "version": "4.21.1", + "resolved": "https://registry.npmjs.org/express/-/express-4.21.1.tgz", + "integrity": "sha512-YSFlK1Ee0/GC8QaO91tHcDxJiE/X4FbpAyQWkxAvG6AXCuR65YzK8ua6D9hvi/TzUfZMpc+BwuM1IPw8fmQBiQ==", "dev": true, "requires": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "1.20.2", + "body-parser": "1.20.3", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.6.0", + "cookie": "0.7.1", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", - "encodeurl": "~1.0.2", + "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "etag": "~1.8.1", - "finalhandler": "1.2.0", + "finalhandler": "1.3.1", "fresh": "0.5.2", "http-errors": "2.0.0", - "merge-descriptors": "1.0.1", + "merge-descriptors": "1.0.3", "methods": "~1.1.2", "on-finished": "2.4.1", "parseurl": "~1.3.3", - "path-to-regexp": "0.1.7", + "path-to-regexp": "0.1.10", "proxy-addr": "~2.0.7", - "qs": "6.11.0", + "qs": "6.13.0", "range-parser": "~1.2.1", "safe-buffer": "5.2.1", - "send": "0.18.0", - "serve-static": "1.15.0", + "send": "0.19.0", + "serve-static": "1.16.2", "setprototypeof": "1.2.0", "statuses": "2.0.1", "type-is": "~1.6.18", @@ -5591,6 +5583,12 @@ "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", "dev": true }, + "encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true + }, "safe-buffer": { "version": "5.2.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", @@ -5655,22 +5653,22 @@ } }, "fill-range": { - "version": "7.0.1", - "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", - "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz", + "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "requires": { "to-regex-range": "^5.0.1" } }, "finalhandler": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.2.0.tgz", - "integrity": "sha512-5uXcUVftlQMFnWC9qu/svkWv3GTd2PfUhK/3PLkYNAe7FbqJMt3515HaxE6eRL74GdsriiwujiawdaB1BpEISg==", + "version": "1.3.1", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-1.3.1.tgz", + "integrity": "sha512-6BN9trH7bp3qvnrRyzsBz+g3lZxTNZTbVO2EV1CS0WIcDbawYVdYvGflME/9QP0h0pYlCDBCTjYa9nZzMDpyxQ==", "dev": true, "requires": { "debug": "2.6.9", - "encodeurl": "~1.0.2", + "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "on-finished": "2.4.1", "parseurl": "~1.3.3", @@ -5687,6 +5685,12 @@ "ms": "2.0.0" } }, + "encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true + }, "statuses": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz", @@ -6266,9 +6270,9 @@ } }, "merge-descriptors": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.1.tgz", - "integrity": "sha512-cCi6g3/Zr1iqQi6ySbseM1Xvooa98N0w31jzUYrXPX2xqObmFGHJ0tQ5u74H3mVh7wLouTseZyYIq39g8cNp1w==", + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-1.0.3.tgz", + "integrity": "sha512-gaNvAS7TZ897/rVaZ0nMtAyxNyi/pdbjbAwUpFQpN70GqnVfOiXpeUUMKRBmzXaSQ8DdTX4/0ms62r2K+hE6mQ==", "dev": true }, "merge-stream": { @@ -6449,9 +6453,9 @@ } }, "object-inspect": { - "version": "1.13.1", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz", - "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", + "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", "dev": true }, "obuf": { @@ -6589,9 +6593,9 @@ "dev": true }, "path-to-regexp": { - "version": "0.1.7", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.7.tgz", - "integrity": "sha512-5DFkuoqlv1uYQKxy8omFBeJPQcdoE07Kv2sferDCrAq1ohOU+MSDswDIbnx3YAM60qIOnYa53wBhXW0EbMonrQ==", + "version": "0.1.10", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-0.1.10.tgz", + "integrity": "sha512-7lf7qcQidTku0Gu3YDPc8DJ1q7OOucfa/BSsIwjuh56VU7katFvuM8hULfkwB3Fns/rsVF7PwPKVw1sl5KQS9w==", "dev": true }, "path-type": { @@ -6658,12 +6662,12 @@ "dev": true }, "qs": { - "version": "6.11.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", - "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "version": "6.13.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, "requires": { - "side-channel": "^1.0.4" + "side-channel": "^1.0.6" } }, "queue-microtask": { @@ -6856,9 +6860,9 @@ } }, "send": { - "version": "0.18.0", - "resolved": "https://registry.npmjs.org/send/-/send-0.18.0.tgz", - "integrity": "sha512-qqWzuOjSFOuqPjFe4NOsMLafToQQwBSOEpS+FwEt3A2V3vKubTquT3vmLTQpFgMXp8AlFWFuP1qKaJZOtPpVXg==", + "version": "0.19.0", + "resolved": "https://registry.npmjs.org/send/-/send-0.19.0.tgz", + "integrity": "sha512-dW41u5VfLXu8SJh5bwRmyYUbAoSB3c9uQh6L8h/KtsFREPWpbX1lrljJo186Jc4nmci/sGUZ9a0a0J2zgfq2hw==", "dev": true, "requires": { "debug": "2.6.9", @@ -6967,15 +6971,23 @@ } }, "serve-static": { - "version": "1.15.0", - "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.15.0.tgz", - "integrity": "sha512-XGuRDNjXUijsUL0vl6nSD7cwURuzEgglbOaFuZM9g3kwDXOWVTck0jLzjPzGD+TazWbboZYu52/9/XPdUgne9g==", + "version": "1.16.2", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-1.16.2.tgz", + "integrity": "sha512-VqpjJZKadQB/PEbEwvFdO43Ax5dFBZ2UECszz8bQ7pi7wt//PWe1P6MN7eCnjsatYtBT6EuiClbjSWP2WrIoTw==", "dev": true, "requires": { - "encodeurl": "~1.0.2", + "encodeurl": "~2.0.0", "escape-html": "~1.0.3", "parseurl": "~1.3.3", - "send": "0.18.0" + "send": "0.19.0" + }, + "dependencies": { + "encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "dev": true + } } }, "set-function-length": { @@ -7199,9 +7211,9 @@ } }, "terser": { - "version": "5.20.0", - "resolved": "https://registry.npmjs.org/terser/-/terser-5.20.0.tgz", - "integrity": "sha512-e56ETryaQDyebBwJIWYB2TT6f2EZ0fL0sW/JRXNMN26zZdKi2u/E/5my5lG6jNxym6qsrVXfFRmOdV42zlAgLQ==", + "version": "5.31.6", + "resolved": "https://registry.npmjs.org/terser/-/terser-5.31.6.tgz", + "integrity": "sha512-PQ4DAriWzKj+qgehQ7LK5bQqCFNMmlhjR2PFFLuqGCpuCAauxemVBWwWOxo3UIwWQx8+Pr61Df++r76wDmkQBg==", "dev": true, "requires": { "@jridgewell/source-map": "^0.3.3", @@ -7211,22 +7223,22 @@ } }, "terser-webpack-plugin": { - "version": "5.3.9", - "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.9.tgz", - "integrity": "sha512-ZuXsqE07EcggTWQjXUj+Aot/OMcD0bMKGgF63f7UxYcu5/AJF53aIpK1YoP5xR9l6s/Hy2b+t1AM0bLNPRuhwA==", + "version": "5.3.10", + "resolved": "https://registry.npmjs.org/terser-webpack-plugin/-/terser-webpack-plugin-5.3.10.tgz", + "integrity": "sha512-BKFPWlPDndPs+NGGCr1U59t0XScL5317Y0UReNrHaw9/FwhPENlq6bfgs+4yPfyP51vqC1bQ4rp1EfXW5ZSH9w==", "dev": true, "requires": { - "@jridgewell/trace-mapping": "^0.3.17", + "@jridgewell/trace-mapping": "^0.3.20", "jest-worker": "^27.4.5", "schema-utils": "^3.1.1", "serialize-javascript": "^6.0.1", - "terser": "^5.16.8" + "terser": "^5.26.0" }, "dependencies": { "serialize-javascript": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.1.tgz", - "integrity": "sha512-owoXEFjWRllis8/M1Q+Cw5k8ZH40e3zhp/ovX+Xr/vi1qj6QesbyXXViFbpNvWvPNAD62SutwEXavefrLJWj7w==", + "version": "6.0.2", + "resolved": "https://registry.npmjs.org/serialize-javascript/-/serialize-javascript-6.0.2.tgz", + "integrity": "sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==", "dev": true, "requires": { "randombytes": "^2.1.0" @@ -7339,9 +7351,9 @@ "dev": true }, "watchpack": { - "version": "2.4.0", - "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.0.tgz", - "integrity": "sha512-Lcvm7MGST/4fup+ifyKi2hjyIAwcdI4HRgtvTpIUxBRhB+RFtUh8XtDOxUfctVCnhVi+QQj49i91OyvzkJl6cg==", + "version": "2.4.2", + "resolved": "https://registry.npmjs.org/watchpack/-/watchpack-2.4.2.tgz", + "integrity": "sha512-TnbFSbcOCcDgjZ4piURLCbJ3nJhznVh9kw6F6iokjiFPl8ONxe9A6nMDVXDiNbrSfLILs6vB07F7wLBrwPYzJw==", "dev": true, "requires": { "glob-to-regexp": "^0.4.1", @@ -7358,34 +7370,33 @@ } }, "webpack": { - "version": "5.88.2", - "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.88.2.tgz", - "integrity": "sha512-JmcgNZ1iKj+aiR0OvTYtWQqJwq37Pf683dY9bVORwVbUrDhLhdn/PlO2sHsFHPkj7sHNQF3JwaAkp49V+Sq1tQ==", + "version": "5.94.0", + "resolved": "https://registry.npmjs.org/webpack/-/webpack-5.94.0.tgz", + "integrity": "sha512-KcsGn50VT+06JH/iunZJedYGUJS5FGjow8wb9c0v5n1Om8O1g4L6LjtfxwlXIATopoQu+vOXXa7gYisWxCoPyg==", "dev": true, "requires": { - "@types/eslint-scope": "^3.7.3", - "@types/estree": "^1.0.0", - "@webassemblyjs/ast": "^1.11.5", - "@webassemblyjs/wasm-edit": "^1.11.5", - "@webassemblyjs/wasm-parser": "^1.11.5", + "@types/estree": "^1.0.5", + "@webassemblyjs/ast": "^1.12.1", + "@webassemblyjs/wasm-edit": "^1.12.1", + "@webassemblyjs/wasm-parser": "^1.12.1", "acorn": "^8.7.1", - "acorn-import-assertions": "^1.9.0", - "browserslist": "^4.14.5", + "acorn-import-attributes": "^1.9.5", + "browserslist": "^4.21.10", "chrome-trace-event": "^1.0.2", - "enhanced-resolve": "^5.15.0", + "enhanced-resolve": "^5.17.1", "es-module-lexer": "^1.2.1", "eslint-scope": "5.1.1", "events": "^3.2.0", "glob-to-regexp": "^0.4.1", - "graceful-fs": "^4.2.9", + "graceful-fs": "^4.2.11", "json-parse-even-better-errors": "^2.3.1", "loader-runner": "^4.2.0", "mime-types": "^2.1.27", "neo-async": "^2.6.2", "schema-utils": "^3.2.0", "tapable": "^2.1.1", - "terser-webpack-plugin": "^5.3.7", - "watchpack": "^2.4.0", + "terser-webpack-plugin": "^5.3.10", + "watchpack": "^2.4.1", "webpack-sources": "^3.2.3" }, "dependencies": { @@ -7618,9 +7629,9 @@ "dev": true }, "ws": { - "version": "8.14.2", - "resolved": "https://registry.npmjs.org/ws/-/ws-8.14.2.tgz", - "integrity": "sha512-wEBG1ftX4jcglPxgFCMJmZ2PLtSbJ2Peg6TmpJFTbe9GZYOQCDPdMYu/Tm0/bGZkw8paZnJY45J4K2PZrLYq8g==", + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz", + "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==", "dev": true, "requires": {} }, diff --git a/datafusion/wasmtest/datafusion-wasm-app/package.json b/datafusion/wasmtest/datafusion-wasm-app/package.json index cd32070fa0bc..0860473276ea 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package.json @@ -27,7 +27,7 @@ "datafusion-wasmtest": "../pkg" }, "devDependencies": { - "webpack": "5.88.2", + "webpack": "5.94.0", "webpack-cli": "5.1.4", "webpack-dev-server": "4.15.1", "copy-webpack-plugin": "6.4.1" diff --git a/datafusion/wasmtest/src/lib.rs b/datafusion/wasmtest/src/lib.rs index a74cce72ac64..085064d16d94 100644 --- a/datafusion/wasmtest/src/lib.rs +++ b/datafusion/wasmtest/src/lib.rs @@ -78,9 +78,8 @@ mod test { use super::*; use datafusion::execution::context::SessionContext; use datafusion_execution::{ - config::SessionConfig, - disk_manager::DiskManagerConfig, - runtime_env::{RuntimeConfig, RuntimeEnv}, + config::SessionConfig, disk_manager::DiskManagerConfig, + runtime_env::RuntimeEnvBuilder, }; use datafusion_physical_plan::collect; use datafusion_sql::parser::DFParser; @@ -88,23 +87,22 @@ mod test { wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); - #[wasm_bindgen_test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + #[cfg_attr(not(target_arch = "wasm32"), allow(dead_code))] fn datafusion_test() { basic_exprs(); basic_parse(); } - #[wasm_bindgen_test] + #[wasm_bindgen_test(unsupported = tokio::test)] async fn basic_execute() { let sql = "SELECT 2 + 2;"; // Execute SQL (using datafusion) - let rt = Arc::new( - RuntimeEnv::new( - RuntimeConfig::new().with_disk_manager(DiskManagerConfig::Disabled), - ) - .unwrap(), - ); + let rt = RuntimeEnvBuilder::new() + .with_disk_manager(DiskManagerConfig::Disabled) + .build_arc() + .unwrap(); let session_config = SessionConfig::new().with_target_partitions(1); let session_context = Arc::new(SessionContext::new_with_config_rt(session_config, rt)); diff --git a/dev/changelog/38.0.0.md b/dev/changelog/38.0.0.md new file mode 100644 index 000000000000..ca06ad5de475 --- /dev/null +++ b/dev/changelog/38.0.0.md @@ -0,0 +1,370 @@ + + +## [38.0.0](https://github.com/apache/datafusion/tree/38.0.0) (2024-05-07) + +**Breaking changes:** + +- refactor: make dfschema wrap schemaref [#9595](https://github.com/apache/datafusion/pull/9595) (haohuaijin) +- Make FirstValue an UDAF, Change `AggregateUDFImpl::accumulator` signature, support ORDER BY for UDAFs [#9874](https://github.com/apache/datafusion/pull/9874) (jayzhan211) +- Remove `OwnedTableReference` and `OwnedSchemaReference` [#9933](https://github.com/apache/datafusion/pull/9933) (comphead) +- Consistent LogicalPlan subquery handling in TreeNode::apply and TreeNode::visit [#9913](https://github.com/apache/datafusion/pull/9913) (peter-toth) +- Refactor `Optimizer` to use owned plans and `TreeNode` API (10% faster planning) [#9948](https://github.com/apache/datafusion/pull/9948) (alamb) +- Stop copying plans in `LogicalPlan::with_param_values` [#10016](https://github.com/apache/datafusion/pull/10016) (alamb) +- Move coalesce to datafusion-functions and remove BuiltInScalarFunction [#10098](https://github.com/apache/datafusion/pull/10098) (Omega359) +- Refactor sessionconfig set fns to avoid an unnecessary enum to string conversion [#10141](https://github.com/apache/datafusion/pull/10141) (psvri) +- ScalarUDF: Remove `supports_zero_argument` and avoid creating null array for empty args [#10193](https://github.com/apache/datafusion/pull/10193) (jayzhan211) +- Clean-up: Remove AggregateExec::group_by() [#10297](https://github.com/apache/datafusion/pull/10297) (berkaysynnada) +- Remove `ScalarFunctionDefinition::Name` [#10277](https://github.com/apache/datafusion/pull/10277) (lewiszlw) +- feat: Determine ordering of file groups [#9593](https://github.com/apache/datafusion/pull/9593) (suremarc) +- Split parquet bloom filter config and enable bloom filter on read by default [#10306](https://github.com/apache/datafusion/pull/10306) (lewiszlw) +- Improve coerce API so it does not need DFSchema [#10331](https://github.com/apache/datafusion/pull/10331) (alamb) +- Minor: Do not force analyzer to copy logical plans [#10367](https://github.com/apache/datafusion/pull/10367) (alamb) +- Move `Covariance` (Sample) `covar` / `covar_samp` to be a User Defined Aggregate Function [#10372](https://github.com/apache/datafusion/pull/10372) (jayzhan211) + +**Performance related:** + +- perf: Use `Arc` instead of `Cow<&'a>` in the analyzer [#9824](https://github.com/apache/datafusion/pull/9824) (comphead) + +**Implemented enhancements:** + +- feat: Add display_pg_json for LogicalPlan [#9789](https://github.com/apache/datafusion/pull/9789) (liurenjie1024) +- feat: eliminate redundant sorts on monotonic expressions [#9813](https://github.com/apache/datafusion/pull/9813) (suremarc) +- feat: optimize `lower` and `upper` functions [#9971](https://github.com/apache/datafusion/pull/9971) (JasonLi-cn) +- feat: support `unnest` multiple arrays [#10044](https://github.com/apache/datafusion/pull/10044) (jonahgao) +- feat: `DataFrame` supports unnesting multiple columns [#10118](https://github.com/apache/datafusion/pull/10118) (jonahgao) +- feat: support input reordering for `NestedLoopJoinExec` [#9676](https://github.com/apache/datafusion/pull/9676) (korowa) +- feat: add static_name() to ExecutionPlan [#10266](https://github.com/apache/datafusion/pull/10266) (waynexia) +- feat: add optimizer config param to avoid grouping partitions `prefer_existing_union` [#10259](https://github.com/apache/datafusion/pull/10259) (NGA-TRAN) +- feat: unwrap casts of string and dictionary columns [#10323](https://github.com/apache/datafusion/pull/10323) (erratic-pattern) +- feat: Add CrossJoin match case to unparser [#10371](https://github.com/apache/datafusion/pull/10371) (sardination) +- feat: run expression simplifier in a loop until a fixedpoint or 3 cycles [#10358](https://github.com/apache/datafusion/pull/10358) (erratic-pattern) + +**Fixed bugs:** + +- fix: detect non-recursive CTEs in the recursive `WITH` clause [#9836](https://github.com/apache/datafusion/pull/9836) (jonahgao) +- fix: improve `unnest_generic_list` handling of null list [#9975](https://github.com/apache/datafusion/pull/9975) (jonahgao) +- fix: reduce lock contention in `RepartitionExec::execute` [#10009](https://github.com/apache/datafusion/pull/10009) (crepererum) +- fix: `RepartitionExec` metrics [#10025](https://github.com/apache/datafusion/pull/10025) (crepererum) +- fix: Support Dict types in `in_list` physical plans [#10031](https://github.com/apache/datafusion/pull/10031) (advancedxy) +- fix: Specify row count in sort_batch for batch with no columns [#10094](https://github.com/apache/datafusion/pull/10094) (viirya) +- fix: another non-deterministic test in `joins.slt` [#10122](https://github.com/apache/datafusion/pull/10122) (korowa) +- fix: duplicate output for HashJoinExec in CollectLeft mode [#9757](https://github.com/apache/datafusion/pull/9757) (korowa) +- fix: cargo warnings of import item [#10196](https://github.com/apache/datafusion/pull/10196) (waynexia) +- fix: reduce lock contention in distributor channels [#10026](https://github.com/apache/datafusion/pull/10026) (crepererum) +- fix: no longer support the `substring` function [#10242](https://github.com/apache/datafusion/pull/10242) (jonahgao) +- fix: Correct null_count in describe() [#10260](https://github.com/apache/datafusion/pull/10260) (Weijun-H) +- fix: schema error when parsing order-by expressions [#10234](https://github.com/apache/datafusion/pull/10234) (jonahgao) +- fix: LogFunc simplify swaps arguments [#10360](https://github.com/apache/datafusion/pull/10360) (erratic-pattern) + +**Documentation updates:** + +- Update `COPY` documentation to reflect changes [#9754](https://github.com/apache/datafusion/pull/9754) (alamb) +- doc: Add `datafusion-federation` to Integrations [#9853](https://github.com/apache/datafusion/pull/9853) (phillipleblanc) +- Improve `AggregateUDFImpl::state_fields` documentation [#9919](https://github.com/apache/datafusion/pull/9919) (alamb) +- Update datafusion-cli docs, split up [#10078](https://github.com/apache/datafusion/pull/10078) (alamb) +- Fix large futures causing stack overflows [#10033](https://github.com/apache/datafusion/pull/10033) (sergiimk) +- Update documentation to replace Apache Arrow DataFusion with Apache DataFusion [#10130](https://github.com/apache/datafusion/pull/10130) (andygrove) +- Update github repo links [#10167](https://github.com/apache/datafusion/pull/10167) (lewiszlw) +- minor: fix installation section link [#10179](https://github.com/apache/datafusion/pull/10179) (comphead) +- Improve documentation on `TreeNode` [#10035](https://github.com/apache/datafusion/pull/10035) (alamb) +- Update .asf.yaml to publish docs to datafusion.apache.org [#10190](https://github.com/apache/datafusion/pull/10190) (phillipleblanc) +- Update links to point to datafusion.apache.org [#10195](https://github.com/apache/datafusion/pull/10195) (phillipleblanc) +- doc: fix subscribe mail link to datafusion mailing lists [#10225](https://github.com/apache/datafusion/pull/10225) (jackwener) +- Fix docs.rs build for datafusion-proto (hopefully) [#10254](https://github.com/apache/datafusion/pull/10254) (alamb) +- docs: add download page [#10271](https://github.com/apache/datafusion/pull/10271) (tisonkun) +- Clarify docs explaining the relationship between `SessionState` and `SessionContext` [#10350](https://github.com/apache/datafusion/pull/10350) (alamb) +- docs: Add DataFusion subprojects to navigation menu, other minor updates [#10362](https://github.com/apache/datafusion/pull/10362) (andygrove) + +**Merged pull requests:** + +- Prepare 37.0.0 Release [#9697](https://github.com/apache/datafusion/pull/9697) (andygrove) +- move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions [#9841](https://github.com/apache/datafusion/pull/9841) (Omega359) +- Add non-column expression equality tracking to filter exec [#9819](https://github.com/apache/datafusion/pull/9819) (mustafasrepo) +- datafusion-cli support for multiple commands in a single line [#9831](https://github.com/apache/datafusion/pull/9831) (berkaysynnada) +- Add tests for filtering, grouping, aggregation of ARRAYs [#9695](https://github.com/apache/datafusion/pull/9695) (alamb) +- Remove vestigal conbench integration [#9855](https://github.com/apache/datafusion/pull/9855) (alamb) +- feat: Add display_pg_json for LogicalPlan [#9789](https://github.com/apache/datafusion/pull/9789) (liurenjie1024) +- Update `COPY` documentation to reflect changes [#9754](https://github.com/apache/datafusion/pull/9754) (alamb) +- Minor: Remove the bench most likely to cause OOM in CI [#9858](https://github.com/apache/datafusion/pull/9858) (gruuya) +- Minor: make uuid an optional dependency on datafusion-functions [#9771](https://github.com/apache/datafusion/pull/9771) (alamb) +- doc: Add `Spice.ai` to Known Users [#9852](https://github.com/apache/datafusion/pull/9852) (phillipleblanc) +- minor: add a hint how to adjust max rows displayed [#9845](https://github.com/apache/datafusion/pull/9845) (comphead) +- Exclude .github directory from release tarball [#9850](https://github.com/apache/datafusion/pull/9850) (andygrove) +- move strpos, substr functions to datafusion_functions [#9849](https://github.com/apache/datafusion/pull/9849) (Omega359) +- doc: Add `datafusion-federation` to Integrations [#9853](https://github.com/apache/datafusion/pull/9853) (phillipleblanc) +- chore(deps): update cargo requirement from 0.77.0 to 0.78.1 [#9844](https://github.com/apache/datafusion/pull/9844) (dependabot[bot]) +- chore(deps-dev): bump webpack-dev-middleware from 5.3.3 to 5.3.4 in /datafusion/wasmtest/datafusion-wasm-app [#9741](https://github.com/apache/datafusion/pull/9741) (dependabot[bot]) +- Implement semi/anti join output statistics estimation [#9800](https://github.com/apache/datafusion/pull/9800) (korowa) +- move Log2, Log10, Ln to datafusion-functions [#9869](https://github.com/apache/datafusion/pull/9869) (tinfoil-knight) +- Add CI compile checks for feature flags in datafusion-functions [#9772](https://github.com/apache/datafusion/pull/9772) (alamb) +- move the Translate, SubstrIndex, FindInSet functions to datafusion-functions [#9864](https://github.com/apache/datafusion/pull/9864) (Omega359) +- Support custom struct field names with new scalar function named_struct [#9743](https://github.com/apache/datafusion/pull/9743) (gstvg) +- Allow declaring partition columns in `PARTITION BY` clause, backwards compatible [#9599](https://github.com/apache/datafusion/pull/9599) (MohamedAbdeen21) +- Minor: Move depcheck out of datafusion crate (200 less crates to compile) [#9865](https://github.com/apache/datafusion/pull/9865) (alamb) +- Minor: delete duplicate bench test [#9866](https://github.com/apache/datafusion/pull/9866) (Lordworms) +- parquet: Add tests for pruning on Int8/Int16/Int64 columns [#9778](https://github.com/apache/datafusion/pull/9778) (progval) +- move `Atan2`, `Atan`, `Acosh`, `Asinh`, `Atanh` to `datafusion-function` [#9872](https://github.com/apache/datafusion/pull/9872) (Weijun-H) +- minor(doc): fix dead link for catalogs example [#9883](https://github.com/apache/datafusion/pull/9883) (yjshen) +- parquet: Add tests for page pruning on unsigned integers [#9888](https://github.com/apache/datafusion/pull/9888) (progval) +- fix(9870): common expression elimination optimization, should always re-find the correct expression during re-write. [#9871](https://github.com/apache/datafusion/pull/9871) (wiedld) +- [CI] Use alias for table.struct [#9894](https://github.com/apache/datafusion/pull/9894) (jayzhan211) +- fix: detect non-recursive CTEs in the recursive `WITH` clause [#9836](https://github.com/apache/datafusion/pull/9836) (jonahgao) +- Minor: Add SIGMOD paper reference to architecture guide [#9886](https://github.com/apache/datafusion/pull/9886) (alamb) +- refactor: add macro for the binary math function in `datafusion-function` [#9889](https://github.com/apache/datafusion/pull/9889) (Weijun-H) +- Add benchmark for substr_index [#9878](https://github.com/apache/datafusion/pull/9878) (Omega359) +- Add test for reading back file created with `COPY ... OPTIONS (FORMAT..)` options [#9753](https://github.com/apache/datafusion/pull/9753) (alamb) +- Add Expr->String for SimilarTo, IsNotTrue, IsNotUnknown,Negative [#9902](https://github.com/apache/datafusion/pull/9902) (yyy1000) +- refactor: make dfschema wrap schemaref [#9595](https://github.com/apache/datafusion/pull/9595) (haohuaijin) +- Add `spilled_rows` metric to `ExternalSorter` by `IPCWriter` [#9885](https://github.com/apache/datafusion/pull/9885) (erenavsarogullari) +- Minor: Add ParquetExec::table_parquet_options accessor [#9909](https://github.com/apache/datafusion/pull/9909) (alamb) +- Add support for Bloom filters on unsigned integer columns in Parquet tables [#9770](https://github.com/apache/datafusion/pull/9770) (progval) +- Move `radians`, `signum`, `sin`, `sinh` and `sqrt` functions to `datafusion-functions` crate [#9882](https://github.com/apache/datafusion/pull/9882) (erenavsarogullari) +- refactor: make all udf function impls public [#9903](https://github.com/apache/datafusion/pull/9903) (universalmind303) +- Minor: Improve math expr description [#9911](https://github.com/apache/datafusion/pull/9911) (caicancai) +- perf: Use `Arc` instead of `Cow<&'a>` in the analyzer [#9824](https://github.com/apache/datafusion/pull/9824) (comphead) +- Use `struct` instead of `named_struct` when there are no aliases [#9897](https://github.com/apache/datafusion/pull/9897) (alamb) +- Improve planning speed using `impl Into>` to create Arc rather than `&str` [#9916](https://github.com/apache/datafusion/pull/9916) (alamb) +- Make FirstValue an UDAF, Change `AggregateUDFImpl::accumulator` signature, support ORDER BY for UDAFs [#9874](https://github.com/apache/datafusion/pull/9874) (jayzhan211) +- Add TPCH-DS planning benchmark [#9907](https://github.com/apache/datafusion/pull/9907) (alamb) +- Simplify Expr::map_children [#9876](https://github.com/apache/datafusion/pull/9876) (peter-toth) +- CrossJoin Refactor [#9830](https://github.com/apache/datafusion/pull/9830) (berkaysynnada) +- Optimization: concat function [#9732](https://github.com/apache/datafusion/pull/9732) (JasonLi-cn) +- Improve `AggregateUDFImpl::state_fields` documentation [#9919](https://github.com/apache/datafusion/pull/9919) (alamb) +- chore(deps): update substrait requirement from 0.28.0 to 0.29.0 [#9942](https://github.com/apache/datafusion/pull/9942) (dependabot[bot]) +- test: fix intermittent failure in cte.slt [#9934](https://github.com/apache/datafusion/pull/9934) (jonahgao) +- Move `cbrt`, `cos`, `cosh`, `degrees` to `datafusion-functions` [#9938](https://github.com/apache/datafusion/pull/9938) (erenavsarogullari) +- Add Expr->String for Exists, Sort [#9936](https://github.com/apache/datafusion/pull/9936) (kevinmingtarja) +- Remove `OwnedTableReference` and `OwnedSchemaReference` [#9933](https://github.com/apache/datafusion/pull/9933) (comphead) +- Prune out constant expressions from output ordering. [#9947](https://github.com/apache/datafusion/pull/9947) (mustafasrepo) +- Move `AggregateExpr`, `PhysicalExpr` and `PhysicalSortExpr` to physical-expr-core [#9926](https://github.com/apache/datafusion/pull/9926) (jayzhan211) +- Minor: Update release README [#9956](https://github.com/apache/datafusion/pull/9956) (alamb) +- Optimize `COUNT(1)`: Change the sentinel value's type for COUNT(\*) to Int64 [#9944](https://github.com/apache/datafusion/pull/9944) (gruuya) +- Improve docs for `TableProvider::supports_filters_pushdown` and remove deprecated function [#9923](https://github.com/apache/datafusion/pull/9923) (alamb) +- Minor: Improve documentation for AggregateUDFImpl::accumulator and `AccumulatorArgs` [#9920](https://github.com/apache/datafusion/pull/9920) (alamb) +- Minor: improve TableReference docs [#9952](https://github.com/apache/datafusion/pull/9952) (alamb) +- Fix datafusion-cli publishing [#9955](https://github.com/apache/datafusion/pull/9955) (alamb) +- Simplify TreeNode recursions [#9965](https://github.com/apache/datafusion/pull/9965) (peter-toth) +- Validate partitions columns in `CREATE EXTERNAL TABLE` if table already exists. [#9912](https://github.com/apache/datafusion/pull/9912) (MohamedAbdeen21) +- Minor: Add additional documentation to `CommonSubexprEliminate` [#9959](https://github.com/apache/datafusion/pull/9959) (alamb) +- Fix tpcds planning stack overflows - Join planning refactoring [#9962](https://github.com/apache/datafusion/pull/9962) (Jefffrey) +- coercion vec[Dictionary, Utf8] to Dictionary for coalesce function [#9958](https://github.com/apache/datafusion/pull/9958) (Lordworms) +- Minor: Update library documentation with new crates [#9966](https://github.com/apache/datafusion/pull/9966) (alamb) +- Minor: Return InternalError rather than panic for `NamedStructField should be rewritten in OperatorToFunction` [#9968](https://github.com/apache/datafusion/pull/9968) (alamb) +- minor: update MSRV 1.73 [#9977](https://github.com/apache/datafusion/pull/9977) (comphead) +- Move First Value UDAF and builtin first / last function to `aggregate-functions` [#9960](https://github.com/apache/datafusion/pull/9960) (jayzhan211) +- Minor: Avoid copying all expressions in `Analzyer` / `check_plan` [#9974](https://github.com/apache/datafusion/pull/9974) (alamb) +- Minor: Improve documentation about optimizer [#9967](https://github.com/apache/datafusion/pull/9967) (alamb) +- Minor: Use `Expr::apply()` instead of `inspect_expr_pre()` [#9984](https://github.com/apache/datafusion/pull/9984) (peter-toth) +- Update documentation for COPY command [#9931](https://github.com/apache/datafusion/pull/9931) (alamb) +- Minor: fix bug in pruning predicate doc [#9986](https://github.com/apache/datafusion/pull/9986) (alamb) +- fix: improve `unnest_generic_list` handling of null list [#9975](https://github.com/apache/datafusion/pull/9975) (jonahgao) +- Consistent LogicalPlan subquery handling in TreeNode::apply and TreeNode::visit [#9913](https://github.com/apache/datafusion/pull/9913) (peter-toth) +- Remove unnecessary result in `DFSchema::index_of_column_by_name` [#9990](https://github.com/apache/datafusion/pull/9990) (lewiszlw) +- Removes Bloom filter for Int8/Int16/Uint8/Uint16 [#9969](https://github.com/apache/datafusion/pull/9969) (edmondop) +- Move LogicalPlan `tree_node` module [#9995](https://github.com/apache/datafusion/pull/9995) (alamb) +- Optimize performance of substr_index and add tests [#9973](https://github.com/apache/datafusion/pull/9973) (kevinmingtarja) +- move Floor, Gcd, Lcm, Pi to datafusion-functions [#9976](https://github.com/apache/datafusion/pull/9976) (Omega359) +- Minor: Improve documentation on `LogicalPlan::apply*` and `LogicalPlan::map*` [#9996](https://github.com/apache/datafusion/pull/9996) (alamb) +- move the Log, Power functions to datafusion-functions [#9983](https://github.com/apache/datafusion/pull/9983) (tinfoil-knight) +- Remove FORMAT <..> backwards compatibility options from COPY [#9985](https://github.com/apache/datafusion/pull/9985) (tinfoil-knight) +- move Trunc, Cot, Round, iszero functions to datafusion-functions [#10000](https://github.com/apache/datafusion/pull/10000) (Omega359) +- Minor: Clarify documentation on `PruningStatistics::row_counts` and `PruningStatistics::null_counts` and make test match [#10004](https://github.com/apache/datafusion/pull/10004) (alamb) +- Avoid `LogicalPlan::clone()` in `LogicalPlan::map_children` when possible [#9999](https://github.com/apache/datafusion/pull/9999) (alamb) +- Introduce `TreeNode::exists()` API, avoid copying expressions [#10008](https://github.com/apache/datafusion/pull/10008) (peter-toth) +- Minor: Make `LogicalPlan::apply_subqueries` and `LogicalPlan::map_subqueries` pub [#9998](https://github.com/apache/datafusion/pull/9998) (alamb) +- Move Nanvl and random functions to datafusion-functions [#10017](https://github.com/apache/datafusion/pull/10017) (Omega359) +- fix: reduce lock contention in `RepartitionExec::execute` [#10009](https://github.com/apache/datafusion/pull/10009) (crepererum) +- chore(deps): update rstest requirement from 0.18.0 to 0.19.0 [#10021](https://github.com/apache/datafusion/pull/10021) (dependabot[bot]) +- Minor: Document LogicalPlan tree node transformations [#10010](https://github.com/apache/datafusion/pull/10010) (alamb) +- Refactor `Optimizer` to use owned plans and `TreeNode` API (10% faster planning) [#9948](https://github.com/apache/datafusion/pull/9948) (alamb) +- Further clarification of the supports_filters_pushdown documentation [#9988](https://github.com/apache/datafusion/pull/9988) (cisaacson) +- Prune columns are all null in ParquetExec by row_counts , handle IS NOT NULL [#9989](https://github.com/apache/datafusion/pull/9989) (Ted-Jiang) +- Improve the performance of ltrim/rtrim/btrim [#10006](https://github.com/apache/datafusion/pull/10006) (JasonLi-cn) +- fix: `RepartitionExec` metrics [#10025](https://github.com/apache/datafusion/pull/10025) (crepererum) +- modify emit() of TopK to emit on `batch_size` rather than `batch_size-1` [#10030](https://github.com/apache/datafusion/pull/10030) (JasonLi-cn) +- Consolidate LogicalPlan tree node walking/rewriting code into one module [#10034](https://github.com/apache/datafusion/pull/10034) (alamb) +- Introduce `OptimizerRule::rewrite` to rewrite in place, rewrite `ExprSimplifier` (20% faster planning) [#9954](https://github.com/apache/datafusion/pull/9954) (alamb) +- Fix DistinctCount for timestamps with time zone [#10043](https://github.com/apache/datafusion/pull/10043) (joroKr21) +- Improve documentation on `LogicalPlan` TreeNode methods [#10037](https://github.com/apache/datafusion/pull/10037) (alamb) +- chore(deps): update prost-build requirement from =0.12.3 to =0.12.4 [#10045](https://github.com/apache/datafusion/pull/10045) (crepererum) +- Fix datafusion-cli cursor isn't on the right position in windows 7 cmd [#10028](https://github.com/apache/datafusion/pull/10028) (colommar) +- Always pass DataType to PrimitiveDistinctCountAccumulator [#10047](https://github.com/apache/datafusion/pull/10047) (joroKr21) +- Stop copying plans in `LogicalPlan::with_param_values` [#10016](https://github.com/apache/datafusion/pull/10016) (alamb) +- fix `NamedStructField should be rewritten in OperatorToFunction` in subquery regression (change `ApplyFunctionRewrites` to use TreeNode API [#10032](https://github.com/apache/datafusion/pull/10032) (alamb) +- Avoid copies in `InlineTableScan` via TreeNode API [#10038](https://github.com/apache/datafusion/pull/10038) (alamb) +- Bump sccache-action to v0.0.4 [#10060](https://github.com/apache/datafusion/pull/10060) (phillipleblanc) +- chore: add GitHub workflow to close stale PRs [#10046](https://github.com/apache/datafusion/pull/10046) (andygrove) +- feat: eliminate redundant sorts on monotonic expressions [#9813](https://github.com/apache/datafusion/pull/9813) (suremarc) +- Disable `crypto_expressions` feature properly for --no-default-features [#10059](https://github.com/apache/datafusion/pull/10059) (phillipleblanc) +- Return self in EmptyExec and PlaceholderRowExec with_new_children [#10052](https://github.com/apache/datafusion/pull/10052) (joroKr21) +- chore(deps): update sqllogictest requirement from 0.19.0 to 0.20.0 [#10057](https://github.com/apache/datafusion/pull/10057) (dependabot[bot]) +- Rename `FileSinkExec` to `DataSinkExec` [#10065](https://github.com/apache/datafusion/pull/10065) (phillipleblanc) +- fix: Support Dict types in `in_list` physical plans [#10031](https://github.com/apache/datafusion/pull/10031) (advancedxy) +- Prune pages are all null in ParquetExec by row_counts and fix NOT NULL prune [#10051](https://github.com/apache/datafusion/pull/10051) (Ted-Jiang) +- Refactor `EliminateOuterJoin` to implement `OptimizerRule::rewrite()` [#10081](https://github.com/apache/datafusion/pull/10081) (peter-toth) +- chore(deps): update substrait requirement from 0.29.0 to 0.30.0 [#10084](https://github.com/apache/datafusion/pull/10084) (dependabot[bot]) +- feat: optimize `lower` and `upper` functions [#9971](https://github.com/apache/datafusion/pull/9971) (JasonLi-cn) +- Prepend sqllogictest explain result with line number [#10019](https://github.com/apache/datafusion/pull/10019) (duongcongtoai) +- Use PhysicalExtensionCodec consistently [#10075](https://github.com/apache/datafusion/pull/10075) (joroKr21) +- Minor: Do not truncate `SHOW ALL` in datafusion-cli [#10079](https://github.com/apache/datafusion/pull/10079) (alamb) +- Minor: get mutable ref to `SessionConfig` in `SessionState` [#10050](https://github.com/apache/datafusion/pull/10050) (MichaelScofield) +- Move `ceil`, `exp`, `factorial` to `datafusion-functions` crate [#10083](https://github.com/apache/datafusion/pull/10083) (erenavsarogullari) +- feat: support `unnest` multiple arrays [#10044](https://github.com/apache/datafusion/pull/10044) (jonahgao) +- cleanup(tests): Move tests from `push_down_projections.rs` to `optimize_projections.rs` [#10071](https://github.com/apache/datafusion/pull/10071) (kavirajk) +- Move conversion of FIRST/LAST Aggregate function to independent physical optimizer rule [#10061](https://github.com/apache/datafusion/pull/10061) (jayzhan211) +- Avoid copies in `CountWildcardRule` via TreeNode API [#10066](https://github.com/apache/datafusion/pull/10066) (alamb) +- Coerce Dictionary types for scalar functions [#10077](https://github.com/apache/datafusion/pull/10077) (viirya) +- Refactor `UnwrapCastInComparison` to implement `OptimizerRule::rewrite()` [#10087](https://github.com/apache/datafusion/pull/10087) (peter-toth) +- Improve ApproxPercentileAccumulator merge api and fix bug [#10056](https://github.com/apache/datafusion/pull/10056) (Ted-Jiang) +- Support http s3 endpoints in datafusion-cli via `CREATE EXTERNAL TABLE` [#10080](https://github.com/apache/datafusion/pull/10080) (alamb) +- [Bug Fix]: Deem hash repartition unnecessary when input and output has 1 partition [#10095](https://github.com/apache/datafusion/pull/10095) (mustafasrepo) +- fix: Specify row count in sort_batch for batch with no columns [#10094](https://github.com/apache/datafusion/pull/10094) (viirya) +- Move concat, concat_ws, ends_with, initcap to datafusion-functions [#10089](https://github.com/apache/datafusion/pull/10089) (Omega359) +- Update datafusion-cli docs, split up [#10078](https://github.com/apache/datafusion/pull/10078) (alamb) +- Refactor physical create_initial_plan to iteratively & concurrently construct plan from the bottom up [#10023](https://github.com/apache/datafusion/pull/10023) (Jefffrey) +- Adding TPCH benchmarks for Sort Merge Join [#10092](https://github.com/apache/datafusion/pull/10092) (comphead) +- [minor] make parquet prune tests more readable [#10112](https://github.com/apache/datafusion/pull/10112) (Ted-Jiang) +- Fix intermittent CI test failure in `joins.slt` [#10120](https://github.com/apache/datafusion/pull/10120) (alamb) +- Update dependabot to consider datafusion-cli [#10108](https://github.com/apache/datafusion/pull/10108) (Jefffrey) +- fix: another non-deterministic test in `joins.slt` [#10122](https://github.com/apache/datafusion/pull/10122) (korowa) +- Minor: only trigger dependency check on changes to Cargo.toml [#10099](https://github.com/apache/datafusion/pull/10099) (alamb) +- Refactor `UnwrapCastInComparison` to remove `Expr` clones [#10115](https://github.com/apache/datafusion/pull/10115) (peter-toth) +- Fix large futures causing stack overflows [#10033](https://github.com/apache/datafusion/pull/10033) (sergiimk) +- Avoid cloning in `log::simplify` and `power::simplify` [#10086](https://github.com/apache/datafusion/pull/10086) (alamb) +- feat: `DataFrame` supports unnesting multiple columns [#10118](https://github.com/apache/datafusion/pull/10118) (jonahgao) +- Minor: Refine dev/release/README.md [#10129](https://github.com/apache/datafusion/pull/10129) (alamb) +- Minor: Add default for `Expr` [#10127](https://github.com/apache/datafusion/pull/10127) (peter-toth) +- Update documentation to replace Apache Arrow DataFusion with Apache DataFusion [#10130](https://github.com/apache/datafusion/pull/10130) (andygrove) +- Fix AVG groups accummulator ignoring return type [#10114](https://github.com/apache/datafusion/pull/10114) (gruuya) +- Port `37.1.0` changes to main [#10136](https://github.com/apache/datafusion/pull/10136) (alamb) +- chore(deps): update substrait requirement from 0.30.0 to 0.31.0 [#10140](https://github.com/apache/datafusion/pull/10140) (dependabot[bot]) +- Minor: Support more args for udaf [#10146](https://github.com/apache/datafusion/pull/10146) (jayzhan211) +- Minor: Signature check for UDAF [#10147](https://github.com/apache/datafusion/pull/10147) (jayzhan211) +- minor: avoid cloning the `SetExpr` during planning of `SelectInto` [#10152](https://github.com/apache/datafusion/pull/10152) (jonahgao) +- Add distinct aggregate tests to sqllogictest [#10158](https://github.com/apache/datafusion/pull/10158) (Jefffrey) +- Add test for LIKE newline handling [#10160](https://github.com/apache/datafusion/pull/10160) (Jefffrey) +- minor: unparser cleanup and new roundtrip test [#10150](https://github.com/apache/datafusion/pull/10150) (devinjdangelo) +- Support Duration and Union types in ScalarValue::iter_to_array [#10139](https://github.com/apache/datafusion/pull/10139) (joroKr21) +- chore(deps): update sqlparser requirement from 0.44.0 to 0.45.0 [#10137](https://github.com/apache/datafusion/pull/10137) (Jefffrey) +- fix: duplicate output for HashJoinExec in CollectLeft mode [#9757](https://github.com/apache/datafusion/pull/9757) (korowa) +- Move coalesce to datafusion-functions and remove BuiltInScalarFunction [#10098](https://github.com/apache/datafusion/pull/10098) (Omega359) +- [DOC] Add test example for backtraces [#10143](https://github.com/apache/datafusion/pull/10143) (comphead) +- Update github repo links [#10167](https://github.com/apache/datafusion/pull/10167) (lewiszlw) +- feat: support input reordering for `NestedLoopJoinExec` [#9676](https://github.com/apache/datafusion/pull/9676) (korowa) +- minor: fix installation section link [#10179](https://github.com/apache/datafusion/pull/10179) (comphead) +- Improve `TreeNode` and `LogicalPlan` APIs to accept owned closures, deprecate `transform_down_mut()` and `transform_up_mut()` [#10126](https://github.com/apache/datafusion/pull/10126) (peter-toth) +- Projection Expression - Input Field Inconsistencies during Projection [#10088](https://github.com/apache/datafusion/pull/10088) (berkaysynnada) +- implement short_circuits function for ScalarUDFImpl trait [#10168](https://github.com/apache/datafusion/pull/10168) (Lordworms) +- Improve documentation on `TreeNode` [#10035](https://github.com/apache/datafusion/pull/10035) (alamb) +- implement rewrite for ExtractEquijoinPredicate and avoid clone in filter [#10165](https://github.com/apache/datafusion/pull/10165) (Lordworms) +- Update .asf.yaml to point to new mailing list [#10189](https://github.com/apache/datafusion/pull/10189) (phillipleblanc) +- Update NOTICE.txt to be relevant to DataFusion [#10185](https://github.com/apache/datafusion/pull/10185) (alamb) +- Update .asf.yaml to publish docs to datafusion.apache.org [#10190](https://github.com/apache/datafusion/pull/10190) (phillipleblanc) +- Minor: Add `Column::from(Tableref, &FieldRef)`, `Expr::from(Column)` and `Expr::from(Tableref, &FieldRef)` [#10178](https://github.com/apache/datafusion/pull/10178) (alamb) +- implement rewrite for FilterNullJoinKeys [#10166](https://github.com/apache/datafusion/pull/10166) (Lordworms) +- Implement rewrite for EliminateOneUnion and EliminateJoin [#10184](https://github.com/apache/datafusion/pull/10184) (Lordworms) +- Update links to point to datafusion.apache.org [#10195](https://github.com/apache/datafusion/pull/10195) (phillipleblanc) +- Minor: Introduce `Expr::is_volatile()`, adjust `TreeNode::exists()` [#10191](https://github.com/apache/datafusion/pull/10191) (peter-toth) +- Doc: Modify docs to fix old naming [#10199](https://github.com/apache/datafusion/pull/10199) (comphead) +- [MINOR] Remove ScalarFunction from datafusion.proto #10173 [#10202](https://github.com/apache/datafusion/pull/10202) (dmitrybugakov) +- Allow expr_to_sql unparsing with no quotes [#10198](https://github.com/apache/datafusion/pull/10198) (phillipleblanc) +- Minor: Avoid a clone in ArrayFunctionRewriter [#10204](https://github.com/apache/datafusion/pull/10204) (alamb) +- Move coalesce function from math to core [#10201](https://github.com/apache/datafusion/pull/10201) (xxxuuu) +- fix: cargo warnings of import item [#10196](https://github.com/apache/datafusion/pull/10196) (waynexia) +- Minor: Remove some clone in `TypeCoercion` [#10203](https://github.com/apache/datafusion/pull/10203) (alamb) +- doc: fix subscribe mail link to datafusion mailing lists [#10225](https://github.com/apache/datafusion/pull/10225) (jackwener) +- Minor: Prevent empty datafusion-cli commands [#10219](https://github.com/apache/datafusion/pull/10219) (comphead) +- Optimize date_bin (2x faster) [#10215](https://github.com/apache/datafusion/pull/10215) (simonvandel) +- Refactor sessionconfig set fns to avoid an unnecessary enum to string conversion [#10141](https://github.com/apache/datafusion/pull/10141) (psvri) +- fix: reduce lock contention in distributor channels [#10026](https://github.com/apache/datafusion/pull/10026) (crepererum) +- Avoid `Expr` copies `OptimizeProjection`, 12% faster planning, encapsulate indicies [#10216](https://github.com/apache/datafusion/pull/10216) (alamb) +- chore: Create a doap file [#10233](https://github.com/apache/datafusion/pull/10233) (tisonkun) +- Allow adding user defined metadata to `ParquetSink` [#10224](https://github.com/apache/datafusion/pull/10224) (wiedld) +- refactor `EliminateDuplicatedExpr` optimizer pass to avoid clone [#10218](https://github.com/apache/datafusion/pull/10218) (Lordworms) +- Support for median(distinct) aggregation function [#10226](https://github.com/apache/datafusion/pull/10226) (Jefffrey) +- Add tests that `random()` and `uuid()` produce unique values for each row [#10248](https://github.com/apache/datafusion/pull/10248) (alamb) +- ScalarUDF: Remove `supports_zero_argument` and avoid creating null array for empty args [#10193](https://github.com/apache/datafusion/pull/10193) (jayzhan211) +- Add Expr->String for WindowFunction [#10243](https://github.com/apache/datafusion/pull/10243) (yyy1000) +- Make function modules public, add Default impl's. [#10239](https://github.com/apache/datafusion/pull/10239) (Omega359) +- chore: Update release scripts to reflect move to TLP [#10235](https://github.com/apache/datafusion/pull/10235) (andygrove) +- Stop copying plans in `EliminateLimit` [#10253](https://github.com/apache/datafusion/pull/10253) (kevinmingtarja) +- Minor Clean-up in JoinSelection Tests [#10249](https://github.com/apache/datafusion/pull/10249) (berkaysynnada) +- fix: no longer support the `substring` function [#10242](https://github.com/apache/datafusion/pull/10242) (jonahgao) +- Fix docs.rs build for datafusion-proto (hopefully) [#10254](https://github.com/apache/datafusion/pull/10254) (alamb) +- Minor: Possibility to strip datafusion error name [#10186](https://github.com/apache/datafusion/pull/10186) (comphead) +- Docs: Add governance page to contributor guide [#10238](https://github.com/apache/datafusion/pull/10238) (alamb) +- Improve documentation on `ColumnarValue` [#10265](https://github.com/apache/datafusion/pull/10265) (alamb) +- Minor: Add comments for removed protobuf nodes [#10252](https://github.com/apache/datafusion/pull/10252) (alamb) +- feat: add static_name() to ExecutionPlan [#10266](https://github.com/apache/datafusion/pull/10266) (waynexia) +- Zero-copy conversion from SchemaRef to DfSchema [#10298](https://github.com/apache/datafusion/pull/10298) (tustvold) +- chore: Update Error for Unnest Rewritter [#10263](https://github.com/apache/datafusion/pull/10263) (Weijun-H) +- feat(CLI): print column headers for empty query results [#10300](https://github.com/apache/datafusion/pull/10300) (jonahgao) +- Clean-up: Remove AggregateExec::group_by() [#10297](https://github.com/apache/datafusion/pull/10297) (berkaysynnada) +- Add mailing list descriptions to documentation [#10284](https://github.com/apache/datafusion/pull/10284) (alamb) +- chore(deps): update substrait requirement from 0.31.0 to 0.32.0 [#10279](https://github.com/apache/datafusion/pull/10279) (dependabot[bot]) +- refactor: Convert `IPCWriter` metrics from `u64` to `usize` [#10278](https://github.com/apache/datafusion/pull/10278) (erenavsarogullari) +- Validate ScalarUDF output rows and fix nulls for `array_has` and `get_field` for `Map` [#10148](https://github.com/apache/datafusion/pull/10148) (duongcongtoai) +- Minor: return NULL for range and generate_series [#10275](https://github.com/apache/datafusion/pull/10275) (Lordworms) +- docs: add download page [#10271](https://github.com/apache/datafusion/pull/10271) (tisonkun) +- Minor: Add some more tests to map.slt [#10301](https://github.com/apache/datafusion/pull/10301) (alamb) +- fix: Correct null_count in describe() [#10260](https://github.com/apache/datafusion/pull/10260) (Weijun-H) +- chore: Add datatype info to error message [#10307](https://github.com/apache/datafusion/pull/10307) (viirya) +- feat: add optimizer config param to avoid grouping partitions `prefer_existing_union` [#10259](https://github.com/apache/datafusion/pull/10259) (NGA-TRAN) +- Remove `ScalarFunctionDefinition::Name` [#10277](https://github.com/apache/datafusion/pull/10277) (lewiszlw) +- Display: Support `preserve_partitioning` on SortExec physical plan. [#10153](https://github.com/apache/datafusion/pull/10153) (kavirajk) +- Fix build with missing `use` ( `" return internal_err!("UDF returned a different ..."`) [#10317](https://github.com/apache/datafusion/pull/10317) (alamb) +- [Minor] Update link to list of committers in contributor guide [#10312](https://github.com/apache/datafusion/pull/10312) (alamb) +- Optimize EliminateFilter to avoid unnecessary copies #10288 [#10302](https://github.com/apache/datafusion/pull/10302) (dmitrybugakov) +- chore: add function to set prefer_existing_union [#10322](https://github.com/apache/datafusion/pull/10322) (NGA-TRAN) +- `ExecutionPlan` visitor example documentation [#10286](https://github.com/apache/datafusion/pull/10286) (matthewmturner) +- fix: schema error when parsing order-by expressions [#10234](https://github.com/apache/datafusion/pull/10234) (jonahgao) +- Stop copying LogicalPlan and Exprs in `RewriteDisjunctivePredicate` [#10305](https://github.com/apache/datafusion/pull/10305) (rohitrastogi) +- feat: unwrap casts of string and dictionary columns [#10323](https://github.com/apache/datafusion/pull/10323) (erratic-pattern) +- feat: Determine ordering of file groups [#9593](https://github.com/apache/datafusion/pull/9593) (suremarc) +- Stop copying LogicalPlan and Exprs in `DecorrelatePredicateSubquery` [#10318](https://github.com/apache/datafusion/pull/10318) (alamb) +- Minor: Add additional coalesce tests [#10334](https://github.com/apache/datafusion/pull/10334) (alamb) +- Minor: add a few more dictionary unwrap tests [#10335](https://github.com/apache/datafusion/pull/10335) (alamb) +- Check list size before concat in ScalarValue [#10329](https://github.com/apache/datafusion/pull/10329) (timsaucer) +- Split parquet bloom filter config and enable bloom filter on read by default [#10306](https://github.com/apache/datafusion/pull/10306) (lewiszlw) +- Improve coerce API so it does not need DFSchema [#10331](https://github.com/apache/datafusion/pull/10331) (alamb) +- Stop copying LogicalPlan and Exprs in `PropagateEmptyRelation` [#10332](https://github.com/apache/datafusion/pull/10332) (dmitrybugakov) +- Stop copying LogicalPlan and Exprs in EliminateNestedUnion [#10319](https://github.com/apache/datafusion/pull/10319) (emgeee) +- Fix clippy lints found by Clippy in Rust `1.78` [#10353](https://github.com/apache/datafusion/pull/10353) (alamb) +- Minor: Add sql level test for lead/lag on arrays [#10345](https://github.com/apache/datafusion/pull/10345) (alamb) +- fix: LogFunc simplify swaps arguments [#10360](https://github.com/apache/datafusion/pull/10360) (erratic-pattern) +- Refine documentation for `Transformed::{update,map,transform})_data` [#10355](https://github.com/apache/datafusion/pull/10355) (alamb) +- Clarify docs explaining the relationship between `SessionState` and `SessionContext` [#10350](https://github.com/apache/datafusion/pull/10350) (alamb) +- Optimized push down filter #10291 [#10366](https://github.com/apache/datafusion/pull/10366) (dmitrybugakov) +- Unparser: Support `ORDER BY` in window function definition [#10370](https://github.com/apache/datafusion/pull/10370) (yyy1000) +- docs: Add DataFusion subprojects to navigation menu, other minor updates [#10362](https://github.com/apache/datafusion/pull/10362) (andygrove) +- feat: Add CrossJoin match case to unparser [#10371](https://github.com/apache/datafusion/pull/10371) (sardination) +- Minor: Do not force analyzer to copy logical plans [#10367](https://github.com/apache/datafusion/pull/10367) (alamb) +- Minor: Move Sum aggregate function test to slt [#10382](https://github.com/apache/datafusion/pull/10382) (jayzhan211) +- chore: remove DataPtr trait since Arc::ptr_eq ignores pointer metadata [#10378](https://github.com/apache/datafusion/pull/10378) (intoraw) +- Move `Covariance` (Sample) `covar` / `covar_samp` to be a User Defined Aggregate Function [#10372](https://github.com/apache/datafusion/pull/10372) (jayzhan211) +- Support limit in StreamingTableExec [#10309](https://github.com/apache/datafusion/pull/10309) (lewiszlw) +- Minor: Move count test to slt [#10383](https://github.com/apache/datafusion/pull/10383) (jayzhan211) +- [MINOR]: Reduce test run time [#10390](https://github.com/apache/datafusion/pull/10390) (mustafasrepo) +- Fix `coalesce`, `struct` and `named_strct` expr_fn function to take multiple arguments [#10321](https://github.com/apache/datafusion/pull/10321) (alamb) +- Minor: remove old `create_physical_expr` to `scalar_function` [#10387](https://github.com/apache/datafusion/pull/10387) (jayzhan211) +- Move average unit tests to slt [#10401](https://github.com/apache/datafusion/pull/10401) (lewiszlw) +- Move array_agg unit tests to slt [#10402](https://github.com/apache/datafusion/pull/10402) (lewiszlw) +- feat: run expression simplifier in a loop until a fixedpoint or 3 cycles [#10358](https://github.com/apache/datafusion/pull/10358) (erratic-pattern) +- Add `SessionContext`/`SessionState::create_physical_expr()` to create `PhysicalExpressions` from `Expr`s [#10330](https://github.com/apache/datafusion/pull/10330) (alamb) diff --git a/dev/changelog/39.0.0.md b/dev/changelog/39.0.0.md new file mode 100644 index 000000000000..ff27b4ba24f5 --- /dev/null +++ b/dev/changelog/39.0.0.md @@ -0,0 +1,333 @@ + + +# Apache DataFusion 39.0.0 Changelog + +This release consists of 234 commits from 59 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- Remove ScalarFunctionDefinition [#10325](https://github.com/apache/datafusion/pull/10325) (lewiszlw) +- Introduce user-defined signature [#10439](https://github.com/apache/datafusion/pull/10439) (jayzhan211) +- Remove `AggregateFunctionDefinition::Name` [#10441](https://github.com/apache/datafusion/pull/10441) (lewiszlw) +- Make `CREATE EXTERNAL TABLE` format options consistent, remove special syntax for `HEADER ROW`, `DELIMITER` and `COMPRESSION` [#10404](https://github.com/apache/datafusion/pull/10404) (berkaysynnada) +- feat: allow `array_slice` to take an optional stride parameter [#10469](https://github.com/apache/datafusion/pull/10469) (jonahgao) +- Minor: Extend more style of udaf `expr_fn`, Remove order args for`covar_samp` and `covar_pop` [#10492](https://github.com/apache/datafusion/pull/10492) (jayzhan211) +- Remove `file_type()` from `FileFormat` [#10499](https://github.com/apache/datafusion/pull/10499) (Jefffrey) +- UDAF: Extend more args to `state_fields` and `groups_accumulator_supported` and introduce `ReversedUDAF` [#10525](https://github.com/apache/datafusion/pull/10525) (jayzhan211) +- Remove `Expr::GetIndexedField`, replace `Expr::{field,index,range}` with `FieldAccessor`, `IndexAccessor`, and `SliceAccessor` [#10568](https://github.com/apache/datafusion/pull/10568) (jayzhan211) +- Improve ContextProvider [#10577](https://github.com/apache/datafusion/pull/10577) (lewiszlw) +- Minor: Use slice in `ConcreteTreeNode` [#10666](https://github.com/apache/datafusion/pull/10666) (peter-toth) +- Add reference visitor `TreeNode` APIs, change `ExecutionPlan::children()` and `PhysicalExpr::children()` return references [#10543](https://github.com/apache/datafusion/pull/10543) (peter-toth) +- Introduce Sum UDAF [#10651](https://github.com/apache/datafusion/pull/10651) (jayzhan211) + +**Implemented enhancements:** + +- feat: optional args for regexp\_\* UDFs [#10514](https://github.com/apache/datafusion/pull/10514) (Michael-J-Ward) +- feat: Expose Parquet Schema Adapter [#10515](https://github.com/apache/datafusion/pull/10515) (HawaiianSpork) +- feat: API for collecting statistics/index for metadata of a parquet file + tests [#10537](https://github.com/apache/datafusion/pull/10537) (NGA-TRAN) +- feat: Add eliminate group by constant optimizer rule [#10591](https://github.com/apache/datafusion/pull/10591) (korowa) +- feat: extend `unnest` to support Struct datatype [#10429](https://github.com/apache/datafusion/pull/10429) (duongcongtoai) +- feat: add substrait support for Interval types and literals [#10646](https://github.com/apache/datafusion/pull/10646) (waynexia) +- feat: support unparsing LogicalPlan::Window nodes [#10767](https://github.com/apache/datafusion/pull/10767) (devinjdangelo) +- feat: Update Parquet row filtering to handle type coercion [#10716](https://github.com/apache/datafusion/pull/10716) (jeffreyssmith2nd) + +**Fixed bugs:** + +- fix: make `columnize_expr` resistant to display_name collisions [#10459](https://github.com/apache/datafusion/pull/10459) (jonahgao) +- fix: avoid compressed json files repartitioning [#10470](https://github.com/apache/datafusion/pull/10470) (korowa) +- fix: parsing timestamp with date format [#10476](https://github.com/apache/datafusion/pull/10476) (shanretoo) +- fix: `array_slice` panics [#10547](https://github.com/apache/datafusion/pull/10547) (jonahgao) +- fix: pass `quote` parameter to CSV writer [#10671](https://github.com/apache/datafusion/pull/10671) (DDtKey) +- fix: CI compilation failed on substrait [#10683](https://github.com/apache/datafusion/pull/10683) (jonahgao) +- fix: fix string repeat for negative numbers [#10760](https://github.com/apache/datafusion/pull/10760) (tshauck) +- fix: `array_slice` and `array_element` panicked on empty args [#10804](https://github.com/apache/datafusion/pull/10804) (jonahgao) + +**Documentation updates:** + +- Prepare 38.0.0 release candidate 1 [#10407](https://github.com/apache/datafusion/pull/10407) (andygrove) +- chore(docs): update subquery documentation with more information [#10361](https://github.com/apache/datafusion/pull/10361) (sanderson) +- minor: Remove docs archive [#10416](https://github.com/apache/datafusion/pull/10416) (andygrove) +- Minor: format comments in `PushDownFilter` rule [#10437](https://github.com/apache/datafusion/pull/10437) (alamb) +- Minor: Add usecase to comments in `LogicalPlan::recompute_schema` [#10443](https://github.com/apache/datafusion/pull/10443) (alamb) +- doc: fix old master branch references to main [#10458](https://github.com/apache/datafusion/pull/10458) (Jefffrey) +- Minor: Improved document string for `LogicalPlanBuilder` [#10496](https://github.com/apache/datafusion/pull/10496) (AbrarNitk) +- Add to_date function to scalar functions doc [#10601](https://github.com/apache/datafusion/pull/10601) (Omega359) +- Docs: Update PR workflow documentation [#10532](https://github.com/apache/datafusion/pull/10532) (alamb) +- Minor: Add examples of using TreeNode with `Expr` [#10686](https://github.com/apache/datafusion/pull/10686) (alamb) +- docs: add documents to substrait type variation consts [#10719](https://github.com/apache/datafusion/pull/10719) (waynexia) +- Minor: (Doc) Enable rt-multi-thread feature for sample code [#10770](https://github.com/apache/datafusion/pull/10770) (hsiang-c) + +**Other:** + +- Minor: Add more docs and examples for `Expr::unalias` [#10406](https://github.com/apache/datafusion/pull/10406) (alamb) +- minor: Remove [RUST][datafusion] from release vote email subject line [#10411](https://github.com/apache/datafusion/pull/10411) (andygrove) +- fix dml logical plan output schema [#10394](https://github.com/apache/datafusion/pull/10394) (leoyvens) +- [MINOR]: Move transpose code to under common [#10409](https://github.com/apache/datafusion/pull/10409) (mustafasrepo) +- Fix incorrect Schema over aggregate function, Remove unnecessary `exprlist_to_fields_aggregate` [#10408](https://github.com/apache/datafusion/pull/10408) (jonahgao) +- Enable user defined display_name for ScalarUDF [#10417](https://github.com/apache/datafusion/pull/10417) (yyy1000) +- Fix and improve `CommonSubexprEliminate` rule [#10396](https://github.com/apache/datafusion/pull/10396) (peter-toth) +- Simplify making information_schame tables [#10420](https://github.com/apache/datafusion/pull/10420) (lewiszlw) +- only consider main part of the url when deciding is_collection in listing table [#10419](https://github.com/apache/datafusion/pull/10419) (y-f-u) +- make common expression alias human-readable [#10333](https://github.com/apache/datafusion/pull/10333) (MohamedAbdeen21) +- Minor: Simplify + document `EliminateCrossJoin` better [#10427](https://github.com/apache/datafusion/pull/10427) (alamb) +- During expression equality, check for new ordering information [#10434](https://github.com/apache/datafusion/pull/10434) (mustafasrepo) +- Revert 10333 / changes to aliasing in CommonSubExprEliminate [#10436](https://github.com/apache/datafusion/pull/10436) (MohamedAbdeen21) +- Improve flight sql examples [#10432](https://github.com/apache/datafusion/pull/10432) (lewiszlw) +- Move Covariance (Population) covar_pop to be a User Defined Aggregate Function [#10418](https://github.com/apache/datafusion/pull/10418) (yyy1000) +- Stop copying LogicalPlan and Exprs in `OptimizeProjections` (2% faster planning) [#10405](https://github.com/apache/datafusion/pull/10405) (alamb) +- chore: Improve release process for next time [#10447](https://github.com/apache/datafusion/pull/10447) (andygrove) +- Move bit_and_or_xor unit tests to slt [#10457](https://github.com/apache/datafusion/pull/10457) (NoeB) +- Remove some Expr clones in `EliminateCrossJoin`(3%-5% faster planning) [#10430](https://github.com/apache/datafusion/pull/10430) (alamb) +- refactor: Reduce string allocations in Expr::display_name (use write instead of format!) [#10454](https://github.com/apache/datafusion/pull/10454) (erratic-pattern) +- Add `simplify` method to aggregate function [#10354](https://github.com/apache/datafusion/pull/10354) (milenkovicm) +- Add cast array test to sqllogictest [#10474](https://github.com/apache/datafusion/pull/10474) (viirya) +- Add `Expr::try_as_col`, deprecate `Expr::try_into_col` (speed up optimizer) [#10448](https://github.com/apache/datafusion/pull/10448) (alamb) +- Implement `From>` for `LogicalPlanBuilder` [#10466](https://github.com/apache/datafusion/pull/10466) (AbrarNitk) +- Minor: Improve documentation for `catalog.has_header` config option [#10452](https://github.com/apache/datafusion/pull/10452) (alamb) +- Minor: Simplify conjunction and disjunction, improve docs [#10446](https://github.com/apache/datafusion/pull/10446) (alamb) +- Stop copying LogicalPlan and Exprs in `ReplaceDistinctWithAggregate` [#10460](https://github.com/apache/datafusion/pull/10460) (ClSlaid) +- Stop copying LogicalPlan and Exprs in `EliminateCrossJoin` (4% faster planning) [#10431](https://github.com/apache/datafusion/pull/10431) (alamb) +- Improved ergonomy for `CREATE EXTERNAL TABLE OPTIONS`: Don't require quotations for simple namespaced keys like `foo.bar` [#10483](https://github.com/apache/datafusion/pull/10483) (ozankabak) +- Replace `GetFieldAccess` with indexing function in `SqlToRel ` [#10375](https://github.com/apache/datafusion/pull/10375) (jayzhan211) +- Fix values with different data types caused failure [#10445](https://github.com/apache/datafusion/pull/10445) (b41sh) +- Fix SortMergeJoin with join filter filtering all rows out [#10495](https://github.com/apache/datafusion/pull/10495) (viirya) +- chore: use fullpath in macro to avoid declaring in other module [#10503](https://github.com/apache/datafusion/pull/10503) (jayzhan211) +- Minor: remove unused source file `udf.rs` [#10497](https://github.com/apache/datafusion/pull/10497) (jonahgao) +- Support UDAF to align Builtin aggregate function [#10493](https://github.com/apache/datafusion/pull/10493) (jayzhan211) +- Minor: add a test for `current_time` (no args) [#10509](https://github.com/apache/datafusion/pull/10509) (alamb) +- [MINOR]: Move pipeline checker rule to the end [#10502](https://github.com/apache/datafusion/pull/10502) (mustafasrepo) +- Minor: Extract parent/child limit calculation into a function, improve docs [#10501](https://github.com/apache/datafusion/pull/10501) (alamb) +- Fix window expr deserialization [#10506](https://github.com/apache/datafusion/pull/10506) (lewiszlw) +- Update substrait requirement from 0.32.0 to 0.33.3 [#10516](https://github.com/apache/datafusion/pull/10516) (dependabot[bot]) +- Stop copying LogicalPlan and Exprs in `TypeCoercion` (10% faster planning) [#10356](https://github.com/apache/datafusion/pull/10356) (alamb) +- Implement unparse `IS_NULL` to String and enhance the tests [#10529](https://github.com/apache/datafusion/pull/10529) (goldmedal) +- Fix panic in array_agg(distinct) query [#10526](https://github.com/apache/datafusion/pull/10526) (jayzhan211) +- Move min_max unit tests to slt [#10539](https://github.com/apache/datafusion/pull/10539) (xinlifoobar) +- Implement unparse `IsNotFalse` to String [#10538](https://github.com/apache/datafusion/pull/10538) (goldmedal) +- Implement Unparse TryCast Expr --> String Support [#10542](https://github.com/apache/datafusion/pull/10542) (xinlifoobar) +- Implement unparse `Placeholder` to String [#10540](https://github.com/apache/datafusion/pull/10540) (reswqa) +- Implement unparse `OuterReferenceColumn` to String [#10544](https://github.com/apache/datafusion/pull/10544) (goldmedal) +- Stop copying LogicalPlan and Exprs in `PushDownFilter` (4%-6% faster planning) [#10444](https://github.com/apache/datafusion/pull/10444) (alamb) +- Stop most copying LogicalPlan and Exprs in `ScalarSubqueryToJoin` [#10489](https://github.com/apache/datafusion/pull/10489) (alamb) +- Example for simple Expr --> SQL conversion [#10528](https://github.com/apache/datafusion/pull/10528) (edmondop) +- fix `null_count` on `compute_record_batch_statistics` to report null counts across partitions [#10468](https://github.com/apache/datafusion/pull/10468) (samuelcolvin) +- Minor: Add `PullUpCorrelatedExpr::new` and improve documentation [#10500](https://github.com/apache/datafusion/pull/10500) (alamb) +- Stop copying LogicalPlan and Exprs in `PushDownLimit` [#10508](https://github.com/apache/datafusion/pull/10508) (alamb) +- Break up contributing guide into smaller pages [#10533](https://github.com/apache/datafusion/pull/10533) (alamb) +- PhysicalExpr Orderings with Range Information [#10504](https://github.com/apache/datafusion/pull/10504) (berkaysynnada) +- Implement unparse `ScalarVariable` to String [#10541](https://github.com/apache/datafusion/pull/10541) (reswqa) +- Handle dictionary values in ScalarValue serde [#10563](https://github.com/apache/datafusion/pull/10563) (thinkharderdev) +- Improve signature of `get_field` function [#10569](https://github.com/apache/datafusion/pull/10569) (lewiszlw) +- Implement Unparse `GroupingSet` Expr --> String Support sql [#10555](https://github.com/apache/datafusion/pull/10555) (xinlifoobar) +- Minor: Move proxy to datafusion common [#10561](https://github.com/apache/datafusion/pull/10561) (jayzhan211) +- Update prost-build requirement from =0.12.4 to =0.12.6 [#10578](https://github.com/apache/datafusion/pull/10578) (dependabot[bot]) +- Add examples of how to convert logical plan to/from sql strings [#10558](https://github.com/apache/datafusion/pull/10558) (xinlifoobar) +- Fix: Sort Merge Join LeftSemi issues when JoinFilter is set [#10304](https://github.com/apache/datafusion/pull/10304) (comphead) +- Minor: Fix `ArrayFunctionRewriter` name reporting [#10581](https://github.com/apache/datafusion/pull/10581) (alamb) +- Improve `UserDefinedLogicalNode::from_template` API to return `Result` [#10575](https://github.com/apache/datafusion/pull/10575) (lewiszlw) +- Migrate testing optimizer rules to use `rewrite` API [#10576](https://github.com/apache/datafusion/pull/10576) (lewiszlw) +- test: add more tests for statistics reading [#10592](https://github.com/apache/datafusion/pull/10592) (NGA-TRAN) +- refactor: reduce allocations in push down filter [#10567](https://github.com/apache/datafusion/pull/10567) (erratic-pattern) +- Fix compilation of datafusion-cli on 32bit targets [#10594](https://github.com/apache/datafusion/pull/10594) (nathaniel-daniel) +- Rename monotonicity as output_ordering in ScalarUDF's [#10596](https://github.com/apache/datafusion/pull/10596) (berkaysynnada) +- Implement Unparser for `UNION ALL` [#10603](https://github.com/apache/datafusion/pull/10603) (phillipleblanc) +- Improve `UserDefinedLogicalNodeCore::from_template` API to return Result [#10597](https://github.com/apache/datafusion/pull/10597) (lewiszlw) +- Minor: Move group accumulator for aggregate function to physical-expr-common, and add ahash physical-expr-common [#10574](https://github.com/apache/datafusion/pull/10574) (jayzhan211) +- Minor: Consolidate some integration tests into `core_integration` [#10588](https://github.com/apache/datafusion/pull/10588) (alamb) +- Stop copying LogicalPlan and Exprs in `SingleDistinctToGroupBy` [#10527](https://github.com/apache/datafusion/pull/10527) (appletreeisyellow) +- [MINOR]: Update get range implementation for lead lag window functions [#10614](https://github.com/apache/datafusion/pull/10614) (mustafasrepo) +- Minor: Improve documentation in sql_to_plan example [#10582](https://github.com/apache/datafusion/pull/10582) (alamb) +- Docs: add examples for `RuntimeEnv::register_object_store`, improve error messages [#10617](https://github.com/apache/datafusion/pull/10617) (aditanase) +- Add support for Substrait List/EmptyList literals [#10615](https://github.com/apache/datafusion/pull/10615) (Blizzara) +- Add to_unixtime function to scalar functions doc [#10620](https://github.com/apache/datafusion/pull/10620) (Omega359) +- Test for reading read statistics from parquet files without statistics and boolean & struct data type [#10608](https://github.com/apache/datafusion/pull/10608) (NGA-TRAN) +- adding benchmark for extracting arrow statistics from parquet [#10610](https://github.com/apache/datafusion/pull/10610) (Lordworms) +- Implement a dialect-specific rule for unparsing an identifier with or without quotes [#10573](https://github.com/apache/datafusion/pull/10573) (goldmedal) +- add catalog as part of the table path in plan_to_sql [#10612](https://github.com/apache/datafusion/pull/10612) (y-f-u) +- Refactor parquet row group pruning into a struct (use new statistics API, part 1) [#10607](https://github.com/apache/datafusion/pull/10607) (alamb) +- Extract `Date32` parquet statistics as `Date32Array` rather than `Int32Array` [#10593](https://github.com/apache/datafusion/pull/10593) (xinlifoobar) +- Omit NULLS FIRST/LAST when unparsing ORDER BY clauses for MySQL [#10625](https://github.com/apache/datafusion/pull/10625) (phillipleblanc) +- Fix broken build/test from merge [#10637](https://github.com/apache/datafusion/pull/10637) (phillipleblanc) +- Add SessionContext::register_object_store [#10621](https://github.com/apache/datafusion/pull/10621) (alamb) +- Minor: Move median test [#10611](https://github.com/apache/datafusion/pull/10611) (jayzhan211) +- Add support for Substrait Struct literals and type [#10622](https://github.com/apache/datafusion/pull/10622) (Blizzara) +- fix Incorrect statistics read for i8 i16 columns in parquet [#10629](https://github.com/apache/datafusion/pull/10629) (Lordworms) +- Minor: add runtime asserts to `RowGroup` [#10641](https://github.com/apache/datafusion/pull/10641) (alamb) +- Update cli Dockerfile to a newer ubuntu release, newer rust release [#10638](https://github.com/apache/datafusion/pull/10638) (Omega359) +- More properly handle nullability of types/literals in Substrait [#10640](https://github.com/apache/datafusion/pull/10640) (Blizzara) +- fix wrong type validation on unnest expr [#10657](https://github.com/apache/datafusion/pull/10657) (duongcongtoai) +- Fix incorrect statistics read for binary columns in parquet [#10645](https://github.com/apache/datafusion/pull/10645) (xinlifoobar) +- Fix `NULL["field"]` for expr_API [#10655](https://github.com/apache/datafusion/pull/10655) (alamb) +- Update substrait requirement from 0.33.3 to 0.34.0 [#10632](https://github.com/apache/datafusion/pull/10632) (dependabot[bot]) +- Fix typo in Cargo.toml (unused manifest key: dependencies.regex.worksapce) [#10662](https://github.com/apache/datafusion/pull/10662) (alamb) +- Add `FileScanConfig::new()` API [#10623](https://github.com/apache/datafusion/pull/10623) (alamb) +- Minor: Remove `GetFieldAccessSchema` [#10665](https://github.com/apache/datafusion/pull/10665) (jayzhan211) +- Move Median to `functions-aggregate` and Introduce Numeric signature [#10644](https://github.com/apache/datafusion/pull/10644) (jayzhan211) +- Fix `Coalesce` casting logic to follows what Postgres and DuckDB do. Introduce signature that do non-comparison coercion [#10268](https://github.com/apache/datafusion/pull/10268) (jayzhan211) +- Fix compilation "comparison_binary_numeric_coercion not found" [#10677](https://github.com/apache/datafusion/pull/10677) (alamb) +- refactor: simplify converting List DataTypes to `ScalarValue` [#10675](https://github.com/apache/datafusion/pull/10675) (jonahgao) +- Minor: Improve ObjectStoreUrl docs + examples [#10619](https://github.com/apache/datafusion/pull/10619) (alamb) +- Add tests for reading numeric limits in parquet statistics [#10642](https://github.com/apache/datafusion/pull/10642) (alamb) +- Update nix requirement from 0.28.0 to 0.29.0 [#10684](https://github.com/apache/datafusion/pull/10684) (dependabot[bot]) +- refactor: Move SchemaAdapter from parquet module to data source [#10680](https://github.com/apache/datafusion/pull/10680) (HawaiianSpork) +- Convert first, last aggregate function to UDAF [#10648](https://github.com/apache/datafusion/pull/10648) (mustafasrepo) +- Minor: CastExpr Ordering Handle [#10650](https://github.com/apache/datafusion/pull/10650) (berkaysynnada) +- Factor out common datafusion types into another proto file [#10649](https://github.com/apache/datafusion/pull/10649) (mustafasrepo) +- Minor: Add tests showing aggregate behavior for NaNs [#10634](https://github.com/apache/datafusion/pull/10634) (alamb) +- Improve `ParquetExec` and related documentation [#10647](https://github.com/apache/datafusion/pull/10647) (alamb) +- minor: inconsistent group by position planning [#10679](https://github.com/apache/datafusion/pull/10679) (korowa) +- Remove duplicate function name in its aliases list [#10661](https://github.com/apache/datafusion/pull/10661) (goldmedal) +- Add protobuf serde support for `LogicalPlan::Unnest` [#10681](https://github.com/apache/datafusion/pull/10681) (akoshchiy) +- Support Substrait's VirtualTables [#10531](https://github.com/apache/datafusion/pull/10531) (Blizzara) +- support serialization and deserialization limit in the aggregation exec [#10692](https://github.com/apache/datafusion/pull/10692) (liukun4515) +- Display date32/64 in YYYY-MM-DD format [#10691](https://github.com/apache/datafusion/pull/10691) (houqp) +- Fix: array list values are leaked on nested `unnest` operators [#10689](https://github.com/apache/datafusion/pull/10689) (duongcongtoai) +- Support LogicalPlan::Distinct in unparser [#10690](https://github.com/apache/datafusion/pull/10690) (yyy1000) +- Remove redundant upper case aliases for `median`, `first_value` and `last_value` [#10696](https://github.com/apache/datafusion/pull/10696) (goldmedal) +- Minor: improve Expr documentation [#10685](https://github.com/apache/datafusion/pull/10685) (alamb) +- chore: align re-exports in functions-aggregate [#10705](https://github.com/apache/datafusion/pull/10705) (waynexia) +- Fix typo in bench.sh [#10698](https://github.com/apache/datafusion/pull/10698) (vimt) +- Fix incorrect statistics read for unsigned integers columns in parquet [#10704](https://github.com/apache/datafusion/pull/10704) (xinlifoobar) +- Separate `Partitioning` protobuf serialization code [#10708](https://github.com/apache/datafusion/pull/10708) (lewiszlw) +- Support consuming Substrait with compound signature function names [#10653](https://github.com/apache/datafusion/pull/10653) (Blizzara) +- Minor: Add examples of using TreeNode with `LogicalPlan` [#10687](https://github.com/apache/datafusion/pull/10687) (alamb) +- Add `ParquetExec::builder()`, deprecate `ParquetExec::new` [#10636](https://github.com/apache/datafusion/pull/10636) (alamb) +- feature: Add a WindowUDFImpl::simplify() API [#9906](https://github.com/apache/datafusion/pull/9906) (guojidan) +- Chore: clean up udwf example && remove redundant import [#10718](https://github.com/apache/datafusion/pull/10718) (guojidan) +- Push down filter as table partition list prefix [#10693](https://github.com/apache/datafusion/pull/10693) (houqp) +- Make swap_hash_join public API [#10702](https://github.com/apache/datafusion/pull/10702) (viirya) +- ci: fix clippy error on main [#10723](https://github.com/apache/datafusion/pull/10723) (jonahgao) +- CI: Fix complaints from newer Clippy versions [#10725](https://github.com/apache/datafusion/pull/10725) (comphead) +- Remove Eager Trait for Joins [#10721](https://github.com/apache/datafusion/pull/10721) (berkaysynnada) +- Minor: fix signature `fn octect_length()` [#10726](https://github.com/apache/datafusion/pull/10726) (marvinlanhenke) +- Update rstest requirement from 0.19.0 to 0.20.0 [#10734](https://github.com/apache/datafusion/pull/10734) (dependabot[bot]) +- Update rstest_reuse requirement from 0.6.0 to 0.7.0 [#10733](https://github.com/apache/datafusion/pull/10733) (dependabot[bot]) +- Add example for building an external secondary index for parquet files [#10549](https://github.com/apache/datafusion/pull/10549) (alamb) +- Minor: move stddev test to slt [#10741](https://github.com/apache/datafusion/pull/10741) (marvinlanhenke) +- fix(CLI): can not create external tables with format options [#10739](https://github.com/apache/datafusion/pull/10739) (jonahgao) +- Add support for `AggregateExpr`, `WindowExpr` rewrite. [#10742](https://github.com/apache/datafusion/pull/10742) (mustafasrepo) +- Fix SMJ Left Anti Join when the join filter is set [#10724](https://github.com/apache/datafusion/pull/10724) (comphead) +- Introduce FunctionRegistry dependency to optimize and rewrite rule [#10714](https://github.com/apache/datafusion/pull/10714) (jayzhan211) +- Minor: Add SMJ to TPCH benchmark usage [#10747](https://github.com/apache/datafusion/pull/10747) (comphead) +- Minor: Split physical_plan/parquet/mod.rs into smaller modules [#10727](https://github.com/apache/datafusion/pull/10727) (alamb) +- minor: consolidate unparser integration tests [#10736](https://github.com/apache/datafusion/pull/10736) (devinjdangelo) +- Minor: Move aggregate variance to slt [#10750](https://github.com/apache/datafusion/pull/10750) (marvinlanhenke) +- Extract parquet statistics from timestamps with timezones [#10766](https://github.com/apache/datafusion/pull/10766) (xinlifoobar) +- Minor: Add tests for extracting dictionary parquet statistics [#10729](https://github.com/apache/datafusion/pull/10729) (alamb) +- Update rstest requirement from 0.20.0 to 0.21.0 [#10774](https://github.com/apache/datafusion/pull/10774) (dependabot[bot]) +- Minor: Refactor memory size estimation for HashTable [#10748](https://github.com/apache/datafusion/pull/10748) (marvinlanhenke) +- Reduce code repetition in `datafusion/functions` mod files [#10700](https://github.com/apache/datafusion/pull/10700) (MohamedAbdeen21) +- Support negatives in split part [#10780](https://github.com/apache/datafusion/pull/10780) (tshauck) +- Extract parquet statistics from `LargeUtf8` columns and Add tests for `UTF8` And `LargeUTF8` [#10762](https://github.com/apache/datafusion/pull/10762) (Weijun-H) +- Cleanup GetIndexedField [#10769](https://github.com/apache/datafusion/pull/10769) (lewiszlw) +- Extract parquet statistics from f16 columns, add `ScalarValue::Float16` [#10763](https://github.com/apache/datafusion/pull/10763) (Lordworms) +- Handle empty rows for `array_sort` [#10786](https://github.com/apache/datafusion/pull/10786) (jayzhan211) +- Fix extract parquet statistics from LargeBinary columns [#10775](https://github.com/apache/datafusion/pull/10775) (xinlifoobar) +- Extract parquet statistics from Time32 and Time64 columns [#10771](https://github.com/apache/datafusion/pull/10771) (Lordworms) +- chore: fix `last_value` coercion [#10783](https://github.com/apache/datafusion/pull/10783) (appletreeisyellow) +- Fix extract parquet statistics from Decimal256 columns [#10777](https://github.com/apache/datafusion/pull/10777) (xinlifoobar) +- Speed up arrow_statistics test [#10735](https://github.com/apache/datafusion/pull/10735) (alamb) +- minor: Refactor some unparser methods to improve readability [#10788](https://github.com/apache/datafusion/pull/10788) (devinjdangelo) +- Convert variance sample to udaf [#10713](https://github.com/apache/datafusion/pull/10713) (yyin-dev) +- Improve docs and fix a typo [#10798](https://github.com/apache/datafusion/pull/10798) (lewiszlw) +- Avoid the usage of intermediate ScalarValue to improve performance of extracting statistics from parquet files [#10711](https://github.com/apache/datafusion/pull/10711) (xinlifoobar) +- SMJ: Add more tests and improve comments [#10784](https://github.com/apache/datafusion/pull/10784) (comphead) +- Handle EmptyRelation during SQL unparsing [#10803](https://github.com/apache/datafusion/pull/10803) (goldmedal) +- Document Committer and PMC process [#10778](https://github.com/apache/datafusion/pull/10778) (alamb) +- Int64 as default type for make_array function empty or null case [#10790](https://github.com/apache/datafusion/pull/10790) (jayzhan211) +- Split `SessionState` into its own module [#10794](https://github.com/apache/datafusion/pull/10794) (alamb) +- Add `StreamProvider` for configuring `StreamTable` [#10600](https://github.com/apache/datafusion/pull/10600) (matthewmturner) +- Bench: Add `PREFER_HASH_JOIN` env variable [#10809](https://github.com/apache/datafusion/pull/10809) (comphead) +- Add `ParquetAccessPlan`, unify RowGroup selection and PagePruning selection [#10738](https://github.com/apache/datafusion/pull/10738) (alamb) +- Fix `ScalarUDFImpl::propagate_constraints` doc [#10810](https://github.com/apache/datafusion/pull/10810) (lewiszlw) +- Extract Parquet statistics from `Interval` column [#10801](https://github.com/apache/datafusion/pull/10801) (marvinlanhenke) +- build(deps): upgrade sqlparser to 0.47.0 [#10392](https://github.com/apache/datafusion/pull/10392) (tisonkun) +- Refactor and simplify the SQL unparser [#10811](https://github.com/apache/datafusion/pull/10811) (goldmedal) +- Minor: Remove code duplication in `memory_limit` derivation for datafusion-cli [#10814](https://github.com/apache/datafusion/pull/10814) (comphead) +- build(deps): update Arrow/Parquet to `52.0`, object-store to `0.10` [#10765](https://github.com/apache/datafusion/pull/10765) (waynexia) +- chore: Prepare 39.0.0-rc1 [#10828](https://github.com/apache/datafusion/pull/10828) (andygrove) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 44 Andrew Lamb + 18 Jay Zhan + 14 张林伟 + 11 Andy Grove + 11 Xin Li + 10 Jonah Gao + 8 Jax Liu + 7 Mustafa Akur + 7 Oleks V + 7 dependabot[bot] + 5 Arttu + 5 Berkay Şahin + 5 Marvin Lanhenke + 4 Lordworms + 4 Ruihang Xia + 3 Bruce Ritchie + 3 Devin D'Angelo + 3 Duong Cong Toai + 3 Eduard Karacharov + 3 Junhao Liu + 3 Liang-Chi Hsieh + 3 Mohamed Abdeen + 3 Nga Tran + 3 Peter Toth + 3 Phillip LeBlanc + 2 Abrar Khan + 2 Adam Curtis + 2 Chunchun Ye + 2 Jeffrey Vo + 2 Michael Maletich + 2 QP Hou + 2 Trent Hauck + 2 Weijie Guo + 2 junxiangMu + 2 yfu + 1 Adrian Tanase + 1 Alex Huang + 1 Andrey Koshchiy + 1 Artem Medvedev + 1 ClSlaid + 1 Dan Harris + 1 Edmondo Porcu + 1 Jeffrey Smith II + 1 Kun Liu + 1 Leonardo Yvens + 1 Marko Milenković + 1 Matthew Turner + 1 Mehmet Ozan Kabak + 1 Michael J Ward + 1 NoeB + 1 Samuel Colvin + 1 Scott Anderson + 1 VimT + 1 Yue Yin + 1 baishen + 1 hsiang-c + 1 nathaniel-daniel + 1 shanretoo + 1 tison +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/40.0.0.md b/dev/changelog/40.0.0.md new file mode 100644 index 000000000000..72143ae48b28 --- /dev/null +++ b/dev/changelog/40.0.0.md @@ -0,0 +1,371 @@ + + +# Apache DataFusion 40.0.0 Changelog + +This release consists of 263 commits from 64 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- Convert `StringAgg` to UDAF [#10945](https://github.com/apache/datafusion/pull/10945) (lewiszlw) +- Convert `bool_and` & `bool_or` to UDAF [#11009](https://github.com/apache/datafusion/pull/11009) (jcsherin) +- Convert Average to UDAF #10942 [#10964](https://github.com/apache/datafusion/pull/10964) (dharanad) +- fix: remove the Sized requirement on ExecutionPlan::name() [#11047](https://github.com/apache/datafusion/pull/11047) (waynexia) +- Return `&Arc` reference to inner trait object [#11103](https://github.com/apache/datafusion/pull/11103) (linhr) +- Support COPY TO Externally Defined File Formats, add FileType trait [#11060](https://github.com/apache/datafusion/pull/11060) (devinjdangelo) +- expose table name in proto extension codec [#11139](https://github.com/apache/datafusion/pull/11139) (leoyvens) +- fix(typo): unqualifed to unqualified [#11159](https://github.com/apache/datafusion/pull/11159) (waynexia) +- Consolidate `Filter::remove_aliases` into `Expr::unalias_nested` [#11001](https://github.com/apache/datafusion/pull/11001) (alamb) +- Convert `nth_value` to UDAF [#11287](https://github.com/apache/datafusion/pull/11287) (jcsherin) + +**Implemented enhancements:** + +- feat: Add support for Int8 and Int16 data types in data page statistics [#10931](https://github.com/apache/datafusion/pull/10931) (Weijun-H) +- feat: add CliSessionContext trait for cli [#10890](https://github.com/apache/datafusion/pull/10890) (tshauck) +- feat(optimizer): handle partial anchored regex cases and improve doc [#10977](https://github.com/apache/datafusion/pull/10977) (waynexia) +- feat: support uint data page extraction [#11018](https://github.com/apache/datafusion/pull/11018) (tshauck) +- feat: propagate EmptyRelation for more join types [#10963](https://github.com/apache/datafusion/pull/10963) (tshauck) +- feat: Add method to add analyzer rules to SessionContext [#10849](https://github.com/apache/datafusion/pull/10849) (pingsutw) +- feat: Support duplicate column names in Joins in Substrait consumer [#11049](https://github.com/apache/datafusion/pull/11049) (Blizzara) +- feat: Add support for Timestamp data types in data page statistics. [#11123](https://github.com/apache/datafusion/pull/11123) (efredine) +- feat: Add support for `Binary`/`LargeBinary`/`Utf8`/`LargeUtf8` data types in data page statistics [#11136](https://github.com/apache/datafusion/pull/11136) (PsiACE) +- feat: Support Map type in Substrait conversions [#11129](https://github.com/apache/datafusion/pull/11129) (Blizzara) +- feat: Conditionally allow to keep partition_by columns when using PARTITIONED BY enhancement [#11107](https://github.com/apache/datafusion/pull/11107) (hveiga) +- feat: enable "substring" as a UDF in addition to "substr" [#11277](https://github.com/apache/datafusion/pull/11277) (Blizzara) + +**Fixed bugs:** + +- fix: use total ordering in the min & max accumulator for floats [#10627](https://github.com/apache/datafusion/pull/10627) (westonpace) +- fix: Support double quotes in `date_part` [#10833](https://github.com/apache/datafusion/pull/10833) (Weijun-H) +- fix: Ignore nullability of list elements when consuming Substrait [#10874](https://github.com/apache/datafusion/pull/10874) (Blizzara) +- fix: Support `NOT IN ()` via anti join [#10936](https://github.com/apache/datafusion/pull/10936) (akoshchiy) +- fix: CTEs defined in a subquery can escape their scope [#10954](https://github.com/apache/datafusion/pull/10954) (jonahgao) +- fix: Fix the incorrect null joined rows for SMJ outer join with join filter [#10892](https://github.com/apache/datafusion/pull/10892) (viirya) +- fix: gcd returns negative results [#11099](https://github.com/apache/datafusion/pull/11099) (jonahgao) +- fix: LCM panicked due to overflow [#11131](https://github.com/apache/datafusion/pull/11131) (jonahgao) +- fix: Support dictionary type in parquet metadata statistics. [#11169](https://github.com/apache/datafusion/pull/11169) (efredine) +- fix: Ignore nullability in Substrait structs [#11130](https://github.com/apache/datafusion/pull/11130) (Blizzara) +- fix: typo in comment about FinalPhysicalPlan [#11181](https://github.com/apache/datafusion/pull/11181) (c8ef) +- fix: Support Substrait's compound names also for window functions [#11163](https://github.com/apache/datafusion/pull/11163) (Blizzara) +- fix: Incorrect LEFT JOIN evaluation result on OR conditions [#11203](https://github.com/apache/datafusion/pull/11203) (viirya) +- fix: Be more lenient in interpreting input args for builtin window functions [#11199](https://github.com/apache/datafusion/pull/11199) (Blizzara) +- fix: correctly handle Substrait windows with rows bounds (and validate executability of test plans) [#11278](https://github.com/apache/datafusion/pull/11278) (Blizzara) +- fix: When consuming Substrait, temporarily rename clashing duplicate columns [#11329](https://github.com/apache/datafusion/pull/11329) (Blizzara) + +**Documentation updates:** + +- Minor: Clarify `SessionContext::state` docs [#10847](https://github.com/apache/datafusion/pull/10847) (alamb) +- Minor: Update SIGMOD paper reference url [#10860](https://github.com/apache/datafusion/pull/10860) (alamb) +- docs(variance): Correct typos in comments [#10844](https://github.com/apache/datafusion/pull/10844) (pingsutw) +- Add missing code close tick in LiteralGuarantee docs [#10859](https://github.com/apache/datafusion/pull/10859) (adriangb) +- Minor: Add more docs and examples for `Transformed` and `TransformedResult` [#11003](https://github.com/apache/datafusion/pull/11003) (alamb) +- doc: Update links in the documantation [#11044](https://github.com/apache/datafusion/pull/11044) (Weijun-H) +- Minor: Examples cleanup + more docs in pruning example [#11086](https://github.com/apache/datafusion/pull/11086) (alamb) +- Minor: refine documentation pointing to examples [#11110](https://github.com/apache/datafusion/pull/11110) (alamb) +- Fix running in Docker instructions [#11141](https://github.com/apache/datafusion/pull/11141) (findepi) +- docs: add example for custom file format with `COPY TO` [#11174](https://github.com/apache/datafusion/pull/11174) (tshauck) +- Fix docs wordings [#11226](https://github.com/apache/datafusion/pull/11226) (findepi) +- Fix count() docs around including null values [#11293](https://github.com/apache/datafusion/pull/11293) (findepi) + +**Other:** + +- chore: Prepare 39.0.0-rc1 [#10828](https://github.com/apache/datafusion/pull/10828) (andygrove) +- Remove expr_fn::sum and replace them with function stub [#10816](https://github.com/apache/datafusion/pull/10816) (jayzhan211) +- Debug print as many fields as possible for `SessionState` [#10818](https://github.com/apache/datafusion/pull/10818) (lewiszlw) +- Prune Parquet RowGroup in a single call to `PruningPredicate::prune`, update StatisticsExtractor API [#10802](https://github.com/apache/datafusion/pull/10802) (alamb) +- Remove Built-in sum and Rename to lowercase `sum` [#10831](https://github.com/apache/datafusion/pull/10831) (jayzhan211) +- Convert `stddev` and `stddev_pop` to UDAF [#10834](https://github.com/apache/datafusion/pull/10834) (goldmedal) +- Introduce expr builder for aggregate function [#10560](https://github.com/apache/datafusion/pull/10560) (jayzhan211) +- chore: Improve change log generator [#10841](https://github.com/apache/datafusion/pull/10841) (andygrove) +- Support user defined `ParquetAccessPlan` in `ParquetExec`, validation to `ParquetAccessPlan::select` [#10813](https://github.com/apache/datafusion/pull/10813) (alamb) +- Convert `VariancePopulation` to UDAF [#10836](https://github.com/apache/datafusion/pull/10836) (mknaw) +- Convert `approx_median` to UDAF [#10840](https://github.com/apache/datafusion/pull/10840) (goldmedal) +- MINOR: use workspace deps in proto-common (upgrade object store dependency) [#10848](https://github.com/apache/datafusion/pull/10848) (waynexia) +- Minor: add `Window::try_new_with_schema` constructor [#10850](https://github.com/apache/datafusion/pull/10850) (sadboy) +- Add support for reading CSV files with comments [#10467](https://github.com/apache/datafusion/pull/10467) (bbannier) +- Convert approx_distinct to UDAF [#10851](https://github.com/apache/datafusion/pull/10851) (Lordworms) +- minor: add proto-common crate to release instructions [#10858](https://github.com/apache/datafusion/pull/10858) (andygrove) +- Implement TPCH substrait integration teset, support tpch_1 [#10842](https://github.com/apache/datafusion/pull/10842) (Lordworms) +- Remove unecessary passing around of `suffix: &str` in `pruning.rs`'s `RequiredColumns` [#10863](https://github.com/apache/datafusion/pull/10863) (adriangb) +- chore: Make DFSchema::datatype_is_logically_equal function public [#10867](https://github.com/apache/datafusion/pull/10867) (advancedxy) +- Bump braces from 3.0.2 to 3.0.3 in /datafusion/wasmtest/datafusion-wasm-app [#10865](https://github.com/apache/datafusion/pull/10865) (dependabot[bot]) +- Docs: Add `unnest` to SQL Reference [#10839](https://github.com/apache/datafusion/pull/10839) (gloomweaver) +- Support correct output column names and struct field names when consuming/producing Substrait [#10829](https://github.com/apache/datafusion/pull/10829) (Blizzara) +- Make Logical Plans more readable by removing extra aliases [#10832](https://github.com/apache/datafusion/pull/10832) (MohamedAbdeen21) +- Minor: Improve `ListingTable` documentation [#10854](https://github.com/apache/datafusion/pull/10854) (alamb) +- Extending join fuzz tests to support join filtering [#10728](https://github.com/apache/datafusion/pull/10728) (edmondop) +- replace and(_, not(_)) with and_not(\*) [#10885](https://github.com/apache/datafusion/pull/10885) (RTEnzyme) +- Disabling test for semi join with filters [#10887](https://github.com/apache/datafusion/pull/10887) (edmondop) +- Minor: Update `min_statistics` and `max_statistics` to be helpers, update docs [#10866](https://github.com/apache/datafusion/pull/10866) (alamb) +- Remove `Interval` column test // parquet extraction [#10888](https://github.com/apache/datafusion/pull/10888) (marvinlanhenke) +- Minor: SMJ fuzz tests fix for rowcounts [#10891](https://github.com/apache/datafusion/pull/10891) (comphead) +- Move `Count` to `functions-aggregate`, update MSRV to rust 1.75 [#10484](https://github.com/apache/datafusion/pull/10484) (jayzhan211) +- refactor: fetch statistics for a given ParquetMetaData [#10880](https://github.com/apache/datafusion/pull/10880) (NGA-TRAN) +- Move FileSinkExec::metrics to the correct place [#10901](https://github.com/apache/datafusion/pull/10901) (joroKr21) +- Refine ParquetAccessPlan comments and tests [#10896](https://github.com/apache/datafusion/pull/10896) (alamb) +- ci: fix clippy failures on main [#10903](https://github.com/apache/datafusion/pull/10903) (jonahgao) +- Minor: disable flaky fuzz test [#10904](https://github.com/apache/datafusion/pull/10904) (comphead) +- Remove builtin count [#10893](https://github.com/apache/datafusion/pull/10893) (jayzhan211) +- Move Regr\_\* functions to use UDAF [#10898](https://github.com/apache/datafusion/pull/10898) (eejbyfeldt) +- Docs: clarify when the parquet reader will read from object store when using cached metadata [#10909](https://github.com/apache/datafusion/pull/10909) (alamb) +- Minor: Fix `bench.sh tpch data` [#10905](https://github.com/apache/datafusion/pull/10905) (alamb) +- Minor: use venv in benchmark compare [#10894](https://github.com/apache/datafusion/pull/10894) (tmi) +- Support explicit type and name during table creation [#10273](https://github.com/apache/datafusion/pull/10273) (duongcongtoai) +- Simplify Join Partition Rules [#10911](https://github.com/apache/datafusion/pull/10911) (berkaysynnada) +- Move `Literal` to `physical-expr-common` [#10910](https://github.com/apache/datafusion/pull/10910) (lewiszlw) +- chore: update some error messages for clarity [#10916](https://github.com/apache/datafusion/pull/10916) (jeffreyssmith2nd) +- Initial Extract parquet data page statistics API [#10852](https://github.com/apache/datafusion/pull/10852) (marvinlanhenke) +- Add contains function, and support in datafusion substrait consumer [#10879](https://github.com/apache/datafusion/pull/10879) (Lordworms) +- Minor: Improve `arrow_statistics` tests [#10927](https://github.com/apache/datafusion/pull/10927) (alamb) +- Minor: Remove `prefer_hash_join` env variable for clickbench [#10933](https://github.com/apache/datafusion/pull/10933) (jayzhan211) +- Convert ApproxPercentileCont and ApproxPercentileContWithWeight to UDAF [#10917](https://github.com/apache/datafusion/pull/10917) (goldmedal) +- refactor: remove extra default in max rows [#10941](https://github.com/apache/datafusion/pull/10941) (tshauck) +- chore: Improve performance of Parquet statistics conversion [#10932](https://github.com/apache/datafusion/pull/10932) (Weijun-H) +- Add catalog::resolve_table_references [#10876](https://github.com/apache/datafusion/pull/10876) (leoyvens) +- Convert BitAnd, BitOr, BitXor to UDAF [#10930](https://github.com/apache/datafusion/pull/10930) (dharanad) +- refactor: improve PoolType argument handling for CLI [#10940](https://github.com/apache/datafusion/pull/10940) (tshauck) +- Minor: remove potential string copy from Column::from_qualified_name [#10947](https://github.com/apache/datafusion/pull/10947) (alamb) +- Fix: StatisticsConverter `counts` for missing columns [#10946](https://github.com/apache/datafusion/pull/10946) (marvinlanhenke) +- Add initial support for Utf8View and BinaryView types [#10925](https://github.com/apache/datafusion/pull/10925) (XiangpengHao) +- Use shorter aliases in CSE [#10939](https://github.com/apache/datafusion/pull/10939) (peter-toth) +- Substrait support for ParquetExec round trip for simple select [#10949](https://github.com/apache/datafusion/pull/10949) (xinlifoobar) +- Support to unparse `ScalarValue::IntervalMonthDayNano` to String [#10956](https://github.com/apache/datafusion/pull/10956) (goldmedal) +- Minor: Return option from row_group_row_count [#10973](https://github.com/apache/datafusion/pull/10973) (marvinlanhenke) +- Minor: Add routine to debug join fuzz tests [#10970](https://github.com/apache/datafusion/pull/10970) (comphead) +- Support to unparse `ScalarValue::TimestampNanosecond` to String [#10984](https://github.com/apache/datafusion/pull/10984) (goldmedal) +- build(deps-dev): bump ws from 8.14.2 to 8.17.1 in /datafusion/wasmtest/datafusion-wasm-app [#10988](https://github.com/apache/datafusion/pull/10988) (dependabot[bot]) +- Minor: reuse Rows buffer in GroupValuesRows [#10980](https://github.com/apache/datafusion/pull/10980) (alamb) +- Add example for writing SQL analysis using DataFusion structures [#10938](https://github.com/apache/datafusion/pull/10938) (LorrensP-2158466) +- Push down filter for Unnest plan [#10974](https://github.com/apache/datafusion/pull/10974) (jayzhan211) +- Add parquet page stats for float{16, 32, 64} [#10982](https://github.com/apache/datafusion/pull/10982) (tmi) +- Fix `file_stream_provider` example compilation failure on windows [#10975](https://github.com/apache/datafusion/pull/10975) (lewiszlw) +- Stop copying LogicalPlan and Exprs in `CommonSubexprEliminate` (2-3% planning speed improvement) [#10835](https://github.com/apache/datafusion/pull/10835) (alamb) +- chore: Update documentation link in `PhysicalOptimizerRule` comment [#11002](https://github.com/apache/datafusion/pull/11002) (Weijun-H) +- Push down filter plan for unnest on non-unnest column only [#10991](https://github.com/apache/datafusion/pull/10991) (jayzhan211) +- Minor: add test for pushdown past unnest [#11017](https://github.com/apache/datafusion/pull/11017) (alamb) +- Update docs for `protoc` minimum installed version [#11006](https://github.com/apache/datafusion/pull/11006) (jcsherin) +- propagate error instead of panicking on out of bounds in physical-expr/src/analysis.rs [#10992](https://github.com/apache/datafusion/pull/10992) (LorrensP-2158466) +- Add drop_columns to dataframe api [#11010](https://github.com/apache/datafusion/pull/11010) (Omega359) +- Push down filter plan for non-unnest column [#11019](https://github.com/apache/datafusion/pull/11019) (jayzhan211) +- Consider timezones with `UTC` and `+00:00` to be the same [#10960](https://github.com/apache/datafusion/pull/10960) (marvinlanhenke) +- Deprecate `OptimizerRule::try_optimize` [#11022](https://github.com/apache/datafusion/pull/11022) (lewiszlw) +- Relax combine partial final rule [#10913](https://github.com/apache/datafusion/pull/10913) (mustafasrepo) +- Compute gcd with u64 instead of i64 because of overflows [#11036](https://github.com/apache/datafusion/pull/11036) (LorrensP-2158466) +- Add distinct_on to dataframe api [#11012](https://github.com/apache/datafusion/pull/11012) (Omega359) +- chore: add test to show current behavior of `AT TIME ZONE` for string vs. timestamp [#11056](https://github.com/apache/datafusion/pull/11056) (appletreeisyellow) +- Boolean parquet get datapage stat [#11054](https://github.com/apache/datafusion/pull/11054) (LorrensP-2158466) +- Using display_name for Expr::Aggregation [#11020](https://github.com/apache/datafusion/pull/11020) (Lordworms) +- Minor: Convert `Count`'s name to lowercase [#11028](https://github.com/apache/datafusion/pull/11028) (jayzhan211) +- Minor: Move `function::Hint` to `datafusion-expr` crate to avoid physical-expr dependency for `datafusion-function` crate [#11061](https://github.com/apache/datafusion/pull/11061) (jayzhan211) +- Support to unparse ScalarValue::TimestampMillisecond to String [#11046](https://github.com/apache/datafusion/pull/11046) (pingsutw) +- Support to unparse IntervalYearMonth and IntervalDayTime to String [#11065](https://github.com/apache/datafusion/pull/11065) (goldmedal) +- SMJ: fix streaming row concurrency issue for LEFT SEMI filtered join [#11041](https://github.com/apache/datafusion/pull/11041) (comphead) +- Add `advanced_parquet_index.rs` example of index in into parquet files [#10701](https://github.com/apache/datafusion/pull/10701) (alamb) +- Add Expr::column_refs to find column references without copying [#10948](https://github.com/apache/datafusion/pull/10948) (alamb) +- Give `OptimizerRule::try_optimize` default implementation and cleanup duplicated custom implementations [#11059](https://github.com/apache/datafusion/pull/11059) (lewiszlw) +- Fix `FormatOptions::CSV` propagation [#10912](https://github.com/apache/datafusion/pull/10912) (svranesevic) +- Support parsing SQL strings to Exprs [#10995](https://github.com/apache/datafusion/pull/10995) (xinlifoobar) +- Support dictionary data type in array_to_string [#10908](https://github.com/apache/datafusion/pull/10908) (EduardoVega) +- Implement min/max for interval types [#11015](https://github.com/apache/datafusion/pull/11015) (maxburke) +- Improve LIKE performance for Dictionary arrays [#11058](https://github.com/apache/datafusion/pull/11058) (Lordworms) +- handle overflow in gcd and return this as an error [#11057](https://github.com/apache/datafusion/pull/11057) (LorrensP-2158466) +- Convert Correlation to UDAF [#11064](https://github.com/apache/datafusion/pull/11064) (pingsutw) +- Migrate more code from `Expr::to_columns` to `Expr::column_refs` [#11067](https://github.com/apache/datafusion/pull/11067) (alamb) +- decimal support for unparser [#11092](https://github.com/apache/datafusion/pull/11092) (y-f-u) +- Improve `CommonSubexprEliminate` identifier management (10% faster planning) [#10473](https://github.com/apache/datafusion/pull/10473) (peter-toth) +- Change wildcard qualifier type from `String` to `TableReference` [#11073](https://github.com/apache/datafusion/pull/11073) (linhr) +- Allow access to UDTF in `SessionContext` [#11071](https://github.com/apache/datafusion/pull/11071) (linhr) +- Strip table qualifiers from schema in `UNION ALL` for unparser [#11082](https://github.com/apache/datafusion/pull/11082) (phillipleblanc) +- Update ListingTable to use StatisticsConverter [#11068](https://github.com/apache/datafusion/pull/11068) (xinlifoobar) +- to_timestamp functions should preserve timezone [#11038](https://github.com/apache/datafusion/pull/11038) (maxburke) +- Rewrite array operator to function in parser [#11101](https://github.com/apache/datafusion/pull/11101) (jayzhan211) +- Resolve empty relation opt for join types [#11066](https://github.com/apache/datafusion/pull/11066) (LorrensP-2158466) +- Add composed extension codec example [#11095](https://github.com/apache/datafusion/pull/11095) (lewiszlw) +- Minor: Avoid some repetition in to_timestamp [#11116](https://github.com/apache/datafusion/pull/11116) (alamb) +- Minor: fix ScalarValue::new_ten error message (cites one not ten) [#11126](https://github.com/apache/datafusion/pull/11126) (gstvg) +- Deprecate Expr::column_refs [#11115](https://github.com/apache/datafusion/pull/11115) (alamb) +- Overflow in negate operator [#11084](https://github.com/apache/datafusion/pull/11084) (LorrensP-2158466) +- Minor: Add Architectural Goals to the docs [#11109](https://github.com/apache/datafusion/pull/11109) (alamb) +- Fix overflow in pow [#11124](https://github.com/apache/datafusion/pull/11124) (LorrensP-2158466) +- Support to unparse Time scalar value to String [#11121](https://github.com/apache/datafusion/pull/11121) (goldmedal) +- Support to unparse `TimestampSecond` and `TimestampMicrosecond` to String [#11120](https://github.com/apache/datafusion/pull/11120) (goldmedal) +- Add standalone example for `OptimizerRule` [#11087](https://github.com/apache/datafusion/pull/11087) (alamb) +- Fix overflow in factorial [#11134](https://github.com/apache/datafusion/pull/11134) (LorrensP-2158466) +- Temporary Fix: Query error when grouping by case expressions [#11133](https://github.com/apache/datafusion/pull/11133) (jonahgao) +- Fix nullability of return value of array_agg [#11093](https://github.com/apache/datafusion/pull/11093) (eejbyfeldt) +- Support filter for List [#11091](https://github.com/apache/datafusion/pull/11091) (jayzhan211) +- [MINOR]: Fix some minor silent bugs [#11127](https://github.com/apache/datafusion/pull/11127) (mustafasrepo) +- Minor Fix for Logical and Physical Expr Conversions [#11142](https://github.com/apache/datafusion/pull/11142) (berkaysynnada) +- Support Date Parquet Data Page Statistics [#11135](https://github.com/apache/datafusion/pull/11135) (dharanad) +- fix flaky array query slt test [#11140](https://github.com/apache/datafusion/pull/11140) (leoyvens) +- Support Decimal and Decimal256 Parquet Data Page Statistics [#11138](https://github.com/apache/datafusion/pull/11138) (Lordworms) +- Implement comparisons on nested data types such that distinct/except would work [#11117](https://github.com/apache/datafusion/pull/11117) (rtyler) +- Minor: dont panic with bad arguments to round [#10899](https://github.com/apache/datafusion/pull/10899) (tmi) +- Minor: reduce replication for nested comparison [#11149](https://github.com/apache/datafusion/pull/11149) (alamb) +- [Minor]: Remove datafusion-functions-aggregate dependency from physical-expr crate [#11158](https://github.com/apache/datafusion/pull/11158) (mustafasrepo) +- adding config to control Varchar behavior [#11090](https://github.com/apache/datafusion/pull/11090) (Lordworms) +- minor: consolidate `gcd` related tests [#11164](https://github.com/apache/datafusion/pull/11164) (jonahgao) +- Minor: move batch spilling methods to `lib.rs` to make it reusable [#11154](https://github.com/apache/datafusion/pull/11154) (comphead) +- Move schema projection to where it's used in ListingTable [#11167](https://github.com/apache/datafusion/pull/11167) (adriangb) +- Make running in docker instruction be copy-pastable [#11148](https://github.com/apache/datafusion/pull/11148) (findepi) +- Rewrite `array @> array` and `array <@ array` in sql_expr_to_logical_expr [#11155](https://github.com/apache/datafusion/pull/11155) (jayzhan211) +- Minor: make some physical_optimizer rules public [#11171](https://github.com/apache/datafusion/pull/11171) (askalt) +- Remove pr_benchmarks.yml [#11165](https://github.com/apache/datafusion/pull/11165) (alamb) +- Optionally display schema in explain plan [#11177](https://github.com/apache/datafusion/pull/11177) (alamb) +- Minor: Add more support for ScalarValue::Float16 [#11156](https://github.com/apache/datafusion/pull/11156) (Lordworms) +- Minor: fix SQLOptions::with_allow_ddl comments [#11166](https://github.com/apache/datafusion/pull/11166) (alamb) +- Update sqllogictest requirement from 0.20.0 to 0.21.0 [#11189](https://github.com/apache/datafusion/pull/11189) (dependabot[bot]) +- Support Time Parquet Data Page Statistics [#11187](https://github.com/apache/datafusion/pull/11187) (dharanad) +- Adds support for Dictionary data type statistics from parquet data pages. [#11195](https://github.com/apache/datafusion/pull/11195) (efredine) +- [Minor]: Make sort_batch public [#11191](https://github.com/apache/datafusion/pull/11191) (mustafasrepo) +- Introduce user defined SQL planner API [#11180](https://github.com/apache/datafusion/pull/11180) (jayzhan211) +- Covert grouping to udaf [#11147](https://github.com/apache/datafusion/pull/11147) (Rachelint) +- Make statistics_from_parquet_meta a sync function [#11205](https://github.com/apache/datafusion/pull/11205) (adriangb) +- Allow user defined SQL planners to be registered [#11208](https://github.com/apache/datafusion/pull/11208) (samuelcolvin) +- Recursive `unnest` [#11062](https://github.com/apache/datafusion/pull/11062) (duongcongtoai) +- Document how to test examples in user guide, add some more coverage [#11178](https://github.com/apache/datafusion/pull/11178) (alamb) +- Minor: Move MemoryCatalog\*Provider into a module, improve comments [#11183](https://github.com/apache/datafusion/pull/11183) (alamb) +- Add standalone example of using the SQL frontend [#11088](https://github.com/apache/datafusion/pull/11088) (alamb) +- Add Optimizer Sanity Checker, improve sortedness equivalence properties [#11196](https://github.com/apache/datafusion/pull/11196) (mustafasrepo) +- Implement user defined planner for extract [#11215](https://github.com/apache/datafusion/pull/11215) (xinlifoobar) +- Move basic SQL query examples to user guide [#11217](https://github.com/apache/datafusion/pull/11217) (alamb) +- Support FixedSizedBinaryArray Parquet Data Page Statistics [#11200](https://github.com/apache/datafusion/pull/11200) (dharanad) +- Implement ScalarValue::Map [#11224](https://github.com/apache/datafusion/pull/11224) (goldmedal) +- Remove unmaintained python pre-commit configuration [#11255](https://github.com/apache/datafusion/pull/11255) (findepi) +- Enable `clone_on_ref_ptr` clippy lint on execution crate [#11239](https://github.com/apache/datafusion/pull/11239) (lewiszlw) +- Minor: Improve documentation about pushdown join predicates [#11209](https://github.com/apache/datafusion/pull/11209) (alamb) +- Minor: clean up data page statistics tests and fix bugs [#11236](https://github.com/apache/datafusion/pull/11236) (efredine) +- Replacing pattern matching through downcast with trait method [#11257](https://github.com/apache/datafusion/pull/11257) (edmondop) +- Update substrait requirement from 0.34.0 to 0.35.0 [#11206](https://github.com/apache/datafusion/pull/11206) (dependabot[bot]) +- Enhance short circuit handling in `CommonSubexprEliminate` [#11197](https://github.com/apache/datafusion/pull/11197) (peter-toth) +- Add bench for data page statistics parquet extraction [#10950](https://github.com/apache/datafusion/pull/10950) (marvinlanhenke) +- Register SQL planners in `SessionState` constructor [#11253](https://github.com/apache/datafusion/pull/11253) (dharanad) +- Support DuckDB style struct syntax [#11214](https://github.com/apache/datafusion/pull/11214) (jayzhan211) +- Enable `clone_on_ref_ptr` clippy lint on expr crate [#11238](https://github.com/apache/datafusion/pull/11238) (lewiszlw) +- Optimize PushDownFilter to avoid recreating schema columns [#11211](https://github.com/apache/datafusion/pull/11211) (alamb) +- Remove outdated `rewrite_expr.rs` example [#11085](https://github.com/apache/datafusion/pull/11085) (alamb) +- Implement TPCH substrait integration teset, support tpch_2 [#11234](https://github.com/apache/datafusion/pull/11234) (Lordworms) +- Enable `clone_on_ref_ptr` clippy lint on physical-expr crate [#11240](https://github.com/apache/datafusion/pull/11240) (lewiszlw) +- Add standalone `AnalyzerRule` example that implements row level access control [#11089](https://github.com/apache/datafusion/pull/11089) (alamb) +- Replace println! with assert! if possible in DataFusion examples [#11237](https://github.com/apache/datafusion/pull/11237) (Nishi46) +- minor: format `Expr::get_type()` [#11267](https://github.com/apache/datafusion/pull/11267) (jonahgao) +- Fix hash join for nested types [#11232](https://github.com/apache/datafusion/pull/11232) (eejbyfeldt) +- Infer count() aggregation is not null [#11256](https://github.com/apache/datafusion/pull/11256) (findepi) +- Remove unnecessary qualified names [#11292](https://github.com/apache/datafusion/pull/11292) (findepi) +- Fix running examples readme [#11225](https://github.com/apache/datafusion/pull/11225) (findepi) +- Minor: Add `ConstExpr::from` and use in physical optimizer [#11283](https://github.com/apache/datafusion/pull/11283) (alamb) +- Implement TPCH substrait integration teset, support tpch_3 [#11298](https://github.com/apache/datafusion/pull/11298) (Lordworms) +- Implement user defined planner for position [#11243](https://github.com/apache/datafusion/pull/11243) (xinlifoobar) +- Upgrade to arrow 52.1.0 (and fix clippy issues on main) [#11302](https://github.com/apache/datafusion/pull/11302) (alamb) +- AggregateExec: Take grouping sets into account for InputOrderMode [#11301](https://github.com/apache/datafusion/pull/11301) (thinkharderdev) +- Add user_defined_sql_planners(..) to FunctionRegistry [#11296](https://github.com/apache/datafusion/pull/11296) (Omega359) +- use safe cast in propagate_constraints [#11297](https://github.com/apache/datafusion/pull/11297) (Lordworms) +- Minor: Remove clone in optimizer [#11315](https://github.com/apache/datafusion/pull/11315) (jayzhan211) +- minor: Add `PhysicalSortExpr::new` [#11310](https://github.com/apache/datafusion/pull/11310) (andygrove) +- Fix data page statistics when all rows are null in a data page [#11295](https://github.com/apache/datafusion/pull/11295) (efredine) +- Made UserDefinedFunctionPlanner to uniform the usages [#11318](https://github.com/apache/datafusion/pull/11318) (xinlifoobar) +- Implement user defined planner for `create_struct` & `create_named_struct` [#11273](https://github.com/apache/datafusion/pull/11273) (dharanad) +- Improve stats convert performance for Binary/String/Boolean arrays [#11319](https://github.com/apache/datafusion/pull/11319) (Rachelint) +- Fix typos in datafusion-examples/datafusion-cli/docs [#11259](https://github.com/apache/datafusion/pull/11259) (lewiszlw) +- Minor: Fix Failing TPC-DS Test [#11331](https://github.com/apache/datafusion/pull/11331) (berkaysynnada) +- HashJoin can preserve the right ordering when join type is Right [#11276](https://github.com/apache/datafusion/pull/11276) (berkaysynnada) +- Update substrait requirement from 0.35.0 to 0.36.0 [#11328](https://github.com/apache/datafusion/pull/11328) (dependabot[bot]) +- Support to uparse logical plans with timestamp cast to string [#11326](https://github.com/apache/datafusion/pull/11326) (sgrebnov) +- Implement user defined planner for sql_substring_to_expr [#11327](https://github.com/apache/datafusion/pull/11327) (xinlifoobar) +- Improve volatile expression handling in `CommonSubexprEliminate` [#11265](https://github.com/apache/datafusion/pull/11265) (peter-toth) +- Support `IS NULL` and `IS NOT NULL` on Unions [#11321](https://github.com/apache/datafusion/pull/11321) (samuelcolvin) +- Implement TPCH substrait integration test, support tpch_4 and tpch_5 [#11311](https://github.com/apache/datafusion/pull/11311) (Lordworms) +- Enable `clone_on_ref_ptr` clippy lint on physical-plan crate [#11241](https://github.com/apache/datafusion/pull/11241) (lewiszlw) +- Remove any aliases in `Filter::try_new` rather than erroring [#11307](https://github.com/apache/datafusion/pull/11307) (samuelcolvin) +- Improve `DataFrame` Users Guide [#11324](https://github.com/apache/datafusion/pull/11324) (alamb) +- chore: Rename UserDefinedSQLPlanner to ExprPlanner [#11338](https://github.com/apache/datafusion/pull/11338) (andygrove) +- Revert "remove `derive(Copy)` from `Operator` (#11132)" [#11341](https://github.com/apache/datafusion/pull/11341) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 41 Andrew Lamb + 17 Jay Zhan + 12 Lordworms + 12 张林伟 + 10 Arttu + 9 Jax Liu + 9 Lorrens Pantelis + 8 Piotr Findeisen + 7 Dharan Aditya + 7 Jonah Gao + 7 Xin Li + 6 Andy Grove + 6 Marvin Lanhenke + 6 Trent Hauck + 5 Alex Huang + 5 Eric Fredine + 5 Mustafa Akur + 5 Oleks V + 5 dependabot[bot] + 4 Adrian Garcia Badaracco + 4 Berkay Şahin + 4 Kevin Su + 4 Peter Toth + 4 Ruihang Xia + 4 Samuel Colvin + 3 Bruce Ritchie + 3 Edmondo Porcu + 3 Emil Ejbyfeldt + 3 Heran Lin + 3 Leonardo Yvens + 3 jcsherin + 3 tmi + 2 Duong Cong Toai + 2 Liang-Chi Hsieh + 2 Max Burke + 2 kamille + 1 Albert Skalt + 1 Andrey Koshchiy + 1 Benjamin Bannier + 1 Bo Lin + 1 Chojan Shang + 1 Chunchun Ye + 1 Dan Harris + 1 Devin D'Angelo + 1 Eduardo Vega + 1 Georgi Krastev + 1 Hector Veiga + 1 Jeffrey Smith II + 1 Kirill Khramkov + 1 Matt Nawara + 1 Mohamed Abdeen + 1 Nga Tran + 1 Nishi + 1 Phillip LeBlanc + 1 R. Tyler Croy + 1 RT_Enzyme + 1 Sava Vranešević + 1 Sergei Grebnov + 1 Weston Pace + 1 Xiangpeng Hao + 1 advancedxy + 1 c8ef + 1 gstvg + 1 yfu +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/41.0.0.md b/dev/changelog/41.0.0.md new file mode 100644 index 000000000000..3e289112c7bb --- /dev/null +++ b/dev/changelog/41.0.0.md @@ -0,0 +1,363 @@ + + +# Apache DataFusion 41.0.0 Changelog + +This release consists of 245 commits from 69 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- make unparser `Dialect` trait `Send` + `Sync` [#11504](https://github.com/apache/datafusion/pull/11504) (y-f-u) +- Implement physical plan serialization for csv COPY plans , add `as_any`, `Debug` to `FileFormatFactory` [#11588](https://github.com/apache/datafusion/pull/11588) (Lordworms) +- Consistent API to set parameters of aggregate and window functions (`AggregateExt` --> `ExprFunctionExt`) [#11550](https://github.com/apache/datafusion/pull/11550) (timsaucer) +- Rename `ColumnOptions` to `ParquetColumnOptions` [#11512](https://github.com/apache/datafusion/pull/11512) (alamb) +- Rename `input_type` --> `input_types` on AggregateFunctionExpr / AccumulatorArgs / StateFieldsArgs [#11666](https://github.com/apache/datafusion/pull/11666) (lewiszlw) +- Rename RepartitionExec metric `repart_time` to `repartition_time` [#11703](https://github.com/apache/datafusion/pull/11703) (alamb) +- Remove `AggregateFunctionDefinition` [#11803](https://github.com/apache/datafusion/pull/11803) (lewiszlw) +- Skipping partial aggregation when it is not helping for high cardinality aggregates [#11627](https://github.com/apache/datafusion/pull/11627) (korowa) +- Optionally create name of aggregate expression from expressions [#11776](https://github.com/apache/datafusion/pull/11776) (lewiszlw) + +**Performance related:** + +- feat: Optimize CASE expression for "column or null" use case [#11534](https://github.com/apache/datafusion/pull/11534) (andygrove) +- feat: Optimize CASE expression for usage where then and else values are literals [#11553](https://github.com/apache/datafusion/pull/11553) (andygrove) +- perf: Optimize IsNotNullExpr [#11586](https://github.com/apache/datafusion/pull/11586) (andygrove) + +**Implemented enhancements:** + +- feat: Add `fail_on_overflow` option to `BinaryExpr` [#11400](https://github.com/apache/datafusion/pull/11400) (andygrove) +- feat: add UDF to_local_time() [#11347](https://github.com/apache/datafusion/pull/11347) (appletreeisyellow) +- feat: switch to using proper Substrait types for IntervalYearMonth and IntervalDayTime [#11471](https://github.com/apache/datafusion/pull/11471) (Blizzara) +- feat: support UDWFs in Substrait [#11489](https://github.com/apache/datafusion/pull/11489) (Blizzara) +- feat: support `unnest` in GROUP BY clause [#11469](https://github.com/apache/datafusion/pull/11469) (JasonLi-cn) +- feat: support `COUNT()` [#11229](https://github.com/apache/datafusion/pull/11229) (tshauck) +- feat: consume and produce Substrait type extensions [#11510](https://github.com/apache/datafusion/pull/11510) (Blizzara) +- feat: Error when a SHOW command is passed in with an accompanying non-existant variable [#11540](https://github.com/apache/datafusion/pull/11540) (itsjunetime) +- feat: support Map literals in Substrait consumer and producer [#11547](https://github.com/apache/datafusion/pull/11547) (Blizzara) +- feat: add bounds for unary math scalar functions [#11584](https://github.com/apache/datafusion/pull/11584) (tshauck) +- feat: Add support for cardinality function on maps [#11801](https://github.com/apache/datafusion/pull/11801) (Weijun-H) +- feat: support `Utf8View` type in `starts_with` function [#11787](https://github.com/apache/datafusion/pull/11787) (tshauck) +- feat: Expose public method for optimizing physical plans [#11879](https://github.com/apache/datafusion/pull/11879) (andygrove) + +**Fixed bugs:** + +- fix: Fix eq properties regression from #10434 [#11363](https://github.com/apache/datafusion/pull/11363) (suremarc) +- fix: make sure JOIN ON expression is boolean type [#11423](https://github.com/apache/datafusion/pull/11423) (jonahgao) +- fix: `regexp_replace` fails when pattern or replacement is a scalar `NULL` [#11459](https://github.com/apache/datafusion/pull/11459) (Weijun-H) +- fix: unparser generates wrong sql for derived table with columns [#11505](https://github.com/apache/datafusion/pull/11505) (y-f-u) +- fix: make `UnKnownColumn`s not equal to others physical exprs [#11536](https://github.com/apache/datafusion/pull/11536) (jonahgao) +- fix: fixes trig function order by [#11559](https://github.com/apache/datafusion/pull/11559) (tshauck) +- fix: CASE with NULL [#11542](https://github.com/apache/datafusion/pull/11542) (Weijun-H) +- fix: panic and incorrect results in `LogFunc::output_ordering()` [#11571](https://github.com/apache/datafusion/pull/11571) (jonahgao) +- fix: expose the fluent API fn for approx_distinct instead of the module [#11644](https://github.com/apache/datafusion/pull/11644) (Michael-J-Ward) +- fix: dont try to coerce list for regex match [#11646](https://github.com/apache/datafusion/pull/11646) (tshauck) +- fix: regr_count now returns Uint64 [#11731](https://github.com/apache/datafusion/pull/11731) (Michael-J-Ward) +- fix: set `null_equals_null` to false when `convert_cross_join_to_inner_join` [#11738](https://github.com/apache/datafusion/pull/11738) (jonahgao) +- fix: Add additional required expression for natural join [#11713](https://github.com/apache/datafusion/pull/11713) (Lordworms) +- fix: hash join tests with forced collisions [#11806](https://github.com/apache/datafusion/pull/11806) (korowa) +- fix: `collect_columns` quadratic complexity [#11843](https://github.com/apache/datafusion/pull/11843) (crepererum) + +**Documentation updates:** + +- Minor: Add link to blog to main DataFusion website [#11356](https://github.com/apache/datafusion/pull/11356) (alamb) +- Add `to_local_time()` in function reference docs [#11401](https://github.com/apache/datafusion/pull/11401) (appletreeisyellow) +- Minor: Consolidate specification doc sections [#11427](https://github.com/apache/datafusion/pull/11427) (alamb) +- Combine the Roadmap / Quarterly Roadmap sections [#11426](https://github.com/apache/datafusion/pull/11426) (alamb) +- Minor: Add an example for backtrace pretty print [#11450](https://github.com/apache/datafusion/pull/11450) (goldmedal) +- Docs: Document creating new extension APIs [#11425](https://github.com/apache/datafusion/pull/11425) (alamb) +- Minor: Clarify which parquet options are used for reading/writing [#11511](https://github.com/apache/datafusion/pull/11511) (alamb) +- Support `newlines_in_values` CSV option [#11533](https://github.com/apache/datafusion/pull/11533) (connec) +- chore: Minor cleanup `simplify_demo()` example [#11576](https://github.com/apache/datafusion/pull/11576) (kavirajk) +- Move Datafusion Query Optimizer to library user guide [#11563](https://github.com/apache/datafusion/pull/11563) (devesh-2002) +- Fix typo in doc of Partitioning [#11612](https://github.com/apache/datafusion/pull/11612) (waruto210) +- Doc: A tiny typo in scalar function's doc [#11620](https://github.com/apache/datafusion/pull/11620) (2010YOUY01) +- Change default Parquet writer settings to match arrow-rs (except for compression & statistics) [#11558](https://github.com/apache/datafusion/pull/11558) (wiedld) +- Rename `functions-array` to `functions-nested` [#11602](https://github.com/apache/datafusion/pull/11602) (goldmedal) +- Add parser option enable_options_value_normalization [#11330](https://github.com/apache/datafusion/pull/11330) (xinlifoobar) +- Add reference to #comet channel in Arrow Rust Discord server [#11637](https://github.com/apache/datafusion/pull/11637) (ajmarcus) +- Extract catalog API to separate crate, change `TableProvider::scan` to take a trait rather than `SessionState` [#11516](https://github.com/apache/datafusion/pull/11516) (findepi) +- doc: why nullable of list item is set to true [#11626](https://github.com/apache/datafusion/pull/11626) (jcsherin) +- Docs: adding explicit mention of test_utils to docs [#11670](https://github.com/apache/datafusion/pull/11670) (edmondop) +- Ensure statistic defaults in parquet writers are in sync [#11656](https://github.com/apache/datafusion/pull/11656) (wiedld) +- Merge `string-view2` branch: reading from parquet up to 2x faster for some ClickBench queries (not on by default) [#11667](https://github.com/apache/datafusion/pull/11667) (alamb) +- Doc: Add Sail to known users list [#11791](https://github.com/apache/datafusion/pull/11791) (shehabgamin) +- Move min and max to user defined aggregate function, remove `AggregateFunction` / `AggregateFunctionDefinition::BuiltIn` [#11013](https://github.com/apache/datafusion/pull/11013) (edmondop) +- Change name of MAX/MIN udaf to lowercase max/min [#11795](https://github.com/apache/datafusion/pull/11795) (edmondop) +- doc: Add support for `map` and `make_map` functions [#11799](https://github.com/apache/datafusion/pull/11799) (Weijun-H) +- Improve readme page in crates.io [#11809](https://github.com/apache/datafusion/pull/11809) (lewiszlw) +- refactor: remove unneed mut for session context [#11864](https://github.com/apache/datafusion/pull/11864) (sunng87) + +**Other:** + +- Prepare 40.0.0 Release [#11343](https://github.com/apache/datafusion/pull/11343) (andygrove) +- Support `NULL` literals in where clause [#11266](https://github.com/apache/datafusion/pull/11266) (xinlifoobar) +- Implement TPCH substrait integration test, support tpch_6, tpch_10, t… [#11349](https://github.com/apache/datafusion/pull/11349) (Lordworms) +- Fix bug when pushing projection under joins [#11333](https://github.com/apache/datafusion/pull/11333) (jonahgao) +- Minor: some cosmetics in `filter.rs`, fix clippy due to logical conflict [#11368](https://github.com/apache/datafusion/pull/11368) (comphead) +- Update prost-derive requirement from 0.12 to 0.13 [#11355](https://github.com/apache/datafusion/pull/11355) (dependabot[bot]) +- Minor: update dashmap `6.0.1` [#11335](https://github.com/apache/datafusion/pull/11335) (alamb) +- Improve and test dataframe API examples in docs [#11290](https://github.com/apache/datafusion/pull/11290) (alamb) +- Remove redundant `unalias_nested` calls for creating Filter's [#11340](https://github.com/apache/datafusion/pull/11340) (alamb) +- Enable `clone_on_ref_ptr` clippy lint on optimizer [#11346](https://github.com/apache/datafusion/pull/11346) (lewiszlw) +- Update termtree requirement from 0.4.1 to 0.5.0 [#11383](https://github.com/apache/datafusion/pull/11383) (dependabot[bot]) +- Introduce `resources_err!` error macro [#11374](https://github.com/apache/datafusion/pull/11374) (comphead) +- Enable `clone_on_ref_ptr` clippy lint on common [#11384](https://github.com/apache/datafusion/pull/11384) (lewiszlw) +- Track parquet writer encoding memory usage on MemoryPool [#11345](https://github.com/apache/datafusion/pull/11345) (wiedld) +- Minor: remove clones and unnecessary Arcs in `from_substrait_rex` [#11337](https://github.com/apache/datafusion/pull/11337) (alamb) +- Minor: Change no-statement error message to be clearer [#11394](https://github.com/apache/datafusion/pull/11394) (itsjunetime) +- Change `array_agg` to return `null` on no input rather than empty list [#11299](https://github.com/apache/datafusion/pull/11299) (jayzhan211) +- Minor: return "not supported" for `COUNT DISTINCT` with multiple arguments [#11391](https://github.com/apache/datafusion/pull/11391) (jonahgao) +- Enable `clone_on_ref_ptr` clippy lint on sql [#11380](https://github.com/apache/datafusion/pull/11380) (lewiszlw) +- Move configuration information out of example usage page [#11300](https://github.com/apache/datafusion/pull/11300) (alamb) +- chore: reuse a single function to create the Substrait TPCH consumer test contexts [#11396](https://github.com/apache/datafusion/pull/11396) (Blizzara) +- refactor: change error type for "no statement" [#11411](https://github.com/apache/datafusion/pull/11411) (crepererum) +- Implement prettier SQL unparsing (more human readable) [#11186](https://github.com/apache/datafusion/pull/11186) (MohamedAbdeen21) +- Move `overlay` planning to`ExprPlanner` [#11398](https://github.com/apache/datafusion/pull/11398) (dharanad) +- Coerce types for all union children plans when eliminating nesting [#11386](https://github.com/apache/datafusion/pull/11386) (gruuya) +- Add customizable equality and hash functions to UDFs [#11392](https://github.com/apache/datafusion/pull/11392) (joroKr21) +- Implement ScalarFunction `MAKE_MAP` and `MAP` [#11361](https://github.com/apache/datafusion/pull/11361) (goldmedal) +- Improve `CommonSubexprEliminate` rule with surely and conditionally evaluated stats [#11357](https://github.com/apache/datafusion/pull/11357) (peter-toth) +- fix(11397): surface proper errors in ParquetSink [#11399](https://github.com/apache/datafusion/pull/11399) (wiedld) +- Minor: Add note about SQLLancer fuzz testing to docs [#11430](https://github.com/apache/datafusion/pull/11430) (alamb) +- Trivial: use arrow csv writer's timestamp_tz_format [#11407](https://github.com/apache/datafusion/pull/11407) (tmi) +- Improved unparser documentation [#11395](https://github.com/apache/datafusion/pull/11395) (alamb) +- Avoid calling shutdown after failed write of AsyncWrite [#11415](https://github.com/apache/datafusion/pull/11415) (joroKr21) +- Short term way to make `AggregateStatistics` still work when min/max is converted to udaf [#11261](https://github.com/apache/datafusion/pull/11261) (Rachelint) +- Implement TPCH substrait integration test, support tpch_13, tpch_14,16 [#11405](https://github.com/apache/datafusion/pull/11405) (Lordworms) +- Minor: fix giuthub action labeler rules [#11428](https://github.com/apache/datafusion/pull/11428) (alamb) +- Minor: change internal error to not supported error for nested field … [#11446](https://github.com/apache/datafusion/pull/11446) (alamb) +- Minor: change Datafusion --> DataFusion in docs [#11439](https://github.com/apache/datafusion/pull/11439) (alamb) +- Support serialization/deserialization for custom physical exprs in proto [#11387](https://github.com/apache/datafusion/pull/11387) (lewiszlw) +- remove termtree dependency [#11416](https://github.com/apache/datafusion/pull/11416) (Kev1n8) +- Add SessionStateBuilder and extract out the registration of defaults [#11403](https://github.com/apache/datafusion/pull/11403) (Omega359) +- integrate consumer tests, implement tpch query 18 to 22 [#11462](https://github.com/apache/datafusion/pull/11462) (Lordworms) +- Docs: Explain the usage of logical expressions for `create_aggregate_expr` [#11458](https://github.com/apache/datafusion/pull/11458) (jayzhan211) +- Return scalar result when all inputs are constants in `map` and `make_map` [#11461](https://github.com/apache/datafusion/pull/11461) (Rachelint) +- Enable `clone_on_ref_ptr` clippy lint on functions\* [#11468](https://github.com/apache/datafusion/pull/11468) (lewiszlw) +- minor: non-overlapping `repart_time` and `send_time` metrics [#11440](https://github.com/apache/datafusion/pull/11440) (korowa) +- Minor: rename `row_groups.rs` to `row_group_filter.rs` [#11481](https://github.com/apache/datafusion/pull/11481) (alamb) +- Support alternate formats for unparsing `datetime` to `timestamp` and `interval` [#11466](https://github.com/apache/datafusion/pull/11466) (y-f-u) +- chore: Add criterion benchmark for CaseExpr [#11482](https://github.com/apache/datafusion/pull/11482) (andygrove) +- Initial support for `StringView`, merge changes from `string-view` development branch [#11402](https://github.com/apache/datafusion/pull/11402) (alamb) +- Replace to_lowercase with to_string in sql example [#11486](https://github.com/apache/datafusion/pull/11486) (lewiszlw) +- Minor: Make execute_input_stream Accessible for Any Sinking Operators [#11449](https://github.com/apache/datafusion/pull/11449) (berkaysynnada) +- Enable `clone_on_ref_ptr` clippy lints on proto [#11465](https://github.com/apache/datafusion/pull/11465) (lewiszlw) +- upgrade sqlparser 0.47 -> 0.48 [#11453](https://github.com/apache/datafusion/pull/11453) (MohamedAbdeen21) +- Add extension hooks for encoding and decoding UDAFs and UDWFs [#11417](https://github.com/apache/datafusion/pull/11417) (joroKr21) +- Remove element's nullability of array_agg function [#11447](https://github.com/apache/datafusion/pull/11447) (jayzhan211) +- Get expr planners when creating new planner [#11485](https://github.com/apache/datafusion/pull/11485) (jayzhan211) +- Support alternate format for Utf8 unparsing (CHAR) [#11494](https://github.com/apache/datafusion/pull/11494) (sgrebnov) +- implement retract_batch for xor accumulator [#11500](https://github.com/apache/datafusion/pull/11500) (drewhayward) +- Refactor: more clearly delineate between `TableParquetOptions` and `ParquetWriterOptions` [#11444](https://github.com/apache/datafusion/pull/11444) (wiedld) +- chore: fix typos of common and core packages [#11520](https://github.com/apache/datafusion/pull/11520) (JasonLi-cn) +- Move spill related functions to spill.rs [#11509](https://github.com/apache/datafusion/pull/11509) (findepi) +- Add tests that show the different defaults for `ArrowWriter` and `TableParquetOptions` [#11524](https://github.com/apache/datafusion/pull/11524) (wiedld) +- Create `datafusion-physical-optimizer` crate [#11507](https://github.com/apache/datafusion/pull/11507) (lewiszlw) +- Minor: Assert `test_enabled_backtrace` requirements to run [#11525](https://github.com/apache/datafusion/pull/11525) (comphead) +- Move handlign of NULL literals in where clause to type coercion pass [#11491](https://github.com/apache/datafusion/pull/11491) (xinlifoobar) +- Update parquet page pruning code to use the `StatisticsExtractor` [#11483](https://github.com/apache/datafusion/pull/11483) (alamb) +- Enable SortMergeJoin LeftAnti filtered fuzz tests [#11535](https://github.com/apache/datafusion/pull/11535) (comphead) +- chore: fix typos of expr, functions, optimizer, physical-expr-common,… [#11538](https://github.com/apache/datafusion/pull/11538) (JasonLi-cn) +- Minor: Remove clone in `PushDownFilter` [#11532](https://github.com/apache/datafusion/pull/11532) (jayzhan211) +- Minor: avoid a clone in type coercion [#11530](https://github.com/apache/datafusion/pull/11530) (alamb) +- Move array `ArrayAgg` to a `UserDefinedAggregate` [#11448](https://github.com/apache/datafusion/pull/11448) (jayzhan211) +- Move `MAKE_MAP` to ExprPlanner [#11452](https://github.com/apache/datafusion/pull/11452) (goldmedal) +- chore: fix typos of sql, sqllogictest and substrait packages [#11548](https://github.com/apache/datafusion/pull/11548) (JasonLi-cn) +- Prevent bigger files from being checked in [#11508](https://github.com/apache/datafusion/pull/11508) (findepi) +- Add dialect param to use double precision for float64 in Postgres [#11495](https://github.com/apache/datafusion/pull/11495) (Sevenannn) +- Minor: move `SessionStateDefaults` into its own module [#11566](https://github.com/apache/datafusion/pull/11566) (alamb) +- refactor: rewrite mega type to an enum containing both cases [#11539](https://github.com/apache/datafusion/pull/11539) (LorrensP-2158466) +- Move `sql_compound_identifier_to_expr ` to `ExprPlanner` [#11487](https://github.com/apache/datafusion/pull/11487) (dharanad) +- Support SortMergeJoin spilling [#11218](https://github.com/apache/datafusion/pull/11218) (comphead) +- Fix unparser invalid sql for query with order [#11527](https://github.com/apache/datafusion/pull/11527) (y-f-u) +- Provide DataFrame API for `map` and move `map` to `functions-array` [#11560](https://github.com/apache/datafusion/pull/11560) (goldmedal) +- Move OutputRequirements to datafusion-physical-optimizer crate [#11579](https://github.com/apache/datafusion/pull/11579) (xinlifoobar) +- Minor: move `Column` related tests and rename `column.rs` [#11573](https://github.com/apache/datafusion/pull/11573) (jonahgao) +- Fix SortMergeJoin antijoin flaky condition [#11604](https://github.com/apache/datafusion/pull/11604) (comphead) +- Improve Union Equivalence Propagation [#11506](https://github.com/apache/datafusion/pull/11506) (mustafasrepo) +- Migrate `OrderSensitiveArrayAgg` to be a user defined aggregate [#11564](https://github.com/apache/datafusion/pull/11564) (jayzhan211) +- Minor:Disable flaky SMJ antijoin filtered test until the fix [#11608](https://github.com/apache/datafusion/pull/11608) (comphead) +- support Decimal256 type in datafusion-proto [#11606](https://github.com/apache/datafusion/pull/11606) (leoyvens) +- Chore/fifo tests cleanup [#11616](https://github.com/apache/datafusion/pull/11616) (ozankabak) +- Fix Internal Error for an INNER JOIN query [#11578](https://github.com/apache/datafusion/pull/11578) (xinlifoobar) +- test: get file size by func metadata [#11575](https://github.com/apache/datafusion/pull/11575) (zhuliquan) +- Improve unparser MySQL compatibility [#11589](https://github.com/apache/datafusion/pull/11589) (sgrebnov) +- Push scalar functions into cross join [#11528](https://github.com/apache/datafusion/pull/11528) (lewiszlw) +- Remove ArrayAgg Builtin in favor of UDF [#11611](https://github.com/apache/datafusion/pull/11611) (jayzhan211) +- refactor: simplify `DFSchema::field_with_unqualified_name` [#11619](https://github.com/apache/datafusion/pull/11619) (jonahgao) +- Minor: Use upstream `concat_batches` from arrow-rs [#11615](https://github.com/apache/datafusion/pull/11615) (alamb) +- Fix : `signum` function bug when `0.0` input [#11580](https://github.com/apache/datafusion/pull/11580) (getChan) +- Enforce uniqueness of `named_struct` field names [#11614](https://github.com/apache/datafusion/pull/11614) (dharanad) +- Minor: unecessary row_count calculation in `CrossJoinExec` and `NestedLoopsJoinExec` [#11632](https://github.com/apache/datafusion/pull/11632) (alamb) +- ExprBuilder for Physical Aggregate Expr [#11617](https://github.com/apache/datafusion/pull/11617) (jayzhan211) +- Minor: avoid copying order by exprs in planner [#11634](https://github.com/apache/datafusion/pull/11634) (alamb) +- Unify CI and pre-commit hook settings for clippy [#11640](https://github.com/apache/datafusion/pull/11640) (findepi) +- Parsing SQL strings to Exprs with the qualified schema [#11562](https://github.com/apache/datafusion/pull/11562) (Lordworms) +- Add some zero column tests covering LIMIT, GROUP BY, WHERE, JOIN, and WINDOW [#11624](https://github.com/apache/datafusion/pull/11624) (Kev1n8) +- Refactor/simplify window frame utils [#11648](https://github.com/apache/datafusion/pull/11648) (ozankabak) +- Minor: use `ready!` macro to simplify `FilterExec` [#11649](https://github.com/apache/datafusion/pull/11649) (alamb) +- Temporarily pin toolchain version to avoid clippy errors [#11655](https://github.com/apache/datafusion/pull/11655) (findepi) +- Fix clippy errors for Rust 1.80 [#11654](https://github.com/apache/datafusion/pull/11654) (findepi) +- Add `CsvExecBuilder` for creating `CsvExec` [#11633](https://github.com/apache/datafusion/pull/11633) (connec) +- chore(deps): update sqlparser requirement from 0.48 to 0.49 [#11630](https://github.com/apache/datafusion/pull/11630) (dependabot[bot]) +- Add support for USING to SQL unparser [#11636](https://github.com/apache/datafusion/pull/11636) (wackywendell) +- Run CI with latest (Rust 1.80), add ticket references to commented out tests [#11661](https://github.com/apache/datafusion/pull/11661) (alamb) +- Use `AccumulatorArgs::is_reversed` in `NthValueAgg` [#11669](https://github.com/apache/datafusion/pull/11669) (jcsherin) +- Implement physical plan serialization for json Copy plans [#11645](https://github.com/apache/datafusion/pull/11645) (Lordworms) +- Minor: improve documentation on `SessionState` [#11642](https://github.com/apache/datafusion/pull/11642) (alamb) +- Add LimitPushdown optimization rule and CoalesceBatchesExec fetch [#11652](https://github.com/apache/datafusion/pull/11652) (alihandroid) +- Update to arrow/parquet `52.2.0` [#11691](https://github.com/apache/datafusion/pull/11691) (alamb) +- Minor: Rename `RepartitionMetrics::repartition_time` to `RepartitionMetrics::repart_time` to match metric [#11478](https://github.com/apache/datafusion/pull/11478) (alamb) +- Update cache key used in rust CI script [#11641](https://github.com/apache/datafusion/pull/11641) (findepi) +- Fix bug in `remove_join_expressions` [#11693](https://github.com/apache/datafusion/pull/11693) (jonahgao) +- Initial changes to support using udaf min/max for statistics and opti… [#11696](https://github.com/apache/datafusion/pull/11696) (edmondop) +- Handle nulls in approx_percentile_cont [#11721](https://github.com/apache/datafusion/pull/11721) (Dandandan) +- Reduce repetition in try_process_group_by_unnest and try_process_unnest [#11714](https://github.com/apache/datafusion/pull/11714) (JasonLi-cn) +- Minor: Add example for `ScalarUDF::call` [#11727](https://github.com/apache/datafusion/pull/11727) (alamb) +- Use `cargo release` in `bench.sh` [#11722](https://github.com/apache/datafusion/pull/11722) (alamb) +- expose some fields on session state [#11716](https://github.com/apache/datafusion/pull/11716) (waynexia) +- Make DefaultSchemaAdapterFactory public [#11709](https://github.com/apache/datafusion/pull/11709) (adriangb) +- Check hashes first during probing the aggr hash table [#11718](https://github.com/apache/datafusion/pull/11718) (Rachelint) +- Implement physical plan serialization for parquet Copy plans [#11735](https://github.com/apache/datafusion/pull/11735) (Lordworms) +- Support cross-timezone `timestamp` comparison via coercsion [#11711](https://github.com/apache/datafusion/pull/11711) (jeffreyssmith2nd) +- Minor: Improve documentation for AggregateUDFImpl::state_fields [#11740](https://github.com/apache/datafusion/pull/11740) (lewiszlw) +- Do not push down Sorts if it violates the sort requirements [#11678](https://github.com/apache/datafusion/pull/11678) (alamb) +- Use upstream `StatisticsConverter` from arrow-rs in DataFusion [#11479](https://github.com/apache/datafusion/pull/11479) (alamb) +- Fix `plan_to_sql`: Add wildcard projection to SELECT statement if no projection was set [#11744](https://github.com/apache/datafusion/pull/11744) (LatrecheYasser) +- Use upstream `DataType::from_str` in arrow-cast [#11254](https://github.com/apache/datafusion/pull/11254) (alamb) +- Fix documentation warnings, make CsvExecBuilder and Unparsed pub [#11729](https://github.com/apache/datafusion/pull/11729) (alamb) +- [Minor] Add test for only nulls (empty) as input in APPROX_PERCENTILE_CONT [#11760](https://github.com/apache/datafusion/pull/11760) (Dandandan) +- Add `TrackedMemoryPool` with better error messages on exhaustion [#11665](https://github.com/apache/datafusion/pull/11665) (wiedld) +- Derive `Debug` for logical plan nodes [#11757](https://github.com/apache/datafusion/pull/11757) (lewiszlw) +- Minor: add "clickbench extended" queries to slt tests [#11763](https://github.com/apache/datafusion/pull/11763) (alamb) +- Minor: Add comment explaining rationale for hash check [#11750](https://github.com/apache/datafusion/pull/11750) (alamb) +- Fix bug that `COUNT(DISTINCT)` on StringView panics [#11768](https://github.com/apache/datafusion/pull/11768) (XiangpengHao) +- [Minor] Refactor approx_percentile [#11769](https://github.com/apache/datafusion/pull/11769) (Dandandan) +- minor: always time batch_filter even when the result is an empty batch [#11775](https://github.com/apache/datafusion/pull/11775) (andygrove) +- Improve OOM message when a single reservation request fails to get more bytes. [#11771](https://github.com/apache/datafusion/pull/11771) (wiedld) +- [Minor] Short circuit `ApplyFunctionRewrites` if there are no function rewrites [#11765](https://github.com/apache/datafusion/pull/11765) (gruuya) +- Fix #11692: Improve doc comments within macros [#11694](https://github.com/apache/datafusion/pull/11694) (Rafferty97) +- Extract `CoalesceBatchesStream` to a struct [#11610](https://github.com/apache/datafusion/pull/11610) (alamb) +- refactor: move ExecutionPlan and related structs into dedicated mod [#11759](https://github.com/apache/datafusion/pull/11759) (waynexia) +- Minor: Add references to github issue in comments [#11784](https://github.com/apache/datafusion/pull/11784) (findepi) +- Add docs and rename param for `Signature::numeric` [#11778](https://github.com/apache/datafusion/pull/11778) (matthewmturner) +- Support planning `Map` literal [#11780](https://github.com/apache/datafusion/pull/11780) (goldmedal) +- Support `LogicalPlan` `Debug` differently than `Display` [#11774](https://github.com/apache/datafusion/pull/11774) (lewiszlw) +- Remove redundant Aggregate when `DISTINCT` & `GROUP BY` are in the same query [#11781](https://github.com/apache/datafusion/pull/11781) (mertak-synnada) +- Minor: add ticket reference and fmt [#11805](https://github.com/apache/datafusion/pull/11805) (alamb) +- Improve MSRV CI check to print out problems to log [#11789](https://github.com/apache/datafusion/pull/11789) (alamb) +- Improve log func tests stability [#11808](https://github.com/apache/datafusion/pull/11808) (lewiszlw) +- Add valid Distinct case for aggregation [#11814](https://github.com/apache/datafusion/pull/11814) (mertak-synnada) +- Don't implement `create_sliding_accumulator` repeatedly [#11813](https://github.com/apache/datafusion/pull/11813) (lewiszlw) +- chore(deps): update rstest requirement from 0.21.0 to 0.22.0 [#11811](https://github.com/apache/datafusion/pull/11811) (dependabot[bot]) +- Minor: Update exected output due to logical conflict [#11824](https://github.com/apache/datafusion/pull/11824) (alamb) +- Pass scalar to `eq` inside `nullif` [#11697](https://github.com/apache/datafusion/pull/11697) (simonvandel) +- refactor: move `aggregate_statistics` to `datafusion-physical-optimizer` [#11798](https://github.com/apache/datafusion/pull/11798) (Weijun-H) +- Minor: refactor probe check into function `should_skip_aggregation` [#11821](https://github.com/apache/datafusion/pull/11821) (alamb) +- Minor: consolidate `path_partition` test into `core_integration` [#11831](https://github.com/apache/datafusion/pull/11831) (alamb) +- Move optimizer integration tests to `core_integration` [#11830](https://github.com/apache/datafusion/pull/11830) (alamb) +- Bump deprecated version of SessionState::new_with_config_rt to 41.0.0 [#11839](https://github.com/apache/datafusion/pull/11839) (kezhuw) +- Fix partial aggregation skipping with Decimal aggregators [#11833](https://github.com/apache/datafusion/pull/11833) (alamb) +- Fix bug with zero-sized buffer for StringViewArray [#11841](https://github.com/apache/datafusion/pull/11841) (XiangpengHao) +- Reduce clone of `Statistics` in `ListingTable` and `PartitionedFile` [#11802](https://github.com/apache/datafusion/pull/11802) (Rachelint) +- Add `LogicalPlan::CreateIndex` [#11817](https://github.com/apache/datafusion/pull/11817) (lewiszlw) +- Update `object_store` to 0.10.2 [#11860](https://github.com/apache/datafusion/pull/11860) (danlgrca) +- Add `skipped_aggregation_rows` metric to aggregate operator [#11706](https://github.com/apache/datafusion/pull/11706) (alamb) +- Cast `Utf8View` to `Utf8` to support `||` from `StringViewArray` [#11796](https://github.com/apache/datafusion/pull/11796) (dharanad) +- Improve nested loop join code [#11863](https://github.com/apache/datafusion/pull/11863) (lewiszlw) +- [Minor]: Refactor to use Result.transpose() [#11882](https://github.com/apache/datafusion/pull/11882) (djanderson) +- support `ANY()` op [#11849](https://github.com/apache/datafusion/pull/11849) (samuelcolvin) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 48 Andrew Lamb + 20 张林伟 + 9 Jay Zhan + 9 Jonah Gao + 8 Andy Grove + 8 Lordworms + 8 Piotr Findeisen + 8 wiedld + 7 Oleks V + 6 Jax Liu + 5 Alex Huang + 5 Arttu + 5 JasonLi + 5 Trent Hauck + 5 Xin Li + 4 Dharan Aditya + 4 Edmondo Porcu + 4 dependabot[bot] + 4 kamille + 4 yfu + 3 Daniël Heres + 3 Eduard Karacharov + 3 Georgi Krastev + 2 Chris Connelly + 2 Chunchun Ye + 2 June + 2 Marco Neumann + 2 Marko Grujic + 2 Mehmet Ozan Kabak + 2 Michael J Ward + 2 Mohamed Abdeen + 2 Ruihang Xia + 2 Sergei Grebnov + 2 Xiangpeng Hao + 2 jcsherin + 2 kf zheng + 2 mertak-synnada + 1 Adrian Garcia Badaracco + 1 Alexander Rafferty + 1 Alihan Çelikcan + 1 Ariel Marcus + 1 Berkay Şahin + 1 Bruce Ritchie + 1 Devesh Rahatekar + 1 Douglas Anderson + 1 Drew Hayward + 1 Jeffrey Smith II + 1 Kaviraj Kanagaraj + 1 Kezhu Wang + 1 Leonardo Yvens + 1 Lorrens Pantelis + 1 Matthew Cramerus + 1 Matthew Turner + 1 Mustafa Akur + 1 Namgung Chan + 1 Ning Sun + 1 Peter Toth + 1 Qianqian + 1 Samuel Colvin + 1 Shehab Amin + 1 Simon Vandel Sillesen + 1 Tim Saucer + 1 Wendell Smith + 1 Yasser Latreche + 1 Yongting You + 1 danlgrca + 1 tmi + 1 waruto + 1 zhuliquan +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/42.0.0.md b/dev/changelog/42.0.0.md new file mode 100644 index 000000000000..2dff506a0f92 --- /dev/null +++ b/dev/changelog/42.0.0.md @@ -0,0 +1,418 @@ + + +# Apache DataFusion 42.0.0 Changelog + +This release consists of 296 commits from 73 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- feat: expose centroids in approx_percentile_cont fluent api [#11878](https://github.com/apache/datafusion/pull/11878) (Michael-J-Ward) +- UDAF refactor: Add PhysicalExpr trait dependency on `datafusion-expr` and remove logical expressions requirement for creating physical aggregate expression [#11845](https://github.com/apache/datafusion/pull/11845) (jayzhan211) +- `datafusion.catalog.has_headers` default value set to `true` [#11919](https://github.com/apache/datafusion/pull/11919) (korowa) +- Use `schema_name` to create the `physical_name` [#11977](https://github.com/apache/datafusion/pull/11977) (joroKr21) +- Add ability to return `LogicalPlan` by value from `TableProvider::get_logical_plan` [#12113](https://github.com/apache/datafusion/pull/12113) (askalt) +- Remove Sort expression (`Expr::Sort`) [#12177](https://github.com/apache/datafusion/pull/12177) (findepi) +- Remove TableSource::supports_filter_pushdown function [#12239](https://github.com/apache/datafusion/pull/12239) (findepi) +- Remove Box from Sort [#12207](https://github.com/apache/datafusion/pull/12207) (findepi) +- Avoid unnecessary move when setting SessionConfig [#12260](https://github.com/apache/datafusion/pull/12260) (findepi) +- Remove unused `AggregateOptions` struct and `scalar_update_factor` config setting [#12241](https://github.com/apache/datafusion/pull/12241) (jc4x4) +- Remove deprecated LogicalPlan::with_new_inputs function [#12285](https://github.com/apache/datafusion/pull/12285) (findepi) +- Fixes missing `nth_value` UDAF expr function [#12279](https://github.com/apache/datafusion/pull/12279) (jcsherin) +- Remove unnecessary `Result` from return type in `NamePreserver` [#12358](https://github.com/apache/datafusion/pull/12358) (jonahgao) +- Removed Arc wrapping for AggregateFunctionExpr [#12353](https://github.com/apache/datafusion/pull/12353) (athultr1997) + +**Performance related:** + +- perf: avoid repeat format in calc_func_dependencies_for_project [#12305](https://github.com/apache/datafusion/pull/12305) (haohuaijin) + +**Implemented enhancements:** + +- feat: Add map_extract module and function [#11969](https://github.com/apache/datafusion/pull/11969) (Weijun-H) +- feat: use Substrait's PrecisionTimestamp and PrecisionTimestampTz instead of deprecated Timestamp [#11597](https://github.com/apache/datafusion/pull/11597) (Blizzara) +- feat: support upper and lower for stringview [#12138](https://github.com/apache/datafusion/pull/12138) (tshauck) +- feat: Add DateFieldExtractStyle::Strftime support for SqliteDialect unparser [#12161](https://github.com/apache/datafusion/pull/12161) (peasee) +- feat: Enforce the uniqueness of map key name for the map/make_map function [#12153](https://github.com/apache/datafusion/pull/12153) (Weijun-H) +- feat: Add projection to FilterExec [#12281](https://github.com/apache/datafusion/pull/12281) (eejbyfeldt) +- feat: Support `FixedSizedList` in `array_distance` function [#12381](https://github.com/apache/datafusion/pull/12381) (Weijun-H) + +**Fixed bugs:** + +- fix: invalid sqls when unparsing derived table with columns contains calculations, limit/order/distinct [#11756](https://github.com/apache/datafusion/pull/11756) (y-f-u) +- fix: make ScalarValue::Dictionary with NULL values produce NULL arrays [#11908](https://github.com/apache/datafusion/pull/11908) (mhilton) +- fix: throw error on sub-day generate_series increments [#11907](https://github.com/apache/datafusion/pull/11907) (tshauck) +- fix: impl ordering for serialization/deserialization for AggregateUdf [#11926](https://github.com/apache/datafusion/pull/11926) (haohuaijin) +- fix: Fix various complaints from the latest nightly clippy [#11958](https://github.com/apache/datafusion/pull/11958) (itsjunetime) +- fix: move coercion of union from builder to `TypeCoercion` [#11961](https://github.com/apache/datafusion/pull/11961) (jonahgao) +- fix: incorrect aggregation result of `bool_and` [#12017](https://github.com/apache/datafusion/pull/12017) (jonahgao) +- fix: support min/max for Float16 type [#12050](https://github.com/apache/datafusion/pull/12050) (korowa) +- fix: Panic non-integer for the second argument of `nth_value` function [#12076](https://github.com/apache/datafusion/pull/12076) (Weijun-H) +- fix: ser/de fetch in CoalesceBatchesExec [#12107](https://github.com/apache/datafusion/pull/12107) (haohuaijin) +- fix: UDF, UDAF, UDWF with_alias(..) should wrap the inner function fully [#12098](https://github.com/apache/datafusion/pull/12098) (Blizzara) +- fix: Produce buffered null join row only if all joined rows are failed on join filter in SMJ full join [#12090](https://github.com/apache/datafusion/pull/12090) (viirya) +- fix: single partition in SortPreservingMergeExec don't take fetch [#12109](https://github.com/apache/datafusion/pull/12109) (haohuaijin) +- fix: set `supports_retract_batch` to false for `ApproxPercentileAccumulator` [#12132](https://github.com/apache/datafusion/pull/12132) (jonahgao) +- fix: preserve expression names when replacing placeholders [#12126](https://github.com/apache/datafusion/pull/12126) (jonahgao) +- fix: Skip buffered rows which are not joined with streamed side when checking join filter results [#12159](https://github.com/apache/datafusion/pull/12159) (viirya) +- fix: preserve qualifiers when rewriting expressions [#12341](https://github.com/apache/datafusion/pull/12341) (jonahgao) +- fix: support Substrait VirtualTables with no columns [#12339](https://github.com/apache/datafusion/pull/12339) (Blizzara) +- fix: nested loop join requires outer table to be a FusedStream [#12189](https://github.com/apache/datafusion/pull/12189) (YjyJeff) + +**Documentation updates:** + +- chore: Prepare 41.0.0-rc1 [#11889](https://github.com/apache/datafusion/pull/11889) (andygrove) +- Enforce sorting handle fetchable operators, add option to repartition based on row count estimates [#11875](https://github.com/apache/datafusion/pull/11875) (mustafasrepo) +- Minor: change wording for PMC membership notice [#11930](https://github.com/apache/datafusion/pull/11930) (alamb) +- Minor: fix outdated link [#11964](https://github.com/apache/datafusion/pull/11964) (austin362667) +- Minor: polish `Accumulator::state` docs [#12053](https://github.com/apache/datafusion/pull/12053) (lewiszlw) +- Fix CI check when version changes -- remove checked in file that is created by doc example [#12034](https://github.com/apache/datafusion/pull/12034) (alamb) +- Add new user doc to translate logical plan to physical plan [#12026](https://github.com/apache/datafusion/pull/12026) (jc4x4) +- Remove vestigal `datafusion-docs` module compilation [#12081](https://github.com/apache/datafusion/pull/12081) (alamb) +- Minor: Add example for configuring SessionContext [#12139](https://github.com/apache/datafusion/pull/12139) (alamb) +- Make it easier to understand datafusion-cli exclusion [#12188](https://github.com/apache/datafusion/pull/12188) (findepi) +- Add documentation on `EXPLAIN` and `EXPLAIN ANALYZE` [#12122](https://github.com/apache/datafusion/pull/12122) (devanbenz) +- Add `array_distance` function [#12211](https://github.com/apache/datafusion/pull/12211) (austin362667) +- Minor: fix `list_distinct` alias link title [#12246](https://github.com/apache/datafusion/pull/12246) (austin362667) +- Support `map_keys` & `map_values` for MAP type [#12194](https://github.com/apache/datafusion/pull/12194) (dharanad) +- Minor: Improve ExecutionMode documentation [#12214](https://github.com/apache/datafusion/pull/12214) (alamb) +- Implement `kurtosis_pop` UDAF [#12273](https://github.com/apache/datafusion/pull/12273) (goldmedal) +- Update download page to reflect latest version (v41) [#12310](https://github.com/apache/datafusion/pull/12310) (phillipleblanc) +- Fix issue with "to_date" failing to process dates later than year 2262 [#12227](https://github.com/apache/datafusion/pull/12227) (MartinKolbAtWork) +- Coerce BinaryView/Utf8View to LargeBinary/LargeUtf8 on output. [#12271](https://github.com/apache/datafusion/pull/12271) (wiedld) +- Add documentation about performance PRs, add (TBD) section on feature criteria [#12372](https://github.com/apache/datafusion/pull/12372) (alamb) +- Implement native support StringView for `CONTAINS` function [#12168](https://github.com/apache/datafusion/pull/12168) (tlm365) +- Fix parquet statistics for ListingTable and Utf8View with `schema_force_string_view`, rename config option to `schema_force_view_types` [#12232](https://github.com/apache/datafusion/pull/12232) (wiedld) +- Minor: Fix project website links [#12419](https://github.com/apache/datafusion/pull/12419) (alamb) +- doc: Update MSRV policy, shortening to max(4 months, 4 releases) [#12402](https://github.com/apache/datafusion/pull/12402) (comphead) +- Add a `version()` UDF [#12429](https://github.com/apache/datafusion/pull/12429) (samuelcolvin) +- Support timestamps and steps of less than a day for range/generate_series [#12400](https://github.com/apache/datafusion/pull/12400) (Omega359) +- Improve comments on target user and unify intro summaries [#12418](https://github.com/apache/datafusion/pull/12418) (alamb) +- Add 'Extensions List' page to the documentation [#12420](https://github.com/apache/datafusion/pull/12420) (alamb) +- Added array_any_value function [#12329](https://github.com/apache/datafusion/pull/12329) (athultr1997) + +**Other:** + +- Sync rust docs params for CI and dev [#11890](https://github.com/apache/datafusion/pull/11890) (findepi) +- Update ASCII scalar function to support Utf8View #11834 [#11884](https://github.com/apache/datafusion/pull/11884) (dmitrybugakov) +- Fix `Duration` vs `Interval` comparisons and `Interval` as LHS [#11876](https://github.com/apache/datafusion/pull/11876) (samuelcolvin) +- Produce clear error message when build runs with conflicting features [#11895](https://github.com/apache/datafusion/pull/11895) (findepi) +- Add tests for StringView / character functions, fix `regexp_like` and `regexp_match` to work with StringView [#11753](https://github.com/apache/datafusion/pull/11753) (alamb) +- Avoid unecessary copy when reading arrow files [#11840](https://github.com/apache/datafusion/pull/11840) (XiangpengHao) +- Support NULL literal in Min/Max [#11812](https://github.com/apache/datafusion/pull/11812) (xinlifoobar) +- Remove many `crate::` imports in listing table provider module [#11887](https://github.com/apache/datafusion/pull/11887) (findepi) +- Rename `Expr::display_name` to `Expr::schema_name`, make `UNNEST` naming conform to convention [#11797](https://github.com/apache/datafusion/pull/11797) (jayzhan211) +- Make `CommonSubexprEliminate` top-down like [#11683](https://github.com/apache/datafusion/pull/11683) (peter-toth) +- Add `generate_series` tests for arrays [#11921](https://github.com/apache/datafusion/pull/11921) (alamb) +- Minor: use `lit(true)` and `lit(false)` more [#11904](https://github.com/apache/datafusion/pull/11904) (alamb) +- Fix: panics in `approx_percentile_cont()` aggregate function [#11934](https://github.com/apache/datafusion/pull/11934) (2010YOUY01) +- Ingore shebang at top of file in `datafusion-cli` [#11927](https://github.com/apache/datafusion/pull/11927) (PsiACE) +- Parse Sqllogictest column types from physical schema [#11929](https://github.com/apache/datafusion/pull/11929) (jonahgao) +- Update INITCAP scalar function to support Utf8View [#11888](https://github.com/apache/datafusion/pull/11888) (xinlifoobar) +- Implement native support StringView for Octet Length [#11906](https://github.com/apache/datafusion/pull/11906) (PsiACE) +- Implement native support StringView for Ends With [#11924](https://github.com/apache/datafusion/pull/11924) (PsiACE) +- Implement native support StringView for Levenshtein [#11925](https://github.com/apache/datafusion/pull/11925) (PsiACE) +- Implement native stringview support for BTRIM [#11920](https://github.com/apache/datafusion/pull/11920) (Kev1n8) +- Move `LimitPushdown` to physical-optimizer crate [#11945](https://github.com/apache/datafusion/pull/11945) (lewiszlw) +- Minor: Improve comments in row_hash.rs for skipping aggregation [#11820](https://github.com/apache/datafusion/pull/11820) (alamb) +- chore: Add SessionState to MockContextProvider just like SessionContextProvider [#11940](https://github.com/apache/datafusion/pull/11940) (dharanad) +- Update labeler.yml to match crates [#11937](https://github.com/apache/datafusion/pull/11937) (alamb) +- Support tuples as types [#11896](https://github.com/apache/datafusion/pull/11896) (samuelcolvin) +- Support `convert_to_state` for `AVG` accumulator [#11734](https://github.com/apache/datafusion/pull/11734) (alamb) +- minor: Update release documentation based on 41.0.0 release [#11947](https://github.com/apache/datafusion/pull/11947) (andygrove) +- Make `Precision` copy to make it clear clones are not expensive [#11828](https://github.com/apache/datafusion/pull/11828) (alamb) +- Minor: simplify SQL number parsing and add a comment about unused [#11965](https://github.com/apache/datafusion/pull/11965) (alamb) +- Support Arrays for the Map scalar functions [#11712](https://github.com/apache/datafusion/pull/11712) (dharanad) +- Implement Utf8View for lpad scalar function [#11941](https://github.com/apache/datafusion/pull/11941) (Omega359) +- Add native stringview support for LTRIM & RTRIM [#11948](https://github.com/apache/datafusion/pull/11948) (Kev1n8) +- Move wildcard expansions to the analyzer [#11681](https://github.com/apache/datafusion/pull/11681) (goldmedal) +- Add native stringview support for RIGHT [#11955](https://github.com/apache/datafusion/pull/11955) (Kev1n8) +- Register get_field by default [#11959](https://github.com/apache/datafusion/pull/11959) (leoyvens) +- Refactor `CoalesceBatches` to use an explicit state machine [#11966](https://github.com/apache/datafusion/pull/11966) (berkaysynnada) +- Implement native support StringView for find in set [#11970](https://github.com/apache/datafusion/pull/11970) (PsiACE) +- test: re-enable window function over parquet with forced collisions [#11939](https://github.com/apache/datafusion/pull/11939) (korowa) +- Implement native support StringView for `REPEAT` [#11962](https://github.com/apache/datafusion/pull/11962) (tlm365) +- Update RPAD scalar function to support Utf8View [#11942](https://github.com/apache/datafusion/pull/11942) (Lordworms) +- Improve lpad udf by using a GenericStringBuilder [#11987](https://github.com/apache/datafusion/pull/11987) (Omega359) +- Implement native support StringView for substr_index [#11974](https://github.com/apache/datafusion/pull/11974) (PsiACE) +- Add native stringview support for LEFT [#11976](https://github.com/apache/datafusion/pull/11976) (Kev1n8) +- Minor: Improve function documentation [#11996](https://github.com/apache/datafusion/pull/11996) (alamb) +- Implement native support StringView for overlay [#11968](https://github.com/apache/datafusion/pull/11968) (PsiACE) +- Keep the existing default catalog for `SessionStateBuilder::new_from_existing` [#11991](https://github.com/apache/datafusion/pull/11991) (goldmedal) +- Use tracked-consumers memory pool be the default. [#11949](https://github.com/apache/datafusion/pull/11949) (wiedld) +- Update REVERSE scalar function to support Utf8View [#11973](https://github.com/apache/datafusion/pull/11973) (Omega359) +- Support partial aggregation skip for boolean functions [#11847](https://github.com/apache/datafusion/pull/11847) (2010YOUY01) +- feat/11953: Support StringView for TRANSLATE() fn [#11967](https://github.com/apache/datafusion/pull/11967) (devanbenz) +- Update SPLIT_PART scalar function to support Utf8View [#11975](https://github.com/apache/datafusion/pull/11975) (Lordworms) +- Handle arguments checking of `min`/`max` function to avoid crashes [#12016](https://github.com/apache/datafusion/pull/12016) (tlm365) +- Fix: support NULL input for regular expression comparison operations [#11985](https://github.com/apache/datafusion/pull/11985) (HuSen8891) +- Remove physical sort parameters on aggregate window functions [#12009](https://github.com/apache/datafusion/pull/12009) (timsaucer) +- Minor: Use execution error in ScalarValue::iter_to_array for incorrect usage [#11999](https://github.com/apache/datafusion/pull/11999) (jayzhan211) +- Fix: support NULL input for like operations [#12025](https://github.com/apache/datafusion/pull/12025) (HuSen8891) +- Minor: Add error tests for min/max with 2 arguments [#12024](https://github.com/apache/datafusion/pull/12024) (alamb) +- Improve performance of REPEAT functions [#12015](https://github.com/apache/datafusion/pull/12015) (tlm365) +- Update SUBSTR scalar function to support Utf8View [#12019](https://github.com/apache/datafusion/pull/12019) (dmitrybugakov) +- Minor: Remove wrong comment on `Accumulator::evaluate` and `Accumulator::state` [#12001](https://github.com/apache/datafusion/pull/12001) (lewiszlw) +- Minor: cleanup `.gitignore` [#12035](https://github.com/apache/datafusion/pull/12035) (alamb) +- Improve documentation about `ParquetExec` / Parquet predicate pushdown [#11994](https://github.com/apache/datafusion/pull/11994) (alamb) +- refactor: Move `LimitedDistinctAggregation` to `physical-optimizer` crate [#12036](https://github.com/apache/datafusion/pull/12036) (Weijun-H) +- Convert built-in `row_number` to user-defined window function [#12030](https://github.com/apache/datafusion/pull/12030) (jcsherin) +- Fix projection name with DataFrame::with_column and window functions [#12000](https://github.com/apache/datafusion/pull/12000) (devanbenz) +- Update to `sqlparser-rs` v0.50.0 [#12014](https://github.com/apache/datafusion/pull/12014) (samuelcolvin) +- Minor: make some physical-plan properties public [#12022](https://github.com/apache/datafusion/pull/12022) (emgeee) +- chore: improve variable naming conventions [#12042](https://github.com/apache/datafusion/pull/12042) (caicancai) +- Fix: handle NULL input for regex match operations [#12028](https://github.com/apache/datafusion/pull/12028) (HuSen8891) +- Fix compilation, change row_number() expr_fn to 0 args [#12043](https://github.com/apache/datafusion/pull/12043) (alamb) +- Minor: Remove warning when building `datafusion-cli` from `Dockerfile` [#12018](https://github.com/apache/datafusion/pull/12018) (tlm365) +- Minor: add getter method for LogicalPlanBuilder.plan [#12038](https://github.com/apache/datafusion/pull/12038) (emgeee) +- Window UDF signature check [#12045](https://github.com/apache/datafusion/pull/12045) (jayzhan211) +- Fix: generate_series function support string type [#12002](https://github.com/apache/datafusion/pull/12002) (getChan) +- Do not add redundant subquery ordering into plan [#12003](https://github.com/apache/datafusion/pull/12003) (mertak-synnada) +- Fix: Remove Unrelated Fields When Expanding Wildcards in Functional Dependency Projections [#12060](https://github.com/apache/datafusion/pull/12060) (berkaysynnada) +- Update async-trait in CLI and catalog crates [#12061](https://github.com/apache/datafusion/pull/12061) (findepi) +- Minor: remove clones in `coerce_plan_expr_for_schema` [#12051](https://github.com/apache/datafusion/pull/12051) (jonahgao) +- implement utf8_view for replace [#12004](https://github.com/apache/datafusion/pull/12004) (thinh2) +- Minor: update sqllogictest to treat Utf8View as text [#12033](https://github.com/apache/datafusion/pull/12033) (alamb) +- [MINOR] correct document mistakes [#12068](https://github.com/apache/datafusion/pull/12068) (FANNG1) +- Plan `LATERAL` subqueries [#11456](https://github.com/apache/datafusion/pull/11456) (aalexandrov) +- Faster random() scalar function [#12078](https://github.com/apache/datafusion/pull/12078) (2010YOUY01) +- functions: support strpos with mixed string types [#12072](https://github.com/apache/datafusion/pull/12072) (nrc) +- Update to `clap` 4.5.16 [#12064](https://github.com/apache/datafusion/pull/12064) (findepi) +- Fix the schema mismatch between logical and physical for aggregate function, add `AggregateUDFImpl::is_null` [#11989](https://github.com/apache/datafusion/pull/11989) (jayzhan211) +- minor: Remove unused create_row_hashes [#12083](https://github.com/apache/datafusion/pull/12083) (andygrove) +- Improve rpad udf by using a GenericStringBuilder [#12070](https://github.com/apache/datafusion/pull/12070) (Lordworms) +- Add test to verify count aggregate function should not be nullable [#12100](https://github.com/apache/datafusion/pull/12100) (HuSen8891) +- Minor: Extract `BatchCoalescer` to its own module [#12047](https://github.com/apache/datafusion/pull/12047) (alamb) +- Add Utf8View support to STRPOS function [#12087](https://github.com/apache/datafusion/pull/12087) (dmitrybugakov) +- Update itertools requirement from 0.12 to 0.13 [#10556](https://github.com/apache/datafusion/pull/10556) (dependabot[bot]) +- Fix wildcard expansion for `HAVING` clause [#12046](https://github.com/apache/datafusion/pull/12046) (goldmedal) +- Convert LogicalPlanBuilder to use Arc [#12040](https://github.com/apache/datafusion/pull/12040) (jc4x4) +- Minor: rename `dictionary_coercion` to `dictionary_comparison_coercion`, add comments [#12102](https://github.com/apache/datafusion/pull/12102) (alamb) +- Improve documentation on `StringArrayType` trait [#12027](https://github.com/apache/datafusion/pull/12027) (alamb) +- Improve split_part udf by using a GenericStringBuilder [#12093](https://github.com/apache/datafusion/pull/12093) (Lordworms) +- Fix compilation on main [#12108](https://github.com/apache/datafusion/pull/12108) (alamb) +- minor: SortExec measure elapsed_compute time when sorting [#12099](https://github.com/apache/datafusion/pull/12099) (mhilton) +- Support string concat `||` for StringViewArray [#12063](https://github.com/apache/datafusion/pull/12063) (dharanad) +- Minor: make RowNumber public [#12110](https://github.com/apache/datafusion/pull/12110) (berkaysynnada) +- Add benchmark for SUBSTR to evaluate improvements using StringView [#12111](https://github.com/apache/datafusion/pull/12111) (Kev1n8) +- [minor] Use Vec instead of primitive builders [#12121](https://github.com/apache/datafusion/pull/12121) (Dandandan) +- Fix thread panic when "unreachable" SpawnedTask code is reachable. [#12086](https://github.com/apache/datafusion/pull/12086) (wiedld) +- Improve `CombinePartialFinalAggregate` code readability [#12128](https://github.com/apache/datafusion/pull/12128) (lewiszlw) +- Use `LexRequirement` alias as much as possible [#12130](https://github.com/apache/datafusion/pull/12130) (lewiszlw) +- `array_has` avoid row converter for string type [#12097](https://github.com/apache/datafusion/pull/12097) (jayzhan211) +- Throw `not_impl_error` for `approx_percentile_cont` parameters validation [#12133](https://github.com/apache/datafusion/pull/12133) (goldmedal) +- minor: Add comments for `GroupedHashAggregateStream` struct [#12127](https://github.com/apache/datafusion/pull/12127) (2010YOUY01) +- fix concat dictionary(int32, utf8) bug [#12143](https://github.com/apache/datafusion/pull/12143) (thinh2) +- `array_has` with eq kernel [#12125](https://github.com/apache/datafusion/pull/12125) (jayzhan211) +- Check for overflow in substring with negative start [#12141](https://github.com/apache/datafusion/pull/12141) (findepi) +- Minor: add test for panic propagation [#12134](https://github.com/apache/datafusion/pull/12134) (alamb) +- Add benchmark for STDDEV and VAR to Clickbench extended [#12146](https://github.com/apache/datafusion/pull/12146) (alamb) +- Use Result.unwrap_or_else where applicable [#12166](https://github.com/apache/datafusion/pull/12166) (findepi) +- Provide documentation of expose APIs to enable handling of type coercion at UNION plan construction. [#12142](https://github.com/apache/datafusion/pull/12142) (wiedld) +- Implement groups accumulator for stddev and variance [#12095](https://github.com/apache/datafusion/pull/12095) (eejbyfeldt) +- Minor: refine Partitioning documentation [#12145](https://github.com/apache/datafusion/pull/12145) (alamb) +- Minor: allow to build RuntimeEnv from RuntimeConfig [#12151](https://github.com/apache/datafusion/pull/12151) (theirix) +- benches: add lower benches for stringview [#12152](https://github.com/apache/datafusion/pull/12152) (tshauck) +- Replace Arc::try_unwrap with Arc::unwrap_or_clone where cloning anyway [#12173](https://github.com/apache/datafusion/pull/12173) (findepi) +- Enable the test for creating empty map [#12176](https://github.com/apache/datafusion/pull/12176) (goldmedal) +- Remove unwrap_arc helper [#12172](https://github.com/apache/datafusion/pull/12172) (findepi) +- Fix typo [#12169](https://github.com/apache/datafusion/pull/12169) (findepi) +- Minor: remove vestigal github workflow `pr_comment.yml` [#12182](https://github.com/apache/datafusion/pull/12182) (alamb) +- Remove `AggregateExpr` trait [#12096](https://github.com/apache/datafusion/pull/12096) (lewiszlw) +- Deduplicate sort unparsing logic [#12175](https://github.com/apache/datafusion/pull/12175) (findepi) +- Require sort expressions to be of type Sort [#12171](https://github.com/apache/datafusion/pull/12171) (findepi) +- Remove `parse_vec_expr` helper [#12178](https://github.com/apache/datafusion/pull/12178) (findepi) +- Reuse bulk serialization helpers for protobuf [#12179](https://github.com/apache/datafusion/pull/12179) (findepi) +- Remove unnecessary clones from `.../logical_plan/builder.rs` [#12196](https://github.com/apache/datafusion/pull/12196) (findepi) +- Remove unnecessary clones with clippy [#12197](https://github.com/apache/datafusion/pull/12197) (findepi) +- Make RuntimeEnvBuilder rather than RuntimeConfig [#12157](https://github.com/apache/datafusion/pull/12157) (devanbenz) +- Minor: Fix grouping set typo [#12216](https://github.com/apache/datafusion/pull/12216) (lewiszlw) +- Unbounded SortExec (and Top-K) Implementation When Req's Are Satisfied [#12174](https://github.com/apache/datafusion/pull/12174) (berkaysynnada) +- Remove normalize_with_schemas function [#12233](https://github.com/apache/datafusion/pull/12233) (findepi) +- Update AWS dependencies in CLI [#12229](https://github.com/apache/datafusion/pull/12229) (findepi) +- Avoid Arc::clone when serializing physical expressions [#12235](https://github.com/apache/datafusion/pull/12235) (findepi) +- Confirming UDF aliases are serialized correctly [#12219](https://github.com/apache/datafusion/pull/12219) (edmondop) +- Minor: Reuse `NamePreserver` in `SimplifyExpressions` [#12238](https://github.com/apache/datafusion/pull/12238) (jonahgao) +- Remove redundant argument and improve error message [#12217](https://github.com/apache/datafusion/pull/12217) (findepi) +- Remove deprecated from_plan function [#12265](https://github.com/apache/datafusion/pull/12265) (findepi) +- Remove redundant result of `AggregateFunctionExpr::field` [#12258](https://github.com/apache/datafusion/pull/12258) (lewiszlw) +- Define current arrow_cast behavior for BinaryView [#12200](https://github.com/apache/datafusion/pull/12200) (wiedld) +- Update prost dependency [#12237](https://github.com/apache/datafusion/pull/12237) (findepi) +- Bump webpack from 5.88.2 to 5.94.0 in /datafusion/wasmtest/datafusion-wasm-app [#12236](https://github.com/apache/datafusion/pull/12236) (dependabot[bot]) +- Avoid redundant pass-by-value in physical optimizer [#12261](https://github.com/apache/datafusion/pull/12261) (findepi) +- Remove FileScanConfig::repartition_file_groups function [#12242](https://github.com/apache/datafusion/pull/12242) (findepi) +- Make group expressions nullable more accurate [#12256](https://github.com/apache/datafusion/pull/12256) (lewiszlw) +- Avoid redundant pass-by-value in optimizer [#12262](https://github.com/apache/datafusion/pull/12262) (findepi) +- Support alternate format for Date32 unparsing (TEXT/SQLite) [#12282](https://github.com/apache/datafusion/pull/12282) (sgrebnov) +- Extract drive-by fixes from PR 12135 for easier reviewing [#12240](https://github.com/apache/datafusion/pull/12240) (itsjunetime) +- Move `CombinePartialFinalAggregate` rule into physical-optimizer crate [#12167](https://github.com/apache/datafusion/pull/12167) (lewiszlw) +- Minor: Add `RuntimeEnvBuilder::build_arc() [#12213](https://github.com/apache/datafusion/pull/12213) (alamb) +- Introduce `Signature::Coercible` [#12275](https://github.com/apache/datafusion/pull/12275) (jayzhan211) +- fix hash-repartition panic [#12297](https://github.com/apache/datafusion/pull/12297) (thinh2) +- Remove unsafe Send impl from PriorityMap [#12289](https://github.com/apache/datafusion/pull/12289) (findepi) +- test: check record count and types in parquet window test [#12277](https://github.com/apache/datafusion/pull/12277) (korowa) +- Optimize `struct` and `named_struct` functions [#11688](https://github.com/apache/datafusion/pull/11688) (Rafferty97) +- Update the CONCAT scalar function to support Utf8View [#12224](https://github.com/apache/datafusion/pull/12224) (devanbenz) +- chore: Fix warnings produced by shellcheck on bench.sh [#12303](https://github.com/apache/datafusion/pull/12303) (eejbyfeldt) +- test: re-enable merge join test with forced collisions [#12276](https://github.com/apache/datafusion/pull/12276) (korowa) +- Fix various typos in aggregation doc [#12301](https://github.com/apache/datafusion/pull/12301) (lewiszlw) +- Improve binary scalars display [#12192](https://github.com/apache/datafusion/pull/12192) (lewiszlw) +- Minor: Reduce string allocations in ScalarValue::binary display [#12322](https://github.com/apache/datafusion/pull/12322) (alamb) +- minor: Add PartialEq, Eq traits to StatsType [#12327](https://github.com/apache/datafusion/pull/12327) (andygrove) +- Update to `arrow`/`parquet` `53.0.0`, `tonic`, `prost`, `object_store`, `pyo3` [#12032](https://github.com/apache/datafusion/pull/12032) (alamb) +- Minor: Update Sanity Checker Error Messages [#12333](https://github.com/apache/datafusion/pull/12333) (berkaysynnada) +- Improve & unify validation in LogicalPlan::with_new_exprs [#12264](https://github.com/apache/datafusion/pull/12264) (findepi) +- Support the custom terminator for the CSV file format [#12263](https://github.com/apache/datafusion/pull/12263) (goldmedal) +- Support try_from_array and eq_array for ScalarValue::Union [#12208](https://github.com/apache/datafusion/pull/12208) (joroKr21) +- Fix some clippy warnings [#12346](https://github.com/apache/datafusion/pull/12346) (mbrobbel) +- minor: reuse SessionStateBuilder methods for default builder [#12330](https://github.com/apache/datafusion/pull/12330) (comphead) +- Push down null filters for more join types [#12348](https://github.com/apache/datafusion/pull/12348) (Dandandan) +- Move `TopKAggregation` rule into `physical-optimizer` crate [#12334](https://github.com/apache/datafusion/pull/12334) (lewiszlw) +- Support Utf8View and BinaryView in substrait serialization. [#12199](https://github.com/apache/datafusion/pull/12199) (wiedld) +- Fix Possible Congestion Scenario in `SortPreservingMergeExec` [#12302](https://github.com/apache/datafusion/pull/12302) (berkaysynnada) +- Minor: Re-export variable provider [#12351](https://github.com/apache/datafusion/pull/12351) (lewiszlw) +- Support protobuf encoding and decoding of `UnnestExec` [#12344](https://github.com/apache/datafusion/pull/12344) (joroKr21) +- Fix subquery alias table definition unparsing for SQLite [#12331](https://github.com/apache/datafusion/pull/12331) (sgrebnov) +- Remove deprecated ScalarValue::get_datatype [#12361](https://github.com/apache/datafusion/pull/12361) (findepi) +- Improve StringView support for SUBSTR [#12044](https://github.com/apache/datafusion/pull/12044) (Kev1n8) +- Minor: improve performance of `ScalarValue::Binary*` debug [#12323](https://github.com/apache/datafusion/pull/12323) (alamb) +- Apply non-nested kernel for non-nested in `array_has` and `inlist` [#12164](https://github.com/apache/datafusion/pull/12164) (jayzhan211) +- Faster `character_length()` string function for ASCII-only case [#12356](https://github.com/apache/datafusion/pull/12356) (2010YOUY01) +- Unparse TableScan with projections, filters or fetch to SQL string [#12158](https://github.com/apache/datafusion/pull/12158) (goldmedal) +- Minor: Support protobuf serialization for Utf8View and BinaryView [#12165](https://github.com/apache/datafusion/pull/12165) (Lordworms) +- Minor: Add tests for using FilterExec when parquet was pushed down [#12362](https://github.com/apache/datafusion/pull/12362) (alamb) +- Minor: Add getter for logical optimizer rules [#12379](https://github.com/apache/datafusion/pull/12379) (maronavenue) +- Update sqllogictest requirement from 0.21.0 to 0.22.0 [#12388](https://github.com/apache/datafusion/pull/12388) (dependabot[bot]) +- Support StringView for binary operators [#12212](https://github.com/apache/datafusion/pull/12212) (tlm365) +- Support for SIMILAR TO for physical plan [#12350](https://github.com/apache/datafusion/pull/12350) (theirix) +- Remove deprecated expression optimizer's utils [#12390](https://github.com/apache/datafusion/pull/12390) (findepi) +- Minor: Remove redundant usage of clone [#12392](https://github.com/apache/datafusion/pull/12392) (waruto210) +- Introduce the `DynamicFileCatalog` in `datafusion-catalog` [#11035](https://github.com/apache/datafusion/pull/11035) (goldmedal) +- tests: enable fuzz for filtered anti-semi NLJoin [#12360](https://github.com/apache/datafusion/pull/12360) (korowa) +- Refactor `SqlToRel::sql_expr_to_logical_expr_internal` to reduce stack size [#12384](https://github.com/apache/datafusion/pull/12384) (Jefffrey) +- Reuse `spill_record_batch_by_size` function [#12389](https://github.com/apache/datafusion/pull/12389) (lewiszlw) +- minor: improve join fuzz tests debug kit [#12397](https://github.com/apache/datafusion/pull/12397) (comphead) +- Fix invalid ref in UserDefinedLogicalNodeCore doc [#12396](https://github.com/apache/datafusion/pull/12396) (dttung2905) +- Don't push down IsNotNull for `null_equals_null` case [#12404](https://github.com/apache/datafusion/pull/12404) (Dandandan) +- Fix: `substr()` on StringView column's behavior is inconsistent with the old version [#12383](https://github.com/apache/datafusion/pull/12383) (2010YOUY01) +- validate and adjust Substrait NamedTable schemas (#12223) [#12245](https://github.com/apache/datafusion/pull/12245) (vbarua) +- Bump rstest from 0.17.0 to 0.22.0 in /datafusion-cli [#12413](https://github.com/apache/datafusion/pull/12413) (dependabot[bot]) +- fix tpc-h parquet setting to respect global options [#12405](https://github.com/apache/datafusion/pull/12405) (XiangpengHao) +- Bump dirs from 4.0.0 to 5.0.1 in /datafusion-cli [#12411](https://github.com/apache/datafusion/pull/12411) (dependabot[bot]) +- Allow using dictionary arrays as filters [#12382](https://github.com/apache/datafusion/pull/12382) (adriangb) +- Add support for Utf8View, Boolean, Date32/64, int32/64 for writing hive style partitions [#12283](https://github.com/apache/datafusion/pull/12283) (Omega359) +- Bump env_logger from 0.9.3 to 0.11.5 in /datafusion-cli [#12410](https://github.com/apache/datafusion/pull/12410) (dependabot[bot]) +- Check window functions by str for with_column [#12431](https://github.com/apache/datafusion/pull/12431) (timsaucer) +- Fix incorrect OFFSET during LIMIT pushdown. [#12399](https://github.com/apache/datafusion/pull/12399) (wiedld) +- Fix: upper case qualifier wildcard bug [#12426](https://github.com/apache/datafusion/pull/12426) (JasonLi-cn) +- Fix: Internal error in regexp_replace() for some StringView input [#12203](https://github.com/apache/datafusion/pull/12203) (devanbenz) +- Automate sqllogictest for StringView (for one function, `substr`) [#12433](https://github.com/apache/datafusion/pull/12433) (2010YOUY01) +- Update concat_ws scalar function to support Utf8View [#12309](https://github.com/apache/datafusion/pull/12309) (devanbenz) +- Bump serve-static and express in /datafusion/wasmtest/datafusion-wasm-app [#12434](https://github.com/apache/datafusion/pull/12434) (dependabot[bot]) +- Minor: add err on `create` `temporary` table [#12439](https://github.com/apache/datafusion/pull/12439) (hailelagi) +- Minor: Add a test for version() function [#12441](https://github.com/apache/datafusion/pull/12441) (alamb) +- Handle case-sensitive identifier when decorrelating predicate subquery [#12443](https://github.com/apache/datafusion/pull/12443) (goldmedal) +- Bump send and express in /datafusion/wasmtest/datafusion-wasm-app [#12447](https://github.com/apache/datafusion/pull/12447) (dependabot[bot]) +- Add PartialOrd for the DF subfields/structs for the WindowFunction expr [#12421](https://github.com/apache/datafusion/pull/12421) (ngli-me) +- Making avro_to_arrow::schema::to_arrow_schema public [#12452](https://github.com/apache/datafusion/pull/12452) (ameyc) +- Bump rustyline from 11.0.0 to 14.0.0 in /datafusion-cli [#12407](https://github.com/apache/datafusion/pull/12407) (dependabot[bot]) +- Minor: add `ListingOptions::with_file_extension_opt` [#12461](https://github.com/apache/datafusion/pull/12461) (alamb) +- Improve PhysicalExpr and Column documentation [#12457](https://github.com/apache/datafusion/pull/12457) (alamb) +- fix length error with `array_has` [#12459](https://github.com/apache/datafusion/pull/12459) (samuelcolvin) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 39 Andrew Lamb + 35 Piotr Findeisen + 15 张林伟 + 10 Jax Liu + 9 Jay Zhan + 9 Jonah Gao + 9 dependabot[bot] + 8 wiedld + 7 Chojan Shang + 7 WeblWabl + 7 Yongting You + 6 Berkay Şahin + 6 Eduard Karacharov + 6 Tai Le Manh + 6 kf zheng + 5 Alex Huang + 5 Bruce Ritchie + 5 Lordworms + 5 Samuel Colvin + 4 Andy Grove + 4 Dharan Aditya + 4 HuSen + 4 Huaijin + 3 Arttu + 3 Austin Liu + 3 Daniël Heres + 3 Dmitry Bugakov + 3 Emil Ejbyfeldt + 3 Georgi Krastev + 3 JC + 3 Oleks V + 3 Trent Hauck + 3 iamthinh + 2 Athul T R + 2 June + 2 Liang-Chi Hsieh + 2 Martin Hilton + 2 Matt Green + 2 Sergei Grebnov + 2 Tim Saucer + 2 Xiangpeng Hao + 2 Xin Li + 2 jcsherin + 2 theirix + 1 Adrian Garcia Badaracco + 1 Albert Skalt + 1 Alexander Alexandrov + 1 Alexander Rafferty + 1 Amey Chaugule + 1 Cancai Cai + 1 Dao Thanh Tung + 1 Edmondo Porcu + 1 FANNG + 1 Haile + 1 JasonLi + 1 Jeffrey Vo + 1 Leonardo Yvens + 1 Maron Montano + 1 Martin Kolb + 1 Matthijs Brobbel + 1 Michael J Ward + 1 Mustafa Akur + 1 Namgung Chan + 1 Nick Cameron + 1 Peter Toth + 1 Phillip LeBlanc + 1 Victor Barua + 1 YjyJeff + 1 mertak-synnada + 1 ngli-me + 1 peasee + 1 waruto + 1 yfu +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/42.1.0.md b/dev/changelog/42.1.0.md new file mode 100644 index 000000000000..cf4f911150ac --- /dev/null +++ b/dev/changelog/42.1.0.md @@ -0,0 +1,42 @@ + + +# Apache DataFusion 42.1.0 Changelog + +This release consists of 5 commits from 4 contributors. See credits at the end of this changelog for more information. + +**Other:** + +- Backport update to arrow 53.1.0 on branch-42 [#12977](https://github.com/apache/datafusion/pull/12977) (alamb) +- Backport "Provide field and schema metadata missing on cross joins, and union with null fields" (#12729) [#12974](https://github.com/apache/datafusion/pull/12974) (matthewmturner) +- Backport "physical-plan: Cast nested group values back to dictionary if necessary" (#12586) [#12976](https://github.com/apache/datafusion/pull/12976) (matthewmturner) +- backport-to-DF-42: Provide field and schema metadata missing on distinct aggregations [#12975](https://github.com/apache/datafusion/pull/12975) (Xuanwo) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 2 Matthew Turner + 1 Andrew Lamb + 1 Andy Grove + 1 Xuanwo +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/42.2.0.md b/dev/changelog/42.2.0.md new file mode 100644 index 000000000000..6c907162c65e --- /dev/null +++ b/dev/changelog/42.2.0.md @@ -0,0 +1,37 @@ + + +# Apache DataFusion 42.2.0 Changelog + +This release consists of 1 commits from 1 contributor. See credits at the end of this changelog for more information. + +**Other:** + +- +- Backport config option `skip_physical_aggregate_schema_check` #13176 to 42 [#13189](https://github.com/apache/datafusion/pull/13189) (alamb) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 1 Andrew Lamb +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/changelog/43.0.0.md b/dev/changelog/43.0.0.md new file mode 100644 index 000000000000..e1fcc55b4b91 --- /dev/null +++ b/dev/changelog/43.0.0.md @@ -0,0 +1,545 @@ + + +# Apache DataFusion 43.0.0 Changelog + +This release consists of 403 commits from 96 contributors. See credits at the end of this changelog for more information. + +**Breaking changes:** + +- Remove Arc wrapping from create_udf's return_type [#12489](https://github.com/apache/datafusion/pull/12489) (findepi) +- Make make_scalar_function() result candidate for inlining, by removing the `Arc` [#12477](https://github.com/apache/datafusion/pull/12477) (findepi) +- Bump MSRV to 1.78 [#12398](https://github.com/apache/datafusion/pull/12398) (comphead) +- fix: DataFusion panics with "No candidates provided" [#12469](https://github.com/apache/datafusion/pull/12469) (Weijun-H) +- Implement PartialOrd for Expr and sub fields/structs without using hash values [#12481](https://github.com/apache/datafusion/pull/12481) (ngli-me) +- Add `field` trait method to `WindowUDFImpl`, remove `return_type`/`nullable` [#12374](https://github.com/apache/datafusion/pull/12374) (jcsherin) +- parquet: Make page_index/pushdown metrics consistent with row_group metrics [#12545](https://github.com/apache/datafusion/pull/12545) (progval) +- Make SessionContext::enable_url_table consume self [#12573](https://github.com/apache/datafusion/pull/12573) (alamb) +- LexRequirement as a struct, instead of a type [#12583](https://github.com/apache/datafusion/pull/12583) (ngli-me) +- Require `Debug` for `AnalyzerRule`, `FunctionRewriter`, and `OptimizerRule` [#12556](https://github.com/apache/datafusion/pull/12556) (alamb) +- Require `Debug` for `TableProvider`, `TableProviderFactory` and `PartitionStream` [#12557](https://github.com/apache/datafusion/pull/12557) (alamb) +- Require `Debug` for `PhysicalOptimizerRule` [#12624](https://github.com/apache/datafusion/pull/12624) (AnthonyZhOon) +- Rename aggregation modules, GroupColumn [#12619](https://github.com/apache/datafusion/pull/12619) (alamb) +- Update `register_table` functions args to take `Into` [#12630](https://github.com/apache/datafusion/pull/12630) (JasonLi-cn) +- Derive `Debug` for `SessionStateBuilder`, adding `Debug` requirements to fields [#12632](https://github.com/apache/datafusion/pull/12632) (AnthonyZhOon) +- Support REPLACE INTO for INSERT statements [#12516](https://github.com/apache/datafusion/pull/12516) (fmeringdal) +- Add `PartitionEvaluatorArgs` to `WindowUDFImpl::partition_evaluator` [#12804](https://github.com/apache/datafusion/pull/12804) (jcsherin) +- Convert `rank` / `dense_rank` and `percent_rank` builtin functions to UDWF [#12718](https://github.com/apache/datafusion/pull/12718) (jatin510) +- Bug-fix: MemoryExec sort expressions do NOT refer to the projected schema [#12876](https://github.com/apache/datafusion/pull/12876) (berkaysynnada) +- Minor: add flags for temporary ddl [#12561](https://github.com/apache/datafusion/pull/12561) (hailelagi) +- Convert `BuiltInWindowFunction::{Lead, Lag}` to a user defined window function [#12857](https://github.com/apache/datafusion/pull/12857) (jcsherin) +- Improve performance for physical plan creation with many columns [#12950](https://github.com/apache/datafusion/pull/12950) (askalt) +- Improve recursive `unnest` options API [#12836](https://github.com/apache/datafusion/pull/12836) (duongcongtoai) +- fix(substrait): disallow union with a single input [#13023](https://github.com/apache/datafusion/pull/13023) (tokoko) +- feat: support arbitrary expressions in `LIMIT` plan [#13028](https://github.com/apache/datafusion/pull/13028) (jonahgao) +- Remove unused `LogicalPlan::CrossJoin` as it is unused [#13076](https://github.com/apache/datafusion/pull/13076) (buraksenn) +- Minor: make `Expr::volatile` infallible [#13206](https://github.com/apache/datafusion/pull/13206) (alamb) +- Convert LexOrdering `type` to `struct`. [#13146](https://github.com/apache/datafusion/pull/13146) (ngli-me) + +**Implemented enhancements:** + +- feat(unparser): adding alias for table scan filter in sql unparser [#12453](https://github.com/apache/datafusion/pull/12453) (Lordworms) +- feat(substrait): set ProjectRel output_mapping in producer [#12495](https://github.com/apache/datafusion/pull/12495) (vbarua) +- feat:Support applying parquet bloom filters to StringView columns [#12503](https://github.com/apache/datafusion/pull/12503) (my-vegetable-has-exploded) +- feat: Support adding a single new table factory to SessionStateBuilder [#12563](https://github.com/apache/datafusion/pull/12563) (Weijun-H) +- feat(planner): Allowing setting sort order of parquet files without specifying the schema [#12466](https://github.com/apache/datafusion/pull/12466) (devanbenz) +- feat: add support for Substrait ExtendedExpression [#12728](https://github.com/apache/datafusion/pull/12728) (westonpace) +- feat(substrait): add intersect support to consumer [#12830](https://github.com/apache/datafusion/pull/12830) (tokoko) +- feat: Implement grouping function using grouping id [#12704](https://github.com/apache/datafusion/pull/12704) (eejbyfeldt) +- feat(substrait): add set operations to consumer, update substrait to `0.45.0` [#12863](https://github.com/apache/datafusion/pull/12863) (tokoko) +- feat(substrait): add wildcard handling to producer [#12987](https://github.com/apache/datafusion/pull/12987) (tokoko) +- feat: Add regexp_count function [#12970](https://github.com/apache/datafusion/pull/12970) (Omega359) +- feat: Decorrelate more predicate subqueries [#12945](https://github.com/apache/datafusion/pull/12945) (eejbyfeldt) +- feat: Run (logical) optimizers on subqueries [#13066](https://github.com/apache/datafusion/pull/13066) (eejbyfeldt) +- feat: Convert CumeDist to UDWF [#13051](https://github.com/apache/datafusion/pull/13051) (jonathanc-n) +- feat: Migrate Map Functions [#13047](https://github.com/apache/datafusion/pull/13047) (jonathanc-n) +- feat: improve type inference for WindowFrame [#13059](https://github.com/apache/datafusion/pull/13059) (notfilippo) +- feat: Move subquery check from analyzer to PullUpCorrelatedExpr (Fix TPC-DS q41) [#13091](https://github.com/apache/datafusion/pull/13091) (eejbyfeldt) +- feat: Add `Date32`/`Date64` in aggregate fuzz testing [#13041](https://github.com/apache/datafusion/pull/13041) (LeslieKid) +- feat(substrait): support order_by in aggregate functions [#13114](https://github.com/apache/datafusion/pull/13114) (bvolpato) +- feat: Support Substrait's IntervalCompound type/literal instead of interval-month-day-nano UDT [#12112](https://github.com/apache/datafusion/pull/12112) (Blizzara) +- feat: Implement LeftMark join to fix subquery correctness issue [#13134](https://github.com/apache/datafusion/pull/13134) (eejbyfeldt) +- feat: support logical plan for `EXECUTE` statement [#13194](https://github.com/apache/datafusion/pull/13194) (jonahgao) +- feat(substrait): handle emit_kind when consuming Substrait plans [#13127](https://github.com/apache/datafusion/pull/13127) (vbarua) +- feat(substrait): AggregateRel grouping_expressions support [#13173](https://github.com/apache/datafusion/pull/13173) (akoshchiy) + +**Fixed bugs:** + +- fix: Panic/correctness issue in variance GroupsAccumulator [#12615](https://github.com/apache/datafusion/pull/12615) (eejbyfeldt) +- fix: coalesce schema issues [#12308](https://github.com/apache/datafusion/pull/12308) (mesejo) +- fix: Correct results for grouping sets when columns contain nulls [#12571](https://github.com/apache/datafusion/pull/12571) (eejbyfeldt) +- fix(substrait): remove optimize calls from substrait consumer [#12800](https://github.com/apache/datafusion/pull/12800) (tokoko) +- fix(substrait): consuming AggregateRel as last node [#12875](https://github.com/apache/datafusion/pull/12875) (tokoko) +- fix: Update TO_DATE, TO_TIMESTAMP scalar functions to support LargeUtf8, Utf8View [#12929](https://github.com/apache/datafusion/pull/12929) (Omega359) +- fix: Add Int32 type override for Dialects [#12916](https://github.com/apache/datafusion/pull/12916) (peasee) +- fix: using simple string match replace regex match for contains udf [#12931](https://github.com/apache/datafusion/pull/12931) (zhuliquan) +- fix: Dialect requires derived table alias [#12994](https://github.com/apache/datafusion/pull/12994) (peasee) +- fix: join swap for projected semi/anti joins [#13022](https://github.com/apache/datafusion/pull/13022) (korowa) +- fix: Verify supported type for Unary::Plus in sql planner [#13019](https://github.com/apache/datafusion/pull/13019) (eejbyfeldt) +- fix: Do NOT preserve names (aliases) of Exprs for simplification in TableScan filters [#13048](https://github.com/apache/datafusion/pull/13048) (eejbyfeldt) +- fix: planning of prepare statement with limit clause [#13088](https://github.com/apache/datafusion/pull/13088) (jonahgao) +- fix: add missing `NotExpr::evaluate_bounds` [#13082](https://github.com/apache/datafusion/pull/13082) (crepererum) +- fix: Order by mentioning missing column multiple times [#13158](https://github.com/apache/datafusion/pull/13158) (eejbyfeldt) +- fix: import JoinTestType without triggering unused_qualifications lint [#13170](https://github.com/apache/datafusion/pull/13170) (smarticen) +- fix: default UDWFImpl::expressions returns all expressions [#13169](https://github.com/apache/datafusion/pull/13169) (Michael-J-Ward) +- fix: date_bin() on timstamps before 1970 [#13204](https://github.com/apache/datafusion/pull/13204) (mhilton) +- fix: array_resize null fix [#13209](https://github.com/apache/datafusion/pull/13209) (jonathanc-n) +- fix: CSV Infer Schema now properly supports escaped characters. [#13214](https://github.com/apache/datafusion/pull/13214) (mnorfolk03) + +**Documentation updates:** + +- chore: Prepare 42.0.0 Release [#12465](https://github.com/apache/datafusion/pull/12465) (andygrove) +- Minor: improve ParquetOpener docs [#12456](https://github.com/apache/datafusion/pull/12456) (alamb) +- Improve doc wording around scalar authoring [#12478](https://github.com/apache/datafusion/pull/12478) (findepi) +- Minor: improve `GroupsAccumulator` docs [#12501](https://github.com/apache/datafusion/pull/12501) (alamb) +- Minor: improve `GroupsAccumulatorAdapter` docs [#12502](https://github.com/apache/datafusion/pull/12502) (alamb) +- Improve flamegraph profiling instructions [#12521](https://github.com/apache/datafusion/pull/12521) (alamb) +- docs: :memo: Add expected answers to `DataFrame` method examples [#12564](https://github.com/apache/datafusion/pull/12564) (Eason0729) +- parquet: Add finer metrics on operations covered by `time_elapsed_opening` [#12585](https://github.com/apache/datafusion/pull/12585) (progval) +- Update scalar_functions.md [#12627](https://github.com/apache/datafusion/pull/12627) (Abdullahsab3) +- Move `kurtosis_pop` to datafusion-functions-extra and out of core [#12647](https://github.com/apache/datafusion/pull/12647) (dharanad) +- Update introduction.md for `blaze` project [#12577](https://github.com/apache/datafusion/pull/12577) (liyuance) +- docs: improve the documentation for Aggregate code [#12617](https://github.com/apache/datafusion/pull/12617) (alamb) +- doc: Fix malformed hex string literal in user guide [#12708](https://github.com/apache/datafusion/pull/12708) (kawadakk) +- docs: Update DataFusion introduction to clarify that DataFusion does provide an "out of the box" query engine [#12666](https://github.com/apache/datafusion/pull/12666) (andygrove) +- Framework for generating function docs from embedded code documentation [#12668](https://github.com/apache/datafusion/pull/12668) (Omega359) +- Fix misformatted links on project index page [#12750](https://github.com/apache/datafusion/pull/12750) (amoeba) +- Add `DocumentationBuilder::with_standard_argument` to reduce copy/paste [#12747](https://github.com/apache/datafusion/pull/12747) (alamb) +- Minor: doc how field name is to be set for `WindowUDF` [#12757](https://github.com/apache/datafusion/pull/12757) (jcsherin) +- Port / Add Documentation for `VarianceSample` and `VariancePopulation` [#12742](https://github.com/apache/datafusion/pull/12742) (alamb) +- Transformed::new_transformed: Fix documentation formatting [#12787](https://github.com/apache/datafusion/pull/12787) (progval) +- Migrate documentation for all string functions from scalar_functions.md to code [#12775](https://github.com/apache/datafusion/pull/12775) (Omega359) +- Minor: add README to Catalog Folder [#12797](https://github.com/apache/datafusion/pull/12797) (jonathanc-n) +- Remove redundant aggregate/window/scalar function documentation [#12745](https://github.com/apache/datafusion/pull/12745) (alamb) +- Improve description of function migration [#12743](https://github.com/apache/datafusion/pull/12743) (alamb) +- Crypto Function Migration [#12840](https://github.com/apache/datafusion/pull/12840) (jonathanc-n) +- Minor: more doc to `MemoryPool` module [#12849](https://github.com/apache/datafusion/pull/12849) (2010YOUY01) +- Migrate documentation for all core functions from scalar_functions.md to code [#12854](https://github.com/apache/datafusion/pull/12854) (Omega359) +- Migrate documentation for Aggregate Functions to code [#12861](https://github.com/apache/datafusion/pull/12861) (jonathanc-n) +- Wordsmith project description [#12778](https://github.com/apache/datafusion/pull/12778) (matthewmturner) +- Migrate Regex Functions from static docs [#12886](https://github.com/apache/datafusion/pull/12886) (jonathanc-n) +- Migrate documentation for all math functions from scalar_functions.md to code [#12908](https://github.com/apache/datafusion/pull/12908) (juroberttyb) +- Combine the logic of rank, dense_rank and percent_rank udwf to reduce duplications [#12893](https://github.com/apache/datafusion/pull/12893) (jatin510) +- Migrate Array function Documentation to code [#12948](https://github.com/apache/datafusion/pull/12948) (jonathanc-n) +- Minor: fix Aggregation Docs from review [#12880](https://github.com/apache/datafusion/pull/12880) (jonathanc-n) +- Minor: expr-doc small fixes [#12960](https://github.com/apache/datafusion/pull/12960) (jonathanc-n) +- docs: Add documentation about conventional commits [#12971](https://github.com/apache/datafusion/pull/12971) (andygrove) +- Migrate datetime documentation to code [#12966](https://github.com/apache/datafusion/pull/12966) (jatin510) +- Fix CI on main ( regenerate function docs) [#12991](https://github.com/apache/datafusion/pull/12991) (alamb) +- Split output batches of joins that do not respect batch size [#12969](https://github.com/apache/datafusion/pull/12969) (alihan-synnada) +- Minor: Fixed regexpr_match docs [#13008](https://github.com/apache/datafusion/pull/13008) (jonathanc-n) +- Minor: Fix spelling in regexpr_count docs [#13014](https://github.com/apache/datafusion/pull/13014) (jonathanc-n) +- Update version to 42.1.0, add CHANGELOG (#12986) [#12989](https://github.com/apache/datafusion/pull/12989) (alamb) +- Added expresion to "with_standard_argument" [#12926](https://github.com/apache/datafusion/pull/12926) (jonathanc-n) +- Migrate documentation for `regr*` aggregate functions to code [#12871](https://github.com/apache/datafusion/pull/12871) (alamb) +- Minor: Add documentation for `cot` [#13069](https://github.com/apache/datafusion/pull/13069) (alamb) +- Documentation: Add API deprecation policy [#13083](https://github.com/apache/datafusion/pull/13083) (comphead) +- docs: Fixed generate_series docs [#13097](https://github.com/apache/datafusion/pull/13097) (jonathanc-n) +- [docs]: migrate lead/lag window function docs to new docs [#13095](https://github.com/apache/datafusion/pull/13095) (buraksenn) +- minor: Add deprecated policy to the contributor guide contents [#13100](https://github.com/apache/datafusion/pull/13100) (comphead) +- Introduce `binary_as_string` parquet option, upgrade to arrow/parquet `53.2.0` [#12816](https://github.com/apache/datafusion/pull/12816) (goldmedal) +- Convert `ntile` builtIn function to UDWF [#13040](https://github.com/apache/datafusion/pull/13040) (jatin510) +- docs: Added Special Functions Page [#13102](https://github.com/apache/datafusion/pull/13102) (jonathanc-n) +- [docs]: added `alternative_syntax` function for docs [#13140](https://github.com/apache/datafusion/pull/13140) (jonathanc-n) +- Minor: Delete old cume_dist and percent_rank docs [#13137](https://github.com/apache/datafusion/pull/13137) (jonathanc-n) +- docs: Add alternative syntax for extract, trim and substring. [#13143](https://github.com/apache/datafusion/pull/13143) (Omega359) +- docs: switch completely to generated docs for scalar and aggregate functions [#13161](https://github.com/apache/datafusion/pull/13161) (Omega359) +- Minor: improve testing docs, mention `cargo nextest` [#13160](https://github.com/apache/datafusion/pull/13160) (alamb) +- minor: Update HOWTO to help with updating new docs [#13172](https://github.com/apache/datafusion/pull/13172) (jonathanc-n) +- Add config option `skip_physical_aggregate_schema_check ` [#13176](https://github.com/apache/datafusion/pull/13176) (alamb) +- Enable reading `StringViewArray` by default from Parquet (8% improvement for entire ClickBench suite) [#13101](https://github.com/apache/datafusion/pull/13101) (alamb) +- Forward port changes for `42.2.0` release (#13191) [#13193](https://github.com/apache/datafusion/pull/13193) (alamb) +- [minor] overload from_unixtime func to have optional timezone parameter [#13130](https://github.com/apache/datafusion/pull/13130) (buraksenn) + +**Other:** + +- Impl `convert_to_state` for `GroupsAccumulatorAdapter` (faster median for high cardinality aggregates) [#11827](https://github.com/apache/datafusion/pull/11827) (Rachelint) +- Upgrade sqlparser-rs to 0.51.0, support new interval logic from `sqlparse-rs` [#12222](https://github.com/apache/datafusion/pull/12222) (samuelcolvin) +- Do not silently ignore unsupported `CREATE TABLE` and `CREATE VIEW` syntax [#12450](https://github.com/apache/datafusion/pull/12450) (alamb) +- use FileFormat::get_ext as the default file extension filter [#12417](https://github.com/apache/datafusion/pull/12417) (waruto210) +- fix interval units parsing [#12448](https://github.com/apache/datafusion/pull/12448) (samuelcolvin) +- test(substrait): update TPCH tests [#12462](https://github.com/apache/datafusion/pull/12462) (vbarua) +- Add "Extended Clickbench" benchmark for median and approx_median for high cardinality aggregates [#12438](https://github.com/apache/datafusion/pull/12438) (alamb) +- date_trunc small update for readability [#12479](https://github.com/apache/datafusion/pull/12479) (findepi) +- cleanup `array_has` [#12460](https://github.com/apache/datafusion/pull/12460) (samuelcolvin) +- chore: bump chrono to 0.4.38 [#12485](https://github.com/apache/datafusion/pull/12485) (my-vegetable-has-exploded) +- Remove deprecated ScalarUDF::new [#12487](https://github.com/apache/datafusion/pull/12487) (findepi) +- Remove deprecated config setup functions [#12486](https://github.com/apache/datafusion/pull/12486) (findepi) +- Remove unnecessary shifts in gcd() [#12480](https://github.com/apache/datafusion/pull/12480) (findepi) +- Return TableProviderFilterPushDown::Exact when Parquet Pushdown Enabled [#12135](https://github.com/apache/datafusion/pull/12135) (itsjunetime) +- Update substrait requirement from 0.41 to 0.42, `prost-build` to `0.13.2` [#12483](https://github.com/apache/datafusion/pull/12483) (dependabot[bot]) +- Faster strpos() string function for ASCII-only case [#12401](https://github.com/apache/datafusion/pull/12401) (goldmedal) +- Specialize ASCII case for substr() [#12444](https://github.com/apache/datafusion/pull/12444) (2010YOUY01) +- Improve SQLite subquery tables aliasing unparsing [#12482](https://github.com/apache/datafusion/pull/12482) (sgrebnov) +- Minor: use Option rather than Result for not found suggestion [#12512](https://github.com/apache/datafusion/pull/12512) (alamb) +- Remove deprecated datafusion_physical_expr::functions module [#12505](https://github.com/apache/datafusion/pull/12505) (findepi) +- Remove deprecated AggregateUDF::new [#12508](https://github.com/apache/datafusion/pull/12508) (findepi) +- Make `required_guarantees` output to be deterministic [#12484](https://github.com/apache/datafusion/pull/12484) (austin362667) +- Deprecate unused ScalarUDF::fun [#12506](https://github.com/apache/datafusion/pull/12506) (findepi) +- Remove deprecated WindowUDF::new [#12507](https://github.com/apache/datafusion/pull/12507) (findepi) +- Preserve the order of right table in NestedLoopJoinExec [#12504](https://github.com/apache/datafusion/pull/12504) (alihan-synnada) +- Improve benchmark for ltrim [#12513](https://github.com/apache/datafusion/pull/12513) (Rachelint) +- Fix: check ambiguous column reference [#12467](https://github.com/apache/datafusion/pull/12467) (HuSen8891) +- Minor: move imports to top in `row_hash.rs` [#12530](https://github.com/apache/datafusion/pull/12530) (Rachelint) +- tests: Fix typo in config setting name [#12535](https://github.com/apache/datafusion/pull/12535) (progval) +- Expose DataFrame select_exprs method [#12520](https://github.com/apache/datafusion/pull/12520) (milenkovicm) +- Replace some usages of `Expr::to_field` with `Expr::qualified_name` [#12522](https://github.com/apache/datafusion/pull/12522) (jonahgao) +- Bump aws-sdk-sso to 1.43.0, aws-sdk-sts to 1.43.0 and aws-sdk-ssooidc from 1.40.0 to 1.44.0 in /datafusion-cli [#12409](https://github.com/apache/datafusion/pull/12409) (dependabot[bot]) +- Fix NestedLoopJoin performance regression [#12531](https://github.com/apache/datafusion/pull/12531) (alihan-synnada) +- Produce informative error message on insert plan type mismatch [#12540](https://github.com/apache/datafusion/pull/12540) (findepi) +- Fix unparse table scan with the projection pushdown [#12534](https://github.com/apache/datafusion/pull/12534) (goldmedal) +- Automate sqllogictest for String, LargeString and StringView behavior [#12525](https://github.com/apache/datafusion/pull/12525) (goldmedal) +- Fix unparsing offset [#12539](https://github.com/apache/datafusion/pull/12539) (Stazer) +- support EXTRACT on intervals and durations [#12514](https://github.com/apache/datafusion/pull/12514) (nrc) +- Support List type coercion for CASE-WHEN-THEN expression [#12490](https://github.com/apache/datafusion/pull/12490) (goldmedal) +- Sort metrics alphabetically in EXPLAIN ANALYZE output [#12568](https://github.com/apache/datafusion/pull/12568) (progval) +- Add `RuntimeEnv::try_new` and deprecate `RuntimeEnv::new` [#12566](https://github.com/apache/datafusion/pull/12566) (OussamaSaoudi) +- Reorgnize the StringView tests in sqllogictests [#12572](https://github.com/apache/datafusion/pull/12572) (goldmedal) +- fix parquet infer statistics for BinaryView types [#12575](https://github.com/apache/datafusion/pull/12575) (XiangpengHao) +- Minor: add example to of assert_batches_eq [#12580](https://github.com/apache/datafusion/pull/12580) (alamb) +- Use qualified aliases to simplify searching DFSchema [#12546](https://github.com/apache/datafusion/pull/12546) (jonahgao) +- return absent stats when filters are pushed down [#12471](https://github.com/apache/datafusion/pull/12471) (waruto210) +- Minor: add new() function for ParquetReadOptions [#12579](https://github.com/apache/datafusion/pull/12579) (Smith-Cruise) +- make `Debug` for `MemoryExec` prettier [#12582](https://github.com/apache/datafusion/pull/12582) (samuelcolvin) +- Add `SessionStateBuilder::with_object_store` method [#12578](https://github.com/apache/datafusion/pull/12578) (OussamaSaoudi) +- Fix and Improve Sort Pushdown for Nested Loop and Hash Join [#12559](https://github.com/apache/datafusion/pull/12559) (berkaysynnada) +- Add Docs and Examples and helper methods to `PhysicalSortExpr` [#12589](https://github.com/apache/datafusion/pull/12589) (alamb) +- Warn instead of error for unused imports [#12588](https://github.com/apache/datafusion/pull/12588) (samuelcolvin) +- Update prost-build requirement from =0.13.2 to =0.13.3 [#12587](https://github.com/apache/datafusion/pull/12587) (dependabot[bot]) +- Add JOB benchmark dataset [1/N] (imdb dataset) [#12497](https://github.com/apache/datafusion/pull/12497) (doupache) +- Improve documentation and add `Display` impl to `EquivalenceProperties` [#12590](https://github.com/apache/datafusion/pull/12590) (alamb) +- physical-plan: Cast nested group values back to dictionary if necessary [#12586](https://github.com/apache/datafusion/pull/12586) (brancz) +- Support `Date32` for `date_trunc` function [#12603](https://github.com/apache/datafusion/pull/12603) (goldmedal) +- Avoid RowConverter for multi column grouping (10% faster clickbench queries) [#12269](https://github.com/apache/datafusion/pull/12269) (jayzhan211) +- Refactor to support recursive unnest in physical plan [#11577](https://github.com/apache/datafusion/pull/11577) (duongcongtoai) +- Use original value when comparing with dictionary column in unparser [#12610](https://github.com/apache/datafusion/pull/12610) (Sevenannn) +- Fix to unparse the plan with multiple UNION statements into an SQL string [#12605](https://github.com/apache/datafusion/pull/12605) (goldmedal) +- Keep the float information in scalar_to_sql [#12609](https://github.com/apache/datafusion/pull/12609) (Sevenannn) +- Add Dictionary String (UTF8) type to String sqllogictests [#12621](https://github.com/apache/datafusion/pull/12621) (goldmedal) +- Improve SanityChecker error message [#12595](https://github.com/apache/datafusion/pull/12595) (alamb) +- Improve performance of `trim` for string view (10%) [#12395](https://github.com/apache/datafusion/pull/12395) (Rachelint) +- Simplify `update_skip_aggregation_probe` method [#12332](https://github.com/apache/datafusion/pull/12332) (lewiszlw) +- Minor: Encapsulate type check in GroupValuesColumn, avoid panic [#12620](https://github.com/apache/datafusion/pull/12620) (alamb) +- Fix sort node deserialization from proto [#12626](https://github.com/apache/datafusion/pull/12626) (palaska) +- Minor: improve documentation to StringView trim [#12629](https://github.com/apache/datafusion/pull/12629) (alamb) +- [MINOR]: Simplifications Sort Operator [#12639](https://github.com/apache/datafusion/pull/12639) (akurmustafa) +- [Minor] Remove redundant member from RepartitionExec [#12638](https://github.com/apache/datafusion/pull/12638) (akurmustafa) +- implement nested identifier access [#12614](https://github.com/apache/datafusion/pull/12614) (Lordworms) +- [MINOR]: Rename get_arrayref_at_indices to take_arrays [#12654](https://github.com/apache/datafusion/pull/12654) (akurmustafa) +- [MINOR]: Use take_arrays in repartition , fix build [#12657](https://github.com/apache/datafusion/pull/12657) (doupache) +- Add binary_view to string_view coercion [#12643](https://github.com/apache/datafusion/pull/12643) (doupache) +- [Minor] Improve error message when bitwise\_\* operator takes wrong unsupported type [#12646](https://github.com/apache/datafusion/pull/12646) (dharanad) +- Minor: Add github link to code that was upstreamed [#12660](https://github.com/apache/datafusion/pull/12660) (alamb) +- Minor: Improve documentation on execution error handling [#12651](https://github.com/apache/datafusion/pull/12651) (alamb) +- Adds `WindowUDFImpl::reverse_expr`trait method + Support for `IGNORE NULLS` [#12662](https://github.com/apache/datafusion/pull/12662) (jcsherin) +- Fill in missing `Debug` fields for `SessionState` [#12663](https://github.com/apache/datafusion/pull/12663) (AnthonyZhOon) +- Minor: add partial assertion for skip aggregation probe [#12640](https://github.com/apache/datafusion/pull/12640) (Rachelint) +- Add more functions for string sqllogictests [#12665](https://github.com/apache/datafusion/pull/12665) (goldmedal) +- Update rstest requirement from 0.22.0 to 0.23.0 [#12678](https://github.com/apache/datafusion/pull/12678) (dependabot[bot]) +- Minor: Change LiteralGuarantee try_new to new [#12669](https://github.com/apache/datafusion/pull/12669) (pgwhalen) +- Refactor PrimitiveGroupValueBuilder to use `MaybeNullBufferBuilder` [#12623](https://github.com/apache/datafusion/pull/12623) (alamb) +- Add `value_from_statisics` to AggregateUDFImpl, remove special case for min/max/count aggregate statistics [#12296](https://github.com/apache/datafusion/pull/12296) (edmondop) +- Provide field and schema metadata missing on distinct aggregations. [#12691](https://github.com/apache/datafusion/pull/12691) (wiedld) +- [MINOR]: Simplify required_input_ordering of BoundedWindowAggExec [#12656](https://github.com/apache/datafusion/pull/12656) (akurmustafa) +- handle 0 and NULL value of NTH_VALUE function [#12676](https://github.com/apache/datafusion/pull/12676) (thinh2) +- Improve documentation for AggregateUDFImpl::value_from_stats [#12689](https://github.com/apache/datafusion/pull/12689) (alamb) +- Add support for external tables with qualified names [#12645](https://github.com/apache/datafusion/pull/12645) (OussamaSaoudi) +- Fix Regex signature types [#12690](https://github.com/apache/datafusion/pull/12690) (blaginin) +- Refactor `ByteGroupValueBuilder` to use `MaybeNullBufferBuilder` [#12681](https://github.com/apache/datafusion/pull/12681) (alamb) +- Simplify match patterns in coercion rules [#12711](https://github.com/apache/datafusion/pull/12711) (findepi) +- Remove aggregate functions dependency on frontend [#12715](https://github.com/apache/datafusion/pull/12715) (findepi) +- Minor: Remove clone in `transform_to_states` [#12707](https://github.com/apache/datafusion/pull/12707) (jayzhan211) +- Refactor tests for union sorting properties, add tests for unions and constants [#12702](https://github.com/apache/datafusion/pull/12702) (alamb) +- Fix: support Qualified Wildcard in count aggregate function [#12673](https://github.com/apache/datafusion/pull/12673) (HuSen8891) +- Reduce code duplication in `PrimitiveGroupValueBuilder` with const generics [#12703](https://github.com/apache/datafusion/pull/12703) (alamb) +- Disallow duplicated qualified field names [#12608](https://github.com/apache/datafusion/pull/12608) (eejbyfeldt) +- Optimize base64/hex decoding by pre-allocating output buffers (~2x faster) [#12675](https://github.com/apache/datafusion/pull/12675) (simonvandel) +- Allow DynamicFileCatalog support to query partitioned file [#12683](https://github.com/apache/datafusion/pull/12683) (goldmedal) +- Support `LIMIT` Push-down logical plan optimization for `Extension` nodes [#12685](https://github.com/apache/datafusion/pull/12685) (austin362667) +- Fix AvroReader: Add union resolving for nested struct arrays [#12686](https://github.com/apache/datafusion/pull/12686) (JonasDev1) +- Adds macros for creating `WindowUDF` and `WindowFunction` expression [#12693](https://github.com/apache/datafusion/pull/12693) (jcsherin) +- Support unparsing plans with both Aggregation and Window functions [#12705](https://github.com/apache/datafusion/pull/12705) (sgrebnov) +- Fix strpos invocation with dictionary and null [#12712](https://github.com/apache/datafusion/pull/12712) (findepi) +- Add IMDB(JOB) Benchmark [2/N] (imdb queries) [#12529](https://github.com/apache/datafusion/pull/12529) (austin362667) +- Minor: avoid clone while calculating union equivalence properties [#12722](https://github.com/apache/datafusion/pull/12722) (alamb) +- Simplify streaming_merge function parameters [#12719](https://github.com/apache/datafusion/pull/12719) (mertak-synnada) +- Provide field and schema metadata missing on cross joins, and union with null fields. [#12729](https://github.com/apache/datafusion/pull/12729) (wiedld) +- Minor: Update string tests for strpos [#12739](https://github.com/apache/datafusion/pull/12739) (alamb) +- Apply `type_union_resolution` to array and values [#12753](https://github.com/apache/datafusion/pull/12753) (jayzhan211) +- fix `equal_to` in `PrimitiveGroupValueBuilder` [#12758](https://github.com/apache/datafusion/pull/12758) (Rachelint) +- Fix `equal_to` in `ByteGroupValueBuilder` [#12770](https://github.com/apache/datafusion/pull/12770) (alamb) +- Allow boolean Expr simplification even when nullable [#12746](https://github.com/apache/datafusion/pull/12746) (eejbyfeldt) +- Fix unnest conjunction with selecting wildcard expression [#12760](https://github.com/apache/datafusion/pull/12760) (goldmedal) +- Improve `round` scalar function unparsing for Postgres [#12744](https://github.com/apache/datafusion/pull/12744) (sgrebnov) +- Fix stack overflow calculating projected orderings [#12759](https://github.com/apache/datafusion/pull/12759) (alamb) +- Upgrade arrow/parquet to `53.1.0` / fix clippy [#12724](https://github.com/apache/datafusion/pull/12724) (alamb) +- Account for constant equivalence properties in union, tests [#12562](https://github.com/apache/datafusion/pull/12562) (alamb) +- Minor: clarify comment about empty dependencies [#12786](https://github.com/apache/datafusion/pull/12786) (alamb) +- Introduce Signature::String and return error if input of `strpos` is integer [#12751](https://github.com/apache/datafusion/pull/12751) (jayzhan211) +- Minor: improve docs on MovingMin/MovingMax [#12790](https://github.com/apache/datafusion/pull/12790) (alamb) +- Add union sorting equivalence end to end tests [#12721](https://github.com/apache/datafusion/pull/12721) (alamb) +- Fix bug in TopK aggregates [#12766](https://github.com/apache/datafusion/pull/12766) (avantgardnerio) +- Minor: clean up TODO comments in unnest.slt [#12795](https://github.com/apache/datafusion/pull/12795) (goldmedal) +- Refactor `DependencyMap` and `Dependencies` into structs [#12761](https://github.com/apache/datafusion/pull/12761) (alamb) +- Remove unnecessary `DFSchema::check_ambiguous_name` [#12805](https://github.com/apache/datafusion/pull/12805) (jonahgao) +- API from `ParquetExec` to `ParquetExecBuilder` [#12799](https://github.com/apache/datafusion/pull/12799) (alamb) +- Minor: add documentation note about `NullState` [#12791](https://github.com/apache/datafusion/pull/12791) (alamb) +- Chore: Move `aggregate statistics` optimizer test from core to optimizer crate [#12783](https://github.com/apache/datafusion/pull/12783) (jayzhan211) +- Clarify documentation on ArrowBytesMap and ArrowBytesViewMap [#12789](https://github.com/apache/datafusion/pull/12789) (alamb) +- Bump cookie and express in /datafusion/wasmtest/datafusion-wasm-app [#12825](https://github.com/apache/datafusion/pull/12825) (dependabot[bot]) +- Remove unused dependencies and features [#12808](https://github.com/apache/datafusion/pull/12808) (jonahgao) +- Add Aggregation fuzzer framework [#12667](https://github.com/apache/datafusion/pull/12667) (Rachelint) +- Retry apt-get and rustup on CI [#12714](https://github.com/apache/datafusion/pull/12714) (findepi) +- Support creating tables via SQL with `FixedSizeList` column (e.g. `a int[3]`) [#12810](https://github.com/apache/datafusion/pull/12810) (jandremarais) +- Make HashJoinExec::join_schema public [#12807](https://github.com/apache/datafusion/pull/12807) (progval) +- Fix convert_to_state bug in `GroupsAccumulatorAdapter` [#12834](https://github.com/apache/datafusion/pull/12834) (alamb) +- Fix: approx_percentile_cont_with_weight Panic [#12823](https://github.com/apache/datafusion/pull/12823) (jonathanc-n) +- Fix clippy error on wasmtest [#12844](https://github.com/apache/datafusion/pull/12844) (jonahgao) +- Fix panic on wrong number of arguments to substr [#12837](https://github.com/apache/datafusion/pull/12837) (eejbyfeldt) +- Fix Bug in Display for ScalarValue::Struct [#12856](https://github.com/apache/datafusion/pull/12856) (avantgardnerio) +- Support DictionaryString for Regex matching operators [#12768](https://github.com/apache/datafusion/pull/12768) (blaginin) +- Minor: Small comment changes in sql folder [#12838](https://github.com/apache/datafusion/pull/12838) (jonathanc-n) +- Add DuckDB struct test and row as alias [#12841](https://github.com/apache/datafusion/pull/12841) (jayzhan211) +- Support struct coercion in `type_union_resolution` [#12839](https://github.com/apache/datafusion/pull/12839) (jayzhan211) +- Added check for aggregate functions in optimizer rules [#12860](https://github.com/apache/datafusion/pull/12860) (jonathanc-n) +- Optimize `iszero` function (3-5x faster) [#12881](https://github.com/apache/datafusion/pull/12881) (simonvandel) +- Macro for creating record batch from literal slice [#12846](https://github.com/apache/datafusion/pull/12846) (timsaucer) +- Implement special min/max accumulator for Strings and Binary (10% faster for Clickbench Q28) [#12792](https://github.com/apache/datafusion/pull/12792) (alamb) +- Make PruningPredicate's rewrite public [#12850](https://github.com/apache/datafusion/pull/12850) (adriangb) +- octet_length + string view == ❤️ [#12900](https://github.com/apache/datafusion/pull/12900) (Omega359) +- Remove Expr clones in `select_to_plan` [#12887](https://github.com/apache/datafusion/pull/12887) (jonahgao) +- Minor: added to docs in expr folder [#12882](https://github.com/apache/datafusion/pull/12882) (jonathanc-n) +- Print undocumented functions to console while generating docs [#12874](https://github.com/apache/datafusion/pull/12874) (alamb) +- Fix: handle NULL offset of NTH_VALUE window function [#12851](https://github.com/apache/datafusion/pull/12851) (HuSen8891) +- Optimize `signum` function (3-25x faster) [#12890](https://github.com/apache/datafusion/pull/12890) (simonvandel) +- re-export PartitionEvaluatorArgs from datafusion_expr::function [#12878](https://github.com/apache/datafusion/pull/12878) (Michael-J-Ward) +- Unparse Sort with pushdown limit to SQL string [#12873](https://github.com/apache/datafusion/pull/12873) (goldmedal) +- Add spilling related metrics for aggregation [#12888](https://github.com/apache/datafusion/pull/12888) (2010YOUY01) +- Move equivalence fuzz testing to fuzz test binary [#12767](https://github.com/apache/datafusion/pull/12767) (alamb) +- Remove unused `math_expressions.rs` [#12917](https://github.com/apache/datafusion/pull/12917) (jonahgao) +- Improve AggregationFuzzer error reporting [#12832](https://github.com/apache/datafusion/pull/12832) (alamb) +- Import Arc consistently [#12899](https://github.com/apache/datafusion/pull/12899) (findepi) +- Optimize `isnan` (2-5x faster) [#12889](https://github.com/apache/datafusion/pull/12889) (simonvandel) +- Minor: Move StringArrayType, StringViewArrayBuilder, etc outside of string module [#12912](https://github.com/apache/datafusion/pull/12912) (Omega359) +- Remove redundant unsafe in test [#12914](https://github.com/apache/datafusion/pull/12914) (findepi) +- Ensure that math functions fulfil the ColumnarValue contract [#12922](https://github.com/apache/datafusion/pull/12922) (joroKr21) +- Optimization: support push down limit when full join [#12963](https://github.com/apache/datafusion/pull/12963) (JasonLi-cn) +- Implement `GroupColumn` support for `StringView` / `ByteView` (faster grouping performance) [#12809](https://github.com/apache/datafusion/pull/12809) (Rachelint) +- Implement native support StringView for `REGEXP_LIKE` [#12897](https://github.com/apache/datafusion/pull/12897) (tlm365) +- Minor: Refactor benchmark imports to use `util` module [#12885](https://github.com/apache/datafusion/pull/12885) (loloxwg) +- Fix zero data type in `expr % 1` simplification [#12913](https://github.com/apache/datafusion/pull/12913) (eejbyfeldt) +- Optimize performance of `math::cot` (~2x faster) [#12910](https://github.com/apache/datafusion/pull/12910) (tlm365) +- Expand wildcard expressions in distinct on [#12941](https://github.com/apache/datafusion/pull/12941) (epsio-banay) +- chores: remove redundant clone [#12964](https://github.com/apache/datafusion/pull/12964) (JasonLi-cn) +- Fix: handle NULL input in lead/lag window function [#12811](https://github.com/apache/datafusion/pull/12811) (HuSen8891) +- Fix logical vs physical schema mismatch for aliased `now()` [#12951](https://github.com/apache/datafusion/pull/12951) (wiedld) +- Optimize performance of `math::trunc` (~2.5x faster) [#12909](https://github.com/apache/datafusion/pull/12909) (tlm365) +- Minor: Add slt test for `DISTINCT ON` with wildcard [#12968](https://github.com/apache/datafusion/pull/12968) (alamb) +- Fix 'Too many open files' on fuzz test. [#12961](https://github.com/apache/datafusion/pull/12961) (dhegberg) +- Increase minimum supported Rust version (MSRV) to 1.79 [#12962](https://github.com/apache/datafusion/pull/12962) (findepi) +- Unparse `SubqueryAlias` without projections to SQL [#12896](https://github.com/apache/datafusion/pull/12896) (goldmedal) +- Fix 2 bugs related to push down partition filters [#12902](https://github.com/apache/datafusion/pull/12902) (eejbyfeldt) +- Move TableConstraint to Constraints conversion [#12953](https://github.com/apache/datafusion/pull/12953) (findepi) +- Added current_timestamp alias [#12958](https://github.com/apache/datafusion/pull/12958) (jonathanc-n) +- Improve unparsing for `ORDER BY`, `UNION`, Windows functions with Aggregation [#12946](https://github.com/apache/datafusion/pull/12946) (sgrebnov) +- Handle one-element array return value in ScalarFunctionExpr [#12965](https://github.com/apache/datafusion/pull/12965) (joroKr21) +- Add links to new_constraint_from_table_constraints doc [#12995](https://github.com/apache/datafusion/pull/12995) (findepi) +- Fix:fix HashJoin projection swap [#12967](https://github.com/apache/datafusion/pull/12967) (my-vegetable-has-exploded) +- refactor(substrait): refactor ReadRel consumer [#12983](https://github.com/apache/datafusion/pull/12983) (tokoko) +- Move SMJ join filtered part out of join_output stage. LeftOuter, LeftSemi [#12764](https://github.com/apache/datafusion/pull/12764) (comphead) +- Remove logical cross join in planning [#12985](https://github.com/apache/datafusion/pull/12985) (Dandandan) +- [MINOR]: Use arrow take_arrays, remove datafusion take_arrays [#13013](https://github.com/apache/datafusion/pull/13013) (akurmustafa) +- Don't preserve functional dependency when generating UNION logical plan [#12979](https://github.com/apache/datafusion/pull/12979) (Sevenannn) +- [Minor]: Add data based sort expression test [#12992](https://github.com/apache/datafusion/pull/12992) (akurmustafa) +- Removed last usages of scalar_inputs, scalar_input_types and inputs2 to use arrow unary/binary for performance [#12972](https://github.com/apache/datafusion/pull/12972) (buraksenn) +- Minor: Update release instructions to include new crates [#13024](https://github.com/apache/datafusion/pull/13024) (alamb) +- Extract CSE logic to `datafusion_common` [#13002](https://github.com/apache/datafusion/pull/13002) (peter-toth) +- Enhance table scan unparsing to avoid unnamed subqueries. [#13006](https://github.com/apache/datafusion/pull/13006) (goldmedal) +- Fix count on all null `VALUES` clause [#13029](https://github.com/apache/datafusion/pull/13029) (findepi) +- Support filter in cross join elimination [#13025](https://github.com/apache/datafusion/pull/13025) (Dandandan) +- [minor]: remove same util functions from the code base. [#13026](https://github.com/apache/datafusion/pull/13026) (akurmustafa) +- Improve `AggregateFuzz` testing: generate random queries [#12847](https://github.com/apache/datafusion/pull/12847) (alamb) +- Fix functions with Volatility::Volatile and parameters [#13001](https://github.com/apache/datafusion/pull/13001) (agscpp) +- refactor: Incorporate RewriteDisjunctivePredicate rule into SimplifyExpressions [#13032](https://github.com/apache/datafusion/pull/13032) (eejbyfeldt) +- Move filtered SMJ right join out of `join_partial` phase [#13053](https://github.com/apache/datafusion/pull/13053) (comphead) +- Remove functions and types deprecated since 37 [#13056](https://github.com/apache/datafusion/pull/13056) (findepi) +- Minor: Cleaned physical-plan Comments [#13055](https://github.com/apache/datafusion/pull/13055) (jonathanc-n) +- improve the condition checking for unparsing table_scan [#13062](https://github.com/apache/datafusion/pull/13062) (goldmedal) +- minor: simplify associated item bound of `hash_array_primitive` [#13070](https://github.com/apache/datafusion/pull/13070) (jonahgao) +- extended log.rs tests for unary/binary and f32/f64 casting [#13034](https://github.com/apache/datafusion/pull/13034) (buraksenn) +- Fix check_not_null_constraints null detection [#13033](https://github.com/apache/datafusion/pull/13033) (findepi) +- [Minor] Update info/list of TPC-DS queries [#13075](https://github.com/apache/datafusion/pull/13075) (Dandandan) +- Fix logical vs physical schema mismatch for UNION where some inputs are constants [#12954](https://github.com/apache/datafusion/pull/12954) (wiedld) +- Improve CSE stats [#13080](https://github.com/apache/datafusion/pull/13080) (peter-toth) +- Infer data type from schema for `Values` and add struct coercion to `coalesce` [#12864](https://github.com/apache/datafusion/pull/12864) (jayzhan211) +- [minor]: use arrow take_batch instead of get_record_batch_indices [#13084](https://github.com/apache/datafusion/pull/13084) (akurmustafa) +- chore: Added a number of physical planning join benchmarks [#13085](https://github.com/apache/datafusion/pull/13085) (mnorfolk03) +- Fix more instances of schema missing metadata [#13068](https://github.com/apache/datafusion/pull/13068) (itsjunetime) +- Bug-fix / Limit with_new_exprs() [#13109](https://github.com/apache/datafusion/pull/13109) (berkaysynnada) +- Minor: doc IMDB in benchmark README [#13107](https://github.com/apache/datafusion/pull/13107) (2010YOUY01) +- removed --prefer_hash_join option from parquet_filter command. [#13106](https://github.com/apache/datafusion/pull/13106) (neyama) +- Make CI error if a function has no documentation [#12938](https://github.com/apache/datafusion/pull/12938) (alamb) +- Allow using `cargo nextest` for running tests [#13045](https://github.com/apache/datafusion/pull/13045) (alamb) +- Add benchmark for memory-limited aggregation [#13090](https://github.com/apache/datafusion/pull/13090) (2010YOUY01) +- Add clickbench parquet based queries to sql_planner benchmark [#13103](https://github.com/apache/datafusion/pull/13103) (Omega359) +- Improve documentation and examples for `SchemaAdapterFactory`, make `record_batch` "hygenic" [#13063](https://github.com/apache/datafusion/pull/13063) (alamb) +- Move filtered SMJ Left Anti filtered join out of `join_partial` phase [#13111](https://github.com/apache/datafusion/pull/13111) (comphead) +- Improve TableScan with filters pushdown unparsing (multiple filters) [#13131](https://github.com/apache/datafusion/pull/13131) (sgrebnov) +- Raise a plan error on union if column count is not the same between plans [#13117](https://github.com/apache/datafusion/pull/13117) (Omega359) +- Add basic support for `unnest` unparsing [#13129](https://github.com/apache/datafusion/pull/13129) (sgrebnov) +- Improve TableScan with filters pushdown unparsing (joins) [#13132](https://github.com/apache/datafusion/pull/13132) (sgrebnov) +- Report offending plan node when In/Exist subquery misused [#13155](https://github.com/apache/datafusion/pull/13155) (findepi) +- Remove unused assert_analyzed_plan_ne test helper [#13121](https://github.com/apache/datafusion/pull/13121) (findepi) +- Fix Utf8View as Join Key [#13115](https://github.com/apache/datafusion/pull/13115) (demetribu) +- Add Support for `modulus` operation in substrait [#13108](https://github.com/apache/datafusion/pull/13108) (LatrecheYasser) +- unify cast_to function of ScalarValue [#13122](https://github.com/apache/datafusion/pull/13122) (JasonLi-cn) +- Add unused_qualifications rustic lint with deny lint level. [#13086](https://github.com/apache/datafusion/pull/13086) (dhegberg) +- [Optimization] Infer predicate under all JoinTypes [#13081](https://github.com/apache/datafusion/pull/13081) (JasonLi-cn) +- Support `negate` arithmetic expression in substrait [#13112](https://github.com/apache/datafusion/pull/13112) (LatrecheYasser) +- Fix to_char signature ordering [#13126](https://github.com/apache/datafusion/pull/13126) (Omega359) +- chore: re-export functions_window_common::ExpressionArgs [#13149](https://github.com/apache/datafusion/pull/13149) (Michael-J-Ward) +- minor: Fix build on main [#13159](https://github.com/apache/datafusion/pull/13159) (eejbyfeldt) +- minor: Update test case for issue #5771 showing it is resolved [#13180](https://github.com/apache/datafusion/pull/13180) (eejbyfeldt) +- Test LIKE with dynamic pattern [#13141](https://github.com/apache/datafusion/pull/13141) (findepi) +- Increase fuzz testing of streaming group by / low cardinality columns [#12990](https://github.com/apache/datafusion/pull/12990) (alamb) +- FFI initial implementation [#12920](https://github.com/apache/datafusion/pull/12920) (timsaucer) +- Report file location and offset when CSV schema mismatch [#13185](https://github.com/apache/datafusion/pull/13185) (findepi) +- Round robin polling between tied winners in sort preserving merge [#13133](https://github.com/apache/datafusion/pull/13133) (jayzhan211) +- Fix rendering of dictionary empty string values in SLT tests [#13198](https://github.com/apache/datafusion/pull/13198) (findepi) +- Improve push down filter of join [#13184](https://github.com/apache/datafusion/pull/13184) (JasonLi-cn) +- Minor: Reduce indirection for finding changlog [#13199](https://github.com/apache/datafusion/pull/13199) (alamb) +- Support `DictionaryArray` in `OVER` clause [#13153](https://github.com/apache/datafusion/pull/13153) (adriangb) +- Allow testing records with sibling whitespace in SLT tests and add more string tests [#13197](https://github.com/apache/datafusion/pull/13197) (findepi) +- Use single file write when an extension is present in the path. [#13079](https://github.com/apache/datafusion/pull/13079) (dhegberg) +- Deprecate ScalarUDF::invoke and invoke_no_args for invoke_batch [#13179](https://github.com/apache/datafusion/pull/13179) (findepi) +- consider volatile function in simply_expression [#13128](https://github.com/apache/datafusion/pull/13128) (Lordworms) +- Fix CI compile failure due to merge conflict [#13219](https://github.com/apache/datafusion/pull/13219) (alamb) +- Revert "Improve push down filter of join (#13184)" [#13229](https://github.com/apache/datafusion/pull/13229) (eejbyfeldt) +- Derive `Clone` for more ExecutionPlans [#13203](https://github.com/apache/datafusion/pull/13203) (alamb) +- feat(logical-types): add NativeType and LogicalType [#12853](https://github.com/apache/datafusion/pull/12853) (notfilippo) +- Apply projection to `Statistics` in `FilterExec` [#13187](https://github.com/apache/datafusion/pull/13187) (alamb) +- Minor: make LeftJoinData into a struct in CrossJoinExec [#13227](https://github.com/apache/datafusion/pull/13227) (alamb) +- Deprecate invoke and invoke_no_args in favor of invoke_batch [#13174](https://github.com/apache/datafusion/pull/13174) (findepi) +- Support timestamp(n) SQL type [#13231](https://github.com/apache/datafusion/pull/13231) (findepi) +- Remove elements deprecated since v 38. [#13245](https://github.com/apache/datafusion/pull/13245) (findepi) + +## Credits + +Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) per contributor. + +``` + 68 Andrew Lamb + 34 Piotr Findeisen + 24 Jonathan Chen + 19 Emil Ejbyfeldt + 17 Jax Liu + 12 Bruce Ritchie + 11 Jonah Gao + 9 Jay Zhan + 8 Mustafa Akur + 8 kamille + 7 Sergei Grebnov + 7 Tornike Gurgenidze + 6 JasonLi + 6 Oleks V + 6 Val Lorentz + 6 jcsherin + 5 Burak Şen + 5 Samuel Colvin + 5 Yongting You + 5 dependabot[bot] + 4 HuSen + 4 Jagdish Parihar + 4 Simon Vandel Sillesen + 4 wiedld + 3 Alihan Çelikcan + 3 Andy Grove + 3 AnthonyZhOon + 3 Austin Liu + 3 Berkay Şahin + 3 Daniel Hegberg + 3 Daniël Heres + 3 Lordworms + 3 Michael J Ward + 3 OussamaSaoudi + 3 Qianqian + 3 Tai Le Manh + 3 Victor Barua + 3 doupache + 3 ngli-me + 3 yi wang + 2 Adrian Garcia Badaracco + 2 Alex Huang + 2 Brent Gardner + 2 Dharan Aditya + 2 Dmitrii Blaginin + 2 Duong Cong Toai + 2 Filippo Rossi + 2 Georgi Krastev + 2 June + 2 Max Norfolk + 2 Peter Toth + 2 Tim Saucer + 2 Yasser Latreche + 2 peasee + 2 waruto + 1 Abdullah Sabaa Allil + 1 Agaev Guseyn + 1 Albert Skalt + 1 Andrey Koshchiy + 1 Arttu + 1 Baris Palaska + 1 Bruno Volpato + 1 Bryce Mecum + 1 Daniel Mesejo + 1 Dmitry Bugakov + 1 Eason + 1 Edmondo Porcu + 1 Eduard Karacharov + 1 Frederic Branczyk + 1 Fredrik Meringdal + 1 Haile + 1 Jan + 1 JonasDev1 + 1 Justus Flerlage + 1 Leslie Su + 1 Marco Neumann + 1 Marko Milenković + 1 Martin Hilton + 1 Matthew Turner + 1 Nick Cameron + 1 Paul + 1 Smith Cruise + 1 Tomoaki Kawada + 1 WeblWabl + 1 Weston Pace + 1 Xiangpeng Hao + 1 Xwg + 1 Yuance.Li + 1 epsio-banay + 1 iamthinh + 1 juroberttyb + 1 mertak-synnada + 1 neyama + 1 smarticen + 1 zhuliquan + 1 张林伟 +``` + +Thank you also to everyone who contributed in other ways such as filing issues, reviewing PRs, and providing feedback on this release. diff --git a/dev/depcheck/Cargo.toml b/dev/depcheck/Cargo.toml index cb4e77eabb22..23cefaec43be 100644 --- a/dev/depcheck/Cargo.toml +++ b/dev/depcheck/Cargo.toml @@ -22,4 +22,4 @@ name = "depcheck" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -cargo = "0.78.1" +cargo = "0.81.0" diff --git a/dev/depcheck/src/main.rs b/dev/depcheck/src/main.rs index 1599fdd4188d..80feefcd1b1c 100644 --- a/dev/depcheck/src/main.rs +++ b/dev/depcheck/src/main.rs @@ -23,7 +23,7 @@ use std::collections::{HashMap, HashSet}; use std::env; use std::path::Path; -use cargo::util::config::Config; +use cargo::util::context::GlobalContext; /// Verifies that there are no circular dependencies between DataFusion crates /// (which prevents publishing on crates.io) by parsing the Cargo.toml files and @@ -31,7 +31,7 @@ use cargo::util::config::Config; /// /// See https://github.com/apache/datafusion/issues/9278 for more details fn main() -> CargoResult<()> { - let config = Config::default()?; + let gctx = GlobalContext::default()?; // This is the path for the depcheck binary let path = env::var("CARGO_MANIFEST_DIR").unwrap(); let root_cargo_toml = Path::new(&path) @@ -47,7 +47,7 @@ fn main() -> CargoResult<()> { "Checking for circular dependencies in {}", root_cargo_toml.display() ); - let workspace = cargo::core::Workspace::new(&root_cargo_toml, &config)?; + let workspace = cargo::core::Workspace::new(&root_cargo_toml, &gctx)?; let (_, resolve) = cargo::ops::resolve_ws(&workspace)?; let mut package_deps = HashMap::new(); diff --git a/dev/release/README.md b/dev/release/README.md index 32735588ed8f..0e0daa9d6c40 100644 --- a/dev/release/README.md +++ b/dev/release/README.md @@ -48,8 +48,8 @@ patch release: - Created a personal access token in GitHub for changelog automation script. - Github PAT should be created with `repo` access - Make sure your signing key is added to the following files in SVN: - - https://dist.apache.org/repos/dist/dev/arrow/KEYS - - https://dist.apache.org/repos/dist/release/arrow/KEYS + - https://dist.apache.org/repos/dist/dev/datafusion/KEYS + - https://dist.apache.org/repos/dist/release/datafusion/KEYS ### How to add signing key @@ -57,16 +57,16 @@ See instructions at https://infra.apache.org/release-signing.html#generate for g Committers can add signing keys in Subversion client with their ASF account. e.g.: -```bash -$ svn co https://dist.apache.org/repos/dist/dev/arrow -$ cd arrow +```shell +$ svn co https://dist.apache.org/repos/dist/dev/datafusion +$ cd datafusion $ editor KEYS $ svn ci KEYS ``` Follow the instructions in the header of the KEYS file to append your key. Here is an example: -```bash +```shell (gpg --list-sigs "John Doe" && gpg --armor --export "John Doe") >> KEYS svn commit KEYS -m "Add key for John Doe" ``` @@ -89,36 +89,26 @@ to generate one if you do not already have one. The changelog is generated using a Python script. There is a dependency on `PyGitHub`, which can be installed using pip: -```bash +```shell pip3 install PyGitHub ``` -Run the following command to generate the changelog content. +To generate the changelog, set the `GITHUB_TOKEN` environment variable to a valid token and then run the script +providing two commit ids or tags followed by the version number of the release being created. The following +example generates a change log of all changes between the first commit and the current HEAD revision. -```bash -$ GITHUB_TOKEN= ./dev/release/generate-changelog.py apache/datafusion 24.0.0 HEAD > dev/changelog/25.0.0.md +```shell +export GITHUB_TOKEN= +./dev/release/generate-changelog.py 24.0.0 HEAD 25.0.0 > dev/changelog/25.0.0.md ``` This script creates a changelog from GitHub PRs based on the labels associated with them as well as looking for -titles starting with `feat:`, `fix:`, or `docs:` . The script will produce output similar to: - -``` -Fetching list of commits between 24.0.0 and HEAD -Fetching pull requests -Categorizing pull requests -Generating changelog content -``` +titles starting with `feat:`, `fix:`, or `docs:`. -This process is not fully automated, so there are some additional manual steps: +Once the change log is generated, run `prettier` to format the document: -- Add the ASF header to the generated file -- Add a link to this changelog from the top-level `/datafusion/CHANGELOG.md` -- Add the following content (copy from the previous version's changelog and update as appropriate: - -``` -## [24.0.0](https://github.com/apache/datafusion/tree/24.0.0) (2023-05-06) - -[Full Changelog](https://github.com/apache/datafusion/compare/23.0.0...24.0.0) +```shell +prettier -w dev/changelog/25.0.0md ``` ## Prepare release commits and PR @@ -128,26 +118,37 @@ release. See [#9697](https://github.com/apache/datafusion/pull/9697) for an example. -Here are the commands that could be used to prepare the `5.1.0` release: +Here are the commands that could be used to prepare the `38.0.0` release: ### Update Version Checkout the main commit to be released -``` +```shell git fetch apache git checkout apache/main ``` -Update datafusion version in `datafusion/Cargo.toml` to `5.1.0`: +Manually update the datafusion version in the root `Cargo.toml` to `38.0.0`. + +Run `cargo update` in the root directory and also in `datafusion-cli`: +```shell +cargo update +cd datafustion-cli +cargo update +cd .. ``` -./dev/update_datafusion_versions.py 5.1.0 + +Run `cargo test` to re-generate some example files: + +```shell +cargo test ``` Lastly commit the version change: -``` +```shell git commit -a -m 'Update version' ``` @@ -167,7 +168,7 @@ Pick numbers in sequential order, with `0` for `rc0`, `1` for `rc1`, etc. While the official release artifacts are signed tarballs and zip files, we also tag the commit it was created for convenience and code archaeology. -Using a string such as `5.1.0` as the ``, create and push the tag by running these commands: +Using a string such as `38.0.0` as the ``, create and push the tag by running these commands: ```shell git fetch apache @@ -181,51 +182,21 @@ git push apache Run `create-tarball.sh` with the `` tag and `` and you found in previous steps: ```shell -GH_TOKEN= ./dev/release/create-tarball.sh 5.1.0 0 +GH_TOKEN= ./dev/release/create-tarball.sh 38.0.0 0 ``` The `create-tarball.sh` script -1. creates and uploads all release candidate artifacts to the [arrow - dev](https://dist.apache.org/repos/dist/dev/arrow) location on the +1. creates and uploads all release candidate artifacts to the [datafusion + dev](https://dist.apache.org/repos/dist/dev/datafusion) location on the apache distribution svn server 2. provide you an email template to - send to dev@arrow.apache.org for release voting. + send to dev@datafusion.apache.org for release voting. ### Vote on Release Candidate artifacts -Send the email output from the script to dev@arrow.apache.org. The email should look like - -``` -To: dev@arrow.apache.org -Subject: [VOTE][DataFusion] Release Apache DataFusion 5.1.0 RC0 - -Hi, - -I would like to propose a release of Apache DataFusion version 5.1.0. - -This release candidate is based on commit: a5dd428f57e62db20a945e8b1895de91405958c4 [1] -The proposed release artifacts and signatures are hosted at [2]. -The changelog is located at [3]. - -Please download, verify checksums and signatures, run the unit tests, -and vote on the release. - -The vote will be open for at least 72 hours. - -[ ] +1 Release this as Apache DataFusion 5.1.0 -[ ] +0 -[ ] -1 Do not release this as Apache DataFusion 5.1.0 because... - -Here is my vote: - -+1 - -[1]: https://github.com/apache/datafusion/tree/a5dd428f57e62db20a945e8b1895de91405958c4 -[2]: https://dist.apache.org/repos/dist/dev/arrow/apache-datafusion-5.1.0 -[3]: https://github.com/apache/datafusion/blob/a5dd428f57e62db20a945e8b1895de91405958c4/CHANGELOG.md -``` +Send the email output from the script to dev@datafusion.apache.org. For the release to become "official" it needs at least three PMC members to vote +1 on it. @@ -233,8 +204,8 @@ For the release to become "official" it needs at least three PMC members to vote The `dev/release/verify-release-candidate.sh` is a script in this repository that can assist in the verification process. Run it like: -``` -./dev/release/verify-release-candidate.sh 5.1.0 0 +```shell +./dev/release/verify-release-candidate.sh 38.0.0 0 ``` #### If the release is not approved @@ -249,11 +220,11 @@ NOTE: steps in this section can only be done by PMC members. ### After the release is approved Move artifacts to the release location in SVN, e.g. -https://dist.apache.org/repos/dist/release/datafusion/datafusion-5.1.0/, using +https://dist.apache.org/repos/dist/release/datafusion/datafusion-38.0.0/, using the `release-tarball.sh` script: ```shell -./dev/release/release-tarball.sh 5.1.0 0 +./dev/release/release-tarball.sh 38.0.0 0 ``` Congratulations! The release is now official! @@ -262,10 +233,10 @@ Congratulations! The release is now official! Tag the same release candidate commit with the final release tag -``` -git co apache/5.1.0-rc0 -git tag 5.1.0 -git push apache 5.1.0 +```shell +git co apache/38.0.0-rc0 +git tag 38.0.0 +git push apache 38.0.0 ``` ### Publish on Crates.io @@ -280,53 +251,33 @@ been made to crates.io using the following instructions. Follow [these instructions](https://doc.rust-lang.org/cargo/reference/publishing.html) to create an account and login to crates.io before asking to be added as an owner -of the following crates: - -- [datafusion-common](https://crates.io/crates/datafusion-common) -- [datafusion-expr](https://crates.io/crates/datafusion-expr) -- [datafusion-execution](https://crates.io/crates/datafusion-execution) -- [datafusion-physical-expr](https://crates.io/crates/datafusion-physical-expr) -- [datafusion-functions](https://crates.io/crates/datafusion-functions) -- [datafusion-functions-array](https://crates.io/crates/datafusion-functions-array) -- [datafusion-sql](https://crates.io/crates/datafusion-sql) -- [datafusion-optimizer](https://crates.io/crates/datafusion-optimizer) -- [datafusion-common-runtime](https://crates.io/crates/datafusion-common-runtime) -- [datafusion-physical-plan](https://crates.io/crates/datafusion-physical-plan) -- [datafusion](https://crates.io/crates/datafusion) -- [datafusion-proto](https://crates.io/crates/datafusion-proto) -- [datafusion-substrait](https://crates.io/crates/datafusion-substrait) -- [datafusion-cli](https://crates.io/crates/datafusion-cli) +to all of the DataFusion crates. Download and unpack the official release tarball Verify that the Cargo.toml in the tarball contains the correct version -(e.g. `version = "5.1.0"`) and then publish the crates by running the script `release-crates.sh` -in a directory extracted from the source tarball that was voted on. Note that this script doesn't -work if run in a Git repo. - -Alternatively the crates can be published one at a time with the following commands. Crates need to be -published in the correct order as shown in this diagram. - -![](crate-deps.svg) - -_To update this diagram, manually edit the dependencies in [crate-deps.dot](crate-deps.dot) and then run:_ - -```bash -dot -Tsvg dev/release/crate-deps.dot > dev/release/crate-deps.svg -``` +(e.g. `version = "38.0.0"`) and then publish the crates by running the following commands ```shell (cd datafusion/common && cargo publish) +(cd datafusion/expr-common && cargo publish) +(cd datafusion/physical-expr-common && cargo publish) +(cd datafusion/functions-aggregate-common && cargo publish) (cd datafusion/expr && cargo publish) (cd datafusion/execution && cargo publish) (cd datafusion/physical-expr && cargo publish) (cd datafusion/functions && cargo publish) -(cd datafusion/functions-array && cargo publish) +(cd datafusion/functions-aggregate && cargo publish) +(cd datafusion/functions-window && cargo publish) +(cd datafusion/functions-nested && cargo publish) (cd datafusion/sql && cargo publish) (cd datafusion/optimizer && cargo publish) (cd datafusion/common-runtime && cargo publish) (cd datafusion/physical-plan && cargo publish) +(cd datafusion/physical-optimizer && cargo publish) +(cd datafusion/catalog && cargo publish) (cd datafusion/core && cargo publish) +(cd datafusion/proto-common && cargo publish) (cd datafusion/proto && cargo publish) (cd datafusion/substrait && cargo publish) ``` @@ -354,7 +305,7 @@ Please visit https://brew.sh/ to obtain Homebrew. In addition to that please che Before running the script make sure that you can run the following command in your bash to make sure that `brew` has been installed and configured properly: -```bash +```shell brew --version ``` @@ -369,7 +320,7 @@ To create a Github Personal Access Token, please visit https://docs.github.com/e After all of the above is complete execute the following command: -```bash +```shell dev/release/publish_homebrew.sh ``` @@ -394,29 +345,11 @@ The vote has passed with +1 votes. Thank you to all who helped with the release verification. ``` -You can include mention crates.io and PyPI version URLs in the email if applicable. - -``` -We have published new versions of DataFusion to crates.io: - -https://crates.io/crates/datafusion/28.0.0 -https://crates.io/crates/datafusion-cli/28.0.0 -https://crates.io/crates/datafusion-common/28.0.0 -https://crates.io/crates/datafusion-expr/28.0.0 -https://crates.io/crates/datafusion-optimizer/28.0.0 -https://crates.io/crates/datafusion-physical-expr/28.0.0 -https://crates.io/crates/datafusion-proto/28.0.0 -https://crates.io/crates/datafusion-sql/28.0.0 -https://crates.io/crates/datafusion-execution/28.0.0 -https://crates.io/crates/datafusion-substrait/28.0.0 -``` - ### Add the release to Apache Reporter -Add the release to https://reporter.apache.org/addrelease.html?arrow with a version name prefixed with `RS-DATAFUSION-`, -for example `RS-DATAFUSION-14.0.0`. +Add the release to https://reporter.apache.org/addrelease.html?datafusion using the version number e.g. 38.0.0. -The release information is used to generate a template for a board report (see example +The release information is used to generate a template for a board report (see example from Apache Arrow project [here](https://github.com/apache/arrow/pull/14357)). ### Delete old RCs and Releases @@ -430,14 +363,14 @@ Release candidates should be deleted once the release is published. Get a list of DataFusion release candidates: -```bash -svn ls https://dist.apache.org/repos/dist/dev/arrow | grep datafusion +```shell +svn ls https://dist.apache.org/repos/dist/dev/datafusion ``` Delete a release candidate: -```bash -svn delete -m "delete old DataFusion RC" https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-7.1.0-rc1/ +```shell +svn delete -m "delete old DataFusion RC" https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-38.0.0-rc1/ ``` #### Deleting old releases from `release` svn @@ -446,36 +379,26 @@ Only the latest release should be available. Delete old releases after publishin Get a list of DataFusion releases: -```bash -svn ls https://dist.apache.org/repos/dist/release/arrow | grep datafusion +```shell +svn ls https://dist.apache.org/repos/dist/release/datafusion ``` Delete a release: -```bash -svn delete -m "delete old DataFusion release" https://dist.apache.org/repos/dist/release/datafusion/datafusion-7.0.0 +```shell +svn delete -m "delete old DataFusion release" https://dist.apache.org/repos/dist/release/datafusion/datafusion-37.0.0 ``` -### Publish the User Guide to the Arrow Site - -- Run the `build.sh` in the `docs` directory from the release tarball. -- Clone the [arrow-site](https://github.com/apache/arrow-site) repository -- Checkout the `asf-site` branch -- Copy content from `docs/build/html/*` to the `datafusion` directory in arrow-site -- Create a PR against the `asf-site` branch ([example](https://github.com/apache/arrow-site/pull/237)) -- Once the PR is merged, the content will be published to https://datafusion.apache.org/ by GitHub Pages (this - can take some time). - ### Optional: Write a blog post announcing the release -We typically crowdsource release announcements by collaborating on a Google document, usually starting +We typically crowd source release announcements by collaborating on a Google document, usually starting with a copy of the previous release announcement. Run the following commands to get the number of commits and number of unique contributors for inclusion in the blog post. -```bash -git log --pretty=oneline 10.0.0..11.0.0 datafusion datafusion-cli datafusion-examples | wc -l -git shortlog -sn 10.0.0..11.0.0 datafusion datafusion-cli datafusion-examples | wc -l +```shell +git log --pretty=oneline 37.0.0..38.0.0 datafusion datafusion-cli datafusion-examples | wc -l +git shortlog -sn 37.0.0..38.0.0 datafusion datafusion-cli datafusion-examples | wc -l ``` Once there is consensus on the contents of the post, create a PR to add a blog post to the @@ -487,3 +410,7 @@ Here is an example blog post PR: - https://github.com/apache/arrow-site/pull/217 Once the PR is merged, a GitHub action will publish the new blog post to https://arrow.apache.org/blog/. + +### Update the version on the download page + +Update the version on the [download page](https://datafusion.apache.org/download) to point to the latest release [here](../../docs/source/download.md). diff --git a/dev/release/crate-deps.dot b/dev/release/crate-deps.dot deleted file mode 100644 index 69811c7d6109..000000000000 --- a/dev/release/crate-deps.dot +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -digraph G { - datafusion_examples - datafusion_examples -> datafusion - datafusion_examples -> datafusion_common - datafusion_examples -> datafusion_expr - datafusion_examples -> datafusion_optimizer - datafusion_examples -> datafusion_physical_expr - datafusion_examples -> datafusion_sql - datafusion_expr - datafusion_expr -> datafusion_common - datafusion_functions - datafusion_functions -> datafusion_common - datafusion_functions -> datafusion_execution - datafusion_functions -> datafusion_expr - datafusion_wasmtest - datafusion_wasmtest -> datafusion - datafusion_wasmtest -> datafusion_common - datafusion_wasmtest -> datafusion_execution - datafusion_wasmtest -> datafusion_expr - datafusion_wasmtest -> datafusion_optimizer - datafusion_wasmtest -> datafusion_physical_expr - datafusion_wasmtest -> datafusion_physical_plan - datafusion_wasmtest -> datafusion_sql - datafusion_common - datafusion_sql - datafusion_sql -> datafusion_common - datafusion_sql -> datafusion_expr - datafusion_physical_plan - datafusion_physical_plan -> datafusion_common - datafusion_physical_plan -> datafusion_execution - datafusion_physical_plan -> datafusion_expr - datafusion_physical_plan -> datafusion_physical_expr - datafusion_benchmarks - datafusion_benchmarks -> datafusion - datafusion_benchmarks -> datafusion_common - datafusion_benchmarks -> datafusion_proto - datafusion_docs_tests - datafusion_docs_tests -> datafusion - datafusion_optimizer - datafusion_optimizer -> datafusion_common - datafusion_optimizer -> datafusion_expr - datafusion_optimizer -> datafusion_physical_expr - datafusion_optimizer -> datafusion_sql - datafusion_proto - datafusion_proto -> datafusion - datafusion_proto -> datafusion_common - datafusion_proto -> datafusion_expr - datafusion_physical_expr - datafusion_physical_expr -> datafusion_common - datafusion_physical_expr -> datafusion_execution - datafusion_physical_expr -> datafusion_expr - datafusion_sqllogictest - datafusion_sqllogictest -> datafusion - datafusion_sqllogictest -> datafusion_common - datafusion - datafusion -> datafusion_common - datafusion -> datafusion_execution - datafusion -> datafusion_expr - datafusion -> datafusion_functions - datafusion -> datafusion_functions_array - datafusion -> datafusion_optimizer - datafusion -> datafusion_physical_expr - datafusion -> datafusion_physical_plan - datafusion -> datafusion_sql - datafusion_functions_array - datafusion_functions_array -> datafusion_common - datafusion_functions_array -> datafusion_execution - datafusion_functions_array -> datafusion_expr - datafusion_execution - datafusion_execution -> datafusion_common - datafusion_execution -> datafusion_expr - datafusion_substrait - datafusion_substrait -> datafusion -} \ No newline at end of file diff --git a/dev/release/crate-deps.svg b/dev/release/crate-deps.svg deleted file mode 100644 index cf60bf752642..000000000000 --- a/dev/release/crate-deps.svg +++ /dev/null @@ -1,445 +0,0 @@ - - - - - - -G - - - -datafusion_examples - -datafusion_examples - - - -datafusion - -datafusion - - - -datafusion_examples->datafusion - - - - - -datafusion_common - -datafusion_common - - - -datafusion_examples->datafusion_common - - - - - -datafusion_expr - -datafusion_expr - - - -datafusion_examples->datafusion_expr - - - - - -datafusion_optimizer - -datafusion_optimizer - - - -datafusion_examples->datafusion_optimizer - - - - - -datafusion_physical_expr - -datafusion_physical_expr - - - -datafusion_examples->datafusion_physical_expr - - - - - -datafusion_sql - -datafusion_sql - - - -datafusion_examples->datafusion_sql - - - - - -datafusion->datafusion_common - - - - - -datafusion->datafusion_expr - - - - - -datafusion->datafusion_optimizer - - - - - -datafusion->datafusion_physical_expr - - - - - -datafusion->datafusion_sql - - - - - -datafusion_functions - -datafusion_functions - - - -datafusion->datafusion_functions - - - - - -datafusion_execution - -datafusion_execution - - - -datafusion->datafusion_execution - - - - - -datafusion_physical_plan - -datafusion_physical_plan - - - -datafusion->datafusion_physical_plan - - - - - -datafusion_functions_array - -datafusion_functions_array - - - -datafusion->datafusion_functions_array - - - - - -datafusion_expr->datafusion_common - - - - - -datafusion_optimizer->datafusion_common - - - - - -datafusion_optimizer->datafusion_expr - - - - - -datafusion_optimizer->datafusion_physical_expr - - - - - -datafusion_optimizer->datafusion_sql - - - - - -datafusion_physical_expr->datafusion_common - - - - - -datafusion_physical_expr->datafusion_expr - - - - - -datafusion_physical_expr->datafusion_execution - - - - - -datafusion_sql->datafusion_common - - - - - -datafusion_sql->datafusion_expr - - - - - -datafusion_functions->datafusion_common - - - - - -datafusion_functions->datafusion_expr - - - - - -datafusion_functions->datafusion_execution - - - - - -datafusion_execution->datafusion_common - - - - - -datafusion_execution->datafusion_expr - - - - - -datafusion_wasmtest - -datafusion_wasmtest - - - -datafusion_wasmtest->datafusion - - - - - -datafusion_wasmtest->datafusion_common - - - - - -datafusion_wasmtest->datafusion_expr - - - - - -datafusion_wasmtest->datafusion_optimizer - - - - - -datafusion_wasmtest->datafusion_physical_expr - - - - - -datafusion_wasmtest->datafusion_sql - - - - - -datafusion_wasmtest->datafusion_execution - - - - - -datafusion_wasmtest->datafusion_physical_plan - - - - - -datafusion_physical_plan->datafusion_common - - - - - -datafusion_physical_plan->datafusion_expr - - - - - -datafusion_physical_plan->datafusion_physical_expr - - - - - -datafusion_physical_plan->datafusion_execution - - - - - -datafusion_benchmarks - -datafusion_benchmarks - - - -datafusion_benchmarks->datafusion - - - - - -datafusion_benchmarks->datafusion_common - - - - - -datafusion_proto - -datafusion_proto - - - -datafusion_benchmarks->datafusion_proto - - - - - -datafusion_proto->datafusion - - - - - -datafusion_proto->datafusion_common - - - - - -datafusion_proto->datafusion_expr - - - - - -datafusion_docs_tests - -datafusion_docs_tests - - - -datafusion_docs_tests->datafusion - - - - - -datafusion_sqllogictest - -datafusion_sqllogictest - - - -datafusion_sqllogictest->datafusion - - - - - -datafusion_sqllogictest->datafusion_common - - - - - -datafusion_functions_array->datafusion_common - - - - - -datafusion_functions_array->datafusion_expr - - - - - -datafusion_functions_array->datafusion_execution - - - - - -datafusion_substrait - -datafusion_substrait - - - -datafusion_substrait->datafusion - - - - - diff --git a/dev/release/create-tarball.sh b/dev/release/create-tarball.sh index e345773287cf..c43e02bbbba5 100755 --- a/dev/release/create-tarball.sh +++ b/dev/release/create-tarball.sh @@ -21,9 +21,9 @@ # Adapted from https://github.com/apache/arrow-rs/tree/master/dev/release/create-tarball.sh # This script creates a signed tarball in -# dev/dist/apache-arrow-datafusion--.tar.gz and uploads it to -# the "dev" area of the dist.apache.arrow repository and prepares an -# email for sending to the dev@arrow.apache.org list for a formal +# dev/dist/apache-datafusion--.tar.gz and uploads it to +# the "dev" area of the dist.apache.datafusion repository and prepares an +# email for sending to the dev@datafusion.apache.org list for a formal # vote. # # See release/README.md for full release instructions @@ -65,22 +65,22 @@ tag="${version}-rc${rc}" echo "Attempting to create ${tarball} from tag ${tag}" release_hash=$(cd "${SOURCE_TOP_DIR}" && git rev-list --max-count=1 ${tag}) -release=apache-arrow-datafusion-${version} +release=apache-datafusion-${version} distdir=${SOURCE_TOP_DIR}/dev/dist/${release}-rc${rc} tarname=${release}.tar.gz tarball=${distdir}/${tarname} -url="https://dist.apache.org/repos/dist/dev/arrow/${release}-rc${rc}" +url="https://dist.apache.org/repos/dist/dev/datafusion/${release}-rc${rc}" if [ -z "$release_hash" ]; then echo "Cannot continue: unknown git tag: ${tag}" fi -echo "Draft email for dev@arrow.apache.org mailing list" +echo "Draft email for dev@datafusion.apache.org mailing list" echo "" echo "---------------------------------------------------------" cat < ${tarball}.sha256 (cd ${distdir} && shasum -a 512 ${tarname}) > ${tarball}.sha512 -echo "Uploading to apache dist/dev to ${url}" -svn co --depth=empty https://dist.apache.org/repos/dist/dev/arrow ${SOURCE_TOP_DIR}/dev/dist +echo "Uploading to datafusion dist/dev to ${url}" +svn co --depth=empty https://dist.apache.org/repos/dist/dev/datafusion ${SOURCE_TOP_DIR}/dev/dist svn add ${distdir} svn ci -m "Apache DataFusion ${version} ${rc}" ${distdir} diff --git a/dev/release/generate-changelog.py b/dev/release/generate-changelog.py index 74e77ce846e5..23b594214823 100755 --- a/dev/release/generate-changelog.py +++ b/dev/release/generate-changelog.py @@ -20,7 +20,7 @@ from github import Github import os import re - +import subprocess def print_pulls(repo_name, title, pulls): if len(pulls) > 0: @@ -32,7 +32,7 @@ def print_pulls(repo_name, title, pulls): print() -def generate_changelog(repo, repo_name, tag1, tag2): +def generate_changelog(repo, repo_name, tag1, tag2, version): # get a list of commits between two tags print(f"Fetching list of commits between {tag1} and {tag2}", file=sys.stderr) @@ -52,12 +52,12 @@ def generate_changelog(repo, repo_name, tag1, tag2): all_pulls.append((pull, commit)) # we split the pulls into categories - #TODO: make categories configurable breaking = [] bugs = [] docs = [] enhancements = [] performance = [] + other = [] # categorize the pull requests based on GitHub labels print("Categorizing pull requests", file=sys.stderr) @@ -75,7 +75,6 @@ def generate_changelog(repo, repo_name, tag1, tag2): cc_breaking = parts_tuple[2] == '!' labels = [label.name for label in pull.labels] - #print(pull.number, labels, parts, file=sys.stderr) if 'api change' in labels or cc_breaking: breaking.append((pull, commit)) elif 'bug' in labels or cc_type == 'fix': @@ -84,18 +83,64 @@ def generate_changelog(repo, repo_name, tag1, tag2): performance.append((pull, commit)) elif 'enhancement' in labels or cc_type == 'feat': enhancements.append((pull, commit)) - elif 'documentation' in labels or cc_type == 'docs': + elif 'documentation' in labels or cc_type == 'docs' or cc_type == 'doc': docs.append((pull, commit)) + else: + other.append((pull, commit)) # produce the changelog content print("Generating changelog content", file=sys.stderr) + + # ASF header + print("""\n""") + + print(f"# Apache DataFusion {version} Changelog\n") + + # get the number of commits + commit_count = subprocess.check_output(f"git log --pretty=oneline {tag1}..{tag2} | wc -l", shell=True, text=True).strip() + + # get number of contributors + contributor_count = subprocess.check_output(f"git shortlog -sn {tag1}..{tag2} | wc -l", shell=True, text=True).strip() + + print(f"This release consists of {commit_count} commits from {contributor_count} contributors. " + f"See credits at the end of this changelog for more information.\n") + print_pulls(repo_name, "Breaking changes", breaking) print_pulls(repo_name, "Performance related", performance) print_pulls(repo_name, "Implemented enhancements", enhancements) print_pulls(repo_name, "Fixed bugs", bugs) print_pulls(repo_name, "Documentation updates", docs) - print_pulls(repo_name, "Merged pull requests", all_pulls) + print_pulls(repo_name, "Other", other) + + # show code contributions + credits = subprocess.check_output(f"git shortlog -sn {tag1}..{tag2}", shell=True, text=True).rstrip() + + print("## Credits\n") + print("Thank you to everyone who contributed to this release. Here is a breakdown of commits (PRs merged) " + "per contributor.\n") + print("```") + print(credits) + print("```\n") + print("Thank you also to everyone who contributed in other ways such as filing issues, reviewing " + "PRs, and providing feedback on this release.\n") def cli(args=None): """Process command line arguments.""" @@ -103,16 +148,17 @@ def cli(args=None): args = sys.argv[1:] parser = argparse.ArgumentParser() - parser.add_argument("project", help="The project name e.g. apache/datafusion") - parser.add_argument("tag1", help="The previous release tag") - parser.add_argument("tag2", help="The current release tag") + parser.add_argument("tag1", help="The previous commit or tag (e.g. 0.1.0)") + parser.add_argument("tag2", help="The current commit or tag (e.g. HEAD)") + parser.add_argument("version", help="The version number to include in the changelog") args = parser.parse_args() token = os.getenv("GITHUB_TOKEN") + project = "apache/datafusion" g = Github(token) - repo = g.get_repo(args.project) - generate_changelog(repo, args.project, args.tag1, args.tag2) + repo = g.get_repo(project) + generate_changelog(repo, project, args.tag1, args.tag2, args.version) if __name__ == "__main__": cli() \ No newline at end of file diff --git a/dev/release/publish_homebrew.sh b/dev/release/publish_homebrew.sh index 1cf7160d4284..20955953e85a 100644 --- a/dev/release/publish_homebrew.sh +++ b/dev/release/publish_homebrew.sh @@ -39,8 +39,8 @@ else # Fallback num_processing_units=1 fi -url="https://www.apache.org/dyn/closer.lua?path=arrow/arrow-datafusion-${version}/apache-arrow-datafusion-${version}.tar.gz" -sha256="$(curl https://dist.apache.org/repos/dist/release/arrow/arrow-datafusion-${version}/apache-arrow-datafusion-${version}.tar.gz.sha256 | cut -d' ' -f1)" +url="https://www.apache.org/dyn/closer.lua?path=datafusion/datafusion-${version}/apache-datafusion-${version}.tar.gz" +sha256="$(curl https://dist.apache.org/repos/dist/release/datafusion/datafusion-${version}/apache-datafusion-${version}.tar.gz.sha256 | cut -d' ' -f1)" pushd "$(brew --repository homebrew/core)" @@ -52,7 +52,7 @@ fi echo "Updating working copy" git fetch --all --prune --tags --force -j$num_processing_units -branch=apache-arrow-datafusion-${version} +branch=apache-datafusion-${version} echo "Creating branch: ${branch}" git branch -D ${branch} || : git checkout -b ${branch} origin/master diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index ce5635b6daf4..7953a5b4e291 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -15,84 +15,8 @@ ci/etc/*.patch ci/vcpkg/*.patch CHANGELOG.md datafusion/CHANGELOG.md -python/CHANGELOG.md -conbench/benchmarks.json -conbench/requirements.txt -conbench/requirements-test.txt -conbench/.flake8 -conbench/.isort.cfg dev/requirements*.txt -dev/archery/MANIFEST.in -dev/archery/requirements*.txt -dev/archery/archery/tests/fixtures/* -dev/archery/archery/crossbow/tests/fixtures/* dev/release/rat_exclude_files.txt -dev/tasks/homebrew-formulae/apache-arrow.rb -dev/tasks/linux-packages/apache-arrow-apt-source/debian/apache-arrow-apt-source.install -dev/tasks/linux-packages/apache-arrow-apt-source/debian/compat -dev/tasks/linux-packages/apache-arrow-apt-source/debian/control -dev/tasks/linux-packages/apache-arrow-apt-source/debian/rules -dev/tasks/linux-packages/apache-arrow-apt-source/debian/source/format -dev/tasks/linux-packages/apache-arrow/debian/compat -dev/tasks/linux-packages/apache-arrow/debian/control.in -dev/tasks/linux-packages/apache-arrow/debian/gir1.2-arrow-1.0.install -dev/tasks/linux-packages/apache-arrow/debian/gir1.2-arrow-cuda-1.0.install -dev/tasks/linux-packages/apache-arrow/debian/gir1.2-arrow-dataset-1.0.install -dev/tasks/linux-packages/apache-arrow/debian/gir1.2-gandiva-1.0.install -dev/tasks/linux-packages/apache-arrow/debian/gir1.2-parquet-1.0.install -dev/tasks/linux-packages/apache-arrow/debian/gir1.2-plasma-1.0.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-glib-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-glib-doc.doc-base -dev/tasks/linux-packages/apache-arrow/debian/libarrow-glib-doc.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-glib-doc.links -dev/tasks/linux-packages/apache-arrow/debian/libarrow-glib400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-cuda-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-cuda-glib-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-cuda-glib400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-cuda400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset-glib-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset-glib-doc.doc-base -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset-glib-doc.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset-glib-doc.links -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset-glib400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-dataset400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-flight-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-flight400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-python-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-python-flight-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-python-flight400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow-python400.install -dev/tasks/linux-packages/apache-arrow/debian/libarrow400.install -dev/tasks/linux-packages/apache-arrow/debian/libgandiva-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libgandiva-glib-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libgandiva-glib-doc.doc-base -dev/tasks/linux-packages/apache-arrow/debian/libgandiva-glib-doc.install -dev/tasks/linux-packages/apache-arrow/debian/libgandiva-glib-doc.links -dev/tasks/linux-packages/apache-arrow/debian/libgandiva-glib400.install -dev/tasks/linux-packages/apache-arrow/debian/libgandiva400.install -dev/tasks/linux-packages/apache-arrow/debian/libparquet-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libparquet-glib-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libparquet-glib-doc.doc-base -dev/tasks/linux-packages/apache-arrow/debian/libparquet-glib-doc.install -dev/tasks/linux-packages/apache-arrow/debian/libparquet-glib-doc.links -dev/tasks/linux-packages/apache-arrow/debian/libparquet-glib400.install -dev/tasks/linux-packages/apache-arrow/debian/libparquet400.install -dev/tasks/linux-packages/apache-arrow/debian/libplasma-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libplasma-glib-dev.install -dev/tasks/linux-packages/apache-arrow/debian/libplasma-glib-doc.doc-base -dev/tasks/linux-packages/apache-arrow/debian/libplasma-glib-doc.install -dev/tasks/linux-packages/apache-arrow/debian/libplasma-glib-doc.links -dev/tasks/linux-packages/apache-arrow/debian/libplasma-glib400.install -dev/tasks/linux-packages/apache-arrow/debian/libplasma400.install -dev/tasks/linux-packages/apache-arrow/debian/patches/series -dev/tasks/linux-packages/apache-arrow/debian/plasma-store-server.install -dev/tasks/linux-packages/apache-arrow/debian/rules -dev/tasks/linux-packages/apache-arrow/debian/source/format -dev/tasks/linux-packages/apache-arrow/debian/watch -dev/tasks/requirements*.txt -dev/tasks/conda-recipes/* pax_global_header MANIFEST.in __init__.pxd @@ -109,8 +33,6 @@ requirements.txt .gitattributes rust-toolchain benchmarks/queries/q*.sql -python/rust-toolchain -python/requirements*.txt **/testdata/* benchmarks/queries/* benchmarks/expected-plans/* @@ -130,8 +52,11 @@ Cargo.lock .history parquet-testing/* *rat.txt +datafusion/proto/src/generated/datafusion_proto_common.rs datafusion/proto/src/generated/pbjson.rs datafusion/proto/src/generated/prost.rs +datafusion/proto-common/src/generated/pbjson.rs +datafusion/proto-common/src/generated/prost.rs .github/ISSUE_TEMPLATE/bug_report.yml .github/ISSUE_TEMPLATE/feature_request.yml .github/workflows/docs.yaml diff --git a/dev/release/release-crates.sh b/dev/release/release-crates.sh deleted file mode 100644 index 00ce77a86749..000000000000 --- a/dev/release/release-crates.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -# This script publishes datafusion crates to crates.io. -# -# This script should only be run after the release has been approved -# by the arrow PMC committee. -# -# See release/README.md for full release instructions - -set -eu - -# Do not run inside a git repo -if ! [ git rev-parse --is-inside-work-tree ]; then - cd datafusion/common && cargo publish - cd datafusion/expr && cargo publish - cd datafusion/sql && cargo publish - cd datafusion/physical-expr && cargo publish - cd datafusion/optimizer && cargo publish - cd datafusion/core && cargo publish - cd datafusion/proto && cargo publish - cd datafusion/execution && cargo publish - cd datafusion/substrait && cargo publish - cd datafusion-cli && cargo publish --no-verify -else - echo "Crates must be released from the source tarball that was voted on, not from the repo" - exit 1 -fi diff --git a/dev/release/release-tarball.sh b/dev/release/release-tarball.sh index 74a4bab3aecd..bd858d23a767 100755 --- a/dev/release/release-tarball.sh +++ b/dev/release/release-tarball.sh @@ -21,10 +21,10 @@ # Adapted from https://github.com/apache/arrow-rs/tree/master/dev/release/release-tarball.sh # This script copies a tarball from the "dev" area of the -# dist.apache.arrow repository to the "release" area +# dist.apache.datafusion repository to the "release" area # # This script should only be run after the release has been approved -# by the arrow PMC committee. +# by the Apache DataFusion PMC committee. # # See release/README.md for full release instructions # @@ -43,7 +43,7 @@ fi version=$1 rc=$2 -tmp_dir=tmp-apache-arrow-datafusion-dist +tmp_dir=tmp-apache-datafusion-dist echo "Recreate temporary directory: ${tmp_dir}" rm -rf ${tmp_dir} @@ -52,14 +52,14 @@ mkdir -p ${tmp_dir} echo "Clone dev dist repository" svn \ co \ - https://dist.apache.org/repos/dist/dev/arrow/apache-arrow-datafusion-${version}-rc${rc} \ + https://dist.apache.org/repos/dist/dev/datafusion/apache-datafusion-${version}-rc${rc} \ ${tmp_dir}/dev echo "Clone release dist repository" -svn co https://dist.apache.org/repos/dist/release/arrow ${tmp_dir}/release +svn co https://dist.apache.org/repos/dist/release/datafusion ${tmp_dir}/release echo "Copy ${version}-rc${rc} to release working copy" -release_version=arrow-datafusion-${version} +release_version=datafusion-${version} mkdir -p ${tmp_dir}/release/${release_version} cp -r ${tmp_dir}/dev/* ${tmp_dir}/release/${release_version}/ svn add ${tmp_dir}/release/${release_version} @@ -71,4 +71,4 @@ echo "Clean up" rm -rf ${tmp_dir} echo "Success! The release is available here:" -echo " https://dist.apache.org/repos/dist/release/arrow/${release_version}" +echo " https://dist.apache.org/repos/dist/release/datafusion/${release_version}" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 45e984dec3a0..2c0bd216b3ac 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -33,7 +33,7 @@ set -o pipefail SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" ARROW_DIR="$(dirname $(dirname ${SOURCE_DIR}))" -ARROW_DIST_URL='https://dist.apache.org/repos/dist/dev/arrow' +ARROW_DIST_URL='https://dist.apache.org/repos/dist/dev/datafusion' download_dist_file() { curl \ @@ -45,7 +45,7 @@ download_dist_file() { } download_rc_file() { - download_dist_file apache-arrow-datafusion-${VERSION}-rc${RC_NUMBER}/$1 + download_dist_file apache-datafusion-${VERSION}-rc${RC_NUMBER}/$1 } import_gpg_keys() { @@ -143,11 +143,11 @@ test_source_distribution() { TEST_SUCCESS=no -setup_tempdir "arrow-${VERSION}" +setup_tempdir "datafusion-${VERSION}" echo "Working in sandbox ${ARROW_TMPDIR}" cd ${ARROW_TMPDIR} -dist_name="apache-arrow-datafusion-${VERSION}" +dist_name="apache-datafusion-${VERSION}" import_gpg_keys fetch_archive ${dist_name} tar xf ${dist_name}.tar.gz diff --git a/dev/requirements.txt b/dev/requirements.txt new file mode 100644 index 000000000000..7fcba0493129 --- /dev/null +++ b/dev/requirements.txt @@ -0,0 +1,2 @@ +tomlkit +PyGitHub \ No newline at end of file diff --git a/dev/update_arrow_deps.py b/dev/update_arrow_deps.py index b685ad2738b1..268ded38f6e8 100755 --- a/dev/update_arrow_deps.py +++ b/dev/update_arrow_deps.py @@ -17,7 +17,7 @@ # limitations under the License. # -# Script that updates the arrow dependencies in datafusion and ballista, locally +# Script that updates the arrow dependencies in datafusion locally # # installation: # pip install tomlkit requests diff --git a/dev/update_config_docs.sh b/dev/update_config_docs.sh index 836ba6772eac..585cb77839f9 100755 --- a/dev/update_config_docs.sh +++ b/dev/update_config_docs.sh @@ -24,7 +24,7 @@ SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" cd "${SOURCE_DIR}/../" && pwd TARGET_FILE="docs/source/user-guide/configs.md" -PRINT_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" +PRINT_CONFIG_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_config_docs" echo "Inserting header" cat <<'EOF' > "$TARGET_FILE" @@ -67,8 +67,8 @@ Environment variables are read during `SessionConfig` initialisation so they mus EOF -echo "Running CLI and inserting docs table" -$PRINT_DOCS_COMMAND >> "$TARGET_FILE" +echo "Running CLI and inserting config docs table" +$PRINT_CONFIG_DOCS_COMMAND >> "$TARGET_FILE" echo "Running prettier" npx prettier@2.3.2 --write "$TARGET_FILE" diff --git a/dev/update_datafusion_versions.py b/dev/update_datafusion_versions.py index 12b0a90d4ab6..2e3374cd920b 100755 --- a/dev/update_datafusion_versions.py +++ b/dev/update_datafusion_versions.py @@ -29,13 +29,16 @@ crates = { 'datafusion-common': 'datafusion/common/Cargo.toml', + 'datafusion-common-runtime': 'datafusion/common-runtime/Cargo.toml', 'datafusion': 'datafusion/core/Cargo.toml', 'datafusion-execution': 'datafusion/execution/Cargo.toml', 'datafusion-expr': 'datafusion/expr/Cargo.toml', 'datafusion-functions': 'datafusion/functions/Cargo.toml', - 'datafusion-functions-array': 'datafusion/functions-array/Cargo.toml', + 'datafusion-functions-aggregate': 'datafusion/functions-aggregate/Cargo.toml', + 'datafusion-functions-nested': 'datafusion/functions-nested/Cargo.toml', 'datafusion-optimizer': 'datafusion/optimizer/Cargo.toml', 'datafusion-physical-expr': 'datafusion/physical-expr/Cargo.toml', + 'datafusion-physical-expr-common': 'datafusion/physical-expr-common/Cargo.toml', 'datafusion-physical-plan': 'datafusion/physical-plan/Cargo.toml', 'datafusion-proto': 'datafusion/proto/Cargo.toml', 'datafusion-sql': 'datafusion/sql/Cargo.toml', diff --git a/dev/update_function_docs.sh b/dev/update_function_docs.sh new file mode 100755 index 000000000000..ad3bc9c7f69c --- /dev/null +++ b/dev/update_function_docs.sh @@ -0,0 +1,287 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +set -e + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "${SOURCE_DIR}/../" && pwd + + +TARGET_FILE="docs/source/user-guide/sql/aggregate_functions.md" +PRINT_AGGREGATE_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- aggregate" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + +# Aggregate Functions + +Aggregate functions operate on a set of values to compute a single result. +EOF + +echo "Running CLI and inserting aggregate function docs table" +$PRINT_AGGREGATE_FUNCTION_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" + +TARGET_FILE="docs/source/user-guide/sql/scalar_functions.md" +PRINT_SCALAR_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- scalar" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + +# Scalar Functions + +EOF + +echo "Running CLI and inserting scalar function docs table" +$PRINT_SCALAR_FUNCTION_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" + +TARGET_FILE="docs/source/user-guide/sql/window_functions_new.md" +PRINT_WINDOW_FUNCTION_DOCS_COMMAND="cargo run --manifest-path datafusion/core/Cargo.toml --bin print_functions_docs -- window" + +echo "Inserting header" +cat <<'EOF' > "$TARGET_FILE" + + + + + +# Window Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (Old)](window_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. +This is comparable to the type of calculation that can be done with an aggregate function. +However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. +Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result + +Here is an example that shows how to compare each employee's salary with the average salary in his or her department: + +```sql +SELECT depname, empno, salary, avg(salary) OVER (PARTITION BY depname) FROM empsalary; + ++-----------+-------+--------+-------------------+ +| depname | empno | salary | avg | ++-----------+-------+--------+-------------------+ +| personnel | 2 | 3900 | 3700.0 | +| personnel | 5 | 3500 | 3700.0 | +| develop | 8 | 6000 | 5020.0 | +| develop | 10 | 5200 | 5020.0 | +| develop | 11 | 5200 | 5020.0 | +| develop | 9 | 4500 | 5020.0 | +| develop | 7 | 4200 | 5020.0 | +| sales | 1 | 5000 | 4866.666666666667 | +| sales | 4 | 4800 | 4866.666666666667 | +| sales | 3 | 4800 | 4866.666666666667 | ++-----------+-------+--------+-------------------+ +``` + +A window function call always contains an OVER clause directly following the window function's name and argument(s). This is what syntactically distinguishes it from a normal function or non-window aggregate. The OVER clause determines exactly how the rows of the query are split up for processing by the window function. The PARTITION BY clause within OVER divides the rows into groups, or partitions, that share the same values of the PARTITION BY expression(s). For each row, the window function is computed across the rows that fall into the same partition as the current row. The previous example showed how to count the average of a column per partition. + +You can also control the order in which rows are processed by window functions using ORDER BY within OVER. (The window ORDER BY does not even have to match the order in which the rows are output.) Here is an example: + +```sql +SELECT depname, empno, salary, + rank() OVER (PARTITION BY depname ORDER BY salary DESC) +FROM empsalary; + ++-----------+-------+--------+--------+ +| depname | empno | salary | rank | ++-----------+-------+--------+--------+ +| personnel | 2 | 3900 | 1 | +| develop | 8 | 6000 | 1 | +| develop | 10 | 5200 | 2 | +| develop | 11 | 5200 | 2 | +| develop | 9 | 4500 | 4 | +| develop | 7 | 4200 | 5 | +| sales | 1 | 5000 | 1 | +| sales | 4 | 4800 | 2 | +| personnel | 5 | 3500 | 2 | +| sales | 3 | 4800 | 2 | ++-----------+-------+--------+--------+ +``` + +There is another important concept associated with window functions: for each row, there is a set of rows within its partition called its window frame. Some window functions act only on the rows of the window frame, rather than of the whole partition. Here is an example of using window frames in queries: + +```sql +SELECT depname, empno, salary, + avg(salary) OVER(ORDER BY salary ASC ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS avg, + min(salary) OVER(ORDER BY empno ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cum_min +FROM empsalary +ORDER BY empno ASC; + ++-----------+-------+--------+--------------------+---------+ +| depname | empno | salary | avg | cum_min | ++-----------+-------+--------+--------------------+---------+ +| sales | 1 | 5000 | 5000.0 | 5000 | +| personnel | 2 | 3900 | 3866.6666666666665 | 3900 | +| sales | 3 | 4800 | 4700.0 | 3900 | +| sales | 4 | 4800 | 4866.666666666667 | 3900 | +| personnel | 5 | 3500 | 3700.0 | 3500 | +| develop | 7 | 4200 | 4200.0 | 3500 | +| develop | 8 | 6000 | 5600.0 | 3500 | +| develop | 9 | 4500 | 4500.0 | 3500 | +| develop | 10 | 5200 | 5133.333333333333 | 3500 | +| develop | 11 | 5200 | 5466.666666666667 | 3500 | ++-----------+-------+--------+--------------------+---------+ +``` + +When a query involves multiple window functions, it is possible to write out each one with a separate OVER clause, but this is duplicative and error-prone if the same windowing behavior is wanted for several functions. Instead, each windowing behavior can be named in a WINDOW clause and then referenced in OVER. For example: + +```sql +SELECT sum(salary) OVER w, avg(salary) OVER w +FROM empsalary +WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); +``` + +## Syntax + +The syntax for the OVER-clause is + +``` +function([expr]) + OVER( + [PARTITION BY expr[, …]] + [ORDER BY expr [ ASC | DESC ][, …]] + [ frame_clause ] + ) +``` + +where **frame_clause** is one of: + +``` + { RANGE | ROWS | GROUPS } frame_start + { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end +``` + +and **frame_start** and **frame_end** can be one of + +```sql +UNBOUNDED PRECEDING +offset PRECEDING +CURRENT ROW +offset FOLLOWING +UNBOUNDED FOLLOWING +``` + +where **offset** is an non-negative integer. + +RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must specify exactly one column). + +## Aggregate functions + +All [aggregate functions](aggregate_functions.md) can be used as window functions. + +EOF + +echo "Running CLI and inserting window function docs table" +$PRINT_WINDOW_FUNCTION_DOCS_COMMAND >> "$TARGET_FILE" + +echo "Running prettier" +npx prettier@2.3.2 --write "$TARGET_FILE" + +echo "'$TARGET_FILE' successfully updated!" diff --git a/doap.rdf b/doap.rdf new file mode 100644 index 000000000000..c8b8cb361ad8 --- /dev/null +++ b/doap.rdf @@ -0,0 +1,57 @@ + + + + + + + 2024-04-17 + + Apache DataFusion + + + Apache DataFusion is a fast, extensible query engine for building high-quality data-centric systems in Rust. + + Apache DataFusion is a fast, extensible query engine for building high-quality data-centric systems + in Rust, using the Apache Arrow in-memory format. Python Bindings are also available. DataFusion offers SQL + and Dataframe APIs, excellent performance, built-in support for CSV, Parquet, JSON, and Avro, + extensive customization, and a great community. + + + + + Python + Rust + + + + + + + + + + + + diff --git a/docs/source/_static/images/flamegraph.svg b/docs/source/_static/images/flamegraph.svg new file mode 100644 index 000000000000..951cbb1ff366 --- /dev/null +++ b/docs/source/_static/images/flamegraph.svg @@ -0,0 +1,491 @@ +Flame Graph Reset ZoomSearch datafusion-cli`<tokio::runtime::coop::with_budget::ResetGuard as core::ops::drop::Drop>::drop (16 samples, 0.02%)datafusion-cli`datafusion_cli::main_inner::_{{closure}} (91 samples, 0.11%)datafusion-cli`<tokio::runtime::coop::with_budget::ResetGuard as core::ops::drop::Drop>::drop (69 samples, 0.08%)datafusion-cli`datafusion_cli::exec::exec_from_files::_{{closure}} (19 samples, 0.02%)datafusion-cli`datafusion_cli::exec::exec_and_print::_{{closure}} (101 samples, 0.12%)datafusion-cli`<futures_util::stream::try_stream::try_collect::TryCollect<St,C> as core::future::future::Future>::poll (66 samples, 0.08%)datafusion-cli`<parquet::format::FileMetaData as parquet::thrift::TSerializable>::read_from_in_protocol (13 samples, 0.02%)datafusion-cli`<parquet::format::ColumnChunk as parquet::thrift::TSerializable>::read_from_in_protocol (13 samples, 0.02%)datafusion-cli`datafusion::datasource::file_format::parquet::fetch_parquet_metadata::_{{closure}} (19 samples, 0.02%)datafusion-cli`parquet::file::footer::decode_metadata (17 samples, 0.02%)datafusion-cli`<datafusion::datasource::listing::table::ListingTable as datafusion_catalog::table::TableProvider>::scan::_{{closure}} (23 samples, 0.03%)datafusion-cli`<futures_util::stream::stream::buffered::Buffered<St> as futures_core::stream::Stream>::poll_next (23 samples, 0.03%)datafusion-cli`<futures_util::stream::futures_ordered::FuturesOrdered<Fut> as futures_core::stream::Stream>::poll_next (23 samples, 0.03%)datafusion-cli`<futures_util::stream::futures_unordered::FuturesUnordered<Fut> as futures_core::stream::Stream>::poll_next (23 samples, 0.03%)datafusion-cli`<futures_util::stream::futures_ordered::OrderWrapper<T> as core::future::future::Future>::poll (23 samples, 0.03%)datafusion-cli`<datafusion::datasource::file_format::parquet::ParquetFormat as datafusion::datasource::file_format::FileFormat>::infer_stats::_{{closure}} (23 samples, 0.03%)datafusion-cli`<datafusion::execution::session_state::DefaultQueryPlanner as datafusion::execution::context::QueryPlanner>::create_physical_plan::_{{closure}} (25 samples, 0.03%)datafusion-cli`<datafusion::physical_planner::DefaultPhysicalPlanner as datafusion::physical_planner::PhysicalPlanner>::create_physical_plan::_{{closure}} (25 samples, 0.03%)datafusion-cli`datafusion::physical_planner::DefaultPhysicalPlanner::create_initial_plan::_{{closure}} (24 samples, 0.03%)datafusion-cli`<futures_util::stream::try_stream::try_collect::TryCollect<St,C> as core::future::future::Future>::poll (24 samples, 0.03%)datafusion-cli`<S as futures_core::stream::TryStream>::try_poll_next (24 samples, 0.03%)datafusion-cli`<futures_util::stream::futures_unordered::FuturesUnordered<Fut> as futures_core::stream::Stream>::poll_next (24 samples, 0.03%)datafusion-cli`datafusion::physical_planner::DefaultPhysicalPlanner::map_logical_node_to_physical::_{{closure}} (24 samples, 0.03%)datafusion-cli`<parquet::format::ColumnChunk as parquet::thrift::TSerializable>::read_from_in_protocol (20 samples, 0.02%)datafusion-cli`<parquet::format::FileMetaData as parquet::thrift::TSerializable>::read_from_in_protocol (24 samples, 0.03%)datafusion-cli`<S as futures_core::stream::TryStream>::try_poll_next (36 samples, 0.04%)datafusion-cli`<futures_util::stream::futures_ordered::FuturesOrdered<Fut> as futures_core::stream::Stream>::poll_next (36 samples, 0.04%)datafusion-cli`<futures_util::stream::futures_unordered::FuturesUnordered<Fut> as futures_core::stream::Stream>::poll_next (36 samples, 0.04%)datafusion-cli`<futures_util::stream::futures_ordered::OrderWrapper<T> as core::future::future::Future>::poll (36 samples, 0.04%)datafusion-cli`datafusion::datasource::file_format::parquet::fetch_parquet_metadata::_{{closure}} (35 samples, 0.04%)datafusion-cli`parquet::file::footer::decode_metadata (35 samples, 0.04%)datafusion-cli`<datafusion_cli::catalog::DynamicObjectStoreSchemaProvider as datafusion_catalog::schema::SchemaProvider>::table::_{{closure}} (39 samples, 0.05%)datafusion-cli`<datafusion_catalog::dynamic_file::catalog::DynamicFileSchemaProvider as datafusion_catalog::schema::SchemaProvider>::table::_{{closure}} (39 samples, 0.05%)datafusion-cli`<datafusion::datasource::dynamic_file::DynamicListTableFactory as datafusion_catalog::dynamic_file::catalog::UrlTableFactory>::try_new::_{{closure}} (39 samples, 0.05%)datafusion-cli`datafusion::datasource::listing::table::ListingOptions::infer_schema::_{{closure}} (39 samples, 0.05%)datafusion-cli`<datafusion::datasource::file_format::parquet::ParquetFormat as datafusion::datasource::file_format::FileFormat>::infer_schema::_{{closure}} (39 samples, 0.05%)datafusion-cli`<datafusion_physical_plan::sorts::merge::SortPreservingMergeStream<C> as futures_core::stream::Stream>::poll_next (63 samples, 0.07%)datafusion-cli`<alloc::collections::vec_deque::VecDeque<T,A> as core::clone::Clone>::clone (103 samples, 0.12%)datafusion-cli`mi_malloc_aligned (98 samples, 0.12%)datafusion-cli`<alloc::collections::vec_deque::VecDeque<T,A> as core::clone::Clone>::clone (1,443 samples, 1.69%)libdyld.dylib`tlv_get_addr (57 samples, 0.07%)datafusion-cli`<datafusion_physical_plan::sorts::stream::FieldCursorStream<T> as datafusion_physical_plan::sorts::stream::PartitionedStream>::poll_next (57 samples, 0.07%)datafusion-cli`<datafusion_physical_plan::stream::RecordBatchStreamAdapter<S> as futures_core::stream::Stream>::poll_next (26 samples, 0.03%)datafusion-cli`<futures_util::stream::select_with_strategy::SelectWithStrategy<St1,St2,Clos,State> as futures_core::stream::Stream>::poll_next (57 samples, 0.07%)datafusion-cli`<futures_util::stream::once::Once<Fut> as futures_core::stream::Stream>::poll_next (61 samples, 0.07%)datafusion-cli`tokio::task::join_set::JoinSet<T>::poll_join_next (212 samples, 0.25%)datafusion-cli`tokio::util::idle_notified_set::IdleNotifiedSet<T>::pop_notified (147 samples, 0.17%)datafusion-cli`<futures_util::stream::once::Once<Fut> as futures_core::stream::Stream>::poll_next (361 samples, 0.42%)datafusion-cli`tokio::util::idle_notified_set::IdleNotifiedSet<T>::pop_notified (36 samples, 0.04%)datafusion-cli`<futures_util::stream::stream::filter_map::FilterMap<St,Fut,F> as futures_core::stream::Stream>::poll_next (511 samples, 0.60%)datafusion-cli`tokio::task::join_set::JoinSet<T>::poll_join_next (35 samples, 0.04%)datafusion-cli`<tokio::runtime::coop::RestoreOnPending as core::ops::drop::Drop>::drop (88 samples, 0.10%)datafusion-cli`<tokio::runtime::coop::RestoreOnPending as core::ops::drop::Drop>::drop (58 samples, 0.07%)datafusion-cli`tokio::runtime::park::clone (39 samples, 0.05%)datafusion-cli`tokio::sync::mpsc::list::Rx<T>::pop (276 samples, 0.32%)datafusion-cli`tokio::sync::task::atomic_waker::AtomicWaker::register_by_ref (169 samples, 0.20%)datafusion-cli`tokio::runtime::park::drop_waker (36 samples, 0.04%)datafusion-cli`tokio::sync::mpsc::chan::Rx<T,S>::recv (780 samples, 0.92%)libdyld.dylib`tlv_get_addr (68 samples, 0.08%)datafusion-cli`tokio::sync::mpsc::list::Rx<T>::pop (74 samples, 0.09%)datafusion-cli`tokio::sync::task::atomic_waker::AtomicWaker::register_by_ref (52 samples, 0.06%)datafusion-cli`<futures_util::stream::unfold::Unfold<T,F,Fut> as futures_core::stream::Stream>::poll_next (1,241 samples, 1.46%)libdyld.dylib`tlv_get_addr (88 samples, 0.10%)datafusion-cli`<futures_util::stream::select_with_strategy::SelectWithStrategy<St1,St2,Clos,State> as futures_core::stream::Stream>::poll_next (2,014 samples, 2.37%)da..datafusion-cli`tokio::sync::mpsc::chan::Rx<T,S>::recv (32 samples, 0.04%)datafusion-cli`<futures_util::stream::stream::filter_map::FilterMap<St,Fut,F> as futures_core::stream::Stream>::poll_next (90 samples, 0.11%)datafusion-cli`<futures_util::stream::unfold::Unfold<T,F,Fut> as futures_core::stream::Stream>::poll_next (47 samples, 0.06%)datafusion-cli`<datafusion_physical_plan::sorts::stream::FieldCursorStream<T> as datafusion_physical_plan::sorts::stream::PartitionedStream>::poll_next (2,396 samples, 2.81%)da..datafusion-cli`datafusion_physical_plan::sorts::stream::FusedStreams::poll_next (2,284 samples, 2.68%)da..datafusion-cli`futures_util::stream::select::select::round_robin (26 samples, 0.03%)datafusion-cli`datafusion_physical_plan::sorts::merge::SortPreservingMergeStream<C>::maybe_poll_stream (2,570 samples, 3.02%)dat..datafusion-cli`datafusion_physical_plan::sorts::stream::FusedStreams::poll_next (81 samples, 0.10%)datafusion-cli`<datafusion_physical_plan::sorts::merge::SortPreservingMergeStream<C> as futures_core::stream::Stream>::poll_next (4,392 samples, 5.16%)datafu..datafusion-cli`tokio::runtime::park::wake_by_ref (25 samples, 0.03%)datafusion-cli`datafusion_physical_plan::sorts::merge::SortPreservingMergeStream<C>::maybe_poll_stream (36 samples, 0.04%)datafusion-cli`mi_free (86 samples, 0.10%)datafusion-cli`<futures_util::stream::try_stream::try_collect::TryCollect<St,C> as core::future::future::Future>::poll (4,796 samples, 5.63%)datafus..datafusion-cli`tokio::runtime::park::wake_by_ref (23 samples, 0.03%)datafusion-cli`datafusion_cli::exec::exec_from_files::_{{closure}} (5,412 samples, 6.36%)datafusi..datafusion-cli`datafusion_cli::exec::exec_from_lines::_{{closure}} (5,240 samples, 6.15%)datafusi..datafusion-cli`datafusion_cli::exec::exec_and_print::_{{closure}} (5,081 samples, 5.97%)datafusi..datafusion-cli`datafusion_cli::main_inner::_{{closure}} (5,579 samples, 6.55%)datafusio..datafusion-cli`datafusion_cli::exec::exec_from_lines::_{{closure}} (61 samples, 0.07%)datafusion-cli`tokio::runtime::park::CachedParkThread::park (39 samples, 0.05%)datafusion-cli`tokio::runtime::park::Inner::park (62 samples, 0.07%)datafusion-cli`tokio::runtime::park::CachedParkThread::block_on (5,952 samples, 6.99%)datafusio..libdyld.dylib`tlv_get_addr (77 samples, 0.09%)datafusion-cli`tokio::runtime::park::Inner::park (27 samples, 0.03%)datafusion-cli`std::rt::lang_start::_{{closure}} (6,136 samples, 7.21%)datafusion..datafusion-cli`std::sys::backtrace::__rust_begin_short_backtrace (6,136 samples, 7.21%)datafusion..datafusion-cli`datafusion_cli::main (6,136 samples, 7.21%)datafusion..libdyld.dylib`tlv_get_addr (50 samples, 0.06%)datafusion-cli`main (6,137 samples, 7.21%)datafusion..datafusion-cli`std::rt::lang_start_internal (6,137 samples, 7.21%)datafusion..datafusion-cli`mi_arenas_try_purge (51 samples, 0.06%)datafusion-cli`mi_arena_purge (51 samples, 0.06%)libsystem_kernel.dylib`madvise (51 samples, 0.06%)dyld`start (6,193 samples, 7.27%)dyld`startlibdyld.dylib`dyld4::LibSystemHelpers::getenv (56 samples, 0.07%)libsystem_c.dylib`exit (56 samples, 0.07%)libsystem_c.dylib`__cxa_finalize_ranges (56 samples, 0.07%)datafusion-cli`mi_process_done (56 samples, 0.07%)libsystem_kernel.dylib`__exit (47 samples, 0.06%)datafusion-cli`parking_lot::condvar::Condvar::wait_until_internal (21 samples, 0.02%)libsystem_kernel.dylib`__psynch_cvwait (20 samples, 0.02%)datafusion-cli`tokio::runtime::scheduler::multi_thread::park::Parker::park (30 samples, 0.04%)datafusion-cli`tokio::runtime::time::Driver::park_internal (9 samples, 0.01%)datafusion-cli`tokio::runtime::scheduler::multi_thread::worker::Context::park_timeout (31 samples, 0.04%)datafusion-cli`<datafusion_functions_aggregate::average::AvgGroupsAccumulator<T,F> as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::merge_batch (32 samples, 0.04%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState::accumulate (25 samples, 0.03%)datafusion-cli`<datafusion_functions_aggregate::count::CountGroupsAccumulator as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::merge_batch (12 samples, 0.01%)datafusion-cli`_mi_malloc_generic (18 samples, 0.02%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (19 samples, 0.02%)datafusion-cli`mi_malloc_aligned (12 samples, 0.01%)datafusion-cli`<alloc::vec::Vec<T,A> as core::clone::Clone>::clone (36 samples, 0.04%)datafusion-cli`core::ptr::drop_in_place<datafusion_common::scalar::ScalarValue> (13 samples, 0.02%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (10 samples, 0.01%)datafusion-cli`_mi_malloc_generic (9 samples, 0.01%)datafusion-cli`mi_malloc_aligned (9 samples, 0.01%)libsystem_platform.dylib`_platform_memcmp (28 samples, 0.03%)datafusion-cli`<datafusion_functions_aggregate::min_max::MinAccumulator as datafusion_expr_common::accumulator::Accumulator>::update_batch (179 samples, 0.21%)libsystem_platform.dylib`_platform_memmove (23 samples, 0.03%)datafusion-cli`mi_malloc_aligned (10 samples, 0.01%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::grow_one (27 samples, 0.03%)datafusion-cli`alloc::raw_vec::finish_grow (27 samples, 0.03%)datafusion-cli`core::ptr::drop_in_place<arrow_array::array::byte_array::GenericByteArray<arrow_array::types::GenericBinaryType<i32>>> (11 samples, 0.01%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (22 samples, 0.03%)datafusion-cli`datafusion_common::scalar::ScalarValue::size (27 samples, 0.03%)datafusion-cli`mi_heap_realloc_zero_aligned_at (9 samples, 0.01%)datafusion-cli`arrow_select::take::take_bytes (61 samples, 0.07%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (49 samples, 0.06%)libsystem_platform.dylib`_platform_memmove (40 samples, 0.05%)datafusion-cli`datafusion_common::utils::get_arrayref_at_indices (142 samples, 0.17%)datafusion-cli`core::iter::adapters::try_process (142 samples, 0.17%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (142 samples, 0.17%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (142 samples, 0.17%)datafusion-cli`arrow_select::take::take (142 samples, 0.17%)datafusion-cli`arrow_select::take::take_impl (141 samples, 0.17%)libsystem_platform.dylib`_platform_memmove (77 samples, 0.09%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::grow_one (17 samples, 0.02%)datafusion-cli`alloc::raw_vec::finish_grow (17 samples, 0.02%)libsystem_platform.dylib`_platform_memmove (17 samples, 0.02%)datafusion-cli`datafusion_common::scalar::ScalarValue::size (10 samples, 0.01%)datafusion-cli`<datafusion_common::scalar::ScalarValue as core::convert::TryFrom<&arrow_schema::datatype::DataType>>::try_from (14 samples, 0.02%)datafusion-cli`_mi_free_delayed_block (11 samples, 0.01%)datafusion-cli`mi_find_page (10 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (28 samples, 0.03%)datafusion-cli`_mi_malloc_generic (28 samples, 0.03%)datafusion-cli`<datafusion_functions_aggregate::min_max::Min as datafusion_expr::udaf::AggregateUDFImpl>::accumulator (44 samples, 0.05%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::GroupsAccumulatorAdapter::make_accumulators_if_needed (110 samples, 0.13%)datafusion-cli`datafusion_physical_expr::aggregate::AggregateFunctionExpr::create_accumulator (66 samples, 0.08%)datafusion-cli`_mi_malloc_generic (9 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (10 samples, 0.01%)datafusion-cli`mi_malloc_aligned (13 samples, 0.02%)datafusion-cli`<arrow_array::array::byte_array::GenericByteArray<T> as arrow_array::array::Array>::slice (68 samples, 0.08%)datafusion-cli`mi_malloc_aligned (10 samples, 0.01%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (93 samples, 0.11%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::slice_and_maybe_filter (115 samples, 0.14%)datafusion-cli`mi_free_block_delayed_mt (11 samples, 0.01%)datafusion-cli`<datafusion_functions_aggregate_common::aggregate::groups_accumulator::GroupsAccumulatorAdapter as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::merge_batch (865 samples, 1.02%)libsystem_platform.dylib`_platform_memmove (83 samples, 0.10%)datafusion-cli`<hashbrown::raw::inner::RawTable<T> as datafusion_common::utils::proxy::RawTableAllocExt>::insert_accounted (52 samples, 0.06%)datafusion-cli`hashbrown::raw::inner::RawTable<T,A>::reserve_rehash (32 samples, 0.04%)datafusion-cli`<str as datafusion_common::hash_utils::HashValue>::hash_one (61 samples, 0.07%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (69 samples, 0.08%)libsystem_platform.dylib`_platform_memmove (64 samples, 0.08%)datafusion-cli`datafusion_physical_expr_common::binary_map::ArrowBytesMap<O,V>::insert_if_new (299 samples, 0.35%)libsystem_platform.dylib`_platform_memcmp (13 samples, 0.02%)datafusion-cli`<datafusion_physical_plan::aggregates::group_values::bytes::GroupValuesByes<O> as datafusion_physical_plan::aggregates::group_values::GroupValues>::intern (380 samples, 0.45%)libsystem_platform.dylib`_platform_memmove (62 samples, 0.07%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (13 samples, 0.02%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (13 samples, 0.02%)datafusion-cli`_mi_malloc_generic (13 samples, 0.02%)datafusion-cli`_mi_free_delayed_block (13 samples, 0.02%)datafusion-cli`_mi_page_free (15 samples, 0.02%)datafusion-cli`mi_segment_page_clear (15 samples, 0.02%)datafusion-cli`mi_segment_span_free_coalesce (15 samples, 0.02%)datafusion-cli`mi_segment_span_free (15 samples, 0.02%)datafusion-cli`mi_segment_try_purge (15 samples, 0.02%)datafusion-cli`mi_segment_purge (15 samples, 0.02%)libsystem_kernel.dylib`madvise (15 samples, 0.02%)datafusion-cli`_mi_free_delayed_block (67 samples, 0.08%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (70 samples, 0.08%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (69 samples, 0.08%)datafusion-cli`_mi_malloc_generic (69 samples, 0.08%)datafusion-cli`arrow_data::transform::variable_size::build_extend::_{{closure}} (13 samples, 0.02%)datafusion-cli`arrow_data::transform::utils::extend_offsets (13 samples, 0.02%)datafusion-cli`arrow_data::transform::MutableArrayData::extend (67 samples, 0.08%)libsystem_platform.dylib`_platform_memmove (54 samples, 0.06%)datafusion-cli`arrow_select::concat::concat (146 samples, 0.17%)datafusion-cli`arrow_select::concat::concat_fallback (144 samples, 0.17%)datafusion-cli`arrow_select::concat::concat_batches (164 samples, 0.19%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (10 samples, 0.01%)datafusion-cli`datafusion_physical_plan::coalesce::BatchCoalescer::finish_batch (175 samples, 0.21%)datafusion-cli`core::ptr::drop_in_place<arrow_array::record_batch::RecordBatch> (11 samples, 0.01%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (29 samples, 0.03%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (29 samples, 0.03%)datafusion-cli`_mi_malloc_generic (29 samples, 0.03%)datafusion-cli`_mi_free_delayed_block (29 samples, 0.03%)datafusion-cli`datafusion_physical_plan::coalesce::BatchCoalescer::push_batch (30 samples, 0.04%)datafusion-cli`<datafusion_physical_plan::coalesce_batches::CoalesceBatchesStream as futures_core::stream::Stream>::poll_next (218 samples, 0.26%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::slice_and_maybe_filter (13 samples, 0.02%)datafusion-cli`_mi_free_delayed_block (128 samples, 0.15%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (137 samples, 0.16%)datafusion-cli`_mi_malloc_generic (137 samples, 0.16%)datafusion-cli`mi_malloc_aligned (31 samples, 0.04%)datafusion-cli`<alloc::vec::Vec<T,A> as core::clone::Clone>::clone (182 samples, 0.21%)datafusion-cli`<datafusion_functions_aggregate::min_max::MaxAccumulator as datafusion_expr_common::accumulator::Accumulator>::evaluate (357 samples, 0.42%)datafusion-cli`<datafusion_common::scalar::ScalarValue as core::clone::Clone>::clone (353 samples, 0.41%)libsystem_platform.dylib`_platform_memmove (135 samples, 0.16%)datafusion-cli`core::ptr::drop_in_place<datafusion_common::scalar::ScalarValue> (12 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<datafusion_functions_aggregate_common::aggregate::groups_accumulator::AccumulatorState> (31 samples, 0.04%)datafusion-cli`mi_free_block_mt (11 samples, 0.01%)datafusion-cli`datafusion_common::scalar::ScalarValue::size (16 samples, 0.02%)datafusion-cli`mi_free (12 samples, 0.01%)datafusion-cli`mi_free_block_delayed_mt (53 samples, 0.06%)datafusion-cli`mi_free_block_mt (33 samples, 0.04%)datafusion-cli`<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::try_fold (555 samples, 0.65%)datafusion-cli`mi_free_generic_mt (17 samples, 0.02%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::reserve::do_reserve_and_handle (27 samples, 0.03%)datafusion-cli`alloc::raw_vec::finish_grow (27 samples, 0.03%)libsystem_platform.dylib`_platform_memmove (24 samples, 0.03%)datafusion-cli`mi_free_block_delayed_mt (9 samples, 0.01%)datafusion-cli`core::iter::adapters::try_process (650 samples, 0.76%)datafusion-cli`alloc::vec::in_place_collect::_<impl alloc::vec::spec_from_iter::SpecFromIter<T,I> for alloc::vec::Vec<T>>::from_iter (647 samples, 0.76%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (16 samples, 0.02%)datafusion-cli`<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::try_fold (14 samples, 0.02%)datafusion-cli`mi_heap_realloc_zero_aligned_at (9 samples, 0.01%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (43 samples, 0.05%)libsystem_platform.dylib`_platform_memmove (34 samples, 0.04%)datafusion-cli`<arrow_array::array::byte_array::GenericByteArray<T> as core::iter::traits::collect::FromIterator<core::option::Option<Ptr>>>::from_iter (90 samples, 0.11%)datafusion-cli`datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream::set_input_done_and_produce_output (856 samples, 1.01%)datafusion-cli`datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream::emit (856 samples, 1.01%)datafusion-cli`<datafusion_functions_aggregate_common::aggregate::groups_accumulator::GroupsAccumulatorAdapter as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::evaluate (856 samples, 1.01%)datafusion-cli`datafusion_common::scalar::ScalarValue::iter_to_array (206 samples, 0.24%)datafusion-cli`core::iter::adapters::try_process (206 samples, 0.24%)libsystem_platform.dylib`_platform_memmove (104 samples, 0.12%)datafusion-cli`mi_free (20 samples, 0.02%)datafusion-cli`<datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream as futures_core::stream::Stream>::poll_next (2,428 samples, 2.85%)da..datafusion-cli`<datafusion_physical_plan::projection::ProjectionStream as futures_core::stream::Stream>::poll_next (2,437 samples, 2.86%)da..datafusion-cli`<datafusion_physical_plan::coalesce_batches::CoalesceBatchesStream as futures_core::stream::Stream>::poll_next (2,437 samples, 2.86%)da..datafusion-cli`<datafusion_physical_plan::filter::FilterExecStream as futures_core::stream::Stream>::poll_next (2,436 samples, 2.86%)da..datafusion-cli`datafusion_physical_plan::common::spawn_buffered::_{{closure}} (2,439 samples, 2.86%)da..datafusion-cli`<futures_util::stream::try_stream::try_flatten::TryFlatten<St> as futures_core::stream::Stream>::poll_next (2,439 samples, 2.86%)da..datafusion-cli`<S as futures_core::stream::TryStream>::try_poll_next (2,438 samples, 2.86%)da..datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (47 samples, 0.06%)libsystem_platform.dylib`_platform_memmove (39 samples, 0.05%)datafusion-cli`arrow_select::take::take_bytes (79 samples, 0.09%)datafusion-cli`arrow_select::take::take_primitive (9 samples, 0.01%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::next (261 samples, 0.31%)datafusion-cli`core::iter::adapters::try_process (260 samples, 0.31%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (260 samples, 0.31%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (260 samples, 0.31%)datafusion-cli`arrow_select::take::take (259 samples, 0.30%)datafusion-cli`arrow_select::take::take_impl (257 samples, 0.30%)libsystem_platform.dylib`_platform_memmove (158 samples, 0.19%)datafusion-cli`<datafusion_functions_aggregate::average::AvgGroupsAccumulator<T,F> as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::update_batch (265 samples, 0.31%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState::accumulate (263 samples, 0.31%)datafusion-cli`<datafusion_functions_aggregate::count::CountGroupsAccumulator as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::update_batch (122 samples, 0.14%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices (121 samples, 0.14%)datafusion-cli`<datafusion_functions_aggregate::min_max::MinAccumulator as datafusion_expr_common::accumulator::Accumulator>::update_batch (9 samples, 0.01%)datafusion-cli`mi_find_page (13 samples, 0.02%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (20 samples, 0.02%)datafusion-cli`_mi_malloc_generic (20 samples, 0.02%)datafusion-cli`mi_malloc_aligned (22 samples, 0.03%)datafusion-cli`<alloc::vec::Vec<T,A> as core::clone::Clone>::clone (59 samples, 0.07%)datafusion-cli`DYLD-STUB$$memcmp (53 samples, 0.06%)datafusion-cli`arrow_arith::aggregate::min_max_helper (227 samples, 0.27%)datafusion-cli`core::ptr::drop_in_place<datafusion_common::scalar::ScalarValue> (13 samples, 0.02%)datafusion-cli`mi_malloc_aligned (17 samples, 0.02%)libsystem_platform.dylib`_platform_memcmp (395 samples, 0.46%)datafusion-cli`<datafusion_functions_aggregate::min_max::MinAccumulator as datafusion_expr_common::accumulator::Accumulator>::update_batch (913 samples, 1.07%)libsystem_platform.dylib`_platform_memmove (61 samples, 0.07%)datafusion-cli`__rust_dealloc (9 samples, 0.01%)datafusion-cli`_mi_malloc_generic (16 samples, 0.02%)datafusion-cli`mi_find_page (14 samples, 0.02%)datafusion-cli`_mi_heap_realloc_zero (20 samples, 0.02%)datafusion-cli`_mi_page_free (10 samples, 0.01%)datafusion-cli`mi_segment_page_clear (10 samples, 0.01%)datafusion-cli`mi_segment_span_free_coalesce (10 samples, 0.01%)datafusion-cli`mi_segment_span_free (10 samples, 0.01%)datafusion-cli`mi_segment_try_purge (10 samples, 0.01%)datafusion-cli`mi_segment_purge (10 samples, 0.01%)libsystem_kernel.dylib`madvise (10 samples, 0.01%)datafusion-cli`mi_find_page (12 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (17 samples, 0.02%)datafusion-cli`_mi_malloc_generic (17 samples, 0.02%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::grow_one (77 samples, 0.09%)datafusion-cli`alloc::raw_vec::finish_grow (72 samples, 0.08%)libsystem_platform.dylib`_platform_memmove (16 samples, 0.02%)datafusion-cli`_mi_page_free (24 samples, 0.03%)datafusion-cli`mi_segment_page_clear (24 samples, 0.03%)datafusion-cli`mi_segment_span_free_coalesce (24 samples, 0.03%)datafusion-cli`mi_segment_span_free (23 samples, 0.03%)datafusion-cli`mi_segment_try_purge (23 samples, 0.03%)datafusion-cli`mi_segment_purge (23 samples, 0.03%)libsystem_kernel.dylib`madvise (22 samples, 0.03%)datafusion-cli`_mi_heap_realloc_zero (27 samples, 0.03%)datafusion-cli`_mi_malloc_generic (26 samples, 0.03%)datafusion-cli`mi_find_page (26 samples, 0.03%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::reserve::do_reserve_and_handle (44 samples, 0.05%)datafusion-cli`alloc::raw_vec::finish_grow (44 samples, 0.05%)libsystem_platform.dylib`_platform_memmove (15 samples, 0.02%)datafusion-cli`alloc::raw_vec::finish_grow (10 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<arrow_array::array::byte_array::GenericByteArray<arrow_array::types::GenericBinaryType<i32>>> (22 samples, 0.03%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (35 samples, 0.04%)datafusion-cli`datafusion_common::scalar::ScalarValue::size (71 samples, 0.08%)datafusion-cli`DYLD-STUB$$memcpy (10 samples, 0.01%)datafusion-cli`_mi_malloc_generic (9 samples, 0.01%)datafusion-cli`mi_find_page (9 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (10 samples, 0.01%)datafusion-cli`mi_heap_realloc_zero_aligned_at (29 samples, 0.03%)datafusion-cli`mi_heap_malloc_zero_aligned_at_overalloc (16 samples, 0.02%)datafusion-cli`_mi_malloc_generic (16 samples, 0.02%)datafusion-cli`mi_page_fresh_alloc (14 samples, 0.02%)datafusion-cli`mi_segments_page_alloc (14 samples, 0.02%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (352 samples, 0.41%)libsystem_platform.dylib`_platform_memmove (322 samples, 0.38%)datafusion-cli`arrow_select::take::take_bytes (567 samples, 0.67%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (1,437 samples, 1.69%)datafusion-cli`arrow_select::take::take (1,437 samples, 1.69%)datafusion-cli`arrow_select::take::take_impl (1,436 samples, 1.69%)libsystem_platform.dylib`_platform_memmove (859 samples, 1.01%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (1,438 samples, 1.69%)datafusion-cli`datafusion_common::utils::get_arrayref_at_indices (1,439 samples, 1.69%)datafusion-cli`core::iter::adapters::try_process (1,439 samples, 1.69%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::grow_one (11 samples, 0.01%)datafusion-cli`alloc::raw_vec::finish_grow (11 samples, 0.01%)libsystem_platform.dylib`_platform_memmove (10 samples, 0.01%)datafusion-cli`<datafusion_common::scalar::ScalarValue as core::convert::TryFrom<&arrow_schema::datatype::DataType>>::try_from (16 samples, 0.02%)datafusion-cli`<datafusion_common::scalar::ScalarValue as core::convert::TryFrom<&arrow_schema::datatype::DataType>>::try_from (9 samples, 0.01%)datafusion-cli`_mi_page_free (37 samples, 0.04%)datafusion-cli`mi_segment_page_clear (37 samples, 0.04%)datafusion-cli`mi_segment_span_free_coalesce (37 samples, 0.04%)datafusion-cli`mi_segment_span_free (36 samples, 0.04%)datafusion-cli`mi_segment_try_purge (36 samples, 0.04%)datafusion-cli`mi_segment_purge (36 samples, 0.04%)libsystem_kernel.dylib`madvise (36 samples, 0.04%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (41 samples, 0.05%)datafusion-cli`_mi_malloc_generic (41 samples, 0.05%)datafusion-cli`mi_find_page (39 samples, 0.05%)datafusion-cli`<datafusion_functions_aggregate::min_max::Min as datafusion_expr::udaf::AggregateUDFImpl>::accumulator (64 samples, 0.08%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::GroupsAccumulatorAdapter::make_accumulators_if_needed (135 samples, 0.16%)datafusion-cli`datafusion_physical_expr::aggregate::AggregateFunctionExpr::create_accumulator (89 samples, 0.10%)datafusion-cli`arrow_buffer::buffer::immutable::Buffer::slice_with_length (9 samples, 0.01%)datafusion-cli`arrow_buffer::buffer::scalar::ScalarBuffer<T>::new (10 samples, 0.01%)datafusion-cli`mi_malloc_aligned (27 samples, 0.03%)datafusion-cli`<arrow_array::array::byte_array::GenericByteArray<T> as arrow_array::array::Array>::slice (92 samples, 0.11%)datafusion-cli`mi_malloc_aligned (12 samples, 0.01%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (123 samples, 0.14%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::slice_and_maybe_filter (142 samples, 0.17%)datafusion-cli`mi_free (20 samples, 0.02%)datafusion-cli`<datafusion_functions_aggregate_common::aggregate::groups_accumulator::GroupsAccumulatorAdapter as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::update_batch (3,755 samples, 4.41%)dataf..libsystem_platform.dylib`_platform_memmove (96 samples, 0.11%)datafusion-cli`DYLD-STUB$$memcmp (19 samples, 0.02%)datafusion-cli`<hashbrown::raw::inner::RawTable<T> as datafusion_common::utils::proxy::RawTableAllocExt>::insert_accounted (65 samples, 0.08%)datafusion-cli`hashbrown::raw::inner::RawTable<T,A>::reserve_rehash (46 samples, 0.05%)datafusion-cli`<str as datafusion_common::hash_utils::HashValue>::hash_one (424 samples, 0.50%)datafusion-cli`mi_heap_realloc_zero_aligned_at (9 samples, 0.01%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (53 samples, 0.06%)libsystem_platform.dylib`_platform_memmove (44 samples, 0.05%)datafusion-cli`datafusion_physical_expr_common::binary_map::ArrowBytesMap<O,V>::insert_if_new (1,298 samples, 1.52%)datafusion-cli`datafusion_common::hash_utils::create_hashes (18 samples, 0.02%)libsystem_platform.dylib`_platform_memcmp (381 samples, 0.45%)datafusion-cli`<datafusion_physical_plan::aggregates::group_values::bytes::GroupValuesByes<O> as datafusion_physical_plan::aggregates::group_values::GroupValues>::intern (1,777 samples, 2.09%)d..libsystem_platform.dylib`_platform_memmove (65 samples, 0.08%)datafusion-cli`<parquet::format::ColumnChunk as parquet::thrift::TSerializable>::read_from_in_protocol (18 samples, 0.02%)datafusion-cli`<parquet::format::FileMetaData as parquet::thrift::TSerializable>::read_from_in_protocol (25 samples, 0.03%)datafusion-cli`<parquet::arrow::async_reader::store::ParquetObjectReader as parquet::arrow::async_reader::AsyncFileReader>::get_metadata::_{{closure}} (43 samples, 0.05%)datafusion-cli`parquet::file::footer::decode_metadata (39 samples, 0.05%)datafusion-cli`<datafusion::datasource::physical_plan::parquet::opener::ParquetOpener as datafusion::datasource::physical_plan::file_stream::FileOpener>::open::_{{closure}} (54 samples, 0.06%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::reserve::do_reserve_and_handle (59 samples, 0.07%)datafusion-cli`alloc::raw_vec::finish_grow (58 samples, 0.07%)libsystem_platform.dylib`_platform_memmove (51 samples, 0.06%)datafusion-cli`parquet::arrow::buffer::offset_buffer::OffsetBuffer<I>::try_push (202 samples, 0.24%)datafusion-cli`parquet::arrow::array_reader::byte_array::ByteArrayDecoderPlain::read (960 samples, 1.13%)libsystem_platform.dylib`_platform_memmove (620 samples, 0.73%)datafusion-cli`parquet::arrow::buffer::offset_buffer::OffsetBuffer<I>::try_push (111 samples, 0.13%)datafusion-cli`parquet::arrow::buffer::offset_buffer::OffsetBuffer<I>::extend_from_dictionary (90 samples, 0.11%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::reserve::do_reserve_and_handle (49 samples, 0.06%)datafusion-cli`alloc::raw_vec::finish_grow (49 samples, 0.06%)libsystem_platform.dylib`_platform_memmove (40 samples, 0.05%)datafusion-cli`parquet::encodings::rle::RleDecoder::get_batch (14 samples, 0.02%)datafusion-cli`parquet::arrow::decoder::dictionary_index::DictIndexDecoder::read (266 samples, 0.31%)libsystem_platform.dylib`_platform_memmove (159 samples, 0.19%)datafusion-cli`parquet::arrow::array_reader::byte_array::ByteArrayDecoderPlain::read (44 samples, 0.05%)libsystem_platform.dylib`_platform_memmove (33 samples, 0.04%)datafusion-cli`<parquet::arrow::array_reader::byte_array::ByteArrayColumnValueDecoder<I> as parquet::column::reader::decoder::ColumnValueDecoder>::set_dict (46 samples, 0.05%)datafusion-cli`snap::decompress::Decoder::decompress (2,824 samples, 3.32%)dat..datafusion-cli`<parquet::compression::snappy_codec::SnappyCodec as parquet::compression::Codec>::decompress (2,966 samples, 3.48%)dat..libsystem_platform.dylib`_platform_memmove (142 samples, 0.17%)datafusion-cli`parquet::file::serialized_reader::decode_page (3,024 samples, 3.55%)data..libsystem_platform.dylib`__bzero (54 samples, 0.06%)libsystem_platform.dylib`__bzero (32 samples, 0.04%)libsystem_platform.dylib`_platform_memmove (43 samples, 0.05%)datafusion-cli`<parquet::file::serialized_reader::SerializedPageReader<R> as parquet::column::page::PageReader>::get_next_page (3,116 samples, 3.66%)data..datafusion-cli`parquet::column::reader::GenericColumnReader<R,D,V>::read_new_page (3,165 samples, 3.72%)data..datafusion-cli`<parquet::arrow::arrow_reader::ParquetRecordBatchReader as core::iter::traits::iterator::Iterator>::next (4,517 samples, 5.31%)datafu..datafusion-cli`<parquet::arrow::array_reader::struct_array::StructArrayReader as parquet::arrow::array_reader::ArrayReader>::read_records (4,509 samples, 5.30%)datafu..datafusion-cli`<parquet::arrow::array_reader::byte_array::ByteArrayReader<I> as parquet::arrow::array_reader::ArrayReader>::read_records (4,508 samples, 5.29%)datafu..datafusion-cli`parquet::arrow::record_reader::GenericRecordReader<V,CV>::read_records (4,508 samples, 5.29%)datafu..datafusion-cli`parquet::column::reader::GenericColumnReader<R,D,V>::read_records (4,508 samples, 5.29%)datafu..datafusion-cli`core::ptr::drop_in_place<parquet::arrow::arrow_reader::ParquetRecordBatchReader> (11 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<parquet::arrow::array_reader::struct_array::StructArrayReader> (11 samples, 0.01%)datafusion-cli`<alloc::vec::Vec<T,A> as core::ops::drop::Drop>::drop (11 samples, 0.01%)datafusion-cli`<futures_util::stream::stream::map::Map<St,F> as futures_core::stream::Stream>::poll_next (4,535 samples, 5.33%)datafus..datafusion-cli`<futures_util::stream::stream::map::Map<St,F> as futures_core::stream::Stream>::poll_next (4,532 samples, 5.32%)datafus..datafusion-cli`<S as futures_core::stream::TryStream>::try_poll_next (4,532 samples, 5.32%)datafus..datafusion-cli`<datafusion::datasource::physical_plan::file_stream::FileStream<F> as futures_core::stream::Stream>::poll_next (4,596 samples, 5.40%)datafus..datafusion-cli`DYLD-STUB$$memcmp (17 samples, 0.02%)datafusion-cli`arrow_ord::cmp::apply_op (201 samples, 0.24%)datafusion-cli`arrow_ord::cmp::compare_op (253 samples, 0.30%)datafusion-cli`arrow_ord::cmp::compare_op::_{{closure}} (252 samples, 0.30%)libsystem_platform.dylib`_platform_memcmp (33 samples, 0.04%)datafusion-cli`<datafusion_physical_expr::expressions::binary::BinaryExpr as datafusion_physical_expr_common::physical_expr::PhysicalExpr>::evaluate (260 samples, 0.31%)datafusion-cli`datafusion_physical_expr_common::datum::apply_cmp (258 samples, 0.30%)datafusion-cli`<arrow_buffer::util::bit_iterator::BitIndexIterator as core::iter::traits::iterator::Iterator>::next (16 samples, 0.02%)datafusion-cli`<arrow_buffer::util::bit_iterator::BitSliceIterator as core::iter::traits::iterator::Iterator>::next (15 samples, 0.02%)datafusion-cli`DYLD-STUB$$memcpy (12 samples, 0.01%)datafusion-cli`arrow_select::filter::FilterBytes<OffsetSize>::extend_idx (117 samples, 0.14%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (63 samples, 0.07%)libsystem_platform.dylib`_platform_memmove (58 samples, 0.07%)datafusion-cli`mi_heap_realloc_zero_aligned_at (15 samples, 0.02%)datafusion-cli`arrow_select::filter::FilterBytes<OffsetSize>::extend_slices (336 samples, 0.39%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (199 samples, 0.23%)libsystem_platform.dylib`_platform_memmove (182 samples, 0.21%)datafusion-cli`datafusion_physical_plan::filter::filter_and_project (1,154 samples, 1.36%)datafusion-cli`arrow_select::filter::filter_record_batch (892 samples, 1.05%)datafusion-cli`core::iter::adapters::try_process (891 samples, 1.05%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (890 samples, 1.05%)datafusion-cli`arrow_select::filter::filter_array (890 samples, 1.05%)datafusion-cli`arrow_select::filter::filter_bytes (889 samples, 1.04%)libsystem_platform.dylib`_platform_memmove (393 samples, 0.46%)datafusion-cli`<datafusion_physical_plan::filter::FilterExecStream as futures_core::stream::Stream>::poll_next (5,753 samples, 6.76%)datafusio..datafusion-cli`arrow_data::transform::variable_size::build_extend::_{{closure}} (62 samples, 0.07%)datafusion-cli`arrow_data::transform::utils::extend_offsets (62 samples, 0.07%)datafusion-cli`arrow_data::transform::MutableArrayData::extend (227 samples, 0.27%)libsystem_platform.dylib`_platform_memmove (165 samples, 0.19%)datafusion-cli`arrow_select::concat::concat_batches (232 samples, 0.27%)datafusion-cli`arrow_select::concat::concat (231 samples, 0.27%)datafusion-cli`arrow_select::concat::concat_fallback (231 samples, 0.27%)datafusion-cli`<datafusion_physical_plan::coalesce_batches::CoalesceBatchesStream as futures_core::stream::Stream>::poll_next (5,987 samples, 7.03%)datafusio..datafusion-cli`datafusion_physical_plan::coalesce::BatchCoalescer::finish_batch (233 samples, 0.27%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (1,376 samples, 1.62%)datafusion-cli`<datafusion_physical_expr::expressions::cast::CastExpr as datafusion_physical_expr_common::physical_expr::PhysicalExpr>::evaluate (1,376 samples, 1.62%)datafusion-cli`datafusion_expr_common::columnar_value::ColumnarValue::cast_to (1,376 samples, 1.62%)datafusion-cli`arrow_cast::cast::string::cast_binary_to_string (1,376 samples, 1.62%)datafusion-cli`arrow_array::array::string_array::_<impl arrow_array::array::byte_array::GenericByteArray<arrow_array::types::GenericStringType<OffsetSize>>>::try_from_binary (1,375 samples, 1.61%)datafusion-cli`<arrow_array::types::GenericStringType<O> as arrow_array::types::ByteArrayType>::validate (1,375 samples, 1.61%)datafusion-cli`core::str::converts::from_utf8 (1,316 samples, 1.55%)datafusion-cli`<datafusion_physical_plan::projection::ProjectionStream as futures_core::stream::Stream>::poll_next (7,365 samples, 8.65%)datafusion-c..datafusion-cli`core::iter::adapters::try_process (1,378 samples, 1.62%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (1,378 samples, 1.62%)datafusion-cli`core::ops::function::impls::_<impl core::ops::function::FnOnce<A> for &mut F>::call_once (254 samples, 0.30%)datafusion-cli`core::str::count::char_count_general_case (90 samples, 0.11%)datafusion-cli`core::str::count::do_count_chars (1,371 samples, 1.61%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (1,878 samples, 2.21%)d..datafusion-cli`core::ops::function::impls::_<impl core::ops::function::FnOnce<A> for &mut F>::call_once (41 samples, 0.05%)datafusion-cli`<arrow_buffer::buffer::immutable::Buffer as core::iter::traits::collect::FromIterator<T>>::from_iter (1,929 samples, 2.27%)d..datafusion-cli`<arrow_array::array::primitive_array::PrimitiveArray<T> as core::iter::traits::collect::FromIterator<Ptr>>::from_iter (1,932 samples, 2.27%)d..datafusion-cli`<datafusion_functions::unicode::character_length::CharacterLengthFunc as datafusion_expr::udf::ScalarUDFImpl>::invoke (1,935 samples, 2.27%)d..datafusion-cli`datafusion_functions::utils::make_scalar_function::_{{closure}} (1,935 samples, 2.27%)d..datafusion-cli`<datafusion_physical_expr::scalar_function::ScalarFunctionExpr as datafusion_physical_expr_common::physical_expr::PhysicalExpr>::evaluate (1,937 samples, 2.27%)d..datafusion-cli`arrow_array::array::primitive_array::PrimitiveArray<T>::try_unary (24 samples, 0.03%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (18 samples, 0.02%)datafusion-cli`_mi_malloc_generic (18 samples, 0.02%)datafusion-cli`mi_find_page (16 samples, 0.02%)datafusion-cli`_mi_page_free (15 samples, 0.02%)datafusion-cli`mi_segment_page_clear (15 samples, 0.02%)datafusion-cli`mi_segment_span_free_coalesce (15 samples, 0.02%)datafusion-cli`mi_segment_span_free (15 samples, 0.02%)datafusion-cli`mi_segment_try_purge (15 samples, 0.02%)datafusion-cli`mi_segment_purge (15 samples, 0.02%)libsystem_kernel.dylib`madvise (15 samples, 0.02%)datafusion-cli`arrow_cast::cast::cast_numeric_arrays (36 samples, 0.04%)libsystem_platform.dylib`__bzero (11 samples, 0.01%)datafusion-cli`<datafusion_physical_expr::expressions::cast::CastExpr as datafusion_physical_expr_common::physical_expr::PhysicalExpr>::evaluate (1,979 samples, 2.32%)d..datafusion-cli`datafusion_expr_common::columnar_value::ColumnarValue::cast_to (40 samples, 0.05%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (1,994 samples, 2.34%)d..datafusion-cli`datafusion_expr_common::columnar_value::ColumnarValue::into_array (15 samples, 0.02%)datafusion-cli`datafusion_common::scalar::ScalarValue::to_array_of_size (15 samples, 0.02%)datafusion-cli`arrow_array::array::primitive_array::PrimitiveArray<T>::from_value (15 samples, 0.02%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (1,995 samples, 2.34%)d..datafusion-cli`core::iter::adapters::try_process (1,995 samples, 2.34%)d..datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (1,995 samples, 2.34%)d..datafusion-cli`core::iter::adapters::try_process (1,996 samples, 2.34%)d..datafusion-cli`core::ptr::drop_in_place<arrow_array::array::byte_array::GenericByteArray<arrow_array::types::GenericStringType<i32>>> (9 samples, 0.01%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (9 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<alloc::vec::Vec<alloc::sync::Arc<dyn arrow_array::array::Array>>> (16 samples, 0.02%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (14 samples, 0.02%)datafusion-cli`datafusion_common::scalar::ScalarValue::size (17 samples, 0.02%)datafusion-cli`datafusion_functions_aggregate_common::aggregate::groups_accumulator::slice_and_maybe_filter (19 samples, 0.02%)datafusion-cli`DYLD-STUB$$memcpy (28 samples, 0.03%)datafusion-cli`__rust_dealloc (11 samples, 0.01%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (132 samples, 0.16%)datafusion-cli`arrow_data::data::ArrayDataBuilder::build (452 samples, 0.53%)datafusion-cli`arrow_data::data::ArrayData::validate_values (452 samples, 0.53%)datafusion-cli`core::str::converts::from_utf8 (319 samples, 0.37%)datafusion-cli`<&str as regex::regex::string::Replacer>::no_expansion (50 samples, 0.06%)datafusion-cli`<core::iter::adapters::enumerate::Enumerate<I> as core::iter::traits::iterator::Iterator>::next (112 samples, 0.13%)datafusion-cli`DYLD-STUB$$memcpy (28 samples, 0.03%)datafusion-cli`core::ptr::drop_in_place<core::iter::adapters::peekable::Peekable<core::iter::adapters::enumerate::Enumerate<regex::regex::string::CaptureMatches>>> (25 samples, 0.03%)datafusion-cli`mi_free (112 samples, 0.13%)datafusion-cli`<&str as regex::regex::string::Replacer>::no_expansion (24 samples, 0.03%)datafusion-cli`regex_automata::util::determinize::next (16 samples, 0.02%)datafusion-cli`regex_automata::hybrid::dfa::Lazy::cache_next_state (37 samples, 0.04%)datafusion-cli`regex_automata::hybrid::regex::Regex::try_search (7,837 samples, 9.20%)datafusion-cl..datafusion-cli`regex_automata::hybrid::search::find_fwd (7,526 samples, 8.84%)datafusion-c..datafusion-cli`regex_automata::hybrid::search::find_fwd (93 samples, 0.11%)datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::search_imp (76 samples, 0.09%)datafusion-cli`DYLD-STUB$$bzero (46 samples, 0.05%)datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::search_imp (42,705 samples, 50.16%)datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::search..libsystem_platform.dylib`__bzero (54 samples, 0.06%)datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::try_search_slots (43,651 samples, 51.27%)datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::try_sea..datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::try_search_slots_imp (43,543 samples, 51.14%)datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::try_sea..libsystem_platform.dylib`_platform_memset (632 samples, 0.74%)datafusion-cli`regex_automata::meta::strategy::Core::search_slots_nofail (43,848 samples, 51.50%)datafusion-cli`regex_automata::meta::strategy::Core::search_slots_nofaildatafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::try_search_slots_imp (51 samples, 0.06%)datafusion-cli`<regex_automata::meta::strategy::Core as regex_automata::meta::strategy::Strategy>::search_slots (51,998 samples, 61.07%)datafusion-cli`<regex_automata::meta::strategy::Core as regex_automata::meta::strategy::Strategy>::se..datafusion-cli`regex_automata::nfa::thompson::backtrack::BoundedBacktracker::try_search_slots (53 samples, 0.06%)datafusion-cli`_mi_malloc_generic (12 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (16 samples, 0.02%)datafusion-cli`mi_malloc_aligned (108 samples, 0.13%)datafusion-cli`regex_automata::hybrid::regex::Regex::try_search (59 samples, 0.07%)datafusion-cli`regex_automata::meta::strategy::Core::search_slots_nofail (24 samples, 0.03%)datafusion-cli`<core::iter::adapters::enumerate::Enumerate<I> as core::iter::traits::iterator::Iterator>::next (52,763 samples, 61.97%)datafusion-cli`<core::iter::adapters::enumerate::Enumerate<I> as core::iter::traits::iterator::Iterator..libdyld.dylib`tlv_get_addr (32 samples, 0.04%)datafusion-cli`<regex_automata::meta::strategy::Core as regex_automata::meta::strategy::Strategy>::search_slots (57 samples, 0.07%)datafusion-cli`<regex_automata::meta::strategy::ReverseAnchored as regex_automata::meta::strategy::Strategy>::group_info (28 samples, 0.03%)datafusion-cli`DYLD-STUB$$bzero (23 samples, 0.03%)datafusion-cli`DYLD-STUB$$memcpy (14 samples, 0.02%)datafusion-cli`core::ptr::drop_in_place<core::iter::adapters::peekable::Peekable<core::iter::adapters::enumerate::Enumerate<regex::regex::string::CaptureMatches>>> (158 samples, 0.19%)datafusion-cli`core::ptr::drop_in_place<regex_automata::util::pool::PoolGuard<regex_automata::meta::regex::Cache,alloc::boxed::Box<dyn core::ops::function::Fn<()>+Output = regex_automata::meta::regex::Cache+core::panic::unwind_safe::UnwindSafe+core::marker::Sync+core::panic::unwind_safe::RefUnwindSafe+core::marker::Send>>> (101 samples, 0.12%)datafusion-cli`core::ptr::drop_in_place<regex_automata::util::pool::PoolGuard<regex_automata::meta::regex::Cache,alloc::boxed::Box<dyn core::ops::function::Fn<()>+Output = regex_automata::meta::regex::Cache+core::panic::unwind_safe::UnwindSafe+core::marker::Sync+core::panic::unwind_safe::RefUnwindSafe+core::marker::Send>>> (92 samples, 0.11%)datafusion-cli`mi_free (95 samples, 0.11%)datafusion-cli`_mi_free_delayed_block (34 samples, 0.04%)datafusion-cli`_mi_malloc_generic (62 samples, 0.07%)datafusion-cli`mi_find_page (19 samples, 0.02%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (77 samples, 0.09%)datafusion-cli`mi_malloc_aligned (125 samples, 0.15%)datafusion-cli`regex::find_byte::find_byte (87 samples, 0.10%)datafusion-cli`regex_automata::meta::regex::Regex::create_captures (56 samples, 0.07%)datafusion-cli`_mi_malloc_generic (18 samples, 0.02%)datafusion-cli`mi_find_page (9 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (21 samples, 0.02%)datafusion-cli`mi_malloc_aligned (135 samples, 0.16%)datafusion-cli`regex_automata::util::captures::Captures::all (412 samples, 0.48%)libdyld.dylib`tlv_get_addr (26 samples, 0.03%)datafusion-cli`DYLD-STUB$$memcpy (55 samples, 0.06%)datafusion-cli`regex_automata::util::captures::Captures::interpolate_string_into::_{{closure}} (48 samples, 0.06%)datafusion-cli`regex_automata::util::interpolate::find_cap_ref (74 samples, 0.09%)datafusion-cli`DYLD-STUB$$memcpy (13 samples, 0.02%)datafusion-cli`core::num::_<impl core::str::traits::FromStr for usize>::from_str (32 samples, 0.04%)datafusion-cli`core::str::converts::from_utf8 (15 samples, 0.02%)datafusion-cli`regex_automata::util::captures::Captures::interpolate_string_into::_{{closure}} (154 samples, 0.18%)datafusion-cli`core::num::_<impl core::str::traits::FromStr for usize>::from_str (95 samples, 0.11%)datafusion-cli`regex_automata::util::interpolate::find_cap_ref (302 samples, 0.35%)datafusion-cli`core::str::converts::from_utf8 (94 samples, 0.11%)datafusion-cli`regex_automata::util::interpolate::string (1,197 samples, 1.41%)libsystem_platform.dylib`_platform_memmove (479 samples, 0.56%)datafusion-cli`regex_automata::util::captures::Captures::interpolate_string_into (1,508 samples, 1.77%)d..libsystem_platform.dylib`_platform_memmove (92 samples, 0.11%)datafusion-cli`regex_automata::util::interpolate::string (62 samples, 0.07%)datafusion-cli`regex_automata::hybrid::dfa::Cache::new (16 samples, 0.02%)datafusion-cli`regex_automata::hybrid::dfa::Lazy::init_cache (15 samples, 0.02%)datafusion-cli`regex_automata::util::pool::inner::Pool<T,F>::get_slow (18 samples, 0.02%)datafusion-cli`<regex_automata::meta::strategy::Core as regex_automata::meta::strategy::Strategy>::create_cache (18 samples, 0.02%)libdyld.dylib`tlv_get_addr (37 samples, 0.04%)libsystem_platform.dylib`__bzero (14 samples, 0.02%)libsystem_platform.dylib`_platform_memmove (215 samples, 0.25%)datafusion-cli`regex::regex::string::Regex::replacen (57,245 samples, 67.23%)datafusion-cli`regex::regex::string::Regex::replacenlibsystem_platform.dylib`_platform_memset (154 samples, 0.18%)datafusion-cli`regex_automata::util::captures::Captures::all (116 samples, 0.14%)datafusion-cli`regex_automata::util::captures::Captures::interpolate_string_into (40 samples, 0.05%)libdyld.dylib`tlv_get_addr (41 samples, 0.05%)datafusion-cli`core::iter::traits::iterator::Iterator::fold (58,157 samples, 68.30%)datafusion-cli`core::iter::traits::iterator::Iterator::foldlibsystem_platform.dylib`_platform_memmove (114 samples, 0.13%)datafusion-cli`_mi_page_free (9 samples, 0.01%)datafusion-cli`mi_segment_page_clear (9 samples, 0.01%)datafusion-cli`_mi_page_free (10 samples, 0.01%)datafusion-cli`mi_segment_page_clear (10 samples, 0.01%)datafusion-cli`mi_segment_span_free_coalesce (10 samples, 0.01%)datafusion-cli`mi_segment_span_free (10 samples, 0.01%)datafusion-cli`mi_segment_try_purge (10 samples, 0.01%)datafusion-cli`mi_segment_purge (10 samples, 0.01%)libsystem_kernel.dylib`madvise (10 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<regex_automata::hybrid::dfa::Cache> (17 samples, 0.02%)datafusion-cli`core::ptr::drop_in_place<regex_automata::meta::regex::Cache> (28 samples, 0.03%)datafusion-cli`core::ptr::drop_in_place<regex_automata::util::pool::Pool<regex_automata::meta::regex::Cache,alloc::boxed::Box<dyn core::ops::function::Fn<()>+Output = regex_automata::meta::regex::Cache+core::panic::unwind_safe::UnwindSafe+core::marker::Sync+core::panic::unwind_safe::RefUnwindSafe+core::marker::Send>>> (36 samples, 0.04%)datafusion-cli`core::ptr::drop_in_place<regex::regex::string::Regex> (38 samples, 0.04%)datafusion-cli`mi_free (128 samples, 0.15%)datafusion-cli`core::ptr::drop_in_place<regex_automata::nfa::thompson::compiler::Compiler> (25 samples, 0.03%)datafusion-cli`core::ptr::drop_in_place<core::cell::RefCell<regex_automata::nfa::thompson::compiler::Utf8State>> (24 samples, 0.03%)datafusion-cli`regex_automata::meta::wrappers::Hybrid::new (11 samples, 0.01%)datafusion-cli`regex_automata::dfa::onepass::Builder::build_from_nfa (10 samples, 0.01%)datafusion-cli`regex_automata::meta::wrappers::OnePass::new (14 samples, 0.02%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c (19 samples, 0.02%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c (22 samples, 0.03%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c_at_least (21 samples, 0.02%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c_at_least (12 samples, 0.01%)datafusion-cli`DYLD-STUB$$memcpy (32 samples, 0.04%)datafusion-cli`alloc::vec::Vec<T,A>::extend_with (27 samples, 0.03%)datafusion-cli`regex_automata::nfa::thompson::compiler::Utf8Compiler::new (128 samples, 0.15%)datafusion-cli`regex_automata::nfa::thompson::map::Utf8BoundedMap::clear (127 samples, 0.15%)datafusion-cli`<T as alloc::vec::spec_from_elem::SpecFromElem>::from_elem (127 samples, 0.15%)libsystem_platform.dylib`_platform_memmove (67 samples, 0.08%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::next (180 samples, 0.21%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c_cap (158 samples, 0.19%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c (156 samples, 0.18%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c_cap (140 samples, 0.16%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c_at_least (139 samples, 0.16%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::c (138 samples, 0.16%)datafusion-cli`regex_automata::nfa::thompson::nfa::Inner::into_nfa (9 samples, 0.01%)datafusion-cli`regex_automata::nfa::thompson::builder::Builder::build (23 samples, 0.03%)datafusion-cli`regex_automata::nfa::thompson::compiler::Compiler::compile (206 samples, 0.24%)datafusion-cli`regex_automata::meta::strategy::new (259 samples, 0.30%)datafusion-cli`regex_syntax::ast::parse::ParserI<P>::parse_with_comments (26 samples, 0.03%)datafusion-cli`regex_syntax::ast::parse::Parser::parse (28 samples, 0.03%)datafusion-cli`<regex_syntax::hir::translate::TranslatorI as regex_syntax::ast::visitor::Visitor>::visit_post (18 samples, 0.02%)datafusion-cli`regex_syntax::hir::translate::Translator::translate (23 samples, 0.03%)datafusion-cli`regex_syntax::ast::visitor::visit (23 samples, 0.03%)datafusion-cli`regex::regex::string::Regex::new (322 samples, 0.38%)datafusion-cli`regex::builders::Builder::build_one_string (322 samples, 0.38%)datafusion-cli`regex_automata::meta::regex::Builder::build (321 samples, 0.38%)datafusion-cli`regex::regex::string::Regex::replacen (75 samples, 0.09%)datafusion-cli`datafusion_functions::regex::regexpreplace::regexp_replace_func (59,574 samples, 69.97%)datafusion-cli`datafusion_functions::regex::regexpreplace::regexp_replace_funclibsystem_platform.dylib`_platform_memmove (344 samples, 0.40%)datafusion-cli`<datafusion_functions::regex::regexpreplace::RegexpReplaceFunc as datafusion_expr::udf::ScalarUDFImpl>::invoke (59,576 samples, 69.97%)datafusion-cli`<datafusion_functions::regex::regexpreplace::RegexpReplaceFunc as datafusion_expr::udf::ScalarUDFImpl..datafusion-cli`<datafusion_physical_expr::scalar_function::ScalarFunctionExpr as datafusion_physical_expr_common::physical_expr::PhysicalExpr>::evaluate (59,578 samples, 69.97%)datafusion-cli`<datafusion_physical_expr::scalar_function::ScalarFunctionExpr as datafusion_physical_expr_common::ph..datafusion-cli`datafusion_physical_plan::aggregates::evaluate_group_by (59,583 samples, 69.98%)datafusion-cli`datafusion_physical_plan::aggregates::evaluate_group_bydatafusion-cli`core::iter::adapters::try_process (59,582 samples, 69.98%)datafusion-cli`core::iter::adapters::try_processdatafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iter (59,582 samples, 69.98%)datafusion-cli`<alloc::vec::Vec<T> as alloc::vec::spec_from_iter::SpecFromIter<T,I>>::from_iterdatafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (59,582 samples, 69.98%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_folddatafusion-cli`<alloc::vec::Vec<T,A> as core::clone::Clone>::clone (11 samples, 0.01%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (12 samples, 0.01%)datafusion-cli`_mi_malloc_generic (12 samples, 0.01%)datafusion-cli`mi_page_free_list_extend (9 samples, 0.01%)datafusion-cli`mi_malloc_aligned (13 samples, 0.02%)datafusion-cli`<alloc::vec::Vec<T,A> as core::clone::Clone>::clone (37 samples, 0.04%)datafusion-cli`<datafusion_common::scalar::ScalarValue as core::clone::Clone>::clone (103 samples, 0.12%)libsystem_platform.dylib`_platform_memmove (54 samples, 0.06%)datafusion-cli`mi_malloc_aligned (18 samples, 0.02%)datafusion-cli`<datafusion_functions_aggregate::min_max::MaxAccumulator as datafusion_expr_common::accumulator::Accumulator>::state (140 samples, 0.16%)datafusion-cli`alloc::raw_vec::RawVec<T,A>::grow_one (13 samples, 0.02%)datafusion-cli`alloc::raw_vec::finish_grow (13 samples, 0.02%)libsystem_platform.dylib`_platform_memmove (11 samples, 0.01%)datafusion-cli`<core::iter::adapters::map::Map<I,F> as core::iter::traits::iterator::Iterator>::try_fold (15 samples, 0.02%)datafusion-cli`<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::try_fold (14 samples, 0.02%)datafusion-cli`arrow_buffer::buffer::mutable::MutableBuffer::reallocate (54 samples, 0.06%)libsystem_platform.dylib`_platform_memmove (49 samples, 0.06%)datafusion-cli`<arrow_array::array::byte_array::GenericByteArray<T> as core::iter::traits::collect::FromIterator<core::option::Option<Ptr>>>::from_iter (87 samples, 0.10%)datafusion-cli`mi_free (9 samples, 0.01%)datafusion-cli`core::iter::adapters::try_process (202 samples, 0.24%)datafusion-cli`alloc::vec::in_place_collect::from_iter_in_place (202 samples, 0.24%)datafusion-cli`<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::try_fold (202 samples, 0.24%)datafusion-cli`datafusion_common::scalar::ScalarValue::iter_to_array (202 samples, 0.24%)datafusion-cli`core::iter::adapters::try_process (202 samples, 0.24%)libsystem_platform.dylib`_platform_memmove (103 samples, 0.12%)datafusion-cli`core::ptr::drop_in_place<datafusion_common::scalar::ScalarValue> (10 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<datafusion_functions_aggregate_common::aggregate::groups_accumulator::AccumulatorState> (21 samples, 0.02%)datafusion-cli`mi_free (11 samples, 0.01%)datafusion-cli`mi_free_block_delayed_mt (24 samples, 0.03%)datafusion-cli`mi_free_block_mt (9 samples, 0.01%)datafusion-cli`<datafusion_functions_aggregate_common::aggregate::groups_accumulator::GroupsAccumulatorAdapter as datafusion_expr_common::groups_accumulator::GroupsAccumulator>::state (492 samples, 0.58%)datafusion-cli`mi_free_generic_mt (9 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<[alloc::vec::Vec<datafusion_common::scalar::ScalarValue>]> (12 samples, 0.01%)datafusion-cli`mi_free (20 samples, 0.02%)datafusion-cli`datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream::set_input_done_and_produce_output (546 samples, 0.64%)datafusion-cli`datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream::emit (546 samples, 0.64%)datafusion-cli`mi_free (16 samples, 0.02%)libsystem_platform.dylib`_platform_memmove (43 samples, 0.05%)datafusion-cli`<datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream as futures_core::stream::Stream>::poll_next (75,562 samples, 88.74%)datafusion-cli`<datafusion_physical_plan::aggregates::row_hash::GroupedHashAggregateStream as futures_core::stream::Stream>::poll_nextdatafusion-cli`<datafusion_physical_plan::repartition::distributor_channels::SendFuture<T> as core::future::future::Future>::poll (11 samples, 0.01%)datafusion-cli`tokio::runtime::task::waker::wake_by_val (9 samples, 0.01%)datafusion-cli`tokio::runtime::scheduler::multi_thread::worker::_<impl tokio::runtime::task::Schedule for alloc::sync::Arc<tokio::runtime::scheduler::multi_thread::handle::Handle>>::schedule (9 samples, 0.01%)datafusion-cli`tokio::runtime::context::with_scheduler (9 samples, 0.01%)datafusion-cli`mi_free_block_mt (9 samples, 0.01%)datafusion-cli`_mi_os_reset (9 samples, 0.01%)libsystem_kernel.dylib`madvise (9 samples, 0.01%)datafusion-cli`core::ptr::drop_in_place<arrow_array::record_batch::RecordBatch> (13 samples, 0.02%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (13 samples, 0.02%)datafusion-cli`core::ptr::drop_in_place<arrow_array::array::byte_array::GenericByteArray<arrow_array::types::GenericBinaryType<i32>>> (13 samples, 0.02%)datafusion-cli`alloc::sync::Arc<T,A>::drop_slow (13 samples, 0.02%)datafusion-cli`datafusion_physical_plan::repartition::BatchPartitioner::partition_iter (57 samples, 0.07%)datafusion-cli`<str as datafusion_common::hash_utils::HashValue>::hash_one (49 samples, 0.06%)datafusion-cli`datafusion_physical_plan::repartition::RepartitionExec::pull_from_input::_{{closure}} (75,911 samples, 89.15%)datafusion-cli`datafusion_physical_plan::repartition::RepartitionExec::pull_from_input::_{{closure}}datafusion-cli`tokio::runtime::scheduler::multi_thread::worker::Context::run (78,392 samples, 92.07%)datafusion-cli`tokio::runtime::scheduler::multi_thread::worker::Context::rundatafusion-cli`tokio::runtime::scheduler::multi_thread::worker::Context::run_task (78,358 samples, 92.03%)datafusion-cli`tokio::runtime::scheduler::multi_thread::worker::Context::run_taskdatafusion-cli`tokio::runtime::task::harness::Harness<T,S>::poll (78,351 samples, 92.02%)datafusion-cli`tokio::runtime::task::harness::Harness<T,S>::polldatafusion-cli`tokio::runtime::task::harness::Harness<T,S>::poll (78,393 samples, 92.07%)datafusion-cli`tokio::runtime::task::harness::Harness<T,S>::polldatafusion-cli`tokio::runtime::task::core::Core<T,S>::poll (78,393 samples, 92.07%)datafusion-cli`tokio::runtime::task::core::Core<T,S>::polldatafusion-cli`<tokio::runtime::blocking::task::BlockingTask<T> as core::future::future::Future>::poll (78,393 samples, 92.07%)datafusion-cli`<tokio::runtime::blocking::task::BlockingTask<T> as core::future::future::Future>::polldatafusion-cli`tokio::runtime::scheduler::multi_thread::worker::run (78,393 samples, 92.07%)datafusion-cli`tokio::runtime::scheduler::multi_thread::worker::rundatafusion-cli`tokio::runtime::context::runtime::enter_runtime (78,393 samples, 92.07%)datafusion-cli`tokio::runtime::context::runtime::enter_runtimedatafusion-cli`_mi_page_free (30 samples, 0.04%)datafusion-cli`mi_segment_page_clear (30 samples, 0.04%)datafusion-cli`mi_segment_span_free_coalesce (30 samples, 0.04%)datafusion-cli`mi_segment_span_free (30 samples, 0.04%)datafusion-cli`mi_segment_try_purge (30 samples, 0.04%)datafusion-cli`mi_segment_purge (30 samples, 0.04%)libsystem_kernel.dylib`madvise (30 samples, 0.04%)datafusion-cli`_mi_free_delayed_block (31 samples, 0.04%)datafusion-cli`mi_heap_malloc_zero_aligned_at_generic (35 samples, 0.04%)datafusion-cli`_mi_malloc_generic (35 samples, 0.04%)datafusion-cli`core::iter::adapters::try_process (399 samples, 0.47%)datafusion-cli`alloc::vec::in_place_collect::_<impl alloc::vec::spec_from_iter::SpecFromIter<T,I> for alloc::vec::Vec<T>>::from_iter (399 samples, 0.47%)datafusion-cli`<alloc::vec::into_iter::IntoIter<T,A> as core::iter::traits::iterator::Iterator>::try_fold (399 samples, 0.47%)datafusion-cli`object_store::local::read_range (399 samples, 0.47%)datafusion-cli`std::io::default_read_to_end (363 samples, 0.43%)libsystem_kernel.dylib`read (363 samples, 0.43%)datafusion-cli`std::fs::OpenOptions::_open (47 samples, 0.06%)datafusion-cli`std::sys::pal::unix::fs::File::open_c (47 samples, 0.06%)libsystem_kernel.dylib`__open (47 samples, 0.06%)datafusion-cli`object_store::local::open_file (54 samples, 0.06%)datafusion-cli`<tokio::runtime::blocking::task::BlockingTask<T> as core::future::future::Future>::poll (470 samples, 0.55%)datafusion-cli`object_store::local::read_range (12 samples, 0.01%)datafusion-cli`tokio::runtime::blocking::pool::Inner::run (78,880 samples, 92.64%)datafusion-cli`tokio::runtime::blocking::pool::Inner::rundatafusion-cli`tokio::runtime::task::raw::poll (477 samples, 0.56%)datafusion-cli`std::sys::backtrace::__rust_begin_short_backtrace (78,881 samples, 92.64%)datafusion-cli`std::sys::backtrace::__rust_begin_short_backtracedatafusion-cli`std::sys::pal::unix::thread::Thread::new::thread_start (78,882 samples, 92.64%)datafusion-cli`std::sys::pal::unix::thread::Thread::new::thread_startdatafusion-cli`core::ops::function::FnOnce::call_once{{vtable.shim}} (78,882 samples, 92.64%)datafusion-cli`core::ops::function::FnOnce::call_once{{vtable.shim}}datafusion-cli`mi_heap_collect_ex (13 samples, 0.02%)datafusion-cli`_mi_free_delayed_block (11 samples, 0.01%)all (85,146 samples, 100%)libsystem_pthread.dylib`thread_start (78,906 samples, 92.67%)libsystem_pthread.dylib`thread_startlibsystem_pthread.dylib`_pthread_start (78,906 samples, 92.67%)libsystem_pthread.dylib`_pthread_startlibsystem_pthread.dylib`_pthread_exit (22 samples, 0.03%)libsystem_pthread.dylib`_pthread_tsd_cleanup (14 samples, 0.02%)datafusion-cli`_mi_thread_done (14 samples, 0.02%) \ No newline at end of file diff --git a/docs/source/contributor-guide/architecture.md b/docs/source/contributor-guide/architecture.md index 68541f877768..1a094968a274 100644 --- a/docs/source/contributor-guide/architecture.md +++ b/docs/source/contributor-guide/architecture.md @@ -25,3 +25,67 @@ possible. You can find the most up to date version in the [source code]. [crates.io documentation]: https://docs.rs/datafusion/latest/datafusion/index.html#architecture [source code]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/lib.rs + +## Forks vs Extension APIs + +DataFusion is a fast moving project, which results in frequent internal changes. +This benefits DataFusion by allowing it to evolve and respond quickly to +requests, but also means that maintaining a fork with major modifications +sometimes requires non trivial work. + +The public API (what is accessible if you use the DataFusion releases from +crates.io) is typically much more stable (though it does change from release to +release as well). + +Thus, rather than forks, we recommend using one of the many extension APIs (such +as `TableProvider`, `OptimizerRule`, or `ExecutionPlan`) to customize +DataFusion. If you can not do what you want with the existing APIs, we would +welcome you working with us to add new APIs to enable your use case, as +described in the next section. + +Please see the [Extensions] section to find out more about existing DataFusion +extensions and how to contribute your extension to the community. + +[extensions]: ../library-user-guide/extensions.md + +## Creating new Extension APIs + +DataFusion aims to be a general-purpose query engine, and thus the core crates +contain features that are useful for a wide range of use cases. Use case specific +functionality (such as very specific time series or stream processing features) +are typically implemented using the extension APIs. + +If have a use case that is not covered by the existing APIs, we would love to +work with you to design a new general purpose API. There are often others who are +interested in similar extensions and the act of defining the API often improves +the code overall for everyone. + +Extension APIs that provide "safe" default behaviors are more likely to be +suitable for inclusion in DataFusion, while APIs that require major changes to +built-in operators are less likely. For example, it might make less sense +to add an API to support a stream processing feature if that would result in +slower performance for built-in operators. It may still make sense to add +extension APIs for such features, but leave implementation of such operators in +downstream projects. + +The process to create a new extension API is typically: + +- Look for an existing issue describing what you want to do, and file one if it + doesn't yet exist. +- Discuss what the API would look like. Feel free to ask contributors (via `@` + mentions) for feedback (you can find such people by looking at the most + recently changed PRs and issues) +- Prototype the new API, typically by adding an example (in + `datafusion-examples` or refactoring existing code) to show how it would work +- Create a PR with the new API, and work with the community to get it merged + +Some benefits of using an example based approach are + +- Any future API changes will also keep your example going ensuring no + regression in functionality +- There will be a blue print of any needed changes to your code if the APIs do change + (just look at what changed in your example) + +An example of this process was [creating a SQL Extension Planning API]. + +[creating a sql extension planning api]: https://github.com/apache/datafusion/issues/11207 diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 3e5e816d2f90..43d412200201 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -37,18 +37,54 @@ We use the Slack and Discord platforms for informal discussions and coordination meet other contributors and get guidance on where to contribute. It is important to note that any technical designs and decisions are made fully in the open, on GitHub. -Most of us use the `#datafusion` and `#arrow-rust` channels in the [ASF Slack workspace](https://s.apache.org/slack-invite) . -Unfortunately, due to spammers, the ASF Slack workspace requires an invitation to join. To get an invitation, -request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https://discord.gg/Qw5gKqHxUM). +Most of us use the [ASF Slack +workspace](https://s.apache.org/slack-invite) and the [Arrow Rust Discord +server](https://discord.gg/Qw5gKqHxUM) for discussions. -## Mailing list +There are specific channels for Arrow, DataFusion, and the DataFusion subprojects (Ballista, Comet, Python, etc). -We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other -than the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. -([subscribe](mailto:dev-subscribe@datafusion.apache.org), -[unsubscribe](mailto:dev-unsubscribe@datafusion.apache.org), -[archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). +In Slack we use these channels: -When emailing the dev list, please make sure to prefix the subject line with a -`[DataFusion]` tag, e.g. `"[DataFusion] New API for remote data sources"`, so -that the appropriate people in the Apache Arrow community notice the message. +- #arrow +- #arrow-rust +- #datafusion +- #datafusion-ballista +- #datafusion-comet +- #datafusion-python + +In Discord we use these channels: + +- #ballista +- #comet +- #contrib-federation +- #datafusion +- #datafusion-python +- #dolomite-optimizer +- #general +- #hiring +- #incremental-materialized-views + +Unfortunately, due to spammers, the ASF Slack workspace requires an invitation +to join. We are happy to invite you -- please ask for an invitation in the +Discord server. + +## Mailing Lists + +Like other Apache projects, we use [mailing lists] for certain purposes, most +importantly release coordination. Other than the release process, most +DataFusion mailing list traffic will simply link to a GitHub issue or PR where +the actual discussion occurs. The project mailing lists are: + +- [`dev@datafusion.apache.org`](mailto:dev@datafusion.apache.org): the main + mailing list for release coordination and other project-wide discussions. Links: + [archives](https://lists.apache.org/list.html?dev@datafusion.apache.org), + [subscribe](mailto:dev-subscribe@datafusion.apache.org), + [unsubscribe](mailto:dev-unsubscribe@datafusion.apache.org) +- `github@datafusion.apache.org`: read-only mailing list that receives all GitHub notifications for the project. Links: + [archives](https://lists.apache.org/list.html?github@datafusion.apache.org) +- `commits@datafusion.apache.org`: read-only mailing list that receives all GitHub commits for the project. Links: + [archives](https://lists.apache.org/list.html?commits@datafusion.apache.org) +- `private@datafusion.apache.org`: private mailing list for PMC members. This list has very little traffic, almost exclusively discussions on growing the committer and PMC membership. Links: + [archives](https://lists.apache.org/list.html?private@datafusion.apache.org) + +[mailing lists]: https://www.apache.org/foundation/mailinglists diff --git a/docs/source/contributor-guide/getting_started.md b/docs/source/contributor-guide/getting_started.md new file mode 100644 index 000000000000..696d6d3a0fe2 --- /dev/null +++ b/docs/source/contributor-guide/getting_started.md @@ -0,0 +1,87 @@ + + +# Getting Started + +This section describes how you can get started at developing DataFusion. + +## Windows setup + +```shell +wget https://az792536.vo.msecnd.net/vms/VMBuild_20190311/VirtualBox/MSEdge/MSEdge.Win10.VirtualBox.zip +choco install -y git rustup.install visualcpp-build-tools +git-bash.exe +cargo build +``` + +## Protoc Installation + +Compiling DataFusion from sources requires an installed version of the protobuf compiler, `protoc`. + +On most platforms this can be installed from your system's package manager + +``` +# Ubuntu +$ sudo apt install -y protobuf-compiler + +# Fedora +$ dnf install -y protobuf-devel + +# Arch Linux +$ pacman -S protobuf + +# macOS +$ brew install protobuf +``` + +You will want to verify the version installed is `3.15` or greater, which has support for explicit [field presence](https://github.com/protocolbuffers/protobuf/blob/v3.15.0/docs/field_presence.md). Older versions may fail to compile. + +```shell +$ protoc --version +libprotoc 3.15.0 +``` + +Alternatively a binary release can be downloaded from the [Release Page](https://github.com/protocolbuffers/protobuf/releases) or [built from source](https://github.com/protocolbuffers/protobuf/blob/main/src/README.md). + +## Bootstrap environment + +DataFusion is written in Rust and it uses a standard rust toolkit: + +- `cargo build` +- `cargo fmt` to format the code +- `cargo test` to test +- etc. + +Note that running `cargo test` requires significant memory resources, due to cargo running many tests in parallel by default. If you run into issues with slow tests or system lock ups, you can significantly reduce the memory required by instead running `cargo test -- --test-threads=1`. For more information see [this issue](https://github.com/apache/datafusion/issues/5347). + +Testing setup: + +- `rustup update stable` DataFusion uses the latest stable release of rust +- `git submodule init` +- `git submodule update` + +Formatting instructions: + +- [ci/scripts/rust_fmt.sh](../../../ci/scripts/rust_fmt.sh) +- [ci/scripts/rust_clippy.sh](../../../ci/scripts/rust_clippy.sh) +- [ci/scripts/rust_toml_fmt.sh](../../../ci/scripts/rust_toml_fmt.sh) + +or run them all at once: + +- [dev/rust_lint.sh](../../../dev/rust_lint.sh) diff --git a/docs/source/contributor-guide/governance.md b/docs/source/contributor-guide/governance.md new file mode 100644 index 000000000000..27ff90eb92c8 --- /dev/null +++ b/docs/source/contributor-guide/governance.md @@ -0,0 +1,93 @@ + + +# Governance + +The current PMC and committers are listed in the [Apache Phonebook]. + +[apache phonebook]: https://projects.apache.org/committee.html?datafusion + +## Overview + +DataFusion is part of the [Apache Software Foundation] and is governed following +the [Apache Way] and [project management guidelines], [independently of +commercial interests]. + +[apache software foundation]: https://www.apache.org/ +[apache way]: https://www.apache.org/theapacheway/ +[project management guidelines]: https://www.apache.org/foundation/how-it-works.html#management +[independently of commercial interests]: https://community.apache.org/projectIndependence.html + +As much as practicable, we strive to make decisions by consensus, and anyone in +the community is encouraged to propose ideas, start discussions, and contribute +to the project. + +## Roles + +- **Contributors**: Anyone who contributes to the project, whether it be code, + documentation, testing, issue reports, code, or some other forms. + +- **Committers**: Contributors who have been granted write access to the + project's source code repository. Committers are responsible for reviewing and + merging pull requests. Committers are chosen by the PMC. + +- **Project Management Committee (PMC)**: The PMC is responsible for the + oversight of the project. The PMC is responsible for making decisions about the + project, including the addition of new committers and PMC members. The PMC is + also responsible for [voting] on releases and ensuring that the project follows + the [Apache Way]. + +[voting]: https://www.apache.org/foundation/voting.html + +## Becoming a Committer + +Contributors with sustained, high-quality activity may be invited to become +committers by the PMC as a recognition of their contribution to the project and +their shared commitment. Committers have the significant responsibility of using +their status and access to improve the project for the entire community. + +When considering inviting someone to be a committer, the PMC looks for +contributors who are already doing the work and exercising the judgment expected +of a committer. After all, any contributor can do all of the things a committer +does except for merge a PR. While there is no set list of requirements, nor a +checklist that entitles one to commit privileges, typical behaviors include: + +- Contributions beyond pull requests, such as reviewing other pull requests, + fixing bugs and documentation, triaging issues, answering community questions, + improving usability, helping with CI, verifying releases, etc. + +- Contributions that are consistent in quality and sustained + over time, typically on the order of 6 months or more. + +- Assistance growing the size and health of the community via constructive, + respectful, and consensus driven interactions, as described in our [Code of + Conduct] and the [Apache Way]. + +If you feel you should be offered committer privileges, but have not been, you +can reach out to one of the PMC members or the private@datafusion.apache.org mailing +list. + +[code of conduct]: https://www.apache.org/foundation/policies/conduct.html + +## Becoming a PMC Member + +Committers with long term sustained contributions to the project may be invited +to join the PMC. This is a recognition of a significant contribution to growing +the community, improving the project, and helping to guide the project's +direction, typically over the course of a year or more. diff --git a/docs/source/contributor-guide/howtos.md b/docs/source/contributor-guide/howtos.md new file mode 100644 index 000000000000..f105ab2c42db --- /dev/null +++ b/docs/source/contributor-guide/howtos.md @@ -0,0 +1,133 @@ + + +# HOWTOs + +## How to add a new scalar function + +Below is a checklist of what you need to do to add a new scalar function to DataFusion: + +- Add the actual implementation of the function to a new module file within: + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions-nested) for arrays, maps and structs functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/crypto) for crypto functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/datetime) for datetime functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/encoding) for encoding functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/math) for math functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/regex) for regex functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/string) for string functions + - [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/unicode) for unicode functions + - create a new module [here](https://github.com/apache/datafusion/tree/main/datafusion/functions/src/) for other functions. +- New function modules - for example a `vector` module, should use a [rust feature](https://doc.rust-lang.org/cargo/reference/features.html) (for example `vector_expressions`) to allow DataFusion + users to enable or disable the new module as desired. +- The implementation of the function is done via implementing `ScalarUDFImpl` trait for the function struct. + - See the [advanced_udf.rs] example for an example implementation + - Add tests for the new function +- To connect the implementation of the function add to the mod.rs file: + - a `mod xyz;` where xyz is the new module file + - a call to `make_udf_function!(..);` + - an item in `export_functions!(..);` +- In [sqllogictest/test_files], add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](https://github.com/apache/datafusion/blob/main/datafusion/sqllogictest/README.md) +- Add SQL reference documentation [here](https://github.com/apache/datafusion/blob/main/docs/source/user-guide/sql/scalar_functions.md) + - An example of this being done can be seen [here](https://github.com/apache/datafusion/pull/12775) + - Run `./dev/update_function_docs.sh` to update docs + +[advanced_udf.rs]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +[sqllogictest/test_files]: https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest/test_files + +## How to add a new aggregate function + +Below is a checklist of what you need to do to add a new aggregate function to DataFusion: + +- Add the actual implementation of an `Accumulator` and `AggregateExpr`: +- In [datafusion/expr/src](../../../datafusion/expr/src/aggregate_function.rs), add: + - a new variant to `AggregateFunction` + - a new entry to `FromStr` with the name of the function as called by SQL + - a new line in `return_type` with the expected return type of the function, given an incoming type + - a new line in `signature` with the signature of the function (number and types of its arguments) + - a new line in `create_aggregate_expr` mapping the built-in to the implementation + - tests to the function. +- In [sqllogictest/test_files], add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. + - Documentation for `sqllogictest` [here](https://github.com/apache/datafusion/blob/main/datafusion/sqllogictest/README.md) +- Add SQL reference documentation [here](https://github.com/apache/datafusion/blob/main/docs/source/user-guide/sql/aggregate_functions.md) + - An example of this being done can be seen [here](https://github.com/apache/datafusion/pull/12775) + - Run `./dev/update_function_docs.sh` to update docs + +## How to display plans graphically + +The query plans represented by `LogicalPlan` nodes can be graphically +rendered using [Graphviz](https://www.graphviz.org/). + +To do so, save the output of the `display_graphviz` function to a file.: + +```rust +// Create plan somehow... +let mut output = File::create("/tmp/plan.dot")?; +write!(output, "{}", plan.display_graphviz()); +``` + +Then, use the `dot` command line tool to render it into a file that +can be displayed. For example, the following command creates a +`/tmp/plan.pdf` file: + +```bash +dot -Tpdf < /tmp/plan.dot > /tmp/plan.pdf +``` + +## How to format `.md` document + +We are using `prettier` to format `.md` files. + +You can either use `npm i -g prettier` to install it globally or use `npx` to run it as a standalone binary. Using `npx` required a working node environment. Upgrading to the latest prettier is recommended (by adding `--upgrade` to the `npm` command). + +```bash +$ prettier --version +2.3.0 +``` + +After you've confirmed your prettier version, you can format all the `.md` files: + +```bash +prettier -w {datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md +``` + +## How to format `.toml` files + +We use `taplo` to format `.toml` files. + +For Rust developers, you can install it via: + +```sh +cargo install taplo-cli --locked +``` + +> Refer to the [Installation section][doc] on other ways to install it. +> +> [doc]: https://taplo.tamasfe.dev/cli/installation/binary.html + +```bash +$ taplo --version +taplo 0.9.0 +``` + +After you've confirmed your `taplo` version, you can format all the `.toml` files: + +```bash +taplo fmt +``` diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 252848bb7132..4645fe5c8804 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -34,6 +34,10 @@ community as well as get more familiar with Rust and the relevant codebases. You can find a curated [good-first-issue] list to help you get started. +[good-first-issue]: https://github.com/apache/datafusion/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22 + +### Open Contribution and Assigning tickets + DataFusion is an open contribution project, and thus there is no particular project imposed deadline for completing any issue or any restriction on who can work on an issue, nor how many people can work on an issue at the same time. @@ -54,343 +58,137 @@ unable to make progress you should unassign the issue by using the `unassign me` link at the top of the issue page (and ask for help if are stuck) so that someone else can get involved in the work. +### File Tickets to Discuss New Features + If you plan to work on a new feature that doesn't have an existing ticket, it is a good idea to open a ticket to discuss the feature. Advanced discussion often helps avoid wasted effort by determining early if the feature is a good fit for -DataFusion before too much time is invested. It also often helps to discuss your -ideas with the community to get feedback on implementation. - -[good-first-issue]: https://github.com/apache/datafusion/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22 - -# Developer's guide - -## Pull Request Overview - -We welcome pull requests (PRs) from anyone from the community. - -DataFusion is a very active fast-moving project and we try to review and merge PRs quickly to keep the review backlog down and the pace up. After review and approval, one of the [many people with commit access](https://arrow.apache.org/committers/) will merge your PR. - -Review bandwidth is currently our most limited resource, and we highly encourage reviews by the broader community. If you are waiting for your PR to be reviewed, consider helping review other PRs that are waiting. Such review both helps the reviewer to learn the codebase and become more expert, as well as helps identify issues in the PR (such as lack of test coverage), that can be addressed and make future reviews faster and more efficient. - -## Creating Pull Requests - -We recommend splitting your contributions into smaller PRs rather than large PRs (500+ lines) because: - -1. The PR is more likely to be reviewed quickly -- our reviewers struggle to find the contiguous time needed to review large PRs. -2. The PR discussions tend to be more focused and less likely to get lost among several different threads. -3. It is often easier to accept and act on feedback when it comes early on in a small change, before a particular approach has been polished too much. - -If you are concerned that a larger design will be lost in a string of small PRs, creating a large draft PR that shows how they all work together can help. - -# Reviewing Pull Requests - -When reviewing PRs, please remember our primary goal is to improve DataFusion and its community together. PR feedback should be constructive with the aim to help improve the code as well as the understanding of the contributor. - -Please ensure any issues you raise contains a rationale and suggested alternative -- it is frustrating to be told "don't do it this way" without any clear reason or alternate provided. - -Some things to specifically check: - -1. Is the feature or fix covered sufficiently with tests (see `Test Organization` below)? -2. Is the code clear, and fits the style of the existing codebase? - -## "Major" and "Minor" PRs - -Since we are a worldwide community, we have contributors in many timezones who review and comment. To ensure anyone who wishes has an opportunity to review a PR, our committers try to ensure that at least 24 hours passes between when a "major" PR is approved and when it is merged. - -A "major" PR means there is a substantial change in design or a change in the API. Committers apply their best judgment to determine what constitutes a substantial change. A "minor" PR might be merged without a 24 hour delay, again subject to the judgment of the committer. Examples of potential "minor" PRs are: - -1. Documentation improvements/additions -2. Small bug fixes -3. Non-controversial build-related changes (clippy, version upgrades etc.) -4. Smaller non-controversial feature additions - -The good thing about open code and open development is that any issues in one change can almost always be fixed with a follow on PR. - -## Stale PRs - -Pull requests will be marked with a `stale` label after 60 days of inactivity and then closed 7 days after that. -Commenting on the PR will remove the `stale` label. - -## Getting Started - -This section describes how you can get started at developing DataFusion. - -### Windows setup - -```shell -wget https://az792536.vo.msecnd.net/vms/VMBuild_20190311/VirtualBox/MSEdge/MSEdge.Win10.VirtualBox.zip -choco install -y git rustup.install visualcpp-build-tools -git-bash.exe -cargo build -``` - -### Protoc Installation - -Compiling DataFusion from sources requires an installed version of the protobuf compiler, `protoc`. - -On most platforms this can be installed from your system's package manager - -``` -# Ubuntu -$ sudo apt install -y protobuf-compiler - -# Fedora -$ dnf install -y protobuf-devel +DataFusion before too much time is invested. Discussion on a ticket can help +gather feedback from the community and is likely easier to discuss than a 1000 +line PR. -# Arch Linux -$ pacman -S protobuf +If you open a ticket and it doesn't get any response, you can try `@`-mentioning +recently active community members in the ticket to get their attention. -# macOS -$ brew install protobuf -``` +### What Features are Good Fits for DataFusion? -You will want to verify the version installed is `3.12` or greater, which introduced support for explicit [field presence](https://github.com/protocolbuffers/protobuf/blob/v3.12.0/docs/field_presence.md). Older versions may fail to compile. +DataFusion is designed to highly extensible, and many features can be implemented +as extensions without changing the core of DataFusion. -```shell -$ protoc --version -libprotoc 3.12.4 -``` +We are [working on criteria for what features are good fits for DataFusion], and +will update this section when we have more to share. -Alternatively a binary release can be downloaded from the [Release Page](https://github.com/protocolbuffers/protobuf/releases) or [built from source](https://github.com/protocolbuffers/protobuf/blob/main/src/README.md). +[working on criteria for what features are good fits for datafusion]: https://github.com/apache/datafusion/issues/12357 -### Bootstrap environment - -DataFusion is written in Rust and it uses a standard rust toolkit: - -- `cargo build` -- `cargo fmt` to format the code -- `cargo test` to test -- etc. - -Note that running `cargo test` requires significant memory resources, due to cargo running many tests in parallel by default. If you run into issues with slow tests or system lock ups, you can significantly reduce the memory required by instead running `cargo test -- --test-threads=1`. For more information see [this issue](https://github.com/apache/datafusion/issues/5347). - -Testing setup: - -- `rustup update stable` DataFusion uses the latest stable release of rust -- `git submodule init` -- `git submodule update` - -Formatting instructions: - -- [ci/scripts/rust_fmt.sh](../../../ci/scripts/rust_fmt.sh) -- [ci/scripts/rust_clippy.sh](../../../ci/scripts/rust_clippy.sh) -- [ci/scripts/rust_toml_fmt.sh](../../../ci/scripts/rust_toml_fmt.sh) - -or run them all at once: - -- [dev/rust_lint.sh](../../../dev/rust_lint.sh) - -## Testing - -Tests are critical to ensure that DataFusion is working properly and -is not accidentally broken during refactorings. All new features -should have test coverage. - -DataFusion has several levels of tests in its [Test -Pyramid](https://martinfowler.com/articles/practical-test-pyramid.html) -and tries to follow the Rust standard [Testing Organization](https://doc.rust-lang.org/book/ch11-03-test-organization.html) in the The Book. - -### Unit tests - -Tests for code in an individual module are defined in the same source file with a `test` module, following Rust convention. - -### sqllogictests Tests - -DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest) which are run like any other Rust test using `cargo test --test sqllogictests`. - -`sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. - -Like similar systems such as [DuckDB](https://duckdb.org/dev/testing), DataFusion has chosen to trade off a slightly higher barrier to contribution for longer term maintainability. - -### Rust Integration Tests - -There are several tests of the public interface of the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. - -You can run these tests individually using `cargo` as normal command such as - -```shell -cargo test -p datafusion --test parquet_exec -``` - -## Benchmarks - -### Criterion Benchmarks - -[Criterion](https://docs.rs/criterion/latest/criterion/index.html) is a statistics-driven micro-benchmarking framework used by DataFusion for evaluating the performance of specific code-paths. In particular, the criterion benchmarks help to both guide optimisation efforts, and prevent performance regressions within DataFusion. - -Criterion integrates with Cargo's built-in [benchmark support](https://doc.rust-lang.org/cargo/commands/cargo-bench.html) and a given benchmark can be run with - -``` -cargo bench --bench BENCHMARK_NAME -``` - -A full list of benchmarks can be found [here](https://github.com/apache/datafusion/tree/main/datafusion/core/benches). - -_[cargo-criterion](https://github.com/bheisler/cargo-criterion) may also be used for more advanced reporting._ - -### Parquet SQL Benchmarks - -The parquet SQL benchmarks can be run with - -``` - cargo bench --bench parquet_query_sql -``` - -These randomly generate a parquet file, and then benchmark queries sourced from [parquet_query_sql.sql](../../../datafusion/core/benches/parquet_query_sql.sql) against it. This can therefore be a quick way to add coverage of particular query and/or data paths. - -If the environment variable `PARQUET_FILE` is set, the benchmark will run queries against this file instead of a randomly generated one. This can be useful for performing multiple runs, potentially with different code, against the same source data, or for testing against a custom dataset. - -The benchmark will automatically remove any generated parquet file on exit, however, if interrupted (e.g. by CTRL+C) it will not. This can be useful for analysing the particular file after the fact, or preserving it to use with `PARQUET_FILE` in subsequent runs. +# Developer's guide -### Comparing Baselines +## Pull Request Overview -By default, Criterion.rs will compare the measurements against the previous run (if any). Sometimes it's useful to keep a set of measurements around for several runs. For example, you might want to make multiple changes to the code while comparing against the master branch. For this situation, Criterion.rs supports custom baselines. +We welcome pull requests (PRs) from anyone in the community. -``` - git checkout main - cargo bench --bench sql_planner -- --save-baseline main - git checkout YOUR_BRANCH - cargo bench --bench sql_planner -- --baseline main -``` +DataFusion is a rapidly evolving project and we try to review and merge PRs quickly. -Note: For MacOS it may be required to run `cargo bench` with `sudo` +Review bandwidth is currently our most limited resource, and we highly encourage reviews by the broader community. If you are waiting for your PR to be reviewed, consider helping review other PRs that are waiting. Such review both helps the reviewer to learn the codebase and become more expert, as well as helps identify issues in the PR (such as lack of test coverage), that can be addressed and make future reviews faster and more efficient. -``` -sudo cargo bench ... -``` +The lifecycle of a PR is: -More information on [Baselines](https://bheisler.github.io/criterion.rs/book/user_guide/command_line_options.html#baselines) +1. Create a PR targeting the `main` branch. +2. For new contributors a committer must first trigger the CI tasks. Please mention the members from committers list in the PR to help trigger the CI +3. Your PR will be reviewed. Please respond to all feedback on the PR: you don't have to change the code, but you should acknowledge the feedback. PRs waiting for the feedback for more than a few days will be marked as draft. +4. Once the PR is approved, one of the [committers] will merge your PR, typically within 24 hours. We leave approved "major" changes (see below) open for 24 hours prior to merging, and sometimes leave "minor" PRs open for the same time to permit additional feedback. -### Upstream Benchmark Suites +Note that the above time frames are estimates. Due to limited committer +bandwidth, it may take longer to merge your PR. Please wait +patiently. If it has been several days you can friendly ping the +committer who approved your PR to help remind them to merge it. -Instructions and tooling for running upstream benchmark suites against DataFusion can be found in [benchmarks](https://github.com/apache/datafusion/tree/main/benchmarks). +[committers]: https://people.apache.org/phonebook.html?unix=datafusion -These are valuable for comparative evaluation against alternative Arrow implementations and query engines. +## Creating Pull Requests -## HOWTOs +When possible, we recommend splitting your contributions into multiple smaller focused PRs rather than large PRs (500+ lines) because: -### How to add a new scalar function +1. The PR is more likely to be reviewed quickly -- our reviewers struggle to find the contiguous time needed to review large PRs. +2. The PR discussions tend to be more focused and less likely to get lost among several different threads. +3. It is often easier to accept and act on feedback when it comes early on in a small change, before a particular approach has been polished too much. -Below is a checklist of what you need to do to add a new scalar function to DataFusion: +If you are concerned that a larger design will be lost in a string of small PRs, creating a large draft PR that shows how they all work together can help. -- Add the actual implementation of the function to a new module file within: - - [here](../../../datafusion/functions-array/src) for array functions - - [here](../../../datafusion/functions/src/crypto) for crypto functions - - [here](../../../datafusion/functions/src/datetime) for datetime functions - - [here](../../../datafusion/functions/src/encoding) for encoding functions - - [here](../../../datafusion/functions/src/math) for math functions - - [here](../../../datafusion/functions/src/regex) for regex functions - - [here](../../../datafusion/functions/src/string) for string functions - - [here](../../../datafusion/functions/src/unicode) for unicode functions - - create a new module [here](../../../datafusion/functions/src) for other functions. -- New function modules - for example a `vector` module, should use a [rust feature](https://doc.rust-lang.org/cargo/reference/features.html) (for example `vector_expressions`) to allow DataFusion - users to enable or disable the new module as desired. -- The implementation of the function is done via implementing `ScalarUDFImpl` trait for the function struct. - - See the [advanced_udf.rs](../../../datafusion-examples/examples/advanced_udf.rs) example for an example implementation - - Add tests for the new function -- To connect the implementation of the function add to the mod.rs file: - - a `mod xyz;` where xyz is the new module file - - a call to `make_udf_function!(..);` - - an item in `export_functions!(..);` -- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) -- Add SQL reference documentation [here](../../../docs/source/user-guide/sql/scalar_functions.md) +Note all commits in a PR are squashed when merged to the `main` branch so there is one commit per PR after merge. -### How to add a new aggregate function +## Conventional Commits & Labeling PRs -Below is a checklist of what you need to do to add a new aggregate function to DataFusion: +We generate change logs for each release using an automated process that will categorize PRs based on the title +and/or the GitHub labels attached to the PR. -- Add the actual implementation of an `Accumulator` and `AggregateExpr`: - - [here](../../../datafusion/physical-expr/src/string_expressions.rs) for string functions - - [here](../../../datafusion/physical-expr/src/math_expressions.rs) for math functions - - [here](../../../datafusion/functions/src/datetime/mod.rs) for datetime functions - - create a new module [here](../../../datafusion/physical-expr/src) for other functions -- In [datafusion/expr/src](../../../datafusion/expr/src/aggregate_function.rs), add: - - a new variant to `AggregateFunction` - - a new entry to `FromStr` with the name of the function as called by SQL - - a new line in `return_type` with the expected return type of the function, given an incoming type - - a new line in `signature` with the signature of the function (number and types of its arguments) - - a new line in `create_aggregate_expr` mapping the built-in to the implementation - - tests to the function. -- In [sqllogictest/test_files](../../../datafusion/sqllogictest/test_files), add new `sqllogictest` integration tests where the function is called through SQL against well known data and returns the expected result. - - Documentation for `sqllogictest` [here](../../../datafusion/sqllogictest/README.md) -- Add SQL reference documentation [here](../../../docs/source/user-guide/sql/aggregate_functions.md) +We follow the [Conventional Commits] specification to categorize PRs based on the title. This most often simply means +looking for titles starting with prefixes such as `fix:`, `feat:`, `docs:`, or `chore:`. We do not enforce this +convention but encourage its use if you want your PR to feature in the correct section of the changelog. -### How to display plans graphically +The change log generator will also look at GitHub labels such as `bug`, `enhancement`, or `api change`, and labels +do take priority over the conventional commit approach, allowing maintainers to re-categorize PRs after they have been merged. -The query plans represented by `LogicalPlan` nodes can be graphically -rendered using [Graphviz](https://www.graphviz.org/). +[conventional commits]: https://www.conventionalcommits.org/en/v1.0.0/ -To do so, save the output of the `display_graphviz` function to a file.: +# Reviewing Pull Requests -```rust -// Create plan somehow... -let mut output = File::create("/tmp/plan.dot")?; -write!(output, "{}", plan.display_graphviz()); -``` +Some helpful links: -Then, use the `dot` command line tool to render it into a file that -can be displayed. For example, the following command creates a -`/tmp/plan.pdf` file: +- [PRs Waiting for Review] on GitHub +- [Approved PRs Waiting for Merge] on GitHub -```bash -dot -Tpdf < /tmp/plan.dot > /tmp/plan.pdf -``` +[prs waiting for review]: https://github.com/apache/datafusion/pulls?q=is%3Apr+is%3Aopen+-review%3Aapproved+-is%3Adraft+ +[approved prs waiting for merge]: https://github.com/apache/datafusion/pulls?q=is%3Apr+is%3Aopen+review%3Aapproved+-is%3Adraft -## Specifications +When reviewing PRs, our primary goal is to improve DataFusion and its community together. PR feedback should be constructive with the aim to help improve the code as well as the understanding of the contributor. -We formalize some DataFusion semantics and behaviors through specification -documents. These specifications are useful to be used as references to help -resolve ambiguities during development or code reviews. +Please ensure any issues you raise contains a rationale and suggested alternative -- it is frustrating to be told "don't do it this way" without any clear reason or alternate provided. -You are also welcome to propose changes to existing specifications or create -new specifications as you see fit. +Some things to specifically check: -Here is the list current active specifications: +1. Is the feature or fix covered sufficiently with tests (see the [Testing](testing.md) section)? +2. Is the code clear, and fits the style of the existing codebase? -- [Output field name semantic](https://datafusion.apache.org/contributor-guide/specification/output-field-name-semantic.html) -- [Invariants](https://datafusion.apache.org/contributor-guide/specification/invariants.html) +## Performance Improvements -All specifications are stored in the `docs/source/specification` folder. +Performance improvements are always welcome: performance is a key DataFusion +feature. -## How to format `.md` document +In general, the performance improvement from a change should be "enough" to +justify any added code complexity. How much is "enough" is a judgement made by +the committers, but generally means that the improvement should be noticeable in +a real-world scenario and is greater than the noise of the benchmarking system. -We are using `prettier` to format `.md` files. +To help committers evaluate the potential improvement, performance PRs should +in general be accompanied by benchmark results that demonstrate the improvement. -You can either use `npm i -g prettier` to install it globally or use `npx` to run it as a standalone binary. Using `npx` required a working node environment. Upgrading to the latest prettier is recommended (by adding `--upgrade` to the `npm` command). +The best way to demonstrate a performance improvement is with the existing +benchmarks: -```bash -$ prettier --version -2.3.0 -``` +- [System level SQL Benchmarks](https://github.com/apache/datafusion/tree/main/benchmarks) +- Microbenchmarks such as those in [functions/benches](https://github.com/apache/datafusion/tree/main/datafusion/functions/benches) -After you've confirmed your prettier version, you can format all the `.md` files: +If there is no suitable existing benchmark, you can create a new one. It helps +to isolate the effects of your change by creating a separate PR with the +benchmark, and then a PR with the code change that improves the benchmark. -```bash -prettier -w {datafusion,datafusion-cli,datafusion-examples,dev,docs}/**/*.md -``` +[system level sql benchmarks]: https://github.com/apache/datafusion/tree/main/benchmarks +[functions/benches]: https://github.com/apache/datafusion/tree/main/datafusion/functions/benches -## How to format `.toml` files +## "Major" and "Minor" PRs -We use `taplo` to format `.toml` files. - -For Rust developers, you can install it via: +Since we are a worldwide community, we have contributors in many timezones who review and comment. To ensure anyone who wishes has an opportunity to review a PR, our committers try to ensure that at least 24 hours passes between when a "major" PR is approved and when it is merged. -```sh -cargo install taplo-cli --locked -``` +A "major" PR means there is a substantial change in design or a change in the API. Committers apply their best judgment to determine what constitutes a substantial change. A "minor" PR might be merged without a 24 hour delay, again subject to the judgment of the committer. Examples of potential "minor" PRs are: -> Refer to the [Installation section][doc] on other ways to install it. -> -> [doc]: https://taplo.tamasfe.dev/cli/installation/binary.html +1. Documentation improvements/additions +2. Small bug fixes +3. Non-controversial build-related changes (clippy, version upgrades etc.) +4. Smaller non-controversial feature additions -```bash -$ taplo --version -taplo 0.9.0 -``` +The good thing about open code and open development is that any issues in one change can almost always be fixed with a follow on PR. -After you've confirmed your `taplo` version, you can format all the `.toml` files: +## Stale PRs -```bash -taplo fmt -``` +Pull requests will be marked with a `stale` label after 60 days of inactivity and then closed 7 days after that. +Commenting on the PR will remove the `stale` label. diff --git a/docs/source/contributor-guide/inviting.md b/docs/source/contributor-guide/inviting.md new file mode 100644 index 000000000000..4e7ffeb7518d --- /dev/null +++ b/docs/source/contributor-guide/inviting.md @@ -0,0 +1,427 @@ + + +# Inviting New Committers and PMC Members + +This is a cookbook of the recommended DataFusion specific process for inviting +new committers and PMC members. It is intended to follow the [Apache Project +Management Committee guidelines], which takes precedence over this document if +there is a conflict. + +This process is intended for PMC members. While the process is open, the actual +discussions and invitations are not (one of the very few things that we do not +do in the open). + +When following this process, in doubt, check the `private@datafusion.apache.org` +[mailing list archive] for examples of previous emails. + +The general process is: + +1. A PMC member starts a discussion on `private@datafusion.apache.org` + about a candidate they feel should be considered who would use the additional + trust granted to a committer or PMC member to help the community grow and thrive. +2. When consensus is reached a formal [vote] occurs +3. Assuming the vote is successful, that person is granted the appropriate + access via ASF systems. + +[apache project management committee guidelines]: https://www.apache.org/dev/pmc.html +[mailing list archive]: https://lists.apache.org/list.html?dev@datafusion.apache.org +[vote]: https://www.apache.org/foundation/voting + +## New Committers + +### Step 1: Start a Discussion Thread + +The goal of this step is to allow free form discussion prior to calling a vote. +This helps the other PMC members understand why you are proposing to invite the +person to become a committer and makes the voting process easier. + +Best practice is to include some details about why you are proposing to invite +the person. Here is an example: + +``` +To: private@datafusion.apache.org +Subject: [DISCUSS] $PERSONS_NAME for Committer + +$PERSONS_NAME has been an active contributor to the DataFusion community for the +last 6 months[1][2], helping others, answering questions, and improving the +project's code. + +Are there any thoughts about inviting $PERSONS_NAME to become a committer? + +Thanks, +Your Name + +[1]: https://github.com/apache/datafusion/issues?q=commenter%3A +[2]: https://github.com/apache/datafusion/commits?author= +``` + +### Step 2: Formal Vote + +Assuming the discussion thread goes well, start a formal vote, with an email like: + +``` +To: private@datafusion.apache.org +Subject: [VOTE] $PERSONS_NAME for DataFusion Committer +I propose to invite $PERSONS_NAME to be a committer. See discussion here [1]. + +[ ] +1 : Invite $PERSONS_NAME to become a committer +[ ] +0: ... +[ ] -1: I disagree because ... + +The vote will be open for at least 48 hours. + +My vote: +1 + +Thanks, +Your Name + +[1] LINK TO DISCUSSION THREAD (e.g. https://lists.apache.org/thread/7rocc026wckknrjt9j6bsqk3z4c0g5yf) +``` + +If the vote passes (requires 3 `+1`, and no `-1` voites) , send a result email +like the following (substitute `N` with the number of `+1` votes) + +``` +to: private@datafusion.apache.org +subject: [RESULT][VOTE] $PERSONS_NAME for DataFusion Committer + +The vote carries with N +1 votes. +``` + +### Step 3: Send Invitation to the Candidate + +Once the vote on `private@` has passed and the `[RESULT]` e-mail sent, send an +invitation to the new committer and cc: `private@datafusion.apache.org` + +In order to be given write access, the committer needs an apache.org account (e.g. `name@apache.org`). To get one they need: + +1. An [ICLA] on file. Note that Sending an ICLA to `secretary@apache.org` will trigger account creation. If they already have an ICLA on file, but no Apache account, see instructions below. +2. Add GitHub username to the account at [id.apache.org](http://id.apache.org/) +3. Connect [gitbox.apache.org] to their GitHub account +4. Follow the instructions at [gitbox.apache.org] to link your GitHub account with your + +[gitbox.apache.org]: https://gitbox.apache.org + +If the new committer is already a committer on another Apache project (so they +already had an Apache account), The PMC Chair (or an ASF member) simply needs to +explicitly add them to the roster on the [Whimsy Roster Tool]. + +### Step 4: Announce and Celebrate the New Committer + +Email to Send an email such as the following to +[dev@datafusion.apache.org](mailto:dev@datafusion.apache.org]) to celebrate and +acknowledge the new committer to the community. + +``` +To: dev@datafusion.apache.org +Subject: [ANNOUNCE] New DataFusion committer: $NEW_COMMITTER + +On behalf of the DataFusion PMC, I'm happy to announce that $NEW_COMMITTER +has accepted an invitation to become a committer on Apache +DataFusion. Welcome, and thank you for your contributions! + + +``` + +[icla]: http://www.apache.org/licenses/#clas + +## Email Templates for Inviting New Committers + +### Committers WITHOUT an Apache account and WITHOUT ICLA on file + +You can check here to see if someone has an Apache account: http://people.apache.org/committer-index.html + +If you aren't sure whether someone has an ICLA on file, ask the DataFusion PMC +chair or an ASF Member to check [the ICLA file list]. + +[the icla file list]: https://whimsy.apache.org/officers/unlistedclas.cgi + +``` +To: $EMAIL +Cc: private@datafusion.apache.org +Subject: Invitation to become a DataFusion Committer + +Dear $NEW_COMMITTER, + +The DataFusion Project Management Committee (PMC) hereby offers you +committer privileges to the project. These privileges are offered on +the understanding that you'll use them reasonably and with common +sense. We like to work on trust rather than unnecessary constraints. + +Being a committer enables you to merge PRs to DataFusion git repositories. + +Being a committer does not require you to participate any more than +you already do. It does tend to make one even more committed. You will +probably find that you spend more time here. + +Of course, you can decline and instead remain as a contributor, +participating as you do now. + +A. This personal invitation is a chance for you to accept or decline +in private. Either way, please let us know in reply to the +private@datafusion.apache.org address only. + +B. If you accept, the next step is to register an ICLA: + +Details of the iCLA and how to submit them are found through this link: +https://www.apache.org/licenses/contributor-agreements.html#clas + +When you transmit the completed ICLA, request to notify Apache +DataFusion and choose a unique Apache id. Look to see if your preferred id +is already taken at http://people.apache.org/committer-index.html This +will allow the Secretary to notify the PMC when your ICLA has been +recorded. + +Once you are notified that your Apache account has been created, you must then: + +1. Add your GitHub username to your account at id.apache.org + +2. Follow the instructions at gitbox.apache.org to link your GitHub account +with your Apache account. + +You will then have write (but not admin) access to the DataFusion repositories. +If you have questions or run into issues, please reply-all to this e-mail. +``` + +After the new account has been created, you can announce the new committer on `dev@` + +### Committers WITHOUT an Apache account but WITH an ICLA on file + +In this scenario, an officer (the project VP or an Apache Member) needs only to +request an account to be created for the new committer, since through the +ordinary process the ASF Secretary will do it automatically. This is done at +https://whimsy.apache.org/officers/acreq. Before requesting the new account, +send an e-mail to the new committer (cc'ing `private@`) like this: + +``` +To: $EMAIL +Cc: private@datafusion.apache.org +Subject: Invitation to become a DataFusion Committer + +Dear $NEW_COMMITTER, + +The DataFusion Project Management Committee (PMC) hereby offers you +committer privileges to the project. These privileges are offered on +the understanding that you'll use them reasonably and with common +sense. We like to work on trust rather than unnecessary constraints. + +Being a committer enables you to merge PRs to DataFusion git repositories. + +Being a committer does not require you to participate any more than +you already do. It does tend to make one even more committed. You will +probably find that you spend more time here. + +Of course, you can decline and instead remain as a contributor, +participating as you do now. + +This personal invitation is a chance for you to accept or decline +in private. Either way, please let us know in reply to the +private@datafusion.apache.org address only. We will have to request an +Apache account be created for you, so please let us know what user id +you would prefer. + +Once you are notified that your Apache account has been created, you must then: + +1. Add your GitHub username to your account at id.apache.org + +2. Follow the instructions at gitbox.apache.org to link your GitHub account +with your Apache account. + +You will then have write (but not admin) access to the DataFusion repositories. +If you have questions or run into issues, please reply-all to this e-mail. +``` + +### Committers WITH an existing Apache account + +In this scenario, an officer (the project PMC Chair or an ASF Member) can simply add +the new committer on the [Whimsy Roster Tool]. Before doing this, e-mail the new +committer inviting them to be a committer like so (cc private@): + +``` +To: $EMAIL +Cc: private@datafusion.apache.org +Subject: Invitation to become a DataFusion Committer + +Dear $NEW_COMMITTER, + +The DataFusion Project Management Committee (PMC) hereby offers you +committer privileges to the project. These privileges are offered on +the understanding that you'll use them reasonably and with common +sense. We like to work on trust rather than unnecessary constraints. + +Being a committer enables you to merge PRs to DataFusion git repositories. + +Being a committer does not require you to participate any more than +you already do. It does tend to make one even more committed. You will +probably find that you spend more time here. + +Of course, you can decline and instead remain as a contributor, +participating as you do now. + +If you accept, please let us know by replying to private@datafusion.apache.org. +``` + +## New PMC Members + +See also the ASF instructions on [how to add a PMC member]. + +[how to add a pmc member]: https://www.apache.org/dev/pmc.html#newpmc + +### Step 1: Start a Discussion Thread + +As for committers, start a discussion thread on the `private@` mailing list + +``` +To: private@datafusion.apache.org +Subject: [DISCUSS] $NEW_PMC_MEMBER for PMC + +I would like to propose adding $NEW_PMC_MEMBER[1] to the DataFusion PMC. + +$NEW_PMC_MEMBMER has been a committer since $COMMITTER_MONTH [2], has a +strong and sustained contribution record for more than a year, and focused on +helping the community and the project grow[3]. + +Are there any thoughts about inviting $NEW_PMC_MEMBER to become a PMC member? + +[1] https://github.com/$NEW_PMC_MEMBERS_GITHUB_ACCOUNT +[2] LINK TO COMMMITER VOTE RESULT THREAD (e.g. https://lists.apache.org/thread/ovgp8z97l1vh0wzjkgn0ktktggomxq9t) +[3]: https://github.com/apache/datafusion/pulls?q=commenter%3A<$NEW_PMC_MEMBERS_GITHUB_ACCOUNT>+ + +Thanks, +YOUR NAME +``` + +### Step 2: Formal Vote + +Assuming the discussion thread goes well, start a formal vote with an email like: + +``` +To: private@datafusion.apache.org +Subject: [VOTE] $NEW_PMC_MEMBER for PMC + +I propose inviting $NEW_PMC_MEMBER to join the DataFusion PMC. We previously +discussed the merits of inviting $NEW_PMC_MEMBER to join the PMC [1]. + +The vote will be open for at least 7 days. + +[ ] +1 : Invite $NEW_PMC_MEMBER to become a PMC member +[ ] +0: ... +[ ] -1: I disagree because ... + +My vote: +1 + +[1] LINK TO DISCUSSION THREAD (e.g. https://lists.apache.org/thread/x2zno2hs1ormvfy13n7h82hmsxp3j66c) + +Thanks, +Your Name +``` + +### Step 3: Send Notice to ASF Board + +The DataFusion PMC Chair then sends a NOTICE to `board@apache.org` (cc'ing +`private@`) like this: + +``` +To: board@apache.org +Cc: private@datafusion.apache.org +Subject: [NOTICE] $NEW_PMC_MEMBER to join DataFusion PMC + +DataFusion proposes to invite $NEW_PMC_MEMBER ($NEW_PMC_MEMBER_APACHE_ID) to join the PMC. + +The vote result is available here: +$VOTE_RESULT_URL + +FYI: Full vote details: +$VOTE_URL +``` + +### Step 4: Send invitation email + +Once, the PMC chair has confirmed that the email sent to `board@apache.org` has +made it to the archives, the Chair sends an invitation e-mail to the new PMC +member (cc'ing `private@`) like this: + +``` +To: $EMAIL +Cc: private@datafusion.apache.org +Subject: Invitation to join the DataFusion PMC +Dear $NEW_PMC_MEMBER, + +In recognition of your demonstrated commitment to, and alignment with, the +goals of the Apache DataFusion project, the DataFusion PMC has voted to offer you +membership in the DataFusion PMC ("Project Management Committee"). + +Please let us know if you accept by subscribing to the private alias [by +sending mail to private-subscribe@datafusion.apache.org], and posting +a message to private@datafusion.apache.org. + +The PMC for every top-level project is tasked by the Apache Board of +Directors with official oversight and binding votes in that project. + +As a PMC member, you are responsible for continuing the general project, code, +and community oversight that you have exhibited so far. The votes of the PMC +are legally binding. + +All PMC members are subscribed to the project's private mail list, which is +used to discuss issues unsuitable for an open, public forum, such as people +issues (e.g. new committers, problematic community members, etc.), security +issues, and the like. It can't be emphasized enough that care should be taken +to minimize the use of the private list, discussing everything possible on the +appropriate public list. + +The private PMC list is *private* - it is strictly for the use of the +PMC. Messages are not to be forwarded to anyone else without the express +permission of the PMC. Also note that any Member of the Foundation has the +right to review and participate in any PMC list, as a PMC is acting on behalf +of the Membership. + +Finally, the PMC is not meant to create a hierarchy within the committership or +the community. Therefore, in our day-to-day interactions with the rest of the +community, we continue to interact as peers, where every reasonable opinion is +considered, and all community members are invited to participate in our public +voting. If there ever is a situation where the PMC's view differs significantly +from that of the rest of the community, this is a symptom of a problem that +needs to be addressed. + +With the expectation of your acceptance, welcome! + +The Apache DataFusion PMC +``` + +### Step 5: Chair Promotes the Committer to PMC + +The PMC chair adds the user to the PMC using the [Whimsy Roster Tool]. + +### Step 6: Announce and Celebrate the New PMC Member + +Send an email such as the following to `dev@datafusion.apache.org` to celebrate: + +``` +To: dev@datafusion.apache.org +Subject: [ANNOUNCE] New DataFusion PMC member: $NEW_PMC_MEMBER + +The Project Management Committee (PMC) for Apache DataFusion has invited +$NEW_PMC_MEMBER to become a PMC member and we are pleased to announce +that $NEW_PMC_MEMBER has accepted. + +Congratulations and welcome! +``` + +[whimsy roster tool]: https://whimsy.apache.org/roster/committee/datafusion diff --git a/docs/source/contributor-guide/quarterly_roadmap.md b/docs/source/contributor-guide/quarterly_roadmap.md deleted file mode 100644 index ee82617225aa..000000000000 --- a/docs/source/contributor-guide/quarterly_roadmap.md +++ /dev/null @@ -1,96 +0,0 @@ - - -# Quarterly Roadmap - -A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding. - -## 2023 Q4 - -- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/datafusion/issues/6569) -- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/datafusion/issues/6980) -- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/datafusion/issues/6782) - -## 2022 Q2 - -### DataFusion Core - -- IO Improvements - - Reading, registering, and writing more file formats from both DataFrame API and SQL - - Additional options for IO including partitioning and metadata support -- Work Scheduling - - Improve predictability, observability and performance of IO and CPU-bound work - - Develop a more explicit story for managing parallelism during plan execution -- Memory Management - - Add more operators for memory limited execution -- Performance - - Incorporate row-format into operators such as aggregate - - Add row-format benchmarks - - Explore JIT-compiling complex expressions - - Explore LLVM for JIT, with inline Rust functions as the primary goal - - Improve performance of Sort and Merge using Row Format / JIT expressions -- Documentation - - General improvements to DataFusion website - - Publish design documents -- Streaming - - Create `StreamProvider` trait - -### Ballista - -- Make production ready - - Shuffle file cleanup - - Fill functional gaps between DataFusion and Ballista - - Improve task scheduling and data exchange efficiency - - Better error handling - - Task failure - - Executor lost - - Schedule restart - - Improve monitoring and logging - - Auto scaling support -- Support for multi-scheduler deployments. Initially for resiliency and fault tolerance but ultimately to support sharding for scalability and more efficient caching. -- Executor deployment grouping based on resource allocation - -### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib)) - -#### [DataFusion-Python](https://github.com/datafusion-contrib/datafusion-python) - -- Add missing functionality to DataFrame and SessionContext -- Improve documentation - -#### [DataFusion-S3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) - -- Create Python bindings to use with datafusion-python - -#### [DataFusion-Tui](https://github.com/datafusion-contrib/datafusion-tui) - -- Create multiple SQL editors -- Expose more Context and query metadata -- Support new data sources - - BigTable, HDFS, HTTP APIs - -#### [DataFusion-BigTable](https://github.com/datafusion-contrib/datafusion-bigtable) - -- Python binding to use with datafusion-python -- Timestamp range predicate pushdown -- Multi-threaded partition aware execution -- Production ready Rust SDK - -#### [DataFusion-Streams](https://github.com/datafusion-contrib/datafusion-streams) - -- Create experimental implementation of `StreamProvider` trait diff --git a/docs/source/contributor-guide/roadmap.md b/docs/source/contributor-guide/roadmap.md index a6d78d9311aa..3d9c1ee371fe 100644 --- a/docs/source/contributor-guide/roadmap.md +++ b/docs/source/contributor-guide/roadmap.md @@ -43,3 +43,84 @@ start a conversation using a github issue or the make review efficient and avoid surprises. [The current list of `EPIC`s can be found here](https://github.com/apache/datafusion/issues?q=is%3Aissue+is%3Aopen+epic). + +# Quarterly Roadmap + +A quarterly roadmap will be published to give the DataFusion community +visibility into the priorities of the projects contributors. This roadmap is not +binding and we would welcome any/all contributions to help keep this list up to +date. + +## 2023 Q4 + +- Improve data output (`COPY`, `INSERT` and DataFrame) output capability [#6569](https://github.com/apache/datafusion/issues/6569) +- Implementation of `ARRAY` types and related functions [#6980](https://github.com/apache/datafusion/issues/6980) +- Write an industrial paper about DataFusion for SIGMOD [#6782](https://github.com/apache/datafusion/issues/6782) + +## 2022 Q2 + +### DataFusion Core + +- IO Improvements + - Reading, registering, and writing more file formats from both DataFrame API and SQL + - Additional options for IO including partitioning and metadata support +- Work Scheduling + - Improve predictability, observability and performance of IO and CPU-bound work + - Develop a more explicit story for managing parallelism during plan execution +- Memory Management + - Add more operators for memory limited execution +- Performance + - Incorporate row-format into operators such as aggregate + - Add row-format benchmarks + - Explore JIT-compiling complex expressions + - Explore LLVM for JIT, with inline Rust functions as the primary goal + - Improve performance of Sort and Merge using Row Format / JIT expressions +- Documentation + - General improvements to DataFusion website + - Publish design documents +- Streaming + - Create `StreamProvider` trait + +### Ballista + +- Make production ready + - Shuffle file cleanup + - Fill functional gaps between DataFusion and Ballista + - Improve task scheduling and data exchange efficiency + - Better error handling + - Task failure + - Executor lost + - Schedule restart + - Improve monitoring and logging + - Auto scaling support +- Support for multi-scheduler deployments. Initially for resiliency and fault tolerance but ultimately to support sharding for scalability and more efficient caching. +- Executor deployment grouping based on resource allocation + +### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib)) + +### [DataFusion-Python](https://github.com/datafusion-contrib/datafusion-python) + +- Add missing functionality to DataFrame and SessionContext +- Improve documentation + +### [DataFusion-S3](https://github.com/datafusion-contrib/datafusion-objectstore-s3) + +- Create Python bindings to use with datafusion-python + +### [DataFusion-Tui](https://github.com/datafusion-contrib/datafusion-tui) + +- Create multiple SQL editors +- Expose more Context and query metadata +- Support new data sources + - BigTable, HDFS, HTTP APIs + +### [DataFusion-BigTable](https://github.com/datafusion-contrib/datafusion-bigtable) + +- Python binding to use with datafusion-python +- Timestamp range predicate pushdown +- Multi-threaded partition aware execution +- Production ready Rust SDK + +### [DataFusion-Streams](https://github.com/datafusion-contrib/datafusion-streams) + +- Create experimental implementation of `StreamProvider` trait diff --git a/docs/source/contributor-guide/specification/index.rst b/docs/source/contributor-guide/specification/index.rst index bcd5a895c4d2..a34f0b19e4de 100644 --- a/docs/source/contributor-guide/specification/index.rst +++ b/docs/source/contributor-guide/specification/index.rst @@ -18,6 +18,16 @@ Specifications ============== +We formalize some DataFusion semantics and behaviors through specification +documents. These specifications are useful to be used as references to help +resolve ambiguities during development or code reviews. + +You are also welcome to propose changes to existing specifications or create +new specifications as you see fit. All specifications are stored in the +`docs/source/specification` folder. Here is the list current active +specifications: + + .. toctree:: :maxdepth: 1 diff --git a/docs/source/contributor-guide/testing.md b/docs/source/contributor-guide/testing.md new file mode 100644 index 000000000000..b955b09050b3 --- /dev/null +++ b/docs/source/contributor-guide/testing.md @@ -0,0 +1,159 @@ + + +# Testing + +Tests are critical to ensure that DataFusion is working properly and +is not accidentally broken during refactorings. All new features +should have test coverage and the entire test suite is run as part of CI. + +DataFusion has several levels of tests in its [Test Pyramid] and tries to follow +the Rust standard [Testing Organization] described in [The Book]. + +Run tests using `cargo`: + +```shell +cargo test +``` + +You can also use other runners such as [cargo-nextest]. + +```shell +cargo nextest run +``` + +[test pyramid]: https://martinfowler.com/articles/practical-test-pyramid.html +[testing organization]: https://doc.rust-lang.org/book/ch11-03-test-organization.html +[the book]: https://doc.rust-lang.org/book/ +[cargo-nextest]: https://nexte.st/ + +## Unit tests + +Tests for code in an individual module are defined in the same source file with a `test` module, following Rust convention. +The [test_util](https://github.com/apache/datafusion/tree/main/datafusion/common/src/test_util.rs) module provides useful macros to write unit tests effectively, such as `assert_batches_sorted_eq` and `assert_batches_eq` for RecordBatches and `assert_contains` / `assert_not_contains` which are used extensively in the codebase. + +## sqllogictests Tests + +DataFusion's SQL implementation is tested using [sqllogictest](https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest) which are run like other tests using `cargo test --test sqllogictests`. + +`sqllogictests` tests may be less convenient for new contributors who are familiar with writing `.rs` tests as they require learning another tool. However, `sqllogictest` based tests are much easier to develop and maintain as they 1) do not require a slow recompile/link cycle and 2) can be automatically updated via `cargo test --test sqllogictests -- --complete`. + +Like similar systems such as [DuckDB](https://duckdb.org/dev/testing), DataFusion has chosen to trade off a slightly higher barrier to contribution for longer term maintainability. + +## Rust Integration Tests + +There are several tests of the public interface of the DataFusion library in the [tests](https://github.com/apache/datafusion/tree/main/datafusion/core/tests) directory. + +You can run these tests individually using `cargo` as normal command such as + +```shell +cargo test -p datafusion --test parquet_exec +``` + +## SQL "Fuzz" testing + +DataFusion uses the [SQLancer] for "fuzz" testing: it generates random SQL +queries and execute them against DataFusion to find bugs. + +The code is in the [datafusion-sqllancer] repository, and we welcome further +contributions. Kudos to [@2010YOUY01] for the initial implementation. + +[sqlancer]: https://github.com/sqlancer/sqlancer +[datafusion-sqllancer]: https://github.com/datafusion-contrib/datafusion-sqllancer +[@2010youy01]: https://github.com/2010YOUY01 + +## Documentation Examples + +We use Rust [doctest] to verify examples from the documentation are correct and +up-to-date. These tests are run as part of our CI and you can run them them +locally with the following command: + +```shell +cargo test --doc +``` + +### API Documentation Examples + +As with other Rust projects, examples in doc comments in `.rs` files are +automatically checked to ensure they work and evolve along with the code. + +### User Guide Documentation + +Rust example code from the user guide (anything marked with \`\`\`rust) is also +tested in the same way using the [doc_comment] crate. See the end of +[core/src/lib.rs] for more details. + +[doctest]: https://doc.rust-lang.org/rust-by-example/testing/doc_testing.html +[doc_comment]: https://docs.rs/doc-comment/latest/doc_comment +[core/src/lib.rs]: https://github.com/apache/datafusion/blob/main/datafusion/core/src/lib.rs#L583 + +## Benchmarks + +### Criterion Benchmarks + +[Criterion](https://docs.rs/criterion/latest/criterion/index.html) is a statistics-driven micro-benchmarking framework used by DataFusion for evaluating the performance of specific code-paths. In particular, the criterion benchmarks help to both guide optimisation efforts, and prevent performance regressions within DataFusion. + +Criterion integrates with Cargo's built-in [benchmark support](https://doc.rust-lang.org/cargo/commands/cargo-bench.html) and a given benchmark can be run with + +``` +cargo bench --bench BENCHMARK_NAME +``` + +A full list of benchmarks can be found [here](https://github.com/apache/datafusion/tree/main/datafusion/core/benches). + +_[cargo-criterion](https://github.com/bheisler/cargo-criterion) may also be used for more advanced reporting._ + +### Parquet SQL Benchmarks + +The parquet SQL benchmarks can be run with + +``` + cargo bench --bench parquet_query_sql +``` + +These randomly generate a parquet file, and then benchmark queries sourced from [parquet_query_sql.sql](../../../datafusion/core/benches/parquet_query_sql.sql) against it. This can therefore be a quick way to add coverage of particular query and/or data paths. + +If the environment variable `PARQUET_FILE` is set, the benchmark will run queries against this file instead of a randomly generated one. This can be useful for performing multiple runs, potentially with different code, against the same source data, or for testing against a custom dataset. + +The benchmark will automatically remove any generated parquet file on exit, however, if interrupted (e.g. by CTRL+C) it will not. This can be useful for analysing the particular file after the fact, or preserving it to use with `PARQUET_FILE` in subsequent runs. + +### Comparing Baselines + +By default, Criterion.rs will compare the measurements against the previous run (if any). Sometimes it's useful to keep a set of measurements around for several runs. For example, you might want to make multiple changes to the code while comparing against the master branch. For this situation, Criterion.rs supports custom baselines. + +``` + git checkout main + cargo bench --bench sql_planner -- --save-baseline main + git checkout YOUR_BRANCH + cargo bench --bench sql_planner -- --baseline main +``` + +Note: For MacOS it may be required to run `cargo bench` with `sudo` + +``` +sudo cargo bench ... +``` + +More information on [Baselines](https://bheisler.github.io/criterion.rs/book/user_guide/command_line_options.html#baselines) + +### Upstream Benchmark Suites + +Instructions and tooling for running upstream benchmark suites against DataFusion can be found in [benchmarks](https://github.com/apache/datafusion/tree/main/benchmarks). + +These are valuable for comparative evaluation against alternative Arrow implementations and query engines. diff --git a/docs/source/download.md b/docs/source/download.md new file mode 100644 index 000000000000..33a6d7008877 --- /dev/null +++ b/docs/source/download.md @@ -0,0 +1,69 @@ + + +# Download + +While DataFusion is also distributed via the Rust [crates.io] package manager as a convenience, the +official Apache DataFusion releases are provided as source artifacts. + +[crates.io]: https://crates.io/crates/datafusion + +## Releases + +The latest source release is [41.0.0][source-link] ([asc][asc-link], +[sha512][sha512-link]). + +[source-link]: https://www.apache.org/dyn/closer.lua/datafusion/datafusion-41.0.0/apache-datafusion-41.0.0.tar.gz?action=download +[asc-link]: https://downloads.apache.org/datafusion/datafusion-41.0.0/apache-datafusion-41.0.0.tar.gz.asc +[sha512-link]: https://downloads.apache.org/datafusion/datafusion-41.0.0/apache-datafusion-41.0.0.tar.gz.sha512 + +For previous releases, please check the [archive](https://archive.apache.org/dist/datafusion/). + +For releases earlier than 37.0.0, please check [Arrow's archive](https://archive.apache.org/dist/arrow/). + +## Notes + +- When downloading a release, please verify the OpenPGP compatible signature (or failing that, check the SHA-512); these should be fetched from the main Apache site. +- The KEYS file contains the public keys used for signing release. It is recommended that (when possible) a web of trust is used to confirm the identity of these keys. +- Please download the [KEYS](https://downloads.apache.org/datafusion/KEYS) as well as the .asc signature files. + +### To verify the signature of the release artifact + +You will need to download both the release artifact and the .asc signature file for that artifact. Then verify the signature by: + +- Download the KEYS file and the .asc signature files for the relevant release artifacts. +- Import the KEYS file to your GPG keyring: + + ```shell + gpg --import KEYS + ``` + +- Verify the signature of the release artifact using the following command: + + ```shell + gpg --verify .asc + ``` + +### To verify the checksum of the release artifact + +You will need to download both the release artifact and the .sha512 checksum file for that artifact. Then verify the checksum by: + +```shell +shasum -a 512 -c .sha512 +``` diff --git a/docs/source/index.rst b/docs/source/index.rst index 5944d346ca95..9008950d3dd6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,22 +32,45 @@ Apache DataFusion Fork

-DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in -`Rust `_, using the `Apache Arrow `_ -in-memory format. -DataFusion offers SQL and Dataframe APIs, excellent -`performance `_, built-in support for -CSV, Parquet, JSON, and Avro, extensive customization, and a great -community. +DataFusion is an extensible query engine written in `Rust `_ that +uses `Apache Arrow `_ as its in-memory format. -The `example usage`_ section in the user guide and the `datafusion-examples`_ code in the crate contain information on using DataFusion. +The documentation on this site is for the `core DataFusion project `_, which contains +libraries and binaries for developers building fast and feature rich database and analytic systems, +customized to particular workloads. See `use cases `_ for examples. -Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. +The following related subprojects target end users and have separate documentation. + +- `DataFusion Python `_ offers a Python interface for SQL and DataFrame + queries. +- `DataFusion Ray `_ provides a distributed version of DataFusion + that scales out on `Ray `_ clusters. +- `DataFusion Comet `_ is an accelerator for Apache Spark based on + DataFusion. + +"Out of the box," DataFusion offers `SQL `_ +and `Dataframe `_ APIs, +excellent `performance `_, built-in support for CSV, Parquet, JSON, and Avro, +extensive customization, and a great community. +`Python Bindings `_ are also available. + +DataFusion features a full query planner, a columnar, streaming, multi-threaded, +vectorized execution engine, and partitioned data sources. You can +customize DataFusion at almost all points including additional data sources, +query languages, functions, custom operators and more. +See the `Architecture `_ section for more details. + +To get started, see + +* The `example usage`_ section of the user guide and the `datafusion-examples`_ directory. +* The `library user guide`_ for examples of using DataFusion's extension APIs +* The `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html -.. _datafusion-examples: https://github.com/apache/datafusion/tree/master/datafusion-examples +.. _datafusion-examples: https://github.com/apache/datafusion/tree/main/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide +.. _library user guide: library-user-guide/index.html .. _communication: contributor-guide/communication.html .. _toc.asf-links: @@ -66,10 +89,12 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for :maxdepth: 1 :caption: Links - Github and Issue Tracker + GitHub and Issue Tracker crates.io API Docs + Blog Code of conduct + Download .. _toc.guide: .. toctree:: @@ -78,11 +103,13 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for user-guide/introduction user-guide/example-usage + user-guide/crate-configuration user-guide/cli/index user-guide/dataframe user-guide/expressions user-guide/sql/index user-guide/configs + user-guide/explain-usage user-guide/faq .. _toc.library-user-guide: @@ -92,6 +119,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for :caption: Library User Guide library-user-guide/index + library-user-guide/extensions library-user-guide/using-the-sql-api library-user-guide/working-with-exprs library-user-guide/using-the-dataframe-api @@ -101,7 +129,8 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for library-user-guide/custom-table-providers library-user-guide/extending-operators library-user-guide/profiling - + library-user-guide/query-optimizer + library-user-guide/api-health .. _toc.contributor-guide: .. toctree:: @@ -110,7 +139,21 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for contributor-guide/index contributor-guide/communication + contributor-guide/getting_started contributor-guide/architecture + contributor-guide/testing + contributor-guide/howtos contributor-guide/roadmap - contributor-guide/quarterly_roadmap + contributor-guide/governance + contributor-guide/inviting contributor-guide/specification/index + +.. _toc.subprojects: + +.. toctree:: + :maxdepth: 1 + :caption: DataFusion Subprojects + + DataFusion Ballista + DataFusion Comet + DataFusion Python diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index f805f0a99292..fe3990b90c3c 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -268,7 +268,7 @@ impl PartitionEvaluator for MyPartitionEvaluator { } } -/// Create a `PartitionEvalutor` to evaluate this function on a new +/// Create a `PartitionEvaluator` to evaluate this function on a new /// partition. fn make_partition_evaluator() -> Result> { Ok(Box::new(MyPartitionEvaluator::new())) @@ -474,7 +474,7 @@ impl Accumulator for GeometricMean { ### registering an Aggregate UDF -To register a Aggreate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. +To register a Aggregate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust diff --git a/docs/source/library-user-guide/api-health.md b/docs/source/library-user-guide/api-health.md new file mode 100644 index 000000000000..943a370e8172 --- /dev/null +++ b/docs/source/library-user-guide/api-health.md @@ -0,0 +1,37 @@ + + +# API health policy + +To maintain API health, developers must track and properly deprecate outdated methods. +When deprecating a method: + +- clearly mark the API as deprecated and specify the exact DataFusion version in which it was deprecated. +- concisely describe the preferred API, if relevant + +API deprecation example: + +```rust + #[deprecated(since = "41.0.0", note = "Use SessionStateBuilder")] + pub fn new_with_config_rt(config: SessionConfig, runtime: Arc) -> Self +``` + +Deprecated methods will remain in the codebase for a period of 6 major versions or 6 months, whichever is longer, to provide users ample time to transition away from them. + +Please refer to [DataFusion releases](https://crates.io/crates/datafusion/versions) to plan ahead API migration diff --git a/docs/source/library-user-guide/building-logical-plans.md b/docs/source/library-user-guide/building-logical-plans.md index fe922d8eaeb1..556deb02e980 100644 --- a/docs/source/library-user-guide/building-logical-plans.md +++ b/docs/source/library-user-guide/building-logical-plans.md @@ -31,44 +31,52 @@ explained in more detail in the [Query Planning and Execution Overview] section DataFusion's [LogicalPlan] is an enum containing variants representing all the supported operators, and also contains an `Extension` variant that allows projects building on DataFusion to add custom logical operators. -It is possible to create logical plans by directly creating instances of the [LogicalPlan] enum as follows, but is is +It is possible to create logical plans by directly creating instances of the [LogicalPlan] enum as shown, but it is much easier to use the [LogicalPlanBuilder], which is described in the next section. Here is an example of building a logical plan directly: - - ```rust -// create a logical table source -let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), -]); -let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - -// create a TableScan plan -let projection = None; // optional projection -let filters = vec![]; // optional filters to push down -let fetch = None; // optional LIMIT -let table_scan = LogicalPlan::TableScan(TableScan::try_new( - "person", - Arc::new(table_source), - projection, - filters, - fetch, -)?); - -// create a Filter plan that evaluates `id > 500` that wraps the TableScan -let filter_expr = col("id").gt(lit(500)); -let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); - -// print the plan -println!("{}", plan.display_indent_schema()); +use datafusion::common::DataFusionError; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::logical_expr::{Filter, LogicalPlan, TableScan, LogicalTableSource}; +use datafusion::prelude::*; +use std::sync::Arc; + +fn main() -> Result<(), DataFusionError> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // create a TableScan plan + let projection = None; // optional projection + let filters = vec![]; // optional filters to push down + let fetch = None; // optional LIMIT + let table_scan = LogicalPlan::TableScan(TableScan::try_new( + "person", + Arc::new(table_source), + projection, + filters, + fetch, + )? + ); + + // create a Filter plan that evaluates `id > 500` that wraps the TableScan + let filter_expr = col("id").gt(lit(500)); + let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan)) ? ); + + // print the plan + println!("{}", plan.display_indent_schema()); + Ok(()) +} ``` This example produces the following plan: -``` +```text Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] TableScan: person [id:Int32;N, name:Utf8;N] ``` @@ -78,7 +86,7 @@ Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] DataFusion logical plans can be created using the [LogicalPlanBuilder] struct. There is also a [DataFrame] API which is a higher-level API that delegates to [LogicalPlanBuilder]. -The following associated functions can be used to create a new builder: +There are several functions that can can be used to create a new builder, such as - `empty` - create an empty plan with no fields - `values` - create a plan from a set of literal values @@ -102,41 +110,107 @@ The following example demonstrates building the same simple query plan as the pr ```rust -// create a logical table source -let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), -]); -let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - -// optional projection -let projection = None; - -// create a LogicalPlanBuilder for a table scan -let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; - -// perform a filter operation and build the plan -let plan = builder - .filter(col("id").gt(lit(500)))? // WHERE id > 500 - .build()?; - -// print the plan -println!("{}", plan.display_indent_schema()); +use datafusion::common::DataFusionError; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::logical_expr::{LogicalPlanBuilder, LogicalTableSource}; +use datafusion::prelude::*; +use std::sync::Arc; + +fn main() -> Result<(), DataFusionError> { + // create a logical table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + let table_source = LogicalTableSource::new(SchemaRef::new(schema)); + + // optional projection + let projection = None; + + // create a LogicalPlanBuilder for a table scan + let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; + + // perform a filter operation and build the plan + let plan = builder + .filter(col("id").gt(lit(500)))? // WHERE id > 500 + .build()?; + + // print the plan + println!("{}", plan.display_indent_schema()); + Ok(()) +} ``` This example produces the following plan: -``` +```text Filter: person.id > Int32(500) [id:Int32;N, name:Utf8;N] TableScan: person [id:Int32;N, name:Utf8;N] ``` +## Translating Logical Plan to Physical Plan + +Logical plans can not be directly executed. They must be "compiled" into an +[`ExecutionPlan`], which is often referred to as a "physical plan". + +Compared to `LogicalPlan`s `ExecutionPlans` have many more details such as +specific algorithms and detailed optimizations compared to. Given a +`LogicalPlan` the easiest way to create an `ExecutionPlan` is using +[`SessionState::create_physical_plan`] as shown below + +```rust +use datafusion::datasource::{provider_as_source, MemTable}; +use datafusion::common::DataFusionError; +use datafusion::physical_plan::display::DisplayableExecutionPlan; +use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::logical_expr::{LogicalPlanBuilder, LogicalTableSource}; +use datafusion::prelude::*; +use std::sync::Arc; + +// Creating physical plans may access remote catalogs and data sources +// thus it must be run with an async runtime. +#[tokio::main] +async fn main() -> Result<(), DataFusionError> { + + // create a default table source + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, true), + Field::new("name", DataType::Utf8, true), + ]); + // To create an ExecutionPlan we must provide an actual + // TableProvider. For this example, we don't provide any data + // but in production code, this would have `RecordBatch`es with + // in memory data + let table_provider = Arc::new(MemTable::try_new(Arc::new(schema), vec![])?); + // Use the provider_as_source function to convert the TableProvider to a table source + let table_source = provider_as_source(table_provider); + + // create a LogicalPlanBuilder for a table scan without projection or filters + let logical_plan = LogicalPlanBuilder::scan("person", table_source, None)?.build()?; + + // Now create the physical plan by calling `create_physical_plan` + let ctx = SessionContext::new(); + let physical_plan = ctx.state().create_physical_plan(&logical_plan).await?; + + // print the plan + println!("{}", DisplayableExecutionPlan::new(physical_plan.as_ref()).indent(true)); + Ok(()) +} +``` + +This example produces the following physical plan: + +```text +MemoryExec: partitions=0, partition_sizes=[] +``` + ## Table Sources -The previous example used a [LogicalTableSource], which is used for tests and documentation in DataFusion, and is also -suitable if you are using DataFusion to build logical plans but do not use DataFusion's physical planner. However, if you -want to use a [TableSource] that can be executed in DataFusion then you will need to use [DefaultTableSource], which is a -wrapper for a [TableProvider]. +The previous examples use a [LogicalTableSource], which is used for tests and documentation in DataFusion, and is also +suitable if you are using DataFusion to build logical plans but do not use DataFusion's physical planner. + +However, it is more common to use a [TableProvider]. To get a [TableSource] from a +[TableProvider], use [provider_as_source] or [DefaultTableSource]. [query planning and execution overview]: https://docs.rs/datafusion/latest/datafusion/index.html#query-planning-and-execution-overview [architecture guide]: https://docs.rs/datafusion/latest/datafusion/index.html#architecture @@ -145,5 +219,8 @@ wrapper for a [TableProvider]. [dataframe]: using-the-dataframe-api.md [logicaltablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/builder/struct.LogicalTableSource.html [defaulttablesource]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/struct.DefaultTableSource.html +[provider_as_source]: https://docs.rs/datafusion/latest/datafusion/datasource/default_table_source/fn.provider_as_source.html [tableprovider]: https://docs.rs/datafusion/latest/datafusion/datasource/provider/trait.TableProvider.html [tablesource]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.TableSource.html +[`executionplan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html +[`sessionstate::create_physical_plan`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html#method.create_physical_plan diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index f53ac6cfae97..f86cea0bda95 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -112,7 +112,7 @@ impl CustomDataSource { impl TableProvider for CustomDataSource { async fn scan( &self, - _state: &SessionState, + _state: &dyn Session, projection: Option<&Vec>, // filters and limit can be used here to inject some push-down operations if needed _filters: &[Expr], @@ -146,7 +146,7 @@ For filters that can be pushed down, they'll be passed to the `scan` method as t In order to use the custom table provider, we need to register it with DataFusion. This is done by creating a `TableProvider` and registering it with the `SessionContext`. ```rust -let mut ctx = SessionContext::new(); +let ctx = SessionContext::new(); let custom_table_provider = CustomDataSource::new(); ctx.register_table("custom_table", Arc::new(custom_table_provider)); diff --git a/docs/source/library-user-guide/extensions.md b/docs/source/library-user-guide/extensions.md new file mode 100644 index 000000000000..0c7c891f4b1e --- /dev/null +++ b/docs/source/library-user-guide/extensions.md @@ -0,0 +1,64 @@ + + +# Extensions List + +DataFusion tries to provide a good set of features "out of the box" to quickly +start with a working system, but it can't include every useful feature (e.g. +`TableProvider`s for all data formats). + +Thankfully one of the core features of DataFusion is a flexible extension API +that allows users to extend its behavior at all points. This page lists some +community maintained extensions available for DataFusion. These extensions are +not part of the core DataFusion project, and not under Apache Software +Foundation governance but we list them here to be useful to others in the +community. + +If you know of an available extension that is not listed below, please open a PR +to add it to this page. If there is some feature you would like to see in +DataFusion, please consider creating a new extension in the `datafusion-contrib` +project (see [below](#datafusion-contrib)). Please [contact] us via github issue, slack, or Discord and +we'll gladly set up a new repository for your extension. + +| Name | Type | Description | +| ---------------------------- | ----------------- | --------------------------------------------------------------------------------- | +| [DataFusion Table Providers] | [`TableProvider`] | Support for `PostgreSQL`, `MySQL`, `SQLite`, `DuckDB`, and `Flight SQL` | +| [DataFusion Federation] | Framework | Allows DataFusion to execute (part of) a query plan by a remote execution engine. | +| [DataFusion ORC] | [`TableProvider`] | [Apache ORC] file format | +| [DataFusion JSON Functions] | Functions | Scalar functions for querying JSON strings | + +[`tableprovider`]: https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProvider.html +[datafusion table providers]: https://github.com/datafusion-contrib/datafusion-table-providers +[datafusion federation]: https://github.com/datafusion-contrib/datafusion-federation +[datafusion orc]: https://github.com/datafusion-contrib/datafusion-orc +[apache orc]: https://orc.apache.org/ +[datafusion json functions]: https://github.com/datafusion-contrib/datafusion-functions-json + +## `datafusion-contrib` + +The [`datafusion-contrib`] project contains a collection of community maintained +extensions that are not part of the core DataFusion project, and not under +Apache Software Foundation governance but may be useful to others in the +community. If you are interested adding a feature to DataFusion, a new extension +in `datafusion-contrib` is likely a good place to start. Please [contact] us via +github issue, slack, or Discord and we'll gladly set up a new repository for +your extension. + +[`datafusion-contrib`]: https://github.com/datafusion-contrib +[contact]: ../contributor-guide/communication.md diff --git a/docs/source/library-user-guide/index.md b/docs/source/library-user-guide/index.md index 47257e0c926e..fd126a1120ed 100644 --- a/docs/source/library-user-guide/index.md +++ b/docs/source/library-user-guide/index.md @@ -19,8 +19,25 @@ # Introduction -The library user guide explains how to use the DataFusion library as a dependency in your Rust project. Please check out the user-guide for more details on how to use DataFusion's SQL and DataFrame APIs, or the contributor guide for details on how to contribute to DataFusion. +The library user guide explains how to use the DataFusion library as a +dependency in your Rust project and customize its behavior using its extension APIs. -If you haven't reviewed the [architecture section in the docs][docs], it's a useful place to get the lay of the land before starting down a specific path. +Please check out the [user guide] for getting started using +DataFusion's SQL and DataFrame APIs, or the [contributor guide] +for details on how to contribute to DataFusion. +If you haven't reviewed the [architecture section in the docs][docs], it's a +useful place to get the lay of the land before starting down a specific path. + +DataFusion is designed to be extensible at all points, including + +- [x] User Defined Functions (UDFs) +- [x] User Defined Aggregate Functions (UDAFs) +- [x] User Defined Table Source (`TableProvider`) for tables +- [x] User Defined `Optimizer` passes (plan rewrites) +- [x] User Defined `LogicalPlan` nodes +- [x] User Defined `ExecutionPlan` nodes + +[user guide]: ../user-guide/example-usage.md +[contributor guide]: ../contributor-guide/index.md [docs]: https://docs.rs/datafusion/latest/datafusion/#architecture diff --git a/docs/source/library-user-guide/profiling.md b/docs/source/library-user-guide/profiling.md index c8afe15f25a8..75f2394a22da 100644 --- a/docs/source/library-user-guide/profiling.md +++ b/docs/source/library-user-guide/profiling.md @@ -25,34 +25,44 @@ The section contains examples how to perform CPU profiling for Apache DataFusion ### Building a flamegraph -- [cargo-flamegraph](https://github.com/flamegraph-rs/flamegraph) +[Video: how to CPU profile DataFusion with a Flamegraph](https://youtu.be/2z11xtYw_xs) -Test: +A flamegraph is a visual representation of which functions are being run +You can create flamegraphs in many ways; The instructions below are for +[cargo-flamegraph](https://github.com/flamegraph-rs/flamegraph) which results +in images such as this: -```bash -CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --unit-test datafusion -- dataframe::tests::test_array_agg +![Flamegraph](../_static/images/flamegraph.svg) + +To create a flamegraph, you need to install the `flamegraph` tool: + +```shell +cargo install flamegraph ``` -Benchmark: +Then you can run the flamegraph tool with the `--` separator to pass arguments +to the binary you want to profile. -```bash -CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --bench sql_planner -- --bench +Example: Flamegraph for `datafusion-cli` executing `q28.sql`. Note this +must be run as root on Mac OSx to access DTrace. + +```shell +sudo flamegraph -- datafusion-cli -f q28.sq ``` -Open `flamegraph.svg` file with the browser +You can also invoke the flamegraph tool with `cargo` to profile a specific test or benchmark. -- dtrace with DataFusion CLI +Example: Flamegraph for a specific test: ```bash -git clone https://github.com/brendangregg/FlameGraph.git /tmp/fg -cd datafusion-cli -CARGO_PROFILE_RELEASE_DEBUG=true cargo build --release -echo "select * from table;" >> test.sql -sudo dtrace -c './target/debug/datafusion-cli -f test.sql' -o out.stacks -n 'profile-997 /execname == "datafusion-cli"/ { @[ustack(100)] = count(); }' -/tmp/fg/FlameGraph/stackcollapse.pl out.stacks | /tmp/fg/FlameGraph/flamegraph.pl > flamegraph.svg +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --unit-test datafusion -- dataframe::tests::test_array_agg ``` -Open `flamegraph.svg` file with the browser +Example: Flamegraph for a benchmark + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --bench sql_planner -- --bench +``` ### CPU profiling with XCode Instruments diff --git a/docs/source/library-user-guide/query-optimizer.md b/docs/source/library-user-guide/query-optimizer.md new file mode 100644 index 000000000000..5aacfaf59cb1 --- /dev/null +++ b/docs/source/library-user-guide/query-optimizer.md @@ -0,0 +1,336 @@ + + +# DataFusion Query Optimizer + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory +format. + +DataFusion has modular design, allowing individual crates to be re-used in other projects. + +This crate is a submodule of DataFusion that provides a query optimizer for logical plans, and +contains an extensive set of OptimizerRules that may rewrite the plan and/or its expressions so +they execute more quickly while still computing the same result. + +## Running the Optimizer + +The following code demonstrates the basic flow of creating the optimizer with a default set of optimization rules +and applying it to a logical plan to produce an optimized logical plan. + +```rust + +// We need a logical plan as the starting point. There are many ways to build a logical plan: +// +// The `datafusion-expr` crate provides a LogicalPlanBuilder +// The `datafusion-sql` crate provides a SQL query planner that can create a LogicalPlan from SQL +// The `datafusion` crate provides a DataFrame API that can create a LogicalPlan +let logical_plan = ... + +let mut config = OptimizerContext::default(); +let optimizer = Optimizer::new(&config); +let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; + +fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { + println!( + "After applying rule '{}':\n{}", + rule.name(), + plan.display_indent() + ) +} +``` + +## Providing Custom Rules + +The optimizer can be created with a custom set of rules. + +```rust +let optimizer = Optimizer::with_rules(vec![ + Arc::new(MyRule {}) +]); +``` + +## Writing Optimization Rules + +Please refer to the +[optimizer_rule.rs](../../datafusion-examples/examples/optimizer_rule.rs) +example to learn more about the general approach to writing optimizer rules and +then move onto studying the existing rules. + +All rules must implement the `OptimizerRule` trait. + +```rust +/// `OptimizerRule` transforms one ['LogicalPlan'] into another which +/// computes the same results, but in a potentially more efficient +/// way. If there are no suitable transformations for the input plan, +/// the optimizer can simply return it as is. +pub trait OptimizerRule { + /// Rewrite `plan` to an optimized form + fn optimize( + &self, + plan: &LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result; + + /// A human readable name for this optimizer rule + fn name(&self) -> &str; +} +``` + +### General Guidelines + +Rules typical walk the logical plan and walk the expression trees inside operators and selectively mutate +individual operators or expressions. + +Sometimes there is an initial pass that visits the plan and builds state that is used in a second pass that performs +the actual optimization. This approach is used in projection push down and filter push down. + +### Expression Naming + +Every expression in DataFusion has a name, which is used as the column name. For example, in this example the output +contains a single column with the name `"COUNT(aggregate_test_100.c9)"`: + +```text +> select count(c9) from aggregate_test_100; ++------------------------------+ +| COUNT(aggregate_test_100.c9) | ++------------------------------+ +| 100 | ++------------------------------+ +``` + +These names are used to refer to the columns in both subqueries as well as internally from one stage of the LogicalPlan +to another. For example: + +```text +> select "COUNT(aggregate_test_100.c9)" + 1 from (select count(c9) from aggregate_test_100) as sq; ++--------------------------------------------+ +| sq.COUNT(aggregate_test_100.c9) + Int64(1) | ++--------------------------------------------+ +| 101 | ++--------------------------------------------+ +``` + +### Implication + +Because DataFusion identifies columns using a string name, it means it is critical that the names of expressions are +not changed by the optimizer when it rewrites expressions. This is typically accomplished by renaming a rewritten +expression by adding an alias. + +Here is a simple example of such a rewrite. The expression `1 + 2` can be internally simplified to 3 but must still be +displayed the same as `1 + 2`: + +```text +> select 1 + 2; ++---------------------+ +| Int64(1) + Int64(2) | ++---------------------+ +| 3 | ++---------------------+ +``` + +Looking at the `EXPLAIN` output we can see that the optimizer has effectively rewritten `1 + 2` into effectively +`3 as "1 + 2"`: + +```text +> explain select 1 + 2; ++---------------+-------------------------------------------------+ +| plan_type | plan | ++---------------+-------------------------------------------------+ +| logical_plan | Projection: Int64(3) AS Int64(1) + Int64(2) | +| | EmptyRelation | +| physical_plan | ProjectionExec: expr=[3 as Int64(1) + Int64(2)] | +| | PlaceholderRowExec | +| | | ++---------------+-------------------------------------------------+ +``` + +If the expression name is not preserved, bugs such as [#3704](https://github.com/apache/datafusion/issues/3704) +and [#3555](https://github.com/apache/datafusion/issues/3555) occur where the expected columns can not be found. + +### Building Expression Names + +There are currently two ways to create a name for an expression in the logical plan. + +```rust +impl Expr { + /// Returns the name of this expression as it should appear in a schema. This name + /// will not include any CAST expressions. + pub fn display_name(&self) -> Result { + create_name(self) + } + + /// Returns a full and complete string representation of this expression. + pub fn canonical_name(&self) -> String { + format!("{}", self) + } +} +``` + +When comparing expressions to determine if they are equivalent, `canonical_name` should be used, and when creating a +name to be used in a schema, `display_name` should be used. + +### Utilities + +There are a number of utility methods provided that take care of some common tasks. + +### ExprVisitor + +The `ExprVisitor` and `ExprVisitable` traits provide a mechanism for applying a visitor pattern to an expression tree. + +Here is an example that demonstrates this. + +```rust +fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { + struct InSubqueryVisitor<'a> { + accum: &'a mut Vec, + } + + impl ExpressionVisitor for InSubqueryVisitor<'_> { + fn pre_visit(self, expr: &Expr) -> Result> { + if let Expr::InSubquery(_) = expr { + self.accum.push(expr.to_owned()); + } + Ok(Recursion::Continue(self)) + } + } + + expression.accept(InSubqueryVisitor { accum: extracted })?; + Ok(()) +} +``` + +### Rewriting Expressions + +The `MyExprRewriter` trait can be implemented to provide a way to rewrite expressions. This rule can then be applied +to an expression by calling `Expr::rewrite` (from the `ExprRewritable` trait). + +The `rewrite` method will perform a depth first walk of the expression and its children to rewrite an expression, +consuming `self` producing a new expression. + +```rust +let mut expr_rewriter = MyExprRewriter {}; +let expr = expr.rewrite(&mut expr_rewriter)?; +``` + +Here is an example implementation which will rewrite `expr BETWEEN a AND b` as `expr >= a AND expr <= b`. Note that the +implementation does not need to perform any recursion since this is handled by the `rewrite` method. + +```rust +struct MyExprRewriter {} + +impl ExprRewriter for MyExprRewriter { + fn mutate(&mut self, expr: Expr) -> Result { + match expr { + Expr::Between { + negated, + expr, + low, + high, + } => { + let expr: Expr = expr.as_ref().clone(); + let low: Expr = low.as_ref().clone(); + let high: Expr = high.as_ref().clone(); + if negated { + Ok(expr.clone().lt(low).or(expr.clone().gt(high))) + } else { + Ok(expr.clone().gt_eq(low).and(expr.clone().lt_eq(high))) + } + } + _ => Ok(expr.clone()), + } + } +} +``` + +### optimize_children + +Typically a rule is applied recursively to all operators within a query plan. Rather than duplicate +that logic in each rule, an `optimize_children` method is provided. This recursively invokes the `optimize` method on +the plan's children and then returns a node of the same type. + +```rust +fn optimize( + &self, + plan: &LogicalPlan, + _config: &mut OptimizerConfig, +) -> Result { + // recurse down and optimize children first + let plan = utils::optimize_children(self, plan, _config)?; + + ... +} +``` + +### Writing Tests + +There should be unit tests in the same file as the new rule that test the effect of the rule being applied to a plan +in isolation (without any other rule being applied). + +There should also be a test in `integration-tests.rs` that tests the rule as part of the overall optimization process. + +### Debugging + +The `EXPLAIN VERBOSE` command can be used to show the effect of each optimization rule on a query. + +In the following example, the `type_coercion` and `simplify_expressions` passes have simplified the plan so that it returns the constant `"3.2"` rather than doing a computation at execution time. + +```text +> explain verbose select cast(1 + 2.2 as string) as foo; ++------------------------------------------------------------+---------------------------------------------------------------------------+ +| plan_type | plan | ++------------------------------------------------------------+---------------------------------------------------------------------------+ +| initial_logical_plan | Projection: CAST(Int64(1) + Float64(2.2) AS Utf8) AS foo | +| | EmptyRelation | +| logical_plan after type_coercion | Projection: CAST(CAST(Int64(1) AS Float64) + Float64(2.2) AS Utf8) AS foo | +| | EmptyRelation | +| logical_plan after simplify_expressions | Projection: Utf8("3.2") AS foo | +| | EmptyRelation | +| logical_plan after unwrap_cast_in_comparison | SAME TEXT AS ABOVE | +| logical_plan after decorrelate_where_exists | SAME TEXT AS ABOVE | +| logical_plan after decorrelate_where_in | SAME TEXT AS ABOVE | +| logical_plan after scalar_subquery_to_join | SAME TEXT AS ABOVE | +| logical_plan after subquery_filter_to_join | SAME TEXT AS ABOVE | +| logical_plan after simplify_expressions | SAME TEXT AS ABOVE | +| logical_plan after eliminate_filter | SAME TEXT AS ABOVE | +| logical_plan after reduce_cross_join | SAME TEXT AS ABOVE | +| logical_plan after common_sub_expression_eliminate | SAME TEXT AS ABOVE | +| logical_plan after eliminate_limit | SAME TEXT AS ABOVE | +| logical_plan after projection_push_down | SAME TEXT AS ABOVE | +| logical_plan after rewrite_disjunctive_predicate | SAME TEXT AS ABOVE | +| logical_plan after reduce_outer_join | SAME TEXT AS ABOVE | +| logical_plan after filter_push_down | SAME TEXT AS ABOVE | +| logical_plan after limit_push_down | SAME TEXT AS ABOVE | +| logical_plan after single_distinct_aggregation_to_group_by | SAME TEXT AS ABOVE | +| logical_plan | Projection: Utf8("3.2") AS foo | +| | EmptyRelation | +| initial_physical_plan | ProjectionExec: expr=[3.2 as foo] | +| | PlaceholderRowExec | +| | | +| physical_plan after aggregate_statistics | SAME TEXT AS ABOVE | +| physical_plan after join_selection | SAME TEXT AS ABOVE | +| physical_plan after coalesce_batches | SAME TEXT AS ABOVE | +| physical_plan after repartition | SAME TEXT AS ABOVE | +| physical_plan after add_merge_exec | SAME TEXT AS ABOVE | +| physical_plan | ProjectionExec: expr=[3.2 as foo] | +| | PlaceholderRowExec | +| | | ++------------------------------------------------------------+---------------------------------------------------------------------------+ +``` + +[df]: https://crates.io/crates/datafusion diff --git a/docs/source/library-user-guide/using-the-dataframe-api.md b/docs/source/library-user-guide/using-the-dataframe-api.md index c4f4ecd4f137..7f3e28c255c6 100644 --- a/docs/source/library-user-guide/using-the-dataframe-api.md +++ b/docs/source/library-user-guide/using-the-dataframe-api.md @@ -19,129 +19,268 @@ # Using the DataFrame API -## What is a DataFrame +The [Users Guide] introduces the [`DataFrame`] API and this section describes +that API in more depth. -`DataFrame` in `DataFrame` is modeled after the Pandas DataFrame interface, and is a thin wrapper over LogicalPlan that adds functionality for building and executing those plans. +## What is a DataFrame? -```rust -pub struct DataFrame { - session_state: SessionState, - plan: LogicalPlan, -} -``` - -You can build up `DataFrame`s using its methods, similarly to building `LogicalPlan`s using `LogicalPlanBuilder`: - -```rust -let df = ctx.table("users").await?; +As described in the [Users Guide], DataFusion [`DataFrame`]s are modeled after +the [Pandas DataFrame] interface, and are implemented as thin wrapper over a +[`LogicalPlan`] that adds functionality for building and executing those plans. -// Create a new DataFrame sorted by `id`, `bank_account` -let new_df = df.select(vec![col("id"), col("bank_account")])? - .sort(vec![col("id")])?; - -// Build the same plan using the LogicalPlanBuilder -let plan = LogicalPlanBuilder::from(&df.to_logical_plan()) - .project(vec![col("id"), col("bank_account")])? - .sort(vec![col("id")])? - .build()?; -``` - -You can use `collect` or `execute_stream` to execute the query. +The simplest possible dataframe is one that scans a table and that table can be +in a file or in memory. ## How to generate a DataFrame -You can directly use the `DataFrame` API or generate a `DataFrame` from a SQL query. - -For example, to use `sql` to construct `DataFrame`: +You can construct [`DataFrame`]s programmatically using the API, similarly to +other DataFrame APIs. For example, you can read an in memory `RecordBatch` into +a `DataFrame`: ```rust -let ctx = SessionContext::new(); -// Register the in-memory table containing the data -ctx.register_table("users", Arc::new(create_memtable()?))?; -let dataframe = ctx.sql("SELECT * FROM users;").await?; +use std::sync::Arc; +use datafusion::prelude::*; +use datafusion::arrow::array::{ArrayRef, Int32Array}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // Register an in-memory table containing the following data + // id | bank_account + // ---|------------- + // 1 | 9000 + // 2 | 8000 + // 3 | 7000 + let data = RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef), + ("bank_account", Arc::new(Int32Array::from(vec![9000, 8000, 7000]))), + ])?; + // Create a DataFrame that scans the user table, and finds + // all users with a bank account at least 8000 + // and sorts the results by bank account in descending order + let dataframe = ctx + .read_batch(data)? + .filter(col("bank_account").gt_eq(lit(8000)))? // bank_account >= 8000 + .sort(vec![col("bank_account").sort(false, true)])?; // ORDER BY bank_account DESC + + Ok(()) +} ``` -To construct `DataFrame` using the API: +You can _also_ generate a `DataFrame` from a SQL query and use the DataFrame's APIs +to manipulate the output of the query. ```rust -let ctx = SessionContext::new(); -// Register the in-memory table containing the data -ctx.register_table("users", Arc::new(create_memtable()?))?; -let dataframe = ctx - .table("users") - .filter(col("a").lt_eq(col("b")))? - .sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; +use std::sync::Arc; +use datafusion::prelude::*; +use datafusion::assert_batches_eq; +use datafusion::arrow::array::{ArrayRef, Int32Array}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // Register the same in-memory table as the previous example + let data = RecordBatch::try_from_iter(vec![ + ("id", Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef), + ("bank_account", Arc::new(Int32Array::from(vec![9000, 8000, 7000]))), + ])?; + ctx.register_batch("users", data)?; + // Create a DataFrame using SQL + let dataframe = ctx.sql("SELECT * FROM users;") + .await? + // Note we can filter the output of the query using the DataFrame API + .filter(col("bank_account").gt_eq(lit(8000)))?; // bank_account >= 8000 + + let results = &dataframe.collect().await?; + + // use the `assert_batches_eq` macro to show the output + assert_batches_eq!( + vec![ + "+----+--------------+", + "| id | bank_account |", + "+----+--------------+", + "| 1 | 9000 |", + "| 2 | 8000 |", + "+----+--------------+", + ], + &results + ); + Ok(()) +} ``` ## Collect / Streaming Exec -DataFusion `DataFrame`s are "lazy", meaning they do not do any processing until they are executed, which allows for additional optimizations. +DataFusion [`DataFrame`]s are "lazy", meaning they do no processing until +they are executed, which allows for additional optimizations. -When you have a `DataFrame`, you can run it in one of three ways: +You can run a `DataFrame` in one of three ways: -1. `collect` which executes the query and buffers all the output into a `Vec` -2. `streaming_exec`, which begins executions and returns a `SendableRecordBatchStream` which incrementally computes output on each call to `next()` -3. `cache` which executes the query and buffers the output into a new in memory DataFrame. +1. `collect`: executes the query and buffers all the output into a `Vec` +2. `execute_stream`: begins executions and returns a `SendableRecordBatchStream` which incrementally computes output on each call to `next()` +3. `cache`: executes the query and buffers the output into a new in memory `DataFrame.` -You can just collect all outputs once like: +To collect all outputs into a memory buffer, use the `collect` method: ```rust -let ctx = SessionContext::new(); -let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; -let batches = df.collect().await?; +use datafusion::prelude::*; +use datafusion::error::Result; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // read the contents of a CSV file into a DataFrame + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + // execute the query and collect the results as a Vec + let batches = df.collect().await?; + for record_batch in batches { + println!("{record_batch:?}"); + } + Ok(()) +} ``` -You can also use stream output to incrementally generate output one `RecordBatch` at a time +Use `execute_stream` to incrementally generate output one `RecordBatch` at a time: ```rust -let ctx = SessionContext::new(); -let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; -let mut stream = df.execute_stream().await?; -while let Some(rb) = stream.next().await { - println!("{rb:?}"); +use datafusion::prelude::*; +use datafusion::error::Result; +use futures::stream::StreamExt; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // read example.csv file into a DataFrame + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + // begin execution (returns quickly, does not compute results) + let mut stream = df.execute_stream().await?; + // results are returned incrementally as they are computed + while let Some(record_batch) = stream.next().await { + println!("{record_batch:?}"); + } + Ok(()) } ``` # Write DataFrame to Files -You can also serialize `DataFrame` to a file. For now, `Datafusion` supports write `DataFrame` to `csv`, `json` and `parquet`. - -When writing a file, DataFusion will execute the DataFrame and stream the results to a file. +You can also write the contents of a `DataFrame` to a file. When writing a file, +DataFusion executes the `DataFrame` and streams the results to the output. +DataFusion comes with support for writing `csv`, `json` `arrow` `avro`, and +`parquet` files, and supports writing custom file formats via API (see +[`custom_file_format.rs`] for an example) -For example, to write a csv_file +For example, to read a CSV file and write it to a parquet file, use the +[`DataFrame::write_parquet`] method ```rust -let ctx = SessionContext::new(); -// Register the in-memory table containing the data -ctx.register_table("users", Arc::new(mem_table))?; -let dataframe = ctx.sql("SELECT * FROM users;").await?; - -dataframe - .write_csv("user_dataframe.csv", DataFrameWriteOptions::default(), None) - .await; +use datafusion::prelude::*; +use datafusion::error::Result; +use datafusion::dataframe::DataFrameWriteOptions; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // read example.csv file into a DataFrame + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + // stream the contents of the DataFrame to the `example.parquet` file + let target_path = tempfile::tempdir()?.path().join("example.parquet"); + df.write_parquet( + target_path.to_str().unwrap(), + DataFrameWriteOptions::new(), + None, // writer_options + ).await; + Ok(()) +} ``` -and the file will look like (Example Output): +[`custom_file_format.rs`]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/custom_file_format.rs -``` -id,bank_account -1,9000 +The output file will look like (Example Output): + +```sql +> select * from '../datafusion/core/example.parquet'; ++---+---+---+ +| a | b | c | ++---+---+---+ +| 1 | 2 | 3 | ++---+---+---+ ``` -## Transform between LogicalPlan and DataFrame +## Relationship between `LogicalPlan`s and `DataFrame`s -As shown above, `DataFrame` is just a very thin wrapper of `LogicalPlan`, so you can easily go back and forth between them. +The `DataFrame` struct is defined like this: ```rust -// Just combine LogicalPlan with SessionContext and you get a DataFrame -let ctx = SessionContext::new(); -// Register the in-memory table containing the data -ctx.register_table("users", Arc::new(mem_table))?; -let dataframe = ctx.sql("SELECT * FROM users;").await?; +use datafusion::execution::session_state::SessionState; +use datafusion::logical_expr::LogicalPlan; +pub struct DataFrame { + // state required to execute a LogicalPlan + session_state: Box, + // LogicalPlan that describes the computation to perform + plan: LogicalPlan, +} +``` -// get LogicalPlan in dataframe -let plan = dataframe.logical_plan().clone(); +As shown above, `DataFrame` is a thin wrapper of `LogicalPlan`, so you can +easily go back and forth between them. -// construct a DataFrame with LogicalPlan -let new_df = DataFrame::new(ctx.state(), plan); +```rust +use datafusion::prelude::*; +use datafusion::error::Result; +use datafusion::logical_expr::LogicalPlanBuilder; + +#[tokio::main] +async fn main() -> Result<()>{ + let ctx = SessionContext::new(); + // read example.csv file into a DataFrame + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + // You can easily get the LogicalPlan from the DataFrame + let (_state, plan) = df.into_parts(); + // Just combine LogicalPlan with SessionContext and you get a DataFrame + // get LogicalPlan in dataframe + let new_df = DataFrame::new(ctx.state(), plan); + Ok(()) +} ``` + +In fact, using the [`DataFrame`]s methods you can create the same +[`LogicalPlan`]s as when using [`LogicalPlanBuilder`]: + +```rust +use datafusion::prelude::*; +use datafusion::error::Result; +use datafusion::logical_expr::LogicalPlanBuilder; + +#[tokio::main] +async fn main() -> Result<()>{ + let ctx = SessionContext::new(); + // read example.csv file into a DataFrame + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + // Create a new DataFrame sorted by `id`, `bank_account` + let new_df = df.select(vec![col("a"), col("b")])? + .sort_by(vec![col("a")])?; + // Build the same plan using the LogicalPlanBuilder + // Similar to `SELECT a, b FROM example.csv ORDER BY a` + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + let (_state, plan) = df.into_parts(); // get the DataFrame's LogicalPlan + let plan = LogicalPlanBuilder::from(plan) + .project(vec![col("a"), col("b")])? + .sort_by(vec![col("a")])? + .build()?; + // prove they are the same + assert_eq!(new_df.logical_plan(), &plan); + Ok(()) +} +``` + +[users guide]: ../user-guide/dataframe.md +[pandas dataframe]: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html +[`dataframe`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html +[`logicalplan`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.LogicalPlan.html +[`logicalplanbuilder`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.LogicalPlanBuilder.html +[`dataframe::write_parquet`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.write_parquet diff --git a/docs/source/library-user-guide/using-the-sql-api.md b/docs/source/library-user-guide/using-the-sql-api.md index f4e85ee4e3a9..f78cf16f4cb6 100644 --- a/docs/source/library-user-guide/using-the-sql-api.md +++ b/docs/source/library-user-guide/using-the-sql-api.md @@ -19,4 +19,204 @@ # Using the SQL API +DataFusion has a full SQL API that allows you to interact with DataFusion using +SQL query strings. The simplest way to use the SQL API is to use the +[`SessionContext`] struct which provides the highest-level API for executing SQL +queries. + +To use SQL, you first register your data as a table and then run queries +using the [`SessionContext::sql`] method. For lower level control such as +preventing DDL, you can use [`SessionContext::sql_with_options`] or the +[`SessionState`] APIs + +## Registering Data Sources using `SessionContext::register*` + +The `SessionContext::register*` methods tell DataFusion the name of +the source and how to read data. Once registered, you can execute SQL queries +using the [`SessionContext::sql`] method referring to your data source as a table. + +The [`SessionContext::sql`] method returns a `DataFrame` for ease of +use. See the ["Using the DataFrame API"] section for more information on how to +work with DataFrames. + +### Read a CSV File + +```rust +use datafusion::error::Result; +use datafusion::prelude::*; +use arrow::record_batch::RecordBatch; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // register the "example" table + ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?; + // create a plan to run a SQL query + let df = ctx.sql("SELECT a, min(b) FROM example WHERE a <= b GROUP BY a LIMIT 100").await?; + // execute the plan and collect the results as Vec + let results: Vec = df.collect().await?; + // Use the assert_batches_eq macro to compare the results with expected output + datafusion::assert_batches_eq!(vec![ + "+---+----------------+", + "| a | min(example.b) |", + "+---+----------------+", + "| 1 | 2 |", + "+---+----------------+", + ], + &results + ); + Ok(()) +} +``` + +### Read an Apache Parquet file + +Similarly to CSV, you can register a Parquet file as a table using the `register_parquet` method. + +```rust +use datafusion::error::Result; +use datafusion::prelude::*; +#[tokio::main] +async fn main() -> Result<()> { + // create local session context + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + + // register parquet file with the execution context + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + // execute the query + let df = ctx.sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col", + ).await?; + + // execute the plan, and compare to the expected results + let results = df.collect().await?; + datafusion::assert_batches_eq!(vec![ + "+---------+------------+--------------------------------+", + "| int_col | double_col | alltypes_plain.date_string_col |", + "+---------+------------+--------------------------------+", + "| 1 | 10.1 | 03/01/09 |", + "| 1 | 10.1 | 04/01/09 |", + "| 1 | 10.1 | 02/01/09 |", + "+---------+------------+--------------------------------+", + ], + &results + ); + Ok(()) +} +``` + +### Read an Apache Avro file + +DataFusion can also read Avro files using the `register_avro` method. + +```rust +use datafusion::arrow::util::pretty; +use datafusion::error::Result; +use datafusion::prelude::*; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + // find the path to the avro test files + let testdata = datafusion::test_util::arrow_test_data(); + // register avro file with the execution context + let avro_file = &format!("{testdata}/avro/alltypes_plain.avro"); + ctx.register_avro("alltypes_plain", avro_file, AvroReadOptions::default()).await?; + + // execute the query + let df = ctx.sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col" + ).await?; + + // execute the plan, and compare to the expected results + let results = df.collect().await?; + datafusion::assert_batches_eq!(vec![ + "+---------+------------+--------------------------------+", + "| int_col | double_col | alltypes_plain.date_string_col |", + "+---------+------------+--------------------------------+", + "| 1 | 10.1 | 03/01/09 |", + "| 1 | 10.1 | 04/01/09 |", + "| 1 | 10.1 | 02/01/09 |", + "+---------+------------+--------------------------------+", + ], + &results + ); + Ok(()) +} +``` + +## Reading Multiple Files as a table + +It is also possible to read multiple files as a single table. This is done +with the ListingTableProvider which takes a list of file paths and reads them +as a single table, matching schemas as appropriate + Coming Soon + +```rust + +``` + +## Using `CREATE EXTERNAL TABLE` to register data sources via SQL + +You can also register files using SQL using the [`CREATE EXTERNAL TABLE`] +statement. + +[`create external table`]: ../user-guide/sql/ddl.md#create-external-table + +```rust +use datafusion::error::Result; +use datafusion::prelude::*; +#[tokio::main] +async fn main() -> Result<()> { + // create local session context + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + + // register parquet file using SQL + let ddl = format!( + "CREATE EXTERNAL TABLE alltypes_plain \ + STORED AS PARQUET LOCATION '{testdata}/alltypes_plain.parquet'" + ); + ctx.sql(&ddl).await?; + + // execute the query referring to the alltypes_plain table we just registered + let df = ctx.sql( + "SELECT int_col, double_col, CAST(date_string_col as VARCHAR) \ + FROM alltypes_plain \ + WHERE id > 1 AND tinyint_col < double_col", + ).await?; + + // execute the plan, and compare to the expected results + let results = df.collect().await?; + datafusion::assert_batches_eq!(vec![ + "+---------+------------+--------------------------------+", + "| int_col | double_col | alltypes_plain.date_string_col |", + "+---------+------------+--------------------------------+", + "| 1 | 10.1 | 03/01/09 |", + "| 1 | 10.1 | 04/01/09 |", + "| 1 | 10.1 | 02/01/09 |", + "+---------+------------+--------------------------------+", + ], + &results + ); + Ok(()) +} +``` + +[`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html +[`sessioncontext::sql`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.sql +[`sessioncontext::sql_with_options`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.sql_with_options +[`sessionstate`]: https://docs.rs/datafusion/latest/datafusion/execution/session_state/struct.SessionState.html +["using the dataframe api"]: ../library-user-guide/using-the-dataframe-api.md diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index e0c9e69eb6ed..e0b6f434a032 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -80,7 +80,11 @@ If you'd like to learn more about `Expr`s, before we get into the details of cre ## Rewriting `Expr`s -[rewrite_expr.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/rewrite_expr.rs) contains example code for rewriting `Expr`s. +There are several examples of rewriting and working with `Exprs`: + +- [expr_api.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs) +- [analyzer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/analyzer_rule.rs) +- [optimizer_rule.rs](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/optimizer_rule.rs) Rewriting Expressions is the process of taking an `Expr` and transforming it into another `Expr`. This is useful for a number of reasons, including: diff --git a/docs/source/user-guide/cli/datasources.md b/docs/source/user-guide/cli/datasources.md index c2c00b633479..2b11645c471a 100644 --- a/docs/source/user-guide/cli/datasources.md +++ b/docs/source/user-guide/cli/datasources.md @@ -166,8 +166,8 @@ Register a single file csv datasource with a header row. ```sql CREATE EXTERNAL TABLE test STORED AS CSV -WITH HEADER ROW -LOCATION '/path/to/aggregate_test_100.csv'; +LOCATION '/path/to/aggregate_test_100.csv' +OPTIONS ('has_header' 'true'); ``` Register a single file csv datasource with explicitly defined schema. diff --git a/docs/source/user-guide/cli/installation.md b/docs/source/user-guide/cli/installation.md index 3a71240783e5..a3dc4bd2bdb4 100644 --- a/docs/source/user-guide/cli/installation.md +++ b/docs/source/user-guide/cli/installation.md @@ -56,8 +56,9 @@ this to work. ```bash git clone https://github.com/apache/datafusion -cd arrow-datafusion -git checkout 12.0.0 +cd datafusion +# Note: the build can take a while docker build -f datafusion-cli/Dockerfile . --tag datafusion-cli -docker run -it -v $(your_data_location):/data datafusion-cli +# You can also bind persistent storage with `-v /path/to/data:/data` +docker run --rm -it datafusion-cli ``` diff --git a/docs/source/user-guide/cli/usage.md b/docs/source/user-guide/cli/usage.md index 617b462875c7..6a620fc69252 100644 --- a/docs/source/user-guide/cli/usage.md +++ b/docs/source/user-guide/cli/usage.md @@ -52,7 +52,7 @@ OPTIONS: --maxrows The max number of rows to display for 'Table' format - [default: 40] [possible values: numbers(0/10/...), inf(no limit)] + [possible values: numbers(0/10/...), inf(no limit)] [default: 40] --mem-pool-type Specify the memory pool type 'greedy' or 'fair', default to 'greedy' diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 3ee3778177c4..6a49fda668a9 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -35,78 +35,93 @@ Values are parsed according to the [same rules used in casts from Utf8](https:// If the value in the environment variable cannot be cast to the type of the configuration option, the default value will be used instead and a warning emitted. Environment variables are read during `SessionConfig` initialisation so they must be set beforehand and will not affect running sessions. -| key | default | description | -| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | -| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | -| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | -| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | -| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | -| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | false | If the file has a header | -| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | -| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | -| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | -| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | -| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | -| datafusion.execution.parquet.enable_page_index | true | If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | -| datafusion.execution.parquet.pruning | true | If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | -| datafusion.execution.parquet.skip_metadata | true | If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | -| datafusion.execution.parquet.metadata_size_hint | NULL | If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | -| datafusion.execution.parquet.pushdown_filters | false | If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | -| datafusion.execution.parquet.reorder_filters | false | If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | -| datafusion.execution.parquet.data_pagesize_limit | 1048576 | Sets best effort maximum size of data page in bytes | -| datafusion.execution.parquet.write_batch_size | 1024 | Sets write_batch_size in bytes | -| datafusion.execution.parquet.writer_version | 1.0 | Sets parquet writer version valid values are "1.0" and "2.0" | -| datafusion.execution.parquet.compression | zstd(3) | Sets default parquet compression codec Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_enabled | NULL | Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | Sets best effort maximum dictionary page size, in bytes | -| datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 37.1.0 | Sets "created by" property | -| datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | -| datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | -| datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_enabled | false | Sets if bloom filter is enabled for any column | -| datafusion.execution.parquet.bloom_filter_fpp | NULL | Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.bloom_filter_ndv | NULL | Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | -| datafusion.execution.parquet.allow_single_file_parallelism | true | Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | -| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | -| datafusion.execution.aggregate.scalar_update_factor | 10 | Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. | -| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | -| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | -| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | -| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | -| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | -| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | -| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | -| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | -| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | -| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | -| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | -| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | -| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | -| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | -| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | -| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | -| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | -| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | -| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | -| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | -| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | -| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | -| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | -| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | -| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | -| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | -| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | -| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | -| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | -| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | -| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | -| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | -| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | +| key | default | description | +| ----------------------------------------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| datafusion.catalog.create_default_catalog_and_schema | true | Whether the default catalog and schema should be created automatically. | +| datafusion.catalog.default_catalog | datafusion | The default catalog name - this impacts what SQL queries use if not specified | +| datafusion.catalog.default_schema | public | The default schema name - this impacts what SQL queries use if not specified | +| datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | +| datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | +| datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | +| datafusion.catalog.has_header | true | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | +| datafusion.catalog.newlines_in_values | false | Specifies whether newlines in (quoted) CSV values are supported. This is the default value for `format.newlines_in_values` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. Parsing newlines in quoted values may be affected by execution behaviour such as parallel file scanning. Setting this to `true` ensures that newlines in values are parsed successfully, which may reduce performance. | +| datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | +| datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | +| datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | +| datafusion.execution.target_partitions | 0 | Number of partitions for query execution. Increasing partitions can increase concurrency. Defaults to the number of CPU cores on the system | +| datafusion.execution.time_zone | +00:00 | The default time zone Some functions, e.g. `EXTRACT(HOUR from SOME_TIME)`, shift the underlying datetime according to this time zone, and then extract the hour | +| datafusion.execution.parquet.enable_page_index | true | (reading) If true, reads the Parquet data page level metadata (the Page Index), if present, to reduce the I/O and number of rows decoded. | +| datafusion.execution.parquet.pruning | true | (reading) If true, the parquet reader attempts to skip entire row groups based on the predicate in the query and the metadata (min/max values) stored in the parquet file | +| datafusion.execution.parquet.skip_metadata | true | (reading) If true, the parquet reader skip the optional embedded metadata that may be in the file Schema. This setting can help avoid schema conflicts when querying multiple parquet files with schemas containing compatible types but different metadata | +| datafusion.execution.parquet.metadata_size_hint | NULL | (reading) If specified, the parquet reader will try and fetch the last `size_hint` bytes of the parquet file optimistically. If not specified, two reads are required: One read to fetch the 8-byte parquet footer and another to fetch the metadata length encoded in the footer | +| datafusion.execution.parquet.pushdown_filters | false | (reading) If true, filter expressions are be applied during the parquet decoding operation to reduce the number of rows decoded. This optimization is sometimes called "late materialization". | +| datafusion.execution.parquet.reorder_filters | false | (reading) If true, filter expressions evaluated during the parquet decoding operation will be reordered heuristically to minimize the cost of evaluation. If false, the filters are applied in the same order as written in the query | +| datafusion.execution.parquet.schema_force_view_types | true | (reading) If true, parquet reader will read columns of `Utf8/Utf8Large` with `Utf8View`, and `Binary/BinaryLarge` with `BinaryView`. | +| datafusion.execution.parquet.binary_as_string | false | (reading) If true, parquet reader will read columns of `Binary/LargeBinary` with `Utf8`, and `BinaryView` with `Utf8View`. Parquet files generated by some legacy writers do not correctly set the UTF8 flag for strings, causing string columns to be loaded as BLOB instead. | +| datafusion.execution.parquet.data_pagesize_limit | 1048576 | (writing) Sets best effort maximum size of data page in bytes | +| datafusion.execution.parquet.write_batch_size | 1024 | (writing) Sets write_batch_size in bytes | +| datafusion.execution.parquet.writer_version | 1.0 | (writing) Sets parquet writer version valid values are "1.0" and "2.0" | +| datafusion.execution.parquet.compression | zstd(3) | (writing) Sets default parquet compression codec. Valid values are: uncompressed, snappy, gzip(level), lzo, brotli(level), lz4, zstd(level), and lz4_raw. These values are not case sensitive. If NULL, uses default parquet writer setting Note that this default setting is not the same as the default parquet writer setting. | +| datafusion.execution.parquet.dictionary_enabled | true | (writing) Sets if dictionary encoding is enabled. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.dictionary_page_size_limit | 1048576 | (writing) Sets best effort maximum dictionary page size, in bytes | +| datafusion.execution.parquet.statistics_enabled | page | (writing) Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_statistics_size | 4096 | (writing) Sets max statistics size for any column. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.max_row_group_size | 1048576 | (writing) Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | +| datafusion.execution.parquet.created_by | datafusion version 43.0.0 | (writing) Sets "created by" property | +| datafusion.execution.parquet.column_index_truncate_length | 64 | (writing) Sets column index truncate length | +| datafusion.execution.parquet.data_page_row_count_limit | 20000 | (writing) Sets best effort maximum number of rows in data page | +| datafusion.execution.parquet.encoding | NULL | (writing) Sets default encoding for any column. Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_on_read | true | (writing) Use any available bloom filters when reading parquet files | +| datafusion.execution.parquet.bloom_filter_on_write | false | (writing) Write bloom filters for all columns when creating parquet files | +| datafusion.execution.parquet.bloom_filter_fpp | NULL | (writing) Sets bloom filter false positive probability. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.bloom_filter_ndv | NULL | (writing) Sets bloom filter number of distinct values. If NULL, uses default parquet writer setting | +| datafusion.execution.parquet.allow_single_file_parallelism | true | (writing) Controls whether DataFusion will attempt to speed up writing parquet files by serializing them in parallel. Each column in each row group in each output file are serialized in parallel leveraging a maximum possible core count of n_files*n_row_groups*n_columns. | +| datafusion.execution.parquet.maximum_parallel_row_group_writers | 1 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.parquet.maximum_buffered_record_batches_per_stream | 2 | (writing) By default parallel parquet writer is tuned for minimum memory usage in a streaming execution plan. You may see a performance benefit when writing large parquet files by increasing maximum_parallel_row_group_writers and maximum_buffered_record_batches_per_stream if your system has idle cores and can tolerate additional memory usage. Boosting these values is likely worthwhile when writing out already in-memory data, such as from a cached data frame. | +| datafusion.execution.planning_concurrency | 0 | Fan-out during initial physical planning. This is mostly use to plan `UNION` children in parallel. Defaults to the number of CPU cores on the system | +| datafusion.execution.skip_physical_aggregate_schema_check | false | When set to true, skips verifying that the schema produced by planning the input of `LogicalPlan::Aggregate` exactly matches the schema of the input plan. When set to false, if the schema does not match exactly (including nullability and metadata), a planning error will be raised. This is used to workaround bugs in the planner that are now caught by the new schema verification step. | +| datafusion.execution.sort_spill_reservation_bytes | 10485760 | Specifies the reserved memory for each spillable sort operation to facilitate an in-memory merge. When a sort operation spills to disk, the in-memory data must be sorted and merged before being written to a file. This setting reserves a specific amount of memory for that in-memory sort/merge process. Note: This setting is irrelevant if the sort operation cannot spill (i.e., if there's no `DiskManager` configured). | +| datafusion.execution.sort_in_place_threshold_bytes | 1048576 | When sorting, below what size should data be concatenated and sorted in a single RecordBatch rather than sorted in batches and merged. | +| datafusion.execution.meta_fetch_concurrency | 32 | Number of files to read in parallel when inferring schema and statistics | +| datafusion.execution.minimum_parallel_output_files | 4 | Guarantees a minimum level of output files running in parallel. RecordBatches will be distributed in round robin fashion to each parallel writer. Each writer is closed and a new file opened once soft_max_rows_per_output_file is reached. | +| datafusion.execution.soft_max_rows_per_output_file | 50000000 | Target number of rows in output files when writing multiple. This is a soft max, so it can be exceeded slightly. There also will be one file smaller than the limit if the total number of rows written is not roughly divisible by the soft max | +| datafusion.execution.max_buffered_batches_per_output_file | 2 | This is the maximum number of RecordBatches buffered for each output file being worked. Higher values can potentially give faster write performance at the cost of higher peak memory consumption | +| datafusion.execution.listing_table_ignore_subdirectory | true | Should sub directories be ignored when scanning directories for data files. Defaults to true (ignores subdirectories), consistent with Hive. Note that this setting does not affect reading partitioned tables (e.g. `/table/year=2021/month=01/data.parquet`). | +| datafusion.execution.enable_recursive_ctes | true | Should DataFusion support recursive CTEs | +| datafusion.execution.split_file_groups_by_statistics | false | Attempt to eliminate sorts by packing & sorting files with non-overlapping statistics into the same file groups. Currently experimental | +| datafusion.execution.keep_partition_by_columns | false | Should DataFusion keep the columns used for partition_by in the output RecordBatches | +| datafusion.execution.skip_partial_aggregation_probe_ratio_threshold | 0.8 | Aggregation ratio (number of distinct groups / number of input rows) threshold for skipping partial aggregation. If the value is greater then partial aggregation will skip aggregation for further input | +| datafusion.execution.skip_partial_aggregation_probe_rows_threshold | 100000 | Number of input rows partial aggregation partition should process, before aggregation ratio check and trying to switch to skipping aggregation mode | +| datafusion.execution.use_row_number_estimates_to_optimize_partitioning | false | Should DataFusion use row number estimates at the input to decide whether increasing parallelism is beneficial or not. By default, only exact row numbers (not estimates) are used for this decision. Setting this flag to `true` will likely produce better plans. if the source of statistics is accurate. We plan to make this the default in the future. | +| datafusion.execution.enforce_batch_size_in_joins | false | Should DataFusion enforce batch size in joins or not. By default, DataFusion will not enforce batch size in joins. Enforcing batch size in joins can reduce memory usage when joining large tables with a highly-selective join filter, but is also slightly slower. | +| datafusion.optimizer.enable_distinct_aggregation_soft_limit | true | When set to true, the optimizer will push a limit operation into grouped aggregations which have no aggregate expressions, as a soft limit, emitting groups once the limit is reached, before all rows in the group are read. | +| datafusion.optimizer.enable_round_robin_repartition | true | When set to true, the physical plan optimizer will try to add round robin repartitioning to increase parallelism to leverage more CPU cores | +| datafusion.optimizer.enable_topk_aggregation | true | When set to true, the optimizer will attempt to perform limit operations during aggregations, if possible | +| datafusion.optimizer.filter_null_join_keys | false | When set to true, the optimizer will insert filters before a join between a nullable and non-nullable column to filter out nulls on the nullable side. This filter can add additional overhead when the file format does not fully support predicate push down. | +| datafusion.optimizer.repartition_aggregations | true | Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_file_min_size | 10485760 | Minimum total files size in bytes to perform file scan repartitioning. | +| datafusion.optimizer.repartition_joins | true | Should DataFusion repartition data using the join keys to execute joins in parallel using the provided `target_partitions` level | +| datafusion.optimizer.allow_symmetric_joins_without_pruning | true | Should DataFusion allow symmetric hash joins for unbounded data sources even when its inputs do not have any ordering or filtering If the flag is not enabled, the SymmetricHashJoin operator will be unable to prune its internal buffers, resulting in certain join types - such as Full, Left, LeftAnti, LeftSemi, Right, RightAnti, and RightSemi - being produced only at the end of the execution. This is not typical in stream processing. Additionally, without proper design for long runner execution, all types of joins may encounter out-of-memory errors. | +| datafusion.optimizer.repartition_file_scans | true | When set to `true`, file groups will be repartitioned to achieve maximum parallelism. Currently Parquet and CSV formats are supported. If set to `true`, all files will be repartitioned evenly (i.e., a single large file might be partitioned into smaller chunks) for parallel scanning. If set to `false`, different files will be read in parallel, but repartitioning won't happen within a single file. | +| datafusion.optimizer.repartition_windows | true | Should DataFusion repartition data using the partitions keys to execute window functions in parallel using the provided `target_partitions` level | +| datafusion.optimizer.repartition_sorts | true | Should DataFusion execute sorts in a per-partition fashion and merge afterwards instead of coalescing first and sorting globally. With this flag is enabled, plans in the form below `text "SortExec: [a@0 ASC]", " CoalescePartitionsExec", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` would turn into the plan below which performs better in multithreaded environments `text "SortPreservingMergeExec: [a@0 ASC]", " SortExec: [a@0 ASC]", " RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1", ` | +| datafusion.optimizer.prefer_existing_sort | false | When true, DataFusion will opportunistically remove sorts when the data is already sorted, (i.e. setting `preserve_order` to true on `RepartitionExec` and using `SortPreservingMergeExec`) When false, DataFusion will maximize plan parallelism using `RepartitionExec` even if this requires subsequently resorting data using a `SortExec`. | +| datafusion.optimizer.skip_failed_rules | false | When set to true, the logical plan optimizer will produce warning messages if any optimization rules produce errors and then proceed to the next rule. When set to false, any rules that produce errors will cause the query to fail | +| datafusion.optimizer.max_passes | 3 | Number of times that the optimizer will attempt to optimize the plan | +| datafusion.optimizer.top_down_join_key_reordering | true | When set to true, the physical plan optimizer will run a top down process to reorder the join keys | +| datafusion.optimizer.prefer_hash_join | true | When set to true, the physical plan optimizer will prefer HashJoin over SortMergeJoin. HashJoin can work more efficiently than SortMergeJoin but consumes more memory | +| datafusion.optimizer.hash_join_single_partition_threshold | 1048576 | The maximum estimated size in bytes for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.hash_join_single_partition_threshold_rows | 131072 | The maximum estimated size in rows for one input side of a HashJoin will be collected into a single partition | +| datafusion.optimizer.default_filter_selectivity | 20 | The default filter selectivity used by Filter Statistics when an exact selectivity cannot be determined. Valid values are between 0 (no selectivity) and 100 (all rows are selected). | +| datafusion.optimizer.prefer_existing_union | false | When set to true, the optimizer will not attempt to convert Union to Interleave | +| datafusion.optimizer.expand_views_at_output | false | When set to true, if the returned type is a view type then the output will be coerced to a non-view. Coerces `Utf8View` to `LargeUtf8`, and `BinaryView` to `LargeBinary`. | +| datafusion.explain.logical_plan_only | false | When set to true, the explain statement will only print logical plans | +| datafusion.explain.physical_plan_only | false | When set to true, the explain statement will only print physical plans | +| datafusion.explain.show_statistics | false | When set to true, the explain statement will print operator statistics for physical plans | +| datafusion.explain.show_sizes | true | When set to true, the explain statement will print the partition sizes | +| datafusion.explain.show_schema | false | When set to true, the explain statement will print schema information | +| datafusion.sql_parser.parse_float_as_decimal | false | When set to true, SQL parser will parse float as decimal type | +| datafusion.sql_parser.enable_ident_normalization | true | When set to true, SQL parser will normalize ident (convert ident to lowercase when not quoted) | +| datafusion.sql_parser.enable_options_value_normalization | true | When set to true, SQL parser will normalize options value (convert value to lowercase) | +| datafusion.sql_parser.dialect | generic | Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, and Ansi. | +| datafusion.sql_parser.support_varchar_with_length | true | If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but ignore the length. If false, error if a `VARCHAR` with a length is specified. The Arrow type system does not have a notion of maximum string length and thus DataFusion can not enforce such limits. | diff --git a/docs/source/user-guide/crate-configuration.md b/docs/source/user-guide/crate-configuration.md new file mode 100644 index 000000000000..9d22e3403097 --- /dev/null +++ b/docs/source/user-guide/crate-configuration.md @@ -0,0 +1,188 @@ + + +# Crate Configuration + +This section contains information on how to configure DataFusion in your Rust +project. See the [Configuration Settings] section for a list of options that +control DataFusion's behavior. + +[configuration settings]: configs.md + +## Add latest non published DataFusion dependency + +DataFusion changes are published to `crates.io` according to the [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) + +If you would like to test out DataFusion changes which are merged but not yet +published, Cargo supports adding dependency directly to GitHub branch: + +```toml +datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} +``` + +Also it works on the package level + +```toml +datafusion-common = { git = "https://github.com/apache/datafusion", branch = "main", package = "datafusion-common"} +``` + +And with features + +```toml +datafusion = { git = "https://github.com/apache/datafusion", branch = "main", default-features = false, features = ["unicode_expressions"] } +``` + +More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) + +## Optimized Configuration + +For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is +worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. + +```toml +[dependencies] +datafusion = { version = "22.0" } +tokio = { version = "^1.0", features = ["rt-multi-thread"] } +snmalloc-rs = "0.3" + +[profile.release] +lto = true +codegen-units = 1 +``` + +Then, in `main.rs.` update the memory allocator with the below after your imports: + +```rust ,ignore +use datafusion::prelude::*; + +#[global_allocator] +static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; + +#[tokio::main] +async fn main() -> datafusion::error::Result<()> { + Ok(()) +} +``` + +Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally +with `native` or at least `avx2`. + +```shell +RUSTFLAGS='-C target-cpu=native' cargo run --release +``` + +## Enable backtraces + +By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, +like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: + +```toml +datafusion = { version = "31.0.0", features = ["backtrace"]} +``` + +Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) + +```bash +RUST_BACKTRACE=1 ./target/debug/datafusion-cli +DataFusion CLI v31.0.0 +> select row_numer() over (partition by a order by a) from (select 1 a); +Error during planning: Invalid function 'row_numer'. +Did you mean 'ROW_NUMBER'? + +backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 + 1: std::backtrace_rs::backtrace::trace_unsynchronized + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 + 2: std::backtrace::Backtrace::create + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 + 3: std::backtrace::Backtrace::capture + at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 + 4: datafusion_common::error::DataFusionError::get_back_trace + at /datafusion/datafusion/common/src/error.rs:436:30 + 5: datafusion_sql::expr::function::>::sql_function_to_expr + ............ +``` + +The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` + +```rust +#[tokio::test] +async fn test_get_backtrace_for_failed_code() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = " + select row_numer() over (partition by a order by a) from (select 1 a); + "; + + let _ = ctx.sql(sql).await?.collect().await?; + + Ok(()) +} +``` + +To obtain a backtrace: + +```bash +cargo build --features=backtrace +RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture + +running 1 test +Error: Plan("Invalid function 'row_numer'.\nDid you mean 'ROW_NUMBER'?\n\nbacktrace: 0: std::backtrace_rs::backtrace::libunwind::trace\n at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/../../backtrace/src/backtrace/libunwind.rs:105:5\n 1: std::backtrace_rs::backtrace::trace_unsynchronized\n... +``` + +Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored + +To show the backtrace in a pretty-printed format use `eprintln!("{e}");`. + +```rust +#[tokio::test] +async fn test_get_backtrace_for_failed_code() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select row_numer() over (partition by a order by a) from (select 1 a);"; + + let _ = match ctx.sql(sql).await { + Ok(result) => result.show().await?, + Err(e) => { + eprintln!("{e}"); + } + }; + + Ok(()) +} +``` + +Then run the test: + +```bash +$ RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture + +running 1 test +Error during planning: Invalid function 'row_numer'. +Did you mean 'ROW_NUMBER'? + +backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace + at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/../../backtrace/src/backtrace/libunwind.rs:105:5 + 1: std::backtrace_rs::backtrace::trace_unsynchronized + at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 + 2: std::backtrace::Backtrace::create + at /rustc/129f3b9964af4d4a709d1383930ade12dfe7c081/library/std/src/backtrace.rs:331:13 + 3: std::backtrace::Backtrace::capture + ... +``` diff --git a/docs/source/user-guide/dataframe.md b/docs/source/user-guide/dataframe.md index c0210200a246..96be1bb9e256 100644 --- a/docs/source/user-guide/dataframe.md +++ b/docs/source/user-guide/dataframe.md @@ -19,17 +19,30 @@ # DataFrame API -A DataFrame represents a logical set of rows with the same named columns, similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or -[Spark DataFrame](https://spark.apache.org/docs/latest/sql-programming-guide.html). +A DataFrame represents a logical set of rows with the same named columns, +similar to a [Pandas DataFrame] or [Spark DataFrame]. -DataFrames are typically created by calling a method on -`SessionContext`, such as `read_csv`, and can then be modified -by calling the transformation methods, such as `filter`, `select`, `aggregate`, and `limit` -to build up a query definition. +DataFrames are typically created by calling a method on [`SessionContext`], such +as [`read_csv`], and can then be modified by calling the transformation methods, +such as [`filter`], [`select`], [`aggregate`], and [`limit`] to build up a query +definition. -The query can be executed by calling the `collect` method. +The query can be executed by calling the [`collect`] method. -The DataFrame struct is part of DataFusion's prelude and can be imported with the following statement. +DataFusion DataFrames use lazy evaluation, meaning that each transformation +creates a new plan but does not actually perform any immediate actions. This +approach allows for the overall plan to be optimized before execution. The plan +is evaluated (executed) when an action method is invoked, such as [`collect`]. +See the [Library Users Guide] for more details. + +The DataFrame API is well documented in the [API reference on docs.rs]. +Please refer to the [Expressions Reference] for more information on +building logical expressions (`Expr`) to use with the DataFrame API. + +## Example + +The DataFrame struct is part of DataFusion's `prelude` and can be imported with +the following statement. ```rust use datafusion::prelude::*; @@ -38,71 +51,32 @@ use datafusion::prelude::*; Here is a minimal example showing the execution of a query using the DataFrame API. ```rust -let ctx = SessionContext::new(); -let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; -let df = df.filter(col("a").lt_eq(col("b")))? - .aggregate(vec![col("a")], vec![min(col("b"))])? - .limit(0, Some(100))?; -// Print results -df.show().await?; +use datafusion::prelude::*; +use datafusion::error::Result; +use datafusion::functions_aggregate::expr_fn::min; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?; + let df = df.filter(col("a").lt_eq(col("b")))? + .aggregate(vec![col("a")], vec![min(col("b"))])? + .limit(0, Some(100))?; + // Print results + df.show().await?; + Ok(()) +} ``` -The DataFrame API is well documented in the [API reference on docs.rs](https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html). - -Refer to the [Expressions Reference](expressions) for available functions for building logical expressions for use with the -DataFrame API. - -## DataFrame Transformations - -These methods create a new DataFrame after applying a transformation to the logical plan that the DataFrame represents. - -DataFusion DataFrames use lazy evaluation, meaning that each transformation is just creating a new query plan and -not actually performing any transformations. This approach allows for the overall plan to be optimized before -execution. The plan is evaluated (executed) when an action method is invoked, such as `collect`. - -| Function | Notes | -| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | -| aggregate | Perform an aggregate query with optional grouping expressions. | -| distinct | Filter out duplicate rows. | -| except | Calculate the exception of two DataFrames. The two DataFrames must have exactly the same schema | -| filter | Filter a DataFrame to only include rows that match the specified filter expression. | -| intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema | -| join | Join this DataFrame with another DataFrame using the specified columns as join keys. | -| join_on | Join this DataFrame with another DataFrame using arbitrary expressions. | -| limit | Limit the number of rows returned from this DataFrame. | -| repartition | Repartition a DataFrame based on a logical partitioning scheme. | -| sort | Sort the DataFrame by the specified sorting expressions. Any expression can be turned into a sort expression by calling its `sort` method. | -| select | Create a projection based on arbitrary expressions. Example: `df.select(vec![col("c1"), abs(col("c2"))])?` | -| select_columns | Create a projection based on column names. Example: `df.select_columns(&["id", "name"])?`. | -| union | Calculate the union of two DataFrames, preserving duplicate rows. The two DataFrames must have exactly the same schema. | -| union_distinct | Calculate the distinct union of two DataFrames. The two DataFrames must have exactly the same schema. | -| with_column | Add an additional column to the DataFrame. | -| with_column_renamed | Rename one column by applying a new projection. | - -## DataFrame Actions - -These methods execute the logical plan represented by the DataFrame and either collects the results into memory, prints them to stdout, or writes them to disk. - -| Function | Notes | -| -------------------------- | --------------------------------------------------------------------------------------------------------------------------- | -| collect | Executes this DataFrame and collects all results into a vector of RecordBatch. | -| collect_partitioned | Executes this DataFrame and collects all results into a vector of vector of RecordBatch maintaining the input partitioning. | -| count | Executes this DataFrame to get the total number of rows. | -| execute_stream | Executes this DataFrame and returns a stream over a single partition. | -| execute_stream_partitioned | Executes this DataFrame and returns one stream per partition. | -| show | Execute this DataFrame and print the results to stdout. | -| show_limit | Execute this DataFrame and print a subset of results to stdout. | -| write_csv | Execute this DataFrame and write the results to disk in CSV format. | -| write_json | Execute this DataFrame and write the results to disk in JSON format. | -| write_parquet | Execute this DataFrame and write the results to disk in Parquet format. | -| write_table | Execute this DataFrame and write the results via the insert_into method of the registered TableProvider | - -## Other DataFrame Methods - -| Function | Notes | -| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| explain | Return a DataFrame with the explanation of its plan so far. | -| registry | Return a `FunctionRegistry` used to plan udf's calls. | -| schema | Returns the schema describing the output of this DataFrame in terms of columns returned, where each column has a name, data type, and nullability attribute. | -| to_logical_plan | Return the optimized logical plan represented by this DataFrame. | -| to_unoptimized_plan | Return the unoptimized logical plan represented by this DataFrame. | +[pandas dataframe]: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html +[spark dataframe]: https://spark.apache.org/docs/latest/sql-programming-guide.html +[`sessioncontext`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html +[`read_csv`]: https://docs.rs/datafusion/latest/datafusion/execution/context/struct.SessionContext.html#method.read_csv +[`filter`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.filter +[`select`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.select +[`aggregate`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.aggregate +[`limit`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.limit +[`collect`]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.collect +[library users guide]: ../library-user-guide/using-the-dataframe-api.md +[api reference on docs.rs]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html +[expressions reference]: expressions diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 2fb4e55d698d..6108315f398a 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -30,33 +30,10 @@ crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml datafusion = "latest_version" -tokio = "1.0" +tokio = { version = "1.0", features = ["rt-multi-thread"] } ``` -## Add latest non published DataFusion dependency - -DataFusion changes are published to `crates.io` according to [release schedule](https://github.com/apache/datafusion/blob/main/dev/release/README.md#release-process) -In case if it is required to test out DataFusion changes which are merged but yet to be published, Cargo supports adding dependency directly to Github branch - -```toml -datafusion = { git = "https://github.com/apache/datafusion", branch = "main"} -``` - -Also it works on the package level - -```toml -datafusion-common = { git = "https://github.com/apache/datafusion", branch = "main", package = "datafusion-common"} -``` - -And with features - -```toml -datafusion = { git = "https://github.com/apache/datafusion", branch = "main", default-features = false, features = ["unicode_expressions"] } -``` - -More on [Cargo dependencies](https://doc.rust-lang.org/cargo/reference/specifying-dependencies.html#specifying-dependencies) - -## Run a SQL query against data stored in a CSV: +## Run a SQL query against data stored in a CSV ```rust use datafusion::prelude::*; @@ -76,10 +53,14 @@ async fn main() -> datafusion::error::Result<()> { } ``` -## Use the DataFrame API to process data stored in a CSV: +See [the SQL API](../library-user-guide/using-the-sql-api.md) section of the +library guide for more information on the SQL API. + +## Use the DataFrame API to process data stored in a CSV ```rust use datafusion::prelude::*; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> datafusion::error::Result<()> { @@ -168,6 +149,7 @@ async fn main() -> datafusion::error::Result<()> { ```rust use datafusion::prelude::*; +use datafusion::functions_aggregate::expr_fn::min; #[tokio::main] async fn main() -> datafusion::error::Result<()> { @@ -198,109 +180,3 @@ async fn main() -> datafusion::error::Result<()> { | 1 | 2 | +---+--------+ ``` - -## Extensibility - -DataFusion is designed to be extensible at all points. To that end, you can provide your own custom: - -- [x] User Defined Functions (UDFs) -- [x] User Defined Aggregate Functions (UDAFs) -- [x] User Defined Table Source (`TableProvider`) for tables -- [x] User Defined `Optimizer` passes (plan rewrites) -- [x] User Defined `LogicalPlan` nodes -- [x] User Defined `ExecutionPlan` nodes - -## Optimized Configuration - -For an optimized build several steps are required. First, use the below in your `Cargo.toml`. It is -worth noting that using the settings in the `[profile.release]` section will significantly increase the build time. - -```toml -[dependencies] -datafusion = { version = "22.0" } -tokio = { version = "^1.0", features = ["rt-multi-thread"] } -snmalloc-rs = "0.3" - -[profile.release] -lto = true -codegen-units = 1 -``` - -Then, in `main.rs.` update the memory allocator with the below after your imports: - -```rust ,ignore -use datafusion::prelude::*; - -#[global_allocator] -static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; - -#[tokio::main] -async fn main() -> datafusion::error::Result<()> { - Ok(()) -} -``` - -Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally -with `native` or at least `avx2`. - -```shell -RUSTFLAGS='-C target-cpu=native' cargo run --release -``` - -## Enable backtraces - -By default Datafusion returns errors as a plain message. There is option to enable more verbose details about the error, -like error backtrace. To enable a backtrace you need to add Datafusion `backtrace` feature to your `Cargo.toml` file: - -```toml -datafusion = { version = "31.0.0", features = ["backtrace"]} -``` - -Set environment [variables](https://doc.rust-lang.org/std/backtrace/index.html#environment-variables) - -```bash -RUST_BACKTRACE=1 ./target/debug/datafusion-cli -DataFusion CLI v31.0.0 -> select row_numer() over (partition by a order by a) from (select 1 a); -Error during planning: Invalid function 'row_numer'. -Did you mean 'ROW_NUMBER'? - -backtrace: 0: std::backtrace_rs::backtrace::libunwind::trace - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/libunwind.rs:93:5 - 1: std::backtrace_rs::backtrace::trace_unsynchronized - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/../../backtrace/src/backtrace/mod.rs:66:5 - 2: std::backtrace::Backtrace::create - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:332:13 - 3: std::backtrace::Backtrace::capture - at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/backtrace.rs:298:9 - 4: datafusion_common::error::DataFusionError::get_back_trace - at /datafusion/datafusion/common/src/error.rs:436:30 - 5: datafusion_sql::expr::function::>::sql_function_to_expr - ............ -``` - -The backtraces are useful when debugging code. If there is a test in `datafusion/core/src/physical_planner.rs` - -``` -#[tokio::test] -async fn test_get_backtrace_for_failed_code() -> Result<()> { - let ctx = SessionContext::new(); - - let sql = " - select row_numer() over (partition by a order by a) from (select 1 a); - "; - - let _ = ctx.sql(sql).await?.collect().await?; - - Ok(()) -} -``` - -To obtain a backtrace: - -```bash -cargo build --features=backtrace -RUST_BACKTRACE=1 cargo test --features=backtrace --package datafusion --lib -- physical_planner::tests::test_get_backtrace_for_failed_code --exact --nocapture -``` - -Note: The backtrace wrapped into systems calls, so some steps on top of the backtrace can be ignored diff --git a/docs/source/user-guide/explain-usage.md b/docs/source/user-guide/explain-usage.md new file mode 100644 index 000000000000..2eb03aad2ef9 --- /dev/null +++ b/docs/source/user-guide/explain-usage.md @@ -0,0 +1,382 @@ + + +# Reading Explain Plans + +## Introduction + +This section describes of how to read a DataFusion query plan. While fully +comprehending all details of these plans requires significant expertise in the +DataFusion engine, this guide will help you get started with the basics. + +Datafusion executes queries using a `query plan`. To see the plan without +running the query, add the keyword `EXPLAIN` to your SQL query or call the +[DataFrame::explain] method + +[dataframe::explain]: https://docs.rs/datafusion/latest/datafusion/dataframe/struct.DataFrame.html#method.explain + +## Example: Select and filter + +In this section, we run example queries against the `hits.parquet` file. See +[below](#data-in-this-example)) for information on how to get this file. + +Let's see how DataFusion runs a query that selects the top 5 watch lists for the +site `http://domcheloveplanet.ru/`: + +```sql +EXPLAIN SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip +FROM 'hits.parquet' +WHERE starts_with("URL", 'http://domcheloveplanet.ru/') +ORDER BY wid ASC, ip DESC +LIMIT 5; +``` + +The output will look like + +``` ++---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| plan_type | plan | ++---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| logical_plan | Sort: wid ASC NULLS LAST, ip DESC NULLS FIRST, fetch=5 | +| | Projection: hits.parquet.WatchID AS wid, hits.parquet.ClientIP AS ip | +| | Filter: starts_with(hits.parquet.URL, Utf8("http://domcheloveplanet.ru/")) | +| | TableScan: hits.parquet projection=[WatchID, ClientIP, URL], partial_filters=[starts_with(hits.parquet.URL, Utf8("http://domcheloveplanet.ru/"))] | +| physical_plan | SortPreservingMergeExec: [wid@0 ASC NULLS LAST,ip@1 DESC], fetch=5 | +| | SortExec: TopK(fetch=5), expr=[wid@0 ASC NULLS LAST,ip@1 DESC], preserve_partitioning=[true] | +| | ProjectionExec: expr=[WatchID@0 as wid, ClientIP@1 as ip] | +| | CoalesceBatchesExec: target_batch_size=8192 | +| | FilterExec: starts_with(URL@2, http://domcheloveplanet.ru/) | +| | ParquetExec: file_groups={16 groups: [[hits.parquet:0..923748528], ...]}, projection=[WatchID, ClientIP, URL], predicate=starts_with(URL@13, http://domcheloveplanet.ru/) | ++---------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +2 row(s) fetched. +Elapsed 0.060 seconds. +``` + +There are two sections: logical plan and physical plan + +- **Logical Plan:** is a plan generated for a specific SQL query, DataFrame, or other language without the + knowledge of the underlying data organization. +- **Physical Plan:** is a plan generated from a logical plan along with consideration of the hardware + configuration (e.g number of CPUs) and the underlying data organization (e.g number of files). + This physical plan is specific to your hardware configuration and your data. If you load the same + data to different hardware with different configurations, the same query may generate different query plans. + +Understanding a query plan can help to you understand its performance. For example, when the plan shows your query reads +many files, it signals you to either add more filter in the query to read less data or to modify your file +design to make fewer but larger files. This document focuses on how to read a query plan. How to make a +query run faster depends on the reason it is slow and beyond the scope of this document. + +## Query plans are trees + +A query plan is an upside down tree, and we always read from bottom up. The +physical plan in Figure 1 in tree format will look like + +``` + ▲ + │ + │ +┌─────────────────────────────────────────────────┐ +│ SortPreservingMergeExec │ +│ [wid@0 ASC NULLS LAST,ip@1 DESC] │ +│ fetch=5 │ +└─────────────────────────────────────────────────┘ + ▲ + │ +┌─────────────────────────────────────────────────┐ +│ SortExec TopK(fetch=5), │ +│ expr=[wid@0 ASC NULLS LAST,ip@1 DESC], │ +│ preserve_partitioning=[true] │ +└─────────────────────────────────────────────────┘ + ▲ + │ +┌─────────────────────────────────────────────────┐ +│ ProjectionExec │ +│ expr=[WatchID@0 as wid, ClientIP@1 as ip] │ +└─────────────────────────────────────────────────┘ + ▲ + │ +┌─────────────────────────────────────────────────┐ +│ CoalesceBatchesExec │ +└─────────────────────────────────────────────────┘ + ▲ + │ +┌─────────────────────────────────────────────────┐ +│ FilterExec │ +│ starts_with(URL@2, http://domcheloveplanet.ru/) │ +└─────────────────────────────────────────────────┘ + ▲ + │ +┌────────────────────────────────────────────────┐ +│ ParquetExec │ +│ hits.parquet (filter = ...) │ +└────────────────────────────────────────────────┘ +``` + +Each node in the tree/plan ends with `Exec` and is sometimes also called an `operator` or `ExecutionPlan` where data is +processed, transformed and sent up. + +1. First, data in parquet the `hits.parquet` file us read in parallel using 16 cores in 16 "partitions" (more on this later) from `ParquetExec`, which applies a first pass at filtering during the scan. +2. Next, the output is filtered using `FilterExec` to ensure only rows where `starts_with(URL, 'http://domcheloveplanet.ru/')` evaluates to true are passed on +3. The `CoalesceBatchesExec` then ensures that the data is grouped into larger batches for processing +4. The `ProjectionExec` then projects the data to rename the `WatchID` and `ClientIP` columns to `wid` and `ip` respectively. +5. The `SortExec` then sorts the data by `wid ASC, ip DESC`. The `Topk(fetch=5)` indicates that a special implementation is used that only tracks and emits the top 5 values in each partition. +6. Finally the `SortPreservingMergeExec` merges the sorted data from all partitions and returns the top 5 rows overall. + +## Understanding large query plans + +A large query plan may look intimidating, but you can quickly understand what it does by following these steps + +1. As always, read from bottom up, one operator at a time. +2. Understand the job of this operator by reading + the [Physical Plan documentation](https://docs.rs/datafusion/latest/datafusion/physical_plan/index.html). +3. Understand the input data of the operator and how large/small it may be. +4. Understand how much data that operator produces and what it would look like. + +If you can answer those questions, you will be able to estimate how much work +that plan has to do and thus how long it will take. However, the `EXPLAIN` just +shows you the plan without executing it. + +If you want to know more about how much work each operator in query plan does, +you can use the `EXPLAIN ANALYZE` to get the explain with runtime added (see +next section) + +## More Debugging Information: `EXPLAIN VERBOSE` + +If the plan has to read too many files, not all of them will be shown in the +`EXPLAIN`. To see them, use `EXPLAIN VEBOSE`. Like `EXPLAIN`, `EXPLAIN VERBOSE` +does not run the query. Instead it shows the full explain plan, with information +that is omitted from the default explain, as well as all intermediate physical +plans DataFusion generates before returning. This mode can be very helpful for +debugging to see why and when DataFusion added and removed operators from a plan. + +## Execution Counters: `EXPLAIN ANALYZE` + +During execution, DataFusion operators collect detailed metrics. You can access +them programmatically via [`ExecutionPlan::metrics`] as well as with the +`EXPLAIN ANALYZE` command. For example here is the same query query as +above but with `EXPLAIN ANALYZE` (note the output is edited for clarity) + +[`executionplan::metrics`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html#method.metrics + +``` +> EXPLAIN ANALYZE SELECT "WatchID" AS wid, "hits.parquet"."ClientIP" AS ip +FROM 'hits.parquet' +WHERE starts_with("URL", 'http://domcheloveplanet.ru/') +ORDER BY wid ASC, ip DESC +LIMIT 5; ++-------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| plan_type | plan | ++-------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Plan with Metrics | SortPreservingMergeExec: [wid@0 ASC NULLS LAST,ip@1 DESC], fetch=5, metrics=[output_rows=5, elapsed_compute=2.375µs] | +| | SortExec: TopK(fetch=5), expr=[wid@0 ASC NULLS LAST,ip@1 DESC], preserve_partitioning=[true], metrics=[output_rows=75, elapsed_compute=7.243038ms, row_replacements=482] | +| | ProjectionExec: expr=[WatchID@0 as wid, ClientIP@1 as ip], metrics=[output_rows=811821, elapsed_compute=66.25µs] | +| | FilterExec: starts_with(URL@2, http://domcheloveplanet.ru/), metrics=[output_rows=811821, elapsed_compute=1.36923816s] | +| | ParquetExec: file_groups={16 groups: [[hits.parquet:0..923748528], ...]}, projection=[WatchID, ClientIP, URL], predicate=starts_with(URL@13, http://domcheloveplanet.ru/), metrics=[output_rows=99997497, elapsed_compute=16ns, ... bytes_scanned=3703192723, ... time_elapsed_opening=308.203002ms, time_elapsed_scanning_total=8.350342183s, ...] | ++-------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +1 row(s) fetched. +Elapsed 0.720 seconds. +``` + +In this case, DataFusion actually ran the query, but discarded any results, and +instead returned an annotated plan with a new field, `metrics=[...]` + +Most operators have the common metrics `output_rows` and `elapsed_compute` and +some have operator specific metrics such as `ParquetExec` which has +`bytes_scanned=3703192723`. Note that times and counters are reported across all +cores, so if you have 16 cores, the time reported is the sum of the time taken +by all 16 cores. + +Again, reading from bottom up: + +- `ParquetExec` + - `output_rows=99997497`: A total 99.9M rows were produced + - `bytes_scanned=3703192723`: Of the 14GB file, 3.7GB were actually read (due to projection pushdown) + - `time_elapsed_opening=308.203002ms`: It took 300ms to open the file and prepare to read it + - `time_elapsed_scanning_total=8.350342183s`: It took 8.3 seconds of CPU time (across 16 cores) to actually decode the parquet data +- `FilterExec` + - `output_rows=811821`: Of the 99.9M rows at its input, only 811K rows passed the filter and were produced at the output + - `elapsed_compute=1.36923816s`: In total, 1.36s of CPU time (across 16 cores) was spend evaluating the filter +- `CoalesceBatchesExec` + - `output_rows=811821`, `elapsed_compute=12.873379ms`: Produced 811K rows in 13ms +- `ProjectionExec` + - `output_rows=811821, elapsed_compute=66.25µs`: Produced 811K rows in 66µs (microseconds). This projection is almost instantaneous as it does not manipulate any data +- `SortExec` + - `output_rows=75`: Produced 75 rows in total. Each of 16 cores could produce up to 5 rows, but in this case not all cores did. + - `elapsed_compute=7.243038ms`: 7ms was used to determine the top 5 rows + - `row_replacements=482`: Internally, the TopK operator updated its top list 482 times +- `SortPreservingMergeExec` + - `output_rows=5`, `elapsed_compute=2.375µs`: Produced the final 5 rows in 2.375µs (microseconds) + +When predicate pushdown is enabled, `ParquetExec` gains the following metrics: + +- `page_index_rows_matched`: number of rows in pages that were tested by a page index filter, and passed +- `page_index_rows_pruned`: number of rows in pages that were tested by a page index filter, and did not pass +- `row_groups_matched_bloom_filter`: number of rows in row groups that were tested by a Bloom Filter, and passed +- `row_groups_pruned_bloom_filter`: number of rows in row groups that were tested by a Bloom Filter, and did not pass +- `row_groups_matched_statistics`: number of rows in row groups that were tested by row group statistics (min and max value), and passed +- `row_groups_pruned_statistics`: number of rows in row groups that were tested by row group statistics (min and max value), and did not pass +- `pushdown_rows_matched`: rows that were tested by any of the above filtered, and passed all of them (this should be minimum of `page_index_rows_matched`, `row_groups_pruned_bloom_filter`, and `row_groups_pruned_statistics`) +- `pushdown_rows_pruned`: rows that were tested by any of the above filtered, and did not pass one of them (this should be sum of `page_index_rows_matched`, `row_groups_pruned_bloom_filter`, and `row_groups_pruned_statistics`) +- `predicate_evaluation_errors`: number of times evaluating the filter expression failed (expected to be zero in normal operation) +- `num_predicate_creation_errors`: number of errors creating predicates (expected to be zero in normal operation) +- `bloom_filter_eval_time`: time spent parsing and evaluating Bloom Filters +- `statistics_eval_time`: time spent parsing and evaluating row group-level statistics +- `row_pushdown_eval_time`: time spent evaluating row-level filters +- `page_index_eval_time`: time required to evaluate the page index filters + +## Partitions and Execution + +DataFusion determines the optimal number of cores to use as part of query +planning. Roughly speaking, each "partition" in the plan is run independently using +a separate core. Data crosses between cores only within certain operators such as +`RepartitionExec`, `CoalescePartitions` and `SortPreservingMergeExec` + +You can read more about this in the [Partitoning Docs]. + +[partitoning docs]: https://docs.rs/datafusion/latest/datafusion/physical_expr/enum.Partitioning.html + +## Example of an Aggregate Query + +Let us delve into an example query that aggregates data from the `hits.parquet` +file. For example, this query from ClickBench finds the top 10 users by their +number of hits: + +```sql +SELECT "UserID", COUNT(*) +FROM 'hits.parquet' +GROUP BY "UserID" +ORDER BY COUNT(*) DESC +LIMIT 10; +``` + +We can again see the query plan by using `EXPLAIN`: + +``` +> EXPLAIN SELECT "UserID", COUNT(*) FROM 'hits.parquet' GROUP BY "UserID" ORDER BY COUNT(*) DESC LIMIT 10; ++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| plan_type | plan | ++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| logical_plan | Limit: skip=0, fetch=10 | +| | Sort: count(*) DESC NULLS FIRST, fetch=10 | +| | Aggregate: groupBy=[[hits.parquet.UserID]], aggr=[[count(Int64(1)) AS count(*)]] | +| | TableScan: hits.parquet projection=[UserID] | +| physical_plan | GlobalLimitExec: skip=0, fetch=10 | +| | SortPreservingMergeExec: [count(*)@1 DESC], fetch=10 | +| | SortExec: TopK(fetch=10), expr=[count(*)@1 DESC], preserve_partitioning=[true] | +| | AggregateExec: mode=FinalPartitioned, gby=[UserID@0 as UserID], aggr=[count(*)] | +| | CoalesceBatchesExec: target_batch_size=8192 | +| | RepartitionExec: partitioning=Hash([UserID@0], 10), input_partitions=10 | +| | AggregateExec: mode=Partial, gby=[UserID@0 as UserID], aggr=[count(*)] | +| | ParquetExec: file_groups={10 groups: [[hits.parquet:0..1477997645], [hits.parquet:1477997645..2955995290], [hits.parquet:2955995290..4433992935], [hits.parquet:4433992935..5911990580], [hits.parquet:5911990580..7389988225], ...]}, projection=[UserID] | +| | | ++---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +``` + +For this query, let's again read the plan from the bottom to the top: + +**Logical plan operators** + +- `TableScan` + - `hits.parquet`: Scans data from the file `hits.parquet`. + - `projection=[UserID]`: Reads only the `UserID` column +- `Aggregate` + - `groupBy=[[hits.parquet.UserID]]`: Groups by `UserID` column. + - `aggr=[[count(Int64(1)) AS count(*)]]`: Applies the `COUNT` aggregate on each distinct group. +- `Sort` + - `count(*) DESC NULLS FIRST`: Sorts the data in descending count order. + - `fetch=10`: Returns only the first 10 rows. +- `Limit` + - `skip=0`: Does not skip any data for the results. + - `fetch=10`: Limits the results to 10 values. + +**Physical plan operators** + +- `ParquetExec` + - `file_groups={10 groups: [...]}`: Reads 10 groups in parallel from `hits.parquet`file. (The example above was run on a machine with 10 cores.) + - `projection=[UserID]`: Pushes down projection of the `UserID` column. The parquet format is columnar and the DataFusion reader only decodes the columns required. +- `AggregateExec` + - `mode=Partial` Runs a [partial aggregation] in parallel across each of the 10 partitions from the `ParquetExec` immediately after reading. + - `gby=[UserID@0 as UserID]`: Represents `GROUP BY` in the [physical plan] and groups together the same values of `UserID`. + - `aggr=[count(*)]`: Applies the `COUNT` aggregate on all rows for each group. +- `RepartitionExec` + - `partitioning=Hash([UserID@0], 10)`: Divides the input into into 10 (new) output partitions based on the value of `hash(UserID)`. You can read more about this in the [partitioning] documentation. + - `input_partitions=10`: Number of input partitions. +- `CoalesceBatchesExec` + - `target_batch_size=8192`: Combines smaller batches in to larger batches. In this case approximately 8192 rows in each batch. +- `AggregateExec` + - `mode=FinalPartitioned`: Performs the final aggregation on each group. See the [documentation on multi phase grouping] for more information. + - `gby=[UserID@0 as UserID]`: Groups by `UserID`. + - `aggr=[count(*)]`: Applies the `COUNT` aggregate on all rows for each group. +- `SortExec` + - `TopK(fetch=10)`: Use a special "TopK" sort that keeps only the largest 10 values in memory at a time. You can read more about this in the [TopK] documentation. + - `expr=[count(*)@1 DESC]`: Sorts all rows in descending order. Note this represents the `ORDER BY` in the physical plan. + - `preserve_partitioning=[true]`: The sort is done in parallel on each partition. In this case the top 10 values are found for each of the 10 partitions, in parallel. +- `SortPreservingMergeExec` + - `[count(*)@1 DESC]`: This operator merges the 10 distinct streams into a single stream using this expression. + - `fetch=10`: Returns only the first 10 rows +- `GlobalLimitExec` + - `skip=0`: Does not skip any rows + - `fetch=10`: Returns only the first 10 rows, denoted by `LIMIT 10` in the query. + +[partial aggregation]: https://docs.rs/datafusion/latest/datafusion/physical_plan/aggregates/enum.AggregateMode.html#variant.Partial +[physical plan]: https://docs.rs/datafusion/latest/datafusion/physical_plan/aggregates/struct.PhysicalGroupBy.html +[partitioning]: https://docs.rs/datafusion/latest/datafusion/physical_plan/repartition/struct.RepartitionExec.html +[topk]: https://docs.rs/datafusion/latest/datafusion/physical_plan/struct.TopK.html +[documentation on multi phase grouping]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.Accumulator.html#tymethod.state + +### Data in this Example + +The examples in this section use data from [ClickBench], a benchmark for data +analytics. The examples are in terms of the 14GB [`hits.parquet`] file and can be +downloaded from the website or using the following commands: + +```shell +cd benchmarks +./bench.sh data clickbench_1 +*************************** +DataFusion Benchmark Runner and Data Generator +COMMAND: data +BENCHMARK: clickbench_1 +DATA_DIR: /Users/andrewlamb/Software/datafusion2/benchmarks/data +CARGO_COMMAND: cargo run --release +PREFER_HASH_JOIN: true +*************************** +Checking hits.parquet...... found 14779976446 bytes ... Done +``` + +Then you can run `datafusion-cli` to get plans: + +```shell +cd datafusion/benchmarks/data +datafusion-cli + +DataFusion CLI v41.0.0 +> select count(*) from 'hits.parquet'; ++----------+ +| count(*) | ++----------+ +| 99997497 | ++----------+ +1 row(s) fetched. +Elapsed 0.062 seconds. +> +``` + +[clickbench]: https://benchmark.clickhouse.com/ +[`hits.parquet`]: https://datasets.clickhouse.com/hits_compatible/hits.parquet diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index a5fc13491677..ababb001f5c5 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -26,8 +26,10 @@ available for creating logical expressions. These are documented below. Most functions and methods may receive and return an `Expr`, which can be chained together using a fluent-style API: ```rust +use datafusion::prelude::*; // create the expression `(a > 6) AND (b < 7)` -col("a").gt(lit(6)).and(col("b").lt(lit(7))) +col("a").gt(lit(6)).and(col("b").lt(lit(7))); + ``` ::: @@ -67,7 +69,7 @@ value ::: :::{note} -Since `&&` and `||` are existed as logical operators in Rust, but those are not overloadable and not works with expression API. +Since `&&` and `||` are logical operators in Rust and cannot be overloaded these are not available in the expression API. ::: ## Bitwise Expressions @@ -149,7 +151,7 @@ but these operators always return a `bool` which makes them not work with the ex | trunc(x) | truncate toward zero | :::{note} -Unlike to some databases the math functions in Datafusion works the same way as Rust math functions, avoiding failing on corner cases e.g +Unlike to some databases the math functions in Datafusion works the same way as Rust math functions, avoiding failing on corner cases e.g. ```sql select log(-1), log(0), sqrt(-1); @@ -209,6 +211,7 @@ select log(-1), log(0), sqrt(-1); | Syntax | Description | | ---------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| array_any_value(array) | Returns the first non-null element in the array. `array_any_value([NULL, 1, 2, 3]) -> 1` | | array_append(array, element) | Appends an element to the end of an array. `array_append([1, 2, 3], 4) -> [1, 2, 3, 4]` | | array_concat(array[, ..., array_n]) | Concatenates arrays. `array_concat([1, 2, 3], [4, 5, 6]) -> [1, 2, 3, 4, 5, 6]` | | array_has(array, element) | Returns true if the array contains the element `array_has([1,2,3], 1) -> true` | @@ -241,7 +244,7 @@ select log(-1), log(0), sqrt(-1); | array_except(array1, array2) | Returns an array of the elements that appear in the first array but not in the second. `array_except([1, 2, 3, 4], [5, 6, 3, 4]) -> [1, 2]` | | array_resize(array, size, value) | Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. `array_resize([1, 2, 3], 5, 0) -> [1, 2, 3, 0, 0]` | | array_sort(array, desc, null_first) | Returns sorted array. `array_sort([3, 1, 2, 5, 4]) -> [1, 2, 3, 4, 5]` | -| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | +| cardinality(array/map) | Returns the total number of elements in the array or map. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | range(start [, stop, step]) | Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` | | string_to_array(array, delimiter, null_string) | Splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`. `string_to_array('abc#def#ghi', '#', ' ') -> ['abc', 'def', 'ghi']` | @@ -304,6 +307,16 @@ select log(-1), log(0), sqrt(-1); | rollup(exprs) | Creates a grouping set for rollup sets. | | sum(expr) | Сalculates the sum of `expr`. | +## Aggregate Function Builder + +You can also use the `ExprFunctionExt` trait to more easily build Aggregate arguments `Expr`. + +See `datafusion-examples/examples/expr_api.rs` for example usage. + +| Syntax | Equivalent to | +| ----------------------------------------------------------------------- | ----------------------------------- | +| first_value_udaf.call(vec![expr]).order_by(vec![expr]).build().unwrap() | first_value(expr, Some(vec![expr])) | + ## Subquery Expressions | Syntax | Description | diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 676543b04034..7c975055d152 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -22,7 +22,7 @@ DataFusion is a very fast, extensible query engine for building high-quality data-centric systems in [Rust](http://rustlang.org), using the [Apache Arrow](https://arrow.apache.org) in-memory format. -DataFusion is part of the [Apache Arrow](https://arrow.apache.org/) +DataFusion originated as part of the [Apache Arrow](https://arrow.apache.org/) project. DataFusion offers SQL and Dataframe APIs, excellent [performance](https://benchmark.clickhouse.com/), built-in support for CSV, Parquet, JSON, and Avro, [python bindings], extensive customization, a great community, and more. @@ -96,25 +96,26 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine -- [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin +- [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database +- [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python -- [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications - [delta-rs](https://github.com/delta-io/delta-rs) Native Rust implementation of Delta Lake -- [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database +- [Exon](https://github.com/wheretrue/exon) Analysis toolkit for life-science applications - [GlareDB](https://github.com/GlareDB/glaredb) Fast SQL database for querying and analyzing distributed data. +- [GreptimeDB](https://github.com/GreptimeTeam/greptimedb) Open Source & Cloud Native Distributed Time Series Database - [HoraeDB](https://github.com/apache/incubator-horaedb) Distributed Time-Series Database - [InfluxDB](https://github.com/influxdata/influxdb) Time Series Database - [Kamu](https://github.com/kamu-data/kamu-cli/) Planet-scale streaming data pipeline - [LakeSoul](https://github.com/lakesoul-io/LakeSoul) Open source LakeHouse framework with native IO in Rust. - [Lance](https://github.com/lancedb/lance) Modern columnar data format for ML -- [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform - [ParadeDB](https://github.com/paradedb/paradedb) PostgreSQL for Search & Analytics +- [Parseable](https://github.com/parseablehq/parseable) Log storage and observability platform - [qv](https://github.com/timvw/qv) Quickly view your data -- [bdt](https://github.com/andygrove/bdt) Boring Data Tool - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await - [ROAPI](https://github.com/roapi/roapi) +- [Sail](https://github.com/lakehq/sail) Unifying stream, batch, and AI workloads with Apache Spark compatibility - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database - [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine - [Synnada](https://synnada.ai/) Streaming-first framework for data products @@ -123,7 +124,7 @@ Here are some active projects using DataFusion: Here are some less active projects that used DataFusion: -- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core +- [bdt](https://github.com/datafusion-contrib/bdt) Boring Data Tool - [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion - [Flock](https://github.com/flock-lab/flock) diff --git a/docs/source/user-guide/sql/aggregate_functions.md b/docs/source/user-guide/sql/aggregate_functions.md index 427a7bf130a7..d9fc28a81772 100644 --- a/docs/source/user-guide/sql/aggregate_functions.md +++ b/docs/source/user-guide/sql/aggregate_functions.md @@ -17,12 +17,21 @@ under the License. --> + + # Aggregate Functions Aggregate functions operate on a set of values to compute a single result. -## General +## General Functions +- [array_agg](#array_agg) - [avg](#avg) - [bit_and](#bit_and) - [bit_or](#bit_or) @@ -30,14 +39,43 @@ Aggregate functions operate on a set of values to compute a single result. - [bool_and](#bool_and) - [bool_or](#bool_or) - [count](#count) +- [first_value](#first_value) +- [grouping](#grouping) +- [last_value](#last_value) - [max](#max) - [mean](#mean) - [median](#median) - [min](#min) +- [string_agg](#string_agg) - [sum](#sum) -- [array_agg](#array_agg) -- [first_value](#first_value) -- [last_value](#last_value) +- [var](#var) +- [var_pop](#var_pop) +- [var_population](#var_population) +- [var_samp](#var_samp) +- [var_sample](#var_sample) + +### `array_agg` + +Returns an array created from the expression elements. If ordering is required, elements are inserted in the specified order. + +``` +array_agg(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT array_agg(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| array_agg(column_name ORDER BY other_column) | ++-----------------------------------------------+ +| [element1, element2, element3] | ++-----------------------------------------------+ +``` ### `avg` @@ -49,12 +87,22 @@ avg(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT avg(column_name) FROM table_name; ++---------------------------+ +| avg(column_name) | ++---------------------------+ +| 42.75 | ++---------------------------+ +``` #### Aliases -- `mean` +- mean ### `bit_and` @@ -66,8 +114,7 @@ bit_and(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `bit_or` @@ -79,8 +126,7 @@ bit_or(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `bit_xor` @@ -92,8 +138,7 @@ bit_xor(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `bool_and` @@ -105,29 +150,45 @@ bool_and(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +``` ### `bool_or` -Returns true if any non-null input value is true, otherwise false. +Returns true if all non-null input values are true, otherwise false. ``` -bool_or(expression) +bool_and(expression) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `count` +#### Example -Returns the number of rows in the specified column. +```sql +> SELECT bool_and(column_name) FROM table_name; ++----------------------------+ +| bool_and(column_name) | ++----------------------------+ +| true | ++----------------------------+ +``` + +### `count` -Count includes _null_ values in the total count. -To exclude _null_ values from the total count, include ` IS NOT NULL` -in the `WHERE` clause. +Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`. ``` count(expression) @@ -135,8 +196,98 @@ count(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT count(column_name) FROM table_name; ++-----------------------+ +| count(column_name) | ++-----------------------+ +| 100 | ++-----------------------+ + +> SELECT count(*) FROM table_name; ++------------------+ +| count(*) | ++------------------+ +| 120 | ++------------------+ +``` + +### `first_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +``` + +### `grouping` + +Returns 1 if the data is aggregated across the specified column, or 0 if it is not aggregated in the result set. + +``` +grouping(expression) +``` + +#### Arguments + +- **expression**: Expression to evaluate whether data is aggregated across the specified column. Can be a constant, column, or function. + +#### Example + +```sql +> SELECT column_name, GROUPING(column_name) AS group_column + FROM table_name + GROUP BY GROUPING SETS ((column_name), ()); ++-------------+-------------+ +| column_name | group_column | ++-------------+-------------+ +| value1 | 0 | +| value2 | 0 | +| NULL | 1 | ++-------------+-------------+ +``` + +### `last_value` + +Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. + +``` +first_value(expression [ORDER BY expression]) +``` + +#### Arguments + +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT first_value(column_name ORDER BY other_column) FROM table_name; ++-----------------------------------------------+ +| first_value(column_name ORDER BY other_column)| ++-----------------------------------------------+ +| first_element | ++-----------------------------------------------+ +``` ### `max` @@ -148,8 +299,18 @@ max(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +``` ### `mean` @@ -165,86 +326,142 @@ median(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT median(column_name) FROM table_name; ++----------------------+ +| median(column_name) | ++----------------------+ +| 45.5 | ++----------------------+ +``` ### `min` -Returns the minimum value in the specified column. +Returns the maximum value in the specified column. ``` -min(expression) +max(expression) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `sum` +#### Example -Returns the sum of all values in the specified column. +```sql +> SELECT max(column_name) FROM table_name; ++----------------------+ +| max(column_name) | ++----------------------+ +| 150 | ++----------------------+ +``` + +### `string_agg` + +Concatenates the values of string expressions and places separator values between them. ``` -sum(expression) +string_agg(expression, delimiter) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The string expression to concatenate. Can be a column or any valid string expression. +- **delimiter**: A literal string used as a separator between the concatenated values. -### `array_agg` +#### Example + +```sql +> SELECT string_agg(name, ', ') AS names_list + FROM employee; ++--------------------------+ +| names_list | ++--------------------------+ +| Alice, Bob, Charlie | ++--------------------------+ +``` -Returns an array created from the expression elements. If ordering requirement is given, elements are inserted in the order of required ordering. +### `sum` + +Returns the sum of all values in the specified column. ``` -array_agg(expression [ORDER BY expression]) +sum(expression) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `first_value` +#### Example -Returns the first element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. +```sql +> SELECT sum(column_name) FROM table_name; ++-----------------------+ +| sum(column_name) | ++-----------------------+ +| 12345 | ++-----------------------+ +``` + +### `var` + +Returns the statistical sample variance of a set of numbers. ``` -first_value(expression [ORDER BY expression]) +var(expression) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `last_value` +#### Aliases + +- var_sample +- var_samp -Returns the last element in an aggregation group according to the requested ordering. If no ordering is given, returns an arbitrary element from the group. +### `var_pop` + +Returns the statistical population variance of a set of numbers. ``` -last_value(expression [ORDER BY expression]) +var_pop(expression) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -## Statistical +#### Aliases + +- var_population + +### `var_population` + +_Alias of [var_pop](#var_pop)._ + +### `var_samp` + +_Alias of [var](#var)._ + +### `var_sample` + +_Alias of [var](#var)._ + +## Statistical Functions - [corr](#corr) - [covar](#covar) - [covar_pop](#covar_pop) - [covar_samp](#covar_samp) -- [stddev](#stddev) -- [stddev_pop](#stddev_pop) -- [stddev_samp](#stddev_samp) -- [var](#var) -- [var_pop](#var_pop) -- [var_samp](#var_samp) +- [nth_value](#nth_value) - [regr_avgx](#regr_avgx) - [regr_avgy](#regr_avgy) - [regr_count](#regr_count) @@ -252,8 +469,11 @@ last_value(expression [ORDER BY expression]) - [regr_r2](#regr_r2) - [regr_slope](#regr_slope) - [regr_sxx](#regr_sxx) -- [regr_syy](#regr_syy) - [regr_sxy](#regr_sxy) +- [regr_syy](#regr_syy) +- [stddev](#stddev) +- [stddev_pop](#stddev_pop) +- [stddev_samp](#stddev_samp) ### `corr` @@ -265,40 +485,47 @@ corr(expression1, expression2) #### Arguments -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - -### `covar` +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. -Returns the covariance of a set of number pairs. +#### Example -``` -covar(expression1, expression2) +```sql +> SELECT corr(column1, column2) FROM table_name; ++--------------------------------+ +| corr(column1, column2) | ++--------------------------------+ +| 0.85 | ++--------------------------------+ ``` -#### Arguments +### `covar` -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +_Alias of [covar_samp](#covar_samp)._ ### `covar_pop` -Returns the population covariance of a set of number pairs. +Returns the sample covariance of a set of number pairs. ``` -covar_pop(expression1, expression2) +covar_samp(expression1, expression2) ``` #### Arguments -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ +``` ### `covar_samp` @@ -310,226 +537,225 @@ covar_samp(expression1, expression2) #### Arguments -- **expression1**: First expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Second expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression1**: First expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Second expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `stddev` - -Returns the standard deviation of a set of numbers. +#### Example -``` -stddev(expression) +```sql +> SELECT covar_samp(column1, column2) FROM table_name; ++-----------------------------------+ +| covar_samp(column1, column2) | ++-----------------------------------+ +| 8.25 | ++-----------------------------------+ ``` -#### Arguments +#### Aliases -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- covar -### `stddev_pop` +### `nth_value` -Returns the population standard deviation of a set of numbers. +Returns the nth value in a group of values. ``` -stddev_pop(expression) +nth_value(expression, n ORDER BY expression) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The column or expression to retrieve the nth value from. +- **n**: The position (nth) of the value to retrieve, based on the ordering. -### `stddev_samp` +#### Example + +```sql +> SELECT dept_id, salary, NTH_VALUE(salary, 2) OVER (PARTITION BY dept_id ORDER BY salary ASC) AS second_salary_by_dept + FROM employee; ++---------+--------+-------------------------+ +| dept_id | salary | second_salary_by_dept | ++---------+--------+-------------------------+ +| 1 | 30000 | NULL | +| 1 | 40000 | 40000 | +| 1 | 50000 | 40000 | +| 2 | 35000 | NULL | +| 2 | 45000 | 45000 | ++---------+--------+-------------------------+ +``` -Returns the sample standard deviation of a set of numbers. +### `regr_avgx` + +Computes the average of the independent variable (input) expression_x for the non-null paired data points. ``` -stddev_samp(expression) +regr_avgx(expression_y, expression_x) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `var` +### `regr_avgy` -Returns the statistical variance of a set of numbers. +Computes the average of the dependent variable (output) expression_y for the non-null paired data points. ``` -var(expression) +regr_avgy(expression_y, expression_x) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `var_pop` +### `regr_count` -Returns the statistical population variance of a set of numbers. +Counts the number of non-null paired data points. ``` -var_pop(expression) +regr_count(expression_y, expression_x) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `var_samp` +### `regr_intercept` -Returns the statistical sample variance of a set of numbers. +Computes the y-intercept of the linear regression line. For the equation (y = kx + b), this function returns b. ``` -var_samp(expression) +regr_intercept(expression_y, expression_x) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_slope` +### `regr_r2` -Returns the slope of the linear regression line for non-null pairs in aggregate columns. -Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. +Computes the square of the correlation coefficient between the independent and dependent variables. ``` -regr_slope(expression1, expression2) +regr_r2(expression_y, expression_x) ``` #### Arguments -- **expression_y**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_avgx` +### `regr_slope` -Computes the average of the independent variable (input) `expression_x` for the non-null paired data points. +Returns the slope of the linear regression line for non-null pairs in aggregate columns. Given input column Y and X: regr_slope(Y, X) returns the slope (k in Y = k\*X + b) using minimal RSS fitting. ``` -regr_avgx(expression_y, expression_x) +regr_slope(expression_y, expression_x) ``` #### Arguments -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_avgy` +### `regr_sxx` -Computes the average of the dependent variable (output) `expression_y` for the non-null paired data points. +Computes the sum of squares of the independent variable. ``` -regr_avgy(expression_y, expression_x) +regr_sxx(expression_y, expression_x) ``` #### Arguments -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_count` +### `regr_sxy` -Counts the number of non-null paired data points. +Computes the sum of products of paired data points. ``` -regr_count(expression_y, expression_x) +regr_sxy(expression_y, expression_x) ``` #### Arguments -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_intercept` +### `regr_syy` -Computes the y-intercept of the linear regression line. For the equation \(y = kx + b\), this function returns `b`. +Computes the sum of squares of the dependent variable. ``` -regr_intercept(expression_y, expression_x) +regr_syy(expression_y, expression_x) ``` #### Arguments -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Dependent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_x**: Independent variable expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_r2` +### `stddev` -Computes the square of the correlation coefficient between the independent and dependent variables. +Returns the standard deviation of a set of numbers. ``` -regr_r2(expression_y, expression_x) +stddev(expression) ``` #### Arguments -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_sxx` +#### Example -Computes the sum of squares of the independent variable. - -``` -regr_sxx(expression_y, expression_x) +```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ ``` -#### Arguments +#### Aliases -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- stddev_samp -### `regr_syy` +### `stddev_pop` -Computes the sum of squares of the dependent variable. +Returns the standard deviation of a set of numbers. ``` -regr_syy(expression_y, expression_x) +stddev(expression) ``` #### Arguments -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `regr_sxy` - -Computes the sum of products of paired data points. +#### Example -``` -regr_sxy(expression_y, expression_x) +```sql +> SELECT stddev(column_name) FROM table_name; ++----------------------+ +| stddev(column_name) | ++----------------------+ +| 12.34 | ++----------------------+ ``` -#### Arguments +### `stddev_samp` -- **expression_y**: Dependent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Independent variable. - Can be a constant, column, or function, and any combination of arithmetic operators. +_Alias of [stddev](#stddev)._ -## Approximate +## Approximate Functions - [approx_distinct](#approx_distinct) - [approx_median](#approx_median) @@ -538,8 +764,7 @@ regr_sxy(expression_y, expression_x) ### `approx_distinct` -Returns the approximate number of distinct input values calculated using the -HyperLogLog algorithm. +Returns the approximate number of distinct input values calculated using the HyperLogLog algorithm. ``` approx_distinct(expression) @@ -547,13 +772,22 @@ approx_distinct(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT approx_distinct(column_name) FROM table_name; ++-----------------------------------+ +| approx_distinct(column_name) | ++-----------------------------------+ +| 42 | ++-----------------------------------+ +``` ### `approx_median` -Returns the approximate median (50th percentile) of input values. -It is an alias of `approx_percentile_cont(x, 0.5)`. +Returns the approximate median (50th percentile) of input values. It is an alias of `approx_percentile_cont(x, 0.5)`. ``` approx_median(expression) @@ -561,8 +795,18 @@ approx_median(expression) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> SELECT approx_median(column_name) FROM table_name; ++-----------------------------------+ +| approx_median(column_name) | ++-----------------------------------+ +| 23.5 | ++-----------------------------------+ +``` ### `approx_percentile_cont` @@ -574,19 +818,24 @@ approx_percentile_cont(expression, percentile, centroids) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). -- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. +- **centroids**: Number of centroids to use in the t-digest algorithm. _Default is 100_. A higher number results in more accurate approximation but requires more memory. - If there are this number or fewer unique values, you can expect an exact result. - A higher number of centroids results in a more accurate approximation, but - requires more memory to compute. +#### Example + +```sql +> SELECT approx_percentile_cont(column_name, 0.75, 100) FROM table_name; ++-------------------------------------------------+ +| approx_percentile_cont(column_name, 0.75, 100) | ++-------------------------------------------------+ +| 65.0 | ++-------------------------------------------------+ +``` ### `approx_percentile_cont_with_weight` -Returns the weighted approximate percentile of input values using the -t-digest algorithm. +Returns the weighted approximate percentile of input values using the t-digest algorithm. ``` approx_percentile_cont_with_weight(expression, weight, percentile) @@ -594,8 +843,17 @@ approx_percentile_cont_with_weight(expression, weight, percentile) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **weight**: Expression to use as weight. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **weight**: Expression to use as weight. Can be a constant, column, or function, and any combination of arithmetic operators. - **percentile**: Percentile to compute. Must be a float value between 0 and 1 (inclusive). + +#### Example + +```sql +> SELECT approx_percentile_cont_with_weight(column_name, weight_column, 0.90) FROM table_name; ++----------------------------------------------------------------------+ +| approx_percentile_cont_with_weight(column_name, weight_column, 0.90) | ++----------------------------------------------------------------------+ +| 78.5 | ++----------------------------------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/data_types.md b/docs/source/user-guide/sql/data_types.md index 0e974550a84d..18c95cdea70e 100644 --- a/docs/source/user-guide/sql/data_types.md +++ b/docs/source/user-guide/sql/data_types.md @@ -97,7 +97,7 @@ select arrow_cast(now(), 'Timestamp(Second, None)'); | `BYTEA` | `Binary` | You can create binary literals using a hex string literal such as -`X'1234` to create a `Binary` value of two bytes, `0x12` and `0x34`. +`X'1234'` to create a `Binary` value of two bytes, `0x12` and `0x34`. ## Unsupported SQL Types diff --git a/docs/source/user-guide/sql/ddl.md b/docs/source/user-guide/sql/ddl.md index 3d8b632f6ec2..e16b9681eb80 100644 --- a/docs/source/user-guide/sql/ddl.md +++ b/docs/source/user-guide/sql/ddl.md @@ -60,9 +60,6 @@ CREATE [UNBOUNDED] EXTERNAL TABLE [ IF NOT EXISTS ] [ () ] STORED AS -[ WITH HEADER ROW ] -[ DELIMITER ] -[ COMPRESSION TYPE ] [ PARTITIONED BY () ] [ WITH ORDER () ] [ OPTIONS () ] @@ -100,8 +97,8 @@ scanning a subset of the file. ```sql CREATE EXTERNAL TABLE test STORED AS CSV -WITH HEADER ROW -LOCATION '/path/to/aggregate_simple.csv'; +LOCATION '/path/to/aggregate_simple.csv' +OPTIONS ('has_header' 'true'); ``` It is also possible to use compressed files, such as `.csv.gz`: @@ -110,8 +107,8 @@ It is also possible to use compressed files, such as `.csv.gz`: CREATE EXTERNAL TABLE test STORED AS CSV COMPRESSION TYPE GZIP -WITH HEADER ROW -LOCATION '/path/to/aggregate_simple.csv.gz'; +LOCATION '/path/to/aggregate_simple.csv.gz' +OPTIONS ('has_header' 'true'); ``` It is also possible to specify the schema manually. @@ -133,8 +130,8 @@ CREATE EXTERNAL TABLE test ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW -LOCATION '/path/to/aggregate_test_100.csv'; +LOCATION '/path/to/aggregate_test_100.csv' +OPTIONS ('has_header' 'true'); ``` It is also possible to specify a directory that contains a partitioned @@ -143,8 +140,8 @@ table (multiple files with the same schema) ```sql CREATE EXTERNAL TABLE test STORED AS CSV -WITH HEADER ROW -LOCATION '/path/to/directory/of/files'; +LOCATION '/path/to/directory/of/files' +OPTIONS ('has_header' 'true'); ``` With `CREATE UNBOUNDED EXTERNAL TABLE` SQL statement. We can create unbounded data sources such as following: @@ -181,9 +178,9 @@ CREATE EXTERNAL TABLE test ( c13 VARCHAR NOT NULL ) STORED AS CSV -WITH HEADER ROW WITH ORDER (c2 ASC, c5 + c8 DESC NULL FIRST) -LOCATION '/path/to/aggregate_test_100.csv'; +LOCATION '/path/to/aggregate_test_100.csv' +OPTIONS ('has_header' 'true'); ``` Where `WITH ORDER` clause specifies the sort order: diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 42e0c8054c9b..dd016cabbfb7 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -39,7 +39,10 @@ TO 'file_name' clause is not specified, it will be inferred from the file extension if possible. `PARTITIONED BY` specifies the columns to use for partitioning the output files into -separate hive-style directories. +separate hive-style directories. By default, columns used in `PARTITIONED BY` will be removed +from the output format. If you want to keep the columns, you should provide the option +`execution.keep_partition_by_columns true`. `execution.keep_partition_by_columns` flag can also +be enabled through `ExecutionOptions` within `SessionConfig`. The output format is determined by the first match of the following rules: diff --git a/docs/source/user-guide/sql/explain.md b/docs/source/user-guide/sql/explain.md index 22f73e3d76d7..45bb3a57aa7c 100644 --- a/docs/source/user-guide/sql/explain.md +++ b/docs/source/user-guide/sql/explain.md @@ -21,6 +21,8 @@ The `EXPLAIN` command shows the logical and physical execution plan for the specified SQL statement. +See the [Reading Explain Plans](../explain-usage.md) page for more information on how to interpret these plans. +
 EXPLAIN [ANALYZE] [VERBOSE] statement
 
diff --git a/docs/source/user-guide/sql/index.rst b/docs/source/user-guide/sql/index.rst index 04d1fc228f81..4499aac53611 100644 --- a/docs/source/user-guide/sql/index.rst +++ b/docs/source/user-guide/sql/index.rst @@ -31,6 +31,8 @@ SQL Reference operators aggregate_functions window_functions + window_functions_new scalar_functions + special_functions sql_status write_options diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 624e86db3565..b92b815d7c95 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -17,6 +17,14 @@ under the License. --> + + # Scalar Functions ## Math Functions @@ -27,12 +35,13 @@ - [asin](#asin) - [asinh](#asinh) - [atan](#atan) -- [atanh](#atanh) - [atan2](#atan2) +- [atanh](#atanh) - [cbrt](#cbrt) - [ceil](#ceil) - [cos](#cos) - [cosh](#cosh) +- [cot](#cot) - [degrees](#degrees) - [exp](#exp) - [factorial](#factorial) @@ -47,8 +56,8 @@ - [log2](#log2) - [nanvl](#nanvl) - [pi](#pi) -- [power](#power) - [pow](#pow) +- [power](#power) - [radians](#radians) - [random](#random) - [round](#round) @@ -70,8 +79,7 @@ abs(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `acos` @@ -83,8 +91,7 @@ acos(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `acosh` @@ -96,8 +103,7 @@ acosh(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `asin` @@ -109,8 +115,7 @@ asin(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `asinh` @@ -122,8 +127,7 @@ asinh(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `atan` @@ -135,36 +139,34 @@ atan(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. -### `atanh` +### `atan2` -Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. +Returns the arc tangent or inverse tangent of `expression_y / expression_x`. ``` -atanh(numeric_expression) +atan2(expression_y, expression_x) ``` #### Arguments -- **numeric_expression**: Numeric expression to operate on. +- **expression_y**: First numeric expression to operate on. + Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. -### `atan2` +### `atanh` -Returns the arc tangent or inverse tangent of `expression_y / expression_x`. +Returns the area hyperbolic tangent or inverse hyperbolic tangent of a number. ``` -atan2(expression_y, expression_x) +atanh(numeric_expression) ``` #### Arguments -- **expression_y**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_x**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `cbrt` @@ -176,8 +178,7 @@ cbrt(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `ceil` @@ -189,8 +190,7 @@ ceil(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `cos` @@ -202,8 +202,7 @@ cos(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `cosh` @@ -213,23 +212,33 @@ Returns the hyperbolic cosine of a number. cosh(numeric_expression) ``` -### `degrees` +#### Arguments -Converts radians to degrees. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `cot` + +Returns the cotangent of a number. ``` -degrees(numeric_expression) +cot(numeric_expression) ``` #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. + +### `degrees` + +Converts radians to degrees. + +``` +degrees(numeric_expression) +``` #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `exp` @@ -241,8 +250,7 @@ exp(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to use as the exponent. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `factorial` @@ -254,8 +262,7 @@ factorial(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `floor` @@ -267,8 +274,7 @@ floor(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `gcd` @@ -280,10 +286,8 @@ gcd(expression_x, expression_y) #### Arguments -- **expression_x**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `isnan` @@ -295,8 +299,7 @@ isnan(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `iszero` @@ -308,8 +311,7 @@ iszero(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `lcm` @@ -321,10 +323,8 @@ lcm(expression_x, expression_y) #### Arguments -- **expression_x**: First numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Second numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: First numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **expression_y**: Second numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `ln` @@ -336,13 +336,11 @@ ln(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `log` -Returns the base-x logarithm of a number. -Can either provide a specified base, or if omitted then takes the base-10 of a number. +Returns the base-x logarithm of a number. Can either provide a specified base, or if omitted then takes the base-10 of a number. ``` log(base, numeric_expression) @@ -351,10 +349,8 @@ log(numeric_expression) #### Arguments -- **base**: Base numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **base**: Base numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `log10` @@ -366,8 +362,7 @@ log10(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `log2` @@ -379,8 +374,7 @@ log2(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `nanvl` @@ -393,10 +387,8 @@ nanvl(expression_x, expression_y) #### Arguments -- **expression_x**: Numeric expression to return if it's not _NaN_. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression_y**: Numeric expression to return if the first expression is _NaN_. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_x**: Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression_y**: Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators. ### `pi` @@ -406,6 +398,10 @@ Returns an approximate value of π. pi() ``` +### `pow` + +_Alias of [power](#power)._ + ### `power` Returns a base expression raised to the power of an exponent. @@ -416,19 +412,13 @@ power(base, exponent) #### Arguments -- **base**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **exponent**: Exponent numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **base**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **exponent**: Exponent numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. #### Aliases - pow -### `pow` - -_Alias of [power](#power)._ - ### `radians` Converts degrees to radians. @@ -439,8 +429,7 @@ radians(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `random` @@ -461,10 +450,8 @@ round(numeric_expression[, decimal_places]) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **decimal_places**: Optional. The number of decimal places to round to. - Defaults to 0. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **decimal_places**: Optional. The number of decimal places to round to. Defaults to 0. ### `signum` @@ -478,8 +465,7 @@ signum(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `sin` @@ -491,8 +477,7 @@ sin(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `sinh` @@ -504,8 +489,7 @@ sinh(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `sqrt` @@ -517,8 +501,7 @@ sqrt(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `tan` @@ -530,8 +513,7 @@ tan(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `tanh` @@ -543,8 +525,7 @@ tanh(numeric_expression) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. ### `trunc` @@ -556,9 +537,7 @@ trunc(numeric_expression[, decimal_places]) #### Arguments -- **numeric_expression**: Numeric expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. - +- **numeric_expression**: Numeric expression to operate on. Can be a constant, column, or function, and any combination of operators. - **decimal_places**: Optional. The number of decimal places to truncate to. Defaults to 0 (truncate to a whole number). If `decimal_places` is a positive integer, truncates digits to the @@ -568,16 +547,14 @@ trunc(numeric_expression[, decimal_places]) ## Conditional Functions - [coalesce](#coalesce) +- [ifnull](#ifnull) - [nullif](#nullif) - [nvl](#nvl) - [nvl2](#nvl2) -- [ifnull](#ifnull) ### `coalesce` -Returns the first of its arguments that is not _null_. -Returns _null_ if all arguments are _null_. -This function is often used to substitute a default value for _null_ values. +Returns the first of its arguments that is not _null_. Returns _null_ if all arguments are _null_. This function is often used to substitute a default value for _null_ values. ``` coalesce(expression1[, ..., expression_n]) @@ -585,10 +562,22 @@ coalesce(expression1[, ..., expression_n]) #### Arguments -- **expression1, expression_n**: - Expression to use if previous expressions are _null_. - Can be a constant, column, or function, and any combination of arithmetic operators. - Pass as many expression arguments as necessary. +- **expression1, expression_n**: Expression to use if previous expressions are _null_. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary. + +#### Example + +```sql +> select coalesce(null, null, 'datafusion'); ++----------------------------------------+ +| coalesce(NULL,NULL,Utf8("datafusion")) | ++----------------------------------------+ +| datafusion | ++----------------------------------------+ +``` + +### `ifnull` + +_Alias of [nvl](#nvl)._ ### `nullif` @@ -601,14 +590,29 @@ nullif(expression1, expression2) #### Arguments -- **expression1**: Expression to compare and return if equal to expression2. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: Expression to compare to expression1. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression1**: Expression to compare and return if equal to expression2. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to compare to expression1. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nullif('datafusion', 'data'); ++-----------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("data")) | ++-----------------------------------------+ +| datafusion | ++-----------------------------------------+ +> select nullif('datafusion', 'datafusion'); ++-----------------------------------------------+ +| nullif(Utf8("datafusion"),Utf8("datafusion")) | ++-----------------------------------------------+ +| | ++-----------------------------------------------+ +``` ### `nvl` -Returns _expression2_ if _expression1_ is NULL; otherwise it returns _expression1_. +Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_. ``` nvl(expression1, expression2) @@ -616,10 +620,29 @@ nvl(expression1, expression2) #### Arguments -- **expression1**: return if expression1 not is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: return if expression1 is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression1**: Expression to return if not null. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select nvl(null, 'a'); ++---------------------+ +| nvl(NULL,Utf8("a")) | ++---------------------+ +| a | ++---------------------+\ +> select nvl('b', 'a'); ++--------------------------+ +| nvl(Utf8("b"),Utf8("a")) | ++--------------------------+ +| b | ++--------------------------+ +``` + +#### Aliases + +- ifnull ### `nvl2` @@ -631,16 +654,26 @@ nvl2(expression1, expression2, expression3) #### Arguments -- **expression1**: conditional expression. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression2**: return if expression1 is not NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **expression3**: return if expression1 is NULL. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression1**: Expression to test for null. Can be a constant, column, or function, and any combination of operators. +- **expression2**: Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators. +- **expression3**: Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators. -### `ifnull` +#### Example -_Alias of [nvl](#nvl)._ +```sql +> select nvl2(null, 'a', 'b'); ++--------------------------------+ +| nvl2(NULL,Utf8("a"),Utf8("b")) | ++--------------------------------+ +| b | ++--------------------------------+ +> select nvl2('data', 'a', 'b'); ++----------------------------------------+ +| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) | ++----------------------------------------+ +| a | ++----------------------------------------+ +``` ## String Functions @@ -649,18 +682,22 @@ _Alias of [nvl](#nvl)._ - [btrim](#btrim) - [char_length](#char_length) - [character_length](#character_length) +- [chr](#chr) - [concat](#concat) - [concat_ws](#concat_ws) -- [chr](#chr) +- [contains](#contains) - [ends_with](#ends_with) +- [find_in_set](#find_in_set) - [initcap](#initcap) - [instr](#instr) - [left](#left) - [length](#length) +- [levenshtein](#levenshtein) - [lower](#lower) - [lpad](#lpad) - [ltrim](#ltrim) - [octet_length](#octet_length) +- [position](#position) - [repeat](#repeat) - [replace](#replace) - [reverse](#reverse) @@ -671,20 +708,18 @@ _Alias of [nvl](#nvl)._ - [starts_with](#starts_with) - [strpos](#strpos) - [substr](#substr) +- [substr_index](#substr_index) +- [substring](#substring) +- [substring_index](#substring_index) - [to_hex](#to_hex) - [translate](#translate) - [trim](#trim) - [upper](#upper) - [uuid](#uuid) -- [overlay](#overlay) -- [levenshtein](#levenshtein) -- [substr_index](#substr_index) -- [find_in_set](#find_in_set) -- [position](#position) ### `ascii` -Returns the ASCII value of the first character in a string. +Returns the Unicode character code of the first character in a string. ``` ascii(str) @@ -692,11 +727,28 @@ ascii(str) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select ascii('abc'); ++--------------------+ +| ascii(Utf8("abc")) | ++--------------------+ +| 97 | ++--------------------+ +> select ascii('🚀'); ++-------------------+ +| ascii(Utf8("🚀")) | ++-------------------+ +| 128640 | ++-------------------+ +``` **Related functions**: -[chr](#chr) + +- [chr](#chr) ### `bit_length` @@ -708,18 +760,27 @@ bit_length(str) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select bit_length('datafusion'); ++--------------------------------+ +| bit_length(Utf8("datafusion")) | ++--------------------------------+ +| 80 | ++--------------------------------+ +``` **Related functions**: -[length](#length), -[octet_length](#octet_length) + +- [length](#length) +- [octet_length](#octet_length) ### `btrim` -Trims the specified trim string from the start and end of a string. -If no trim string is provided, all whitespace is removed from the start and end -of the input string. +Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string. ``` btrim(str[, trim_str]) @@ -727,62 +788,75 @@ btrim(str[, trim_str]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the beginning and end of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._ -**Related functions**: -[ltrim](#ltrim), -[rtrim](#rtrim) +#### Example + +```sql +> select btrim('__datafusion____', '_'); ++-------------------------------------------+ +| btrim(Utf8("__datafusion____"),Utf8("_")) | ++-------------------------------------------+ +| datafusion | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(BOTH trim_str FROM str) +``` + +```sql +trim(trim_str FROM str) +``` #### Aliases - trim -### `char_length` +**Related functions**: -_Alias of [length](#length)._ +- [ltrim](#ltrim) +- [rtrim](#rtrim) -### `character_length` +### `char_length` -_Alias of [length](#length)._ +_Alias of [character_length](#character_length)._ -### `concat` +### `character_length` -Concatenates multiple strings together. +Returns the number of characters in a string. ``` -concat(str[, ..., str_n]) +character_length(str) ``` #### Arguments -- **str**: String expression to concatenate. - Can be a constant, column, or function, and any combination of string operators. -- **str_n**: Subsequent string column or literal string to concatenate. - -**Related functions**: -[concat_ws](#concat_ws) - -### `concat_ws` +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -Concatenates multiple strings together with a specified separator. +#### Example -``` -concat(separator, str[, ..., str_n]) +```sql +> select character_length('Ångström'); ++------------------------------------+ +| character_length(Utf8("Ångström")) | ++------------------------------------+ +| 8 | ++------------------------------------+ ``` -#### Arguments +#### Aliases -- **separator**: Separator to insert between concatenated strings. -- **str**: String expression to concatenate. - Can be a constant, column, or function, and any combination of string operators. -- **str_n**: Subsequent string column or literal string to concatenate. +- length +- char_length **Related functions**: -[concat](#concat) + +- [bit_length](#bit_length) +- [octet_length](#octet_length) ### `chr` @@ -794,94 +868,245 @@ chr(expression) #### Arguments -- **expression**: Expression containing the ASCII or Unicode code value to operate on. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -**Related functions**: -[ascii](#ascii) - -### `ends_with` - -Tests if a string ends with a substring. +#### Example -``` -ends_with(str, substr) +```sql +> select chr(128640); ++--------------------+ +| chr(Int64(128640)) | ++--------------------+ +| 🚀 | ++--------------------+ ``` -#### Arguments +**Related functions**: -- **str**: String expression to test. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring to test for. +- [ascii](#ascii) -### `initcap` +### `concat` -Capitalizes the first character in each word in the input string. -Words are delimited by non-alphanumeric characters. +Concatenates multiple strings together. ``` -initcap(str) +concat(str[, ..., str_n]) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[lower](#lower), -[upper](#upper) +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. -### `instr` +#### Example -_Alias of [strpos](#strpos)._ +```sql +> select concat('data', 'f', 'us', 'ion'); ++-------------------------------------------------------+ +| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) | ++-------------------------------------------------------+ +| datafusion | ++-------------------------------------------------------+ +``` -#### Arguments +**Related functions**: -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to search for. - Can be a constant, column, or function, and any combination of string operators. +- [concat_ws](#concat_ws) -### `left` +### `concat_ws` -Returns a specified number of characters from the left side of a string. +Concatenates multiple strings together with a specified separator. ``` -left(str, n) +concat_ws(separator, str[, ..., str_n]) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of characters to return. +- **separator**: Separator to insert between concatenated strings. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **str_n**: Subsequent string expressions to concatenate. + +#### Example + +```sql +> select concat_ws('_', 'data', 'fusion'); ++--------------------------------------------------+ +| concat_ws(Utf8("_"),Utf8("data"),Utf8("fusion")) | ++--------------------------------------------------+ +| data_fusion | ++--------------------------------------------------+ +``` **Related functions**: -[right](#right) -### `length` +- [concat](#concat) -Returns the number of characters in a string. +### `contains` + +Return true if search_str is found within string (case-sensitive). ``` -length(str) +contains(str, search_str) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -#### Aliases +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **search_str**: The string to search for in str. -- char_length -- character_length +#### Example -**Related functions**: -[bit_length](#bit_length), -[octet_length](#octet_length) +```sql +> select contains('the quick brown fox', 'row'); ++---------------------------------------------------+ +| contains(Utf8("the quick brown fox"),Utf8("row")) | ++---------------------------------------------------+ +| true | ++---------------------------------------------------+ +``` + +### `ends_with` + +Tests if a string ends with a substring. + +``` +ends_with(str, substr) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring to test for. + +#### Example + +```sql +> select ends_with('datafusion', 'soin'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("soin")) | ++--------------------------------------------+ +| false | ++--------------------------------------------+ +> select ends_with('datafusion', 'sion'); ++--------------------------------------------+ +| ends_with(Utf8("datafusion"),Utf8("sion")) | ++--------------------------------------------+ +| true | ++--------------------------------------------+ +``` + +### `find_in_set` + +Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. + +``` +find_in_set(str, strlist) +``` + +#### Arguments + +- **str**: String expression to find in strlist. +- **strlist**: A string list is a string composed of substrings separated by , characters. + +#### Example + +```sql +> select find_in_set('b', 'a,b,c,d'); ++----------------------------------------+ +| find_in_set(Utf8("b"),Utf8("a,b,c,d")) | ++----------------------------------------+ +| 2 | ++----------------------------------------+ +``` + +### `initcap` + +Capitalizes the first character in each word in the input string. Words are delimited by non-alphanumeric characters. + +``` +initcap(str) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select initcap('apache datafusion'); ++------------------------------------+ +| initcap(Utf8("apache datafusion")) | ++------------------------------------+ +| Apache Datafusion | ++------------------------------------+ +``` + +**Related functions**: + +- [lower](#lower) +- [upper](#upper) + +### `instr` + +_Alias of [strpos](#strpos)._ + +### `left` + +Returns a specified number of characters from the left side of a string. + +``` +left(str, n) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return. + +#### Example + +```sql +> select left('datafusion', 4); ++-----------------------------------+ +| left(Utf8("datafusion"),Int64(4)) | ++-----------------------------------+ +| data | ++-----------------------------------+ +``` + +**Related functions**: + +- [right](#right) + +### `length` + +_Alias of [character_length](#character_length)._ + +### `levenshtein` + +Returns the [`Levenshtein distance`](https://en.wikipedia.org/wiki/Levenshtein_distance) between the two given strings. + +``` +levenshtein(str1, str2) +``` + +#### Arguments + +- **str1**: String expression to compute Levenshtein distance with str2. +- **str2**: String expression to compute Levenshtein distance with str1. + +#### Example + +```sql +> select levenshtein('kitten', 'sitting'); ++---------------------------------------------+ +| levenshtein(Utf8("kitten"),Utf8("sitting")) | ++---------------------------------------------+ +| 3 | ++---------------------------------------------+ +``` ### `lower` @@ -893,12 +1118,23 @@ lower(str) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select lower('Ångström'); ++-------------------------+ +| lower(Utf8("Ångström")) | ++-------------------------+ +| ångström | ++-------------------------+ +``` **Related functions**: -[initcap](#initcap), -[upper](#upper) + +- [initcap](#initcap) +- [upper](#upper) ### `lpad` @@ -910,21 +1146,28 @@ lpad(str, n[, padding_str]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **n**: String length to pad to. -- **padding_str**: String expression to pad with. - Can be a constant, column, or function, and any combination of string operators. - _Default is a space._ +- **padding_str**: Optional string expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select lpad('Dolly', 10, 'hello'); ++---------------------------------------------+ +| lpad(Utf8("Dolly"),Int64(10),Utf8("hello")) | ++---------------------------------------------+ +| helloDolly | ++---------------------------------------------+ +``` **Related functions**: -[rpad](#rpad) + +- [rpad](#rpad) ### `ltrim` -Trims the specified trim string from the beginning of a string. -If no trim string is provided, all whitespace is removed from the start -of the input string. +Trims the specified trim string from the beginning of a string. If no trim string is provided, all whitespace is removed from the start of the input string. ``` ltrim(str[, trim_str]) @@ -932,15 +1175,36 @@ ltrim(str[, trim_str]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the beginning of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the beginning of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select ltrim(' datafusion '); ++-------------------------------+ +| ltrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select ltrim('___datafusion___', '_'); ++-------------------------------------------+ +| ltrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| datafusion___ | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(LEADING trim_str FROM str) +``` **Related functions**: -[btrim](#btrim), -[rtrim](#rtrim) + +- [btrim](#btrim) +- [rtrim](#rtrim) ### `octet_length` @@ -952,12 +1216,27 @@ octet_length(str) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select octet_length('Ångström'); ++--------------------------------+ +| octet_length(Utf8("Ångström")) | ++--------------------------------+ +| 10 | ++--------------------------------+ +``` **Related functions**: -[bit_length](#bit_length), -[length](#length) + +- [bit_length](#bit_length) +- [length](#length) + +### `position` + +_Alias of [strpos](#strpos)._ ### `repeat` @@ -969,10 +1248,20 @@ repeat(str, n) #### Arguments -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **n**: Number of times to repeat the input string. +#### Example + +```sql +> select repeat('data', 3); ++-------------------------------+ +| repeat(Utf8("data"),Int64(3)) | ++-------------------------------+ +| datadatadata | ++-------------------------------+ +``` + ### `replace` Replaces all occurrences of a specified substring in a string with a new substring. @@ -983,12 +1272,20 @@ replace(str, substr, replacement) #### Arguments -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. -- **substr**: Substring expression to replace in the input string. - Can be a constant, column, or function, and any combination of string operators. -- **replacement**: Replacement substring expression. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **substr**: Substring expression to replace in the input string. Substring expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **replacement**: Replacement substring expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select replace('ABabbaBA', 'ab', 'cd'); ++-------------------------------------------------+ +| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) | ++-------------------------------------------------+ +| ABcdbaBA | ++-------------------------------------------------+ +``` ### `reverse` @@ -1000,8 +1297,18 @@ reverse(str) #### Arguments -- **str**: String expression to repeat. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select reverse('datafusion'); ++-----------------------------+ +| reverse(Utf8("datafusion")) | ++-----------------------------+ +| noisufatad | ++-----------------------------+ +``` ### `right` @@ -1013,12 +1320,23 @@ right(str, n) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **n**: Number of characters to return. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **n**: Number of characters to return + +#### Example + +```sql +> select right('datafusion', 6); ++------------------------------------+ +| right(Utf8("datafusion"),Int64(6)) | ++------------------------------------+ +| fusion | ++------------------------------------+ +``` **Related functions**: -[left](#left) + +- [left](#left) ### `rpad` @@ -1030,21 +1348,28 @@ rpad(str, n[, padding_str]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **n**: String length to pad to. -- **padding_str**: String expression to pad with. - Can be a constant, column, or function, and any combination of string operators. - _Default is a space._ +- **padding_str**: String expression to pad with. Can be a constant, column, or function, and any combination of string operators. _Default is a space._ + +#### Example + +```sql +> select rpad('datafusion', 20, '_-'); ++-----------------------------------------------+ +| rpad(Utf8("datafusion"),Int64(20),Utf8("_-")) | ++-----------------------------------------------+ +| datafusion_-_-_-_-_- | ++-----------------------------------------------+ +``` **Related functions**: -[lpad](#lpad) + +- [lpad](#lpad) ### `rtrim` -Trims the specified trim string from the end of a string. -If no trim string is provided, all whitespace is removed from the end -of the input string. +Trims the specified trim string from the end of a string. If no trim string is provided, all whitespace is removed from the end of the input string. ``` rtrim(str[, trim_str]) @@ -1052,20 +1377,40 @@ rtrim(str[, trim_str]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **trim_str**: String expression to trim from the end of the input string. - Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters._ +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **trim_str**: String expression to trim from the end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. _Default is whitespace characters._ + +#### Example + +```sql +> select rtrim(' datafusion '); ++-------------------------------+ +| rtrim(Utf8(" datafusion ")) | ++-------------------------------+ +| datafusion | ++-------------------------------+ +> select rtrim('___datafusion___', '_'); ++-------------------------------------------+ +| rtrim(Utf8("___datafusion___"),Utf8("_")) | ++-------------------------------------------+ +| ___datafusion | ++-------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +trim(TRAILING trim_str FROM str) +``` **Related functions**: -[btrim](#btrim), -[ltrim](#ltrim) + +- [btrim](#btrim) +- [ltrim](#ltrim) ### `split_part` -Splits a string based on a specified delimiter and returns the substring in the -specified position. +Splits a string based on a specified delimiter and returns the substring in the specified position. ``` split_part(str, delimiter, pos) @@ -1073,11 +1418,21 @@ split_part(str, delimiter, pos) #### Arguments -- **str**: String expression to spit. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **delimiter**: String or character to split on. - **pos**: Position of the part to return. +#### Example + +```sql +> select split_part('1.2.3.4.5', '.', 3); ++--------------------------------------------------+ +| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) | ++--------------------------------------------------+ +| 3 | ++--------------------------------------------------+ +``` + ### `starts_with` Tests if a string starts with a substring. @@ -1088,15 +1443,23 @@ starts_with(str, substr) #### Arguments -- **str**: String expression to test. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **substr**: Substring to test for. +#### Example + +```sql +> select starts_with('datafusion','data'); ++----------------------------------------------+ +| starts_with(Utf8("datafusion"),Utf8("data")) | ++----------------------------------------------+ +| true | ++----------------------------------------------+ +``` + ### `strpos` -Returns the starting position of a specified substring in a string. -Positions begin at 1. -If the substring does not exist in the string, the function returns 0. +Returns the starting position of a specified substring in a string. Positions begin at 1. If the substring does not exist in the string, the function returns 0. ``` strpos(str, substr) @@ -1104,32 +1467,133 @@ strpos(str, substr) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **substr**: Substring expression to search for. - Can be a constant, column, or function, and any combination of string operators. -#### Aliases +#### Example + +```sql +> select strpos('datafusion', 'fus'); ++----------------------------------------+ +| strpos(Utf8("datafusion"),Utf8("fus")) | ++----------------------------------------+ +| 5 | ++----------------------------------------+ +``` + +#### Alternative Syntax + +```sql +position(substr in origstr) +``` + +#### Aliases + +- instr +- position + +### `substr` + +Extracts a substring of a specified number of characters from a specific starting position in a string. + +``` +substr(str, start_pos[, length]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start_pos**: Character position to start the substring at. The first character in the string has a position of 1. +- **length**: Number of characters to extract. If not specified, returns the rest of the string after the start position. + +#### Example + +```sql +> select substr('datafusion', 5, 3); ++----------------------------------------------+ +| substr(Utf8("datafusion"),Int64(5),Int64(3)) | ++----------------------------------------------+ +| fus | ++----------------------------------------------+ +``` + +#### Alternative Syntax + +```sql +substring(str from start_pos for length) +``` + +#### Aliases + +- substring + +### `substr_index` + +Returns the substring from str before count occurrences of the delimiter delim. +If count is positive, everything to the left of the final delimiter (counting from the left) is returned. +If count is negative, everything to the right of the final delimiter (counting from the right) is returned. + +``` +substr_index(str, delim, count) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **delim**: The string to find in str to split str. +- **count**: The number of times to search for the delimiter. Can be either a positive or negative number. + +#### Example + +```sql +> select substr_index('www.apache.org', '.', 1); ++---------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(1)) | ++---------------------------------------------------------+ +| www | ++---------------------------------------------------------+ +> select substr_index('www.apache.org', '.', -1); ++----------------------------------------------------------+ +| substr_index(Utf8("www.apache.org"),Utf8("."),Int64(-1)) | ++----------------------------------------------------------+ +| org | ++----------------------------------------------------------+ +``` + +#### Aliases + +- substring_index + +### `substring` + +_Alias of [substr](#substr)._ + +### `substring_index` -- instr +_Alias of [substr_index](#substr_index)._ -### `substr` +### `to_hex` -Extracts a substring of a specified number of characters from a specific -starting position in a string. +Converts an integer to a hexadecimal string. ``` -substr(str, start_pos[, length]) +to_hex(int) ``` #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **start_pos**: Character position to start the substring at. - The first character in the string has a position of 1. -- **length**: Number of characters to extract. - If not specified, returns the rest of the string after the start position. +- **int**: Integer expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select to_hex(12345689); ++-------------------------+ +| to_hex(Int64(12345689)) | ++-------------------------+ +| bc6159 | ++-------------------------+ +``` ### `translate` @@ -1139,25 +1603,23 @@ Translates characters in a string to specified translation characters. translate(str, chars, translation) ``` -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **chars**: Characters to translate. -- **translation**: Translation characters. Translation characters replace only - characters at the same position in the **chars** string. +#### Arguments -### `to_hex` +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **chars**: Characters to translate. +- **translation**: Translation characters. Translation characters replace only characters at the same position in the **chars** string. -Converts an integer to a hexadecimal string. +#### Example -``` -to_hex(int) +```sql +> select translate('twice', 'wic', 'her'); ++--------------------------------------------------+ +| translate(Utf8("twice"),Utf8("wic"),Utf8("her")) | ++--------------------------------------------------+ +| there | ++--------------------------------------------------+ ``` -#### Arguments - -- **int**: Integer expression to convert. - Can be a constant, column, or function, and any combination of arithmetic operators. - ### `trim` _Alias of [btrim](#btrim)._ @@ -1172,85 +1634,63 @@ upper(str) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[initcap](#initcap), -[lower](#lower) - -### `uuid` - -Returns UUID v4 string value which is unique per row. - -``` -uuid() -``` - -### `overlay` +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. -Returns the string which is replaced by another string from the specified position and specified count length. -For example, `overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas` +#### Example -``` -overlay(str PLACING substr FROM pos [FOR count]) +```sql +> select upper('dataFusion'); ++---------------------------+ +| upper(Utf8("dataFusion")) | ++---------------------------+ +| DATAFUSION | ++---------------------------+ ``` -#### Arguments +**Related functions**: -- **str**: String expression to operate on. -- **substr**: the string to replace part of str. -- **pos**: the start position to replace of str. -- **count**: the count of characters to be replaced from start position of str. If not specified, will use substr length instead. +- [initcap](#initcap) +- [lower](#lower) -### `levenshtein` +### `uuid` -Returns the Levenshtein distance between the two given strings. -For example, `levenshtein('kitten', 'sitting') = 3` +Returns [`UUID v4`]() string value which is unique per row. ``` -levenshtein(str1, str2) +uuid() ``` -#### Arguments - -- **str1**: String expression to compute Levenshtein distance with str2. -- **str2**: String expression to compute Levenshtein distance with str1. - -### `substr_index` - -Returns the substring from str before count occurrences of the delimiter delim. -If count is positive, everything to the left of the final delimiter (counting from the left) is returned. -If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -For example, `substr_index('www.apache.org', '.', 1) = www`, `substr_index('www.apache.org', '.', -1) = org` +#### Example -``` -substr_index(str, delim, count) +```sql +> select uuid(); ++--------------------------------------+ +| uuid() | ++--------------------------------------+ +| 6ec17ef8-1934-41cc-8d59-d0c8f9eea1f0 | ++--------------------------------------+ ``` -#### Arguments +## Binary String Functions -- **str**: String expression to operate on. -- **delim**: the string to find in str to split str. -- **count**: The number of times to search for the delimiter. Can be both a positive or negative number. +- [decode](#decode) +- [encode](#encode) -### `find_in_set` +### `decode` -Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings. -For example, `find_in_set('b', 'a,b,c,d') = 2` +Decode binary data from textual representation in string. ``` -find_in_set(str, strlist) +decode(expression, format) ``` #### Arguments -- **str**: String expression to find in strlist. -- **strlist**: A string list is a string composed of substrings separated by , characters. +- **expression**: Expression containing encoded string data +- **format**: Same arguments as [encode](#encode) -## Binary String Functions +**Related functions**: -- [decode](#decode) - [encode](#encode) ### `encode` @@ -1264,48 +1704,58 @@ encode(expression, format) #### Arguments - **expression**: Expression containing string or binary data - - **format**: Supported formats are: `base64`, `hex` **Related functions**: -[decode](#decode) - -### `decode` - -Decode binary data from textual representation in string. - -``` -decode(expression, format) -``` - -#### Arguments - -- **expression**: Expression containing encoded string data -- **format**: Same arguments as [encode](#encode) - -**Related functions**: -[encode](#encode) +- [decode](#decode) ## Regular Expression Functions -Apache DataFusion uses a [PCRE-like] regular expression [syntax] +Apache DataFusion uses a [PCRE-like](https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions) +regular expression [syntax](https://docs.rs/regex/latest/regex/#syntax) (minus support for several features including look-around and backreferences). The following regular expression functions are supported: +- [regexp_count](#regexp_count) - [regexp_like](#regexp_like) - [regexp_match](#regexp_match) - [regexp_replace](#regexp_replace) -[pcre-like]: https://en.wikibooks.org/wiki/Regular_Expressions/Perl-Compatible_Regular_Expressions -[syntax]: https://docs.rs/regex/latest/regex/#syntax +### `regexp_count` -### `regexp_like` +Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string. + +``` +regexp_count(str, regexp[, start, flags]) +``` + +#### Arguments + +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **start**: - **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: + - **i**: case-insensitive: letters match both upper and lower case + - **m**: multi-line mode: ^ and $ match begin/end of line + - **s**: allow . to match \n + - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used + - **U**: swap the meaning of x* and x*? -Returns true if a [regular expression] has at least one match in a string, -false otherwise. +#### Example + +```sql +> select regexp_count('abcAbAbc', 'abc', 2, 'i'); ++---------------------------------------------------------------+ +| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) | ++---------------------------------------------------------------+ +| 1 | ++---------------------------------------------------------------+ +``` -[regular expression]: https://docs.rs/regex/latest/regex/#syntax +### `regexp_like` + +Returns true if a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has at least one match in a string, false otherwise. ``` regexp_like(str, regexp[, flags]) @@ -1313,12 +1763,9 @@ regexp_like(str, regexp[, flags]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **regexp**: Regular expression to test against the string expression. - Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **regexp**: Regular expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n @@ -1346,7 +1793,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `regexp_match` -Returns a list of [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. +Returns the first [regular expression](https://docs.rs/regex/latest/regex/#syntax) matches in a string. ``` regexp_match(str, regexp[, flags]) @@ -1354,12 +1801,10 @@ regexp_match(str, regexp[, flags]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **regexp**: Regular expression to match against. Can be a constant, column, or function. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: - **i**: case-insensitive: letters match both upper and lower case - **m**: multi-line mode: ^ and $ match begin/end of line - **s**: allow . to match \n @@ -1369,18 +1814,18 @@ regexp_match(str, regexp[, flags]) #### Example ```sql -select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); -+---------------------------------------------------------+ -| regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | -+---------------------------------------------------------+ -| [Köln] | -+---------------------------------------------------------+ -SELECT regexp_match('aBc', '(b|d)', 'i'); -+---------------------------------------------------+ -| regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | -+---------------------------------------------------+ -| [B] | -+---------------------------------------------------+ + > select regexp_match('Köln', '[a-zA-Z]ö[a-zA-Z]{2}'); + +---------------------------------------------------------+ + | regexp_match(Utf8("Köln"),Utf8("[a-zA-Z]ö[a-zA-Z]{2}")) | + +---------------------------------------------------------+ + | [Köln] | + +---------------------------------------------------------+ + SELECT regexp_match('aBc', '(b|d)', 'i'); + +---------------------------------------------------+ + | regexp_match(Utf8("aBc"),Utf8("(b|d)"),Utf8("i")) | + +---------------------------------------------------+ + | [B] | + +---------------------------------------------------+ ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) @@ -1395,25 +1840,22 @@ regexp_replace(str, regexp, replacement[, flags]) #### Arguments -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **str**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. - **regexp**: Regular expression to match against. Can be a constant, column, or function. -- **replacement**: Replacement string expression. - Can be a constant, column, or function, and any combination of string operators. -- **flags**: Optional regular expression flags that control the behavior of the - regular expression. The following flags are supported: - - **g**: (global) Search globally and don't return after the first match - - **i**: case-insensitive: letters match both upper and lower case - - **m**: multi-line mode: ^ and $ match begin/end of line - - **s**: allow . to match \n - - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used - - **U**: swap the meaning of x* and x*? +- **replacement**: Replacement string expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **flags**: Optional regular expression flags that control the behavior of the regular expression. The following flags are supported: +- **g**: (global) Search globally and don't return after the first match +- **i**: case-insensitive: letters match both upper and lower case +- **m**: multi-line mode: ^ and $ match begin/end of line +- **s**: allow . to match \n +- **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used +- **U**: swap the meaning of x* and x*? #### Example ```sql -SELECT regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); +> select regexp_replace('foobarbaz', 'b(..)', 'X\\1Y', 'g'); +------------------------------------------------------------------------+ | regexp_replace(Utf8("foobarbaz"),Utf8("b(..)"),Utf8("X\1Y"),Utf8("g")) | +------------------------------------------------------------------------+ @@ -1429,58 +1871,36 @@ SELECT regexp_replace('aBc', '(b|d)', 'Ab\\1a', 'i'); Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/regexp.rs) -### `position` - -Returns the position of `substr` in `origstr` (counting from 1). If `substr` does -not appear in `origstr`, return 0. - -``` -position(substr in origstr) -``` - -#### Arguments - -- **substr**: The pattern string. -- **origstr**: The model string. - ## Time and Date Functions -- [now](#now) - [current_date](#current_date) - [current_time](#current_time) +- [current_timestamp](#current_timestamp) - [date_bin](#date_bin) -- [date_trunc](#date_trunc) -- [datetrunc](#datetrunc) +- [date_format](#date_format) - [date_part](#date_part) +- [date_trunc](#date_trunc) - [datepart](#datepart) -- [extract](#extract) -- [today](#today) +- [datetrunc](#datetrunc) +- [from_unixtime](#from_unixtime) - [make_date](#make_date) +- [now](#now) - [to_char](#to_char) +- [to_date](#to_date) +- [to_local_time](#to_local_time) - [to_timestamp](#to_timestamp) -- [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_micros](#to_timestamp_micros) -- [to_timestamp_seconds](#to_timestamp_seconds) +- [to_timestamp_millis](#to_timestamp_millis) - [to_timestamp_nanos](#to_timestamp_nanos) -- [from_unixtime](#from_unixtime) - -### `now` - -Returns the current UTC timestamp. - -The `now()` return value is determined at query time and will return the same timestamp, -no matter when in the query plan the function executes. - -``` -now() -``` +- [to_timestamp_seconds](#to_timestamp_seconds) +- [to_unixtime](#to_unixtime) +- [today](#today) ### `current_date` Returns the current UTC date. -The `current_date()` return value is determined at query time and will return the same date, -no matter when in the query plan the function executes. +The `current_date()` return value is determined at query time and will return the same date, no matter when in the query plan the function executes. ``` current_date() @@ -1490,30 +1910,25 @@ current_date() - today -### `today` - -_Alias of [current_date](#current_date)._ - ### `current_time` Returns the current UTC time. -The `current_time()` return value is determined at query time and will return the same time, -no matter when in the query plan the function executes. +The `current_time()` return value is determined at query time and will return the same time, no matter when in the query plan the function executes. ``` current_time() ``` +### `current_timestamp` + +_Alias of [now](#now)._ + ### `date_bin` -Calculates time intervals and returns the start of the interval nearest to the specified timestamp. -Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" -and applying an aggregate or selector function to each window. +Calculates time intervals and returns the start of the interval nearest to the specified timestamp. Use `date_bin` to downsample time series data by grouping rows into time-based "bins" or "windows" and applying an aggregate or selector function to each window. -For example, if you "bin" or "window" data into 15 minute intervals, an input -timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 -minute bin it is in: `2023-01-01T18:15:00Z`. +For example, if you "bin" or "window" data into 15 minute intervals, an input timestamp of `2023-01-01T18:18:18Z` will be updated to the start time of the 15 minute bin it is in: `2023-01-01T18:15:00Z`. ``` date_bin(interval, expression, origin-timestamp) @@ -1522,10 +1937,8 @@ date_bin(interval, expression, origin-timestamp) #### Arguments - **interval**: Bin interval. -- **expression**: Time expression to operate on. - Can be a constant, column, or function. -- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified - defaults `1970-01-01T00:00:00Z` (the UNIX epoch in UTC). +- **expression**: Time expression to operate on. Can be a constant, column, or function. +- **origin-timestamp**: Optional. Starting point used to determine bin boundaries. If not specified defaults 1970-01-01T00:00:00Z (the UNIX epoch in UTC). The following intervals are supported: @@ -1541,6 +1954,49 @@ The following intervals are supported: - years - century +### `date_format` + +_Alias of [to_char](#to_char)._ + +### `date_part` + +Returns the specified part of the date as an integer. + +``` +date_part(part, expression) +``` + +#### Arguments + +- **part**: Part of the date to return. The following date parts are supported: + + - year + - quarter (emits value in inclusive range [1, 4] based on which quartile of the year the date is in) + - month + - week (week of the year) + - day (day of the month) + - hour + - minute + - second + - millisecond + - microsecond + - nanosecond + - dow (day of the week) + - doy (day of the year) + - epoch (seconds since Unix epoch) + +- **expression**: Time expression to operate on. Can be a constant, column, or function. + +#### Alternative Syntax + +```sql +extract(field FROM source) +``` + +#### Aliases + +- datepart + ### `date_trunc` Truncates a timestamp value to a specified precision. @@ -1551,8 +2007,7 @@ date_trunc(precision, expression) #### Arguments -- **precision**: Time precision to truncate to. - The following precisions are supported: +- **precision**: Time precision to truncate to. The following precisions are supported: - year / YEAR - quarter / QUARTER @@ -1563,73 +2018,44 @@ date_trunc(precision, expression) - minute / MINUTE - second / SECOND -- **expression**: Time expression to operate on. - Can be a constant, column, or function. +- **expression**: Time expression to operate on. Can be a constant, column, or function. #### Aliases - datetrunc +### `datepart` + +_Alias of [date_part](#date_part)._ + ### `datetrunc` _Alias of [date_trunc](#date_trunc)._ -### `date_part` +### `from_unixtime` -Returns the specified part of the date as an integer. +Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) return the corresponding timestamp. ``` -date_part(part, expression) +from_unixtime(expression[, timezone]) ``` #### Arguments -- **part**: Part of the date to return. - The following date parts are supported: - - - year - - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - - month - - week _(week of the year)_ - - day _(day of the month)_ - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow _(day of the week)_ - - doy _(day of the year)_ - - epoch _(seconds since Unix epoch)_ - -- **expression**: Time expression to operate on. - Can be a constant, column, or function. - -#### Aliases - -- datepart - -### `datepart` - -_Alias of [date_part](#date_part)._ - -### `extract` - -Returns a sub-field from a time value as an integer. +- **expression**: The expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **timezone**: Optional timezone to use when converting the integer to a timestamp. If not provided, the default timezone is UTC. -``` -extract(field FROM source) -``` - -Equivalent to calling `date_part('field', source)`. For example, these are equivalent: +#### Example ```sql -extract(day FROM '2024-04-13'::date) -date_part('day', '2024-04-13'::date) +> select from_unixtime(1599572549, 'America/New_York'); ++-----------------------------------------------------------+ +| from_unixtime(Int64(1599572549),Utf8("America/New_York")) | ++-----------------------------------------------------------+ +| 2020-09-08T09:42:29-04:00 | ++-----------------------------------------------------------+ ``` -See [date_part](#date_part). - ### `make_date` Make a date from year/month/day component parts. @@ -1640,16 +2066,13 @@ make_date(year, month, day) #### Arguments -- **year**: Year to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. -- **month**: Month to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. -- **day**: Day to use when making the date. - Can be a constant, column or function, and any combination of arithmetic operators. +- **year**: Year to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **month**: Month to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. #### Example -``` +```sql > select make_date(2023, 1, 31); +-------------------------------------------+ | make_date(Int64(2023),Int64(1),Int64(31)) | @@ -1664,13 +2087,25 @@ make_date(year, month, day) +-----------------------------------------------+ ``` -Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/make_date.rs) + +### `now` + +Returns the current UTC timestamp. + +The `now()` return value is determined at query time and will return the same timestamp, no matter when in the query plan the function executes. + +``` +now() +``` + +#### Aliases + +- current_timestamp ### `to_char` -Returns a string representation of a date, time, timestamp or duration based -on a [Chrono format]. Unlike the PostgreSQL equivalent of this function -numerical formatting is not supported. +Returns a string representation of a date, time, timestamp or duration based on a [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html). Unlike the PostgreSQL equivalent of this function numerical formatting is not supported. ``` to_char(expression, format) @@ -1678,15 +2113,14 @@ to_char(expression, format) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function that results in a - date, time, timestamp or duration. -- **format**: A [Chrono format] string to use to convert the expression. +- **expression**: Expression to operate on. Can be a constant, column, or function that results in a date, time, timestamp or duration. +- **format**: A [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) string to use to convert the expression. +- **day**: Day to use when making the date. Can be a constant, column or function, and any combination of arithmetic operators. #### Example -``` -> > select to_char('2023-03-01'::date, '%d-%m-%Y'); +```sql +> select to_char('2023-03-01'::date, '%d-%m-%Y'); +----------------------------------------------+ | to_char(Utf8("2023-03-01"),Utf8("%d-%m-%Y")) | +----------------------------------------------+ @@ -1694,43 +2128,131 @@ to_char(expression, format) +----------------------------------------------+ ``` -Additional examples can be found [here] - -[here]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_char.rs) #### Aliases - date_format -### `to_timestamp` +### `to_date` -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). -Supports strings, integer, unsigned integer, and double types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. -Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. +Converts a value to a date (`YYYY-MM-DD`). +Supports strings, integer and double types as input. +Strings are parsed as YYYY-MM-DD (e.g. '2023-07-20') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. +Integers and doubles are interpreted as days since the unix epoch (`1970-01-01T00:00:00Z`). +Returns the corresponding date. -Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. -Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` -for the input outside of supported bounds. +Note: `to_date` returns Date32, which represents its values as the number of days since unix epoch(`1970-01-01`) stored as signed 32 bit value. The largest supported date value is `9999-12-31`. ``` -to_timestamp(expression[, ..., format_n]) +to_date('2017-05-31', '%Y-%m-%d') ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. -[chrono format]: https://docs.rs/chrono/latest/chrono/format/strftime/index.html +#### Example + +```sql +> select to_date('2023-01-31'); ++-----------------------------+ +| to_date(Utf8("2023-01-31")) | ++-----------------------------+ +| 2023-01-31 | ++-----------------------------+ +> select to_date('2023/01/31', '%Y-%m-%d', '%Y/%m/%d'); ++---------------------------------------------------------------+ +| to_date(Utf8("2023/01/31"),Utf8("%Y-%m-%d"),Utf8("%Y/%m/%d")) | ++---------------------------------------------------------------+ +| 2023-01-31 | ++---------------------------------------------------------------+ +``` + +Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_date.rs) + +### `to_local_time` + +Converts a timestamp with a timezone to a timestamp without a timezone (with no offset or timezone information). This function handles daylight saving time changes. + +``` +to_local_time(expression) +``` + +#### Arguments + +- **expression**: Time expression to operate on. Can be a constant, column, or function. #### Example +```sql +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels'); ++---------------------------------------------+ +| to_local_time(Utf8("2024-04-01T00:00:20Z")) | ++---------------------------------------------+ +| 2024-04-01T00:00:20 | ++---------------------------------------------+ + +> SELECT + time, + arrow_typeof(time) as type, + to_local_time(time) as to_local_time, + arrow_typeof(to_local_time(time)) as to_local_time_type +FROM ( + SELECT '2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels' AS time +); ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| time | type | to_local_time | to_local_time_type | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ +| 2024-04-01T00:00:20+02:00 | Timestamp(Nanosecond, Some("Europe/Brussels")) | 2024-04-01T00:00:20 | Timestamp(Nanosecond, None) | ++---------------------------+------------------------------------------------+---------------------+-----------------------------+ + +# combine `to_local_time()` with `date_bin()` to bin on boundaries in the timezone rather +# than UTC boundaries + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AS date_bin; ++---------------------+ +| date_bin | ++---------------------+ +| 2024-04-01T00:00:00 | ++---------------------+ + +> SELECT date_bin(interval '1 day', to_local_time('2024-04-01T00:00:20Z'::timestamp AT TIME ZONE 'Europe/Brussels')) AT TIME ZONE 'Europe/Brussels' AS date_bin_with_timezone; ++---------------------------+ +| date_bin_with_timezone | ++---------------------------+ +| 2024-04-01T00:00:00+02:00 | ++---------------------------+ +``` + +### `to_timestamp` + +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00Z`). Supports strings, integer, unsigned integer, and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats] are provided. Integers, unsigned integers, and doubles are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. + +Note: `to_timestamp` returns `Timestamp(Nanosecond)`. The supported range for integer input is between `-9223372037` and `9223372036`. Supported range for string input is between `1677-09-21T00:12:44.0` and `2262-04-11T23:47:16.0`. Please use `to_timestamp_seconds` for the input outside of supported bounds. + +``` +to_timestamp(expression[, ..., format_n]) ``` + +#### Arguments + +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql > select to_timestamp('2023-01-31T09:26:56.123456789-05:00'); +-----------------------------------------------------------+ | to_timestamp(Utf8("2023-01-31T09:26:56.123456789-05:00")) | @@ -1747,79 +2269,65 @@ to_timestamp(expression[, ..., format_n]) Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) -### `to_timestamp_millis` +### `to_timestamp_micros` -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) Returns the corresponding timestamp. ``` -to_timestamp_millis(expression[, ..., format_n]) +to_timestamp_micros(expression[, ..., format_n]) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. #### Example -``` -> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); +```sql +> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); +------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | +| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | +------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123 | +| 2023-01-31T14:26:56.123456 | +------------------------------------------------------------------+ -> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); +> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); +---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | +| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | +---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123 | +| 2023-05-17T03:59:00.123456 | +---------------------------------------------------------------------------------------------------------------+ ``` Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) -### `to_timestamp_micros` +### `to_timestamp_millis` -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as microseconds since the unix epoch (`1970-01-01T00:00:00Z`) -Returns the corresponding timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. Integers and unsigned integers are interpreted as milliseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. ``` -to_timestamp_micros(expression[, ..., format_n]) +to_timestamp_millis(expression[, ..., format_n]) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. #### Example -``` -> select to_timestamp_micros('2023-01-31T09:26:56.123456789-05:00'); +```sql +> select to_timestamp_millis('2023-01-31T09:26:56.123456789-05:00'); +------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("2023-01-31T09:26:56.123456789-05:00")) | +| to_timestamp_millis(Utf8("2023-01-31T09:26:56.123456789-05:00")) | +------------------------------------------------------------------+ -| 2023-01-31T14:26:56.123456 | +| 2023-01-31T14:26:56.123 | +------------------------------------------------------------------+ -> select to_timestamp_micros('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); +> select to_timestamp_millis('03:59:00.123456789 05-17-2023', '%c', '%+', '%H:%M:%S%.f %m-%d-%Y'); +---------------------------------------------------------------------------------------------------------------+ -| to_timestamp_micros(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | +| to_timestamp_millis(Utf8("03:59:00.123456789 05-17-2023"),Utf8("%c"),Utf8("%+"),Utf8("%H:%M:%S%.f %m-%d-%Y")) | +---------------------------------------------------------------------------------------------------------------+ -| 2023-05-17T03:59:00.123456 | +| 2023-05-17T03:59:00.123 | +---------------------------------------------------------------------------------------------------------------+ ``` @@ -1827,11 +2335,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `to_timestamp_nanos` -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000000000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. ``` to_timestamp_nanos(expression[, ..., format_n]) @@ -1839,15 +2343,12 @@ to_timestamp_nanos(expression[, ..., format_n]) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. #### Example -``` +```sql > select to_timestamp_nanos('2023-01-31T09:26:56.123456789-05:00'); +-----------------------------------------------------------------+ | to_timestamp_nanos(Utf8("2023-01-31T09:26:56.123456789-05:00")) | @@ -1866,11 +2367,7 @@ Additional examples can be found [here](https://github.com/apache/datafusion/blo ### `to_timestamp_seconds` -Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). -Supports strings, integer, and unsigned integer types as input. -Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format]s are provided. -Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). -Returns the corresponding timestamp. +Converts a value to a timestamp (`YYYY-MM-DDT00:00:00.000Z`). Supports strings, integer, and unsigned integer types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html)s are provided. Integers and unsigned integers are interpreted as seconds since the unix epoch (`1970-01-01T00:00:00Z`). Returns the corresponding timestamp. ``` to_timestamp_seconds(expression[, ..., format_n]) @@ -1878,15 +2375,12 @@ to_timestamp_seconds(expression[, ..., format_n]) #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. -- **format_n**: Optional [Chrono format] strings to use to parse the expression. Formats will be tried in the order - they appear with the first successful one being returned. If none of the formats successfully parse the expression - an error will be returned. +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. #### Example -``` +```sql > select to_timestamp_seconds('2023-01-31T09:26:56.123456789-05:00'); +-------------------------------------------------------------------+ | to_timestamp_seconds(Utf8("2023-01-31T09:26:56.123456789-05:00")) | @@ -1903,73 +2397,96 @@ to_timestamp_seconds(expression[, ..., format_n]) Additional examples can be found [here](https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/to_timestamp.rs) -### `from_unixtime` +### `to_unixtime` -Converts an integer to RFC3339 timestamp format (`YYYY-MM-DDT00:00:00.000000000Z`). -Integers and unsigned integers are interpreted as nanoseconds since the unix epoch (`1970-01-01T00:00:00Z`) -return the corresponding timestamp. +Converts a value to seconds since the unix epoch (`1970-01-01T00:00:00Z`). Supports strings, dates, timestamps and double types as input. Strings are parsed as RFC3339 (e.g. '2023-07-20T05:44:00') if no [Chrono formats](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) are provided. ``` -from_unixtime(expression) +to_unixtime(expression[, ..., format_n]) ``` #### Arguments -- **expression**: Expression to operate on. - Can be a constant, column, or function, and any combination of arithmetic operators. +- **expression**: Expression to operate on. Can be a constant, column, or function, and any combination of arithmetic operators. +- **format_n**: Optional [Chrono format](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) strings to use to parse the expression. Formats will be tried in the order they appear with the first successful one being returned. If none of the formats successfully parse the expression an error will be returned. + +#### Example + +```sql +> select to_unixtime('2020-09-08T12:00:00+00:00'); ++------------------------------------------------+ +| to_unixtime(Utf8("2020-09-08T12:00:00+00:00")) | ++------------------------------------------------+ +| 1599566400 | ++------------------------------------------------+ +> select to_unixtime('01-14-2023 01:01:30+05:30', '%q', '%d-%m-%Y %H/%M/%S', '%+', '%m-%d-%Y %H:%M:%S%#z'); ++-----------------------------------------------------------------------------------------------------------------------------+ +| to_unixtime(Utf8("01-14-2023 01:01:30+05:30"),Utf8("%q"),Utf8("%d-%m-%Y %H/%M/%S"),Utf8("%+"),Utf8("%m-%d-%Y %H:%M:%S%#z")) | ++-----------------------------------------------------------------------------------------------------------------------------+ +| 1673638290 | ++-----------------------------------------------------------------------------------------------------------------------------+ +``` + +### `today` + +_Alias of [current_date](#current_date)._ ## Array Functions +- [array_any_value](#array_any_value) - [array_append](#array_append) -- [array_sort](#array_sort) - [array_cat](#array_cat) - [array_concat](#array_concat) - [array_contains](#array_contains) - [array_dims](#array_dims) +- [array_distance](#array_distance) - [array_distinct](#array_distinct) -- [array_has](#array_has) -- [array_has_all](#array_has_all) -- [array_has_any](#array_has_any) - [array_element](#array_element) - [array_empty](#array_empty) - [array_except](#array_except) - [array_extract](#array_extract) -- [array_fill](#array_fill) +- [array_has](#array_has) +- [array_has_all](#array_has_all) +- [array_has_any](#array_has_any) - [array_indexof](#array_indexof) - [array_intersect](#array_intersect) - [array_join](#array_join) - [array_length](#array_length) - [array_ndims](#array_ndims) -- [array_prepend](#array_prepend) -- [array_pop_front](#array_pop_front) - [array_pop_back](#array_pop_back) +- [array_pop_front](#array_pop_front) - [array_position](#array_position) - [array_positions](#array_positions) +- [array_prepend](#array_prepend) - [array_push_back](#array_push_back) - [array_push_front](#array_push_front) -- [array_repeat](#array_repeat) -- [array_resize](#array_resize) - [array_remove](#array_remove) -- [array_remove_n](#array_remove_n) - [array_remove_all](#array_remove_all) +- [array_remove_n](#array_remove_n) +- [array_repeat](#array_repeat) - [array_replace](#array_replace) -- [array_replace_n](#array_replace_n) - [array_replace_all](#array_replace_all) +- [array_replace_n](#array_replace_n) +- [array_resize](#array_resize) - [array_reverse](#array_reverse) - [array_slice](#array_slice) +- [array_sort](#array_sort) - [array_to_string](#array_to_string) - [array_union](#array_union) - [cardinality](#cardinality) - [empty](#empty) - [flatten](#flatten) - [generate_series](#generate_series) +- [list_any_value](#list_any_value) - [list_append](#list_append) -- [list_sort](#list_sort) - [list_cat](#list_cat) - [list_concat](#list_concat) +- [list_contains](#list_contains) - [list_dims](#list_dims) +- [list_distance](#list_distance) - [list_distinct](#list_distinct) - [list_element](#list_element) +- [list_empty](#list_empty) - [list_except](#list_except) - [list_extract](#list_extract) - [list_has](#list_has) @@ -1980,121 +2497,89 @@ from_unixtime(expression) - [list_join](#list_join) - [list_length](#list_length) - [list_ndims](#list_ndims) -- [list_prepend](#list_prepend) - [list_pop_back](#list_pop_back) - [list_pop_front](#list_pop_front) - [list_position](#list_position) - [list_positions](#list_positions) +- [list_prepend](#list_prepend) - [list_push_back](#list_push_back) - [list_push_front](#list_push_front) -- [list_repeat](#list_repeat) -- [list_resize](#list_resize) - [list_remove](#list_remove) -- [list_remove_n](#list_remove_n) - [list_remove_all](#list_remove_all) +- [list_remove_n](#list_remove_n) +- [list_repeat](#list_repeat) - [list_replace](#list_replace) -- [list_replace_n](#list_replace_n) - [list_replace_all](#list_replace_all) +- [list_replace_n](#list_replace_n) +- [list_resize](#list_resize) +- [list_reverse](#list_reverse) - [list_slice](#list_slice) +- [list_sort](#list_sort) - [list_to_string](#list_to_string) - [list_union](#list_union) - [make_array](#make_array) - [make_list](#make_list) +- [range](#range) - [string_to_array](#string_to_array) - [string_to_list](#string_to_list) -- [trim_array](#trim_array) -- [range](#range) - -### `array_append` - -Appends an element to the end of an array. - -``` -array_append(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to append to the array. - -#### Example - -``` -> select array_append([1, 2, 3], 4); -+--------------------------------------+ -| array_append(List([1,2,3]),Int64(4)) | -+--------------------------------------+ -| [1, 2, 3, 4] | -+--------------------------------------+ -``` - -#### Aliases - -- array_push_back -- list_append -- list_push_back -### `array_sort` +### `array_any_value` -Sort array. +Extracts the element with the index n from the array. ``` -array_sort(array, desc, nulls_first) +array_element(array, index) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **desc**: Whether to sort in descending order(`ASC` or `DESC`). -- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. #### Example -``` -> select array_sort([3, 1, 2]); -+-----------------------------+ -| array_sort(List([3,1,2])) | -+-----------------------------+ -| [1, 2, 3] | -+-----------------------------+ +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ ``` #### Aliases -- list_sort +- list_any_value -### `array_resize` +### `array_append` -Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. +Appends an element to the end of an array. ``` -array_resize(array, size, value) +array_append(array, element) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **size**: New size of given array. -- **value**: Defines new elements' value or empty if value is not set. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. #### Example -``` -> select array_resize([1, 2, 3], 5, 0); -+-------------------------------------+ -| array_resize(List([1,2,3],5,0)) | -+-------------------------------------+ -| [1, 2, 3, 0, 0] | -+-------------------------------------+ +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ ``` #### Aliases -- list_resize +- list_append +- array_push_back +- list_push_back ### `array_cat` @@ -2102,95 +2587,37 @@ _Alias of [array_concat](#array_concat)._ ### `array_concat` -Concatenates arrays. - -``` -array_concat(array[, ..., array_n]) -``` - -#### Arguments - -- **array**: Array expression to concatenate. - Can be a constant, column, or function, and any combination of array operators. -- **array_n**: Subsequent array column or literal array to concatenate. - -#### Example - -``` -> select array_concat([1, 2], [3, 4], [5, 6]); -+---------------------------------------------------+ -| array_concat(List([1,2]),List([3,4]),List([5,6])) | -+---------------------------------------------------+ -| [1, 2, 3, 4, 5, 6] | -+---------------------------------------------------+ -``` - -#### Aliases - -- array_cat -- list_cat -- list_concat - -### `array_contains` - -_Alias of [array_has](#array_has)._ - -### `array_has` - -Returns true if the array contains the element - -``` -array_has(array, element) -``` - -#### Arguments - -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Scalar or Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Aliases - -- list_has - -### `array_has_all` - -Returns true if all elements of sub-array exist in array +Appends an element to the end of an array. ``` -array_has_all(array, sub-array) +array_append(array, element) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **sub-array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. - -#### Aliases - -- list_has_all - -### `array_has_any` +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. -Returns true if any elements exist in both arrays +#### Example -``` -array_has_any(array, sub-array) +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ ``` -#### Arguments +#### Aliases -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **sub-array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- array_cat +- list_concat +- list_cat -#### Aliases +### `array_contains` -- list_has_any +_Alias of [array_has](#array_has)._ ### `array_dims` @@ -2202,12 +2629,11 @@ array_dims(array) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` +```sql > select array_dims([[1, 2, 3], [4, 5, 6]]); +---------------------------------+ | array_dims(List([1,2,3,4,5,6])) | @@ -2220,6 +2646,34 @@ array_dims(array) - list_dims +### `array_distance` + +Returns the Euclidean distance between two input arrays of equal length. + +``` +array_distance(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_distance([1, 2], [1, 4]); ++------------------------------------+ +| array_distance(List([1,2], [1,4])) | ++------------------------------------+ +| 2.0 | ++------------------------------------+ +``` + +#### Aliases + +- list_distance + ### `array_distinct` Returns distinct values from the array after removing duplicates. @@ -2230,12 +2684,11 @@ array_distinct(array) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` +```sql > select array_distinct([1, 3, 2, 3, 1, 2, 4]); +---------------------------------+ | array_distinct(List([1,2,3,4])) | @@ -2258,13 +2711,12 @@ array_element(array, index) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **index**: Index to extract the element from the array. #### Example -``` +```sql > select array_element([1, 2, 3, 4], 3); +-----------------------------------------+ | array_element(List([1,2,3,4]),Int64(3)) | @@ -2279,82 +2731,161 @@ array_element(array, index) - list_element - list_extract +### `array_empty` + +_Alias of [empty](#empty)._ + +### `array_except` + +Returns an array of the elements that appear in the first array but not in the second. + +``` +array_except(array1, array2) +``` + +#### Arguments + +- **array1**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **array2**: Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); ++----------------------------------------------------+ +| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | ++----------------------------------------------------+ +| [1, 2] | ++----------------------------------------------------+ +``` + +#### Aliases + +- list_except + ### `array_extract` _Alias of [array_element](#array_element)._ -### `array_fill` - -Returns an array filled with copies of the given value. +### `array_has` -DEPRECATED: use `array_repeat` instead! +Returns true if the array contains the element. ``` -array_fill(element, array) +array_has(array, element) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to copy to the array. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. -### `flatten` +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` -Converts an array of arrays to a flat array +#### Aliases -- Applies to any depth of nested arrays -- Does not change arrays that are already flat +- list_has +- array_contains +- list_contains -The flattened array contains all the elements from all source arrays. +### `array_has_all` + +Returns true if the array contains the element. + +``` +array_has(array, element) +``` #### Arguments -- **array**: Array expression - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ +``` + +#### Aliases + +- list_has_all + +### `array_has_any` + +Returns true if the array contains the element. ``` -flatten(array) +array_has(array, element) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Scalar or Array expression. Can be a constant, column, or function, and any combination of array operators. + +#### Example + +```sql +> select array_has([1, 2, 3], 2); ++-----------------------------+ +| array_has(List([1,2,3]), 2) | ++-----------------------------+ +| true | ++-----------------------------+ ``` +#### Aliases + +- list_has_any + ### `array_indexof` _Alias of [array_position](#array_position)._ ### `array_intersect` -Returns an array of elements in the intersection of array1 and array2. +Returns distinct values from the array after removing duplicates. ``` -array_intersect(array1, array2) +array_distinct(array) ``` #### Arguments -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` -> select array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_intersect([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [3, 4] | -+----------------------------------------------------+ -> select array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); -+----------------------------------------------------+ -| array_intersect([1, 2, 3, 4], [5, 6, 7, 8]); | -+----------------------------------------------------+ -| [] | -+----------------------------------------------------+ +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ ``` ---- - #### Aliases - list_intersect @@ -2373,19 +2904,18 @@ array_length(array, dimension) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **dimension**: Array dimension. #### Example -``` -> select array_length([1, 2, 3, 4, 5]); -+---------------------------------+ -| array_length(List([1,2,3,4,5])) | -+---------------------------------+ -| 5 | -+---------------------------------+ +```sql +> select array_length([1, 2, 3, 4, 5], 1); ++-------------------------------------------+ +| array_length(List([1,2,3,4,5]), 1) | ++-------------------------------------------+ +| 5 | ++-------------------------------------------+ ``` #### Aliases @@ -2394,120 +2924,126 @@ array_length(array, dimension) ### `array_ndims` -Returns the number of dimensions of the array. +Returns an array of the array's dimensions. ``` -array_ndims(array, element) +array_dims(array) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` -> select array_ndims([[1, 2, 3], [4, 5, 6]]); -+----------------------------------+ -| array_ndims(List([1,2,3,4,5,6])) | -+----------------------------------+ -| 2 | -+----------------------------------+ +```sql +> select array_dims([[1, 2, 3], [4, 5, 6]]); ++---------------------------------+ +| array_dims(List([1,2,3,4,5,6])) | ++---------------------------------+ +| [2, 3] | ++---------------------------------+ ``` #### Aliases - list_ndims -### `array_prepend` +### `array_pop_back` -Prepends an element to the beginning of an array. +Extracts the element with the index n from the array. ``` -array_prepend(element, array) +array_element(array, index) ``` #### Arguments -- **element**: Element to prepend to the array. -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. #### Example -``` -> select array_prepend(1, [2, 3, 4]); -+---------------------------------------+ -| array_prepend(Int64(1),List([2,3,4])) | -+---------------------------------------+ -| [1, 2, 3, 4] | -+---------------------------------------+ +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ ``` #### Aliases -- array_push_front -- list_prepend -- list_push_front +- list_pop_back ### `array_pop_front` -Returns the array without the first element. +Extracts the element with the index n from the array. ``` -array_pop_front(array) +array_element(array, index) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. #### Example -``` -> select array_pop_front([1, 2, 3]); -+-------------------------------+ -| array_pop_front(List([1,2,3])) | -+-------------------------------+ -| [2, 3] | -+-------------------------------+ +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ ``` #### Aliases - list_pop_front -### `array_pop_back` +### `array_position` -Returns the array without the last element. +Returns the position of the first occurrence of the specified element in the array. ``` -array_pop_back(array) +array_position(array, element) +array_position(array, element, index) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to search for position in the array. +- **index**: Index at which to start searching. #### Example -``` -> select array_pop_back([1, 2, 3]); -+-------------------------------+ -| array_pop_back(List([1,2,3])) | -+-------------------------------+ -| [1, 2] | -+-------------------------------+ +```sql +> select array_position([1, 2, 2, 3, 1, 4], 2); ++----------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2)) | ++----------------------------------------------+ +| 2 | ++----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ ``` #### Aliases -- list_pop_back +- list_position +- array_indexof +- list_indexof -### `array_position` +### `array_positions` Returns the position of the first occurrence of the specified element in the array. @@ -2518,56 +3054,60 @@ array_position(array, element, index) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **element**: Element to search for position in the array. - **index**: Index at which to start searching. #### Example -``` +```sql > select array_position([1, 2, 2, 3, 1, 4], 2); +----------------------------------------------+ | array_position(List([1,2,2,3,1,4]),Int64(2)) | +----------------------------------------------+ | 2 | +----------------------------------------------+ +> select array_position([1, 2, 2, 3, 1, 4], 2, 3); ++----------------------------------------------------+ +| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) | ++----------------------------------------------------+ +| 3 | ++----------------------------------------------------+ ``` #### Aliases -- array_indexof -- list_indexof -- list_position +- list_positions -### `array_positions` +### `array_prepend` -Searches for an element in the array, returns all occurrences. +Appends an element to the end of an array. ``` -array_positions(array, element) +array_append(array, element) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to search for positions in the array. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to append to the array. #### Example -``` -> select array_positions([1, 2, 2, 3, 1, 4], 2); -+-----------------------------------------------+ -| array_positions(List([1,2,2,3,1,4]),Int64(2)) | -+-----------------------------------------------+ -| [2, 3] | -+-----------------------------------------------+ +```sql +> select array_append([1, 2, 3], 4); ++--------------------------------------+ +| array_append(List([1,2,3]),Int64(4)) | ++--------------------------------------+ +| [1, 2, 3, 4] | ++--------------------------------------+ ``` #### Aliases -- list_positions +- list_prepend +- array_push_front +- list_push_front ### `array_push_back` @@ -2577,45 +3117,35 @@ _Alias of [array_append](#array_append)._ _Alias of [array_prepend](#array_prepend)._ -### `array_repeat` +### `array_remove` -Returns an array containing element `count` times. +Removes the first element from the array equal to the given value. ``` -array_repeat(element, count) +array_remove(array, element) ``` #### Arguments -- **element**: Element expression. - Can be a constant, column, or function, and any combination of array operators. -- **count**: Value of how many times to repeat the element. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **element**: Element to be removed from the array. #### Example -``` -> select array_repeat(1, 3); -+---------------------------------+ -| array_repeat(Int64(1),Int64(3)) | -+---------------------------------+ -| [1, 1, 1] | -+---------------------------------+ -``` - -``` -> select array_repeat([1, 2], 2); -+------------------------------------+ -| array_repeat(List([1,2]),Int64(2)) | -+------------------------------------+ -| [[1, 2], [1, 2]] | -+------------------------------------+ +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ ``` #### Aliases -- list_repeat +- list_remove -### `array_remove` +### `array_remove_all` Removes the first element from the array equal to the given value. @@ -2625,13 +3155,12 @@ array_remove(array, element) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **element**: Element to be removed from the array. #### Example -``` +```sql > select array_remove([1, 2, 2, 3, 2, 1, 4], 2); +----------------------------------------------+ | array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | @@ -2642,98 +3171,101 @@ array_remove(array, element) #### Aliases -- list_remove +- list_remove_all ### `array_remove_n` -Removes the first `max` elements from the array equal to the given value. +Removes the first element from the array equal to the given value. ``` -array_remove_n(array, element, max) +array_remove(array, element) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **element**: Element to be removed from the array. -- **max**: Number of first occurrences to remove. #### Example -``` -> select array_remove_n([1, 2, 2, 3, 2, 1, 4], 2, 2); -+---------------------------------------------------------+ -| array_remove_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(2)) | -+---------------------------------------------------------+ -| [1, 3, 2, 1, 4] | -+---------------------------------------------------------+ +```sql +> select array_remove([1, 2, 2, 3, 2, 1, 4], 2); ++----------------------------------------------+ +| array_remove(List([1,2,2,3,2,1,4]),Int64(2)) | ++----------------------------------------------+ +| [1, 2, 3, 2, 1, 4] | ++----------------------------------------------+ ``` #### Aliases - list_remove_n -### `array_remove_all` +### `array_repeat` -Removes all elements from the array equal to the given value. +Returns an array containing element `count` times. ``` -array_remove_all(array, element) +array_repeat(element, count) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **element**: Element to be removed from the array. +- **element**: Element expression. Can be a constant, column, or function, and any combination of array operators. +- **count**: Value of how many times to repeat the element. #### Example -``` -> select array_remove_all([1, 2, 2, 3, 2, 1, 4], 2); -+--------------------------------------------------+ -| array_remove_all(List([1,2,2,3,2,1,4]),Int64(2)) | -+--------------------------------------------------+ -| [1, 3, 1, 4] | -+--------------------------------------------------+ +```sql +> select array_repeat(1, 3); ++---------------------------------+ +| array_repeat(Int64(1),Int64(3)) | ++---------------------------------+ +| [1, 1, 1] | ++---------------------------------+ +> select array_repeat([1, 2], 2); ++------------------------------------+ +| array_repeat(List([1,2]),Int64(2)) | ++------------------------------------+ +| [[1, 2], [1, 2]] | ++------------------------------------+ ``` #### Aliases -- list_remove_all +- list_repeat ### `array_replace` -Replaces the first occurrence of the specified element with another specified element. +Replaces the first `max` occurrences of the specified element with another specified element. ``` -array_replace(array, from, to) +array_replace_n(array, from, to, max) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **from**: Initial element. - **to**: Final element. +- **max**: Number of first occurrences to replace. #### Example -``` -> select array_replace([1, 2, 2, 3, 2, 1, 4], 2, 5); -+--------------------------------------------------------+ -| array_replace(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | -+--------------------------------------------------------+ -| [1, 5, 2, 3, 2, 1, 4] | -+--------------------------------------------------------+ +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ ``` #### Aliases - list_replace -### `array_replace_n` +### `array_replace_all` Replaces the first `max` occurrences of the specified element with another specified element. @@ -2743,15 +3275,14 @@ array_replace_n(array, from, to, max) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **from**: Initial element. - **to**: Final element. - **max**: Number of first occurrences to replace. #### Example -``` +```sql > select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); +-------------------------------------------------------------------+ | array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | @@ -2762,37 +3293,66 @@ array_replace_n(array, from, to, max) #### Aliases -- list_replace_n +- list_replace_all -### `array_replace_all` +### `array_replace_n` -Replaces all occurrences of the specified element with another specified element. +Replaces the first `max` occurrences of the specified element with another specified element. ``` -array_replace_all(array, from, to) +array_replace_n(array, from, to, max) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. - **from**: Initial element. - **to**: Final element. +- **max**: Number of first occurrences to replace. #### Example +```sql +> select array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2); ++-------------------------------------------------------------------+ +| array_replace_n(List([1,2,2,3,2,1,4]),Int64(2),Int64(5),Int64(2)) | ++-------------------------------------------------------------------+ +| [1, 5, 5, 3, 2, 1, 4] | ++-------------------------------------------------------------------+ +``` + +#### Aliases + +- list_replace_n + +### `array_resize` + +Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set. + ``` -> select array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5); -+------------------------------------------------------------+ -| array_replace_all(List([1,2,2,3,2,1,4]),Int64(2),Int64(5)) | -+------------------------------------------------------------+ -| [1, 5, 5, 3, 5, 1, 4] | -+------------------------------------------------------------+ +array_resize(array, size, value) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **size**: New size of given array. +- **value**: Defines new elements' value or empty if value is not set. + +#### Example + +```sql +> select array_resize([1, 2, 3], 5, 0); ++-------------------------------------+ +| array_resize(List([1,2,3],5,0)) | ++-------------------------------------+ +| [1, 2, 3, 0, 0] | ++-------------------------------------+ ``` #### Aliases -- list_replace_all +- list_resize ### `array_reverse` @@ -2804,12 +3364,11 @@ array_reverse(array) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` +```sql > select array_reverse([1, 2, 3, 4]); +------------------------------------------------------------+ | array_reverse(List([1, 2, 3, 4])) | @@ -2824,143 +3383,117 @@ array_reverse(array) ### `array_slice` -Returns a slice of the array based on 1-indexed start and end positions. +Extracts the element with the index n from the array. ``` -array_slice(array, begin, end) +array_element(array, index) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **begin**: Index of the first element. - If negative, it counts backward from the end of the array. -- **end**: Index of the last element. - If negative, it counts backward from the end of the array. -- **stride**: Stride of the array slice. The default is 1. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **index**: Index to extract the element from the array. #### Example -``` -> select array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6); -+--------------------------------------------------------+ -| array_slice(List([1,2,3,4,5,6,7,8]),Int64(3),Int64(6)) | -+--------------------------------------------------------+ -| [3, 4, 5, 6] | -+--------------------------------------------------------+ +```sql +> select array_element([1, 2, 3, 4], 3); ++-----------------------------------------+ +| array_element(List([1,2,3,4]),Int64(3)) | ++-----------------------------------------+ +| 3 | ++-----------------------------------------+ ``` #### Aliases - list_slice -### `array_to_string` +### `array_sort` -Converts each element to its text representation. +Sort array. ``` -array_to_string(array, delimiter) +array_sort(array, desc, nulls_first) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **delimiter**: Array element separator. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **desc**: Whether to sort in descending order(`ASC` or `DESC`). +- **nulls_first**: Whether to sort nulls first(`NULLS FIRST` or `NULLS LAST`). #### Example -``` -> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); -+----------------------------------------------------+ -| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | -+----------------------------------------------------+ -| 1,2,3,4,5,6,7,8 | -+----------------------------------------------------+ +```sql +> select array_sort([3, 1, 2]); ++-----------------------------+ +| array_sort(List([3,1,2])) | ++-----------------------------+ +| [1, 2, 3] | ++-----------------------------+ ``` #### Aliases -- array_join -- list_join -- list_to_string +- list_sort -### `array_union` +### `array_to_string` -Returns an array of elements that are present in both arrays (all elements from both arrays) with out duplicates. +Converts each element to its text representation. ``` -array_union(array1, array2) +array_to_string(array, delimiter) ``` #### Arguments -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. #### Example -``` -> select array_union([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_union([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [1, 2, 3, 4, 5, 6] | -+----------------------------------------------------+ -> select array_union([1, 2, 3, 4], [5, 6, 7, 8]); +```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); +----------------------------------------------------+ -| array_union([1, 2, 3, 4], [5, 6, 7, 8]); | +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | +----------------------------------------------------+ -| [1, 2, 3, 4, 5, 6, 7, 8] | +| 1,2,3,4,5,6,7,8 | +----------------------------------------------------+ ``` ---- - #### Aliases -- list_union +- list_to_string +- array_join +- list_join -### `array_except` +### `array_union` -Returns an array of the elements that appear in the first array but not in the second. +Returns distinct values from the array after removing duplicates. ``` -array_except(array1, array2) +array_distinct(array) ``` #### Arguments -- **array1**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **array2**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` -> select array_except([1, 2, 3, 4], [5, 6, 3, 4]); -+----------------------------------------------------+ -| array_except([1, 2, 3, 4], [5, 6, 3, 4]); | -+----------------------------------------------------+ -| [1, 2] | -+----------------------------------------------------+ -> select array_except([1, 2, 3, 4], [3, 4, 5, 6]); -+----------------------------------------------------+ -| array_except([1, 2, 3, 4], [3, 4, 5, 6]); | -+----------------------------------------------------+ -| [1, 2] | -+----------------------------------------------------+ +```sql +> select array_distinct([1, 3, 2, 3, 1, 2, 4]); ++---------------------------------+ +| array_distinct(List([1,2,3,4])) | ++---------------------------------+ +| [1, 2, 3, 4] | ++---------------------------------+ ``` ---- - #### Aliases -- list_except +- list_union ### `cardinality` @@ -2972,12 +3505,11 @@ cardinality(array) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` +```sql > select cardinality([[1, 2, 3, 4], [5, 6, 7, 8]]); +--------------------------------------+ | cardinality(List([1,2,3,4,5,6,7,8])) | @@ -2996,12 +3528,11 @@ empty(array) #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. #### Example -``` +```sql > select empty([1]); +------------------+ | empty(List([1])) | @@ -3010,10 +3541,38 @@ empty(array) +------------------+ ``` -#### Aliases +#### Aliases + +- array_empty +- list_empty + +### `flatten` + +Converts an array of arrays to a flat array. + +- Applies to any depth of nested arrays +- Does not change arrays that are already flat + +The flattened array contains all the elements from all source arrays. + +``` +flatten(array) +``` + +#### Arguments + +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. -- array_empty, -- list_empty +#### Example + +```sql +> select flatten([[1, 2], [3, 4]]); ++------------------------------+ +| flatten(List([1,2], [3,4])) | ++------------------------------+ +| [1, 2, 3, 4] | ++------------------------------+ +``` ### `generate_series` @@ -3025,13 +3584,13 @@ generate_series(start, stop, step) #### Arguments -- **start**: start of the range -- **end**: end of the range (included) -- **step**: increase by step (can not be 0) +- **start**: start of the series. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: end of the series (included). Type must be the same as start. +- **step**: increase by step (can not be 0). Steps less than a day are supported only for timestamp ranges. #### Example -``` +```sql > select generate_series(1,3); +------------------------------------+ | generate_series(Int64(1),Int64(3)) | @@ -3040,6 +3599,10 @@ generate_series(start, stop, step) +------------------------------------+ ``` +### `list_any_value` + +_Alias of [array_any_value](#array_any_value)._ + ### `list_append` _Alias of [array_append](#array_append)._ @@ -3052,13 +3615,21 @@ _Alias of [array_concat](#array_concat)._ _Alias of [array_concat](#array_concat)._ +### `list_contains` + +_Alias of [array_has](#array_has)._ + ### `list_dims` _Alias of [array_dims](#array_dims)._ +### `list_distance` + +_Alias of [array_distance](#array_distance)._ + ### `list_distinct` -_Alias of [array_dims](#array_distinct)._ +_Alias of [array_distinct](#array_distinct)._ ### `list_element` @@ -3070,7 +3641,7 @@ _Alias of [empty](#empty)._ ### `list_except` -_Alias of [array_element](#array_except)._ +_Alias of [array_except](#array_except)._ ### `list_extract` @@ -3094,7 +3665,7 @@ _Alias of [array_position](#array_position)._ ### `list_intersect` -_Alias of [array_position](#array_intersect)._ +_Alias of [array_intersect](#array_intersect)._ ### `list_join` @@ -3108,10 +3679,6 @@ _Alias of [array_length](#array_length)._ _Alias of [array_ndims](#array_ndims)._ -### `list_prepend` - -_Alias of [array_prepend](#array_prepend)._ - ### `list_pop_back` _Alias of [array_pop_back](#array_pop_back)._ @@ -3128,6 +3695,10 @@ _Alias of [array_position](#array_position)._ _Alias of [array_positions](#array_positions)._ +### `list_prepend` + +_Alias of [array_prepend](#array_prepend)._ + ### `list_push_back` _Alias of [array_append](#array_append)._ @@ -3136,37 +3707,37 @@ _Alias of [array_append](#array_append)._ _Alias of [array_prepend](#array_prepend)._ -### `list_repeat` - -_Alias of [array_repeat](#array_repeat)._ - -### `list_resize` - -_Alias of [array_resize](#array_resize)._ - ### `list_remove` _Alias of [array_remove](#array_remove)._ +### `list_remove_all` + +_Alias of [array_remove_all](#array_remove_all)._ + ### `list_remove_n` _Alias of [array_remove_n](#array_remove_n)._ -### `list_remove_all` +### `list_repeat` -_Alias of [array_remove_all](#array_remove_all)._ +_Alias of [array_repeat](#array_repeat)._ ### `list_replace` _Alias of [array_replace](#array_replace)._ +### `list_replace_all` + +_Alias of [array_replace_all](#array_replace_all)._ + ### `list_replace_n` _Alias of [array_replace_n](#array_replace_n)._ -### `list_replace_all` +### `list_resize` -_Alias of [array_replace_all](#array_replace_all)._ +_Alias of [array_resize](#array_resize)._ ### `list_reverse` @@ -3190,25 +3761,19 @@ _Alias of [array_union](#array_union)._ ### `make_array` -Returns an Arrow array using the specified input expressions. +Returns an array using the specified input expressions. ``` make_array(expression1[, ..., expression_n]) ``` -### `array_empty` - -_Alias of [empty](#empty)._ - #### Arguments -- **expression_n**: Expression to include in the output array. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. +- **expression_n**: Expression to include in the output array. Can be a constant, column, or function, and any combination of arithmetic or string operators. #### Example -``` +```sql > select make_array(1, 2, 3, 4, 5); +----------------------------------------------------------+ | make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)) | @@ -3225,94 +3790,114 @@ _Alias of [empty](#empty)._ _Alias of [make_array](#make_array)._ -### `string_to_array` +### `range` -Splits a string in to an array of substrings based on a delimiter. Any substrings matching the optional `null_str` argument are replaced with NULL. -`SELECT string_to_array('abc##def', '##')` or `SELECT string_to_array('abc def', ' ', 'def')` +Returns an Arrow array between start and stop with step. The range start..end contains all values with start <= x < end. It is empty if start >= end. Step cannot be 0. ``` -starts_with(str, delimiter[, null_str]) +range(start, stop, step) ``` #### Arguments -- **str**: String expression to split. -- **delimiter**: Delimiter string to split on. -- **null_str**: Substring values to be replaced with `NULL` +- **start**: Start of the range. Ints, timestamps, dates or string types that can be coerced to Date32 are supported. +- **end**: End of the range (not included). Type must be the same as start. +- **step**: Increase by step (cannot be 0). Steps less than a day are supported only for timestamp ranges. -#### Aliases - -- string_to_list - -### `string_to_list` +#### Example -_Alias of [string_to_array](#string_to_array)._ +```sql +> select range(2, 10, 3); ++-----------------------------------+ +| range(Int64(2),Int64(10),Int64(3))| ++-----------------------------------+ +| [2, 5, 8] | ++-----------------------------------+ -### `trim_array` +> select range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); ++--------------------------------------------------------------+ +| range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH) | ++--------------------------------------------------------------+ +| [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] | ++--------------------------------------------------------------+ +``` -Removes the last n elements from the array. +### `string_to_array` -DEPRECATED: use `array_slice` instead! +Converts each element to its text representation. ``` -trim_array(array, n) +array_to_string(array, delimiter) ``` #### Arguments -- **array**: Array expression. - Can be a constant, column, or function, and any combination of array operators. -- **n**: Element to trim the array. +- **array**: Array expression. Can be a constant, column, or function, and any combination of array operators. +- **delimiter**: Array element separator. -### `range` +#### Example -Returns an Arrow array between start and stop with step. `SELECT range(2, 10, 3) -> [2, 5, 8]` or `SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH);` +```sql +> select array_to_string([[1, 2, 3, 4], [5, 6, 7, 8]], ','); ++----------------------------------------------------+ +| array_to_string(List([1,2,3,4,5,6,7,8]),Utf8(",")) | ++----------------------------------------------------+ +| 1,2,3,4,5,6,7,8 | ++----------------------------------------------------+ +``` -The range start..end contains all values with start <= x < end. It is empty if start >= end. +#### Aliases -Step can not be 0 (then the range will be nonsense.). +- string_to_list -Note that when the required range is a number, it accepts (stop), (start, stop), and (start, stop, step) as parameters, but when the required range is a date, it must be 3 non-NULL parameters. -For example, +### `string_to_list` -``` -SELECT range(3); -SELECT range(1,5); -SELECT range(1,5,1); -``` +_Alias of [string_to_array](#string_to_array)._ -are allowed in number ranges +## Struct Functions -but in date ranges, only +- [named_struct](#named_struct) +- [row](#row) +- [struct](#struct) -``` -SELECT range(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' MONTH); -``` +### `named_struct` -is allowed, and +Returns an Arrow struct using the specified name and input expressions pairs. ``` -SELECT range(DATE '1992-09-01', DATE '1993-03-01', NULL); -SELECT range(NULL, DATE '1993-03-01', INTERVAL '1' MONTH); -SELECT range(DATE '1992-09-01', NULL, INTERVAL '1' MONTH); +named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) ``` -are not allowed - #### Arguments -- **start**: start of the range -- **end**: end of the range (not included) -- **step**: increase by step (can not be 0) +- **expression_n_name**: Name of the column field. Must be a constant string. +- **expression_n_input**: Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators. -#### Aliases +#### Example -- generate_series +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: -## Struct Functions +```sql +> select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ +> select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +``` -- [struct](#struct) -- [named_struct](#named_struct) +### `row` + +_Alias of [struct](#struct)._ ### `struct` @@ -3324,11 +3909,17 @@ For example: `c0`, `c1`, `c2`, etc. struct(expression1[, ..., expression_n]) ``` +#### Arguments + +- **expression1, expression_n**: Expression to include in the output struct. Can be a constant, column, or function, any combination of arithmetic or string operators. + +#### Example + For example, this query converts two columns `a` and `b` to a single column with a struct type of fields `field_a` and `c1`: -``` -select * from t; +```sql +> select * from t; +---+---+ | a | b | +---+---+ @@ -3355,48 +3946,153 @@ select struct(a as field_a, b) from t; +--------------------------------------------------+ ``` +#### Aliases + +- row + +## Map Functions + +- [element_at](#element_at) +- [map](#map) +- [map_extract](#map_extract) +- [map_keys](#map_keys) +- [map_values](#map_values) + +### `element_at` + +_Alias of [map_extract](#map_extract)._ + +### `map` + +Returns an Arrow map with the specified key-value pairs. + +The `make_map` function creates a map from two lists: one for keys and one for values. Each key must be unique and non-null. + +``` +map(key, value) +map(key: value) +make_map(['key1', 'key2'], ['value1', 'value2']) +``` + #### Arguments -- **expression_n**: Expression to include in the output struct. - Can be a constant, column, or function, any combination of arithmetic or - string operators, or a named expression of previous listed . +- **key**: For `map`: Expression to be used for key. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of keys to be used in the map. Each key must be unique and non-null. +- **value**: For `map`: Expression to be used for value. Can be a constant, column, function, or any combination of arithmetic or string operators. + For `make_map`: The list of values to be mapped to the corresponding keys. -### `named_struct` +#### Example -Returns an Arrow struct using the specified name and input expressions pairs. +````sql + -- Using map function + SELECT MAP('type', 'test'); + ---- + {type: test} + + SELECT MAP(['POST', 'HEAD', 'PATCH'], [41, 33, null]); + ---- + {POST: 41, HEAD: 33, PATCH: } + + SELECT MAP([[1,2], [3,4]], ['a', 'b']); + ---- + {[1, 2]: a, [3, 4]: b} + + SELECT MAP { 'a': 1, 'b': 2 }; + ---- + {a: 1, b: 2} + + -- Using make_map function + SELECT MAKE_MAP(['POST', 'HEAD'], [41, 33]); + ---- + {POST: 41, HEAD: 33} + + SELECT MAKE_MAP(['key1', 'key2'], ['value1', null]); + ---- + {key1: value1, key2: } + ``` + + +### `map_extract` + +Returns a list containing the value for the given key or an empty list if the key is not present in the map. + +```` + +map_extract(map, key) + +```` +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. +- **key**: Key to extract from the map. Can be a constant, column, or function, any combination of arithmetic or string operators, or a named expression of the previously listed. + +#### Example + +```sql +SELECT map_extract(MAP {'a': 1, 'b': NULL, 'c': 3}, 'a'); +---- +[1] + +SELECT map_extract(MAP {1: 'one', 2: 'two'}, 2); +---- +['two'] + +SELECT map_extract(MAP {'x': 10, 'y': NULL, 'z': 30}, 'y'); +---- +[] +```` + +#### Aliases + +- element_at + +### `map_keys` + +Returns a list of all keys in the map. ``` -named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) +map_keys(map) ``` -For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `field_a` and `field_b`: +#### Arguments + +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. +#### Example + +```sql +SELECT map_keys(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[a, b, c] + +SELECT map_keys(map([100, 5], [42, 43])); +---- +[100, 5] ``` -select * from t; -+---+---+ -| a | b | -+---+---+ -| 1 | 2 | -| 3 | 4 | -+---+---+ -select named_struct('field_a', a, 'field_b', b) from t; -+-------------------------------------------------------+ -| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | -+-------------------------------------------------------+ -| {field_a: 1, field_b: 2} | -| {field_a: 3, field_b: 4} | -+-------------------------------------------------------+ +### `map_values` + +Returns a list of all values in the map. + +``` +map_values(map) ``` #### Arguments -- **expression_n_name**: Name of the column field. - Must be a constant string. -- **expression_n_input**: Expression to include in the output struct. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. +- **map**: Map expression. Can be a constant, column, or function, and any combination of map operators. + +#### Example + +```sql +SELECT map_values(MAP {'a': 1, 'b': NULL, 'c': 3}); +---- +[1, , 3] + +SELECT map_values(map([100, 5], [42, 43])); +---- +[42, 43] +``` ## Hashing Functions @@ -3417,19 +4113,27 @@ digest(expression, algorithm) #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. -- **algorithm**: String expression specifying algorithm to use. - Must be one of: +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. +- **algorithm**: String expression specifying algorithm to use. Must be one of: +- md5 +- sha224 +- sha256 +- sha384 +- sha512 +- blake2s +- blake2b +- blake3 + +#### Example - - md5 - - sha224 - - sha256 - - sha384 - - sha512 - - blake2s - - blake2b - - blake3 +```sql +> select digest('foo', 'sha256'); ++------------------------------------------+ +| digest(Utf8("foo"), Utf8("sha256")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` ### `md5` @@ -3441,8 +4145,18 @@ md5(expression) #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select md5('foo'); ++-------------------------------------+ +| md5(Utf8("foo")) | ++-------------------------------------+ +| | ++-------------------------------------+ +``` ### `sha224` @@ -3454,8 +4168,18 @@ sha224(expression) #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha224('foo'); ++------------------------------------------+ +| sha224(Utf8("foo")) | ++------------------------------------------+ +| | ++------------------------------------------+ +``` ### `sha256` @@ -3467,8 +4191,18 @@ sha256(expression) #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha256('foo'); ++--------------------------------------+ +| sha256(Utf8("foo")) | ++--------------------------------------+ +| | ++--------------------------------------+ +``` ### `sha384` @@ -3480,8 +4214,18 @@ sha384(expression) #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **expression**: String expression to operate on. Can be a constant, column, or function, and any combination of operators. + +#### Example + +```sql +> select sha384('foo'); ++-----------------------------------------+ +| sha384(Utf8("foo")) | ++-----------------------------------------+ +| | ++-----------------------------------------+ +``` ### `sha512` @@ -3493,17 +4237,29 @@ sha512(expression) #### Arguments -- **expression**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +- **expression**: String + +#### Example + +```sql +> select sha512('foo'); ++-------------------------------------------+ +| sha512(Utf8("foo")) | ++-------------------------------------------+ +| | ++-------------------------------------------+ +``` ## Other Functions - [arrow_cast](#arrow_cast) - [arrow_typeof](#arrow_typeof) +- [get_field](#get_field) +- [version](#version) ### `arrow_cast` -Casts a value to a specific Arrow data type: +Casts a value to a specific Arrow data type. ``` arrow_cast(expression, datatype) @@ -3511,15 +4267,12 @@ arrow_cast(expression, datatype) #### Arguments -- **expression**: Expression to cast. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. -- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name - to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] +- **expression**: Expression to cast. The expression can be a constant, column, or function, and any combination of operators. +- **datatype**: [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) name to cast to, as a string. The format is the same as that returned by [`arrow_typeof`] #### Example -``` +```sql > select arrow_cast(-5, 'Int8') as a, arrow_cast('foo', 'Dictionary(Int32, Utf8)') as b, arrow_cast('bar', 'LargeUtf8') as c, @@ -3530,12 +4283,11 @@ arrow_cast(expression, datatype) +----+-----+-----+---------------------------+ | -5 | foo | bar | 2023-01-02T12:53:02+08:00 | +----+-----+-----+---------------------------+ -1 row in set. Query took 0.001 seconds. ``` ### `arrow_typeof` -Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression: +Returns the name of the underlying [Arrow data type](https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html) of the expression. ``` arrow_typeof(expression) @@ -3543,18 +4295,77 @@ arrow_typeof(expression) #### Arguments -- **expression**: Expression to evaluate. - Can be a constant, column, or function, and any combination of arithmetic or - string operators. +- **expression**: Expression to evaluate. The expression can be a constant, column, or function, and any combination of operators. #### Example -``` +```sql > select arrow_typeof('foo'), arrow_typeof(1); +---------------------------+------------------------+ | arrow_typeof(Utf8("foo")) | arrow_typeof(Int64(1)) | +---------------------------+------------------------+ | Utf8 | Int64 | +---------------------------+------------------------+ -1 row in set. Query took 0.001 seconds. +``` + +### `get_field` + +Returns a field within a map or a struct with the given key. +Note: most users invoke `get_field` indirectly via field access +syntax such as `my_struct_col['field_name']` which results in a call to +`get_field(my_struct_col, 'field_name')`. + +``` +get_field(expression1, expression2) +``` + +#### Arguments + +- **expression1**: The map or struct to retrieve a field for. +- **expression2**: The field name in the map or struct to retrieve data for. Must evaluate to a string. + +#### Example + +```sql +> create table t (idx varchar, v varchar) as values ('data','fusion'), ('apache', 'arrow'); +> select struct(idx, v) from t as c; ++-------------------------+ +| struct(c.idx,c.v) | ++-------------------------+ +| {c0: data, c1: fusion} | +| {c0: apache, c1: arrow} | ++-------------------------+ +> select get_field((select struct(idx, v) from t), 'c0'); ++-----------------------+ +| struct(t.idx,t.v)[c0] | ++-----------------------+ +| data | +| apache | ++-----------------------+ +> select get_field((select struct(idx, v) from t), 'c1'); ++-----------------------+ +| struct(t.idx,t.v)[c1] | ++-----------------------+ +| fusion | +| arrow | ++-----------------------+ +``` + +### `version` + +Returns the version of DataFusion. + +``` +version() +``` + +#### Example + +```sql +> select version(); ++--------------------------------------------+ +| version() | ++--------------------------------------------+ +| Apache DataFusion 42.0.0, aarch64 on macos | ++--------------------------------------------+ ``` diff --git a/docs/source/user-guide/sql/special_functions.md b/docs/source/user-guide/sql/special_functions.md new file mode 100644 index 000000000000..7c9efbb66218 --- /dev/null +++ b/docs/source/user-guide/sql/special_functions.md @@ -0,0 +1,100 @@ + + +# Special Functions + +## Expansion Functions + +- [unnest](#unnest) +- [unnest(struct)](#unnest-struct) + +### `unnest` + +Expands an array or map into rows. + +#### Arguments + +- **array**: Array expression to unnest. + Can be a constant, column, or function, and any combination of array operators. + +#### Examples + +```sql +> select unnest(make_array(1, 2, 3, 4, 5)) as unnested; ++----------+ +| unnested | ++----------+ +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | ++----------+ +``` + +```sql +> select unnest(range(0, 10)) as unnested_range; ++----------------+ +| unnested_range | ++----------------+ +| 0 | +| 1 | +| 2 | +| 3 | +| 4 | +| 5 | +| 6 | +| 7 | +| 8 | +| 9 | ++----------------+ +``` + +### `unnest (struct)` + +Expand a struct fields into individual columns. + +#### Arguments + +- **struct**: Object expression to unnest. + Can be a constant, column, or function, and any combination of object operators. + +#### Examples + +```sql +> create table foo as values ({a: 5, b: 'a string'}), ({a:6, b: 'another string'}); + +> create view foov as select column1 as struct_column from foo; + +> select * from foov; ++---------------------------+ +| struct_column | ++---------------------------+ +| {a: 5, b: a string} | +| {a: 6, b: another string} | ++---------------------------+ + +> select unnest(struct_column) from foov; ++------------------------------------------+------------------------------------------+ +| unnest_placeholder(foov.struct_column).a | unnest_placeholder(foov.struct_column).b | ++------------------------------------------+------------------------------------------+ +| 5 | a string | +| 6 | another string | ++------------------------------------------+------------------------------------------+ +``` diff --git a/docs/source/user-guide/sql/subqueries.md b/docs/source/user-guide/sql/subqueries.md index 6055b0fc5b78..ee75a6a1575c 100644 --- a/docs/source/user-guide/sql/subqueries.md +++ b/docs/source/user-guide/sql/subqueries.md @@ -19,80 +19,391 @@ # Subqueries -DataFusion supports `EXISTS`, `NOT EXISTS`, `IN`, `NOT IN` and Scalar Subqueries. +Subqueries (also known as inner queries or nested queries) are queries within +a query. +Subqueries can be used in `SELECT`, `FROM`, `WHERE`, and `HAVING` clauses. -The examples below are based on the following table. +The examples below are based on the following tables. ```sql -select * from x; +SELECT * FROM x; + +----------+----------+ | column_1 | column_2 | +----------+----------+ | 1 | 2 | +----------+----------+ +| 2 | 4 | ++----------+----------+ +``` + +```sql +SELECT * FROM y; + ++--------+--------+ +| number | string | ++--------+--------+ +| 1 | one | ++--------+--------+ +| 2 | two | ++--------+--------+ +| 3 | three | ++--------+--------+ +| 4 | four | ++--------+--------+ ``` -## EXISTS +## Subquery operators -The `EXISTS` syntax can be used to find all rows in a relation where a correlated subquery produces one or more matches -for that row. Only correlated subqueries are supported. +- [[ NOT ] EXISTS](#-not--exists) +- [[ NOT ] IN](#-not--in) + +### [ NOT ] EXISTS + +The `EXISTS` operator returns all rows where a +_[correlated subquery](#correlated-subqueries)_ produces one or more matches for +that row. `NOT EXISTS` returns all rows where a _correlated subquery_ produces +zero matches for that row. Only _correlated subqueries_ are supported. ```sql -select * from x y where exists (select * from x where x.column_1 = y.column_1); +[NOT] EXISTS (subquery) +``` + +### [ NOT ] IN + +The `IN` operator returns all rows where a given expression’s value can be found +in the results of a _[correlated subquery](#correlated-subqueries)_. +`NOT IN` returns all rows where a given expression’s value cannot be found in +the results of a subquery or list of values. + +```sql +expression [NOT] IN (subquery|list-literal) +``` + +#### Examples + +```sql +SELECT * FROM x WHERE column_1 IN (1,3); + +----------+----------+ | column_1 | column_2 | +----------+----------+ | 1 | 2 | +----------+----------+ -1 row in set. ``` -## NOT EXISTS +```sql +SELECT * FROM x WHERE column_1 NOT IN (1,3); + ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ +| 2 | 4 | ++----------+----------+ +``` + +## SELECT clause subqueries + +`SELECT` clause subqueries use values returned from the inner query as part +of the outer query's `SELECT` list. +The `SELECT` clause only supports [scalar subqueries](#scalar-subqueries) that +return a single value per execution of the inner query. +The returned value can be unique per row. + +```sql +SELECT [expression1[, expression2, ..., expressionN],] () +``` + +**Note**: `SELECT` clause subqueries can be used as an alternative to `JOIN` +operations. -The `NOT EXISTS` syntax can be used to find all rows in a relation where a correlated subquery produces zero matches -for that row. Only correlated subqueries are supported. +### Example ```sql -select * from x y where not exists (select * from x where x.column_1 = y.column_1); -0 rows in set. +SELECT + column_1, + ( + SELECT + first_value(string) + FROM + y + WHERE + number = x.column_1 + ) AS "numeric string" +FROM + x; + ++----------+----------------+ +| column_1 | numeric string | ++----------+----------------+ +| 1 | one | +| 2 | two | ++----------+----------------+ ``` -## IN +## FROM clause subqueries -The `IN` syntax can be used to find all rows in a relation where a given expression's value can be found in the -results of a correlated subquery. +`FROM` clause subqueries return a set of results that is then queried and +operated on by the outer query. ```sql -select * from x where column_1 in (select column_1 from x); +SELECT expression1[, expression2, ..., expressionN] FROM () +``` + +### Example + +The following query returns the average of maximum values per room. +The inner query returns the maximum value for each field from each room. +The outer query uses the results of the inner query and returns the average +maximum value for each field. + +```sql +SELECT + column_2 +FROM + ( + SELECT + * + FROM + x + WHERE + column_1 > 1 + ); + ++----------+ +| column_2 | ++----------+ +| 4 | ++----------+ +``` + +## WHERE clause subqueries + +`WHERE` clause subqueries compare an expression to the result of the subquery +and return _true_ or _false_. +Rows that evaluate to _false_ or NULL are filtered from results. +The `WHERE` clause supports correlated and non-correlated subqueries +as well as scalar and non-scalar subqueries (depending on the the operator used +in the predicate expression). + +```sql +SELECT + expression1[, expression2, ..., expressionN] +FROM + +WHERE + expression operator () +``` + +**Note:** `WHERE` clause subqueries can be used as an alternative to `JOIN` +operations. + +### Examples + +#### `WHERE` clause with scalar subquery + +The following query returns all rows with `column_2` values above the average +of all `number` values in `y`. + +```sql +SELECT + * +FROM + x +WHERE + column_2 > ( + SELECT + AVG(number) + FROM + y + ); + +----------+----------+ | column_1 | column_2 | +----------+----------+ -| 1 | 2 | +| 2 | 4 | +----------+----------+ -1 row in set. ``` -## NOT IN +#### `WHERE` clause with non-scalar subquery -The `NOT IN` syntax can be used to find all rows in a relation where a given expression's value can not be found in the -results of a correlated subquery. +Non-scalar subqueries must use the `[NOT] IN` or `[NOT] EXISTS` operators and +can only return a single column. +The values in the returned column are evaluated as a list. + +The following query returns all rows with `column_2` values in table `x` that +are in the list of numbers with string lengths greater than three from table +`y`. ```sql -select * from x where column_1 not in (select column_1 from x); -0 rows in set. +SELECT + * +FROM + x +WHERE + column_2 IN ( + SELECT + number + FROM + y + WHERE + length(string) > 3 + ); + ++----------+----------+ +| column_1 | column_2 | ++----------+----------+ +| 2 | 4 | ++----------+----------+ ``` -## Scalar Subquery +### `WHERE` clause with correlated subquery -A scalar subquery can be used to produce a single value that can be used in many different contexts in a query. Here -is an example of a filter using a scalar subquery. Only correlated subqueries are supported. +The following query returns rows with `column_2` values from table `x` greater +than the average `string` value length from table `y`. +The subquery in the `WHERE` clause uses the `column_1` value from the outer +query to return the average `string` value length for that specific value. ```sql -select * from x y where column_1 < (select sum(column_2) from x where x.column_1 = y.column_1); +SELECT + * +FROM + x +WHERE + column_2 > ( + SELECT + AVG(length(string)) + FROM + y + WHERE + number = x.column_1 + ); + +----------+----------+ | column_1 | column_2 | +----------+----------+ -| 1 | 2 | +| 2 | 4 | +----------+----------+ -1 row in set. ``` + +## HAVING clause subqueries + +`HAVING` clause subqueries compare an expression that uses aggregate values +returned by aggregate functions in the `SELECT` clause to the result of the +subquery and return _true_ or _false_. +Rows that evaluate to _false_ are filtered from results. +The `HAVING` clause supports correlated and non-correlated subqueries +as well as scalar and non-scalar subqueries (depending on the the operator used +in the predicate expression). + +```sql +SELECT + aggregate_expression1[, aggregate_expression2, ..., aggregate_expressionN] +FROM + +WHERE + +GROUP BY + column_expression1[, column_expression2, ..., column_expressionN] +HAVING + expression operator () +``` + +### Examples + +The following query calculates the averages of even and odd numbers in table `y` +and returns the averages that are equal to the maximum value of `column_1` +in table `x`. + +#### `HAVING` clause with a scalar subquery + +```sql +SELECT + AVG(number) AS avg, + (number % 2 = 0) AS even +FROM + y +GROUP BY + even +HAVING + avg = ( + SELECT + MAX(column_1) + FROM + x + ); + ++-------+--------+ +| avg | even | ++-------+--------+ +| 2 | false | ++-------+--------+ +``` + +#### `HAVING` clause with a non-scalar subquery + +Non-scalar subqueries must use the `[NOT] IN` or `[NOT] EXISTS` operators and +can only return a single column. +The values in the returned column are evaluated as a list. + +The following query calculates the averages of even and odd numbers in table `y` +and returns the averages that are in `column_1` of table `x`. + +```sql +SELECT + AVG(number) AS avg, + (number % 2 = 0) AS even +FROM + y +GROUP BY + even +HAVING + avg IN ( + SELECT + column_1 + FROM + x + ); + ++-------+--------+ +| avg | even | ++-------+--------+ +| 2 | false | ++-------+--------+ +``` + +## Subquery categories + +Subqueries can be categorized as one or more of the following based on the +behavior of the subquery: + +- [correlated](#correlated-subqueries) or + [non-correlated](#non-correlated-subqueries) +- [scalar](#scalar-subqueries) or [non-scalar](#non-scalar-subqueries) + +### Correlated subqueries + +In a **correlated** subquery, the inner query depends on the values of the +current row being processed. + +**Note:** DataFusion internally rewrites correlated subqueries into JOINs to +improve performance. In general correlated subqueries are **less performant** +than non-correlated subqueries. + +### Non-correlated subqueries + +In a **non-correlated** subquery, the inner query _doesn't_ depend on the outer +query and executes independently. +The inner query executes first, and then passes the results to the outer query. + +### Scalar subqueries + +A **scalar** subquery returns a single value (one column of one row). +If no rows are returned, the subquery returns NULL. + +### Non-scalar subqueries + +A **non-scalar** subquery returns 0, 1, or multiple rows, each of which may +contain 1 or multiple columns. For each column, if there is no value to return, +the subquery returns NULL. If no rows qualify to be returned, the subquery +returns 0 rows. diff --git a/docs/source/user-guide/sql/window_functions.md b/docs/source/user-guide/sql/window_functions.md index 4d9d2557249f..8216a3b258b8 100644 --- a/docs/source/user-guide/sql/window_functions.md +++ b/docs/source/user-guide/sql/window_functions.md @@ -19,7 +19,15 @@ # Window Functions -A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. This is comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (new)](window_functions_new.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +Window functions are comparable to the type of calculation that can be done with an aggregate function. However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result Here is an example that shows how to compare each employee's salary with the average salary in his or her department: @@ -138,103 +146,12 @@ RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must All [aggregate functions](aggregate_functions.md) can be used as window functions. -## Ranking functions - -- [row_number](#row_number) -- [rank](#rank) -- [dense_rank](#dense_rank) -- [ntile](#ntile) - -### `row_number` - -Number of the current row within its partition, counting from 1. - -```sql -row_number() -``` - -### `rank` - -Rank of the current row with gaps; same as row_number of its first peer. - -```sql -rank() -``` - -### `dense_rank` - -Rank of the current row without gaps; this function counts peer groups. - -```sql -dense_rank() -``` - -### `ntile` - -Integer ranging from 1 to the argument value, dividing the partition as equally as possible. - -```sql -ntile(expression) -``` - -#### Arguments - -- **expression**: An integer describing the number groups the partition should be split into - ## Analytical functions -- [cume_dist](#cume_dist) -- [percent_rank](#percent_rank) -- [lag](#lag) -- [lead](#lead) - [first_value](#first_value) - [last_value](#last_value) - [nth_value](#nth_value) -### `cume_dist` - -Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). - -```sql -cume_dist() -``` - -### `percent_rank` - -Relative rank of the current row: (rank - 1) / (total rows - 1). - -```sql -percent_rank() -``` - -### `lag` - -Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). Both offset and default are evaluated with respect to the current row. If omitted, offset defaults to 1 and default to null. - -```sql -lag(expression, offset, default) -``` - -#### Arguments - -- **expression**: Expression to operate on -- **offset**: Integer. Specifies how many rows back the value of _expression_ should be retrieved. Defaults to 1. -- **default**: The default value if the offset is not within the partition. Must be of the same type as _expression_. - -### `lead` - -Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). Both offset and default are evaluated with respect to the current row. If omitted, offset defaults to 1 and default to null. - -```sql -lead(expression, offset, default) -``` - -#### Arguments - -- **expression**: Expression to operate on -- **offset**: Integer. Specifies how many rows forward the value of _expression_ should be retrieved. Defaults to 1. -- **default**: The default value if the offset is not within the partition. Must be of the same type as _expression_. - ### `first_value` Returns value evaluated at the row that is the first row of the window frame. diff --git a/docs/source/user-guide/sql/window_functions_new.md b/docs/source/user-guide/sql/window_functions_new.md new file mode 100644 index 000000000000..ae3edb832fcb --- /dev/null +++ b/docs/source/user-guide/sql/window_functions_new.md @@ -0,0 +1,250 @@ + + + + +# Window Functions (NEW) + +Note: this documentation is in the process of being migrated to be [automatically created from the codebase]. +Please see the [Window Functions (Old)](window_functions.md) page for +the rest of the documentation. + +[automatically created from the codebase]: https://github.com/apache/datafusion/issues/12740 + +A _window function_ performs a calculation across a set of table rows that are somehow related to the current row. +This is comparable to the type of calculation that can be done with an aggregate function. +However, window functions do not cause rows to become grouped into a single output row like non-window aggregate calls would. +Instead, the rows retain their separate identities. Behind the scenes, the window function is able to access more than just the current row of the query result + +Here is an example that shows how to compare each employee's salary with the average salary in his or her department: + +```sql +SELECT depname, empno, salary, avg(salary) OVER (PARTITION BY depname) FROM empsalary; + ++-----------+-------+--------+-------------------+ +| depname | empno | salary | avg | ++-----------+-------+--------+-------------------+ +| personnel | 2 | 3900 | 3700.0 | +| personnel | 5 | 3500 | 3700.0 | +| develop | 8 | 6000 | 5020.0 | +| develop | 10 | 5200 | 5020.0 | +| develop | 11 | 5200 | 5020.0 | +| develop | 9 | 4500 | 5020.0 | +| develop | 7 | 4200 | 5020.0 | +| sales | 1 | 5000 | 4866.666666666667 | +| sales | 4 | 4800 | 4866.666666666667 | +| sales | 3 | 4800 | 4866.666666666667 | ++-----------+-------+--------+-------------------+ +``` + +A window function call always contains an OVER clause directly following the window function's name and argument(s). This is what syntactically distinguishes it from a normal function or non-window aggregate. The OVER clause determines exactly how the rows of the query are split up for processing by the window function. The PARTITION BY clause within OVER divides the rows into groups, or partitions, that share the same values of the PARTITION BY expression(s). For each row, the window function is computed across the rows that fall into the same partition as the current row. The previous example showed how to count the average of a column per partition. + +You can also control the order in which rows are processed by window functions using ORDER BY within OVER. (The window ORDER BY does not even have to match the order in which the rows are output.) Here is an example: + +```sql +SELECT depname, empno, salary, + rank() OVER (PARTITION BY depname ORDER BY salary DESC) +FROM empsalary; + ++-----------+-------+--------+--------+ +| depname | empno | salary | rank | ++-----------+-------+--------+--------+ +| personnel | 2 | 3900 | 1 | +| develop | 8 | 6000 | 1 | +| develop | 10 | 5200 | 2 | +| develop | 11 | 5200 | 2 | +| develop | 9 | 4500 | 4 | +| develop | 7 | 4200 | 5 | +| sales | 1 | 5000 | 1 | +| sales | 4 | 4800 | 2 | +| personnel | 5 | 3500 | 2 | +| sales | 3 | 4800 | 2 | ++-----------+-------+--------+--------+ +``` + +There is another important concept associated with window functions: for each row, there is a set of rows within its partition called its window frame. Some window functions act only on the rows of the window frame, rather than of the whole partition. Here is an example of using window frames in queries: + +```sql +SELECT depname, empno, salary, + avg(salary) OVER(ORDER BY salary ASC ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) AS avg, + min(salary) OVER(ORDER BY empno ASC ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS cum_min +FROM empsalary +ORDER BY empno ASC; + ++-----------+-------+--------+--------------------+---------+ +| depname | empno | salary | avg | cum_min | ++-----------+-------+--------+--------------------+---------+ +| sales | 1 | 5000 | 5000.0 | 5000 | +| personnel | 2 | 3900 | 3866.6666666666665 | 3900 | +| sales | 3 | 4800 | 4700.0 | 3900 | +| sales | 4 | 4800 | 4866.666666666667 | 3900 | +| personnel | 5 | 3500 | 3700.0 | 3500 | +| develop | 7 | 4200 | 4200.0 | 3500 | +| develop | 8 | 6000 | 5600.0 | 3500 | +| develop | 9 | 4500 | 4500.0 | 3500 | +| develop | 10 | 5200 | 5133.333333333333 | 3500 | +| develop | 11 | 5200 | 5466.666666666667 | 3500 | ++-----------+-------+--------+--------------------+---------+ +``` + +When a query involves multiple window functions, it is possible to write out each one with a separate OVER clause, but this is duplicative and error-prone if the same windowing behavior is wanted for several functions. Instead, each windowing behavior can be named in a WINDOW clause and then referenced in OVER. For example: + +```sql +SELECT sum(salary) OVER w, avg(salary) OVER w +FROM empsalary +WINDOW w AS (PARTITION BY depname ORDER BY salary DESC); +``` + +## Syntax + +The syntax for the OVER-clause is + +``` +function([expr]) + OVER( + [PARTITION BY expr[, …]] + [ORDER BY expr [ ASC | DESC ][, …]] + [ frame_clause ] + ) +``` + +where **frame_clause** is one of: + +``` + { RANGE | ROWS | GROUPS } frame_start + { RANGE | ROWS | GROUPS } BETWEEN frame_start AND frame_end +``` + +and **frame_start** and **frame_end** can be one of + +```sql +UNBOUNDED PRECEDING +offset PRECEDING +CURRENT ROW +offset FOLLOWING +UNBOUNDED FOLLOWING +``` + +where **offset** is an non-negative integer. + +RANGE and GROUPS modes require an ORDER BY clause (with RANGE the ORDER BY must specify exactly one column). + +## Aggregate functions + +All [aggregate functions](aggregate_functions.md) can be used as window functions. + +## Ranking Functions + +- [cume_dist](#cume_dist) +- [dense_rank](#dense_rank) +- [ntile](#ntile) +- [percent_rank](#percent_rank) +- [rank](#rank) +- [row_number](#row_number) + +### `cume_dist` + +Relative rank of the current row: (number of rows preceding or peer with current row) / (total rows). + +``` +cume_dist() +``` + +### `dense_rank` + +Returns the rank of the current row without gaps. This function ranks rows in a dense manner, meaning consecutive ranks are assigned even for identical values. + +``` +dense_rank() +``` + +### `ntile` + +Integer ranging from 1 to the argument value, dividing the partition as equally as possible + +``` +ntile(expression) +``` + +#### Arguments + +- **expression**: An integer describing the number groups the partition should be split into + +### `percent_rank` + +Returns the percentage rank of the current row within its partition. The value ranges from 0 to 1 and is computed as `(rank - 1) / (total_rows - 1)`. + +``` +percent_rank() +``` + +### `rank` + +Returns the rank of the current row within its partition, allowing gaps between ranks. This function provides a ranking similar to `row_number`, but skips ranks for identical values. + +``` +rank() +``` + +### `row_number` + +Number of the current row within its partition, counting from 1. + +``` +row_number() +``` + +## Analytical Functions + +- [lag](#lag) +- [lead](#lead) + +### `lag` + +Returns value evaluated at the row that is offset rows before the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). + +``` +lag(expression, offset, default) +``` + +#### Arguments + +- **expression**: Expression to operate on +- **offset**: Integer. Specifies how many rows back the value of expression should be retrieved. Defaults to 1. +- **default**: The default value if the offset is not within the partition. Must be of the same type as expression. + +### `lead` + +Returns value evaluated at the row that is offset rows after the current row within the partition; if there is no such row, instead return default (which must be of the same type as value). + +``` +lead(expression, offset, default) +``` + +#### Arguments + +- **expression**: Expression to operate on +- **offset**: Integer. Specifies how many rows forward the value of expression should be retrieved. Defaults to 1. +- **default**: The default value if the offset is not within the partition. Must be of the same type as expression. diff --git a/docs/source/user-guide/sql/write_options.md b/docs/source/user-guide/sql/write_options.md index 7631525f13e5..6fb4ef215ff1 100644 --- a/docs/source/user-guide/sql/write_options.md +++ b/docs/source/user-guide/sql/write_options.md @@ -38,11 +38,11 @@ CREATE EXTERNAL TABLE my_table(a bigint, b bigint) STORED AS csv COMPRESSION TYPE gzip - WITH HEADER ROW DELIMITER ';' LOCATION '/test/location/my_csv_table/' OPTIONS( - NULL_VALUE 'NAN' + NULL_VALUE 'NAN', + 'has_header' 'true' ) ``` @@ -70,6 +70,16 @@ In this example, we write the entirety of `source_table` out to a folder of parq ## Available Options +### Execution Specific Options + +The following options are available when executing a `COPY` query. + +| Option | Description | Default Value | +| ----------------------------------- | ---------------------------------------------------------------------------------- | ------------- | +| execution.keep_partition_by_columns | Flag to retain the columns in the output data when using `PARTITIONED BY` queries. | false | + +Note: `execution.keep_partition_by_columns` flag can also be enabled through `ExecutionOptions` within `SessionConfig`. + ### JSON Format Specific Options The following options are available when writing JSON files. Note: If any unsupported option is specified, an error will be raised and the query will fail. diff --git a/docs/src/library_logical_plan.rs b/docs/src/library_logical_plan.rs deleted file mode 100644 index 355003941570..000000000000 --- a/docs/src/library_logical_plan.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::error::Result; -use datafusion::logical_expr::builder::LogicalTableSource; -use datafusion::logical_expr::{Filter, LogicalPlan, LogicalPlanBuilder, TableScan}; -use datafusion::prelude::*; -use std::sync::Arc; - -#[test] -fn plan_1() -> Result<()> { - // create a logical table source - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), - ]); - let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - - // create a TableScan plan - let projection = None; // optional projection - let filters = vec![]; // optional filters to push down - let fetch = None; // optional LIMIT - let table_scan = LogicalPlan::TableScan(TableScan::try_new( - "person", - Arc::new(table_source), - projection, - filters, - fetch, - )?); - - // create a Filter plan that evaluates `id > 500` and wraps the TableScan - let filter_expr = col("id").gt(lit(500)); - let plan = LogicalPlan::Filter(Filter::try_new(filter_expr, Arc::new(table_scan))?); - - // print the plan - println!("{}", plan.display_indent_schema()); - - Ok(()) -} - -#[test] -fn plan_builder_1() -> Result<()> { - // create a logical table source - let schema = Schema::new(vec![ - Field::new("id", DataType::Int32, true), - Field::new("name", DataType::Utf8, true), - ]); - let table_source = LogicalTableSource::new(SchemaRef::new(schema)); - - // optional projection - let projection = None; - - // create a LogicalPlanBuilder for a table scan - let builder = LogicalPlanBuilder::scan("person", Arc::new(table_source), projection)?; - - // perform a filter that evaluates `id > 500`, and build the plan - let plan = builder.filter(col("id").gt(lit(500)))?.build()?; - - // print the plan - println!("{}", plan.display_indent_schema()); - - Ok(()) -} diff --git a/pre-commit.sh b/pre-commit.sh index 09cf431a1409..7ab121919861 100755 --- a/pre-commit.sh +++ b/pre-commit.sh @@ -25,6 +25,8 @@ # This file be run directly: # $ ./pre-commit.sh +set -e + function RED() { echo "\033[0;31m$@\033[0m" } @@ -57,13 +59,7 @@ fi # 1. cargo clippy echo -e "$(GREEN INFO): cargo clippy ..." - -# Cargo clippy always return exit code 0, and `tee` doesn't work. -# So let's just run cargo clippy. -cargo clippy --all-targets --workspace --features avro,pyarrow -- -D warnings -pushd datafusion-cli -cargo clippy --all-targets --all-features -- -D warnings -popd +./ci/scripts/rust_clippy.sh echo -e "$(GREEN INFO): cargo clippy done" # 2. cargo fmt: format with nightly and stable. diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 325a2cc2fcc4..414fa5569cfe 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -29,4 +29,5 @@ workspace = true arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } +paste = "1.0.15" rand = { workspace = true } diff --git a/test-utils/src/array_gen/mod.rs b/test-utils/src/array_gen/mod.rs new file mode 100644 index 000000000000..4a799ae737d7 --- /dev/null +++ b/test-utils/src/array_gen/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod primitive; +mod string; + +pub use primitive::PrimitiveArrayGenerator; +pub use string::StringArrayGenerator; diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs new file mode 100644 index 000000000000..0581862d63bd --- /dev/null +++ b/test-utils/src/array_gen/primitive.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array}; +use arrow::datatypes::DataType; +use rand::distributions::Standard; +use rand::prelude::Distribution; +use rand::rngs::StdRng; +use rand::Rng; + +/// Trait for converting type safely from a native type T impl this trait. +pub trait FromNative: std::fmt::Debug + Send + Sync + Copy + Default { + /// Convert native type from i64. + fn from_i64(_: i64) -> Option { + None + } +} + +macro_rules! native_type { + ($t: ty $(, $from:ident)*) => { + impl FromNative for $t { + $( + #[inline] + fn $from(v: $t) -> Option { + Some(v) + } + )* + } + }; +} + +native_type!(i8); +native_type!(i16); +native_type!(i32); +native_type!(i64, from_i64); +native_type!(u8); +native_type!(u16); +native_type!(u32); +native_type!(u64); +native_type!(f32); +native_type!(f64); + +/// Randomly generate primitive array +pub struct PrimitiveArrayGenerator { + /// the total number of strings in the output + pub num_primitives: usize, + /// The number of distinct strings in the columns + pub num_distinct_primitives: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +// TODO: support generating more primitive arrays +impl PrimitiveArrayGenerator { + pub fn gen_data(&mut self) -> ArrayRef + where + A: ArrowPrimitiveType, + A::Native: FromNative, + Standard: Distribution<::Native>, + { + // table of primitives from which to draw + let distinct_primitives: PrimitiveArray = (0..self.num_distinct_primitives) + .map(|_| { + Some(match A::DATA_TYPE { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 + | DataType::Date32 => self.rng.gen::(), + + DataType::Date64 => { + // TODO: constrain this range to valid dates if necessary + let date_value = self.rng.gen_range(i64::MIN..=i64::MAX); + let millis_per_day = 86_400_000; + let adjusted_value = date_value - (date_value % millis_per_day); + A::Native::from_i64(adjusted_value).unwrap() + } + + _ => { + let arrow_type = A::DATA_TYPE; + panic!("Unsupported arrow data type: {arrow_type}") + } + }) + }) + .collect(); + + // pick num_primitves randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_primitives) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_primitives > 1 { + let range = 1..(self.num_distinct_primitives as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_primitives, &indicies, options).unwrap() + } +} diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs new file mode 100644 index 000000000000..fbfa2bb941e0 --- /dev/null +++ b/test-utils/src/array_gen/string.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate string arrays +pub struct StringArrayGenerator { + //// The maximum length of the strings + pub max_len: usize, + /// the total number of strings in the output + pub num_strings: usize, + /// The number of distinct strings in the columns + pub num_distinct_strings: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +impl StringArrayGenerator { + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + pub fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() + } +} + +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() + } + } +} diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 3ddba2fec800..9db8920833ae 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -22,6 +22,7 @@ use datafusion_common::cast::as_int32_array; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; +pub mod array_gen; mod data_gen; mod string_gen; pub mod tpcds; diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index 530fc1535387..b598241db1e9 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -1,3 +1,4 @@ +use crate::array_gen::StringArrayGenerator; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -14,27 +15,14 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// -// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, RecordBatch, UInt32Array}; + use crate::stagger_batch; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; use rand::{thread_rng, Rng, SeedableRng}; /// Randomly generate strings -pub struct StringBatchGenerator { - //// The maximum length of the strings - pub max_len: usize, - /// the total number of strings in the output - pub num_strings: usize, - /// The number of distinct strings in the columns - pub num_distinct_strings: usize, - /// The percentage of nulls in the columns - pub null_pct: f64, - /// Random number generator - pub rng: StdRng, -} +pub struct StringBatchGenerator(StringArrayGenerator); impl StringBatchGenerator { /// Make batches of random strings with a random length columns "a" and "b". @@ -44,8 +32,8 @@ impl StringBatchGenerator { pub fn make_input_batches(&mut self) -> Vec { // use a random number generator to pick a random sized output let batch = RecordBatch::try_from_iter(vec![ - ("a", self.gen_data::()), - ("b", self.gen_data::()), + ("a", self.0.gen_data::()), + ("b", self.0.gen_data::()), ]) .unwrap(); stagger_batch(batch) @@ -57,9 +45,9 @@ impl StringBatchGenerator { /// if large is true, the array is a LargeStringArray pub fn make_sorted_input_batches(&mut self, large: bool) -> Vec { let array = if large { - self.gen_data::() + self.0.gen_data::() } else { - self.gen_data::() + self.0.gen_data::() }; let array = arrow::compute::sort(&array, None).unwrap(); @@ -68,39 +56,13 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Creates a StringArray or LargeStringArray with random strings according - /// to the parameters of the BatchGenerator - fn gen_data(&mut self) -> ArrayRef { - // table of strings from which to draw - let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) - .map(|_| Some(random_string(&mut self.rng, self.max_len))) - .collect(); - - // pick num_strings randomly from the distinct string table - let indicies: UInt32Array = (0..self.num_strings) - .map(|_| { - if self.rng.gen::() < self.null_pct { - None - } else if self.num_distinct_strings > 1 { - let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) - } else { - Some(0) - } - }) - .collect(); - - let options = None; - arrow::compute::take(&distinct_strings, &indicies, options).unwrap() - } - /// Return an set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { let mut cases = vec![]; let mut rng = thread_rng(); for null_pct in [0.0, 0.01, 0.1, 0.5] { - for _ in 0..100 { + for _ in 0..10 { // max length of generated strings let max_len = rng.gen_range(1..50); let num_strings = rng.gen_range(1..100); @@ -109,31 +71,15 @@ impl StringBatchGenerator { } else { num_strings }; - cases.push(StringBatchGenerator { + cases.push(StringBatchGenerator(StringArrayGenerator { max_len, num_strings, num_distinct_strings, null_pct, rng: StdRng::from_seed(rng.gen()), - }) + })) } } cases } } - -/// Return a string of random characters of length 1..=max_len -fn random_string(rng: &mut StdRng, max_len: usize) -> String { - // pick characters at random (not just ascii) - match max_len { - 0 => "".to_string(), - 1 => String::from(rng.gen::()), - _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) - .take(len) - .map(char::from) - .collect::() - } - } -}